注意
点击此处下载完整的示例代码
在 PyTorch 中推断形状¶
创建日期:2023 年 3 月 27 日 | 最后更新日期:2023 年 3 月 27 日 | 最后验证日期:未验证
使用 PyTorch 编写模型时,给定层的参数通常取决于前一层的输出形状。例如,nn.Linear
层的 in_features
必须与输入的 size(-1)
匹配。对于某些层,形状计算涉及复杂的方程,例如卷积运算。
一种解决方法是使用随机输入运行前向传播,但这会浪费内存和计算资源。
相反,我们可以利用 meta
设备来确定层的输出形状,而无需具体化任何数据。
import torch
import timeit
t = torch.rand(2, 3, 10, 10, device="meta")
conv = torch.nn.Conv2d(3, 5, 2, device="meta")
start = timeit.default_timer()
out = conv(t)
end = timeit.default_timer()
print(out)
print(f"Time taken: {end-start}")
请注意,由于数据没有具体化,传递任意大的输入也不会显著改变形状计算所需的时间。
t_large = torch.rand(2**10, 3, 2**16, 2**16, device="meta")
start = timeit.default_timer()
out = conv(t_large)
end = timeit.default_timer()
print(out)
print(f"Time taken: {end-start}")
考虑一个任意的网络,如下所示
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
我们可以通过为每个层注册一个打印输出形状的前向 hook 来查看整个网络中的中间形状。
def fw_hook(module, input, output):
print(f"Shape of output to {module} is {output.shape}.")
# Any tensor created within this torch.device context manager will be
# on the meta device.
with torch.device("meta"):
net = Net()
inp = torch.randn((1024, 3, 32, 32))
for name, layer in net.named_modules():
layer.register_forward_hook(fw_hook)
out = net(inp)
脚本总运行时间: ( 0 分 0.000 秒)