快捷方式

TorchRL 簡介

本簡報在 ICML 2022 工業演示日上進行展示。

它很好地概述了 TorchRL 的功能。如果您對此有疑問或意見,請隨時聯絡 vmoens@fb.com 或提交 issue。

TorchRL 是一個用於 PyTorch 的開源強化學習 (RL) 庫。

https://github.com/pytorch/rl

PyTorch 生態系統團隊 (Meta) 已決定投資該庫,以提供一個領先的平臺,用於在研究環境中開發 RL 解決方案。

它提供了 pytorch 和Python 優先的低階和高階抽象,旨在實現高效、文件齊全和經過充分測試。程式碼旨在支援 RL 研究。大部分程式碼以高度模組化的方式用 Python 編寫,以便研究人員可以輕鬆地替換元件、轉換元件或編寫新元件,而無需太多精力。

該倉庫力求與現有的 pytorch 生態系統庫對齊,它具有資料集支柱 (torchrl/envs)、轉換、模型、資料工具(例如收集器和容器)等。TorchRL 的目標是儘可能少地依賴項(Python 標準庫、numpy 和 pytorch)。常見的環境庫(例如 OpenAI gym)僅為可選依賴。

內容:
../_images/aafig-1f3b6e30cfaaae3f21ed3b55ebfc722276b91b6f.svg

與許多其他領域不同,RL 更多關注演算法而非媒體。因此,很難構建真正獨立的元件。

TorchRL 不是什麼

  • 演算法集合:我們不打算提供 RL 演算法的 SOTA(最新最好)實現,我們提供這些演算法僅作為如何使用該庫的示例。

  • 研究框架:TorchRL 的模組化有兩種形式。首先,我們嘗試構建可重用元件,以便它們可以輕鬆地相互替換。其次,我們盡最大努力使元件可以獨立於庫的其餘部分使用。

TorchRL 的核心依賴項非常少,主要是 PyTorch 和 numpy。所有其他依賴項(gym、torchvision、wandb / tensorboard)都是可選的。

資料

TensorDict

import torch
from tensordict import TensorDict

讓我們建立一個 TensorDict。建構函式接受許多不同的格式,例如傳入一個 dict 或使用關鍵字引數。

batch_size = 5
data = TensorDict(
    key1=torch.zeros(batch_size, 3),
    key2=torch.zeros(batch_size, 5, 6, dtype=torch.bool),
    batch_size=[batch_size],
)
print(data)

您可以沿著 batch_size 索引 TensorDict,也可以查詢鍵 (keys)。

print(data[2])
print(data["key1"] is data.get("key1"))

以下展示瞭如何堆疊多個 TensorDict。這在編寫 rollout 迴圈時特別有用!

data1 = TensorDict(
    {
        "key1": torch.zeros(batch_size, 1),
        "key2": torch.zeros(batch_size, 5, 6, dtype=torch.bool),
    },
    batch_size=[batch_size],
)

data2 = TensorDict(
    {
        "key1": torch.ones(batch_size, 1),
        "key2": torch.ones(batch_size, 5, 6, dtype=torch.bool),
    },
    batch_size=[batch_size],
)

data = torch.stack([data1, data2], 0)
data.batch_size, data["key1"]

以下是 TensorDict 的其他一些功能:檢視、置換、共享記憶體或擴充套件。

print(
    "view(-1): ",
    data.view(-1).batch_size,
    data.view(-1).get("key1").shape,
)

print("to device: ", data.to("cpu"))

# print("pin_memory: ", data.pin_memory())

print("share memory: ", data.share_memory_())

print(
    "permute(1, 0): ",
    data.permute(1, 0).batch_size,
    data.permute(1, 0).get("key1").shape,
)

print(
    "expand: ",
    data.expand(3, *data.batch_size).batch_size,
    data.expand(3, *data.batch_size).get("key1").shape,
)

您也可以建立巢狀資料

data = TensorDict(
    source={
        "key1": torch.zeros(batch_size, 3),
        "key2": TensorDict(
            source={"sub_key1": torch.zeros(batch_size, 2, 1)},
            batch_size=[batch_size, 2],
        ),
    },
    batch_size=[batch_size],
)
data

經驗回放緩衝區

經驗回放緩衝區是許多 RL 演算法中的關鍵組成部分。TorchRL 提供了一系列經驗回放緩衝區實現。大多數基本功能適用於任何資料結構(list、tuples、dict),但為了充分發揮經驗回放緩衝區的作用並實現快速讀寫訪問,應優先使用 TensorDict API。

from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer

rb = ReplayBuffer(collate_fn=lambda x: x)

可以使用 add() (n=1) 或 extend() (n>1) 新增。

