We use cookies to ensure you have the best browsing experience on our website. Please read our cookie policy for more information about how we use cookies.
# Enter your code here. Read input from STDIN. Print output to STDOUTimportitertoolsclassGraph:def__init__(self,weights):self.n=len(weights)self.weights=[0]+weightsself.edges=[[]for_inrange(n+1)]defdiffs(seq):d=[]foriinrange(len(seq)-1):d.append(seq[i+1]-seq[i])returnddefdistributions(elems,picks):assignments=itertools.combinations_with_replacement(elems+[None],picks)defdist_for_assignment(assgn):picks_by_elem=dict(((elem,len(list(picks)))for(elem,picks)initertools.groupby(assgn)))returndict(((elem,picks_by_elem.get(elem,0))foreleminelems))return(dist_for_assignment(a)forainassignments)defmax_sum(g,k):visits=[]seen=[Falsefor_inrange(g.n+1)]stack=[1]whilestack:v=stack.pop()visits.append(v)seen[v]=Truei=0whilei<len(g.edges[v]):nv=g.edges[v][i]ifseen[nv]:g.edges[v].pop(i)else:stack.append(nv)i+=1sums=[[]for_inrange(g.n+1)]forvinreversed(visits):children=g.edges[v]ifnotchildren:# no edges to removesums[v]=[g.weights[v]]else:child_sums=[sums[c]forcinchildren]total_cuts=sum((len(sums[c])-1forcinchildren))combined_sums=sums[children[0]]forchild_idxinrange(1,len(children)):child_sums=sums[children[child_idx]]new_combined_sums=[-(1<<64)for_inrange(total_cuts+1)]# best distribution of total_cuts among [:child_idx] childrenforcutsinrange(total_cuts+1):forcuts1inrange(cuts+1):ifcuts1>=len(combined_sums)orcuts-cuts1>=len(child_sums):continuenew_combined_sums[cuts]=max(new_combined_sums[cuts],combined_sums[cuts1]+child_sums[cuts-cuts1])combined_sums=new_combined_sumssums[v]=[s+g.weights[v]forsincombined_sums]ifv>1:# we could cut off this node if not rootiflen(sums[v])==1:ifsums[v][0]<0:sums[v].append(0)elifsums[v][1]<0:sums[v][1]=0#morecutsforlessnevermakessensemax_sum_idx=1foriinrange(1,len(sums[v])):sums[v][i]=max(sums[v][i],0)ifsums[v][i]>sums[v][max_sum_idx]:max_sum_idx=isums[v]=sums[v][:min(max_sum_idx,k)+1]returnsums[1][min(k,len(sums[1])-1)]n,k=map(int,input().split())w=list(map(int,input().split()))g=Graph(w)for_inrange(n-1):a,b=map(int,input().split())g.edges[a].append(b)g.edges[b].append(a)print(max_sum(g,k))
Cookie support is required to access HackerRank
Seems like cookies are disabled on this browser, please enable them to open this website
Tree Pruning
You are viewing a single comment's thread. Return to all comments →
There you go, PyPy 3 solution :-)