提交时间:2023-12-14 20:53:51

运行 ID: 24204

#include<bits/stdc++.h> #define up(i,l,r) for(int i=(l);i<=(r);++i) #define down(i,l,r) for(int i=(l);i>=(r);--i) #define p_b push_back using namespace std; typedef unsigned long long ull; typedef long long ll; const int maxn=5e5+10,mod=998244353; inline int read(){ int x=0; short t=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')t=-1;ch=getchar();} while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar(); return x*t; }int n,m,dep[maxn],pw3[maxn],a[maxn],res[maxn],siz[maxn],mxs[maxn],fa[maxn][20],rt[maxn],cnt; vector<int>v[maxn]; int fa2[maxn]; bool vis[maxn]; struct SegTree { vector<int>v[maxn]; void upd(int x,int y){ if(y<0)return; v[y].insert(lower_bound(v[y].begin(),v[y].end(),x),x); }int qry(int lim,int x){ return v[x].end()-lower_bound(v[x].begin(),v[x].end(),lim); } }T; void cl(){ up(i,1,m)T.v[i].clear(); } vector<int>P; int lca(int x,int y){ if(dep[x]<dep[y])swap(x,y); down(i,19,0)if(dep[fa[x][i]]>=dep[y])x=fa[x][i]; if(x==y)return x; down(i,19,0)if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i]; return fa[x][0]; }int dis(int x,int y){ return dep[x]+dep[y]-dep[lca(x,y)]*2; } void dfs1(int u){ for(int x:v[u])if(x!=fa[u][0]){ dep[x]=dep[u]+1,fa[x][0]=u;dfs1(x); } } void dfs(int u,int fa){ P.p_b(u); siz[u]=1,mxs[u]=0; for(int x:v[u])if(x!=fa&&(!vis[x])){ dfs(x,u);siz[u]+=siz[x],mxs[u]=max(mxs[u],siz[x]); } } int get_rt(int u){ P.clear();dfs(u,0); for(int x:P)mxs[x]=max(mxs[x],siz[u]-siz[x]); int mn=0;mxs[mn]=1e9; for(int x:P)if(mxs[x]<mxs[mn])mn=x; return mn; } void bd(int u){ vis[u]=1; for(int x:v[u])if(!vis[x]){ int y=get_rt(x); fa2[y]=u;bd(y); } } void init(){ n=read();m=n; up(i,1,n)a[i]=read()*2; pw3[0]=1;up(i,1,4e5)pw3[i]=pw3[i-1]*3ll%mod; up(i,1,n-1){ int x=read(),y=read(); ++m;v[m].p_b(x),v[m].p_b(y),v[x].p_b(m),v[y].p_b(m); }dep[1]=1;dfs1(1);up(i,1,19)up(j,1,m)fa[j][i]=fa[fa[j][i-1]][i-1]; bd(get_rt(1)); }void slv(){ init(); up(i,1,n){ int x=i; while(x){ T.upd(a[i]-dis(x,i),x); x=fa2[x]; } }up(i,1,m){ int x=i; while(x){ res[i]+=T.qry(dis(x,i),x); x=fa2[x]; }//cout<<"test "<<res[i]<<'\n'; }cl(); up(i,1,n){ int x=i; while(fa2[x]){ T.upd(a[i]-dis(fa2[x],i),x); x=fa2[x]; } }up(i,1,m){ int x=i; while(fa2[x]){ res[i]-=T.qry(dis(fa2[x],i),x); x=fa2[x]; } } int RES=0; up(i,1,m){ if(i<=n)(RES+=pw3[res[i]]-1)%=mod; else (RES-=pw3[res[i]]-1-mod)%=mod; }cout<<RES<<'\n'; } int main(){ // freopen("tree.in","r",stdin); // freopen("tree.out","w",stdout); slv(); fclose(stdin); fclose(stdout); return 0; }