| Run ID | 作者 | 问题 | 语言 | 测评结果 | 分数 | 时间 | 内存 | 代码长度 | 提交时间 |
|---|---|---|---|---|---|---|---|---|---|
| 38581 | 沈仲恩 | 【J】T4 | C++ | 通过 | 100 | 431 MS | 261984 KB | 4224 | 2025-10-18 12:52:55 |
#include <bits/stdc++.h> #define int long long using namespace std; int n, rt, f[22][1000005], l, dep[1000005], sz[1000005]; vector <int> ed[1000005]; inline void dfs(int u) { // printf("%d\n", u); // fflush(stdout); dep[u] = dep[f[0][u]] + 1; sz[u] = 1; for (int v : ed[u]) { if (v == f[0][u]) continue; f[0][v] = u; dfs(v); sz[u] += sz[v]; } } inline void init() { dfs(rt); for (int i = 1; i <= l; i++) { for (int u = 1; u <= n; u++) { f[i][u] = f[i - 1][f[i - 1][u]]; } } } inline int lca(int u, int v) { if (dep[u] < dep[v]) swap(u, v); while (dep[u] != dep[v]) { int d = dep[u] - dep[v]; d = __lg(d & (-d)); u = f[d][u]; } if (u == v) return u; for (int i = l; i >= 0; i--) if (f[i][u] != f[i][v]) u = f[i][u], v = f[i][v]; return f[0][u]; } inline bool onc(int u, int v) { if (dep[u] < dep[v]) swap(u, v); while (dep[u] != dep[v]) { int d = dep[u] - dep[v]; d = __lg(d & (-d)); u = f[d][u]; } return u == v; } inline int st(int u, int v) { if (dep[v] < dep[u]) { return n - sz[u]; } return sz[v]; } vector <int> sl; signed main() { scanf("%lld", &n); l = __lg(n) + 1; if (n == 1) { puts("0"); return 0; } for (int i = 1; i < n; i++) { int u; scanf("%lld", &u); ed[u].emplace_back(i + 1); ed[i + 1].emplace_back(u); } for (int i = 1; i <= n; i++) { if (ed[i].size() > 3) { puts("0"); return 0; } if (!rt && ed[i].size() == 1) rt = i; } // puts("1"); // fflush(stdout); init(); for (int i = 1; i <= n; i++) { if (ed[i].size() == 3) { sl.emplace_back(i); } } if (sl.size() == 0) { // 1+...+n-1 printf("%lld", (1ll + n - 1) * (n - 1ll) / 2); } else if (sl.size() == 1) { long long a = st(sl[0], ed[sl[0]][0]); long long b = st(sl[0], ed[sl[0]][1]); long long c = st(sl[0], ed[sl[0]][2]); printf("%lld", a * b + a * c + b * c); } else { int l = sl[0], d = 0, d1 = 0; for (int u : sl) l = lca(l, u), d = (dep[u] > dep[d] ? u : d); for (int u : sl) { if (!onc(u, d)) { if (lca(u, d) != l) { puts("0"); return 0; } if (dep[u] > dep[d1]) d1 = u; } } for (int u : sl) { if (!onc(u, d) && !onc(u, d1)) { puts("0"); return 0; } } if (!d1) { vector <long long> ll; for (int vl : ed[l]) { if (onc(vl, d) && dep[vl] > l) continue; ll.emplace_back(st(l, vl)); } long long ans = 0; for (int vd : ed[d]) { if (dep[vd] < dep[d]) continue; for (int x : ll) { ans += x * st(d, vd); } } printf("%lld\n", ans); } else { vector <long long> ld; for (int vd : ed[d]) { if (dep[vd] < dep[d]) continue; ld.emplace_back(vd); } long long ans = 0; for (int v : ed[d1]) { if (dep[v] < dep[d1]) continue; for (int x : ld) { ans += x * st(d1, v); } } printf("%lld\n", ans); } } return 0; }