From 2180c68999eb8dc0c7bcec015b2703f5b8b20223 Mon Sep 17 00:00:00 2001 From: zhang Date: Wed, 4 May 2022 08:47:54 +0800 Subject: ndarray axis --- learn_torch/basics/nn_demo_optim.py | 52 +++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 learn_torch/basics/nn_demo_optim.py (limited to 'learn_torch/basics/nn_demo_optim.py') diff --git a/learn_torch/basics/nn_demo_optim.py b/learn_torch/basics/nn_demo_optim.py new file mode 100644 index 0000000..38d95dc --- /dev/null +++ b/learn_torch/basics/nn_demo_optim.py @@ -0,0 +1,52 @@ + +import torch +import math + + +device = 'cuda:0' if torch.cuda.is_available() else 'cpu' +dtype = torch.float +lr = 1e-3 + + +def train(X, y): + for i in range(2000): + y_pred = model(X) + loss = loss_fn(y_pred, y) + + if i % 100 == 0: + print('{}/{}: {}'.format(i, 2000, loss.item())) + + # model.zero_grad() + opt.zero_grad() + + loss.backward() + + # with torch.no_grad(): + # for param in model.parameters(): + # param -= lr * param.grad + opt.step() + + +if __name__ == '__main__': + + X = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) + y = torch.sin(X) + + p = torch.Tensor([1, 2, 3]) + X = X.unsqueeze(-1).pow(p) + + model = torch.nn.Sequential( + torch.nn.Linear(3, 1), + torch.nn.Flatten(0, 1) + ) + + loss_fn = torch.nn.MSELoss(reduction='sum') + opt = torch.optim.RMSprop(model.parameters(), lr=lr) + + train(X, y) + weight_layer = model[0] + + print('y = {} + {}x + {}x^2 + {}x^3'.format(weight_layer.bias.item(), + weight_layer.weight[0, 0].item(), + weight_layer.weight[0, 1].item(), + weight_layer.weight[0, 2].item())) -- cgit v1.2.3