快捷方式

DataParallel

class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)[原始碼][原始碼]

在模組級別實現資料並行。

此容器透過沿批處理維度(其他物件將在每個裝置上覆制一次)對輸入進行分塊,將給定的 module 的應用並行化到指定的裝置上。在前向傳播中,模組在每個裝置上覆制,每個副本處理一部分輸入。在後向傳播中,每個副本的梯度會求和到原始模組中。

批處理大小應大於所使用的 GPU 數量。

警告

建議使用 DistributedDataParallel 類進行多 GPU 訓練,即使只有一個節點。參見:使用 nn.parallel.DistributedDataParallel 而非多程序或 nn.DataParallel 以及 分散式資料並行

允許將任意位置和關鍵字輸入傳遞給 DataParallel,但某些型別會特殊處理。張量會沿指定的維度(預設為 0)進行 分散。tuple、list 和 dict 型別會進行淺複製。其他型別將在不同的執行緒之間共享,如果在模型的正向傳播中寫入,則可能損壞。

並行化的 module 在執行此 DataParallel 模組之前,必須將其引數和緩衝區放置在 device_ids[0] 上。

警告

在前向傳播時, module 在每個裝置上被 複製,因此在 forward 中對正在執行的模組進行的任何更新都將丟失。例如,如果 module 有一個在每次 forward 中遞增的計數器屬性,它將始終保持初始值,因為更新是在副本上進行的,而副本在 forward 後即被銷燬。然而,DataParallel 保證 device[0] 上的副本將其引數和緩衝區與基礎並行化 module 共享儲存。因此,對 device[0] 上的引數或緩衝區的 就地 更新將被記錄下來。例如, BatchNorm2dspectral_norm() 依賴此行為來更新緩衝區。

警告

module 及其子模組上定義的前向和後向鉤子將被呼叫 len(device_ids) 次,每次輸入位於特定的裝置上。特別是,鉤子只保證按照與相應裝置上的操作相關的正確順序執行。例如,不保證透過 register_forward_pre_hook() 設定的鉤子在所有 len(device_ids)forward() 呼叫之前執行,但保證每個此類鉤子在該裝置相應的 forward() 呼叫之前執行。

警告

moduleforward() 中返回一個標量(即 0 維張量)時,此封裝器將返回一個長度等於資料並行所用裝置數量的向量,其中包含來自每個裝置的結果。

注意

在使用 Module 封裝在 DataParallel 中時,使用 打包序列 -> 迴圈網路 -> 解包序列 模式時存在一些細微之處。詳情請參閱 FAQ 中的 我的迴圈網路無法與資料並行一起使用 部分。

引數
  • module (Module) – 要並行化的模組

  • device_ids (list of int or torch.device) – CUDA 裝置 (預設: 所有裝置)

  • output_device (int or torch.device) – 輸出裝置的位置 (預設: device_ids[0])

變數

module (Module) – 要並行化的模組

示例

>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
>>> output = net(input_var)  # input_var can be on any device, including CPU

文件

訪問 PyTorch 全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源