• 文件 >
  • Table Batched Embedding (TBE) 訓練模組
快捷方式

Table Batched Embedding (TBE) 訓練模組

穩定 API

class fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen(embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]], feature_table_map: List[int] | None = None, cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU, cache_load_factor: float = 0.2, cache_sets: int = 0, cache_reserved_memory: float = 0.0, cache_precision: SparseType | None = None, weights_precision: SparseType = SparseType.FP32, output_dtype: SparseType = SparseType.FP32, enforce_hbm: bool = False, optimizer: EmbOptimType = EmbOptimType.EXACT_SGD, record_cache_metrics: RecordCacheMetrics | None = None, gather_uvm_cache_stats: bool | None = False, stochastic_rounding: bool = True, gradient_clipping: bool = False, max_gradient: float = 1.0, max_norm: float = 0.0, learning_rate: float = 0.01, eps: float = 1e-08, momentum: float = 0.9, weight_decay: float = 0.0, weight_decay_mode: WeightDecayMode = WeightDecayMode.NONE, eta: float = 0.001, beta1: float = 0.9, beta2: float = 0.999, ensemble_mode: EnsembleModeDefinition | None = None, emainplace_mode: EmainplaceModeDefinition | None = None, counter_based_regularization: CounterBasedRegularizationDefinition | None = None, cowclip_regularization: CowClipDefinition | None = None, pooling_mode: PoolingMode = PoolingMode.SUM, device: str | device | int | None = None, bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING, uvm_non_rowwise_momentum: bool = False, use_experimental_tbe: bool = False, prefetch_pipeline: bool = False, stats_reporter_config: TBEStatsReporterConfig | None = None, table_names: List[str] | None = None, optimizer_state_dtypes: Dict[str, SparseType] | None = None, multipass_prefetch_config: MultiPassPrefetchConfig | None = None, global_weight_decay: GlobalWeightDecayDefinition | None = None, uvm_host_mapped: bool = False, extra_optimizer_config: UserEnabledConfigDefinition | None = None, tbe_input_multiplexer_config: TBEInputMultiplexerConfig | None = None, embedding_table_index_type: dtype = torch.int64, embedding_table_offset_type: dtype = torch.int64, embedding_shard_info: List[Tuple[int, int, int, int]] | None = None)[原始碼]

Table Batched Embedding (TBE) 運算元。查詢一個或多個嵌入表。此模組應用於訓練。反向運算元與最佳化器融合。因此,嵌入表會在反向傳播期間更新。

