• + 0 comments

    There you go, PyPy 3 solution :-)

    # Enter your code here. Read input from STDIN. Print output to STDOUT
    import itertools
    
    class Graph:
        def __init__(self, weights):
            self.n = len(weights)
            self.weights = [0] + weights
            self.edges = [[] for _ in range(n + 1)]
            
    def diffs(seq):
        d = []
        for i in range(len(seq) - 1):
            d.append(seq[i + 1] - seq[i])
        return d
    
    def distributions(elems, picks):
        assignments = itertools.combinations_with_replacement(elems + [None], picks)
        def dist_for_assignment(assgn):
            picks_by_elem = dict(((elem, len(list(picks))) for (elem, picks) in itertools.groupby(assgn)))
            return dict(((elem, picks_by_elem.get(elem, 0)) for elem in elems))
        return (dist_for_assignment(a) for a in assignments)
          
    def max_sum(g, k):
        visits = []
        seen = [False for _ in range(g.n + 1)]
        stack = [1]
        while stack:
            v = stack.pop()
            visits.append(v)
            seen[v] = True
            i = 0
            while i < len(g.edges[v]):
                nv = g.edges[v][i]
                if seen[nv]:
                    g.edges[v].pop(i)
                else:
                    stack.append(nv)
                    i += 1       
        sums = [[] for _ in range(g.n + 1)]
        for v in reversed(visits):
            children = g.edges[v]
            if not children:
                # no edges to remove
                sums[v] = [g.weights[v]]
            else:
                child_sums = [sums[c] for c in children]
                total_cuts = sum((len(sums[c]) - 1 for c in children))
                combined_sums = sums[children[0]]
                for child_idx in range(1, len(children)):
                    child_sums = sums[children[child_idx]]
                    new_combined_sums = [-(1 << 64) for _ in range(total_cuts + 1)]
                    # best distribution of total_cuts among [:child_idx] children
                    for cuts in range(total_cuts + 1):
                        for cuts1 in range(cuts + 1):
                            if cuts1 >= len(combined_sums) or cuts - cuts1 >=len(child_sums):
                                continue
                            new_combined_sums[cuts] = max(new_combined_sums[cuts], combined_sums[cuts1] + child_sums[cuts - cuts1])
                    combined_sums = new_combined_sums
                sums[v] = [s + g.weights[v] for s in combined_sums]        
            if v > 1:
                # we could cut off this node if not root
                if len(sums[v]) == 1:
                    if sums[v][0] < 0:
                        sums[v].append(0)
                elif sums[v][1] < 0:
                    sums[v][1] = 0               
            # more cuts for less never makes sense
            max_sum_idx = 1
            for i in range(1, len(sums[v])):
                sums[v][i] = max(sums[v][i], 0)
                if sums[v][i] > sums[v][max_sum_idx]:
                    max_sum_idx = i
            sums[v] = sums[v][:min(max_sum_idx,k) + 1]    
        return sums[1][min(k, len(sums[1]) - 1)]
    
    n, k = map(int, input().split())
    w = list(map(int, input().split()))
    g = Graph(w)
    for _ in range(n - 1):
        a, b = map(int, input().split())
        g.edges[a].append(b)
        g.edges[b].append(a)
    print(max_sum(g, k))