huskies: merge 847

This commit is contained in:
dave
2026-04-29 18:35:32 +00:00
parent 39013be535
commit a956a98197
3 changed files with 741 additions and 737 deletions
-737
View File
@@ -1,737 +0,0 @@
//! WebSocket transport adapter — accept connection, serialise/deserialise frames,
//! invoke service methods. No business logic, no inline state transitions.
use crate::config::ProjectConfig;
use crate::http::context::AppContext;
use crate::llm::chat;
use crate::service::ws::{self, WsResponse};
use futures::{SinkExt, StreamExt};
use poem::handler;
use poem::web::Data;
use poem::web::websocket::{Message as WsMessage, WebSocket};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
use crate::http::context::PermissionDecision;
// Re-export WizardStepInfo for any downstream code that imports it from here.
#[allow(unused_imports)]
pub use crate::service::ws::WizardStepInfo;
#[handler]
/// WebSocket endpoint for streaming chat responses, cancellation, and
/// filesystem watcher notifications.
///
/// Accepts JSON `WsRequest` messages and streams `WsResponse` messages.
pub async fn ws_handler(ws: WebSocket, ctx: Data<&Arc<AppContext>>) -> impl poem::IntoResponse {
let ctx = ctx.0.clone();
ws.on_upgrade(move |socket| async move {
let (mut sink, mut stream) = socket.split();
let (tx, mut rx) = mpsc::unbounded_channel::<WsResponse>();
// Separate channel for pre-serialized messages (e.g. RPC responses).
let (raw_tx, mut raw_rx) = mpsc::unbounded_channel::<String>();
let forward = tokio::spawn(async move {
loop {
tokio::select! {
msg = rx.recv() => match msg {
Some(msg) => {
if let Ok(text) = serde_json::to_string(&msg)
&& sink.send(WsMessage::Text(text)).await.is_err()
{
break;
}
}
None => break,
},
raw = raw_rx.recv() => match raw {
Some(text) => {
if sink.send(WsMessage::Text(text)).await.is_err() {
break;
}
}
None => break,
},
}
}
});
// ── Initial state burst ────────────────────────────────────
if let Some(state) = ws::load_initial_pipeline_state(ctx.as_ref()) {
let _ = tx.send(state);
}
let _ = tx.send(ws::check_onboarding(ctx.as_ref()));
if let Some(wiz) = ws::load_wizard_state(ctx.as_ref()) {
let _ = tx.send(wiz);
}
for log in ws::load_recent_logs(100) {
let _ = tx.send(log);
}
// ── Background subscriptions ───────────────────────────────
ws::subscribe_logs(tx.clone());
ws::subscribe_watcher(tx.clone(), ctx.clone(), ctx.watcher_tx.subscribe());
ws::subscribe_reconciliation(tx.clone(), ctx.reconciliation_tx.subscribe());
// Subscribe to the status broadcaster if web UI consumer is enabled (default: true).
let status_enabled = ctx
.state
.get_project_root()
.ok()
.and_then(|root| ProjectConfig::load(&root).ok())
.map(|c| c.web_ui_status_consumer)
.unwrap_or(true);
if status_enabled {
ws::subscribe_status(tx.clone(), ctx.services.status.subscribe());
}
// Map of pending permission request_id -> oneshot responder.
let mut pending_perms: HashMap<String, oneshot::Sender<PermissionDecision>> =
HashMap::new();
loop {
// Outer loop: wait for the next WebSocket message.
let Some(Ok(WsMessage::Text(text))) = stream.next().await else {
break;
};
// Handle read-RPC frames (discriminated by "kind", not "type").
if let Some(rpc_resp) = crate::crdt_sync::try_handle_rpc_text(&text) {
if let Ok(resp_text) = serde_json::to_string(&rpc_resp) {
let _ = raw_tx.send(resp_text);
}
continue;
}
match ws::dispatch_outer(&text) {
ws::DispatchResult::StartChat { messages, config } => {
let tx_updates = tx.clone();
let tx_tokens = tx.clone();
let tx_thinking = tx.clone();
let tx_activity = tx.clone();
let ctx_clone = ctx.clone();
let chat_fut = chat::chat(
messages,
config,
&ctx_clone.state,
ctx_clone.store.as_ref(),
move |history| {
let _ = tx_updates.send(WsResponse::Update {
messages: history.to_vec(),
});
},
move |token| {
let _ = tx_tokens.send(WsResponse::Token {
content: token.to_string(),
});
},
move |thinking: &str| {
let _ = tx_thinking.send(WsResponse::ThinkingToken {
content: thinking.to_string(),
});
},
move |tool_name: &str| {
let _ = tx_activity.send(WsResponse::ToolActivity {
tool_name: tool_name.to_string(),
});
},
);
tokio::pin!(chat_fut);
let mut perm_rx = ctx.services.perm_rx.lock().await;
let chat_result = loop {
tokio::select! {
result = &mut chat_fut => break result,
Some(perm_fwd) = perm_rx.recv() => {
let _ = tx.send(ws::permission_request_response(
&perm_fwd.request_id,
&perm_fwd.tool_name,
&perm_fwd.tool_input,
));
pending_perms.insert(
perm_fwd.request_id,
perm_fwd.response_tx,
);
}
Some(Ok(WsMessage::Text(inner_text))) = stream.next() => {
// Handle read-RPC frames during active chat.
if let Some(rpc_resp) = crate::crdt_sync::try_handle_rpc_text(&inner_text) {
if let Ok(resp_text) = serde_json::to_string(&rpc_resp) {
let _ = raw_tx.send(resp_text);
}
continue;
}
match ws::dispatch_inner(&inner_text, &mut pending_perms) {
ws::InnerDispatchResult::CancelChat => {
let _ = chat::cancel_chat(&ctx.state);
}
ws::InnerDispatchResult::Pong => {
let _ = tx.send(WsResponse::Pong);
}
ws::InnerDispatchResult::StartSideQuestion { question, context_messages, config } => {
let tx_side = tx.clone();
let store = ctx.store.clone();
tokio::spawn(async move {
let result = chat::side_question(
context_messages,
question,
config,
store.as_ref(),
|token| {
let _ = tx_side.send(WsResponse::SideQuestionToken {
content: token.to_string(),
});
},
).await;
match result {
Ok(response) => {
let _ = tx_side.send(WsResponse::SideQuestionDone { response });
}
Err(err) => {
let _ = tx_side.send(WsResponse::SideQuestionDone {
response: format!("Error: {err}"),
});
}
}
});
}
ws::InnerDispatchResult::PermissionResolved
| ws::InnerDispatchResult::Ignored => {}
}
}
}
};
match chat_result {
Ok(chat_result) => {
if let Some(sid) = chat_result.session_id {
let _ = tx.send(WsResponse::SessionId { session_id: sid });
}
}
Err(err) => {
let _ = tx.send(ws::error_response(err));
}
}
}
ws::DispatchResult::CancelChat => {
let _ = chat::cancel_chat(&ctx.state);
}
ws::DispatchResult::Pong => {
let _ = tx.send(WsResponse::Pong);
}
ws::DispatchResult::IgnoredPermission => {
// Permission responses outside an active chat are ignored.
}
ws::DispatchResult::StartSideQuestion {
question,
context_messages,
config,
} => {
let tx_side = tx.clone();
let store = ctx.store.clone();
tokio::spawn(async move {
let result = chat::side_question(
context_messages,
question,
config,
store.as_ref(),
|token| {
let _ = tx_side.send(WsResponse::SideQuestionToken {
content: token.to_string(),
});
},
)
.await;
match result {
Ok(response) => {
let _ =
tx_side.send(WsResponse::SideQuestionDone { response });
}
Err(err) => {
let _ = tx_side.send(WsResponse::SideQuestionDone {
response: format!("Error: {err}"),
});
}
}
});
}
ws::DispatchResult::ParseError(msg) => {
let _ = tx.send(ws::error_response(msg));
}
}
}
drop(tx);
let _ = forward.await;
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::io::watcher::WatcherEvent;
use crate::service::status::StatusEvent;
// ── ws_handler integration tests (real WebSocket connection) ─────
use futures::stream::SplitSink;
use poem::EndpointExt;
use tokio_tungstenite::tungstenite;
/// Helper: construct a tungstenite text message from a string.
fn ws_text(s: &str) -> tungstenite::Message {
tungstenite::Message::Text(s.into())
}
/// Helper: start a poem server with ws_handler on an ephemeral port
/// and return the WebSocket URL.
async fn start_test_server() -> (String, Arc<AppContext>) {
let tmp = tempfile::tempdir().unwrap();
let root = tmp.path().to_path_buf();
// Ensure CRDT content store is initialised — load_pipeline_state
// now reads from the in-memory CRDT, not the filesystem.
crate::db::ensure_content_store();
let ctx = Arc::new(AppContext::new_test(root));
let ctx_data = ctx.clone();
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let app = poem::Route::new()
.at("/ws", poem::get(ws_handler))
.data(ctx_data);
tokio::spawn(async move {
let acceptor = poem::listener::TcpAcceptor::from_tokio(listener).unwrap();
let _ = poem::Server::new_with_acceptor(acceptor).run(app).await;
});
// Small delay to let the server start.
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let url = format!("ws://127.0.0.1:{}/ws", addr.port());
(url, ctx)
}
type WsSink = SplitSink<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
tungstenite::Message,
>;
/// Helper: connect and return (sink, stream) plus read the initial
/// pipeline_state and onboarding_status messages that are always sent
/// on connect.
async fn connect_ws(
url: &str,
) -> (
WsSink,
futures::stream::SplitStream<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
>,
serde_json::Value,
) {
let (ws, _resp) = tokio_tungstenite::connect_async(url).await.unwrap();
let (sink, mut stream) = futures::StreamExt::split(ws);
// The first message should be the initial pipeline_state.
let first = tokio::time::timeout(std::time::Duration::from_secs(2), stream.next())
.await
.expect("timeout waiting for initial message")
.expect("stream ended")
.expect("ws error");
let initial: serde_json::Value = match first {
tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(),
other => panic!("expected text message, got: {other:?}"),
};
// The second message is the onboarding_status — consume it so
// callers only see application-level messages.
let second = tokio::time::timeout(std::time::Duration::from_secs(2), stream.next())
.await
.expect("timeout waiting for onboarding_status")
.expect("stream ended")
.expect("ws error");
let onboarding: serde_json::Value = match second {
tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(),
other => panic!("expected text message, got: {other:?}"),
};
assert_eq!(
onboarding["type"], "onboarding_status",
"expected onboarding_status, got: {onboarding}"
);
// Drain any log_entry messages sent as initial history on connect.
// These are buffered before tests send their own requests.
loop {
// Use a very short timeout: if nothing arrives quickly, the burst is done.
let Ok(Some(Ok(msg))) =
tokio::time::timeout(std::time::Duration::from_millis(200), stream.next()).await
else {
break;
};
let val: serde_json::Value = match msg {
tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(),
_ => break,
};
if val["type"] != "log_entry" {
// Unexpected non-log message during drain — this shouldn't happen.
panic!("unexpected message during log drain: {val}");
}
}
(sink, stream, initial)
}
/// Read next non-log_entry text message from the stream with a timeout.
/// Skips any `log_entry` messages that arrive between events.
async fn next_msg(
stream: &mut futures::stream::SplitStream<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
>,
) -> serde_json::Value {
loop {
let msg = tokio::time::timeout(std::time::Duration::from_secs(2), stream.next())
.await
.expect("timeout waiting for message")
.expect("stream ended")
.expect("ws error");
let val: serde_json::Value = match msg {
tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(),
other => panic!("expected text message, got: {other:?}"),
};
if val["type"] != "log_entry" {
return val;
}
}
}
#[tokio::test]
async fn ws_handler_sends_initial_pipeline_state_on_connect() {
let (url, _ctx) = start_test_server().await;
let (_sink, _stream, initial) = connect_ws(&url).await;
assert_eq!(initial["type"], "pipeline_state");
// Verify stage arrays are present (may contain items from the
// shared global CRDT store populated by other tests).
assert!(initial["backlog"].as_array().is_some());
assert!(initial["current"].as_array().is_some());
assert!(initial["qa"].as_array().is_some());
assert!(initial["merge"].as_array().is_some());
}
#[tokio::test]
async fn ws_handler_returns_error_for_invalid_json() {
let (url, _ctx) = start_test_server().await;
let (mut sink, mut stream, _initial) = connect_ws(&url).await;
// Send invalid JSON.
sink.send(ws_text("not valid json")).await.unwrap();
let msg = next_msg(&mut stream).await;
assert_eq!(msg["type"], "error");
assert!(
msg["message"].as_str().unwrap().contains("Invalid request"),
"error message should indicate invalid request, got: {}",
msg["message"]
);
}
#[tokio::test]
async fn ws_handler_returns_error_for_unknown_type() {
let (url, _ctx) = start_test_server().await;
let (mut sink, mut stream, _initial) = connect_ws(&url).await;
// Send a message with an unknown type.
sink.send(ws_text(r#"{"type": "bogus"}"#)).await.unwrap();
let msg = next_msg(&mut stream).await;
assert_eq!(msg["type"], "error");
assert!(msg["message"].as_str().unwrap().contains("Invalid request"));
}
#[tokio::test]
async fn ws_handler_cancel_outside_chat_does_not_error() {
let (url, _ctx) = start_test_server().await;
let (mut sink, mut stream, _initial) = connect_ws(&url).await;
// Send cancel when no chat is active — should not produce an error.
sink.send(ws_text(r#"{"type": "cancel"}"#)).await.unwrap();
// Send another invalid message to check the connection is still alive.
sink.send(ws_text("{}")).await.unwrap();
let msg = next_msg(&mut stream).await;
// The invalid JSON message should produce an error, confirming
// the cancel didn't break the connection.
assert_eq!(msg["type"], "error");
}
#[tokio::test]
async fn ws_handler_permission_response_outside_chat_is_ignored() {
let (url, _ctx) = start_test_server().await;
let (mut sink, mut stream, _initial) = connect_ws(&url).await;
// Send permission response outside an active chat.
sink.send(ws_text(
r#"{"type": "permission_response", "request_id": "x", "approved": true}"#,
))
.await
.unwrap();
// Send a probe message to check the connection is still alive.
sink.send(ws_text("bad")).await.unwrap();
let msg = next_msg(&mut stream).await;
assert_eq!(msg["type"], "error");
assert!(msg["message"].as_str().unwrap().contains("Invalid request"));
}
#[tokio::test]
async fn ws_handler_forwards_watcher_events() {
let (url, ctx) = start_test_server().await;
let (_sink, mut stream, _initial) = connect_ws(&url).await;
// Broadcast a watcher event.
ctx.watcher_tx
.send(WatcherEvent::WorkItem {
stage: "2_current".to_string(),
item_id: "99_story_test".to_string(),
action: "start".to_string(),
commit_msg: "huskies: start 99_story_test".to_string(),
from_stage: None,
})
.unwrap();
let msg = next_msg(&mut stream).await;
assert_eq!(msg["type"], "work_item_changed");
assert_eq!(msg["item_id"], "99_story_test");
assert_eq!(msg["stage"], "2_current");
// After a work-item event, a pipeline_state refresh is pushed.
let state_msg = next_msg(&mut stream).await;
assert_eq!(state_msg["type"], "pipeline_state");
}
#[tokio::test]
async fn ws_handler_forwards_config_changed_without_pipeline_refresh() {
let (url, ctx) = start_test_server().await;
let (_sink, mut stream, _initial) = connect_ws(&url).await;
// Broadcast a config-changed event.
ctx.watcher_tx.send(WatcherEvent::ConfigChanged).unwrap();
let msg = next_msg(&mut stream).await;
assert_eq!(msg["type"], "agent_config_changed");
// Config-changed should NOT be followed by a pipeline_state refresh.
// Send a probe to check no extra message is queued.
ctx.watcher_tx.send(WatcherEvent::ConfigChanged).unwrap();
let msg2 = next_msg(&mut stream).await;
assert_eq!(msg2["type"], "agent_config_changed");
}
#[tokio::test]
async fn ws_handler_forwards_reconciliation_events() {
let (url, ctx) = start_test_server().await;
let (_sink, mut stream, _initial) = connect_ws(&url).await;
// Broadcast a reconciliation event.
ctx.reconciliation_tx
.send(crate::agents::ReconciliationEvent {
story_id: "50_story_recon".to_string(),
status: "checking".to_string(),
message: "Checking story...".to_string(),
})
.unwrap();
let msg = next_msg(&mut stream).await;
assert_eq!(msg["type"], "reconciliation_progress");
assert_eq!(msg["story_id"], "50_story_recon");
assert_eq!(msg["status"], "checking");
assert_eq!(msg["message"], "Checking story...");
}
#[tokio::test]
async fn ws_handler_handles_client_disconnect_gracefully() {
let (url, _ctx) = start_test_server().await;
let (mut sink, _stream, _initial) = connect_ws(&url).await;
// Close the connection — should not panic the server.
sink.close().await.unwrap();
// Give the server a moment to process the close.
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Connect again to verify server is still alive.
let (_sink2, _stream2, initial2) = connect_ws(&url).await;
assert_eq!(initial2["type"], "pipeline_state");
}
/// Read the next `status_update` whose story_id or story_name contains `needle`,
/// within a timeout. Skips `log_entry` noise and unrelated status events so
/// genuine server log noise cannot cause false positives or negatives.
async fn next_status_update_containing(
stream: &mut futures::stream::SplitStream<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
>,
needle: &str,
timeout_ms: u64,
) -> Option<serde_json::Value> {
let deadline = std::time::Instant::now() + std::time::Duration::from_millis(timeout_ms);
loop {
let remaining = deadline.saturating_duration_since(std::time::Instant::now());
if remaining.is_zero() {
return None;
}
let msg = tokio::time::timeout(remaining, stream.next())
.await
.ok()?
.expect("stream ended")
.expect("ws error");
let val: serde_json::Value = match msg {
tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).ok()?,
_ => continue,
};
if val["type"] == "status_update" {
let event = &val["event"];
let story_id = event["story_id"].as_str().unwrap_or("");
let story_name = event["story_name"].as_str().unwrap_or("");
if story_id.contains(needle) || story_name.contains(needle) {
return Some(val);
}
}
// Skip log_entry and other unrelated messages.
}
}
// ── Status broadcaster integration tests ─────────────────────────
/// Publishing a status event via `services.status` must result in a
/// `status_update` WebSocket message with structured fields delivered to the
/// connected client.
#[tokio::test]
async fn ws_handler_forwards_status_events_as_status_update() {
let (url, ctx) = start_test_server().await;
let (_sink, mut stream, _initial) = connect_ws(&url).await;
// Use a story ID unique enough that genuine server logs won't match it.
ctx.services.status.publish(StatusEvent::StageTransition {
story_id: "77_story_status_fwd_test".to_string(),
story_name: Some("StatusFwdTest".to_string()),
from_stage: "1_backlog".to_string(),
to_stage: "2_current".to_string(),
});
// The handler must forward it as a status_update with structured fields.
let msg = next_status_update_containing(&mut stream, "StatusFwdTest", 2000)
.await
.expect("expected a status_update for the status event");
assert_eq!(msg["type"], "status_update");
let event = &msg["event"];
assert_eq!(event["type"], "stage_transition");
assert_eq!(event["story_id"], "77_story_status_fwd_test");
assert_eq!(event["story_name"], "StatusFwdTest");
assert_eq!(event["from_stage"], "1_backlog");
assert_eq!(event["to_stage"], "2_current");
}
/// Multi-project isolation: a client connected to project A's server must
/// NOT receive status events published on project B's broadcaster.
#[tokio::test]
async fn ws_handler_multi_project_status_isolation() {
// Start two independent servers (each with its own AppContext / Services).
let (url_a, ctx_a) = start_test_server().await;
let (url_b, _ctx_b) = start_test_server().await;
let (_sink_a, mut stream_a, _) = connect_ws(&url_a).await;
let (_sink_b, mut stream_b, _) = connect_ws(&url_b).await;
// Use a needle unique enough that genuine server logs won't match.
let needle = "ProjAIsolation7734";
ctx_a.services.status.publish(StatusEvent::MergeFailure {
story_id: "10_story_proj_a_isolation".to_string(),
story_name: Some(needle.to_string()),
reason: "conflict".to_string(),
});
// Client A must receive the status_update with structured fields.
let msg_a = next_status_update_containing(&mut stream_a, needle, 2000)
.await
.expect("client A should receive the status event");
assert_eq!(msg_a["type"], "status_update");
assert_eq!(msg_a["event"]["story_name"], needle);
// Client B must NOT receive any status_update containing the needle.
let msg_b = next_status_update_containing(&mut stream_b, needle, 300).await;
assert!(
msg_b.is_none(),
"client B must not receive project A's status event, got: {msg_b:?}"
);
}
/// When `web_ui_status_consumer = false` in project.toml, the WebSocket
/// handler must not forward status events to the connected client.
#[tokio::test]
async fn ws_handler_status_consumer_disabled_via_config() {
let tmp = tempfile::tempdir().unwrap();
let root = tmp.path().to_path_buf();
// Write a project.toml that disables the web UI status consumer.
let huskies_dir = root.join(".huskies");
std::fs::create_dir_all(&huskies_dir).unwrap();
std::fs::write(
huskies_dir.join("project.toml"),
"web_ui_status_consumer = false\n",
)
.unwrap();
crate::db::ensure_content_store();
let ctx = Arc::new(AppContext::new_test(root));
let ctx_data = ctx.clone();
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let app = poem::Route::new()
.at("/ws", poem::get(ws_handler))
.data(ctx_data);
tokio::spawn(async move {
let acceptor = poem::listener::TcpAcceptor::from_tokio(listener).unwrap();
let _ = poem::Server::new_with_acceptor(acceptor).run(app).await;
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let url = format!("ws://127.0.0.1:{}/ws", addr.port());
let (_sink, mut stream, _) = connect_ws(&url).await;
// Use a unique needle — genuine server logs will never contain this.
let needle = "DisabledConsumer9182";
ctx.services.status.publish(StatusEvent::StoryBlocked {
story_id: "55_story_disabled_consumer".to_string(),
story_name: Some(needle.to_string()),
reason: "test".to_string(),
});
// Consumer is disabled — no status_update with this needle should arrive.
let msg = next_status_update_containing(&mut stream, needle, 500).await;
assert!(
msg.is_none(),
"disabled consumer must not forward status events, got: {msg:?}"
);
}
}
+275
View File
@@ -0,0 +1,275 @@
//! WebSocket transport adapter — accept connection, serialise/deserialise frames,
//! invoke service methods. No business logic, no inline state transitions.
use crate::config::ProjectConfig;
use crate::http::context::AppContext;
use crate::llm::chat;
use crate::service::ws::{self, WsResponse};
use futures::{SinkExt, StreamExt};
use poem::handler;
use poem::web::Data;
use poem::web::websocket::{Message as WsMessage, WebSocket};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
use crate::http::context::PermissionDecision;
// Re-export WizardStepInfo for any downstream code that imports it from here.
#[allow(unused_imports)]
pub use crate::service::ws::WizardStepInfo;
#[handler]
/// WebSocket endpoint for streaming chat responses, cancellation, and
/// filesystem watcher notifications.
///
/// Accepts JSON `WsRequest` messages and streams `WsResponse` messages.
pub async fn ws_handler(ws: WebSocket, ctx: Data<&Arc<AppContext>>) -> impl poem::IntoResponse {
let ctx = ctx.0.clone();
ws.on_upgrade(move |socket| async move {
let (mut sink, mut stream) = socket.split();
let (tx, mut rx) = mpsc::unbounded_channel::<WsResponse>();
// Separate channel for pre-serialized messages (e.g. RPC responses).
let (raw_tx, mut raw_rx) = mpsc::unbounded_channel::<String>();
let forward = tokio::spawn(async move {
loop {
tokio::select! {
msg = rx.recv() => match msg {
Some(msg) => {
if let Ok(text) = serde_json::to_string(&msg)
&& sink.send(WsMessage::Text(text)).await.is_err()
{
break;
}
}
None => break,
},
raw = raw_rx.recv() => match raw {
Some(text) => {
if sink.send(WsMessage::Text(text)).await.is_err() {
break;
}
}
None => break,
},
}
}
});
// ── Initial state burst ────────────────────────────────────────
if let Some(state) = ws::load_initial_pipeline_state(ctx.as_ref()) {
let _ = tx.send(state);
}
let _ = tx.send(ws::check_onboarding(ctx.as_ref()));
if let Some(wiz) = ws::load_wizard_state(ctx.as_ref()) {
let _ = tx.send(wiz);
}
for log in ws::load_recent_logs(100) {
let _ = tx.send(log);
}
// ── Background subscriptions ───────────────────────────────────
ws::subscribe_logs(tx.clone());
ws::subscribe_watcher(tx.clone(), ctx.clone(), ctx.watcher_tx.subscribe());
ws::subscribe_reconciliation(tx.clone(), ctx.reconciliation_tx.subscribe());
// Subscribe to the status broadcaster if web UI consumer is enabled (default: true).
let status_enabled = ctx
.state
.get_project_root()
.ok()
.and_then(|root| ProjectConfig::load(&root).ok())
.map(|c| c.web_ui_status_consumer)
.unwrap_or(true);
if status_enabled {
ws::subscribe_status(tx.clone(), ctx.services.status.subscribe());
}
// Map of pending permission request_id -> oneshot responder.
let mut pending_perms: HashMap<String, oneshot::Sender<PermissionDecision>> =
HashMap::new();
loop {
// Outer loop: wait for the next WebSocket message.
let Some(Ok(WsMessage::Text(text))) = stream.next().await else {
break;
};
// Handle read-RPC frames (discriminated by "kind", not "type").
if let Some(rpc_resp) = crate::crdt_sync::try_handle_rpc_text(&text) {
if let Ok(resp_text) = serde_json::to_string(&rpc_resp) {
let _ = raw_tx.send(resp_text);
}
continue;
}
match ws::dispatch_outer(&text) {
ws::DispatchResult::StartChat { messages, config } => {
let tx_updates = tx.clone();
let tx_tokens = tx.clone();
let tx_thinking = tx.clone();
let tx_activity = tx.clone();
let ctx_clone = ctx.clone();
let chat_fut = chat::chat(
messages,
config,
&ctx_clone.state,
ctx_clone.store.as_ref(),
move |history| {
let _ = tx_updates.send(WsResponse::Update {
messages: history.to_vec(),
});
},
move |token| {
let _ = tx_tokens.send(WsResponse::Token {
content: token.to_string(),
});
},
move |thinking: &str| {
let _ = tx_thinking.send(WsResponse::ThinkingToken {
content: thinking.to_string(),
});
},
move |tool_name: &str| {
let _ = tx_activity.send(WsResponse::ToolActivity {
tool_name: tool_name.to_string(),
});
},
);
tokio::pin!(chat_fut);
let mut perm_rx = ctx.services.perm_rx.lock().await;
let chat_result = loop {
tokio::select! {
result = &mut chat_fut => break result,
Some(perm_fwd) = perm_rx.recv() => {
let _ = tx.send(ws::permission_request_response(
&perm_fwd.request_id,
&perm_fwd.tool_name,
&perm_fwd.tool_input,
));
pending_perms.insert(
perm_fwd.request_id,
perm_fwd.response_tx,
);
}
Some(Ok(WsMessage::Text(inner_text))) = stream.next() => {
// Handle read-RPC frames during active chat.
if let Some(rpc_resp) = crate::crdt_sync::try_handle_rpc_text(&inner_text) {
if let Ok(resp_text) = serde_json::to_string(&rpc_resp) {
let _ = raw_tx.send(resp_text);
}
continue;
}
match ws::dispatch_inner(&inner_text, &mut pending_perms) {
ws::InnerDispatchResult::CancelChat => {
let _ = chat::cancel_chat(&ctx.state);
}
ws::InnerDispatchResult::Pong => {
let _ = tx.send(WsResponse::Pong);
}
ws::InnerDispatchResult::StartSideQuestion { question, context_messages, config } => {
let tx_side = tx.clone();
let store = ctx.store.clone();
tokio::spawn(async move {
let result = chat::side_question(
context_messages,
question,
config,
store.as_ref(),
|token| {
let _ = tx_side.send(WsResponse::SideQuestionToken {
content: token.to_string(),
});
},
).await;
match result {
Ok(response) => {
let _ = tx_side.send(WsResponse::SideQuestionDone { response });
}
Err(err) => {
let _ = tx_side.send(WsResponse::SideQuestionDone {
response: format!("Error: {err}"),
});
}
}
});
}
ws::InnerDispatchResult::PermissionResolved
| ws::InnerDispatchResult::Ignored => {}
}
}
}
};
match chat_result {
Ok(chat_result) => {
if let Some(sid) = chat_result.session_id {
let _ = tx.send(WsResponse::SessionId { session_id: sid });
}
}
Err(err) => {
let _ = tx.send(ws::error_response(err));
}
}
}
ws::DispatchResult::CancelChat => {
let _ = chat::cancel_chat(&ctx.state);
}
ws::DispatchResult::Pong => {
let _ = tx.send(WsResponse::Pong);
}
ws::DispatchResult::IgnoredPermission => {
// Permission responses outside an active chat are ignored.
}
ws::DispatchResult::StartSideQuestion {
question,
context_messages,
config,
} => {
let tx_side = tx.clone();
let store = ctx.store.clone();
tokio::spawn(async move {
let result = chat::side_question(
context_messages,
question,
config,
store.as_ref(),
|token| {
let _ = tx_side.send(WsResponse::SideQuestionToken {
content: token.to_string(),
});
},
)
.await;
match result {
Ok(response) => {
let _ =
tx_side.send(WsResponse::SideQuestionDone { response });
}
Err(err) => {
let _ = tx_side.send(WsResponse::SideQuestionDone {
response: format!("Error: {err}"),
});
}
}
});
}
ws::DispatchResult::ParseError(msg) => {
let _ = tx.send(ws::error_response(msg));
}
}
}
drop(tx);
let _ = forward.await;
})
}
#[cfg(test)]
mod tests;
+466
View File
@@ -0,0 +1,466 @@
//! Integration tests for the WebSocket handler (`ws_handler`).
//!
//! Tests cover: initial state burst on connect, error responses for invalid
//! messages, cancel/permission handling outside chat, watcher event forwarding,
//! reconciliation event forwarding, status broadcaster forwarding, and graceful
//! client disconnect.
use super::*;
use crate::io::watcher::WatcherEvent;
use crate::service::status::StatusEvent;
// ── ws_handler integration tests (real WebSocket connection) ─────
use futures::stream::SplitSink;
use poem::EndpointExt;
use tokio_tungstenite::tungstenite;
/// Helper: construct a tungstenite text message from a string.
fn ws_text(s: &str) -> tungstenite::Message {
tungstenite::Message::Text(s.into())
}
/// Helper: start a poem server with ws_handler on an ephemeral port
/// and return the WebSocket URL.
async fn start_test_server() -> (String, Arc<AppContext>) {
let tmp = tempfile::tempdir().unwrap();
let root = tmp.path().to_path_buf();
// Ensure CRDT content store is initialised — load_pipeline_state
// now reads from the in-memory CRDT, not the filesystem.
crate::db::ensure_content_store();
let ctx = Arc::new(AppContext::new_test(root));
let ctx_data = ctx.clone();
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let app = poem::Route::new()
.at("/ws", poem::get(ws_handler))
.data(ctx_data);
tokio::spawn(async move {
let acceptor = poem::listener::TcpAcceptor::from_tokio(listener).unwrap();
let _ = poem::Server::new_with_acceptor(acceptor).run(app).await;
});
// Small delay to let the server start.
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let url = format!("ws://127.0.0.1:{}/ws", addr.port());
(url, ctx)
}
type WsSink = SplitSink<
tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
tungstenite::Message,
>;
/// Helper: connect and return (sink, stream) plus read the initial
/// pipeline_state and onboarding_status messages that are always sent
/// on connect.
async fn connect_ws(
url: &str,
) -> (
WsSink,
futures::stream::SplitStream<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
>,
serde_json::Value,
) {
let (ws, _resp) = tokio_tungstenite::connect_async(url).await.unwrap();
let (sink, mut stream) = futures::StreamExt::split(ws);
// The first message should be the initial pipeline_state.
let first = tokio::time::timeout(std::time::Duration::from_secs(2), stream.next())
.await
.expect("timeout waiting for initial message")
.expect("stream ended")
.expect("ws error");
let initial: serde_json::Value = match first {
tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(),
other => panic!("expected text message, got: {other:?}"),
};
// The second message is the onboarding_status — consume it so
// callers only see application-level messages.
let second = tokio::time::timeout(std::time::Duration::from_secs(2), stream.next())
.await
.expect("timeout waiting for onboarding_status")
.expect("stream ended")
.expect("ws error");
let onboarding: serde_json::Value = match second {
tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(),
other => panic!("expected text message, got: {other:?}"),
};
assert_eq!(
onboarding["type"], "onboarding_status",
"expected onboarding_status, got: {onboarding}"
);
// Drain any log_entry messages sent as initial history on connect.
// These are buffered before tests send their own requests.
loop {
// Use a very short timeout: if nothing arrives quickly, the burst is done.
let Ok(Some(Ok(msg))) =
tokio::time::timeout(std::time::Duration::from_millis(200), stream.next()).await
else {
break;
};
let val: serde_json::Value = match msg {
tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(),
_ => break,
};
if val["type"] != "log_entry" {
// Unexpected non-log message during drain — this shouldn't happen.
panic!("unexpected message during log drain: {val}");
}
}
(sink, stream, initial)
}
/// Read next non-log_entry text message from the stream with a timeout.
/// Skips any `log_entry` messages that arrive between events.
async fn next_msg(
stream: &mut futures::stream::SplitStream<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
>,
) -> serde_json::Value {
loop {
let msg = tokio::time::timeout(std::time::Duration::from_secs(2), stream.next())
.await
.expect("timeout waiting for message")
.expect("stream ended")
.expect("ws error");
let val: serde_json::Value = match msg {
tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(),
other => panic!("expected text message, got: {other:?}"),
};
if val["type"] != "log_entry" {
return val;
}
}
}
#[tokio::test]
async fn ws_handler_sends_initial_pipeline_state_on_connect() {
let (url, _ctx) = start_test_server().await;
let (_sink, _stream, initial) = connect_ws(&url).await;
assert_eq!(initial["type"], "pipeline_state");
// Verify stage arrays are present (may contain items from the
// shared global CRDT store populated by other tests).
assert!(initial["backlog"].as_array().is_some());
assert!(initial["current"].as_array().is_some());
assert!(initial["qa"].as_array().is_some());
assert!(initial["merge"].as_array().is_some());
}
#[tokio::test]
async fn ws_handler_returns_error_for_invalid_json() {
let (url, _ctx) = start_test_server().await;
let (mut sink, mut stream, _initial) = connect_ws(&url).await;
// Send invalid JSON.
sink.send(ws_text("not valid json")).await.unwrap();
let msg = next_msg(&mut stream).await;
assert_eq!(msg["type"], "error");
assert!(
msg["message"].as_str().unwrap().contains("Invalid request"),
"error message should indicate invalid request, got: {}",
msg["message"]
);
}
#[tokio::test]
async fn ws_handler_returns_error_for_unknown_type() {
let (url, _ctx) = start_test_server().await;
let (mut sink, mut stream, _initial) = connect_ws(&url).await;
// Send a message with an unknown type.
sink.send(ws_text(r#"{"type": "bogus"}"#)).await.unwrap();
let msg = next_msg(&mut stream).await;
assert_eq!(msg["type"], "error");
assert!(msg["message"].as_str().unwrap().contains("Invalid request"));
}
#[tokio::test]
async fn ws_handler_cancel_outside_chat_does_not_error() {
let (url, _ctx) = start_test_server().await;
let (mut sink, mut stream, _initial) = connect_ws(&url).await;
// Send cancel when no chat is active — should not produce an error.
sink.send(ws_text(r#"{"type": "cancel"}"#)).await.unwrap();
// Send another invalid message to check the connection is still alive.
sink.send(ws_text("{}")).await.unwrap();
let msg = next_msg(&mut stream).await;
// The invalid JSON message should produce an error, confirming
// the cancel didn't break the connection.
assert_eq!(msg["type"], "error");
}
#[tokio::test]
async fn ws_handler_permission_response_outside_chat_is_ignored() {
let (url, _ctx) = start_test_server().await;
let (mut sink, mut stream, _initial) = connect_ws(&url).await;
// Send permission response outside an active chat.
sink.send(ws_text(
r#"{"type": "permission_response", "request_id": "x", "approved": true}"#,
))
.await
.unwrap();
// Send a probe message to check the connection is still alive.
sink.send(ws_text("bad")).await.unwrap();
let msg = next_msg(&mut stream).await;
assert_eq!(msg["type"], "error");
assert!(msg["message"].as_str().unwrap().contains("Invalid request"));
}
#[tokio::test]
async fn ws_handler_forwards_watcher_events() {
let (url, ctx) = start_test_server().await;
let (_sink, mut stream, _initial) = connect_ws(&url).await;
// Broadcast a watcher event.
ctx.watcher_tx
.send(WatcherEvent::WorkItem {
stage: "2_current".to_string(),
item_id: "99_story_test".to_string(),
action: "start".to_string(),
commit_msg: "huskies: start 99_story_test".to_string(),
from_stage: None,
})
.unwrap();
let msg = next_msg(&mut stream).await;
assert_eq!(msg["type"], "work_item_changed");
assert_eq!(msg["item_id"], "99_story_test");
assert_eq!(msg["stage"], "2_current");
// After a work-item event, a pipeline_state refresh is pushed.
let state_msg = next_msg(&mut stream).await;
assert_eq!(state_msg["type"], "pipeline_state");
}
#[tokio::test]
async fn ws_handler_forwards_config_changed_without_pipeline_refresh() {
let (url, ctx) = start_test_server().await;
let (_sink, mut stream, _initial) = connect_ws(&url).await;
// Broadcast a config-changed event.
ctx.watcher_tx.send(WatcherEvent::ConfigChanged).unwrap();
let msg = next_msg(&mut stream).await;
assert_eq!(msg["type"], "agent_config_changed");
// Config-changed should NOT be followed by a pipeline_state refresh.
// Send a probe to check no extra message is queued.
ctx.watcher_tx.send(WatcherEvent::ConfigChanged).unwrap();
let msg2 = next_msg(&mut stream).await;
assert_eq!(msg2["type"], "agent_config_changed");
}
#[tokio::test]
async fn ws_handler_forwards_reconciliation_events() {
let (url, ctx) = start_test_server().await;
let (_sink, mut stream, _initial) = connect_ws(&url).await;
// Broadcast a reconciliation event.
ctx.reconciliation_tx
.send(crate::agents::ReconciliationEvent {
story_id: "50_story_recon".to_string(),
status: "checking".to_string(),
message: "Checking story...".to_string(),
})
.unwrap();
let msg = next_msg(&mut stream).await;
assert_eq!(msg["type"], "reconciliation_progress");
assert_eq!(msg["story_id"], "50_story_recon");
assert_eq!(msg["status"], "checking");
assert_eq!(msg["message"], "Checking story...");
}
#[tokio::test]
async fn ws_handler_handles_client_disconnect_gracefully() {
let (url, _ctx) = start_test_server().await;
let (mut sink, _stream, _initial) = connect_ws(&url).await;
// Close the connection — should not panic the server.
sink.close().await.unwrap();
// Give the server a moment to process the close.
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Connect again to verify server is still alive.
let (_sink2, _stream2, initial2) = connect_ws(&url).await;
assert_eq!(initial2["type"], "pipeline_state");
}
/// Read the next `status_update` whose story_id or story_name contains `needle`,
/// within a timeout. Skips `log_entry` noise and unrelated status events so
/// genuine server log noise cannot cause false positives or negatives.
async fn next_status_update_containing(
stream: &mut futures::stream::SplitStream<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
>,
needle: &str,
timeout_ms: u64,
) -> Option<serde_json::Value> {
let deadline = std::time::Instant::now() + std::time::Duration::from_millis(timeout_ms);
loop {
let remaining = deadline.saturating_duration_since(std::time::Instant::now());
if remaining.is_zero() {
return None;
}
let msg = tokio::time::timeout(remaining, stream.next())
.await
.ok()?
.expect("stream ended")
.expect("ws error");
let val: serde_json::Value = match msg {
tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).ok()?,
_ => continue,
};
if val["type"] == "status_update" {
let event = &val["event"];
let story_id = event["story_id"].as_str().unwrap_or("");
let story_name = event["story_name"].as_str().unwrap_or("");
if story_id.contains(needle) || story_name.contains(needle) {
return Some(val);
}
}
// Skip log_entry and other unrelated messages.
}
}
// ── Status broadcaster integration tests ─────────────────────────
/// Publishing a status event via `services.status` must result in a
/// `status_update` WebSocket message with structured fields delivered to the
/// connected client.
#[tokio::test]
async fn ws_handler_forwards_status_events_as_status_update() {
let (url, ctx) = start_test_server().await;
let (_sink, mut stream, _initial) = connect_ws(&url).await;
// Use a story ID unique enough that genuine server logs won't match it.
ctx.services.status.publish(StatusEvent::StageTransition {
story_id: "77_story_status_fwd_test".to_string(),
story_name: Some("StatusFwdTest".to_string()),
from_stage: "1_backlog".to_string(),
to_stage: "2_current".to_string(),
});
// The handler must forward it as a status_update with structured fields.
let msg = next_status_update_containing(&mut stream, "StatusFwdTest", 2000)
.await
.expect("expected a status_update for the status event");
assert_eq!(msg["type"], "status_update");
let event = &msg["event"];
assert_eq!(event["type"], "stage_transition");
assert_eq!(event["story_id"], "77_story_status_fwd_test");
assert_eq!(event["story_name"], "StatusFwdTest");
assert_eq!(event["from_stage"], "1_backlog");
assert_eq!(event["to_stage"], "2_current");
}
/// Multi-project isolation: a client connected to project A's server must
/// NOT receive status events published on project B's broadcaster.
#[tokio::test]
async fn ws_handler_multi_project_status_isolation() {
// Start two independent servers (each with its own AppContext / Services).
let (url_a, ctx_a) = start_test_server().await;
let (url_b, _ctx_b) = start_test_server().await;
let (_sink_a, mut stream_a, _) = connect_ws(&url_a).await;
let (_sink_b, mut stream_b, _) = connect_ws(&url_b).await;
// Use a needle unique enough that genuine server logs won't match.
let needle = "ProjAIsolation7734";
ctx_a.services.status.publish(StatusEvent::MergeFailure {
story_id: "10_story_proj_a_isolation".to_string(),
story_name: Some(needle.to_string()),
reason: "conflict".to_string(),
});
// Client A must receive the status_update with structured fields.
let msg_a = next_status_update_containing(&mut stream_a, needle, 2000)
.await
.expect("client A should receive the status event");
assert_eq!(msg_a["type"], "status_update");
assert_eq!(msg_a["event"]["story_name"], needle);
// Client B must NOT receive any status_update containing the needle.
let msg_b = next_status_update_containing(&mut stream_b, needle, 300).await;
assert!(
msg_b.is_none(),
"client B must not receive project A's status event, got: {msg_b:?}"
);
}
/// When `web_ui_status_consumer = false` in project.toml, the WebSocket
/// handler must not forward status events to the connected client.
#[tokio::test]
async fn ws_handler_status_consumer_disabled_via_config() {
let tmp = tempfile::tempdir().unwrap();
let root = tmp.path().to_path_buf();
// Write a project.toml that disables the web UI status consumer.
let huskies_dir = root.join(".huskies");
std::fs::create_dir_all(&huskies_dir).unwrap();
std::fs::write(
huskies_dir.join("project.toml"),
"web_ui_status_consumer = false\n",
)
.unwrap();
crate::db::ensure_content_store();
let ctx = Arc::new(AppContext::new_test(root));
let ctx_data = ctx.clone();
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let app = poem::Route::new()
.at("/ws", poem::get(ws_handler))
.data(ctx_data);
tokio::spawn(async move {
let acceptor = poem::listener::TcpAcceptor::from_tokio(listener).unwrap();
let _ = poem::Server::new_with_acceptor(acceptor).run(app).await;
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let url = format!("ws://127.0.0.1:{}/ws", addr.port());
let (_sink, mut stream, _) = connect_ws(&url).await;
// Use a unique needle — genuine server logs will never contain this.
let needle = "DisabledConsumer9182";
ctx.services.status.publish(StatusEvent::StoryBlocked {
story_id: "55_story_disabled_consumer".to_string(),
story_name: Some(needle.to_string()),
reason: "test".to_string(),
});
// Consumer is disabled — no status_update with this needle should arrive.
let msg = next_status_update_containing(&mut stream, needle, 500).await;
assert!(
msg.is_none(),
"disabled consumer must not forward status events, got: {msg:?}"
);
}