767 lines
31 KiB
Rust
767 lines
31 KiB
Rust
//! Sled uplink — background task that maintains a WebSocket connection from a
|
|
//! sled (standard huskies instance) to an upstream gateway.
|
|
//!
|
|
//! When `HUSKIES_UPSTREAM_GATEWAY` is set (or `--upstream-gateway` is passed
|
|
//! on the CLI), this module spawns a task that:
|
|
//!
|
|
//! 1. Acquires `services.perm_rx` for its lifetime (matching the Matrix bot's
|
|
//! `permission_listener` pattern), preventing `tool_prompt_permission` from
|
|
//! auto-denying requests with "no interactive session".
|
|
//! 2. Maintains a persistent WebSocket connection to the gateway's
|
|
//! `/api/sled-uplink` endpoint.
|
|
//! 3. Sends an `identity` frame announcing the sled's project name + auth
|
|
//! token immediately after connect (story 899 Phase 2).
|
|
//! 4. Sends a `heartbeat` frame every [`HEARTBEAT_INTERVAL_SECS`] so the
|
|
//! gateway can mark the connection live without extra HTTP polls.
|
|
//! 5. Forwards each [`PermissionForward`] as a `perm_request` JSON envelope
|
|
//! and awaits the matching `perm_response`.
|
|
//! 6. Handles inbound `mcp_request` frames by replaying the MCP JSON-RPC
|
|
//! body against the sled's own local `/mcp` HTTP endpoint and returning
|
|
//! the response as an `mcp_response` frame.
|
|
//! 7. Reconnects with exponential back-off on connection drop, fail-closing
|
|
//! any in-flight permission requests with [`PermissionDecision::Deny`].
|
|
|
|
use crate::http::context::{PermissionDecision, PermissionForward};
|
|
use crate::services::Services;
|
|
use crate::slog;
|
|
use futures::SinkExt as _;
|
|
use futures::StreamExt as _;
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use tokio::sync::oneshot;
|
|
use tokio_tungstenite::tungstenite::Message as WsMessage;
|
|
|
|
/// Configuration for [`spawn_uplink_task`].
|
|
///
|
|
/// Bundled into a single struct so `main.rs` can build it once and pass it in
|
|
/// without proliferating positional arguments.
|
|
pub struct UplinkConfig {
|
|
/// WebSocket URL of the upstream gateway's `/api/sled-uplink` endpoint.
|
|
/// Includes the `?token=` query parameter for auth.
|
|
pub upstream_url: String,
|
|
/// Project name this sled identifies as. Sent in the `identity` frame
|
|
/// after WS connect (story 899 AC 5).
|
|
pub project_name: String,
|
|
/// HTTP base URL of this sled's own MCP endpoint (e.g.
|
|
/// `http://127.0.0.1:3001/mcp`). Used to replay `mcp_request` frames
|
|
/// received from the gateway against the local MCP handler.
|
|
pub local_mcp_url: String,
|
|
}
|
|
|
|
// ── Back-off constants ────────────────────────────────────────────────────────
|
|
|
|
const INITIAL_BACKOFF_SECS: u64 = 1;
|
|
const MAX_BACKOFF_SECS: u64 = 60;
|
|
const BACKOFF_MULTIPLIER: u64 = 2;
|
|
|
|
/// Interval between `heartbeat` frames sent to the gateway (story 899 AC 3).
|
|
///
|
|
/// Must be shorter than `crate::service::gateway::HEARTBEAT_MAX_AGE_MS` so
|
|
/// the gateway never marks a healthy sled as stale.
|
|
pub const HEARTBEAT_INTERVAL_SECS: u64 = 10;
|
|
|
|
// ── Wire protocol ─────────────────────────────────────────────────────────────
|
|
|
|
/// Extensible JSON envelope for all sled↔gateway uplink messages.
|
|
///
|
|
/// Phase 1 defines `perm_request` (sled→gateway) and `perm_response`
|
|
/// (gateway→sled). Future phases add new `type` values without changing this
|
|
/// framing.
|
|
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
|
|
pub struct UplinkEnvelope {
|
|
/// Message type discriminant (e.g. `"perm_request"`, `"perm_response"`).
|
|
#[serde(rename = "type")]
|
|
pub msg_type: String,
|
|
/// Correlation ID — the sled chooses this for each request; the gateway
|
|
/// echoes it back so the sled can demux concurrent responses.
|
|
pub req_id: String,
|
|
/// Message-specific payload. Varies by `msg_type`.
|
|
pub payload: serde_json::Value,
|
|
}
|
|
|
|
// ── Public API ────────────────────────────────────────────────────────────────
|
|
|
|
/// Spawn the sled uplink background task.
|
|
///
|
|
/// Does nothing when `config.upstream_url` is empty (AC 8 — sleds without
|
|
/// an upstream configured continue to work unchanged). When active, the task
|
|
/// holds `services.perm_rx` locked for its lifetime (preventing auto-deny in
|
|
/// `tool_prompt_permission`) and forwards all permission requests to the
|
|
/// gateway. Reconnects automatically with exponential back-off.
|
|
pub fn spawn_uplink_task(config: UplinkConfig, services: Arc<Services>) {
|
|
if config.upstream_url.is_empty() {
|
|
return;
|
|
}
|
|
let UplinkConfig {
|
|
upstream_url,
|
|
project_name,
|
|
local_mcp_url,
|
|
} = config;
|
|
slog!("[uplink] Spawning sled uplink task (gateway={upstream_url}, project={project_name})");
|
|
tokio::spawn(async move {
|
|
// Acquire perm_rx for this task's entire lifetime. While this lock is
|
|
// held, try_lock() inside tool_prompt_permission fails — meaning
|
|
// requests flow to perm_tx (which we drain here) rather than auto-deny.
|
|
let mut perm_rx = services.perm_rx.lock().await;
|
|
slog!("[uplink] Acquired perm_rx; maintaining gateway connection");
|
|
|
|
let http = reqwest::Client::new();
|
|
|
|
let mut backoff = INITIAL_BACKOFF_SECS;
|
|
loop {
|
|
match run_uplink_session(
|
|
&upstream_url,
|
|
&project_name,
|
|
&local_mcp_url,
|
|
&http,
|
|
&mut perm_rx,
|
|
)
|
|
.await
|
|
{
|
|
Ok(()) => {
|
|
slog!("[uplink] Connection closed cleanly; reconnecting in {backoff}s");
|
|
}
|
|
Err(ref e) => {
|
|
slog!("[uplink] Session error: {e}; reconnecting in {backoff}s");
|
|
}
|
|
}
|
|
tokio::time::sleep(std::time::Duration::from_secs(backoff)).await;
|
|
backoff = backoff
|
|
.saturating_mul(BACKOFF_MULTIPLIER)
|
|
.min(MAX_BACKOFF_SECS);
|
|
}
|
|
});
|
|
}
|
|
|
|
// ── Private helpers ───────────────────────────────────────────────────────────
|
|
|
|
/// Run a single uplink session: connect, send identity frame, pump messages
|
|
/// bidirectionally until disconnect or channel close, then fail-close any
|
|
/// in-flight requests.
|
|
async fn run_uplink_session(
|
|
url: &str,
|
|
project_name: &str,
|
|
local_mcp_url: &str,
|
|
http: &reqwest::Client,
|
|
perm_rx: &mut tokio::sync::mpsc::UnboundedReceiver<PermissionForward>,
|
|
) -> Result<(), String> {
|
|
let (ws_stream, _) = tokio_tungstenite::connect_async(url)
|
|
.await
|
|
.map_err(|e| format!("WS connect to {url}: {e}"))?;
|
|
slog!("[uplink] Connected to gateway uplink endpoint");
|
|
|
|
let (mut ws_sink, mut ws_rx) = ws_stream.split();
|
|
|
|
// ── Identity handshake (story 899 AC 5) ─────────────────────────────
|
|
let identity = UplinkEnvelope {
|
|
msg_type: "identity".to_string(),
|
|
req_id: "identity".to_string(),
|
|
payload: serde_json::json!({ "project": project_name }),
|
|
};
|
|
let identity_text =
|
|
serde_json::to_string(&identity).map_err(|e| format!("serialise identity: {e}"))?;
|
|
ws_sink
|
|
.send(WsMessage::Text(identity_text.into()))
|
|
.await
|
|
.map_err(|e| format!("send identity: {e}"))?;
|
|
|
|
let mut in_flight: HashMap<String, oneshot::Sender<PermissionDecision>> = HashMap::new();
|
|
|
|
let result = pump_messages(
|
|
&mut ws_sink,
|
|
&mut ws_rx,
|
|
perm_rx,
|
|
&mut in_flight,
|
|
local_mcp_url,
|
|
http,
|
|
)
|
|
.await;
|
|
fail_close_all(&mut in_flight);
|
|
result
|
|
}
|
|
|
|
/// Drive the bidirectional message loop for one session.
|
|
async fn pump_messages(
|
|
ws_sink: &mut (impl futures::Sink<WsMessage, Error = tokio_tungstenite::tungstenite::Error> + Unpin),
|
|
ws_rx: &mut (
|
|
impl futures::Stream<Item = Result<WsMessage, tokio_tungstenite::tungstenite::Error>>
|
|
+ Unpin
|
|
),
|
|
perm_rx: &mut tokio::sync::mpsc::UnboundedReceiver<PermissionForward>,
|
|
in_flight: &mut HashMap<String, oneshot::Sender<PermissionDecision>>,
|
|
local_mcp_url: &str,
|
|
http: &reqwest::Client,
|
|
) -> Result<(), String> {
|
|
// Channel for spawned mcp_request handlers to deliver their finished
|
|
// mcp_response frames back to the WS writer.
|
|
let (mcp_resp_tx, mut mcp_resp_rx) = tokio::sync::mpsc::unbounded_channel::<UplinkEnvelope>();
|
|
|
|
let mut heartbeat =
|
|
tokio::time::interval(std::time::Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
|
|
// Skip the immediate tick so the first heartbeat fires after the interval,
|
|
// not right after connect.
|
|
heartbeat.tick().await;
|
|
|
|
loop {
|
|
tokio::select! {
|
|
// Heartbeat tick (story 899 AC 3).
|
|
_ = heartbeat.tick() => {
|
|
let env = UplinkEnvelope {
|
|
msg_type: "heartbeat".to_string(),
|
|
req_id: String::new(),
|
|
payload: serde_json::Value::Null,
|
|
};
|
|
let text = serde_json::to_string(&env)
|
|
.map_err(|e| format!("serialise heartbeat: {e}"))?;
|
|
if ws_sink.send(WsMessage::Text(text.into())).await.is_err() {
|
|
return Err("WS send failed (heartbeat)".to_string());
|
|
}
|
|
}
|
|
|
|
// Completed mcp_response from a spawned handler — forward to gateway.
|
|
Some(env) = mcp_resp_rx.recv() => {
|
|
let text = serde_json::to_string(&env)
|
|
.map_err(|e| format!("serialise mcp_response: {e}"))?;
|
|
if ws_sink.send(WsMessage::Text(text.into())).await.is_err() {
|
|
return Err("WS send failed (mcp_response)".to_string());
|
|
}
|
|
}
|
|
|
|
// New permission request from the MCP layer.
|
|
maybe_fwd = perm_rx.recv() => {
|
|
match maybe_fwd {
|
|
None => return Ok(()), // channel closed — server shutting down
|
|
Some(fwd) => {
|
|
let PermissionForward {
|
|
request_id,
|
|
tool_name,
|
|
tool_input,
|
|
response_tx,
|
|
} = fwd;
|
|
let env = UplinkEnvelope {
|
|
msg_type: "perm_request".to_string(),
|
|
req_id: request_id.clone(),
|
|
payload: serde_json::json!({
|
|
"tool_name": tool_name,
|
|
"tool_input": tool_input,
|
|
}),
|
|
};
|
|
let text = serde_json::to_string(&env)
|
|
.map_err(|e| format!("serialise perm_request: {e}"))?;
|
|
if ws_sink.send(WsMessage::Text(text.into())).await.is_err() {
|
|
// Connection dead: fail-close this request immediately.
|
|
let _ = response_tx.send(PermissionDecision::Deny);
|
|
return Err("WS send failed".to_string());
|
|
}
|
|
in_flight.insert(request_id, response_tx);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Message arriving from the gateway.
|
|
msg = ws_rx.next() => {
|
|
match msg {
|
|
None | Some(Err(_)) => {
|
|
return Err("WS stream closed".to_string());
|
|
}
|
|
Some(Ok(WsMessage::Close(_))) => {
|
|
return Err("Gateway sent Close frame".to_string());
|
|
}
|
|
Some(Ok(WsMessage::Text(text))) => {
|
|
on_gateway_text(&text, in_flight, local_mcp_url, http, &mcp_resp_tx);
|
|
}
|
|
Some(Ok(WsMessage::Ping(data))) => {
|
|
let _ = ws_sink.send(WsMessage::Pong(data)).await;
|
|
}
|
|
Some(Ok(_)) => {}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Parse an incoming gateway text frame and dispatch it to the appropriate
|
|
/// handler. `perm_response` resolves a waiting in-flight permission request;
|
|
/// `mcp_request` spawns a task that replays the JSON-RPC body against the
|
|
/// sled's local `/mcp` HTTP endpoint and forwards the result back to the
|
|
/// gateway as an `mcp_response`.
|
|
fn on_gateway_text(
|
|
text: &str,
|
|
in_flight: &mut HashMap<String, oneshot::Sender<PermissionDecision>>,
|
|
local_mcp_url: &str,
|
|
http: &reqwest::Client,
|
|
mcp_resp_tx: &tokio::sync::mpsc::UnboundedSender<UplinkEnvelope>,
|
|
) {
|
|
let Ok(env) = serde_json::from_str::<UplinkEnvelope>(text) else {
|
|
return;
|
|
};
|
|
match env.msg_type.as_str() {
|
|
"perm_response" => resolve_perm_response(env, in_flight),
|
|
"mcp_request" => spawn_mcp_request_handler(env, local_mcp_url, http, mcp_resp_tx),
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
/// Replay the gateway's `mcp_request` body against the sled's local MCP HTTP
|
|
/// endpoint in a spawned task and forward the response back as an
|
|
/// `mcp_response` envelope. The payload is expected to contain a `body`
|
|
/// string field holding the raw JSON-RPC bytes.
|
|
fn spawn_mcp_request_handler(
|
|
env: UplinkEnvelope,
|
|
local_mcp_url: &str,
|
|
http: &reqwest::Client,
|
|
mcp_resp_tx: &tokio::sync::mpsc::UnboundedSender<UplinkEnvelope>,
|
|
) {
|
|
let req_id = env.req_id.clone();
|
|
let body = env
|
|
.payload
|
|
.get("body")
|
|
.and_then(|v| v.as_str())
|
|
.unwrap_or("")
|
|
.to_string();
|
|
let mcp_url = local_mcp_url.to_string();
|
|
let http = http.clone();
|
|
let mcp_resp_tx = mcp_resp_tx.clone();
|
|
tokio::spawn(async move {
|
|
let response_value = match http
|
|
.post(&mcp_url)
|
|
.header("Content-Type", "application/json")
|
|
.body(body)
|
|
.send()
|
|
.await
|
|
{
|
|
Ok(r) => match r.json::<serde_json::Value>().await {
|
|
Ok(v) => v,
|
|
Err(e) => serde_json::json!({
|
|
"jsonrpc": "2.0",
|
|
"id": null,
|
|
"error": {
|
|
"code": -32603,
|
|
"message": format!("local mcp invalid JSON: {e}"),
|
|
}
|
|
}),
|
|
},
|
|
Err(e) => serde_json::json!({
|
|
"jsonrpc": "2.0",
|
|
"id": null,
|
|
"error": {
|
|
"code": -32603,
|
|
"message": format!("local mcp request failed: {e}"),
|
|
}
|
|
}),
|
|
};
|
|
let resp = UplinkEnvelope {
|
|
msg_type: "mcp_response".to_string(),
|
|
req_id,
|
|
payload: response_value,
|
|
};
|
|
let _ = mcp_resp_tx.send(resp);
|
|
});
|
|
}
|
|
|
|
/// Map a `perm_response` envelope to a [`PermissionDecision`] and wake the
|
|
/// waiting MCP call.
|
|
fn resolve_perm_response(
|
|
env: UplinkEnvelope,
|
|
in_flight: &mut HashMap<String, oneshot::Sender<PermissionDecision>>,
|
|
) {
|
|
let Some(tx) = in_flight.remove(&env.req_id) else {
|
|
return;
|
|
};
|
|
let approved = env
|
|
.payload
|
|
.get("approved")
|
|
.and_then(|v| v.as_bool())
|
|
.unwrap_or(false);
|
|
let always_allow = env
|
|
.payload
|
|
.get("always_allow")
|
|
.and_then(|v| v.as_bool())
|
|
.unwrap_or(false);
|
|
let decision = if always_allow {
|
|
PermissionDecision::AlwaysAllow
|
|
} else if approved {
|
|
PermissionDecision::Approve
|
|
} else {
|
|
PermissionDecision::Deny
|
|
};
|
|
let _ = tx.send(decision);
|
|
}
|
|
|
|
/// Deny all in-flight requests (fail-closed on connection drop — AC 8).
|
|
fn fail_close_all(in_flight: &mut HashMap<String, oneshot::Sender<PermissionDecision>>) {
|
|
for (_, tx) in in_flight.drain() {
|
|
let _ = tx.send(PermissionDecision::Deny);
|
|
}
|
|
}
|
|
|
|
// ── Tests ─────────────────────────────────────────────────────────────────────
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::http::context::PermissionForward;
|
|
use crate::services::Services;
|
|
use std::collections::HashMap;
|
|
use tokio::net::TcpListener;
|
|
use tokio::sync::oneshot;
|
|
use tokio_tungstenite::tungstenite::Message as WsMessage;
|
|
|
|
// ── Pure unit tests ───────────────────────────────────────────────
|
|
|
|
#[test]
|
|
fn uplink_envelope_roundtrips_json() {
|
|
let env = UplinkEnvelope {
|
|
msg_type: "perm_request".to_string(),
|
|
req_id: "req-1".to_string(),
|
|
payload: serde_json::json!({"tool_name": "Bash", "tool_input": {}}),
|
|
};
|
|
let text = serde_json::to_string(&env).unwrap();
|
|
let back: UplinkEnvelope = serde_json::from_str(&text).unwrap();
|
|
assert_eq!(back.msg_type, "perm_request");
|
|
assert_eq!(back.req_id, "req-1");
|
|
assert_eq!(back.payload["tool_name"], "Bash");
|
|
}
|
|
|
|
#[test]
|
|
fn resolve_perm_response_approve() {
|
|
let (tx, rx) = oneshot::channel();
|
|
let mut in_flight = HashMap::new();
|
|
in_flight.insert("r1".to_string(), tx);
|
|
let env = UplinkEnvelope {
|
|
msg_type: "perm_response".to_string(),
|
|
req_id: "r1".to_string(),
|
|
payload: serde_json::json!({"approved": true, "always_allow": false}),
|
|
};
|
|
resolve_perm_response(env, &mut in_flight);
|
|
assert!(in_flight.is_empty());
|
|
assert_eq!(rx.blocking_recv().unwrap(), PermissionDecision::Approve);
|
|
}
|
|
|
|
#[test]
|
|
fn resolve_perm_response_deny() {
|
|
let (tx, rx) = oneshot::channel();
|
|
let mut in_flight = HashMap::new();
|
|
in_flight.insert("r2".to_string(), tx);
|
|
let env = UplinkEnvelope {
|
|
msg_type: "perm_response".to_string(),
|
|
req_id: "r2".to_string(),
|
|
payload: serde_json::json!({"approved": false}),
|
|
};
|
|
resolve_perm_response(env, &mut in_flight);
|
|
assert_eq!(rx.blocking_recv().unwrap(), PermissionDecision::Deny);
|
|
}
|
|
|
|
#[test]
|
|
fn resolve_perm_response_always_allow() {
|
|
let (tx, rx) = oneshot::channel();
|
|
let mut in_flight = HashMap::new();
|
|
in_flight.insert("r3".to_string(), tx);
|
|
let env = UplinkEnvelope {
|
|
msg_type: "perm_response".to_string(),
|
|
req_id: "r3".to_string(),
|
|
payload: serde_json::json!({"approved": true, "always_allow": true}),
|
|
};
|
|
resolve_perm_response(env, &mut in_flight);
|
|
assert_eq!(rx.blocking_recv().unwrap(), PermissionDecision::AlwaysAllow);
|
|
}
|
|
|
|
#[test]
|
|
fn resolve_perm_response_unknown_req_id_is_noop() {
|
|
let mut in_flight: HashMap<String, oneshot::Sender<PermissionDecision>> = HashMap::new();
|
|
let env = UplinkEnvelope {
|
|
msg_type: "perm_response".to_string(),
|
|
req_id: "missing".to_string(),
|
|
payload: serde_json::json!({"approved": true}),
|
|
};
|
|
resolve_perm_response(env, &mut in_flight); // must not panic
|
|
}
|
|
|
|
#[test]
|
|
fn fail_close_all_denies_all_pending() {
|
|
let (tx1, rx1) = oneshot::channel();
|
|
let (tx2, rx2) = oneshot::channel();
|
|
let mut in_flight = HashMap::new();
|
|
in_flight.insert("a".to_string(), tx1);
|
|
in_flight.insert("b".to_string(), tx2);
|
|
fail_close_all(&mut in_flight);
|
|
assert!(in_flight.is_empty());
|
|
assert_eq!(rx1.blocking_recv().unwrap(), PermissionDecision::Deny);
|
|
assert_eq!(rx2.blocking_recv().unwrap(), PermissionDecision::Deny);
|
|
}
|
|
|
|
#[test]
|
|
fn on_gateway_text_ignores_unknown_type() {
|
|
let mut in_flight: HashMap<String, oneshot::Sender<PermissionDecision>> = HashMap::new();
|
|
let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
|
|
on_gateway_text(
|
|
r#"{"type":"future_type","req_id":"x","payload":{}}"#,
|
|
&mut in_flight,
|
|
"http://127.0.0.1:0/mcp",
|
|
&reqwest::Client::new(),
|
|
&tx,
|
|
);
|
|
assert!(in_flight.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn on_gateway_text_ignores_invalid_json() {
|
|
let mut in_flight: HashMap<String, oneshot::Sender<PermissionDecision>> = HashMap::new();
|
|
let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
|
|
on_gateway_text(
|
|
"not-json",
|
|
&mut in_flight,
|
|
"http://127.0.0.1:0/mcp",
|
|
&reqwest::Client::new(),
|
|
&tx,
|
|
); // must not panic
|
|
assert!(in_flight.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn spawn_uplink_task_noop_when_url_empty() {
|
|
let (_perm_tx, perm_rx) = tokio::sync::mpsc::unbounded_channel();
|
|
let agents = Arc::new(crate::agents::AgentPool::new_test(3000));
|
|
let services = Arc::new(Services {
|
|
project_root: std::path::PathBuf::from("/tmp"),
|
|
status: agents.status_broadcaster(),
|
|
agents,
|
|
bot_name: "Test".to_string(),
|
|
bot_user_id: String::new(),
|
|
ambient_rooms: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
|
|
perm_rx: Arc::new(tokio::sync::Mutex::new(perm_rx)),
|
|
pending_perm_replies: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
|
permission_timeout_secs: 120,
|
|
});
|
|
// Empty URL → noop; if it panicked or blocked the test would fail.
|
|
spawn_uplink_task(
|
|
UplinkConfig {
|
|
upstream_url: String::new(),
|
|
project_name: "test".to_string(),
|
|
local_mcp_url: "http://127.0.0.1:0/mcp".to_string(),
|
|
},
|
|
services,
|
|
);
|
|
}
|
|
|
|
// ── AC 11: permission approved via uplink ────────────────────────
|
|
// "Simulate matrix bot triggering a Bash permission, sled forwards via
|
|
// uplink, mock matrix transport approves, tool call proceeds."
|
|
|
|
#[tokio::test]
|
|
async fn integration_perm_request_approved_via_uplink() {
|
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
|
let port = listener.local_addr().unwrap().port();
|
|
let url = format!("ws://127.0.0.1:{port}");
|
|
|
|
// Mock gateway: accept one connection, consume identity, receive
|
|
// perm_request, reply approved.
|
|
let gw_task = tokio::spawn(async move {
|
|
let (tcp, _) = listener.accept().await.unwrap();
|
|
let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap();
|
|
|
|
// First frame: identity (story 899 AC 5).
|
|
let msg = ws.next().await.unwrap().unwrap();
|
|
let text = match msg {
|
|
WsMessage::Text(t) => t.to_string(),
|
|
other => panic!("expected identity Text; got {other:?}"),
|
|
};
|
|
let env: UplinkEnvelope = serde_json::from_str(&text).unwrap();
|
|
assert_eq!(env.msg_type, "identity");
|
|
assert_eq!(env.payload["project"], "test-proj");
|
|
|
|
// Next frame: perm_request.
|
|
let msg = ws.next().await.unwrap().unwrap();
|
|
let text = match msg {
|
|
WsMessage::Text(t) => t.to_string(),
|
|
other => panic!("expected Text; got {other:?}"),
|
|
};
|
|
let env: UplinkEnvelope = serde_json::from_str(&text).unwrap();
|
|
assert_eq!(env.msg_type, "perm_request");
|
|
assert_eq!(env.payload["tool_name"], "Bash");
|
|
let resp = UplinkEnvelope {
|
|
msg_type: "perm_response".to_string(),
|
|
req_id: env.req_id,
|
|
payload: serde_json::json!({"approved": true, "always_allow": false}),
|
|
};
|
|
ws.send(WsMessage::Text(
|
|
serde_json::to_string(&resp).unwrap().into(),
|
|
))
|
|
.await
|
|
.unwrap();
|
|
});
|
|
|
|
let (perm_tx, perm_rx) = tokio::sync::mpsc::unbounded_channel::<PermissionForward>();
|
|
let agents = Arc::new(crate::agents::AgentPool::new_test(3000));
|
|
let services = Arc::new(Services {
|
|
project_root: std::path::PathBuf::from("/tmp"),
|
|
status: agents.status_broadcaster(),
|
|
agents,
|
|
bot_name: "Test".to_string(),
|
|
bot_user_id: String::new(),
|
|
ambient_rooms: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
|
|
perm_rx: Arc::new(tokio::sync::Mutex::new(perm_rx)),
|
|
pending_perm_replies: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
|
permission_timeout_secs: 120,
|
|
});
|
|
|
|
spawn_uplink_task(
|
|
UplinkConfig {
|
|
upstream_url: url,
|
|
project_name: "test-proj".to_string(),
|
|
local_mcp_url: "http://127.0.0.1:0/mcp".to_string(),
|
|
},
|
|
Arc::clone(&services),
|
|
);
|
|
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
|
|
|
let (response_tx, response_rx) = oneshot::channel();
|
|
perm_tx
|
|
.send(PermissionForward {
|
|
request_id: "req-test-1".to_string(),
|
|
tool_name: "Bash".to_string(),
|
|
tool_input: serde_json::json!({"command": "echo hello"}),
|
|
response_tx,
|
|
})
|
|
.unwrap();
|
|
|
|
let decision = tokio::time::timeout(std::time::Duration::from_secs(5), response_rx)
|
|
.await
|
|
.expect("timeout waiting for decision")
|
|
.expect("oneshot dropped");
|
|
|
|
assert_eq!(decision, PermissionDecision::Approve);
|
|
gw_task.await.unwrap();
|
|
}
|
|
|
|
// ── AC 12: sled disconnects and reconnects ────────────────────────
|
|
// "Sled disconnects and reconnects mid-session; subsequent permission
|
|
// requests succeed once reconnected."
|
|
|
|
#[tokio::test]
|
|
async fn integration_reconnects_after_disconnect() {
|
|
use std::sync::Arc as StdArc;
|
|
use std::sync::atomic::{AtomicU32, Ordering};
|
|
|
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
|
let port = listener.local_addr().unwrap().port();
|
|
let url = format!("ws://127.0.0.1:{port}");
|
|
|
|
let conn_count = StdArc::new(AtomicU32::new(0));
|
|
let conn_count2 = StdArc::clone(&conn_count);
|
|
|
|
tokio::spawn(async move {
|
|
// First connection: consume identity + the request frame, then
|
|
// drop without replying (simulates network failure).
|
|
{
|
|
let (tcp, _) = listener.accept().await.unwrap();
|
|
let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap();
|
|
conn_count2.fetch_add(1, Ordering::SeqCst);
|
|
let _ = ws.next().await; // identity frame
|
|
let _ = ws.next().await; // perm_request frame
|
|
drop(ws); // close without replying → fail-close on sled side
|
|
}
|
|
// Second connection: approve the next request.
|
|
{
|
|
let (tcp, _) = listener.accept().await.unwrap();
|
|
let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap();
|
|
conn_count2.fetch_add(1, Ordering::SeqCst);
|
|
// Consume identity frame.
|
|
let _ = ws.next().await;
|
|
// Drain non-perm frames (heartbeat etc.) until perm_request arrives.
|
|
while let Some(Ok(WsMessage::Text(text))) = ws.next().await {
|
|
let env: UplinkEnvelope = serde_json::from_str(&text).unwrap();
|
|
if env.msg_type == "perm_request" {
|
|
let resp = UplinkEnvelope {
|
|
msg_type: "perm_response".to_string(),
|
|
req_id: env.req_id,
|
|
payload: serde_json::json!({"approved": true, "always_allow": false}),
|
|
};
|
|
let _ = ws
|
|
.send(WsMessage::Text(
|
|
serde_json::to_string(&resp).unwrap().into(),
|
|
))
|
|
.await;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
let (perm_tx, perm_rx) = tokio::sync::mpsc::unbounded_channel::<PermissionForward>();
|
|
let agents = Arc::new(crate::agents::AgentPool::new_test(3000));
|
|
let services = Arc::new(Services {
|
|
project_root: std::path::PathBuf::from("/tmp"),
|
|
status: agents.status_broadcaster(),
|
|
agents,
|
|
bot_name: "Test".to_string(),
|
|
bot_user_id: String::new(),
|
|
ambient_rooms: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
|
|
perm_rx: Arc::new(tokio::sync::Mutex::new(perm_rx)),
|
|
pending_perm_replies: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
|
permission_timeout_secs: 120,
|
|
});
|
|
|
|
spawn_uplink_task(
|
|
UplinkConfig {
|
|
upstream_url: url,
|
|
project_name: "test-proj".to_string(),
|
|
local_mcp_url: "http://127.0.0.1:0/mcp".to_string(),
|
|
},
|
|
Arc::clone(&services),
|
|
);
|
|
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
|
|
|
// First request is sent on the connection that drops → denied.
|
|
let (tx1, rx1) = oneshot::channel();
|
|
perm_tx
|
|
.send(PermissionForward {
|
|
request_id: "req-drop".to_string(),
|
|
tool_name: "Bash".to_string(),
|
|
tool_input: serde_json::json!({}),
|
|
response_tx: tx1,
|
|
})
|
|
.unwrap();
|
|
|
|
let d1 = tokio::time::timeout(std::time::Duration::from_secs(5), rx1)
|
|
.await
|
|
.expect("timeout on first request")
|
|
.expect("oneshot dropped");
|
|
assert_eq!(
|
|
d1,
|
|
PermissionDecision::Deny,
|
|
"dropped connection must fail-close"
|
|
);
|
|
|
|
// Wait for the 1-second reconnect backoff plus buffer.
|
|
tokio::time::sleep(std::time::Duration::from_millis(1500)).await;
|
|
|
|
// Second request arrives on the reconnected session → approved.
|
|
let (tx2, rx2) = oneshot::channel();
|
|
perm_tx
|
|
.send(PermissionForward {
|
|
request_id: "req-reconnect".to_string(),
|
|
tool_name: "Write".to_string(),
|
|
tool_input: serde_json::json!({}),
|
|
response_tx: tx2,
|
|
})
|
|
.unwrap();
|
|
|
|
let d2 = tokio::time::timeout(std::time::Duration::from_secs(5), rx2)
|
|
.await
|
|
.expect("timeout on second request")
|
|
.expect("oneshot dropped");
|
|
assert_eq!(
|
|
d2,
|
|
PermissionDecision::Approve,
|
|
"reconnected session must approve"
|
|
);
|
|
|
|
assert_eq!(
|
|
conn_count.load(Ordering::SeqCst),
|
|
2,
|
|
"must have seen exactly 2 connections"
|
|
);
|
|
}
|
|
}
|