Inversion Count TLE

I am getting TLE for all the 3 test cases. I’m creating the temporary array only once which is used in the merge step.

Here is my code:

import java.util.*;
public class Main {

	public int countInversions(int[] nums, int low, int high, int[] temp) {
		if (low >= high)
			return 0;
		int mid = low + (high - low) / 2;
		int left = countInversions(nums, 0, mid, temp);
		int right = countInversions(nums, mid + 1, high, temp);
		int merged = merge(nums, low, mid, high, temp);
		return left + right + merged;
	}

	public int merge(int[] nums, int low, int mid, int high, int[] temp) {
		int k = 0, i = low, j = mid + 1, count = 0;
		while (i <= mid && j <= high) {
			if (nums[i] <= nums[j]) {
				temp[k] = nums[i];
				i += 1;
			} else {
				count += (mid - i + 1);
				temp[k] = nums[j];
				j += 1;
			}
			k += 1;
		}

		while (i <= mid) {
			temp[k] = nums[i];
			i += 1;
			k += 1;
		}
		while (j <= high) {
			temp[k] = nums[j];
			j += 1;
			k += 1;
		}

		int size = high - low + 1;
		for (k = 0; k < size; k++) {
			nums[low] = temp[k];
			low += 1;
		}

		return count;
	}

    public static void main(String args[]) {
		Main main = new Main();
		int[] temp = new int[1000005];
		Scanner sc = new Scanner(System.in);
		int tests = sc.nextInt();
		for (int t = 0; t < tests; t++) {
			int n = sc.nextInt();
			int[] nums = new int[n];
			for (int i = 0; i < n; i++)
				nums[i] = sc.nextInt();
			System.out.println(main.countInversions(nums, 0, nums.length - 1, temp));
		}
    }
}