提交时间:2023-12-13 09:42:50

运行 ID: 24187

#include<bits/stdc++.h> #pragma gcc optimize(2) #define up(i,l,r) for(int i=(l);i<=(r);++i) #define down(i,l,r) for(int i=(l);i>=(r);--i) #define p_b push_back using namespace std; typedef unsigned long long ull; typedef long long ll; const int maxn=5e5+10,mod=998244353; inline int read(){ int x=0; short t=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')t=-1;ch=getchar();} while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar(); return x*t; }int n,m,dep[maxn],pw3[maxn],a[maxn],d[maxn],a2[maxn],dfn[maxn],siz[maxn],bel[maxn],lz[maxn],res[maxn],len,cnt; vector<int>v2[maxn]; vector<int>g; void upd(int l,int r,int val){ auto it=lower_bound(g.begin(),g.end(),l); if(it==g.end())return; l=it-g.begin()+1; it=upper_bound(g.begin(),g.end(),r); if(it==g.begin())return; it--,r=it-g.begin()+1; if(l>r)return; // cout<<l<<' '<<r<<'\n'; // cout<<"len "<<len<<'\n'; if(bel[l]==bel[r]){ up(i,l,r)a[i]+=val; int L=(bel[l]-1)*len+1,R=min(bel[l]*len,n); up(i,L,R)d[i]=a[i]; sort(d+L,d+R+1); return; }up(i,l,bel[l]*len)a[i]+=val; int L=(bel[l]-1)*len+1,R=min(bel[l]*len,n); up(i,L,R)d[i]=a[i]; sort(d+L,d+R+1); up(i,bel[l]+1,bel[r]-1)lz[i]+=val; L=(bel[r]-1)*len+1,R=min(bel[r]*len,n); up(i,(bel[r]-1)*len+1,r)a[i]+=val; up(i,L,R)d[i]=a[i]; sort(d+L,d+R+1); }int qry(int lim){ int res=0; up(i,1,bel[n])res+=min(i*len,n)-(lower_bound(d+(i-1)*len+1,d+min(i*len,n)+1,-lz[i]+lim)-d)+1; return res; } void dfs(int u,int fa){ dfn[u]=++cnt,siz[u]=1; for(int x:v2[u])if(x!=fa){ dep[x]=dep[u]+1;dfs(x,u); siz[u]+=siz[x]; } } void dfs2(int u,int fa,int now){ res[u]=qry(now); for(int x:v2[u])if(x!=fa){ upd(dfn[x],dfn[x]+siz[x]-1,2); dfs2(x,u,now+1); upd(dfn[x],dfn[x]+siz[x]-1,-2); } } void slv(){ n=read();m=n; up(i,1,n)a2[i]=read()*2; pw3[0]=1;up(i,1,4e5)pw3[i]=pw3[i-1]*3ll%mod; up(i,1,n-1){ int x=read(),y=read(); ++m; v2[m].p_b(x),v2[m].p_b(y),v2[x].p_b(m),v2[y].p_b(m); }dfs(1,0); up(i,1,n)g.p_b(dfn[i]);sort(g.begin(),g.end()); up(i,1,n)a[lower_bound(g.begin(),g.end(),dfn[i])-g.begin()+1]=a2[i]-dep[i]; len=950;up(i,1,n)bel[i]=(i-1)/len+1,d[i]=a[i]; up(i,1,bel[n])sort(d+(i-1)*len+1,d+min(i*len,n)+1); dfs2(1,0,0); int RES=0; up(i,1,m){ if(i<=n)(RES+=pw3[res[i]]-1)%=mod; else (RES-=pw3[res[i]]-1-mod)%=mod; }cout<<RES<<'\n'; // cout<<"time="<<clock()*1.0/CLOCKS_PER_SEC<<"s\n"; } int main(){ // freopen("tree.in","r",stdin); // freopen("tree.out","w",stdout); slv(); fclose(stdin); fclose(stdout); return 0; }