快捷方式

OpenMLEnv

torchrl.envs.OpenMLEnv(*args, **kwargs)[source]

OpenML 数据的环境接口,用于 bandits 上下文。

文档: https://www.openml.org/search?type=data

Scikit-learn 接口: https://scikit-learn.cn/stable/modules/generated/sklearn.datasets.fetch_openml.html

参数:
  • dataset_name (str) – 支持以下数据集: "adult_num", "adult_onehot", "mushroom_num", "mushroom_onehot", "covertype", "shuttle""magic"

  • device (torch.device兼容设备, 可选) – 输入和输出数据预期所在的设备。默认为 "cpu"

  • batch_size (torch.Size兼容设备, 可选) – 环境的批大小,即调用 reset() 时采样和返回的元素数量。默认为空批大小,即每次采样一个元素。

变量:

available_envs (List[str]) – 此类要构建的环境列表。

示例

>>> env = OpenMLEnv("adult_onehot", batch_size=[2, 3])
>>> print(env.reset())
TensorDict(
    fields={
        done: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([2, 3, 106]), device=cpu, dtype=torch.float32, is_shared=False),
        reward: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        y: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([2, 3]),
    device=cpu,
    is_shared=False)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源