より良いエンジニアを目指して

1日1つ。良くなる!上手くなる!

CatBoostを使って、ポケモンの戦闘結果予測をリトライ

Kaggleで他の方のKernelを勉強させていただいていると、Catboostというアルゴリズムがあることを知りました。

CatboostはロシアのYandexという会社が公開したライブラリです。

tekenuko.hatenablog.com

過学習を避けるアルゴリズムとのこと。

Yandexとは?

ロシアのGoogleとも言われる会社。

jp.rbth.com

偉大な天気予報は気になります。

さっそくトライ

rimever.hatenablog.com

もっとも手っ取り早く、上記でやったポケモンのバトルを予測する処理に使ってみます。

tech.yandex.com

from catboost import CatBoostClassifier
# Initialize data
train_target = []

for item in y_train:
    if item:
        train_target.append(1)
    else:
        train_target.append(0)

# Initialize CatBoostClassifier
model = CatBoostClassifier(iterations=2, learning_rate=1, depth=2, loss_function='Logloss')
# Fit model
model.fit(X_train, train_target)
# Get predicted classes
preds_class = model.predict(X_test)
# Get predicted probabilities for each class
preds_proba = model.predict_proba(X_test)
# Get predicted RawFormulaVal
y_pred = model.predict(X_test, prediction_type='RawFormulaVal')



y_result = []

for value in y_pred:
    if value > 0.5:
        y_result.append(True)
    else:
        y_result.append(False)


# 精度 (Accuracy) を計算する
accuracy = sum(y_test == y_result) / len(y_test)
print("CatBoostの精度",accuracy * 100,"%")

結果は93%でした。

qiita.com

上記サイトでは、ちゃんとパラメーター設定しないとパフォーマンスでないよ、ということでもう少し工夫して見る事にしました。

iterationは5000としlearning_rateは1のまま、depthは6に変更しました。

f:id:rimever:20190126223825p:plain

99.6%!

このアルゴリズムのポテンシャルは、なかなかありそうです。