快捷方式

torch.export IR 規範

Export IR 是用於編譯器的中間表示 (IR),與 MLIR 和 TorchScript 類似。它專門設計用於表達 PyTorch 程式的語義。Export IR 主要透過精簡的運算列表表示計算,對控制流等動態特性的支援有限。

為了建立 Export IR 圖,可以使用前端透過跟蹤特化機制可靠地捕獲 PyTorch 程式。生成的 Export IR 可以由後端進行最佳化和執行。目前可以透過 torch.export.export() 實現這一點。

本文件將涵蓋的關鍵概念包括

  • ExportedProgram:包含 Export IR 程式的資料結構

  • Graph:由節點列表組成。

  • Nodes:表示運算、控制流以及儲存在此節點上的元資料。

  • 值由節點產生和消費。

  • 型別與值和節點相關聯。

  • 還定義了值的大小和記憶體佈局。

假設

本文件假設讀者對 PyTorch 足夠熟悉,特別是對 torch.fx 及其相關工具。因此,將不再描述 torch.fx 文件和論文中已有的內容。

什麼是 Export IR

Export IR 是 PyTorch 程式的基於圖的中間表示 IR。Export IR 是在 torch.fx.Graph 之上實現的。換句話說,所有 Export IR 圖都是有效的 FX 圖,如果使用標準的 FX 語義解釋,Export IR 可以被可靠地解釋。一個推論是,匯出的圖可以透過標準的 FX 程式碼生成轉換為有效的 Python 程式。

本文件將主要重點介紹 Export IR 在嚴格性方面與 FX 的不同之處,同時跳過與 FX 相似的部分。

ExportedProgram

頂層 Export IR 構造是 torch.export.ExportedProgram 類。它將 PyTorch 模型的計算圖(通常是 torch.nn.Module)與該模型使用的引數或權重捆綁在一起。

torch.export.ExportedProgram 類的一些值得注意的屬性包括

  • graph_module (torch.fx.GraphModule):包含 PyTorch 模型展平計算圖的資料結構。可以透過 ExportedProgram.graph 直接訪問該圖。

  • graph_signature (torch.export.ExportGraphSignature):圖簽名,指定圖中使用和修改的引數和緩衝區名稱。引數和緩衝區不作為圖的屬性儲存,而是作為圖的輸入被提升。graph_signature 用於跟蹤這些引數和緩衝區的附加資訊。

  • state_dict (Dict[str, Union[torch.Tensor, torch.nn.Parameter]]):包含引數和緩衝區的資料結構。

  • range_constraints (Dict[sympy.Symbol, RangeConstraint]):對於匯出時具有資料依賴行為的程式,每個節點上的元資料將包含符號形狀(例如 s0i0)。此屬性將符號形狀對映到其下限/上限範圍。

Graph

Export IR Graph 是以 DAG(有向無環圖)形式表示的 PyTorch 程式。此圖中的每個節點表示特定的計算或運算,圖的邊由節點之間的引用組成。

我們可以將 Graph 視為具有以下模式

class Graph:
  nodes: List[Node]

實際上,Export IR 的圖透過 torch.fx.Graph Python 類實現。

Export IR 圖包含以下節點(節點將在下一節中更詳細地描述)

  • 0 個或多個 op 型別為 placeholder 的節點

  • 0 個或多個 op 型別為 call_function 的節點

  • 恰好 1 個 op 型別為 output 的節點

推論:最小的有效 Graph 將包含一個節點。即,nodes 永遠不為空。

定義: Graph 的 placeholder 節點集合表示 GraphModule 的 Graph 的輸入。output 節點表示 GraphModule 的 Graph 的輸出。

示例

import torch
from torch import nn

class MyModule(nn.Module):

    def forward(self, x, y):
      return x + y

example_args = (torch.randn(1), torch.randn(1))
mod = torch.export.export(MyModule(), example_args)
print(mod.graph)
graph():
  %x : [num_users=1] = placeholder[target=x]
  %y : [num_users=1] = placeholder[target=y]
  %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {})
  return (add,)

上圖是 Graph 的文字表示,每行是一個節點。

節點

Node 表示特定的計算或運算,在 Python 中使用 torch.fx.Node 類表示。節點之間的邊透過 Node 類的 args 屬性表示為對其他節點的直接引用。使用相同的 FX 機制,我們可以表示計算圖通常需要的以下運算,例如運算元呼叫、placeholder(即輸入)、條件語句和迴圈。

Node 具有以下模式

