提交时间:2024-10-23 08:09:52

运行 ID: 33819

#include<bits/stdc++.h> #define up(i,l,r) for(int i=(l);i<=(r);++i) #define down(i,l,r) for(int i=(l);i>=(r);--i) #define pi pair<int,int> #define p1 first #define p2 second #define m_p make_pair #define p_b push_back using namespace std; typedef long long ll; const int maxn=2e5+10; inline ll read(){ ll x=0;short t=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')t=-1;ch=getchar();} while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar(); return x*t; }int n,dep[maxn],dfn[maxn],siz[maxn],rnk[maxn],ct; int f[maxn],g[maxn],h[maxn]; int ff[maxn],gg[maxn]; int sizf[maxn]; string S; vector<int>v[maxn]; struct SegTree { struct nd { int lz,mx; }d[maxn<<2]; #define ls(p) (p<<1) #define rs(p) (p<<1|1) #define lz(p) d[p].lz #define mx(p) d[p].mx void pu(int p){mx(p)=max(mx(ls(p)),mx(rs(p)));} void cl(int p,int x){lz(p)+=x,mx(p)+=x;} void pd(int p){cl(ls(p),lz(p)),cl(rs(p),lz(p)),lz(p)=0;} void bd(int l,int r,int p){ if(l==r){ int x=rnk[l]; mx(p)=dep[x];return; }int mid=l+r>>1; bd(l,mid,ls(p)),bd(mid+1,r,rs(p));pu(p); } void modify(int l,int r,int s,int t,int p,int x){if(l>r)return; if(l<=s&&t<=r){cl(p,x);return;}pd(p); int mid=s+t>>1; if(l<=mid)modify(l,r,s,mid,ls(p),x);if(r>=mid+1)modify(l,r,mid+1,t,rs(p),x);pu(p); }int qry(int l,int r,int s,int t,int p){if(l>r)return -1e9; if(l<=s&&t<=r)return mx(p);pd(p); int mid=s+t>>1,mx=0; if(l<=mid)mx=max(mx,qry(l,r,s,mid,ls(p)));if(r>=mid+1)mx=max(mx,qry(l,r,mid+1,t,rs(p))); return mx; } }T; void dfs(int u,int fa){ dfn[u]=++ct,rnk[ct]=u,siz[u]=1,sizf[u]=(S[u]=='1'); for(int x:v[u])if(x!=fa)dep[x]=dep[u]+1,dfs(x,u),siz[u]+=siz[x],sizf[u]+=sizf[x]; } void dfs2(int u,int fa){ vector<pi>S; for(int x:v[u]){ if(x^fa)S.p_b(m_p(T.qry(dfn[x],dfn[x]+siz[x]-1,1,n,1),x)); else S.p_b(m_p(max(T.qry(1,dfn[u]-1,1,n,1),T.qry(dfn[u]+siz[u],n,1,n,1)),x)); } sort(S.begin(),S.end(),greater<pi>()); f[u]=S[0].p1;if(S.size()>1)g[u]=S[1].p1; ff[u]=0;if(S[0].p2==fa){if(S.size()>1)ff[u]=S[1].p1;}else ff[u]=S[0].p1; gg[u]=max(T.qry(1,dfn[u]-1,1,n,1),T.qry(dfn[u]+siz[u],n,1,n,1)); for(int x:v[u])if(x^fa){ T.modify(1,dfn[x]-1,1,n,1,1); T.modify(dfn[x]+siz[x],n,1,n,1,1); T.modify(dfn[x],dfn[x]+siz[x]-1,1,n,1,-1); dfs2(x,u); T.modify(1,dfn[x]-1,1,n,1,-1); T.modify(dfn[x]+siz[x],n,1,n,1,-1); T.modify(dfn[x],dfn[x]+siz[x]-1,1,n,1,1); } } void slv(){ n=read(); up(i,1,n-1){ int x=read(),y=read(); v[x].p_b(y),v[y].p_b(x); }cin>>S;S=" "+S; dfs(1,0);T.bd(1,n,1);dfs2(1,0); up(i,1,n){h[i]=1e9; if(S[i]=='1')h[i]=0; if(sizf[i]!=sizf[1])h[i]=min(h[i],gg[i]); for(int x:v[i]){ if(dfn[x]>dfn[i]){if(sizf[x])h[i]=min(h[i],ff[x]+1);} } //printf("%d f:%d g:%d h:%d\n",i,f[i],g[i],h[i]); //printf("%d ff:%d gg:%d h:%d\n",i,ff[i],gg[i],h[i]); } ll res=1; up(i,1,n)res+=max(min(g[i]+1,f[i]-1)-h[i]+1,0); cout<<res; }int main(){ //freopen("apple.in","r",stdin); //freopen("apple.out","w",stdout); slv(); fclose(stdin); fclose(stdout); return 0; }