• 文件 >
  • 多程序最佳實務
捷徑

多程序最佳實務

torch.multiprocessing 是 Python multiprocessing 模組的替代品。它支援完全相同的操作,但擴展了它,因此所有透過 multiprocessing.Queue 傳送的張量,其資料都會被移到共享記憶體中,並且只會將控點傳送到另一個程序。

注意

Tensor 被傳送到另一個程序時,Tensor 資料會被共享。如果 torch.Tensor.grad 不是 None,它也會被共享。在沒有 torch.Tensor.grad 欄位的 Tensor 被傳送到另一個程序之後,它會建立一個標準的程序特定 .grad Tensor,這個 Tensor 不會像 Tensor 的資料那樣自動在所有程序之間共享。

這允許實作各種訓練方法,例如 Hogwild、A3C 或任何其他需要非同步操作的方法。

多程序中的 CUDA

CUDA 執行階段不支援 fork 啟動方法;需要使用 spawnforkserver 啟動方法才能在子程序中使用 CUDA。

注意

啟動方法可以透過使用 multiprocessing.get_context(...) 建立上下文或直接使用 multiprocessing.set_start_method(...) 來設定。

與 CPU 張量不同,傳送程序需要保留原始張量,只要接收程序保留了該張量的副本。它是在底層實作的,但需要使用者遵循最佳實務才能使程式正確執行。例如,傳送程序必須保持活動狀態,只要消費者程序具有對該張量的引用,並且如果消費者程序透過致命訊號異常退出,則引用計數無法儲存您。請參閱 本節

另請參閱:使用 nn.parallel.DistributedDataParallel 而不是多程序或 nn.DataParallel

最佳實務和技巧

避免和解決死結

在產生新程序時,可能會出現很多問題,其中最常見的死結原因是背景執行緒。如果有任何執行緒持有鎖定或匯入模組,並且呼叫了 fork,則子程序很可能會處於損壞狀態,並會死結或以其他方式失敗。請注意,即使您沒有這樣做,Python 內建函式庫也會這樣做 - 無需多看,只要看看 multiprocessingmultiprocessing.Queue 實際上是一個非常複雜的類別,它會產生多個用於序列化、傳送和接收物件的執行緒,它們也可能導致上述問題。如果您發現自己處於這種情況,請嘗試使用 SimpleQueue,它不使用任何額外的執行緒。

我們正在盡力讓您輕鬆使用並確保這些死結不會發生,但有些事情我們無法控制。如果您有任何無法解決的問題,請嘗試在論壇上尋求幫助,我們會看看這是否是我們可以解決的問題。

重複使用透過佇列傳遞的緩衝區

請記住,每次將 Tensor 放入 multiprocessing.Queue 時,都必須將其移至共享記憶體。如果它已經是共享的,則此操作無效,否則會產生額外的記憶體複製,從而降低整個過程的速度。即使您有一個處理序池將數據發送到單個處理序,也要讓它將緩衝區發送回來 - 這幾乎是免費的,並且可以讓您在發送下一批數據時避免複製。

非同步多程序訓練(例如 Hogwild)

使用 torch.multiprocessing,可以非同步地訓練模型,參數可以一直共享,也可以定期同步。在前一種情況下,我們建議發送整個模型物件,而在後一種情況下,我們建議只發送 state_dict()

我們建議使用 multiprocessing.Queue 在處理序之間傳遞各種 PyTorch 物件。例如,在使用 fork 啟動方法時,可以繼承已經在共享記憶體中的張量和儲存,但是這很容易出錯,應該謹慎使用,並且只有進階用戶才應該使用。佇列雖然有時不是一種優雅的解決方案,但在所有情況下都能正常工作。

警告

您應該小心使用沒有用 if __name__ == '__main__' 保護的全域語句。如果使用 fork 以外的其他啟動方法,它們將在所有子處理序中執行。

Hogwild

您可以在 範例儲存庫 中找到具體的 Hogwild 實作,但為了展示程式碼的整體結構,下方還有一個最小範例

