import pandas as pd import numpy as np import math from dataclasses import dataclass from typing import Optional from sklearn.base import BaseEstimator, ClassifierMixin @dataclass class Node: # If leaf node value: Optional[float] = None # If branching node feature: Optional[str] = None threshold: Optional[float] = None left: Optional["Node"] = None right: Optional["Node"] = None info_gain: Optional[float] = None def leaf(value: float) -> "Node": return Node(value=value) def branch(feature: str, threshold: float, info_gain: float, left: "Node", right: "Node") -> "Node": return Node( feature=feature, threshold=threshold, left=left, right=right, info_gain=info_gain ) @dataclass class SplitResult: feature: str threshold: float info_gain: float left_X: pd.DataFrame left_Y: pd.Series right_X: pd.DataFrame right_Y: pd.Series class DecisionTree(BaseEstimator, ClassifierMixin): root_node: Optional[Node] def __init__(self, min_samples_split=2, max_depth=10) -> None: self.root_node = None self.min_samples_split = min_samples_split self.max_depth = max_depth def build_tree(self, X: pd.DataFrame, Y: pd.Series, current_depth = 0): num_samples = np.shape(X)[0] if num_samples >= self.min_samples_split and current_depth <= self.max_depth: best_split = self.get_best_split(X, Y) if best_split and best_split.info_gain > 0: left_node = self.build_tree(best_split.left_X , best_split.left_Y , current_depth + 1) right_node = self.build_tree(best_split.right_X, best_split.right_Y, current_depth + 1) return Node.branch( best_split.feature, best_split.threshold, best_split.info_gain, left_node, right_node, ) return Node.leaf(self.calculate_leaf_value(Y)) def get_best_split(self, X: pd.DataFrame, Y: pd.Series): best_split = SplitResult( "unknown", 0, -math.inf, pd.DataFrame(), pd.Series(), pd.DataFrame(), pd.Series() ) for column_name in X: column = X[column_name] thresholds = column.unique() # TODO: Should probably be cached if len(thresholds) > 20: continue for threshold in thresholds: left_X, left_Y, right_X, right_Y = self.split(X, Y, column_name, threshold) if not left_X.empty and not right_X.empty: info_gain = self.gini_information_gain(Y, left_Y, right_Y) if info_gain > best_split.info_gain: best_split.info_gain = info_gain best_split.feature = column_name best_split.threshold = threshold best_split.left_X = left_X best_split.left_Y = left_Y best_split.right_X = right_X best_split.right_Y = right_Y if best_split.info_gain == -math.inf: return None return best_split def split(self, X: pd.DataFrame, Y: pd.Series, column_name: str, threshold: float): left_rows = X[column_name] <= threshold right_rows = ~left_rows return X[left_rows], Y[left_rows], X[right_rows], Y[right_rows] def gini_information_gain(self, Y, left_Y, right_Y): weight_left = len(left_Y) / len(Y) weight_right = len(right_Y) / len(Y) return self.gini_index(Y) - weight_left*self.gini_index(left_Y) - weight_right*self.gini_index(right_Y) # return self.entropy(Y) - weight_left*self.entropy(left_Y) - weight_right*self.entropy(right_Y) def gini_index(self, Y): gini = 0 for y_value in np.unique(Y): probability = (Y == y_value).sum() / len(Y) gini += probability*probability return 1 - gini def entropy(self, Y): entropy = 0 for y_value in np.unique(Y): probability = (Y == y_value).sum() / len(Y) entropy += -probability * np.log2(probability) return entropy def calculate_leaf_value(self, Y): return Y.mode()[0] def fit(self, X, Y): self.root_node = self.build_tree(X, Y) def predict(self, X: pd.DataFrame): return X.apply(lambda x: self.make_prediction(x, self.root_node), axis=1) def make_prediction(self, x, tree: Node): if tree.value != None: return tree.value if x[tree.feature] <= tree.threshold: return self.make_prediction(x, tree.left) else: return self.make_prediction(x, tree.right) def print(self, tree=None, indent=" "): if not tree: tree = self.root_node if tree.value is not None: print(tree.value) else: print(str(tree.feature), "<=", tree.threshold, "?", tree.info_gain) print("%sleft: " % (indent), end="") self.print(tree.left, indent + indent) print("%sright: " % (indent), end="") self.print(tree.right, indent + indent) def __sklearn_is_fitted__(self): return self.root_node != None