快捷方式

torch.nn.utils.prune.random_unstructured

torch.nn.utils.prune.random_unstructured(module, name, amount)[source][source]

透過移除隨機(當前未修剪的)單元來修剪張量。

透過移除隨機選擇的指定 amount 數量的(當前未修剪的)單元,修剪 module 中名為 name 的引數對應的張量。此函式會就地修改模組(並返回修改後的模組),具體方式如下:

  1. 新增一個名為 name+'_mask' 的命名緩衝區,對應於剪枝方法應用於引數 name 的二進位制掩碼。

  2. 將引數 name 替換為其修剪後的版本,同時將原始(未修剪的)引數儲存在一個名為 name+'_orig' 的新引數中。

引數
  • module (nn.Module) – 包含要修剪的張量的模組

  • name (str) – module 中將進行剪枝操作的引數名稱。

  • amount (intfloat) – 要修剪的引數數量。如果為 float,應介於 0.0 和 1.0 之間,表示要修剪的引數比例。如果為 int,表示要修剪的引數的絕對數量。

返回

輸入模組的修改(即修剪後的)版本

返回型別

module (nn.Module)

示例

>>> m = prune.random_unstructured(nn.Linear(2, 3), 'weight', amount=1)
>>> torch.sum(m.weight_mask == 0)
tensor(1)

文件

查閱 PyTorch 的完整開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並解答疑問

檢視資源