""" Phase C: Direct Vector Credit Field Pilot. Compare scalar credit bridge vs direct vector credit field on synthetic best regime. Vector field: a_phi(h_l, t_l, s) -> R^d, trained with symmetric finite-difference directional targets (no hidden BP anchor). Loss: L_proj = (1/M) sum_j ( - g_j )^2 where g_j = [ loss(h_l + eps*v_j) - loss(h_l - eps*v_j) ] / (2*eps) L_term = || a_phi(h_L, 1, s) - delta_L ||^2 L_total = L_term + beta * L_proj """ import os import sys import json import argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import copy import time sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema from metrics.credit_metrics import ( cosine_similarity_batch, perturbation_correlation, nudging_test, offline_bp_cosine ) # ============================================================================= # Synthetic teacher-student (from synth_nonlinearity_ladder.py) # ============================================================================= class TeacherNet(nn.Module): """Fixed teacher with controllable nonlinearity.""" def __init__(self, d_hidden, num_classes, num_blocks, alpha=1.0, seed=0): super().__init__() self.d_hidden = d_hidden self.num_blocks = num_blocks self.alpha = alpha rng = torch.Generator().manual_seed(seed) self.Ws = nn.ParameterList() for _ in range(num_blocks): W = torch.randn(d_hidden, d_hidden, generator=rng) * 0.3 / (d_hidden ** 0.5) U, S, Vh = torch.linalg.svd(W, full_matrices=False) S_clamped = S.clamp(max=0.3) W = U @ torch.diag(S_clamped) @ Vh self.Ws.append(nn.Parameter(W, requires_grad=False)) self.U = nn.Parameter(torch.randn(num_classes, d_hidden, generator=rng) / (d_hidden ** 0.5), requires_grad=False) def phi(self, z): return (1 - self.alpha) * z + self.alpha * torch.tanh(z) def forward(self, x): h = x for W in self.Ws: h = h + self.phi(h @ W.T) return h @ self.U.T class StudentBlock(nn.Module): """Student block with pre-LayerNorm.""" def __init__(self, d_hidden, alpha=1.0): super().__init__() self.ln = nn.LayerNorm(d_hidden) self.w = nn.Linear(d_hidden, d_hidden, bias=False) nn.init.normal_(self.w.weight, std=0.01) self.alpha = alpha def phi(self, z): return (1 - self.alpha) * z + self.alpha * torch.tanh(z) def forward(self, h): return self.w(self.phi(self.ln(h))) class StudentNet(nn.Module): """Student network.""" def __init__(self, d_hidden, num_classes, num_blocks, alpha=1.0): super().__init__() self.blocks = nn.ModuleList([StudentBlock(d_hidden, alpha) for _ in range(num_blocks)]) self.out_head = nn.Linear(d_hidden, num_classes) self.d_hidden = d_hidden self.num_blocks = num_blocks def forward(self, x, return_hidden=False): h = x hiddens = [h] if return_hidden else None for block in self.blocks: f = block(h) h = h + f if return_hidden: hiddens.append(h) logits = self.out_head(h) if return_hidden: return logits, hiddens return logits def forward_from_layer(self, h, start_layer): for i in range(start_layer, self.num_blocks): f = self.blocks[i](h) h = h + f return self.out_head(h) # ============================================================================= # Vector Credit Field Network # ============================================================================= class VectorCreditNet(nn.Module): """ Direct vector credit field: a_phi(h_l, t_l, s) -> R^d. Output is d-dimensional credit vector directly. """ def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): super().__init__() self.ln = nn.LayerNorm(d_hidden) self.time_embed = SinusoidalTimeEmbed(time_embed_dim) input_dim = d_hidden + time_embed_dim + s_dim layers = [] for i in range(num_layers): in_d = input_dim if i == 0 else hidden_dim layers.append(nn.Linear(in_d, hidden_dim)) layers.append(nn.GELU()) layers.append(nn.Linear(hidden_dim, d_hidden)) self.net = nn.Sequential(*layers) def forward(self, h, t, s): """Returns credit vector (batch, d_hidden).""" h_normed = self.ln(h) t_emb = self.time_embed(t) inp = torch.cat([h_normed, t_emb, s], dim=-1) return self.net(inp) # ============================================================================= # Training functions # ============================================================================= def generate_batch(teacher, d_hidden, num_classes, batch_size, device): """Generate synthetic data from teacher.""" x = torch.randn(batch_size, d_hidden, device=device) with torch.no_grad(): teacher_logits = teacher(x) y = teacher_logits.argmax(dim=-1) return x, y def train_scalar_cb(model, teacher, device, args): """Train scalar credit bridge (current method, as baseline).""" d = model.d_hidden L = model.num_blocks num_classes = args.num_classes value_net = ValueNet(d_hidden=d, s_dim=num_classes, time_embed_dim=32, hidden_dim=256, num_layers=3).to(device) value_net_ema = create_ema_model(value_net) Bs = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)] block_opts = [optim.AdamW(b.parameters(), lr=args.lr, weight_decay=0.01) for b in model.blocks] head_opt = optim.AdamW(model.out_head.parameters(), lr=args.lr, weight_decay=0.01) value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb) warmup_epochs = max(1, int(args.epochs * args.warmup_ratio)) log = {'train_loss': [], 'test_acc': []} for epoch in range(1, args.epochs + 1): model.train() value_net.train() if epoch <= warmup_epochs: credit_blend = 0.0 else: credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs)) total_loss, correct, total = 0, 0, 0 for _ in range(args.steps_per_epoch): x, y = generate_batch(teacher, d, num_classes, args.batch_size, device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) loss_val = F.cross_entropy(logits, y) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() true_loss = F.cross_entropy(logits, y, reduction='none').detach() hL_det = hiddens[-1].detach() # Train value net t_L = torch.ones(batch, device=device) V_term = value_net(hL_det, t_L, s) loss_term = ((V_term - true_loss) ** 2).mean() # Terminal gradient matching hL_req = hL_det.clone().requires_grad_(True) V_at_L = value_net(hL_req, t_L, s) grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] hL_req2 = hL_det.clone().requires_grad_(True) logits_tgt = model.out_head(hL_req2) ce = F.cross_entropy(logits_tgt, y, reduction='sum') a_L_exact = torch.autograd.grad(ce, hL_req2, create_graph=False)[0].detach() loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() # Bridge consistency loss_bridge = 0.0 for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) t_next = torch.full((batch,), (l + 1) / L, device=device) V_l = value_net(h_l_det, t_l, s) with torch.no_grad(): h_next = hiddens[l + 1].detach() log_terms = [] for k in range(args.K): noise = args.sigma_bridge * torch.randn_like(h_next) V_next = value_net_ema(h_next + noise, t_next, s) log_terms.append(-V_next / args.lam) log_stack = torch.stack(log_terms, dim=-1) V_target = -args.lam * (torch.logsumexp(log_stack, dim=-1) - np.log(args.K)) loss_bridge += ((V_l - V_target.detach()) ** 2).mean() loss_bridge /= L vloss = loss_term + loss_bridge + args.term_grad_weight * loss_tgrad value_opt.zero_grad() vloss.backward() torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) value_opt.step() update_ema(value_net, value_net_ema, args.ema_momentum) # Compute credits cb_credits = [] for l in range(L): h_l_det = hiddens[l].detach().requires_grad_(True) t_l = torch.full((batch,), l / L, device=device) V_l = value_net(h_l_det, t_l, s) a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0] cb_credits.append(a_l.detach()) dfa_credits = [(e_T @ Bs[l].T).detach() for l in range(L)] credits = [] for l in range(L): if credit_blend >= 1.0: credits.append(cb_credits[l]) elif credit_blend <= 0.0: credits.append(dfa_credits[l]) else: cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 credits.append(credit_blend * cb_credits[l] / cb_rms + (1 - credit_blend) * dfa_credits[l] / dfa_rms) # Update head logits_out = model.out_head(hL_det) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad() loss_out.backward() head_opt.step() # Update blocks for l in range(L): h_l = hiddens[l].detach() a = credits[l] rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 f_l = model.blocks[l](h_l) local_loss = (f_l * (a / rms)).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() total_loss += loss_val.item() * batch correct += (logits.argmax(1) == y).sum().item() total += batch if epoch % 10 == 0 or epoch == 1: acc = correct / total print(f" [scalar_cb] Ep {epoch}: loss={total_loss/total:.4f}, acc={acc:.4f}") return value_net def train_vector_field(model, teacher, device, args, M=4): """ Train direct vector credit field with perturbation-based targets. No hidden BP anchor. """ d = model.d_hidden L = model.num_blocks num_classes = args.num_classes vector_net = VectorCreditNet(d_hidden=d, s_dim=num_classes, time_embed_dim=32, hidden_dim=256, num_layers=3).to(device) Bs = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)] block_opts = [optim.AdamW(b.parameters(), lr=args.lr, weight_decay=0.01) for b in model.blocks] head_opt = optim.AdamW(model.out_head.parameters(), lr=args.lr, weight_decay=0.01) vec_opt = optim.Adam(vector_net.parameters(), lr=args.lr_fb) warmup_epochs = max(1, int(args.epochs * args.warmup_ratio)) eps = args.pert_eps beta = args.pert_beta for epoch in range(1, args.epochs + 1): model.train() vector_net.train() if epoch <= warmup_epochs: credit_blend = 0.0 else: credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs)) total_loss, correct, total = 0, 0, 0 total_vloss = 0 for _ in range(args.steps_per_epoch): x, y = generate_batch(teacher, d, num_classes, args.batch_size, device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) loss_val = F.cross_entropy(logits, y) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() hL_det = hiddens[-1].detach() # Terminal matching: a_phi(h_L, 1, s) = delta_L t_L = torch.ones(batch, device=device) a_terminal = vector_net(hL_det, t_L, s) # delta_L = grad_{h_L} CE (output-layer-local) hL_req = hL_det.clone().requires_grad_(True) logits_tgt = model.out_head(hL_req) ce = F.cross_entropy(logits_tgt, y, reduction='sum') delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach() loss_term = ((a_terminal - delta_L) ** 2).sum(dim=-1).mean() # Perturbation-based directional targets for all layers loss_proj = torch.tensor(0.0, device=device) for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) a_l = vector_net(h_l_det, t_l, s) # Compute directional targets using symmetric finite difference layer_proj_loss = 0.0 for _ in range(M): v = torch.randn_like(h_l_det) v = v / (v.norm(dim=-1, keepdim=True) + 1e-8) # Forward from perturbed h_l with torch.no_grad(): logits_plus = model.forward_from_layer(h_l_det + eps * v, l) loss_plus = F.cross_entropy(logits_plus, y, reduction='none') logits_minus = model.forward_from_layer(h_l_det - eps * v, l) loss_minus = F.cross_entropy(logits_minus, y, reduction='none') g_j = (loss_plus - loss_minus) / (2 * eps) # (batch,) # Predicted directional derivative pred_j = (a_l * v).sum(dim=-1) # (batch,) layer_proj_loss = layer_proj_loss + ((pred_j - g_j.detach()) ** 2).mean() loss_proj = loss_proj + layer_proj_loss / M loss_proj = loss_proj / L vec_loss = loss_term + beta * loss_proj vec_opt.zero_grad() vec_loss.backward() torch.nn.utils.clip_grad_norm_(vector_net.parameters(), 1.0) vec_opt.step() total_vloss += vec_loss.item() * batch # Compute credits for block updates with torch.no_grad(): vec_credits = [] for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) a_l = vector_net(h_l_det, t_l, s) vec_credits.append(a_l.detach()) dfa_credits = [(e_T @ Bs[l].T).detach() for l in range(L)] credits = [] for l in range(L): if credit_blend >= 1.0: credits.append(vec_credits[l]) elif credit_blend <= 0.0: credits.append(dfa_credits[l]) else: vc_rms = (vec_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 credits.append(credit_blend * vec_credits[l] / vc_rms + (1 - credit_blend) * dfa_credits[l] / dfa_rms) # Update head logits_out = model.out_head(hL_det) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad() loss_out.backward() head_opt.step() # Update blocks for l in range(L): h_l = hiddens[l].detach() a = credits[l] rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 f_l = model.blocks[l](h_l) local_loss = (f_l * (a / rms)).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() total_loss += loss_val.item() * batch correct += (logits.argmax(1) == y).sum().item() total += batch if epoch % 10 == 0 or epoch == 1: acc = correct / total print(f" [vec_M={M}] Ep {epoch}: loss={total_loss/total:.4f}, acc={acc:.4f}, " f"vloss={total_vloss/total:.6f}") return vector_net def train_dfa(model, teacher, device, args): """DFA baseline for comparison.""" d = model.d_hidden L = model.num_blocks num_classes = args.num_classes Bs = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)] block_opts = [optim.AdamW(b.parameters(), lr=args.lr, weight_decay=0.01) for b in model.blocks] head_opt = optim.AdamW(model.out_head.parameters(), lr=args.lr, weight_decay=0.01) for epoch in range(1, args.epochs + 1): model.train() total_loss, correct, total = 0, 0, 0 for _ in range(args.steps_per_epoch): x, y = generate_batch(teacher, d, num_classes, args.batch_size, device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) loss_val = F.cross_entropy(logits, y) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 hL_det = hiddens[-1].detach() logits_out = model.out_head(hL_det) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad() loss_out.backward() head_opt.step() for l in range(L): h_l = hiddens[l].detach() a = (e_T @ Bs[l].T).detach() rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 f_l = model.blocks[l](h_l) local_loss = (f_l * (a / rms)).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() total_loss += loss_val.item() * batch correct += (logits.argmax(1) == y).sum().item() total += batch if epoch % 10 == 0 or epoch == 1: print(f" [DFA] Ep {epoch}: loss={total_loss/total:.4f}, acc={correct/total:.4f}") return Bs # ============================================================================= # Diagnostics # ============================================================================= def compute_diagnostics(model, teacher, device, method_name, args, value_net=None, vector_net=None, dfa_Bs=None): """Compute Gamma, rho, nudging per layer.""" model.eval() if value_net is not None: value_net.eval() if vector_net is not None: vector_net.eval() d = model.d_hidden L = model.num_blocks num_classes = args.num_classes x, y = generate_batch(teacher, d, num_classes, 512, device) batch = x.size(0) # BP gradients (evaluation only) h = x.detach().requires_grad_(True) hiddens_bp = [h] for block in model.blocks: f = block(hiddens_bp[-1]) h_next = hiddens_bp[-1] + f hiddens_bp.append(h_next) logits_bp = model.out_head(hiddens_bp[-1]) loss_bp = F.cross_entropy(logits_bp, y) grads = torch.autograd.grad(loss_bp, hiddens_bp, retain_graph=False) bp_grads = {l: grads[l].detach().clone() for l in range(L + 1)} with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() results = {'bp_cosine': [], 'perturbation_rho': [], 'nudging': []} for l in range(L): h_l = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) if method_name == 'dfa': a_l = (s @ dfa_Bs[l].T).detach() elif method_name == 'scalar_cb': h_l_req = h_l.clone().requires_grad_(True) V_l = value_net(h_l_req, t_l, s) a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach() elif method_name.startswith('vector'): a_l = vector_net(h_l, t_l, s).detach() else: raise ValueError(f"Unknown: {method_name}") bp_cos = cosine_similarity_batch(a_l, bp_grads[l]) results['bp_cosine'].append(float(bp_cos)) def make_fwd_fn(start_l): def fwd_fn(h): with torch.no_grad(): logits = model.forward_from_layer(h, start_l) return F.cross_entropy(logits, y, reduction='none') return fwd_fn fwd_fn = make_fwd_fn(l) rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=32) results['perturbation_rho'].append(float(rho)) nud = nudging_test(h_l, a_l, fwd_fn, eta=0.003) results['nudging'].append(float(nud)) return results # ============================================================================= # Main # ============================================================================= def run_experiment(args): device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") os.makedirs(args.output_dir, exist_ok=True) all_results = [] for L in args.depths: for seed in args.seeds: print(f"\n{'='*60}") print(f"L={L}, seed={seed}") print(f"{'='*60}") torch.manual_seed(seed) np.random.seed(seed) torch.cuda.manual_seed_all(seed) teacher = TeacherNet(args.d_hidden, args.num_classes, L, alpha=args.alpha, seed=seed * 1000).to(device) # --- DFA --- print("\n --- DFA ---") torch.manual_seed(seed) model_dfa = StudentNet(args.d_hidden, args.num_classes, L, alpha=args.alpha).to(device) Bs = train_dfa(model_dfa, teacher, device, args) diag_dfa = compute_diagnostics(model_dfa, teacher, device, 'dfa', args, dfa_Bs=Bs) r_dfa = { 'method': 'dfa', 'L': L, 'seed': seed, 'mean_gamma': float(np.mean(diag_dfa['bp_cosine'])), 'mean_rho': float(np.mean(diag_dfa['perturbation_rho'])), 'mean_nudge': float(np.mean(diag_dfa['nudging'])), 'per_layer_gamma': diag_dfa['bp_cosine'], 'per_layer_rho': diag_dfa['perturbation_rho'], } print(f" Result: Gamma={r_dfa['mean_gamma']:.4f}, rho={r_dfa['mean_rho']:.4f}") all_results.append(r_dfa) # --- Scalar CB --- print("\n --- Scalar CB ---") torch.manual_seed(seed) model_cb = StudentNet(args.d_hidden, args.num_classes, L, alpha=args.alpha).to(device) vnet = train_scalar_cb(model_cb, teacher, device, args) diag_cb = compute_diagnostics(model_cb, teacher, device, 'scalar_cb', args, value_net=vnet) r_cb = { 'method': 'scalar_cb', 'L': L, 'seed': seed, 'mean_gamma': float(np.mean(diag_cb['bp_cosine'])), 'mean_rho': float(np.mean(diag_cb['perturbation_rho'])), 'mean_nudge': float(np.mean(diag_cb['nudging'])), 'per_layer_gamma': diag_cb['bp_cosine'], 'per_layer_rho': diag_cb['perturbation_rho'], } print(f" Result: Gamma={r_cb['mean_gamma']:.4f}, rho={r_cb['mean_rho']:.4f}") all_results.append(r_cb) # --- Vector Field M=4 --- for M in args.M_values: print(f"\n --- Vector Field M={M} ---") torch.manual_seed(seed) model_vec = StudentNet(args.d_hidden, args.num_classes, L, alpha=args.alpha).to(device) vec_net = train_vector_field(model_vec, teacher, device, args, M=M) diag_vec = compute_diagnostics(model_vec, teacher, device, f'vector_M{M}', args, vector_net=vec_net) r_vec = { 'method': f'vector_M{M}', 'L': L, 'seed': seed, 'M': M, 'mean_gamma': float(np.mean(diag_vec['bp_cosine'])), 'mean_rho': float(np.mean(diag_vec['perturbation_rho'])), 'mean_nudge': float(np.mean(diag_vec['nudging'])), 'per_layer_gamma': diag_vec['bp_cosine'], 'per_layer_rho': diag_vec['perturbation_rho'], } print(f" Result: Gamma={r_vec['mean_gamma']:.4f}, rho={r_vec['mean_rho']:.4f}") all_results.append(r_vec) # Summary print(f"\n{'='*80}") print("SUMMARY") print(f"{'='*80}") print(f"{'Method':<20} {'L':>3} {'seed':>5} {'Gamma':>8} {'rho':>8} {'nudge':>10}") print("-" * 60) for r in all_results: print(f"{r['method']:<20} {r['L']:>3} {r['seed']:>5} {r['mean_gamma']:>8.4f} " f"{r['mean_rho']:>8.4f} {r['mean_nudge']:>10.6f}") # Save out_path = os.path.join(args.output_dir, 'results.json') with open(out_path, 'w') as f: json.dump(all_results, f, indent=2) print(f"\nResults saved to {out_path}") # Compare vector field vs scalar CB print(f"\n{'='*60}") print("COMPARISON: Vector Field vs Scalar CB") print(f"{'='*60}") for L in args.depths: for seed in args.seeds: cb_r = [r for r in all_results if r['method'] == 'scalar_cb' and r['L'] == L and r['seed'] == seed] if not cb_r: continue cb_r = cb_r[0] for M in args.M_values: vec_r = [r for r in all_results if r['method'] == f'vector_M{M}' and r['L'] == L and r['seed'] == seed] if not vec_r: continue vec_r = vec_r[0] delta_gamma = vec_r['mean_gamma'] - cb_r['mean_gamma'] delta_rho = vec_r['mean_rho'] - cb_r['mean_rho'] print(f" L={L} seed={seed} M={M}: delta_Gamma={delta_gamma:+.4f}, delta_rho={delta_rho:+.4f}") if delta_rho >= 0.05 or delta_gamma >= 0.05: print(f" -> SIGNIFICANT IMPROVEMENT") elif delta_rho > 0 and delta_gamma > 0: print(f" -> Modest improvement") else: print(f" -> No clear improvement") def main(): parser = argparse.ArgumentParser(description='Phase C: Vector Credit Field Pilot') parser.add_argument('--d_hidden', type=int, default=128) parser.add_argument('--num_classes', type=int, default=10) parser.add_argument('--alpha', type=float, default=1.0) parser.add_argument('--depths', type=int, nargs='+', default=[4, 8]) parser.add_argument('--M_values', type=int, nargs='+', default=[4, 8]) parser.add_argument('--epochs', type=int, default=80) parser.add_argument('--steps_per_epoch', type=int, default=50) parser.add_argument('--batch_size', type=int, default=256) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--lr_fb', type=float, default=1e-3) parser.add_argument('--warmup_ratio', type=float, default=0.05) parser.add_argument('--term_grad_weight', type=float, default=1.0) parser.add_argument('--lam', type=float, default=0.1) parser.add_argument('--K', type=int, default=4) parser.add_argument('--sigma_bridge', type=float, default=0.05) parser.add_argument('--ema_momentum', type=float, default=0.995) parser.add_argument('--pert_eps', type=float, default=1e-3) parser.add_argument('--pert_beta', type=float, default=1.0) parser.add_argument('--seeds', type=int, nargs='+', default=[42, 123, 456]) parser.add_argument('--gpu', type=int, default=3) parser.add_argument('--output_dir', type=str, default='results/vector_credit_pilot') args = parser.parse_args() run_experiment(args) if __name__ == '__main__': main()