rb.add(1)
rb.sample(1)
rb.extend([2, 3])
rb.sample(3)

也可以使用優先經驗回放緩衝區

rb = PrioritizedReplayBuffer(alpha=0.7, beta=1.1, collate_fn=lambda x: x)
rb.add(1)
rb.sample(1)
rb.update_priority(1, 0.5)

以下是使用帶 data_stack 的經驗回放緩衝區的示例。使用它們可以輕鬆地抽象經驗回放緩衝區在多種用例中的行為。

collate_fn = torch.stack
rb = ReplayBuffer(collate_fn=collate_fn)
rb.add(TensorDict({"a": torch.randn(3)}, batch_size=[]))
len(rb)

rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2]))
print(len(rb))
print(rb.sample(10))
print(rb.sample(2).contiguous())

torch.manual_seed(0)
from torchrl.data import TensorDictPrioritizedReplayBuffer

rb = TensorDictPrioritizedReplayBuffer(alpha=0.7, beta=1.1, priority_key="td_error")
rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2]))
data_sample = rb.sample(2).contiguous()
print(data_sample)

print(data_sample["index"])

data_sample["td_error"] = torch.rand(2)
rb.update_tensordict_priority(data_sample)

for i, val in enumerate(rb._sampler._sum_tree):
    print(i, val)
    if i == len(rb):
        break

環境

TorchRL 提供了一系列環境封裝器和工具。

Gym 環境

try:
    import gymnasium as gym
except ModuleNotFoundError:
    import gym

from torchrl.envs.libs.gym import GymEnv, GymWrapper, set_gym_backend

gym_env = gym.make("Pendulum-v1")
env = GymWrapper(gym_env)
env = GymEnv("Pendulum-v1")

data = env.reset()
env.rand_step(data)

更改環境配置

env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
env.reset()

env.close()
del env

from torchrl.envs import (
    Compose,
    NoopResetEnv,
    ObservationNorm,
    ToTensorImage,
    TransformedEnv,
)

base_env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage()))
env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))

環境轉換

轉換的作用類似於 Gym 封裝器,但其 API 更接近 torchvision 的 torch.distributions 的轉換。有多種轉換可供選擇。

from torchrl.envs import (
    Compose,
    NoopResetEnv,
    ObservationNorm,
    StepCounter,
    ToTensorImage,
    TransformedEnv,
)

base_env = GymEnv("HalfCheetah-v4", frame_skip=3, from_pixels=True, pixels_only=False)
env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage()))
env = env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))

env.reset()

print("env: ", env)
print("last transform parent: ", env.transform[2].parent)

向量化環境

向量化/並行環境可以提供顯著的加速。

from torchrl.envs import ParallelEnv


def make_env():
    # You can control whether to use gym or gymnasium for your env
    with set_gym_backend("gym"):
        return GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)


base_env = ParallelEnv(
    4,
    make_env,
    mp_start_method="fork",  # This will break on Windows machines! Remove and decorate with if __name__ == "__main__"
)
env = TransformedEnv(
    base_env, Compose(StepCounter(), ToTensorImage())
)  # applies transforms on batch of envs
env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))
env.reset()

print(env.action_spec)

env.close()
del env

模組

庫中提供了多種模組(工具、模型和封裝器)。

模型

MLP 模型示例

from torch import nn
from torchrl.modules import ConvNet, MLP
from torchrl.modules.models.utils import SquashDims

net = MLP(num_cells=[32, 64], out_features=4, activation_class=nn.ELU)
print(net)
print(net(torch.randn(10, 3)).shape)

CNN 模型示例

cnn = ConvNet(
    num_cells=[32, 64],
    kernel_sizes=[8, 4],
    strides=[2, 1],
    aggregator_class=SquashDims,
)
print(cnn)
print(cnn(torch.randn(10, 3, 32, 32)).shape)  # last tensor is squashed

TensorDict 模組

一些模組專門設計用於處理 TensorDict 輸入。

from tensordict.nn import TensorDictModule

data = TensorDict({"key1": torch.randn(10, 3)}, batch_size=[10])
module = nn.Linear(3, 4)
td_module = TensorDictModule(module, in_keys=["key1"], out_keys=["key2"])
td_module(data)
print(data)

模組序列

TensorDictSequential 使構建模組序列變得容易。

from tensordict.nn import TensorDictSequential

backbone_module = nn.Linear(5, 3)
backbone = TensorDictModule(
    backbone_module, in_keys=["observation"], out_keys=["hidden"]
)
actor_module = nn.Linear(3, 4)
actor = TensorDictModule(actor_module, in_keys=["hidden"], out_keys=["action"])
value_module = MLP(out_features=1, num_cells=[4, 5])
value = TensorDictModule(value_module, in_keys=["hidden", "action"], out_keys=["value"])

