提交时间:2024-11-12 15:49:21
运行 ID: 34683
#include <bits/stdc++.h> #define int long long using namespace std; const int mod = 1e9 + 7, m2 = 500000004; int n, m; vector <pair <int, int> > ed[100005]; inline int ksm(int x, int y) { int res = 1; while (y) { if (y & 1) res = res * x % mod; x = x * x % mod, y >>= 1; } return res; } struct node { int val, dis; const bool operator > (const node x) const { return dis > x.dis; } } ; bool vis[100005]; int dis[100005], cnt[100005], f[100005], g[100005]; signed main() { scanf("%lld %lld", &n, &m); for (int i = 1; i <= m; i++) { int u, v, w; scanf("%lld %lld %lld", &u, &v, &w); ed[u].emplace_back(make_pair(v, w)); } // (sum^2 - sigma di^2) / 2 memset(dis, 0x3f, sizeof dis); priority_queue <node, vector <node>, greater <node> > q; q.push((node){1, 0}); dis[1] = 0; cnt[1] = 1; while (!q.empty()) { int u = q.top().val; q.pop(); if (vis[u]) continue; vis[u] = 1; for (auto &p : ed[u]) { int &v = p.first, &w = p.second; if (dis[u] + w < dis[v]) { dis[v] = dis[u] + w; cnt[v] = cnt[u], f[v] = (f[u] + cnt[u]) % mod; g[v] = ((g[u] + 2 * f[u] % mod) % mod + cnt[u]) % mod; q.push((node){v, dis[v]}); } else if (dis[u] + w == dis[v]) { // printf("dis[%lld] + %lld = %lld\n", u, w, v); cnt[v] = (cnt[v] + cnt[u]) % mod; f[v] = (f[v] + f[u] + cnt[u]) % mod; g[v] = ((g[v] + g[u]) % mod + 2 * f[u] + cnt[u]) % mod; } } } int ans = 0; for (int i = 1; i <= n; i++) { // printf("%lld %lld %lld\n", cnt[i], f[i], g[i]); ans = (ans + (f[i] * f[i] % mod - g[i]) % mod) % mod; // sum = (sum + f[i]) % mod, sig = (sig + g[i]) % mod; } // printf("%lld %lld\n", sum, sig); printf("%lld\n", ans * m2 % mod); return 0; }