summaryrefslogtreecommitdiff
path: root/collaborativeagents/training/estimate_token_lengths.py
diff options
context:
space:
mode:
Diffstat (limited to 'collaborativeagents/training/estimate_token_lengths.py')
-rw-r--r--collaborativeagents/training/estimate_token_lengths.py340
1 files changed, 340 insertions, 0 deletions
diff --git a/collaborativeagents/training/estimate_token_lengths.py b/collaborativeagents/training/estimate_token_lengths.py
new file mode 100644
index 0000000..d00ca8f
--- /dev/null
+++ b/collaborativeagents/training/estimate_token_lengths.py
@@ -0,0 +1,340 @@
+#!/usr/bin/env python3
+"""
+Script to estimate token lengths of data samples in a JSONL file.
+Supports multiple tokenizers: tiktoken (GPT models) and transformers (LLaMA, etc.)
+"""
+
+import json
+import argparse
+from pathlib import Path
+from typing import List, Dict, Any
+import numpy as np
+from collections import defaultdict
+
+def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
+ """Load data from a JSONL file."""
+ data = []
+ with open(file_path, 'r', encoding='utf-8') as f:
+ content = f.read().strip()
+ # Handle both JSONL and JSON array formats
+ if content.startswith('['):
+ data = json.loads(content)
+ else:
+ for line in f:
+ line = line.strip()
+ if line:
+ data.append(json.loads(line))
+ return data
+
+def get_tokenizer(tokenizer_type: str, model_name: str = None):
+ """Initialize and return the specified tokenizer."""
+ if tokenizer_type == "tiktoken":
+ try:
+ import tiktoken
+ if model_name:
+ tokenizer = tiktoken.encoding_for_model(model_name)
+ else:
+ tokenizer = tiktoken.get_encoding("cl100k_base") # GPT-4 default
+ return tokenizer, "tiktoken"
+ except ImportError:
+ print("Warning: tiktoken not installed. Install with: pip install tiktoken")
+ return None, None
+
+ elif tokenizer_type == "transformers":
+ try:
+ from transformers import AutoTokenizer
+ if not model_name:
+ model_name = "meta-llama/Llama-2-7b-hf" # Default to LLaMA
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
+ return tokenizer, "transformers"
+ except ImportError:
+ print("Warning: transformers not installed. Install with: pip install transformers")
+ return None, None
+ except Exception as e:
+ print(f"Warning: Could not load tokenizer {model_name}: {e}")
+ return None, None
+
+ elif tokenizer_type == "simple":
+ # Simple word-based estimation (rough approximation: 1 token ≈ 4 chars)
+ return None, "simple"
+
+ else:
+ raise ValueError(f"Unknown tokenizer type: {tokenizer_type}")
+
+def count_tokens_simple(text: str) -> int:
+ """Simple token estimation based on character count."""
+ return len(text) // 4
+
+def count_tokens(messages: List[Dict[str, str]], tokenizer, tokenizer_type: str) -> Dict[str, int]:
+ """Count tokens for a conversation with messages."""
+
+ # Concatenate all message content
+ full_text = ""
+ role_tokens = 0
+
+ for msg in messages:
+ role = msg.get("role", "")
+ content = msg.get("content", "")
+ full_text += content
+
+ # Add overhead for role tokens (approximate)
+ if tokenizer_type == "tiktoken":
+ role_tokens += len(tokenizer.encode(role)) + 3 # role + delimiters
+ elif tokenizer_type == "transformers":
+ role_tokens += len(tokenizer.encode(role, add_special_tokens=False)) + 3
+ elif tokenizer_type == "simple":
+ role_tokens += len(role) // 4 + 3
+
+ # Count content tokens
+ if tokenizer_type == "tiktoken":
+ content_tokens = len(tokenizer.encode(full_text))
+ elif tokenizer_type == "transformers":
+ content_tokens = len(tokenizer.encode(full_text, add_special_tokens=False))
+ elif tokenizer_type == "simple":
+ content_tokens = count_tokens_simple(full_text)
+ else:
+ content_tokens = 0
+
+ total_tokens = content_tokens + role_tokens
+
+ return {
+ "content_tokens": content_tokens,
+ "role_tokens": role_tokens,
+ "total_tokens": total_tokens
+ }
+
+def analyze_token_lengths(
+ file_path: str,
+ tokenizer_type: str = "simple",
+ model_name: str = None,
+ max_samples: int = None
+) -> Dict[str, Any]:
+ """Analyze token lengths of all samples in the file."""
+
+ print(f"Loading data from: {file_path}")
+ data = load_jsonl(file_path)
+ print(f"Loaded {len(data)} samples")
+
+ if max_samples:
+ data = data[:max_samples]
+ print(f"Analyzing first {max_samples} samples")
+
+ print(f"\nInitializing {tokenizer_type} tokenizer...")
+ tokenizer, tokenizer_type = get_tokenizer(tokenizer_type, model_name)
+
+ if tokenizer is None and tokenizer_type != "simple":
+ print("Falling back to simple estimation...")
+ tokenizer_type = "simple"
+
+ print(f"Counting tokens using {tokenizer_type} tokenizer...\n")
+
+ token_counts = []
+ content_token_counts = []
+ role_token_counts = []
+
+ for i, sample in enumerate(data):
+ if (i + 1) % 100 == 0:
+ print(f"Processed {i + 1}/{len(data)} samples...", end='\r')
+
+ messages = sample.get("messages", [])
+ if not messages:
+ continue
+
+ token_info = count_tokens(messages, tokenizer, tokenizer_type)
+ token_counts.append(token_info["total_tokens"])
+ content_token_counts.append(token_info["content_tokens"])
+ role_token_counts.append(token_info["role_tokens"])
+
+ print(f"Processed {len(data)}/{len(data)} samples... ")
+
+ # Calculate statistics
+ token_counts = np.array(token_counts)
+ content_token_counts = np.array(content_token_counts)
+ role_token_counts = np.array(role_token_counts)
+
+ stats = {
+ "total_samples": len(token_counts),
+ "tokenizer_type": tokenizer_type,
+ "model_name": model_name,
+ "total_tokens": {
+ "min": int(np.min(token_counts)),
+ "max": int(np.max(token_counts)),
+ "mean": float(np.mean(token_counts)),
+ "median": float(np.median(token_counts)),
+ "std": float(np.std(token_counts)),
+ "percentiles": {
+ "25th": float(np.percentile(token_counts, 25)),
+ "50th": float(np.percentile(token_counts, 50)),
+ "75th": float(np.percentile(token_counts, 75)),
+ "90th": float(np.percentile(token_counts, 90)),
+ "95th": float(np.percentile(token_counts, 95)),
+ "99th": float(np.percentile(token_counts, 99)),
+ },
+ "sum": int(np.sum(token_counts))
+ },
+ "content_tokens": {
+ "mean": float(np.mean(content_token_counts)),
+ "sum": int(np.sum(content_token_counts))
+ },
+ "role_tokens": {
+ "mean": float(np.mean(role_token_counts)),
+ "sum": int(np.sum(role_token_counts))
+ }
+ }
+
+ # Token length distribution
+ bins = [0, 512, 1024, 2048, 4096, 8192, 16384, 32768, float('inf')]
+ bin_labels = ['0-512', '512-1K', '1K-2K', '2K-4K', '4K-8K', '8K-16K', '16K-32K', '32K+']
+ distribution = defaultdict(int)
+
+ for count in token_counts:
+ for i, (low, high) in enumerate(zip(bins[:-1], bins[1:])):
+ if low <= count < high:
+ distribution[bin_labels[i]] += 1
+ break
+
+ stats["distribution"] = dict(distribution)
+
+ return stats, token_counts
+
+def print_statistics(stats: Dict[str, Any]):
+ """Print formatted statistics."""
+ print("\n" + "="*70)
+ print("TOKEN LENGTH STATISTICS")
+ print("="*70)
+ print(f"\nTokenizer: {stats['tokenizer_type']}")
+ if stats['model_name']:
+ print(f"Model: {stats['model_name']}")
+ print(f"Total Samples: {stats['total_samples']:,}")
+
+ print("\n" + "-"*70)
+ print("TOTAL TOKENS (including role tokens and formatting)")
+ print("-"*70)
+ total = stats['total_tokens']
+ print(f" Sum: {total['sum']:>15,} tokens")
+ print(f" Mean: {total['mean']:>15,.2f} tokens per sample")
+ print(f" Median: {total['median']:>15,.2f} tokens per sample")
+ print(f" Std Dev: {total['std']:>15,.2f}")
+ print(f" Min: {total['min']:>15,} tokens")
+ print(f" Max: {total['max']:>15,} tokens")
+
+ print("\nPercentiles:")
+ for pct, value in total['percentiles'].items():
+ print(f" {pct:>5}: {value:>15,.2f} tokens")
+
+ print("\n" + "-"*70)
+ print("CONTENT TOKENS (message content only)")
+ print("-"*70)
+ content = stats['content_tokens']
+ print(f" Sum: {content['sum']:>15,} tokens")
+ print(f" Mean: {content['mean']:>15,.2f} tokens per sample")
+
+ print("\n" + "-"*70)
+ print("ROLE TOKENS (role markers and formatting)")
+ print("-"*70)
+ role = stats['role_tokens']
+ print(f" Sum: {role['sum']:>15,} tokens")
+ print(f" Mean: {role['mean']:>15,.2f} tokens per sample")
+
+ print("\n" + "-"*70)
+ print("DISTRIBUTION")
+ print("-"*70)
+ distribution = stats['distribution']
+ for label in ['0-512', '512-1K', '1K-2K', '2K-4K', '4K-8K', '8K-16K', '16K-32K', '32K+']:
+ count = distribution.get(label, 0)
+ percentage = (count / stats['total_samples']) * 100 if stats['total_samples'] > 0 else 0
+ bar = '█' * int(percentage / 2)
+ print(f" {label:>10}: {count:>6} ({percentage:>5.1f}%) {bar}")
+
+ print("\n" + "="*70)
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Estimate token lengths of data samples in a JSONL file",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+Examples:
+ # Simple estimation (no external dependencies)
+ python estimate_token_lengths.py data.jsonl
+
+ # Using tiktoken for GPT models
+ python estimate_token_lengths.py data.jsonl --tokenizer tiktoken --model gpt-4
+
+ # Using transformers for LLaMA
+ python estimate_token_lengths.py data.jsonl --tokenizer transformers --model meta-llama/Llama-2-7b-hf
+
+ # Analyze only first 1000 samples
+ python estimate_token_lengths.py data.jsonl --max-samples 1000
+
+ # Save results to JSON
+ python estimate_token_lengths.py data.jsonl --output stats.json
+ """
+ )
+
+ parser.add_argument(
+ "file_path",
+ type=str,
+ help="Path to JSONL file containing the data"
+ )
+
+ parser.add_argument(
+ "--tokenizer",
+ type=str,
+ choices=["simple", "tiktoken", "transformers"],
+ default="simple",
+ help="Tokenizer to use (default: simple)"
+ )
+
+ parser.add_argument(
+ "--model",
+ type=str,
+ default=None,
+ help="Model name for the tokenizer (e.g., 'gpt-4', 'meta-llama/Llama-2-7b-hf')"
+ )
+
+ parser.add_argument(
+ "--max-samples",
+ type=int,
+ default=None,
+ help="Maximum number of samples to analyze (default: all)"
+ )
+
+ parser.add_argument(
+ "--output",
+ type=str,
+ default=None,
+ help="Output JSON file to save statistics (optional)"
+ )
+
+ args = parser.parse_args()
+
+ # Check if file exists
+ if not Path(args.file_path).exists():
+ print(f"Error: File not found: {args.file_path}")
+ return 1
+
+ # Analyze token lengths
+ stats, token_counts = analyze_token_lengths(
+ args.file_path,
+ tokenizer_type=args.tokenizer,
+ model_name=args.model,
+ max_samples=args.max_samples
+ )
+
+ # Print statistics
+ print_statistics(stats)
+
+ # Save to JSON if requested
+ if args.output:
+ with open(args.output, 'w') as f:
+ json.dump(stats, f, indent=2)
+ print(f"\nStatistics saved to: {args.output}")
+
+ return 0
+
+if __name__ == "__main__":
+ exit(main())
+
+
+
+# python3 estimate_token_lengths.py /shared/storage-01/users/mehri2/mem/collaborativeagents/training/sft/training_data/session_level_reflection_sft_data.jsonl --tokenizer simple \ No newline at end of file