From 5226438b16022d5f91890b385150804573af6c81 Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 24 Feb 2026 13:05:30 +0000 Subject: [PATCH] story-kit: merge 138_bug_no_heartbeat_to_detect_stale_websocket_connections --- frontend/src/api/client.test.ts | 104 +++++++++++++++++++++++++++++++- frontend/src/api/client.ts | 37 +++++++++++- server/src/http/ws.rs | 26 ++++++++ 3 files changed, 163 insertions(+), 4 deletions(-) diff --git a/frontend/src/api/client.test.ts b/frontend/src/api/client.test.ts index 6ab4eee..a0bd504 100644 --- a/frontend/src/api/client.test.ts +++ b/frontend/src/api/client.test.ts @@ -163,7 +163,8 @@ interface MockWsInstance { onmessage: ((e: { data: string }) => void) | null; onerror: (() => void) | null; readyState: number; - send: () => void; + sentMessages: string[]; + send: (data: string) => void; close: () => void; simulateClose: () => void; simulateMessage: (data: Record) => void; @@ -183,12 +184,15 @@ function makeMockWebSocket() { onmessage: ((e: { data: string }) => void) | null = null; onerror: (() => void) | null = null; readyState = 0; + sentMessages: string[] = []; constructor(_url: string) { instances.push(this as unknown as MockWsInstance); } - send() {} + send(data: string) { + this.sentMessages.push(data); + } close() { this.readyState = 3; @@ -330,3 +334,99 @@ describe("ChatWebSocket", () => { expect(instances).toHaveLength(4); }); }); + +describe("ChatWebSocket heartbeat", () => { + beforeEach(() => { + vi.useFakeTimers(); + const { MockWebSocket } = makeMockWebSocket(); + vi.stubGlobal("WebSocket", MockWebSocket); + (ChatWebSocket as unknown as { sharedSocket: null }).sharedSocket = null; + (ChatWebSocket as unknown as { refCount: number }).refCount = 0; + }); + + afterEach(() => { + vi.useRealTimers(); + vi.restoreAllMocks(); + }); + + it("sends ping after heartbeat interval", () => { + const { MockWebSocket, instances } = makeMockWebSocket(); + vi.stubGlobal("WebSocket", MockWebSocket); + + const ws = new ChatWebSocket(); + ws.connect({}); + instances[0].readyState = 1; // OPEN + instances[0].onopen?.(); // starts heartbeat + + vi.advanceTimersByTime(29_999); + expect(instances[0].sentMessages).toHaveLength(0); + + vi.advanceTimersByTime(1); + expect(instances[0].sentMessages).toHaveLength(1); + expect(JSON.parse(instances[0].sentMessages[0])).toEqual({ type: "ping" }); + + ws.close(); + }); + + it("closes stale connection when pong is not received", () => { + const { MockWebSocket, instances } = makeMockWebSocket(); + vi.stubGlobal("WebSocket", MockWebSocket); + + const ws = new ChatWebSocket(); + ws.connect({}); + instances[0].readyState = 1; // OPEN + instances[0].onopen?.(); // starts heartbeat + + // Fire heartbeat — sends ping and starts pong timeout + vi.advanceTimersByTime(30_000); + + // No pong received; advance past pong timeout → socket closed → reconnect scheduled + vi.advanceTimersByTime(5_000); + + // Advance past reconnect delay + vi.advanceTimersByTime(1_001); + + expect(instances).toHaveLength(2); + ws.close(); + }); + + it("does not close when pong is received before timeout", () => { + const { MockWebSocket, instances } = makeMockWebSocket(); + vi.stubGlobal("WebSocket", MockWebSocket); + + const ws = new ChatWebSocket(); + ws.connect({}); + instances[0].readyState = 1; // OPEN + instances[0].onopen?.(); // starts heartbeat + + // Fire heartbeat + vi.advanceTimersByTime(30_000); + + // Server responds with pong — clears the pong timeout + instances[0].simulateMessage({ type: "pong" }); + + // Advance past where pong timeout would have fired + vi.advanceTimersByTime(5_001); + + // No reconnect triggered + expect(instances).toHaveLength(1); + ws.close(); + }); + + it("stops sending pings after explicit close", () => { + const { MockWebSocket, instances } = makeMockWebSocket(); + vi.stubGlobal("WebSocket", MockWebSocket); + + const ws = new ChatWebSocket(); + ws.connect({}); + instances[0].readyState = 1; // OPEN + instances[0].onopen?.(); // starts heartbeat + + ws.close(); + + // Advance well past multiple heartbeat intervals + vi.advanceTimersByTime(90_000); + + expect(instances[0].sentMessages).toHaveLength(0); + }); +}); diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 7fffe9f..65805d2 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -11,7 +11,8 @@ export type WsRequest = type: "permission_response"; request_id: string; approved: boolean; - }; + } + | { type: "ping" }; export interface AgentAssignment { agent_name: string; @@ -60,7 +61,9 @@ export type WsResponse = } /** `.story_kit/project.toml` was modified; re-fetch the agent roster. */ | { type: "agent_config_changed" } - | { type: "tool_activity"; tool_name: string }; + | { type: "tool_activity"; tool_name: string } + /** Heartbeat response confirming the connection is alive. */ + | { type: "pong" }; export interface ProviderConfig { provider: string; @@ -283,6 +286,30 @@ export class ChatWebSocket { private reconnectTimer?: number; private reconnectDelay = 1000; private shouldReconnect = false; + private heartbeatInterval?: number; + private heartbeatTimeout?: number; + private static readonly HEARTBEAT_INTERVAL = 30_000; + private static readonly HEARTBEAT_TIMEOUT = 5_000; + + private _startHeartbeat(): void { + this._stopHeartbeat(); + this.heartbeatInterval = window.setInterval(() => { + if (!this.socket || this.socket.readyState !== WebSocket.OPEN) return; + const ping: WsRequest = { type: "ping" }; + this.socket.send(JSON.stringify(ping)); + this.heartbeatTimeout = window.setTimeout(() => { + // No pong received within timeout; close socket to trigger reconnect. + this.socket?.close(); + }, ChatWebSocket.HEARTBEAT_TIMEOUT); + }, ChatWebSocket.HEARTBEAT_INTERVAL); + } + + private _stopHeartbeat(): void { + window.clearInterval(this.heartbeatInterval); + window.clearTimeout(this.heartbeatTimeout); + this.heartbeatInterval = undefined; + this.heartbeatTimeout = undefined; + } private _buildWsUrl(): string { const protocol = window.location.protocol === "https:" ? "wss" : "ws"; @@ -298,6 +325,7 @@ export class ChatWebSocket { if (!this.socket) return; this.socket.onopen = () => { this.reconnectDelay = 1000; + this._startHeartbeat(); }; this.socket.onmessage = (event) => { try { @@ -327,6 +355,10 @@ export class ChatWebSocket { data.message, ); if (data.type === "agent_config_changed") this.onAgentConfigChanged?.(); + if (data.type === "pong") { + window.clearTimeout(this.heartbeatTimeout); + this.heartbeatTimeout = undefined; + } } catch (err) { this.onError?.(String(err)); } @@ -420,6 +452,7 @@ export class ChatWebSocket { close() { this.shouldReconnect = false; + this._stopHeartbeat(); window.clearTimeout(this.reconnectTimer); this.reconnectTimer = undefined; diff --git a/server/src/http/ws.rs b/server/src/http/ws.rs index 124a2c3..eaad195 100644 --- a/server/src/http/ws.rs +++ b/server/src/http/ws.rs @@ -29,6 +29,9 @@ enum WsRequest { request_id: String, approved: bool, }, + /// Heartbeat ping from the client. The server responds with `Pong` so the + /// client can detect stale (half-closed) connections. + Ping, } #[derive(Serialize)] @@ -91,6 +94,9 @@ enum WsResponse { status: String, message: String, }, + /// Heartbeat response to a client `Ping`. Lets the client confirm the + /// connection is alive and cancel any stale-connection timeout. + Pong, } impl From for Option { @@ -285,6 +291,9 @@ pub async fn ws_handler(ws: WebSocket, ctx: Data<&Arc>) -> impl poem Ok(WsRequest::Cancel) => { let _ = chat::cancel_chat(&ctx.state); } + Ok(WsRequest::Ping) => { + let _ = tx.send(WsResponse::Pong); + } _ => {} } } @@ -305,6 +314,9 @@ pub async fn ws_handler(ws: WebSocket, ctx: Data<&Arc>) -> impl poem Ok(WsRequest::Cancel) => { let _ = chat::cancel_chat(&ctx.state); } + Ok(WsRequest::Ping) => { + let _ = tx.send(WsResponse::Pong); + } Ok(WsRequest::PermissionResponse { .. }) => { // Permission responses outside an active chat are ignored. } @@ -385,6 +397,13 @@ mod tests { assert!(matches!(req, WsRequest::Cancel)); } + #[test] + fn deserialize_ping_request() { + let json = r#"{"type": "ping"}"#; + let req: WsRequest = serde_json::from_str(json).unwrap(); + assert!(matches!(req, WsRequest::Ping)); + } + #[test] fn deserialize_permission_response_approved() { let json = r#"{ @@ -538,6 +557,13 @@ mod tests { assert_eq!(json["type"], "agent_config_changed"); } + #[test] + fn serialize_pong_response() { + let resp = WsResponse::Pong; + let json = serde_json::to_value(&resp).unwrap(); + assert_eq!(json["type"], "pong"); + } + #[test] fn serialize_permission_request_response() { let resp = WsResponse::PermissionRequest {