提交时间:2023-12-14 21:07:21

运行 ID: 24206

#include<bits/stdc++.h> #pragma gcc optimize(2) #define up(i,l,r) for(int i=(l);i<=(r);++i) #define down(i,l,r) for(int i=(l);i>=(r);--i) #define p_b push_back using namespace std; typedef unsigned long long ull; typedef long long ll; const int maxn=5e5+10,mod=998244353; inline int read(){ int x=0; short t=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')t=-1;ch=getchar();} while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar(); return x*t; }int n,m,dep[maxn],pw3[maxn],a[maxn],res[maxn],siz[maxn],mxs[maxn],sum[maxn]; vector<int>v[maxn]; bool vis[maxn]; vector<int>P; void dfs1(int u,int fa){ P.p_b(u); for(int x:v[u])if(x!=fa&&(!vis[x])){ dep[x]=dep[u]+1;dfs1(x,u); } } void dfs(int u,int fa){ P.p_b(u); siz[u]=1,mxs[u]=0; for(int x:v[u])if(x!=fa&&(!vis[x])){ dfs(x,u);siz[u]+=siz[x],mxs[u]=max(mxs[u],siz[x]); } }void calc(int OP){ int sz=P.size(); up(i,0,sz+5)sum[i]=0; for(int x:P)if(x<=n&&a[x]>=dep[x])sum[min(sz,a[x]-dep[x])]++; down(i,sz-1,0)sum[i]+=sum[i+1]; for(int x:P)res[x]+=OP*sum[dep[x]]; } int get_rt(int u){ P.clear();dfs(u,0); for(int x:P)mxs[x]=max(mxs[x],siz[u]-siz[x]); int mn=0;mxs[mn]=1e9; for(int x:P)if(mxs[x]<mxs[mn])mn=x; return mn; } void bd(int u){ vis[u]=1,dep[u]=0;P.clear();dfs1(u,0); calc(1); for(int x:v[u])if(!vis[x]){ P.clear();dfs1(x,0); calc(-1); int y=get_rt(x);bd(y); } } void init(){ n=read();m=n; up(i,1,n)a[i]=read()*2; pw3[0]=1;up(i,1,4e5)pw3[i]=pw3[i-1]*3ll%mod; up(i,1,n-1){ int x=read(),y=read(); ++m;v[m].p_b(x),v[m].p_b(y),v[x].p_b(m),v[y].p_b(m); }bd(get_rt(1)); }void slv(){ init(); int RES=0; up(i,1,m){ if(i<=n)(RES+=pw3[res[i]]-1)%=mod; else (RES-=pw3[res[i]]-1-mod)%=mod; }cout<<RES<<'\n'; } int main(){ // freopen("tree.in","r",stdin); // freopen("tree.out","w",stdout); slv(); fclose(stdin); fclose(stdout); return 0; }