diff --git a/server/src/http/ws.rs b/server/src/http/ws.rs deleted file mode 100644 index d0e6f83d..00000000 --- a/server/src/http/ws.rs +++ /dev/null @@ -1,737 +0,0 @@ -//! WebSocket transport adapter — accept connection, serialise/deserialise frames, -//! invoke service methods. No business logic, no inline state transitions. - -use crate::config::ProjectConfig; -use crate::http::context::AppContext; -use crate::llm::chat; -use crate::service::ws::{self, WsResponse}; -use futures::{SinkExt, StreamExt}; -use poem::handler; -use poem::web::Data; -use poem::web::websocket::{Message as WsMessage, WebSocket}; -use std::collections::HashMap; -use std::sync::Arc; -use tokio::sync::{mpsc, oneshot}; - -use crate::http::context::PermissionDecision; - -// Re-export WizardStepInfo for any downstream code that imports it from here. -#[allow(unused_imports)] -pub use crate::service::ws::WizardStepInfo; - -#[handler] -/// WebSocket endpoint for streaming chat responses, cancellation, and -/// filesystem watcher notifications. -/// -/// Accepts JSON `WsRequest` messages and streams `WsResponse` messages. -pub async fn ws_handler(ws: WebSocket, ctx: Data<&Arc>) -> impl poem::IntoResponse { - let ctx = ctx.0.clone(); - ws.on_upgrade(move |socket| async move { - let (mut sink, mut stream) = socket.split(); - let (tx, mut rx) = mpsc::unbounded_channel::(); - // Separate channel for pre-serialized messages (e.g. RPC responses). - let (raw_tx, mut raw_rx) = mpsc::unbounded_channel::(); - - let forward = tokio::spawn(async move { - loop { - tokio::select! { - msg = rx.recv() => match msg { - Some(msg) => { - if let Ok(text) = serde_json::to_string(&msg) - && sink.send(WsMessage::Text(text)).await.is_err() - { - break; - } - } - None => break, - }, - raw = raw_rx.recv() => match raw { - Some(text) => { - if sink.send(WsMessage::Text(text)).await.is_err() { - break; - } - } - None => break, - }, - } - } - }); - - // ── Initial state burst ───────���───────────────────────────── - if let Some(state) = ws::load_initial_pipeline_state(ctx.as_ref()) { - let _ = tx.send(state); - } - let _ = tx.send(ws::check_onboarding(ctx.as_ref())); - if let Some(wiz) = ws::load_wizard_state(ctx.as_ref()) { - let _ = tx.send(wiz); - } - for log in ws::load_recent_logs(100) { - let _ = tx.send(log); - } - - // ── Background subscriptions ──────────────────────��───────── - ws::subscribe_logs(tx.clone()); - ws::subscribe_watcher(tx.clone(), ctx.clone(), ctx.watcher_tx.subscribe()); - ws::subscribe_reconciliation(tx.clone(), ctx.reconciliation_tx.subscribe()); - - // Subscribe to the status broadcaster if web UI consumer is enabled (default: true). - let status_enabled = ctx - .state - .get_project_root() - .ok() - .and_then(|root| ProjectConfig::load(&root).ok()) - .map(|c| c.web_ui_status_consumer) - .unwrap_or(true); - if status_enabled { - ws::subscribe_status(tx.clone(), ctx.services.status.subscribe()); - } - - // Map of pending permission request_id -> oneshot responder. - let mut pending_perms: HashMap> = - HashMap::new(); - - loop { - // Outer loop: wait for the next WebSocket message. - let Some(Ok(WsMessage::Text(text))) = stream.next().await else { - break; - }; - - // Handle read-RPC frames (discriminated by "kind", not "type"). - if let Some(rpc_resp) = crate::crdt_sync::try_handle_rpc_text(&text) { - if let Ok(resp_text) = serde_json::to_string(&rpc_resp) { - let _ = raw_tx.send(resp_text); - } - continue; - } - - match ws::dispatch_outer(&text) { - ws::DispatchResult::StartChat { messages, config } => { - let tx_updates = tx.clone(); - let tx_tokens = tx.clone(); - let tx_thinking = tx.clone(); - let tx_activity = tx.clone(); - let ctx_clone = ctx.clone(); - - let chat_fut = chat::chat( - messages, - config, - &ctx_clone.state, - ctx_clone.store.as_ref(), - move |history| { - let _ = tx_updates.send(WsResponse::Update { - messages: history.to_vec(), - }); - }, - move |token| { - let _ = tx_tokens.send(WsResponse::Token { - content: token.to_string(), - }); - }, - move |thinking: &str| { - let _ = tx_thinking.send(WsResponse::ThinkingToken { - content: thinking.to_string(), - }); - }, - move |tool_name: &str| { - let _ = tx_activity.send(WsResponse::ToolActivity { - tool_name: tool_name.to_string(), - }); - }, - ); - tokio::pin!(chat_fut); - - let mut perm_rx = ctx.services.perm_rx.lock().await; - - let chat_result = loop { - tokio::select! { - result = &mut chat_fut => break result, - - Some(perm_fwd) = perm_rx.recv() => { - let _ = tx.send(ws::permission_request_response( - &perm_fwd.request_id, - &perm_fwd.tool_name, - &perm_fwd.tool_input, - )); - pending_perms.insert( - perm_fwd.request_id, - perm_fwd.response_tx, - ); - } - - Some(Ok(WsMessage::Text(inner_text))) = stream.next() => { - // Handle read-RPC frames during active chat. - if let Some(rpc_resp) = crate::crdt_sync::try_handle_rpc_text(&inner_text) { - if let Ok(resp_text) = serde_json::to_string(&rpc_resp) { - let _ = raw_tx.send(resp_text); - } - continue; - } - match ws::dispatch_inner(&inner_text, &mut pending_perms) { - ws::InnerDispatchResult::CancelChat => { - let _ = chat::cancel_chat(&ctx.state); - } - ws::InnerDispatchResult::Pong => { - let _ = tx.send(WsResponse::Pong); - } - ws::InnerDispatchResult::StartSideQuestion { question, context_messages, config } => { - let tx_side = tx.clone(); - let store = ctx.store.clone(); - tokio::spawn(async move { - let result = chat::side_question( - context_messages, - question, - config, - store.as_ref(), - |token| { - let _ = tx_side.send(WsResponse::SideQuestionToken { - content: token.to_string(), - }); - }, - ).await; - match result { - Ok(response) => { - let _ = tx_side.send(WsResponse::SideQuestionDone { response }); - } - Err(err) => { - let _ = tx_side.send(WsResponse::SideQuestionDone { - response: format!("Error: {err}"), - }); - } - } - }); - } - ws::InnerDispatchResult::PermissionResolved - | ws::InnerDispatchResult::Ignored => {} - } - } - } - }; - - match chat_result { - Ok(chat_result) => { - if let Some(sid) = chat_result.session_id { - let _ = tx.send(WsResponse::SessionId { session_id: sid }); - } - } - Err(err) => { - let _ = tx.send(ws::error_response(err)); - } - } - } - ws::DispatchResult::CancelChat => { - let _ = chat::cancel_chat(&ctx.state); - } - ws::DispatchResult::Pong => { - let _ = tx.send(WsResponse::Pong); - } - ws::DispatchResult::IgnoredPermission => { - // Permission responses outside an active chat are ignored. - } - ws::DispatchResult::StartSideQuestion { - question, - context_messages, - config, - } => { - let tx_side = tx.clone(); - let store = ctx.store.clone(); - tokio::spawn(async move { - let result = chat::side_question( - context_messages, - question, - config, - store.as_ref(), - |token| { - let _ = tx_side.send(WsResponse::SideQuestionToken { - content: token.to_string(), - }); - }, - ) - .await; - match result { - Ok(response) => { - let _ = - tx_side.send(WsResponse::SideQuestionDone { response }); - } - Err(err) => { - let _ = tx_side.send(WsResponse::SideQuestionDone { - response: format!("Error: {err}"), - }); - } - } - }); - } - ws::DispatchResult::ParseError(msg) => { - let _ = tx.send(ws::error_response(msg)); - } - } - } - - drop(tx); - let _ = forward.await; - }) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::io::watcher::WatcherEvent; - use crate::service::status::StatusEvent; - - // ── ws_handler integration tests (real WebSocket connection) ───── - - use futures::stream::SplitSink; - use poem::EndpointExt; - use tokio_tungstenite::tungstenite; - - /// Helper: construct a tungstenite text message from a string. - fn ws_text(s: &str) -> tungstenite::Message { - tungstenite::Message::Text(s.into()) - } - - /// Helper: start a poem server with ws_handler on an ephemeral port - /// and return the WebSocket URL. - async fn start_test_server() -> (String, Arc) { - let tmp = tempfile::tempdir().unwrap(); - let root = tmp.path().to_path_buf(); - - // Ensure CRDT content store is initialised — load_pipeline_state - // now reads from the in-memory CRDT, not the filesystem. - crate::db::ensure_content_store(); - - let ctx = Arc::new(AppContext::new_test(root)); - let ctx_data = ctx.clone(); - - let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - - let app = poem::Route::new() - .at("/ws", poem::get(ws_handler)) - .data(ctx_data); - - tokio::spawn(async move { - let acceptor = poem::listener::TcpAcceptor::from_tokio(listener).unwrap(); - let _ = poem::Server::new_with_acceptor(acceptor).run(app).await; - }); - - // Small delay to let the server start. - tokio::time::sleep(std::time::Duration::from_millis(50)).await; - - let url = format!("ws://127.0.0.1:{}/ws", addr.port()); - (url, ctx) - } - - type WsSink = SplitSink< - tokio_tungstenite::WebSocketStream< - tokio_tungstenite::MaybeTlsStream, - >, - tungstenite::Message, - >; - - /// Helper: connect and return (sink, stream) plus read the initial - /// pipeline_state and onboarding_status messages that are always sent - /// on connect. - async fn connect_ws( - url: &str, - ) -> ( - WsSink, - futures::stream::SplitStream< - tokio_tungstenite::WebSocketStream< - tokio_tungstenite::MaybeTlsStream, - >, - >, - serde_json::Value, - ) { - let (ws, _resp) = tokio_tungstenite::connect_async(url).await.unwrap(); - let (sink, mut stream) = futures::StreamExt::split(ws); - - // The first message should be the initial pipeline_state. - let first = tokio::time::timeout(std::time::Duration::from_secs(2), stream.next()) - .await - .expect("timeout waiting for initial message") - .expect("stream ended") - .expect("ws error"); - - let initial: serde_json::Value = match first { - tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(), - other => panic!("expected text message, got: {other:?}"), - }; - - // The second message is the onboarding_status — consume it so - // callers only see application-level messages. - let second = tokio::time::timeout(std::time::Duration::from_secs(2), stream.next()) - .await - .expect("timeout waiting for onboarding_status") - .expect("stream ended") - .expect("ws error"); - let onboarding: serde_json::Value = match second { - tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(), - other => panic!("expected text message, got: {other:?}"), - }; - assert_eq!( - onboarding["type"], "onboarding_status", - "expected onboarding_status, got: {onboarding}" - ); - - // Drain any log_entry messages sent as initial history on connect. - // These are buffered before tests send their own requests. - loop { - // Use a very short timeout: if nothing arrives quickly, the burst is done. - let Ok(Some(Ok(msg))) = - tokio::time::timeout(std::time::Duration::from_millis(200), stream.next()).await - else { - break; - }; - let val: serde_json::Value = match msg { - tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(), - _ => break, - }; - if val["type"] != "log_entry" { - // Unexpected non-log message during drain — this shouldn't happen. - panic!("unexpected message during log drain: {val}"); - } - } - - (sink, stream, initial) - } - - /// Read next non-log_entry text message from the stream with a timeout. - /// Skips any `log_entry` messages that arrive between events. - async fn next_msg( - stream: &mut futures::stream::SplitStream< - tokio_tungstenite::WebSocketStream< - tokio_tungstenite::MaybeTlsStream, - >, - >, - ) -> serde_json::Value { - loop { - let msg = tokio::time::timeout(std::time::Duration::from_secs(2), stream.next()) - .await - .expect("timeout waiting for message") - .expect("stream ended") - .expect("ws error"); - let val: serde_json::Value = match msg { - tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(), - other => panic!("expected text message, got: {other:?}"), - }; - if val["type"] != "log_entry" { - return val; - } - } - } - - #[tokio::test] - async fn ws_handler_sends_initial_pipeline_state_on_connect() { - let (url, _ctx) = start_test_server().await; - let (_sink, _stream, initial) = connect_ws(&url).await; - - assert_eq!(initial["type"], "pipeline_state"); - // Verify stage arrays are present (may contain items from the - // shared global CRDT store populated by other tests). - assert!(initial["backlog"].as_array().is_some()); - assert!(initial["current"].as_array().is_some()); - assert!(initial["qa"].as_array().is_some()); - assert!(initial["merge"].as_array().is_some()); - } - - #[tokio::test] - async fn ws_handler_returns_error_for_invalid_json() { - let (url, _ctx) = start_test_server().await; - let (mut sink, mut stream, _initial) = connect_ws(&url).await; - - // Send invalid JSON. - sink.send(ws_text("not valid json")).await.unwrap(); - - let msg = next_msg(&mut stream).await; - assert_eq!(msg["type"], "error"); - assert!( - msg["message"].as_str().unwrap().contains("Invalid request"), - "error message should indicate invalid request, got: {}", - msg["message"] - ); - } - - #[tokio::test] - async fn ws_handler_returns_error_for_unknown_type() { - let (url, _ctx) = start_test_server().await; - let (mut sink, mut stream, _initial) = connect_ws(&url).await; - - // Send a message with an unknown type. - sink.send(ws_text(r#"{"type": "bogus"}"#)).await.unwrap(); - - let msg = next_msg(&mut stream).await; - assert_eq!(msg["type"], "error"); - assert!(msg["message"].as_str().unwrap().contains("Invalid request")); - } - - #[tokio::test] - async fn ws_handler_cancel_outside_chat_does_not_error() { - let (url, _ctx) = start_test_server().await; - let (mut sink, mut stream, _initial) = connect_ws(&url).await; - - // Send cancel when no chat is active — should not produce an error. - sink.send(ws_text(r#"{"type": "cancel"}"#)).await.unwrap(); - - // Send another invalid message to check the connection is still alive. - sink.send(ws_text("{}")).await.unwrap(); - - let msg = next_msg(&mut stream).await; - // The invalid JSON message should produce an error, confirming - // the cancel didn't break the connection. - assert_eq!(msg["type"], "error"); - } - - #[tokio::test] - async fn ws_handler_permission_response_outside_chat_is_ignored() { - let (url, _ctx) = start_test_server().await; - let (mut sink, mut stream, _initial) = connect_ws(&url).await; - - // Send permission response outside an active chat. - sink.send(ws_text( - r#"{"type": "permission_response", "request_id": "x", "approved": true}"#, - )) - .await - .unwrap(); - - // Send a probe message to check the connection is still alive. - sink.send(ws_text("bad")).await.unwrap(); - - let msg = next_msg(&mut stream).await; - assert_eq!(msg["type"], "error"); - assert!(msg["message"].as_str().unwrap().contains("Invalid request")); - } - - #[tokio::test] - async fn ws_handler_forwards_watcher_events() { - let (url, ctx) = start_test_server().await; - let (_sink, mut stream, _initial) = connect_ws(&url).await; - - // Broadcast a watcher event. - ctx.watcher_tx - .send(WatcherEvent::WorkItem { - stage: "2_current".to_string(), - item_id: "99_story_test".to_string(), - action: "start".to_string(), - commit_msg: "huskies: start 99_story_test".to_string(), - from_stage: None, - }) - .unwrap(); - - let msg = next_msg(&mut stream).await; - assert_eq!(msg["type"], "work_item_changed"); - assert_eq!(msg["item_id"], "99_story_test"); - assert_eq!(msg["stage"], "2_current"); - - // After a work-item event, a pipeline_state refresh is pushed. - let state_msg = next_msg(&mut stream).await; - assert_eq!(state_msg["type"], "pipeline_state"); - } - - #[tokio::test] - async fn ws_handler_forwards_config_changed_without_pipeline_refresh() { - let (url, ctx) = start_test_server().await; - let (_sink, mut stream, _initial) = connect_ws(&url).await; - - // Broadcast a config-changed event. - ctx.watcher_tx.send(WatcherEvent::ConfigChanged).unwrap(); - - let msg = next_msg(&mut stream).await; - assert_eq!(msg["type"], "agent_config_changed"); - - // Config-changed should NOT be followed by a pipeline_state refresh. - // Send a probe to check no extra message is queued. - ctx.watcher_tx.send(WatcherEvent::ConfigChanged).unwrap(); - let msg2 = next_msg(&mut stream).await; - assert_eq!(msg2["type"], "agent_config_changed"); - } - - #[tokio::test] - async fn ws_handler_forwards_reconciliation_events() { - let (url, ctx) = start_test_server().await; - let (_sink, mut stream, _initial) = connect_ws(&url).await; - - // Broadcast a reconciliation event. - ctx.reconciliation_tx - .send(crate::agents::ReconciliationEvent { - story_id: "50_story_recon".to_string(), - status: "checking".to_string(), - message: "Checking story...".to_string(), - }) - .unwrap(); - - let msg = next_msg(&mut stream).await; - assert_eq!(msg["type"], "reconciliation_progress"); - assert_eq!(msg["story_id"], "50_story_recon"); - assert_eq!(msg["status"], "checking"); - assert_eq!(msg["message"], "Checking story..."); - } - - #[tokio::test] - async fn ws_handler_handles_client_disconnect_gracefully() { - let (url, _ctx) = start_test_server().await; - let (mut sink, _stream, _initial) = connect_ws(&url).await; - - // Close the connection — should not panic the server. - sink.close().await.unwrap(); - - // Give the server a moment to process the close. - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - - // Connect again to verify server is still alive. - let (_sink2, _stream2, initial2) = connect_ws(&url).await; - assert_eq!(initial2["type"], "pipeline_state"); - } - - /// Read the next `status_update` whose story_id or story_name contains `needle`, - /// within a timeout. Skips `log_entry` noise and unrelated status events so - /// genuine server log noise cannot cause false positives or negatives. - async fn next_status_update_containing( - stream: &mut futures::stream::SplitStream< - tokio_tungstenite::WebSocketStream< - tokio_tungstenite::MaybeTlsStream, - >, - >, - needle: &str, - timeout_ms: u64, - ) -> Option { - let deadline = std::time::Instant::now() + std::time::Duration::from_millis(timeout_ms); - loop { - let remaining = deadline.saturating_duration_since(std::time::Instant::now()); - if remaining.is_zero() { - return None; - } - let msg = tokio::time::timeout(remaining, stream.next()) - .await - .ok()? - .expect("stream ended") - .expect("ws error"); - let val: serde_json::Value = match msg { - tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).ok()?, - _ => continue, - }; - if val["type"] == "status_update" { - let event = &val["event"]; - let story_id = event["story_id"].as_str().unwrap_or(""); - let story_name = event["story_name"].as_str().unwrap_or(""); - if story_id.contains(needle) || story_name.contains(needle) { - return Some(val); - } - } - // Skip log_entry and other unrelated messages. - } - } - - // ── Status broadcaster integration tests ───────────────────────── - - /// Publishing a status event via `services.status` must result in a - /// `status_update` WebSocket message with structured fields delivered to the - /// connected client. - #[tokio::test] - async fn ws_handler_forwards_status_events_as_status_update() { - let (url, ctx) = start_test_server().await; - let (_sink, mut stream, _initial) = connect_ws(&url).await; - - // Use a story ID unique enough that genuine server logs won't match it. - ctx.services.status.publish(StatusEvent::StageTransition { - story_id: "77_story_status_fwd_test".to_string(), - story_name: Some("StatusFwdTest".to_string()), - from_stage: "1_backlog".to_string(), - to_stage: "2_current".to_string(), - }); - - // The handler must forward it as a status_update with structured fields. - let msg = next_status_update_containing(&mut stream, "StatusFwdTest", 2000) - .await - .expect("expected a status_update for the status event"); - assert_eq!(msg["type"], "status_update"); - let event = &msg["event"]; - assert_eq!(event["type"], "stage_transition"); - assert_eq!(event["story_id"], "77_story_status_fwd_test"); - assert_eq!(event["story_name"], "StatusFwdTest"); - assert_eq!(event["from_stage"], "1_backlog"); - assert_eq!(event["to_stage"], "2_current"); - } - - /// Multi-project isolation: a client connected to project A's server must - /// NOT receive status events published on project B's broadcaster. - #[tokio::test] - async fn ws_handler_multi_project_status_isolation() { - // Start two independent servers (each with its own AppContext / Services). - let (url_a, ctx_a) = start_test_server().await; - let (url_b, _ctx_b) = start_test_server().await; - - let (_sink_a, mut stream_a, _) = connect_ws(&url_a).await; - let (_sink_b, mut stream_b, _) = connect_ws(&url_b).await; - - // Use a needle unique enough that genuine server logs won't match. - let needle = "ProjAIsolation7734"; - ctx_a.services.status.publish(StatusEvent::MergeFailure { - story_id: "10_story_proj_a_isolation".to_string(), - story_name: Some(needle.to_string()), - reason: "conflict".to_string(), - }); - - // Client A must receive the status_update with structured fields. - let msg_a = next_status_update_containing(&mut stream_a, needle, 2000) - .await - .expect("client A should receive the status event"); - assert_eq!(msg_a["type"], "status_update"); - assert_eq!(msg_a["event"]["story_name"], needle); - - // Client B must NOT receive any status_update containing the needle. - let msg_b = next_status_update_containing(&mut stream_b, needle, 300).await; - assert!( - msg_b.is_none(), - "client B must not receive project A's status event, got: {msg_b:?}" - ); - } - - /// When `web_ui_status_consumer = false` in project.toml, the WebSocket - /// handler must not forward status events to the connected client. - #[tokio::test] - async fn ws_handler_status_consumer_disabled_via_config() { - let tmp = tempfile::tempdir().unwrap(); - let root = tmp.path().to_path_buf(); - - // Write a project.toml that disables the web UI status consumer. - let huskies_dir = root.join(".huskies"); - std::fs::create_dir_all(&huskies_dir).unwrap(); - std::fs::write( - huskies_dir.join("project.toml"), - "web_ui_status_consumer = false\n", - ) - .unwrap(); - - crate::db::ensure_content_store(); - let ctx = Arc::new(AppContext::new_test(root)); - let ctx_data = ctx.clone(); - - let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - let app = poem::Route::new() - .at("/ws", poem::get(ws_handler)) - .data(ctx_data); - tokio::spawn(async move { - let acceptor = poem::listener::TcpAcceptor::from_tokio(listener).unwrap(); - let _ = poem::Server::new_with_acceptor(acceptor).run(app).await; - }); - tokio::time::sleep(std::time::Duration::from_millis(50)).await; - - let url = format!("ws://127.0.0.1:{}/ws", addr.port()); - let (_sink, mut stream, _) = connect_ws(&url).await; - - // Use a unique needle — genuine server logs will never contain this. - let needle = "DisabledConsumer9182"; - ctx.services.status.publish(StatusEvent::StoryBlocked { - story_id: "55_story_disabled_consumer".to_string(), - story_name: Some(needle.to_string()), - reason: "test".to_string(), - }); - - // Consumer is disabled — no status_update with this needle should arrive. - let msg = next_status_update_containing(&mut stream, needle, 500).await; - assert!( - msg.is_none(), - "disabled consumer must not forward status events, got: {msg:?}" - ); - } -} diff --git a/server/src/http/ws/mod.rs b/server/src/http/ws/mod.rs new file mode 100644 index 00000000..7c8da623 --- /dev/null +++ b/server/src/http/ws/mod.rs @@ -0,0 +1,275 @@ +//! WebSocket transport adapter — accept connection, serialise/deserialise frames, +//! invoke service methods. No business logic, no inline state transitions. + +use crate::config::ProjectConfig; +use crate::http::context::AppContext; +use crate::llm::chat; +use crate::service::ws::{self, WsResponse}; +use futures::{SinkExt, StreamExt}; +use poem::handler; +use poem::web::Data; +use poem::web::websocket::{Message as WsMessage, WebSocket}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::{mpsc, oneshot}; + +use crate::http::context::PermissionDecision; + +// Re-export WizardStepInfo for any downstream code that imports it from here. +#[allow(unused_imports)] +pub use crate::service::ws::WizardStepInfo; + +#[handler] +/// WebSocket endpoint for streaming chat responses, cancellation, and +/// filesystem watcher notifications. +/// +/// Accepts JSON `WsRequest` messages and streams `WsResponse` messages. +pub async fn ws_handler(ws: WebSocket, ctx: Data<&Arc>) -> impl poem::IntoResponse { + let ctx = ctx.0.clone(); + ws.on_upgrade(move |socket| async move { + let (mut sink, mut stream) = socket.split(); + let (tx, mut rx) = mpsc::unbounded_channel::(); + // Separate channel for pre-serialized messages (e.g. RPC responses). + let (raw_tx, mut raw_rx) = mpsc::unbounded_channel::(); + + let forward = tokio::spawn(async move { + loop { + tokio::select! { + msg = rx.recv() => match msg { + Some(msg) => { + if let Ok(text) = serde_json::to_string(&msg) + && sink.send(WsMessage::Text(text)).await.is_err() + { + break; + } + } + None => break, + }, + raw = raw_rx.recv() => match raw { + Some(text) => { + if sink.send(WsMessage::Text(text)).await.is_err() { + break; + } + } + None => break, + }, + } + } + }); + + // ── Initial state burst ──────────────────────────────────────── + if let Some(state) = ws::load_initial_pipeline_state(ctx.as_ref()) { + let _ = tx.send(state); + } + let _ = tx.send(ws::check_onboarding(ctx.as_ref())); + if let Some(wiz) = ws::load_wizard_state(ctx.as_ref()) { + let _ = tx.send(wiz); + } + for log in ws::load_recent_logs(100) { + let _ = tx.send(log); + } + + // ── Background subscriptions ─────────────────────────────────── + ws::subscribe_logs(tx.clone()); + ws::subscribe_watcher(tx.clone(), ctx.clone(), ctx.watcher_tx.subscribe()); + ws::subscribe_reconciliation(tx.clone(), ctx.reconciliation_tx.subscribe()); + + // Subscribe to the status broadcaster if web UI consumer is enabled (default: true). + let status_enabled = ctx + .state + .get_project_root() + .ok() + .and_then(|root| ProjectConfig::load(&root).ok()) + .map(|c| c.web_ui_status_consumer) + .unwrap_or(true); + if status_enabled { + ws::subscribe_status(tx.clone(), ctx.services.status.subscribe()); + } + + // Map of pending permission request_id -> oneshot responder. + let mut pending_perms: HashMap> = + HashMap::new(); + + loop { + // Outer loop: wait for the next WebSocket message. + let Some(Ok(WsMessage::Text(text))) = stream.next().await else { + break; + }; + + // Handle read-RPC frames (discriminated by "kind", not "type"). + if let Some(rpc_resp) = crate::crdt_sync::try_handle_rpc_text(&text) { + if let Ok(resp_text) = serde_json::to_string(&rpc_resp) { + let _ = raw_tx.send(resp_text); + } + continue; + } + + match ws::dispatch_outer(&text) { + ws::DispatchResult::StartChat { messages, config } => { + let tx_updates = tx.clone(); + let tx_tokens = tx.clone(); + let tx_thinking = tx.clone(); + let tx_activity = tx.clone(); + let ctx_clone = ctx.clone(); + + let chat_fut = chat::chat( + messages, + config, + &ctx_clone.state, + ctx_clone.store.as_ref(), + move |history| { + let _ = tx_updates.send(WsResponse::Update { + messages: history.to_vec(), + }); + }, + move |token| { + let _ = tx_tokens.send(WsResponse::Token { + content: token.to_string(), + }); + }, + move |thinking: &str| { + let _ = tx_thinking.send(WsResponse::ThinkingToken { + content: thinking.to_string(), + }); + }, + move |tool_name: &str| { + let _ = tx_activity.send(WsResponse::ToolActivity { + tool_name: tool_name.to_string(), + }); + }, + ); + tokio::pin!(chat_fut); + + let mut perm_rx = ctx.services.perm_rx.lock().await; + + let chat_result = loop { + tokio::select! { + result = &mut chat_fut => break result, + + Some(perm_fwd) = perm_rx.recv() => { + let _ = tx.send(ws::permission_request_response( + &perm_fwd.request_id, + &perm_fwd.tool_name, + &perm_fwd.tool_input, + )); + pending_perms.insert( + perm_fwd.request_id, + perm_fwd.response_tx, + ); + } + + Some(Ok(WsMessage::Text(inner_text))) = stream.next() => { + // Handle read-RPC frames during active chat. + if let Some(rpc_resp) = crate::crdt_sync::try_handle_rpc_text(&inner_text) { + if let Ok(resp_text) = serde_json::to_string(&rpc_resp) { + let _ = raw_tx.send(resp_text); + } + continue; + } + match ws::dispatch_inner(&inner_text, &mut pending_perms) { + ws::InnerDispatchResult::CancelChat => { + let _ = chat::cancel_chat(&ctx.state); + } + ws::InnerDispatchResult::Pong => { + let _ = tx.send(WsResponse::Pong); + } + ws::InnerDispatchResult::StartSideQuestion { question, context_messages, config } => { + let tx_side = tx.clone(); + let store = ctx.store.clone(); + tokio::spawn(async move { + let result = chat::side_question( + context_messages, + question, + config, + store.as_ref(), + |token| { + let _ = tx_side.send(WsResponse::SideQuestionToken { + content: token.to_string(), + }); + }, + ).await; + match result { + Ok(response) => { + let _ = tx_side.send(WsResponse::SideQuestionDone { response }); + } + Err(err) => { + let _ = tx_side.send(WsResponse::SideQuestionDone { + response: format!("Error: {err}"), + }); + } + } + }); + } + ws::InnerDispatchResult::PermissionResolved + | ws::InnerDispatchResult::Ignored => {} + } + } + } + }; + + match chat_result { + Ok(chat_result) => { + if let Some(sid) = chat_result.session_id { + let _ = tx.send(WsResponse::SessionId { session_id: sid }); + } + } + Err(err) => { + let _ = tx.send(ws::error_response(err)); + } + } + } + ws::DispatchResult::CancelChat => { + let _ = chat::cancel_chat(&ctx.state); + } + ws::DispatchResult::Pong => { + let _ = tx.send(WsResponse::Pong); + } + ws::DispatchResult::IgnoredPermission => { + // Permission responses outside an active chat are ignored. + } + ws::DispatchResult::StartSideQuestion { + question, + context_messages, + config, + } => { + let tx_side = tx.clone(); + let store = ctx.store.clone(); + tokio::spawn(async move { + let result = chat::side_question( + context_messages, + question, + config, + store.as_ref(), + |token| { + let _ = tx_side.send(WsResponse::SideQuestionToken { + content: token.to_string(), + }); + }, + ) + .await; + match result { + Ok(response) => { + let _ = + tx_side.send(WsResponse::SideQuestionDone { response }); + } + Err(err) => { + let _ = tx_side.send(WsResponse::SideQuestionDone { + response: format!("Error: {err}"), + }); + } + } + }); + } + ws::DispatchResult::ParseError(msg) => { + let _ = tx.send(ws::error_response(msg)); + } + } + } + + drop(tx); + let _ = forward.await; + }) +} + +#[cfg(test)] +mod tests; diff --git a/server/src/http/ws/tests.rs b/server/src/http/ws/tests.rs new file mode 100644 index 00000000..a13310d9 --- /dev/null +++ b/server/src/http/ws/tests.rs @@ -0,0 +1,466 @@ +//! Integration tests for the WebSocket handler (`ws_handler`). +//! +//! Tests cover: initial state burst on connect, error responses for invalid +//! messages, cancel/permission handling outside chat, watcher event forwarding, +//! reconciliation event forwarding, status broadcaster forwarding, and graceful +//! client disconnect. + +use super::*; +use crate::io::watcher::WatcherEvent; +use crate::service::status::StatusEvent; + +// ── ws_handler integration tests (real WebSocket connection) ───── + +use futures::stream::SplitSink; +use poem::EndpointExt; +use tokio_tungstenite::tungstenite; + +/// Helper: construct a tungstenite text message from a string. +fn ws_text(s: &str) -> tungstenite::Message { + tungstenite::Message::Text(s.into()) +} + +/// Helper: start a poem server with ws_handler on an ephemeral port +/// and return the WebSocket URL. +async fn start_test_server() -> (String, Arc) { + let tmp = tempfile::tempdir().unwrap(); + let root = tmp.path().to_path_buf(); + + // Ensure CRDT content store is initialised — load_pipeline_state + // now reads from the in-memory CRDT, not the filesystem. + crate::db::ensure_content_store(); + + let ctx = Arc::new(AppContext::new_test(root)); + let ctx_data = ctx.clone(); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let app = poem::Route::new() + .at("/ws", poem::get(ws_handler)) + .data(ctx_data); + + tokio::spawn(async move { + let acceptor = poem::listener::TcpAcceptor::from_tokio(listener).unwrap(); + let _ = poem::Server::new_with_acceptor(acceptor).run(app).await; + }); + + // Small delay to let the server start. + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + + let url = format!("ws://127.0.0.1:{}/ws", addr.port()); + (url, ctx) +} + +type WsSink = SplitSink< + tokio_tungstenite::WebSocketStream>, + tungstenite::Message, +>; + +/// Helper: connect and return (sink, stream) plus read the initial +/// pipeline_state and onboarding_status messages that are always sent +/// on connect. +async fn connect_ws( + url: &str, +) -> ( + WsSink, + futures::stream::SplitStream< + tokio_tungstenite::WebSocketStream< + tokio_tungstenite::MaybeTlsStream, + >, + >, + serde_json::Value, +) { + let (ws, _resp) = tokio_tungstenite::connect_async(url).await.unwrap(); + let (sink, mut stream) = futures::StreamExt::split(ws); + + // The first message should be the initial pipeline_state. + let first = tokio::time::timeout(std::time::Duration::from_secs(2), stream.next()) + .await + .expect("timeout waiting for initial message") + .expect("stream ended") + .expect("ws error"); + + let initial: serde_json::Value = match first { + tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(), + other => panic!("expected text message, got: {other:?}"), + }; + + // The second message is the onboarding_status — consume it so + // callers only see application-level messages. + let second = tokio::time::timeout(std::time::Duration::from_secs(2), stream.next()) + .await + .expect("timeout waiting for onboarding_status") + .expect("stream ended") + .expect("ws error"); + let onboarding: serde_json::Value = match second { + tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(), + other => panic!("expected text message, got: {other:?}"), + }; + assert_eq!( + onboarding["type"], "onboarding_status", + "expected onboarding_status, got: {onboarding}" + ); + + // Drain any log_entry messages sent as initial history on connect. + // These are buffered before tests send their own requests. + loop { + // Use a very short timeout: if nothing arrives quickly, the burst is done. + let Ok(Some(Ok(msg))) = + tokio::time::timeout(std::time::Duration::from_millis(200), stream.next()).await + else { + break; + }; + let val: serde_json::Value = match msg { + tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(), + _ => break, + }; + if val["type"] != "log_entry" { + // Unexpected non-log message during drain — this shouldn't happen. + panic!("unexpected message during log drain: {val}"); + } + } + + (sink, stream, initial) +} + +/// Read next non-log_entry text message from the stream with a timeout. +/// Skips any `log_entry` messages that arrive between events. +async fn next_msg( + stream: &mut futures::stream::SplitStream< + tokio_tungstenite::WebSocketStream< + tokio_tungstenite::MaybeTlsStream, + >, + >, +) -> serde_json::Value { + loop { + let msg = tokio::time::timeout(std::time::Duration::from_secs(2), stream.next()) + .await + .expect("timeout waiting for message") + .expect("stream ended") + .expect("ws error"); + let val: serde_json::Value = match msg { + tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(), + other => panic!("expected text message, got: {other:?}"), + }; + if val["type"] != "log_entry" { + return val; + } + } +} + +#[tokio::test] +async fn ws_handler_sends_initial_pipeline_state_on_connect() { + let (url, _ctx) = start_test_server().await; + let (_sink, _stream, initial) = connect_ws(&url).await; + + assert_eq!(initial["type"], "pipeline_state"); + // Verify stage arrays are present (may contain items from the + // shared global CRDT store populated by other tests). + assert!(initial["backlog"].as_array().is_some()); + assert!(initial["current"].as_array().is_some()); + assert!(initial["qa"].as_array().is_some()); + assert!(initial["merge"].as_array().is_some()); +} + +#[tokio::test] +async fn ws_handler_returns_error_for_invalid_json() { + let (url, _ctx) = start_test_server().await; + let (mut sink, mut stream, _initial) = connect_ws(&url).await; + + // Send invalid JSON. + sink.send(ws_text("not valid json")).await.unwrap(); + + let msg = next_msg(&mut stream).await; + assert_eq!(msg["type"], "error"); + assert!( + msg["message"].as_str().unwrap().contains("Invalid request"), + "error message should indicate invalid request, got: {}", + msg["message"] + ); +} + +#[tokio::test] +async fn ws_handler_returns_error_for_unknown_type() { + let (url, _ctx) = start_test_server().await; + let (mut sink, mut stream, _initial) = connect_ws(&url).await; + + // Send a message with an unknown type. + sink.send(ws_text(r#"{"type": "bogus"}"#)).await.unwrap(); + + let msg = next_msg(&mut stream).await; + assert_eq!(msg["type"], "error"); + assert!(msg["message"].as_str().unwrap().contains("Invalid request")); +} + +#[tokio::test] +async fn ws_handler_cancel_outside_chat_does_not_error() { + let (url, _ctx) = start_test_server().await; + let (mut sink, mut stream, _initial) = connect_ws(&url).await; + + // Send cancel when no chat is active — should not produce an error. + sink.send(ws_text(r#"{"type": "cancel"}"#)).await.unwrap(); + + // Send another invalid message to check the connection is still alive. + sink.send(ws_text("{}")).await.unwrap(); + + let msg = next_msg(&mut stream).await; + // The invalid JSON message should produce an error, confirming + // the cancel didn't break the connection. + assert_eq!(msg["type"], "error"); +} + +#[tokio::test] +async fn ws_handler_permission_response_outside_chat_is_ignored() { + let (url, _ctx) = start_test_server().await; + let (mut sink, mut stream, _initial) = connect_ws(&url).await; + + // Send permission response outside an active chat. + sink.send(ws_text( + r#"{"type": "permission_response", "request_id": "x", "approved": true}"#, + )) + .await + .unwrap(); + + // Send a probe message to check the connection is still alive. + sink.send(ws_text("bad")).await.unwrap(); + + let msg = next_msg(&mut stream).await; + assert_eq!(msg["type"], "error"); + assert!(msg["message"].as_str().unwrap().contains("Invalid request")); +} + +#[tokio::test] +async fn ws_handler_forwards_watcher_events() { + let (url, ctx) = start_test_server().await; + let (_sink, mut stream, _initial) = connect_ws(&url).await; + + // Broadcast a watcher event. + ctx.watcher_tx + .send(WatcherEvent::WorkItem { + stage: "2_current".to_string(), + item_id: "99_story_test".to_string(), + action: "start".to_string(), + commit_msg: "huskies: start 99_story_test".to_string(), + from_stage: None, + }) + .unwrap(); + + let msg = next_msg(&mut stream).await; + assert_eq!(msg["type"], "work_item_changed"); + assert_eq!(msg["item_id"], "99_story_test"); + assert_eq!(msg["stage"], "2_current"); + + // After a work-item event, a pipeline_state refresh is pushed. + let state_msg = next_msg(&mut stream).await; + assert_eq!(state_msg["type"], "pipeline_state"); +} + +#[tokio::test] +async fn ws_handler_forwards_config_changed_without_pipeline_refresh() { + let (url, ctx) = start_test_server().await; + let (_sink, mut stream, _initial) = connect_ws(&url).await; + + // Broadcast a config-changed event. + ctx.watcher_tx.send(WatcherEvent::ConfigChanged).unwrap(); + + let msg = next_msg(&mut stream).await; + assert_eq!(msg["type"], "agent_config_changed"); + + // Config-changed should NOT be followed by a pipeline_state refresh. + // Send a probe to check no extra message is queued. + ctx.watcher_tx.send(WatcherEvent::ConfigChanged).unwrap(); + let msg2 = next_msg(&mut stream).await; + assert_eq!(msg2["type"], "agent_config_changed"); +} + +#[tokio::test] +async fn ws_handler_forwards_reconciliation_events() { + let (url, ctx) = start_test_server().await; + let (_sink, mut stream, _initial) = connect_ws(&url).await; + + // Broadcast a reconciliation event. + ctx.reconciliation_tx + .send(crate::agents::ReconciliationEvent { + story_id: "50_story_recon".to_string(), + status: "checking".to_string(), + message: "Checking story...".to_string(), + }) + .unwrap(); + + let msg = next_msg(&mut stream).await; + assert_eq!(msg["type"], "reconciliation_progress"); + assert_eq!(msg["story_id"], "50_story_recon"); + assert_eq!(msg["status"], "checking"); + assert_eq!(msg["message"], "Checking story..."); +} + +#[tokio::test] +async fn ws_handler_handles_client_disconnect_gracefully() { + let (url, _ctx) = start_test_server().await; + let (mut sink, _stream, _initial) = connect_ws(&url).await; + + // Close the connection — should not panic the server. + sink.close().await.unwrap(); + + // Give the server a moment to process the close. + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + // Connect again to verify server is still alive. + let (_sink2, _stream2, initial2) = connect_ws(&url).await; + assert_eq!(initial2["type"], "pipeline_state"); +} + +/// Read the next `status_update` whose story_id or story_name contains `needle`, +/// within a timeout. Skips `log_entry` noise and unrelated status events so +/// genuine server log noise cannot cause false positives or negatives. +async fn next_status_update_containing( + stream: &mut futures::stream::SplitStream< + tokio_tungstenite::WebSocketStream< + tokio_tungstenite::MaybeTlsStream, + >, + >, + needle: &str, + timeout_ms: u64, +) -> Option { + let deadline = std::time::Instant::now() + std::time::Duration::from_millis(timeout_ms); + loop { + let remaining = deadline.saturating_duration_since(std::time::Instant::now()); + if remaining.is_zero() { + return None; + } + let msg = tokio::time::timeout(remaining, stream.next()) + .await + .ok()? + .expect("stream ended") + .expect("ws error"); + let val: serde_json::Value = match msg { + tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).ok()?, + _ => continue, + }; + if val["type"] == "status_update" { + let event = &val["event"]; + let story_id = event["story_id"].as_str().unwrap_or(""); + let story_name = event["story_name"].as_str().unwrap_or(""); + if story_id.contains(needle) || story_name.contains(needle) { + return Some(val); + } + } + // Skip log_entry and other unrelated messages. + } +} + +// ── Status broadcaster integration tests ───────────────────────── + +/// Publishing a status event via `services.status` must result in a +/// `status_update` WebSocket message with structured fields delivered to the +/// connected client. +#[tokio::test] +async fn ws_handler_forwards_status_events_as_status_update() { + let (url, ctx) = start_test_server().await; + let (_sink, mut stream, _initial) = connect_ws(&url).await; + + // Use a story ID unique enough that genuine server logs won't match it. + ctx.services.status.publish(StatusEvent::StageTransition { + story_id: "77_story_status_fwd_test".to_string(), + story_name: Some("StatusFwdTest".to_string()), + from_stage: "1_backlog".to_string(), + to_stage: "2_current".to_string(), + }); + + // The handler must forward it as a status_update with structured fields. + let msg = next_status_update_containing(&mut stream, "StatusFwdTest", 2000) + .await + .expect("expected a status_update for the status event"); + assert_eq!(msg["type"], "status_update"); + let event = &msg["event"]; + assert_eq!(event["type"], "stage_transition"); + assert_eq!(event["story_id"], "77_story_status_fwd_test"); + assert_eq!(event["story_name"], "StatusFwdTest"); + assert_eq!(event["from_stage"], "1_backlog"); + assert_eq!(event["to_stage"], "2_current"); +} + +/// Multi-project isolation: a client connected to project A's server must +/// NOT receive status events published on project B's broadcaster. +#[tokio::test] +async fn ws_handler_multi_project_status_isolation() { + // Start two independent servers (each with its own AppContext / Services). + let (url_a, ctx_a) = start_test_server().await; + let (url_b, _ctx_b) = start_test_server().await; + + let (_sink_a, mut stream_a, _) = connect_ws(&url_a).await; + let (_sink_b, mut stream_b, _) = connect_ws(&url_b).await; + + // Use a needle unique enough that genuine server logs won't match. + let needle = "ProjAIsolation7734"; + ctx_a.services.status.publish(StatusEvent::MergeFailure { + story_id: "10_story_proj_a_isolation".to_string(), + story_name: Some(needle.to_string()), + reason: "conflict".to_string(), + }); + + // Client A must receive the status_update with structured fields. + let msg_a = next_status_update_containing(&mut stream_a, needle, 2000) + .await + .expect("client A should receive the status event"); + assert_eq!(msg_a["type"], "status_update"); + assert_eq!(msg_a["event"]["story_name"], needle); + + // Client B must NOT receive any status_update containing the needle. + let msg_b = next_status_update_containing(&mut stream_b, needle, 300).await; + assert!( + msg_b.is_none(), + "client B must not receive project A's status event, got: {msg_b:?}" + ); +} + +/// When `web_ui_status_consumer = false` in project.toml, the WebSocket +/// handler must not forward status events to the connected client. +#[tokio::test] +async fn ws_handler_status_consumer_disabled_via_config() { + let tmp = tempfile::tempdir().unwrap(); + let root = tmp.path().to_path_buf(); + + // Write a project.toml that disables the web UI status consumer. + let huskies_dir = root.join(".huskies"); + std::fs::create_dir_all(&huskies_dir).unwrap(); + std::fs::write( + huskies_dir.join("project.toml"), + "web_ui_status_consumer = false\n", + ) + .unwrap(); + + crate::db::ensure_content_store(); + let ctx = Arc::new(AppContext::new_test(root)); + let ctx_data = ctx.clone(); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let app = poem::Route::new() + .at("/ws", poem::get(ws_handler)) + .data(ctx_data); + tokio::spawn(async move { + let acceptor = poem::listener::TcpAcceptor::from_tokio(listener).unwrap(); + let _ = poem::Server::new_with_acceptor(acceptor).run(app).await; + }); + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + + let url = format!("ws://127.0.0.1:{}/ws", addr.port()); + let (_sink, mut stream, _) = connect_ws(&url).await; + + // Use a unique needle — genuine server logs will never contain this. + let needle = "DisabledConsumer9182"; + ctx.services.status.publish(StatusEvent::StoryBlocked { + story_id: "55_story_disabled_consumer".to_string(), + story_name: Some(needle.to_string()), + reason: "test".to_string(), + }); + + // Consumer is disabled — no status_update with this needle should arrive. + let msg = next_status_update_containing(&mut stream, needle, 500).await; + assert!( + msg.is_none(), + "disabled consumer must not forward status events, got: {msg:?}" + ); +}