huskies: merge 899

This commit is contained in:
dave
2026-05-12 23:11:34 +00:00
parent 0f0cf59329
commit cd214d7246
9 changed files with 1105 additions and 218 deletions
+1 -1
View File
@@ -119,7 +119,7 @@ pub async fn run(config_path: &Path, port: u16) -> Result<(), std::io::Error> {
.read()
.await
.iter()
.map(|(name, entry)| (name.clone(), entry.url.clone()))
.filter_map(|(name, entry)| entry.url.as_ref().map(|u| (name.clone(), u.clone())))
.collect();
let (bot_abort, bot_shutdown_tx) = gateway::io::spawn_gateway_bot(
&config_dir,
+206 -29
View File
@@ -7,12 +7,7 @@ use std::path::PathBuf;
fn make_test_state() -> Arc<GatewayState> {
let mut projects = BTreeMap::new();
projects.insert(
"test".into(),
ProjectEntry {
url: "http://test:3001".into(),
},
);
projects.insert("test".into(), ProjectEntry::with_url("http://test:3001"));
let config = GatewayConfig {
projects,
sled_tokens: BTreeMap::new(),
@@ -367,9 +362,7 @@ async fn init_project_registers_in_projects_toml_when_name_and_url_given() {
let mut projects = BTreeMap::new();
projects.insert(
"existing".into(),
ProjectEntry {
url: "http://existing:3001".into(),
},
ProjectEntry::with_url("http://existing:3001"),
);
let config = GatewayConfig {
projects,
@@ -388,19 +381,17 @@ async fn init_project_registers_in_projects_toml_when_name_and_url_given() {
let projects = state.projects.read().await;
assert!(projects.contains_key("new-project"));
assert_eq!(projects["new-project"].url, "http://new-project:3002");
assert_eq!(
projects["new-project"].url.as_deref(),
Some("http://new-project:3002")
);
}
#[tokio::test]
async fn init_project_duplicate_name_returns_error() {
let dir = tempfile::tempdir().unwrap();
let mut projects = BTreeMap::new();
projects.insert(
"taken".into(),
ProjectEntry {
url: "http://taken:3001".into(),
},
);
projects.insert("taken".into(), ProjectEntry::with_url("http://taken:3001"));
let config = GatewayConfig {
projects,
sled_tokens: BTreeMap::new(),
@@ -452,7 +443,7 @@ async fn init_project_then_wizard_status_integration() {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let mut projects = BTreeMap::new();
projects.insert("mock-project".into(), ProjectEntry { url: mock_url });
projects.insert("mock-project".into(), ProjectEntry::with_url(mock_url));
let config = GatewayConfig {
projects,
sled_tokens: BTreeMap::new(),
@@ -972,12 +963,7 @@ async fn gateway_mcp_sse_proxy_streams_progress_and_final_response() {
.await;
let mut projects = BTreeMap::new();
projects.insert(
"sled".to_string(),
ProjectEntry {
url: mock_sled.url(),
},
);
projects.insert("sled".to_string(), ProjectEntry::with_url(mock_sled.url()));
let config = GatewayConfig {
projects,
sled_tokens: BTreeMap::new(),
@@ -1068,12 +1054,7 @@ async fn gateway_mcp_post_without_sse_returns_plain_json() {
.await;
let mut projects = BTreeMap::new();
projects.insert(
"sled".to_string(),
ProjectEntry {
url: mock_sled.url(),
},
);
projects.insert("sled".to_string(), ProjectEntry::with_url(mock_sled.url()));
let config = GatewayConfig {
projects,
sled_tokens: BTreeMap::new(),
@@ -1118,3 +1099,199 @@ async fn gateway_mcp_post_without_sse_returns_plain_json() {
"Expected result in plain JSON response"
);
}
// ── Story 899: MCP-over-WS uplink integration ────────────────────────────────
/// Build a `SledConnection` plus a spawned "mock sled" task that:
///
/// * Reads outbound `mcp_request` envelopes off the connection's channel.
/// * Invokes the supplied closure to build a `result` value for each request.
/// * Resolves the matching in-flight oneshot directly (the same effect the WS
/// handler has when it receives an `mcp_response` from a real sled).
///
/// Returns the registered `SledConnection`.
fn spawn_mock_sled<F>(handler: F) -> crate::service::gateway::SledConnection
where
F: Fn(&serde_json::Value) -> serde_json::Value + Send + Sync + 'static,
{
use crate::service::gateway::SledConnection;
use std::sync::atomic::AtomicI64;
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let last_heartbeat_ms = Arc::new(AtomicI64::new(chrono::Utc::now().timestamp_millis()));
let in_flight = Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new()));
let conn = SledConnection {
tx,
last_heartbeat_ms,
in_flight: Arc::clone(&in_flight),
};
let handler = Arc::new(handler);
tokio::spawn(async move {
while let Some(env) = rx.recv().await {
if env.msg_type != "mcp_request" {
continue;
}
let body_str = env
.payload
.get("body")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let body_json: serde_json::Value = match serde_json::from_str(&body_str) {
Ok(v) => v,
Err(_) => serde_json::Value::Null,
};
let result = handler(&body_json);
let response = serde_json::json!({
"jsonrpc": "2.0",
"id": body_json.get("id").cloned().unwrap_or(serde_json::Value::Null),
"result": result,
});
if let Some(oneshot_tx) = in_flight.lock().await.remove(&env.req_id) {
let _ = oneshot_tx.send(response);
}
}
});
conn
}
/// AC 9: gateway switches active project, calls tools/list and tools/call
/// against a sled connected only via WS uplink, gets correct responses.
#[tokio::test]
async fn ws_only_sled_handles_tools_list_and_tools_call() {
use crate::service::gateway::ProjectEntry;
// Project entry with NO url — WS-only.
let mut projects = BTreeMap::new();
projects.insert(
"ws-only".into(),
ProjectEntry {
url: None,
auth_token: Some("secret".into()),
},
);
let config = GatewayConfig {
projects,
sled_tokens: BTreeMap::new(),
};
let state = Arc::new(GatewayState::new(config, PathBuf::new(), 3000).unwrap());
let conn = spawn_mock_sled(|body| {
let method = body.get("method").and_then(|m| m.as_str()).unwrap_or("");
match method {
"tools/list" => serde_json::json!({
"tools": [
{ "name": "my_tool", "description": "test" }
]
}),
"tools/call" => serde_json::json!({
"content": [{ "type": "text", "text": "called ok" }]
}),
_ => serde_json::json!({ "echo": method }),
}
});
state
.register_sled_connection("ws-only".to_string(), conn)
.await;
// tools/list via proxy_active_mcp.
let body = serde_json::to_vec(&serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list",
"params": {}
}))
.unwrap();
let resp = state.proxy_active_mcp(&body).await.expect("ws proxy works");
let resp_json: serde_json::Value = serde_json::from_slice(&resp).unwrap();
assert_eq!(
resp_json["result"]["tools"][0]["name"], "my_tool",
"tools/list response must come from the WS-connected sled, not HTTP"
);
// tools/call via proxy_active_mcp.
let body = serde_json::to_vec(&serde_json::json!({
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
"params": { "name": "my_tool", "arguments": {} }
}))
.unwrap();
let resp = state.proxy_active_mcp(&body).await.expect("ws proxy works");
let resp_json: serde_json::Value = serde_json::from_slice(&resp).unwrap();
assert_eq!(
resp_json["result"]["content"][0]["text"], "called ok",
"tools/call response must come from the WS-connected sled, not HTTP"
);
}
/// AC 10: two sleds connected to one gateway concurrently; gateway routes
/// calls to the right sled based on active project.
#[tokio::test]
async fn two_concurrent_sleds_are_routed_by_active_project() {
use crate::service::gateway::ProjectEntry;
let mut projects = BTreeMap::new();
projects.insert(
"alpha".into(),
ProjectEntry {
url: None,
auth_token: Some("alpha-tok".into()),
},
);
projects.insert(
"beta".into(),
ProjectEntry {
url: None,
auth_token: Some("beta-tok".into()),
},
);
let config = GatewayConfig {
projects,
sled_tokens: BTreeMap::new(),
};
let state = Arc::new(GatewayState::new(config, PathBuf::new(), 3000).unwrap());
let alpha_conn = spawn_mock_sled(|_body| {
serde_json::json!({
"content": [{ "type": "text", "text": "from-alpha" }]
})
});
let beta_conn = spawn_mock_sled(|_body| {
serde_json::json!({
"content": [{ "type": "text", "text": "from-beta" }]
})
});
state
.register_sled_connection("alpha".to_string(), alpha_conn)
.await;
state
.register_sled_connection("beta".to_string(), beta_conn)
.await;
// Switch to alpha.
gateway::switch_project(&state, "alpha").await.unwrap();
let body = serde_json::to_vec(&serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": { "name": "noop", "arguments": {} }
}))
.unwrap();
let resp = state.proxy_active_mcp(&body).await.expect("ws proxy works");
let resp_json: serde_json::Value = serde_json::from_slice(&resp).unwrap();
assert_eq!(
resp_json["result"]["content"][0]["text"], "from-alpha",
"When active project is alpha, calls must route to the alpha sled"
);
// Switch to beta.
gateway::switch_project(&state, "beta").await.unwrap();
let resp = state.proxy_active_mcp(&body).await.expect("ws proxy works");
let resp_json: serde_json::Value = serde_json::from_slice(&resp).unwrap();
assert_eq!(
resp_json["result"]["content"][0]["text"], "from-beta",
"When active project is beta, calls must route to the beta sled"
);
}
+40 -18
View File
@@ -203,14 +203,11 @@ pub async fn gateway_mcp_post_handler(
}
/// Proxy a request to the active project and format the response.
///
/// Prefers the live sled-uplink WebSocket when one is attached (story 899
/// AC 2); falls back to the legacy HTTP proxy otherwise.
async fn proxy_and_respond(state: &GatewayState, bytes: &[u8], id: Option<Value>) -> Response {
let url = match state.active_url().await {
Ok(u) => u,
Err(e) => {
return to_json_response(JsonRpcResponse::error(id, -32603, e.to_string()));
}
};
match gateway::io::proxy_mcp_call(&state.client, &url, bytes).await {
match state.proxy_active_mcp(bytes).await {
Ok(resp_body) => Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
@@ -316,13 +313,23 @@ fn handle_initialize(id: Option<Value>) -> JsonRpcResponse {
}
/// Fetch tools/list from the active project and merge in gateway tools.
///
/// Routes via the sled-uplink WS when one is attached (story 899 AC 2);
/// falls back to HTTP otherwise.
async fn handle_tools_list(
state: &GatewayState,
id: Option<Value>,
) -> Result<JsonRpcResponse, String> {
let url = state.active_url().await.map_err(|e| e.to_string())?;
let resp_json = gateway::io::fetch_tools_list(&state.client, &url).await?;
let rpc_body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list",
"params": {}
});
let bytes = serde_json::to_vec(&rpc_body).map_err(|e| e.to_string())?;
let resp_bytes = state.proxy_active_mcp(&bytes).await?;
let resp_json: Value =
serde_json::from_slice(&resp_bytes).map_err(|e| format!("invalid tools/list JSON: {e}"))?;
let mut tools: Vec<Value> = resp_json
.get("result")
@@ -414,21 +421,36 @@ async fn handle_gateway_status_tool(state: &GatewayState, id: Option<Value>) ->
async fn handle_gateway_health_tool(state: &GatewayState, id: Option<Value>) -> JsonRpcResponse {
let mut results = BTreeMap::new();
let project_entries: Vec<(String, String)> = state
// Build the project list, preferring the WS-uplink heartbeat as the
// source of truth for liveness (story 899 AC 3). HTTP polls are used
// only as a fallback when no live sled is connected.
let project_names: Vec<(String, Option<String>)> = state
.projects
.read()
.await
.iter()
.map(|(n, e)| (n.clone(), e.url.clone()))
.collect();
for (name, url) in &project_entries {
let status = match gateway::io::check_project_health(&state.client, url).await {
Ok(true) => "healthy".to_string(),
Ok(false) => "unhealthy".to_string(),
Err(e) => e,
let sled_conns = state.sled_connections.read().await;
for (name, url_opt) in &project_names {
let status = if let Some(conn) = sled_conns.get(name) {
if conn.is_alive(crate::service::gateway::HEARTBEAT_MAX_AGE_MS) {
"healthy (ws)".to_string()
} else {
"stale (ws heartbeat overdue)".to_string()
}
} else if let Some(url) = url_opt {
match gateway::io::check_project_health(&state.client, url).await {
Ok(true) => "healthy".to_string(),
Ok(false) => "unhealthy".to_string(),
Err(e) => e,
}
} else {
"no uplink and no url configured".to_string()
};
results.insert(name.clone(), status);
}
drop(sled_conns);
let active = state.active_project.read().await.clone();
JsonRpcResponse::success(
@@ -512,7 +534,7 @@ async fn handle_aggregate_pipeline_status_tool(
.read()
.await
.iter()
.map(|(name, entry)| (name.clone(), entry.url.clone()))
.filter_map(|(name, entry)| entry.url.as_ref().map(|u| (name.clone(), u.clone())))
.collect();
let statuses =
@@ -656,7 +678,7 @@ async fn handle_pipeline_get(state: &GatewayState, id: Option<Value>) -> JsonRpc
.read()
.await
.iter()
.map(|(n, e)| (n.clone(), e.url.clone()))
.filter_map(|(n, e)| e.url.as_ref().map(|u| (n.clone(), u.clone())))
.collect();
let results = gateway::io::fetch_all_project_pipeline_items(&project_urls, &state.client).await;
+213 -75
View File
@@ -155,21 +155,30 @@ struct SledUplinkParams {
token: Option<String>,
}
/// `GET /api/sled-uplink` — gateway-side WebSocket endpoint for sled permission uplinks.
/// `GET /api/sled-uplink` — gateway-side WebSocket endpoint for sled uplinks.
///
/// # Authentication
///
/// The connecting sled must supply a valid shared-secret token via the `token`
/// query parameter. Tokens are configured in `[sled_tokens]` in `projects.toml`
/// as `sled_id = "secret"` entries.
/// query parameter. Tokens are configured either as per-project `auth_token`
/// fields under `[projects.<name>]` in `projects.toml` (preferred, story 899)
/// or, for backwards compatibility, in the deprecated `[sled_tokens]` table.
///
/// # Protocol
///
/// See `sled_uplink.rs` for the wire format ([`UplinkEnvelope`]). The gateway
/// accepts `perm_request` messages, injects them into the local permission
/// pipeline (via `state.perm_tx`), and sends `perm_response` frames back to the
/// sled once the Matrix bot resolves them. Multiple sleds are demuxed by
/// connection: each handler owns exactly one sled's request/response flow.
/// See `sled_uplink.rs` for the wire format ([`UplinkEnvelope`]).
///
/// Phase 2 (story 899) expands the protocol from the original perm-only flow
/// to a full bidirectional MCP transport:
///
/// - **sled → gateway:** `identity` (mandatory first frame after upgrade),
/// `heartbeat`, `perm_request`, `mcp_response`.
/// - **gateway → sled:** `mcp_request`, `perm_response`.
///
/// On `identity`, the connection is published to
/// [`GatewayState::sled_connections`] under the project name, allowing the
/// MCP proxy (`mcp.rs::proxy_and_respond`) to route subsequent calls over
/// the live WS instead of HTTP. The entry is removed on disconnect.
#[handler]
pub async fn gateway_sled_uplink_handler(
ws: WebSocket,
@@ -196,82 +205,211 @@ pub async fn gateway_sled_uplink_handler(
use poem::IntoResponse as _;
let perm_tx = state.perm_tx.clone();
let state = Arc::clone(&state);
ws.on_upgrade(move |socket| async move {
let (mut sink, mut stream) = socket.split();
// Aggregator channel: spawned per-request tasks send (req_id, decision) here
// so the main loop can write perm_response frames back to the sled.
let (agg_tx, mut agg_rx) = tokio::sync::mpsc::unbounded_channel::<(
String,
crate::http::context::PermissionDecision,
)>();
run_sled_uplink_session(socket, state, sled_id, perm_tx).await;
})
.into_response()
}
crate::slog!("[gateway/sled-uplink] Sled '{}' connected", sled_id);
/// Run a single connected sled's request/response flow until it disconnects.
///
/// Performs the identity handshake, registers a [`SledConnection`] in
/// [`GatewayState::sled_connections`] under the published project name, and
/// pumps messages bidirectionally for the lifetime of the WebSocket.
async fn run_sled_uplink_session(
socket: poem::web::websocket::WebSocketStream,
state: Arc<GatewayState>,
token_sled_id: String,
perm_tx: tokio::sync::mpsc::UnboundedSender<crate::http::context::PermissionForward>,
) {
use crate::service::gateway::SledConnection;
use crate::sled_uplink::UplinkEnvelope;
use std::sync::Arc as StdArc;
use std::sync::atomic::AtomicI64;
loop {
tokio::select! {
msg = stream.next() => {
let text = match msg {
Some(Ok(WsMessage::Text(t))) => t,
Some(Ok(WsMessage::Close(_))) | None => break,
_ => continue,
};
let Ok(env) = serde_json::from_str::<crate::sled_uplink::UplinkEnvelope>(&text) else {
continue;
};
if env.msg_type == "perm_request" {
let req_id = env.req_id.clone();
let tool_name = env.payload.get("tool_name")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
let tool_input = env.payload.get("tool_input")
.cloned()
.unwrap_or(serde_json::Value::Null);
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
let fwd = crate::http::context::PermissionForward {
request_id: format!("{sled_id}:{req_id}"),
tool_name,
tool_input,
response_tx,
};
if perm_tx.send(fwd).is_err() {
break;
}
let agg_tx2 = agg_tx.clone();
tokio::spawn(async move {
let decision = response_rx
.await
.unwrap_or(crate::http::context::PermissionDecision::Deny);
let _ = agg_tx2.send((req_id, decision));
});
}
let (mut sink, mut stream) = socket.split();
// ── Identity handshake: expect the first frame to be `identity` ──────
let identity = match tokio::time::timeout(
std::time::Duration::from_secs(10),
wait_for_identity(&mut stream),
)
.await
{
Ok(Some(p)) => p,
_ => {
crate::slog!(
"[gateway/sled-uplink] '{}' missing identity frame; closing",
token_sled_id
);
return;
}
};
// Project name in the identity must match the project resolved from the
// auth token; otherwise the sled is claiming to be a different project.
if identity != token_sled_id {
crate::slog!(
"[gateway/sled-uplink] identity mismatch (token says '{}', sled claims '{}'); closing",
token_sled_id,
identity
);
return;
}
// ── Build SledConnection and publish it ─────────────────────────────
let (out_tx, mut out_rx) = tokio::sync::mpsc::unbounded_channel::<UplinkEnvelope>();
let conn = SledConnection {
tx: out_tx,
last_heartbeat_ms: StdArc::new(AtomicI64::new(chrono::Utc::now().timestamp_millis())),
in_flight: StdArc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
};
state
.register_sled_connection(identity.clone(), conn.clone())
.await;
crate::slog!(
"[gateway/sled-uplink] Sled '{}' connected and registered",
identity
);
// Aggregator channel for perm responses produced by spawned per-request tasks.
let (agg_tx, mut agg_rx) =
tokio::sync::mpsc::unbounded_channel::<(String, crate::http::context::PermissionDecision)>(
);
loop {
tokio::select! {
// Outbound: forward queued envelopes (mcp_request, etc.) to the sled.
Some(env) = out_rx.recv() => {
let Ok(text) = serde_json::to_string(&env) else { continue };
if sink.send(WsMessage::Text(text)).await.is_err() {
break;
}
Some((req_id, decision)) = agg_rx.recv() => {
use crate::http::context::PermissionDecision;
let (approved, always_allow) = match decision {
PermissionDecision::AlwaysAllow => (true, true),
PermissionDecision::Approve => (true, false),
PermissionDecision::Deny => (false, false),
};
let resp = crate::sled_uplink::UplinkEnvelope {
msg_type: "perm_response".to_string(),
req_id,
payload: serde_json::json!({
"approved": approved,
"always_allow": always_allow,
}),
};
let Ok(text) = serde_json::to_string(&resp) else { continue };
if sink.send(WsMessage::Text(text)).await.is_err() {
break;
}
// Inbound: messages from the sled.
msg = stream.next() => {
let text = match msg {
Some(Ok(WsMessage::Text(t))) => t,
Some(Ok(WsMessage::Close(_))) | None => break,
_ => continue,
};
let Ok(env) = serde_json::from_str::<UplinkEnvelope>(&text) else {
continue;
};
match env.msg_type.as_str() {
"heartbeat" => {
conn.last_heartbeat_ms.store(
chrono::Utc::now().timestamp_millis(),
std::sync::atomic::Ordering::Relaxed,
);
}
"mcp_response" => {
let tx_opt = conn.in_flight.lock().await.remove(&env.req_id);
if let Some(tx) = tx_opt {
let _ = tx.send(env.payload);
}
}
"perm_request" => {
forward_perm_request(env, &perm_tx, &agg_tx, &identity);
}
_ => {}
}
}
// Aggregator: per-request permission resolutions get formatted as perm_response.
Some((req_id, decision)) = agg_rx.recv() => {
use crate::http::context::PermissionDecision;
let (approved, always_allow) = match decision {
PermissionDecision::AlwaysAllow => (true, true),
PermissionDecision::Approve => (true, false),
PermissionDecision::Deny => (false, false),
};
let resp = UplinkEnvelope {
msg_type: "perm_response".to_string(),
req_id,
payload: serde_json::json!({
"approved": approved,
"always_allow": always_allow,
}),
};
let Ok(text) = serde_json::to_string(&resp) else { continue };
if sink.send(WsMessage::Text(text)).await.is_err() {
break;
}
}
}
}
crate::slog!("[gateway/sled-uplink] Sled '{}' disconnected", sled_id);
})
.into_response()
state.deregister_sled_connection(&identity).await;
crate::slog!("[gateway/sled-uplink] Sled '{}' disconnected", identity);
}
/// Read frames until an `identity` envelope is seen; return the project name
/// from its payload. Returns `None` if the stream ends or an invalid frame
/// arrives where the first frame is expected to be `identity`.
async fn wait_for_identity(
stream: &mut futures::stream::SplitStream<poem::web::websocket::WebSocketStream>,
) -> Option<String> {
while let Some(msg) = stream.next().await {
let text = match msg {
Ok(WsMessage::Text(t)) => t,
Ok(WsMessage::Close(_)) | Err(_) => return None,
_ => continue,
};
let Ok(env) = serde_json::from_str::<crate::sled_uplink::UplinkEnvelope>(&text) else {
return None;
};
if env.msg_type != "identity" {
return None;
}
return env
.payload
.get("project")
.and_then(|v| v.as_str())
.map(str::to_string);
}
None
}
/// Convert an inbound `perm_request` envelope into a [`PermissionForward`] and
/// inject it into the gateway's permission pipeline. The spawned waiter task
/// publishes the resolved decision into the aggregator channel so the WS
/// writer can emit a matching `perm_response` frame.
fn forward_perm_request(
env: crate::sled_uplink::UplinkEnvelope,
perm_tx: &tokio::sync::mpsc::UnboundedSender<crate::http::context::PermissionForward>,
agg_tx: &tokio::sync::mpsc::UnboundedSender<(String, crate::http::context::PermissionDecision)>,
sled_id: &str,
) {
let req_id = env.req_id.clone();
let tool_name = env
.payload
.get("tool_name")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
let tool_input = env
.payload
.get("tool_input")
.cloned()
.unwrap_or(serde_json::Value::Null);
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
let fwd = crate::http::context::PermissionForward {
request_id: format!("{sled_id}:{req_id}"),
tool_name,
tool_input,
response_tx,
};
if perm_tx.send(fwd).is_err() {
return;
}
let agg_tx2 = agg_tx.clone();
tokio::spawn(async move {
let decision = response_rx
.await
.unwrap_or(crate::http::context::PermissionDecision::Deny);
let _ = agg_tx2.send((req_id, decision));
});
}
// ── Event-push WebSocket handler ─────────────────────────────────────────────
+19 -1
View File
@@ -241,7 +241,25 @@ async fn main() -> Result<(), std::io::Error> {
.clone()
.or_else(|| std::env::var("HUSKIES_UPSTREAM_GATEWAY").ok())
.unwrap_or_default();
sled_uplink::spawn_uplink_task(upstream_gateway, Arc::clone(&services));
// Project name for the identity frame (story 899 AC 5). Env-var override
// wins; otherwise derive from the project_root basename.
let project_name = std::env::var("HUSKIES_PROJECT_NAME").unwrap_or_else(|_| {
services
.project_root
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("unknown")
.to_string()
});
let local_mcp_url = format!("http://127.0.0.1:{port}/mcp");
sled_uplink::spawn_uplink_task(
sled_uplink::UplinkConfig {
upstream_url: upstream_gateway,
project_name,
local_mcp_url,
},
Arc::clone(&services),
);
// ── Build bot contexts (WhatsApp / Slack / Discord) ───────────────────────
let (bot_ctxs, matrix_shutdown_rx) =
+144 -27
View File
@@ -7,20 +7,56 @@ use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
/// A single project entry in `projects.toml`.
///
/// Phase 2 (story 899): `url` is now optional — a project served exclusively
/// via the sled-uplink WebSocket does not need an HTTP base URL. The `url`
/// field is deprecated for removal in a future release; configure
/// `auth_token` instead and rely on the WS uplink for all traffic.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ProjectEntry {
/// Base URL of the project's huskies container (e.g. `http://localhost:3001`).
pub url: String,
///
/// **Deprecated** (story 899) — when a sled connects via the uplink WS the
/// gateway routes all MCP traffic over that connection instead. The URL is
/// used as a fallback when no live uplink exists. Omit for WS-only projects.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub url: Option<String>,
/// Shared-secret token used to authenticate this project's sled when it
/// connects to `/api/sled-uplink`. Takes precedence over the top-level
/// `[sled_tokens]` table for projects that set this field.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub auth_token: Option<String>,
}
impl ProjectEntry {
/// Convenience constructor for entries that only have a URL (e.g. in tests
/// and existing `projects.toml` files that have not yet been migrated to
/// the WS-uplink model).
pub fn with_url(url: impl Into<String>) -> Self {
Self {
url: Some(url.into()),
auth_token: None,
}
}
/// Returns `true` if this entry has a configured HTTP base URL.
pub fn has_url(&self) -> bool {
self.url.as_ref().is_some_and(|u| !u.is_empty())
}
}
/// Top-level `projects.toml` config.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct GatewayConfig {
/// Map of project name → container URL.
/// Map of project name → container configuration.
#[serde(default)]
pub projects: BTreeMap<String, ProjectEntry>,
/// Map of sled_id → shared secret token for sled-uplink authentication.
///
/// **Deprecated** (story 899) — move tokens into per-project
/// `auth_token` fields instead. The gateway still honours entries here for
/// one release to provide a smooth migration window.
///
/// Each entry allows a sled identified by `sled_id` to connect to
/// `/api/sled-uplink` using the given secret token as a bearer credential.
#[serde(default)]
@@ -40,18 +76,22 @@ pub fn validate_config(config: &GatewayConfig) -> Result<String, String> {
/// Validate that a project name exists in the given project map.
///
/// Returns the project's URL on success.
/// Returns the project's URL (may be empty for WS-uplink-only projects) on
/// success.
pub fn validate_project_exists(
projects: &BTreeMap<String, ProjectEntry>,
name: &str,
) -> Result<String, String> {
projects.get(name).map(|p| p.url.clone()).ok_or_else(|| {
let available: Vec<&str> = projects.keys().map(|s| s.as_str()).collect();
format!(
"unknown project '{name}'. Available: {}",
available.join(", ")
)
})
projects
.get(name)
.map(|p| p.url.clone().unwrap_or_default())
.ok_or_else(|| {
let available: Vec<&str> = projects.keys().map(|s| s.as_str()).collect();
format!(
"unknown project '{name}'. Available: {}",
available.join(", ")
)
})
}
/// Escape a string as a TOML quoted string.
@@ -104,8 +144,29 @@ url = "http://localhost:3002"
"#;
let config: GatewayConfig = toml::from_str(toml_str).unwrap();
assert_eq!(config.projects.len(), 2);
assert_eq!(config.projects["huskies"].url, "http://localhost:3001");
assert_eq!(config.projects["robot-studio"].url, "http://localhost:3002");
assert_eq!(
config.projects["huskies"].url.as_deref(),
Some("http://localhost:3001")
);
assert_eq!(
config.projects["robot-studio"].url.as_deref(),
Some("http://localhost:3002")
);
}
#[test]
fn parse_project_without_url_is_valid() {
let toml_str = r#"
[projects.ws-only]
auth_token = "secret"
"#;
let config: GatewayConfig = toml::from_str(toml_str).unwrap();
assert_eq!(config.projects.len(), 1);
assert!(config.projects["ws-only"].url.is_none());
assert_eq!(
config.projects["ws-only"].auth_token.as_deref(),
Some("secret")
);
}
#[test]
@@ -127,18 +188,8 @@ url = "http://localhost:3002"
#[test]
fn validate_config_returns_first_project_name() {
let mut projects = BTreeMap::new();
projects.insert(
"beta".into(),
ProjectEntry {
url: "http://b".into(),
},
);
projects.insert(
"alpha".into(),
ProjectEntry {
url: "http://a".into(),
},
);
projects.insert("beta".into(), ProjectEntry::with_url("http://b"));
projects.insert("alpha".into(), ProjectEntry::with_url("http://a"));
let config = GatewayConfig {
projects,
sled_tokens: BTreeMap::new(),
@@ -147,14 +198,26 @@ url = "http://localhost:3002"
}
#[test]
fn validate_project_exists_succeeds() {
fn validate_config_accepts_ws_only_project() {
let mut projects = BTreeMap::new();
projects.insert(
"p1".into(),
"ws-only".into(),
ProjectEntry {
url: "http://p1".into(),
url: None,
auth_token: Some("secret".into()),
},
);
let config = GatewayConfig {
projects,
sled_tokens: BTreeMap::new(),
};
assert!(validate_config(&config).is_ok());
}
#[test]
fn validate_project_exists_succeeds() {
let mut projects = BTreeMap::new();
projects.insert("p1".into(), ProjectEntry::with_url("http://p1"));
assert_eq!(
validate_project_exists(&projects, "p1").unwrap(),
"http://p1"
@@ -167,6 +230,36 @@ url = "http://localhost:3002"
assert!(validate_project_exists(&projects, "missing").is_err());
}
#[test]
fn validate_project_exists_ws_only_returns_empty_url() {
let mut projects = BTreeMap::new();
projects.insert(
"ws".into(),
ProjectEntry {
url: None,
auth_token: Some("tok".into()),
},
);
assert_eq!(validate_project_exists(&projects, "ws").unwrap(), "");
}
#[test]
fn project_entry_with_url_constructor() {
let e = ProjectEntry::with_url("http://example.com");
assert_eq!(e.url.as_deref(), Some("http://example.com"));
assert!(e.auth_token.is_none());
assert!(e.has_url());
}
#[test]
fn project_entry_has_url_false_when_none() {
let e = ProjectEntry {
url: None,
auth_token: Some("tok".into()),
};
assert!(!e.has_url());
}
#[test]
fn toml_string_escapes_quotes() {
assert_eq!(toml_string(r#"a"b"#), r#""a\"b""#);
@@ -198,4 +291,28 @@ url = "http://localhost:3002"
assert!(content.contains("transport = \"slack\""));
assert!(content.contains("slack_bot_token = \"xoxb-123\""));
}
#[test]
fn roundtrip_project_entry_with_auth_token() {
let entry = ProjectEntry {
url: Some("http://a:3001".into()),
auth_token: Some("mysecret".into()),
};
let mut projects = BTreeMap::new();
projects.insert("myproj".into(), entry);
let config = GatewayConfig {
projects,
sled_tokens: BTreeMap::new(),
};
let toml_str = toml::to_string_pretty(&config).unwrap();
let parsed: GatewayConfig = toml::from_str(&toml_str).unwrap();
assert_eq!(
parsed.projects["myproj"].url.as_deref(),
Some("http://a:3001")
);
assert_eq!(
parsed.projects["myproj"].auth_token.as_deref(),
Some("mysecret")
);
}
}
-23
View File
@@ -140,29 +140,6 @@ pub async fn proxy_mcp_call_sse(
.map_err(|e| format!("failed to reach {mcp_url}: {e}"))
}
/// Fetch tools/list from a project's MCP endpoint.
pub async fn fetch_tools_list(client: &Client, base_url: &str) -> Result<Value, String> {
let mcp_url = format!("{}/mcp", base_url.trim_end_matches('/'));
let rpc_body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list",
"params": {}
});
let resp = client
.post(&mcp_url)
.json(&rpc_body)
.send()
.await
.map_err(|e| format!("failed to reach {mcp_url}: {e}"))?;
resp.json()
.await
.map_err(|e| format!("invalid JSON from upstream: {e}"))
}
/// Fetch and aggregate pipeline status for a single project URL.
pub async fn fetch_one_project_pipeline_status(url: &str, client: &Client) -> Value {
let mcp_url = format!("{}/mcp", url.trim_end_matches('/'));
+232 -17
View File
@@ -28,6 +28,7 @@ use io::Client;
use std::collections::{BTreeMap, HashMap};
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::{AtomicI64, Ordering};
use tokio::sync::Mutex as TokioMutex;
use tokio::sync::RwLock;
use tokio::sync::mpsc;
@@ -49,6 +50,95 @@ pub struct GatewayStatusEvent {
pub event: crate::service::events::StoredEvent,
}
// ── Sled connection ─────────────────────────────────────────────────────────
/// Maximum age, in milliseconds, of a sled heartbeat before the gateway
/// considers the connection stale (story 899 AC 3).
pub const HEARTBEAT_MAX_AGE_MS: i64 = 30_000;
/// Default per-request timeout, in milliseconds, for an MCP call proxied over
/// a sled uplink WebSocket. Mirrors the existing reqwest-based path which has
/// no explicit cap; we set a generous bound so long-running tools (e.g.
/// `run_tests`) still complete.
pub const MCP_VIA_WS_TIMEOUT_MS: u64 = 1_200_000;
/// Handle to a sled currently connected to the gateway via the uplink WebSocket.
///
/// Created by the `/api/sled-uplink` WS handler on connect and stored in
/// [`GatewayState::sled_connections`]. The gateway's MCP proxy reads this
/// to forward requests over the live connection rather than via HTTP.
#[derive(Clone)]
pub struct SledConnection {
/// Sender side of the channel the WS handler reads to forward outgoing
/// frames (e.g. `mcp_request`) to the sled.
pub tx: mpsc::UnboundedSender<crate::sled_uplink::UplinkEnvelope>,
/// Timestamp (ms since Unix epoch) of the last `heartbeat` frame received
/// from this sled. Updated atomically by the WS handler task.
pub last_heartbeat_ms: Arc<AtomicI64>,
/// In-flight MCP requests waiting for a matching `mcp_response` from this
/// sled. Keyed by `req_id`; the oneshot sender is resolved when the
/// response arrives.
pub in_flight:
Arc<TokioMutex<HashMap<String, tokio::sync::oneshot::Sender<serde_json::Value>>>>,
}
impl SledConnection {
/// Returns `true` if a heartbeat has been received within the last `max_age_ms`
/// milliseconds.
pub fn is_alive(&self, max_age_ms: i64) -> bool {
let last = self.last_heartbeat_ms.load(Ordering::Relaxed);
let now = chrono::Utc::now().timestamp_millis();
now - last <= max_age_ms
}
}
/// Proxy a raw MCP request body to a sled over its uplink WebSocket and
/// return the serialised JSON response bytes.
///
/// Generates a fresh correlation id, registers a oneshot in the connection's
/// in-flight map, sends an `mcp_request` envelope, and waits for the
/// matching `mcp_response` (or [`MCP_VIA_WS_TIMEOUT_MS`] to elapse).
///
/// The response payload is serialised back into JSON bytes so callers can
/// return it directly to the HTTP client unchanged.
pub async fn proxy_mcp_via_ws(
conn: &SledConnection,
request_bytes: &[u8],
) -> Result<Vec<u8>, String> {
let req_id = uuid::Uuid::new_v4().to_string();
let body_str = std::str::from_utf8(request_bytes)
.map_err(|e| format!("non-utf8 mcp request body: {e}"))?
.to_string();
let (tx, rx) = tokio::sync::oneshot::channel();
conn.in_flight.lock().await.insert(req_id.clone(), tx);
let env = crate::sled_uplink::UplinkEnvelope {
msg_type: "mcp_request".to_string(),
req_id: req_id.clone(),
payload: serde_json::json!({ "body": body_str }),
};
if conn.tx.send(env).is_err() {
conn.in_flight.lock().await.remove(&req_id);
return Err("sled uplink connection closed".to_string());
}
let timeout = std::time::Duration::from_millis(MCP_VIA_WS_TIMEOUT_MS);
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(response_value)) => {
serde_json::to_vec(&response_value).map_err(|e| format!("serialise mcp_response: {e}"))
}
Ok(Err(_)) => Err("sled response channel dropped".to_string()),
Err(_) => {
conn.in_flight.lock().await.remove(&req_id);
Err(format!(
"mcp call to sled timed out after {MCP_VIA_WS_TIMEOUT_MS} ms"
))
}
}
}
// ── Error type ──────────────────────────────────────────────────────────────
/// Typed errors returned by `service::gateway` functions.
@@ -135,11 +225,18 @@ pub struct GatewayState {
/// The Matrix bot's `permission_listener` holds this locked for its lifetime;
/// the sled-uplink WS handler sends requests via `perm_tx`.
pub perm_rx: Arc<TokioMutex<mpsc::UnboundedReceiver<PermissionForward>>>,
/// Reversed sled-token map: token → sled_id.
/// Reversed sled-token map: token → project_name (sled_id).
///
/// Built at startup from [`GatewayConfig::sled_tokens`] (which maps
/// sled_id → token). The handler looks up incoming tokens in O(1).
/// Built at startup from both [`GatewayConfig::sled_tokens`] AND the
/// per-project `auth_token` field (story 899). The handler looks up
/// incoming tokens in O(1) to identify the project the sled represents.
pub sled_tokens: HashMap<String, String>,
/// Live sled connections keyed by project name.
///
/// Populated by the `/api/sled-uplink` WS handler when a sled authenticates
/// and depopulated when it disconnects. MCP proxy functions check here
/// first (WS route), falling back to HTTP when no live connection exists.
pub sled_connections: Arc<RwLock<HashMap<String, SledConnection>>>,
}
impl GatewayState {
@@ -160,11 +257,21 @@ impl GatewayState {
.unwrap_or(first_from_config);
let (event_tx, _) = tokio::sync::broadcast::channel(EVENT_CHANNEL_CAPACITY);
let (perm_tx, perm_rx) = mpsc::unbounded_channel::<PermissionForward>();
let sled_tokens: HashMap<String, String> = gateway_config
// Build token→project_name map from two sources:
// 1. Legacy top-level [sled_tokens] section (sled_id → token, reversed)
// 2. Per-project auth_token fields (project_name → token, reversed)
let mut sled_tokens: HashMap<String, String> = gateway_config
.sled_tokens
.iter()
.map(|(sled_id, token)| (token.clone(), sled_id.clone()))
.collect();
for (project_name, entry) in &gateway_config.projects {
if let Some(ref token) = entry.auth_token {
sled_tokens.insert(token.clone(), project_name.clone());
}
}
Ok(Self {
projects: Arc::new(RwLock::new(gateway_config.projects)),
active_project: Arc::new(RwLock::new(first)),
@@ -178,26 +285,78 @@ impl GatewayState {
perm_tx,
perm_rx: Arc::new(TokioMutex::new(perm_rx)),
sled_tokens,
sled_connections: Arc::new(RwLock::new(HashMap::new())),
})
}
/// Get the URL of the currently active project.
/// Get the URL of the currently active project, if one is configured.
///
/// Returns `Err` when the active project has no URL configured (WS-uplink
/// only) or the project name is not found.
pub async fn active_url(&self) -> Result<String, Error> {
let name = self.active_project.read().await.clone();
self.projects
.read()
.await
.get(&name)
.map(|p| p.url.clone())
.and_then(|p| p.url.clone())
.ok_or_else(|| {
Error::ProjectNotFound(format!("active project '{name}' not found in config"))
Error::ProjectNotFound(format!(
"active project '{name}' has no URL configured \
(use sled-uplink WS or add url to projects.toml)"
))
})
}
/// Register a live sled connection for the given project.
pub async fn register_sled_connection(&self, project_name: String, conn: SledConnection) {
self.sled_connections
.write()
.await
.insert(project_name, conn);
}
/// Remove the sled connection for the given project (on disconnect).
pub async fn deregister_sled_connection(&self, project_name: &str) {
self.sled_connections.write().await.remove(project_name);
}
/// Look up the live sled connection for the active project, returning a
/// clone if one exists and has a recent heartbeat.
///
/// Returns `None` when no sled has connected for this project or when its
/// heartbeat is overdue.
pub async fn active_sled_connection(&self) -> Option<SledConnection> {
let name = self.active_project.read().await.clone();
let conn = self.sled_connections.read().await.get(&name).cloned()?;
if conn.is_alive(HEARTBEAT_MAX_AGE_MS) {
Some(conn)
} else {
None
}
}
/// Proxy an MCP request to the active project, preferring the live
/// sled-uplink WebSocket when available (story 899 AC 2) and falling
/// back to HTTP otherwise.
///
/// Returns the raw response body bytes ready to be relayed to the caller.
pub async fn proxy_active_mcp(&self, bytes: &[u8]) -> Result<Vec<u8>, String> {
if let Some(conn) = self.active_sled_connection().await {
return proxy_mcp_via_ws(&conn, bytes).await;
}
let url = self.active_url().await.map_err(|e| e.to_string())?;
crate::slog!(
"[gateway] MCP proxy: WS uplink unavailable, falling back to HTTP \
(deprecated, will be removed once all sleds are WS-only)"
);
crate::service::gateway::io::proxy_mcp_call(&self.client, &url, bytes).await
}
}
// ── Public API ──────────────────────────────────────────────────────────────
/// Switch the active project. Returns the project's URL on success.
/// Switch the active project. Returns the project's URL (empty for WS-only projects).
///
/// Writes the new active project to the CRDT `gateway_config.active_project`
/// register (LWW — last write wins) so the selection is persisted across
@@ -358,7 +517,7 @@ pub async fn add_project(state: &GatewayState, name: &str, url: &str) -> Result<
"project '{name}' already exists"
)));
}
projects.insert(name.clone(), ProjectEntry { url: url.clone() });
projects.insert(name.clone(), ProjectEntry::with_url(&url));
}
let snapshot = state.projects.read().await.clone();
@@ -441,7 +600,7 @@ pub async fn init_project(
"project '{n}' is already registered. Choose a different name or use switch_project."
)));
}
projects.insert(n.to_string(), ProjectEntry { url: u.to_string() });
projects.insert(n.to_string(), ProjectEntry::with_url(u));
io::save_config(&projects, &state.config_dir).await;
crate::slog!("[gateway] init_project: registered '{n}' ({u})");
Some(n.to_string())
@@ -494,7 +653,7 @@ pub async fn save_bot_config_and_restart(state: &GatewayState, content: &str) ->
.read()
.await
.iter()
.map(|(name, entry)| (name.clone(), entry.url.clone()))
.filter_map(|(name, entry)| entry.url.as_ref().map(|u| (name.clone(), u.clone())))
.collect();
let (new_handle, new_shutdown_tx) = io::spawn_gateway_bot(
@@ -523,12 +682,7 @@ mod tests {
fn make_config(names: &[(&str, &str)]) -> GatewayConfig {
let mut projects = BTreeMap::new();
for (name, url) in names {
projects.insert(
name.to_string(),
ProjectEntry {
url: url.to_string(),
},
);
projects.insert(name.to_string(), ProjectEntry::with_url(*url));
}
GatewayConfig {
projects,
@@ -584,6 +738,24 @@ mod tests {
assert_eq!(url, "http://my:3001");
}
#[tokio::test]
async fn active_url_fails_for_ws_only_project() {
let mut projects = BTreeMap::new();
projects.insert(
"ws-proj".into(),
ProjectEntry {
url: None,
auth_token: Some("tok".into()),
},
);
let config = GatewayConfig {
projects,
sled_tokens: BTreeMap::new(),
};
let state = GatewayState::new(config, PathBuf::from("."), 3000).unwrap();
assert!(state.active_url().await.is_err());
}
#[test]
fn error_display_variants() {
assert!(
@@ -682,4 +854,47 @@ mod tests {
let result = init_project(&state, dir.path().to_str().unwrap(), None, None).await;
assert!(result.is_err());
}
#[tokio::test]
async fn sled_connection_registration_and_lookup() {
let config = make_config(&[("myproj", "http://myproj:3001")]);
let state = GatewayState::new(config, PathBuf::new(), 3000).unwrap();
let (tx, _rx) = mpsc::unbounded_channel();
let conn = SledConnection {
tx,
last_heartbeat_ms: Arc::new(AtomicI64::new(chrono::Utc::now().timestamp_millis())),
in_flight: Arc::new(TokioMutex::new(HashMap::new())),
};
state
.register_sled_connection("myproj".to_string(), conn)
.await;
assert!(state.sled_connections.read().await.contains_key("myproj"));
state.deregister_sled_connection("myproj").await;
assert!(!state.sled_connections.read().await.contains_key("myproj"));
}
#[tokio::test]
async fn auth_token_in_project_entry_populates_sled_tokens_map() {
let mut projects = BTreeMap::new();
projects.insert(
"huskies".into(),
ProjectEntry {
url: Some("http://huskies:3001".into()),
auth_token: Some("secret-token".into()),
},
);
let config = GatewayConfig {
projects,
sled_tokens: BTreeMap::new(),
};
let state = GatewayState::new(config, PathBuf::new(), 3000).unwrap();
assert_eq!(
state.sled_tokens.get("secret-token").map(|s| s.as_str()),
Some("huskies"),
"Per-project auth_token must be in reversed sled_tokens map"
);
}
}
+250 -27
View File
@@ -1,6 +1,5 @@
//! Sled uplink — background task that maintains a WebSocket connection from a
//! sled (standard huskies instance) to an upstream gateway for permission
//! request forwarding.
//! sled (standard huskies instance) to an upstream gateway.
//!
//! When `HUSKIES_UPSTREAM_GATEWAY` is set (or `--upstream-gateway` is passed
//! on the CLI), this module spawns a task that:
@@ -10,10 +9,17 @@
//! auto-denying requests with "no interactive session".
//! 2. Maintains a persistent WebSocket connection to the gateway's
//! `/api/sled-uplink` endpoint.
//! 3. Forwards each [`PermissionForward`] as a `perm_request` JSON envelope.
//! 4. Awaits the matching `perm_response` envelope from the gateway.
//! 5. Reconnects with exponential back-off on connection drop, fail-closing
//! any in-flight requests with [`PermissionDecision::Deny`].
//! 3. Sends an `identity` frame announcing the sled's project name + auth
//! token immediately after connect (story 899 Phase 2).
//! 4. Sends a `heartbeat` frame every [`HEARTBEAT_INTERVAL_SECS`] so the
//! gateway can mark the connection live without extra HTTP polls.
//! 5. Forwards each [`PermissionForward`] as a `perm_request` JSON envelope
//! and awaits the matching `perm_response`.
//! 6. Handles inbound `mcp_request` frames by replaying the MCP JSON-RPC
//! body against the sled's own local `/mcp` HTTP endpoint and returning
//! the response as an `mcp_response` frame.
//! 7. Reconnects with exponential back-off on connection drop, fail-closing
//! any in-flight permission requests with [`PermissionDecision::Deny`].
use crate::http::context::{PermissionDecision, PermissionForward};
use crate::services::Services;
@@ -25,12 +31,35 @@ use std::sync::Arc;
use tokio::sync::oneshot;
use tokio_tungstenite::tungstenite::Message as WsMessage;
/// Configuration for [`spawn_uplink_task`].
///
/// Bundled into a single struct so `main.rs` can build it once and pass it in
/// without proliferating positional arguments.
pub struct UplinkConfig {
/// WebSocket URL of the upstream gateway's `/api/sled-uplink` endpoint.
/// Includes the `?token=` query parameter for auth.
pub upstream_url: String,
/// Project name this sled identifies as. Sent in the `identity` frame
/// after WS connect (story 899 AC 5).
pub project_name: String,
/// HTTP base URL of this sled's own MCP endpoint (e.g.
/// `http://127.0.0.1:3001/mcp`). Used to replay `mcp_request` frames
/// received from the gateway against the local MCP handler.
pub local_mcp_url: String,
}
// ── Back-off constants ────────────────────────────────────────────────────────
const INITIAL_BACKOFF_SECS: u64 = 1;
const MAX_BACKOFF_SECS: u64 = 60;
const BACKOFF_MULTIPLIER: u64 = 2;
/// Interval between `heartbeat` frames sent to the gateway (story 899 AC 3).
///
/// Must be shorter than `crate::service::gateway::HEARTBEAT_MAX_AGE_MS` so
/// the gateway never marks a healthy sled as stale.
pub const HEARTBEAT_INTERVAL_SECS: u64 = 10;
// ── Wire protocol ─────────────────────────────────────────────────────────────
/// Extensible JSON envelope for all sled↔gateway uplink messages.
@@ -54,15 +83,21 @@ pub struct UplinkEnvelope {
/// Spawn the sled uplink background task.
///
/// Does nothing when `upstream_url` is empty. When active, the task holds
/// `services.perm_rx` locked for its lifetime (preventing auto-deny in
/// Does nothing when `config.upstream_url` is empty (AC 8 — sleds without
/// an upstream configured continue to work unchanged). When active, the task
/// holds `services.perm_rx` locked for its lifetime (preventing auto-deny in
/// `tool_prompt_permission`) and forwards all permission requests to the
/// gateway. Reconnects automatically with exponential back-off.
pub fn spawn_uplink_task(upstream_url: String, services: Arc<Services>) {
if upstream_url.is_empty() {
pub fn spawn_uplink_task(config: UplinkConfig, services: Arc<Services>) {
if config.upstream_url.is_empty() {
return;
}
slog!("[uplink] Spawning sled uplink task (gateway={upstream_url})");
let UplinkConfig {
upstream_url,
project_name,
local_mcp_url,
} = config;
slog!("[uplink] Spawning sled uplink task (gateway={upstream_url}, project={project_name})");
tokio::spawn(async move {
// Acquire perm_rx for this task's entire lifetime. While this lock is
// held, try_lock() inside tool_prompt_permission fails — meaning
@@ -70,9 +105,19 @@ pub fn spawn_uplink_task(upstream_url: String, services: Arc<Services>) {
let mut perm_rx = services.perm_rx.lock().await;
slog!("[uplink] Acquired perm_rx; maintaining gateway connection");
let http = reqwest::Client::new();
let mut backoff = INITIAL_BACKOFF_SECS;
loop {
match run_uplink_session(&upstream_url, &mut perm_rx).await {
match run_uplink_session(
&upstream_url,
&project_name,
&local_mcp_url,
&http,
&mut perm_rx,
)
.await
{
Ok(()) => {
slog!("[uplink] Connection closed cleanly; reconnecting in {backoff}s");
}
@@ -90,10 +135,14 @@ pub fn spawn_uplink_task(upstream_url: String, services: Arc<Services>) {
// ── Private helpers ───────────────────────────────────────────────────────────
/// Run a single uplink session: connect, pump messages bidirectionally until
/// disconnect or channel close, then fail-close any in-flight requests.
/// Run a single uplink session: connect, send identity frame, pump messages
/// bidirectionally until disconnect or channel close, then fail-close any
/// in-flight requests.
async fn run_uplink_session(
url: &str,
project_name: &str,
local_mcp_url: &str,
http: &reqwest::Client,
perm_rx: &mut tokio::sync::mpsc::UnboundedReceiver<PermissionForward>,
) -> Result<(), String> {
let (ws_stream, _) = tokio_tungstenite::connect_async(url)
@@ -102,9 +151,31 @@ async fn run_uplink_session(
slog!("[uplink] Connected to gateway uplink endpoint");
let (mut ws_sink, mut ws_rx) = ws_stream.split();
// ── Identity handshake (story 899 AC 5) ─────────────────────────────
let identity = UplinkEnvelope {
msg_type: "identity".to_string(),
req_id: "identity".to_string(),
payload: serde_json::json!({ "project": project_name }),
};
let identity_text =
serde_json::to_string(&identity).map_err(|e| format!("serialise identity: {e}"))?;
ws_sink
.send(WsMessage::Text(identity_text.into()))
.await
.map_err(|e| format!("send identity: {e}"))?;
let mut in_flight: HashMap<String, oneshot::Sender<PermissionDecision>> = HashMap::new();
let result = pump_messages(&mut ws_sink, &mut ws_rx, perm_rx, &mut in_flight).await;
let result = pump_messages(
&mut ws_sink,
&mut ws_rx,
perm_rx,
&mut in_flight,
local_mcp_url,
http,
)
.await;
fail_close_all(&mut in_flight);
result
}
@@ -118,9 +189,44 @@ async fn pump_messages(
),
perm_rx: &mut tokio::sync::mpsc::UnboundedReceiver<PermissionForward>,
in_flight: &mut HashMap<String, oneshot::Sender<PermissionDecision>>,
local_mcp_url: &str,
http: &reqwest::Client,
) -> Result<(), String> {
// Channel for spawned mcp_request handlers to deliver their finished
// mcp_response frames back to the WS writer.
let (mcp_resp_tx, mut mcp_resp_rx) = tokio::sync::mpsc::unbounded_channel::<UplinkEnvelope>();
let mut heartbeat =
tokio::time::interval(std::time::Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
// Skip the immediate tick so the first heartbeat fires after the interval,
// not right after connect.
heartbeat.tick().await;
loop {
tokio::select! {
// Heartbeat tick (story 899 AC 3).
_ = heartbeat.tick() => {
let env = UplinkEnvelope {
msg_type: "heartbeat".to_string(),
req_id: String::new(),
payload: serde_json::Value::Null,
};
let text = serde_json::to_string(&env)
.map_err(|e| format!("serialise heartbeat: {e}"))?;
if ws_sink.send(WsMessage::Text(text.into())).await.is_err() {
return Err("WS send failed (heartbeat)".to_string());
}
}
// Completed mcp_response from a spawned handler — forward to gateway.
Some(env) = mcp_resp_rx.recv() => {
let text = serde_json::to_string(&env)
.map_err(|e| format!("serialise mcp_response: {e}"))?;
if ws_sink.send(WsMessage::Text(text.into())).await.is_err() {
return Err("WS send failed (mcp_response)".to_string());
}
}
// New permission request from the MCP layer.
maybe_fwd = perm_rx.recv() => {
match maybe_fwd {
@@ -162,7 +268,7 @@ async fn pump_messages(
return Err("Gateway sent Close frame".to_string());
}
Some(Ok(WsMessage::Text(text))) => {
on_gateway_text(&text, in_flight);
on_gateway_text(&text, in_flight, local_mcp_url, http, &mcp_resp_tx);
}
Some(Ok(WsMessage::Ping(data))) => {
let _ = ws_sink.send(WsMessage::Pong(data)).await;
@@ -174,19 +280,85 @@ async fn pump_messages(
}
}
/// Parse an incoming gateway text frame and resolve any matching in-flight request.
/// Parse an incoming gateway text frame and dispatch it to the appropriate
/// handler. `perm_response` resolves a waiting in-flight permission request;
/// `mcp_request` spawns a task that replays the JSON-RPC body against the
/// sled's local `/mcp` HTTP endpoint and forwards the result back to the
/// gateway as an `mcp_response`.
fn on_gateway_text(
text: &str,
in_flight: &mut HashMap<String, oneshot::Sender<PermissionDecision>>,
local_mcp_url: &str,
http: &reqwest::Client,
mcp_resp_tx: &tokio::sync::mpsc::UnboundedSender<UplinkEnvelope>,
) {
let Ok(env) = serde_json::from_str::<UplinkEnvelope>(text) else {
return;
};
if env.msg_type == "perm_response" {
resolve_perm_response(env, in_flight);
match env.msg_type.as_str() {
"perm_response" => resolve_perm_response(env, in_flight),
"mcp_request" => spawn_mcp_request_handler(env, local_mcp_url, http, mcp_resp_tx),
_ => {}
}
}
/// Replay the gateway's `mcp_request` body against the sled's local MCP HTTP
/// endpoint in a spawned task and forward the response back as an
/// `mcp_response` envelope. The payload is expected to contain a `body`
/// string field holding the raw JSON-RPC bytes.
fn spawn_mcp_request_handler(
env: UplinkEnvelope,
local_mcp_url: &str,
http: &reqwest::Client,
mcp_resp_tx: &tokio::sync::mpsc::UnboundedSender<UplinkEnvelope>,
) {
let req_id = env.req_id.clone();
let body = env
.payload
.get("body")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let mcp_url = local_mcp_url.to_string();
let http = http.clone();
let mcp_resp_tx = mcp_resp_tx.clone();
tokio::spawn(async move {
let response_value = match http
.post(&mcp_url)
.header("Content-Type", "application/json")
.body(body)
.send()
.await
{
Ok(r) => match r.json::<serde_json::Value>().await {
Ok(v) => v,
Err(e) => serde_json::json!({
"jsonrpc": "2.0",
"id": null,
"error": {
"code": -32603,
"message": format!("local mcp invalid JSON: {e}"),
}
}),
},
Err(e) => serde_json::json!({
"jsonrpc": "2.0",
"id": null,
"error": {
"code": -32603,
"message": format!("local mcp request failed: {e}"),
}
}),
};
let resp = UplinkEnvelope {
msg_type: "mcp_response".to_string(),
req_id,
payload: response_value,
};
let _ = mcp_resp_tx.send(resp);
});
}
/// Map a `perm_response` envelope to a [`PermissionDecision`] and wake the
/// waiting MCP call.
fn resolve_perm_response(
@@ -321,9 +493,13 @@ mod tests {
#[test]
fn on_gateway_text_ignores_unknown_type() {
let mut in_flight: HashMap<String, oneshot::Sender<PermissionDecision>> = HashMap::new();
let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
on_gateway_text(
r#"{"type":"future_type","req_id":"x","payload":{}}"#,
&mut in_flight,
"http://127.0.0.1:0/mcp",
&reqwest::Client::new(),
&tx,
);
assert!(in_flight.is_empty());
}
@@ -331,7 +507,14 @@ mod tests {
#[test]
fn on_gateway_text_ignores_invalid_json() {
let mut in_flight: HashMap<String, oneshot::Sender<PermissionDecision>> = HashMap::new();
on_gateway_text("not-json", &mut in_flight); // must not panic
let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
on_gateway_text(
"not-json",
&mut in_flight,
"http://127.0.0.1:0/mcp",
&reqwest::Client::new(),
&tx,
); // must not panic
assert!(in_flight.is_empty());
}
@@ -351,7 +534,14 @@ mod tests {
permission_timeout_secs: 120,
});
// Empty URL → noop; if it panicked or blocked the test would fail.
spawn_uplink_task(String::new(), services);
spawn_uplink_task(
UplinkConfig {
upstream_url: String::new(),
project_name: "test".to_string(),
local_mcp_url: "http://127.0.0.1:0/mcp".to_string(),
},
services,
);
}
// ── AC 11: permission approved via uplink ────────────────────────
@@ -364,10 +554,23 @@ mod tests {
let port = listener.local_addr().unwrap().port();
let url = format!("ws://127.0.0.1:{port}");
// Mock gateway: accept one connection, receive perm_request, reply approved.
// Mock gateway: accept one connection, consume identity, receive
// perm_request, reply approved.
let gw_task = tokio::spawn(async move {
let (tcp, _) = listener.accept().await.unwrap();
let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap();
// First frame: identity (story 899 AC 5).
let msg = ws.next().await.unwrap().unwrap();
let text = match msg {
WsMessage::Text(t) => t.to_string(),
other => panic!("expected identity Text; got {other:?}"),
};
let env: UplinkEnvelope = serde_json::from_str(&text).unwrap();
assert_eq!(env.msg_type, "identity");
assert_eq!(env.payload["project"], "test-proj");
// Next frame: perm_request.
let msg = ws.next().await.unwrap().unwrap();
let text = match msg {
WsMessage::Text(t) => t.to_string(),
@@ -402,7 +605,14 @@ mod tests {
permission_timeout_secs: 120,
});
spawn_uplink_task(url, Arc::clone(&services));
spawn_uplink_task(
UplinkConfig {
upstream_url: url,
project_name: "test-proj".to_string(),
local_mcp_url: "http://127.0.0.1:0/mcp".to_string(),
},
Arc::clone(&services),
);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let (response_tx, response_rx) = oneshot::channel();
@@ -441,12 +651,14 @@ mod tests {
let conn_count2 = StdArc::clone(&conn_count);
tokio::spawn(async move {
// First connection: receive the request then immediately drop (simulates network failure).
// First connection: consume identity + the request frame, then
// drop without replying (simulates network failure).
{
let (tcp, _) = listener.accept().await.unwrap();
let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap();
conn_count2.fetch_add(1, Ordering::SeqCst);
let _ = ws.next().await; // consume one frame
let _ = ws.next().await; // identity frame
let _ = ws.next().await; // perm_request frame
drop(ws); // close without replying → fail-close on sled side
}
// Second connection: approve the next request.
@@ -454,7 +666,10 @@ mod tests {
let (tcp, _) = listener.accept().await.unwrap();
let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap();
conn_count2.fetch_add(1, Ordering::SeqCst);
if let Some(Ok(WsMessage::Text(text))) = ws.next().await {
// Consume identity frame.
let _ = ws.next().await;
// Drain non-perm frames (heartbeat etc.) until perm_request arrives.
while let Some(Ok(WsMessage::Text(text))) = ws.next().await {
let env: UplinkEnvelope = serde_json::from_str(&text).unwrap();
if env.msg_type == "perm_request" {
let resp = UplinkEnvelope {
@@ -467,6 +682,7 @@ mod tests {
serde_json::to_string(&resp).unwrap().into(),
))
.await;
break;
}
}
}
@@ -486,7 +702,14 @@ mod tests {
permission_timeout_secs: 120,
});
spawn_uplink_task(url, Arc::clone(&services));
spawn_uplink_task(
UplinkConfig {
upstream_url: url,
project_name: "test-proj".to_string(),
local_mcp_url: "http://127.0.0.1:0/mcp".to_string(),
},
Arc::clone(&services),
);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// First request is sent on the connection that drops → denied.