summaryrefslogtreecommitdiff
path: root/backend/app/services/llm.py
blob: 958ab4ceb84a1240b1f816f966f76ea45abb99f1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
from typing import AsyncGenerator
import openai
import google.generativeai as genai
from app.schemas import LLMConfig, Message, Role, Context

# Simple in-memory cache for clients to avoid re-initializing constantly
# In a real app, use dependency injection or singletons
_openai_client = None

def get_openai_client(api_key: str = None):
    global _openai_client
    key = api_key or os.getenv("OPENAI_API_KEY")
    if not key:
        raise ValueError("OpenAI API Key not found")
    if not _openai_client:
        _openai_client = openai.AsyncOpenAI(api_key=key)
    return _openai_client

def configure_google(api_key: str = None):
    key = api_key or os.getenv("GOOGLE_API_KEY")
    if not key:
        raise ValueError("Google API Key not found")
    genai.configure(api_key=key)

async def stream_openai(messages: list[Message], config: LLMConfig) -> AsyncGenerator[str, None]:
    client = get_openai_client(config.api_key)
    
    # Convert internal Message schema to OpenAI format
    openai_messages = []
    if config.system_prompt:
        openai_messages.append({"role": "system", "content": config.system_prompt})
    
    for msg in messages:
        openai_messages.append({"role": msg.role.value, "content": msg.content})

    stream = await client.chat.completions.create(
        model=config.model_name,
        messages=openai_messages,
        temperature=config.temperature,
        max_tokens=config.max_tokens,
        stream=True
    )

    async for chunk in stream:
        if chunk.choices[0].delta.content:
            yield chunk.choices[0].delta.content

async def stream_google(messages: list[Message], config: LLMConfig) -> AsyncGenerator[str, None]:
    configure_google(config.api_key)
    model = genai.GenerativeModel(config.model_name)
    
    # Google Generative AI history format:
    # [{"role": "user", "parts": ["..."]}, {"role": "model", "parts": ["..."]}]
    # System prompt is usually set on model init or prepended.
    
    history = []
    # If system prompt exists, we might prepend it to the first user message or use specific system instruction if supported
    # Gemini 1.5 Pro supports system instructions. For simplicity, let's prepend to history if possible or context.
    
    system_instruction = config.system_prompt
    if system_instruction:
         model = genai.GenerativeModel(config.model_name, system_instruction=system_instruction)

    # Convert messages
    # Note: Gemini strictly requires user/model alternation in history usually.
    # We will need to handle this. For MVP, we assume the input is clean or we blindly map.
    for msg in messages:
        role = "user" if msg.role == Role.USER else "model"
        history.append({"role": role, "parts": [msg.content]})

    # The last message should be the prompt, strictly speaking, `chat.send_message` takes the new message
    # But if we are treating everything as history...
    # Let's separate the last user message as the prompt if possible.
    
    if history and history[-1]["role"] == "user":
        last_msg = history.pop()
        chat = model.start_chat(history=history)
        response_stream = await chat.send_message_async(last_msg["parts"][0], stream=True)
    else:
        # If the last message is not user, we might be in a weird state.
        # Just send an empty prompt or handle error? 
        # For now, assume the user always provides a prompt in the node.
        chat = model.start_chat(history=history)
        response_stream = await chat.send_message_async("...", stream=True) # Fallback

    async for chunk in response_stream:
        if chunk.text:
            yield chunk.text

async def llm_streamer(context: Context, user_prompt: str, config: LLMConfig) -> AsyncGenerator[str, None]:
    # 1. Merge Context + New User Prompt
    # We create a temporary list of messages for this inference
    messages_to_send = context.messages.copy()
    
    # If user_prompt is provided (it should be for a Question Block)
    if user_prompt.strip():
        messages_to_send.append(Message(
            id="temp_user_prompt", # ID doesn't matter for the API call
            role=Role.USER,
            content=user_prompt
        ))
    
    # 2. Call Provider
    try:
        if config.provider == "openai":
            async for chunk in stream_openai(messages_to_send, config):
                yield chunk
        elif config.provider == "google":
            async for chunk in stream_google(messages_to_send, config):
                yield chunk
        else:
            yield f"Error: Unsupported provider {config.provider}"
    except Exception as e:
        yield f"Error calling LLM: {str(e)}"