引數:
  • embedding_specs (List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]) –

    嵌入規範列表。每個規範描述了一個物理嵌入表的規格。每個規範是一個元組,包含嵌入行的數量、嵌入維度(必須是 4 的倍數)、表位置(EmbeddingLocation)和計算裝置(ComputeDevice)。

    可用的 EmbeddingLocation 選項包括

    1. DEVICE = 將嵌入表放置在 GPU 全域性記憶體 (HBM) 中

    2. MANAGED = 將嵌入表放置在統一虛擬記憶體中(GPU 和 CPU 均可訪問)

    3. MANAGED_CACHING = 將嵌入表放置在統一虛擬記憶體中,並使用 GPU 全域性記憶體 (HBM) 作為快取

    4. HOST = 將嵌入表放置在 CPU 記憶體 (DRAM) 中

    5. MTIA = 將嵌入表放置在 MTIA 記憶體中

    可用的 ComputeDevice 選項包括

    1. CPU = 在 CPU 上執行表查詢

    2. CUDA = 在 GPU 上執行表查詢

    3. MTIA = 在 MTIA 上執行表查詢

  • feature_table_map (Optional[List[int]] = None) – 可選列表,指定特徵到表的對映。feature_table_map[i] 表示特徵 i 對映到的物理嵌入表。

  • cache_algorithm (CacheAlgorithm = CacheAlgorithm.LRU) –

    快取演算法(當 EmbeddingLocation 設定為 MANAGED_CACHING 時使用)。選項包括

    1. LRU = 最近最少使用

    2. LFU = 最不常用

  • cache_load_factor (float = 0.2) – 用於確定使用 EmbeddingLocation.MANAGED_CACHING 時的快取容量的因子。快取容量為 cache_load_factor * 所有嵌入表中的總行數。

  • cache_sets (int = 0) – 快取集的數量(當 EmbeddingLocation 設定為 MANAGED_CACHING 時使用)。

  • cache_reserved_memory (float = 0.0) – 在 HBM 中為非快取目的保留的記憶體量(當 EmbeddingLocation 設定為 MANAGED_CACHING 時使用)。

  • cache_precision (SparseType = SparseType.FP32) – 快取的資料型別(當 EmbeddingLocation 設定為 MANAGED_CACHING 時使用)。選項為 SparseType.FP32SparseType.FP16

  • weights_precision (SparseType = SparseType.FP32) – 嵌入表(也稱為權重)的資料型別。選項為 SparseType.FP32SparseType.FP16

  • output_dtype (SparseType = SparseType.FP32) – 輸出張量的資料型別。選項為 SparseType.FP32SparseType.FP16

  • enforce_hbm (bool = False) – 如果為 True,則在使用 EmbeddingLocation.MANAGED_CACHING 時將所有權重/動量放置在 HBM 中。

  • optimizer (OptimType = OptimType.EXACT_SGD) –

    在反向傳播中用於更新嵌入表的最佳化器。可用的 OptimType 選項包括

    1. ADAM = Adam

    2. EXACT_ADAGRAD = Adagrad

    3. EXACT_ROWWISE_ADAGRAD = 按行 Adagrad

    4. EXACT_SGD = SGD

    5. LAMB = Lamb

    6. LARS_SGD = LARS-SGD

    7. PARTIAL_ROWWISE_ADAM = 部分按行 Adam

    8. PARTIAL_ROWWISE_LAMB = 部分按行 Lamb

    9. ENSEMBLE_ROWWISE_ADAGRAD = 整合按行 Adagrad

    10. EMAINPLACE_ROWWISE_ADAGRAD = EMA 就地按行 Adagrad

    11. NONE = 在反向傳播中不應用最佳化器更新

    並輸出稀疏權重梯度

  • record_cache_metrics (Optional[RecordCacheMetrics] = None) – 記錄命中次數、請求次數等,如果 RecordCacheMetrics.record_cache_miss_counter 為 True;如果 RecordCacheMetrics.record_tablewise_cache_miss 為 True,則按表記錄類似指標。

  • gather_uvm_cache_stats (Optional[bool] = False) – 如果為 True,則當 EmbeddingLocation 設定為 MANAGED_CACHING 時收集快取統計資訊。

  • stochastic_rounding (bool = True) – 如果為 True,則對非 SparseType.FP32 的權重型別應用隨機舍入。

  • gradient_clipping (bool = False) – 如果為 True,則應用梯度裁剪。

  • max_gradient (float = 1.0) – 梯度裁剪的值。

  • max_norm (float = 0.0) – 最大範數值。

  • learning_rate (float = 0.01) – 學習率。

  • eps (float = 1.0e-8) – Adagrad、LAMB 和 Adam 使用的 epsilon 值。注意,此預設值與 torch.nn.optim.Adagrad 的預設值 1e-10 不同。

  • momentum (float = 0.9) – LARS-SGD 使用的動量。

  • weight_decay (float = 0.0) –

    LARS-SGD、LAMB、ADAM 和按行 Adagrad 使用的權重衰減。

    1. EXACT_ADAGRAD、SGD、EXACT_SGD 不支援權重衰減

    2. LAMB、ADAM、PARTIAL_ROWWISE_ADAM、PARTIAL_ROWWISE_LAMB、LARS_SGD 支援解耦權重衰減

    3. EXACT_ROWWISE_ADAGRAD 支援 L2 和解耦權重衰減(透過 weight_decay_mode)

  • weight_decay_mode (WeightDecayMode = WeightDecayMode.NONE) – 權重衰減模式。選項為 WeightDecayMode.NONEWeightDecayMode.L2WeightDecayMode.DECOUPLE

  • eta (float = 0.001) – LARS-SGD 使用的 eta 值。

  • beta1 (float = 0.9) – LAMB 和 ADAM 使用的 beta1 值。

  • beta2 (float = 0.999) – LAMB 和 ADAM 使用的 beta2 值。

  • ensemble_mode (Optional[EnsembleModeDefinition] = None) – 由整合按行 Adagrad 使用。

  • emainplace_mode (Optional[EmainplaceModeDefinition] = None) – 由 EMA 就地按行 Adagrad 使用。

  • counter_based_regularization (Optional[CounterBasedRegularizationDefinition] = None) – 由按行 Adagrad 使用。

  • cowclip_regularization (Optional[CowClipDefinition] = None) – 由按行 Adagrad 使用。

  • pooling_mode (PoolingMode = PoolingMode.SUM) –

    池化模式。可用的 PoolingMode 選項包括

    1. SUM = 求和池化

    2. MEAN = 平均值池化

    3. NONE = 不進行池化(序列嵌入)

  • device (Optional[Union[str, int, torch.device]] = None) – 當前放置張量的裝置。

  • bounds_check_mode (BoundsCheckMode = BoundsCheckMode.WARNING) –

    輸入檢查模式。可用的 BoundsCheckMode 選項包括

    1. NONE = 跳過邊界檢查

    2. FATAL = 遇到無效索引/偏移量時丟擲錯誤

    3. WARNING = 遇到無效索引/偏移量時列印警告訊息並修復(將無效索引設定為零,並將無效偏移量調整到邊界內)

    4. IGNORE = 靜默修復無效索引/偏移量(將無效索引設定為零,並將無效偏移量調整到邊界內)

  • uvm_non_rowwise_momentum (bool = False) – 如果為 True,則將非按行動量放置在統一虛擬記憶體中。

  • use_experimental_tbe (bool = False) – 如果為 True,則使用最佳化的 TBE 實現(TBE v2)。請注意,這僅支援 NVIDIA GPU。

  • prefetch_pipeline (bool = False) – 如果為 True,則在使用 EmbeddingLocation.MANAGED_CACHING 時啟用快取預取流水線。目前僅支援 LRU 快取策略。如果使用單獨的流進行預取,則必須設定預取函式的可選引數 forward_stream

  • stats_reporter_config (Optional[TBEStatsReporterConfig] = None) – TBE 統計報告器的配置。

  • table_names (Optional[List[str]] = None) – 此 TBE 中的嵌入表名稱列表。

  • optimizer_state_dtypes (Optional[Dict[str, SparseType]] = None) – 最佳化器狀態資料型別字典。鍵是最佳化器狀態名稱,值是其對應的型別

  • multipass_prefetch_config (Optional[MultiPassPrefetchConfig] = None) – 用於多遍快取預取的配置(當使用 EmbeddingLocation.MANAGED_CACHING 時)

  • global_weight_decay (Optional[GlobalWeightDecayDefinition] = None) – 用於全域性權重衰減的配置

  • uvm_host_mapped (bool = False) – 如果為 True,則使用 malloc + cudaHostRegister 分配每個 UVM 張量。否則使用 cudaMallocManaged

  • None) (extra_optimizer_config Optional[UserEnabledConfigDefinition] =) –

    一個額外的配置,用於為最佳化器啟用某些模式。這些模式預設不啟用。- 在 Adam 中使用 use_rowwise_bias_correction 啟用逐行偏差校正

    計算

  • embedding_table_index_type (torch.dtype = torch.int64) – 嵌入表索引張量的資料型別。選項包括 torch.int32torch.int64

  • embedding_table_offset_type (torch.dtype = torch.int64) – 嵌入表偏移張量的資料型別。選項包括 torch.int32torch.int64

  • embedding_shard_info (Optional[List[Tuple[int, int, int, int]]] = None) – 關於分片位置和預分片表大小的資訊。如果未設定,則表不分片。(preshard_table_height, preshard_table_dim, height_offset, dim_offset)

