We use cookies to ensure you have the best browsing experience on our website. Please read our cookie policy for more information about how we use cookies.
# Enter your code here. Read input from STDIN. Print output to STDOUTimportitertoolsclassGraph:def__init__(self,weights):self.n=len(weights)self.weights=[0]+weightsself.edges=[[]for_inrange(n+1)]defdiffs(seq):d=[]foriinrange(len(seq)-1):d.append(seq[i+1]-seq[i])returnddefdistributions(elems,picks):assignments=itertools.combinations_with_replacement(elems+[None],picks)defdist_for_assignment(assgn):picks_by_elem=dict(((elem,len(list(picks)))for(elem,picks)initertools.groupby(assgn)))returndict(((elem,picks_by_elem.get(elem,0))foreleminelems))return(dist_for_assignment(a)forainassignments)defmax_sum(g,k):visits=[]seen=[Falsefor_inrange(g.n+1)]stack=[1]whilestack:v=stack.pop()visits.append(v)seen[v]=Truei=0whilei<len(g.edges[v]):nv=g.edges[v][i]ifseen[nv]:g.edges[v].pop(i)else:stack.append(nv)i+=1sums=[[]for_inrange(g.n+1)]forvinreversed(visits):children=g.edges[v]ifnotchildren:# no edges to removesums[v]=[g.weights[v]]else:child_sums=[sums[c]forcinchildren]total_cuts=sum((len(sums[c])-1forcinchildren))combined_sums=sums[children[0]]forchild_idxinrange(1,len(children)):child_sums=sums[children[child_idx]]new_combined_sums=[-(1<<64)for_inrange(total_cuts+1)]# best distribution of total_cuts among [:child_idx] childrenforcutsinrange(total_cuts+1):forcuts1inrange(cuts+1):ifcuts1>=len(combined_sums)orcuts-cuts1>=len(child_sums):continuenew_combined_sums[cuts]=max(new_combined_sums[cuts],combined_sums[cuts1]+child_sums[cuts-cuts1])combined_sums=new_combined_sumssums[v]=[s+g.weights[v]forsincombined_sums]ifv>1:# we could cut off this node if not rootiflen(sums[v])==1:ifsums[v][0]<0:sums[v].append(0)elifsums[v][1]<0:sums[v][1]=0#morecutsforlessnevermakessensemax_sum_idx=1foriinrange(1,len(sums[v])):sums[v][i]=max(sums[v][i],0)ifsums[v][i]>sums[v][max_sum_idx]:max_sum_idx=isums[v]=sums[v][:min(max_sum_idx,k)+1]returnsums[1][min(k,len(sums[1])-1)]n,k=map(int,input().split())w=list(map(int,input().split()))g=Graph(w)for_inrange(n-1):a,b=map(int,input().split())g.edges[a].append(b)g.edges[b].append(a)print(max_sum(g,k))
Cookie support is required to access HackerRank
Seems like cookies are disabled on this browser, please enable them to open this website
An unexpected error occurred. Please try reloading the page. If problem persists, please contact support@hackerrank.com
Tree Pruning
You are viewing a single comment's thread. Return to all comments →
There you go, PyPy 3 solution :-)