複數¶
複數是可以表示為 形式的數字,其中 a 和 b 是實數,而 *j* 稱為虛數單位,滿足方程式 。複數經常出現在數學和工程學中,特別是在信號處理等主題中。傳統上,許多使用者和函式庫(例如 TorchAudio)透過使用形狀為 的浮點數張量表示資料來處理複數,其中最後一個維度包含實部和虛部值。
複數 dtype 的張量在處理複數時提供了更自然的使用者體驗。與模擬它們的浮點數張量上的運算相比,複數張量上的運算(例如 torch.mv()、torch.matmul())可能會更快且記憶體效率更高。PyTorch 中涉及複數的運算已最佳化,可使用向量化組譯指令和專用的核心(例如 LAPACK、cuBlas)。
備註
torch.fft 模組 中的頻譜運算支援原生複數張量。
警告
複數張量是一項測試功能,可能會有所變更。
建立複數張量¶
我們支援兩種複數 dtype:torch.cfloat 和 torch.cdouble
>>> x = torch.randn(2,2, dtype=torch.cfloat)
>>> x
tensor([[-0.4621-0.0303j, -0.2438-0.5874j],
     [ 0.7706+0.1421j,  1.2110+0.1918j]])
備註
複數張量的預設 dtype 由預設浮點數 dtype 決定。如果預設浮點數 dtype 為 torch.float64,則複數的 dtype 為 torch.complex128,否則假設它們的 dtype 為 torch.complex64。
除了 torch.linspace()、torch.logspace() 和 torch.arange() 之外,所有工廠函數都支援複數張量。
從舊表示法轉換¶
目前使用形狀為  的實數張量來解決缺乏複數張量問題的使用者,可以使用 torch.view_as_complex() 和 torch.view_as_real() 輕鬆地在程式碼中切換使用複數張量。請注意,這些函數不會執行任何複製,而是傳回輸入張量的檢視。
>>> x = torch.randn(3, 2)
>>> x
tensor([[ 0.6125, -0.1681],
     [-0.3773,  1.3487],
     [-0.0861, -0.7981]])
>>> y = torch.view_as_complex(x)
>>> y
tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j])
>>> torch.view_as_real(y)
tensor([[ 0.6125, -0.1681],
     [-0.3773,  1.3487],
     [-0.0861, -0.7981]])
存取實部和虛部¶
可以使用 real 和 imag 來存取複數張量的實部和虛部值。
備註
存取 real 和 imag 屬性不會分配任何記憶體,並且對 real 和 imag 張量的就地更新將會更新原始的複數張量。此外,返回的 real 和 imag 張量不是連續的。
>>> y.real
tensor([ 0.6125, -0.3773, -0.0861])
>>> y.imag
tensor([-0.1681,  1.3487, -0.7981])
>>> y.real.mul_(2)
tensor([ 1.2250, -0.7546, -0.1722])
>>> y
tensor([ 1.2250-0.1681j, -0.7546+1.3487j, -0.1722-0.7981j])
>>> y.real.stride()
(2,)
角度和絕對值¶
可以使用 torch.angle() 和 torch.abs() 計算複數張量的角度和絕對值。
>>> x1=torch.tensor([3j, 4+4j])
>>> x1.abs()
tensor([3.0000, 5.6569])
>>> x1.angle()
tensor([1.5708, 0.7854])
線性代數¶
許多線性代數運算,例如 torch.matmul()、torch.linalg.svd()、torch.linalg.solve() 等,都支援複數。如果您想請求我們目前不支援的運算,請先 搜尋 是否已有相關的問題被提出,如果沒有,請 提出一個新問題。
序列化¶
複數張量可以被序列化,允許將資料儲存為複數值。
>>> torch.save(y, 'complex_tensor.pt')
>>> torch.load('complex_tensor.pt')
tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j])
自動微分¶
PyTorch 支援複數張量的自動微分。計算出的梯度是共軛 Wirtinger 導數,其負數恰好是梯度下降演算法中使用的最速下降方向。因此,所有現有的優化器都可以直接用於具有複數參數的情況。如需更多詳細資訊,請查看備註 複數的自動微分。
優化器¶
在語義上,我們將使用複數參數逐步執行 PyTorch 優化器定義為等同於在複數參數的 torch.view_as_real() 等效項上逐步執行相同的優化器。更具體地說
>>> params = [torch.rand(2, 3, dtype=torch.complex64) for _ in range(5)]
>>> real_params = [torch.view_as_real(p) for p in params]
>>> complex_optim = torch.optim.AdamW(params)
>>> real_optim = torch.optim.AdamW(real_params)
real_optim 和 complex_optim 將計算對參數相同的更新,儘管兩個優化器之間可能會有一些細微的數值差異,類似於 foreach 與 for 迴圈優化器以及可捕獲與預設優化器之間的數值差異。如需更多詳細資訊,請參閱 https://pytorch.com.tw/docs/stable/notes/numerical_accuracy.html。
具體來說,雖然您可以將我們的優化器處理複數張量的方式視為與分別優化其 p.real 和 p.imag 部分相同,但實作細節並非完全如此。請注意,torch.view_as_real() 等效項會將複數張量轉換為形狀為  的實數張量,而將複數張量拆分為兩個張量則是兩個大小為  的張量。這種區別對逐點優化器(如 AdamW)沒有影響,但會在執行全局縮減的優化器(如 LBFGS)中造成輕微的差異。我們目前沒有執行逐個張量縮減的優化器,因此尚未定義此行為。如果您有需要精確定義此行為的用例,請提出一個問題。
我們不完全支援以下子系統
- 量化 
- JIT 
- 稀疏張量 
- 分散式