#!/usr/bin/env python """ Pre-flight tests before running full experiments. Tests: 1. Timeout handling (infinite timeout) 2. Large batch stress test (batch=100) 3. Context length handling (auto-reduce max_tokens) 4. Error recovery (partial failures) 5. Sequential profile processing (for RAG/reflection methods) 6. Memory usage estimation """ import sys import os import time import json import asyncio sys.path.insert(0, '/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents') from agents.batch_vllm_agent import BatchVLLMClient, BatchConversationGenerator def test_1_timeout_handling(user_url: str): """Test 1: Infinite timeout configuration.""" print("\n" + "="*60) print("TEST 1: Timeout Handling (Infinite Timeout)") print("="*60) # Create client with infinite timeout client = BatchVLLMClient( vllm_url=user_url, max_tokens=256, temperature=0.7, timeout=None, # Infinite timeout max_concurrent=50 ) print(f"✓ Client created with timeout=None (infinite)") print(f" Model: {client.model_name}") print(f" Max concurrent: {client.max_concurrent}") # Test with a simple request messages = [[{"role": "user", "content": "Say 'hello' and nothing else."}]] start = time.time() results = client.batch_completion(messages) elapsed = time.time() - start if results[0]: print(f"✓ Single request succeeded in {elapsed:.1f}s") print(f" Response: {results[0][:50]}...") return True else: print(f"✗ Single request failed") return False def test_2_large_batch(user_url: str, batch_size: int = 100): """Test 2: Large batch stress test.""" print("\n" + "="*60) print(f"TEST 2: Large Batch Stress Test (batch={batch_size})") print("="*60) client = BatchVLLMClient( vllm_url=user_url, max_tokens=128, # Small to speed up test temperature=0.7, timeout=None, max_concurrent=100 ) # Create batch of simple requests messages_list = [ [{"role": "user", "content": f"Count from 1 to 5. Request #{i+1}"}] for i in range(batch_size) ] print(f"Sending {batch_size} concurrent requests...") start = time.time() results = client.batch_completion(messages_list) elapsed = time.time() - start successes = sum(1 for r in results if r is not None) print(f"\nResults:") print(f" Successes: {successes}/{batch_size}") print(f" Time: {elapsed:.1f}s") print(f" Throughput: {successes * 3600 / elapsed:.0f} requests/hr") if successes >= batch_size * 0.9: print(f"✓ Batch test PASSED (>90% success)") return True else: print(f"✗ Batch test FAILED (<90% success)") return False def test_3_context_length_handling(user_url: str): """Test 3: Context length error handling.""" print("\n" + "="*60) print("TEST 3: Context Length Handling") print("="*60) client = BatchVLLMClient( vllm_url=user_url, max_tokens=512, # Request large output temperature=0.7, timeout=None, max_concurrent=10 ) # Create request with very long input (near 4096 token limit) long_text = "This is a test. " * 500 # ~2000 tokens messages_list = [ [{"role": "user", "content": f"Summarize: {long_text}"}], # Will hit limit [{"role": "user", "content": "Say hello."}], # Should succeed ] print("Testing with 1 long + 1 short request...") results = client.batch_completion(messages_list) # The long one might fail or get reduced max_tokens # The short one should succeed short_success = results[1] is not None if short_success: print(f"✓ Short request succeeded despite long request") print(f" Long request result: {'OK' if results[0] else 'Handled gracefully'}") return True else: print(f"✗ Short request should not have failed") return False def test_4_error_recovery(user_url: str, agent_url: str): """Test 4: Error recovery in batch processing.""" print("\n" + "="*60) print("TEST 4: Error Recovery (Partial Failures)") print("="*60) generator = BatchConversationGenerator( user_vllm_url=user_url, agent_vllm_url=agent_url, max_turns=3, user_max_tokens=256, agent_max_tokens=256, ) # Mix of valid and problematic samples samples = [ {"problem": "What is 2+2?", "solution": "4"}, {"problem": "What is 3+3?", "solution": "6"}, {"problem": "What is 4+4?", "solution": "8"}, ] print("Testing batch generation with 3 samples, 3 turns...") start = time.time() results = generator.generate_batch( samples=samples, user_persona="A student.", user_preferences=None, ) elapsed = time.time() - start successes = sum(1 for r in results if r is not None) print(f"\nResults:") print(f" Successes: {successes}/{len(samples)}") print(f" Time: {elapsed:.1f}s") if successes >= 2: print(f"✓ Error recovery PASSED") return True else: print(f"✗ Error recovery FAILED") return False def test_5_sequential_profile(user_url: str, agent_url: str): """Test 5: Sequential profile processing (simulating RAG/reflection).""" print("\n" + "="*60) print("TEST 5: Sequential Profile Processing (RAG/Reflection Simulation)") print("="*60) # Simulate 3 profiles, each with 2 sequential sessions # This is how RAG/reflection methods work - sequential within profile generator = BatchConversationGenerator( user_vllm_url=user_url, agent_vllm_url=agent_url, max_turns=2, user_max_tokens=256, agent_max_tokens=256, ) n_profiles = 3 sessions_per_profile = 2 total_time = 0 total_sessions = 0 for profile_idx in range(n_profiles): profile_start = time.time() # Sequential sessions for this profile for session_idx in range(sessions_per_profile): samples = [ {"problem": f"Profile {profile_idx+1}, Session {session_idx+1}: What is {profile_idx+session_idx}+1?", "solution": str(profile_idx + session_idx + 1)} ] results = generator.generate_batch( samples=samples, user_persona=f"User profile {profile_idx+1}", user_preferences="Be concise.", ) if results[0]: total_sessions += 1 profile_elapsed = time.time() - profile_start total_time += profile_elapsed print(f" Profile {profile_idx+1}: {profile_elapsed:.1f}s for {sessions_per_profile} sessions") print(f"\nResults:") print(f" Total sessions: {total_sessions}/{n_profiles * sessions_per_profile}") print(f" Total time: {total_time:.1f}s") print(f" Throughput: {total_sessions * 3600 / total_time:.0f} sessions/hr") if total_sessions >= n_profiles * sessions_per_profile * 0.8: print(f"✓ Sequential profile test PASSED") return True else: print(f"✗ Sequential profile test FAILED") return False def test_6_memory_estimation(): """Test 6: Memory usage estimation.""" print("\n" + "="*60) print("TEST 6: Memory Usage Estimation") print("="*60) try: import subprocess result = subprocess.run( ['nvidia-smi', '--query-gpu=index,memory.used,memory.total', '--format=csv,noheader,nounits'], capture_output=True, text=True ) print("GPU Memory Usage:") for line in result.stdout.strip().split('\n'): parts = line.split(', ') if len(parts) == 3: gpu_idx, used, total = parts used_pct = float(used) / float(total) * 100 print(f" GPU {gpu_idx}: {used}/{total} MiB ({used_pct:.1f}%)") print("✓ Memory estimation completed") return True except Exception as e: print(f"✗ Could not get memory info: {e}") return False def run_all_tests(user_url: str, agent_url: str): """Run all pre-flight tests.""" print("\n" + "="*60) print("PRE-FLIGHT TESTS FOR FULL EXPERIMENTS") print("="*60) print(f"User URL: {user_url}") print(f"Agent URL: {agent_url}") print(f"Time: {time.strftime('%Y-%m-%d %H:%M:%S')}") results = {} # Run each test results['timeout'] = test_1_timeout_handling(user_url) results['large_batch'] = test_2_large_batch(user_url, batch_size=50) results['context_length'] = test_3_context_length_handling(user_url) results['error_recovery'] = test_4_error_recovery(user_url, agent_url) results['sequential_profile'] = test_5_sequential_profile(user_url, agent_url) results['memory'] = test_6_memory_estimation() # Summary print("\n" + "="*60) print("PRE-FLIGHT TEST SUMMARY") print("="*60) all_passed = True for test_name, passed in results.items(): status = "✓ PASSED" if passed else "✗ FAILED" print(f" {test_name}: {status}") if not passed: all_passed = False print() if all_passed: print("✓ ALL TESTS PASSED - Ready for full experiments!") else: print("✗ SOME TESTS FAILED - Review before proceeding") return all_passed if __name__ == "__main__": user_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:8004/v1" agent_url = sys.argv[2] if len(sys.argv) > 2 else "http://localhost:8003/v1" success = run_all_tests(user_url, agent_url) sys.exit(0 if success else 1)