快捷方式

ConsistentDropoutModule

class torchrl.modules.ConsistentDropoutModule(*args, **kwargs)[source]

ConsistentDropout 的 TensorDictModule 包装器。

参数:
  • p (float, optional) – Dropout 概率。默认值: 0.5

  • in_keys (NestedKeyNestedKeys 列表) – 要从输入 tensordict 中读取并传递给此模块的键。

  • out_keys (NestedKeyNestedKeys 的可迭代对象) – 要写入输入 tensordict 的键。默认为 in_keys 值。

关键字参数:
  • input_shape (tuple, optional) – 输入的形状(非批处理),用于使用 make_tensordict_primer() 生成 tensordict 引子。

  • input_dtype (torch.dtype, optional) – 引子的输入 dtype。如果未传递,则假定为 torch.get_default_dtype

注意

要在策略中使用此类,需要mask在重置时重置。这可以通过 TensorDictPrimer 变换来实现,该变换可以使用 make_tensordict_primer() 获得。有关更多信息,请参阅此方法。

示例

>>> from tensordict import TensorDict
>>> module = ConsistentDropoutModule(p = 0.1)
>>> td = TensorDict({"x": torch.randn(3, 4)}, [3])
>>> module(td)
TensorDict(
    fields={
        mask_6127171760: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.bool, is_shared=False),
        x: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
forward(tensordict)[source]

定义每次调用时执行的计算。

应由所有子类覆盖。

注意

尽管前向传递的配方需要在该函数内定义,但应在此之后调用 Module 实例而不是此函数,因为前者负责运行注册的钩子,而后者则静默地忽略它们。

make_tensordict_primer()[source]

为环境创建一个 tensordict 引子,以便在重置调用期间生成随机 mask。

另请参阅

torchrl.modules.utils.get_primers_from_module(),用于生成给定模块的所有引子的方法。

模块。

示例

>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>> from torchrl.envs import GymEnv, StepCounter, SerialEnv
>>> m = Seq(
...     Mod(torch.nn.Linear(7, 4), in_keys=["observation"], out_keys=["intermediate"]),
...     ConsistentDropoutModule(
...         p=0.5,
...         input_shape=(2, 4),
...         in_keys="intermediate",
...     ),
...     Mod(torch.nn.Linear(4, 7), in_keys=["intermediate"], out_keys=["action"]),
... )
>>> primer = get_primers_from_module(m)
>>> env0 = GymEnv("Pendulum-v1").append_transform(StepCounter(5))
>>> env1 = GymEnv("Pendulum-v1").append_transform(StepCounter(6))
>>> env = SerialEnv(2, [lambda env=env0: env, lambda env=env1: env])
>>> env = env.append_transform(primer)
>>> r = env.rollout(10, m, break_when_any_done=False)
>>> mask = [k for k in r.keys() if k.startswith("mask")][0]
>>> assert (r[mask][0, :5] != r[mask][0, 5:6]).any()
>>> assert (r[mask][0, :4] == r[mask][0, 4:5]).all()

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源