• 文件 >
  • 使用預訓練模型
快捷方式

使用預訓練模型

本教程介紹如何在 TorchRL 中使用預訓練模型。

import tempfile

完成本教程後,你將能夠使用預訓練模型進行高效的影像表示,並對其進行微調。

TorchRL 提供了預訓練模型,這些模型既可以作為 transforms 使用,也可以作為策略的組成部分。由於它們的語義相同,因此可以在不同上下文中互換使用。在本教程中,我們將使用 R3M (https://arxiv.org/abs/2203.12601),但其他模型(例如 VIP)也同樣適用。

import torch.cuda
from tensordict.nn import TensorDictSequential
from torch import nn
from torchrl.envs import Compose, R3MTransform, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import Actor

is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)

首先讓我們建立一個環境。為簡單起見,我們將使用一個常見的 gym 環境。在實踐中,這在更具挑戰性的具身 AI 場景(例如,檢視我們的 Habitat 封裝器)中也會有效。

base_env = GymEnv("Ant-v4", from_pixels=True, device=device)

讓我們獲取預訓練模型。我們透過設定 download=True 標誌來請求模型的預訓練版本。預設情況下此功能是關閉的。接下來,我們將 transform 附加到環境中。實際上,每次收集的資料批次都將透過 transform 並對映到輸出 tensordict 中的“r3m_vec”條目。我們的策略,由單層 MLP 組成,將讀取此向量並計算相應的動作。

r3m = R3MTransform(
    "resnet50",
    in_keys=["pixels"],
    download=False,  # Turn to true for real-life testing
)
env_transformed = TransformedEnv(base_env, r3m)
net = nn.Sequential(
    nn.LazyLinear(128, device=device),
    nn.Tanh(),
    nn.Linear(128, base_env.action_spec.shape[-1], device=device),
)
policy = Actor(net, in_keys=["r3m_vec"])

讓我們檢查策略的引數數量

print("number of params:", len(list(policy.parameters())))

我們收集 32 步的 rollout 並列印其輸出

rollout = env_transformed.rollout(32, policy)
print("rollout with transform:", rollout)

對於微調,我們在使引數可訓練後將 transform 整合到策略中。實際上,更明智的做法可能是將其限制在引數的子集(例如 MLP 的最後一層)。

r3m.train()
policy = TensorDictSequential(r3m, policy)
print("number of params after r3m is integrated:", len(list(policy.parameters())))

再次,我們使用 R3M 收集 rollout。輸出的結構略有變化,因為現在環境返回畫素(而不是 embedding)。embedding“r3m_vec”是我們策略的中間結果。

rollout = base_env.rollout(32, policy)
print("rollout, fine tuning:", rollout)

我們之所以能如此輕鬆地將 transform 從環境切換到策略,是因為它們都表現得像 TensorDictModule:它們有一組 “in_keys”“out_keys”,這使得在不同上下文中讀取和寫入輸出變得容易。

作為本教程的總結,讓我們看看如何使用 R3M 讀取儲存在回放緩衝區(例如,在離線 RL 場景中)中的影像。首先,讓我們構建資料集

from torchrl.data import LazyMemmapStorage, ReplayBuffer

buffer_scratch_dir = tempfile.TemporaryDirectory().name
storage = LazyMemmapStorage(1000, scratch_dir=buffer_scratch_dir)
rb = ReplayBuffer(storage=storage, transform=Compose(lambda td: td.to(device), r3m))

現在我們可以收集資料(為演示目的使用隨機 rollouts)並用它填充回放緩衝區

total = 0
while total < 1000:
    tensordict = base_env.rollout(1000)
    rb.extend(tensordict)
    total += tensordict.numel()

讓我們檢查回放緩衝區的儲存內容。它應該不包含“r3m_vec”條目,因為我們還沒有使用它

print("stored data:", storage._storage)

取樣時,資料將透過 R3M transform,得到我們想要的處理後的資料。透過這種方式,我們可以在由影像組成的資料集上離線訓練演算法

batch = rb.sample(32)
print("data after sampling:", batch)

由 Sphinx-Gallery 生成的相簿

文件

訪問 PyTorch 全面開發者文件

檢視文件

教程

獲取適合初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源