NumPyによる決定木の簡易実装

決定木

概要

決定木は言わずもがな木の根からフローチャートのようにたどっていくとそのデータがどのクラス属するのか予測できるモデルです。各ノードはある特徴量について regressionについては扱わないよ


https://en.wikipedia.org/wiki/Decisiontreelearning

ではどのようにして分岐に使う変数と閾値を決めていけばいいでしょうか。

データの分割方法

ジニ不純度(Gini impurity)

ジニ不純度を用いる場合、分割前のジニ不純度と分割後のジニ不純度の差が最大になるような特徴量、閾値を探索する。ジニ不純度は以下の式で求められる。

1i=1Jpi21 - \sum_{i=1}^{J} p_i^2

ジニ不純度の表しているところは、集合の中からランダムに1つを選ぶ試行を2回行ったときにそれぞれ違うクラスに属している確率である。違うクラスに属している確率が高まるということは、クラスに偏りが無く、うまく別れていない状態を表すことになるので、この数字ができるだけ小さくなるように分割を行えれば嬉しいということになる。

また、日本語のブログなどでもジニ係数とジニ不純度を混同しているものがあるが、2つは名前こそ似ているが全くの別物です。

情報利得(Information gain)

情報利得を用いる場合、分割前の情報量と分割後の情報量の差が最大になるような特徴量、閾値を探索する。

i=1Jpilog2pi- \sum_{i=1}^{J} p_i \log_{2}{p_i}

irisを使って分割前後のジニ不純度の差を見てみる

ひとまずirisを落としてくる。

wget https://raw.githubusercontent.com/pandas-dev/pandas/master/pandas/tests/data/iris.csv

ジニ不純度を求める関数は

def gini_impurity(target: np.ndarray, classes: Set[Any]) -> float:
    ret = 1.0
    if len(target) == 0:
        return ret
    for _class in classes:
        ret -= (len(target[target == _class]) / len(target))**2
    return ret

こんな感じで書ける。irisで例を出すとclasses={'Iris-setosa', 'Iris-versicolor', 'Iris-virginica'}となっていて、targetに実際のデータのインスタンスのクラスの集まりとなっている。もし、targetの中身がすべて一つのクラスだと、ジニ不純度は0になる。irisのデータセットはclasses={'Iris-setosa', 'Iris-versicolor', 'Iris-virginica'}の3クラスが全て均等になっているので、ジニ不純度は

1(13)2(13)2(13)2=0.6661 - \left(\frac{1}{3} \right)^2 - \left(\frac{1}{3} \right)^2 - \left(\frac{1}{3} \right)^2 = 0.666

となる。
仮に、ある特徴量についてある閾値できれいに2クラスのみが含まれる集合と1クラスのみが含まれる集合に分割できたとする。

decision tree split node

分割後の2クラスのみが含まれる集合(三角と四角のグループ)のジニ不純度は、

1(12)2(12)202=0.51 - \left(\frac{1}{2}\right)^2 - \left(\frac{1}{2}\right)^2 - 0^2 = 0.5

1クラスのみが含まれる集合のジニ不純度は0なので、分割後のジニ不純度の重みを付けた平均は

23×0.5+12×0=0.333\frac{2}{3} \times 0.5 + \frac{1}{2} \times 0 = 0.333

したがって分割によるジニ不純度の差は0.6660.333=0.3330.666 - 0.333 = 0.333となる。一方、分割したもののクラスの分布がきれいに別れずにそのままにになってしまった場合、分割後のジニ不純度は同様に0.6660.666となるので分割によるジニ不純度の差は0.6660.666=00.666-0.666=0となってしまうので、きれいに分かれている方が差が大きくなっていることがわかる。

このジニ不純度の差が大きくなるような特徴量と閾値を探す必要がある。今回は特徴量の種類がSepalLengthSepalWidthPetalLengthPetalWidthの4種類しかないのでナイーブにデータの数×特徴量の種類数分だけ計算すれば良い。

適当にプログラムを書くと、

def calc_gain(feature: np.ndarray, target: np.ndarray, threshold: float) -> float:
    classes = set(target)
    target_left = target[feature > threshold]
    target_right = target[feature <= threshold]
    criterion_before = gini_impurity(target, classes)
    criterion_left = gini_impurity(target_left, classes)
    criterion_right = gini_impurity(target_right, classes)
    criterion_after = criterion_left * len(target_left) / len(target) + criterion_right * len(target_right) / len(target)
    gain = criterion_before - criterion_after
    return gain


df = pd.read_csv('./iris.csv')
x = df.iloc[:, :4].values
y = df.iloc[:, 4].values

fig = plt.figure()

ax = None
for i, column in enumerate(df.columns[:-1]):
    kwargs: dict = {} if ax is None else {'sharey': ax}
    ax = fig.add_subplot('22{0}'.format(i+1), **kwargs)
    ax.title.set_text(column)
    ax.grid()

    gain = []
    feature_values = df[column].values
    thresholds = list(set(feature_values))
    thresholds.sort()
    for threshold in thresholds:
        gain.append(calc_gain(feature_values, y, threshold))

    ax.plot(thresholds, gain)
plt.show()

グラフから、PetalLengthPetalWidthを使えばジニ不純度が0.33くらいで最大を取れそうなことがわかる。厳密に最大を取る変数と閾値を使って繰り返し分割していけば決定木を作ることができそうだ。

ひとまずここまでの内容で実装してみる。

