summaryrefslogtreecommitdiff
path: root/utils_bf16_sparsity.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-04 18:59:35 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-04 18:59:35 -0600
commitf1c2cc22d46a6976df3555391e667c7e61592fad (patch)
tree0b37b52c8ff91042a742d3b3ec54542cb6d6e2f6 /utils_bf16_sparsity.py
Initial commit: RL floating-point noise projectHEADmain
Diffstat (limited to 'utils_bf16_sparsity.py')
-rw-r--r--utils_bf16_sparsity.py459
1 files changed, 459 insertions, 0 deletions
diff --git a/utils_bf16_sparsity.py b/utils_bf16_sparsity.py
new file mode 100644
index 0000000..2e0729d
--- /dev/null
+++ b/utils_bf16_sparsity.py
@@ -0,0 +1,459 @@
+# utils_bf16_sparsity.py
+"""
+bf16-Aware Update Sparsity Utilities.
+
+This module implements the bf16-aware update sparsity metric from the RLVR paper,
+which measures how many parameter updates are "visible" after bf16 quantization.
+
+Key concepts:
+- Due to bf16's limited precision (7 mantissa bits), small updates may be "swallowed"
+- The bf16 ULP (Unit in Last Place) creates a minimum relative update threshold
+- Updates smaller than ~0.2-0.4% may not be reflected in bf16 representation
+
+Reference:
+- RLVR paper Definition 2.1 & 2.2
+- bf16 ULP analysis showing relative update threshold of 2^{-8} to 2^{-7}
+"""
+
+import torch
+import numpy as np
+from typing import Dict, Any, Tuple, List, Optional
+import logging
+from tqdm import tqdm
+
+logger = logging.getLogger(__name__)
+
+
+# ============================================================================
+# bf16 Equality Check
+# ============================================================================
+
+def bf16_approximately_equal(
+ w: torch.Tensor,
+ w_hat: torch.Tensor,
+ eta: float = 1e-3
+) -> torch.Tensor:
+ """
+ Check if two tensors are approximately equal under bf16 precision.
+
+ From RLVR Definition 2.1:
+ Two values w and w_hat are considered bf16-equal if:
+ |w_hat - w| <= eta * max(|w|, |w_hat|)
+
+ When eta < 2^{-9}, this is equivalent to bit-wise bf16 equality.
+
+ Args:
+ w: Original weights tensor
+ w_hat: Updated weights tensor
+ eta: Relative tolerance (default 1e-3 as in RLVR)
+
+ Returns:
+ Boolean mask where True indicates bf16-equal
+ """
+ max_abs = torch.maximum(w.abs(), w_hat.abs())
+ diff = (w_hat - w).abs()
+
+ # Handle zero weights (avoid division by zero in relative comparison)
+ # For zeros, use absolute comparison
+ zero_mask = max_abs < 1e-10
+
+ # Relative comparison
+ relative_equal = diff <= eta * max_abs
+
+ # For zeros, check if both are effectively zero
+ both_zero = w.abs() < 1e-10
+ both_zero = both_zero & (w_hat.abs() < 1e-10)
+
+ # Combine: either relatively equal, or both effectively zero
+ equal_mask = relative_equal | (zero_mask & both_zero)
+
+ return equal_mask
+
+
+def bf16_bitwise_equal(
+ w: torch.Tensor,
+ w_hat: torch.Tensor
+) -> torch.Tensor:
+ """
+ Check if two tensors are bitwise equal in bf16 representation.
+
+ This is the strictest equality check - values must have identical
+ bf16 bit patterns.
+
+ Args:
+ w: Original weights tensor
+ w_hat: Updated weights tensor
+
+ Returns:
+ Boolean mask where True indicates bitwise bf16 equality
+ """
+ # Convert to bf16 and compare
+ w_bf16 = w.to(torch.bfloat16)
+ w_hat_bf16 = w_hat.to(torch.bfloat16)
+
+ # Bitwise comparison via view as int16
+ w_bits = w_bf16.view(torch.int16)
+ w_hat_bits = w_hat_bf16.view(torch.int16)
+
+ return w_bits == w_hat_bits
+
+
+# ============================================================================
+# Update Count and Sparsity
+# ============================================================================
+
+def compute_bf16_update_count(
+ w: torch.Tensor,
+ w_hat: torch.Tensor,
+ eta: float = 1e-3
+) -> Tuple[int, int, int]:
+ """
+ Compute bf16-aware update count.
+
+ From RLVR Definition 2.2:
+ |θ_1 - θ_0|_{0,bf16,η} = #{i: w_hat_i not≈_{bf16,η} w_i}
+
+ Args:
+ w: Original weights tensor
+ w_hat: Updated weights tensor
+ eta: Relative tolerance
+
+ Returns:
+ Tuple of (num_changed, num_unchanged, total)
+ """
+ equal_mask = bf16_approximately_equal(w, w_hat, eta=eta)
+
+ total = int(equal_mask.numel())
+ num_unchanged = int(equal_mask.sum().item())
+ num_changed = total - num_unchanged
+
+ return num_changed, num_unchanged, total
+
+
+def compute_bf16_sparsity(
+ base_model: torch.nn.Module,
+ finetuned_model: torch.nn.Module,
+ eta: float = 1e-3,
+ include_layer_stats: bool = False
+) -> Dict[str, Any]:
+ """
+ Compute bf16-aware update sparsity between two models.
+
+ Sparsity = 1 - |θ_1 - θ_0|_{0,bf16,η} / n
+
+ Values closer to 1 mean more sparse (fewer visible updates).
+ Values closer to 0 mean more dense (more visible updates).
+
+ RLVR Table 1 reports sparsity in range 36%-92% for their experiments.
+
+ Args:
+ base_model: Original model (θ_0)
+ finetuned_model: Updated model (θ_1)
+ eta: Relative tolerance
+ include_layer_stats: If True, include per-layer statistics
+
+ Returns:
+ Dictionary with sparsity metrics
+ """
+ base_params = dict(base_model.named_parameters())
+ ft_params = dict(finetuned_model.named_parameters())
+
+ total_elements = 0
+ changed_elements = 0
+
+ layer_stats: Dict[str, Dict[str, Any]] = {}
+
+ for name, base_param in base_params.items():
+ if name not in ft_params:
+ logger.warning(f"Parameter {name} not found in finetuned model")
+ continue
+
+ ft_param = ft_params[name]
+
+ if base_param.shape != ft_param.shape:
+ logger.warning(
+ f"Shape mismatch for {name}: "
+ f"{base_param.shape} vs {ft_param.shape}"
+ )
+ continue
+
+ # Move to CPU for computation
+ w = base_param.detach().cpu().float().flatten()
+ w_hat = ft_param.detach().cpu().float().flatten()
+
+ # Compute update count
+ num_changed, num_unchanged, total = compute_bf16_update_count(
+ w, w_hat, eta=eta
+ )
+
+ total_elements += total
+ changed_elements += num_changed
+
+ if include_layer_stats:
+ layer_sparsity = 1.0 - num_changed / total if total > 0 else 1.0
+ layer_stats[name] = {
+ "num_changed": num_changed,
+ "num_unchanged": num_unchanged,
+ "total": total,
+ "sparsity": layer_sparsity,
+ "shape": list(base_param.shape)
+ }
+
+ # Compute overall sparsity
+ overall_sparsity = 1.0 - changed_elements / total_elements if total_elements > 0 else 1.0
+
+ result = {
+ "sparsity": overall_sparsity,
+ "sparsity_percent": overall_sparsity * 100,
+ "num_changed": changed_elements,
+ "num_unchanged": total_elements - changed_elements,
+ "total_parameters": total_elements,
+ "eta": eta,
+ "update_fraction": changed_elements / total_elements if total_elements > 0 else 0.0,
+ }
+
+ if include_layer_stats:
+ result["layer_stats"] = layer_stats
+
+ return result
+
+
+# ============================================================================
+# Update Magnitude Analysis
+# ============================================================================
+
+def analyze_update_magnitudes(
+ base_model: torch.nn.Module,
+ finetuned_model: torch.nn.Module
+) -> Dict[str, Any]:
+ """
+ Analyze the distribution of update magnitudes.
+
+ This helps understand which updates are below the bf16 ULP threshold.
+
+ Returns statistics about:
+ - Absolute update magnitudes
+ - Relative update magnitudes
+ - Distribution relative to bf16 ULP
+ """
+ base_params = dict(base_model.named_parameters())
+ ft_params = dict(finetuned_model.named_parameters())
+
+ all_relative_updates: List[float] = []
+ all_absolute_updates: List[float] = []
+
+ for name, base_param in base_params.items():
+ if name not in ft_params:
+ continue
+
+ ft_param = ft_params[name]
+ if base_param.shape != ft_param.shape:
+ continue
+
+ w = base_param.detach().cpu().float().flatten()
+ w_hat = ft_param.detach().cpu().float().flatten()
+
+ # Absolute updates
+ abs_updates = (w_hat - w).abs()
+
+ # Relative updates (avoid division by zero)
+ max_abs = torch.maximum(w.abs(), w_hat.abs())
+ valid_mask = max_abs > 1e-10
+
+ rel_updates = torch.zeros_like(abs_updates)
+ rel_updates[valid_mask] = abs_updates[valid_mask] / max_abs[valid_mask]
+
+ # Sample for statistics (avoid memory issues)
+ sample_size = min(10000, len(abs_updates))
+ indices = np.random.choice(len(abs_updates), sample_size, replace=False)
+
+ all_absolute_updates.extend(abs_updates[indices].tolist())
+ all_relative_updates.extend(rel_updates[indices].tolist())
+
+ abs_array = np.array(all_absolute_updates)
+ rel_array = np.array(all_relative_updates)
+
+ # bf16 ULP threshold (approximately 2^{-8} to 2^{-7}, or 0.2% to 0.4%)
+ bf16_ulp_low = 2 ** -8 # ~0.39%
+ bf16_ulp_high = 2 ** -7 # ~0.78%
+
+ # Fraction of updates below ULP threshold
+ below_low = (rel_array < bf16_ulp_low).mean()
+ below_high = (rel_array < bf16_ulp_high).mean()
+
+ result = {
+ "absolute_updates": {
+ "mean": float(np.mean(abs_array)),
+ "std": float(np.std(abs_array)),
+ "median": float(np.median(abs_array)),
+ "min": float(np.min(abs_array)),
+ "max": float(np.max(abs_array)),
+ "percentiles": {
+ "p25": float(np.percentile(abs_array, 25)),
+ "p50": float(np.percentile(abs_array, 50)),
+ "p75": float(np.percentile(abs_array, 75)),
+ "p90": float(np.percentile(abs_array, 90)),
+ "p99": float(np.percentile(abs_array, 99)),
+ }
+ },
+ "relative_updates": {
+ "mean": float(np.mean(rel_array)),
+ "std": float(np.std(rel_array)),
+ "median": float(np.median(rel_array)),
+ "min": float(np.min(rel_array)),
+ "max": float(np.max(rel_array)),
+ "percentiles": {
+ "p25": float(np.percentile(rel_array, 25)),
+ "p50": float(np.percentile(rel_array, 50)),
+ "p75": float(np.percentile(rel_array, 75)),
+ "p90": float(np.percentile(rel_array, 90)),
+ "p99": float(np.percentile(rel_array, 99)),
+ }
+ },
+ "bf16_ulp_analysis": {
+ "ulp_low_threshold": bf16_ulp_low,
+ "ulp_high_threshold": bf16_ulp_high,
+ "fraction_below_low": float(below_low),
+ "fraction_below_high": float(below_high),
+ "estimated_swallowed_fraction": float(below_low),
+ }
+ }
+
+ return result
+
+
+# ============================================================================
+# Sparsity Trajectory
+# ============================================================================
+
+def compute_sparsity_trajectory(
+ base_model_path: str,
+ checkpoint_paths: List[str],
+ eta: float = 1e-3
+) -> List[Dict[str, Any]]:
+ """
+ Compute bf16 sparsity for a sequence of checkpoints.
+
+ Useful for understanding how sparsity evolves during training.
+
+ Args:
+ base_model_path: Path to base model
+ checkpoint_paths: List of checkpoint paths (in training order)
+ eta: Relative tolerance
+
+ Returns:
+ List of sparsity results for each checkpoint
+ """
+ from transformers import AutoModelForCausalLM
+
+ # Load base model
+ logger.info(f"Loading base model from {base_model_path}")
+ base_model = AutoModelForCausalLM.from_pretrained(
+ base_model_path,
+ torch_dtype=torch.float32,
+ device_map="cpu"
+ )
+
+ trajectory = []
+
+ for ckpt_path in tqdm(checkpoint_paths, desc="Computing sparsity"):
+ # Load checkpoint
+ ckpt_model = AutoModelForCausalLM.from_pretrained(
+ ckpt_path,
+ torch_dtype=torch.float32,
+ device_map="cpu"
+ )
+
+ # Compute sparsity
+ sparsity_result = compute_bf16_sparsity(
+ base_model=base_model,
+ finetuned_model=ckpt_model,
+ eta=eta,
+ include_layer_stats=False
+ )
+
+ trajectory.append({
+ "checkpoint": ckpt_path,
+ "sparsity": sparsity_result["sparsity"],
+ "sparsity_percent": sparsity_result["sparsity_percent"],
+ "num_changed": sparsity_result["num_changed"],
+ })
+
+ # Free memory
+ del ckpt_model
+
+ return trajectory
+
+
+# ============================================================================
+# Layer-wise Sparsity Analysis
+# ============================================================================
+
+def analyze_layer_sparsity_patterns(
+ base_model: torch.nn.Module,
+ finetuned_model: torch.nn.Module,
+ eta: float = 1e-3
+) -> Dict[str, Any]:
+ """
+ Analyze sparsity patterns across different layer types.
+
+ Groups layers by type (attention, MLP, embeddings, etc.) and
+ reports aggregate sparsity statistics.
+ """
+ sparsity_result = compute_bf16_sparsity(
+ base_model=base_model,
+ finetuned_model=finetuned_model,
+ eta=eta,
+ include_layer_stats=True
+ )
+
+ layer_stats = sparsity_result.get("layer_stats", {})
+
+ # Group by layer type
+ groups: Dict[str, List[Dict[str, Any]]] = {
+ "attention": [],
+ "mlp": [],
+ "embedding": [],
+ "norm": [],
+ "other": []
+ }
+
+ for name, stats in layer_stats.items():
+ name_lower = name.lower()
+
+ if any(k in name_lower for k in ["attn", "attention", "self_attn"]):
+ groups["attention"].append(stats)
+ elif any(k in name_lower for k in ["mlp", "fc", "dense", "linear"]):
+ groups["mlp"].append(stats)
+ elif any(k in name_lower for k in ["embed", "wte", "wpe"]):
+ groups["embedding"].append(stats)
+ elif any(k in name_lower for k in ["norm", "ln", "layer_norm"]):
+ groups["norm"].append(stats)
+ else:
+ groups["other"].append(stats)
+
+ # Compute aggregate statistics per group
+ group_analysis = {}
+ for group_name, layer_list in groups.items():
+ if not layer_list:
+ continue
+
+ sparsities = [l["sparsity"] for l in layer_list]
+ total_params = sum(l["total"] for l in layer_list)
+ total_changed = sum(l["num_changed"] for l in layer_list)
+
+ group_analysis[group_name] = {
+ "num_layers": len(layer_list),
+ "total_params": total_params,
+ "mean_sparsity": float(np.mean(sparsities)),
+ "std_sparsity": float(np.std(sparsities)),
+ "min_sparsity": float(np.min(sparsities)),
+ "max_sparsity": float(np.max(sparsities)),
+ "aggregate_sparsity": 1.0 - total_changed / total_params if total_params > 0 else 1.0,
+ }
+
+ return {
+ "overall_sparsity": sparsity_result["sparsity"],
+ "group_analysis": group_analysis,
+ }
+