"""UPH hyperparameter search for a given K. Searches over (lr, steps, d) on a small N, then confirms top configs on full N. Usage: python scripts/uph_hyperparam_search.py --task review --setting user --K 8 --device cuda:0 python scripts/uph_hyperparam_search.py --task review --setting user --K 8 --N_screen 50 --N_confirm 200 """ import sys import os import json import time import itertools import numpy as np import torch sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.longlamp import load_longlamp, select_k_profile_items from data.templates import build_query_prompt from data.style_features import compute_sfd, compute_feature_deltas from models.qwen_wrapper import QwenWrapper from models.cvh import UnconditionalHead from adapt.cache_hidden import cache_support_hidden_states from adapt.fit_theta import fit_theta from eval.metrics import compute_rouge, compute_meteor def run_uph_config(wrapper, examples, support_sets, references, support_texts, d, lr, steps, device, N): """Run UPH with specific hyperparams, return mean R-L.""" H = wrapper.hidden_size uncond = UnconditionalHead(H, d=d, alpha=0.1, basis_seed=42).to(device) lm_head_bias = None if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: lm_head_bias = wrapper.model.lm_head.bias.data rl_scores = [] for i, (ex, support) in enumerate(zip(examples[:N], support_sets[:N])): cached_h = cache_support_hidden_states(wrapper, support, ex['task']) if not cached_h: prompt = build_query_prompt(ex['query_input'], ex['task']) # Base generation chat_msgs = [{"role": "system", "content": "You are a helpful writing assistant."}, {"role": "user", "content": prompt}] pt = wrapper.tokenizer.apply_chat_template(chat_msgs, tokenize=False, add_generation_prompt=True) ids = wrapper.tokenizer.encode(pt, return_tensors="pt").to(device) with torch.no_grad(): out = wrapper.model.generate(ids, max_new_tokens=512, min_new_tokens=128, temperature=None, top_p=None, do_sample=False, pad_token_id=wrapper.tokenizer.pad_token_id) pred = wrapper.tokenizer.decode(out[0, ids.shape[1]:], skip_special_tokens=True) else: theta = fit_theta( cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight, lm_head_bias=lm_head_bias, head_module=uncond, d=d, lr=lr, steps=steps, beta=0.05, lam=1e-4, max_grad_norm=5.0, device=device, ) prompt = build_query_prompt(ex['query_input'], ex['task']) pred = wrapper.generate_with_head_blended( prompt, theta, uncond.forward_fn, blend_gamma=0.5, max_new_tokens=512, min_new_tokens=128, temperature=0.0, ) del cached_h, theta torch.cuda.empty_cache() r = compute_rouge([pred], [references[i]]) rl_scores.append(r['rougeL']) return np.mean(rl_scores) def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='review') parser.add_argument('--setting', type=str, default='user') parser.add_argument('--K', type=int, default=8) parser.add_argument('--N_screen', type=int, default=50) parser.add_argument('--N_confirm', type=int, default=200) parser.add_argument('--device', type=str, default='cuda:0') parser.add_argument('--output_dir', type=str, default='outputs/hyperparam') parser.add_argument('--top_k', type=int, default=3, help='Top configs to confirm') args = parser.parse_args() config_map = { ('review', 'user'): 'product_review_user', ('review', 'temporal'): 'product_review_temporal', ('topic', 'user'): 'topic_writing_user', ('topic', 'temporal'): 'topic_writing_temporal', } config_name = config_map[(args.task, args.setting)] # Search grid d_values = [32, 64, 128] lr_values = [0.03, 0.05, 0.1] steps_values = [30, 50, 100] print(f"=== UPH Hyperparam Search: {args.task}_{args.setting}, K={args.K} ===") print(f"Grid: d={d_values}, lr={lr_values}, steps={steps_values}") print(f"Screen N={args.N_screen}, Confirm N={args.N_confirm}") print("\nLoading data...") examples = load_longlamp(config_name, split='val')[:args.N_confirm] support_sets = [select_k_profile_items(ex['profile_items'], args.K, seed=0) for ex in examples] references = [ex['target_output'] for ex in examples] support_texts = [[s['support_output'] for s in ss] for ss in support_sets] print(f"Loading model on {args.device}...") wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device=args.device) # Phase 1: Screen all configs on small N print(f"\n--- Phase 1: Screening ({args.N_screen} examples) ---") results = [] for d, lr, steps in itertools.product(d_values, lr_values, steps_values): t0 = time.time() mean_rl = run_uph_config( wrapper, examples, support_sets, references, support_texts, d=d, lr=lr, steps=steps, device=args.device, N=args.N_screen, ) elapsed = time.time() - t0 results.append({'d': d, 'lr': lr, 'steps': steps, 'mean_rl': mean_rl, 'time': elapsed}) print(f" d={d:3d} lr={lr:.3f} steps={steps:3d}: R-L={mean_rl:.4f} ({elapsed:.0f}s)") # Sort by R-L results.sort(key=lambda x: x['mean_rl'], reverse=True) print(f"\n--- Top {args.top_k} configs ---") for i, r in enumerate(results[:args.top_k]): print(f" #{i+1}: d={r['d']} lr={r['lr']} steps={r['steps']} R-L={r['mean_rl']:.4f}") # Phase 2: Confirm top configs on full N print(f"\n--- Phase 2: Confirming top {args.top_k} ({args.N_confirm} examples) ---") confirmed = [] for r in results[:args.top_k]: t0 = time.time() mean_rl = run_uph_config( wrapper, examples, support_sets, references, support_texts, d=r['d'], lr=r['lr'], steps=r['steps'], device=args.device, N=args.N_confirm, ) elapsed = time.time() - t0 confirmed.append({**r, 'confirmed_rl': mean_rl, 'confirm_time': elapsed}) print(f" d={r['d']} lr={r['lr']} steps={r['steps']}: " f"screen={r['mean_rl']:.4f} → confirm={mean_rl:.4f} ({elapsed:.0f}s)") # Save os.makedirs(args.output_dir, exist_ok=True) output_path = os.path.join(args.output_dir, f"{args.task}_{args.setting}_K{args.K}_search.json") with open(output_path, 'w') as f: json.dump({ 'screening': results, 'confirmed': confirmed, 'task': args.task, 'setting': args.setting, 'K': args.K, 'N_screen': args.N_screen, 'N_confirm': args.N_confirm, }, f, indent=2, default=str) print(f"\nSaved to {output_path}") if __name__ == '__main__': main()