summaryrefslogtreecommitdiff
path: root/src/personalization/feedback/local_llm_reward.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/personalization/feedback/local_llm_reward.py')
-rw-r--r--src/personalization/feedback/local_llm_reward.py342
1 files changed, 342 insertions, 0 deletions
diff --git a/src/personalization/feedback/local_llm_reward.py b/src/personalization/feedback/local_llm_reward.py
new file mode 100644
index 0000000..9837ff0
--- /dev/null
+++ b/src/personalization/feedback/local_llm_reward.py
@@ -0,0 +1,342 @@
+"""
+Local LLM reward model using vLLM server for batch inference.
+
+Drop-in replacement for LLMRewardClient when you want to use a local model
+(e.g., Llama-3.1-8B-Instruct) instead of OpenAI API.
+
+Uses BatchVLLMClient for efficient concurrent requests - vLLM's continuous
+batching will process them together for high throughput.
+"""
+from __future__ import annotations
+
+import asyncio
+import hashlib
+import json
+from dataclasses import dataclass
+from typing import Dict, List, Optional
+
+import aiohttp
+
+from personalization.feedback.schemas import TurnSample
+from personalization.feedback.llm_reward import (
+ REWARD_MAP,
+ VALID_LABELS,
+ JUDGE_SYSTEM_PROMPT,
+ JUDGE_USER_TEMPLATE,
+ RewardResult,
+)
+
+
+@dataclass
+class LocalLLMRewardConfig:
+ """Configuration for local LLM reward model."""
+ vllm_url: str = "http://localhost:8005/v1" # vLLM server URL
+ model_name: Optional[str] = None # Auto-discovered if None
+ max_tokens: int = 256
+ temperature: float = 0.1
+ max_concurrent: int = 100 # High concurrency for vLLM batching
+ timeout: Optional[int] = 60 # Per-request timeout in seconds
+ confidence_threshold: float = 0.6 # tau_c: skip update if confidence < this
+ enable_cache: bool = True # Cache by hash of (q_t, a_t, q_{t+1})
+
+
+class LocalLLMRewardClient:
+ """
+ Local LLM reward client using vLLM server.
+
+ Designed for batch processing - uses async HTTP requests that vLLM
+ batches together via continuous batching for high throughput.
+ """
+
+ def __init__(self, config: Optional[LocalLLMRewardConfig] = None):
+ self.config = config or LocalLLMRewardConfig()
+ self._model_name = self.config.model_name
+ self._cache: Dict[str, RewardResult] = {}
+
+ # Discover model name if not provided
+ if self._model_name is None:
+ self._discover_model_sync()
+
+ def _discover_model_sync(self):
+ """Synchronously discover model name from vLLM server."""
+ import requests
+ try:
+ response = requests.get(
+ f"{self.config.vllm_url}/models",
+ timeout=10
+ )
+ response.raise_for_status()
+ models = response.json()
+ if models.get("data") and len(models["data"]) > 0:
+ self._model_name = models["data"][0]["id"]
+ else:
+ self._model_name = "default"
+ except Exception as e:
+ print(f"[LocalLLMReward] Warning: Could not discover model ({e})")
+ self._model_name = "default"
+
+ def _cache_key(self, query_t: str, answer_t: str, query_t1: str) -> str:
+ """Deterministic hash of the judge input triple."""
+ content = f"{query_t}\x00{answer_t}\x00{query_t1}"
+ return hashlib.sha256(content.encode("utf-8")).hexdigest()
+
+ def _build_messages(self, sample: TurnSample) -> List[dict]:
+ """Construct the judge prompt from (q_t, a_t, q_{t+1})."""
+ user_content = JUDGE_USER_TEMPLATE.format(
+ query_t=sample.query_t,
+ answer_t=sample.answer_t,
+ query_t1=sample.query_t1,
+ )
+ return [
+ {"role": "system", "content": JUDGE_SYSTEM_PROMPT},
+ {"role": "user", "content": user_content},
+ ]
+
+ def _parse_result(self, raw: Optional[str]) -> RewardResult:
+ """Parse structured JSON output into RewardResult."""
+ if raw is None:
+ return RewardResult(
+ label="neutral",
+ confidence=0.0,
+ rationale="request_failed",
+ reward=0.0,
+ should_update=False,
+ )
+
+ try:
+ # Handle markdown code blocks
+ text = raw.strip()
+ if text.startswith("```"):
+ lines = text.split("\n")
+ text = "\n".join(
+ lines[1:-1] if lines[-1].strip() == "```" else lines[1:]
+ )
+
+ parsed = json.loads(text)
+ label = parsed.get("label", "neutral")
+ confidence = float(parsed.get("confidence", 0.0))
+ rationale = parsed.get("rationale", "")
+
+ if label not in VALID_LABELS:
+ label = "neutral"
+ confidence = 0.0
+
+ reward = REWARD_MAP[label]
+
+ # Confidence gating and topic_shift skip
+ should_update = (
+ confidence >= self.config.confidence_threshold
+ and label != "topic_shift"
+ )
+ if not should_update:
+ reward = 0.0
+
+ return RewardResult(
+ label=label,
+ confidence=confidence,
+ rationale=rationale,
+ reward=reward,
+ should_update=should_update,
+ )
+ except (json.JSONDecodeError, KeyError, TypeError, ValueError):
+ # Try to extract JSON from text
+ import re
+ match = re.search(r'\{[^}]+\}', raw, re.DOTALL)
+ if match:
+ try:
+ parsed = json.loads(match.group())
+ label = parsed.get("label", "neutral")
+ confidence = float(parsed.get("confidence", 0.0))
+ rationale = parsed.get("rationale", "")
+
+ if label not in VALID_LABELS:
+ label = "neutral"
+ confidence = 0.0
+
+ reward = REWARD_MAP[label]
+ should_update = (
+ confidence >= self.config.confidence_threshold
+ and label != "topic_shift"
+ )
+ if not should_update:
+ reward = 0.0
+
+ return RewardResult(
+ label=label,
+ confidence=confidence,
+ rationale=rationale,
+ reward=reward,
+ should_update=should_update,
+ )
+ except:
+ pass
+
+ return RewardResult(
+ label="neutral",
+ confidence=0.0,
+ rationale="parse_failure",
+ reward=0.0,
+ should_update=False,
+ )
+
+ async def _single_request(
+ self,
+ session: aiohttp.ClientSession,
+ messages: List[dict],
+ idx: int,
+ ) -> tuple:
+ """Make a single async request to vLLM server."""
+ payload = {
+ "model": self._model_name,
+ "messages": messages,
+ "max_tokens": self.config.max_tokens,
+ "temperature": self.config.temperature,
+ "response_format": {"type": "json_object"},
+ }
+
+ for attempt in range(3):
+ try:
+ timeout_config = (
+ aiohttp.ClientTimeout(total=self.config.timeout)
+ if self.config.timeout else None
+ )
+ async with session.post(
+ f"{self.config.vllm_url}/chat/completions",
+ json=payload,
+ timeout=timeout_config,
+ ) as response:
+ if response.status == 200:
+ result = await response.json()
+ content = result["choices"][0]["message"]["content"]
+ return (idx, content, None)
+ elif response.status == 429:
+ # Rate limit - wait and retry
+ await asyncio.sleep(2 ** attempt)
+ continue
+ else:
+ error_text = await response.text()
+ return (idx, None, f"HTTP {response.status}: {error_text[:200]}")
+ except asyncio.TimeoutError:
+ if attempt < 2:
+ continue
+ return (idx, None, "Timeout")
+ except Exception as e:
+ return (idx, None, str(e))
+
+ return (idx, None, "Max retries exceeded")
+
+ async def judge_batch_async(
+ self,
+ samples: List[TurnSample],
+ show_progress: bool = False,
+ ) -> List[RewardResult]:
+ """
+ Judge a batch of turns using concurrent vLLM requests.
+
+ vLLM's continuous batching will process these together for
+ high throughput.
+ """
+ n_samples = len(samples)
+ results = [None] * n_samples
+
+ # Check cache and build request list
+ to_request = [] # (original_idx, messages)
+ for i, sample in enumerate(samples):
+ if self.config.enable_cache:
+ key = self._cache_key(sample.query_t, sample.answer_t, sample.query_t1)
+ if key in self._cache:
+ results[i] = self._cache[key]
+ continue
+
+ messages = self._build_messages(sample)
+ to_request.append((i, messages))
+
+ if not to_request:
+ return results
+
+ # Make concurrent requests
+ semaphore = asyncio.Semaphore(self.config.max_concurrent)
+
+ async def limited_request(session, messages, idx):
+ async with semaphore:
+ return await self._single_request(session, messages, idx)
+
+ connector = aiohttp.TCPConnector(limit=self.config.max_concurrent)
+ headers = {"Content-Type": "application/json"}
+
+ async with aiohttp.ClientSession(
+ connector=connector, headers=headers
+ ) as session:
+ tasks = [
+ limited_request(session, messages, orig_idx)
+ for orig_idx, messages in to_request
+ ]
+
+ completed = 0
+ for coro in asyncio.as_completed(tasks):
+ orig_idx, content, error = await coro
+ completed += 1
+
+ if error:
+ print(f"[LocalLLMReward] Request {orig_idx} failed: {error}")
+
+ result = self._parse_result(content)
+ results[orig_idx] = result
+
+ # Cache the result
+ if self.config.enable_cache:
+ sample = samples[orig_idx]
+ key = self._cache_key(
+ sample.query_t, sample.answer_t, sample.query_t1
+ )
+ self._cache[key] = result
+
+ if show_progress and completed % 10 == 0:
+ print(f" [LocalLLMReward {completed}/{len(to_request)}] completed")
+
+ return results
+
+ async def judge_async(self, sample: TurnSample) -> RewardResult:
+ """Judge a single turn (async)."""
+ results = await self.judge_batch_async([sample])
+ return results[0]
+
+ def judge_batch(self, samples: List[TurnSample]) -> List[RewardResult]:
+ """
+ Judge a batch of turns (sync wrapper).
+
+ This is the main entry point for batch reward estimation.
+ """
+ return asyncio.run(self.judge_batch_async(samples))
+
+ def judge(self, sample: TurnSample) -> RewardResult:
+ """Judge a single turn (sync wrapper)."""
+ return asyncio.run(self.judge_async(sample))
+
+
+# --- Convenience Functions ---
+
+def estimate_reward_local(
+ sample: TurnSample,
+ config: Optional[LocalLLMRewardConfig] = None,
+) -> tuple:
+ """
+ Synchronous single-sample reward estimation using local LLM.
+ Returns (reward, should_update).
+ """
+ client = LocalLLMRewardClient(config)
+ result = client.judge(sample)
+ return result.reward, result.should_update
+
+
+def estimate_rewards_batch_local(
+ samples: List[TurnSample],
+ config: Optional[LocalLLMRewardConfig] = None,
+) -> List[tuple]:
+ """
+ Synchronous batch reward estimation using local LLM.
+ Returns list of (reward, should_update) tuples.
+ """
+ client = LocalLLMRewardClient(config)
+ results = client.judge_batch(samples)
+ return [(r.reward, r.should_update) for r in results]