快捷方式

torch.autograd.Function.vmap

static Function.vmap(info, in_dims, *args)[source][source]

定義此 autograd.Function 在 torch.vmap() 下的行為。

對於一個 torch.autograd.Function() 要支援 torch.vmap(),你必須要麼覆蓋這個靜態方法,要麼將 generate_vmap_rule 設定為 True(不能同時執行這兩項操作)。

如果你選擇覆蓋此靜態方法:它必須接受

  • 一個 info 物件作為第一個引數。 info.batch_size 指定正在進行 vmap 的維度的批大小,而 info.randomness 是傳遞給 torch.vmap() 的隨機性選項。

  • 一個 in_dims 元組作為第二個引數。對於 args 中的每個引數,in_dims 都有一個對應的 Optional[int]。如果該引數不是 Tensor 或未進行 vmap 操作,則該值為 None;否則,它是一個整數,指定 Tensor 的哪個維度正在進行 vmap 操作。

  • *args,與 forward() 方法的引數相同。

vmap 靜態方法的返回是一個 (output, out_dims) 元組。與 in_dims 類似,out_dims 的結構應與 output 相同,併為每個輸出包含一個 out_dim,指定輸出是否包含 vmap 維度以及該維度所在的索引。

更多詳細資訊請參閱使用 autograd.Function 擴充套件 torch.func

文件

獲取 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源