huskies: merge 905
This commit is contained in:
@@ -5,10 +5,12 @@ use crate::service::gateway::{self, GatewayState};
|
||||
use poem::handler;
|
||||
use poem::http::StatusCode;
|
||||
use poem::web::Data;
|
||||
use poem::{Body, Request, Response};
|
||||
use poem::web::sse::{Event, SSE};
|
||||
use poem::{Body, IntoResponse, Request, Response};
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
// ── MCP tool definitions ─────────────────────────────────────────────────────
|
||||
|
||||
@@ -146,6 +148,31 @@ pub async fn gateway_mcp_post_handler(
|
||||
return to_json_response(JsonRpcResponse::error(None, -32600, "Missing id".into()));
|
||||
}
|
||||
|
||||
// SSE proxy: tools/call with Accept: text/event-stream + progressToken for
|
||||
// non-gateway tools is forwarded to the sled's SSE endpoint so progress
|
||||
// notifications flow through to the gateway client unchanged.
|
||||
if rpc.method == "tools/call" {
|
||||
let accepts_sse = req
|
||||
.header("accept")
|
||||
.map(|h| h.contains("text/event-stream"))
|
||||
.unwrap_or(false);
|
||||
let has_progress_token = rpc
|
||||
.params
|
||||
.get("_meta")
|
||||
.and_then(|m| m.get("progressToken"))
|
||||
.is_some();
|
||||
if accepts_sse && has_progress_token {
|
||||
let tool_name = rpc
|
||||
.params
|
||||
.get("name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
if !GATEWAY_TOOLS.contains(&tool_name) {
|
||||
return proxy_and_respond_sse(&state, &bytes, rpc.id).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match rpc.method.as_str() {
|
||||
"initialize" => to_json_response(handle_initialize(rpc.id)),
|
||||
"tools/list" => match handle_tools_list(&state, rpc.id.clone()).await {
|
||||
@@ -193,6 +220,73 @@ async fn proxy_and_respond(state: &GatewayState, bytes: &[u8], id: Option<Value>
|
||||
}
|
||||
}
|
||||
|
||||
/// Stream an MCP tool call to the active sled as SSE, re-emitting each `data:`
|
||||
/// event from the sled to the originating gateway client without buffering.
|
||||
///
|
||||
/// On sled disconnect mid-stream a JSON-RPC error event is emitted so the
|
||||
/// client does not hang forever.
|
||||
async fn proxy_and_respond_sse(state: &GatewayState, bytes: &[u8], id: Option<Value>) -> Response {
|
||||
let url = match state.active_url().await {
|
||||
Ok(u) => u,
|
||||
Err(e) => return sse_error_response(id, -32603, e.to_string()),
|
||||
};
|
||||
|
||||
let resp = match gateway::io::proxy_mcp_call_sse(&state.client, &url, bytes).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => return sse_error_response(id, -32603, format!("proxy error: {e}")),
|
||||
};
|
||||
|
||||
let id_for_error = id;
|
||||
let stream = async_stream::stream! {
|
||||
use futures::StreamExt as _;
|
||||
let mut buf = String::new();
|
||||
let byte_stream = resp.bytes_stream();
|
||||
tokio::pin!(byte_stream);
|
||||
|
||||
while let Some(chunk) = byte_stream.next().await {
|
||||
match chunk {
|
||||
Ok(bytes) => {
|
||||
if let Ok(text) = std::str::from_utf8(&bytes) {
|
||||
buf.push_str(text);
|
||||
// Emit a gateway SSE event for each complete `data:` line.
|
||||
while let Some(pos) = buf.find('\n') {
|
||||
let line = buf[..pos].trim_end_matches('\r').to_string();
|
||||
buf = buf[pos + 1..].to_string();
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
yield Event::message(data.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let err = JsonRpcResponse::error(
|
||||
id_for_error.clone(),
|
||||
-32603,
|
||||
format!("upstream disconnected: {e}"),
|
||||
);
|
||||
let data = serde_json::to_string(&err).unwrap_or_default();
|
||||
yield Event::message(data);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
SSE::new(stream)
|
||||
.keep_alive(Duration::from_secs(15))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// Build a minimal SSE response containing a single JSON-RPC error event.
|
||||
fn sse_error_response(id: Option<Value>, code: i64, msg: String) -> Response {
|
||||
let err = JsonRpcResponse::error(id, code, msg);
|
||||
let data = serde_json::to_string(&err).unwrap_or_default();
|
||||
let stream = async_stream::stream! {
|
||||
yield Event::message(data);
|
||||
};
|
||||
SSE::new(stream).into_response()
|
||||
}
|
||||
|
||||
/// GET handler — method not allowed.
|
||||
#[handler]
|
||||
pub async fn gateway_mcp_get_handler() -> Response {
|
||||
|
||||
Reference in New Issue
Block a user