Sort by

recency

|

163 Discussions

|

  • + 1 comment

    We can make a mapping of each node, to the sum of their tree. In this case, the root node has a sum equal to the entire tree, and each other node the respective cost of their subtree.

    If we cut at an edge, we spawn two subtrees: one where we just cut (with a sum equal to the mapping of that node) and a 'bigger' tree, with sum equal to the entire tree (i.e., mapping[1]) - the tree we just removed. With this idea, we can map each node to the sum of its tree in O(V + E), and then cut every edge once, perfoming the check in O(1), giving a further O(V + E).

    The total time complexity then becomes O(V + E), with O(V + E) space, too.

    def cutTheTree(data, edges):
        # Write your code here
        adj = collections.defaultdict(list)
        for u, v in edges:
            adj[u].append(v)
            adj[v].append(u)
            
        lookup = {}
        
        def treeSum(node, par):
            if node is None:
                return 0
            if node in lookup:
                return lookup[node]
            ans = data[node - 1]
            for ch in adj[node]:
                if ch == par:
                    continue
                ans += treeSum(ch, node)
            lookup[node] = ans
            return ans
        
        treeSum(1, 1)
        minDiff = float("inf")
        def minPart(node, par):
            nonlocal minDiff
            if node is None:
                return
            for ch in adj[node]:
                if ch == par:
                    continue
                # cut at child
                tree1 = lookup[ch]
                tree2 = lookup[1] - tree1
                diff = abs(tree1 - tree2)
                minDiff = min(minDiff, diff)
                minPart(ch, node)
        minPart(1, 1)
        return minDiff
    
  • + 2 comments
    class Result {
    
    static List<List<Integer>> graph;
    static List<Integer> data;
    static int root,res,total;
    public static void create(List<List<Integer>> edges, int n){
    for (int i = 0; i <= n; i++) {
    graph.add(new ArrayList<Integer>()); // Change to LinkedList<Integer>()
    }
    
    for(List<Integer> e : edges) {
    graph.get(e.get(0)).add(e.get(1));
    graph.get(e.get(1)).add(e.get(0));
    }
    
    
    }
    
    private static int dfs(int r, boolean[] vis){
    vis[r]=true;
    if(graph.get(r).size() == 0) return 0;
    
    int sum =0;
    for(int i : graph.get(r)){
    if(!vis[i])
    sum+=dfs(i,vis);
    }
    
    sum+=data.get(r-1);
    res=Math.min(res,Math.abs(total-sum*2));
    return sum;
    
    
    } 
    public static int cutTheTree(List<Integer> data, List<List<Integer>> edges) {
    // Write your code here
    int s = 0;
    graph = new ArrayList<>();
    root =-1;
    res =Integer.MAX_VALUE;
    for(int i : data)
    s+=i;
    total = s;
    create(edges,data.size());
    Result.data = data;
    int v =dfs(1,new boolean[data.size()+2]);
    return res;
            
            
    
        }
    
    }
    
  • + 0 comments
    #include<iostream>
    #include<vector>
    #include<climits>
    using namespace std;
    
    int value[100001], n, tot, best;
    vector<vector<int> > v;
    bool visited[100001];
    int sum[100001];
    
    int dfs(int node)
    {
        int ret = 0;
        visited[node] = true;
        int sz = (int)v[node].size();
        for (int i = 0; i < sz; i++)
        {
            int next = v[node][i];
            if (!visited[next])
                ret += dfs(next);
        }
        return sum[node] = value[node] + ret;
    }
    
    int main()
    {
        best = INT_MAX;
        tot = 0;
        cin >> n;
        for (int i = 1; i <= n; i++)
        {
            cin >> value[i];
            tot += value[i];
        }
        v.resize(n + 1);
        for (int i = 0; i < n - 1; i++)
        {
            int a, b;
            cin >> a >> b;
            v[a].push_back(b);
            v[b].push_back(a);
        }
        dfs(1);
        for (int i = 1; i <= n; i++)
            if (abs(tot - sum[i] - sum[i]) < best)
                best = abs(tot - sum[i] - sum[i]);
        cout << best << endl;
        return 0;
    }
    
  • + 0 comments

    Refined Version without building Node class

    from collections import defaultdict, deque
    
    def cutTheTree(data, edges):
        # Create an adjacency list to represent the tree
        node = defaultdict(set)
        for a, b in edges:
            node[a].add(b)
            node[b].add(a)
    
        n = len(data)
        subtree_sums = [None] * (n + 1)
        
        # Perform DFS to calculate subtree sums
        visited = [False] * (n + 1)
        stack = deque([1])
        visited[1] = True
        
        while stack:
            i = stack[-1]
            # Check if any child nodes are unprocessed
            if any(subtree_sums[j] is None for j in node[i]):
                for j in sorted(node[i]):
                    if not visited[j]:
                        stack.append(j)
                        visited[j] = True
                    else:
                        node[i].remove(j)  # Avoid revisiting parent node
                continue
            # Calculate the subtree sum for node i
            subtree_sums[i] = sum(subtree_sums[j] for j in node[i]) + data[i - 1]
            stack.pop()
        
        # Total sum of all nodes in the tree
        total_sum = subtree_sums[1]
        
        # Find the minimum difference between two parts of the tree
        for i in range(2, n + 1):
            diff = abs(total_sum - 2 * subtree_sums[i])
            min_diff = min(min_diff, diff)
        
        return min_diff
    
  • + 0 comments
    from collections import defaultdict, deque
    
    class Node:
        def __init__(self, num, val):
            self.num = num  # Node number
            self.val = val  # Node value
            self.data = None  # Cached subtree sum
            self.children = set()  # Children nodes
            
        def tree_sum(self):
            """Calculate subtree sum."""
            if self.data is None:  # If not calculated yet
                if not self.children:  # If leaf node
                    self.data = self.val
                else:
                    self.data = sum(child.data for child in self.children) + self.val
            return self.data
    
    def build_tree_and_calculate_sums(data, edges):
        """Build tree and calculate subtree sums."""
        connections = defaultdict(set)  # Node connections
        for a, b in edges:
            connections[a].add(b)  
            connections[b].add(a) 
        
        n = len(data)  
        nodes = {i: Node(i, data[i-1]) for i in range(1, n+1)}  
        
        # DFS traversal to set up node relationships and calculate subtree sums
        visited = [False] * (n + 1)  
        dq = deque([1])  
        visited[1] = True  
        while dq:
            current_node = dq.pop()  
            for j in connections[current_node]:
                if not visited[j]:  
                    visited[j] = True  
                    nodes[current_node].children.add(nodes[j])  
                    dq.append(j)  
        
        # Calculate subtree sums
        dq = deque([nodes[1]])  
        while dq:
            curr_node = dq[-1]  
            if curr_node.data is None:  
                if all(child.data is not None for child in curr_node.children):
                    curr_node.tree_sum()  
                    dq.pop()  
                else:
                    dq.extend(child for child in curr_node.children)  
        
        return nodes
    
    def cutTheTree(data, edges):
        """Calculate minimum difference between two subtrees' sums after a cut."""
        nodes = build_tree_and_calculate_sums(data, edges)
        
        # Calculate total sum of the tree rooted at the first node
        total_sum = nodes[1].tree_sum()  
        
        # Calculate minimum difference between sums of two subtrees
        min_diff = total_sum
        for i in range(2, len(nodes) + 1):
            min_diff = min(min_diff, abs(total_sum - 2 * nodes[i].tree_sum()))
        
        return min_diff