快捷方式

OpenMLEnv

torchrl.envs.OpenMLEnv(*args, **kwargs)[源代码]

用于 bandit 上下文中的 OpenML 数据的环境接口。

文档:https://www.openml.org/search?type=data

Scikit-learn 接口:https://scikit-learn.cn/stable/modules/generated/sklearn.datasets.fetch_openml.html

参数:
  • dataset_name (str) – 支持以下数据集:"adult_num""adult_onehot""mushroom_num""mushroom_onehot""covertype""shuttle""magic"

  • device (torch.device兼容, 可选) – 预计输入和输出数据所在的设备。默认为 "cpu"

  • batch_size (torch.Size兼容, 可选) – 环境的批次大小,即调用 reset() 时采样和返回的元素数量。默认为空批次大小,即一次采样一个元素。

变量:

available_envs (List[str]) – 此类将构建的环境列表。

示例

>>> env = OpenMLEnv("adult_onehot", batch_size=[2, 3])
>>> print(env.reset())
TensorDict(
    fields={
        done: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([2, 3, 106]), device=cpu, dtype=torch.float32, is_shared=False),
        reward: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        y: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([2, 3]),
    device=cpu,
    is_shared=False)

文档

访问 PyTorch 的全面的开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源