提交时间:2023-12-09 08:32:00

运行 ID: 24039

#include <bits/stdc++.h> using namespace std; #define int long long const int MAXN=3010,Mod=998244353; struct Edge{int v,nx;}edge[MAXN<<1];int h[MAXN],CNT;void init(){memset(h,-1,sizeof(h));CNT=0;}; void add_side(int u,int v){edge[++CNT]={v,h[u]};h[u]=CNT;edge[++CNT]={u,h[v]};h[v]=CNT;} bool b[MAXN]; int n,ans,a[MAXN],vs[MAXN]; void check(int now,int lst,int tp){ vs[now]++; for(int i=h[now];i!=-1;i=edge[i].nx){ if(tp>0&&edge[i].v!=lst)check(edge[i].v,now,tp-1); } } int Pow(int x,int y){ int rt=1; while(y){ if(y&1)rt=rt*x%Mod; x=x*x%Mod; y>>=1; } return rt; } int jc[MAXN],ny[MAXN]; int C(int x,int y){ if(y>x)return 0; return jc[x]*ny[y]%Mod*ny[x-y]%Mod; } int j2sm[MAXN]; void dfs(int now){ if(now==n+1){ int sum=0; for(int i=1;i<=n;i++)vs[i]=0; for(int i=1;i<=n;i++){ if(b[i]){ sum++; check(i,-1,a[i]); } } if(sum==0)return; for(int i=1;i<=n;i++){ if(vs[i]==sum){ ans=(ans+Pow(2,sum))%Mod; break; } } return; } b[now]=1; dfs(now+1); b[now]=0; dfs(now+1); } int d[MAXN]; bool lian=1,dqw1=1; struct Pt{ int dn,a,id; }pt[MAXN]; int cnt; void dfn(int now,int lst){ pt[now].dn=++cnt; for(int i=h[now];i!=-1;i=edge[i].nx){ if(edge[i].v!=lst)dfn(edge[i].v,now); } } bool cmp(Pt x,Pt y){ return x.dn<y.dn; } struct Po{ int x,q; }op[MAXN]; bool cmp1(Po x,Po y){ return x.x<y.x; } signed main(){ init(); jc[0]=ny[0]=jc[1]=ny[1]=1; for(int i=2;i<=2000;i++){ jc[i]=jc[i-1]*i; ny[i]=Pow(jc[i],Mod-2); } scanf("%lld",&n); for(int i=1;i<=n;i++){ j2sm[i]=(Pow(3,i)-1)%Mod; } for(int i=1;i<=n;i++){ scanf("%lld",&a[i]);pt[i].a=a[i];pt[i].id=i; if(a[i]!=1)dqw1=0; } for(int i=1;i<n;i++){ int u,v; scanf("%lld%lld",&u,&v); add_side(u,v);d[u]++;d[v]++; if(d[u]>2||d[v]>2)lian=0; } if(n<=8){ dfs(1); } else if(dqw1){ for(int i=1;i<=n;i++){ ans=(ans+Pow(2,d[i]))%Mod; } } else if(lian){ for(int i=1;i<=n;i++){ if(d[i]==1){ dfn(i,-1); break; } } sort(pt+1,pt+n+1,cmp); for(int i=1;i<=n;i++){ op[i]={max((int)1,i-pt[i].a),1}; op[i+n]={min(n,i+pt[i].a+1),-1}; } sort(op+1,op+2*n+1,cmp1); int jian=0,zong=0,lst=0; for(int i=1;i<=n*2;i++){ if(op[i].q==-1)jian--; zong+=op[i].q; if(op[i].x!=op[i+1].x||i==n){ ans=(ans+j2sm[lst+zong]-j2sm[lst+jian])%Mod; lst=lst+zong; jian=0;zong=0; } } } printf("%lld",ans); return 0; }