提交时间:2024-10-20 16:49:45

运行 ID: 33739

#include<bits/stdc++.h> using namespace std; #define int long long const int mod=998244353; const int N=5005; int a[N]; int f[N][2]; bool mk[N],vis[N]; vector<int> g[N]; int n; void dfs(int u){ //cout<<u<<" "; int c=0; vis[u]=1; f[u][0]=1; f[u][1]=0; for(int v:g[u])if(!vis[v]){ if(mk[v]){ dfs(v); int t0=f[u][0],t1=f[u][1]; f[u][1]=(t1*f[v][0]+t0*f[v][1]+t1*f[v][1]*2)%mod; f[u][0]=(t0*f[v][0]+t0*f[v][1]*2)%mod; } else c++; } if(c){ f[u][1]=(f[u][1]*a[c]%mod+f[u][0]*c%mod*a[c-1]%mod)%mod; f[u][0]=f[u][0]*a[c]%mod; } //cout<<"c="<<c<<endl; } signed main(){ cin>>n; for(int i=1;i<n;i++){ int x,y; cin>>x>>y; g[x].push_back(y); g[y].push_back(x); } a[0]=1; for(int i=1;i<=n;i++){ a[i]=(a[i-1]*2)%mod; } for(int i=1;i<n;i++){ memset(vis,0,sizeof vis); int ans=1; int x; cin>>x; mk[x]=1; for(int j=1;j<=n;j++){ if(mk[j]&&!vis[j]){ dfs(j); //cout<<i<<" "<<j<<endl; //cout<<endl; //cout<<i<<" "<<j<<" "<<f[j][1]<<endl; ans*=f[j][1]; ans%=mod; } } cout<<ans<<endl; } }