summaryrefslogtreecommitdiff
path: root/scripts/theta_analysis.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/theta_analysis.py')
-rw-r--r--scripts/theta_analysis.py281
1 files changed, 281 insertions, 0 deletions
diff --git a/scripts/theta_analysis.py b/scripts/theta_analysis.py
new file mode 100644
index 0000000..94d4010
--- /dev/null
+++ b/scripts/theta_analysis.py
@@ -0,0 +1,281 @@
+"""User-state geometry / representational alignment analysis.
+
+Computes:
+1. RSA: Spearman(cos(theta_u, theta_v), cos(phi_u, phi_v)) for all-style and -len/newline
+2. Self-consistency: Delta_self = E_u[cos(theta_a, theta_b)] - E_{u!=v}[cos(theta_a, theta_v)]
+3. Ridge probe: R^2 for predicting style features from theta
+4. PCA visualization
+"""
+
+import sys
+import os
+import json
+import numpy as np
+from scipy import stats
+from sklearn.linear_model import Ridge
+from sklearn.model_selection import cross_val_score
+import torch
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from data.longlamp import load_longlamp, select_k_profile_items
+from data.style_features import extract_style_features, FEATURE_NAMES
+from models.qwen_wrapper import QwenWrapper
+from models.cvh import UnconditionalHead
+from adapt.cache_hidden import cache_support_hidden_states
+from adapt.fit_theta import fit_theta
+
+
+def collect_thetas_and_styles(wrapper, examples, K=4, seed=0):
+ """Collect theta_u and style prototypes for all users."""
+ device = 'cuda:1'
+ H = wrapper.hidden_size
+ head = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device)
+
+ lm_head_bias = None
+ if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
+ lm_head_bias = wrapper.model.lm_head.bias.data
+
+ thetas = []
+ style_protos = []
+ user_ids = []
+
+ for i, ex in enumerate(examples):
+ support = select_k_profile_items(ex['profile_items'], K, seed=seed)
+ cached_h = cache_support_hidden_states(wrapper, support, ex['task'])
+ if not cached_h:
+ continue
+
+ theta = fit_theta(
+ cached_h=cached_h,
+ lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias,
+ head_module=head,
+ d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4,
+ max_grad_norm=5.0, device=device, verbose=False,
+ )
+
+ thetas.append(theta.cpu().numpy())
+
+ # Compute style prototype
+ support_texts = [s['support_output'] for s in support]
+ features_list = [extract_style_features(t) for t in support_texts]
+ proto = np.mean(features_list, axis=0)
+ style_protos.append(proto)
+ user_ids.append(ex['user_id'])
+
+ del cached_h, theta
+ torch.cuda.empty_cache()
+
+ if (i + 1) % 40 == 0:
+ print(f" Collected {i+1}/{len(examples)}")
+
+ return np.array(thetas), np.array(style_protos), user_ids
+
+
+def compute_rsa(thetas, style_protos, exclude_indices=None):
+ """Compute RSA: Spearman correlation between theta similarity and style similarity."""
+ N = len(thetas)
+
+ # Theta cosine similarity matrix
+ theta_norms = np.linalg.norm(thetas, axis=1, keepdims=True)
+ theta_norms = np.maximum(theta_norms, 1e-8)
+ theta_normed = thetas / theta_norms
+ theta_sim = theta_normed @ theta_normed.T
+
+ # Style cosine similarity matrix
+ if exclude_indices is not None:
+ style = np.delete(style_protos, exclude_indices, axis=1)
+ else:
+ style = style_protos.copy()
+
+ style_norms = np.linalg.norm(style, axis=1, keepdims=True)
+ style_norms = np.maximum(style_norms, 1e-8)
+ style_normed = style / style_norms
+ style_sim = style_normed @ style_normed.T
+
+ # Extract upper triangle
+ idx = np.triu_indices(N, k=1)
+ theta_upper = theta_sim[idx]
+ style_upper = style_sim[idx]
+
+ rho, pval = stats.spearmanr(theta_upper, style_upper)
+ return rho, pval
+
+
+def compute_self_consistency(wrapper, examples, K=4):
+ """Compute self-consistency by fitting theta with different support subsets."""
+ device = 'cuda:1'
+ H = wrapper.hidden_size
+ head = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device)
+
+ lm_head_bias = None
+ if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
+ lm_head_bias = wrapper.model.lm_head.bias.data
+
+ thetas_a = []
+ thetas_b = []
+ valid_indices = []
+
+ for i, ex in enumerate(examples):
+ profile = ex['profile_items']
+ if len(profile) < 2 * K:
+ continue
+
+ # Two different subsets
+ support_a = select_k_profile_items(profile, K, seed=100)
+ support_b = select_k_profile_items(profile, K, seed=200)
+
+ cached_a = cache_support_hidden_states(wrapper, support_a, ex['task'])
+ cached_b = cache_support_hidden_states(wrapper, support_b, ex['task'])
+
+ if not cached_a or not cached_b:
+ continue
+
+ theta_a = fit_theta(
+ cached_h=cached_a, lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias, head_module=head,
+ d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4,
+ max_grad_norm=5.0, device=device, verbose=False,
+ )
+ theta_b = fit_theta(
+ cached_h=cached_b, lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias, head_module=head,
+ d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4,
+ max_grad_norm=5.0, device=device, verbose=False,
+ )
+
+ thetas_a.append(theta_a.cpu().numpy())
+ thetas_b.append(theta_b.cpu().numpy())
+ valid_indices.append(i)
+
+ del cached_a, cached_b, theta_a, theta_b
+ torch.cuda.empty_cache()
+
+ if (i + 1) % 40 == 0:
+ print(f" Self-consistency: {i+1}/{len(examples)} ({len(valid_indices)} valid)")
+
+ thetas_a = np.array(thetas_a)
+ thetas_b = np.array(thetas_b)
+ N = len(thetas_a)
+
+ if N < 5:
+ return 0.0, 0.0, 0.0
+
+ # Self similarity: cos(theta_a_u, theta_b_u)
+ self_cos = []
+ for u in range(N):
+ cos = np.dot(thetas_a[u], thetas_b[u]) / (
+ np.linalg.norm(thetas_a[u]) * np.linalg.norm(thetas_b[u]) + 1e-8)
+ self_cos.append(cos)
+ avg_self = np.mean(self_cos)
+
+ # Cross similarity: cos(theta_a_u, theta_b_v) for u != v
+ cross_cos = []
+ for u in range(N):
+ for v in range(N):
+ if u == v:
+ continue
+ cos = np.dot(thetas_a[u], thetas_b[v]) / (
+ np.linalg.norm(thetas_a[u]) * np.linalg.norm(thetas_b[v]) + 1e-8)
+ cross_cos.append(cos)
+ avg_cross = np.mean(cross_cos)
+
+ delta_self = avg_self - avg_cross
+ return avg_self, avg_cross, delta_self
+
+
+def compute_ridge_probe(thetas, style_protos):
+ """Probe: predict each style feature from theta using Ridge regression."""
+ results = {}
+ N = len(thetas)
+
+ for i, feat_name in enumerate(FEATURE_NAMES):
+ y = style_protos[:, i]
+
+ # Check if target has variance
+ if np.std(y) < 1e-8:
+ results[feat_name] = 0.0
+ continue
+
+ ridge = Ridge(alpha=1.0)
+ scores = cross_val_score(ridge, thetas, y, cv=min(5, N), scoring='r2')
+ results[feat_name] = max(np.mean(scores), 0.0) # Clip negative R2 to 0
+
+ return results
+
+
+def main():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--num_eval', type=int, default=200)
+ parser.add_argument('--config', type=str, default='product_review_user')
+ args = parser.parse_args()
+
+ N = args.num_eval
+ print(f"=== Theta Analysis: {args.config}, N={N} ===")
+
+ print("\nLoading data...")
+ examples = load_longlamp(args.config, split='val')[:N]
+ print(f"Loaded {len(examples)} examples")
+
+ print("\nLoading model...")
+ wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1')
+
+ # 1. Collect thetas and style prototypes
+ print("\n--- Collecting thetas and style prototypes ---")
+ thetas, style_protos, user_ids = collect_thetas_and_styles(wrapper, examples, K=4, seed=0)
+ print(f"Collected {len(thetas)} vectors")
+
+ # 2. RSA
+ print("\n--- RSA (Representational Similarity Analysis) ---")
+ rho_all, pval_all = compute_rsa(thetas, style_protos)
+ # Exclude length (index 0) and newline_rate (index 3)
+ rho_nolen, pval_nolen = compute_rsa(thetas, style_protos, exclude_indices=[0, 3])
+ print(f" rho_all: {rho_all:.4f} (p={pval_all:.2e})")
+ print(f" rho_-len/newline: {rho_nolen:.4f} (p={pval_nolen:.2e})")
+
+ # 3. Self-consistency
+ print("\n--- Self-Consistency ---")
+ avg_self, avg_cross, delta_self = compute_self_consistency(wrapper, examples, K=4)
+ print(f" avg_self_cos: {avg_self:.4f}")
+ print(f" avg_cross_cos: {avg_cross:.4f}")
+ print(f" Delta_self: {delta_self:.4f}")
+
+ # 4. Ridge probe
+ print("\n--- Ridge Probe (R^2) ---")
+ probe_results = compute_ridge_probe(thetas, style_protos)
+ for feat_name in FEATURE_NAMES:
+ r2 = probe_results[feat_name]
+ print(f" {feat_name:<20}: R^2 = {r2:.4f}")
+
+ # Summary: the 6 key numbers
+ print("\n" + "=" * 60)
+ print("KEY NUMBERS FOR PAPER DECISION")
+ print("=" * 60)
+ print(f" rho_all: {rho_all:.4f}")
+ print(f" rho_-len/newline: {rho_nolen:.4f}")
+ print(f" Delta_self: {delta_self:.4f}")
+ print(f" R^2_TTR: {probe_results.get('TTR', 0.0):.4f}")
+ print(f" R^2_first_person: {probe_results.get('first_person_rate', 0.0):.4f}")
+ print(f" R^2_newline: {probe_results.get('newline_rate', 0.0):.4f}")
+
+ # Save results
+ os.makedirs('outputs/analysis', exist_ok=True)
+ save_data = {
+ 'rsa_all': {'rho': float(rho_all), 'pval': float(pval_all)},
+ 'rsa_nolen': {'rho': float(rho_nolen), 'pval': float(pval_nolen)},
+ 'self_consistency': {'avg_self': float(avg_self), 'avg_cross': float(avg_cross), 'delta_self': float(delta_self)},
+ 'probe_r2': {k: float(v) for k, v in probe_results.items()},
+ 'num_users': len(thetas),
+ 'thetas': [[float(x) for x in row] for row in thetas],
+ 'style_protos': [[float(x) for x in row] for row in style_protos],
+ 'user_ids': user_ids,
+ }
+ with open('outputs/analysis/theta_analysis.json', 'w') as f:
+ json.dump(save_data, f, indent=2)
+ print("\nSaved to outputs/analysis/theta_analysis.json")
+
+
+if __name__ == '__main__':
+ main()