注意
前往末尾下載完整的示例程式碼。
匯出 tensordict 模組¶
作者: Vincent Moens
先決條件¶
建議閱讀 TensorDictModule 教程,以便充分利用本教程。
使用 tensordict.nn 編寫模組後,通常將計算圖隔離並匯出該圖很有用。這樣做的目的是在硬體(例如,機器人、無人機、邊緣裝置)上執行模型,或者完全消除對 tensordict 的依賴。
PyTorch 提供了多種匯出模組的方法,包括 onnx 和 torch.export,它們都與 tensordict 相容。
在這個簡短的教程中,我們將看到如何使用 torch.export 來隔離模型的計算圖。torch.onnx 支援遵循相同的邏輯。
主要學習內容¶
在沒有 TensorDict 輸入的情況下執行
tensordict.nn模組;選擇模型的輸出;
處理隨機模型;
使用 torch.export 匯出此類模型;
將模型儲存到檔案;
隔離 PyTorch 模型;
import time
import torch
from tensordict.nn import (
InteractionType,
NormalParamExtractor,
ProbabilisticTensorDictModule as Prob,
set_interaction_type,
TensorDictModule as Mod,
TensorDictSequential as Seq,
)
from torch import distributions as dists, nn
設計模型¶
在許多應用中,使用隨機模型非常有用,即輸出不是確定性定義而是根據引數化分佈取樣的變數的模型。例如,生成式 AI 模型在給定相同輸入時通常會生成不同的輸出,因為它們根據其引數由輸入定義的分佈來取樣輸出。
tensordict 庫透過 ProbabilisticTensorDictModule 類處理此問題。此原語使用一個分佈類(在我們的例子中是 Normal)和一個指示符構建,指示在執行時將用於構建該分佈的輸入鍵。
因此,我們構建的網路將是三個主要元件的組合
一個將輸入對映到潛在引數的網路;
一個
tensordict.nn.NormalParamExtractor模組,將輸入分割為位置 “loc” 和比例 “scale” 引數,以便傳遞給Normal分佈;一個分佈構造器模組。
model = Seq(
# 1. A small network for embedding
Mod(nn.Linear(3, 4), in_keys=["x"], out_keys=["hidden"]),
Mod(nn.ReLU(), in_keys=["hidden"], out_keys=["hidden"]),
Mod(nn.Linear(4, 4), in_keys=["hidden"], out_keys=["latent"]),
# 2. Extracting params
Mod(NormalParamExtractor(), in_keys=["latent"], out_keys=["loc", "scale"]),
# 3. Probabilistic module
Prob(
in_keys=["loc", "scale"],
out_keys=["sample"],
distribution_class=dists.Normal,
),
)
讓我們執行這個模型,看看輸出是什麼樣子
x = torch.randn(1, 3)
print(model(x=x))
(tensor([[0.2805, 0.4506, 0.0000, 0.1332]], grad_fn=<ReluBackward0>), tensor([[ 0.6580, -0.1202, 0.2788, -0.4807]], grad_fn=<AddmmBackward0>), tensor([[ 0.6580, -0.1202]], grad_fn=<SplitBackward0>), tensor([[1.1840, 0.7258]], grad_fn=<ClampMinBackward0>), tensor([[ 0.6580, -0.1202]], grad_fn=<SplitBackward0>))
正如預期的那樣,使用張量輸入執行模型會返回與模組輸出鍵數量相同的張量!對於大型模型,這可能非常令人惱火且浪費。稍後,我們將看到如何限制模型的輸出數量來解決這個問題。
將 torch.export 與 TensorDictModule 一起使用¶
現在我們已經成功構建了模型,我們希望將其計算圖提取到一個獨立於 tensordict 的單一物件中。torch.export 是一個 PyTorch 模組,專門用於隔離模組的圖並以標準化方式表示它。其主要入口點是 export(),它返回一個 ExportedProgram 物件。反過來,此物件具有我們將在下面探索的幾個感興趣的屬性:一個 graph_module,表示由 export 捕獲的 FX 圖,一個 graph_signature,包含圖的輸入、輸出等,最後是一個 module(),返回一個可呼叫物件,可以代替原始模組使用。
儘管我們的模組同時接受 args 和 kwargs,但我們將重點關注其與 kwargs 一起使用的情況,因為這樣更清晰。
from torch.export import export
model_export = export(model, args=(), kwargs={"x": x})
讓我們看看這個模組
print("module:", model_export.module())
module: GraphModule(
(module): Module(
(0): Module(
(module): Module()
)
(2): Module(
(module): Module()
)
)
)
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
module_0_module_weight = getattr(self.module, "0").module.weight
module_0_module_bias = getattr(self.module, "0").module.bias
module_2_module_weight = getattr(self.module, "2").module.weight
module_2_module_bias = getattr(self.module, "2").module.bias
linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); module_2_module_weight = module_2_module_bias = None
chunk = torch.ops.aten.chunk.default(linear_1, 2, -1)
getitem = chunk[0]
getitem_1 = chunk[1]; chunk = None
add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus = torch.ops.aten.softplus.default(add); add = None
add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2 = broadcast_tensors[0]
getitem_3 = broadcast_tensors[1]; broadcast_tensors = None
return pytree.tree_unflatten((relu, linear_1, getitem_2, getitem_3, getitem_2), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
這個模組可以像我們的原始模組一樣執行(開銷更低)
Time for TDModule: 519.51 micro-seconds
Time for exported module: 352.14 micro-seconds
以及 FX 圖
print("fx graph:", model_export.graph_module.print_readable())
class GraphModule(torch.nn.Module):
def forward(self, p_l__args___0_module_0_module_weight: "f32[4, 3]", p_l__args___0_module_0_module_bias: "f32[4]", p_l__args___0_module_2_module_weight: "f32[4, 4]", p_l__args___0_module_2_module_bias: "f32[4]", x: "f32[1, 3]"):
# File: /pytorch/tensordict/tensordict/nn/common.py:1109 in _call_module, code: out = self.module(*tensors, **kwargs)
linear: "f32[1, 4]" = torch.ops.aten.linear.default(x, p_l__args___0_module_0_module_weight, p_l__args___0_module_0_module_bias); x = p_l__args___0_module_0_module_weight = p_l__args___0_module_0_module_bias = None
relu: "f32[1, 4]" = torch.ops.aten.relu.default(linear); linear = None
linear_1: "f32[1, 4]" = torch.ops.aten.linear.default(relu, p_l__args___0_module_2_module_weight, p_l__args___0_module_2_module_bias); p_l__args___0_module_2_module_weight = p_l__args___0_module_2_module_bias = None
# File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:85 in forward, code: loc, scale = tensor.chunk(2, -1)
chunk = torch.ops.aten.chunk.default(linear_1, 2, -1)
getitem: "f32[1, 2]" = chunk[0]
getitem_1: "f32[1, 2]" = chunk[1]; chunk = None
# File: /pytorch/tensordict/tensordict/nn/utils.py:71 in forward, code: return torch.nn.functional.softplus(x + self.bias) + self.min_val
add: "f32[1, 2]" = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus: "f32[1, 2]" = torch.ops.aten.softplus.default(add); add = None
add_1: "f32[1, 2]" = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
# File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:86 in forward, code: scale = self.scale_mapping(scale).clamp_min(self.scale_lb)
clamp_min: "f32[1, 2]" = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/distributions/utils.py:57 in broadcast_all, code: return torch.broadcast_tensors(*values)
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2: "f32[1, 2]" = broadcast_tensors[0]
getitem_3: "f32[1, 2]" = broadcast_tensors[1]; broadcast_tensors = None
return (relu, linear_1, getitem_2, getitem_3, getitem_2)
fx graph: class GraphModule(torch.nn.Module):
def forward(self, p_l__args___0_module_0_module_weight: "f32[4, 3]", p_l__args___0_module_0_module_bias: "f32[4]", p_l__args___0_module_2_module_weight: "f32[4, 4]", p_l__args___0_module_2_module_bias: "f32[4]", x: "f32[1, 3]"):
# File: /pytorch/tensordict/tensordict/nn/common.py:1109 in _call_module, code: out = self.module(*tensors, **kwargs)
linear: "f32[1, 4]" = torch.ops.aten.linear.default(x, p_l__args___0_module_0_module_weight, p_l__args___0_module_0_module_bias); x = p_l__args___0_module_0_module_weight = p_l__args___0_module_0_module_bias = None
relu: "f32[1, 4]" = torch.ops.aten.relu.default(linear); linear = None
linear_1: "f32[1, 4]" = torch.ops.aten.linear.default(relu, p_l__args___0_module_2_module_weight, p_l__args___0_module_2_module_bias); p_l__args___0_module_2_module_weight = p_l__args___0_module_2_module_bias = None
# File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:85 in forward, code: loc, scale = tensor.chunk(2, -1)
chunk = torch.ops.aten.chunk.default(linear_1, 2, -1)
getitem: "f32[1, 2]" = chunk[0]
getitem_1: "f32[1, 2]" = chunk[1]; chunk = None
# File: /pytorch/tensordict/tensordict/nn/utils.py:71 in forward, code: return torch.nn.functional.softplus(x + self.bias) + self.min_val
add: "f32[1, 2]" = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus: "f32[1, 2]" = torch.ops.aten.softplus.default(add); add = None
add_1: "f32[1, 2]" = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
# File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:86 in forward, code: scale = self.scale_mapping(scale).clamp_min(self.scale_lb)
clamp_min: "f32[1, 2]" = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/distributions/utils.py:57 in broadcast_all, code: return torch.broadcast_tensors(*values)
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2: "f32[1, 2]" = broadcast_tensors[0]
getitem_3: "f32[1, 2]" = broadcast_tensors[1]; broadcast_tensors = None
return (relu, linear_1, getitem_2, getitem_3, getitem_2)
處理巢狀鍵¶
巢狀鍵是 tensordict 庫的核心功能,因此能夠匯出讀取和寫入巢狀條目的模組是一項重要的支援功能。由於關鍵字引數必須是常規字串,因此 dispatch 無法直接使用它們。相反,dispatch 將解包用常規下劃線(“_”)連線的巢狀鍵,如下例所示。
model_nested = Seq(
Mod(lambda x: x + 1, in_keys=[("some", "key")], out_keys=["hidden"]),
Mod(lambda x: x - 1, in_keys=["hidden"], out_keys=[("some", "output")]),
).select_out_keys(("some", "output"))
model_nested_export = export(model_nested, args=(), kwargs={"some_key": x})
print("exported module with nested input:", model_nested_export.module())
exported module with nested input: GraphModule()
def forward(self, some_key):
some_key, = fx_pytree.tree_flatten_spec(([], {'some_key':some_key}), self._in_spec)
add = torch.ops.aten.add.Tensor(some_key, 1); some_key = None
sub = torch.ops.aten.sub.Tensor(add, 1); add = None
return pytree.tree_unflatten((sub,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
請注意,module() 返回的可呼叫物件是純 Python 可呼叫物件,可以反過來使用 compile() 進行編譯。
儲存匯出的模組¶
torch.export 有自己的序列化協議,即 save() 和 load()。通常使用 “.pt2” 副檔名。
>>> torch.export.save(model_export, "model.pt2")
選擇輸出¶
回想一下,tensordict.nn 會在輸出中保留所有中間值,除非使用者明確要求只保留特定值。在訓練期間,這非常有用:可以輕鬆記錄圖的中間值,或將其用於其他目的(例如,根據儲存的引數重建分佈,而不是儲存 Distribution 物件本身)。還可以爭辯說,在訓練期間,註冊中間值對記憶體的影響可以忽略不計,因為它們是 torch.autograd 用於計算引數梯度的計算圖的一部分。
然而,在推理期間,我們最可能只對模型的最終樣本感興趣。因為我們希望提取與 tensordict 庫無關的模型,所以隔離我們想要的唯一輸出是有意義的。為此,我們有幾個選項
使用
selected_out_keys關鍵字引數構建TensorDictSequential(),這將在使用模組時引導選擇所需的條目;使用
select_out_keys()方法,該方法將就地修改out_keys屬性(可以透過reset_out_keys()恢復)。將現有例項包裝在
TensorDictSequential()中,該類將過濾掉不需要的鍵>>> module_filtered = Seq(module, selected_out_keys=["sample"])
讓我們在選擇其輸出鍵後測試模型。當提供一個 x 輸入時,我們期望模型輸出一個對應於分佈樣本的單個張量
tensor([[ 0.6580, -0.1202]], grad_fn=<SplitBackward0>)
我們看到輸出現在是一個單個張量,對應於分佈的樣本。我們可以從中建立一個新的匯出圖。其計算圖應該被簡化
model_export = export(model, args=(), kwargs={"x": x})
print("module:", model_export.module())
module: GraphModule(
(module): Module(
(0): Module(
(module): Module()
)
(2): Module(
(module): Module()
)
)
)
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
module_0_module_weight = getattr(self.module, "0").module.weight
module_0_module_bias = getattr(self.module, "0").module.bias
module_2_module_weight = getattr(self.module, "2").module.weight
module_2_module_bias = getattr(self.module, "2").module.bias
linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); relu = module_2_module_weight = module_2_module_bias = None
chunk = torch.ops.aten.chunk.default(linear_1, 2, -1); linear_1 = None
getitem = chunk[0]
getitem_1 = chunk[1]; chunk = None
add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus = torch.ops.aten.softplus.default(add); add = None
add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2 = broadcast_tensors[0]
getitem_3 = broadcast_tensors[1]; broadcast_tensors = getitem_3 = None
return pytree.tree_unflatten((getitem_2,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
控制取樣策略¶
我們尚未討論 ProbabilisticTensorDictModule 如何從分佈中取樣。取樣是指根據特定策略在分佈定義的空間內獲取一個值。例如,人們可能希望在訓練期間獲得隨機樣本,但在推理時獲得確定性樣本(例如,均值或眾數)。為了解決這個問題,tensordict 利用了 set_interaction_type 裝飾器和上下文管理器,它們接受 InteractionType 列舉輸入
>>> with set_interaction_type(InteractionType.MEAN):
... output = module(input) # takes the input of the distribution, if ProbabilisticTensorDictModule is invoked
預設的 InteractionType 是 InteractionType.DETERMINISTIC,如果未直接實現,則為實數域分佈的均值或離散域分佈的眾數。可以使用 ProbabilisticTensorDictModule 的 default_interaction_type 關鍵字引數更改此預設值。
我們來回顧一下:為了控制網路的取樣策略,我們可以在建構函式中定義一個預設取樣策略,或者透過 set_interaction_type 上下文管理器在執行時覆蓋它。
正如我們從以下示例中看到的,torch.export 正確響應了裝飾器的使用:如果我們要求一個隨機樣本,輸出與我們要求均值時不同
with set_interaction_type(InteractionType.RANDOM):
model_export = export(model, args=(), kwargs={"x": x})
print(model_export.module())
with set_interaction_type(InteractionType.MEAN):
model_export = export(model, args=(), kwargs={"x": x})
print(model_export.module())
GraphModule(
(module): Module(
(0): Module(
(module): Module()
)
(2): Module(
(module): Module()
)
)
)
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
module_0_module_weight = getattr(self.module, "0").module.weight
module_0_module_bias = getattr(self.module, "0").module.bias
module_2_module_weight = getattr(self.module, "2").module.weight
module_2_module_bias = getattr(self.module, "2").module.bias
linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); relu = module_2_module_weight = module_2_module_bias = None
chunk = torch.ops.aten.chunk.default(linear_1, 2, -1); linear_1 = None
getitem = chunk[0]
getitem_1 = chunk[1]; chunk = None
add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus = torch.ops.aten.softplus.default(add); add = None
add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2 = broadcast_tensors[0]
getitem_3 = broadcast_tensors[1]; broadcast_tensors = None
empty = torch.ops.aten.empty.memory_format([1, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
normal_ = torch.ops.aten.normal_.default(empty); empty = None
mul = torch.ops.aten.mul.Tensor(normal_, getitem_3); normal_ = getitem_3 = None
add_2 = torch.ops.aten.add.Tensor(getitem_2, mul); getitem_2 = mul = None
return pytree.tree_unflatten((add_2,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
GraphModule(
(module): Module(
(0): Module(
(module): Module()
)
(2): Module(
(module): Module()
)
)
)
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
module_0_module_weight = getattr(self.module, "0").module.weight
module_0_module_bias = getattr(self.module, "0").module.bias
module_2_module_weight = getattr(self.module, "2").module.weight
module_2_module_bias = getattr(self.module, "2").module.bias
linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); relu = module_2_module_weight = module_2_module_bias = None
chunk = torch.ops.aten.chunk.default(linear_1, 2, -1); linear_1 = None
getitem = chunk[0]
getitem_1 = chunk[1]; chunk = None
add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus = torch.ops.aten.softplus.default(add); add = None
add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2 = broadcast_tensors[0]
getitem_3 = broadcast_tensors[1]; broadcast_tensors = getitem_3 = None
return pytree.tree_unflatten((getitem_2,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
這就是使用 torch.export 所需瞭解的全部內容。更多資訊請參考 官方文件。
下一步和進一步閱讀¶
查閱
torch.export教程,可在此處獲取;ONNX 支援:查閱 ONNX 教程,瞭解更多關於此功能的資訊。匯出到 ONNX 與此處解釋的 torch.export 非常相似。
對於在沒有 Python 環境的伺服器上部署 PyTorch 程式碼,請查閱 AOTInductor 文件。
指令碼總執行時間: (0 分鐘 1.521 秒)