torch.nn.utils.prune.identity¶
- torch.nn.utils.prune.identity(module, name)[原始碼][原始碼]¶
應用剪枝重引數化,但不實際剪除任何單元。
對
module中名為name的引數對應的張量應用剪枝重引數化,但不實際剪除任何單元。該函式會原地修改模組(並返回修改後的模組),具體操作包括:新增一個名為
name+'_mask'的具名 buffer,它對應於剪枝方法應用於引數name的二值掩碼。將引數
name替換為其剪枝後的版本,而原始(未剪枝)引數則儲存在一個名為name+'_orig'的新引數中。
注意
掩碼是一個全為 1 的張量。
- 引數
- 返回
輸入模組的修改(即剪枝後)版本
- 返回型別
module (nn.Module)
示例
>>> m = prune.identity(nn.Linear(2, 3), 'bias') >>> print(m.bias_mask) tensor([1., 1., 1.])