import torch.multiprocessing as mp
from model import MyModel

def train(model):
    # Construct data_loader, optimizer, etc.
    for data, labels in data_loader:
        optimizer.zero_grad()
        loss_fn(model(data), labels).backward()
        optimizer.step()  # This will update the shared parameters

if __name__ == '__main__':
    num_processes = 4
    model = MyModel()
    # NOTE: this is required for the ``fork`` method to work
    model.share_memory()
    processes = []
    for rank in range(num_processes):
        p = mp.Process(target=train, args=(model,))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()

多程序中的 CPU

不當的多程序處理可能會導致 CPU 超載,導致不同的處理序競爭 CPU 資源,從而導致效率低下。

本教學將說明什麼是 CPU 超載以及如何避免它。

CPU 超載

CPU 超載是一個技術術語,指的是分配給系統的虛擬 CPU 總數超過硬體上可用的虛擬 CPU 總數的情況。

這會導致對 CPU 資源的嚴重爭用。在這種情況下,處理序之間會頻繁切換,這會增加處理序切換的開銷,並降低整體系統效率。

請參閱 範例儲存庫 中 Hogwild 實作中的程式碼範例,了解 CPU 超載的情況。

當使用以下命令在 CPU 上使用 4 個處理序運行訓練範例時

python main.py --num-processes 4

假設機器上有 N 個虛擬 CPU 可用,執行上述命令將產生 4 個子處理序。每個子處理序將為自身分配 N 個虛擬 CPU,導致需要 4*N 個虛擬 CPU。但是,機器只有 N 個虛擬 CPU 可用。因此,不同的處理序將競爭資源,導致頻繁的處理序切換。

以下觀察結果表明存在 CPU 超載

  1. 高 CPU 使用率:通過使用 htop 命令,您可以觀察到 CPU 使用率一直很高,經常達到或超過其最大容量。這表明對 CPU 資源的需求超過了可用的物理核心,導致處理序之間爭奪 CPU 時間。

  2. 頻繁的上下文切換和低系統效率:在 CPU 超載的情況下,處理序會競爭 CPU 時間,作業系統需要在不同的處理序之間快速切換,以便公平地分配資源。這種頻繁的上下文切換會增加開銷,並降低整體系統效率。

避免 CPU 超載

避免 CPU 超載的一個好方法是適當的資源分配。確保同時運行的處理序或執行緒的數量不超過可用的 CPU 資源。

在這種情況下,一個解決方案是在子處理序中指定適當的執行緒數。這可以通過使用子處理序中的 torch.set_num_threads(int) 函數為每個處理序設置執行緒數來實現。

假設機器上有 N 個虛擬 CPU,並且將產生 M 個處理序,則每個處理序使用的最大 num_threads 值為 floor(N/M)。為避免 mnist_hogwild 範例中的 CPU 超載,需要對 範例儲存庫 中的檔案 train.py 進行以下更改。

def train(rank, args, model, device, dataset, dataloader_kwargs):
    torch.manual_seed(args.seed + rank)

    #### define the num threads used in current sub-processes
    torch.set_num_threads(floor(N/M))

    train_loader = torch.utils.data.DataLoader(dataset, **dataloader_kwargs)

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
    for epoch in range(1, args.epochs + 1):
        train_epoch(epoch, args, model, device, train_loader, optimizer)

使用 torch.set_num_threads(floor(N/M)) 為每個處理序設置 num_thread。其中,您將 N 替換為可用的虛擬 CPU 數,將 M 替換為選擇的處理序數。適當的 num_thread 值將根據具體任務而有所不同。但是,作為一般準則,num_thread 的最大值應為 floor(N/M),以避免 CPU 超載。在 mnist_hogwild 訓練範例中,在避免 CPU 超載之後,您可以實現 30 倍的效能提升。

文件

訪問 PyTorch 的完整開發者文檔

查看文檔

教學

獲取針對初學者和進階開發者的深入教學

查看教學

資源

查找開發資源並獲得問題的答案

查看資源