You are viewing a single comment's thread. Return to all comments →
hard question
O(n log(n)), requires 3 pass over the tree, 1st to construct it, 2nd to calculate total sum of each subtree, 3rd to find the suitable partitions
dont know if there's a O(n) algo
// O(n log(n)) class Node { public: int number, value = 0; long sum = 0; Node* parent; vector<Node*> children; Node(int num, int val, Node* p) : number(num), value(val), parent(p) {} }; Node* treeCreator(int n, int rootNumber, const vector<int>& C, const vector<vector<int>>& edges) { vector<vector<int>> adj(n + 1); for (vector<int> edge : edges) { adj[edge[0]].push_back(edge[1]); adj[edge[1]].push_back(edge[0]); } Node* root = new Node(rootNumber, C[rootNumber], NULL); queue<Node*> Q; Q.push(root); while (!Q.empty()) { for (int neighbour : adj[Q.front()->number]) { if (Q.front() != root and neighbour == Q.front()->parent->number) continue; Node* child = new Node(neighbour, C[neighbour], Q.front()); Q.front()->children.push_back(child); Q.push(child); } Q.pop(); } return root; } void computeSum(Node* node, unordered_map<long, int>& hash) { node->sum = node->value; for (Node* child : node->children) { computeSum(child, hash); node->sum = node->sum + child->sum; } ++hash[node->sum]; } void search(long& X, Node* node, const unordered_map<long, int>& hash, vector<long>& ancestors, const long& total) { ancestors.push_back(node->sum); if (total/3.0 < node->sum and node->sum <= total/2.0) { bool found1 = binary_search(ancestors.begin(), ancestors.end(), 2 * node->sum, greater<long>()); bool found2 = binary_search(ancestors.begin(), ancestors.end(), total - node->sum, greater<long>()); int I = hash.at(node->sum); if (found1 or found2 or I > 1) X = min(X, 3 * node->sum - total); } else if (node->sum <= total/3.0 and (total - node->sum) % 2 == 0) { long t = (total - node->sum) / 2; bool found1 = binary_search(ancestors.begin(), ancestors.end(), node->sum + t, greater<long>()); bool found2 = binary_search(ancestors.begin(), ancestors.end(), t, greater<long>()); auto it = hash.find(t); bool flag = (it != hash.end()) and (!found2 or it->second > 1); if (found1 or flag) X = min(X, t - node->sum); } for (Node* child : node->children) search(X, child, hash, ancestors, total); ancestors.pop_back(); } long balancedForest(int n, const vector<int>& C, const vector<vector<int>>& edges) { long X = LONG_MAX; unordered_map<long, int> hash; vector<long> ancestorSums; Node* root = treeCreator(n, 1, C, edges); computeSum(root, hash); search(X, root, hash, ancestorSums, root->sum); return (X == LONG_MAX) ? -1 : X; } int main() { int q, n, k, u, v; cin >> q; for (int i=1; i <= q; ++i) { vector<int> C = {0}; vector<vector<int>> edges; cin >> n; for (int j=1; j <= n; ++j) { cin >> k; C.push_back(k); } for (int j=1; j <= n-1; ++j) { cin >> u >> v; edges.push_back({u, v}); } cout << balancedForest(n, C, edges) << '\n'; } }
Seems like cookies are disabled on this browser, please enable them to open this website
Balanced Forest
You are viewing a single comment's thread. Return to all comments →
hard question
O(n log(n)), requires 3 pass over the tree, 1st to construct it, 2nd to calculate total sum of each subtree, 3rd to find the suitable partitions
dont know if there's a O(n) algo