summaryrefslogtreecommitdiff
path: root/scripts/test_normalized_cvh.py
blob: 0057341eb60ee5b10e85b9ba3285b55d19999b9e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""Quick test: normalized CVH vs Uncond with blending."""

import sys
import os
import time
import torch

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from data.longlamp import load_longlamp, select_k_profile_items
from data.templates import build_query_prompt
from models.qwen_wrapper import QwenWrapper
from models.cvh import CVHHead, UnconditionalHead
from adapt.cache_hidden import cache_support_hidden_states
from adapt.fit_theta import fit_theta
from eval.metrics import evaluate_all


def run_blended(wrapper, examples, support_sets, head_module, d=64,
                beta=0.05, steps=30, lr=0.05, blend_gamma=0.5):
    device = 'cuda:1'
    lm_head_bias = None
    if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
        lm_head_bias = wrapper.model.lm_head.bias.data

    predictions = []
    theta_norms = []

    for i, (ex, support) in enumerate(zip(examples, support_sets)):
        cached_h = cache_support_hidden_states(wrapper, support, ex['task'])
        if not cached_h:
            prompt = build_query_prompt(ex['query_input'], ex['task'])
            pred = wrapper.generate_base(prompt, max_new_tokens=512)
            predictions.append(pred)
            continue

        theta = fit_theta(
            cached_h=cached_h,
            lm_head_weight=wrapper.lm_head_weight,
            lm_head_bias=lm_head_bias,
            head_module=head_module,
            d=d, lr=lr, steps=steps, beta=beta, lam=1e-4,
            max_grad_norm=5.0, device=device, verbose=(i == 0),
        )
        theta_norms.append(theta.norm().item())

        prompt = build_query_prompt(ex['query_input'], ex['task'])
        pred = wrapper.generate_with_head_blended(
            prompt, theta, head_module.forward_fn,
            blend_gamma=blend_gamma,
            max_new_tokens=512, min_new_tokens=128,
            temperature=0.0,
        )
        predictions.append(pred)

        del cached_h, theta
        torch.cuda.empty_cache()

        if (i + 1) % 20 == 0:
            print(f"    {i+1}/{len(examples)}")

    avg_norm = sum(theta_norms) / max(len(theta_norms), 1)
    avg_len = sum(len(p.split()) for p in predictions) / max(len(predictions), 1)
    return predictions, avg_norm, avg_len


def main():
    N = 100
    print(f"Loading data ({N} examples)...")
    examples = load_longlamp('product_review_user', split='val')[:N]
    K = 4
    support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples]
    references = [ex['target_output'] for ex in examples]
    support_texts = [[s['support_output'] for s in ss] for ss in support_sets]

    print("Loading model...")
    wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1')
    H = wrapper.hidden_size
    device = 'cuda:1'

    # Base
    print("\n=== Base ===")
    base_preds = []
    for i, ex in enumerate(examples):
        prompt = build_query_prompt(ex['query_input'], ex['task'])
        pred = wrapper.generate_base(prompt, max_new_tokens=512, temperature=0.0)
        base_preds.append(pred)
        if (i + 1) % 20 == 0:
            print(f"    {i+1}/{N}")
    base_r = evaluate_all(base_preds, references, support_texts)
    base_len = sum(len(p.split()) for p in base_preds) / len(base_preds)
    print(f"  R-L: {base_r['rougeL']:.4f}, METEOR: {base_r['meteor']:.4f}, SFD: {base_r['sfd']:.4f}, len: {base_len:.0f}")

    results = {'Base': base_r}

    configs = [
        ('Uncond d=64 g=0.5', UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64, 0.5),
        ('CVH-norm d=64 g=0.5', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64, 0.5),
        ('CVH-norm d=64 g=0.3', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64, 0.3),
        ('CVH-norm d=64 g=0.7', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64, 0.7),
    ]

    for name, head, d, gamma in configs:
        print(f"\n=== {name} ===")
        t0 = time.time()
        preds, avg_norm, avg_len = run_blended(
            wrapper, examples, support_sets, head, d=d,
            beta=0.05, steps=30, lr=0.05, blend_gamma=gamma,
        )
        elapsed = time.time() - t0
        r = evaluate_all(preds, references, support_texts)
        results[name] = r
        print(f"  R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, SFD: {r['sfd']:.4f}, "
              f"|θ|: {avg_norm:.3f}, len: {avg_len:.0f}, time: {elapsed:.0f}s")

    # Summary
    print("\n" + "=" * 90)
    print(f"{'Config':<30} {'R-1':<8} {'R-L':<8} {'METEOR':<8} {'SFD':<8}")
    print("-" * 90)
    for name, r in results.items():
        print(f"{name:<30} {r['rouge1']:<8.4f} {r['rougeL']:<8.4f} "
              f"{r['meteor']:<8.4f} {r['sfd']:<8.4f}")


if __name__ == '__main__':
    main()