torch.autograd.function.FunctionCtx.mark_non_differentiable¶
- FunctionCtx.mark_non_differentiable(*args)[source][source]¶
將輸出標記為不可微分。
這個方法最多隻能呼叫一次,可以在
setup_context()或forward()方法中呼叫,並且所有引數都應該是 tensor 輸出。這將把輸出標記為不需要梯度,從而提高反向計算的效率。你仍然需要在
backward()方法中為每個輸出接受一個梯度,但這始終是一個零張量,其形狀與對應輸出的形狀相同。- 這用於例如從排序返回的索引。參見示例:
>>> class Func(Function): >>> @staticmethod >>> def forward(ctx, x): >>> sorted, idx = x.sort() >>> ctx.mark_non_differentiable(idx) >>> ctx.save_for_backward(x, idx) >>> return sorted, idx >>> >>> @staticmethod >>> @once_differentiable >>> def backward(ctx, g1, g2): # still need to accept g2 >>> x, idx = ctx.saved_tensors >>> grad_input = torch.zeros_like(x) >>> grad_input.index_add_(0, idx, g1) >>> return grad_input