提交时间:2024-10-20 16:07:06

运行 ID: 33727

#include<bits/stdc++.h> using namespace std; const int N=5050,mod=998244353; bool mk[N];int n; vector<int> E[N]; void plu(int &a,int b){ a+=b;if(a>=mod) a-=mod; }void mult(int &a,int b){ a=1ll*a*b%mod; }int f[N];bool vis[N]; int g[N],h[N]; int dfs(int u,int fa){ int cnt=0;vis[u]=1; //g[u]=1,h[u]=0; for(auto v:E[u]){ if(v==fa) continue; if(!mk[v]){ cnt++; /*mult(h[u],2); plu(h[u],g[u]); mult(g[u],2);*/ }else{ /*mult(h[u],g[v]); plu(h[u],1ll*h[v]*g[u]%mod); mult(g[u],g[v]);*/ cnt+=dfs(v,u); } }f[u]=cnt; //cout<<u<<' '<<f[u]<<endl; return cnt; }int tmp=0;int pw2[N]; int ans=0; int g2(int a){ if(a>=0) return pw2[a]; return 0; } int dfs1(int u,int fa){ int cnt=0; int now=0; g[u]=1,h[u]=0; for(auto v:E[u]){ if(v==fa) continue; if(mk[v]){ int res=dfs1(v,u); mult(h[u],(2*h[v]+g[v])%mod); plu(h[u],1ll*g[u]*h[v]%mod); mult(g[u],(1ll*g[v]+1ll*2*h[v])%mod); plu(cnt,dfs1(v,u)); }else{ mult(h[u],2); plu(h[u],g[u]); mult(g[u],2); //plu(cnt,dfs1(v,u)); } }if(u!=fa&&(f[u]!=0&&f[u]!=tmp)) plu(cnt,1ll*(1ll*g2(f[u]-1)*f[u]%mod*1ll*g2(tmp-f[u]-1)%mod*(tmp-f[u])%mod)*2%mod); //cout<<u<<' '<<cnt<<' '<<f[u]<<' '<<g[u]<<' '<<h[u]<<endl; return cnt; } int main(){ cin>>n; for(int i=1;i<n;i++){ int u,v;cin>>u>>v; E[u].push_back(v); E[v].push_back(u); }pw2[0]=1; for(int i=1;i<=n;i++) pw2[i]=1ll*pw2[i-1]*2%mod; for(int i=1;i<n;i++){ int x;cin>>x;mk[x]=1; memset(f,0,sizeof(f));memset(vis,0,sizeof(vis)); ans=1; for(int j=1;j<=n;j++){ //cout<<mk[j]<<' '; if(!vis[j]&&mk[j]){ tmp=dfs(j,0); int now=/*1ll*tmp*g2(tmp-1)*/0; //plu(now,dfs1(j,j)); dfs1(j,j);plu(now,(h[j])%mod); //cout<<tmp<<' '<<now<<' '; mult(ans,now); //cout<<endl<<"============"<<endl; } }cout<<ans<<endl; } return 0; }