• 文档 >
  • SSD Embedding 算子
快捷方式

SSD Embedding 算子

CUDA 算子

enum RocksdbWriteMode

rocksdb 写入模式

在 SSD 卸载中,每个训练迭代有 3 种写入方式:FWD_ROCKSDB_READ:缓存查找会将 rocksdb 中的未缓存数据移动到前向路径上的 L2 缓存中

FWD_L1_EVICTION:L1 缓存驱逐会将数据驱逐到前向路径上的 L2 缓存中

BWD_L1_CNFLCT_MISS_WRITE_BACK:L1 冲突未命中会将数据插入到 L2 中以进行后向路径上的嵌入更新

以上所有的 L2 缓存填充在 L2 缓存满时都可能触发 rocksdb 写入

此外,我们将在 L2 刷新时执行 ssd IO

enumerator FWD_ROCKSDB_READ
enumerator FWD_L1_EVICTION
enumerator BWD_L1_CNFLCT_MISS_WRITE_BACK
enumerator FLUSH
inline size_t hash_shard(int64_t id, size_t num_shards)

用于 SSD L2 缓存和 rocksdb 分片算法的哈希函数

参数:
  • id – 分片键

  • num_shards – 分片范围

返回值:

分片 ID 范围为 [0, num_shards)

std::tuple<at::Tensor, at::Tensor> get_bucket_sorted_indices_and_bucket_tensor(const at::Tensor &unordered_indices, int64_t hash_mode, int64_t bucket_start, int64_t bucket_end, std::optional<int64_t> bucket_size, std::optional<int64_t> total_num_buckets)

给定一个包含随机顺序 id 的张量,返回 2 个张量。张量 1 包含按桶升序排序的 id,例如给定 [1,2,3,4] 和 2 个桶 [1, 4) 和 [4, 7),输出将是 [1,2,3,4] 或 [2, 1, 3, 4],id 1, 2, 3 必须在 4 之前,但 1 2 3 可以按任意顺序排列。张量 2 包含每个桶 ID(张量偏移)中的嵌入数量,在上面的示例中,张量 2 将是 [3, 1],其中第一个项对应于第一个桶 ID,值 3 表示第一个桶 ID 中有 3 个 id

参数:
  • unordered_indices – 无序 id,此处的 id 可能是原始(非线性化)id

  • hash_mode – 0 表示按模哈希,1 表示按交织哈希

  • bucket_start – 全局桶 ID,桶范围的起始

  • bucket_end – 全局桶 ID,桶范围的结束

  • bucket_size – 可选的桶的虚拟大小(输入空间,例如 2^50)

  • total_num_buckets – 可选的,每个训练模型的总桶数

返回值:

包含 2 个张量的列表,第一个张量是按桶排序的 id,第二个张量是桶大小

void cuda_callback_func(cudaStream_t stream, cudaError_t status, void *functor)

cudaStreamAddCallback 的回调函数

cudaStreamAddCallback 的一个通用回调函数,即 cudaStreamCallback_t callback。此函数将 functor 转换为 void 函数,调用它然后删除它(删除发生在另一个线程中)

参数:
  • streamcudaStreamAddCallback 操作的 CUDA 流

  • status – CUDA 状态

  • functor – 将被调用的函数对象

返回值:

Tensor masked_index_put_cuda(Tensor self, Tensor indices, Tensor values, Tensor count, const bool use_pipeline, const int64_t preferred_sms)

类似于 torch.Tensor.index_put,但忽略 indices < 0

masked_index_put_cuda 仅支持 2D 输入 values。它使用 indices 中 >= 0 的行索引,将 values 中的 count 行放入 self 中。

# Equivalent PyTorch Python code
indices = indices[:count]
filter_ = indices >= 0
indices_ = indices[filter_]
self[indices_] = values[filter_.nonzero().flatten()]
参数:
  • self – 2D 输出张量(被索引的张量)

  • indices – 1D 索引张量

  • values – 2D 输入张量

  • count – 包含要处理的 indices 长度的张量

  • use_pipeline – 一个标志,指示此核函数将与其他核函数重叠。如果为 true,则使用一部分 SM 以减少资源竞争

  • preferred_sms – 当 use_pipeline=true 时,核函数首选使用的 SM 数量。当 use_pipeline=false 时,此值被忽略。

