"""Test SVD-based CVH vs random basis CVH vs Unconditional.""" import sys import os import time import torch sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.longlamp import load_longlamp, select_k_profile_items from data.templates import build_query_prompt from models.qwen_wrapper import QwenWrapper from models.cvh import CVHHead, UnconditionalHead from models.svd_cvh import SVDCVHHead, SVDUncondHead from adapt.cache_hidden import cache_support_hidden_states from adapt.fit_theta import fit_theta from eval.metrics import evaluate_all def run_head(wrapper, examples, support_sets, head_module, d=64, beta=0.05, steps=30, lr=0.05, max_new_tokens=512, min_new_tokens=64): device = 'cuda:1' lm_head_bias = None if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: lm_head_bias = wrapper.model.lm_head.bias.data predictions = [] theta_norms = [] for i, (ex, support) in enumerate(zip(examples, support_sets)): cached_h = cache_support_hidden_states(wrapper, support, ex['task']) if not cached_h: prompt = build_query_prompt(ex['query_input'], ex['task']) pred = wrapper.generate_base(prompt, max_new_tokens=max_new_tokens) predictions.append(pred) continue theta = fit_theta( cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight, lm_head_bias=lm_head_bias, head_module=head_module, d=d, lr=lr, steps=steps, beta=beta, lam=1e-4, max_grad_norm=5.0, device=device, verbose=(i == 0), ) theta_norms.append(theta.norm().item()) prompt = build_query_prompt(ex['query_input'], ex['task']) pred = wrapper.generate_with_head( prompt, theta, head_module.forward_fn, max_new_tokens=max_new_tokens, temperature=0.0, min_new_tokens=min_new_tokens, ) predictions.append(pred) del cached_h, theta torch.cuda.empty_cache() if (i + 1) % 20 == 0: print(f" {i+1}/{len(examples)}") avg_norm = sum(theta_norms) / max(len(theta_norms), 1) # Check avg output length avg_len = sum(len(p.split()) for p in predictions) / max(len(predictions), 1) return predictions, avg_norm, avg_len def main(): N = 50 print(f"Loading data ({N} examples)...") examples = load_longlamp('product_review_user', split='val')[:N] K = 4 support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples] references = [ex['target_output'] for ex in examples] support_texts = [[s['support_output'] for s in ss] for ss in support_sets] avg_ref_len = sum(len(r.split()) for r in references) / len(references) print(f"Avg reference length: {avg_ref_len:.0f} words") print("Loading model...") wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1') H = wrapper.hidden_size device = 'cuda:1' # Base print("\n=== Base ===") base_preds = [] for ex in examples: prompt = build_query_prompt(ex['query_input'], ex['task']) pred = wrapper.generate_base(prompt, max_new_tokens=512, temperature=0.0) base_preds.append(pred) base_r = evaluate_all(base_preds, references, support_texts) base_len = sum(len(p.split()) for p in base_preds) / len(base_preds) print(f" ROUGE-L: {base_r['rougeL']:.4f}, METEOR: {base_r['meteor']:.4f}, " f"SFD: {base_r['sfd']:.4f}, avg_len: {base_len:.0f}") results = {} results['Base'] = {**base_r, 'avg_len': base_len} # SVD-based heads print("\nComputing SVD of lm_head...") svd_cvh = SVDCVHHead(wrapper.lm_head_weight, d=64, alpha=0.1).to(device) svd_uncond = SVDUncondHead(wrapper.lm_head_weight, d=64, alpha=0.1).to(device) configs = [ ('Random CVH d=64', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64), ('Random Uncond d=64', UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64), ('SVD CVH d=64', svd_cvh, 64), ('SVD Uncond d=64', svd_uncond, 64), # Try different alpha with SVD ('SVD CVH d=64 a=0.05', SVDCVHHead(wrapper.lm_head_weight, d=64, alpha=0.05).to(device), 64), ('SVD CVH d=64 a=0.2', SVDCVHHead(wrapper.lm_head_weight, d=64, alpha=0.2).to(device), 64), ] for name, head, d in configs: print(f"\n=== {name} ===") t0 = time.time() preds, avg_norm, avg_len = run_head( wrapper, examples, support_sets, head, d=d, beta=0.05, steps=30, lr=0.05, max_new_tokens=512, min_new_tokens=64, ) elapsed = time.time() - t0 r = evaluate_all(preds, references, support_texts) results[name] = {**r, 'avg_len': avg_len} print(f" ROUGE-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, " f"SFD: {r['sfd']:.4f}, avg|theta|: {avg_norm:.3f}, " f"avg_len: {avg_len:.0f}, time: {elapsed:.0f}s") # Summary print("\n" + "=" * 100) print(f"{'Config':<25} {'R-1':<8} {'R-L':<8} {'METEOR':<8} {'SFD':<8} {'Len':<6}") print("-" * 100) for name, r in results.items(): print(f"{name:<25} {r['rouge1']:<8.4f} {r['rougeL']:<8.4f} " f"{r['meteor']:<8.4f} {r['sfd']:<8.4f} {r.get('avg_len', 0):<6.0f}") if __name__ == '__main__': main()