diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 09:57:37 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 09:57:37 -0600 |
| commit | dc801c07cf38b0c495686463e6ca6f871a64440e (patch) | |
| tree | 599f03114775921dbc472403c701f4a3a8ea188a /collaborativeagents/scripts/preflight_test.py | |
| parent | e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (diff) | |
Add collaborativeagents module and update gitignore
- Add collaborativeagents subproject with adapters, agents, and evaluation modules
- Update .gitignore to exclude large binary files (.whl, .tar), wandb logs, and results
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'collaborativeagents/scripts/preflight_test.py')
| -rw-r--r-- | collaborativeagents/scripts/preflight_test.py | 311 |
1 files changed, 311 insertions, 0 deletions
diff --git a/collaborativeagents/scripts/preflight_test.py b/collaborativeagents/scripts/preflight_test.py new file mode 100644 index 0000000..2411f1f --- /dev/null +++ b/collaborativeagents/scripts/preflight_test.py @@ -0,0 +1,311 @@ +#!/usr/bin/env python +""" +Pre-flight tests before running full experiments. + +Tests: +1. Timeout handling (infinite timeout) +2. Large batch stress test (batch=100) +3. Context length handling (auto-reduce max_tokens) +4. Error recovery (partial failures) +5. Sequential profile processing (for RAG/reflection methods) +6. Memory usage estimation +""" + +import sys +import os +import time +import json +import asyncio + +sys.path.insert(0, '/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents') + +from agents.batch_vllm_agent import BatchVLLMClient, BatchConversationGenerator + + +def test_1_timeout_handling(user_url: str): + """Test 1: Infinite timeout configuration.""" + print("\n" + "="*60) + print("TEST 1: Timeout Handling (Infinite Timeout)") + print("="*60) + + # Create client with infinite timeout + client = BatchVLLMClient( + vllm_url=user_url, + max_tokens=256, + temperature=0.7, + timeout=None, # Infinite timeout + max_concurrent=50 + ) + + print(f"✓ Client created with timeout=None (infinite)") + print(f" Model: {client.model_name}") + print(f" Max concurrent: {client.max_concurrent}") + + # Test with a simple request + messages = [[{"role": "user", "content": "Say 'hello' and nothing else."}]] + + start = time.time() + results = client.batch_completion(messages) + elapsed = time.time() - start + + if results[0]: + print(f"✓ Single request succeeded in {elapsed:.1f}s") + print(f" Response: {results[0][:50]}...") + return True + else: + print(f"✗ Single request failed") + return False + + +def test_2_large_batch(user_url: str, batch_size: int = 100): + """Test 2: Large batch stress test.""" + print("\n" + "="*60) + print(f"TEST 2: Large Batch Stress Test (batch={batch_size})") + print("="*60) + + client = BatchVLLMClient( + vllm_url=user_url, + max_tokens=128, # Small to speed up test + temperature=0.7, + timeout=None, + max_concurrent=100 + ) + + # Create batch of simple requests + messages_list = [ + [{"role": "user", "content": f"Count from 1 to 5. Request #{i+1}"}] + for i in range(batch_size) + ] + + print(f"Sending {batch_size} concurrent requests...") + start = time.time() + results = client.batch_completion(messages_list) + elapsed = time.time() - start + + successes = sum(1 for r in results if r is not None) + + print(f"\nResults:") + print(f" Successes: {successes}/{batch_size}") + print(f" Time: {elapsed:.1f}s") + print(f" Throughput: {successes * 3600 / elapsed:.0f} requests/hr") + + if successes >= batch_size * 0.9: + print(f"✓ Batch test PASSED (>90% success)") + return True + else: + print(f"✗ Batch test FAILED (<90% success)") + return False + + +def test_3_context_length_handling(user_url: str): + """Test 3: Context length error handling.""" + print("\n" + "="*60) + print("TEST 3: Context Length Handling") + print("="*60) + + client = BatchVLLMClient( + vllm_url=user_url, + max_tokens=512, # Request large output + temperature=0.7, + timeout=None, + max_concurrent=10 + ) + + # Create request with very long input (near 4096 token limit) + long_text = "This is a test. " * 500 # ~2000 tokens + messages_list = [ + [{"role": "user", "content": f"Summarize: {long_text}"}], # Will hit limit + [{"role": "user", "content": "Say hello."}], # Should succeed + ] + + print("Testing with 1 long + 1 short request...") + results = client.batch_completion(messages_list) + + # The long one might fail or get reduced max_tokens + # The short one should succeed + short_success = results[1] is not None + + if short_success: + print(f"✓ Short request succeeded despite long request") + print(f" Long request result: {'OK' if results[0] else 'Handled gracefully'}") + return True + else: + print(f"✗ Short request should not have failed") + return False + + +def test_4_error_recovery(user_url: str, agent_url: str): + """Test 4: Error recovery in batch processing.""" + print("\n" + "="*60) + print("TEST 4: Error Recovery (Partial Failures)") + print("="*60) + + generator = BatchConversationGenerator( + user_vllm_url=user_url, + agent_vllm_url=agent_url, + max_turns=3, + user_max_tokens=256, + agent_max_tokens=256, + ) + + # Mix of valid and problematic samples + samples = [ + {"problem": "What is 2+2?", "solution": "4"}, + {"problem": "What is 3+3?", "solution": "6"}, + {"problem": "What is 4+4?", "solution": "8"}, + ] + + print("Testing batch generation with 3 samples, 3 turns...") + start = time.time() + results = generator.generate_batch( + samples=samples, + user_persona="A student.", + user_preferences=None, + ) + elapsed = time.time() - start + + successes = sum(1 for r in results if r is not None) + print(f"\nResults:") + print(f" Successes: {successes}/{len(samples)}") + print(f" Time: {elapsed:.1f}s") + + if successes >= 2: + print(f"✓ Error recovery PASSED") + return True + else: + print(f"✗ Error recovery FAILED") + return False + + +def test_5_sequential_profile(user_url: str, agent_url: str): + """Test 5: Sequential profile processing (simulating RAG/reflection).""" + print("\n" + "="*60) + print("TEST 5: Sequential Profile Processing (RAG/Reflection Simulation)") + print("="*60) + + # Simulate 3 profiles, each with 2 sequential sessions + # This is how RAG/reflection methods work - sequential within profile + + generator = BatchConversationGenerator( + user_vllm_url=user_url, + agent_vllm_url=agent_url, + max_turns=2, + user_max_tokens=256, + agent_max_tokens=256, + ) + + n_profiles = 3 + sessions_per_profile = 2 + total_time = 0 + total_sessions = 0 + + for profile_idx in range(n_profiles): + profile_start = time.time() + + # Sequential sessions for this profile + for session_idx in range(sessions_per_profile): + samples = [ + {"problem": f"Profile {profile_idx+1}, Session {session_idx+1}: What is {profile_idx+session_idx}+1?", + "solution": str(profile_idx + session_idx + 1)} + ] + + results = generator.generate_batch( + samples=samples, + user_persona=f"User profile {profile_idx+1}", + user_preferences="Be concise.", + ) + + if results[0]: + total_sessions += 1 + + profile_elapsed = time.time() - profile_start + total_time += profile_elapsed + print(f" Profile {profile_idx+1}: {profile_elapsed:.1f}s for {sessions_per_profile} sessions") + + print(f"\nResults:") + print(f" Total sessions: {total_sessions}/{n_profiles * sessions_per_profile}") + print(f" Total time: {total_time:.1f}s") + print(f" Throughput: {total_sessions * 3600 / total_time:.0f} sessions/hr") + + if total_sessions >= n_profiles * sessions_per_profile * 0.8: + print(f"✓ Sequential profile test PASSED") + return True + else: + print(f"✗ Sequential profile test FAILED") + return False + + +def test_6_memory_estimation(): + """Test 6: Memory usage estimation.""" + print("\n" + "="*60) + print("TEST 6: Memory Usage Estimation") + print("="*60) + + try: + import subprocess + result = subprocess.run( + ['nvidia-smi', '--query-gpu=index,memory.used,memory.total', '--format=csv,noheader,nounits'], + capture_output=True, text=True + ) + + print("GPU Memory Usage:") + for line in result.stdout.strip().split('\n'): + parts = line.split(', ') + if len(parts) == 3: + gpu_idx, used, total = parts + used_pct = float(used) / float(total) * 100 + print(f" GPU {gpu_idx}: {used}/{total} MiB ({used_pct:.1f}%)") + + print("✓ Memory estimation completed") + return True + except Exception as e: + print(f"✗ Could not get memory info: {e}") + return False + + +def run_all_tests(user_url: str, agent_url: str): + """Run all pre-flight tests.""" + print("\n" + "="*60) + print("PRE-FLIGHT TESTS FOR FULL EXPERIMENTS") + print("="*60) + print(f"User URL: {user_url}") + print(f"Agent URL: {agent_url}") + print(f"Time: {time.strftime('%Y-%m-%d %H:%M:%S')}") + + results = {} + + # Run each test + results['timeout'] = test_1_timeout_handling(user_url) + results['large_batch'] = test_2_large_batch(user_url, batch_size=50) + results['context_length'] = test_3_context_length_handling(user_url) + results['error_recovery'] = test_4_error_recovery(user_url, agent_url) + results['sequential_profile'] = test_5_sequential_profile(user_url, agent_url) + results['memory'] = test_6_memory_estimation() + + # Summary + print("\n" + "="*60) + print("PRE-FLIGHT TEST SUMMARY") + print("="*60) + + all_passed = True + for test_name, passed in results.items(): + status = "✓ PASSED" if passed else "✗ FAILED" + print(f" {test_name}: {status}") + if not passed: + all_passed = False + + print() + if all_passed: + print("✓ ALL TESTS PASSED - Ready for full experiments!") + else: + print("✗ SOME TESTS FAILED - Review before proceeding") + + return all_passed + + +if __name__ == "__main__": + user_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:8004/v1" + agent_url = sys.argv[2] if len(sys.argv) > 2 else "http://localhost:8003/v1" + + success = run_all_tests(user_url, agent_url) + sys.exit(0 if success else 1) |
