"""Support-query distribution shift analysis. For each user, compute: s_u = cos(mean_support_hidden, mean_query_hidden) Then correlate with CVH-UPH performance gap: delta_u = ROUGE-L(CVH, u) - ROUGE-L(UPH, u) If correlation is positive: CVH benefits when support-query are aligned. """ import sys import os import json import numpy as np from scipy import stats import torch sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.longlamp import load_longlamp, select_k_profile_items from data.templates import build_query_prompt, build_support_prompt from models.qwen_wrapper import QwenWrapper from models.cvh import CVHHead, UnconditionalHead from adapt.cache_hidden import cache_support_hidden_states from adapt.fit_theta import fit_theta from eval.metrics import compute_rouge def get_query_hidden_mean(wrapper, query_text, task): """Get mean hidden state from the query prompt.""" chat_messages = [ {"role": "system", "content": "You are a helpful writing assistant."}, {"role": "user", "content": build_query_prompt(query_text, task)}, ] prompt_text = wrapper.tokenizer.apply_chat_template( chat_messages, tokenize=False, add_generation_prompt=True ) input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device) with torch.no_grad(): outputs = wrapper.model( input_ids=input_ids, output_hidden_states=True, return_dict=True, ) last_hidden = outputs.hidden_states[-1][0] # (seq_len, H) return last_hidden.mean(dim=0).cpu().float().numpy() def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument('--num_eval', type=int, default=100) parser.add_argument('--config', type=str, default='product_review_user') args = parser.parse_args() N = args.num_eval print(f"=== Shift Analysis: {args.config}, N={N} ===") print("Loading data...") examples = load_longlamp(args.config, split='val')[:N] print("Loading model...") wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1') H = wrapper.hidden_size device = 'cuda:1' uph_head = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device) cvh_head = CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device) lm_head_bias = None if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: lm_head_bias = wrapper.model.lm_head.bias.data K = 4 shift_cosines = [] uph_rouges = [] cvh_rouges = [] for i, ex in enumerate(examples): support = select_k_profile_items(ex['profile_items'], K, seed=0) cached_h = cache_support_hidden_states(wrapper, support, ex['task']) if not cached_h: continue # Mean support hidden all_h = torch.cat([h for h, _ in cached_h], dim=0) support_mean = all_h.mean(dim=0).numpy() # Mean query hidden query_mean = get_query_hidden_mean(wrapper, ex['query_input'], ex['task']) # Cosine similarity cos = np.dot(support_mean, query_mean) / ( np.linalg.norm(support_mean) * np.linalg.norm(query_mean) + 1e-8) shift_cosines.append(float(cos)) # Fit UPH theta theta_uph = fit_theta( cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight, lm_head_bias=lm_head_bias, head_module=uph_head, d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4, max_grad_norm=5.0, device=device, verbose=False, ) # Fit CVH theta theta_cvh = fit_theta( cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight, lm_head_bias=lm_head_bias, head_module=cvh_head, d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4, max_grad_norm=5.0, device=device, verbose=False, ) # Generate with both prompt = build_query_prompt(ex['query_input'], ex['task']) pred_uph = wrapper.generate_with_head_blended( prompt, theta_uph, uph_head.forward_fn, blend_gamma=0.5, max_new_tokens=512, min_new_tokens=128, temperature=0.0, ) pred_cvh = wrapper.generate_with_head_blended( prompt, theta_cvh, cvh_head.forward_fn, blend_gamma=0.5, max_new_tokens=512, min_new_tokens=128, temperature=0.0, ) # ROUGE-L for each rouge_uph = compute_rouge([pred_uph], [ex['target_output']])['rougeL'] rouge_cvh = compute_rouge([pred_cvh], [ex['target_output']])['rougeL'] uph_rouges.append(rouge_uph) cvh_rouges.append(rouge_cvh) del cached_h, theta_uph, theta_cvh torch.cuda.empty_cache() if (i + 1) % 20 == 0: print(f" {i+1}/{N}") # Compute correlation shift_cosines = np.array(shift_cosines) deltas = np.array(cvh_rouges) - np.array(uph_rouges) # positive = CVH better rho, pval = stats.spearmanr(shift_cosines, deltas) print(f"\n=== Results (N={len(shift_cosines)}) ===") print(f" Mean shift cosine: {shift_cosines.mean():.4f} +/- {shift_cosines.std():.4f}") print(f" Mean delta (CVH - UPH): {deltas.mean():.4f} +/- {deltas.std():.4f}") print(f" Spearman(shift_cos, delta): rho={rho:.4f}, p={pval:.4f}") print(f" Mean UPH ROUGE-L: {np.mean(uph_rouges):.4f}") print(f" Mean CVH ROUGE-L: {np.mean(cvh_rouges):.4f}") # Bin analysis: high vs low shift median_cos = np.median(shift_cosines) high_mask = shift_cosines >= median_cos low_mask = shift_cosines < median_cos print(f"\n High-alignment (cos >= {median_cos:.3f}, n={high_mask.sum()}):") print(f" UPH R-L: {np.mean(np.array(uph_rouges)[high_mask]):.4f}") print(f" CVH R-L: {np.mean(np.array(cvh_rouges)[high_mask]):.4f}") print(f" Low-alignment (cos < {median_cos:.3f}, n={low_mask.sum()}):") print(f" UPH R-L: {np.mean(np.array(uph_rouges)[low_mask]):.4f}") print(f" CVH R-L: {np.mean(np.array(cvh_rouges)[low_mask]):.4f}") # Save os.makedirs('outputs/analysis', exist_ok=True) save_data = { 'shift_cosines': [float(x) for x in shift_cosines], 'uph_rouges': [float(x) for x in uph_rouges], 'cvh_rouges': [float(x) for x in cvh_rouges], 'deltas': [float(x) for x in deltas], 'spearman_rho': float(rho), 'spearman_pval': float(pval), } with open('outputs/analysis/shift_analysis.json', 'w') as f: json.dump(save_data, f, indent=2) print("\nSaved to outputs/analysis/shift_analysis.json") if __name__ == '__main__': main()