#!/usr/bin/env python """Quick test to verify the optimized SpikingVGG forward works correctly. Tests: 1. Forward pass without Lyapunov 2. Forward pass with Lyapunov (global renorm) 3. Backward pass with gradients 4. Multiple training steps 5. Verify global renorm produces consistent perturbation norm """ import sys sys.path.insert(0, '/projects/bfqt/users/yurenh2/ml-projects/snn-training') import torch import torch.nn as nn from files.experiments.depth_scaling_benchmark import SpikingVGG def test_forward(): """Test that forward pass works with and without Lyapunov computation.""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Testing on device: {device}") print("Using: Global delta + Global renorm (Option 1 - textbook LE)") # Create model model = SpikingVGG( in_channels=3, num_classes=10, base_channels=64, num_stages=3, blocks_per_stage=2, T=4, ).to(device) print(f"Model depth: {model.depth} conv layers") print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}") # Test input B = 8 x = torch.randn(B, 3, 32, 32, device=device) y = torch.randint(0, 10, (B,), device=device) # Test 1: Forward without Lyapunov print("\n[Test 1] Forward without Lyapunov...") logits, lyap, _ = model(x, compute_lyapunov=False) assert logits.shape == (B, 10), f"Expected (B, 10), got {logits.shape}" assert lyap is None, "Expected lyap to be None" print(f" Logits shape: {logits.shape} ✓") # Test 2: Forward with Lyapunov print("\n[Test 2] Forward with Lyapunov...") logits, lyap, _ = model(x, compute_lyapunov=True) assert logits.shape == (B, 10), f"Expected (B, 10), got {logits.shape}" assert lyap is not None, "Expected lyap to be a tensor" assert isinstance(lyap.item(), float), "Expected lyap to be a scalar" print(f" Logits shape: {logits.shape} ✓") print(f" Lyapunov exponent: {lyap.item():.4f} ✓") # Test 3: Backward pass print("\n[Test 3] Backward pass...") criterion = nn.CrossEntropyLoss() loss = criterion(logits, y) + 0.3 * (lyap ** 2) loss.backward() grad_norm = sum(p.grad.norm().item()**2 for p in model.parameters() if p.grad is not None)**0.5 print(f" Loss: {loss.item():.4f} ✓") print(f" Gradient norm: {grad_norm:.4f} ✓") # Test 4: Check gradients are not NaN has_nan = any(torch.isnan(p.grad).any() for p in model.parameters() if p.grad is not None) assert not has_nan, "Found NaN in gradients!" print(f" No NaN gradients ✓") # Test 5: Multiple forward-backward passes (training simulation) print("\n[Test 4] Multiple training steps...") optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) for step in range(5): optimizer.zero_grad() x_batch = torch.randn(B, 3, 32, 32, device=device) y_batch = torch.randint(0, 10, (B,), device=device) logits, lyap, _ = model(x_batch, compute_lyapunov=True) loss = criterion(logits, y_batch) + 0.3 * (lyap ** 2) loss.backward() optimizer.step() print(f" Step {step+1}: loss={loss.item():.4f}, λ={lyap.item():.4f}") print("\n" + "="*50) print("ALL TESTS PASSED!") print("="*50) if __name__ == "__main__": test_forward()