快捷方式

torch.autograd.function.FunctionCtx.mark_non_differentiable

FunctionCtx.mark_non_differentiable(*args)[source][source]

將輸出標記為不可微分。

這個方法最多隻能呼叫一次,可以在 setup_context()forward() 方法中呼叫,並且所有引數都應該是 tensor 輸出。

這將把輸出標記為不需要梯度,從而提高反向計算的效率。你仍然需要在 backward() 方法中為每個輸出接受一個梯度,但這始終是一個零張量,其形狀與對應輸出的形狀相同。

這用於例如從排序返回的索引。參見示例:
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         sorted, idx = x.sort()
>>>         ctx.mark_non_differentiable(idx)
>>>         ctx.save_for_backward(x, idx)
>>>         return sorted, idx
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):  # still need to accept g2
>>>         x, idx = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         grad_input.index_add_(0, idx, g1)
>>>         return grad_input

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源