• 文件 >
  • 在 XLA 設備上執行 PyTorch
捷徑

在 XLA 設備上執行 PyTorch

PyTorch 使用 torch_xla 套件 在 XLA 設備(如 TPU)上執行。本文檔描述如何在這些設備上執行您的模型。

建立 XLA 張量

PyTorch/XLA 為 PyTorch 新增了一個新的 xla 設備類型。此設備類型的工作方式與其他 PyTorch 設備類型相同。例如,以下是如何建立和列印 XLA 張量

import torch
import torch_xla
import torch_xla.core.xla_model as xm

t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)

這段程式碼看起來應該很熟悉。PyTorch/XLA 使用與常規 PyTorch 相同的介面,並新增了一些功能。匯入 torch_xla 會初始化 PyTorch/XLA,而 xm.xla_device() 會傳回目前的 XLA 設備。根據您的環境,這可能是 CPU 或 TPU。

XLA 張量是 PyTorch 張量

PyTorch 操作可以在 XLA 張量上執行,就像 CPU 或 CUDA 張量一樣。

例如,可以將 XLA 張量加在一起

t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
print(t0 + t1)

或進行矩陣乘法

print(t0.mm(t1))

或與神經網路模組一起使用

l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_out = linear(l_in)
print(l_out)

與其他設備類型一樣,XLA 張量只能與相同設備上的其他 XLA 張量一起使用。所以像這樣的程式碼

l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
print(l_out)
# Input tensor is not an XLA tensor: torch.FloatTensor

會擲回錯誤,因為 torch.nn.Linear 模組在 CPU 上。

在 XLA 設備上執行模型

建構新的 PyTorch 網路或轉換現有網路以在 XLA 設備上執行只需要幾行程 XLA 特定的程式碼。以下程式碼片段重點說明了在單一設備上執行以及使用 XLA 多程序在多個設備上執行時的這些行。

在單一 XLA 設備上執行

以下程式碼片段顯示了在單一 XLA 設備上進行網路訓練

import torch_xla.core.xla_model as xm

device = xm.xla_device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

for data, target in train_loader:
  optimizer.zero_grad()
  data = data.to(device)
  target = target.to(device)
  output = model(data)
  loss = loss_fn(output, target)
  loss.backward()

  optimizer.step()
  xm.mark_step()

這段程式碼片段重點說明了切換模型以在 XLA 上執行是多麼容易。模型定義、資料載入器、優化器和訓練迴圈可以在任何設備上工作。唯一 XLA 特定的程式碼是獲取 XLA 設備並標記步驟的幾行程。在每次訓練迭代結束時呼叫 xm.mark_step() 會導致 XLA 執行其目前的圖形並更新模型的參數。有關 XLA 如何建立圖形和執行操作的更多資訊,請參閱 XLA 張量深入探討

使用多程序在多個 XLA 設備上執行

PyTorch/XLA 可以輕鬆地透過在多個 XLA 設備上執行來加速訓練。以下程式碼片段顯示了如何執行

import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

def _mp_fn(index):
  device = xm.xla_device()
  mp_device_loader = pl.MpDeviceLoader(train_loader, device)

  model = MNIST().train().to(device)
  loss_fn = nn.NLLLoss()
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

  for data, target in mp_device_loader:
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    xm.optimizer_step(optimizer)

if __name__ == '__main__':
  xmp.spawn(_mp_fn, args=())

此多設備程式碼片段與先前的單一設備程式碼片段之間存在三個差異。讓我們逐一查看。

  • xmp.spawn()

    • 建立每個執行 XLA 設備的程序。

    • 每個程序只能存取分配給目前程序的設備。例如,在 TPU v4-8 上,將會產生 4 個程序,每個程序將擁有 1 個 TPU 設備。

    • 請注意,如果您在每個行程上列印 xm.xla_device(),您會在所有裝置上看到 xla:0。這是因為每個行程只能看到一個裝置。這並不代表多行程沒有作用。僅在 TPU v2 和 TPU v3 上使用 PJRT 執行階段時才會執行,因為將會有 #devices/2 個行程,並且每個行程將會有 2 個執行緒(請查看此 文件 以了解更多詳細資訊)。

  • MpDeviceLoader

    • 將訓練資料載入到每個裝置。

    • MpDeviceLoader 可以包裝在 PyTorch 資料載入器上。它可以將資料預載入到裝置,並將資料載入與裝置執行重疊,以提高效能。

    • MpDeviceLoader 也會在每個 batches_per_execution(預設為 1)批次產生時為您呼叫 xm.mark_step

  • xm.optimizer_step(optimizer)

    • 整合裝置之間的梯度,並發出 XLA 裝置步驟計算。

    • 它非常類似於 all_reduce_gradients + optimizer.step() + mark_step,並返回已減少的損失。

模型定義、優化器定義和訓練迴圈保持不變。

**注意:** 重要的是要注意,當使用多行程處理時,使用者只能從 xmp.spawn() 的目標函數(或在呼叫堆疊中具有 xmp.spawn() 作為父項的任何函數)內開始檢索和訪問 XLA 裝置。

如需有關在具有多行程處理的多個 XLA 裝置上訓練網路的更多資訊,請參閱 完整的多行程處理範例

在 TPU Pod 上執行

針對不同加速器的多主機設定可能非常不同。本文件將討論多主機訓練中與裝置無關的部分,並將以 TPU + PJRT 執行階段(目前在 1.13 和 2.x 版本中可用)為例。

在您開始之前,請先查看我們在 這裡 的使用者指南,其中將說明一些 Google Cloud 基礎知識,例如如何使用 gcloud 命令以及如何設定您的專案。您也可以查看 這裡 以獲取所有 Cloud TPU 的操作指南。本文件將重點介紹 PyTorch/XLA 角度的設定。

假設您在 train_mnist_xla.py 中有來自上述章節的 mnist 範例。如果是單一主機多裝置訓練,您將透過 ssh 連線到 TPUVM 並執行以下命令:

PJRT_DEVICE=TPU python3 train_mnist_xla.py

現在,為了在 TPU v4-16(具有 2 個主機,每個主機有 4 個 TPU 裝置)上執行相同的模型,您需要:

  • 確保每個主機都可以訪問訓練腳本和訓練資料。這通常是透過使用 gcloud scp 命令或 gcloud ssh 命令將訓練腳本複製到所有主機來完成的。

  • 在所有主機上同時執行相同的訓練命令。

gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=$ZONE --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 train_mnist_xla.py"

上述 gcloud ssh 命令將透過 ssh 連線到 TPUVM Pod 中的所有主機,並同時在所有主機上執行相同的命令。

**注意:** 您需要在 TPUVM 虛擬機器外部執行上述 gcloud 命令。

多行程訓練和多主機訓練的模型程式碼和訓練腳本是相同的。PyTorch/XLA 和底層基礎架構將確保每個裝置都知道全域拓撲以及每個裝置的本地和全域序號。跨裝置通訊將發生在所有裝置之間,而不是僅限於本地裝置。

有關 PJRT 執行階段以及如何在 Pod 上執行它的更多詳細資訊,請參閱此 文件。有關 PyTorch/XLA 和 TPU Pod 的更多資訊,以及在 TPU Pod 上執行帶有虛擬資料的 resnet50 的完整指南,請參閱此 指南

XLA 張量深入探討

使用 XLA 張量和裝置只需要更改幾行程式碼。但是,即使 XLA 張量的行為與 CPU 和 CUDA 張量非常相似,但它們的內部結構卻有所不同。本節將說明 XLA 張量的獨特之處。

XLA 張量是惰性的

CPU 和 CUDA 張量會立即或急切地啟動操作。另一方面,XLA 張量是惰性的。它們會在圖形中記錄操作,直到需要結果為止。像這樣延遲執行可以让 XLA 對其進行優化。例如,可以將多個獨立操作的圖形融合成單個優化操作。

惰性執行通常對呼叫者是不可見的。PyTorch/XLA 會自動構建圖形,將其發送到 XLA 裝置,並在 XLA 裝置和 CPU 之間複製資料時進行同步。在執行優化器步驟時插入屏障會明確地同步 CPU 和 XLA 裝置。有關我們惰性張量設計的更多資訊,您可以閱讀 這篇論文

XLA 張量和 bfloat16

PyTorch/XLA 在 TPU 上執行時可以使用 bfloat16 資料類型。事實上,PyTorch/XLA 在 TPU 上處理浮點數類型(torch.floattorch.double)的方式有所不同。此行為由 XLA_USE_BF16XLA_DOWNCAST_BF16 環境變數控制

  • 默認情況下,torch.floattorch.double 在 TPU 上都是 torch.float

  • 如果設定了 XLA_USE_BF16,則 torch.floattorch.double 在 TPU 上都是 bfloat16

  • 如果設定了 XLA_DOWNCAST_BF16,則 torch.float 在 TPU 上是 bfloat16,而 torch.double 在 TPU 上是 float32

  • 如果 PyTorch 張量的資料類型為 torch.bfloat16,則它將直接映射到 TPU bfloat16(XLA BF16 基元類型)。

開發人員應該注意,*無論 TPU 上的 XLA 張量使用的是哪種實際資料類型,它們都會報告其 PyTorch 資料類型*。此轉換是自動且不透明的。如果將 TPU 上的 XLA 張量移回 CPU,則它將從其實際資料類型轉換為其 PyTorch 資料類型。根據您的程式碼運作方式,由處理單元類型觸發的這種轉換可能很重要。

記憶體佈局

XLA 張量的內部資料表示對使用者是不透明的。與 CPU 和 CUDA 張量不同,它們不會公開其儲存空間,並且始終看起來是連續的。這使得 XLA 可以調整張量的記憶體佈局以獲得更好的效能。

在 CPU 和 XLA 裝置之間移動 XLA 張量

可以將 XLA 張量從 CPU 移動到 XLA 裝置,也可以從 XLA 裝置移動到 CPU。如果移動了視圖,則它所查看的資料也會被複製到另一個裝置,並且視圖關係不會被保留。換句話說,一旦資料被複製到另一個裝置,它就與其先前的裝置或其上的任何張量沒有關係。同樣,根據您的程式碼運作方式,理解和適應這種轉換可能很重要。

儲存和載入 XLA 張量

在儲存 XLA 張量之前,應該將其移動到 CPU,如下列程式碼片段所示

import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()

t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)

tensors = (t0.cpu(), t1.cpu())

torch.save(tensors, 'tensors.pt')

tensors = torch.load('tensors.pt')

t0 = tensors[0].to(device)
t1 = tensors[1].to(device)

這使您可以將載入的張量放在任何可用的裝置上,而不僅僅是初始化它們的裝置。

根據上面關於將 XLA 張量移動到 CPU 的說明,在處理視圖時必須小心。建議您在載入張量並將其移動到目標裝置後重新建立視圖,而不是儲存視圖。

提供了一個實用程式 API,可以透過先將資料移動到 CPU 來儲存資料

import torch
import torch_xla
import torch_xla.core.xla_model as xm

xm.save(model.state_dict(), path)

如果有多個裝置,則上述 API 將只儲存主裝置序號(0)的資料。

如果記憶體與模型參數的大小相比有限,則提供了一個 API 來減少主機上的記憶體佔用

import torch_xla.utils.serialization as xser

xser.save(model.state_dict(), path)

此 API 會將 XLA 張量一次一個地串流到 CPU,從而減少主機記憶體的使用量,但它需要匹配的載入 API 才能恢復

import torch_xla.utils.serialization as xser

state_dict = xser.load(path)
model.load_state_dict(state_dict)

可以直接儲存 XLA 張量,但建議不要這樣做。XLA 張量始終會載入回儲存它們的裝置,如果該裝置不可用,則載入將會失敗。與所有 PyTorch 一樣,PyTorch/XLA 正在積極開發中,這種行為在未來可能會發生變化。

編譯快取

XLA 編譯器會將追蹤的 HLO 轉換為在裝置上執行的可執行檔。編譯可能會非常耗時,並且在 HLO 在多次執行中沒有變化的情況下,可以將編譯結果持久化到磁碟以供重複使用,從而顯著減少開發迭代時間。

請注意,如果 HLO 在多次執行之間發生變化,則仍然會發生重新編譯。

這目前是一個實驗性的選擇性加入 API,必須在執行任何計算之前啟用。初始化是透過 initialize_cache API 完成的

import torch_xla.runtime as xr
xr.initialize_cache('YOUR_CACHE_PATH', readonly=False)

這將在指定的路徑初始化持久性編譯快取。readonly 參數可用於控制工作器是否能夠寫入快取,這在為 SPMD 工作負載使用共用快取掛載時非常有用。

進一步閱讀

更多文件可在 PyTorch/XLA 儲存庫 中找到。更多在 TPU 上執行網路的範例可在 這裡 找到。

PyTorch/XLA API

xla_model

torch_xla.core.xla_model.xla_device(n=None, devkind=None)[source]

返回給定 XLA 設備的執行個體。

參數
  • n (python:int, 選用) – 要返回的特定執行個體(序號)。如果指定,則會返回特定的 XLA 設備執行個體。否則,將返回 devkind 的第一個設備。

  • devkind (string..., 選用) – 如果指定,則為設備類型,例如 TPUCUDACPU 或自訂 PJRT 設備。已棄用。

返回

具有請求執行個體的 torch.device

torch_xla.core.xla_model.get_xla_supported_devices(devkind=None, max_devices=None)[原始碼]

返回給定類型支援的設備清單。

參數
  • devkind (string..., 選用) – 如果指定,則為設備類型,例如 TPUCUDACPU 或自訂 PJRT 設備的名稱。

  • max_devices (python:int, 選用) – 要返回的該類型設備的最大數量。

返回

0'、'xla:1',...]

返回類型

