"""Sweep d values and test multi-basis CVH.""" import sys import os import time import torch sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.longlamp import load_longlamp, select_k_profile_items from data.templates import build_query_prompt from models.qwen_wrapper import QwenWrapper from models.cvh import CVHHead, UnconditionalHead from adapt.cache_hidden import cache_support_hidden_states from adapt.fit_theta import fit_theta from eval.metrics import evaluate_all class MultiBasisCVH(torch.nn.Module): """Two-basis CVH: h'_t = h_t + a1*B1(theta⊙A1*h) + a2*B2(theta⊙A2*h)""" def __init__(self, hidden_size, d=64, alpha=0.1, basis_seed=42): super().__init__() self.hidden_size = hidden_size self.d = d self.alpha = alpha gen1 = torch.Generator() gen1.manual_seed(basis_seed) gen2 = torch.Generator() gen2.manual_seed(basis_seed + 500) scale_a = 1.0 / (hidden_size ** 0.5) scale_b = 1.0 / (d ** 0.5) self.register_buffer('A1', torch.randn(d, hidden_size, generator=gen1) * scale_a) self.register_buffer('B1', torch.randn(hidden_size, d, generator=gen1) * scale_b) self.register_buffer('A2', torch.randn(d, hidden_size, generator=gen2) * scale_a) self.register_buffer('B2', torch.randn(hidden_size, d, generator=gen2) * scale_b) def forward(self, h, theta): proj1 = (self.A1.float() @ h.T).T gated1 = theta.unsqueeze(0) * proj1 res1 = (self.B1.float() @ gated1.T).T proj2 = (self.A2.float() @ h.T).T gated2 = theta.unsqueeze(0) * proj2 res2 = (self.B2.float() @ gated2.T).T return h + self.alpha * (res1 + res2) def forward_fn(self, h, theta): return self.forward(h, theta) def run_head(wrapper, examples, support_sets, head_module, d=64, alpha=0.1, beta=0.05, steps=30, lr=0.05, max_new_tokens=512): device = 'cuda:1' lm_head_bias = None if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: lm_head_bias = wrapper.model.lm_head.bias.data predictions = [] theta_norms = [] for i, (ex, support) in enumerate(zip(examples, support_sets)): cached_h = cache_support_hidden_states(wrapper, support, ex['task']) if not cached_h: prompt = build_query_prompt(ex['query_input'], ex['task']) pred = wrapper.generate_base(prompt, max_new_tokens=max_new_tokens) predictions.append(pred) continue theta = fit_theta( cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight, lm_head_bias=lm_head_bias, head_module=head_module, d=d, lr=lr, steps=steps, beta=beta, lam=1e-4, max_grad_norm=5.0, device=device, verbose=False, ) theta_norms.append(theta.norm().item()) prompt = build_query_prompt(ex['query_input'], ex['task']) pred = wrapper.generate_with_head( prompt, theta, head_module.forward_fn, max_new_tokens=max_new_tokens, temperature=0.0, ) predictions.append(pred) del cached_h, theta torch.cuda.empty_cache() if (i + 1) % 20 == 0: print(f" {i+1}/{len(examples)}") avg_norm = sum(theta_norms) / max(len(theta_norms), 1) return predictions, avg_norm def main(): print("Loading data...") examples = load_longlamp('product_review_user', split='val')[:50] K = 4 support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples] references = [ex['target_output'] for ex in examples] support_texts = [[s['support_output'] for s in ss] for ss in support_sets] print("Loading model...") wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1') H = wrapper.hidden_size device = 'cuda:1' # Base print("\n=== Base ===") base_preds = [] for ex in examples: prompt = build_query_prompt(ex['query_input'], ex['task']) pred = wrapper.generate_base(prompt, max_new_tokens=512, temperature=0.0) base_preds.append(pred) base_r = evaluate_all(base_preds, references, support_texts) print(f" ROUGE-L: {base_r['rougeL']:.4f}, METEOR: {base_r['meteor']:.4f}, SFD: {base_r['sfd']:.4f}") results = {'Base': base_r} configs = [ ('CVH d=64', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64), ('CVH d=128', CVHHead(H, d=128, alpha=0.1, basis_seed=42).to(device), 128), ('CVH d=256', CVHHead(H, d=256, alpha=0.1, basis_seed=42).to(device), 256), ('Uncond d=64', UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64), ('Uncond d=128', UnconditionalHead(H, d=128, alpha=0.1, basis_seed=42).to(device), 128), ('MultiBasis d=64', MultiBasisCVH(H, d=64, alpha=0.1, basis_seed=42).to(device), 64), # Higher beta to preserve content ('CVH d=64 b=0.1', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64), ('CVH d=64 b=0.2', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64), ] betas = { 'CVH d=64 b=0.1': 0.1, 'CVH d=64 b=0.2': 0.2, } for name, head, d in configs: beta = betas.get(name, 0.05) print(f"\n=== {name} (beta={beta}) ===") t0 = time.time() preds, avg_norm = run_head( wrapper, examples, support_sets, head, d=d, alpha=0.1, beta=beta, steps=30, lr=0.05, max_new_tokens=512, ) elapsed = time.time() - t0 r = evaluate_all(preds, references, support_texts) results[name] = r print(f" ROUGE-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, " f"SFD: {r['sfd']:.4f}, avg|theta|: {avg_norm:.3f}, time: {elapsed:.0f}s") # Summary print("\n" + "=" * 90) print(f"{'Config':<25} {'ROUGE-1':<10} {'ROUGE-L':<10} {'METEOR':<10} {'SFD':<10}") print("-" * 90) for name, r in results.items(): print(f"{name:<25} {r['rouge1']:<10.4f} {r['rougeL']:<10.4f} {r['meteor']:<10.4f} {r['sfd']:<10.4f}") if __name__ == '__main__': main()