快捷方式

TanhModule

class torchrl.modules.tensordict_module.TanhModule(*args, **kwargs)[source]

用于确定性策略的 Tanh 模块,具有有界动作空间。

此变换用作 TensorDictModule 层,将网络输出映射到有界空间。

参数:
  • in_keys (list of str or tuples of str) – 模块的输入键。

  • out_keys (list of str or tuples of str, optional) – 模块的输出键。如果未提供,则假定与 in_keys 相同的键。

关键字参数:
  • spec (TensorSpec, optional) – 如果提供,则为输出的 spec。如果提供了 Composite,则其键必须与 out_keys 中的键匹配。否则,假定 out_keys 的键,并且所有输出都使用相同的 spec。

  • low (float, np.ndarray or torch.Tensor) – 空间的下界。如果未提供且未提供 spec,则假定为 -1。如果提供了 spec,则将检索 spec 的最小值。

  • high (float, np.ndarray or torch.Tensor) – 空间的上界。如果未提供且未提供 spec,则假定为 1。如果提供了 spec,则将检索 spec 的最大值。

  • clamp (bool, optional) – 如果 True,则输出将被钳制在边界内,但与边界保持最小分辨率。默认为 False

示例

>>> from tensordict import TensorDict
>>> # simplest use case: -1 - 1 boundaries
>>> torch.manual_seed(0)
>>> in_keys = ["action"]
>>> mod = TanhModule(
...     in_keys=in_keys,
... )
>>> data = TensorDict({"action": torch.randn(5) * 10}, [])
>>> data = mod(data)
>>> data['action']
tensor([ 1.0000, -0.9944, -1.0000,  1.0000, -1.0000])
>>> # low and high can be customized
>>> low = -2
>>> high = 1
>>> mod = TanhModule(
...     in_keys=in_keys,
...     low=low,
...     high=high,
... )
>>> data = TensorDict({"action": torch.randn(5) * 10}, [])
>>> data = mod(data)
>>> data['action']
tensor([-2.0000,  0.9991,  1.0000, -2.0000, -1.9991])
>>> # A spec can be provided
>>> from torchrl.data import Bounded
>>> spec = Bounded(low, high, shape=())
>>> mod = TanhModule(
...     in_keys=in_keys,
...     low=low,
...     high=high,
...     spec=spec,
...     clamp=False,
... )
>>> # One can also work with multiple keys
>>> in_keys = ['a', 'b']
>>> spec = Composite(
...     a=Bounded(-3, 0, shape=()),
...     b=Bounded(0, 3, shape=()))
>>> mod = TanhModule(
...     in_keys=in_keys,
...     spec=spec,
... )
>>> data = TensorDict(
...     {'a': torch.randn(10), 'b': torch.randn(10)}, batch_size=[])
>>> data = mod(data)
>>> data['a']
tensor([-2.3020, -1.2299, -2.5418, -0.2989, -2.6849, -1.3169, -2.2690, -0.9649,
        -2.5686, -2.8602])
>>> data['b']
tensor([2.0315, 2.8455, 2.6027, 2.4746, 1.7843, 2.7782, 0.2111, 0.5115, 1.4687,
        0.5760])
forward(tensordict=None)[source]

定义每次调用时执行的计算。

应由所有子类重写。

注意

尽管 forward 传递的配方需要在该函数中定义,但之后应调用 Module 实例而不是此函数,因为前者负责运行注册的钩子,而后者会静默忽略它们。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得您的问题解答

查看资源