快捷方式

複數

複數是可以表示為 a+bja + bj 形式的數,其中 a 和 b 是實數,j 是虛數單位,滿足方程 j2=1j^2 = -1。複數在數學和工程領域中頻繁出現,尤其是在訊號處理等主題中。傳統上,許多使用者和庫(例如 TorchAudio)透過使用形狀為 (...,2)(..., 2) 的浮點張量來表示資料,其中最後一個維度包含實部和虛部值來處理複數。

複數資料型別的張量在使用複數時提供了更自然的使用者體驗。對複數張量執行的操作(例如,torch.mv(), torch.matmul())可能比模擬它們的浮點張量操作更快、更節省記憶體。PyTorch 中涉及複數的運算已最佳化,以使用向量化彙編指令和專用核心(例如 LAPACK、cuBlas)。

注意

torch.fft 模組中的譜運算支援原生複數張量。

警告

複數張量是 Beta 特性,可能會有所更改。

建立複數張量

我們支援兩種複數資料型別:torch.cfloattorch.cdouble

>>> x = torch.randn(2,2, dtype=torch.cfloat)
>>> x
tensor([[-0.4621-0.0303j, -0.2438-0.5874j],
     [ 0.7706+0.1421j,  1.2110+0.1918j]])

注意

複數張量的預設資料型別由預設浮點資料型別決定。如果預設浮點資料型別是 torch.float64,則推斷複數的資料型別為 torch.complex128,否則假定其資料型別為 torch.complex64

torch.linspace()torch.logspace()torch.arange() 外,所有工廠函式都支援複數張量。

從舊錶示轉換

目前使用形狀為 (...,2)(..., 2) 的實數張量來解決缺少複數張量問題的使用者,可以使用 torch.view_as_complex()torch.view_as_real() 輕鬆地在其程式碼中切換到使用複數張量。請注意,這些函式不執行任何複製,並返回輸入張量的檢視。

>>> x = torch.randn(3, 2)
>>> x
tensor([[ 0.6125, -0.1681],
     [-0.3773,  1.3487],
     [-0.0861, -0.7981]])
>>> y = torch.view_as_complex(x)
>>> y
tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j])
>>> torch.view_as_real(y)
tensor([[ 0.6125, -0.1681],
     [-0.3773,  1.3487],
     [-0.0861, -0.7981]])

訪問實部和虛部

可以使用 realimag 屬性訪問複數張量的實部和虛部值。

注意

訪問 realimag 屬性不會分配任何記憶體,並且對 realimag 張量進行原地更新將更新原始複數張量。此外,返回的 realimag 張量不是連續的。

>>> y.real
tensor([ 0.6125, -0.3773, -0.0861])
>>> y.imag
tensor([-0.1681,  1.3487, -0.7981])

>>> y.real.mul_(2)
tensor([ 1.2250, -0.7546, -0.1722])
>>> y
tensor([ 1.2250-0.1681j, -0.7546+1.3487j, -0.1722-0.7981j])
>>> y.real.stride()
(2,)

幅角和模長

可以使用 torch.angle()torch.abs() 計算複數張量的幅角和絕對值。

>>> x1=torch.tensor([3j, 4+4j])
>>> x1.abs()
tensor([3.0000, 5.6569])
>>> x1.angle()
tensor([1.5708, 0.7854])

線性代數

許多線性代數運算,例如 torch.matmul()torch.linalg.svd()torch.linalg.solve() 等,都支援複數。如果您希望請求我們目前不支援的運算,請搜尋是否已提交過相關問題,如果尚未提交,請提交一個

序列化

複數張量可以被序列化,從而允許資料以複數形式儲存。

>>> torch.save(y, 'complex_tensor.pt')
>>> torch.load('complex_tensor.pt')
tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j])

自動微分

PyTorch 支援複數張量的自動微分。計算出的梯度是 Conjugate Wirtinger 導數,其負數恰好是梯度下降演算法中使用的最速下降方向。因此,所有現有最佳化器都可以直接用於複數引數。有關更多詳細資訊,請檢視註釋 複數自動微分

最佳化器

從語義上講,我們定義對具有複數引數的 PyTorch 最佳化器執行步驟等同於對這些複數引數的 torch.view_as_real() 等效項執行相同的最佳化器步驟。更具體地說,

>>> params = [torch.rand(2, 3, dtype=torch.complex64) for _ in range(5)]
>>> real_params = [torch.view_as_real(p) for p in params]

>>> complex_optim = torch.optim.AdamW(params)
>>> real_optim = torch.optim.AdamW(real_params)

real_optimcomplex_optim 將計算對引數的相同更新,儘管兩個最佳化器之間可能存在細微的數值差異,類似於 foreach 與 forloop 最佳化器以及 capturable 與預設最佳化器之間的數值差異。有關更多詳細資訊,請參閱 https://pytorch.com.tw/docs/stable/notes/numerical_accuracy.html

具體來說,雖然您可以認為我們的最佳化器處理複數張量與分別對它們的 p.realp.imag 部分進行最佳化是相同的,但實現細節並非完全如此。請注意,torch.view_as_real() 等效項會將複數張量轉換為形狀為 (...,2)(..., 2) 的實數張量,而將複數張量拆分為兩個張量是 2 個大小為 (...)(...) 的張量。這種區別對逐點最佳化器(如 AdamW)沒有影響,但會導致執行全域性歸約(如 LBFGS)的最佳化器產生細微差異。我們目前沒有執行逐張量歸約的最佳化器,因此尚未定義此行為。如果您有需要精確定義此行為的用例,請提交一個問題。

我們尚未完全支援以下子系統

  • 量化

  • JIT

  • 稀疏張量

  • 分散式

如果其中任何一項有助於您的用例,請搜尋是否已提交過相關問題,如果尚未提交,請提交一個

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取適合初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源