diff options
Diffstat (limited to 'broker.py')
| -rw-r--r-- | broker.py | 1580 |
1 files changed, 1580 insertions, 0 deletions
diff --git a/broker.py b/broker.py new file mode 100644 index 0000000..ee073fb --- /dev/null +++ b/broker.py @@ -0,0 +1,1580 @@ +""" +Claude Bridge - Broker Server +消息中转 + Telegram Bot + 任务队列 +部署在私人服务器上,24h 运行 +""" + +import asyncio +import hashlib +import hmac +import os +import shlex +import sqlite3 +import subprocess +import time +import uuid +from contextlib import asynccontextmanager + +import httpx +from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile, File, Form +from fastapi.responses import JSONResponse, FileResponse +from pydantic import BaseModel + +# === Config === + +TELEGRAM_TOKEN = os.environ["TELEGRAM_TOKEN"] +TELEGRAM_CHAT_ID = os.environ["TELEGRAM_CHAT_ID"] # 你自己的 chat id +API_SECRET = os.environ["API_SECRET"] # broker 认证密钥 +SLACK_BOT_TOKEN = os.environ.get("SLACK_BOT_TOKEN", "") +SLACK_APP_TOKEN = os.environ.get("SLACK_APP_TOKEN", "") +SLACK_ALLOWED_USERS = os.environ.get("SLACK_ALLOWED_USERS", "").split(",") # 逗号分隔的 user ID 白名单,空=全部放行 +DEFAULT_DISPATCHER = "dispatcher" # 默认 dispatcher session(owner 的) +DB_PATH = os.environ.get("DB_PATH", "bridge.db") +POLL_INTERVAL = int(os.environ.get("POLL_INTERVAL", "2")) # Telegram 轮询间隔(秒) +DISPATCHER_TMUX = os.environ.get("DISPATCHER_TMUX", "dispatcher") # dispatcher 的 tmux session 名 +HEARTBEAT_TIMEOUT = int(os.environ.get("HEARTBEAT_TIMEOUT", "600")) # 心跳超时秒数,默认 10 分钟 +FILES_DIR = os.environ.get("FILES_DIR", os.path.join(os.path.dirname(os.path.abspath(__file__)), "files")) +os.makedirs(FILES_DIR, exist_ok=True) + + +# === Database === + +def get_db() -> sqlite3.Connection: + conn = sqlite3.connect(DB_PATH) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + return conn + + +def init_db(): + db = get_db() + db.executescript(""" + CREATE TABLE IF NOT EXISTS messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + chat_id TEXT NOT NULL, + text TEXT NOT NULL, + timestamp REAL NOT NULL, + processed INTEGER DEFAULT 0 + ); + CREATE TABLE IF NOT EXISTS tasks ( + id TEXT PRIMARY KEY, + target TEXT NOT NULL DEFAULT '', + type TEXT NOT NULL DEFAULT 'task', + content TEXT NOT NULL, + status TEXT DEFAULT 'pending', + result TEXT DEFAULT '', + created_at REAL NOT NULL, + updated_at REAL NOT NULL + ); + """) + db.execute(""" + CREATE TABLE IF NOT EXISTS workers ( + id TEXT PRIMARY KEY, + session_name TEXT NOT NULL, + host TEXT NOT NULL, + path TEXT NOT NULL, + status TEXT DEFAULT 'active', + created_at REAL NOT NULL + ); + """) + db.execute(""" + CREATE TABLE IF NOT EXISTS channel_bindings ( + channel_id TEXT PRIMARY KEY, + worker_id TEXT NOT NULL, + channel_name TEXT DEFAULT '', + created_by TEXT DEFAULT '', + created_at REAL NOT NULL + ); + """) + db.execute(""" + CREATE TABLE IF NOT EXISTS schedules ( + id TEXT PRIMARY KEY, + action TEXT NOT NULL, + trigger_at REAL NOT NULL, + repeat_seconds REAL DEFAULT 0, + status TEXT DEFAULT 'active', + created_at REAL NOT NULL + ); + """) + db.execute(""" + CREATE TABLE IF NOT EXISTS commands ( + id TEXT PRIMARY KEY, + target TEXT NOT NULL, + action TEXT NOT NULL, + params TEXT NOT NULL DEFAULT '{}', + status TEXT DEFAULT 'pending', + result TEXT DEFAULT '', + created_at REAL NOT NULL + ); + """) + db.execute(""" + CREATE TABLE IF NOT EXISTS pending_context ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + text TEXT NOT NULL, + timestamp REAL NOT NULL + ); + """) + db.execute(""" + CREATE TABLE IF NOT EXISTS users ( + id TEXT PRIMARY KEY, + slack_user_id TEXT DEFAULT '', + telegram_chat_id TEXT DEFAULT '', + dispatcher_session TEXT NOT NULL, + display_name TEXT DEFAULT '', + status TEXT DEFAULT 'active', + created_at REAL NOT NULL + ); + """) + # 兼容旧表:加新列 + migrations = [ + ("tasks", "target", "''"), + ("tasks", "type", "'task'"), + ("tasks", "dispatcher_id", "''"), + ("schedules", "dispatcher_id", "''"), + ("messages", "dispatcher_id", "''"), + ("pending_context", "dispatcher_id", "''"), + ("commands", "dispatcher_id", "''"), + ("channel_bindings", "created_by", "''"), + ] + for table, col, default in migrations: + try: + db.execute(f"ALTER TABLE {table} ADD COLUMN {col} TEXT NOT NULL DEFAULT {default}") + except Exception: + pass + # 创建默认用户(迁移:把现有 dispatcher 绑定给 owner) + existing = db.execute("SELECT id FROM users WHERE id = 'owner'").fetchone() + if not existing: + db.execute( + "INSERT INTO users (id, slack_user_id, telegram_chat_id, dispatcher_session, display_name, status, created_at) VALUES (?, ?, ?, ?, ?, 'active', ?)", + ("owner", os.environ.get("SLACK_OWNER_ID", ""), TELEGRAM_CHAT_ID, "dispatcher", "Owner", time.time()), + ) + db.commit() + db.close() + + +# === User Helpers === + +def _get_user(user_id: str) -> dict | None: + db = get_db() + row = db.execute("SELECT * FROM users WHERE id = ?", (user_id,)).fetchone() + db.close() + return dict(row) if row else None + + +def _get_user_by_slack(slack_user_id: str) -> dict | None: + db = get_db() + row = db.execute("SELECT * FROM users WHERE slack_user_id = ?", (slack_user_id,)).fetchone() + db.close() + return dict(row) if row else None + + +def _get_user_by_telegram(telegram_chat_id: str) -> dict | None: + db = get_db() + row = db.execute("SELECT * FROM users WHERE telegram_chat_id = ?", (telegram_chat_id,)).fetchone() + db.close() + return dict(row) if row else None + + +def _get_dispatcher_session(dispatcher_id: str) -> str: + """Look up tmux session name for a dispatcher_id""" + if not dispatcher_id: + return DEFAULT_DISPATCHER + user = _get_user(dispatcher_id) + return user["dispatcher_session"] if user else DEFAULT_DISPATCHER + + +# worker session → last dispatcher_id mapping (for reply_to_dispatcher routing) +_worker_dispatcher_map: dict = {} + + +# === Auth === + +async def verify_token(request: Request): + auth = request.headers.get("Authorization", "") + if not auth.startswith("Bearer "): + raise HTTPException(401, "Missing token") + token = auth[7:] + if not hmac.compare_digest(token, API_SECRET): + raise HTTPException(403, "Invalid token") + + +# === Notify Dispatcher === + +_last_long_notify = 0 # 防抖:上次发长消息通知的时间 + + +def _notify_dispatcher(text: str, dispatcher_id: str = ""): + """推送消息到指定 dispatcher。短消息直接 send-keys,长消息写文件。""" + global _last_long_notify + session = _get_dispatcher_session(dispatcher_id) + try: + safe_text = text.replace("\n", " ").replace("\r", "") + if len(safe_text) <= 400: + subprocess.run( + ["tmux", "send-keys", "-t", session, safe_text, "Enter"], + timeout=5, capture_output=True, + ) + else: + msg_id = uuid.uuid4().hex[:8] + msg_path = os.path.join(FILES_DIR, f"msg_{msg_id}.txt") + with open(msg_path, "w") as f: + f.write(text) + preview = safe_text[:80] + subprocess.run( + ["tmux", "send-keys", "-t", session, + f"[Long message saved to {msg_path} — Read it first then reply] Preview: {preview}...", "Enter"], + timeout=5, capture_output=True, + ) + print(f"[Broker] Pushed to dispatcher: {text[:60]}...") + except Exception as e: + print(f"[Broker] Failed to notify dispatcher: {e}") + + +# === Telegram Polling === + +class TelegramPoller: + def __init__(self): + self.offset = 0 + self.running = False + + async def start(self): + self.running = True + async with httpx.AsyncClient(timeout=60) as client: + while self.running: + try: + resp = await client.get( + f"https://api.telegram.org/bot{TELEGRAM_TOKEN}/getUpdates", + params={"offset": self.offset, "timeout": 30}, + ) + data = resp.json() + if data.get("ok"): + for update in data.get("result", []): + self.offset = update["update_id"] + 1 + await self._handle_update(update) + except Exception as e: + print(f"[Telegram Poll Error] {e}") + await asyncio.sleep(POLL_INTERVAL) + + async def _handle_update(self, update: dict): + msg = update.get("message", {}) + chat_id = str(msg.get("chat", {}).get("id", "")) + + if not chat_id: + return + + # 只接收指定 chat_id 的消息(安全) + if chat_id != TELEGRAM_CHAT_ID: + print(f"[Telegram] Ignored message from unknown chat: {chat_id}") + return + + # 处理语音消息 + voice = msg.get("voice") + if voice: + text = await self._transcribe_voice(voice["file_id"]) + if not text: + return + text = f"[voice] {text}" + # 处理文件/文档 + elif msg.get("document"): + file_id = msg["document"]["file_id"] + file_name = msg["document"].get("file_name", f"file_{file_id}") + saved_name = await self._download_telegram_file(file_id, file_name) + if saved_name: + caption = msg.get("caption", "") + text = f"[file] {saved_name}" + (f" caption: {caption}" if caption else "") + else: + return + # 处理图片 + elif msg.get("photo"): + # Telegram 发多个尺寸,取最大的 + photo = msg["photo"][-1] + file_id = photo["file_id"] + saved_name = await self._download_telegram_file(file_id, f"photo_{int(time.time())}.jpg") + if saved_name: + caption = msg.get("caption", "") + text = f"[image] {saved_name}" + (f" caption: {caption}" if caption else "") + else: + return + else: + text = msg.get("text", "") + + if not text: + return + + db = get_db() + db.execute( + "INSERT INTO messages (chat_id, text, timestamp) VALUES (?, ?, ?)", + (chat_id, text, time.time()), + ) + db.commit() + db.close() + print(f"[Telegram] Stored message: {text[:50]}...") + + # 推送到 dispatcher + # Telegram 只给 owner 的 dispatcher + user = _get_user_by_telegram(chat_id) + did = user["id"] if user else "owner" + _notify_dispatcher(f"[from_telegram reply with send_telegram_message] {text}", dispatcher_id=did) + + async def _download_telegram_file(self, file_id: str, file_name: str) -> str: + """从 Telegram 下载文件,存到 FILES_DIR,返回保存的文件名""" + try: + async with httpx.AsyncClient(timeout=60) as client: + resp = await client.get( + f"https://api.telegram.org/bot{TELEGRAM_TOKEN}/getFile", + params={"file_id": file_id}, + ) + file_path = resp.json()["result"]["file_path"] + resp = await client.get( + f"https://api.telegram.org/file/bot{TELEGRAM_TOKEN}/{file_path}" + ) + # 防重名 + safe_name = f"{uuid.uuid4().hex[:8]}_{file_name}" + save_path = os.path.join(FILES_DIR, safe_name) + with open(save_path, "wb") as f: + f.write(resp.content) + print(f"[Telegram] File saved: {safe_name} ({len(resp.content)} bytes)") + return safe_name + except Exception as e: + print(f"[Telegram] File download failed: {e}") + return "" + + async def _transcribe_voice(self, file_id: str) -> str: + """下载 Telegram 语音文件并用 whisper 转文字""" + try: + async with httpx.AsyncClient(timeout=30) as client: + # 获取文件路径 + resp = await client.get( + f"https://api.telegram.org/bot{TELEGRAM_TOKEN}/getFile", + params={"file_id": file_id}, + ) + file_path = resp.json()["result"]["file_path"] + + # 下载文件 + resp = await client.get( + f"https://api.telegram.org/file/bot{TELEGRAM_TOKEN}/{file_path}" + ) + ogg_path = f"/tmp/voice_{file_id}.ogg" + with open(ogg_path, "wb") as f: + f.write(resp.content) + + print(f"[Telegram] Voice downloaded: {ogg_path}") + + # 用 faster-whisper 转文字 + from faster_whisper import WhisperModel + if not hasattr(self, '_whisper_model'): + print("[Whisper] Loading model (first time)...") + self._whisper_model = WhisperModel("base", device="cpu", compute_type="int8") + print("[Whisper] Model loaded.") + + segments, info = self._whisper_model.transcribe(ogg_path, language=None) + text = "".join(seg.text for seg in segments).strip() + + # 清理临时文件 + os.remove(ogg_path) + + print(f"[Whisper] Transcribed ({info.language}): {text[:80]}...") + return text + + except Exception as e: + print(f"[Whisper] Transcription failed: {e}") + return "" + + def stop(self): + self.running = False + + +poller = TelegramPoller() + + +# === Slack Socket Mode === + +_slack_handler = None +_slack_task_channels = {} # task_id → slack channel_id,worker 完成后回发 + +def _start_slack_listener(): + """在后台线程启动 Slack Socket Mode listener""" + if not SLACK_BOT_TOKEN or not SLACK_APP_TOKEN: + print("[Slack] No tokens configured, skipping") + return + + import threading + from slack_sdk.web import WebClient + from slack_sdk.socket_mode import SocketModeClient + from slack_sdk.socket_mode.request import SocketModeRequest + from slack_sdk.socket_mode.response import SocketModeResponse + + slack_client = WebClient(token=SLACK_BOT_TOKEN) + + # 缓存 bot 自己的 user id + try: + auth = slack_client.auth_test() + bot_user_id = auth["user_id"] + print(f"[Slack] Bot user ID: {bot_user_id}") + except Exception as e: + print(f"[Slack] Auth failed: {e}") + return + + # 缓存 user id → display name + _user_cache = {} + + def _get_user_name(user_id): + if user_id not in _user_cache: + try: + info = slack_client.users_info(user=user_id) + _user_cache[user_id] = info["user"]["profile"].get("display_name") or info["user"]["real_name"] or user_id + except Exception: + _user_cache[user_id] = user_id + return _user_cache[user_id] + + def _get_channel_binding(channel_id): + """查询 channel 是否绑定了 worker""" + db = get_db() + row = db.execute( + "SELECT w.session_name, w.host, w.path, w.id FROM channel_bindings cb JOIN workers w ON cb.worker_id = w.id WHERE cb.channel_id = ?", + (channel_id,), + ).fetchone() + db.close() + return dict(row) if row else None + + def _find_worker_by_path(host, path): + """根据 host:path 查找已有 worker""" + db = get_db() + row = db.execute( + "SELECT * FROM workers WHERE host = ? AND path = ? AND status = 'active'", + (host, path), + ).fetchone() + db.close() + return dict(row) if row else None + + def _create_worker(session_name, host, path): + """注册新 worker 到 DB""" + worker_id = uuid.uuid4().hex[:8] + db = get_db() + db.execute( + "INSERT INTO workers (id, session_name, host, path, status, created_at) VALUES (?, ?, ?, ?, 'active', ?)", + (worker_id, session_name, host, path, time.time()), + ) + db.commit() + db.close() + return worker_id + + def _bind_channel(channel_id, worker_id, channel_name="", created_by=""): + """绑定 channel 到 worker""" + db = get_db() + db.execute( + "INSERT OR REPLACE INTO channel_bindings (channel_id, worker_id, channel_name, created_by, created_at) VALUES (?, ?, ?, ?, ?)", + (channel_id, worker_id, channel_name, created_by, time.time()), + ) + db.commit() + db.close() + + def _handle_event(client: SocketModeClient, req: SocketModeRequest): + # 立刻 ack + client.send_socket_mode_response(SocketModeResponse(envelope_id=req.envelope_id)) + + # Debug: log all request types + print(f"[Slack] req.type={req.type}") + + # --- Slash Command --- + if req.type in ("slash_commands", "interactive"): + payload = req.payload + command = payload.get("command", "") + cmd_text = payload.get("text", "").strip() + channel = payload.get("channel_id", "") + user_id = payload.get("user_id", "") + channel_name = payload.get("channel_name", "") + + print(f"[Slack] Slash command: {command} text={cmd_text} channel={channel}") + + if command == "/help": + slack_client.chat_postMessage(channel=channel, text=( + "*Claude Bridge Commands*\n\n" + "`/register` — Register for a personal dispatcher\n" + "`/init-worker host:path` — Create or bind a worker to this channel\n" + "`/meta` — Show worker info (status, heartbeat, tasks)\n" + "`/restart-worker` — Restart the worker bound to this channel\n" + "`/stop-worker` — Stop worker and unbind from channel\n" + "`/unbind` — Unbind worker without stopping it\n" + "`/help` — Show this message\n\n" + "*In channels with a bound worker:*\n" + "• `@bot message` — Chat with the worker (instant, stays in context)\n" + "• `@bot /task description` — Submit a formal task (queued, tracked)\n\n" + "*In DM:*\n" + "• Message the bot directly to chat with your personal dispatcher" + )) + elif command == "/register": + _handle_register(channel, user_id) + elif command == "/init-worker": + _handle_init_worker(channel, channel_name, user_id, cmd_text) + elif command == "/meta": + _handle_meta(channel) + elif command == "/restart-worker": + _handle_worker_command(channel, "restart", user_id) + elif command == "/stop-worker": + _handle_worker_command(channel, "stop", user_id) + elif command == "/unbind": + _handle_unbind(channel, user_id) + return + + if req.type != "events_api": + return + + event = req.payload.get("event", {}) + event_type = event.get("type", "") + user_id = event.get("user", "") + text = event.get("text", "") + channel = event.get("channel", "") + channel_type = event.get("channel_type", "") + + # --- Bot 被加到 channel --- + if event_type == "member_joined_channel" and user_id == bot_user_id: + try: + ch_info = slack_client.conversations_info(channel=channel) + ch_name = ch_info["channel"].get("name", channel) + except Exception: + ch_name = channel + binding = _get_channel_binding(channel) + if binding: + slack_client.chat_postMessage( + channel=channel, + text=f"Bound to worker `{binding['session_name']}` ({binding['host']}:{binding['path']}). @mention me to chat.", + ) + else: + slack_client.chat_postMessage( + channel=channel, + text=f"👋 Hi! No worker bound to this channel yet.\nUse `/init-worker host:path` to initialize, e.g.:\n`/init-worker timan1:/home/yurenh2/graph-grape`", + ) + return + + # 忽略 bot 自己的消息 + if user_id == bot_user_id: + return + if not text: + return + + # 权限过滤:已注册用户 或 白名单里的用户 + is_registered = _get_user_by_slack(user_id) is not None + in_whitelist = SLACK_ALLOWED_USERS and SLACK_ALLOWED_USERS != [''] and user_id in SLACK_ALLOWED_USERS + if not is_registered and not in_whitelist: + return + + # 去掉 @mention 标记 + text = text.replace(f"<@{bot_user_id}>", "").strip() + + user_name = _get_user_name(user_id) + + # 获取 channel 名 + channel_name = channel + try: + ch_info = slack_client.conversations_info(channel=channel) + channel_name = ch_info["channel"].get("name", channel) + except Exception: + pass + + if event_type == "message" and channel_type == "im": + # DM → 查用户 → 路由到该用户的 dispatcher + user_rec = _get_user_by_slack(user_id) + if not user_rec: + slack_client.chat_postMessage(channel=channel, text="Please register first with /register") + return + did = user_rec["id"] + tag = f"[from_slack DM @{user_name} reply with send_slack_message channel={channel}]" + full_text = f"{tag} {text}" + db = get_db() + db.execute("INSERT INTO messages (chat_id, text, timestamp, dispatcher_id) VALUES (?, ?, ?, ?)", + (f"slack:{channel}:{user_id}", full_text, time.time(), did)) + db.commit() + db.close() + print(f"[Slack] DM from {user_name} → dispatcher {did}: {text[:50]}...") + _notify_dispatcher(full_text, dispatcher_id=did) + + elif event_type == "app_mention": + # Channel @bot → 检查是否绑定了 worker + binding = _get_channel_binding(channel) + if binding: + session = binding["session_name"] + # 默认走 message(问答模式),worker 用 reply_to_slack 回复 + # 只有 /task 命令才走任务模式 + if text.startswith("/task "): + # 任务模式 + task_content = text[6:].strip() + task_id = uuid.uuid4().hex[:8] + now = time.time() + db = get_db() + db.execute( + "INSERT INTO tasks (id, target, type, content, status, created_at, updated_at) VALUES (?, ?, 'task', ?, 'pending', ?, ?)", + (task_id, session, f"[from Slack #{channel_name} @{user_name}] {task_content}\n\n[Use report_result when done — result will be posted to Slack channel]", now, now), + ) + db.commit() + db.close() + _slack_task_channels[task_id] = channel + slack_client.chat_postMessage(channel=channel, text=f"📋 Task dispatched to `{session}`, ID: `{task_id}`") + print(f"[Slack] #{channel_name} TASK → {session}: {task_content[:50]}...") + else: + # 问答模式:message 直接注入 worker context + # worker 用 reply_to_slack 回复,broker 转发到 channel + task_id = uuid.uuid4().hex[:8] + now = time.time() + db = get_db() + db.execute( + "INSERT INTO tasks (id, target, type, content, status, created_at, updated_at) VALUES (?, ?, 'message', ?, 'pending', ?, ?)", + (task_id, session, f"[from Slack #{channel_name} @{user_name} — reply with reply_to_slack] {text}", now, now), + ) + db.commit() + db.close() + # 记录 message→channel 映射,reply_to_slack 时回发 + _slack_task_channels[task_id] = channel + print(f"[Slack] #{channel_name} MSG → {session}: {text[:50]}...") + else: + # 没绑定 → 路由到发送者的 dispatcher + user_rec = _get_user_by_slack(user_id) + if not user_rec: + slack_client.chat_postMessage(channel=channel, text="Please register first with /register") + return + did = user_rec["id"] + tag = f"[from_slack #{channel_name} @{user_name} reply with send_slack_message channel={channel}]" + full_text = f"{tag} {text}" + db = get_db() + db.execute("INSERT INTO messages (chat_id, text, timestamp, dispatcher_id) VALUES (?, ?, ?, ?)", + (f"slack:{channel}:{user_id}", full_text, time.time(), did)) + db.commit() + db.close() + print(f"[Slack] #{channel_name} (unbound) → dispatcher {did}: {text[:50]}...") + _notify_dispatcher(full_text, dispatcher_id=did) + + # task→channel 映射用全局变量 + global _slack_task_channels + + def _handle_register(channel, slack_user_id): + """处理 /register — 注册新用户 + 创建 dispatcher""" + existing = _get_user_by_slack(slack_user_id) + if existing: + slack_client.chat_postMessage( + channel=channel, + text=f"You are already registered. Dispatcher: `{existing['dispatcher_session']}`\nDM me to chat with your dispatcher.", + ) + return + + user_name = _get_user_name(slack_user_id) + user_id = uuid.uuid4().hex[:8] + session_name = f"dispatcher-{user_id}" + + # 创建 user 记录 + db = get_db() + db.execute( + "INSERT INTO users (id, slack_user_id, dispatcher_session, display_name, status, created_at) VALUES (?, ?, ?, ?, 'active', ?)", + (user_id, slack_user_id, session_name, user_name, time.time()), + ) + db.commit() + db.close() + + # 创建 dispatcher tmux session + workspace = os.path.expanduser(f"~/dispatchers/{user_id}") + os.makedirs(workspace, exist_ok=True) + # 复制英文 CLAUDE.md + en_md = os.path.join(os.path.dirname(os.path.abspath(__file__)), "CLAUDE.md.en") + if os.path.exists(en_md): + import shutil + shutil.copy2(en_md, os.path.join(workspace, "CLAUDE.md")) + + subprocess.run(["tmux", "new-session", "-d", "-s", session_name], timeout=5) + subprocess.run(["tmux", "send-keys", "-t", session_name, f"cd {workspace}", "Enter"], timeout=5) + time.sleep(1) + subprocess.run(["tmux", "send-keys", "-t", session_name, f"export DISPATCHER_USER_ID={user_id}", "Enter"], timeout=5) + time.sleep(0.5) + subprocess.run(["tmux", "send-keys", "-t", session_name, "claude", "Enter"], timeout=5) + + slack_client.chat_postMessage( + channel=channel, + text=f"✅ Registered! Your dispatcher: `{session_name}`\nDM me to start chatting with your personal assistant.", + ) + print(f"[Register] Created user {user_id} ({user_name}) with dispatcher {session_name}") + + def _handle_meta(channel): + """处理 /meta 命令 — 显示 channel 绑定的 worker 元数据""" + binding = _get_channel_binding(channel) + if not binding: + slack_client.chat_postMessage(channel=channel, text="No worker bound to this channel. Use `/init-worker host:path` first.") + return + + session = binding["session_name"] + host = binding["host"] + path = binding["path"] + worker_id = binding["id"] + + # 心跳 + hb = heartbeats.get(session, {}) + if hb: + age = int(time.time() - hb["last_seen"]) + alive = "🟢 online" if age < HEARTBEAT_TIMEOUT else f"🔴 offline ({age}s ago)" + last_hb = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(hb["last_seen"])) + else: + alive = "⚪ no heartbeat" + last_hb = "N/A" + + # 任务统计 + db = get_db() + running = db.execute("SELECT COUNT(*) FROM tasks WHERE target = ? AND status = 'running'", (session,)).fetchone()[0] + pending = db.execute("SELECT COUNT(*) FROM tasks WHERE target = ? AND status = 'pending'", (session,)).fetchone()[0] + done = db.execute("SELECT COUNT(*) FROM tasks WHERE target = ? AND status = 'done'", (session,)).fetchone()[0] + + # 获取 worker 创建时间 + w_row = db.execute("SELECT created_at FROM workers WHERE id = ?", (worker_id,)).fetchone() + db.close() + created = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(w_row[0])) if w_row else "N/A" + + text = ( + f"*Worker Meta*\n" + f"• Session: `{session}`\n" + f"• Host: `{host}`\n" + f"• Path: `{path}`\n" + f"• Status: {alive}\n" + f"• Last heartbeat: {last_hb}\n" + f"• Created: {created}\n" + f"• Tasks — running: {running}, pending: {pending}, done: {done}" + ) + slack_client.chat_postMessage(channel=channel, text=text) + + def _check_pm(channel, slack_user_id): + """检查用户是否是该 channel 的 PM(绑定者)""" + db = get_db() + row = db.execute("SELECT created_by FROM channel_bindings WHERE channel_id = ?", (channel,)).fetchone() + db.close() + if not row or not row["created_by"]: + return True # 没记录创建者,放行 + return row["created_by"] == slack_user_id + + def _handle_worker_command(channel, action, slack_user_id=""): + """处理 /restart-worker 和 /stop-worker""" + binding = _get_channel_binding(channel) + if not binding: + slack_client.chat_postMessage(channel=channel, text="No worker bound to this channel.") + return + if not _check_pm(channel, slack_user_id): + slack_client.chat_postMessage(channel=channel, text="Permission denied. Only the PM (who ran `/init-worker`) can do this.") + return + + session = binding["session_name"] + host = binding["host"] + cmd_id = uuid.uuid4().hex[:8] + import json as _json + db = get_db() + + if action == "restart": + db.execute( + "INSERT INTO commands (id, target, action, params, status, created_at) VALUES (?, ?, 'restart', '{}', 'pending', ?)", + (cmd_id, session, time.time()), + ) + db.commit() + db.close() + slack_client.chat_postMessage(channel=channel, text=f"🔄 Restarting worker `{session}` on {host}...") + elif action == "stop": + db.execute( + "INSERT INTO commands (id, target, action, params, status, created_at) VALUES (?, ?, 'stop', '{}', 'pending', ?)", + (cmd_id, session, time.time()), + ) + # 解绑 channel + 标记 worker inactive + db.execute("DELETE FROM channel_bindings WHERE channel_id = ?", (channel,)) + db.execute("UPDATE workers SET status = 'inactive' WHERE session_name = ?", (session,)) + db.commit() + db.close() + slack_client.chat_postMessage(channel=channel, text=f"🛑 Stopping worker `{session}` on {host}... Channel unbound.") + + def _handle_unbind(channel, slack_user_id=""): + """解绑 channel 和 worker""" + binding = _get_channel_binding(channel) + if not binding: + slack_client.chat_postMessage(channel=channel, text="No worker bound to this channel.") + return + if not _check_pm(channel, slack_user_id): + slack_client.chat_postMessage(channel=channel, text="Permission denied. Only the PM (who ran `/init-worker`) can do this.") + return + session = binding["session_name"] + db = get_db() + db.execute("DELETE FROM channel_bindings WHERE channel_id = ?", (channel,)) + db.commit() + db.close() + slack_client.chat_postMessage(channel=channel, text=f"Unbound from worker `{session}`. Use `/init-worker` to bind a new one.") + + def _handle_init_worker(channel, channel_name, user_id, cmd_text): + """处理 /init-worker 命令""" + if not cmd_text or ":" not in cmd_text: + slack_client.chat_postMessage( + channel=channel, + text="Usage: `/init-worker host:path`\nExample: `/init-worker timan1:/home/yurenh2/graph-grape`", + ) + return + + host, path = cmd_text.split(":", 1) + host = host.strip() + path = path.strip() + + # 检查是否已有这个路径的 worker + existing = _find_worker_by_path(host, path) + if existing: + # 已有 worker,直接绑定 + _bind_channel(channel, existing["id"], channel_name, created_by=user_id) + slack_client.chat_postMessage( + channel=channel, + text=f"✅ Bound to existing worker `{existing['session_name']}` ({host}:{path})\nYou can now @mention me to chat with the worker.", + ) + print(f"[Slack] Channel {channel} bound to existing worker {existing['session_name']}") + else: + # 需要创建新 worker + # 生成 session 名 + session_name = f"worker-{uuid.uuid4().hex[:6]}" + worker_id = _create_worker(session_name, host, path) + _bind_channel(channel, worker_id, channel_name, created_by=user_id) + + # 下发创建命令到 lab + cmd_id = uuid.uuid4().hex[:8] + import json as _json + db = get_db() + db.execute( + "INSERT INTO commands (id, target, action, params, status, created_at) VALUES (?, ?, ?, ?, 'pending', ?)", + (cmd_id, session_name, "create_worker", _json.dumps({"host": host, "path": path, "session_name": session_name, "slack_channel": channel}), time.time()), + ) + db.commit() + db.close() + + slack_client.chat_postMessage( + channel=channel, + text=f"🚀 Creating new worker `{session_name}` on {host}:{path}\nCommand ID: `{cmd_id}`, please wait...", + ) + print(f"[Slack] Creating new worker {session_name} at {host}:{path} for channel {channel}") + + # 启动 Socket Mode + socket_client = SocketModeClient( + app_token=SLACK_APP_TOKEN, + web_client=slack_client, + ) + socket_client.socket_mode_request_listeners.append(_handle_event) + + def _run(): + try: + socket_client.connect() + print("[Slack] Socket Mode connected") + # 保持线程活着 + import select + while True: + select.select([], [], [], 60) + except Exception as e: + print(f"[Slack] Socket Mode error: {e}") + + t = threading.Thread(target=_run, daemon=True) + t.start() + global _slack_handler + _slack_handler = slack_client + print("[Slack] Listener started in background thread") + + +@asynccontextmanager +async def lifespan(app: FastAPI): + init_db() + task = asyncio.create_task(poller.start()) + hb_task = asyncio.create_task(_heartbeat_checker()) + sched_task = asyncio.create_task(_schedule_checker()) + _start_slack_listener() + yield + poller.stop() + task.cancel() + hb_task.cancel() + sched_task.cancel() + + +# === FastAPI App === + +app = FastAPI(title="Claude Bridge Broker", lifespan=lifespan) + + +# --- Pydantic Models --- + +class TaskCreate(BaseModel): + content: str + target: str = "" + type: str = "task" + dispatcher_id: str = "" # 哪个 dispatcher 创建的 + + +class TaskResult(BaseModel): + result: str + + +class TelegramSend(BaseModel): + message: str + + +# --- Message Endpoints (调度端用) --- + +@app.get("/messages/new", dependencies=[Depends(verify_token)]) +def get_new_messages(): + """获取未处理的用户消息""" + db = get_db() + rows = db.execute( + "SELECT id, chat_id, text, timestamp FROM messages WHERE processed = 0 ORDER BY id" + ).fetchall() + # 标记为已处理 + if rows: + ids = [r["id"] for r in rows] + placeholders = ",".join("?" * len(ids)) + db.execute(f"UPDATE messages SET processed = 1 WHERE id IN ({placeholders})", ids) + db.commit() + db.close() + return { + "messages": [ + {"id": r["id"], "chat_id": r["chat_id"], "text": r["text"], "timestamp": r["timestamp"]} + for r in rows + ] + } + + +# --- Task Endpoints --- + +@app.post("/tasks", dependencies=[Depends(verify_token)]) +def create_task(body: TaskCreate): + """创建任务(调度端 → 实验室)""" + task_id = uuid.uuid4().hex[:8] + now = time.time() + db = get_db() + db.execute( + "INSERT INTO tasks (id, target, type, content, status, dispatcher_id, created_at, updated_at) VALUES (?, ?, ?, ?, 'pending', ?, ?, ?)", + (task_id, body.target, body.type, body.content, body.dispatcher_id, now, now), + ) + db.commit() + db.close() + return {"id": task_id, "target": body.target, "type": body.type, "content": body.content, "status": "pending"} + + +@app.get("/tasks/pending", dependencies=[Depends(verify_token)]) +def get_pending_tasks(target: str = ""): + """获取待执行的任务。target 过滤目标 session,只返回匹配的或无指定目标的任务""" + db = get_db() + if target: + rows = db.execute( + "SELECT id, target, type, content, created_at FROM tasks WHERE status = 'pending' AND (target = ? OR target = '') ORDER BY created_at", + (target,), + ).fetchall() + else: + rows = db.execute( + "SELECT id, target, type, content, created_at FROM tasks WHERE status = 'pending' ORDER BY created_at" + ).fetchall() + db.close() + return {"tasks": [dict(r) for r in rows]} + + +@app.post("/tasks/{task_id}/claim", dependencies=[Depends(verify_token)]) +def claim_task(task_id: str): + """领取任务(实验室端)""" + db = get_db() + cur = db.execute( + "UPDATE tasks SET status = 'running', updated_at = ? WHERE id = ? AND status = 'pending'", + (time.time(), task_id), + ) + db.commit() + if cur.rowcount == 0: + db.close() + raise HTTPException(404, "Task not found or already claimed") + # 更新 worker→dispatcher 映射 + row = db.execute("SELECT target, dispatcher_id FROM tasks WHERE id = ?", (task_id,)).fetchone() + if row and row["dispatcher_id"]: + _worker_dispatcher_map[row["target"]] = row["dispatcher_id"] + db.close() + return {"ok": True, "task_id": task_id} + + +@app.post("/tasks/{task_id}/result", dependencies=[Depends(verify_token)]) +def submit_result(task_id: str, body: TaskResult): + """提交任务结果(实验室端)""" + db = get_db() + cur = db.execute( + "UPDATE tasks SET status = 'done', result = ?, updated_at = ? WHERE id = ? AND status = 'running'", + (body.result, time.time(), task_id), + ) + db.commit() + if cur.rowcount == 0: + db.close() + raise HTTPException(404, "Task not found or not running") + db.close() + # 如果这个 task 关联了 Slack channel 且是 task 类型,发结果到 channel + # message 类型不发(worker 已经通过 reply_to_slack 直接回复了) + if _slack_handler: + channel = _slack_task_channels.pop(task_id, None) + # 查 task 类型 + db2 = get_db() + task_row = db2.execute("SELECT type FROM tasks WHERE id = ?", (task_id,)).fetchone() + task_type = task_row[0] if task_row else "task" + db2.close() + if channel and task_type == "task": + result_preview = body.result[:3000] if body.result else "(无结果)" + try: + _slack_handler.chat_postMessage( + channel=channel, + text=f"✅ Task `{task_id}` completed:\n\n{result_preview}", + ) + print(f"[Slack] Result for {task_id} sent to channel {channel}") + except Exception as e: + print(f"[Slack] Failed to send result to channel: {e}") + # 路由到正确的 dispatcher + db3 = get_db() + task_info = db3.execute("SELECT dispatcher_id FROM tasks WHERE id = ?", (task_id,)).fetchone() + db3.close() + did = task_info["dispatcher_id"] if task_info else "" + _notify_dispatcher(f"Task {task_id} completed. Use check_task_status to view result and report to user.", dispatcher_id=did) + return {"ok": True} + + +@app.post("/tasks/{task_id}/fail", dependencies=[Depends(verify_token)]) +def fail_task(task_id: str, body: TaskResult): + """标记任务失败(实验室端)""" + db = get_db() + cur = db.execute( + "UPDATE tasks SET status = 'failed', result = ?, updated_at = ? WHERE id = ?", + (body.result, time.time(), task_id), + ) + db.commit() + db.close() + db2 = get_db() + task_info = db2.execute("SELECT dispatcher_id FROM tasks WHERE id = ?", (task_id,)).fetchone() + db2.close() + did = task_info["dispatcher_id"] if task_info else "" + _notify_dispatcher(f"Task {task_id} failed. Use check_task_status to view reason and notify user.", dispatcher_id=did) + return {"ok": True} + + +@app.get("/tasks/{task_id}", dependencies=[Depends(verify_token)]) +def get_task(task_id: str): + """查询任务状态""" + db = get_db() + row = db.execute("SELECT * FROM tasks WHERE id = ?", (task_id,)).fetchone() + db.close() + if not row: + raise HTTPException(404, "Task not found") + return dict(row) + + +@app.get("/tasks", dependencies=[Depends(verify_token)]) +def list_tasks(status: str = None, limit: int = 20): + """列出任务""" + db = get_db() + if status: + rows = db.execute( + "SELECT * FROM tasks WHERE status = ? ORDER BY created_at DESC LIMIT ?", + (status, limit), + ).fetchall() + else: + rows = db.execute( + "SELECT * FROM tasks ORDER BY created_at DESC LIMIT ?", (limit,) + ).fetchall() + db.close() + return {"tasks": [dict(r) for r in rows]} + + +# --- Workers list --- + +@app.get("/workers", dependencies=[Depends(verify_token)]) +def list_workers_api(): + """List all active workers with channel bindings""" + db = get_db() + rows = db.execute(""" + SELECT w.*, cb.channel_name + FROM workers w + LEFT JOIN channel_bindings cb ON cb.worker_id = w.id + WHERE w.status = 'active' + ORDER BY w.created_at + """).fetchall() + db.close() + return {"workers": [dict(r) for r in rows]} + + +# --- Telegram Send (调度端用) --- + +@app.post("/telegram/send", dependencies=[Depends(verify_token)]) +async def send_telegram(body: TelegramSend): + """发送 Telegram 消息给用户""" + async with httpx.AsyncClient() as client: + resp = await client.post( + f"https://api.telegram.org/bot{TELEGRAM_TOKEN}/sendMessage", + json={ + "chat_id": TELEGRAM_CHAT_ID, + "text": body.message, + # 不用 Markdown,避免特殊字符解析失败 + }, + ) + if resp.status_code != 200: + raise HTTPException(502, f"Telegram API error: {resp.text}") + return {"ok": True} + + +# --- File Transfer --- + +@app.post("/files/upload", dependencies=[Depends(verify_token)]) +async def upload_file(file: UploadFile = File(...), filename: str = Form("")): + """上传文件到 broker 文件存储""" + name = filename or file.filename or f"upload_{int(time.time())}" + safe_name = f"{uuid.uuid4().hex[:8]}_{name}" + save_path = os.path.join(FILES_DIR, safe_name) + content = await file.read() + with open(save_path, "wb") as f: + f.write(content) + return {"filename": safe_name, "size": len(content)} + + +@app.get("/files/{filename}") +async def download_file(filename: str, request: Request): + """下载文件(需认证)""" + await verify_token(request) + file_path = os.path.join(FILES_DIR, filename) + if not os.path.exists(file_path): + raise HTTPException(404, "File not found") + return FileResponse(file_path, filename=filename) + + +@app.get("/files", dependencies=[Depends(verify_token)]) +def list_files(): + """列出所有文件""" + files = [] + for f in os.listdir(FILES_DIR): + fp = os.path.join(FILES_DIR, f) + files.append({"name": f, "size": os.path.getsize(fp)}) + return {"files": sorted(files, key=lambda x: x["name"])} + + +@app.post("/telegram/send_file", dependencies=[Depends(verify_token)]) +async def send_telegram_file(filename: str = Form(...), caption: str = Form("")): + """通过 Telegram 给用户发送文件""" + file_path = os.path.join(FILES_DIR, filename) + if not os.path.exists(file_path): + raise HTTPException(404, "File not found") + async with httpx.AsyncClient(timeout=60) as client: + with open(file_path, "rb") as f: + resp = await client.post( + f"https://api.telegram.org/bot{TELEGRAM_TOKEN}/sendDocument", + data={"chat_id": TELEGRAM_CHAT_ID, "caption": caption or ""}, + files={"document": (filename, f)}, + ) + if resp.status_code != 200: + raise HTTPException(502, f"Telegram API error: {resp.text}") + return {"ok": True} + + +# --- Slack Send --- + +class SlackSend(BaseModel): + channel: str # channel ID + message: str + thread_ts: str = "" # 可选,回复到 thread + + +@app.post("/slack/send", dependencies=[Depends(verify_token)]) +def send_slack_message(body: SlackSend): + """发送 Slack 消息""" + if not _slack_handler: + raise HTTPException(503, "Slack not configured") + kwargs = {"channel": body.channel, "text": body.message} + if body.thread_ts: + kwargs["thread_ts"] = body.thread_ts + try: + _slack_handler.chat_postMessage(**kwargs) + except Exception as e: + raise HTTPException(502, f"Slack API error: {e}") + return {"ok": True} + + +# --- Log / Reply (实验室 → dispatcher) --- + +class LogEntry(BaseModel): + source: str # claude / claude2 / claude3 / system + message: str + + +@app.post("/log", dependencies=[Depends(verify_token)]) +def post_log(body: LogEntry): + """实验室 Claude 回传消息给 dispatcher""" + # 从 worker→dispatcher 映射推断目标 + did = _worker_dispatcher_map.get(body.source, "") + _notify_dispatcher(f"[{body.source}] {body.message}", dispatcher_id=did) + return {"ok": True} + + +class SlackHistory(BaseModel): + session: str # worker session name, 用来查绑定的 channel + limit: int = 20 + + +@app.post("/slack/history", dependencies=[Depends(verify_token)]) +def get_slack_history(body: SlackHistory): + """获取 worker 绑定的 Slack channel 的消息历史""" + if not _slack_handler: + raise HTTPException(503, "Slack not configured") + db = get_db() + row = db.execute( + "SELECT cb.channel_id FROM channel_bindings cb JOIN workers w ON cb.worker_id = w.id WHERE w.session_name = ?", + (body.session,), + ).fetchone() + db.close() + if not row: + raise HTTPException(404, f"No Slack channel bound to session {body.session}") + channel = row[0] + try: + resp = _slack_handler.conversations_history(channel=channel, limit=body.limit) + messages = [] + for msg in resp.get("messages", []): + user = msg.get("user", "bot") + text = msg.get("text", "") + ts = msg.get("ts", "") + messages.append({"user": user, "text": text, "ts": ts}) + # 反转让旧消息在前 + messages.reverse() + return {"messages": messages} + except Exception as e: + raise HTTPException(502, f"Slack API error: {e}") + + +class SlackReply(BaseModel): + session: str # worker session name + message: str + + +@app.post("/slack/reply", dependencies=[Depends(verify_token)]) +def slack_reply(body: SlackReply): + """Worker 回复到关联的 Slack channel(通过 session→channel_binding 查找)""" + if not _slack_handler: + raise HTTPException(503, "Slack not configured") + # 从 worker session 找到绑定的 channel + db = get_db() + row = db.execute( + "SELECT cb.channel_id FROM channel_bindings cb JOIN workers w ON cb.worker_id = w.id WHERE w.session_name = ?", + (body.session,), + ).fetchone() + db.close() + if not row: + raise HTTPException(404, f"No Slack channel bound to session {body.session}") + channel = row[0] + try: + _slack_handler.chat_postMessage(channel=channel, text=body.message) + except Exception as e: + raise HTTPException(502, f"Slack API error: {e}") + return {"ok": True} + + +# --- GPT-Pro Expert --- + +OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") +expert_requests: dict = {} # id -> {status, question, answer} + + +class ExpertQuestion(BaseModel): + question: str + + +@app.post("/expert/ask", dependencies=[Depends(verify_token)]) +def ask_expert(body: ExpertQuestion): + """提交问题给 GPT-Pro,后台异步处理""" + req_id = uuid.uuid4().hex[:8] + expert_requests[req_id] = { + "status": "thinking", + "question": body.question, + "answer": "", + } + # 后台线程调 OpenAI API + import threading + threading.Thread(target=_call_gpt_pro, args=(req_id, body.question), daemon=True).start() + return {"id": req_id, "status": "thinking"} + + +@app.get("/expert/{req_id}", dependencies=[Depends(verify_token)]) +def get_expert(req_id: str): + if req_id not in expert_requests: + raise HTTPException(404, "Request not found") + return expert_requests[req_id] + + +def _call_gpt_pro(req_id: str, question: str): + """同步调用 OpenAI GPT-Pro(在后台线程中运行)""" + import json + import urllib.request + + try: + req = urllib.request.Request( + "https://api.openai.com/v1/responses", + data=json.dumps({ + "model": "o3-pro", + "input": question, + }).encode(), + headers={ + "Authorization": f"Bearer {OPENAI_API_KEY}", + "Content-Type": "application/json", + }, + ) + with urllib.request.urlopen(req, timeout=600) as resp: + data = json.loads(resp.read()) + + # 提取回复文本 + answer = "" + for item in data.get("output", []): + if item.get("type") == "message": + for content in item.get("content", []): + if content.get("type") == "output_text": + answer += content.get("text", "") + + expert_requests[req_id]["status"] = "done" + expert_requests[req_id]["answer"] = answer or "(empty response)" + + except Exception as e: + expert_requests[req_id]["status"] = "error" + expert_requests[req_id]["answer"] = str(e) + + # 通知 dispatcher + status = expert_requests[req_id]["status"] + _notify_dispatcher( + f"[system] GPT-Pro reply ready (ID: {req_id}, status: {status}). Use get_expert_answer to view and forward to user." + ) + + +# --- Commands (lab server执行的系统命令) --- + +class CommandCreate(BaseModel): + target: str # tmux session name + action: str # switch_project, restart, etc. + params: dict = {} + + +@app.post("/commands", dependencies=[Depends(verify_token)]) +def create_command(body: CommandCreate): + """创建一个命令让 lab cron 执行""" + cmd_id = uuid.uuid4().hex[:8] + db = get_db() + import json as _json + db.execute( + "INSERT INTO commands (id, target, action, params, status, created_at) VALUES (?, ?, ?, ?, 'pending', ?)", + (cmd_id, body.target, body.action, _json.dumps(body.params), time.time()), + ) + db.commit() + db.close() + return {"id": cmd_id, "target": body.target, "action": body.action} + + +@app.get("/commands/pending", dependencies=[Depends(verify_token)]) +def get_pending_commands(): + """lab cron 轮询待执行命令""" + db = get_db() + rows = db.execute( + "SELECT id, target, action, params FROM commands WHERE status = 'pending' ORDER BY created_at" + ).fetchall() + db.close() + return {"commands": [dict(r) for r in rows]} + + +@app.post("/commands/{cmd_id}/done", dependencies=[Depends(verify_token)]) +def complete_command(cmd_id: str, body: TaskResult): + """标记命令完成""" + db = get_db() + db.execute( + "UPDATE commands SET status = 'done', result = ? WHERE id = ?", + (body.result, cmd_id), + ) + # 查命令的 target session,找到绑定的 Slack channel 回报 + row = db.execute("SELECT target FROM commands WHERE id = ?", (cmd_id,)).fetchone() + target_session = row[0] if row else "" + slack_channel = None + if target_session and _slack_handler: + ch_row = db.execute( + "SELECT cb.channel_id FROM channel_bindings cb JOIN workers w ON cb.worker_id = w.id WHERE w.session_name = ?", + (target_session,), + ).fetchone() + if ch_row: + slack_channel = ch_row[0] + db.commit() + db.close() + + result = body.result + is_ok = result.startswith("OK") + if slack_channel: + emoji = "✅" if is_ok else "❌" + try: + _slack_handler.chat_postMessage(channel=slack_channel, text=f"{emoji} {result}") + except Exception: + pass + _notify_dispatcher(f"[system] Command {cmd_id} completed: {result}") + return {"ok": True} + + +@app.get("/commands/{cmd_id}", dependencies=[Depends(verify_token)]) +def get_command(cmd_id: str): + """查询命令状态""" + db = get_db() + row = db.execute("SELECT * FROM commands WHERE id = ?", (cmd_id,)).fetchone() + db.close() + if not row: + raise HTTPException(404, "Command not found") + return dict(row) + + +# --- Pending Context (for UserPromptSubmit hook) --- + +@app.get("/context/pending", dependencies=[Depends(verify_token)]) +def get_pending_context(dispatcher_id: str = ""): + """Hook 调用:拉取待注入的 context,拉取后自动删除""" + db = get_db() + if dispatcher_id: + rows = db.execute("SELECT id, text FROM pending_context WHERE dispatcher_id = ? OR dispatcher_id = '' ORDER BY id", (dispatcher_id,)).fetchall() + else: + rows = db.execute("SELECT id, text FROM pending_context ORDER BY id").fetchall() + if rows: + ids = [r["id"] for r in rows] + placeholders = ",".join("?" * len(ids)) + db.execute(f"DELETE FROM pending_context WHERE id IN ({placeholders})", ids) + db.commit() + db.close() + texts = [r["text"] for r in rows] + return {"messages": texts} + + +# --- Schedules (定时任务) --- + +class ScheduleCreate(BaseModel): + action: str + trigger_at: float = 0 + delay_seconds: float = 0 + repeat_seconds: float = 0 + dispatcher_id: str = "" + + +@app.post("/schedules", dependencies=[Depends(verify_token)]) +def create_schedule(body: ScheduleCreate): + """创建定时任务""" + schedule_id = uuid.uuid4().hex[:8] + now = time.time() + if body.trigger_at > 0: + trigger = body.trigger_at + elif body.delay_seconds > 0: + trigger = now + body.delay_seconds + else: + trigger = now + db = get_db() + db.execute( + "INSERT INTO schedules (id, action, trigger_at, repeat_seconds, dispatcher_id, status, created_at) VALUES (?, ?, ?, ?, ?, 'active', ?)", + (schedule_id, body.action, trigger, body.repeat_seconds, body.dispatcher_id, now), + ) + db.commit() + db.close() + import datetime + trigger_str = datetime.datetime.fromtimestamp(trigger).strftime("%Y-%m-%d %H:%M:%S") + return {"id": schedule_id, "trigger_at": trigger_str, "repeat_seconds": body.repeat_seconds} + + +@app.get("/schedules", dependencies=[Depends(verify_token)]) +def list_schedules(): + """列出所有定时任务""" + db = get_db() + rows = db.execute("SELECT * FROM schedules WHERE status = 'active' ORDER BY trigger_at").fetchall() + db.close() + return {"schedules": [dict(r) for r in rows]} + + +@app.delete("/schedules/{schedule_id}", dependencies=[Depends(verify_token)]) +def cancel_schedule(schedule_id: str): + """取消定时任务""" + db = get_db() + db.execute("UPDATE schedules SET status = 'cancelled' WHERE id = ?", (schedule_id,)) + db.commit() + db.close() + return {"ok": True} + + +async def _schedule_checker(): + """后台协程:每 30 秒检查定时任务,到时间就触发""" + while True: + await asyncio.sleep(30) + now = time.time() + db = get_db() + rows = db.execute( + "SELECT id, action, trigger_at, repeat_seconds, dispatcher_id FROM schedules WHERE status = 'active' AND trigger_at <= ?", + (now,), + ).fetchall() + for r in rows: + sid, action, trigger_at, repeat = r["id"], r["action"], r["trigger_at"], r["repeat_seconds"] + did = r["dispatcher_id"] if "dispatcher_id" in r.keys() else "" + _notify_dispatcher(f"[scheduled {sid}] {action}", dispatcher_id=did) + if repeat > 0: + # 循环任务:更新下次触发时间 + next_trigger = trigger_at + repeat + # 如果落后太多,跳到下一个未来时间点 + while next_trigger <= now: + next_trigger += repeat + db.execute("UPDATE schedules SET trigger_at = ? WHERE id = ?", (next_trigger, sid)) + else: + # 一次性任务:标记完成 + db.execute("UPDATE schedules SET status = 'done' WHERE id = ?", (sid,)) + db.commit() + db.close() + + +# --- Heartbeat --- + +heartbeats: dict = {} # session -> {host, last_seen, alerted} + + +class Heartbeat(BaseModel): + session: str + host: str = "" + + +@app.post("/heartbeat", dependencies=[Depends(verify_token)]) +def post_heartbeat(body: Heartbeat): + """Worker 心跳上报""" + heartbeats[body.session] = { + "host": body.host, + "last_seen": time.time(), + "alerted": False, + } + return {"ok": True} + + +@app.get("/heartbeat", dependencies=[Depends(verify_token)]) +def get_heartbeats(): + """查看所有 worker 心跳状态""" + now = time.time() + result = {} + for session, info in heartbeats.items(): + age = now - info["last_seen"] + result[session] = { + "host": info["host"], + "last_seen": info["last_seen"], + "age_seconds": int(age), + "alive": age < HEARTBEAT_TIMEOUT, + } + return result + + +async def _heartbeat_checker(): + """后台协程:定期检查心跳,超时则 Telegram 通知用户""" + while True: + await asyncio.sleep(60) # 每分钟检查一次 + now = time.time() + for session, info in heartbeats.items(): + age = now - info["last_seen"] + if age > HEARTBEAT_TIMEOUT and not info.get("alerted"): + # 超时!通知用户 + info["alerted"] = True + host = info.get("host", "unknown") + msg = f"Worker {session} ({host}) heartbeat timeout ({int(age)}s). May be offline. Please check and restart." + try: + async with httpx.AsyncClient() as client: + await client.post( + f"https://api.telegram.org/bot{TELEGRAM_TOKEN}/sendMessage", + json={"chat_id": TELEGRAM_CHAT_ID, "text": msg}, + ) + print(f"[Heartbeat] ALERT: {session} timeout, notified user") + except Exception as e: + print(f"[Heartbeat] Failed to send alert: {e}") + # 也通知 dispatcher + _notify_dispatcher(f"[system] Worker {session} ({host}) heartbeat timeout, may be offline") + + +# --- Health Check --- + +@app.get("/health") +def health(): + return {"status": "ok", "time": time.time()} |
