diff options
Diffstat (limited to 'research/flossing/surrogate_flossing')
4 files changed, 471 insertions, 0 deletions
diff --git a/research/flossing/surrogate_flossing/MinimalGradientFlossinExample.py b/research/flossing/surrogate_flossing/MinimalGradientFlossinExample.py new file mode 100644 index 0000000..7c8314e --- /dev/null +++ b/research/flossing/surrogate_flossing/MinimalGradientFlossinExample.py @@ -0,0 +1,135 @@ +import torch +import torch.optim as optim +import numpy as np +import matplotlib.pyplot as plt + +class VanillaRNN(torch.nn.Module): + def __init__(self, Nin, N, Nout): + super(VanillaRNN, self).__init__() + self.N = N + self.W_in = torch.nn.Parameter(torch.randn(N, Nin) / np.sqrt(Nin)) + self.W_h = torch.nn.Parameter(torch.randn(N, N) / np.sqrt(N)) + self.b_h = torch.nn.Parameter(torch.zeros(N)) + self.W_out = torch.nn.Parameter(torch.randn(Nout, N) / np.sqrt(N)) + self.b_out = torch.nn.Parameter(torch.zeros(Nout)) + + def forward(self, x, h_prev): + h = torch.tanh(self.W_in @ x + self.W_h @ h_prev + self.b_h) + y = self.W_out @ h + self.b_out + return h, y + +def calculate_jacobian_analytical(vanilla_rnn, h): + tanh_prime = 1 / torch.cosh(h)**2 + jacobian = vanilla_rnn.W_h @ torch.diag(tanh_prime) + return jacobian + +def calculate_lyapunov_spectrum(vanilla_rnn, x_data, nle, seedIC=1, seedONS=1): + n = vanilla_rnn.N + steps = len(x_data) + ONSstep = 1 + + torch.manual_seed(seedIC) + h = torch.zeros(n, dtype=torch.float32, requires_grad=True) + + torch.manual_seed(seedONS) + Q, R = torch.linalg.qr(torch.randn(n, nle)) + ls = torch.zeros(nle, dtype=torch.float32) + + for step in range(steps): + x = x_data[step] + D = calculate_jacobian_analytical(vanilla_rnn, h) + h, _ = vanilla_rnn(x, h) + Q = D @ Q + + if step % ONSstep == 0 and nle > 0: + Q, R = torch.linalg.qr(Q) + ls += torch.log(torch.abs(torch.diag(R))) / ONSstep + + return ls + +# Model parameters +Nin = 1 +nb_steps = 1024 +nb_hidden = 64 +Nout = 1 +nle = 16 # number of Lyapunov exponents to floss +Ef = 41 # number of flossing epochs + +# Initialize the RNN +vanilla_rnn = VanillaRNN(Nin, nb_hidden, Nout) +optimizer = optim.Adam(vanilla_rnn.parameters()) + +# Generate input data +pIn = 0.5 # input probability +x_data = [torch.tensor(np.random.rand(Nin) < pIn, dtype=torch.float32) for _ in range(nb_steps)] + +# Optimization setup +losses = [] +lyapunov_spectra = [] + +# Set up the initial plot +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5)) + +lyapunov_line, = ax1.plot([], [], "k", label="Lyapunov spectrum after flossing") +ax1.set_xlabel(r"Index $i$") +ax1.set_ylabel(r"Lyapunov Exponent $\lambda_i$ (1/step)") +ax1.legend() + +loss_line, = ax2.semilogy([], [], "r", label="Loss") +ax2.set_xlabel("Epoch") +ax2.set_ylabel("Loss") +ax2.set_title("Loss over Epochs") +ax2.legend() + +plt.ion() # Turn on interactive mode +plt.show() + +lyapunov_spectrum_initial = calculate_lyapunov_spectrum(vanilla_rnn, x_data[len(x_data)//2:], nle).detach().cpu().numpy() +ax1.plot(lyapunov_spectrum_initial, "r", label="Lyapunov spectrum before flossing") +ax1.legend() + +# Training loop +for epoch in range(Ef): + torch.manual_seed(epoch) + x_data = [torch.tensor(np.random.rand(Nin) < pIn, dtype=torch.float32) for _ in range(nb_steps)] + + optimizer.zero_grad() + + lyapunov_spectrum = calculate_lyapunov_spectrum(vanilla_rnn, x_data[len(x_data) // 2:], nle) + + # Calculate the loss (mean square of the first nle Lyapunov exponents) + loss = torch.mean(lyapunov_spectrum**2) + print(f"Epoch {epoch}: Loss = {loss.item()}") + + # Backward pass: compute gradients + loss.backward() + + # Optimization step + optimizer.step() + + # Store the loss and Lyapunov spectrum + losses.append(loss.item()) + lyapunov_spectra.append(lyapunov_spectrum.detach().cpu().numpy()) + + if epoch % 10 == 0: + print(f"Epoch {epoch}: Loss = {loss.item()}") + + # Update the Lyapunov spectrum plot + lyapunov_line.set_ydata(lyapunov_spectrum.detach().cpu().numpy()) + lyapunov_line.set_xdata(range(len(lyapunov_spectrum))) + ax1.relim() + ax1.autoscale_view() + ax1.legend() + + # Update the loss plot + loss_line.set_ydata(losses) + loss_line.set_xdata(range(len(losses))) + ax2.relim() + ax2.autoscale_view() + + # Draw the updated plots + fig.canvas.draw() + fig.canvas.flush_events() + +print("Flossing complete.") + diff --git a/research/flossing/surrogate_flossing/MinimalSurrogateGradientFlossinExample.py b/research/flossing/surrogate_flossing/MinimalSurrogateGradientFlossinExample.py new file mode 100644 index 0000000..678112c --- /dev/null +++ b/research/flossing/surrogate_flossing/MinimalSurrogateGradientFlossinExample.py @@ -0,0 +1,228 @@ +import torch +import torch.optim as optim +import numpy as np + + +class SurrGradSpike(torch.autograd.Function): + """ + Here we implement our spiking nonlinearity which also implements + the surrogate gradient. By subclassing torch.autograd.Function, + we will be able to use all of PyTorch's autograd functionality. + Here we use the normalized negative part of a fast sigmoid + as this was done in Zenke & Ganguli (2018). + """ + + scale = 10.0 # controls steepness of surrogate gradient + + @staticmethod + def forward(ctx, input): + """ + In the forward pass we compute a step function of the input Tensor + and return it. ctx is a context object that we use to stash information which + we need to later backpropagate our error signals. To achieve this we use the + ctx.save_for_backward method. + """ + ctx.save_for_backward(input) + out = torch.zeros_like(input) + out[input > 0] = 1.0 + return out + + @staticmethod + def backward(ctx, grad_output): + """ + In the backward pass we receive a Tensor we need to compute the + surrogate gradient of the loss with respect to the input. + Here we use the normalized negative part of a fast sigmoid + as this was done in Zenke & Ganguli (2018). + """ + input, = ctx.saved_tensors + grad_input = grad_output.clone() + grad = grad_input/(SurrGradSpike.scale*torch.abs(input)+1.0)**2 + return grad + +# here we overwrite our naive spike function by the "SurrGradSpike" nonlinearity which implements a surrogate gradient +spike_fn = SurrGradSpike.apply + + + +def sgprime(x, g): + return 1 / (g * torch.abs(x) + 1)**2 + +def manual_jacobian_clif15NoReset(mem, syn, x, alpha, beta, v1, iext, gj): + n = len(mem) + mthr = mem - 1 + out_derivative = sgprime(mthr, g=gj) + + eye_n = torch.eye(n, dtype=torch.float32, device=mem.device) + du_dmem = beta * eye_n + v1 @ torch.diag(out_derivative * (1 - beta)) + J_dmem_dmem = du_dmem + J_dsyn_dmem = v1 @ torch.diag(out_derivative) + J_dmem_dsyn = (1 - beta) * alpha * eye_n + J_dsyn_dsyn = alpha * eye_n + + # Constructing full Jacobian using block diagonal matrix + J_upper = torch.cat([J_dmem_dmem, J_dmem_dsyn], dim=1) + J_lower = torch.cat([J_dsyn_dmem, J_dsyn_dsyn], dim=1) + J = torch.cat([J_upper, J_lower], dim=0) + + return J + +def neuron_update7hardOut(mem, syn, x, gu): + mthr = mem - 1 + out = spike_fn(mthr) + rst = out.detach() # like spytorch tutorial: https://github.com/fzenke/spytorch/blob/main/notebooks/SpyTorchTutorial5.ipynb + syn = alpha * syn + x + v1 @ out + + mem = (beta * mem + (1 - beta) * syn) * (1 - rst) + + return mem, syn + + + +def calculate_lyapunov_spectrum(gForward, gBackward, x_data, nle, Win, v1, g): + n = nb_hidden + nx = nb_inputs + steps = len(x_data) + ONSstep = 1 + + torch.manual_seed(seedIC) + mem = torch.zeros(n, dtype=torch.float32) + syn = torch.zeros(n, dtype=torch.float32) + + torch.manual_seed(seedONS) + Q, R = torch.linalg.qr(torch.randn(2*n, nle)) + ls = torch.zeros(nle, dtype=torch.float32) + #lsall = torch.zeros((nle, steps // ONSstep), dtype=torch.float32) # Preallocate as a matrix + + for step in range(steps): + x = x_data[step] + D = manual_jacobian_clif15NoReset(mem, syn, Win @ x, alpha, beta, v1, iext, gBackward) + #print(f"before D.shape {D.shape}") + #print(f"before mem.shape{mem.shape}") + mem, syn = neuron_update7hardOut(mem, syn, Win @ x, gForward) + #print(f"D.shape {D.shape}") + #print(f"mem.shape{mem.shape}") + Q = D @ Q + + if step % ONSstep == 0 and nle > 0: + Q, R = torch.linalg.qr(Q) + ls += torch.log(torch.abs(torch.diag(R))) / ONSstep + #lsall[:, step // ONSstep] = torch.log(torch.abs(torch.diag(R))) / ONSstep/ + + return ls#torch.mean(lsall[:, steps//2//ONSstep:], axis=1) + +# Model parameters +tau_mem = 10e-3 +tau_syn = 5e-3 +time_step = 1e-3 +alpha, beta = torch.exp(torch.tensor(-time_step / tau_syn)), torch.exp(torch.tensor(-time_step / tau_mem)) + +nb_inputs = 100 +nb_steps = 1000 +nb_hidden = 64 +iext = 0 + +# Initialize parameters to optimize as leaf tensors +Win = torch.randn(nb_hidden, nb_inputs, dtype=torch.float32, requires_grad=True) / torch.sqrt(torch.tensor(nb_inputs, dtype=torch.float32)) +Win = Win - torch.mean(Win, dim=1, keepdim=True) +Win = Win.clone().detach().requires_grad_(True) # Ensure it is a leaf tensor + +v1 = torch.randn(nb_hidden, nb_hidden, dtype=torch.float32, requires_grad=True) / torch.sqrt(torch.tensor(nb_hidden, dtype=torch.float32)) +v1 = v1 - torch.mean(v1, dim=1, keepdim=True) +v1 = v1.clone().detach().requires_grad_(True) # Ensure it is a leaf tensor + +g = torch.tensor(5.0, dtype=torch.float32, requires_grad=False) + +# Generate input data +pSpike = 0.1 # for dt=1e-3 this is 10 Hz input +x_data = [torch.tensor(np.random.rand(nb_inputs) < pSpike, dtype=torch.float32) for _ in range(nb_steps)] + +nle = 50#2 * nb_hidden +seedIC = seedONS = 1 +subDir = "cLIF_spectrum" + +resetSwitch = False + +# Optimization setup +#optimizer = optim.SGD([Win, v1, g], lr=1e-3) +optimizer = optim.Adam([Win, v1])#, g]) +import matplotlib.pyplot as plt + + + +import matplotlib.pyplot as plt +from IPython.display import clear_output + +lyapunov_spectrum = calculate_lyapunov_spectrum(1e9, g, x_data[len(x_data)//2:], nle, Win, v1, g) +# Initialize lists to store the loss and Lyapunov spectrum for plotting +losses = [] +lyapunov_spectra = [] + +# Set up the initial plot +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5)) + +lyapunov_line, = ax1.plot([], [], "k", label="Lyapunov Spectrum") +ax1.set_xlabel("Index") +ax1.set_ylabel("Lyapunov Exponent") +ax1.set_title("Lyapunov Spectrum") +ax1.legend() + +loss_line, = ax2.plot([], [], "r", label="Loss") +ax2.set_xlabel("Epoch") +ax2.set_ylabel("Loss") +ax2.set_title("Loss over Epochs") +ax2.legend() + +plt.ion() # Turn on interactive mode +plt.show() +lyapunov_spectrumInitial=lyapunov_spectrum.detach().cpu().numpy() +plt.subplot(121) +plt.plot(lyapunov_spectrumInitial,"r") +0 +# Training loop +#Main training loop with profiling + +for epoch in range(100): + torch.manual_seed(epoch) + x_data = [torch.tensor(np.random.rand(nb_inputs) < pSpike, dtype=torch.float32) for _ in range(nb_steps)] + + optimizer.zero_grad() + + lyapunov_spectrum = calculate_lyapunov_spectrum(1e9, g, x_data[len(x_data) // 2:], nle, Win, v1, g) + + # Calculate the loss (mean square of the first nle Lyapunov exponents) + loss = torch.mean(lyapunov_spectrum**2) + print(f"Epoch {epoch}: Loss = {loss.item()}, Loss dtype = {loss.dtype}") + + # Backward pass: compute gradients + loss.backward() + + # Optimization step + optimizer.step() + + + + # Store the loss and Lyapunov spectrum + losses.append(loss.item()) + lyapunov_spectra.append(lyapunov_spectrum.detach().cpu().numpy()) + + if epoch % 10 == 0: + print(f"Epoch {epoch}: Loss = {loss.item()}") + + # Update the Lyapunov spectrum plot + lyapunov_line.set_ydata(lyapunov_spectrum.detach().cpu().numpy()) + lyapunov_line.set_xdata(range(len(lyapunov_spectrum))) + ax1.relim() + ax1.autoscale_view() + + # Update the loss plot + loss_line.set_ydata(losses) + loss_line.set_xdata(range(len(losses))) + ax2.relim() + ax2.autoscale_view() + + # Draw the updated plots + fig.canvas.draw() + fig.canvas.flush_events() + +print("Training complete.") diff --git a/research/flossing/surrogate_flossing/README.md b/research/flossing/surrogate_flossing/README.md new file mode 100644 index 0000000..40f97aa --- /dev/null +++ b/research/flossing/surrogate_flossing/README.md @@ -0,0 +1,106 @@ +# SurrogateGradientFlossing +This repository contains the implementation code for the manuscript: <br> + __Using Dynamical Systems Theory to Improve Surrogate Gradient Learning in Spiking Neural Networks__ <br> + + +## Overview +We analyze and optimize gradients of binary and spiking recurrent neural networks using concepts from dynamical systems theory. Specifically, we show that surrogate gradient training can be improved by pushing surrogate Lyapunov exponents to zero during or before training. + +## Installation + +#### Prerequisites +- Download [Julia](https://julialang.org/downloads/) + +#### Dependencies +- Julia (1.6) +- Flux, BackwardsLinalg + +## Getting started +To install the required packages, run the following in the julia REPL after installing Julia: + +``` +using Pkg + +for pkg in ["Flux", "BackwardsLinalg"] + Pkg.add(pkg) +end +``` + +For example, to train a spiking neural network on the delayed XOR task, run: +``` +include("SurrogateGradientFlossing_ExampleCode.jl") +# setting parameters: +N, E, Ef, Ei, Ep, Ni, B, S, T, Tp, Ti, sIC, sIn, sNet, sONS, lr, b1, b2, IC, g, gbar, I1, delay, wsS, wsM, wrS, wrM, bS, bM, nLE, task, intype, Lwnt= +80, 3001, 100, 500, 500, 2, 16, 1, 300, 55, 300, 1,1,1,1, 0.001f0, 0.9, 0.999, 1, 1.0, 0.0, 1.0,10, 1.0f0, 0.0f0, 1.0f0, 0.0f0, 0.1f0, 0.0f0,75, -1, 3, 0.0 + +trainSRNNflossing(N, E, Ef, Ei, Ep, Ni, B, S, T, Tp, Ti, sIC, sIn, sNet, sONS, lr, b1, b2, IC, g, gbar, I1, delay, wsS, wsM, wrS, wrM, bS, bM, nLE, task, intype, Lwnt) +``` + +## Repository Overview +_GradientFlossing_ExampleCode.jl_:\ +Example scripts for training networks with gradient flossing before training, with gradient flossing before and during training and without gradient flossing. + + +_GradientFlossing_XOR.jl_:\ +Generates input and target output for copy task and delayed XOR task. + +<!--- + +runOneStimulus.jl trains an RNN on tracking one OU-signal showing that the network becomes more tightly balanced over training epochs.\ +runTwoStimuli.jl trains an RNN on two OU-signal stimuli showing that the network becomes more tightly balanced over training epochs and breaks up into two weakly-connected subnetworks.\ +runTheeStimuli.jl trains an RNN on two OU-signal stimuli showing that the network becomes more tightly balanced over training epochs and breaks up into three weakly-connected subnetworks.\ + +--> + + +<!--- + +### Training dynamics of eigenvalues: +Here is a visualization of the recurrent weight matrix and the eigenvalues throughout across training epochs. + +--> + + +### Implementation details +A full specification of packages used and their versions can be found in _packages.txt_ .\ +For learning rates, the default ADAM parameters were used to avoid any impression of fine-tuning.\ +All simulations were run on a single CPU and took on the order of minutes to a few hours. + +## Additional results: +We here provide additional results on surrogate gradient flossing in binary RNNs. The following figures shows that we can manipulate one or several surrogate Lyapunov exponets in binary networks: + +**Figure 1: Surrogate gradient flossing** regularizes *surrogate Lyapunov exponents* and facilitates gradient signal propagation in binary neural networks. + + + + +**A)** The first *surrogate Lyapunov exponent* of a recurrent binary network plotted as a function of training epochs for different surrogate sharpness $g$. The square of the first *surrogate Lyapunov exponent* is minimized using gradient descent. + +**B)** *Surrogate Lyapunov spectrum* of a recurrent binary network after different numbers of Lyapunov exponents $k$ have been driven towards zero via *surrogate gradient flossing* for $k\in\{1,16,32\}$. The gray lines show the *surrogate Lyapunov spectra* before *surrogate gradient flossing*. Parameters: network size $N=80$, $g=1$ for **B**. Input as in Fig. 3. The thin semitransparent lines in **A** and **B** indicate nine network realizations; the full lines are their average. + +The following figure shows that surrogate gradient flossing improves training in binary RNNs: + +**Figure 2: Surrogate gradient flossing improves binary RNN training.** + + + + +**A)** Test accuracy for binary RNNs trained on the delayed temporal binary XOR task $y_t=x_{t-d/2} \oplus x_{t-d}$ with *adaptive gradient flossing* during training (orange) and without *gradient flossing* (blue) for $d=18$. Solid lines are the median across 9 network realizations, and individual network realizations are shown in transparent fine lines. + +**B)** Mean final test accuracy as a function of task difficulty (delay $d$) for delayed XOR task. + +**C)** Gradient norm with respect to initial network state $\mathbf{h}_0$. + +**D)** Gradient norm with respect to initial network state as a function of temporal task complexity averaged over training epochs. + + +<!--- +### figures/ +Contains all figures of the main text and the supplement. +--> + + +<!--- +### tex/ +Contains the raw text of the main text and the supplement. +--> diff --git a/research/flossing/surrogate_flossing/figures/binary_surrogate_gradient_flossing.md b/research/flossing/surrogate_flossing/figures/binary_surrogate_gradient_flossing.md new file mode 100644 index 0000000..abc0c63 --- /dev/null +++ b/research/flossing/surrogate_flossing/figures/binary_surrogate_gradient_flossing.md @@ -0,0 +1,2 @@ + +**Figure:** Surrogate gradient flossing regularizes surrogate Lyapunov exponents and facilitates gradient signal propagation in binary neural networks. |
