• 文件 >
  • 表批處理嵌入運算子
快捷方式

表批處理嵌入運算子

std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>> get_unique_indices_cuda(const at::Tensor &linear_indices, const int64_t max_indices, const bool compute_count)

索引去重。

std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>> get_unique_indices_with_inverse_cuda(const at::Tensor &linear_indices, const int64_t max_indices, const bool compute_count, const bool compute_inverse_indices)

索引去重。

std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>> lru_cache_find_uncached_cuda(at::Tensor unique_indices, at::Tensor unique_indices_length, int64_t max_indices, at::Tensor lxu_cache_state, int64_t time_stamp, at::Tensor lru_state, bool gather_cache_stats, at::Tensor uvm_cache_stats, bool lock_cache_line, at::Tensor lxu_cache_locking_counter, const bool compute_inverse_indices)

查詢 LRU 快取以查詢未快取的索引,然後根據集合對它們進行排序。

int64_t host_lxu_cache_slot(int64_t h_in, int64_t C)

將索引對映到 cache_set。h_in:線性索引;C:快取集數量。

at::Tensor linearize_cache_indices_cuda(const at::Tensor &cache_hash_size_cumsum, const at::Tensor &indices, const at::Tensor &offsets, const std::optional<at::Tensor> &B_offsets, const int64_t max_B, const int64_t indices_base_offset)

將所有表的索引線性化以使其唯一

at::Tensor linearize_cache_indices_from_row_idx_cuda(at::Tensor cache_hash_size_cumsum, at::Tensor update_table_indices, at::Tensor update_row_indices)

將所有表的索引線性化以使其唯一。注意 update_table_indices 和 update_row_indices 來自用於原地更新的行索引格式。

at::Tensor direct_mapped_lxu_cache_lookup_cuda(at::Tensor linear_cache_indices, at::Tensor lxu_cache_state, int64_t invalid_index, bool gather_cache_stats, std::optional<at::Tensor> uvm_cache_stats)

LRU 快取:從 weights 中獲取與 linear_cache_indices 對應的行,並在時間步 time_stamp 將它們插入快取中。

. void lru_cache_populate_cuda(

at::Tensor weights,

at::Tensor hash_size_cumsum,

int64_t total_cache_hash_size,

at::Tensor cache_index_table_map,

at::Tensor weights_offsets,

at::Tensor D_offsets,

at::Tensor linear_cache_indices,

at::Tensor lxu_cache_state,

at::Tensor lxu_cache_weights,

int64_t time_stamp,

at::Tensor lru_state,

bool stochastic_rounding,

bool gather_cache_stats,

std::optional<at::Tensor> uvm_cache_stats,

bool lock_cache_line,

std::optional<at::Tensor> lxu_cache_locking_counter);

// LRU 快取:從 weights 中獲取與 linear_cache_indices 對應的行,並在時間步 time_stamp 將它們插入快取中。

. // weights 和 lxu_cache_weights 的元素型別為 “uint8_t” 位元組 void lru_cache_populate_byte_cuda(

at::Tensor weights,

at::Tensor hash_size_cumsum,

int64_t total_cache_hash_size,

at::Tensor cache_index_table_map,

at::Tensor weights_offsets,

at::Tensor weights_tys,

at::Tensor D_offsets,

at::Tensor linear_cache_indices,

at::Tensor lxu_cache_state,

at::Tensor lxu_cache_weights,

int64_t time_stamp,

at::Tensor lru_state,

int64_t row_alignment,

bool gather_cache_stats,

std::optional<at::Tensor> uvm_cache_stats);

// direct-mapped (assoc=1) 版本的 lru_cache_populate_byte_cuda void direct_mapped_lru_cache_populate_byte_cuda(

at::Tensor weights,

at::Tensor hash_size_cumsum,

int64_t total_cache_hash_size,

at::Tensor cache_index_table_map,

at::Tensor weights_offsets,

at::Tensor weights_tys,

at::Tensor D_offsets,

at::Tensor linear_cache_indices,

at::Tensor lxu_cache_state,

at::Tensor lxu_cache_weights,

int64_t time_stamp,

