summaryrefslogtreecommitdiff
path: root/backend
diff options
context:
space:
mode:
Diffstat (limited to 'backend')
-rw-r--r--backend/app/main.py85
-rw-r--r--backend/app/schemas.py52
-rw-r--r--backend/app/services/llm.py116
-rw-r--r--backend/requirements.txt8
4 files changed, 261 insertions, 0 deletions
diff --git a/backend/app/main.py b/backend/app/main.py
new file mode 100644
index 0000000..48cb89f
--- /dev/null
+++ b/backend/app/main.py
@@ -0,0 +1,85 @@
+from fastapi import FastAPI, HTTPException
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import StreamingResponse
+from app.schemas import NodeRunRequest, NodeRunResponse, MergeStrategy, Role, Message, Context
+from app.services.llm import llm_streamer
+from dotenv import load_dotenv
+import os
+
+load_dotenv()
+
+app = FastAPI(title="ContextFlow Backend")
+
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+
+@app.get("/")
+def read_root():
+ return {"message": "ContextFlow Backend is running"}
+
+def smart_merge_messages(messages: list[Message]) -> list[Message]:
+ """
+ Merges messages using two steps:
+ 1. Deduplication by ID (to handle diamond dependencies).
+ 2. Merging consecutive messages from the same role.
+ """
+ if not messages:
+ return []
+
+ # 1. Deduplicate by ID, keeping order
+ seen_ids = set()
+ deduplicated = []
+ for msg in messages:
+ if msg.id not in seen_ids:
+ deduplicated.append(msg)
+ seen_ids.add(msg.id)
+
+ # 2. Merge consecutive roles
+ if not deduplicated:
+ return []
+
+ merged = []
+ current_msg = deduplicated[0].model_copy()
+
+ for next_msg in deduplicated[1:]:
+ if next_msg.role == current_msg.role:
+ # Merge content
+ current_msg.content += f"\n\n{next_msg.content}"
+ # Keep the latest timestamp
+ current_msg.timestamp = next_msg.timestamp
+ else:
+ merged.append(current_msg)
+ current_msg = next_msg.model_copy()
+
+ merged.append(current_msg)
+ return merged
+
+@app.post("/api/run_node_stream")
+async def run_node_stream(request: NodeRunRequest):
+ """
+ Stream the response from the LLM.
+ """
+ # 1. Concatenate all incoming contexts first
+ raw_messages = []
+ for ctx in request.incoming_contexts:
+ raw_messages.extend(ctx.messages)
+
+ # 2. Apply Merge Strategy
+ final_messages = []
+ if request.merge_strategy == MergeStrategy.SMART:
+ final_messages = smart_merge_messages(raw_messages)
+ else:
+ # RAW strategy: just keep them as is
+ final_messages = raw_messages
+
+ execution_context = Context(messages=final_messages)
+
+ return StreamingResponse(
+ llm_streamer(execution_context, request.user_prompt, request.config),
+ media_type="text/event-stream"
+ )
diff --git a/backend/app/schemas.py b/backend/app/schemas.py
new file mode 100644
index 0000000..ac90bc1
--- /dev/null
+++ b/backend/app/schemas.py
@@ -0,0 +1,52 @@
+from pydantic import BaseModel, Field
+from typing import List, Optional, Dict, Any, Union
+from enum import Enum
+import time
+
+class Role(str, Enum):
+ USER = "user"
+ ASSISTANT = "assistant"
+ SYSTEM = "system"
+
+class Message(BaseModel):
+ id: str = Field(..., description="Unique ID for the message")
+ role: Role
+ content: str
+ timestamp: float = Field(default_factory=time.time)
+ # Metadata to track where this message came from
+ source_node_id: Optional[str] = None
+ model_used: Optional[str] = None
+
+class Context(BaseModel):
+ messages: List[Message] = []
+
+class ModelProvider(str, Enum):
+ OPENAI = "openai"
+ GOOGLE = "google"
+
+class LLMConfig(BaseModel):
+ provider: ModelProvider
+ model_name: str
+ temperature: float = 0.7
+ max_tokens: int = 1000
+ system_prompt: Optional[str] = None
+ api_key: Optional[str] = None # Optional override, usually from env
+
+class MergeStrategy(str, Enum):
+ RAW = "raw"
+ SMART = "smart"
+
+class NodeRunRequest(BaseModel):
+ node_id: str
+ incoming_contexts: List[Context] = []
+ user_prompt: str
+ config: LLMConfig
+ merge_strategy: MergeStrategy = MergeStrategy.SMART
+
+class NodeRunResponse(BaseModel):
+ node_id: str
+ output_context: Context
+ response_content: str
+ raw_response: Optional[Dict[str, Any]] = None
+ usage: Optional[Dict[str, Any]] = None
+
diff --git a/backend/app/services/llm.py b/backend/app/services/llm.py
new file mode 100644
index 0000000..958ab4c
--- /dev/null
+++ b/backend/app/services/llm.py
@@ -0,0 +1,116 @@
+import os
+from typing import AsyncGenerator
+import openai
+import google.generativeai as genai
+from app.schemas import LLMConfig, Message, Role, Context
+
+# Simple in-memory cache for clients to avoid re-initializing constantly
+# In a real app, use dependency injection or singletons
+_openai_client = None
+
+def get_openai_client(api_key: str = None):
+ global _openai_client
+ key = api_key or os.getenv("OPENAI_API_KEY")
+ if not key:
+ raise ValueError("OpenAI API Key not found")
+ if not _openai_client:
+ _openai_client = openai.AsyncOpenAI(api_key=key)
+ return _openai_client
+
+def configure_google(api_key: str = None):
+ key = api_key or os.getenv("GOOGLE_API_KEY")
+ if not key:
+ raise ValueError("Google API Key not found")
+ genai.configure(api_key=key)
+
+async def stream_openai(messages: list[Message], config: LLMConfig) -> AsyncGenerator[str, None]:
+ client = get_openai_client(config.api_key)
+
+ # Convert internal Message schema to OpenAI format
+ openai_messages = []
+ if config.system_prompt:
+ openai_messages.append({"role": "system", "content": config.system_prompt})
+
+ for msg in messages:
+ openai_messages.append({"role": msg.role.value, "content": msg.content})
+
+ stream = await client.chat.completions.create(
+ model=config.model_name,
+ messages=openai_messages,
+ temperature=config.temperature,
+ max_tokens=config.max_tokens,
+ stream=True
+ )
+
+ async for chunk in stream:
+ if chunk.choices[0].delta.content:
+ yield chunk.choices[0].delta.content
+
+async def stream_google(messages: list[Message], config: LLMConfig) -> AsyncGenerator[str, None]:
+ configure_google(config.api_key)
+ model = genai.GenerativeModel(config.model_name)
+
+ # Google Generative AI history format:
+ # [{"role": "user", "parts": ["..."]}, {"role": "model", "parts": ["..."]}]
+ # System prompt is usually set on model init or prepended.
+
+ history = []
+ # If system prompt exists, we might prepend it to the first user message or use specific system instruction if supported
+ # Gemini 1.5 Pro supports system instructions. For simplicity, let's prepend to history if possible or context.
+
+ system_instruction = config.system_prompt
+ if system_instruction:
+ model = genai.GenerativeModel(config.model_name, system_instruction=system_instruction)
+
+ # Convert messages
+ # Note: Gemini strictly requires user/model alternation in history usually.
+ # We will need to handle this. For MVP, we assume the input is clean or we blindly map.
+ for msg in messages:
+ role = "user" if msg.role == Role.USER else "model"
+ history.append({"role": role, "parts": [msg.content]})
+
+ # The last message should be the prompt, strictly speaking, `chat.send_message` takes the new message
+ # But if we are treating everything as history...
+ # Let's separate the last user message as the prompt if possible.
+
+ if history and history[-1]["role"] == "user":
+ last_msg = history.pop()
+ chat = model.start_chat(history=history)
+ response_stream = await chat.send_message_async(last_msg["parts"][0], stream=True)
+ else:
+ # If the last message is not user, we might be in a weird state.
+ # Just send an empty prompt or handle error?
+ # For now, assume the user always provides a prompt in the node.
+ chat = model.start_chat(history=history)
+ response_stream = await chat.send_message_async("...", stream=True) # Fallback
+
+ async for chunk in response_stream:
+ if chunk.text:
+ yield chunk.text
+
+async def llm_streamer(context: Context, user_prompt: str, config: LLMConfig) -> AsyncGenerator[str, None]:
+ # 1. Merge Context + New User Prompt
+ # We create a temporary list of messages for this inference
+ messages_to_send = context.messages.copy()
+
+ # If user_prompt is provided (it should be for a Question Block)
+ if user_prompt.strip():
+ messages_to_send.append(Message(
+ id="temp_user_prompt", # ID doesn't matter for the API call
+ role=Role.USER,
+ content=user_prompt
+ ))
+
+ # 2. Call Provider
+ try:
+ if config.provider == "openai":
+ async for chunk in stream_openai(messages_to_send, config):
+ yield chunk
+ elif config.provider == "google":
+ async for chunk in stream_google(messages_to_send, config):
+ yield chunk
+ else:
+ yield f"Error: Unsupported provider {config.provider}"
+ except Exception as e:
+ yield f"Error calling LLM: {str(e)}"
+
diff --git a/backend/requirements.txt b/backend/requirements.txt
new file mode 100644
index 0000000..545f6b7
--- /dev/null
+++ b/backend/requirements.txt
@@ -0,0 +1,8 @@
+fastapi
+uvicorn
+pydantic
+openai
+google-generativeai
+python-dotenv
+httpx
+