天色グラフィティ

技術ちっくなことを書きます

LightGBMのcallbackを利用して学習履歴をロガー経由で出力する

KaggleなどでLightGBMを使っていて学習履歴を見たとき、ログファイルにも残してほしいと思うことがあります。

公式にはそのような機能は実装されていないようなので、LightGBMのコールバックで対応したいと思います。 LightGBMではfitメソッドの引数としてコールバック関数を渡すことができ、内部的にはEarlyStoppingや学習履歴を標準出力に吐くのに使われています。

LightGBMのコールバックを実装

lightgbm/callback.pyを見ると、学習時に標準出力に履歴を表示する関数があります。

def print_evaluation(period=1, show_stdv=True):
    """Create a callback that prints the evaluation results.
    Parameters
    ----------
    period : int, optional (default=1)
        The period to print the evaluation results.
    show_stdv : bool, optional (default=True)
        Whether to show stdv (if provided).
    Returns
    -------
    callback : function
        The callback that prints the evaluation results every ``period`` iteration(s).
    """
    def _callback(env):
        if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
            result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
            print('[%d]\t%s' % (env.iteration + 1, result))
    _callback.order = 10
    return _callback

これをベースにしてカスタマイズしてあげればよさそうです。

実装例は以下のとおりです。基本的にはprint_evaluationと一緒で、logging.Loggerとログレベルを新たに受け取れるようにしてあります。

import logging

from lightgbm.callback import _format_eval_result


def log_evaluation(logger, period=1, show_stdv=True, level=logging.DEBUG):
    def _callback(env):
        if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
            result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
            logger.log(level, '[{}]\t{}'.format(env.iteration+1, result))
    _callback.order = 10
    return _callback

使い方

ロガーを予め作成しておいて log_evaluation の引数に渡してあげれば動くはずです。

# ロガーの作成
logger = logging.getLogger('main')
logger.setLevel(logging.DEBUG)
sc = logging.StreamHandler()
logger.addHandler(sc)
fh = logging.FileHandler('hoge.log')
logger.addHandler(fh)

# データのロードなどは省略

# 訓練時にコールバックのリストを渡す
clf = lgb.LGBMClassifier()
callbacks = [log_evaluation(logger, period=10)]
clf.fit(X_train, y_train, eval_set=[(X_val, y_val)], callbacks=callbacks)

LightGBMのコールバックの実装方法

コールバック関数のリストをfit()時のcallbacks引数に渡してあげるとコールバックが呼ばれます。

コールバック関数は引数をひとつ取る必要があります。その引数には以下のように学習状況などが入ったnamedtupleが渡されます。

namedtuple(
    "LightGBMCallbackEnv",
    ["model",
     "params",
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

例えば env.iteration のようにすれば現在のイテレーション数が取れます。(0始まりなので+1するのが良いと思います)

コールバック関数にパラメータを設定したい場合は今回のようにクロージャ(関数を返す関数)にしてあげると良いです。

まとめ

  • LightGBMのコールバックを作成すると学習履歴をロガーに渡すことができる
  • lightgbm/callback.pyはシンプルなので読んでみるとよい

お役立ちリンク