diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
| commit | e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch) | |
| tree | 6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /src/personalization/config/settings.py | |
Diffstat (limited to 'src/personalization/config/settings.py')
| -rw-r--r-- | src/personalization/config/settings.py | 73 |
1 files changed, 73 insertions, 0 deletions
diff --git a/src/personalization/config/settings.py b/src/personalization/config/settings.py new file mode 100644 index 0000000..1bb1bbe --- /dev/null +++ b/src/personalization/config/settings.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import os +from pathlib import Path +from typing import Optional, Any, Dict + +import yaml +from pydantic import BaseModel, Field + + +class ModelSpec(BaseModel): + hf_id: str = Field(..., description="Hugging Face repository id") + local_path: str = Field(..., description="Local directory for model weights") + dtype: Optional[str] = Field( + default="bfloat16", description="Preferred torch dtype: bfloat16|float16|float32" + ) + device_map: Optional[str] = Field(default="auto", description="Device map policy") + + +class EmbeddingModelsConfig(BaseModel): + qwen3: Optional[ModelSpec] = None + nemotron: Optional[ModelSpec] = None + + +class RerankerModelsConfig(BaseModel): + qwen3_8b: Optional[ModelSpec] = None + + +class LocalModelsConfig(BaseModel): + llm: ModelSpec + preference_extractor: Any # Allow flexible dict or ModelSpec for now to support map + embedding: Optional[EmbeddingModelsConfig] = None + reranker: Optional[RerankerModelsConfig] = None + + +def _resolve_config_path(env_key: str, default_rel: str) -> Path: + value = os.getenv(env_key) + if value: + return Path(value).expanduser().resolve() + return (Path.cwd() / default_rel).resolve() + + +def load_local_models_config(path: Optional[str] = None) -> LocalModelsConfig: + config_path = Path(path) if path else _resolve_config_path( + "LOCAL_MODELS_CONFIG", "configs/local_models.yaml" + ) + with open(config_path, "r", encoding="utf-8") as f: + raw = yaml.safe_load(f) or {} + models = raw.get("models", {}) + embedding_cfg = None + if "embedding" in models: + emb = models["embedding"] or {} + # dtype/device_map are not necessary for embedders; ModelSpec still accepts them + embedding_cfg = EmbeddingModelsConfig( + qwen3=ModelSpec(**emb["qwen3"]) if "qwen3" in emb else None, + nemotron=ModelSpec(**emb["nemotron"]) if "nemotron" in emb else None, + ) + + reranker_cfg = None + if "reranker" in models: + rer = models["reranker"] or {} + reranker_cfg = RerankerModelsConfig( + qwen3_8b=ModelSpec(**rer["qwen3_8b"]) if "qwen3_8b" in rer else None + ) + + return LocalModelsConfig( + llm=ModelSpec(**models["llm"]), + preference_extractor=models["preference_extractor"], # Pass raw dict/value + embedding=embedding_cfg, + reranker=reranker_cfg, + ) + + |
