diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-04 18:59:35 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-04 18:59:35 -0600 |
| commit | f1c2cc22d46a6976df3555391e667c7e61592fad (patch) | |
| tree | 0b37b52c8ff91042a742d3b3ec54542cb6d6e2f6 /analyze_results.py | |
Diffstat (limited to 'analyze_results.py')
| -rw-r--r-- | analyze_results.py | 741 |
1 files changed, 741 insertions, 0 deletions
diff --git a/analyze_results.py b/analyze_results.py new file mode 100644 index 0000000..265c005 --- /dev/null +++ b/analyze_results.py @@ -0,0 +1,741 @@ +#!/usr/bin/env python3 +# analyze_results.py +""" +Results Analysis for RLVR Floating-Point Precision Experiments. + +This script analyzes the experimental results to verify the hypotheses: +1. On-task performance is insensitive to floating-point noise +2. Off-task performance and KL divergence are sensitive to precision + +Computes: +- Mean and variance of ΔJ_k for each precision mode +- KL divergence patterns for on-task vs off-task +- Statistical tests for significance +- Visualizations + +Usage: + python analyze_results.py \ + --results_dir results/eval_metrics \ + --output_dir results/analysis +""" + +import argparse +import json +import os +import glob +from typing import Dict, Any, List, Tuple, Optional +from collections import defaultdict +import logging + +import numpy as np +from scipy import stats + +# Optional: for visualizations +try: + import matplotlib.pyplot as plt + import matplotlib + matplotlib.use('Agg') # Non-interactive backend + HAS_MATPLOTLIB = True +except ImportError: + HAS_MATPLOTLIB = False + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Data Loading +# ============================================================================ + +def load_eval_results(results_dir: str) -> Dict[str, Dict[str, Any]]: + """ + Load all evaluation results from a directory. + + Returns: + Dictionary mapping "{precision_mode}_seed{seed}" to results + """ + results = {} + + pattern = os.path.join(results_dir, "*.json") + for filepath in glob.glob(pattern): + filename = os.path.basename(filepath) + + # Skip sparsity files + if "sparsity" in filename: + continue + + # Parse filename: {precision_mode}_seed{seed}.json + name = filename.replace(".json", "") + + try: + with open(filepath, "r") as f: + data = json.load(f) + results[name] = data + logger.info(f"Loaded {filepath}") + except Exception as e: + logger.warning(f"Failed to load {filepath}: {e}") + + return results + + +def load_sparsity_results(results_dir: str) -> Dict[str, Dict[str, Any]]: + """Load bf16 sparsity analysis results.""" + results = {} + + pattern = os.path.join(results_dir, "*_sparsity.json") + for filepath in glob.glob(pattern): + filename = os.path.basename(filepath) + name = filename.replace("_sparsity.json", "") + + try: + with open(filepath, "r") as f: + data = json.load(f) + results[name] = data + logger.info(f"Loaded sparsity: {filepath}") + except Exception as e: + logger.warning(f"Failed to load {filepath}: {e}") + + return results + + +def parse_run_name(name: str) -> Tuple[str, int]: + """Parse run name to get precision mode and seed.""" + # Format: {precision_mode}_seed{seed} + parts = name.split("_seed") + if len(parts) == 2: + precision_mode = parts[0] + seed = int(parts[1]) + return precision_mode, seed + + raise ValueError(f"Cannot parse run name: {name}") + + +# ============================================================================ +# Metrics Aggregation +# ============================================================================ + +def aggregate_by_precision( + results: Dict[str, Dict[str, Any]] +) -> Dict[str, Dict[str, List[float]]]: + """ + Aggregate results by precision mode. + + Returns: + {precision_mode: {task_name: [scores from different seeds]}} + """ + aggregated: Dict[str, Dict[str, List[float]]] = defaultdict( + lambda: defaultdict(list) + ) + + for run_name, run_results in results.items(): + try: + precision_mode, seed = parse_run_name(run_name) + except ValueError: + continue + + tasks = run_results.get("tasks", {}) + for task_name, task_data in tasks.items(): + # Score + if "ft_avg_score" in task_data: + aggregated[precision_mode][f"{task_name}_score"].append( + task_data["ft_avg_score"] + ) + + # Delta J + if "delta_j" in task_data: + aggregated[precision_mode][f"{task_name}_delta_j"].append( + task_data["delta_j"] + ) + + # KL + if "avg_kl" in task_data: + aggregated[precision_mode][f"{task_name}_kl"].append( + task_data["avg_kl"] + ) + + return dict(aggregated) + + +def compute_statistics(values: List[float]) -> Dict[str, float]: + """Compute statistics for a list of values.""" + if not values: + return { + "mean": 0.0, + "std": 0.0, + "var": 0.0, + "min": 0.0, + "max": 0.0, + "n": 0, + } + + arr = np.array(values) + return { + "mean": float(np.mean(arr)), + "std": float(np.std(arr)), + "var": float(np.var(arr)), + "min": float(np.min(arr)), + "max": float(np.max(arr)), + "n": len(arr), + } + + +# ============================================================================ +# Hypothesis Testing +# ============================================================================ + +def test_variance_ratio( + values1: List[float], + values2: List[float], + alpha: float = 0.05 +) -> Dict[str, Any]: + """ + Test if variance of values2 is significantly greater than values1. + + Uses F-test (Levene's test is more robust but F-test is simpler). + """ + if len(values1) < 2 or len(values2) < 2: + return { + "test": "variance_ratio", + "valid": False, + "reason": "Not enough samples", + } + + var1 = np.var(values1, ddof=1) + var2 = np.var(values2, ddof=1) + + # F statistic + if var1 > 0: + f_stat = var2 / var1 + else: + f_stat = float("inf") + + df1 = len(values2) - 1 + df2 = len(values1) - 1 + + # p-value (one-tailed: var2 > var1) + p_value = 1 - stats.f.cdf(f_stat, df1, df2) + + return { + "test": "variance_ratio", + "valid": True, + "var1": var1, + "var2": var2, + "f_statistic": f_stat, + "p_value": p_value, + "significant": p_value < alpha, + "alpha": alpha, + } + + +def test_mean_difference( + values1: List[float], + values2: List[float], + alpha: float = 0.05 +) -> Dict[str, Any]: + """ + Test if means are significantly different (two-tailed t-test). + """ + if len(values1) < 2 or len(values2) < 2: + return { + "test": "mean_difference", + "valid": False, + "reason": "Not enough samples", + } + + t_stat, p_value = stats.ttest_ind(values1, values2) + + return { + "test": "mean_difference", + "valid": True, + "mean1": float(np.mean(values1)), + "mean2": float(np.mean(values2)), + "t_statistic": float(t_stat), + "p_value": float(p_value), + "significant": p_value < alpha, + "alpha": alpha, + } + + +# ============================================================================ +# Hypothesis Verification +# ============================================================================ + +def verify_hypothesis_1( + aggregated: Dict[str, Dict[str, List[float]]], + on_task_names: List[str] +) -> Dict[str, Any]: + """ + Verify Hypothesis 1: On-task performance is insensitive to precision. + + Expected: + - E[ΔJ_0^{fp32}] ≈ E[ΔJ_0^{bf16}] > 0 + - Var[ΔJ_0^{fp32}] and Var[ΔJ_0^{bf16}] are both small + """ + results = { + "hypothesis": "On-task performance insensitive to precision", + "tasks": {}, + } + + fp32_data = aggregated.get("fp32", {}) + bf16_data = aggregated.get("bf16", {}) + + for task_name in on_task_names: + key = f"{task_name}_delta_j" + + fp32_values = fp32_data.get(key, []) + bf16_values = bf16_data.get(key, []) + + task_result = { + "fp32": compute_statistics(fp32_values), + "bf16": compute_statistics(bf16_values), + } + + # Test mean difference (should NOT be significant) + mean_test = test_mean_difference(fp32_values, bf16_values) + task_result["mean_test"] = mean_test + + # Test variance ratio (should NOT show bf16 >> fp32) + var_test = test_variance_ratio(fp32_values, bf16_values) + task_result["variance_test"] = var_test + + # Verify expected pattern + fp32_mean = task_result["fp32"]["mean"] + bf16_mean = task_result["bf16"]["mean"] + fp32_var = task_result["fp32"]["var"] + bf16_var = task_result["bf16"]["var"] + + task_result["verification"] = { + "means_similar": abs(fp32_mean - bf16_mean) < 0.05, # Within 5% + "both_positive": fp32_mean > 0 and bf16_mean > 0, + "variances_small": fp32_var < 0.01 and bf16_var < 0.01, + "hypothesis_supported": ( + abs(fp32_mean - bf16_mean) < 0.05 and + fp32_mean > 0 and bf16_mean > 0 + ), + } + + results["tasks"][task_name] = task_result + + # Overall verdict + all_supported = all( + t["verification"]["hypothesis_supported"] + for t in results["tasks"].values() + ) + results["overall_supported"] = all_supported + + return results + + +def verify_hypothesis_2( + aggregated: Dict[str, Dict[str, List[float]]], + off_task_names: List[str], + on_task_names: List[str] +) -> Dict[str, Any]: + """ + Verify Hypothesis 2: Off-task performance is sensitive to precision. + + Expected: + - Var[ΔJ_k^{bf16}] >> Var[ΔJ_k^{fp32}] >> Var[ΔJ_0] + """ + results = { + "hypothesis": "Off-task performance sensitive to precision", + "tasks": {}, + } + + fp32_data = aggregated.get("fp32", {}) + bf16_data = aggregated.get("bf16", {}) + + # Get on-task variance for comparison + on_task_variances = [] + for task_name in on_task_names: + key = f"{task_name}_delta_j" + for precision_data in [fp32_data, bf16_data]: + values = precision_data.get(key, []) + if values: + on_task_variances.append(np.var(values)) + + on_task_avg_var = np.mean(on_task_variances) if on_task_variances else 0.0 + + for task_name in off_task_names: + key = f"{task_name}_delta_j" + + fp32_values = fp32_data.get(key, []) + bf16_values = bf16_data.get(key, []) + + task_result = { + "fp32": compute_statistics(fp32_values), + "bf16": compute_statistics(bf16_values), + "on_task_avg_var": on_task_avg_var, + } + + # Test variance ratio (bf16 should be >> fp32) + var_test = test_variance_ratio(fp32_values, bf16_values) + task_result["variance_test"] = var_test + + # Verify expected pattern + fp32_var = task_result["fp32"]["var"] + bf16_var = task_result["bf16"]["var"] + + task_result["verification"] = { + "bf16_var_gt_fp32": bf16_var > fp32_var, + "bf16_var_gt_fp32_by_5x": bf16_var > 5 * fp32_var if fp32_var > 0 else False, + "fp32_var_gt_ontask": fp32_var > on_task_avg_var, + "variance_ratio": bf16_var / fp32_var if fp32_var > 0 else float("inf"), + "hypothesis_supported": bf16_var > fp32_var, + } + + results["tasks"][task_name] = task_result + + # Count how many tasks show expected pattern + supported_count = sum( + 1 for t in results["tasks"].values() + if t["verification"]["hypothesis_supported"] + ) + results["num_tasks_supported"] = supported_count + results["num_tasks_total"] = len(off_task_names) + results["overall_supported"] = supported_count > len(off_task_names) // 2 + + return results + + +def verify_hypothesis_3( + aggregated: Dict[str, Dict[str, List[float]]], + on_task_names: List[str], + off_task_names: List[str] +) -> Dict[str, Any]: + """ + Verify Hypothesis 3: KL divergence patterns. + + Expected: + - On-task KL is similar between fp32 and bf16 (DAPO implicit leash) + - Off-task KL has higher variance in bf16 + """ + results = { + "hypothesis": "KL divergence patterns differ by task type", + "on_task": {}, + "off_task": {}, + } + + fp32_data = aggregated.get("fp32", {}) + bf16_data = aggregated.get("bf16", {}) + + # On-task KL analysis + for task_name in on_task_names: + key = f"{task_name}_kl" + + fp32_values = fp32_data.get(key, []) + bf16_values = bf16_data.get(key, []) + + task_result = { + "fp32": compute_statistics(fp32_values), + "bf16": compute_statistics(bf16_values), + } + + mean_test = test_mean_difference(fp32_values, bf16_values) + task_result["mean_test"] = mean_test + + # Verify: KL should be similar (implicit leash working) + task_result["kl_similar"] = not mean_test.get("significant", True) + + results["on_task"][task_name] = task_result + + # Off-task KL analysis + for task_name in off_task_names: + key = f"{task_name}_kl" + + fp32_values = fp32_data.get(key, []) + bf16_values = bf16_data.get(key, []) + + task_result = { + "fp32": compute_statistics(fp32_values), + "bf16": compute_statistics(bf16_values), + } + + var_test = test_variance_ratio(fp32_values, bf16_values) + task_result["variance_test"] = var_test + + # Verify: bf16 should have higher variance + task_result["bf16_higher_variance"] = var_test.get("significant", False) + + results["off_task"][task_name] = task_result + + # Overall assessment + on_task_similar = all( + t.get("kl_similar", False) for t in results["on_task"].values() + ) + off_task_variance_higher = sum( + 1 for t in results["off_task"].values() + if t.get("bf16_higher_variance", False) + ) + + results["summary"] = { + "on_task_kl_similar": on_task_similar, + "off_task_higher_variance_count": off_task_variance_higher, + "off_task_total": len(off_task_names), + } + + return results + + +# ============================================================================ +# Visualization +# ============================================================================ + +def plot_delta_j_comparison( + aggregated: Dict[str, Dict[str, List[float]]], + task_names: List[str], + output_path: str +) -> None: + """Plot ΔJ comparison between precision modes.""" + if not HAS_MATPLOTLIB: + logger.warning("matplotlib not available, skipping plot") + return + + fig, ax = plt.subplots(figsize=(12, 6)) + + x = np.arange(len(task_names)) + width = 0.35 + + fp32_data = aggregated.get("fp32", {}) + bf16_data = aggregated.get("bf16", {}) + + fp32_means = [] + fp32_stds = [] + bf16_means = [] + bf16_stds = [] + + for task_name in task_names: + key = f"{task_name}_delta_j" + + fp32_values = fp32_data.get(key, [0]) + bf16_values = bf16_data.get(key, [0]) + + fp32_means.append(np.mean(fp32_values)) + fp32_stds.append(np.std(fp32_values)) + bf16_means.append(np.mean(bf16_values)) + bf16_stds.append(np.std(bf16_values)) + + ax.bar(x - width/2, fp32_means, width, yerr=fp32_stds, + label='FP32', color='steelblue', capsize=5) + ax.bar(x + width/2, bf16_means, width, yerr=bf16_stds, + label='bf16', color='coral', capsize=5) + + ax.set_ylabel('ΔJ (Performance Delta)') + ax.set_xlabel('Task') + ax.set_title('Performance Delta by Precision Mode') + ax.set_xticks(x) + ax.set_xticklabels(task_names, rotation=45, ha='right') + ax.legend() + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + + plt.tight_layout() + plt.savefig(output_path, dpi=150) + plt.close() + + logger.info(f"Saved plot to {output_path}") + + +def plot_variance_comparison( + aggregated: Dict[str, Dict[str, List[float]]], + task_names: List[str], + output_path: str +) -> None: + """Plot variance comparison between precision modes.""" + if not HAS_MATPLOTLIB: + logger.warning("matplotlib not available, skipping plot") + return + + fig, ax = plt.subplots(figsize=(12, 6)) + + x = np.arange(len(task_names)) + width = 0.35 + + fp32_data = aggregated.get("fp32", {}) + bf16_data = aggregated.get("bf16", {}) + + fp32_vars = [] + bf16_vars = [] + + for task_name in task_names: + key = f"{task_name}_delta_j" + + fp32_values = fp32_data.get(key, [0]) + bf16_values = bf16_data.get(key, [0]) + + fp32_vars.append(np.var(fp32_values)) + bf16_vars.append(np.var(bf16_values)) + + ax.bar(x - width/2, fp32_vars, width, label='FP32', color='steelblue') + ax.bar(x + width/2, bf16_vars, width, label='bf16', color='coral') + + ax.set_ylabel('Variance of ΔJ') + ax.set_xlabel('Task') + ax.set_title('Variance of Performance Delta by Precision Mode') + ax.set_xticks(x) + ax.set_xticklabels(task_names, rotation=45, ha='right') + ax.legend() + ax.set_yscale('log') + + plt.tight_layout() + plt.savefig(output_path, dpi=150) + plt.close() + + logger.info(f"Saved plot to {output_path}") + + +# ============================================================================ +# Main Analysis +# ============================================================================ + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Analyze RLVR floating-point precision experiment results" + ) + parser.add_argument( + "--results_dir", + type=str, + required=True, + help="Directory containing evaluation results" + ) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Directory to save analysis outputs" + ) + parser.add_argument( + "--on_task", + type=str, + nargs="+", + default=["dm_val"], + help="On-task (training distribution) task names" + ) + parser.add_argument( + "--off_task", + type=str, + nargs="+", + default=["aime24", "aime25", "amc23", "math500", "mmlu_stem", "humaneval"], + help="Off-task task names" + ) + return parser.parse_args() + + +def main() -> None: + """Main analysis function.""" + args = parse_args() + + # Create output directory + os.makedirs(args.output_dir, exist_ok=True) + + # Load results + logger.info(f"Loading results from {args.results_dir}") + eval_results = load_eval_results(args.results_dir) + sparsity_results = load_sparsity_results(args.results_dir) + + if not eval_results: + logger.error("No evaluation results found!") + return + + # Aggregate by precision mode + aggregated = aggregate_by_precision(eval_results) + + # Get all task names from results + all_tasks = set() + for run_results in eval_results.values(): + all_tasks.update(run_results.get("tasks", {}).keys()) + + # Filter task names to those present in results + on_task_names = [t for t in args.on_task if t in all_tasks] + off_task_names = [t for t in args.off_task if t in all_tasks] + + logger.info(f"On-task: {on_task_names}") + logger.info(f"Off-task: {off_task_names}") + + # Verify hypotheses + analysis = {} + + logger.info("\n" + "="*60) + logger.info("HYPOTHESIS 1: On-task insensitivity") + logger.info("="*60) + h1_result = verify_hypothesis_1(aggregated, on_task_names) + analysis["hypothesis_1"] = h1_result + logger.info(f"Supported: {h1_result['overall_supported']}") + + logger.info("\n" + "="*60) + logger.info("HYPOTHESIS 2: Off-task sensitivity") + logger.info("="*60) + h2_result = verify_hypothesis_2(aggregated, off_task_names, on_task_names) + analysis["hypothesis_2"] = h2_result + logger.info(f"Supported: {h2_result['overall_supported']} " + f"({h2_result['num_tasks_supported']}/{h2_result['num_tasks_total']} tasks)") + + logger.info("\n" + "="*60) + logger.info("HYPOTHESIS 3: KL divergence patterns") + logger.info("="*60) + h3_result = verify_hypothesis_3(aggregated, on_task_names, off_task_names) + analysis["hypothesis_3"] = h3_result + logger.info(f"On-task KL similar: {h3_result['summary']['on_task_kl_similar']}") + + # Add sparsity analysis + if sparsity_results: + sparsity_summary = {} + for name, data in sparsity_results.items(): + sparsity_info = data.get("sparsity", {}) + sparsity_summary[name] = { + "sparsity_percent": sparsity_info.get("sparsity_percent", 0), + "num_changed": sparsity_info.get("num_changed", 0), + } + analysis["bf16_sparsity"] = sparsity_summary + + # Save full analysis + analysis_path = os.path.join(args.output_dir, "full_analysis.json") + with open(analysis_path, "w") as f: + json.dump(analysis, f, indent=2, default=str) + logger.info(f"Saved analysis to {analysis_path}") + + # Generate plots + all_task_names = on_task_names + off_task_names + + plot_delta_j_comparison( + aggregated, + all_task_names, + os.path.join(args.output_dir, "delta_j_comparison.png") + ) + + plot_variance_comparison( + aggregated, + all_task_names, + os.path.join(args.output_dir, "variance_comparison.png") + ) + + # Print summary + print("\n" + "="*80) + print("ANALYSIS SUMMARY") + print("="*80) + + print("\nHypothesis 1 (On-task insensitivity):") + print(f" Supported: {h1_result['overall_supported']}") + + print("\nHypothesis 2 (Off-task sensitivity):") + print(f" Supported: {h2_result['overall_supported']}") + print(f" Tasks showing expected pattern: {h2_result['num_tasks_supported']}/{h2_result['num_tasks_total']}") + + print("\nHypothesis 3 (KL patterns):") + print(f" On-task KL similar across precision: {h3_result['summary']['on_task_kl_similar']}") + print(f" Off-task with higher bf16 variance: {h3_result['summary']['off_task_higher_variance_count']}/{h3_result['summary']['off_task_total']}") + + if sparsity_results: + print("\nbf16 Sparsity:") + for name, data in sorted(sparsity_summary.items()): + print(f" {name}: {data['sparsity_percent']:.1f}% sparse") + + print("="*80) + + +if __name__ == "__main__": + main() + |
