summaryrefslogtreecommitdiff
path: root/collaborativeagents/scripts/preflight_test.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-27 09:57:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-27 09:57:37 -0600
commitdc801c07cf38b0c495686463e6ca6f871a64440e (patch)
tree599f03114775921dbc472403c701f4a3a8ea188a /collaborativeagents/scripts/preflight_test.py
parente43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (diff)
Add collaborativeagents module and update gitignore
- Add collaborativeagents subproject with adapters, agents, and evaluation modules - Update .gitignore to exclude large binary files (.whl, .tar), wandb logs, and results Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'collaborativeagents/scripts/preflight_test.py')
-rw-r--r--collaborativeagents/scripts/preflight_test.py311
1 files changed, 311 insertions, 0 deletions
diff --git a/collaborativeagents/scripts/preflight_test.py b/collaborativeagents/scripts/preflight_test.py
new file mode 100644
index 0000000..2411f1f
--- /dev/null
+++ b/collaborativeagents/scripts/preflight_test.py
@@ -0,0 +1,311 @@
+#!/usr/bin/env python
+"""
+Pre-flight tests before running full experiments.
+
+Tests:
+1. Timeout handling (infinite timeout)
+2. Large batch stress test (batch=100)
+3. Context length handling (auto-reduce max_tokens)
+4. Error recovery (partial failures)
+5. Sequential profile processing (for RAG/reflection methods)
+6. Memory usage estimation
+"""
+
+import sys
+import os
+import time
+import json
+import asyncio
+
+sys.path.insert(0, '/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents')
+
+from agents.batch_vllm_agent import BatchVLLMClient, BatchConversationGenerator
+
+
+def test_1_timeout_handling(user_url: str):
+ """Test 1: Infinite timeout configuration."""
+ print("\n" + "="*60)
+ print("TEST 1: Timeout Handling (Infinite Timeout)")
+ print("="*60)
+
+ # Create client with infinite timeout
+ client = BatchVLLMClient(
+ vllm_url=user_url,
+ max_tokens=256,
+ temperature=0.7,
+ timeout=None, # Infinite timeout
+ max_concurrent=50
+ )
+
+ print(f"✓ Client created with timeout=None (infinite)")
+ print(f" Model: {client.model_name}")
+ print(f" Max concurrent: {client.max_concurrent}")
+
+ # Test with a simple request
+ messages = [[{"role": "user", "content": "Say 'hello' and nothing else."}]]
+
+ start = time.time()
+ results = client.batch_completion(messages)
+ elapsed = time.time() - start
+
+ if results[0]:
+ print(f"✓ Single request succeeded in {elapsed:.1f}s")
+ print(f" Response: {results[0][:50]}...")
+ return True
+ else:
+ print(f"✗ Single request failed")
+ return False
+
+
+def test_2_large_batch(user_url: str, batch_size: int = 100):
+ """Test 2: Large batch stress test."""
+ print("\n" + "="*60)
+ print(f"TEST 2: Large Batch Stress Test (batch={batch_size})")
+ print("="*60)
+
+ client = BatchVLLMClient(
+ vllm_url=user_url,
+ max_tokens=128, # Small to speed up test
+ temperature=0.7,
+ timeout=None,
+ max_concurrent=100
+ )
+
+ # Create batch of simple requests
+ messages_list = [
+ [{"role": "user", "content": f"Count from 1 to 5. Request #{i+1}"}]
+ for i in range(batch_size)
+ ]
+
+ print(f"Sending {batch_size} concurrent requests...")
+ start = time.time()
+ results = client.batch_completion(messages_list)
+ elapsed = time.time() - start
+
+ successes = sum(1 for r in results if r is not None)
+
+ print(f"\nResults:")
+ print(f" Successes: {successes}/{batch_size}")
+ print(f" Time: {elapsed:.1f}s")
+ print(f" Throughput: {successes * 3600 / elapsed:.0f} requests/hr")
+
+ if successes >= batch_size * 0.9:
+ print(f"✓ Batch test PASSED (>90% success)")
+ return True
+ else:
+ print(f"✗ Batch test FAILED (<90% success)")
+ return False
+
+
+def test_3_context_length_handling(user_url: str):
+ """Test 3: Context length error handling."""
+ print("\n" + "="*60)
+ print("TEST 3: Context Length Handling")
+ print("="*60)
+
+ client = BatchVLLMClient(
+ vllm_url=user_url,
+ max_tokens=512, # Request large output
+ temperature=0.7,
+ timeout=None,
+ max_concurrent=10
+ )
+
+ # Create request with very long input (near 4096 token limit)
+ long_text = "This is a test. " * 500 # ~2000 tokens
+ messages_list = [
+ [{"role": "user", "content": f"Summarize: {long_text}"}], # Will hit limit
+ [{"role": "user", "content": "Say hello."}], # Should succeed
+ ]
+
+ print("Testing with 1 long + 1 short request...")
+ results = client.batch_completion(messages_list)
+
+ # The long one might fail or get reduced max_tokens
+ # The short one should succeed
+ short_success = results[1] is not None
+
+ if short_success:
+ print(f"✓ Short request succeeded despite long request")
+ print(f" Long request result: {'OK' if results[0] else 'Handled gracefully'}")
+ return True
+ else:
+ print(f"✗ Short request should not have failed")
+ return False
+
+
+def test_4_error_recovery(user_url: str, agent_url: str):
+ """Test 4: Error recovery in batch processing."""
+ print("\n" + "="*60)
+ print("TEST 4: Error Recovery (Partial Failures)")
+ print("="*60)
+
+ generator = BatchConversationGenerator(
+ user_vllm_url=user_url,
+ agent_vllm_url=agent_url,
+ max_turns=3,
+ user_max_tokens=256,
+ agent_max_tokens=256,
+ )
+
+ # Mix of valid and problematic samples
+ samples = [
+ {"problem": "What is 2+2?", "solution": "4"},
+ {"problem": "What is 3+3?", "solution": "6"},
+ {"problem": "What is 4+4?", "solution": "8"},
+ ]
+
+ print("Testing batch generation with 3 samples, 3 turns...")
+ start = time.time()
+ results = generator.generate_batch(
+ samples=samples,
+ user_persona="A student.",
+ user_preferences=None,
+ )
+ elapsed = time.time() - start
+
+ successes = sum(1 for r in results if r is not None)
+ print(f"\nResults:")
+ print(f" Successes: {successes}/{len(samples)}")
+ print(f" Time: {elapsed:.1f}s")
+
+ if successes >= 2:
+ print(f"✓ Error recovery PASSED")
+ return True
+ else:
+ print(f"✗ Error recovery FAILED")
+ return False
+
+
+def test_5_sequential_profile(user_url: str, agent_url: str):
+ """Test 5: Sequential profile processing (simulating RAG/reflection)."""
+ print("\n" + "="*60)
+ print("TEST 5: Sequential Profile Processing (RAG/Reflection Simulation)")
+ print("="*60)
+
+ # Simulate 3 profiles, each with 2 sequential sessions
+ # This is how RAG/reflection methods work - sequential within profile
+
+ generator = BatchConversationGenerator(
+ user_vllm_url=user_url,
+ agent_vllm_url=agent_url,
+ max_turns=2,
+ user_max_tokens=256,
+ agent_max_tokens=256,
+ )
+
+ n_profiles = 3
+ sessions_per_profile = 2
+ total_time = 0
+ total_sessions = 0
+
+ for profile_idx in range(n_profiles):
+ profile_start = time.time()
+
+ # Sequential sessions for this profile
+ for session_idx in range(sessions_per_profile):
+ samples = [
+ {"problem": f"Profile {profile_idx+1}, Session {session_idx+1}: What is {profile_idx+session_idx}+1?",
+ "solution": str(profile_idx + session_idx + 1)}
+ ]
+
+ results = generator.generate_batch(
+ samples=samples,
+ user_persona=f"User profile {profile_idx+1}",
+ user_preferences="Be concise.",
+ )
+
+ if results[0]:
+ total_sessions += 1
+
+ profile_elapsed = time.time() - profile_start
+ total_time += profile_elapsed
+ print(f" Profile {profile_idx+1}: {profile_elapsed:.1f}s for {sessions_per_profile} sessions")
+
+ print(f"\nResults:")
+ print(f" Total sessions: {total_sessions}/{n_profiles * sessions_per_profile}")
+ print(f" Total time: {total_time:.1f}s")
+ print(f" Throughput: {total_sessions * 3600 / total_time:.0f} sessions/hr")
+
+ if total_sessions >= n_profiles * sessions_per_profile * 0.8:
+ print(f"✓ Sequential profile test PASSED")
+ return True
+ else:
+ print(f"✗ Sequential profile test FAILED")
+ return False
+
+
+def test_6_memory_estimation():
+ """Test 6: Memory usage estimation."""
+ print("\n" + "="*60)
+ print("TEST 6: Memory Usage Estimation")
+ print("="*60)
+
+ try:
+ import subprocess
+ result = subprocess.run(
+ ['nvidia-smi', '--query-gpu=index,memory.used,memory.total', '--format=csv,noheader,nounits'],
+ capture_output=True, text=True
+ )
+
+ print("GPU Memory Usage:")
+ for line in result.stdout.strip().split('\n'):
+ parts = line.split(', ')
+ if len(parts) == 3:
+ gpu_idx, used, total = parts
+ used_pct = float(used) / float(total) * 100
+ print(f" GPU {gpu_idx}: {used}/{total} MiB ({used_pct:.1f}%)")
+
+ print("✓ Memory estimation completed")
+ return True
+ except Exception as e:
+ print(f"✗ Could not get memory info: {e}")
+ return False
+
+
+def run_all_tests(user_url: str, agent_url: str):
+ """Run all pre-flight tests."""
+ print("\n" + "="*60)
+ print("PRE-FLIGHT TESTS FOR FULL EXPERIMENTS")
+ print("="*60)
+ print(f"User URL: {user_url}")
+ print(f"Agent URL: {agent_url}")
+ print(f"Time: {time.strftime('%Y-%m-%d %H:%M:%S')}")
+
+ results = {}
+
+ # Run each test
+ results['timeout'] = test_1_timeout_handling(user_url)
+ results['large_batch'] = test_2_large_batch(user_url, batch_size=50)
+ results['context_length'] = test_3_context_length_handling(user_url)
+ results['error_recovery'] = test_4_error_recovery(user_url, agent_url)
+ results['sequential_profile'] = test_5_sequential_profile(user_url, agent_url)
+ results['memory'] = test_6_memory_estimation()
+
+ # Summary
+ print("\n" + "="*60)
+ print("PRE-FLIGHT TEST SUMMARY")
+ print("="*60)
+
+ all_passed = True
+ for test_name, passed in results.items():
+ status = "✓ PASSED" if passed else "✗ FAILED"
+ print(f" {test_name}: {status}")
+ if not passed:
+ all_passed = False
+
+ print()
+ if all_passed:
+ print("✓ ALL TESTS PASSED - Ready for full experiments!")
+ else:
+ print("✗ SOME TESTS FAILED - Review before proceeding")
+
+ return all_passed
+
+
+if __name__ == "__main__":
+ user_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:8004/v1"
+ agent_url = sys.argv[2] if len(sys.argv) > 2 else "http://localhost:8003/v1"
+
+ success = run_all_tests(user_url, agent_url)
+ sys.exit(0 if success else 1)