快捷方式

torch.autograd.function.FunctionCtx.save_for_backward

FunctionCtx.save_for_backward(*tensors)[原始碼][原始碼]

儲存給定的張量供將來呼叫 backward() 時使用。

save_for_backward 最多隻能呼叫一次,可以在 setup_context()forward() 方法中呼叫,並且只能用於張量。

所有打算在反向傳播中使用張量都應該透過 save_for_backward 進行儲存(而不是直接儲存在 ctx 上),以防止不正確的梯度和記憶體洩漏,並啟用儲存張量鉤子的應用。參見 torch.autograd.graph.saved_tensors_hooks

請注意,如果中間張量(既不是 forward() 的輸入也不是輸出的張量)被儲存用於反向傳播,您的自定義 Function 可能不支援二次反向傳播。不支援二次反向傳播的自定義 Function 應該使用 @once_differentiable 裝飾其 backward() 方法,以便在執行二次反向傳播時引發錯誤。如果您想支援二次反向傳播,可以在反向傳播期間根據輸入重新計算中間張量,或者將中間張量作為自定義 Function 的輸出返回。更多詳細資訊請參閱二次反向傳播教程

backward() 中,可以透過 saved_tensors 屬性訪問儲存的張量。在將其返回給使用者之前,會進行檢查以確保它們未在任何修改其內容的原地操作中使用。

引數也可以是 None。這是一個空操作。

有關如何使用此方法的更多詳細資訊,請參閱擴充套件 torch.autograd

示例:
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>>         w = x * z
>>>         out = x * y + y * z + w * y
>>>         ctx.save_for_backward(x, y, w, out)
>>>         ctx.z = z  # z is not a tensor
>>>         return out
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, grad_out):
>>>         x, y, w, out = ctx.saved_tensors
>>>         z = ctx.z
>>>         gx = grad_out * (y + y * z)
>>>         gy = grad_out * (x + z + w)
>>>         gz = None
>>>         return gx, gy, gz
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>> c = 4
>>> d = Func.apply(a, b, c)

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並解答您的疑問

檢視資源