From 396a47d7c20e03e6116fbafb999a4af8b4f05dfe Mon Sep 17 00:00:00 2001 From: dave Date: Wed, 13 May 2026 10:03:25 +0000 Subject: [PATCH] huskies: merge 957 --- .../src/chat/transport/matrix/bot/context.rs | 172 +++++++++++++++--- 1 file changed, 148 insertions(+), 24 deletions(-) diff --git a/server/src/chat/transport/matrix/bot/context.rs b/server/src/chat/transport/matrix/bot/context.rs index e475629b..a3c29085 100644 --- a/server/src/chat/transport/matrix/bot/context.rs +++ b/server/src/chat/transport/matrix/bot/context.rs @@ -48,7 +48,7 @@ pub struct BotContext { /// Empty in standalone mode. pub gateway_projects: Vec, /// In gateway mode: mapping of project name → base URL (e.g. `"http://localhost:3001"`). - /// Used to proxy bot commands to the active project's `/api/bot/command` endpoint. + /// Used to proxy bot commands to the active project over WebSocket (`/ws`). /// Empty in standalone mode. pub gateway_project_urls: BTreeMap, } @@ -81,35 +81,99 @@ impl BotContext { self.gateway_project_urls.get(&name).cloned() } - /// Proxy a bot command to the active project's `/api/bot/command` endpoint. + /// Proxy a bot command to the active project over a WebSocket RPC call. /// - /// Returns the Markdown response from the project server, or an error - /// message if the request failed. + /// Connects to `{base_url}/ws`, sends an `rpc_request` frame for the + /// `bot.command` method, and returns the Markdown response from the + /// `rpc_response` frame. Returns an error message string if the + /// connection or command fails. pub async fn proxy_bot_command(&self, command: &str, args: &str) -> Option { + use futures::{SinkExt, StreamExt}; + use tokio_tungstenite::tungstenite::Message as WsMsg; + let base_url = self.active_project_url().await?; - let url = format!("{base_url}/api/bot/command"); - let client = reqwest::Client::new(); - let body = serde_json::json!({ - "command": command, - "args": args, + + // Convert http(s):// → ws(s):// + let ws_base = if let Some(rest) = base_url.strip_prefix("https://") { + format!("wss://{rest}") + } else if let Some(rest) = base_url.strip_prefix("http://") { + format!("ws://{rest}") + } else { + base_url.clone() + }; + let ws_url = format!("{ws_base}/ws"); + + let correlation_id = uuid::Uuid::new_v4().to_string(); + let request = serde_json::json!({ + "kind": "rpc_request", + "version": 1, + "correlation_id": correlation_id, + "ttl_ms": 30_000u64, + "method": "bot.command", + "params": { "command": command, "args": args }, }); - match client.post(&url).json(&body).send().await { - Ok(resp) if resp.status().is_success() => { - match resp.json::().await { - Ok(json) => json - .get("response") - .and_then(|v| v.as_str()) - .map(String::from), - Err(e) => Some(format!("Failed to parse response from project server: {e}")), - } + let request_text = match serde_json::to_string(&request) { + Ok(t) => t, + Err(e) => return Some(format!("Failed to serialize RPC request: {e}")), + }; + + let ws_stream = match tokio_tungstenite::connect_async(&ws_url).await { + Ok((stream, _)) => stream, + Err(e) => { + return Some(format!( + "Failed to connect to project server at {ws_url}: {e}" + )); } - Ok(resp) => Some(format!( - "Project server returned HTTP {}: {}", - resp.status(), - resp.text().await.unwrap_or_default() - )), - Err(e) => Some(format!("Failed to reach project server at {url}: {e}")), + }; + + let (mut sink, mut stream) = ws_stream.split(); + + if let Err(e) = sink.send(WsMsg::Text(request_text.into())).await { + return Some(format!("Failed to send RPC request: {e}")); } + + while let Some(msg) = stream.next().await { + match msg { + Ok(WsMsg::Text(text)) => { + let Ok(frame) = serde_json::from_str::(&text) else { + continue; + }; + if frame.get("kind").and_then(|v| v.as_str()) != Some("rpc_response") { + continue; + } + if frame + .get("correlation_id") + .and_then(|v| v.as_str()) + .map(|id| id != correlation_id) + .unwrap_or(true) + { + continue; + } + let ok = frame.get("ok").and_then(|v| v.as_bool()).unwrap_or(false); + if ok { + return frame + .get("result") + .and_then(|r| r.get("response")) + .and_then(|v| v.as_str()) + .map(String::from) + .or_else(|| { + Some("Command succeeded with no response text".to_string()) + }); + } else { + let err = frame + .get("error") + .and_then(|v| v.as_str()) + .unwrap_or("unknown error"); + return Some(format!("Project server command failed: {err}")); + } + } + Ok(WsMsg::Close(_)) => break, + Err(e) => return Some(format!("WebSocket error: {e}")), + _ => continue, + } + } + + Some("Connection closed before receiving command response".to_string()) } } @@ -244,4 +308,64 @@ mod tests { let ctx = test_bot_context(services, None, vec![], BTreeMap::new()); let _cloned = ctx.clone(); } + + /// A bot command issued in gateway mode must round-trip over WebSocket + /// (using the `bot.command` RPC method) and must NOT use HTTP transport. + #[tokio::test] + async fn proxy_bot_command_uses_websocket_not_http() { + use futures::{SinkExt, StreamExt}; + use tokio::net::TcpListener; + use tokio_tungstenite::tungstenite::Message as WsMsg; + + // Bind an ephemeral port for our mock WebSocket server. + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + + // Spawn a minimal WS server: accept one connection, verify the + // request uses the `bot.command` RPC method (not HTTP), and reply. + let server = tokio::spawn(async move { + let (tcp, _addr) = listener.accept().await.unwrap(); + let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap(); + while let Some(Ok(msg)) = ws.next().await { + if let WsMsg::Text(text) = msg { + let req: serde_json::Value = + serde_json::from_str(&text).expect("valid JSON from proxy"); + assert_eq!( + req["kind"], "rpc_request", + "transport must use rpc_request, not HTTP" + ); + assert_eq!(req["method"], "bot.command"); + assert_eq!(req["params"]["command"], "status"); + let correlation_id = req["correlation_id"].clone(); + let resp = serde_json::json!({ + "kind": "rpc_response", + "correlation_id": correlation_id, + "ok": true, + "result": { "response": "all systems go" }, + }); + ws.send(WsMsg::Text(resp.to_string().into())).await.unwrap(); + break; + } + } + }); + + let base_url = format!("http://127.0.0.1:{port}"); + let services = test_services(PathBuf::from("/gateway")); + let active = Arc::new(RwLock::new("huskies".to_string())); + let ctx = test_bot_context( + services, + Some(Arc::clone(&active)), + vec!["huskies".into()], + BTreeMap::from([("huskies".into(), base_url)]), + ); + + let result = ctx.proxy_bot_command("status", "").await; + assert_eq!( + result.as_deref(), + Some("all systems go"), + "proxy must return the response text from the rpc_response frame" + ); + + server.await.unwrap(); + } }