Replace with sum of greater nodes

import java.util.*;
public class BST {
private class Node{
int data;
Node left;
Node right;
}

	private Node root;

	public BST(int[] arr){
		this.root = construct(arr,0,arr.length-1);
	}

	private Node construct(int[] arr,int lo,int hi){
		if(lo>hi){
			return null;
		}

		int mid = (lo+hi)/2;

		Node nn = new Node();
		nn.data = arr[mid];
		nn.left = construct(arr,lo,mid-1);
		nn.right = construct(arr,mid+1,hi);

		return nn;
	}

		public void preOrder() {
		this.preOrder(this.root);
		System.out.println(" END");
	}
	
	private void preOrder(Node node) {
		if(node == null) {
			return;
		}
		System.out.print(node.data+", ");
		preOrder(node.left);
		preOrder(node.right);
	}
	
	public void modifyBST() {
		int sum = 0;
		modifyBST(this.root,sum);
	}
	
	private void modifyBST(Node node,int sum) {
		if(node == null) {
			return;
		}
		modifyBST(node.right,sum);
		
		sum = sum + root.data;
		root.data = sum;
		
		modifyBST(node.left,sum);
	}
	
	public static void main(String[] args) {
		Scanner sc = new Scanner(System.in);
		int N = sc.nextInt();
		int[] in = new int[N];
		for(int i=0;i<N;i++) {
			in[i] = sc.nextInt();
		}
		BST tree = new BST(in);
		
		tree.modifyBST();
		tree.preOrder();
	}
	


}

//sir can you please tell what’s wrong with the code I’m getting wrong output

@Siddharth_sharma1808,

https://ide.codingblocks.com/s/223359 Corrected code.

Keep sum as global and do not pass it as a argument.
Update node not root.