• + 0 comments

    took me AGES to debug

    O(n)

    long mod = 1000000007;
    
    class Node {
    public:
        int number;
        long divide;
        long loneRed;
        Node* parent;
        vector<Node*> children;
        Node (int num, long div, long free, Node* par) {
            number = num; divide = div; loneRed = free; parent = par;
        }
    };
    
    Node* treeMaker (int n, const vector<vector<int>>& roads) {
        Node* root = new Node(1, 0, 1, NULL);
        vector<vector<int>> adj(n+1);
        for (int i=0; i < roads.size(); i++) {
            adj[roads[i][0]].push_back(roads[i][1]);
            adj[roads[i][1]].push_back(roads[i][0]);
        }
        queue<Node*> Q;
        Q.push(root);
        while (!Q.empty()) {
            Node* t = Q.front();
            for (int i=0; i < adj[t->number].size(); i++) {
                if (t == root or adj[t->number][i] != t->parent->number) {
                    Node* child = new Node(adj[t->number][i], 0, 1, t);
                    t->children.push_back(child);
                    Q.push(child);
                }
            }
            Q.pop();
        }
        return root;
    }
    
    void dehaka (Node* current) {
        if (current->children.empty()) return;
        long a = 1, b = 1;
        for (Node* child : current->children) {
            dehaka(child);
            b = (b * child->divide) % mod;
            a = (a * (2 * child->divide + child->loneRed)) % mod;
        }
        current->loneRed = b;
        current->divide = (a - b + mod) % mod;
    }
    
    int main()
    {
        int n, k, l;
        vector<vector<int>> roads;
        cin >> n;
        for (int i=1; i <= n-1; i++) {
            cin >> k >> l;
            roads.push_back({k, l});
        }
        Node* root = treeMaker(n, roads);
        dehaka(root);
        cout << (2 * root->divide) % mod;
    }