From bdf381a2c8a0337f7459000f487a80f9cbbbdd2f Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Sat, 14 Feb 2026 03:40:31 +0000 Subject: Add background task persistence for debate & council operations Decouple debate/council execution from SSE connection lifecycle so tasks survive browser disconnects. Backend runs work as asyncio.Tasks with progressive disk persistence; frontend can reconnect and recover state. - New backend/app/services/tasks.py: task registry, broadcast pattern, disk persistence at milestones, stale task cleanup on startup - New endpoints: POST start_debate/start_council, GET task stream/poll - Frontend stores taskId on nodes, recovers running tasks on page load - _applyPartialEvents rebuilds stage text + data from accumulated events Co-Authored-By: Claude Opus 4.6 --- backend/app/main.py | 355 +++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 354 insertions(+), 1 deletion(-) (limited to 'backend/app/main.py') diff --git a/backend/app/main.py b/backend/app/main.py index 89c5dd0..746b731 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -12,6 +12,10 @@ from app.schemas import NodeRunRequest, NodeRunResponse, MergeStrategy, Role, Me from app.services.llm import llm_streamer, generate_title, get_openai_client, get_anthropic_client, resolve_provider from app.services.council import council_event_stream from app.services.debate import debate_event_stream +from app.services.tasks import ( + register_task, get_task, get_task_status, run_task_in_background, + cleanup_stale_tasks, TaskStatus, _CHUNK_EVENT_TYPES, _task_registry, +) from app.auth import auth_router, get_current_user, get_current_user_optional, init_db, User, get_db from app.auth.utils import get_password_hash from dotenv import load_dotenv @@ -57,7 +61,12 @@ app.add_middleware( async def startup_event(): """Initialize database and create default test user if not exists""" init_db() - + + # Mark any in-progress tasks from prior run as interrupted + stale_count = cleanup_stale_tasks(DATA_ROOT) + if stale_count: + logger.info("Cleaned up %d stale background tasks", stale_count) + # Create test user if not exists from app.auth.models import SessionLocal db = SessionLocal() @@ -711,6 +720,350 @@ async def run_debate_stream( ) +# --------------- Background Task Helpers --------------- + +async def _prepare_council_args(request: CouncilRunRequest, resolved: User | None, username: str) -> dict: + """Extract shared council request processing into a reusable helper. + Returns kwargs suitable for council_event_stream().""" + raw_messages = [] + for ctx in request.incoming_contexts: + raw_messages.extend(ctx.messages) + if request.merge_strategy == MergeStrategy.SMART: + final_messages = smart_merge_messages(raw_messages) + else: + final_messages = raw_messages + execution_context = Context(messages=final_messages) + + images, non_image_file_ids = extract_image_attachments(username, request.attached_file_ids) + openrouter_key = get_user_api_key(resolved, "openrouter") + + member_configs: list[LLMConfig] = [] + attachments_per_model: list[list[dict] | None] = [] + tools_per_model: list[list[dict] | None] = [] + contexts_per_model: list[Context | None] = [] + + for member in request.council_models: + provider = resolve_provider(member.model_name) + provider_str = provider.value + api_key = get_user_api_key(resolved, provider_str) + config = LLMConfig( + provider=provider, + model_name=member.model_name, + temperature=member.temperature if member.temperature is not None else request.temperature, + system_prompt=request.system_prompt, + api_key=api_key, + reasoning_effort=member.reasoning_effort if member.reasoning_effort is not None else request.reasoning_effort, + enable_google_search=member.enable_google_search if member.enable_google_search is not None else request.enable_google_search, + ) + member_configs.append(config) + + tools: list[dict] = [] + attachments: list[dict] = [] + scoped_file_ids = resolve_scoped_file_ids(username, request.scopes, non_image_file_ids) + + if provider == ModelProvider.OPENAI: + vs_ids, debug_refs, filters = await prepare_openai_vector_search( + user=username, attached_ids=non_image_file_ids, + scopes=request.scopes, llm_config=config, + ) + if not vs_ids: + try: + client = get_openai_client(config.api_key) + vs_id = await ensure_user_vector_store(username, client) + if vs_id: + vs_ids = [vs_id] + except Exception: + pass + if vs_ids: + tool_def = {"type": "file_search", "vector_store_ids": vs_ids} + if filters: + tool_def["filters"] = filters + tools.append(tool_def) + elif provider == ModelProvider.GOOGLE: + attachments = await prepare_attachments( + user=username, target_provider=provider, + attached_ids=scoped_file_ids, llm_config=config, + ) + elif provider == ModelProvider.CLAUDE: + attachments = await prepare_attachments( + user=username, target_provider=provider, + attached_ids=scoped_file_ids, llm_config=config, + ) + + attachments_per_model.append(attachments or None) + tools_per_model.append(tools or None) + + if member.incoming_contexts: + raw = [] + for ctx in member.incoming_contexts: + raw.extend(ctx.messages) + if request.merge_strategy == MergeStrategy.SMART: + merged = smart_merge_messages(raw) + else: + merged = raw + contexts_per_model.append(Context(messages=merged)) + else: + contexts_per_model.append(None) + + chairman = request.chairman_model + chairman_provider = resolve_provider(chairman.model_name) + chairman_api_key = get_user_api_key(resolved, chairman_provider.value) + chairman_config = LLMConfig( + provider=chairman_provider, + model_name=chairman.model_name, + temperature=chairman.temperature if chairman.temperature is not None else request.temperature, + system_prompt=request.system_prompt, + api_key=chairman_api_key, + reasoning_effort=chairman.reasoning_effort if chairman.reasoning_effort is not None else request.reasoning_effort, + enable_google_search=chairman.enable_google_search if chairman.enable_google_search is not None else request.enable_google_search, + ) + + return dict( + user_prompt=request.user_prompt, + context=execution_context, + member_configs=member_configs, + chairman_config=chairman_config, + attachments_per_model=attachments_per_model, + tools_per_model=tools_per_model, + openrouter_api_key=openrouter_key, + images=images, + contexts_per_model=contexts_per_model, + ) + + +async def _prepare_debate_args(request: DebateRunRequest, resolved: User | None, username: str) -> dict: + """Extract shared debate request processing into a reusable helper. + Returns kwargs suitable for debate_event_stream().""" + raw_messages = [] + for ctx in request.incoming_contexts: + raw_messages.extend(ctx.messages) + if request.merge_strategy == MergeStrategy.SMART: + final_messages = smart_merge_messages(raw_messages) + else: + final_messages = raw_messages + execution_context = Context(messages=final_messages) + + images, non_image_file_ids = extract_image_attachments(username, request.attached_file_ids) + openrouter_key = get_user_api_key(resolved, "openrouter") + + member_configs: list[LLMConfig] = [] + attachments_per_model: list[list[dict] | None] = [] + tools_per_model: list[list[dict] | None] = [] + + for member in request.debate_models: + provider = resolve_provider(member.model_name) + provider_str = provider.value + api_key = get_user_api_key(resolved, provider_str) + config = LLMConfig( + provider=provider, + model_name=member.model_name, + temperature=member.temperature if member.temperature is not None else request.temperature, + system_prompt=request.system_prompt, + api_key=api_key, + reasoning_effort=member.reasoning_effort if member.reasoning_effort is not None else request.reasoning_effort, + enable_google_search=member.enable_google_search if member.enable_google_search is not None else request.enable_google_search, + ) + member_configs.append(config) + + tools: list[dict] = [] + attachments: list[dict] = [] + scoped_file_ids = resolve_scoped_file_ids(username, request.scopes, non_image_file_ids) + + if provider == ModelProvider.OPENAI: + vs_ids, debug_refs, filters = await prepare_openai_vector_search( + user=username, attached_ids=non_image_file_ids, + scopes=request.scopes, llm_config=config, + ) + if not vs_ids: + try: + client = get_openai_client(config.api_key) + vs_id = await ensure_user_vector_store(username, client) + if vs_id: + vs_ids = [vs_id] + except Exception: + pass + if vs_ids: + tool_def = {"type": "file_search", "vector_store_ids": vs_ids} + if filters: + tool_def["filters"] = filters + tools.append(tool_def) + elif provider == ModelProvider.GOOGLE: + attachments = await prepare_attachments( + user=username, target_provider=provider, + attached_ids=scoped_file_ids, llm_config=config, + ) + elif provider == ModelProvider.CLAUDE: + attachments = await prepare_attachments( + user=username, target_provider=provider, + attached_ids=scoped_file_ids, llm_config=config, + ) + + attachments_per_model.append(attachments or None) + tools_per_model.append(tools or None) + + judge_config = None + if request.judge_mode == DebateJudgeMode.EXTERNAL_JUDGE and request.judge_model: + judge = request.judge_model + judge_provider = resolve_provider(judge.model_name) + judge_api_key = get_user_api_key(resolved, judge_provider.value) + judge_config = LLMConfig( + provider=judge_provider, + model_name=judge.model_name, + temperature=judge.temperature if judge.temperature is not None else request.temperature, + system_prompt=request.system_prompt, + api_key=judge_api_key, + reasoning_effort=judge.reasoning_effort if judge.reasoning_effort is not None else request.reasoning_effort, + enable_google_search=judge.enable_google_search if judge.enable_google_search is not None else request.enable_google_search, + ) + + return dict( + user_prompt=request.user_prompt, + context=execution_context, + member_configs=member_configs, + judge_config=judge_config, + judge_mode=request.judge_mode, + debate_format=request.debate_format, + max_rounds=request.max_rounds, + custom_format_prompt=request.custom_format_prompt, + attachments_per_model=attachments_per_model, + tools_per_model=tools_per_model, + openrouter_api_key=openrouter_key, + images=images, + ) + + +# --------------- Background Task Endpoints --------------- + +@app.post("/api/task/start_council") +async def start_council_task( + request: CouncilRunRequest, + user: str = DEFAULT_USER, + current_user: User | None = Depends(get_current_user_optional), +): + """Start council as a background task. Returns {task_id}.""" + resolved = resolve_user(current_user, user) + username = resolved.username if resolved else DEFAULT_USER + + # Cancel existing task on same node if still running + for existing in list(_task_registry.values()): + if existing.node_id == request.node_id and existing.user == username and existing.status == TaskStatus.RUNNING: + if existing.asyncio_task and not existing.asyncio_task.done(): + existing.asyncio_task.cancel() + logger.info("Cancelled previous task %s for node %s", existing.task_id, request.node_id) + + kwargs = await _prepare_council_args(request, resolved, username) + generator = council_event_stream(**kwargs) + + task_id = str(uuid4()) + info = register_task(task_id, username, request.node_id, "council") + info.asyncio_task = asyncio.create_task(run_task_in_background(info, generator)) + return {"task_id": task_id} + + +@app.post("/api/task/start_debate") +async def start_debate_task( + request: DebateRunRequest, + user: str = DEFAULT_USER, + current_user: User | None = Depends(get_current_user_optional), +): + """Start debate as a background task. Returns {task_id}.""" + resolved = resolve_user(current_user, user) + username = resolved.username if resolved else DEFAULT_USER + + # Cancel existing task on same node if still running + for existing in list(_task_registry.values()): + if existing.node_id == request.node_id and existing.user == username and existing.status == TaskStatus.RUNNING: + if existing.asyncio_task and not existing.asyncio_task.done(): + existing.asyncio_task.cancel() + logger.info("Cancelled previous task %s for node %s", existing.task_id, request.node_id) + + kwargs = await _prepare_debate_args(request, resolved, username) + generator = debate_event_stream(**kwargs) + + task_id = str(uuid4()) + info = register_task(task_id, username, request.node_id, "debate") + info.asyncio_task = asyncio.create_task(run_task_in_background(info, generator)) + return {"task_id": task_id} + + +@app.get("/api/task/{task_id}/stream") +async def stream_task_events( + task_id: str, + from_event: int = 0, + user: str = DEFAULT_USER, + current_user: User | None = Depends(get_current_user_optional), +): + """SSE stream: replay missed events then stream live.""" + resolved = resolve_user(current_user, user) + username = resolved.username if resolved else DEFAULT_USER + + info = get_task(task_id) + if not info: + raise HTTPException(status_code=404, detail="Task not found") + if info.user != username: + raise HTTPException(status_code=403, detail="Not your task") + + async def event_generator(): + cursor = from_event + # 1. Replay accumulated events from cursor + while cursor < len(info.events): + evt = info.events[cursor] + yield f"data: {json.dumps(evt)}\n\n" + cursor += 1 + + # 2. If already done, send terminal and return + if info.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.INTERRUPTED): + yield f"data: {json.dumps({'type': 'task_status', 'data': {'status': info.status.value}})}\n\n" + return + + # 3. Stream live events + while True: + # Wait for new events with timeout (keepalive) + try: + await asyncio.wait_for( + asyncio.shield(info.new_event_signal.wait()), + timeout=15.0, + ) + except asyncio.TimeoutError: + # Send keepalive comment + yield ": keepalive\n\n" + # Check if task finished while waiting + if info.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.INTERRUPTED): + break + continue + + # Yield any new events since cursor + while cursor < len(info.events): + evt = info.events[cursor] + yield f"data: {json.dumps(evt)}\n\n" + cursor += 1 + + if info.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.INTERRUPTED): + break + + yield f"data: {json.dumps({'type': 'task_status', 'data': {'status': info.status.value}})}\n\n" + + return StreamingResponse(event_generator(), media_type="text/event-stream") + + +@app.get("/api/task/{task_id}") +async def poll_task( + task_id: str, + user: str = DEFAULT_USER, + current_user: User | None = Depends(get_current_user_optional), +): + """Poll: return status + accumulated results.""" + resolved = resolve_user(current_user, user) + username = resolved.username if resolved else DEFAULT_USER + + status_data = get_task_status(username, task_id) + if not status_data: + raise HTTPException(status_code=404, detail="Task not found") + if status_data.get("user") != username: + raise HTTPException(status_code=403, detail="Not your task") + return status_data + + class TitleRequest(BaseModel): user_prompt: str response: str -- cgit v1.2.3