廣播語義¶
許多 PyTorch 運算支援 NumPy 的廣播語義。如需詳細資訊,請參閱 https://numpy.org/doc/stable/user/basics.broadcasting.html。
簡而言之,如果 PyTorch 運算支援廣播,則其張量引數可以自動擴展為相同的大小(無需複製資料)。
一般語義¶
如果滿足以下規則,則兩個張量「可廣播」
- 每個張量至少有一個維度。 
- 從尾隨維度開始迭代維度大小時,維度大小必須相等,其中一個為 1,或者其中一個不存在。 
例如
>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# same shapes are always broadcastable (i.e. the above rules always hold)
>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# x and y are not broadcastable, because x does not have at least 1 dimension
# can line up trailing dimensions
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(  3,1,1)
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist
# but:
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(  3,1,1)
# x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3
如果兩個張量 x、y「可廣播」,則結果張量大小的計算方式如下
- 如果 - x和- y的維度數不相等,則在維度較少的張量的維度前面加上 1,使其長度相等。
- 然後,對於每個維度大小,結果維度大小是 - x和- y沿該維度的最大大小。
例如
# can line up trailing dimensions to make reading easier
>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty(  3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])
# but not necessary:
>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
就地語義¶
一個複雜因素是,就地運算不允許就地張量因廣播而改變形狀。
例如
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(3,1,1)
>>> (x.add_(y)).size()
torch.Size([5, 3, 4, 1])
# but:
>>> x=torch.empty(1,3,1)
>>> y=torch.empty(3,1,7)
>>> (x.add_(y)).size()
RuntimeError: The expanded size of the tensor (1) must match the existing size (7) at non-singleton dimension 2.
回溯相容性¶
先前版本的 PyTorch 允許某些逐點函數在形狀不同的張量上執行,只要每個張量的元素數量相等即可。然後,透過將每個張量視為一維來執行逐點運算。PyTorch 現在支援廣播,「一維」逐點行為被認為已棄用,並且在張量不可廣播但元素數量相同的情況下會產生 Python 警告。
請注意,在兩個張量形狀不同,但可廣播且元素數量相同的情況下,引入廣播可能會導致回溯不相容的變更。例如
>>> torch.add(torch.ones(4,1), torch.randn(4))
先前會產生大小為:torch.Size([4,1]) 的張量,但現在會產生大小為:torch.Size([4,4]) 的張量。為了協助識別程式碼中可能存在廣播引入的回溯不相容性的情況,您可以將 torch.utils.backcompat.broadcast_warning.enabled 設定為 True,這會在這種情況下產生 Python 警告。
例如
>>> torch.utils.backcompat.broadcast_warning.enabled=True
>>> torch.add(torch.ones(4,1), torch.ones(4))
__main__:1: UserWarning: self and other do not have the same shape, but are broadcastable, and have the same number of elements.
Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.