summaryrefslogtreecommitdiff
path: root/scripts/shift_analysis.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-03 15:12:34 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-03 15:12:34 -0500
commit8fe28101366dd32562b8c5534d7fe359b252bdf3 (patch)
treec92a92184fb2f46f265ab84c1f754c3d5d6597bc /scripts/shift_analysis.py
Initial commit: UPH project codebase and experiment results
Includes model code, evaluation scripts, configs, analysis outputs, and experiment results for the User Prior Head personalization method. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'scripts/shift_analysis.py')
-rw-r--r--scripts/shift_analysis.py178
1 files changed, 178 insertions, 0 deletions
diff --git a/scripts/shift_analysis.py b/scripts/shift_analysis.py
new file mode 100644
index 0000000..99c7fd2
--- /dev/null
+++ b/scripts/shift_analysis.py
@@ -0,0 +1,178 @@
+"""Support-query distribution shift analysis.
+
+For each user, compute:
+ s_u = cos(mean_support_hidden, mean_query_hidden)
+Then correlate with CVH-UPH performance gap:
+ delta_u = ROUGE-L(CVH, u) - ROUGE-L(UPH, u)
+
+If correlation is positive: CVH benefits when support-query are aligned.
+"""
+
+import sys
+import os
+import json
+import numpy as np
+from scipy import stats
+import torch
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from data.longlamp import load_longlamp, select_k_profile_items
+from data.templates import build_query_prompt, build_support_prompt
+from models.qwen_wrapper import QwenWrapper
+from models.cvh import CVHHead, UnconditionalHead
+from adapt.cache_hidden import cache_support_hidden_states
+from adapt.fit_theta import fit_theta
+from eval.metrics import compute_rouge
+
+
+def get_query_hidden_mean(wrapper, query_text, task):
+ """Get mean hidden state from the query prompt."""
+ chat_messages = [
+ {"role": "system", "content": "You are a helpful writing assistant."},
+ {"role": "user", "content": build_query_prompt(query_text, task)},
+ ]
+ prompt_text = wrapper.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device)
+
+ with torch.no_grad():
+ outputs = wrapper.model(
+ input_ids=input_ids,
+ output_hidden_states=True,
+ return_dict=True,
+ )
+ last_hidden = outputs.hidden_states[-1][0] # (seq_len, H)
+ return last_hidden.mean(dim=0).cpu().float().numpy()
+
+
+def main():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--num_eval', type=int, default=100)
+ parser.add_argument('--config', type=str, default='product_review_user')
+ args = parser.parse_args()
+
+ N = args.num_eval
+ print(f"=== Shift Analysis: {args.config}, N={N} ===")
+
+ print("Loading data...")
+ examples = load_longlamp(args.config, split='val')[:N]
+
+ print("Loading model...")
+ wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1')
+ H = wrapper.hidden_size
+ device = 'cuda:1'
+
+ uph_head = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device)
+ cvh_head = CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device)
+
+ lm_head_bias = None
+ if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
+ lm_head_bias = wrapper.model.lm_head.bias.data
+
+ K = 4
+ shift_cosines = []
+ uph_rouges = []
+ cvh_rouges = []
+
+ for i, ex in enumerate(examples):
+ support = select_k_profile_items(ex['profile_items'], K, seed=0)
+ cached_h = cache_support_hidden_states(wrapper, support, ex['task'])
+ if not cached_h:
+ continue
+
+ # Mean support hidden
+ all_h = torch.cat([h for h, _ in cached_h], dim=0)
+ support_mean = all_h.mean(dim=0).numpy()
+
+ # Mean query hidden
+ query_mean = get_query_hidden_mean(wrapper, ex['query_input'], ex['task'])
+
+ # Cosine similarity
+ cos = np.dot(support_mean, query_mean) / (
+ np.linalg.norm(support_mean) * np.linalg.norm(query_mean) + 1e-8)
+ shift_cosines.append(float(cos))
+
+ # Fit UPH theta
+ theta_uph = fit_theta(
+ cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias, head_module=uph_head,
+ d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4,
+ max_grad_norm=5.0, device=device, verbose=False,
+ )
+
+ # Fit CVH theta
+ theta_cvh = fit_theta(
+ cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias, head_module=cvh_head,
+ d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4,
+ max_grad_norm=5.0, device=device, verbose=False,
+ )
+
+ # Generate with both
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred_uph = wrapper.generate_with_head_blended(
+ prompt, theta_uph, uph_head.forward_fn,
+ blend_gamma=0.5, max_new_tokens=512, min_new_tokens=128, temperature=0.0,
+ )
+ pred_cvh = wrapper.generate_with_head_blended(
+ prompt, theta_cvh, cvh_head.forward_fn,
+ blend_gamma=0.5, max_new_tokens=512, min_new_tokens=128, temperature=0.0,
+ )
+
+ # ROUGE-L for each
+ rouge_uph = compute_rouge([pred_uph], [ex['target_output']])['rougeL']
+ rouge_cvh = compute_rouge([pred_cvh], [ex['target_output']])['rougeL']
+ uph_rouges.append(rouge_uph)
+ cvh_rouges.append(rouge_cvh)
+
+ del cached_h, theta_uph, theta_cvh
+ torch.cuda.empty_cache()
+
+ if (i + 1) % 20 == 0:
+ print(f" {i+1}/{N}")
+
+ # Compute correlation
+ shift_cosines = np.array(shift_cosines)
+ deltas = np.array(cvh_rouges) - np.array(uph_rouges) # positive = CVH better
+
+ rho, pval = stats.spearmanr(shift_cosines, deltas)
+
+ print(f"\n=== Results (N={len(shift_cosines)}) ===")
+ print(f" Mean shift cosine: {shift_cosines.mean():.4f} +/- {shift_cosines.std():.4f}")
+ print(f" Mean delta (CVH - UPH): {deltas.mean():.4f} +/- {deltas.std():.4f}")
+ print(f" Spearman(shift_cos, delta): rho={rho:.4f}, p={pval:.4f}")
+ print(f" Mean UPH ROUGE-L: {np.mean(uph_rouges):.4f}")
+ print(f" Mean CVH ROUGE-L: {np.mean(cvh_rouges):.4f}")
+
+ # Bin analysis: high vs low shift
+ median_cos = np.median(shift_cosines)
+ high_mask = shift_cosines >= median_cos
+ low_mask = shift_cosines < median_cos
+
+ print(f"\n High-alignment (cos >= {median_cos:.3f}, n={high_mask.sum()}):")
+ print(f" UPH R-L: {np.mean(np.array(uph_rouges)[high_mask]):.4f}")
+ print(f" CVH R-L: {np.mean(np.array(cvh_rouges)[high_mask]):.4f}")
+ print(f" Low-alignment (cos < {median_cos:.3f}, n={low_mask.sum()}):")
+ print(f" UPH R-L: {np.mean(np.array(uph_rouges)[low_mask]):.4f}")
+ print(f" CVH R-L: {np.mean(np.array(cvh_rouges)[low_mask]):.4f}")
+
+ # Save
+ os.makedirs('outputs/analysis', exist_ok=True)
+ save_data = {
+ 'shift_cosines': [float(x) for x in shift_cosines],
+ 'uph_rouges': [float(x) for x in uph_rouges],
+ 'cvh_rouges': [float(x) for x in cvh_rouges],
+ 'deltas': [float(x) for x in deltas],
+ 'spearman_rho': float(rho),
+ 'spearman_pval': float(pval),
+ }
+ with open('outputs/analysis/shift_analysis.json', 'w') as f:
+ json.dump(save_data, f, indent=2)
+ print("\nSaved to outputs/analysis/shift_analysis.json")
+
+
+if __name__ == '__main__':
+ main()