diff options
| author | blackhao <13851610112@163.com> | 2025-12-05 20:40:40 -0600 |
|---|---|---|
| committer | blackhao <13851610112@163.com> | 2025-12-05 20:40:40 -0600 |
| commit | d9868550e66fe8aaa7fff55a8e24b871ee51e3b1 (patch) | |
| tree | 147757f77def085c5649c4d930d5a51ff44a1e3d /backend/app/services | |
| parent | d87c364dc43ca241fadc9dccbad9ec8896c93a1e (diff) | |
init: add project files and ignore secrets
Diffstat (limited to 'backend/app/services')
| -rw-r--r-- | backend/app/services/llm.py | 116 |
1 files changed, 116 insertions, 0 deletions
diff --git a/backend/app/services/llm.py b/backend/app/services/llm.py new file mode 100644 index 0000000..958ab4c --- /dev/null +++ b/backend/app/services/llm.py @@ -0,0 +1,116 @@ +import os +from typing import AsyncGenerator +import openai +import google.generativeai as genai +from app.schemas import LLMConfig, Message, Role, Context + +# Simple in-memory cache for clients to avoid re-initializing constantly +# In a real app, use dependency injection or singletons +_openai_client = None + +def get_openai_client(api_key: str = None): + global _openai_client + key = api_key or os.getenv("OPENAI_API_KEY") + if not key: + raise ValueError("OpenAI API Key not found") + if not _openai_client: + _openai_client = openai.AsyncOpenAI(api_key=key) + return _openai_client + +def configure_google(api_key: str = None): + key = api_key or os.getenv("GOOGLE_API_KEY") + if not key: + raise ValueError("Google API Key not found") + genai.configure(api_key=key) + +async def stream_openai(messages: list[Message], config: LLMConfig) -> AsyncGenerator[str, None]: + client = get_openai_client(config.api_key) + + # Convert internal Message schema to OpenAI format + openai_messages = [] + if config.system_prompt: + openai_messages.append({"role": "system", "content": config.system_prompt}) + + for msg in messages: + openai_messages.append({"role": msg.role.value, "content": msg.content}) + + stream = await client.chat.completions.create( + model=config.model_name, + messages=openai_messages, + temperature=config.temperature, + max_tokens=config.max_tokens, + stream=True + ) + + async for chunk in stream: + if chunk.choices[0].delta.content: + yield chunk.choices[0].delta.content + +async def stream_google(messages: list[Message], config: LLMConfig) -> AsyncGenerator[str, None]: + configure_google(config.api_key) + model = genai.GenerativeModel(config.model_name) + + # Google Generative AI history format: + # [{"role": "user", "parts": ["..."]}, {"role": "model", "parts": ["..."]}] + # System prompt is usually set on model init or prepended. + + history = [] + # If system prompt exists, we might prepend it to the first user message or use specific system instruction if supported + # Gemini 1.5 Pro supports system instructions. For simplicity, let's prepend to history if possible or context. + + system_instruction = config.system_prompt + if system_instruction: + model = genai.GenerativeModel(config.model_name, system_instruction=system_instruction) + + # Convert messages + # Note: Gemini strictly requires user/model alternation in history usually. + # We will need to handle this. For MVP, we assume the input is clean or we blindly map. + for msg in messages: + role = "user" if msg.role == Role.USER else "model" + history.append({"role": role, "parts": [msg.content]}) + + # The last message should be the prompt, strictly speaking, `chat.send_message` takes the new message + # But if we are treating everything as history... + # Let's separate the last user message as the prompt if possible. + + if history and history[-1]["role"] == "user": + last_msg = history.pop() + chat = model.start_chat(history=history) + response_stream = await chat.send_message_async(last_msg["parts"][0], stream=True) + else: + # If the last message is not user, we might be in a weird state. + # Just send an empty prompt or handle error? + # For now, assume the user always provides a prompt in the node. + chat = model.start_chat(history=history) + response_stream = await chat.send_message_async("...", stream=True) # Fallback + + async for chunk in response_stream: + if chunk.text: + yield chunk.text + +async def llm_streamer(context: Context, user_prompt: str, config: LLMConfig) -> AsyncGenerator[str, None]: + # 1. Merge Context + New User Prompt + # We create a temporary list of messages for this inference + messages_to_send = context.messages.copy() + + # If user_prompt is provided (it should be for a Question Block) + if user_prompt.strip(): + messages_to_send.append(Message( + id="temp_user_prompt", # ID doesn't matter for the API call + role=Role.USER, + content=user_prompt + )) + + # 2. Call Provider + try: + if config.provider == "openai": + async for chunk in stream_openai(messages_to_send, config): + yield chunk + elif config.provider == "google": + async for chunk in stream_google(messages_to_send, config): + yield chunk + else: + yield f"Error: Unsupported provider {config.provider}" + except Exception as e: + yield f"Error calling LLM: {str(e)}" + |
