diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-24 01:44:34 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-24 01:44:34 -0500 |
| commit | c09ae4244033a7a2703f0c36279d598ca869a95f (patch) | |
| tree | ac09c1dc29d228865df5796b2a842ca0a42add88 /report_explore | |
| parent | 8f786597d1007f0ef6012f53c22958d9c4e9b81a (diff) | |
Add CIFAR deltaL test (failed) and pivot design memo
- CIFAR deltaL: s=grad_hL CE (dim=512) -> acc=17.2%, Gamma≈0
Confirms scalar value field has dimensionality bottleneck on CIFAR
- Pivot memo: direct vector credit field a_phi(h,t,s) -> R^d
Trained with perturbation-based target, avoids curvature problem
Still satisfies no hidden BP anchor constraint
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'report_explore')
| -rw-r--r-- | report_explore/MEMO_pivot_vector_field.md | 128 |
1 files changed, 128 insertions, 0 deletions
diff --git a/report_explore/MEMO_pivot_vector_field.md b/report_explore/MEMO_pivot_vector_field.md new file mode 100644 index 0000000..f73ac3d --- /dev/null +++ b/report_explore/MEMO_pivot_vector_field.md @@ -0,0 +1,128 @@ +# Pivot Design Memo: Direct Vector Credit Field + +## Why Scalar V May Have Value-Correct but Curvature-Wrong Gradients + +The current credit bridge learns a scalar function V_phi(h_l, t_l, s) and defines credit as a_l = grad_h V_phi. The bridge consistency loss constrains V's **values** at successive layers: + + V(h_l, t_l, s) ≈ soft-min_noise V_bar(h_{l+1} + noise, t_{l+1}, s) + +This gives correct V values but provides only **indirect** constraints on grad_h V. The gradient of V depends on its curvature with respect to h, which is a second-order property that the value-matching loss doesn't directly optimize. + +Terminal gradient matching addresses this at the boundary (l=L), but the information must propagate backward through the bridge consistency, which is a value-level (zeroth-order) constraint. Each layer of bridge consistency loses gradient information. + +**Evidence from experiments:** +- Without terminal gradient matching: V values converge but gradients are uninformative (cosine → 0.03) +- With terminal gradient matching: gradients improve but degrade with distance from terminal layer +- On CIFAR (d=512), the gradient information from 10-dim terminal code is insufficient +- deltaL (d-dim conditioning) helps on synthetic but fundamental issue remains + +The core problem: **optimizing a scalar function's values does not efficiently constrain its d-dimensional gradient field**, especially in high dimensions. + +## Direct Vector Credit Field: The Alternative + +Instead of V_phi: R^d x R x R^s -> R, learn the credit directly: + + a_phi(h_l, t_l, s) in R^d + +This outputs the credit vector without going through a scalar intermediate. The gradient computation disappears — the network output IS the credit. + +### Architecture + +``` +Input: [LN(h_l), time_embed(t_l), s] +-> MLP (same as current ValueNet architecture) +-> Linear(hidden_dim, d_hidden) # Output d-dimensional credit +``` + +### Training Objective + +The bridge consistency becomes a **vector** consistency: + + a_phi(h_l, t_l, s) ≈ J_l^T a_phi(h_{l+1}, t_{l+1}, s) + +where J_l = I + dF_l/dh_l is the block Jacobian. But computing J_l^T v requires hidden BP, which violates the constraint! + +**Alternative 1: Forward-mode approximation** + +Use finite differences along the forward dynamics: + + a_phi(h_l, t_l, s) ≈ E_xi [ (a_phi(h_{l+1} + sigma*xi, t_{l+1}, s) - a_phi(h_{l+1}, t_{l+1}, s)) / sigma * xi + a_phi(h_{l+1}, t_{l+1}, s) ] + +Wait — this doesn't work either because it would need J_l^T, not J_l. + +**Alternative 2: Perturbation-based target** + +Train a_phi to predict local loss sensitivity directly: + + L_pert = E_v [ (<a_phi(h_l, t_l, s), v> - (loss(h_l + eps*v) - loss(h_l))/eps )^2 ] + +This is computationally expensive (need M forward passes per layer per sample) but provides a direct training signal for the credit vector. It doesn't require any Jacobian or hidden BP. + +**Alternative 3: Terminal matching + interpolation smoothness** + +- Terminal: a_phi(h_L, 1, s) = delta_L (exact output-layer gradient) +- Smoothness: ||a_phi(h_{l+0.5}, ...) - 0.5*a_phi(h_l, ...) - 0.5*a_phi(h_{l+1}, ...)||^2 + +This is similar to FM auxiliary but applied to the credit vector directly. + +**Alternative 4: Soft contrastive target** + + a_phi(h_l, t_l, s) should point in the direction that makes + V_target(h_l + eps*a_phi) < V_target(h_l - eps*a_phi) + +Using the EMA target network: + + L_contrastive = -log sigmoid( (V_bar(h_l - eps*a_norm, t_l, s) - V_bar(h_l + eps*a_norm, t_l, s)) / tau ) + +This trains a_phi to point "downhill" on the value landscape without needing the exact gradient. + +### Recommended Approach: Alternative 2 + Terminal Matching + +The perturbation-based target is the most principled because it directly measures what we want: local loss sensitivity. Combined with terminal matching: + + L_total = L_terminal + beta * L_perturbation + +Where: +- L_terminal = ||a_phi(h_L, 1, s) - delta_L||^2 +- L_perturbation = sum_l E_v [ (<a_phi(h_l, t_l, s), v> - (loss(h_l + eps*v) - loss(h_l))/eps)^2 ] + +With M=4 directions per layer, this needs 4*L extra forward-from-layer passes per batch. For L=4, that's 16 passes — expensive but tractable. + +## Does It Still Satisfy No Hidden BP Anchor? + +**Yes.** The perturbation-based target uses: +1. Forward-from-layer passes (no backprop through hidden layers) +2. Output-layer loss evaluation (no gradient extraction) +3. Terminal gradient matching (output-layer-local) + +No hidden-layer BP gradients are used as training targets at any point. + +## Minimal Test Setup + +**Task**: Synthetic teacher-student, alpha=1.0, L=4, d=128 (same as Phase 1 best regime) + +**Comparison**: +1. Current scalar credit bridge (V_phi -> grad_h V) — baseline +2. Direct vector credit field with perturbation target (M=4) +3. Direct vector credit field with perturbation target (M=8) + +**Metrics**: Same as Phase 1 (Gamma, rho, nudge) + +**Expected outcome**: +- Direct vector field should achieve higher rho than scalar V (it's directly trained to predict perturbation sensitivity) +- Gamma may or may not improve (depends on whether the perturbation target implicitly aligns with BP gradient) +- Training cost: ~4x per-step for M=4 due to extra forward passes + +**Implementation effort**: ~100 lines of new code. Reuse existing StudentNet and diagnostics. + +## Risk Assessment + +**Upside**: Direct vector field avoids the fundamental curvature problem. It's trained on exactly the quantity we care about (local loss sensitivity). + +**Downside**: The perturbation target is noisier than the scalar bridge consistency. With M=4 random directions, the variance of the gradient estimate is high in d=512 dimensions. + +**Mitigation**: Start with d=128 synthetic. If it works, gradually increase d. The perturbation target quality scales as sqrt(M/d), so d=512 with M=4 gives signal/noise ~ 0.09. May need M=32+ for CIFAR. + +## Bottom Line + +The scalar V approach has a fundamental curvature-vs-value disconnect. The direct vector field addresses this head-on. The recommended first step is a minimal test on the synthetic alpha=1.0, L=4 regime, comparing perturbation-trained vector field against the current scalar bridge. If it shows improved rho, scale up to CIFAR with higher M. |
