summaryrefslogtreecommitdiff
path: root/backend/app/services/tasks.py
diff options
context:
space:
mode:
Diffstat (limited to 'backend/app/services/tasks.py')
-rw-r--r--backend/app/services/tasks.py304
1 files changed, 304 insertions, 0 deletions
diff --git a/backend/app/services/tasks.py b/backend/app/services/tasks.py
new file mode 100644
index 0000000..003c02d
--- /dev/null
+++ b/backend/app/services/tasks.py
@@ -0,0 +1,304 @@
+"""
+Background task management for debate & council operations.
+
+Decouples execution from SSE delivery so tasks survive browser disconnects.
+Results are persisted progressively to disk.
+"""
+
+import asyncio
+import json
+import logging
+import os
+import time
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import Any, AsyncGenerator, Dict, List, Optional
+
+logger = logging.getLogger("contextflow.tasks")
+
+DATA_ROOT = os.path.abspath(os.getenv("DATA_ROOT", os.path.join(os.getcwd(), "data")))
+
+
+class TaskStatus(str, Enum):
+ PENDING = "pending"
+ RUNNING = "running"
+ COMPLETED = "completed"
+ FAILED = "failed"
+ INTERRUPTED = "interrupted"
+
+
+# Chunk event types that are NOT persisted to disk (redundant with *_complete events)
+_CHUNK_EVENT_TYPES = {"final_chunk", "stage3_chunk"}
+
+
+@dataclass
+class TaskInfo:
+ task_id: str
+ user: str
+ node_id: str
+ task_type: str # "debate" | "council"
+ status: TaskStatus
+ created_at: float
+ updated_at: float
+ events: List[Dict] = field(default_factory=list)
+ result: Optional[Dict] = None
+ error: Optional[str] = None
+ # In-memory only (not serialized):
+ new_event_signal: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
+ asyncio_task: Optional[asyncio.Task] = field(default=None, repr=False)
+
+
+# --------------- In-memory registry ---------------
+_task_registry: Dict[str, TaskInfo] = {}
+
+
+def register_task(task_id: str, user: str, node_id: str, task_type: str) -> TaskInfo:
+ """Create and register a new task in memory + persist initial state to disk."""
+ now = time.time()
+ info = TaskInfo(
+ task_id=task_id,
+ user=user,
+ node_id=node_id,
+ task_type=task_type,
+ status=TaskStatus.PENDING,
+ created_at=now,
+ updated_at=now,
+ )
+ _task_registry[task_id] = info
+ _persist_task(info)
+ return info
+
+
+def get_task(task_id: str) -> Optional[TaskInfo]:
+ """Get task from in-memory registry."""
+ return _task_registry.get(task_id)
+
+
+def get_task_status(user: str, task_id: str) -> Optional[Dict]:
+ """
+ Get task status dict. In-memory first, fallback to disk.
+ If disk says "running" but not in memory → report "interrupted".
+ """
+ info = _task_registry.get(task_id)
+ if info:
+ return _task_to_dict(info)
+
+ # Fallback: read from disk
+ disk_data = _load_task_from_disk(user, task_id)
+ if not disk_data:
+ return None
+
+ # If disk says running but it's not in memory, backend must have restarted
+ if disk_data.get("status") == TaskStatus.RUNNING:
+ disk_data["status"] = TaskStatus.INTERRUPTED
+ return disk_data
+
+
+async def run_task_in_background(info: TaskInfo, event_generator: AsyncGenerator[str, None]) -> None:
+ """
+ Consume the async generator, appending events to info.events.
+ Signals new_event_signal for SSE consumers.
+ Persists at milestones, skips chunk events from persistence.
+ """
+ info.status = TaskStatus.RUNNING
+ info.updated_at = time.time()
+ _persist_task(info)
+
+ try:
+ async for raw_sse in event_generator:
+ # raw_sse is a string like 'data: {"type": "...", ...}\n\n'
+ evt = _parse_sse_event(raw_sse)
+ if evt is None:
+ continue
+
+ evt_type = evt.get("type", "")
+
+ # Always store in memory (even chunks, for live streaming)
+ info.events.append(evt)
+ info.updated_at = time.time()
+
+ # Signal waiting SSE consumers
+ info.new_event_signal.set()
+ info.new_event_signal.clear()
+
+ # Persist at milestones (skip chunk events)
+ if evt_type not in _CHUNK_EVENT_TYPES:
+ _persist_task(info)
+
+ # Detect completion / error
+ if evt_type in ("complete", "debate_complete"):
+ info.result = _extract_result(info.task_type, info.events)
+ info.status = TaskStatus.COMPLETED
+ info.updated_at = time.time()
+ _persist_task(info)
+ return
+ if evt_type == "error":
+ info.error = evt.get("data", {}).get("message", "Unknown error") if isinstance(evt.get("data"), dict) else str(evt.get("data", "Unknown error"))
+ info.status = TaskStatus.FAILED
+ info.updated_at = time.time()
+ _persist_task(info)
+ return
+
+ # Generator exhausted without explicit complete/error
+ if info.status == TaskStatus.RUNNING:
+ info.result = _extract_result(info.task_type, info.events)
+ info.status = TaskStatus.COMPLETED
+ info.updated_at = time.time()
+ _persist_task(info)
+
+ except asyncio.CancelledError:
+ info.status = TaskStatus.INTERRUPTED
+ info.updated_at = time.time()
+ _persist_task(info)
+ raise
+ except Exception as e:
+ logger.exception("Background task %s failed: %s", info.task_id, e)
+ info.error = str(e)
+ info.status = TaskStatus.FAILED
+ info.updated_at = time.time()
+ _persist_task(info)
+
+
+def cleanup_stale_tasks(data_root: Optional[str] = None) -> int:
+ """
+ On startup, scan task files on disk and mark any "running" as "interrupted".
+ Returns count of tasks marked interrupted.
+ """
+ root = data_root or DATA_ROOT
+ count = 0
+ if not os.path.exists(root):
+ return count
+
+ for user_dir in os.listdir(root):
+ tasks_dir = os.path.join(root, user_dir, "tasks")
+ if not os.path.isdir(tasks_dir):
+ continue
+ for fname in os.listdir(tasks_dir):
+ if not fname.endswith(".json"):
+ continue
+ fpath = os.path.join(tasks_dir, fname)
+ try:
+ with open(fpath, "r", encoding="utf-8") as f:
+ data = json.load(f)
+ if data.get("status") == TaskStatus.RUNNING:
+ data["status"] = TaskStatus.INTERRUPTED
+ data["updated_at"] = time.time()
+ with open(fpath, "w", encoding="utf-8") as f:
+ json.dump(data, f, ensure_ascii=False)
+ count += 1
+ logger.info("Marked stale task as interrupted: %s", fname)
+ except Exception as e:
+ logger.warning("Failed to process stale task file %s: %s", fpath, e)
+
+ return count
+
+
+# --------------- Persistence helpers ---------------
+
+def _task_dir(user: str) -> str:
+ d = os.path.join(DATA_ROOT, user, "tasks")
+ os.makedirs(d, exist_ok=True)
+ return d
+
+
+def _task_file_path(user: str, task_id: str) -> str:
+ return os.path.join(_task_dir(user), f"{task_id}.json")
+
+
+def _persist_task(info: TaskInfo) -> None:
+ """Write task state to disk. Excludes chunk events for smaller files."""
+ try:
+ data = _task_to_dict(info, exclude_chunks=True)
+ path = _task_file_path(info.user, info.task_id)
+ with open(path, "w", encoding="utf-8") as f:
+ json.dump(data, f, ensure_ascii=False)
+ except Exception as e:
+ logger.warning("Failed to persist task %s: %s", info.task_id, e)
+
+
+def _load_task_from_disk(user: str, task_id: str) -> Optional[Dict]:
+ path = _task_file_path(user, task_id)
+ if not os.path.exists(path):
+ return None
+ try:
+ with open(path, "r", encoding="utf-8") as f:
+ return json.load(f)
+ except Exception as e:
+ logger.warning("Failed to load task %s from disk: %s", task_id, e)
+ return None
+
+
+def _task_to_dict(info: TaskInfo, exclude_chunks: bool = False) -> Dict:
+ events = info.events
+ if exclude_chunks:
+ events = [e for e in events if e.get("type") not in _CHUNK_EVENT_TYPES]
+ return {
+ "task_id": info.task_id,
+ "user": info.user,
+ "node_id": info.node_id,
+ "task_type": info.task_type,
+ "status": info.status.value if isinstance(info.status, TaskStatus) else info.status,
+ "created_at": info.created_at,
+ "updated_at": info.updated_at,
+ "events": events,
+ "result": info.result,
+ "error": info.error,
+ }
+
+
+def _parse_sse_event(raw: str) -> Optional[Dict]:
+ """Parse a raw SSE string like 'data: {...}\\n\\n' into a dict."""
+ raw = raw.strip()
+ if not raw:
+ return None
+ for line in raw.split("\n"):
+ line = line.strip()
+ if line.startswith("data: "):
+ try:
+ return json.loads(line[6:])
+ except json.JSONDecodeError:
+ return None
+ return None
+
+
+def _extract_result(task_type: str, events: List[Dict]) -> Dict:
+ """Reconstruct final result data from accumulated events."""
+ if task_type == "council":
+ stage1 = None
+ stage2 = None
+ stage3 = None
+ for evt in events:
+ t = evt.get("type", "")
+ if t == "stage1_complete":
+ stage1 = evt.get("data")
+ elif t == "stage1_model_complete":
+ if stage1 is None:
+ stage1 = []
+ stage1.append(evt.get("data"))
+ elif t == "stage2_complete":
+ stage2 = evt.get("data")
+ elif t == "stage3_complete":
+ stage3 = evt.get("data")
+ return {"councilData": {"stage1": stage1, "stage2": stage2, "stage3": stage3}}
+
+ elif task_type == "debate":
+ rounds: List[Dict] = []
+ final_verdict = None
+ for evt in events:
+ t = evt.get("type", "")
+ if t == "round_complete":
+ rounds.append(evt.get("data", {}))
+ elif t == "final_complete":
+ final_verdict = evt.get("data")
+ elif t == "judge_decision":
+ # Attach to last round
+ if rounds:
+ rounds[-1]["judgeDecision"] = evt.get("data")
+ elif t == "model_eliminated":
+ if rounds:
+ if "eliminated" not in rounds[-1]:
+ rounds[-1]["eliminated"] = []
+ rounds[-1]["eliminated"].append(evt.get("data"))
+ return {"debateRounds": rounds, "finalVerdict": final_verdict}
+
+ return {}