1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
|
"""Quick sweep over alpha values to find the right perturbation scale."""
import sys
import os
import json
import time
import torch
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.longlamp import load_longlamp, select_k_profile_items
from data.templates import build_query_prompt
from models.qwen_wrapper import QwenWrapper
from models.cvh import CVHHead
from adapt.cache_hidden import cache_support_hidden_states
from adapt.fit_theta import fit_theta
from eval.metrics import evaluate_all
def run_cvh_with_params(wrapper, examples, support_sets, alpha, beta, steps, d=64, lr=0.05):
"""Run CVH with specific hyperparameters."""
device = 'cuda:1'
H = wrapper.hidden_size
head = CVHHead(H, d=d, alpha=alpha, basis_seed=42).to(device)
lm_head_bias = None
if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
lm_head_bias = wrapper.model.lm_head.bias.data
predictions = []
theta_norms = []
for i, (ex, support) in enumerate(zip(examples, support_sets)):
cached_h = cache_support_hidden_states(wrapper, support, ex['task'])
if not cached_h:
prompt = build_query_prompt(ex['query_input'], ex['task'])
pred = wrapper.generate_base(prompt, max_new_tokens=256)
predictions.append(pred)
continue
theta = fit_theta(
cached_h=cached_h,
lm_head_weight=wrapper.lm_head_weight,
lm_head_bias=lm_head_bias,
head_module=head,
d=d, lr=lr, steps=steps, beta=beta, lam=1e-4,
max_grad_norm=5.0, device=device, verbose=False,
)
theta_norms.append(theta.norm().item())
prompt = build_query_prompt(ex['query_input'], ex['task'])
pred = wrapper.generate_with_head(
prompt, theta, head.forward_fn,
max_new_tokens=256, temperature=0.0,
)
predictions.append(pred)
del cached_h, theta
torch.cuda.empty_cache()
avg_norm = sum(theta_norms) / max(len(theta_norms), 1)
return predictions, avg_norm
def main():
print("Loading data...")
examples = load_longlamp('product_review_user', split='val')[:50]
K = 4
support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples]
references = [ex['target_output'] for ex in examples]
support_texts = [[s['support_output'] for s in ss] for ss in support_sets]
print("Loading model...")
wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1')
# Run base
print("\n=== Base ===")
base_preds = []
for ex in examples:
prompt = build_query_prompt(ex['query_input'], ex['task'])
pred = wrapper.generate_base(prompt, max_new_tokens=256, temperature=0.0)
base_preds.append(pred)
base_results = evaluate_all(base_preds, references, support_texts)
print(f" ROUGE-L: {base_results['rougeL']:.4f}, METEOR: {base_results['meteor']:.4f}, SFD: {base_results['sfd']:.4f}")
# Sweep
configs = [
{'alpha': 0.1, 'beta': 0.05, 'steps': 30, 'lr': 0.05},
{'alpha': 0.3, 'beta': 0.05, 'steps': 30, 'lr': 0.05},
{'alpha': 0.5, 'beta': 0.05, 'steps': 30, 'lr': 0.05},
{'alpha': 0.3, 'beta': 0.01, 'steps': 50, 'lr': 0.05},
{'alpha': 0.5, 'beta': 0.01, 'steps': 50, 'lr': 0.05},
{'alpha': 0.3, 'beta': 0.01, 'steps': 50, 'lr': 0.1},
]
all_results = {'Base': base_results}
for cfg in configs:
name = f"a{cfg['alpha']}_b{cfg['beta']}_s{cfg['steps']}_lr{cfg['lr']}"
print(f"\n=== CVH {name} ===")
t0 = time.time()
preds, avg_norm = run_cvh_with_params(
wrapper, examples, support_sets,
alpha=cfg['alpha'], beta=cfg['beta'],
steps=cfg['steps'], lr=cfg['lr'],
)
elapsed = time.time() - t0
results = evaluate_all(preds, references, support_texts)
all_results[name] = results
print(f" ROUGE-L: {results['rougeL']:.4f}, METEOR: {results['meteor']:.4f}, "
f"SFD: {results['sfd']:.4f}, avg|theta|: {avg_norm:.3f}, time: {elapsed:.0f}s")
# Summary
print("\n" + "=" * 80)
print(f"{'Config':<40} {'ROUGE-L':<10} {'METEOR':<10} {'SFD':<10}")
print("-" * 80)
for name, r in all_results.items():
print(f"{name:<40} {r['rougeL']:<10.4f} {r['meteor']:<10.4f} {r['sfd']:<10.4f}")
if __name__ == '__main__':
main()
|