注意
請前往末尾下載完整示例程式碼。
從你的第一個訓練迴圈開始¶
注意
要在 Notebook 中執行此教程,請在開頭新增一個包含以下內容的安裝單元格
!pip install tensordict !pip install torchrl
是時候總結一下我們在本入門系列中學到的所有知識了!
在本教程中,我們將使用前面課程中介紹過的元件,編寫最基本的訓練迴圈。
我們將使用帶有 CartPole 環境的 DQN 作為原型示例。
我們將故意將細節保持在最低限度,只將每個部分連結到相關的教程。
構建環境¶
我們將使用一個帶有 StepCounter 轉換的 gym 環境。如果需要回顧,請檢視這些功能在環境教程中的介紹。
import torch
torch.manual_seed(0)
import time
from torchrl.envs import GymEnv, StepCounter, TransformedEnv
env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter())
env.set_seed(0)
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
設計策略¶
下一步是構建我們的策略。我們將製作一個常規的、確定性版本的 Actor,用於損失模組內部和評估期間。接下來,我們將為其新增一個探索模組用於推理。
from torchrl.modules import EGreedyModule, MLP, QValueModule
value_mlp = MLP(out_features=env.action_spec.shape[-1], num_cells=[64, 64])
value_net = Mod(value_mlp, in_keys=["observation"], out_keys=["action_value"])
policy = Seq(value_net, QValueModule(spec=env.action_spec))
exploration_module = EGreedyModule(
env.action_spec, annealing_num_steps=100_000, eps_init=0.5
)
policy_explore = Seq(policy, exploration_module)
資料收集器和回放緩衝區¶
接下來是資料部分:我們需要一個資料收集器來輕鬆獲取資料批次,還需要一個回放緩衝區來儲存這些資料用於訓練。
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer
init_rand_steps = 5000
frames_per_batch = 100
optim_steps = 10
collector = SyncDataCollector(
env,
policy_explore,
frames_per_batch=frames_per_batch,
total_frames=-1,
init_random_frames=init_rand_steps,
)
rb = ReplayBuffer(storage=LazyTensorStorage(100_000))
from torch.optim import Adam
損失模組和最佳化器¶
我們按照專用教程中的說明構建損失函式,以及其最佳化器和目標引數更新器
from torchrl.objectives import DQNLoss, SoftUpdate
loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True)
optim = Adam(loss.parameters(), lr=0.02)
updater = SoftUpdate(loss, eps=0.99)
日誌記錄器¶
我們將使用 CSV 日誌記錄器來記錄結果並儲存渲染的影片。
from torchrl._utils import logger as torchrl_logger
from torchrl.record import CSVLogger, VideoRecorder
path = "./training_loop"
logger = CSVLogger(exp_name="dqn", log_dir=path, video_format="mp4")
video_recorder = VideoRecorder(logger, tag="video")
record_env = TransformedEnv(
GymEnv("CartPole-v1", from_pixels=True, pixels_only=False), video_recorder
)
訓練迴圈¶
我們將不固定執行的迭代次數,而是持續訓練網路,直到它達到一定的效能(任意定義為在環境中達到 200 步 - 對於 CartPole,成功定義為具有更長的軌跡)。
total_count = 0
total_episodes = 0
t0 = time.time()
for i, data in enumerate(collector):
# Write data in replay buffer
rb.extend(data)
max_length = rb[:]["next", "step_count"].max()
if len(rb) > init_rand_steps:
# Optim loop (we do several optim steps
# per batch collected for efficiency)
for _ in range(optim_steps):
sample = rb.sample(128)
loss_vals = loss(sample)
loss_vals["loss"].backward()
optim.step()
optim.zero_grad()
# Update exploration factor
exploration_module.step(data.numel())
# Update target params
updater.step()
if i % 10:
torchrl_logger.info(f"Max num steps: {max_length}, rb length {len(rb)}")
total_count += data.numel()
total_episodes += data["next", "done"].sum()
if max_length > 200:
break
t1 = time.time()
torchrl_logger.info(
f"solved after {total_count} steps, {total_episodes} episodes and in {t1-t0}s."
)
渲染¶
最後,我們讓環境執行儘可能多的步數,並在本地儲存影片(注意我們此時沒有進行探索)。
record_env.rollout(max_steps=1000, policy=policy)
video_recorder.dump()
完整的訓練迴圈結束後,你渲染的 CartPole 影片將看起來像這樣
至此,我們的“TorchRL 入門”系列教程就結束了!歡迎在 GitHub 上分享你的反饋。