huskies: merge 899
This commit is contained in:
@@ -119,7 +119,7 @@ pub async fn run(config_path: &Path, port: u16) -> Result<(), std::io::Error> {
|
||||
.read()
|
||||
.await
|
||||
.iter()
|
||||
.map(|(name, entry)| (name.clone(), entry.url.clone()))
|
||||
.filter_map(|(name, entry)| entry.url.as_ref().map(|u| (name.clone(), u.clone())))
|
||||
.collect();
|
||||
let (bot_abort, bot_shutdown_tx) = gateway::io::spawn_gateway_bot(
|
||||
&config_dir,
|
||||
|
||||
+206
-29
@@ -7,12 +7,7 @@ use std::path::PathBuf;
|
||||
|
||||
fn make_test_state() -> Arc<GatewayState> {
|
||||
let mut projects = BTreeMap::new();
|
||||
projects.insert(
|
||||
"test".into(),
|
||||
ProjectEntry {
|
||||
url: "http://test:3001".into(),
|
||||
},
|
||||
);
|
||||
projects.insert("test".into(), ProjectEntry::with_url("http://test:3001"));
|
||||
let config = GatewayConfig {
|
||||
projects,
|
||||
sled_tokens: BTreeMap::new(),
|
||||
@@ -367,9 +362,7 @@ async fn init_project_registers_in_projects_toml_when_name_and_url_given() {
|
||||
let mut projects = BTreeMap::new();
|
||||
projects.insert(
|
||||
"existing".into(),
|
||||
ProjectEntry {
|
||||
url: "http://existing:3001".into(),
|
||||
},
|
||||
ProjectEntry::with_url("http://existing:3001"),
|
||||
);
|
||||
let config = GatewayConfig {
|
||||
projects,
|
||||
@@ -388,19 +381,17 @@ async fn init_project_registers_in_projects_toml_when_name_and_url_given() {
|
||||
|
||||
let projects = state.projects.read().await;
|
||||
assert!(projects.contains_key("new-project"));
|
||||
assert_eq!(projects["new-project"].url, "http://new-project:3002");
|
||||
assert_eq!(
|
||||
projects["new-project"].url.as_deref(),
|
||||
Some("http://new-project:3002")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn init_project_duplicate_name_returns_error() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut projects = BTreeMap::new();
|
||||
projects.insert(
|
||||
"taken".into(),
|
||||
ProjectEntry {
|
||||
url: "http://taken:3001".into(),
|
||||
},
|
||||
);
|
||||
projects.insert("taken".into(), ProjectEntry::with_url("http://taken:3001"));
|
||||
let config = GatewayConfig {
|
||||
projects,
|
||||
sled_tokens: BTreeMap::new(),
|
||||
@@ -452,7 +443,7 @@ async fn init_project_then_wizard_status_integration() {
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
|
||||
let mut projects = BTreeMap::new();
|
||||
projects.insert("mock-project".into(), ProjectEntry { url: mock_url });
|
||||
projects.insert("mock-project".into(), ProjectEntry::with_url(mock_url));
|
||||
let config = GatewayConfig {
|
||||
projects,
|
||||
sled_tokens: BTreeMap::new(),
|
||||
@@ -972,12 +963,7 @@ async fn gateway_mcp_sse_proxy_streams_progress_and_final_response() {
|
||||
.await;
|
||||
|
||||
let mut projects = BTreeMap::new();
|
||||
projects.insert(
|
||||
"sled".to_string(),
|
||||
ProjectEntry {
|
||||
url: mock_sled.url(),
|
||||
},
|
||||
);
|
||||
projects.insert("sled".to_string(), ProjectEntry::with_url(mock_sled.url()));
|
||||
let config = GatewayConfig {
|
||||
projects,
|
||||
sled_tokens: BTreeMap::new(),
|
||||
@@ -1068,12 +1054,7 @@ async fn gateway_mcp_post_without_sse_returns_plain_json() {
|
||||
.await;
|
||||
|
||||
let mut projects = BTreeMap::new();
|
||||
projects.insert(
|
||||
"sled".to_string(),
|
||||
ProjectEntry {
|
||||
url: mock_sled.url(),
|
||||
},
|
||||
);
|
||||
projects.insert("sled".to_string(), ProjectEntry::with_url(mock_sled.url()));
|
||||
let config = GatewayConfig {
|
||||
projects,
|
||||
sled_tokens: BTreeMap::new(),
|
||||
@@ -1118,3 +1099,199 @@ async fn gateway_mcp_post_without_sse_returns_plain_json() {
|
||||
"Expected result in plain JSON response"
|
||||
);
|
||||
}
|
||||
|
||||
// ── Story 899: MCP-over-WS uplink integration ────────────────────────────────
|
||||
|
||||
/// Build a `SledConnection` plus a spawned "mock sled" task that:
|
||||
///
|
||||
/// * Reads outbound `mcp_request` envelopes off the connection's channel.
|
||||
/// * Invokes the supplied closure to build a `result` value for each request.
|
||||
/// * Resolves the matching in-flight oneshot directly (the same effect the WS
|
||||
/// handler has when it receives an `mcp_response` from a real sled).
|
||||
///
|
||||
/// Returns the registered `SledConnection`.
|
||||
fn spawn_mock_sled<F>(handler: F) -> crate::service::gateway::SledConnection
|
||||
where
|
||||
F: Fn(&serde_json::Value) -> serde_json::Value + Send + Sync + 'static,
|
||||
{
|
||||
use crate::service::gateway::SledConnection;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
|
||||
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
let last_heartbeat_ms = Arc::new(AtomicI64::new(chrono::Utc::now().timestamp_millis()));
|
||||
let in_flight = Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new()));
|
||||
|
||||
let conn = SledConnection {
|
||||
tx,
|
||||
last_heartbeat_ms,
|
||||
in_flight: Arc::clone(&in_flight),
|
||||
};
|
||||
let handler = Arc::new(handler);
|
||||
tokio::spawn(async move {
|
||||
while let Some(env) = rx.recv().await {
|
||||
if env.msg_type != "mcp_request" {
|
||||
continue;
|
||||
}
|
||||
let body_str = env
|
||||
.payload
|
||||
.get("body")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let body_json: serde_json::Value = match serde_json::from_str(&body_str) {
|
||||
Ok(v) => v,
|
||||
Err(_) => serde_json::Value::Null,
|
||||
};
|
||||
let result = handler(&body_json);
|
||||
let response = serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": body_json.get("id").cloned().unwrap_or(serde_json::Value::Null),
|
||||
"result": result,
|
||||
});
|
||||
if let Some(oneshot_tx) = in_flight.lock().await.remove(&env.req_id) {
|
||||
let _ = oneshot_tx.send(response);
|
||||
}
|
||||
}
|
||||
});
|
||||
conn
|
||||
}
|
||||
|
||||
/// AC 9: gateway switches active project, calls tools/list and tools/call
|
||||
/// against a sled connected only via WS uplink, gets correct responses.
|
||||
#[tokio::test]
|
||||
async fn ws_only_sled_handles_tools_list_and_tools_call() {
|
||||
use crate::service::gateway::ProjectEntry;
|
||||
|
||||
// Project entry with NO url — WS-only.
|
||||
let mut projects = BTreeMap::new();
|
||||
projects.insert(
|
||||
"ws-only".into(),
|
||||
ProjectEntry {
|
||||
url: None,
|
||||
auth_token: Some("secret".into()),
|
||||
},
|
||||
);
|
||||
let config = GatewayConfig {
|
||||
projects,
|
||||
sled_tokens: BTreeMap::new(),
|
||||
};
|
||||
let state = Arc::new(GatewayState::new(config, PathBuf::new(), 3000).unwrap());
|
||||
|
||||
let conn = spawn_mock_sled(|body| {
|
||||
let method = body.get("method").and_then(|m| m.as_str()).unwrap_or("");
|
||||
match method {
|
||||
"tools/list" => serde_json::json!({
|
||||
"tools": [
|
||||
{ "name": "my_tool", "description": "test" }
|
||||
]
|
||||
}),
|
||||
"tools/call" => serde_json::json!({
|
||||
"content": [{ "type": "text", "text": "called ok" }]
|
||||
}),
|
||||
_ => serde_json::json!({ "echo": method }),
|
||||
}
|
||||
});
|
||||
state
|
||||
.register_sled_connection("ws-only".to_string(), conn)
|
||||
.await;
|
||||
|
||||
// tools/list via proxy_active_mcp.
|
||||
let body = serde_json::to_vec(&serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/list",
|
||||
"params": {}
|
||||
}))
|
||||
.unwrap();
|
||||
let resp = state.proxy_active_mcp(&body).await.expect("ws proxy works");
|
||||
let resp_json: serde_json::Value = serde_json::from_slice(&resp).unwrap();
|
||||
assert_eq!(
|
||||
resp_json["result"]["tools"][0]["name"], "my_tool",
|
||||
"tools/list response must come from the WS-connected sled, not HTTP"
|
||||
);
|
||||
|
||||
// tools/call via proxy_active_mcp.
|
||||
let body = serde_json::to_vec(&serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"method": "tools/call",
|
||||
"params": { "name": "my_tool", "arguments": {} }
|
||||
}))
|
||||
.unwrap();
|
||||
let resp = state.proxy_active_mcp(&body).await.expect("ws proxy works");
|
||||
let resp_json: serde_json::Value = serde_json::from_slice(&resp).unwrap();
|
||||
assert_eq!(
|
||||
resp_json["result"]["content"][0]["text"], "called ok",
|
||||
"tools/call response must come from the WS-connected sled, not HTTP"
|
||||
);
|
||||
}
|
||||
|
||||
/// AC 10: two sleds connected to one gateway concurrently; gateway routes
|
||||
/// calls to the right sled based on active project.
|
||||
#[tokio::test]
|
||||
async fn two_concurrent_sleds_are_routed_by_active_project() {
|
||||
use crate::service::gateway::ProjectEntry;
|
||||
|
||||
let mut projects = BTreeMap::new();
|
||||
projects.insert(
|
||||
"alpha".into(),
|
||||
ProjectEntry {
|
||||
url: None,
|
||||
auth_token: Some("alpha-tok".into()),
|
||||
},
|
||||
);
|
||||
projects.insert(
|
||||
"beta".into(),
|
||||
ProjectEntry {
|
||||
url: None,
|
||||
auth_token: Some("beta-tok".into()),
|
||||
},
|
||||
);
|
||||
let config = GatewayConfig {
|
||||
projects,
|
||||
sled_tokens: BTreeMap::new(),
|
||||
};
|
||||
let state = Arc::new(GatewayState::new(config, PathBuf::new(), 3000).unwrap());
|
||||
|
||||
let alpha_conn = spawn_mock_sled(|_body| {
|
||||
serde_json::json!({
|
||||
"content": [{ "type": "text", "text": "from-alpha" }]
|
||||
})
|
||||
});
|
||||
let beta_conn = spawn_mock_sled(|_body| {
|
||||
serde_json::json!({
|
||||
"content": [{ "type": "text", "text": "from-beta" }]
|
||||
})
|
||||
});
|
||||
state
|
||||
.register_sled_connection("alpha".to_string(), alpha_conn)
|
||||
.await;
|
||||
state
|
||||
.register_sled_connection("beta".to_string(), beta_conn)
|
||||
.await;
|
||||
|
||||
// Switch to alpha.
|
||||
gateway::switch_project(&state, "alpha").await.unwrap();
|
||||
let body = serde_json::to_vec(&serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/call",
|
||||
"params": { "name": "noop", "arguments": {} }
|
||||
}))
|
||||
.unwrap();
|
||||
let resp = state.proxy_active_mcp(&body).await.expect("ws proxy works");
|
||||
let resp_json: serde_json::Value = serde_json::from_slice(&resp).unwrap();
|
||||
assert_eq!(
|
||||
resp_json["result"]["content"][0]["text"], "from-alpha",
|
||||
"When active project is alpha, calls must route to the alpha sled"
|
||||
);
|
||||
|
||||
// Switch to beta.
|
||||
gateway::switch_project(&state, "beta").await.unwrap();
|
||||
let resp = state.proxy_active_mcp(&body).await.expect("ws proxy works");
|
||||
let resp_json: serde_json::Value = serde_json::from_slice(&resp).unwrap();
|
||||
assert_eq!(
|
||||
resp_json["result"]["content"][0]["text"], "from-beta",
|
||||
"When active project is beta, calls must route to the beta sled"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -203,14 +203,11 @@ pub async fn gateway_mcp_post_handler(
|
||||
}
|
||||
|
||||
/// Proxy a request to the active project and format the response.
|
||||
///
|
||||
/// Prefers the live sled-uplink WebSocket when one is attached (story 899
|
||||
/// AC 2); falls back to the legacy HTTP proxy otherwise.
|
||||
async fn proxy_and_respond(state: &GatewayState, bytes: &[u8], id: Option<Value>) -> Response {
|
||||
let url = match state.active_url().await {
|
||||
Ok(u) => u,
|
||||
Err(e) => {
|
||||
return to_json_response(JsonRpcResponse::error(id, -32603, e.to_string()));
|
||||
}
|
||||
};
|
||||
match gateway::io::proxy_mcp_call(&state.client, &url, bytes).await {
|
||||
match state.proxy_active_mcp(bytes).await {
|
||||
Ok(resp_body) => Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", "application/json")
|
||||
@@ -316,13 +313,23 @@ fn handle_initialize(id: Option<Value>) -> JsonRpcResponse {
|
||||
}
|
||||
|
||||
/// Fetch tools/list from the active project and merge in gateway tools.
|
||||
///
|
||||
/// Routes via the sled-uplink WS when one is attached (story 899 AC 2);
|
||||
/// falls back to HTTP otherwise.
|
||||
async fn handle_tools_list(
|
||||
state: &GatewayState,
|
||||
id: Option<Value>,
|
||||
) -> Result<JsonRpcResponse, String> {
|
||||
let url = state.active_url().await.map_err(|e| e.to_string())?;
|
||||
|
||||
let resp_json = gateway::io::fetch_tools_list(&state.client, &url).await?;
|
||||
let rpc_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/list",
|
||||
"params": {}
|
||||
});
|
||||
let bytes = serde_json::to_vec(&rpc_body).map_err(|e| e.to_string())?;
|
||||
let resp_bytes = state.proxy_active_mcp(&bytes).await?;
|
||||
let resp_json: Value =
|
||||
serde_json::from_slice(&resp_bytes).map_err(|e| format!("invalid tools/list JSON: {e}"))?;
|
||||
|
||||
let mut tools: Vec<Value> = resp_json
|
||||
.get("result")
|
||||
@@ -414,21 +421,36 @@ async fn handle_gateway_status_tool(state: &GatewayState, id: Option<Value>) ->
|
||||
async fn handle_gateway_health_tool(state: &GatewayState, id: Option<Value>) -> JsonRpcResponse {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
let project_entries: Vec<(String, String)> = state
|
||||
// Build the project list, preferring the WS-uplink heartbeat as the
|
||||
// source of truth for liveness (story 899 AC 3). HTTP polls are used
|
||||
// only as a fallback when no live sled is connected.
|
||||
let project_names: Vec<(String, Option<String>)> = state
|
||||
.projects
|
||||
.read()
|
||||
.await
|
||||
.iter()
|
||||
.map(|(n, e)| (n.clone(), e.url.clone()))
|
||||
.collect();
|
||||
for (name, url) in &project_entries {
|
||||
let status = match gateway::io::check_project_health(&state.client, url).await {
|
||||
let sled_conns = state.sled_connections.read().await;
|
||||
for (name, url_opt) in &project_names {
|
||||
let status = if let Some(conn) = sled_conns.get(name) {
|
||||
if conn.is_alive(crate::service::gateway::HEARTBEAT_MAX_AGE_MS) {
|
||||
"healthy (ws)".to_string()
|
||||
} else {
|
||||
"stale (ws heartbeat overdue)".to_string()
|
||||
}
|
||||
} else if let Some(url) = url_opt {
|
||||
match gateway::io::check_project_health(&state.client, url).await {
|
||||
Ok(true) => "healthy".to_string(),
|
||||
Ok(false) => "unhealthy".to_string(),
|
||||
Err(e) => e,
|
||||
}
|
||||
} else {
|
||||
"no uplink and no url configured".to_string()
|
||||
};
|
||||
results.insert(name.clone(), status);
|
||||
}
|
||||
drop(sled_conns);
|
||||
|
||||
let active = state.active_project.read().await.clone();
|
||||
JsonRpcResponse::success(
|
||||
@@ -512,7 +534,7 @@ async fn handle_aggregate_pipeline_status_tool(
|
||||
.read()
|
||||
.await
|
||||
.iter()
|
||||
.map(|(name, entry)| (name.clone(), entry.url.clone()))
|
||||
.filter_map(|(name, entry)| entry.url.as_ref().map(|u| (name.clone(), u.clone())))
|
||||
.collect();
|
||||
|
||||
let statuses =
|
||||
@@ -656,7 +678,7 @@ async fn handle_pipeline_get(state: &GatewayState, id: Option<Value>) -> JsonRpc
|
||||
.read()
|
||||
.await
|
||||
.iter()
|
||||
.map(|(n, e)| (n.clone(), e.url.clone()))
|
||||
.filter_map(|(n, e)| e.url.as_ref().map(|u| (n.clone(), u.clone())))
|
||||
.collect();
|
||||
|
||||
let results = gateway::io::fetch_all_project_pipeline_items(&project_urls, &state.client).await;
|
||||
|
||||
@@ -155,21 +155,30 @@ struct SledUplinkParams {
|
||||
token: Option<String>,
|
||||
}
|
||||
|
||||
/// `GET /api/sled-uplink` — gateway-side WebSocket endpoint for sled permission uplinks.
|
||||
/// `GET /api/sled-uplink` — gateway-side WebSocket endpoint for sled uplinks.
|
||||
///
|
||||
/// # Authentication
|
||||
///
|
||||
/// The connecting sled must supply a valid shared-secret token via the `token`
|
||||
/// query parameter. Tokens are configured in `[sled_tokens]` in `projects.toml`
|
||||
/// as `sled_id = "secret"` entries.
|
||||
/// query parameter. Tokens are configured either as per-project `auth_token`
|
||||
/// fields under `[projects.<name>]` in `projects.toml` (preferred, story 899)
|
||||
/// or, for backwards compatibility, in the deprecated `[sled_tokens]` table.
|
||||
///
|
||||
/// # Protocol
|
||||
///
|
||||
/// See `sled_uplink.rs` for the wire format ([`UplinkEnvelope`]). The gateway
|
||||
/// accepts `perm_request` messages, injects them into the local permission
|
||||
/// pipeline (via `state.perm_tx`), and sends `perm_response` frames back to the
|
||||
/// sled once the Matrix bot resolves them. Multiple sleds are demuxed by
|
||||
/// connection: each handler owns exactly one sled's request/response flow.
|
||||
/// See `sled_uplink.rs` for the wire format ([`UplinkEnvelope`]).
|
||||
///
|
||||
/// Phase 2 (story 899) expands the protocol from the original perm-only flow
|
||||
/// to a full bidirectional MCP transport:
|
||||
///
|
||||
/// - **sled → gateway:** `identity` (mandatory first frame after upgrade),
|
||||
/// `heartbeat`, `perm_request`, `mcp_response`.
|
||||
/// - **gateway → sled:** `mcp_request`, `perm_response`.
|
||||
///
|
||||
/// On `identity`, the connection is published to
|
||||
/// [`GatewayState::sled_connections`] under the project name, allowing the
|
||||
/// MCP proxy (`mcp.rs::proxy_and_respond`) to route subsequent calls over
|
||||
/// the live WS instead of HTTP. The entry is removed on disconnect.
|
||||
#[handler]
|
||||
pub async fn gateway_sled_uplink_handler(
|
||||
ws: WebSocket,
|
||||
@@ -196,56 +205,118 @@ pub async fn gateway_sled_uplink_handler(
|
||||
|
||||
use poem::IntoResponse as _;
|
||||
let perm_tx = state.perm_tx.clone();
|
||||
let state = Arc::clone(&state);
|
||||
ws.on_upgrade(move |socket| async move {
|
||||
let (mut sink, mut stream) = socket.split();
|
||||
// Aggregator channel: spawned per-request tasks send (req_id, decision) here
|
||||
// so the main loop can write perm_response frames back to the sled.
|
||||
let (agg_tx, mut agg_rx) = tokio::sync::mpsc::unbounded_channel::<(
|
||||
String,
|
||||
crate::http::context::PermissionDecision,
|
||||
)>();
|
||||
run_sled_uplink_session(socket, state, sled_id, perm_tx).await;
|
||||
})
|
||||
.into_response()
|
||||
}
|
||||
|
||||
crate::slog!("[gateway/sled-uplink] Sled '{}' connected", sled_id);
|
||||
/// Run a single connected sled's request/response flow until it disconnects.
|
||||
///
|
||||
/// Performs the identity handshake, registers a [`SledConnection`] in
|
||||
/// [`GatewayState::sled_connections`] under the published project name, and
|
||||
/// pumps messages bidirectionally for the lifetime of the WebSocket.
|
||||
async fn run_sled_uplink_session(
|
||||
socket: poem::web::websocket::WebSocketStream,
|
||||
state: Arc<GatewayState>,
|
||||
token_sled_id: String,
|
||||
perm_tx: tokio::sync::mpsc::UnboundedSender<crate::http::context::PermissionForward>,
|
||||
) {
|
||||
use crate::service::gateway::SledConnection;
|
||||
use crate::sled_uplink::UplinkEnvelope;
|
||||
use std::sync::Arc as StdArc;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
|
||||
let (mut sink, mut stream) = socket.split();
|
||||
|
||||
// ── Identity handshake: expect the first frame to be `identity` ──────
|
||||
let identity = match tokio::time::timeout(
|
||||
std::time::Duration::from_secs(10),
|
||||
wait_for_identity(&mut stream),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Some(p)) => p,
|
||||
_ => {
|
||||
crate::slog!(
|
||||
"[gateway/sled-uplink] '{}' missing identity frame; closing",
|
||||
token_sled_id
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Project name in the identity must match the project resolved from the
|
||||
// auth token; otherwise the sled is claiming to be a different project.
|
||||
if identity != token_sled_id {
|
||||
crate::slog!(
|
||||
"[gateway/sled-uplink] identity mismatch (token says '{}', sled claims '{}'); closing",
|
||||
token_sled_id,
|
||||
identity
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// ── Build SledConnection and publish it ─────────────────────────────
|
||||
let (out_tx, mut out_rx) = tokio::sync::mpsc::unbounded_channel::<UplinkEnvelope>();
|
||||
let conn = SledConnection {
|
||||
tx: out_tx,
|
||||
last_heartbeat_ms: StdArc::new(AtomicI64::new(chrono::Utc::now().timestamp_millis())),
|
||||
in_flight: StdArc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
|
||||
};
|
||||
state
|
||||
.register_sled_connection(identity.clone(), conn.clone())
|
||||
.await;
|
||||
crate::slog!(
|
||||
"[gateway/sled-uplink] Sled '{}' connected and registered",
|
||||
identity
|
||||
);
|
||||
|
||||
// Aggregator channel for perm responses produced by spawned per-request tasks.
|
||||
let (agg_tx, mut agg_rx) =
|
||||
tokio::sync::mpsc::unbounded_channel::<(String, crate::http::context::PermissionDecision)>(
|
||||
);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Outbound: forward queued envelopes (mcp_request, etc.) to the sled.
|
||||
Some(env) = out_rx.recv() => {
|
||||
let Ok(text) = serde_json::to_string(&env) else { continue };
|
||||
if sink.send(WsMessage::Text(text)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Inbound: messages from the sled.
|
||||
msg = stream.next() => {
|
||||
let text = match msg {
|
||||
Some(Ok(WsMessage::Text(t))) => t,
|
||||
Some(Ok(WsMessage::Close(_))) | None => break,
|
||||
_ => continue,
|
||||
};
|
||||
let Ok(env) = serde_json::from_str::<crate::sled_uplink::UplinkEnvelope>(&text) else {
|
||||
let Ok(env) = serde_json::from_str::<UplinkEnvelope>(&text) else {
|
||||
continue;
|
||||
};
|
||||
if env.msg_type == "perm_request" {
|
||||
let req_id = env.req_id.clone();
|
||||
let tool_name = env.payload.get("tool_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
let tool_input = env.payload.get("tool_input")
|
||||
.cloned()
|
||||
.unwrap_or(serde_json::Value::Null);
|
||||
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
|
||||
let fwd = crate::http::context::PermissionForward {
|
||||
request_id: format!("{sled_id}:{req_id}"),
|
||||
tool_name,
|
||||
tool_input,
|
||||
response_tx,
|
||||
};
|
||||
if perm_tx.send(fwd).is_err() {
|
||||
break;
|
||||
match env.msg_type.as_str() {
|
||||
"heartbeat" => {
|
||||
conn.last_heartbeat_ms.store(
|
||||
chrono::Utc::now().timestamp_millis(),
|
||||
std::sync::atomic::Ordering::Relaxed,
|
||||
);
|
||||
}
|
||||
let agg_tx2 = agg_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
let decision = response_rx
|
||||
.await
|
||||
.unwrap_or(crate::http::context::PermissionDecision::Deny);
|
||||
let _ = agg_tx2.send((req_id, decision));
|
||||
});
|
||||
"mcp_response" => {
|
||||
let tx_opt = conn.in_flight.lock().await.remove(&env.req_id);
|
||||
if let Some(tx) = tx_opt {
|
||||
let _ = tx.send(env.payload);
|
||||
}
|
||||
}
|
||||
"perm_request" => {
|
||||
forward_perm_request(env, &perm_tx, &agg_tx, &identity);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
// Aggregator: per-request permission resolutions get formatted as perm_response.
|
||||
Some((req_id, decision)) = agg_rx.recv() => {
|
||||
use crate::http::context::PermissionDecision;
|
||||
let (approved, always_allow) = match decision {
|
||||
@@ -253,7 +324,7 @@ pub async fn gateway_sled_uplink_handler(
|
||||
PermissionDecision::Approve => (true, false),
|
||||
PermissionDecision::Deny => (false, false),
|
||||
};
|
||||
let resp = crate::sled_uplink::UplinkEnvelope {
|
||||
let resp = UplinkEnvelope {
|
||||
msg_type: "perm_response".to_string(),
|
||||
req_id,
|
||||
payload: serde_json::json!({
|
||||
@@ -269,9 +340,76 @@ pub async fn gateway_sled_uplink_handler(
|
||||
}
|
||||
}
|
||||
|
||||
crate::slog!("[gateway/sled-uplink] Sled '{}' disconnected", sled_id);
|
||||
})
|
||||
.into_response()
|
||||
state.deregister_sled_connection(&identity).await;
|
||||
crate::slog!("[gateway/sled-uplink] Sled '{}' disconnected", identity);
|
||||
}
|
||||
|
||||
/// Read frames until an `identity` envelope is seen; return the project name
|
||||
/// from its payload. Returns `None` if the stream ends or an invalid frame
|
||||
/// arrives where the first frame is expected to be `identity`.
|
||||
async fn wait_for_identity(
|
||||
stream: &mut futures::stream::SplitStream<poem::web::websocket::WebSocketStream>,
|
||||
) -> Option<String> {
|
||||
while let Some(msg) = stream.next().await {
|
||||
let text = match msg {
|
||||
Ok(WsMessage::Text(t)) => t,
|
||||
Ok(WsMessage::Close(_)) | Err(_) => return None,
|
||||
_ => continue,
|
||||
};
|
||||
let Ok(env) = serde_json::from_str::<crate::sled_uplink::UplinkEnvelope>(&text) else {
|
||||
return None;
|
||||
};
|
||||
if env.msg_type != "identity" {
|
||||
return None;
|
||||
}
|
||||
return env
|
||||
.payload
|
||||
.get("project")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(str::to_string);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Convert an inbound `perm_request` envelope into a [`PermissionForward`] and
|
||||
/// inject it into the gateway's permission pipeline. The spawned waiter task
|
||||
/// publishes the resolved decision into the aggregator channel so the WS
|
||||
/// writer can emit a matching `perm_response` frame.
|
||||
fn forward_perm_request(
|
||||
env: crate::sled_uplink::UplinkEnvelope,
|
||||
perm_tx: &tokio::sync::mpsc::UnboundedSender<crate::http::context::PermissionForward>,
|
||||
agg_tx: &tokio::sync::mpsc::UnboundedSender<(String, crate::http::context::PermissionDecision)>,
|
||||
sled_id: &str,
|
||||
) {
|
||||
let req_id = env.req_id.clone();
|
||||
let tool_name = env
|
||||
.payload
|
||||
.get("tool_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
let tool_input = env
|
||||
.payload
|
||||
.get("tool_input")
|
||||
.cloned()
|
||||
.unwrap_or(serde_json::Value::Null);
|
||||
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
|
||||
let fwd = crate::http::context::PermissionForward {
|
||||
request_id: format!("{sled_id}:{req_id}"),
|
||||
tool_name,
|
||||
tool_input,
|
||||
response_tx,
|
||||
};
|
||||
if perm_tx.send(fwd).is_err() {
|
||||
return;
|
||||
}
|
||||
let agg_tx2 = agg_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
let decision = response_rx
|
||||
.await
|
||||
.unwrap_or(crate::http::context::PermissionDecision::Deny);
|
||||
let _ = agg_tx2.send((req_id, decision));
|
||||
});
|
||||
}
|
||||
|
||||
// ── Event-push WebSocket handler ─────────────────────────────────────────────
|
||||
|
||||
+19
-1
@@ -241,7 +241,25 @@ async fn main() -> Result<(), std::io::Error> {
|
||||
.clone()
|
||||
.or_else(|| std::env::var("HUSKIES_UPSTREAM_GATEWAY").ok())
|
||||
.unwrap_or_default();
|
||||
sled_uplink::spawn_uplink_task(upstream_gateway, Arc::clone(&services));
|
||||
// Project name for the identity frame (story 899 AC 5). Env-var override
|
||||
// wins; otherwise derive from the project_root basename.
|
||||
let project_name = std::env::var("HUSKIES_PROJECT_NAME").unwrap_or_else(|_| {
|
||||
services
|
||||
.project_root
|
||||
.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string()
|
||||
});
|
||||
let local_mcp_url = format!("http://127.0.0.1:{port}/mcp");
|
||||
sled_uplink::spawn_uplink_task(
|
||||
sled_uplink::UplinkConfig {
|
||||
upstream_url: upstream_gateway,
|
||||
project_name,
|
||||
local_mcp_url,
|
||||
},
|
||||
Arc::clone(&services),
|
||||
);
|
||||
|
||||
// ── Build bot contexts (WhatsApp / Slack / Discord) ───────────────────────
|
||||
let (bot_ctxs, matrix_shutdown_rx) =
|
||||
|
||||
@@ -7,20 +7,56 @@ use serde::{Deserialize, Serialize};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
/// A single project entry in `projects.toml`.
|
||||
///
|
||||
/// Phase 2 (story 899): `url` is now optional — a project served exclusively
|
||||
/// via the sled-uplink WebSocket does not need an HTTP base URL. The `url`
|
||||
/// field is deprecated for removal in a future release; configure
|
||||
/// `auth_token` instead and rely on the WS uplink for all traffic.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ProjectEntry {
|
||||
/// Base URL of the project's huskies container (e.g. `http://localhost:3001`).
|
||||
pub url: String,
|
||||
///
|
||||
/// **Deprecated** (story 899) — when a sled connects via the uplink WS the
|
||||
/// gateway routes all MCP traffic over that connection instead. The URL is
|
||||
/// used as a fallback when no live uplink exists. Omit for WS-only projects.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub url: Option<String>,
|
||||
/// Shared-secret token used to authenticate this project's sled when it
|
||||
/// connects to `/api/sled-uplink`. Takes precedence over the top-level
|
||||
/// `[sled_tokens]` table for projects that set this field.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub auth_token: Option<String>,
|
||||
}
|
||||
|
||||
impl ProjectEntry {
|
||||
/// Convenience constructor for entries that only have a URL (e.g. in tests
|
||||
/// and existing `projects.toml` files that have not yet been migrated to
|
||||
/// the WS-uplink model).
|
||||
pub fn with_url(url: impl Into<String>) -> Self {
|
||||
Self {
|
||||
url: Some(url.into()),
|
||||
auth_token: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if this entry has a configured HTTP base URL.
|
||||
pub fn has_url(&self) -> bool {
|
||||
self.url.as_ref().is_some_and(|u| !u.is_empty())
|
||||
}
|
||||
}
|
||||
|
||||
/// Top-level `projects.toml` config.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct GatewayConfig {
|
||||
/// Map of project name → container URL.
|
||||
/// Map of project name → container configuration.
|
||||
#[serde(default)]
|
||||
pub projects: BTreeMap<String, ProjectEntry>,
|
||||
/// Map of sled_id → shared secret token for sled-uplink authentication.
|
||||
///
|
||||
/// **Deprecated** (story 899) — move tokens into per-project
|
||||
/// `auth_token` fields instead. The gateway still honours entries here for
|
||||
/// one release to provide a smooth migration window.
|
||||
///
|
||||
/// Each entry allows a sled identified by `sled_id` to connect to
|
||||
/// `/api/sled-uplink` using the given secret token as a bearer credential.
|
||||
#[serde(default)]
|
||||
@@ -40,12 +76,16 @@ pub fn validate_config(config: &GatewayConfig) -> Result<String, String> {
|
||||
|
||||
/// Validate that a project name exists in the given project map.
|
||||
///
|
||||
/// Returns the project's URL on success.
|
||||
/// Returns the project's URL (may be empty for WS-uplink-only projects) on
|
||||
/// success.
|
||||
pub fn validate_project_exists(
|
||||
projects: &BTreeMap<String, ProjectEntry>,
|
||||
name: &str,
|
||||
) -> Result<String, String> {
|
||||
projects.get(name).map(|p| p.url.clone()).ok_or_else(|| {
|
||||
projects
|
||||
.get(name)
|
||||
.map(|p| p.url.clone().unwrap_or_default())
|
||||
.ok_or_else(|| {
|
||||
let available: Vec<&str> = projects.keys().map(|s| s.as_str()).collect();
|
||||
format!(
|
||||
"unknown project '{name}'. Available: {}",
|
||||
@@ -104,8 +144,29 @@ url = "http://localhost:3002"
|
||||
"#;
|
||||
let config: GatewayConfig = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(config.projects.len(), 2);
|
||||
assert_eq!(config.projects["huskies"].url, "http://localhost:3001");
|
||||
assert_eq!(config.projects["robot-studio"].url, "http://localhost:3002");
|
||||
assert_eq!(
|
||||
config.projects["huskies"].url.as_deref(),
|
||||
Some("http://localhost:3001")
|
||||
);
|
||||
assert_eq!(
|
||||
config.projects["robot-studio"].url.as_deref(),
|
||||
Some("http://localhost:3002")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_project_without_url_is_valid() {
|
||||
let toml_str = r#"
|
||||
[projects.ws-only]
|
||||
auth_token = "secret"
|
||||
"#;
|
||||
let config: GatewayConfig = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(config.projects.len(), 1);
|
||||
assert!(config.projects["ws-only"].url.is_none());
|
||||
assert_eq!(
|
||||
config.projects["ws-only"].auth_token.as_deref(),
|
||||
Some("secret")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -127,18 +188,8 @@ url = "http://localhost:3002"
|
||||
#[test]
|
||||
fn validate_config_returns_first_project_name() {
|
||||
let mut projects = BTreeMap::new();
|
||||
projects.insert(
|
||||
"beta".into(),
|
||||
ProjectEntry {
|
||||
url: "http://b".into(),
|
||||
},
|
||||
);
|
||||
projects.insert(
|
||||
"alpha".into(),
|
||||
ProjectEntry {
|
||||
url: "http://a".into(),
|
||||
},
|
||||
);
|
||||
projects.insert("beta".into(), ProjectEntry::with_url("http://b"));
|
||||
projects.insert("alpha".into(), ProjectEntry::with_url("http://a"));
|
||||
let config = GatewayConfig {
|
||||
projects,
|
||||
sled_tokens: BTreeMap::new(),
|
||||
@@ -147,14 +198,26 @@ url = "http://localhost:3002"
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_project_exists_succeeds() {
|
||||
fn validate_config_accepts_ws_only_project() {
|
||||
let mut projects = BTreeMap::new();
|
||||
projects.insert(
|
||||
"p1".into(),
|
||||
"ws-only".into(),
|
||||
ProjectEntry {
|
||||
url: "http://p1".into(),
|
||||
url: None,
|
||||
auth_token: Some("secret".into()),
|
||||
},
|
||||
);
|
||||
let config = GatewayConfig {
|
||||
projects,
|
||||
sled_tokens: BTreeMap::new(),
|
||||
};
|
||||
assert!(validate_config(&config).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_project_exists_succeeds() {
|
||||
let mut projects = BTreeMap::new();
|
||||
projects.insert("p1".into(), ProjectEntry::with_url("http://p1"));
|
||||
assert_eq!(
|
||||
validate_project_exists(&projects, "p1").unwrap(),
|
||||
"http://p1"
|
||||
@@ -167,6 +230,36 @@ url = "http://localhost:3002"
|
||||
assert!(validate_project_exists(&projects, "missing").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_project_exists_ws_only_returns_empty_url() {
|
||||
let mut projects = BTreeMap::new();
|
||||
projects.insert(
|
||||
"ws".into(),
|
||||
ProjectEntry {
|
||||
url: None,
|
||||
auth_token: Some("tok".into()),
|
||||
},
|
||||
);
|
||||
assert_eq!(validate_project_exists(&projects, "ws").unwrap(), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn project_entry_with_url_constructor() {
|
||||
let e = ProjectEntry::with_url("http://example.com");
|
||||
assert_eq!(e.url.as_deref(), Some("http://example.com"));
|
||||
assert!(e.auth_token.is_none());
|
||||
assert!(e.has_url());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn project_entry_has_url_false_when_none() {
|
||||
let e = ProjectEntry {
|
||||
url: None,
|
||||
auth_token: Some("tok".into()),
|
||||
};
|
||||
assert!(!e.has_url());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn toml_string_escapes_quotes() {
|
||||
assert_eq!(toml_string(r#"a"b"#), r#""a\"b""#);
|
||||
@@ -198,4 +291,28 @@ url = "http://localhost:3002"
|
||||
assert!(content.contains("transport = \"slack\""));
|
||||
assert!(content.contains("slack_bot_token = \"xoxb-123\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_project_entry_with_auth_token() {
|
||||
let entry = ProjectEntry {
|
||||
url: Some("http://a:3001".into()),
|
||||
auth_token: Some("mysecret".into()),
|
||||
};
|
||||
let mut projects = BTreeMap::new();
|
||||
projects.insert("myproj".into(), entry);
|
||||
let config = GatewayConfig {
|
||||
projects,
|
||||
sled_tokens: BTreeMap::new(),
|
||||
};
|
||||
let toml_str = toml::to_string_pretty(&config).unwrap();
|
||||
let parsed: GatewayConfig = toml::from_str(&toml_str).unwrap();
|
||||
assert_eq!(
|
||||
parsed.projects["myproj"].url.as_deref(),
|
||||
Some("http://a:3001")
|
||||
);
|
||||
assert_eq!(
|
||||
parsed.projects["myproj"].auth_token.as_deref(),
|
||||
Some("mysecret")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,29 +140,6 @@ pub async fn proxy_mcp_call_sse(
|
||||
.map_err(|e| format!("failed to reach {mcp_url}: {e}"))
|
||||
}
|
||||
|
||||
/// Fetch tools/list from a project's MCP endpoint.
|
||||
pub async fn fetch_tools_list(client: &Client, base_url: &str) -> Result<Value, String> {
|
||||
let mcp_url = format!("{}/mcp", base_url.trim_end_matches('/'));
|
||||
|
||||
let rpc_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/list",
|
||||
"params": {}
|
||||
});
|
||||
|
||||
let resp = client
|
||||
.post(&mcp_url)
|
||||
.json(&rpc_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("failed to reach {mcp_url}: {e}"))?;
|
||||
|
||||
resp.json()
|
||||
.await
|
||||
.map_err(|e| format!("invalid JSON from upstream: {e}"))
|
||||
}
|
||||
|
||||
/// Fetch and aggregate pipeline status for a single project URL.
|
||||
pub async fn fetch_one_project_pipeline_status(url: &str, client: &Client) -> Value {
|
||||
let mcp_url = format!("{}/mcp", url.trim_end_matches('/'));
|
||||
|
||||
@@ -28,6 +28,7 @@ use io::Client;
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicI64, Ordering};
|
||||
use tokio::sync::Mutex as TokioMutex;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::sync::mpsc;
|
||||
@@ -49,6 +50,95 @@ pub struct GatewayStatusEvent {
|
||||
pub event: crate::service::events::StoredEvent,
|
||||
}
|
||||
|
||||
// ── Sled connection ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Maximum age, in milliseconds, of a sled heartbeat before the gateway
|
||||
/// considers the connection stale (story 899 AC 3).
|
||||
pub const HEARTBEAT_MAX_AGE_MS: i64 = 30_000;
|
||||
|
||||
/// Default per-request timeout, in milliseconds, for an MCP call proxied over
|
||||
/// a sled uplink WebSocket. Mirrors the existing reqwest-based path which has
|
||||
/// no explicit cap; we set a generous bound so long-running tools (e.g.
|
||||
/// `run_tests`) still complete.
|
||||
pub const MCP_VIA_WS_TIMEOUT_MS: u64 = 1_200_000;
|
||||
|
||||
/// Handle to a sled currently connected to the gateway via the uplink WebSocket.
|
||||
///
|
||||
/// Created by the `/api/sled-uplink` WS handler on connect and stored in
|
||||
/// [`GatewayState::sled_connections`]. The gateway's MCP proxy reads this
|
||||
/// to forward requests over the live connection rather than via HTTP.
|
||||
#[derive(Clone)]
|
||||
pub struct SledConnection {
|
||||
/// Sender side of the channel the WS handler reads to forward outgoing
|
||||
/// frames (e.g. `mcp_request`) to the sled.
|
||||
pub tx: mpsc::UnboundedSender<crate::sled_uplink::UplinkEnvelope>,
|
||||
/// Timestamp (ms since Unix epoch) of the last `heartbeat` frame received
|
||||
/// from this sled. Updated atomically by the WS handler task.
|
||||
pub last_heartbeat_ms: Arc<AtomicI64>,
|
||||
/// In-flight MCP requests waiting for a matching `mcp_response` from this
|
||||
/// sled. Keyed by `req_id`; the oneshot sender is resolved when the
|
||||
/// response arrives.
|
||||
pub in_flight:
|
||||
Arc<TokioMutex<HashMap<String, tokio::sync::oneshot::Sender<serde_json::Value>>>>,
|
||||
}
|
||||
|
||||
impl SledConnection {
|
||||
/// Returns `true` if a heartbeat has been received within the last `max_age_ms`
|
||||
/// milliseconds.
|
||||
pub fn is_alive(&self, max_age_ms: i64) -> bool {
|
||||
let last = self.last_heartbeat_ms.load(Ordering::Relaxed);
|
||||
let now = chrono::Utc::now().timestamp_millis();
|
||||
now - last <= max_age_ms
|
||||
}
|
||||
}
|
||||
|
||||
/// Proxy a raw MCP request body to a sled over its uplink WebSocket and
|
||||
/// return the serialised JSON response bytes.
|
||||
///
|
||||
/// Generates a fresh correlation id, registers a oneshot in the connection's
|
||||
/// in-flight map, sends an `mcp_request` envelope, and waits for the
|
||||
/// matching `mcp_response` (or [`MCP_VIA_WS_TIMEOUT_MS`] to elapse).
|
||||
///
|
||||
/// The response payload is serialised back into JSON bytes so callers can
|
||||
/// return it directly to the HTTP client unchanged.
|
||||
pub async fn proxy_mcp_via_ws(
|
||||
conn: &SledConnection,
|
||||
request_bytes: &[u8],
|
||||
) -> Result<Vec<u8>, String> {
|
||||
let req_id = uuid::Uuid::new_v4().to_string();
|
||||
let body_str = std::str::from_utf8(request_bytes)
|
||||
.map_err(|e| format!("non-utf8 mcp request body: {e}"))?
|
||||
.to_string();
|
||||
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
conn.in_flight.lock().await.insert(req_id.clone(), tx);
|
||||
|
||||
let env = crate::sled_uplink::UplinkEnvelope {
|
||||
msg_type: "mcp_request".to_string(),
|
||||
req_id: req_id.clone(),
|
||||
payload: serde_json::json!({ "body": body_str }),
|
||||
};
|
||||
|
||||
if conn.tx.send(env).is_err() {
|
||||
conn.in_flight.lock().await.remove(&req_id);
|
||||
return Err("sled uplink connection closed".to_string());
|
||||
}
|
||||
|
||||
let timeout = std::time::Duration::from_millis(MCP_VIA_WS_TIMEOUT_MS);
|
||||
match tokio::time::timeout(timeout, rx).await {
|
||||
Ok(Ok(response_value)) => {
|
||||
serde_json::to_vec(&response_value).map_err(|e| format!("serialise mcp_response: {e}"))
|
||||
}
|
||||
Ok(Err(_)) => Err("sled response channel dropped".to_string()),
|
||||
Err(_) => {
|
||||
conn.in_flight.lock().await.remove(&req_id);
|
||||
Err(format!(
|
||||
"mcp call to sled timed out after {MCP_VIA_WS_TIMEOUT_MS} ms"
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Error type ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// Typed errors returned by `service::gateway` functions.
|
||||
@@ -135,11 +225,18 @@ pub struct GatewayState {
|
||||
/// The Matrix bot's `permission_listener` holds this locked for its lifetime;
|
||||
/// the sled-uplink WS handler sends requests via `perm_tx`.
|
||||
pub perm_rx: Arc<TokioMutex<mpsc::UnboundedReceiver<PermissionForward>>>,
|
||||
/// Reversed sled-token map: token → sled_id.
|
||||
/// Reversed sled-token map: token → project_name (sled_id).
|
||||
///
|
||||
/// Built at startup from [`GatewayConfig::sled_tokens`] (which maps
|
||||
/// sled_id → token). The handler looks up incoming tokens in O(1).
|
||||
/// Built at startup from both [`GatewayConfig::sled_tokens`] AND the
|
||||
/// per-project `auth_token` field (story 899). The handler looks up
|
||||
/// incoming tokens in O(1) to identify the project the sled represents.
|
||||
pub sled_tokens: HashMap<String, String>,
|
||||
/// Live sled connections keyed by project name.
|
||||
///
|
||||
/// Populated by the `/api/sled-uplink` WS handler when a sled authenticates
|
||||
/// and depopulated when it disconnects. MCP proxy functions check here
|
||||
/// first (WS route), falling back to HTTP when no live connection exists.
|
||||
pub sled_connections: Arc<RwLock<HashMap<String, SledConnection>>>,
|
||||
}
|
||||
|
||||
impl GatewayState {
|
||||
@@ -160,11 +257,21 @@ impl GatewayState {
|
||||
.unwrap_or(first_from_config);
|
||||
let (event_tx, _) = tokio::sync::broadcast::channel(EVENT_CHANNEL_CAPACITY);
|
||||
let (perm_tx, perm_rx) = mpsc::unbounded_channel::<PermissionForward>();
|
||||
let sled_tokens: HashMap<String, String> = gateway_config
|
||||
|
||||
// Build token→project_name map from two sources:
|
||||
// 1. Legacy top-level [sled_tokens] section (sled_id → token, reversed)
|
||||
// 2. Per-project auth_token fields (project_name → token, reversed)
|
||||
let mut sled_tokens: HashMap<String, String> = gateway_config
|
||||
.sled_tokens
|
||||
.iter()
|
||||
.map(|(sled_id, token)| (token.clone(), sled_id.clone()))
|
||||
.collect();
|
||||
for (project_name, entry) in &gateway_config.projects {
|
||||
if let Some(ref token) = entry.auth_token {
|
||||
sled_tokens.insert(token.clone(), project_name.clone());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
projects: Arc::new(RwLock::new(gateway_config.projects)),
|
||||
active_project: Arc::new(RwLock::new(first)),
|
||||
@@ -178,26 +285,78 @@ impl GatewayState {
|
||||
perm_tx,
|
||||
perm_rx: Arc::new(TokioMutex::new(perm_rx)),
|
||||
sled_tokens,
|
||||
sled_connections: Arc::new(RwLock::new(HashMap::new())),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the URL of the currently active project.
|
||||
/// Get the URL of the currently active project, if one is configured.
|
||||
///
|
||||
/// Returns `Err` when the active project has no URL configured (WS-uplink
|
||||
/// only) or the project name is not found.
|
||||
pub async fn active_url(&self) -> Result<String, Error> {
|
||||
let name = self.active_project.read().await.clone();
|
||||
self.projects
|
||||
.read()
|
||||
.await
|
||||
.get(&name)
|
||||
.map(|p| p.url.clone())
|
||||
.and_then(|p| p.url.clone())
|
||||
.ok_or_else(|| {
|
||||
Error::ProjectNotFound(format!("active project '{name}' not found in config"))
|
||||
Error::ProjectNotFound(format!(
|
||||
"active project '{name}' has no URL configured \
|
||||
(use sled-uplink WS or add url to projects.toml)"
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
/// Register a live sled connection for the given project.
|
||||
pub async fn register_sled_connection(&self, project_name: String, conn: SledConnection) {
|
||||
self.sled_connections
|
||||
.write()
|
||||
.await
|
||||
.insert(project_name, conn);
|
||||
}
|
||||
|
||||
/// Remove the sled connection for the given project (on disconnect).
|
||||
pub async fn deregister_sled_connection(&self, project_name: &str) {
|
||||
self.sled_connections.write().await.remove(project_name);
|
||||
}
|
||||
|
||||
/// Look up the live sled connection for the active project, returning a
|
||||
/// clone if one exists and has a recent heartbeat.
|
||||
///
|
||||
/// Returns `None` when no sled has connected for this project or when its
|
||||
/// heartbeat is overdue.
|
||||
pub async fn active_sled_connection(&self) -> Option<SledConnection> {
|
||||
let name = self.active_project.read().await.clone();
|
||||
let conn = self.sled_connections.read().await.get(&name).cloned()?;
|
||||
if conn.is_alive(HEARTBEAT_MAX_AGE_MS) {
|
||||
Some(conn)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Proxy an MCP request to the active project, preferring the live
|
||||
/// sled-uplink WebSocket when available (story 899 AC 2) and falling
|
||||
/// back to HTTP otherwise.
|
||||
///
|
||||
/// Returns the raw response body bytes ready to be relayed to the caller.
|
||||
pub async fn proxy_active_mcp(&self, bytes: &[u8]) -> Result<Vec<u8>, String> {
|
||||
if let Some(conn) = self.active_sled_connection().await {
|
||||
return proxy_mcp_via_ws(&conn, bytes).await;
|
||||
}
|
||||
let url = self.active_url().await.map_err(|e| e.to_string())?;
|
||||
crate::slog!(
|
||||
"[gateway] MCP proxy: WS uplink unavailable, falling back to HTTP \
|
||||
(deprecated, will be removed once all sleds are WS-only)"
|
||||
);
|
||||
crate::service::gateway::io::proxy_mcp_call(&self.client, &url, bytes).await
|
||||
}
|
||||
}
|
||||
|
||||
// ── Public API ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// Switch the active project. Returns the project's URL on success.
|
||||
/// Switch the active project. Returns the project's URL (empty for WS-only projects).
|
||||
///
|
||||
/// Writes the new active project to the CRDT `gateway_config.active_project`
|
||||
/// register (LWW — last write wins) so the selection is persisted across
|
||||
@@ -358,7 +517,7 @@ pub async fn add_project(state: &GatewayState, name: &str, url: &str) -> Result<
|
||||
"project '{name}' already exists"
|
||||
)));
|
||||
}
|
||||
projects.insert(name.clone(), ProjectEntry { url: url.clone() });
|
||||
projects.insert(name.clone(), ProjectEntry::with_url(&url));
|
||||
}
|
||||
|
||||
let snapshot = state.projects.read().await.clone();
|
||||
@@ -441,7 +600,7 @@ pub async fn init_project(
|
||||
"project '{n}' is already registered. Choose a different name or use switch_project."
|
||||
)));
|
||||
}
|
||||
projects.insert(n.to_string(), ProjectEntry { url: u.to_string() });
|
||||
projects.insert(n.to_string(), ProjectEntry::with_url(u));
|
||||
io::save_config(&projects, &state.config_dir).await;
|
||||
crate::slog!("[gateway] init_project: registered '{n}' ({u})");
|
||||
Some(n.to_string())
|
||||
@@ -494,7 +653,7 @@ pub async fn save_bot_config_and_restart(state: &GatewayState, content: &str) ->
|
||||
.read()
|
||||
.await
|
||||
.iter()
|
||||
.map(|(name, entry)| (name.clone(), entry.url.clone()))
|
||||
.filter_map(|(name, entry)| entry.url.as_ref().map(|u| (name.clone(), u.clone())))
|
||||
.collect();
|
||||
|
||||
let (new_handle, new_shutdown_tx) = io::spawn_gateway_bot(
|
||||
@@ -523,12 +682,7 @@ mod tests {
|
||||
fn make_config(names: &[(&str, &str)]) -> GatewayConfig {
|
||||
let mut projects = BTreeMap::new();
|
||||
for (name, url) in names {
|
||||
projects.insert(
|
||||
name.to_string(),
|
||||
ProjectEntry {
|
||||
url: url.to_string(),
|
||||
},
|
||||
);
|
||||
projects.insert(name.to_string(), ProjectEntry::with_url(*url));
|
||||
}
|
||||
GatewayConfig {
|
||||
projects,
|
||||
@@ -584,6 +738,24 @@ mod tests {
|
||||
assert_eq!(url, "http://my:3001");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn active_url_fails_for_ws_only_project() {
|
||||
let mut projects = BTreeMap::new();
|
||||
projects.insert(
|
||||
"ws-proj".into(),
|
||||
ProjectEntry {
|
||||
url: None,
|
||||
auth_token: Some("tok".into()),
|
||||
},
|
||||
);
|
||||
let config = GatewayConfig {
|
||||
projects,
|
||||
sled_tokens: BTreeMap::new(),
|
||||
};
|
||||
let state = GatewayState::new(config, PathBuf::from("."), 3000).unwrap();
|
||||
assert!(state.active_url().await.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn error_display_variants() {
|
||||
assert!(
|
||||
@@ -682,4 +854,47 @@ mod tests {
|
||||
let result = init_project(&state, dir.path().to_str().unwrap(), None, None).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sled_connection_registration_and_lookup() {
|
||||
let config = make_config(&[("myproj", "http://myproj:3001")]);
|
||||
let state = GatewayState::new(config, PathBuf::new(), 3000).unwrap();
|
||||
|
||||
let (tx, _rx) = mpsc::unbounded_channel();
|
||||
let conn = SledConnection {
|
||||
tx,
|
||||
last_heartbeat_ms: Arc::new(AtomicI64::new(chrono::Utc::now().timestamp_millis())),
|
||||
in_flight: Arc::new(TokioMutex::new(HashMap::new())),
|
||||
};
|
||||
|
||||
state
|
||||
.register_sled_connection("myproj".to_string(), conn)
|
||||
.await;
|
||||
assert!(state.sled_connections.read().await.contains_key("myproj"));
|
||||
|
||||
state.deregister_sled_connection("myproj").await;
|
||||
assert!(!state.sled_connections.read().await.contains_key("myproj"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_token_in_project_entry_populates_sled_tokens_map() {
|
||||
let mut projects = BTreeMap::new();
|
||||
projects.insert(
|
||||
"huskies".into(),
|
||||
ProjectEntry {
|
||||
url: Some("http://huskies:3001".into()),
|
||||
auth_token: Some("secret-token".into()),
|
||||
},
|
||||
);
|
||||
let config = GatewayConfig {
|
||||
projects,
|
||||
sled_tokens: BTreeMap::new(),
|
||||
};
|
||||
let state = GatewayState::new(config, PathBuf::new(), 3000).unwrap();
|
||||
assert_eq!(
|
||||
state.sled_tokens.get("secret-token").map(|s| s.as_str()),
|
||||
Some("huskies"),
|
||||
"Per-project auth_token must be in reversed sled_tokens map"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
+250
-27
@@ -1,6 +1,5 @@
|
||||
//! Sled uplink — background task that maintains a WebSocket connection from a
|
||||
//! sled (standard huskies instance) to an upstream gateway for permission
|
||||
//! request forwarding.
|
||||
//! sled (standard huskies instance) to an upstream gateway.
|
||||
//!
|
||||
//! When `HUSKIES_UPSTREAM_GATEWAY` is set (or `--upstream-gateway` is passed
|
||||
//! on the CLI), this module spawns a task that:
|
||||
@@ -10,10 +9,17 @@
|
||||
//! auto-denying requests with "no interactive session".
|
||||
//! 2. Maintains a persistent WebSocket connection to the gateway's
|
||||
//! `/api/sled-uplink` endpoint.
|
||||
//! 3. Forwards each [`PermissionForward`] as a `perm_request` JSON envelope.
|
||||
//! 4. Awaits the matching `perm_response` envelope from the gateway.
|
||||
//! 5. Reconnects with exponential back-off on connection drop, fail-closing
|
||||
//! any in-flight requests with [`PermissionDecision::Deny`].
|
||||
//! 3. Sends an `identity` frame announcing the sled's project name + auth
|
||||
//! token immediately after connect (story 899 Phase 2).
|
||||
//! 4. Sends a `heartbeat` frame every [`HEARTBEAT_INTERVAL_SECS`] so the
|
||||
//! gateway can mark the connection live without extra HTTP polls.
|
||||
//! 5. Forwards each [`PermissionForward`] as a `perm_request` JSON envelope
|
||||
//! and awaits the matching `perm_response`.
|
||||
//! 6. Handles inbound `mcp_request` frames by replaying the MCP JSON-RPC
|
||||
//! body against the sled's own local `/mcp` HTTP endpoint and returning
|
||||
//! the response as an `mcp_response` frame.
|
||||
//! 7. Reconnects with exponential back-off on connection drop, fail-closing
|
||||
//! any in-flight permission requests with [`PermissionDecision::Deny`].
|
||||
|
||||
use crate::http::context::{PermissionDecision, PermissionForward};
|
||||
use crate::services::Services;
|
||||
@@ -25,12 +31,35 @@ use std::sync::Arc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_tungstenite::tungstenite::Message as WsMessage;
|
||||
|
||||
/// Configuration for [`spawn_uplink_task`].
|
||||
///
|
||||
/// Bundled into a single struct so `main.rs` can build it once and pass it in
|
||||
/// without proliferating positional arguments.
|
||||
pub struct UplinkConfig {
|
||||
/// WebSocket URL of the upstream gateway's `/api/sled-uplink` endpoint.
|
||||
/// Includes the `?token=` query parameter for auth.
|
||||
pub upstream_url: String,
|
||||
/// Project name this sled identifies as. Sent in the `identity` frame
|
||||
/// after WS connect (story 899 AC 5).
|
||||
pub project_name: String,
|
||||
/// HTTP base URL of this sled's own MCP endpoint (e.g.
|
||||
/// `http://127.0.0.1:3001/mcp`). Used to replay `mcp_request` frames
|
||||
/// received from the gateway against the local MCP handler.
|
||||
pub local_mcp_url: String,
|
||||
}
|
||||
|
||||
// ── Back-off constants ────────────────────────────────────────────────────────
|
||||
|
||||
const INITIAL_BACKOFF_SECS: u64 = 1;
|
||||
const MAX_BACKOFF_SECS: u64 = 60;
|
||||
const BACKOFF_MULTIPLIER: u64 = 2;
|
||||
|
||||
/// Interval between `heartbeat` frames sent to the gateway (story 899 AC 3).
|
||||
///
|
||||
/// Must be shorter than `crate::service::gateway::HEARTBEAT_MAX_AGE_MS` so
|
||||
/// the gateway never marks a healthy sled as stale.
|
||||
pub const HEARTBEAT_INTERVAL_SECS: u64 = 10;
|
||||
|
||||
// ── Wire protocol ─────────────────────────────────────────────────────────────
|
||||
|
||||
/// Extensible JSON envelope for all sled↔gateway uplink messages.
|
||||
@@ -54,15 +83,21 @@ pub struct UplinkEnvelope {
|
||||
|
||||
/// Spawn the sled uplink background task.
|
||||
///
|
||||
/// Does nothing when `upstream_url` is empty. When active, the task holds
|
||||
/// `services.perm_rx` locked for its lifetime (preventing auto-deny in
|
||||
/// Does nothing when `config.upstream_url` is empty (AC 8 — sleds without
|
||||
/// an upstream configured continue to work unchanged). When active, the task
|
||||
/// holds `services.perm_rx` locked for its lifetime (preventing auto-deny in
|
||||
/// `tool_prompt_permission`) and forwards all permission requests to the
|
||||
/// gateway. Reconnects automatically with exponential back-off.
|
||||
pub fn spawn_uplink_task(upstream_url: String, services: Arc<Services>) {
|
||||
if upstream_url.is_empty() {
|
||||
pub fn spawn_uplink_task(config: UplinkConfig, services: Arc<Services>) {
|
||||
if config.upstream_url.is_empty() {
|
||||
return;
|
||||
}
|
||||
slog!("[uplink] Spawning sled uplink task (gateway={upstream_url})");
|
||||
let UplinkConfig {
|
||||
upstream_url,
|
||||
project_name,
|
||||
local_mcp_url,
|
||||
} = config;
|
||||
slog!("[uplink] Spawning sled uplink task (gateway={upstream_url}, project={project_name})");
|
||||
tokio::spawn(async move {
|
||||
// Acquire perm_rx for this task's entire lifetime. While this lock is
|
||||
// held, try_lock() inside tool_prompt_permission fails — meaning
|
||||
@@ -70,9 +105,19 @@ pub fn spawn_uplink_task(upstream_url: String, services: Arc<Services>) {
|
||||
let mut perm_rx = services.perm_rx.lock().await;
|
||||
slog!("[uplink] Acquired perm_rx; maintaining gateway connection");
|
||||
|
||||
let http = reqwest::Client::new();
|
||||
|
||||
let mut backoff = INITIAL_BACKOFF_SECS;
|
||||
loop {
|
||||
match run_uplink_session(&upstream_url, &mut perm_rx).await {
|
||||
match run_uplink_session(
|
||||
&upstream_url,
|
||||
&project_name,
|
||||
&local_mcp_url,
|
||||
&http,
|
||||
&mut perm_rx,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
slog!("[uplink] Connection closed cleanly; reconnecting in {backoff}s");
|
||||
}
|
||||
@@ -90,10 +135,14 @@ pub fn spawn_uplink_task(upstream_url: String, services: Arc<Services>) {
|
||||
|
||||
// ── Private helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
/// Run a single uplink session: connect, pump messages bidirectionally until
|
||||
/// disconnect or channel close, then fail-close any in-flight requests.
|
||||
/// Run a single uplink session: connect, send identity frame, pump messages
|
||||
/// bidirectionally until disconnect or channel close, then fail-close any
|
||||
/// in-flight requests.
|
||||
async fn run_uplink_session(
|
||||
url: &str,
|
||||
project_name: &str,
|
||||
local_mcp_url: &str,
|
||||
http: &reqwest::Client,
|
||||
perm_rx: &mut tokio::sync::mpsc::UnboundedReceiver<PermissionForward>,
|
||||
) -> Result<(), String> {
|
||||
let (ws_stream, _) = tokio_tungstenite::connect_async(url)
|
||||
@@ -102,9 +151,31 @@ async fn run_uplink_session(
|
||||
slog!("[uplink] Connected to gateway uplink endpoint");
|
||||
|
||||
let (mut ws_sink, mut ws_rx) = ws_stream.split();
|
||||
|
||||
// ── Identity handshake (story 899 AC 5) ─────────────────────────────
|
||||
let identity = UplinkEnvelope {
|
||||
msg_type: "identity".to_string(),
|
||||
req_id: "identity".to_string(),
|
||||
payload: serde_json::json!({ "project": project_name }),
|
||||
};
|
||||
let identity_text =
|
||||
serde_json::to_string(&identity).map_err(|e| format!("serialise identity: {e}"))?;
|
||||
ws_sink
|
||||
.send(WsMessage::Text(identity_text.into()))
|
||||
.await
|
||||
.map_err(|e| format!("send identity: {e}"))?;
|
||||
|
||||
let mut in_flight: HashMap<String, oneshot::Sender<PermissionDecision>> = HashMap::new();
|
||||
|
||||
let result = pump_messages(&mut ws_sink, &mut ws_rx, perm_rx, &mut in_flight).await;
|
||||
let result = pump_messages(
|
||||
&mut ws_sink,
|
||||
&mut ws_rx,
|
||||
perm_rx,
|
||||
&mut in_flight,
|
||||
local_mcp_url,
|
||||
http,
|
||||
)
|
||||
.await;
|
||||
fail_close_all(&mut in_flight);
|
||||
result
|
||||
}
|
||||
@@ -118,9 +189,44 @@ async fn pump_messages(
|
||||
),
|
||||
perm_rx: &mut tokio::sync::mpsc::UnboundedReceiver<PermissionForward>,
|
||||
in_flight: &mut HashMap<String, oneshot::Sender<PermissionDecision>>,
|
||||
local_mcp_url: &str,
|
||||
http: &reqwest::Client,
|
||||
) -> Result<(), String> {
|
||||
// Channel for spawned mcp_request handlers to deliver their finished
|
||||
// mcp_response frames back to the WS writer.
|
||||
let (mcp_resp_tx, mut mcp_resp_rx) = tokio::sync::mpsc::unbounded_channel::<UplinkEnvelope>();
|
||||
|
||||
let mut heartbeat =
|
||||
tokio::time::interval(std::time::Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
|
||||
// Skip the immediate tick so the first heartbeat fires after the interval,
|
||||
// not right after connect.
|
||||
heartbeat.tick().await;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Heartbeat tick (story 899 AC 3).
|
||||
_ = heartbeat.tick() => {
|
||||
let env = UplinkEnvelope {
|
||||
msg_type: "heartbeat".to_string(),
|
||||
req_id: String::new(),
|
||||
payload: serde_json::Value::Null,
|
||||
};
|
||||
let text = serde_json::to_string(&env)
|
||||
.map_err(|e| format!("serialise heartbeat: {e}"))?;
|
||||
if ws_sink.send(WsMessage::Text(text.into())).await.is_err() {
|
||||
return Err("WS send failed (heartbeat)".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Completed mcp_response from a spawned handler — forward to gateway.
|
||||
Some(env) = mcp_resp_rx.recv() => {
|
||||
let text = serde_json::to_string(&env)
|
||||
.map_err(|e| format!("serialise mcp_response: {e}"))?;
|
||||
if ws_sink.send(WsMessage::Text(text.into())).await.is_err() {
|
||||
return Err("WS send failed (mcp_response)".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// New permission request from the MCP layer.
|
||||
maybe_fwd = perm_rx.recv() => {
|
||||
match maybe_fwd {
|
||||
@@ -162,7 +268,7 @@ async fn pump_messages(
|
||||
return Err("Gateway sent Close frame".to_string());
|
||||
}
|
||||
Some(Ok(WsMessage::Text(text))) => {
|
||||
on_gateway_text(&text, in_flight);
|
||||
on_gateway_text(&text, in_flight, local_mcp_url, http, &mcp_resp_tx);
|
||||
}
|
||||
Some(Ok(WsMessage::Ping(data))) => {
|
||||
let _ = ws_sink.send(WsMessage::Pong(data)).await;
|
||||
@@ -174,19 +280,85 @@ async fn pump_messages(
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse an incoming gateway text frame and resolve any matching in-flight request.
|
||||
/// Parse an incoming gateway text frame and dispatch it to the appropriate
|
||||
/// handler. `perm_response` resolves a waiting in-flight permission request;
|
||||
/// `mcp_request` spawns a task that replays the JSON-RPC body against the
|
||||
/// sled's local `/mcp` HTTP endpoint and forwards the result back to the
|
||||
/// gateway as an `mcp_response`.
|
||||
fn on_gateway_text(
|
||||
text: &str,
|
||||
in_flight: &mut HashMap<String, oneshot::Sender<PermissionDecision>>,
|
||||
local_mcp_url: &str,
|
||||
http: &reqwest::Client,
|
||||
mcp_resp_tx: &tokio::sync::mpsc::UnboundedSender<UplinkEnvelope>,
|
||||
) {
|
||||
let Ok(env) = serde_json::from_str::<UplinkEnvelope>(text) else {
|
||||
return;
|
||||
};
|
||||
if env.msg_type == "perm_response" {
|
||||
resolve_perm_response(env, in_flight);
|
||||
match env.msg_type.as_str() {
|
||||
"perm_response" => resolve_perm_response(env, in_flight),
|
||||
"mcp_request" => spawn_mcp_request_handler(env, local_mcp_url, http, mcp_resp_tx),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Replay the gateway's `mcp_request` body against the sled's local MCP HTTP
|
||||
/// endpoint in a spawned task and forward the response back as an
|
||||
/// `mcp_response` envelope. The payload is expected to contain a `body`
|
||||
/// string field holding the raw JSON-RPC bytes.
|
||||
fn spawn_mcp_request_handler(
|
||||
env: UplinkEnvelope,
|
||||
local_mcp_url: &str,
|
||||
http: &reqwest::Client,
|
||||
mcp_resp_tx: &tokio::sync::mpsc::UnboundedSender<UplinkEnvelope>,
|
||||
) {
|
||||
let req_id = env.req_id.clone();
|
||||
let body = env
|
||||
.payload
|
||||
.get("body")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let mcp_url = local_mcp_url.to_string();
|
||||
let http = http.clone();
|
||||
let mcp_resp_tx = mcp_resp_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
let response_value = match http
|
||||
.post(&mcp_url)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(body)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => match r.json::<serde_json::Value>().await {
|
||||
Ok(v) => v,
|
||||
Err(e) => serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": null,
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": format!("local mcp invalid JSON: {e}"),
|
||||
}
|
||||
}),
|
||||
},
|
||||
Err(e) => serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": null,
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": format!("local mcp request failed: {e}"),
|
||||
}
|
||||
}),
|
||||
};
|
||||
let resp = UplinkEnvelope {
|
||||
msg_type: "mcp_response".to_string(),
|
||||
req_id,
|
||||
payload: response_value,
|
||||
};
|
||||
let _ = mcp_resp_tx.send(resp);
|
||||
});
|
||||
}
|
||||
|
||||
/// Map a `perm_response` envelope to a [`PermissionDecision`] and wake the
|
||||
/// waiting MCP call.
|
||||
fn resolve_perm_response(
|
||||
@@ -321,9 +493,13 @@ mod tests {
|
||||
#[test]
|
||||
fn on_gateway_text_ignores_unknown_type() {
|
||||
let mut in_flight: HashMap<String, oneshot::Sender<PermissionDecision>> = HashMap::new();
|
||||
let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
on_gateway_text(
|
||||
r#"{"type":"future_type","req_id":"x","payload":{}}"#,
|
||||
&mut in_flight,
|
||||
"http://127.0.0.1:0/mcp",
|
||||
&reqwest::Client::new(),
|
||||
&tx,
|
||||
);
|
||||
assert!(in_flight.is_empty());
|
||||
}
|
||||
@@ -331,7 +507,14 @@ mod tests {
|
||||
#[test]
|
||||
fn on_gateway_text_ignores_invalid_json() {
|
||||
let mut in_flight: HashMap<String, oneshot::Sender<PermissionDecision>> = HashMap::new();
|
||||
on_gateway_text("not-json", &mut in_flight); // must not panic
|
||||
let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
on_gateway_text(
|
||||
"not-json",
|
||||
&mut in_flight,
|
||||
"http://127.0.0.1:0/mcp",
|
||||
&reqwest::Client::new(),
|
||||
&tx,
|
||||
); // must not panic
|
||||
assert!(in_flight.is_empty());
|
||||
}
|
||||
|
||||
@@ -351,7 +534,14 @@ mod tests {
|
||||
permission_timeout_secs: 120,
|
||||
});
|
||||
// Empty URL → noop; if it panicked or blocked the test would fail.
|
||||
spawn_uplink_task(String::new(), services);
|
||||
spawn_uplink_task(
|
||||
UplinkConfig {
|
||||
upstream_url: String::new(),
|
||||
project_name: "test".to_string(),
|
||||
local_mcp_url: "http://127.0.0.1:0/mcp".to_string(),
|
||||
},
|
||||
services,
|
||||
);
|
||||
}
|
||||
|
||||
// ── AC 11: permission approved via uplink ────────────────────────
|
||||
@@ -364,10 +554,23 @@ mod tests {
|
||||
let port = listener.local_addr().unwrap().port();
|
||||
let url = format!("ws://127.0.0.1:{port}");
|
||||
|
||||
// Mock gateway: accept one connection, receive perm_request, reply approved.
|
||||
// Mock gateway: accept one connection, consume identity, receive
|
||||
// perm_request, reply approved.
|
||||
let gw_task = tokio::spawn(async move {
|
||||
let (tcp, _) = listener.accept().await.unwrap();
|
||||
let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap();
|
||||
|
||||
// First frame: identity (story 899 AC 5).
|
||||
let msg = ws.next().await.unwrap().unwrap();
|
||||
let text = match msg {
|
||||
WsMessage::Text(t) => t.to_string(),
|
||||
other => panic!("expected identity Text; got {other:?}"),
|
||||
};
|
||||
let env: UplinkEnvelope = serde_json::from_str(&text).unwrap();
|
||||
assert_eq!(env.msg_type, "identity");
|
||||
assert_eq!(env.payload["project"], "test-proj");
|
||||
|
||||
// Next frame: perm_request.
|
||||
let msg = ws.next().await.unwrap().unwrap();
|
||||
let text = match msg {
|
||||
WsMessage::Text(t) => t.to_string(),
|
||||
@@ -402,7 +605,14 @@ mod tests {
|
||||
permission_timeout_secs: 120,
|
||||
});
|
||||
|
||||
spawn_uplink_task(url, Arc::clone(&services));
|
||||
spawn_uplink_task(
|
||||
UplinkConfig {
|
||||
upstream_url: url,
|
||||
project_name: "test-proj".to_string(),
|
||||
local_mcp_url: "http://127.0.0.1:0/mcp".to_string(),
|
||||
},
|
||||
Arc::clone(&services),
|
||||
);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
@@ -441,12 +651,14 @@ mod tests {
|
||||
let conn_count2 = StdArc::clone(&conn_count);
|
||||
|
||||
tokio::spawn(async move {
|
||||
// First connection: receive the request then immediately drop (simulates network failure).
|
||||
// First connection: consume identity + the request frame, then
|
||||
// drop without replying (simulates network failure).
|
||||
{
|
||||
let (tcp, _) = listener.accept().await.unwrap();
|
||||
let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap();
|
||||
conn_count2.fetch_add(1, Ordering::SeqCst);
|
||||
let _ = ws.next().await; // consume one frame
|
||||
let _ = ws.next().await; // identity frame
|
||||
let _ = ws.next().await; // perm_request frame
|
||||
drop(ws); // close without replying → fail-close on sled side
|
||||
}
|
||||
// Second connection: approve the next request.
|
||||
@@ -454,7 +666,10 @@ mod tests {
|
||||
let (tcp, _) = listener.accept().await.unwrap();
|
||||
let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap();
|
||||
conn_count2.fetch_add(1, Ordering::SeqCst);
|
||||
if let Some(Ok(WsMessage::Text(text))) = ws.next().await {
|
||||
// Consume identity frame.
|
||||
let _ = ws.next().await;
|
||||
// Drain non-perm frames (heartbeat etc.) until perm_request arrives.
|
||||
while let Some(Ok(WsMessage::Text(text))) = ws.next().await {
|
||||
let env: UplinkEnvelope = serde_json::from_str(&text).unwrap();
|
||||
if env.msg_type == "perm_request" {
|
||||
let resp = UplinkEnvelope {
|
||||
@@ -467,6 +682,7 @@ mod tests {
|
||||
serde_json::to_string(&resp).unwrap().into(),
|
||||
))
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -486,7 +702,14 @@ mod tests {
|
||||
permission_timeout_secs: 120,
|
||||
});
|
||||
|
||||
spawn_uplink_task(url, Arc::clone(&services));
|
||||
spawn_uplink_task(
|
||||
UplinkConfig {
|
||||
upstream_url: url,
|
||||
project_name: "test-proj".to_string(),
|
||||
local_mcp_url: "http://127.0.0.1:0/mcp".to_string(),
|
||||
},
|
||||
Arc::clone(&services),
|
||||
);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
|
||||
// First request is sent on the connection that drops → denied.
|
||||
|
||||
Reference in New Issue
Block a user