story-kit: merge 138_bug_no_heartbeat_to_detect_stale_websocket_connections

This commit is contained in:
Dave
2026-02-24 13:05:30 +00:00
parent 71e07041cf
commit 5226438b16
3 changed files with 163 additions and 4 deletions

View File

@@ -163,7 +163,8 @@ interface MockWsInstance {
onmessage: ((e: { data: string }) => void) | null; onmessage: ((e: { data: string }) => void) | null;
onerror: (() => void) | null; onerror: (() => void) | null;
readyState: number; readyState: number;
send: () => void; sentMessages: string[];
send: (data: string) => void;
close: () => void; close: () => void;
simulateClose: () => void; simulateClose: () => void;
simulateMessage: (data: Record<string, unknown>) => void; simulateMessage: (data: Record<string, unknown>) => void;
@@ -183,12 +184,15 @@ function makeMockWebSocket() {
onmessage: ((e: { data: string }) => void) | null = null; onmessage: ((e: { data: string }) => void) | null = null;
onerror: (() => void) | null = null; onerror: (() => void) | null = null;
readyState = 0; readyState = 0;
sentMessages: string[] = [];
constructor(_url: string) { constructor(_url: string) {
instances.push(this as unknown as MockWsInstance); instances.push(this as unknown as MockWsInstance);
} }
send() {} send(data: string) {
this.sentMessages.push(data);
}
close() { close() {
this.readyState = 3; this.readyState = 3;
@@ -330,3 +334,99 @@ describe("ChatWebSocket", () => {
expect(instances).toHaveLength(4); expect(instances).toHaveLength(4);
}); });
}); });
describe("ChatWebSocket heartbeat", () => {
beforeEach(() => {
vi.useFakeTimers();
const { MockWebSocket } = makeMockWebSocket();
vi.stubGlobal("WebSocket", MockWebSocket);
(ChatWebSocket as unknown as { sharedSocket: null }).sharedSocket = null;
(ChatWebSocket as unknown as { refCount: number }).refCount = 0;
});
afterEach(() => {
vi.useRealTimers();
vi.restoreAllMocks();
});
it("sends ping after heartbeat interval", () => {
const { MockWebSocket, instances } = makeMockWebSocket();
vi.stubGlobal("WebSocket", MockWebSocket);
const ws = new ChatWebSocket();
ws.connect({});
instances[0].readyState = 1; // OPEN
instances[0].onopen?.(); // starts heartbeat
vi.advanceTimersByTime(29_999);
expect(instances[0].sentMessages).toHaveLength(0);
vi.advanceTimersByTime(1);
expect(instances[0].sentMessages).toHaveLength(1);
expect(JSON.parse(instances[0].sentMessages[0])).toEqual({ type: "ping" });
ws.close();
});
it("closes stale connection when pong is not received", () => {
const { MockWebSocket, instances } = makeMockWebSocket();
vi.stubGlobal("WebSocket", MockWebSocket);
const ws = new ChatWebSocket();
ws.connect({});
instances[0].readyState = 1; // OPEN
instances[0].onopen?.(); // starts heartbeat
// Fire heartbeat — sends ping and starts pong timeout
vi.advanceTimersByTime(30_000);
// No pong received; advance past pong timeout → socket closed → reconnect scheduled
vi.advanceTimersByTime(5_000);
// Advance past reconnect delay
vi.advanceTimersByTime(1_001);
expect(instances).toHaveLength(2);
ws.close();
});
it("does not close when pong is received before timeout", () => {
const { MockWebSocket, instances } = makeMockWebSocket();
vi.stubGlobal("WebSocket", MockWebSocket);
const ws = new ChatWebSocket();
ws.connect({});
instances[0].readyState = 1; // OPEN
instances[0].onopen?.(); // starts heartbeat
// Fire heartbeat
vi.advanceTimersByTime(30_000);
// Server responds with pong — clears the pong timeout
instances[0].simulateMessage({ type: "pong" });
// Advance past where pong timeout would have fired
vi.advanceTimersByTime(5_001);
// No reconnect triggered
expect(instances).toHaveLength(1);
ws.close();
});
it("stops sending pings after explicit close", () => {
const { MockWebSocket, instances } = makeMockWebSocket();
vi.stubGlobal("WebSocket", MockWebSocket);
const ws = new ChatWebSocket();
ws.connect({});
instances[0].readyState = 1; // OPEN
instances[0].onopen?.(); // starts heartbeat
ws.close();
// Advance well past multiple heartbeat intervals
vi.advanceTimersByTime(90_000);
expect(instances[0].sentMessages).toHaveLength(0);
});
});

View File

