diff options
| author | zhang <zch921005@126.com> | 2022-05-04 08:47:54 +0800 |
|---|---|---|
| committer | zhang <zch921005@126.com> | 2022-05-04 08:47:54 +0800 |
| commit | 2180c68999eb8dc0c7bcec015b2703f5b8b20223 (patch) | |
| tree | 3ec71623038ff8b90a5bc4e32da14a7382d42d9d /learn_torch/basics/regression_v3.py | |
| parent | 70aebb73b81b50911e2107cd4519e69f09971021 (diff) | |
ndarray axis
Diffstat (limited to 'learn_torch/basics/regression_v3.py')
| -rw-r--r-- | learn_torch/basics/regression_v3.py | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/learn_torch/basics/regression_v3.py b/learn_torch/basics/regression_v3.py new file mode 100644 index 0000000..f6bd467 --- /dev/null +++ b/learn_torch/basics/regression_v3.py @@ -0,0 +1,40 @@ +import math +import torch + + +dtype = torch.float +device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + + +lr = 1e-6 + + +def train(X, y): + a = torch.randn((), device=device, dtype=dtype, requires_grad=True) + b = torch.randn((), device=device, dtype=dtype, requires_grad=True) + c = torch.randn((), device=device, dtype=dtype, requires_grad=True) + d = torch.randn((), device=device, dtype=dtype, requires_grad=True) + + for i in range(2000): + y_pred = a + b*X + c*X**2 + d*X**3 + loss = (y_pred - y).pow(2).sum() + if i % 100 == 0: + print('{}/{}: {}'.format(i, 2000, loss.item())) + loss.backward() + with torch.no_grad(): + a -= lr * a.grad + b -= lr * b.grad + c -= lr * c.grad + d -= lr * d.grad + a.grad = None + b.grad = None + c.grad = None + d.grad = None + print('a = {}, b = {}, c = {}, d = {}'.format(a.item(), b.item(), c.item(), d.item())) + +if __name__ == '__main__': + + X = torch.linspace(-math.pi, math.pi, 2000) + y = torch.sin(X) + train(X, y) + |
