快捷方式

torch.nn.utils.clip_grad_norm_

torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinite=False, foreach=None)[原始碼][原始碼]

裁剪可迭代引數的梯度範數。

範數是根據所有引數的單獨梯度的範數計算的,就像將單獨梯度的範數連線成一個向量一樣。梯度會在原地修改。

此函式等效於先呼叫 torch.nn.utils.get_total_norm(),然後使用 get_total_norm 返回的 total_norm 呼叫 torch.nn.utils.clip_grads_with_norm_()

引數
  • parameters (Iterable[Tensor] or Tensor) – 一個 Tensor 的可迭代物件或單個 Tensor,將對其梯度進行歸一化

  • max_norm (float) – 梯度的最大範數

  • norm_type (float) – 使用的 p-範數型別。對於無窮範數,可以是 'inf'

  • error_if_nonfinite (bool) – 如果為 True,則當 parameters 中梯度的總範數為 naninf-inf 時會丟擲錯誤。預設值:False (未來會切換為 True)

  • foreach (bool) – 使用更快的基於 foreach 的實現。如果為 None,則對 CUDA 和 CPU 原生 Tensor 使用 foreach 實現,對其他裝置型別靜默回退到慢速實現。預設值:None

返回值

引數梯度的總範數(視為一個向量)。

返回值型別

Tensor

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源