diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:50:59 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:50:59 -0600 |
| commit | 00cf667cee7ffacb144d5805fc7e0ef443f3583a (patch) | |
| tree | 77d20a3adaecf96bf3aff0612bdd3b5fa1a7dc7e /scripts/test_optimized_forward.py | |
| parent | c53c04aa1d6ff75cb478a9498c370baa929c74b6 (diff) | |
| parent | cd99d6b874d9d09b3bb87b8485cc787885af71f1 (diff) | |
Merge master into main
Diffstat (limited to 'scripts/test_optimized_forward.py')
| -rw-r--r-- | scripts/test_optimized_forward.py | 96 |
1 files changed, 96 insertions, 0 deletions
diff --git a/scripts/test_optimized_forward.py b/scripts/test_optimized_forward.py new file mode 100644 index 0000000..71f8923 --- /dev/null +++ b/scripts/test_optimized_forward.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python +"""Quick test to verify the optimized SpikingVGG forward works correctly. + +Tests: +1. Forward pass without Lyapunov +2. Forward pass with Lyapunov (global renorm) +3. Backward pass with gradients +4. Multiple training steps +5. Verify global renorm produces consistent perturbation norm +""" + +import sys +sys.path.insert(0, '/projects/bfqt/users/yurenh2/ml-projects/snn-training') + +import torch +import torch.nn as nn + +from files.experiments.depth_scaling_benchmark import SpikingVGG + +def test_forward(): + """Test that forward pass works with and without Lyapunov computation.""" + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Testing on device: {device}") + print("Using: Global delta + Global renorm (Option 1 - textbook LE)") + + # Create model + model = SpikingVGG( + in_channels=3, + num_classes=10, + base_channels=64, + num_stages=3, + blocks_per_stage=2, + T=4, + ).to(device) + + print(f"Model depth: {model.depth} conv layers") + print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}") + + # Test input + B = 8 + x = torch.randn(B, 3, 32, 32, device=device) + y = torch.randint(0, 10, (B,), device=device) + + # Test 1: Forward without Lyapunov + print("\n[Test 1] Forward without Lyapunov...") + logits, lyap, _ = model(x, compute_lyapunov=False) + assert logits.shape == (B, 10), f"Expected (B, 10), got {logits.shape}" + assert lyap is None, "Expected lyap to be None" + print(f" Logits shape: {logits.shape} ✓") + + # Test 2: Forward with Lyapunov + print("\n[Test 2] Forward with Lyapunov...") + logits, lyap, _ = model(x, compute_lyapunov=True) + assert logits.shape == (B, 10), f"Expected (B, 10), got {logits.shape}" + assert lyap is not None, "Expected lyap to be a tensor" + assert isinstance(lyap.item(), float), "Expected lyap to be a scalar" + print(f" Logits shape: {logits.shape} ✓") + print(f" Lyapunov exponent: {lyap.item():.4f} ✓") + + # Test 3: Backward pass + print("\n[Test 3] Backward pass...") + criterion = nn.CrossEntropyLoss() + loss = criterion(logits, y) + 0.3 * (lyap ** 2) + loss.backward() + + grad_norm = sum(p.grad.norm().item()**2 for p in model.parameters() if p.grad is not None)**0.5 + print(f" Loss: {loss.item():.4f} ✓") + print(f" Gradient norm: {grad_norm:.4f} ✓") + + # Test 4: Check gradients are not NaN + has_nan = any(torch.isnan(p.grad).any() for p in model.parameters() if p.grad is not None) + assert not has_nan, "Found NaN in gradients!" + print(f" No NaN gradients ✓") + + # Test 5: Multiple forward-backward passes (training simulation) + print("\n[Test 4] Multiple training steps...") + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + for step in range(5): + optimizer.zero_grad() + x_batch = torch.randn(B, 3, 32, 32, device=device) + y_batch = torch.randint(0, 10, (B,), device=device) + + logits, lyap, _ = model(x_batch, compute_lyapunov=True) + loss = criterion(logits, y_batch) + 0.3 * (lyap ** 2) + loss.backward() + optimizer.step() + + print(f" Step {step+1}: loss={loss.item():.4f}, λ={lyap.item():.4f}") + + print("\n" + "="*50) + print("ALL TESTS PASSED!") + print("="*50) + +if __name__ == "__main__": + test_forward() |
