summaryrefslogtreecommitdiff
path: root/src/personalization/config/registry.py
blob: d825ad3017cbc630cb11aa9cdbd10ffe2369f09c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from __future__ import annotations

from pathlib import Path
from typing import Any, Dict, Optional
import torch
import yaml

from personalization.config import settings

# Avoid circular imports by NOT importing extractors here at top level
# from personalization.models.preference_extractor.base import PreferenceExtractorBase
# from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor
# from personalization.models.preference_extractor.gpt4o_extractor import GPT4OExtractor
# from personalization.models.preference_extractor.llm_extractor import PreferenceExtractorLLM

_DTYPE_MAP: Dict[str, torch.dtype] = {
    "bfloat16": torch.bfloat16,
    "float16": torch.float16,
    "float32": torch.float32,
}

def choose_dtype(preferred: Optional[str] = None) -> torch.dtype:
    if preferred and preferred.lower() in _DTYPE_MAP:
        dt = _DTYPE_MAP[preferred.lower()]
    else:
        dt = torch.bfloat16 if torch.cuda.is_available() else torch.float32
    if dt is torch.bfloat16 and not torch.cuda.is_available():
        return torch.float32
    return dt

def choose_device_map(spec: Optional[str] = "auto") -> Any:
    return spec or "auto"

def ensure_local_path(path_str: str) -> str:
    path = Path(path_str)
    if not path.exists():
        path.mkdir(parents=True, exist_ok=True)
    return str(path)

# --- Chat Model Factory ---
def get_chat_model(name: str, device_override: Optional[str] = None):
    """
    Get a chat model by name.
    
    Args:
        name: Model name (e.g., "qwen_1_5b", "llama_8b")
        device_override: Optional device override (e.g., "cuda:2"). If None, uses config default.
    """
    from personalization.models.llm.base import ChatModel
    from personalization.models.llm.qwen_instruct import QwenInstruct
    from personalization.models.llm.llama_instruct import LlamaChatModel
    
    cfg = settings.load_local_models_config()
    
    # Try to load raw config to support multi-backend map
    with open("configs/local_models.yaml", "r") as f:
        raw_cfg = yaml.safe_load(f)
    
    models = raw_cfg.get("models", {}).get("llm", {})
    
    # If models['llm'] is a dict of configs (new style)
    if isinstance(models, dict) and "backend" in models.get(name, {}):
        spec = models[name]
        backend = spec.get("backend", "qwen")
        path = spec["path"]
        device = device_override or spec.get("device", "cuda")  # Use override if provided
        dtype = spec.get("dtype", "bfloat16")
        max_len = spec.get("max_context_length", 4096)
        
        if backend == "qwen":
            return QwenInstruct(
                model_path=path,
                device=device,
                dtype=choose_dtype(dtype), # Converts string to torch.dtype
                max_context_length=max_len
            )
        elif backend == "llama":
            return LlamaChatModel(
                model_path=path,
                device=device,
                dtype=choose_dtype(dtype), # Converts string to torch.dtype
                max_context_length=max_len
            )
    
    # Fallback to legacy single config
    return QwenInstruct.from_config(cfg)

def get_preference_extractor(name: Optional[str] = None):
    # Deferred imports to break circular dependency
    from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor
    from personalization.models.preference_extractor.gpt4o_extractor import GPT4OExtractor
    from personalization.models.preference_extractor.llm_extractor import PreferenceExtractorLLM
    
    cfg = settings.load_local_models_config()
    pref_cfg = cfg.preference_extractor
    
    if name is None:
        if isinstance(pref_cfg, dict) and "qwen3_0_6b_sft" in pref_cfg:
            name = "qwen3_0_6b_sft"
        else:
            name = "rule"

    if isinstance(pref_cfg, dict) and name in pref_cfg:
        spec = pref_cfg[name]
        if name == "qwen3_0_6b_sft":
            # Use QwenRuleExtractor which we have updated for SFT End-to-End logic
            return QwenRuleExtractor(
                model_path=spec["path"],
                device_map=spec.get("device", "auto"),
                dtype=choose_dtype(spec.get("dtype", "bfloat16")),
            )
        # Add 'default' handling if mapped to rule/gpt
        if name == "default":
             pass

    if name == "gpt4o":
        return GPT4OExtractor.from_config(cfg)
    elif name == "rule":
        if isinstance(pref_cfg, dict):
             if "default" in pref_cfg:
                 # Manually construct to bypass ModelSpec mismatch if needed
                 spec_dict = pref_cfg["default"]
                 return QwenRuleExtractor(
                     model_path=spec_dict["local_path"],
                     dtype=choose_dtype(spec_dict.get("dtype")),
                     device_map=choose_device_map(spec_dict.get("device_map"))
                 )
        else:
            return QwenRuleExtractor.from_config(cfg)

    raise ValueError(f"Could not load preference extractor: {name}")