• + 0 comments

    Refined Version without building Node class

    from collections import defaultdict, deque
    
    def cutTheTree(data, edges):
        # Create an adjacency list to represent the tree
        node = defaultdict(set)
        for a, b in edges:
            node[a].add(b)
            node[b].add(a)
    
        n = len(data)
        subtree_sums = [None] * (n + 1)
        
        # Perform DFS to calculate subtree sums
        visited = [False] * (n + 1)
        stack = deque([1])
        visited[1] = True
        
        while stack:
            i = stack[-1]
            # Check if any child nodes are unprocessed
            if any(subtree_sums[j] is None for j in node[i]):
                for j in sorted(node[i]):
                    if not visited[j]:
                        stack.append(j)
                        visited[j] = True
                    else:
                        node[i].remove(j)  # Avoid revisiting parent node
                continue
            # Calculate the subtree sum for node i
            subtree_sums[i] = sum(subtree_sums[j] for j in node[i]) + data[i - 1]
            stack.pop()
        
        # Total sum of all nodes in the tree
        total_sum = subtree_sums[1]
        
        # Find the minimum difference between two parts of the tree
        for i in range(2, n + 1):
            diff = abs(total_sum - 2 * subtree_sums[i])
            min_diff = min(min_diff, diff)
        
        return min_diff