summaryrefslogtreecommitdiff
path: root/scripts/run_peft_baselines.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/run_peft_baselines.py')
-rw-r--r--scripts/run_peft_baselines.py271
1 files changed, 271 insertions, 0 deletions
diff --git a/scripts/run_peft_baselines.py b/scripts/run_peft_baselines.py
new file mode 100644
index 0000000..c23256b
--- /dev/null
+++ b/scripts/run_peft_baselines.py
@@ -0,0 +1,271 @@
+"""Evaluate PEFT baselines (LoRA, Tiny LoRA, VeRA) with fair decode policy.
+
+Saves complete per-user data: predictions, references, scores, metadata.
+
+Usage:
+ python scripts/run_peft_baselines.py --task review --setting user
+ python scripts/run_peft_baselines.py --task topic --setting user
+ python scripts/run_peft_baselines.py --task review --setting user --methods lora
+"""
+
+import sys
+import os
+import json
+import time
+import torch
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from data.longlamp import load_longlamp, select_k_profile_items
+from data.templates import build_query_prompt
+from data.style_features import FEATURE_NAMES, compute_sfd, compute_feature_deltas
+from models.qwen_wrapper import QwenWrapper
+from baselines.peft_baseline import (
+ PEFTBaseline, get_lora_config, get_tiny_lora_config, get_vera_config,
+)
+from eval.metrics import compute_rouge, compute_meteor
+
+
+PEFT_CONFIGS = {
+ 'lora': {
+ 'config_fn': lambda: get_lora_config(rank=8),
+ 'lr': 1e-4,
+ 'steps': 30,
+ 'desc': 'LoRA (rank=8, q+v proj)',
+ },
+ 'tiny_lora': {
+ 'config_fn': lambda: get_tiny_lora_config(rank=1),
+ 'lr': 1e-4,
+ 'steps': 30,
+ 'desc': 'Tiny LoRA (rank=1, q+v proj)',
+ },
+ 'vera': {
+ 'config_fn': lambda: get_vera_config(rank=256),
+ 'lr': 1e-3,
+ 'steps': 30,
+ 'desc': 'VeRA (rank=256, q+v proj)',
+ },
+}
+
+
+def compute_per_user_metrics(pred, ref, support_texts):
+ """Compute all metrics for a single prediction."""
+ r = compute_rouge([pred], [ref])
+ m = compute_meteor([pred], [ref])
+ sfd_all = compute_sfd(pred if pred.strip() else "empty", support_texts, exclude_length=False)
+ sfd_nolen = compute_sfd(pred if pred.strip() else "empty", support_texts, exclude_length=True)
+ deltas = compute_feature_deltas(pred if pred.strip() else "empty", support_texts)
+
+ return {
+ 'rouge1': r['rouge1'],
+ 'rougeL': r['rougeL'],
+ 'meteor': m,
+ 'sfd_all': sfd_all,
+ 'sfd_nolen': sfd_nolen,
+ 'length': len(pred.split()),
+ 'feature_deltas': {k: v['delta'] for k, v in deltas.items()},
+ }
+
+
+def run_peft_method(wrapper, examples, support_sets, references, support_texts,
+ method_name, config_entry, N):
+ """Run one PEFT baseline, returning per-user results."""
+ cfg = config_entry['config_fn']()
+ lr = config_entry['lr']
+ steps = config_entry['steps']
+
+ print(f"\n--- {config_entry['desc']} ---")
+
+ baseline = PEFTBaseline(wrapper, cfg)
+ print(f" Trainable params: {baseline.n_params:,} ({baseline.n_bytes:,} bytes)")
+
+ per_user = []
+
+ for i, (ex, support) in enumerate(zip(examples, support_sets)):
+ t0 = time.time()
+
+ pred = baseline.adapt_and_generate(
+ support_items=support,
+ query_input=ex['query_input'],
+ task=ex['task'],
+ lr=lr,
+ steps=steps,
+ max_new_tokens=512,
+ min_new_tokens=128,
+ verbose=False,
+ )
+ adapt_time = time.time() - t0
+
+ # Per-user metrics
+ metrics = compute_per_user_metrics(pred, references[i], support_texts[i])
+
+ per_user.append({
+ 'example_id': ex['example_id'],
+ 'user_id': ex['user_id'],
+ 'prediction': pred,
+ 'reference': references[i],
+ 'support_texts': support_texts[i],
+ 'K': len(support),
+ 'adapt_time': adapt_time,
+ 'metrics': metrics,
+ })
+
+ if (i + 1) % 20 == 0:
+ avg_t = sum(u['adapt_time'] for u in per_user) / len(per_user)
+ avg_rl = sum(u['metrics']['rougeL'] for u in per_user) / len(per_user)
+ print(f" {i+1}/{N} (avg time: {avg_t:.1f}s, avg R-L: {avg_rl:.4f})")
+
+ # Aggregate metrics
+ agg = {
+ 'rouge1': sum(u['metrics']['rouge1'] for u in per_user) / N,
+ 'rougeL': sum(u['metrics']['rougeL'] for u in per_user) / N,
+ 'meteor': sum(u['metrics']['meteor'] for u in per_user) / N,
+ 'sfd_all': sum(u['metrics']['sfd_all'] for u in per_user) / N,
+ 'sfd_nolen': sum(u['metrics']['sfd_nolen'] for u in per_user) / N,
+ 'avg_len': sum(u['metrics']['length'] for u in per_user) / N,
+ 'adapt_time': sum(u['adapt_time'] for u in per_user) / N,
+ 'n_params': baseline.n_params,
+ 'n_bytes': baseline.n_bytes,
+ }
+
+ # Cleanup
+ baseline.cleanup()
+
+ print(f" R-L: {agg['rougeL']:.4f}, METEOR: {agg['meteor']:.4f}, "
+ f"SFD_-len: {agg['sfd_nolen']:.4f}, len: {agg['avg_len']:.0f}, "
+ f"adapt: {agg['adapt_time']:.1f}s")
+
+ return per_user, agg
+
+
+def main():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--num_eval', type=int, default=200)
+ parser.add_argument('--task', type=str, default='review', choices=['review', 'topic'])
+ parser.add_argument('--setting', type=str, default='user', choices=['user', 'temporal'])
+ parser.add_argument('--methods', type=str, default='all',
+ help='Comma-separated methods: lora,tiny_lora,vera or "all"')
+ parser.add_argument('--output_dir', type=str, default='outputs/peft_baselines')
+ parser.add_argument('--device', type=str, default='cuda:1')
+ parser.add_argument('--steps', type=int, default=None, help='Override adaptation steps')
+ args = parser.parse_args()
+
+ N = args.num_eval
+ task = args.task
+ setting = args.setting
+
+ config_map = {
+ ('review', 'user'): 'product_review_user',
+ ('review', 'temporal'): 'product_review_temporal',
+ ('topic', 'user'): 'topic_writing_user',
+ ('topic', 'temporal'): 'topic_writing_temporal',
+ }
+ config_name = config_map[(task, setting)]
+
+ if args.methods == 'all':
+ methods = list(PEFT_CONFIGS.keys())
+ else:
+ methods = [m.strip() for m in args.methods.split(',')]
+ for m in methods:
+ if m not in PEFT_CONFIGS:
+ print(f"Unknown method: {m}. Available: {list(PEFT_CONFIGS.keys())}")
+ return
+
+ print(f"=== PEFT Baselines: {task}_{setting}, N={N} ===")
+ print(f"Methods: {methods}")
+ print(f"Decode policy: greedy, min_new_tokens=128, max_new_tokens=512")
+
+ print("\nLoading data...")
+ examples = load_longlamp(config_name, split='val')[:N]
+ K = 4
+ support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples]
+ references = [ex['target_output'] for ex in examples]
+ support_texts = [[s['support_output'] for s in ss] for ss in support_sets]
+
+ avg_ref_len = sum(len(r.split()) for r in references) / len(references)
+ print(f"Examples: {len(examples)}, Avg reference len: {avg_ref_len:.0f}")
+
+ print(f"\nLoading model on {args.device}...")
+ wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device=args.device)
+
+ all_agg = {}
+ all_per_user = {}
+
+ for method_name in methods:
+ config_entry = PEFT_CONFIGS[method_name].copy()
+ if args.steps is not None:
+ config_entry['steps'] = args.steps
+
+ per_user, agg = run_peft_method(
+ wrapper, examples, support_sets, references, support_texts,
+ method_name, config_entry, N,
+ )
+ all_agg[method_name] = agg
+ all_per_user[method_name] = per_user
+
+ # Print summary
+ print("\n" + "=" * 100)
+ print("PEFT BASELINES SUMMARY")
+ print("=" * 100)
+ header = (f"{'Method':<25} {'R-L':<8} {'METEOR':<8} {'SFD_-len':<9} "
+ f"{'Len':<6} {'Params':<12} {'Bytes':<10} {'Time/user':<10}")
+ print(header)
+ print("-" * 100)
+
+ uph_path = f"outputs/fair_audit/{task}_{setting}_K4_d64_N{N}_fair_results.json"
+ if os.path.exists(uph_path):
+ with open(uph_path) as f:
+ uph_data = json.load(f)
+ if 'Uncond-Head' in uph_data.get('results', {}):
+ uph_r = uph_data['results']['Uncond-Head']
+ print(f"{'UPH (reference)':<25} {uph_r['rougeL']:<8.4f} {uph_r['meteor']:<8.4f} "
+ f"{uph_r['sfd_nolen']:<9.4f} {uph_r['avg_len']:<6.0f} "
+ f"{'64':<12} {'128':<10} {'~7s':<10}")
+ if 'Base' in uph_data.get('results', {}):
+ base_r = uph_data['results']['Base']
+ print(f"{'Base (reference)':<25} {base_r['rougeL']:<8.4f} {base_r['meteor']:<8.4f} "
+ f"{base_r['sfd_nolen']:<9.4f} {base_r['avg_len']:<6.0f} "
+ f"{'0':<12} {'0':<10} {'0s':<10}")
+ print("-" * 100)
+
+ for name, agg in all_agg.items():
+ print(f"{PEFT_CONFIGS[name]['desc']:<25} {agg['rougeL']:<8.4f} {agg['meteor']:<8.4f} "
+ f"{agg['sfd_nolen']:<9.4f} {agg['avg_len']:<6.0f} "
+ f"{agg['n_params']:<12,} {agg['n_bytes']:<10,} "
+ f"{agg['adapt_time']:<10.1f}s")
+
+ # Save complete results with per-user data
+ os.makedirs(args.output_dir, exist_ok=True)
+ exp_name = f"{task}_{setting}_K4_N{N}_peft"
+
+ # Aggregate results (lightweight)
+ agg_path = os.path.join(args.output_dir, f"{exp_name}_results.json")
+ with open(agg_path, 'w') as f:
+ json.dump({
+ 'aggregate': all_agg,
+ 'num_examples': N,
+ 'task': task,
+ 'setting': setting,
+ 'K': K,
+ 'decode_policy': 'greedy, min_new_tokens=128, max_new_tokens=512',
+ 'methods': {k: PEFT_CONFIGS[k]['desc'] for k in methods},
+ }, f, indent=2, default=str)
+
+ # Per-user data (complete)
+ per_user_path = os.path.join(args.output_dir, f"{exp_name}_per_user.json")
+ with open(per_user_path, 'w') as f:
+ json.dump({
+ 'per_user': all_per_user,
+ 'num_examples': N,
+ 'task': task,
+ 'setting': setting,
+ 'K': K,
+ }, f, indent=2, default=str)
+
+ print(f"\nAggregate results saved to {agg_path}")
+ print(f"Per-user data saved to {per_user_path}")
+
+
+if __name__ == '__main__':
+ main()