From 8fe28101366dd32562b8c5534d7fe359b252bdf3 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Fri, 3 Apr 2026 15:12:34 -0500 Subject: Initial commit: UPH project codebase and experiment results Includes model code, evaluation scripts, configs, analysis outputs, and experiment results for the User Prior Head personalization method. Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/shift_analysis.py | 178 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 178 insertions(+) create mode 100644 scripts/shift_analysis.py (limited to 'scripts/shift_analysis.py') diff --git a/scripts/shift_analysis.py b/scripts/shift_analysis.py new file mode 100644 index 0000000..99c7fd2 --- /dev/null +++ b/scripts/shift_analysis.py @@ -0,0 +1,178 @@ +"""Support-query distribution shift analysis. + +For each user, compute: + s_u = cos(mean_support_hidden, mean_query_hidden) +Then correlate with CVH-UPH performance gap: + delta_u = ROUGE-L(CVH, u) - ROUGE-L(UPH, u) + +If correlation is positive: CVH benefits when support-query are aligned. +""" + +import sys +import os +import json +import numpy as np +from scipy import stats +import torch + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from data.longlamp import load_longlamp, select_k_profile_items +from data.templates import build_query_prompt, build_support_prompt +from models.qwen_wrapper import QwenWrapper +from models.cvh import CVHHead, UnconditionalHead +from adapt.cache_hidden import cache_support_hidden_states +from adapt.fit_theta import fit_theta +from eval.metrics import compute_rouge + + +def get_query_hidden_mean(wrapper, query_text, task): + """Get mean hidden state from the query prompt.""" + chat_messages = [ + {"role": "system", "content": "You are a helpful writing assistant."}, + {"role": "user", "content": build_query_prompt(query_text, task)}, + ] + prompt_text = wrapper.tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) + input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device) + + with torch.no_grad(): + outputs = wrapper.model( + input_ids=input_ids, + output_hidden_states=True, + return_dict=True, + ) + last_hidden = outputs.hidden_states[-1][0] # (seq_len, H) + return last_hidden.mean(dim=0).cpu().float().numpy() + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--num_eval', type=int, default=100) + parser.add_argument('--config', type=str, default='product_review_user') + args = parser.parse_args() + + N = args.num_eval + print(f"=== Shift Analysis: {args.config}, N={N} ===") + + print("Loading data...") + examples = load_longlamp(args.config, split='val')[:N] + + print("Loading model...") + wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1') + H = wrapper.hidden_size + device = 'cuda:1' + + uph_head = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device) + cvh_head = CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device) + + lm_head_bias = None + if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: + lm_head_bias = wrapper.model.lm_head.bias.data + + K = 4 + shift_cosines = [] + uph_rouges = [] + cvh_rouges = [] + + for i, ex in enumerate(examples): + support = select_k_profile_items(ex['profile_items'], K, seed=0) + cached_h = cache_support_hidden_states(wrapper, support, ex['task']) + if not cached_h: + continue + + # Mean support hidden + all_h = torch.cat([h for h, _ in cached_h], dim=0) + support_mean = all_h.mean(dim=0).numpy() + + # Mean query hidden + query_mean = get_query_hidden_mean(wrapper, ex['query_input'], ex['task']) + + # Cosine similarity + cos = np.dot(support_mean, query_mean) / ( + np.linalg.norm(support_mean) * np.linalg.norm(query_mean) + 1e-8) + shift_cosines.append(float(cos)) + + # Fit UPH theta + theta_uph = fit_theta( + cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, head_module=uph_head, + d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4, + max_grad_norm=5.0, device=device, verbose=False, + ) + + # Fit CVH theta + theta_cvh = fit_theta( + cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, head_module=cvh_head, + d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4, + max_grad_norm=5.0, device=device, verbose=False, + ) + + # Generate with both + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred_uph = wrapper.generate_with_head_blended( + prompt, theta_uph, uph_head.forward_fn, + blend_gamma=0.5, max_new_tokens=512, min_new_tokens=128, temperature=0.0, + ) + pred_cvh = wrapper.generate_with_head_blended( + prompt, theta_cvh, cvh_head.forward_fn, + blend_gamma=0.5, max_new_tokens=512, min_new_tokens=128, temperature=0.0, + ) + + # ROUGE-L for each + rouge_uph = compute_rouge([pred_uph], [ex['target_output']])['rougeL'] + rouge_cvh = compute_rouge([pred_cvh], [ex['target_output']])['rougeL'] + uph_rouges.append(rouge_uph) + cvh_rouges.append(rouge_cvh) + + del cached_h, theta_uph, theta_cvh + torch.cuda.empty_cache() + + if (i + 1) % 20 == 0: + print(f" {i+1}/{N}") + + # Compute correlation + shift_cosines = np.array(shift_cosines) + deltas = np.array(cvh_rouges) - np.array(uph_rouges) # positive = CVH better + + rho, pval = stats.spearmanr(shift_cosines, deltas) + + print(f"\n=== Results (N={len(shift_cosines)}) ===") + print(f" Mean shift cosine: {shift_cosines.mean():.4f} +/- {shift_cosines.std():.4f}") + print(f" Mean delta (CVH - UPH): {deltas.mean():.4f} +/- {deltas.std():.4f}") + print(f" Spearman(shift_cos, delta): rho={rho:.4f}, p={pval:.4f}") + print(f" Mean UPH ROUGE-L: {np.mean(uph_rouges):.4f}") + print(f" Mean CVH ROUGE-L: {np.mean(cvh_rouges):.4f}") + + # Bin analysis: high vs low shift + median_cos = np.median(shift_cosines) + high_mask = shift_cosines >= median_cos + low_mask = shift_cosines < median_cos + + print(f"\n High-alignment (cos >= {median_cos:.3f}, n={high_mask.sum()}):") + print(f" UPH R-L: {np.mean(np.array(uph_rouges)[high_mask]):.4f}") + print(f" CVH R-L: {np.mean(np.array(cvh_rouges)[high_mask]):.4f}") + print(f" Low-alignment (cos < {median_cos:.3f}, n={low_mask.sum()}):") + print(f" UPH R-L: {np.mean(np.array(uph_rouges)[low_mask]):.4f}") + print(f" CVH R-L: {np.mean(np.array(cvh_rouges)[low_mask]):.4f}") + + # Save + os.makedirs('outputs/analysis', exist_ok=True) + save_data = { + 'shift_cosines': [float(x) for x in shift_cosines], + 'uph_rouges': [float(x) for x in uph_rouges], + 'cvh_rouges': [float(x) for x in cvh_rouges], + 'deltas': [float(x) for x in deltas], + 'spearman_rho': float(rho), + 'spearman_pval': float(pval), + } + with open('outputs/analysis/shift_analysis.json', 'w') as f: + json.dump(save_data, f, indent=2) + print("\nSaved to outputs/analysis/shift_analysis.json") + + +if __name__ == '__main__': + main() -- cgit v1.2.3