提交时间:2024-10-23 08:07:30

运行 ID: 33817

//100,数组开小 #include<bits/stdc++.h> using namespace std; #define int long long #define pii pair<int,int> #define fr first #define sc second #define mk make_pair #define inx(u) int I=h[(u)],v=edge[I].v;I;I=edge[I].nx,v=edge[I].v int read(){int x=0,f=1;char c=getchar();while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}while(c<='9'&&c>='0')x=x*10+c-'0',c=getchar();return x*f;} const int MAXN=200010,base=2333,Mod=1011451423,inf=1000000000; int Pow(int x,int y){int rt=1;while(y){if(y&1)rt=rt*x%Mod;x=x*x%Mod;y>>=1;}return rt;} struct Edge{int v,nx;}edge[MAXN<<1];int h[MAXN],CNT;void add_side(int u,int v){edge[++CNT]={v,h[u]};h[u]=CNT;edge[++CNT]={u,h[v]};h[v]=CNT;} int n,m,ans; struct node{ pii mx,nx; node operator*(const pii G)const{ node res; if(G>mx)res={G,mx}; else res={mx,max(G,nx)}; return res; } }f[MAXN]; char c[MAXN]; int p[MAXN]; void dfs(int u,int lst){ f[u]={{0,0},{0,0}}; if(c[u]=='1')p[u]++; for(inx(u))if(v!=lst){ dfs(v,u); f[u]=f[u]*mk(f[v].mx.fr+1,v); p[u]+=p[v]; } } void dfs2(int u,int lst){ int tmp=c[u]=='1'?0:inf; // cout<<u<<":"<<f[u].mx.fr<<"_"<<f[u].mx.sc<<" "<<f[u].nx.fr<<"_"<<f[u].nx.sc<<endl; for(inx(u))if(p[v])tmp=min(tmp,f[v].mx.fr+1);//,cout<<f[v].mx.fr<<" ";cout<<tmp<<endl; ans+=max(0ll,min(f[u].mx.fr-1,f[u].nx.fr+1)-tmp+1); // cout<<min(f[u].mx.fr-1,f[u].nx.fr+1)-tmp+1<<endl; for(inx(u))if(v!=lst){ if(f[u].mx.sc==v)swap(f[u].mx,f[u].nx); f[v]=f[v]*mk(f[u].mx.fr+1,u); p[u]-=p[v]; p[v]+=p[u]; dfs2(v,u); p[v]-=p[u]; p[u]+=p[v]; if(f[u].nx>f[u].mx)swap(f[u].mx,f[u].nx); } } void slv(){ n=read(); for(int i=1;i<n;i++){ int u=read(),v=read(); add_side(u,v); } scanf("%s",c+1); dfs(1,1); dfs2(1,1); printf("%lld",ans+1); } signed main(){ slv(); return 0; }