提交时间:2023-12-09 14:43:34

运行 ID: 24146

#include <bits/stdc++.h> using namespace std; #define int long long #pragma GCC optimize(2) const int MAXN=3010,Mod=998244353; struct Edge{int v,nx;}edge[MAXN<<1];int h[MAXN],CNT;void init(){memset(h,-1,sizeof(h));CNT=0;}; void add_side(int u,int v){edge[++CNT]={v,h[u]};h[u]=CNT;edge[++CNT]={u,h[v]};h[v]=CNT;} int n,ans,a[MAXN]; int Pow(int x,int y){ int rt=1; while(y){ if(y&1)rt=rt*x%Mod; x=x*x%Mod; y>>=1; } return rt; } int jc[MAXN],ny[MAXN]; int C(int x,int y){ if(y>x)return 0; return jc[x]*ny[y]%Mod*ny[x-y]%Mod; } int j2sm[MAXN]; bool vs[MAXN][MAXN]; void mj(int now,int tp,int lst,int frm){ vs[now][frm]=1; for(int i=h[now];i!=-1;i=edge[i].nx){ if(edge[i].v!=lst&&tp>0)mj(edge[i].v,tp-1,now,frm); } } void dfs(int now,int fa){ int sum=0,jj=0; for(int i=1;i<=n;i++){ sum+=vs[now][i]; jj+=(vs[now][i]==vs[fa][i]&&vs[fa][i]); } //cout<<now<<" "<<sum<<":"<<j2sm[sum]<<" "<<jj<<":"<<j2sm[jj]<<endl; ans=(ans+j2sm[sum]-j2sm[jj]+Mod)%Mod; for(int i=h[now];i!=-1;i=edge[i].nx){ if(edge[i].v!=fa)dfs(edge[i].v,now); } } signed main(){ init(); jc[0]=ny[0]=jc[1]=ny[1]=1; for(int i=2;i<=2000;i++){ jc[i]=jc[i-1]*i; ny[i]=Pow(jc[i],Mod-2); } scanf("%lld",&n); for(int i=1;i<=n;i++){ j2sm[i]=(Pow(3,i)-1)%Mod; } for(int i=1;i<=n;i++){ scanf("%lld",&a[i]); } for(int i=1;i<n;i++){ int u,v; scanf("%lld%lld",&u,&v); add_side(u,v); } for(int i=1;i<=n;i++){ mj(i,a[i],-1,i); } dfs(1,-1); printf("%lld",ans); return 0; }