• + 0 comments
    from collections import deque
    
    class FenwickTree:
        def __init__(self, n):
            self.n = n
            self.tree = [0] * (n + 1)
    
        def update(self, i, delta):
            while i <= self.n:
                self.tree[i] += delta
                i += i & (-i)
    
        def query(self, i):
            total = 0
            while i > 0:
                total += self.tree[i]
                i -= i & (-i)
            return total
    
    def similarPair(n, k, edges):
        tree = [[] for _ in range(n + 1)]
        for parent, child in edges:
            tree[parent].append(child)
        
        fenwick = FenwickTree(n)
        result = 0
        
        stack = [(edges[0][0], 0)]  # (node, state)
        visited = [False] * (n + 1)
        
        while stack:
            node, state = stack[-1]
            
            if state == 0:  # Pre-order
                visited[node] = True
                left = max(1, node - k)
                right = min(n, node + k)
                result += fenwick.query(right) - fenwick.query(left - 1)
                fenwick.update(node, 1)
                stack[-1] = (node, 1)
                
                for child in tree[node]:
                    if not visited[child]:
                        stack.append((child, 0))
            else:  # Post-order
                fenwick.update(node, -1)
                stack.pop()
        
        return result