summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorblackhao <13851610112@163.com>2025-12-10 21:32:46 -0600
committerblackhao <13851610112@163.com>2025-12-10 21:32:46 -0600
commitb51af40db8ae73216789472903298589bb54df2b (patch)
treec172c0ebdc66ee805fb51663ec921f89ac547186
parent718c7f50992656a97434ce5041e716145ec3a5c8 (diff)
set keys 2
-rw-r--r--backend/app/main.py57
-rw-r--r--backend/app/services/llm.py22
2 files changed, 59 insertions, 20 deletions
diff --git a/backend/app/main.py b/backend/app/main.py
index 902d693..c254652 100644
--- a/backend/app/main.py
+++ b/backend/app/main.py
@@ -8,7 +8,7 @@ from fastapi import UploadFile, File, Form
from pydantic import BaseModel
from app.schemas import NodeRunRequest, NodeRunResponse, MergeStrategy, Role, Message, Context, LLMConfig, ModelProvider, ReasoningEffort
from app.services.llm import llm_streamer, generate_title, get_openai_client
-from app.auth import auth_router, get_current_user, init_db, User, get_db
+from app.auth import auth_router, get_current_user, get_current_user_optional, init_db, User, get_db
from app.auth.utils import get_password_hash
from dotenv import load_dotenv
import os
@@ -67,6 +67,23 @@ VALID_FILE_PROVIDERS = {"local", "openai", "google"}
OPENAI_MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB limit per OpenAI docs
OPENAI_DEFAULT_FILE_PURPOSE = os.getenv("OPENAI_FILE_PURPOSE", "user_data")
+def get_user_api_key(user: User | None, provider: str) -> str | None:
+ """
+ Get API key for a provider from user's saved settings.
+ Falls back to environment variable if user has no key set.
+ """
+ if user:
+ if provider == "openai" and user.openai_api_key:
+ return user.openai_api_key
+ if provider in ("google", "gemini") and user.gemini_api_key:
+ return user.gemini_api_key
+ # Fallback to environment variables
+ if provider == "openai":
+ return os.getenv("OPENAI_API_KEY")
+ if provider in ("google", "gemini"):
+ return os.getenv("GOOGLE_API_KEY")
+ return None
+
def ensure_user_root(user: str) -> str:
"""
Ensures the new data root structure:
@@ -248,10 +265,23 @@ def smart_merge_messages(messages: list[Message]) -> list[Message]:
return merged
@app.post("/api/run_node_stream")
-async def run_node_stream(request: NodeRunRequest):
+async def run_node_stream(
+ request: NodeRunRequest,
+ current_user: User | None = Depends(get_current_user_optional)
+):
"""
Stream the response from the LLM.
"""
+ # Get API key from user settings if not provided in request
+ provider_name = request.config.provider.value if hasattr(request.config.provider, 'value') else str(request.config.provider)
+ if not request.config.api_key:
+ user_key = get_user_api_key(current_user, provider_name.lower())
+ if user_key:
+ request.config.api_key = user_key
+
+ # Get username for file operations
+ username = current_user.username if current_user else DEFAULT_USER
+
# 1. Concatenate all incoming contexts first
raw_messages = []
for ctx in request.incoming_contexts:
@@ -272,7 +302,7 @@ async def run_node_stream(request: NodeRunRequest):
if request.config.provider == ModelProvider.OPENAI:
vs_ids, debug_refs, filters = await prepare_openai_vector_search(
- user=DEFAULT_USER,
+ user=username,
attached_ids=request.attached_file_ids,
scopes=request.scopes,
llm_config=request.config,
@@ -283,7 +313,7 @@ async def run_node_stream(request: NodeRunRequest):
# Try to get user's vector store anyway
try:
client = get_openai_client(request.config.api_key)
- vs_id = await ensure_user_vector_store(DEFAULT_USER, client)
+ vs_id = await ensure_user_vector_store(username, client)
if vs_id:
vs_ids = [vs_id]
except Exception as e:
@@ -297,7 +327,7 @@ async def run_node_stream(request: NodeRunRequest):
print(f"[openai file_search] vs_ids={vs_ids} refs={debug_refs} filters={filters}")
elif request.config.provider == ModelProvider.GOOGLE:
attachments = await prepare_attachments(
- user=DEFAULT_USER,
+ user=username,
target_provider=request.config.provider,
attached_ids=request.attached_file_ids,
llm_config=request.config,
@@ -316,12 +346,16 @@ class TitleResponse(BaseModel):
title: str
@app.post("/api/generate_title", response_model=TitleResponse)
-async def generate_title_endpoint(request: TitleRequest):
+async def generate_title_endpoint(
+ request: TitleRequest,
+ current_user: User | None = Depends(get_current_user_optional)
+):
"""
Generate a short title for a Q-A pair using gpt-5-nano.
Returns 3-4 short English words summarizing the topic.
"""
- title = await generate_title(request.user_prompt, request.response)
+ api_key = get_user_api_key(current_user, "openai")
+ title = await generate_title(request.user_prompt, request.response, api_key)
return TitleResponse(title=title)
@@ -333,12 +367,17 @@ class SummarizeResponse(BaseModel):
summary: str
@app.post("/api/summarize", response_model=SummarizeResponse)
-async def summarize_endpoint(request: SummarizeRequest):
+async def summarize_endpoint(
+ request: SummarizeRequest,
+ current_user: User | None = Depends(get_current_user_optional)
+):
"""
Summarize the given content using the specified model.
"""
from app.services.llm import summarize_content
- summary = await summarize_content(request.content, request.model)
+ openai_key = get_user_api_key(current_user, "openai")
+ gemini_key = get_user_api_key(current_user, "gemini")
+ summary = await summarize_content(request.content, request.model, openai_key, gemini_key)
return SummarizeResponse(summary=summary)
# ---------------- Project / Blueprint APIs ----------------
diff --git a/backend/app/services/llm.py b/backend/app/services/llm.py
index 96b0514..87b9155 100644
--- a/backend/app/services/llm.py
+++ b/backend/app/services/llm.py
@@ -4,18 +4,18 @@ import openai
import google.generativeai as genai
from app.schemas import LLMConfig, Message, Role, Context
-# Simple in-memory cache for clients to avoid re-initializing constantly
+# Cache OpenAI clients by API key to avoid re-initializing constantly
# In a real app, use dependency injection or singletons
-_openai_client = None
+_openai_clients: dict[str, openai.AsyncOpenAI] = {}
def get_openai_client(api_key: str = None):
- global _openai_client
+ global _openai_clients
key = api_key or os.getenv("OPENAI_API_KEY")
if not key:
raise ValueError("OpenAI API Key not found")
- if not _openai_client:
- _openai_client = openai.AsyncOpenAI(api_key=key)
- return _openai_client
+ if key not in _openai_clients:
+ _openai_clients[key] = openai.AsyncOpenAI(api_key=key)
+ return _openai_clients[key]
def configure_google(api_key: str = None):
key = api_key or os.getenv("GOOGLE_API_KEY")
@@ -345,12 +345,12 @@ async def llm_streamer(
yield f"Error calling LLM: {str(e)}"
-async def generate_title(user_prompt: str, response: str) -> str:
+async def generate_title(user_prompt: str, response: str, api_key: str = None) -> str:
"""
Generate a short title (3-4 words) for a Q-A pair using gpt-5-nano.
Uses Responses API (required for gpt-5 series), synchronous mode (no background).
"""
- client = get_openai_client()
+ client = get_openai_client(api_key)
instructions = """TASK: Extract a short topic title from the given Q&A. Do NOT answer the question - only extract the topic.
@@ -413,7 +413,7 @@ Q: "What's the weather in NYC?" -> "NYC Weather\""""
return "New Question"
-async def summarize_content(content: str, model: str) -> str:
+async def summarize_content(content: str, model: str, openai_api_key: str = None, gemini_api_key: str = None) -> str:
"""
Summarize the given content using the specified model.
Supports both OpenAI and Gemini models.
@@ -434,7 +434,7 @@ Output only the summary, no preamble."""
from google.genai import types
import os
- key = os.getenv("GOOGLE_API_KEY")
+ key = gemini_api_key or os.getenv("GOOGLE_API_KEY")
if not key:
return "Error: Google API Key not found"
@@ -456,7 +456,7 @@ Output only the summary, no preamble."""
else:
# Use OpenAI
- client = get_openai_client()
+ client = get_openai_client(openai_api_key)
# Check if model needs Responses API
responses_api_models = [