torch.nn.utils.prune.random_structured¶
- torch.nn.utils.prune.random_structured(module, name, amount, dim)[source][source]¶
透過移除指定維度上的隨機通道來修剪張量。
透過移除指定
amount數量的(當前未修剪的)通道(沿著指定的dim隨機選擇),來修剪module中名為name的引數所對應的張量。透過以下方式原地修改模組(並返回修改後的模組):新增一個名為
name+'_mask'的命名 buffer,它對應於修剪方法應用於引數name的二進位制掩碼。將引數
name替換為其修剪後的版本,而原始(未修剪的)引數則儲存在一個名為name+'_orig'的新引數中。
- 引數
- 返回
輸入模組的修改(即已修剪的)版本
- 返回型別
模組 (nn.Module)
示例
>>> m = prune.random_structured( ... nn.Linear(5, 3), 'weight', amount=3, dim=1 ... ) >>> columns_pruned = int(sum(torch.sum(m.weight, dim=0) == 0)) >>> print(columns_pruned) 3