import numpy as np
from queue import Queue
from typing import Optional, Union, List, Set, Any, Tuple, cast


class DecisionTreeNode(object):
    def __init__(self, features: np.ndarray, target: Union[np.ndarray, List[Any]], depth=0) -> None:
        self.left: Optional[DecisionTreeNode] = None
        self.right: Optional[DecisionTreeNode] = None
        self.threshold: Optional[float] = None
        self.feature_index: Optional[int] = None
        self.gain = 0.0
        self.has_child = False
        self.depth = depth

        self.features = features
        self.target = target if isinstance(target, np.ndarray) else np.array(target)

    def _split_node(self, threshold: float, feature_index: int, gain: float) -> None:
        self.threshold = threshold
        self.feature_index = feature_index
        self.gain = gain

        left_indices = self.features[:, feature_index] > threshold
        right_indices = self.features[:, feature_index] <= threshold

        # Todo 枝刈り
        if len(left_indices) < 5 or len(right_indices) < 5:
            pass
        else:
            self.left = DecisionTreeNode(self.features[left_indices],
                                         self.target[left_indices],
                                         depth=self.depth+1)
            self.right = DecisionTreeNode(self.features[right_indices],
                                          self.target[right_indices],
                                          depth=self.depth+1)
            self.has_child = True

            del self.features, self.target
            self.features = None
            self.target = None


class DecisionTree(object):
    def __init__(self, criterion: str = 'gini') -> None:
        self.root: Optional[DecisionTreeNode] = None
        self.criterion = criterion

    def fit(self, X: np.ndarray, y: np.ndarray) -> Any:
        self.root = DecisionTreeNode(X, y)
        nodes: Queue = Queue()
        nodes.put(self.root)

        while nodes.qsize() > 0:
            current_node: DecisionTreeNode = nodes.get()
            threshold, feature_index, gain = self.calc_best_gain(current_node.features, current_node.target)
            if gain > 0:
                current_node._split_node(threshold, feature_index, gain)
                if current_node.has_child:
                    nodes.put(current_node.left)
                    nodes.put(current_node.right)
        return self

    def predict(self, X: np.ndarray) -> np.ndarray:
        if self.root is None:
            raise Exception('This instance is not fitted yet. Call "fit" function before calling "predict".')
        ret = []
        for sample in X:
            current_node: DecisionTreeNode = self.root
            while current_node.has_child:
                if sample[current_node.feature_index] > current_node.threshold:
                    current_node = cast(DecisionTreeNode, current_node.left)
                else:
                    current_node = cast(DecisionTreeNode, current_node.right)
            classes, counts = np.unique(current_node.target, return_counts=True)
            ret.append(classes[counts.argmax()])
        return np.array(ret)

    def gini_impurity(self, target: np.ndarray, classes: Set[Any]) -> float:
        ret = 1.0
        if len(target) == 0:
            return ret
        for _class in classes:
            ret -= (len(target[target == _class]) / len(target))**2
        return ret

    def entropy(self, target: np.ndarray, classes: Set[Any]) -> float:
        ret = 0.0
        if len(target) == 0:
            return ret
        for _class in classes:
            p = len(target[target == _class]) / len(target)
            ret -= p * np.log2(p)
        return ret

    def calc_gain(self, feature: np.ndarray, target: np.ndarray, threshold: float) -> float:
        classes = set(target)
        if self.criterion == 'gini':
            criterion = self.gini_impurity
        elif self.criterion == 'entropy':
            criterion = self.entropy
        target_left = target[feature > threshold]
        target_right = target[feature <= threshold]
        criterion_before = criterion(target, classes)
        criterion_left = criterion(target_left, classes)
        criterion_right = criterion(target_right, classes)
        criterion_after = criterion_left * len(target_left) / len(target) + criterion_right * len(target_right) / len(target)
        gain = criterion_before - criterion_after
        return gain

    def calc_best_gain(self, features: np.ndarray, target: np.ndarray) -> Tuple[float, int, float]:
        best_feature_index = -1
        best_gain = 0.0
        best_threshold = 0.0
        for feature_index in range(features.shape[1]):
            feature = features[:, feature_index]
            thresholds = list(set(feature))
            # thresholds.sort()
            for threshold in thresholds:
                gain = self.calc_gain(feature, target, threshold)
                if gain > best_gain:
                    best_gain = gain
                    best_feature_index = feature_index
                    best_threshold = threshold
        return best_threshold, best_feature_index, best_gain

scikit-learnの力も借りて5分割交差検証をしてみる。

from sklearn.model_selection import KFold
from model import DecisionTree

iris = fetch_mldata('iris')
x: np.ndarray = iris.data
y: np.ndarray = iris.target

kf = KFold(n_splits=5, shuffle=True)

for i, (train, test) in enumerate(kf.split(x)):
    clf = DecisionTree().fit(x[train], y[train])
    accuracy = (clf.predict(x[test]) == y[test]).sum() / len(y[test])
    print('{0}-validation accuracy: {1}'.format((i+1), accuracy))

出力:

1-validation accuracy: 0.9
2-validation accuracy: 0.9333333333333333
3-validation accuracy: 0.9
4-validation accuracy: 0.9666666666666667
5-validation accuracy: 0.9333333333333333

コード

http://github.com/minatosato/decision-tree


Written by@Minato Sato
Senior Software Engineer - Embedded AI

GitHubTwitterFacebookLinkedIn

© 2023 Minato Sato. All Rights Reserved