• 教程 >
  • 擺錘:使用 TorchRL 編寫你的環境和變換
快捷方式

擺錘:使用 TorchRL 編寫你的環境和變換

創建於:2023 年 11 月 9 日 | 最後更新:2025 年 1 月 27 日 | 最後驗證:2024 年 11 月 5 日

作者Vincent Moens

建立環境(模擬器或物理控制系統的介面)是強化學習和控制工程的組成部分。

TorchRL 提供了一套工具,可在多種情境下實現這一點。本教程演示如何使用 PyTorch 和 TorchRL 從頭開始編寫一個擺錘模擬器。其靈感來自於 OpenAI-Gym/Farama-Gymnasium 控制庫中的 Pendulum-v1 實現。

Pendulum

簡單擺錘

關鍵學習內容

  • 如何在 TorchRL 中設計環境:- 編寫規範 (輸入、觀測和獎勵);- 實現行為:設定種子、重置和步進。

  • 變換你的環境輸入和輸出,以及編寫你自己的變換;

  • 如何使用 TensorDictcodebase 中傳遞任意資料結構。

    在此過程中,我們將接觸 TorchRL 的三個關鍵組成部分

為了讓大家瞭解 TorchRL 的環境能實現什麼,我們將設計一個無狀態環境。有狀態環境會記錄遇到的最新物理狀態並以此來模擬狀態間的轉換,而無狀態環境則期望在每個步驟中接收當前狀態以及採取的動作。TorchRL 支援這兩種型別的環境,但無狀態環境更具通用性,因此涵蓋了 TorchRL 環境 API 的更廣泛功能。

對無狀態環境進行建模使使用者能夠完全控制模擬器的輸入和輸出:可以在任何階段重置實驗或從外部主動修改動態。然而,這假定我們對任務有一定控制權,但這並非總是如此:解決一個我們無法控制當前狀態的問題更具挑戰性,但應用範圍也更廣。

無狀態環境的另一個優點是它們可以實現轉換模擬的批次執行。如果後端和實現允許,代數運算可以在標量、向量或張量上無縫執行。本教程提供了此類示例。

本教程結構如下

  • 我們將首先熟悉環境屬性:其形狀 (batch_size)、其方法(主要是 step()reset()set_seed()),最後是其規範。

  • 編寫好模擬器後,我們將演示如何在訓練中使用變換。

  • 我們將探索 TorchRL API 引出的新途徑,包括:變換輸入的可能性、模擬的向量化執行以及透過模擬圖進行反向傳播的可能性。

  • 最後,我們將訓練一個簡單的策略來解決我們實現的系統。

from collections import defaultdict
from typing import Optional

import numpy as np
import torch
import tqdm
from tensordict import TensorDict, TensorDictBase
from tensordict.nn import TensorDictModule
from torch import nn

from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.envs import (
    CatTensors,
    EnvBase,
    Transform,
    TransformedEnv,
    UnsqueezeTransform,
)
from torchrl.envs.transforms.transforms import _apply_to_composite
from torchrl.envs.utils import check_env_specs, step_mdp

DEFAULT_X = np.pi
DEFAULT_Y = 1.0

