summaryrefslogtreecommitdiff
path: root/scripts/run_dev.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-03 15:12:34 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-03 15:12:34 -0500
commit8fe28101366dd32562b8c5534d7fe359b252bdf3 (patch)
treec92a92184fb2f46f265ab84c1f754c3d5d6597bc /scripts/run_dev.py
Initial commit: UPH project codebase and experiment results
Includes model code, evaluation scripts, configs, analysis outputs, and experiment results for the User Prior Head personalization method. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'scripts/run_dev.py')
-rw-r--r--scripts/run_dev.py310
1 files changed, 310 insertions, 0 deletions
diff --git a/scripts/run_dev.py b/scripts/run_dev.py
new file mode 100644
index 0000000..d8ccee5
--- /dev/null
+++ b/scripts/run_dev.py
@@ -0,0 +1,310 @@
+"""Main experiment runner for CVH development.
+
+Runs all methods (Base, Prompt-All-K, BM25-Top1, Unconditional, CVH) on
+LongLaMP validation set and reports metrics.
+"""
+
+import sys
+import os
+import json
+import time
+import argparse
+import yaml
+import torch
+
+# Add project root to path
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from data.longlamp import load_longlamp, select_k_profile_items
+from data.templates import build_query_prompt
+from models.qwen_wrapper import QwenWrapper
+from models.cvh import CVHHead, UnconditionalHead
+from adapt.cache_hidden import cache_support_hidden_states
+from adapt.fit_theta import fit_theta
+from baselines.prompt_all_k import generate_prompt_all_k
+from baselines.bm25_top1 import generate_bm25_top1
+from eval.metrics import evaluate_all, print_results_table, compute_recovery
+
+
+def run_method_base(wrapper, examples, cfg):
+ """Run Base (no personalization) method."""
+ predictions = []
+ for i, ex in enumerate(examples):
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_base(
+ prompt,
+ max_new_tokens=cfg['max_new_tokens'],
+ temperature=cfg.get('temperature', 0.7),
+ top_p=cfg.get('top_p', 0.9),
+ )
+ predictions.append(pred)
+ if (i + 1) % 20 == 0:
+ print(f" Base: {i+1}/{len(examples)}")
+ return predictions
+
+
+def run_method_prompt_all_k(wrapper, examples, support_sets, cfg):
+ """Run Prompt-All-K baseline."""
+ predictions = []
+ for i, (ex, support) in enumerate(zip(examples, support_sets)):
+ pred = generate_prompt_all_k(
+ wrapper, ex['query_input'], support, ex['task'],
+ max_new_tokens=cfg['max_new_tokens'],
+ temperature=cfg.get('temperature', 0.7),
+ top_p=cfg.get('top_p', 0.9),
+ )
+ predictions.append(pred)
+ if (i + 1) % 20 == 0:
+ print(f" Prompt-All-K: {i+1}/{len(examples)}")
+ return predictions
+
+
+def run_method_bm25_top1(wrapper, examples, support_sets, cfg):
+ """Run BM25-Top1 baseline."""
+ predictions = []
+ for i, (ex, support) in enumerate(zip(examples, support_sets)):
+ pred = generate_bm25_top1(
+ wrapper, ex['query_input'], support, ex['task'],
+ max_new_tokens=cfg['max_new_tokens'],
+ temperature=cfg.get('temperature', 0.7),
+ top_p=cfg.get('top_p', 0.9),
+ )
+ predictions.append(pred)
+ if (i + 1) % 20 == 0:
+ print(f" BM25-Top1: {i+1}/{len(examples)}")
+ return predictions
+
+
+def run_method_vector_head(wrapper, examples, support_sets, cfg, head_type='cvh'):
+ """Run CVH or Unconditional head method."""
+ device = cfg['device']
+ d = cfg['d']
+ alpha = cfg['alpha']
+ basis_seed = cfg['basis_seed']
+ H = wrapper.hidden_size
+
+ # Create head
+ if head_type == 'cvh':
+ head = CVHHead(H, d=d, alpha=alpha, basis_seed=basis_seed).to(device)
+ else:
+ head = UnconditionalHead(H, d=d, alpha=alpha, basis_seed=basis_seed).to(device)
+
+ lm_head_bias = None
+ if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
+ lm_head_bias = wrapper.model.lm_head.bias.data
+
+ predictions = []
+ adapt_times = []
+
+ for i, (ex, support) in enumerate(zip(examples, support_sets)):
+ # Step 1: Cache hidden states from support set
+ t0 = time.time()
+ cached_h = cache_support_hidden_states(wrapper, support, ex['task'])
+
+ if not cached_h:
+ # Fallback to base generation if caching fails
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_base(prompt, max_new_tokens=cfg['max_new_tokens'])
+ predictions.append(pred)
+ adapt_times.append(0.0)
+ continue
+
+ # Step 2: Fit theta_u
+ theta = fit_theta(
+ cached_h=cached_h,
+ lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias,
+ head_module=head,
+ d=d,
+ lr=cfg['lr'],
+ steps=cfg['adapt_steps'],
+ beta=cfg['beta'],
+ lam=cfg['lam'],
+ max_grad_norm=cfg['max_grad_norm'],
+ device=device,
+ verbose=False,
+ )
+ adapt_time = time.time() - t0
+ adapt_times.append(adapt_time)
+
+ # Step 3: Generate with personalized head (blended)
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ blend_gamma = cfg.get('blend_gamma', 0.5)
+ pred = wrapper.generate_with_head_blended(
+ prompt, theta, head.forward_fn,
+ blend_gamma=blend_gamma,
+ max_new_tokens=cfg['max_new_tokens'],
+ min_new_tokens=cfg.get('min_new_tokens', 128),
+ temperature=cfg.get('temperature', 0.0),
+ )
+ predictions.append(pred)
+
+ # Cleanup GPU memory
+ del cached_h, theta
+ torch.cuda.empty_cache()
+
+ if (i + 1) % 10 == 0:
+ avg_adapt = sum(adapt_times) / len(adapt_times)
+ print(f" {head_type.upper()}: {i+1}/{len(examples)} "
+ f"(avg adapt: {avg_adapt:.2f}s)")
+
+ avg_adapt_time = sum(adapt_times) / max(len(adapt_times), 1)
+ print(f" Average adaptation time: {avg_adapt_time:.2f}s")
+
+ return predictions, avg_adapt_time
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', type=str, required=True)
+ parser.add_argument('--num_eval', type=int, default=None,
+ help='Override number of examples to evaluate')
+ parser.add_argument('--methods', type=str, default='all',
+ help='Comma-separated methods: base,prompt_all_k,bm25_top1,uncond,cvh')
+ parser.add_argument('--output_dir', type=str, default='outputs')
+ parser.add_argument('--seed', type=int, default=0)
+ args = parser.parse_args()
+
+ # Load config
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ if args.num_eval is not None:
+ cfg['num_eval'] = args.num_eval
+
+ torch.manual_seed(args.seed)
+
+ print(f"=== CVH Dev Experiment ===")
+ print(f"Config: {args.config}")
+ print(f"Task: {cfg['task']}, Setting: {cfg['setting']}")
+ print(f"Model: {cfg['model_name']}, Device: {cfg['device']}")
+ print(f"d={cfg['d']}, K={cfg['K']}, alpha={cfg['alpha']}, steps={cfg['adapt_steps']}")
+
+ # Load data
+ print("\nLoading LongLaMP data...")
+ examples = load_longlamp(cfg['dataset_config'], split='val')
+ print(f"Loaded {len(examples)} validation examples")
+
+ # Limit examples if specified
+ num_eval = cfg.get('num_eval', -1)
+ if num_eval > 0:
+ examples = examples[:num_eval]
+ print(f"Evaluating on {len(examples)} examples")
+
+ # Prepare support sets
+ K = cfg['K']
+ support_sets = []
+ for ex in examples:
+ support = select_k_profile_items(ex['profile_items'], K, seed=args.seed)
+ support_sets.append(support)
+
+ # Gather references and support texts for metrics
+ references = [ex['target_output'] for ex in examples]
+ support_texts_per_example = [
+ [s['support_output'] for s in support]
+ for support in support_sets
+ ]
+
+ # Parse methods to run
+ if args.methods == 'all':
+ methods_to_run = ['base', 'prompt_all_k', 'bm25_top1', 'uncond', 'cvh']
+ else:
+ methods_to_run = args.methods.split(',')
+
+ # Load model
+ print(f"\nLoading model {cfg['model_name']}...")
+ wrapper = QwenWrapper(cfg['model_name'], device=cfg['device'])
+ print(f"Model loaded. Hidden size: {wrapper.hidden_size}")
+
+ results = {}
+ all_predictions = {}
+
+ # Run each method
+ if 'base' in methods_to_run:
+ print("\n--- Running Base ---")
+ preds = run_method_base(wrapper, examples, cfg)
+ all_predictions['Base'] = preds
+ results['Base'] = evaluate_all(preds, references, support_texts_per_example)
+ print(f" ROUGE-L: {results['Base']['rougeL']:.4f}, METEOR: {results['Base']['meteor']:.4f}, SFD: {results['Base']['sfd']:.4f}")
+
+ if 'prompt_all_k' in methods_to_run:
+ print(f"\n--- Running Prompt-All-K (K={K}) ---")
+ preds = run_method_prompt_all_k(wrapper, examples, support_sets, cfg)
+ all_predictions['Prompt-All-K'] = preds
+ results['Prompt-All-K'] = evaluate_all(preds, references, support_texts_per_example)
+ print(f" ROUGE-L: {results['Prompt-All-K']['rougeL']:.4f}, METEOR: {results['Prompt-All-K']['meteor']:.4f}, SFD: {results['Prompt-All-K']['sfd']:.4f}")
+
+ if 'bm25_top1' in methods_to_run:
+ print(f"\n--- Running BM25-Top1 ---")
+ preds = run_method_bm25_top1(wrapper, examples, support_sets, cfg)
+ all_predictions['BM25-Top1'] = preds
+ results['BM25-Top1'] = evaluate_all(preds, references, support_texts_per_example)
+ print(f" ROUGE-L: {results['BM25-Top1']['rougeL']:.4f}, METEOR: {results['BM25-Top1']['meteor']:.4f}, SFD: {results['BM25-Top1']['sfd']:.4f}")
+
+ if 'uncond' in methods_to_run:
+ print(f"\n--- Running Unconditional Head ---")
+ preds, adapt_time = run_method_vector_head(wrapper, examples, support_sets, cfg, head_type='uncond')
+ all_predictions['Uncond-Head'] = preds
+ results['Uncond-Head'] = evaluate_all(preds, references, support_texts_per_example)
+ results['Uncond-Head']['adapt_time'] = adapt_time
+ print(f" ROUGE-L: {results['Uncond-Head']['rougeL']:.4f}, METEOR: {results['Uncond-Head']['meteor']:.4f}, SFD: {results['Uncond-Head']['sfd']:.4f}")
+
+ if 'cvh' in methods_to_run:
+ print(f"\n--- Running CVH ---")
+ preds, adapt_time = run_method_vector_head(wrapper, examples, support_sets, cfg, head_type='cvh')
+ all_predictions['CVH'] = preds
+ results['CVH'] = evaluate_all(preds, references, support_texts_per_example)
+ results['CVH']['adapt_time'] = adapt_time
+ print(f" ROUGE-L: {results['CVH']['rougeL']:.4f}, METEOR: {results['CVH']['meteor']:.4f}, SFD: {results['CVH']['sfd']:.4f}")
+
+ # Print comparison table
+ print("\n" + "=" * 70)
+ print("RESULTS SUMMARY")
+ print("=" * 70)
+ print_results_table(results)
+
+ # Compute compression for vector methods
+ print("\n--- Efficiency ---")
+ theta_bytes = cfg['d'] * 2 # bf16 = 2 bytes per dim
+ for ex_support in support_texts_per_example[:5]:
+ from eval.metrics import compute_compression
+ comp = compute_compression(ex_support, theta_bytes)
+ print(f" Compression ratio (example): {comp:.0f}x")
+
+ print(f" Theta size: {theta_bytes} bytes")
+ print(f" Personalization prompt tokens at inference: 0")
+
+ # Save results
+ os.makedirs(args.output_dir, exist_ok=True)
+ exp_name = f"{cfg['task']}_{cfg['setting']}_K{K}_d{cfg['d']}"
+ output_path = os.path.join(args.output_dir, f"{exp_name}_results.json")
+
+ # Convert results to serializable format
+ save_data = {
+ 'config': cfg,
+ 'results': results,
+ 'num_examples': len(examples),
+ }
+ with open(output_path, 'w') as f:
+ json.dump(save_data, f, indent=2)
+ print(f"\nResults saved to {output_path}")
+
+ # Save predictions
+ pred_path = os.path.join(args.output_dir, f"{exp_name}_predictions.json")
+ save_preds = {}
+ for method, preds in all_predictions.items():
+ save_preds[method] = [
+ {
+ 'example_id': examples[i]['example_id'],
+ 'prediction': preds[i],
+ 'reference': references[i],
+ }
+ for i in range(len(preds))
+ ]
+ with open(pred_path, 'w') as f:
+ json.dump(save_preds, f, indent=2)
+ print(f"Predictions saved to {pred_path}")
+
+
+if __name__ == '__main__':
+ main()