1
0
intelektikos-pagrindai/decision_tree.py
2024-03-24 21:13:24 +02:00

167 lines
5.3 KiB
Python

import pandas as pd
import numpy as np
import math
from dataclasses import dataclass
from typing import Optional
from sklearn.base import BaseEstimator, ClassifierMixin
@dataclass
class Node:
# If leaf node
value: Optional[float] = None
# If branching node
feature: Optional[str] = None
threshold: Optional[float] = None
left: Optional["Node"] = None
right: Optional["Node"] = None
info_gain: Optional[float] = None
def leaf(value: float) -> "Node":
return Node(value=value)
def branch(feature: str, threshold: float, info_gain: float, left: "Node", right: "Node") -> "Node":
return Node(
feature=feature,
threshold=threshold,
left=left,
right=right,
info_gain=info_gain
)
@dataclass
class SplitResult:
feature: str
threshold: float
info_gain: float
left_X: pd.DataFrame
left_Y: pd.Series
right_X: pd.DataFrame
right_Y: pd.Series
class DecisionTree(BaseEstimator, ClassifierMixin):
root_node: Optional[Node]
def __init__(self, min_samples_split=2, max_depth=10) -> None:
self.root_node = None
self.min_samples_split = min_samples_split
self.max_depth = max_depth
def build_tree(self, X: pd.DataFrame, Y: pd.Series, current_depth = 0):
num_samples = np.shape(X)[0]
if num_samples >= self.min_samples_split and current_depth <= self.max_depth:
best_split = self.get_best_split(X, Y)
if best_split and best_split.info_gain > 0:
left_node = self.build_tree(best_split.left_X , best_split.left_Y , current_depth + 1)
right_node = self.build_tree(best_split.right_X, best_split.right_Y, current_depth + 1)
return Node.branch(
best_split.feature,
best_split.threshold,
best_split.info_gain,
left_node,
right_node,
)
return Node.leaf(self.calculate_leaf_value(Y))
def get_best_split(self, X: pd.DataFrame, Y: pd.Series):
best_split = SplitResult(
"unknown",
0,
-math.inf,
pd.DataFrame(),
pd.Series(),
pd.DataFrame(),
pd.Series()
)
for column_name in X:
column = X[column_name]
thresholds = column.unique() # TODO: Should probably be cached
if len(thresholds) > 20:
continue
for threshold in thresholds:
left_X, left_Y, right_X, right_Y = self.split(X, Y, column_name, threshold)
if not left_X.empty and not right_X.empty:
info_gain = self.gini_information_gain(Y, left_Y, right_Y)
if info_gain > best_split.info_gain:
best_split.info_gain = info_gain
best_split.feature = column_name
best_split.threshold = threshold
best_split.left_X = left_X
best_split.left_Y = left_Y
best_split.right_X = right_X
best_split.right_Y = right_Y
if best_split.info_gain == -math.inf:
return None
return best_split
def split(self, X: pd.DataFrame, Y: pd.Series, column_name: str, threshold: float):
left_rows = X[column_name] <= threshold
right_rows = ~left_rows
return X[left_rows], Y[left_rows], X[right_rows], Y[right_rows]
def gini_information_gain(self, Y, left_Y, right_Y):
weight_left = len(left_Y) / len(Y)
weight_right = len(right_Y) / len(Y)
return self.gini_index(Y) - weight_left*self.gini_index(left_Y) - weight_right*self.gini_index(right_Y)
# return self.entropy(Y) - weight_left*self.entropy(left_Y) - weight_right*self.entropy(right_Y)
def gini_index(self, Y):
gini = 0
for y_value in np.unique(Y):
probability = (Y == y_value).sum() / len(Y)
gini += probability*probability
return 1 - gini
def entropy(self, Y):
entropy = 0
for y_value in np.unique(Y):
probability = (Y == y_value).sum() / len(Y)
entropy += -probability * np.log2(probability)
return entropy
def calculate_leaf_value(self, Y):
return Y.mode()[0]
def fit(self, X, Y):
self.root_node = self.build_tree(X, Y)
def predict(self, X: pd.DataFrame):
return X.apply(lambda x: self.make_prediction(x, self.root_node), axis=1)
def make_prediction(self, x, tree: Node):
if tree.value != None: return tree.value
if x[tree.feature] <= tree.threshold:
return self.make_prediction(x, tree.left)
else:
return self.make_prediction(x, tree.right)
def print(self, tree=None, indent=" "):
if not tree:
tree = self.root_node
if tree.value is not None:
print(tree.value)
else:
print(str(tree.feature), "<=", tree.threshold, "?", tree.info_gain)
print("%sleft: " % (indent), end="")
self.print(tree.left, indent + indent)
print("%sright: " % (indent), end="")
self.print(tree.right, indent + indent)
def __sklearn_is_fitted__(self):
return self.root_node != None