diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
| commit | e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch) | |
| tree | 6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /scripts/smoke_extractor_llm.py | |
Diffstat (limited to 'scripts/smoke_extractor_llm.py')
| -rw-r--r-- | scripts/smoke_extractor_llm.py | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/scripts/smoke_extractor_llm.py b/scripts/smoke_extractor_llm.py new file mode 100644 index 0000000..b16d0e2 --- /dev/null +++ b/scripts/smoke_extractor_llm.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +""" +Smoke test for PreferenceExtractorLLM (Qwen3-0.6B). +Requires 'saves/qwen3-0.6b-full-sft-h200/checkpoint-4358' to be present. +""" + +import sys +import os +import json + +# Add src to sys.path +sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.config.registry import get_preference_extractor +from personalization.retrieval.preference_store.schemas import ChatTurn + +def main(): + print("Initializing Preference Extractor (qwen3_0_6b_sft)...") + try: + extractor = get_preference_extractor("qwen3_0_6b_sft") + except Exception as e: + print(f"Failed to load extractor: {e}") + print("Please check if the checkpoint exists at saves/qwen3-0.6b-full-sft-h200/checkpoint-4358") + print("and local_models.yaml is configured correctly.") + sys.exit(1) + + print("Extractor loaded successfully.") + + # Construct dummy conversation + turns = [ + ChatTurn(user_id="u1", session_id="s1", turn_id=0, role="user", text="Hi, I am learning Python. Please always use Python 3.11 in your code examples."), + ChatTurn(user_id="u1", session_id="s1", turn_id=1, role="assistant", text="Hello! Python is a great language. How can I help?"), + ChatTurn(user_id="u1", session_id="s1", turn_id=2, role="user", text="Please explain lists. And btw, always use snake_case for variables in your code examples."), + ] + + print("\n--- Input Turns ---") + for t in turns: + print(f"[{t.role}]: {t.text}") + + print("\n--- Extracting ---") + prefs = extractor.extract_turn(turns) + + print("\n--- Output PreferenceList ---") + print(prefs.model_dump_json(indent=2)) + + # Validation + if prefs.preferences: + print("\nSUCCESS: Extracted preferences found.") + else: + print("\nWARNING: No preferences extracted. (Model might need warming up or prompt adjustment)") + +if __name__ == "__main__": + main() + |
