From cd214d7246162006fff6a47df4513479244f54ea Mon Sep 17 00:00:00 2001 From: dave Date: Tue, 12 May 2026 23:11:34 +0000 Subject: [PATCH] huskies: merge 899 --- server/src/gateway/mod.rs | 2 +- server/src/gateway/tests.rs | 235 +++++++++++++++++++--- server/src/http/gateway/mcp.rs | 58 ++++-- server/src/http/gateway/websocket.rs | 288 ++++++++++++++++++++------- server/src/main.rs | 20 +- server/src/service/gateway/config.rs | 171 +++++++++++++--- server/src/service/gateway/io.rs | 23 --- server/src/service/gateway/mod.rs | 249 +++++++++++++++++++++-- server/src/sled_uplink.rs | 277 +++++++++++++++++++++++--- 9 files changed, 1105 insertions(+), 218 deletions(-) diff --git a/server/src/gateway/mod.rs b/server/src/gateway/mod.rs index 39488d97..d4c60423 100644 --- a/server/src/gateway/mod.rs +++ b/server/src/gateway/mod.rs @@ -119,7 +119,7 @@ pub async fn run(config_path: &Path, port: u16) -> Result<(), std::io::Error> { .read() .await .iter() - .map(|(name, entry)| (name.clone(), entry.url.clone())) + .filter_map(|(name, entry)| entry.url.as_ref().map(|u| (name.clone(), u.clone()))) .collect(); let (bot_abort, bot_shutdown_tx) = gateway::io::spawn_gateway_bot( &config_dir, diff --git a/server/src/gateway/tests.rs b/server/src/gateway/tests.rs index f2372707..c8b448fb 100644 --- a/server/src/gateway/tests.rs +++ b/server/src/gateway/tests.rs @@ -7,12 +7,7 @@ use std::path::PathBuf; fn make_test_state() -> Arc { let mut projects = BTreeMap::new(); - projects.insert( - "test".into(), - ProjectEntry { - url: "http://test:3001".into(), - }, - ); + projects.insert("test".into(), ProjectEntry::with_url("http://test:3001")); let config = GatewayConfig { projects, sled_tokens: BTreeMap::new(), @@ -367,9 +362,7 @@ async fn init_project_registers_in_projects_toml_when_name_and_url_given() { let mut projects = BTreeMap::new(); projects.insert( "existing".into(), - ProjectEntry { - url: "http://existing:3001".into(), - }, + ProjectEntry::with_url("http://existing:3001"), ); let config = GatewayConfig { projects, @@ -388,19 +381,17 @@ async fn init_project_registers_in_projects_toml_when_name_and_url_given() { let projects = state.projects.read().await; assert!(projects.contains_key("new-project")); - assert_eq!(projects["new-project"].url, "http://new-project:3002"); + assert_eq!( + projects["new-project"].url.as_deref(), + Some("http://new-project:3002") + ); } #[tokio::test] async fn init_project_duplicate_name_returns_error() { let dir = tempfile::tempdir().unwrap(); let mut projects = BTreeMap::new(); - projects.insert( - "taken".into(), - ProjectEntry { - url: "http://taken:3001".into(), - }, - ); + projects.insert("taken".into(), ProjectEntry::with_url("http://taken:3001")); let config = GatewayConfig { projects, sled_tokens: BTreeMap::new(), @@ -452,7 +443,7 @@ async fn init_project_then_wizard_status_integration() { tokio::time::sleep(std::time::Duration::from_millis(10)).await; let mut projects = BTreeMap::new(); - projects.insert("mock-project".into(), ProjectEntry { url: mock_url }); + projects.insert("mock-project".into(), ProjectEntry::with_url(mock_url)); let config = GatewayConfig { projects, sled_tokens: BTreeMap::new(), @@ -972,12 +963,7 @@ async fn gateway_mcp_sse_proxy_streams_progress_and_final_response() { .await; let mut projects = BTreeMap::new(); - projects.insert( - "sled".to_string(), - ProjectEntry { - url: mock_sled.url(), - }, - ); + projects.insert("sled".to_string(), ProjectEntry::with_url(mock_sled.url())); let config = GatewayConfig { projects, sled_tokens: BTreeMap::new(), @@ -1068,12 +1054,7 @@ async fn gateway_mcp_post_without_sse_returns_plain_json() { .await; let mut projects = BTreeMap::new(); - projects.insert( - "sled".to_string(), - ProjectEntry { - url: mock_sled.url(), - }, - ); + projects.insert("sled".to_string(), ProjectEntry::with_url(mock_sled.url())); let config = GatewayConfig { projects, sled_tokens: BTreeMap::new(), @@ -1118,3 +1099,199 @@ async fn gateway_mcp_post_without_sse_returns_plain_json() { "Expected result in plain JSON response" ); } + +// ── Story 899: MCP-over-WS uplink integration ──────────────────────────────── + +/// Build a `SledConnection` plus a spawned "mock sled" task that: +/// +/// * Reads outbound `mcp_request` envelopes off the connection's channel. +/// * Invokes the supplied closure to build a `result` value for each request. +/// * Resolves the matching in-flight oneshot directly (the same effect the WS +/// handler has when it receives an `mcp_response` from a real sled). +/// +/// Returns the registered `SledConnection`. +fn spawn_mock_sled(handler: F) -> crate::service::gateway::SledConnection +where + F: Fn(&serde_json::Value) -> serde_json::Value + Send + Sync + 'static, +{ + use crate::service::gateway::SledConnection; + use std::sync::atomic::AtomicI64; + + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + let last_heartbeat_ms = Arc::new(AtomicI64::new(chrono::Utc::now().timestamp_millis())); + let in_flight = Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())); + + let conn = SledConnection { + tx, + last_heartbeat_ms, + in_flight: Arc::clone(&in_flight), + }; + let handler = Arc::new(handler); + tokio::spawn(async move { + while let Some(env) = rx.recv().await { + if env.msg_type != "mcp_request" { + continue; + } + let body_str = env + .payload + .get("body") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let body_json: serde_json::Value = match serde_json::from_str(&body_str) { + Ok(v) => v, + Err(_) => serde_json::Value::Null, + }; + let result = handler(&body_json); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": body_json.get("id").cloned().unwrap_or(serde_json::Value::Null), + "result": result, + }); + if let Some(oneshot_tx) = in_flight.lock().await.remove(&env.req_id) { + let _ = oneshot_tx.send(response); + } + } + }); + conn +} + +/// AC 9: gateway switches active project, calls tools/list and tools/call +/// against a sled connected only via WS uplink, gets correct responses. +#[tokio::test] +async fn ws_only_sled_handles_tools_list_and_tools_call() { + use crate::service::gateway::ProjectEntry; + + // Project entry with NO url — WS-only. + let mut projects = BTreeMap::new(); + projects.insert( + "ws-only".into(), + ProjectEntry { + url: None, + auth_token: Some("secret".into()), + }, + ); + let config = GatewayConfig { + projects, + sled_tokens: BTreeMap::new(), + }; + let state = Arc::new(GatewayState::new(config, PathBuf::new(), 3000).unwrap()); + + let conn = spawn_mock_sled(|body| { + let method = body.get("method").and_then(|m| m.as_str()).unwrap_or(""); + match method { + "tools/list" => serde_json::json!({ + "tools": [ + { "name": "my_tool", "description": "test" } + ] + }), + "tools/call" => serde_json::json!({ + "content": [{ "type": "text", "text": "called ok" }] + }), + _ => serde_json::json!({ "echo": method }), + } + }); + state + .register_sled_connection("ws-only".to_string(), conn) + .await; + + // tools/list via proxy_active_mcp. + let body = serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list", + "params": {} + })) + .unwrap(); + let resp = state.proxy_active_mcp(&body).await.expect("ws proxy works"); + let resp_json: serde_json::Value = serde_json::from_slice(&resp).unwrap(); + assert_eq!( + resp_json["result"]["tools"][0]["name"], "my_tool", + "tools/list response must come from the WS-connected sled, not HTTP" + ); + + // tools/call via proxy_active_mcp. + let body = serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { "name": "my_tool", "arguments": {} } + })) + .unwrap(); + let resp = state.proxy_active_mcp(&body).await.expect("ws proxy works"); + let resp_json: serde_json::Value = serde_json::from_slice(&resp).unwrap(); + assert_eq!( + resp_json["result"]["content"][0]["text"], "called ok", + "tools/call response must come from the WS-connected sled, not HTTP" + ); +} + +/// AC 10: two sleds connected to one gateway concurrently; gateway routes +/// calls to the right sled based on active project. +#[tokio::test] +async fn two_concurrent_sleds_are_routed_by_active_project() { + use crate::service::gateway::ProjectEntry; + + let mut projects = BTreeMap::new(); + projects.insert( + "alpha".into(), + ProjectEntry { + url: None, + auth_token: Some("alpha-tok".into()), + }, + ); + projects.insert( + "beta".into(), + ProjectEntry { + url: None, + auth_token: Some("beta-tok".into()), + }, + ); + let config = GatewayConfig { + projects, + sled_tokens: BTreeMap::new(), + }; + let state = Arc::new(GatewayState::new(config, PathBuf::new(), 3000).unwrap()); + + let alpha_conn = spawn_mock_sled(|_body| { + serde_json::json!({ + "content": [{ "type": "text", "text": "from-alpha" }] + }) + }); + let beta_conn = spawn_mock_sled(|_body| { + serde_json::json!({ + "content": [{ "type": "text", "text": "from-beta" }] + }) + }); + state + .register_sled_connection("alpha".to_string(), alpha_conn) + .await; + state + .register_sled_connection("beta".to_string(), beta_conn) + .await; + + // Switch to alpha. + gateway::switch_project(&state, "alpha").await.unwrap(); + let body = serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { "name": "noop", "arguments": {} } + })) + .unwrap(); + let resp = state.proxy_active_mcp(&body).await.expect("ws proxy works"); + let resp_json: serde_json::Value = serde_json::from_slice(&resp).unwrap(); + assert_eq!( + resp_json["result"]["content"][0]["text"], "from-alpha", + "When active project is alpha, calls must route to the alpha sled" + ); + + // Switch to beta. + gateway::switch_project(&state, "beta").await.unwrap(); + let resp = state.proxy_active_mcp(&body).await.expect("ws proxy works"); + let resp_json: serde_json::Value = serde_json::from_slice(&resp).unwrap(); + assert_eq!( + resp_json["result"]["content"][0]["text"], "from-beta", + "When active project is beta, calls must route to the beta sled" + ); +} diff --git a/server/src/http/gateway/mcp.rs b/server/src/http/gateway/mcp.rs index ee051f8f..a1604643 100644 --- a/server/src/http/gateway/mcp.rs +++ b/server/src/http/gateway/mcp.rs @@ -203,14 +203,11 @@ pub async fn gateway_mcp_post_handler( } /// Proxy a request to the active project and format the response. +/// +/// Prefers the live sled-uplink WebSocket when one is attached (story 899 +/// AC 2); falls back to the legacy HTTP proxy otherwise. async fn proxy_and_respond(state: &GatewayState, bytes: &[u8], id: Option) -> Response { - let url = match state.active_url().await { - Ok(u) => u, - Err(e) => { - return to_json_response(JsonRpcResponse::error(id, -32603, e.to_string())); - } - }; - match gateway::io::proxy_mcp_call(&state.client, &url, bytes).await { + match state.proxy_active_mcp(bytes).await { Ok(resp_body) => Response::builder() .status(StatusCode::OK) .header("Content-Type", "application/json") @@ -316,13 +313,23 @@ fn handle_initialize(id: Option) -> JsonRpcResponse { } /// Fetch tools/list from the active project and merge in gateway tools. +/// +/// Routes via the sled-uplink WS when one is attached (story 899 AC 2); +/// falls back to HTTP otherwise. async fn handle_tools_list( state: &GatewayState, id: Option, ) -> Result { - let url = state.active_url().await.map_err(|e| e.to_string())?; - - let resp_json = gateway::io::fetch_tools_list(&state.client, &url).await?; + let rpc_body = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list", + "params": {} + }); + let bytes = serde_json::to_vec(&rpc_body).map_err(|e| e.to_string())?; + let resp_bytes = state.proxy_active_mcp(&bytes).await?; + let resp_json: Value = + serde_json::from_slice(&resp_bytes).map_err(|e| format!("invalid tools/list JSON: {e}"))?; let mut tools: Vec = resp_json .get("result") @@ -414,21 +421,36 @@ async fn handle_gateway_status_tool(state: &GatewayState, id: Option) -> async fn handle_gateway_health_tool(state: &GatewayState, id: Option) -> JsonRpcResponse { let mut results = BTreeMap::new(); - let project_entries: Vec<(String, String)> = state + // Build the project list, preferring the WS-uplink heartbeat as the + // source of truth for liveness (story 899 AC 3). HTTP polls are used + // only as a fallback when no live sled is connected. + let project_names: Vec<(String, Option)> = state .projects .read() .await .iter() .map(|(n, e)| (n.clone(), e.url.clone())) .collect(); - for (name, url) in &project_entries { - let status = match gateway::io::check_project_health(&state.client, url).await { - Ok(true) => "healthy".to_string(), - Ok(false) => "unhealthy".to_string(), - Err(e) => e, + let sled_conns = state.sled_connections.read().await; + for (name, url_opt) in &project_names { + let status = if let Some(conn) = sled_conns.get(name) { + if conn.is_alive(crate::service::gateway::HEARTBEAT_MAX_AGE_MS) { + "healthy (ws)".to_string() + } else { + "stale (ws heartbeat overdue)".to_string() + } + } else if let Some(url) = url_opt { + match gateway::io::check_project_health(&state.client, url).await { + Ok(true) => "healthy".to_string(), + Ok(false) => "unhealthy".to_string(), + Err(e) => e, + } + } else { + "no uplink and no url configured".to_string() }; results.insert(name.clone(), status); } + drop(sled_conns); let active = state.active_project.read().await.clone(); JsonRpcResponse::success( @@ -512,7 +534,7 @@ async fn handle_aggregate_pipeline_status_tool( .read() .await .iter() - .map(|(name, entry)| (name.clone(), entry.url.clone())) + .filter_map(|(name, entry)| entry.url.as_ref().map(|u| (name.clone(), u.clone()))) .collect(); let statuses = @@ -656,7 +678,7 @@ async fn handle_pipeline_get(state: &GatewayState, id: Option) -> JsonRpc .read() .await .iter() - .map(|(n, e)| (n.clone(), e.url.clone())) + .filter_map(|(n, e)| e.url.as_ref().map(|u| (n.clone(), u.clone()))) .collect(); let results = gateway::io::fetch_all_project_pipeline_items(&project_urls, &state.client).await; diff --git a/server/src/http/gateway/websocket.rs b/server/src/http/gateway/websocket.rs index f8074407..9302a700 100644 --- a/server/src/http/gateway/websocket.rs +++ b/server/src/http/gateway/websocket.rs @@ -155,21 +155,30 @@ struct SledUplinkParams { token: Option, } -/// `GET /api/sled-uplink` — gateway-side WebSocket endpoint for sled permission uplinks. +/// `GET /api/sled-uplink` — gateway-side WebSocket endpoint for sled uplinks. /// /// # Authentication /// /// The connecting sled must supply a valid shared-secret token via the `token` -/// query parameter. Tokens are configured in `[sled_tokens]` in `projects.toml` -/// as `sled_id = "secret"` entries. +/// query parameter. Tokens are configured either as per-project `auth_token` +/// fields under `[projects.]` in `projects.toml` (preferred, story 899) +/// or, for backwards compatibility, in the deprecated `[sled_tokens]` table. /// /// # Protocol /// -/// See `sled_uplink.rs` for the wire format ([`UplinkEnvelope`]). The gateway -/// accepts `perm_request` messages, injects them into the local permission -/// pipeline (via `state.perm_tx`), and sends `perm_response` frames back to the -/// sled once the Matrix bot resolves them. Multiple sleds are demuxed by -/// connection: each handler owns exactly one sled's request/response flow. +/// See `sled_uplink.rs` for the wire format ([`UplinkEnvelope`]). +/// +/// Phase 2 (story 899) expands the protocol from the original perm-only flow +/// to a full bidirectional MCP transport: +/// +/// - **sled → gateway:** `identity` (mandatory first frame after upgrade), +/// `heartbeat`, `perm_request`, `mcp_response`. +/// - **gateway → sled:** `mcp_request`, `perm_response`. +/// +/// On `identity`, the connection is published to +/// [`GatewayState::sled_connections`] under the project name, allowing the +/// MCP proxy (`mcp.rs::proxy_and_respond`) to route subsequent calls over +/// the live WS instead of HTTP. The entry is removed on disconnect. #[handler] pub async fn gateway_sled_uplink_handler( ws: WebSocket, @@ -196,82 +205,211 @@ pub async fn gateway_sled_uplink_handler( use poem::IntoResponse as _; let perm_tx = state.perm_tx.clone(); + let state = Arc::clone(&state); ws.on_upgrade(move |socket| async move { - let (mut sink, mut stream) = socket.split(); - // Aggregator channel: spawned per-request tasks send (req_id, decision) here - // so the main loop can write perm_response frames back to the sled. - let (agg_tx, mut agg_rx) = tokio::sync::mpsc::unbounded_channel::<( - String, - crate::http::context::PermissionDecision, - )>(); + run_sled_uplink_session(socket, state, sled_id, perm_tx).await; + }) + .into_response() +} - crate::slog!("[gateway/sled-uplink] Sled '{}' connected", sled_id); +/// Run a single connected sled's request/response flow until it disconnects. +/// +/// Performs the identity handshake, registers a [`SledConnection`] in +/// [`GatewayState::sled_connections`] under the published project name, and +/// pumps messages bidirectionally for the lifetime of the WebSocket. +async fn run_sled_uplink_session( + socket: poem::web::websocket::WebSocketStream, + state: Arc, + token_sled_id: String, + perm_tx: tokio::sync::mpsc::UnboundedSender, +) { + use crate::service::gateway::SledConnection; + use crate::sled_uplink::UplinkEnvelope; + use std::sync::Arc as StdArc; + use std::sync::atomic::AtomicI64; - loop { - tokio::select! { - msg = stream.next() => { - let text = match msg { - Some(Ok(WsMessage::Text(t))) => t, - Some(Ok(WsMessage::Close(_))) | None => break, - _ => continue, - }; - let Ok(env) = serde_json::from_str::(&text) else { - continue; - }; - if env.msg_type == "perm_request" { - let req_id = env.req_id.clone(); - let tool_name = env.payload.get("tool_name") - .and_then(|v| v.as_str()) - .unwrap_or("unknown") - .to_string(); - let tool_input = env.payload.get("tool_input") - .cloned() - .unwrap_or(serde_json::Value::Null); - let (response_tx, response_rx) = tokio::sync::oneshot::channel(); - let fwd = crate::http::context::PermissionForward { - request_id: format!("{sled_id}:{req_id}"), - tool_name, - tool_input, - response_tx, - }; - if perm_tx.send(fwd).is_err() { - break; - } - let agg_tx2 = agg_tx.clone(); - tokio::spawn(async move { - let decision = response_rx - .await - .unwrap_or(crate::http::context::PermissionDecision::Deny); - let _ = agg_tx2.send((req_id, decision)); - }); - } + let (mut sink, mut stream) = socket.split(); + + // ── Identity handshake: expect the first frame to be `identity` ────── + let identity = match tokio::time::timeout( + std::time::Duration::from_secs(10), + wait_for_identity(&mut stream), + ) + .await + { + Ok(Some(p)) => p, + _ => { + crate::slog!( + "[gateway/sled-uplink] '{}' missing identity frame; closing", + token_sled_id + ); + return; + } + }; + + // Project name in the identity must match the project resolved from the + // auth token; otherwise the sled is claiming to be a different project. + if identity != token_sled_id { + crate::slog!( + "[gateway/sled-uplink] identity mismatch (token says '{}', sled claims '{}'); closing", + token_sled_id, + identity + ); + return; + } + + // ── Build SledConnection and publish it ───────────────────────────── + let (out_tx, mut out_rx) = tokio::sync::mpsc::unbounded_channel::(); + let conn = SledConnection { + tx: out_tx, + last_heartbeat_ms: StdArc::new(AtomicI64::new(chrono::Utc::now().timestamp_millis())), + in_flight: StdArc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())), + }; + state + .register_sled_connection(identity.clone(), conn.clone()) + .await; + crate::slog!( + "[gateway/sled-uplink] Sled '{}' connected and registered", + identity + ); + + // Aggregator channel for perm responses produced by spawned per-request tasks. + let (agg_tx, mut agg_rx) = + tokio::sync::mpsc::unbounded_channel::<(String, crate::http::context::PermissionDecision)>( + ); + + loop { + tokio::select! { + // Outbound: forward queued envelopes (mcp_request, etc.) to the sled. + Some(env) = out_rx.recv() => { + let Ok(text) = serde_json::to_string(&env) else { continue }; + if sink.send(WsMessage::Text(text)).await.is_err() { + break; } - Some((req_id, decision)) = agg_rx.recv() => { - use crate::http::context::PermissionDecision; - let (approved, always_allow) = match decision { - PermissionDecision::AlwaysAllow => (true, true), - PermissionDecision::Approve => (true, false), - PermissionDecision::Deny => (false, false), - }; - let resp = crate::sled_uplink::UplinkEnvelope { - msg_type: "perm_response".to_string(), - req_id, - payload: serde_json::json!({ - "approved": approved, - "always_allow": always_allow, - }), - }; - let Ok(text) = serde_json::to_string(&resp) else { continue }; - if sink.send(WsMessage::Text(text)).await.is_err() { - break; + } + // Inbound: messages from the sled. + msg = stream.next() => { + let text = match msg { + Some(Ok(WsMessage::Text(t))) => t, + Some(Ok(WsMessage::Close(_))) | None => break, + _ => continue, + }; + let Ok(env) = serde_json::from_str::(&text) else { + continue; + }; + match env.msg_type.as_str() { + "heartbeat" => { + conn.last_heartbeat_ms.store( + chrono::Utc::now().timestamp_millis(), + std::sync::atomic::Ordering::Relaxed, + ); } + "mcp_response" => { + let tx_opt = conn.in_flight.lock().await.remove(&env.req_id); + if let Some(tx) = tx_opt { + let _ = tx.send(env.payload); + } + } + "perm_request" => { + forward_perm_request(env, &perm_tx, &agg_tx, &identity); + } + _ => {} + } + } + // Aggregator: per-request permission resolutions get formatted as perm_response. + Some((req_id, decision)) = agg_rx.recv() => { + use crate::http::context::PermissionDecision; + let (approved, always_allow) = match decision { + PermissionDecision::AlwaysAllow => (true, true), + PermissionDecision::Approve => (true, false), + PermissionDecision::Deny => (false, false), + }; + let resp = UplinkEnvelope { + msg_type: "perm_response".to_string(), + req_id, + payload: serde_json::json!({ + "approved": approved, + "always_allow": always_allow, + }), + }; + let Ok(text) = serde_json::to_string(&resp) else { continue }; + if sink.send(WsMessage::Text(text)).await.is_err() { + break; } } } + } - crate::slog!("[gateway/sled-uplink] Sled '{}' disconnected", sled_id); - }) - .into_response() + state.deregister_sled_connection(&identity).await; + crate::slog!("[gateway/sled-uplink] Sled '{}' disconnected", identity); +} + +/// Read frames until an `identity` envelope is seen; return the project name +/// from its payload. Returns `None` if the stream ends or an invalid frame +/// arrives where the first frame is expected to be `identity`. +async fn wait_for_identity( + stream: &mut futures::stream::SplitStream, +) -> Option { + while let Some(msg) = stream.next().await { + let text = match msg { + Ok(WsMessage::Text(t)) => t, + Ok(WsMessage::Close(_)) | Err(_) => return None, + _ => continue, + }; + let Ok(env) = serde_json::from_str::(&text) else { + return None; + }; + if env.msg_type != "identity" { + return None; + } + return env + .payload + .get("project") + .and_then(|v| v.as_str()) + .map(str::to_string); + } + None +} + +/// Convert an inbound `perm_request` envelope into a [`PermissionForward`] and +/// inject it into the gateway's permission pipeline. The spawned waiter task +/// publishes the resolved decision into the aggregator channel so the WS +/// writer can emit a matching `perm_response` frame. +fn forward_perm_request( + env: crate::sled_uplink::UplinkEnvelope, + perm_tx: &tokio::sync::mpsc::UnboundedSender, + agg_tx: &tokio::sync::mpsc::UnboundedSender<(String, crate::http::context::PermissionDecision)>, + sled_id: &str, +) { + let req_id = env.req_id.clone(); + let tool_name = env + .payload + .get("tool_name") + .and_then(|v| v.as_str()) + .unwrap_or("unknown") + .to_string(); + let tool_input = env + .payload + .get("tool_input") + .cloned() + .unwrap_or(serde_json::Value::Null); + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + let fwd = crate::http::context::PermissionForward { + request_id: format!("{sled_id}:{req_id}"), + tool_name, + tool_input, + response_tx, + }; + if perm_tx.send(fwd).is_err() { + return; + } + let agg_tx2 = agg_tx.clone(); + tokio::spawn(async move { + let decision = response_rx + .await + .unwrap_or(crate::http::context::PermissionDecision::Deny); + let _ = agg_tx2.send((req_id, decision)); + }); } // ── Event-push WebSocket handler ───────────────────────────────────────────── diff --git a/server/src/main.rs b/server/src/main.rs index 2aae0cd1..22669402 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -241,7 +241,25 @@ async fn main() -> Result<(), std::io::Error> { .clone() .or_else(|| std::env::var("HUSKIES_UPSTREAM_GATEWAY").ok()) .unwrap_or_default(); - sled_uplink::spawn_uplink_task(upstream_gateway, Arc::clone(&services)); + // Project name for the identity frame (story 899 AC 5). Env-var override + // wins; otherwise derive from the project_root basename. + let project_name = std::env::var("HUSKIES_PROJECT_NAME").unwrap_or_else(|_| { + services + .project_root + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("unknown") + .to_string() + }); + let local_mcp_url = format!("http://127.0.0.1:{port}/mcp"); + sled_uplink::spawn_uplink_task( + sled_uplink::UplinkConfig { + upstream_url: upstream_gateway, + project_name, + local_mcp_url, + }, + Arc::clone(&services), + ); // ── Build bot contexts (WhatsApp / Slack / Discord) ─────────────────────── let (bot_ctxs, matrix_shutdown_rx) = diff --git a/server/src/service/gateway/config.rs b/server/src/service/gateway/config.rs index 1d7866a7..996e68b9 100644 --- a/server/src/service/gateway/config.rs +++ b/server/src/service/gateway/config.rs @@ -7,20 +7,56 @@ use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; /// A single project entry in `projects.toml`. +/// +/// Phase 2 (story 899): `url` is now optional — a project served exclusively +/// via the sled-uplink WebSocket does not need an HTTP base URL. The `url` +/// field is deprecated for removal in a future release; configure +/// `auth_token` instead and rely on the WS uplink for all traffic. #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ProjectEntry { /// Base URL of the project's huskies container (e.g. `http://localhost:3001`). - pub url: String, + /// + /// **Deprecated** (story 899) — when a sled connects via the uplink WS the + /// gateway routes all MCP traffic over that connection instead. The URL is + /// used as a fallback when no live uplink exists. Omit for WS-only projects. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub url: Option, + /// Shared-secret token used to authenticate this project's sled when it + /// connects to `/api/sled-uplink`. Takes precedence over the top-level + /// `[sled_tokens]` table for projects that set this field. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub auth_token: Option, +} + +impl ProjectEntry { + /// Convenience constructor for entries that only have a URL (e.g. in tests + /// and existing `projects.toml` files that have not yet been migrated to + /// the WS-uplink model). + pub fn with_url(url: impl Into) -> Self { + Self { + url: Some(url.into()), + auth_token: None, + } + } + + /// Returns `true` if this entry has a configured HTTP base URL. + pub fn has_url(&self) -> bool { + self.url.as_ref().is_some_and(|u| !u.is_empty()) + } } /// Top-level `projects.toml` config. #[derive(Debug, Clone, Deserialize, Serialize)] pub struct GatewayConfig { - /// Map of project name → container URL. + /// Map of project name → container configuration. #[serde(default)] pub projects: BTreeMap, /// Map of sled_id → shared secret token for sled-uplink authentication. /// + /// **Deprecated** (story 899) — move tokens into per-project + /// `auth_token` fields instead. The gateway still honours entries here for + /// one release to provide a smooth migration window. + /// /// Each entry allows a sled identified by `sled_id` to connect to /// `/api/sled-uplink` using the given secret token as a bearer credential. #[serde(default)] @@ -40,18 +76,22 @@ pub fn validate_config(config: &GatewayConfig) -> Result { /// Validate that a project name exists in the given project map. /// -/// Returns the project's URL on success. +/// Returns the project's URL (may be empty for WS-uplink-only projects) on +/// success. pub fn validate_project_exists( projects: &BTreeMap, name: &str, ) -> Result { - projects.get(name).map(|p| p.url.clone()).ok_or_else(|| { - let available: Vec<&str> = projects.keys().map(|s| s.as_str()).collect(); - format!( - "unknown project '{name}'. Available: {}", - available.join(", ") - ) - }) + projects + .get(name) + .map(|p| p.url.clone().unwrap_or_default()) + .ok_or_else(|| { + let available: Vec<&str> = projects.keys().map(|s| s.as_str()).collect(); + format!( + "unknown project '{name}'. Available: {}", + available.join(", ") + ) + }) } /// Escape a string as a TOML quoted string. @@ -104,8 +144,29 @@ url = "http://localhost:3002" "#; let config: GatewayConfig = toml::from_str(toml_str).unwrap(); assert_eq!(config.projects.len(), 2); - assert_eq!(config.projects["huskies"].url, "http://localhost:3001"); - assert_eq!(config.projects["robot-studio"].url, "http://localhost:3002"); + assert_eq!( + config.projects["huskies"].url.as_deref(), + Some("http://localhost:3001") + ); + assert_eq!( + config.projects["robot-studio"].url.as_deref(), + Some("http://localhost:3002") + ); + } + + #[test] + fn parse_project_without_url_is_valid() { + let toml_str = r#" +[projects.ws-only] +auth_token = "secret" +"#; + let config: GatewayConfig = toml::from_str(toml_str).unwrap(); + assert_eq!(config.projects.len(), 1); + assert!(config.projects["ws-only"].url.is_none()); + assert_eq!( + config.projects["ws-only"].auth_token.as_deref(), + Some("secret") + ); } #[test] @@ -127,18 +188,8 @@ url = "http://localhost:3002" #[test] fn validate_config_returns_first_project_name() { let mut projects = BTreeMap::new(); - projects.insert( - "beta".into(), - ProjectEntry { - url: "http://b".into(), - }, - ); - projects.insert( - "alpha".into(), - ProjectEntry { - url: "http://a".into(), - }, - ); + projects.insert("beta".into(), ProjectEntry::with_url("http://b")); + projects.insert("alpha".into(), ProjectEntry::with_url("http://a")); let config = GatewayConfig { projects, sled_tokens: BTreeMap::new(), @@ -147,14 +198,26 @@ url = "http://localhost:3002" } #[test] - fn validate_project_exists_succeeds() { + fn validate_config_accepts_ws_only_project() { let mut projects = BTreeMap::new(); projects.insert( - "p1".into(), + "ws-only".into(), ProjectEntry { - url: "http://p1".into(), + url: None, + auth_token: Some("secret".into()), }, ); + let config = GatewayConfig { + projects, + sled_tokens: BTreeMap::new(), + }; + assert!(validate_config(&config).is_ok()); + } + + #[test] + fn validate_project_exists_succeeds() { + let mut projects = BTreeMap::new(); + projects.insert("p1".into(), ProjectEntry::with_url("http://p1")); assert_eq!( validate_project_exists(&projects, "p1").unwrap(), "http://p1" @@ -167,6 +230,36 @@ url = "http://localhost:3002" assert!(validate_project_exists(&projects, "missing").is_err()); } + #[test] + fn validate_project_exists_ws_only_returns_empty_url() { + let mut projects = BTreeMap::new(); + projects.insert( + "ws".into(), + ProjectEntry { + url: None, + auth_token: Some("tok".into()), + }, + ); + assert_eq!(validate_project_exists(&projects, "ws").unwrap(), ""); + } + + #[test] + fn project_entry_with_url_constructor() { + let e = ProjectEntry::with_url("http://example.com"); + assert_eq!(e.url.as_deref(), Some("http://example.com")); + assert!(e.auth_token.is_none()); + assert!(e.has_url()); + } + + #[test] + fn project_entry_has_url_false_when_none() { + let e = ProjectEntry { + url: None, + auth_token: Some("tok".into()), + }; + assert!(!e.has_url()); + } + #[test] fn toml_string_escapes_quotes() { assert_eq!(toml_string(r#"a"b"#), r#""a\"b""#); @@ -198,4 +291,28 @@ url = "http://localhost:3002" assert!(content.contains("transport = \"slack\"")); assert!(content.contains("slack_bot_token = \"xoxb-123\"")); } + + #[test] + fn roundtrip_project_entry_with_auth_token() { + let entry = ProjectEntry { + url: Some("http://a:3001".into()), + auth_token: Some("mysecret".into()), + }; + let mut projects = BTreeMap::new(); + projects.insert("myproj".into(), entry); + let config = GatewayConfig { + projects, + sled_tokens: BTreeMap::new(), + }; + let toml_str = toml::to_string_pretty(&config).unwrap(); + let parsed: GatewayConfig = toml::from_str(&toml_str).unwrap(); + assert_eq!( + parsed.projects["myproj"].url.as_deref(), + Some("http://a:3001") + ); + assert_eq!( + parsed.projects["myproj"].auth_token.as_deref(), + Some("mysecret") + ); + } } diff --git a/server/src/service/gateway/io.rs b/server/src/service/gateway/io.rs index 85c6e9af..f05ae384 100644 --- a/server/src/service/gateway/io.rs +++ b/server/src/service/gateway/io.rs @@ -140,29 +140,6 @@ pub async fn proxy_mcp_call_sse( .map_err(|e| format!("failed to reach {mcp_url}: {e}")) } -/// Fetch tools/list from a project's MCP endpoint. -pub async fn fetch_tools_list(client: &Client, base_url: &str) -> Result { - let mcp_url = format!("{}/mcp", base_url.trim_end_matches('/')); - - let rpc_body = json!({ - "jsonrpc": "2.0", - "id": 1, - "method": "tools/list", - "params": {} - }); - - let resp = client - .post(&mcp_url) - .json(&rpc_body) - .send() - .await - .map_err(|e| format!("failed to reach {mcp_url}: {e}"))?; - - resp.json() - .await - .map_err(|e| format!("invalid JSON from upstream: {e}")) -} - /// Fetch and aggregate pipeline status for a single project URL. pub async fn fetch_one_project_pipeline_status(url: &str, client: &Client) -> Value { let mcp_url = format!("{}/mcp", url.trim_end_matches('/')); diff --git a/server/src/service/gateway/mod.rs b/server/src/service/gateway/mod.rs index 631c9222..feebcf6c 100644 --- a/server/src/service/gateway/mod.rs +++ b/server/src/service/gateway/mod.rs @@ -28,6 +28,7 @@ use io::Client; use std::collections::{BTreeMap, HashMap}; use std::path::PathBuf; use std::sync::Arc; +use std::sync::atomic::{AtomicI64, Ordering}; use tokio::sync::Mutex as TokioMutex; use tokio::sync::RwLock; use tokio::sync::mpsc; @@ -49,6 +50,95 @@ pub struct GatewayStatusEvent { pub event: crate::service::events::StoredEvent, } +// ── Sled connection ───────────────────────────────────────────────────────── + +/// Maximum age, in milliseconds, of a sled heartbeat before the gateway +/// considers the connection stale (story 899 AC 3). +pub const HEARTBEAT_MAX_AGE_MS: i64 = 30_000; + +/// Default per-request timeout, in milliseconds, for an MCP call proxied over +/// a sled uplink WebSocket. Mirrors the existing reqwest-based path which has +/// no explicit cap; we set a generous bound so long-running tools (e.g. +/// `run_tests`) still complete. +pub const MCP_VIA_WS_TIMEOUT_MS: u64 = 1_200_000; + +/// Handle to a sled currently connected to the gateway via the uplink WebSocket. +/// +/// Created by the `/api/sled-uplink` WS handler on connect and stored in +/// [`GatewayState::sled_connections`]. The gateway's MCP proxy reads this +/// to forward requests over the live connection rather than via HTTP. +#[derive(Clone)] +pub struct SledConnection { + /// Sender side of the channel the WS handler reads to forward outgoing + /// frames (e.g. `mcp_request`) to the sled. + pub tx: mpsc::UnboundedSender, + /// Timestamp (ms since Unix epoch) of the last `heartbeat` frame received + /// from this sled. Updated atomically by the WS handler task. + pub last_heartbeat_ms: Arc, + /// In-flight MCP requests waiting for a matching `mcp_response` from this + /// sled. Keyed by `req_id`; the oneshot sender is resolved when the + /// response arrives. + pub in_flight: + Arc>>>, +} + +impl SledConnection { + /// Returns `true` if a heartbeat has been received within the last `max_age_ms` + /// milliseconds. + pub fn is_alive(&self, max_age_ms: i64) -> bool { + let last = self.last_heartbeat_ms.load(Ordering::Relaxed); + let now = chrono::Utc::now().timestamp_millis(); + now - last <= max_age_ms + } +} + +/// Proxy a raw MCP request body to a sled over its uplink WebSocket and +/// return the serialised JSON response bytes. +/// +/// Generates a fresh correlation id, registers a oneshot in the connection's +/// in-flight map, sends an `mcp_request` envelope, and waits for the +/// matching `mcp_response` (or [`MCP_VIA_WS_TIMEOUT_MS`] to elapse). +/// +/// The response payload is serialised back into JSON bytes so callers can +/// return it directly to the HTTP client unchanged. +pub async fn proxy_mcp_via_ws( + conn: &SledConnection, + request_bytes: &[u8], +) -> Result, String> { + let req_id = uuid::Uuid::new_v4().to_string(); + let body_str = std::str::from_utf8(request_bytes) + .map_err(|e| format!("non-utf8 mcp request body: {e}"))? + .to_string(); + + let (tx, rx) = tokio::sync::oneshot::channel(); + conn.in_flight.lock().await.insert(req_id.clone(), tx); + + let env = crate::sled_uplink::UplinkEnvelope { + msg_type: "mcp_request".to_string(), + req_id: req_id.clone(), + payload: serde_json::json!({ "body": body_str }), + }; + + if conn.tx.send(env).is_err() { + conn.in_flight.lock().await.remove(&req_id); + return Err("sled uplink connection closed".to_string()); + } + + let timeout = std::time::Duration::from_millis(MCP_VIA_WS_TIMEOUT_MS); + match tokio::time::timeout(timeout, rx).await { + Ok(Ok(response_value)) => { + serde_json::to_vec(&response_value).map_err(|e| format!("serialise mcp_response: {e}")) + } + Ok(Err(_)) => Err("sled response channel dropped".to_string()), + Err(_) => { + conn.in_flight.lock().await.remove(&req_id); + Err(format!( + "mcp call to sled timed out after {MCP_VIA_WS_TIMEOUT_MS} ms" + )) + } + } +} + // ── Error type ────────────────────────────────────────────────────────────── /// Typed errors returned by `service::gateway` functions. @@ -135,11 +225,18 @@ pub struct GatewayState { /// The Matrix bot's `permission_listener` holds this locked for its lifetime; /// the sled-uplink WS handler sends requests via `perm_tx`. pub perm_rx: Arc>>, - /// Reversed sled-token map: token → sled_id. + /// Reversed sled-token map: token → project_name (sled_id). /// - /// Built at startup from [`GatewayConfig::sled_tokens`] (which maps - /// sled_id → token). The handler looks up incoming tokens in O(1). + /// Built at startup from both [`GatewayConfig::sled_tokens`] AND the + /// per-project `auth_token` field (story 899). The handler looks up + /// incoming tokens in O(1) to identify the project the sled represents. pub sled_tokens: HashMap, + /// Live sled connections keyed by project name. + /// + /// Populated by the `/api/sled-uplink` WS handler when a sled authenticates + /// and depopulated when it disconnects. MCP proxy functions check here + /// first (WS route), falling back to HTTP when no live connection exists. + pub sled_connections: Arc>>, } impl GatewayState { @@ -160,11 +257,21 @@ impl GatewayState { .unwrap_or(first_from_config); let (event_tx, _) = tokio::sync::broadcast::channel(EVENT_CHANNEL_CAPACITY); let (perm_tx, perm_rx) = mpsc::unbounded_channel::(); - let sled_tokens: HashMap = gateway_config + + // Build token→project_name map from two sources: + // 1. Legacy top-level [sled_tokens] section (sled_id → token, reversed) + // 2. Per-project auth_token fields (project_name → token, reversed) + let mut sled_tokens: HashMap = gateway_config .sled_tokens .iter() .map(|(sled_id, token)| (token.clone(), sled_id.clone())) .collect(); + for (project_name, entry) in &gateway_config.projects { + if let Some(ref token) = entry.auth_token { + sled_tokens.insert(token.clone(), project_name.clone()); + } + } + Ok(Self { projects: Arc::new(RwLock::new(gateway_config.projects)), active_project: Arc::new(RwLock::new(first)), @@ -178,26 +285,78 @@ impl GatewayState { perm_tx, perm_rx: Arc::new(TokioMutex::new(perm_rx)), sled_tokens, + sled_connections: Arc::new(RwLock::new(HashMap::new())), }) } - /// Get the URL of the currently active project. + /// Get the URL of the currently active project, if one is configured. + /// + /// Returns `Err` when the active project has no URL configured (WS-uplink + /// only) or the project name is not found. pub async fn active_url(&self) -> Result { let name = self.active_project.read().await.clone(); self.projects .read() .await .get(&name) - .map(|p| p.url.clone()) + .and_then(|p| p.url.clone()) .ok_or_else(|| { - Error::ProjectNotFound(format!("active project '{name}' not found in config")) + Error::ProjectNotFound(format!( + "active project '{name}' has no URL configured \ + (use sled-uplink WS or add url to projects.toml)" + )) }) } + + /// Register a live sled connection for the given project. + pub async fn register_sled_connection(&self, project_name: String, conn: SledConnection) { + self.sled_connections + .write() + .await + .insert(project_name, conn); + } + + /// Remove the sled connection for the given project (on disconnect). + pub async fn deregister_sled_connection(&self, project_name: &str) { + self.sled_connections.write().await.remove(project_name); + } + + /// Look up the live sled connection for the active project, returning a + /// clone if one exists and has a recent heartbeat. + /// + /// Returns `None` when no sled has connected for this project or when its + /// heartbeat is overdue. + pub async fn active_sled_connection(&self) -> Option { + let name = self.active_project.read().await.clone(); + let conn = self.sled_connections.read().await.get(&name).cloned()?; + if conn.is_alive(HEARTBEAT_MAX_AGE_MS) { + Some(conn) + } else { + None + } + } + + /// Proxy an MCP request to the active project, preferring the live + /// sled-uplink WebSocket when available (story 899 AC 2) and falling + /// back to HTTP otherwise. + /// + /// Returns the raw response body bytes ready to be relayed to the caller. + pub async fn proxy_active_mcp(&self, bytes: &[u8]) -> Result, String> { + if let Some(conn) = self.active_sled_connection().await { + return proxy_mcp_via_ws(&conn, bytes).await; + } + let url = self.active_url().await.map_err(|e| e.to_string())?; + crate::slog!( + "[gateway] MCP proxy: WS uplink unavailable, falling back to HTTP \ + (deprecated, will be removed once all sleds are WS-only)" + ); + crate::service::gateway::io::proxy_mcp_call(&self.client, &url, bytes).await + } } // ── Public API ────────────────────────────────────────────────────────────── -/// Switch the active project. Returns the project's URL on success. +/// Switch the active project. Returns the project's URL (empty for WS-only projects). /// /// Writes the new active project to the CRDT `gateway_config.active_project` /// register (LWW — last write wins) so the selection is persisted across @@ -358,7 +517,7 @@ pub async fn add_project(state: &GatewayState, name: &str, url: &str) -> Result< "project '{name}' already exists" ))); } - projects.insert(name.clone(), ProjectEntry { url: url.clone() }); + projects.insert(name.clone(), ProjectEntry::with_url(&url)); } let snapshot = state.projects.read().await.clone(); @@ -441,7 +600,7 @@ pub async fn init_project( "project '{n}' is already registered. Choose a different name or use switch_project." ))); } - projects.insert(n.to_string(), ProjectEntry { url: u.to_string() }); + projects.insert(n.to_string(), ProjectEntry::with_url(u)); io::save_config(&projects, &state.config_dir).await; crate::slog!("[gateway] init_project: registered '{n}' ({u})"); Some(n.to_string()) @@ -494,7 +653,7 @@ pub async fn save_bot_config_and_restart(state: &GatewayState, content: &str) -> .read() .await .iter() - .map(|(name, entry)| (name.clone(), entry.url.clone())) + .filter_map(|(name, entry)| entry.url.as_ref().map(|u| (name.clone(), u.clone()))) .collect(); let (new_handle, new_shutdown_tx) = io::spawn_gateway_bot( @@ -523,12 +682,7 @@ mod tests { fn make_config(names: &[(&str, &str)]) -> GatewayConfig { let mut projects = BTreeMap::new(); for (name, url) in names { - projects.insert( - name.to_string(), - ProjectEntry { - url: url.to_string(), - }, - ); + projects.insert(name.to_string(), ProjectEntry::with_url(*url)); } GatewayConfig { projects, @@ -584,6 +738,24 @@ mod tests { assert_eq!(url, "http://my:3001"); } + #[tokio::test] + async fn active_url_fails_for_ws_only_project() { + let mut projects = BTreeMap::new(); + projects.insert( + "ws-proj".into(), + ProjectEntry { + url: None, + auth_token: Some("tok".into()), + }, + ); + let config = GatewayConfig { + projects, + sled_tokens: BTreeMap::new(), + }; + let state = GatewayState::new(config, PathBuf::from("."), 3000).unwrap(); + assert!(state.active_url().await.is_err()); + } + #[test] fn error_display_variants() { assert!( @@ -682,4 +854,47 @@ mod tests { let result = init_project(&state, dir.path().to_str().unwrap(), None, None).await; assert!(result.is_err()); } + + #[tokio::test] + async fn sled_connection_registration_and_lookup() { + let config = make_config(&[("myproj", "http://myproj:3001")]); + let state = GatewayState::new(config, PathBuf::new(), 3000).unwrap(); + + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = SledConnection { + tx, + last_heartbeat_ms: Arc::new(AtomicI64::new(chrono::Utc::now().timestamp_millis())), + in_flight: Arc::new(TokioMutex::new(HashMap::new())), + }; + + state + .register_sled_connection("myproj".to_string(), conn) + .await; + assert!(state.sled_connections.read().await.contains_key("myproj")); + + state.deregister_sled_connection("myproj").await; + assert!(!state.sled_connections.read().await.contains_key("myproj")); + } + + #[tokio::test] + async fn auth_token_in_project_entry_populates_sled_tokens_map() { + let mut projects = BTreeMap::new(); + projects.insert( + "huskies".into(), + ProjectEntry { + url: Some("http://huskies:3001".into()), + auth_token: Some("secret-token".into()), + }, + ); + let config = GatewayConfig { + projects, + sled_tokens: BTreeMap::new(), + }; + let state = GatewayState::new(config, PathBuf::new(), 3000).unwrap(); + assert_eq!( + state.sled_tokens.get("secret-token").map(|s| s.as_str()), + Some("huskies"), + "Per-project auth_token must be in reversed sled_tokens map" + ); + } } diff --git a/server/src/sled_uplink.rs b/server/src/sled_uplink.rs index 1629df65..163efadf 100644 --- a/server/src/sled_uplink.rs +++ b/server/src/sled_uplink.rs @@ -1,6 +1,5 @@ //! Sled uplink — background task that maintains a WebSocket connection from a -//! sled (standard huskies instance) to an upstream gateway for permission -//! request forwarding. +//! sled (standard huskies instance) to an upstream gateway. //! //! When `HUSKIES_UPSTREAM_GATEWAY` is set (or `--upstream-gateway` is passed //! on the CLI), this module spawns a task that: @@ -10,10 +9,17 @@ //! auto-denying requests with "no interactive session". //! 2. Maintains a persistent WebSocket connection to the gateway's //! `/api/sled-uplink` endpoint. -//! 3. Forwards each [`PermissionForward`] as a `perm_request` JSON envelope. -//! 4. Awaits the matching `perm_response` envelope from the gateway. -//! 5. Reconnects with exponential back-off on connection drop, fail-closing -//! any in-flight requests with [`PermissionDecision::Deny`]. +//! 3. Sends an `identity` frame announcing the sled's project name + auth +//! token immediately after connect (story 899 Phase 2). +//! 4. Sends a `heartbeat` frame every [`HEARTBEAT_INTERVAL_SECS`] so the +//! gateway can mark the connection live without extra HTTP polls. +//! 5. Forwards each [`PermissionForward`] as a `perm_request` JSON envelope +//! and awaits the matching `perm_response`. +//! 6. Handles inbound `mcp_request` frames by replaying the MCP JSON-RPC +//! body against the sled's own local `/mcp` HTTP endpoint and returning +//! the response as an `mcp_response` frame. +//! 7. Reconnects with exponential back-off on connection drop, fail-closing +//! any in-flight permission requests with [`PermissionDecision::Deny`]. use crate::http::context::{PermissionDecision, PermissionForward}; use crate::services::Services; @@ -25,12 +31,35 @@ use std::sync::Arc; use tokio::sync::oneshot; use tokio_tungstenite::tungstenite::Message as WsMessage; +/// Configuration for [`spawn_uplink_task`]. +/// +/// Bundled into a single struct so `main.rs` can build it once and pass it in +/// without proliferating positional arguments. +pub struct UplinkConfig { + /// WebSocket URL of the upstream gateway's `/api/sled-uplink` endpoint. + /// Includes the `?token=` query parameter for auth. + pub upstream_url: String, + /// Project name this sled identifies as. Sent in the `identity` frame + /// after WS connect (story 899 AC 5). + pub project_name: String, + /// HTTP base URL of this sled's own MCP endpoint (e.g. + /// `http://127.0.0.1:3001/mcp`). Used to replay `mcp_request` frames + /// received from the gateway against the local MCP handler. + pub local_mcp_url: String, +} + // ── Back-off constants ──────────────────────────────────────────────────────── const INITIAL_BACKOFF_SECS: u64 = 1; const MAX_BACKOFF_SECS: u64 = 60; const BACKOFF_MULTIPLIER: u64 = 2; +/// Interval between `heartbeat` frames sent to the gateway (story 899 AC 3). +/// +/// Must be shorter than `crate::service::gateway::HEARTBEAT_MAX_AGE_MS` so +/// the gateway never marks a healthy sled as stale. +pub const HEARTBEAT_INTERVAL_SECS: u64 = 10; + // ── Wire protocol ───────────────────────────────────────────────────────────── /// Extensible JSON envelope for all sled↔gateway uplink messages. @@ -54,15 +83,21 @@ pub struct UplinkEnvelope { /// Spawn the sled uplink background task. /// -/// Does nothing when `upstream_url` is empty. When active, the task holds -/// `services.perm_rx` locked for its lifetime (preventing auto-deny in +/// Does nothing when `config.upstream_url` is empty (AC 8 — sleds without +/// an upstream configured continue to work unchanged). When active, the task +/// holds `services.perm_rx` locked for its lifetime (preventing auto-deny in /// `tool_prompt_permission`) and forwards all permission requests to the /// gateway. Reconnects automatically with exponential back-off. -pub fn spawn_uplink_task(upstream_url: String, services: Arc) { - if upstream_url.is_empty() { +pub fn spawn_uplink_task(config: UplinkConfig, services: Arc) { + if config.upstream_url.is_empty() { return; } - slog!("[uplink] Spawning sled uplink task (gateway={upstream_url})"); + let UplinkConfig { + upstream_url, + project_name, + local_mcp_url, + } = config; + slog!("[uplink] Spawning sled uplink task (gateway={upstream_url}, project={project_name})"); tokio::spawn(async move { // Acquire perm_rx for this task's entire lifetime. While this lock is // held, try_lock() inside tool_prompt_permission fails — meaning @@ -70,9 +105,19 @@ pub fn spawn_uplink_task(upstream_url: String, services: Arc) { let mut perm_rx = services.perm_rx.lock().await; slog!("[uplink] Acquired perm_rx; maintaining gateway connection"); + let http = reqwest::Client::new(); + let mut backoff = INITIAL_BACKOFF_SECS; loop { - match run_uplink_session(&upstream_url, &mut perm_rx).await { + match run_uplink_session( + &upstream_url, + &project_name, + &local_mcp_url, + &http, + &mut perm_rx, + ) + .await + { Ok(()) => { slog!("[uplink] Connection closed cleanly; reconnecting in {backoff}s"); } @@ -90,10 +135,14 @@ pub fn spawn_uplink_task(upstream_url: String, services: Arc) { // ── Private helpers ─────────────────────────────────────────────────────────── -/// Run a single uplink session: connect, pump messages bidirectionally until -/// disconnect or channel close, then fail-close any in-flight requests. +/// Run a single uplink session: connect, send identity frame, pump messages +/// bidirectionally until disconnect or channel close, then fail-close any +/// in-flight requests. async fn run_uplink_session( url: &str, + project_name: &str, + local_mcp_url: &str, + http: &reqwest::Client, perm_rx: &mut tokio::sync::mpsc::UnboundedReceiver, ) -> Result<(), String> { let (ws_stream, _) = tokio_tungstenite::connect_async(url) @@ -102,9 +151,31 @@ async fn run_uplink_session( slog!("[uplink] Connected to gateway uplink endpoint"); let (mut ws_sink, mut ws_rx) = ws_stream.split(); + + // ── Identity handshake (story 899 AC 5) ───────────────────────────── + let identity = UplinkEnvelope { + msg_type: "identity".to_string(), + req_id: "identity".to_string(), + payload: serde_json::json!({ "project": project_name }), + }; + let identity_text = + serde_json::to_string(&identity).map_err(|e| format!("serialise identity: {e}"))?; + ws_sink + .send(WsMessage::Text(identity_text.into())) + .await + .map_err(|e| format!("send identity: {e}"))?; + let mut in_flight: HashMap> = HashMap::new(); - let result = pump_messages(&mut ws_sink, &mut ws_rx, perm_rx, &mut in_flight).await; + let result = pump_messages( + &mut ws_sink, + &mut ws_rx, + perm_rx, + &mut in_flight, + local_mcp_url, + http, + ) + .await; fail_close_all(&mut in_flight); result } @@ -118,9 +189,44 @@ async fn pump_messages( ), perm_rx: &mut tokio::sync::mpsc::UnboundedReceiver, in_flight: &mut HashMap>, + local_mcp_url: &str, + http: &reqwest::Client, ) -> Result<(), String> { + // Channel for spawned mcp_request handlers to deliver their finished + // mcp_response frames back to the WS writer. + let (mcp_resp_tx, mut mcp_resp_rx) = tokio::sync::mpsc::unbounded_channel::(); + + let mut heartbeat = + tokio::time::interval(std::time::Duration::from_secs(HEARTBEAT_INTERVAL_SECS)); + // Skip the immediate tick so the first heartbeat fires after the interval, + // not right after connect. + heartbeat.tick().await; + loop { tokio::select! { + // Heartbeat tick (story 899 AC 3). + _ = heartbeat.tick() => { + let env = UplinkEnvelope { + msg_type: "heartbeat".to_string(), + req_id: String::new(), + payload: serde_json::Value::Null, + }; + let text = serde_json::to_string(&env) + .map_err(|e| format!("serialise heartbeat: {e}"))?; + if ws_sink.send(WsMessage::Text(text.into())).await.is_err() { + return Err("WS send failed (heartbeat)".to_string()); + } + } + + // Completed mcp_response from a spawned handler — forward to gateway. + Some(env) = mcp_resp_rx.recv() => { + let text = serde_json::to_string(&env) + .map_err(|e| format!("serialise mcp_response: {e}"))?; + if ws_sink.send(WsMessage::Text(text.into())).await.is_err() { + return Err("WS send failed (mcp_response)".to_string()); + } + } + // New permission request from the MCP layer. maybe_fwd = perm_rx.recv() => { match maybe_fwd { @@ -162,7 +268,7 @@ async fn pump_messages( return Err("Gateway sent Close frame".to_string()); } Some(Ok(WsMessage::Text(text))) => { - on_gateway_text(&text, in_flight); + on_gateway_text(&text, in_flight, local_mcp_url, http, &mcp_resp_tx); } Some(Ok(WsMessage::Ping(data))) => { let _ = ws_sink.send(WsMessage::Pong(data)).await; @@ -174,19 +280,85 @@ async fn pump_messages( } } -/// Parse an incoming gateway text frame and resolve any matching in-flight request. +/// Parse an incoming gateway text frame and dispatch it to the appropriate +/// handler. `perm_response` resolves a waiting in-flight permission request; +/// `mcp_request` spawns a task that replays the JSON-RPC body against the +/// sled's local `/mcp` HTTP endpoint and forwards the result back to the +/// gateway as an `mcp_response`. fn on_gateway_text( text: &str, in_flight: &mut HashMap>, + local_mcp_url: &str, + http: &reqwest::Client, + mcp_resp_tx: &tokio::sync::mpsc::UnboundedSender, ) { let Ok(env) = serde_json::from_str::(text) else { return; }; - if env.msg_type == "perm_response" { - resolve_perm_response(env, in_flight); + match env.msg_type.as_str() { + "perm_response" => resolve_perm_response(env, in_flight), + "mcp_request" => spawn_mcp_request_handler(env, local_mcp_url, http, mcp_resp_tx), + _ => {} } } +/// Replay the gateway's `mcp_request` body against the sled's local MCP HTTP +/// endpoint in a spawned task and forward the response back as an +/// `mcp_response` envelope. The payload is expected to contain a `body` +/// string field holding the raw JSON-RPC bytes. +fn spawn_mcp_request_handler( + env: UplinkEnvelope, + local_mcp_url: &str, + http: &reqwest::Client, + mcp_resp_tx: &tokio::sync::mpsc::UnboundedSender, +) { + let req_id = env.req_id.clone(); + let body = env + .payload + .get("body") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let mcp_url = local_mcp_url.to_string(); + let http = http.clone(); + let mcp_resp_tx = mcp_resp_tx.clone(); + tokio::spawn(async move { + let response_value = match http + .post(&mcp_url) + .header("Content-Type", "application/json") + .body(body) + .send() + .await + { + Ok(r) => match r.json::().await { + Ok(v) => v, + Err(e) => serde_json::json!({ + "jsonrpc": "2.0", + "id": null, + "error": { + "code": -32603, + "message": format!("local mcp invalid JSON: {e}"), + } + }), + }, + Err(e) => serde_json::json!({ + "jsonrpc": "2.0", + "id": null, + "error": { + "code": -32603, + "message": format!("local mcp request failed: {e}"), + } + }), + }; + let resp = UplinkEnvelope { + msg_type: "mcp_response".to_string(), + req_id, + payload: response_value, + }; + let _ = mcp_resp_tx.send(resp); + }); +} + /// Map a `perm_response` envelope to a [`PermissionDecision`] and wake the /// waiting MCP call. fn resolve_perm_response( @@ -321,9 +493,13 @@ mod tests { #[test] fn on_gateway_text_ignores_unknown_type() { let mut in_flight: HashMap> = HashMap::new(); + let (tx, _rx) = tokio::sync::mpsc::unbounded_channel(); on_gateway_text( r#"{"type":"future_type","req_id":"x","payload":{}}"#, &mut in_flight, + "http://127.0.0.1:0/mcp", + &reqwest::Client::new(), + &tx, ); assert!(in_flight.is_empty()); } @@ -331,7 +507,14 @@ mod tests { #[test] fn on_gateway_text_ignores_invalid_json() { let mut in_flight: HashMap> = HashMap::new(); - on_gateway_text("not-json", &mut in_flight); // must not panic + let (tx, _rx) = tokio::sync::mpsc::unbounded_channel(); + on_gateway_text( + "not-json", + &mut in_flight, + "http://127.0.0.1:0/mcp", + &reqwest::Client::new(), + &tx, + ); // must not panic assert!(in_flight.is_empty()); } @@ -351,7 +534,14 @@ mod tests { permission_timeout_secs: 120, }); // Empty URL → noop; if it panicked or blocked the test would fail. - spawn_uplink_task(String::new(), services); + spawn_uplink_task( + UplinkConfig { + upstream_url: String::new(), + project_name: "test".to_string(), + local_mcp_url: "http://127.0.0.1:0/mcp".to_string(), + }, + services, + ); } // ── AC 11: permission approved via uplink ──────────────────────── @@ -364,10 +554,23 @@ mod tests { let port = listener.local_addr().unwrap().port(); let url = format!("ws://127.0.0.1:{port}"); - // Mock gateway: accept one connection, receive perm_request, reply approved. + // Mock gateway: accept one connection, consume identity, receive + // perm_request, reply approved. let gw_task = tokio::spawn(async move { let (tcp, _) = listener.accept().await.unwrap(); let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap(); + + // First frame: identity (story 899 AC 5). + let msg = ws.next().await.unwrap().unwrap(); + let text = match msg { + WsMessage::Text(t) => t.to_string(), + other => panic!("expected identity Text; got {other:?}"), + }; + let env: UplinkEnvelope = serde_json::from_str(&text).unwrap(); + assert_eq!(env.msg_type, "identity"); + assert_eq!(env.payload["project"], "test-proj"); + + // Next frame: perm_request. let msg = ws.next().await.unwrap().unwrap(); let text = match msg { WsMessage::Text(t) => t.to_string(), @@ -402,7 +605,14 @@ mod tests { permission_timeout_secs: 120, }); - spawn_uplink_task(url, Arc::clone(&services)); + spawn_uplink_task( + UplinkConfig { + upstream_url: url, + project_name: "test-proj".to_string(), + local_mcp_url: "http://127.0.0.1:0/mcp".to_string(), + }, + Arc::clone(&services), + ); tokio::time::sleep(std::time::Duration::from_millis(100)).await; let (response_tx, response_rx) = oneshot::channel(); @@ -441,12 +651,14 @@ mod tests { let conn_count2 = StdArc::clone(&conn_count); tokio::spawn(async move { - // First connection: receive the request then immediately drop (simulates network failure). + // First connection: consume identity + the request frame, then + // drop without replying (simulates network failure). { let (tcp, _) = listener.accept().await.unwrap(); let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap(); conn_count2.fetch_add(1, Ordering::SeqCst); - let _ = ws.next().await; // consume one frame + let _ = ws.next().await; // identity frame + let _ = ws.next().await; // perm_request frame drop(ws); // close without replying → fail-close on sled side } // Second connection: approve the next request. @@ -454,7 +666,10 @@ mod tests { let (tcp, _) = listener.accept().await.unwrap(); let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap(); conn_count2.fetch_add(1, Ordering::SeqCst); - if let Some(Ok(WsMessage::Text(text))) = ws.next().await { + // Consume identity frame. + let _ = ws.next().await; + // Drain non-perm frames (heartbeat etc.) until perm_request arrives. + while let Some(Ok(WsMessage::Text(text))) = ws.next().await { let env: UplinkEnvelope = serde_json::from_str(&text).unwrap(); if env.msg_type == "perm_request" { let resp = UplinkEnvelope { @@ -467,6 +682,7 @@ mod tests { serde_json::to_string(&resp).unwrap().into(), )) .await; + break; } } } @@ -486,7 +702,14 @@ mod tests { permission_timeout_secs: 120, }); - spawn_uplink_task(url, Arc::clone(&services)); + spawn_uplink_task( + UplinkConfig { + upstream_url: url, + project_name: "test-proj".to_string(), + local_mcp_url: "http://127.0.0.1:0/mcp".to_string(), + }, + Arc::clone(&services), + ); tokio::time::sleep(std::time::Duration::from_millis(100)).await; // First request is sent on the connection that drops → denied.