TanhModule¶
- class torchrl.modules.tensordict_module.TanhModule(*args, **kwargs)[source]¶
用于具有有界动作空间的确定性策略的 Tanh 模块。
此转换用作 TensorDictModule 层以将网络输出映射到有界空间。
- 参数:
in_keys (str 列表或 str 元组) – 模块的输入键。
out_keys (str 列表或 str 元组,可选) – 模块的输出键。如果未提供,则假定与 in_keys 相同的键。
- 关键字参数:
spec (TensorSpec, 可选) – 如果提供,则为输出的规格。如果提供 CompositeSpec,则其键必须与 out_keys 中的键匹配。否则,将假定 out_keys 的键,并且所有输出都使用相同的规格。
low (float, np.ndarray 或 torch.Tensor) – 空间的下界。如果未提供并且未提供规格,则假定为 -1。如果提供规格,则将检索规格的最小值。
high (float, np.ndarray 或 torch.Tensor) – 空间的上界。如果未提供并且未提供规格,则假定为 1。如果提供规格,则将检索规格的最大值。
clamp (bool, 可选) – 如果为
True
,则输出将被钳制以保持在边界内,但至少要与它们隔开一个最小分辨率。默认为False
。
示例
>>> from tensordict import TensorDict >>> # simplest use case: -1 - 1 boundaries >>> torch.manual_seed(0) >>> in_keys = ["action"] >>> mod = TanhModule( ... in_keys=in_keys, ... ) >>> data = TensorDict({"action": torch.randn(5) * 10}, []) >>> data = mod(data) >>> data['action'] tensor([ 1.0000, -0.9944, -1.0000, 1.0000, -1.0000]) >>> # low and high can be customized >>> low = -2 >>> high = 1 >>> mod = TanhModule( ... in_keys=in_keys, ... low=low, ... high=high, ... ) >>> data = TensorDict({"action": torch.randn(5) * 10}, []) >>> data = mod(data) >>> data['action'] tensor([-2.0000, 0.9991, 1.0000, -2.0000, -1.9991]) >>> # A spec can be provided >>> from torchrl.data import BoundedTensorSpec >>> spec = BoundedTensorSpec(low, high, shape=()) >>> mod = TanhModule( ... in_keys=in_keys, ... low=low, ... high=high, ... spec=spec, ... clamp=False, ... ) >>> # One can also work with multiple keys >>> in_keys = ['a', 'b'] >>> spec = CompositeSpec( ... a=BoundedTensorSpec(-3, 0, shape=()), ... b=BoundedTensorSpec(0, 3, shape=())) >>> mod = TanhModule( ... in_keys=in_keys, ... spec=spec, ... ) >>> data = TensorDict( ... {'a': torch.randn(10), 'b': torch.randn(10)}, batch_size=[]) >>> data = mod(data) >>> data['a'] tensor([-2.3020, -1.2299, -2.5418, -0.2989, -2.6849, -1.3169, -2.2690, -0.9649, -2.5686, -2.8602]) >>> data['b'] tensor([2.0315, 2.8455, 2.6027, 2.4746, 1.7843, 2.7782, 0.2111, 0.5115, 1.4687, 0.5760])