Merge Sort: Counting Inversions

  • + 0 comments

    O(n log(n))

    void merge(vector<int>& V, int left, int mid, int right, long& swaps) {
        vector<int> leftV(V.begin() + left, V.begin() + mid + 1), rightV(V.begin() + mid + 1, V.begin() + right + 1);
        int i = 0, j = 0, k = left;
        while (i < leftV.size() and j < rightV.size()) {
            if (leftV[i] <= rightV[j]) {
                V[k] = leftV[i];
                ++i;
            }
            else {
                V[k] = rightV[j];
                ++j;
                swaps = swaps + leftV.size() - i;
            }
            ++k;
        }
        while (i < leftV.size()) {
            V[k] = leftV[i];
            ++i;
            ++k;
        }
        while (j < rightV.size()) {
            V[k] = rightV[j];
            ++j;
            ++k;
        }
    }
    
    void merge_sort(vector<int>& V, int left, int right, long& swaps) {
        if (left >= right) return;
        int mid = left + (right - left) / 2;
        merge_sort(V, left, mid, swaps);
        merge_sort(V, mid + 1, right, swaps);
        merge(V, left, mid, right, swaps);
    }
    
    long countInversions(vector<int>& arr) {
        long swaps = 0;
        merge_sort(arr, 0, arr.size() - 1, swaps);
        return swaps;
    }