summaryrefslogtreecommitdiff
path: root/src/personalization/config
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
commite43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch)
tree6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /src/personalization/config
Initial commit (clean history)HEADmain
Diffstat (limited to 'src/personalization/config')
-rw-r--r--src/personalization/config/__init__.py0
-rw-r--r--src/personalization/config/registry.py131
-rw-r--r--src/personalization/config/settings.py73
3 files changed, 204 insertions, 0 deletions
diff --git a/src/personalization/config/__init__.py b/src/personalization/config/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/config/__init__.py
diff --git a/src/personalization/config/registry.py b/src/personalization/config/registry.py
new file mode 100644
index 0000000..d825ad3
--- /dev/null
+++ b/src/personalization/config/registry.py
@@ -0,0 +1,131 @@
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Any, Dict, Optional
+import torch
+import yaml
+
+from personalization.config import settings
+
+# Avoid circular imports by NOT importing extractors here at top level
+# from personalization.models.preference_extractor.base import PreferenceExtractorBase
+# from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor
+# from personalization.models.preference_extractor.gpt4o_extractor import GPT4OExtractor
+# from personalization.models.preference_extractor.llm_extractor import PreferenceExtractorLLM
+
+_DTYPE_MAP: Dict[str, torch.dtype] = {
+ "bfloat16": torch.bfloat16,
+ "float16": torch.float16,
+ "float32": torch.float32,
+}
+
+def choose_dtype(preferred: Optional[str] = None) -> torch.dtype:
+ if preferred and preferred.lower() in _DTYPE_MAP:
+ dt = _DTYPE_MAP[preferred.lower()]
+ else:
+ dt = torch.bfloat16 if torch.cuda.is_available() else torch.float32
+ if dt is torch.bfloat16 and not torch.cuda.is_available():
+ return torch.float32
+ return dt
+
+def choose_device_map(spec: Optional[str] = "auto") -> Any:
+ return spec or "auto"
+
+def ensure_local_path(path_str: str) -> str:
+ path = Path(path_str)
+ if not path.exists():
+ path.mkdir(parents=True, exist_ok=True)
+ return str(path)
+
+# --- Chat Model Factory ---
+def get_chat_model(name: str, device_override: Optional[str] = None):
+ """
+ Get a chat model by name.
+
+ Args:
+ name: Model name (e.g., "qwen_1_5b", "llama_8b")
+ device_override: Optional device override (e.g., "cuda:2"). If None, uses config default.
+ """
+ from personalization.models.llm.base import ChatModel
+ from personalization.models.llm.qwen_instruct import QwenInstruct
+ from personalization.models.llm.llama_instruct import LlamaChatModel
+
+ cfg = settings.load_local_models_config()
+
+ # Try to load raw config to support multi-backend map
+ with open("configs/local_models.yaml", "r") as f:
+ raw_cfg = yaml.safe_load(f)
+
+ models = raw_cfg.get("models", {}).get("llm", {})
+
+ # If models['llm'] is a dict of configs (new style)
+ if isinstance(models, dict) and "backend" in models.get(name, {}):
+ spec = models[name]
+ backend = spec.get("backend", "qwen")
+ path = spec["path"]
+ device = device_override or spec.get("device", "cuda") # Use override if provided
+ dtype = spec.get("dtype", "bfloat16")
+ max_len = spec.get("max_context_length", 4096)
+
+ if backend == "qwen":
+ return QwenInstruct(
+ model_path=path,
+ device=device,
+ dtype=choose_dtype(dtype), # Converts string to torch.dtype
+ max_context_length=max_len
+ )
+ elif backend == "llama":
+ return LlamaChatModel(
+ model_path=path,
+ device=device,
+ dtype=choose_dtype(dtype), # Converts string to torch.dtype
+ max_context_length=max_len
+ )
+
+ # Fallback to legacy single config
+ return QwenInstruct.from_config(cfg)
+
+def get_preference_extractor(name: Optional[str] = None):
+ # Deferred imports to break circular dependency
+ from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor
+ from personalization.models.preference_extractor.gpt4o_extractor import GPT4OExtractor
+ from personalization.models.preference_extractor.llm_extractor import PreferenceExtractorLLM
+
+ cfg = settings.load_local_models_config()
+ pref_cfg = cfg.preference_extractor
+
+ if name is None:
+ if isinstance(pref_cfg, dict) and "qwen3_0_6b_sft" in pref_cfg:
+ name = "qwen3_0_6b_sft"
+ else:
+ name = "rule"
+
+ if isinstance(pref_cfg, dict) and name in pref_cfg:
+ spec = pref_cfg[name]
+ if name == "qwen3_0_6b_sft":
+ # Use QwenRuleExtractor which we have updated for SFT End-to-End logic
+ return QwenRuleExtractor(
+ model_path=spec["path"],
+ device_map=spec.get("device", "auto"),
+ dtype=choose_dtype(spec.get("dtype", "bfloat16")),
+ )
+ # Add 'default' handling if mapped to rule/gpt
+ if name == "default":
+ pass
+
+ if name == "gpt4o":
+ return GPT4OExtractor.from_config(cfg)
+ elif name == "rule":
+ if isinstance(pref_cfg, dict):
+ if "default" in pref_cfg:
+ # Manually construct to bypass ModelSpec mismatch if needed
+ spec_dict = pref_cfg["default"]
+ return QwenRuleExtractor(
+ model_path=spec_dict["local_path"],
+ dtype=choose_dtype(spec_dict.get("dtype")),
+ device_map=choose_device_map(spec_dict.get("device_map"))
+ )
+ else:
+ return QwenRuleExtractor.from_config(cfg)
+
+ raise ValueError(f"Could not load preference extractor: {name}")
diff --git a/src/personalization/config/settings.py b/src/personalization/config/settings.py
new file mode 100644
index 0000000..1bb1bbe
--- /dev/null
+++ b/src/personalization/config/settings.py
@@ -0,0 +1,73 @@
+from __future__ import annotations
+
+import os
+from pathlib import Path
+from typing import Optional, Any, Dict
+
+import yaml
+from pydantic import BaseModel, Field
+
+
+class ModelSpec(BaseModel):
+ hf_id: str = Field(..., description="Hugging Face repository id")
+ local_path: str = Field(..., description="Local directory for model weights")
+ dtype: Optional[str] = Field(
+ default="bfloat16", description="Preferred torch dtype: bfloat16|float16|float32"
+ )
+ device_map: Optional[str] = Field(default="auto", description="Device map policy")
+
+
+class EmbeddingModelsConfig(BaseModel):
+ qwen3: Optional[ModelSpec] = None
+ nemotron: Optional[ModelSpec] = None
+
+
+class RerankerModelsConfig(BaseModel):
+ qwen3_8b: Optional[ModelSpec] = None
+
+
+class LocalModelsConfig(BaseModel):
+ llm: ModelSpec
+ preference_extractor: Any # Allow flexible dict or ModelSpec for now to support map
+ embedding: Optional[EmbeddingModelsConfig] = None
+ reranker: Optional[RerankerModelsConfig] = None
+
+
+def _resolve_config_path(env_key: str, default_rel: str) -> Path:
+ value = os.getenv(env_key)
+ if value:
+ return Path(value).expanduser().resolve()
+ return (Path.cwd() / default_rel).resolve()
+
+
+def load_local_models_config(path: Optional[str] = None) -> LocalModelsConfig:
+ config_path = Path(path) if path else _resolve_config_path(
+ "LOCAL_MODELS_CONFIG", "configs/local_models.yaml"
+ )
+ with open(config_path, "r", encoding="utf-8") as f:
+ raw = yaml.safe_load(f) or {}
+ models = raw.get("models", {})
+ embedding_cfg = None
+ if "embedding" in models:
+ emb = models["embedding"] or {}
+ # dtype/device_map are not necessary for embedders; ModelSpec still accepts them
+ embedding_cfg = EmbeddingModelsConfig(
+ qwen3=ModelSpec(**emb["qwen3"]) if "qwen3" in emb else None,
+ nemotron=ModelSpec(**emb["nemotron"]) if "nemotron" in emb else None,
+ )
+
+ reranker_cfg = None
+ if "reranker" in models:
+ rer = models["reranker"] or {}
+ reranker_cfg = RerankerModelsConfig(
+ qwen3_8b=ModelSpec(**rer["qwen3_8b"]) if "qwen3_8b" in rer else None
+ )
+
+ return LocalModelsConfig(
+ llm=ModelSpec(**models["llm"]),
+ preference_extractor=models["preference_extractor"], # Pass raw dict/value
+ embedding=embedding_cfg,
+ reranker=reranker_cfg,
+ )
+
+