1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
|
"""Support-query distribution shift analysis.
For each user, compute:
s_u = cos(mean_support_hidden, mean_query_hidden)
Then correlate with CVH-UPH performance gap:
delta_u = ROUGE-L(CVH, u) - ROUGE-L(UPH, u)
If correlation is positive: CVH benefits when support-query are aligned.
"""
import sys
import os
import json
import numpy as np
from scipy import stats
import torch
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.longlamp import load_longlamp, select_k_profile_items
from data.templates import build_query_prompt, build_support_prompt
from models.qwen_wrapper import QwenWrapper
from models.cvh import CVHHead, UnconditionalHead
from adapt.cache_hidden import cache_support_hidden_states
from adapt.fit_theta import fit_theta
from eval.metrics import compute_rouge
def get_query_hidden_mean(wrapper, query_text, task):
"""Get mean hidden state from the query prompt."""
chat_messages = [
{"role": "system", "content": "You are a helpful writing assistant."},
{"role": "user", "content": build_query_prompt(query_text, task)},
]
prompt_text = wrapper.tokenizer.apply_chat_template(
chat_messages, tokenize=False, add_generation_prompt=True
)
input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device)
with torch.no_grad():
outputs = wrapper.model(
input_ids=input_ids,
output_hidden_states=True,
return_dict=True,
)
last_hidden = outputs.hidden_states[-1][0] # (seq_len, H)
return last_hidden.mean(dim=0).cpu().float().numpy()
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--num_eval', type=int, default=100)
parser.add_argument('--config', type=str, default='product_review_user')
args = parser.parse_args()
N = args.num_eval
print(f"=== Shift Analysis: {args.config}, N={N} ===")
print("Loading data...")
examples = load_longlamp(args.config, split='val')[:N]
print("Loading model...")
wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1')
H = wrapper.hidden_size
device = 'cuda:1'
uph_head = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device)
cvh_head = CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device)
lm_head_bias = None
if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
lm_head_bias = wrapper.model.lm_head.bias.data
K = 4
shift_cosines = []
uph_rouges = []
cvh_rouges = []
for i, ex in enumerate(examples):
support = select_k_profile_items(ex['profile_items'], K, seed=0)
cached_h = cache_support_hidden_states(wrapper, support, ex['task'])
if not cached_h:
continue
# Mean support hidden
all_h = torch.cat([h for h, _ in cached_h], dim=0)
support_mean = all_h.mean(dim=0).numpy()
# Mean query hidden
query_mean = get_query_hidden_mean(wrapper, ex['query_input'], ex['task'])
# Cosine similarity
cos = np.dot(support_mean, query_mean) / (
np.linalg.norm(support_mean) * np.linalg.norm(query_mean) + 1e-8)
shift_cosines.append(float(cos))
# Fit UPH theta
theta_uph = fit_theta(
cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight,
lm_head_bias=lm_head_bias, head_module=uph_head,
d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4,
max_grad_norm=5.0, device=device, verbose=False,
)
# Fit CVH theta
theta_cvh = fit_theta(
cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight,
lm_head_bias=lm_head_bias, head_module=cvh_head,
d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4,
max_grad_norm=5.0, device=device, verbose=False,
)
# Generate with both
prompt = build_query_prompt(ex['query_input'], ex['task'])
pred_uph = wrapper.generate_with_head_blended(
prompt, theta_uph, uph_head.forward_fn,
blend_gamma=0.5, max_new_tokens=512, min_new_tokens=128, temperature=0.0,
)
pred_cvh = wrapper.generate_with_head_blended(
prompt, theta_cvh, cvh_head.forward_fn,
blend_gamma=0.5, max_new_tokens=512, min_new_tokens=128, temperature=0.0,
)
# ROUGE-L for each
rouge_uph = compute_rouge([pred_uph], [ex['target_output']])['rougeL']
rouge_cvh = compute_rouge([pred_cvh], [ex['target_output']])['rougeL']
uph_rouges.append(rouge_uph)
cvh_rouges.append(rouge_cvh)
del cached_h, theta_uph, theta_cvh
torch.cuda.empty_cache()
if (i + 1) % 20 == 0:
print(f" {i+1}/{N}")
# Compute correlation
shift_cosines = np.array(shift_cosines)
deltas = np.array(cvh_rouges) - np.array(uph_rouges) # positive = CVH better
rho, pval = stats.spearmanr(shift_cosines, deltas)
print(f"\n=== Results (N={len(shift_cosines)}) ===")
print(f" Mean shift cosine: {shift_cosines.mean():.4f} +/- {shift_cosines.std():.4f}")
print(f" Mean delta (CVH - UPH): {deltas.mean():.4f} +/- {deltas.std():.4f}")
print(f" Spearman(shift_cos, delta): rho={rho:.4f}, p={pval:.4f}")
print(f" Mean UPH ROUGE-L: {np.mean(uph_rouges):.4f}")
print(f" Mean CVH ROUGE-L: {np.mean(cvh_rouges):.4f}")
# Bin analysis: high vs low shift
median_cos = np.median(shift_cosines)
high_mask = shift_cosines >= median_cos
low_mask = shift_cosines < median_cos
print(f"\n High-alignment (cos >= {median_cos:.3f}, n={high_mask.sum()}):")
print(f" UPH R-L: {np.mean(np.array(uph_rouges)[high_mask]):.4f}")
print(f" CVH R-L: {np.mean(np.array(cvh_rouges)[high_mask]):.4f}")
print(f" Low-alignment (cos < {median_cos:.3f}, n={low_mask.sum()}):")
print(f" UPH R-L: {np.mean(np.array(uph_rouges)[low_mask]):.4f}")
print(f" CVH R-L: {np.mean(np.array(cvh_rouges)[low_mask]):.4f}")
# Save
os.makedirs('outputs/analysis', exist_ok=True)
save_data = {
'shift_cosines': [float(x) for x in shift_cosines],
'uph_rouges': [float(x) for x in uph_rouges],
'cvh_rouges': [float(x) for x in cvh_rouges],
'deltas': [float(x) for x in deltas],
'spearman_rho': float(rho),
'spearman_pval': float(pval),
}
with open('outputs/analysis/shift_analysis.json', 'w') as f:
json.dump(save_data, f, indent=2)
print("\nSaved to outputs/analysis/shift_analysis.json")
if __name__ == '__main__':
main()
|