huskies: merge 957

This commit is contained in:
dave
2026-05-13 10:03:25 +00:00
parent 765d54fc4b
commit 396a47d7c2
+148 -24
View File
@@ -48,7 +48,7 @@ pub struct BotContext {
/// Empty in standalone mode.
pub gateway_projects: Vec<String>,
/// In gateway mode: mapping of project name → base URL (e.g. `"http://localhost:3001"`).
/// Used to proxy bot commands to the active project's `/api/bot/command` endpoint.
/// Used to proxy bot commands to the active project over WebSocket (`/ws`).
/// Empty in standalone mode.
pub gateway_project_urls: BTreeMap<String, String>,
}
@@ -81,35 +81,99 @@ impl BotContext {
self.gateway_project_urls.get(&name).cloned()
}
/// Proxy a bot command to the active project's `/api/bot/command` endpoint.
/// Proxy a bot command to the active project over a WebSocket RPC call.
///
/// Returns the Markdown response from the project server, or an error
/// message if the request failed.
/// Connects to `{base_url}/ws`, sends an `rpc_request` frame for the
/// `bot.command` method, and returns the Markdown response from the
/// `rpc_response` frame. Returns an error message string if the
/// connection or command fails.
pub async fn proxy_bot_command(&self, command: &str, args: &str) -> Option<String> {
use futures::{SinkExt, StreamExt};
use tokio_tungstenite::tungstenite::Message as WsMsg;
let base_url = self.active_project_url().await?;
let url = format!("{base_url}/api/bot/command");
let client = reqwest::Client::new();
let body = serde_json::json!({
"command": command,
"args": args,
// Convert http(s):// → ws(s)://
let ws_base = if let Some(rest) = base_url.strip_prefix("https://") {
format!("wss://{rest}")
} else if let Some(rest) = base_url.strip_prefix("http://") {
format!("ws://{rest}")
} else {
base_url.clone()
};
let ws_url = format!("{ws_base}/ws");
let correlation_id = uuid::Uuid::new_v4().to_string();
let request = serde_json::json!({
"kind": "rpc_request",
"version": 1,
"correlation_id": correlation_id,
"ttl_ms": 30_000u64,
"method": "bot.command",
"params": { "command": command, "args": args },
});
match client.post(&url).json(&body).send().await {
Ok(resp) if resp.status().is_success() => {
match resp.json::<serde_json::Value>().await {
Ok(json) => json
.get("response")
.and_then(|v| v.as_str())
.map(String::from),
Err(e) => Some(format!("Failed to parse response from project server: {e}")),
}
let request_text = match serde_json::to_string(&request) {
Ok(t) => t,
Err(e) => return Some(format!("Failed to serialize RPC request: {e}")),
};
let ws_stream = match tokio_tungstenite::connect_async(&ws_url).await {
Ok((stream, _)) => stream,
Err(e) => {
return Some(format!(
"Failed to connect to project server at {ws_url}: {e}"
));
}
Ok(resp) => Some(format!(
"Project server returned HTTP {}: {}",
resp.status(),
resp.text().await.unwrap_or_default()
)),
Err(e) => Some(format!("Failed to reach project server at {url}: {e}")),
};
let (mut sink, mut stream) = ws_stream.split();
if let Err(e) = sink.send(WsMsg::Text(request_text.into())).await {
return Some(format!("Failed to send RPC request: {e}"));
}
while let Some(msg) = stream.next().await {
match msg {
Ok(WsMsg::Text(text)) => {
let Ok(frame) = serde_json::from_str::<serde_json::Value>(&text) else {
continue;
};
if frame.get("kind").and_then(|v| v.as_str()) != Some("rpc_response") {
continue;
}
if frame
.get("correlation_id")
.and_then(|v| v.as_str())
.map(|id| id != correlation_id)
.unwrap_or(true)
{
continue;
}
let ok = frame.get("ok").and_then(|v| v.as_bool()).unwrap_or(false);
if ok {
return frame
.get("result")
.and_then(|r| r.get("response"))
.and_then(|v| v.as_str())
.map(String::from)
.or_else(|| {
Some("Command succeeded with no response text".to_string())
});
} else {
let err = frame
.get("error")
.and_then(|v| v.as_str())
.unwrap_or("unknown error");
return Some(format!("Project server command failed: {err}"));
}
}
Ok(WsMsg::Close(_)) => break,
Err(e) => return Some(format!("WebSocket error: {e}")),
_ => continue,
}
}
Some("Connection closed before receiving command response".to_string())
}
}
@@ -244,4 +308,64 @@ mod tests {
let ctx = test_bot_context(services, None, vec![], BTreeMap::new());
let _cloned = ctx.clone();
}
/// A bot command issued in gateway mode must round-trip over WebSocket
/// (using the `bot.command` RPC method) and must NOT use HTTP transport.
#[tokio::test]
async fn proxy_bot_command_uses_websocket_not_http() {
use futures::{SinkExt, StreamExt};
use tokio::net::TcpListener;
use tokio_tungstenite::tungstenite::Message as WsMsg;
// Bind an ephemeral port for our mock WebSocket server.
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
// Spawn a minimal WS server: accept one connection, verify the
// request uses the `bot.command` RPC method (not HTTP), and reply.
let server = tokio::spawn(async move {
let (tcp, _addr) = listener.accept().await.unwrap();
let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap();
while let Some(Ok(msg)) = ws.next().await {
if let WsMsg::Text(text) = msg {
let req: serde_json::Value =
serde_json::from_str(&text).expect("valid JSON from proxy");
assert_eq!(
req["kind"], "rpc_request",
"transport must use rpc_request, not HTTP"
);
assert_eq!(req["method"], "bot.command");
assert_eq!(req["params"]["command"], "status");
let correlation_id = req["correlation_id"].clone();
let resp = serde_json::json!({
"kind": "rpc_response",
"correlation_id": correlation_id,
"ok": true,
"result": { "response": "all systems go" },
});
ws.send(WsMsg::Text(resp.to_string().into())).await.unwrap();
break;
}
}
});
let base_url = format!("http://127.0.0.1:{port}");
let services = test_services(PathBuf::from("/gateway"));
let active = Arc::new(RwLock::new("huskies".to_string()));
let ctx = test_bot_context(
services,
Some(Arc::clone(&active)),
vec!["huskies".into()],
BTreeMap::from([("huskies".into(), base_url)]),
);
let result = ctx.proxy_bot_command("status", "").await;
assert_eq!(
result.as_deref(),
Some("all systems go"),
"proxy must return the response text from the rpc_response frame"
);
server.await.unwrap();
}
}