summaryrefslogtreecommitdiff
path: root/scripts/smoke_extractor_llm.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/smoke_extractor_llm.py')
-rw-r--r--scripts/smoke_extractor_llm.py54
1 files changed, 54 insertions, 0 deletions
diff --git a/scripts/smoke_extractor_llm.py b/scripts/smoke_extractor_llm.py
new file mode 100644
index 0000000..b16d0e2
--- /dev/null
+++ b/scripts/smoke_extractor_llm.py
@@ -0,0 +1,54 @@
+#!/usr/bin/env python3
+"""
+Smoke test for PreferenceExtractorLLM (Qwen3-0.6B).
+Requires 'saves/qwen3-0.6b-full-sft-h200/checkpoint-4358' to be present.
+"""
+
+import sys
+import os
+import json
+
+# Add src to sys.path
+sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.config.registry import get_preference_extractor
+from personalization.retrieval.preference_store.schemas import ChatTurn
+
+def main():
+ print("Initializing Preference Extractor (qwen3_0_6b_sft)...")
+ try:
+ extractor = get_preference_extractor("qwen3_0_6b_sft")
+ except Exception as e:
+ print(f"Failed to load extractor: {e}")
+ print("Please check if the checkpoint exists at saves/qwen3-0.6b-full-sft-h200/checkpoint-4358")
+ print("and local_models.yaml is configured correctly.")
+ sys.exit(1)
+
+ print("Extractor loaded successfully.")
+
+ # Construct dummy conversation
+ turns = [
+ ChatTurn(user_id="u1", session_id="s1", turn_id=0, role="user", text="Hi, I am learning Python. Please always use Python 3.11 in your code examples."),
+ ChatTurn(user_id="u1", session_id="s1", turn_id=1, role="assistant", text="Hello! Python is a great language. How can I help?"),
+ ChatTurn(user_id="u1", session_id="s1", turn_id=2, role="user", text="Please explain lists. And btw, always use snake_case for variables in your code examples."),
+ ]
+
+ print("\n--- Input Turns ---")
+ for t in turns:
+ print(f"[{t.role}]: {t.text}")
+
+ print("\n--- Extracting ---")
+ prefs = extractor.extract_turn(turns)
+
+ print("\n--- Output PreferenceList ---")
+ print(prefs.model_dump_json(indent=2))
+
+ # Validation
+ if prefs.preferences:
+ print("\nSUCCESS: Extracted preferences found.")
+ else:
+ print("\nWARNING: No preferences extracted. (Model might need warming up or prompt adjustment)")
+
+if __name__ == "__main__":
+ main()
+