summaryrefslogtreecommitdiff
path: root/src/personalization/config/settings.py
blob: 1bb1bbebdf533e93ae427518994e4d8cde10fbb5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from __future__ import annotations

import os
from pathlib import Path
from typing import Optional, Any, Dict

import yaml
from pydantic import BaseModel, Field


class ModelSpec(BaseModel):
    hf_id: str = Field(..., description="Hugging Face repository id")
    local_path: str = Field(..., description="Local directory for model weights")
    dtype: Optional[str] = Field(
        default="bfloat16", description="Preferred torch dtype: bfloat16|float16|float32"
    )
    device_map: Optional[str] = Field(default="auto", description="Device map policy")


class EmbeddingModelsConfig(BaseModel):
    qwen3: Optional[ModelSpec] = None
    nemotron: Optional[ModelSpec] = None


class RerankerModelsConfig(BaseModel):
    qwen3_8b: Optional[ModelSpec] = None


class LocalModelsConfig(BaseModel):
    llm: ModelSpec
    preference_extractor: Any # Allow flexible dict or ModelSpec for now to support map
    embedding: Optional[EmbeddingModelsConfig] = None
    reranker: Optional[RerankerModelsConfig] = None


def _resolve_config_path(env_key: str, default_rel: str) -> Path:
    value = os.getenv(env_key)
    if value:
        return Path(value).expanduser().resolve()
    return (Path.cwd() / default_rel).resolve()


def load_local_models_config(path: Optional[str] = None) -> LocalModelsConfig:
    config_path = Path(path) if path else _resolve_config_path(
        "LOCAL_MODELS_CONFIG", "configs/local_models.yaml"
    )
    with open(config_path, "r", encoding="utf-8") as f:
        raw = yaml.safe_load(f) or {}
    models = raw.get("models", {})
    embedding_cfg = None
    if "embedding" in models:
        emb = models["embedding"] or {}
        # dtype/device_map are not necessary for embedders; ModelSpec still accepts them
        embedding_cfg = EmbeddingModelsConfig(
            qwen3=ModelSpec(**emb["qwen3"]) if "qwen3" in emb else None,
            nemotron=ModelSpec(**emb["nemotron"]) if "nemotron" in emb else None,
        )
    
    reranker_cfg = None
    if "reranker" in models:
        rer = models["reranker"] or {}
        reranker_cfg = RerankerModelsConfig(
            qwen3_8b=ModelSpec(**rer["qwen3_8b"]) if "qwen3_8b" in rer else None
        )

    return LocalModelsConfig(
        llm=ModelSpec(**models["llm"]),
        preference_extractor=models["preference_extractor"], # Pass raw dict/value
        embedding=embedding_cfg,
        reranker=reranker_cfg,
    )