diff options
Diffstat (limited to 'backend')
| -rw-r--r-- | backend/app/main.py | 85 | ||||
| -rw-r--r-- | backend/app/schemas.py | 52 | ||||
| -rw-r--r-- | backend/app/services/llm.py | 116 | ||||
| -rw-r--r-- | backend/requirements.txt | 8 |
4 files changed, 261 insertions, 0 deletions
diff --git a/backend/app/main.py b/backend/app/main.py new file mode 100644 index 0000000..48cb89f --- /dev/null +++ b/backend/app/main.py @@ -0,0 +1,85 @@ +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse +from app.schemas import NodeRunRequest, NodeRunResponse, MergeStrategy, Role, Message, Context +from app.services.llm import llm_streamer +from dotenv import load_dotenv +import os + +load_dotenv() + +app = FastAPI(title="ContextFlow Backend") + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +@app.get("/") +def read_root(): + return {"message": "ContextFlow Backend is running"} + +def smart_merge_messages(messages: list[Message]) -> list[Message]: + """ + Merges messages using two steps: + 1. Deduplication by ID (to handle diamond dependencies). + 2. Merging consecutive messages from the same role. + """ + if not messages: + return [] + + # 1. Deduplicate by ID, keeping order + seen_ids = set() + deduplicated = [] + for msg in messages: + if msg.id not in seen_ids: + deduplicated.append(msg) + seen_ids.add(msg.id) + + # 2. Merge consecutive roles + if not deduplicated: + return [] + + merged = [] + current_msg = deduplicated[0].model_copy() + + for next_msg in deduplicated[1:]: + if next_msg.role == current_msg.role: + # Merge content + current_msg.content += f"\n\n{next_msg.content}" + # Keep the latest timestamp + current_msg.timestamp = next_msg.timestamp + else: + merged.append(current_msg) + current_msg = next_msg.model_copy() + + merged.append(current_msg) + return merged + +@app.post("/api/run_node_stream") +async def run_node_stream(request: NodeRunRequest): + """ + Stream the response from the LLM. + """ + # 1. Concatenate all incoming contexts first + raw_messages = [] + for ctx in request.incoming_contexts: + raw_messages.extend(ctx.messages) + + # 2. Apply Merge Strategy + final_messages = [] + if request.merge_strategy == MergeStrategy.SMART: + final_messages = smart_merge_messages(raw_messages) + else: + # RAW strategy: just keep them as is + final_messages = raw_messages + + execution_context = Context(messages=final_messages) + + return StreamingResponse( + llm_streamer(execution_context, request.user_prompt, request.config), + media_type="text/event-stream" + ) diff --git a/backend/app/schemas.py b/backend/app/schemas.py new file mode 100644 index 0000000..ac90bc1 --- /dev/null +++ b/backend/app/schemas.py @@ -0,0 +1,52 @@ +from pydantic import BaseModel, Field +from typing import List, Optional, Dict, Any, Union +from enum import Enum +import time + +class Role(str, Enum): + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + +class Message(BaseModel): + id: str = Field(..., description="Unique ID for the message") + role: Role + content: str + timestamp: float = Field(default_factory=time.time) + # Metadata to track where this message came from + source_node_id: Optional[str] = None + model_used: Optional[str] = None + +class Context(BaseModel): + messages: List[Message] = [] + +class ModelProvider(str, Enum): + OPENAI = "openai" + GOOGLE = "google" + +class LLMConfig(BaseModel): + provider: ModelProvider + model_name: str + temperature: float = 0.7 + max_tokens: int = 1000 + system_prompt: Optional[str] = None + api_key: Optional[str] = None # Optional override, usually from env + +class MergeStrategy(str, Enum): + RAW = "raw" + SMART = "smart" + +class NodeRunRequest(BaseModel): + node_id: str + incoming_contexts: List[Context] = [] + user_prompt: str + config: LLMConfig + merge_strategy: MergeStrategy = MergeStrategy.SMART + +class NodeRunResponse(BaseModel): + node_id: str + output_context: Context + response_content: str + raw_response: Optional[Dict[str, Any]] = None + usage: Optional[Dict[str, Any]] = None + diff --git a/backend/app/services/llm.py b/backend/app/services/llm.py new file mode 100644 index 0000000..958ab4c --- /dev/null +++ b/backend/app/services/llm.py @@ -0,0 +1,116 @@ +import os +from typing import AsyncGenerator +import openai +import google.generativeai as genai +from app.schemas import LLMConfig, Message, Role, Context + +# Simple in-memory cache for clients to avoid re-initializing constantly +# In a real app, use dependency injection or singletons +_openai_client = None + +def get_openai_client(api_key: str = None): + global _openai_client + key = api_key or os.getenv("OPENAI_API_KEY") + if not key: + raise ValueError("OpenAI API Key not found") + if not _openai_client: + _openai_client = openai.AsyncOpenAI(api_key=key) + return _openai_client + +def configure_google(api_key: str = None): + key = api_key or os.getenv("GOOGLE_API_KEY") + if not key: + raise ValueError("Google API Key not found") + genai.configure(api_key=key) + +async def stream_openai(messages: list[Message], config: LLMConfig) -> AsyncGenerator[str, None]: + client = get_openai_client(config.api_key) + + # Convert internal Message schema to OpenAI format + openai_messages = [] + if config.system_prompt: + openai_messages.append({"role": "system", "content": config.system_prompt}) + + for msg in messages: + openai_messages.append({"role": msg.role.value, "content": msg.content}) + + stream = await client.chat.completions.create( + model=config.model_name, + messages=openai_messages, + temperature=config.temperature, + max_tokens=config.max_tokens, + stream=True + ) + + async for chunk in stream: + if chunk.choices[0].delta.content: + yield chunk.choices[0].delta.content + +async def stream_google(messages: list[Message], config: LLMConfig) -> AsyncGenerator[str, None]: + configure_google(config.api_key) + model = genai.GenerativeModel(config.model_name) + + # Google Generative AI history format: + # [{"role": "user", "parts": ["..."]}, {"role": "model", "parts": ["..."]}] + # System prompt is usually set on model init or prepended. + + history = [] + # If system prompt exists, we might prepend it to the first user message or use specific system instruction if supported + # Gemini 1.5 Pro supports system instructions. For simplicity, let's prepend to history if possible or context. + + system_instruction = config.system_prompt + if system_instruction: + model = genai.GenerativeModel(config.model_name, system_instruction=system_instruction) + + # Convert messages + # Note: Gemini strictly requires user/model alternation in history usually. + # We will need to handle this. For MVP, we assume the input is clean or we blindly map. + for msg in messages: + role = "user" if msg.role == Role.USER else "model" + history.append({"role": role, "parts": [msg.content]}) + + # The last message should be the prompt, strictly speaking, `chat.send_message` takes the new message + # But if we are treating everything as history... + # Let's separate the last user message as the prompt if possible. + + if history and history[-1]["role"] == "user": + last_msg = history.pop() + chat = model.start_chat(history=history) + response_stream = await chat.send_message_async(last_msg["parts"][0], stream=True) + else: + # If the last message is not user, we might be in a weird state. + # Just send an empty prompt or handle error? + # For now, assume the user always provides a prompt in the node. + chat = model.start_chat(history=history) + response_stream = await chat.send_message_async("...", stream=True) # Fallback + + async for chunk in response_stream: + if chunk.text: + yield chunk.text + +async def llm_streamer(context: Context, user_prompt: str, config: LLMConfig) -> AsyncGenerator[str, None]: + # 1. Merge Context + New User Prompt + # We create a temporary list of messages for this inference + messages_to_send = context.messages.copy() + + # If user_prompt is provided (it should be for a Question Block) + if user_prompt.strip(): + messages_to_send.append(Message( + id="temp_user_prompt", # ID doesn't matter for the API call + role=Role.USER, + content=user_prompt + )) + + # 2. Call Provider + try: + if config.provider == "openai": + async for chunk in stream_openai(messages_to_send, config): + yield chunk + elif config.provider == "google": + async for chunk in stream_google(messages_to_send, config): + yield chunk + else: + yield f"Error: Unsupported provider {config.provider}" + except Exception as e: + yield f"Error calling LLM: {str(e)}" + diff --git a/backend/requirements.txt b/backend/requirements.txt new file mode 100644 index 0000000..545f6b7 --- /dev/null +++ b/backend/requirements.txt @@ -0,0 +1,8 @@ +fastapi +uvicorn +pydantic +openai +google-generativeai +python-dotenv +httpx + |
