注意
請跳轉至末尾 下載完整的示例程式碼。
TorchRL 簡介¶
本簡報在 ICML 2022 工業演示日上進行展示。
它很好地概述了 TorchRL 的功能。如果您對此有疑問或意見,請隨時聯絡 vmoens@fb.com 或提交 issue。
TorchRL 是一個用於 PyTorch 的開源強化學習 (RL) 庫。
PyTorch 生態系統團隊 (Meta) 已決定投資該庫,以提供一個領先的平臺,用於在研究環境中開發 RL 解決方案。
它提供了 pytorch 和Python 優先的低階和高階抽象,旨在實現高效、文件齊全和經過充分測試。程式碼旨在支援 RL 研究。大部分程式碼以高度模組化的方式用 Python 編寫,以便研究人員可以輕鬆地替換元件、轉換元件或編寫新元件,而無需太多精力。
該倉庫力求與現有的 pytorch 生態系統庫對齊,它具有資料集支柱 (torchrl/envs)、轉換、模型、資料工具(例如收集器和容器)等。TorchRL 的目標是儘可能少地依賴項(Python 標準庫、numpy 和 pytorch)。常見的環境庫(例如 OpenAI gym)僅為可選依賴。
與許多其他領域不同,RL 更多關注演算法而非媒體。因此,很難構建真正獨立的元件。
TorchRL 不是什麼
演算法集合:我們不打算提供 RL 演算法的 SOTA(最新最好)實現,我們提供這些演算法僅作為如何使用該庫的示例。
研究框架:TorchRL 的模組化有兩種形式。首先,我們嘗試構建可重用元件,以便它們可以輕鬆地相互替換。其次,我們盡最大努力使元件可以獨立於庫的其餘部分使用。
TorchRL 的核心依賴項非常少,主要是 PyTorch 和 numpy。所有其他依賴項(gym、torchvision、wandb / tensorboard)都是可選的。
資料¶
TensorDict¶
import torch
from tensordict import TensorDict
讓我們建立一個 TensorDict。建構函式接受許多不同的格式,例如傳入一個 dict 或使用關鍵字引數。
batch_size = 5
data = TensorDict(
key1=torch.zeros(batch_size, 3),
key2=torch.zeros(batch_size, 5, 6, dtype=torch.bool),
batch_size=[batch_size],
)
print(data)
您可以沿著 batch_size 索引 TensorDict,也可以查詢鍵 (keys)。
print(data[2])
print(data["key1"] is data.get("key1"))
以下展示瞭如何堆疊多個 TensorDict。這在編寫 rollout 迴圈時特別有用!
data1 = TensorDict(
{
"key1": torch.zeros(batch_size, 1),
"key2": torch.zeros(batch_size, 5, 6, dtype=torch.bool),
},
batch_size=[batch_size],
)
data2 = TensorDict(
{
"key1": torch.ones(batch_size, 1),
"key2": torch.ones(batch_size, 5, 6, dtype=torch.bool),
},
batch_size=[batch_size],
)
data = torch.stack([data1, data2], 0)
data.batch_size, data["key1"]
以下是 TensorDict 的其他一些功能:檢視、置換、共享記憶體或擴充套件。
print(
"view(-1): ",
data.view(-1).batch_size,
data.view(-1).get("key1").shape,
)
print("to device: ", data.to("cpu"))
# print("pin_memory: ", data.pin_memory())
print("share memory: ", data.share_memory_())
print(
"permute(1, 0): ",
data.permute(1, 0).batch_size,
data.permute(1, 0).get("key1").shape,
)
print(
"expand: ",
data.expand(3, *data.batch_size).batch_size,
data.expand(3, *data.batch_size).get("key1").shape,
)
您也可以建立巢狀資料。
data = TensorDict(
source={
"key1": torch.zeros(batch_size, 3),
"key2": TensorDict(
source={"sub_key1": torch.zeros(batch_size, 2, 1)},
batch_size=[batch_size, 2],
),
},
batch_size=[batch_size],
)
data
經驗回放緩衝區¶
經驗回放緩衝區是許多 RL 演算法中的關鍵組成部分。TorchRL 提供了一系列經驗回放緩衝區實現。大多數基本功能適用於任何資料結構(list、tuples、dict),但為了充分發揮經驗回放緩衝區的作用並實現快速讀寫訪問,應優先使用 TensorDict API。
from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer
rb = ReplayBuffer(collate_fn=lambda x: x)
可以使用 add() (n=1) 或 extend() (n>1) 新增。
rb.add(1)
rb.sample(1)
rb.extend([2, 3])
rb.sample(3)
也可以使用優先經驗回放緩衝區
rb = PrioritizedReplayBuffer(alpha=0.7, beta=1.1, collate_fn=lambda x: x)
rb.add(1)
rb.sample(1)
rb.update_priority(1, 0.5)
以下是使用帶 data_stack 的經驗回放緩衝區的示例。使用它們可以輕鬆地抽象經驗回放緩衝區在多種用例中的行為。
collate_fn = torch.stack
rb = ReplayBuffer(collate_fn=collate_fn)
rb.add(TensorDict({"a": torch.randn(3)}, batch_size=[]))
len(rb)
rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2]))
print(len(rb))
print(rb.sample(10))
print(rb.sample(2).contiguous())
torch.manual_seed(0)
from torchrl.data import TensorDictPrioritizedReplayBuffer
rb = TensorDictPrioritizedReplayBuffer(alpha=0.7, beta=1.1, priority_key="td_error")
rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2]))
data_sample = rb.sample(2).contiguous()
print(data_sample)
print(data_sample["index"])
data_sample["td_error"] = torch.rand(2)
rb.update_tensordict_priority(data_sample)
for i, val in enumerate(rb._sampler._sum_tree):
print(i, val)
if i == len(rb):
break
環境¶
TorchRL 提供了一系列環境封裝器和工具。
Gym 環境¶
try:
import gymnasium as gym
except ModuleNotFoundError:
import gym
from torchrl.envs.libs.gym import GymEnv, GymWrapper, set_gym_backend
gym_env = gym.make("Pendulum-v1")
env = GymWrapper(gym_env)
env = GymEnv("Pendulum-v1")
data = env.reset()
env.rand_step(data)
更改環境配置¶
env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
env.reset()
env.close()
del env
from torchrl.envs import (
Compose,
NoopResetEnv,
ObservationNorm,
ToTensorImage,
TransformedEnv,
)
base_env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage()))
env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))
環境轉換¶
轉換的作用類似於 Gym 封裝器,但其 API 更接近 torchvision 的 torch.distributions 的轉換。有多種轉換可供選擇。
from torchrl.envs import (
Compose,
NoopResetEnv,
ObservationNorm,
StepCounter,
ToTensorImage,
TransformedEnv,
)
base_env = GymEnv("HalfCheetah-v4", frame_skip=3, from_pixels=True, pixels_only=False)
env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage()))
env = env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))
env.reset()
print("env: ", env)
print("last transform parent: ", env.transform[2].parent)
向量化環境¶
向量化/並行環境可以提供顯著的加速。
from torchrl.envs import ParallelEnv
def make_env():
# You can control whether to use gym or gymnasium for your env
with set_gym_backend("gym"):
return GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
base_env = ParallelEnv(
4,
make_env,
mp_start_method="fork", # This will break on Windows machines! Remove and decorate with if __name__ == "__main__"
)
env = TransformedEnv(
base_env, Compose(StepCounter(), ToTensorImage())
) # applies transforms on batch of envs
env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))
env.reset()
print(env.action_spec)
env.close()
del env
模組¶
庫中提供了多種模組(工具、模型和封裝器)。
模型¶
MLP 模型示例
from torch import nn
from torchrl.modules import ConvNet, MLP
from torchrl.modules.models.utils import SquashDims
net = MLP(num_cells=[32, 64], out_features=4, activation_class=nn.ELU)
print(net)
print(net(torch.randn(10, 3)).shape)
CNN 模型示例
cnn = ConvNet(
num_cells=[32, 64],
kernel_sizes=[8, 4],
strides=[2, 1],
aggregator_class=SquashDims,
)
print(cnn)
print(cnn(torch.randn(10, 3, 32, 32)).shape) # last tensor is squashed
TensorDict 模組¶
一些模組專門設計用於處理 TensorDict 輸入。
from tensordict.nn import TensorDictModule
data = TensorDict({"key1": torch.randn(10, 3)}, batch_size=[10])
module = nn.Linear(3, 4)
td_module = TensorDictModule(module, in_keys=["key1"], out_keys=["key2"])
td_module(data)
print(data)
模組序列¶
TensorDictSequential 使構建模組序列變得容易。
from tensordict.nn import TensorDictSequential
backbone_module = nn.Linear(5, 3)
backbone = TensorDictModule(
backbone_module, in_keys=["observation"], out_keys=["hidden"]
)
actor_module = nn.Linear(3, 4)
actor = TensorDictModule(actor_module, in_keys=["hidden"], out_keys=["action"])
value_module = MLP(out_features=1, num_cells=[4, 5])
value = TensorDictModule(value_module, in_keys=["hidden", "action"], out_keys=["value"])
sequence = TensorDictSequential(backbone, actor, value)
print(sequence)
print(sequence.in_keys, sequence.out_keys)
data = TensorDict(
{"observation": torch.randn(3, 5)},
[3],
)
backbone(data)
actor(data)
value(data)
data = TensorDict(
{"observation": torch.randn(3, 5)},
[3],
)
sequence(data)
print(data)
函數語言程式設計(整合 / 元強化學習)¶
函式式呼叫從未如此簡單。使用 from_module() 提取引數,然後使用 to_module() 替換它們。
from tensordict import from_module
params = from_module(sequence)
print("extracted params", params)
使用 TensorDict 的函式式呼叫
with params.to_module(sequence):
data = sequence(data)
VMAP¶
快速執行多個相似架構的副本是快速訓練模型的關鍵。vmap() 正是為此量身定製的。
專門類¶
TorchRL 還提供了一些專門模組,它們對輸出值執行檢查。
torch.manual_seed(0)
from torchrl.data import Bounded
from torchrl.modules import SafeModule
spec = Bounded(-torch.ones(3), torch.ones(3))
base_module = nn.Linear(5, 3)
module = SafeModule(
module=base_module, spec=spec, in_keys=["obs"], out_keys=["action"], safe=True
)
data = TensorDict({"obs": torch.randn(5)}, batch_size=[])
module(data)["action"]
data = TensorDict({"obs": torch.randn(5) * 100}, batch_size=[])
module(data)["action"] # safe=True projects the result within the set
Actor 類有一個預定義的輸出鍵 ("action")。
from torchrl.modules import Actor
base_module = nn.Linear(5, 3)
actor = Actor(base_module, in_keys=["obs"])
data = TensorDict({"obs": torch.randn(5)}, batch_size=[])
actor(data) # action is the default value
from tensordict.nn import (
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
)
藉助 tensordict.nn API,處理機率模型也變得容易。
from torchrl.modules import NormalParamExtractor, TanhNormal
td = TensorDict({"input": torch.randn(3, 5)}, [3])
net = nn.Sequential(
nn.Linear(5, 4), NormalParamExtractor()
) # splits the output in loc and scale
module = TensorDictModule(net, in_keys=["input"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
module,
ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=TanhNormal,
return_log_prob=False,
),
)
td_module(td)
print(td)
# returning the log-probability
td = TensorDict({"input": torch.randn(3, 5)}, [3])
td_module = ProbabilisticTensorDictSequential(
module,
ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=TanhNormal,
return_log_prob=True,
),
)
td_module(td)
print(td)
透過上下文管理器 set_exploration_type 可以控制隨機性和取樣策略。
from torchrl.envs.utils import ExplorationType, set_exploration_type
td = TensorDict({"input": torch.randn(3, 5)}, [3])
torch.manual_seed(0)
with set_exploration_type(ExplorationType.RANDOM):
td_module(td)
print("random:", td["action"])
with set_exploration_type(ExplorationType.DETERMINISTIC):
td_module(td)
print("mode:", td["action"])
使用環境和模組¶
讓我們看看環境和模組如何結合使用。
from torchrl.envs.utils import step_mdp
env = GymEnv("Pendulum-v1")
action_spec = env.action_spec
actor_module = nn.Linear(3, 1)
actor = SafeModule(
actor_module, spec=action_spec, in_keys=["observation"], out_keys=["action"]
)
torch.manual_seed(0)
env.set_seed(0)
max_steps = 100
data = env.reset()
data_stack = TensorDict(batch_size=[max_steps])
for i in range(max_steps):
actor(data)
data_stack[i] = env.step(data)
if data["done"].any():
break
data = step_mdp(data) # roughly equivalent to obs = next_obs
tensordicts_prealloc = data_stack.clone()
print("total steps:", i)
print(data_stack)
# equivalent
torch.manual_seed(0)
env.set_seed(0)
max_steps = 100
data = env.reset()
data_stack = []
for _ in range(max_steps):
actor(data)
data_stack.append(env.step(data))
if data["done"].any():
break
data = step_mdp(data) # roughly equivalent to obs = next_obs
tensordicts_stack = torch.stack(data_stack, 0)
print("total steps:", i)
print(tensordicts_stack)
(tensordicts_stack == tensordicts_prealloc).all()
torch.manual_seed(0)
env.set_seed(0)
tensordict_rollout = env.rollout(policy=actor, max_steps=max_steps)
tensordict_rollout
(tensordict_rollout == tensordicts_prealloc).all()
from tensordict.nn import TensorDictModule
收集器¶
我們還提供了一組資料收集器,它們會自動按照需求收集每批次的幀數。它們支援從單節點、單 worker 到多節點、多 worker 的設定。
from torchrl.collectors import MultiaSyncDataCollector, MultiSyncDataCollector
from torchrl.envs import EnvCreator, SerialEnv
from torchrl.envs.libs.gym import GymEnv
EnvCreator 確保我們可以在程序之間傳送 lambda 函式。為了簡單起見(單個 worker),我們使用 SerialEnv,但對於大型任務,ParallelEnv(多個 worker)將更適合。
注意
多程序環境和多程序收集器可以結合使用!
parallel_env = SerialEnv(
3,
EnvCreator(lambda: GymEnv("Pendulum-v1")),
)
create_env_fn = [parallel_env, parallel_env]
actor_module = nn.Linear(3, 1)
actor = TensorDictModule(actor_module, in_keys=["observation"], out_keys=["action"])
同步多程序資料收集器¶
devices = ["cpu", "cpu"]
collector = MultiSyncDataCollector(
create_env_fn=create_env_fn, # either a list of functions or a ParallelEnv
policy=actor,
total_frames=240,
max_frames_per_traj=-1, # envs are terminating, we don't need to stop them early
frames_per_batch=60, # we want 60 frames at a time (we have 3 envs per sub-collector)
device=devices,
)
for i, d in enumerate(collector):
if i == 0:
print(d) # trajectories are split automatically in [6 workers x 10 steps]
collector.update_policy_weights_() # make sure that our policies have the latest weights if working on multiple devices
print(i)
collector.shutdown()
del collector
非同步多程序資料收集器¶
此類允許您在模型訓練時收集資料。這在離策略(off-policy)設定中特別有用,因為它解耦了推理和模型訓練。資料按先就緒先服務(first-ready-first-served)原則交付(workers 將排隊其結果)。
collector = MultiaSyncDataCollector(
create_env_fn=create_env_fn, # either a list of functions or a ParallelEnv
policy=actor,
total_frames=240,
max_frames_per_traj=-1, # envs are terminating, we don't need to stop them early
frames_per_batch=60, # we want 60 frames at a time (we have 3 envs per sub-collector)
device=devices,
)
for i, d in enumerate(collector):
if i == 0:
print(d) # trajectories are split automatically in [6 workers x 10 steps]
collector.update_policy_weights_() # make sure that our policies have the latest weights if working on multiple devices
print(i)
collector.shutdown()
del collector
del create_env_fn
del parallel_env
目標¶
目標是編寫新演算法時的主要入口點。
from torchrl.objectives import DDPGLoss
actor_module = nn.Linear(3, 1)
actor = TensorDictModule(actor_module, in_keys=["observation"], out_keys=["action"])
class ConcatModule(nn.Linear):
def forward(self, obs, action):
return super().forward(torch.cat([obs, action], -1))
value_module = ConcatModule(4, 1)
value = TensorDictModule(
value_module, in_keys=["observation", "action"], out_keys=["state_action_value"]
)
loss_fn = DDPGLoss(actor, value)
loss_fn.make_value_estimator(loss_fn.default_value_estimator, gamma=0.99)
data = TensorDict(
{
"observation": torch.randn(10, 3),
"next": {
"observation": torch.randn(10, 3),
"reward": torch.randn(10, 1),
"done": torch.zeros(10, 1, dtype=torch.bool),
},
"action": torch.randn(10, 1),
},
batch_size=[10],
device="cpu",
)
loss_td = loss_fn(data)
print(loss_td)
print(data)
安裝庫¶
該庫已釋出到 PyPI:pip install torchrl 更多資訊請參閱 README 檔案。
貢獻¶
我們正在積極尋找貢獻者和早期使用者。如果您正在從事 RL 工作(或只是好奇),請嘗試一下!請給我們反饋:TorchRL 的成功取決於它如何很好地滿足研究人員的需求。為此,我們需要他們的意見!由於該庫尚處於萌芽階段,現在是塑造您想要它的方式的絕佳時機!
更多資訊請參閱 貢獻指南。