summaryrefslogtreecommitdiff
path: root/scripts/run_uph_base_per_user.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/run_uph_base_per_user.py')
-rw-r--r--scripts/run_uph_base_per_user.py263
1 files changed, 263 insertions, 0 deletions
diff --git a/scripts/run_uph_base_per_user.py b/scripts/run_uph_base_per_user.py
new file mode 100644
index 0000000..4a48396
--- /dev/null
+++ b/scripts/run_uph_base_per_user.py
@@ -0,0 +1,263 @@
+"""Run UPH and Base with complete per-user data saving.
+
+Saves predictions, references, all per-user metrics (R-L, METEOR, SFD, feature deltas),
+and metadata. Then computes significance tests vs PEFT baselines.
+
+Usage:
+ python scripts/run_uph_base_per_user.py --task review --setting user --device cuda:0
+"""
+
+import sys
+import os
+import json
+import time
+import numpy as np
+import torch
+from scipy import stats
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from data.longlamp import load_longlamp, select_k_profile_items
+from data.templates import build_query_prompt
+from data.style_features import compute_sfd, compute_feature_deltas
+from models.qwen_wrapper import QwenWrapper
+from models.cvh import UnconditionalHead
+from adapt.cache_hidden import cache_support_hidden_states
+from adapt.fit_theta import fit_theta
+from eval.metrics import compute_rouge, compute_meteor
+
+
+def compute_per_user_metrics(pred, ref, support_texts):
+ r = compute_rouge([pred], [ref])
+ m = compute_meteor([pred], [ref])
+ p = pred if pred.strip() else "empty"
+ sfd_all = compute_sfd(p, support_texts, exclude_length=False)
+ sfd_nolen = compute_sfd(p, support_texts, exclude_length=True)
+ deltas = compute_feature_deltas(p, support_texts)
+ return {
+ 'rouge1': r['rouge1'],
+ 'rougeL': r['rougeL'],
+ 'meteor': m,
+ 'sfd_all': sfd_all,
+ 'sfd_nolen': sfd_nolen,
+ 'length': len(pred.split()),
+ 'feature_deltas': {k: v['delta'] for k, v in deltas.items()},
+ }
+
+
+def generate_base(wrapper, prompt, max_new_tokens=512, min_new_tokens=128):
+ chat_messages = [
+ {"role": "system", "content": "You are a helpful writing assistant."},
+ {"role": "user", "content": prompt},
+ ]
+ prompt_text = wrapper.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device)
+ with torch.no_grad():
+ outputs = wrapper.model.generate(
+ input_ids,
+ max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens,
+ temperature=None, top_p=None, do_sample=False,
+ pad_token_id=wrapper.tokenizer.pad_token_id,
+ )
+ return wrapper.tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True)
+
+
+def paired_test(scores_a, scores_b, name_a, name_b, metric_name):
+ a, b = np.array(scores_a), np.array(scores_b)
+ diff = a - b
+ mean_diff = np.mean(diff)
+ t_stat, t_pval = stats.ttest_rel(a, b)
+ try:
+ w_stat, w_pval = stats.wilcoxon(a, b)
+ except ValueError:
+ w_stat, w_pval = float('nan'), float('nan')
+ se = stats.sem(diff)
+ ci_low, ci_high = mean_diff - 1.96 * se, mean_diff + 1.96 * se
+
+ print(f" {name_a} vs {name_b} ({metric_name}): "
+ f"diff={mean_diff:+.4f}, 95% CI=[{ci_low:+.4f}, {ci_high:+.4f}], "
+ f"t-test p={t_pval:.2e}, Wilcoxon p={w_pval:.2e}")
+ return {
+ 'mean_a': float(np.mean(a)), 'mean_b': float(np.mean(b)),
+ 'mean_diff': float(mean_diff),
+ 'ci_low': float(ci_low), 'ci_high': float(ci_high),
+ 't_pval': float(t_pval), 'w_pval': float(w_pval),
+ }
+
+
+def main():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--num_eval', type=int, default=200)
+ parser.add_argument('--task', type=str, default='review', choices=['review', 'topic'])
+ parser.add_argument('--setting', type=str, default='user', choices=['user', 'temporal'])
+ parser.add_argument('--device', type=str, default='cuda:0')
+ parser.add_argument('--output_dir', type=str, default='outputs/per_user')
+ args = parser.parse_args()
+
+ N = args.num_eval
+ device = args.device
+ task = args.task
+ setting = args.setting
+
+ config_map = {
+ ('review', 'user'): 'product_review_user',
+ ('review', 'temporal'): 'product_review_temporal',
+ ('topic', 'user'): 'topic_writing_user',
+ ('topic', 'temporal'): 'topic_writing_temporal',
+ }
+ config_name = config_map[(task, setting)]
+
+ print(f"=== UPH + Base per-user: {task}_{setting}, N={N} ===")
+
+ print("\nLoading data...")
+ examples = load_longlamp(config_name, split='val')[:N]
+ K = 4
+ support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples]
+ references = [ex['target_output'] for ex in examples]
+ support_texts = [[s['support_output'] for s in ss] for ss in support_sets]
+
+ print(f"Loading model on {device}...")
+ wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device=device)
+ H = wrapper.hidden_size
+
+ all_per_user = {}
+
+ # === Base ===
+ print("\n--- Base ---")
+ base_per_user = []
+ for i, ex in enumerate(examples):
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ t0 = time.time()
+ pred = generate_base(wrapper, prompt)
+ gen_time = time.time() - t0
+
+ metrics = compute_per_user_metrics(pred, references[i], support_texts[i])
+ base_per_user.append({
+ 'example_id': ex['example_id'],
+ 'user_id': ex['user_id'],
+ 'prediction': pred,
+ 'reference': references[i],
+ 'support_texts': support_texts[i],
+ 'K': K,
+ 'gen_time': gen_time,
+ 'metrics': metrics,
+ })
+ if (i + 1) % 40 == 0:
+ avg_rl = np.mean([u['metrics']['rougeL'] for u in base_per_user])
+ print(f" {i+1}/{N} (avg R-L: {avg_rl:.4f})")
+
+ all_per_user['Base'] = base_per_user
+ avg_rl = np.mean([u['metrics']['rougeL'] for u in base_per_user])
+ print(f" Mean R-L: {avg_rl:.4f}")
+
+ # === UPH ===
+ print("\n--- UPH ---")
+ uncond = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device)
+ lm_head_bias = None
+ if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
+ lm_head_bias = wrapper.model.lm_head.bias.data
+
+ uph_per_user = []
+ for i, (ex, support) in enumerate(zip(examples, support_sets)):
+ t0 = time.time()
+ cached_h = cache_support_hidden_states(wrapper, support, ex['task'])
+ if not cached_h:
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = generate_base(wrapper, prompt)
+ else:
+ theta = fit_theta(
+ cached_h=cached_h,
+ lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias,
+ head_module=uncond,
+ d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4,
+ max_grad_norm=5.0, device=device,
+ )
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_with_head_blended(
+ prompt, theta, uncond.forward_fn,
+ blend_gamma=0.5, max_new_tokens=512,
+ min_new_tokens=128, temperature=0.0,
+ )
+ del cached_h, theta
+ torch.cuda.empty_cache()
+
+ adapt_time = time.time() - t0
+ metrics = compute_per_user_metrics(pred, references[i], support_texts[i])
+ uph_per_user.append({
+ 'example_id': ex['example_id'],
+ 'user_id': ex['user_id'],
+ 'prediction': pred,
+ 'reference': references[i],
+ 'support_texts': support_texts[i],
+ 'K': K,
+ 'adapt_time': adapt_time,
+ 'metrics': metrics,
+ })
+ if (i + 1) % 40 == 0:
+ avg_rl = np.mean([u['metrics']['rougeL'] for u in uph_per_user])
+ print(f" {i+1}/{N} (avg R-L: {avg_rl:.4f})")
+
+ all_per_user['UPH'] = uph_per_user
+ avg_rl = np.mean([u['metrics']['rougeL'] for u in uph_per_user])
+ print(f" Mean R-L: {avg_rl:.4f}")
+
+ # Save per-user data
+ os.makedirs(args.output_dir, exist_ok=True)
+ per_user_path = os.path.join(args.output_dir, f"{task}_{setting}_uph_base_per_user.json")
+ with open(per_user_path, 'w') as f:
+ json.dump({
+ 'per_user': all_per_user,
+ 'num_examples': N, 'task': task, 'setting': setting, 'K': K,
+ }, f, indent=2, default=str)
+ print(f"\nPer-user data saved to {per_user_path}")
+
+ # === Significance tests vs PEFT ===
+ peft_path = f"outputs/peft_baselines/{task}_{setting}_K4_N{N}_peft_per_user.json"
+ if os.path.exists(peft_path):
+ with open(peft_path) as f:
+ peft_data = json.load(f)
+
+ print("\n" + "=" * 80)
+ print("SIGNIFICANCE TESTS — ALL METRICS (UPH vs each baseline)")
+ print("=" * 80)
+
+ uph_rl = [u['metrics']['rougeL'] for u in uph_per_user]
+ uph_sfd = [u['metrics']['sfd_nolen'] for u in uph_per_user]
+ uph_meteor = [u['metrics']['meteor'] for u in uph_per_user]
+
+ all_tests = {}
+ comparisons = {
+ 'Base': base_per_user,
+ }
+ for m in ['lora', 'tiny_lora', 'vera']:
+ if m in peft_data['per_user']:
+ comparisons[m] = peft_data['per_user'][m]
+
+ for name, users in comparisons.items():
+ other_rl = [u['metrics']['rougeL'] for u in users]
+ other_sfd = [u['metrics']['sfd_nolen'] for u in users]
+ other_meteor = [u['metrics']['meteor'] for u in users]
+
+ print(f"\n--- UPH vs {name} ---")
+ tests = {}
+ tests['rougeL'] = paired_test(uph_rl, other_rl, 'UPH', name, 'ROUGE-L')
+ tests['sfd_nolen'] = paired_test(uph_sfd, other_sfd, 'UPH', name, 'SFD_-len')
+ tests['meteor'] = paired_test(uph_meteor, other_meteor, 'UPH', name, 'METEOR')
+ all_tests[f'UPH_vs_{name}'] = tests
+
+ # Save significance results
+ sig_path = os.path.join(args.output_dir, f"{task}_{setting}_all_significance.json")
+ with open(sig_path, 'w') as f:
+ json.dump({
+ 'significance_tests': all_tests,
+ 'num_examples': N, 'task': task, 'setting': setting,
+ }, f, indent=2, default=str)
+ print(f"\nSignificance tests saved to {sig_path}")
+
+
+if __name__ == '__main__':
+ main()