提交时间:2023-12-09 14:27:07
运行 ID: 24144
#include <bits/stdc++.h> using namespace std; #define int long long const int MAXN=3010,Mod=998244353; struct Edge{int v,nx;}edge[MAXN<<1];int h[MAXN],CNT;void init(){memset(h,-1,sizeof(h));CNT=0;}; void add_side(int u,int v){edge[++CNT]={v,h[u]};h[u]=CNT;edge[++CNT]={u,h[v]};h[v]=CNT;} int n,ans,a[MAXN]; int Pow(int x,int y){ int rt=1; while(y){ if(y&1)rt=rt*x%Mod; x=x*x%Mod; y>>=1; } return rt; } int jc[MAXN],ny[MAXN]; int C(int x,int y){ if(y>x)return 0; return jc[x]*ny[y]%Mod*ny[x-y]%Mod; } int j2sm[MAXN]; bool favs[MAXN],vs[MAXN]; void mj(int now,int tp,int lst){ vs[now]=1; for(int i=h[now];i!=-1;i=edge[i].nx){ if(edge[i].v!=lst&&tp>0)mj(edge[i].v,tp-1,now); } } void dfs(int now,int fa){ for(int i=1;i<=n;i++){ favs[i]=vs[i]; vs[i]=0; } mj(now,a[now],-1); int sum=0,jj=0; for(int i=1;i<=n;i++){ sum+=vs[i]; jj+=(vs[i]==favs[i]&&favs[i]); } //cout<<now<<" "<<sum<<" "<<jj<<endl; ans=(ans+j2sm[sum]-j2sm[jj]+Mod)%Mod; for(int i=h[now];i!=-1;i=edge[i].nx){ if(edge[i].v!=fa)dfs(edge[i].v,now); } } signed main(){ init(); jc[0]=ny[0]=jc[1]=ny[1]=1; for(int i=2;i<=2000;i++){ jc[i]=jc[i-1]*i; ny[i]=Pow(jc[i],Mod-2); } scanf("%lld",&n); for(int i=1;i<=n;i++){ j2sm[i]=(Pow(3,i)-1)%Mod; } for(int i=1;i<=n;i++){ scanf("%lld",&a[i]); } for(int i=1;i<n;i++){ int u,v; scanf("%lld%lld",&u,&v); add_side(u,v); } dfs(1,-1); printf("%lld",ans); return 0; }