快捷方式

torch.nn.utils.prune.custom_from_mask

torch.nn.utils.prune.custom_from_mask(module, name, mask)[source][source]

透過應用 mask 中的預計算掩碼,對 module 中名為 name 的引數對應的張量進行剪枝。

透過以下方式就地修改模組(並返回修改後的模組):

  1. 新增一個名為 name+'_mask' 的命名緩衝區,對應於剪枝方法應用於引數 name 的二進位制掩碼。

  2. 將其引數 name 替換為其剪枝後的版本,同時將原始(未剪枝的)引數儲存在一個名為 name+'_orig' 的新引數中。

引數
  • module (nn.Module) – 包含要剪枝張量的模組

  • name (str) – 模組 module 中將對其進行剪枝的引數名稱。

  • mask (Tensor) – 要應用於引數的二進位制掩碼。

返回值

輸入模組的修改(即已剪枝)版本

返回值型別

module (nn.Module)

示例

>>> from torch.nn.utils import prune
>>> m = prune.custom_from_mask(
...     nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0])
... )
>>> print(m.bias_mask)
tensor([0., 1., 0.])

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源