summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-02 16:19:14 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-02 16:19:14 -0500
commit9d1eaacab11510793e36fc9bba271fd7c330f6e4 (patch)
treefac2c1fc308a5479c48e89615abd69d25b5c6565 /experiments
parent6e280e59d492203ea7f7765a65949a6c256bf73a (diff)
Add SB and CB methods to cnn_baseline.py
State bridge: per-layer StateBridgeNet predicting h3 from flattened h_l Credit bridge: per-layer ValueNet with terminal + bridge consistency + DFA warmup Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/cnn_baseline.py314
1 files changed, 311 insertions, 3 deletions
diff --git a/experiments/cnn_baseline.py b/experiments/cnn_baseline.py
index f55b77b..af754c0 100644
--- a/experiments/cnn_baseline.py
+++ b/experiments/cnn_baseline.py
@@ -38,6 +38,8 @@ from torch.utils.data import DataLoader
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation
+from models.state_bridge import StateBridgeNet
+from models.value_net import ValueNet, create_ema_model, update_ema
import torchvision, torchvision.transforms as transforms
@@ -415,6 +417,264 @@ def train_ep(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01,
# ---------------------------------------------------------------------------
+# Training: State Bridge
+# ---------------------------------------------------------------------------
+
+def train_state_bridge(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01, lr_fb=3e-4):
+ """
+ State Bridge for CNN.
+
+ StateBridgeNet G_psi(h_l_flat, t_l, s) -> predicted h3 (256-dim terminal state).
+ s = e_T (10-dim softmax error).
+ Credit: a_l = grad_{h_l_flat} CE(out_head(SB(h_l_flat, t_l, s)), y).
+ Local update: <flat(F_l(h_{l-1})), a_l_norm>.
+ """
+ L = model.num_blocks # 4
+ C = 10
+ flat_dims = model.flat_dims # [32768, 16384, 8192, 256]
+ d_terminal = 256 # h3 is the terminal hidden state
+
+ # One SB net per layer (each takes flat_dim_l as input, outputs 256)
+ state_preds = nn.ModuleList([
+ StateBridgeNet(d_hidden=flat_dims[l], s_dim=C,
+ time_embed_dim=32, hidden_dim=256, num_layers=3).to(dev)
+ for l in range(L)
+ ])
+
+ block_opts = [optim.AdamW(model.blocks[l].parameters(), lr=lr, weight_decay=wd) for l in range(L)]
+ head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd)
+ state_opts = [optim.Adam(state_preds[l].parameters(), lr=lr_fb) for l in range(L)]
+ all_main_opts = block_opts + [head_opt]
+ schedulers = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in all_main_opts]
+
+ for ep in range(1, epochs + 1):
+ model.train()
+ for sp in state_preds:
+ sp.train()
+
+ for x, y in trl:
+ x, y = x.to(dev), y.to(dev)
+ B = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ probs = logits.softmax(-1)
+ e_T = probs.clone()
+ e_T[torch.arange(B), y] -= 1.0
+ s = e_T.detach()
+
+ h3_det = hiddens[3].detach() # (B, 256) terminal hidden state
+
+ # --- Train each state predictor: G_psi_l(h_l_flat, t_l, s) -> h3 ---
+ for l in range(L):
+ h_l_flat = flat(hiddens[l]).detach()
+ t_l = torch.full((B,), l / L, device=dev)
+ pred_h3 = state_preds[l](h_l_flat, t_l, s)
+ target = h3_det
+ target_norm = target.norm(dim=-1, keepdim=True).clamp(min=1.0)
+ state_loss = (((pred_h3 - target) / target_norm) ** 2).sum(dim=-1).mean()
+ state_opts[l].zero_grad()
+ state_loss.backward()
+ state_opts[l].step()
+
+ # --- Compute credits: a_l = grad_{h_l_flat} CE(out_head(SB(h_l_flat, t_l, s)), y) ---
+ credits = []
+ for l in range(L):
+ h_l_flat_req = flat(hiddens[l]).detach().requires_grad_(True)
+ t_l = torch.full((B,), l / L, device=dev)
+ pred_h3 = state_preds[l](h_l_flat_req, t_l, s)
+ pred_logits = model.out_head(pred_h3)
+ pred_loss = F.cross_entropy(pred_logits, y, reduction='sum')
+ a_l = torch.autograd.grad(pred_loss, h_l_flat_req, create_graph=False)[0]
+ credits.append(a_l.detach()) # (B, flat_dim_l)
+
+ # --- Train out_head with CE on detached h3 ---
+ ce_loss = F.cross_entropy(model.out_head(h3_det), y)
+ head_opt.zero_grad()
+ ce_loss.backward()
+ head_opt.step()
+
+ # --- Train each block with local surrogate <F_l(inp), a_l_norm> ---
+ inputs = [x, hiddens[0].detach(), hiddens[1].detach(), hiddens[2].detach()]
+ for l in range(L):
+ a_l = credits[l]
+ rms = (a_l ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a_l_norm = a_l / rms
+
+ inp = inputs[l].detach()
+ if l == 3:
+ out_l = model.blocks[l](inp.flatten(1) if inp.dim() > 2 else inp)
+ else:
+ out_l = model.blocks[l](inp)
+
+ out_flat = flat(out_l)
+ local_loss = (out_flat * a_l_norm).sum(-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+
+ for s in schedulers:
+ s.step()
+ if ep % 20 == 0:
+ print(f" Ep {ep}: acc={evaluate(model, tel, dev):.4f}", flush=True)
+
+ return model, state_preds
+
+
+# ---------------------------------------------------------------------------
+# Training: Credit Bridge
+# ---------------------------------------------------------------------------
+
+def train_credit_bridge(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01, lr_fb=3e-4,
+ lam=0.1, K=4, sigma_bridge=0.01, ema_momentum=0.99,
+ term_grad_weight=0.1):
+ """
+ Credit Bridge for CNN.
+
+ ValueNet V(h_l_flat, t_l, s) -> scalar.
+ Credit: a_l = grad_{h_l_flat} V.
+ Training: terminal boundary + bridge consistency.
+ DFA warmup for first 20% epochs.
+ """
+ L = model.num_blocks # 4
+ C = 10
+ flat_dims = model.flat_dims # [32768, 16384, 8192, 256]
+ warmup_epochs = max(1, epochs // 5)
+
+ # One ValueNet per layer (each takes flat_dim_l as h input)
+ value_nets = nn.ModuleList([
+ ValueNet(d_hidden=flat_dims[l], s_dim=C,
+ time_embed_dim=32, hidden_dim=256, num_layers=3).to(dev)
+ for l in range(L)
+ ])
+ value_nets_ema = nn.ModuleList([create_ema_model(value_nets[l]) for l in range(L)])
+
+ # DFA fallback matrices
+ Bs_fallback = [torch.randn(flat_dims[l], C, device=dev) / np.sqrt(C) for l in range(L)]
+
+ block_opts = [optim.AdamW(model.blocks[l].parameters(), lr=lr, weight_decay=wd) for l in range(L)]
+ head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd)
+ value_opts = [optim.Adam(value_nets[l].parameters(), lr=lr_fb) for l in range(L)]
+ all_main_opts = block_opts + [head_opt]
+ schedulers = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in all_main_opts]
+
+ print(f" [CB] Warmup phase: {warmup_epochs} epochs (DFA fallback + value net training)")
+
+ for ep in range(1, epochs + 1):
+ model.train()
+ for vn in value_nets:
+ vn.train()
+
+ if ep <= warmup_epochs:
+ credit_blend = 0.0
+ else:
+ credit_blend = min(1.0, (ep - warmup_epochs) / max(1, warmup_epochs))
+
+ for x, y in trl:
+ x, y = x.to(dev), y.to(dev)
+ B = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ probs = logits.softmax(-1)
+ e_T = probs.clone()
+ e_T[torch.arange(B), y] -= 1.0
+ s = e_T.detach()
+ true_loss = F.cross_entropy(logits, y, reduction='none').detach()
+
+ h3_det = flat(hiddens[3]).detach() # (B, 256) terminal
+
+ # --- Train value nets (always) ---
+ for l in range(L):
+ h_l_flat = flat(hiddens[l]).detach()
+ t_l = torch.full((B,), l / L, device=dev)
+ t_l_next = torch.full((B,), (l + 1) / L, device=dev)
+
+ # Terminal boundary loss (only for last layer)
+ if l == L - 1:
+ V_l = value_nets[l](h_l_flat, t_l, s)
+ loss_term = ((V_l - true_loss) ** 2).mean()
+ else:
+ # Bridge consistency: V_l ~ -lam * log E[exp(-V_{l+1}/lam)]
+ V_l = value_nets[l](h_l_flat, t_l, s)
+ h_next_flat = flat(hiddens[l + 1]).detach()
+ with torch.no_grad():
+ log_terms = []
+ for k in range(K):
+ noise = sigma_bridge * torch.randn_like(h_next_flat)
+ V_next = value_nets_ema[l + 1](h_next_flat + noise, t_l_next, s)
+ log_terms.append(-V_next / lam)
+ log_stack = torch.stack(log_terms, dim=-1)
+ V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K))
+ loss_term = ((V_l - V_target.detach()) ** 2).mean()
+
+ value_opts[l].zero_grad()
+ loss_term.backward()
+ torch.nn.utils.clip_grad_norm_(value_nets[l].parameters(), 1.0)
+ value_opts[l].step()
+ update_ema(value_nets[l], value_nets_ema[l], ema_momentum)
+
+ # --- Compute credits ---
+ cb_credits = []
+ for l in range(L):
+ h_l_flat_req = flat(hiddens[l]).detach().requires_grad_(True)
+ t_l = torch.full((B,), l / L, device=dev)
+ V_l = value_nets[l](h_l_flat_req, t_l, s)
+ a_l = torch.autograd.grad(V_l.sum(), h_l_flat_req, create_graph=False)[0]
+ cb_credits.append(a_l.detach())
+
+ dfa_credits = [(e_T @ Bs_fallback[l].T).detach() for l in range(L)]
+
+ credits = []
+ for l in range(L):
+ if credit_blend >= 1.0:
+ a = cb_credits[l]
+ elif credit_blend <= 0.0:
+ a = dfa_credits[l]
+ else:
+ cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a = (credit_blend * (cb_credits[l] / cb_rms)
+ + (1 - credit_blend) * (dfa_credits[l] / dfa_rms))
+ credits.append(a)
+
+ # --- Train out_head with CE on detached h3 ---
+ ce_loss = F.cross_entropy(model.out_head(h3_det), y)
+ head_opt.zero_grad()
+ ce_loss.backward()
+ head_opt.step()
+
+ # --- Train each block with local surrogate ---
+ inputs = [x, hiddens[0].detach(), hiddens[1].detach(), hiddens[2].detach()]
+ for l in range(L):
+ a_l = credits[l]
+ rms = (a_l ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a_l_norm = a_l / rms
+
+ inp = inputs[l].detach()
+ if l == 3:
+ out_l = model.blocks[l](inp.flatten(1) if inp.dim() > 2 else inp)
+ else:
+ out_l = model.blocks[l](inp)
+
+ out_flat = flat(out_l)
+ local_loss = (out_flat * a_l_norm).sum(-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+
+ for s in schedulers:
+ s.step()
+ if ep % 20 == 0:
+ phase = "warmup" if ep <= warmup_epochs else f"blend={credit_blend:.2f}"
+ print(f" Ep {ep} ({phase}): acc={evaluate(model, tel, dev):.4f}", flush=True)
+
+ return model, value_nets, value_nets_ema
+
+
+# ---------------------------------------------------------------------------
# Diagnostics
# ---------------------------------------------------------------------------
@@ -443,8 +703,15 @@ def compute_bp_grads(model, x, y):
return [g.detach() if g is not None else torch.zeros_like(h[i]) for i, g in enumerate(gs)], h
-def compute_diagnostics(model, tel, dev, method, beta=0.5, T_nudge=20, alpha_nudge=0.05):
+def compute_diagnostics(model, tel, dev, method, beta=0.5, T_nudge=20, alpha_nudge=0.05,
+ state_preds=None, value_nets=None):
model.eval()
+ if state_preds is not None:
+ for sp in state_preds:
+ sp.eval()
+ if value_nets is not None:
+ for vn in value_nets:
+ vn.eval()
L = model.num_blocks
# Grab one batch
@@ -452,6 +719,8 @@ def compute_diagnostics(model, tel, dev, method, beta=0.5, T_nudge=20, alpha_nud
x, y = x.to(dev), y.to(dev)
break
+ B = x.size(0)
+
# BP gradients
bp_grads, h_bp = compute_bp_grads(model, x, y)
@@ -461,6 +730,26 @@ def compute_diagnostics(model, tel, dev, method, beta=0.5, T_nudge=20, alpha_nud
_, h_free = model(x, return_hidden=True)
h_nudged = ep_nudged_phase_cnn(model, x, y, h_free, beta, T_nudge, alpha_nudge)
credits = [flat((h_nudged[l] - h_free[l]) / beta) for l in range(L)]
+ elif method in ('state_bridge', 'credit_bridge'):
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ probs = logits.softmax(-1)
+ e_T = probs.clone()
+ e_T[torch.arange(B), y] -= 1.0
+ s = e_T.detach()
+ credits = []
+ for l in range(L):
+ h_l_flat_req = flat(hiddens[l]).detach().requires_grad_(True)
+ t_l = torch.full((B,), l / L, device=dev)
+ if method == 'state_bridge':
+ pred_h3 = state_preds[l](h_l_flat_req, t_l, s)
+ pred_logits = model.out_head(pred_h3)
+ pred_loss = F.cross_entropy(pred_logits, y, reduction='sum')
+ a_l = torch.autograd.grad(pred_loss, h_l_flat_req, create_graph=False)[0]
+ else: # credit_bridge
+ V_l = value_nets[l](h_l_flat_req, t_l, s)
+ a_l = torch.autograd.grad(V_l.sum(), h_l_flat_req, create_graph=False)[0]
+ credits.append(a_l.detach())
else:
# For BP and DFA, use BP grads directly (BP self-cosine = 1 by definition)
credits = [flat(bp_grads[l]) for l in range(L)]
@@ -526,7 +815,7 @@ def compute_diagnostics(model, tel, dev, method, beta=0.5, T_nudge=20, alpha_nud
def main():
p = argparse.ArgumentParser(description='CNN baseline for CIFAR-10')
- p.add_argument('--method', type=str, required=True, choices=['bp', 'dfa', 'ep'])
+ p.add_argument('--method', type=str, required=True, choices=['bp', 'dfa', 'ep', 'state_bridge', 'credit_bridge'])
p.add_argument('--seed', type=int, required=True)
p.add_argument('--gpu', type=int, default=0)
p.add_argument('--output_dir', type=str, default='results/cnn_baseline')
@@ -537,6 +826,13 @@ def main():
p.add_argument('--beta', type=float, default=0.5)
p.add_argument('--T_nudge', type=int, default=20)
p.add_argument('--alpha_nudge', type=float, default=0.05)
+ # SB/CB hyperparameters
+ p.add_argument('--lr_fb', type=float, default=3e-4, help='Learning rate for SB/CB feedback nets')
+ p.add_argument('--lam', type=float, default=0.1, help='CB soft-min temperature')
+ p.add_argument('--K', type=int, default=4, help='CB bridge consistency samples')
+ p.add_argument('--sigma_bridge', type=float, default=0.01, help='CB bridge noise std')
+ p.add_argument('--ema_momentum', type=float, default=0.99, help='CB EMA momentum')
+ p.add_argument('--term_grad_weight', type=float, default=0.1, help='CB terminal gradient matching weight')
args = p.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
@@ -550,6 +846,9 @@ def main():
print(f"[{args.method} s={args.seed}] Training CNN on CIFAR-10 for {args.epochs} epochs...", flush=True)
+ state_preds = None
+ value_nets = None
+
if args.method == 'bp':
model = train_bp(model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd)
elif args.method == 'dfa':
@@ -557,10 +856,19 @@ def main():
elif args.method == 'ep':
model = train_ep(model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd,
beta=args.beta, T_nudge=args.T_nudge, alpha_nudge=args.alpha_nudge)
+ elif args.method == 'state_bridge':
+ model, state_preds = train_state_bridge(
+ model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd, lr_fb=args.lr_fb)
+ elif args.method == 'credit_bridge':
+ model, value_nets, _ = train_credit_bridge(
+ model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd, lr_fb=args.lr_fb,
+ lam=args.lam, K=args.K, sigma_bridge=args.sigma_bridge,
+ ema_momentum=args.ema_momentum, term_grad_weight=args.term_grad_weight)
acc = evaluate(model, tel, dev)
diag = compute_diagnostics(model, tel, dev, args.method,
- beta=args.beta, T_nudge=args.T_nudge, alpha_nudge=args.alpha_nudge)
+ beta=args.beta, T_nudge=args.T_nudge, alpha_nudge=args.alpha_nudge,
+ state_preds=state_preds, value_nets=value_nets)
# Save checkpoint
ckpt_path = os.path.join(args.output_dir, f'{args.method}_s{args.seed}.pt')