Prim's (MST) : Special Subtree

  • + 0 comments

    Java8

    public static int prims(int n, List<List<Integer>> edges, int start) {
        //USING PROVIDED DATA APPROACH
        
        Comparator<List<Integer>> comp = (l1, l2) -> {
            return l1.get(2) - l2.get(2);
        };
        
        Collections.sort(edges, comp);
        
        HashSet<Integer> set = new HashSet<>();
        set.add(start);
        
        int sumMin = 0;
        
        while (set.size() < n) {
            Iterator<List<Integer>> iter = edges.iterator();
                
            while (iter.hasNext()) {
                List<Integer> list = iter.next();
                
                int left = list.get(0);
                int right = list.get(1);
                int distance = list.get(2);
                
    						//Using LinkedList might speed this up
    						//Actual no need to remove already parsed list
    						// to fit set limit
                if (set.contains(left) && set.contains(right)) {
                    iter.remove();
                    continue;
                }
                                
                if ((set.contains(left) && !set.contains(right)) 
                        || (!set.contains(left) && set.contains(right))) {
                    sumMin += distance;
                    set.add(left);
                    set.add(right);
                    break;
                }
            }   
        }
        
        return sumMin;
        
        /* CREATING SUMMARY ARRAY APPROACH
        int res = 0;
        
        int[][] grid = new int[n + 1][n + 1];
        
        for (List<Integer> list : edges) {
            int i = list.get(0);
            int j = list.get(1);
            int dis = list.get(2);
            
            grid[i][j] = dis;
            grid[j][i] = dis;
        }
        
        HashSet<Integer> set = new HashSet<>();
    
        set.add(start);
        
        while (set.size() < n) {
            TreeMap<Integer, Integer> map = new TreeMap<>();
            
            for (Integer i : set) {
                for (int j = 1; j < n + 1; j ++) {
                    int curDis = grid[i][j];
                    if (!set.contains(j) && curDis != 0) {
                        map.put(curDis, j);        
                    }
                }
            }
            
            res += map.firstKey();
            set.add(map.firstEntry().getValue());
        }
        
        return res;
        */
    }