huskies: merge 905

This commit is contained in:
dave
2026-05-12 14:57:53 +00:00
parent 2c5326f339
commit 379ff16d3e
3 changed files with 290 additions and 1 deletions
+95 -1
View File
@@ -5,10 +5,12 @@ use crate::service::gateway::{self, GatewayState};
use poem::handler;
use poem::http::StatusCode;
use poem::web::Data;
use poem::{Body, Request, Response};
use poem::web::sse::{Event, SSE};
use poem::{Body, IntoResponse, Request, Response};
use serde_json::{Value, json};
use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::Duration;
// ── MCP tool definitions ─────────────────────────────────────────────────────
@@ -146,6 +148,31 @@ pub async fn gateway_mcp_post_handler(
return to_json_response(JsonRpcResponse::error(None, -32600, "Missing id".into()));
}
// SSE proxy: tools/call with Accept: text/event-stream + progressToken for
// non-gateway tools is forwarded to the sled's SSE endpoint so progress
// notifications flow through to the gateway client unchanged.
if rpc.method == "tools/call" {
let accepts_sse = req
.header("accept")
.map(|h| h.contains("text/event-stream"))
.unwrap_or(false);
let has_progress_token = rpc
.params
.get("_meta")
.and_then(|m| m.get("progressToken"))
.is_some();
if accepts_sse && has_progress_token {
let tool_name = rpc
.params
.get("name")
.and_then(|v| v.as_str())
.unwrap_or("");
if !GATEWAY_TOOLS.contains(&tool_name) {
return proxy_and_respond_sse(&state, &bytes, rpc.id).await;
}
}
}
match rpc.method.as_str() {
"initialize" => to_json_response(handle_initialize(rpc.id)),
"tools/list" => match handle_tools_list(&state, rpc.id.clone()).await {
@@ -193,6 +220,73 @@ async fn proxy_and_respond(state: &GatewayState, bytes: &[u8], id: Option<Value>
}
}
/// Stream an MCP tool call to the active sled as SSE, re-emitting each `data:`
/// event from the sled to the originating gateway client without buffering.
///
/// On sled disconnect mid-stream a JSON-RPC error event is emitted so the
/// client does not hang forever.
async fn proxy_and_respond_sse(state: &GatewayState, bytes: &[u8], id: Option<Value>) -> Response {
let url = match state.active_url().await {
Ok(u) => u,
Err(e) => return sse_error_response(id, -32603, e.to_string()),
};
let resp = match gateway::io::proxy_mcp_call_sse(&state.client, &url, bytes).await {
Ok(r) => r,
Err(e) => return sse_error_response(id, -32603, format!("proxy error: {e}")),
};
let id_for_error = id;
let stream = async_stream::stream! {
use futures::StreamExt as _;
let mut buf = String::new();
let byte_stream = resp.bytes_stream();
tokio::pin!(byte_stream);
while let Some(chunk) = byte_stream.next().await {
match chunk {
Ok(bytes) => {
if let Ok(text) = std::str::from_utf8(&bytes) {
buf.push_str(text);
// Emit a gateway SSE event for each complete `data:` line.
while let Some(pos) = buf.find('\n') {
let line = buf[..pos].trim_end_matches('\r').to_string();
buf = buf[pos + 1..].to_string();
if let Some(data) = line.strip_prefix("data: ") {
yield Event::message(data.to_string());
}
}
}
}
Err(e) => {
let err = JsonRpcResponse::error(
id_for_error.clone(),
-32603,
format!("upstream disconnected: {e}"),
);
let data = serde_json::to_string(&err).unwrap_or_default();
yield Event::message(data);
break;
}
}
}
};
SSE::new(stream)
.keep_alive(Duration::from_secs(15))
.into_response()
}
/// Build a minimal SSE response containing a single JSON-RPC error event.
fn sse_error_response(id: Option<Value>, code: i64, msg: String) -> Response {
let err = JsonRpcResponse::error(id, code, msg);
let data = serde_json::to_string(&err).unwrap_or_default();
let stream = async_stream::stream! {
yield Event::message(data);
};
SSE::new(stream).into_response()
}
/// GET handler — method not allowed.
#[handler]
pub async fn gateway_mcp_get_handler() -> Response {