diff options
| author | blackhao <13851610112@163.com> | 2025-12-10 21:32:46 -0600 |
|---|---|---|
| committer | blackhao <13851610112@163.com> | 2025-12-10 21:32:46 -0600 |
| commit | b51af40db8ae73216789472903298589bb54df2b (patch) | |
| tree | c172c0ebdc66ee805fb51663ec921f89ac547186 /backend/app | |
| parent | 718c7f50992656a97434ce5041e716145ec3a5c8 (diff) | |
set keys 2
Diffstat (limited to 'backend/app')
| -rw-r--r-- | backend/app/main.py | 57 | ||||
| -rw-r--r-- | backend/app/services/llm.py | 22 |
2 files changed, 59 insertions, 20 deletions
diff --git a/backend/app/main.py b/backend/app/main.py index 902d693..c254652 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -8,7 +8,7 @@ from fastapi import UploadFile, File, Form from pydantic import BaseModel from app.schemas import NodeRunRequest, NodeRunResponse, MergeStrategy, Role, Message, Context, LLMConfig, ModelProvider, ReasoningEffort from app.services.llm import llm_streamer, generate_title, get_openai_client -from app.auth import auth_router, get_current_user, init_db, User, get_db +from app.auth import auth_router, get_current_user, get_current_user_optional, init_db, User, get_db from app.auth.utils import get_password_hash from dotenv import load_dotenv import os @@ -67,6 +67,23 @@ VALID_FILE_PROVIDERS = {"local", "openai", "google"} OPENAI_MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB limit per OpenAI docs OPENAI_DEFAULT_FILE_PURPOSE = os.getenv("OPENAI_FILE_PURPOSE", "user_data") +def get_user_api_key(user: User | None, provider: str) -> str | None: + """ + Get API key for a provider from user's saved settings. + Falls back to environment variable if user has no key set. + """ + if user: + if provider == "openai" and user.openai_api_key: + return user.openai_api_key + if provider in ("google", "gemini") and user.gemini_api_key: + return user.gemini_api_key + # Fallback to environment variables + if provider == "openai": + return os.getenv("OPENAI_API_KEY") + if provider in ("google", "gemini"): + return os.getenv("GOOGLE_API_KEY") + return None + def ensure_user_root(user: str) -> str: """ Ensures the new data root structure: @@ -248,10 +265,23 @@ def smart_merge_messages(messages: list[Message]) -> list[Message]: return merged @app.post("/api/run_node_stream") -async def run_node_stream(request: NodeRunRequest): +async def run_node_stream( + request: NodeRunRequest, + current_user: User | None = Depends(get_current_user_optional) +): """ Stream the response from the LLM. """ + # Get API key from user settings if not provided in request + provider_name = request.config.provider.value if hasattr(request.config.provider, 'value') else str(request.config.provider) + if not request.config.api_key: + user_key = get_user_api_key(current_user, provider_name.lower()) + if user_key: + request.config.api_key = user_key + + # Get username for file operations + username = current_user.username if current_user else DEFAULT_USER + # 1. Concatenate all incoming contexts first raw_messages = [] for ctx in request.incoming_contexts: @@ -272,7 +302,7 @@ async def run_node_stream(request: NodeRunRequest): if request.config.provider == ModelProvider.OPENAI: vs_ids, debug_refs, filters = await prepare_openai_vector_search( - user=DEFAULT_USER, + user=username, attached_ids=request.attached_file_ids, scopes=request.scopes, llm_config=request.config, @@ -283,7 +313,7 @@ async def run_node_stream(request: NodeRunRequest): # Try to get user's vector store anyway try: client = get_openai_client(request.config.api_key) - vs_id = await ensure_user_vector_store(DEFAULT_USER, client) + vs_id = await ensure_user_vector_store(username, client) if vs_id: vs_ids = [vs_id] except Exception as e: @@ -297,7 +327,7 @@ async def run_node_stream(request: NodeRunRequest): print(f"[openai file_search] vs_ids={vs_ids} refs={debug_refs} filters={filters}") elif request.config.provider == ModelProvider.GOOGLE: attachments = await prepare_attachments( - user=DEFAULT_USER, + user=username, target_provider=request.config.provider, attached_ids=request.attached_file_ids, llm_config=request.config, @@ -316,12 +346,16 @@ class TitleResponse(BaseModel): title: str @app.post("/api/generate_title", response_model=TitleResponse) -async def generate_title_endpoint(request: TitleRequest): +async def generate_title_endpoint( + request: TitleRequest, + current_user: User | None = Depends(get_current_user_optional) +): """ Generate a short title for a Q-A pair using gpt-5-nano. Returns 3-4 short English words summarizing the topic. """ - title = await generate_title(request.user_prompt, request.response) + api_key = get_user_api_key(current_user, "openai") + title = await generate_title(request.user_prompt, request.response, api_key) return TitleResponse(title=title) @@ -333,12 +367,17 @@ class SummarizeResponse(BaseModel): summary: str @app.post("/api/summarize", response_model=SummarizeResponse) -async def summarize_endpoint(request: SummarizeRequest): +async def summarize_endpoint( + request: SummarizeRequest, + current_user: User | None = Depends(get_current_user_optional) +): """ Summarize the given content using the specified model. """ from app.services.llm import summarize_content - summary = await summarize_content(request.content, request.model) + openai_key = get_user_api_key(current_user, "openai") + gemini_key = get_user_api_key(current_user, "gemini") + summary = await summarize_content(request.content, request.model, openai_key, gemini_key) return SummarizeResponse(summary=summary) # ---------------- Project / Blueprint APIs ---------------- diff --git a/backend/app/services/llm.py b/backend/app/services/llm.py index 96b0514..87b9155 100644 --- a/backend/app/services/llm.py +++ b/backend/app/services/llm.py @@ -4,18 +4,18 @@ import openai import google.generativeai as genai from app.schemas import LLMConfig, Message, Role, Context -# Simple in-memory cache for clients to avoid re-initializing constantly +# Cache OpenAI clients by API key to avoid re-initializing constantly # In a real app, use dependency injection or singletons -_openai_client = None +_openai_clients: dict[str, openai.AsyncOpenAI] = {} def get_openai_client(api_key: str = None): - global _openai_client + global _openai_clients key = api_key or os.getenv("OPENAI_API_KEY") if not key: raise ValueError("OpenAI API Key not found") - if not _openai_client: - _openai_client = openai.AsyncOpenAI(api_key=key) - return _openai_client + if key not in _openai_clients: + _openai_clients[key] = openai.AsyncOpenAI(api_key=key) + return _openai_clients[key] def configure_google(api_key: str = None): key = api_key or os.getenv("GOOGLE_API_KEY") @@ -345,12 +345,12 @@ async def llm_streamer( yield f"Error calling LLM: {str(e)}" -async def generate_title(user_prompt: str, response: str) -> str: +async def generate_title(user_prompt: str, response: str, api_key: str = None) -> str: """ Generate a short title (3-4 words) for a Q-A pair using gpt-5-nano. Uses Responses API (required for gpt-5 series), synchronous mode (no background). """ - client = get_openai_client() + client = get_openai_client(api_key) instructions = """TASK: Extract a short topic title from the given Q&A. Do NOT answer the question - only extract the topic. @@ -413,7 +413,7 @@ Q: "What's the weather in NYC?" -> "NYC Weather\"""" return "New Question" -async def summarize_content(content: str, model: str) -> str: +async def summarize_content(content: str, model: str, openai_api_key: str = None, gemini_api_key: str = None) -> str: """ Summarize the given content using the specified model. Supports both OpenAI and Gemini models. @@ -434,7 +434,7 @@ Output only the summary, no preamble.""" from google.genai import types import os - key = os.getenv("GOOGLE_API_KEY") + key = gemini_api_key or os.getenv("GOOGLE_API_KEY") if not key: return "Error: Google API Key not found" @@ -456,7 +456,7 @@ Output only the summary, no preamble.""" else: # Use OpenAI - client = get_openai_client() + client = get_openai_client(openai_api_key) # Check if model needs Responses API responses_api_models = [ |