設備字串清單,例如 ['xla

torch_xla.core.xla_model.xla_device_hw(device)[原始碼]

返回給定設備的硬體類型。

參數

device (stringtorch.device) – 將映射到實際設備的 xla 設備。

返回

給定設備的硬體類型的字串表示。

torch_xla.core.xla_model.get_ordinal(defval=0)[原始碼]

擷取目前執行緒的複寫序號。

序號範圍從 0 到 xrt_world_size() 減 1。

參數

defval (python:int, 選用) – 如果沒有可用的複寫資訊,則返回的預設值。執行階段會忽略。預設值:0

返回

目前執行緒的複寫序號。

torch_xla.core.xla_model.get_local_ordinal(defval=0)[原始碼]

擷取目前執行緒的複寫本機序號。

本機序號範圍從 0 到本機設備數量減 1。

參數

defval (python:int, 選用) – 如果沒有可用的複寫資訊,則返回的預設值。執行階段會忽略。預設值:0

返回

目前執行緒的複寫本機序號。

torch_xla.core.xla_model.is_master_ordinal(local=True)[原始碼]

檢查目前處理序是否為主序號 (0)。

參數

local (bool) – 是否應檢查本機或全域主序號。如果是多主機複寫,則只有一個全域主序號(主機 0,設備 0),而有 NUM_HOSTS 個本機主序號。預設值:True

返回

一個布林值,指示目前處理序是否為主序號。

torch_xla.core.xla_model.xrt_world_size(defval=1)[原始碼]

擷取參與複寫的設備數量。

參數

defval (python:int, 選用) – 如果沒有可用的複寫資訊,則返回的預設值。預設值:1

返回

參與複寫的設備數量。

torch_xla.core.xla_model.all_reduce(reduce_type, inputs, scale=1.0, groups=None, pin_layout=True)[原始碼]

對輸入張量執行就地歸約操作。

參數
  • reduce_type (string) – xm.REDUCE_SUMxm.REDUCE_MULxm.REDUCE_ANDxm.REDUCE_ORxm.REDUCE_MINxm.REDUCE_MAX 其中之一。

  • inputs – 單個 torch.Tensor 或要對其執行 all reduce 操作的 torch.Tensor 清單。

  • scale (python:float) – 在歸約後套用的預設縮放值。預設值:1.0

  • groups (list, 選用) –

    清單的清單,表示 all_reduce() 操作的複本群組。範例:[[0, 1, 2, 3], [4, 5, 6, 7]]

    定義兩個群組,一個包含 [0, 1, 2, 3] 複本,另一個包含 [4, 5, 6, 7] 複本。如果為 None,則只有一個群組包含所有複本。

  • pin_layout (bool, 選用) – 是否要為此通訊操作固定配置。配置固定可以防止參與通訊的每個處理序具有略有不同的程式時發生潛在的資料損毀,但可能會導致某些 xla 編譯失敗。當您看到類似「HloModule 具有配置約束的混合」的錯誤訊息時,請取消固定配置。

返回

如果傳遞單個 torch.Tensor,則返回值為包含歸約值(跨複本)的 torch.Tensor。如果傳遞清單/元組,則此函式會對輸入張量執行就地 all-reduce 操作,並返回清單/元組本身。

torch_xla.core.xla_model.all_gather(value, dim=0, groups=None, output=None, pin_layout=True)[原始碼]

沿著給定維度執行 all-gather 操作。

參數
  • value (torch.Tensor) – 輸入張量。

  • dim (python:int) – 聚集維度。預設值:0

  • groups (list, 選用) –

    清單的清單,表示 all_gather() 操作的複本群組。範例:[[0, 1, 2, 3], [4, 5, 6, 7]]

    定義兩個群組,一個包含 [0, 1, 2, 3] 複本,另一個包含 [4, 5, 6, 7] 複本。如果為 None,則只有一個群組包含所有複本。

  • output (torch.Tensor) – 選用的輸出張量。

  • pin_layout (bool, 選用) – 是否要為此通訊操作固定配置。配置固定可以防止參與通訊的每個處理序具有略有不同的程式時發生潛在的資料損毀,但可能會導致某些 xla 編譯失敗。當您看到類似「HloModule 具有配置約束的混合」的錯誤訊息時,請取消固定配置。

返回

一個張量,在 dim 維度中,包含來自參與複本的所有值。

torch_xla.core.xla_model.all_to_all(value, split_dimension, concat_dimension, split_count, groups=None, pin_layout=True)[原始碼]

對輸入張量執行 XLA AllToAll() 操作。

請參閱:https://www.tensorflow.org/xla/operation_semantics#alltoall

參數
  • value (torch.Tensor) – 輸入張量。

  • split_dimension (python:int) – 應該進行分割的維度。

  • concat_dimension (python:int) – 應該進行串聯的維度。

  • split_count (python:int) – 分割計數。

  • groups (list, 選用) –

    清單的清單,表示 all_reduce() 操作的複本群組。範例:[[0, 1, 2, 3], [4, 5, 6, 7]]

    定義兩個群組,一個包含 [0, 1, 2, 3] 複本,另一個包含 [4, 5, 6, 7] 複本。如果為 None,則只有一個群組包含所有複本。

  • pin_layout (bool, 選用) – 是否要為此通訊操作固定配置。配置固定可以防止參與通訊的每個處理序具有略有不同的程式時發生潛在的資料損毀,但可能會導致某些 xla 編譯失敗。當您看到類似「HloModule 具有配置約束的混合」的錯誤訊息時,請取消固定配置。

返回

all_to_all() 操作的結果 torch.Tensor

torch_xla.core.xla_model.add_step_closure(closure, args=(), run_async=False)[原始碼]

將閉包新增到要在步驟結束時執行的閉包清單中。

在模型訓練期間,很多時候都需要列印/報告(列印到主控台、發佈到 TensorBoard 等)需要檢查中間張量內容的資訊。在模型程式碼的不同點檢查不同張量的內容需要多次執行,並且通常會導致效能問題。新增步驟閉包將確保它會在屏障之後執行,屆時所有活動張量都將已具體化為設備資料。活動張量將包括由閉包參數捕獲的張量。因此,使用 add_step_closure() 將確保即使在多個閉包排隊、需要檢查多個張量時,也只會執行一次。步驟閉包將按照排隊順序依序執行。請注意,即使使用此 API 會最佳化執行,但建議每 N 個步驟節流一次列印/報告事件。

參數
  • closure (callable) – 要呼叫的函式。

  • args (tuple) – 要傳遞給閉包的參數。

  • run_async - 如果為 True,則異步執行閉包。

torch_xla.core.xla_model.wait_device_ops(devices=[])[原始碼]

等待給定裝置上的所有非同步操作完成。

參數

devices (字串..., 可選) - 需要等待其非同步操作的裝置。如果為空,則將等待所有本地裝置。

torch_xla.core.xla_model.optimizer_step(optimizer, barrier=False, optimizer_args={}, groups=None, pin_layout=True)[原始碼]

執行提供的優化器步驟並發出 XLA 裝置步驟計算。

參數
  • optimizer (torch.Optimizer) - 需要呼叫其 step() 函式的 torch.Optimizer 實例。將使用 optimizer_args 命名參數呼叫 step() 函式。

  • barrier (布林值, 可選) - 是否應在此 API 中發出 XLA 張量屏障。如果使用 PyTorch XLA ParallelLoaderDataParallel 支援,則不需要這樣做,因為屏障將由 XLA 資料載入器迭代器 next() 呼叫發出。預設值:False

  • optimizer_args (dict, 可選) - optimizer.step() 呼叫的命名參數字典。

  • groups (list, 選用) –

    清單的清單,表示 all_reduce() 操作的複本群組。範例:[[0, 1, 2, 3], [4, 5, 6, 7]]

    定義兩個群組,一個包含 [0, 1, 2, 3] 複本,另一個包含 [4, 5, 6, 7] 複本。如果為 None,則只有一個群組包含所有複本。

  • pin_layout (布林值, 可選) - 是否在減少梯度時固定佈局。有關詳細資訊,請參閱 xm.all_reduce

返回

optimizer.step() 呼叫返回的相同值。

torch_xla.core.xla_model.save(data, file_or_path, master_only=True, global_master=False)[原始碼]

將輸入資料儲存到檔案中。

儲存的資料在儲存之前會被轉移到 PyTorch CPU 裝置,因此後續的 torch.load() 將載入 CPU 資料。處理檢視時必須小心。建議您在載入張量並將其移至目標裝置後重新建立檢視,而不是儲存檢視。

參數
  • data - 要儲存的輸入資料。Python 物件的任何嵌套組合(清單、元組、集合、字典等)。

  • file_or_path - 資料儲存操作的目的地。檔案路徑或 Python 檔案物件。如果 master_onlyFalse,則路徑或檔案物件必須指向不同的目的地,否則來自同一主機的所有寫入都將相互覆蓋。

  • master_only (布林值, 可選) - 是否只有主裝置應該儲存資料。如果為 False,則對於參與複製的每個序號,file_or_path 參數應該是不同的檔案或路徑,否則同一主機上的所有副本都將寫入到同一個位置。預設值:True

  • global_master (布林值, 可選) - 當 master_onlyTrue 時,此標誌控制每個主機的主裝置(如果 global_masterFalse)是否儲存內容,或者只有全域主裝置(序號 0)儲存內容。預設值:False

  • sync (布林值, 可選) - 是否在儲存張量後同步所有副本。如果為 True,則所有副本都必須呼叫 xm.save,否則主程序將掛起。

torch_xla.core.xla_model.rendezvous(tag, payload=b'', replicas=[])[原始碼]

等待所有網格客戶端到達指定的會合點。

注意:PJRT 不支援 XRT 網格伺服器,因此這實際上是 xla_rendezvous 的別名。

參數
  • tag (字串) - 要加入的會合點的名稱。

  • payload (位元組, 可選) - 要發送到會合點的有效負載。

  • replicas (清單, python:int) - 參與會合點的副本序號。空表示網格中的所有副本。預設值:[]

返回

所有其他核心交換的有效負載,核心序號 i 的有效負載位於返回元組中的位置 i

torch_xla.core.xla_model.do_on_ordinals(target, data=(), ordinals=(0,))[原始碼]

僅在給定的序號集上執行函式。

參數
  • target (可呼叫) - 要在 ordinals 上執行的函式。

  • data - target 函式的任何輸入資料,其中包含張量。 target 函式使用的所有 XLA 張量都必須在此參數中傳遞。函式使用的所有其他資料都可以像往常一樣由 Python 解譯器捕獲。預設值:()

  • ordinals (清單, python:int) - 應該執行 target 函式的序號清單/集合。預設值:(0,)

返回

在執行 target 函式的序號中,函式返回值,否則為 None

torch_xla.core.xla_model.mesh_reduce(tag, data, reduce_fn)[原始碼]

執行圖外客戶端網格縮減。

參數
  • tag (字串) - 要加入的會合點的名稱。

  • data - 要縮減的資料。 reduce_fn 可呼叫物件將接收一個清單,其中包含來自所有網格客戶端程序(每個核心一個)的相同資料的副本。

  • reduce_fn (可呼叫) - 一個函式,接收一個 data 類物件的清單並返回縮減的結果。

返回

縮減後的值。

torch_xla.core.xla_model.set_rng_state(seed, device=None)[原始碼]

設定亂數產生器狀態。

參數
  • seed (python:integer) - 要設定的狀態。

  • device (字串, 可選) - 需要設定 RNG 狀態的裝置。如果缺少,則將設定預設裝置種子。

torch_xla.core.xla_model.get_rng_state(device=None)[原始碼]

取得目前執行的亂數產生器狀態。

參數

device (字串, 可選) - 需要擷取其 RNG 狀態的裝置。如果缺少,則將設定預設裝置種子。

返回

RNG 狀態,以整數表示。

torch_xla.core.xla_model.get_memory_info(device)[原始碼]

擷取裝置記憶體資訊。

參數

device (字串) - 請求其記憶體資訊的裝置。

返回

一個字典,包含 kb_free(可用記憶體,以 KB 為單位)和 kb_total(總記憶體,以 KB 為單位)鍵。

torch_xla.core.xla_model.get_stablehlo(tensors=None) str[原始碼]

以字串格式取得計算圖的 StableHLO。

如果 tensors 不為空,則將傾印以 tensors 作為輸出的圖。如果 tensors 為空,則將傾印整個計算圖。TODO(lsy323):當 tensors 為空時,一些中間張量也將作為輸出傾印。需要進一步調查。

對於推斷圖,建議將模型輸出傳遞給 tensors。對於訓練圖,識別“輸出”並不容易。建議使用空的 tensors

要在 StableHLO 中啟用原始碼行資訊,請設定環境變數 XLA_HLO_DEBUG=1。

參數

tensors (list[torch.Tensor], 可選) - 表示 StableHLO 圖的輸出/根的張量。

返回

以字串格式表示的 StableHLO 模組。

torch_xla.core.xla_model.get_stablehlo_bytecode(tensors=None) bytes[source]

以位元組碼格式取得計算圖的 StableHLO。

如果 tensors 不為空,則將傾印以 tensors 作為輸出的圖。如果 tensors 為空,則將傾印整個計算圖。TODO(lsy323):當 tensors 為空時,一些中間張量也將作為輸出傾印。需要進一步調查。

對於推斷圖,建議將模型輸出傳遞給 tensors。對於訓練圖,識別“輸出”並不容易。建議使用空的 tensors

參數

tensors (list[torch.Tensor], 可選) - 表示 StableHLO 圖的輸出/根的張量。

返回

以位元組碼格式呈現 StableHLO 模組。

torch_xla.core.functions.all_reduce(reduce_type, value, scale=1.0, groups=None)[source]

對輸入張量執行就地歸約運算。

這與 xm.all_reduce() 相同,但支援 Autograd 微分。

參數
  • reduce_type (字串) – REDUCE_SUMREDUCE_MULREDUCE_ANDREDUCE_ORREDUCE_MINREDUCE_MAX 其中之一。

  • value (torch.Tensor) – 要對其執行全部歸約運算的目標。

  • scale (python:float) – 在歸約後套用的預設縮放值。預設值:1.0

  • groups (list, 選用) –

    清單的清單,表示 all_reduce() 操作的複本群組。範例:[[0, 1, 2, 3], [4, 5, 6, 7]]

    定義兩個群組,一個包含 [0, 1, 2, 3] 複本,另一個包含 [4, 5, 6, 7] 複本。如果為 None,則只有一個群組包含所有複本。

返回

跨選定副本的歸約值。

torch_xla.core.functions.all_gather(value, dim=0)[source]

沿著給定維度執行 all-gather 操作。

這與 xm.all_gather() 相同,但支援 Autograd 微分。

參數
  • value (torch.Tensor) – 輸入張量。

  • dim (python:int) – 聚集維度。預設值:0

返回

一個張量,在 dim 維度中,包含來自參與複本的所有值。

torch_xla.core.functions.nms(boxes, scores, score_threshold, iou_threshold, output_size)[source]

執行非極大值抑制運算。

參數
  • boxes (torch.Tensor) – 形狀為 [N, 4]torch.Tensor,以 (y0, x0, y1, x1) 形式列出方框座標。

  • scores (torch.Tensor) – 形狀為 [N]torch.Tensor,列出每個方框的分數。

  • score_threshold (torch.Tensor) – 方框符合有效條件的最低分數。

  • iou_threshold (torch.Tensor) – 觸發重疊邏輯的最小 IOU(交集聯集)分數。

  • output_size (python:int) – 傳回索引的最大數量(必須小於或等於 N)。

返回

一個 torch.Tensor 的元組,第一個元素是選定的方框索引,第二個元素是有效方框的數量。

distributed

class torch_xla.distributed.parallel_loader.ParallelLoader(loader, devices, batchdim=0, batches_per_execution=1, loader_prefetch_size=8, device_prefetch_size=4, host_to_device_transfer_threads=1, input_sharding=None)[source]

使用背景資料上傳功能包裝現有的 PyTorch DataLoader。

參數
  • loader (torch.utils.data.DataLoader) – 要包裝的 PyTorch DataLoader。

  • devices (torch.device…) – 資料必須傳送到的裝置清單。 loader 傳回的第 i 個樣本將會傳送到 devices[i % len(devices)]

  • batchdim (python:int, 選用) – 保存批次大小的維度。預設值:0

  • loader_prefetch_size (python:int, 選用) – 從 loader 讀取樣本的執行緒所使用的佇列最大容量,將由將資料上傳到裝置的工作執行緒處理。預設值:8

  • device_prefetch_size (python:int, 選用) – 每個裝置佇列的最大大小,工作執行緒會將已傳送到裝置的張量存放到這些佇列中。預設值:4

  • host_to_device_transfer_threads (python:int, 選用) – 平行運作以將資料從載入器佇列傳輸到裝置佇列的執行緒數量。預設值:1

  • input_sharding (ShardingSpec, 選用) – 載入後要套用到相容輸入張量的分片規格。預設值:無

per_device_loader(device)[source]

擷取給定裝置的載入器迭代器物件。

參數

device (torch.device) – 正在請求其整個載入器的裝置。

返回

device 的載入器迭代器物件。這不是 torch.utils.data.DataLoader 介面,而是一個 Python 迭代器,它傳回與包裝的 torch.utils.data.DataLoader 傳回的張量資料結構相同,但駐留在 XLA 裝置上的資料結構。

torch_xla.distributed.xla_multiprocessing.spawn(fn, args=(), nprocs=None, join=True, daemon=False, start_method='spawn')[source]

啟用基於多程序的複寫。

參數
  • fn (可呼叫) – 要為參與複寫的每個裝置呼叫的函數。此函數將會使用第一個參數作為複寫中程序的全域索引來呼叫,後面接著 args 中傳遞的參數。

  • args (元組) – fn 的參數。預設值:空元組

  • nprocs (python:int) – 複寫的程序/裝置數量。目前,如果指定,可以是 1 或最大裝置數量。

  • join (布林值) – 呼叫是否應該封鎖,等待已產生的程序完成。預設值:True

  • daemon (布林值) – 正在產生的程序是否應該設定 daemon 旗標(請參閱 Python 多程序 API)。預設值:False

  • start_method (字串) – Python multiprocessing 程序建立方法。預設值:spawn

返回

torch.multiprocessing.spawn API 傳回的相同物件。如果 nprocs 為 1,則會直接呼叫 fn 函數,且 API 將會傳回 None。

class torch_xla.distributed.xla_multiprocessing.MpModelWrapper(model)[source]

包裝模型以在使用 fork 方法時,將主機記憶體使用量降到最低。

這個類別應該與 spawn(…, start_method=’fork’) API 一起使用,以將主機記憶體的使用量降到最低。模型不是在每個多程序程序上建立,而是複寫模型的初始主機記憶體,而是在全域範圍內建立一次,然後在 spawn() 目標函數內移動到每個裝置。範例

WRAPPED_MODEL = xmp.MpModelWrapper(MyNetwork())

def _mp_fn(index, ...):
  device = xm.xla_device()
  model = WRAPPED_MODEL.to(device)
  ...

xmp.spawn(_mp_fn, ..., start_method='fork')

這個方法有兩個優點。首先,它只使用一份記憶體頁面來儲存原始模型權重,其次,它會將包裝模型移動到每個裝置的動作序列化,方法是在過程中降低系統記憶體的負載。

to(device)[source]

擷取已移至指定裝置的模型。

參數

device (torch.device) – 模型應該移至的裝置。

返回

指定裝置上的模型。

class torch_xla.distributed.xla_multiprocessing.MpSerialExecutor[source]

用於以序列化方式在多核心程序之間執行函數的公用程式。

範例

# At global scope.
SERIAL_EXEC = xmp.MpSerialExecutor()

def load_dataset(path):
  return maybe_download_and_load(path)

def _mp_fn(index, ...):
  # Avoid all cores downloading the same data with the serial executor.
  dataset = SERIAL_EXEC.run(lambda: load_dataset('/tmp/mnist-data'))
  ...

xmp.spawn(_mp_fn, ...)
run(fn)[原始碼]

針對每個核心程序,以序列化方式執行所提供的函數。

參數

fn (可呼叫) – 以序列化方式執行的函數。

返回

fn 的回傳值。

utils

class torch_xla.utils.utils.SampleGenerator(data, sample_count)[原始碼]

迭代器,會傳回給定輸入資料的多個樣本。

可以用來取代 PyTorch DataLoader 來產生合成資料。

參數
  • data – 應該在每個迭代器步驟傳回的資料。

  • sample_count – 要傳回的 data 樣本的最大數量。

class torch_xla.utils.utils.DataWrapper[原始碼]

用於包裝要傳送到裝置的資料結構的工具類別。

torch_xla.utils.serialization.save(data, path, master_only=True, global_master=False)[原始碼]

將輸入資料儲存到檔案中。

儲存的資料在儲存之前會被轉移到 PyTorch CPU 裝置,因此後續的 torch.load() 將載入 CPU 資料。處理檢視時必須小心。建議您在載入張量並將其移至目標裝置後重新建立檢視,而不是儲存檢視。

參數
  • data - 要儲存的輸入資料。Python 物件的任何嵌套組合(清單、元組、集合、字典等)。

  • path – 資料儲存作業的目的地檔案。如果 master_onlyFalse,則路徑必須指向不同的目的地,否則來自相同主機的所有寫入都會互相覆蓋。

  • master_only (bool, 可選) – 是否只有主裝置應該儲存資料。如果為 False,則 path 參數應該是參與複製的每個序號的不同路徑,否則相同主機上的所有複本都將寫入相同的位置。預設值:True

  • global_master (布林值, 可選) - 當 master_onlyTrue 時,此標誌控制每個主機的主裝置(如果 global_masterFalse)是否儲存內容,或者只有全域主裝置(序號 0)儲存內容。預設值:False

torch_xla.utils.serialization.load(path)[原始碼]

載入先前使用 save() API 儲存的資料。

參數

path (str) – 傳遞給 save() API 的路徑。

返回

已載入的資料。

測試

PyTorch/XLA 新手指南

本文件提供 PyTorch XLA 的高階概觀,並說明幾個範例,說明如何轉換 PyTorch 程式碼以在 XLA 裝置(例如 TPU)上執行。這不是一個完整的解決方案,可能需要根據特定的程式碼進行其他更改。但是,本文件應該可以作為轉換過程的起點。

對一些 XLA 細節的基本高階理解

本節簡要概述 PyTorch XLA 的基本細節,

這將有助於讀者更好地理解所需的程式碼修改和優化。它是對 這裡 描述的 API 指南的補充。

與逐行執行程式碼且在擷取 PyTorch 張量 的值之前不會阻塞執行的常規 PyTorch 不同,PyTorch XLA 的工作方式不同。它會迭代 Python 程式碼並在遇到屏障(如下所述)之前,將 (PyTorch)XLA 張量 上的操作記錄在中間表示 (IR) 圖形中。這個生成 IR 圖形的過程稱為追蹤(LazyTensor 追蹤或程式碼追蹤)。然後,PyTorch XLA 會將 IR 圖形轉換為稱為 HLO(高階操作碼)的低階機器可讀格式。HLO 是一種針對 XLA 編譯器的計算表示形式,允許它為其執行的硬體生成高效的程式碼。HLO 被饋送到 XLA 編譯器以進行編譯和優化。然後,PyTorch XLA 會快取編譯,以便以後在需要時重複使用。圖形的編譯是在主機(CPU)上完成的,主機是執行 Python 程式碼的機器。如果有多個 XLA 裝置,則主機將分別為每個裝置編譯程式碼,除非使用 SPMD(單程式多資料)。例如,v4-8 有一個主機和 四個裝置。在這種情況下,主機將分別為四個裝置編譯程式碼。在 Pod 切片的情況下,當有多個主機時,每個主機都會為其連接的 XLA 裝置進行編譯。如果使用 SPMD,則程式碼只會在每個主機上為所有裝置編譯一次(對於給定的形狀和計算)。

img

如需更多詳細資訊和範例,請參閱 LazyTensor 指南

只有在需要張量的值時,才會執行 IR 圖形中的操作。這稱為張量的評估或具體化。有時這也被稱為惰性求值,它可以顯著 提高效能

Pytorch XLA 中的同步操作,如列印、記錄、檢查點或回呼,會阻塞追蹤並導致執行速度變慢。如果操作需要 XLA 張量的特定值,例如 print(xla_tensor_z),則追蹤會被阻塞,直到主機可以使用該張量的值。請注意,只有負責計算該張量值的圖形部分會被執行。這些操作不會切割 IR 圖形,但它們會透過 TransferFromDevice 觸發主機與裝置之間的通訊,這會導致效能下降。

屏障是一種特殊指令,它告訴 XLA 執行 IR 圖形並具體化張量。這表示 PyTorch XLA 張量將被評估,並且主機將可以使用結果。Pytorch XLA 中使用者公開的屏障是 xm.mark_step(),它會打破 IR 圖形並導致程式碼在 XLA 裝置上執行。xm.mark_step 的一個關鍵特性是,與同步操作不同,它不會在裝置執行圖形時阻塞進一步的追蹤。但是,它確實會阻止存取正在具體化的張量的值。

LazyTensor 指南中的範例說明了在兩個張量相加的簡單情況下會發生什麼。現在,假設我們有一個 for 迴圈,它會新增 XLA 張量並稍後使用該值

for x, y in tensors_on_device:
    z += x + y

如果沒有屏障,Python 追蹤將產生一個單一圖形,該圖形將張量的加法 len(tensors_on_device) 次包裝起來。這是因為追蹤不會捕捉到 for 迴圈,因此迴圈的每次迭代都會建立一個對應於 z += x+y 計算的新子圖形並將其新增到圖形中。以下是一個 len(tensors_on_device)=3 的範例。

img

但是,在迴圈的結尾引入屏障將產生一個較小的圖形,該圖形將在 for 迴圈內的第一次傳遞期間編譯一次,並將在下一個 len(tensors_on_device)-1 迭代中重複使用。屏障將向追蹤發出信號,表示到目前為止追蹤的圖形可以提交執行,並且如果該圖形之前已經見過,則將重複使用快取的編譯程式。

for x, y in tensors_on_device:
    z += x + y
    xm.mark_step()

在這種情況下,將有一個使用 len(tensors_on_device)=3 次的小圖形。

img

需要強調的是,在 PyTorch XLA 中,如果結尾處有屏障,則會追蹤 for 迴圈內的 Python 程式碼,並且會為每次迭代建構一個新的圖形。這可能是效能的嚴重瓶頸。

當在相同形狀的張量上發生相同計算時,可以重複使用 XLA 圖形。如果輸入或中間張量的形狀發生變化,則 XLA 編譯器將使用新的張量形狀重新編譯一個新的圖形。這表示,如果您有動態形狀,或者您的程式碼沒有重複使用張量圖形,則在 XLA 上執行您的模型將不適合該用例。將輸入填充到固定形狀中可以作為一種選擇,以幫助避免動態形狀。否則,編譯器將花費大量時間來優化和融合將不再使用的操作。

圖形大小和編譯時間之間的權衡也很重要。如果有一個大型 IR 圖形,則 XLA 編譯器可能會花費大量時間來優化和融合操作。這可能會導致非常長的編譯時間。但是,由於在編譯過程中執行了優化,因此後續執行可能會快得多。

有時值得使用 xm.mark_step() 打破 IR 圖形。如上所述,這將產生一個可以在以後重複使用的較小圖形。但是,使圖形變小會降低 XLA 編譯器可以執行的優化。

另一個需要考慮的重點是 MPDeviceLoader。一旦您的程式碼在 XLA 裝置上執行,請考慮使用 XLA MPDeviceLoader 包裝 torch 資料載入器,它會將資料預載入到裝置以提高效能,並在其中包含 xm.mark_step()。後者會自動打破對資料批次的迭代,並將其發送以供執行。請注意,如果您沒有使用 MPDeviceLoader,則可能需要在 optimizer_step() 中設定 barrier=True 以在執行訓練作業時啟用 xm.mark_step(),或明確新增 xm.mark_step()

TPU 設定

使用基礎映像建立 TPU 以使用每晚構建,或透過指定 RUNTIME_VERSION 從穩定版本建立 TPU。

export ZONE=us-central2-b
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v4-8 # v4-16, v4-32, …
export RUNTIME_VERSION=tpu-vm-v4-pt-2.0 # or tpu-vm-v4-base
export TPU_NAME=your_tpu_name

gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version=${RUNTIME_VERSION} \
--subnetwork=tpusubnet

如果您有一個單一主機虛擬機器(例如 v4-8),則可以直接 ssh 到您的虛擬機器並直接從虛擬機器執行以下命令。否則,在 TPU Pod 的情況下,您可以使用 --worker=all --command="",類似於

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--zone=us-central2-b \
--worker=all \
--command="pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl"

接下來,如果您使用的是基礎映像,請安裝每晚套件和所需的程式庫

pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl
​​pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl
sudo apt-get install libopenblas-dev -y

sudo apt-get update && sudo apt-get install libgl1 -y # diffusion specific

將代碼轉換為 PyTorch XLA

修改代碼的一般準則

  • cuda 替換為 xm.xla_device()

  • 移除會存取 XLA 張量值的進度條和列印語句

  • 減少會存取 XLA 張量值的日誌記錄和回呼

  • 使用 MPDeviceLoader 包裝資料載入器

  • 進行效能分析以進一步優化代碼

請記住:每個案例都是獨特的,因此您可能需要針對每個案例採取不同的措施。

範例 1:在單個 TPU 裝置上使用 PyTorch Lightning 進行 Stable Diffusion 推論

作為第一個範例,請考慮 PyTorch Lightning 中 Stable Diffusion 模型的 推論代碼,可以透過以下命令行指令執行:

python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse"

如需參考,可以在 此處 找到下方所述修改的差異。讓我們逐步進行說明。如同上述一般準則,請先從與 cuda 裝置相關的變更開始。此推論代碼旨在 GPU 上執行,並且可以在多個位置找到 cuda。首先,從 此行 移除 model.cuda(),並從 此處 移除 precision_scope。此外,將 此行 中的 cuda 裝置替換為 xla 裝置,類似於以下代碼:

接下來,此特定模型配置使用的是 FrozenCLIPEmbedder,因此我們也將修改 此行。為了簡單起見,我們將直接在此教學中定義 device,但您也可以將 device 值傳遞給函數。

import torch_xla.core.xla_model as xm
self.device = xm.xla_device()

代碼中另一個具有 cuda 特定代碼的地方是 DDIM 排程器。在檔案頂部新增 import torch_xla.core.xla_model as xm,然後替換 這些 行:

if attr.device != torch.device("cuda"):
   attr = attr.to(torch.device("cuda"))

替換為:

device = xm.xla_device()
attr = attr.to(torch.device(device))

接下來,您可以透過移除列印語句、停用進度條以及減少或移除回呼和日誌記錄來減少裝置(TPU)和主機(CPU)之間的通訊。這些操作需要裝置停止執行、回到 CPU、執行日誌記錄/回呼,然後再返回裝置。這可能會成為嚴重的效能瓶頸,尤其是在大型模型上。

進行這些變更後,代碼將可以在 TPU 上執行。但是,效能會非常慢。這是因為 XLA 編譯器會嘗試構建一個單一(巨大)的圖形,其中包含推論步驟的數量(在本例中為 50 個),因為 for 迴圈內沒有屏障。編譯器難以優化圖形,這會導致嚴重的效能下降。如上所述,使用屏障 (xm.mark_step()) 打破 for 迴圈將產生較小的圖形,編譯器更容易優化。這也將允許編譯器重複使用上一步驟中的圖形,從而提高效能。

現在,代碼 已準備好在合理的時限內於 TPU 上執行。您可以透過 擷取效能分析 並進一步調查來進行更多優化和分析。但是,本文件未涵蓋此主題。

注意:如果您使用的是 v4-8 TPU,則您有 4 個可用的 XLA(TPU)裝置。如上所述執行代碼只會使用一個 XLA 裝置。為了在所有 4 個裝置上執行,您需要使用 xmp.spawn() 函數在所有裝置上產生代碼。我們將在下一個範例中討論 xmp.spawn

範例 2:HF Stable Diffusion 推論

現在,請考慮在 HuggingFace diffusers 函式庫中使用 Stable Diffusion 推論 來處理 SD-XL 和 2.1 版本的模型。如需參考,可以在此 儲存庫 中找到下方所述的變更。您可以複製儲存庫並在 TPU VM 上使用以下命令執行推論:

(vm)$ git clone https://github.com/pytorch-tpu/diffusers.git
(vm)$ cd diffusers/examples/text_to_image/
(vm)$ python3 inference_tpu_single_device.py

由於沒有 bf16 版本的 SD-XL 模型可用,您可以使用 XLA_USE_BF16=1 旗標將所有值轉換為 bf16 並加快訓練速度。

(vm)$ XLA_USE_BF16=1 python3 inference_tpu_single_device.py # uses sd-xl version

(vm)$ python3 inference_tpu_multidevice.py # uses 2.1 version

(2.1 版本的模型中已包含 torch.bfloat16)。

警告:請注意 此處 強調的注意事項。

在單個 TPU 裝置上執行

本節說明需要對 文字轉圖像推論範例 代碼進行哪些變更才能在 TPU 上執行。

原始代碼使用 Lora 進行推論,但本教學將不使用它。相反地,我們將在初始化管道時將 model_id 參數設定為 stabilityai/stable-diffusion-xl-base-0.9。我們還將使用預設的排程器 (DPMSolverMultistepScheduler)。但是,也可以對其他排程器進行類似的變更。

git clone https://github.com/huggingface/diffusers
cd diffusers
pip install . # pip install -e .

cd examples/text_to_image/
pip install -r requirements.txt
pip install invisible_watermark transformers accelerate safetensors

(如果找不到 accelerate,請登出並重新登入。)

登入 HF 並同意模型卡片上的 sd-xl 0.9 授權。接下來,前往 帳戶→設定→存取權杖 並產生新的權杖。複製權杖並在您的虛擬機器上使用該特定權杖值執行以下命令:

(vm)$ huggingface-cli login --token _your_copied_token__

HuggingFace 自述文件提供了旨在 GPU 上執行的 PyTorch 代碼。若要在 TPU 上執行,第一步是將 CUDA 裝置變更為 XLA 裝置。這可以透過將 pipe.to("cuda") 行替換為以下行來完成:

import torch_xla.core.xla_model as xm
device = xm.xla_device()
pipe.to(device)

此外,請務必注意,第一次使用 XLA 執行推論時,編譯時間會很長。例如,HuggingFace 的 Stable Diffusion XL 模型推論的編譯時間可能需要大約一個小時,而實際的推論時間可能只需要 5 秒,具體取決於批次大小。同樣地,GPT-2 模型可能需要大約 10-15 分鐘來編譯,之後訓練週期的時間會變得更快。這是因為 XLA 會構建將要執行的計算圖形,然後針對其執行的特定硬體優化此圖形。但是,一旦編譯了圖形,就可以將其重複用於後續的推論,這將會快得多。因此,如果您只執行一次推論,則可能無法從使用 XLA 中受益。但是,如果您要多次執行推論,或者您要對提示清單執行推論,則在執行前幾個推論後,您將開始看到 XLA 的優勢。例如,如果您對 10 個提示的清單執行推論,則第一個推論(可能是兩個1)的編譯時間可能會很長,但其餘的推論步驟將會快得多。這是因為 XLA 將重複使用其為第一個推論編譯的圖形。

如果您嘗試在不做任何其他變更的情況下執行代碼,您會注意到編譯時間非常長(超過 6 小時)。這是因為 XLA 編譯器會嘗試為所有排程器步驟構建單一圖形,類似於我們在上一個範例中討論的內容。為了加快代碼的執行速度,我們需要使用 xm.mark_step() 將圖形分解成更小的片段,並在後續步驟中重複使用它們。這發生在 pipe.__call__ 函數 中的 這些行 中。停用進度條、移除回呼並在 for 迴圈的末尾新增 xm.mark_step() 可以顯著加快代碼的速度。此 提交 中提供了變更。

此外,self.scheduler.step() 函數(預設情況下使用 DPMSolverMultistepScheduler 排程器)存在一些問題,這些問題在 PyTorch XLA 注意事項 中有所說明。此函數中的 .nonzero().item() 呼叫會向 CPU 發送張量評估請求,從而觸發裝置與主機之間的通訊。這並不可取,因為它會降低代碼的執行速度。在這種情況下,我們可以透過將索引直接傳遞給函數來避免這些呼叫。這將防止函數向 CPU 發送請求,並提高代碼的效能。變更包含在此 提交 中。現在,代碼已準備好在 TPU 上執行。

效能分析和效能分析

為了進一步調查模型的效能,我們可以使用效能分析 指南 對其進行效能分析。根據經驗,效能分析腳本應該使用適合記憶體的最大批次大小來執行,以實現 最佳記憶體使用率。這也有助於將代碼的追蹤與裝置執行重疊,從而更有效地利用裝置。效能分析的持續時間應該足夠長,以便至少擷取一個步驟。模型在 TPU 上的良好效能意味著裝置與主機之間的通訊最少,並且裝置在沒有閒置時間的情況下持續執行程序。

inference_tpu_*.py 檔案中啟動伺服器並按照指南中所述運行 capture_profile.py 腳本,將會提供有關在裝置上運行的進程的資訊。目前,只有一個 XLA 裝置會被分析。為了更好地理解 TPU 閒置時間(分析中的間隔),應該在程式碼中添加分析追蹤(xp.Trace())。xp.Trace() 會測量追蹤主機上使用追蹤包裝的 Python 程式碼所需的時間。在本例中,xp.Trace() 追蹤被添加到管道U-net 模型中,以測量在主機(CPU)上運行程式碼特定部分所需的時間。

如果分析中的間隔是由於在主機上進行的 Python 程式碼追蹤所致,那麼這可能是一個瓶頸,並且沒有其他直接的優化方法。否則,應進一步分析程式碼以了解注意事項並進一步提高效能。請注意,您不能在呼叫 xm.mark_step() 的程式碼部分使用 xp.Trace() 包裝。

為了說明這一點,我們可以查看已經按照分析指南上傳到 TensorBoard 的已捕獲分析。

從 Stable Diffusion 模型版本 2.1 開始

如果我們在不插入任何追蹤的情況下捕獲分析,我們將看到以下內容

Alt text

v4-8 上的單個 TPU 裝置(具有兩個核心)似乎很忙。除了中間有一個小的間隔外,它們的使用沒有明顯的間隔。如果我們向上滾動以嘗試找到哪個進程佔用了主機,我們將找不到任何資訊。因此,我們將在管道檔案以及 U-net函數中添加 xp.traces。後者對於此特定用例可能沒有用,但它確實演示了如何在不同位置添加追蹤以及它們的資訊如何在 TensorBoard 中顯示。

如果我們添加追蹤並使用可以容納在裝置上的最大批次大小(在本例中為 32)重新捕獲分析,我們將看到裝置中的間隔是由在主機上運行的 Python 進程引起的。

Alt text Alt text

我們可以使用適當的工具放大時間線,並查看在該期間內運行的進程。這就是 Python 程式碼追蹤在主機上發生的時間,此時我們無法進一步改進追蹤。

現在,讓我們檢查一下模型的 XL 版本並執行相同的操作。我們將以與 2.1 版本相同的方式在管道檔案中添加追蹤,並捕獲分析。

Alt text

這一次,除了由 pipe_watermark 追蹤引起的中間大間隔之外,此迴圈中的推理步驟之間還有許多小的間隔。

首先仔細查看由 pipe_watermark 引起的大間隔。間隔之前是 TransferFromDevice,這表明主機上正在發生一些事情,正在等待計算完成後再繼續進行。查看水印程式碼,我們可以看到張量被傳輸到 CPU 並轉換為 numpy 陣列,以便稍後使用 cv2pywt 庫進行處理。由於這部分不容易優化,我們將保持原樣。

現在,如果我們放大迴圈,我們可以看到迴圈內的圖表被分解成更小的部分,因為發生了 TransferFromDevice 操作。

Alt text

如果我們研究 U-Net 函數和排程器,我們可以看到 U-Net 程式碼不包含任何針對 PyTorch/XLA 的優化目標。但是,在scheduler.step 中有 .item().nonzero() 呼叫。我們可以重寫該函數以避免這些呼叫。如果我們解決了這個問題並重新運行分析,我們不會看到太大的差異。但是,由於我們減少了引入較小圖表的裝置-主機通信,因此我們允許編譯器更好地優化程式碼。scale_model_input 函數也有類似問題,我們可以通過對 step 函數進行上述更改來解決這些問題。總體而言,由於許多間隔是由於 Python 級別的程式碼追蹤和圖表構建造成的,因此在當前版本的 PyTorch XLA 中無法優化這些間隔,但是當 PyTorch XLA 中啟用 dynamo 時,我們可能會在將來看到改進。

在多個 TPU 裝置上運行

要使用多個 TPU 裝置,可以使用 xmp.spawn 函數將您在單個裝置上運行的函數產生到多個裝置上。xmp.spawn 函數將在多個 TPU 裝置上啟動進程,並在需要時同步它們。這可以通過將 index 參數傳遞給在單個裝置上運行的函數來完成。例如,

import torch_xla.distributed.xla_multiprocessing as xmp

def my_function(index):
  # function that runs on a single device

xmp.spawn(my_function, args=(0,), nprocs=4)

在此示例中,my_function 函數將在 v4-8 上的 4 個 TPU 裝置上產生,每個裝置分配一個從 0 到 3 的索引。

此檔案說明了如何使用 xmp.spawn 在多個 TPU 裝置上運行 Stable Diffusion 2.1 版本。對於此版本,與上述更改類似,對管道檔案進行了更改。

在 Pod 上運行

一旦您擁有在單個主機裝置上運行的程式碼,就不需要進一步更改。您可以按照這些說明創建 TPU Pod。然後使用以下命令運行您的腳本

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --worker=all \
  --command="python3 your_script.py"
1

0 和 1 是 XLA 中的魔數,在 HLO 中被視為常數。因此,如果程式碼中有一個隨機數生成器可以生成這些值,則程式碼將分別為每個值進行編譯。可以使用 XLA_NO_SPECIAL_SCALARS=1 環境變數禁用此功能。

故障排除

請注意,本節中的資訊可能會在未來版本的*PyTorch/XLA*軟體中刪除,因為其中許多資訊僅適用於可能會發生變化的特定內部實現。

健全性檢查

在執行任何深入的除錯之前,我們想對安裝的 PyTorch/XLA 進行健全性檢查。

檢查 PyTorch/XLA 版本

PyTorch 和 PyTorch/XLA 版本應匹配。有關可用版本的更多詳細資訊,請查看我們的README

vm:~$ python
>>> import torch
>>> import torch_xla
>>> print(torch.__version__)
2.1.0+cu121
>>> print(torch_xla.__version__)
2.1.0

執行簡單計算

vm:~$ export PJRT_DEVICE=TPU
vm:~$ python3
>>> import torch
>>> import torch_xla.core.xla_model as xm
>>> t1 = torch.tensor(100, device=xm.xla_device())
>>> t2 = torch.tensor(200, device=xm.xla_device())
>>> print(t1 + t2)
tensor(300, device='xla:0')

使用虛假資料運行 ResNet

對於每晚版本

vm:~$ git clone https://github.com/pytorch/xla.git
vm:~$ python xla/test/test_train_mp_imagenet.py --fake_data

對於發佈版本 x.y,您需要使用分支 rx.y。例如,如果您安裝了 2.1 發佈版本,則應執行

vm:~$ git clone --branch r2.1 https://github.com/pytorch/xla.git
vm:~$ python xla/test/test_train_mp_imagenet.py --fake_data

如果您可以運行 ResNet,我們可以得出結論,torch_xla 已正確安裝。

效能除錯

要診斷效能問題,我們可以使用*PyTorch/XLA*提供的執行指標和計數器。當模型速度較慢時,首先要檢查的是生成指標報告。

指標報告對於診斷問題非常有幫助。如果您有指標報告,請盡量將其包含在發送給我們的錯誤報告中。

PyTorch/XLA 除錯工具

您可以通過設置 PT_XLA_DEBUG=1 來啟用 PyTorch/XLA 除錯工具,該工具提供了一些有用的除錯功能。

PyTorch/XLA + Dynamo 除錯工具

您可以通過設置 XLA_DYNAMO_DEBUG=1 來啟用 PyTorch/XLA + Dynamo 除錯工具。

執行自動指標分析

除錯工具將分析指標報告並提供摘要。一些示例輸出將是

pt-xla-profiler: CompileTime too frequent: 21 counts during 11 steps
pt-xla-profiler: TransferFromDeviceTime too frequent: 11 counts during 11 steps
pt-xla-profiler: Op(s) not lowered: aten::_ctc_loss, aten::_ctc_loss_backward,  Please open a GitHub issue with the above op lowering requests.
pt-xla-profiler: CompileTime too frequent: 23 counts during 12 steps
pt-xla-profiler: TransferFromDeviceTime too frequent: 12 counts during 12 steps

編譯和執行分析

除錯工具將分析模型的每次編譯和執行。一些示例輸出將是

Compilation Analysis: ================================================================================
Compilation Analysis: Compilation Cause
Compilation Analysis:   user mark_step
Compilation Analysis: Graph Info:
Compilation Analysis:   Graph Hash: 537d4b0264b029688281412214d252e9
Compilation Analysis:   Number of Graph Inputs: 588
Compilation Analysis:   Number of Graph Outputs: 320
Compilation Analysis: Python Frame Triggered Execution:
Compilation Analysis:   mark_step (/workspaces/dk2/pytorch/xla/torch_xla/core/xla_model.py:840)
Compilation Analysis:   broadcast_master_param (/workspaces/dk2/pytorch/xla/torch_xla/core/xla_model.py:1230)
Compilation Analysis:   train_imagenet (/workspaces/dk2/pytorch/xla/test/test_train_mp_imagenet.py:261)
Compilation Analysis:   _mp_fn (/workspaces/dk2/pytorch/xla/test/test_train_mp_imagenet.py:365)
Compilation Analysis:   __call__ (/workspaces/dk2/pytorch/xla/torch_xla/_internal/pjrt.py:176)
Compilation Analysis:   _thread_fn (/workspaces/dk2/pytorch/xla/torch_xla/_internal/pjrt.py:70)
Compilation Analysis:   run (/usr/local/lib/python3.8/concurrent/futures/thread.py:57)
Compilation Analysis:   _worker (/usr/local/lib/python3.8/concurrent/futures/thread.py:80)
Compilation Analysis:   ..........
Compilation Analysis: --------------------------------------------------------------------------------
Compilation Analysis: ================================================================================

Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis:   user mark_step
Execution Analysis: Graph Info:
Execution Analysis:   Graph Hash: 537d4b0264b029688281412214d252e9
Execution Analysis:   Number of Graph Inputs: 588
Execution Analysis:   Number of Graph Outputs: 320
Execution Analysis: Python Frame Triggered Execution:
Execution Analysis:   mark_step (/workspaces/dk2/pytorch/xla/torch_xla/core/xla_model.py:840)
Execution Analysis:   broadcast_master_param (/workspaces/dk2/pytorch/xla/torch_xla/core/xla_model.py:1230)
Execution Analysis:   train_imagenet (/workspaces/dk2/pytorch/xla/test/test_train_mp_imagenet.py:261)
Execution Analysis:   _mp_fn (/workspaces/dk2/pytorch/xla/test/test_train_mp_imagenet.py:365)
Execution Analysis:   __call__ (/workspaces/dk2/pytorch/xla/torch_xla/_internal/pjrt.py:176)
Execution Analysis:   _thread_fn (/workspaces/dk2/pytorch/xla/torch_xla/_internal/pjrt.py:70)
Execution Analysis:   run (/usr/local/lib/python3.8/concurrent/futures/thread.py:57)
Execution Analysis:   _worker (/usr/local/lib/python3.8/concurrent/futures/thread.py:80)
Execution Analysis:   ..........
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================

編譯/執行的一些常見原因是

  1. 使用者手動呼叫 mark_step

  2. 平行加載器為每 x 個(可配置)批次呼叫 mark_step

  3. 退出分析器 StepTrace 區域

  4. Dynamo 決定編譯/執行圖表。

  5. 使用者試圖在 mark_step 之前訪問張量的值(通常是由於日誌記錄)。

由 1-4 引起的執行是預期的,我們希望通過減少訪問張量值的頻率或在訪問之前手動添加 mark_step 來避免 5。

使用者應該預期在前幾個步驟中會看到 Compilation Cause + Executation Cause 的配對。在模型穩定後,使用者應該只會看到 Execution Cause。為了有效地使用 PyTorch/XLA,我們預期每個步驟都運行相同的模型程式碼,並且每個圖形只編譯一次。如果您不斷看到 Compilation Cause,您應該嘗試按照本節的說明傾印 IR/HLO,並比較每個步驟的圖形,了解差異的來源。

以下章節將說明如何取得和理解更詳細的指標報告。

取得指標報告

在您的程式中加入以下這一行,即可產生報告

import torch_xla.debug.metrics as met

# For short report that only contains a few key metrics.
print(met.short_metrics_report())
# For full report that includes all metrics.
print(met.metrics_report())

理解指標報告

報告包含以下內容

  • 我們發出 _XLA_ 編譯的次數和花費的時間。

  • 我們執行的次數和花費的時間

  • 我們建立/銷毀的裝置資料控點數量等等。

此資訊以樣本的百分比表示。例如

Metric: CompileTime
  TotalSamples: 202
  Counter: 06m09s401ms746.001us
  ValueRate: 778ms572.062us / second
  Rate: 0.425201 / second
  Percentiles: 1%=001ms32.778us; 5%=001ms61.283us; 10%=001ms79.236us; 20%=001ms110.973us; 50%=001ms228.773us; 80%=001ms339.183us; 90%=001ms434.305us; 95%=002ms921.063us; 99%=21s102ms853.173us

我們還提供計數器,它們是以整數變數命名的,用於追蹤內部軟體狀態。例如

Counter: CachedSyncTensors
  Value: 395

在本報告中,任何以 aten:: 開頭的計數器都表示 XLA 裝置和 CPU 之間的上下文切換,這可能是模型程式碼中潛在的效能優化區域。

計數器有助於了解哪些操作被路由回 _PyTorch_ 的 CPU 引擎。它們使用 C++ 命名空間進行完整限定

Counter: aten::nonzero
  Value: 33

如果您看到 aten:: 操作不是 nonzero_local_scalar_dense,那通常表示 PyTorch/XLA 中缺少降低。歡迎在 GitHub 問題上為其開啟功能請求。

清除指標報告

如果您想在步驟/時期之間清除指標,可以使用

import torch_xla.debug.metrics as met

met.clear_all()

效能分析

如需深入分析您的工作負載以了解瓶頸,請查看以下資源

已知的效能注意事項

PyTorch/XLA 在語義上表現得像普通的 PyTorch,XLA 張量與 CPU 和 GPU 張量共用完整的張量介面。但是,XLA/硬體中的限制和延遲評估模型表明,某些模式可能會導致效能不佳。

如果您的模型顯示效能不佳,請牢記以下注意事項

  1. 如果重新編譯次數過多,XLA/TPU 的效能會下降。

    XLA 編譯成本很高。PyTorch/XLA 會在每次遇到新形狀時自動重新編譯圖形。通常,模型應該在幾個步驟內穩定下來,並且您可以看到其餘訓練的巨大加速。

    為了避免重新編譯,不僅形狀必須恆定,而且所有主機上 XLA 裝置的計算也必須恆定。

    可能的原因:

    • 直接或間接使用 nonzero 會引入動態形狀;例如,遮罩索引 base[index],其中 index 是遮罩張量。

    • 步驟之間迭代次數不同的迴圈可能會導致不同的執行圖形,因此需要重新編譯。

    解決方案:

    • 張量形狀在迭代之間應保持一致,或者應使用少量形狀變化。

    • 盡可能將張量填充到固定大小。

  2. 某些操作沒有到 XLA 的原生轉換。

    對於這些操作,PyTorch/XLA 會自動轉移到 CPU 記憶體,在 CPU 上評估,然後將結果轉移回 XLA 裝置。在訓練步驟中執行太多此類操作可能會導致顯著的速度減慢。

    可能的原因:

    • item() 操作明確要求評估結果。除非必要,否則不要使用它。

    解決方案:

    • 對於大多數操作,我們可以將它們降低到 XLA 來修復它。查看指標報告章節以找出缺少的操作,並在 GitHub 上開啟功能請求。

    • 即使已知 PyTorch 張量是純量,也要避免使用 tensor.item()。將其保留為張量,並對其使用張量操作。

    • 在適用時使用 torch.where 來替代控制流程。例如,clip_grad*norm* 中使用的帶有 item() 的控制流程是有問題的,會影響效能,因此我們已修補 clip_grad_norm_,改為呼叫 torch.where,這使我們的效能得到了顯著提升。 .. code-block:: python

      … else

      device = parameters[0].device total_norm = torch.zeros([], device=device if parameters else None) for p in parameters

      param_norm = p.grad.data.norm(norm_type) ** norm_type total_norm.add_(param_norm)

      total_norm = (total_norm ** (1. / norm_type))

      clip_coef = torch.tensor(max_norm, device=device) / (total_norm + 1e-6) for p in parameters

      p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device)))

  3. ``torch_xla.distributed.data_parallel`` 中的迭代器可能會丟棄輸入迭代器中的最後幾個批次。

    這是為了確保我們在所有 XLA 裝置上執行相同數量的工作。

    解決方案:

    • 當資料集很小,並且步驟太少時,這可能會導致無效的時期。因此,在這種情況下最好使用小批次大小。