class Node:
  name: str # name of node
  op_name: str  # type of operation

  # interpretation of the fields below depends on op_name
  target: [str|Callable]
  args: List[object]
  kwargs: Dict[str, object]
  meta: Dict[str, object]

FX 文字格式

如上例所示,請注意每行都具有此格式

%<name>:[...] = <op_name>[target=<target>](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5})

此格式以緊湊形式捕獲 Node 類中的所有內容,但 meta 除外。

具體來說

  • <name> 是節點的名稱,如它在 node.name 中顯示的那樣。

  • <op_name>node.op 欄位,必須是以下之一:<call_function><placeholder><get_attr><output>

  • <target> 是節點的 target,如 node.target。此欄位的含義取決於 op_name

  • args1, … args 4…node.args 元組中列出的內容。如果列表中的值是 torch.fx.Node,則會特別用前導 %. 指示。

例如,對 add 運算元的呼叫將顯示為

%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {})

其中 %x%y 是另外兩個名稱為 x 和 y 的 Node。值得注意的是,字串 torch.op.aten.add.Tensor 表示實際儲存在 target 欄位中的可呼叫物件,而不僅僅是其字串名稱。

此文字格式的最後一行是

return [add]

它是一個 op_name = output 的 Node,表示我們正在返回此元素。

call_function

一個 call_function 節點表示對運算元的呼叫。

定義

  • 函式式:如果一個可呼叫物件滿足以下所有要求,我們稱之為“函式式”

    • 非變異:運算元不修改其輸入的值(對於張量,這包括元資料和資料)。

    • 無副作用:運算元不修改外部可見的狀態,例如更改模組引數的值。

  • 運算元:是具有預定義模式的函式式可呼叫物件。此類運算元示例包括函式式 ATen 運算元。

在 FX 中的表示

%name = call_function[target = operator](args = (%x, %y, …), kwargs = {})

與普通 FX call_function 的區別

  1. 在 FX 圖中,一個 call_function 可以引用任何可呼叫物件,而在 Export IR 中,我們將其限制為僅一部分選定的 ATen 運算元、自定義運算元和控制流運算元。

  2. 在 Export IR 中,常量引數將嵌入到圖中。

  3. 在 FX 圖中,get_attr 節點可以表示讀取圖模組中儲存的任何屬性。然而,在 Export IR 中,這被限制為僅讀取子模組,因為所有引數/緩衝區都將作為輸入傳遞給圖模組。

元資料

Node.meta 是附加到每個 FX 節點的字典。然而,FX 規範並未指定其中可以或將包含哪些元資料。Export IR 提供了更強的契約,特別是所有 call_function 節點都保證具有且僅具有以下元資料欄位

  • node.meta["stack_trace"] 是包含引用原始 Python 原始碼的 Python 堆疊跟蹤的字串。示例堆疊跟蹤如下所示

    File "my_module.py", line 19, in forward
    return x + dummy_helper(y)
    File "helper_utility.py", line 89, in dummy_helper
    return y + 1
    
  • node.meta["val"] 描述了執行運算的輸出。它可以是 <symint><FakeTensor>List[Union[FakeTensor, SymInt]]None 型別。

  • node.meta["nn_module_stack"] 描述了節點來源的 torch.nn.Module 的“堆疊跟蹤”,如果它來自 torch.nn.Module 呼叫的話。例如,如果包含來自 torch.nn.Sequential 模組內部的 torch.nn.Linear 模組呼叫的 addmm 運算元的節點,則 nn_module_stack 將類似於

    {'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)}
    
  • node.meta["source_fn_stack"] 包含此節點在分解之前從哪個 torch 函式或葉子 torch.nn.Module 類呼叫。例如,包含來自 torch.nn.Linear 模組呼叫的 addmm 運算元的節點將在其 source_fn 中包含 torch.nn.Linear,而包含來自 torch.nn.functional.Linear 模組呼叫的 addmm 運算元的節點將在其 source_fn 中包含 torch.nn.functional.Linear

placeholder

Placeholder 表示圖的輸入。其語義與 FX 中完全相同。Placeholder 節點必須是圖節點列表中的前 N 個節點。N 可以為零。

在 FX 中的表示

%name = placeholder[target = name](args = ())

target 欄位是一個字串,它是輸入的名稱。

args,如果非空,其大小應為 1,表示此輸入的預設值。

元資料

Placeholder 節點也具有 meta[‘val’],就像 call_function 節點一樣。在這種情況下,val 欄位表示圖期望為此輸入引數接收的輸入形狀/dtype。

output

一個 output 呼叫表示函式中的返回語句;因此它終止當前圖。只有一個 output 節點,並且它始終是圖的最後一個節點。

在 FX 中的表示

output[](args = (%something, …))

這與 torch.fx 中的語義完全相同。args 表示要返回的節點。

元資料

Output 節點具有與 call_function 節點相同的元資料。

get_attr

get_attr 節點表示從封裝的 torch.fx.GraphModule 讀取子模組。與透過 torch.fx.symbolic_trace() 生成的普通 FX 圖不同,在普通 FX 圖中,get_attr 節點用於從頂層 torch.fx.GraphModule 讀取引數和緩衝區等屬性,而在 Export IR 中,引數和緩衝區作為輸入傳遞給圖模組,並存儲在頂層 torch.export.ExportedProgram 中。

在 FX 中的表示

%name = get_attr[target = name](args = ())

示例

考慮以下模型

from functorch.experimental.control_flow import cond

def true_fn(x):
    return x.sin()

def false_fn(x):
    return x.cos()

def f(x, y):
    return cond(y, true_fn, false_fn, [x])

Graph

graph():
    %x_1 : [num_users=1] = placeholder[target=x_1]
    %y_1 : [num_users=1] = placeholder[target=y_1]
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {})
    return conditional

