diff options
Diffstat (limited to 'eval/metrics.py')
| -rw-r--r-- | eval/metrics.py | 130 |
1 files changed, 130 insertions, 0 deletions
diff --git a/eval/metrics.py b/eval/metrics.py new file mode 100644 index 0000000..a9b91fa --- /dev/null +++ b/eval/metrics.py @@ -0,0 +1,130 @@ +"""Evaluation metrics: ROUGE-1, ROUGE-L, METEOR, SFD, Recovery, Compression.""" + +import sys +from rouge_score import rouge_scorer +import nltk +from nltk.translate.meteor_score import meteor_score as nltk_meteor +from nltk.tokenize import word_tokenize +from data.style_features import compute_sfd + + +def compute_rouge(predictions: list, references: list) -> dict: + """Compute ROUGE-1 and ROUGE-L F1 scores.""" + scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True) + + r1_scores = [] + rl_scores = [] + + for pred, ref in zip(predictions, references): + if not pred.strip(): + pred = "empty" + scores = scorer.score(ref, pred) + r1_scores.append(scores['rouge1'].fmeasure) + rl_scores.append(scores['rougeL'].fmeasure) + + return { + 'rouge1': sum(r1_scores) / max(len(r1_scores), 1), + 'rougeL': sum(rl_scores) / max(len(rl_scores), 1), + } + + +def compute_meteor(predictions: list, references: list) -> float: + """Compute average METEOR score.""" + scores = [] + for pred, ref in zip(predictions, references): + if not pred.strip(): + pred = "empty" + try: + ref_tokens = word_tokenize(ref) + pred_tokens = word_tokenize(pred) + score = nltk_meteor([ref_tokens], pred_tokens) + scores.append(score) + except Exception: + scores.append(0.0) + return sum(scores) / max(len(scores), 1) + + +def compute_avg_sfd(predictions: list, support_texts_per_example: list) -> float: + """Compute average SFD across examples. + + Args: + predictions: List of generated texts + support_texts_per_example: List of lists, each inner list contains + the user's support output texts + """ + sfds = [] + for pred, support_texts in zip(predictions, support_texts_per_example): + if not pred.strip(): + pred = "empty" + sfd = compute_sfd(pred, support_texts) + sfds.append(sfd) + return sum(sfds) / max(len(sfds), 1) + + +def compute_recovery(method_score: float, base_score: float, bm25_score: float) -> float: + """Compute Recovery metric. + + Recovery = (M_CVH - M_Base) / (M_BM25 - M_Base) + """ + denom = bm25_score - base_score + if abs(denom) < 1e-8: + return 0.0 + return (method_score - base_score) / denom + + +def compute_compression(support_texts: list, theta_bytes: int = 128) -> float: + """Compute Compression ratio. + + Compression = bytes_of_K_support_texts / bytes_of_theta_u + """ + total_bytes = sum(len(t.encode('utf-8')) for t in support_texts) + return total_bytes / theta_bytes + + +def evaluate_all(predictions: list, references: list, support_texts_per_example: list) -> dict: + """Run all metrics.""" + rouge = compute_rouge(predictions, references) + meteor = compute_meteor(predictions, references) + sfd = compute_avg_sfd(predictions, support_texts_per_example) + + return { + 'rouge1': rouge['rouge1'], + 'rougeL': rouge['rougeL'], + 'meteor': meteor, + 'sfd': sfd, + 'num_examples': len(predictions), + } + + +def print_results_table(results_dict: dict): + """Print a comparison table of results across methods.""" + methods = list(results_dict.keys()) + metrics = ['rouge1', 'rougeL', 'meteor', 'sfd'] + + # Header + header = f"{'Method':<25}" + "".join(f"{m:<12}" for m in metrics) + print(header) + print("-" * len(header)) + + for method in methods: + r = results_dict[method] + row = f"{method:<25}" + for m in metrics: + val = r.get(m, 0.0) + row += f"{val:<12.4f}" + print(row) + + # Recovery if base and bm25 are present + if 'Base' in results_dict and 'BM25-Top1' in results_dict: + print("\n--- Recovery ---") + base_rl = results_dict['Base']['rougeL'] + bm25_rl = results_dict['BM25-Top1']['rougeL'] + base_m = results_dict['Base']['meteor'] + bm25_m = results_dict['BM25-Top1']['meteor'] + + for method in methods: + if method in ('Base', 'BM25-Top1'): + continue + rl_rec = compute_recovery(results_dict[method]['rougeL'], base_rl, bm25_rl) + m_rec = compute_recovery(results_dict[method]['meteor'], base_m, bm25_m) + print(f" {method}: ROUGE-L Recovery={rl_rec:.3f}, METEOR Recovery={m_rec:.3f}") |
