summaryrefslogtreecommitdiff
path: root/learn_torch/basics/v5_diff.py
diff options
context:
space:
mode:
authorzhang <zch921005@126.com>2022-05-04 08:47:54 +0800
committerzhang <zch921005@126.com>2022-05-04 08:47:54 +0800
commit2180c68999eb8dc0c7bcec015b2703f5b8b20223 (patch)
tree3ec71623038ff8b90a5bc4e32da14a7382d42d9d /learn_torch/basics/v5_diff.py
parent70aebb73b81b50911e2107cd4519e69f09971021 (diff)
ndarray axis
Diffstat (limited to 'learn_torch/basics/v5_diff.py')
-rw-r--r--learn_torch/basics/v5_diff.py85
1 files changed, 85 insertions, 0 deletions
diff --git a/learn_torch/basics/v5_diff.py b/learn_torch/basics/v5_diff.py
new file mode 100644
index 0000000..5b247c0
--- /dev/null
+++ b/learn_torch/basics/v5_diff.py
@@ -0,0 +1,85 @@
+# -*- coding: utf-8 -*-
+import torch
+import math
+
+
+class LegendrePolynomial3(torch.autograd.Function):
+ """
+ We can implement our own custom autograd Functions by subclassing
+ torch.autograd.Function and implementing the forward and backward passes
+ which operate on Tensors.
+ """
+
+ @staticmethod
+ def forward(ctx, input):
+ """
+ In the forward pass we receive a Tensor containing the input and return
+ a Tensor containing the output. ctx is a context object that can be used
+ to stash information for backward computation. You can cache arbitrary
+ objects for use in the backward pass using the ctx.save_for_backward method.
+ """
+ ctx.save_for_backward(input)
+ return 0.5 * (5 * input ** 3 - 3 * input)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ """
+ In the backward pass we receive a Tensor containing the gradient of the loss
+ with respect to the output, and we need to compute the gradient of the loss
+ with respect to the input.
+ """
+ input, = ctx.saved_tensors
+ return grad_output * 1.5 * (5 * input ** 2 - 1)
+
+
+dtype = torch.float
+device = torch.device("cpu")
+# device = torch.device("cuda:0") # Uncomment this to run on GPU
+
+# Create Tensors to hold input and outputs.
+# By default, requires_grad=False, which indicates that we do not need to
+# compute gradients with respect to these Tensors during the backward pass.
+x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
+y = torch.sin(x)
+
+# Create random Tensors for weights. For this example, we need
+# 4 weights: y = a + b * P3(c + d * x), these weights need to be initialized
+# not too far from the correct result to ensure convergence.
+# Setting requires_grad=True indicates that we want to compute gradients with
+# respect to these Tensors during the backward pass.
+a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
+b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True)
+c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
+d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True)
+
+learning_rate = 5e-6
+for t in range(2000):
+ # To apply our Function, we use Function.apply method. We alias this as 'P3'.
+ P3 = LegendrePolynomial3.apply
+
+ # Forward pass: compute predicted y using operations; we compute
+ # P3 using our custom autograd operation.
+ y_pred = a + b * P3(c + d * x)
+
+ # Compute and print loss
+ loss = (y_pred - y).pow(2).sum()
+ if t % 100 == 99:
+ print(t, loss.item())
+
+ # Use autograd to compute the backward pass.
+ loss.backward()
+
+ # Update weights using gradient descent
+ with torch.no_grad():
+ a -= learning_rate * a.grad
+ b -= learning_rate * b.grad
+ c -= learning_rate * c.grad
+ d -= learning_rate * d.grad
+
+ # Manually zero the gradients after updating weights
+ a.grad = None
+ b.grad = None
+ c.grad = None
+ d.grad = None
+
+print(f'Result: y = {a.item()} + {b.item()} * P3({c.item()} + {d.item()} x)') \ No newline at end of file