1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
|
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, Optional
import torch
import yaml
from personalization.config import settings
# Avoid circular imports by NOT importing extractors here at top level
# from personalization.models.preference_extractor.base import PreferenceExtractorBase
# from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor
# from personalization.models.preference_extractor.gpt4o_extractor import GPT4OExtractor
# from personalization.models.preference_extractor.llm_extractor import PreferenceExtractorLLM
_DTYPE_MAP: Dict[str, torch.dtype] = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
}
def choose_dtype(preferred: Optional[str] = None) -> torch.dtype:
if preferred and preferred.lower() in _DTYPE_MAP:
dt = _DTYPE_MAP[preferred.lower()]
else:
dt = torch.bfloat16 if torch.cuda.is_available() else torch.float32
if dt is torch.bfloat16 and not torch.cuda.is_available():
return torch.float32
return dt
def choose_device_map(spec: Optional[str] = "auto") -> Any:
return spec or "auto"
def ensure_local_path(path_str: str) -> str:
path = Path(path_str)
if not path.exists():
path.mkdir(parents=True, exist_ok=True)
return str(path)
# --- Chat Model Factory ---
def get_chat_model(name: str, device_override: Optional[str] = None):
"""
Get a chat model by name.
Args:
name: Model name (e.g., "qwen_1_5b", "llama_8b")
device_override: Optional device override (e.g., "cuda:2"). If None, uses config default.
"""
from personalization.models.llm.base import ChatModel
from personalization.models.llm.qwen_instruct import QwenInstruct
from personalization.models.llm.llama_instruct import LlamaChatModel
cfg = settings.load_local_models_config()
# Try to load raw config to support multi-backend map
with open("configs/local_models.yaml", "r") as f:
raw_cfg = yaml.safe_load(f)
models = raw_cfg.get("models", {}).get("llm", {})
# If models['llm'] is a dict of configs (new style)
if isinstance(models, dict) and "backend" in models.get(name, {}):
spec = models[name]
backend = spec.get("backend", "qwen")
path = spec["path"]
device = device_override or spec.get("device", "cuda") # Use override if provided
dtype = spec.get("dtype", "bfloat16")
max_len = spec.get("max_context_length", 4096)
if backend == "qwen":
return QwenInstruct(
model_path=path,
device=device,
dtype=choose_dtype(dtype), # Converts string to torch.dtype
max_context_length=max_len
)
elif backend == "llama":
return LlamaChatModel(
model_path=path,
device=device,
dtype=choose_dtype(dtype), # Converts string to torch.dtype
max_context_length=max_len
)
# Fallback to legacy single config
return QwenInstruct.from_config(cfg)
def get_preference_extractor(name: Optional[str] = None):
# Deferred imports to break circular dependency
from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor
from personalization.models.preference_extractor.gpt4o_extractor import GPT4OExtractor
from personalization.models.preference_extractor.llm_extractor import PreferenceExtractorLLM
cfg = settings.load_local_models_config()
pref_cfg = cfg.preference_extractor
if name is None:
if isinstance(pref_cfg, dict) and "qwen3_0_6b_sft" in pref_cfg:
name = "qwen3_0_6b_sft"
else:
name = "rule"
if isinstance(pref_cfg, dict) and name in pref_cfg:
spec = pref_cfg[name]
if name == "qwen3_0_6b_sft":
# Use QwenRuleExtractor which we have updated for SFT End-to-End logic
return QwenRuleExtractor(
model_path=spec["path"],
device_map=spec.get("device", "auto"),
dtype=choose_dtype(spec.get("dtype", "bfloat16")),
)
# Add 'default' handling if mapped to rule/gpt
if name == "default":
pass
if name == "gpt4o":
return GPT4OExtractor.from_config(cfg)
elif name == "rule":
if isinstance(pref_cfg, dict):
if "default" in pref_cfg:
# Manually construct to bypass ModelSpec mismatch if needed
spec_dict = pref_cfg["default"]
return QwenRuleExtractor(
model_path=spec_dict["local_path"],
dtype=choose_dtype(spec_dict.get("dtype")),
device_map=choose_device_map(spec_dict.get("device_map"))
)
else:
return QwenRuleExtractor.from_config(cfg)
raise ValueError(f"Could not load preference extractor: {name}")
|