• + 0 comments

    hard question

    O(n log(n)), requires 3 pass over the tree, 1st to construct it, 2nd to calculate total sum of each subtree, 3rd to find the suitable partitions

    dont know if there's a O(n) algo

    // O(n log(n))
    
    class Node {
    public:
        int number, value = 0;
        long sum = 0;
        Node* parent;
        vector<Node*> children;
        Node(int num, int val, Node* p) : number(num), value(val), parent(p) {}
    };
    
    Node* treeCreator(int n, int rootNumber, const vector<int>& C, const vector<vector<int>>& edges) {
        vector<vector<int>> adj(n + 1);
        for (vector<int> edge : edges) {
            adj[edge[0]].push_back(edge[1]);
            adj[edge[1]].push_back(edge[0]);
        }
        Node* root = new Node(rootNumber, C[rootNumber], NULL);
        queue<Node*> Q;
        Q.push(root);
        while (!Q.empty()) {
            for (int neighbour : adj[Q.front()->number]) {
                if (Q.front() != root and neighbour == Q.front()->parent->number) continue;
                Node* child = new Node(neighbour, C[neighbour], Q.front());
                Q.front()->children.push_back(child);
                Q.push(child);
            }
            Q.pop();
        }
        return root;
    }
    
    void computeSum(Node* node, unordered_map<long, int>& hash) {
        node->sum = node->value;
        for (Node* child : node->children) {
            computeSum(child, hash);
            node->sum = node->sum + child->sum;
        }
        ++hash[node->sum];
    }
    
    void search(long& X, Node* node, const unordered_map<long, int>& hash, vector<long>& ancestors, const long& total) {
        ancestors.push_back(node->sum);
        if (total/3.0 < node->sum and node->sum <= total/2.0) {
            bool found1 = binary_search(ancestors.begin(), ancestors.end(), 2 * node->sum, greater<long>());
            bool found2 = binary_search(ancestors.begin(), ancestors.end(), total - node->sum, greater<long>());
            int I = hash.at(node->sum);
            if (found1 or found2 or I > 1) X = min(X, 3 * node->sum - total);
        }
        else if (node->sum <= total/3.0 and (total - node->sum) % 2 == 0) {
            long t = (total - node->sum) / 2;
            bool found1 = binary_search(ancestors.begin(), ancestors.end(), node->sum + t, greater<long>());
            bool found2 = binary_search(ancestors.begin(), ancestors.end(), t, greater<long>());
            auto it = hash.find(t);
            bool flag = (it != hash.end()) and (!found2 or it->second > 1);
            if (found1 or flag) X = min(X, t - node->sum);
        }
        for (Node* child : node->children) search(X, child, hash, ancestors, total);
        ancestors.pop_back();
    }
    
    long balancedForest(int n, const vector<int>& C, const vector<vector<int>>& edges) {
        long X = LONG_MAX;
        unordered_map<long, int> hash;
        vector<long> ancestorSums;
        Node* root = treeCreator(n, 1, C, edges);
        computeSum(root, hash);
        search(X, root, hash, ancestorSums, root->sum);
        return (X == LONG_MAX) ? -1 : X;
    }
    
    int main()
    {
        int q, n, k, u, v;
        cin >> q;
        for (int i=1; i <= q; ++i) {
            vector<int> C = {0};
            vector<vector<int>> edges;
            cin >> n;
            for (int j=1; j <= n; ++j) {
                cin >> k;
                C.push_back(k);
            }
            for (int j=1; j <= n-1; ++j) {
                cin >> u >> v;
                edges.push_back({u, v});
            }
            cout << balancedForest(n, C, edges) << '\n';
        }
    }