summaryrefslogtreecommitdiff
path: root/backend/app/services/llm.py
diff options
context:
space:
mode:
Diffstat (limited to 'backend/app/services/llm.py')
-rw-r--r--backend/app/services/llm.py116
1 files changed, 116 insertions, 0 deletions
diff --git a/backend/app/services/llm.py b/backend/app/services/llm.py
new file mode 100644
index 0000000..958ab4c
--- /dev/null
+++ b/backend/app/services/llm.py
@@ -0,0 +1,116 @@
+import os
+from typing import AsyncGenerator
+import openai
+import google.generativeai as genai
+from app.schemas import LLMConfig, Message, Role, Context
+
+# Simple in-memory cache for clients to avoid re-initializing constantly
+# In a real app, use dependency injection or singletons
+_openai_client = None
+
+def get_openai_client(api_key: str = None):
+ global _openai_client
+ key = api_key or os.getenv("OPENAI_API_KEY")
+ if not key:
+ raise ValueError("OpenAI API Key not found")
+ if not _openai_client:
+ _openai_client = openai.AsyncOpenAI(api_key=key)
+ return _openai_client
+
+def configure_google(api_key: str = None):
+ key = api_key or os.getenv("GOOGLE_API_KEY")
+ if not key:
+ raise ValueError("Google API Key not found")
+ genai.configure(api_key=key)
+
+async def stream_openai(messages: list[Message], config: LLMConfig) -> AsyncGenerator[str, None]:
+ client = get_openai_client(config.api_key)
+
+ # Convert internal Message schema to OpenAI format
+ openai_messages = []
+ if config.system_prompt:
+ openai_messages.append({"role": "system", "content": config.system_prompt})
+
+ for msg in messages:
+ openai_messages.append({"role": msg.role.value, "content": msg.content})
+
+ stream = await client.chat.completions.create(
+ model=config.model_name,
+ messages=openai_messages,
+ temperature=config.temperature,
+ max_tokens=config.max_tokens,
+ stream=True
+ )
+
+ async for chunk in stream:
+ if chunk.choices[0].delta.content:
+ yield chunk.choices[0].delta.content
+
+async def stream_google(messages: list[Message], config: LLMConfig) -> AsyncGenerator[str, None]:
+ configure_google(config.api_key)
+ model = genai.GenerativeModel(config.model_name)
+
+ # Google Generative AI history format:
+ # [{"role": "user", "parts": ["..."]}, {"role": "model", "parts": ["..."]}]
+ # System prompt is usually set on model init or prepended.
+
+ history = []
+ # If system prompt exists, we might prepend it to the first user message or use specific system instruction if supported
+ # Gemini 1.5 Pro supports system instructions. For simplicity, let's prepend to history if possible or context.
+
+ system_instruction = config.system_prompt
+ if system_instruction:
+ model = genai.GenerativeModel(config.model_name, system_instruction=system_instruction)
+
+ # Convert messages
+ # Note: Gemini strictly requires user/model alternation in history usually.
+ # We will need to handle this. For MVP, we assume the input is clean or we blindly map.
+ for msg in messages:
+ role = "user" if msg.role == Role.USER else "model"
+ history.append({"role": role, "parts": [msg.content]})
+
+ # The last message should be the prompt, strictly speaking, `chat.send_message` takes the new message
+ # But if we are treating everything as history...
+ # Let's separate the last user message as the prompt if possible.
+
+ if history and history[-1]["role"] == "user":
+ last_msg = history.pop()
+ chat = model.start_chat(history=history)
+ response_stream = await chat.send_message_async(last_msg["parts"][0], stream=True)
+ else:
+ # If the last message is not user, we might be in a weird state.
+ # Just send an empty prompt or handle error?
+ # For now, assume the user always provides a prompt in the node.
+ chat = model.start_chat(history=history)
+ response_stream = await chat.send_message_async("...", stream=True) # Fallback
+
+ async for chunk in response_stream:
+ if chunk.text:
+ yield chunk.text
+
+async def llm_streamer(context: Context, user_prompt: str, config: LLMConfig) -> AsyncGenerator[str, None]:
+ # 1. Merge Context + New User Prompt
+ # We create a temporary list of messages for this inference
+ messages_to_send = context.messages.copy()
+
+ # If user_prompt is provided (it should be for a Question Block)
+ if user_prompt.strip():
+ messages_to_send.append(Message(
+ id="temp_user_prompt", # ID doesn't matter for the API call
+ role=Role.USER,
+ content=user_prompt
+ ))
+
+ # 2. Call Provider
+ try:
+ if config.provider == "openai":
+ async for chunk in stream_openai(messages_to_send, config):
+ yield chunk
+ elif config.provider == "google":
+ async for chunk in stream_google(messages_to_send, config):
+ yield chunk
+ else:
+ yield f"Error: Unsupported provider {config.provider}"
+ except Exception as e:
+ yield f"Error calling LLM: {str(e)}"
+