summaryrefslogtreecommitdiff
path: root/files/analysis/stability_monitor.py
blob: 18cad0fe568ea090fdf497efdb4025bbcae1ab52 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
"""
Stability monitoring utilities for SNN training.

Provides metrics to diagnose training stability:
- Lyapunov exponent (trajectory divergence)
- Gradient norms (vanishing/exploding)
- Firing rates (dead/saturated neurons)
- Membrane potential statistics
"""

from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, field

import torch
import torch.nn as nn
import numpy as np


@dataclass
class StabilityMetrics:
    """Container for stability measurements."""
    lyapunov: Optional[float] = None
    grad_norm: Optional[float] = None
    grad_max_sv: Optional[float] = None  # Max singular value of gradients
    grad_min_sv: Optional[float] = None  # Min singular value of gradients
    grad_condition: Optional[float] = None  # Condition number (max_sv / min_sv)
    firing_rate_mean: Optional[float] = None
    firing_rate_std: Optional[float] = None
    dead_neuron_frac: Optional[float] = None
    saturated_neuron_frac: Optional[float] = None
    membrane_mean: Optional[float] = None
    membrane_std: Optional[float] = None

    def to_dict(self) -> Dict[str, float]:
        return {k: v for k, v in self.__dict__.items() if v is not None}

    def __str__(self) -> str:
        parts = []
        if self.lyapunov is not None:
            parts.append(f"λ={self.lyapunov:.4f}")
        if self.grad_norm is not None:
            parts.append(f"∇={self.grad_norm:.4f}")
        if self.firing_rate_mean is not None:
            parts.append(f"fr={self.firing_rate_mean:.3f}±{self.firing_rate_std:.3f}")
        if self.dead_neuron_frac is not None:
            parts.append(f"dead={self.dead_neuron_frac:.1%}")
        if self.saturated_neuron_frac is not None:
            parts.append(f"sat={self.saturated_neuron_frac:.1%}")
        return " | ".join(parts)


