使用 Fully Sharded Data Parallel (FSDP) 進行高階模型訓練¶
創建於:2024 年 10 月 31 日 | 最後更新:2024 年 10 月 31 日 | 最後驗證:2024 年 11 月 5 日
作者:Hamid Shojanazeri, Less Wright, Rohan Varma, Yanli Zhao
PyTorch 的 Fully Sharded Data Parallel 模組:一個用於在
資料並行工作節點之間分片模組引數的包裝器。
PyTorch 1.12 或更高版本
閱讀關於 FSDP API 的內容。
本教程介紹了作為 PyTorch 1.12 釋出的一部分的 Fully Sharded Data Parallel (FSDP) 更高階的功能。要熟悉 FSDP,請參閱 FSDP 入門教程。
在本教程中,我們使用 FSDP 微調 HuggingFace (HF) T5 模型進行文字摘要,作為一個工作示例。
本示例使用 WikiHow 資料集,為簡單起見,我們將演示在具有 8 個 A100 GPU 的單節點 P4dn 例項上進行訓練。我們現在有多篇部落格文章 ((連結 1), (連結 2)) 和一篇關於在多節點叢集上進行大規模 FSDP 訓練的論文。
FSDP 是一個生產就緒的軟體包,專注於易用性、效能和長期支援。FSDP 的主要優勢之一是減少每個 GPU 的記憶體佔用。這使得可以使用比 DDP 更低的總記憶體訓練更大的模型,並利用計算和通訊的重疊來高效地訓練模型。這種減輕的記憶體壓力可以用來訓練更大的模型或增加批次大小,這可能有助於提高整體訓練吞吐量。你可以在此處閱讀更多關於 PyTorch FSDP 的資訊。
本教程中的 FSDP 特性¶
Transformer 自動包裝策略
混合精度
在裝置上初始化 FSDP 模型
分片策略
反向預取
透過流式傳輸到 CPU 儲存模型檢查點
FSDP 工作原理回顧¶
從高層次看,FSDP 工作流程如下
在建構函式中
分片模型引數,每個 Rank 只保留自己的分片
在前向傳播中
執行 all_gather 收集所有 Rank 的所有分片,以恢復此 FSDP 單元的完整引數,並執行前向計算
丟棄剛剛收集的非自身擁有的引數分片以釋放記憶體
在反向傳播中
執行 all_gather 收集所有 Rank 的所有分片,以恢復此 FSDP 單元的完整引數,並執行反向計算
丟棄非自身擁有的引數以釋放記憶體。
執行 reduce_scatter 以同步梯度
微調 HF T5¶
HF T5 預訓練模型有四種不同大小,從 6000 萬引數的小型模型到 110 億引數的 XXL 模型。在本教程中,我們演示了使用 FSDP 微調 T5 3B 模型進行文字摘要,使用 WikiHow 資料集。本教程的主要重點是突出 FSDP 中可用的不同功能,這些功能對於訓練超過 30 億引數的大規模模型非常有用。此外,我們還介紹了基於 Transformer 的模型的特定功能。本教程的程式碼可在 PyTorch examples 中找到。
設定
1.1 安裝最新版 PyTorch
pip3 install torch torchvision torchaudio
1.2 資料集設定
請建立一個 data 資料夾,從 wikihowAll.csv 和 wikihowSep.cs 下載 WikiHow 資料集,並將它們放在 data 資料夾中。我們將使用來自 summarization_dataset 的 wikihow 資料集。
接下來,我們將以下程式碼片段新增到 Python 指令碼“T5_training.py”中。
注意
本教程的完整原始碼可在 PyTorch examples 中找到。
1.3 匯入必要的包
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, GPT2TokenizerFast
from transformers import T5Tokenizer, T5ForConditionalGeneration
import functools
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from transformers.models.t5.modeling_t5 import T5Block
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing_wrapper)
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy,
FullStateDictConfig,
StateDictType,
)
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
enable_wrap,
wrap,
)
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
from summarization_dataset import *
from transformers.models.t5.modeling_t5 import T5Block
from typing import Type
import time
import tqdm
from datetime import datetime
1.4 分散式訓練設定。在這裡,我們使用兩個輔助函式來初始化分散式訓練的程序,並在訓練完成後進行清理。在本教程中,我們將使用 torch elastic,透過 torchrun,這將自動設定工作節點的 RANK 和 WORLD_SIZE。
def setup():
# initialize the process group
dist.init_process_group("nccl")
def cleanup():
dist.destroy_process_group()
2.1 設定 HuggingFace T5 模型
def setup_model(model_name):
model = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = T5Tokenizer.from_pretrained(model_name)
return model, tokenizer
此外,我們在這裡添加了幾個用於日期和格式化記憶體指標的輔助函式。
def get_date_of_run():
"""create date and time for file save uniqueness
example: 2022-05-07-08:31:12_PM'
"""
date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
print(f"--> current date and time of run = {date_of_run}")
return date_of_run
def format_metrics_to_gb(item):
"""quick function to format numbers to gigabyte and round to 4 digit precision"""
metric_num = item / g_gigabyte
metric_num = round(metric_num, ndigits=4)
return metric_num
2.2 定義訓練函式
def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
model.train()
local_rank = int(os.environ['LOCAL_RANK'])
fsdp_loss = torch.zeros(2).to(local_rank)
if sampler:
sampler.set_epoch(epoch)
if rank==0:
inner_pbar = tqdm.tqdm(
range(len(train_loader)), colour="blue", desc="r0 Training Epoch"
)
for batch in train_loader:
for key in batch.keys():
batch[key] = batch[key].to(local_rank)
optimizer.zero_grad()
output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
loss = output["loss"]
loss.backward()
optimizer.step()
fsdp_loss[0] += loss.item()
fsdp_loss[1] += len(batch)
if rank==0:
inner_pbar.update(1)
dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
train_accuracy = fsdp_loss[0] / fsdp_loss[1]
if rank == 0:
inner_pbar.close()
print(
f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}"
)
return train_accuracy
2.3 定義驗證函式
def validation(model, rank, world_size, val_loader):
model.eval()
correct = 0
local_rank = int(os.environ['LOCAL_RANK'])
fsdp_loss = torch.zeros(3).to(local_rank)
if rank == 0:
inner_pbar = tqdm.tqdm(
range(len(val_loader)), colour="green", desc="Validation Epoch"
)
with torch.no_grad():
for batch in val_loader:
for key in batch.keys():
batch[key] = batch[key].to(local_rank)
output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"])
fsdp_loss[0] += output["loss"].item() # sum up batch loss
fsdp_loss[1] += len(batch)
if rank==0:
inner_pbar.update(1)
dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
val_loss = fsdp_loss[0] / fsdp_loss[1]
if rank == 0:
inner_pbar.close()
print(f"Validation Loss: {val_loss:.4f}")
return val_loss
2.4 定義一個將模型封裝在 FSDP 中的分散式訓練函式
def fsdp_main(args):
model, tokenizer = setup_model("t5-base")
local_rank = int(os.environ['LOCAL_RANK'])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
dataset = load_dataset('wikihow', 'all', data_dir='data/')
print(dataset.keys())
print("Size of train dataset: ", dataset['train'].shape)
print("Size of Validation dataset: ", dataset['validation'].shape)
#wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False)
val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False)
sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size)
setup()
train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
cuda_kwargs = {'num_workers': 2,
'pin_memory': True,
'shuffle': False}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)
t5_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
T5Block,
},
)
sharding_strategy: ShardingStrategy = ShardingStrategy.SHARD_GRAD_OP #for Zero2 and FULL_SHARD for Zero3
torch.cuda.set_device(local_rank)
#init_start_event = torch.cuda.Event(enable_timing=True)
#init_end_event = torch.cuda.Event(enable_timing=True)
#init_start_event.record()
bf16_ready = (
torch.version.cuda
and torch.cuda.is_bf16_supported()
and LooseVersion(torch.version.cuda) >= "11.0"
and dist.is_nccl_available()
and nccl.version() >= (2, 10)
)
if bf16_ready:
mp_policy = bfSixteen
else:
mp_policy = None # defaults to fp32
# model is on CPU before input to FSDP
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=mp_policy,
#sharding_strategy=sharding_strategy,
device_id=torch.cuda.current_device())
optimizer = optim.AdamW(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
best_val_loss = float("inf")
curr_val_loss = float("inf")
file_save_name = "T5-model-"
if rank == 0:
time_of_run = get_date_of_run()
dur = []
train_acc_tracking = []
val_acc_tracking = []
training_start_time = time.time()
if rank == 0 and args.track_memory:
mem_alloc_tracker = []
mem_reserved_tracker = []
for epoch in range(1, args.epochs + 1):
t0 = time.time()
train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
if args.run_validation:
curr_val_loss = validation(model, rank, world_size, val_loader)
scheduler.step()
if rank == 0:
print(f"--> epoch {epoch} completed...entering save and stats zone")
dur.append(time.time() - t0)
train_acc_tracking.append(train_accuracy.item())
if args.run_validation:
val_acc_tracking.append(curr_val_loss.item())
if args.track_memory:
mem_alloc_tracker.append(
format_metrics_to_gb(torch.cuda.memory_allocated())
)
mem_reserved_tracker.append(
format_metrics_to_gb(torch.cuda.memory_reserved())
)
print(f"completed save and stats zone...")
if args.save_model and curr_val_loss < best_val_loss:
# save
if rank == 0:
print(f"--> entering save model state")
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, save_policy
):
cpu_state = model.state_dict()
#print(f"saving process: rank {rank} done w state_dict")
if rank == 0:
print(f"--> saving model ...")
currEpoch = (
"-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt"
)
print(f"--> attempting to save model prefix {currEpoch}")
save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
print(f"--> saving as model name {save_name}")
torch.save(cpu_state, save_name)
if curr_val_loss < best_val_loss:
best_val_loss = curr_val_loss
if rank==0:
print(f"-->>>> New Val Loss Record: {best_val_loss}")
dist.barrier()
cleanup()
2.5 解析引數並設定主函式
if __name__ == '__main__':
# Training settings
parser = argparse.ArgumentParser(description='PyTorch T5 FSDP Example')
parser.add_argument('--batch-size', type=int, default=4, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=4, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=2, metavar='N',
help='number of epochs to train (default: 3)')
parser.add_argument('--lr', type=float, default=.002, metavar='LR',
help='learning rate (default: .002)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--track_memory', action='store_false', default=True,
help='track the gpu memory')
parser.add_argument('--run_validation', action='store_false', default=True,
help='running the validation')
parser.add_argument('--save-model', action='store_false', default=True,
help='For Saving the current Model')
args = parser.parse_args()
torch.manual_seed(args.seed)
fsdp_main(args)
要使用 torchrun 執行訓練
torchrun --nnodes 1 --nproc_per_node 4 T5_training.py
Transformer 包裝策略¶
如之前的教程所述,auto_wrap_policy 是 FSDP 的特性之一,它使得自動分片給定的模型並將模型、最佳化器和梯度分片放入不同的 FSDP 單元變得容易。
對於某些架構,例如 Transformer 編碼器-解碼器,模型的某些部分(例如嵌入表)與編碼器和解碼器共享。在這種情況下,我們需要將嵌入表放在外部 FSDP 單元中,以便編碼器和解碼器都可以訪問它。此外,透過註冊 Transformer 的層類,可以使分片計劃更具通訊效率。在 PyTorch 1.12 中,FSDP 添加了此支援,現在我們有了一個用於 Transformer 的包裝策略。
可以按如下方式建立,其中 T5Block 表示 T5 Transformer 層類(包含 MHSA 和 FFN)。
t5_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
T5Block,
},
)
torch.cuda.set_device(local_rank)
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy)
要檢視包裝後的模型,你可以輕鬆列印模型並直觀檢查分片和 FSDP 單元。
混合精度¶
FSDP 支援靈活的混合精度訓練,允許任意的降低精度型別(例如 fp16 或 bfloat16)。目前 BFloat16 僅在 Ampere GPU 上可用,因此在使用之前需要確認原生支援。例如,在 V100 上仍然可以執行 BFloat16,但由於它是非原生執行的,可能會導致明顯的效能下降。
要檢查 BFloat16 是否原生支援,你可以使用以下方法
bf16_ready = (
torch.version.cuda
and torch.cuda.is_bf16_supported()
and LooseVersion(torch.version.cuda) >= "11.0"
and dist.is_nccl_available()
and nccl.version() >= (2, 10)
)
FSDP 中混合精度的一個優點是提供對引數、梯度和緩衝區的不同精度級別的細粒度控制,如下所示
fpSixteen = MixedPrecision(
param_dtype=torch.float16,
# Gradient communication precision.
reduce_dtype=torch.float16,
# Buffer precision.
buffer_dtype=torch.float16,
)
bfSixteen = MixedPrecision(
param_dtype=torch.bfloat16,
# Gradient communication precision.
reduce_dtype=torch.bfloat16,
# Buffer precision.
buffer_dtype=torch.bfloat16,
)
fp32_policy = MixedPrecision(
param_dtype=torch.float32,
# Gradient communication precision.
reduce_dtype=torch.float32,
# Buffer precision.
buffer_dtype=torch.float32,
)
請注意,如果未指定某種型別(引數、reduce、緩衝區),則不會進行任何型別轉換。
這種靈活性允許使用者進行細粒度控制,例如只設置梯度通訊在降低精度下進行,而所有引數/緩衝區計算都在全精度下進行。這在節點內通訊是主要瓶頸且引數/緩衝區必須是全精度以避免精度問題的情況下可能很有用。可以透過以下策略實現:
grad_bf16 = MixedPrecision(reduce_dtype=torch.bfloat16)
在 2.4 中,我們只需將相關的混合精度策略新增到 FSDP 包裝器中
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=bfSixteen)
在我們的實驗中,我們觀察到使用 BFloat16 進行訓練可以將速度提高多達 4 倍,並在某些實驗中將記憶體減少約 30%,這可用於增加批次大小。
在裝置上初始化 FSDP 模型¶
在 1.12 版本中,FSDP 支援一個 device_id 引數,用於在 device_id 指定的裝置上初始化輸入的 CPU 模組。這在整個模型無法容納在單個 GPU 上,但可以容納在主機 CPU 記憶體中的情況下非常有用。指定 device_id 後,FSDP 會按每個 FSDP 單元將模型移動到指定的裝置,從而避免 GPU OOM 問題,同時初始化速度比基於 CPU 的初始化快幾倍。
torch.cuda.set_device(local_rank)
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=bfSixteen,
device_id=torch.cuda.current_device())
分片策略¶
FSDP 分片策略預設設定為完全分片模型引數、梯度和最佳化器狀態,將它們分片到所有 Rank。(也稱為 Zero3 分片)。如果你有興趣使用 Zero2 分片策略(僅對最佳化器狀態和梯度進行分片),FSDP 支援此功能,方法是在 FSDP 初始化時傳遞分片策略,使用“ShardingStrategy.SHARD_GRAD_OP”而不是“ShardingStrategy.FULL_SHARD”,如下所示
torch.cuda.set_device(local_rank)
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=bfSixteen,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP # ZERO2)
這將減少 FSDP 中的通訊開銷,在這種情況下,它在前向和反向傳播後會保留完整的引數。
這在反向傳播期間節省了一次 all_gather 操作,因此通訊量減少,但代價是記憶體佔用較高。請注意,完整的模型引數在反向傳播結束時會釋放,並在下一次前向傳播時發生 all_gather。
反向預取¶
反向預取設定控制何時請求下一個 FSDP 單元引數的時機。透過將其設定為 BACKWARD_PRE,可以在當前單元計算開始之前更早地請求並獲取下一個 FSDP 單元的引數。這使得 all_gather 通訊與梯度計算重疊,從而提高訓練速度,代價是略微增加記憶體消耗。可以在 2.4 中的 FSDP 包裝器中使用它,如下所示
torch.cuda.set_device(local_rank)
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=bfSixteen,
device_id=torch.cuda.current_device(),
backward_prefetch = BackwardPrefetch.BACKWARD_PRE)
backward_prefetch 有兩種模式:BACKWARD_PRE 和 BACKWARD_POST。BACKWARD_POST 意味著在當前 FSDP 單元處理完成之前,不會請求下一個 FSDP 單元的引數,從而最大限度地減少記憶體開銷。在某些情況下,使用 BACKWARD_PRE 可以將模型訓練速度提高 2-10%,對於較大的模型甚至可以觀察到更高的速度提升。
透過流式傳輸到 Rank0 CPU 儲存模型檢查點¶
為了使用 FULL_STATE_DICT 儲存模型檢查點(這種方式與儲存本地模型類似),PyTorch 1.12 提供了一些工具來支援儲存較大的模型。
首先,可以指定一個 FullStateDictConfig,允許僅在 Rank 0 上填充 state_dict 並解除安裝到 CPU。
使用此配置時,FSDP 將 allgather 模型引數,僅在 Rank 0 上將它們逐個解除安裝到 CPU。當最終儲存 state_dict 時,它將僅在 Rank 0 上填充幷包含 CPU 張量。這避免了對於大於單個 GPU 記憶體的模型可能出現的 OOM 問題,並允許使用者儲存大小大致與使用者機器上可用 CPU RAM 相當的模型檢查點。
此功能可以按如下方式執行
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, save_policy
):
cpu_state = model.state_dict()
if rank == 0:
save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
torch.save(cpu_state, save_name)
總結¶
在本教程中,我們介紹了 PyTorch 1.12 中 FSDP 的許多新特性,並使用 HF T5 作為執行示例。使用適當的包裝策略(特別是對於 Transformer 模型),以及混合精度和反向預取,應該可以加速你的訓練執行。此外,諸如在裝置上初始化模型和透過流式傳輸到 CPU 儲存檢查點等特性應該有助於避免處理大型模型時的 OOM 錯誤。
我們正在積極努力為 FSDP 的下一個版本新增新功能。如果你有反饋、功能請求、問題或在使用 FSDP 時遇到問題,請隨時透過在 PyTorch Github 倉庫中提出 issue 來聯絡我們。