快捷方式

选择键

class torchrl.trainers.SelectKeys(keys: Sequence[str])[source]

选择 TensorDict 批次中的键。

参数:

keys (字符串的可迭代对象) – 要在 tensordict 中选择的键。

示例

>>> trainer = make_trainer()
>>> key1 = "first key"
>>> key2 = "second key"
>>> td = TensorDict(
...     {
...         key1: torch.randn(3),
...         key2: torch.randn(3),
...     },
...     [],
... )
>>> trainer.register_op("batch_process", SelectKeys([key1]))
>>> td_out = trainer._process_batch_hook(td)
>>> assert key1 in td_out.keys()
>>> assert key2 not in td_out.keys()
register(trainer, name='select_keys') None[source]

在默认位置的训练器中注册钩子。

参数:
  • trainer (Trainer) – 必须注册钩子的训练器。

  • name (str) – 钩子的名称。

注意

要将钩子注册到默认位置以外的位置,请使用 register_op()

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源