From 937792f2085bd11c00af9bed0532e1ab40324d7b Mon Sep 17 00:00:00 2001 From: dave Date: Tue, 12 May 2026 21:29:04 +0000 Subject: [PATCH] huskies: merge 898 --- server/src/cli.rs | 53 +++ server/src/gateway/mod.rs | 3 + server/src/gateway/tests.rs | 30 +- server/src/http/gateway/mod.rs | 4 +- server/src/http/gateway/websocket.rs | 128 +++++++ server/src/main.rs | 10 + server/src/service/gateway/config.rs | 12 +- server/src/service/gateway/io.rs | 35 +- server/src/service/gateway/mod.rs | 34 +- server/src/sled_uplink.rs | 543 +++++++++++++++++++++++++++ 10 files changed, 829 insertions(+), 23 deletions(-) create mode 100644 server/src/sled_uplink.rs diff --git a/server/src/cli.rs b/server/src/cli.rs index 2ee8604a..25f073a9 100644 --- a/server/src/cli.rs +++ b/server/src/cli.rs @@ -21,6 +21,12 @@ pub(crate) struct CliArgs { pub(crate) join_token: Option, /// HTTP URL of the gateway to register with when a join token is provided (`--gateway-url`). pub(crate) gateway_url: Option, + /// WebSocket URL of the upstream gateway to forward permission requests to (`--upstream-gateway`). + /// + /// When set, the sled spawns a background uplink task that holds `perm_rx` and + /// forwards all `prompt_permission` tool calls to the gateway over a WebSocket. + /// Also readable from the `HUSKIES_UPSTREAM_GATEWAY` env var. + pub(crate) upstream_gateway: Option, } /// Parse CLI arguments into `CliArgs`, or exit early for `--help` / `--version`. @@ -34,6 +40,7 @@ pub(crate) fn parse_cli_args(args: &[String]) -> Result { let mut rendezvous: Option = None; let mut join_token: Option = None; let mut gateway_url: Option = None; + let mut upstream_gateway: Option = None; let mut i = 0; while i < args.len() { @@ -94,6 +101,16 @@ pub(crate) fn parse_cli_args(args: &[String]) -> Result { a if a.starts_with("--gateway-url=") => { gateway_url = Some(a["--gateway-url=".len()..].to_string()); } + "--upstream-gateway" => { + i += 1; + if i >= args.len() { + return Err("--upstream-gateway requires a value".to_string()); + } + upstream_gateway = Some(args[i].clone()); + } + a if a.starts_with("--upstream-gateway=") => { + upstream_gateway = Some(a["--upstream-gateway=".len()..].to_string()); + } "--gateway" => { gateway = true; } @@ -129,6 +146,7 @@ pub(crate) fn parse_cli_args(args: &[String]) -> Result { gateway, join_token, gateway_url, + upstream_gateway, }) } @@ -167,6 +185,11 @@ pub(crate) fn print_help() { println!(" --gateway-url HTTP URL of the gateway to register with when"); println!(" --join-token is provided (agent mode only)."); println!(" Also readable from HUSKIES_GATEWAY_URL env var."); + println!(" --upstream-gateway WebSocket URL of an upstream gateway to forward"); + println!(" permission requests to (sled mode). When set, the"); + println!(" sled connects to WS URL and forwards all"); + println!(" prompt_permission calls via the uplink protocol."); + println!(" Also readable from HUSKIES_UPSTREAM_GATEWAY env var."); } /// Resolve the optional positional path argument into an absolute `PathBuf`. @@ -343,6 +366,36 @@ mod tests { let result = parse_cli_args(&[]).unwrap(); assert_eq!(result.join_token, None); assert_eq!(result.gateway_url, None); + assert_eq!(result.upstream_gateway, None); + } + + #[test] + fn parse_upstream_gateway_flag() { + let args = vec![ + "--upstream-gateway".to_string(), + "ws://gateway:3001/api/sled-uplink?token=abc".to_string(), + ]; + let result = parse_cli_args(&args).unwrap(); + assert_eq!( + result.upstream_gateway, + Some("ws://gateway:3001/api/sled-uplink?token=abc".to_string()) + ); + } + + #[test] + fn parse_upstream_gateway_equals_syntax() { + let args = vec!["--upstream-gateway=ws://gw:3001/api/sled-uplink?token=x".to_string()]; + let result = parse_cli_args(&args).unwrap(); + assert_eq!( + result.upstream_gateway, + Some("ws://gw:3001/api/sled-uplink?token=x".to_string()) + ); + } + + #[test] + fn parse_upstream_gateway_missing_value_is_error() { + let args = vec!["--upstream-gateway".to_string()]; + assert!(parse_cli_args(&args).is_err()); } // ── resolve_path_arg ──────────────────────────────────────────── diff --git a/server/src/gateway/mod.rs b/server/src/gateway/mod.rs index 20d77d45..39488d97 100644 --- a/server/src/gateway/mod.rs +++ b/server/src/gateway/mod.rs @@ -55,6 +55,8 @@ pub fn build_gateway_route(state_arc: Arc) -> impl poem::Endpoint ) // Agent registration via CRDT-sync WebSocket. .at("/crdt-sync", poem::get(gateway_crdt_sync_handler)) + // Sled uplink: permission-forwarding WebSocket from sleds to gateway. + .at("/api/sled-uplink", poem::get(gateway_sled_uplink_handler)) // Agent management REST endpoints. .at( "/gateway/agents/:id/assign", @@ -126,6 +128,7 @@ pub async fn run(config_path: &Path, port: u16) -> Result<(), std::io::Error> { gateway_project_urls, port, Some(state_arc.event_tx.clone()), + Arc::clone(&state_arc.perm_rx), ); *state_arc.bot_handle.lock().await = bot_abort; *state_arc.bot_shutdown_tx.lock().await = Some(bot_shutdown_tx); diff --git a/server/src/gateway/tests.rs b/server/src/gateway/tests.rs index bebe5434..99fecf11 100644 --- a/server/src/gateway/tests.rs +++ b/server/src/gateway/tests.rs @@ -13,7 +13,10 @@ fn make_test_state() -> Arc { url: "http://test:3001".into(), }, ); - let config = GatewayConfig { projects }; + let config = GatewayConfig { + projects, + sled_tokens: BTreeMap::new(), + }; Arc::new(GatewayState::new(config, PathBuf::new(), 3000).unwrap()) } @@ -368,7 +371,10 @@ async fn init_project_registers_in_projects_toml_when_name_and_url_given() { url: "http://existing:3001".into(), }, ); - let config = GatewayConfig { projects }; + let config = GatewayConfig { + projects, + sled_tokens: BTreeMap::new(), + }; let state = Arc::new(GatewayState::new(config, config_dir.path().to_path_buf(), 3000).unwrap()); let result = gateway::init_project( @@ -395,7 +401,10 @@ async fn init_project_duplicate_name_returns_error() { url: "http://taken:3001".into(), }, ); - let config = GatewayConfig { projects }; + let config = GatewayConfig { + projects, + sled_tokens: BTreeMap::new(), + }; let state = Arc::new(GatewayState::new(config, PathBuf::new(), 3000).unwrap()); let result = gateway::init_project( @@ -444,7 +453,10 @@ async fn init_project_then_wizard_status_integration() { let mut projects = BTreeMap::new(); projects.insert("mock-project".into(), ProjectEntry { url: mock_url }); - let config = GatewayConfig { projects }; + let config = GatewayConfig { + projects, + sled_tokens: BTreeMap::new(), + }; let config_dir = tempfile::tempdir().unwrap(); let state = Arc::new(GatewayState::new(config, config_dir.path().to_path_buf(), 3000).unwrap()); @@ -966,7 +978,10 @@ async fn gateway_mcp_sse_proxy_streams_progress_and_final_response() { url: mock_sled.url(), }, ); - let config = GatewayConfig { projects }; + let config = GatewayConfig { + projects, + sled_tokens: BTreeMap::new(), + }; let state = Arc::new(GatewayState::new(config, PathBuf::new(), 3000).unwrap()); let app = poem::Route::new() @@ -1059,7 +1074,10 @@ async fn gateway_mcp_post_without_sse_returns_plain_json() { url: mock_sled.url(), }, ); - let config = GatewayConfig { projects }; + let config = GatewayConfig { + projects, + sled_tokens: BTreeMap::new(), + }; let state = Arc::new(GatewayState::new(config, PathBuf::new(), 3000).unwrap()); let app = poem::Route::new() diff --git a/server/src/http/gateway/mod.rs b/server/src/http/gateway/mod.rs index 43475d6d..cee9df9f 100644 --- a/server/src/http/gateway/mod.rs +++ b/server/src/http/gateway/mod.rs @@ -22,4 +22,6 @@ pub use rest::{ gateway_bot_config_save_handler, gateway_generate_token_handler, gateway_mode_handler, gateway_remove_project_handler, }; -pub use websocket::{gateway_crdt_sync_handler, gateway_event_push_handler}; +pub use websocket::{ + gateway_crdt_sync_handler, gateway_event_push_handler, gateway_sled_uplink_handler, +}; diff --git a/server/src/http/gateway/websocket.rs b/server/src/http/gateway/websocket.rs index 22dfa3de..f8074407 100644 --- a/server/src/http/gateway/websocket.rs +++ b/server/src/http/gateway/websocket.rs @@ -146,6 +146,134 @@ pub async fn gateway_crdt_sync_handler( .into_response() } +// ── Sled uplink WebSocket handler ──────────────────────────────────────────── + +/// Query parameters accepted on the `/api/sled-uplink` WebSocket upgrade. +#[derive(Deserialize)] +struct SledUplinkParams { + /// Shared-secret token identifying the connecting sled (from `[sled_tokens]` in `projects.toml`). + token: Option, +} + +/// `GET /api/sled-uplink` — gateway-side WebSocket endpoint for sled permission uplinks. +/// +/// # Authentication +/// +/// The connecting sled must supply a valid shared-secret token via the `token` +/// query parameter. Tokens are configured in `[sled_tokens]` in `projects.toml` +/// as `sled_id = "secret"` entries. +/// +/// # Protocol +/// +/// See `sled_uplink.rs` for the wire format ([`UplinkEnvelope`]). The gateway +/// accepts `perm_request` messages, injects them into the local permission +/// pipeline (via `state.perm_tx`), and sends `perm_response` frames back to the +/// sled once the Matrix bot resolves them. Multiple sleds are demuxed by +/// connection: each handler owns exactly one sled's request/response flow. +#[handler] +pub async fn gateway_sled_uplink_handler( + ws: WebSocket, + state: Data<&Arc>, + Query(params): Query, +) -> poem::Response { + let token = match params.token { + Some(t) if !t.is_empty() => t, + _ => { + return poem::Response::builder() + .status(StatusCode::UNAUTHORIZED) + .body("token query parameter required"); + } + }; + + let sled_id = match state.sled_tokens.get(&token) { + Some(id) => id.clone(), + None => { + return poem::Response::builder() + .status(StatusCode::UNAUTHORIZED) + .body("invalid token"); + } + }; + + use poem::IntoResponse as _; + let perm_tx = state.perm_tx.clone(); + ws.on_upgrade(move |socket| async move { + let (mut sink, mut stream) = socket.split(); + // Aggregator channel: spawned per-request tasks send (req_id, decision) here + // so the main loop can write perm_response frames back to the sled. + let (agg_tx, mut agg_rx) = tokio::sync::mpsc::unbounded_channel::<( + String, + crate::http::context::PermissionDecision, + )>(); + + crate::slog!("[gateway/sled-uplink] Sled '{}' connected", sled_id); + + loop { + tokio::select! { + msg = stream.next() => { + let text = match msg { + Some(Ok(WsMessage::Text(t))) => t, + Some(Ok(WsMessage::Close(_))) | None => break, + _ => continue, + }; + let Ok(env) = serde_json::from_str::(&text) else { + continue; + }; + if env.msg_type == "perm_request" { + let req_id = env.req_id.clone(); + let tool_name = env.payload.get("tool_name") + .and_then(|v| v.as_str()) + .unwrap_or("unknown") + .to_string(); + let tool_input = env.payload.get("tool_input") + .cloned() + .unwrap_or(serde_json::Value::Null); + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + let fwd = crate::http::context::PermissionForward { + request_id: format!("{sled_id}:{req_id}"), + tool_name, + tool_input, + response_tx, + }; + if perm_tx.send(fwd).is_err() { + break; + } + let agg_tx2 = agg_tx.clone(); + tokio::spawn(async move { + let decision = response_rx + .await + .unwrap_or(crate::http::context::PermissionDecision::Deny); + let _ = agg_tx2.send((req_id, decision)); + }); + } + } + Some((req_id, decision)) = agg_rx.recv() => { + use crate::http::context::PermissionDecision; + let (approved, always_allow) = match decision { + PermissionDecision::AlwaysAllow => (true, true), + PermissionDecision::Approve => (true, false), + PermissionDecision::Deny => (false, false), + }; + let resp = crate::sled_uplink::UplinkEnvelope { + msg_type: "perm_response".to_string(), + req_id, + payload: serde_json::json!({ + "approved": approved, + "always_allow": always_allow, + }), + }; + let Ok(text) = serde_json::to_string(&resp) else { continue }; + if sink.send(WsMessage::Text(text)).await.is_err() { + break; + } + } + } + } + + crate::slog!("[gateway/sled-uplink] Sled '{}' disconnected", sled_id); + }) + .into_response() +} + // ── Event-push WebSocket handler ───────────────────────────────────────────── /// Query parameters accepted on the `/gateway/events/push` WebSocket upgrade. diff --git a/server/src/main.rs b/server/src/main.rs index 20f7e11d..2aae0cd1 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -38,6 +38,8 @@ pub mod rebuild; mod service; /// Services — shared service bundle injected into HTTP handlers and bot tasks. pub mod services; +/// Sled uplink — background task that forwards permission requests to an upstream gateway. +pub mod sled_uplink; mod startup; mod state; mod store; @@ -233,6 +235,14 @@ async fn main() -> Result<(), std::io::Error> { status: agents.status_broadcaster(), }); + // Sled uplink: forward permission requests to an upstream gateway when configured. + let upstream_gateway = cli + .upstream_gateway + .clone() + .or_else(|| std::env::var("HUSKIES_UPSTREAM_GATEWAY").ok()) + .unwrap_or_default(); + sled_uplink::spawn_uplink_task(upstream_gateway, Arc::clone(&services)); + // ── Build bot contexts (WhatsApp / Slack / Discord) ─────────────────────── let (bot_ctxs, matrix_shutdown_rx) = startup::bots::build_bot_contexts(&startup_root, &services); diff --git a/server/src/service/gateway/config.rs b/server/src/service/gateway/config.rs index 24b88065..1d7866a7 100644 --- a/server/src/service/gateway/config.rs +++ b/server/src/service/gateway/config.rs @@ -19,6 +19,12 @@ pub struct GatewayConfig { /// Map of project name → container URL. #[serde(default)] pub projects: BTreeMap, + /// Map of sled_id → shared secret token for sled-uplink authentication. + /// + /// Each entry allows a sled identified by `sled_id` to connect to + /// `/api/sled-uplink` using the given secret token as a bearer credential. + #[serde(default)] + pub sled_tokens: BTreeMap, } /// Validate that a gateway config has at least one project. @@ -113,6 +119,7 @@ url = "http://localhost:3002" fn validate_config_rejects_empty() { let config = GatewayConfig { projects: BTreeMap::new(), + sled_tokens: BTreeMap::new(), }; assert!(validate_config(&config).is_err()); } @@ -132,7 +139,10 @@ url = "http://localhost:3002" url: "http://a".into(), }, ); - let config = GatewayConfig { projects }; + let config = GatewayConfig { + projects, + sled_tokens: BTreeMap::new(), + }; assert_eq!(validate_config(&config).unwrap(), "alpha"); } diff --git a/server/src/service/gateway/io.rs b/server/src/service/gateway/io.rs index 705f30df..85c6e9af 100644 --- a/server/src/service/gateway/io.rs +++ b/server/src/service/gateway/io.rs @@ -21,13 +21,23 @@ pub fn load_config(path: &Path) -> Result { /// Persist the current projects map to `/projects.toml`. /// Silently ignores write errors or skips when `config_dir` is empty. +/// +/// Existing `[sled_tokens]` entries are preserved so that adding or removing +/// projects via the UI does not wipe the sled authentication tokens. pub async fn save_config(projects: &BTreeMap, config_dir: &Path) { if config_dir.as_os_str().is_empty() { return; } let path = config_dir.join("projects.toml"); + let sled_tokens = tokio::fs::read_to_string(&path) + .await + .ok() + .and_then(|data| toml::from_str::(&data).ok()) + .map(|c| c.sled_tokens) + .unwrap_or_default(); let config = GatewayConfig { projects: projects.clone(), + sled_tokens, }; if let Ok(data) = toml::to_string_pretty(&config) { let _ = tokio::fs::write(&path, data).await; @@ -518,27 +528,20 @@ pub fn spawn_gateway_bot( gateway_project_urls: BTreeMap, port: u16, gateway_event_tx: Option>, + perm_rx: std::sync::Arc< + tokio::sync::Mutex< + tokio::sync::mpsc::UnboundedReceiver, + >, + >, ) -> ( Option, tokio::sync::watch::Sender>, ) { use crate::agents::AgentPool; use crate::services::Services; - use tokio::sync::{broadcast, mpsc}; - - let (watcher_tx, _) = broadcast::channel(16); - let (perm_tx, perm_rx) = mpsc::unbounded_channel(); - // Keep the sender alive for the gateway's lifetime so the matrix bot's - // `permission_listener` task doesn't exit immediately with - // "perm_rx channel closed". Previously `_perm_tx` was dropped when - // `spawn_gateway_bot` returned, closing the channel before the - // listener could even register. Story 898 (sled→gateway WS uplink) - // will eventually wire in a real sender; for now the leak keeps the - // channel open with no senders writing to it, matching the original - // intent of "listener watches forever, waiting for requests". - std::mem::forget(perm_tx); - let perm_rx = std::sync::Arc::new(tokio::sync::Mutex::new(perm_rx)); + use tokio::sync::broadcast; + let (watcher_tx, _) = broadcast::channel::(16); let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel::>(None); // shutdown_tx is intentionally NOT forgotten — the caller holds it and @@ -611,6 +614,9 @@ mod tests { let active = std::sync::Arc::new(tokio::sync::RwLock::new("proj".to_string())); let (event_tx, _) = tokio::sync::broadcast::channel(4); + let (_perm_tx, perm_rx) = + tokio::sync::mpsc::unbounded_channel::(); + let perm_rx = std::sync::Arc::new(tokio::sync::Mutex::new(perm_rx)); let (handle, shutdown_tx) = spawn_gateway_bot( tmp.path(), active, @@ -618,6 +624,7 @@ mod tests { std::collections::BTreeMap::new(), 3001, Some(event_tx), + perm_rx, ); // No bot.toml in tmp → no abort handle spawned. diff --git a/server/src/service/gateway/mod.rs b/server/src/service/gateway/mod.rs index e1f66156..631c9222 100644 --- a/server/src/service/gateway/mod.rs +++ b/server/src/service/gateway/mod.rs @@ -22,6 +22,7 @@ pub use io::{ spawn_gateway_notification_poller, }; +use crate::http::context::PermissionForward; use crate::rebuild::ShutdownReason; use io::Client; use std::collections::{BTreeMap, HashMap}; @@ -29,6 +30,7 @@ use std::path::PathBuf; use std::sync::Arc; use tokio::sync::Mutex as TokioMutex; use tokio::sync::RwLock; +use tokio::sync::mpsc; pub use crate::crdt_state::NodePresenceView; @@ -122,6 +124,22 @@ pub struct GatewayState { /// /// Call `event_tx.subscribe()` to obtain a receiver for outbound fan-out. pub event_tx: tokio::sync::broadcast::Sender, + /// Sender end of the gateway's permission channel. + /// + /// The sled-uplink handler uses this to inject `perm_request` messages + /// received from connected sleds into the gateway's Matrix bot permission + /// pipeline. + pub perm_tx: mpsc::UnboundedSender, + /// Receiver end of the gateway's permission channel (shared with the Matrix bot). + /// + /// The Matrix bot's `permission_listener` holds this locked for its lifetime; + /// the sled-uplink WS handler sends requests via `perm_tx`. + pub perm_rx: Arc>>, + /// Reversed sled-token map: token → sled_id. + /// + /// Built at startup from [`GatewayConfig::sled_tokens`] (which maps + /// sled_id → token). The handler looks up incoming tokens in O(1). + pub sled_tokens: HashMap, } impl GatewayState { @@ -141,6 +159,12 @@ impl GatewayState { .filter(|p| gateway_config.projects.contains_key(p)) .unwrap_or(first_from_config); let (event_tx, _) = tokio::sync::broadcast::channel(EVENT_CHANNEL_CAPACITY); + let (perm_tx, perm_rx) = mpsc::unbounded_channel::(); + let sled_tokens: HashMap = gateway_config + .sled_tokens + .iter() + .map(|(sled_id, token)| (token.clone(), sled_id.clone())) + .collect(); Ok(Self { projects: Arc::new(RwLock::new(gateway_config.projects)), active_project: Arc::new(RwLock::new(first)), @@ -151,6 +175,9 @@ impl GatewayState { bot_handle: Arc::new(TokioMutex::new(None)), bot_shutdown_tx: Arc::new(TokioMutex::new(None)), event_tx, + perm_tx, + perm_rx: Arc::new(TokioMutex::new(perm_rx)), + sled_tokens, }) } @@ -477,6 +504,7 @@ pub async fn save_bot_config_and_restart(state: &GatewayState, content: &str) -> gateway_project_urls, state.port, Some(state.event_tx.clone()), + Arc::clone(&state.perm_rx), ); *handle = new_handle; *state.bot_shutdown_tx.lock().await = Some(new_shutdown_tx); @@ -502,13 +530,17 @@ mod tests { }, ); } - GatewayConfig { projects } + GatewayConfig { + projects, + sled_tokens: BTreeMap::new(), + } } #[test] fn gateway_state_rejects_empty_config() { let config = GatewayConfig { projects: BTreeMap::new(), + sled_tokens: BTreeMap::new(), }; assert!(GatewayState::new(config, PathBuf::from("."), 3000).is_err()); } diff --git a/server/src/sled_uplink.rs b/server/src/sled_uplink.rs new file mode 100644 index 00000000..1629df65 --- /dev/null +++ b/server/src/sled_uplink.rs @@ -0,0 +1,543 @@ +//! Sled uplink — background task that maintains a WebSocket connection from a +//! sled (standard huskies instance) to an upstream gateway for permission +//! request forwarding. +//! +//! When `HUSKIES_UPSTREAM_GATEWAY` is set (or `--upstream-gateway` is passed +//! on the CLI), this module spawns a task that: +//! +//! 1. Acquires `services.perm_rx` for its lifetime (matching the Matrix bot's +//! `permission_listener` pattern), preventing `tool_prompt_permission` from +//! auto-denying requests with "no interactive session". +//! 2. Maintains a persistent WebSocket connection to the gateway's +//! `/api/sled-uplink` endpoint. +//! 3. Forwards each [`PermissionForward`] as a `perm_request` JSON envelope. +//! 4. Awaits the matching `perm_response` envelope from the gateway. +//! 5. Reconnects with exponential back-off on connection drop, fail-closing +//! any in-flight requests with [`PermissionDecision::Deny`]. + +use crate::http::context::{PermissionDecision, PermissionForward}; +use crate::services::Services; +use crate::slog; +use futures::SinkExt as _; +use futures::StreamExt as _; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::oneshot; +use tokio_tungstenite::tungstenite::Message as WsMessage; + +// ── Back-off constants ──────────────────────────────────────────────────────── + +const INITIAL_BACKOFF_SECS: u64 = 1; +const MAX_BACKOFF_SECS: u64 = 60; +const BACKOFF_MULTIPLIER: u64 = 2; + +// ── Wire protocol ───────────────────────────────────────────────────────────── + +/// Extensible JSON envelope for all sled↔gateway uplink messages. +/// +/// Phase 1 defines `perm_request` (sled→gateway) and `perm_response` +/// (gateway→sled). Future phases add new `type` values without changing this +/// framing. +#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] +pub struct UplinkEnvelope { + /// Message type discriminant (e.g. `"perm_request"`, `"perm_response"`). + #[serde(rename = "type")] + pub msg_type: String, + /// Correlation ID — the sled chooses this for each request; the gateway + /// echoes it back so the sled can demux concurrent responses. + pub req_id: String, + /// Message-specific payload. Varies by `msg_type`. + pub payload: serde_json::Value, +} + +// ── Public API ──────────────────────────────────────────────────────────────── + +/// Spawn the sled uplink background task. +/// +/// Does nothing when `upstream_url` is empty. When active, the task holds +/// `services.perm_rx` locked for its lifetime (preventing auto-deny in +/// `tool_prompt_permission`) and forwards all permission requests to the +/// gateway. Reconnects automatically with exponential back-off. +pub fn spawn_uplink_task(upstream_url: String, services: Arc) { + if upstream_url.is_empty() { + return; + } + slog!("[uplink] Spawning sled uplink task (gateway={upstream_url})"); + tokio::spawn(async move { + // Acquire perm_rx for this task's entire lifetime. While this lock is + // held, try_lock() inside tool_prompt_permission fails — meaning + // requests flow to perm_tx (which we drain here) rather than auto-deny. + let mut perm_rx = services.perm_rx.lock().await; + slog!("[uplink] Acquired perm_rx; maintaining gateway connection"); + + let mut backoff = INITIAL_BACKOFF_SECS; + loop { + match run_uplink_session(&upstream_url, &mut perm_rx).await { + Ok(()) => { + slog!("[uplink] Connection closed cleanly; reconnecting in {backoff}s"); + } + Err(ref e) => { + slog!("[uplink] Session error: {e}; reconnecting in {backoff}s"); + } + } + tokio::time::sleep(std::time::Duration::from_secs(backoff)).await; + backoff = backoff + .saturating_mul(BACKOFF_MULTIPLIER) + .min(MAX_BACKOFF_SECS); + } + }); +} + +// ── Private helpers ─────────────────────────────────────────────────────────── + +/// Run a single uplink session: connect, pump messages bidirectionally until +/// disconnect or channel close, then fail-close any in-flight requests. +async fn run_uplink_session( + url: &str, + perm_rx: &mut tokio::sync::mpsc::UnboundedReceiver, +) -> Result<(), String> { + let (ws_stream, _) = tokio_tungstenite::connect_async(url) + .await + .map_err(|e| format!("WS connect to {url}: {e}"))?; + slog!("[uplink] Connected to gateway uplink endpoint"); + + let (mut ws_sink, mut ws_rx) = ws_stream.split(); + let mut in_flight: HashMap> = HashMap::new(); + + let result = pump_messages(&mut ws_sink, &mut ws_rx, perm_rx, &mut in_flight).await; + fail_close_all(&mut in_flight); + result +} + +/// Drive the bidirectional message loop for one session. +async fn pump_messages( + ws_sink: &mut (impl futures::Sink + Unpin), + ws_rx: &mut ( + impl futures::Stream> + + Unpin + ), + perm_rx: &mut tokio::sync::mpsc::UnboundedReceiver, + in_flight: &mut HashMap>, +) -> Result<(), String> { + loop { + tokio::select! { + // New permission request from the MCP layer. + maybe_fwd = perm_rx.recv() => { + match maybe_fwd { + None => return Ok(()), // channel closed — server shutting down + Some(fwd) => { + let PermissionForward { + request_id, + tool_name, + tool_input, + response_tx, + } = fwd; + let env = UplinkEnvelope { + msg_type: "perm_request".to_string(), + req_id: request_id.clone(), + payload: serde_json::json!({ + "tool_name": tool_name, + "tool_input": tool_input, + }), + }; + let text = serde_json::to_string(&env) + .map_err(|e| format!("serialise perm_request: {e}"))?; + if ws_sink.send(WsMessage::Text(text.into())).await.is_err() { + // Connection dead: fail-close this request immediately. + let _ = response_tx.send(PermissionDecision::Deny); + return Err("WS send failed".to_string()); + } + in_flight.insert(request_id, response_tx); + } + } + } + + // Message arriving from the gateway. + msg = ws_rx.next() => { + match msg { + None | Some(Err(_)) => { + return Err("WS stream closed".to_string()); + } + Some(Ok(WsMessage::Close(_))) => { + return Err("Gateway sent Close frame".to_string()); + } + Some(Ok(WsMessage::Text(text))) => { + on_gateway_text(&text, in_flight); + } + Some(Ok(WsMessage::Ping(data))) => { + let _ = ws_sink.send(WsMessage::Pong(data)).await; + } + Some(Ok(_)) => {} + } + } + } + } +} + +/// Parse an incoming gateway text frame and resolve any matching in-flight request. +fn on_gateway_text( + text: &str, + in_flight: &mut HashMap>, +) { + let Ok(env) = serde_json::from_str::(text) else { + return; + }; + if env.msg_type == "perm_response" { + resolve_perm_response(env, in_flight); + } +} + +/// Map a `perm_response` envelope to a [`PermissionDecision`] and wake the +/// waiting MCP call. +fn resolve_perm_response( + env: UplinkEnvelope, + in_flight: &mut HashMap>, +) { + let Some(tx) = in_flight.remove(&env.req_id) else { + return; + }; + let approved = env + .payload + .get("approved") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + let always_allow = env + .payload + .get("always_allow") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + let decision = if always_allow { + PermissionDecision::AlwaysAllow + } else if approved { + PermissionDecision::Approve + } else { + PermissionDecision::Deny + }; + let _ = tx.send(decision); +} + +/// Deny all in-flight requests (fail-closed on connection drop — AC 8). +fn fail_close_all(in_flight: &mut HashMap>) { + for (_, tx) in in_flight.drain() { + let _ = tx.send(PermissionDecision::Deny); + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::http::context::PermissionForward; + use crate::services::Services; + use std::collections::HashMap; + use tokio::net::TcpListener; + use tokio::sync::oneshot; + use tokio_tungstenite::tungstenite::Message as WsMessage; + + // ── Pure unit tests ─────────────────────────────────────────────── + + #[test] + fn uplink_envelope_roundtrips_json() { + let env = UplinkEnvelope { + msg_type: "perm_request".to_string(), + req_id: "req-1".to_string(), + payload: serde_json::json!({"tool_name": "Bash", "tool_input": {}}), + }; + let text = serde_json::to_string(&env).unwrap(); + let back: UplinkEnvelope = serde_json::from_str(&text).unwrap(); + assert_eq!(back.msg_type, "perm_request"); + assert_eq!(back.req_id, "req-1"); + assert_eq!(back.payload["tool_name"], "Bash"); + } + + #[test] + fn resolve_perm_response_approve() { + let (tx, rx) = oneshot::channel(); + let mut in_flight = HashMap::new(); + in_flight.insert("r1".to_string(), tx); + let env = UplinkEnvelope { + msg_type: "perm_response".to_string(), + req_id: "r1".to_string(), + payload: serde_json::json!({"approved": true, "always_allow": false}), + }; + resolve_perm_response(env, &mut in_flight); + assert!(in_flight.is_empty()); + assert_eq!(rx.blocking_recv().unwrap(), PermissionDecision::Approve); + } + + #[test] + fn resolve_perm_response_deny() { + let (tx, rx) = oneshot::channel(); + let mut in_flight = HashMap::new(); + in_flight.insert("r2".to_string(), tx); + let env = UplinkEnvelope { + msg_type: "perm_response".to_string(), + req_id: "r2".to_string(), + payload: serde_json::json!({"approved": false}), + }; + resolve_perm_response(env, &mut in_flight); + assert_eq!(rx.blocking_recv().unwrap(), PermissionDecision::Deny); + } + + #[test] + fn resolve_perm_response_always_allow() { + let (tx, rx) = oneshot::channel(); + let mut in_flight = HashMap::new(); + in_flight.insert("r3".to_string(), tx); + let env = UplinkEnvelope { + msg_type: "perm_response".to_string(), + req_id: "r3".to_string(), + payload: serde_json::json!({"approved": true, "always_allow": true}), + }; + resolve_perm_response(env, &mut in_flight); + assert_eq!(rx.blocking_recv().unwrap(), PermissionDecision::AlwaysAllow); + } + + #[test] + fn resolve_perm_response_unknown_req_id_is_noop() { + let mut in_flight: HashMap> = HashMap::new(); + let env = UplinkEnvelope { + msg_type: "perm_response".to_string(), + req_id: "missing".to_string(), + payload: serde_json::json!({"approved": true}), + }; + resolve_perm_response(env, &mut in_flight); // must not panic + } + + #[test] + fn fail_close_all_denies_all_pending() { + let (tx1, rx1) = oneshot::channel(); + let (tx2, rx2) = oneshot::channel(); + let mut in_flight = HashMap::new(); + in_flight.insert("a".to_string(), tx1); + in_flight.insert("b".to_string(), tx2); + fail_close_all(&mut in_flight); + assert!(in_flight.is_empty()); + assert_eq!(rx1.blocking_recv().unwrap(), PermissionDecision::Deny); + assert_eq!(rx2.blocking_recv().unwrap(), PermissionDecision::Deny); + } + + #[test] + fn on_gateway_text_ignores_unknown_type() { + let mut in_flight: HashMap> = HashMap::new(); + on_gateway_text( + r#"{"type":"future_type","req_id":"x","payload":{}}"#, + &mut in_flight, + ); + assert!(in_flight.is_empty()); + } + + #[test] + fn on_gateway_text_ignores_invalid_json() { + let mut in_flight: HashMap> = HashMap::new(); + on_gateway_text("not-json", &mut in_flight); // must not panic + assert!(in_flight.is_empty()); + } + + #[test] + fn spawn_uplink_task_noop_when_url_empty() { + let (_perm_tx, perm_rx) = tokio::sync::mpsc::unbounded_channel(); + let agents = Arc::new(crate::agents::AgentPool::new_test(3000)); + let services = Arc::new(Services { + project_root: std::path::PathBuf::from("/tmp"), + status: agents.status_broadcaster(), + agents, + bot_name: "Test".to_string(), + bot_user_id: String::new(), + ambient_rooms: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())), + perm_rx: Arc::new(tokio::sync::Mutex::new(perm_rx)), + pending_perm_replies: Arc::new(tokio::sync::Mutex::new(HashMap::new())), + permission_timeout_secs: 120, + }); + // Empty URL → noop; if it panicked or blocked the test would fail. + spawn_uplink_task(String::new(), services); + } + + // ── AC 11: permission approved via uplink ──────────────────────── + // "Simulate matrix bot triggering a Bash permission, sled forwards via + // uplink, mock matrix transport approves, tool call proceeds." + + #[tokio::test] + async fn integration_perm_request_approved_via_uplink() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let url = format!("ws://127.0.0.1:{port}"); + + // Mock gateway: accept one connection, receive perm_request, reply approved. + let gw_task = tokio::spawn(async move { + let (tcp, _) = listener.accept().await.unwrap(); + let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap(); + let msg = ws.next().await.unwrap().unwrap(); + let text = match msg { + WsMessage::Text(t) => t.to_string(), + other => panic!("expected Text; got {other:?}"), + }; + let env: UplinkEnvelope = serde_json::from_str(&text).unwrap(); + assert_eq!(env.msg_type, "perm_request"); + assert_eq!(env.payload["tool_name"], "Bash"); + let resp = UplinkEnvelope { + msg_type: "perm_response".to_string(), + req_id: env.req_id, + payload: serde_json::json!({"approved": true, "always_allow": false}), + }; + ws.send(WsMessage::Text( + serde_json::to_string(&resp).unwrap().into(), + )) + .await + .unwrap(); + }); + + let (perm_tx, perm_rx) = tokio::sync::mpsc::unbounded_channel::(); + let agents = Arc::new(crate::agents::AgentPool::new_test(3000)); + let services = Arc::new(Services { + project_root: std::path::PathBuf::from("/tmp"), + status: agents.status_broadcaster(), + agents, + bot_name: "Test".to_string(), + bot_user_id: String::new(), + ambient_rooms: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())), + perm_rx: Arc::new(tokio::sync::Mutex::new(perm_rx)), + pending_perm_replies: Arc::new(tokio::sync::Mutex::new(HashMap::new())), + permission_timeout_secs: 120, + }); + + spawn_uplink_task(url, Arc::clone(&services)); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let (response_tx, response_rx) = oneshot::channel(); + perm_tx + .send(PermissionForward { + request_id: "req-test-1".to_string(), + tool_name: "Bash".to_string(), + tool_input: serde_json::json!({"command": "echo hello"}), + response_tx, + }) + .unwrap(); + + let decision = tokio::time::timeout(std::time::Duration::from_secs(5), response_rx) + .await + .expect("timeout waiting for decision") + .expect("oneshot dropped"); + + assert_eq!(decision, PermissionDecision::Approve); + gw_task.await.unwrap(); + } + + // ── AC 12: sled disconnects and reconnects ──────────────────────── + // "Sled disconnects and reconnects mid-session; subsequent permission + // requests succeed once reconnected." + + #[tokio::test] + async fn integration_reconnects_after_disconnect() { + use std::sync::Arc as StdArc; + use std::sync::atomic::{AtomicU32, Ordering}; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let url = format!("ws://127.0.0.1:{port}"); + + let conn_count = StdArc::new(AtomicU32::new(0)); + let conn_count2 = StdArc::clone(&conn_count); + + tokio::spawn(async move { + // First connection: receive the request then immediately drop (simulates network failure). + { + let (tcp, _) = listener.accept().await.unwrap(); + let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap(); + conn_count2.fetch_add(1, Ordering::SeqCst); + let _ = ws.next().await; // consume one frame + drop(ws); // close without replying → fail-close on sled side + } + // Second connection: approve the next request. + { + let (tcp, _) = listener.accept().await.unwrap(); + let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap(); + conn_count2.fetch_add(1, Ordering::SeqCst); + if let Some(Ok(WsMessage::Text(text))) = ws.next().await { + let env: UplinkEnvelope = serde_json::from_str(&text).unwrap(); + if env.msg_type == "perm_request" { + let resp = UplinkEnvelope { + msg_type: "perm_response".to_string(), + req_id: env.req_id, + payload: serde_json::json!({"approved": true, "always_allow": false}), + }; + let _ = ws + .send(WsMessage::Text( + serde_json::to_string(&resp).unwrap().into(), + )) + .await; + } + } + } + }); + + let (perm_tx, perm_rx) = tokio::sync::mpsc::unbounded_channel::(); + let agents = Arc::new(crate::agents::AgentPool::new_test(3000)); + let services = Arc::new(Services { + project_root: std::path::PathBuf::from("/tmp"), + status: agents.status_broadcaster(), + agents, + bot_name: "Test".to_string(), + bot_user_id: String::new(), + ambient_rooms: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())), + perm_rx: Arc::new(tokio::sync::Mutex::new(perm_rx)), + pending_perm_replies: Arc::new(tokio::sync::Mutex::new(HashMap::new())), + permission_timeout_secs: 120, + }); + + spawn_uplink_task(url, Arc::clone(&services)); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + // First request is sent on the connection that drops → denied. + let (tx1, rx1) = oneshot::channel(); + perm_tx + .send(PermissionForward { + request_id: "req-drop".to_string(), + tool_name: "Bash".to_string(), + tool_input: serde_json::json!({}), + response_tx: tx1, + }) + .unwrap(); + + let d1 = tokio::time::timeout(std::time::Duration::from_secs(5), rx1) + .await + .expect("timeout on first request") + .expect("oneshot dropped"); + assert_eq!( + d1, + PermissionDecision::Deny, + "dropped connection must fail-close" + ); + + // Wait for the 1-second reconnect backoff plus buffer. + tokio::time::sleep(std::time::Duration::from_millis(1500)).await; + + // Second request arrives on the reconnected session → approved. + let (tx2, rx2) = oneshot::channel(); + perm_tx + .send(PermissionForward { + request_id: "req-reconnect".to_string(), + tool_name: "Write".to_string(), + tool_input: serde_json::json!({}), + response_tx: tx2, + }) + .unwrap(); + + let d2 = tokio::time::timeout(std::time::Duration::from_secs(5), rx2) + .await + .expect("timeout on second request") + .expect("oneshot dropped"); + assert_eq!( + d2, + PermissionDecision::Approve, + "reconnected session must approve" + ); + + assert_eq!( + conn_count.load(Ordering::SeqCst), + 2, + "must have seen exactly 2 connections" + ); + } +}