summaryrefslogtreecommitdiff
path: root/websocket-server/src
diff options
context:
space:
mode:
authorIlan Bigio <ilan@openai.com>2024-12-16 13:06:08 -0800
committerIlan Bigio <ilan@openai.com>2024-12-19 16:08:22 -0500
commit20009aed53d8864c9204d43a17895168a777d2cc (patch)
tree754dded819869bc34a8a2a02c66ea72dac1ccd24 /websocket-server/src
Initial commit
Diffstat (limited to 'websocket-server/src')
-rw-r--r--websocket-server/src/functionHandlers.ts33
-rw-r--r--websocket-server/src/server.ts77
-rw-r--r--websocket-server/src/sessionManager.ts286
-rw-r--r--websocket-server/src/twiml.xml8
-rw-r--r--websocket-server/src/types.ts31
5 files changed, 435 insertions, 0 deletions
diff --git a/websocket-server/src/functionHandlers.ts b/websocket-server/src/functionHandlers.ts
new file mode 100644
index 0000000..512a7af
--- /dev/null
+++ b/websocket-server/src/functionHandlers.ts
@@ -0,0 +1,33 @@
+import { FunctionHandler } from "./types";
+
+const functions: FunctionHandler[] = [];
+
+functions.push({
+ schema: {
+ name: "get_weather_from_coords",
+ type: "function",
+ description: "Get the current weather",
+ parameters: {
+ type: "object",
+ properties: {
+ latitude: {
+ type: "number",
+ },
+ longitude: {
+ type: "number",
+ },
+ },
+ required: ["latitude", "longitude"],
+ },
+ },
+ handler: async (args: { latitude: number; longitude: number }) => {
+ const response = await fetch(
+ `https://api.open-meteo.com/v1/forecast?latitude=${args.latitude}&longitude=${args.longitude}&current=temperature_2m,wind_speed_10m&hourly=temperature_2m,relative_humidity_2m,wind_speed_10m`
+ );
+ const data = await response.json();
+ const currentTemp = data.current?.temperature_2m;
+ return JSON.stringify({ temp: currentTemp });
+ },
+});
+
+export default functions;
diff --git a/websocket-server/src/server.ts b/websocket-server/src/server.ts
new file mode 100644
index 0000000..8ab3ea6
--- /dev/null
+++ b/websocket-server/src/server.ts
@@ -0,0 +1,77 @@
+import express from "express";
+import { WebSocketServer, WebSocket } from "ws";
+import { IncomingMessage } from "http";
+import dotenv from "dotenv";
+import http from "http";
+import { readFileSync } from "fs";
+import { join } from "path";
+import cors from "cors";
+import {
+ handleCallConnection,
+ handleFrontendConnection,
+} from "./sessionManager";
+import functions from "./functionHandlers";
+
+dotenv.config();
+
+const PORT = parseInt(process.env.PORT || "8081", 10);
+const PUBLIC_URL = process.env.PUBLIC_URL || "";
+
+const app = express();
+app.use(cors());
+const server = http.createServer(app);
+const wss = new WebSocketServer({ server });
+
+app.use(express.urlencoded({ extended: false }));
+
+const twimlPath = join(__dirname, "twiml.xml");
+const twimlTemplate = readFileSync(twimlPath, "utf-8");
+
+app.get("/public-url", (req, res) => {
+ res.json({ publicUrl: PUBLIC_URL });
+});
+
+app.all("/twiml", (req, res) => {
+ const wsUrl = new URL(PUBLIC_URL);
+ wsUrl.protocol = "wss:";
+ wsUrl.pathname = `/call`;
+
+ const twimlContent = twimlTemplate.replace("{{WS_URL}}", wsUrl.toString());
+ res.type("text/xml").send(twimlContent);
+});
+
+// New endpoint to list available tools (schemas)
+app.get("/tools", (req, res) => {
+ res.json(functions.map((f) => f.schema));
+});
+
+let currentCall: WebSocket | null = null;
+let currentLogs: WebSocket | null = null;
+
+wss.on("connection", (ws: WebSocket, req: IncomingMessage) => {
+ const url = new URL(req.url || "", `http://${req.headers.host}`);
+ const parts = url.pathname.split("/").filter(Boolean);
+
+ if (parts.length < 1) {
+ ws.close();
+ return;
+ }
+
+ const type = parts[0];
+
+ if (type === "call") {
+ if (currentCall) currentCall.close();
+ currentCall = ws;
+ handleCallConnection(currentCall);
+ } else if (type === "logs") {
+ if (currentLogs) currentLogs.close();
+ currentLogs = ws;
+ handleFrontendConnection(currentLogs);
+ } else {
+ ws.close();
+ }
+});
+
+server.listen(PORT, () => {
+ console.log(`Server running on http://localhost:${PORT}`);
+});
diff --git a/websocket-server/src/sessionManager.ts b/websocket-server/src/sessionManager.ts
new file mode 100644
index 0000000..7cf6336
--- /dev/null
+++ b/websocket-server/src/sessionManager.ts
@@ -0,0 +1,286 @@
+import { RawData, WebSocket } from "ws";
+import functions from "./functionHandlers";
+
+const OPENAI_API_KEY = process.env.OPENAI_API_KEY || "";
+
+interface Session {
+ twilioConn?: WebSocket;
+ frontendConn?: WebSocket;
+ modelConn?: WebSocket;
+ streamSid?: string;
+ saved_config?: any;
+ lastAssistantItem?: string;
+ responseStartTimestamp?: number;
+ latestMediaTimestamp?: number;
+}
+
+let session: Session = {};
+
+export function handleCallConnection(ws: WebSocket) {
+ cleanupConnection(session.twilioConn);
+ session.twilioConn = ws;
+
+ ws.on("message", handleTwilioMessage);
+ ws.on("error", ws.close);
+ ws.on("close", () => {
+ cleanupConnection(session.modelConn);
+ cleanupConnection(session.twilioConn);
+ session.twilioConn = undefined;
+ session.modelConn = undefined;
+ session.streamSid = undefined;
+ session.lastAssistantItem = undefined;
+ session.responseStartTimestamp = undefined;
+ session.latestMediaTimestamp = undefined;
+ if (!session.frontendConn) session = {};
+ });
+}
+
+export function handleFrontendConnection(ws: WebSocket) {
+ cleanupConnection(session.frontendConn);
+ session.frontendConn = ws;
+
+ ws.on("message", handleFrontendMessage);
+ ws.on("close", () => {
+ cleanupConnection(session.frontendConn);
+ session.frontendConn = undefined;
+ if (!session.twilioConn && !session.modelConn) session = {};
+ });
+}
+
+async function handleFunctionCall(item: { name: string; arguments: string }) {
+ console.log("Handling function call:", item);
+ const fnDef = functions.find((f) => f.schema.name === item.name);
+ if (!fnDef) {
+ throw new Error(`No handler found for function: ${item.name}`);
+ }
+
+ let args: unknown;
+ try {
+ args = JSON.parse(item.arguments);
+ } catch {
+ return JSON.stringify({
+ error: "Invalid JSON arguments for function call.",
+ });
+ }
+
+ try {
+ console.log("Calling function:", fnDef.schema.name, args);
+ const result = await fnDef.handler(args as any);
+ return result;
+ } catch (err: any) {
+ console.error("Error running function:", err);
+ return JSON.stringify({
+ error: `Error running function ${item.name}: ${err.message}`,
+ });
+ }
+}
+
+function handleTwilioMessage(data: RawData) {
+ const msg = parseMessage(data);
+ if (!msg) return;
+
+ switch (msg.event) {
+ case "start":
+ session.streamSid = msg.start.streamSid;
+ session.latestMediaTimestamp = 0;
+ session.lastAssistantItem = undefined;
+ session.responseStartTimestamp = undefined;
+ tryConnectModel();
+ break;
+ case "media":
+ session.latestMediaTimestamp = msg.media.timestamp;
+ if (isOpen(session.modelConn)) {
+ jsonSend(session.modelConn, {
+ type: "input_audio_buffer.append",
+ audio: msg.media.payload,
+ });
+ }
+ break;
+ case "close":
+ closeAllConnections();
+ break;
+ }
+}
+
+function handleFrontendMessage(data: RawData) {
+ const msg = parseMessage(data);
+ if (!msg) return;
+
+ if (isOpen(session.modelConn)) {
+ jsonSend(session.modelConn, msg);
+ }
+
+ if (msg.type === "session.update") {
+ session.saved_config = msg.session;
+ }
+}
+
+function tryConnectModel() {
+ if (!session.twilioConn || !session.streamSid) return;
+ if (isOpen(session.modelConn)) return;
+
+ session.modelConn = new WebSocket(
+ "wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-12-17",
+ {
+ headers: {
+ Authorization: `Bearer ${OPENAI_API_KEY}`,
+ "OpenAI-Beta": "realtime=v1",
+ },
+ }
+ );
+
+ session.modelConn.on("open", () => {
+ const config = session.saved_config || {};
+ jsonSend(session.modelConn, {
+ type: "session.update",
+ session: {
+ modalities: ["text", "audio"],
+ turn_detection: { type: "server_vad" },
+ voice: "ash",
+ input_audio_transcription: { model: "whisper-1" },
+ input_audio_format: "g711_ulaw",
+ output_audio_format: "g711_ulaw",
+ ...config,
+ },
+ });
+ });
+
+ session.modelConn.on("message", handleModelMessage);
+ session.modelConn.on("error", closeModel);
+ session.modelConn.on("close", closeModel);
+}
+
+function handleModelMessage(data: RawData) {
+ const event = parseMessage(data);
+ if (!event) return;
+
+ jsonSend(session.frontendConn, event);
+
+ switch (event.type) {
+ case "input_audio_buffer.speech_started":
+ handleTruncation();
+ break;
+
+ case "response.audio.delta":
+ if (session.twilioConn && session.streamSid) {
+ if (session.responseStartTimestamp === undefined) {
+ session.responseStartTimestamp = session.latestMediaTimestamp || 0;
+ }
+ if (event.item_id) session.lastAssistantItem = event.item_id;
+
+ jsonSend(session.twilioConn, {
+ event: "media",
+ streamSid: session.streamSid,
+ media: { payload: event.delta },
+ });
+
+ jsonSend(session.twilioConn, {
+ event: "mark",
+ streamSid: session.streamSid,
+ });
+ }
+ break;
+
+ case "response.output_item.done": {
+ const { item } = event;
+ if (item.type === "function_call") {
+ handleFunctionCall(item)
+ .then((output) => {
+ if (session.modelConn) {
+ jsonSend(session.modelConn, {
+ type: "conversation.item.create",
+ item: {
+ type: "function_call_output",
+ call_id: item.call_id,
+ output: JSON.stringify(output),
+ },
+ });
+ jsonSend(session.modelConn, { type: "response.create" });
+ }
+ })
+ .catch((err) => {
+ console.error("Error handling function call:", err);
+ });
+ }
+ break;
+ }
+ }
+}
+
+function handleTruncation() {
+ if (
+ !session.lastAssistantItem ||
+ session.responseStartTimestamp === undefined
+ )
+ return;
+
+ const elapsedMs =
+ (session.latestMediaTimestamp || 0) - (session.responseStartTimestamp || 0);
+ const audio_end_ms = elapsedMs > 0 ? elapsedMs : 0;
+
+ if (isOpen(session.modelConn)) {
+ jsonSend(session.modelConn, {
+ type: "conversation.item.truncate",
+ item_id: session.lastAssistantItem,
+ content_index: 0,
+ audio_end_ms,
+ });
+ }
+
+ if (session.twilioConn && session.streamSid) {
+ jsonSend(session.twilioConn, {
+ event: "clear",
+ streamSid: session.streamSid,
+ });
+ }
+
+ session.lastAssistantItem = undefined;
+ session.responseStartTimestamp = undefined;
+}
+
+function closeModel() {
+ cleanupConnection(session.modelConn);
+ session.modelConn = undefined;
+ if (!session.twilioConn && !session.frontendConn) session = {};
+}
+
+function closeAllConnections() {
+ if (session.twilioConn) {
+ session.twilioConn.close();
+ session.twilioConn = undefined;
+ }
+ if (session.modelConn) {
+ session.modelConn.close();
+ session.modelConn = undefined;
+ }
+ if (session.frontendConn) {
+ session.frontendConn.close();
+ session.frontendConn = undefined;
+ }
+ session.streamSid = undefined;
+ session.lastAssistantItem = undefined;
+ session.responseStartTimestamp = undefined;
+ session.latestMediaTimestamp = undefined;
+ session.saved_config = undefined;
+}
+
+function cleanupConnection(ws?: WebSocket) {
+ if (isOpen(ws)) ws.close();
+}
+
+function parseMessage(data: RawData): any {
+ try {
+ return JSON.parse(data.toString());
+ } catch {
+ return null;
+ }
+}
+
+function jsonSend(ws: WebSocket | undefined, obj: unknown) {
+ if (!isOpen(ws)) return;
+ ws.send(JSON.stringify(obj));
+}
+
+function isOpen(ws?: WebSocket): ws is WebSocket {
+ return !!ws && ws.readyState === WebSocket.OPEN;
+}
diff --git a/websocket-server/src/twiml.xml b/websocket-server/src/twiml.xml
new file mode 100644
index 0000000..2f25108
--- /dev/null
+++ b/websocket-server/src/twiml.xml
@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<Response>
+ <Say>Connected</Say>
+ <Connect>
+ <Stream url="{{WS_URL}}" />
+ </Connect>
+ <Say>Disconnected</Say>
+</Response>
diff --git a/websocket-server/src/types.ts b/websocket-server/src/types.ts
new file mode 100644
index 0000000..6c544c9
--- /dev/null
+++ b/websocket-server/src/types.ts
@@ -0,0 +1,31 @@
+import { WebSocket } from "ws";
+
+export interface Session {
+ twilioConn?: WebSocket;
+ frontendConn?: WebSocket;
+ modelConn?: WebSocket;
+ config?: any;
+ streamSid?: string;
+}
+
+export interface FunctionCallItem {
+ name: string;
+ arguments: string;
+ call_id?: string;
+}
+
+export interface FunctionSchema {
+ name: string;
+ type: "function";
+ description?: string;
+ parameters: {
+ type: string;
+ properties: Record<string, { type: string; description?: string }>;
+ required: string[];
+ };
+}
+
+export interface FunctionHandler {
+ schema: FunctionSchema;
+ handler: (args: any) => Promise<string>;
+}