huskies: merge 898
This commit is contained in:
@@ -21,6 +21,12 @@ pub(crate) struct CliArgs {
|
||||
pub(crate) join_token: Option<String>,
|
||||
/// HTTP URL of the gateway to register with when a join token is provided (`--gateway-url`).
|
||||
pub(crate) gateway_url: Option<String>,
|
||||
/// WebSocket URL of the upstream gateway to forward permission requests to (`--upstream-gateway`).
|
||||
///
|
||||
/// When set, the sled spawns a background uplink task that holds `perm_rx` and
|
||||
/// forwards all `prompt_permission` tool calls to the gateway over a WebSocket.
|
||||
/// Also readable from the `HUSKIES_UPSTREAM_GATEWAY` env var.
|
||||
pub(crate) upstream_gateway: Option<String>,
|
||||
}
|
||||
|
||||
/// Parse CLI arguments into `CliArgs`, or exit early for `--help` / `--version`.
|
||||
@@ -34,6 +40,7 @@ pub(crate) fn parse_cli_args(args: &[String]) -> Result<CliArgs, String> {
|
||||
let mut rendezvous: Option<String> = None;
|
||||
let mut join_token: Option<String> = None;
|
||||
let mut gateway_url: Option<String> = None;
|
||||
let mut upstream_gateway: Option<String> = None;
|
||||
let mut i = 0;
|
||||
|
||||
while i < args.len() {
|
||||
@@ -94,6 +101,16 @@ pub(crate) fn parse_cli_args(args: &[String]) -> Result<CliArgs, String> {
|
||||
a if a.starts_with("--gateway-url=") => {
|
||||
gateway_url = Some(a["--gateway-url=".len()..].to_string());
|
||||
}
|
||||
"--upstream-gateway" => {
|
||||
i += 1;
|
||||
if i >= args.len() {
|
||||
return Err("--upstream-gateway requires a value".to_string());
|
||||
}
|
||||
upstream_gateway = Some(args[i].clone());
|
||||
}
|
||||
a if a.starts_with("--upstream-gateway=") => {
|
||||
upstream_gateway = Some(a["--upstream-gateway=".len()..].to_string());
|
||||
}
|
||||
"--gateway" => {
|
||||
gateway = true;
|
||||
}
|
||||
@@ -129,6 +146,7 @@ pub(crate) fn parse_cli_args(args: &[String]) -> Result<CliArgs, String> {
|
||||
gateway,
|
||||
join_token,
|
||||
gateway_url,
|
||||
upstream_gateway,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -167,6 +185,11 @@ pub(crate) fn print_help() {
|
||||
println!(" --gateway-url <URL> HTTP URL of the gateway to register with when");
|
||||
println!(" --join-token is provided (agent mode only).");
|
||||
println!(" Also readable from HUSKIES_GATEWAY_URL env var.");
|
||||
println!(" --upstream-gateway <URL> WebSocket URL of an upstream gateway to forward");
|
||||
println!(" permission requests to (sled mode). When set, the");
|
||||
println!(" sled connects to WS URL and forwards all");
|
||||
println!(" prompt_permission calls via the uplink protocol.");
|
||||
println!(" Also readable from HUSKIES_UPSTREAM_GATEWAY env var.");
|
||||
}
|
||||
|
||||
/// Resolve the optional positional path argument into an absolute `PathBuf`.
|
||||
@@ -343,6 +366,36 @@ mod tests {
|
||||
let result = parse_cli_args(&[]).unwrap();
|
||||
assert_eq!(result.join_token, None);
|
||||
assert_eq!(result.gateway_url, None);
|
||||
assert_eq!(result.upstream_gateway, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_upstream_gateway_flag() {
|
||||
let args = vec![
|
||||
"--upstream-gateway".to_string(),
|
||||
"ws://gateway:3001/api/sled-uplink?token=abc".to_string(),
|
||||
];
|
||||
let result = parse_cli_args(&args).unwrap();
|
||||
assert_eq!(
|
||||
result.upstream_gateway,
|
||||
Some("ws://gateway:3001/api/sled-uplink?token=abc".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_upstream_gateway_equals_syntax() {
|
||||
let args = vec!["--upstream-gateway=ws://gw:3001/api/sled-uplink?token=x".to_string()];
|
||||
let result = parse_cli_args(&args).unwrap();
|
||||
assert_eq!(
|
||||
result.upstream_gateway,
|
||||
Some("ws://gw:3001/api/sled-uplink?token=x".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_upstream_gateway_missing_value_is_error() {
|
||||
let args = vec!["--upstream-gateway".to_string()];
|
||||
assert!(parse_cli_args(&args).is_err());
|
||||
}
|
||||
|
||||
// ── resolve_path_arg ────────────────────────────────────────────
|
||||
|
||||
@@ -55,6 +55,8 @@ pub fn build_gateway_route(state_arc: Arc<GatewayState>) -> impl poem::Endpoint
|
||||
)
|
||||
// Agent registration via CRDT-sync WebSocket.
|
||||
.at("/crdt-sync", poem::get(gateway_crdt_sync_handler))
|
||||
// Sled uplink: permission-forwarding WebSocket from sleds to gateway.
|
||||
.at("/api/sled-uplink", poem::get(gateway_sled_uplink_handler))
|
||||
// Agent management REST endpoints.
|
||||
.at(
|
||||
"/gateway/agents/:id/assign",
|
||||
@@ -126,6 +128,7 @@ pub async fn run(config_path: &Path, port: u16) -> Result<(), std::io::Error> {
|
||||
gateway_project_urls,
|
||||
port,
|
||||
Some(state_arc.event_tx.clone()),
|
||||
Arc::clone(&state_arc.perm_rx),
|
||||
);
|
||||
*state_arc.bot_handle.lock().await = bot_abort;
|
||||
*state_arc.bot_shutdown_tx.lock().await = Some(bot_shutdown_tx);
|
||||
|
||||
@@ -13,7 +13,10 @@ fn make_test_state() -> Arc<GatewayState> {
|
||||
url: "http://test:3001".into(),
|
||||
},
|
||||
);
|
||||
let config = GatewayConfig { projects };
|
||||
let config = GatewayConfig {
|
||||
projects,
|
||||
sled_tokens: BTreeMap::new(),
|
||||
};
|
||||
Arc::new(GatewayState::new(config, PathBuf::new(), 3000).unwrap())
|
||||
}
|
||||
|
||||
@@ -368,7 +371,10 @@ async fn init_project_registers_in_projects_toml_when_name_and_url_given() {
|
||||
url: "http://existing:3001".into(),
|
||||
},
|
||||
);
|
||||
let config = GatewayConfig { projects };
|
||||
let config = GatewayConfig {
|
||||
projects,
|
||||
sled_tokens: BTreeMap::new(),
|
||||
};
|
||||
let state = Arc::new(GatewayState::new(config, config_dir.path().to_path_buf(), 3000).unwrap());
|
||||
|
||||
let result = gateway::init_project(
|
||||
@@ -395,7 +401,10 @@ async fn init_project_duplicate_name_returns_error() {
|
||||
url: "http://taken:3001".into(),
|
||||
},
|
||||
);
|
||||
let config = GatewayConfig { projects };
|
||||
let config = GatewayConfig {
|
||||
projects,
|
||||
sled_tokens: BTreeMap::new(),
|
||||
};
|
||||
let state = Arc::new(GatewayState::new(config, PathBuf::new(), 3000).unwrap());
|
||||
|
||||
let result = gateway::init_project(
|
||||
@@ -444,7 +453,10 @@ async fn init_project_then_wizard_status_integration() {
|
||||
|
||||
let mut projects = BTreeMap::new();
|
||||
projects.insert("mock-project".into(), ProjectEntry { url: mock_url });
|
||||
let config = GatewayConfig { projects };
|
||||
let config = GatewayConfig {
|
||||
projects,
|
||||
sled_tokens: BTreeMap::new(),
|
||||
};
|
||||
let config_dir = tempfile::tempdir().unwrap();
|
||||
let state = Arc::new(GatewayState::new(config, config_dir.path().to_path_buf(), 3000).unwrap());
|
||||
|
||||
@@ -966,7 +978,10 @@ async fn gateway_mcp_sse_proxy_streams_progress_and_final_response() {
|
||||
url: mock_sled.url(),
|
||||
},
|
||||
);
|
||||
let config = GatewayConfig { projects };
|
||||
let config = GatewayConfig {
|
||||
projects,
|
||||
sled_tokens: BTreeMap::new(),
|
||||
};
|
||||
let state = Arc::new(GatewayState::new(config, PathBuf::new(), 3000).unwrap());
|
||||
|
||||
let app = poem::Route::new()
|
||||
@@ -1059,7 +1074,10 @@ async fn gateway_mcp_post_without_sse_returns_plain_json() {
|
||||
url: mock_sled.url(),
|
||||
},
|
||||
);
|
||||
let config = GatewayConfig { projects };
|
||||
let config = GatewayConfig {
|
||||
projects,
|
||||
sled_tokens: BTreeMap::new(),
|
||||
};
|
||||
let state = Arc::new(GatewayState::new(config, PathBuf::new(), 3000).unwrap());
|
||||
|
||||
let app = poem::Route::new()
|
||||
|
||||
@@ -22,4 +22,6 @@ pub use rest::{
|
||||
gateway_bot_config_save_handler, gateway_generate_token_handler, gateway_mode_handler,
|
||||
gateway_remove_project_handler,
|
||||
};
|
||||
pub use websocket::{gateway_crdt_sync_handler, gateway_event_push_handler};
|
||||
pub use websocket::{
|
||||
gateway_crdt_sync_handler, gateway_event_push_handler, gateway_sled_uplink_handler,
|
||||
};
|
||||
|
||||
@@ -146,6 +146,134 @@ pub async fn gateway_crdt_sync_handler(
|
||||
.into_response()
|
||||
}
|
||||
|
||||
// ── Sled uplink WebSocket handler ────────────────────────────────────────────
|
||||
|
||||
/// Query parameters accepted on the `/api/sled-uplink` WebSocket upgrade.
|
||||
#[derive(Deserialize)]
|
||||
struct SledUplinkParams {
|
||||
/// Shared-secret token identifying the connecting sled (from `[sled_tokens]` in `projects.toml`).
|
||||
token: Option<String>,
|
||||
}
|
||||
|
||||
/// `GET /api/sled-uplink` — gateway-side WebSocket endpoint for sled permission uplinks.
|
||||
///
|
||||
/// # Authentication
|
||||
///
|
||||
/// The connecting sled must supply a valid shared-secret token via the `token`
|
||||
/// query parameter. Tokens are configured in `[sled_tokens]` in `projects.toml`
|
||||
/// as `sled_id = "secret"` entries.
|
||||
///
|
||||
/// # Protocol
|
||||
///
|
||||
/// See `sled_uplink.rs` for the wire format ([`UplinkEnvelope`]). The gateway
|
||||
/// accepts `perm_request` messages, injects them into the local permission
|
||||
/// pipeline (via `state.perm_tx`), and sends `perm_response` frames back to the
|
||||
/// sled once the Matrix bot resolves them. Multiple sleds are demuxed by
|
||||
/// connection: each handler owns exactly one sled's request/response flow.
|
||||
#[handler]
|
||||
pub async fn gateway_sled_uplink_handler(
|
||||
ws: WebSocket,
|
||||
state: Data<&Arc<GatewayState>>,
|
||||
Query(params): Query<SledUplinkParams>,
|
||||
) -> poem::Response {
|
||||
let token = match params.token {
|
||||
Some(t) if !t.is_empty() => t,
|
||||
_ => {
|
||||
return poem::Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.body("token query parameter required");
|
||||
}
|
||||
};
|
||||
|
||||
let sled_id = match state.sled_tokens.get(&token) {
|
||||
Some(id) => id.clone(),
|
||||
None => {
|
||||
return poem::Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.body("invalid token");
|
||||
}
|
||||
};
|
||||
|
||||
use poem::IntoResponse as _;
|
||||
let perm_tx = state.perm_tx.clone();
|
||||
ws.on_upgrade(move |socket| async move {
|
||||
let (mut sink, mut stream) = socket.split();
|
||||
// Aggregator channel: spawned per-request tasks send (req_id, decision) here
|
||||
// so the main loop can write perm_response frames back to the sled.
|
||||
let (agg_tx, mut agg_rx) = tokio::sync::mpsc::unbounded_channel::<(
|
||||
String,
|
||||
crate::http::context::PermissionDecision,
|
||||
)>();
|
||||
|
||||
crate::slog!("[gateway/sled-uplink] Sled '{}' connected", sled_id);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = stream.next() => {
|
||||
let text = match msg {
|
||||
Some(Ok(WsMessage::Text(t))) => t,
|
||||
Some(Ok(WsMessage::Close(_))) | None => break,
|
||||
_ => continue,
|
||||
};
|
||||
let Ok(env) = serde_json::from_str::<crate::sled_uplink::UplinkEnvelope>(&text) else {
|
||||
continue;
|
||||
};
|
||||
if env.msg_type == "perm_request" {
|
||||
let req_id = env.req_id.clone();
|
||||
let tool_name = env.payload.get("tool_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
let tool_input = env.payload.get("tool_input")
|
||||
.cloned()
|
||||
.unwrap_or(serde_json::Value::Null);
|
||||
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
|
||||
let fwd = crate::http::context::PermissionForward {
|
||||
request_id: format!("{sled_id}:{req_id}"),
|
||||
tool_name,
|
||||
tool_input,
|
||||
response_tx,
|
||||
};
|
||||
if perm_tx.send(fwd).is_err() {
|
||||
break;
|
||||
}
|
||||
let agg_tx2 = agg_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
let decision = response_rx
|
||||
.await
|
||||
.unwrap_or(crate::http::context::PermissionDecision::Deny);
|
||||
let _ = agg_tx2.send((req_id, decision));
|
||||
});
|
||||
}
|
||||
}
|
||||
Some((req_id, decision)) = agg_rx.recv() => {
|
||||
use crate::http::context::PermissionDecision;
|
||||
let (approved, always_allow) = match decision {
|
||||
PermissionDecision::AlwaysAllow => (true, true),
|
||||
PermissionDecision::Approve => (true, false),
|
||||
PermissionDecision::Deny => (false, false),
|
||||
};
|
||||
let resp = crate::sled_uplink::UplinkEnvelope {
|
||||
msg_type: "perm_response".to_string(),
|
||||
req_id,
|
||||
payload: serde_json::json!({
|
||||
"approved": approved,
|
||||
"always_allow": always_allow,
|
||||
}),
|
||||
};
|
||||
let Ok(text) = serde_json::to_string(&resp) else { continue };
|
||||
if sink.send(WsMessage::Text(text)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
crate::slog!("[gateway/sled-uplink] Sled '{}' disconnected", sled_id);
|
||||
})
|
||||
.into_response()
|
||||
}
|
||||
|
||||
// ── Event-push WebSocket handler ─────────────────────────────────────────────
|
||||
|
||||
/// Query parameters accepted on the `/gateway/events/push` WebSocket upgrade.
|
||||
|
||||
@@ -38,6 +38,8 @@ pub mod rebuild;
|
||||
mod service;
|
||||
/// Services — shared service bundle injected into HTTP handlers and bot tasks.
|
||||
pub mod services;
|
||||
/// Sled uplink — background task that forwards permission requests to an upstream gateway.
|
||||
pub mod sled_uplink;
|
||||
mod startup;
|
||||
mod state;
|
||||
mod store;
|
||||
@@ -233,6 +235,14 @@ async fn main() -> Result<(), std::io::Error> {
|
||||
status: agents.status_broadcaster(),
|
||||
});
|
||||
|
||||
// Sled uplink: forward permission requests to an upstream gateway when configured.
|
||||
let upstream_gateway = cli
|
||||
.upstream_gateway
|
||||
.clone()
|
||||
.or_else(|| std::env::var("HUSKIES_UPSTREAM_GATEWAY").ok())
|
||||
.unwrap_or_default();
|
||||
sled_uplink::spawn_uplink_task(upstream_gateway, Arc::clone(&services));
|
||||
|
||||
// ── Build bot contexts (WhatsApp / Slack / Discord) ───────────────────────
|
||||
let (bot_ctxs, matrix_shutdown_rx) =
|
||||
startup::bots::build_bot_contexts(&startup_root, &services);
|
||||
|
||||
@@ -19,6 +19,12 @@ pub struct GatewayConfig {
|
||||
/// Map of project name → container URL.
|
||||
#[serde(default)]
|
||||
pub projects: BTreeMap<String, ProjectEntry>,
|
||||
/// Map of sled_id → shared secret token for sled-uplink authentication.
|
||||
///
|
||||
/// Each entry allows a sled identified by `sled_id` to connect to
|
||||
/// `/api/sled-uplink` using the given secret token as a bearer credential.
|
||||
#[serde(default)]
|
||||
pub sled_tokens: BTreeMap<String, String>,
|
||||
}
|
||||
|
||||
/// Validate that a gateway config has at least one project.
|
||||
@@ -113,6 +119,7 @@ url = "http://localhost:3002"
|
||||
fn validate_config_rejects_empty() {
|
||||
let config = GatewayConfig {
|
||||
projects: BTreeMap::new(),
|
||||
sled_tokens: BTreeMap::new(),
|
||||
};
|
||||
assert!(validate_config(&config).is_err());
|
||||
}
|
||||
@@ -132,7 +139,10 @@ url = "http://localhost:3002"
|
||||
url: "http://a".into(),
|
||||
},
|
||||
);
|
||||
let config = GatewayConfig { projects };
|
||||
let config = GatewayConfig {
|
||||
projects,
|
||||
sled_tokens: BTreeMap::new(),
|
||||
};
|
||||
assert_eq!(validate_config(&config).unwrap(), "alpha");
|
||||
}
|
||||
|
||||
|
||||
@@ -21,13 +21,23 @@ pub fn load_config(path: &Path) -> Result<GatewayConfig, String> {
|
||||
|
||||
/// Persist the current projects map to `<config_dir>/projects.toml`.
|
||||
/// Silently ignores write errors or skips when `config_dir` is empty.
|
||||
///
|
||||
/// Existing `[sled_tokens]` entries are preserved so that adding or removing
|
||||
/// projects via the UI does not wipe the sled authentication tokens.
|
||||
pub async fn save_config(projects: &BTreeMap<String, ProjectEntry>, config_dir: &Path) {
|
||||
if config_dir.as_os_str().is_empty() {
|
||||
return;
|
||||
}
|
||||
let path = config_dir.join("projects.toml");
|
||||
let sled_tokens = tokio::fs::read_to_string(&path)
|
||||
.await
|
||||
.ok()
|
||||
.and_then(|data| toml::from_str::<GatewayConfig>(&data).ok())
|
||||
.map(|c| c.sled_tokens)
|
||||
.unwrap_or_default();
|
||||
let config = GatewayConfig {
|
||||
projects: projects.clone(),
|
||||
sled_tokens,
|
||||
};
|
||||
if let Ok(data) = toml::to_string_pretty(&config) {
|
||||
let _ = tokio::fs::write(&path, data).await;
|
||||
@@ -518,27 +528,20 @@ pub fn spawn_gateway_bot(
|
||||
gateway_project_urls: BTreeMap<String, String>,
|
||||
port: u16,
|
||||
gateway_event_tx: Option<tokio::sync::broadcast::Sender<super::GatewayStatusEvent>>,
|
||||
perm_rx: std::sync::Arc<
|
||||
tokio::sync::Mutex<
|
||||
tokio::sync::mpsc::UnboundedReceiver<crate::http::context::PermissionForward>,
|
||||
>,
|
||||
>,
|
||||
) -> (
|
||||
Option<tokio::task::AbortHandle>,
|
||||
tokio::sync::watch::Sender<Option<crate::rebuild::ShutdownReason>>,
|
||||
) {
|
||||
use crate::agents::AgentPool;
|
||||
use crate::services::Services;
|
||||
use tokio::sync::{broadcast, mpsc};
|
||||
|
||||
let (watcher_tx, _) = broadcast::channel(16);
|
||||
let (perm_tx, perm_rx) = mpsc::unbounded_channel();
|
||||
// Keep the sender alive for the gateway's lifetime so the matrix bot's
|
||||
// `permission_listener` task doesn't exit immediately with
|
||||
// "perm_rx channel closed". Previously `_perm_tx` was dropped when
|
||||
// `spawn_gateway_bot` returned, closing the channel before the
|
||||
// listener could even register. Story 898 (sled→gateway WS uplink)
|
||||
// will eventually wire in a real sender; for now the leak keeps the
|
||||
// channel open with no senders writing to it, matching the original
|
||||
// intent of "listener watches forever, waiting for requests".
|
||||
std::mem::forget(perm_tx);
|
||||
let perm_rx = std::sync::Arc::new(tokio::sync::Mutex::new(perm_rx));
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
let (watcher_tx, _) = broadcast::channel::<crate::io::watcher::WatcherEvent>(16);
|
||||
let (shutdown_tx, shutdown_rx) =
|
||||
tokio::sync::watch::channel::<Option<crate::rebuild::ShutdownReason>>(None);
|
||||
// shutdown_tx is intentionally NOT forgotten — the caller holds it and
|
||||
@@ -611,6 +614,9 @@ mod tests {
|
||||
let active = std::sync::Arc::new(tokio::sync::RwLock::new("proj".to_string()));
|
||||
let (event_tx, _) = tokio::sync::broadcast::channel(4);
|
||||
|
||||
let (_perm_tx, perm_rx) =
|
||||
tokio::sync::mpsc::unbounded_channel::<crate::http::context::PermissionForward>();
|
||||
let perm_rx = std::sync::Arc::new(tokio::sync::Mutex::new(perm_rx));
|
||||
let (handle, shutdown_tx) = spawn_gateway_bot(
|
||||
tmp.path(),
|
||||
active,
|
||||
@@ -618,6 +624,7 @@ mod tests {
|
||||
std::collections::BTreeMap::new(),
|
||||
3001,
|
||||
Some(event_tx),
|
||||
perm_rx,
|
||||
);
|
||||
|
||||
// No bot.toml in tmp → no abort handle spawned.
|
||||
|
||||
@@ -22,6 +22,7 @@ pub use io::{
|
||||
spawn_gateway_notification_poller,
|
||||
};
|
||||
|
||||
use crate::http::context::PermissionForward;
|
||||
use crate::rebuild::ShutdownReason;
|
||||
use io::Client;
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
@@ -29,6 +30,7 @@ use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex as TokioMutex;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
pub use crate::crdt_state::NodePresenceView;
|
||||
|
||||
@@ -122,6 +124,22 @@ pub struct GatewayState {
|
||||
///
|
||||
/// Call `event_tx.subscribe()` to obtain a receiver for outbound fan-out.
|
||||
pub event_tx: tokio::sync::broadcast::Sender<GatewayStatusEvent>,
|
||||
/// Sender end of the gateway's permission channel.
|
||||
///
|
||||
/// The sled-uplink handler uses this to inject `perm_request` messages
|
||||
/// received from connected sleds into the gateway's Matrix bot permission
|
||||
/// pipeline.
|
||||
pub perm_tx: mpsc::UnboundedSender<PermissionForward>,
|
||||
/// Receiver end of the gateway's permission channel (shared with the Matrix bot).
|
||||
///
|
||||
/// The Matrix bot's `permission_listener` holds this locked for its lifetime;
|
||||
/// the sled-uplink WS handler sends requests via `perm_tx`.
|
||||
pub perm_rx: Arc<TokioMutex<mpsc::UnboundedReceiver<PermissionForward>>>,
|
||||
/// Reversed sled-token map: token → sled_id.
|
||||
///
|
||||
/// Built at startup from [`GatewayConfig::sled_tokens`] (which maps
|
||||
/// sled_id → token). The handler looks up incoming tokens in O(1).
|
||||
pub sled_tokens: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl GatewayState {
|
||||
@@ -141,6 +159,12 @@ impl GatewayState {
|
||||
.filter(|p| gateway_config.projects.contains_key(p))
|
||||
.unwrap_or(first_from_config);
|
||||
let (event_tx, _) = tokio::sync::broadcast::channel(EVENT_CHANNEL_CAPACITY);
|
||||
let (perm_tx, perm_rx) = mpsc::unbounded_channel::<PermissionForward>();
|
||||
let sled_tokens: HashMap<String, String> = gateway_config
|
||||
.sled_tokens
|
||||
.iter()
|
||||
.map(|(sled_id, token)| (token.clone(), sled_id.clone()))
|
||||
.collect();
|
||||
Ok(Self {
|
||||
projects: Arc::new(RwLock::new(gateway_config.projects)),
|
||||
active_project: Arc::new(RwLock::new(first)),
|
||||
@@ -151,6 +175,9 @@ impl GatewayState {
|
||||
bot_handle: Arc::new(TokioMutex::new(None)),
|
||||
bot_shutdown_tx: Arc::new(TokioMutex::new(None)),
|
||||
event_tx,
|
||||
perm_tx,
|
||||
perm_rx: Arc::new(TokioMutex::new(perm_rx)),
|
||||
sled_tokens,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -477,6 +504,7 @@ pub async fn save_bot_config_and_restart(state: &GatewayState, content: &str) ->
|
||||
gateway_project_urls,
|
||||
state.port,
|
||||
Some(state.event_tx.clone()),
|
||||
Arc::clone(&state.perm_rx),
|
||||
);
|
||||
*handle = new_handle;
|
||||
*state.bot_shutdown_tx.lock().await = Some(new_shutdown_tx);
|
||||
@@ -502,13 +530,17 @@ mod tests {
|
||||
},
|
||||
);
|
||||
}
|
||||
GatewayConfig { projects }
|
||||
GatewayConfig {
|
||||
projects,
|
||||
sled_tokens: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gateway_state_rejects_empty_config() {
|
||||
let config = GatewayConfig {
|
||||
projects: BTreeMap::new(),
|
||||
sled_tokens: BTreeMap::new(),
|
||||
};
|
||||
assert!(GatewayState::new(config, PathBuf::from("."), 3000).is_err());
|
||||
}
|
||||
|
||||
@@ -0,0 +1,543 @@
|
||||
//! Sled uplink — background task that maintains a WebSocket connection from a
|
||||
//! sled (standard huskies instance) to an upstream gateway for permission
|
||||
//! request forwarding.
|
||||
//!
|
||||
//! When `HUSKIES_UPSTREAM_GATEWAY` is set (or `--upstream-gateway` is passed
|
||||
//! on the CLI), this module spawns a task that:
|
||||
//!
|
||||
//! 1. Acquires `services.perm_rx` for its lifetime (matching the Matrix bot's
|
||||
//! `permission_listener` pattern), preventing `tool_prompt_permission` from
|
||||
//! auto-denying requests with "no interactive session".
|
||||
//! 2. Maintains a persistent WebSocket connection to the gateway's
|
||||
//! `/api/sled-uplink` endpoint.
|
||||
//! 3. Forwards each [`PermissionForward`] as a `perm_request` JSON envelope.
|
||||
//! 4. Awaits the matching `perm_response` envelope from the gateway.
|
||||
//! 5. Reconnects with exponential back-off on connection drop, fail-closing
|
||||
//! any in-flight requests with [`PermissionDecision::Deny`].
|
||||
|
||||
use crate::http::context::{PermissionDecision, PermissionForward};
|
||||
use crate::services::Services;
|
||||
use crate::slog;
|
||||
use futures::SinkExt as _;
|
||||
use futures::StreamExt as _;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_tungstenite::tungstenite::Message as WsMessage;
|
||||
|
||||
// ── Back-off constants ────────────────────────────────────────────────────────
|
||||
|
||||
const INITIAL_BACKOFF_SECS: u64 = 1;
|
||||
const MAX_BACKOFF_SECS: u64 = 60;
|
||||
const BACKOFF_MULTIPLIER: u64 = 2;
|
||||
|
||||
// ── Wire protocol ─────────────────────────────────────────────────────────────
|
||||
|
||||
/// Extensible JSON envelope for all sled↔gateway uplink messages.
|
||||
///
|
||||
/// Phase 1 defines `perm_request` (sled→gateway) and `perm_response`
|
||||
/// (gateway→sled). Future phases add new `type` values without changing this
|
||||
/// framing.
|
||||
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
|
||||
pub struct UplinkEnvelope {
|
||||
/// Message type discriminant (e.g. `"perm_request"`, `"perm_response"`).
|
||||
#[serde(rename = "type")]
|
||||
pub msg_type: String,
|
||||
/// Correlation ID — the sled chooses this for each request; the gateway
|
||||
/// echoes it back so the sled can demux concurrent responses.
|
||||
pub req_id: String,
|
||||
/// Message-specific payload. Varies by `msg_type`.
|
||||
pub payload: serde_json::Value,
|
||||
}
|
||||
|
||||
// ── Public API ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Spawn the sled uplink background task.
|
||||
///
|
||||
/// Does nothing when `upstream_url` is empty. When active, the task holds
|
||||
/// `services.perm_rx` locked for its lifetime (preventing auto-deny in
|
||||
/// `tool_prompt_permission`) and forwards all permission requests to the
|
||||
/// gateway. Reconnects automatically with exponential back-off.
|
||||
pub fn spawn_uplink_task(upstream_url: String, services: Arc<Services>) {
|
||||
if upstream_url.is_empty() {
|
||||
return;
|
||||
}
|
||||
slog!("[uplink] Spawning sled uplink task (gateway={upstream_url})");
|
||||
tokio::spawn(async move {
|
||||
// Acquire perm_rx for this task's entire lifetime. While this lock is
|
||||
// held, try_lock() inside tool_prompt_permission fails — meaning
|
||||
// requests flow to perm_tx (which we drain here) rather than auto-deny.
|
||||
let mut perm_rx = services.perm_rx.lock().await;
|
||||
slog!("[uplink] Acquired perm_rx; maintaining gateway connection");
|
||||
|
||||
let mut backoff = INITIAL_BACKOFF_SECS;
|
||||
loop {
|
||||
match run_uplink_session(&upstream_url, &mut perm_rx).await {
|
||||
Ok(()) => {
|
||||
slog!("[uplink] Connection closed cleanly; reconnecting in {backoff}s");
|
||||
}
|
||||
Err(ref e) => {
|
||||
slog!("[uplink] Session error: {e}; reconnecting in {backoff}s");
|
||||
}
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_secs(backoff)).await;
|
||||
backoff = backoff
|
||||
.saturating_mul(BACKOFF_MULTIPLIER)
|
||||
.min(MAX_BACKOFF_SECS);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// ── Private helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
/// Run a single uplink session: connect, pump messages bidirectionally until
|
||||
/// disconnect or channel close, then fail-close any in-flight requests.
|
||||
async fn run_uplink_session(
|
||||
url: &str,
|
||||
perm_rx: &mut tokio::sync::mpsc::UnboundedReceiver<PermissionForward>,
|
||||
) -> Result<(), String> {
|
||||
let (ws_stream, _) = tokio_tungstenite::connect_async(url)
|
||||
.await
|
||||
.map_err(|e| format!("WS connect to {url}: {e}"))?;
|
||||
slog!("[uplink] Connected to gateway uplink endpoint");
|
||||
|
||||
let (mut ws_sink, mut ws_rx) = ws_stream.split();
|
||||
let mut in_flight: HashMap<String, oneshot::Sender<PermissionDecision>> = HashMap::new();
|
||||
|
||||
let result = pump_messages(&mut ws_sink, &mut ws_rx, perm_rx, &mut in_flight).await;
|
||||
fail_close_all(&mut in_flight);
|
||||
result
|
||||
}
|
||||
|
||||
/// Drive the bidirectional message loop for one session.
|
||||
async fn pump_messages(
|
||||
ws_sink: &mut (impl futures::Sink<WsMessage, Error = tokio_tungstenite::tungstenite::Error> + Unpin),
|
||||
ws_rx: &mut (
|
||||
impl futures::Stream<Item = Result<WsMessage, tokio_tungstenite::tungstenite::Error>>
|
||||
+ Unpin
|
||||
),
|
||||
perm_rx: &mut tokio::sync::mpsc::UnboundedReceiver<PermissionForward>,
|
||||
in_flight: &mut HashMap<String, oneshot::Sender<PermissionDecision>>,
|
||||
) -> Result<(), String> {
|
||||
loop {
|
||||
tokio::select! {
|
||||
// New permission request from the MCP layer.
|
||||
maybe_fwd = perm_rx.recv() => {
|
||||
match maybe_fwd {
|
||||
None => return Ok(()), // channel closed — server shutting down
|
||||
Some(fwd) => {
|
||||
let PermissionForward {
|
||||
request_id,
|
||||
tool_name,
|
||||
tool_input,
|
||||
response_tx,
|
||||
} = fwd;
|
||||
let env = UplinkEnvelope {
|
||||
msg_type: "perm_request".to_string(),
|
||||
req_id: request_id.clone(),
|
||||
payload: serde_json::json!({
|
||||
"tool_name": tool_name,
|
||||
"tool_input": tool_input,
|
||||
}),
|
||||
};
|
||||
let text = serde_json::to_string(&env)
|
||||
.map_err(|e| format!("serialise perm_request: {e}"))?;
|
||||
if ws_sink.send(WsMessage::Text(text.into())).await.is_err() {
|
||||
// Connection dead: fail-close this request immediately.
|
||||
let _ = response_tx.send(PermissionDecision::Deny);
|
||||
return Err("WS send failed".to_string());
|
||||
}
|
||||
in_flight.insert(request_id, response_tx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Message arriving from the gateway.
|
||||
msg = ws_rx.next() => {
|
||||
match msg {
|
||||
None | Some(Err(_)) => {
|
||||
return Err("WS stream closed".to_string());
|
||||
}
|
||||
Some(Ok(WsMessage::Close(_))) => {
|
||||
return Err("Gateway sent Close frame".to_string());
|
||||
}
|
||||
Some(Ok(WsMessage::Text(text))) => {
|
||||
on_gateway_text(&text, in_flight);
|
||||
}
|
||||
Some(Ok(WsMessage::Ping(data))) => {
|
||||
let _ = ws_sink.send(WsMessage::Pong(data)).await;
|
||||
}
|
||||
Some(Ok(_)) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse an incoming gateway text frame and resolve any matching in-flight request.
|
||||
fn on_gateway_text(
|
||||
text: &str,
|
||||
in_flight: &mut HashMap<String, oneshot::Sender<PermissionDecision>>,
|
||||
) {
|
||||
let Ok(env) = serde_json::from_str::<UplinkEnvelope>(text) else {
|
||||
return;
|
||||
};
|
||||
if env.msg_type == "perm_response" {
|
||||
resolve_perm_response(env, in_flight);
|
||||
}
|
||||
}
|
||||
|
||||
/// Map a `perm_response` envelope to a [`PermissionDecision`] and wake the
|
||||
/// waiting MCP call.
|
||||
fn resolve_perm_response(
|
||||
env: UplinkEnvelope,
|
||||
in_flight: &mut HashMap<String, oneshot::Sender<PermissionDecision>>,
|
||||
) {
|
||||
let Some(tx) = in_flight.remove(&env.req_id) else {
|
||||
return;
|
||||
};
|
||||
let approved = env
|
||||
.payload
|
||||
.get("approved")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
let always_allow = env
|
||||
.payload
|
||||
.get("always_allow")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
let decision = if always_allow {
|
||||
PermissionDecision::AlwaysAllow
|
||||
} else if approved {
|
||||
PermissionDecision::Approve
|
||||
} else {
|
||||
PermissionDecision::Deny
|
||||
};
|
||||
let _ = tx.send(decision);
|
||||
}
|
||||
|
||||
/// Deny all in-flight requests (fail-closed on connection drop — AC 8).
|
||||
fn fail_close_all(in_flight: &mut HashMap<String, oneshot::Sender<PermissionDecision>>) {
|
||||
for (_, tx) in in_flight.drain() {
|
||||
let _ = tx.send(PermissionDecision::Deny);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::http::context::PermissionForward;
|
||||
use crate::services::Services;
|
||||
use std::collections::HashMap;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_tungstenite::tungstenite::Message as WsMessage;
|
||||
|
||||
// ── Pure unit tests ───────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn uplink_envelope_roundtrips_json() {
|
||||
let env = UplinkEnvelope {
|
||||
msg_type: "perm_request".to_string(),
|
||||
req_id: "req-1".to_string(),
|
||||
payload: serde_json::json!({"tool_name": "Bash", "tool_input": {}}),
|
||||
};
|
||||
let text = serde_json::to_string(&env).unwrap();
|
||||
let back: UplinkEnvelope = serde_json::from_str(&text).unwrap();
|
||||
assert_eq!(back.msg_type, "perm_request");
|
||||
assert_eq!(back.req_id, "req-1");
|
||||
assert_eq!(back.payload["tool_name"], "Bash");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_perm_response_approve() {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let mut in_flight = HashMap::new();
|
||||
in_flight.insert("r1".to_string(), tx);
|
||||
let env = UplinkEnvelope {
|
||||
msg_type: "perm_response".to_string(),
|
||||
req_id: "r1".to_string(),
|
||||
payload: serde_json::json!({"approved": true, "always_allow": false}),
|
||||
};
|
||||
resolve_perm_response(env, &mut in_flight);
|
||||
assert!(in_flight.is_empty());
|
||||
assert_eq!(rx.blocking_recv().unwrap(), PermissionDecision::Approve);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_perm_response_deny() {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let mut in_flight = HashMap::new();
|
||||
in_flight.insert("r2".to_string(), tx);
|
||||
let env = UplinkEnvelope {
|
||||
msg_type: "perm_response".to_string(),
|
||||
req_id: "r2".to_string(),
|
||||
payload: serde_json::json!({"approved": false}),
|
||||
};
|
||||
resolve_perm_response(env, &mut in_flight);
|
||||
assert_eq!(rx.blocking_recv().unwrap(), PermissionDecision::Deny);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_perm_response_always_allow() {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let mut in_flight = HashMap::new();
|
||||
in_flight.insert("r3".to_string(), tx);
|
||||
let env = UplinkEnvelope {
|
||||
msg_type: "perm_response".to_string(),
|
||||
req_id: "r3".to_string(),
|
||||
payload: serde_json::json!({"approved": true, "always_allow": true}),
|
||||
};
|
||||
resolve_perm_response(env, &mut in_flight);
|
||||
assert_eq!(rx.blocking_recv().unwrap(), PermissionDecision::AlwaysAllow);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_perm_response_unknown_req_id_is_noop() {
|
||||
let mut in_flight: HashMap<String, oneshot::Sender<PermissionDecision>> = HashMap::new();
|
||||
let env = UplinkEnvelope {
|
||||
msg_type: "perm_response".to_string(),
|
||||
req_id: "missing".to_string(),
|
||||
payload: serde_json::json!({"approved": true}),
|
||||
};
|
||||
resolve_perm_response(env, &mut in_flight); // must not panic
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fail_close_all_denies_all_pending() {
|
||||
let (tx1, rx1) = oneshot::channel();
|
||||
let (tx2, rx2) = oneshot::channel();
|
||||
let mut in_flight = HashMap::new();
|
||||
in_flight.insert("a".to_string(), tx1);
|
||||
in_flight.insert("b".to_string(), tx2);
|
||||
fail_close_all(&mut in_flight);
|
||||
assert!(in_flight.is_empty());
|
||||
assert_eq!(rx1.blocking_recv().unwrap(), PermissionDecision::Deny);
|
||||
assert_eq!(rx2.blocking_recv().unwrap(), PermissionDecision::Deny);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn on_gateway_text_ignores_unknown_type() {
|
||||
let mut in_flight: HashMap<String, oneshot::Sender<PermissionDecision>> = HashMap::new();
|
||||
on_gateway_text(
|
||||
r#"{"type":"future_type","req_id":"x","payload":{}}"#,
|
||||
&mut in_flight,
|
||||
);
|
||||
assert!(in_flight.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn on_gateway_text_ignores_invalid_json() {
|
||||
let mut in_flight: HashMap<String, oneshot::Sender<PermissionDecision>> = HashMap::new();
|
||||
on_gateway_text("not-json", &mut in_flight); // must not panic
|
||||
assert!(in_flight.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spawn_uplink_task_noop_when_url_empty() {
|
||||
let (_perm_tx, perm_rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
let agents = Arc::new(crate::agents::AgentPool::new_test(3000));
|
||||
let services = Arc::new(Services {
|
||||
project_root: std::path::PathBuf::from("/tmp"),
|
||||
status: agents.status_broadcaster(),
|
||||
agents,
|
||||
bot_name: "Test".to_string(),
|
||||
bot_user_id: String::new(),
|
||||
ambient_rooms: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
|
||||
perm_rx: Arc::new(tokio::sync::Mutex::new(perm_rx)),
|
||||
pending_perm_replies: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||
permission_timeout_secs: 120,
|
||||
});
|
||||
// Empty URL → noop; if it panicked or blocked the test would fail.
|
||||
spawn_uplink_task(String::new(), services);
|
||||
}
|
||||
|
||||
// ── AC 11: permission approved via uplink ────────────────────────
|
||||
// "Simulate matrix bot triggering a Bash permission, sled forwards via
|
||||
// uplink, mock matrix transport approves, tool call proceeds."
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration_perm_request_approved_via_uplink() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let port = listener.local_addr().unwrap().port();
|
||||
let url = format!("ws://127.0.0.1:{port}");
|
||||
|
||||
// Mock gateway: accept one connection, receive perm_request, reply approved.
|
||||
let gw_task = tokio::spawn(async move {
|
||||
let (tcp, _) = listener.accept().await.unwrap();
|
||||
let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap();
|
||||
let msg = ws.next().await.unwrap().unwrap();
|
||||
let text = match msg {
|
||||
WsMessage::Text(t) => t.to_string(),
|
||||
other => panic!("expected Text; got {other:?}"),
|
||||
};
|
||||
let env: UplinkEnvelope = serde_json::from_str(&text).unwrap();
|
||||
assert_eq!(env.msg_type, "perm_request");
|
||||
assert_eq!(env.payload["tool_name"], "Bash");
|
||||
let resp = UplinkEnvelope {
|
||||
msg_type: "perm_response".to_string(),
|
||||
req_id: env.req_id,
|
||||
payload: serde_json::json!({"approved": true, "always_allow": false}),
|
||||
};
|
||||
ws.send(WsMessage::Text(
|
||||
serde_json::to_string(&resp).unwrap().into(),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let (perm_tx, perm_rx) = tokio::sync::mpsc::unbounded_channel::<PermissionForward>();
|
||||
let agents = Arc::new(crate::agents::AgentPool::new_test(3000));
|
||||
let services = Arc::new(Services {
|
||||
project_root: std::path::PathBuf::from("/tmp"),
|
||||
status: agents.status_broadcaster(),
|
||||
agents,
|
||||
bot_name: "Test".to_string(),
|
||||
bot_user_id: String::new(),
|
||||
ambient_rooms: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
|
||||
perm_rx: Arc::new(tokio::sync::Mutex::new(perm_rx)),
|
||||
pending_perm_replies: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||
permission_timeout_secs: 120,
|
||||
});
|
||||
|
||||
spawn_uplink_task(url, Arc::clone(&services));
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
perm_tx
|
||||
.send(PermissionForward {
|
||||
request_id: "req-test-1".to_string(),
|
||||
tool_name: "Bash".to_string(),
|
||||
tool_input: serde_json::json!({"command": "echo hello"}),
|
||||
response_tx,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let decision = tokio::time::timeout(std::time::Duration::from_secs(5), response_rx)
|
||||
.await
|
||||
.expect("timeout waiting for decision")
|
||||
.expect("oneshot dropped");
|
||||
|
||||
assert_eq!(decision, PermissionDecision::Approve);
|
||||
gw_task.await.unwrap();
|
||||
}
|
||||
|
||||
// ── AC 12: sled disconnects and reconnects ────────────────────────
|
||||
// "Sled disconnects and reconnects mid-session; subsequent permission
|
||||
// requests succeed once reconnected."
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration_reconnects_after_disconnect() {
|
||||
use std::sync::Arc as StdArc;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let port = listener.local_addr().unwrap().port();
|
||||
let url = format!("ws://127.0.0.1:{port}");
|
||||
|
||||
let conn_count = StdArc::new(AtomicU32::new(0));
|
||||
let conn_count2 = StdArc::clone(&conn_count);
|
||||
|
||||
tokio::spawn(async move {
|
||||
// First connection: receive the request then immediately drop (simulates network failure).
|
||||
{
|
||||
let (tcp, _) = listener.accept().await.unwrap();
|
||||
let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap();
|
||||
conn_count2.fetch_add(1, Ordering::SeqCst);
|
||||
let _ = ws.next().await; // consume one frame
|
||||
drop(ws); // close without replying → fail-close on sled side
|
||||
}
|
||||
// Second connection: approve the next request.
|
||||
{
|
||||
let (tcp, _) = listener.accept().await.unwrap();
|
||||
let mut ws = tokio_tungstenite::accept_async(tcp).await.unwrap();
|
||||
conn_count2.fetch_add(1, Ordering::SeqCst);
|
||||
if let Some(Ok(WsMessage::Text(text))) = ws.next().await {
|
||||
let env: UplinkEnvelope = serde_json::from_str(&text).unwrap();
|
||||
if env.msg_type == "perm_request" {
|
||||
let resp = UplinkEnvelope {
|
||||
msg_type: "perm_response".to_string(),
|
||||
req_id: env.req_id,
|
||||
payload: serde_json::json!({"approved": true, "always_allow": false}),
|
||||
};
|
||||
let _ = ws
|
||||
.send(WsMessage::Text(
|
||||
serde_json::to_string(&resp).unwrap().into(),
|
||||
))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let (perm_tx, perm_rx) = tokio::sync::mpsc::unbounded_channel::<PermissionForward>();
|
||||
let agents = Arc::new(crate::agents::AgentPool::new_test(3000));
|
||||
let services = Arc::new(Services {
|
||||
project_root: std::path::PathBuf::from("/tmp"),
|
||||
status: agents.status_broadcaster(),
|
||||
agents,
|
||||
bot_name: "Test".to_string(),
|
||||
bot_user_id: String::new(),
|
||||
ambient_rooms: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
|
||||
perm_rx: Arc::new(tokio::sync::Mutex::new(perm_rx)),
|
||||
pending_perm_replies: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||
permission_timeout_secs: 120,
|
||||
});
|
||||
|
||||
spawn_uplink_task(url, Arc::clone(&services));
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
|
||||
// First request is sent on the connection that drops → denied.
|
||||
let (tx1, rx1) = oneshot::channel();
|
||||
perm_tx
|
||||
.send(PermissionForward {
|
||||
request_id: "req-drop".to_string(),
|
||||
tool_name: "Bash".to_string(),
|
||||
tool_input: serde_json::json!({}),
|
||||
response_tx: tx1,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let d1 = tokio::time::timeout(std::time::Duration::from_secs(5), rx1)
|
||||
.await
|
||||
.expect("timeout on first request")
|
||||
.expect("oneshot dropped");
|
||||
assert_eq!(
|
||||
d1,
|
||||
PermissionDecision::Deny,
|
||||
"dropped connection must fail-close"
|
||||
);
|
||||
|
||||
// Wait for the 1-second reconnect backoff plus buffer.
|
||||
tokio::time::sleep(std::time::Duration::from_millis(1500)).await;
|
||||
|
||||
// Second request arrives on the reconnected session → approved.
|
||||
let (tx2, rx2) = oneshot::channel();
|
||||
perm_tx
|
||||
.send(PermissionForward {
|
||||
request_id: "req-reconnect".to_string(),
|
||||
tool_name: "Write".to_string(),
|
||||
tool_input: serde_json::json!({}),
|
||||
response_tx: tx2,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let d2 = tokio::time::timeout(std::time::Duration::from_secs(5), rx2)
|
||||
.await
|
||||
.expect("timeout on second request")
|
||||
.expect("oneshot dropped");
|
||||
assert_eq!(
|
||||
d2,
|
||||
PermissionDecision::Approve,
|
||||
"reconnected session must approve"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
conn_count.load(Ordering::SeqCst),
|
||||
2,
|
||||
"must have seen exactly 2 connections"
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user