フリーランチ食べたい

No Free Lunch in ML and Life. Pythonや機械学習のことを書きます。

scikit-learnのRandomForest.feature_importances_のコードを追う

feature_importances_をちゃんと理解する

  • feature_importances_ とは sklearn.ensemble.RandomForestClassifiersklearn.ensemble.RandomForestRegressor (など)で特徴量の重要度を出力するメソッドです。
  • とても便利で、EDAやモデルの精度向上のためのアイディアを得るためによく使用しますが、「この重要度って何を表しているの?」と聞かれたときにパッと説明できなかったので調べてみました。
  • ちなみにドキュメントには↓の1行だけ説明があります。

    The importance of a feature is computed as the (normalized) total reduction of the criterion brought by that feature. It is also known as the Gini importance.

開発者が回答しているStackOverflow

もう少し細かい説明としてはscikit-learnの開発者 Gilles Louppe がstackoverflowで↓のように質問に回答しています。

stackoverflow.com

There are indeed several ways to get feature "importances". As often, there is no strict consensus about what this word means.

In scikit-learn, we implement the importance as described in [1] (often cited, but unfortunately rarely read...). It is sometimes called "gini importance" or "mean decrease impurity" and is defined as the total decrease in node impurity (weighted by the probability of reaching that node (which is approximated by the proportion of samples reaching that node)) averaged over all trees of the ensemble.

記載されている通り、gini importance あるいは mean decrease impurity と呼ばれ、ノードの不純度(impurity)をアンサンブル木で平均したものになります。 これが簡潔で正しい回答なのですが、一応コードベースでも見てみたいと思います。

該当コード

順番に見ていきます。まず大元の sklearn.ensemble.forest.BaseForest クラスのメソッドです。

https://github.com/scikit-learn/scikit-learn/blob/1128094271923c66f9e602372ba7ee8b7f565e52/sklearn/ensemble/forest.py#L365 ※ 該当部分以外を省略しています

    def feature_importances_(self):
        all_importances = Parallel(n_jobs=self.n_jobs,
                                   **_joblib_parallel_args(prefer='threads'))(
            delayed(getattr)(tree, 'feature_importances_')
            for tree in self.estimators_)

        return sum(all_importances) / len(self.estimators_)

わかりやすいコードで、ここで各アンサンブル木のall_importancesを平均していることがわかります。

呼ばれているのは、 sklearn.tree.tree.BaseDecisionTree のメソッドで、単純に compute_feature_importances を呼び出しているだけです。

https://github.com/scikit-learn/scikit-learn/blob/a80bbd9403fea9cf4aa46dfef26a4b31a608957b/sklearn/tree/tree.py#L513 ※ 該当部分以外を省略しています

    def feature_importances_(self):
        return self.tree_.compute_feature_importances()

ここからはCythonで書かれた sklearn.tree._tree.Tree クラスのメソッドを呼び出しています。

https://github.com/scikit-learn/scikit-learn/blob/a80bbd9403fea9cf4aa46dfef26a4b31a608957b/sklearn/tree/_tree.pyx#L1062 ※ 該当部分以外を省略しています

    cpdef compute_feature_importances(self, normalize=True):
        cdef np.ndarray[np.float64_t, ndim=1] importances
        importances = np.zeros((self.n_features,))
        cdef DOUBLE_t* importance_data = <DOUBLE_t*>importances.data

        with nogil:
            while node != end_node:
                if node.left_child != _TREE_LEAF:
                    # ... and node.right_child != _TREE_LEAF:
                    left = &nodes[node.left_child]
                    right = &nodes[node.right_child]

                    importance_data[node.feature] += (
                        node.weighted_n_node_samples * node.impurity -
                        left.weighted_n_node_samples * left.impurity -
                        right.weighted_n_node_samples * right.impurity)
                node += 1

        importances /= nodes[0].weighted_n_node_samples
        return importances

こちらが特に重要な部分でnodeの不純度(impurity)から左右のnodeの不純度を引いたもの(それぞれ重みをかけている)をその特徴量の不純度にしています。

importance_data[node.feature] += (
  node.weighted_n_node_samples * node.impurity -
  left.weighted_n_node_samples * left.impurity -
  right.weighted_n_node_samples * right.impurity)

この不純度を各アンサンブル木で平均したものが feature_importances_ になります。

さいごに

  • ざっくりとですが、コードを追って、 RandomForestfeature_importances_ を理解しました。
  • ちなみに他の指標としては mean decrease accuracy があります。
  • これはOOBを用いて測る指標で、該当変数をモデルから除いた際の予測精度の低下を計算します。RのRandomForestにはこの指標も実装されています。