• + 0 comments
    from collections import defaultdict, deque
    
    class Node:
        def __init__(self, num, val):
            self.num = num  # Node number
            self.val = val  # Node value
            self.data = None  # Cached subtree sum
            self.children = set()  # Children nodes
            
        def tree_sum(self):
            """Calculate subtree sum."""
            if self.data is None:  # If not calculated yet
                if not self.children:  # If leaf node
                    self.data = self.val
                else:
                    self.data = sum(child.data for child in self.children) + self.val
            return self.data
    
    def build_tree_and_calculate_sums(data, edges):
        """Build tree and calculate subtree sums."""
        connections = defaultdict(set)  # Node connections
        for a, b in edges:
            connections[a].add(b)  
            connections[b].add(a) 
        
        n = len(data)  
        nodes = {i: Node(i, data[i-1]) for i in range(1, n+1)}  
        
        # DFS traversal to set up node relationships and calculate subtree sums
        visited = [False] * (n + 1)  
        dq = deque([1])  
        visited[1] = True  
        while dq:
            current_node = dq.pop()  
            for j in connections[current_node]:
                if not visited[j]:  
                    visited[j] = True  
                    nodes[current_node].children.add(nodes[j])  
                    dq.append(j)  
        
        # Calculate subtree sums
        dq = deque([nodes[1]])  
        while dq:
            curr_node = dq[-1]  
            if curr_node.data is None:  
                if all(child.data is not None for child in curr_node.children):
                    curr_node.tree_sum()  
                    dq.pop()  
                else:
                    dq.extend(child for child in curr_node.children)  
        
        return nodes
    
    def cutTheTree(data, edges):
        """Calculate minimum difference between two subtrees' sums after a cut."""
        nodes = build_tree_and_calculate_sums(data, edges)
        
        # Calculate total sum of the tree rooted at the first node
        total_sum = nodes[1].tree_sum()  
        
        # Calculate minimum difference between sums of two subtrees
        min_diff = total_sum
        for i in range(2, len(nodes) + 1):
            min_diff = min(min_diff, abs(total_sum - 2 * nodes[i].tree_sum()))
        
        return min_diff