Run ID | 作者 | 问题 | 语言 | 测评结果 | 分数 | 时间 | 内存 | 代码长度 | 提交时间 |
---|---|---|---|---|---|---|---|---|---|
29484 | LYLAKIOI | 【BJ】T3 | C++ | 运行超时 | 50 | 1000 MS | 29292 KB | 3152 | 2024-05-08 16:57:07 |
#include<bits/stdc++.h> #define up(i,l,r) for(int i=(l);i<=(r);++i) #define down(i,l,r) for(int i=(l);i>=(r);--i) #define p_b push_back #define pi pair<int,int> #define p1 first #define p2 second #define m_p make_pair using namespace std; typedef long long ll; const int maxn=1e5+10; inline ll read(){ ll x=0;short t=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')t=-1;ch=getchar();} while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar(); return x*t; } int n,P,K,R,siz[maxn],mxs[maxn],dep[maxn],pw[maxn],pw_inv[maxn],vis[maxn],a[maxn]; vector<int>v[maxn]; int in[maxn][2],out[maxn][2]; int h[maxn],g[maxn]; vector<int>Q; map<int,int>mp; int qpow(int a,int b,int mod){ int res=1;while(b){if(b&1)res=res*1ll*a%mod;a=a*1ll*a%mod;b>>=1;}return res; } void dfs(int u,int fa){ mp[g[u]]++; siz[u]=1,mxs[u]=0; for(int x:v[u])if(x!=fa&&(!vis[x])){dep[x]=dep[u]+1,h[x]=(h[u]*1ll*K%P+a[x])%P,g[x]=(g[u]+a[x]*1ll*pw[dep[x]-1]%P)%P;dfs(x,u);siz[u]+=siz[x],mxs[u]=max(mxs[u],siz[x]);} } void dfs2(int u,int fa){ Q.p_b(u); for(int x:v[u])if(x!=fa&&(!vis[x]))dfs2(x,u); } void dfs3(int u){ h[u]=a[u],g[u]=0; vis[u]=1,mp.clear(),dep[u]=0;dfs(u,0); for(int x:v[u])if(!vis[x]){ Q.clear(),dfs2(x,0); for(int x:Q)mp[g[x]]--; for(int x:Q){ int val=(R-h[x]+P)%P*1ll*pw_inv[dep[x]+1]%P; int sum=siz[u]-int(Q.size()); if(mp.count(val))out[x][0]+=mp[val],sum-=mp[val]; out[x][1]+=sum; }for(int x:Q)mp[g[x]]++; }int val=(R-h[u]+P)%P*1ll*pw_inv[1]%P; int sum=siz[u]; if(mp.count(val))out[u][0]+=mp[val],sum-=mp[val]; out[u][1]+=sum; mp.clear(),Q.clear();dfs2(u,0); for(int x:Q)mp[(R-h[x]+P)%P*1ll*pw_inv[dep[x]+1]%P]++; for(int x:v[u])if(!vis[x]){ Q.clear(),dfs2(x,0); for(int x:Q)mp[(R-h[x]+P)%P*1ll*pw_inv[dep[x]+1]%P]--; for(int x:Q){ int val=g[x]; int sum=siz[u]-int(Q.size()); if(mp.count(val))in[x][0]+=mp[val],sum-=mp[val]; in[x][1]+=sum; }for(int x:Q)mp[(R-h[x]+P)%P*1ll*pw_inv[dep[x]+1]%P]++; }val=g[u],sum=siz[u]; if(mp.count(g[u]))in[u][0]+=mp[val],sum-=mp[val]; in[u][1]+=sum; mp.clear(); for(int x:v[u])if(!vis[x]){ Q.clear(),dfs2(x,0);for(int y:Q)mxs[y]=max(mxs[y],int(Q.size())-siz[y]); int mn=Q[0];for(int y:Q)if(mxs[y]<mxs[mn])mn=y;dfs3(mn); } } void slv(){ n=read(),P=read(),K=read(),R=read();pw[0]=1;up(i,1,n)pw[i]=pw[i-1]*1ll*K%P;up(i,0,n)pw_inv[i]=qpow(pw[i],P-2,P); up(i,1,n)a[i]=read();up(i,1,n-1){int x=read(),y=read();v[x].p_b(y),v[y].p_b(x);}dfs3(1); //up(i,1,n)printf("test in0 %d in1 %d out0 %d out1 %d\n",in[i][0],in[i][1],out[i][0],out[i][1]); ll res=0;up(i,1,n)res+=in[i][0]*2ll*in[i][1];up(i,1,n)res+=out[i][0]*2ll*out[i][1];up(i,1,n)res+=in[i][0]*1ll*out[i][1]+in[i][1]*1ll*out[i][0]; cout<<(n*1ll*n*1ll*n-res/2); }int main(){ // freopen("a.in","r",stdin); // freopen("a.out","w",stdout); slv(); fclose(stdin); fclose(stdout); return 0; }