提交时间:2023-12-09 11:31:32
运行 ID: 24111
#include<iostream> #include<cstring> #include<bitset> #include<vector> using namespace std; #define int long long #define memoryMB(x) (sizeof(x)>>20) inline int read(){ int i=getchar(),r=0; while(i<'0'||i>'9')i=getchar(); while(i>='0'&&i<='9')r=(r<<1)+(r<<3)+(i^48),i=getchar(); return r; } const int N=200100,P=998244353; inline int fpow(int a,int b){ int c=1; for(;b;b>>=1,a=a*a%P)if(b&1)c=c*a%P; return c; } int n,a[N]; int head[N],to[N<<1],nex[N<<1]; #define forson(x,nd,y) for(int x=head[nd];x;x=nex[x])if(y) int F[N],dfn[N],dep[N],S[N]; vector<int>v[N]; void dfs(int nd){ dfn[nd]=++dfn[0]; S[nd]=1; v[max(0ll,dep[nd]-a[nd]+1)].emplace_back(dfn[nd]); forson(i,nd,to[i]!=F[nd]){ F[to[i]]=nd; dep[to[i]]=dep[nd]+1; dfs(to[i]); S[nd]+=S[to[i]]; } } inline int count(const vector<int>&V,int l,int r){//[l,r) auto L=lower_bound(V.begin(),V.end(),l); auto R=lower_bound(V.begin(),V.end(),r); return R-L; } void init(){ cin>>n; for(int i=1;i<=n;i++)a[i]=read(); for(int i=1;i<n;i++){ int u=read(),v=read(); to[i<<1]=v; nex[i<<1]=head[u]; head[u]=i<<1; to[i<<1|1]=u; nex[i<<1|1]=head[v]; head[v]=i<<1|1; } dfs(1); } int cnt[N]; int siz[N]; bool vis[N]; void get_siz(int nd,int fa){ siz[nd]=1; forson(i,nd,to[i]!=fa&&!vis[to[i]])get_siz(to[i],nd),siz[nd]+=siz[to[i]]; } int barycenter(int nd,int fa,int tot){ int mx=tot-siz[nd]; forson(i,nd,to[i]!=fa&&!vis[to[i]])mx=max(mx,siz[to[i]]); if(mx*2<=tot)return nd; forson(i,nd,to[i]!=fa&&!vis[to[i]]){ int p=barycenter(to[i],nd,tot); if(p)return p; } return 0; } int t[N]; inline void add(int i,int k){i++;while(i)t[i]+=k,i-=i&-i;} inline int get_sum(int i){i++;int r=0;while(i<=n+1)r+=t[i],i+=i&-i;return r;} void insert(int nd,int fa,int d){ if(d<=a[nd])add(a[nd]-d,1); forson(i,nd,!vis[to[i]]&&to[i]!=fa)insert(to[i],nd,d+1); } void erase(int nd,int fa,int d){ if(d<=a[nd])add(a[nd]-d,-1); forson(i,nd,!vis[to[i]]&&to[i]!=fa)erase(to[i],nd,d+1); } void update(int nd,int fa,int d,int s){ cnt[nd]=(cnt[nd]+s*get_sum(d)); forson(i,nd,!vis[to[i]]&&to[i]!=fa)update(to[i],nd,d+1,s); } void solve(int nd){ get_siz(nd,0); vis[nd=barycenter(nd,0,siz[nd])]=true; insert(nd,0,0); update(nd,0,0,1); erase(nd,0,0); forson(i,nd,!vis[to[i]]){ insert(to[i],nd,1); update(to[i],nd,1,-1); erase(to[i],nd,1); } forson(i,nd,!vis[to[i]])solve(to[i]); } signed main(){ // freopen("tree.in","r",stdin); init(); solve(1); int ans=0; for(int i=1;i<=n;i++){ ans=(ans+fpow(3,cnt[i]))%P; int x=cnt[i]-count(v[dep[i]+1],dfn[i],dfn[i]+S[i]); if(F[i])ans=(ans-fpow(3,x))%P; else ans--; } cout<<(ans+P)%P; return 0; }