From 379ff16d3ecc0e50c9c92053ed0fb7d03ed08967 Mon Sep 17 00:00:00 2001 From: dave Date: Tue, 12 May 2026 14:57:53 +0000 Subject: [PATCH] huskies: merge 905 --- server/src/gateway/tests.rs | 173 +++++++++++++++++++++++++++++++ server/src/http/gateway/mcp.rs | 96 ++++++++++++++++- server/src/service/gateway/io.rs | 22 ++++ 3 files changed, 290 insertions(+), 1 deletion(-) diff --git a/server/src/gateway/tests.rs b/server/src/gateway/tests.rs index 5fd87ae1..bebe5434 100644 --- a/server/src/gateway/tests.rs +++ b/server/src/gateway/tests.rs @@ -927,3 +927,176 @@ enabled = false let config = BotConfig::load(tmp.path()); assert!(config.is_none()); } + +// ── Gateway MCP SSE proxy integration tests ────────────────────────── + +#[tokio::test] +async fn gateway_mcp_sse_proxy_streams_progress_and_final_response() { + let mut mock_sled = mockito::Server::new_async().await; + + let prog1 = serde_json::json!({ + "jsonrpc": "2.0", + "method": "notifications/progress", + "params": { "progressToken": "tok1", "progress": 1.0 } + }); + let prog2 = serde_json::json!({ + "jsonrpc": "2.0", + "method": "notifications/progress", + "params": { "progressToken": "tok1", "progress": 2.0 } + }); + let final_resp = serde_json::json!({ + "jsonrpc": "2.0", + "id": 1, + "result": { "content": [{ "type": "text", "text": "tests passed" }] } + }); + let sse_body = format!("data: {prog1}\n\ndata: {prog2}\n\ndata: {final_resp}\n\n"); + + let _mock = mock_sled + .mock("POST", "/mcp") + .with_status(200) + .with_header("content-type", "text/event-stream") + .with_body(&sse_body) + .create_async() + .await; + + let mut projects = BTreeMap::new(); + projects.insert( + "sled".to_string(), + ProjectEntry { + url: mock_sled.url(), + }, + ); + let config = GatewayConfig { projects }; + let state = Arc::new(GatewayState::new(config, PathBuf::new(), 3000).unwrap()); + + let app = poem::Route::new() + .at("/mcp", poem::post(gateway_mcp_post_handler)) + .data(state.clone()); + let cli = poem::test::TestClient::new(app); + + let rpc_body = serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": "run_tests", + "arguments": {}, + "_meta": { "progressToken": "tok1" } + } + })) + .unwrap(); + + let resp = cli + .post("/mcp") + .header("content-type", "application/json") + .header("accept", "text/event-stream") + .body(rpc_body) + .send() + .await; + + let body = resp.0.into_body().into_string().await.unwrap(); + + let data_lines: Vec<&str> = body + .lines() + .filter_map(|l| l.strip_prefix("data: ")) + .collect(); + + assert_eq!( + data_lines.len(), + 3, + "Expected 3 SSE events (2 progress + 1 final); got {}: {:?}", + data_lines.len(), + body + ); + + let ev1: serde_json::Value = + serde_json::from_str(data_lines[0]).expect("event 1 is valid JSON"); + assert_eq!( + ev1["method"], "notifications/progress", + "event 1 must be a progress notification" + ); + assert_eq!(ev1["params"]["progress"], 1.0); + + let ev2: serde_json::Value = + serde_json::from_str(data_lines[1]).expect("event 2 is valid JSON"); + assert_eq!( + ev2["method"], "notifications/progress", + "event 2 must be a progress notification" + ); + assert_eq!(ev2["params"]["progress"], 2.0); + + let ev3: serde_json::Value = + serde_json::from_str(data_lines[2]).expect("event 3 is valid JSON"); + assert_eq!(ev3["id"], 1, "event 3 must be the final JSON-RPC response"); + assert!( + ev3.get("result").is_some(), + "event 3 must carry a result field" + ); +} + +#[tokio::test] +async fn gateway_mcp_post_without_sse_returns_plain_json() { + let mut mock_sled = mockito::Server::new_async().await; + + let json_resp = serde_json::json!({ + "jsonrpc": "2.0", + "id": 2, + "result": { "content": [{ "type": "text", "text": "done" }] } + }); + + let _mock = mock_sled + .mock("POST", "/mcp") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(serde_json::to_string(&json_resp).unwrap()) + .create_async() + .await; + + let mut projects = BTreeMap::new(); + projects.insert( + "sled".to_string(), + ProjectEntry { + url: mock_sled.url(), + }, + ); + let config = GatewayConfig { projects }; + let state = Arc::new(GatewayState::new(config, PathBuf::new(), 3000).unwrap()); + + let app = poem::Route::new() + .at("/mcp", poem::post(gateway_mcp_post_handler)) + .data(state.clone()); + let cli = poem::test::TestClient::new(app); + + let rpc_body = serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { "name": "run_tests", "arguments": {} } + })) + .unwrap(); + + let resp = cli + .post("/mcp") + .header("content-type", "application/json") + .body(rpc_body) + .send() + .await; + + let ct = resp + .0 + .headers() + .get("content-type") + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + assert!( + ct.contains("application/json"), + "Non-SSE path must return application/json; got: {ct}" + ); + + let body: serde_json::Value = resp.0.into_body().into_json().await.unwrap(); + assert_eq!(body["id"], 2); + assert!( + body.get("result").is_some(), + "Expected result in plain JSON response" + ); +} diff --git a/server/src/http/gateway/mcp.rs b/server/src/http/gateway/mcp.rs index 3a5d9dc9..a6f154d0 100644 --- a/server/src/http/gateway/mcp.rs +++ b/server/src/http/gateway/mcp.rs @@ -5,10 +5,12 @@ use crate::service::gateway::{self, GatewayState}; use poem::handler; use poem::http::StatusCode; use poem::web::Data; -use poem::{Body, Request, Response}; +use poem::web::sse::{Event, SSE}; +use poem::{Body, IntoResponse, Request, Response}; use serde_json::{Value, json}; use std::collections::BTreeMap; use std::sync::Arc; +use std::time::Duration; // ── MCP tool definitions ───────────────────────────────────────────────────── @@ -146,6 +148,31 @@ pub async fn gateway_mcp_post_handler( return to_json_response(JsonRpcResponse::error(None, -32600, "Missing id".into())); } + // SSE proxy: tools/call with Accept: text/event-stream + progressToken for + // non-gateway tools is forwarded to the sled's SSE endpoint so progress + // notifications flow through to the gateway client unchanged. + if rpc.method == "tools/call" { + let accepts_sse = req + .header("accept") + .map(|h| h.contains("text/event-stream")) + .unwrap_or(false); + let has_progress_token = rpc + .params + .get("_meta") + .and_then(|m| m.get("progressToken")) + .is_some(); + if accepts_sse && has_progress_token { + let tool_name = rpc + .params + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or(""); + if !GATEWAY_TOOLS.contains(&tool_name) { + return proxy_and_respond_sse(&state, &bytes, rpc.id).await; + } + } + } + match rpc.method.as_str() { "initialize" => to_json_response(handle_initialize(rpc.id)), "tools/list" => match handle_tools_list(&state, rpc.id.clone()).await { @@ -193,6 +220,73 @@ async fn proxy_and_respond(state: &GatewayState, bytes: &[u8], id: Option } } +/// Stream an MCP tool call to the active sled as SSE, re-emitting each `data:` +/// event from the sled to the originating gateway client without buffering. +/// +/// On sled disconnect mid-stream a JSON-RPC error event is emitted so the +/// client does not hang forever. +async fn proxy_and_respond_sse(state: &GatewayState, bytes: &[u8], id: Option) -> Response { + let url = match state.active_url().await { + Ok(u) => u, + Err(e) => return sse_error_response(id, -32603, e.to_string()), + }; + + let resp = match gateway::io::proxy_mcp_call_sse(&state.client, &url, bytes).await { + Ok(r) => r, + Err(e) => return sse_error_response(id, -32603, format!("proxy error: {e}")), + }; + + let id_for_error = id; + let stream = async_stream::stream! { + use futures::StreamExt as _; + let mut buf = String::new(); + let byte_stream = resp.bytes_stream(); + tokio::pin!(byte_stream); + + while let Some(chunk) = byte_stream.next().await { + match chunk { + Ok(bytes) => { + if let Ok(text) = std::str::from_utf8(&bytes) { + buf.push_str(text); + // Emit a gateway SSE event for each complete `data:` line. + while let Some(pos) = buf.find('\n') { + let line = buf[..pos].trim_end_matches('\r').to_string(); + buf = buf[pos + 1..].to_string(); + if let Some(data) = line.strip_prefix("data: ") { + yield Event::message(data.to_string()); + } + } + } + } + Err(e) => { + let err = JsonRpcResponse::error( + id_for_error.clone(), + -32603, + format!("upstream disconnected: {e}"), + ); + let data = serde_json::to_string(&err).unwrap_or_default(); + yield Event::message(data); + break; + } + } + } + }; + + SSE::new(stream) + .keep_alive(Duration::from_secs(15)) + .into_response() +} + +/// Build a minimal SSE response containing a single JSON-RPC error event. +fn sse_error_response(id: Option, code: i64, msg: String) -> Response { + let err = JsonRpcResponse::error(id, code, msg); + let data = serde_json::to_string(&err).unwrap_or_default(); + let stream = async_stream::stream! { + yield Event::message(data); + }; + SSE::new(stream).into_response() +} + /// GET handler — method not allowed. #[handler] pub async fn gateway_mcp_get_handler() -> Response { diff --git a/server/src/service/gateway/io.rs b/server/src/service/gateway/io.rs index f1f2b08f..ff5aa81b 100644 --- a/server/src/service/gateway/io.rs +++ b/server/src/service/gateway/io.rs @@ -108,6 +108,28 @@ pub async fn proxy_mcp_call( .map_err(|e| format!("failed to read response from {mcp_url}: {e}")) } +/// Proxy an MCP `tools/call` request to the sled with `Accept: text/event-stream` +/// and return the raw response for streaming. No per-request timeout is applied +/// so long-running tool calls (e.g. `run_tests`, up to 1200 s) are not cut short. +/// +/// The caller reads `.bytes_stream()` from the returned response and re-emits +/// each SSE `data:` line as a new event to the originating client. +pub async fn proxy_mcp_call_sse( + client: &Client, + base_url: &str, + request_bytes: &[u8], +) -> Result { + let mcp_url = format!("{}/mcp", base_url.trim_end_matches('/')); + client + .post(&mcp_url) + .header("Content-Type", "application/json") + .header("Accept", "text/event-stream") + .body(request_bytes.to_vec()) + .send() + .await + .map_err(|e| format!("failed to reach {mcp_url}: {e}")) +} + /// Fetch tools/list from a project's MCP endpoint. pub async fn fetch_tools_list(client: &Client, base_url: &str) -> Result { let mcp_url = format!("{}/mcp", base_url.trim_end_matches('/'));