分散式 Autograd 設計¶
本說明將介紹分散式 Autograd 的詳細設計,並逐步介紹其內部結構。在繼續之前,請確保您熟悉 Autograd 機制 和 分散式 RPC 框架。
背景¶
假設您有兩個節點和一個非常簡單的模型,分佈在兩個節點上。這可以使用 torch.distributed.rpc 實現,如下所示
import torch
import torch.distributed.rpc as rpc
def my_add(t1, t2):
return torch.add(t1, t2)
# On worker 0:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)
# Compute some loss.
loss = t5.sum()
分散式 Autograd 背後的主要動機是能夠使用我們計算出的 loss 在此類分散式模型上執行反向傳播,並記錄所有需要梯度的張量的適當梯度。
正向傳播期間的 Autograd 記錄¶
PyTorch 在正向傳播期間建構 Autograd 圖,並且此圖用於執行反向傳播。有關更多詳細資訊,請參閱 Autograd 如何編碼歷史記錄。
對於分散式 Autograd,我們需要在正向傳播期間追蹤所有 RPC,以確保正確執行反向傳播。為此,我們在執行 RPC 時將 send 和 recv 函數附加到 Autograd 圖。
send函數附加到 RPC 的來源,其輸出邊緣指向 RPC 輸入張量的 Autograd 函數。反向傳播期間此函數的輸入是從目的地接收的,作為適當recv函數的輸出。recv函數附加到 RPC 的目的地,其輸入是使用輸入張量從目的地執行的運算子中檢索的。此函數的輸出梯度在反向傳播期間發送到源節點的適當send函數。每個
send-recv對都分配了一個全域唯一的autograd_message_id來唯一標識該對。這對於在反向傳播期間在遠端節點上查詢相應的函數很有用。對於 RRef,每當我們呼叫
torch.distributed.rpc.RRef.to_here()時,我們都會為涉及的張量附加一個適當的send-recv對。
舉例來說,這就是我們上面範例的 Autograd 圖的樣子(為簡單起見,不包括 t5.sum())
分散式 Autograd 上下文¶
每個使用分散式 Autograd 的正向和反向傳播都分配了一個唯一的 torch.distributed.autograd.context,並且此上下文具有一個全域唯一的 autograd_context_id。此上下文會在每個節點上視需要建立。
此上下文用於以下目的
執行分散式反向傳播的多個節點可能會在同一個張量上累積梯度,因此在我們有機會執行優化器之前,張量的
.grad欄位將具有來自各種分散式反向傳播的梯度。這類似於在本地多次呼叫torch.autograd.backward()。為了提供一種分離每個反向傳播的梯度的方法,梯度會在每個反向傳播的torch.distributed.autograd.context中累積。在正向傳播期間,我們將每個 Autograd 傳播的
send和recv函數儲存在此上下文中。這確保我們持有對 Autograd 圖中適當節點的引用,以使其保持活動狀態。除此之外,在反向傳播期間很容易查詢適當的send和recv函數。一般來說,我們也使用此上下文來儲存每個分散式 Autograd 傳播的一些中繼資料。
從使用者的角度來看,Autograd 上下文的設定如下
import torch.distributed.autograd as dist_autograd
with dist_autograd.context() as context_id:
loss = model.forward()
dist_autograd.backward(context_id, loss)
請務必注意,您的模型的前向傳遞必須在分佈式自動梯度上下文管理器中調用,因為需要有效的上下文才能確保所有 send 和 recv 函數都正確存儲,以便在所有參與節點上運行反向傳遞。
分佈式反向傳遞¶
在本節中,我們概述了在分佈式反向傳遞期間準確計算依賴關係的挑戰,並描述了如何在分佈式反向傳遞期間執行幾種算法(具有取捨)。
計算依賴關係¶
考慮在單台機器上運行的以下程式碼片段
import torch
a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)
d = a + b
e = b * c
d.sum.().backward()
這就是上面程式碼的自動梯度圖的樣子
自動梯度引擎作為反向傳遞的一部分執行的第一步是計算自動梯度圖中每個節點的依賴關係數量。這有助於自動梯度引擎知道圖中的節點何時可以執行。 add(1) 和 mul(0) 的括號中的數字表示依賴關係的數量。如您所見,這意味著在反向傳遞期間, add 節點需要 1 個輸入,而 mul 節點不需要任何輸入(換句話說,不需要執行)。本地自動梯度引擎通過從根節點(在本例中為 d)遍歷圖來計算這些依賴關係。
自動梯度圖中的某些節點可能未在反向傳遞中執行,這一事實對分佈式自動梯度提出了挑戰。請考慮這段使用 RPC 的程式碼。
import torch
import torch.distributed.rpc as rpc
a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)
d = rpc.rpc_sync("worker1", torch.add, args=(a, b))
e = rpc.rpc_sync("worker1", torch.mul, args=(b, c))
loss = d.sum()
上面程式碼的關聯自動梯度圖將是
計算此分佈式自動梯度圖的依賴關係更具挑戰性,並且需要一些開銷(計算或網路通訊方面)。
對於效能敏感的應用程式,我們可以通過假設每個 send 和 recv 函數在反向傳遞中都有效來避免大量開銷(大多數應用程式不會執行未使用的 RPC)。這簡化了分佈式自動梯度算法,並且效率更高,但代價是應用程式需要了解這些限制。此算法稱為 快速模式算法,並在下面詳細描述。
在一般情況下,並非每個 send 和 recv 函數都必須在反向傳遞中有效。為了解決這個問題,我們提出了一種 智慧模式算法,將在後面的章節中介紹。請注意,目前僅實作了「快速」模式算法。
快速模式算法¶
此算法的關鍵假設是,當我們運行反向傳遞時,每個 send 函數的依賴關係為 1。換句話說,我們假設我們將通過 RPC 從另一個節點接收梯度。
算法如下
我們從具有反向傳遞根的 Worker 開始(所有根都必須是本地的)。
查找當前 分佈式自動梯度上下文 的所有
send函數。從提供的根和我們檢索到的所有
send函數開始,在本地計算依賴關係。計算依賴關係後,使用提供的根啟動本地自動梯度引擎。
當自動梯度引擎執行
recv函數時,recv函數會通過 RPC 將輸入梯度發送到相應的 Worker。每個recv函數都知道目標 Worker ID,因為它在正向傳遞中被記錄。recv函數還會將autograd_context_id和autograd_message_id發送到遠端主機。當在遠端主機上收到此請求時,我們使用
autograd_context_id和autograd_message_id查找相應的send函數。如果這是 Worker 第一次收到給定
autograd_context_id的請求,它將如上文第 1-3 點所述在本地計算依賴關係。然後,在 6. 中檢索到的
send函數會排入佇列,以便在該 Worker 的本地自動梯度引擎上執行。最後,我們不是將梯度累加到 Tensor 的
.grad欄位中,而是為每個 分佈式自動梯度上下文 單獨累加梯度。梯度存儲在Dict[Tensor, Tensor]中,這基本上是從 Tensor 到其關聯梯度的映射,並且可以使用get_gradients()API 檢索此映射。
例如,具有分佈式自動梯度的完整程式碼如下所示
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
def my_add(t1, t2):
return torch.add(t1, t2)
# On worker 0:
# Setup the autograd context. Computations that take
# part in the distributed backward pass must be within
# the distributed autograd context manager.
with dist_autograd.context() as context_id:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)
# Compute some loss.
loss = t5.sum()
# Run the backward pass.
dist_autograd.backward(context_id, [loss])
# Retrieve the gradients from the context.
dist_autograd.get_gradients(context_id)
具有依賴關係的分佈式自動梯度圖如下所示(為簡單起見,不包括 t5.sum())
應用於上述範例的 快速模式算法 如下所示
在
Worker 0上,我們從根loss和send1開始計算依賴關係。因此,send1標記為依賴關係 1,而Worker 0上的mul標記為依賴關係 1。現在,我們在
Worker 0上啟動本地自動梯度引擎。我們首先執行mul函數,將其輸出累加到自動梯度上下文中作為t4的梯度。然後,我們執行recv2,它將梯度發送到Worker 1。由於這是
Worker 1第一次收到有關此反向傳遞的資訊,因此它會開始依賴關係計算,並相應地標記send2、add和recv1的依賴關係。接下來,我們將
send2排入Worker 1的本地自動梯度引擎的佇列中,該引擎又會執行add和recv1。當執行
recv1時,它會將梯度發送到Worker 0。由於
Worker 0已經為此反向傳遞計算了依賴關係,因此它只會在本地排入佇列並執行send1。最後,
t1、t2和t4的梯度會累加到 分佈式自動梯度上下文 中。
分佈式優化器¶
DistributedOptimizer 的運作方式如下
獲取要最佳化的遠端參數清單(
RRef)。這些也可以是在本地RRef中包裝的本地參數。獲取
Optimizer類作為本地優化器,以在所有不同的RRef所有者上運行。分佈式優化器在每個 Worker 節點上創建本地
Optimizer的實例,並持有對它們的RRef。調用
torch.distributed.optim.DistributedOptimizer.step()時,分佈式優化器使用 RPC 在相應的遠端 Worker 上遠端執行所有本地優化器。必須將分佈式自動梯度context_id作為輸入提供給torch.distributed.optim.DistributedOptimizer.step()。本地優化器使用它來應用存儲在相應上下文中的梯度。如果多個並發分佈式優化器正在更新 Worker 上的相同參數,則這些更新將通過鎖定進行序列化。
簡單的端到端範例¶
總的來說,以下是使用分佈式自動梯度和分佈式優化器的簡單端到端範例。如果程式碼放在名為「dist_autograd_simple.py」的檔案中,則可以使用命令 MASTER_ADDR="localhost" MASTER_PORT=29500 python dist_autograd_simple.py 運行
import torch
import torch.multiprocessing as mp
import torch.distributed.autograd as dist_autograd
from torch.distributed import rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer
def random_tensor():
return torch.rand((3, 3), requires_grad=True)
def _run_process(rank, dst_rank, world_size):
name = "worker{}".format(rank)
dst_name = "worker{}".format(dst_rank)
# Initialize RPC.
rpc.init_rpc(
name=name,
rank=rank,
world_size=world_size
)
# Use a distributed autograd context.
with dist_autograd.context() as context_id:
# Forward pass (create references on remote nodes).
rref1 = rpc.remote(dst_name, random_tensor)
rref2 = rpc.remote(dst_name, random_tensor)
loss = rref1.to_here() + rref2.to_here()
# Backward pass (run distributed autograd).
dist_autograd.backward(context_id, [loss.sum()])
# Build DistributedOptimizer.
dist_optim = DistributedOptimizer(
optim.SGD,
[rref1, rref2],
lr=0.05,
)
# Run the distributed optimizer step.
dist_optim.step(context_id)
def run_process(rank, world_size):
dst_rank = (rank + 1) % world_size
_run_process(rank, dst_rank, world_size)
rpc.shutdown()
if __name__ == '__main__':
# Run world_size workers
world_size = 2
mp.spawn(run_process, args=(world_size,), nprocs=world_size)