提交时间:2025-10-27 19:12:11
运行 ID: 38784
#include <bits/stdc++.h> using namespace std; #define int long long const int mod = 998244353; int n, w, m, ans; int f[3005][1005], g[2][1000005]; signed main() { cin >> n >> w; m = sqrt(n); for (int i = 1; i <= m; i++) f[i][i] = 1; for (int i = 1; i < n; i++) { int p = (i - 1) % (3 * m) + 1; for (int j = 1; j <= 2 * m; j++) { if (!f[p][j]) continue; if (j > 1 && i + j - 1 <= n) { int q = (i + j - 2) % (3 * m) + 1; f[q][j - 1] = (f[q][j - 1] + f[p][j] * w % mod) % mod; } if (i + j + 1 <= n) { int q = (i + j) % (3 * m) + 1; f[q][j + 1] = (f[q][j + 1] + f[p][j]) % mod; } } for (int j = 1; j <= 2 * m; j++) f[p][j] = 0; } int p = (n - 1) % (3 * m) + 1; for (int i = 1; i <= 2 * m; i++) ans = (ans + f[p][i]) % mod; g[0][2 * n] = 1; int sum = 0; for (int i = 0; i < 2 * m; i++) { sum += i; int p = i & 1, q = !p; for (int j = -sum; j <= sum; j++) { if ((n - j) % (i + 1) == 0 && (n - j) / (i + 1) > m) ans = (ans + g[p][j + 2 * n]) % mod; if (j + i + 1 <= n) g[q][j + i + 1 + 2 * n] = (g[q][j + i + 1 + 2 * n] + g[p][j + 2 * n]) % mod; g[q][j - i - 1 + 2 * n] = (g[q][j - i - 1 + 2 * n] + g[p][j + 2 * n] * w % mod) % mod; } for (int j = -sum; j <= sum; j++) g[p][j + 2 * n] = 0; } cout << ans; return 0; }