From ea4a8f837e81b5e5fab6086cb3014c711c5e58e9 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Sun, 5 Apr 2026 10:31:36 -0500 Subject: Add PEFT baselines, ICL baselines, profile-based, and unified pipeline New baselines: - baselines/peft_baseline.py: LoRA, Tiny LoRA, VeRA (per-user PEFT adaptation) - baselines/dense_retrieval.py: Dense retrieval ICL (sentence-transformers) - baselines/profile_based.py: LLM-generated user profile conditioned generation New scripts: - scripts/run_all_methods.py: Unified pipeline running all 9 methods with per-method directory output structure (method/per_user.json) - scripts/run_peft_baselines.py: PEFT-only evaluation (legacy) - scripts/run_significance.py: Significance tests (UPH+Base per-user) - scripts/run_uph_base_per_user.py: UPH+Base with full per-user data - scripts/compute_bertscore.py: BERTScore from saved predictions - scripts/significance_test.py: Standalone significance test framework Updated .gitignore to exclude outputs/ directory. Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/run_peft_baselines.py | 271 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 271 insertions(+) create mode 100644 scripts/run_peft_baselines.py (limited to 'scripts/run_peft_baselines.py') diff --git a/scripts/run_peft_baselines.py b/scripts/run_peft_baselines.py new file mode 100644 index 0000000..c23256b --- /dev/null +++ b/scripts/run_peft_baselines.py @@ -0,0 +1,271 @@ +"""Evaluate PEFT baselines (LoRA, Tiny LoRA, VeRA) with fair decode policy. + +Saves complete per-user data: predictions, references, scores, metadata. + +Usage: + python scripts/run_peft_baselines.py --task review --setting user + python scripts/run_peft_baselines.py --task topic --setting user + python scripts/run_peft_baselines.py --task review --setting user --methods lora +""" + +import sys +import os +import json +import time +import torch + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from data.longlamp import load_longlamp, select_k_profile_items +from data.templates import build_query_prompt +from data.style_features import FEATURE_NAMES, compute_sfd, compute_feature_deltas +from models.qwen_wrapper import QwenWrapper +from baselines.peft_baseline import ( + PEFTBaseline, get_lora_config, get_tiny_lora_config, get_vera_config, +) +from eval.metrics import compute_rouge, compute_meteor + + +PEFT_CONFIGS = { + 'lora': { + 'config_fn': lambda: get_lora_config(rank=8), + 'lr': 1e-4, + 'steps': 30, + 'desc': 'LoRA (rank=8, q+v proj)', + }, + 'tiny_lora': { + 'config_fn': lambda: get_tiny_lora_config(rank=1), + 'lr': 1e-4, + 'steps': 30, + 'desc': 'Tiny LoRA (rank=1, q+v proj)', + }, + 'vera': { + 'config_fn': lambda: get_vera_config(rank=256), + 'lr': 1e-3, + 'steps': 30, + 'desc': 'VeRA (rank=256, q+v proj)', + }, +} + + +def compute_per_user_metrics(pred, ref, support_texts): + """Compute all metrics for a single prediction.""" + r = compute_rouge([pred], [ref]) + m = compute_meteor([pred], [ref]) + sfd_all = compute_sfd(pred if pred.strip() else "empty", support_texts, exclude_length=False) + sfd_nolen = compute_sfd(pred if pred.strip() else "empty", support_texts, exclude_length=True) + deltas = compute_feature_deltas(pred if pred.strip() else "empty", support_texts) + + return { + 'rouge1': r['rouge1'], + 'rougeL': r['rougeL'], + 'meteor': m, + 'sfd_all': sfd_all, + 'sfd_nolen': sfd_nolen, + 'length': len(pred.split()), + 'feature_deltas': {k: v['delta'] for k, v in deltas.items()}, + } + + +def run_peft_method(wrapper, examples, support_sets, references, support_texts, + method_name, config_entry, N): + """Run one PEFT baseline, returning per-user results.""" + cfg = config_entry['config_fn']() + lr = config_entry['lr'] + steps = config_entry['steps'] + + print(f"\n--- {config_entry['desc']} ---") + + baseline = PEFTBaseline(wrapper, cfg) + print(f" Trainable params: {baseline.n_params:,} ({baseline.n_bytes:,} bytes)") + + per_user = [] + + for i, (ex, support) in enumerate(zip(examples, support_sets)): + t0 = time.time() + + pred = baseline.adapt_and_generate( + support_items=support, + query_input=ex['query_input'], + task=ex['task'], + lr=lr, + steps=steps, + max_new_tokens=512, + min_new_tokens=128, + verbose=False, + ) + adapt_time = time.time() - t0 + + # Per-user metrics + metrics = compute_per_user_metrics(pred, references[i], support_texts[i]) + + per_user.append({ + 'example_id': ex['example_id'], + 'user_id': ex['user_id'], + 'prediction': pred, + 'reference': references[i], + 'support_texts': support_texts[i], + 'K': len(support), + 'adapt_time': adapt_time, + 'metrics': metrics, + }) + + if (i + 1) % 20 == 0: + avg_t = sum(u['adapt_time'] for u in per_user) / len(per_user) + avg_rl = sum(u['metrics']['rougeL'] for u in per_user) / len(per_user) + print(f" {i+1}/{N} (avg time: {avg_t:.1f}s, avg R-L: {avg_rl:.4f})") + + # Aggregate metrics + agg = { + 'rouge1': sum(u['metrics']['rouge1'] for u in per_user) / N, + 'rougeL': sum(u['metrics']['rougeL'] for u in per_user) / N, + 'meteor': sum(u['metrics']['meteor'] for u in per_user) / N, + 'sfd_all': sum(u['metrics']['sfd_all'] for u in per_user) / N, + 'sfd_nolen': sum(u['metrics']['sfd_nolen'] for u in per_user) / N, + 'avg_len': sum(u['metrics']['length'] for u in per_user) / N, + 'adapt_time': sum(u['adapt_time'] for u in per_user) / N, + 'n_params': baseline.n_params, + 'n_bytes': baseline.n_bytes, + } + + # Cleanup + baseline.cleanup() + + print(f" R-L: {agg['rougeL']:.4f}, METEOR: {agg['meteor']:.4f}, " + f"SFD_-len: {agg['sfd_nolen']:.4f}, len: {agg['avg_len']:.0f}, " + f"adapt: {agg['adapt_time']:.1f}s") + + return per_user, agg + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--num_eval', type=int, default=200) + parser.add_argument('--task', type=str, default='review', choices=['review', 'topic']) + parser.add_argument('--setting', type=str, default='user', choices=['user', 'temporal']) + parser.add_argument('--methods', type=str, default='all', + help='Comma-separated methods: lora,tiny_lora,vera or "all"') + parser.add_argument('--output_dir', type=str, default='outputs/peft_baselines') + parser.add_argument('--device', type=str, default='cuda:1') + parser.add_argument('--steps', type=int, default=None, help='Override adaptation steps') + args = parser.parse_args() + + N = args.num_eval + task = args.task + setting = args.setting + + config_map = { + ('review', 'user'): 'product_review_user', + ('review', 'temporal'): 'product_review_temporal', + ('topic', 'user'): 'topic_writing_user', + ('topic', 'temporal'): 'topic_writing_temporal', + } + config_name = config_map[(task, setting)] + + if args.methods == 'all': + methods = list(PEFT_CONFIGS.keys()) + else: + methods = [m.strip() for m in args.methods.split(',')] + for m in methods: + if m not in PEFT_CONFIGS: + print(f"Unknown method: {m}. Available: {list(PEFT_CONFIGS.keys())}") + return + + print(f"=== PEFT Baselines: {task}_{setting}, N={N} ===") + print(f"Methods: {methods}") + print(f"Decode policy: greedy, min_new_tokens=128, max_new_tokens=512") + + print("\nLoading data...") + examples = load_longlamp(config_name, split='val')[:N] + K = 4 + support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples] + references = [ex['target_output'] for ex in examples] + support_texts = [[s['support_output'] for s in ss] for ss in support_sets] + + avg_ref_len = sum(len(r.split()) for r in references) / len(references) + print(f"Examples: {len(examples)}, Avg reference len: {avg_ref_len:.0f}") + + print(f"\nLoading model on {args.device}...") + wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device=args.device) + + all_agg = {} + all_per_user = {} + + for method_name in methods: + config_entry = PEFT_CONFIGS[method_name].copy() + if args.steps is not None: + config_entry['steps'] = args.steps + + per_user, agg = run_peft_method( + wrapper, examples, support_sets, references, support_texts, + method_name, config_entry, N, + ) + all_agg[method_name] = agg + all_per_user[method_name] = per_user + + # Print summary + print("\n" + "=" * 100) + print("PEFT BASELINES SUMMARY") + print("=" * 100) + header = (f"{'Method':<25} {'R-L':<8} {'METEOR':<8} {'SFD_-len':<9} " + f"{'Len':<6} {'Params':<12} {'Bytes':<10} {'Time/user':<10}") + print(header) + print("-" * 100) + + uph_path = f"outputs/fair_audit/{task}_{setting}_K4_d64_N{N}_fair_results.json" + if os.path.exists(uph_path): + with open(uph_path) as f: + uph_data = json.load(f) + if 'Uncond-Head' in uph_data.get('results', {}): + uph_r = uph_data['results']['Uncond-Head'] + print(f"{'UPH (reference)':<25} {uph_r['rougeL']:<8.4f} {uph_r['meteor']:<8.4f} " + f"{uph_r['sfd_nolen']:<9.4f} {uph_r['avg_len']:<6.0f} " + f"{'64':<12} {'128':<10} {'~7s':<10}") + if 'Base' in uph_data.get('results', {}): + base_r = uph_data['results']['Base'] + print(f"{'Base (reference)':<25} {base_r['rougeL']:<8.4f} {base_r['meteor']:<8.4f} " + f"{base_r['sfd_nolen']:<9.4f} {base_r['avg_len']:<6.0f} " + f"{'0':<12} {'0':<10} {'0s':<10}") + print("-" * 100) + + for name, agg in all_agg.items(): + print(f"{PEFT_CONFIGS[name]['desc']:<25} {agg['rougeL']:<8.4f} {agg['meteor']:<8.4f} " + f"{agg['sfd_nolen']:<9.4f} {agg['avg_len']:<6.0f} " + f"{agg['n_params']:<12,} {agg['n_bytes']:<10,} " + f"{agg['adapt_time']:<10.1f}s") + + # Save complete results with per-user data + os.makedirs(args.output_dir, exist_ok=True) + exp_name = f"{task}_{setting}_K4_N{N}_peft" + + # Aggregate results (lightweight) + agg_path = os.path.join(args.output_dir, f"{exp_name}_results.json") + with open(agg_path, 'w') as f: + json.dump({ + 'aggregate': all_agg, + 'num_examples': N, + 'task': task, + 'setting': setting, + 'K': K, + 'decode_policy': 'greedy, min_new_tokens=128, max_new_tokens=512', + 'methods': {k: PEFT_CONFIGS[k]['desc'] for k in methods}, + }, f, indent=2, default=str) + + # Per-user data (complete) + per_user_path = os.path.join(args.output_dir, f"{exp_name}_per_user.json") + with open(per_user_path, 'w') as f: + json.dump({ + 'per_user': all_per_user, + 'num_examples': N, + 'task': task, + 'setting': setting, + 'K': K, + }, f, indent=2, default=str) + + print(f"\nAggregate results saved to {agg_path}") + print(f"Per-user data saved to {per_user_path}") + + +if __name__ == '__main__': + main() -- cgit v1.2.3