TanhModule¶
- class torchrl.modules.tensordict_module.TanhModule(*args, **kwargs)[source]¶
用于确定性策略的 Tanh 模块,具有有界动作空间。
此变换用作 TensorDictModule 层,将网络输出映射到有界空间。
- 参数:
in_keys (list of str or tuples of str) – 模块的输入键。
out_keys (list of str or tuples of str, optional) – 模块的输出键。如果未提供,则假定与 in_keys 相同的键。
- 关键字参数:
spec (TensorSpec, optional) – 如果提供,则为输出的 spec。如果提供了 Composite,则其键必须与 out_keys 中的键匹配。否则,假定 out_keys 的键,并且所有输出都使用相同的 spec。
low (float, np.ndarray or torch.Tensor) – 空间的下界。如果未提供且未提供 spec,则假定为 -1。如果提供了 spec,则将检索 spec 的最小值。
high (float, np.ndarray or torch.Tensor) – 空间的上界。如果未提供且未提供 spec,则假定为 1。如果提供了 spec,则将检索 spec 的最大值。
clamp (bool, optional) – 如果
True
,则输出将被钳制在边界内,但与边界保持最小分辨率。默认为False
。
示例
>>> from tensordict import TensorDict >>> # simplest use case: -1 - 1 boundaries >>> torch.manual_seed(0) >>> in_keys = ["action"] >>> mod = TanhModule( ... in_keys=in_keys, ... ) >>> data = TensorDict({"action": torch.randn(5) * 10}, []) >>> data = mod(data) >>> data['action'] tensor([ 1.0000, -0.9944, -1.0000, 1.0000, -1.0000]) >>> # low and high can be customized >>> low = -2 >>> high = 1 >>> mod = TanhModule( ... in_keys=in_keys, ... low=low, ... high=high, ... ) >>> data = TensorDict({"action": torch.randn(5) * 10}, []) >>> data = mod(data) >>> data['action'] tensor([-2.0000, 0.9991, 1.0000, -2.0000, -1.9991]) >>> # A spec can be provided >>> from torchrl.data import Bounded >>> spec = Bounded(low, high, shape=()) >>> mod = TanhModule( ... in_keys=in_keys, ... low=low, ... high=high, ... spec=spec, ... clamp=False, ... ) >>> # One can also work with multiple keys >>> in_keys = ['a', 'b'] >>> spec = Composite( ... a=Bounded(-3, 0, shape=()), ... b=Bounded(0, 3, shape=())) >>> mod = TanhModule( ... in_keys=in_keys, ... spec=spec, ... ) >>> data = TensorDict( ... {'a': torch.randn(10), 'b': torch.randn(10)}, batch_size=[]) >>> data = mod(data) >>> data['a'] tensor([-2.3020, -1.2299, -2.5418, -0.2989, -2.6849, -1.3169, -2.2690, -0.9649, -2.5686, -2.8602]) >>> data['b'] tensor([2.0315, 2.8455, 2.6027, 2.4746, 1.7843, 2.7782, 0.2111, 0.5115, 1.4687, 0.5760])