提交时间:2024-08-28 15:40:57

运行 ID: 31956

#include<bits/stdc++.h> #define up(i,l,r) for(int i=(l);i<=(r);++i) #define down(i,l,r) for(int i=(l);i>=(r);--i) #define pi pair<int,int> #define p1 first #define p2 second #define m_p make_pair #define p_b push_back using namespace std; typedef long long ll; typedef unsigned long long ull; const int maxn=5e5+10,mod=998244353; inline ll read(){ ll x=0;short t=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')t=-1;ch=getchar();} while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar(); return x*t; } int n,p[maxn],FA[maxn],cc,tag; int siz[maxn],fa[maxn],sz[maxn]; int fd(int x){if(x==fa[x])return x;return fa[x]=fd(fa[x]);} vector<int>v[maxn]; void mg(int x,int y){ x=fd(x),y=fd(y);if(sz[x]&&sz[y])cc++; if(sz[x])tag-=(sz[x]!=siz[x]); if(sz[y])tag-=(sz[y]!=siz[y]); fa[x]=y,sz[y]+=sz[x]; tag+=(sz[y]!=siz[y]); } void upd(){ int a=read(),b=read(); v[a].erase(lower_bound(v[a].begin(),v[a].end(),b)); v[b].erase(lower_bound(v[b].begin(),v[b].end(),a)); int c=read(),d=read(); v[c].insert(lower_bound(v[c].begin(),v[c].end(),d),d); v[d].insert(lower_bound(v[d].begin(),v[d].end(),c),c); } void dfs(int u){ siz[u]=1; for(int x:v[u])if(x!=FA[u])FA[x]=u,dfs(x),siz[u]+=siz[x]; } void qry(){ up(i,1,n)FA[i]=0; dfs(1);up(i,1,n)fa[i]=i,sz[i]=0;tag=cc=0;ll res=0; up(i,1,n){ int x=p[i];sz[x]=1;tag+=(sz[x]!=siz[x]); //cout<<"test "<<x<<" "<<FA[x]<<"\n"; for(int y:v[x]){ if(!sz[y])continue;if(y==FA[x])continue; mg(y,x); }if(FA[x]&&sz[FA[x]])mg(x,FA[x]); //cout<<"? "<<i<<" "<<tag<<"\n"; if(!tag)res+=i-cc; }printf("%lld\n",res); } void slv(){ n=read();int q=read();q=1; up(i,1,n-1){int x=read(),y=read();v[x].p_b(y),v[y].p_b(x);} up(i,1,n)p[i]=read(); up(i,1,n)sort(v[i].begin(),v[i].end()); while(q--){ upd();qry(); } } int main(){ //freopen("bird.in","r",stdin); //freopen("bird.out","w",stdout); slv(); fclose(stdin); fclose(stdout); return 0; }