提交时间:2023-12-11 19:17:34

运行 ID: 24166

#include<bits/stdc++.h> #define up(i,l,r) for(int i=(l);i<=(r);++i) #define down(i,l,r) for(int i=(l);i>=(r);--i) #define p_b push_back using namespace std; typedef unsigned long long ull; typedef long long ll; const int maxn=1e5+10,mod=998244353; inline int read(){ int x=0; short t=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')t=-1;ch=getchar();} while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar(); return x*t; }int n,m,dep[maxn],pw3[maxn],a[maxn]; vector<int>v[maxn]; void dfs(int u,int fa){ for(int x:v[u])if(x!=fa){ dep[x]=dep[u]+1;dfs(x,u); } } void slv(){ n=read();m=n; up(i,1,n)a[i]=read()*2; pw3[0]=1;up(i,1,1e5)pw3[i]=pw3[i-1]*3ll%mod; up(i,1,n-1){ int x=read(),y=read(); ++m; v[m].p_b(x),v[m].p_b(y),v[x].p_b(m),v[y].p_b(m); }int res=0; up(i,1,m){ dep[i]=0;dfs(i,0); int cnt=0; up(j,1,n)if(dep[j]<=a[j])cnt++; if(i<=n)(res+=pw3[cnt]-1)%=mod; else (res-=pw3[cnt]-1-mod)%=mod; }cout<<res<<'\n'; } int main(){ // freopen("tree.in","r",stdin); // freopen("tree.out","w",stdout); slv(); fclose(stdin); fclose(stdout); return 0; }