@@ -11,7 +11,8 @@ export type WsRequest =
type: "permission_response"; type: "permission_response";
request_id: string; request_id: string;
approved: boolean; approved: boolean;
}; }
| { type: "ping" };
export interface AgentAssignment { export interface AgentAssignment {
agent_name: string; agent_name: string;
@@ -60,7 +61,9 @@ export type WsResponse =
} }
/** `.story_kit/project.toml` was modified; re-fetch the agent roster. */ /** `.story_kit/project.toml` was modified; re-fetch the agent roster. */
| { type: "agent_config_changed" } | { type: "agent_config_changed" }
| { type: "tool_activity"; tool_name: string }; | { type: "tool_activity"; tool_name: string }
/** Heartbeat response confirming the connection is alive. */
| { type: "pong" };
export interface ProviderConfig { export interface ProviderConfig {
provider: string; provider: string;
@@ -283,6 +286,30 @@ export class ChatWebSocket {
private reconnectTimer?: number; private reconnectTimer?: number;
private reconnectDelay = 1000; private reconnectDelay = 1000;
private shouldReconnect = false; private shouldReconnect = false;
private heartbeatInterval?: number;
private heartbeatTimeout?: number;
private static readonly HEARTBEAT_INTERVAL = 30_000;
private static readonly HEARTBEAT_TIMEOUT = 5_000;
private _startHeartbeat(): void {
this._stopHeartbeat();
this.heartbeatInterval = window.setInterval(() => {
if (!this.socket || this.socket.readyState !== WebSocket.OPEN) return;
const ping: WsRequest = { type: "ping" };
this.socket.send(JSON.stringify(ping));
this.heartbeatTimeout = window.setTimeout(() => {
// No pong received within timeout; close socket to trigger reconnect.
this.socket?.close();
}, ChatWebSocket.HEARTBEAT_TIMEOUT);
}, ChatWebSocket.HEARTBEAT_INTERVAL);
}
private _stopHeartbeat(): void {
window.clearInterval(this.heartbeatInterval);
window.clearTimeout(this.heartbeatTimeout);
this.heartbeatInterval = undefined;
this.heartbeatTimeout = undefined;
}
private _buildWsUrl(): string { private _buildWsUrl(): string {
const protocol = window.location.protocol === "https:" ? "wss" : "ws"; const protocol = window.location.protocol === "https:" ? "wss" : "ws";
@@ -298,6 +325,7 @@ export class ChatWebSocket {
if (!this.socket) return; if (!this.socket) return;
this.socket.onopen = () => { this.socket.onopen = () => {
this.reconnectDelay = 1000; this.reconnectDelay = 1000;
this._startHeartbeat();
}; };
this.socket.onmessage = (event) => { this.socket.onmessage = (event) => {
try { try {
@@ -327,6 +355,10 @@ export class ChatWebSocket {
data.message, data.message,
); );
if (data.type === "agent_config_changed") this.onAgentConfigChanged?.(); if (data.type === "agent_config_changed") this.onAgentConfigChanged?.();
if (data.type === "pong") {
window.clearTimeout(this.heartbeatTimeout);
this.heartbeatTimeout = undefined;
}
} catch (err) { } catch (err) {
this.onError?.(String(err)); this.onError?.(String(err));
} }
@@ -420,6 +452,7 @@ export class ChatWebSocket {
close() { close() {
this.shouldReconnect = false; this.shouldReconnect = false;
this._stopHeartbeat();
window.clearTimeout(this.reconnectTimer); window.clearTimeout(this.reconnectTimer);
this.reconnectTimer = undefined; this.reconnectTimer = undefined;

View File

@@ -29,6 +29,9 @@ enum WsRequest {
request_id: String, request_id: String,
approved: bool, approved: bool,
}, },
/// Heartbeat ping from the client. The server responds with `Pong` so the
/// client can detect stale (half-closed) connections.
Ping,
} }
#[derive(Serialize)] #[derive(Serialize)]
@@ -91,6 +94,9 @@ enum WsResponse {
status: String, status: String,
message: String, message: String,
}, },
/// Heartbeat response to a client `Ping`. Lets the client confirm the
/// connection is alive and cancel any stale-connection timeout.
Pong,
} }
impl From<WatcherEvent> for Option<WsResponse> { impl From<WatcherEvent> for Option<WsResponse> {
@@ -285,6 +291,9 @@ pub async fn ws_handler(ws: WebSocket, ctx: Data<&Arc<AppContext>>) -> impl poem
Ok(WsRequest::Cancel) => { Ok(WsRequest::Cancel) => {
let _ = chat::cancel_chat(&ctx.state); let _ = chat::cancel_chat(&ctx.state);
} }
Ok(WsRequest::Ping) => {
let _ = tx.send(WsResponse::Pong);
}
_ => {} _ => {}
} }
} }
@@ -305,6 +314,9 @@ pub async fn ws_handler(ws: WebSocket, ctx: Data<&Arc<AppContext>>) -> impl poem
Ok(WsRequest::Cancel) => { Ok(WsRequest::Cancel) => {
let _ = chat::cancel_chat(&ctx.state); let _ = chat::cancel_chat(&ctx.state);
} }
Ok(WsRequest::Ping) => {
let _ = tx.send(WsResponse::Pong);
}
Ok(WsRequest::PermissionResponse { .. }) => { Ok(WsRequest::PermissionResponse { .. }) => {
// Permission responses outside an active chat are ignored. // Permission responses outside an active chat are ignored.
} }
@@ -385,6 +397,13 @@ mod tests {
assert!(matches!(req, WsRequest::Cancel)); assert!(matches!(req, WsRequest::Cancel));
} }
#[test]
fn deserialize_ping_request() {
let json = r#"{"type": "ping"}"#;
let req: WsRequest = serde_json::from_str(json).unwrap();
assert!(matches!(req, WsRequest::Ping));
}
#[test] #[test]
fn deserialize_permission_response_approved() { fn deserialize_permission_response_approved() {
let json = r#"{ let json = r#"{
@@ -538,6 +557,13 @@ mod tests {
assert_eq!(json["type"], "agent_config_changed"); assert_eq!(json["type"], "agent_config_changed");
} }
#[test]
fn serialize_pong_response() {
let resp = WsResponse::Pong;
let json = serde_json::to_value(&resp).unwrap();
assert_eq!(json["type"], "pong");
}
#[test] #[test]
fn serialize_permission_request_response() { fn serialize_permission_request_response() {
let resp = WsResponse::PermissionRequest { let resp = WsResponse::PermissionRequest {