#!/usr/bin/env python3 """ Script to estimate token lengths of data samples in a JSONL file. Supports multiple tokenizers: tiktoken (GPT models) and transformers (LLaMA, etc.) """ import json import argparse from pathlib import Path from typing import List, Dict, Any import numpy as np from collections import defaultdict def load_jsonl(file_path: str) -> List[Dict[str, Any]]: """Load data from a JSONL file.""" data = [] with open(file_path, 'r', encoding='utf-8') as f: content = f.read().strip() # Handle both JSONL and JSON array formats if content.startswith('['): data = json.loads(content) else: for line in f: line = line.strip() if line: data.append(json.loads(line)) return data def get_tokenizer(tokenizer_type: str, model_name: str = None): """Initialize and return the specified tokenizer.""" if tokenizer_type == "tiktoken": try: import tiktoken if model_name: tokenizer = tiktoken.encoding_for_model(model_name) else: tokenizer = tiktoken.get_encoding("cl100k_base") # GPT-4 default return tokenizer, "tiktoken" except ImportError: print("Warning: tiktoken not installed. Install with: pip install tiktoken") return None, None elif tokenizer_type == "transformers": try: from transformers import AutoTokenizer if not model_name: model_name = "meta-llama/Llama-2-7b-hf" # Default to LLaMA tokenizer = AutoTokenizer.from_pretrained(model_name) return tokenizer, "transformers" except ImportError: print("Warning: transformers not installed. Install with: pip install transformers") return None, None except Exception as e: print(f"Warning: Could not load tokenizer {model_name}: {e}") return None, None elif tokenizer_type == "simple": # Simple word-based estimation (rough approximation: 1 token ≈ 4 chars) return None, "simple" else: raise ValueError(f"Unknown tokenizer type: {tokenizer_type}") def count_tokens_simple(text: str) -> int: """Simple token estimation based on character count.""" return len(text) // 4 def count_tokens(messages: List[Dict[str, str]], tokenizer, tokenizer_type: str) -> Dict[str, int]: """Count tokens for a conversation with messages.""" # Concatenate all message content full_text = "" role_tokens = 0 for msg in messages: role = msg.get("role", "") content = msg.get("content", "") full_text += content # Add overhead for role tokens (approximate) if tokenizer_type == "tiktoken": role_tokens += len(tokenizer.encode(role)) + 3 # role + delimiters elif tokenizer_type == "transformers": role_tokens += len(tokenizer.encode(role, add_special_tokens=False)) + 3 elif tokenizer_type == "simple": role_tokens += len(role) // 4 + 3 # Count content tokens if tokenizer_type == "tiktoken": content_tokens = len(tokenizer.encode(full_text)) elif tokenizer_type == "transformers": content_tokens = len(tokenizer.encode(full_text, add_special_tokens=False)) elif tokenizer_type == "simple": content_tokens = count_tokens_simple(full_text) else: content_tokens = 0 total_tokens = content_tokens + role_tokens return { "content_tokens": content_tokens, "role_tokens": role_tokens, "total_tokens": total_tokens } def analyze_token_lengths( file_path: str, tokenizer_type: str = "simple", model_name: str = None, max_samples: int = None ) -> Dict[str, Any]: """Analyze token lengths of all samples in the file.""" print(f"Loading data from: {file_path}") data = load_jsonl(file_path) print(f"Loaded {len(data)} samples") if max_samples: data = data[:max_samples] print(f"Analyzing first {max_samples} samples") print(f"\nInitializing {tokenizer_type} tokenizer...") tokenizer, tokenizer_type = get_tokenizer(tokenizer_type, model_name) if tokenizer is None and tokenizer_type != "simple": print("Falling back to simple estimation...") tokenizer_type = "simple" print(f"Counting tokens using {tokenizer_type} tokenizer...\n") token_counts = [] content_token_counts = [] role_token_counts = [] for i, sample in enumerate(data): if (i + 1) % 100 == 0: print(f"Processed {i + 1}/{len(data)} samples...", end='\r') messages = sample.get("messages", []) if not messages: continue token_info = count_tokens(messages, tokenizer, tokenizer_type) token_counts.append(token_info["total_tokens"]) content_token_counts.append(token_info["content_tokens"]) role_token_counts.append(token_info["role_tokens"]) print(f"Processed {len(data)}/{len(data)} samples... ") # Calculate statistics token_counts = np.array(token_counts) content_token_counts = np.array(content_token_counts) role_token_counts = np.array(role_token_counts) stats = { "total_samples": len(token_counts), "tokenizer_type": tokenizer_type, "model_name": model_name, "total_tokens": { "min": int(np.min(token_counts)), "max": int(np.max(token_counts)), "mean": float(np.mean(token_counts)), "median": float(np.median(token_counts)), "std": float(np.std(token_counts)), "percentiles": { "25th": float(np.percentile(token_counts, 25)), "50th": float(np.percentile(token_counts, 50)), "75th": float(np.percentile(token_counts, 75)), "90th": float(np.percentile(token_counts, 90)), "95th": float(np.percentile(token_counts, 95)), "99th": float(np.percentile(token_counts, 99)), }, "sum": int(np.sum(token_counts)) }, "content_tokens": { "mean": float(np.mean(content_token_counts)), "sum": int(np.sum(content_token_counts)) }, "role_tokens": { "mean": float(np.mean(role_token_counts)), "sum": int(np.sum(role_token_counts)) } } # Token length distribution bins = [0, 512, 1024, 2048, 4096, 8192, 16384, 32768, float('inf')] bin_labels = ['0-512', '512-1K', '1K-2K', '2K-4K', '4K-8K', '8K-16K', '16K-32K', '32K+'] distribution = defaultdict(int) for count in token_counts: for i, (low, high) in enumerate(zip(bins[:-1], bins[1:])): if low <= count < high: distribution[bin_labels[i]] += 1 break stats["distribution"] = dict(distribution) return stats, token_counts def print_statistics(stats: Dict[str, Any]): """Print formatted statistics.""" print("\n" + "="*70) print("TOKEN LENGTH STATISTICS") print("="*70) print(f"\nTokenizer: {stats['tokenizer_type']}") if stats['model_name']: print(f"Model: {stats['model_name']}") print(f"Total Samples: {stats['total_samples']:,}") print("\n" + "-"*70) print("TOTAL TOKENS (including role tokens and formatting)") print("-"*70) total = stats['total_tokens'] print(f" Sum: {total['sum']:>15,} tokens") print(f" Mean: {total['mean']:>15,.2f} tokens per sample") print(f" Median: {total['median']:>15,.2f} tokens per sample") print(f" Std Dev: {total['std']:>15,.2f}") print(f" Min: {total['min']:>15,} tokens") print(f" Max: {total['max']:>15,} tokens") print("\nPercentiles:") for pct, value in total['percentiles'].items(): print(f" {pct:>5}: {value:>15,.2f} tokens") print("\n" + "-"*70) print("CONTENT TOKENS (message content only)") print("-"*70) content = stats['content_tokens'] print(f" Sum: {content['sum']:>15,} tokens") print(f" Mean: {content['mean']:>15,.2f} tokens per sample") print("\n" + "-"*70) print("ROLE TOKENS (role markers and formatting)") print("-"*70) role = stats['role_tokens'] print(f" Sum: {role['sum']:>15,} tokens") print(f" Mean: {role['mean']:>15,.2f} tokens per sample") print("\n" + "-"*70) print("DISTRIBUTION") print("-"*70) distribution = stats['distribution'] for label in ['0-512', '512-1K', '1K-2K', '2K-4K', '4K-8K', '8K-16K', '16K-32K', '32K+']: count = distribution.get(label, 0) percentage = (count / stats['total_samples']) * 100 if stats['total_samples'] > 0 else 0 bar = '█' * int(percentage / 2) print(f" {label:>10}: {count:>6} ({percentage:>5.1f}%) {bar}") print("\n" + "="*70) def main(): parser = argparse.ArgumentParser( description="Estimate token lengths of data samples in a JSONL file", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Simple estimation (no external dependencies) python estimate_token_lengths.py data.jsonl # Using tiktoken for GPT models python estimate_token_lengths.py data.jsonl --tokenizer tiktoken --model gpt-4 # Using transformers for LLaMA python estimate_token_lengths.py data.jsonl --tokenizer transformers --model meta-llama/Llama-2-7b-hf # Analyze only first 1000 samples python estimate_token_lengths.py data.jsonl --max-samples 1000 # Save results to JSON python estimate_token_lengths.py data.jsonl --output stats.json """ ) parser.add_argument( "file_path", type=str, help="Path to JSONL file containing the data" ) parser.add_argument( "--tokenizer", type=str, choices=["simple", "tiktoken", "transformers"], default="simple", help="Tokenizer to use (default: simple)" ) parser.add_argument( "--model", type=str, default=None, help="Model name for the tokenizer (e.g., 'gpt-4', 'meta-llama/Llama-2-7b-hf')" ) parser.add_argument( "--max-samples", type=int, default=None, help="Maximum number of samples to analyze (default: all)" ) parser.add_argument( "--output", type=str, default=None, help="Output JSON file to save statistics (optional)" ) args = parser.parse_args() # Check if file exists if not Path(args.file_path).exists(): print(f"Error: File not found: {args.file_path}") return 1 # Analyze token lengths stats, token_counts = analyze_token_lengths( args.file_path, tokenizer_type=args.tokenizer, model_name=args.model, max_samples=args.max_samples ) # Print statistics print_statistics(stats) # Save to JSON if requested if args.output: with open(args.output, 'w') as f: json.dump(stats, f, indent=2) print(f"\nStatistics saved to: {args.output}") return 0 if __name__ == "__main__": exit(main()) # python3 estimate_token_lengths.py /shared/storage-01/users/mehri2/mem/collaborativeagents/training/sft/training_data/session_level_reflection_sft_data.jsonl --tokenizer simple