快捷方式

EGreedyModule

class torchrl.modules.EGreedyModule(*args, **kwargs)[源码]

Epsilon-Greedy 探索模块。

该模块根据 epsilon-greedy 探索策略随机更新 tensordict 中的动作。每次调用时,会根据某个概率阈值执行随机抽取(每个动作一次)。如果抽取成功,对应的动作将被替换为从提供的动作规范 (action spec) 中抽取的随机样本。未被抽取的动作将保持不变。

参数:
  • spec (TensorSpec) – 用于采样动作的规范。

  • eps_init (标量, 可选) – 初始 epsilon 值。默认为 1.0

  • eps_end (标量, 可选) – 最终 epsilon 值。默认为 0.1

  • annealing_num_steps (int, 可选) – epsilon 达到 eps_end 值所需的步数。默认为 1000

关键字参数:
  • action_key (NestedKey, 可选) – 输入 tensordict 中动作所在的键。默认为 "action"

  • action_mask_key (NestedKey, 可选) – 输入 tensordict 中动作掩码所在的键。默认为 None(表示没有掩码)。

  • device (torch.device, 可选) – 探索模块所在的设备。

注意

至关重要的是,要在训练循环中调用 step() 来更新探索因子。由于很难捕获这种遗漏,因此如果遗漏了此步骤,不会引发警告或异常!

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictSequential
>>> from torchrl.modules import EGreedyModule, Actor
>>> from torchrl.data import Bounded
>>> torch.manual_seed(0)
>>> spec = Bounded(-1, 1, torch.Size([4]))
>>> module = torch.nn.Linear(4, 4, bias=False)
>>> policy = Actor(spec=spec, module=module)
>>> explorative_policy = TensorDictSequential(policy,  EGreedyModule(eps_init=0.2))
>>> td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10])
>>> print(explorative_policy(td).get("action"))
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.9055, -0.9277, -0.6295, -0.2532],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000]], grad_fn=<AddBackward0>)
forward(tensordict: TensorDictBase) TensorDictBase[源码]

定义每次调用时执行的计算。

应由所有子类覆盖。

注意

虽然前向传播的实现需要在该函数中定义,但之后应调用 Module 实例而不是直接调用此函数,因为前者负责运行已注册的钩子,而后者则会静默忽略它们。

step(frames: int = 1) None[源码]

一次 epsilon 衰减。

在此方法被调用 self.annealing_num_steps 次后,后续调用将无效。

参数:

frames (int, 可选) – 自上次步进以来的帧数。默认为 1

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并解答疑问

查看资源