diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-02 16:19:14 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-02 16:19:14 -0500 |
| commit | 9d1eaacab11510793e36fc9bba271fd7c330f6e4 (patch) | |
| tree | fac2c1fc308a5479c48e89615abd69d25b5c6565 /experiments | |
| parent | 6e280e59d492203ea7f7765a65949a6c256bf73a (diff) | |
Add SB and CB methods to cnn_baseline.py
State bridge: per-layer StateBridgeNet predicting h3 from flattened h_l
Credit bridge: per-layer ValueNet with terminal + bridge consistency + DFA warmup
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/cnn_baseline.py | 314 |
1 files changed, 311 insertions, 3 deletions
diff --git a/experiments/cnn_baseline.py b/experiments/cnn_baseline.py index f55b77b..af754c0 100644 --- a/experiments/cnn_baseline.py +++ b/experiments/cnn_baseline.py @@ -38,6 +38,8 @@ from torch.utils.data import DataLoader sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation +from models.state_bridge import StateBridgeNet +from models.value_net import ValueNet, create_ema_model, update_ema import torchvision, torchvision.transforms as transforms @@ -415,6 +417,264 @@ def train_ep(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01, # --------------------------------------------------------------------------- +# Training: State Bridge +# --------------------------------------------------------------------------- + +def train_state_bridge(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01, lr_fb=3e-4): + """ + State Bridge for CNN. + + StateBridgeNet G_psi(h_l_flat, t_l, s) -> predicted h3 (256-dim terminal state). + s = e_T (10-dim softmax error). + Credit: a_l = grad_{h_l_flat} CE(out_head(SB(h_l_flat, t_l, s)), y). + Local update: <flat(F_l(h_{l-1})), a_l_norm>. + """ + L = model.num_blocks # 4 + C = 10 + flat_dims = model.flat_dims # [32768, 16384, 8192, 256] + d_terminal = 256 # h3 is the terminal hidden state + + # One SB net per layer (each takes flat_dim_l as input, outputs 256) + state_preds = nn.ModuleList([ + StateBridgeNet(d_hidden=flat_dims[l], s_dim=C, + time_embed_dim=32, hidden_dim=256, num_layers=3).to(dev) + for l in range(L) + ]) + + block_opts = [optim.AdamW(model.blocks[l].parameters(), lr=lr, weight_decay=wd) for l in range(L)] + head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) + state_opts = [optim.Adam(state_preds[l].parameters(), lr=lr_fb) for l in range(L)] + all_main_opts = block_opts + [head_opt] + schedulers = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in all_main_opts] + + for ep in range(1, epochs + 1): + model.train() + for sp in state_preds: + sp.train() + + for x, y in trl: + x, y = x.to(dev), y.to(dev) + B = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + probs = logits.softmax(-1) + e_T = probs.clone() + e_T[torch.arange(B), y] -= 1.0 + s = e_T.detach() + + h3_det = hiddens[3].detach() # (B, 256) terminal hidden state + + # --- Train each state predictor: G_psi_l(h_l_flat, t_l, s) -> h3 --- + for l in range(L): + h_l_flat = flat(hiddens[l]).detach() + t_l = torch.full((B,), l / L, device=dev) + pred_h3 = state_preds[l](h_l_flat, t_l, s) + target = h3_det + target_norm = target.norm(dim=-1, keepdim=True).clamp(min=1.0) + state_loss = (((pred_h3 - target) / target_norm) ** 2).sum(dim=-1).mean() + state_opts[l].zero_grad() + state_loss.backward() + state_opts[l].step() + + # --- Compute credits: a_l = grad_{h_l_flat} CE(out_head(SB(h_l_flat, t_l, s)), y) --- + credits = [] + for l in range(L): + h_l_flat_req = flat(hiddens[l]).detach().requires_grad_(True) + t_l = torch.full((B,), l / L, device=dev) + pred_h3 = state_preds[l](h_l_flat_req, t_l, s) + pred_logits = model.out_head(pred_h3) + pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') + a_l = torch.autograd.grad(pred_loss, h_l_flat_req, create_graph=False)[0] + credits.append(a_l.detach()) # (B, flat_dim_l) + + # --- Train out_head with CE on detached h3 --- + ce_loss = F.cross_entropy(model.out_head(h3_det), y) + head_opt.zero_grad() + ce_loss.backward() + head_opt.step() + + # --- Train each block with local surrogate <F_l(inp), a_l_norm> --- + inputs = [x, hiddens[0].detach(), hiddens[1].detach(), hiddens[2].detach()] + for l in range(L): + a_l = credits[l] + rms = (a_l ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + a_l_norm = a_l / rms + + inp = inputs[l].detach() + if l == 3: + out_l = model.blocks[l](inp.flatten(1) if inp.dim() > 2 else inp) + else: + out_l = model.blocks[l](inp) + + out_flat = flat(out_l) + local_loss = (out_flat * a_l_norm).sum(-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + + for s in schedulers: + s.step() + if ep % 20 == 0: + print(f" Ep {ep}: acc={evaluate(model, tel, dev):.4f}", flush=True) + + return model, state_preds + + +# --------------------------------------------------------------------------- +# Training: Credit Bridge +# --------------------------------------------------------------------------- + +def train_credit_bridge(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01, lr_fb=3e-4, + lam=0.1, K=4, sigma_bridge=0.01, ema_momentum=0.99, + term_grad_weight=0.1): + """ + Credit Bridge for CNN. + + ValueNet V(h_l_flat, t_l, s) -> scalar. + Credit: a_l = grad_{h_l_flat} V. + Training: terminal boundary + bridge consistency. + DFA warmup for first 20% epochs. + """ + L = model.num_blocks # 4 + C = 10 + flat_dims = model.flat_dims # [32768, 16384, 8192, 256] + warmup_epochs = max(1, epochs // 5) + + # One ValueNet per layer (each takes flat_dim_l as h input) + value_nets = nn.ModuleList([ + ValueNet(d_hidden=flat_dims[l], s_dim=C, + time_embed_dim=32, hidden_dim=256, num_layers=3).to(dev) + for l in range(L) + ]) + value_nets_ema = nn.ModuleList([create_ema_model(value_nets[l]) for l in range(L)]) + + # DFA fallback matrices + Bs_fallback = [torch.randn(flat_dims[l], C, device=dev) / np.sqrt(C) for l in range(L)] + + block_opts = [optim.AdamW(model.blocks[l].parameters(), lr=lr, weight_decay=wd) for l in range(L)] + head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) + value_opts = [optim.Adam(value_nets[l].parameters(), lr=lr_fb) for l in range(L)] + all_main_opts = block_opts + [head_opt] + schedulers = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in all_main_opts] + + print(f" [CB] Warmup phase: {warmup_epochs} epochs (DFA fallback + value net training)") + + for ep in range(1, epochs + 1): + model.train() + for vn in value_nets: + vn.train() + + if ep <= warmup_epochs: + credit_blend = 0.0 + else: + credit_blend = min(1.0, (ep - warmup_epochs) / max(1, warmup_epochs)) + + for x, y in trl: + x, y = x.to(dev), y.to(dev) + B = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + probs = logits.softmax(-1) + e_T = probs.clone() + e_T[torch.arange(B), y] -= 1.0 + s = e_T.detach() + true_loss = F.cross_entropy(logits, y, reduction='none').detach() + + h3_det = flat(hiddens[3]).detach() # (B, 256) terminal + + # --- Train value nets (always) --- + for l in range(L): + h_l_flat = flat(hiddens[l]).detach() + t_l = torch.full((B,), l / L, device=dev) + t_l_next = torch.full((B,), (l + 1) / L, device=dev) + + # Terminal boundary loss (only for last layer) + if l == L - 1: + V_l = value_nets[l](h_l_flat, t_l, s) + loss_term = ((V_l - true_loss) ** 2).mean() + else: + # Bridge consistency: V_l ~ -lam * log E[exp(-V_{l+1}/lam)] + V_l = value_nets[l](h_l_flat, t_l, s) + h_next_flat = flat(hiddens[l + 1]).detach() + with torch.no_grad(): + log_terms = [] + for k in range(K): + noise = sigma_bridge * torch.randn_like(h_next_flat) + V_next = value_nets_ema[l + 1](h_next_flat + noise, t_l_next, s) + log_terms.append(-V_next / lam) + log_stack = torch.stack(log_terms, dim=-1) + V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K)) + loss_term = ((V_l - V_target.detach()) ** 2).mean() + + value_opts[l].zero_grad() + loss_term.backward() + torch.nn.utils.clip_grad_norm_(value_nets[l].parameters(), 1.0) + value_opts[l].step() + update_ema(value_nets[l], value_nets_ema[l], ema_momentum) + + # --- Compute credits --- + cb_credits = [] + for l in range(L): + h_l_flat_req = flat(hiddens[l]).detach().requires_grad_(True) + t_l = torch.full((B,), l / L, device=dev) + V_l = value_nets[l](h_l_flat_req, t_l, s) + a_l = torch.autograd.grad(V_l.sum(), h_l_flat_req, create_graph=False)[0] + cb_credits.append(a_l.detach()) + + dfa_credits = [(e_T @ Bs_fallback[l].T).detach() for l in range(L)] + + credits = [] + for l in range(L): + if credit_blend >= 1.0: + a = cb_credits[l] + elif credit_blend <= 0.0: + a = dfa_credits[l] + else: + cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a = (credit_blend * (cb_credits[l] / cb_rms) + + (1 - credit_blend) * (dfa_credits[l] / dfa_rms)) + credits.append(a) + + # --- Train out_head with CE on detached h3 --- + ce_loss = F.cross_entropy(model.out_head(h3_det), y) + head_opt.zero_grad() + ce_loss.backward() + head_opt.step() + + # --- Train each block with local surrogate --- + inputs = [x, hiddens[0].detach(), hiddens[1].detach(), hiddens[2].detach()] + for l in range(L): + a_l = credits[l] + rms = (a_l ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + a_l_norm = a_l / rms + + inp = inputs[l].detach() + if l == 3: + out_l = model.blocks[l](inp.flatten(1) if inp.dim() > 2 else inp) + else: + out_l = model.blocks[l](inp) + + out_flat = flat(out_l) + local_loss = (out_flat * a_l_norm).sum(-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + + for s in schedulers: + s.step() + if ep % 20 == 0: + phase = "warmup" if ep <= warmup_epochs else f"blend={credit_blend:.2f}" + print(f" Ep {ep} ({phase}): acc={evaluate(model, tel, dev):.4f}", flush=True) + + return model, value_nets, value_nets_ema + + +# --------------------------------------------------------------------------- # Diagnostics # --------------------------------------------------------------------------- @@ -443,8 +703,15 @@ def compute_bp_grads(model, x, y): return [g.detach() if g is not None else torch.zeros_like(h[i]) for i, g in enumerate(gs)], h -def compute_diagnostics(model, tel, dev, method, beta=0.5, T_nudge=20, alpha_nudge=0.05): +def compute_diagnostics(model, tel, dev, method, beta=0.5, T_nudge=20, alpha_nudge=0.05, + state_preds=None, value_nets=None): model.eval() + if state_preds is not None: + for sp in state_preds: + sp.eval() + if value_nets is not None: + for vn in value_nets: + vn.eval() L = model.num_blocks # Grab one batch @@ -452,6 +719,8 @@ def compute_diagnostics(model, tel, dev, method, beta=0.5, T_nudge=20, alpha_nud x, y = x.to(dev), y.to(dev) break + B = x.size(0) + # BP gradients bp_grads, h_bp = compute_bp_grads(model, x, y) @@ -461,6 +730,26 @@ def compute_diagnostics(model, tel, dev, method, beta=0.5, T_nudge=20, alpha_nud _, h_free = model(x, return_hidden=True) h_nudged = ep_nudged_phase_cnn(model, x, y, h_free, beta, T_nudge, alpha_nudge) credits = [flat((h_nudged[l] - h_free[l]) / beta) for l in range(L)] + elif method in ('state_bridge', 'credit_bridge'): + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + probs = logits.softmax(-1) + e_T = probs.clone() + e_T[torch.arange(B), y] -= 1.0 + s = e_T.detach() + credits = [] + for l in range(L): + h_l_flat_req = flat(hiddens[l]).detach().requires_grad_(True) + t_l = torch.full((B,), l / L, device=dev) + if method == 'state_bridge': + pred_h3 = state_preds[l](h_l_flat_req, t_l, s) + pred_logits = model.out_head(pred_h3) + pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') + a_l = torch.autograd.grad(pred_loss, h_l_flat_req, create_graph=False)[0] + else: # credit_bridge + V_l = value_nets[l](h_l_flat_req, t_l, s) + a_l = torch.autograd.grad(V_l.sum(), h_l_flat_req, create_graph=False)[0] + credits.append(a_l.detach()) else: # For BP and DFA, use BP grads directly (BP self-cosine = 1 by definition) credits = [flat(bp_grads[l]) for l in range(L)] @@ -526,7 +815,7 @@ def compute_diagnostics(model, tel, dev, method, beta=0.5, T_nudge=20, alpha_nud def main(): p = argparse.ArgumentParser(description='CNN baseline for CIFAR-10') - p.add_argument('--method', type=str, required=True, choices=['bp', 'dfa', 'ep']) + p.add_argument('--method', type=str, required=True, choices=['bp', 'dfa', 'ep', 'state_bridge', 'credit_bridge']) p.add_argument('--seed', type=int, required=True) p.add_argument('--gpu', type=int, default=0) p.add_argument('--output_dir', type=str, default='results/cnn_baseline') @@ -537,6 +826,13 @@ def main(): p.add_argument('--beta', type=float, default=0.5) p.add_argument('--T_nudge', type=int, default=20) p.add_argument('--alpha_nudge', type=float, default=0.05) + # SB/CB hyperparameters + p.add_argument('--lr_fb', type=float, default=3e-4, help='Learning rate for SB/CB feedback nets') + p.add_argument('--lam', type=float, default=0.1, help='CB soft-min temperature') + p.add_argument('--K', type=int, default=4, help='CB bridge consistency samples') + p.add_argument('--sigma_bridge', type=float, default=0.01, help='CB bridge noise std') + p.add_argument('--ema_momentum', type=float, default=0.99, help='CB EMA momentum') + p.add_argument('--term_grad_weight', type=float, default=0.1, help='CB terminal gradient matching weight') args = p.parse_args() os.makedirs(args.output_dir, exist_ok=True) @@ -550,6 +846,9 @@ def main(): print(f"[{args.method} s={args.seed}] Training CNN on CIFAR-10 for {args.epochs} epochs...", flush=True) + state_preds = None + value_nets = None + if args.method == 'bp': model = train_bp(model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd) elif args.method == 'dfa': @@ -557,10 +856,19 @@ def main(): elif args.method == 'ep': model = train_ep(model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd, beta=args.beta, T_nudge=args.T_nudge, alpha_nudge=args.alpha_nudge) + elif args.method == 'state_bridge': + model, state_preds = train_state_bridge( + model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd, lr_fb=args.lr_fb) + elif args.method == 'credit_bridge': + model, value_nets, _ = train_credit_bridge( + model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd, lr_fb=args.lr_fb, + lam=args.lam, K=args.K, sigma_bridge=args.sigma_bridge, + ema_momentum=args.ema_momentum, term_grad_weight=args.term_grad_weight) acc = evaluate(model, tel, dev) diag = compute_diagnostics(model, tel, dev, args.method, - beta=args.beta, T_nudge=args.T_nudge, alpha_nudge=args.alpha_nudge) + beta=args.beta, T_nudge=args.T_nudge, alpha_nudge=args.alpha_nudge, + state_preds=state_preds, value_nets=value_nets) # Save checkpoint ckpt_path = os.path.join(args.output_dir, f'{args.method}_s{args.seed}.pt') |
