summaryrefslogtreecommitdiff
path: root/broker.py
diff options
context:
space:
mode:
Diffstat (limited to 'broker.py')
-rw-r--r--broker.py1580
1 files changed, 1580 insertions, 0 deletions
diff --git a/broker.py b/broker.py
new file mode 100644
index 0000000..ee073fb
--- /dev/null
+++ b/broker.py
@@ -0,0 +1,1580 @@
+"""
+Claude Bridge - Broker Server
+消息中转 + Telegram Bot + 任务队列
+部署在私人服务器上,24h 运行
+"""
+
+import asyncio
+import hashlib
+import hmac
+import os
+import shlex
+import sqlite3
+import subprocess
+import time
+import uuid
+from contextlib import asynccontextmanager
+
+import httpx
+from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile, File, Form
+from fastapi.responses import JSONResponse, FileResponse
+from pydantic import BaseModel
+
+# === Config ===
+
+TELEGRAM_TOKEN = os.environ["TELEGRAM_TOKEN"]
+TELEGRAM_CHAT_ID = os.environ["TELEGRAM_CHAT_ID"] # 你自己的 chat id
+API_SECRET = os.environ["API_SECRET"] # broker 认证密钥
+SLACK_BOT_TOKEN = os.environ.get("SLACK_BOT_TOKEN", "")
+SLACK_APP_TOKEN = os.environ.get("SLACK_APP_TOKEN", "")
+SLACK_ALLOWED_USERS = os.environ.get("SLACK_ALLOWED_USERS", "").split(",") # 逗号分隔的 user ID 白名单,空=全部放行
+DEFAULT_DISPATCHER = "dispatcher" # 默认 dispatcher session(owner 的)
+DB_PATH = os.environ.get("DB_PATH", "bridge.db")
+POLL_INTERVAL = int(os.environ.get("POLL_INTERVAL", "2")) # Telegram 轮询间隔(秒)
+DISPATCHER_TMUX = os.environ.get("DISPATCHER_TMUX", "dispatcher") # dispatcher 的 tmux session 名
+HEARTBEAT_TIMEOUT = int(os.environ.get("HEARTBEAT_TIMEOUT", "600")) # 心跳超时秒数,默认 10 分钟
+FILES_DIR = os.environ.get("FILES_DIR", os.path.join(os.path.dirname(os.path.abspath(__file__)), "files"))
+os.makedirs(FILES_DIR, exist_ok=True)
+
+
+# === Database ===
+
+def get_db() -> sqlite3.Connection:
+ conn = sqlite3.connect(DB_PATH)
+ conn.row_factory = sqlite3.Row
+ conn.execute("PRAGMA journal_mode=WAL")
+ return conn
+
+
+def init_db():
+ db = get_db()
+ db.executescript("""
+ CREATE TABLE IF NOT EXISTS messages (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ chat_id TEXT NOT NULL,
+ text TEXT NOT NULL,
+ timestamp REAL NOT NULL,
+ processed INTEGER DEFAULT 0
+ );
+ CREATE TABLE IF NOT EXISTS tasks (
+ id TEXT PRIMARY KEY,
+ target TEXT NOT NULL DEFAULT '',
+ type TEXT NOT NULL DEFAULT 'task',
+ content TEXT NOT NULL,
+ status TEXT DEFAULT 'pending',
+ result TEXT DEFAULT '',
+ created_at REAL NOT NULL,
+ updated_at REAL NOT NULL
+ );
+ """)
+ db.execute("""
+ CREATE TABLE IF NOT EXISTS workers (
+ id TEXT PRIMARY KEY,
+ session_name TEXT NOT NULL,
+ host TEXT NOT NULL,
+ path TEXT NOT NULL,
+ status TEXT DEFAULT 'active',
+ created_at REAL NOT NULL
+ );
+ """)
+ db.execute("""
+ CREATE TABLE IF NOT EXISTS channel_bindings (
+ channel_id TEXT PRIMARY KEY,
+ worker_id TEXT NOT NULL,
+ channel_name TEXT DEFAULT '',
+ created_by TEXT DEFAULT '',
+ created_at REAL NOT NULL
+ );
+ """)
+ db.execute("""
+ CREATE TABLE IF NOT EXISTS schedules (
+ id TEXT PRIMARY KEY,
+ action TEXT NOT NULL,
+ trigger_at REAL NOT NULL,
+ repeat_seconds REAL DEFAULT 0,
+ status TEXT DEFAULT 'active',
+ created_at REAL NOT NULL
+ );
+ """)
+ db.execute("""
+ CREATE TABLE IF NOT EXISTS commands (
+ id TEXT PRIMARY KEY,
+ target TEXT NOT NULL,
+ action TEXT NOT NULL,
+ params TEXT NOT NULL DEFAULT '{}',
+ status TEXT DEFAULT 'pending',
+ result TEXT DEFAULT '',
+ created_at REAL NOT NULL
+ );
+ """)
+ db.execute("""
+ CREATE TABLE IF NOT EXISTS pending_context (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ text TEXT NOT NULL,
+ timestamp REAL NOT NULL
+ );
+ """)
+ db.execute("""
+ CREATE TABLE IF NOT EXISTS users (
+ id TEXT PRIMARY KEY,
+ slack_user_id TEXT DEFAULT '',
+ telegram_chat_id TEXT DEFAULT '',
+ dispatcher_session TEXT NOT NULL,
+ display_name TEXT DEFAULT '',
+ status TEXT DEFAULT 'active',
+ created_at REAL NOT NULL
+ );
+ """)
+ # 兼容旧表:加新列
+ migrations = [
+ ("tasks", "target", "''"),
+ ("tasks", "type", "'task'"),
+ ("tasks", "dispatcher_id", "''"),
+ ("schedules", "dispatcher_id", "''"),
+ ("messages", "dispatcher_id", "''"),
+ ("pending_context", "dispatcher_id", "''"),
+ ("commands", "dispatcher_id", "''"),
+ ("channel_bindings", "created_by", "''"),
+ ]
+ for table, col, default in migrations:
+ try:
+ db.execute(f"ALTER TABLE {table} ADD COLUMN {col} TEXT NOT NULL DEFAULT {default}")
+ except Exception:
+ pass
+ # 创建默认用户(迁移:把现有 dispatcher 绑定给 owner)
+ existing = db.execute("SELECT id FROM users WHERE id = 'owner'").fetchone()
+ if not existing:
+ db.execute(
+ "INSERT INTO users (id, slack_user_id, telegram_chat_id, dispatcher_session, display_name, status, created_at) VALUES (?, ?, ?, ?, ?, 'active', ?)",
+ ("owner", os.environ.get("SLACK_OWNER_ID", ""), TELEGRAM_CHAT_ID, "dispatcher", "Owner", time.time()),
+ )
+ db.commit()
+ db.close()
+
+
+# === User Helpers ===
+
+def _get_user(user_id: str) -> dict | None:
+ db = get_db()
+ row = db.execute("SELECT * FROM users WHERE id = ?", (user_id,)).fetchone()
+ db.close()
+ return dict(row) if row else None
+
+
+def _get_user_by_slack(slack_user_id: str) -> dict | None:
+ db = get_db()
+ row = db.execute("SELECT * FROM users WHERE slack_user_id = ?", (slack_user_id,)).fetchone()
+ db.close()
+ return dict(row) if row else None
+
+
+def _get_user_by_telegram(telegram_chat_id: str) -> dict | None:
+ db = get_db()
+ row = db.execute("SELECT * FROM users WHERE telegram_chat_id = ?", (telegram_chat_id,)).fetchone()
+ db.close()
+ return dict(row) if row else None
+
+
+def _get_dispatcher_session(dispatcher_id: str) -> str:
+ """Look up tmux session name for a dispatcher_id"""
+ if not dispatcher_id:
+ return DEFAULT_DISPATCHER
+ user = _get_user(dispatcher_id)
+ return user["dispatcher_session"] if user else DEFAULT_DISPATCHER
+
+
+# worker session → last dispatcher_id mapping (for reply_to_dispatcher routing)
+_worker_dispatcher_map: dict = {}
+
+
+# === Auth ===
+
+async def verify_token(request: Request):
+ auth = request.headers.get("Authorization", "")
+ if not auth.startswith("Bearer "):
+ raise HTTPException(401, "Missing token")
+ token = auth[7:]
+ if not hmac.compare_digest(token, API_SECRET):
+ raise HTTPException(403, "Invalid token")
+
+
+# === Notify Dispatcher ===
+
+_last_long_notify = 0 # 防抖:上次发长消息通知的时间
+
+
+def _notify_dispatcher(text: str, dispatcher_id: str = ""):
+ """推送消息到指定 dispatcher。短消息直接 send-keys,长消息写文件。"""
+ global _last_long_notify
+ session = _get_dispatcher_session(dispatcher_id)
+ try:
+ safe_text = text.replace("\n", " ").replace("\r", "")
+ if len(safe_text) <= 400:
+ subprocess.run(
+ ["tmux", "send-keys", "-t", session, safe_text, "Enter"],
+ timeout=5, capture_output=True,
+ )
+ else:
+ msg_id = uuid.uuid4().hex[:8]
+ msg_path = os.path.join(FILES_DIR, f"msg_{msg_id}.txt")
+ with open(msg_path, "w") as f:
+ f.write(text)
+ preview = safe_text[:80]
+ subprocess.run(
+ ["tmux", "send-keys", "-t", session,
+ f"[Long message saved to {msg_path} — Read it first then reply] Preview: {preview}...", "Enter"],
+ timeout=5, capture_output=True,
+ )
+ print(f"[Broker] Pushed to dispatcher: {text[:60]}...")
+ except Exception as e:
+ print(f"[Broker] Failed to notify dispatcher: {e}")
+
+
+# === Telegram Polling ===
+
+class TelegramPoller:
+ def __init__(self):
+ self.offset = 0
+ self.running = False
+
+ async def start(self):
+ self.running = True
+ async with httpx.AsyncClient(timeout=60) as client:
+ while self.running:
+ try:
+ resp = await client.get(
+ f"https://api.telegram.org/bot{TELEGRAM_TOKEN}/getUpdates",
+ params={"offset": self.offset, "timeout": 30},
+ )
+ data = resp.json()
+ if data.get("ok"):
+ for update in data.get("result", []):
+ self.offset = update["update_id"] + 1
+ await self._handle_update(update)
+ except Exception as e:
+ print(f"[Telegram Poll Error] {e}")
+ await asyncio.sleep(POLL_INTERVAL)
+
+ async def _handle_update(self, update: dict):
+ msg = update.get("message", {})
+ chat_id = str(msg.get("chat", {}).get("id", ""))
+
+ if not chat_id:
+ return
+
+ # 只接收指定 chat_id 的消息(安全)
+ if chat_id != TELEGRAM_CHAT_ID:
+ print(f"[Telegram] Ignored message from unknown chat: {chat_id}")
+ return
+
+ # 处理语音消息
+ voice = msg.get("voice")
+ if voice:
+ text = await self._transcribe_voice(voice["file_id"])
+ if not text:
+ return
+ text = f"[voice] {text}"
+ # 处理文件/文档
+ elif msg.get("document"):
+ file_id = msg["document"]["file_id"]
+ file_name = msg["document"].get("file_name", f"file_{file_id}")
+ saved_name = await self._download_telegram_file(file_id, file_name)
+ if saved_name:
+ caption = msg.get("caption", "")
+ text = f"[file] {saved_name}" + (f" caption: {caption}" if caption else "")
+ else:
+ return
+ # 处理图片
+ elif msg.get("photo"):
+ # Telegram 发多个尺寸,取最大的
+ photo = msg["photo"][-1]
+ file_id = photo["file_id"]
+ saved_name = await self._download_telegram_file(file_id, f"photo_{int(time.time())}.jpg")
+ if saved_name:
+ caption = msg.get("caption", "")
+ text = f"[image] {saved_name}" + (f" caption: {caption}" if caption else "")
+ else:
+ return
+ else:
+ text = msg.get("text", "")
+
+ if not text:
+ return
+
+ db = get_db()
+ db.execute(
+ "INSERT INTO messages (chat_id, text, timestamp) VALUES (?, ?, ?)",
+ (chat_id, text, time.time()),
+ )
+ db.commit()
+ db.close()
+ print(f"[Telegram] Stored message: {text[:50]}...")
+
+ # 推送到 dispatcher
+ # Telegram 只给 owner 的 dispatcher
+ user = _get_user_by_telegram(chat_id)
+ did = user["id"] if user else "owner"
+ _notify_dispatcher(f"[from_telegram reply with send_telegram_message] {text}", dispatcher_id=did)
+
+ async def _download_telegram_file(self, file_id: str, file_name: str) -> str:
+ """从 Telegram 下载文件,存到 FILES_DIR,返回保存的文件名"""
+ try:
+ async with httpx.AsyncClient(timeout=60) as client:
+ resp = await client.get(
+ f"https://api.telegram.org/bot{TELEGRAM_TOKEN}/getFile",
+ params={"file_id": file_id},
+ )
+ file_path = resp.json()["result"]["file_path"]
+ resp = await client.get(
+ f"https://api.telegram.org/file/bot{TELEGRAM_TOKEN}/{file_path}"
+ )
+ # 防重名
+ safe_name = f"{uuid.uuid4().hex[:8]}_{file_name}"
+ save_path = os.path.join(FILES_DIR, safe_name)
+ with open(save_path, "wb") as f:
+ f.write(resp.content)
+ print(f"[Telegram] File saved: {safe_name} ({len(resp.content)} bytes)")
+ return safe_name
+ except Exception as e:
+ print(f"[Telegram] File download failed: {e}")
+ return ""
+
+ async def _transcribe_voice(self, file_id: str) -> str:
+ """下载 Telegram 语音文件并用 whisper 转文字"""
+ try:
+ async with httpx.AsyncClient(timeout=30) as client:
+ # 获取文件路径
+ resp = await client.get(
+ f"https://api.telegram.org/bot{TELEGRAM_TOKEN}/getFile",
+ params={"file_id": file_id},
+ )
+ file_path = resp.json()["result"]["file_path"]
+
+ # 下载文件
+ resp = await client.get(
+ f"https://api.telegram.org/file/bot{TELEGRAM_TOKEN}/{file_path}"
+ )
+ ogg_path = f"/tmp/voice_{file_id}.ogg"
+ with open(ogg_path, "wb") as f:
+ f.write(resp.content)
+
+ print(f"[Telegram] Voice downloaded: {ogg_path}")
+
+ # 用 faster-whisper 转文字
+ from faster_whisper import WhisperModel
+ if not hasattr(self, '_whisper_model'):
+ print("[Whisper] Loading model (first time)...")
+ self._whisper_model = WhisperModel("base", device="cpu", compute_type="int8")
+ print("[Whisper] Model loaded.")
+
+ segments, info = self._whisper_model.transcribe(ogg_path, language=None)
+ text = "".join(seg.text for seg in segments).strip()
+
+ # 清理临时文件
+ os.remove(ogg_path)
+
+ print(f"[Whisper] Transcribed ({info.language}): {text[:80]}...")
+ return text
+
+ except Exception as e:
+ print(f"[Whisper] Transcription failed: {e}")
+ return ""
+
+ def stop(self):
+ self.running = False
+
+
+poller = TelegramPoller()
+
+
+# === Slack Socket Mode ===
+
+_slack_handler = None
+_slack_task_channels = {} # task_id → slack channel_id,worker 完成后回发
+
+def _start_slack_listener():
+ """在后台线程启动 Slack Socket Mode listener"""
+ if not SLACK_BOT_TOKEN or not SLACK_APP_TOKEN:
+ print("[Slack] No tokens configured, skipping")
+ return
+
+ import threading
+ from slack_sdk.web import WebClient
+ from slack_sdk.socket_mode import SocketModeClient
+ from slack_sdk.socket_mode.request import SocketModeRequest
+ from slack_sdk.socket_mode.response import SocketModeResponse
+
+ slack_client = WebClient(token=SLACK_BOT_TOKEN)
+
+ # 缓存 bot 自己的 user id
+ try:
+ auth = slack_client.auth_test()
+ bot_user_id = auth["user_id"]
+ print(f"[Slack] Bot user ID: {bot_user_id}")
+ except Exception as e:
+ print(f"[Slack] Auth failed: {e}")
+ return
+
+ # 缓存 user id → display name
+ _user_cache = {}
+
+ def _get_user_name(user_id):
+ if user_id not in _user_cache:
+ try:
+ info = slack_client.users_info(user=user_id)
+ _user_cache[user_id] = info["user"]["profile"].get("display_name") or info["user"]["real_name"] or user_id
+ except Exception:
+ _user_cache[user_id] = user_id
+ return _user_cache[user_id]
+
+ def _get_channel_binding(channel_id):
+ """查询 channel 是否绑定了 worker"""
+ db = get_db()
+ row = db.execute(
+ "SELECT w.session_name, w.host, w.path, w.id FROM channel_bindings cb JOIN workers w ON cb.worker_id = w.id WHERE cb.channel_id = ?",
+ (channel_id,),
+ ).fetchone()
+ db.close()
+ return dict(row) if row else None
+
+ def _find_worker_by_path(host, path):
+ """根据 host:path 查找已有 worker"""
+ db = get_db()
+ row = db.execute(
+ "SELECT * FROM workers WHERE host = ? AND path = ? AND status = 'active'",
+ (host, path),
+ ).fetchone()
+ db.close()
+ return dict(row) if row else None
+
+ def _create_worker(session_name, host, path):
+ """注册新 worker 到 DB"""
+ worker_id = uuid.uuid4().hex[:8]
+ db = get_db()
+ db.execute(
+ "INSERT INTO workers (id, session_name, host, path, status, created_at) VALUES (?, ?, ?, ?, 'active', ?)",
+ (worker_id, session_name, host, path, time.time()),
+ )
+ db.commit()
+ db.close()
+ return worker_id
+
+ def _bind_channel(channel_id, worker_id, channel_name="", created_by=""):
+ """绑定 channel 到 worker"""
+ db = get_db()
+ db.execute(
+ "INSERT OR REPLACE INTO channel_bindings (channel_id, worker_id, channel_name, created_by, created_at) VALUES (?, ?, ?, ?, ?)",
+ (channel_id, worker_id, channel_name, created_by, time.time()),
+ )
+ db.commit()
+ db.close()
+
+ def _handle_event(client: SocketModeClient, req: SocketModeRequest):
+ # 立刻 ack
+ client.send_socket_mode_response(SocketModeResponse(envelope_id=req.envelope_id))
+
+ # Debug: log all request types
+ print(f"[Slack] req.type={req.type}")
+
+ # --- Slash Command ---
+ if req.type in ("slash_commands", "interactive"):
+ payload = req.payload
+ command = payload.get("command", "")
+ cmd_text = payload.get("text", "").strip()
+ channel = payload.get("channel_id", "")
+ user_id = payload.get("user_id", "")
+ channel_name = payload.get("channel_name", "")
+
+ print(f"[Slack] Slash command: {command} text={cmd_text} channel={channel}")
+
+ if command == "/help":
+ slack_client.chat_postMessage(channel=channel, text=(
+ "*Claude Bridge Commands*\n\n"
+ "`/register` — Register for a personal dispatcher\n"
+ "`/init-worker host:path` — Create or bind a worker to this channel\n"
+ "`/meta` — Show worker info (status, heartbeat, tasks)\n"
+ "`/restart-worker` — Restart the worker bound to this channel\n"
+ "`/stop-worker` — Stop worker and unbind from channel\n"
+ "`/unbind` — Unbind worker without stopping it\n"
+ "`/help` — Show this message\n\n"
+ "*In channels with a bound worker:*\n"
+ "• `@bot message` — Chat with the worker (instant, stays in context)\n"
+ "• `@bot /task description` — Submit a formal task (queued, tracked)\n\n"
+ "*In DM:*\n"
+ "• Message the bot directly to chat with your personal dispatcher"
+ ))
+ elif command == "/register":
+ _handle_register(channel, user_id)
+ elif command == "/init-worker":
+ _handle_init_worker(channel, channel_name, user_id, cmd_text)
+ elif command == "/meta":
+ _handle_meta(channel)
+ elif command == "/restart-worker":
+ _handle_worker_command(channel, "restart", user_id)
+ elif command == "/stop-worker":
+ _handle_worker_command(channel, "stop", user_id)
+ elif command == "/unbind":
+ _handle_unbind(channel, user_id)
+ return
+
+ if req.type != "events_api":
+ return
+
+ event = req.payload.get("event", {})
+ event_type = event.get("type", "")
+ user_id = event.get("user", "")
+ text = event.get("text", "")
+ channel = event.get("channel", "")
+ channel_type = event.get("channel_type", "")
+
+ # --- Bot 被加到 channel ---
+ if event_type == "member_joined_channel" and user_id == bot_user_id:
+ try:
+ ch_info = slack_client.conversations_info(channel=channel)
+ ch_name = ch_info["channel"].get("name", channel)
+ except Exception:
+ ch_name = channel
+ binding = _get_channel_binding(channel)
+ if binding:
+ slack_client.chat_postMessage(
+ channel=channel,
+ text=f"Bound to worker `{binding['session_name']}` ({binding['host']}:{binding['path']}). @mention me to chat.",
+ )
+ else:
+ slack_client.chat_postMessage(
+ channel=channel,
+ text=f"👋 Hi! No worker bound to this channel yet.\nUse `/init-worker host:path` to initialize, e.g.:\n`/init-worker timan1:/home/yurenh2/graph-grape`",
+ )
+ return
+
+ # 忽略 bot 自己的消息
+ if user_id == bot_user_id:
+ return
+ if not text:
+ return
+
+ # 权限过滤:已注册用户 或 白名单里的用户
+ is_registered = _get_user_by_slack(user_id) is not None
+ in_whitelist = SLACK_ALLOWED_USERS and SLACK_ALLOWED_USERS != [''] and user_id in SLACK_ALLOWED_USERS
+ if not is_registered and not in_whitelist:
+ return
+
+ # 去掉 @mention 标记
+ text = text.replace(f"<@{bot_user_id}>", "").strip()
+
+ user_name = _get_user_name(user_id)
+
+ # 获取 channel 名
+ channel_name = channel
+ try:
+ ch_info = slack_client.conversations_info(channel=channel)
+ channel_name = ch_info["channel"].get("name", channel)
+ except Exception:
+ pass
+
+ if event_type == "message" and channel_type == "im":
+ # DM → 查用户 → 路由到该用户的 dispatcher
+ user_rec = _get_user_by_slack(user_id)
+ if not user_rec:
+ slack_client.chat_postMessage(channel=channel, text="Please register first with /register")
+ return
+ did = user_rec["id"]
+ tag = f"[from_slack DM @{user_name} reply with send_slack_message channel={channel}]"
+ full_text = f"{tag} {text}"
+ db = get_db()
+ db.execute("INSERT INTO messages (chat_id, text, timestamp, dispatcher_id) VALUES (?, ?, ?, ?)",
+ (f"slack:{channel}:{user_id}", full_text, time.time(), did))
+ db.commit()
+ db.close()
+ print(f"[Slack] DM from {user_name} → dispatcher {did}: {text[:50]}...")
+ _notify_dispatcher(full_text, dispatcher_id=did)
+
+ elif event_type == "app_mention":
+ # Channel @bot → 检查是否绑定了 worker
+ binding = _get_channel_binding(channel)
+ if binding:
+ session = binding["session_name"]
+ # 默认走 message(问答模式),worker 用 reply_to_slack 回复
+ # 只有 /task 命令才走任务模式
+ if text.startswith("/task "):
+ # 任务模式
+ task_content = text[6:].strip()
+ task_id = uuid.uuid4().hex[:8]
+ now = time.time()
+ db = get_db()
+ db.execute(
+ "INSERT INTO tasks (id, target, type, content, status, created_at, updated_at) VALUES (?, ?, 'task', ?, 'pending', ?, ?)",
+ (task_id, session, f"[from Slack #{channel_name} @{user_name}] {task_content}\n\n[Use report_result when done — result will be posted to Slack channel]", now, now),
+ )
+ db.commit()
+ db.close()
+ _slack_task_channels[task_id] = channel
+ slack_client.chat_postMessage(channel=channel, text=f"📋 Task dispatched to `{session}`, ID: `{task_id}`")
+ print(f"[Slack] #{channel_name} TASK → {session}: {task_content[:50]}...")
+ else:
+ # 问答模式:message 直接注入 worker context
+ # worker 用 reply_to_slack 回复,broker 转发到 channel
+ task_id = uuid.uuid4().hex[:8]
+ now = time.time()
+ db = get_db()
+ db.execute(
+ "INSERT INTO tasks (id, target, type, content, status, created_at, updated_at) VALUES (?, ?, 'message', ?, 'pending', ?, ?)",
+ (task_id, session, f"[from Slack #{channel_name} @{user_name} — reply with reply_to_slack] {text}", now, now),
+ )
+ db.commit()
+ db.close()
+ # 记录 message→channel 映射,reply_to_slack 时回发
+ _slack_task_channels[task_id] = channel
+ print(f"[Slack] #{channel_name} MSG → {session}: {text[:50]}...")
+ else:
+ # 没绑定 → 路由到发送者的 dispatcher
+ user_rec = _get_user_by_slack(user_id)
+ if not user_rec:
+ slack_client.chat_postMessage(channel=channel, text="Please register first with /register")
+ return
+ did = user_rec["id"]
+ tag = f"[from_slack #{channel_name} @{user_name} reply with send_slack_message channel={channel}]"
+ full_text = f"{tag} {text}"
+ db = get_db()
+ db.execute("INSERT INTO messages (chat_id, text, timestamp, dispatcher_id) VALUES (?, ?, ?, ?)",
+ (f"slack:{channel}:{user_id}", full_text, time.time(), did))
+ db.commit()
+ db.close()
+ print(f"[Slack] #{channel_name} (unbound) → dispatcher {did}: {text[:50]}...")
+ _notify_dispatcher(full_text, dispatcher_id=did)
+
+ # task→channel 映射用全局变量
+ global _slack_task_channels
+
+ def _handle_register(channel, slack_user_id):
+ """处理 /register — 注册新用户 + 创建 dispatcher"""
+ existing = _get_user_by_slack(slack_user_id)
+ if existing:
+ slack_client.chat_postMessage(
+ channel=channel,
+ text=f"You are already registered. Dispatcher: `{existing['dispatcher_session']}`\nDM me to chat with your dispatcher.",
+ )
+ return
+
+ user_name = _get_user_name(slack_user_id)
+ user_id = uuid.uuid4().hex[:8]
+ session_name = f"dispatcher-{user_id}"
+
+ # 创建 user 记录
+ db = get_db()
+ db.execute(
+ "INSERT INTO users (id, slack_user_id, dispatcher_session, display_name, status, created_at) VALUES (?, ?, ?, ?, 'active', ?)",
+ (user_id, slack_user_id, session_name, user_name, time.time()),
+ )
+ db.commit()
+ db.close()
+
+ # 创建 dispatcher tmux session
+ workspace = os.path.expanduser(f"~/dispatchers/{user_id}")
+ os.makedirs(workspace, exist_ok=True)
+ # 复制英文 CLAUDE.md
+ en_md = os.path.join(os.path.dirname(os.path.abspath(__file__)), "CLAUDE.md.en")
+ if os.path.exists(en_md):
+ import shutil
+ shutil.copy2(en_md, os.path.join(workspace, "CLAUDE.md"))
+
+ subprocess.run(["tmux", "new-session", "-d", "-s", session_name], timeout=5)
+ subprocess.run(["tmux", "send-keys", "-t", session_name, f"cd {workspace}", "Enter"], timeout=5)
+ time.sleep(1)
+ subprocess.run(["tmux", "send-keys", "-t", session_name, f"export DISPATCHER_USER_ID={user_id}", "Enter"], timeout=5)
+ time.sleep(0.5)
+ subprocess.run(["tmux", "send-keys", "-t", session_name, "claude", "Enter"], timeout=5)
+
+ slack_client.chat_postMessage(
+ channel=channel,
+ text=f"✅ Registered! Your dispatcher: `{session_name}`\nDM me to start chatting with your personal assistant.",
+ )
+ print(f"[Register] Created user {user_id} ({user_name}) with dispatcher {session_name}")
+
+ def _handle_meta(channel):
+ """处理 /meta 命令 — 显示 channel 绑定的 worker 元数据"""
+ binding = _get_channel_binding(channel)
+ if not binding:
+ slack_client.chat_postMessage(channel=channel, text="No worker bound to this channel. Use `/init-worker host:path` first.")
+ return
+
+ session = binding["session_name"]
+ host = binding["host"]
+ path = binding["path"]
+ worker_id = binding["id"]
+
+ # 心跳
+ hb = heartbeats.get(session, {})
+ if hb:
+ age = int(time.time() - hb["last_seen"])
+ alive = "🟢 online" if age < HEARTBEAT_TIMEOUT else f"🔴 offline ({age}s ago)"
+ last_hb = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(hb["last_seen"]))
+ else:
+ alive = "⚪ no heartbeat"
+ last_hb = "N/A"
+
+ # 任务统计
+ db = get_db()
+ running = db.execute("SELECT COUNT(*) FROM tasks WHERE target = ? AND status = 'running'", (session,)).fetchone()[0]
+ pending = db.execute("SELECT COUNT(*) FROM tasks WHERE target = ? AND status = 'pending'", (session,)).fetchone()[0]
+ done = db.execute("SELECT COUNT(*) FROM tasks WHERE target = ? AND status = 'done'", (session,)).fetchone()[0]
+
+ # 获取 worker 创建时间
+ w_row = db.execute("SELECT created_at FROM workers WHERE id = ?", (worker_id,)).fetchone()
+ db.close()
+ created = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(w_row[0])) if w_row else "N/A"
+
+ text = (
+ f"*Worker Meta*\n"
+ f"• Session: `{session}`\n"
+ f"• Host: `{host}`\n"
+ f"• Path: `{path}`\n"
+ f"• Status: {alive}\n"
+ f"• Last heartbeat: {last_hb}\n"
+ f"• Created: {created}\n"
+ f"• Tasks — running: {running}, pending: {pending}, done: {done}"
+ )
+ slack_client.chat_postMessage(channel=channel, text=text)
+
+ def _check_pm(channel, slack_user_id):
+ """检查用户是否是该 channel 的 PM(绑定者)"""
+ db = get_db()
+ row = db.execute("SELECT created_by FROM channel_bindings WHERE channel_id = ?", (channel,)).fetchone()
+ db.close()
+ if not row or not row["created_by"]:
+ return True # 没记录创建者,放行
+ return row["created_by"] == slack_user_id
+
+ def _handle_worker_command(channel, action, slack_user_id=""):
+ """处理 /restart-worker 和 /stop-worker"""
+ binding = _get_channel_binding(channel)
+ if not binding:
+ slack_client.chat_postMessage(channel=channel, text="No worker bound to this channel.")
+ return
+ if not _check_pm(channel, slack_user_id):
+ slack_client.chat_postMessage(channel=channel, text="Permission denied. Only the PM (who ran `/init-worker`) can do this.")
+ return
+
+ session = binding["session_name"]
+ host = binding["host"]
+ cmd_id = uuid.uuid4().hex[:8]
+ import json as _json
+ db = get_db()
+
+ if action == "restart":
+ db.execute(
+ "INSERT INTO commands (id, target, action, params, status, created_at) VALUES (?, ?, 'restart', '{}', 'pending', ?)",
+ (cmd_id, session, time.time()),
+ )
+ db.commit()
+ db.close()
+ slack_client.chat_postMessage(channel=channel, text=f"🔄 Restarting worker `{session}` on {host}...")
+ elif action == "stop":
+ db.execute(
+ "INSERT INTO commands (id, target, action, params, status, created_at) VALUES (?, ?, 'stop', '{}', 'pending', ?)",
+ (cmd_id, session, time.time()),
+ )
+ # 解绑 channel + 标记 worker inactive
+ db.execute("DELETE FROM channel_bindings WHERE channel_id = ?", (channel,))
+ db.execute("UPDATE workers SET status = 'inactive' WHERE session_name = ?", (session,))
+ db.commit()
+ db.close()
+ slack_client.chat_postMessage(channel=channel, text=f"🛑 Stopping worker `{session}` on {host}... Channel unbound.")
+
+ def _handle_unbind(channel, slack_user_id=""):
+ """解绑 channel 和 worker"""
+ binding = _get_channel_binding(channel)
+ if not binding:
+ slack_client.chat_postMessage(channel=channel, text="No worker bound to this channel.")
+ return
+ if not _check_pm(channel, slack_user_id):
+ slack_client.chat_postMessage(channel=channel, text="Permission denied. Only the PM (who ran `/init-worker`) can do this.")
+ return
+ session = binding["session_name"]
+ db = get_db()
+ db.execute("DELETE FROM channel_bindings WHERE channel_id = ?", (channel,))
+ db.commit()
+ db.close()
+ slack_client.chat_postMessage(channel=channel, text=f"Unbound from worker `{session}`. Use `/init-worker` to bind a new one.")
+
+ def _handle_init_worker(channel, channel_name, user_id, cmd_text):
+ """处理 /init-worker 命令"""
+ if not cmd_text or ":" not in cmd_text:
+ slack_client.chat_postMessage(
+ channel=channel,
+ text="Usage: `/init-worker host:path`\nExample: `/init-worker timan1:/home/yurenh2/graph-grape`",
+ )
+ return
+
+ host, path = cmd_text.split(":", 1)
+ host = host.strip()
+ path = path.strip()
+
+ # 检查是否已有这个路径的 worker
+ existing = _find_worker_by_path(host, path)
+ if existing:
+ # 已有 worker,直接绑定
+ _bind_channel(channel, existing["id"], channel_name, created_by=user_id)
+ slack_client.chat_postMessage(
+ channel=channel,
+ text=f"✅ Bound to existing worker `{existing['session_name']}` ({host}:{path})\nYou can now @mention me to chat with the worker.",
+ )
+ print(f"[Slack] Channel {channel} bound to existing worker {existing['session_name']}")
+ else:
+ # 需要创建新 worker
+ # 生成 session 名
+ session_name = f"worker-{uuid.uuid4().hex[:6]}"
+ worker_id = _create_worker(session_name, host, path)
+ _bind_channel(channel, worker_id, channel_name, created_by=user_id)
+
+ # 下发创建命令到 lab
+ cmd_id = uuid.uuid4().hex[:8]
+ import json as _json
+ db = get_db()
+ db.execute(
+ "INSERT INTO commands (id, target, action, params, status, created_at) VALUES (?, ?, ?, ?, 'pending', ?)",
+ (cmd_id, session_name, "create_worker", _json.dumps({"host": host, "path": path, "session_name": session_name, "slack_channel": channel}), time.time()),
+ )
+ db.commit()
+ db.close()
+
+ slack_client.chat_postMessage(
+ channel=channel,
+ text=f"🚀 Creating new worker `{session_name}` on {host}:{path}\nCommand ID: `{cmd_id}`, please wait...",
+ )
+ print(f"[Slack] Creating new worker {session_name} at {host}:{path} for channel {channel}")
+
+ # 启动 Socket Mode
+ socket_client = SocketModeClient(
+ app_token=SLACK_APP_TOKEN,
+ web_client=slack_client,
+ )
+ socket_client.socket_mode_request_listeners.append(_handle_event)
+
+ def _run():
+ try:
+ socket_client.connect()
+ print("[Slack] Socket Mode connected")
+ # 保持线程活着
+ import select
+ while True:
+ select.select([], [], [], 60)
+ except Exception as e:
+ print(f"[Slack] Socket Mode error: {e}")
+
+ t = threading.Thread(target=_run, daemon=True)
+ t.start()
+ global _slack_handler
+ _slack_handler = slack_client
+ print("[Slack] Listener started in background thread")
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ init_db()
+ task = asyncio.create_task(poller.start())
+ hb_task = asyncio.create_task(_heartbeat_checker())
+ sched_task = asyncio.create_task(_schedule_checker())
+ _start_slack_listener()
+ yield
+ poller.stop()
+ task.cancel()
+ hb_task.cancel()
+ sched_task.cancel()
+
+
+# === FastAPI App ===
+
+app = FastAPI(title="Claude Bridge Broker", lifespan=lifespan)
+
+
+# --- Pydantic Models ---
+
+class TaskCreate(BaseModel):
+ content: str
+ target: str = ""
+ type: str = "task"
+ dispatcher_id: str = "" # 哪个 dispatcher 创建的
+
+
+class TaskResult(BaseModel):
+ result: str
+
+
+class TelegramSend(BaseModel):
+ message: str
+
+
+# --- Message Endpoints (调度端用) ---
+
+@app.get("/messages/new", dependencies=[Depends(verify_token)])
+def get_new_messages():
+ """获取未处理的用户消息"""
+ db = get_db()
+ rows = db.execute(
+ "SELECT id, chat_id, text, timestamp FROM messages WHERE processed = 0 ORDER BY id"
+ ).fetchall()
+ # 标记为已处理
+ if rows:
+ ids = [r["id"] for r in rows]
+ placeholders = ",".join("?" * len(ids))
+ db.execute(f"UPDATE messages SET processed = 1 WHERE id IN ({placeholders})", ids)
+ db.commit()
+ db.close()
+ return {
+ "messages": [
+ {"id": r["id"], "chat_id": r["chat_id"], "text": r["text"], "timestamp": r["timestamp"]}
+ for r in rows
+ ]
+ }
+
+
+# --- Task Endpoints ---
+
+@app.post("/tasks", dependencies=[Depends(verify_token)])
+def create_task(body: TaskCreate):
+ """创建任务(调度端 → 实验室)"""
+ task_id = uuid.uuid4().hex[:8]
+ now = time.time()
+ db = get_db()
+ db.execute(
+ "INSERT INTO tasks (id, target, type, content, status, dispatcher_id, created_at, updated_at) VALUES (?, ?, ?, ?, 'pending', ?, ?, ?)",
+ (task_id, body.target, body.type, body.content, body.dispatcher_id, now, now),
+ )
+ db.commit()
+ db.close()
+ return {"id": task_id, "target": body.target, "type": body.type, "content": body.content, "status": "pending"}
+
+
+@app.get("/tasks/pending", dependencies=[Depends(verify_token)])
+def get_pending_tasks(target: str = ""):
+ """获取待执行的任务。target 过滤目标 session,只返回匹配的或无指定目标的任务"""
+ db = get_db()
+ if target:
+ rows = db.execute(
+ "SELECT id, target, type, content, created_at FROM tasks WHERE status = 'pending' AND (target = ? OR target = '') ORDER BY created_at",
+ (target,),
+ ).fetchall()
+ else:
+ rows = db.execute(
+ "SELECT id, target, type, content, created_at FROM tasks WHERE status = 'pending' ORDER BY created_at"
+ ).fetchall()
+ db.close()
+ return {"tasks": [dict(r) for r in rows]}
+
+
+@app.post("/tasks/{task_id}/claim", dependencies=[Depends(verify_token)])
+def claim_task(task_id: str):
+ """领取任务(实验室端)"""
+ db = get_db()
+ cur = db.execute(
+ "UPDATE tasks SET status = 'running', updated_at = ? WHERE id = ? AND status = 'pending'",
+ (time.time(), task_id),
+ )
+ db.commit()
+ if cur.rowcount == 0:
+ db.close()
+ raise HTTPException(404, "Task not found or already claimed")
+ # 更新 worker→dispatcher 映射
+ row = db.execute("SELECT target, dispatcher_id FROM tasks WHERE id = ?", (task_id,)).fetchone()
+ if row and row["dispatcher_id"]:
+ _worker_dispatcher_map[row["target"]] = row["dispatcher_id"]
+ db.close()
+ return {"ok": True, "task_id": task_id}
+
+
+@app.post("/tasks/{task_id}/result", dependencies=[Depends(verify_token)])
+def submit_result(task_id: str, body: TaskResult):
+ """提交任务结果(实验室端)"""
+ db = get_db()
+ cur = db.execute(
+ "UPDATE tasks SET status = 'done', result = ?, updated_at = ? WHERE id = ? AND status = 'running'",
+ (body.result, time.time(), task_id),
+ )
+ db.commit()
+ if cur.rowcount == 0:
+ db.close()
+ raise HTTPException(404, "Task not found or not running")
+ db.close()
+ # 如果这个 task 关联了 Slack channel 且是 task 类型,发结果到 channel
+ # message 类型不发(worker 已经通过 reply_to_slack 直接回复了)
+ if _slack_handler:
+ channel = _slack_task_channels.pop(task_id, None)
+ # 查 task 类型
+ db2 = get_db()
+ task_row = db2.execute("SELECT type FROM tasks WHERE id = ?", (task_id,)).fetchone()
+ task_type = task_row[0] if task_row else "task"
+ db2.close()
+ if channel and task_type == "task":
+ result_preview = body.result[:3000] if body.result else "(无结果)"
+ try:
+ _slack_handler.chat_postMessage(
+ channel=channel,
+ text=f"✅ Task `{task_id}` completed:\n\n{result_preview}",
+ )
+ print(f"[Slack] Result for {task_id} sent to channel {channel}")
+ except Exception as e:
+ print(f"[Slack] Failed to send result to channel: {e}")
+ # 路由到正确的 dispatcher
+ db3 = get_db()
+ task_info = db3.execute("SELECT dispatcher_id FROM tasks WHERE id = ?", (task_id,)).fetchone()
+ db3.close()
+ did = task_info["dispatcher_id"] if task_info else ""
+ _notify_dispatcher(f"Task {task_id} completed. Use check_task_status to view result and report to user.", dispatcher_id=did)
+ return {"ok": True}
+
+
+@app.post("/tasks/{task_id}/fail", dependencies=[Depends(verify_token)])
+def fail_task(task_id: str, body: TaskResult):
+ """标记任务失败(实验室端)"""
+ db = get_db()
+ cur = db.execute(
+ "UPDATE tasks SET status = 'failed', result = ?, updated_at = ? WHERE id = ?",
+ (body.result, time.time(), task_id),
+ )
+ db.commit()
+ db.close()
+ db2 = get_db()
+ task_info = db2.execute("SELECT dispatcher_id FROM tasks WHERE id = ?", (task_id,)).fetchone()
+ db2.close()
+ did = task_info["dispatcher_id"] if task_info else ""
+ _notify_dispatcher(f"Task {task_id} failed. Use check_task_status to view reason and notify user.", dispatcher_id=did)
+ return {"ok": True}
+
+
+@app.get("/tasks/{task_id}", dependencies=[Depends(verify_token)])
+def get_task(task_id: str):
+ """查询任务状态"""
+ db = get_db()
+ row = db.execute("SELECT * FROM tasks WHERE id = ?", (task_id,)).fetchone()
+ db.close()
+ if not row:
+ raise HTTPException(404, "Task not found")
+ return dict(row)
+
+
+@app.get("/tasks", dependencies=[Depends(verify_token)])
+def list_tasks(status: str = None, limit: int = 20):
+ """列出任务"""
+ db = get_db()
+ if status:
+ rows = db.execute(
+ "SELECT * FROM tasks WHERE status = ? ORDER BY created_at DESC LIMIT ?",
+ (status, limit),
+ ).fetchall()
+ else:
+ rows = db.execute(
+ "SELECT * FROM tasks ORDER BY created_at DESC LIMIT ?", (limit,)
+ ).fetchall()
+ db.close()
+ return {"tasks": [dict(r) for r in rows]}
+
+
+# --- Workers list ---
+
+@app.get("/workers", dependencies=[Depends(verify_token)])
+def list_workers_api():
+ """List all active workers with channel bindings"""
+ db = get_db()
+ rows = db.execute("""
+ SELECT w.*, cb.channel_name
+ FROM workers w
+ LEFT JOIN channel_bindings cb ON cb.worker_id = w.id
+ WHERE w.status = 'active'
+ ORDER BY w.created_at
+ """).fetchall()
+ db.close()
+ return {"workers": [dict(r) for r in rows]}
+
+
+# --- Telegram Send (调度端用) ---
+
+@app.post("/telegram/send", dependencies=[Depends(verify_token)])
+async def send_telegram(body: TelegramSend):
+ """发送 Telegram 消息给用户"""
+ async with httpx.AsyncClient() as client:
+ resp = await client.post(
+ f"https://api.telegram.org/bot{TELEGRAM_TOKEN}/sendMessage",
+ json={
+ "chat_id": TELEGRAM_CHAT_ID,
+ "text": body.message,
+ # 不用 Markdown,避免特殊字符解析失败
+ },
+ )
+ if resp.status_code != 200:
+ raise HTTPException(502, f"Telegram API error: {resp.text}")
+ return {"ok": True}
+
+
+# --- File Transfer ---
+
+@app.post("/files/upload", dependencies=[Depends(verify_token)])
+async def upload_file(file: UploadFile = File(...), filename: str = Form("")):
+ """上传文件到 broker 文件存储"""
+ name = filename or file.filename or f"upload_{int(time.time())}"
+ safe_name = f"{uuid.uuid4().hex[:8]}_{name}"
+ save_path = os.path.join(FILES_DIR, safe_name)
+ content = await file.read()
+ with open(save_path, "wb") as f:
+ f.write(content)
+ return {"filename": safe_name, "size": len(content)}
+
+
+@app.get("/files/{filename}")
+async def download_file(filename: str, request: Request):
+ """下载文件(需认证)"""
+ await verify_token(request)
+ file_path = os.path.join(FILES_DIR, filename)
+ if not os.path.exists(file_path):
+ raise HTTPException(404, "File not found")
+ return FileResponse(file_path, filename=filename)
+
+
+@app.get("/files", dependencies=[Depends(verify_token)])
+def list_files():
+ """列出所有文件"""
+ files = []
+ for f in os.listdir(FILES_DIR):
+ fp = os.path.join(FILES_DIR, f)
+ files.append({"name": f, "size": os.path.getsize(fp)})
+ return {"files": sorted(files, key=lambda x: x["name"])}
+
+
+@app.post("/telegram/send_file", dependencies=[Depends(verify_token)])
+async def send_telegram_file(filename: str = Form(...), caption: str = Form("")):
+ """通过 Telegram 给用户发送文件"""
+ file_path = os.path.join(FILES_DIR, filename)
+ if not os.path.exists(file_path):
+ raise HTTPException(404, "File not found")
+ async with httpx.AsyncClient(timeout=60) as client:
+ with open(file_path, "rb") as f:
+ resp = await client.post(
+ f"https://api.telegram.org/bot{TELEGRAM_TOKEN}/sendDocument",
+ data={"chat_id": TELEGRAM_CHAT_ID, "caption": caption or ""},
+ files={"document": (filename, f)},
+ )
+ if resp.status_code != 200:
+ raise HTTPException(502, f"Telegram API error: {resp.text}")
+ return {"ok": True}
+
+
+# --- Slack Send ---
+
+class SlackSend(BaseModel):
+ channel: str # channel ID
+ message: str
+ thread_ts: str = "" # 可选,回复到 thread
+
+
+@app.post("/slack/send", dependencies=[Depends(verify_token)])
+def send_slack_message(body: SlackSend):
+ """发送 Slack 消息"""
+ if not _slack_handler:
+ raise HTTPException(503, "Slack not configured")
+ kwargs = {"channel": body.channel, "text": body.message}
+ if body.thread_ts:
+ kwargs["thread_ts"] = body.thread_ts
+ try:
+ _slack_handler.chat_postMessage(**kwargs)
+ except Exception as e:
+ raise HTTPException(502, f"Slack API error: {e}")
+ return {"ok": True}
+
+
+# --- Log / Reply (实验室 → dispatcher) ---
+
+class LogEntry(BaseModel):
+ source: str # claude / claude2 / claude3 / system
+ message: str
+
+
+@app.post("/log", dependencies=[Depends(verify_token)])
+def post_log(body: LogEntry):
+ """实验室 Claude 回传消息给 dispatcher"""
+ # 从 worker→dispatcher 映射推断目标
+ did = _worker_dispatcher_map.get(body.source, "")
+ _notify_dispatcher(f"[{body.source}] {body.message}", dispatcher_id=did)
+ return {"ok": True}
+
+
+class SlackHistory(BaseModel):
+ session: str # worker session name, 用来查绑定的 channel
+ limit: int = 20
+
+
+@app.post("/slack/history", dependencies=[Depends(verify_token)])
+def get_slack_history(body: SlackHistory):
+ """获取 worker 绑定的 Slack channel 的消息历史"""
+ if not _slack_handler:
+ raise HTTPException(503, "Slack not configured")
+ db = get_db()
+ row = db.execute(
+ "SELECT cb.channel_id FROM channel_bindings cb JOIN workers w ON cb.worker_id = w.id WHERE w.session_name = ?",
+ (body.session,),
+ ).fetchone()
+ db.close()
+ if not row:
+ raise HTTPException(404, f"No Slack channel bound to session {body.session}")
+ channel = row[0]
+ try:
+ resp = _slack_handler.conversations_history(channel=channel, limit=body.limit)
+ messages = []
+ for msg in resp.get("messages", []):
+ user = msg.get("user", "bot")
+ text = msg.get("text", "")
+ ts = msg.get("ts", "")
+ messages.append({"user": user, "text": text, "ts": ts})
+ # 反转让旧消息在前
+ messages.reverse()
+ return {"messages": messages}
+ except Exception as e:
+ raise HTTPException(502, f"Slack API error: {e}")
+
+
+class SlackReply(BaseModel):
+ session: str # worker session name
+ message: str
+
+
+@app.post("/slack/reply", dependencies=[Depends(verify_token)])
+def slack_reply(body: SlackReply):
+ """Worker 回复到关联的 Slack channel(通过 session→channel_binding 查找)"""
+ if not _slack_handler:
+ raise HTTPException(503, "Slack not configured")
+ # 从 worker session 找到绑定的 channel
+ db = get_db()
+ row = db.execute(
+ "SELECT cb.channel_id FROM channel_bindings cb JOIN workers w ON cb.worker_id = w.id WHERE w.session_name = ?",
+ (body.session,),
+ ).fetchone()
+ db.close()
+ if not row:
+ raise HTTPException(404, f"No Slack channel bound to session {body.session}")
+ channel = row[0]
+ try:
+ _slack_handler.chat_postMessage(channel=channel, text=body.message)
+ except Exception as e:
+ raise HTTPException(502, f"Slack API error: {e}")
+ return {"ok": True}
+
+
+# --- GPT-Pro Expert ---
+
+OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
+expert_requests: dict = {} # id -> {status, question, answer}
+
+
+class ExpertQuestion(BaseModel):
+ question: str
+
+
+@app.post("/expert/ask", dependencies=[Depends(verify_token)])
+def ask_expert(body: ExpertQuestion):
+ """提交问题给 GPT-Pro,后台异步处理"""
+ req_id = uuid.uuid4().hex[:8]
+ expert_requests[req_id] = {
+ "status": "thinking",
+ "question": body.question,
+ "answer": "",
+ }
+ # 后台线程调 OpenAI API
+ import threading
+ threading.Thread(target=_call_gpt_pro, args=(req_id, body.question), daemon=True).start()
+ return {"id": req_id, "status": "thinking"}
+
+
+@app.get("/expert/{req_id}", dependencies=[Depends(verify_token)])
+def get_expert(req_id: str):
+ if req_id not in expert_requests:
+ raise HTTPException(404, "Request not found")
+ return expert_requests[req_id]
+
+
+def _call_gpt_pro(req_id: str, question: str):
+ """同步调用 OpenAI GPT-Pro(在后台线程中运行)"""
+ import json
+ import urllib.request
+
+ try:
+ req = urllib.request.Request(
+ "https://api.openai.com/v1/responses",
+ data=json.dumps({
+ "model": "o3-pro",
+ "input": question,
+ }).encode(),
+ headers={
+ "Authorization": f"Bearer {OPENAI_API_KEY}",
+ "Content-Type": "application/json",
+ },
+ )
+ with urllib.request.urlopen(req, timeout=600) as resp:
+ data = json.loads(resp.read())
+
+ # 提取回复文本
+ answer = ""
+ for item in data.get("output", []):
+ if item.get("type") == "message":
+ for content in item.get("content", []):
+ if content.get("type") == "output_text":
+ answer += content.get("text", "")
+
+ expert_requests[req_id]["status"] = "done"
+ expert_requests[req_id]["answer"] = answer or "(empty response)"
+
+ except Exception as e:
+ expert_requests[req_id]["status"] = "error"
+ expert_requests[req_id]["answer"] = str(e)
+
+ # 通知 dispatcher
+ status = expert_requests[req_id]["status"]
+ _notify_dispatcher(
+ f"[system] GPT-Pro reply ready (ID: {req_id}, status: {status}). Use get_expert_answer to view and forward to user."
+ )
+
+
+# --- Commands (lab server执行的系统命令) ---
+
+class CommandCreate(BaseModel):
+ target: str # tmux session name
+ action: str # switch_project, restart, etc.
+ params: dict = {}
+
+
+@app.post("/commands", dependencies=[Depends(verify_token)])
+def create_command(body: CommandCreate):
+ """创建一个命令让 lab cron 执行"""
+ cmd_id = uuid.uuid4().hex[:8]
+ db = get_db()
+ import json as _json
+ db.execute(
+ "INSERT INTO commands (id, target, action, params, status, created_at) VALUES (?, ?, ?, ?, 'pending', ?)",
+ (cmd_id, body.target, body.action, _json.dumps(body.params), time.time()),
+ )
+ db.commit()
+ db.close()
+ return {"id": cmd_id, "target": body.target, "action": body.action}
+
+
+@app.get("/commands/pending", dependencies=[Depends(verify_token)])
+def get_pending_commands():
+ """lab cron 轮询待执行命令"""
+ db = get_db()
+ rows = db.execute(
+ "SELECT id, target, action, params FROM commands WHERE status = 'pending' ORDER BY created_at"
+ ).fetchall()
+ db.close()
+ return {"commands": [dict(r) for r in rows]}
+
+
+@app.post("/commands/{cmd_id}/done", dependencies=[Depends(verify_token)])
+def complete_command(cmd_id: str, body: TaskResult):
+ """标记命令完成"""
+ db = get_db()
+ db.execute(
+ "UPDATE commands SET status = 'done', result = ? WHERE id = ?",
+ (body.result, cmd_id),
+ )
+ # 查命令的 target session,找到绑定的 Slack channel 回报
+ row = db.execute("SELECT target FROM commands WHERE id = ?", (cmd_id,)).fetchone()
+ target_session = row[0] if row else ""
+ slack_channel = None
+ if target_session and _slack_handler:
+ ch_row = db.execute(
+ "SELECT cb.channel_id FROM channel_bindings cb JOIN workers w ON cb.worker_id = w.id WHERE w.session_name = ?",
+ (target_session,),
+ ).fetchone()
+ if ch_row:
+ slack_channel = ch_row[0]
+ db.commit()
+ db.close()
+
+ result = body.result
+ is_ok = result.startswith("OK")
+ if slack_channel:
+ emoji = "✅" if is_ok else "❌"
+ try:
+ _slack_handler.chat_postMessage(channel=slack_channel, text=f"{emoji} {result}")
+ except Exception:
+ pass
+ _notify_dispatcher(f"[system] Command {cmd_id} completed: {result}")
+ return {"ok": True}
+
+
+@app.get("/commands/{cmd_id}", dependencies=[Depends(verify_token)])
+def get_command(cmd_id: str):
+ """查询命令状态"""
+ db = get_db()
+ row = db.execute("SELECT * FROM commands WHERE id = ?", (cmd_id,)).fetchone()
+ db.close()
+ if not row:
+ raise HTTPException(404, "Command not found")
+ return dict(row)
+
+
+# --- Pending Context (for UserPromptSubmit hook) ---
+
+@app.get("/context/pending", dependencies=[Depends(verify_token)])
+def get_pending_context(dispatcher_id: str = ""):
+ """Hook 调用:拉取待注入的 context,拉取后自动删除"""
+ db = get_db()
+ if dispatcher_id:
+ rows = db.execute("SELECT id, text FROM pending_context WHERE dispatcher_id = ? OR dispatcher_id = '' ORDER BY id", (dispatcher_id,)).fetchall()
+ else:
+ rows = db.execute("SELECT id, text FROM pending_context ORDER BY id").fetchall()
+ if rows:
+ ids = [r["id"] for r in rows]
+ placeholders = ",".join("?" * len(ids))
+ db.execute(f"DELETE FROM pending_context WHERE id IN ({placeholders})", ids)
+ db.commit()
+ db.close()
+ texts = [r["text"] for r in rows]
+ return {"messages": texts}
+
+
+# --- Schedules (定时任务) ---
+
+class ScheduleCreate(BaseModel):
+ action: str
+ trigger_at: float = 0
+ delay_seconds: float = 0
+ repeat_seconds: float = 0
+ dispatcher_id: str = ""
+
+
+@app.post("/schedules", dependencies=[Depends(verify_token)])
+def create_schedule(body: ScheduleCreate):
+ """创建定时任务"""
+ schedule_id = uuid.uuid4().hex[:8]
+ now = time.time()
+ if body.trigger_at > 0:
+ trigger = body.trigger_at
+ elif body.delay_seconds > 0:
+ trigger = now + body.delay_seconds
+ else:
+ trigger = now
+ db = get_db()
+ db.execute(
+ "INSERT INTO schedules (id, action, trigger_at, repeat_seconds, dispatcher_id, status, created_at) VALUES (?, ?, ?, ?, ?, 'active', ?)",
+ (schedule_id, body.action, trigger, body.repeat_seconds, body.dispatcher_id, now),
+ )
+ db.commit()
+ db.close()
+ import datetime
+ trigger_str = datetime.datetime.fromtimestamp(trigger).strftime("%Y-%m-%d %H:%M:%S")
+ return {"id": schedule_id, "trigger_at": trigger_str, "repeat_seconds": body.repeat_seconds}
+
+
+@app.get("/schedules", dependencies=[Depends(verify_token)])
+def list_schedules():
+ """列出所有定时任务"""
+ db = get_db()
+ rows = db.execute("SELECT * FROM schedules WHERE status = 'active' ORDER BY trigger_at").fetchall()
+ db.close()
+ return {"schedules": [dict(r) for r in rows]}
+
+
+@app.delete("/schedules/{schedule_id}", dependencies=[Depends(verify_token)])
+def cancel_schedule(schedule_id: str):
+ """取消定时任务"""
+ db = get_db()
+ db.execute("UPDATE schedules SET status = 'cancelled' WHERE id = ?", (schedule_id,))
+ db.commit()
+ db.close()
+ return {"ok": True}
+
+
+async def _schedule_checker():
+ """后台协程:每 30 秒检查定时任务,到时间就触发"""
+ while True:
+ await asyncio.sleep(30)
+ now = time.time()
+ db = get_db()
+ rows = db.execute(
+ "SELECT id, action, trigger_at, repeat_seconds, dispatcher_id FROM schedules WHERE status = 'active' AND trigger_at <= ?",
+ (now,),
+ ).fetchall()
+ for r in rows:
+ sid, action, trigger_at, repeat = r["id"], r["action"], r["trigger_at"], r["repeat_seconds"]
+ did = r["dispatcher_id"] if "dispatcher_id" in r.keys() else ""
+ _notify_dispatcher(f"[scheduled {sid}] {action}", dispatcher_id=did)
+ if repeat > 0:
+ # 循环任务:更新下次触发时间
+ next_trigger = trigger_at + repeat
+ # 如果落后太多,跳到下一个未来时间点
+ while next_trigger <= now:
+ next_trigger += repeat
+ db.execute("UPDATE schedules SET trigger_at = ? WHERE id = ?", (next_trigger, sid))
+ else:
+ # 一次性任务:标记完成
+ db.execute("UPDATE schedules SET status = 'done' WHERE id = ?", (sid,))
+ db.commit()
+ db.close()
+
+
+# --- Heartbeat ---
+
+heartbeats: dict = {} # session -> {host, last_seen, alerted}
+
+
+class Heartbeat(BaseModel):
+ session: str
+ host: str = ""
+
+
+@app.post("/heartbeat", dependencies=[Depends(verify_token)])
+def post_heartbeat(body: Heartbeat):
+ """Worker 心跳上报"""
+ heartbeats[body.session] = {
+ "host": body.host,
+ "last_seen": time.time(),
+ "alerted": False,
+ }
+ return {"ok": True}
+
+
+@app.get("/heartbeat", dependencies=[Depends(verify_token)])
+def get_heartbeats():
+ """查看所有 worker 心跳状态"""
+ now = time.time()
+ result = {}
+ for session, info in heartbeats.items():
+ age = now - info["last_seen"]
+ result[session] = {
+ "host": info["host"],
+ "last_seen": info["last_seen"],
+ "age_seconds": int(age),
+ "alive": age < HEARTBEAT_TIMEOUT,
+ }
+ return result
+
+
+async def _heartbeat_checker():
+ """后台协程:定期检查心跳,超时则 Telegram 通知用户"""
+ while True:
+ await asyncio.sleep(60) # 每分钟检查一次
+ now = time.time()
+ for session, info in heartbeats.items():
+ age = now - info["last_seen"]
+ if age > HEARTBEAT_TIMEOUT and not info.get("alerted"):
+ # 超时!通知用户
+ info["alerted"] = True
+ host = info.get("host", "unknown")
+ msg = f"Worker {session} ({host}) heartbeat timeout ({int(age)}s). May be offline. Please check and restart."
+ try:
+ async with httpx.AsyncClient() as client:
+ await client.post(
+ f"https://api.telegram.org/bot{TELEGRAM_TOKEN}/sendMessage",
+ json={"chat_id": TELEGRAM_CHAT_ID, "text": msg},
+ )
+ print(f"[Heartbeat] ALERT: {session} timeout, notified user")
+ except Exception as e:
+ print(f"[Heartbeat] Failed to send alert: {e}")
+ # 也通知 dispatcher
+ _notify_dispatcher(f"[system] Worker {session} ({host}) heartbeat timeout, may be offline")
+
+
+# --- Health Check ---
+
+@app.get("/health")
+def health():
+ return {"status": "ok", "time": time.time()}