diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:50:59 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:50:59 -0600 |
| commit | 00cf667cee7ffacb144d5805fc7e0ef443f3583a (patch) | |
| tree | 77d20a3adaecf96bf3aff0612bdd3b5fa1a7dc7e /files/analysis | |
| parent | c53c04aa1d6ff75cb478a9498c370baa929c74b6 (diff) | |
| parent | cd99d6b874d9d09b3bb87b8485cc787885af71f1 (diff) | |
Merge master into main
Diffstat (limited to 'files/analysis')
| -rw-r--r-- | files/analysis/plot_mvp.py | 139 | ||||
| -rw-r--r-- | files/analysis/stability_monitor.py | 395 |
2 files changed, 534 insertions, 0 deletions
diff --git a/files/analysis/plot_mvp.py b/files/analysis/plot_mvp.py new file mode 100644 index 0000000..495a9c2 --- /dev/null +++ b/files/analysis/plot_mvp.py @@ -0,0 +1,139 @@ +import argparse +import csv +import json +import os +from glob import glob +from typing import Dict, List, Tuple + +import matplotlib.pyplot as plt + + +def load_run(run_dir: str) -> Dict: + args_path = os.path.join(run_dir, "args.json") + metrics_path = os.path.join(run_dir, "metrics.csv") + if not (os.path.exists(args_path) and os.path.exists(metrics_path)): + return {} + with open(args_path, "r") as f: + args = json.load(f) + epochs = [] + loss = [] + acc = [] + lyap = [] + with open(metrics_path, "r") as f: + reader = csv.DictReader(f) + for row in reader: + if row.get("step", "") != "epoch": + continue + try: + e = int(row.get("epoch", "0")) + l = float(row.get("loss", "nan")) + a = float(row.get("acc", "nan")) + y = row.get("lyap", "nan") + y = float("nan") if (y is None or y == "" or str(y).lower() == "nan") else float(y) + except Exception: + continue + epochs.append(e) + loss.append(l) + acc.append(a) + lyap.append(y) + if not epochs: + return {} + return { + "args": args, + "epochs": epochs, + "loss": loss, + "acc": acc, + "lyap": lyap, + "run_dir": run_dir, + } + + +def label_for_run(run: Dict) -> str: + args = run["args"] + if args.get("lyapunov", False): + lam = args.get("lambda_reg", None) + tgt = args.get("lambda_target", None) + hid = args.get("hidden", None) + return f"Lyap λ={lam}, tgt={tgt}, H={hid}" + else: + hid = args.get("hidden", None) + return f"Baseline H={hid}" + + +def gather_runs(base_dir: str) -> List[Dict]: + cand = sorted(glob(os.path.join(base_dir, "*"))) + runs = [] + for rd in cand: + data = load_run(rd) + if data: + runs.append(data) + return runs + + +def plot_runs(runs: List[Dict], out_path: str): + if not runs: + raise SystemExit("No valid runs found (expected args.json and metrics.csv under base_dir/*)") + # Split baseline vs lyapunov + base = [r for r in runs if not r["args"].get("lyapunov", False)] + lyap = [r for r in runs if r["args"].get("lyapunov", False)] + + fig, axes = plt.subplots(1, 3, figsize=(14, 4.5)) + + # Loss + ax = axes[0] + for r in base: + ax.plot(r["epochs"], r["loss"], label=label_for_run(r), alpha=0.9) + for r in lyap: + ax.plot(r["epochs"], r["loss"], label=label_for_run(r), linestyle="--", alpha=0.9) + ax.set_title("Training loss") + ax.set_xlabel("Epoch") + ax.set_ylabel("Loss") + ax.grid(True, alpha=0.3) + ax.legend(fontsize=8) + + # Accuracy + ax = axes[1] + for r in base: + ax.plot(r["epochs"], r["acc"], label=label_for_run(r), alpha=0.9) + for r in lyap: + ax.plot(r["epochs"], r["acc"], label=label_for_run(r), linestyle="--", alpha=0.9) + ax.set_title("Training accuracy") + ax.set_xlabel("Epoch") + ax.set_ylabel("Accuracy") + ax.grid(True, alpha=0.3) + ax.legend(fontsize=8) + + # Lyapunov estimate (only lyap runs) + ax = axes[2] + if lyap: + for r in lyap: + ax.plot(r["epochs"], r["lyap"], label=label_for_run(r), alpha=0.9) + ax.set_title("Surrogate Lyapunov estimate") + ax.set_xlabel("Epoch") + ax.set_ylabel("Avg log growth") + ax.grid(True, alpha=0.3) + ax.legend(fontsize=8) + else: + ax.text(0.5, 0.5, "No Lyapunov runs found", ha="center", va="center", transform=ax.transAxes) + ax.set_axis_off() + + fig.tight_layout() + os.makedirs(os.path.dirname(out_path), exist_ok=True) + fig.savefig(out_path, dpi=150) + print(f"Saved figure to {out_path}") + + +def main(): + ap = argparse.ArgumentParser(description="Plot MVP runs: baseline vs Lyapunov comparison") + ap.add_argument("--base_dir", type=str, default="runs/mvp", help="Directory containing run subfolders") + ap.add_argument("--out", type=str, default="runs/mvp_summary.png", help="Output figure path") + args = ap.parse_args() + + runs = gather_runs(args.base_dir) + plot_runs(runs, args.out) + + +if __name__ == "__main__": + main() + + diff --git a/files/analysis/stability_monitor.py b/files/analysis/stability_monitor.py new file mode 100644 index 0000000..18cad0f --- /dev/null +++ b/files/analysis/stability_monitor.py @@ -0,0 +1,395 @@ +""" +Stability monitoring utilities for SNN training. + +Provides metrics to diagnose training stability: +- Lyapunov exponent (trajectory divergence) +- Gradient norms (vanishing/exploding) +- Firing rates (dead/saturated neurons) +- Membrane potential statistics +""" + +from typing import Dict, List, Optional, Tuple +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import numpy as np + + +@dataclass +class StabilityMetrics: + """Container for stability measurements.""" + lyapunov: Optional[float] = None + grad_norm: Optional[float] = None + grad_max_sv: Optional[float] = None # Max singular value of gradients + grad_min_sv: Optional[float] = None # Min singular value of gradients + grad_condition: Optional[float] = None # Condition number (max_sv / min_sv) + firing_rate_mean: Optional[float] = None + firing_rate_std: Optional[float] = None + dead_neuron_frac: Optional[float] = None + saturated_neuron_frac: Optional[float] = None + membrane_mean: Optional[float] = None + membrane_std: Optional[float] = None + + def to_dict(self) -> Dict[str, float]: + return {k: v for k, v in self.__dict__.items() if v is not None} + + def __str__(self) -> str: + parts = [] + if self.lyapunov is not None: + parts.append(f"λ={self.lyapunov:.4f}") + if self.grad_norm is not None: + parts.append(f"∇={self.grad_norm:.4f}") + if self.firing_rate_mean is not None: + parts.append(f"fr={self.firing_rate_mean:.3f}±{self.firing_rate_std:.3f}") + if self.dead_neuron_frac is not None: + parts.append(f"dead={self.dead_neuron_frac:.1%}") + if self.saturated_neuron_frac is not None: + parts.append(f"sat={self.saturated_neuron_frac:.1%}") + return " | ".join(parts) + + +class StabilityMonitor: + """ + Monitor SNN stability during training. + + Usage: + monitor = StabilityMonitor() + + # During training + logits, lyap = model(x, compute_lyapunov=True) + loss.backward() + + metrics = monitor.compute( + model=model, + lyapunov=lyap, + spikes=spike_recordings, # optional + membrane=membrane_recordings, # optional + ) + print(metrics) + """ + + def __init__( + self, + dead_threshold: float = 0.01, + saturated_threshold: float = 0.9, + history_size: int = 100, + ): + """ + Args: + dead_threshold: Firing rate below this = dead neuron + saturated_threshold: Firing rate above this = saturated neuron + history_size: Number of batches to track for moving averages + """ + self.dead_threshold = dead_threshold + self.saturated_threshold = saturated_threshold + self.history_size = history_size + + # History tracking + self.lyap_history: List[float] = [] + self.grad_history: List[float] = [] + self.fr_history: List[float] = [] + + def compute_gradient_norm(self, model: nn.Module) -> float: + """Compute total gradient norm across all parameters.""" + total_norm = 0.0 + for p in model.parameters(): + if p.grad is not None: + total_norm += p.grad.data.norm(2).item() ** 2 + return total_norm ** 0.5 + + def compute_gradient_singular_values( + self, + model: nn.Module, + top_k: int = 5, + ) -> Dict[str, Tuple[np.ndarray, float]]: + """ + Compute singular values of gradient matrices. + + Per Gradient Flossing paper: singular value spectrum reveals + gradient pathologies (vanishing/exploding/rank collapse). + + Args: + model: The SNN model + top_k: Number of top/bottom singular values to return + + Returns: + Dict mapping layer name to (singular_values, condition_number) + """ + results = {} + for name, param in model.named_parameters(): + if param.grad is not None and param.ndim == 2: + # Only compute for weight matrices (2D) + with torch.no_grad(): + G = param.grad.detach().cpu() + try: + # Full SVD is expensive; use truncated for large matrices + if G.shape[0] * G.shape[1] > 1e6: + # For very large matrices, just compute extremes + U, S, V = torch.svd_lowrank(G, q=min(top_k, min(G.shape))) + sv = S.numpy() + else: + sv = torch.linalg.svdvals(G).numpy() + + max_sv = sv[0] if len(sv) > 0 else 0 + min_sv = sv[-1] if len(sv) > 0 else 0 + condition = max_sv / (min_sv + 1e-12) + + results[name] = (sv[:top_k], condition) + except Exception: + pass # Skip if SVD fails + return results + + def get_aggregate_gradient_sv(self, model: nn.Module) -> Tuple[float, float, float]: + """ + Get aggregate gradient singular value statistics. + + Returns: + (max_sv, min_sv, avg_condition_number) across all layers + """ + sv_results = self.compute_gradient_singular_values(model) + if not sv_results: + return 0.0, 0.0, 1.0 + + max_svs = [] + min_svs = [] + conditions = [] + + for name, (sv, cond) in sv_results.items(): + if len(sv) > 0: + max_svs.append(sv[0]) + min_svs.append(sv[-1]) + conditions.append(cond) + + if not max_svs: + return 0.0, 0.0, 1.0 + + return ( + float(np.max(max_svs)), + float(np.min(min_svs)), + float(np.mean(conditions)) + ) + + def compute_firing_stats( + self, + spikes: torch.Tensor, + ) -> Tuple[float, float, float, float]: + """ + Compute firing rate statistics. + + Args: + spikes: Spike tensor, shape (B, T, H) or (T, H) or (B, H) + Values should be 0/1. + + Returns: + (mean_rate, std_rate, dead_frac, saturated_frac) + """ + with torch.no_grad(): + # Flatten to (num_samples, num_neurons) if needed + if spikes.ndim == 3: + # (B, T, H) -> compute rate per neuron per sample + rates = spikes.float().mean(dim=1) # (B, H) + elif spikes.ndim == 2: + # Could be (T, H) or (B, H) - assume (T, H) for single sample + rates = spikes.float().mean(dim=0, keepdim=True) # (1, H) + else: + rates = spikes.float().unsqueeze(0) + + # Per-neuron average rate across batch + neuron_rates = rates.mean(dim=0) # (H,) + + mean_rate = neuron_rates.mean().item() + std_rate = neuron_rates.std().item() + + dead_frac = (neuron_rates < self.dead_threshold).float().mean().item() + saturated_frac = (neuron_rates > self.saturated_threshold).float().mean().item() + + return mean_rate, std_rate, dead_frac, saturated_frac + + def compute_membrane_stats( + self, + membrane: torch.Tensor, + ) -> Tuple[float, float]: + """ + Compute membrane potential statistics. + + Args: + membrane: Membrane potential tensor, any shape + + Returns: + (mean, std) + """ + with torch.no_grad(): + return membrane.mean().item(), membrane.std().item() + + def compute( + self, + model: nn.Module, + lyapunov: Optional[torch.Tensor] = None, + spikes: Optional[torch.Tensor] = None, + membrane: Optional[torch.Tensor] = None, + compute_sv: bool = False, + ) -> StabilityMetrics: + """ + Compute all available stability metrics. + + Args: + model: The SNN model (for gradient norms) + lyapunov: Lyapunov exponent from forward pass + spikes: Recorded spikes (optional) + membrane: Recorded membrane potentials (optional) + compute_sv: Whether to compute gradient singular values (expensive) + + Returns: + StabilityMetrics object + """ + metrics = StabilityMetrics() + + # Lyapunov exponent + if lyapunov is not None: + if isinstance(lyapunov, torch.Tensor): + lyapunov = lyapunov.item() + metrics.lyapunov = lyapunov + self.lyap_history.append(lyapunov) + if len(self.lyap_history) > self.history_size: + self.lyap_history.pop(0) + + # Gradient norm + grad_norm = self.compute_gradient_norm(model) + metrics.grad_norm = grad_norm + self.grad_history.append(grad_norm) + if len(self.grad_history) > self.history_size: + self.grad_history.pop(0) + + # Gradient singular values (optional, expensive) + if compute_sv: + max_sv, min_sv, avg_cond = self.get_aggregate_gradient_sv(model) + metrics.grad_max_sv = max_sv + metrics.grad_min_sv = min_sv + metrics.grad_condition = avg_cond + + # Firing rate statistics + if spikes is not None: + fr_mean, fr_std, dead_frac, sat_frac = self.compute_firing_stats(spikes) + metrics.firing_rate_mean = fr_mean + metrics.firing_rate_std = fr_std + metrics.dead_neuron_frac = dead_frac + metrics.saturated_neuron_frac = sat_frac + self.fr_history.append(fr_mean) + if len(self.fr_history) > self.history_size: + self.fr_history.pop(0) + + # Membrane potential statistics + if membrane is not None: + mem_mean, mem_std = self.compute_membrane_stats(membrane) + metrics.membrane_mean = mem_mean + metrics.membrane_std = mem_std + + return metrics + + def get_trends(self) -> Dict[str, str]: + """ + Analyze trends in stability metrics. + + Returns: + Dictionary with trend analysis for each metric. + """ + trends = {} + + if len(self.lyap_history) >= 10: + recent = np.mean(self.lyap_history[-10:]) + older = np.mean(self.lyap_history[:10]) + if recent > older + 0.1: + trends["lyapunov"] = "⚠️ INCREASING (becoming unstable)" + elif recent < older - 0.1: + trends["lyapunov"] = "✓ DECREASING (stabilizing)" + else: + trends["lyapunov"] = "→ STABLE" + + if len(self.grad_history) >= 10: + recent = np.mean(self.grad_history[-10:]) + older = np.mean(self.grad_history[:10]) + ratio = recent / (older + 1e-8) + if ratio > 10: + trends["gradients"] = "⚠️ EXPLODING" + elif ratio < 0.1: + trends["gradients"] = "⚠️ VANISHING" + else: + trends["gradients"] = "✓ STABLE" + + if len(self.fr_history) >= 10: + recent = np.mean(self.fr_history[-10:]) + if recent < 0.01: + trends["firing"] = "⚠️ DEAD NETWORK" + elif recent > 0.8: + trends["firing"] = "⚠️ SATURATED" + else: + trends["firing"] = "✓ HEALTHY" + + return trends + + def diagnose(self) -> str: + """Generate a diagnostic summary.""" + trends = self.get_trends() + + lines = ["=== Stability Diagnosis ==="] + + if self.lyap_history: + avg_lyap = np.mean(self.lyap_history[-20:]) + lines.append(f"Lyapunov exponent: {avg_lyap:.4f}") + if avg_lyap > 0.5: + lines.append(" → Network is CHAOTIC (trajectories diverge quickly)") + lines.append(" → Suggestion: Increase lambda_reg or decrease learning rate") + elif avg_lyap < -0.5: + lines.append(" → Network is OVER-STABLE (trajectories collapse)") + lines.append(" → Suggestion: May lose expressiveness, consider reducing regularization") + else: + lines.append(" → Network is at EDGE OF CHAOS (good for learning)") + + if self.grad_history: + avg_grad = np.mean(self.grad_history[-20:]) + max_grad = max(self.grad_history[-20:]) + lines.append(f"Gradient norm: avg={avg_grad:.4f}, max={max_grad:.4f}") + if "gradients" in trends: + lines.append(f" → {trends['gradients']}") + + if self.fr_history: + avg_fr = np.mean(self.fr_history[-20:]) + lines.append(f"Firing rate: {avg_fr:.4f}") + if "firing" in trends: + lines.append(f" → {trends['firing']}") + + return "\n".join(lines) + + +def compute_spectral_radius(weight_matrix: torch.Tensor) -> float: + """ + Compute spectral radius of a weight matrix. + + For recurrent networks, spectral radius > 1 indicates potential instability. + + Args: + weight_matrix: 2D weight tensor + + Returns: + Spectral radius (largest absolute eigenvalue) + """ + with torch.no_grad(): + W = weight_matrix.detach().cpu().numpy() + eigenvalues = np.linalg.eigvals(W) + return float(np.max(np.abs(eigenvalues))) + + +def analyze_weight_spectrum(model: nn.Module) -> Dict[str, float]: + """ + Analyze spectral properties of all weight matrices. + + Returns: + Dictionary mapping layer names to spectral radii. + """ + results = {} + for name, param in model.named_parameters(): + if "weight" in name and param.ndim == 2: + if param.shape[0] == param.shape[1]: # Square matrix (recurrent) + results[name] = compute_spectral_radius(param) + return results |
