diff options
| author | zhang <zch921005@126.com> | 2022-05-04 08:47:54 +0800 |
|---|---|---|
| committer | zhang <zch921005@126.com> | 2022-05-04 08:47:54 +0800 |
| commit | 2180c68999eb8dc0c7bcec015b2703f5b8b20223 (patch) | |
| tree | 3ec71623038ff8b90a5bc4e32da14a7382d42d9d /learn_torch/basics/autograd_v5.py | |
| parent | 70aebb73b81b50911e2107cd4519e69f09971021 (diff) | |
ndarray axis
Diffstat (limited to 'learn_torch/basics/autograd_v5.py')
| -rw-r--r-- | learn_torch/basics/autograd_v5.py | 53 |
1 files changed, 53 insertions, 0 deletions
diff --git a/learn_torch/basics/autograd_v5.py b/learn_torch/basics/autograd_v5.py new file mode 100644 index 0000000..810a702 --- /dev/null +++ b/learn_torch/basics/autograd_v5.py @@ -0,0 +1,53 @@ + +import torch +import math + +device = 'cuda:0' if torch.cuda.is_available() else 'cpu' +dtype = torch.float + +lr = 5e-6 + +class LegendrePolynomial3(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + return 0.5*(5*input**3 - 3*input) + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + return grad_output*(7.5*input**2 - 1.5) + + +def train(X, y): + a = torch.full((), 0, device=device, dtype=dtype, requires_grad=True) + b = torch.full((), -1, device=device, dtype=dtype, requires_grad=True) + c = torch.full((), 0, device=device, dtype=dtype, requires_grad=True) + d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True) + + for i in range(2000): + P3 = LegendrePolynomial3.apply + # 执行 forward + y_pred = a + b * P3(c + d*X) + loss = (y_pred - y).pow(2).sum() + if i % 100 == 0: + print('{}/{}: {}'.format(i, 2000, loss.item())) + # 执行 backward + loss.backward() + with torch.no_grad(): + a -= lr * a.grad + b -= lr * b.grad + c -= lr * c.grad + d -= lr * d.grad + + a.grad = None + b.grad = None + c.grad = None + d.grad = None + print('a = {}, b = {}, c = {}, d = {}'.format(a.item(), b.item(), c.item(), d.item())) + + + +if __name__ == '__main__': + X = torch.linspace(-math.pi, math.pi, 2000, dtype=dtype, device=device) + y = torch.sin(X) + train(X, y) |
