提交时间:2025-02-09 20:22:21

运行 ID: 36188

#include<bits/stdc++.h> using namespace std; #define int long long #define fr first #define sc second #define pb push_back const int N=400010,M=998244353; int n,u,v,rd[N][2],ans,p[N]; vector<int>T[N],U[N]; pair<int,int>et[N],eu[N]; bool find(int u,int v){ sort(U[u].begin(),U[u].end()); int l=0,r=U[u].size()-1; while(l<r){ int mid=(l+r)>>1; if(U[u][mid]==v)return 1; if(U[u][mid]<v)l=mid+1; else if(U[u][mid]>v)r=mid-1; } if(l!=r)return 0; if(U[u][l]==v)return 1; return 0; } int ksm(int a,int b){ int ans=1; while(b>0){ if(b&1)ans=ans*a%M; a=a*a%M;b>>=1; } return ans; } signed main(){ p[0]=1;for(int i=1;i<=N-5;i++)p[i]=p[i-1]*2%M; scanf("%lld",&n); for(int i=1;i<=n-1;i++){ scanf("%lld%lld",&u,&v); rd[u][0]++;rd[v][0]++; T[u].pb(v);T[v].pb(u); et[i]={u,v}; } for(int i=1;i<=n-1;i++){ scanf("%lld%lld",&u,&v); rd[u][1]++;rd[v][1]++; U[u].pb(v);U[v].pb(u); eu[i]={u,v}; } //点*点 // for(int i=1;i<=n;i++)ans=(ans+(n-1)*p[n-2])%M; ans=n*(n-1)%M*p[n-2]%M; //点*边 // for(int i=1;i<=n;i++){ans=(ans-(n-1-rd[i][1])*p[n-3]%M+M)%M;} // for(int i=1;i<=n;i++){ans=(ans-(n-1-rd[i][0])*p[n-3]%M+M)%M;} ans-=2*(n-1)*(n-2)%M*p[n-3]%M;ans=(ans+M)%M; //边*边 for(int i=1;i<=n-1;i++){ if(!find(et[i].fr,et[i].sc)){ ans=(ans+(n-1-rd[et[i].fr][1]-rd[et[i].sc][1])*p[n-4]%M)%M; } else{ ans=(ans+(n-1-rd[et[i].fr][1]-rd[et[i].sc][1]+1)*p[n-4]%M)%M; } } // cout<<ans<<endl; ans=ans*ksm(p[n],M-2)%M; printf("%lld\n",ans); return 0; }