XLA 張量怪癖

  1. **XLA 張量內部是不透明的。** XLA 張量始終顯示為連續的,並且沒有儲存空間。網路不應嘗試檢查 XLA 張量的步幅。

  2. **儲存 XLA 張量之前,應先將其移至 CPU。** 直接儲存 XLA 張量會導致它們在儲存它們的裝置上重新載入。如果載入時裝置不可用,則載入將失敗。在儲存 XLA 張量之前將其移至 CPU,您可以決定將載入的張量放在哪些裝置上。如果您想在沒有 XLA 裝置的機器上載入張量,則這是必要的。但是,在儲存 XLA 張量之前,應小心地將其移至 CPU,因為跨裝置類型移動張量不會保留檢視關係。相反,應在載入張量後根據需要重建檢視。

  3. **使用 Python 的 copy.copy 複製 XLA 張量會返回深拷貝,而不是淺拷貝。** 使用 XLA 張量的檢視來獲取其淺拷貝。

  4. **處理共用權重。** 模組可以通過將一個模組的參數設定為另一個模組來共用權重。模組權重的這種“綁定”應該在將模組移至 XLA 裝置之後完成。否則,將在 XLA 裝置上建立共用張量的兩個獨立副本。

更多除錯工具

我們不希望使用者使用本節中的工具來除錯他們的模型。但是,當您提交錯誤報告時,我們可能會要求提供這些資訊,因為它們提供了指標報告中沒有的額外資訊。

  • print(torch_xla._XLAC._get_xla_tensors_text([res])),其中 res 是結果張量,會印出 IR。

  • print(torch_xla._XLAC._get_xla_tensors_hlo([res])),其中 res 是結果張量,會印出生成的 XLA HLO。

