huskies: merge 899

This commit is contained in:
dave
2026-05-12 23:11:34 +00:00
parent 0f0cf59329
commit cd214d7246
9 changed files with 1105 additions and 218 deletions
+250 -27
View File
@@ -1,6 +1,5 @@
//! Sled uplink — background task that maintains a WebSocket connection from a
//! sled (standard huskies instance) to an upstream gateway for permission
//! request forwarding.
//! sled (standard huskies instance) to an upstream gateway.
//!
//! When `HUSKIES_UPSTREAM_GATEWAY` is set (or `--upstream-gateway` is passed
//! on the CLI), this module spawns a task that:
@@ -10,10 +9,17 @@
//! auto-denying requests with "no interactive session".
//! 2. Maintains a persistent WebSocket connection to the gateway's
//! `/api/sled-uplink` endpoint.
//! 3. Forwards each [`PermissionForward`] as a `perm_request` JSON envelope.
//! 4. Awaits the matching `perm_response` envelope from the gateway.
//! 5. Reconnects with exponential back-off on connection drop, fail-closing
//! any in-flight requests with [`PermissionDecision::Deny`].
//! 3. Sends an `identity` frame announcing the sled's project name + auth
//! token immediately after connect (story 899 Phase 2).
//! 4. Sends a `heartbeat` frame every [`HEARTBEAT_INTERVAL_SECS`] so the
//! gateway can mark the connection live without extra HTTP polls.
//! 5. Forwards each [`PermissionForward`] as a `perm_request` JSON envelope
//! and awaits the matching `perm_response`.
//! 6. Handles inbound `mcp_request` frames by replaying the MCP JSON-RPC
//! body against the sled's own local `/mcp` HTTP endpoint and returning
//! the response as an `mcp_response` frame.
//! 7. Reconnects with exponential back-off on connection drop, fail-closing
//! any in-flight permission requests with [`PermissionDecision::Deny`].
use crate::http::context::{PermissionDecision, PermissionForward};
use crate::services::Services;
@@ -25,12 +31,35 @@ use std::sync::Arc;
use tokio::sync::oneshot;
use tokio_tungstenite::tungstenite::Message as WsMessage;
/// Configuration for [`spawn_uplink_task`].
///
/// Bundled into a single struct so `main.rs` can build it once and pass it in
/// without proliferating positional arguments.
pub struct UplinkConfig {
/// WebSocket URL of the upstream gateway's `/api/sled-uplink` endpoint.
/// Includes the `?token=` query parameter for auth.
pub upstream_url: String,
/// Project name this sled identifies as. Sent in the `identity` frame
/// after WS connect (story 899 AC 5).
pub project_name: String,
/// HTTP base URL of this sled's own MCP endpoint (e.g.
/// `http://127.0.0.1:3001/mcp`). Used to replay `mcp_request` frames
/// received from the gateway against the local MCP handler.
pub local_mcp_url: String,
}
// ── Back-off constants ────────────────────────────────────────────────────────
const INITIAL_BACKOFF_SECS: u64 = 1;
const MAX_BACKOFF_SECS: u64 = 60;
const BACKOFF_MULTIPLIER: u64 = 2;
/// Interval between `heartbeat` frames sent to the gateway (story 899 AC 3).
///
/// Must be shorter than `crate::service::gateway::HEARTBEAT_MAX_AGE_MS` so
/// the gateway never marks a healthy sled as stale.
pub const HEARTBEAT_INTERVAL_SECS: u64 = 10;
// ── Wire protocol ─────────────────────────────────────────────────────────────
/// Extensible JSON envelope for all sled↔gateway uplink messages.
@@ -54,15 +83,21 @@ pub struct UplinkEnvelope {
/// Spawn the sled uplink background task.
///
/// Does nothing when `upstream_url` is empty. When active, the task holds
/// `services.perm_rx` locked for its lifetime (preventing auto-deny in
/// Does nothing when `config.upstream_url` is empty (AC 8 — sleds without
/// an upstream configured continue to work unchanged). When active, the task
/// holds `services.perm_rx` locked for its lifetime (preventing auto-deny in
/// `tool_prompt_permission`) and forwards all permission requests to the
/// gateway. Reconnects automatically with exponential back-off.
pub fn spawn_uplink_task(upstream_url: String, services: Arc<Services>) {
if upstream_url.is_empty() {
pub fn spawn_uplink_task(config: UplinkConfig, services: Arc<Services>) {
if config.upstream_url.is_empty() {
return;
}
slog!("[uplink] Spawning sled uplink task (gateway={upstream_url})");
let UplinkConfig {
upstream_url,
project_name,
local_mcp_url,
} = config;
slog!("[uplink] Spawning sled uplink task (gateway={upstream_url}, project={project_name})");
tokio::spawn(async move {
// Acquire perm_rx for this task's entire lifetime. While this lock is
// held, try_lock() inside tool_prompt_permission fails — meaning
@@ -70,9 +105,19 @@ pub fn spawn_uplink_task(upstream_url: String, services: Arc<Services>) {
let mut perm_rx = services.perm_rx.lock().await;
slog!("[uplink] Acquired perm_rx; maintaining gateway connection");
let http = reqwest::Client::new();
let mut backoff = INITIAL_BACKOFF_SECS;
loop {
match run_uplink_session(&upstream_url, &mut perm_rx).await {
match run_uplink_session(
&upstream_url,
&project_name,
&local_mcp_url,
&http,
&mut perm_rx,
)
.await
{
Ok(()) => {
slog!("[uplink] Connection closed cleanly; reconnecting in {backoff}s");
}
@@ -90,10 +135,14 @@ pub fn spawn_uplink_task(upstream_url: String, services: Arc<Services>) {
// ── Private helpers ───────────────────────────────────────────────────────────
/// Run a single uplink session: connect, pump messages bidirectionally until
/// disconnect or channel close, then fail-close any in-flight requests.
/// Run a single uplink session: connect, send identity frame, pump messages
/// bidirectionally until disconnect or channel close, then fail-close any
/// in-flight requests.
async fn run_uplink_session(
url: &str,
project_name: &str,
local_mcp_url: &str,
http: &reqwest::Client,
perm_rx: &mut tokio::sync::mpsc::UnboundedReceiver<PermissionForward>,
) -> Result<(), String> {
let (ws_stream, _) = tokio_tungstenite::connect_async(url)
@@ -102,9 +151,31 @@ async fn run_uplink_session(
slog!("[uplink] Connected to gateway uplink endpoint");
let (mut ws_sink, mut ws_rx) = ws_stream.split();
// ── Identity handshake (story 899 AC 5) ─────────────────────────────
let identity = UplinkEnvelope {
msg_type: "identity".to_string(),
req_id: "identity".to_string(),
payload: serde_json::json!({ "project": project_name }),
};
let identity_text =
serde_json::to_string(&identity).map_err(|e| format!("serialise identity: {e}"))?;
ws_sink
.send(WsMessage::Text(identity_text.into()))
.await
.map_err(|e| format!("send identity: {e}"))?;
let mut in_flight: HashMap<String, oneshot::Sender<PermissionDecision>> = HashMap::new();
let result = pump_messages(&mut ws_sink, &mut ws_rx, perm_rx, &mut in_flight).await;
let result = pump_messages(
&mut ws_sink,
&mut ws_rx,
perm_rx,
&mut in_flight,
local_mcp_url,
http,
)
.await;
fail_close_all(&mut in_flight);
result
}
@@ -118,9 +189,44 @@ async fn pump_messages(
),
perm_rx: &mut tokio::sync::mpsc::UnboundedReceiver<PermissionForward>,
in_flight: &mut HashMap<String, oneshot::Sender<PermissionDecision>>,
local_mcp_url: &str,
http: &reqwest::Client,
) -> Result<(), String> {
// Channel for spawned mcp_request handlers to deliver their finished
// mcp_response frames back to the WS writer.
let (mcp_resp_tx, mut mcp_resp_rx) = tokio::sync::mpsc::unbounded_channel::<UplinkEnvelope>();
let mut heartbeat =
tokio::time::interval(std::time::Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
// Skip the immediate tick so the first heartbeat fires after the interval,
// not right after connect.
heartbeat.tick().await;
loop {
tokio::select! {
// Heartbeat tick (story 899 AC 3).
_ = heartbeat.tick() => {
let env = UplinkEnvelope {
msg_type: "heartbeat".to_string(),
req_id: String::new(),
payload: serde_json::Value::Null,
};
let text = serde_json::to_string(&env)
.map_err(|e| format!("serialise heartbeat: {e}"))?;
if ws_sink.send(WsMessage::Text(text.into())).await.is_err() {
return Err("WS send failed (heartbeat)".to_string());
}
}
// Completed mcp_response from a spawned handler — forward to gateway.
Some(env) = mcp_resp_rx.recv() => {
let text = serde_json::to_string(&env)
.map_err(|e| format!("serialise mcp_response: {e}"))?;
if ws_sink.send(WsMessage::Text(text.into())).await.is_err() {
return Err("WS send failed (mcp_response)".to_string());
}
}
// New permission request from the MCP layer.
maybe_fwd = perm_rx.recv() => {
match maybe_fwd {
@@ -162,7 +268,7 @@ async fn pump_messages(
return Err("Gateway sent Close frame".to_string());
}
Some(Ok(WsMessage::Text(text))) => {
on_gateway_text(&text, in_flight);
on_gateway_text(&text, in_flight, local_mcp_url, http, &mcp_resp_tx);
}
Some(Ok(WsMessage::Ping(data))) => {
let _ = ws_sink.send(WsMessage::Pong(data)).await;
@@ -174,19 +280,85 @@ async fn pump_messages(
}
}
/// Parse an incoming gateway text frame and resolve any matching in-flight request.
/// Parse an incoming gateway text frame and dispatch it to the appropriate
/// handler. `perm_response` resolves a waiting in-flight permission request;
/// `mcp_request` spawns a task that replays the JSON-RPC body against the
/// sled's local `/mcp` HTTP endpoint and forwards the result back to the
/// gateway as an `mcp_response`.
fn on_gateway_text(
text: &str,
in_flight: &mut HashMap<String, oneshot::Sender<PermissionDecision>>,
local_mcp_url: &str,
http: &reqwest::Client,
mcp_resp_tx: &tokio::sync::mpsc::UnboundedSender<UplinkEnvelope>,
) {
let Ok(env) = serde_json::from_str::<UplinkEnvelope>(text) else {
return;
};
if env.msg_type == "perm_response" {
resolve_perm_response(env, in_flight);
match env.msg_type.as_str() {
"perm_response" => resolve_perm_response(env, in_flight),
"mcp_request" => spawn_mcp_request_handler(env, local_mcp_url, http, mcp_resp_tx),
_ => {}
}
}
/// Replay the gateway's `mcp_request` body against the sled's local MCP HTTP
/// endpoint in a spawned task and forward the response back as an
/// `mcp_response` envelope. The payload is expected to contain a `body`
/// string field holding the raw JSON-RPC bytes.
fn spawn_mcp_request_handler(
env: UplinkEnvelope,
local_mcp_url: &str,
http: &reqwest::Client,
mcp_resp_tx: &tokio::sync::mpsc::UnboundedSender<UplinkEnvelope>,
) {
let req_id = env.req_id.clone();
let body = env
.payload
.get("body")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let mcp_url = local_mcp_url.to_string();
let http = http.clone();
let mcp_resp_tx = mcp_resp_tx.clone();
tokio::spawn(async move {
let response_value = match http
.post(&mcp_url)
.header("Content-Type", "application/json")
.body(body)
.send()
.await
{
Ok(r) => match r.json::<serde_json::Value>().await {
Ok(v) => v,
Err(e) => serde_json::json!({
"jsonrpc": "2.0",
"id": null,
"error": {
"code": -32603,
"message": format!("local mcp invalid JSON: {e}"),
}
}),
},
Err(e) => serde_json::json!({
"jsonrpc": "2.0",
"id": null,
"error": {
"code": -32603,
"message": format!("local mcp request failed: {e}"),
}
}),
};
let resp = UplinkEnvelope {
msg_type: "mcp_response".to_string(),
req_id,
payload: response_value,
};
let _ = mcp_resp_tx.send(resp);
});
}
/// Map a `perm_response` envelope to a [`PermissionDecision`] and wake the
/// waiting MCP call.
fn resolve_perm_response(
@@ -321,9 +493,13 @@ mod tests {
#[test]
fn on_gateway_text_ignores_unknown_type() {
let mut in_flight: HashMap<String, oneshot::Sender<PermissionDecision>> = HashMap::new();
let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
on_gateway_text(
r#"{"type":"future_type","req_id":"x","payload":{}}"#,
&mut in_flight,
"http://127.0.0.1:0/mcp",
&reqwest::Client::new(),
&tx,
);
assert!(in_flight.is_empty());
}
@@ -331,7 +507,14 @@ mod tests {
#[test]
fn on_gateway_text_ignores_invalid_json() {
let mut in_flight: HashMap<String, oneshot::Sender<PermissionDecision>> = HashMap::new();
on_gateway_text("not-json", &mut in_flight); // must not panic
let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
on_gateway_text(
"not-json",
&mut in_flight,
"http://127.0.0.1:0/mcp",
&reqwest::Client::new(),
&tx,
); // must not panic
assert!(in_flight.is_empty());
}
@@ -351,7 +534,14 @@ mod tests {
permission_timeout_secs: 120,
});
// Empty URL → noop; if it panicked or blocked the test would fail.
spawn_uplink_task(String::new(), services);
spawn_uplink_task(
UplinkConfig {
upstream_url: String::new(),
project_name: "test".to_string(),
local_mcp_url: "http://127.0.0.1:0/mcp".to_string(),
},
services,
);
}
// ── AC 11: permission approved via uplink ────────────────────────
@@ -364,10 +554,23 @@ mod tests {
let port = listener.local_addr().unwrap().port();
let url = format!("ws://127.0.0.1:{port}");
// Mock gateway: accept one connection, receive perm_request, reply approved.
// Mock gateway: accept one connection, consume identity, receive
// perm_request, reply approved.
let gw_task = tokio::spawn(async move {
let (tcp, _) = listener.accept().await.unwrap();
let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap();
// First frame: identity (story 899 AC 5).
let msg = ws.next().await.unwrap().unwrap();
let text = match msg {
WsMessage::Text(t) => t.to_string(),
other => panic!("expected identity Text; got {other:?}"),
};
let env: UplinkEnvelope = serde_json::from_str(&text).unwrap();
assert_eq!(env.msg_type, "identity");
assert_eq!(env.payload["project"], "test-proj");
// Next frame: perm_request.
let msg = ws.next().await.unwrap().unwrap();
let text = match msg {
WsMessage::Text(t) => t.to_string(),
@@ -402,7 +605,14 @@ mod tests {
permission_timeout_secs: 120,
});
spawn_uplink_task(url, Arc::clone(&services));
spawn_uplink_task(
UplinkConfig {
upstream_url: url,
project_name: "test-proj".to_string(),
local_mcp_url: "http://127.0.0.1:0/mcp".to_string(),
},
Arc::clone(&services),
);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let (response_tx, response_rx) = oneshot::channel();
@@ -441,12 +651,14 @@ mod tests {
let conn_count2 = StdArc::clone(&conn_count);
tokio::spawn(async move {
// First connection: receive the request then immediately drop (simulates network failure).
// First connection: consume identity + the request frame, then
// drop without replying (simulates network failure).
{
let (tcp, _) = listener.accept().await.unwrap();
let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap();
conn_count2.fetch_add(1, Ordering::SeqCst);
let _ = ws.next().await; // consume one frame
let _ = ws.next().await; // identity frame
let _ = ws.next().await; // perm_request frame
drop(ws); // close without replying → fail-close on sled side
}
// Second connection: approve the next request.
@@ -454,7 +666,10 @@ mod tests {
let (tcp, _) = listener.accept().await.unwrap();
let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap();
conn_count2.fetch_add(1, Ordering::SeqCst);
if let Some(Ok(WsMessage::Text(text))) = ws.next().await {
// Consume identity frame.
let _ = ws.next().await;
// Drain non-perm frames (heartbeat etc.) until perm_request arrives.
while let Some(Ok(WsMessage::Text(text))) = ws.next().await {
let env: UplinkEnvelope = serde_json::from_str(&text).unwrap();
if env.msg_type == "perm_request" {
let resp = UplinkEnvelope {
@@ -467,6 +682,7 @@ mod tests {
serde_json::to_string(&resp).unwrap().into(),
))
.await;
break;
}
}
}
@@ -486,7 +702,14 @@ mod tests {
permission_timeout_secs: 120,
});
spawn_uplink_task(url, Arc::clone(&services));
spawn_uplink_task(
UplinkConfig {
upstream_url: url,
project_name: "test-proj".to_string(),
local_mcp_url: "http://127.0.0.1:0/mcp".to_string(),
},
Arc::clone(&services),
);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// First request is sent on the connection that drops → denied.