""" Stability monitoring utilities for SNN training. Provides metrics to diagnose training stability: - Lyapunov exponent (trajectory divergence) - Gradient norms (vanishing/exploding) - Firing rates (dead/saturated neurons) - Membrane potential statistics """ from typing import Dict, List, Optional, Tuple from dataclasses import dataclass, field import torch import torch.nn as nn import numpy as np @dataclass class StabilityMetrics: """Container for stability measurements.""" lyapunov: Optional[float] = None grad_norm: Optional[float] = None grad_max_sv: Optional[float] = None # Max singular value of gradients grad_min_sv: Optional[float] = None # Min singular value of gradients grad_condition: Optional[float] = None # Condition number (max_sv / min_sv) firing_rate_mean: Optional[float] = None firing_rate_std: Optional[float] = None dead_neuron_frac: Optional[float] = None saturated_neuron_frac: Optional[float] = None membrane_mean: Optional[float] = None membrane_std: Optional[float] = None def to_dict(self) -> Dict[str, float]: return {k: v for k, v in self.__dict__.items() if v is not None} def __str__(self) -> str: parts = [] if self.lyapunov is not None: parts.append(f"λ={self.lyapunov:.4f}") if self.grad_norm is not None: parts.append(f"∇={self.grad_norm:.4f}") if self.firing_rate_mean is not None: parts.append(f"fr={self.firing_rate_mean:.3f}±{self.firing_rate_std:.3f}") if self.dead_neuron_frac is not None: parts.append(f"dead={self.dead_neuron_frac:.1%}") if self.saturated_neuron_frac is not None: parts.append(f"sat={self.saturated_neuron_frac:.1%}") return " | ".join(parts) class StabilityMonitor: """ Monitor SNN stability during training. Usage: monitor = StabilityMonitor() # During training logits, lyap = model(x, compute_lyapunov=True) loss.backward() metrics = monitor.compute( model=model, lyapunov=lyap, spikes=spike_recordings, # optional membrane=membrane_recordings, # optional ) print(metrics) """ def __init__( self, dead_threshold: float = 0.01, saturated_threshold: float = 0.9, history_size: int = 100, ): """ Args: dead_threshold: Firing rate below this = dead neuron saturated_threshold: Firing rate above this = saturated neuron history_size: Number of batches to track for moving averages """ self.dead_threshold = dead_threshold self.saturated_threshold = saturated_threshold self.history_size = history_size # History tracking self.lyap_history: List[float] = [] self.grad_history: List[float] = [] self.fr_history: List[float] = [] def compute_gradient_norm(self, model: nn.Module) -> float: """Compute total gradient norm across all parameters.""" total_norm = 0.0 for p in model.parameters(): if p.grad is not None: total_norm += p.grad.data.norm(2).item() ** 2 return total_norm ** 0.5 def compute_gradient_singular_values( self, model: nn.Module, top_k: int = 5, ) -> Dict[str, Tuple[np.ndarray, float]]: """ Compute singular values of gradient matrices. Per Gradient Flossing paper: singular value spectrum reveals gradient pathologies (vanishing/exploding/rank collapse). Args: model: The SNN model top_k: Number of top/bottom singular values to return Returns: Dict mapping layer name to (singular_values, condition_number) """ results = {} for name, param in model.named_parameters(): if param.grad is not None and param.ndim == 2: # Only compute for weight matrices (2D) with torch.no_grad(): G = param.grad.detach().cpu() try: # Full SVD is expensive; use truncated for large matrices if G.shape[0] * G.shape[1] > 1e6: # For very large matrices, just compute extremes U, S, V = torch.svd_lowrank(G, q=min(top_k, min(G.shape))) sv = S.numpy() else: sv = torch.linalg.svdvals(G).numpy() max_sv = sv[0] if len(sv) > 0 else 0 min_sv = sv[-1] if len(sv) > 0 else 0 condition = max_sv / (min_sv + 1e-12) results[name] = (sv[:top_k], condition) except Exception: pass # Skip if SVD fails return results def get_aggregate_gradient_sv(self, model: nn.Module) -> Tuple[float, float, float]: """ Get aggregate gradient singular value statistics. Returns: (max_sv, min_sv, avg_condition_number) across all layers """ sv_results = self.compute_gradient_singular_values(model) if not sv_results: return 0.0, 0.0, 1.0 max_svs = [] min_svs = [] conditions = [] for name, (sv, cond) in sv_results.items(): if len(sv) > 0: max_svs.append(sv[0]) min_svs.append(sv[-1]) conditions.append(cond) if not max_svs: return 0.0, 0.0, 1.0 return ( float(np.max(max_svs)), float(np.min(min_svs)), float(np.mean(conditions)) ) def compute_firing_stats( self, spikes: torch.Tensor, ) -> Tuple[float, float, float, float]: """ Compute firing rate statistics. Args: spikes: Spike tensor, shape (B, T, H) or (T, H) or (B, H) Values should be 0/1. Returns: (mean_rate, std_rate, dead_frac, saturated_frac) """ with torch.no_grad(): # Flatten to (num_samples, num_neurons) if needed if spikes.ndim == 3: # (B, T, H) -> compute rate per neuron per sample rates = spikes.float().mean(dim=1) # (B, H) elif spikes.ndim == 2: # Could be (T, H) or (B, H) - assume (T, H) for single sample rates = spikes.float().mean(dim=0, keepdim=True) # (1, H) else: rates = spikes.float().unsqueeze(0) # Per-neuron average rate across batch neuron_rates = rates.mean(dim=0) # (H,) mean_rate = neuron_rates.mean().item() std_rate = neuron_rates.std().item() dead_frac = (neuron_rates < self.dead_threshold).float().mean().item() saturated_frac = (neuron_rates > self.saturated_threshold).float().mean().item() return mean_rate, std_rate, dead_frac, saturated_frac def compute_membrane_stats( self, membrane: torch.Tensor, ) -> Tuple[float, float]: """ Compute membrane potential statistics. Args: membrane: Membrane potential tensor, any shape Returns: (mean, std) """ with torch.no_grad(): return membrane.mean().item(), membrane.std().item() def compute( self, model: nn.Module, lyapunov: Optional[torch.Tensor] = None, spikes: Optional[torch.Tensor] = None, membrane: Optional[torch.Tensor] = None, compute_sv: bool = False, ) -> StabilityMetrics: """ Compute all available stability metrics. Args: model: The SNN model (for gradient norms) lyapunov: Lyapunov exponent from forward pass spikes: Recorded spikes (optional) membrane: Recorded membrane potentials (optional) compute_sv: Whether to compute gradient singular values (expensive) Returns: StabilityMetrics object """ metrics = StabilityMetrics() # Lyapunov exponent if lyapunov is not None: if isinstance(lyapunov, torch.Tensor): lyapunov = lyapunov.item() metrics.lyapunov = lyapunov self.lyap_history.append(lyapunov) if len(self.lyap_history) > self.history_size: self.lyap_history.pop(0) # Gradient norm grad_norm = self.compute_gradient_norm(model) metrics.grad_norm = grad_norm self.grad_history.append(grad_norm) if len(self.grad_history) > self.history_size: self.grad_history.pop(0) # Gradient singular values (optional, expensive) if compute_sv: max_sv, min_sv, avg_cond = self.get_aggregate_gradient_sv(model) metrics.grad_max_sv = max_sv metrics.grad_min_sv = min_sv metrics.grad_condition = avg_cond # Firing rate statistics if spikes is not None: fr_mean, fr_std, dead_frac, sat_frac = self.compute_firing_stats(spikes) metrics.firing_rate_mean = fr_mean metrics.firing_rate_std = fr_std metrics.dead_neuron_frac = dead_frac metrics.saturated_neuron_frac = sat_frac self.fr_history.append(fr_mean) if len(self.fr_history) > self.history_size: self.fr_history.pop(0) # Membrane potential statistics if membrane is not None: mem_mean, mem_std = self.compute_membrane_stats(membrane) metrics.membrane_mean = mem_mean metrics.membrane_std = mem_std return metrics def get_trends(self) -> Dict[str, str]: """ Analyze trends in stability metrics. Returns: Dictionary with trend analysis for each metric. """ trends = {} if len(self.lyap_history) >= 10: recent = np.mean(self.lyap_history[-10:]) older = np.mean(self.lyap_history[:10]) if recent > older + 0.1: trends["lyapunov"] = "⚠️ INCREASING (becoming unstable)" elif recent < older - 0.1: trends["lyapunov"] = "✓ DECREASING (stabilizing)" else: trends["lyapunov"] = "→ STABLE" if len(self.grad_history) >= 10: recent = np.mean(self.grad_history[-10:]) older = np.mean(self.grad_history[:10]) ratio = recent / (older + 1e-8) if ratio > 10: trends["gradients"] = "⚠️ EXPLODING" elif ratio < 0.1: trends["gradients"] = "⚠️ VANISHING" else: trends["gradients"] = "✓ STABLE" if len(self.fr_history) >= 10: recent = np.mean(self.fr_history[-10:]) if recent < 0.01: trends["firing"] = "⚠️ DEAD NETWORK" elif recent > 0.8: trends["firing"] = "⚠️ SATURATED" else: trends["firing"] = "✓ HEALTHY" return trends def diagnose(self) -> str: """Generate a diagnostic summary.""" trends = self.get_trends() lines = ["=== Stability Diagnosis ==="] if self.lyap_history: avg_lyap = np.mean(self.lyap_history[-20:]) lines.append(f"Lyapunov exponent: {avg_lyap:.4f}") if avg_lyap > 0.5: lines.append(" → Network is CHAOTIC (trajectories diverge quickly)") lines.append(" → Suggestion: Increase lambda_reg or decrease learning rate") elif avg_lyap < -0.5: lines.append(" → Network is OVER-STABLE (trajectories collapse)") lines.append(" → Suggestion: May lose expressiveness, consider reducing regularization") else: lines.append(" → Network is at EDGE OF CHAOS (good for learning)") if self.grad_history: avg_grad = np.mean(self.grad_history[-20:]) max_grad = max(self.grad_history[-20:]) lines.append(f"Gradient norm: avg={avg_grad:.4f}, max={max_grad:.4f}") if "gradients" in trends: lines.append(f" → {trends['gradients']}") if self.fr_history: avg_fr = np.mean(self.fr_history[-20:]) lines.append(f"Firing rate: {avg_fr:.4f}") if "firing" in trends: lines.append(f" → {trends['firing']}") return "\n".join(lines) def compute_spectral_radius(weight_matrix: torch.Tensor) -> float: """ Compute spectral radius of a weight matrix. For recurrent networks, spectral radius > 1 indicates potential instability. Args: weight_matrix: 2D weight tensor Returns: Spectral radius (largest absolute eigenvalue) """ with torch.no_grad(): W = weight_matrix.detach().cpu().numpy() eigenvalues = np.linalg.eigvals(W) return float(np.max(np.abs(eigenvalues))) def analyze_weight_spectrum(model: nn.Module) -> Dict[str, float]: """ Analyze spectral properties of all weight matrices. Returns: Dictionary mapping layer names to spectral radii. """ results = {} for name, param in model.named_parameters(): if "weight" in name and param.ndim == 2: if param.shape[0] == param.shape[1]: # Square matrix (recurrent) results[name] = compute_spectral_radius(param) return results