diff options
Diffstat (limited to 'backend/app/services/tasks.py')
| -rw-r--r-- | backend/app/services/tasks.py | 304 |
1 files changed, 304 insertions, 0 deletions
diff --git a/backend/app/services/tasks.py b/backend/app/services/tasks.py new file mode 100644 index 0000000..003c02d --- /dev/null +++ b/backend/app/services/tasks.py @@ -0,0 +1,304 @@ +""" +Background task management for debate & council operations. + +Decouples execution from SSE delivery so tasks survive browser disconnects. +Results are persisted progressively to disk. +""" + +import asyncio +import json +import logging +import os +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, AsyncGenerator, Dict, List, Optional + +logger = logging.getLogger("contextflow.tasks") + +DATA_ROOT = os.path.abspath(os.getenv("DATA_ROOT", os.path.join(os.getcwd(), "data"))) + + +class TaskStatus(str, Enum): + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + INTERRUPTED = "interrupted" + + +# Chunk event types that are NOT persisted to disk (redundant with *_complete events) +_CHUNK_EVENT_TYPES = {"final_chunk", "stage3_chunk"} + + +@dataclass +class TaskInfo: + task_id: str + user: str + node_id: str + task_type: str # "debate" | "council" + status: TaskStatus + created_at: float + updated_at: float + events: List[Dict] = field(default_factory=list) + result: Optional[Dict] = None + error: Optional[str] = None + # In-memory only (not serialized): + new_event_signal: asyncio.Event = field(default_factory=asyncio.Event, repr=False) + asyncio_task: Optional[asyncio.Task] = field(default=None, repr=False) + + +# --------------- In-memory registry --------------- +_task_registry: Dict[str, TaskInfo] = {} + + +def register_task(task_id: str, user: str, node_id: str, task_type: str) -> TaskInfo: + """Create and register a new task in memory + persist initial state to disk.""" + now = time.time() + info = TaskInfo( + task_id=task_id, + user=user, + node_id=node_id, + task_type=task_type, + status=TaskStatus.PENDING, + created_at=now, + updated_at=now, + ) + _task_registry[task_id] = info + _persist_task(info) + return info + + +def get_task(task_id: str) -> Optional[TaskInfo]: + """Get task from in-memory registry.""" + return _task_registry.get(task_id) + + +def get_task_status(user: str, task_id: str) -> Optional[Dict]: + """ + Get task status dict. In-memory first, fallback to disk. + If disk says "running" but not in memory → report "interrupted". + """ + info = _task_registry.get(task_id) + if info: + return _task_to_dict(info) + + # Fallback: read from disk + disk_data = _load_task_from_disk(user, task_id) + if not disk_data: + return None + + # If disk says running but it's not in memory, backend must have restarted + if disk_data.get("status") == TaskStatus.RUNNING: + disk_data["status"] = TaskStatus.INTERRUPTED + return disk_data + + +async def run_task_in_background(info: TaskInfo, event_generator: AsyncGenerator[str, None]) -> None: + """ + Consume the async generator, appending events to info.events. + Signals new_event_signal for SSE consumers. + Persists at milestones, skips chunk events from persistence. + """ + info.status = TaskStatus.RUNNING + info.updated_at = time.time() + _persist_task(info) + + try: + async for raw_sse in event_generator: + # raw_sse is a string like 'data: {"type": "...", ...}\n\n' + evt = _parse_sse_event(raw_sse) + if evt is None: + continue + + evt_type = evt.get("type", "") + + # Always store in memory (even chunks, for live streaming) + info.events.append(evt) + info.updated_at = time.time() + + # Signal waiting SSE consumers + info.new_event_signal.set() + info.new_event_signal.clear() + + # Persist at milestones (skip chunk events) + if evt_type not in _CHUNK_EVENT_TYPES: + _persist_task(info) + + # Detect completion / error + if evt_type in ("complete", "debate_complete"): + info.result = _extract_result(info.task_type, info.events) + info.status = TaskStatus.COMPLETED + info.updated_at = time.time() + _persist_task(info) + return + if evt_type == "error": + info.error = evt.get("data", {}).get("message", "Unknown error") if isinstance(evt.get("data"), dict) else str(evt.get("data", "Unknown error")) + info.status = TaskStatus.FAILED + info.updated_at = time.time() + _persist_task(info) + return + + # Generator exhausted without explicit complete/error + if info.status == TaskStatus.RUNNING: + info.result = _extract_result(info.task_type, info.events) + info.status = TaskStatus.COMPLETED + info.updated_at = time.time() + _persist_task(info) + + except asyncio.CancelledError: + info.status = TaskStatus.INTERRUPTED + info.updated_at = time.time() + _persist_task(info) + raise + except Exception as e: + logger.exception("Background task %s failed: %s", info.task_id, e) + info.error = str(e) + info.status = TaskStatus.FAILED + info.updated_at = time.time() + _persist_task(info) + + +def cleanup_stale_tasks(data_root: Optional[str] = None) -> int: + """ + On startup, scan task files on disk and mark any "running" as "interrupted". + Returns count of tasks marked interrupted. + """ + root = data_root or DATA_ROOT + count = 0 + if not os.path.exists(root): + return count + + for user_dir in os.listdir(root): + tasks_dir = os.path.join(root, user_dir, "tasks") + if not os.path.isdir(tasks_dir): + continue + for fname in os.listdir(tasks_dir): + if not fname.endswith(".json"): + continue + fpath = os.path.join(tasks_dir, fname) + try: + with open(fpath, "r", encoding="utf-8") as f: + data = json.load(f) + if data.get("status") == TaskStatus.RUNNING: + data["status"] = TaskStatus.INTERRUPTED + data["updated_at"] = time.time() + with open(fpath, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False) + count += 1 + logger.info("Marked stale task as interrupted: %s", fname) + except Exception as e: + logger.warning("Failed to process stale task file %s: %s", fpath, e) + + return count + + +# --------------- Persistence helpers --------------- + +def _task_dir(user: str) -> str: + d = os.path.join(DATA_ROOT, user, "tasks") + os.makedirs(d, exist_ok=True) + return d + + +def _task_file_path(user: str, task_id: str) -> str: + return os.path.join(_task_dir(user), f"{task_id}.json") + + +def _persist_task(info: TaskInfo) -> None: + """Write task state to disk. Excludes chunk events for smaller files.""" + try: + data = _task_to_dict(info, exclude_chunks=True) + path = _task_file_path(info.user, info.task_id) + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False) + except Exception as e: + logger.warning("Failed to persist task %s: %s", info.task_id, e) + + +def _load_task_from_disk(user: str, task_id: str) -> Optional[Dict]: + path = _task_file_path(user, task_id) + if not os.path.exists(path): + return None + try: + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + except Exception as e: + logger.warning("Failed to load task %s from disk: %s", task_id, e) + return None + + +def _task_to_dict(info: TaskInfo, exclude_chunks: bool = False) -> Dict: + events = info.events + if exclude_chunks: + events = [e for e in events if e.get("type") not in _CHUNK_EVENT_TYPES] + return { + "task_id": info.task_id, + "user": info.user, + "node_id": info.node_id, + "task_type": info.task_type, + "status": info.status.value if isinstance(info.status, TaskStatus) else info.status, + "created_at": info.created_at, + "updated_at": info.updated_at, + "events": events, + "result": info.result, + "error": info.error, + } + + +def _parse_sse_event(raw: str) -> Optional[Dict]: + """Parse a raw SSE string like 'data: {...}\\n\\n' into a dict.""" + raw = raw.strip() + if not raw: + return None + for line in raw.split("\n"): + line = line.strip() + if line.startswith("data: "): + try: + return json.loads(line[6:]) + except json.JSONDecodeError: + return None + return None + + +def _extract_result(task_type: str, events: List[Dict]) -> Dict: + """Reconstruct final result data from accumulated events.""" + if task_type == "council": + stage1 = None + stage2 = None + stage3 = None + for evt in events: + t = evt.get("type", "") + if t == "stage1_complete": + stage1 = evt.get("data") + elif t == "stage1_model_complete": + if stage1 is None: + stage1 = [] + stage1.append(evt.get("data")) + elif t == "stage2_complete": + stage2 = evt.get("data") + elif t == "stage3_complete": + stage3 = evt.get("data") + return {"councilData": {"stage1": stage1, "stage2": stage2, "stage3": stage3}} + + elif task_type == "debate": + rounds: List[Dict] = [] + final_verdict = None + for evt in events: + t = evt.get("type", "") + if t == "round_complete": + rounds.append(evt.get("data", {})) + elif t == "final_complete": + final_verdict = evt.get("data") + elif t == "judge_decision": + # Attach to last round + if rounds: + rounds[-1]["judgeDecision"] = evt.get("data") + elif t == "model_eliminated": + if rounds: + if "eliminated" not in rounds[-1]: + rounds[-1]["eliminated"] = [] + rounds[-1]["eliminated"].append(evt.get("data")) + return {"debateRounds": rounds, "finalVerdict": final_verdict} + + return {} |
