167 lines
5.3 KiB
Python
167 lines
5.3 KiB
Python
import pandas as pd
|
|
import numpy as np
|
|
import math
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
from sklearn.base import BaseEstimator, ClassifierMixin
|
|
|
|
@dataclass
|
|
class Node:
|
|
# If leaf node
|
|
value: Optional[float] = None
|
|
|
|
# If branching node
|
|
feature: Optional[str] = None
|
|
threshold: Optional[float] = None
|
|
left: Optional["Node"] = None
|
|
right: Optional["Node"] = None
|
|
info_gain: Optional[float] = None
|
|
|
|
def leaf(value: float) -> "Node":
|
|
return Node(value=value)
|
|
|
|
def branch(feature: str, threshold: float, info_gain: float, left: "Node", right: "Node") -> "Node":
|
|
return Node(
|
|
feature=feature,
|
|
threshold=threshold,
|
|
left=left,
|
|
right=right,
|
|
info_gain=info_gain
|
|
)
|
|
|
|
@dataclass
|
|
class SplitResult:
|
|
feature: str
|
|
threshold: float
|
|
info_gain: float
|
|
|
|
left_X: pd.DataFrame
|
|
left_Y: pd.Series
|
|
|
|
right_X: pd.DataFrame
|
|
right_Y: pd.Series
|
|
|
|
class DecisionTree(BaseEstimator, ClassifierMixin):
|
|
root_node: Optional[Node]
|
|
|
|
def __init__(self, min_samples_split=2, max_depth=10) -> None:
|
|
self.root_node = None
|
|
|
|
self.min_samples_split = min_samples_split
|
|
self.max_depth = max_depth
|
|
|
|
def build_tree(self, X: pd.DataFrame, Y: pd.Series, current_depth = 0):
|
|
num_samples = np.shape(X)[0]
|
|
if num_samples >= self.min_samples_split and current_depth <= self.max_depth:
|
|
best_split = self.get_best_split(X, Y)
|
|
|
|
if best_split and best_split.info_gain > 0:
|
|
left_node = self.build_tree(best_split.left_X , best_split.left_Y , current_depth + 1)
|
|
right_node = self.build_tree(best_split.right_X, best_split.right_Y, current_depth + 1)
|
|
return Node.branch(
|
|
best_split.feature,
|
|
best_split.threshold,
|
|
best_split.info_gain,
|
|
left_node,
|
|
right_node,
|
|
)
|
|
|
|
return Node.leaf(self.calculate_leaf_value(Y))
|
|
|
|
def get_best_split(self, X: pd.DataFrame, Y: pd.Series):
|
|
best_split = SplitResult(
|
|
"unknown",
|
|
0,
|
|
-math.inf,
|
|
pd.DataFrame(),
|
|
pd.Series(),
|
|
pd.DataFrame(),
|
|
pd.Series()
|
|
)
|
|
|
|
for column_name in X:
|
|
column = X[column_name]
|
|
thresholds = column.unique() # TODO: Should probably be cached
|
|
if len(thresholds) > 20:
|
|
continue
|
|
|
|
for threshold in thresholds:
|
|
left_X, left_Y, right_X, right_Y = self.split(X, Y, column_name, threshold)
|
|
if not left_X.empty and not right_X.empty:
|
|
info_gain = self.gini_information_gain(Y, left_Y, right_Y)
|
|
|
|
if info_gain > best_split.info_gain:
|
|
best_split.info_gain = info_gain
|
|
best_split.feature = column_name
|
|
best_split.threshold = threshold
|
|
best_split.left_X = left_X
|
|
best_split.left_Y = left_Y
|
|
best_split.right_X = right_X
|
|
best_split.right_Y = right_Y
|
|
|
|
if best_split.info_gain == -math.inf:
|
|
return None
|
|
|
|
return best_split
|
|
|
|
def split(self, X: pd.DataFrame, Y: pd.Series, column_name: str, threshold: float):
|
|
left_rows = X[column_name] <= threshold
|
|
right_rows = ~left_rows
|
|
|
|
return X[left_rows], Y[left_rows], X[right_rows], Y[right_rows]
|
|
|
|
def gini_information_gain(self, Y, left_Y, right_Y):
|
|
weight_left = len(left_Y) / len(Y)
|
|
weight_right = len(right_Y) / len(Y)
|
|
return self.gini_index(Y) - weight_left*self.gini_index(left_Y) - weight_right*self.gini_index(right_Y)
|
|
# return self.entropy(Y) - weight_left*self.entropy(left_Y) - weight_right*self.entropy(right_Y)
|
|
|
|
def gini_index(self, Y):
|
|
gini = 0
|
|
for y_value in np.unique(Y):
|
|
probability = (Y == y_value).sum() / len(Y)
|
|
gini += probability*probability
|
|
return 1 - gini
|
|
|
|
def entropy(self, Y):
|
|
entropy = 0
|
|
for y_value in np.unique(Y):
|
|
probability = (Y == y_value).sum() / len(Y)
|
|
entropy += -probability * np.log2(probability)
|
|
return entropy
|
|
|
|
def calculate_leaf_value(self, Y):
|
|
return Y.mode()[0]
|
|
|
|
def fit(self, X, Y):
|
|
self.root_node = self.build_tree(X, Y)
|
|
|
|
def predict(self, X: pd.DataFrame):
|
|
|
|
return X.apply(lambda x: self.make_prediction(x, self.root_node), axis=1)
|
|
|
|
def make_prediction(self, x, tree: Node):
|
|
if tree.value != None: return tree.value
|
|
|
|
if x[tree.feature] <= tree.threshold:
|
|
return self.make_prediction(x, tree.left)
|
|
else:
|
|
return self.make_prediction(x, tree.right)
|
|
|
|
def print(self, tree=None, indent=" "):
|
|
if not tree:
|
|
tree = self.root_node
|
|
|
|
if tree.value is not None:
|
|
print(tree.value)
|
|
|
|
else:
|
|
print(str(tree.feature), "<=", tree.threshold, "?", tree.info_gain)
|
|
print("%sleft: " % (indent), end="")
|
|
self.print(tree.left, indent + indent)
|
|
print("%sright: " % (indent), end="")
|
|
self.print(tree.right, indent + indent)
|
|
|
|
|
|
def __sklearn_is_fitted__(self):
|
|
return self.root_node != None |