注意
點選 此處 下載完整示例程式碼
torch.vmap¶
建立日期:2020 年 10 月 26 日 | 最後更新:2021 年 9 月 01 日 | 最後驗證:未驗證
本教程介紹了 torch.vmap,一個用於 PyTorch 操作的自動向量化工具。torch.vmap 是一個原型功能,目前無法處理許多用例;但是,我們希望收集其用例以改進設計。如果您正在考慮使用 torch.vmap 或認為它在某些方面會非常有用,請透過 https://github.com/pytorch/pytorch/issues/42368 聯絡我們。
那麼,什麼是 vmap?¶
vmap 是一個高階函式。它接受一個函式 func,並返回一個新函式,該函式將 func 對映到輸入資料的某個維度上。它在很大程度上受到了 JAX 的 vmap 的啟發。
語義上,vmap 將“對映”推入由 func 呼叫的 PyTorch 操作中,有效地向量化了這些操作。
import torch
# NB: vmap is only available on nightly builds of PyTorch.
# You can download one at pytorch.org if you're interested in testing it out.
from torch import vmap
vmap 的第一個用例是使處理程式碼中的批處理維度變得更容易。可以編寫一個對單個樣本執行的函式 func,然後使用 vmap(func) 將其提升為一個可以處理批次樣本的函式。但是,func 受制於許多限制:
它必須是函式式的(不能在其中修改 Python 資料結構),就地 (in-place) PyTorch 操作除外。
批次樣本必須以張量形式提供。這意味著 vmap 本身無法處理變長序列。
使用 vmap 的一個例子是計算批處理的點積。PyTorch 沒有提供批處理的 torch.dot API;與其徒勞地查閱文件,不如使用 vmap 構建一個新函式。
torch.dot # [D], [D] -> []
batched_dot = torch.vmap(torch.dot) # [N, D], [N, D] -> [N]
x, y = torch.randn(2, 5), torch.randn(2, 5)
batched_dot(x, y)
vmap 有助於隱藏批處理維度,從而簡化模型編寫體驗。
batch_size, feature_size = 3, 5
weights = torch.randn(feature_size, requires_grad=True)
# Note that model doesn't work with a batch of feature vectors because
# torch.dot must take 1D tensors. It's pretty easy to rewrite this
# to use `torch.matmul` instead, but if we didn't want to do that or if
# the code is more complicated (e.g., does some advanced indexing
# shenanigins), we can simply call `vmap`. `vmap` batches over ALL
# inputs, unless otherwise specified (with the in_dims argument,
# please see the documentation for more details).
def model(feature_vec):
# Very simple linear model with activation
return feature_vec.dot(weights).relu()
examples = torch.randn(batch_size, feature_size)
result = torch.vmap(model)(examples)
expected = torch.stack([model(example) for example in examples.unbind()])
assert torch.allclose(result, expected)
vmap 還可以幫助向量化之前難以或無法進行批次處理的計算。這引出了我們的第二個用例:批處理梯度計算。
PyTorch 的 autograd 引擎計算 vjps(向量-雅可比乘積)。使用 vmap,我們可以計算(批處理向量)- 雅可比乘積。
一個例子是計算完整的雅可比矩陣(這也可應用於計算完整的 Hessian 矩陣)。對於函式 f: R^N -> R^N 計算完整的雅可比矩陣通常需要呼叫 autograd.grad N 次,雅可比的每一行呼叫一次。
# Setup
N = 5
def f(x):
return x ** 2
x = torch.randn(N, requires_grad=True)
y = f(x)
basis_vectors = torch.eye(N)
# Sequential approach
jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
for v in basis_vectors.unbind()]
jacobian = torch.stack(jacobian_rows)
# Using `vmap`, we can vectorize the whole computation, computing the
# Jacobian in a single call to `autograd.grad`.
def get_vjp(v):
return torch.autograd.grad(y, x, v)[0]
jacobian_vmap = vmap(get_vjp)(basis_vectors)
assert torch.allclose(jacobian_vmap, jacobian)
vmap 的第三個主要用例是計算逐樣本梯度 (per-sample-gradients)。這是 vmap 原型目前無法高效能處理的功能。我們不確定計算逐樣本梯度的 API 應該是什麼樣子,如果您有想法,請在 https://github.com/pytorch/pytorch/issues/7786 中評論。
def model(sample, weight):
# do something...
return torch.dot(sample, weight)
def grad_sample(sample):
return torch.autograd.functional.vjp(lambda weight: model(sample), weight)[1]
# The following doesn't actually work in the vmap prototype. But it
# could be an API for computing per-sample-gradients.
# batch_of_samples = torch.randn(64, 5)
# vmap(grad_sample)(batch_of_samples)
指令碼總執行時間: ( 0 分 0.000 秒)