summaryrefslogtreecommitdiff
path: root/collaborativeagents/scripts/test_70b_pilot.py
blob: 4bb27a372e573dd59815e8de7b2cdf9f46af72ff (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
#!/usr/bin/env python3
"""
Pilot test for 70B AWQ user model.

Tests:
1. 70B AWQ model loads without OOM
2. User simulation works correctly
3. Multi-turn conversation completes
4. Memory usage is acceptable

Run with 4xA100 GPUs.
"""

import sys
import json
import torch
from pathlib import Path

# Add paths
sys.path.insert(0, str(Path(__file__).parent.parent))
sys.path.insert(0, str(Path(__file__).parent.parent.parent))


def print_gpu_memory():
    """Print current GPU memory usage."""
    print("\n=== GPU Memory Usage ===")
    for i in range(torch.cuda.device_count()):
        total = torch.cuda.get_device_properties(i).total_memory / 1e9
        allocated = torch.cuda.memory_allocated(i) / 1e9
        reserved = torch.cuda.memory_reserved(i) / 1e9
        print(f"  GPU {i}: {allocated:.1f}GB allocated, {reserved:.1f}GB reserved, {total:.1f}GB total")
    print()


def test_70b_user_agent():
    """Test 70B user agent standalone."""
    print("=" * 60)
    print("TEST 1: 70B AWQ User Agent Loading")
    print("=" * 60)

    from agents.local_user_agent import LocalUserAgent, DEFAULT_MODEL_PATH

    print(f"Default model path: {DEFAULT_MODEL_PATH}")
    print(f"Is AWQ model: {'awq' in DEFAULT_MODEL_PATH.lower()}")

    # Create user agent
    user_agent = LocalUserAgent(
        user_task_description="Help solve a math problem",
        problem="What is 2 + 2?",
        user_persona="A student learning math",
        user_preferences="- Show step by step solutions\n- Use simple language",
    )

    print("\nGenerating user response...")
    print_gpu_memory()

    # Simulate a conversation
    conversation = [{"role": "assistant", "content": "How can I help you today?"}]
    response = user_agent.generate_user_response(conversation)

    print_gpu_memory()

    if response:
        print(f"SUCCESS! User response: {response.get('response', 'N/A')[:200]}...")
        print(f"Should terminate: {response.get('should_terminate', 'N/A')}")
        return True
    else:
        print("FAILED! User agent returned None")
        return False


def test_multiturn_with_70b():
    """Test multi-turn conversation with 70B user model."""
    print("\n" + "=" * 60)
    print("TEST 2: Multi-turn Conversation with 70B User Model")
    print("=" * 60)

    from agents.local_user_agent import SharedLocalUserAgent, TERMINATION_SIGNAL
    from adapters.personalized_llm_adapter import create_baseline_adapter

    # Create vanilla adapter (uses Qwen 1.5B for agent)
    print("\nCreating vanilla adapter...")
    adapter = create_baseline_adapter("vanilla")
    adapter.initialize()

    print_gpu_memory()

    # Load a test profile
    profile_path = Path(__file__).parent.parent / "data/complex_profiles_v2/profiles_100.jsonl"
    with open(profile_path) as f:
        profile = json.loads(f.readline())

    print(f"Loaded profile: {profile.get('user_id', 'unknown')}")

    # Create user agent with 70B model
    problem = "What is 15% of 80?"
    user_prefs = profile.get("preferences", [])[:3]
    pref_str = "\n".join([f"- {p}" for p in user_prefs])

    print(f"\nUser preferences:\n{pref_str}")

    user_agent = SharedLocalUserAgent(
        user_task_description="Solve the math problem",
        problem=problem,
        user_persona=profile.get("persona", "A user"),
        user_preferences=pref_str,
    )

    print_gpu_memory()

    # Start session
    adapter.start_session(user_id=profile.get("user_id", "test"))

    # Run multi-turn conversation
    conversation = [{"role": "assistant", "content": "How can I help you today?"}]
    turns = []
    max_turns = 5

    print(f"\nStarting {max_turns}-turn conversation...")

    for turn_num in range(max_turns):
        print(f"\n--- Turn {turn_num + 1} ---")

        # User turn
        user_response = user_agent.generate_user_response(conversation)
        if user_response is None:
            print("User agent failed!")
            break

        user_msg = user_response.get("response", "")
        print(f"USER: {user_msg[:150]}...")

        conversation.append({"role": "user", "content": user_msg})
        turns.append({"role": "user", "content": user_msg})

        # Check termination
        if user_response.get("should_terminate", False) or TERMINATION_SIGNAL in user_msg:
            print("\n[User terminated conversation]")
            break

        # Agent turn
        response = adapter.generate_response(user_msg, conversation[:-1])
        agent_msg = response.get("response", str(response)) if isinstance(response, dict) else str(response)
        print(f"AGENT: {agent_msg[:150]}...")

        conversation.append({"role": "assistant", "content": agent_msg})
        turns.append({"role": "assistant", "content": agent_msg})

    # End session
    adapter.end_session()

    print(f"\n--- Results ---")
    print(f"Total turns: {len(turns)}")
    print(f"User turns: {len([t for t in turns if t['role'] == 'user'])}")
    print(f"Agent turns: {len([t for t in turns if t['role'] == 'assistant'])}")

    print_gpu_memory()

    return len(turns) > 2  # Success if more than single turn


def test_memory_after_multiple_sessions():
    """Test memory doesn't grow unboundedly after multiple sessions."""
    print("\n" + "=" * 60)
    print("TEST 3: Memory Stability Across Sessions")
    print("=" * 60)

    from agents.local_user_agent import SharedLocalUserAgent, TERMINATION_SIGNAL
    from adapters.personalized_llm_adapter import create_baseline_adapter

    adapter = create_baseline_adapter("vanilla")
    adapter.initialize()

    profile_path = Path(__file__).parent.parent / "data/complex_profiles_v2/profiles_100.jsonl"
    with open(profile_path) as f:
        profile = json.loads(f.readline())

    n_sessions = 3
    print(f"\nRunning {n_sessions} sessions to check memory stability...")

    for session_idx in range(n_sessions):
        print(f"\n--- Session {session_idx + 1}/{n_sessions} ---")

        user_agent = SharedLocalUserAgent(
            user_task_description="Solve math",
            problem=f"What is {session_idx + 1} + {session_idx + 2}?",
            user_persona="A student",
            user_preferences="- Be concise",
        )

        adapter.start_session(user_id=profile.get("user_id", "test"))

        conversation = [{"role": "assistant", "content": "How can I help?"}]
        for turn in range(3):
            user_response = user_agent.generate_user_response(conversation)
            if user_response is None or user_response.get("should_terminate"):
                break
            conversation.append({"role": "user", "content": user_response.get("response", "")})

            response = adapter.generate_response(user_response.get("response", ""), conversation[:-1])
            conversation.append({"role": "assistant", "content": response.get("response", str(response))})

        adapter.end_session()
        print_gpu_memory()

        # Force garbage collection
        import gc
        gc.collect()
        torch.cuda.empty_cache()

    print("\nMemory stability test completed.")
    return True


if __name__ == "__main__":
    import os
    os.environ["HF_HOME"] = "/projects/bfqt/users/yurenh2/hf_cache/huggingface"

    print("\n" + "=" * 60)
    print("70B AWQ USER MODEL PILOT TEST")
    print("=" * 60)
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    print(f"GPU count: {torch.cuda.device_count()}")

    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")

    print_gpu_memory()

    results = {}

    # Test 1: User agent loading
    try:
        results["70b_load"] = test_70b_user_agent()
    except Exception as e:
        print(f"TEST 1 FAILED: {e}")
        import traceback
        traceback.print_exc()
        results["70b_load"] = False

    # Test 2: Multi-turn conversation (only if test 1 passed)
    if results.get("70b_load", False):
        try:
            results["multiturn"] = test_multiturn_with_70b()
        except Exception as e:
            print(f"TEST 2 FAILED: {e}")
            import traceback
            traceback.print_exc()
            results["multiturn"] = False
    else:
        print("\nSkipping TEST 2 (TEST 1 failed)")
        results["multiturn"] = False

    # Test 3: Memory stability (only if test 2 passed)
    if results.get("multiturn", False):
        try:
            results["memory_stable"] = test_memory_after_multiple_sessions()
        except Exception as e:
            print(f"TEST 3 FAILED: {e}")
            import traceback
            traceback.print_exc()
            results["memory_stable"] = False
    else:
        print("\nSkipping TEST 3 (TEST 2 failed)")
        results["memory_stable"] = False

    # Summary
    print("\n" + "=" * 60)
    print("TEST SUMMARY")
    print("=" * 60)
    for test_name, passed in results.items():
        status = "PASS" if passed else "FAIL"
        print(f"  {test_name}: {status}")

    all_passed = all(results.values())
    print(f"\nOverall: {'ALL TESTS PASSED - Ready for full experiment!' if all_passed else 'SOME TESTS FAILED'}")

    print_gpu_memory()

    sys.exit(0 if all_passed else 1)