快捷方式

torch.cond

torch.cond(pred, true_fn, false_fn, operands=())[source]

有條件地應用 true_fnfalse_fn

警告

torch.cond 是 PyTorch 中的一個原型功能。目前它對輸入和輸出型別的支援有限,且不支援訓練。敬請期待 PyTorch 未來版本中更穩定的實現。閱讀更多關於功能分類的資訊:https://pytorch.com.tw/blog/pytorch-feature-classification-changes/#prototype

cond 是一個結構化控制流運算元。也就是說,它類似於 Python 的 if 語句,但對 true_fnfalse_fnoperands 有限制,這些限制使其可以使用 torch.compile 和 torch.export 進行捕獲。

假設 cond 引數的約束條件滿足,cond 等效於以下內容

def cond(pred, true_branch, false_branch, operands):
    if pred:
        return true_branch(*operands)
    else:
        return false_branch(*operands)
引數
  • pred (Union[bool, torch.Tensor]) – 一個布林表示式或一個只有一個元素的 tensor,指示應用哪個分支函式。

  • true_fn (Callable) – 一個可呼叫函式 (a -> b),位於正在被跟蹤的作用域內。

  • false_fn (Callable) – 一個可呼叫函式 (a -> b),位於正在被跟蹤的作用域內。真分支和假分支的輸入和輸出必須一致,這意味著輸入必須相同,並且輸出必須具有相同的型別和形狀。

  • operands (Tuple of possibly nested dict/list/tuple of torch.Tensor) – true/false 函式的輸入元組。如果 true_fn/false_fn 不需要輸入,它可以是空的。預設為 ()。

返回型別

Any

示例

def true_fn(x: torch.Tensor):
    return x.cos()
def false_fn(x: torch.Tensor):
    return x.sin()
return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
限制條件
  • 條件語句(即 pred)必須滿足以下約束條件之一

    • 它是一個只有一個元素的 torch.Tensor,且資料型別為 torch.bool

    • 它是一個布林表示式,例如 x.shape[0] > 10x.dim() > 1 and x.shape[1] > 10

  • 分支函式(即 true_fn/false_fn)必須滿足以下所有約束條件

    • 函式簽名必須與 operands 匹配。

    • 函式必須返回具有相同元資料(例如形狀、資料型別等)的 tensor。

    • 函式不能對輸入或全域性變數進行原地修改。(注意:分支中允許使用原地 tensor 操作,例如用於中間結果的 add_

文件

查閱全面的 PyTorch 開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲取問題解答

檢視資源