快捷方式

torch.nn.utils.prune.global_unstructured

torch.nn.utils.prune.global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs)[source][source]

透過應用指定的 pruning_method,全域性剪枝 parameters 中所有引數對應的張量。

透過以下方式原地修改模組:

  1. 新增一個名為 name+'_mask' 的命名緩衝區,對應於剪枝方法應用於引數 name 的二值掩碼。

  2. 將引數 name 替換為其剪枝版本,而原始(未剪枝的)引數則儲存在一個名為 name+'_orig' 的新引數中。

引數
  • parameters (Iterable of (module, name) tuples) – 需要以全域性方式(即在決定剪枝哪些權重之前先聚合所有權重)剪枝的模型引數。module 必須是 nn.Module 型別,name 必須是字串。

  • pruning_method (function) – 此模組中有效的剪枝函式,或使用者實現的滿足實現指南且 PRUNING_TYPE='unstructured' 的自定義函式。

  • importance_scores (dict) – 一個字典,將 (module, name) 元組對映到相應的引數的重要性分數張量。該張量的形狀應與引數相同,用於計算剪枝掩碼。如果未指定或為 None,將使用引數本身作為其重要性分數。

  • kwargs – 其他關鍵字引數,例如:amount (int or float): 跨指定引數剪枝的數量。如果為 float,應介於 0.0 和 1.0 之間,表示要剪枝的引數比例。如果為 int,表示要剪枝的絕對引數數量。

引發

TypeError – 如果 PRUNING_TYPE != 'unstructured'

注意

由於全域性結構化剪枝在未對引數大小進行範數歸一化的情況下意義不大,我們目前將全域性剪枝的範圍限制在非結構化方法。

示例

>>> from torch.nn.utils import prune
>>> from collections import OrderedDict
>>> net = nn.Sequential(OrderedDict([
...     ('first', nn.Linear(10, 4)),
...     ('second', nn.Linear(4, 1)),
... ]))
>>> parameters_to_prune = (
...     (net.first, 'weight'),
...     (net.second, 'weight'),
... )
>>> prune.global_unstructured(
...     parameters_to_prune,
...     pruning_method=prune.L1Unstructured,
...     amount=10,
... )
>>> print(sum(torch.nn.utils.parameters_to_vector(net.buffers()) == 0))
tensor(10)

文件

查閱 PyTorch 全面的開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源