"""Quick test: normalized CVH vs Uncond with blending.""" import sys import os import time import torch sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.longlamp import load_longlamp, select_k_profile_items from data.templates import build_query_prompt from models.qwen_wrapper import QwenWrapper from models.cvh import CVHHead, UnconditionalHead from adapt.cache_hidden import cache_support_hidden_states from adapt.fit_theta import fit_theta from eval.metrics import evaluate_all def run_blended(wrapper, examples, support_sets, head_module, d=64, beta=0.05, steps=30, lr=0.05, blend_gamma=0.5): device = 'cuda:1' lm_head_bias = None if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: lm_head_bias = wrapper.model.lm_head.bias.data predictions = [] theta_norms = [] for i, (ex, support) in enumerate(zip(examples, support_sets)): cached_h = cache_support_hidden_states(wrapper, support, ex['task']) if not cached_h: prompt = build_query_prompt(ex['query_input'], ex['task']) pred = wrapper.generate_base(prompt, max_new_tokens=512) predictions.append(pred) continue theta = fit_theta( cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight, lm_head_bias=lm_head_bias, head_module=head_module, d=d, lr=lr, steps=steps, beta=beta, lam=1e-4, max_grad_norm=5.0, device=device, verbose=(i == 0), ) theta_norms.append(theta.norm().item()) prompt = build_query_prompt(ex['query_input'], ex['task']) pred = wrapper.generate_with_head_blended( prompt, theta, head_module.forward_fn, blend_gamma=blend_gamma, max_new_tokens=512, min_new_tokens=128, temperature=0.0, ) predictions.append(pred) del cached_h, theta torch.cuda.empty_cache() if (i + 1) % 20 == 0: print(f" {i+1}/{len(examples)}") avg_norm = sum(theta_norms) / max(len(theta_norms), 1) avg_len = sum(len(p.split()) for p in predictions) / max(len(predictions), 1) return predictions, avg_norm, avg_len def main(): N = 100 print(f"Loading data ({N} examples)...") examples = load_longlamp('product_review_user', split='val')[:N] K = 4 support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples] references = [ex['target_output'] for ex in examples] support_texts = [[s['support_output'] for s in ss] for ss in support_sets] print("Loading model...") wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1') H = wrapper.hidden_size device = 'cuda:1' # Base print("\n=== Base ===") base_preds = [] for i, ex in enumerate(examples): prompt = build_query_prompt(ex['query_input'], ex['task']) pred = wrapper.generate_base(prompt, max_new_tokens=512, temperature=0.0) base_preds.append(pred) if (i + 1) % 20 == 0: print(f" {i+1}/{N}") base_r = evaluate_all(base_preds, references, support_texts) base_len = sum(len(p.split()) for p in base_preds) / len(base_preds) print(f" R-L: {base_r['rougeL']:.4f}, METEOR: {base_r['meteor']:.4f}, SFD: {base_r['sfd']:.4f}, len: {base_len:.0f}") results = {'Base': base_r} configs = [ ('Uncond d=64 g=0.5', UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64, 0.5), ('CVH-norm d=64 g=0.5', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64, 0.5), ('CVH-norm d=64 g=0.3', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64, 0.3), ('CVH-norm d=64 g=0.7', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64, 0.7), ] for name, head, d, gamma in configs: print(f"\n=== {name} ===") t0 = time.time() preds, avg_norm, avg_len = run_blended( wrapper, examples, support_sets, head, d=d, beta=0.05, steps=30, lr=0.05, blend_gamma=gamma, ) elapsed = time.time() - t0 r = evaluate_all(preds, references, support_texts) results[name] = r print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, SFD: {r['sfd']:.4f}, " f"|θ|: {avg_norm:.3f}, len: {avg_len:.0f}, time: {elapsed:.0f}s") # Summary print("\n" + "=" * 90) print(f"{'Config':<30} {'R-1':<8} {'R-L':<8} {'METEOR':<8} {'SFD':<8}") print("-" * 90) for name, r in results.items(): print(f"{name:<30} {r['rouge1']:<8.4f} {r['rougeL']:<8.4f} " f"{r['meteor']:<8.4f} {r['sfd']:<8.4f}") if __name__ == '__main__': main()