summaryrefslogtreecommitdiff
path: root/report_explore/MEMO_pivot_vector_field.md
blob: f73ac3d07fb8e0f85ea2c19300a9966752e837f5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Pivot Design Memo: Direct Vector Credit Field

## Why Scalar V May Have Value-Correct but Curvature-Wrong Gradients

The current credit bridge learns a scalar function V_phi(h_l, t_l, s) and defines credit as a_l = grad_h V_phi. The bridge consistency loss constrains V's **values** at successive layers:

    V(h_l, t_l, s) ≈ soft-min_noise V_bar(h_{l+1} + noise, t_{l+1}, s)

This gives correct V values but provides only **indirect** constraints on grad_h V. The gradient of V depends on its curvature with respect to h, which is a second-order property that the value-matching loss doesn't directly optimize.

Terminal gradient matching addresses this at the boundary (l=L), but the information must propagate backward through the bridge consistency, which is a value-level (zeroth-order) constraint. Each layer of bridge consistency loses gradient information.

**Evidence from experiments:**
- Without terminal gradient matching: V values converge but gradients are uninformative (cosine → 0.03)
- With terminal gradient matching: gradients improve but degrade with distance from terminal layer
- On CIFAR (d=512), the gradient information from 10-dim terminal code is insufficient
- deltaL (d-dim conditioning) helps on synthetic but fundamental issue remains

The core problem: **optimizing a scalar function's values does not efficiently constrain its d-dimensional gradient field**, especially in high dimensions.

## Direct Vector Credit Field: The Alternative

Instead of V_phi: R^d x R x R^s -> R, learn the credit directly:

    a_phi(h_l, t_l, s) in R^d

This outputs the credit vector without going through a scalar intermediate. The gradient computation disappears — the network output IS the credit.

### Architecture

```
Input: [LN(h_l), time_embed(t_l), s]
-> MLP (same as current ValueNet architecture)
-> Linear(hidden_dim, d_hidden)  # Output d-dimensional credit
```

### Training Objective

The bridge consistency becomes a **vector** consistency:

    a_phi(h_l, t_l, s) ≈ J_l^T a_phi(h_{l+1}, t_{l+1}, s)

where J_l = I + dF_l/dh_l is the block Jacobian. But computing J_l^T v requires hidden BP, which violates the constraint!

**Alternative 1: Forward-mode approximation**

Use finite differences along the forward dynamics:

    a_phi(h_l, t_l, s) ≈ E_xi [ (a_phi(h_{l+1} + sigma*xi, t_{l+1}, s) - a_phi(h_{l+1}, t_{l+1}, s)) / sigma * xi + a_phi(h_{l+1}, t_{l+1}, s) ]

Wait — this doesn't work either because it would need J_l^T, not J_l.

**Alternative 2: Perturbation-based target**

Train a_phi to predict local loss sensitivity directly:

    L_pert = E_v [ (<a_phi(h_l, t_l, s), v> - (loss(h_l + eps*v) - loss(h_l))/eps )^2 ]

This is computationally expensive (need M forward passes per layer per sample) but provides a direct training signal for the credit vector. It doesn't require any Jacobian or hidden BP.

**Alternative 3: Terminal matching + interpolation smoothness**

- Terminal: a_phi(h_L, 1, s) = delta_L (exact output-layer gradient)
- Smoothness: ||a_phi(h_{l+0.5}, ...) - 0.5*a_phi(h_l, ...) - 0.5*a_phi(h_{l+1}, ...)||^2

This is similar to FM auxiliary but applied to the credit vector directly.

**Alternative 4: Soft contrastive target**

    a_phi(h_l, t_l, s) should point in the direction that makes
    V_target(h_l + eps*a_phi) < V_target(h_l - eps*a_phi)

Using the EMA target network:

    L_contrastive = -log sigmoid( (V_bar(h_l - eps*a_norm, t_l, s) - V_bar(h_l + eps*a_norm, t_l, s)) / tau )

This trains a_phi to point "downhill" on the value landscape without needing the exact gradient.

### Recommended Approach: Alternative 2 + Terminal Matching

The perturbation-based target is the most principled because it directly measures what we want: local loss sensitivity. Combined with terminal matching:

    L_total = L_terminal + beta * L_perturbation

Where:
- L_terminal = ||a_phi(h_L, 1, s) - delta_L||^2
- L_perturbation = sum_l E_v [ (<a_phi(h_l, t_l, s), v> - (loss(h_l + eps*v) - loss(h_l))/eps)^2 ]

With M=4 directions per layer, this needs 4*L extra forward-from-layer passes per batch. For L=4, that's 16 passes — expensive but tractable.

## Does It Still Satisfy No Hidden BP Anchor?

**Yes.** The perturbation-based target uses:
1. Forward-from-layer passes (no backprop through hidden layers)
2. Output-layer loss evaluation (no gradient extraction)
3. Terminal gradient matching (output-layer-local)

No hidden-layer BP gradients are used as training targets at any point.

## Minimal Test Setup

**Task**: Synthetic teacher-student, alpha=1.0, L=4, d=128 (same as Phase 1 best regime)

**Comparison**:
1. Current scalar credit bridge (V_phi -> grad_h V) — baseline
2. Direct vector credit field with perturbation target (M=4)
3. Direct vector credit field with perturbation target (M=8)

**Metrics**: Same as Phase 1 (Gamma, rho, nudge)

**Expected outcome**:
- Direct vector field should achieve higher rho than scalar V (it's directly trained to predict perturbation sensitivity)
- Gamma may or may not improve (depends on whether the perturbation target implicitly aligns with BP gradient)
- Training cost: ~4x per-step for M=4 due to extra forward passes

**Implementation effort**: ~100 lines of new code. Reuse existing StudentNet and diagnostics.

## Risk Assessment

**Upside**: Direct vector field avoids the fundamental curvature problem. It's trained on exactly the quantity we care about (local loss sensitivity).

**Downside**: The perturbation target is noisier than the scalar bridge consistency. With M=4 random directions, the variance of the gradient estimate is high in d=512 dimensions.

**Mitigation**: Start with d=128 synthetic. If it works, gradually increase d. The perturbation target quality scales as sqrt(M/d), so d=512 with M=4 gives signal/noise ~ 0.09. May need M=32+ for CIFAR.

## Bottom Line

The scalar V approach has a fundamental curvature-vs-value disconnect. The direct vector field addresses this head-on. The recommended first step is a minimal test on the synthetic alpha=1.0, L=4 regime, comparing perturbation-trained vector field against the current scalar bridge. If it shows improved rho, scale up to CIFAR with higher M.