//! WebSocket transport adapter — accept connection, serialise/deserialise frames, //! invoke service methods. No business logic, no inline state transitions. use crate::config::ProjectConfig; use crate::http::context::AppContext; use crate::llm::chat; use crate::service::ws::{self, WsResponse}; use futures::{SinkExt, StreamExt}; use poem::handler; use poem::web::Data; use poem::web::websocket::{Message as WsMessage, WebSocket}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{mpsc, oneshot}; use crate::http::context::PermissionDecision; // Re-export WizardStepInfo for any downstream code that imports it from here. #[allow(unused_imports)] pub use crate::service::ws::WizardStepInfo; #[handler] /// WebSocket endpoint for streaming chat responses, cancellation, and /// filesystem watcher notifications. /// /// Accepts JSON `WsRequest` messages and streams `WsResponse` messages. pub async fn ws_handler(ws: WebSocket, ctx: Data<&Arc>) -> impl poem::IntoResponse { let ctx = ctx.0.clone(); ws.on_upgrade(move |socket| async move { let (mut sink, mut stream) = socket.split(); let (tx, mut rx) = mpsc::unbounded_channel::(); let forward = tokio::spawn(async move { while let Some(msg) = rx.recv().await { if let Ok(text) = serde_json::to_string(&msg) && sink.send(WsMessage::Text(text)).await.is_err() { break; } } }); // ── Initial state burst ───────���───────────────────────────── if let Some(state) = ws::load_initial_pipeline_state(ctx.as_ref()) { let _ = tx.send(state); } let _ = tx.send(ws::check_onboarding(ctx.as_ref())); if let Some(wiz) = ws::load_wizard_state(ctx.as_ref()) { let _ = tx.send(wiz); } for log in ws::load_recent_logs(100) { let _ = tx.send(log); } // ── Background subscriptions ──────────────────────��───────── ws::subscribe_logs(tx.clone()); ws::subscribe_watcher(tx.clone(), ctx.clone(), ctx.watcher_tx.subscribe()); ws::subscribe_reconciliation(tx.clone(), ctx.reconciliation_tx.subscribe()); // Subscribe to the status broadcaster if web UI consumer is enabled (default: true). let status_enabled = ctx .state .get_project_root() .ok() .and_then(|root| ProjectConfig::load(&root).ok()) .map(|c| c.web_ui_status_consumer) .unwrap_or(true); if status_enabled { ws::subscribe_status(tx.clone(), ctx.services.status.subscribe()); } // Map of pending permission request_id -> oneshot responder. let mut pending_perms: HashMap> = HashMap::new(); loop { // Outer loop: wait for the next WebSocket message. let Some(Ok(WsMessage::Text(text))) = stream.next().await else { break; }; match ws::dispatch_outer(&text) { ws::DispatchResult::StartChat { messages, config } => { let tx_updates = tx.clone(); let tx_tokens = tx.clone(); let tx_thinking = tx.clone(); let tx_activity = tx.clone(); let ctx_clone = ctx.clone(); let chat_fut = chat::chat( messages, config, &ctx_clone.state, ctx_clone.store.as_ref(), move |history| { let _ = tx_updates.send(WsResponse::Update { messages: history.to_vec(), }); }, move |token| { let _ = tx_tokens.send(WsResponse::Token { content: token.to_string(), }); }, move |thinking: &str| { let _ = tx_thinking.send(WsResponse::ThinkingToken { content: thinking.to_string(), }); }, move |tool_name: &str| { let _ = tx_activity.send(WsResponse::ToolActivity { tool_name: tool_name.to_string(), }); }, ); tokio::pin!(chat_fut); let mut perm_rx = ctx.services.perm_rx.lock().await; let chat_result = loop { tokio::select! { result = &mut chat_fut => break result, Some(perm_fwd) = perm_rx.recv() => { let _ = tx.send(ws::permission_request_response( &perm_fwd.request_id, &perm_fwd.tool_name, &perm_fwd.tool_input, )); pending_perms.insert( perm_fwd.request_id, perm_fwd.response_tx, ); } Some(Ok(WsMessage::Text(inner_text))) = stream.next() => { match ws::dispatch_inner(&inner_text, &mut pending_perms) { ws::InnerDispatchResult::CancelChat => { let _ = chat::cancel_chat(&ctx.state); } ws::InnerDispatchResult::Pong => { let _ = tx.send(WsResponse::Pong); } ws::InnerDispatchResult::StartSideQuestion { question, context_messages, config } => { let tx_side = tx.clone(); let store = ctx.store.clone(); tokio::spawn(async move { let result = chat::side_question( context_messages, question, config, store.as_ref(), |token| { let _ = tx_side.send(WsResponse::SideQuestionToken { content: token.to_string(), }); }, ).await; match result { Ok(response) => { let _ = tx_side.send(WsResponse::SideQuestionDone { response }); } Err(err) => { let _ = tx_side.send(WsResponse::SideQuestionDone { response: format!("Error: {err}"), }); } } }); } ws::InnerDispatchResult::PermissionResolved | ws::InnerDispatchResult::Ignored => {} } } } }; match chat_result { Ok(chat_result) => { if let Some(sid) = chat_result.session_id { let _ = tx.send(WsResponse::SessionId { session_id: sid }); } } Err(err) => { let _ = tx.send(ws::error_response(err)); } } } ws::DispatchResult::CancelChat => { let _ = chat::cancel_chat(&ctx.state); } ws::DispatchResult::Pong => { let _ = tx.send(WsResponse::Pong); } ws::DispatchResult::IgnoredPermission => { // Permission responses outside an active chat are ignored. } ws::DispatchResult::StartSideQuestion { question, context_messages, config, } => { let tx_side = tx.clone(); let store = ctx.store.clone(); tokio::spawn(async move { let result = chat::side_question( context_messages, question, config, store.as_ref(), |token| { let _ = tx_side.send(WsResponse::SideQuestionToken { content: token.to_string(), }); }, ) .await; match result { Ok(response) => { let _ = tx_side.send(WsResponse::SideQuestionDone { response }); } Err(err) => { let _ = tx_side.send(WsResponse::SideQuestionDone { response: format!("Error: {err}"), }); } } }); } ws::DispatchResult::ParseError(msg) => { let _ = tx.send(ws::error_response(msg)); } } } drop(tx); let _ = forward.await; }) } #[cfg(test)] mod tests { use super::*; use crate::io::watcher::WatcherEvent; use crate::service::status::StatusEvent; // ── ws_handler integration tests (real WebSocket connection) ───── use futures::stream::SplitSink; use poem::EndpointExt; use tokio_tungstenite::tungstenite; /// Helper: construct a tungstenite text message from a string. fn ws_text(s: &str) -> tungstenite::Message { tungstenite::Message::Text(s.into()) } /// Helper: start a poem server with ws_handler on an ephemeral port /// and return the WebSocket URL. async fn start_test_server() -> (String, Arc) { let tmp = tempfile::tempdir().unwrap(); let root = tmp.path().to_path_buf(); // Ensure CRDT content store is initialised — load_pipeline_state // now reads from the in-memory CRDT, not the filesystem. crate::db::ensure_content_store(); let ctx = Arc::new(AppContext::new_test(root)); let ctx_data = ctx.clone(); let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); let app = poem::Route::new() .at("/ws", poem::get(ws_handler)) .data(ctx_data); tokio::spawn(async move { let acceptor = poem::listener::TcpAcceptor::from_tokio(listener).unwrap(); let _ = poem::Server::new_with_acceptor(acceptor).run(app).await; }); // Small delay to let the server start. tokio::time::sleep(std::time::Duration::from_millis(50)).await; let url = format!("ws://127.0.0.1:{}/ws", addr.port()); (url, ctx) } type WsSink = SplitSink< tokio_tungstenite::WebSocketStream< tokio_tungstenite::MaybeTlsStream, >, tungstenite::Message, >; /// Helper: connect and return (sink, stream) plus read the initial /// pipeline_state and onboarding_status messages that are always sent /// on connect. async fn connect_ws( url: &str, ) -> ( WsSink, futures::stream::SplitStream< tokio_tungstenite::WebSocketStream< tokio_tungstenite::MaybeTlsStream, >, >, serde_json::Value, ) { let (ws, _resp) = tokio_tungstenite::connect_async(url).await.unwrap(); let (sink, mut stream) = futures::StreamExt::split(ws); // The first message should be the initial pipeline_state. let first = tokio::time::timeout(std::time::Duration::from_secs(2), stream.next()) .await .expect("timeout waiting for initial message") .expect("stream ended") .expect("ws error"); let initial: serde_json::Value = match first { tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(), other => panic!("expected text message, got: {other:?}"), }; // The second message is the onboarding_status — consume it so // callers only see application-level messages. let second = tokio::time::timeout(std::time::Duration::from_secs(2), stream.next()) .await .expect("timeout waiting for onboarding_status") .expect("stream ended") .expect("ws error"); let onboarding: serde_json::Value = match second { tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(), other => panic!("expected text message, got: {other:?}"), }; assert_eq!( onboarding["type"], "onboarding_status", "expected onboarding_status, got: {onboarding}" ); // Drain any log_entry messages sent as initial history on connect. // These are buffered before tests send their own requests. loop { // Use a very short timeout: if nothing arrives quickly, the burst is done. let Ok(Some(Ok(msg))) = tokio::time::timeout(std::time::Duration::from_millis(200), stream.next()).await else { break; }; let val: serde_json::Value = match msg { tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(), _ => break, }; if val["type"] != "log_entry" { // Unexpected non-log message during drain — this shouldn't happen. panic!("unexpected message during log drain: {val}"); } } (sink, stream, initial) } /// Read next non-log_entry text message from the stream with a timeout. /// Skips any `log_entry` messages that arrive between events. async fn next_msg( stream: &mut futures::stream::SplitStream< tokio_tungstenite::WebSocketStream< tokio_tungstenite::MaybeTlsStream, >, >, ) -> serde_json::Value { loop { let msg = tokio::time::timeout(std::time::Duration::from_secs(2), stream.next()) .await .expect("timeout waiting for message") .expect("stream ended") .expect("ws error"); let val: serde_json::Value = match msg { tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(), other => panic!("expected text message, got: {other:?}"), }; if val["type"] != "log_entry" { return val; } } } #[tokio::test] async fn ws_handler_sends_initial_pipeline_state_on_connect() { let (url, _ctx) = start_test_server().await; let (_sink, _stream, initial) = connect_ws(&url).await; assert_eq!(initial["type"], "pipeline_state"); // Verify stage arrays are present (may contain items from the // shared global CRDT store populated by other tests). assert!(initial["backlog"].as_array().is_some()); assert!(initial["current"].as_array().is_some()); assert!(initial["qa"].as_array().is_some()); assert!(initial["merge"].as_array().is_some()); } #[tokio::test] async fn ws_handler_returns_error_for_invalid_json() { let (url, _ctx) = start_test_server().await; let (mut sink, mut stream, _initial) = connect_ws(&url).await; // Send invalid JSON. sink.send(ws_text("not valid json")).await.unwrap(); let msg = next_msg(&mut stream).await; assert_eq!(msg["type"], "error"); assert!( msg["message"].as_str().unwrap().contains("Invalid request"), "error message should indicate invalid request, got: {}", msg["message"] ); } #[tokio::test] async fn ws_handler_returns_error_for_unknown_type() { let (url, _ctx) = start_test_server().await; let (mut sink, mut stream, _initial) = connect_ws(&url).await; // Send a message with an unknown type. sink.send(ws_text(r#"{"type": "bogus"}"#)).await.unwrap(); let msg = next_msg(&mut stream).await; assert_eq!(msg["type"], "error"); assert!(msg["message"].as_str().unwrap().contains("Invalid request")); } #[tokio::test] async fn ws_handler_cancel_outside_chat_does_not_error() { let (url, _ctx) = start_test_server().await; let (mut sink, mut stream, _initial) = connect_ws(&url).await; // Send cancel when no chat is active — should not produce an error. sink.send(ws_text(r#"{"type": "cancel"}"#)).await.unwrap(); // Send another invalid message to check the connection is still alive. sink.send(ws_text("{}")).await.unwrap(); let msg = next_msg(&mut stream).await; // The invalid JSON message should produce an error, confirming // the cancel didn't break the connection. assert_eq!(msg["type"], "error"); } #[tokio::test] async fn ws_handler_permission_response_outside_chat_is_ignored() { let (url, _ctx) = start_test_server().await; let (mut sink, mut stream, _initial) = connect_ws(&url).await; // Send permission response outside an active chat. sink.send(ws_text( r#"{"type": "permission_response", "request_id": "x", "approved": true}"#, )) .await .unwrap(); // Send a probe message to check the connection is still alive. sink.send(ws_text("bad")).await.unwrap(); let msg = next_msg(&mut stream).await; assert_eq!(msg["type"], "error"); assert!(msg["message"].as_str().unwrap().contains("Invalid request")); } #[tokio::test] async fn ws_handler_forwards_watcher_events() { let (url, ctx) = start_test_server().await; let (_sink, mut stream, _initial) = connect_ws(&url).await; // Broadcast a watcher event. ctx.watcher_tx .send(WatcherEvent::WorkItem { stage: "2_current".to_string(), item_id: "99_story_test".to_string(), action: "start".to_string(), commit_msg: "huskies: start 99_story_test".to_string(), from_stage: None, }) .unwrap(); let msg = next_msg(&mut stream).await; assert_eq!(msg["type"], "work_item_changed"); assert_eq!(msg["item_id"], "99_story_test"); assert_eq!(msg["stage"], "2_current"); // After a work-item event, a pipeline_state refresh is pushed. let state_msg = next_msg(&mut stream).await; assert_eq!(state_msg["type"], "pipeline_state"); } #[tokio::test] async fn ws_handler_forwards_config_changed_without_pipeline_refresh() { let (url, ctx) = start_test_server().await; let (_sink, mut stream, _initial) = connect_ws(&url).await; // Broadcast a config-changed event. ctx.watcher_tx.send(WatcherEvent::ConfigChanged).unwrap(); let msg = next_msg(&mut stream).await; assert_eq!(msg["type"], "agent_config_changed"); // Config-changed should NOT be followed by a pipeline_state refresh. // Send a probe to check no extra message is queued. ctx.watcher_tx.send(WatcherEvent::ConfigChanged).unwrap(); let msg2 = next_msg(&mut stream).await; assert_eq!(msg2["type"], "agent_config_changed"); } #[tokio::test] async fn ws_handler_forwards_reconciliation_events() { let (url, ctx) = start_test_server().await; let (_sink, mut stream, _initial) = connect_ws(&url).await; // Broadcast a reconciliation event. ctx.reconciliation_tx .send(crate::agents::ReconciliationEvent { story_id: "50_story_recon".to_string(), status: "checking".to_string(), message: "Checking story...".to_string(), }) .unwrap(); let msg = next_msg(&mut stream).await; assert_eq!(msg["type"], "reconciliation_progress"); assert_eq!(msg["story_id"], "50_story_recon"); assert_eq!(msg["status"], "checking"); assert_eq!(msg["message"], "Checking story..."); } #[tokio::test] async fn ws_handler_handles_client_disconnect_gracefully() { let (url, _ctx) = start_test_server().await; let (mut sink, _stream, _initial) = connect_ws(&url).await; // Close the connection — should not panic the server. sink.close().await.unwrap(); // Give the server a moment to process the close. tokio::time::sleep(std::time::Duration::from_millis(100)).await; // Connect again to verify server is still alive. let (_sink2, _stream2, initial2) = connect_ws(&url).await; assert_eq!(initial2["type"], "pipeline_state"); } /// Read the next `status_update` whose story_id or story_name contains `needle`, /// within a timeout. Skips `log_entry` noise and unrelated status events so /// genuine server log noise cannot cause false positives or negatives. async fn next_status_update_containing( stream: &mut futures::stream::SplitStream< tokio_tungstenite::WebSocketStream< tokio_tungstenite::MaybeTlsStream, >, >, needle: &str, timeout_ms: u64, ) -> Option { let deadline = std::time::Instant::now() + std::time::Duration::from_millis(timeout_ms); loop { let remaining = deadline.saturating_duration_since(std::time::Instant::now()); if remaining.is_zero() { return None; } let msg = tokio::time::timeout(remaining, stream.next()) .await .ok()? .expect("stream ended") .expect("ws error"); let val: serde_json::Value = match msg { tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).ok()?, _ => continue, }; if val["type"] == "status_update" { let event = &val["event"]; let story_id = event["story_id"].as_str().unwrap_or(""); let story_name = event["story_name"].as_str().unwrap_or(""); if story_id.contains(needle) || story_name.contains(needle) { return Some(val); } } // Skip log_entry and other unrelated messages. } } // ── Status broadcaster integration tests ───────────────────────── /// Publishing a status event via `services.status` must result in a /// `status_update` WebSocket message with structured fields delivered to the /// connected client. #[tokio::test] async fn ws_handler_forwards_status_events_as_status_update() { let (url, ctx) = start_test_server().await; let (_sink, mut stream, _initial) = connect_ws(&url).await; // Use a story ID unique enough that genuine server logs won't match it. ctx.services.status.publish(StatusEvent::StageTransition { story_id: "77_story_status_fwd_test".to_string(), story_name: Some("StatusFwdTest".to_string()), from_stage: "1_backlog".to_string(), to_stage: "2_current".to_string(), }); // The handler must forward it as a status_update with structured fields. let msg = next_status_update_containing(&mut stream, "StatusFwdTest", 2000) .await .expect("expected a status_update for the status event"); assert_eq!(msg["type"], "status_update"); let event = &msg["event"]; assert_eq!(event["type"], "stage_transition"); assert_eq!(event["story_id"], "77_story_status_fwd_test"); assert_eq!(event["story_name"], "StatusFwdTest"); assert_eq!(event["from_stage"], "1_backlog"); assert_eq!(event["to_stage"], "2_current"); } /// Multi-project isolation: a client connected to project A's server must /// NOT receive status events published on project B's broadcaster. #[tokio::test] async fn ws_handler_multi_project_status_isolation() { // Start two independent servers (each with its own AppContext / Services). let (url_a, ctx_a) = start_test_server().await; let (url_b, _ctx_b) = start_test_server().await; let (_sink_a, mut stream_a, _) = connect_ws(&url_a).await; let (_sink_b, mut stream_b, _) = connect_ws(&url_b).await; // Use a needle unique enough that genuine server logs won't match. let needle = "ProjAIsolation7734"; ctx_a.services.status.publish(StatusEvent::MergeFailure { story_id: "10_story_proj_a_isolation".to_string(), story_name: Some(needle.to_string()), reason: "conflict".to_string(), }); // Client A must receive the status_update with structured fields. let msg_a = next_status_update_containing(&mut stream_a, needle, 2000) .await .expect("client A should receive the status event"); assert_eq!(msg_a["type"], "status_update"); assert_eq!(msg_a["event"]["story_name"], needle); // Client B must NOT receive any status_update containing the needle. let msg_b = next_status_update_containing(&mut stream_b, needle, 300).await; assert!( msg_b.is_none(), "client B must not receive project A's status event, got: {msg_b:?}" ); } /// When `web_ui_status_consumer = false` in project.toml, the WebSocket /// handler must not forward status events to the connected client. #[tokio::test] async fn ws_handler_status_consumer_disabled_via_config() { let tmp = tempfile::tempdir().unwrap(); let root = tmp.path().to_path_buf(); // Write a project.toml that disables the web UI status consumer. let huskies_dir = root.join(".huskies"); std::fs::create_dir_all(&huskies_dir).unwrap(); std::fs::write( huskies_dir.join("project.toml"), "web_ui_status_consumer = false\n", ) .unwrap(); crate::db::ensure_content_store(); let ctx = Arc::new(AppContext::new_test(root)); let ctx_data = ctx.clone(); let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); let app = poem::Route::new() .at("/ws", poem::get(ws_handler)) .data(ctx_data); tokio::spawn(async move { let acceptor = poem::listener::TcpAcceptor::from_tokio(listener).unwrap(); let _ = poem::Server::new_with_acceptor(acceptor).run(app).await; }); tokio::time::sleep(std::time::Duration::from_millis(50)).await; let url = format!("ws://127.0.0.1:{}/ws", addr.port()); let (_sink, mut stream, _) = connect_ws(&url).await; // Use a unique needle — genuine server logs will never contain this. let needle = "DisabledConsumer9182"; ctx.services.status.publish(StatusEvent::StoryBlocked { story_id: "55_story_disabled_consumer".to_string(), story_name: Some(needle.to_string()), reason: "test".to_string(), }); // Consumer is disabled — no status_update with this needle should arrive. let msg = next_status_update_containing(&mut stream, needle, 500).await; assert!( msg.is_none(), "disabled consumer must not forward status events, got: {msg:?}" ); } }