sequence = TensorDictSequential(backbone, actor, value)
print(sequence)

print(sequence.in_keys, sequence.out_keys)

data = TensorDict(
    {"observation": torch.randn(3, 5)},
    [3],
)
backbone(data)
actor(data)
value(data)

data = TensorDict(
    {"observation": torch.randn(3, 5)},
    [3],
)
sequence(data)
print(data)

函數語言程式設計(整合 / 元強化學習)

函式式呼叫從未如此簡單。使用 from_module() 提取引數,然後使用 to_module() 替換它們。

from tensordict import from_module

params = from_module(sequence)
print("extracted params", params)

使用 TensorDict 的函式式呼叫

with params.to_module(sequence):
    data = sequence(data)

VMAP

快速執行多個相似架構的副本是快速訓練模型的關鍵。vmap() 正是為此量身定製的。

from torch import vmap

params_expand = params.expand(4)


def exec_sequence(params, data):
    with params.to_module(sequence):
        return sequence(data)


tensordict_exp = vmap(exec_sequence, (0, None))(params_expand, data)
print(tensordict_exp)

專門類

TorchRL 還提供了一些專門模組,它們對輸出值執行檢查。

torch.manual_seed(0)
from torchrl.data import Bounded
from torchrl.modules import SafeModule

spec = Bounded(-torch.ones(3), torch.ones(3))
base_module = nn.Linear(5, 3)
module = SafeModule(
    module=base_module, spec=spec, in_keys=["obs"], out_keys=["action"], safe=True
)
data = TensorDict({"obs": torch.randn(5)}, batch_size=[])
module(data)["action"]

data = TensorDict({"obs": torch.randn(5) * 100}, batch_size=[])
module(data)["action"]  # safe=True projects the result within the set

Actor 類有一個預定義的輸出鍵 ("action")。

from torchrl.modules import Actor

base_module = nn.Linear(5, 3)
actor = Actor(base_module, in_keys=["obs"])
data = TensorDict({"obs": torch.randn(5)}, batch_size=[])
actor(data)  # action is the default value

from tensordict.nn import (
    ProbabilisticTensorDictModule,
    ProbabilisticTensorDictSequential,
)

藉助 tensordict.nn API,處理機率模型也變得容易。

from torchrl.modules import NormalParamExtractor, TanhNormal

