哈希¶
- class torchrl.envs.transforms.Hash(in_keys: Sequence[NestedKey], out_keys: Sequence[NestedKey], *, hash_fn: Callable = None, seed: Any | None = None, use_raw_nontensor: bool = False)[源]¶
向 tensordict 中添加哈希值。
- 参数:
in_keys (嵌套键序列) – 要进行哈希处理的值对应的键。
out_keys (嵌套键序列) – 生成的哈希值对应的键。
in_keys_inv (嵌套键序列, 可选) –
在 inv 调用期间要进行哈希处理的值对应的键。
注意
如果需要反向映射,应与键列表一起传递一个哈希到值的映射集
Dict[Tuple[int], Any]
,以便让Hash
transform 知道如何从给定的哈希中恢复值。此映射集不会被复制,因此在 transform 实例化后可以在同一工作区中修改它,并且这些修改将反映在映射中。缺失的哈希将被映射到None
。out_keys_inv (嵌套键序列, 可选) – 在 inv 调用期间生成的哈希值对应的键。
- 关键字参数:
hash_fn (可调用对象, 可选) – 要使用的哈希函数。如果提供了
seed
,则哈希函数必须接受它作为第二个参数。默认值为Hash.reproducible_hash
。seed (可选) – 哈希函数要使用的种子,如果需要的话。
use_raw_nontensor (bool, 可选) – 如果为
False
,则在对fn
调用NonTensorData
/NonTensorStack
输入之前,会从中提取数据。如果为True
,则直接将原始NonTensorData
/NonTensorStack
输入提供给fn
,fn
必须支持这些输入。默认值为False
。Hash (>>> from torchrl.envs import GymEnv, UnaryTransform,) –
GymEnv (>>> env =) –
output (>>> # 处理字符串) –
env.append_transform( (>>> env =) –
UnaryTransform( (...) –
in_keys=["observation"], (...) –
out_keys=["observation_str"], (...) –
tensor (... fn=lambda) – str(tensor.numpy().tobytes())))
output –
env.append_transform( –
Hash( (...) –
in_keys=["observation_str"], (...) –
out_keys=["observation_hash"],) (...) –
) (...) –
env.observation_spec (>>>) –
Composite( –
- observation: BoundedContinuous(
shape=torch.Size([3]), space=ContinuousBox(
low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu, dtype=torch.float32, domain=continuous),
- observation_str: NonTensor(
shape=torch.Size([]), space=None, device=cpu, dtype=None, domain=None),
- observation_hash: UnboundedDiscrete(
shape=torch.Size([32]), space=ContinuousBox(
low=Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.uint8, contiguous=True), high=Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.uint8, contiguous=True)),
device=cpu, dtype=torch.uint8, domain=discrete),
device=None, shape=torch.Size([]))
env.rollout (>>>) –
TensorDict( –
- fields={
action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict(
- fields={
done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), observation_hash: Tensor(shape=torch.Size([3, 32]), device=cpu, dtype=torch.uint8, is_shared=False), observation_str: NonTensorStack(
[“b’g\x08\x8b\xbexav\xbf\x00\xee(>’”, “b’\x…, batch_size=torch.Size([3]), device=None),
reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3]), device=None, is_shared=False),
observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), observation_hash: Tensor(shape=torch.Size([3, 32]), device=cpu, dtype=torch.uint8, is_shared=False), observation_str: NonTensorStack(
[“b’\xb5\x17\x8f\xbe\x88\xccu\xbf\xc0Vr?’”…, batch_size=torch.Size([3]), device=None),
terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3]), device=None, is_shared=False)
env.check_env_specs() (>>>) –
succeeded! ([torchrl][INFO] check_env_specs)
- classmethod reproducible_hash(string, seed=None)[源]¶
使用种子从字符串创建可重现的 256 位哈希。
- 参数:
string (str or None) – 输入字符串。如果为
None
,则使用空字符串""
。seed (str, 可选) – 种子值。默认值为
None
。
- 返回:
形状为
(32,)
,dtype 为torch.uint8
的张量。- 返回类型:
Tensor
- state_dict(*args, destination=None, prefix='', keep_vars=False)[源]¶
返回一个包含对模块整个状态的引用的字典。
包含参数和持久性缓冲区(例如,运行平均值)。键是相应的参数和缓冲区名称。设置为
None
的参数和缓冲区不包含在内。注意
返回的对象是浅拷贝。它包含对模块参数和缓冲区的引用。
警告
目前
state_dict()
也按顺序接受destination
、prefix
和keep_vars
的位置参数。然而,这已被弃用,未来版本将强制使用关键字参数。警告
请避免使用参数
destination
,因为它不是为最终用户设计的。- 参数:
destination (dict, 可选) – 如果提供,模块的状态将更新到此字典中并返回同一对象。否则,将创建一个并返回
OrderedDict
。默认值:None
。prefix (str, 可选) – 添加到参数和缓冲区名称前的字符串,用于构成 state_dict 中的键。默认值:
''
。keep_vars (bool, 可选) – 默认情况下,state dict 中返回的
Tensor
会从 autograd 中分离。如果设置为True
,则不会执行分离操作。默认值:False
。
- 返回:
包含模块整个状态的字典
- 返回类型:
dict
示例
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']