Merge Sort: Counting Inversions

Sort by

recency

|

507 Discussions

|

  • + 0 comments

    Java Solution

        public static long countInversions(List<Integer> arr) {
            int size = arr.size();
            AtomicLong atomicLong = new AtomicLong(0L);
            
            mergeSort(arr, 0, size, atomicLong);
            
            return atomicLong.get();
        }
        
        private static void mergeSort(List<Integer> arr, int firstIndex, int lastIndex, AtomicLong atomicLong){
            if(firstIndex >= lastIndex){
                return;
            }
            final int mediumIndex = firstIndex + ((lastIndex - firstIndex) / 2);
            if(mediumIndex == firstIndex){
                return;
            }
            mergeSort(arr, firstIndex, mediumIndex, atomicLong);
            mergeSort(arr, mediumIndex, lastIndex, atomicLong);
            merge(arr, firstIndex, mediumIndex, lastIndex, atomicLong);
        }
        
        private static void merge(List<Integer> arr, int firstIndex, int mediumIndex, int lastIndex, AtomicLong atomicLong){
            final List<Integer> firstList = arr.subList(firstIndex, mediumIndex).stream().collect(Collectors.toList());
            final List<Integer> secondList = arr.subList(mediumIndex, lastIndex).stream().collect(Collectors.toList());
            
            final int firstSize = firstList.size();
            final int secondSize = secondList.size();
            
            if(firstSize == 0 || secondSize == 0){
                return;
            }
            
            int i = 0, j = 0, k = firstIndex;
            while(i<firstSize && j < secondSize){
                int first = firstList.get(i);
                int second = secondList.get(j);
                if(first <= second){
                    arr.set(k, first);
                    i++;
                }
                else{
                    arr.set(k, second);
                    j++;
                    atomicLong.addAndGet(firstSize-i);
                }
                k++;
            }
            
            while(i<firstSize){
                arr.set(k, firstList.get(i));
                i++;
                k++;
            }
            
            while(j < secondSize){
                arr.set(k, secondList.get(j));
                j++;
                k++;
            }
        }
    
  • + 0 comments

    O(n log(n))

    void merge(vector<int>& V, int left, int mid, int right, long& swaps) {
        vector<int> leftV(V.begin() + left, V.begin() + mid + 1), rightV(V.begin() + mid + 1, V.begin() + right + 1);
        int i = 0, j = 0, k = left;
        while (i < leftV.size() and j < rightV.size()) {
            if (leftV[i] <= rightV[j]) {
                V[k] = leftV[i];
                ++i;
            }
            else {
                V[k] = rightV[j];
                ++j;
                swaps = swaps + leftV.size() - i;
            }
            ++k;
        }
        while (i < leftV.size()) {
            V[k] = leftV[i];
            ++i;
            ++k;
        }
        while (j < rightV.size()) {
            V[k] = rightV[j];
            ++j;
            ++k;
        }
    }
    
    void merge_sort(vector<int>& V, int left, int right, long& swaps) {
        if (left >= right) return;
        int mid = left + (right - left) / 2;
        merge_sort(V, left, mid, swaps);
        merge_sort(V, mid + 1, right, swaps);
        merge(V, left, mid, right, swaps);
    }
    
    long countInversions(vector<int>& arr) {
        long swaps = 0;
        merge_sort(arr, 0, arr.size() - 1, swaps);
        return swaps;
    }
    
  • + 0 comments

    C++ solution. Solving this by merge sort is probably the quickest way, as it is done in O(nlogn) time, as opposed to other methods most of which run in O(n^2).

    An addition to the merge sort algorithm made specifically for this question is that we need to keep track of the number of swaps made in the algorithm.

    vector<int> merge_sort(vector<int> arr, long& swapCount) {
        if(arr.size() <= 1) return arr;
        
        size_t mid = arr.size() / 2 - 1;
        
        vector<int> left = vector<int>(arr.begin(), arr.begin()+mid+1);
        vector<int> right = vector<int>(arr.begin()+mid+1, arr.end());
        
        left = merge_sort(left, swapCount);
        right = merge_sort(right, swapCount);
        
        vector<int> res;
        
        int l = 0, r = 0;
        
        while(l < left.size() || r < right.size()) {
            if((l < left.size() && r < right.size())) {
                if(left[l] > right[r]) {
                    swapCount += left.size()-l;
                    res.push_back(right[r]);
                    r++;
                } else {
                    res.push_back(left[l]);
                    l++;
                }
            } else if(l < left.size()) {
                res.push_back(left[l]);
                l++;
            } else {
                res.push_back(right[r]);
                r++;
            }
        }
        
        return res;
    }
    
    long countInversions(vector<int> arr) {
        long swapCount = 0;
        merge_sort(arr, swapCount);
       
       return swapCount;
    }
    
  • + 0 comments

    For people looking for PHP solution. You can use the merge sort procedure to solve this

    function countInversions($arr) {
        $temp = array_fill(0, count($arr), 0);
        return mergeSort($arr, $temp, 0, count($arr) - 1);
    }
    
    function mergeSort(&$arr, &$temp, $left, $right) {
        $inversions = 0;
        if ($left < $right) {
            $mid = (int)(($left + $right) / 2);
            $inversions += mergeSort($arr, $temp, $left, $mid);
            $inversions += mergeSort($arr, $temp, $mid + 1, $right);
            $inversions += merge($arr, $temp, $left, $mid, $right);
        }
        return $inversions;
    }
    
    function merge(&$arr, &$temp, $left, $mid, $right) {
        $i = $left;
        $j = $mid + 1;
        $k = $left;
        $inversions = 0;
    
        while ($i <= $mid && $j <= $right) {
            if ($arr[$i] <= $arr[$j]) {
                $temp[$k++] = $arr[$i++];
            } else {
                $temp[$k++] = $arr[$j++];
                // Count inversions
                $inversions += ($mid - $i + 1);
            }
        }
    
        while ($i <= $mid) {
            $temp[$k++] = $arr[$i++];
        }
    
        while ($j <= $right) {
            $temp[$k++] = $arr[$j++];
        }
    
        // Copy back the merged elements to the original array
        for ($i = $left; $i <= $right; $i++) {
            $arr[$i] = $temp[$i];
        }
    
        return $inversions;
    }
    
  • + 0 comments

    Python3 My Solution

    N_swaps = 0
    def merge_sort(arr):
        n = len(arr)
        if n == 1:
            return arr
        med = n//2
        sub1 = merge_sort(arr[:med])
        sub2 = merge_sort(arr[med:])
    
        global N_swaps
    
        sorted_arr = []
        i, j = 0, 0
        while i < len(sub1) and j < len(sub2):
            if sub1[i] <= sub2[j]:
                sorted_arr.append(sub1[i])
                i += 1
            elif sub1[i] > sub2[j]:
                sorted_arr.append(sub2[j])
                j += 1
                
                N_swaps += len(sub1) - i
        if i < len(sub1):
            sorted_arr += sub1[i:]
            
        if j < len(sub2):
            sorted_arr += sub2[j:]
        return sorted_arr
    
    def countInversions(arr):
        # Write your code here
        global N_swaps
        N_swaps = 0
        merge_sort(arr)
        return N_swaps