快捷方式

廣播語義

許多 PyTorch 操作支援 NumPy 的廣播語義。詳情請參見 https://numpy.org/doc/stable/user/basics.broadcasting.html

簡而言之,如果一個 PyTorch 操作支援廣播,那麼其張量引數可以自動擴充套件到相同大小(無需複製資料)。

一般語義

如果滿足以下規則,則兩個張量“可廣播”:

  • 每個張量至少有一個維度。

  • 當從末尾維度開始迭代維度大小時,維度大小必須相等,或者其中一個維度大小為 1,或者其中一個維度不存在。

例如

>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# same shapes are always broadcastable (i.e. the above rules always hold)

>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# x and y are not broadcastable, because x does not have at least 1 dimension

# can line up trailing dimensions
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(  3,1,1)
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist

# but:
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(  3,1,1)
# x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3

如果兩個張量 xy “可廣播”,則結果張量的大小按以下方式計算:

  • 如果 xy 的維度數量不相等,則在維度較少的張量前面補 1,使其維度長度相等。

  • 然後,對於每個維度大小,結果維度大小是 xy 在該維度上的大小的最大值。

例如

# can line up trailing dimensions to make reading easier
>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty(  3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])

# but not necessary:
>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])

>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

就地語義

一個複雜之處在於,就地操作不允許就地張量因廣播而改變形狀。

例如

>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(3,1,1)
>>> (x.add_(y)).size()
torch.Size([5, 3, 4, 1])

# but:
>>> x=torch.empty(1,3,1)
>>> y=torch.empty(3,1,7)
>>> (x.add_(y)).size()
RuntimeError: The expanded size of the tensor (1) must match the existing size (7) at non-singleton dimension 2.

向後相容性

PyTorch 的早期版本允許某些逐元素函式在形狀不同的張量上執行,只要每個張量中的元素數量相等即可。然後,逐元素操作透過將每個張量視為 1 維來執行。PyTorch 現在支援廣播,並且“1 維”逐元素行為被認為是已棄用的,在張量不可廣播但元素數量相等的情況下會生成 Python 警告。

請注意,引入廣播可能會導致在兩個張量形狀不同但可廣播且元素數量相等的情況下出現向後不相容的更改。例如

>>> torch.add(torch.ones(4,1), torch.randn(4))

之前會產生一個大小為 torch.Size([4,1]) 的張量,但現在產生一個大小為 torch.Size([4,4]) 的張量。為了幫助識別程式碼中可能存在因廣播引入的向後不相容情況,可以將 torch.utils.backcompat.broadcast_warning.enabled 設定為 True,這會在此類情況下生成 Python 警告。

例如

>>> torch.utils.backcompat.broadcast_warning.enabled=True
>>> torch.add(torch.ones(4,1), torch.ones(4))
__main__:1: UserWarning: self and other do not have the same shape, but are broadcastable, and have the same number of elements.
Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源