diff options
| -rw-r--r-- | scripts/uph_hyperparam_search.py | 169 |
1 files changed, 169 insertions, 0 deletions
diff --git a/scripts/uph_hyperparam_search.py b/scripts/uph_hyperparam_search.py new file mode 100644 index 0000000..78706f1 --- /dev/null +++ b/scripts/uph_hyperparam_search.py @@ -0,0 +1,169 @@ +"""UPH hyperparameter search for a given K. + +Searches over (lr, steps, d) on a small N, then confirms top configs on full N. + +Usage: + python scripts/uph_hyperparam_search.py --task review --setting user --K 8 --device cuda:0 + python scripts/uph_hyperparam_search.py --task review --setting user --K 8 --N_screen 50 --N_confirm 200 +""" + +import sys +import os +import json +import time +import itertools +import numpy as np +import torch + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from data.longlamp import load_longlamp, select_k_profile_items +from data.templates import build_query_prompt +from data.style_features import compute_sfd, compute_feature_deltas +from models.qwen_wrapper import QwenWrapper +from models.cvh import UnconditionalHead +from adapt.cache_hidden import cache_support_hidden_states +from adapt.fit_theta import fit_theta +from eval.metrics import compute_rouge, compute_meteor + + +def run_uph_config(wrapper, examples, support_sets, references, support_texts, + d, lr, steps, device, N): + """Run UPH with specific hyperparams, return mean R-L.""" + H = wrapper.hidden_size + uncond = UnconditionalHead(H, d=d, alpha=0.1, basis_seed=42).to(device) + lm_head_bias = None + if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: + lm_head_bias = wrapper.model.lm_head.bias.data + + rl_scores = [] + for i, (ex, support) in enumerate(zip(examples[:N], support_sets[:N])): + cached_h = cache_support_hidden_states(wrapper, support, ex['task']) + if not cached_h: + prompt = build_query_prompt(ex['query_input'], ex['task']) + # Base generation + chat_msgs = [{"role": "system", "content": "You are a helpful writing assistant."}, + {"role": "user", "content": prompt}] + pt = wrapper.tokenizer.apply_chat_template(chat_msgs, tokenize=False, add_generation_prompt=True) + ids = wrapper.tokenizer.encode(pt, return_tensors="pt").to(device) + with torch.no_grad(): + out = wrapper.model.generate(ids, max_new_tokens=512, min_new_tokens=128, + temperature=None, top_p=None, do_sample=False, + pad_token_id=wrapper.tokenizer.pad_token_id) + pred = wrapper.tokenizer.decode(out[0, ids.shape[1]:], skip_special_tokens=True) + else: + theta = fit_theta( + cached_h=cached_h, + lm_head_weight=wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, + head_module=uncond, + d=d, lr=lr, steps=steps, beta=0.05, lam=1e-4, + max_grad_norm=5.0, device=device, + ) + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_with_head_blended( + prompt, theta, uncond.forward_fn, + blend_gamma=0.5, max_new_tokens=512, + min_new_tokens=128, temperature=0.0, + ) + del cached_h, theta + torch.cuda.empty_cache() + + r = compute_rouge([pred], [references[i]]) + rl_scores.append(r['rougeL']) + + return np.mean(rl_scores) + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='review') + parser.add_argument('--setting', type=str, default='user') + parser.add_argument('--K', type=int, default=8) + parser.add_argument('--N_screen', type=int, default=50) + parser.add_argument('--N_confirm', type=int, default=200) + parser.add_argument('--device', type=str, default='cuda:0') + parser.add_argument('--output_dir', type=str, default='outputs/hyperparam') + parser.add_argument('--top_k', type=int, default=3, help='Top configs to confirm') + args = parser.parse_args() + + config_map = { + ('review', 'user'): 'product_review_user', + ('review', 'temporal'): 'product_review_temporal', + ('topic', 'user'): 'topic_writing_user', + ('topic', 'temporal'): 'topic_writing_temporal', + } + config_name = config_map[(args.task, args.setting)] + + # Search grid + d_values = [32, 64, 128] + lr_values = [0.03, 0.05, 0.1] + steps_values = [30, 50, 100] + + print(f"=== UPH Hyperparam Search: {args.task}_{args.setting}, K={args.K} ===") + print(f"Grid: d={d_values}, lr={lr_values}, steps={steps_values}") + print(f"Screen N={args.N_screen}, Confirm N={args.N_confirm}") + + print("\nLoading data...") + examples = load_longlamp(config_name, split='val')[:args.N_confirm] + support_sets = [select_k_profile_items(ex['profile_items'], args.K, seed=0) for ex in examples] + references = [ex['target_output'] for ex in examples] + support_texts = [[s['support_output'] for s in ss] for ss in support_sets] + + print(f"Loading model on {args.device}...") + wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device=args.device) + + # Phase 1: Screen all configs on small N + print(f"\n--- Phase 1: Screening ({args.N_screen} examples) ---") + results = [] + for d, lr, steps in itertools.product(d_values, lr_values, steps_values): + t0 = time.time() + mean_rl = run_uph_config( + wrapper, examples, support_sets, references, support_texts, + d=d, lr=lr, steps=steps, device=args.device, N=args.N_screen, + ) + elapsed = time.time() - t0 + results.append({'d': d, 'lr': lr, 'steps': steps, 'mean_rl': mean_rl, 'time': elapsed}) + print(f" d={d:3d} lr={lr:.3f} steps={steps:3d}: R-L={mean_rl:.4f} ({elapsed:.0f}s)") + + # Sort by R-L + results.sort(key=lambda x: x['mean_rl'], reverse=True) + print(f"\n--- Top {args.top_k} configs ---") + for i, r in enumerate(results[:args.top_k]): + print(f" #{i+1}: d={r['d']} lr={r['lr']} steps={r['steps']} R-L={r['mean_rl']:.4f}") + + # Phase 2: Confirm top configs on full N + print(f"\n--- Phase 2: Confirming top {args.top_k} ({args.N_confirm} examples) ---") + confirmed = [] + for r in results[:args.top_k]: + t0 = time.time() + mean_rl = run_uph_config( + wrapper, examples, support_sets, references, support_texts, + d=r['d'], lr=r['lr'], steps=r['steps'], + device=args.device, N=args.N_confirm, + ) + elapsed = time.time() - t0 + confirmed.append({**r, 'confirmed_rl': mean_rl, 'confirm_time': elapsed}) + print(f" d={r['d']} lr={r['lr']} steps={r['steps']}: " + f"screen={r['mean_rl']:.4f} → confirm={mean_rl:.4f} ({elapsed:.0f}s)") + + # Save + os.makedirs(args.output_dir, exist_ok=True) + output_path = os.path.join(args.output_dir, + f"{args.task}_{args.setting}_K{args.K}_search.json") + with open(output_path, 'w') as f: + json.dump({ + 'screening': results, + 'confirmed': confirmed, + 'task': args.task, + 'setting': args.setting, + 'K': args.K, + 'N_screen': args.N_screen, + 'N_confirm': args.N_confirm, + }, f, indent=2, default=str) + print(f"\nSaved to {output_path}") + + +if __name__ == '__main__': + main() |
