选择键¶
- class torchrl.trainers.SelectKeys(keys: Sequence[str])[source]¶
选择 TensorDict 批次中的键。
- 参数:
keys (字符串的可迭代对象) – 要在 tensordict 中选择的键。
示例
>>> trainer = make_trainer() >>> key1 = "first key" >>> key2 = "second key" >>> td = TensorDict( ... { ... key1: torch.randn(3), ... key2: torch.randn(3), ... }, ... [], ... ) >>> trainer.register_op("batch_process", SelectKeys([key1])) >>> td_out = trainer._process_batch_hook(td) >>> assert key1 in td_out.keys() >>> assert key2 not in td_out.keys()