提交时间:2025-10-18 16:35:01

运行 ID: 38694

#include<bits/stdc++.h> using namespace std; const int N = 1e6+10; int n,rt,d[N],dep[N],fr[N],s[3][2],f[N],son,son2,son3; long long ans,siz[N]; vector<int> G[N]; void dfs(int x,int fa){ siz[x]=1; dep[x]=dep[fa]+1; f[x]=fa; for(int i=0;i<G[x].size();i++){ int v=G[x][i]; if(v==fa) continue; dfs(v,x); siz[x]+=siz[v]; } } void dfs2(int x,int fa){ for(int i=0;i<G[x].size();i++){ int v=G[x][i]; if(v==fa) continue; if(x!=rt){ fr[v]=fr[x]; } else{ fr[v]=v; } if(d[v]==3){ if(fr[v]==son) s[0][0]=siz[v],s[0][1]=v; else if(fr[v]==son2) s[1][0]=siz[v],s[1][1]=v; else s[2][0]=siz[v],s[2][1]=v; } dfs2(v,x); } } int main(){ cin>>n; for(int i=1;i<=n-1;i++){ int x; cin>>x; G[i+1].push_back(x); G[x].push_back(i+1); d[x]++; d[i+1]++; } bool flag=true,flag2=false; for(int i=1;i<=n;i++){ if(d[i]>=3){ flag=false; if(d[i]>3) flag2=true; } } if(flag){ ans=(long long)n*((long long)n-1ll)/2ll; cout<<ans<<endl; return 0; } if(flag2){ cout<<0<<endl; return 0; } int cnt=0; for(int i=1;i<=n;i++){ if(d[i]==3){ cnt++; rt=i; } } son=G[rt][0]; son2=G[rt][1]; son3=G[rt][2]; if(cnt==1){ dfs(rt,0); ans=ans+(long long)siz[son]*(long long)siz[son2]; ans=ans+(long long)siz[son2]*(long long)siz[son3]; ans=ans+(long long)siz[son]*(long long)siz[son3]; cout<<ans<<endl; return 0; } dfs(rt,0); dfs2(rt,0); int now=s[0][1],now2=s[1][1],now3=s[2][1],tot=0; if(now!=0&&now2!=0&&now3!=0){ cout<<0<<endl; return 0; } if(now==0&&now2==0){ while(now3!=rt){ if(d[now3]==3) tot++; now3=f[now3]; } tot++; if(tot!=cnt){ cout<<0<<endl; } else{ ans=(siz[now3]-1)*(siz[son]+siz[son2]); cout<<ans<<endl; } } else if(now==0&&now3==0){ while(now2!=rt){ if(d[now2]==3) tot++; now2=f[now2]; } tot++; if(tot!=cnt){ cout<<0<<endl; } else{ ans=(siz[now2]-1)*(siz[son]+siz[son3]); cout<<ans<<endl; } } else if(now2==0&&now3==0){ while(now!=rt){ if(d[now]==3) tot++; now=f[now]; } tot++; if(tot!=cnt){ cout<<0<<endl; } else{ ans=(siz[now]-1)*(siz[son2]+siz[son3]); cout<<ans<<endl; } } else if(now==0){ while(now2!=rt){ if(d[now2]==3) tot++; now2=f[now2]; } while(now3!=rt){ if(d[now3]==3) tot++; now3=f[now3]; } tot++; if(tot!=cnt){ cout<<0<<endl; } else{ ans=(siz[now2]-1)*(siz[now3]-1); cout<<ans<<endl; } } else if(now2==0){ while(now!=rt){ if(d[now]==3) tot++; now=f[now]; } while(now3!=rt){ if(d[now3]==3) tot++; now3=f[now3]; } tot++; if(tot!=cnt){ cout<<0<<endl; } else{ ans=(siz[now]-1)*(siz[now3]-1); cout<<ans<<endl; } } else if(now3==0){ while(now!=rt){ if(d[now]==3) tot++; now=f[now]; } while(now2!=rt){ if(d[now2]==3) tot++; now2=f[now2]; } tot++; if(tot!=cnt){ cout<<0<<endl; } else{ ans=(siz[now]-1)*(siz[now2]-1); cout<<ans<<endl; } } return 0; }