快捷方式

torch.nn.utils.prune.identity

torch.nn.utils.prune.identity(module, name)[原始碼][原始碼]

應用剪枝重引數化,但不實際剪除任何單元。

module 中名為 name 的引數對應的張量應用剪枝重引數化,但不實際剪除任何單元。該函式會原地修改模組(並返回修改後的模組),具體操作包括:

  1. 新增一個名為 name+'_mask' 的具名 buffer,它對應於剪枝方法應用於引數 name 的二值掩碼。

  2. 將引數 name 替換為其剪枝後的版本,而原始(未剪枝)引數則儲存在一個名為 name+'_orig' 的新引數中。

注意

掩碼是一個全為 1 的張量。

引數
  • module (nn.Module) – 包含要剪枝張量的模組。

  • name (str) – 模組中將要進行剪枝操作的引數名稱。

返回

輸入模組的修改(即剪枝後)版本

返回型別

module (nn.Module)

示例

>>> m = prune.identity(nn.Linear(2, 3), 'bias')
>>> print(m.bias_mask)
tensor([1., 1., 1.])

文件

查閱 PyTorch 的完整開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得解答

檢視資源