# utils_bf16_sparsity.py """ bf16-Aware Update Sparsity Utilities. This module implements the bf16-aware update sparsity metric from the RLVR paper, which measures how many parameter updates are "visible" after bf16 quantization. Key concepts: - Due to bf16's limited precision (7 mantissa bits), small updates may be "swallowed" - The bf16 ULP (Unit in Last Place) creates a minimum relative update threshold - Updates smaller than ~0.2-0.4% may not be reflected in bf16 representation Reference: - RLVR paper Definition 2.1 & 2.2 - bf16 ULP analysis showing relative update threshold of 2^{-8} to 2^{-7} """ import torch import numpy as np from typing import Dict, Any, Tuple, List, Optional import logging from tqdm import tqdm logger = logging.getLogger(__name__) # ============================================================================ # bf16 Equality Check # ============================================================================ def bf16_approximately_equal( w: torch.Tensor, w_hat: torch.Tensor, eta: float = 1e-3 ) -> torch.Tensor: """ Check if two tensors are approximately equal under bf16 precision. From RLVR Definition 2.1: Two values w and w_hat are considered bf16-equal if: |w_hat - w| <= eta * max(|w|, |w_hat|) When eta < 2^{-9}, this is equivalent to bit-wise bf16 equality. Args: w: Original weights tensor w_hat: Updated weights tensor eta: Relative tolerance (default 1e-3 as in RLVR) Returns: Boolean mask where True indicates bf16-equal """ max_abs = torch.maximum(w.abs(), w_hat.abs()) diff = (w_hat - w).abs() # Handle zero weights (avoid division by zero in relative comparison) # For zeros, use absolute comparison zero_mask = max_abs < 1e-10 # Relative comparison relative_equal = diff <= eta * max_abs # For zeros, check if both are effectively zero both_zero = w.abs() < 1e-10 both_zero = both_zero & (w_hat.abs() < 1e-10) # Combine: either relatively equal, or both effectively zero equal_mask = relative_equal | (zero_mask & both_zero) return equal_mask def bf16_bitwise_equal( w: torch.Tensor, w_hat: torch.Tensor ) -> torch.Tensor: """ Check if two tensors are bitwise equal in bf16 representation. This is the strictest equality check - values must have identical bf16 bit patterns. Args: w: Original weights tensor w_hat: Updated weights tensor Returns: Boolean mask where True indicates bitwise bf16 equality """ # Convert to bf16 and compare w_bf16 = w.to(torch.bfloat16) w_hat_bf16 = w_hat.to(torch.bfloat16) # Bitwise comparison via view as int16 w_bits = w_bf16.view(torch.int16) w_hat_bits = w_hat_bf16.view(torch.int16) return w_bits == w_hat_bits # ============================================================================ # Update Count and Sparsity # ============================================================================ def compute_bf16_update_count( w: torch.Tensor, w_hat: torch.Tensor, eta: float = 1e-3 ) -> Tuple[int, int, int]: """ Compute bf16-aware update count. From RLVR Definition 2.2: |θ_1 - θ_0|_{0,bf16,η} = #{i: w_hat_i not≈_{bf16,η} w_i} Args: w: Original weights tensor w_hat: Updated weights tensor eta: Relative tolerance Returns: Tuple of (num_changed, num_unchanged, total) """ equal_mask = bf16_approximately_equal(w, w_hat, eta=eta) total = int(equal_mask.numel()) num_unchanged = int(equal_mask.sum().item()) num_changed = total - num_unchanged return num_changed, num_unchanged, total def compute_bf16_sparsity( base_model: torch.nn.Module, finetuned_model: torch.nn.Module, eta: float = 1e-3, include_layer_stats: bool = False ) -> Dict[str, Any]: """ Compute bf16-aware update sparsity between two models. Sparsity = 1 - |θ_1 - θ_0|_{0,bf16,η} / n Values closer to 1 mean more sparse (fewer visible updates). Values closer to 0 mean more dense (more visible updates). RLVR Table 1 reports sparsity in range 36%-92% for their experiments. Args: base_model: Original model (θ_0) finetuned_model: Updated model (θ_1) eta: Relative tolerance include_layer_stats: If True, include per-layer statistics Returns: Dictionary with sparsity metrics """ base_params = dict(base_model.named_parameters()) ft_params = dict(finetuned_model.named_parameters()) total_elements = 0 changed_elements = 0 layer_stats: Dict[str, Dict[str, Any]] = {} for name, base_param in base_params.items(): if name not in ft_params: logger.warning(f"Parameter {name} not found in finetuned model") continue ft_param = ft_params[name] if base_param.shape != ft_param.shape: logger.warning( f"Shape mismatch for {name}: " f"{base_param.shape} vs {ft_param.shape}" ) continue # Move to CPU for computation w = base_param.detach().cpu().float().flatten() w_hat = ft_param.detach().cpu().float().flatten() # Compute update count num_changed, num_unchanged, total = compute_bf16_update_count( w, w_hat, eta=eta ) total_elements += total changed_elements += num_changed if include_layer_stats: layer_sparsity = 1.0 - num_changed / total if total > 0 else 1.0 layer_stats[name] = { "num_changed": num_changed, "num_unchanged": num_unchanged, "total": total, "sparsity": layer_sparsity, "shape": list(base_param.shape) } # Compute overall sparsity overall_sparsity = 1.0 - changed_elements / total_elements if total_elements > 0 else 1.0 result = { "sparsity": overall_sparsity, "sparsity_percent": overall_sparsity * 100, "num_changed": changed_elements, "num_unchanged": total_elements - changed_elements, "total_parameters": total_elements, "eta": eta, "update_fraction": changed_elements / total_elements if total_elements > 0 else 0.0, } if include_layer_stats: result["layer_stats"] = layer_stats return result # ============================================================================ # Update Magnitude Analysis # ============================================================================ def analyze_update_magnitudes( base_model: torch.nn.Module, finetuned_model: torch.nn.Module ) -> Dict[str, Any]: """ Analyze the distribution of update magnitudes. This helps understand which updates are below the bf16 ULP threshold. Returns statistics about: - Absolute update magnitudes - Relative update magnitudes - Distribution relative to bf16 ULP """ base_params = dict(base_model.named_parameters()) ft_params = dict(finetuned_model.named_parameters()) all_relative_updates: List[float] = [] all_absolute_updates: List[float] = [] for name, base_param in base_params.items(): if name not in ft_params: continue ft_param = ft_params[name] if base_param.shape != ft_param.shape: continue w = base_param.detach().cpu().float().flatten() w_hat = ft_param.detach().cpu().float().flatten() # Absolute updates abs_updates = (w_hat - w).abs() # Relative updates (avoid division by zero) max_abs = torch.maximum(w.abs(), w_hat.abs()) valid_mask = max_abs > 1e-10 rel_updates = torch.zeros_like(abs_updates) rel_updates[valid_mask] = abs_updates[valid_mask] / max_abs[valid_mask] # Sample for statistics (avoid memory issues) sample_size = min(10000, len(abs_updates)) indices = np.random.choice(len(abs_updates), sample_size, replace=False) all_absolute_updates.extend(abs_updates[indices].tolist()) all_relative_updates.extend(rel_updates[indices].tolist()) abs_array = np.array(all_absolute_updates) rel_array = np.array(all_relative_updates) # bf16 ULP threshold (approximately 2^{-8} to 2^{-7}, or 0.2% to 0.4%) bf16_ulp_low = 2 ** -8 # ~0.39% bf16_ulp_high = 2 ** -7 # ~0.78% # Fraction of updates below ULP threshold below_low = (rel_array < bf16_ulp_low).mean() below_high = (rel_array < bf16_ulp_high).mean() result = { "absolute_updates": { "mean": float(np.mean(abs_array)), "std": float(np.std(abs_array)), "median": float(np.median(abs_array)), "min": float(np.min(abs_array)), "max": float(np.max(abs_array)), "percentiles": { "p25": float(np.percentile(abs_array, 25)), "p50": float(np.percentile(abs_array, 50)), "p75": float(np.percentile(abs_array, 75)), "p90": float(np.percentile(abs_array, 90)), "p99": float(np.percentile(abs_array, 99)), } }, "relative_updates": { "mean": float(np.mean(rel_array)), "std": float(np.std(rel_array)), "median": float(np.median(rel_array)), "min": float(np.min(rel_array)), "max": float(np.max(rel_array)), "percentiles": { "p25": float(np.percentile(rel_array, 25)), "p50": float(np.percentile(rel_array, 50)), "p75": float(np.percentile(rel_array, 75)), "p90": float(np.percentile(rel_array, 90)), "p99": float(np.percentile(rel_array, 99)), } }, "bf16_ulp_analysis": { "ulp_low_threshold": bf16_ulp_low, "ulp_high_threshold": bf16_ulp_high, "fraction_below_low": float(below_low), "fraction_below_high": float(below_high), "estimated_swallowed_fraction": float(below_low), } } return result # ============================================================================ # Sparsity Trajectory # ============================================================================ def compute_sparsity_trajectory( base_model_path: str, checkpoint_paths: List[str], eta: float = 1e-3 ) -> List[Dict[str, Any]]: """ Compute bf16 sparsity for a sequence of checkpoints. Useful for understanding how sparsity evolves during training. Args: base_model_path: Path to base model checkpoint_paths: List of checkpoint paths (in training order) eta: Relative tolerance Returns: List of sparsity results for each checkpoint """ from transformers import AutoModelForCausalLM # Load base model logger.info(f"Loading base model from {base_model_path}") base_model = AutoModelForCausalLM.from_pretrained( base_model_path, torch_dtype=torch.float32, device_map="cpu" ) trajectory = [] for ckpt_path in tqdm(checkpoint_paths, desc="Computing sparsity"): # Load checkpoint ckpt_model = AutoModelForCausalLM.from_pretrained( ckpt_path, torch_dtype=torch.float32, device_map="cpu" ) # Compute sparsity sparsity_result = compute_bf16_sparsity( base_model=base_model, finetuned_model=ckpt_model, eta=eta, include_layer_stats=False ) trajectory.append({ "checkpoint": ckpt_path, "sparsity": sparsity_result["sparsity"], "sparsity_percent": sparsity_result["sparsity_percent"], "num_changed": sparsity_result["num_changed"], }) # Free memory del ckpt_model return trajectory # ============================================================================ # Layer-wise Sparsity Analysis # ============================================================================ def analyze_layer_sparsity_patterns( base_model: torch.nn.Module, finetuned_model: torch.nn.Module, eta: float = 1e-3 ) -> Dict[str, Any]: """ Analyze sparsity patterns across different layer types. Groups layers by type (attention, MLP, embeddings, etc.) and reports aggregate sparsity statistics. """ sparsity_result = compute_bf16_sparsity( base_model=base_model, finetuned_model=finetuned_model, eta=eta, include_layer_stats=True ) layer_stats = sparsity_result.get("layer_stats", {}) # Group by layer type groups: Dict[str, List[Dict[str, Any]]] = { "attention": [], "mlp": [], "embedding": [], "norm": [], "other": [] } for name, stats in layer_stats.items(): name_lower = name.lower() if any(k in name_lower for k in ["attn", "attention", "self_attn"]): groups["attention"].append(stats) elif any(k in name_lower for k in ["mlp", "fc", "dense", "linear"]): groups["mlp"].append(stats) elif any(k in name_lower for k in ["embed", "wte", "wpe"]): groups["embedding"].append(stats) elif any(k in name_lower for k in ["norm", "ln", "layer_norm"]): groups["norm"].append(stats) else: groups["other"].append(stats) # Compute aggregate statistics per group group_analysis = {} for group_name, layer_list in groups.items(): if not layer_list: continue sparsities = [l["sparsity"] for l in layer_list] total_params = sum(l["total"] for l in layer_list) total_changed = sum(l["num_changed"] for l in layer_list) group_analysis[group_name] = { "num_layers": len(layer_list), "total_params": total_params, "mean_sparsity": float(np.mean(sparsities)), "std_sparsity": float(np.std(sparsities)), "min_sparsity": float(np.min(sparsities)), "max_sparsity": float(np.max(sparsities)), "aggregate_sparsity": 1.0 - total_changed / total_params if total_params > 0 else 1.0, } return { "overall_sparsity": sparsity_result["sparsity"], "group_analysis": group_analysis, }