多程序最佳實務¶
torch.multiprocessing 是 Python multiprocessing 模組的替代品。它支援完全相同的操作,但擴展了它,因此所有透過 multiprocessing.Queue 傳送的張量,其資料都會被移到共享記憶體中,並且只會將控點傳送到另一個程序。
注意
當 Tensor 被傳送到另一個程序時,Tensor 資料會被共享。如果 torch.Tensor.grad 不是 None,它也會被共享。在沒有 torch.Tensor.grad 欄位的 Tensor 被傳送到另一個程序之後,它會建立一個標準的程序特定 .grad Tensor,這個 Tensor 不會像 Tensor 的資料那樣自動在所有程序之間共享。
這允許實作各種訓練方法,例如 Hogwild、A3C 或任何其他需要非同步操作的方法。
多程序中的 CUDA¶
CUDA 執行階段不支援 fork 啟動方法;需要使用 spawn 或 forkserver 啟動方法才能在子程序中使用 CUDA。
注意
啟動方法可以透過使用 multiprocessing.get_context(...) 建立上下文或直接使用 multiprocessing.set_start_method(...) 來設定。
與 CPU 張量不同,傳送程序需要保留原始張量,只要接收程序保留了該張量的副本。它是在底層實作的,但需要使用者遵循最佳實務才能使程式正確執行。例如,傳送程序必須保持活動狀態,只要消費者程序具有對該張量的引用,並且如果消費者程序透過致命訊號異常退出,則引用計數無法儲存您。請參閱 本節。
另請參閱:使用 nn.parallel.DistributedDataParallel 而不是多程序或 nn.DataParallel
最佳實務和技巧¶
避免和解決死結¶
在產生新程序時,可能會出現很多問題,其中最常見的死結原因是背景執行緒。如果有任何執行緒持有鎖定或匯入模組,並且呼叫了 fork,則子程序很可能會處於損壞狀態,並會死結或以其他方式失敗。請注意,即使您沒有這樣做,Python 內建函式庫也會這樣做 - 無需多看,只要看看 multiprocessing。 multiprocessing.Queue 實際上是一個非常複雜的類別,它會產生多個用於序列化、傳送和接收物件的執行緒,它們也可能導致上述問題。如果您發現自己處於這種情況,請嘗試使用 SimpleQueue,它不使用任何額外的執行緒。
我們正在盡力讓您輕鬆使用並確保這些死結不會發生,但有些事情我們無法控制。如果您有任何無法解決的問題,請嘗試在論壇上尋求幫助,我們會看看這是否是我們可以解決的問題。
重複使用透過佇列傳遞的緩衝區¶
請記住,每次將 Tensor 放入 multiprocessing.Queue 時,都必須將其移至共享記憶體。如果它已經是共享的,則此操作無效,否則會產生額外的記憶體複製,從而降低整個過程的速度。即使您有一個處理序池將數據發送到單個處理序,也要讓它將緩衝區發送回來 - 這幾乎是免費的,並且可以讓您在發送下一批數據時避免複製。
非同步多程序訓練(例如 Hogwild)¶
使用 torch.multiprocessing,可以非同步地訓練模型,參數可以一直共享,也可以定期同步。在前一種情況下,我們建議發送整個模型物件,而在後一種情況下,我們建議只發送 state_dict()。
我們建議使用 multiprocessing.Queue 在處理序之間傳遞各種 PyTorch 物件。例如,在使用 fork 啟動方法時,可以繼承已經在共享記憶體中的張量和儲存,但是這很容易出錯,應該謹慎使用,並且只有進階用戶才應該使用。佇列雖然有時不是一種優雅的解決方案,但在所有情況下都能正常工作。
警告
您應該小心使用沒有用 if __name__ == '__main__' 保護的全域語句。如果使用 fork 以外的其他啟動方法,它們將在所有子處理序中執行。
Hogwild¶
您可以在 範例儲存庫 中找到具體的 Hogwild 實作,但為了展示程式碼的整體結構,下方還有一個最小範例
import torch.multiprocessing as mp
from model import MyModel
def train(model):
    # Construct data_loader, optimizer, etc.
    for data, labels in data_loader:
        optimizer.zero_grad()
        loss_fn(model(data), labels).backward()
        optimizer.step()  # This will update the shared parameters