forward(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor | None = None, feature_requires_grad: Tensor | None = None, batch_size_per_feature_per_rank: List[List[int]] | None = None, total_unique_indices: int | None = None) Tensor[source]

正向傳播函式,它執行以下操作:

  1. 執行輸入邊界檢查

  2. 生成必要的變長批次嵌入 (VBE) 元資料(如果使用 VBE)

  3. 將資料從 UVM 預取到快取(如果使用 EmbeddingLocation.MANAGED_CACHING 且使用者未明確預取資料)

  4. 透過呼叫相應的 Autograd 函式(根據所選最佳化器)執行嵌入表查詢

引數:
  • indices (Tensor) – 一個 1D 張量,包含要從所有嵌入表中查詢的索引

  • offsets (Tensor) – 一個 1D 張量,包含索引的偏移量。形狀為 (B * T + 1),其中 B = 批次大小,T = 特徵數量。offsets[t * B + b + 1] - offsets[t * B + b] 是特徵 t 中 bag b 的長度

  • per_sample_weights (Optional[Tensor]) – 一個可選的 1D float 張量,包含每個樣本的權重。如果為 None,將執行無權重嵌入查詢。否則,將使用加權查詢。此張量的長度必須與 indices 張量的長度相同。per_sample_weights[i] 的值將用於乘以查詢的行 indices[i] 中的每個元素,其中 0 <= i < len(per_sample_weights)

  • feature_requires_grad (Optional[Tensor]) – 一個可選的 1D 張量,用於指示 per_sample_weights 是否需要梯度。張量的長度必須等於特徵數量

  • batch_size_per_feature_per_rank (Optional[List[List[int]]]) – 一個可選的 2D 張量,包含每個 rank 和每個特徵的批次大小。如果為 None,TBE 假定每個特徵具有相同的批次大小,並從 offsets 形狀計算批次大小。否則,TBE 假定不同特徵可以具有不同的批次大小,並使用變長批次嵌入查詢模式 (VBE)。形狀為(特徵數量,rank 數量)。batch_size_per_feature_per_rank[f][r] 表示特徵 f 和 rank r 的批次大小

  • total_unique_indices (Optional[int]) – 一個可選的整數,表示唯一索引的總數。當使用 OptimType.NONE 時,必須設定此值。這是因為 TBE 在反向傳播中分配權重梯度張量需要此資訊。

