注意
点击此处下载完整示例代码
PyTorch:张量¶
创建日期:2020 年 12 月 03 日 | 最后更新:2020 年 12 月 03 日 | 最后验证:2024 年 11 月 05 日
一个三阶多项式,通过最小化平方欧几里得距离,训练用于预测 \(y=\sin(x)\) 在 \(-\pi\) 到 \(pi\) 之间的值。
此实现使用 PyTorch 张量手动计算前向传播、损失和反向传播。
PyTorch 张量与 numpy 数组基本相同:它不知道深度学习、计算图或梯度,只是一个用于任意数值计算的通用 n 维数组。
numpy 数组和 PyTorch 张量之间最大的区别在于 PyTorch 张量可以在 CPU 或 GPU 上运行。要在 GPU 上运行操作,只需将张量转换为 cuda 数据类型即可。
99 463.81201171875
199 312.108154296875
299 211.08497619628906
399 143.78140258789062
499 98.92195129394531
599 69.00719451904297
699 49.04833984375
799 35.724693298339844
899 26.82526397705078
999 20.877582550048828
1099 16.900102615356445
1199 14.238385200500488
1299 12.455989837646484
1399 11.26158618927002
1499 10.46059799194336
1599 9.923055648803711
1699 9.561992645263672
1799 9.319270133972168
1899 9.15597152709961
1999 9.045997619628906
Result: y = 0.009059899486601353 + 0.8446162343025208 x + -0.0015629848930984735 x^2 + -0.0916057676076889 x^3
import torch
import math
dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU
# Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)
# Randomly initialize weights
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)
learning_rate = 1e-6
for t in range(2000):
# Forward pass: compute predicted y
y_pred = a + b * x + c * x ** 2 + d * x ** 3
# Compute and print loss
loss = (y_pred - y).pow(2).sum().item()
if t % 100 == 99:
print(t, loss)
# Backprop to compute gradients of a, b, c, d with respect to loss
grad_y_pred = 2.0 * (y_pred - y)
grad_a = grad_y_pred.sum()
grad_b = (grad_y_pred * x).sum()
grad_c = (grad_y_pred * x ** 2).sum()
grad_d = (grad_y_pred * x ** 3).sum()
# Update weights using gradient descent
a -= learning_rate * grad_a
b -= learning_rate * grad_b
c -= learning_rate * grad_c
d -= learning_rate * grad_d
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')
脚本总运行时间: ( 0 分钟 0.211 秒)