if __name__ == '__main__':
    num_processes = 4
    model = MyModel()
    # NOTE: this is required for the ``fork`` method to work
    model.share_memory()
    processes = []
    for rank in range(num_processes):
        p = mp.Process(target=train, args=(model,))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()
多程序中的 CPU¶
不當的多程序處理可能會導致 CPU 超載,導致不同的處理序競爭 CPU 資源,從而導致效率低下。
本教學將說明什麼是 CPU 超載以及如何避免它。
CPU 超載¶
CPU 超載是一個技術術語,指的是分配給系統的虛擬 CPU 總數超過硬體上可用的虛擬 CPU 總數的情況。
這會導致對 CPU 資源的嚴重爭用。在這種情況下,處理序之間會頻繁切換,這會增加處理序切換的開銷,並降低整體系統效率。
請參閱 範例儲存庫 中 Hogwild 實作中的程式碼範例,了解 CPU 超載的情況。
當使用以下命令在 CPU 上使用 4 個處理序運行訓練範例時
python main.py --num-processes 4
假設機器上有 N 個虛擬 CPU 可用,執行上述命令將產生 4 個子處理序。每個子處理序將為自身分配 N 個虛擬 CPU,導致需要 4*N 個虛擬 CPU。但是,機器只有 N 個虛擬 CPU 可用。因此,不同的處理序將競爭資源,導致頻繁的處理序切換。
以下觀察結果表明存在 CPU 超載
- 高 CPU 使用率:通過使用 - htop命令,您可以觀察到 CPU 使用率一直很高,經常達到或超過其最大容量。這表明對 CPU 資源的需求超過了可用的物理核心,導致處理序之間爭奪 CPU 時間。
- 頻繁的上下文切換和低系統效率:在 CPU 超載的情況下,處理序會競爭 CPU 時間,作業系統需要在不同的處理序之間快速切換,以便公平地分配資源。這種頻繁的上下文切換會增加開銷,並降低整體系統效率。 
避免 CPU 超載¶
避免 CPU 超載的一個好方法是適當的資源分配。確保同時運行的處理序或執行緒的數量不超過可用的 CPU 資源。
在這種情況下,一個解決方案是在子處理序中指定適當的執行緒數。這可以通過使用子處理序中的 torch.set_num_threads(int) 函數為每個處理序設置執行緒數來實現。
假設機器上有 N 個虛擬 CPU,並且將產生 M 個處理序,則每個處理序使用的最大 num_threads 值為 floor(N/M)。為避免 mnist_hogwild 範例中的 CPU 超載,需要對 範例儲存庫 中的檔案 train.py 進行以下更改。
def train(rank, args, model, device, dataset, dataloader_kwargs):
    torch.manual_seed(args.seed + rank)
    #### define the num threads used in current sub-processes
    torch.set_num_threads(floor(N/M))
    train_loader = torch.utils.data.DataLoader(dataset, **dataloader_kwargs)
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
    for epoch in range(1, args.epochs + 1):
        train_epoch(epoch, args, model, device, train_loader, optimizer)
使用 torch.set_num_threads(floor(N/M)) 為每個處理序設置 num_thread。其中,您將 N 替換為可用的虛擬 CPU 數,將 M 替換為選擇的處理序數。適當的 num_thread 值將根據具體任務而有所不同。但是,作為一般準則,num_thread 的最大值應為 floor(N/M),以避免 CPU 超載。在 mnist_hogwild 訓練範例中,在避免 CPU 超載之後,您可以實現 30 倍的效能提升。