• 文件 >
  • 表批處理嵌入 (TBE) 推理模組
快捷方式

表批處理嵌入 (TBE) 推理模組

穩定版 API

class fbgemm_gpu.split_table_batched_embeddings_ops_inference.IntNBitTableBatchedEmbeddingBagsCodegen(embedding_specs: List[Tuple[str, int, int, SparseType, EmbeddingLocation]], feature_table_map: List[int] | None = None, index_remapping: List[Tensor] | None = None, pooling_mode: PoolingMode = PoolingMode.SUM, device: str | device | int | None = None, bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING, weight_lists: List[Tuple[Tensor, Tensor | None]] | None = None, pruning_hash_load_factor: float = 0.5, use_array_for_index_remapping: bool = True, output_dtype: SparseType = SparseType.FP16, cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU, cache_load_factor: float = 0.2, cache_sets: int = 0, cache_reserved_memory: float = 0.0, enforce_hbm: bool = False, record_cache_metrics: RecordCacheMetrics | None = None, gather_uvm_cache_stats: bool | None = False, row_alignment: int | None = None, fp8_exponent_bits: int | None = None, fp8_exponent_bias: int | None = None, cache_assoc: int = 32, scale_bias_size_in_bytes: int = 4, cacheline_alignment: bool = True, uvm_host_mapped: bool = False, reverse_qparam: bool = False, feature_names_per_table: List[List[str]] | None = None, indices_dtype: dtype = torch.int32)[source]

nn.EmbeddingBag(sparse=False) 的推理版本,採用表批處理,支援 FP32/FP16/FP8/INT8/INT4/INT2 權重