返回值:

self 张量

Tensor masked_index_select_cuda(Tensor self, Tensor indices, Tensor values, Tensor count, const bool use_pipeline, const int64_t preferred_sms)

类似于 torch.index_select,但忽略 indices < 0

masked_index_select_cuda 仅支持 2D 输入 values。它将 values 中由 indices(其中 indices >= 0)指定的 count 行放入 self 中。

# Equivalent PyTorch Python code
indices = indices[:count]
filter_ = indices >= 0
indices_ = indices[filter_]
self[filter_.nonzero().flatten()] = values[indices_]
参数:
  • self – 2D 输出张量

  • indices – 1D 索引张量

  • values – 2D 输入张量(被索引的张量)

  • count – 包含要处理的 indices 长度的张量

  • use_pipeline – 一个标志,指示此核函数将与其他核函数重叠。如果为 true,则使用一部分 SM 以减少资源竞争

  • preferred_sms – 当 use_pipeline=true 时,核函数首选使用的 SM 数量。当 use_pipeline=false 时,此值被忽略。

返回值:

self 张量

std::tuple<Tensor, Tensor> ssd_generate_row_addrs_cuda(const Tensor &lxu_cache_locations, const Tensor &assigned_cache_slots, const Tensor &linear_index_inverse_indices, const Tensor &unique_indices_count_cumsum, const Tensor &cache_set_inverse_indices, const Tensor &lxu_cache_weights, const Tensor &inserted_ssd_weights, const Tensor &unique_indices_length, const Tensor &cache_set_sorted_unique_indices)

为 SSD TBE 数据生成内存地址。

从 SSD 检索的数据可以存储在暂存区 (HBM) 或 LXU 缓存 (同样是 HBM) 中。lxu_cache_locations 用于指定数据的位置。如果位置为 -1,则关联索引的数据在暂存区中;否则,它在缓存中。为了方便 TBE 核函数访问数据,此算子为每个索引生成首字节的内存地址。访问数据时,TBE 核函数只需将地址转换为指针。

此外,此算子还会生成后向驱逐索引的列表,这些索引的数据基本上位于暂存区中。

参数:
  • lxu_cache_locations – 包含用于存储完整索引列表数据的缓存槽位的张量。-1 是一个指示数据不在缓存中的哨兵值。

  • assigned_cache_slots – 包含用于唯一索引列表的缓存槽位的张量。-1 指示数据不在缓存中

  • linear_index_inverse_indices – 包含线性索引排序前原始位置的张量

  • unique_indices_count_cumsum – 包含唯一索引计数(count)的排他前缀和结果的张量

  • cache_set_inverse_indices_curr – 包含当前迭代中缓存集排序前原始位置的张量

  • lxu_cache_weights – LXU 缓存张量

  • inserted_ssd_weights – 暂存区张量

  • unique_indices_length – 包含唯一索引数量的张量(GPU 张量)

  • cache_set_sorted_unique_indices – 包含与排序后的唯一缓存集关联的唯一索引的张量

返回值:

一个张量元组(SSD 行地址张量和后向驱逐索引张量)

void ssd_update_row_addrs_cuda(const Tensor &ssd_row_addrs_curr, const Tensor &inserted_ssd_weights_curr_next_map, const Tensor &lxu_cache_locations_curr, const Tensor &linear_index_inverse_indices_curr, const Tensor &unique_indices_count_cumsum_curr, const Tensor &cache_set_inverse_indices_curr, const Tensor &lxu_cache_weights, const Tensor &inserted_ssd_weights_next, const Tensor &unique_indices_length_curr)

更新 SSD TBE 数据的内存地址。

启用管道预取时,当前迭代暂存区中的数据可以在预取步骤期间移动到 L1 或下一迭代的暂存区。此算子更新已重定位到正确位置的数据的内存地址。

