77 Discussions



    Return Type should be Long, not Int to pass all test cases.


    For C++11, remember the change also the result type in the main function. Not just your solution function. What a joke


    hard question

    O(n log(n)), requires 3 pass over the tree, 1st to construct it, 2nd to calculate total sum of each subtree, 3rd to find the suitable partitions

    dont know if there's a O(n) algo

    // O(n log(n))
    class Node {
        int number, value = 0;
        long sum = 0;
        Node* parent;
        vector<Node*> children;
        Node(int num, int val, Node* p) : number(num), value(val), parent(p) {}
    Node* treeCreator(int n, int rootNumber, const vector<int>& C, const vector<vector<int>>& edges) {
        vector<vector<int>> adj(n + 1);
        for (vector<int> edge : edges) {
        Node* root = new Node(rootNumber, C[rootNumber], NULL);
        queue<Node*> Q;
        while (!Q.empty()) {
            for (int neighbour : adj[Q.front()->number]) {
                if (Q.front() != root and neighbour == Q.front()->parent->number) continue;
                Node* child = new Node(neighbour, C[neighbour], Q.front());
        return root;
    void computeSum(Node* node, unordered_map<long, int>& hash) {
        node->sum = node->value;
        for (Node* child : node->children) {
            computeSum(child, hash);
            node->sum = node->sum + child->sum;
    void search(long& X, Node* node, const unordered_map<long, int>& hash, vector<long>& ancestors, const long& total) {
        if (total/3.0 < node->sum and node->sum <= total/2.0) {
            bool found1 = binary_search(ancestors.begin(), ancestors.end(), 2 * node->sum, greater<long>());
            bool found2 = binary_search(ancestors.begin(), ancestors.end(), total - node->sum, greater<long>());
            int I =>sum);
            if (found1 or found2 or I > 1) X = min(X, 3 * node->sum - total);
        else if (node->sum <= total/3.0 and (total - node->sum) % 2 == 0) {
            long t = (total - node->sum) / 2;
            bool found1 = binary_search(ancestors.begin(), ancestors.end(), node->sum + t, greater<long>());
            bool found2 = binary_search(ancestors.begin(), ancestors.end(), t, greater<long>());
            auto it = hash.find(t);
            bool flag = (it != hash.end()) and (!found2 or it->second > 1);
            if (found1 or flag) X = min(X, t - node->sum);
        for (Node* child : node->children) search(X, child, hash, ancestors, total);
    long balancedForest(int n, const vector<int>& C, const vector<vector<int>>& edges) {
        long X = LONG_MAX;
        unordered_map<long, int> hash;
        vector<long> ancestorSums;
        Node* root = treeCreator(n, 1, C, edges);
        computeSum(root, hash);
        search(X, root, hash, ancestorSums, root->sum);
        return (X == LONG_MAX) ? -1 : X;
    int main()
        int q, n, k, u, v;
        cin >> q;
        for (int i=1; i <= q; ++i) {
            vector<int> C = {0};
            vector<vector<int>> edges;
            cin >> n;
            for (int j=1; j <= n; ++j) {
                cin >> k;
            for (int j=1; j <= n-1; ++j) {
                cin >> u >> v;
                edges.push_back({u, v});
            cout << balancedForest(n, C, edges) << '\n';

    Everyone can try to solve by this js code

    function addChild(tree, c, node, child) {
        tree[node] = tree[node] ?? { children: new Set(), val: c[node - 1] };
    function cleanUpTree(tree, node = 1, parent = -1, vMap = {}) {
        let { children, val } = tree[node];
        for (const child of children) {
            const [v] = cleanUpTree(tree, child, node, vMap);
            val += v;
        vMap[val] = (vMap[val] ?? 0) + 1;
        tree[node].val = val;
        return [val, vMap];
    function getMin(a, b) {
        return a === -1 || a > b ? b : a;
    function* dfsIter(tree, vMap, node = 1, parentVMap = {}) {
        const { val, children } = tree[node];
        parentVMap[val] = (parentVMap[val] ?? 0) + 1;
        for (const child of children)
            yield* dfsIter(tree, vMap, child, parentVMap);
        if (node === 1) return;
        yield [val, parentVMap];
    function balancedForest(c, edges) {
        let ret = -1;
        if (!c?.length || !edges?.length) return ret;
        const tree = {};
        for (const [begin, end] of edges) {
            addChild(tree, c, begin, end);
            addChild(tree, c, end, begin);
        const [total, vMap] = cleanUpTree(tree);
        for (const [v1, parentVMap] of dfsIter(tree, vMap)) {
            const rest = total - v1;
            if (rest === v1) {
                ret = getMin(ret, v1);
            for (const v2 of [rest / 2, v1, rest - v1]) {
                if (!vMap[v2] && !parentVMap[v1 + v2]) continue;
                const v3 = rest - v2;
                const [a, b, c] = [v1, v2, v3].sort((a, b) => b - a);
                if (a !== b) continue;
                ret = getMin(ret, a - c);
        return ret;