8 |
Pongと強化学習 (69) |
DQN損失関数を深掘りしてみます。
DQN損失関数において2乗を取る理由は、確かに予測誤差の正負を問わず、その大きさ(絶対値)を考慮するためです。この損失関数は、TD誤差(Temporal Difference error)の二乗として定義され、エージェントの行動価値関数の予測が実際の報酬からどれだけ離れているかを量る指標となります。具体的には、以下の形で表されます:
$$L(\theta) = \mathbb{E}\left[\left(r + \gamma \max_{a'} Q(s', a'; \theta^-) - Q(s, a; \theta)\right)^2\right]$$
2乗した数値の扱い方
損失の平均化:一般的に、学習データセット全体またはバッチ内のサンプルに対する損失の平均(または総和)を取ることで、モデルのパラメータを更新する際の目的関数とします。期待値$\mathbb{E}[\cdot]$は、この平均化された損失を表しています。
勾配降下法:損失関数の平均化された値に基づいて、勾配降下法(またはそのバリエーション)を用いてネットワークのパラメータ$\theta$を更新します。具体的には、損失関数のパラメータに対する勾配(偏微分)を計算し、その勾配の方向にパラメータを少しずつ動かすことで、損失を最小化するパラメータの値を探します。
安定性と収束の向上:2乗損失は、特に大きな誤差に対してペナルティを大きくするため、予測を正確に行うようにモデルを強く促します。また、二乗損失関数は微分可能であり、勾配降下法による最適化が比較的容易であるため、学習プロセスの安定性と収束の向上に寄与します。
注意点
ただし、二乗損失関数は外れ値に対して敏感であるため、非常に大きな誤差を持つサンプルが存在する場合には、モデルの学習がそのようなサンプルに過剰に影響を受ける可能性があります。この点を考慮し、場合によっては損失関数として絶対値損失(L1損失)を使用することもありますが、DQNの文脈では二乗損失が一般的に使用されています。