提交时间:2023-12-03 21:09:33
运行 ID: 23904
#include <bits/stdc++.h> #define LL long long #define PII pair<int,int> using namespace std; const int MAXN=1e5+5; LL dp[MAXN][4];//0-m,1-w&down/w&indis,2-w&!down,3-no LL ans[MAXN]; int N,Q,K,L; vector<int>gra[MAXN]; int m[MAXN],w[MAXN],f[MAXN]; vector<int>vec[MAXN]; void dpfs(int u,int fa,int dis) { dp[u][0]=m[u]; if(dis<=L) dp[u][1]=w[u]; else dp[u][2]=w[u]; dp[u][3]=0; vec[dis].push_back(u); f[u]=fa; for(int v:gra[u]) { if(v==fa) continue; dpfs(v,u,dis+1); dp[u][0]+=max(max(dp[v][0],dp[v][1]),dp[v][3]); dp[u][1]=max(dp[u][1]+max(max(dp[v][0],dp[v][1]),dp[v][3]),dp[u][2]+dp[v][3]); if(dis>L) dp[u][2]+=max(dp[v][0],dp[v][1]); dp[u][3]+=max(max(dp[v][0],dp[v][1]),max(dp[v][2],dp[v][3])); } return; } int main() { // freopen("tree.in","r",stdin); // freopen("tree.out","w",stdout); scanf("%d %d",&N,&Q); for(int i=1;i<=N;i++) scanf("%d",&m[i]); for(int i=1;i<=N;i++) scanf("%d",&w[i]); scanf("%d",&K); for(int i=1;i<N;i++) { int u,v; scanf("%d %d",&u,&v); gra[u].push_back(v); gra[v].push_back(u); } L=0; for(int i=1;i<=N;i++) for(int j=0;j<4;j++) dp[i][j]=-1e17; dpfs(K,0,0); LL x=0; ans[0]=max(max(dp[K][0],dp[K][1]),dp[K][3]); for(int i=1;i<=N;i++) { for(int j:vec[i-1]) x+=max(m[j],w[j]); ans[i]=x; for(int j:vec[i]) { LL res[4]; res[0]=m[j]; res[1]=w[j]; res[3]=0; for(int v:gra[j]) { if(v==f[j]) continue; res[0]+=max(max(dp[v][0],dp[v][1]),dp[v][3]); res[1]+=max(max(dp[v][0],dp[v][1]),dp[v][3]); res[3]+=max(max(dp[v][0],dp[v][1]),max(dp[v][2],dp[v][3])); } ans[i]+=max(res[0],max(res[1],res[3])); } } while(Q--) { scanf("%d",&L); printf("%lld\n",ans[L]); } return 0; }