diff options
Diffstat (limited to 'backend/app/main.py')
| -rw-r--r-- | backend/app/main.py | 1042 |
1 files changed, 1036 insertions, 6 deletions
diff --git a/backend/app/main.py b/backend/app/main.py index 48cb89f..c254652 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,15 +1,31 @@ -from fastapi import FastAPI, HTTPException +import asyncio +import tempfile +import time +from fastapi import FastAPI, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse -from app.schemas import NodeRunRequest, NodeRunResponse, MergeStrategy, Role, Message, Context -from app.services.llm import llm_streamer +from fastapi.responses import StreamingResponse, FileResponse +from fastapi import UploadFile, File, Form +from pydantic import BaseModel +from app.schemas import NodeRunRequest, NodeRunResponse, MergeStrategy, Role, Message, Context, LLMConfig, ModelProvider, ReasoningEffort +from app.services.llm import llm_streamer, generate_title, get_openai_client +from app.auth import auth_router, get_current_user, get_current_user_optional, init_db, User, get_db +from app.auth.utils import get_password_hash from dotenv import load_dotenv import os +import json +import shutil +from typing import List, Literal, Optional +from uuid import uuid4 +from google import genai +from sqlalchemy.orm import Session load_dotenv() app = FastAPI(title="ContextFlow Backend") +# Include authentication router +app.include_router(auth_router) + app.add_middleware( CORSMiddleware, allow_origins=["*"], @@ -18,6 +34,195 @@ app.add_middleware( allow_headers=["*"], ) +# Initialize database on startup +@app.on_event("startup") +async def startup_event(): + """Initialize database and create default test user if not exists""" + init_db() + + # Create test user if not exists + from app.auth.models import SessionLocal + db = SessionLocal() + try: + existing = db.query(User).filter(User.username == "test").first() + if not existing: + test_user = User( + username="test", + email="test@contextflow.local", + hashed_password=get_password_hash("114514") + ) + db.add(test_user) + db.commit() + print("[startup] Created default test user (test/114514)") + else: + print("[startup] Test user already exists") + finally: + db.close() + +# --------- Project / Blueprint storage --------- +DATA_ROOT = os.path.abspath(os.getenv("DATA_ROOT", os.path.join(os.getcwd(), "data"))) +DEFAULT_USER = "test" +ARCHIVE_FILENAME = "archived_nodes.json" +VALID_FILE_PROVIDERS = {"local", "openai", "google"} +OPENAI_MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB limit per OpenAI docs +OPENAI_DEFAULT_FILE_PURPOSE = os.getenv("OPENAI_FILE_PURPOSE", "user_data") + +def get_user_api_key(user: User | None, provider: str) -> str | None: + """ + Get API key for a provider from user's saved settings. + Falls back to environment variable if user has no key set. + """ + if user: + if provider == "openai" and user.openai_api_key: + return user.openai_api_key + if provider in ("google", "gemini") and user.gemini_api_key: + return user.gemini_api_key + # Fallback to environment variables + if provider == "openai": + return os.getenv("OPENAI_API_KEY") + if provider in ("google", "gemini"): + return os.getenv("GOOGLE_API_KEY") + return None + +def ensure_user_root(user: str) -> str: + """ + Ensures the new data root structure: + data/<user>/projects + data/<user>/archive + """ + user_root = os.path.join(DATA_ROOT, user) + projects_root = os.path.join(user_root, "projects") + archive_root = os.path.join(user_root, "archive") + os.makedirs(projects_root, exist_ok=True) + os.makedirs(archive_root, exist_ok=True) + return user_root + + +def projects_root(user: str) -> str: + return os.path.join(ensure_user_root(user), "projects") + + +def archive_root(user: str) -> str: + return os.path.join(ensure_user_root(user), "archive") + + +def files_root(user: str) -> str: + root = os.path.join(ensure_user_root(user), "files") + os.makedirs(root, exist_ok=True) + return root + + +def migrate_legacy_layout(user: str): + """ + Migrate from legacy ./projects/<user> and legacy archive folders to the new data/<user>/ structure. + """ + legacy_root = os.path.abspath(os.path.join(os.getcwd(), "projects", user)) + new_projects = projects_root(user) + if os.path.exists(legacy_root) and not os.listdir(new_projects): + try: + for name in os.listdir(legacy_root): + src = os.path.join(legacy_root, name) + dst = os.path.join(new_projects, name) + if not os.path.exists(dst): + shutil.move(src, dst) + except Exception: + pass + # migrate legacy archive (archived/ or .cf_archived/) + legacy_archives = [ + os.path.join(legacy_root, "archived", ARCHIVE_FILENAME), + os.path.join(legacy_root, ".cf_archived", ARCHIVE_FILENAME), + ] + new_archive_file = archived_path(user) + if not os.path.exists(new_archive_file): + for legacy in legacy_archives: + if os.path.exists(legacy): + os.makedirs(os.path.dirname(new_archive_file), exist_ok=True) + try: + shutil.move(legacy, new_archive_file) + except Exception: + pass + +def safe_path(user: str, relative_path: str) -> str: + root = projects_root(user) + norm = os.path.normpath(relative_path).lstrip(os.sep) + full = os.path.abspath(os.path.join(root, norm)) + if not full.startswith(root): + raise HTTPException(status_code=400, detail="Invalid path") + return full + +class FSItem(BaseModel): + name: str + path: str # path relative to user root + type: Literal["file", "folder"] + size: Optional[int] = None + mtime: Optional[float] = None + children: Optional[List["FSItem"]] = None + +FSItem.model_rebuild() + +def list_tree(user: str, relative_path: str = ".") -> List[FSItem]: + migrate_legacy_layout(user) + root = safe_path(user, relative_path) + items: List[FSItem] = [] + for name in sorted(os.listdir(root)): + full = os.path.join(root, name) + rel = os.path.relpath(full, projects_root(user)) + stat = os.stat(full) + if os.path.isdir(full): + items.append(FSItem( + name=name, + path=rel, + type="folder", + size=None, + mtime=stat.st_mtime, + children=list_tree(user, rel) + )) + else: + items.append(FSItem( + name=name, + path=rel, + type="file", + size=stat.st_size, + mtime=stat.st_mtime, + children=None + )) + return items + +class SaveBlueprintRequest(BaseModel): + user: str = DEFAULT_USER + path: str # relative path including filename.json + content: dict + +class RenameRequest(BaseModel): + user: str = DEFAULT_USER + path: str + new_name: Optional[str] = None + new_path: Optional[str] = None + +class FileMeta(BaseModel): + id: str + name: str + size: int + mime: str + created_at: float + provider: Optional[str] = None + provider_file_id: Optional[str] = None + openai_file_id: Optional[str] = None + openai_vector_store_id: Optional[str] = None + # Scopes for filtering: "project_path/node_id" composite keys + scopes: List[str] = [] + +class FolderRequest(BaseModel): + user: str = DEFAULT_USER + path: str # relative folder path + +class DeleteRequest(BaseModel): + user: str = DEFAULT_USER + path: str + is_folder: bool = False + +# ----------------------------------------------- + @app.get("/") def read_root(): return {"message": "ContextFlow Backend is running"} @@ -60,10 +265,23 @@ def smart_merge_messages(messages: list[Message]) -> list[Message]: return merged @app.post("/api/run_node_stream") -async def run_node_stream(request: NodeRunRequest): +async def run_node_stream( + request: NodeRunRequest, + current_user: User | None = Depends(get_current_user_optional) +): """ Stream the response from the LLM. """ + # Get API key from user settings if not provided in request + provider_name = request.config.provider.value if hasattr(request.config.provider, 'value') else str(request.config.provider) + if not request.config.api_key: + user_key = get_user_api_key(current_user, provider_name.lower()) + if user_key: + request.config.api_key = user_key + + # Get username for file operations + username = current_user.username if current_user else DEFAULT_USER + # 1. Concatenate all incoming contexts first raw_messages = [] for ctx in request.incoming_contexts: @@ -79,7 +297,819 @@ async def run_node_stream(request: NodeRunRequest): execution_context = Context(messages=final_messages) + tools: List[dict] = [] + attachments: List[dict] = [] + + if request.config.provider == ModelProvider.OPENAI: + vs_ids, debug_refs, filters = await prepare_openai_vector_search( + user=username, + attached_ids=request.attached_file_ids, + scopes=request.scopes, + llm_config=request.config, + ) + # Always enable file_search if vector store exists (even without explicit attachments) + # This allows nodes to access files attached in previous nodes of the trace + if not vs_ids: + # Try to get user's vector store anyway + try: + client = get_openai_client(request.config.api_key) + vs_id = await ensure_user_vector_store(username, client) + if vs_id: + vs_ids = [vs_id] + except Exception as e: + print(f"[warn] Could not get vector store: {e}") + + if vs_ids: + tool_def = {"type": "file_search", "vector_store_ids": vs_ids} + if filters: + tool_def["filters"] = filters + tools.append(tool_def) + print(f"[openai file_search] vs_ids={vs_ids} refs={debug_refs} filters={filters}") + elif request.config.provider == ModelProvider.GOOGLE: + attachments = await prepare_attachments( + user=username, + target_provider=request.config.provider, + attached_ids=request.attached_file_ids, + llm_config=request.config, + ) + return StreamingResponse( - llm_streamer(execution_context, request.user_prompt, request.config), + llm_streamer(execution_context, request.user_prompt, request.config, attachments, tools), media_type="text/event-stream" ) + +class TitleRequest(BaseModel): + user_prompt: str + response: str + +class TitleResponse(BaseModel): + title: str + +@app.post("/api/generate_title", response_model=TitleResponse) +async def generate_title_endpoint( + request: TitleRequest, + current_user: User | None = Depends(get_current_user_optional) +): + """ + Generate a short title for a Q-A pair using gpt-5-nano. + Returns 3-4 short English words summarizing the topic. + """ + api_key = get_user_api_key(current_user, "openai") + title = await generate_title(request.user_prompt, request.response, api_key) + return TitleResponse(title=title) + + +class SummarizeRequest(BaseModel): + content: str + model: str # Model to use for summarization + +class SummarizeResponse(BaseModel): + summary: str + +@app.post("/api/summarize", response_model=SummarizeResponse) +async def summarize_endpoint( + request: SummarizeRequest, + current_user: User | None = Depends(get_current_user_optional) +): + """ + Summarize the given content using the specified model. + """ + from app.services.llm import summarize_content + openai_key = get_user_api_key(current_user, "openai") + gemini_key = get_user_api_key(current_user, "gemini") + summary = await summarize_content(request.content, request.model, openai_key, gemini_key) + return SummarizeResponse(summary=summary) + +# ---------------- Project / Blueprint APIs ---------------- +@app.get("/api/projects/tree", response_model=List[FSItem]) +def get_project_tree(user: str = DEFAULT_USER): + """ + List all files/folders for the user under the projects root. + """ + ensure_user_root(user) + return list_tree(user) + + +@app.post("/api/projects/create_folder") +def create_folder(req: FolderRequest): + """ + Create a folder (and parents) under the user's project root. + """ + try: + folder_path = safe_path(req.user, req.path) + os.makedirs(folder_path, exist_ok=True) + return {"ok": True} + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/api/projects/save_blueprint") +def save_blueprint(req: SaveBlueprintRequest): + """ + Save a blueprint JSON to disk. + """ + try: + full_path = safe_path(req.user, req.path) + os.makedirs(os.path.dirname(full_path), exist_ok=True) + with open(full_path, "w", encoding="utf-8") as f: + json.dump(req.content, f, ensure_ascii=False, indent=2) + return {"ok": True} + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/api/projects/file") +def read_blueprint(user: str = DEFAULT_USER, path: str = ""): + """ + Read a blueprint JSON file. + """ + if not path: + raise HTTPException(status_code=400, detail="path is required") + full_path = safe_path(user, path) + if not os.path.isfile(full_path): + raise HTTPException(status_code=404, detail="file not found") + try: + with open(full_path, "r", encoding="utf-8") as f: + data = json.load(f) + return {"content": data} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/api/projects/download") +def download_blueprint(user: str = DEFAULT_USER, path: str = ""): + """ + Download a blueprint file. + """ + if not path: + raise HTTPException(status_code=400, detail="path is required") + full_path = safe_path(user, path) + if not os.path.isfile(full_path): + raise HTTPException(status_code=404, detail="file not found") + return FileResponse(full_path, filename=os.path.basename(full_path), media_type="application/json") + + +@app.post("/api/projects/rename") +def rename_item(req: RenameRequest): + """ + Rename or move a file or folder. + - If new_path is provided, it is treated as the target relative path (move). + - Else, new_name is used within the same directory. + """ + try: + src = safe_path(req.user, req.path) + if not os.path.exists(src): + raise HTTPException(status_code=404, detail="source not found") + if req.new_path: + dst = safe_path(req.user, req.new_path) + else: + if not req.new_name: + raise HTTPException(status_code=400, detail="new_name or new_path required") + base_dir = os.path.dirname(src) + dst = os.path.join(base_dir, req.new_name) + # Ensure still inside user root + safe_path(req.user, os.path.relpath(dst, ensure_user_root(req.user))) + os.rename(src, dst) + return {"ok": True} + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/api/projects/delete") +def delete_item(req: DeleteRequest): + """ + Delete a file or folder. + """ + try: + target = safe_path(req.user, req.path) + if not os.path.exists(target): + raise HTTPException(status_code=404, detail="not found") + if os.path.isdir(target): + if not req.is_folder: + # Prevent deleting folder accidentally unless flagged + raise HTTPException(status_code=400, detail="set is_folder=True to delete folder") + shutil.rmtree(target) + else: + os.remove(target) + return {"ok": True} + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) +# ---------------------------------------------------------- + +# --------------- Archived Nodes APIs ---------------------- +def archived_path(user: str) -> str: + root = archive_root(user) + return os.path.join(root, ARCHIVE_FILENAME) + +# ---------------- Files (uploads) ---------------- +def files_index_path(user: str) -> str: + return os.path.join(files_root(user), "index.json") + +def user_vector_store_path(user: str) -> str: + return os.path.join(files_root(user), "vector_store.json") + +async def ensure_user_vector_store(user: str, client=None) -> str: + """ + Ensure there is a vector store for the user (OpenAI). + Persist the id under data/<user>/files/vector_store.json. + """ + path = user_vector_store_path(user) + if client is None: + client = get_openai_client() + + # Try existing cached ID + if os.path.exists(path): + try: + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + vs_id_cached = data.get("id") + if vs_id_cached: + try: + await client.vector_stores.retrieve(vector_store_id=vs_id_cached) + return vs_id_cached + except Exception: + # Possibly deleted; recreate below + pass + except Exception: + pass + + # create new + vs = await client.vector_stores.create(name=f"{user}-vs") + vs_id = getattr(vs, "id", None) + if not vs_id: + raise HTTPException(status_code=500, detail="Failed to create vector store") + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump({"id": vs_id}, f) + return vs_id + +async def ensure_openai_file_and_index(user: str, meta: FileMeta, path: str, llm_config: Optional[LLMConfig] = None) -> tuple[str, str]: + """ + Ensure the file is uploaded to OpenAI Files and added to the user's vector store. + Returns (openai_file_id, vector_store_id). + """ + client = get_openai_client(llm_config.api_key if llm_config else None) + vs_id = await ensure_user_vector_store(user, client) + + file_id = meta.openai_file_id or (meta.provider_file_id if meta.provider == "openai" else None) + if not file_id: + with open(path, "rb") as f: + content = f.read() + resp = await client.files.create( + file=(meta.name or "upload.bin", content), + purpose="assistants", + ) + file_id = getattr(resp, "id", None) + if not file_id: + raise HTTPException(status_code=500, detail="OpenAI file upload returned no file_id") + + await add_file_to_vector_store(vs_id, file_id, client=client) + return file_id, vs_id + +async def remove_file_from_vector_store(vs_id: str, file_id: str, client=None): + if not vs_id or not file_id: + return + if client is None: + client = get_openai_client() + try: + await client.vector_stores.files.delete(vector_store_id=vs_id, file_id=file_id) + except Exception as e: + print(f"[warn] remove_file_from_vector_store failed: {e}") + +async def add_file_to_vector_store(vs_id: str, file_id: str, client=None): + """ + Add a file to vector store with file_id as attribute for filtering. + We use file_id as the attribute so we can filter by specific files at query time. + """ + if client is None: + client = get_openai_client() + + # Use file_id as attribute for filtering + create_params = { + "vector_store_id": vs_id, + "file_id": file_id, + "attributes": {"file_id": file_id} # Enable filtering by file_id + } + + await client.vector_stores.files.create(**create_params) + # Poll until completed (limit capped at 100 per API spec) + for _ in range(20): + listing = await client.vector_stores.files.list(vector_store_id=vs_id, limit=100) + found = None + for item in getattr(listing, "data", []): + if getattr(item, "id", None) == file_id or getattr(item, "file_id", None) == file_id: + found = item + break + status = getattr(found, "status", None) if found else None + if status == "completed": + return + await asyncio.sleep(0.5) + # If not confirmed, still continue + return + +def load_files_index(user: str) -> List[FileMeta]: + path = files_index_path(user) + if not os.path.exists(path): + return [] + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + return [FileMeta(**item) for item in data] + + +def save_files_index(user: str, items: List[FileMeta]): + path = files_index_path(user) + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump([item.model_dump() for item in items], f, ensure_ascii=False, indent=2) + + +async def prepare_attachments( + user: str, + target_provider: str, + attached_ids: List[str], + llm_config: LLMConfig, +) -> list[dict]: + """ + For each attached file ID: + - If already uploaded to the target provider, reuse provider_file_id/uri. + - Otherwise, upload with the original filename (required by OpenAI). + Returns a list of dicts describing attachment references for the provider. + """ + if not attached_ids: + return [] + + items = load_files_index(user) + items_map = {item.id: item for item in items} + attachments: list[dict] = [] + + if isinstance(target_provider, ModelProvider): + provider_norm = target_provider.value.lower() + else: + provider_norm = str(target_provider).lower() + + for fid in attached_ids: + meta = items_map.get(fid) + if not meta: + print(f"[warn] Attached file id not found, skipping: {fid}") + continue + + path = os.path.join(files_root(user), fid) + if not os.path.exists(path): + raise HTTPException(status_code=404, detail=f"Attached file missing on disk: {meta.name}") + + if provider_norm == ModelProvider.OPENAI or provider_norm == "openai": + # Reuse provider file id if available + if meta.provider == "openai" and meta.provider_file_id: + attachments.append({ + "provider": "openai", + "file_id": meta.provider_file_id, + "name": meta.name, + "mime": meta.mime, + }) + continue + + # Upload to OpenAI with original filename + with open(path, "rb") as f: + content = f.read() + size = len(content) + if size > OPENAI_MAX_FILE_SIZE: + raise HTTPException(status_code=400, detail=f"File {meta.name} exceeds OpenAI 50MB limit") + + try: + client = get_openai_client(llm_config.api_key) + resp = await client.files.create( + file=(meta.name or "upload.bin", content), + purpose=OPENAI_DEFAULT_FILE_PURPOSE, + ) + openai_file_id = getattr(resp, "id", None) + if not openai_file_id: + raise HTTPException(status_code=500, detail="OpenAI file upload returned no file_id") + attachments.append({ + "provider": "openai", + "file_id": openai_file_id, + "name": meta.name, + "mime": meta.mime, + }) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"OpenAI upload failed: {str(e)}") + + elif provider_norm == ModelProvider.GOOGLE or provider_norm == "google": + # Reuse uri/name if available and looks like a URI + if meta.provider == "google" and meta.provider_file_id and "://" in meta.provider_file_id: + attachments.append({ + "provider": "google", + "uri": meta.provider_file_id, + "name": meta.name, + "mime": meta.mime, + }) + continue + + key = llm_config.api_key or os.getenv("GOOGLE_API_KEY") + if not key: + raise HTTPException(status_code=500, detail="Google API Key not found") + client = genai.Client(api_key=key) + + tmp_path = None + try: + with open(path, "rb") as f: + content = f.read() + with tempfile.NamedTemporaryFile(delete=False) as tmp: + tmp.write(content) + tmp_path = tmp.name + + google_resp = await asyncio.to_thread( + client.files.upload, + file=tmp_path, + config={"mimeType": meta.mime or "application/octet-stream"}, + ) + google_name = getattr(google_resp, "name", None) + google_uri = getattr(google_resp, "uri", None) + + # Poll for ACTIVE and uri if missing + if google_name: + for _ in range(10): + try: + info = await asyncio.to_thread(client.files.get, name=google_name) + state = getattr(info, "state", None) + google_uri = getattr(info, "uri", google_uri) + if str(state).upper().endswith("ACTIVE") or state == "ACTIVE": + break + await asyncio.sleep(1) + except Exception: + await asyncio.sleep(1) + print(f"[google upload] name={google_name} uri={google_uri}") + + uri = google_uri or google_name + if not uri: + raise HTTPException(status_code=500, detail="Google upload returned no uri/name") + attachments.append({ + "provider": "google", + "uri": uri, + "name": meta.name, + "mime": meta.mime, + }) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Google upload failed: {str(e)}") + finally: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + + else: + raise HTTPException(status_code=400, detail=f"Unsupported provider for attachments: {target_provider}") + + # Debug log + print(f"[attachments] provider={provider_norm} count={len(attachments)} detail={[{'name': a.get('name'), 'id': a.get('file_id', a.get('uri'))} for a in attachments]}") + return attachments + + +async def prepare_openai_vector_search( + user: str, + attached_ids: List[str], + scopes: List[str], + llm_config: LLMConfig, +) -> tuple[List[str], List[dict], Optional[dict]]: + """ + Ensure all attached files are uploaded to OpenAI Files (purpose=assistants) and added to the user's vector store. + Returns (vector_store_ids, openai_file_refs_for_debug, filters). + + Filtering logic: + - Include files whose scopes intersect with requested scopes + - ALSO include explicitly attached files (attached_ids) + - Deduplicate to avoid double-processing + - Filters are constructed using file_id attribute in vector store + """ + items = load_files_index(user) + items_map = {item.id: item for item in items} + + # Determine which files to include - combine scopes AND attached_ids + relevant_files_map: dict[str, FileMeta] = {} + + # First: add files matching scopes + if scopes: + for item in items: + if item.scopes and any(s in scopes for s in item.scopes): + relevant_files_map[item.id] = item + print(f"[file_search] scopes={scopes} matched_files={[f.name for f in relevant_files_map.values()]}") + + # Second: also add explicitly attached files (they should always be searchable) + if attached_ids: + for fid in attached_ids: + meta = items_map.get(fid) + if meta and fid not in relevant_files_map: + relevant_files_map[fid] = meta + print(f"[file_search] adding explicitly attached file: {meta.name}") + + relevant_files = list(relevant_files_map.values()) + + if not relevant_files: + return [], [], None + + changed = False + vs_ids: List[str] = [] + debug_refs: List[dict] = [] + file_ids_for_filter: List[str] = [] + + for meta in relevant_files: + path = os.path.join(files_root(user), meta.id) + if not os.path.exists(path): + print(f"[warn] Attached file missing on disk, skipping: {meta.id}") + continue + # Enforce 50MB OpenAI limit + file_size = os.path.getsize(path) + if file_size > OPENAI_MAX_FILE_SIZE: + print(f"[warn] File {meta.name} exceeds OpenAI 50MB limit, skipping") + continue + + openai_file_id, vs_id = await ensure_openai_file_and_index(user, meta, path, llm_config) + if meta.openai_file_id != openai_file_id or meta.openai_vector_store_id != vs_id: + meta.openai_file_id = openai_file_id + meta.openai_vector_store_id = vs_id + changed = True + vs_ids.append(vs_id) + debug_refs.append({"name": meta.name, "file_id": openai_file_id, "vs_id": vs_id}) + if openai_file_id: + file_ids_for_filter.append(openai_file_id) + + if changed: + save_files_index(user, list(items_map.values())) + + # deduplicate + vs_ids_unique = list({vid for vid in vs_ids if vid}) + + # Build filters to only search relevant files + filters = None + if file_ids_for_filter: + filters = {"type": "in", "key": "file_id", "value": file_ids_for_filter} + + return vs_ids_unique, debug_refs, filters + +# ------------------------------------------------- + +@app.get("/api/projects/archived") +def get_archived_nodes(user: str = DEFAULT_USER): + migrate_legacy_layout(user) + path = archived_path(user) + if not os.path.exists(path): + return {"archived": []} + try: + with open(path, "r", encoding="utf-8") as f: + return {"archived": json.load(f)} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/api/projects/archived") +def save_archived_nodes(payload: dict): + user = payload.get("user", DEFAULT_USER) + data = payload.get("archived", []) + try: + path = archived_path(user) + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + return {"ok": True} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/api/files") +def list_files(user: str = DEFAULT_USER): + migrate_legacy_layout(user) + items = load_files_index(user) + return {"files": [item.model_dump() for item in items]} + + +@app.post("/api/files/upload") +async def upload_file( + user: str = DEFAULT_USER, + file: UploadFile = File(...), + provider: str = Form("local"), + purpose: Optional[str] = Form(None), +): + migrate_legacy_layout(user) + items = load_files_index(user) + file_id = str(uuid4()) + dest_root = files_root(user) + dest_path = os.path.join(dest_root, file_id) + file_name = file.filename or "upload.bin" + provider_normalized = (provider or "local").lower() + if provider_normalized not in VALID_FILE_PROVIDERS: + raise HTTPException(status_code=400, detail="Unsupported provider") + + try: + content = await file.read() + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + size = len(content) + if provider_normalized == "openai" and size > OPENAI_MAX_FILE_SIZE: + raise HTTPException(status_code=400, detail="OpenAI provider limit: max 50MB per file") + + provider_file_id: Optional[str] = None + provider_created_at: Optional[float] = None + + if provider_normalized == "openai": + try: + client = get_openai_client() + upload_purpose = purpose or OPENAI_DEFAULT_FILE_PURPOSE + resp = await client.files.create( + file=(file_name, content), + purpose=upload_purpose, + ) + provider_file_id = getattr(resp, "id", None) + provider_created_at = getattr(resp, "created_at", None) + except Exception as e: + raise HTTPException(status_code=500, detail=f"OpenAI upload failed: {str(e)}") + elif provider_normalized == "google": + try: + key = os.getenv("GOOGLE_API_KEY") + if not key: + raise HTTPException(status_code=500, detail="Google API Key not found") + client = genai.Client(api_key=key) + # The Google GenAI SDK upload is synchronous; run in thread to avoid blocking the event loop. + tmp_path = None + try: + with tempfile.NamedTemporaryFile(delete=False) as tmp: + tmp.write(content) + tmp_path = tmp.name + google_resp = await asyncio.to_thread( + client.files.upload, + file=tmp_path, + config={"mimeType": file.content_type or "application/octet-stream"}, + ) + google_name = getattr(google_resp, "name", None) + google_uri = getattr(google_resp, "uri", None) + + # Poll for ACTIVE and uri if missing + if google_name: + for _ in range(10): + try: + info = await asyncio.to_thread(client.files.get, name=google_name) + state = getattr(info, "state", None) + google_uri = getattr(info, "uri", google_uri) + if str(state).upper().endswith("ACTIVE") or state == "ACTIVE": + break + await asyncio.sleep(1) + except Exception: + await asyncio.sleep(1) + + provider_file_id = google_uri or google_name + finally: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + + provider_created_at = time.time() + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Google upload failed: {str(e)}") + + try: + os.makedirs(dest_root, exist_ok=True) + with open(dest_path, "wb") as f: + f.write(content) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + created_at = provider_created_at or os.path.getmtime(dest_path) + + meta = FileMeta( + id=file_id, + name=file_name, + size=size, + mime=file.content_type or "application/octet-stream", + created_at=created_at, + provider=provider_normalized if provider_normalized != "local" else None, + provider_file_id=provider_file_id, + openai_file_id=None, + openai_vector_store_id=None, + ) + + # Always try to index into OpenAI vector store (if <=50MB) + if size <= OPENAI_MAX_FILE_SIZE: + try: + openai_file_id, vs_id = await ensure_openai_file_and_index(user, meta, dest_path, None) + meta.openai_file_id = openai_file_id + meta.openai_vector_store_id = vs_id + if provider_normalized == "openai" and not meta.provider_file_id: + meta.provider_file_id = openai_file_id + except Exception as e: + print(f"[warn] OpenAI indexing failed for {file_name}: {e}") + else: + print(f"[warn] Skipping OpenAI indexing for {file_name}: exceeds 50MB") + + items.append(meta) + save_files_index(user, items) + return {"file": meta} + + +@app.get("/api/files/download") +def download_file(user: str = DEFAULT_USER, file_id: str = ""): + migrate_legacy_layout(user) + items = load_files_index(user) + meta = next((i for i in items if i.id == file_id), None) + if not meta: + raise HTTPException(status_code=404, detail="file not found") + path = os.path.join(files_root(user), file_id) + if not os.path.exists(path): + raise HTTPException(status_code=404, detail="file missing on disk") + return FileResponse(path, filename=meta.name, media_type=meta.mime) + + +@app.post("/api/files/delete") +async def delete_file(user: str = DEFAULT_USER, file_id: str = ""): + migrate_legacy_layout(user) + items = load_files_index(user) + meta = next((i for i in items if i.id == file_id), None) + if not meta: + raise HTTPException(status_code=404, detail="file not found") + + # Remove from vector store and OpenAI Files if present + if meta.openai_vector_store_id and meta.openai_file_id: + await remove_file_from_vector_store(meta.openai_vector_store_id, meta.openai_file_id) + if meta.provider == "openai" and meta.provider_file_id: + try: + client = get_openai_client() + await client.files.delete(meta.provider_file_id) + except Exception as e: + raise HTTPException(status_code=500, detail=f"OpenAI delete failed: {str(e)}") + if meta.provider == "google" and meta.provider_file_id: + try: + key = os.getenv("GOOGLE_API_KEY") + if not key: + raise HTTPException(status_code=500, detail="Google API Key not found") + client = genai.Client(api_key=key) + await asyncio.to_thread(client.files.delete, meta.provider_file_id) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Google delete failed: {str(e)}") + + path = os.path.join(files_root(user), file_id) + if os.path.exists(path): + os.remove(path) + items = [i for i in items if i.id != file_id] + save_files_index(user, items) + return {"ok": True} + + +class AddScopeRequest(BaseModel): + user: str = DEFAULT_USER + file_id: str + scope: str # "project_path/node_id" composite key + + +@app.post("/api/files/add_scope") +def add_file_scope(request: AddScopeRequest): + """ + Add a scope to a file's scopes list. + Called when user attaches a file to a node. + """ + migrate_legacy_layout(request.user) + items = load_files_index(request.user) + meta = next((i for i in items if i.id == request.file_id), None) + if not meta: + raise HTTPException(status_code=404, detail="file not found") + + if request.scope not in meta.scopes: + meta.scopes.append(request.scope) + save_files_index(request.user, items) + + return {"file": meta.model_dump()} + + +class RemoveScopeRequest(BaseModel): + user: str = DEFAULT_USER + file_id: str + scope: str + + +@app.post("/api/files/remove_scope") +def remove_file_scope(request: RemoveScopeRequest): + """ + Remove a scope from a file's scopes list. + Called when user detaches a file from a node. + """ + migrate_legacy_layout(request.user) + items = load_files_index(request.user) + meta = next((i for i in items if i.id == request.file_id), None) + if not meta: + raise HTTPException(status_code=404, detail="file not found") + + if request.scope in meta.scopes: + meta.scopes.remove(request.scope) + save_files_index(request.user, items) + + return {"file": meta.model_dump()} |
