summaryrefslogtreecommitdiff
path: root/scripts/theta_analysis.py
blob: 94d4010cb83c3b540139e6a2edfdc96af918a43e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
"""User-state geometry / representational alignment analysis.

Computes:
1. RSA: Spearman(cos(theta_u, theta_v), cos(phi_u, phi_v)) for all-style and -len/newline
2. Self-consistency: Delta_self = E_u[cos(theta_a, theta_b)] - E_{u!=v}[cos(theta_a, theta_v)]
3. Ridge probe: R^2 for predicting style features from theta
4. PCA visualization
"""

import sys
import os
import json
import numpy as np
from scipy import stats
from sklearn.linear_model import Ridge
from sklearn.model_selection import cross_val_score
import torch

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from data.longlamp import load_longlamp, select_k_profile_items
from data.style_features import extract_style_features, FEATURE_NAMES
from models.qwen_wrapper import QwenWrapper
from models.cvh import UnconditionalHead
from adapt.cache_hidden import cache_support_hidden_states
from adapt.fit_theta import fit_theta


def collect_thetas_and_styles(wrapper, examples, K=4, seed=0):
    """Collect theta_u and style prototypes for all users."""
    device = 'cuda:1'
    H = wrapper.hidden_size
    head = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device)

    lm_head_bias = None
    if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
        lm_head_bias = wrapper.model.lm_head.bias.data

    thetas = []
    style_protos = []
    user_ids = []

    for i, ex in enumerate(examples):
        support = select_k_profile_items(ex['profile_items'], K, seed=seed)
        cached_h = cache_support_hidden_states(wrapper, support, ex['task'])
        if not cached_h:
            continue

        theta = fit_theta(
            cached_h=cached_h,
            lm_head_weight=wrapper.lm_head_weight,
            lm_head_bias=lm_head_bias,
            head_module=head,
            d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4,
            max_grad_norm=5.0, device=device, verbose=False,
        )

        thetas.append(theta.cpu().numpy())

        # Compute style prototype
        support_texts = [s['support_output'] for s in support]
        features_list = [extract_style_features(t) for t in support_texts]
        proto = np.mean(features_list, axis=0)
        style_protos.append(proto)
        user_ids.append(ex['user_id'])

        del cached_h, theta
        torch.cuda.empty_cache()

        if (i + 1) % 40 == 0:
            print(f"  Collected {i+1}/{len(examples)}")

    return np.array(thetas), np.array(style_protos), user_ids


def compute_rsa(thetas, style_protos, exclude_indices=None):
    """Compute RSA: Spearman correlation between theta similarity and style similarity."""
    N = len(thetas)

    # Theta cosine similarity matrix
    theta_norms = np.linalg.norm(thetas, axis=1, keepdims=True)
    theta_norms = np.maximum(theta_norms, 1e-8)
    theta_normed = thetas / theta_norms
    theta_sim = theta_normed @ theta_normed.T

    # Style cosine similarity matrix
    if exclude_indices is not None:
        style = np.delete(style_protos, exclude_indices, axis=1)
    else:
        style = style_protos.copy()

    style_norms = np.linalg.norm(style, axis=1, keepdims=True)
    style_norms = np.maximum(style_norms, 1e-8)
    style_normed = style / style_norms
    style_sim = style_normed @ style_normed.T

    # Extract upper triangle
    idx = np.triu_indices(N, k=1)
    theta_upper = theta_sim[idx]
    style_upper = style_sim[idx]

    rho, pval = stats.spearmanr(theta_upper, style_upper)
    return rho, pval


def compute_self_consistency(wrapper, examples, K=4):
    """Compute self-consistency by fitting theta with different support subsets."""
    device = 'cuda:1'
    H = wrapper.hidden_size
    head = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device)

    lm_head_bias = None
    if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
        lm_head_bias = wrapper.model.lm_head.bias.data

    thetas_a = []
    thetas_b = []
    valid_indices = []

    for i, ex in enumerate(examples):
        profile = ex['profile_items']
        if len(profile) < 2 * K:
            continue

        # Two different subsets
        support_a = select_k_profile_items(profile, K, seed=100)
        support_b = select_k_profile_items(profile, K, seed=200)

        cached_a = cache_support_hidden_states(wrapper, support_a, ex['task'])
        cached_b = cache_support_hidden_states(wrapper, support_b, ex['task'])

        if not cached_a or not cached_b:
            continue

        theta_a = fit_theta(
            cached_h=cached_a, lm_head_weight=wrapper.lm_head_weight,
            lm_head_bias=lm_head_bias, head_module=head,
            d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4,
            max_grad_norm=5.0, device=device, verbose=False,
        )
        theta_b = fit_theta(
            cached_h=cached_b, lm_head_weight=wrapper.lm_head_weight,
            lm_head_bias=lm_head_bias, head_module=head,
            d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4,
            max_grad_norm=5.0, device=device, verbose=False,
        )

        thetas_a.append(theta_a.cpu().numpy())
        thetas_b.append(theta_b.cpu().numpy())
        valid_indices.append(i)

        del cached_a, cached_b, theta_a, theta_b
        torch.cuda.empty_cache()

        if (i + 1) % 40 == 0:
            print(f"  Self-consistency: {i+1}/{len(examples)} ({len(valid_indices)} valid)")

    thetas_a = np.array(thetas_a)
    thetas_b = np.array(thetas_b)
    N = len(thetas_a)

    if N < 5:
        return 0.0, 0.0, 0.0

    # Self similarity: cos(theta_a_u, theta_b_u)
    self_cos = []
    for u in range(N):
        cos = np.dot(thetas_a[u], thetas_b[u]) / (
            np.linalg.norm(thetas_a[u]) * np.linalg.norm(thetas_b[u]) + 1e-8)
        self_cos.append(cos)
    avg_self = np.mean(self_cos)

    # Cross similarity: cos(theta_a_u, theta_b_v) for u != v
    cross_cos = []
    for u in range(N):
        for v in range(N):
            if u == v:
                continue
            cos = np.dot(thetas_a[u], thetas_b[v]) / (
                np.linalg.norm(thetas_a[u]) * np.linalg.norm(thetas_b[v]) + 1e-8)
            cross_cos.append(cos)
    avg_cross = np.mean(cross_cos)

    delta_self = avg_self - avg_cross
    return avg_self, avg_cross, delta_self