請注意,必須在 mark_step() 之前呼叫這些函式,否則張量將已被具體化。

環境變數

還有一些環境變數可以控制 _PyTorch/XLA_ 軟體堆疊的行為。

設定此類變數將導致不同程度的效能下降,因此應僅啟用它們以進行除錯。

  • XLA_IR_DEBUG:啟用在建立 IR 節點時擷取 _Python_ 堆疊追蹤,從而可以了解哪個 _PyTorch_ 操作負責生成 IR。

  • XLA_HLO_DEBUG:啟用在啟用 _XLA_IR_DEBUG_ 時擷取的 _Python_ 堆疊框架,以傳播到 _XLA_ _HLO_ 中繼資料。

  • XLA_SAVE_TENSORS_FILE:用於在執行期間傾印 IR 圖形的檔案路徑。請注意,如果選項保持啟用狀態並且 _PyTorch_ 程式長時間執行,則檔案可能會變得非常大。圖形會附加到檔案中,因此要在每次執行時都有一個乾淨的表單,應明確刪除該檔案。

  • XLA_SAVE_TENSORS_FMT:_XLA_SAVE_TENSORS_FILE_ 檔案中儲存的圖形格式。可以是 text(預設值)、dot(_Graphviz_ 格式)或 hlo

  • XLA_FLAGS=--xla_dump_to:如果設定為 =/tmp/dir_name,XLA 編譯器將在每次編譯時傾印未優化和優化的 HLO。

  • XLA_METRICS_FILE:如果設定,則為本地檔案的路徑,其中內部指標將在每個步驟中儲存到該檔案。指標將附加到檔案中(如果已存在)。

  • XLA_SAVE_HLO_FILE:如果設定,則為本地檔案的路徑,如果發生編譯/執行錯誤,將儲存違規的 HLO 圖形。

  • XLA_SYNC_WAIT:強制 XLA 張量同步操作等待其完成,然後再轉到下一步。

  • XLA_USE_EAGER_DEBUG_MODE:強制 XLA 張量立即執行,表示逐一編譯和執行 torch 操作。這在略過冗長的編譯時間時很有用,但整體步驟時間會慢很多,而且記憶體使用量會比較高,因為會略過所有編譯器最佳化。

  • XLA_USE_BF16:如果設定為 1,則在傳送到 _TPU_ 裝置時,會將所有 _PyTorch_ _Float_ 值轉換為 _BiFloat16_。請注意,使用 XLA_USE_BF16=1 時,張量算術將以降低的精度完成,因此如果隨著時間累積,張量將不準確。例如

    # In reduced bfloat16 precision
    >>> torch.tensor(4096, dtype=torch.bfloat16) + torch.tensor(1, dtype=torch.bfloat16)
    tensor(4096., dtype=torch.bfloat16)
    # Whereas in full float32 precision
    >>> torch.tensor(4096) + torch.tensor(1)
    tensor(4097)
    

    因此,若要取得準確的指標,例如多個步驟的平均損失值,請使用手動混合精度,讓指標保持在 FP32。

  • XLA_USE_F16:如果設定為 1,則在傳送到支援的裝置時,會將所有 _PyTorch_ _Float_ 值轉換為 _Float16_ (_PyTorch_ _Half_ 類型)。

  • TF_CPP_LOG_THREAD_ID:如果設定為 1,TF 記錄將會顯示執行緒 ID,有助於對多執行緒處理序進行偵錯。

  • TF_CPP_VMODULE:用於 TF VLOG 的環境變數,採用 TF_CPP_VMODULE=name=value,... 的形式。請注意,對於 VLOG,您必須設定 TF_CPP_MIN_LOG_LEVEL=0

  • TF_CPP_MIN_LOG_LEVEL:要列印訊息的層級。 TF_CPP_MIN_LOG_LEVEL=0 將會開啟 INFO 記錄, TF_CPP_MIN_LOG_LEVEL=1 將會開啟 WARNING 記錄,依此類推。我們的 PyTorch/XLA TF_VLOG 預設使用 tensorflow::INFO 層級,因此若要查看 VLOG,請設定 TF_CPP_MIN_LOG_LEVEL=0

  • XLA_DUMP_HLO_GRAPH:如果在編譯或執行錯誤時設定為 =1,則會將違規的 HLO 圖表作為 xla_util.cc 引發的執行階段錯誤的一部分傾印。

