summaryrefslogtreecommitdiff
path: root/analyze_results.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-04 18:59:35 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-04 18:59:35 -0600
commitf1c2cc22d46a6976df3555391e667c7e61592fad (patch)
tree0b37b52c8ff91042a742d3b3ec54542cb6d6e2f6 /analyze_results.py
Initial commit: RL floating-point noise projectHEADmain
Diffstat (limited to 'analyze_results.py')
-rw-r--r--analyze_results.py741
1 files changed, 741 insertions, 0 deletions
diff --git a/analyze_results.py b/analyze_results.py
new file mode 100644
index 0000000..265c005
--- /dev/null
+++ b/analyze_results.py
@@ -0,0 +1,741 @@
+#!/usr/bin/env python3
+# analyze_results.py
+"""
+Results Analysis for RLVR Floating-Point Precision Experiments.
+
+This script analyzes the experimental results to verify the hypotheses:
+1. On-task performance is insensitive to floating-point noise
+2. Off-task performance and KL divergence are sensitive to precision
+
+Computes:
+- Mean and variance of ΔJ_k for each precision mode
+- KL divergence patterns for on-task vs off-task
+- Statistical tests for significance
+- Visualizations
+
+Usage:
+ python analyze_results.py \
+ --results_dir results/eval_metrics \
+ --output_dir results/analysis
+"""
+
+import argparse
+import json
+import os
+import glob
+from typing import Dict, Any, List, Tuple, Optional
+from collections import defaultdict
+import logging
+
+import numpy as np
+from scipy import stats
+
+# Optional: for visualizations
+try:
+ import matplotlib.pyplot as plt
+ import matplotlib
+ matplotlib.use('Agg') # Non-interactive backend
+ HAS_MATPLOTLIB = True
+except ImportError:
+ HAS_MATPLOTLIB = False
+
+logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+)
+logger = logging.getLogger(__name__)
+
+
+# ============================================================================
+# Data Loading
+# ============================================================================
+
+def load_eval_results(results_dir: str) -> Dict[str, Dict[str, Any]]:
+ """
+ Load all evaluation results from a directory.
+
+ Returns:
+ Dictionary mapping "{precision_mode}_seed{seed}" to results
+ """
+ results = {}
+
+ pattern = os.path.join(results_dir, "*.json")
+ for filepath in glob.glob(pattern):
+ filename = os.path.basename(filepath)
+
+ # Skip sparsity files
+ if "sparsity" in filename:
+ continue
+
+ # Parse filename: {precision_mode}_seed{seed}.json
+ name = filename.replace(".json", "")
+
+ try:
+ with open(filepath, "r") as f:
+ data = json.load(f)
+ results[name] = data
+ logger.info(f"Loaded {filepath}")
+ except Exception as e:
+ logger.warning(f"Failed to load {filepath}: {e}")
+
+ return results
+
+
+def load_sparsity_results(results_dir: str) -> Dict[str, Dict[str, Any]]:
+ """Load bf16 sparsity analysis results."""
+ results = {}
+
+ pattern = os.path.join(results_dir, "*_sparsity.json")
+ for filepath in glob.glob(pattern):
+ filename = os.path.basename(filepath)
+ name = filename.replace("_sparsity.json", "")
+
+ try:
+ with open(filepath, "r") as f:
+ data = json.load(f)
+ results[name] = data
+ logger.info(f"Loaded sparsity: {filepath}")
+ except Exception as e:
+ logger.warning(f"Failed to load {filepath}: {e}")
+
+ return results
+
+
+def parse_run_name(name: str) -> Tuple[str, int]:
+ """Parse run name to get precision mode and seed."""
+ # Format: {precision_mode}_seed{seed}
+ parts = name.split("_seed")
+ if len(parts) == 2:
+ precision_mode = parts[0]
+ seed = int(parts[1])
+ return precision_mode, seed
+
+ raise ValueError(f"Cannot parse run name: {name}")
+
+
+# ============================================================================
+# Metrics Aggregation
+# ============================================================================
+
+def aggregate_by_precision(
+ results: Dict[str, Dict[str, Any]]
+) -> Dict[str, Dict[str, List[float]]]:
+ """
+ Aggregate results by precision mode.
+
+ Returns:
+ {precision_mode: {task_name: [scores from different seeds]}}
+ """
+ aggregated: Dict[str, Dict[str, List[float]]] = defaultdict(
+ lambda: defaultdict(list)
+ )
+
+ for run_name, run_results in results.items():
+ try:
+ precision_mode, seed = parse_run_name(run_name)
+ except ValueError:
+ continue
+
+ tasks = run_results.get("tasks", {})
+ for task_name, task_data in tasks.items():
+ # Score
+ if "ft_avg_score" in task_data:
+ aggregated[precision_mode][f"{task_name}_score"].append(
+ task_data["ft_avg_score"]
+ )
+
+ # Delta J
+ if "delta_j" in task_data:
+ aggregated[precision_mode][f"{task_name}_delta_j"].append(
+ task_data["delta_j"]
+ )
+
+ # KL
+ if "avg_kl" in task_data:
+ aggregated[precision_mode][f"{task_name}_kl"].append(
+ task_data["avg_kl"]
+ )
+
+ return dict(aggregated)
+
+
+def compute_statistics(values: List[float]) -> Dict[str, float]:
+ """Compute statistics for a list of values."""
+ if not values:
+ return {
+ "mean": 0.0,
+ "std": 0.0,
+ "var": 0.0,
+ "min": 0.0,
+ "max": 0.0,
+ "n": 0,
+ }
+
+ arr = np.array(values)
+ return {
+ "mean": float(np.mean(arr)),
+ "std": float(np.std(arr)),
+ "var": float(np.var(arr)),
+ "min": float(np.min(arr)),
+ "max": float(np.max(arr)),
+ "n": len(arr),
+ }
+
+
+# ============================================================================
+# Hypothesis Testing
+# ============================================================================
+
+def test_variance_ratio(
+ values1: List[float],
+ values2: List[float],
+ alpha: float = 0.05
+) -> Dict[str, Any]:
+ """
+ Test if variance of values2 is significantly greater than values1.
+
+ Uses F-test (Levene's test is more robust but F-test is simpler).
+ """
+ if len(values1) < 2 or len(values2) < 2:
+ return {
+ "test": "variance_ratio",
+ "valid": False,
+ "reason": "Not enough samples",
+ }
+
+ var1 = np.var(values1, ddof=1)
+ var2 = np.var(values2, ddof=1)
+
+ # F statistic
+ if var1 > 0:
+ f_stat = var2 / var1
+ else:
+ f_stat = float("inf")
+
+ df1 = len(values2) - 1
+ df2 = len(values1) - 1
+
+ # p-value (one-tailed: var2 > var1)
+ p_value = 1 - stats.f.cdf(f_stat, df1, df2)
+
+ return {
+ "test": "variance_ratio",
+ "valid": True,
+ "var1": var1,
+ "var2": var2,
+ "f_statistic": f_stat,
+ "p_value": p_value,
+ "significant": p_value < alpha,
+ "alpha": alpha,
+ }
+
+
+def test_mean_difference(
+ values1: List[float],
+ values2: List[float],
+ alpha: float = 0.05
+) -> Dict[str, Any]:
+ """
+ Test if means are significantly different (two-tailed t-test).
+ """
+ if len(values1) < 2 or len(values2) < 2:
+ return {
+ "test": "mean_difference",
+ "valid": False,
+ "reason": "Not enough samples",
+ }
+
+ t_stat, p_value = stats.ttest_ind(values1, values2)
+
+ return {
+ "test": "mean_difference",
+ "valid": True,
+ "mean1": float(np.mean(values1)),
+ "mean2": float(np.mean(values2)),
+ "t_statistic": float(t_stat),
+ "p_value": float(p_value),
+ "significant": p_value < alpha,
+ "alpha": alpha,
+ }
+
+
+# ============================================================================
+# Hypothesis Verification
+# ============================================================================
+
+def verify_hypothesis_1(
+ aggregated: Dict[str, Dict[str, List[float]]],
+ on_task_names: List[str]
+) -> Dict[str, Any]:
+ """
+ Verify Hypothesis 1: On-task performance is insensitive to precision.
+
+ Expected:
+ - E[ΔJ_0^{fp32}] ≈ E[ΔJ_0^{bf16}] > 0
+ - Var[ΔJ_0^{fp32}] and Var[ΔJ_0^{bf16}] are both small
+ """
+ results = {
+ "hypothesis": "On-task performance insensitive to precision",
+ "tasks": {},
+ }
+
+ fp32_data = aggregated.get("fp32", {})
+ bf16_data = aggregated.get("bf16", {})
+
+ for task_name in on_task_names:
+ key = f"{task_name}_delta_j"
+
+ fp32_values = fp32_data.get(key, [])
+ bf16_values = bf16_data.get(key, [])
+
+ task_result = {
+ "fp32": compute_statistics(fp32_values),
+ "bf16": compute_statistics(bf16_values),
+ }
+
+ # Test mean difference (should NOT be significant)
+ mean_test = test_mean_difference(fp32_values, bf16_values)
+ task_result["mean_test"] = mean_test
+
+ # Test variance ratio (should NOT show bf16 >> fp32)
+ var_test = test_variance_ratio(fp32_values, bf16_values)
+ task_result["variance_test"] = var_test
+
+ # Verify expected pattern
+ fp32_mean = task_result["fp32"]["mean"]
+ bf16_mean = task_result["bf16"]["mean"]
+ fp32_var = task_result["fp32"]["var"]
+ bf16_var = task_result["bf16"]["var"]
+
+ task_result["verification"] = {
+ "means_similar": abs(fp32_mean - bf16_mean) < 0.05, # Within 5%
+ "both_positive": fp32_mean > 0 and bf16_mean > 0,
+ "variances_small": fp32_var < 0.01 and bf16_var < 0.01,
+ "hypothesis_supported": (
+ abs(fp32_mean - bf16_mean) < 0.05 and
+ fp32_mean > 0 and bf16_mean > 0
+ ),
+ }
+
+ results["tasks"][task_name] = task_result
+
+ # Overall verdict
+ all_supported = all(
+ t["verification"]["hypothesis_supported"]
+ for t in results["tasks"].values()
+ )
+ results["overall_supported"] = all_supported
+
+ return results
+
+
+def verify_hypothesis_2(
+ aggregated: Dict[str, Dict[str, List[float]]],
+ off_task_names: List[str],
+ on_task_names: List[str]
+) -> Dict[str, Any]:
+ """
+ Verify Hypothesis 2: Off-task performance is sensitive to precision.
+
+ Expected:
+ - Var[ΔJ_k^{bf16}] >> Var[ΔJ_k^{fp32}] >> Var[ΔJ_0]
+ """
+ results = {
+ "hypothesis": "Off-task performance sensitive to precision",
+ "tasks": {},
+ }
+
+ fp32_data = aggregated.get("fp32", {})
+ bf16_data = aggregated.get("bf16", {})
+
+ # Get on-task variance for comparison
+ on_task_variances = []
+ for task_name in on_task_names:
+ key = f"{task_name}_delta_j"
+ for precision_data in [fp32_data, bf16_data]:
+ values = precision_data.get(key, [])
+ if values:
+ on_task_variances.append(np.var(values))
+
+ on_task_avg_var = np.mean(on_task_variances) if on_task_variances else 0.0
+
+ for task_name in off_task_names:
+ key = f"{task_name}_delta_j"
+
+ fp32_values = fp32_data.get(key, [])
+ bf16_values = bf16_data.get(key, [])
+
+ task_result = {
+ "fp32": compute_statistics(fp32_values),
+ "bf16": compute_statistics(bf16_values),
+ "on_task_avg_var": on_task_avg_var,
+ }
+
+ # Test variance ratio (bf16 should be >> fp32)
+ var_test = test_variance_ratio(fp32_values, bf16_values)
+ task_result["variance_test"] = var_test
+
+ # Verify expected pattern
+ fp32_var = task_result["fp32"]["var"]
+ bf16_var = task_result["bf16"]["var"]
+
+ task_result["verification"] = {
+ "bf16_var_gt_fp32": bf16_var > fp32_var,
+ "bf16_var_gt_fp32_by_5x": bf16_var > 5 * fp32_var if fp32_var > 0 else False,
+ "fp32_var_gt_ontask": fp32_var > on_task_avg_var,
+ "variance_ratio": bf16_var / fp32_var if fp32_var > 0 else float("inf"),
+ "hypothesis_supported": bf16_var > fp32_var,
+ }
+
+ results["tasks"][task_name] = task_result
+
+ # Count how many tasks show expected pattern
+ supported_count = sum(
+ 1 for t in results["tasks"].values()
+ if t["verification"]["hypothesis_supported"]
+ )
+ results["num_tasks_supported"] = supported_count
+ results["num_tasks_total"] = len(off_task_names)
+ results["overall_supported"] = supported_count > len(off_task_names) // 2
+
+ return results
+
+
+def verify_hypothesis_3(
+ aggregated: Dict[str, Dict[str, List[float]]],
+ on_task_names: List[str],
+ off_task_names: List[str]
+) -> Dict[str, Any]:
+ """
+ Verify Hypothesis 3: KL divergence patterns.
+
+ Expected:
+ - On-task KL is similar between fp32 and bf16 (DAPO implicit leash)
+ - Off-task KL has higher variance in bf16
+ """
+ results = {
+ "hypothesis": "KL divergence patterns differ by task type",
+ "on_task": {},
+ "off_task": {},
+ }
+
+ fp32_data = aggregated.get("fp32", {})
+ bf16_data = aggregated.get("bf16", {})
+
+ # On-task KL analysis
+ for task_name in on_task_names:
+ key = f"{task_name}_kl"
+
+ fp32_values = fp32_data.get(key, [])
+ bf16_values = bf16_data.get(key, [])
+
+ task_result = {
+ "fp32": compute_statistics(fp32_values),
+ "bf16": compute_statistics(bf16_values),
+ }
+
+ mean_test = test_mean_difference(fp32_values, bf16_values)
+ task_result["mean_test"] = mean_test
+
+ # Verify: KL should be similar (implicit leash working)
+ task_result["kl_similar"] = not mean_test.get("significant", True)
+
+ results["on_task"][task_name] = task_result
+
+ # Off-task KL analysis
+ for task_name in off_task_names:
+ key = f"{task_name}_kl"
+
+ fp32_values = fp32_data.get(key, [])
+ bf16_values = bf16_data.get(key, [])
+
+ task_result = {
+ "fp32": compute_statistics(fp32_values),
+ "bf16": compute_statistics(bf16_values),
+ }
+
+ var_test = test_variance_ratio(fp32_values, bf16_values)
+ task_result["variance_test"] = var_test
+
+ # Verify: bf16 should have higher variance
+ task_result["bf16_higher_variance"] = var_test.get("significant", False)
+
+ results["off_task"][task_name] = task_result
+
+ # Overall assessment
+ on_task_similar = all(
+ t.get("kl_similar", False) for t in results["on_task"].values()
+ )
+ off_task_variance_higher = sum(
+ 1 for t in results["off_task"].values()
+ if t.get("bf16_higher_variance", False)
+ )
+
+ results["summary"] = {
+ "on_task_kl_similar": on_task_similar,
+ "off_task_higher_variance_count": off_task_variance_higher,
+ "off_task_total": len(off_task_names),
+ }
+
+ return results
+
+
+# ============================================================================
+# Visualization
+# ============================================================================
+
+def plot_delta_j_comparison(
+ aggregated: Dict[str, Dict[str, List[float]]],
+ task_names: List[str],
+ output_path: str
+) -> None:
+ """Plot ΔJ comparison between precision modes."""
+ if not HAS_MATPLOTLIB:
+ logger.warning("matplotlib not available, skipping plot")
+ return
+
+ fig, ax = plt.subplots(figsize=(12, 6))
+
+ x = np.arange(len(task_names))
+ width = 0.35
+
+ fp32_data = aggregated.get("fp32", {})
+ bf16_data = aggregated.get("bf16", {})
+
+ fp32_means = []
+ fp32_stds = []
+ bf16_means = []
+ bf16_stds = []
+
+ for task_name in task_names:
+ key = f"{task_name}_delta_j"
+
+ fp32_values = fp32_data.get(key, [0])
+ bf16_values = bf16_data.get(key, [0])
+
+ fp32_means.append(np.mean(fp32_values))
+ fp32_stds.append(np.std(fp32_values))
+ bf16_means.append(np.mean(bf16_values))
+ bf16_stds.append(np.std(bf16_values))
+
+ ax.bar(x - width/2, fp32_means, width, yerr=fp32_stds,
+ label='FP32', color='steelblue', capsize=5)
+ ax.bar(x + width/2, bf16_means, width, yerr=bf16_stds,
+ label='bf16', color='coral', capsize=5)
+
+ ax.set_ylabel('ΔJ (Performance Delta)')
+ ax.set_xlabel('Task')
+ ax.set_title('Performance Delta by Precision Mode')
+ ax.set_xticks(x)
+ ax.set_xticklabels(task_names, rotation=45, ha='right')
+ ax.legend()
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=150)
+ plt.close()
+
+ logger.info(f"Saved plot to {output_path}")
+
+
+def plot_variance_comparison(
+ aggregated: Dict[str, Dict[str, List[float]]],
+ task_names: List[str],
+ output_path: str
+) -> None:
+ """Plot variance comparison between precision modes."""
+ if not HAS_MATPLOTLIB:
+ logger.warning("matplotlib not available, skipping plot")
+ return
+
+ fig, ax = plt.subplots(figsize=(12, 6))
+
+ x = np.arange(len(task_names))
+ width = 0.35
+
+ fp32_data = aggregated.get("fp32", {})
+ bf16_data = aggregated.get("bf16", {})
+
+ fp32_vars = []
+ bf16_vars = []
+
+ for task_name in task_names:
+ key = f"{task_name}_delta_j"
+
+ fp32_values = fp32_data.get(key, [0])
+ bf16_values = bf16_data.get(key, [0])
+
+ fp32_vars.append(np.var(fp32_values))
+ bf16_vars.append(np.var(bf16_values))
+
+ ax.bar(x - width/2, fp32_vars, width, label='FP32', color='steelblue')
+ ax.bar(x + width/2, bf16_vars, width, label='bf16', color='coral')
+
+ ax.set_ylabel('Variance of ΔJ')
+ ax.set_xlabel('Task')
+ ax.set_title('Variance of Performance Delta by Precision Mode')
+ ax.set_xticks(x)
+ ax.set_xticklabels(task_names, rotation=45, ha='right')
+ ax.legend()
+ ax.set_yscale('log')
+
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=150)
+ plt.close()
+
+ logger.info(f"Saved plot to {output_path}")
+
+
+# ============================================================================
+# Main Analysis
+# ============================================================================
+
+def parse_args() -> argparse.Namespace:
+ """Parse command line arguments."""
+ parser = argparse.ArgumentParser(
+ description="Analyze RLVR floating-point precision experiment results"
+ )
+ parser.add_argument(
+ "--results_dir",
+ type=str,
+ required=True,
+ help="Directory containing evaluation results"
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ required=True,
+ help="Directory to save analysis outputs"
+ )
+ parser.add_argument(
+ "--on_task",
+ type=str,
+ nargs="+",
+ default=["dm_val"],
+ help="On-task (training distribution) task names"
+ )
+ parser.add_argument(
+ "--off_task",
+ type=str,
+ nargs="+",
+ default=["aime24", "aime25", "amc23", "math500", "mmlu_stem", "humaneval"],
+ help="Off-task task names"
+ )
+ return parser.parse_args()
+
+
+def main() -> None:
+ """Main analysis function."""
+ args = parse_args()
+
+ # Create output directory
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # Load results
+ logger.info(f"Loading results from {args.results_dir}")
+ eval_results = load_eval_results(args.results_dir)
+ sparsity_results = load_sparsity_results(args.results_dir)
+
+ if not eval_results:
+ logger.error("No evaluation results found!")
+ return
+
+ # Aggregate by precision mode
+ aggregated = aggregate_by_precision(eval_results)
+
+ # Get all task names from results
+ all_tasks = set()
+ for run_results in eval_results.values():
+ all_tasks.update(run_results.get("tasks", {}).keys())
+
+ # Filter task names to those present in results
+ on_task_names = [t for t in args.on_task if t in all_tasks]
+ off_task_names = [t for t in args.off_task if t in all_tasks]
+
+ logger.info(f"On-task: {on_task_names}")
+ logger.info(f"Off-task: {off_task_names}")
+
+ # Verify hypotheses
+ analysis = {}
+
+ logger.info("\n" + "="*60)
+ logger.info("HYPOTHESIS 1: On-task insensitivity")
+ logger.info("="*60)
+ h1_result = verify_hypothesis_1(aggregated, on_task_names)
+ analysis["hypothesis_1"] = h1_result
+ logger.info(f"Supported: {h1_result['overall_supported']}")
+
+ logger.info("\n" + "="*60)
+ logger.info("HYPOTHESIS 2: Off-task sensitivity")
+ logger.info("="*60)
+ h2_result = verify_hypothesis_2(aggregated, off_task_names, on_task_names)
+ analysis["hypothesis_2"] = h2_result
+ logger.info(f"Supported: {h2_result['overall_supported']} "
+ f"({h2_result['num_tasks_supported']}/{h2_result['num_tasks_total']} tasks)")
+
+ logger.info("\n" + "="*60)
+ logger.info("HYPOTHESIS 3: KL divergence patterns")
+ logger.info("="*60)
+ h3_result = verify_hypothesis_3(aggregated, on_task_names, off_task_names)
+ analysis["hypothesis_3"] = h3_result
+ logger.info(f"On-task KL similar: {h3_result['summary']['on_task_kl_similar']}")
+
+ # Add sparsity analysis
+ if sparsity_results:
+ sparsity_summary = {}
+ for name, data in sparsity_results.items():
+ sparsity_info = data.get("sparsity", {})
+ sparsity_summary[name] = {
+ "sparsity_percent": sparsity_info.get("sparsity_percent", 0),
+ "num_changed": sparsity_info.get("num_changed", 0),
+ }
+ analysis["bf16_sparsity"] = sparsity_summary
+
+ # Save full analysis
+ analysis_path = os.path.join(args.output_dir, "full_analysis.json")
+ with open(analysis_path, "w") as f:
+ json.dump(analysis, f, indent=2, default=str)
+ logger.info(f"Saved analysis to {analysis_path}")
+
+ # Generate plots
+ all_task_names = on_task_names + off_task_names
+
+ plot_delta_j_comparison(
+ aggregated,
+ all_task_names,
+ os.path.join(args.output_dir, "delta_j_comparison.png")
+ )
+
+ plot_variance_comparison(
+ aggregated,
+ all_task_names,
+ os.path.join(args.output_dir, "variance_comparison.png")
+ )
+
+ # Print summary
+ print("\n" + "="*80)
+ print("ANALYSIS SUMMARY")
+ print("="*80)
+
+ print("\nHypothesis 1 (On-task insensitivity):")
+ print(f" Supported: {h1_result['overall_supported']}")
+
+ print("\nHypothesis 2 (Off-task sensitivity):")
+ print(f" Supported: {h2_result['overall_supported']}")
+ print(f" Tasks showing expected pattern: {h2_result['num_tasks_supported']}/{h2_result['num_tasks_total']}")
+
+ print("\nHypothesis 3 (KL patterns):")
+ print(f" On-task KL similar across precision: {h3_result['summary']['on_task_kl_similar']}")
+ print(f" Off-task with higher bf16 variance: {h3_result['summary']['off_task_higher_variance_count']}/{h3_result['summary']['off_task_total']}")
+
+ if sparsity_results:
+ print("\nbf16 Sparsity:")
+ for name, data in sorted(sparsity_summary.items()):
+ print(f" {name}: {data['sparsity_percent']:.1f}% sparse")
+
+ print("="*80)
+
+
+if __name__ == "__main__":
+ main()
+