summaryrefslogtreecommitdiff
path: root/scripts/compute_bertscore.py
blob: 4fb1dc25ce211f8b51a2313fc88eb26c372aa432 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""Compute BERTScore from saved per-user predictions.

Uses saved predictions from significance tests (UPH, Base) and PEFT per-user data.

Usage:
    python scripts/compute_bertscore.py --task review --setting user --device cuda:0
"""

import sys
import os
import json
import numpy as np
from scipy import stats

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))


def paired_test(scores_a, scores_b, name_a, name_b):
    a = np.array(scores_a)
    b = np.array(scores_b)
    diff = a - b

    mean_a, mean_b = np.mean(a), np.mean(b)
    mean_diff = np.mean(diff)

    t_stat, t_pval = stats.ttest_rel(a, b)
    try:
        w_stat, w_pval = stats.wilcoxon(a, b)
    except ValueError:
        w_stat, w_pval = float('nan'), float('nan')

    se = stats.sem(diff)
    ci_low = mean_diff - 1.96 * se
    ci_high = mean_diff + 1.96 * se

    print(f"  {name_a} vs {name_b}:")
    print(f"    Mean {name_a}: {mean_a:.4f}, Mean {name_b}: {mean_b:.4f}, Diff: {mean_diff:+.4f}")
    print(f"    95% CI: [{ci_low:+.4f}, {ci_high:+.4f}]")
    print(f"    t-test: p={t_pval:.2e}, Wilcoxon: p={w_pval:.2e}")

    return {
        'mean_a': float(mean_a), 'mean_b': float(mean_b),
        'mean_diff': float(mean_diff),
        'ci_low': float(ci_low), 'ci_high': float(ci_high),
        't_pval': float(t_pval), 'w_pval': float(w_pval),
    }


def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--task', type=str, default='review')
    parser.add_argument('--setting', type=str, default='user')
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--bert_model', type=str, default='roberta-large')
    args = parser.parse_args()

    task = args.task
    setting = args.setting
    N = 200

    # Load saved predictions
    sig_path = f"outputs/significance/{task}_{setting}_significance.json"
    peft_path = f"outputs/peft_baselines/{task}_{setting}_K4_N{N}_peft_per_user.json"

    if not os.path.exists(sig_path):
        print(f"Significance data not found: {sig_path}")
        return
    if not os.path.exists(peft_path):
        print(f"PEFT per-user data not found: {peft_path}")
        return

    with open(sig_path) as f:
        sig_data = json.load(f)
    with open(peft_path) as f:
        peft_data = json.load(f)

    # Collect all predictions and references
    all_preds = {}
    all_refs = {}

    # UPH and Base from significance data
    all_preds['UPH'] = sig_data['uph_predictions']
    all_preds['Base'] = sig_data['base_predictions']

    # References (same for all methods)
    refs = [u['reference'] for u in peft_data['per_user']['lora']]

    # PEFT predictions
    for method in ['lora', 'tiny_lora', 'vera']:
        all_preds[method] = [u['prediction'] for u in peft_data['per_user'][method]]

    print(f"=== BERTScore: {task}_{setting}, N={len(refs)} ===")
    print(f"Model: {args.bert_model}")
    print(f"Methods: {list(all_preds.keys())}")

    # Compute BERTScore for each method
    from bert_score import score as bert_score_fn

    all_bertscore = {}
    for method, preds in all_preds.items():
        print(f"\n  Computing BERTScore for {method}...")
        P, R, F1 = bert_score_fn(
            preds, refs,
            model_type=args.bert_model,
            device=args.device,
            verbose=False,
        )
        all_bertscore[method] = F1.tolist()
        print(f"    Mean F1: {np.mean(F1.tolist()):.4f}")

    # Summary table
    print("\n" + "=" * 60)
    print("BERTScore F1 Summary")
    print("=" * 60)
    for method in all_preds:
        scores = all_bertscore[method]
        print(f"  {method:<15} Mean: {np.mean(scores):.4f}, Std: {np.std(scores):.4f}")

    # Significance tests
    print("\n" + "=" * 60)
    print("Significance Tests — BERTScore F1 (paired)")
    print("=" * 60)

    test_results = {}
    for other in ['Base', 'lora', 'tiny_lora', 'vera']:
        r = paired_test(all_bertscore['UPH'], all_bertscore[other], 'UPH', other)
        test_results[f'UPH_vs_{other}'] = r

    # Save
    output_path = f"outputs/significance/{task}_{setting}_bertscore.json"
    with open(output_path, 'w') as f:
        json.dump({
            'bertscore_f1': all_bertscore,
            'significance_tests': test_results,
            'model': args.bert_model,
            'task': task,
            'setting': setting,
            'num_examples': len(refs),
        }, f, indent=2)
    print(f"\nSaved to {output_path}")


if __name__ == '__main__':
    main()