diff options
Diffstat (limited to 'src/personalization/evaluation/baselines/base.py')
| -rw-r--r-- | src/personalization/evaluation/baselines/base.py | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/src/personalization/evaluation/baselines/base.py b/src/personalization/evaluation/baselines/base.py new file mode 100644 index 0000000..a3051bd --- /dev/null +++ b/src/personalization/evaluation/baselines/base.py @@ -0,0 +1,83 @@ +""" +Base class for all baseline agents. + +All agents must implement: +- respond(): Generate a response to user query +- end_session(): Called when a session ends (for memory updates) +- reset_user(): Reset all state for a user +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import List, Dict, Any, Optional + + +@dataclass +class AgentResponse: + """Response from an agent.""" + answer: str + debug_info: Dict[str, Any] = field(default_factory=dict) + + +class BaselineAgent(ABC): + """Abstract base class for all baseline agents.""" + + def __init__(self, model_name: str, **kwargs): + """ + Args: + model_name: Name/path of the LLM to use + **kwargs: Additional configuration + """ + self.model_name = model_name + self.config = kwargs + + @abstractmethod + def respond( + self, + user_id: str, + query: str, + conversation_history: List[Dict[str, str]], + **kwargs + ) -> AgentResponse: + """ + Generate a response to the user's query. + + Args: + user_id: Unique identifier for the user + query: Current user message + conversation_history: List of previous messages [{"role": "user/assistant", "content": "..."}] + **kwargs: Additional context (e.g., task info) + + Returns: + AgentResponse with answer and debug info + """ + pass + + @abstractmethod + def end_session(self, user_id: str, conversation: List[Dict[str, str]]): + """ + Called when a session (one task) ends. + Use this to update memory, notes, etc. + + Args: + user_id: User identifier + conversation: Complete conversation from this session + """ + pass + + @abstractmethod + def reset_user(self, user_id: str): + """ + Completely reset all state for a user. + Called at the start of a new experiment. + + Args: + user_id: User identifier + """ + pass + + def get_name(self) -> str: + """Get a descriptive name for this agent.""" + return self.__class__.__name__ + + |
