diff options
Diffstat (limited to 'collaborativeagents/scripts/test_multiturn.py')
| -rw-r--r-- | collaborativeagents/scripts/test_multiturn.py | 248 |
1 files changed, 248 insertions, 0 deletions
diff --git a/collaborativeagents/scripts/test_multiturn.py b/collaborativeagents/scripts/test_multiturn.py new file mode 100644 index 0000000..1909c34 --- /dev/null +++ b/collaborativeagents/scripts/test_multiturn.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python3 +""" +Minimal test script to validate multi-turn conversation works correctly. + +This runs a single profile with a single session to verify: +1. LocalUserAgent loads and generates responses +2. Multi-turn conversation loop works +3. Metrics are properly extracted +""" + +import sys +import json +from pathlib import Path + +# Add paths +sys.path.insert(0, str(Path(__file__).parent.parent)) +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from agents.local_user_agent import LocalUserAgent, SharedLocalUserAgent, TERMINATION_SIGNAL + +def test_user_agent_standalone(): + """Test LocalUserAgent in isolation.""" + print("=" * 60) + print("TEST 1: LocalUserAgent Standalone") + print("=" * 60) + + user_agent = LocalUserAgent( + user_task_description="Help solve a math problem", + problem="What is 2 + 2?", + user_persona="A student learning math", + user_preferences="- Show step by step solutions\n- Use simple language", + ) + + # Simulate a conversation + conversation = [{"role": "assistant", "content": "How can I help you today?"}] + + print("\nGenerating user response...") + response = user_agent.generate_user_response(conversation) + + if response: + print(f"SUCCESS! User response: {response.get('response', 'N/A')[:200]}...") + print(f"Should terminate: {response.get('should_terminate', 'N/A')}") + print(f"Draft answer: {response.get('draft_answer', 'N/A')[:100]}...") + return True + else: + print("FAILED! User agent returned None") + return False + + +def test_multiturn_conversation(): + """Test full multi-turn conversation with agent adapter.""" + print("\n" + "=" * 60) + print("TEST 2: Multi-turn Conversation") + print("=" * 60) + + from adapters.personalized_llm_adapter import create_baseline_adapter + + # Create a simple agent adapter (vanilla mode) + print("\nCreating vanilla adapter...") + adapter = create_baseline_adapter("vanilla") + adapter.initialize() + + # Load a test profile + profile_path = Path(__file__).parent.parent / "data/complex_profiles_v2/profiles_100.jsonl" + with open(profile_path) as f: + profile = json.loads(f.readline()) + + print(f"Loaded profile: {profile.get('user_id', 'unknown')}") + + # Create user agent + problem = "What is 15% of 80?" + user_prefs = profile.get("preferences", [])[:3] + pref_str = "\n".join([f"- {p}" for p in user_prefs]) + + print(f"\nUser preferences:\n{pref_str}") + + user_agent = SharedLocalUserAgent( + user_task_description="Solve the math problem", + problem=problem, + user_persona=profile.get("persona", "A user"), + user_preferences=pref_str, + ) + + # Start session + adapter.start_session(user_id=profile.get("user_id", "test")) + + # Run multi-turn conversation + conversation = [{"role": "assistant", "content": "How can I help you today?"}] + turns = [] + max_turns = 5 + + print(f"\nStarting {max_turns}-turn conversation...") + + for turn_num in range(max_turns): + print(f"\n--- Turn {turn_num + 1} ---") + + # User turn + user_response = user_agent.generate_user_response(conversation) + if user_response is None: + print("User agent failed!") + break + + user_msg = user_response.get("response", "") + print(f"USER: {user_msg[:150]}...") + + conversation.append({"role": "user", "content": user_msg}) + turns.append({"role": "user", "content": user_msg}) + + # Check termination + if user_response.get("should_terminate", False) or TERMINATION_SIGNAL in user_msg: + print("\n[User terminated conversation]") + break + + # Agent turn + response = adapter.generate_response(user_msg, conversation[:-1]) + agent_msg = response.get("response", str(response)) if isinstance(response, dict) else str(response) + print(f"AGENT: {agent_msg[:150]}...") + + conversation.append({"role": "assistant", "content": agent_msg}) + turns.append({"role": "assistant", "content": agent_msg}) + + # End session + adapter.end_session() + + print(f"\n--- Results ---") + print(f"Total turns: {len(turns)}") + print(f"User turns: {len([t for t in turns if t['role'] == 'user'])}") + print(f"Agent turns: {len([t for t in turns if t['role'] == 'assistant'])}") + + return len(turns) > 2 # Success if more than single turn + + +def test_full_session(): + """Test run_single_session from ExperimentRunner.""" + print("\n" + "=" * 60) + print("TEST 3: Full run_single_session") + print("=" * 60) + + from run_experiments import ExperimentRunner, ExperimentConfig + from adapters.personalized_llm_adapter import create_baseline_adapter + + config = ExperimentConfig( + methods=["vanilla"], + datasets=["math-500"], + n_profiles=1, + n_sessions_per_profile=1, + max_turns_per_session=5, + output_dir="/tmp/test_multiturn", + profile_path=str(Path(__file__).parent.parent / "data/complex_profiles_v2/profiles_100.jsonl"), + ) + + print("\nCreating ExperimentRunner...") + runner = ExperimentRunner(config) + + # Get first profile and problem + profile = runner.profiles[0] + dataset = list(runner.datasets.values())[0] + sample = dataset.get_testset()[0] + + problem = { + "problem": sample.problem, + "solution": sample.solution, + "problem_id": sample.problem_id, + "domain": sample.domain, + } + + print(f"\nRunning single session...") + print(f"Profile: {profile.get('user_id', 'unknown')}") + print(f"Problem: {problem['problem'][:100]}...") + + # Create adapter + adapter = create_baseline_adapter("vanilla") + adapter.initialize() + + result = runner.run_single_session( + method="vanilla", + profile=profile, + problem=problem, + is_conflict_query=False, + adapter=adapter, + ) + + print(f"\n--- Session Results ---") + print(f"Total turns: {result['metrics']['total_turns']}") + print(f"Task success: {result['metrics']['task_success']}") + print(f"Enforcement count: {result['metrics']['enforcement_count']}") + print(f"User tokens: {result['metrics']['user_token_count']}") + print(f"Agent tokens: {result['metrics']['agent_token_count']}") + print(f"Compliance scores: {result['metrics']['preference_compliance_scores']}") + + if result['conversation']: + print(f"\nConversation ({len(result['conversation']['turns'])} messages):") + for i, turn in enumerate(result['conversation']['turns'][:6]): + print(f" [{turn['role']}]: {turn['content'][:80]}...") + + return result['metrics']['total_turns'] > 2 + + +if __name__ == "__main__": + print("\n" + "=" * 60) + print("MULTI-TURN CONVERSATION VALIDATION TEST") + print("=" * 60) + + results = {} + + # Test 1: User agent standalone + try: + results["user_agent"] = test_user_agent_standalone() + except Exception as e: + print(f"TEST 1 FAILED: {e}") + import traceback + traceback.print_exc() + results["user_agent"] = False + + # Test 2: Multi-turn conversation + try: + results["multiturn"] = test_multiturn_conversation() + except Exception as e: + print(f"TEST 2 FAILED: {e}") + import traceback + traceback.print_exc() + results["multiturn"] = False + + # Test 3: Full session (only if test 2 passed) + if results.get("multiturn", False): + try: + results["full_session"] = test_full_session() + except Exception as e: + print(f"TEST 3 FAILED: {e}") + import traceback + traceback.print_exc() + results["full_session"] = False + else: + print("\nSkipping TEST 3 (TEST 2 failed)") + results["full_session"] = False + + # Summary + print("\n" + "=" * 60) + print("TEST SUMMARY") + print("=" * 60) + for test_name, passed in results.items(): + status = "PASS" if passed else "FAIL" + print(f" {test_name}: {status}") + + all_passed = all(results.values()) + print(f"\nOverall: {'ALL TESTS PASSED' if all_passed else 'SOME TESTS FAILED'}") + + sys.exit(0 if all_passed else 1) |
