summaryrefslogtreecommitdiff
path: root/collaborativeagents/scripts/test_multiturn.py
diff options
context:
space:
mode:
Diffstat (limited to 'collaborativeagents/scripts/test_multiturn.py')
-rw-r--r--collaborativeagents/scripts/test_multiturn.py248
1 files changed, 248 insertions, 0 deletions
diff --git a/collaborativeagents/scripts/test_multiturn.py b/collaborativeagents/scripts/test_multiturn.py
new file mode 100644
index 0000000..1909c34
--- /dev/null
+++ b/collaborativeagents/scripts/test_multiturn.py
@@ -0,0 +1,248 @@
+#!/usr/bin/env python3
+"""
+Minimal test script to validate multi-turn conversation works correctly.
+
+This runs a single profile with a single session to verify:
+1. LocalUserAgent loads and generates responses
+2. Multi-turn conversation loop works
+3. Metrics are properly extracted
+"""
+
+import sys
+import json
+from pathlib import Path
+
+# Add paths
+sys.path.insert(0, str(Path(__file__).parent.parent))
+sys.path.insert(0, str(Path(__file__).parent.parent.parent))
+
+from agents.local_user_agent import LocalUserAgent, SharedLocalUserAgent, TERMINATION_SIGNAL
+
+def test_user_agent_standalone():
+ """Test LocalUserAgent in isolation."""
+ print("=" * 60)
+ print("TEST 1: LocalUserAgent Standalone")
+ print("=" * 60)
+
+ user_agent = LocalUserAgent(
+ user_task_description="Help solve a math problem",
+ problem="What is 2 + 2?",
+ user_persona="A student learning math",
+ user_preferences="- Show step by step solutions\n- Use simple language",
+ )
+
+ # Simulate a conversation
+ conversation = [{"role": "assistant", "content": "How can I help you today?"}]
+
+ print("\nGenerating user response...")
+ response = user_agent.generate_user_response(conversation)
+
+ if response:
+ print(f"SUCCESS! User response: {response.get('response', 'N/A')[:200]}...")
+ print(f"Should terminate: {response.get('should_terminate', 'N/A')}")
+ print(f"Draft answer: {response.get('draft_answer', 'N/A')[:100]}...")
+ return True
+ else:
+ print("FAILED! User agent returned None")
+ return False
+
+
+def test_multiturn_conversation():
+ """Test full multi-turn conversation with agent adapter."""
+ print("\n" + "=" * 60)
+ print("TEST 2: Multi-turn Conversation")
+ print("=" * 60)
+
+ from adapters.personalized_llm_adapter import create_baseline_adapter
+
+ # Create a simple agent adapter (vanilla mode)
+ print("\nCreating vanilla adapter...")
+ adapter = create_baseline_adapter("vanilla")
+ adapter.initialize()
+
+ # Load a test profile
+ profile_path = Path(__file__).parent.parent / "data/complex_profiles_v2/profiles_100.jsonl"
+ with open(profile_path) as f:
+ profile = json.loads(f.readline())
+
+ print(f"Loaded profile: {profile.get('user_id', 'unknown')}")
+
+ # Create user agent
+ problem = "What is 15% of 80?"
+ user_prefs = profile.get("preferences", [])[:3]
+ pref_str = "\n".join([f"- {p}" for p in user_prefs])
+
+ print(f"\nUser preferences:\n{pref_str}")
+
+ user_agent = SharedLocalUserAgent(
+ user_task_description="Solve the math problem",
+ problem=problem,
+ user_persona=profile.get("persona", "A user"),
+ user_preferences=pref_str,
+ )
+
+ # Start session
+ adapter.start_session(user_id=profile.get("user_id", "test"))
+
+ # Run multi-turn conversation
+ conversation = [{"role": "assistant", "content": "How can I help you today?"}]
+ turns = []
+ max_turns = 5
+
+ print(f"\nStarting {max_turns}-turn conversation...")
+
+ for turn_num in range(max_turns):
+ print(f"\n--- Turn {turn_num + 1} ---")
+
+ # User turn
+ user_response = user_agent.generate_user_response(conversation)
+ if user_response is None:
+ print("User agent failed!")
+ break
+
+ user_msg = user_response.get("response", "")
+ print(f"USER: {user_msg[:150]}...")
+
+ conversation.append({"role": "user", "content": user_msg})
+ turns.append({"role": "user", "content": user_msg})
+
+ # Check termination
+ if user_response.get("should_terminate", False) or TERMINATION_SIGNAL in user_msg:
+ print("\n[User terminated conversation]")
+ break
+
+ # Agent turn
+ response = adapter.generate_response(user_msg, conversation[:-1])
+ agent_msg = response.get("response", str(response)) if isinstance(response, dict) else str(response)
+ print(f"AGENT: {agent_msg[:150]}...")
+
+ conversation.append({"role": "assistant", "content": agent_msg})
+ turns.append({"role": "assistant", "content": agent_msg})
+
+ # End session
+ adapter.end_session()
+
+ print(f"\n--- Results ---")
+ print(f"Total turns: {len(turns)}")
+ print(f"User turns: {len([t for t in turns if t['role'] == 'user'])}")
+ print(f"Agent turns: {len([t for t in turns if t['role'] == 'assistant'])}")
+
+ return len(turns) > 2 # Success if more than single turn
+
+
+def test_full_session():
+ """Test run_single_session from ExperimentRunner."""
+ print("\n" + "=" * 60)
+ print("TEST 3: Full run_single_session")
+ print("=" * 60)
+
+ from run_experiments import ExperimentRunner, ExperimentConfig
+ from adapters.personalized_llm_adapter import create_baseline_adapter
+
+ config = ExperimentConfig(
+ methods=["vanilla"],
+ datasets=["math-500"],
+ n_profiles=1,
+ n_sessions_per_profile=1,
+ max_turns_per_session=5,
+ output_dir="/tmp/test_multiturn",
+ profile_path=str(Path(__file__).parent.parent / "data/complex_profiles_v2/profiles_100.jsonl"),
+ )
+
+ print("\nCreating ExperimentRunner...")
+ runner = ExperimentRunner(config)
+
+ # Get first profile and problem
+ profile = runner.profiles[0]
+ dataset = list(runner.datasets.values())[0]
+ sample = dataset.get_testset()[0]
+
+ problem = {
+ "problem": sample.problem,
+ "solution": sample.solution,
+ "problem_id": sample.problem_id,
+ "domain": sample.domain,
+ }
+
+ print(f"\nRunning single session...")
+ print(f"Profile: {profile.get('user_id', 'unknown')}")
+ print(f"Problem: {problem['problem'][:100]}...")
+
+ # Create adapter
+ adapter = create_baseline_adapter("vanilla")
+ adapter.initialize()
+
+ result = runner.run_single_session(
+ method="vanilla",
+ profile=profile,
+ problem=problem,
+ is_conflict_query=False,
+ adapter=adapter,
+ )
+
+ print(f"\n--- Session Results ---")
+ print(f"Total turns: {result['metrics']['total_turns']}")
+ print(f"Task success: {result['metrics']['task_success']}")
+ print(f"Enforcement count: {result['metrics']['enforcement_count']}")
+ print(f"User tokens: {result['metrics']['user_token_count']}")
+ print(f"Agent tokens: {result['metrics']['agent_token_count']}")
+ print(f"Compliance scores: {result['metrics']['preference_compliance_scores']}")
+
+ if result['conversation']:
+ print(f"\nConversation ({len(result['conversation']['turns'])} messages):")
+ for i, turn in enumerate(result['conversation']['turns'][:6]):
+ print(f" [{turn['role']}]: {turn['content'][:80]}...")
+
+ return result['metrics']['total_turns'] > 2
+
+
+if __name__ == "__main__":
+ print("\n" + "=" * 60)
+ print("MULTI-TURN CONVERSATION VALIDATION TEST")
+ print("=" * 60)
+
+ results = {}
+
+ # Test 1: User agent standalone
+ try:
+ results["user_agent"] = test_user_agent_standalone()
+ except Exception as e:
+ print(f"TEST 1 FAILED: {e}")
+ import traceback
+ traceback.print_exc()
+ results["user_agent"] = False
+
+ # Test 2: Multi-turn conversation
+ try:
+ results["multiturn"] = test_multiturn_conversation()
+ except Exception as e:
+ print(f"TEST 2 FAILED: {e}")
+ import traceback
+ traceback.print_exc()
+ results["multiturn"] = False
+
+ # Test 3: Full session (only if test 2 passed)
+ if results.get("multiturn", False):
+ try:
+ results["full_session"] = test_full_session()
+ except Exception as e:
+ print(f"TEST 3 FAILED: {e}")
+ import traceback
+ traceback.print_exc()
+ results["full_session"] = False
+ else:
+ print("\nSkipping TEST 3 (TEST 2 failed)")
+ results["full_session"] = False
+
+ # Summary
+ print("\n" + "=" * 60)
+ print("TEST SUMMARY")
+ print("=" * 60)
+ for test_name, passed in results.items():
+ status = "PASS" if passed else "FAIL"
+ print(f" {test_name}: {status}")
+
+ all_passed = all(results.values())
+ print(f"\nOverall: {'ALL TESTS PASSED' if all_passed else 'SOME TESTS FAILED'}")
+
+ sys.exit(0 if all_passed else 1)