use crate::http::context::AppContext; use crate::http::workflow::{PipelineState, load_pipeline_state}; use crate::io::watcher::WatcherEvent; use crate::llm::chat; use crate::llm::types::Message; use futures::{SinkExt, StreamExt}; use poem::handler; use poem::web::Data; use poem::web::websocket::{Message as WsMessage, WebSocket}; use serde::{Deserialize, Serialize}; use std::sync::Arc; use tokio::sync::mpsc; #[derive(Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] /// WebSocket request messages sent by the client. /// /// - `chat` starts a streaming chat session. /// - `cancel` stops the active session. enum WsRequest { Chat { messages: Vec, config: chat::ProviderConfig, }, Cancel, } #[derive(Serialize)] #[serde(tag = "type", rename_all = "snake_case")] /// WebSocket response messages sent by the server. /// /// - `token` streams partial model output. /// - `update` pushes the updated message history. /// - `error` reports a request or processing failure. /// - `work_item_changed` notifies that a `.story_kit/work/` file changed. enum WsResponse { Token { content: String, }, Update { messages: Vec, }, /// Session ID for Claude Code conversation resumption. SessionId { session_id: String, }, Error { message: String, }, /// Filesystem watcher notification: a work-pipeline file was created or /// modified and auto-committed. The frontend can use this to refresh its /// story/bug list without polling. WorkItemChanged { stage: String, item_id: String, action: String, commit_msg: String, }, /// Full pipeline state pushed on connect and after every watcher event. PipelineState { upcoming: Vec, current: Vec, qa: Vec, merge: Vec, }, } impl From for WsResponse { fn from(e: WatcherEvent) -> Self { WsResponse::WorkItemChanged { stage: e.stage, item_id: e.item_id, action: e.action, commit_msg: e.commit_msg, } } } impl From for WsResponse { fn from(s: PipelineState) -> Self { WsResponse::PipelineState { upcoming: s.upcoming, current: s.current, qa: s.qa, merge: s.merge, } } } #[handler] /// WebSocket endpoint for streaming chat responses, cancellation, and /// filesystem watcher notifications. /// /// Accepts JSON `WsRequest` messages and streams `WsResponse` messages. pub async fn ws_handler(ws: WebSocket, ctx: Data<&Arc>) -> impl poem::IntoResponse { let ctx = ctx.0.clone(); ws.on_upgrade(move |socket| async move { let (mut sink, mut stream) = socket.split(); let (tx, mut rx) = mpsc::unbounded_channel::(); let forward = tokio::spawn(async move { while let Some(msg) = rx.recv().await { if let Ok(text) = serde_json::to_string(&msg) && sink.send(WsMessage::Text(text)).await.is_err() { break; } } }); // Push initial pipeline state to the client on connect. if let Ok(state) = load_pipeline_state(ctx.as_ref()) { let _ = tx.send(state.into()); } // Subscribe to filesystem watcher events and forward them to the client. // After each watcher event, also push the updated pipeline state. let tx_watcher = tx.clone(); let ctx_watcher = ctx.clone(); let mut watcher_rx = ctx.watcher_tx.subscribe(); tokio::spawn(async move { loop { match watcher_rx.recv().await { Ok(evt) => { if tx_watcher.send(evt.into()).is_err() { break; } // Push refreshed pipeline state after the change. if let Ok(state) = load_pipeline_state(ctx_watcher.as_ref()) && tx_watcher.send(state.into()).is_err() { break; } } // Lagged: skip missed events, keep going. Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue, Err(tokio::sync::broadcast::error::RecvError::Closed) => break, } } }); while let Some(Ok(msg)) = stream.next().await { if let WsMessage::Text(text) = msg { let parsed: Result = serde_json::from_str(&text); match parsed { Ok(WsRequest::Chat { messages, config }) => { let tx_updates = tx.clone(); let tx_tokens = tx.clone(); let ctx_clone = ctx.clone(); let result = chat::chat( messages, config, &ctx_clone.state, ctx_clone.store.as_ref(), |history| { let _ = tx_updates.send(WsResponse::Update { messages: history.to_vec(), }); }, |token| { let _ = tx_tokens.send(WsResponse::Token { content: token.to_string(), }); }, ) .await; match result { Ok(chat_result) => { if let Some(sid) = chat_result.session_id { let _ = tx.send(WsResponse::SessionId { session_id: sid }); } } Err(err) => { let _ = tx.send(WsResponse::Error { message: err }); } } } Ok(WsRequest::Cancel) => { let _ = chat::cancel_chat(&ctx.state); } Err(err) => { let _ = tx.send(WsResponse::Error { message: format!("Invalid request: {err}"), }); } } } } drop(tx); let _ = forward.await; }) }