"""Main experiment runner for CVH development. Runs all methods (Base, Prompt-All-K, BM25-Top1, Unconditional, CVH) on LongLaMP validation set and reports metrics. """ import sys import os import json import time import argparse import yaml import torch # Add project root to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.longlamp import load_longlamp, select_k_profile_items from data.templates import build_query_prompt from models.qwen_wrapper import QwenWrapper from models.cvh import CVHHead, UnconditionalHead from adapt.cache_hidden import cache_support_hidden_states from adapt.fit_theta import fit_theta from baselines.prompt_all_k import generate_prompt_all_k from baselines.bm25_top1 import generate_bm25_top1 from eval.metrics import evaluate_all, print_results_table, compute_recovery def run_method_base(wrapper, examples, cfg): """Run Base (no personalization) method.""" predictions = [] for i, ex in enumerate(examples): prompt = build_query_prompt(ex['query_input'], ex['task']) pred = wrapper.generate_base( prompt, max_new_tokens=cfg['max_new_tokens'], temperature=cfg.get('temperature', 0.7), top_p=cfg.get('top_p', 0.9), ) predictions.append(pred) if (i + 1) % 20 == 0: print(f" Base: {i+1}/{len(examples)}") return predictions def run_method_prompt_all_k(wrapper, examples, support_sets, cfg): """Run Prompt-All-K baseline.""" predictions = [] for i, (ex, support) in enumerate(zip(examples, support_sets)): pred = generate_prompt_all_k( wrapper, ex['query_input'], support, ex['task'], max_new_tokens=cfg['max_new_tokens'], temperature=cfg.get('temperature', 0.7), top_p=cfg.get('top_p', 0.9), ) predictions.append(pred) if (i + 1) % 20 == 0: print(f" Prompt-All-K: {i+1}/{len(examples)}") return predictions def run_method_bm25_top1(wrapper, examples, support_sets, cfg): """Run BM25-Top1 baseline.""" predictions = [] for i, (ex, support) in enumerate(zip(examples, support_sets)): pred = generate_bm25_top1( wrapper, ex['query_input'], support, ex['task'], max_new_tokens=cfg['max_new_tokens'], temperature=cfg.get('temperature', 0.7), top_p=cfg.get('top_p', 0.9), ) predictions.append(pred) if (i + 1) % 20 == 0: print(f" BM25-Top1: {i+1}/{len(examples)}") return predictions def run_method_vector_head(wrapper, examples, support_sets, cfg, head_type='cvh'): """Run CVH or Unconditional head method.""" device = cfg['device'] d = cfg['d'] alpha = cfg['alpha'] basis_seed = cfg['basis_seed'] H = wrapper.hidden_size # Create head if head_type == 'cvh': head = CVHHead(H, d=d, alpha=alpha, basis_seed=basis_seed).to(device) else: head = UnconditionalHead(H, d=d, alpha=alpha, basis_seed=basis_seed).to(device) lm_head_bias = None if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: lm_head_bias = wrapper.model.lm_head.bias.data predictions = [] adapt_times = [] for i, (ex, support) in enumerate(zip(examples, support_sets)): # Step 1: Cache hidden states from support set t0 = time.time() cached_h = cache_support_hidden_states(wrapper, support, ex['task']) if not cached_h: # Fallback to base generation if caching fails prompt = build_query_prompt(ex['query_input'], ex['task']) pred = wrapper.generate_base(prompt, max_new_tokens=cfg['max_new_tokens']) predictions.append(pred) adapt_times.append(0.0) continue # Step 2: Fit theta_u theta = fit_theta( cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight, lm_head_bias=lm_head_bias, head_module=head, d=d, lr=cfg['lr'], steps=cfg['adapt_steps'], beta=cfg['beta'], lam=cfg['lam'], max_grad_norm=cfg['max_grad_norm'], device=device, verbose=False, ) adapt_time = time.time() - t0 adapt_times.append(adapt_time) # Step 3: Generate with personalized head (blended) prompt = build_query_prompt(ex['query_input'], ex['task']) blend_gamma = cfg.get('blend_gamma', 0.5) pred = wrapper.generate_with_head_blended( prompt, theta, head.forward_fn, blend_gamma=blend_gamma, max_new_tokens=cfg['max_new_tokens'], min_new_tokens=cfg.get('min_new_tokens', 128), temperature=cfg.get('temperature', 0.0), ) predictions.append(pred) # Cleanup GPU memory del cached_h, theta torch.cuda.empty_cache() if (i + 1) % 10 == 0: avg_adapt = sum(adapt_times) / len(adapt_times) print(f" {head_type.upper()}: {i+1}/{len(examples)} " f"(avg adapt: {avg_adapt:.2f}s)") avg_adapt_time = sum(adapt_times) / max(len(adapt_times), 1) print(f" Average adaptation time: {avg_adapt_time:.2f}s") return predictions, avg_adapt_time def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, required=True) parser.add_argument('--num_eval', type=int, default=None, help='Override number of examples to evaluate') parser.add_argument('--methods', type=str, default='all', help='Comma-separated methods: base,prompt_all_k,bm25_top1,uncond,cvh') parser.add_argument('--output_dir', type=str, default='outputs') parser.add_argument('--seed', type=int, default=0) args = parser.parse_args() # Load config with open(args.config) as f: cfg = yaml.safe_load(f) if args.num_eval is not None: cfg['num_eval'] = args.num_eval torch.manual_seed(args.seed) print(f"=== CVH Dev Experiment ===") print(f"Config: {args.config}") print(f"Task: {cfg['task']}, Setting: {cfg['setting']}") print(f"Model: {cfg['model_name']}, Device: {cfg['device']}") print(f"d={cfg['d']}, K={cfg['K']}, alpha={cfg['alpha']}, steps={cfg['adapt_steps']}") # Load data print("\nLoading LongLaMP data...") examples = load_longlamp(cfg['dataset_config'], split='val') print(f"Loaded {len(examples)} validation examples") # Limit examples if specified num_eval = cfg.get('num_eval', -1) if num_eval > 0: examples = examples[:num_eval] print(f"Evaluating on {len(examples)} examples") # Prepare support sets K = cfg['K'] support_sets = [] for ex in examples: support = select_k_profile_items(ex['profile_items'], K, seed=args.seed) support_sets.append(support) # Gather references and support texts for metrics references = [ex['target_output'] for ex in examples] support_texts_per_example = [ [s['support_output'] for s in support] for support in support_sets ] # Parse methods to run if args.methods == 'all': methods_to_run = ['base', 'prompt_all_k', 'bm25_top1', 'uncond', 'cvh'] else: methods_to_run = args.methods.split(',') # Load model print(f"\nLoading model {cfg['model_name']}...") wrapper = QwenWrapper(cfg['model_name'], device=cfg['device']) print(f"Model loaded. Hidden size: {wrapper.hidden_size}") results = {} all_predictions = {} # Run each method if 'base' in methods_to_run: print("\n--- Running Base ---") preds = run_method_base(wrapper, examples, cfg) all_predictions['Base'] = preds results['Base'] = evaluate_all(preds, references, support_texts_per_example) print(f" ROUGE-L: {results['Base']['rougeL']:.4f}, METEOR: {results['Base']['meteor']:.4f}, SFD: {results['Base']['sfd']:.4f}") if 'prompt_all_k' in methods_to_run: print(f"\n--- Running Prompt-All-K (K={K}) ---") preds = run_method_prompt_all_k(wrapper, examples, support_sets, cfg) all_predictions['Prompt-All-K'] = preds results['Prompt-All-K'] = evaluate_all(preds, references, support_texts_per_example) print(f" ROUGE-L: {results['Prompt-All-K']['rougeL']:.4f}, METEOR: {results['Prompt-All-K']['meteor']:.4f}, SFD: {results['Prompt-All-K']['sfd']:.4f}") if 'bm25_top1' in methods_to_run: print(f"\n--- Running BM25-Top1 ---") preds = run_method_bm25_top1(wrapper, examples, support_sets, cfg) all_predictions['BM25-Top1'] = preds results['BM25-Top1'] = evaluate_all(preds, references, support_texts_per_example) print(f" ROUGE-L: {results['BM25-Top1']['rougeL']:.4f}, METEOR: {results['BM25-Top1']['meteor']:.4f}, SFD: {results['BM25-Top1']['sfd']:.4f}") if 'uncond' in methods_to_run: print(f"\n--- Running Unconditional Head ---") preds, adapt_time = run_method_vector_head(wrapper, examples, support_sets, cfg, head_type='uncond') all_predictions['Uncond-Head'] = preds results['Uncond-Head'] = evaluate_all(preds, references, support_texts_per_example) results['Uncond-Head']['adapt_time'] = adapt_time print(f" ROUGE-L: {results['Uncond-Head']['rougeL']:.4f}, METEOR: {results['Uncond-Head']['meteor']:.4f}, SFD: {results['Uncond-Head']['sfd']:.4f}") if 'cvh' in methods_to_run: print(f"\n--- Running CVH ---") preds, adapt_time = run_method_vector_head(wrapper, examples, support_sets, cfg, head_type='cvh') all_predictions['CVH'] = preds results['CVH'] = evaluate_all(preds, references, support_texts_per_example) results['CVH']['adapt_time'] = adapt_time print(f" ROUGE-L: {results['CVH']['rougeL']:.4f}, METEOR: {results['CVH']['meteor']:.4f}, SFD: {results['CVH']['sfd']:.4f}") # Print comparison table print("\n" + "=" * 70) print("RESULTS SUMMARY") print("=" * 70) print_results_table(results) # Compute compression for vector methods print("\n--- Efficiency ---") theta_bytes = cfg['d'] * 2 # bf16 = 2 bytes per dim for ex_support in support_texts_per_example[:5]: from eval.metrics import compute_compression comp = compute_compression(ex_support, theta_bytes) print(f" Compression ratio (example): {comp:.0f}x") print(f" Theta size: {theta_bytes} bytes") print(f" Personalization prompt tokens at inference: 0") # Save results os.makedirs(args.output_dir, exist_ok=True) exp_name = f"{cfg['task']}_{cfg['setting']}_K{K}_d{cfg['d']}" output_path = os.path.join(args.output_dir, f"{exp_name}_results.json") # Convert results to serializable format save_data = { 'config': cfg, 'results': results, 'num_examples': len(examples), } with open(output_path, 'w') as f: json.dump(save_data, f, indent=2) print(f"\nResults saved to {output_path}") # Save predictions pred_path = os.path.join(args.output_dir, f"{exp_name}_predictions.json") save_preds = {} for method, preds in all_predictions.items(): save_preds[method] = [ { 'example_id': examples[i]['example_id'], 'prediction': preds[i], 'reference': references[i], } for i in range(len(preds)) ] with open(pred_path, 'w') as f: json.dump(save_preds, f, indent=2) print(f"Predictions saved to {pred_path}") if __name__ == '__main__': main()