提交时间:2024-10-22 20:38:14
运行 ID: 33810
//100 #include<bits/stdc++.h> using namespace std; #define int long long #define pii pair<int,int> #define p1(x) x.first #define p2(x) x.second int n; vector<int>g[200300]; int f[200300][2]; int h[200300]; int p[200300]; int S,s[200300]; inline void upd(int u,int x){ if(f[u][0]<x)f[u][1]=f[u][0],f[u][0]=x; else f[u][1]=max(f[u][1],x); } inline void dfs(int u,int fa){ f[u][0]=0; f[u][1]=-1e9; h[u]=1e9; for(int v:g[u])if(v!=fa){ dfs(v,u); upd(u,f[v][0]+1); s[u]+=s[v]; if(s[v]) h[u]=min(h[u],f[v][0]+1); } } inline void dfs2(int u,int fa){ if(fa){ int x=f[fa][0]+1; if(f[fa][0]==f[u][0]+1)x=f[fa][1]+1; upd(u,x); if(S-s[u])h[u]=min(h[u],x); } // cout<<u<<" "<<f[u][0]<<" "<<f[u][1]<<endl; for(int v:g[u])if(v!=fa)dfs2(v,u); } signed main(){ ios::sync_with_stdio(0); cin.tie(0),cout.tie(0); // freopen("genshin.in","r",stdin); // freopen("genshin.out","w",stdout); cin>>n; for(int i=1;i<n;i++){ int u,v; cin>>u>>v; g[u].push_back(v); g[v].push_back(u); } for(int i=1;i<=n;i++){ char x; cin>>x; // assert(x!='0'&&x!='1'); p[i]=s[i]=bool(x-'0'); S+=s[i]; } int res=1; dfs(1,0); dfs2(1,0); for(int u=1;u<=n;u++){ if(p[u])h[u]=0; res+=max(0ll,min(f[u][0]-1,f[u][1]+1)-h[u]+1); } cout<<res<<endl; cout.flush(); return 0; }