跳轉到主要內容
部落格

透過 torch.compile 將 NumPy 程式碼編譯成 C++ 或 CUDA

作者: 2023 年 10 月 17 日2024 年 11 月 14 日暫無評論

Quansight 工程師在 PyTorch 2.1 中實現了透過 torch.compile 跟蹤 NumPy 程式碼的支援。此功能利用 PyTorch 的編譯器生成高效的融合向量化程式碼,而無需修改原始 NumPy 程式碼。更重要的是,它還允許透過在 torch.device("cuda") 下執行 torch.compile 來在 CUDA 上執行 NumPy 程式碼!

在這篇文章中,我們將介紹如何使用此功能,並提供一些充分利用它的技巧。

將 NumPy 程式碼編譯成並行 C++

我們以 K-Means 演算法中的一步為例。這段程式碼借用了這本 NumPy 書籍

import numpy as np

def kmeans(X, means):
    return np.argmin(np.linalg.norm(X - means[:, None], axis=2), axis=0)

我們建立了一個包含 2000 萬個隨機 2D 點的合成數據集。我們可以看到,如果均值選擇得當,該函式會為所有點返回正確的簇。

npts = 10_000_000
X = np.repeat([[5, 5], [10, 10]], [npts, npts], axis=0)
X = X + np.random.randn(*X.shape)  # 2 distinct "blobs"
means = np.array([[5, 5], [10, 10]])
np_pred = kmeans(X, means)

在 AMD 3970X CPU 上對該函式進行基準測試,我們得到了 1.26 秒的基線。

現在編譯該函式就像用 torch.compile 封裝它並用示例輸入執行它一樣簡單。

import torch

compiled_fn = torch.compile(kmeans)
compiled_pred = compiled_fn(X, means)
assert np.allclose(np_pred, compiled_pred)

在 1 個核心上執行時,編譯後的函式帶來了 9 倍的加速。更好的是,與 NumPy 不同,我們生成的程式碼確實利用了處理器中的所有核心。因此,當我們在 32 個核心上執行時,我們獲得了 57 倍的加速。請注意,除非明確限制,否則 PyTorch 始終使用所有可用核心,因此這是使用 torch.compile 時的預設行為。

我們可以透過使用環境變數 TORCH_LOGS=output_code 執行指令碼來檢查生成的 C++ 程式碼。這樣做時,我們可以看到 torch.compile 能夠將廣播和兩次歸約編譯成一個 for 迴圈,並使用 OpenMP 對其進行並行化。