at::Tensor lru_state,

at::Tensor lxu_cache_miss_timestamp,

int64_t row_alignment,

bool gather_cache_stats,

std::optional<at::Tensor> uvm_cache_stats);

// LFU 快取:從 weights 中獲取與 linear_cache_indices 對應的行

並將它們插入快取中。 void lfu_cache_populate_cuda(

at::Tensor weights,

at::Tensor cache_hash_size_cumsum,

int64_t total_cache_hash_size,

at::Tensor cache_index_table_map,

at::Tensor weights_offsets,

at::Tensor D_offsets,

at::Tensor linear_cache_indices,

at::Tensor lxu_cache_state,

at::Tensor lxu_cache_weights,

at::Tensor lfu_state,

bool stochastic_rounding);

// LFU 快取:從 weights 中獲取與 linear_cache_indices 對應的行

並將它們插入快取中。 // weights 和 lxu_cache_weights 的元素型別為 “uint8_t” 位元組 void lfu_cache_populate_byte_cuda(

at::Tensor weights,

at::Tensor cache_hash_size_cumsum,

int64_t total_cache_hash_size,

at::Tensor cache_index_table_map,

at::Tensor weights_offsets,

at::Tensor weights_tys,

at::Tensor D_offsets,

at::Tensor linear_cache_indices,

at::Tensor lxu_cache_state,

at::Tensor lxu_cache_weights,

at::Tensor lfu_state,

int64_t row_alignment);

// 查詢 LRU/LFU 快取:查詢所有索引的快取權重位置。// 在快取中查詢與 linear_cache_indices 對應的槽位

,並對缺失項使用哨兵值。 at::Tensor lxu_cache_lookup_cuda(

at::Tensor linear_cache_indices,

at::Tensor lxu_cache_state,

int64_t invalid_index,

bool gather_cache_stats,

std::optional<at::Tensor> uvm_cache_stats,

std::optional<at::Tensor> num_uniq_cache_indices,

std::optional<at::Tensor> lxu_cache_locations_output);

at::Tensor emulate_cache_miss(

at::Tensor lxu_cache_locations,

const int64_t enforced_misses_per_256,

const bool gather_cache_stats,

at::Tensor uvm_cache_stats);

// 查詢 LRU/LFU 快取:查詢所有索引的快取權重位置。// 在快取中查詢與 linear_cache_indices 對應的槽位,並對缺失項使用哨兵值。

void lxu_cache_flush_cuda(at::Tensor uvm_weights, at::Tensor cache_hash_size_cumsum, at::Tensor cache_index_table_map, at::Tensor weights_offsets, at::Tensor D_offsets, int64_t total_D, at::Tensor lxu_cache_state, at::Tensor lxu_cache_weights, bool stochastic_rounding)

重新整理快取:將快取中的權重儲存到後備儲存中。

void reset_weight_momentum_cuda(at::Tensor dev_weights, at::Tensor uvm_weights, at::Tensor lxu_cache_weights, at::Tensor weights_placements, at::Tensor weights_offsets, at::Tensor momentum1_dev, at::Tensor momentum1_uvm, at::Tensor momentum1_placements, at::Tensor momentum1_offsets, at::Tensor D_offsets, at::Tensor pruned_indices, at::Tensor pruned_indices_offsets, at::Tensor logical_table_ids, at::Tensor buffer_ids, at::Tensor cache_hash_size_cumsum, at::Tensor lxu_cache_state, int64_t total_cache_hash_size)
void lxu_cache_locking_counter_decrement_cuda(at::Tensor lxu_cache_locking_counter, at::Tensor lxu_cache_locations)

根據 lxu_cache_locations 減少 LRU/LFU 快取計數器。

void lxu_cache_locations_update_cuda(at::Tensor lxu_cache_locations, at::Tensor lxu_cache_locations_new, std::optional<at::Tensor> num_uniq_cache_indices)

原地更新 lxu_cache_locations 為新值,僅當 lxu_cache_locations[i] == -1 且 lxu_cache_locations_new[i] >= 0 時進行更新。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

獲取適合初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並解答疑問

檢視資源