You are viewing a single comment's thread. Return to all comments →
// #include "testlib.h" using namespace std ; #define ft first #define sd second #define pb push_back #define all(x) x.begin(),x.end() #define ll long long int #define vi vector<int> #define vii vector<pair<int,int> > #define pii pair<int,int> #define plii pair<pair<ll, int>, int> #define piii pair<pii, int> #define viii vector<pair<pii, int> > #define vl vector<ll> #define vll vector<pair<ll,ll> > #define pll pair<ll,ll> #define pli pair<ll,int> #define mp make_pair #define ms(x, v) memset(x, v, sizeof x) #define sc1(x) scanf("%d",&x) #define sc2(x,y) scanf("%d%d",&x,&y) #define sc3(x,y,z) scanf("%d%d%d",&x,&y,&z) #define scll1(x) scanf("%lld",&x) #define scll2(x,y) scanf("%lld%lld",&x,&y) #define scll3(x,y,z) scanf("%lld%lld%lld",&x,&y,&z) #define pr1(x) printf("%d\n",x) #define pr2(x,y) printf("%d %d\n",x,y) #define pr3(x,y,z) printf("%d %d %d\n",x,y,z) #define prll1(x) printf("%lld\n",x) #define prll2(x,y) printf("%lld %lld\n",x,y) #define prll3(x,y,z) printf("%lld %lld %lld\n",x,y,z) #define pr_vec(v) for(int i=0;i<v.size();i++) cout << v[i] << " " ; #define f_in(st) freopen(st,"r",stdin) #define f_out(st) freopen(st,"w",stdout) #define fr(i, a, b) for(i=a; i<=b; i++) #define fb(i, a, b) for(i=a; i>=b; i--) #define ASST(x, l, r) assert( x <= r && x >= l ) #include <ext/pb_ds/assoc_container.hpp> #include <ext/pb_ds/tree_policy.hpp> const int mod = 1e9 + 7; int ADD(int a, int b, int m = mod) { int s = a; s += b; if( s >= m ) s -= m; return s; } int MUL(int a, int b, int m = mod) { return (1LL * a * b % m); } int power(int a, int b, int m = mod) { int res = 1; while( b ) { if( b & 1 ) { res = 1LL * res * a % m; } a = 1LL * a * a % m; b /= 2; } return res; } ll nC2(ll x) { return ( x * ( x - 1 ) / 2 ); } const int maxn = 5 * 1e5 + 5; int t, n, vis[maxn], cnt; map<int, int> M; vii adj[ maxn ]; int prime1 = 23, prime2 = 7, base[2][maxn]; int mod1 = 1589917477; int mod2 = 1897266401; vii a; void dfs(int u, int p = 0, ll cst1 = 0, ll cst2 = 0) { a[u-1].ft = cst1; a[u-1].sd = cst2; for( auto it: adj[u] ) { if( it.ft != p ) { if(!M.count(it.sd)) { M[it.sd] = cnt ++; } vis[M[it.sd]] = 1 - vis[M[it.sd]]; cst1 += (vis[M[it.sd]] ? base[0][M[it.sd]] : -base[0][M[it.sd]]); cst2 += (vis[M[it.sd]] ? base[1][M[it.sd]] : -base[1][M[it.sd]]); if( cst1 >= mod1 ) cst1 -= mod1; if( cst1 < 0 ) cst1 += mod1; if( cst2 >= mod2 ) cst2 -= mod2; if( cst2 < 0 ) cst2 += mod2; dfs(it.ft, u, cst1, cst2); vis[M[it.sd]] = 1 - vis[M[it.sd]]; cst1 += (vis[M[it.sd]] ? base[0][M[it.sd]] : -base[0][M[it.sd]]); cst2 += (vis[M[it.sd]] ? base[1][M[it.sd]] : -base[1][M[it.sd]]); if( cst1 >= mod1 ) cst1 -= mod1; if( cst1 < 0 ) cst1 += mod1; if( cst2 >= mod2 ) cst2 -= mod2; if( cst2 < 0 ) cst2 += mod2; } } } int main() { cin >> t; int sum = 0; while( t-- ) { cin >> n; sum += n; assert(sum <= 500000); int i; base[0][0] = base[1][0] = 1; fr(i, 1, n) { base[0][i] = 1LL * base[0][i-1] * prime1 % mod1; base[1][i] = 1LL * base[1][i-1] * prime2 % mod2; } fr(i, 1, n-1) { int u, v, cst; cin >> u >> v >> cst; adj[u].pb( {v, cst} ); adj[v].pb( {u, cst} ); } cnt = 0; a.resize(n); dfs(1, 0, 0); assert(a.size() == n); sort(all(a)); i = 0; ll ans = 0; while( i < n ) { pii x = a[i]; int c = 0; while( i < n && x == a[i] ) { c ++; i ++; } ans += 1LL * c * (c-1) / 2; } ans = nC2(n) - ans; cout << ans << "\n"; M.clear(); a.clear(); fr(i, 0, n) { adj[i].clear(); vis[i] = base[0][i] = base[1][i] = 0; } } assert(n <= 500000); return 0; }
Seems like cookies are disabled on this browser, please enable them to open this website
Number Game on a Tree
You are viewing a single comment's thread. Return to all comments →
include