フリーランチ食べたい

機械学習、Python、ソフトウェアエンジニアリング、プロダクティビティなど

Stackingでモデルの精度UP 実装と直感的な解説

Stackingとは何か

  • 機械学習モデルの精度を向上させる手法の1つで、モデルを積み重ねる(Stackする)ことで精度を高めます。
  • ポピュラーかつ、強力な手法なKaggleのKernelで見ることも多いですね。
  • アンサンブル学習の一種で、他のアンサンブル学習にはAveraging/Bagging/Boostingがあります。
  • Stackingは他のアンサンブル学習の手法と比べると文章だけ読んでも理解し辛いので、直感的に理解できるように図と簡単な実装で説明してみたいと思います。

注意: この記事で書かないこと

今回は概念の理解にフォーカスしたいので、バリデーションの手法については触れません。特にSecondLevelでのバリデーションについてはいくつか選択肢があり、データの性質によって使い分けが必要なので次回以降書きたいと思います。

Stackingの概念

まず直感的な説明から始めたいと思います。「ある物件の価格を過去のデータから予測するタスク」を考えます。

基本のモデル

基本のモデルは1つのモデルがあって、そのモデルが予測を出力します。 ここではゾウがモデルだと思ってください。

f:id:mergyi:20181021184007p:plain

Stacking

それではStackingはどうなるかというと、図で表すと下記のようになります。

f:id:mergyi:20181021183745p:plain

先ほどと違う点として

  • 複数のモデルが存在する
  • モデルの出力を統括する meta_model (メタモデル) が存在する

があります。

ここで動物たちが表している各モデルをベースモデル、先生が表す統括するモデルをメタモデルと呼びます。ポイントは メタモデル「ベースモデルの出力」を入力として学習し、結果を出力すること です。メタモデルで使われるアルゴリズムは特に制限はないのですが、回帰問題であればLinearRegressionを分類であればLogisticRegressionを使うのが一般的です。

メタモデルは何を学習するか

メタモデルの入力は、「ベースモデルの出力」なので、学習データ内の特徴量を知りません。それではメタモデルは何を学習しているのでしょうか。それは、「ベースモデルがどのくらい信頼できるか」です。図の例でいうと動物たち(ベースモデル)の予測を何度も聞いて「ぞう(LR)はあまり当てにならないな、さる(LGBM)はかなり正確だな」という感覚を学習していっているようなイメージです。最終的にはベースモデルごとに重み(パラメーター)を割り当てて最終結果を出力します。(※「重み」という書き方は線形モデルに限定しているような書き方になってしまうのですが、わかりやすさのためにこの表現を使っています。)

なぜStackingで精度を向上できるのか

アンサンブル手法は複数のモデルを組み合わせて強力なモデルを作る、という手法ですが、その組み合わせ方として、単純な平均などで組み合わせるよりも、「それぞれのモデルの良さを知っているメタモデルが調整する」ことにより、より良い組み合わせが行われる、と理解できます。

Python実装

直感的な説明が終わったので、実際にどんな処理が行われているか実装を通して見てみたいと思います。

データ

KaggleのTurorial用Competitionを使います。

House Prices: Advanced Regression Techniques | Kaggle

コード

実装は下のGithubに上げてあります。

github.com

解説

1つ1つ解説していきたいと思います。まず前提として以下のデータを準備しています。

  • X_train/y_train: 学習データ
  • X_valid/y_valid: バリデーションデータ
  • X_meta_valid/y_meta_valid: メタモデルのバリデーションデータ

ベースモデルの学習、メタモデルの学習、というように2段階に分かれているので、バリデーションデータが2種類必要なことに注意してください。

ベースモデルの学習

まずはベースモデルの学習です。単純にモデルを3つ作って学習させているだけです。 base_pred_1~3はメタモデルの学習入力データになります。

# train base model
base_model_1 = LinearRegression()
base_model_2 = LGBMRegressor()
base_model_3 = KNeighborsRegressor()

base_model_1.fit(X_train, y_train)
base_model_2.fit(X_train, y_train)
base_model_3.fit(X_train, y_train)

# base predicts
base_pred_1 = base_model_1.predict(X_valid)
base_pred_2 = base_model_2.predict(X_valid)
base_pred_3 = base_model_3.predict(X_valid)

メタモデルの学習

ベースモデルの出力を使ってメタモデルの学習を行います。 そんなに難しいことはしておらず、「ベースモデルの出力を連結する」「それを入力として学習させる」だけです。

# stack base predicts for training meta model
stacked_predictions = np.column_stack((base_pred_1, base_pred_2, base_pred_3))

# train meta model 
meta_model = LinearRegression()
meta_model.fit(stacked_predictions, y_valid)

結果の検証

ベースモデル、メタモデルともに学習が終わったらメタモデルの出力を使って精度を検証してみます。この時に、必ず「ベースモデルで予測」->「予測結果をメタモデルに入力」という工程を踏むことを忘れないように注意してください。

# final result 
valid_pred_1 = base_model_1.predict(X_meta_valid)
valid_pred_2 = base_model_2.predict(X_meta_valid)
valid_pred_3 = base_model_3.predict(X_meta_valid)
stacked_valid_predictions = np.column_stack((valid_pred_1, valid_pred_2, valid_pred_3))
meta_valid_pred = meta_model.predict(stacked_valid_predictions)

print ("mean squared error of model 1: {:.4f}".format(mean_squared_error(y_meta_valid, valid_pred_1)) )
print ("mean squared error of model 2: {:.4f}".format(mean_squared_error(y_meta_valid, valid_pred_2)) )
print ("mean squared error of model 3: {:.4f}".format(mean_squared_error(y_meta_valid, valid_pred_3)) )
print ("mean squared error of meta model: {:.4f}".format(mean_squared_error(y_meta_valid, meta_valid_pred)) )

# => mean squared error of model 1: 0.0239
# => mean squared error of model 2: 0.0181
# => mean squared error of model 3: 0.0634
# => mean squared error of meta model: 0.0175

このように、メタモデルの結果が一つ一つのベースモデルよりも良い精度が得られたことを確認できました。一工程ずつ追っていくと難しいことは何もしていないことがわかっていただけたのではないでしょうか?

Stackingを使う上で大事なこと/まとめ

最後にStackingを使う上で気をつけることなどを記載しておきます。

  • アンサンブル手法全体に言えることですが、ベースモデルの多様性が高いほど良い結果を得られやすくなります。
    • 多様性、というのはモデルの種類(線形モデル/決定木モデル/KNNモデル)と特量量の2つの側面があります。
    • 今回はベースモデルはモデルの種類だけ変えましたが特量量自体を変えるという戦略も有効です。
  • その他のアンサンブル手法に比べてStackingが常に優れているわけではないので、データや特徴量によって手法を使い分ける必要があります。
  • 今回はメタモデルのバリデーションをシンプルなHoldOut法で行いましたが、適切でない場合もありDataLeakageが発生する可能性があります。特に時系列データを扱う際は気をつけてください。

参考リンク

Stackingに関しての資料は他のアンサンブル手法に比べて少ないのですが、下記のリンクが参考になりました。

Stackingを簡単に行えるStackNetというフレームワークもあるので興味ある方はチェックしてみてください。

blog.kaggle.com