捷徑

廣播語義

許多 PyTorch 運算支援 NumPy 的廣播語義。如需詳細資訊,請參閱 https://numpy.org/doc/stable/user/basics.broadcasting.html

簡而言之,如果 PyTorch 運算支援廣播,則其張量引數可以自動擴展為相同的大小(無需複製資料)。

一般語義

如果滿足以下規則,則兩個張量「可廣播」

  • 每個張量至少有一個維度。

  • 從尾隨維度開始迭代維度大小時,維度大小必須相等,其中一個為 1,或者其中一個不存在。

例如

>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# same shapes are always broadcastable (i.e. the above rules always hold)

>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# x and y are not broadcastable, because x does not have at least 1 dimension

# can line up trailing dimensions
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(  3,1,1)
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist

# but:
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(  3,1,1)
# x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3

如果兩個張量 xy「可廣播」,則結果張量大小的計算方式如下

  • 如果 xy 的維度數不相等,則在維度較少的張量的維度前面加上 1,使其長度相等。

  • 然後,對於每個維度大小,結果維度大小是 xy 沿該維度的最大大小。

例如

# can line up trailing dimensions to make reading easier
>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty(  3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])

# but not necessary:
>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])

>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

就地語義

一個複雜因素是,就地運算不允許就地張量因廣播而改變形狀。

例如

>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(3,1,1)
>>> (x.add_(y)).size()
torch.Size([5, 3, 4, 1])

# but:
>>> x=torch.empty(1,3,1)
>>> y=torch.empty(3,1,7)
>>> (x.add_(y)).size()
RuntimeError: The expanded size of the tensor (1) must match the existing size (7) at non-singleton dimension 2.

回溯相容性

先前版本的 PyTorch 允許某些逐點函數在形狀不同的張量上執行,只要每個張量的元素數量相等即可。然後,透過將每個張量視為一維來執行逐點運算。PyTorch 現在支援廣播,「一維」逐點行為被認為已棄用,並且在張量不可廣播但元素數量相同的情況下會產生 Python 警告。

請注意,在兩個張量形狀不同,但可廣播且元素數量相同的情況下,引入廣播可能會導致回溯不相容的變更。例如

>>> torch.add(torch.ones(4,1), torch.randn(4))

先前會產生大小為:torch.Size([4,1]) 的張量,但現在會產生大小為:torch.Size([4,4]) 的張量。為了協助識別程式碼中可能存在廣播引入的回溯不相容性的情況,您可以將 torch.utils.backcompat.broadcast_warning.enabled 設定為 True,這會在這種情況下產生 Python 警告。

例如

>>> torch.utils.backcompat.broadcast_warning.enabled=True
>>> torch.add(torch.ones(4,1), torch.ones(4))
__main__:1: UserWarning: self and other do not have the same shape, but are broadcastable, and have the same number of elements.
Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得適用於初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源