快捷方式

RandomUnstructured

class torch.nn.utils.prune.RandomUnstructured(amount)[source][source]

隨機修剪張量中(當前未修剪的)單元。

引數
  • name (str) – module 中要進行修剪操作的引數名稱。

  • amount (intfloat) – 要修剪的引數數量。如果為 float,應在 0.0 和 1.0 之間,表示要修剪的引數比例。如果為 int,則表示要修剪的引數絕對數量。

classmethod apply(module, name, amount)[source][source]

動態新增修剪和張量的重引數化。

添加了前向預鉤子(forward pre-hook),該鉤子支援動態修剪以及將張量重引數化為原始張量和修剪掩碼的形式。

引數
  • module (nn.Module) – 包含要修剪的張量的模組

  • name (str) – module 中要進行修剪操作的引數名稱。

  • amount (intfloat) – 要修剪的引數數量。如果為 float,應在 0.0 和 1.0 之間,表示要修剪的引數比例。如果為 int,則表示要修剪的引數絕對數量。

apply_mask(module)[source]

簡單地處理要修剪的引數與生成的掩碼之間的乘法。

從模組中獲取掩碼和原始張量,並返回張量的修剪版本。

引數

module (nn.Module) – 包含要修剪的張量的模組

返回

輸入張量的修剪版本

返回型別

pruned_tensor (torch.Tensor)

prune(t, default_mask=None, importance_scores=None)[source]

計算並返回輸入張量 t 的修剪版本。

根據 compute_mask() 中指定的修剪規則。

引數
  • t (torch.Tensor) – 要修剪的張量(與 default_mask 維度相同)。

  • importance_scores (torch.Tensor) – 用於計算修剪 t 掩碼的重要性分數張量(與 t 形狀相同)。此張量中的值表示 t 中相應元素的重要性,該 t 正在被修剪。如果未指定或為 None,將使用張量 t 代替。

  • default_mask (torch.Tensor, 可選) – 先前修剪迭代(如果存在)的掩碼。在確定修剪應作用於張量的哪一部分時需要考慮。如果為 None,預設為全 1 掩碼。

返回

張量 t 的修剪版本。

remove(module)[source]

從模組中移除修剪重引數化。

名為 name 的修剪引數保持永久修剪狀態,名為 name+'_orig' 的引數從引數列表中移除。類似地,名為 name+'_mask' 的 buffer 從 buffers 中移除。

注意

修剪本身不會被撤銷或反轉!

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源