def compute_ridge_probe(thetas, style_protos):
    """Probe: predict each style feature from theta using Ridge regression."""
    results = {}
    N = len(thetas)

    for i, feat_name in enumerate(FEATURE_NAMES):
        y = style_protos[:, i]

        # Check if target has variance
        if np.std(y) < 1e-8:
            results[feat_name] = 0.0
            continue

        ridge = Ridge(alpha=1.0)
        scores = cross_val_score(ridge, thetas, y, cv=min(5, N), scoring='r2')
        results[feat_name] = max(np.mean(scores), 0.0)  # Clip negative R2 to 0

    return results


def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_eval', type=int, default=200)
    parser.add_argument('--config', type=str, default='product_review_user')
    args = parser.parse_args()

    N = args.num_eval
    print(f"=== Theta Analysis: {args.config}, N={N} ===")

    print("\nLoading data...")
    examples = load_longlamp(args.config, split='val')[:N]
    print(f"Loaded {len(examples)} examples")

    print("\nLoading model...")
    wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1')

    # 1. Collect thetas and style prototypes
    print("\n--- Collecting thetas and style prototypes ---")
    thetas, style_protos, user_ids = collect_thetas_and_styles(wrapper, examples, K=4, seed=0)
    print(f"Collected {len(thetas)} vectors")

    # 2. RSA
    print("\n--- RSA (Representational Similarity Analysis) ---")
    rho_all, pval_all = compute_rsa(thetas, style_protos)
    # Exclude length (index 0) and newline_rate (index 3)
    rho_nolen, pval_nolen = compute_rsa(thetas, style_protos, exclude_indices=[0, 3])
    print(f"  rho_all:           {rho_all:.4f} (p={pval_all:.2e})")
    print(f"  rho_-len/newline:  {rho_nolen:.4f} (p={pval_nolen:.2e})")

    # 3. Self-consistency
    print("\n--- Self-Consistency ---")
    avg_self, avg_cross, delta_self = compute_self_consistency(wrapper, examples, K=4)
    print(f"  avg_self_cos:  {avg_self:.4f}")
    print(f"  avg_cross_cos: {avg_cross:.4f}")
    print(f"  Delta_self:    {delta_self:.4f}")

    # 4. Ridge probe
    print("\n--- Ridge Probe (R^2) ---")
    probe_results = compute_ridge_probe(thetas, style_protos)
    for feat_name in FEATURE_NAMES:
        r2 = probe_results[feat_name]
        print(f"  {feat_name:<20}: R^2 = {r2:.4f}")

    # Summary: the 6 key numbers
    print("\n" + "=" * 60)
    print("KEY NUMBERS FOR PAPER DECISION")
    print("=" * 60)
    print(f"  rho_all:              {rho_all:.4f}")
    print(f"  rho_-len/newline:     {rho_nolen:.4f}")
    print(f"  Delta_self:           {delta_self:.4f}")
    print(f"  R^2_TTR:              {probe_results.get('TTR', 0.0):.4f}")
    print(f"  R^2_first_person:     {probe_results.get('first_person_rate', 0.0):.4f}")
    print(f"  R^2_newline:          {probe_results.get('newline_rate', 0.0):.4f}")

    # Save results
    os.makedirs('outputs/analysis', exist_ok=True)
    save_data = {
        'rsa_all': {'rho': float(rho_all), 'pval': float(pval_all)},
        'rsa_nolen': {'rho': float(rho_nolen), 'pval': float(pval_nolen)},
        'self_consistency': {'avg_self': float(avg_self), 'avg_cross': float(avg_cross), 'delta_self': float(delta_self)},
        'probe_r2': {k: float(v) for k, v in probe_results.items()},
        'num_users': len(thetas),
        'thetas': [[float(x) for x in row] for row in thetas],
        'style_protos': [[float(x) for x in row] for row in style_protos],
        'user_ids': user_ids,
    }
    with open('outputs/analysis/theta_analysis.json', 'w') as f:
        json.dump(save_data, f, indent=2)
    print("\nSaved to outputs/analysis/theta_analysis.json")


if __name__ == '__main__':
    main()