快捷方式

RandomStructured

class torch.nn.utils.prune.RandomStructured(amount, dim=-1)[source][source]

隨機修剪張量中整個(當前未修剪的)通道。

引數
  • amount (intfloat) – 要修剪的引數數量。如果為 float,應介於 0.0 和 1.0 之間,表示要修剪的引數比例。如果為 int,則表示要修剪的引數的絕對數量。

  • dim (int,可選) – 定義要修剪通道的維度索引。預設值:-1。

classmethod apply(module, name, amount, dim=-1)[source][source]

動態新增修剪並重新引數化張量。

新增前向預鉤子,啟用動態修剪,並根據原始張量和修剪掩碼重新引數化張量。

引數
  • module (nn.Module) – 包含要修剪張量的模組

  • name (str) – 模組內將執行修剪的引數名稱。

  • amount (intfloat) – 要修剪的引數數量。如果為 float,應介於 0.0 和 1.0 之間,表示要修剪的引數比例。如果為 int,則表示要修剪的引數的絕對數量。

  • dim (int,可選) – 定義要修剪通道的維度索引。預設值:-1。

apply_mask(module)[source]

簡單處理被修剪引數與生成的掩碼之間的乘法。

從模組中獲取掩碼和原始張量,並返回張量的修剪版本。

引數

module (nn.Module) – 包含要修剪張量的模組

返回值

輸入張量的修剪版本

返回型別

pruned_tensor (torch.Tensor)

compute_mask(t, default_mask)[source][source]

計算並返回輸入張量 t 的掩碼。

從基礎 default_mask 開始(如果張量尚未修剪,則應為全一掩碼),透過隨機將張量沿指定維度上的通道歸零,生成一個隨機掩碼應用在 default_mask 之上。

引數
  • t (torch.Tensor) – 表示要修剪引數的張量

  • default_mask (torch.Tensor) – 來自先前修剪迭代的基礎掩碼,應用新掩碼後仍需遵守。與 t 具有相同的維度。

返回值

應用於 t 的掩碼,與 t 維度相同

返回型別

mask (torch.Tensor)

丟擲異常

IndexError – 如果 self.dim >= len(t.shape)

prune(t, default_mask=None, importance_scores=None)[source]

計算並返回輸入張量 t 的修剪版本。

根據 compute_mask() 中指定的修剪規則。

引數
  • t (torch.Tensor) – 要修剪的張量(與 default_mask 維度相同)。

  • importance_scores (torch.Tensor) – 重要性分數張量(與 t 形狀相同),用於計算修剪 t 的掩碼。此張量中的值表示待修剪的 t 中相應元素的重要性。如果未指定或為 None,則將使用張量 t 本身代替。

  • default_mask (torch.Tensor,可選) – 來自先前修剪迭代的掩碼(如果存在)。確定修剪應作用於張量的哪一部分時應考慮此掩碼。如果為 None,則預設為全一掩碼。

返回值

張量 t 的修剪版本。

remove(module)[source]

從模組中移除修剪重新引數化。

名為 name 的修剪引數永久保持修剪狀態,名為 name+'_orig' 的引數將從引數列表中移除。類似地,名為 name+'_mask' 的緩衝區將從緩衝區列表中移除。

注意

修剪本身不會被撤銷或反轉!

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源