summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-06 06:56:48 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-06 06:56:48 -0500
commit26c899101dbb192981cc67d73fc00a2d158b503e (patch)
tree3b84a3d07ae24c6f0f070ec1216f908e7ba91fa5 /scripts
parent86a7ef5b8d12cea1032602f30c18d52392f1cc42 (diff)
Add UPH hyperparameter search script for K ablation tuning
Two-phase search: screen 27 configs (d×lr×steps) on N=50, then confirm top-3 on N=200. Needed because K=8 UPH R-L (0.137) unexpectedly lower than K=4 (0.140) with default hyperparams. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'scripts')
-rw-r--r--scripts/uph_hyperparam_search.py169
1 files changed, 169 insertions, 0 deletions
diff --git a/scripts/uph_hyperparam_search.py b/scripts/uph_hyperparam_search.py
new file mode 100644
index 0000000..78706f1
--- /dev/null
+++ b/scripts/uph_hyperparam_search.py
@@ -0,0 +1,169 @@
+"""UPH hyperparameter search for a given K.
+
+Searches over (lr, steps, d) on a small N, then confirms top configs on full N.
+
+Usage:
+ python scripts/uph_hyperparam_search.py --task review --setting user --K 8 --device cuda:0
+ python scripts/uph_hyperparam_search.py --task review --setting user --K 8 --N_screen 50 --N_confirm 200
+"""
+
+import sys
+import os
+import json
+import time
+import itertools
+import numpy as np
+import torch
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from data.longlamp import load_longlamp, select_k_profile_items
+from data.templates import build_query_prompt
+from data.style_features import compute_sfd, compute_feature_deltas
+from models.qwen_wrapper import QwenWrapper
+from models.cvh import UnconditionalHead
+from adapt.cache_hidden import cache_support_hidden_states
+from adapt.fit_theta import fit_theta
+from eval.metrics import compute_rouge, compute_meteor
+
+
+def run_uph_config(wrapper, examples, support_sets, references, support_texts,
+ d, lr, steps, device, N):
+ """Run UPH with specific hyperparams, return mean R-L."""
+ H = wrapper.hidden_size
+ uncond = UnconditionalHead(H, d=d, alpha=0.1, basis_seed=42).to(device)
+ lm_head_bias = None
+ if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
+ lm_head_bias = wrapper.model.lm_head.bias.data
+
+ rl_scores = []
+ for i, (ex, support) in enumerate(zip(examples[:N], support_sets[:N])):
+ cached_h = cache_support_hidden_states(wrapper, support, ex['task'])
+ if not cached_h:
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ # Base generation
+ chat_msgs = [{"role": "system", "content": "You are a helpful writing assistant."},
+ {"role": "user", "content": prompt}]
+ pt = wrapper.tokenizer.apply_chat_template(chat_msgs, tokenize=False, add_generation_prompt=True)
+ ids = wrapper.tokenizer.encode(pt, return_tensors="pt").to(device)
+ with torch.no_grad():
+ out = wrapper.model.generate(ids, max_new_tokens=512, min_new_tokens=128,
+ temperature=None, top_p=None, do_sample=False,
+ pad_token_id=wrapper.tokenizer.pad_token_id)
+ pred = wrapper.tokenizer.decode(out[0, ids.shape[1]:], skip_special_tokens=True)
+ else:
+ theta = fit_theta(
+ cached_h=cached_h,
+ lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias,
+ head_module=uncond,
+ d=d, lr=lr, steps=steps, beta=0.05, lam=1e-4,
+ max_grad_norm=5.0, device=device,
+ )
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_with_head_blended(
+ prompt, theta, uncond.forward_fn,
+ blend_gamma=0.5, max_new_tokens=512,
+ min_new_tokens=128, temperature=0.0,
+ )
+ del cached_h, theta
+ torch.cuda.empty_cache()
+
+ r = compute_rouge([pred], [references[i]])
+ rl_scores.append(r['rougeL'])
+
+ return np.mean(rl_scores)
+
+
+def main():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--task', type=str, default='review')
+ parser.add_argument('--setting', type=str, default='user')
+ parser.add_argument('--K', type=int, default=8)
+ parser.add_argument('--N_screen', type=int, default=50)
+ parser.add_argument('--N_confirm', type=int, default=200)
+ parser.add_argument('--device', type=str, default='cuda:0')
+ parser.add_argument('--output_dir', type=str, default='outputs/hyperparam')
+ parser.add_argument('--top_k', type=int, default=3, help='Top configs to confirm')
+ args = parser.parse_args()
+
+ config_map = {
+ ('review', 'user'): 'product_review_user',
+ ('review', 'temporal'): 'product_review_temporal',
+ ('topic', 'user'): 'topic_writing_user',
+ ('topic', 'temporal'): 'topic_writing_temporal',
+ }
+ config_name = config_map[(args.task, args.setting)]
+
+ # Search grid
+ d_values = [32, 64, 128]
+ lr_values = [0.03, 0.05, 0.1]
+ steps_values = [30, 50, 100]
+
+ print(f"=== UPH Hyperparam Search: {args.task}_{args.setting}, K={args.K} ===")
+ print(f"Grid: d={d_values}, lr={lr_values}, steps={steps_values}")
+ print(f"Screen N={args.N_screen}, Confirm N={args.N_confirm}")
+
+ print("\nLoading data...")
+ examples = load_longlamp(config_name, split='val')[:args.N_confirm]
+ support_sets = [select_k_profile_items(ex['profile_items'], args.K, seed=0) for ex in examples]
+ references = [ex['target_output'] for ex in examples]
+ support_texts = [[s['support_output'] for s in ss] for ss in support_sets]
+
+ print(f"Loading model on {args.device}...")
+ wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device=args.device)
+
+ # Phase 1: Screen all configs on small N
+ print(f"\n--- Phase 1: Screening ({args.N_screen} examples) ---")
+ results = []
+ for d, lr, steps in itertools.product(d_values, lr_values, steps_values):
+ t0 = time.time()
+ mean_rl = run_uph_config(
+ wrapper, examples, support_sets, references, support_texts,
+ d=d, lr=lr, steps=steps, device=args.device, N=args.N_screen,
+ )
+ elapsed = time.time() - t0
+ results.append({'d': d, 'lr': lr, 'steps': steps, 'mean_rl': mean_rl, 'time': elapsed})
+ print(f" d={d:3d} lr={lr:.3f} steps={steps:3d}: R-L={mean_rl:.4f} ({elapsed:.0f}s)")
+
+ # Sort by R-L
+ results.sort(key=lambda x: x['mean_rl'], reverse=True)
+ print(f"\n--- Top {args.top_k} configs ---")
+ for i, r in enumerate(results[:args.top_k]):
+ print(f" #{i+1}: d={r['d']} lr={r['lr']} steps={r['steps']} R-L={r['mean_rl']:.4f}")
+
+ # Phase 2: Confirm top configs on full N
+ print(f"\n--- Phase 2: Confirming top {args.top_k} ({args.N_confirm} examples) ---")
+ confirmed = []
+ for r in results[:args.top_k]:
+ t0 = time.time()
+ mean_rl = run_uph_config(
+ wrapper, examples, support_sets, references, support_texts,
+ d=r['d'], lr=r['lr'], steps=r['steps'],
+ device=args.device, N=args.N_confirm,
+ )
+ elapsed = time.time() - t0
+ confirmed.append({**r, 'confirmed_rl': mean_rl, 'confirm_time': elapsed})
+ print(f" d={r['d']} lr={r['lr']} steps={r['steps']}: "
+ f"screen={r['mean_rl']:.4f} → confirm={mean_rl:.4f} ({elapsed:.0f}s)")
+
+ # Save
+ os.makedirs(args.output_dir, exist_ok=True)
+ output_path = os.path.join(args.output_dir,
+ f"{args.task}_{args.setting}_K{args.K}_search.json")
+ with open(output_path, 'w') as f:
+ json.dump({
+ 'screening': results,
+ 'confirmed': confirmed,
+ 'task': args.task,
+ 'setting': args.setting,
+ 'K': args.K,
+ 'N_screen': args.N_screen,
+ 'N_confirm': args.N_confirm,
+ }, f, indent=2, default=str)
+ print(f"\nSaved to {output_path}")
+
+
+if __name__ == '__main__':
+ main()