• 文档 >
  • Embedding 运算符
快捷方式

Embedding 运算符

CUDA 运算符

Tensor int_nbit_split_embedding_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, int64_t total_D, int64_t max_int2_D, int64_t max_int4_D, int64_t max_int8_D, int64_t max_float16_D, int64_t max_float32_D, Tensor indices, Tensor offsets, int64_t pooling_mode, std::optional<Tensor> indice_weights, int64_t output_dtype, std::optional<Tensor> lxu_cache_weights, std::optional<Tensor> lxu_cache_locations, std::optional<int64_t> row_alignment, std::optional<int64_t> max_float8_D, std::optional<int64_t> fp8_exponent_bits, std::optional<int64_t> fp8_exponent_bias)
Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, int64_t total_D, int64_t max_int2_D, int64_t max_int4_D, int64_t max_int8_D, int64_t max_float16_D, int64_t max_float32_D, Tensor indices, Tensor offsets, int64_t pooling_mode, std::optional<Tensor> indice_weights, int64_t output_dtype, std::optional<Tensor> lxu_cache_weights, std::optional<Tensor> lxu_cache_locations, std::optional<int64_t> row_alignment, std::optional<int64_t> max_float8_D, std::optional<int64_t> fp8_exponent_bits, std::optional<int64_t> fp8_exponent_bias, std::optional<Tensor> cache_hash_size_cumsum, std::optional<int64_t> total_cache_hash_size, std::optional<Tensor> cache_index_table_map, std::optional<Tensor> lxu_cache_state, std::optional<Tensor> lxu_state)

与 int_nbit_split_embedding_codegen_lookup_function 类似,但它执行 UVM_CACHING 查找。

Tensor pruned_hashmap_lookup_cuda(Tensor indices, Tensor offsets, Tensor hash_table, Tensor hash_table_offsets)
Tensor pruned_array_lookup_cuda(Tensor indices, Tensor offsets, Tensor index_remappings, Tensor index_remappings_offsets)
void bounds_check_indices_cuda(Tensor &rows_per_table, Tensor &indices, Tensor &offsets, int64_t bounds_check_mode, Tensor &warning, const std::optional<Tensor> &weights, const std::optional<Tensor> &B_offsets, const int64_t max_B, const std::optional<Tensor> &b_t_map, const int64_t info_B_num_bits, const int64_t info_B_mask, const int8_t bounds_check_version)

CPU 运算符

Tensor int_nbit_split_embedding_codegen_lookup_function_cpu(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, int64_t total_D, int64_t max_int2_D, int64_t max_int4_D, int64_t max_int8_D, int64_t max_float16_D, int64_t max_float32_D, Tensor indices, Tensor offsets, int64_t pooling_mode, std::optional<Tensor> indice_weights, int64_t output_dtype, std::optional<Tensor> lxu_cache_weights, std::optional<Tensor> lxu_cache_locations, std::optional<int64_t> row_alignment, std::optional<int64_t> max_float8_D, std::optional<int64_t> fp8_exponent_bits, std::optional<int64_t> fp8_exponent_bias)
Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function_cpu(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, int64_t total_D, int64_t max_int2_D, int64_t max_int4_D, int64_t max_int8_D, int64_t max_float16_D, int64_t max_float32_D, Tensor indices, Tensor offsets, int64_t pooling_mode, std::optional<Tensor> indice_weights, int64_t output_dtype, std::optional<Tensor> lxu_cache_weights, std::optional<Tensor> lxu_cache_locations, std::optional<int64_t> row_alignment, std::optional<int64_t> max_float8_D, std::optional<int64_t> fp8_exponent_bits, std::optional<int64_t> fp8_exponent_bias, std::optional<Tensor> cache_hash_size_cumsum, std::optional<int64_t> total_cache_hash_size, std::optional<Tensor> cache_index_table_map, std::optional<Tensor> lxu_cache_state, std::optional<Tensor> lxu_state)
void pruned_hashmap_insert_unweighted_cpu(Tensor indices, Tensor dense_indices, Tensor offsets, Tensor hash_table, Tensor hash_table_offsets)
Tensor pruned_hashmap_lookup_unweighted_cpu(Tensor indices, Tensor offsets, Tensor hash_table, Tensor hash_table_offsets)
Tensor pruned_array_lookup_cpu(Tensor indices, Tensor offsets, Tensor index_remappings, Tensor index_remappings_offsets)

文档

访问 PyTorch 的综合开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得您的问题解答

查看资源