ITと哲学と

IT系エンジニアによる技術と哲学のお話。

【機械学習のトピックス】TrainingLossよりValidationLossが小さくなることってたまにあるけどなんで?というお話

モデルの学習をしていて、TrainingLossよりValidationLossが小さくなることが稀に良くある。 これをどう説明解釈すれば良いか?という点についてちょっと悩んだ際に見つけたAurélien Geron氏のツイートをメモっておく。

はじめに

最近の実務の中でよく感じているが、モデルの学習(というか機械学習関連のコーディング全般に言えるが)について、アプリ作ったりする普通のプログラミングにはない難しさがある。 どういうことかというと、厳密には間違っててもなんとなくそれっぽく動いちゃうことが多いように感じる。

例えばTrainVal間でデータの漏洩があったり、データ前処理で意図していない動作になっていたりしても、コードは最後まで動く。 普通のプログラミングの文脈では、コンパイラが強力にサポートしてくれるので、大変助けられているが、その安心感が機械学習の際のコーディングの際には得られない。

そのため、限りあるデータから不具合の兆候を早期に掴むことはプロジェクトの成否に関わる問題で、神経を使うなぁと。。と感じている。

モデルの学習の際に起こる典型的な不具合として、過学習という問題がある。 これは、Trainingデータに強く適合することで、未知のデータに対する精度が劣化してしまうということであり、モデルを実世界で運用していく際に思ったような精度が出なくなるという問題を引き起こす。

この過学習を検知するために、モデルの学習に使うデータをTrainingデータとValidationデータという形で分割し、Validationデータはモデル学習には全く使わないようにした上で、定期的にモデルの性能をValidationデータを使って評価するといったことを行う。 もしもTrainingデータに対する性能は良いがValidationデータに対する性能が大きく劣化しているなどあれば、これは未知のデータに対する精度が劣化していることを意味し、過学習が起きていることを示す。

一般的に、モデルの学習においてはValidationの精度はTrainingの精度を超えることはない。 モデルはTrainingのデータに対して最適な推論結果を出すように学習するため、未知のデータに対する推論結果はこれを超えないと考えられるからだ。

つまり、過学習が起こっていない良い学習が行われたモデルでは、ValidationLossがTrainingLossに伴走するような挙動を見せる。

ここで、表題にあるように、ValidationLossがTrainingLossより良い精度を見せることが稀に良くある。

これをどう解釈していいのか?これはどっかに不具合があるのか?それとも正常だけどこんなことが起こるのか?について悩んだりしていた。

本題

色々調べたりしていたら、最終的に以下の一連のツイートに行きあたった。 詳しくは以下のツイートをたどって欲しいが、いくつか可能性として考えられるものがあるので、まとめる。

正則化手法による影響

正則化を目的としてDropoutなどを行っている場合、表題のような現象が起こることがあり得る。

Trainingの際には正則化を目的として、あえてネットワークのノードを確率的にDropoutして全体としてのロバスト性を高めることを狙う。この結果としてTrainingLossは劣化する。 Validation時には正則化は不要なため、Dropoutせずにネットワークの全力を持って推論を行う。 この結果としてValidationLossがTrainingLossより良いという結果を産む可能性がある。

あえてValidationLossの計算時にもDropoutを入れることで、TrainingLossとValidationLossが同じような値になるといったケースがあるようだ。

各Lossを計算するタイミングの差による影響

TrainingLossはEpochの途中で計算されるが、ValidationLossはEpochの終了時点で計算される。 当たり前のことだがEpochの中でモデルは成長するので、TrainingLoss計算時点とValidationLoss計算時点ではモデルの学習進捗が異なる。 この影響によってValidationLossが先行して精度が上がっているように見えることがある。

この影響を補正することでValidationLossがTrainingLossと同じような値になるといったケースもあるようだ。

データの取扱における不具合

TrainingとValidationのデータ分割の際の不具合によって表題の現象が起こることがある。 例えばTraininingに比べてValidationの方が簡単に推論ができるようなデータセットになっているなど。

また、Trainingデータの一部がValidationデータに流出している場合などでも同じような問題が発生する。

データ分割の際には本質的にTrainingとValidationが同じデータ群から抽出され、データセット間でこのような差異がないようにする必要があるが、これが達成できていないケース。 このままでは過学習が起きたとしてもそれに気がつけないので、これは是正する必要がある。

まとめ

ここまで見てきたように、表題の現象はさまざまな要因によって発生しうる。 場合によっては複雑な事情が組み合わさり、TrainingLossとValidationLossがそれぞれ想定通りのいい感じに見えていたとしても、過学習が起きてしまっているようなことも起こりうる。

データの流出はないか?TrainingとValidationの計算タイミングのズレによる影響は?正則化によってTrainingLossが見かけより劣化してみえる事による影響は?など気をつけて評価していく必要がありそうだ。