• + 1 comment

    python3

    def jeanisRoute(n, cities, roads):
        """The strategy of this approach is to:
        1) Prune off any branches of the tree that aren't visited
        2) Note that to visit each remaining node and return to start
        has length = 2 * sum(all remaining edge weights)
        3) Since we need not return to start, the most by which we can
        reduce total length is the max distance b/w 2 nodes
        i.e. the tree's diameter."""
        
        # Make the adjacency list and sum (2x) total graph weight
        adj = {i: [] for i in range(1, n+1)}
        total = 0
        for u, v, w in roads:
            adj[u].append((v, w))
            adj[v].append((u, w))
            total += 2*w
        
        # Iteratively, prune leaves that are not visited (adjusting total)
        flag = True  # track if any changes made this pass
        non_dest = set(range(1, n+1)).difference({*cities})
        while flag:
            flag = False
            for u in non_dest:
                if len(adj[u]) == 1:
                    v, w = adj[u].pop()
                    adj[v].remove((u, w))
                    flag = True
                    total -= 2*w
                    
        # Compute the pruned tree's diameter by running DFS twice.
        # First time, start at any city and find the furthest leaf.
        # Second time, start at first endpoint. Result is the diameter
        u, d = dfs(cities[0], adj)
        v, diameter = dfs(u, adj)
        
        return total - diameter
            
    
    def dfs(start, adj):
        "Conduct DFS to find the node of max distance from node <start>"
        
        seen = set([start])
        queue = [(start, 0)]
        end, max_dist = start, 0
        
        while queue:
            curr, dist = queue.pop()
            for nxt, wt in adj[curr]:
                if nxt not in seen:
                    seen.add(nxt)
                    new_dist = dist + wt
                    queue.append((nxt, new_dist))
                    if new_dist > max_dist:
                        end, max_dist = nxt, new_dist
        
        return end, max_dist