summaryrefslogtreecommitdiff
path: root/scripts/compute_bertscore.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/compute_bertscore.py')
-rw-r--r--scripts/compute_bertscore.py145
1 files changed, 145 insertions, 0 deletions
diff --git a/scripts/compute_bertscore.py b/scripts/compute_bertscore.py
new file mode 100644
index 0000000..4fb1dc2
--- /dev/null
+++ b/scripts/compute_bertscore.py
@@ -0,0 +1,145 @@
+"""Compute BERTScore from saved per-user predictions.
+
+Uses saved predictions from significance tests (UPH, Base) and PEFT per-user data.
+
+Usage:
+ python scripts/compute_bertscore.py --task review --setting user --device cuda:0
+"""
+
+import sys
+import os
+import json
+import numpy as np
+from scipy import stats
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+
+def paired_test(scores_a, scores_b, name_a, name_b):
+ a = np.array(scores_a)
+ b = np.array(scores_b)
+ diff = a - b
+
+ mean_a, mean_b = np.mean(a), np.mean(b)
+ mean_diff = np.mean(diff)
+
+ t_stat, t_pval = stats.ttest_rel(a, b)
+ try:
+ w_stat, w_pval = stats.wilcoxon(a, b)
+ except ValueError:
+ w_stat, w_pval = float('nan'), float('nan')
+
+ se = stats.sem(diff)
+ ci_low = mean_diff - 1.96 * se
+ ci_high = mean_diff + 1.96 * se
+
+ print(f" {name_a} vs {name_b}:")
+ print(f" Mean {name_a}: {mean_a:.4f}, Mean {name_b}: {mean_b:.4f}, Diff: {mean_diff:+.4f}")
+ print(f" 95% CI: [{ci_low:+.4f}, {ci_high:+.4f}]")
+ print(f" t-test: p={t_pval:.2e}, Wilcoxon: p={w_pval:.2e}")
+
+ return {
+ 'mean_a': float(mean_a), 'mean_b': float(mean_b),
+ 'mean_diff': float(mean_diff),
+ 'ci_low': float(ci_low), 'ci_high': float(ci_high),
+ 't_pval': float(t_pval), 'w_pval': float(w_pval),
+ }
+
+
+def main():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--task', type=str, default='review')
+ parser.add_argument('--setting', type=str, default='user')
+ parser.add_argument('--device', type=str, default='cuda:0')
+ parser.add_argument('--bert_model', type=str, default='roberta-large')
+ args = parser.parse_args()
+
+ task = args.task
+ setting = args.setting
+ N = 200
+
+ # Load saved predictions
+ sig_path = f"outputs/significance/{task}_{setting}_significance.json"
+ peft_path = f"outputs/peft_baselines/{task}_{setting}_K4_N{N}_peft_per_user.json"
+
+ if not os.path.exists(sig_path):
+ print(f"Significance data not found: {sig_path}")
+ return
+ if not os.path.exists(peft_path):
+ print(f"PEFT per-user data not found: {peft_path}")
+ return
+
+ with open(sig_path) as f:
+ sig_data = json.load(f)
+ with open(peft_path) as f:
+ peft_data = json.load(f)
+
+ # Collect all predictions and references
+ all_preds = {}
+ all_refs = {}
+
+ # UPH and Base from significance data
+ all_preds['UPH'] = sig_data['uph_predictions']
+ all_preds['Base'] = sig_data['base_predictions']
+
+ # References (same for all methods)
+ refs = [u['reference'] for u in peft_data['per_user']['lora']]
+
+ # PEFT predictions
+ for method in ['lora', 'tiny_lora', 'vera']:
+ all_preds[method] = [u['prediction'] for u in peft_data['per_user'][method]]
+
+ print(f"=== BERTScore: {task}_{setting}, N={len(refs)} ===")
+ print(f"Model: {args.bert_model}")
+ print(f"Methods: {list(all_preds.keys())}")
+
+ # Compute BERTScore for each method
+ from bert_score import score as bert_score_fn
+
+ all_bertscore = {}
+ for method, preds in all_preds.items():
+ print(f"\n Computing BERTScore for {method}...")
+ P, R, F1 = bert_score_fn(
+ preds, refs,
+ model_type=args.bert_model,
+ device=args.device,
+ verbose=False,
+ )
+ all_bertscore[method] = F1.tolist()
+ print(f" Mean F1: {np.mean(F1.tolist()):.4f}")
+
+ # Summary table
+ print("\n" + "=" * 60)
+ print("BERTScore F1 Summary")
+ print("=" * 60)
+ for method in all_preds:
+ scores = all_bertscore[method]
+ print(f" {method:<15} Mean: {np.mean(scores):.4f}, Std: {np.std(scores):.4f}")
+
+ # Significance tests
+ print("\n" + "=" * 60)
+ print("Significance Tests — BERTScore F1 (paired)")
+ print("=" * 60)
+
+ test_results = {}
+ for other in ['Base', 'lora', 'tiny_lora', 'vera']:
+ r = paired_test(all_bertscore['UPH'], all_bertscore[other], 'UPH', other)
+ test_results[f'UPH_vs_{other}'] = r
+
+ # Save
+ output_path = f"outputs/significance/{task}_{setting}_bertscore.json"
+ with open(output_path, 'w') as f:
+ json.dump({
+ 'bertscore_f1': all_bertscore,
+ 'significance_tests': test_results,
+ 'model': args.bert_model,
+ 'task': task,
+ 'setting': setting,
+ 'num_examples': len(refs),
+ }, f, indent=2)
+ print(f"\nSaved to {output_path}")
+
+
+if __name__ == '__main__':
+ main()