参数:
  • ssd_row_addrs_curr – 包含当前迭代行地址的张量

  • inserted_ssd_weights_curr_next_map – 包含当前迭代中每个索引在下一迭代暂存区中的位置映射的张量。(-1 = 数据尚未移动)。inserted_ssd_weights_curr_next_map[i] 即为该位置

  • lxu_cache_locations_curr – 包含用于存储当前迭代的完整索引列表数据的缓存槽位的张量。-1 是一个指示数据不在缓存中的哨兵值。

  • linear_index_inverse_indices_curr – 包含当前迭代中线性索引排序前原始位置的张量

  • unique_indices_count_cumsum_curr – 包含当前迭代中唯一索引计数(count)的排他前缀和结果的张量

  • cache_set_inverse_indices_curr – 包含当前迭代中缓存集排序前原始位置的张量

  • lxu_cache_weights – LXU 缓存张量

  • inserted_ssd_weights_next – 下一迭代的暂存区张量

  • unique_indices_length_curr – 包含当前迭代唯一索引数量的张量(GPU 张量)

返回值:

void compact_indices_cuda(std::vector<Tensor> compact_indices, Tensor compact_count, std::vector<Tensor> indices, Tensor masks, Tensor count)

压缩给定的索引列表。

此算子根据给定的掩码(一个包含 0 或 1 的张量)压缩给定的索引列表。该算子移除对应掩码为 0 的索引。它只对 count 个元素进行操作(而非整个张量)。

示例

indices = [[0, 3, -1, 3, -1, -1, 7], [0, 2, 2, 3, -1, 9, 7]]
masks = [1, 1, 0, 1, 0, 0, 1]
count = 5

# x represents an arbitrary value
compact_indices = [[0, 3, 3, x, x, x, x], [0, 2, 3, x, x, x, x]]
compact_count = 3
参数:
  • compact_indices – 压缩索引的列表(输出索引)。

  • compact_count – 一个 tensor,包含压缩后的元素数量

  • indices – 要压缩的索引输入列表

  • masks – 一个 tensor,包含 0 或 1,用于指示是否删除/保留元素。0 = 移除对应的索引。1 = 保留对应的索引。@count count 一个 tensor,包含要压缩的元素数量

class CacheLibCache
#include <cachelib_cache.h>

一个用于 Cachelib 交互的 Cachelib 包装类。

它用于维护所有与缓存相关的操作,包括初始化、插入、查找和逐出。它在逐出逻辑方面是状态化的,调用者必须专门获取和重置与逐出相关的状态。Cachelib 相关的优化将被捕获在此类中,例如 fetch 和延迟 markUseful 以提高 get 性能

注意

此类仅处理单个 Cachelib 读取/更新。并行化在调用者端完成

class EmbeddingParameterServer : public EmbeddingKVDB
#include <ps_table_batched_embeddings.h>

EmbeddingKVDB 为训练参数服务 (TPS) 客户端实现的一个类。

class CacheContext
#include <kv_db_table_batched_embeddings.h>

它保存 l2cache 查找结果。

num_misses 是 l2 缓存查找中的未命中数量,cached_addr_list 是预分配的,其大小与查找次数相同,以实现更好的并行性,并且无效位置(缓存未命中)将保留 sentinel 值

struct QueueItem
#include <kv_db_table_batched_embeddings.h>

用于后台 L2/rocksdb 更新的队列项

indices/weights/count 是相应的 set() 参数

read_handles 是 cachelib 抽象的索引/嵌入对元数据,稍后将在更新 cachelib LRU 队列时使用,因为它与 EmbeddingKVDB::get_cache() 分离

mode 用于监控 rocksdb 写入,详细解释请查看 RocksdbWriteMode

class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB>
#include <kv_db_table_batched_embeddings.h>

一个用于与不同缓存层和存储层交互的类,公共调用在 cuda stream 上执行。

目前它被 TBE 用于将 Key(Embedding Index)Value(Embeddings)卸载到 DRAM、SSD 或远程存储,以在不耗尽 HBM 资源的情况下提供更好的可扩展性

继承自 DramKVEmbeddingCache< weight_type >, EmbeddingParameterServer, EmbeddingRocksDB

class EmbeddingRocksDB : public EmbeddingKVDB
#include <ssd_table_batched_embeddings.h>

EmbeddingKVDB 为 RocksDB 实现的一个类。

继承自 MockEmbeddingRocksDB

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源