class StabilityMonitor:
    """
    Monitor SNN stability during training.

    Usage:
        monitor = StabilityMonitor()

        # During training
        logits, lyap = model(x, compute_lyapunov=True)
        loss.backward()

        metrics = monitor.compute(
            model=model,
            lyapunov=lyap,
            spikes=spike_recordings,  # optional
            membrane=membrane_recordings,  # optional
        )
        print(metrics)
    """

    def __init__(
        self,
        dead_threshold: float = 0.01,
        saturated_threshold: float = 0.9,
        history_size: int = 100,
    ):
        """
        Args:
            dead_threshold: Firing rate below this = dead neuron
            saturated_threshold: Firing rate above this = saturated neuron
            history_size: Number of batches to track for moving averages
        """
        self.dead_threshold = dead_threshold
        self.saturated_threshold = saturated_threshold
        self.history_size = history_size

        # History tracking
        self.lyap_history: List[float] = []
        self.grad_history: List[float] = []
        self.fr_history: List[float] = []

    def compute_gradient_norm(self, model: nn.Module) -> float:
        """Compute total gradient norm across all parameters."""
        total_norm = 0.0
        for p in model.parameters():
            if p.grad is not None:
                total_norm += p.grad.data.norm(2).item() ** 2
        return total_norm ** 0.5

    def compute_gradient_singular_values(
        self,
        model: nn.Module,
        top_k: int = 5,
    ) -> Dict[str, Tuple[np.ndarray, float]]:
        """
        Compute singular values of gradient matrices.

        Per Gradient Flossing paper: singular value spectrum reveals
        gradient pathologies (vanishing/exploding/rank collapse).

        Args:
            model: The SNN model
            top_k: Number of top/bottom singular values to return

        Returns:
            Dict mapping layer name to (singular_values, condition_number)
        """
        results = {}
        for name, param in model.named_parameters():
            if param.grad is not None and param.ndim == 2:
                # Only compute for weight matrices (2D)
                with torch.no_grad():
                    G = param.grad.detach().cpu()
                    try:
                        # Full SVD is expensive; use truncated for large matrices
                        if G.shape[0] * G.shape[1] > 1e6:
                            # For very large matrices, just compute extremes
                            U, S, V = torch.svd_lowrank(G, q=min(top_k, min(G.shape)))
                            sv = S.numpy()
                        else:
                            sv = torch.linalg.svdvals(G).numpy()

                        max_sv = sv[0] if len(sv) > 0 else 0
                        min_sv = sv[-1] if len(sv) > 0 else 0
                        condition = max_sv / (min_sv + 1e-12)

                        results[name] = (sv[:top_k], condition)
                    except Exception:
                        pass  # Skip if SVD fails
        return results

    def get_aggregate_gradient_sv(self, model: nn.Module) -> Tuple[float, float, float]:
        """
        Get aggregate gradient singular value statistics.

        Returns:
            (max_sv, min_sv, avg_condition_number) across all layers
        """
        sv_results = self.compute_gradient_singular_values(model)
        if not sv_results:
            return 0.0, 0.0, 1.0

        max_svs = []
        min_svs = []
        conditions = []

        for name, (sv, cond) in sv_results.items():
            if len(sv) > 0:
                max_svs.append(sv[0])
                min_svs.append(sv[-1])
                conditions.append(cond)

        if not max_svs:
            return 0.0, 0.0, 1.0

        return (
            float(np.max(max_svs)),
            float(np.min(min_svs)),
            float(np.mean(conditions))
        )

    def compute_firing_stats(
        self,
        spikes: torch.Tensor,
    ) -> Tuple[float, float, float, float]:
        """
        Compute firing rate statistics.

        Args:
            spikes: Spike tensor, shape (B, T, H) or (T, H) or (B, H)
                    Values should be 0/1.

        Returns:
            (mean_rate, std_rate, dead_frac, saturated_frac)
        """
        with torch.no_grad():
            # Flatten to (num_samples, num_neurons) if needed
            if spikes.ndim == 3:
                # (B, T, H) -> compute rate per neuron per sample
                rates = spikes.float().mean(dim=1)  # (B, H)
            elif spikes.ndim == 2:
                # Could be (T, H) or (B, H) - assume (T, H) for single sample
                rates = spikes.float().mean(dim=0, keepdim=True)  # (1, H)
            else:
                rates = spikes.float().unsqueeze(0)

            # Per-neuron average rate across batch
            neuron_rates = rates.mean(dim=0)  # (H,)

            mean_rate = neuron_rates.mean().item()
            std_rate = neuron_rates.std().item()

            dead_frac = (neuron_rates < self.dead_threshold).float().mean().item()
            saturated_frac = (neuron_rates > self.saturated_threshold).float().mean().item()

            return mean_rate, std_rate, dead_frac, saturated_frac

    def compute_membrane_stats(
        self,
        membrane: torch.Tensor,
    ) -> Tuple[float, float]:
        """
        Compute membrane potential statistics.

        Args:
            membrane: Membrane potential tensor, any shape

        Returns:
            (mean, std)
        """
        with torch.no_grad():
            return membrane.mean().item(), membrane.std().item()

    def compute(
        self,
        model: nn.Module,
        lyapunov: Optional[torch.Tensor] = None,
        spikes: Optional[torch.Tensor] = None,
        membrane: Optional[torch.Tensor] = None,
        compute_sv: bool = False,
    ) -> StabilityMetrics:
        """
        Compute all available stability metrics.

        Args:
            model: The SNN model (for gradient norms)
            lyapunov: Lyapunov exponent from forward pass
            spikes: Recorded spikes (optional)
            membrane: Recorded membrane potentials (optional)
            compute_sv: Whether to compute gradient singular values (expensive)

        Returns:
            StabilityMetrics object
        """
        metrics = StabilityMetrics()

        # Lyapunov exponent
        if lyapunov is not None:
            if isinstance(lyapunov, torch.Tensor):
                lyapunov = lyapunov.item()
            metrics.lyapunov = lyapunov
            self.lyap_history.append(lyapunov)
            if len(self.lyap_history) > self.history_size:
                self.lyap_history.pop(0)

        # Gradient norm
        grad_norm = self.compute_gradient_norm(model)
        metrics.grad_norm = grad_norm
        self.grad_history.append(grad_norm)
        if len(self.grad_history) > self.history_size:
            self.grad_history.pop(0)

        # Gradient singular values (optional, expensive)
        if compute_sv:
            max_sv, min_sv, avg_cond = self.get_aggregate_gradient_sv(model)
            metrics.grad_max_sv = max_sv
            metrics.grad_min_sv = min_sv
            metrics.grad_condition = avg_cond

        # Firing rate statistics
        if spikes is not None:
            fr_mean, fr_std, dead_frac, sat_frac = self.compute_firing_stats(spikes)
            metrics.firing_rate_mean = fr_mean
            metrics.firing_rate_std = fr_std
            metrics.dead_neuron_frac = dead_frac
            metrics.saturated_neuron_frac = sat_frac
            self.fr_history.append(fr_mean)
            if len(self.fr_history) > self.history_size:
                self.fr_history.pop(0)

        # Membrane potential statistics
        if membrane is not None:
            mem_mean, mem_std = self.compute_membrane_stats(membrane)
            metrics.membrane_mean = mem_mean
            metrics.membrane_std = mem_std

        return metrics

    def get_trends(self) -> Dict[str, str]:
        """
        Analyze trends in stability metrics.

        Returns:
            Dictionary with trend analysis for each metric.
        """
        trends = {}

        if len(self.lyap_history) >= 10:
            recent = np.mean(self.lyap_history[-10:])
            older = np.mean(self.lyap_history[:10])
            if recent > older + 0.1:
                trends["lyapunov"] = "⚠️ INCREASING (becoming unstable)"
            elif recent < older - 0.1:
                trends["lyapunov"] = "✓ DECREASING (stabilizing)"
            else:
                trends["lyapunov"] = "→ STABLE"

        if len(self.grad_history) >= 10:
            recent = np.mean(self.grad_history[-10:])
            older = np.mean(self.grad_history[:10])
            ratio = recent / (older + 1e-8)
            if ratio > 10:
                trends["gradients"] = "⚠️ EXPLODING"
            elif ratio < 0.1:
                trends["gradients"] = "⚠️ VANISHING"
            else:
                trends["gradients"] = "✓ STABLE"

        if len(self.fr_history) >= 10:
            recent = np.mean(self.fr_history[-10:])
            if recent < 0.01:
                trends["firing"] = "⚠️ DEAD NETWORK"
            elif recent > 0.8:
                trends["firing"] = "⚠️ SATURATED"
            else:
                trends["firing"] = "✓ HEALTHY"

        return trends

    def diagnose(self) -> str:
        """Generate a diagnostic summary."""
        trends = self.get_trends()

        lines = ["=== Stability Diagnosis ==="]

        if self.lyap_history:
            avg_lyap = np.mean(self.lyap_history[-20:])
            lines.append(f"Lyapunov exponent: {avg_lyap:.4f}")
            if avg_lyap > 0.5:
                lines.append("  → Network is CHAOTIC (trajectories diverge quickly)")
                lines.append("  → Suggestion: Increase lambda_reg or decrease learning rate")
            elif avg_lyap < -0.5:
                lines.append("  → Network is OVER-STABLE (trajectories collapse)")
                lines.append("  → Suggestion: May lose expressiveness, consider reducing regularization")
            else:
                lines.append("  → Network is at EDGE OF CHAOS (good for learning)")

        if self.grad_history:
            avg_grad = np.mean(self.grad_history[-20:])
            max_grad = max(self.grad_history[-20:])
            lines.append(f"Gradient norm: avg={avg_grad:.4f}, max={max_grad:.4f}")
            if "gradients" in trends:
                lines.append(f"  → {trends['gradients']}")

        if self.fr_history:
            avg_fr = np.mean(self.fr_history[-20:])
            lines.append(f"Firing rate: {avg_fr:.4f}")
            if "firing" in trends:
                lines.append(f"  → {trends['firing']}")

        return "\n".join(lines)


def compute_spectral_radius(weight_matrix: torch.Tensor) -> float:
    """
    Compute spectral radius of a weight matrix.

    For recurrent networks, spectral radius > 1 indicates potential instability.

    Args:
        weight_matrix: 2D weight tensor

    Returns:
        Spectral radius (largest absolute eigenvalue)
    """
    with torch.no_grad():
        W = weight_matrix.detach().cpu().numpy()
        eigenvalues = np.linalg.eigvals(W)
        return float(np.max(np.abs(eigenvalues)))


def analyze_weight_spectrum(model: nn.Module) -> Dict[str, float]:
    """
    Analyze spectral properties of all weight matrices.

    Returns:
        Dictionary mapping layer names to spectral radii.
    """
    results = {}
    for name, param in model.named_parameters():
        if "weight" in name and param.ndim == 2:
            if param.shape[0] == param.shape[1]:  # Square matrix (recurrent)
                results[name] = compute_spectral_radius(param)
    return results