summaryrefslogtreecommitdiff
path: root/research/flossing/surrogate_flossing/MinimalGradientFlossinExample.py
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/surrogate_flossing/MinimalGradientFlossinExample.py')
-rw-r--r--research/flossing/surrogate_flossing/MinimalGradientFlossinExample.py135
1 files changed, 135 insertions, 0 deletions
diff --git a/research/flossing/surrogate_flossing/MinimalGradientFlossinExample.py b/research/flossing/surrogate_flossing/MinimalGradientFlossinExample.py
new file mode 100644
index 0000000..7c8314e
--- /dev/null
+++ b/research/flossing/surrogate_flossing/MinimalGradientFlossinExample.py
@@ -0,0 +1,135 @@
+import torch
+import torch.optim as optim
+import numpy as np
+import matplotlib.pyplot as plt
+
+class VanillaRNN(torch.nn.Module):
+ def __init__(self, Nin, N, Nout):
+ super(VanillaRNN, self).__init__()
+ self.N = N
+ self.W_in = torch.nn.Parameter(torch.randn(N, Nin) / np.sqrt(Nin))
+ self.W_h = torch.nn.Parameter(torch.randn(N, N) / np.sqrt(N))
+ self.b_h = torch.nn.Parameter(torch.zeros(N))
+ self.W_out = torch.nn.Parameter(torch.randn(Nout, N) / np.sqrt(N))
+ self.b_out = torch.nn.Parameter(torch.zeros(Nout))
+
+ def forward(self, x, h_prev):
+ h = torch.tanh(self.W_in @ x + self.W_h @ h_prev + self.b_h)
+ y = self.W_out @ h + self.b_out
+ return h, y
+
+def calculate_jacobian_analytical(vanilla_rnn, h):
+ tanh_prime = 1 / torch.cosh(h)**2
+ jacobian = vanilla_rnn.W_h @ torch.diag(tanh_prime)
+ return jacobian
+
+def calculate_lyapunov_spectrum(vanilla_rnn, x_data, nle, seedIC=1, seedONS=1):
+ n = vanilla_rnn.N
+ steps = len(x_data)
+ ONSstep = 1
+
+ torch.manual_seed(seedIC)
+ h = torch.zeros(n, dtype=torch.float32, requires_grad=True)
+
+ torch.manual_seed(seedONS)
+ Q, R = torch.linalg.qr(torch.randn(n, nle))
+ ls = torch.zeros(nle, dtype=torch.float32)
+
+ for step in range(steps):
+ x = x_data[step]
+ D = calculate_jacobian_analytical(vanilla_rnn, h)
+ h, _ = vanilla_rnn(x, h)
+ Q = D @ Q
+
+ if step % ONSstep == 0 and nle > 0:
+ Q, R = torch.linalg.qr(Q)
+ ls += torch.log(torch.abs(torch.diag(R))) / ONSstep
+
+ return ls
+
+# Model parameters
+Nin = 1
+nb_steps = 1024
+nb_hidden = 64
+Nout = 1
+nle = 16 # number of Lyapunov exponents to floss
+Ef = 41 # number of flossing epochs
+
+# Initialize the RNN
+vanilla_rnn = VanillaRNN(Nin, nb_hidden, Nout)
+optimizer = optim.Adam(vanilla_rnn.parameters())
+
+# Generate input data
+pIn = 0.5 # input probability
+x_data = [torch.tensor(np.random.rand(Nin) < pIn, dtype=torch.float32) for _ in range(nb_steps)]
+
+# Optimization setup
+losses = []
+lyapunov_spectra = []
+
+# Set up the initial plot
+fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
+
+lyapunov_line, = ax1.plot([], [], "k", label="Lyapunov spectrum after flossing")
+ax1.set_xlabel(r"Index $i$")
+ax1.set_ylabel(r"Lyapunov Exponent $\lambda_i$ (1/step)")
+ax1.legend()
+
+loss_line, = ax2.semilogy([], [], "r", label="Loss")
+ax2.set_xlabel("Epoch")
+ax2.set_ylabel("Loss")
+ax2.set_title("Loss over Epochs")
+ax2.legend()
+
+plt.ion() # Turn on interactive mode
+plt.show()
+
+lyapunov_spectrum_initial = calculate_lyapunov_spectrum(vanilla_rnn, x_data[len(x_data)//2:], nle).detach().cpu().numpy()
+ax1.plot(lyapunov_spectrum_initial, "r", label="Lyapunov spectrum before flossing")
+ax1.legend()
+
+# Training loop
+for epoch in range(Ef):
+ torch.manual_seed(epoch)
+ x_data = [torch.tensor(np.random.rand(Nin) < pIn, dtype=torch.float32) for _ in range(nb_steps)]
+
+ optimizer.zero_grad()
+
+ lyapunov_spectrum = calculate_lyapunov_spectrum(vanilla_rnn, x_data[len(x_data) // 2:], nle)
+
+ # Calculate the loss (mean square of the first nle Lyapunov exponents)
+ loss = torch.mean(lyapunov_spectrum**2)
+ print(f"Epoch {epoch}: Loss = {loss.item()}")
+
+ # Backward pass: compute gradients
+ loss.backward()
+
+ # Optimization step
+ optimizer.step()
+
+ # Store the loss and Lyapunov spectrum
+ losses.append(loss.item())
+ lyapunov_spectra.append(lyapunov_spectrum.detach().cpu().numpy())
+
+ if epoch % 10 == 0:
+ print(f"Epoch {epoch}: Loss = {loss.item()}")
+
+ # Update the Lyapunov spectrum plot
+ lyapunov_line.set_ydata(lyapunov_spectrum.detach().cpu().numpy())
+ lyapunov_line.set_xdata(range(len(lyapunov_spectrum)))
+ ax1.relim()
+ ax1.autoscale_view()
+ ax1.legend()
+
+ # Update the loss plot
+ loss_line.set_ydata(losses)
+ loss_line.set_xdata(range(len(losses)))
+ ax2.relim()
+ ax2.autoscale_view()
+
+ # Draw the updated plots
+ fig.canvas.draw()
+ fig.canvas.flush_events()
+
+print("Flossing complete.")
+