引數:
  • embedding_specs (List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]) –

    嵌入規格列表。每個規格描述一個物理嵌入表的詳細資訊。每個規格是一個元組,包含嵌入行的數量、嵌入維度(必須是 4 的倍數)、表位置(EmbeddingLocation)和計算裝置(ComputeDevice)。

    可用的 EmbeddingLocation 選項包括

    1. DEVICE = 放置嵌入表到 GPU 全域性記憶體 (HBM)

    2. MANAGED = 放置嵌入到統一虛擬記憶體(可從 GPU 和 CPU 訪問)

    3. MANAGED_CACHING = 放置嵌入表到統一虛擬記憶體,並使用 GPU 全域性記憶體 (HBM) 作為快取

    4. HOST = 放置嵌入表到 CPU 記憶體 (DRAM)

    5. MTIA = 放置嵌入表到 MTIA 記憶體

    可用的 ComputeDevice 選項包括

    1. CPU = 在 CPU 上執行表查詢

    2. CUDA = 在 GPU 上執行表查詢

    3. MTIA = 在 MTIA 上執行表查詢

  • feature_table_map (Optional[List[int]] = None) – 一個可選列表,指定特徵與表的對映。feature_table_map[i] 指示特徵 i 對映到的物理嵌入表。

  • index_remapping (Optional[List[Tensor]] = None) – 用於剪枝的索引重對映

  • pooling_mode (PoolingMode = PoolingMode.SUM) –

    池化模式。可用的 PoolingMode 選項包括

    1. SUM = 求和池化

    2. MEAN = 平均池化

    3. NONE = 無池化(序列嵌入)

  • device (Optional[Union[str, int, torch.device]] = None) – 放置張量的當前裝置

  • bounds_check_mode (BoundsCheckMode = BoundsCheckMode.WARNING) –

    輸入檢查模式。可用的 BoundsCheckMode 選項包括

    1. NONE = 跳過邊界檢查

    2. FATAL = 當遇到無效索引/偏移時丟擲錯誤

    3. WARNING = 當遇到無效索引/偏移時列印警告訊息並修正(將無效索引設為零,並將無效偏移調整到邊界內)

    4. IGNORE = 靜默修正無效索引/偏移(將無效索引設為零,並將無效偏移調整到邊界內)

  • weight_lists (Optional[List[Tuple[Tensor, Optional[Tensor]]]] = None) – [T]

  • pruning_hash_load_factor (float = 0.5) – 用於剪枝雜湊的負載因子

  • use_array_for_index_remapping (bool = True) – 如果為 True,則使用陣列進行索引重對映。否則,使用雜湊表。

  • output_dtype (SparseType = SparseType.FP16) – 輸出張量的資料型別。

  • cache_algorithm (CacheAlgorithm = CacheAlgorithm.LRU) –

    快取演算法(當 EmbeddingLocation 設定為 MANAGED_CACHING 時使用)。選項包括

    1. LRU = 最近最少使用

    2. LFU = 最不經常使用

  • cache_load_factor (float = 0.2) – 用於確定快取容量的因子(當使用 EmbeddingLocation.MANAGED_CACHING 時)。快取容量為 cache_load_factor * 所有嵌入表中的總行數

  • cache_sets (int = 0) – 快取集的數量(當 EmbeddingLocation 設定為 MANAGED_CACHING 時使用)

  • cache_reserved_memory (float = 0.0) – 在 HBM 中為非快取目的保留的記憶體量(當 EmbeddingLocation 設定為 MANAGED_CACHING 時使用)。

  • enforce_hbm (bool = False) – 如果為 True,在使用 EmbeddingLocation.MANAGED_CACHING 時將所有權重/動量放置在 HBM 中

  • record_cache_metrics (Optional[RecordCacheMetrics] = None) – 如果 RecordCacheMetrics.record_cache_miss_counter 為 True,則記錄命中次數、請求次數等;如果 RecordCacheMetrics.record_tablewise_cache_miss 為 True,則按表記錄類似指標

  • gather_uvm_cache_stats (Optional[bool] = False) – 如果為 True,當 EmbeddingLocation 設定為 MANAGED_CACHING 時收集快取統計資訊

  • row_alignment (Optional[int] = None) – 行對齊

  • fp8_exponent_bits (Optional[int] = None) – 使用 FP8 時的指數位數

  • fp8_exponent_bias (Optional[int] = None) – 使用 FP8 時的指數偏差

  • cache_assoc (int = 32) – 快取的路數

  • scale_bias_size_in_bytes (int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES) – 縮放和偏差的大小(位元組)

  • cacheline_alignment (bool = True) – 如果為 True,將每個表對齊到 128b 快取行邊界

  • uvm_host_mapped (bool = False) – 如果為 True,使用 malloc + cudaHostRegister 分配每個 UVM 張量。否則使用 cudaMallocManaged

  • reverse_qparam (bool = False) – 如果為 True,在每行末尾載入 qparams。否則,在每行開頭載入 qparams

  • feature_names_per_table (Optional[List[List[str]]] = None) – 一個可選列表,指定每張表的特徵名稱。feature_names_per_table[t] 指示表 t 的特徵名稱。

  • indices_dtype (torch.dtype = torch.int32) – 預期傳遞給 forward() 呼叫的索引張量的資料型別。此資訊將用於構建 remap_indices 陣列/雜湊表。選項包括 torch.int32torch.int64

assign_embedding_weights(q_weight_list: List[Tuple[Tensor, Tensor | None]]) None[source]

將 self.split_embedding_weights() 分配為輸入權重和 scale_shifts 列表中的值。

fill_random_weights() None[source]

按表用隨機權重填充緩衝區

forward(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor | None = None) Tensor[source]

定義每次呼叫時執行的計算。

應被所有子類重寫。

注意

儘管前向傳播(forward pass)的實現需要在此函式中定義,但之後應呼叫 Module 例項而非此函式本身,因為前者會負責執行已註冊的鉤子(hooks),而後者會靜默忽略它們。

recompute_module_buffers() None[source]

計算位於元裝置上且尚未在 reset_weights_placements_and_offsets() 中例項化的模組緩衝區。當前這些緩衝區包括 weights_tysrows_per_tableD_offsetsbounds_check_warning。當前不計算剪枝相關或 uvm 相關的緩衝區。

split_embedding_weights(split_scale_shifts: bool = True) List[Tuple[Tensor, Tensor | None]][source]

返回按表拆分的權重列表。

split_embedding_weights_with_scale_bias(split_scale_bias_mode: int = 1) List[Tuple[Tensor, Tensor | None, Tensor | None]][source]

返回按 split_scale_bias_mode 模式拆分的權重列表。

0: 返回一行;1: 返回 weights + scale_bias;2: 返回 weights, scale, bias。

其他 API

文件

查閱 PyTorch 完整的開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並解答您的疑問

檢視資源