summaryrefslogtreecommitdiff
path: root/scripts/test_optimized_forward.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/test_optimized_forward.py')
-rw-r--r--scripts/test_optimized_forward.py96
1 files changed, 96 insertions, 0 deletions
diff --git a/scripts/test_optimized_forward.py b/scripts/test_optimized_forward.py
new file mode 100644
index 0000000..71f8923
--- /dev/null
+++ b/scripts/test_optimized_forward.py
@@ -0,0 +1,96 @@
+#!/usr/bin/env python
+"""Quick test to verify the optimized SpikingVGG forward works correctly.
+
+Tests:
+1. Forward pass without Lyapunov
+2. Forward pass with Lyapunov (global renorm)
+3. Backward pass with gradients
+4. Multiple training steps
+5. Verify global renorm produces consistent perturbation norm
+"""
+
+import sys
+sys.path.insert(0, '/projects/bfqt/users/yurenh2/ml-projects/snn-training')
+
+import torch
+import torch.nn as nn
+
+from files.experiments.depth_scaling_benchmark import SpikingVGG
+
+def test_forward():
+ """Test that forward pass works with and without Lyapunov computation."""
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ print(f"Testing on device: {device}")
+ print("Using: Global delta + Global renorm (Option 1 - textbook LE)")
+
+ # Create model
+ model = SpikingVGG(
+ in_channels=3,
+ num_classes=10,
+ base_channels=64,
+ num_stages=3,
+ blocks_per_stage=2,
+ T=4,
+ ).to(device)
+
+ print(f"Model depth: {model.depth} conv layers")
+ print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
+
+ # Test input
+ B = 8
+ x = torch.randn(B, 3, 32, 32, device=device)
+ y = torch.randint(0, 10, (B,), device=device)
+
+ # Test 1: Forward without Lyapunov
+ print("\n[Test 1] Forward without Lyapunov...")
+ logits, lyap, _ = model(x, compute_lyapunov=False)
+ assert logits.shape == (B, 10), f"Expected (B, 10), got {logits.shape}"
+ assert lyap is None, "Expected lyap to be None"
+ print(f" Logits shape: {logits.shape} ✓")
+
+ # Test 2: Forward with Lyapunov
+ print("\n[Test 2] Forward with Lyapunov...")
+ logits, lyap, _ = model(x, compute_lyapunov=True)
+ assert logits.shape == (B, 10), f"Expected (B, 10), got {logits.shape}"
+ assert lyap is not None, "Expected lyap to be a tensor"
+ assert isinstance(lyap.item(), float), "Expected lyap to be a scalar"
+ print(f" Logits shape: {logits.shape} ✓")
+ print(f" Lyapunov exponent: {lyap.item():.4f} ✓")
+
+ # Test 3: Backward pass
+ print("\n[Test 3] Backward pass...")
+ criterion = nn.CrossEntropyLoss()
+ loss = criterion(logits, y) + 0.3 * (lyap ** 2)
+ loss.backward()
+
+ grad_norm = sum(p.grad.norm().item()**2 for p in model.parameters() if p.grad is not None)**0.5
+ print(f" Loss: {loss.item():.4f} ✓")
+ print(f" Gradient norm: {grad_norm:.4f} ✓")
+
+ # Test 4: Check gradients are not NaN
+ has_nan = any(torch.isnan(p.grad).any() for p in model.parameters() if p.grad is not None)
+ assert not has_nan, "Found NaN in gradients!"
+ print(f" No NaN gradients ✓")
+
+ # Test 5: Multiple forward-backward passes (training simulation)
+ print("\n[Test 4] Multiple training steps...")
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
+
+ for step in range(5):
+ optimizer.zero_grad()
+ x_batch = torch.randn(B, 3, 32, 32, device=device)
+ y_batch = torch.randint(0, 10, (B,), device=device)
+
+ logits, lyap, _ = model(x_batch, compute_lyapunov=True)
+ loss = criterion(logits, y_batch) + 0.3 * (lyap ** 2)
+ loss.backward()
+ optimizer.step()
+
+ print(f" Step {step+1}: loss={loss.item():.4f}, λ={lyap.item():.4f}")
+
+ print("\n" + "="*50)
+ print("ALL TESTS PASSED!")
+ print("="*50)
+
+if __name__ == "__main__":
+ test_forward()