常見的偵錯環境變數組合

  • 以 IR 格式記錄圖表執行

    XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 XLA_SAVE_TENSORS_FMT="text" XLA_SAVE_TENSORS_FILE="/tmp/save1.ir"
    
  • 以 HLO 格式記錄圖表執行

    XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 XLA_SAVE_TENSORS_FMT="hlo" XLA_SAVE_TENSORS_FILE="/tmp/save1.hlo"
    
  • 顯示執行階段和圖表編譯/執行的偵錯 VLOG

    TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE="xla_graph_executor=5,pjrt_computation_client=3"
    

重現 PyTorch/XLA CI/CD 單元測試失敗。

您可能會看到 PR 的一些測試失敗,例如

To execute this test, run the following from the base repo dir:
    PYTORCH_TEST_WITH_SLOW=1 python ../test/test_torch.py -k test_put_xla_uint8

直接在命令列中執行這個命令行是無效的。您需要在您的本機 pytorch/xla/test/pytorch_test_base.py 中設定環境變數 TORCH_TEST_DEVICES。例如

TORCH_TEST_DEVICES=/path/to/pytorch/xla/test/pytorch_test_base.py PYTORCH_TEST_WITH_SLOW=1 python ../test/test_torch.py -k test_put_xla_uint8 應該可以正常運作。

PJRT 執行階段

PyTorch/XLA 已從基於 TensorFlow 的 XRT 執行階段遷移到 PJRT 執行階段,這是 JAX 使用的執行階段。

如果您在使用 PJRT 時遇到錯誤,請在 GitHub 上提出問題,並標記為 runtime

PyTorch/XLA r2.1 中的新功能:

  • PJRT 在 PyTorch/XLA r2.1 中很穩定!

  • 公開執行階段 API 已從 torch_xla.experimental.pjrt 移至 torch_xla.runtime

    • pjrt:// 初始化方法已重新命名為 xla://,並由 torch_xla.distributed.xla_backend 註冊。

    • 先前的 torch_xla.experimental.* 名稱在此版本中仍然可用,以確保相容性。

  • 使用 init_method='xla://' 時,現在支援 torchrun

  • 透過 PJRT C API 為 XPU 和 Neuron 提供的新外掛。

PyTorch/XLA r2.0 中的新功能:

  • 如果您沒有傳入任何其他執行階段設定,則預設會設定 PJRT。如果您繼續設定 XRT 設定 (XRT_TPU_CONFIG),則此變更不會產生影響

  • libtpu 中新的 TPU 執行階段實作將效能提升了高達 30%。

  • 新的 xm.rendezvous 實作可擴展至數千個 TPU 核心

  • [實驗性] torch.distributed 支援 TPU v2 和 v3,包括 pjrt:// init_method

TL;DR

  • 若要使用 PJRT 預覽執行階段,請將 PJRT_DEVICE 環境變數設定為 CPUTPUCUDA

  • 在 XRT 中,所有分散式工作負載都是多程序的,每個裝置一個程序。在 PJRT 的 TPU v2 和 v3 上,工作負載是多程序和多執行緒的(4 個程序,每個程序 2 個執行緒),因此您的工作負載應該是執行緒安全的。如需更多資訊,請參閱 TPU v2/v3 上的多執行緒API 指南的多程序區段。需要牢記的幾個主要差異

    • 若要以執行緒安全的方式初始化模型,請在初始化後跨複本廣播參數 (torch_xla.experimental.pjrt.broadcast_master_param),或從共用檢查點載入每個複本的參數。

    • 對於其他隨機數產生,請盡可能使用 torch.Generator。全域 torch RNG _不是_ 執行緒安全的,即使您在複本中設定相同的 torch.manual_seed 也是如此。

    • 若要使用 torch.distributed,請匯入 torch_xla.experimental.pjrt_backend 並使用 xla:// init_method

    • 這些步驟對於 GPU 和 TPU v4 是選用的。

從 XRT 到 PJRT 的範例差異

 import os

 import torch
 import torch.nn as nn
 from torch.nn.parallel import DistributedDataParallel as DDP
 import torch.optim as optim
 import torch.distributed as dist
 import torch_xla.core.xla_model as xm
 import torch_xla.distributed.parallel_loader as pl
 import torch_xla.distributed.xla_backend
 import torch_xla.distributed.xla_multiprocessing as xmp
+import torch_xla.runtime as xr


 def _mp_fn(index):
   device = xm.xla_device()
-  dist.init_process_group('xla', rank=xm.get_ordinal(), world_size=xm.xrt_world_size())
+  dist.init_process_group('xla', init_method='xla://')

   torch.manual_seed(42)
   model = nn.Linear(128, 10).to(device)

+  # Optional for TPU v4 and GPU
+  xm.broadcast_master_param(model)
   model = DDP(model, gradient_as_bucket_view=True)

   loss_fn = nn.MSELoss()
   optimizer = optim.SGD(model.parameters(), lr=.001)

   for i in range(10):
     data, target = torch.randn((128, 128), device=device), torch.randn((128, 10), device=device)

     optimizer.zero_grad()
     output = model(data)
     loss = loss_fn(output, target)
     loss.backward()

     optimizer.step()
     xm.mark_step()

   # Print mean parameters so we can confirm they're the same across replicas
   print([p.mean() for p in model.parameters()])

 if __name__ == '__main__':
-  os.environ['XRT_TPU_CONFIG'] = 'localservice;0;localhost:51011'
-  os.environ['MASTER_ADDR'] = 'localhost'
-  os.environ['MASTER_PORT'] = '12355'

+  # Recommended: set PJRT_DEVICE to your local device type
+  os.environ['PJRT_DEVICE'] = 'TPU'

   xmp.spawn(_mp_fn)

優點

  • 簡單的執行階段設定:只需將 PJRT_DEVICE 設定為 TPUCPUCUDA,即可開始使用 XLA!或者,讓 PJRT 根據您的環境自動選擇裝置。

  • 效能提升:減少 gRPC 的負擔,表示端對端執行速度更快。在 TorchBench 2.0 上,我們觀察到 TPU v4 上的訓練時間縮短了 >35%。

  • 輕鬆執行 Pod:只需將您的程式碼複製到每個 TPU 工作節點,並使用 gcloud compute tpus tpuvm ssh --worker=all 同時執行它們即可。

  • 更好的擴展性:移除 XRT 對參數大小的限制,並支援多達 2048 個 TPU 晶片。

快速入門

若要開始在 PyTorch/XLA 中使用 PJRT,您只需要設定 PJRT_DEVICE 環境變數即可。如果您使用的是 TPU v2 或 v3,請继续阅读以了解 TPU v2 和 v3 與 v4 之间的差异。

CPU

在任何安裝了 PyTorch/XLA 的機器上,您都可以在 CPU 上執行我們的 MNIST 範例,如下所示

PJRT_DEVICE=CPU python3 xla/test/test_train_mp_mnist.py --fake_data

TPU

若要使用安裝的 PyTorch/XLA r2.0 建立新的 TPU

gcloud alpha compute tpus tpu-vm create $USER-pjrt --accelerator-type=v4-8 --version=tpu-vm-v4-pt-2.0 --zone=us-central2-b --project=$PROJECT

在 v4-8 上,您可以執行我們的 ResNet50 範例,如下所示

git clone --depth=1 --branch r2.0 https://github.com/pytorch/xla.git
PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1

預設情況下,PJRT 將會使用所有 TPU 晶片。若要僅使用一個 TPU 晶片,請設定 TPU_PROCESS_BOUNDSTPU_VISIBLE_CHIPS

TPU_PROCESS_BOUNDS=1,1,1 TPU_VISIBLE_CHIPS=0 PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1

Pod

在 TPU Pod 上,使用 gcloud 在每個 TPU 上平行執行您的命令

gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="git clone --depth=1 --branch r1.13 https://github.com/pytorch/xla.git"
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1"

Docker

您也可以使用 Docker 在預先安裝了 PyTorch/XLA 的容器中執行您的工作負載

export DOCKER_IMAGE=gcr.io/...

# Optional: authenticate docker if your image is in a private GCP repository
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo gcloud auth configure-docker"

# Run your workload
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo docker run --rm --privileged --net=host -e PJRT_DEVICE=TPU $DOCKER_IMAGE python pytorch/xla/test/test_train_mp_imagenet.py --fake_data"

請注意,docker run 需要對主機的專屬存取權 (--privileged),才能將 TPU 裝置公開給容器。目前,TPU Pod 上的 Docker 僅支援主機網路 --net=host。如需更多資訊,請參閱 Cloud TPU 文件

GPU

單一節點 GPU 訓練

若要將 GPU 與 PJRT 搭配使用,只需設定 PJRT_DEVICE=CUDA 並將 GPU_NUM_DEVICES 設定為主機上的裝置數量。例如

PJRT_DEVICE=CUDA GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1

您也可以使用 torchrun 來啟動單一節點多 GPU 訓練。例如,

PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node ${NUM_GPU_DEVICES} xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

在上述範例中,--nnodes 表示要使用多少台機器(實體機器或 VM)(因為我們進行單一節點訓練,所以為 1)。--nproc-per-node 表示要使用多少個 GPU 裝置。

多節點 GPU 訓練

**請注意,此功能僅適用於 cuda 12 以上版本**。與 PyTorch 使用多節點訓練的方式類似,您可以執行以下命令

PJRT_DEVICE=CUDA torchrun \
--nnodes=${NUMBER_GPU_VM} \
--node_rank=${CURRENT_NODE_RANK} \
--nproc_per_node=${NUMBER_LOCAL_GPU_DEVICES} \
--rdzv_endpoint=<internal_ip_address:port> multinode_training.py
  • --nnodes:要使用多少台 GPU 機器。

  • --node_rank:目前 GPU 機器的索引。值可以是 0、1、…、${NUMBER_GPU_VM}-1。

  • --nproc_per_node:要在目前機器上使用的 GPU 裝置數量。

  • –rdzv_endpoint:節點排名 (node_rank) 為 0 的 GPU 機器的端點,格式為 host:port`。 ``host將會是內部 IP 位址。 port` 可以是機器上任何可用的埠。 對於單節點訓練/推論,可以省略此參數。

例如,如果您想在 2 台 GPU 機器上進行訓練:machine_0 和 machine_1,請在第一台 GPU 機器 machine_0 上執行

# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=0 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py  --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

在第二台 GPU 機器上,執行

# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=1 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py  --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

上述 2 個命令之間的差異在於 --node_rank 以及如果要在每台機器上使用不同數量的 GPU 裝置,則可能還有 --nproc_per_node。 所有其餘部分都相同。 如需有關 torchrun 的更多資訊,請參閱此 頁面

與 XRT 的差異

雖然在大多數情況下,我們預計 PJRT 和 XRT 從終端使用者的角度來看大部分可以互換使用(尤其是在 TPU v4 上),但還是有一些細微的差異需要注意。重要的是,XRT 是圍繞 TPU 節點架構設計的,因此它始終會產生一個客戶端程序和一個伺服器程序,即使是在 TPU VM 上也是如此。因此,每一批輸入都會因為序列化和反序列化資料以通過網路發送而產生額外的延遲。

PJRT 直接使用本機裝置,沒有中間伺服器程序。在預設配置中,PJRT 將為每個 TPU 晶片建立一個程序,或為每個 TPU 主機建立 4 個程序。如需有關 TPU 架構的更多資訊,請參閱 Cloud TPU 文件

  • 對於受限於開銷的工作負載,可以提高效能。

  • 在 XRT 下,伺服器程序是唯一與 TPU 裝置互動的程序,而客戶端程序無法直接存取 TPU 裝置。在分析單主機 TPU(例如 v3-8 或 v4-8)時,您通常會看到 8 個裝置追蹤(每個 TPU 核心一個)。使用 PJRT,每個程序都有一個晶片,並且來自該程序的配置文件將僅顯示 2 個 TPU 核心。

    • 出於相同的原因,使用 XRT 在 TPU Pod 上進行分析無法正常運作,因為伺服器程序獨立於使用者的模型程式碼執行。PJRT 沒有這個限制,因此可以在 TPU Pod 中為每個程序分析 2 個 TPU 核心。

  • PJRT 僅支援 TPU VM 架構,我們沒有計劃使用 PJRT 支援 TPU 節點架構。

  • 使用 PJRT,執行時配置要簡單得多。執行 TPU Pod 工作負載不需要 xla_dist。相反,將您的程式碼複製到每個 TPU 主機([gcloud compute tpus tpu-vm scp](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/scp)),並在每個主機上並行執行程式碼(例如 [gcloud compute tpus tpu-vm ssh --workers=all --command="PJRT_DEVICE=TPU python run.py"](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/ssh))。

  • 已使用 XLA 原生集體通訊重新實作 xm.rendezvous,以增強大型 TPU pod 上的穩定性。有關更多詳細資訊,請參見下文。

TPU v2/v3 上的多執行緒

在 TPU v2 和 v3 上,分散式工作負載始終以多執行緒方式執行,因為每個 TPU 晶片公開兩個 TPU 核心作為裝置,並且一次只有一個程序可以開啟一個 TPU 晶片。在其預設配置中,xmp.spawn 會自動產生盡可能多的程序(每個 TPU 主機 4 個),並為每個程序建立兩個執行緒(每個 TPU 核心一個)。

注意:在 TPU v4 上,每個 TPU 晶片表示為一個 PyTorch 裝置,因此分散式工作負載將在 4 個程序中執行,每個程序只有一個執行緒。這與 XRT 的行為相同。

在大多數情況下,這不需要對您現有的程式碼進行重大更改。在大多數情況下,您必須進行的主要更改是模型初始化。因為 torch 的全域 RNG 在執行緒之間共用,即使您在每個副本中將 torch.manual_seed 設定為相同的值,結果也會因執行緒和執行而異。要獲得副本之間一致的參數,請使用 torch_xla.experimental.pjrt.broadcast_master_param 將一個副本的參數廣播到所有其他副本,或從公共檢查點載入每個副本的參數。

對 xm.rendezvous 的變更

PyTorch/XLA r2.0 中的新功能

使用 XRT,工作節點 0 執行網格主服務,所有工作節點上的所有程序都通過 gRPC 連接到該服務。在實務中,我們發現由於到工作節點 0 的入站連接數量眾多,在具有數千個晶片的 TPU pod 上執行單個網格主程序是不可靠的。單個客戶端程序超時可能會導致故障,並強制整個工作負載重新啟動。

