跳轉到主要內容
部落格

透過 PyTorch/XLA 在 Cloud TPU 上了解 LazyTensor 系統性能

作者: 2022年3月2日2024年11月15日暫無評論

簡介

易用性、表達性和可除錯性是 PyTorch 的核心原則。易用性的關鍵驅動因素之一是 PyTorch 預設的“即時(eager)”執行方式,即逐操作執行保留了程式的命令式特性。然而,即時執行不提供基於編譯器的最佳化,例如當計算可以表示為圖時進行的最佳化。

LazyTensor [1] 最早隨 PyTorch/XLA 引入,它有助於結合這些看似不同的方法。雖然 PyTorch 的即時執行被廣泛使用、直觀且易於理解,但惰性執行尚未普及。

在這篇文章中,我們將探討 LazyTensor 系統的一些基本概念,目標是將這些概念應用於理解和除錯 PyTorch 中基於 LazyTensor 的實現的效能。儘管我們將使用 PyTorch/XLA 在 Cloud TPU 上作為探索這些概念的載體,但我們希望這些想法對於理解其他基於 LazyTensor 構建的系統也會有所幫助。

LazyTensor

對 PyTorch 張量執行的任何操作預設都會作為核心或核心組合排程到底層硬體。這些核心在底層硬體上非同步執行。程式執行不會被阻塞,直到張量的值被獲取。這種方法與大規模並行程式設計硬體(如 GPU)配合得非常好。

LazyTensor 系統的起點是一個自定義張量型別。在 PyTorch/XLA 中,這種型別被稱為 XLA 張量。與 PyTorch 的原生張量型別不同,在 XLA 張量上執行的操作會被記錄到一個 IR 圖中。讓我們看一個計算兩個張量乘積之和的例子:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

dev = xm.xla_device()

x1 = torch.rand((3, 3)).to(dev)
x2 = torch.rand((3, 8)).to(dev)

y1 = torch.einsum('bs,st->bt', x1, x2)
print(torch_xla._XLAC._get_xla_tensors_text([y1]))

您可以執行  colab 筆記本,檢視 y1 的結果圖。請注意,目前尚未執行任何計算。

y1 = y1 + x2
print(torch_xla._XLAC._get_xla_tensors_text([y1]))

這些操作將繼續執行,直到 PyTorch/XLA 遇到一個屏障。這個屏障可以是 mark_step() API 呼叫,也可以是任何其他強制執行迄今為止記錄的圖的事件。

xm.mark_step()
print(torch_xla._XLAC._get_xla_tensors_text([y1]))

一旦呼叫 mark_step(),圖就會被編譯,然後在 TPU 上執行,即張量已經被例項化。因此,現在圖被簡化為一行 y1 張量,其中包含計算結果。

一次編譯,多次執行

XLA 編譯過程提供最佳化(例如操作融合,透過為多個操作使用暫存記憶體來減少 HBM 壓力,參考),並利用底層的 XLA 基礎設施來最佳化使用底層硬體。然而,有一個注意事項,編譯過程開銷很大,即可能會增加訓練步長。因此,這種方法只有在我們能夠 一次編譯,多次執行 (編譯快取有助於確保相同的圖不會被多次編譯)時才能很好地擴充套件。

在以下示例中,我們建立一個小的計算圖並計時執行:

y1 = torch.rand((3, 8)).to(dev)
def dummy_step() :
  y1 = torch.einsum('bs,st->bt', y1, x)
  xm.mark_step()
  return y1
%timeit dummy_step
The slowest run took 29.74 times longer than the fastest. This could mean that an intermediate result is being cached.
10000000 loops, best of 5: 34.2 ns per loop

您會注意到最慢的步驟比最快的步驟要長得多。這是因為圖編譯開銷,對於給定形狀的圖、輸入形狀和輸出形狀,這種開銷只發生一次。後續步驟更快,因為無需圖編譯。

這也意味著,當“一次編譯,多次執行”的假設被打破時,我們預期會看到效能陡降。理解何時打破這個假設是理解和最佳化 LazyTensor 系統性能的關鍵。讓我們來看看是什麼觸發了編譯。

圖編譯與執行和 LazyTensor 屏障

我們看到,當遇到 LazyTensor 屏障時,計算圖會被編譯和執行。有三種情況會自動或手動引入 LazyTensor 屏障。第一種是顯式呼叫 mark_step() API,如前一個示例所示。當您使用 MpDeviceLoader 包裝您的資料載入器時(強烈建議將計算和資料上傳到 TPU 裝置重疊),mark_step() 也會在每一步隱式呼叫。xla_model 的 Optimizer step 方法也允許隱式呼叫 mark_step(當您設定 barrier=True 時)。

引入屏障的第二種情況是當 PyTorch/XLA 發現一個沒有等效 XLA HLO 操作對映(降級)的操作時。PyTorch 有 2000+ 個操作。儘管這些操作中的大多數是複合的(即可以用其他基本操作來表達),但其中一些操作在 XLA 中沒有相應的降級。

