"""Quick sweep over alpha values to find the right perturbation scale.""" import sys import os import json import time import torch sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.longlamp import load_longlamp, select_k_profile_items from data.templates import build_query_prompt from models.qwen_wrapper import QwenWrapper from models.cvh import CVHHead from adapt.cache_hidden import cache_support_hidden_states from adapt.fit_theta import fit_theta from eval.metrics import evaluate_all def run_cvh_with_params(wrapper, examples, support_sets, alpha, beta, steps, d=64, lr=0.05): """Run CVH with specific hyperparameters.""" device = 'cuda:1' H = wrapper.hidden_size head = CVHHead(H, d=d, alpha=alpha, basis_seed=42).to(device) lm_head_bias = None if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: lm_head_bias = wrapper.model.lm_head.bias.data predictions = [] theta_norms = [] for i, (ex, support) in enumerate(zip(examples, support_sets)): cached_h = cache_support_hidden_states(wrapper, support, ex['task']) if not cached_h: prompt = build_query_prompt(ex['query_input'], ex['task']) pred = wrapper.generate_base(prompt, max_new_tokens=256) predictions.append(pred) continue theta = fit_theta( cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight, lm_head_bias=lm_head_bias, head_module=head, d=d, lr=lr, steps=steps, beta=beta, lam=1e-4, max_grad_norm=5.0, device=device, verbose=False, ) theta_norms.append(theta.norm().item()) prompt = build_query_prompt(ex['query_input'], ex['task']) pred = wrapper.generate_with_head( prompt, theta, head.forward_fn, max_new_tokens=256, temperature=0.0, ) predictions.append(pred) del cached_h, theta torch.cuda.empty_cache() avg_norm = sum(theta_norms) / max(len(theta_norms), 1) return predictions, avg_norm def main(): print("Loading data...") examples = load_longlamp('product_review_user', split='val')[:50] K = 4 support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples] references = [ex['target_output'] for ex in examples] support_texts = [[s['support_output'] for s in ss] for ss in support_sets] print("Loading model...") wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1') # Run base print("\n=== Base ===") base_preds = [] for ex in examples: prompt = build_query_prompt(ex['query_input'], ex['task']) pred = wrapper.generate_base(prompt, max_new_tokens=256, temperature=0.0) base_preds.append(pred) base_results = evaluate_all(base_preds, references, support_texts) print(f" ROUGE-L: {base_results['rougeL']:.4f}, METEOR: {base_results['meteor']:.4f}, SFD: {base_results['sfd']:.4f}") # Sweep configs = [ {'alpha': 0.1, 'beta': 0.05, 'steps': 30, 'lr': 0.05}, {'alpha': 0.3, 'beta': 0.05, 'steps': 30, 'lr': 0.05}, {'alpha': 0.5, 'beta': 0.05, 'steps': 30, 'lr': 0.05}, {'alpha': 0.3, 'beta': 0.01, 'steps': 50, 'lr': 0.05}, {'alpha': 0.5, 'beta': 0.01, 'steps': 50, 'lr': 0.05}, {'alpha': 0.3, 'beta': 0.01, 'steps': 50, 'lr': 0.1}, ] all_results = {'Base': base_results} for cfg in configs: name = f"a{cfg['alpha']}_b{cfg['beta']}_s{cfg['steps']}_lr{cfg['lr']}" print(f"\n=== CVH {name} ===") t0 = time.time() preds, avg_norm = run_cvh_with_params( wrapper, examples, support_sets, alpha=cfg['alpha'], beta=cfg['beta'], steps=cfg['steps'], lr=cfg['lr'], ) elapsed = time.time() - t0 results = evaluate_all(preds, references, support_texts) all_results[name] = results print(f" ROUGE-L: {results['rougeL']:.4f}, METEOR: {results['meteor']:.4f}, " f"SFD: {results['sfd']:.4f}, avg|theta|: {avg_norm:.3f}, time: {elapsed:.0f}s") # Summary print("\n" + "=" * 80) print(f"{'Config':<40} {'ROUGE-L':<10} {'METEOR':<10} {'SFD':<10}") print("-" * 80) for name, r in all_results.items(): print(f"{name:<40} {r['rougeL']:<10.4f} {r['meteor']:<10.4f} {r['sfd']:<10.4f}") if __name__ == '__main__': main()