Files
huskies/server/src/sled_uplink.rs
T

544 lines
22 KiB
Rust
Raw Normal View History

2026-05-12 21:29:04 +00:00
//! Sled uplink — background task that maintains a WebSocket connection from a
//! sled (standard huskies instance) to an upstream gateway for permission
//! request forwarding.
//!
//! When `HUSKIES_UPSTREAM_GATEWAY` is set (or `--upstream-gateway` is passed
//! on the CLI), this module spawns a task that:
//!
//! 1. Acquires `services.perm_rx` for its lifetime (matching the Matrix bot's
//! `permission_listener` pattern), preventing `tool_prompt_permission` from
//! auto-denying requests with "no interactive session".
//! 2. Maintains a persistent WebSocket connection to the gateway's
//! `/api/sled-uplink` endpoint.
//! 3. Forwards each [`PermissionForward`] as a `perm_request` JSON envelope.
//! 4. Awaits the matching `perm_response` envelope from the gateway.
//! 5. Reconnects with exponential back-off on connection drop, fail-closing
//! any in-flight requests with [`PermissionDecision::Deny`].
use crate::http::context::{PermissionDecision, PermissionForward};
use crate::services::Services;
use crate::slog;
use futures::SinkExt as _;
use futures::StreamExt as _;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::oneshot;
use tokio_tungstenite::tungstenite::Message as WsMessage;
// ── Back-off constants ────────────────────────────────────────────────────────
const INITIAL_BACKOFF_SECS: u64 = 1;
const MAX_BACKOFF_SECS: u64 = 60;
const BACKOFF_MULTIPLIER: u64 = 2;
// ── Wire protocol ─────────────────────────────────────────────────────────────
/// Extensible JSON envelope for all sled↔gateway uplink messages.
///
/// Phase 1 defines `perm_request` (sled→gateway) and `perm_response`
/// (gateway→sled). Future phases add new `type` values without changing this
/// framing.
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
pub struct UplinkEnvelope {
/// Message type discriminant (e.g. `"perm_request"`, `"perm_response"`).
#[serde(rename = "type")]
pub msg_type: String,
/// Correlation ID — the sled chooses this for each request; the gateway
/// echoes it back so the sled can demux concurrent responses.
pub req_id: String,
/// Message-specific payload. Varies by `msg_type`.
pub payload: serde_json::Value,
}
// ── Public API ────────────────────────────────────────────────────────────────
/// Spawn the sled uplink background task.
///
/// Does nothing when `upstream_url` is empty. When active, the task holds
/// `services.perm_rx` locked for its lifetime (preventing auto-deny in
/// `tool_prompt_permission`) and forwards all permission requests to the
/// gateway. Reconnects automatically with exponential back-off.
pub fn spawn_uplink_task(upstream_url: String, services: Arc<Services>) {
if upstream_url.is_empty() {
return;
}
slog!("[uplink] Spawning sled uplink task (gateway={upstream_url})");
tokio::spawn(async move {
// Acquire perm_rx for this task's entire lifetime. While this lock is
// held, try_lock() inside tool_prompt_permission fails — meaning
// requests flow to perm_tx (which we drain here) rather than auto-deny.
let mut perm_rx = services.perm_rx.lock().await;
slog!("[uplink] Acquired perm_rx; maintaining gateway connection");
let mut backoff = INITIAL_BACKOFF_SECS;
loop {
match run_uplink_session(&upstream_url, &mut perm_rx).await {
Ok(()) => {
slog!("[uplink] Connection closed cleanly; reconnecting in {backoff}s");
}
Err(ref e) => {
slog!("[uplink] Session error: {e}; reconnecting in {backoff}s");
}
}
tokio::time::sleep(std::time::Duration::from_secs(backoff)).await;
backoff = backoff
.saturating_mul(BACKOFF_MULTIPLIER)
.min(MAX_BACKOFF_SECS);
}
});
}
// ── Private helpers ───────────────────────────────────────────────────────────
/// Run a single uplink session: connect, pump messages bidirectionally until
/// disconnect or channel close, then fail-close any in-flight requests.
async fn run_uplink_session(
url: &str,
perm_rx: &mut tokio::sync::mpsc::UnboundedReceiver<PermissionForward>,
) -> Result<(), String> {
let (ws_stream, _) = tokio_tungstenite::connect_async(url)
.await
.map_err(|e| format!("WS connect to {url}: {e}"))?;
slog!("[uplink] Connected to gateway uplink endpoint");
let (mut ws_sink, mut ws_rx) = ws_stream.split();
let mut in_flight: HashMap<String, oneshot::Sender<PermissionDecision>> = HashMap::new();
let result = pump_messages(&mut ws_sink, &mut ws_rx, perm_rx, &mut in_flight).await;
fail_close_all(&mut in_flight);
result
}
/// Drive the bidirectional message loop for one session.
async fn pump_messages(
ws_sink: &mut (impl futures::Sink<WsMessage, Error = tokio_tungstenite::tungstenite::Error> + Unpin),
ws_rx: &mut (
impl futures::Stream<Item = Result<WsMessage, tokio_tungstenite::tungstenite::Error>>
+ Unpin
),
perm_rx: &mut tokio::sync::mpsc::UnboundedReceiver<PermissionForward>,
in_flight: &mut HashMap<String, oneshot::Sender<PermissionDecision>>,
) -> Result<(), String> {
loop {
tokio::select! {
// New permission request from the MCP layer.
maybe_fwd = perm_rx.recv() => {
match maybe_fwd {
None => return Ok(()), // channel closed — server shutting down
Some(fwd) => {
let PermissionForward {
request_id,
tool_name,
tool_input,
response_tx,
} = fwd;
let env = UplinkEnvelope {
msg_type: "perm_request".to_string(),
req_id: request_id.clone(),
payload: serde_json::json!({
"tool_name": tool_name,
"tool_input": tool_input,
}),
};
let text = serde_json::to_string(&env)
.map_err(|e| format!("serialise perm_request: {e}"))?;
if ws_sink.send(WsMessage::Text(text.into())).await.is_err() {
// Connection dead: fail-close this request immediately.
let _ = response_tx.send(PermissionDecision::Deny);
return Err("WS send failed".to_string());
}
in_flight.insert(request_id, response_tx);
}
}
}
// Message arriving from the gateway.
msg = ws_rx.next() => {
match msg {
None | Some(Err(_)) => {
return Err("WS stream closed".to_string());
}
Some(Ok(WsMessage::Close(_))) => {
return Err("Gateway sent Close frame".to_string());
}
Some(Ok(WsMessage::Text(text))) => {
on_gateway_text(&text, in_flight);
}
Some(Ok(WsMessage::Ping(data))) => {
let _ = ws_sink.send(WsMessage::Pong(data)).await;
}
Some(Ok(_)) => {}
}
}
}
}
}
/// Parse an incoming gateway text frame and resolve any matching in-flight request.
fn on_gateway_text(
text: &str,
in_flight: &mut HashMap<String, oneshot::Sender<PermissionDecision>>,
) {
let Ok(env) = serde_json::from_str::<UplinkEnvelope>(text) else {
return;
};
if env.msg_type == "perm_response" {
resolve_perm_response(env, in_flight);
}
}
/// Map a `perm_response` envelope to a [`PermissionDecision`] and wake the
/// waiting MCP call.
fn resolve_perm_response(
env: UplinkEnvelope,
in_flight: &mut HashMap<String, oneshot::Sender<PermissionDecision>>,
) {
let Some(tx) = in_flight.remove(&env.req_id) else {
return;
};
let approved = env
.payload
.get("approved")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let always_allow = env
.payload
.get("always_allow")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let decision = if always_allow {
PermissionDecision::AlwaysAllow
} else if approved {
PermissionDecision::Approve
} else {
PermissionDecision::Deny
};
let _ = tx.send(decision);
}
/// Deny all in-flight requests (fail-closed on connection drop — AC 8).
fn fail_close_all(in_flight: &mut HashMap<String, oneshot::Sender<PermissionDecision>>) {
for (_, tx) in in_flight.drain() {
let _ = tx.send(PermissionDecision::Deny);
}
}
// ── Tests ─────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use crate::http::context::PermissionForward;
use crate::services::Services;
use std::collections::HashMap;
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use tokio_tungstenite::tungstenite::Message as WsMessage;
// ── Pure unit tests ───────────────────────────────────────────────
#[test]
fn uplink_envelope_roundtrips_json() {
let env = UplinkEnvelope {
msg_type: "perm_request".to_string(),
req_id: "req-1".to_string(),
payload: serde_json::json!({"tool_name": "Bash", "tool_input": {}}),
};
let text = serde_json::to_string(&env).unwrap();
let back: UplinkEnvelope = serde_json::from_str(&text).unwrap();
assert_eq!(back.msg_type, "perm_request");
assert_eq!(back.req_id, "req-1");
assert_eq!(back.payload["tool_name"], "Bash");
}
#[test]
fn resolve_perm_response_approve() {
let (tx, rx) = oneshot::channel();
let mut in_flight = HashMap::new();
in_flight.insert("r1".to_string(), tx);
let env = UplinkEnvelope {
msg_type: "perm_response".to_string(),
req_id: "r1".to_string(),
payload: serde_json::json!({"approved": true, "always_allow": false}),
};
resolve_perm_response(env, &mut in_flight);
assert!(in_flight.is_empty());
assert_eq!(rx.blocking_recv().unwrap(), PermissionDecision::Approve);
}
#[test]
fn resolve_perm_response_deny() {
let (tx, rx) = oneshot::channel();
let mut in_flight = HashMap::new();
in_flight.insert("r2".to_string(), tx);
let env = UplinkEnvelope {
msg_type: "perm_response".to_string(),
req_id: "r2".to_string(),
payload: serde_json::json!({"approved": false}),
};
resolve_perm_response(env, &mut in_flight);
assert_eq!(rx.blocking_recv().unwrap(), PermissionDecision::Deny);
}
#[test]
fn resolve_perm_response_always_allow() {
let (tx, rx) = oneshot::channel();
let mut in_flight = HashMap::new();
in_flight.insert("r3".to_string(), tx);
let env = UplinkEnvelope {
msg_type: "perm_response".to_string(),
req_id: "r3".to_string(),
payload: serde_json::json!({"approved": true, "always_allow": true}),
};
resolve_perm_response(env, &mut in_flight);
assert_eq!(rx.blocking_recv().unwrap(), PermissionDecision::AlwaysAllow);
}
#[test]
fn resolve_perm_response_unknown_req_id_is_noop() {
let mut in_flight: HashMap<String, oneshot::Sender<PermissionDecision>> = HashMap::new();
let env = UplinkEnvelope {
msg_type: "perm_response".to_string(),
req_id: "missing".to_string(),
payload: serde_json::json!({"approved": true}),
};
resolve_perm_response(env, &mut in_flight); // must not panic
}
#[test]
fn fail_close_all_denies_all_pending() {
let (tx1, rx1) = oneshot::channel();
let (tx2, rx2) = oneshot::channel();
let mut in_flight = HashMap::new();
in_flight.insert("a".to_string(), tx1);
in_flight.insert("b".to_string(), tx2);
fail_close_all(&mut in_flight);
assert!(in_flight.is_empty());
assert_eq!(rx1.blocking_recv().unwrap(), PermissionDecision::Deny);
assert_eq!(rx2.blocking_recv().unwrap(), PermissionDecision::Deny);
}
#[test]
fn on_gateway_text_ignores_unknown_type() {
let mut in_flight: HashMap<String, oneshot::Sender<PermissionDecision>> = HashMap::new();
on_gateway_text(
r#"{"type":"future_type","req_id":"x","payload":{}}"#,
&mut in_flight,
);
assert!(in_flight.is_empty());
}
#[test]
fn on_gateway_text_ignores_invalid_json() {
let mut in_flight: HashMap<String, oneshot::Sender<PermissionDecision>> = HashMap::new();
on_gateway_text("not-json", &mut in_flight); // must not panic
assert!(in_flight.is_empty());
}
#[test]
fn spawn_uplink_task_noop_when_url_empty() {
let (_perm_tx, perm_rx) = tokio::sync::mpsc::unbounded_channel();
let agents = Arc::new(crate::agents::AgentPool::new_test(3000));
let services = Arc::new(Services {
project_root: std::path::PathBuf::from("/tmp"),
status: agents.status_broadcaster(),
agents,
bot_name: "Test".to_string(),
bot_user_id: String::new(),
ambient_rooms: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
perm_rx: Arc::new(tokio::sync::Mutex::new(perm_rx)),
pending_perm_replies: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
permission_timeout_secs: 120,
});
// Empty URL → noop; if it panicked or blocked the test would fail.
spawn_uplink_task(String::new(), services);
}
// ── AC 11: permission approved via uplink ────────────────────────
// "Simulate matrix bot triggering a Bash permission, sled forwards via
// uplink, mock matrix transport approves, tool call proceeds."
#[tokio::test]
async fn integration_perm_request_approved_via_uplink() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let url = format!("ws://127.0.0.1:{port}");
// Mock gateway: accept one connection, receive perm_request, reply approved.
let gw_task = tokio::spawn(async move {
let (tcp, _) = listener.accept().await.unwrap();
let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap();
let msg = ws.next().await.unwrap().unwrap();
let text = match msg {
WsMessage::Text(t) => t.to_string(),
other => panic!("expected Text; got {other:?}"),
};
let env: UplinkEnvelope = serde_json::from_str(&text).unwrap();
assert_eq!(env.msg_type, "perm_request");
assert_eq!(env.payload["tool_name"], "Bash");
let resp = UplinkEnvelope {
msg_type: "perm_response".to_string(),
req_id: env.req_id,
payload: serde_json::json!({"approved": true, "always_allow": false}),
};
ws.send(WsMessage::Text(
serde_json::to_string(&resp).unwrap().into(),
))
.await
.unwrap();
});
let (perm_tx, perm_rx) = tokio::sync::mpsc::unbounded_channel::<PermissionForward>();
let agents = Arc::new(crate::agents::AgentPool::new_test(3000));
let services = Arc::new(Services {
project_root: std::path::PathBuf::from("/tmp"),
status: agents.status_broadcaster(),
agents,
bot_name: "Test".to_string(),
bot_user_id: String::new(),
ambient_rooms: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
perm_rx: Arc::new(tokio::sync::Mutex::new(perm_rx)),
pending_perm_replies: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
permission_timeout_secs: 120,
});
spawn_uplink_task(url, Arc::clone(&services));
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let (response_tx, response_rx) = oneshot::channel();
perm_tx
.send(PermissionForward {
request_id: "req-test-1".to_string(),
tool_name: "Bash".to_string(),
tool_input: serde_json::json!({"command": "echo hello"}),
response_tx,
})
.unwrap();
let decision = tokio::time::timeout(std::time::Duration::from_secs(5), response_rx)
.await
.expect("timeout waiting for decision")
.expect("oneshot dropped");
assert_eq!(decision, PermissionDecision::Approve);
gw_task.await.unwrap();
}
// ── AC 12: sled disconnects and reconnects ────────────────────────
// "Sled disconnects and reconnects mid-session; subsequent permission
// requests succeed once reconnected."
#[tokio::test]
async fn integration_reconnects_after_disconnect() {
use std::sync::Arc as StdArc;
use std::sync::atomic::{AtomicU32, Ordering};
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let url = format!("ws://127.0.0.1:{port}");
let conn_count = StdArc::new(AtomicU32::new(0));
let conn_count2 = StdArc::clone(&conn_count);
tokio::spawn(async move {
// First connection: receive the request then immediately drop (simulates network failure).
{
let (tcp, _) = listener.accept().await.unwrap();
let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap();
conn_count2.fetch_add(1, Ordering::SeqCst);
let _ = ws.next().await; // consume one frame
drop(ws); // close without replying → fail-close on sled side
}
// Second connection: approve the next request.
{
let (tcp, _) = listener.accept().await.unwrap();
let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap();
conn_count2.fetch_add(1, Ordering::SeqCst);
if let Some(Ok(WsMessage::Text(text))) = ws.next().await {
let env: UplinkEnvelope = serde_json::from_str(&text).unwrap();
if env.msg_type == "perm_request" {
let resp = UplinkEnvelope {
msg_type: "perm_response".to_string(),
req_id: env.req_id,
payload: serde_json::json!({"approved": true, "always_allow": false}),
};
let _ = ws
.send(WsMessage::Text(
serde_json::to_string(&resp).unwrap().into(),
))
.await;
}
}
}
});
let (perm_tx, perm_rx) = tokio::sync::mpsc::unbounded_channel::<PermissionForward>();
let agents = Arc::new(crate::agents::AgentPool::new_test(3000));
let services = Arc::new(Services {
project_root: std::path::PathBuf::from("/tmp"),
status: agents.status_broadcaster(),
agents,
bot_name: "Test".to_string(),
bot_user_id: String::new(),
ambient_rooms: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
perm_rx: Arc::new(tokio::sync::Mutex::new(perm_rx)),
pending_perm_replies: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
permission_timeout_secs: 120,
});
spawn_uplink_task(url, Arc::clone(&services));
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// First request is sent on the connection that drops → denied.
let (tx1, rx1) = oneshot::channel();
perm_tx
.send(PermissionForward {
request_id: "req-drop".to_string(),
tool_name: "Bash".to_string(),
tool_input: serde_json::json!({}),
response_tx: tx1,
})
.unwrap();
let d1 = tokio::time::timeout(std::time::Duration::from_secs(5), rx1)
.await
.expect("timeout on first request")
.expect("oneshot dropped");
assert_eq!(
d1,
PermissionDecision::Deny,
"dropped connection must fail-close"
);
// Wait for the 1-second reconnect backoff plus buffer.
tokio::time::sleep(std::time::Duration::from_millis(1500)).await;
// Second request arrives on the reconnected session → approved.
let (tx2, rx2) = oneshot::channel();
perm_tx
.send(PermissionForward {
request_id: "req-reconnect".to_string(),
tool_name: "Write".to_string(),
tool_input: serde_json::json!({}),
response_tx: tx2,
})
.unwrap();
let d2 = tokio::time::timeout(std::time::Duration::from_secs(5), rx2)
.await
.expect("timeout on second request")
.expect("oneshot dropped");
assert_eq!(
d2,
PermissionDecision::Approve,
"reconnected session must approve"
);
assert_eq!(
conn_count.load(Ordering::SeqCst),
2,
"must have seen exactly 2 connections"
);
}
}