import os from typing import AsyncGenerator import openai import google.generativeai as genai from app.schemas import LLMConfig, Message, Role, Context # Simple in-memory cache for clients to avoid re-initializing constantly # In a real app, use dependency injection or singletons _openai_client = None def get_openai_client(api_key: str = None): global _openai_client key = api_key or os.getenv("OPENAI_API_KEY") if not key: raise ValueError("OpenAI API Key not found") if not _openai_client: _openai_client = openai.AsyncOpenAI(api_key=key) return _openai_client def configure_google(api_key: str = None): key = api_key or os.getenv("GOOGLE_API_KEY") if not key: raise ValueError("Google API Key not found") genai.configure(api_key=key) async def stream_openai(messages: list[Message], config: LLMConfig) -> AsyncGenerator[str, None]: client = get_openai_client(config.api_key) # Convert internal Message schema to OpenAI format openai_messages = [] if config.system_prompt: openai_messages.append({"role": "system", "content": config.system_prompt}) for msg in messages: openai_messages.append({"role": msg.role.value, "content": msg.content}) stream = await client.chat.completions.create( model=config.model_name, messages=openai_messages, temperature=config.temperature, max_tokens=config.max_tokens, stream=True ) async for chunk in stream: if chunk.choices[0].delta.content: yield chunk.choices[0].delta.content async def stream_google(messages: list[Message], config: LLMConfig) -> AsyncGenerator[str, None]: configure_google(config.api_key) model = genai.GenerativeModel(config.model_name) # Google Generative AI history format: # [{"role": "user", "parts": ["..."]}, {"role": "model", "parts": ["..."]}] # System prompt is usually set on model init or prepended. history = [] # If system prompt exists, we might prepend it to the first user message or use specific system instruction if supported # Gemini 1.5 Pro supports system instructions. For simplicity, let's prepend to history if possible or context. system_instruction = config.system_prompt if system_instruction: model = genai.GenerativeModel(config.model_name, system_instruction=system_instruction) # Convert messages # Note: Gemini strictly requires user/model alternation in history usually. # We will need to handle this. For MVP, we assume the input is clean or we blindly map. for msg in messages: role = "user" if msg.role == Role.USER else "model" history.append({"role": role, "parts": [msg.content]}) # The last message should be the prompt, strictly speaking, `chat.send_message` takes the new message # But if we are treating everything as history... # Let's separate the last user message as the prompt if possible. if history and history[-1]["role"] == "user": last_msg = history.pop() chat = model.start_chat(history=history) response_stream = await chat.send_message_async(last_msg["parts"][0], stream=True) else: # If the last message is not user, we might be in a weird state. # Just send an empty prompt or handle error? # For now, assume the user always provides a prompt in the node. chat = model.start_chat(history=history) response_stream = await chat.send_message_async("...", stream=True) # Fallback async for chunk in response_stream: if chunk.text: yield chunk.text async def llm_streamer(context: Context, user_prompt: str, config: LLMConfig) -> AsyncGenerator[str, None]: # 1. Merge Context + New User Prompt # We create a temporary list of messages for this inference messages_to_send = context.messages.copy() # If user_prompt is provided (it should be for a Question Block) if user_prompt.strip(): messages_to_send.append(Message( id="temp_user_prompt", # ID doesn't matter for the API call role=Role.USER, content=user_prompt )) # 2. Call Provider try: if config.provider == "openai": async for chunk in stream_openai(messages_to_send, config): yield chunk elif config.provider == "google": async for chunk in stream_google(messages_to_send, config): yield chunk else: yield f"Error: Unsupported provider {config.provider}" except Exception as e: yield f"Error calling LLM: {str(e)}"