控制流程 - Cond¶
torch.cond 是一個結構化控制流程運算子。它可以用於指定類似 if-else 的控制流程,並且在邏輯上可以視為以下實現方式。
def cond(
    pred: Union[bool, torch.Tensor],
    true_fn: Callable,
    false_fn: Callable,
    operands: Tuple[torch.Tensor]
):
    if pred:
        return true_fn(*operands)
    else:
        return false_fn(*operands)
它的獨特之處在於它能夠表達**數據依賴的控制流程**:它會降級為條件運算子 (torch.ops.higher_order.cond),該運算子會保留謂詞、真函數和假函數。這為編寫和部署基於張量運算的輸入或中間輸出的**值**或**形狀**來更改模型架構的模型提供了極大的靈活性。
警告
torch.cond 是 PyTorch 中的原型功能。它對輸入和輸出類型的支持有限,目前不支持訓練。請期待在未來版本的 PyTorch 中獲得更穩定的實現。在以下位置閱讀有關功能分類的更多信息: https://pytorch.com.tw/blog/pytorch-feature-classification-changes/#prototype
範例¶
以下是一個使用 cond 根據輸入形狀進行分支的範例
import torch
def true_fn(x: torch.Tensor):
    return x.cos() + x.sin()
def false_fn(x: torch.Tensor):
    return x.sin()
class DynamicShapeCondPredicate(torch.nn.Module):
    """
    A basic usage of cond based on dynamic shape predicate.
    """
    def __init__(self):
        super().__init__()
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        def true_fn(x: torch.Tensor):
            return x.cos()
        def false_fn(x: torch.Tensor):
            return x.sin()
        return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))
dyn_shape_mod = DynamicShapeCondPredicate()
我們可以急切地運行模型,並期望結果會根據輸入形狀而有所不同
inp = torch.randn(3)
inp2 = torch.randn(5)
assert torch.equal(dyn_shape_mod(inp), false_fn(inp))
assert torch.equal(dyn_shape_mod(inp2), true_fn(inp2))
我們可以匯出模型以進行進一步的轉換和部署
inp = torch.randn(4, 3)
dim_batch = torch.export.Dim("batch", min=2)
ep = torch.export.export(DynamicShapeCondPredicate(), (inp,), {}, dynamic_shapes={"x": {0: dim_batch}})
print(ep)
這將為我們提供如下所示的匯出程序
class GraphModule(torch.nn.Module):
    def forward(self, arg0_1: f32[s0, 3]):
        sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0)
        gt: Sym(s0 > 4) = sym_size > 4;  sym_size = None
        true_graph_0 = self.true_graph_0
        false_graph_0 = self.false_graph_0
        conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = None
        return (conditional,)
    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1: f32[s0, 3]):
            cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
            sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
            add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
            return add
    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1: f32[s0, 3]):
            sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
            return sin
請注意,torch.cond 已降級為 torch.ops.higher_order.cond,其謂詞已成為輸入形狀的符號表達式,而分支函數已成為頂級圖形模組的兩個子圖形屬性。
以下是另一個展示如何表達數據依賴的控制流程的範例
class DataDependentCondPredicate(torch.nn.Module):
    """
    A basic usage of cond based on data dependent predicate.
    """
    def __init__(self):
        super().__init__()
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.cond(x.sum() > 4.0, true_fn, false_fn, (x,))
我們在匯出後獲得的匯出程序
class GraphModule(torch.nn.Module):
    def forward(self, arg0_1: f32[s0, 3]):
        sum_1: f32[] = torch.ops.aten.sum.default(arg0_1)
        gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0);  sum_1 = None
        true_graph_0 = self.true_graph_0
        false_graph_0 = self.false_graph_0
        conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = None
        return (conditional,)
    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1: f32[s0, 3]):
            cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
            sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
            add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
            return add
    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1: f32[s0, 3]):
            sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
            return sin
torch.ops.higher_order.cond 的不變量¶
torch.ops.higher_order.cond 有幾個有用的不變量
- 對於謂詞
- 謂詞的動態性得以保留(例如,上述範例中顯示的 gt) 
- 如果用戶程序中的謂詞是常量(例如,Python 布林常量),則運算子的 pred 將是一個常量。 
 
 
- 對於分支
- 輸入和輸出簽章將是一個扁平化的元組。 
- 它們是 torch.fx.GraphModule。 
- 原始函數中的閉包會變成顯式的輸入。沒有閉包。 
- 不允許對輸入或全局變數進行突變。 
 
 
- 對於運算元
- 它也將是一個扁平化的元組。 
 
 
- 用戶程序中 torch.cond 的嵌套會變成嵌套的圖形模組。 
API 參考¶
- torch._higher_order_ops.cond.cond(pred, true_fn, false_fn, operands)¶
- 有條件地應用 true_fn 或 false_fn。 - 警告 - torch.cond 是 PyTorch 中的原型功能。它對輸入和輸出類型的支持有限,目前不支持訓練。請期待在未來版本的 PyTorch 中獲得更穩定的實現。在以下位置閱讀有關功能分類的更多信息: https://pytorch.com.tw/blog/pytorch-feature-classification-changes/#prototype - cond 是一個結構化控制流程運算子。也就是說,它類似於 Python if 語句,但對 true_fn、false_fn 和 operands 有一些限制,這些限制使其能夠使用 torch.compile 和 torch.export 進行捕獲。 - 假設滿足對 cond 參數的約束,則 cond 等效於以下內容 - def cond(pred, true_branch, false_branch, operands): if pred: return true_branch(*operands) else: return false_branch(*operands) - 參數
- **pred** (Union[bool, torch.Tensor]) – 一個布林表達式或具有一個元素的張量,指示要應用哪個分支函數。 
- **true_fn** (Callable) – 一個可調用函數 (a -> b),它在正在跟踪的範圍內。 
- **false_fn** (Callable) – 一個可調用函數 (a -> b),它在正在跟踪的範圍內。真分支和假分支必須具有一致的輸入和輸出,這意味著輸入必須相同,並且輸出必須具有相同的類型和形狀。 
- **operands** (Tuple of 可能是嵌套的 dict/list/tuple of torch.Tensor) – 真/假函數的輸入元組。 
 
 - 範例 - def true_fn(x: torch.Tensor): return x.cos() def false_fn(x: torch.Tensor): return x.sin() return cond(x.shape[0] > 4, true_fn, false_fn, (x,)) - 限制
- 條件語句(又稱 pred)必須滿足以下約束之一 - 它是一個只有一個元素且數據類型為 torch.bool 的 torch.Tensor 
- 它是一個布林表達式,例如 x.shape[0] > 10 或 x.dim() > 1 and x.shape[1] > 10 
 
- 分支函數(又稱 true_fn/false_fn)必須滿足以下所有約束 - 函數簽章必須與運算元匹配。 
- 函數必須返回具有相同元數據的張量,例如形狀、數據類型等。 
- 函數不能對輸入或全局變數進行就地突變。(注意:在分支中允許對中間結果進行就地張量運算,例如 add_) 
 
 
 - 警告 - 時間限制 - cond 目前僅支持**推論**。未來將支持 Autograd。 
- 分支的**輸出**必須是**單個張量**。未來將支持張量 Pytree。