因此,我們使用原生 XLA 集體通訊重新實作了 xm.rendezvous,它在大型 TPU pod 上更加穩定且經過充分測試。與 XRT 實作相比,這引入了兩個新的約束

  • 因為有效負載必須成為 XLA 圖的一部分,所以在傳輸資料之前和之後都會呼叫 xm.mark_step。在模型程式碼中間呼叫 xm.rendezvous 可能會強制進行不必要的編譯。

  • 因為 XLA 不允許在工作節點的子集上執行集體操作,所以所有工作節點都必須參與 rendezvous

如果您需要 xm.rendezvous 的舊行為(即在不更改 XLA 圖和/或同步工作節點子集的情況下傳輸資料),請考慮使用 ``torch.distributed.barrier` <https://pytorch.com.tw/docs/stable/distributed.html#torch.distributed.barrier>`_ 或 ``torch.distributed.all_gather_object` <https://pytorch.com.tw/docs/stable/distributed.html#torch.distributed.all_gather_object>`_ 和 gloo 程序群組。如果您也在使用 xla torch.distributed 後端,則可以使用 torch.new_group 建立 gloo 子群組。請參閱 PyTorch 文件中的 此範例。請牢記這些限制

  • torch.distributed 在 TPU v2/v3 上不受完全支援。僅實作了 xla 後端的一組操作,並且 gloo 在多執行緒環境中可能無法按預期運作。

  • 在我們的實驗中,gloo 無法很好地擴展到數千個 TPU 晶片,因此預計這種替代方案的可靠性不如在大型規模上使用 PJRT 的 xm.rendezvous

PJRT 和 torch.distributed

PyTorch/XLA r2.0 中的新功能

