""" Background task management for debate & council operations. Decouples execution from SSE delivery so tasks survive browser disconnects. Results are persisted progressively to disk. """ import asyncio import json import logging import os import time from dataclasses import dataclass, field from enum import Enum from typing import Any, AsyncGenerator, Dict, List, Optional logger = logging.getLogger("contextflow.tasks") DATA_ROOT = os.path.abspath(os.getenv("DATA_ROOT", os.path.join(os.getcwd(), "data"))) class TaskStatus(str, Enum): PENDING = "pending" RUNNING = "running" COMPLETED = "completed" FAILED = "failed" INTERRUPTED = "interrupted" # Chunk event types that are NOT persisted to disk (redundant with *_complete events) _CHUNK_EVENT_TYPES = {"final_chunk", "stage3_chunk"} @dataclass class TaskInfo: task_id: str user: str node_id: str task_type: str # "debate" | "council" status: TaskStatus created_at: float updated_at: float events: List[Dict] = field(default_factory=list) result: Optional[Dict] = None error: Optional[str] = None # In-memory only (not serialized): new_event_signal: asyncio.Event = field(default_factory=asyncio.Event, repr=False) asyncio_task: Optional[asyncio.Task] = field(default=None, repr=False) # --------------- In-memory registry --------------- _task_registry: Dict[str, TaskInfo] = {} def register_task(task_id: str, user: str, node_id: str, task_type: str) -> TaskInfo: """Create and register a new task in memory + persist initial state to disk.""" now = time.time() info = TaskInfo( task_id=task_id, user=user, node_id=node_id, task_type=task_type, status=TaskStatus.PENDING, created_at=now, updated_at=now, ) _task_registry[task_id] = info _persist_task(info) return info def get_task(task_id: str) -> Optional[TaskInfo]: """Get task from in-memory registry.""" return _task_registry.get(task_id) def get_task_status(user: str, task_id: str) -> Optional[Dict]: """ Get task status dict. In-memory first, fallback to disk. If disk says "running" but not in memory → report "interrupted". """ info = _task_registry.get(task_id) if info: return _task_to_dict(info) # Fallback: read from disk disk_data = _load_task_from_disk(user, task_id) if not disk_data: return None # If disk says running but it's not in memory, backend must have restarted if disk_data.get("status") == TaskStatus.RUNNING: disk_data["status"] = TaskStatus.INTERRUPTED return disk_data async def run_task_in_background(info: TaskInfo, event_generator: AsyncGenerator[str, None]) -> None: """ Consume the async generator, appending events to info.events. Signals new_event_signal for SSE consumers. Persists at milestones, skips chunk events from persistence. """ info.status = TaskStatus.RUNNING info.updated_at = time.time() _persist_task(info) try: async for raw_sse in event_generator: # raw_sse is a string like 'data: {"type": "...", ...}\n\n' evt = _parse_sse_event(raw_sse) if evt is None: continue evt_type = evt.get("type", "") # Always store in memory (even chunks, for live streaming) info.events.append(evt) info.updated_at = time.time() # Signal waiting SSE consumers info.new_event_signal.set() info.new_event_signal.clear() # Persist at milestones (skip chunk events) if evt_type not in _CHUNK_EVENT_TYPES: _persist_task(info) # Detect completion / error if evt_type in ("complete", "debate_complete"): info.result = _extract_result(info.task_type, info.events) info.status = TaskStatus.COMPLETED info.updated_at = time.time() _persist_task(info) return if evt_type == "error": info.error = evt.get("data", {}).get("message", "Unknown error") if isinstance(evt.get("data"), dict) else str(evt.get("data", "Unknown error")) info.status = TaskStatus.FAILED info.updated_at = time.time() _persist_task(info) return # Generator exhausted without explicit complete/error if info.status == TaskStatus.RUNNING: info.result = _extract_result(info.task_type, info.events) info.status = TaskStatus.COMPLETED info.updated_at = time.time() _persist_task(info) except asyncio.CancelledError: info.status = TaskStatus.INTERRUPTED info.updated_at = time.time() _persist_task(info) raise except Exception as e: logger.exception("Background task %s failed: %s", info.task_id, e) info.error = str(e) info.status = TaskStatus.FAILED info.updated_at = time.time() _persist_task(info) def cleanup_stale_tasks(data_root: Optional[str] = None) -> int: """ On startup, scan task files on disk and mark any "running" as "interrupted". Returns count of tasks marked interrupted. """ root = data_root or DATA_ROOT count = 0 if not os.path.exists(root): return count for user_dir in os.listdir(root): tasks_dir = os.path.join(root, user_dir, "tasks") if not os.path.isdir(tasks_dir): continue for fname in os.listdir(tasks_dir): if not fname.endswith(".json"): continue fpath = os.path.join(tasks_dir, fname) try: with open(fpath, "r", encoding="utf-8") as f: data = json.load(f) if data.get("status") == TaskStatus.RUNNING: data["status"] = TaskStatus.INTERRUPTED data["updated_at"] = time.time() with open(fpath, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False) count += 1 logger.info("Marked stale task as interrupted: %s", fname) except Exception as e: logger.warning("Failed to process stale task file %s: %s", fpath, e) return count # --------------- Persistence helpers --------------- def _task_dir(user: str) -> str: d = os.path.join(DATA_ROOT, user, "tasks") os.makedirs(d, exist_ok=True) return d def _task_file_path(user: str, task_id: str) -> str: return os.path.join(_task_dir(user), f"{task_id}.json") def _persist_task(info: TaskInfo) -> None: """Write task state to disk. Excludes chunk events for smaller files.""" try: data = _task_to_dict(info, exclude_chunks=True) path = _task_file_path(info.user, info.task_id) with open(path, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False) except Exception as e: logger.warning("Failed to persist task %s: %s", info.task_id, e) def _load_task_from_disk(user: str, task_id: str) -> Optional[Dict]: path = _task_file_path(user, task_id) if not os.path.exists(path): return None try: with open(path, "r", encoding="utf-8") as f: return json.load(f) except Exception as e: logger.warning("Failed to load task %s from disk: %s", task_id, e) return None def _task_to_dict(info: TaskInfo, exclude_chunks: bool = False) -> Dict: events = info.events if exclude_chunks: events = [e for e in events if e.get("type") not in _CHUNK_EVENT_TYPES] return { "task_id": info.task_id, "user": info.user, "node_id": info.node_id, "task_type": info.task_type, "status": info.status.value if isinstance(info.status, TaskStatus) else info.status, "created_at": info.created_at, "updated_at": info.updated_at, "events": events, "result": info.result, "error": info.error, } def _parse_sse_event(raw: str) -> Optional[Dict]: """Parse a raw SSE string like 'data: {...}\\n\\n' into a dict.""" raw = raw.strip() if not raw: return None for line in raw.split("\n"): line = line.strip() if line.startswith("data: "): try: return json.loads(line[6:]) except json.JSONDecodeError: return None return None def _extract_result(task_type: str, events: List[Dict]) -> Dict: """Reconstruct final result data from accumulated events.""" if task_type == "council": stage1 = None stage2 = None stage3 = None for evt in events: t = evt.get("type", "") if t == "stage1_complete": stage1 = evt.get("data") elif t == "stage1_model_complete": if stage1 is None: stage1 = [] stage1.append(evt.get("data")) elif t == "stage2_complete": stage2 = evt.get("data") elif t == "stage3_complete": stage3 = evt.get("data") return {"councilData": {"stage1": stage1, "stage2": stage2, "stage3": stage3}} elif task_type == "debate": rounds: List[Dict] = [] final_verdict = None for evt in events: t = evt.get("type", "") if t == "round_complete": rounds.append(evt.get("data", {})) elif t == "final_complete": final_verdict = evt.get("data") elif t == "judge_decision": # Attach to last round if rounds: rounds[-1]["judgeDecision"] = evt.get("data") elif t == "model_eliminated": if rounds: if "eliminated" not in rounds[-1]: rounds[-1]["eliminated"] = [] rounds[-1]["eliminated"].append(evt.get("data")) return {"debateRounds": rounds, "finalVerdict": final_verdict} return {}