summaryrefslogtreecommitdiff
path: root/experiments/synth_vector_credit.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/synth_vector_credit.py')
-rw-r--r--experiments/synth_vector_credit.py708
1 files changed, 708 insertions, 0 deletions
diff --git a/experiments/synth_vector_credit.py b/experiments/synth_vector_credit.py
new file mode 100644
index 0000000..14e28e2
--- /dev/null
+++ b/experiments/synth_vector_credit.py
@@ -0,0 +1,708 @@
+"""
+Phase C: Direct Vector Credit Field Pilot.
+
+Compare scalar credit bridge vs direct vector credit field on synthetic best regime.
+Vector field: a_phi(h_l, t_l, s) -> R^d, trained with symmetric finite-difference
+directional targets (no hidden BP anchor).
+
+Loss:
+ L_proj = (1/M) sum_j ( <a_phi(h_l, t_l, s), v_j> - g_j )^2
+ where g_j = [ loss(h_l + eps*v_j) - loss(h_l - eps*v_j) ] / (2*eps)
+
+ L_term = || a_phi(h_L, 1, s) - delta_L ||^2
+
+ L_total = L_term + beta * L_proj
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import copy
+import time
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema
+from metrics.credit_metrics import (
+ cosine_similarity_batch, perturbation_correlation, nudging_test,
+ offline_bp_cosine
+)
+
+
+# =============================================================================
+# Synthetic teacher-student (from synth_nonlinearity_ladder.py)
+# =============================================================================
+class TeacherNet(nn.Module):
+ """Fixed teacher with controllable nonlinearity."""
+ def __init__(self, d_hidden, num_classes, num_blocks, alpha=1.0, seed=0):
+ super().__init__()
+ self.d_hidden = d_hidden
+ self.num_blocks = num_blocks
+ self.alpha = alpha
+ rng = torch.Generator().manual_seed(seed)
+ self.Ws = nn.ParameterList()
+ for _ in range(num_blocks):
+ W = torch.randn(d_hidden, d_hidden, generator=rng) * 0.3 / (d_hidden ** 0.5)
+ U, S, Vh = torch.linalg.svd(W, full_matrices=False)
+ S_clamped = S.clamp(max=0.3)
+ W = U @ torch.diag(S_clamped) @ Vh
+ self.Ws.append(nn.Parameter(W, requires_grad=False))
+ self.U = nn.Parameter(torch.randn(num_classes, d_hidden, generator=rng) / (d_hidden ** 0.5),
+ requires_grad=False)
+
+ def phi(self, z):
+ return (1 - self.alpha) * z + self.alpha * torch.tanh(z)
+
+ def forward(self, x):
+ h = x
+ for W in self.Ws:
+ h = h + self.phi(h @ W.T)
+ return h @ self.U.T
+
+
+class StudentBlock(nn.Module):
+ """Student block with pre-LayerNorm."""
+ def __init__(self, d_hidden, alpha=1.0):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.w = nn.Linear(d_hidden, d_hidden, bias=False)
+ nn.init.normal_(self.w.weight, std=0.01)
+ self.alpha = alpha
+
+ def phi(self, z):
+ return (1 - self.alpha) * z + self.alpha * torch.tanh(z)
+
+ def forward(self, h):
+ return self.w(self.phi(self.ln(h)))
+
+
+class StudentNet(nn.Module):
+ """Student network."""
+ def __init__(self, d_hidden, num_classes, num_blocks, alpha=1.0):
+ super().__init__()
+ self.blocks = nn.ModuleList([StudentBlock(d_hidden, alpha) for _ in range(num_blocks)])
+ self.out_head = nn.Linear(d_hidden, num_classes)
+ self.d_hidden = d_hidden
+ self.num_blocks = num_blocks
+
+ def forward(self, x, return_hidden=False):
+ h = x
+ hiddens = [h] if return_hidden else None
+ for block in self.blocks:
+ f = block(h)
+ h = h + f
+ if return_hidden:
+ hiddens.append(h)
+ logits = self.out_head(h)
+ if return_hidden:
+ return logits, hiddens
+ return logits
+
+ def forward_from_layer(self, h, start_layer):
+ for i in range(start_layer, self.num_blocks):
+ f = self.blocks[i](h)
+ h = h + f
+ return self.out_head(h)
+
+
+# =============================================================================
+# Vector Credit Field Network
+# =============================================================================
+class VectorCreditNet(nn.Module):
+ """
+ Direct vector credit field: a_phi(h_l, t_l, s) -> R^d.
+ Output is d-dimensional credit vector directly.
+ """
+ def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+
+ input_dim = d_hidden + time_embed_dim + s_dim
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, d_hidden))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h, t, s):
+ """Returns credit vector (batch, d_hidden)."""
+ h_normed = self.ln(h)
+ t_emb = self.time_embed(t)
+ inp = torch.cat([h_normed, t_emb, s], dim=-1)
+ return self.net(inp)
+
+
+# =============================================================================
+# Training functions
+# =============================================================================
+def generate_batch(teacher, d_hidden, num_classes, batch_size, device):
+ """Generate synthetic data from teacher."""
+ x = torch.randn(batch_size, d_hidden, device=device)
+ with torch.no_grad():
+ teacher_logits = teacher(x)
+ y = teacher_logits.argmax(dim=-1)
+ return x, y
+
+
+def train_scalar_cb(model, teacher, device, args):
+ """Train scalar credit bridge (current method, as baseline)."""
+ d = model.d_hidden
+ L = model.num_blocks
+ num_classes = args.num_classes
+
+ value_net = ValueNet(d_hidden=d, s_dim=num_classes, time_embed_dim=32,
+ hidden_dim=256, num_layers=3).to(device)
+ value_net_ema = create_ema_model(value_net)
+
+ Bs = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)]
+
+ block_opts = [optim.AdamW(b.parameters(), lr=args.lr, weight_decay=0.01) for b in model.blocks]
+ head_opt = optim.AdamW(model.out_head.parameters(), lr=args.lr, weight_decay=0.01)
+ value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb)
+
+ warmup_epochs = max(1, int(args.epochs * args.warmup_ratio))
+ log = {'train_loss': [], 'test_acc': []}
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ value_net.train()
+
+ if epoch <= warmup_epochs:
+ credit_blend = 0.0
+ else:
+ credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs))
+
+ total_loss, correct, total = 0, 0, 0
+ for _ in range(args.steps_per_epoch):
+ x, y = generate_batch(teacher, d, num_classes, args.batch_size, device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+ true_loss = F.cross_entropy(logits, y, reduction='none').detach()
+
+ hL_det = hiddens[-1].detach()
+
+ # Train value net
+ t_L = torch.ones(batch, device=device)
+ V_term = value_net(hL_det, t_L, s)
+ loss_term = ((V_term - true_loss) ** 2).mean()
+
+ # Terminal gradient matching
+ hL_req = hL_det.clone().requires_grad_(True)
+ V_at_L = value_net(hL_req, t_L, s)
+ grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0]
+ hL_req2 = hL_det.clone().requires_grad_(True)
+ logits_tgt = model.out_head(hL_req2)
+ ce = F.cross_entropy(logits_tgt, y, reduction='sum')
+ a_L_exact = torch.autograd.grad(ce, hL_req2, create_graph=False)[0].detach()
+ loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean()
+
+ # Bridge consistency
+ loss_bridge = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ t_next = torch.full((batch,), (l + 1) / L, device=device)
+ V_l = value_net(h_l_det, t_l, s)
+ with torch.no_grad():
+ h_next = hiddens[l + 1].detach()
+ log_terms = []
+ for k in range(args.K):
+ noise = args.sigma_bridge * torch.randn_like(h_next)
+ V_next = value_net_ema(h_next + noise, t_next, s)
+ log_terms.append(-V_next / args.lam)
+ log_stack = torch.stack(log_terms, dim=-1)
+ V_target = -args.lam * (torch.logsumexp(log_stack, dim=-1) - np.log(args.K))
+ loss_bridge += ((V_l - V_target.detach()) ** 2).mean()
+ loss_bridge /= L
+
+ vloss = loss_term + loss_bridge + args.term_grad_weight * loss_tgrad
+ value_opt.zero_grad()
+ vloss.backward()
+ torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0)
+ value_opt.step()
+ update_ema(value_net, value_net_ema, args.ema_momentum)
+
+ # Compute credits
+ cb_credits = []
+ for l in range(L):
+ h_l_det = hiddens[l].detach().requires_grad_(True)
+ t_l = torch.full((batch,), l / L, device=device)
+ V_l = value_net(h_l_det, t_l, s)
+ a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0]
+ cb_credits.append(a_l.detach())
+
+ dfa_credits = [(e_T @ Bs[l].T).detach() for l in range(L)]
+
+ credits = []
+ for l in range(L):
+ if credit_blend >= 1.0:
+ credits.append(cb_credits[l])
+ elif credit_blend <= 0.0:
+ credits.append(dfa_credits[l])
+ else:
+ cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ credits.append(credit_blend * cb_credits[l] / cb_rms + (1 - credit_blend) * dfa_credits[l] / dfa_rms)
+
+ # Update head
+ logits_out = model.out_head(hL_det)
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ head_opt.step()
+
+ # Update blocks
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * (a / rms)).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ if epoch % 10 == 0 or epoch == 1:
+ acc = correct / total
+ print(f" [scalar_cb] Ep {epoch}: loss={total_loss/total:.4f}, acc={acc:.4f}")
+
+ return value_net
+
+
+def train_vector_field(model, teacher, device, args, M=4):
+ """
+ Train direct vector credit field with perturbation-based targets.
+ No hidden BP anchor.
+ """
+ d = model.d_hidden
+ L = model.num_blocks
+ num_classes = args.num_classes
+
+ vector_net = VectorCreditNet(d_hidden=d, s_dim=num_classes, time_embed_dim=32,
+ hidden_dim=256, num_layers=3).to(device)
+
+ Bs = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)]
+
+ block_opts = [optim.AdamW(b.parameters(), lr=args.lr, weight_decay=0.01) for b in model.blocks]
+ head_opt = optim.AdamW(model.out_head.parameters(), lr=args.lr, weight_decay=0.01)
+ vec_opt = optim.Adam(vector_net.parameters(), lr=args.lr_fb)
+
+ warmup_epochs = max(1, int(args.epochs * args.warmup_ratio))
+ eps = args.pert_eps
+ beta = args.pert_beta
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ vector_net.train()
+
+ if epoch <= warmup_epochs:
+ credit_blend = 0.0
+ else:
+ credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs))
+
+ total_loss, correct, total = 0, 0, 0
+ total_vloss = 0
+
+ for _ in range(args.steps_per_epoch):
+ x, y = generate_batch(teacher, d, num_classes, args.batch_size, device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+
+ hL_det = hiddens[-1].detach()
+
+ # Terminal matching: a_phi(h_L, 1, s) = delta_L
+ t_L = torch.ones(batch, device=device)
+ a_terminal = vector_net(hL_det, t_L, s)
+ # delta_L = grad_{h_L} CE (output-layer-local)
+ hL_req = hL_det.clone().requires_grad_(True)
+ logits_tgt = model.out_head(hL_req)
+ ce = F.cross_entropy(logits_tgt, y, reduction='sum')
+ delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach()
+ loss_term = ((a_terminal - delta_L) ** 2).sum(dim=-1).mean()
+
+ # Perturbation-based directional targets for all layers
+ loss_proj = torch.tensor(0.0, device=device)
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ a_l = vector_net(h_l_det, t_l, s)
+
+ # Compute directional targets using symmetric finite difference
+ layer_proj_loss = 0.0
+ for _ in range(M):
+ v = torch.randn_like(h_l_det)
+ v = v / (v.norm(dim=-1, keepdim=True) + 1e-8)
+
+ # Forward from perturbed h_l
+ with torch.no_grad():
+ logits_plus = model.forward_from_layer(h_l_det + eps * v, l)
+ loss_plus = F.cross_entropy(logits_plus, y, reduction='none')
+ logits_minus = model.forward_from_layer(h_l_det - eps * v, l)
+ loss_minus = F.cross_entropy(logits_minus, y, reduction='none')
+ g_j = (loss_plus - loss_minus) / (2 * eps) # (batch,)
+
+ # Predicted directional derivative
+ pred_j = (a_l * v).sum(dim=-1) # (batch,)
+ layer_proj_loss = layer_proj_loss + ((pred_j - g_j.detach()) ** 2).mean()
+
+ loss_proj = loss_proj + layer_proj_loss / M
+ loss_proj = loss_proj / L
+
+ vec_loss = loss_term + beta * loss_proj
+ vec_opt.zero_grad()
+ vec_loss.backward()
+ torch.nn.utils.clip_grad_norm_(vector_net.parameters(), 1.0)
+ vec_opt.step()
+ total_vloss += vec_loss.item() * batch
+
+ # Compute credits for block updates
+ with torch.no_grad():
+ vec_credits = []
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ a_l = vector_net(h_l_det, t_l, s)
+ vec_credits.append(a_l.detach())
+
+ dfa_credits = [(e_T @ Bs[l].T).detach() for l in range(L)]
+
+ credits = []
+ for l in range(L):
+ if credit_blend >= 1.0:
+ credits.append(vec_credits[l])
+ elif credit_blend <= 0.0:
+ credits.append(dfa_credits[l])
+ else:
+ vc_rms = (vec_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ credits.append(credit_blend * vec_credits[l] / vc_rms + (1 - credit_blend) * dfa_credits[l] / dfa_rms)
+
+ # Update head
+ logits_out = model.out_head(hL_det)
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ head_opt.step()
+
+ # Update blocks
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * (a / rms)).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ if epoch % 10 == 0 or epoch == 1:
+ acc = correct / total
+ print(f" [vec_M={M}] Ep {epoch}: loss={total_loss/total:.4f}, acc={acc:.4f}, "
+ f"vloss={total_vloss/total:.6f}")
+
+ return vector_net
+
+
+def train_dfa(model, teacher, device, args):
+ """DFA baseline for comparison."""
+ d = model.d_hidden
+ L = model.num_blocks
+ num_classes = args.num_classes
+
+ Bs = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)]
+
+ block_opts = [optim.AdamW(b.parameters(), lr=args.lr, weight_decay=0.01) for b in model.blocks]
+ head_opt = optim.AdamW(model.out_head.parameters(), lr=args.lr, weight_decay=0.01)
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ total_loss, correct, total = 0, 0, 0
+ for _ in range(args.steps_per_epoch):
+ x, y = generate_batch(teacher, d, num_classes, args.batch_size, device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+
+ hL_det = hiddens[-1].detach()
+ logits_out = model.out_head(hL_det)
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ head_opt.step()
+
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = (e_T @ Bs[l].T).detach()
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * (a / rms)).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ if epoch % 10 == 0 or epoch == 1:
+ print(f" [DFA] Ep {epoch}: loss={total_loss/total:.4f}, acc={correct/total:.4f}")
+
+ return Bs
+
+
+# =============================================================================
+# Diagnostics
+# =============================================================================
+def compute_diagnostics(model, teacher, device, method_name, args,
+ value_net=None, vector_net=None, dfa_Bs=None):
+ """Compute Gamma, rho, nudging per layer."""
+ model.eval()
+ if value_net is not None:
+ value_net.eval()
+ if vector_net is not None:
+ vector_net.eval()
+
+ d = model.d_hidden
+ L = model.num_blocks
+ num_classes = args.num_classes
+
+ x, y = generate_batch(teacher, d, num_classes, 512, device)
+ batch = x.size(0)
+
+ # BP gradients (evaluation only)
+ h = x.detach().requires_grad_(True)
+ hiddens_bp = [h]
+ for block in model.blocks:
+ f = block(hiddens_bp[-1])
+ h_next = hiddens_bp[-1] + f
+ hiddens_bp.append(h_next)
+ logits_bp = model.out_head(hiddens_bp[-1])
+ loss_bp = F.cross_entropy(logits_bp, y)
+ grads = torch.autograd.grad(loss_bp, hiddens_bp, retain_graph=False)
+ bp_grads = {l: grads[l].detach().clone() for l in range(L + 1)}
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+
+ results = {'bp_cosine': [], 'perturbation_rho': [], 'nudging': []}
+
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+
+ if method_name == 'dfa':
+ a_l = (s @ dfa_Bs[l].T).detach()
+ elif method_name == 'scalar_cb':
+ h_l_req = h_l.clone().requires_grad_(True)
+ V_l = value_net(h_l_req, t_l, s)
+ a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach()
+ elif method_name.startswith('vector'):
+ a_l = vector_net(h_l, t_l, s).detach()
+ else:
+ raise ValueError(f"Unknown: {method_name}")
+
+ bp_cos = cosine_similarity_batch(a_l, bp_grads[l])
+ results['bp_cosine'].append(float(bp_cos))
+
+ def make_fwd_fn(start_l):
+ def fwd_fn(h):
+ with torch.no_grad():
+ logits = model.forward_from_layer(h, start_l)
+ return F.cross_entropy(logits, y, reduction='none')
+ return fwd_fn
+
+ fwd_fn = make_fwd_fn(l)
+ rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=32)
+ results['perturbation_rho'].append(float(rho))
+
+ nud = nudging_test(h_l, a_l, fwd_fn, eta=0.003)
+ results['nudging'].append(float(nud))
+
+ return results
+
+
+# =============================================================================
+# Main
+# =============================================================================
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ all_results = []
+
+ for L in args.depths:
+ for seed in args.seeds:
+ print(f"\n{'='*60}")
+ print(f"L={L}, seed={seed}")
+ print(f"{'='*60}")
+
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+ teacher = TeacherNet(args.d_hidden, args.num_classes, L,
+ alpha=args.alpha, seed=seed * 1000).to(device)
+
+ # --- DFA ---
+ print("\n --- DFA ---")
+ torch.manual_seed(seed)
+ model_dfa = StudentNet(args.d_hidden, args.num_classes, L, alpha=args.alpha).to(device)
+ Bs = train_dfa(model_dfa, teacher, device, args)
+ diag_dfa = compute_diagnostics(model_dfa, teacher, device, 'dfa', args, dfa_Bs=Bs)
+ r_dfa = {
+ 'method': 'dfa', 'L': L, 'seed': seed,
+ 'mean_gamma': float(np.mean(diag_dfa['bp_cosine'])),
+ 'mean_rho': float(np.mean(diag_dfa['perturbation_rho'])),
+ 'mean_nudge': float(np.mean(diag_dfa['nudging'])),
+ 'per_layer_gamma': diag_dfa['bp_cosine'],
+ 'per_layer_rho': diag_dfa['perturbation_rho'],
+ }
+ print(f" Result: Gamma={r_dfa['mean_gamma']:.4f}, rho={r_dfa['mean_rho']:.4f}")
+ all_results.append(r_dfa)
+
+ # --- Scalar CB ---
+ print("\n --- Scalar CB ---")
+ torch.manual_seed(seed)
+ model_cb = StudentNet(args.d_hidden, args.num_classes, L, alpha=args.alpha).to(device)
+ vnet = train_scalar_cb(model_cb, teacher, device, args)
+ diag_cb = compute_diagnostics(model_cb, teacher, device, 'scalar_cb', args, value_net=vnet)
+ r_cb = {
+ 'method': 'scalar_cb', 'L': L, 'seed': seed,
+ 'mean_gamma': float(np.mean(diag_cb['bp_cosine'])),
+ 'mean_rho': float(np.mean(diag_cb['perturbation_rho'])),
+ 'mean_nudge': float(np.mean(diag_cb['nudging'])),
+ 'per_layer_gamma': diag_cb['bp_cosine'],
+ 'per_layer_rho': diag_cb['perturbation_rho'],
+ }
+ print(f" Result: Gamma={r_cb['mean_gamma']:.4f}, rho={r_cb['mean_rho']:.4f}")
+ all_results.append(r_cb)
+
+ # --- Vector Field M=4 ---
+ for M in args.M_values:
+ print(f"\n --- Vector Field M={M} ---")
+ torch.manual_seed(seed)
+ model_vec = StudentNet(args.d_hidden, args.num_classes, L, alpha=args.alpha).to(device)
+ vec_net = train_vector_field(model_vec, teacher, device, args, M=M)
+ diag_vec = compute_diagnostics(model_vec, teacher, device, f'vector_M{M}', args,
+ vector_net=vec_net)
+ r_vec = {
+ 'method': f'vector_M{M}', 'L': L, 'seed': seed, 'M': M,
+ 'mean_gamma': float(np.mean(diag_vec['bp_cosine'])),
+ 'mean_rho': float(np.mean(diag_vec['perturbation_rho'])),
+ 'mean_nudge': float(np.mean(diag_vec['nudging'])),
+ 'per_layer_gamma': diag_vec['bp_cosine'],
+ 'per_layer_rho': diag_vec['perturbation_rho'],
+ }
+ print(f" Result: Gamma={r_vec['mean_gamma']:.4f}, rho={r_vec['mean_rho']:.4f}")
+ all_results.append(r_vec)
+
+ # Summary
+ print(f"\n{'='*80}")
+ print("SUMMARY")
+ print(f"{'='*80}")
+ print(f"{'Method':<20} {'L':>3} {'seed':>5} {'Gamma':>8} {'rho':>8} {'nudge':>10}")
+ print("-" * 60)
+ for r in all_results:
+ print(f"{r['method']:<20} {r['L']:>3} {r['seed']:>5} {r['mean_gamma']:>8.4f} "
+ f"{r['mean_rho']:>8.4f} {r['mean_nudge']:>10.6f}")
+
+ # Save
+ out_path = os.path.join(args.output_dir, 'results.json')
+ with open(out_path, 'w') as f:
+ json.dump(all_results, f, indent=2)
+ print(f"\nResults saved to {out_path}")
+
+ # Compare vector field vs scalar CB
+ print(f"\n{'='*60}")
+ print("COMPARISON: Vector Field vs Scalar CB")
+ print(f"{'='*60}")
+ for L in args.depths:
+ for seed in args.seeds:
+ cb_r = [r for r in all_results if r['method'] == 'scalar_cb' and r['L'] == L and r['seed'] == seed]
+ if not cb_r:
+ continue
+ cb_r = cb_r[0]
+ for M in args.M_values:
+ vec_r = [r for r in all_results if r['method'] == f'vector_M{M}' and r['L'] == L and r['seed'] == seed]
+ if not vec_r:
+ continue
+ vec_r = vec_r[0]
+ delta_gamma = vec_r['mean_gamma'] - cb_r['mean_gamma']
+ delta_rho = vec_r['mean_rho'] - cb_r['mean_rho']
+ print(f" L={L} seed={seed} M={M}: delta_Gamma={delta_gamma:+.4f}, delta_rho={delta_rho:+.4f}")
+ if delta_rho >= 0.05 or delta_gamma >= 0.05:
+ print(f" -> SIGNIFICANT IMPROVEMENT")
+ elif delta_rho > 0 and delta_gamma > 0:
+ print(f" -> Modest improvement")
+ else:
+ print(f" -> No clear improvement")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Phase C: Vector Credit Field Pilot')
+ parser.add_argument('--d_hidden', type=int, default=128)
+ parser.add_argument('--num_classes', type=int, default=10)
+ parser.add_argument('--alpha', type=float, default=1.0)
+ parser.add_argument('--depths', type=int, nargs='+', default=[4, 8])
+ parser.add_argument('--M_values', type=int, nargs='+', default=[4, 8])
+ parser.add_argument('--epochs', type=int, default=80)
+ parser.add_argument('--steps_per_epoch', type=int, default=50)
+ parser.add_argument('--batch_size', type=int, default=256)
+ parser.add_argument('--lr', type=float, default=1e-3)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--warmup_ratio', type=float, default=0.05)
+ parser.add_argument('--term_grad_weight', type=float, default=1.0)
+ parser.add_argument('--lam', type=float, default=0.1)
+ parser.add_argument('--K', type=int, default=4)
+ parser.add_argument('--sigma_bridge', type=float, default=0.05)
+ parser.add_argument('--ema_momentum', type=float, default=0.995)
+ parser.add_argument('--pert_eps', type=float, default=1e-3)
+ parser.add_argument('--pert_beta', type=float, default=1.0)
+ parser.add_argument('--seeds', type=int, nargs='+', default=[42, 123, 456])
+ parser.add_argument('--gpu', type=int, default=3)
+ parser.add_argument('--output_dir', type=str, default='results/vector_credit_pilot')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()