Kitty's Calculations on a Tree

Sort by

recency

|

63 Discussions

|

  • + 1 comment
    import sys, math
    
    # Pretty important line here.
    sys.setrecursionlimit(10**6)
    
    def preprocess_original_tree(node, parent):
        # "Jumping" out the current node.
        global timer
        tin[node] = timer
        timer += 1
    
        # Building a chart of the 2^i ancestors of each node.
        up[node][0] = parent
    
        # Setting each node's (which is the current child) 2^i ancestors' value.
        for i in range(1, LOG):
            if up[node][i - 1] != -1:
                up[node][i] = up[up[node][i - 1]][i - 1]
    
        # Perform DFS to set current node's children's depth recursively.
        for neighbor in adj[node]:
            if neighbor == parent:
                continue
            depth[neighbor] = depth[node] + 1
            preprocess_original_tree(neighbor, node)
    
        # "Jumping" out of the current node.
        tout[node] = timer
        timer += 1
    
    def lift_node(node, k):
        # Jumping to the greatest 2^i ancestor each time.
        for i in range(LOG - 1, -1, -1):
            if k & (1 << i):
                node = up[node][i]
        return node
    
    def get_lca(u, v):
        # Equalizing both node's depths.
        if depth[u] < depth[v]:
            u, v = v, u
        u = lift_node(u, depth[u] - depth[v])
        if u == v:
            return u
    
        # Jumping to the greatest 2^i ancestor each time.
        for i in range(LOG - 1, -1, -1):
            if up[u][i] != up[v][i]:
                u = up[u][i]
                v = up[v][i]
        return up[u][0]
    
    def get_distance(u, v):
        ancestor = get_lca(u, v)
    
        # It uses the original tree's preprocessed depths.
        return depth[u] + depth[v] - 2 * depth[ancestor]
    
    def build_virtual_tree(nodes):
        # Adding relevant nodes to virtual tree.
        nodes.sort(key=lambda x: tin[x])
        m = len(nodes)
        vt_nodes = nodes[:]
        for i in range(m - 1):
            vt_nodes.append(get_lca(nodes[i], nodes[i + 1]))
        vt_nodes = list(set(vt_nodes))
        vt_nodes.sort(key=lambda x: tin[x])
    
        # Connecting nodes in virtual tree.
        tree = {node: [] for node in vt_nodes}
        stack = []
        
        # All virtual tree nodes are stored in the order in which they were found, thus preserving their hierarchy from left to right.
        for node in vt_nodes:
            # Validating if the last ancestor in the stack is the ancestor of the current node.
            while stack and tout[stack[-1]] < tin[node]:
                stack.pop()
            if stack:
                tree[stack[-1]].append(node)
            stack.append(node)
        return tree, vt_nodes
    
    def solve_query(query_nodes):
        # Traversing query nodes (virtual tree's nodes)
        def dp(u):
            nonlocal res
            s = query_val.get(u, 0)
    
            # Performing DFS.
            for v in vt[u]:
                sub = dp(v)
                # Since 
                # sum(u in sub) * sum(v not in sub) 
                # = (sum(u in sub)) * (sum(v not in sub)) 
                # = sub * (S_total - sub)
                res = (res + sub * (S_total - sub) % MOD * get_distance(u, v)) % MOD
                s += sub
    
            # Returning the total sum of the current subtree.
            return s
    
        if len(query_nodes) < 2:
            return 0
        S_total = sum(query_nodes)
        query_val = {node: node for node in query_nodes}
        vt, vt_nodes = build_virtual_tree(query_nodes)
    
        res = 0
        dp(vt_nodes[0])
        return res
    
    MOD = 10**9 + 7
    timer = 0
    
    data = sys.stdin.read().split()
    it = iter(data)
    n = int(next(it))
    q = int(next(it))
    
    LOG = int(math.log2(n)) + 1
    up = [[-1] * LOG for _ in range(n + 1)]
    depth = [0] * (n + 1)
    tin = [0] * (n + 1)
    tout = [0] * (n + 1)
    
    # Building original tree.
    adj = [[] for _ in range(n + 1)]
    for _ in range(n - 1):
        u, v = int(next(it)), int(next(it))
        adj[u].append(v)
        adj[v].append(u)
    
    preprocess_original_tree(1, -1)
    
    res = []
    for _ in range(q):
        k = int(next(it))
        query_nodes = [int(next(it)) for _ in range(k)]
        res.append(str(solve_query(query_nodes)))
    
    sys.stdout.write("\n".join(res))
    
  • + 0 comments

    from collections import Counter, defaultdict

    MOD = 10**9 + 7

    def read_row(): return (int(x) for x in input().split())

    def mul(x, y): return (x * y) % MOD

    def add(*args): return sum(args) % MOD

    def sub(x, y): return (x - y) % MOD

    n, q = read_row()

    Construct adjacency list of the tree

    adj_list = defaultdict(list)

    for _ in range(n - 1): u, v = read_row() adj_list[u].append(v) adj_list[v].append(u)

    Construct element to set mapping {element: [sets it belongs to]}

    elements = {v: set() for v in adj_list}

    for set_no in range(q): read_row() for x in read_row(): elements[x].add(set_no)

    Do BFS to find parent for each node and order them in reverse depth

    root = next(iter(adj_list)) current = [root] current_depth = 0 order = [] parent = {root: None} depth = {root: current_depth}

    while current: current_depth += 1 order.extend(current) nxt = [] for node in current: for neighbor in adj_list[node]: if neighbor not in parent: parent[neighbor] = node depth[neighbor] = current_depth nxt.append(neighbor)

    current = nxt
    

    Process nodes in the order created above

    score = Counter()

    {node: {set_a: [depth, sum of nodes, flow]}}

    state = {} for node in reversed(order): states = [state[neighbor] for neighbor in adj_list[node] if neighbor != parent[node]] largest = {s: [depth[node], node, 0] for s in elements[node]}

    if states:
        max_index = max(range(len(states)), key=lambda x: len(states[x]))
        if len(states[max_index]) > len(largest):
            states[max_index], largest = largest, states[max_index]
    
    
    sets = defaultdict(list)
    for cur_state in states:
        for set_no, v in cur_state.items():
            sets[set_no].append(v)
    
    for set_no, states in sets.items():
        if len(states) == 1 and set_no not in largest:
            largest[set_no] = states[0]
            continue
    
        if set_no in largest:
            states.append(largest.pop(set_no))
    
        total_flow = 0
        total_node_sum = 0
    
        for node_depth, node_sum, node_flow in states:
            flow_delta = mul(node_depth - depth[node], node_sum)
            total_flow = add(total_flow, flow_delta, node_flow)
            total_node_sum += node_sum
    
        set_score = 0
    
        for node_depth, node_sum, node_flow in states:
            node_flow = add(mul(node_depth - depth[node], node_sum), node_flow)
            diff = mul(sub(total_flow, node_flow), node_sum)
            set_score = add(set_score, diff)
    
        score[set_no] = add(score[set_no], set_score)
        largest[set_no] = (depth[node], total_node_sum, total_flow)
    
    state[node] = largest
    

    print(*(score[i] for i in range(q)), sep='\n')

  • + 0 comments

    I was able to find some useful hints in the discussions, but they're buried below a mountain of misinformation and nonfunctional solutions. So if you're looking for help, here are the basic components of the solution (or at least, of a solution).

    1. Use a log(n) least-common-ancestor algorithm such as binary lifting.
    2. Order the nodes of the full tree using a post-order traversal.
    3. Process the nodes for each individual query using a priority-queue with the aformentioned node ordering.
    4. Each term in the summation you are computing can be broken into two terms by considering the lowest-common-ancestor of the two nodes. Using this fact, you can precompute a few values for any pair of subtrees that allow you to compute their combined sum in constant time.
  • + 0 comments

    I'm not sure if we can get a solution in js/ts. I have adapted a working c++ solution. It gets tests 1-6 correct, 7-9 wrong, everything else times out.

    Code:

    'use strict';
    
    process.stdin.resume();
    process.stdin.setEncoding('utf-8');
    let inputString: string = '';
    let inputLines: string[]=[];
    process.stdin.on('data', function(inputStdin: string): void {
        inputString += inputStdin;
    });
    
    function readLine():string{
        return inputLines.shift();
    }
    
    const K = 1000000007;
    const L = 200001;
    
    let parents : number[] = Array(L);
    let children : number[] = Array(L);
    function resetChildren(){
        children = Array(L).fill(0);
    }
    let valuesSet : boolean[] = Array(L-1);
    function resetValuesSet(){
        valuesSet= Array(L-1).fill(false);
    }
    
    process.stdin.on('end', function(): void {
        inputLines = inputString.split('\n');
        inputString = '';
        main();
    });
    
    
    function main(){
        const [n,q] = initMain();
        // edges
        prepareEdges(n);
        queryloop:for(let i = 0; i < q; i++){
            const k = parseInt(readLine());
            if(k < 2) {
                // result = 0
                process.stdout.write(`0\n`);
                // skip next line
                readLine();
                // continue loop
                continue queryloop;
            }
            const query = readLine().split(" ").map(v => parseInt(v));
        
            // reset valueSet
            resetValuesSet();
    
            // reset children
            resetChildren();
            
            let valuesSum = 0;
            for (const q of query) {
                valuesSum += q;
                valuesSet[q] = true;
            }
            let sum = 0;
            for (let j = n; j > 0; j--) {
                let a = children[j];
                if (valuesSet[j]) {
                    a += j;
                }
                if (a) {
                    let x = a * (valuesSum - a);// ((a % K) * ((valuesSum - a) % K)) % K; //
                    // sum =(sum + x)%K;
                    sum += x;
                    // if(sum>K) sum = sum % K;
                }
                children[parents[j]] += a;
            }
            process.stdout.write(`${sum%K}\n`);
        }
    }
    
    function initMain(){
        return readLine().split(" ").map(v => parseInt(v));
    }
    
    function prepareEdges(n:number){
        for (let i = 1; i < n; i++){
            const [a,b] = readLine().split(" ").map(v=>parseInt(v));
            if(a<b)
                parents[b] = a;
            else
                parents[a] = b;
        }
        parents[1]=0;
    }
    
  • + 0 comments

    WTF i can do each query in O(n) time but i get time out on 3 questions, and stack overflow on 2 questions because the tree is too deep

    this question is fuking RIDICULOUS