//! Sled uplink — background task that maintains a WebSocket connection from a //! sled (standard huskies instance) to an upstream gateway. //! //! When `HUSKIES_UPSTREAM_GATEWAY` is set (or `--upstream-gateway` is passed //! on the CLI), this module spawns a task that: //! //! 1. Acquires `services.perm_rx` for its lifetime (matching the Matrix bot's //! `permission_listener` pattern), preventing `tool_prompt_permission` from //! auto-denying requests with "no interactive session". //! 2. Maintains a persistent WebSocket connection to the gateway's //! `/api/sled-uplink` endpoint. //! 3. Sends an `identity` frame announcing the sled's project name + auth //! token immediately after connect (story 899 Phase 2). //! 4. Sends a `heartbeat` frame every [`HEARTBEAT_INTERVAL_SECS`] so the //! gateway can mark the connection live without extra HTTP polls. //! 5. Forwards each [`PermissionForward`] as a `perm_request` JSON envelope //! and awaits the matching `perm_response`. //! 6. Handles inbound `mcp_request` frames by replaying the MCP JSON-RPC //! body against the sled's own local `/mcp` HTTP endpoint and returning //! the response as an `mcp_response` frame. //! 7. Reconnects with exponential back-off on connection drop, fail-closing //! any in-flight permission requests with [`PermissionDecision::Deny`]. use crate::http::context::{PermissionDecision, PermissionForward}; use crate::services::Services; use crate::slog; use futures::SinkExt as _; use futures::StreamExt as _; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::oneshot; use tokio_tungstenite::tungstenite::Message as WsMessage; /// Configuration for [`spawn_uplink_task`]. /// /// Bundled into a single struct so `main.rs` can build it once and pass it in /// without proliferating positional arguments. pub struct UplinkConfig { /// WebSocket URL of the upstream gateway's `/api/sled-uplink` endpoint. /// Includes the `?token=` query parameter for auth. pub upstream_url: String, /// Project name this sled identifies as. Sent in the `identity` frame /// after WS connect (story 899 AC 5). pub project_name: String, /// HTTP base URL of this sled's own MCP endpoint (e.g. /// `http://127.0.0.1:3001/mcp`). Used to replay `mcp_request` frames /// received from the gateway against the local MCP handler. pub local_mcp_url: String, } // ── Back-off constants ──────────────────────────────────────────────────────── const INITIAL_BACKOFF_SECS: u64 = 1; const MAX_BACKOFF_SECS: u64 = 60; const BACKOFF_MULTIPLIER: u64 = 2; /// Interval between `heartbeat` frames sent to the gateway (story 899 AC 3). /// /// Must be shorter than `crate::service::gateway::HEARTBEAT_MAX_AGE_MS` so /// the gateway never marks a healthy sled as stale. pub const HEARTBEAT_INTERVAL_SECS: u64 = 10; // ── Wire protocol ───────────────────────────────────────────────────────────── /// Extensible JSON envelope for all sled↔gateway uplink messages. /// /// Phase 1 defines `perm_request` (sled→gateway) and `perm_response` /// (gateway→sled). Future phases add new `type` values without changing this /// framing. #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] pub struct UplinkEnvelope { /// Message type discriminant (e.g. `"perm_request"`, `"perm_response"`). #[serde(rename = "type")] pub msg_type: String, /// Correlation ID — the sled chooses this for each request; the gateway /// echoes it back so the sled can demux concurrent responses. pub req_id: String, /// Message-specific payload. Varies by `msg_type`. pub payload: serde_json::Value, } // ── Public API ──────────────────────────────────────────────────────────────── /// Spawn the sled uplink background task. /// /// Does nothing when `config.upstream_url` is empty (AC 8 — sleds without /// an upstream configured continue to work unchanged). When active, the task /// holds `services.perm_rx` locked for its lifetime (preventing auto-deny in /// `tool_prompt_permission`) and forwards all permission requests to the /// gateway. Reconnects automatically with exponential back-off. pub fn spawn_uplink_task(config: UplinkConfig, services: Arc) { if config.upstream_url.is_empty() { return; } let UplinkConfig { upstream_url, project_name, local_mcp_url, } = config; slog!("[uplink] Spawning sled uplink task (gateway={upstream_url}, project={project_name})"); tokio::spawn(async move { // Acquire perm_rx for this task's entire lifetime. While this lock is // held, try_lock() inside tool_prompt_permission fails — meaning // requests flow to perm_tx (which we drain here) rather than auto-deny. let mut perm_rx = services.perm_rx.lock().await; slog!("[uplink] Acquired perm_rx; maintaining gateway connection"); let http = reqwest::Client::new(); let mut backoff = INITIAL_BACKOFF_SECS; loop { match run_uplink_session( &upstream_url, &project_name, &local_mcp_url, &http, &mut perm_rx, ) .await { Ok(()) => { slog!("[uplink] Connection closed cleanly; reconnecting in {backoff}s"); } Err(ref e) => { slog!("[uplink] Session error: {e}; reconnecting in {backoff}s"); } } tokio::time::sleep(std::time::Duration::from_secs(backoff)).await; backoff = backoff .saturating_mul(BACKOFF_MULTIPLIER) .min(MAX_BACKOFF_SECS); } }); } // ── Private helpers ─────────────────────────────────────────────────────────── /// Run a single uplink session: connect, send identity frame, pump messages /// bidirectionally until disconnect or channel close, then fail-close any /// in-flight requests. async fn run_uplink_session( url: &str, project_name: &str, local_mcp_url: &str, http: &reqwest::Client, perm_rx: &mut tokio::sync::mpsc::UnboundedReceiver, ) -> Result<(), String> { let (ws_stream, _) = tokio_tungstenite::connect_async(url) .await .map_err(|e| format!("WS connect to {url}: {e}"))?; slog!("[uplink] Connected to gateway uplink endpoint"); let (mut ws_sink, mut ws_rx) = ws_stream.split(); // ── Identity handshake (story 899 AC 5) ───────────────────────────── let identity = UplinkEnvelope { msg_type: "identity".to_string(), req_id: "identity".to_string(), payload: serde_json::json!({ "project": project_name }), }; let identity_text = serde_json::to_string(&identity).map_err(|e| format!("serialise identity: {e}"))?; ws_sink .send(WsMessage::Text(identity_text.into())) .await .map_err(|e| format!("send identity: {e}"))?; let mut in_flight: HashMap> = HashMap::new(); let result = pump_messages( &mut ws_sink, &mut ws_rx, perm_rx, &mut in_flight, local_mcp_url, http, ) .await; fail_close_all(&mut in_flight); result } /// Drive the bidirectional message loop for one session. async fn pump_messages( ws_sink: &mut (impl futures::Sink + Unpin), ws_rx: &mut ( impl futures::Stream> + Unpin ), perm_rx: &mut tokio::sync::mpsc::UnboundedReceiver, in_flight: &mut HashMap>, local_mcp_url: &str, http: &reqwest::Client, ) -> Result<(), String> { // Channel for spawned mcp_request handlers to deliver their finished // mcp_response frames back to the WS writer. let (mcp_resp_tx, mut mcp_resp_rx) = tokio::sync::mpsc::unbounded_channel::(); let mut heartbeat = tokio::time::interval(std::time::Duration::from_secs(HEARTBEAT_INTERVAL_SECS)); // Skip the immediate tick so the first heartbeat fires after the interval, // not right after connect. heartbeat.tick().await; loop { tokio::select! { // Heartbeat tick (story 899 AC 3). _ = heartbeat.tick() => { let env = UplinkEnvelope { msg_type: "heartbeat".to_string(), req_id: String::new(), payload: serde_json::Value::Null, }; let text = serde_json::to_string(&env) .map_err(|e| format!("serialise heartbeat: {e}"))?; if ws_sink.send(WsMessage::Text(text.into())).await.is_err() { return Err("WS send failed (heartbeat)".to_string()); } } // Completed mcp_response from a spawned handler — forward to gateway. Some(env) = mcp_resp_rx.recv() => { let text = serde_json::to_string(&env) .map_err(|e| format!("serialise mcp_response: {e}"))?; if ws_sink.send(WsMessage::Text(text.into())).await.is_err() { return Err("WS send failed (mcp_response)".to_string()); } } // New permission request from the MCP layer. maybe_fwd = perm_rx.recv() => { match maybe_fwd { None => return Ok(()), // channel closed — server shutting down Some(fwd) => { let PermissionForward { request_id, tool_name, tool_input, response_tx, } = fwd; let env = UplinkEnvelope { msg_type: "perm_request".to_string(), req_id: request_id.clone(), payload: serde_json::json!({ "tool_name": tool_name, "tool_input": tool_input, }), }; let text = serde_json::to_string(&env) .map_err(|e| format!("serialise perm_request: {e}"))?; if ws_sink.send(WsMessage::Text(text.into())).await.is_err() { // Connection dead: fail-close this request immediately. let _ = response_tx.send(PermissionDecision::Deny); return Err("WS send failed".to_string()); } in_flight.insert(request_id, response_tx); } } } // Message arriving from the gateway. msg = ws_rx.next() => { match msg { None | Some(Err(_)) => { return Err("WS stream closed".to_string()); } Some(Ok(WsMessage::Close(_))) => { return Err("Gateway sent Close frame".to_string()); } Some(Ok(WsMessage::Text(text))) => { on_gateway_text(&text, in_flight, local_mcp_url, http, &mcp_resp_tx); } Some(Ok(WsMessage::Ping(data))) => { let _ = ws_sink.send(WsMessage::Pong(data)).await; } Some(Ok(_)) => {} } } } } } /// Parse an incoming gateway text frame and dispatch it to the appropriate /// handler. `perm_response` resolves a waiting in-flight permission request; /// `mcp_request` spawns a task that replays the JSON-RPC body against the /// sled's local `/mcp` HTTP endpoint and forwards the result back to the /// gateway as an `mcp_response`. fn on_gateway_text( text: &str, in_flight: &mut HashMap>, local_mcp_url: &str, http: &reqwest::Client, mcp_resp_tx: &tokio::sync::mpsc::UnboundedSender, ) { let Ok(env) = serde_json::from_str::(text) else { return; }; match env.msg_type.as_str() { "perm_response" => resolve_perm_response(env, in_flight), "mcp_request" => spawn_mcp_request_handler(env, local_mcp_url, http, mcp_resp_tx), _ => {} } } /// Replay the gateway's `mcp_request` body against the sled's local MCP HTTP /// endpoint in a spawned task and forward the response back as an /// `mcp_response` envelope. The payload is expected to contain a `body` /// string field holding the raw JSON-RPC bytes. fn spawn_mcp_request_handler( env: UplinkEnvelope, local_mcp_url: &str, http: &reqwest::Client, mcp_resp_tx: &tokio::sync::mpsc::UnboundedSender, ) { let req_id = env.req_id.clone(); let body = env .payload .get("body") .and_then(|v| v.as_str()) .unwrap_or("") .to_string(); let mcp_url = local_mcp_url.to_string(); let http = http.clone(); let mcp_resp_tx = mcp_resp_tx.clone(); tokio::spawn(async move { let response_value = match http .post(&mcp_url) .header("Content-Type", "application/json") .body(body) .send() .await { Ok(r) => match r.json::().await { Ok(v) => v, Err(e) => serde_json::json!({ "jsonrpc": "2.0", "id": null, "error": { "code": -32603, "message": format!("local mcp invalid JSON: {e}"), } }), }, Err(e) => serde_json::json!({ "jsonrpc": "2.0", "id": null, "error": { "code": -32603, "message": format!("local mcp request failed: {e}"), } }), }; let resp = UplinkEnvelope { msg_type: "mcp_response".to_string(), req_id, payload: response_value, }; let _ = mcp_resp_tx.send(resp); }); } /// Map a `perm_response` envelope to a [`PermissionDecision`] and wake the /// waiting MCP call. fn resolve_perm_response( env: UplinkEnvelope, in_flight: &mut HashMap>, ) { let Some(tx) = in_flight.remove(&env.req_id) else { return; }; let approved = env .payload .get("approved") .and_then(|v| v.as_bool()) .unwrap_or(false); let always_allow = env .payload .get("always_allow") .and_then(|v| v.as_bool()) .unwrap_or(false); let decision = if always_allow { PermissionDecision::AlwaysAllow } else if approved { PermissionDecision::Approve } else { PermissionDecision::Deny }; let _ = tx.send(decision); } /// Deny all in-flight requests (fail-closed on connection drop — AC 8). fn fail_close_all(in_flight: &mut HashMap>) { for (_, tx) in in_flight.drain() { let _ = tx.send(PermissionDecision::Deny); } } // ── Tests ───────────────────────────────────────────────────────────────────── #[cfg(test)] mod tests { use super::*; use crate::http::context::PermissionForward; use crate::services::Services; use std::collections::HashMap; use tokio::net::TcpListener; use tokio::sync::oneshot; use tokio_tungstenite::tungstenite::Message as WsMessage; // ── Pure unit tests ─────────────────────────────────────────────── #[test] fn uplink_envelope_roundtrips_json() { let env = UplinkEnvelope { msg_type: "perm_request".to_string(), req_id: "req-1".to_string(), payload: serde_json::json!({"tool_name": "Bash", "tool_input": {}}), }; let text = serde_json::to_string(&env).unwrap(); let back: UplinkEnvelope = serde_json::from_str(&text).unwrap(); assert_eq!(back.msg_type, "perm_request"); assert_eq!(back.req_id, "req-1"); assert_eq!(back.payload["tool_name"], "Bash"); } #[test] fn resolve_perm_response_approve() { let (tx, rx) = oneshot::channel(); let mut in_flight = HashMap::new(); in_flight.insert("r1".to_string(), tx); let env = UplinkEnvelope { msg_type: "perm_response".to_string(), req_id: "r1".to_string(), payload: serde_json::json!({"approved": true, "always_allow": false}), }; resolve_perm_response(env, &mut in_flight); assert!(in_flight.is_empty()); assert_eq!(rx.blocking_recv().unwrap(), PermissionDecision::Approve); } #[test] fn resolve_perm_response_deny() { let (tx, rx) = oneshot::channel(); let mut in_flight = HashMap::new(); in_flight.insert("r2".to_string(), tx); let env = UplinkEnvelope { msg_type: "perm_response".to_string(), req_id: "r2".to_string(), payload: serde_json::json!({"approved": false}), }; resolve_perm_response(env, &mut in_flight); assert_eq!(rx.blocking_recv().unwrap(), PermissionDecision::Deny); } #[test] fn resolve_perm_response_always_allow() { let (tx, rx) = oneshot::channel(); let mut in_flight = HashMap::new(); in_flight.insert("r3".to_string(), tx); let env = UplinkEnvelope { msg_type: "perm_response".to_string(), req_id: "r3".to_string(), payload: serde_json::json!({"approved": true, "always_allow": true}), }; resolve_perm_response(env, &mut in_flight); assert_eq!(rx.blocking_recv().unwrap(), PermissionDecision::AlwaysAllow); } #[test] fn resolve_perm_response_unknown_req_id_is_noop() { let mut in_flight: HashMap> = HashMap::new(); let env = UplinkEnvelope { msg_type: "perm_response".to_string(), req_id: "missing".to_string(), payload: serde_json::json!({"approved": true}), }; resolve_perm_response(env, &mut in_flight); // must not panic } #[test] fn fail_close_all_denies_all_pending() { let (tx1, rx1) = oneshot::channel(); let (tx2, rx2) = oneshot::channel(); let mut in_flight = HashMap::new(); in_flight.insert("a".to_string(), tx1); in_flight.insert("b".to_string(), tx2); fail_close_all(&mut in_flight); assert!(in_flight.is_empty()); assert_eq!(rx1.blocking_recv().unwrap(), PermissionDecision::Deny); assert_eq!(rx2.blocking_recv().unwrap(), PermissionDecision::Deny); } #[test] fn on_gateway_text_ignores_unknown_type() { let mut in_flight: HashMap> = HashMap::new(); let (tx, _rx) = tokio::sync::mpsc::unbounded_channel(); on_gateway_text( r#"{"type":"future_type","req_id":"x","payload":{}}"#, &mut in_flight, "http://127.0.0.1:0/mcp", &reqwest::Client::new(), &tx, ); assert!(in_flight.is_empty()); } #[test] fn on_gateway_text_ignores_invalid_json() { let mut in_flight: HashMap> = HashMap::new(); let (tx, _rx) = tokio::sync::mpsc::unbounded_channel(); on_gateway_text( "not-json", &mut in_flight, "http://127.0.0.1:0/mcp", &reqwest::Client::new(), &tx, ); // must not panic assert!(in_flight.is_empty()); } #[test] fn spawn_uplink_task_noop_when_url_empty() { let (_perm_tx, perm_rx) = tokio::sync::mpsc::unbounded_channel(); let agents = Arc::new(crate::agents::AgentPool::new_test(3000)); let services = Arc::new(Services { project_root: std::path::PathBuf::from("/tmp"), status: agents.status_broadcaster(), agents, bot_name: "Test".to_string(), bot_user_id: String::new(), ambient_rooms: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())), perm_rx: Arc::new(tokio::sync::Mutex::new(perm_rx)), pending_perm_replies: Arc::new(tokio::sync::Mutex::new(HashMap::new())), permission_timeout_secs: 120, }); // Empty URL → noop; if it panicked or blocked the test would fail. spawn_uplink_task( UplinkConfig { upstream_url: String::new(), project_name: "test".to_string(), local_mcp_url: "http://127.0.0.1:0/mcp".to_string(), }, services, ); } // ── AC 11: permission approved via uplink ──────────────────────── // "Simulate matrix bot triggering a Bash permission, sled forwards via // uplink, mock matrix transport approves, tool call proceeds." #[tokio::test] async fn integration_perm_request_approved_via_uplink() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let port = listener.local_addr().unwrap().port(); let url = format!("ws://127.0.0.1:{port}"); // Mock gateway: accept one connection, consume identity, receive // perm_request, reply approved. let gw_task = tokio::spawn(async move { let (tcp, _) = listener.accept().await.unwrap(); let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap(); // First frame: identity (story 899 AC 5). let msg = ws.next().await.unwrap().unwrap(); let text = match msg { WsMessage::Text(t) => t.to_string(), other => panic!("expected identity Text; got {other:?}"), }; let env: UplinkEnvelope = serde_json::from_str(&text).unwrap(); assert_eq!(env.msg_type, "identity"); assert_eq!(env.payload["project"], "test-proj"); // Next frame: perm_request. let msg = ws.next().await.unwrap().unwrap(); let text = match msg { WsMessage::Text(t) => t.to_string(), other => panic!("expected Text; got {other:?}"), }; let env: UplinkEnvelope = serde_json::from_str(&text).unwrap(); assert_eq!(env.msg_type, "perm_request"); assert_eq!(env.payload["tool_name"], "Bash"); let resp = UplinkEnvelope { msg_type: "perm_response".to_string(), req_id: env.req_id, payload: serde_json::json!({"approved": true, "always_allow": false}), }; ws.send(WsMessage::Text( serde_json::to_string(&resp).unwrap().into(), )) .await .unwrap(); }); let (perm_tx, perm_rx) = tokio::sync::mpsc::unbounded_channel::(); let agents = Arc::new(crate::agents::AgentPool::new_test(3000)); let services = Arc::new(Services { project_root: std::path::PathBuf::from("/tmp"), status: agents.status_broadcaster(), agents, bot_name: "Test".to_string(), bot_user_id: String::new(), ambient_rooms: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())), perm_rx: Arc::new(tokio::sync::Mutex::new(perm_rx)), pending_perm_replies: Arc::new(tokio::sync::Mutex::new(HashMap::new())), permission_timeout_secs: 120, }); spawn_uplink_task( UplinkConfig { upstream_url: url, project_name: "test-proj".to_string(), local_mcp_url: "http://127.0.0.1:0/mcp".to_string(), }, Arc::clone(&services), ); tokio::time::sleep(std::time::Duration::from_millis(100)).await; let (response_tx, response_rx) = oneshot::channel(); perm_tx .send(PermissionForward { request_id: "req-test-1".to_string(), tool_name: "Bash".to_string(), tool_input: serde_json::json!({"command": "echo hello"}), response_tx, }) .unwrap(); let decision = tokio::time::timeout(std::time::Duration::from_secs(5), response_rx) .await .expect("timeout waiting for decision") .expect("oneshot dropped"); assert_eq!(decision, PermissionDecision::Approve); gw_task.await.unwrap(); } // ── AC 12: sled disconnects and reconnects ──────────────────────── // "Sled disconnects and reconnects mid-session; subsequent permission // requests succeed once reconnected." #[tokio::test] async fn integration_reconnects_after_disconnect() { use std::sync::Arc as StdArc; use std::sync::atomic::{AtomicU32, Ordering}; let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let port = listener.local_addr().unwrap().port(); let url = format!("ws://127.0.0.1:{port}"); let conn_count = StdArc::new(AtomicU32::new(0)); let conn_count2 = StdArc::clone(&conn_count); tokio::spawn(async move { // First connection: consume identity + the request frame, then // drop without replying (simulates network failure). { let (tcp, _) = listener.accept().await.unwrap(); let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap(); conn_count2.fetch_add(1, Ordering::SeqCst); let _ = ws.next().await; // identity frame let _ = ws.next().await; // perm_request frame drop(ws); // close without replying → fail-close on sled side } // Second connection: approve the next request. { let (tcp, _) = listener.accept().await.unwrap(); let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap(); conn_count2.fetch_add(1, Ordering::SeqCst); // Consume identity frame. let _ = ws.next().await; // Drain non-perm frames (heartbeat etc.) until perm_request arrives. while let Some(Ok(WsMessage::Text(text))) = ws.next().await { let env: UplinkEnvelope = serde_json::from_str(&text).unwrap(); if env.msg_type == "perm_request" { let resp = UplinkEnvelope { msg_type: "perm_response".to_string(), req_id: env.req_id, payload: serde_json::json!({"approved": true, "always_allow": false}), }; let _ = ws .send(WsMessage::Text( serde_json::to_string(&resp).unwrap().into(), )) .await; break; } } } }); let (perm_tx, perm_rx) = tokio::sync::mpsc::unbounded_channel::(); let agents = Arc::new(crate::agents::AgentPool::new_test(3000)); let services = Arc::new(Services { project_root: std::path::PathBuf::from("/tmp"), status: agents.status_broadcaster(), agents, bot_name: "Test".to_string(), bot_user_id: String::new(), ambient_rooms: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())), perm_rx: Arc::new(tokio::sync::Mutex::new(perm_rx)), pending_perm_replies: Arc::new(tokio::sync::Mutex::new(HashMap::new())), permission_timeout_secs: 120, }); spawn_uplink_task( UplinkConfig { upstream_url: url, project_name: "test-proj".to_string(), local_mcp_url: "http://127.0.0.1:0/mcp".to_string(), }, Arc::clone(&services), ); tokio::time::sleep(std::time::Duration::from_millis(100)).await; // First request is sent on the connection that drops → denied. let (tx1, rx1) = oneshot::channel(); perm_tx .send(PermissionForward { request_id: "req-drop".to_string(), tool_name: "Bash".to_string(), tool_input: serde_json::json!({}), response_tx: tx1, }) .unwrap(); let d1 = tokio::time::timeout(std::time::Duration::from_secs(5), rx1) .await .expect("timeout on first request") .expect("oneshot dropped"); assert_eq!( d1, PermissionDecision::Deny, "dropped connection must fail-close" ); // Wait for the 1-second reconnect backoff plus buffer. tokio::time::sleep(std::time::Duration::from_millis(1500)).await; // Second request arrives on the reconnected session → approved. let (tx2, rx2) = oneshot::channel(); perm_tx .send(PermissionForward { request_id: "req-reconnect".to_string(), tool_name: "Write".to_string(), tool_input: serde_json::json!({}), response_tx: tx2, }) .unwrap(); let d2 = tokio::time::timeout(std::time::Duration::from_secs(5), rx2) .await .expect("timeout on second request") .expect("oneshot dropped"); assert_eq!( d2, PermissionDecision::Approve, "reconnected session must approve" ); assert_eq!( conn_count.load(Ordering::SeqCst), 2, "must have seen exactly 2 connections" ); } }