提交时间:2024-11-12 13:38:19

运行 ID: 34622

//100 #include<bits/stdc++.h> using namespace std; #define int long long #define lson (pos<<1) #define rson (pos<<1|1) #define pii pair<int,int> #define fr first #define sc second #define mk make_pair #define pb push_back #define inx(u) int I=h[(u)],v=edge[I].v,w=edge[I].w;I;I=edge[I].nx,v=edge[I].v,w=edge[I].w int read(){int x=0,f=1;char c=getchar();while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}while(c>='0'&&c<='9')x=x*10+c-'0',c=getchar();return x*f;} const int MAXN=100010,Mod=1000000007; struct Edge{int v,nx,w;}edge[MAXN<<1];int h[MAXN],CNT;void add_side(int u,int v,int w){edge[++CNT]={v,h[u],w};h[u]=CNT;} int Pow(int x,int y){int rt=1;while(y){if(y&1)rt=rt*x%Mod;x=x*x%Mod;y>>=1;}return rt;} int n,m,k,inv,ans,dis[MAXN],vis[MAXN]; struct node{ int ct,sum,sul; node operator+(const int&G)const{ node res; res.ct=ct; res.sum=(ct+sum)%Mod; res.sul=(sul+sum*2+ct)%Mod; return res; } node operator*(const node&G)const{ node res; res.ct=(ct+G.ct)%Mod; res.sum=(sum+G.sum)%Mod; res.sul=(sul+G.sul)%Mod; return res; } }d[MAXN]; vector<int>p[MAXN]; priority_queue<pii>q; void dji2(){ memset(dis,0x3f,sizeof(dis)); q.push(mk(0,1)); dis[1]=0; p[1].pb(0); while(!q.empty()){ int u=q.top().sc;q.pop(); if(vis[u])continue;vis[u]=1; for(inx(u))if(dis[u]+w<=dis[v]){ if(dis[v]==dis[u]+w){ for(auto o:p[u])p[v].pb(o+1); } else{ p[v].clear(); for(auto o:p[u])p[v].pb(o+1); dis[v]=dis[u]+w; q.push(mk(-dis[v],v)); } } } } void dji(){ memset(dis,0x3f,sizeof(dis)); q.push(mk(0,1)); dis[1]=0; d[1]={1,0,0}; while(!q.empty()){ int u=q.top().sc;q.pop(); if(vis[u])continue;vis[u]=1; for(inx(u))if(dis[u]+w<=dis[v]){ if(dis[v]==dis[u]+w){ d[v]=d[v]*(d[u]+1); } else{ d[v]=d[u]+1; dis[v]=dis[u]+w; q.push(mk(-dis[v],v)); } } } } void slv1(){ dji(); for(int i=1;i<=n;i++)ans+=(d[i].sum*d[i].sum%Mod-d[i].sul+Mod)*inv%Mod,ans%=Mod; printf("%lld",(ans%Mod+Mod)%Mod); } void slv2(){ dji2(); for(int i=1;i<=n;i++){ int sum1=0,sum2=0; for(auto o:p[i])sum1+=o,sum2+=o*o; ans+=(sum1*sum1-sum2)/2; ans%=Mod; } printf("%lld",(ans%Mod+Mod)%Mod); } void slv(){inv=Pow(2,Mod-2); n=read(),m=read(); for(int i=1;i<=m;i++){ int u=read(),v=read(),w=read(); add_side(u,v,w); } if(n<=10&&m<=10)slv2(); else slv1(); } signed main(){ slv(); return 0; }