summaryrefslogtreecommitdiff
path: root/backend/app/main.py
blob: 65fa3a3a7565a9a489db74ab1cc67c37d9718e8e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app.schemas import NodeRunRequest, NodeRunResponse, MergeStrategy, Role, Message, Context, LLMConfig, ModelProvider, ReasoningEffort
from app.services.llm import llm_streamer, generate_title
from dotenv import load_dotenv
import os

load_dotenv()

app = FastAPI(title="ContextFlow Backend")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/")
def read_root():
    return {"message": "ContextFlow Backend is running"}

def smart_merge_messages(messages: list[Message]) -> list[Message]:
    """
    Merges messages using two steps:
    1. Deduplication by ID (to handle diamond dependencies).
    2. Merging consecutive messages from the same role.
    """
    if not messages:
        return []

    # 1. Deduplicate by ID, keeping order
    seen_ids = set()
    deduplicated = []
    for msg in messages:
        if msg.id not in seen_ids:
            deduplicated.append(msg)
            seen_ids.add(msg.id)
    
    # 2. Merge consecutive roles
    if not deduplicated:
        return []
        
    merged = []
    current_msg = deduplicated[0].model_copy()
    
    for next_msg in deduplicated[1:]:
        if next_msg.role == current_msg.role:
            # Merge content
            current_msg.content += f"\n\n{next_msg.content}"
            # Keep the latest timestamp
            current_msg.timestamp = next_msg.timestamp
        else:
            merged.append(current_msg)
            current_msg = next_msg.model_copy()
            
    merged.append(current_msg)
    return merged

@app.post("/api/run_node_stream")
async def run_node_stream(request: NodeRunRequest):
    """
    Stream the response from the LLM.
    """
    # 1. Concatenate all incoming contexts first
    raw_messages = []
    for ctx in request.incoming_contexts:
        raw_messages.extend(ctx.messages)
    
    # 2. Apply Merge Strategy
    final_messages = []
    if request.merge_strategy == MergeStrategy.SMART:
        final_messages = smart_merge_messages(raw_messages)
    else:
        # RAW strategy: just keep them as is
        final_messages = raw_messages

    execution_context = Context(messages=final_messages)

    return StreamingResponse(
        llm_streamer(execution_context, request.user_prompt, request.config),
        media_type="text/event-stream"
    )

class TitleRequest(BaseModel):
    user_prompt: str
    response: str

class TitleResponse(BaseModel):
    title: str

@app.post("/api/generate_title", response_model=TitleResponse)
async def generate_title_endpoint(request: TitleRequest):
    """
    Generate a short title for a Q-A pair using gpt-5-nano.
    Returns 3-4 short English words summarizing the topic.
    """
    title = await generate_title(request.user_prompt, request.response)
    return TitleResponse(title=title)


class SummarizeRequest(BaseModel):
    content: str
    model: str  # Model to use for summarization

class SummarizeResponse(BaseModel):
    summary: str

@app.post("/api/summarize", response_model=SummarizeResponse)
async def summarize_endpoint(request: SummarizeRequest):
    """
    Summarize the given content using the specified model.
    """
    from app.services.llm import summarize_content
    summary = await summarize_content(request.content, request.model)
    return SummarizeResponse(summary=summary)