summaryrefslogtreecommitdiff
path: root/scripts/test_optimized_forward.py
blob: 71f8923100288451295549871abe22755db3549a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/env python
"""Quick test to verify the optimized SpikingVGG forward works correctly.

Tests:
1. Forward pass without Lyapunov
2. Forward pass with Lyapunov (global renorm)
3. Backward pass with gradients
4. Multiple training steps
5. Verify global renorm produces consistent perturbation norm
"""

import sys
sys.path.insert(0, '/projects/bfqt/users/yurenh2/ml-projects/snn-training')

import torch
import torch.nn as nn

from files.experiments.depth_scaling_benchmark import SpikingVGG

def test_forward():
    """Test that forward pass works with and without Lyapunov computation."""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Testing on device: {device}")
    print("Using: Global delta + Global renorm (Option 1 - textbook LE)")

    # Create model
    model = SpikingVGG(
        in_channels=3,
        num_classes=10,
        base_channels=64,
        num_stages=3,
        blocks_per_stage=2,
        T=4,
    ).to(device)

    print(f"Model depth: {model.depth} conv layers")
    print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Test input
    B = 8
    x = torch.randn(B, 3, 32, 32, device=device)
    y = torch.randint(0, 10, (B,), device=device)

    # Test 1: Forward without Lyapunov
    print("\n[Test 1] Forward without Lyapunov...")
    logits, lyap, _ = model(x, compute_lyapunov=False)
    assert logits.shape == (B, 10), f"Expected (B, 10), got {logits.shape}"
    assert lyap is None, "Expected lyap to be None"
    print(f"  Logits shape: {logits.shape} ✓")

    # Test 2: Forward with Lyapunov
    print("\n[Test 2] Forward with Lyapunov...")
    logits, lyap, _ = model(x, compute_lyapunov=True)
    assert logits.shape == (B, 10), f"Expected (B, 10), got {logits.shape}"
    assert lyap is not None, "Expected lyap to be a tensor"
    assert isinstance(lyap.item(), float), "Expected lyap to be a scalar"
    print(f"  Logits shape: {logits.shape} ✓")
    print(f"  Lyapunov exponent: {lyap.item():.4f} ✓")

    # Test 3: Backward pass
    print("\n[Test 3] Backward pass...")
    criterion = nn.CrossEntropyLoss()
    loss = criterion(logits, y) + 0.3 * (lyap ** 2)
    loss.backward()

    grad_norm = sum(p.grad.norm().item()**2 for p in model.parameters() if p.grad is not None)**0.5
    print(f"  Loss: {loss.item():.4f} ✓")
    print(f"  Gradient norm: {grad_norm:.4f} ✓")

    # Test 4: Check gradients are not NaN
    has_nan = any(torch.isnan(p.grad).any() for p in model.parameters() if p.grad is not None)
    assert not has_nan, "Found NaN in gradients!"
    print(f"  No NaN gradients ✓")

    # Test 5: Multiple forward-backward passes (training simulation)
    print("\n[Test 4] Multiple training steps...")
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for step in range(5):
        optimizer.zero_grad()
        x_batch = torch.randn(B, 3, 32, 32, device=device)
        y_batch = torch.randint(0, 10, (B,), device=device)

        logits, lyap, _ = model(x_batch, compute_lyapunov=True)
        loss = criterion(logits, y_batch) + 0.3 * (lyap ** 2)
        loss.backward()
        optimizer.step()

        print(f"  Step {step+1}: loss={loss.item():.4f}, λ={lyap.item():.4f}")

    print("\n" + "="*50)
    print("ALL TESTS PASSED!")
    print("="*50)

if __name__ == "__main__":
    test_forward()