huskies: merge 905
This commit is contained in:
@@ -927,3 +927,176 @@ enabled = false
|
|||||||
let config = BotConfig::load(tmp.path());
|
let config = BotConfig::load(tmp.path());
|
||||||
assert!(config.is_none());
|
assert!(config.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Gateway MCP SSE proxy integration tests ──────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn gateway_mcp_sse_proxy_streams_progress_and_final_response() {
|
||||||
|
let mut mock_sled = mockito::Server::new_async().await;
|
||||||
|
|
||||||
|
let prog1 = serde_json::json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "notifications/progress",
|
||||||
|
"params": { "progressToken": "tok1", "progress": 1.0 }
|
||||||
|
});
|
||||||
|
let prog2 = serde_json::json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "notifications/progress",
|
||||||
|
"params": { "progressToken": "tok1", "progress": 2.0 }
|
||||||
|
});
|
||||||
|
let final_resp = serde_json::json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 1,
|
||||||
|
"result": { "content": [{ "type": "text", "text": "tests passed" }] }
|
||||||
|
});
|
||||||
|
let sse_body = format!("data: {prog1}\n\ndata: {prog2}\n\ndata: {final_resp}\n\n");
|
||||||
|
|
||||||
|
let _mock = mock_sled
|
||||||
|
.mock("POST", "/mcp")
|
||||||
|
.with_status(200)
|
||||||
|
.with_header("content-type", "text/event-stream")
|
||||||
|
.with_body(&sse_body)
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let mut projects = BTreeMap::new();
|
||||||
|
projects.insert(
|
||||||
|
"sled".to_string(),
|
||||||
|
ProjectEntry {
|
||||||
|
url: mock_sled.url(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
let config = GatewayConfig { projects };
|
||||||
|
let state = Arc::new(GatewayState::new(config, PathBuf::new(), 3000).unwrap());
|
||||||
|
|
||||||
|
let app = poem::Route::new()
|
||||||
|
.at("/mcp", poem::post(gateway_mcp_post_handler))
|
||||||
|
.data(state.clone());
|
||||||
|
let cli = poem::test::TestClient::new(app);
|
||||||
|
|
||||||
|
let rpc_body = serde_json::to_vec(&serde_json::json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 1,
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": {
|
||||||
|
"name": "run_tests",
|
||||||
|
"arguments": {},
|
||||||
|
"_meta": { "progressToken": "tok1" }
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let resp = cli
|
||||||
|
.post("/mcp")
|
||||||
|
.header("content-type", "application/json")
|
||||||
|
.header("accept", "text/event-stream")
|
||||||
|
.body(rpc_body)
|
||||||
|
.send()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let body = resp.0.into_body().into_string().await.unwrap();
|
||||||
|
|
||||||
|
let data_lines: Vec<&str> = body
|
||||||
|
.lines()
|
||||||
|
.filter_map(|l| l.strip_prefix("data: "))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
data_lines.len(),
|
||||||
|
3,
|
||||||
|
"Expected 3 SSE events (2 progress + 1 final); got {}: {:?}",
|
||||||
|
data_lines.len(),
|
||||||
|
body
|
||||||
|
);
|
||||||
|
|
||||||
|
let ev1: serde_json::Value =
|
||||||
|
serde_json::from_str(data_lines[0]).expect("event 1 is valid JSON");
|
||||||
|
assert_eq!(
|
||||||
|
ev1["method"], "notifications/progress",
|
||||||
|
"event 1 must be a progress notification"
|
||||||
|
);
|
||||||
|
assert_eq!(ev1["params"]["progress"], 1.0);
|
||||||
|
|
||||||
|
let ev2: serde_json::Value =
|
||||||
|
serde_json::from_str(data_lines[1]).expect("event 2 is valid JSON");
|
||||||
|
assert_eq!(
|
||||||
|
ev2["method"], "notifications/progress",
|
||||||
|
"event 2 must be a progress notification"
|
||||||
|
);
|
||||||
|
assert_eq!(ev2["params"]["progress"], 2.0);
|
||||||
|
|
||||||
|
let ev3: serde_json::Value =
|
||||||
|
serde_json::from_str(data_lines[2]).expect("event 3 is valid JSON");
|
||||||
|
assert_eq!(ev3["id"], 1, "event 3 must be the final JSON-RPC response");
|
||||||
|
assert!(
|
||||||
|
ev3.get("result").is_some(),
|
||||||
|
"event 3 must carry a result field"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn gateway_mcp_post_without_sse_returns_plain_json() {
|
||||||
|
let mut mock_sled = mockito::Server::new_async().await;
|
||||||
|
|
||||||
|
let json_resp = serde_json::json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 2,
|
||||||
|
"result": { "content": [{ "type": "text", "text": "done" }] }
|
||||||
|
});
|
||||||
|
|
||||||
|
let _mock = mock_sled
|
||||||
|
.mock("POST", "/mcp")
|
||||||
|
.with_status(200)
|
||||||
|
.with_header("content-type", "application/json")
|
||||||
|
.with_body(serde_json::to_string(&json_resp).unwrap())
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let mut projects = BTreeMap::new();
|
||||||
|
projects.insert(
|
||||||
|
"sled".to_string(),
|
||||||
|
ProjectEntry {
|
||||||
|
url: mock_sled.url(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
let config = GatewayConfig { projects };
|
||||||
|
let state = Arc::new(GatewayState::new(config, PathBuf::new(), 3000).unwrap());
|
||||||
|
|
||||||
|
let app = poem::Route::new()
|
||||||
|
.at("/mcp", poem::post(gateway_mcp_post_handler))
|
||||||
|
.data(state.clone());
|
||||||
|
let cli = poem::test::TestClient::new(app);
|
||||||
|
|
||||||
|
let rpc_body = serde_json::to_vec(&serde_json::json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 2,
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": { "name": "run_tests", "arguments": {} }
|
||||||
|
}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let resp = cli
|
||||||
|
.post("/mcp")
|
||||||
|
.header("content-type", "application/json")
|
||||||
|
.body(rpc_body)
|
||||||
|
.send()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let ct = resp
|
||||||
|
.0
|
||||||
|
.headers()
|
||||||
|
.get("content-type")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.unwrap_or("");
|
||||||
|
assert!(
|
||||||
|
ct.contains("application/json"),
|
||||||
|
"Non-SSE path must return application/json; got: {ct}"
|
||||||
|
);
|
||||||
|
|
||||||
|
let body: serde_json::Value = resp.0.into_body().into_json().await.unwrap();
|
||||||
|
assert_eq!(body["id"], 2);
|
||||||
|
assert!(
|
||||||
|
body.get("result").is_some(),
|
||||||
|
"Expected result in plain JSON response"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,10 +5,12 @@ use crate::service::gateway::{self, GatewayState};
|
|||||||
use poem::handler;
|
use poem::handler;
|
||||||
use poem::http::StatusCode;
|
use poem::http::StatusCode;
|
||||||
use poem::web::Data;
|
use poem::web::Data;
|
||||||
use poem::{Body, Request, Response};
|
use poem::web::sse::{Event, SSE};
|
||||||
|
use poem::{Body, IntoResponse, Request, Response};
|
||||||
use serde_json::{Value, json};
|
use serde_json::{Value, json};
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
// ── MCP tool definitions ─────────────────────────────────────────────────────
|
// ── MCP tool definitions ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -146,6 +148,31 @@ pub async fn gateway_mcp_post_handler(
|
|||||||
return to_json_response(JsonRpcResponse::error(None, -32600, "Missing id".into()));
|
return to_json_response(JsonRpcResponse::error(None, -32600, "Missing id".into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SSE proxy: tools/call with Accept: text/event-stream + progressToken for
|
||||||
|
// non-gateway tools is forwarded to the sled's SSE endpoint so progress
|
||||||
|
// notifications flow through to the gateway client unchanged.
|
||||||
|
if rpc.method == "tools/call" {
|
||||||
|
let accepts_sse = req
|
||||||
|
.header("accept")
|
||||||
|
.map(|h| h.contains("text/event-stream"))
|
||||||
|
.unwrap_or(false);
|
||||||
|
let has_progress_token = rpc
|
||||||
|
.params
|
||||||
|
.get("_meta")
|
||||||
|
.and_then(|m| m.get("progressToken"))
|
||||||
|
.is_some();
|
||||||
|
if accepts_sse && has_progress_token {
|
||||||
|
let tool_name = rpc
|
||||||
|
.params
|
||||||
|
.get("name")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("");
|
||||||
|
if !GATEWAY_TOOLS.contains(&tool_name) {
|
||||||
|
return proxy_and_respond_sse(&state, &bytes, rpc.id).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
match rpc.method.as_str() {
|
match rpc.method.as_str() {
|
||||||
"initialize" => to_json_response(handle_initialize(rpc.id)),
|
"initialize" => to_json_response(handle_initialize(rpc.id)),
|
||||||
"tools/list" => match handle_tools_list(&state, rpc.id.clone()).await {
|
"tools/list" => match handle_tools_list(&state, rpc.id.clone()).await {
|
||||||
@@ -193,6 +220,73 @@ async fn proxy_and_respond(state: &GatewayState, bytes: &[u8], id: Option<Value>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Stream an MCP tool call to the active sled as SSE, re-emitting each `data:`
|
||||||
|
/// event from the sled to the originating gateway client without buffering.
|
||||||
|
///
|
||||||
|
/// On sled disconnect mid-stream a JSON-RPC error event is emitted so the
|
||||||
|
/// client does not hang forever.
|
||||||
|
async fn proxy_and_respond_sse(state: &GatewayState, bytes: &[u8], id: Option<Value>) -> Response {
|
||||||
|
let url = match state.active_url().await {
|
||||||
|
Ok(u) => u,
|
||||||
|
Err(e) => return sse_error_response(id, -32603, e.to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let resp = match gateway::io::proxy_mcp_call_sse(&state.client, &url, bytes).await {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => return sse_error_response(id, -32603, format!("proxy error: {e}")),
|
||||||
|
};
|
||||||
|
|
||||||
|
let id_for_error = id;
|
||||||
|
let stream = async_stream::stream! {
|
||||||
|
use futures::StreamExt as _;
|
||||||
|
let mut buf = String::new();
|
||||||
|
let byte_stream = resp.bytes_stream();
|
||||||
|
tokio::pin!(byte_stream);
|
||||||
|
|
||||||
|
while let Some(chunk) = byte_stream.next().await {
|
||||||
|
match chunk {
|
||||||
|
Ok(bytes) => {
|
||||||
|
if let Ok(text) = std::str::from_utf8(&bytes) {
|
||||||
|
buf.push_str(text);
|
||||||
|
// Emit a gateway SSE event for each complete `data:` line.
|
||||||
|
while let Some(pos) = buf.find('\n') {
|
||||||
|
let line = buf[..pos].trim_end_matches('\r').to_string();
|
||||||
|
buf = buf[pos + 1..].to_string();
|
||||||
|
if let Some(data) = line.strip_prefix("data: ") {
|
||||||
|
yield Event::message(data.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let err = JsonRpcResponse::error(
|
||||||
|
id_for_error.clone(),
|
||||||
|
-32603,
|
||||||
|
format!("upstream disconnected: {e}"),
|
||||||
|
);
|
||||||
|
let data = serde_json::to_string(&err).unwrap_or_default();
|
||||||
|
yield Event::message(data);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
SSE::new(stream)
|
||||||
|
.keep_alive(Duration::from_secs(15))
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a minimal SSE response containing a single JSON-RPC error event.
|
||||||
|
fn sse_error_response(id: Option<Value>, code: i64, msg: String) -> Response {
|
||||||
|
let err = JsonRpcResponse::error(id, code, msg);
|
||||||
|
let data = serde_json::to_string(&err).unwrap_or_default();
|
||||||
|
let stream = async_stream::stream! {
|
||||||
|
yield Event::message(data);
|
||||||
|
};
|
||||||
|
SSE::new(stream).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
/// GET handler — method not allowed.
|
/// GET handler — method not allowed.
|
||||||
#[handler]
|
#[handler]
|
||||||
pub async fn gateway_mcp_get_handler() -> Response {
|
pub async fn gateway_mcp_get_handler() -> Response {
|
||||||
|
|||||||
@@ -108,6 +108,28 @@ pub async fn proxy_mcp_call(
|
|||||||
.map_err(|e| format!("failed to read response from {mcp_url}: {e}"))
|
.map_err(|e| format!("failed to read response from {mcp_url}: {e}"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Proxy an MCP `tools/call` request to the sled with `Accept: text/event-stream`
|
||||||
|
/// and return the raw response for streaming. No per-request timeout is applied
|
||||||
|
/// so long-running tool calls (e.g. `run_tests`, up to 1200 s) are not cut short.
|
||||||
|
///
|
||||||
|
/// The caller reads `.bytes_stream()` from the returned response and re-emits
|
||||||
|
/// each SSE `data:` line as a new event to the originating client.
|
||||||
|
pub async fn proxy_mcp_call_sse(
|
||||||
|
client: &Client,
|
||||||
|
base_url: &str,
|
||||||
|
request_bytes: &[u8],
|
||||||
|
) -> Result<reqwest::Response, String> {
|
||||||
|
let mcp_url = format!("{}/mcp", base_url.trim_end_matches('/'));
|
||||||
|
client
|
||||||
|
.post(&mcp_url)
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Accept", "text/event-stream")
|
||||||
|
.body(request_bytes.to_vec())
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("failed to reach {mcp_url}: {e}"))
|
||||||
|
}
|
||||||
|
|
||||||
/// Fetch tools/list from a project's MCP endpoint.
|
/// Fetch tools/list from a project's MCP endpoint.
|
||||||
pub async fn fetch_tools_list(client: &Client, base_url: &str) -> Result<Value, String> {
|
pub async fn fetch_tools_list(client: &Client, base_url: &str) -> Result<Value, String> {
|
||||||
let mcp_url = format!("{}/mcp", base_url.trim_end_matches('/'));
|
let mcp_url = format!("{}/mcp", base_url.trim_end_matches('/'));
|
||||||
|
|||||||
Reference in New Issue
Block a user