欢迎使用 TensorDict 文档!¶
TensorDict 是一个类似字典的类,它继承了张量的属性,例如索引、形状操作、转换为设备等。
您可以直接从 PyPI 安装 tensordict(有关安装说明,请参见下面的专用部分)。
$ pip install tensordict
TensorDict 的主要目的是通过抽象定制操作来使代码库更具可读性和模块化。
>>> for i, tensordict in enumerate(dataset):
... # the model reads and writes tensordicts
... tensordict = model(tensordict)
... loss = loss_module(tensordict)
... loss.backward()
... optimizer.step()
... optimizer.zero_grad()
通过这种级别的抽象,可以为高度异构的任务循环利用一个训练循环。训练循环的每个步骤(数据收集和转换、模型预测、损失计算等)可以根据具体的用例进行调整,而不会影响其他步骤。例如,以上示例可以轻松用于分类和分割任务,以及其他许多任务。
安装¶
Tensordict 版本与 PyTorch 同步,因此请确保您始终使用最新版本的 PyTorch 来享受库的最新功能(尽管核心功能保证向后兼容于 pytorch>=1.13)。夜间版本可以通过以下方式安装
$ pip install tensordict-nightly
或者如果您愿意为库做出贡献,可以通过 git clone 安装
$ cd path/to/root
$ git clone https://github.com/pytorch/tensordict
$ cd tensordict
$ python setup.py develop