注意
转到末尾 以下载完整示例代码。
多任务环境中的任务特定策略¶
本教程详细介绍了如何使用多任务策略和批量环境。
完成本教程后,你将能够编写使用不同权重集在不同设置中计算动作的策略。你还将能够并行执行不同的环境。
from tensordict import LazyStackedTensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential
from torch import nn
from torchrl.envs import CatTensors, Compose, DoubleToFloat, ParallelEnv, TransformedEnv
from torchrl.envs.libs.dm_control import DMControlEnv
from torchrl.modules import MLP
我们设计了两个环境,一个仿人环境必须完成站立任务,另一个必须学习行走。
env1 = DMControlEnv("humanoid", "stand")
env1_obs_keys = list(env1.observation_spec.keys())
env1 = TransformedEnv(
env1,
Compose(
CatTensors(env1_obs_keys, "observation_stand", del_keys=False),
CatTensors(env1_obs_keys, "observation"),
DoubleToFloat(
in_keys=["observation_stand", "observation"],
in_keys_inv=["action"],
),
),
)
env2 = DMControlEnv("humanoid", "walk")
env2_obs_keys = list(env2.observation_spec.keys())
env2 = TransformedEnv(
env2,
Compose(
CatTensors(env2_obs_keys, "observation_walk", del_keys=False),
CatTensors(env2_obs_keys, "observation"),
DoubleToFloat(
in_keys=["observation_walk", "observation"],
in_keys_inv=["action"],
),
),
)
tdreset1 = env1.reset()
tdreset2 = env2.reset()
# With LazyStackedTensorDict, stacking is done in a lazy manner: the original tensordicts
# can still be recovered by indexing the main tensordict
tdreset = LazyStackedTensorDict.lazy_stack([tdreset1, tdreset2], 0)
assert tdreset[0] is tdreset1
print(tdreset[0])
策略¶
我们将设计一个策略,其中主干网络读取“observation”键。然后,如果存在的话,特定的子组件将读取堆叠 tensordicts 的“observation_stand”和“observation_walk”键,并将它们传递给专用的子网络。
action_dim = env1.action_spec.shape[-1]
policy_common = TensorDictModule(
nn.Linear(67, 64), in_keys=["observation"], out_keys=["hidden"]
)
policy_stand = TensorDictModule(
MLP(67 + 64, action_dim, depth=2),
in_keys=["observation_stand", "hidden"],
out_keys=["action"],
)
policy_walk = TensorDictModule(
MLP(67 + 64, action_dim, depth=2),
in_keys=["observation_walk", "hidden"],
out_keys=["action"],
)
seq = TensorDictSequential(
policy_common, policy_stand, policy_walk, partial_tolerant=True
)
我们来检查一下我们的序列是否为单个环境(站立)输出了动作。
seq(env1.reset())
我们来检查一下我们的序列是否为单个环境(行走)输出了动作。
seq(env2.reset())
这也适用于堆叠:现在 stand 和 walk 键已经消失了,因为它们并非所有 tensordicts 共享。但是 TensorDictSequential
仍然执行了操作。注意,主干网络是以向量化方式执行的——而不是在循环中——这更高效。
seq(tdreset)
并行执行不同任务¶
如果公共键值对共享相同的 specs(特别是它们的形状和 dtype 必须匹配:如果 observation 形状不同但指向同一个键,你将无法执行以下操作),我们可以并行化操作。
如果 ParallelEnv 接收单个环境创建函数,它将假定只需执行单个任务。如果提供函数列表,则将假定我们处于多任务设置。
def env1_maker():
return TransformedEnv(
DMControlEnv("humanoid", "stand"),
Compose(
CatTensors(env1_obs_keys, "observation_stand", del_keys=False),
CatTensors(env1_obs_keys, "observation"),
DoubleToFloat(
in_keys=["observation_stand", "observation"],
in_keys_inv=["action"],
),
),
)
def env2_maker():
return TransformedEnv(
DMControlEnv("humanoid", "walk"),
Compose(
CatTensors(env2_obs_keys, "observation_walk", del_keys=False),
CatTensors(env2_obs_keys, "observation"),
DoubleToFloat(
in_keys=["observation_walk", "observation"],
in_keys_inv=["action"],
),
),
)
env = ParallelEnv(2, [env1_maker, env2_maker])
assert not env._single_task
tdreset = env.reset()
print(tdreset)
print(tdreset[0])
print(tdreset[1]) # should be different
让我们将输出通过我们的网络。
tdreset = seq(tdreset)
print(tdreset)
print(tdreset[0])
print(tdreset[1]) # should be different but all have an "action" key
env.step(tdreset) # computes actions and execute steps in parallel
print(tdreset)
print(tdreset[0])
print(tdreset[1]) # next_observation has now been written
Rollout¶
td_rollout = env.rollout(100, policy=seq, return_contiguous=False)
td_rollout[:, 0] # tensordict of the first step: only the common keys are shown
td_rollout[0] # tensordict of the first env: the stand obs is present
env.close()
del env