extern "C" void kernel(const double* in_ptr0, const long* in_ptr1, long* out_ptr0) {
    #pragma omp parallel num_threads(32)
    #pragma omp for
    for(long i0=0L; i0<20000000L; i0+=1L) {
        auto tmp0 = in_ptr0[2L*i0];
        auto tmp1 = in_ptr1[0L];
        auto tmp5 = in_ptr0[1L + (2L*i0)];
        auto tmp6 = in_ptr1[1L];
        // Rest of the kernel omitted for brevity

將 NumPy 程式碼編譯成 CUDA

將我們的程式碼編譯成在 CUDA 上執行,就像將預設裝置設定為 CUDA 一樣簡單。

with torch.device("cuda"):
    cuda_pred = compiled_fn(X, means)
assert np.allclose(np_pred, cuda_pred)

透過 TORCH_LOGS=output_code 檢查生成的程式碼,我們看到,torch.compile 沒有直接生成 CUDA 程式碼,而是生成了相當可讀的 Triton 程式碼。

def triton_(in_ptr0, in_ptr1, out_ptr0, XBLOCK : tl.constexpr):
    xnumel = 20000000
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (2*x0), xmask)
    tmp1 = tl.load(in_ptr1 + (0))
    // Rest of the kernel omitted for brevity

在 RTX 2060 上執行這個小片段,比原始 NumPy 程式碼實現了 8 倍的加速。這還不錯,但考慮到我們在 CPU 上看到的加速,這並不是特別令人印象深刻。讓我們看看如何透過一些小的更改來充分利用我們的 GPU。

float64 vs float32。許多 GPU,特別是消費級 GPU,在執行 float64 上的操作時相當遲鈍。因此,將資料生成更改為 float32 後,原始 NumPy 程式碼只快了一點,大約 9%,但我們的 CUDA 程式碼 快了 40%,比純 NumPy 程式碼帶來了 11 倍的加速

torch.compile 預設遵循 NumPy 語義,因此它將 np.float64 用作其所有建立操作的預設 dtype。正如所討論的,這會影響效能,因此可以透過設定以下內容來更改此預設值:

from torch._dynamo import config
config.numpy_default_float = "float32"

CPU <-> CUDA 複製。11 倍的加速很好,但它甚至不接近 CPU 的數字。這是由 torch.compile 在幕後進行的小轉換引起的。上面的程式碼接受 NumPy 陣列並返回 NumPy 陣列。所有這些陣列都在 CPU 上,但計算在 GPU 上執行。這意味著每次呼叫函式時,torch.compile 都必須將所有這些陣列從 CPU 複製到 GPU,然後將結果複製回 CPU 以保留原始語義。NumPy 中沒有針對此問題的原生解決方案,因為 NumPy 沒有 device 的概念。儘管如此,我們可以透過為該函式建立一個包裝器來解決它,以便它接受 PyTorch 張量並返回 PyTorch 張量。

@torch.compile
def tensor_fn(X, means):
    X, means = X.numpy(), means.numpy()
    ret = kmeans(X, means)
    return torch.from_numpy(ret)

def cuda_fn(X, means):
    with torch.device("cuda"):
        return tensor_fn(X, means)

這個函式現在接受 CUDA 記憶體中的張量並返回 CUDA 記憶體中的張量,但函式本身是用 NumPy 編寫的!torch.compilenumpy()from_numpy() 呼叫用作提示,並將其最佳化掉,在內部它只是使用 PyTorch 張量,完全不移動記憶體。當我們將張量保留在 CUDA 中並以 float32 執行計算時,我們看到比初始 NumPy 實現(在 float32 陣列上)200 倍的加速

混合 NumPy 和 PyTorch。在這個例子中,我們必須編寫一個小的介面卡來將張量轉換為 ndarrays,然後再轉換回張量。在混合 PyTorch 和 NumPy 的程式中,將張量轉換為 ndarray 通常實現為 x.detach().cpu().numpy(),或者簡單地 x.numpy(force=True)。由於在 torch.compile 下執行時我們可以在 CUDA 中執行 NumPy 程式碼,我們可以將這種轉換模式實現為對 x.numpy() 的呼叫,就像我們上面所做的那樣。這樣做並在 device("cuda") 下執行生成的程式碼將從原始 NumPy 呼叫生成高效的 CUDA 程式碼,而無需將資料從 CUDA 複製到 CPU。請注意,生成的程式碼在沒有 torch.compile 的情況下無法執行。為了在 eager 模式下執行,需要回滾到 x.numpy(force=True)

進一步的加速技巧

一般建議。我們展示的 CUDA 程式碼已經相當高效,但執行示例確實相當短。在處理更大的程式時,我們可能需要調整其部分以使其更高效。一個好的起點是多個 torch.compile 的教程和常見問題解答。這展示了檢查跟蹤過程的多種方法,以及如何識別可能導致速度下降的問題程式碼。

編譯 NumPy 程式碼時的建議。NumPy 儘管與 PyTorch 相當相似,但通常使用方式非常不同。在 NumPy 中執行計算,然後根據陣列中的值執行 if/else,或透過布林掩碼就地執行操作是很常見的。這些構造雖然受 torch.compile 支援,但會阻礙其效能。像以無分支方式編寫程式碼以避免圖中斷,或避免就地操作等更改可以起到很大的作用。

要編寫快速的 NumPy 程式碼,最好避免迴圈,但有時它們是不可避免的。在跟蹤迴圈時,torch.compile 會嘗試完全展開它。這有時是可取的,但有時甚至不可能,例如當我們有一個動態停止條件時,就像在 while 迴圈中一樣。在這些情況下,最好只編譯迴圈的主體,也許一次迭代幾次(迴圈展開)。

除錯 NumPy 程式碼。當涉及編譯器時,除錯相當棘手。為了確定您遇到的錯誤是 torch.compile 錯誤,還是程式錯誤,您可以透過將 NumPy 匯入替換為 import torch._numpy as np 來在不使用 torch.compile 的情況下執行 NumPy 程式。這應該只用於 除錯目的,絕不能替代 PyTorch API,因為它 慢得多,並且作為私有 API,可能會在不通知的情況下更改。另請參閱 此常見問題解答 以獲取其他技巧。

NumPy 與 torch.compile NumPy 之間的差異

NumPy 標量。在幾乎所有 PyTorch 會返回 0-D 張量(例如從 np.sum)的情況下,NumPy 都會返回 NumPy 標量。在 torch.compile 下,NumPy 標量被視為 0-D 陣列。這在大多數情況下都很好。它們行為不同的唯一情況是當 NumPy 標量隱式用作 Python 標量時。例如,

>>> np.asarray(2) * [1, 2, 3]  # 0-D array is an array-like
array([2, 4, 6])
>>> u = np.int32(2)
>>> u * [1, 2, 3]              # scalar decays into a Python int
[1, 2, 3, 1, 2, 3]
>>> torch.compile(lambda: u * [1, 2, 3])()
array([2, 4, 6])               # acts as a 0-D array, not as a scalar ?!?!

如果我們編譯前兩行,我們看到 torch.compileu 視為 0-D 陣列。要恢復 eager 語義,我們只需明確進行型別轉換。

>>> torch.compile(lambda: int(u) * [1, 2, 3])()
[1, 2, 3, 1, 2, 3]

型別提升和版本控制。NumPy 的型別提升規則有時可能有點令人驚訝。

>>> np.zeros(1, dtype=np.int8) + 127
array([127], dtype=int8)
>>> np.zeros(1, dtype=np.int8) + 128
array([128], dtype=int16)

NumPy 2.0 正在更改這些規則,以遵循更接近 PyTorch 的其他規則。相關的技術文件是 NEP 50torch.compile 已經實現了 NEP 50,而不是即將棄用的規則。

通常,torch.compile 中的 NumPy 遵循 NumPy 2.0 預釋出版。

超越 NumPy:SciPy 和 scikit-learn

在使 torch.compile 理解 NumPy 程式碼的努力的同時,其他 Quansight 工程師設計並提出了一種支援 scikit-learn 和 SciPy 中 PyTorch 張量的方法。這受到了這些庫的其他維護者的熱烈歡迎,因為事實證明,使用 PyTorch 作為後端通常會帶來顯著的加速。這兩個專案現在都已在多個 API 和子模組中合併了對 PyTorch 張量的初始支援。

這為邁向未來奠定了基礎,未來 PyTorch 張量可以在 Python 資料生態系統中的其他庫中使用。更重要的是,這將使這些其他庫能夠在 GPU 上執行,甚至可以編譯混合這些庫和 PyTorch 的程式碼,類似於我們在這篇文章中討論的內容。

如果您想了解更多關於這項工作、如何使用它或如何幫助推動它,請參閱 這篇博文

總結

PyTorch 自成立以來就致力於成為一個與 Python 生態系統其餘部分相容的框架。啟用編譯 NumPy 程式,並建立必要的工具來對其他知名庫做同樣的事情,是朝著這個方向邁出的另外兩步。Quansight 和 Meta 繼續攜手合作,提高 PyTorch 與生態系統其餘部分之間的相容性。

來自 Quansight,我們要感謝 Mengwei、Voz 和 Ed 在將我們的工作與 torch.compile 整合方面提供的寶貴幫助。我們還要感謝 Meta 為該專案以及之前為提高 PyTorch 中 NumPy 相容性和導致支援 scikit-learn 和 SciPy 中 PyTorch 的專案提供資金。這些是鞏固 PyTorch 作為開源 Python 資料生態系統首選框架的巨大飛躍。