將 PJRT 與 torch.distributed[torch.nn.parallel.DistributedDataParallel](https://github.com/pytorch/xla/blob/master/docs/ddp.md) 一起使用時,我們強烈建議使用新的 xla:// init_method,它通過查詢執行時自動找到副本 ID、世界大小和主 IP。例如

import torch
import torch.distributed as dist
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.experimental import pjrt

# Required for `xla://` init_method and `xla` backend
import torch_xla.distributed.xla_backend

def _all_gather(index: int):
  # No need to pass in `rank` or `world_size`
  dist.init_process_group('xla', init_method='xla://')

  t = torch.tensor([index], dtype=torch.int32, device=xm.xla_device())
  output = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
  dist.all_gather(output, t)

  xm.mark_step()
  print(output)

if __name__ == '__main__':
  xmp.spawn(_all_gather)

注意:雖然 TPU v4 上不需要 xla:// init_method,但仍然建議使用。如果您使用 env://,則必須將 MASTER_ADDR 設定為具有裝置 0 的 IP 主機,這並非始終是工作節點 0。xla:// init_method 會自動找到此 IP。

注意:對於 TPU v2/v3,您仍然需要匯入 torch_xla.experimental.pjrt_backend,因為 torch.distributed 中的 TPU v2/v3 支援仍處於實驗階段。

如需有關在 PyTorch/XLA 上使用 DistributedDataParallel 的更多資訊,請參閱 TPU V4 上的 ``ddp.md` <./ddp.md>`_。有關同時使用 DDP 和 PJRT 的範例,請在 TPU 上執行以下 範例腳本

PJRT_DEVICE=TPU python xla/test/test_train_mp_mnist.py --ddp --pjrt_distributed --fake_data --num_epochs 1

效能

TorchBench 顯示,與 XRT 相比,PJRT 在任務的平均訓練時間方面有所改進,在 TPU v4-8 上的平均改進超過 35%。收益因任務和模型類型而異,從 0% 到 175% 不等。下表顯示了按任務劃分的明細

PJRT vs XRT

新的 TPU 執行時

PyTorch/XLA r2.0 中的新功能

PyTorch/XLA r2.0 版本引入了對 PJRT 外掛 API 的支援,該 API 用於存取 libtpu 中基於 TFRT 的新 TPU 執行時。現在,當設定 PJRT_DEVICE=TPU 時,這是預設執行時。1.13 中使用的基於 StreamExecutor 的舊版 TPU 執行時在 2.0 版本中仍可通過 PJRT_DEVICE=TPU_LEGACY 使用,但將在未來版本中移除。如果您遇到僅發生在 TPU 而非 TPU_LEGACY 上的問題,請在 GitHub 上提交問題。

在大多數情況下,我們預計兩個執行時之間的效能相似,但在某些情況下,新執行時的速度可能提高 30%。下表顯示了按任務劃分的明細

TFRT vs StreamExecutor

注意:此圖表中顯示的改進也包含在 PJRT 與 XRT 的比較中。

PyTorch XLA 中的 TorchDynamo(torch.compile)整合

TorchDynamo 是一種 Python 級 JIT 編譯器,旨在使未經修改的 PyTorch 程式更快。它為編譯器後端提供了一個乾淨的 API 來掛鉤,其最大特點是在執行之前動態修改 Python 位元組碼。在 pytorch/xla 2.0 版本中,PyTorch/XLA 為 TorchDynamo 提供了一個用於推論和訓練的實驗性後端。

XLA 橋接的工作方式是,當 Dynamo 識別到模型模式時,它將提供一個 TorchFX 圖,而 PyTorch/XLA 將使用現有的 Lazy Tensor 技術來編譯 FX 圖並返回編譯後的函數。

整合

目前,通過將 backend='openxla' 參數添加到 torch.compile 來支援 PyTorch/XLA 和 Dynamo。例如

import torch
import torch_xla.core.xla_model as xm

def add(a, b):
  a_xla = a.to(xm.xla_device())
  b_xla = b.to(xm.xla_device())
  return a_xla + b_xla

compiled_code = torch.compile(add, backend='openxla')
print(compiled_code(torch.randn(10), torch.randn(10)))

推論

以下是一個使用 torch.compile 執行 resnet18 的小型程式碼範例

import torch
import torchvision
import torch_xla.core.xla_model as xm

def eval_model(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  xla_resnet18.eval()
  dynamo_resnet18 = torch.compile(
    xla_resnet18, backend='openxla')
  for data, _ in loader:
    with torch.no_grad():
      output = dynamo_resnet18(data)

使用 torch.compile,您會發現 PyTorch/XLA 僅在初始化期間追蹤 resent18 模型一次,並在每次呼叫 dynamo_resnet18 時執行已編譯的二進制文件,而不是每次都追蹤模型。以下是在 Cloud TPU v4-8 上使用 torch bench 比較 Dynamo 和 Lazy 的推論速度分析。

resnet18 | 2.59 resnet50 | 2.64 resnext50_32x4d | 1.91 alexnet | 1.28 mobilenet_v2 | 18.62 mnasnet1_0 | 2.68 vgg16 | 1.33 BERT_pytorch | 7.49 squeezenet1_1 | 2.29 timm_vision_transformer | 3.52 geomean | 3.04

訓練

PyTorch/XLA 也支援使用 Dynamo 進行訓練,但這仍處於實驗階段,我們正與 PyTorch 編譯器團隊合作迭代實作。以下是如何使用 torch.compile 訓練 resnet18 的範例。

import torch
import torchvision
import torch_xla.core.xla_model as xm

def train_model(model, data, target, optimizer):
  loss_fn = torch.nn.CrossEntropyLoss()
  pred = model(data)
  loss = loss_fn(pred, target)
  loss.backward()
  optimizer.step()
  return pred

def train_model_main(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  xla_resnet18.train()
  dynamo_train_model = torch.compile(
        train_model, backend='openxla')
  for data, target in loader:
    xla_optimizer = optim.SGD(data, lr=0.1, weight_decay=1e-2)
    output = dynamo_train_model(xla_resnet18, data, target, xla_optimizer)

我們預計每個訓練步驟會提取並執行 3 個圖形,而不是像使用 Lazy 張量那樣每個訓練步驟執行 1 個圖形。以下是在 Cloud TPU v4-8 上使用 torch bench 比較 Dynamo 和 Lazy 的訓練速度分析。

resnet50 | 1.33 resnet18 | 1.33 BERT_pytorch | 3.07 resnext50_32x4d | 1.43 alexnet | 1.12 mobilenet_v2 | 1.4 mnasnet1_0 | 1.19 vgg16 | 0.81 timm_vision_transformer | 1.87 squeezenet1_1 | 1.41 geomean | 1.41

**注意:**我們針對每個模型的前向和後向傳播執行單一步驟,然後收集端到端時間。在實際情況中,我們會在每個訓練作業中執行多個步驟,這可以輕鬆隱藏執行中的追蹤成本(因為它是異步的)。在這種情況下,Lazy 張量將具有更好的效能。

功能差異

我們要指出一個阻礙我們在更大規模模型上使用 TorchDynamo 的差異。

  1. TorchDynamo 會將前向和後向追蹤到不同的圖形中。對於 PyTorch/XLA,讓 XLA 編譯器將整個步驟視為一個圖形以最佳化速度非常重要。啟動每個裝置執行也有一個固定的開銷,這使得每個訓練步驟執行多個圖形不太理想。

與 Lazy 張量相比,這種差異使其在實際訓練用例中的效率較低,尤其是在訓練中追蹤成本可以與執行重疊。

總結

TorchDynamo 為編譯器後端提供了一種非常有希望的方式來向使用者隱藏複雜性,並輕鬆以圖形格式檢索建模程式碼。與 PyTorch/XLA 傳統的 Lazy 張量提取圖形的方式相比,TorchDynamo 可以跳過每次迭代的圖形追蹤,因此提供了更好的推論回應時間。

大多數 PyTorch/XLA 支援的模型在使用新的 dynamo-xla 橋接器執行推論時,速度都顯著提高。我們的社群正在努力擴展支援的模型集。關於上面提到的訓練功能差異,PyTorch/XLA 社群非常興奮地在我們即將進行的開發工作中改進訓練差異。團隊將繼續大力投資 TorchDynamo 並與上游合作,使訓練更加成熟。

PyTorch XLA 中的全分片資料平行 (FSDP)

PyTorch XLA 中的全分片資料平行 (FSDP) 是一個用於在資料平行工作器之間分片模組參數的工具。

使用範例

import torch
import torch_xla.core.xla_model as xm
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP

model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()

也可以分別分片各個層,並使用外部包裝器處理任何剩餘的參數。

注意事項

  • XlaFullyShardedDataParallel 類別在 https://arxiv.org/abs/1910.02054 中同時支援 ZeRO-2 優化器(分片梯度和優化器狀態)和 ZeRO-3 優化器(分片參數、梯度和優化器狀態)。

    • ZeRO-3 優化器應通過嵌套 FSDP 和 reshard_after_forward=True 來實作。請參閱 test/test_train_mp_mnist_fsdp_with_ckpt.pytest/test_train_mp_imagenet_fsdp.py 以獲取範例。

    • 對於無法放入單個 TPU 記憶體或主機 CPU 記憶體的大型模型,應將子模組建構與內部 FSDP 包裝交錯進行。請參閱 ``FSDPViTModel` <https://github.com/ronghanghu/vit_10b_fsdp_example/blob/master/run_vit_training.py>`_ 以獲取範例。

  • 提供了一個簡單的包裝器 checkpoint_module(基於 https://github.com/pytorch/xla/pull/3524 中的 torch_xla.utils.checkpoint.checkpoint),以便對給定的 nn.Module 實例執行 梯度檢查點。請參閱 test/test_train_mp_mnist_fsdp_with_ckpt.pytest/test_train_mp_imagenet_fsdp.py 以獲取範例。

  • 自動包裝子模組:除了手動嵌套 FSDP 包裝外,還可以指定 auto_wrap_policy 參數以使用內部 FSDP 自動包裝子模組。torch_xla.distributed.fsdp.wrap 中的 size_based_auto_wrap_policyauto_wrap_policy 可呼叫物件的範例,此策略包裝參數數量大於 100M 的層。torch_xla.distributed.fsdp.wrap 中的 transformer_auto_wrap_policy 是適用於類似變壓器的模型架構的 auto_wrap_policy 可呼叫物件的範例。

例如,要使用內部 FSDP 自動包裝所有 torch.nn.Conv2d 子模組,可以使用

from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})

此外,還可以指定 auto_wrapper_callable 參數,以便對子模組使用自定義的可呼叫包裝器(預設包裝器只是 XlaFullyShardedDataParallel 類別本身)。例如,可以使用以下內容將梯度檢查點(即啟動檢查點/重新實例化)應用於每個自動包裝的子模組。

from torch_xla.distributed.fsdp import checkpoint_module
auto_wrapper_callable = lambda m, *args, **kwargs: XlaFullyShardedDataParallel(
    checkpoint_module(m), *args, **kwargs)
  • 在逐步執行優化器時,直接呼叫 optimizer.step,而不要呼叫 xm.optimizer_step。後者會降低各個等級的梯度,這對於 FSDP(其中參數已經分片)是不需要的。

  • 在訓練期間儲存模型和優化器檢查點時,每個訓練過程都需要儲存其自己的(分片的)模型和優化器狀態字典的檢查點(使用 master_only=False 並為 xm.save 中的每個等級設定不同的路徑)。恢復時,它需要為相應的等級載入檢查點。

  • 還請將 model.get_shard_metadata()model.state_dict() 一起儲存,如下所示,並使用 consolidate_sharded_model_checkpoints 將分片的模型檢查點拼接在一起,形成完整的模型狀態字典。請參閱 test/test_train_mp_mnist_fsdp_with_ckpt.py 以獲取範例。 .. code-block:: python3

    ckpt = {

    ‘model’: model.state_dict(), ‘shard_metadata’: model.get_shard_metadata(), ‘optimizer’: optimizer.state_dict(),

    } ckpt_path = f’/tmp/rank-{xm.get_ordinal()}-of-{xm.xrt_world_size()}.pth’ xm.save(ckpt, ckpt_path, master_only=False)

  • 也可以從命令列啟動檢查點合併腳本,如下所示。 .. code-block:: bash

    # 通過命令列工具合併已儲存的檢查點 python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts –ckpt_prefix /path/to/your_sharded_checkpoint_files –ckpt_suffix “_rank-*-of-*.pth”

此類別的實作很大程度上受到 https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.html 中的 fairscale.nn.FullyShardedDataParallel 的啟發,並且大部分遵循其結構。與 fairscale.nn.FullyShardedDataParallel 的最大區別之一是,在 XLA 中,我們沒有顯式的參數儲存,因此我們採用不同的方法來釋放 ZeRO-3 的完整參數。


MNIST 和 ImageNet 上的訓練腳本範例

安裝

FSDP 在 PyTorch/XLA 1.12 版本和更新的每夜版本中可用。有關安裝指南,請參閱 https://github.com/pytorch/xla#-available-images-and-wheels

複製 PyTorch/XLA 儲存庫

git clone --recursive https://github.com/pytorch/pytorch
cd pytorch/
git clone --recursive https://github.com/pytorch/xla.git
cd ~/

在 v3-8 TPU 上訓練 MNIST

2 個時期的準確度約為 98.9%

python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
  --batch_size 16 --drop_last --num_epochs 2 \
  --use_nested_fsdp --use_gradient_checkpointing

此腳本在最後會自動測試檢查點合併。您也可以通過以下方式手動合併分片的檢查點:

# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
  --ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
  --ckpt_suffix "_rank-*-of-*.pth"

在 v3-8 TPU 上使用 ResNet-50 訓練 ImageNet

100 個時期的準確度約為 75.9%;將 ImageNet-1k 下載到 /datasets/imagenet-1k

python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
  --datadir /datasets/imagenet-1k --drop_last \
  --model resnet50 --test_set_batch_size 64 --eval_interval 10 \
  --lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \
  --use_nested_fsdp

您還可以添加 --use_gradient_checkpointing(需要與 --use_nested_fsdp--auto_wrap_policy 一起使用)以對殘差塊應用梯度檢查點。


TPU pod 上的訓練腳本範例(具有 100 億個參數)

要訓練無法放入單個 TPU 的大型模型,在構建整個模型以實作 ZeRO-3 演算法時,應應用自動包裝或使用內部 FSDP 手動包裝子模組。

有關使用此 XLA FSDP PR 對 Vision Transformer (ViT) 模型進行分片訓練的範例,請參閱 https://github.com/ronghanghu/vit_10b_fsdp_example

如何執行 DistributedDataParallel

本文檔說明瞭如何在 xla 中使用 torch.nn.parallel.DistributedDataParallel,並進一步描述了它與原生 xla 資料平行方法的區別。

背景/動機

客戶長期以來一直要求能夠在 xla 中使用 PyTorch 的 DistributedDataParallel API。現在,我們將其作為一項實驗性功能啟用。

如何使用 DistributedDataParallel

對於那些從 PyTorch 急切模式切換到 XLA 的人來說,以下是要將急切 DDP 模型轉換為 XLA 模型所需進行的所有更改。我們假設您已經知道如何在 單個裝置上使用 XLA。

  1. 匯入 xla 特定的分散式套件

import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_backend
  1. 初始化 xla 處理序組,類似於其他處理序組,例如 nccl 和 gloo。

dist.init_process_group("xla", rank=rank, world_size=world_size)
  1. 如果需要,請使用 xla 特定的 API 來獲取等級和 world_size。

new_rank = xm.get_ordinal()
world_size = xm.xrt_world_size()
  1. gradient_as_bucket_view=True 傳遞給 DDP 包裝器。

ddp_model = DDP(model, gradient_as_bucket_view=True)
  1. 最後使用 xla 特定的啟動器啟動您的模型。

xmp.spawn(demo_fn)

在這裡,我們將所有內容整合在一起(該示例實際上取自 DDP 教學)。您編寫代碼的方式與 eager 體驗非常相似。只需在單個設備上進行 xla 特定的調整,以及對腳本進行上述五項更改即可。

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

from torch.nn.parallel import DistributedDataParallel as DDP

# additional imports for xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_backend
import torch_xla.distributed.xla_multiprocessing as xmp

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the xla process group
    dist.init_process_group("xla", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 1000000)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(1000000, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

def demo_basic(rank):
    # xla specific APIs to get rank, world_size.
    new_rank = xm.get_ordinal()
    assert new_rank == rank
    world_size = xm.xrt_world_size()

    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to XLA device
    device = xm.xla_device()
    model = ToyModel().to(device)
    # currently, graident_as_bucket_view is needed to make DDP work for xla
    ddp_model = DDP(model, gradient_as_bucket_view=True)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10).to(device))
    labels = torch.randn(20, 5).to(device)
    loss_fn(outputs, labels).backward()
    optimizer.step()
    # xla specific API to execute the graph
    xm.mark_step()

    cleanup()


def run_demo(demo_fn):
    # xla specific launcher
    xmp.spawn(demo_fn)

if __name__ == "__main__":
    run_demo(demo_basic)

基準測試

使用虛擬數據的 Resnet50

以下結果是使用以下命令收集的:python test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1,環境為 TPU VM V3-8,使用 ToT PyTorch 和 PyTorch/XLA。統計指標是使用此 pull request 中的腳本生成的。速率單位為每秒處理的圖像數。

類型 平均值 中位數 第 90 個百分位數 標準差 變異係數
xm.optimizer_step 418.54 419.22 430.40 9.76 0.02
DDP 395.97 395.54 407.13 7.60 0.02

我們原生分散式數據平行方法與 DistributedDataParallel 包裝器之間的性能差異為:1 - 395.97 / 418.54 = 5.39%。鑑於 DDP 包裝器在追蹤 DDP 運行時引入了額外開銷,因此該結果似乎是合理的。

使用虛擬數據的 MNIST

以下結果是使用以下命令收集的:python test/test_train_mp_mnist.py --fake_data,環境為 TPU VM V3-8,使用 ToT PyTorch 和 PyTorch/XLA。統計指標是使用此 pull request 中的腳本生成的。速率單位為每秒處理的圖像數。

類型 平均值 中位數 第 90 個百分位數 標準差 變異係數
xm.optimizer_step 17864.19 20108.96 24351.74 5866.83 0.33
DDP 10701.39 11770.00 14313.78 3102.92 0.29

我們原生分散式數據平行方法與 DistributedDataParallel 包裝器之間的性能差異為:1 - 14313.78 / 24351.74 = 41.22%。由於數據集很小,前幾輪受數據加載的影響很大,因此我們比較的是第 90 個百分位數。這種速度下降是巨大的,但考虑到模型很小,這是合理的。額外的 DDP 運行時追蹤開銷難以攤銷。

使用真實數據的 MNIST

以下結果是使用以下命令收集的:python test/test_train_mp_mnist.py --logdir mnist/,環境為 TPU VM V3-8,使用 ToT PyTorch 和 PyTorch/XLA。

learning_curves

我們可以觀察到,即使 DDP 包裝器在最後仍能達到 97.48% 的高準確率,但其收斂速度比原生 XLA 方法慢。(原生方法的準確率達到 99%。)

免責聲明

此功能仍處於實驗階段,並且正在積極開發中。請謹慎使用,並歡迎將任何錯誤報告給 xla github 倉庫。對於那些對原生 xla 數據平行方法感興趣的人,這裡有一份 教學

以下是一些已知的正在調查中的問題

  • 需要強制執行 gradient_as_bucket_view=True

  • torch.utils.data.DataLoader 一起使用時,會出現一些問題。​​test_train_mp_mnist.py 在使用真實數據時會在退出前崩潰。

如何使用 PyTorch/XLA:GPU 運行

PyTorch/XLA 使 PyTorch 用戶能夠利用 XLA 編譯器,該編譯器支持包括 TPU、GPU 和 CPU 在內的加速器。本文檔將介紹在 nvidia GPU 实例上運行 PyTorch/XLA 的基本步驟。

创建 GPU 实例

您可以使用连接了 GPU 的本地机器,也可以使用云上的 GPU 虚拟机。例如,在 Google Cloud 中,您可以按照此 文档 创建 GPU 虚拟机。

环境设置

Docker

Pytorch/XLA 目前发布了使用 cuda11.7/8 和 python 3.8 的预构建 Docker 镜像和 wheel 包。我们建议用户使用相应的配置创建 Docker 容器。有关 Docker 镜像和 wheel 包的完整列表,请参阅 此文档

sudo docker pull us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1
sudo apt-get install -y apt-transport-https ca-certificates curl gnupg-agent    software-properties-common
distribution=$(. /etc/os-release;echo $ID$VERSION_ID)
curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add -
curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list
sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit
sudo systemctl restart docker
sudo docker run --shm-size=16g --net=host --gpus all -it -d us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1 bin/bash
sudo docker exec -it $(sudo docker ps | awk 'NR==2 { print $1 }') /bin/bash

请注意,您需要重新启动 Docker 才能在 Docker 容器中显示 GPU 设备。登录 Docker 后,您可以使用 nvidia-smi 验证设备是否已正确设置。

(pytorch) root@20ab2c7a2d06:/# nvidia-smi
Thu Dec  8 06:24:29 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    38W / 300W |      0MiB / 16384MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

检查环境变量

确保 PATHLD_LIBRARY_PATH 环境变量包含 cuda。请执行 echo $PATHecho $LD_LIBRARY_PATH 进行验证。如果不是,请按照 链接 进行操作。示例

echo "export PATH=/usr/local/cuda-12.1/bin${PATH:+:${PATH}}" >> ~/.bashrc
echo "export LD_LIBRARY_PATH=/usr/local/cuda-12.1/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" >> ~/.bashrc
source ~/.bashrc

Wheel

pip3 install torch==2.2.0
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl

運行簡單模型

要运行以下示例,您需要克隆 pytorch/xla 存储库以访问 imagenet 示例(我们已经在 Docker 中克隆了它)。

(pytorch) root@20ab2c7a2d06:/# export GPU_NUM_DEVICES=1 PJRT_DEVICE=CUDA
(pytorch) root@20ab2c7a2d06:/# git clone --recursive https://github.com/pytorch/xla.git
(pytorch) root@20ab2c7a2d06:/# python xla/test/test_train_mp_imagenet.py --fake_data
==> Preparing data..
Epoch 1 train begin 06:12:38
| Training Device=xla:0/0 Epoch=1 Step=0 Loss=6.89059 Rate=2.82 GlobalRate=2.82 Time=06:13:23
| Training Device=xla:0/0 Epoch=1 Step=20 Loss=6.79297 Rate=117.16 GlobalRate=45.84 Time=06:13:36
| Training Device=xla:0/0 Epoch=1 Step=40 Loss=6.43628 Rate=281.16 GlobalRate=80.49 Time=06:13:43
| Training Device=xla:0/0 Epoch=1 Step=60 Loss=5.83108 Rate=346.88 GlobalRate=108.82 Time=06:13:49
| Training Device=xla:0/0 Epoch=1 Step=80 Loss=4.99023 Rate=373.62 GlobalRate=132.43 Time=06:13:56
| Training Device=xla:0/0 Epoch=1 Step=100 Loss=3.92699 Rate=384.33 GlobalRate=152.40 Time=06:14:02
| Training Device=xla:0/0 Epoch=1 Step=120 Loss=2.68816 Rate=388.35 GlobalRate=169.49 Time=06:14:09

AMP(自動混合精度)

AMP 在 GPU 训练中非常有用,PyTorch/XLA 重用了 Cuda 的 AMP 规则。您可以查看我们的 mnist 示例imagenet 示例。请注意,我们还使用了修改版的 优化器,以避免设备和主机之间的额外同步。

在 GPU 实例上开发 PyTorch/XLA(从源代码构建支持 GPU 的 PyTorch/XLA)

  1. 在 GPU 虚拟机中,从开发 Docker 镜像创建一个 Docker 容器。例如:

sudo docker pull us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.8_cuda_12.1
sudo apt-get install -y apt-transport-https ca-certificates curl gnupg-agent    software-properties-common
distribution=$(. /etc/os-release;echo $ID$VERSION_ID)
curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add -
curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list
sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit
sudo systemctl restart docker
sudo docker run --shm-size=16g --net=host --gpus all -it -d us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.8_cuda_12.1
sudo docker exec -it $(sudo docker ps | awk 'NR==2 { print $1 }') /bin/bash
  1. 从源代码构建 PyTorch 和 PyTorch/XLA。

确保 PATHLD_LIBRARY_PATH 环境变量包含 cuda。有关更多信息,请参阅以上内容

git clone https://github.com/pytorch/pytorch.git
cd pytorch
USE_CUDA=1 python setup.py install

git clone https://github.com/pytorch/xla.git
cd xla
XLA_CUDA=1 python setup.py install
  1. 验证 PyTorch 和 PyTorch/XLA 是否已成功安装。

如果您可以成功运行 运行简单模型 部分中的测试,则说明 PyTorch 和 PyTorch/XLA 已成功安装。

PyTorch/XLA SPMD 用户指南

在本用户指南中,我们将讨论如何在 PyTorch/XLA 中集成 GSPMD,并提供设计概述来说明 SPMD 分片注释 API 及其结构是如何工作的。然后,我们将提供一些参考示例供用户尝试。

什么是 PyTorch/XLA SPMD?

GSPMD 是一种针对常见机器学习工作负载的自动并行化系统。XLA 编译器将根据用户提供的分片提示,将单设备程序转换为具有适当集合操作的分区程序。此功能允许开发人员编写 PyTorch 程序,就好像它们位于一个大型设备上一样,而无需任何自定义分片计算操作和/或集合通信即可进行扩展。

alt_text

*图 1. 两种不同执行策略的比较,(a) 非 SPMD 和 (b) SPMD。*

为了在 PyTorch/XLA 中支持 GSPMD,我们引入了一种新的执行模式。在 GSPMD 之前,PyTorch/XLA 中的执行模式假设有多个模型副本,每个副本都有一个核心(图 1.a)。如上所述,这种执行模式适用于数据并行框架,例如流行的 PyTorch 分布式数据并行 (DDP) 或完全分片数据并行 (FSDP),但也有其局限性,即一个副本只能驻留在一个设备核心上执行。PyTorch/XLA SPMD 引入了一种新的执行模式,该模式假设有一个具有多个核心的副本(图 1.b),允许一个副本跨多个设备核心运行。这种转变解锁了更高级的并行策略,以提高大型模型训练的性能。

PyTorch/XLA SPMD 在新的 PJRT 运行时上可用。要启用 PyTorch/XLA SPMD 执行模式,用户必须调用 [use_spmd() API](https://github.com/pytorch/xla/blob/b8b484515a97f74e013dcf38125c44d53a41f011/torch_xla/runtime.py#L214)

import torch_xla.runtime as xr

# Enable PyTorch/XLA SPMD execution mode.
xr.use_spmd()
assert xr.is_spmd() == True

重要的是要注意,SPMD 是对任何现有并行机制(包括 DDP 和 FSDP)的替代。用户不能混合使用两种不同的执行模式(SPMD 和非 SPMD),在本指南的后面部分,我们将介绍如何使用 SPMD 注释来执行 DDP 和 FSDP。

此外,此版本的 SPMD 目前仅在 Google Cloud TPU 上进行了测试和优化。对 GPU 的支持和优化将在 2.2 版本中提供。

PyTorch/XLA SPMD 设计概述

简单示例和分片注释 API

用户可以使用 mark_sharding API (源代码) 注释原生 PyTorch 张量。它接受 torch.Tensor 作为输入,并返回 XLAShardedTensor 作为输出。

def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, partition_spec: Tuple[Union[int, None]]) -> XLAShardedTensor

调用 mark_sharding API 需要用户定义的逻辑 网格分区规范,并为 XLA 编译器生成分片注释。分片规范附加到 XLATensor。以下是 [RFC 中的一个简单使用示例,用于说明分片注释 API 的工作原理:

import numpy as np
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import Mesh

# Enable XLA SPMD execution mode.
xr.use_spmd()

# Device mesh, this and partition spec as well as the input tensor shape define the individual shard shape.
mesh_shape = (2, 4)
num_devices = xr.global_runtime_device_count()
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

t = torch.randn(8, 4).to(xm.xla_device())

# Mesh partitioning, each device holds 1/8-th of the input
partition_spec = (0, 1)
m1_sharded = xs.mark_sharding(t, mesh, partition_spec)
assert isinstance(m1_sharded, XLAShardedTensor) == True

我们可以注释 PyTorch 程序中的不同张量,以启用不同的并行技术,如下面的注释所述:

# Sharding annotate the linear layer weights.
model = SimpleLinear().to(xm.xla_device())
xs.mark_sharding(model.fc1.weight, mesh, partition_spec)

# Training loop
model.train()
for step, (data, target) in enumerate(loader):
  # Assumes `loader` returns data, target on XLA device
  optimizer.zero_grad()
  # Sharding annotate input data, we can shard any input
  # dimensions. Sharidng the batch dimension enables
  # in data parallelism, sharding the feature dimension enables
  # spatial partitioning.
  xs.mark_sharding(data, mesh, partition_spec)
  ouput = model(data)
  loss = loss_fn(output, target)
  optimizer.step()
  xm.mark_step()

更多完整的单元测试用例和集成测试示例可在 PyTorch/XLA 仓库 中找到。

网格

对于给定的设备集群,物理网格表示互连拓扑。

我们基于此拓扑派生出一个逻辑网格,以创建设备子组,这些子组可用于对模型中张量的不同轴进行分区。

alt_text

我們使用 Mesh API 來抽象化邏輯網格。邏輯網格的軸可以命名。以下是一個範例

import torch_xla.runtime as xr
from torch_xla.distributed.spmd import Mesh

# Assuming you are running on a TPU host that has 8 devices attached
num_devices = xr.global_runtime_device_count()
# mesh shape will be (4,2) in this example
mesh_shape = (num_devices // 2, 2)
device_ids = np.array(range(num_devices))
# axis_names 'x' nad 'y' are optional
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

mesh.get_logical_mesh()
>> array([[0, 1],
          [2, 3],
          [4, 5],
          [6, 7]])
mesh.shape()
>> OrderedDict([('x', 4), ('y', 2)])

一般而言,SPMD 程式應建立單一網格並將其重複用於所有分片,以確保分片分配與預期的分片策略一致。透過操作分區規格(如下所述),可以將相同的網格重複用於不同形狀和分片的張量。

混合網格

網格很好地抽象化了物理設備網格的建構方式。使用者可以使用邏輯網格以任何形狀和順序排列設備。但是,可以根據物理拓撲定義效能更高的網格,尤其是在涉及資料中心網路 (DCN) 跨切片連線時。HybridMesh 建立了一個網格,可以在此類多切片環境中提供良好的開箱即用效能。它接受 ici_mesh_shape 和 dcn_mesh_shape,它們表示內部和外部網路的邏輯網格形狀。

from torch_xla.distributed.spmd import HybridMesh

# This example is assuming 2 slices of v4-8.
# - ici_mesh_shape: shape of the logical mesh for inner connected devices.
# - dcn_mesh_shape: shape of logical mesh for outer connected devices.
ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor)
dcn_mesh_shape = (2, 1, 1)

mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, ('data','fsdp','tensor'))
print(mesh.shape())
>> OrderedDict([('data', 2), ('fsdp', 4), ('tensor', 1)])

分區規格

partition_spec 的秩與輸入張量相同。每個維度描述了相應的輸入張量維度如何在設備網格(由 mesh_shape 邏輯定義)上分片。partition_spec 是一個由 device_mesh 維度 index 或 None 組成的元組。如果相應的網格維度已命名,則索引可以是 intstr。這指定了每個輸入秩如何分片(indexmesh_shape)或複製(None)。

# Provide optional mesh axis names and use them in the partition spec
mesh = Mesh(device_ids, (4, 2), ('data', 'model'))
partition_spec = ('model', 'data')
xs.mark_sharding(input_tensor, mesh, partition_spec)

我們支援原始 GSPMD 論文中描述的所有三種類型的分片。例如,可以像這樣指定部分複製

# Provide optional mesh axis names and use them in the partition spec
mesh = Mesh(device_ids, (2, 2, 2), ('x', 'y', 'z'))

# evenly shard across x and z and replicate among y
partition_spec = ('x', 'z')  # equivalent to ('x', None, 'z')
xs.mark_sharding(input_tensor, mesh, partition_spec)

分區規格允許為不同的張量形狀和所需的分片策略重複使用相同的網格。以下範例使用 3D 網格演示了這一點

# Create a 3-D mesh of 8 devices with logical dimensions replica, fsdp, and
# tensor
mesh = Mesh(device_ids, (2, 2, 2), ('replica', 'fsdp', 'tensor'))

# A 2D tensor can be sharded along the fsdp and tensor axes and replicated
# along the replica axis by omitting `replica` from the partition spec.
two_d_partially_replicated = torch.randn(64, 64, device='xla')
xs.mark_sharding(two_d_partially_replicated, mesh, ('fsdp', 'tensor'))

# A 2D tensor can be sharded across all dimensions by combining, for example,
# the replica and fsdp mesh axes using a tuple
two_d_fully_sharded = torch.randn(64, 64, device='xla')
xs.mark_sharding(two_d_fully_sharded, mesh, (('replica', 'fsdp'), 'tensor'))

# A 4D tensor can be sharded along up to three of its axes using the 3D mesh
four_d = torch.randn(64, 64, 64, 64, device='xla')
xs.mark_sharding(four_d, ('replica', 'fsdp', None, 'tensor'))

XLAShardedTensor

XLAShardedTensor [RFC] 的主要用例是用分片規格註釋原生 torch.tensor(在單個設備上)。註釋會立即進行,但張量的實際分片會延遲,因為計算是延遲執行的,但輸入張量會立即分片。將張量註釋並包裝在 XLAShardedTensor 中後,可以將其作為 torch.Tensor 傳遞到現有的 PyTorch 運算和 nn.Module 層。這一點很重要,可以確保相同的 PyTorch 層和張量運算可以與 XLAShardedTensor 堆疊在一起。這意味著使用者不需要為分片計算重寫現有的運算和模型程式碼。也就是說,XLAShardedTensor 將滿足以下要求

  • XLAShardedTensortorch.Tensor 的子類,可以直接與原生 torch 運算和 module.layers 一起使用。我們使用 __torch_dispatch__XLAShardedTensor 發送到 XLA 後端。PyTorch/XLA 檢索附加的分片註釋以追蹤圖形並呼叫 XLA SPMDPartitioner。

  • 在內部,XLAShardedTensor(及其 global_tensor 輸入)由 XLATensor 支援,該 XLATensor 具有特殊的資料結構,其中包含對分片設備資料的引用。

  • 延遲執行後的分片張量可以在主機上請求時(例如,列印全域張量的值)作為 global_tensor 收集並具體化回主機。

  • 對本地分片的控制代碼會在延遲執行後嚴格具體化。XLAShardedTensor 公開了 local_shards,以將可定址設備上的本地分片作為 List[[XLAShard](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharded_tensor.py#L12)] 返回。

目前也在努力將 XLAShardedTensor 整合到 DistributedTensor API 中,以支援 XLA 後端 [RFC]。

DTensor 整合

PyTorch 已在 2.1 版中原型發布了 DTensor。我們正在將 PyTorch/XLA SPMD 整合到 DTensor API 中 RFC。我們對 distribute_tensor 進行了概念驗證整合,它會呼叫 mark_sharding 註釋 API 來使用 XLA 對張量及其計算進行分片

import torch
from torch.distributed import DeviceMesh, Shard, distribute_tensor

# distribute_tensor now works with `xla` backend using PyTorch/XLA SPMD.
mesh = DeviceMesh("xla", list(range(world_size)))
big_tensor = torch.randn(100000, 88)
my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)])

