快捷方式

torch.jit.fork

torch.jit.fork(func, *args, **kwargs)[源][源]

建立一個執行 func 的非同步任務,並返回對該執行結果值的引用。

函式 fork 會立即返回,因此 func 的返回值可能尚未計算完畢。要強制完成任務並訪問返回值,請在 Future 上呼叫 torch.jit.wait。以返回 Tfunc 呼叫的 fork 型別為 torch.jit.Future[T]fork 呼叫可以任意巢狀,並且可以使用位置引數和關鍵字引數進行呼叫。非同步執行僅在 TorchScript 中執行時才會發生。如果在純 Python 中執行,fork 將不會並行執行。fork 在追蹤(tracing)時也不會並行執行,但 forkwait 呼叫將被捕獲在匯出的 IR 圖中。

警告

fork 任務將以非確定性方式執行。我們建議僅對不修改輸入、模組屬性或全域性狀態的純函式建立並行 fork 任務。

引數
  • func (可呼叫物件或 torch.nn.Module) – 要呼叫的 Python 函式或 torch.nn.Module。如果在 TorchScript 中執行,它將非同步執行,否則不會。對 fork 的追蹤呼叫將被捕獲在 IR 中。

  • *args – 用於呼叫 func 的引數。

  • **kwargs – 用於呼叫 func 的引數。

返回

func 執行的引用。值 T 只能透過 torch.jit.wait 強制完成 func 後才能訪問。

返回型別

torch.jit.Future[T]

示例(分叉一個自由函式)

import torch
from torch import Tensor


def foo(a: Tensor, b: int) -> Tensor:
    return a + b


def bar(a):
    fut: torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2)
    return torch.jit.wait(fut)


script_bar = torch.jit.script(bar)
input = torch.tensor(2)
# only the scripted version executes asynchronously
assert script_bar(input) == bar(input)
# trace is not run asynchronously, but fork is captured in IR
graph = torch.jit.trace(bar, (input,)).graph
assert "fork" in str(graph)

示例(分叉一個模組方法)

import torch
from torch import Tensor


class AddMod(torch.nn.Module):
    def forward(self, a: Tensor, b: int):
        return a + b


class Mod(torch.nn.Module):
    def __init__(self) -> None:
        super(self).__init__()
        self.mod = AddMod()

    def forward(self, input):
        fut = torch.jit.fork(self.mod, a, b=2)
        return torch.jit.wait(fut)


input = torch.tensor(2)
mod = Mod()
assert mod(input) == torch.jit.script(mod).forward(input)

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並解答你的問題

檢視資源