DataParallel¶
- class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)[原始碼][原始碼]¶
在模組級別實現資料並行。
此容器透過沿批處理維度(其他物件將在每個裝置上覆制一次)對輸入進行分塊,將給定的
module的應用並行化到指定的裝置上。在前向傳播中,模組在每個裝置上覆制,每個副本處理一部分輸入。在後向傳播中,每個副本的梯度會求和到原始模組中。批處理大小應大於所使用的 GPU 數量。
警告
建議使用
DistributedDataParallel類進行多 GPU 訓練,即使只有一個節點。參見:使用 nn.parallel.DistributedDataParallel 而非多程序或 nn.DataParallel 以及 分散式資料並行。允許將任意位置和關鍵字輸入傳遞給 DataParallel,但某些型別會特殊處理。張量會沿指定的維度(預設為 0)進行 分散。tuple、list 和 dict 型別會進行淺複製。其他型別將在不同的執行緒之間共享,如果在模型的正向傳播中寫入,則可能損壞。
並行化的
module在執行此DataParallel模組之前,必須將其引數和緩衝區放置在device_ids[0]上。警告
在前向傳播時,
module在每個裝置上被 複製,因此在forward中對正在執行的模組進行的任何更新都將丟失。例如,如果module有一個在每次forward中遞增的計數器屬性,它將始終保持初始值,因為更新是在副本上進行的,而副本在forward後即被銷燬。然而,DataParallel保證device[0]上的副本將其引數和緩衝區與基礎並行化module共享儲存。因此,對device[0]上的引數或緩衝區的 就地 更新將被記錄下來。例如,BatchNorm2d和spectral_norm()依賴此行為來更新緩衝區。警告
在
module及其子模組上定義的前向和後向鉤子將被呼叫len(device_ids)次,每次輸入位於特定的裝置上。特別是,鉤子只保證按照與相應裝置上的操作相關的正確順序執行。例如,不保證透過register_forward_pre_hook()設定的鉤子在所有len(device_ids)次forward()呼叫之前執行,但保證每個此類鉤子在該裝置相應的forward()呼叫之前執行。警告
當
module在forward()中返回一個標量(即 0 維張量)時,此封裝器將返回一個長度等於資料並行所用裝置數量的向量,其中包含來自每個裝置的結果。注意
在使用
Module封裝在DataParallel中時,使用打包序列 -> 迴圈網路 -> 解包序列模式時存在一些細微之處。詳情請參閱 FAQ 中的 我的迴圈網路無法與資料並行一起使用 部分。- 引數
module (Module) – 要並行化的模組
device_ids (list of int or torch.device) – CUDA 裝置 (預設: 所有裝置)
output_device (int or torch.device) – 輸出裝置的位置 (預設: device_ids[0])
- 變數
module (Module) – 要並行化的模組
示例
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) >>> output = net(input_var) # input_var can be on any device, including CPU