提交时间:2024-08-28 19:12:03
运行 ID: 31972
#include <bits/stdc++.h> #define int long long #define pii pair<int,int> #define mkp make_pair #define push_back emplace_back const int mod = 998244353; using namespace std; inline int read() { int fl = 1, x = 0; char c = getchar(); while (c < '0' || c > '9') { if (c == '-') fl = -1; c = getchar(); } while (c >= '0' && c <= '9') x = (x << 1) + (x << 3) + (c ^ 48), c = getchar(); return fl * x; } int tr0[3][3] = { {2, 0, 0}, {0, 2, 0}, {0, 0, 2} }; int tr1[3][3] = { {1, 1, 0}, {0, 1, 1}, {1, 0, 1} }; int tr2[3][3] = { {1, 0, 1}, {1, 1, 0}, {0, 1, 1} }; struct MATRIX { vector <vector <int>> a; int h, w; MATRIX () {} MATRIX (int n) { h = w = n; a.resize(h + 1); for (int i = 0; i < n; i++) a[i].resize(w + 1), a[i][i] = 1; } MATRIX (int n, int m) { h = n, w = m; a.resize(h + 1); for (int i = 0; i < h; i++) a[i].resize(w + 1); } MATRIX (int aa[3][3], int n, int m) { h = n, w = m; a.resize(n + 1); for (int i = 0; i < n; i++) { a[i].resize(m + 1); for (int j = 0; j < m; j++) a[i][j] = aa[i][j]; } } MATRIX& operator = (MATRIX xx) { a = xx.a; h = xx.h, w = xx.w; return *this; } vector<int>& operator [](int i) { return a[i]; } const MATRIX operator * (MATRIX xx) { if (xx.h != w) { printf("nimasile"); } MATRIX nw(h, xx.w); for (int i = 0; i < h; i++) { for (int j = 0; j < xx.w; j++) { for (int k = 0; k < w; k++) { (nw[i][j] += a[i][k] * xx[k][j] % mod) %= mod; } } } return nw; } const MATRIX operator ^ (int x) { MATRIX ret(h), now = *this; while (x) { if (x & 1) ret = ret * now; now = now * now, x >>= 1; } return ret; } } t0(tr0, 3, 3), t1(tr1, 3, 3), t2(tr2, 3, 3); int n, w, m2[200005]; char s[200005]; signed main() { scanf("%lld %lld %s", &n, &w, s + 1); int len = strlen(s + 1); m2[0] = 1; for (int i = 1; i <= len + 1; i++) { m2[i] = (m2[i - 1] << 1) % mod; } if (w == 2 || w == 5) { int c = 0; if (w == 2) { for (int i = 1; i <= len; i++) { if (s[i] == '2' || s[i] == '4' || s[i] == '6' || s[i] == '8') c = (c + m2[i - 1]) % mod; } } else { for (int i = 1; i <= len; i++) { if (s[i] == '5') c = (c + m2[i - 1]) % mod; } } int res = 1, l2 = m2[len]; for (int i = 1; i <= n; i++) { res = (res + c) % mod; c = (c * l2) % mod; } printf("%lld\n", res); } else { MATRIX res(3); for (int i = 1; i <= len; i++) { if ((s[i] - '0') % 3 == 0) res = res * t0; else if ((s[i] - '0') % 3 == 1) res = res * t1; else res = res * t2; } res = res ^ n; MATRIX ss; ss.a.resize(1); ss[0].resize(3); ss.h = 1, ss.w = 3; ss[0][0] = 1; // printf("gfdjsipoj"); ss = ss * res; printf("%lld\n", ss[0][0]); } return 0; }