14 |
Pongと強化学習 (73) |
atari_dqnのレポジトリにおいて、Q-networkwork構造を記述する以下のpythonコードピースを見つけたので、同様な形式で解説して貰います。
class QNetwork(nn.Module):
def __init__(self, env):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(4, 32, 8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, 3, stride=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(3136, 512),
nn.ReLU(),
nn.Linear(512, env.single_action_space.n),
このQNetwork
クラスは、PyTorchを使用して定義された畳み込みニューラルネットワーク(CNN)で、特にAtariゲームのようなビジュアルベースの環境で強化学習エージェントのQ値を推定するために設計されています。以下に、各層の役割について説明します。
入力層
このネットワークは、4チャネルの84x84ピクセルの画像を入力として受け取ります。4つのチャネルは、現在のフレームと過去3フレームを表し、動きや時間的な変化を捉えることができます。
第1の畳み込み層
- 層:
nn.Conv2d(4, 32, 8, stride=4)
- 機能: この層は、8x8のカーネル(フィルター)を32枚使用し、ストライド(ステップサイズ)は4です。4つの入力チャネル(前の4フレーム)に対して畳み込みを行い、特徴マップを32枚生成します。この畳み込みにより、画像の初期特徴が抽出されます。
- 活性化関数: ReLU。非線形性を導入し、モデルの表現力を高めます。
第2の畳み込み層
- 層:
nn.Conv2d(32, 64, 4, stride=2)
- 機能: 第1層で抽出された特徴マップをさらに精密化します。4x4のカーネルを64枚使用し、ストライドは2です。これにより、より複雑な特徴が抽出されます。
- 活性化関数: ReLU。
第3の畳み込み層
- 層:
nn.Conv2d(64, 64, 3, stride=1)
- 機能: 更に精緻な特徴を抽出するため、3x3のカーネルを64枚使用し、ストライドは1です。この層は、より細かな特徴を捉えることができます。
- 活性化関数: ReLU。
全結合層
- 層:
nn.Linear(3136, 512)
- 機能: 畳み込み層からの出力を平坦化し、3136個のユニットからなるベクトルを入力として受け取ります。この層は512個の隠れユニットを持ち、畳み込み層で抽出された特徴を基に高レベルの特徴を学習します。
- 活性化関数: ReLU。
出力層
- 層:
nn.Linear(512, env.single_action_space.n)
- 機能: 512個の隠れユニットから、行動空間のサイズに対応する数の出力(各行動に対するQ値)を生成します。
env.single_action_space.n
は、エージェントが選択可能な行動の総数を示します。この層は、各行動の期待される報酬の推定値を出力します。
このネットワークは、入力画像から直接、特定の行動に対するQ値を推定する能力を提供し、エージェントが最適な行動を選択するための情報を生成します。