Files
huskies/server/src/http/ws.rs
T

538 lines
22 KiB
Rust
Raw Normal View History

2026-04-24 14:32:36 +00:00
//! WebSocket transport adapter — accept connection, serialise/deserialise frames,
//! invoke service methods. No business logic, no inline state transitions.
use crate::http::context::AppContext;
use crate::llm::chat;
2026-04-24 14:32:36 +00:00
use crate::service::ws::{self, WsResponse};
use futures::{SinkExt, StreamExt};
use poem::handler;
use poem::web::Data;
use poem::web::websocket::{Message as WsMessage, WebSocket};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
2026-04-24 14:32:36 +00:00
use crate::http::context::PermissionDecision;
2026-04-24 14:32:36 +00:00
// Re-export WizardStepInfo for any downstream code that imports it from here.
#[allow(unused_imports)]
pub use crate::service::ws::WizardStepInfo;
#[handler]
/// WebSocket endpoint for streaming chat responses, cancellation, and
/// filesystem watcher notifications.
///
/// Accepts JSON `WsRequest` messages and streams `WsResponse` messages.
pub async fn ws_handler(ws: WebSocket, ctx: Data<&Arc<AppContext>>) -> impl poem::IntoResponse {
let ctx = ctx.0.clone();
ws.on_upgrade(move |socket| async move {
let (mut sink, mut stream) = socket.split();
let (tx, mut rx) = mpsc::unbounded_channel::<WsResponse>();
let forward = tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
if let Ok(text) = serde_json::to_string(&msg)
&& sink.send(WsMessage::Text(text)).await.is_err()
{
break;
}
}
});
2026-04-24 14:32:36 +00:00
// ── Initial state burst ────────────────────────────────────
if let Some(state) = ws::load_initial_pipeline_state(ctx.as_ref()) {
let _ = tx.send(state);
}
2026-04-24 14:32:36 +00:00
let _ = tx.send(ws::check_onboarding(ctx.as_ref()));
if let Some(wiz) = ws::load_wizard_state(ctx.as_ref()) {
let _ = tx.send(wiz);
}
2026-04-24 14:32:36 +00:00
for log in ws::load_recent_logs(100) {
let _ = tx.send(log);
}
2026-04-24 14:32:36 +00:00
// ── Background subscriptions ───────────────────────────────
ws::subscribe_logs(tx.clone());
ws::subscribe_watcher(tx.clone(), ctx.clone(), ctx.watcher_tx.subscribe());
ws::subscribe_reconciliation(tx.clone(), ctx.reconciliation_tx.subscribe());
2026-04-24 14:32:36 +00:00
// Map of pending permission request_id -> oneshot responder.
let mut pending_perms: HashMap<String, oneshot::Sender<PermissionDecision>> =
HashMap::new();
loop {
// Outer loop: wait for the next WebSocket message.
let Some(Ok(WsMessage::Text(text))) = stream.next().await else {
break;
};
2026-04-24 14:32:36 +00:00
match ws::dispatch_outer(&text) {
ws::DispatchResult::StartChat { messages, config } => {
let tx_updates = tx.clone();
let tx_tokens = tx.clone();
let tx_thinking = tx.clone();
let tx_activity = tx.clone();
let ctx_clone = ctx.clone();
let chat_fut = chat::chat(
messages,
config,
&ctx_clone.state,
ctx_clone.store.as_ref(),
move |history| {
let _ = tx_updates.send(WsResponse::Update {
messages: history.to_vec(),
});
},
move |token| {
let _ = tx_tokens.send(WsResponse::Token {
content: token.to_string(),
});
},
move |thinking: &str| {
let _ = tx_thinking.send(WsResponse::ThinkingToken {
content: thinking.to_string(),
});
},
move |tool_name: &str| {
let _ = tx_activity.send(WsResponse::ToolActivity {
tool_name: tool_name.to_string(),
});
},
);
tokio::pin!(chat_fut);
let mut perm_rx = ctx.services.perm_rx.lock().await;
let chat_result = loop {
tokio::select! {
result = &mut chat_fut => break result,
Some(perm_fwd) = perm_rx.recv() => {
2026-04-24 14:32:36 +00:00
let _ = tx.send(ws::permission_request_response(
&perm_fwd.request_id,
&perm_fwd.tool_name,
&perm_fwd.tool_input,
));
pending_perms.insert(
perm_fwd.request_id,
perm_fwd.response_tx,
);
}
Some(Ok(WsMessage::Text(inner_text))) = stream.next() => {
2026-04-24 14:32:36 +00:00
match ws::dispatch_inner(&inner_text, &mut pending_perms) {
ws::InnerDispatchResult::CancelChat => {
let _ = chat::cancel_chat(&ctx.state);
}
2026-04-24 14:32:36 +00:00
ws::InnerDispatchResult::Pong => {
let _ = tx.send(WsResponse::Pong);
}
2026-04-24 14:32:36 +00:00
ws::InnerDispatchResult::StartSideQuestion { question, context_messages, config } => {
let tx_side = tx.clone();
let store = ctx.store.clone();
tokio::spawn(async move {
let result = chat::side_question(
context_messages,
question,
config,
store.as_ref(),
|token| {
let _ = tx_side.send(WsResponse::SideQuestionToken {
content: token.to_string(),
});
},
).await;
match result {
Ok(response) => {
let _ = tx_side.send(WsResponse::SideQuestionDone { response });
}
Err(err) => {
let _ = tx_side.send(WsResponse::SideQuestionDone {
response: format!("Error: {err}"),
});
}
}
});
}
2026-04-24 14:32:36 +00:00
ws::InnerDispatchResult::PermissionResolved
| ws::InnerDispatchResult::Ignored => {}
}
}
}
};
match chat_result {
Ok(chat_result) => {
if let Some(sid) = chat_result.session_id {
let _ = tx.send(WsResponse::SessionId { session_id: sid });
}
}
Err(err) => {
2026-04-24 14:32:36 +00:00
let _ = tx.send(ws::error_response(err));
}
}
}
2026-04-24 14:32:36 +00:00
ws::DispatchResult::CancelChat => {
let _ = chat::cancel_chat(&ctx.state);
}
2026-04-24 14:32:36 +00:00
ws::DispatchResult::Pong => {
let _ = tx.send(WsResponse::Pong);
}
2026-04-24 14:32:36 +00:00
ws::DispatchResult::IgnoredPermission => {
// Permission responses outside an active chat are ignored.
}
2026-04-24 14:32:36 +00:00
ws::DispatchResult::StartSideQuestion {
question,
context_messages,
config,
2026-04-24 14:32:36 +00:00
} => {
let tx_side = tx.clone();
let store = ctx.store.clone();
tokio::spawn(async move {
let result = chat::side_question(
context_messages,
question,
config,
store.as_ref(),
|token| {
let _ = tx_side.send(WsResponse::SideQuestionToken {
content: token.to_string(),
});
},
)
.await;
match result {
Ok(response) => {
2026-04-24 14:32:36 +00:00
let _ =
tx_side.send(WsResponse::SideQuestionDone { response });
}
Err(err) => {
let _ = tx_side.send(WsResponse::SideQuestionDone {
response: format!("Error: {err}"),
});
}
}
});
}
2026-04-24 14:32:36 +00:00
ws::DispatchResult::ParseError(msg) => {
let _ = tx.send(ws::error_response(msg));
}
}
}
drop(tx);
let _ = forward.await;
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::io::watcher::WatcherEvent;
// ── ws_handler integration tests (real WebSocket connection) ─────
use futures::stream::SplitSink;
use poem::EndpointExt;
use tokio_tungstenite::tungstenite;
/// Helper: construct a tungstenite text message from a string.
fn ws_text(s: &str) -> tungstenite::Message {
tungstenite::Message::Text(s.into())
}
/// Helper: start a poem server with ws_handler on an ephemeral port
/// and return the WebSocket URL.
async fn start_test_server() -> (String, Arc<AppContext>) {
let tmp = tempfile::tempdir().unwrap();
let root = tmp.path().to_path_buf();
// Ensure CRDT content store is initialised — load_pipeline_state
// now reads from the in-memory CRDT, not the filesystem.
crate::db::ensure_content_store();
let ctx = Arc::new(AppContext::new_test(root));
let ctx_data = ctx.clone();
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let app = poem::Route::new()
.at("/ws", poem::get(ws_handler))
.data(ctx_data);
tokio::spawn(async move {
let acceptor = poem::listener::TcpAcceptor::from_tokio(listener).unwrap();
let _ = poem::Server::new_with_acceptor(acceptor).run(app).await;
});
// Small delay to let the server start.
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let url = format!("ws://127.0.0.1:{}/ws", addr.port());
(url, ctx)
}
type WsSink = SplitSink<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
tungstenite::Message,
>;
/// Helper: connect and return (sink, stream) plus read the initial
/// pipeline_state and onboarding_status messages that are always sent
/// on connect.
async fn connect_ws(
url: &str,
) -> (
WsSink,
futures::stream::SplitStream<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
>,
serde_json::Value,
) {
let (ws, _resp) = tokio_tungstenite::connect_async(url).await.unwrap();
let (sink, mut stream) = futures::StreamExt::split(ws);
// The first message should be the initial pipeline_state.
let first = tokio::time::timeout(std::time::Duration::from_secs(2), stream.next())
.await
.expect("timeout waiting for initial message")
.expect("stream ended")
.expect("ws error");
let initial: serde_json::Value = match first {
tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(),
other => panic!("expected text message, got: {other:?}"),
};
// The second message is the onboarding_status — consume it so
// callers only see application-level messages.
let second = tokio::time::timeout(std::time::Duration::from_secs(2), stream.next())
.await
.expect("timeout waiting for onboarding_status")
.expect("stream ended")
.expect("ws error");
let onboarding: serde_json::Value = match second {
tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(),
other => panic!("expected text message, got: {other:?}"),
};
assert_eq!(
onboarding["type"], "onboarding_status",
"expected onboarding_status, got: {onboarding}"
);
// Drain any log_entry messages sent as initial history on connect.
// These are buffered before tests send their own requests.
loop {
// Use a very short timeout: if nothing arrives quickly, the burst is done.
let Ok(Some(Ok(msg))) =
tokio::time::timeout(std::time::Duration::from_millis(200), stream.next()).await
else {
break;
};
let val: serde_json::Value = match msg {
tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(),
_ => break,
};
if val["type"] != "log_entry" {
// Unexpected non-log message during drain — this shouldn't happen.
panic!("unexpected message during log drain: {val}");
}
}
(sink, stream, initial)
}
/// Read next non-log_entry text message from the stream with a timeout.
/// Skips any `log_entry` messages that arrive between events.
async fn next_msg(
stream: &mut futures::stream::SplitStream<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
>,
) -> serde_json::Value {
loop {
let msg = tokio::time::timeout(std::time::Duration::from_secs(2), stream.next())
.await
.expect("timeout waiting for message")
.expect("stream ended")
.expect("ws error");
let val: serde_json::Value = match msg {
tungstenite::Message::Text(t) => serde_json::from_str(t.as_ref()).unwrap(),
other => panic!("expected text message, got: {other:?}"),
};
if val["type"] != "log_entry" {
return val;
}
}
}
#[tokio::test]
async fn ws_handler_sends_initial_pipeline_state_on_connect() {
let (url, _ctx) = start_test_server().await;
let (_sink, _stream, initial) = connect_ws(&url).await;
assert_eq!(initial["type"], "pipeline_state");
// Verify stage arrays are present (may contain items from the
// shared global CRDT store populated by other tests).
assert!(initial["backlog"].as_array().is_some());
assert!(initial["current"].as_array().is_some());
assert!(initial["qa"].as_array().is_some());
assert!(initial["merge"].as_array().is_some());
}
#[tokio::test]
async fn ws_handler_returns_error_for_invalid_json() {
let (url, _ctx) = start_test_server().await;
let (mut sink, mut stream, _initial) = connect_ws(&url).await;
// Send invalid JSON.
sink.send(ws_text("not valid json")).await.unwrap();
let msg = next_msg(&mut stream).await;
assert_eq!(msg["type"], "error");
assert!(
msg["message"].as_str().unwrap().contains("Invalid request"),
"error message should indicate invalid request, got: {}",
msg["message"]
);
}
#[tokio::test]
async fn ws_handler_returns_error_for_unknown_type() {
let (url, _ctx) = start_test_server().await;
let (mut sink, mut stream, _initial) = connect_ws(&url).await;
// Send a message with an unknown type.
sink.send(ws_text(r#"{"type": "bogus"}"#)).await.unwrap();
let msg = next_msg(&mut stream).await;
assert_eq!(msg["type"], "error");
assert!(msg["message"].as_str().unwrap().contains("Invalid request"));
}
#[tokio::test]
async fn ws_handler_cancel_outside_chat_does_not_error() {
let (url, _ctx) = start_test_server().await;
let (mut sink, mut stream, _initial) = connect_ws(&url).await;
// Send cancel when no chat is active — should not produce an error.
sink.send(ws_text(r#"{"type": "cancel"}"#)).await.unwrap();
// Send another invalid message to check the connection is still alive.
sink.send(ws_text("{}")).await.unwrap();
let msg = next_msg(&mut stream).await;
// The invalid JSON message should produce an error, confirming
// the cancel didn't break the connection.
assert_eq!(msg["type"], "error");
}
#[tokio::test]
async fn ws_handler_permission_response_outside_chat_is_ignored() {
let (url, _ctx) = start_test_server().await;
let (mut sink, mut stream, _initial) = connect_ws(&url).await;
// Send permission response outside an active chat.
sink.send(ws_text(
r#"{"type": "permission_response", "request_id": "x", "approved": true}"#,
))
.await
.unwrap();
// Send a probe message to check the connection is still alive.
sink.send(ws_text("bad")).await.unwrap();
let msg = next_msg(&mut stream).await;
assert_eq!(msg["type"], "error");
assert!(msg["message"].as_str().unwrap().contains("Invalid request"));
}
#[tokio::test]
async fn ws_handler_forwards_watcher_events() {
let (url, ctx) = start_test_server().await;
let (_sink, mut stream, _initial) = connect_ws(&url).await;
// Broadcast a watcher event.
ctx.watcher_tx
.send(WatcherEvent::WorkItem {
stage: "2_current".to_string(),
item_id: "99_story_test".to_string(),
action: "start".to_string(),
commit_msg: "huskies: start 99_story_test".to_string(),
from_stage: None,
})
.unwrap();
let msg = next_msg(&mut stream).await;
assert_eq!(msg["type"], "work_item_changed");
assert_eq!(msg["item_id"], "99_story_test");
assert_eq!(msg["stage"], "2_current");
// After a work-item event, a pipeline_state refresh is pushed.
let state_msg = next_msg(&mut stream).await;
assert_eq!(state_msg["type"], "pipeline_state");
}
#[tokio::test]
async fn ws_handler_forwards_config_changed_without_pipeline_refresh() {
let (url, ctx) = start_test_server().await;
let (_sink, mut stream, _initial) = connect_ws(&url).await;
// Broadcast a config-changed event.
ctx.watcher_tx.send(WatcherEvent::ConfigChanged).unwrap();
let msg = next_msg(&mut stream).await;
assert_eq!(msg["type"], "agent_config_changed");
// Config-changed should NOT be followed by a pipeline_state refresh.
// Send a probe to check no extra message is queued.
ctx.watcher_tx.send(WatcherEvent::ConfigChanged).unwrap();
let msg2 = next_msg(&mut stream).await;
assert_eq!(msg2["type"], "agent_config_changed");
}
#[tokio::test]
async fn ws_handler_forwards_reconciliation_events() {
let (url, ctx) = start_test_server().await;
let (_sink, mut stream, _initial) = connect_ws(&url).await;
// Broadcast a reconciliation event.
ctx.reconciliation_tx
.send(crate::agents::ReconciliationEvent {
story_id: "50_story_recon".to_string(),
status: "checking".to_string(),
message: "Checking story...".to_string(),
})
.unwrap();
let msg = next_msg(&mut stream).await;
assert_eq!(msg["type"], "reconciliation_progress");
assert_eq!(msg["story_id"], "50_story_recon");
assert_eq!(msg["status"], "checking");
assert_eq!(msg["message"], "Checking story...");
}
#[tokio::test]
async fn ws_handler_handles_client_disconnect_gracefully() {
let (url, _ctx) = start_test_server().await;
let (mut sink, _stream, _initial) = connect_ws(&url).await;
// Close the connection — should not panic the server.
sink.close().await.unwrap();
// Give the server a moment to process the close.
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Connect again to verify server is still alive.
let (_sink2, _stream2, initial2) = connect_ws(&url).await;
assert_eq!(initial2["type"], "pipeline_state");
}
}