• + 0 comments

    Python solution:

    from bisect import bisect_left, bisect_right
    
    class DSU:
        def __init__(self, n):
            self.parent = list(range(n + 1))
            self.size = [1] * (n + 1)
        
        def find(self, u):
            if self.parent[u] != u:
                self.parent[u] = self.find(self.parent[u])
            return self.parent[u]
        
        def union(self, u, v):
            u_root = self.find(u)
            v_root = self.find(v)
            if u_root == v_root:
                return 0
            if self.size[u_root] < self.size[v_root]:
                u_root, v_root = v_root, u_root
            self.parent[v_root] = u_root
    
            new_pairs = self.size[u_root] * self.size[v_root]
            self.size[u_root] += self.size[v_root]
            return new_pairs
    
    def solve(tree, queries):
        tree.sort(key=lambda x: x[2])
        n = len(tree) + 1
        dsu = DSU(n)
        prefix = [(0, 0)]
        current_pairs = 0
        
        for u, v, w in tree:
            new_pairs = dsu.union(u, v)
            current_pairs += new_pairs
    
            if not prefix or w != prefix[-1][0]:
                prefix.append((w, current_pairs))
            else:
                prefix[-1] = (w, current_pairs)
        
        res = []
        for l, r in queries:
            li = bisect_left(prefix, (l, 0)) - 1
            ri = bisect_right(prefix, (r, float('inf'))) - 1
            
            if li >= len(prefix) or ri < 1:
                res.append(0)
            else:
                total = prefix[ri][1] - prefix[li][1]
                res.append(total)
        
        return res