From 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Sat, 13 Jun 2026 12:35:36 -0500 Subject: rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipeline Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 --- .../MinimalGradientFlossinExample.py | 135 +++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 research/flossing/surrogate_flossing/MinimalGradientFlossinExample.py (limited to 'research/flossing/surrogate_flossing/MinimalGradientFlossinExample.py') diff --git a/research/flossing/surrogate_flossing/MinimalGradientFlossinExample.py b/research/flossing/surrogate_flossing/MinimalGradientFlossinExample.py new file mode 100644 index 0000000..7c8314e --- /dev/null +++ b/research/flossing/surrogate_flossing/MinimalGradientFlossinExample.py @@ -0,0 +1,135 @@ +import torch +import torch.optim as optim +import numpy as np +import matplotlib.pyplot as plt + +class VanillaRNN(torch.nn.Module): + def __init__(self, Nin, N, Nout): + super(VanillaRNN, self).__init__() + self.N = N + self.W_in = torch.nn.Parameter(torch.randn(N, Nin) / np.sqrt(Nin)) + self.W_h = torch.nn.Parameter(torch.randn(N, N) / np.sqrt(N)) + self.b_h = torch.nn.Parameter(torch.zeros(N)) + self.W_out = torch.nn.Parameter(torch.randn(Nout, N) / np.sqrt(N)) + self.b_out = torch.nn.Parameter(torch.zeros(Nout)) + + def forward(self, x, h_prev): + h = torch.tanh(self.W_in @ x + self.W_h @ h_prev + self.b_h) + y = self.W_out @ h + self.b_out + return h, y + +def calculate_jacobian_analytical(vanilla_rnn, h): + tanh_prime = 1 / torch.cosh(h)**2 + jacobian = vanilla_rnn.W_h @ torch.diag(tanh_prime) + return jacobian + +def calculate_lyapunov_spectrum(vanilla_rnn, x_data, nle, seedIC=1, seedONS=1): + n = vanilla_rnn.N + steps = len(x_data) + ONSstep = 1 + + torch.manual_seed(seedIC) + h = torch.zeros(n, dtype=torch.float32, requires_grad=True) + + torch.manual_seed(seedONS) + Q, R = torch.linalg.qr(torch.randn(n, nle)) + ls = torch.zeros(nle, dtype=torch.float32) + + for step in range(steps): + x = x_data[step] + D = calculate_jacobian_analytical(vanilla_rnn, h) + h, _ = vanilla_rnn(x, h) + Q = D @ Q + + if step % ONSstep == 0 and nle > 0: + Q, R = torch.linalg.qr(Q) + ls += torch.log(torch.abs(torch.diag(R))) / ONSstep + + return ls + +# Model parameters +Nin = 1 +nb_steps = 1024 +nb_hidden = 64 +Nout = 1 +nle = 16 # number of Lyapunov exponents to floss +Ef = 41 # number of flossing epochs + +# Initialize the RNN +vanilla_rnn = VanillaRNN(Nin, nb_hidden, Nout) +optimizer = optim.Adam(vanilla_rnn.parameters()) + +# Generate input data +pIn = 0.5 # input probability +x_data = [torch.tensor(np.random.rand(Nin) < pIn, dtype=torch.float32) for _ in range(nb_steps)] + +# Optimization setup +losses = [] +lyapunov_spectra = [] + +# Set up the initial plot +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5)) + +lyapunov_line, = ax1.plot([], [], "k", label="Lyapunov spectrum after flossing") +ax1.set_xlabel(r"Index $i$") +ax1.set_ylabel(r"Lyapunov Exponent $\lambda_i$ (1/step)") +ax1.legend() + +loss_line, = ax2.semilogy([], [], "r", label="Loss") +ax2.set_xlabel("Epoch") +ax2.set_ylabel("Loss") +ax2.set_title("Loss over Epochs") +ax2.legend() + +plt.ion() # Turn on interactive mode +plt.show() + +lyapunov_spectrum_initial = calculate_lyapunov_spectrum(vanilla_rnn, x_data[len(x_data)//2:], nle).detach().cpu().numpy() +ax1.plot(lyapunov_spectrum_initial, "r", label="Lyapunov spectrum before flossing") +ax1.legend() + +# Training loop +for epoch in range(Ef): + torch.manual_seed(epoch) + x_data = [torch.tensor(np.random.rand(Nin) < pIn, dtype=torch.float32) for _ in range(nb_steps)] + + optimizer.zero_grad() + + lyapunov_spectrum = calculate_lyapunov_spectrum(vanilla_rnn, x_data[len(x_data) // 2:], nle) + + # Calculate the loss (mean square of the first nle Lyapunov exponents) + loss = torch.mean(lyapunov_spectrum**2) + print(f"Epoch {epoch}: Loss = {loss.item()}") + + # Backward pass: compute gradients + loss.backward() + + # Optimization step + optimizer.step() + + # Store the loss and Lyapunov spectrum + losses.append(loss.item()) + lyapunov_spectra.append(lyapunov_spectrum.detach().cpu().numpy()) + + if epoch % 10 == 0: + print(f"Epoch {epoch}: Loss = {loss.item()}") + + # Update the Lyapunov spectrum plot + lyapunov_line.set_ydata(lyapunov_spectrum.detach().cpu().numpy()) + lyapunov_line.set_xdata(range(len(lyapunov_spectrum))) + ax1.relim() + ax1.autoscale_view() + ax1.legend() + + # Update the loss plot + loss_line.set_ydata(losses) + loss_line.set_xdata(range(len(losses))) + ax2.relim() + ax2.autoscale_view() + + # Draw the updated plots + fig.canvas.draw() + fig.canvas.flush_events() + +print("Flossing complete.") + -- cgit v1.2.3