表批處理嵌入 (TBE) 推理模組¶
穩定版 API¶
- class fbgemm_gpu.split_table_batched_embeddings_ops_inference.IntNBitTableBatchedEmbeddingBagsCodegen(embedding_specs: List[Tuple[str, int, int, SparseType, EmbeddingLocation]], feature_table_map: List[int] | None = None, index_remapping: List[Tensor] | None = None, pooling_mode: PoolingMode = PoolingMode.SUM, device: str | device | int | None = None, bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING, weight_lists: List[Tuple[Tensor, Tensor | None]] | None = None, pruning_hash_load_factor: float = 0.5, use_array_for_index_remapping: bool = True, output_dtype: SparseType = SparseType.FP16, cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU, cache_load_factor: float = 0.2, cache_sets: int = 0, cache_reserved_memory: float = 0.0, enforce_hbm: bool = False, record_cache_metrics: RecordCacheMetrics | None = None, gather_uvm_cache_stats: bool | None = False, row_alignment: int | None = None, fp8_exponent_bits: int | None = None, fp8_exponent_bias: int | None = None, cache_assoc: int = 32, scale_bias_size_in_bytes: int = 4, cacheline_alignment: bool = True, uvm_host_mapped: bool = False, reverse_qparam: bool = False, feature_names_per_table: List[List[str]] | None = None, indices_dtype: dtype = torch.int32)[source]¶
nn.EmbeddingBag(sparse=False) 的推理版本,採用表批處理,支援 FP32/FP16/FP8/INT8/INT4/INT2 權重
- 引數:
embedding_specs (List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]) –
嵌入規格列表。每個規格描述一個物理嵌入表的詳細資訊。每個規格是一個元組,包含嵌入行的數量、嵌入維度(必須是 4 的倍數)、表位置(EmbeddingLocation)和計算裝置(ComputeDevice)。
可用的 EmbeddingLocation 選項包括
DEVICE = 放置嵌入表到 GPU 全域性記憶體 (HBM)
MANAGED = 放置嵌入到統一虛擬記憶體(可從 GPU 和 CPU 訪問)
MANAGED_CACHING = 放置嵌入表到統一虛擬記憶體,並使用 GPU 全域性記憶體 (HBM) 作為快取
HOST = 放置嵌入表到 CPU 記憶體 (DRAM)
MTIA = 放置嵌入表到 MTIA 記憶體
可用的 ComputeDevice 選項包括
CPU = 在 CPU 上執行表查詢
CUDA = 在 GPU 上執行表查詢
MTIA = 在 MTIA 上執行表查詢
feature_table_map (Optional[List[int]] = None) – 一個可選列表,指定特徵與表的對映。feature_table_map[i] 指示特徵 i 對映到的物理嵌入表。
index_remapping (Optional[List[Tensor]] = None) – 用於剪枝的索引重對映
pooling_mode (PoolingMode = PoolingMode.SUM) –
池化模式。可用的 PoolingMode 選項包括
SUM = 求和池化
MEAN = 平均池化
NONE = 無池化(序列嵌入)
device (Optional[Union[str, int, torch.device]] = None) – 放置張量的當前裝置
bounds_check_mode (BoundsCheckMode = BoundsCheckMode.WARNING) –
輸入檢查模式。可用的 BoundsCheckMode 選項包括
NONE = 跳過邊界檢查
FATAL = 當遇到無效索引/偏移時丟擲錯誤
WARNING = 當遇到無效索引/偏移時列印警告訊息並修正(將無效索引設為零,並將無效偏移調整到邊界內)
IGNORE = 靜默修正無效索引/偏移(將無效索引設為零,並將無效偏移調整到邊界內)
weight_lists (Optional[List[Tuple[Tensor, Optional[Tensor]]]] = None) – [T]
pruning_hash_load_factor (float = 0.5) – 用於剪枝雜湊的負載因子
use_array_for_index_remapping (bool = True) – 如果為 True,則使用陣列進行索引重對映。否則,使用雜湊表。
output_dtype (SparseType = SparseType.FP16) – 輸出張量的資料型別。
cache_algorithm (CacheAlgorithm = CacheAlgorithm.LRU) –
快取演算法(當 EmbeddingLocation 設定為 MANAGED_CACHING 時使用)。選項包括
LRU = 最近最少使用
LFU = 最不經常使用
cache_load_factor (float = 0.2) – 用於確定快取容量的因子(當使用 EmbeddingLocation.MANAGED_CACHING 時)。快取容量為 cache_load_factor * 所有嵌入表中的總行數
cache_sets (int = 0) – 快取集的數量(當 EmbeddingLocation 設定為 MANAGED_CACHING 時使用)
cache_reserved_memory (float = 0.0) – 在 HBM 中為非快取目的保留的記憶體量(當 EmbeddingLocation 設定為 MANAGED_CACHING 時使用)。
enforce_hbm (bool = False) – 如果為 True,在使用 EmbeddingLocation.MANAGED_CACHING 時將所有權重/動量放置在 HBM 中
record_cache_metrics (Optional[RecordCacheMetrics] = None) – 如果 RecordCacheMetrics.record_cache_miss_counter 為 True,則記錄命中次數、請求次數等;如果 RecordCacheMetrics.record_tablewise_cache_miss 為 True,則按表記錄類似指標
gather_uvm_cache_stats (Optional[bool] = False) – 如果為 True,當 EmbeddingLocation 設定為 MANAGED_CACHING 時收集快取統計資訊
row_alignment (Optional[int] = None) – 行對齊
fp8_exponent_bits (Optional[int] = None) – 使用 FP8 時的指數位數
fp8_exponent_bias (Optional[int] = None) – 使用 FP8 時的指數偏差
cache_assoc (int = 32) – 快取的路數
scale_bias_size_in_bytes (int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES) – 縮放和偏差的大小(位元組)
cacheline_alignment (bool = True) – 如果為 True,將每個表對齊到 128b 快取行邊界
uvm_host_mapped (bool = False) – 如果為 True,使用 malloc + cudaHostRegister 分配每個 UVM 張量。否則使用 cudaMallocManaged
reverse_qparam (bool = False) – 如果為 True,在每行末尾載入 qparams。否則,在每行開頭載入 qparams。
feature_names_per_table (Optional[List[List[str]]] = None) – 一個可選列表,指定每張表的特徵名稱。feature_names_per_table[t] 指示表 t 的特徵名稱。
indices_dtype (torch.dtype = torch.int32) – 預期傳遞給 forward() 呼叫的索引張量的資料型別。此資訊將用於構建 remap_indices 陣列/雜湊表。選項包括 torch.int32 和 torch.int64。
- assign_embedding_weights(q_weight_list: List[Tuple[Tensor, Tensor | None]]) None[source]¶
將 self.split_embedding_weights() 分配為輸入權重和 scale_shifts 列表中的值。
- forward(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor | None = None) Tensor[source]¶
定義每次呼叫時執行的計算。
應被所有子類重寫。
注意
儘管前向傳播(forward pass)的實現需要在此函式中定義,但之後應呼叫
Module例項而非此函式本身,因為前者會負責執行已註冊的鉤子(hooks),而後者會靜默忽略它們。
- recompute_module_buffers() None[source]¶
計算位於元裝置上且尚未在 reset_weights_placements_and_offsets() 中例項化的模組緩衝區。當前這些緩衝區包括 weights_tys、rows_per_table、D_offsets 和 bounds_check_warning。當前不計算剪枝相關或 uvm 相關的緩衝區。