快捷方式

TanhModule

class torchrl.modules.tensordict_module.TanhModule(*args, **kwargs)[source]

用于具有有界动作空间的确定性策略的 Tanh 模块。

此转换用作 TensorDictModule 层以将网络输出映射到有界空间。

参数:
  • in_keys (str 列表或 str 元组) – 模块的输入键。

  • out_keys (str 列表或 str 元组,可选) – 模块的输出键。如果未提供,则假定与 in_keys 相同的键。

关键字参数:
  • spec (TensorSpec, 可选) – 如果提供,则为输出的规格。如果提供 CompositeSpec,则其键必须与 out_keys 中的键匹配。否则,将假定 out_keys 的键,并且所有输出都使用相同的规格。

  • low (float, np.ndarraytorch.Tensor) – 空间的下界。如果未提供并且未提供规格,则假定为 -1。如果提供规格,则将检索规格的最小值。

  • high (float, np.ndarraytorch.Tensor) – 空间的上界。如果未提供并且未提供规格,则假定为 1。如果提供规格,则将检索规格的最大值。

  • clamp (bool, 可选) – 如果为 True,则输出将被钳制以保持在边界内,但至少要与它们隔开一个最小分辨率。默认为 False

示例

>>> from tensordict import TensorDict
>>> # simplest use case: -1 - 1 boundaries
>>> torch.manual_seed(0)
>>> in_keys = ["action"]
>>> mod = TanhModule(
...     in_keys=in_keys,
... )
>>> data = TensorDict({"action": torch.randn(5) * 10}, [])
>>> data = mod(data)
>>> data['action']
tensor([ 1.0000, -0.9944, -1.0000,  1.0000, -1.0000])
>>> # low and high can be customized
>>> low = -2
>>> high = 1
>>> mod = TanhModule(
...     in_keys=in_keys,
...     low=low,
...     high=high,
... )
>>> data = TensorDict({"action": torch.randn(5) * 10}, [])
>>> data = mod(data)
>>> data['action']
tensor([-2.0000,  0.9991,  1.0000, -2.0000, -1.9991])
>>> # A spec can be provided
>>> from torchrl.data import BoundedTensorSpec
>>> spec = BoundedTensorSpec(low, high, shape=())
>>> mod = TanhModule(
...     in_keys=in_keys,
...     low=low,
...     high=high,
...     spec=spec,
...     clamp=False,
... )
>>> # One can also work with multiple keys
>>> in_keys = ['a', 'b']
>>> spec = CompositeSpec(
...     a=BoundedTensorSpec(-3, 0, shape=()),
...     b=BoundedTensorSpec(0, 3, shape=()))
>>> mod = TanhModule(
...     in_keys=in_keys,
...     spec=spec,
... )
>>> data = TensorDict(
...     {'a': torch.randn(10), 'b': torch.randn(10)}, batch_size=[])
>>> data = mod(data)
>>> data['a']
tensor([-2.3020, -1.2299, -2.5418, -0.2989, -2.6849, -1.3169, -2.2690, -0.9649,
        -2.5686, -2.8602])
>>> data['b']
tensor([2.0315, 2.8455, 2.6027, 2.4746, 1.7843, 2.7782, 0.2111, 0.5115, 1.4687,
        0.5760])
forward(tensordict)[source]

定义每次调用时执行的计算。

所有子类都应该重写。

注意

虽然前向传递的配方需要在此函数内定义,但应随后调用 Module 实例而不是此函数,因为前者会处理运行注册的钩子,而后者会静默忽略它们。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源