提交时间:2025-10-27 15:03:42
运行 ID: 38768
#include<bits/stdc++.h> #define up(i,l,r) for(int i=(l);i<=(r);++i) #define down(i,l,r) for(int i=(l);i>=(r);--i) #define pi pair<int,int> #define p1 first #define p2 second #define m_p make_pair #define p_b push_back using namespace std; typedef long long ll; const int maxn=2e5+10,B=650; const int mod=998244353; inline ll read(){ ll x=0;short t=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')t=-1;ch=getchar();} while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar(); return x*t; } int A[maxn*4]; int n,w,f[1005][1005]; int *g=A+maxn,*h=A+maxn*3; inline int add(int a,int b){if((a+=b)>=mod)a-=mod;return a;} void slv(){ n=read(),w=read(); //a[1]<=B up(i,1,min(n,B))f[i][i]=1; up(i,1,n){ int o=i%1000,nx1=(o+1)%1000,nx2=(o+999)%1000; up(j,1,min(i,1000)){ if((++nx1)>=1000)nx1=0; if((++nx2)>=1000)nx2=0; f[nx1][j+1]=add(f[nx1][j+1],f[o][j]); f[nx2][j-1]=(f[nx2][j-1]+f[o][j]*1ll*w)%mod; } if(i!=n) memset(f[o],0,sizeof(f[o])); } int res=0; up(i,1,1000)res=add(res,f[n%1000][i]); if(n>B)(++res)%=mod; //a[1]>B g[0]=1; up(i,1,B+5){ up(j,-n,n)h[j]=0; up(j,-n,n)if(g[j]){ if(j+i<=n)h[j+i]=add(h[j+i],g[j]); if(j-i>=-n)h[j-i]=(h[j-i]+g[j]*1ll*w)%mod; } up(j,-n,n)g[j]=h[j]; up(j,-n,n)if((n-j)%(i+1)==0){int k=(n-j)/(i+1);if(k>B)res=add(res,g[j]);} } cout<<res; } int main(){ //freopen("dazzling.in","r",stdin),freopen("dazzling.out","w",stdout); slv(); return 0; }