• 文件 >
  • 從你的第一個訓練迴圈開始
快捷方式

從你的第一個訓練迴圈開始

作者Vincent Moens

注意

要在 Notebook 中執行此教程,請在開頭新增一個包含以下內容的安裝單元格

!pip install tensordict
!pip install torchrl

是時候總結一下我們在本入門系列中學到的所有知識了!

在本教程中,我們將使用前面課程中介紹過的元件,編寫最基本的訓練迴圈。

我們將使用帶有 CartPole 環境的 DQN 作為原型示例。

我們將故意將細節保持在最低限度,只將每個部分連結到相關的教程。

構建環境

我們將使用一個帶有 StepCounter 轉換的 gym 環境。如果需要回顧,請檢視這些功能在環境教程中的介紹。

import torch

torch.manual_seed(0)

import time

from torchrl.envs import GymEnv, StepCounter, TransformedEnv

env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter())
env.set_seed(0)

from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq

設計策略

下一步是構建我們的策略。我們將製作一個常規的、確定性版本的 Actor,用於損失模組內部和評估期間。接下來,我們將為其新增一個探索模組用於推理

from torchrl.modules import EGreedyModule, MLP, QValueModule

value_mlp = MLP(out_features=env.action_spec.shape[-1], num_cells=[64, 64])
value_net = Mod(value_mlp, in_keys=["observation"], out_keys=["action_value"])
policy = Seq(value_net, QValueModule(spec=env.action_spec))
exploration_module = EGreedyModule(
    env.action_spec, annealing_num_steps=100_000, eps_init=0.5
)
policy_explore = Seq(policy, exploration_module)

資料收集器和回放緩衝區

接下來是資料部分:我們需要一個資料收集器來輕鬆獲取資料批次,還需要一個回放緩衝區來儲存這些資料用於訓練。

from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer

init_rand_steps = 5000
frames_per_batch = 100
optim_steps = 10
collector = SyncDataCollector(
    env,
    policy_explore,
    frames_per_batch=frames_per_batch,
    total_frames=-1,
    init_random_frames=init_rand_steps,
)
rb = ReplayBuffer(storage=LazyTensorStorage(100_000))

from torch.optim import Adam

損失模組和最佳化器

我們按照專用教程中的說明構建損失函式,以及其最佳化器和目標引數更新器

from torchrl.objectives import DQNLoss, SoftUpdate

loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True)
optim = Adam(loss.parameters(), lr=0.02)
updater = SoftUpdate(loss, eps=0.99)

日誌記錄器

我們將使用 CSV 日誌記錄器來記錄結果並儲存渲染的影片。

from torchrl._utils import logger as torchrl_logger
from torchrl.record import CSVLogger, VideoRecorder

path = "./training_loop"
logger = CSVLogger(exp_name="dqn", log_dir=path, video_format="mp4")
video_recorder = VideoRecorder(logger, tag="video")
record_env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True, pixels_only=False), video_recorder
)

訓練迴圈

我們將不固定執行的迭代次數,而是持續訓練網路,直到它達到一定的效能(任意定義為在環境中達到 200 步 - 對於 CartPole,成功定義為具有更長的軌跡)。

total_count = 0
total_episodes = 0
t0 = time.time()
for i, data in enumerate(collector):
    # Write data in replay buffer
    rb.extend(data)
    max_length = rb[:]["next", "step_count"].max()
    if len(rb) > init_rand_steps:
        # Optim loop (we do several optim steps
        # per batch collected for efficiency)
        for _ in range(optim_steps):
            sample = rb.sample(128)
            loss_vals = loss(sample)
            loss_vals["loss"].backward()
            optim.step()
            optim.zero_grad()
            # Update exploration factor
            exploration_module.step(data.numel())
            # Update target params
            updater.step()
            if i % 10:
                torchrl_logger.info(f"Max num steps: {max_length}, rb length {len(rb)}")
            total_count += data.numel()
            total_episodes += data["next", "done"].sum()
    if max_length > 200:
        break

t1 = time.time()

torchrl_logger.info(
    f"solved after {total_count} steps, {total_episodes} episodes and in {t1-t0}s."
)

渲染

最後,我們讓環境執行儘可能多的步數,並在本地儲存影片(注意我們此時沒有進行探索)。

record_env.rollout(max_steps=1000, policy=policy)
video_recorder.dump()

完整的訓練迴圈結束後,你渲染的 CartPole 影片將看起來像這樣

../_images/cartpole.gif

至此,我們的“TorchRL 入門”系列教程就結束了!歡迎在 GitHub 上分享你的反饋。

由 Sphinx-Gallery 生成的圖集

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源