diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:49:05 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:49:05 -0600 |
| commit | cd99d6b874d9d09b3bb87b8485cc787885af71f1 (patch) | |
| tree | 59a233959932ca0e4f12f196275e07fcf443b33f /files/analysis/stability_monitor.py | |
init commit
Diffstat (limited to 'files/analysis/stability_monitor.py')
| -rw-r--r-- | files/analysis/stability_monitor.py | 395 |
1 files changed, 395 insertions, 0 deletions
diff --git a/files/analysis/stability_monitor.py b/files/analysis/stability_monitor.py new file mode 100644 index 0000000..18cad0f --- /dev/null +++ b/files/analysis/stability_monitor.py @@ -0,0 +1,395 @@ +""" +Stability monitoring utilities for SNN training. + +Provides metrics to diagnose training stability: +- Lyapunov exponent (trajectory divergence) +- Gradient norms (vanishing/exploding) +- Firing rates (dead/saturated neurons) +- Membrane potential statistics +""" + +from typing import Dict, List, Optional, Tuple +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import numpy as np + + +@dataclass +class StabilityMetrics: + """Container for stability measurements.""" + lyapunov: Optional[float] = None + grad_norm: Optional[float] = None + grad_max_sv: Optional[float] = None # Max singular value of gradients + grad_min_sv: Optional[float] = None # Min singular value of gradients + grad_condition: Optional[float] = None # Condition number (max_sv / min_sv) + firing_rate_mean: Optional[float] = None + firing_rate_std: Optional[float] = None + dead_neuron_frac: Optional[float] = None + saturated_neuron_frac: Optional[float] = None + membrane_mean: Optional[float] = None + membrane_std: Optional[float] = None + + def to_dict(self) -> Dict[str, float]: + return {k: v for k, v in self.__dict__.items() if v is not None} + + def __str__(self) -> str: + parts = [] + if self.lyapunov is not None: + parts.append(f"λ={self.lyapunov:.4f}") + if self.grad_norm is not None: + parts.append(f"∇={self.grad_norm:.4f}") + if self.firing_rate_mean is not None: + parts.append(f"fr={self.firing_rate_mean:.3f}±{self.firing_rate_std:.3f}") + if self.dead_neuron_frac is not None: + parts.append(f"dead={self.dead_neuron_frac:.1%}") + if self.saturated_neuron_frac is not None: + parts.append(f"sat={self.saturated_neuron_frac:.1%}") + return " | ".join(parts) + + +class StabilityMonitor: + """ + Monitor SNN stability during training. + + Usage: + monitor = StabilityMonitor() + + # During training + logits, lyap = model(x, compute_lyapunov=True) + loss.backward() + + metrics = monitor.compute( + model=model, + lyapunov=lyap, + spikes=spike_recordings, # optional + membrane=membrane_recordings, # optional + ) + print(metrics) + """ + + def __init__( + self, + dead_threshold: float = 0.01, + saturated_threshold: float = 0.9, + history_size: int = 100, + ): + """ + Args: + dead_threshold: Firing rate below this = dead neuron + saturated_threshold: Firing rate above this = saturated neuron + history_size: Number of batches to track for moving averages + """ + self.dead_threshold = dead_threshold + self.saturated_threshold = saturated_threshold + self.history_size = history_size + + # History tracking + self.lyap_history: List[float] = [] + self.grad_history: List[float] = [] + self.fr_history: List[float] = [] + + def compute_gradient_norm(self, model: nn.Module) -> float: + """Compute total gradient norm across all parameters.""" + total_norm = 0.0 + for p in model.parameters(): + if p.grad is not None: + total_norm += p.grad.data.norm(2).item() ** 2 + return total_norm ** 0.5 + + def compute_gradient_singular_values( + self, + model: nn.Module, + top_k: int = 5, + ) -> Dict[str, Tuple[np.ndarray, float]]: + """ + Compute singular values of gradient matrices. + + Per Gradient Flossing paper: singular value spectrum reveals + gradient pathologies (vanishing/exploding/rank collapse). + + Args: + model: The SNN model + top_k: Number of top/bottom singular values to return + + Returns: + Dict mapping layer name to (singular_values, condition_number) + """ + results = {} + for name, param in model.named_parameters(): + if param.grad is not None and param.ndim == 2: + # Only compute for weight matrices (2D) + with torch.no_grad(): + G = param.grad.detach().cpu() + try: + # Full SVD is expensive; use truncated for large matrices + if G.shape[0] * G.shape[1] > 1e6: + # For very large matrices, just compute extremes + U, S, V = torch.svd_lowrank(G, q=min(top_k, min(G.shape))) + sv = S.numpy() + else: + sv = torch.linalg.svdvals(G).numpy() + + max_sv = sv[0] if len(sv) > 0 else 0 + min_sv = sv[-1] if len(sv) > 0 else 0 + condition = max_sv / (min_sv + 1e-12) + + results[name] = (sv[:top_k], condition) + except Exception: + pass # Skip if SVD fails + return results + + def get_aggregate_gradient_sv(self, model: nn.Module) -> Tuple[float, float, float]: + """ + Get aggregate gradient singular value statistics. + + Returns: + (max_sv, min_sv, avg_condition_number) across all layers + """ + sv_results = self.compute_gradient_singular_values(model) + if not sv_results: + return 0.0, 0.0, 1.0 + + max_svs = [] + min_svs = [] + conditions = [] + + for name, (sv, cond) in sv_results.items(): + if len(sv) > 0: + max_svs.append(sv[0]) + min_svs.append(sv[-1]) + conditions.append(cond) + + if not max_svs: + return 0.0, 0.0, 1.0 + + return ( + float(np.max(max_svs)), + float(np.min(min_svs)), + float(np.mean(conditions)) + ) + + def compute_firing_stats( + self, + spikes: torch.Tensor, + ) -> Tuple[float, float, float, float]: + """ + Compute firing rate statistics. + + Args: + spikes: Spike tensor, shape (B, T, H) or (T, H) or (B, H) + Values should be 0/1. + + Returns: + (mean_rate, std_rate, dead_frac, saturated_frac) + """ + with torch.no_grad(): + # Flatten to (num_samples, num_neurons) if needed + if spikes.ndim == 3: + # (B, T, H) -> compute rate per neuron per sample + rates = spikes.float().mean(dim=1) # (B, H) + elif spikes.ndim == 2: + # Could be (T, H) or (B, H) - assume (T, H) for single sample + rates = spikes.float().mean(dim=0, keepdim=True) # (1, H) + else: + rates = spikes.float().unsqueeze(0) + + # Per-neuron average rate across batch + neuron_rates = rates.mean(dim=0) # (H,) + + mean_rate = neuron_rates.mean().item() + std_rate = neuron_rates.std().item() + + dead_frac = (neuron_rates < self.dead_threshold).float().mean().item() + saturated_frac = (neuron_rates > self.saturated_threshold).float().mean().item() + + return mean_rate, std_rate, dead_frac, saturated_frac + + def compute_membrane_stats( + self, + membrane: torch.Tensor, + ) -> Tuple[float, float]: + """ + Compute membrane potential statistics. + + Args: + membrane: Membrane potential tensor, any shape + + Returns: + (mean, std) + """ + with torch.no_grad(): + return membrane.mean().item(), membrane.std().item() + + def compute( + self, + model: nn.Module, + lyapunov: Optional[torch.Tensor] = None, + spikes: Optional[torch.Tensor] = None, + membrane: Optional[torch.Tensor] = None, + compute_sv: bool = False, + ) -> StabilityMetrics: + """ + Compute all available stability metrics. + + Args: + model: The SNN model (for gradient norms) + lyapunov: Lyapunov exponent from forward pass + spikes: Recorded spikes (optional) + membrane: Recorded membrane potentials (optional) + compute_sv: Whether to compute gradient singular values (expensive) + + Returns: + StabilityMetrics object + """ + metrics = StabilityMetrics() + + # Lyapunov exponent + if lyapunov is not None: + if isinstance(lyapunov, torch.Tensor): + lyapunov = lyapunov.item() + metrics.lyapunov = lyapunov + self.lyap_history.append(lyapunov) + if len(self.lyap_history) > self.history_size: + self.lyap_history.pop(0) + + # Gradient norm + grad_norm = self.compute_gradient_norm(model) + metrics.grad_norm = grad_norm + self.grad_history.append(grad_norm) + if len(self.grad_history) > self.history_size: + self.grad_history.pop(0) + + # Gradient singular values (optional, expensive) + if compute_sv: + max_sv, min_sv, avg_cond = self.get_aggregate_gradient_sv(model) + metrics.grad_max_sv = max_sv + metrics.grad_min_sv = min_sv + metrics.grad_condition = avg_cond + + # Firing rate statistics + if spikes is not None: + fr_mean, fr_std, dead_frac, sat_frac = self.compute_firing_stats(spikes) + metrics.firing_rate_mean = fr_mean + metrics.firing_rate_std = fr_std + metrics.dead_neuron_frac = dead_frac + metrics.saturated_neuron_frac = sat_frac + self.fr_history.append(fr_mean) + if len(self.fr_history) > self.history_size: + self.fr_history.pop(0) + + # Membrane potential statistics + if membrane is not None: + mem_mean, mem_std = self.compute_membrane_stats(membrane) + metrics.membrane_mean = mem_mean + metrics.membrane_std = mem_std + + return metrics + + def get_trends(self) -> Dict[str, str]: + """ + Analyze trends in stability metrics. + + Returns: + Dictionary with trend analysis for each metric. + """ + trends = {} + + if len(self.lyap_history) >= 10: + recent = np.mean(self.lyap_history[-10:]) + older = np.mean(self.lyap_history[:10]) + if recent > older + 0.1: + trends["lyapunov"] = "⚠️ INCREASING (becoming unstable)" + elif recent < older - 0.1: + trends["lyapunov"] = "✓ DECREASING (stabilizing)" + else: + trends["lyapunov"] = "→ STABLE" + + if len(self.grad_history) >= 10: + recent = np.mean(self.grad_history[-10:]) + older = np.mean(self.grad_history[:10]) + ratio = recent / (older + 1e-8) + if ratio > 10: + trends["gradients"] = "⚠️ EXPLODING" + elif ratio < 0.1: + trends["gradients"] = "⚠️ VANISHING" + else: + trends["gradients"] = "✓ STABLE" + + if len(self.fr_history) >= 10: + recent = np.mean(self.fr_history[-10:]) + if recent < 0.01: + trends["firing"] = "⚠️ DEAD NETWORK" + elif recent > 0.8: + trends["firing"] = "⚠️ SATURATED" + else: + trends["firing"] = "✓ HEALTHY" + + return trends + + def diagnose(self) -> str: + """Generate a diagnostic summary.""" + trends = self.get_trends() + + lines = ["=== Stability Diagnosis ==="] + + if self.lyap_history: + avg_lyap = np.mean(self.lyap_history[-20:]) + lines.append(f"Lyapunov exponent: {avg_lyap:.4f}") + if avg_lyap > 0.5: + lines.append(" → Network is CHAOTIC (trajectories diverge quickly)") + lines.append(" → Suggestion: Increase lambda_reg or decrease learning rate") + elif avg_lyap < -0.5: + lines.append(" → Network is OVER-STABLE (trajectories collapse)") + lines.append(" → Suggestion: May lose expressiveness, consider reducing regularization") + else: + lines.append(" → Network is at EDGE OF CHAOS (good for learning)") + + if self.grad_history: + avg_grad = np.mean(self.grad_history[-20:]) + max_grad = max(self.grad_history[-20:]) + lines.append(f"Gradient norm: avg={avg_grad:.4f}, max={max_grad:.4f}") + if "gradients" in trends: + lines.append(f" → {trends['gradients']}") + + if self.fr_history: + avg_fr = np.mean(self.fr_history[-20:]) + lines.append(f"Firing rate: {avg_fr:.4f}") + if "firing" in trends: + lines.append(f" → {trends['firing']}") + + return "\n".join(lines) + + +def compute_spectral_radius(weight_matrix: torch.Tensor) -> float: + """ + Compute spectral radius of a weight matrix. + + For recurrent networks, spectral radius > 1 indicates potential instability. + + Args: + weight_matrix: 2D weight tensor + + Returns: + Spectral radius (largest absolute eigenvalue) + """ + with torch.no_grad(): + W = weight_matrix.detach().cpu().numpy() + eigenvalues = np.linalg.eigvals(W) + return float(np.max(np.abs(eigenvalues))) + + +def analyze_weight_spectrum(model: nn.Module) -> Dict[str, float]: + """ + Analyze spectral properties of all weight matrices. + + Returns: + Dictionary mapping layer names to spectral radii. + """ + results = {} + for name, param in model.named_parameters(): + if "weight" in name and param.ndim == 2: + if param.shape[0] == param.shape[1]: # Square matrix (recurrent) + results[name] = compute_spectral_radius(param) + return results |
