提交时间:2023-12-13 10:07:11

运行 ID: 24192

#include<bits/stdc++.h> #pragma gcc optimize(2) #define up(i,l,r) for(int i=(l);i<=(r);++i) #define down(i,l,r) for(int i=(l);i>=(r);--i) #define p_b push_back using namespace std; typedef unsigned long long ull; typedef long long ll; const int maxn=5e5+10,mod=998244353; inline int read(){ int x=0; short t=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')t=-1;ch=getchar();} while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar(); return x*t; }int n,m,dep[maxn],pw3[maxn],a[maxn],a2[maxn],dfn[maxn],siz[maxn],res[maxn],cnt; vector<int>v2[maxn]; vector<int>g; struct SegTree { struct node { int mx,mn,lz; }d[maxn<<2]; #define ls(p)(p<<1) #define rs(p) (p<<1|1) #define lz(p) d[p].lz #define mx(p) d[p].mx #define mn(p) d[p].mn void pu(int p){mx(p)=max(mx(ls(p)),mx(rs(p))),mn(p)=min(mn(ls(p)),mn(rs(p)));} void cl(int p,int x){lz(p)+=x,mx(p)+=x,mn(p)+=x;} void pd(int p){cl(ls(p),lz(p)),cl(rs(p),lz(p)),lz(p)=0;} void bd(int l,int r,int p){ if(l==r){ mx(p)=mn(p)=a[l];return; }int mid=l+r>>1; bd(l,mid,ls(p)),bd(mid+1,r,rs(p));pu(p); } void upd(int l,int r,int s,int t,int p,int x){ if(l<=s&&t<=r){cl(p,x);return;}pd(p); int mid=s+t>>1; if(l<=mid)upd(l,r,s,mid,ls(p),x);if(r>=mid+1)upd(l,r,mid+1,t,rs(p),x);pu(p); }int qry(int s,int t,int p,int lim){ if(s==t)return (mn(p)>=lim);pd(p); if(mn(p)>=lim)return t-s+1; if(mx(p)<lim)return 0; int mid=s+t>>1; return qry(s,mid,ls(p),lim)+qry(mid+1,t,rs(p),lim); } }T; void upd(int l,int r,int val){ auto it=lower_bound(g.begin(),g.end(),l); if(it==g.end())return; l=it-g.begin()+1; it=upper_bound(g.begin(),g.end(),r); if(it==g.begin())return; it--,r=it-g.begin()+1; if(l>r)return; T.upd(l,r,1,n,1,val); } void dfs(int u,int fa){ dfn[u]=++cnt,siz[u]=1; for(int x:v2[u])if(x!=fa){ dep[x]=dep[u]+1;dfs(x,u); siz[u]+=siz[x]; } } void dfs2(int u,int fa,int now){ res[u]=T.qry(1,n,1,now); for(int x:v2[u])if(x!=fa){ upd(dfn[x],dfn[x]+siz[x]-1,2); dfs2(x,u,now+1); upd(dfn[x],dfn[x]+siz[x]-1,-2); } } void slv(){ n=read();m=n; up(i,1,n)a2[i]=read()*2; pw3[0]=1;up(i,1,4e5)pw3[i]=pw3[i-1]*3ll%mod; up(i,1,n-1){ int x=read(),y=read(); ++m; v2[m].p_b(x),v2[m].p_b(y),v2[x].p_b(m),v2[y].p_b(m); }up(i,1,2*n-1)random_shuffle(v2[i].begin(),v2[i].end()); dfs(1,0); up(i,1,n)g.p_b(dfn[i]);sort(g.begin(),g.end()); up(i,1,n)a[lower_bound(g.begin(),g.end(),dfn[i])-g.begin()+1]=a2[i]-dep[i]; T.bd(1,n,1);dfs2(1,0,0); int RES=0; up(i,1,m){ if(i<=n)(RES+=pw3[res[i]]-1)%=mod; else (RES-=pw3[res[i]]-1-mod)%=mod; }cout<<RES<<'\n'; // cout<<"time="<<clock()*1.0/CLOCKS_PER_SEC<<"s\n"; } int main(){ // freopen("tree.in","r",stdin); // freopen("tree.out","w",stdout); slv(); fclose(stdin); fclose(stdout); return 0; }