3D scatter visualization

I tried to plot the scatter of the training data points in a 3d space using the following:

import pandas as pd

import matplotlib.pyplot as plt

from mpl_toolkits.mplot3d import Axes3D

X=pd.read_csv(“Logistic_X_Train.csv”).values

Y=pd.read_csv(“Logistic_Y_Train.csv”).values

fig=plt.figure()

ax = plt.axes(projection=‘3d’)

ax.scatter3D(X[: ,0], X[: ,1], X[: ,2], cmap=‘rainbow’)

plt.show()

I want to set the color of the points with respect to their class/y values; how can I do that?

Right now the plot looks like this:

Hey @SanchitSayala, you can give an argument inside the scatter function as :

scatter3D(X[:,0], X[:,1], X[:,2], cmap='rainbow',c=Y)  

Hope this helps.