• 文件 >
  • torch.nn >
  • torch.nn.utils.parametrize.register_parametrization
快捷方式

torch.nn.utils.parametrize.register_parametrization

torch.nn.utils.parametrize.register_parametrization(module, tensor_name, parametrization, *, unsafe=False)[原始碼][原始碼]

在一個模組中為張量註冊引數化(parametrization)。

為簡單起見,假設 tensor_name="weight"。當訪問 module.weight 時,模組將返回引數化版本 parametrization(module.weight)。如果原始張量需要梯度,反向傳播將透過 parametrization 進行微分,並且最佳化器將相應地更新張量。

模組首次註冊引數化時,此函式會向模組新增一個型別為 ParametrizationList 的屬性 parametrizations

張量 weight 上的引數化列表可透過 module.parametrizations.weight 訪問。

原始張量可透過 module.parametrizations.weight.original 訪問。

透過在同一屬性上註冊多個引數化,可以將引數化串聯起來。

註冊的引數化的訓練模式在註冊時會更新,以匹配宿主模組的訓練模式。

引數化的引數和緩衝區有一個內建的快取系統,可以使用上下文管理器 cached() 啟用。

parametrization 可以選擇實現一個方法,簽名如下:

def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]]

當註冊第一個引數化時,此方法在未引數化的張量上呼叫,以計算原始張量的初始值。如果未實現此方法,則原始張量即為未引數化的張量。

如果在張量上註冊的所有引數化都實現了 right_inverse,則可以透過賦值來初始化引數化張量,如下例所示。

第一個引數化可能依賴於多個輸入。這可以透過從 right_inverse 返回一個張量元組來實現(參見下面 RankOne 引數化的示例實現)。

在這種情況下,無約束張量也位於 module.parametrizations.weight 下,名稱分別為 original0original1 等。

注意

如果 unsafe=False(預設值),則會分別呼叫 forward 和 right_inverse 方法一次,以執行多項一致性檢查。如果 unsafe=True,則在張量未引數化時會呼叫 right_inverse,否則不會呼叫任何方法。

注意

在大多數情況下,right_inverse 將是一個函式,使得 forward(right_inverse(X)) == X(參見右逆)。有時,當引數化不是滿射時,放寬此要求可能是合理的。

警告

如果引數化依賴於多個輸入,register_parametrization() 將註冊一些新的引數。如果在建立最佳化器後註冊此類引數化,則需要手動將這些新引數新增到最佳化器中。參見 torch.Optimizer.add_param_group()

引數
  • module (nn.Module) – 要在其上註冊引數化的模組

  • tensor_name (str) – 要在其上註冊引數化的引數或緩衝區的名稱

  • parametrization (nn.Module) – 要註冊的引數化

關鍵字引數

unsafe (bool) – 一個布林標誌,表示引數化是否可以改變張量的資料型別和形狀。預設值:False 警告:註冊時不會檢查引數化的一致性。啟用此標誌的風險自負。

丟擲異常

ValueError – 如果模組沒有名為 tensor_name 的引數或緩衝區

返回型別

Module

示例

>>> import torch
>>> import torch.nn as nn
>>> import torch.nn.utils.parametrize as P
>>>
>>> class Symmetric(nn.Module):
>>>     def forward(self, X):
>>>         return X.triu() + X.triu(1).T  # Return a symmetric matrix
>>>
>>>     def right_inverse(self, A):
>>>         return A.triu()
>>>
>>> m = nn.Linear(5, 5)
>>> P.register_parametrization(m, "weight", Symmetric())
>>> print(torch.allclose(m.weight, m.weight.T))  # m.weight is now symmetric
True
>>> A = torch.rand(5, 5)
>>> A = A + A.T   # A is now symmetric
>>> m.weight = A  # Initialize the weight to be the symmetric matrix A
>>> print(torch.allclose(m.weight, A))
True
>>> class RankOne(nn.Module):
>>>     def forward(self, x, y):
>>>         # Form a rank 1 matrix multiplying two vectors
>>>         return x.unsqueeze(-1) @ y.unsqueeze(-2)
>>>
>>>     def right_inverse(self, Z):
>>>         # Project Z onto the rank 1 matrices
>>>         U, S, Vh = torch.linalg.svd(Z, full_matrices=False)
>>>         # Return rescaled singular vectors
>>>         s0_sqrt = S[0].sqrt().unsqueeze(-1)
>>>         return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
>>>
>>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne())
>>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item())
1

文件

訪問 PyTorch 的完整開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源