summaryrefslogtreecommitdiff
path: root/learn_torch/basics/nd_grad.py
blob: d8af59e338aca2b094554073a76c7af4d8729fd1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# -*- coding: utf-8 -*-
import numpy as np
import math

# Create random input and output data
x = np.linspace(-math.pi, math.pi, 2000)
y = np.sin(x)

# Randomly initialize weights
a = np.random.randn()
b = np.random.randn()
c = np.random.randn()
d = np.random.randn()

learning_rate = 1e-3
for t in range(500):
    # Forward pass: compute predicted y
    # y = a + b x + c x^2 + d x^3
    y_pred = a + b * x + c * x ** 2 + d * x ** 3

    # Compute and print loss
    loss = np.square(y_pred - y).mean()
    if t % 10 == 0:
        print(x.shape, y_pred.shape, y.shape)
        print(t, loss)

    # Backprop to compute gradients of a, b, c, d with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    # grad_y_pred = 2.0 * (y - y_pred)

    grad_a = grad_y_pred.mean()
    grad_b = (grad_y_pred * x).mean()
    grad_c = (grad_y_pred * x ** 2).mean()
    grad_d = (grad_y_pred * x ** 3).mean()

    # Update weights
    a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d

print(f'Result: y = {a} + {b} x + {c} x^2 + {d} x^3')