返回:

一個包含查詢資料的 2D 張量。形狀為 (B, total_D),其中 B = 批次大小,total_D = 表中所有嵌入維度的總和

示例

>>> import torch
>>>
>>> from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
>>>    EmbeddingLocation,
>>> )
>>> from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
>>>    SplitTableBatchedEmbeddingBagsCodegen,
>>>    ComputeDevice,
>>> )
>>>
>>> # Two tables
>>> embedding_specs = [
>>>     (3, 8, EmbeddingLocation.DEVICE, ComputeDevice.CUDA),
>>>     (5, 4, EmbeddingLocation.MANAGED, ComputeDevice.CUDA)
>>> ]
>>>
>>> tbe = SplitTableBatchedEmbeddingBagsCodegen(embedding_specs)
>>> tbe.init_embedding_weights_uniform(-1, 1)
>>>
>>> print(tbe.split_embedding_weights())
[tensor([[-0.9426,  0.7046,  0.4214, -0.0419,  0.1331, -0.7856, -0.8124, -0.2021],
        [-0.5771,  0.5911, -0.7792, -0.1068, -0.6203,  0.4813, -0.1677,  0.4790],
        [-0.5587, -0.0941,  0.5754,  0.3475, -0.8952, -0.1964,  0.0810, -0.4174]],
       device='cuda:0'), tensor([[-0.2513, -0.4039, -0.3775,  0.3273],
        [-0.5399, -0.0229, -0.1455, -0.8770],
        [-0.9520,  0.4593, -0.7169,  0.6307],
        [-0.1765,  0.8757,  0.8614,  0.2051],
        [-0.0603, -0.9980, -0.7958, -0.5826]], device='cuda:0')]
>>> # Batch size = 3
>>> indices = torch.tensor([0, 1, 2, 0, 1, 2, 0, 3, 1, 4, 2, 0, 0],
>>>                        device="cuda",
>>>                        dtype=torch.long)
>>> offsets = torch.tensor([0, 2, 5, 7, 9, 12, 13],
>>>                        device="cuda",
>>>                        dtype=torch.long)
>>>
>>> output = tbe(indices, offsets)
>>>
>>> # Batch size = 3, total embedding dimension = 12
>>> print(output.shape)
torch.Size([3, 12])
>>> print(output)
tensor([[-1.5197,  1.2957, -0.3578, -0.1487, -0.4873, -0.3044, -0.9801,  0.2769,
         -0.7164,  0.8528,  0.7159, -0.6719],
        [-2.0784,  1.2016,  0.2176,  0.1988, -1.3825, -0.5008, -0.8991, -0.1405,
         -1.2637, -0.9427, -1.8902,  0.3754],
        [-1.5013,  0.6105,  0.9968,  0.3057, -0.7621, -0.9821, -0.7314, -0.6195,
         -0.2513, -0.4039, -0.3775,  0.3273]], device='cuda:0',
       grad_fn=<CppNode<SplitLookupFunction_sgd_Op>>)
set_learning_rate(lr: float) None[source]

設定學習率。

引數:

lr (float) – 要設定的學習率值

set_optimizer_step(step: int) None[source]

設定最佳化器步數。

引數:

step (int) – 要設定的步數值

split_embedding_weights() List[Tensor][source]

返回一個嵌入權重列表(檢視),按表分割

返回:

權重列表。長度 = 表的數量

split_optimizer_states() List[List[Tensor]][source]

返回一個最佳化器狀態列表(檢視),按表分割

返回:

狀態列表的列表。形狀 =(表的數量,狀態的數量)。

以下顯示了每個最佳化器的狀態列表(按返回順序)

  1. ADAMmomentum1momentum2

  2. EXACT_ADAGRADmomentum1

  3. EXACT_ROWWISE_ADAGRADmomentum1(逐行),prev_iter(逐行;僅當使用 WeightDecayMode = COUNTERCOWCLIPglobal_weight_decay 不為 None 時),row_counter(逐行;僅當使用 WeightDecayMode = COUNTERCOWCLIP 時)

  4. EXACT_SGD:無狀態

  5. LAMBmomentum1momentum2

  6. LARS_SGDmomentum1

  7. PARTIAL_ROWWISE_ADAMmomentum1momentum2(逐行)

  8. PARTIAL_ROWWISE_LAMBmomentum1momentum2(逐行)

  9. ENSEMBLE_ROWWISE_ADAGRADmomentum1(逐行),momentum2

  10. NONE:無狀態(丟擲錯誤)

update_hyper_parameters(params_dict: Dict[str, float]) None[source]

從外部控制流設定超引數。

引數:

params_dict (Dict[str, float]) – 包含超引數名稱及其值的字典

其他 API

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取適合初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源