OpenMLEnv¶
- torchrl.envs.OpenMLEnv(*args, **kwargs)[source]¶
OpenML 数据的环境接口,用于 bandits 上下文。
文档: https://www.openml.org/search?type=data
Scikit-learn 接口: https://scikit-learn.cn/stable/modules/generated/sklearn.datasets.fetch_openml.html
- 参数:
dataset_name (str) – 支持以下数据集:
"adult_num"
,"adult_onehot"
,"mushroom_num"
,"mushroom_onehot"
,"covertype"
,"shuttle"
和"magic"
。device (torch.device 或 兼容设备, 可选) – 输入和输出数据预期所在的设备。默认为
"cpu"
。batch_size (torch.Size 或 兼容设备, 可选) – 环境的批大小,即调用
reset()
时采样和返回的元素数量。默认为空批大小,即每次采样一个元素。
- 变量:
available_envs (List[str]) – 此类要构建的环境列表。
示例
>>> env = OpenMLEnv("adult_onehot", batch_size=[2, 3]) >>> print(env.reset()) TensorDict( fields={ done: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([2, 3, 106]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False), y: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([2, 3]), device=cpu, is_shared=False)