"""User-state geometry / representational alignment analysis. Computes: 1. RSA: Spearman(cos(theta_u, theta_v), cos(phi_u, phi_v)) for all-style and -len/newline 2. Self-consistency: Delta_self = E_u[cos(theta_a, theta_b)] - E_{u!=v}[cos(theta_a, theta_v)] 3. Ridge probe: R^2 for predicting style features from theta 4. PCA visualization """ import sys import os import json import numpy as np from scipy import stats from sklearn.linear_model import Ridge from sklearn.model_selection import cross_val_score import torch sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.longlamp import load_longlamp, select_k_profile_items from data.style_features import extract_style_features, FEATURE_NAMES from models.qwen_wrapper import QwenWrapper from models.cvh import UnconditionalHead from adapt.cache_hidden import cache_support_hidden_states from adapt.fit_theta import fit_theta def collect_thetas_and_styles(wrapper, examples, K=4, seed=0): """Collect theta_u and style prototypes for all users.""" device = 'cuda:1' H = wrapper.hidden_size head = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device) lm_head_bias = None if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: lm_head_bias = wrapper.model.lm_head.bias.data thetas = [] style_protos = [] user_ids = [] for i, ex in enumerate(examples): support = select_k_profile_items(ex['profile_items'], K, seed=seed) cached_h = cache_support_hidden_states(wrapper, support, ex['task']) if not cached_h: continue theta = fit_theta( cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight, lm_head_bias=lm_head_bias, head_module=head, d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4, max_grad_norm=5.0, device=device, verbose=False, ) thetas.append(theta.cpu().numpy()) # Compute style prototype support_texts = [s['support_output'] for s in support] features_list = [extract_style_features(t) for t in support_texts] proto = np.mean(features_list, axis=0) style_protos.append(proto) user_ids.append(ex['user_id']) del cached_h, theta torch.cuda.empty_cache() if (i + 1) % 40 == 0: print(f" Collected {i+1}/{len(examples)}") return np.array(thetas), np.array(style_protos), user_ids def compute_rsa(thetas, style_protos, exclude_indices=None): """Compute RSA: Spearman correlation between theta similarity and style similarity.""" N = len(thetas) # Theta cosine similarity matrix theta_norms = np.linalg.norm(thetas, axis=1, keepdims=True) theta_norms = np.maximum(theta_norms, 1e-8) theta_normed = thetas / theta_norms theta_sim = theta_normed @ theta_normed.T # Style cosine similarity matrix if exclude_indices is not None: style = np.delete(style_protos, exclude_indices, axis=1) else: style = style_protos.copy() style_norms = np.linalg.norm(style, axis=1, keepdims=True) style_norms = np.maximum(style_norms, 1e-8) style_normed = style / style_norms style_sim = style_normed @ style_normed.T # Extract upper triangle idx = np.triu_indices(N, k=1) theta_upper = theta_sim[idx] style_upper = style_sim[idx] rho, pval = stats.spearmanr(theta_upper, style_upper) return rho, pval def compute_self_consistency(wrapper, examples, K=4): """Compute self-consistency by fitting theta with different support subsets.""" device = 'cuda:1' H = wrapper.hidden_size head = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device) lm_head_bias = None if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: lm_head_bias = wrapper.model.lm_head.bias.data thetas_a = [] thetas_b = [] valid_indices = [] for i, ex in enumerate(examples): profile = ex['profile_items'] if len(profile) < 2 * K: continue # Two different subsets support_a = select_k_profile_items(profile, K, seed=100) support_b = select_k_profile_items(profile, K, seed=200) cached_a = cache_support_hidden_states(wrapper, support_a, ex['task']) cached_b = cache_support_hidden_states(wrapper, support_b, ex['task']) if not cached_a or not cached_b: continue theta_a = fit_theta( cached_h=cached_a, lm_head_weight=wrapper.lm_head_weight, lm_head_bias=lm_head_bias, head_module=head, d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4, max_grad_norm=5.0, device=device, verbose=False, ) theta_b = fit_theta( cached_h=cached_b, lm_head_weight=wrapper.lm_head_weight, lm_head_bias=lm_head_bias, head_module=head, d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4, max_grad_norm=5.0, device=device, verbose=False, ) thetas_a.append(theta_a.cpu().numpy()) thetas_b.append(theta_b.cpu().numpy()) valid_indices.append(i) del cached_a, cached_b, theta_a, theta_b torch.cuda.empty_cache() if (i + 1) % 40 == 0: print(f" Self-consistency: {i+1}/{len(examples)} ({len(valid_indices)} valid)") thetas_a = np.array(thetas_a) thetas_b = np.array(thetas_b) N = len(thetas_a) if N < 5: return 0.0, 0.0, 0.0 # Self similarity: cos(theta_a_u, theta_b_u) self_cos = [] for u in range(N): cos = np.dot(thetas_a[u], thetas_b[u]) / ( np.linalg.norm(thetas_a[u]) * np.linalg.norm(thetas_b[u]) + 1e-8) self_cos.append(cos) avg_self = np.mean(self_cos) # Cross similarity: cos(theta_a_u, theta_b_v) for u != v cross_cos = [] for u in range(N): for v in range(N): if u == v: continue cos = np.dot(thetas_a[u], thetas_b[v]) / ( np.linalg.norm(thetas_a[u]) * np.linalg.norm(thetas_b[v]) + 1e-8) cross_cos.append(cos) avg_cross = np.mean(cross_cos) delta_self = avg_self - avg_cross return avg_self, avg_cross, delta_self def compute_ridge_probe(thetas, style_protos): """Probe: predict each style feature from theta using Ridge regression.""" results = {} N = len(thetas) for i, feat_name in enumerate(FEATURE_NAMES): y = style_protos[:, i] # Check if target has variance if np.std(y) < 1e-8: results[feat_name] = 0.0 continue ridge = Ridge(alpha=1.0) scores = cross_val_score(ridge, thetas, y, cv=min(5, N), scoring='r2') results[feat_name] = max(np.mean(scores), 0.0) # Clip negative R2 to 0 return results def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument('--num_eval', type=int, default=200) parser.add_argument('--config', type=str, default='product_review_user') args = parser.parse_args() N = args.num_eval print(f"=== Theta Analysis: {args.config}, N={N} ===") print("\nLoading data...") examples = load_longlamp(args.config, split='val')[:N] print(f"Loaded {len(examples)} examples") print("\nLoading model...") wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1') # 1. Collect thetas and style prototypes print("\n--- Collecting thetas and style prototypes ---") thetas, style_protos, user_ids = collect_thetas_and_styles(wrapper, examples, K=4, seed=0) print(f"Collected {len(thetas)} vectors") # 2. RSA print("\n--- RSA (Representational Similarity Analysis) ---") rho_all, pval_all = compute_rsa(thetas, style_protos) # Exclude length (index 0) and newline_rate (index 3) rho_nolen, pval_nolen = compute_rsa(thetas, style_protos, exclude_indices=[0, 3]) print(f" rho_all: {rho_all:.4f} (p={pval_all:.2e})") print(f" rho_-len/newline: {rho_nolen:.4f} (p={pval_nolen:.2e})") # 3. Self-consistency print("\n--- Self-Consistency ---") avg_self, avg_cross, delta_self = compute_self_consistency(wrapper, examples, K=4) print(f" avg_self_cos: {avg_self:.4f}") print(f" avg_cross_cos: {avg_cross:.4f}") print(f" Delta_self: {delta_self:.4f}") # 4. Ridge probe print("\n--- Ridge Probe (R^2) ---") probe_results = compute_ridge_probe(thetas, style_protos) for feat_name in FEATURE_NAMES: r2 = probe_results[feat_name] print(f" {feat_name:<20}: R^2 = {r2:.4f}") # Summary: the 6 key numbers print("\n" + "=" * 60) print("KEY NUMBERS FOR PAPER DECISION") print("=" * 60) print(f" rho_all: {rho_all:.4f}") print(f" rho_-len/newline: {rho_nolen:.4f}") print(f" Delta_self: {delta_self:.4f}") print(f" R^2_TTR: {probe_results.get('TTR', 0.0):.4f}") print(f" R^2_first_person: {probe_results.get('first_person_rate', 0.0):.4f}") print(f" R^2_newline: {probe_results.get('newline_rate', 0.0):.4f}") # Save results os.makedirs('outputs/analysis', exist_ok=True) save_data = { 'rsa_all': {'rho': float(rho_all), 'pval': float(pval_all)}, 'rsa_nolen': {'rho': float(rho_nolen), 'pval': float(pval_nolen)}, 'self_consistency': {'avg_self': float(avg_self), 'avg_cross': float(avg_cross), 'delta_self': float(delta_self)}, 'probe_r2': {k: float(v) for k, v in probe_results.items()}, 'num_users': len(thetas), 'thetas': [[float(x) for x in row] for row in thetas], 'style_protos': [[float(x) for x in row] for row in style_protos], 'user_ids': user_ids, } with open('outputs/analysis/theta_analysis.json', 'w') as f: json.dump(save_data, f, indent=2) print("\nSaved to outputs/analysis/theta_analysis.json") if __name__ == '__main__': main()