%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] 讀取子模組 true_graph_0,其中包含 sin 運算元。

參考

SymInt

SymInt 是一個可以表示字面整數或代表整數(在 Python 中由 sympy.Symbol 類表示)的符號的物件。當 SymInt 是一個符號時,它描述了一個在編譯時未知、其值僅在執行時已知的整數型別變數。

FakeTensor

FakeTensor 是包含張量元資料的物件。可以將其視為具有以下元資料。

class FakeTensor:
  size: List[SymInt]
  dtype: torch.dtype
  device: torch.device
  dim_order: List[int]  # This doesn't exist yet

FakeTensor 的 size 欄位是整數或 SymInt 的列表。如果存在 SymInt,則表示此張量具有動態形狀。如果存在整數,則假定該張量將具有該精確的靜態形狀。TensorMeta 的秩永遠不是動態的。dtype 欄位表示該節點輸出的 dtype。Edge IR 中沒有隱式型別提升。FakeTensor 中沒有 strides。

換句話說

  • 如果 node.target 中的運算元返回 Tensor,則 node.meta['val'] 是描述該張量的 FakeTensor。

  • 如果 node.target 中的運算元返回一個由 Tensor 組成的 n 元組,則 node.meta['val'] 是描述每個張量的 FakeTensor 組成的 n 元組。

  • 如果 node.target 中的運算元返回在編譯時已知的 int/float/標量,則 node.meta['val'] 為 None。

  • 如果 node.target 中的運算元返回在編譯時未知的 int/float/標量,則 node.meta['val'] 的型別為 SymInt。

例如

  • aten::add 返回一個 Tensor;因此其 spec 將是 FakeTensor,包含此運算元返回的張量的 dtype 和 size。

  • aten::sym_size 返回一個整數;因此其 val 將是 SymInt,因為其值僅在執行時可用。

  • max_pool2d_with_indexes 返回一個 (Tensor, Tensor) 元組;因此 spec 也將是由 FakeTensor 物件組成的 2 元組,第一個 TensorMeta 描述返回值的第一個元素,依此類推。

Python 程式碼

def add_one(x):
  return torch.ops.aten(x, 1)

Graph

graph():
  %ph_0 : [#users=1] = placeholder[target=ph_0]
  %add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {})
  return [add_tensor]

FakeTensor

FakeTensor(dtype=torch.int, size=[2,], device=CPU)

Pytree-able 型別

我們定義一種型別為“Pytree-able”,如果它要麼是葉子型別,要麼是包含其他 Pytree-able 型別的容器型別。

注意

pytree 的概念與此處為 JAX 記錄的概念相同

以下型別定義為葉子型別

型別

定義

Tensor

torch.Tensor

標量

Python 中的任何數值型別,包括整數型別、浮點型別和零維張量。

int

Python int(在 C++ 中繫結為 int64_t)

float

Python float(在 C++ 中繫結為 double)

bool

Python bool

str

Python 字串

ScalarType

torch.dtype

Layout

torch.layout

MemoryFormat

torch.memory_format

裝置

torch.device

以下型別定義為容器型別

型別

定義

元組

Python 元組

列表

Python 列表

字典

鍵為 Scalar 的 Python 字典

命名元組

Python namedtuple

Dataclass

必須透過 register_dataclass 註冊

自定義類

使用 _register_pytree_node 定義的任何自定義類

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源