控制流 - Cond¶
torch.cond 是一個結構化控制流運算子。它可用於指定 if-else 類的控制流,邏輯上可以看作是如下實現。
def cond(
pred: Union[bool, torch.Tensor],
true_fn: Callable,
false_fn: Callable,
operands: Tuple[torch.Tensor]
):
if pred:
return true_fn(*operands)
else:
return false_fn(*operands)
其獨特之處在於能夠表達資料依賴的控制流:它被降低為一個條件運算子 (torch.ops.higher_order.cond),該運算子保留了謂詞、真函式和假函式。這極大地靈活了模型的編寫和部署,使模型能夠根據輸入張量或中間張量操作的值或形狀改變架構。
警告
torch.cond 在 PyTorch 中是一個原型特性。它對輸入和輸出型別支援有限,目前不支援訓練。請期待未來版本中更穩定的實現。要詳細瞭解特性分類,請參閱:https://pytorch.com.tw/blog/pytorch-feature-classification-changes/#prototype
示例¶
以下是一個使用 cond 基於輸入形狀進行分支的示例
import torch
def true_fn(x: torch.Tensor):
return x.cos() + x.sin()
def false_fn(x: torch.Tensor):
return x.sin()
class DynamicShapeCondPredicate(torch.nn.Module):
"""
A basic usage of cond based on dynamic shape predicate.
"""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
def true_fn(x: torch.Tensor):
return x.cos()
def false_fn(x: torch.Tensor):
return x.sin()
return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))
dyn_shape_mod = DynamicShapeCondPredicate()
我們可以急切地執行模型,並期望結果根據輸入形狀而變化
inp = torch.randn(3)
inp2 = torch.randn(5)
assert torch.equal(dyn_shape_mod(inp), false_fn(inp))
assert torch.equal(dyn_shape_mod(inp2), true_fn(inp2))
我們可以匯出模型以進行進一步的轉換和部署
inp = torch.randn(4, 3)
dim_batch = torch.export.Dim("batch", min=2)
ep = torch.export.export(DynamicShapeCondPredicate(), (inp,), {}, dynamic_shapes={"x": {0: dim_batch}})
print(ep)
這給了我們如下所示的匯出程式
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0)
gt: Sym(s0 > 4) = sym_size > 4; sym_size = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
return (conditional,)
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
return add
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
return sin
注意,torch.cond 被降低為 torch.ops.higher_order.cond,其謂詞成為輸入形狀上的一個 Symbolic 表示式,分支函式成為頂層圖模組的兩個子圖屬性。
這是另一個展示如何表達資料依賴控制流的示例
class DataDependentCondPredicate(torch.nn.Module):
"""
A basic usage of cond based on data dependent predicate.
"""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.cond(x.sum() > 4.0, true_fn, false_fn, (x,))
匯出後獲得的匯出程式
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sum_1: f32[] = torch.ops.aten.sum.default(arg0_1)
gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
return (conditional,)
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
return add
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
return sin
torch.ops.higher_order.cond 的不變式¶
對於 torch.ops.higher_order.cond 有幾個有用的不變式
- 對於謂詞
謂詞的動態性得到保留(例如,上述示例中顯示的 gt)
如果使用者程式中的謂詞是常量(例如,一個 Python 布林常量),則運算子的 pred 將是常量。
- 對於分支
輸入和輸出簽名將是一個扁平化的元組。
它們是 torch.fx.GraphModule。
原始函式中的閉包變為顯式輸入。沒有閉包。
不允許對輸入或全域性變數進行改變。
- 對於運算元
它也將是一個扁平化的元組。
使用者程式中 torch.cond 的巢狀變為巢狀的圖模組。
API 參考¶
- torch._higher_order_ops.cond.cond(pred, true_fn, false_fn, operands=())[原始碼]¶
有條件地應用 true_fn 或 false_fn。
警告
torch.cond 在 PyTorch 中是一個原型特性。它對輸入和輸出型別支援有限,目前不支援訓練。請期待未來版本中更穩定的實現。要詳細瞭解特性分類,請參閱:https://pytorch.com.tw/blog/pytorch-feature-classification-changes/#prototype
cond 是一個結構化控制流運算子。也就是說,它類似於 Python 的 if 語句,但對 true_fn、false_fn 和 operands 有限制,使其能夠使用 torch.compile 和 torch.export 進行捕獲。
假設滿足 cond 引數的約束,cond 等價於以下內容
def cond(pred, true_branch, false_branch, operands): if pred: return true_branch(*operands) else: return false_branch(*operands)
- 引數
pred (Union[bool, torch.Tensor]) – 一個布林表示式或一個包含一個元素的張量,指示要應用哪個分支函式。
true_fn (Callable) – 一個可呼叫函式 (a -> b),位於正在追蹤的範圍之內。
false_fn (Callable) – 一個可呼叫函式 (a -> b),位於正在追蹤的範圍之內。真分支和假分支必須具有一致的輸入和輸出,這意味著輸入必須相同,輸出必須具有相同的型別和形狀。
operands (Tuple of possibly nested dict/list/tuple of torch.Tensor) – 一個包含真/假函式輸入的元組。如果 true_fn/false_fn 不需要輸入,則可以為空。預設為 ()。
- 返回型別
示例
def true_fn(x: torch.Tensor): return x.cos() def false_fn(x: torch.Tensor): return x.sin() return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
- 限制
條件語句(即 pred)必須滿足以下約束之一
它是一個 torch.Tensor,只包含一個元素且 dtype 為 torch.bool
它是一個布林表示式,例如 x.shape[0] > 10 或 x.dim() > 1 and x.shape[1] > 10
分支函式(即 true_fn/false_fn)必須滿足以下所有約束
函式簽名必須與 operands 匹配。
函式必須返回具有相同元資料的張量,例如 shape、dtype 等。
函式不能對輸入或全域性變數進行原地(in-place)改變。(注意:分支中允許對中間結果進行原地張量操作,例如 add_)