1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
|
"""
Stability monitoring utilities for SNN training.
Provides metrics to diagnose training stability:
- Lyapunov exponent (trajectory divergence)
- Gradient norms (vanishing/exploding)
- Firing rates (dead/saturated neurons)
- Membrane potential statistics
"""
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, field
import torch
import torch.nn as nn
import numpy as np
@dataclass
class StabilityMetrics:
"""Container for stability measurements."""
lyapunov: Optional[float] = None
grad_norm: Optional[float] = None
grad_max_sv: Optional[float] = None # Max singular value of gradients
grad_min_sv: Optional[float] = None # Min singular value of gradients
grad_condition: Optional[float] = None # Condition number (max_sv / min_sv)
firing_rate_mean: Optional[float] = None
firing_rate_std: Optional[float] = None
dead_neuron_frac: Optional[float] = None
saturated_neuron_frac: Optional[float] = None
membrane_mean: Optional[float] = None
membrane_std: Optional[float] = None
def to_dict(self) -> Dict[str, float]:
return {k: v for k, v in self.__dict__.items() if v is not None}
def __str__(self) -> str:
parts = []
if self.lyapunov is not None:
parts.append(f"λ={self.lyapunov:.4f}")
if self.grad_norm is not None:
parts.append(f"∇={self.grad_norm:.4f}")
if self.firing_rate_mean is not None:
parts.append(f"fr={self.firing_rate_mean:.3f}±{self.firing_rate_std:.3f}")
if self.dead_neuron_frac is not None:
parts.append(f"dead={self.dead_neuron_frac:.1%}")
if self.saturated_neuron_frac is not None:
parts.append(f"sat={self.saturated_neuron_frac:.1%}")
return " | ".join(parts)
class StabilityMonitor:
"""
Monitor SNN stability during training.
Usage:
monitor = StabilityMonitor()
# During training
logits, lyap = model(x, compute_lyapunov=True)
loss.backward()
metrics = monitor.compute(
model=model,
lyapunov=lyap,
spikes=spike_recordings, # optional
membrane=membrane_recordings, # optional
)
print(metrics)
"""
def __init__(
self,
dead_threshold: float = 0.01,
saturated_threshold: float = 0.9,
history_size: int = 100,
):
"""
Args:
dead_threshold: Firing rate below this = dead neuron
saturated_threshold: Firing rate above this = saturated neuron
history_size: Number of batches to track for moving averages
"""
self.dead_threshold = dead_threshold
self.saturated_threshold = saturated_threshold
self.history_size = history_size
# History tracking
self.lyap_history: List[float] = []
self.grad_history: List[float] = []
self.fr_history: List[float] = []
def compute_gradient_norm(self, model: nn.Module) -> float:
"""Compute total gradient norm across all parameters."""
total_norm = 0.0
for p in model.parameters():
if p.grad is not None:
total_norm += p.grad.data.norm(2).item() ** 2
return total_norm ** 0.5
def compute_gradient_singular_values(
self,
model: nn.Module,
top_k: int = 5,
) -> Dict[str, Tuple[np.ndarray, float]]:
"""
Compute singular values of gradient matrices.
Per Gradient Flossing paper: singular value spectrum reveals
gradient pathologies (vanishing/exploding/rank collapse).
Args:
model: The SNN model
top_k: Number of top/bottom singular values to return
Returns:
Dict mapping layer name to (singular_values, condition_number)
"""
results = {}
for name, param in model.named_parameters():
if param.grad is not None and param.ndim == 2:
# Only compute for weight matrices (2D)
with torch.no_grad():
G = param.grad.detach().cpu()
try:
# Full SVD is expensive; use truncated for large matrices
if G.shape[0] * G.shape[1] > 1e6:
# For very large matrices, just compute extremes
U, S, V = torch.svd_lowrank(G, q=min(top_k, min(G.shape)))
sv = S.numpy()
else:
sv = torch.linalg.svdvals(G).numpy()
max_sv = sv[0] if len(sv) > 0 else 0
min_sv = sv[-1] if len(sv) > 0 else 0
condition = max_sv / (min_sv + 1e-12)
results[name] = (sv[:top_k], condition)
except Exception:
pass # Skip if SVD fails
return results
def get_aggregate_gradient_sv(self, model: nn.Module) -> Tuple[float, float, float]:
"""
Get aggregate gradient singular value statistics.
Returns:
(max_sv, min_sv, avg_condition_number) across all layers
"""
sv_results = self.compute_gradient_singular_values(model)
if not sv_results:
return 0.0, 0.0, 1.0
max_svs = []
min_svs = []
conditions = []
for name, (sv, cond) in sv_results.items():
if len(sv) > 0:
max_svs.append(sv[0])
min_svs.append(sv[-1])
conditions.append(cond)
if not max_svs:
return 0.0, 0.0, 1.0
return (
float(np.max(max_svs)),
float(np.min(min_svs)),
float(np.mean(conditions))
)
def compute_firing_stats(
self,
spikes: torch.Tensor,
) -> Tuple[float, float, float, float]:
"""
Compute firing rate statistics.
Args:
spikes: Spike tensor, shape (B, T, H) or (T, H) or (B, H)
Values should be 0/1.
Returns:
(mean_rate, std_rate, dead_frac, saturated_frac)
"""
with torch.no_grad():
# Flatten to (num_samples, num_neurons) if needed
if spikes.ndim == 3:
# (B, T, H) -> compute rate per neuron per sample
rates = spikes.float().mean(dim=1) # (B, H)
elif spikes.ndim == 2:
# Could be (T, H) or (B, H) - assume (T, H) for single sample
rates = spikes.float().mean(dim=0, keepdim=True) # (1, H)
else:
rates = spikes.float().unsqueeze(0)
# Per-neuron average rate across batch
neuron_rates = rates.mean(dim=0) # (H,)
mean_rate = neuron_rates.mean().item()
std_rate = neuron_rates.std().item()
dead_frac = (neuron_rates < self.dead_threshold).float().mean().item()
saturated_frac = (neuron_rates > self.saturated_threshold).float().mean().item()
return mean_rate, std_rate, dead_frac, saturated_frac
def compute_membrane_stats(
self,
membrane: torch.Tensor,
) -> Tuple[float, float]:
"""
Compute membrane potential statistics.
Args:
membrane: Membrane potential tensor, any shape
Returns:
(mean, std)
"""
with torch.no_grad():
return membrane.mean().item(), membrane.std().item()
def compute(
self,
model: nn.Module,
lyapunov: Optional[torch.Tensor] = None,
spikes: Optional[torch.Tensor] = None,
membrane: Optional[torch.Tensor] = None,
compute_sv: bool = False,
) -> StabilityMetrics:
"""
Compute all available stability metrics.
Args:
model: The SNN model (for gradient norms)
lyapunov: Lyapunov exponent from forward pass
spikes: Recorded spikes (optional)
membrane: Recorded membrane potentials (optional)
compute_sv: Whether to compute gradient singular values (expensive)
Returns:
StabilityMetrics object
"""
metrics = StabilityMetrics()
# Lyapunov exponent
if lyapunov is not None:
if isinstance(lyapunov, torch.Tensor):
lyapunov = lyapunov.item()
metrics.lyapunov = lyapunov
self.lyap_history.append(lyapunov)
if len(self.lyap_history) > self.history_size:
self.lyap_history.pop(0)
# Gradient norm
grad_norm = self.compute_gradient_norm(model)
metrics.grad_norm = grad_norm
self.grad_history.append(grad_norm)
if len(self.grad_history) > self.history_size:
self.grad_history.pop(0)
# Gradient singular values (optional, expensive)
if compute_sv:
max_sv, min_sv, avg_cond = self.get_aggregate_gradient_sv(model)
metrics.grad_max_sv = max_sv
metrics.grad_min_sv = min_sv
metrics.grad_condition = avg_cond
# Firing rate statistics
if spikes is not None:
fr_mean, fr_std, dead_frac, sat_frac = self.compute_firing_stats(spikes)
metrics.firing_rate_mean = fr_mean
metrics.firing_rate_std = fr_std
metrics.dead_neuron_frac = dead_frac
metrics.saturated_neuron_frac = sat_frac
self.fr_history.append(fr_mean)
if len(self.fr_history) > self.history_size:
self.fr_history.pop(0)
# Membrane potential statistics
if membrane is not None:
mem_mean, mem_std = self.compute_membrane_stats(membrane)
metrics.membrane_mean = mem_mean
metrics.membrane_std = mem_std
return metrics
def get_trends(self) -> Dict[str, str]:
"""
Analyze trends in stability metrics.
Returns:
Dictionary with trend analysis for each metric.
"""
trends = {}
if len(self.lyap_history) >= 10:
recent = np.mean(self.lyap_history[-10:])
older = np.mean(self.lyap_history[:10])
if recent > older + 0.1:
trends["lyapunov"] = "⚠️ INCREASING (becoming unstable)"
elif recent < older - 0.1:
trends["lyapunov"] = "✓ DECREASING (stabilizing)"
else:
trends["lyapunov"] = "→ STABLE"
if len(self.grad_history) >= 10:
recent = np.mean(self.grad_history[-10:])
older = np.mean(self.grad_history[:10])
ratio = recent / (older + 1e-8)
if ratio > 10:
trends["gradients"] = "⚠️ EXPLODING"
elif ratio < 0.1:
trends["gradients"] = "⚠️ VANISHING"
else:
trends["gradients"] = "✓ STABLE"
if len(self.fr_history) >= 10:
recent = np.mean(self.fr_history[-10:])
if recent < 0.01:
trends["firing"] = "⚠️ DEAD NETWORK"
elif recent > 0.8:
trends["firing"] = "⚠️ SATURATED"
else:
trends["firing"] = "✓ HEALTHY"
return trends
def diagnose(self) -> str:
"""Generate a diagnostic summary."""
trends = self.get_trends()
lines = ["=== Stability Diagnosis ==="]
if self.lyap_history:
avg_lyap = np.mean(self.lyap_history[-20:])
lines.append(f"Lyapunov exponent: {avg_lyap:.4f}")
if avg_lyap > 0.5:
lines.append(" → Network is CHAOTIC (trajectories diverge quickly)")
lines.append(" → Suggestion: Increase lambda_reg or decrease learning rate")
elif avg_lyap < -0.5:
lines.append(" → Network is OVER-STABLE (trajectories collapse)")
lines.append(" → Suggestion: May lose expressiveness, consider reducing regularization")
else:
lines.append(" → Network is at EDGE OF CHAOS (good for learning)")
if self.grad_history:
avg_grad = np.mean(self.grad_history[-20:])
max_grad = max(self.grad_history[-20:])
lines.append(f"Gradient norm: avg={avg_grad:.4f}, max={max_grad:.4f}")
if "gradients" in trends:
lines.append(f" → {trends['gradients']}")
if self.fr_history:
avg_fr = np.mean(self.fr_history[-20:])
lines.append(f"Firing rate: {avg_fr:.4f}")
if "firing" in trends:
lines.append(f" → {trends['firing']}")
return "\n".join(lines)
def compute_spectral_radius(weight_matrix: torch.Tensor) -> float:
"""
Compute spectral radius of a weight matrix.
For recurrent networks, spectral radius > 1 indicates potential instability.
Args:
weight_matrix: 2D weight tensor
Returns:
Spectral radius (largest absolute eigenvalue)
"""
with torch.no_grad():
W = weight_matrix.detach().cpu().numpy()
eigenvalues = np.linalg.eigvals(W)
return float(np.max(np.abs(eigenvalues)))
def analyze_weight_spectrum(model: nn.Module) -> Dict[str, float]:
"""
Analyze spectral properties of all weight matrices.
Returns:
Dictionary mapping layer names to spectral radii.
"""
results = {}
for name, param in model.named_parameters():
if "weight" in name and param.ndim == 2:
if param.shape[0] == param.shape[1]: # Square matrix (recurrent)
results[name] = compute_spectral_radius(param)
return results
|