• 文件 >
  • 追蹤 TensorDictModule
快捷方式

追蹤 TensorDictModule

我們支援追蹤 TensorDictModule 的執行過程以建立 FX 圖。只需從 tensordict.prototype.fx 匯入 symbolic_trace,而不是從 torch.fx 匯入。

注意

torch.fx 的支援處於高度實驗性階段,隨時可能發生變化。請謹慎使用,如果您在使用過程中遇到問題,請提出問題。

追蹤 TensorDictModule

我們將透過概述中的一個示例進行說明。我們建立一個 TensorDictModule,對其進行追蹤,並檢查生成的圖和程式碼。

追蹤 TensorDictModule
>>> import torch
>>> import torch.nn as nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> from tensordict.prototype.fx import symbolic_trace

>>> class Net(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.linear = nn.LazyLinear(1)
...
...     def forward(self, x):
...         logits = self.linear(x)
...         return logits, torch.sigmoid(logits)
>>> module = TensorDictModule(
...     Net(),
...     in_keys=["input"],
...     out_keys=[("outputs", "logits"), ("outputs", "probabilities")],
... )
>>> graph_module = symbolic_trace(module)
>>> print(graph_module.graph)
graph():
    %tensordict : [#users=1] = placeholder[target=tensordict]
    %getitem : [#users=1] = call_function[target=operator.getitem](args = (%tensordict, input), kwargs = {})
    %linear : [#users=2] = call_module[target=linear](args = (%getitem,), kwargs = {})
    %sigmoid : [#users=1] = call_function[target=torch.sigmoid](args = (%linear,), kwargs = {})
    return (linear, sigmoid)
>>> print(graph_module.code)

def forward(self, tensordict):
    getitem = tensordict['input'];  tensordict = None
    linear = self.linear(getitem);  getitem = None
    sigmoid = torch.sigmoid(linear)
    return (linear, sigmoid)

我們可以檢查每個模組的前向傳播是否產生相同的輸出。

>>> tensordict = TensorDict({"input": torch.randn(32, 100)}, [32])
>>> module_out = module(tensordict, tensordict_out=TensorDict())
>>> graph_module_out = graph_module(tensordict, tensordict_out=TensorDict())
>>> assert (
...     module_out["outputs", "logits"] == graph_module_out["outputs", "logits"]
... ).all()
>>> assert (
...     module_out["outputs", "probabilities"]
...     == graph_module_out["outputs", "probabilities"]
... ).all()

追蹤 TensorDictSequential

我們還可以追蹤 TensorDictSequential。在這種情況下,模組的整個執行過程被追蹤到一個單一的圖中,消除了對輸入 TensorDict 的中間讀寫操作。

我們透過追蹤概述中的序列示例進行演示。

追蹤 TensorDictSequential
>>> import torch
>>> import torch.nn as nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> from tensordict.prototype.fx import symbolic_trace

>>> class Net(nn.Module):
...     def __init__(self, input_size=100, hidden_size=50, output_size=10):
...         super().__init__()
...         self.fc1 = nn.Linear(input_size, hidden_size)
...         self.fc2 = nn.Linear(hidden_size, output_size)
...
...     def forward(self, x):
...         x = torch.relu(self.fc1(x))
...         return self.fc2(x)
...
... class Masker(nn.Module):
...     def forward(self, x, mask):
...         return torch.softmax(x * mask, dim=1)
>>> net = TensorDictModule(
...     Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")]
... )
>>> masker = TensorDictModule(
...     Masker(),
...     in_keys=[("intermediate", "x"), ("input", "mask")],
...     out_keys=[("output", "probabilities")],
... )
>>> module = TensorDictSequential(net, masker)
>>> graph_module = symbolic_trace(module)
>>> print(graph_module.code)

def forward(self, tensordict):
    getitem = tensordict[('input', 'x')]
    _0_fc1 = getattr(self, "0").module.fc1(getitem);  getitem = None
    relu = torch.relu(_0_fc1);  _0_fc1 = None
    _0_fc2 = getattr(self, "0").module.fc2(relu);  relu = None
    getitem_1 = tensordict[('input', 'mask')];  tensordict = None
    mul = _0_fc2 * getitem_1;  getitem_1 = None
    softmax = torch.softmax(mul, dim = 1);  mul = None
    return (_0_fc2, softmax)

在這種情況下,生成的圖和程式碼會稍微複雜一些。我們可以如下所示進行視覺化(需要 pydot

視覺化圖
>>> from torch.fx.passes.graph_drawer import FxGraphDrawer
>>> g = FxGraphDrawer(graph_module, "sequential")
>>> with open("graph.svg", "wb") as f:
...     f.write(g.get_dot_graph().create_svg())

結果如下所示

Visualization of the traced graph.

文件

查閱 PyTorch 開發者完整文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源