提交时间:2024-08-28 15:35:53

运行 ID: 31955

#include<bits/stdc++.h> #define up(i,l,r) for(int i=(l);i<=(r);++i) #define down(i,l,r) for(int i=(l);i>=(r);--i) #define pi pair<int,int> #define p1 first #define p2 second #define m_p make_pair #define p_b push_back using namespace std; typedef long long ll; typedef unsigned long long ull; const int maxn=5e5+10; inline ll read(){ ll x=0;short t=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')t=-1;ch=getchar();} while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar(); return x*t; } int n,p[maxn],f[maxn],fa[maxn],a[maxn],b[maxn]; vector<int>v[maxn]; void dfs(int u){ for(int x:v[u])if(x!=fa[u]){ fa[x]=u;dfs(x); } } struct nd { int mn,cnt; ll sm; nd(){} nd(int _mn,int _cnt,ll _sm){mn=_mn,cnt=_cnt,sm=_sm;} void g(int x,int y){mn+=x,sm+=cnt*1ll*y;} }; nd operator+(nd a,nd b){ if(a.mn<b.mn)return a; if(a.mn>b.mn)return b; return nd(a.mn,a.cnt+b.cnt,a.sm+b.sm); } struct SegTree { struct node { nd S; int lz1,lz2; }d[maxn<<2]; #define ls(p) (p<<1) #define rs(p) (p<<1|1) #define S(p) d[p].S #define lz1(p) d[p].lz1 #define lz2(p) d[p].lz2 void pu(int p){S(p)=S(ls(p))+S(rs(p));} void cl(int p,int x,int y){S(p).g(x,y);lz1(p)+=x,lz2(p)+=y;} void pd(int p){cl(ls(p),lz1(p),lz2(p)),cl(rs(p),lz1(p),lz2(p)),lz1(p)=lz2(p)=0;} void bd(int l,int r,int p){ S(p).cnt=r-l+1; if(l==r)return; int mid=l+r>>1; bd(l,mid,ls(p)),bd(mid+1,r,rs(p)); } void upd(int l,int r,int s,int t,int p,int x,int y){ if(l<=s&&t<=r){cl(p,x,y);return;}pd(p); int mid=s+t>>1; if(l<=mid)upd(l,r,s,mid,ls(p),x,y);if(r>=mid+1)upd(l,r,mid+1,t,rs(p),x,y);pu(p); } }T; void mod(int x,int a,int sgn){ if(f[x]>f[a])T.upd(f[a],f[x]-1,1,n,1,sgn,0); else T.upd(f[x],f[a]-1,1,n,1,0,sgn); } void upd(int x,int a,int b){ if(a==b)return; mod(x,a,-1),mod(x,b,1); } void slv(){ n=read();int q=read(); up(i,1,n-1){int x=read(),y=read();v[x].p_b(y),v[y].p_b(x);} up(i,1,n)p[i]=read(),f[p[i]]=i;dfs(1); T.bd(1,n,1);up(i,2,n)mod(i,fa[i],1); while(q--){ int a=read(),b=read(),c=read(),d=read(); if(fa[b]==a)swap(a,b); if(c!=a)upd(a,fa[a],c),fa[a]=c;else upd(a,fa[a],d),fa[a]=d; printf("%lld\n",(T.S(1).mn?0:T.S(1).sm)+1); } } int main(){ //freopen("bird.in","r",stdin); //freopen("bird.out","w",stdout); slv(); fclose(stdin); fclose(stdout); return 0; }