118 lines
4.3 KiB
Python
118 lines
4.3 KiB
Python
import pandas as pd
|
|
import numpy as np
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor, plot_tree
|
|
from sklearn.metrics import accuracy_score, roc_curve, confusion_matrix
|
|
import matplotlib.pyplot as plt
|
|
from sklearn.preprocessing import LabelEncoder
|
|
from sklearn.model_selection import cross_val_score
|
|
|
|
# Load the Carseats data
|
|
MusicSpotify = pd.read_csv('music_spotify.csv')
|
|
MusicSpotify = MusicSpotify.drop(columns=['X', 'artist', 'song_title'])
|
|
#print(MusicSpotify.head())
|
|
|
|
label_encoder = LabelEncoder()
|
|
label_encoder.fit_transform(MusicSpotify['target'])
|
|
|
|
# Data division
|
|
X = MusicSpotify.drop(columns=['target']) #dropping Sales, because class variable High is made of variable Sales.
|
|
y = MusicSpotify['target']
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)
|
|
|
|
# Tree training (no cost complexity pruning)
|
|
tree_music_spotify = DecisionTreeClassifier(ccp_alpha=0.00, min_samples_split=2)
|
|
tree_music_spotify.fit(X_train, y_train)
|
|
|
|
pred_rbf = tree_music_spotify.predict(X_test)
|
|
|
|
|
|
if False:
|
|
# Plot the Decision Tree with matplotlib
|
|
feature_names = X_train.columns.tolist() # Assuming X_train is a DataFrame
|
|
class_names = y_train.unique().astype(str)
|
|
plt.figure(figsize=(12, 8))
|
|
plot_tree(tree_music_spotify, filled=True, feature_names=feature_names, class_names=class_names, rounded=True, fontsize=10)
|
|
plt.show()
|
|
|
|
# Tree information and visualization
|
|
path = tree_music_spotify.cost_complexity_pruning_path(X_train, y_train)
|
|
ccp_alphas, impurities = path.ccp_alphas, path.impurities
|
|
#path = tree_music_spotify.cost_complexity_pruning_path(X_train, y_train)
|
|
if False:
|
|
fig, ax = plt.subplots()
|
|
ax.plot(ccp_alphas[:-1], impurities[:-1], marker="o", drawstyle="steps-post")
|
|
ax.set_xlabel("effective alpha")
|
|
ax.set_ylabel("total impurity of leaves")
|
|
ax.set_title("Total Impurity vs effective alpha for training set")
|
|
plt.grid()
|
|
plt.show()
|
|
|
|
if False:
|
|
# Vary the hyperparameter (e.g., max depth)
|
|
depth_values = range(1, 25)
|
|
cv_scores = []
|
|
|
|
for depth in depth_values:
|
|
dt_classifier = DecisionTreeClassifier(max_depth=depth, random_state=20)
|
|
scores = cross_val_score(dt_classifier, X, y, cv=10, scoring='accuracy')
|
|
cv_scores.append(np.mean(scores))
|
|
|
|
# Plot the cross-validated error
|
|
plt.plot(depth_values, 1 - np.array(cv_scores), marker='o')
|
|
plt.xlabel('Tree depth')
|
|
plt.ylabel('Cross-validated error rate')
|
|
plt.title('Cross-validated error vs. tree depth')
|
|
plt.grid(True)
|
|
plt.show()
|
|
|
|
if False:
|
|
ccp_alphas_collect=[]
|
|
accuracy_collect=[]
|
|
for ccp_alpha in ccp_alphas:
|
|
tree = DecisionTreeClassifier(ccp_alpha=ccp_alpha)
|
|
tree.fit(X_train, y_train)
|
|
accuracy = tree.score(X_test, y_test)
|
|
ccp_alphas_collect.append(ccp_alpha)
|
|
accuracy_collect.append(accuracy)
|
|
|
|
plt.plot(np.array(ccp_alphas_collect),np.array(accuracy_collect))
|
|
plt.grid()
|
|
plt.xlim(0, 0.03)
|
|
plt.xlabel('CP alpha')
|
|
plt.ylabel('Accuracy')
|
|
plt.show()
|
|
|
|
if False:
|
|
plt.figure(figsize=(12, 6))
|
|
plot_tree(tree_music_spotify, filled=True, feature_names=X.columns.tolist())
|
|
plt.show()
|
|
|
|
# Prediction and accuracy
|
|
pred = tree_music_spotify.predict(X_test)
|
|
accuracy = accuracy_score(y_test, pred) * 100
|
|
print("Unprunned:")
|
|
print(f"Accuracy: {accuracy:.2f}%")
|
|
print("10-fold cross-validation score: ", cross_val_score(tree_music_spotify, X, y, cv=10).mean())
|
|
print("Confusion Matrix:\n", confusion_matrix(pred, y_test))
|
|
|
|
# Pruning
|
|
music_spotify_prunned = DecisionTreeClassifier(ccp_alpha=0.006, random_state=20)
|
|
music_spotify_prunned.fit(X_train, y_train)
|
|
pred = music_spotify_prunned.predict(X_test)
|
|
print("Prunned:")
|
|
print(f"Accuracy: {accuracy_score(y_test, pred) * 100:.2f}%")
|
|
print("10-fold cross-validation score: ", cross_val_score(music_spotify_prunned, X, y, cv=10).mean())
|
|
print("Confusion Matrix:\n", confusion_matrix(pred, y_test))
|
|
|
|
if False:
|
|
plt.figure(figsize=(12, 6))
|
|
plot_tree(music_spotify_prunned, filled=True, feature_names=X.columns.tolist())
|
|
plt.show()
|
|
|
|
if True:
|
|
feat_importances = pd.DataFrame(tree_music_spotify.feature_importances_, index=X_train.columns, columns=["Importance"])
|
|
feat_importances.sort_values(by='Importance', ascending=False, inplace=True)
|
|
feat_importances.plot(kind='bar', figsize=(8,6))
|
|
plt.show() |