注意
轉到末尾 下載完整示例程式碼。
TensorDictModule¶
作者: Nicolas Dufour, Vincent Moens
在本教程中,您將學習如何使用 TensorDictModule 和 TensorDictSequential 來建立可接受 TensorDict 作為輸入的通用且可重用的模組。
為了方便地將 TensorDict 類與 Module 一起使用,tensordict 提供了一個名為 TensorDictModule 的介面來連線它們。
TensorDictModule 類是一個 Module,它在呼叫時接受 TensorDict 作為輸入。它將讀取一系列輸入鍵,將它們作為輸入傳遞給包裝的模組或函式,並在執行完成後將輸出寫入同一個 tensordict 中。
輸入和輸出的鍵由使用者定義。
import torch
import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential
簡單示例:編寫一個迴圈層¶
下面示例了 TensorDictModule 的最簡單用法。雖然乍一看使用這個類似乎會引入不必要的複雜性,但我們稍後會看到,這個 API 允許使用者以程式設計方式將模組連線在一起,在模組之間快取值,或者以程式設計方式構建模組。其中一個最簡單的例子是 ResNet 等架構中的迴圈模組,其中模組的輸入被快取並新增到微型多層感知器 (MLP) 的輸出中。
首先,讓我們考慮如何將一個 MLP 分塊,並使用 tensordict.nn 對其進行編碼。堆疊中的第一層可能是一個 Linear 層,它接受一個輸入項(我們將其命名為 x),並輸出另一個項(我們將將其命名為 y)。
為了饋送給我們的模組,我們有一個包含單個項 "x" 的 TensorDict 例項
tensordict = TensorDict(
x=torch.randn(5, 3),
batch_size=[5],
)
現在,我們使用 tensordict.nn.TensorDictModule 構建我們的簡單模組。預設情況下,此類會就地寫入輸入 tensordict(這意味著條目被寫入與輸入相同的 tensordict 中,而不是條目被就地覆蓋!),這樣我們就不需要顯式地指示輸出是什麼
linear0 = TensorDictModule(nn.Linear(3, 128), in_keys=["x"], out_keys=["linear0"])
linear0(tensordict)
assert "linear0" in tensordict
如果模組輸出多個張量(或 tensordicts!),它們的條目必須按照正確的順序傳遞給 TensorDictModule。
支援可呼叫物件¶
在設計模型時,您經常會希望在網路中包含任意的非引數函式。例如,您可能希望在影像傳遞給卷積網路或視覺 transformer 時對其維度進行置換,或者將值除以 255。有幾種方法可以實現這一點:例如,您可以使用 forward_hook,或者設計一個新的 Module 來執行此操作。
TensorDictModule 可以與任何可呼叫物件一起使用,而不僅僅是模組,這使得將任意函式整合到模組中變得容易。例如,讓我們看看如何在不使用 ReLU 模組的情況下整合 relu 啟用函式
relu0 = TensorDictModule(torch.relu, in_keys=["linear0"], out_keys=["relu0"])
堆疊模組¶
我們的 MLP 不是由單層組成的,所以現在我們需要為其新增另一層。這一層將是一個啟用函式,例如 ReLU。我們可以使用 TensorDictSequential 將此模組與前一個模組堆疊在一起。
注意
tensordict.nn 的真正強大之處在於:與 Sequential 不同,TensorDictSequential 會在記憶體中保留所有先前的輸入和輸出(之後可以選擇過濾掉它們),這使得能夠輕鬆地動態且以程式設計方式構建複雜的網路結構。
block0 = TensorDictSequential(linear0, relu0)
block0(tensordict)
assert "linear0" in tensordict
assert "relu0" in tensordict
我們可以重複這個邏輯來得到一個完整的 MLP
linear1 = TensorDictModule(nn.Linear(128, 128), in_keys=["relu0"], out_keys=["linear1"])
relu1 = TensorDictModule(nn.ReLU(), in_keys=["linear1"], out_keys=["relu1"])
linear2 = TensorDictModule(nn.Linear(128, 3), in_keys=["relu1"], out_keys=["linear2"])
block1 = TensorDictSequential(linear1, relu1, linear2)
多個輸入鍵¶
殘差網路的最後一步是將輸入新增到最後一個線性層的輸出中。無需為此編寫特殊的 Module 子類!TensorDictModule 也可以用於包裝簡單的函式
residual = TensorDictModule(
lambda x, y: x + y, in_keys=["x", "linear2"], out_keys=["y"]
)
現在我們可以將 block0、block1 和 residual 組合起來,形成一個完整的殘差塊
block = TensorDictSequential(block0, block1, residual)
block(tensordict)
assert "y" in tensordict
一個真正令人擔憂的問題可能是作為輸入的 tensordict 中條目的累積:在某些情況下(例如,需要梯度時),中間值無論如何都可能被快取,但這並非總是如此,通知垃圾回收器某些條目可以被丟棄可能會很有用。tensordict.nn.TensorDictModuleBase 及其子類(包括 tensordict.nn.TensorDictModule 和 tensordict.nn.TensorDictSequential)可以選擇在執行後過濾其輸出鍵。為此,只需呼叫 tensordict.nn.TensorDictModuleBase.select_out_keys 方法。這將就地更新模組,並且所有不需要的條目將被丟棄
block.select_out_keys("y")
tensordict = TensorDict(x=torch.randn(1, 3), batch_size=[1])
block(tensordict)
assert "y" in tensordict
assert "linear1" not in tensordict
但是,輸入鍵會被保留
assert "x" in tensordict
附帶一提,selected_out_keys 也可以傳遞給 tensordict.nn.TensorDictSequential,以避免單獨呼叫此方法。
不使用 tensordict 來使用 TensorDictModule¶
tensordict.nn.TensorDictSequential 提供的動態構建複雜架構的能力並不意味著必須切換到 tensordict 來表示資料。藉助 dispatch,tensordict.nn 中的模組也支援與條目名稱匹配的引數和關鍵字引數
x = torch.randn(1, 3)
y = block(x=x)
assert isinstance(y, torch.Tensor)
在底層,dispatch 重建一個 tensordict,執行模組,然後將其分解。這可能會帶來一些開銷,但是,正如我們稍後將看到的,有辦法解決這個問題。
執行時效能¶
tensordict.nn.TensorDictModule 和 tensordict.nn.TensorDictSequential 在執行時確實會產生一些開銷,因為它們需要從 tensordict 讀取和寫入。但是,我們可以透過使用 compile() 來大大減少這種開銷。為此,讓我們比較一下此程式碼在使用 compile 和不使用 compile 時的三種版本
class ResidualBlock(nn.Module):
def __init__(self):
super().__init__()
self.linear0 = nn.Linear(3, 128)
self.relu0 = nn.ReLU()
self.linear1 = nn.Linear(128, 128)
self.relu1 = nn.ReLU()
self.linear2 = nn.Linear(128, 3)
def forward(self, x):
y = self.linear0(x)
y = self.relu0(y)
y = self.linear1(y)
y = self.relu1(y)
return self.linear2(y) + x
print("Without compile")
x = torch.randn(256, 3)
block_notd = ResidualBlock()
block_tdm = TensorDictModule(block_notd, in_keys=["x"], out_keys=["y"])
block_tds = block
from torch.utils.benchmark import Timer
print(
f"Regular: {Timer('block_notd(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
print(
f"TDM: {Timer('block_tdm(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
print(
f"Sequential: {Timer('block_tds(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
print("Compiled versions")
block_notd_c = torch.compile(block_notd, mode="reduce-overhead")
for _ in range(5): # warmup
block_notd_c(x)
print(
f"Compiled regular: {Timer('block_notd_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
block_tdm_c = torch.compile(block_tdm, mode="reduce-overhead")
for _ in range(5): # warmup
block_tdm_c(x=x)
print(
f"Compiled TDM: {Timer('block_tdm_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
block_tds_c = torch.compile(block_tds, mode="reduce-overhead")
for _ in range(5): # warmup
block_tds_c(x=x)
print(
f"Compiled sequential: {Timer('block_tds_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
Without compile
Regular: 215.9519 us
TDM: 276.4528 us
Sequential: 486.3226 us
Compiled versions
Compiled regular: 327.3375 us
Compiled TDM: 370.3780 us
Compiled sequential: 382.0440 us
正如大家所見,TensorDictSequential 引入的開銷已得到完全解決。
使用 TensorDictModule 的注意事項¶
不要在來自
tensordict.nn的模組周圍使用Sequence。這會破壞輸入/輸出鍵結構。始終嘗試依賴nn:TensorDictSequential。不要將輸出 tensordict 賦值給一個新變數,因為輸出 tensordict 只是就地修改的輸入。賦值一個新變數名並非嚴格禁止,但這意味著當一個被刪除時,您可能希望兩者都消失,而實際上垃圾回收器仍然會看到工作區中的張量,並且不會釋放記憶體
>>> tensordict = module(tensordict) # ok! >>> tensordict_out = module(tensordict) # don't!
處理分佈:ProbabilisticTensorDictModule¶
ProbabilisticTensorDictModule 是一個表示機率分佈的非引數模組。分佈引數從 tensordict 輸入中讀取,輸出寫入輸出 tensordict。根據輸入 default_interaction_type 引數和 interaction_type() 全域性函式指定的規則對輸出進行取樣。如果它們衝突,上下文管理器優先。
它可以與使用 ProbabilisticTensorDictSequential 更新了分佈引數的 TensorDictModule 結合使用。這是 TensorDictSequential 的一個特例,其最後一層是一個 ProbabilisticTensorDictModule 例項。
ProbabilisticTensorDictModule 負責構建分佈(透過 get_dist() 方法)和/或從該分佈中進行取樣(透過對模組進行常規的 forward 呼叫)。相同的 get_dist() 方法也在 ProbabilisticTensorDictSequential 中公開。
如果需要,可以在輸出 tensordict 中找到引數以及對數機率。
from tensordict.nn import (
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
)
from tensordict.nn.distributions import NormalParamExtractor
from torch import distributions as dist
td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3])
net = torch.nn.GRUCell(4, 8)
net = TensorDictModule(net, in_keys=["input", "hidden"], out_keys=["hidden"])
extractor = NormalParamExtractor()
extractor = TensorDictModule(extractor, in_keys=["hidden"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
net,
extractor,
ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=dist.Normal,
return_log_prob=True,
),
)
print(f"TensorDict before going through module: {td}")
td_module(td)
print(f"TensorDict after going through module now as keys action, loc and scale: {td}")
TensorDict before going through module: TensorDict(
fields={
hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
TensorDict after going through module now as keys action, loc and scale: TensorDict(
fields={
action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
sample_log_prob: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
結論¶
我們已經看到 tensordict.nn 如何用於動態地即時構建複雜的神經網路架構。這開啟了構建對模型簽名無感知的管道的可能性,也就是說,可以編寫通用的程式碼,以靈活的方式使用具有任意數量輸入或輸出的網路。
我們還看到 dispatch 如何使得能夠使用 tensordict.nn 構建此類網路並使用它們,而無需直接使用 TensorDict。得益於 compile(),tensordict.nn.TensorDictSequential 引入的開銷可以被完全消除,從而為使用者提供了一個整潔的、無需 tensordict 的模組版本。
在下一個教程中,我們將看到如何使用 torch.export 來隔離模組並將其匯出。
指令碼總執行時間: (0 分鐘 16.867 秒)