此功能尚處於實驗階段,請持續關注後續版本中的更多更新、範例和教學。

分片感知的主機到設備資料載入

PyTorch/XLA SPMD 採用單設備程式,並行分片和執行。SPMD 執行需要使用原生 PyTorch DataLoader,它會將資料從主機同步傳輸到 XLA 設備。這會在每個步驟的輸入資料傳輸過程中阻塞訓練。為了提高原生資料載入效能,當傳遞可選的 kwarg _input_sharding_ 時,我們讓 PyTorch/XLA ParallelLoader 直接支援輸入分片(src)。

# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
         train_loader,  # wraps PyTorch DataLoader
         device,
      # optional input_sharding field
         input_sharding=xs.ShardingSpec(input_mesh, (0, 1, 2, 3)))

分散式檢查點

PyTorch/XLA SPMD 透過專用的 Planner 執行個體與 torch.distributed.checkpoint 程式庫相容。使用者能夠透過這個通用介面同步儲存和載入檢查點。

SPMDSavePlanner 和 SPMDLoadPlanner (src) 類別允許 saveload 函數直接對 XLAShardedTensor 的分片進行操作,從而實現 SPMD 訓練中分散式檢查點的所有優點。

以下是同步分散式檢查點 API 的演示

import torch.distributed.checkpoint as dist_cp
import torch_xla.experimental.distributed_checkpoint as xc

# Saving a state_dict
state_dict = {
    "model": model.state_dict(),
    "optim": optim.state_dict(),
}

dist_cp.save(
    state_dict=state_dict,
    storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
    planner=xc.SPMDSavePlanner(),
)
...

# Loading the model's state_dict from the checkpoint. The model should
# already be on the XLA device and have the desired sharding applied.
state_dict = {
    "model": model.state_dict(),
}

dist_cp.load(
    state_dict=state_dict,
    storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
    planner=xc.SPMDLoadPlanner(),
)
model.load_state_dict(state_dict["model"])

CheckpointManager

實驗性的 CheckpointManager 介面在 torch.distributed.checkpoint 函數之上提供了一個更高級別的 API,以啟用一些關鍵功能

  • **託管檢查點**:由 CheckpointManager 執行的每個檢查點都由執行該檢查點的步驟標識。所有追蹤的步驟都可以透過 CheckpointManager.all_steps 方法存取,並且可以使用 CheckpointManager.restore 還原任何追蹤的步驟。

  • **非同步檢查點**:透過 CheckpointManager.save_async API 執行的檢查點會非同步寫入永續性儲存體,以在檢查點期間解除對訓練的阻塞。在將檢查點分派給背景執行緒之前,首先會將輸入分片的 state_dict 移至 CPU。

  • **搶佔時的自動檢查點**:在 Cloud TPU 上,可以偵測到搶佔並在程序終止之前執行檢查點。要使用此功能,請確保您的 TPU 是透過啟用了 自動檢查點 的 QueuedResource 配置的,並確保在建構 CheckpointManager 時設定了 chkpt_on_preemption 參數(預設情況下啟用此選項)。

  • **FSSpec 支援**:CheckpointManager 使用 fsspec 儲存體後端直接啟用對任何與 fsspec 相容的檔案系統(包括 GCS)的檢查點。

以下是 CheckpointManager 的用法範例

from torch_xla.experimental.distributed_checkpoint import CheckpointManager

# Create a CheckpointManager to checkpoint every 10 steps into GCS.
chkpt_mgr = CheckpointManager('gs://my-bucket/my-experiment', 10)

# Select a checkpoint to restore from, and restore if applicable
tracked_steps = chkpt_mgr.all_steps()
if tracked_steps:
    # Choose the highest step
    best_step = max(tracked_steps)
    state_dict = {'model': model.state_dict()}
    chkpt_mgr.restore(best_step, state_dict)
    model.load_state_dict(state_dict['model'])

# Call `save` or `save_async` every step within the train loop. These methods
# return True when a checkpoint is taken.
for step, data in enumerate(dataloader):
    ...
    state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
    if chkpt_mgr.save_async(step, state_dict):
        print(f'Checkpoint taken at step {step}')

程序群組

要使用 torch.distributed API(例如分散式檢查點),需要一個程序群組。在 SPMD 模式下,不支援 xla 後端,因為編譯器負責所有集合。

相反,必須使用 CPU 程序群組,例如 gloo。在 TPU 上,仍然支援 xla:// init_method 來發現主 IP、全域世界大小和主機秩。以下是一個初始化範例

import torch.distributed as dist
# Import to register the `xla://` init_method
import torch_xla.distributed.xla_backend
import torch_xla.runtime as xr

xr.use_spmd()

# The `xla://` init_method will automatically discover master worker IP, rank,
# and global world size without requiring environment configuration on TPUs.
dist.init_process_group('gloo', init_method='xla://')

虛擬設備優化

PyTorch/XLA 通常在定義張量後,將張量資料從主機非同步傳輸到設備。這是為了將資料傳輸與圖形追蹤時間重疊。但是,由於 GSPMD 允許使用者在定義張量_之後_修改張量分片,因此我們需要一種優化措施來防止張量資料在主機和設備之間來回傳輸。我們引入了虛擬設備優化,這是一種在將張量資料上傳到物理設備之前,先將其放置在虛擬設備 SPMD:0 上的技術,所有分片決策都已完成。SPMD 模式下的每個張量資料都放置在虛擬設備 SPMD:0 上。虛擬設備作為 XLA 設備 XLA:0 向使用者公開,實際分片位於物理設備上,例如 TPU:0、TPU:1 等。

程序數量

與現有的 DDP 和 FSDP 不同,在 SPMD 模式下,每個加速器主機上始終只有一個程序在運行。這提供了一個好處,即 PyTorch/XLA 只需要編譯每個圖形一次,並且可以將其重複用於連接到此主機的所有加速器。

在 TPU Pod 上運行 SPMD

如果您根據設備數量而不是某些硬編碼常數來建構網格和分區規格,則從單個 TPU 主機轉到 TPU Pod 不需要更改程式碼。要在 TPU Pod 上運行 PyTorch/XLA 工作負載,請參閱 PJRT 指南的 Pod 部分

在 GPU 上運行 SPMD

PyTorch/XLA 在 NVIDIA GPU(單節點或多節點)上支援 SPMD。訓練/推論腳本與用於 TPU 的腳本相同,例如這個 ResNet 腳本。為了使用 SPMD 執行腳本,我們利用 torchrun

PJRT_DEVICE=CUDA \
torchrun \
--nnodes=${NUM_GPU_MACHINES} \
--node_rank=${RANK_OF_CURRENT_MACHINE} \
--nproc_per_node=1 \
--rdzv_endpoint="<MACHINE_0_IP_ADDRESS>:<PORT>" \
training_or_inference_script_using_spmd.py
  • --nnodes:要使用多少台 GPU 機器。

  • --node_rank:目前 GPU 機器的索引。值可以是 0、1、…、${NUMBER_GPU_VM}-1。

  • --nproc_per_node:由於 SPMD 的要求,該值必須為 1。

  • –rdzv_endpoint:節點等級為 0 的 GPU 機器的端點,格式為 主機:端口`。主機將是內部 IP 位址。端口` 可以是機器上任何可用的端口。對於單節點訓練/推論,可以省略此參數。

例如,如果您想使用 SPMD 在 2 台 GPU 機器上訓練 ResNet 模型,則可以在第一台機器上執行以下腳本

XLA_USE_SPMD=1 PJRT_DEVICE=CUDA \
torchrun \
--nnodes=2 \
--node_rank=0 \
--nproc_per_node=1 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" \
pytorch/xla/test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 128

並在第二台機器上執行以下操作

XLA_USE_SPMD=1 PJRT_DEVICE=CUDA \
torchrun \
--nnodes=2 \
--node_rank=1 \
--nproc_per_node=1 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" \
pytorch/xla/test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 128

有關更多資訊,請參閱 GPU 上的 SPMD 支援 RFC

參考範例

使用 SPMD 來表示資料平行

SPMD API 足夠通用,可以表示資料平行和模型平行。可以通過註釋輸入批次維度以進行分片來簡單地實現資料平行。這裡,我們已經在所有可用設備(N 路)上分片了批次維度:有兩種使用 SPMD 來表示資料平行或批次分片的方法

num_devices = xr.global_runtime_device_count()

# Assume data is 4d and 0th dimension is the batch dimension
mesh_shape = (num_devices, 1, 1, 1)
input_mesh = xs.Mesh(device_ids, mesh_shape, ('B', 'C', 'W', 'H'))
partition_spec = range(num_devices)

# Shard the batch dimension
xs.mark_sharding(input_tensor, input_mesh, partition_spec)

PyTorch/XLA 的 MpDeviceLoader 支援輸入批次分片,它還在後台將批次加載到設備

num_devices = xr.global_runtime_device_count()

# Assume data is 4d and 0th dimension is the batch dimension
mesh_shape = (num_devices, 1, 1, 1)
input_mesh = xs.Mesh(device_ids, mesh_shape, ('B', 'C', 'W', 'H'))
partition_spec = range(num_devices)

# Use MpDeviceLoader to load data in background
train_loader = pl.MpDeviceLoader(
     train_loader,
     device,
     input_sharding=xs.ShardingSpec(input_mesh, partition_spec))

我們強烈建議使用第二種方法,因為它應該會產生更好的訓練效能。

使用 SPMD 來表示 FSDP(完全分片資料平行)

PyTorch 的 FSDP 是資料平行 + 在第 0 維度分片的模型參數。使用者首先需要使用 SPMD 來表示資料平行,如上一節所述。

for name, param in model.named_parameters():
    shape = (num_devices,) + (1,) * (len(param.shape) - 1)
    mesh = xs.Mesh(device_ids, shape)
    xs.mark_sharding(param, mesh, range(len(param.shape)))

使用 SPMD 執行 ResNet50 範例

我們提供了一個 resnet50 的快速範例,其中包含一些不同的 SPMD 分片策略供您試用。您可以先在不使用 SPMD 的情況下使用以下命令運行它

python test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 512

並檢查吞吐量。之後,您可以使用以下命令啟用批次分片

XLA_USE_SPMD=1 python test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 2048 --model=resnet50 --sharding=batch

請注意,我使用的批次大小是之前的 4 倍,因為我是在 TPU v4 上運行它,它連接了 4 個 TPU 設備。您應該會看到吞吐量大約是非 spmd 運行的 4 倍。

SPMD 除錯工具

我們為在 TPU/GPU/CPU 上使用單主機/多主機的 PyTorch/XLA SPMD 使用者提供 分片 放置 視覺化 除錯 工具:您可以使用 visualize_tensor_sharding 來視覺化分片張量,或者您可以使用 visualize_sharding 來視覺化分片字串。以下是在 TPU 單主機(v4-8)上使用 visualize_tensor_shardingvisualize_sharding 的兩個程式碼範例

  • 使用 visualize_tensor_sharding 的程式碼片段和視覺化結果

import rich

# Here, mesh is a 2x2 mesh with axes 'x' and 'y'
t = torch.randn(8, 4, device='xla')
xs.mark_sharding(t, mesh, ('x', 'y'))

# A tensor's sharding can be visualized using the `visualize_tensor_sharding` method
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
generated_table = visualize_tensor_sharding(t, use_color=False)
alt_text
  • 使用 visualize_sharding 的程式碼片段和視覺化結果

from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[2,2]0,1,2,3}'
generated_table = visualize_sharding(sharding, use_color=False)
alt_text

您可以在 TPU/GPU/CPU 單主機上使用這些範例,並修改它以在多主機上運行。並且您可以將其修改為分片樣式 tiledpartial_replicationreplicated

自動分片

我們正在引入一個新的 PyTorch/XLA SPMD 功能,稱為 自動分片RFC。這是 r2.3nightly 中的一個實驗性功能,支援 XLA:TPU 和單個 TPUVM 主機。

可以通过以下方式之一启用 PyTorch/XLA 自动分片

  • 設定環境變數 XLA_SPMD_AUTO=1

  • 在程式碼的開頭呼叫 SPMD API

import torch_xla.runtime as xr
xr.use_spmd(auto=True)
  • 使用 auto-policyxla 呼叫 pytorch.distributed._tensor.distribute_module

import torch_xla.runtime as xr
from torch.distributed._tensor import DeviceMesh, distribute_module
from torch_xla.distributed.spmd import auto_policy

device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))

# Currently, model should be loaded to xla device via distribute_module.
model = MyModule()  # nn.module
sharded_model = distribute_module(model, device_mesh, auto_policy)

或者,可以設定以下選項/環境變數來控制基於 XLA 的自動分片傳遞的行為

  • XLA_AUTO_USE_GROUP_SHARDING:參數的群組重新分片。預設情況下設定。

  • XLA_AUTO_SPMD_MESH:用於自動分片的邏輯網格形狀。例如,XLA_AUTO_SPMD_MESH=2,2 對應於具有 4 個全域設備的 2x2 網格。如果未設定,將使用 num_devices,1 的預設設備網格形狀。

通過 SPMD 完全分片資料平行

通過 SPMD 或 FSDPv2 完全分片資料平行是一個實用程式,它在 SPMD 中重新表達了著名的 FSDP 演算法。功能是一項實驗性功能,旨在為使用者提供一個熟悉的介面,讓他們能夠享受 SPMD 帶來的所有好處。設計文件此處

在繼續之前,請先查看 SPMD 使用者指南

使用範例

import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2

# Define the mesh following common SPMD practice
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
# To be noted, the mesh must have an axis named 'fsdp', which the weights and activations will be sharded on.
mesh = Mesh(device_ids, mesh_shape, ('fsdp', 'model'))

# Shard the input, and assume x is a 2D tensor.
x = xs.mark_sharding(x, mesh, ('fsdp', None))

# As normal FSDP, but an extra mesh is needed.
model = FSDPv2(my_module, mesh)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()

也可以單獨對各個層進行分片,並使用外部包裝器處理任何剩餘的參數。自動包裝功能將在未來版本中提供。

分片輸出

為了確保 XLA 編譯器正確實現 FSDP 演算法,我們需要對權重和激活進行分片。這意味著要對正向方法的輸出進行分片。由於正向函數的輸出可能有所不同,因此我們提供 shard_output 來在模組輸出不屬於以下類別之一的情況下對激活進行分片

  1. 單個張量

  2. 張量元組,其中第 0 個元素是激活。

使用範例

def shard_output(output, mesh):
    xs.mark_sharding(output.logits, mesh, ('fsdp', None, None))

model = FSDPv2(my_module, mesh, shard_output)

梯度檢查點

目前,需要在 FSDP 包裝器之前將梯度檢查點應用於模組。否則,遞迴迴圈進入子模組將導致無限迴圈。我們將在未來版本中修復此問題。

使用範例

from torch_xla.distributed.fsdp import checkpoint_module

model = FSDPv2(checkpoint_module(my_module), mesh)

HuggingFace Llama 2 範例

我們有一個 HF Llama 2 的分支來演示潛在的整合此處

文件

存取 PyTorch 的完整開發者文件

查看文件

教學

取得適用於初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並取得您的問題解答

查看資源