Article #894

既に発行済みのブログであっても適宜修正・追加することがあります。
We may make changes and additions to blogs already published.

Pongと強化学習 (73)

posted by sakurai on November 14, 2024 #894

atari_dqnのレポジトリにおいて、Q-networkwork構造を記述する以下のpythonコードピースを見つけたので、同様な形式で解説して貰います。

class QNetwork(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(4, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, env.single_action_space.n),

このQNetworkクラスは、PyTorchを使用して定義された畳み込みニューラルネットワーク(CNN)で、特にAtariゲームのようなビジュアルベースの環境で強化学習エージェントのQ値を推定するために設計されています。以下に、各層の役割について説明します。

入力層

このネットワークは、4チャネルの84x84ピクセルの画像を入力として受け取ります。4つのチャネルは、現在のフレームと過去3フレームを表し、動きや時間的な変化を捉えることができます。

第1の畳み込み層

  • : nn.Conv2d(4, 32, 8, stride=4)
  • 機能: この層は、8x8のカーネル(フィルター)を32枚使用し、ストライド(ステップサイズ)は4です。4つの入力チャネル(前の4フレーム)に対して畳み込みを行い、特徴マップを32枚生成します。この畳み込みにより、画像の初期特徴が抽出されます。
  • 活性化関数: ReLU。非線形性を導入し、モデルの表現力を高めます。

第2の畳み込み層

  • : nn.Conv2d(32, 64, 4, stride=2)
  • 機能: 第1層で抽出された特徴マップをさらに精密化します。4x4のカーネルを64枚使用し、ストライドは2です。これにより、より複雑な特徴が抽出されます。
  • 活性化関数: ReLU。

第3の畳み込み層

  • : nn.Conv2d(64, 64, 3, stride=1)
  • 機能: 更に精緻な特徴を抽出するため、3x3のカーネルを64枚使用し、ストライドは1です。この層は、より細かな特徴を捉えることができます。
  • 活性化関数: ReLU。

全結合層

  • : nn.Linear(3136, 512)
  • 機能: 畳み込み層からの出力を平坦化し、3136個のユニットからなるベクトルを入力として受け取ります。この層は512個の隠れユニットを持ち、畳み込み層で抽出された特徴を基に高レベルの特徴を学習します。
  • 活性化関数: ReLU。

出力層

  • : nn.Linear(512, env.single_action_space.n)
  • 機能: 512個の隠れユニットから、行動空間のサイズに対応する数の出力(各行動に対するQ値)を生成します。env.single_action_space.nは、エージェントが選択可能な行動の総数を示します。この層は、各行動の期待される報酬の推定値を出力します。

このネットワークは、入力画像から直接、特定の行動に対するQ値を推定する能力を提供し、エージェントが最適な行動を選択するための情報を生成します。


左矢前のブログ 次のブログ右矢

Leave a Comment

Your email address will not be published.

You may use Markdown syntax. If you include an ad such as http://, it will be invalidated by our AI system.

Please enter the numbers as they are shown in the image above.