當使用了沒有 XLA 降級(lowering)的操作時會發生什麼?PyTorch XLA 會停止操作記錄,並切斷導致未降級操作輸入(inputs)的圖。然後,這個被切斷的圖會被編譯並排程執行。執行結果(具體化的張量)會從裝置傳送回主機,然後未降級的操作會在主機(CPU)上執行,接著下游的 LazyTensor 操作會建立新的圖,直到再次遇到屏障。

導致 LazyTensor 屏障的第三種也是最後一種情況是存在控制結構/語句或需要張量值的其他方法。這種語句至少會導致導致該張量的計算圖的執行(如果該圖已被看到),或者導致兩者的編譯和執行。

這類方法的其他例子包括 .item()、isEqual()。通常,任何將 Tensor 對映到 Scalar 的操作都會導致這種行為。

動態圖

如前一節所示,如果相同的圖形狀被多次執行,圖編譯成本會得到分攤。這是因為編譯後的圖會根據圖形狀、輸入形狀和輸出形狀派生的雜湊值進行快取。如果這些形狀發生變化,將觸發重新編譯,而過於頻繁的編譯會導致訓練時間下降。

我們來看下面的例子:

def dummy_step(x, y, loss, acc=False):
  z = torch.einsum('bs,st->bt', y, x)
  step_loss = z.sum().view(1,)
  if acc:
    loss = torch.cat((loss, step_loss))
  else:
    loss = step_loss
  xm.mark_step()
  return loss


import time
def measure_time(acc=False):
  exec_times = []
  iter_count = 100
  x = torch.rand((512, 8)).to(dev)
  y = torch.rand((512, 512)).to(dev)
  loss = torch.zeros(1).to(dev)
  for i in range(iter_count):
    tic = time.time()
    loss = dummy_step(x, y, loss, acc=acc)
    toc = time.time()
    exec_times.append(toc - tic)
  return exec_times

dyn = measure_time(acc=True) # acc= True Results in dynamic graph
st = measure_time(acc=False) # Static graph, computation shape, inputs and output shapes don't change

import matplotlib.pyplot as plt
plt.plot(st, label = 'static graph')
plt.plot(dyn, label = 'dynamic graph')
plt.legend()
plt.title('Execution time in seconds')

請注意,靜態和動態案例具有相同的計算,但動態圖每次都會編譯,導致整體執行時間更長。在實踐中,重新編譯的訓練步驟有時會慢一個數量級甚至更多。在下一節中,我們將討論 PyTorch/XLA 的一些工具來除錯訓練效能下降的問題。

使用 PyTorch/XLA 分析訓練效能

PyTorch/XLA 效能分析包含兩個主要元件。首先是客戶端效能分析。只需將環境變數 PT_XLA_DEBUG 設定為 1 即可啟用此功能。客戶端效能分析會指向您原始碼中未降級(unlowered)的操作或裝置到主機的資料傳輸。客戶端效能分析還會報告訓練期間是否發生過於頻繁的編譯。您可以結合效能分析器在  筆記本中探索 PyTorch/XLA 提供的一些指標和計數器。

PyTorch/XLA 分析器提供的第二個元件是內聯跟蹤註釋。例如:

import torch_xla.debug.profiler as xp

def train_imagenet():
  print('==> Preparing data..')
  img_dim = get_model_property('img_dim')
  ....
  server = xp.start_server(3294)
  def train_loop_fn(loader, epoch):
    ....
    model.train()
    for step, (data, target) in enumerate(loader):
      with xp.StepTrace('Train_Step', step_num=step):
        ....
        if FLAGS.amp:
        ....
        else:
          with xp.Trace('build_graph'):
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
          xm.optimizer_step(optimizer)

注意 start_server API 呼叫。您在此處使用的埠號將與 TensorBoard 分析器使用的埠號相同,以便檢視類似於以下內容的 op 跟蹤:

操作跟蹤與客戶端除錯功能是除錯和最佳化 PyTorch/XLA 訓練效能的強大工具集。有關分析器使用的更詳細說明,建議讀者探索 PyTorch/XLA 效能除錯系列部落格的 第1部分、 第2部分和 第3部分

總結

本文回顧了 LazyTensor 系統的基本原理。我們以 PyTorch/XLA 為基礎,深入理解了訓練效能下降的潛在原因。我們討論了為什麼“一次編譯,多次執行”有助於在 LazyTensor 系統上獲得最佳效能,以及當此假設被打破時訓練速度會變慢的原因。

我們希望 PyTorch 使用者會發現這些見解對他們使用 LazyTensor 系統進行的新穎工作有所幫助。

致謝

衷心感謝我的傑出同事 Jack Cao、Milad Mohammedi、Karl Weinmeister、Rajesh Thallam、Jordan Tottan(Google)和 Geeta Chauhan(Meta)的嚴謹評審和反饋。同時感謝來自 Google、Meta 和開源社群的 PyTorch/XLA 擴充套件開發團隊,使得 PyTorch 在 TPU 上成為可能。最後,感謝 LazyTensor 論文的作者,他們不僅開發了 LazyTensor,還撰寫了如此易懂的論文。

參考文獻

[1] LazyTensor: 將即時執行與領域特定編譯器結合