設計新的環境類時,必須注意四件事

  • EnvBase._reset(),用於在(可能隨機的)初始狀態重置模擬器;

  • EnvBase._step(),用於編寫狀態轉移動態;

  • EnvBase._set_seed`(),用於實現種子機制;

  • 環境規範。

首先讓我們描述一下當前的問題:我們想要模擬一個簡單的擺錘,我們可以控制施加在其固定點上的扭矩。我們的目標是將擺錘置於向上位置(根據約定,角位置為 0),並使其在該位置保持靜止。為了設計我們的動態系統,我們需要定義兩個方程:執行動作(施加扭矩)後的運動方程,以及將構成我們目標函式的獎勵方程。

對於運動方程,我們將根據以下公式更新角速度

\[\dot{\theta}_{t+1} = \dot{\theta}_t + (3 * g / (2 * L) * \sin(\theta_t) + 3 / (m * L^2) * u) * dt\]

其中 \(\dot{\theta}\) 是以 rad/sec 為單位的角速度,\(g\) 是重力加速度,\(L\) 是擺錘長度,\(m\) 是其質量,\(\theta\) 是其角位置,\(u\) 是扭矩。然後角位置根據以下公式更新

\[\theta_{t+1} = \theta_{t} + \dot{\theta}_{t+1} dt\]

我們將獎勵定義為

\[r = -(\theta^2 + 0.1 * \dot{\theta}^2 + 0.001 * u^2)\]

當角度接近 0(擺錘在向上位置)、角速度接近 0(無運動)且扭矩也為 0 時,該獎勵將被最大化。

編寫動作的效果:_step()

步進方法是首先要考慮的,因為它將編碼我們感興趣的模擬過程。在 TorchRL 中,EnvBase 類有一個 EnvBase.step() 方法,該方法接收一個 tensordict.TensorDict 例項,其中包含一個 "action" 條目,指示將要採取的動作。

為了方便從該 tensordict 中讀取和寫入,並確保鍵與庫期望的一致,模擬部分已委託給一個私有抽象方法 _step(),該方法從 tensordict 讀取輸入資料,並使用輸出資料寫入一個新的 tensordict

_step() 方法應執行以下操作

  1. 讀取輸入鍵(如 "action"),並基於這些鍵執行模擬;

  2. 獲取觀測、完成狀態和獎勵;

  3. 將觀測值集以及獎勵和完成狀態寫入新的 TensorDict 中相應的條目。

接下來,step() 方法將把 step() 的輸出合併到輸入 tensordict 中,以強制輸入/輸出一致性。

通常,對於有狀態環境,這看起來像這樣

>>> policy(env.reset())
>>> print(tensordict)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)
>>> env.step(tensordict)
>>> print(tensordict)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

注意,根 tensordict 沒有改變,唯一的修改是出現了一個新的 "next" 條目,其中包含新資訊。

在擺錘示例中,我們的 _step() 方法將從輸入 tensordict 中讀取相關條目,並計算施加由 "action" 鍵編碼的力後襬錘的位置和速度。我們將擺錘的新角位置 "new_th" 計算為前一位置 "th" 加上新速度 "new_thdot" 在時間間隔 dt 內的結果。

由於我們的目標是將擺錘向上翻轉並保持在該位置靜止,因此對於接近目標位置和低速度的情況,我們的 cost (負獎勵) 函式值較低。實際上,我們希望避免遠離“向上”位置和/或遠離 0 的速度。

在我們的示例中,EnvBase._step() 被編碼為一個靜態方法,因為我們的環境是無狀態的。在有狀態設定中,需要 self 引數,因為狀態需要從環境中讀取。

def _step(tensordict):
    th, thdot = tensordict["th"], tensordict["thdot"]  # th := theta

    g_force = tensordict["params", "g"]
    mass = tensordict["params", "m"]
    length = tensordict["params", "l"]
    dt = tensordict["params", "dt"]
    u = tensordict["action"].squeeze(-1)
    u = u.clamp(-tensordict["params", "max_torque"], tensordict["params", "max_torque"])
    costs = angle_normalize(th) ** 2 + 0.1 * thdot**2 + 0.001 * (u**2)

    new_thdot = (
        thdot
        + (3 * g_force / (2 * length) * th.sin() + 3.0 / (mass * length**2) * u) * dt
    )
    new_thdot = new_thdot.clamp(
        -tensordict["params", "max_speed"], tensordict["params", "max_speed"]
    )
    new_th = th + new_thdot * dt
    reward = -costs.view(*tensordict.shape, 1)
    done = torch.zeros_like(reward, dtype=torch.bool)
    out = TensorDict(
        {
            "th": new_th,
            "thdot": new_thdot,
            "params": tensordict["params"],
            "reward": reward,
            "done": done,
        },
        tensordict.shape,
    )
    return out


def angle_normalize(x):
    return ((x + torch.pi) % (2 * torch.pi)) - torch.pi

重置模擬器:_reset()

我們需要關注的第二個方法是 _reset() 方法。與 _step() 一樣,它應將觀測條目和可能的完成狀態寫入其輸出的 tensordict 中(如果省略完成狀態,父方法 reset() 會將其填充為 False)。在某些情況下,_reset 方法需要接收來自呼叫它的函式的命令(例如,在多智慧體設定中,我們可能需要指示哪些智慧體需要重置)。這就是為什麼 _reset() 方法也期望接收一個 tensordict 作為輸入,儘管它可以是空的或 None

父方法 EnvBase.reset() 會像 EnvBase.step() 那樣執行一些簡單的檢查,例如確保輸出 tensordict 中返回了 "done" 狀態,並且形狀與規範中期望的匹配。

對我們來說,唯一需要考慮的重要事情是 EnvBase._reset() 是否包含所有預期的觀測。再說一次,由於我們處理的是無狀態環境,我們將擺錘的配置在一個名為 "params" 的巢狀 tensordict 中傳遞。

在此示例中,我們沒有傳遞完成狀態,因為這對 _reset() 不是強制的,並且我們的環境是非終止的,所以我們始終期望它為 False

def _reset(self, tensordict):
    if tensordict is None or tensordict.is_empty():
        # if no ``tensordict`` is passed, we generate a single set of hyperparameters
        # Otherwise, we assume that the input ``tensordict`` contains all the relevant
        # parameters to get started.
        tensordict = self.gen_params(batch_size=self.batch_size)

    high_th = torch.tensor(DEFAULT_X, device=self.device)
    high_thdot = torch.tensor(DEFAULT_Y, device=self.device)
    low_th = -high_th
    low_thdot = -high_thdot

    # for non batch-locked environments, the input ``tensordict`` shape dictates the number
    # of simulators run simultaneously. In other contexts, the initial
    # random state's shape will depend upon the environment batch-size instead.
    th = (
        torch.rand(tensordict.shape, generator=self.rng, device=self.device)
        * (high_th - low_th)
        + low_th
    )
    thdot = (
        torch.rand(tensordict.shape, generator=self.rng, device=self.device)
        * (high_thdot - low_thdot)
        + low_thdot
    )
    out = TensorDict(
        {
            "th": th,
            "thdot": thdot,
            "params": tensordict["params"],
        },
        batch_size=tensordict.shape,
    )
    return out

環境元資料:env.*_spec

規範定義了環境的輸入和輸出域。規範準確定義執行時將接收到的張量非常重要,因為它們經常用於在多程序和分散式設定中攜帶有關環境的資訊。它們還可以用於例項化延遲定義的神經網路和測試指令碼,而無需實際查詢環境(例如,對於真實世界的物理系統來說,這可能成本很高)。

在我們的環境中,必須編寫以下四種規範

  • EnvBase.observation_spec:這將是一個 CompositeSpec 例項,其中每個鍵都是一個觀測(CompositeSpec 可以看作是規範的字典)。

  • EnvBase.action_spec:可以是任何型別的規範,但要求它與輸入 tensordict 中的 "action" 條目相對應;

  • EnvBase.reward_spec:提供關於獎勵空間的資訊;

  • EnvBase.done_spec:提供關於完成標誌空間的資訊。

TorchRL 規範組織在兩個通用容器中:input_spec,包含步進函式讀取的資訊規範(分為包含動作的 action_spec 和包含其餘所有內容的 state_spec);以及 output_spec,編碼了步進輸出的規範(observation_specreward_specdone_spec)。一般來說,你不應該直接與 output_specinput_spec 互動,而只與其內容互動:observation_specreward_specdone_specaction_specstate_spec。原因是這些規範在 output_specinput_spec 內以非平凡的方式組織,並且它們都不應該被直接修改。

換句話說,observation_spec 及相關屬性是訪問輸出和輸入規範容器內容的便捷快捷方式。

TorchRL 提供了多種 TensorSpec 子類來編碼環境的輸入和輸出特徵。

規範形狀

環境規範的主要維度必須與環境的批次大小匹配。這樣做是為了強制環境的每個元件(包括其變換)都能準確表示預期的輸入和輸出形狀。這在有狀態設定中應該被準確編碼。

對於非批次鎖定的環境,例如我們示例中的環境(見下文),這是不相關的,因為環境的批次大小很可能是空的。

def _make_spec(self, td_params):
    # Under the hood, this will populate self.output_spec["observation"]
    self.observation_spec = CompositeSpec(
        th=BoundedTensorSpec(
            low=-torch.pi,
            high=torch.pi,
            shape=(),
            dtype=torch.float32,
        ),
        thdot=BoundedTensorSpec(
            low=-td_params["params", "max_speed"],
            high=td_params["params", "max_speed"],
            shape=(),
            dtype=torch.float32,
        ),
        # we need to add the ``params`` to the observation specs, as we want
        # to pass it at each step during a rollout
        params=make_composite_from_td(td_params["params"]),
        shape=(),
    )
    # since the environment is stateless, we expect the previous output as input.
    # For this, ``EnvBase`` expects some state_spec to be available
    self.state_spec = self.observation_spec.clone()
    # action-spec will be automatically wrapped in input_spec when
    # `self.action_spec = spec` will be called supported
    self.action_spec = BoundedTensorSpec(
        low=-td_params["params", "max_torque"],
        high=td_params["params", "max_torque"],
        shape=(1,),
        dtype=torch.float32,
    )
    self.reward_spec = UnboundedContinuousTensorSpec(shape=(*td_params.shape, 1))


def make_composite_from_td(td):
    # custom function to convert a ``tensordict`` in a similar spec structure
    # of unbounded values.
    composite = CompositeSpec(
        {
            key: make_composite_from_td(tensor)
            if isinstance(tensor, TensorDictBase)
            else UnboundedContinuousTensorSpec(
                dtype=tensor.dtype, device=tensor.device, shape=tensor.shape
            )
            for key, tensor in td.items()
        },
        shape=td.shape,
    )
    return composite

可復現實驗:設定種子

設定環境種子是初始化實驗時的常見操作。EnvBase._set_seed() 的唯一目標是設定所包含模擬器的種子。如果可能,此操作不應呼叫 reset() 或與環境執行互動。父方法 EnvBase.set_seed() 集成了一個機制,允許使用不同的偽隨機且可復現的種子為多個環境設定種子。

def _set_seed(self, seed: Optional[int]):
    rng = torch.manual_seed(seed)
    self.rng = rng

整合各部分:EnvBase

我們終於可以將各部分整合起來設計我們的環境類了。規範的初始化需要在環境構建過程中執行,因此我們必須確保在 PendulumEnv.__init__() 中呼叫 _make_spec() 方法。

我們新增一個靜態方法 PendulumEnv.gen_params(),它確定性地生成一組用於執行過程的超引數

def gen_params(g=10.0, batch_size=None) -> TensorDictBase:
    """Returns a ``tensordict`` containing the physical parameters such as gravitational force and torque or speed limits."""
    if batch_size is None:
        batch_size = []
    td = TensorDict(
        {
            "params": TensorDict(
                {
                    "max_speed": 8,
                    "max_torque": 2.0,
                    "dt": 0.05,
                    "g": g,
                    "m": 1.0,
                    "l": 1.0,
                },
                [],
            )
        },
        [],
    )
    if batch_size:
        td = td.expand(batch_size).contiguous()
    return td

我們將環境定義為非 batch_locked,方法是將 homonymous 屬性設定為 False。這意味著我們不會強制輸入 tensordictbatch-size 與環境的批次大小匹配。

以下程式碼將整合我們上面編寫的各部分。

class PendulumEnv(EnvBase):
    metadata = {
        "render_modes": ["human", "rgb_array"],
        "render_fps": 30,
    }
    batch_locked = False

    def __init__(self, td_params=None, seed=None, device="cpu"):
        if td_params is None:
            td_params = self.gen_params()

        super().__init__(device=device, batch_size=[])
        self._make_spec(td_params)
        if seed is None:
            seed = torch.empty((), dtype=torch.int64).random_().item()
        self.set_seed(seed)

    # Helpers: _make_step and gen_params
    gen_params = staticmethod(gen_params)
    _make_spec = _make_spec

    # Mandatory methods: _step, _reset and _set_seed
    _reset = _reset
    _step = staticmethod(_step)
    _set_seed = _set_seed

測試我們的環境

TorchRL 提供了一個簡單的函式 check_env_specs(),用於檢查(變換後的)環境的輸入/輸出結構是否與其規範規定的結構匹配。讓我們試一下

env = PendulumEnv()
check_env_specs(env)

我們可以檢視我們的規範,以直觀瞭解環境的簽名

print("observation_spec:", env.observation_spec)
print("state_spec:", env.state_spec)
print("reward_spec:", env.reward_spec)

我們也可以執行一些命令來檢查輸出結構是否符合預期。

td = env.reset()
print("reset tensordict", td)

我們可以執行 env.rand_step()action_spec 域中隨機生成一個動作。由於我們的環境是無狀態的,必須傳遞一個包含超引數和當前狀態的 tensordict。在有狀態環境中,env.rand_step() 也能完美工作。

td = env.rand_step(td)
print("random step tensordict", td)

變換環境

為無狀態模擬器編寫環境變換比有狀態模擬器稍微複雜一些:對需要在下一次迭代中讀取的輸出條目進行變換,需要在下一步呼叫 meth.step() 之前應用逆變換。這是一個展示 TorchRL 變換所有功能的理想場景!

例如,在以下變換後的環境中,我們對條目 ["th", "thdot"] 進行 unsqueeze 操作,以便能夠沿著最後一個維度堆疊它們。我們還將它們作為 in_keys_inv 傳遞,以便在下一次迭代中作為輸入傳遞時,將它們擠壓回原始形狀。

env = TransformedEnv(
    env,
    # ``Unsqueeze`` the observations that we will concatenate
    UnsqueezeTransform(
        dim=-1,
        in_keys=["th", "thdot"],
        in_keys_inv=["th", "thdot"],
    ),
)

編寫自定義變換

TorchRL 的變換可能無法涵蓋環境執行後所有想要執行的操作。編寫變換不需要太多精力。與環境設計一樣,編寫變換也有兩個步驟

  • 正確實現動態(正向和逆向);

  • 調整環境規範。

變換可以在兩種設定中使用:獨立使用時,它可以作為 Module 使用。它也可以附加到 TransformedEnv 使用。類結構允許在不同上下文中自定義行為。

Transform 骨架可以概括如下

class Transform(nn.Module):
    def forward(self, tensordict):
        ...
    def _apply_transform(self, tensordict):
        ...
    def _step(self, tensordict):
        ...
    def _call(self, tensordict):
        ...
    def inv(self, tensordict):
        ...
    def _inv_apply_transform(self, tensordict):
        ...

有三個入口點(forward()_step()inv()),它們都接收 tensordict.TensorDict 例項。前兩個最終將遍歷由 in_keys 指示的鍵,並對每個鍵呼叫 _apply_transform()。結果將寫入由 Transform.out_keys 指示的條目中(如果提供了該屬性;如果未提供,in_keys 將用變換後的值更新)。如果需要執行逆變換,將執行類似的資料流,但使用 Transform.inv()Transform._inv_apply_transform() 方法,並作用於 in_keys_invout_keys_inv 鍵列表。下圖總結了環境和回放緩衝區的資料流。

變換 API

在某些情況下,變換不會以單元方式處理鍵的子集,而是會在父環境上執行某些操作或處理整個輸入 tensordict。在這些情況下,應重寫 _call()forward() 方法,並可以跳過 _apply_transform() 方法。

讓我們編寫新的變換,計算位置角的 sinecosine 值,因為這些值對我們學習策略比原始角度值更有用

class SinTransform(Transform):
    def _apply_transform(self, obs: torch.Tensor) -> None:
        return obs.sin()

    # The transform must also modify the data at reset time
    def _reset(
        self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
    ) -> TensorDictBase:
        return self._call(tensordict_reset)

    # _apply_to_composite will execute the observation spec transform across all
    # in_keys/out_keys pairs and write the result in the observation_spec which
    # is of type ``Composite``
    @_apply_to_composite
    def transform_observation_spec(self, observation_spec):
        return BoundedTensorSpec(
            low=-1,
            high=1,
            shape=observation_spec.shape,
            dtype=observation_spec.dtype,
            device=observation_spec.device,
        )


class CosTransform(Transform):
    def _apply_transform(self, obs: torch.Tensor) -> None:
        return obs.cos()

    # The transform must also modify the data at reset time
    def _reset(
        self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
    ) -> TensorDictBase:
        return self._call(tensordict_reset)

    # _apply_to_composite will execute the observation spec transform across all
    # in_keys/out_keys pairs and write the result in the observation_spec which
    # is of type ``Composite``
    @_apply_to_composite
    def transform_observation_spec(self, observation_spec):
        return BoundedTensorSpec(
            low=-1,
            high=1,
            shape=observation_spec.shape,
            dtype=observation_spec.dtype,
            device=observation_spec.device,
        )


t_sin = SinTransform(in_keys=["th"], out_keys=["sin"])
t_cos = CosTransform(in_keys=["th"], out_keys=["cos"])
env.append_transform(t_sin)
env.append_transform(t_cos)

將觀測值連線到“observation”條目上。del_keys=False 確保我們將這些值保留到下一次迭代。

cat_transform = CatTensors(
    in_keys=["sin", "cos", "thdot"], dim=-1, out_key="observation", del_keys=False
)
env.append_transform(cat_transform)

我們再次檢查環境規格是否與接收到的相匹配

check_env_specs(env)

執行 Rollout

執行 Rollout 是一個簡單的步驟序列

  • 重置環境

  • 當某個條件未滿足時

    • 根據策略計算動作

    • 根據此動作執行一步

    • 收集資料

    • 執行一步 MDP

  • 收集資料並返回

這些操作已方便地封裝在 rollout() 方法中,我們在下方提供了其簡化版本。

def simple_rollout(steps=100):
    # preallocate:
    data = TensorDict({}, [steps])
    # reset
    _data = env.reset()
    for i in range(steps):
        _data["action"] = env.action_spec.rand()
        _data = env.step(_data)
        data[i] = _data
        _data = step_mdp(_data, keep_other=True)
    return data


print("data from rollout:", simple_rollout(100))

批處理計算

本教程最後一個尚未探索的部分是我們在 TorchRL 中進行批處理計算的能力。由於我們的環境對輸入資料形狀沒有任何假設,我們可以無縫地在資料批次上執行它。更妙的是:對於像我們的 Pendulum 這樣非批次鎖定(non-batch-locked)的環境,我們無需重新建立環境就可以動態更改批次大小。為此,我們只需生成具有所需形狀的引數即可。

batch_size = 10  # number of environments to be executed in batch
td = env.reset(env.gen_params(batch_size=[batch_size]))
print("reset (batch size of 10)", td)
td = env.rand_step(td)
print("rand step (batch size of 10)", td)

使用資料批次執行 rollout 需要我們在 rollout 函式之外重置環境,因為我們需要動態定義 `batch_size`,而 rollout() 不支援此功能。

rollout = env.rollout(
    3,
    auto_reset=False,  # we're executing the reset out of the ``rollout`` call
    tensordict=env.reset(env.gen_params(batch_size=[batch_size])),
)
print("rollout of len 3 (batch size of 10):", rollout)

訓練一個簡單策略

在此示例中,我們將使用獎勵作為可微分目標(例如負損失)來訓練一個簡單策略。我們將利用我們的動態系統完全可微分的事實,透過軌跡回報進行反向傳播,並調整我們策略的權重以直接最大化此值。當然,在許多設定中,我們所做的許多假設(例如可微分系統和對底層機制的完全訪問)並不成立。

儘管如此,這是一個非常簡單的示例,展示瞭如何在 TorchRL 中使用自定義環境編寫訓練迴圈。

我們先編寫策略網路

torch.manual_seed(0)
env.set_seed(0)

net = nn.Sequential(
    nn.LazyLinear(64),
    nn.Tanh(),
    nn.LazyLinear(64),
    nn.Tanh(),
    nn.LazyLinear(64),
    nn.Tanh(),
    nn.LazyLinear(1),
)
policy = TensorDictModule(
    net,
    in_keys=["observation"],
    out_keys=["action"],
)

以及我們的最佳化器

optim = torch.optim.Adam(policy.parameters(), lr=2e-3)

訓練迴圈

我們將依次執行

  • 生成軌跡

  • 對獎勵求和

  • 透過這些操作定義的圖進行反向傳播

  • 裁剪梯度範數並執行最佳化步驟

  • 重複

在訓練迴圈結束時,我們應該得到一個接近 0 的最終獎勵,這表明擺錘向上並保持靜止,達到了預期目標。

batch_size = 32
pbar = tqdm.tqdm(range(20_000 // batch_size))
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, 20_000)
logs = defaultdict(list)

for _ in pbar:
    init_td = env.reset(env.gen_params(batch_size=[batch_size]))
    rollout = env.rollout(100, policy, tensordict=init_td, auto_reset=False)
    traj_return = rollout["next", "reward"].mean()
    (-traj_return).backward()
    gn = torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
    optim.step()
    optim.zero_grad()
    pbar.set_description(
        f"reward: {traj_return: 4.4f}, "
        f"last reward: {rollout[..., -1]['next', 'reward'].mean(): 4.4f}, gradient norm: {gn: 4.4}"
    )
    logs["return"].append(traj_return.item())
    logs["last_reward"].append(rollout[..., -1]["next", "reward"].mean().item())
    scheduler.step()


def plot():
    import matplotlib
    from matplotlib import pyplot as plt

    is_ipython = "inline" in matplotlib.get_backend()
    if is_ipython:
        from IPython import display

    with plt.ion():
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.plot(logs["return"])
        plt.title("returns")
        plt.xlabel("iteration")
        plt.subplot(1, 2, 2)
        plt.plot(logs["last_reward"])
        plt.title("last reward")
        plt.xlabel("iteration")
        if is_ipython:
            display.display(plt.gcf())
            display.clear_output(wait=True)
        plt.show()


plot()

結論

在本教程中,我們學習瞭如何從零開始編寫一個無狀態環境。我們涉及了以下主題:

  • 編碼環境時需要關注的四個基本元件(stepreset、seeding 和構建 specs)。我們瞭解了這些方法和類如何與 TensorDict 類互動;

  • 如何使用 check_env_specs() 檢查環境是否正確編碼;

  • 如何在無狀態環境的上下文中新增 transforms 以及如何編寫自定義 transformations;

  • 如何在完全可微分的模擬器上訓練策略。

指令碼總執行時間: ( 0 minutes 0.000 seconds)

由 Sphinx-Gallery 生成的畫廊

文件

訪問 PyTorch 的綜合開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得解答

檢視資源