提交时间:2025-10-18 16:56:37
运行 ID: 38699
#include<bits/stdc++.h> using namespace std; int n,in[1000005],f[1000005],sum,ii,ss,ans,ds; bool b[1000005]; vector<int> p[1000005]; queue<int> q; int dfs(int x,int fa){ int ans=1; for(int i=0;i<p[x].size();i++){ if(p[x][i]==fa) continue; ans+=dfs(p[x][i],x); } return ans; } void dfs2(int x,int fa,int lenn){ if(lenn>ans&&b[x]) { ans=lenn,ii=x; } for(int i=0;i<p[x].size();i++){ if(p[x][i]!=fa) dfs2(p[x][i],x,lenn+1); }//cout<<x<<' '<<fa<<'\n'; } void dfs3(int x,int fa,int lenn){ f[x]=fa; if(lenn>ans&&b[x]) { ans=lenn,ii=x; } for(int i=0;i<p[x].size();i++){ if(p[x][i]!=fa) dfs3(p[x][i],x,lenn+1); } } int main(){ freopen("flower.in","r",stdin); // freopen("flower.out","w",stdout); cin>>n; for(int i=1;i<=n-1;i++){ int v; cin>>v; p[i+1].push_back(v); p[v].push_back(i+1); in[v]++; in[i+1]++; } for(int i=1;i<=n;i++){ if(in[i]>2) sum++,q.push(i); } if(sum==0){ cout<<n*1ll*(n-1)/2; return 0; } else if(sum==1&&in[q.front()]>3){ cout<<0; return 0; } else if(sum==1&&in[q.front()]==3){ if(n==4){ cout<<3; return 0; } int x=q.front(); long long y,z,e; y=dfs(p[x][0],x)*1ll; z=dfs(p[x][1],x)*1ll; e=dfs(p[x][2],x)*1ll; cout<<y*z+z*e+y*e; return 0; } else if(in[q.front()]==3){ ds=q.size(); while(q.size()>=2){ b[q.front()]=1; q.pop(); } int x=q.front(); b[x]=1; ans=0; dfs2(x,0,0);//cout<<x<<' '; ss=ii; dfs3(ii,0,0); int suum=2; while(f[ii]!=ss){ if(b[f[ii]]) suum++; ii=f[ii]; } if(suum!=ds) cout<<0; } return 0; }