summaryrefslogtreecommitdiff
path: root/scripts/smoke_llms.py
blob: 109020a4c76d131ca2a0c8fcd475d8a83964c118 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from __future__ import annotations

from pathlib import Path
import sys
import json
import os

ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT / "src"))

from personalization.config.settings import load_local_models_config
from personalization.models.llm.qwen_instruct import QwenInstruct
from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor


def ensure_models_present() -> bool:
    cfg = load_local_models_config()
    ok = True
    for spec in (cfg.llm, cfg.preference_extractor):
        path = ROOT / spec.local_path
        if not path.exists() or not any(path.iterdir()):
            print(f"[missing] {path} is empty. Run: python scripts/pull_models.py --target all")
            ok = False
    return ok


def main() -> None:
    if not ensure_models_present():
        return

    cfg = load_local_models_config()
    os.environ["PREF_DEBUG"] = "1"
    llm = QwenInstruct.from_config(cfg)
    extractor = QwenRuleExtractor.from_config(cfg)

    print("[llm] generating...")
    out = llm.generate("Say hello in one short sentence.", max_new_tokens=32, temperature=0.2)
    print(out)

    print("[extractor] extracting...")
    scenarios = [
        (
            "math_latex",
            "Consider the sequence defined by a_1 = 1 and a_{n+1} = a_n + 1/n for n >= 1. "
            "(1) Prove that a_n diverges. (2) Derive an asymptotic expression for a_n in terms of the harmonic numbers H_n. "
            "(3) Compute the limit of (a_n - ln n) as n -> infinity. Please use LaTeX for the output.",
        ),
        (
            "code_python311",
            "I have a performance bottleneck in my Python code that processes large CSV files. It reads rows, aggregates stats, and writes summaries. "
            "Explain how to optimize I/O and memory, discuss multiprocessing vs async. When you show code, please use Python 3.11 syntax and include type hints in the snippets.",
        ),
        (
            "data_json_only",
            "Given a dataset of user events with timestamps, device types, and regions, outline steps to compute DAU, WAU, and retention. "
            "List pitfalls and how to handle missing data. Return your final answer as JSON only.",
        ),
        (
            "writing_concise_no_emoji",
            "Explain the difference between supervised and reinforcement learning with practical examples and cautions. "
            "Keep answers concise and avoid emojis.",
        ),
    ]
    for name, query in scenarios:
        print(f"\n[scenario] {name}")
        prefs = extractor.extract_preferences(query)
        print(json.dumps(prefs, indent=2, ensure_ascii=False))


if __name__ == "__main__":
    main()