1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
|
import os
from typing import AsyncGenerator
import openai
import google.generativeai as genai
from app.schemas import LLMConfig, Message, Role, Context
# Simple in-memory cache for clients to avoid re-initializing constantly
# In a real app, use dependency injection or singletons
_openai_client = None
def get_openai_client(api_key: str = None):
global _openai_client
key = api_key or os.getenv("OPENAI_API_KEY")
if not key:
raise ValueError("OpenAI API Key not found")
if not _openai_client:
_openai_client = openai.AsyncOpenAI(api_key=key)
return _openai_client
def configure_google(api_key: str = None):
key = api_key or os.getenv("GOOGLE_API_KEY")
if not key:
raise ValueError("Google API Key not found")
genai.configure(api_key=key)
async def stream_openai(messages: list[Message], config: LLMConfig) -> AsyncGenerator[str, None]:
client = get_openai_client(config.api_key)
# Convert internal Message schema to OpenAI format
openai_messages = []
if config.system_prompt:
openai_messages.append({"role": "system", "content": config.system_prompt})
for msg in messages:
openai_messages.append({"role": msg.role.value, "content": msg.content})
stream = await client.chat.completions.create(
model=config.model_name,
messages=openai_messages,
temperature=config.temperature,
max_tokens=config.max_tokens,
stream=True
)
async for chunk in stream:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
async def stream_google(messages: list[Message], config: LLMConfig) -> AsyncGenerator[str, None]:
configure_google(config.api_key)
model = genai.GenerativeModel(config.model_name)
# Google Generative AI history format:
# [{"role": "user", "parts": ["..."]}, {"role": "model", "parts": ["..."]}]
# System prompt is usually set on model init or prepended.
history = []
# If system prompt exists, we might prepend it to the first user message or use specific system instruction if supported
# Gemini 1.5 Pro supports system instructions. For simplicity, let's prepend to history if possible or context.
system_instruction = config.system_prompt
if system_instruction:
model = genai.GenerativeModel(config.model_name, system_instruction=system_instruction)
# Convert messages
# Note: Gemini strictly requires user/model alternation in history usually.
# We will need to handle this. For MVP, we assume the input is clean or we blindly map.
for msg in messages:
role = "user" if msg.role == Role.USER else "model"
history.append({"role": role, "parts": [msg.content]})
# The last message should be the prompt, strictly speaking, `chat.send_message` takes the new message
# But if we are treating everything as history...
# Let's separate the last user message as the prompt if possible.
if history and history[-1]["role"] == "user":
last_msg = history.pop()
chat = model.start_chat(history=history)
response_stream = await chat.send_message_async(last_msg["parts"][0], stream=True)
else:
# If the last message is not user, we might be in a weird state.
# Just send an empty prompt or handle error?
# For now, assume the user always provides a prompt in the node.
chat = model.start_chat(history=history)
response_stream = await chat.send_message_async("...", stream=True) # Fallback
async for chunk in response_stream:
if chunk.text:
yield chunk.text
async def llm_streamer(context: Context, user_prompt: str, config: LLMConfig) -> AsyncGenerator[str, None]:
# 1. Merge Context + New User Prompt
# We create a temporary list of messages for this inference
messages_to_send = context.messages.copy()
# If user_prompt is provided (it should be for a Question Block)
if user_prompt.strip():
messages_to_send.append(Message(
id="temp_user_prompt", # ID doesn't matter for the API call
role=Role.USER,
content=user_prompt
))
# 2. Call Provider
try:
if config.provider == "openai":
async for chunk in stream_openai(messages_to_send, config):
yield chunk
elif config.provider == "google":
async for chunk in stream_google(messages_to_send, config):
yield chunk
else:
yield f"Error: Unsupported provider {config.provider}"
except Exception as e:
yield f"Error calling LLM: {str(e)}"
|