blob: 0fa60cd1d5389b76d0221053283173d955dd820c (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
|
# Experiment Guide
## Requirements
- Python 3.10+
- PyTorch 2.x with CUDA
- torchvision, numpy, scipy, matplotlib
## Project Structure
```
models/
residual_mlp.py - Deep residual MLP (pre-LayerNorm + GELU blocks)
value_net.py - Scalar value network V_phi for credit bridge
state_bridge.py - State predictor G_psi for state bridge
experiments/
toy_lq_v2.py - Phase A: Linear-quadratic sanity check
cifar_resmlp.py - Phase B: CIFAR-10 main experiment
plot_toy_final.py - Generate toy plots
plot_cifar_final.py - Generate CIFAR plots
metrics/
credit_metrics.py - Diagnostic metrics (cosine, rho, nudging, etc.)
configs/ - YAML configs
report/ - Plots and final report
results/ - Experiment outputs
```
## Running Experiments
### Phase A: Toy LQ Sanity Check
```bash
# Single seed
CUDA_VISIBLE_DEVICES=0 python experiments/toy_lq_v2.py \
--gpu 0 --seed 42 --num_steps 8000 \
--sigma_bridge 0.1 --lam 0.1 \
--term_grad_weight 1.0 --fm_weight 0.0 \
--output_dir results/toy_lq_frozen
# All 3 seeds
for seed in 42 123 456; do
CUDA_VISIBLE_DEVICES=0 python experiments/toy_lq_v2.py \
--gpu 0 --seed $seed --num_steps 8000 \
--sigma_bridge 0.1 --lam 0.1 \
--term_grad_weight 1.0 --fm_weight 0.0 \
--output_dir results/toy_lq_frozen
done
```
### Phase B: CIFAR-10 Main Experiment
```bash
# Single seed (runs BP, DFA, State Bridge, Credit Bridge sequentially)
CUDA_VISIBLE_DEVICES=0 python experiments/cifar_resmlp.py \
--dataset cifar10 --d_hidden 512 --num_blocks 12 \
--epochs 100 --seeds 42 --gpu 0 \
--output_dir results/cifar10
# Parallel across GPUs
CUDA_VISIBLE_DEVICES=0 python experiments/cifar_resmlp.py --seeds 42 --output_dir results/cifar10 --gpu 0 &
CUDA_VISIBLE_DEVICES=1 python experiments/cifar_resmlp.py --seeds 123 --output_dir results/cifar10_seed123 --gpu 0 &
CUDA_VISIBLE_DEVICES=2 python experiments/cifar_resmlp.py --seeds 456 --output_dir results/cifar10_seed456 --gpu 0 &
wait
```
### Generate Plots
```bash
python experiments/plot_toy_final.py
python experiments/plot_cifar_final.py
```
## Key Parameters
| Parameter | Toy LQ | CIFAR-10 | Description |
|-----------|--------|----------|-------------|
| d_hidden | 64 | 512 | Hidden dimension |
| num_layers/blocks | 12 | 12 | Depth |
| sigma_bridge | 0.1 | 0.05 | Bridge noise std |
| lam | 0.1 | 0.1 | Temperature |
| K | 8 | 4 | MC samples for bridge target |
| term_grad_weight | 1.0 | 1.0 | Terminal gradient matching weight |
| ema_momentum | 0.995 | 0.995 | EMA for target network |
| lr_fb | 1e-3 | 1e-3 | Feedback net learning rate |
## Implementation Notes
- **No hidden BP anchor**: Non-BP methods never use exact backprop through hidden layers.
- **Detached hidden copies**: All feedback/value net inputs use `detach().requires_grad_(True)`.
- **Block-local updates**: Each block's parameters updated only from its local forward + credit signal.
- **Output head**: Uses exact CE gradient with detached h_L.
- **Terminal gradient matching**: Matches grad_h V at terminal layer to grad_{h_L} CE. This is output-layer-local information, not hidden BP.
- **Credit bridge warmup**: First 20% epochs use DFA credits, then linearly blend to credit bridge credits.
|