summaryrefslogtreecommitdiff
path: root/scripts/shift_analysis.py
blob: 99c7fd2feccc7a8e22e511069b83c3da3fa63951 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
"""Support-query distribution shift analysis.

For each user, compute:
  s_u = cos(mean_support_hidden, mean_query_hidden)
Then correlate with CVH-UPH performance gap:
  delta_u = ROUGE-L(CVH, u) - ROUGE-L(UPH, u)

If correlation is positive: CVH benefits when support-query are aligned.
"""

import sys
import os
import json
import numpy as np
from scipy import stats
import torch

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from data.longlamp import load_longlamp, select_k_profile_items
from data.templates import build_query_prompt, build_support_prompt
from models.qwen_wrapper import QwenWrapper
from models.cvh import CVHHead, UnconditionalHead
from adapt.cache_hidden import cache_support_hidden_states
from adapt.fit_theta import fit_theta
from eval.metrics import compute_rouge


def get_query_hidden_mean(wrapper, query_text, task):
    """Get mean hidden state from the query prompt."""
    chat_messages = [
        {"role": "system", "content": "You are a helpful writing assistant."},
        {"role": "user", "content": build_query_prompt(query_text, task)},
    ]
    prompt_text = wrapper.tokenizer.apply_chat_template(
        chat_messages, tokenize=False, add_generation_prompt=True
    )
    input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device)

    with torch.no_grad():
        outputs = wrapper.model(
            input_ids=input_ids,
            output_hidden_states=True,
            return_dict=True,
        )
    last_hidden = outputs.hidden_states[-1][0]  # (seq_len, H)
    return last_hidden.mean(dim=0).cpu().float().numpy()


def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_eval', type=int, default=100)
    parser.add_argument('--config', type=str, default='product_review_user')
    args = parser.parse_args()

    N = args.num_eval
    print(f"=== Shift Analysis: {args.config}, N={N} ===")

    print("Loading data...")
    examples = load_longlamp(args.config, split='val')[:N]

    print("Loading model...")
    wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1')
    H = wrapper.hidden_size
    device = 'cuda:1'

    uph_head = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device)
    cvh_head = CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device)

    lm_head_bias = None
    if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
        lm_head_bias = wrapper.model.lm_head.bias.data

    K = 4
    shift_cosines = []
    uph_rouges = []
    cvh_rouges = []

    for i, ex in enumerate(examples):
        support = select_k_profile_items(ex['profile_items'], K, seed=0)
        cached_h = cache_support_hidden_states(wrapper, support, ex['task'])
        if not cached_h:
            continue

        # Mean support hidden
        all_h = torch.cat([h for h, _ in cached_h], dim=0)
        support_mean = all_h.mean(dim=0).numpy()

        # Mean query hidden
        query_mean = get_query_hidden_mean(wrapper, ex['query_input'], ex['task'])

        # Cosine similarity
        cos = np.dot(support_mean, query_mean) / (
            np.linalg.norm(support_mean) * np.linalg.norm(query_mean) + 1e-8)
        shift_cosines.append(float(cos))

        # Fit UPH theta
        theta_uph = fit_theta(
            cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight,
            lm_head_bias=lm_head_bias, head_module=uph_head,
            d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4,
            max_grad_norm=5.0, device=device, verbose=False,
        )

        # Fit CVH theta
        theta_cvh = fit_theta(
            cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight,
            lm_head_bias=lm_head_bias, head_module=cvh_head,
            d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4,
            max_grad_norm=5.0, device=device, verbose=False,
        )

        # Generate with both
        prompt = build_query_prompt(ex['query_input'], ex['task'])
        pred_uph = wrapper.generate_with_head_blended(
            prompt, theta_uph, uph_head.forward_fn,
            blend_gamma=0.5, max_new_tokens=512, min_new_tokens=128, temperature=0.0,
        )
        pred_cvh = wrapper.generate_with_head_blended(
            prompt, theta_cvh, cvh_head.forward_fn,
            blend_gamma=0.5, max_new_tokens=512, min_new_tokens=128, temperature=0.0,
        )

        # ROUGE-L for each
        rouge_uph = compute_rouge([pred_uph], [ex['target_output']])['rougeL']
        rouge_cvh = compute_rouge([pred_cvh], [ex['target_output']])['rougeL']
        uph_rouges.append(rouge_uph)
        cvh_rouges.append(rouge_cvh)

        del cached_h, theta_uph, theta_cvh
        torch.cuda.empty_cache()

        if (i + 1) % 20 == 0:
            print(f"  {i+1}/{N}")

    # Compute correlation
    shift_cosines = np.array(shift_cosines)
    deltas = np.array(cvh_rouges) - np.array(uph_rouges)  # positive = CVH better

    rho, pval = stats.spearmanr(shift_cosines, deltas)

    print(f"\n=== Results (N={len(shift_cosines)}) ===")
    print(f"  Mean shift cosine: {shift_cosines.mean():.4f} +/- {shift_cosines.std():.4f}")
    print(f"  Mean delta (CVH - UPH): {deltas.mean():.4f} +/- {deltas.std():.4f}")
    print(f"  Spearman(shift_cos, delta): rho={rho:.4f}, p={pval:.4f}")
    print(f"  Mean UPH ROUGE-L: {np.mean(uph_rouges):.4f}")
    print(f"  Mean CVH ROUGE-L: {np.mean(cvh_rouges):.4f}")

    # Bin analysis: high vs low shift
    median_cos = np.median(shift_cosines)
    high_mask = shift_cosines >= median_cos
    low_mask = shift_cosines < median_cos

    print(f"\n  High-alignment (cos >= {median_cos:.3f}, n={high_mask.sum()}):")
    print(f"    UPH R-L: {np.mean(np.array(uph_rouges)[high_mask]):.4f}")
    print(f"    CVH R-L: {np.mean(np.array(cvh_rouges)[high_mask]):.4f}")
    print(f"  Low-alignment (cos < {median_cos:.3f}, n={low_mask.sum()}):")
    print(f"    UPH R-L: {np.mean(np.array(uph_rouges)[low_mask]):.4f}")
    print(f"    CVH R-L: {np.mean(np.array(cvh_rouges)[low_mask]):.4f}")

    # Save
    os.makedirs('outputs/analysis', exist_ok=True)
    save_data = {
        'shift_cosines': [float(x) for x in shift_cosines],
        'uph_rouges': [float(x) for x in uph_rouges],
        'cvh_rouges': [float(x) for x in cvh_rouges],
        'deltas': [float(x) for x in deltas],
        'spearman_rho': float(rho),
        'spearman_pval': float(pval),
    }
    with open('outputs/analysis/shift_analysis.json', 'w') as f:
        json.dump(save_data, f, indent=2)
    print("\nSaved to outputs/analysis/shift_analysis.json")


if __name__ == '__main__':
    main()