torch.autograd.function.FunctionCtx.save_for_backward¶
- FunctionCtx.save_for_backward(*tensors)[原始碼][原始碼]¶
儲存給定的張量供將來呼叫
backward()時使用。save_for_backward最多隻能呼叫一次,可以在setup_context()或forward()方法中呼叫,並且只能用於張量。所有打算在反向傳播中使用張量都應該透過
save_for_backward進行儲存(而不是直接儲存在ctx上),以防止不正確的梯度和記憶體洩漏,並啟用儲存張量鉤子的應用。參見torch.autograd.graph.saved_tensors_hooks。請注意,如果中間張量(既不是
forward()的輸入也不是輸出的張量)被儲存用於反向傳播,您的自定義 Function 可能不支援二次反向傳播。不支援二次反向傳播的自定義 Function 應該使用@once_differentiable裝飾其backward()方法,以便在執行二次反向傳播時引發錯誤。如果您想支援二次反向傳播,可以在反向傳播期間根據輸入重新計算中間張量,或者將中間張量作為自定義 Function 的輸出返回。更多詳細資訊請參閱二次反向傳播教程。在
backward()中,可以透過saved_tensors屬性訪問儲存的張量。在將其返回給使用者之前,會進行檢查以確保它們未在任何修改其內容的原地操作中使用。引數也可以是
None。這是一個空操作。有關如何使用此方法的更多詳細資訊,請參閱擴充套件 torch.autograd。
- 示例:
>>> class Func(Function): >>> @staticmethod >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int): >>> w = x * z >>> out = x * y + y * z + w * y >>> ctx.save_for_backward(x, y, w, out) >>> ctx.z = z # z is not a tensor >>> return out >>> >>> @staticmethod >>> @once_differentiable >>> def backward(ctx, grad_out): >>> x, y, w, out = ctx.saved_tensors >>> z = ctx.z >>> gx = grad_out * (y + y * z) >>> gy = grad_out * (x + z + w) >>> gz = None >>> return gx, gy, gz >>> >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double) >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double) >>> c = 4 >>> d = Func.apply(a, b, c)