RolloutFromModel¶
- class torchrl.data.RolloutFromModel(model, ref_model, reward_model, kl_coef=0.1, max_new_tokens=50, score_clip=10.0, kl_scheduler: KLControllerBase | None = None, num_steps: int | None = None)[source]¶
一个用于执行因果语言模型滚动的类。
假设此类包装的模型以标记化的文本作为输入,其任务是在阅读了 n 个之前的单词后预测句子中的下一个单词。
- 参数:
model (transformers.Transformer) – 要使用的模型。应具有
generate()
方法。ref_model (transformers.Transformer) –
model
的冻结版本,其中参数处于其初始配置中。这用于计算奖励的 KL 惩罚,以防止模型在训练期间偏离参考模型太远。reward_model – (nn.Module, tensordict.nn.TensorDictModule): 一个模型,它在给定
input_ids
和attention_mask
的情况下,计算每个标记的奖励和 end_scores(每个序列中最后一个标记的奖励)。kl_coef – (float, 可选): 初始 kl 系数。
max_new_tokens (int, 可选) – 序列的最大长度。默认为 50。
score_clip (float, 可选) – 奖励模型的得分被剪裁到范围
(-score_clip, score_clip)
。默认为 10。kl_scheduler (KLControllerBase, 可选) – KL 系数调度器。
num_steps (int, 可选) – 两次优化之间的步数。
示例
>>> from tensordict.nn import TensorDictModule >>> from torchrl.modules.models.rlhf import GPT2RewardModel >>> from torchrl.data.rlhf.utils import RolloutFromModel >>> from torchrl.data.rlhf.dataset import get_dataloader >>> from torchrl.data.rlhf.prompt import PromptData >>> from transformers import GPT2LMHeadModel >>> >>> dl = get_dataloader( ... batch_size=4, ... block_size=550, ... tensorclass_type=PromptData, ... device="cpu", ... dataset_name="CarperAI/openai_summarize_tldr", ... ) >>> model = GPT2LMHeadModel.from_pretrained("gpt2") >>> # we load ref_model with random weights so it differs from model >>> ref_model = GPT2LMHeadModel(GPT2LMHeadModel.config_class()) >>> reward_model = GPT2RewardModel(model_path="gpt2") >>> rollout_from_model = RolloutFromModel(model, ref_model, reward_model) >>> >>> batch = next(dl) >>> rollout = rollout_from_model.rollout_from_data(batch) >>> rollout TensorDict( fields={ action: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.int64, is_shared=False), attention_mask: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.bool, is_shared=False), input_ids: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.int64, is_shared=False), next: TensorDict( fields={ attention_mask: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.bool, is_shared=False), done: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.bool, is_shared=False), input_ids: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.int64, is_shared=False), reward: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False), reward_kl: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False), reward_raw: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 50]), device=cpu, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 50]), device=cpu, is_shared=False)
- create_rollout_td(batch, generated, log_probs, log_ratio)[source]¶
生成的数据的 TensorDict 包装器。
此函数获取一个批次以及生成的标记,并复制从 TorchRL 环境中采样每个时间步长一个标记的滚动中获得的 tensordict 结构。
- 参数:
batch (TensorDict) – 包含原始提示以及指示提示右侧索引的字段“rindex”的数据批次。
generated (torch.Tensor) – 标记化的提示,后跟生成的标记。这可以通过调用
generate
方法获得。log_probs (torch.Tensor) – 生成的标记的对数概率。可以通过调用
generate
方法获得。log_ratio (torch.Tensor) – 根据生成模型和参考模型对生成的标记概率的对数比率。可以通过调用
generate
方法获得。
- 返回值:
"action"
: 操作序列(生成的标记)"input_ids"
: 在每个时间步长传递给生成模型的 input_ids。"attention_mask"
: 在每个时间步长传递给生成模型的 attention_masks"sample_log_prob"
: 生成期间每个标记的对数概率("next", "input_ids")
: 生成后的标记序列。构成生成下一个标记时将使用的输入的一部分。("next", "attention_mask")
: 标记生成后更新的 attention_mask。在下一个时间步长传递给生成模型("next", "terminated")
: 布尔数组,指示我们是否已到达终止状态(因为我们生成了 EOS 标记或因为我们到达了标记限制)("next", "done")
: 布尔数组,指示我们是否已到达最终状态。当前是"terminated"
的副本。("next", "reward")
: 在每个时间步长收到的奖励("next", "reward_raw")
: 来自奖励模型的原始奖励,不含 KL 项。这主要用于调试和日志记录,不会用于训练("next", "reward_kl")
: 来自奖励的 KL 项。这主要用于调试和日志记录,不会用于训练。
- 返回类型:
一个具有以下键的
TensorDict
- generate(batch: PromptData, generation_config=None)[source]¶
从从数据收集器采样的数据批次中生成一系列标记。
- 参数:
batch (PromptData) – 用于训练的数据。必须具有
input_ids
和prompt_rindex
字段。generation_config (GenerationConfig, optional) – 生成调用配置。
- 返回值:
- 一个 [B x (Ti +To)] 整数(token)序列,
其中 Ti 是输入序列的长度,To 是生成序列的长度。
log_probs_gen: 生成token的对数概率。log_ratio: 生成模型与冻结版本之间的概率的对数比率
模型。
- 返回类型:
generated (torch.Tensor)