summaryrefslogtreecommitdiff
path: root/scripts/smoke_extractor_llm.py
blob: b16d0e2452cf3a65e19c62a75e894a091eedd35f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#!/usr/bin/env python3
"""
Smoke test for PreferenceExtractorLLM (Qwen3-0.6B).
Requires 'saves/qwen3-0.6b-full-sft-h200/checkpoint-4358' to be present.
"""

import sys
import os
import json

# Add src to sys.path
sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))

from personalization.config.registry import get_preference_extractor
from personalization.retrieval.preference_store.schemas import ChatTurn

def main():
    print("Initializing Preference Extractor (qwen3_0_6b_sft)...")
    try:
        extractor = get_preference_extractor("qwen3_0_6b_sft")
    except Exception as e:
        print(f"Failed to load extractor: {e}")
        print("Please check if the checkpoint exists at saves/qwen3-0.6b-full-sft-h200/checkpoint-4358")
        print("and local_models.yaml is configured correctly.")
        sys.exit(1)

    print("Extractor loaded successfully.")

    # Construct dummy conversation
    turns = [
        ChatTurn(user_id="u1", session_id="s1", turn_id=0, role="user", text="Hi, I am learning Python. Please always use Python 3.11 in your code examples."),
        ChatTurn(user_id="u1", session_id="s1", turn_id=1, role="assistant", text="Hello! Python is a great language. How can I help?"),
        ChatTurn(user_id="u1", session_id="s1", turn_id=2, role="user", text="Please explain lists. And btw, always use snake_case for variables in your code examples."),
    ]

    print("\n--- Input Turns ---")
    for t in turns:
        print(f"[{t.role}]: {t.text}")

    print("\n--- Extracting ---")
    prefs = extractor.extract_turn(turns)

    print("\n--- Output PreferenceList ---")
    print(prefs.model_dump_json(indent=2))

    # Validation
    if prefs.preferences:
        print("\nSUCCESS: Extracted preferences found.")
    else:
        print("\nWARNING: No preferences extracted. (Model might need warming up or prompt adjustment)")

if __name__ == "__main__":
    main()