跳轉到主要內容
公告

PyTorch 2.2 中的新庫更新

作者: 2024 年 1 月 30 日2025 年 4 月 30 日暫無評論

總結

在 PyTorch 2.2 釋出的同時,我們為當前的 PyTorch 庫帶來了一系列改進。這些更新表明我們致力於在所有領域開發通用且可擴充套件的 API,以便我們的社群更容易在 PyTorch 上構建生態系統專案。

最新的穩定庫版本 (完整列表)*
TorchArrow 0.1.0 TorchRec 0.6.0 TorchVision 0.17
TorchAudio 2.2.0 TorchServe 0.9.0 TorchX 0.7.0
TorchData 0.7.1 TorchText 0.17.0 PyTorch on XLA Devices 2.1

*要檢視 以前的版本 或(不穩定)的夜間版本,請點選“搜尋文件”上方左上角選單中的版本。

TorchRL

功能:TorchRL 的離線強化學習資料中心

TorchRL 現在提供了一個最大的離線強化學習和模仿學習資料集中心,所有資料都採用單一資料格式(TED,即 TorchRL Episode Data 格式)。這使得在單個訓練迴圈中輕鬆切換不同來源成為可能。現在還可以透過 ReplayBufferEnsemble 類輕鬆組合來自不同來源的資料集。資料處理完全可定製。來源包括模擬任務(Minari、D4RL、VD4RL)、機器人資料集(Roboset、OpenX Embodied 資料集)和遊戲(GenDGRL/ProcGen、Atari/DQN)。請在文件中檢視。

除了這些更改之外,我們的重放緩衝區現在可以使用 .dumps() 方法轉儲到磁碟上,該方法將使用 TensorDict API 將緩衝區序列化到磁碟上,該 API 比使用 torch.save 更快、更安全、更高效。

最後,重放緩衝區現在可以在同一臺機器上的不同程序中讀取和寫入,而無需使用者編寫任何額外的程式碼!

TorchRL2Gym 環境 API

為了促進 TorchRL 在現有程式碼庫中的整合並享受 TorchRL 環境 API 的所有功能(裝置上的執行、批處理操作、變換等),我們提供了一個 TorchRL 到 gym 的 API,允許使用者在 gym 或 gymnasium 中註冊他們想要的任何環境。這反過來可以使 TorchRL 成為一個通用的庫到 gym 轉換器,適用於無狀態(例如,dm_control)和有狀態(Brax、Jumanji)環境。該功能在文件中詳細說明。info_dict 讀取 API 也得到了改進。

環境加速

我們增加了在 ParallelEnv 中在與用於提供資料的環境不同的環境上執行環境的選項。我們還加快了 GymLikeEnv 類的速度,使其現在可以與 gym 本身競爭。

縮放目標

最流行的 RLHF 和大規模訓練目標(PPO 和 A2C)現在與 FSDP 和 DDP 模型相容!

TensorDict

功能:MemoryMappedTensor 取代 MemmapTensor

我們為 TensorDict 提供了一個更高效的 mmap 後端;MemoryMappedTensor,它直接繼承自 torch.Tensor。它附帶了一系列用於構建的實用程式,例如 from_tensorempty 等。MemoryMappedTensor 現在比其對應物更安全、更快。該庫與以前的類完全相容,以方便過渡。

我們還引入了一組新的多執行緒序列化方法,使 tensordict 序列化與 torch.save 極具競爭力,LLM 的序列化和反序列化速度比使用 torch.save 快 3 倍以上

功能:TensorDict 中的非張量資料

現在可以透過 NonTensorData 張量類攜帶非張量資料。這使得可以構建帶有元資料的 tensordict。memmap-API 與這些值完全相容,允許使用者無縫序列化和反序列化此類物件。要在 tensordict 中儲存非張量資料,只需使用 __setitem__ 方法進行賦值即可。

效率提升

一些方法的執行時已得到改進,例如 unbind、split、map 甚至 TensorDict 例項化。請檢視我們的基準測試

TorchRec/fbgemm_gpu

VBE

TorchRec 現在在 EmbeddingBagCollection 模組中原生支援 VBE(可變批次嵌入)。這允許每個特徵的可變批次大小,從而解鎖稀疏輸入資料去重,這可以大大加快嵌入查詢和全對全時間。要啟用,只需使用 stride_per_key_per_rankinverse_indices 欄位初始化 KeyedJaggedTensor,它們分別指定每個特徵的批次大小和用於重新索引嵌入輸出的逆索引。

除了 TorchRec 庫的更改之外,fbgemm_gpu 已在 TBE 中添加了對每個特徵可變批次大小的支援。VBE 在加權和未加權情況下都啟用了拆分 TBE 訓練。要使用 VBE,請務必使用最新的 fbgemm_gpu 版本。

嵌入解除安裝

此技術是指使用 CUDA UVM 快取“熱”嵌入(即在主機記憶體中儲存嵌入表並在 HBM 記憶體中進行快取),並預取快取。嵌入解除安裝允許使用更少的 GPU 執行更大的模型,同時保持具有競爭力的效能。使用預取管道(PrefetchTrainPipelineSparseDist)並在規劃器中透過約束條件傳入每個表的快取載入因子預取管道標誌以使用此功能。

Fbgemm_gpu 在 v0.5.0 中引入了 UVM 快取管道預取,以提高 TBE 效能。這允許快取插入與 TBE 正向/反向並行執行。要啟用此功能,請務必使用最新的 fbgemm_gpu 版本。

Trec.shard/shard_modules

這些 API 將嵌入子模組替換為其分片變體。shard API 適用於單個嵌入模組,而 shard_modules API 替換所有嵌入模組,並且不會觸及其他非嵌入子模組。

嵌入分片遵循與之前的 TorchRec DistributedModuleParallel 行為相似的行為,不同之處在於 ShardedModules 已變為可組合的,這意味著這些模組由 TableBatchedEmbeddingSlices 支援,這些是底層 TBE(包括 .grad)的檢視。這意味著現在使用 named_parameters() 返回融合引數,包括在 DistributedModuleParallel 中。

TorchVision

V2 轉換現在穩定了!

torchvision.transforms.v2 名稱空間直到現在仍處於 BETA 階段。它現在穩定了!無論您是 Torchvision 轉換的新手,還是已經有經驗,我們都鼓勵您從V2 轉換入門開始,以瞭解新 V2 轉換可以做什麼。

瀏覽我們的主文件以獲取一般資訊和效能提示。可用的轉換和函式在API 參考中列出。您還可以在我們的示例庫中找到更多資訊和教程,例如V2 轉換:端到端目標檢測/分割示例如何編寫您自己的 V2 轉換

邁向 torch.compile() 支援

我們正在逐步將 torch.compile() 支援新增到 torchvision 介面,減少圖中斷並允許動態形狀。

torchvision ops(nms[ps_]roi_align[ps_]roi_pooldeform_conv_2d)現在與 torch.compile 和動態形狀相容。

在轉換方面,大部分低階核心(如 resize_image()crop_image())應該能夠正確編譯,沒有圖中斷和動態形狀。我們仍在解決其餘的邊緣情況,朝著完整的函式支援和類邁進,您應該會在下一個版本中看到這方面的更多進展。