Files
huskies/server/src/gateway_relay.rs
T

480 lines
18 KiB
Rust

//! Gateway relay task — pushes project status events to the gateway via WebSocket.
//!
//! When `gateway_url` is configured in `project.toml` (or the
//! `HUSKIES_GATEWAY_URL` environment variable is set), this module spawns a
//! background task that:
//!
//! 1. Obtains a one-time join token from the gateway via `POST /gateway/tokens`.
//! 2. Connects to the gateway's `/gateway/events/push` WebSocket endpoint.
//! 3. Forwards every [`StatusEvent`] from the local broadcaster as a
//! JSON-encoded [`StoredEvent`] text frame.
//! 4. Reconnects with exponential back-off when the connection drops.
use crate::service::events::StoredEvent;
use crate::service::status::{StatusBroadcaster, StatusEvent};
use crate::slog;
use futures::SinkExt as _;
use futures::StreamExt as _;
use std::sync::Arc;
use tokio_tungstenite::tungstenite::Message as WsMessage;
// ── Back-off constants ────────────────────────────────────────────────────────
/// Initial reconnect delay in seconds.
const INITIAL_BACKOFF_SECS: u64 = 1;
/// Maximum reconnect delay cap in seconds.
const MAX_BACKOFF_SECS: u64 = 60;
/// Multiplier applied after each failed attempt.
const BACKOFF_MULTIPLIER: u64 = 2;
// ── Public API ────────────────────────────────────────────────────────────────
/// Spawn the gateway relay background task.
///
/// Does nothing when `gateway_url` is empty. When running, the task holds a
/// persistent WebSocket connection to `{gateway_url}/gateway/events/push` and
/// forwards every [`StatusEvent`] the local broadcaster emits as a
/// JSON-encoded [`StoredEvent`] text frame. On disconnect the task
/// reconnects automatically with exponential back-off (initial 1 s, cap 60 s).
pub fn spawn_relay_task(
gateway_url: String,
project_name: String,
broadcaster: Arc<StatusBroadcaster>,
client: reqwest::Client,
) {
if gateway_url.is_empty() {
return;
}
slog!("[relay] Spawning gateway relay task (project={project_name}, gateway={gateway_url})");
tokio::spawn(async move {
let mut backoff = INITIAL_BACKOFF_SECS;
loop {
match relay_once(&gateway_url, &project_name, &broadcaster, &client).await {
Ok(()) => {
slog!("[relay] Gateway connection closed cleanly; reconnecting in {backoff}s");
}
Err(e) => {
slog!("[relay] Relay error: {e}; reconnecting in {backoff}s");
}
}
tokio::time::sleep(std::time::Duration::from_secs(backoff)).await;
// Exponential back-off with a hard cap.
backoff = (backoff.saturating_mul(BACKOFF_MULTIPLIER)).min(MAX_BACKOFF_SECS);
}
});
}
// ── Private helpers ───────────────────────────────────────────────────────────
/// Run a single relay session: obtain a token, connect, forward events until
/// disconnect or broadcaster close.
async fn relay_once(
gateway_url: &str,
project_name: &str,
broadcaster: &StatusBroadcaster,
client: &reqwest::Client,
) -> Result<(), String> {
// Subscribe before initiating the network round-trip so no events are
// missed during the connection setup window.
let mut sub = broadcaster.subscribe();
// Step 1: obtain a one-time join token from the gateway.
let token_url = format!("{}/gateway/tokens", gateway_url.trim_end_matches('/'));
let resp = client
.post(&token_url)
.send()
.await
.map_err(|e| format!("token request: {e}"))?;
if !resp.status().is_success() {
return Err(format!("token request returned HTTP {}", resp.status()));
}
let body: serde_json::Value = resp.json().await.map_err(|e| format!("token parse: {e}"))?;
let token = body
.get("token")
.and_then(|t| t.as_str())
.ok_or_else(|| "no 'token' field in gateway response".to_string())?
.to_string();
// Step 2: connect to the WebSocket push endpoint.
let ws_base = to_ws_url(gateway_url.trim_end_matches('/'));
let ws_url = format!("{ws_base}/gateway/events/push?token={token}&project={project_name}");
slog!("[relay] Connecting to gateway events endpoint (project={project_name})");
let (ws_stream, _) = tokio_tungstenite::connect_async(ws_url.as_str())
.await
.map_err(|e| format!("WebSocket connect: {e}"))?;
let (mut sink, _rx) = ws_stream.split();
slog!("[relay] Connected to gateway events endpoint (project={project_name})");
// Step 3: forward StatusEvents until the broadcaster or connection closes.
loop {
match sub.recv().await {
None => {
// Broadcaster was dropped — server is shutting down.
return Ok(());
}
Some(event) => {
let Some(stored) = status_to_stored(event) else {
continue;
};
let text = serde_json::to_string(&stored).map_err(|e| format!("serialise: {e}"))?;
sink.send(WsMessage::Text(text.into()))
.await
.map_err(|e| format!("WebSocket send: {e}"))?;
}
}
}
}
/// Convert a [`StatusEvent`] to a [`StoredEvent`] stamped with the current
/// wall-clock time, or `None` when the event has no `StoredEvent` equivalent
/// (e.g. rate-limit variants).
fn status_to_stored(event: StatusEvent) -> Option<StoredEvent> {
let now_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
match event {
StatusEvent::StageTransition {
story_id,
story_name,
from_stage,
to_stage,
} => Some(StoredEvent::StageTransition {
story_id,
story_name,
from_stage,
to_stage,
timestamp_ms: now_ms,
}),
StatusEvent::MergeFailure {
story_id,
story_name,
reason,
} => Some(StoredEvent::MergeFailure {
story_id,
story_name,
reason,
timestamp_ms: now_ms,
}),
StatusEvent::StoryBlocked {
story_id,
story_name,
reason,
} => Some(StoredEvent::StoryBlocked {
story_id,
story_name,
reason,
timestamp_ms: now_ms,
}),
// Rate-limit events have no StoredEvent equivalent — skip them.
StatusEvent::RateLimitWarning { .. } | StatusEvent::RateLimitHardBlock { .. } => None,
}
}
/// Convert an `http://` or `https://` base URL to its `ws://` / `wss://`
/// equivalent. Returns the input unchanged if it does not start with `http`.
fn to_ws_url(base: &str) -> String {
if let Some(rest) = base.strip_prefix("https://") {
format!("wss://{rest}")
} else if let Some(rest) = base.strip_prefix("http://") {
format!("ws://{rest}")
} else {
base.to_string()
}
}
// ── Tests ─────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn to_ws_url_converts_http() {
assert_eq!(to_ws_url("http://gateway:3000"), "ws://gateway:3000");
}
#[test]
fn to_ws_url_converts_https() {
assert_eq!(to_ws_url("https://gateway:3000"), "wss://gateway:3000");
}
#[test]
fn to_ws_url_passes_through_ws() {
assert_eq!(to_ws_url("ws://already:3000"), "ws://already:3000");
}
#[test]
fn status_to_stored_stage_transition() {
let ev = StatusEvent::StageTransition {
story_id: "42".into(),
story_name: String::new(),
from_stage: "1_backlog".into(),
to_stage: "2_current".into(),
};
let stored = status_to_stored(ev).unwrap();
assert!(
matches!(stored, StoredEvent::StageTransition { story_id, .. } if story_id == "42")
);
}
#[test]
fn status_to_stored_merge_failure() {
let ev = StatusEvent::MergeFailure {
story_id: "7".into(),
story_name: String::new(),
reason: "conflict".into(),
};
let stored = status_to_stored(ev).unwrap();
assert!(matches!(stored, StoredEvent::MergeFailure { story_id, .. } if story_id == "7"));
}
#[test]
fn status_to_stored_story_blocked() {
let ev = StatusEvent::StoryBlocked {
story_id: "3".into(),
story_name: String::new(),
reason: "retry limit".into(),
};
let stored = status_to_stored(ev).unwrap();
assert!(matches!(stored, StoredEvent::StoryBlocked { story_id, .. } if story_id == "3"));
}
#[test]
fn status_to_stored_rate_limit_warning_is_none() {
let ev = StatusEvent::RateLimitWarning {
story_id: "1".into(),
story_name: String::new(),
agent_name: "coder".into(),
};
assert!(status_to_stored(ev).is_none());
}
#[test]
fn status_to_stored_rate_limit_hard_block_is_none() {
let ev = StatusEvent::RateLimitHardBlock {
story_id: "2".into(),
story_name: String::new(),
agent_name: "coder".into(),
reset_at: chrono::Utc::now(),
};
assert!(status_to_stored(ev).is_none());
}
#[test]
fn spawn_relay_task_noop_when_url_empty() {
// Should not panic or spawn anything meaningful.
let broadcaster = Arc::new(StatusBroadcaster::new());
let client = reqwest::Client::new();
spawn_relay_task(String::new(), "test".into(), broadcaster, client);
// If we reach here without panic, the guard worked.
}
/// End-to-end: a `TransitionFired`-equivalent event published on the sled's
/// broadcaster must reach the gateway's [`GatewayStatusEvent`] broadcast
/// within 1 second.
///
/// Spins up a real poem HTTP server (token endpoint + WS event-push endpoint),
/// spawns the relay task pointing at it, fires a [`StatusEvent::StageTransition`],
/// and asserts the gateway broadcast receives the matching [`StoredEvent`].
#[tokio::test]
async fn relay_end_to_end_stage_transition_reaches_gateway_broadcast() {
use crate::http::gateway::{gateway_event_push_handler, gateway_generate_token_handler};
use crate::service::gateway::{GatewayConfig, GatewayState, ProjectEntry};
use poem::EndpointExt as _;
use poem::listener::TcpAcceptor;
use std::collections::BTreeMap;
use std::path::PathBuf;
use tokio::net::TcpListener;
crate::crdt_state::init_for_test();
// Gateway state: one project whose name matches the relay project name.
let mut projects = BTreeMap::new();
projects.insert(
"sled-test".to_string(),
ProjectEntry::with_url("http://sled-test:3001"),
);
let config = GatewayConfig {
projects,
sled_tokens: BTreeMap::new(),
};
let state = Arc::new(GatewayState::new(config, PathBuf::new(), 9000).unwrap());
// Subscribe before the relay connects so the event is not missed.
let mut gw_rx = state.event_tx.subscribe();
// Start a poem server on an ephemeral loopback port exposing the real
// token and event-push handlers.
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let gateway_url = format!("http://127.0.0.1:{}", addr.port());
let route = poem::Route::new()
.at(
"/gateway/tokens",
poem::post(gateway_generate_token_handler),
)
.at(
"/gateway/events/push",
poem::get(gateway_event_push_handler),
)
.data(state.clone());
tokio::spawn(async move {
let acceptor = TcpAcceptor::from_tokio(listener).unwrap();
let _ = poem::Server::new_with_acceptor(acceptor).run(route).await;
});
// Spawn the relay task pointing at our in-process gateway server.
let broadcaster = Arc::new(StatusBroadcaster::new());
spawn_relay_task(
gateway_url,
"sled-test".into(),
Arc::clone(&broadcaster),
reqwest::Client::new(),
);
// Give the relay time to obtain a join token, connect the WebSocket,
// and enter its event-receive loop before we publish.
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
// Publish a stage transition on the sled side.
broadcaster.publish(StatusEvent::StageTransition {
story_id: "42_story_relay_e2e".into(),
story_name: "Relay E2E".into(),
from_stage: "1_backlog".into(),
to_stage: "2_current".into(),
});
// The event must arrive at the gateway broadcast within 1 second.
let received = tokio::time::timeout(std::time::Duration::from_secs(1), gw_rx.recv())
.await
.expect("timed out: event did not arrive at gateway broadcast within 1 s")
.expect("gateway broadcast channel closed unexpectedly");
assert_eq!(received.project, "sled-test");
assert!(
matches!(
received.event,
StoredEvent::StageTransition { ref story_id, .. } if story_id == "42_story_relay_e2e"
),
"unexpected gateway event: {:?}",
received.event
);
}
/// Extends `relay_end_to_end_stage_transition_reaches_gateway_broadcast` to
/// cover the full wiring path: `project_docker_run_args` embeds
/// `HUSKIES_GATEWAY_URL` in the sled's argv; when that URL is used to start
/// the relay, a transition fired inside the sled reaches the gateway's CRDT
/// event_log within 1 second.
#[tokio::test]
async fn project_docker_run_args_gateway_url_wires_relay() {
use crate::chat::transport::matrix::new_project::project_docker_run_args;
use crate::http::gateway::{gateway_event_push_handler, gateway_generate_token_handler};
use crate::service::gateway::{GatewayConfig, GatewayState, ProjectEntry};
use poem::EndpointExt as _;
use poem::listener::TcpAcceptor;
use std::collections::BTreeMap;
use std::path::PathBuf;
use tokio::net::TcpListener;
crate::crdt_state::init_for_test();
// Spin up an in-process gateway server on an ephemeral port so we have
// a real URL to embed in the docker run args.
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let gateway_url = format!("http://127.0.0.1:{}", addr.port());
// project_docker_run_args embeds the gateway URL: this is the production
// code path that sets HUSKIES_GATEWAY_URL on the sled container.
let docker_args = project_docker_run_args(
"huskies-sled-relay",
3200,
2300,
"ssh-ed25519 AAAA...",
"Test User",
"test@example.com",
None,
&gateway_url,
);
// Extract the injected URL exactly as the sled would read it from its env.
let injected_url = docker_args
.windows(2)
.find(|w| w[0] == "-e" && w[1].starts_with("HUSKIES_GATEWAY_URL="))
.map(|w| w[1].trim_start_matches("HUSKIES_GATEWAY_URL=").to_string())
.expect("project_docker_run_args must inject HUSKIES_GATEWAY_URL");
assert_eq!(injected_url, gateway_url, "injected URL must match input");
// Set up gateway state for the relay project.
let mut projects = BTreeMap::new();
projects.insert(
"sled-relay".to_string(),
ProjectEntry::with_url("http://sled-relay:3001"),
);
let config = GatewayConfig {
projects,
sled_tokens: BTreeMap::new(),
};
let state = Arc::new(GatewayState::new(config, PathBuf::new(), 9001).unwrap());
let mut gw_rx = state.event_tx.subscribe();
let route = poem::Route::new()
.at(
"/gateway/tokens",
poem::post(gateway_generate_token_handler),
)
.at(
"/gateway/events/push",
poem::get(gateway_event_push_handler),
)
.data(state.clone());
tokio::spawn(async move {
let acceptor = TcpAcceptor::from_tokio(listener).unwrap();
let _ = poem::Server::new_with_acceptor(acceptor).run(route).await;
});
// Spawn the relay using the URL extracted from the docker run args —
// this simulates what the sled does when it reads HUSKIES_GATEWAY_URL
// from its container environment.
let broadcaster = Arc::new(StatusBroadcaster::new());
spawn_relay_task(
injected_url,
"sled-relay".into(),
Arc::clone(&broadcaster),
reqwest::Client::new(),
);
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
broadcaster.publish(StatusEvent::StageTransition {
story_id: "99_docker_args_relay".into(),
story_name: "Docker Args Relay".into(),
from_stage: "1_backlog".into(),
to_stage: "2_current".into(),
});
let received = tokio::time::timeout(std::time::Duration::from_secs(1), gw_rx.recv())
.await
.expect("timed out: event did not reach gateway within 1 s")
.expect("gateway broadcast channel closed unexpectedly");
assert_eq!(received.project, "sled-relay");
assert!(
matches!(
received.event,
StoredEvent::StageTransition { ref story_id, .. } if story_id == "99_docker_args_relay"
),
"unexpected gateway event: {:?}",
received.event
);
}
}