summaryrefslogtreecommitdiff
path: root/files/analysis/stability_monitor.py
diff options
context:
space:
mode:
Diffstat (limited to 'files/analysis/stability_monitor.py')
-rw-r--r--files/analysis/stability_monitor.py395
1 files changed, 395 insertions, 0 deletions
diff --git a/files/analysis/stability_monitor.py b/files/analysis/stability_monitor.py
new file mode 100644
index 0000000..18cad0f
--- /dev/null
+++ b/files/analysis/stability_monitor.py
@@ -0,0 +1,395 @@
+"""
+Stability monitoring utilities for SNN training.
+
+Provides metrics to diagnose training stability:
+- Lyapunov exponent (trajectory divergence)
+- Gradient norms (vanishing/exploding)
+- Firing rates (dead/saturated neurons)
+- Membrane potential statistics
+"""
+
+from typing import Dict, List, Optional, Tuple
+from dataclasses import dataclass, field
+
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+@dataclass
+class StabilityMetrics:
+ """Container for stability measurements."""
+ lyapunov: Optional[float] = None
+ grad_norm: Optional[float] = None
+ grad_max_sv: Optional[float] = None # Max singular value of gradients
+ grad_min_sv: Optional[float] = None # Min singular value of gradients
+ grad_condition: Optional[float] = None # Condition number (max_sv / min_sv)
+ firing_rate_mean: Optional[float] = None
+ firing_rate_std: Optional[float] = None
+ dead_neuron_frac: Optional[float] = None
+ saturated_neuron_frac: Optional[float] = None
+ membrane_mean: Optional[float] = None
+ membrane_std: Optional[float] = None
+
+ def to_dict(self) -> Dict[str, float]:
+ return {k: v for k, v in self.__dict__.items() if v is not None}
+
+ def __str__(self) -> str:
+ parts = []
+ if self.lyapunov is not None:
+ parts.append(f"λ={self.lyapunov:.4f}")
+ if self.grad_norm is not None:
+ parts.append(f"∇={self.grad_norm:.4f}")
+ if self.firing_rate_mean is not None:
+ parts.append(f"fr={self.firing_rate_mean:.3f}±{self.firing_rate_std:.3f}")
+ if self.dead_neuron_frac is not None:
+ parts.append(f"dead={self.dead_neuron_frac:.1%}")
+ if self.saturated_neuron_frac is not None:
+ parts.append(f"sat={self.saturated_neuron_frac:.1%}")
+ return " | ".join(parts)
+
+
+class StabilityMonitor:
+ """
+ Monitor SNN stability during training.
+
+ Usage:
+ monitor = StabilityMonitor()
+
+ # During training
+ logits, lyap = model(x, compute_lyapunov=True)
+ loss.backward()
+
+ metrics = monitor.compute(
+ model=model,
+ lyapunov=lyap,
+ spikes=spike_recordings, # optional
+ membrane=membrane_recordings, # optional
+ )
+ print(metrics)
+ """
+
+ def __init__(
+ self,
+ dead_threshold: float = 0.01,
+ saturated_threshold: float = 0.9,
+ history_size: int = 100,
+ ):
+ """
+ Args:
+ dead_threshold: Firing rate below this = dead neuron
+ saturated_threshold: Firing rate above this = saturated neuron
+ history_size: Number of batches to track for moving averages
+ """
+ self.dead_threshold = dead_threshold
+ self.saturated_threshold = saturated_threshold
+ self.history_size = history_size
+
+ # History tracking
+ self.lyap_history: List[float] = []
+ self.grad_history: List[float] = []
+ self.fr_history: List[float] = []
+
+ def compute_gradient_norm(self, model: nn.Module) -> float:
+ """Compute total gradient norm across all parameters."""
+ total_norm = 0.0
+ for p in model.parameters():
+ if p.grad is not None:
+ total_norm += p.grad.data.norm(2).item() ** 2
+ return total_norm ** 0.5
+
+ def compute_gradient_singular_values(
+ self,
+ model: nn.Module,
+ top_k: int = 5,
+ ) -> Dict[str, Tuple[np.ndarray, float]]:
+ """
+ Compute singular values of gradient matrices.
+
+ Per Gradient Flossing paper: singular value spectrum reveals
+ gradient pathologies (vanishing/exploding/rank collapse).
+
+ Args:
+ model: The SNN model
+ top_k: Number of top/bottom singular values to return
+
+ Returns:
+ Dict mapping layer name to (singular_values, condition_number)
+ """
+ results = {}
+ for name, param in model.named_parameters():
+ if param.grad is not None and param.ndim == 2:
+ # Only compute for weight matrices (2D)
+ with torch.no_grad():
+ G = param.grad.detach().cpu()
+ try:
+ # Full SVD is expensive; use truncated for large matrices
+ if G.shape[0] * G.shape[1] > 1e6:
+ # For very large matrices, just compute extremes
+ U, S, V = torch.svd_lowrank(G, q=min(top_k, min(G.shape)))
+ sv = S.numpy()
+ else:
+ sv = torch.linalg.svdvals(G).numpy()
+
+ max_sv = sv[0] if len(sv) > 0 else 0
+ min_sv = sv[-1] if len(sv) > 0 else 0
+ condition = max_sv / (min_sv + 1e-12)
+
+ results[name] = (sv[:top_k], condition)
+ except Exception:
+ pass # Skip if SVD fails
+ return results
+
+ def get_aggregate_gradient_sv(self, model: nn.Module) -> Tuple[float, float, float]:
+ """
+ Get aggregate gradient singular value statistics.
+
+ Returns:
+ (max_sv, min_sv, avg_condition_number) across all layers
+ """
+ sv_results = self.compute_gradient_singular_values(model)
+ if not sv_results:
+ return 0.0, 0.0, 1.0
+
+ max_svs = []
+ min_svs = []
+ conditions = []
+
+ for name, (sv, cond) in sv_results.items():
+ if len(sv) > 0:
+ max_svs.append(sv[0])
+ min_svs.append(sv[-1])
+ conditions.append(cond)
+
+ if not max_svs:
+ return 0.0, 0.0, 1.0
+
+ return (
+ float(np.max(max_svs)),
+ float(np.min(min_svs)),
+ float(np.mean(conditions))
+ )
+
+ def compute_firing_stats(
+ self,
+ spikes: torch.Tensor,
+ ) -> Tuple[float, float, float, float]:
+ """
+ Compute firing rate statistics.
+
+ Args:
+ spikes: Spike tensor, shape (B, T, H) or (T, H) or (B, H)
+ Values should be 0/1.
+
+ Returns:
+ (mean_rate, std_rate, dead_frac, saturated_frac)
+ """
+ with torch.no_grad():
+ # Flatten to (num_samples, num_neurons) if needed
+ if spikes.ndim == 3:
+ # (B, T, H) -> compute rate per neuron per sample
+ rates = spikes.float().mean(dim=1) # (B, H)
+ elif spikes.ndim == 2:
+ # Could be (T, H) or (B, H) - assume (T, H) for single sample
+ rates = spikes.float().mean(dim=0, keepdim=True) # (1, H)
+ else:
+ rates = spikes.float().unsqueeze(0)
+
+ # Per-neuron average rate across batch
+ neuron_rates = rates.mean(dim=0) # (H,)
+
+ mean_rate = neuron_rates.mean().item()
+ std_rate = neuron_rates.std().item()
+
+ dead_frac = (neuron_rates < self.dead_threshold).float().mean().item()
+ saturated_frac = (neuron_rates > self.saturated_threshold).float().mean().item()
+
+ return mean_rate, std_rate, dead_frac, saturated_frac
+
+ def compute_membrane_stats(
+ self,
+ membrane: torch.Tensor,
+ ) -> Tuple[float, float]:
+ """
+ Compute membrane potential statistics.
+
+ Args:
+ membrane: Membrane potential tensor, any shape
+
+ Returns:
+ (mean, std)
+ """
+ with torch.no_grad():
+ return membrane.mean().item(), membrane.std().item()
+
+ def compute(
+ self,
+ model: nn.Module,
+ lyapunov: Optional[torch.Tensor] = None,
+ spikes: Optional[torch.Tensor] = None,
+ membrane: Optional[torch.Tensor] = None,
+ compute_sv: bool = False,
+ ) -> StabilityMetrics:
+ """
+ Compute all available stability metrics.
+
+ Args:
+ model: The SNN model (for gradient norms)
+ lyapunov: Lyapunov exponent from forward pass
+ spikes: Recorded spikes (optional)
+ membrane: Recorded membrane potentials (optional)
+ compute_sv: Whether to compute gradient singular values (expensive)
+
+ Returns:
+ StabilityMetrics object
+ """
+ metrics = StabilityMetrics()
+
+ # Lyapunov exponent
+ if lyapunov is not None:
+ if isinstance(lyapunov, torch.Tensor):
+ lyapunov = lyapunov.item()
+ metrics.lyapunov = lyapunov
+ self.lyap_history.append(lyapunov)
+ if len(self.lyap_history) > self.history_size:
+ self.lyap_history.pop(0)
+
+ # Gradient norm
+ grad_norm = self.compute_gradient_norm(model)
+ metrics.grad_norm = grad_norm
+ self.grad_history.append(grad_norm)
+ if len(self.grad_history) > self.history_size:
+ self.grad_history.pop(0)
+
+ # Gradient singular values (optional, expensive)
+ if compute_sv:
+ max_sv, min_sv, avg_cond = self.get_aggregate_gradient_sv(model)
+ metrics.grad_max_sv = max_sv
+ metrics.grad_min_sv = min_sv
+ metrics.grad_condition = avg_cond
+
+ # Firing rate statistics
+ if spikes is not None:
+ fr_mean, fr_std, dead_frac, sat_frac = self.compute_firing_stats(spikes)
+ metrics.firing_rate_mean = fr_mean
+ metrics.firing_rate_std = fr_std
+ metrics.dead_neuron_frac = dead_frac
+ metrics.saturated_neuron_frac = sat_frac
+ self.fr_history.append(fr_mean)
+ if len(self.fr_history) > self.history_size:
+ self.fr_history.pop(0)
+
+ # Membrane potential statistics
+ if membrane is not None:
+ mem_mean, mem_std = self.compute_membrane_stats(membrane)
+ metrics.membrane_mean = mem_mean
+ metrics.membrane_std = mem_std
+
+ return metrics
+
+ def get_trends(self) -> Dict[str, str]:
+ """
+ Analyze trends in stability metrics.
+
+ Returns:
+ Dictionary with trend analysis for each metric.
+ """
+ trends = {}
+
+ if len(self.lyap_history) >= 10:
+ recent = np.mean(self.lyap_history[-10:])
+ older = np.mean(self.lyap_history[:10])
+ if recent > older + 0.1:
+ trends["lyapunov"] = "⚠️ INCREASING (becoming unstable)"
+ elif recent < older - 0.1:
+ trends["lyapunov"] = "✓ DECREASING (stabilizing)"
+ else:
+ trends["lyapunov"] = "→ STABLE"
+
+ if len(self.grad_history) >= 10:
+ recent = np.mean(self.grad_history[-10:])
+ older = np.mean(self.grad_history[:10])
+ ratio = recent / (older + 1e-8)
+ if ratio > 10:
+ trends["gradients"] = "⚠️ EXPLODING"
+ elif ratio < 0.1:
+ trends["gradients"] = "⚠️ VANISHING"
+ else:
+ trends["gradients"] = "✓ STABLE"
+
+ if len(self.fr_history) >= 10:
+ recent = np.mean(self.fr_history[-10:])
+ if recent < 0.01:
+ trends["firing"] = "⚠️ DEAD NETWORK"
+ elif recent > 0.8:
+ trends["firing"] = "⚠️ SATURATED"
+ else:
+ trends["firing"] = "✓ HEALTHY"
+
+ return trends
+
+ def diagnose(self) -> str:
+ """Generate a diagnostic summary."""
+ trends = self.get_trends()
+
+ lines = ["=== Stability Diagnosis ==="]
+
+ if self.lyap_history:
+ avg_lyap = np.mean(self.lyap_history[-20:])
+ lines.append(f"Lyapunov exponent: {avg_lyap:.4f}")
+ if avg_lyap > 0.5:
+ lines.append(" → Network is CHAOTIC (trajectories diverge quickly)")
+ lines.append(" → Suggestion: Increase lambda_reg or decrease learning rate")
+ elif avg_lyap < -0.5:
+ lines.append(" → Network is OVER-STABLE (trajectories collapse)")
+ lines.append(" → Suggestion: May lose expressiveness, consider reducing regularization")
+ else:
+ lines.append(" → Network is at EDGE OF CHAOS (good for learning)")
+
+ if self.grad_history:
+ avg_grad = np.mean(self.grad_history[-20:])
+ max_grad = max(self.grad_history[-20:])
+ lines.append(f"Gradient norm: avg={avg_grad:.4f}, max={max_grad:.4f}")
+ if "gradients" in trends:
+ lines.append(f" → {trends['gradients']}")
+
+ if self.fr_history:
+ avg_fr = np.mean(self.fr_history[-20:])
+ lines.append(f"Firing rate: {avg_fr:.4f}")
+ if "firing" in trends:
+ lines.append(f" → {trends['firing']}")
+
+ return "\n".join(lines)
+
+
+def compute_spectral_radius(weight_matrix: torch.Tensor) -> float:
+ """
+ Compute spectral radius of a weight matrix.
+
+ For recurrent networks, spectral radius > 1 indicates potential instability.
+
+ Args:
+ weight_matrix: 2D weight tensor
+
+ Returns:
+ Spectral radius (largest absolute eigenvalue)
+ """
+ with torch.no_grad():
+ W = weight_matrix.detach().cpu().numpy()
+ eigenvalues = np.linalg.eigvals(W)
+ return float(np.max(np.abs(eigenvalues)))
+
+
+def analyze_weight_spectrum(model: nn.Module) -> Dict[str, float]:
+ """
+ Analyze spectral properties of all weight matrices.
+
+ Returns:
+ Dictionary mapping layer names to spectral radii.
+ """
+ results = {}
+ for name, param in model.named_parameters():
+ if "weight" in name and param.ndim == 2:
+ if param.shape[0] == param.shape[1]: # Square matrix (recurrent)
+ results[name] = compute_spectral_radius(param)
+ return results