From f1c2cc22d46a6976df3555391e667c7e61592fad Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Wed, 4 Feb 2026 18:59:35 -0600 Subject: Initial commit: RL floating-point noise project --- utils_bf16_sparsity.py | 459 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 459 insertions(+) create mode 100644 utils_bf16_sparsity.py (limited to 'utils_bf16_sparsity.py') diff --git a/utils_bf16_sparsity.py b/utils_bf16_sparsity.py new file mode 100644 index 0000000..2e0729d --- /dev/null +++ b/utils_bf16_sparsity.py @@ -0,0 +1,459 @@ +# utils_bf16_sparsity.py +""" +bf16-Aware Update Sparsity Utilities. + +This module implements the bf16-aware update sparsity metric from the RLVR paper, +which measures how many parameter updates are "visible" after bf16 quantization. + +Key concepts: +- Due to bf16's limited precision (7 mantissa bits), small updates may be "swallowed" +- The bf16 ULP (Unit in Last Place) creates a minimum relative update threshold +- Updates smaller than ~0.2-0.4% may not be reflected in bf16 representation + +Reference: +- RLVR paper Definition 2.1 & 2.2 +- bf16 ULP analysis showing relative update threshold of 2^{-8} to 2^{-7} +""" + +import torch +import numpy as np +from typing import Dict, Any, Tuple, List, Optional +import logging +from tqdm import tqdm + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# bf16 Equality Check +# ============================================================================ + +def bf16_approximately_equal( + w: torch.Tensor, + w_hat: torch.Tensor, + eta: float = 1e-3 +) -> torch.Tensor: + """ + Check if two tensors are approximately equal under bf16 precision. + + From RLVR Definition 2.1: + Two values w and w_hat are considered bf16-equal if: + |w_hat - w| <= eta * max(|w|, |w_hat|) + + When eta < 2^{-9}, this is equivalent to bit-wise bf16 equality. + + Args: + w: Original weights tensor + w_hat: Updated weights tensor + eta: Relative tolerance (default 1e-3 as in RLVR) + + Returns: + Boolean mask where True indicates bf16-equal + """ + max_abs = torch.maximum(w.abs(), w_hat.abs()) + diff = (w_hat - w).abs() + + # Handle zero weights (avoid division by zero in relative comparison) + # For zeros, use absolute comparison + zero_mask = max_abs < 1e-10 + + # Relative comparison + relative_equal = diff <= eta * max_abs + + # For zeros, check if both are effectively zero + both_zero = w.abs() < 1e-10 + both_zero = both_zero & (w_hat.abs() < 1e-10) + + # Combine: either relatively equal, or both effectively zero + equal_mask = relative_equal | (zero_mask & both_zero) + + return equal_mask + + +def bf16_bitwise_equal( + w: torch.Tensor, + w_hat: torch.Tensor +) -> torch.Tensor: + """ + Check if two tensors are bitwise equal in bf16 representation. + + This is the strictest equality check - values must have identical + bf16 bit patterns. + + Args: + w: Original weights tensor + w_hat: Updated weights tensor + + Returns: + Boolean mask where True indicates bitwise bf16 equality + """ + # Convert to bf16 and compare + w_bf16 = w.to(torch.bfloat16) + w_hat_bf16 = w_hat.to(torch.bfloat16) + + # Bitwise comparison via view as int16 + w_bits = w_bf16.view(torch.int16) + w_hat_bits = w_hat_bf16.view(torch.int16) + + return w_bits == w_hat_bits + + +# ============================================================================ +# Update Count and Sparsity +# ============================================================================ + +def compute_bf16_update_count( + w: torch.Tensor, + w_hat: torch.Tensor, + eta: float = 1e-3 +) -> Tuple[int, int, int]: + """ + Compute bf16-aware update count. + + From RLVR Definition 2.2: + |θ_1 - θ_0|_{0,bf16,η} = #{i: w_hat_i not≈_{bf16,η} w_i} + + Args: + w: Original weights tensor + w_hat: Updated weights tensor + eta: Relative tolerance + + Returns: + Tuple of (num_changed, num_unchanged, total) + """ + equal_mask = bf16_approximately_equal(w, w_hat, eta=eta) + + total = int(equal_mask.numel()) + num_unchanged = int(equal_mask.sum().item()) + num_changed = total - num_unchanged + + return num_changed, num_unchanged, total + + +def compute_bf16_sparsity( + base_model: torch.nn.Module, + finetuned_model: torch.nn.Module, + eta: float = 1e-3, + include_layer_stats: bool = False +) -> Dict[str, Any]: + """ + Compute bf16-aware update sparsity between two models. + + Sparsity = 1 - |θ_1 - θ_0|_{0,bf16,η} / n + + Values closer to 1 mean more sparse (fewer visible updates). + Values closer to 0 mean more dense (more visible updates). + + RLVR Table 1 reports sparsity in range 36%-92% for their experiments. + + Args: + base_model: Original model (θ_0) + finetuned_model: Updated model (θ_1) + eta: Relative tolerance + include_layer_stats: If True, include per-layer statistics + + Returns: + Dictionary with sparsity metrics + """ + base_params = dict(base_model.named_parameters()) + ft_params = dict(finetuned_model.named_parameters()) + + total_elements = 0 + changed_elements = 0 + + layer_stats: Dict[str, Dict[str, Any]] = {} + + for name, base_param in base_params.items(): + if name not in ft_params: + logger.warning(f"Parameter {name} not found in finetuned model") + continue + + ft_param = ft_params[name] + + if base_param.shape != ft_param.shape: + logger.warning( + f"Shape mismatch for {name}: " + f"{base_param.shape} vs {ft_param.shape}" + ) + continue + + # Move to CPU for computation + w = base_param.detach().cpu().float().flatten() + w_hat = ft_param.detach().cpu().float().flatten() + + # Compute update count + num_changed, num_unchanged, total = compute_bf16_update_count( + w, w_hat, eta=eta + ) + + total_elements += total + changed_elements += num_changed + + if include_layer_stats: + layer_sparsity = 1.0 - num_changed / total if total > 0 else 1.0 + layer_stats[name] = { + "num_changed": num_changed, + "num_unchanged": num_unchanged, + "total": total, + "sparsity": layer_sparsity, + "shape": list(base_param.shape) + } + + # Compute overall sparsity + overall_sparsity = 1.0 - changed_elements / total_elements if total_elements > 0 else 1.0 + + result = { + "sparsity": overall_sparsity, + "sparsity_percent": overall_sparsity * 100, + "num_changed": changed_elements, + "num_unchanged": total_elements - changed_elements, + "total_parameters": total_elements, + "eta": eta, + "update_fraction": changed_elements / total_elements if total_elements > 0 else 0.0, + } + + if include_layer_stats: + result["layer_stats"] = layer_stats + + return result + + +# ============================================================================ +# Update Magnitude Analysis +# ============================================================================ + +def analyze_update_magnitudes( + base_model: torch.nn.Module, + finetuned_model: torch.nn.Module +) -> Dict[str, Any]: + """ + Analyze the distribution of update magnitudes. + + This helps understand which updates are below the bf16 ULP threshold. + + Returns statistics about: + - Absolute update magnitudes + - Relative update magnitudes + - Distribution relative to bf16 ULP + """ + base_params = dict(base_model.named_parameters()) + ft_params = dict(finetuned_model.named_parameters()) + + all_relative_updates: List[float] = [] + all_absolute_updates: List[float] = [] + + for name, base_param in base_params.items(): + if name not in ft_params: + continue + + ft_param = ft_params[name] + if base_param.shape != ft_param.shape: + continue + + w = base_param.detach().cpu().float().flatten() + w_hat = ft_param.detach().cpu().float().flatten() + + # Absolute updates + abs_updates = (w_hat - w).abs() + + # Relative updates (avoid division by zero) + max_abs = torch.maximum(w.abs(), w_hat.abs()) + valid_mask = max_abs > 1e-10 + + rel_updates = torch.zeros_like(abs_updates) + rel_updates[valid_mask] = abs_updates[valid_mask] / max_abs[valid_mask] + + # Sample for statistics (avoid memory issues) + sample_size = min(10000, len(abs_updates)) + indices = np.random.choice(len(abs_updates), sample_size, replace=False) + + all_absolute_updates.extend(abs_updates[indices].tolist()) + all_relative_updates.extend(rel_updates[indices].tolist()) + + abs_array = np.array(all_absolute_updates) + rel_array = np.array(all_relative_updates) + + # bf16 ULP threshold (approximately 2^{-8} to 2^{-7}, or 0.2% to 0.4%) + bf16_ulp_low = 2 ** -8 # ~0.39% + bf16_ulp_high = 2 ** -7 # ~0.78% + + # Fraction of updates below ULP threshold + below_low = (rel_array < bf16_ulp_low).mean() + below_high = (rel_array < bf16_ulp_high).mean() + + result = { + "absolute_updates": { + "mean": float(np.mean(abs_array)), + "std": float(np.std(abs_array)), + "median": float(np.median(abs_array)), + "min": float(np.min(abs_array)), + "max": float(np.max(abs_array)), + "percentiles": { + "p25": float(np.percentile(abs_array, 25)), + "p50": float(np.percentile(abs_array, 50)), + "p75": float(np.percentile(abs_array, 75)), + "p90": float(np.percentile(abs_array, 90)), + "p99": float(np.percentile(abs_array, 99)), + } + }, + "relative_updates": { + "mean": float(np.mean(rel_array)), + "std": float(np.std(rel_array)), + "median": float(np.median(rel_array)), + "min": float(np.min(rel_array)), + "max": float(np.max(rel_array)), + "percentiles": { + "p25": float(np.percentile(rel_array, 25)), + "p50": float(np.percentile(rel_array, 50)), + "p75": float(np.percentile(rel_array, 75)), + "p90": float(np.percentile(rel_array, 90)), + "p99": float(np.percentile(rel_array, 99)), + } + }, + "bf16_ulp_analysis": { + "ulp_low_threshold": bf16_ulp_low, + "ulp_high_threshold": bf16_ulp_high, + "fraction_below_low": float(below_low), + "fraction_below_high": float(below_high), + "estimated_swallowed_fraction": float(below_low), + } + } + + return result + + +# ============================================================================ +# Sparsity Trajectory +# ============================================================================ + +def compute_sparsity_trajectory( + base_model_path: str, + checkpoint_paths: List[str], + eta: float = 1e-3 +) -> List[Dict[str, Any]]: + """ + Compute bf16 sparsity for a sequence of checkpoints. + + Useful for understanding how sparsity evolves during training. + + Args: + base_model_path: Path to base model + checkpoint_paths: List of checkpoint paths (in training order) + eta: Relative tolerance + + Returns: + List of sparsity results for each checkpoint + """ + from transformers import AutoModelForCausalLM + + # Load base model + logger.info(f"Loading base model from {base_model_path}") + base_model = AutoModelForCausalLM.from_pretrained( + base_model_path, + torch_dtype=torch.float32, + device_map="cpu" + ) + + trajectory = [] + + for ckpt_path in tqdm(checkpoint_paths, desc="Computing sparsity"): + # Load checkpoint + ckpt_model = AutoModelForCausalLM.from_pretrained( + ckpt_path, + torch_dtype=torch.float32, + device_map="cpu" + ) + + # Compute sparsity + sparsity_result = compute_bf16_sparsity( + base_model=base_model, + finetuned_model=ckpt_model, + eta=eta, + include_layer_stats=False + ) + + trajectory.append({ + "checkpoint": ckpt_path, + "sparsity": sparsity_result["sparsity"], + "sparsity_percent": sparsity_result["sparsity_percent"], + "num_changed": sparsity_result["num_changed"], + }) + + # Free memory + del ckpt_model + + return trajectory + + +# ============================================================================ +# Layer-wise Sparsity Analysis +# ============================================================================ + +def analyze_layer_sparsity_patterns( + base_model: torch.nn.Module, + finetuned_model: torch.nn.Module, + eta: float = 1e-3 +) -> Dict[str, Any]: + """ + Analyze sparsity patterns across different layer types. + + Groups layers by type (attention, MLP, embeddings, etc.) and + reports aggregate sparsity statistics. + """ + sparsity_result = compute_bf16_sparsity( + base_model=base_model, + finetuned_model=finetuned_model, + eta=eta, + include_layer_stats=True + ) + + layer_stats = sparsity_result.get("layer_stats", {}) + + # Group by layer type + groups: Dict[str, List[Dict[str, Any]]] = { + "attention": [], + "mlp": [], + "embedding": [], + "norm": [], + "other": [] + } + + for name, stats in layer_stats.items(): + name_lower = name.lower() + + if any(k in name_lower for k in ["attn", "attention", "self_attn"]): + groups["attention"].append(stats) + elif any(k in name_lower for k in ["mlp", "fc", "dense", "linear"]): + groups["mlp"].append(stats) + elif any(k in name_lower for k in ["embed", "wte", "wpe"]): + groups["embedding"].append(stats) + elif any(k in name_lower for k in ["norm", "ln", "layer_norm"]): + groups["norm"].append(stats) + else: + groups["other"].append(stats) + + # Compute aggregate statistics per group + group_analysis = {} + for group_name, layer_list in groups.items(): + if not layer_list: + continue + + sparsities = [l["sparsity"] for l in layer_list] + total_params = sum(l["total"] for l in layer_list) + total_changed = sum(l["num_changed"] for l in layer_list) + + group_analysis[group_name] = { + "num_layers": len(layer_list), + "total_params": total_params, + "mean_sparsity": float(np.mean(sparsities)), + "std_sparsity": float(np.std(sparsities)), + "min_sparsity": float(np.min(sparsities)), + "max_sparsity": float(np.max(sparsities)), + "aggregate_sparsity": 1.0 - total_changed / total_params if total_params > 0 else 1.0, + } + + return { + "overall_sparsity": sparsity_result["sparsity"], + "group_analysis": group_analysis, + } + -- cgit v1.2.3