I tried to plot the scatter of the training data points in a 3d space using the following:
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
X=pd.read_csv(“Logistic_X_Train.csv”).values
Y=pd.read_csv(“Logistic_Y_Train.csv”).values
fig=plt.figure()
ax = plt.axes(projection=‘3d’)
ax.scatter3D(X[: ,0], X[: ,1], X[: ,2], cmap=‘rainbow’)
plt.show()
I want to set the color of the points with respect to their class/y values; how can I do that?