提交时间:2023-12-09 08:55:20

运行 ID: 24050

// 30pts #include <bits/stdc++.h> using namespace std; typedef long long LL; const int N = 2e5 + 9, Mod = 998244353; int fpow(int a, int b) { int res = 1; while (b) { if (b & 1) res = 1ll * res * a % Mod; a = 1ll * a * a % Mod; b >>= 1; } return res; } vector<int> G[N]; int n, a[N], fa[N][20], dep[N], pw2[N]; void dfs(int u) { dep[u] = dep[fa[u][0]] + 1; for (int v : G[u]) { if (v == fa[u][0]) continue; fa[v][0] = u, dfs(v); } } int lca(int u, int v) { if (dep[u] < dep[v]) swap(u, v); for (int i = 17; i >= 0; i--) { if (dep[fa[u][i]] >= dep[v]) u = fa[u][i]; } if (u == v) return u; for (int i = 17; i >= 0; i--) { if (fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i]; } return fa[u][0]; } int dis(int u, int v) { return dep[u] + dep[v] - 2 * dep[lca(u, v)]; } namespace Sub1 { vector<int> D; bool chk() { if (!D.size()) return 0; for (int u = 1; u <= n; u++) { int fl = 1; for (int v : D) { if (dis(v, u) > a[v]) { fl = 0; break; } } if (fl) return 1; } return 0; } int ans = 0; void dfs1(int d) { if (d > n) { if (chk()) { (ans += pw2[D.size()]) %= Mod; } return; } dfs1(d + 1); D.push_back(d); dfs1(d + 1); D.pop_back(); } void Main() { dfs1(1); printf("%d\n", ans); } } namespace Sub2 { int ans = 0; void Main() { for (int i = 1; i <= n; i++) { ans = (ans + 2ll * fpow(3, G[i].size())) % Mod; if (i > 1) ans = (ans - 4 + Mod) % Mod; } printf("%d\n", ans); } } signed main() { pw2[0] = 1; scanf("%d", &n); for (int i = 1; i <= n; i++) { scanf("%d", &a[i]); pw2[i] = 2ll * pw2[i - 1] % Mod; } for (int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(v); G[v].push_back(u); } dfs(1); for (int j = 1; (1 << j) <= n; j++) { for (int i = 1; i <= n; i++) { fa[i][j] = fa[fa[i][j - 1]][j - 1]; } } if (n <= 8) { Sub1::Main(); fclose(stdin); fclose(stdout); return 0; } Sub2::Main(); fclose(stdin); fclose(stdout); return 0; }