diff options
Diffstat (limited to 'scripts/theta_analysis.py')
| -rw-r--r-- | scripts/theta_analysis.py | 281 |
1 files changed, 281 insertions, 0 deletions
diff --git a/scripts/theta_analysis.py b/scripts/theta_analysis.py new file mode 100644 index 0000000..94d4010 --- /dev/null +++ b/scripts/theta_analysis.py @@ -0,0 +1,281 @@ +"""User-state geometry / representational alignment analysis. + +Computes: +1. RSA: Spearman(cos(theta_u, theta_v), cos(phi_u, phi_v)) for all-style and -len/newline +2. Self-consistency: Delta_self = E_u[cos(theta_a, theta_b)] - E_{u!=v}[cos(theta_a, theta_v)] +3. Ridge probe: R^2 for predicting style features from theta +4. PCA visualization +""" + +import sys +import os +import json +import numpy as np +from scipy import stats +from sklearn.linear_model import Ridge +from sklearn.model_selection import cross_val_score +import torch + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from data.longlamp import load_longlamp, select_k_profile_items +from data.style_features import extract_style_features, FEATURE_NAMES +from models.qwen_wrapper import QwenWrapper +from models.cvh import UnconditionalHead +from adapt.cache_hidden import cache_support_hidden_states +from adapt.fit_theta import fit_theta + + +def collect_thetas_and_styles(wrapper, examples, K=4, seed=0): + """Collect theta_u and style prototypes for all users.""" + device = 'cuda:1' + H = wrapper.hidden_size + head = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device) + + lm_head_bias = None + if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: + lm_head_bias = wrapper.model.lm_head.bias.data + + thetas = [] + style_protos = [] + user_ids = [] + + for i, ex in enumerate(examples): + support = select_k_profile_items(ex['profile_items'], K, seed=seed) + cached_h = cache_support_hidden_states(wrapper, support, ex['task']) + if not cached_h: + continue + + theta = fit_theta( + cached_h=cached_h, + lm_head_weight=wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, + head_module=head, + d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4, + max_grad_norm=5.0, device=device, verbose=False, + ) + + thetas.append(theta.cpu().numpy()) + + # Compute style prototype + support_texts = [s['support_output'] for s in support] + features_list = [extract_style_features(t) for t in support_texts] + proto = np.mean(features_list, axis=0) + style_protos.append(proto) + user_ids.append(ex['user_id']) + + del cached_h, theta + torch.cuda.empty_cache() + + if (i + 1) % 40 == 0: + print(f" Collected {i+1}/{len(examples)}") + + return np.array(thetas), np.array(style_protos), user_ids + + +def compute_rsa(thetas, style_protos, exclude_indices=None): + """Compute RSA: Spearman correlation between theta similarity and style similarity.""" + N = len(thetas) + + # Theta cosine similarity matrix + theta_norms = np.linalg.norm(thetas, axis=1, keepdims=True) + theta_norms = np.maximum(theta_norms, 1e-8) + theta_normed = thetas / theta_norms + theta_sim = theta_normed @ theta_normed.T + + # Style cosine similarity matrix + if exclude_indices is not None: + style = np.delete(style_protos, exclude_indices, axis=1) + else: + style = style_protos.copy() + + style_norms = np.linalg.norm(style, axis=1, keepdims=True) + style_norms = np.maximum(style_norms, 1e-8) + style_normed = style / style_norms + style_sim = style_normed @ style_normed.T + + # Extract upper triangle + idx = np.triu_indices(N, k=1) + theta_upper = theta_sim[idx] + style_upper = style_sim[idx] + + rho, pval = stats.spearmanr(theta_upper, style_upper) + return rho, pval + + +def compute_self_consistency(wrapper, examples, K=4): + """Compute self-consistency by fitting theta with different support subsets.""" + device = 'cuda:1' + H = wrapper.hidden_size + head = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device) + + lm_head_bias = None + if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: + lm_head_bias = wrapper.model.lm_head.bias.data + + thetas_a = [] + thetas_b = [] + valid_indices = [] + + for i, ex in enumerate(examples): + profile = ex['profile_items'] + if len(profile) < 2 * K: + continue + + # Two different subsets + support_a = select_k_profile_items(profile, K, seed=100) + support_b = select_k_profile_items(profile, K, seed=200) + + cached_a = cache_support_hidden_states(wrapper, support_a, ex['task']) + cached_b = cache_support_hidden_states(wrapper, support_b, ex['task']) + + if not cached_a or not cached_b: + continue + + theta_a = fit_theta( + cached_h=cached_a, lm_head_weight=wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, head_module=head, + d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4, + max_grad_norm=5.0, device=device, verbose=False, + ) + theta_b = fit_theta( + cached_h=cached_b, lm_head_weight=wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, head_module=head, + d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4, + max_grad_norm=5.0, device=device, verbose=False, + ) + + thetas_a.append(theta_a.cpu().numpy()) + thetas_b.append(theta_b.cpu().numpy()) + valid_indices.append(i) + + del cached_a, cached_b, theta_a, theta_b + torch.cuda.empty_cache() + + if (i + 1) % 40 == 0: + print(f" Self-consistency: {i+1}/{len(examples)} ({len(valid_indices)} valid)") + + thetas_a = np.array(thetas_a) + thetas_b = np.array(thetas_b) + N = len(thetas_a) + + if N < 5: + return 0.0, 0.0, 0.0 + + # Self similarity: cos(theta_a_u, theta_b_u) + self_cos = [] + for u in range(N): + cos = np.dot(thetas_a[u], thetas_b[u]) / ( + np.linalg.norm(thetas_a[u]) * np.linalg.norm(thetas_b[u]) + 1e-8) + self_cos.append(cos) + avg_self = np.mean(self_cos) + + # Cross similarity: cos(theta_a_u, theta_b_v) for u != v + cross_cos = [] + for u in range(N): + for v in range(N): + if u == v: + continue + cos = np.dot(thetas_a[u], thetas_b[v]) / ( + np.linalg.norm(thetas_a[u]) * np.linalg.norm(thetas_b[v]) + 1e-8) + cross_cos.append(cos) + avg_cross = np.mean(cross_cos) + + delta_self = avg_self - avg_cross + return avg_self, avg_cross, delta_self + + +def compute_ridge_probe(thetas, style_protos): + """Probe: predict each style feature from theta using Ridge regression.""" + results = {} + N = len(thetas) + + for i, feat_name in enumerate(FEATURE_NAMES): + y = style_protos[:, i] + + # Check if target has variance + if np.std(y) < 1e-8: + results[feat_name] = 0.0 + continue + + ridge = Ridge(alpha=1.0) + scores = cross_val_score(ridge, thetas, y, cv=min(5, N), scoring='r2') + results[feat_name] = max(np.mean(scores), 0.0) # Clip negative R2 to 0 + + return results + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--num_eval', type=int, default=200) + parser.add_argument('--config', type=str, default='product_review_user') + args = parser.parse_args() + + N = args.num_eval + print(f"=== Theta Analysis: {args.config}, N={N} ===") + + print("\nLoading data...") + examples = load_longlamp(args.config, split='val')[:N] + print(f"Loaded {len(examples)} examples") + + print("\nLoading model...") + wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1') + + # 1. Collect thetas and style prototypes + print("\n--- Collecting thetas and style prototypes ---") + thetas, style_protos, user_ids = collect_thetas_and_styles(wrapper, examples, K=4, seed=0) + print(f"Collected {len(thetas)} vectors") + + # 2. RSA + print("\n--- RSA (Representational Similarity Analysis) ---") + rho_all, pval_all = compute_rsa(thetas, style_protos) + # Exclude length (index 0) and newline_rate (index 3) + rho_nolen, pval_nolen = compute_rsa(thetas, style_protos, exclude_indices=[0, 3]) + print(f" rho_all: {rho_all:.4f} (p={pval_all:.2e})") + print(f" rho_-len/newline: {rho_nolen:.4f} (p={pval_nolen:.2e})") + + # 3. Self-consistency + print("\n--- Self-Consistency ---") + avg_self, avg_cross, delta_self = compute_self_consistency(wrapper, examples, K=4) + print(f" avg_self_cos: {avg_self:.4f}") + print(f" avg_cross_cos: {avg_cross:.4f}") + print(f" Delta_self: {delta_self:.4f}") + + # 4. Ridge probe + print("\n--- Ridge Probe (R^2) ---") + probe_results = compute_ridge_probe(thetas, style_protos) + for feat_name in FEATURE_NAMES: + r2 = probe_results[feat_name] + print(f" {feat_name:<20}: R^2 = {r2:.4f}") + + # Summary: the 6 key numbers + print("\n" + "=" * 60) + print("KEY NUMBERS FOR PAPER DECISION") + print("=" * 60) + print(f" rho_all: {rho_all:.4f}") + print(f" rho_-len/newline: {rho_nolen:.4f}") + print(f" Delta_self: {delta_self:.4f}") + print(f" R^2_TTR: {probe_results.get('TTR', 0.0):.4f}") + print(f" R^2_first_person: {probe_results.get('first_person_rate', 0.0):.4f}") + print(f" R^2_newline: {probe_results.get('newline_rate', 0.0):.4f}") + + # Save results + os.makedirs('outputs/analysis', exist_ok=True) + save_data = { + 'rsa_all': {'rho': float(rho_all), 'pval': float(pval_all)}, + 'rsa_nolen': {'rho': float(rho_nolen), 'pval': float(pval_nolen)}, + 'self_consistency': {'avg_self': float(avg_self), 'avg_cross': float(avg_cross), 'delta_self': float(delta_self)}, + 'probe_r2': {k: float(v) for k, v in probe_results.items()}, + 'num_users': len(thetas), + 'thetas': [[float(x) for x in row] for row in thetas], + 'style_protos': [[float(x) for x in row] for row in style_protos], + 'user_ids': user_ids, + } + with open('outputs/analysis/theta_analysis.json', 'w') as f: + json.dump(save_data, f, indent=2) + print("\nSaved to outputs/analysis/theta_analysis.json") + + +if __name__ == '__main__': + main() |
