Posts Issued on July 25, 2024

Pongと強化学習 (46)

posted by sakurai on July 25, 2024 #841

ターゲット$Q$ネットワークを深掘りしてみます。


ターゲットQネットワークは、Deep Q-Network (DQN) アルゴリズムにおいて、学習の安定性を向上させるために導入される重要な概念です。DQNでは、行動価値関数(Q関数)を近似するために深層ニューラルネットワークが使用されますが、この学習プロセスは、特に高次元の状態空間を扱う場合に不安定になりがちです。ターゲットQネットワークは、この問題に対処するために設計されています。

ターゲットQネットワークの役割

ターゲットQネットワークは、主Qネットワークのコピーであり、学習中にQ値の更新ターゲットとして使用されます。主Qネットワークは学習の各ステップで更新されるのに対し、ターゲットQネットワークは定期的にのみ更新されます(例えば、数百ステップごと)。この遅延更新により、学習ターゲットが一定期間固定され、学習プロセスの安定性が向上します。

ターゲットQネットワークの実装の概要

  1. ネットワークの初期化: 学習を開始する際、主QネットワークとターゲットQネットワークの両方を同じ重みで初期化します。
  2. 経験リプレイ: エージェントの経験(状態、行動、報酬、次の状態)をリプレイメモリに保存します。
  3. 学習バッチの抽出: リプレイメモリからランダムにバッチを抽出します。
  4. Q値の更新: 抽出したバッチに基づき、主Qネットワークを使用して現在のQ値を計算し、ターゲットQネットワークを使用してターゲットQ値を計算します。次に、これらのQ値の差(TD誤差)に基づいて、主Qネットワークの重みを更新します。
  5. ターゲットQネットワークの更新: 定期的に主Qネットワークの重みをターゲットQネットワークにコピーします。

コードスニペット(疑似コード)

# 主QネットワークとターゲットQネットワークの初期化
Q_network = initialize_network()
target_Q_network = copy_network(Q_network)

# 学習ループ
for step in range(num_steps):
    # 経験リプレイからバッチを抽出
    batch = replay_memory.sample(batch_size)
    
    # ターゲットQ値の計算
    target_Q_values = calculate_target_Q_values(batch, target_Q_network)
    
    # 主Qネットワークの更新
    update_Q_network(Q_network, batch, target_Q_values)
    
    # 定期的にターゲットQネットワークを更新
    if step % update_target_network_steps == 0:
        target_Q_network = copy_network(Q_network)

ターゲットQネットワークを使用することで、学習中の目標値が頻繁に変動することを防ぎ、DQNの学習プロセスをより安定させることができます。この安定化は、特に複雑な環境やタスクにおいて、学習の収束を助ける重要な要素となります。


左矢前のブログ 次のブログ右矢