td = TensorDict({"input": torch.randn(3, 5)}, [3])
net = nn.Sequential(
    nn.Linear(5, 4), NormalParamExtractor()
)  # splits the output in loc and scale
module = TensorDictModule(net, in_keys=["input"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
    module,
    ProbabilisticTensorDictModule(
        in_keys=["loc", "scale"],
        out_keys=["action"],
        distribution_class=TanhNormal,
        return_log_prob=False,
    ),
)
td_module(td)
print(td)
# returning the log-probability
td = TensorDict({"input": torch.randn(3, 5)}, [3])
td_module = ProbabilisticTensorDictSequential(
    module,
    ProbabilisticTensorDictModule(
        in_keys=["loc", "scale"],
        out_keys=["action"],
        distribution_class=TanhNormal,
        return_log_prob=True,
    ),
)
td_module(td)
print(td)

透過上下文管理器 set_exploration_type 可以控制隨機性和取樣策略。

from torchrl.envs.utils import ExplorationType, set_exploration_type

td = TensorDict({"input": torch.randn(3, 5)}, [3])

torch.manual_seed(0)
with set_exploration_type(ExplorationType.RANDOM):
    td_module(td)
    print("random:", td["action"])

with set_exploration_type(ExplorationType.DETERMINISTIC):
    td_module(td)
    print("mode:", td["action"])

使用環境和模組

讓我們看看環境和模組如何結合使用。

from torchrl.envs.utils import step_mdp

env = GymEnv("Pendulum-v1")

action_spec = env.action_spec
actor_module = nn.Linear(3, 1)
actor = SafeModule(
    actor_module, spec=action_spec, in_keys=["observation"], out_keys=["action"]
)

torch.manual_seed(0)
env.set_seed(0)

max_steps = 100
data = env.reset()
data_stack = TensorDict(batch_size=[max_steps])
for i in range(max_steps):
    actor(data)
    data_stack[i] = env.step(data)
    if data["done"].any():
        break
    data = step_mdp(data)  # roughly equivalent to obs = next_obs

tensordicts_prealloc = data_stack.clone()
print("total steps:", i)
print(data_stack)
# equivalent
torch.manual_seed(0)
env.set_seed(0)

max_steps = 100
data = env.reset()
data_stack = []
for _ in range(max_steps):
    actor(data)
    data_stack.append(env.step(data))
    if data["done"].any():
        break
    data = step_mdp(data)  # roughly equivalent to obs = next_obs
tensordicts_stack = torch.stack(data_stack, 0)
print("total steps:", i)
print(tensordicts_stack)
(tensordicts_stack == tensordicts_prealloc).all()
torch.manual_seed(0)
env.set_seed(0)
tensordict_rollout = env.rollout(policy=actor, max_steps=max_steps)
tensordict_rollout


(tensordict_rollout == tensordicts_prealloc).all()

from tensordict.nn import TensorDictModule

收集器

我們還提供了一組資料收集器,它們會自動按照需求收集每批次的幀數。它們支援從單節點、單 worker 到多節點、多 worker 的設定。

from torchrl.collectors import MultiaSyncDataCollector, MultiSyncDataCollector

from torchrl.envs import EnvCreator, SerialEnv
from torchrl.envs.libs.gym import GymEnv

EnvCreator 確保我們可以在程序之間傳送 lambda 函式。為了簡單起見(單個 worker),我們使用 SerialEnv,但對於大型任務,ParallelEnv(多個 worker)將更適合。

注意

多程序環境和多程序收集器可以結合使用!

parallel_env = SerialEnv(
    3,
    EnvCreator(lambda: GymEnv("Pendulum-v1")),
)
create_env_fn = [parallel_env, parallel_env]

actor_module = nn.Linear(3, 1)
actor = TensorDictModule(actor_module, in_keys=["observation"], out_keys=["action"])

同步多程序資料收集器

devices = ["cpu", "cpu"]

collector = MultiSyncDataCollector(
    create_env_fn=create_env_fn,  # either a list of functions or a ParallelEnv
    policy=actor,
    total_frames=240,
    max_frames_per_traj=-1,  # envs are terminating, we don't need to stop them early
    frames_per_batch=60,  # we want 60 frames at a time (we have 3 envs per sub-collector)
    device=devices,
)
for i, d in enumerate(collector):
    if i == 0:
        print(d)  # trajectories are split automatically in [6 workers x 10 steps]
    collector.update_policy_weights_()  # make sure that our policies have the latest weights if working on multiple devices
print(i)
collector.shutdown()
del collector

非同步多程序資料收集器

此類允許您在模型訓練時收集資料。這在離策略(off-policy)設定中特別有用,因為它解耦了推理和模型訓練。資料按先就緒先服務(first-ready-first-served)原則交付(workers 將排隊其結果)。

collector = MultiaSyncDataCollector(
    create_env_fn=create_env_fn,  # either a list of functions or a ParallelEnv
    policy=actor,
    total_frames=240,
    max_frames_per_traj=-1,  # envs are terminating, we don't need to stop them early
    frames_per_batch=60,  # we want 60 frames at a time (we have 3 envs per sub-collector)
    device=devices,
)

for i, d in enumerate(collector):
    if i == 0:
        print(d)  # trajectories are split automatically in [6 workers x 10 steps]
    collector.update_policy_weights_()  # make sure that our policies have the latest weights if working on multiple devices
print(i)
collector.shutdown()
del collector
del create_env_fn
del parallel_env

目標

目標是編寫新演算法時的主要入口點。

from torchrl.objectives import DDPGLoss

actor_module = nn.Linear(3, 1)
actor = TensorDictModule(actor_module, in_keys=["observation"], out_keys=["action"])


class ConcatModule(nn.Linear):
    def forward(self, obs, action):
        return super().forward(torch.cat([obs, action], -1))


value_module = ConcatModule(4, 1)
value = TensorDictModule(
    value_module, in_keys=["observation", "action"], out_keys=["state_action_value"]
)

loss_fn = DDPGLoss(actor, value)
loss_fn.make_value_estimator(loss_fn.default_value_estimator, gamma=0.99)
data = TensorDict(
    {
        "observation": torch.randn(10, 3),
        "next": {
            "observation": torch.randn(10, 3),
            "reward": torch.randn(10, 1),
            "done": torch.zeros(10, 1, dtype=torch.bool),
        },
        "action": torch.randn(10, 1),
    },
    batch_size=[10],
    device="cpu",
)
loss_td = loss_fn(data)

print(loss_td)

print(data)

安裝庫

該庫已釋出到 PyPI:pip install torchrl 更多資訊請參閱 README 檔案。

貢獻

我們正在積極尋找貢獻者和早期使用者。如果您正在從事 RL 工作(或只是好奇),請嘗試一下!請給我們反饋:TorchRL 的成功取決於它如何很好地滿足研究人員的需求。為此,我們需要他們的意見!由於該庫尚處於萌芽階段,現在是塑造您想要它的方式的絕佳時機!

更多資訊請參閱 貢獻指南

由 Sphinx-Gallery 生成的相簿

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源