Submit Time:2026-01-31 20:41:07
运行 ID: 39781
#include<bits/stdc++.h> #define int long long using namespace std; const int N=3e3+7,mod=998244353; int n,m; int Fac[N*N],iFac[N*N]; int f[N],g[N]; inline void Init(){ Fac[0]=Fac[1]=iFac[0]=iFac[1]=1; for(int i=2;i<N*N;i++)iFac[i]=iFac[mod%i]*(mod-mod/i)%mod; for(int i=2;i<N*N;i++)Fac[i]=Fac[i-1]*i%mod,(iFac[i]*=iFac[i-1])%=mod; } inline int C(int n,int m){ return n<m||m<0?0:Fac[n]*iFac[n-m]%mod*iFac[m]%mod; } inline int A(int n,int m){ return n<m||m<0?0:Fac[n]*iFac[n-m]%mod; } signed main(){ #ifndef ONLINE_JUDGE freopen("matrix.in","r",stdin); freopen("matrix.out","w",stdout); #endif cin>>n>>m; Init(); for(int k=1;k<=min(n,m);k++){ int x=(n+m)*k-k*k; int S=0; g[k]=C(n*m,x); for(int i=1;i<=k;i++){ int p=n+m-2*k+2*i-2; (g[k]*=A(x-S-i,p))%=mod; S+=p; } } for(int k=1;k<=min(n,m);k++){ f[k]=C(n,k)*C(m,k)%mod*Fac[k]%mod*Fac[k]%mod*g[k]%mod; } int Ans=0; for(int i=1;i<=min(n,m);i++){ (Ans+=(i&1?f[i]*i%mod:mod-f[i]*i%mod))%=mod; }cout<<Ans<<endl; }