diff --git a/server/src/chat/transport/matrix/bot/messages/on_room_message.rs b/server/src/chat/transport/matrix/bot/messages/on_room_message.rs index fa14ce7a..5bc842ce 100644 --- a/server/src/chat/transport/matrix/bot/messages/on_room_message.rs +++ b/server/src/chat/transport/matrix/bot/messages/on_room_message.rs @@ -248,6 +248,7 @@ pub(in crate::chat::transport::matrix::bot) async fn on_room_message( "new", "config", "project-rebuild", + "upgrade", ]; let stripped = crate::chat::util::strip_bot_mention( @@ -467,6 +468,84 @@ pub(in crate::chat::transport::matrix::bot) async fn on_room_message( return; } + // In gateway mode, handle the `upgrade []` command to upgrade a + // sled's binary in-container, streaming phase markers to the room. + if ctx.is_gateway() + && let Some(upgrade_cmd) = super::super::super::sled_upgrade::extract_upgrade_command( + &user_message, + &ctx.services.bot_name, + ctx.matrix_user_id.as_str(), + ) + { + match upgrade_cmd { + super::super::super::sled_upgrade::UpgradeCommand::ListProjects => { + slog!("[matrix-bot] Handling 'upgrade' list-projects from {sender}"); + let response = if let Some(ref store) = ctx.gateway_projects_store { + super::super::super::sled_upgrade::handle_upgrade_list_projects(store).await + } else { + "Gateway projects store unavailable.".to_string() + }; + let html = markdown_to_html(&response); + if let Ok(msg_id) = ctx + .transport + .send_message(&room_id_str, &response, &html) + .await + && let Ok(event_id) = msg_id.parse() + { + ctx.bot_sent_event_ids.lock().await.insert(event_id); + } + } + super::super::super::sled_upgrade::UpgradeCommand::Upgrade { project } => { + slog!("[matrix-bot] Handling 'upgrade {project}' from {sender}"); + if let Some(ref store) = ctx.gateway_projects_store { + let transport = Arc::clone(&ctx.transport); + let bot_sent = Arc::clone(&ctx.bot_sent_event_ids); + let room = room_id_str.clone(); + + let response = super::super::super::sled_upgrade::handle_sled_upgrade( + &project, + store, + ctx.gateway_port, + |phase_msg| { + let transport = Arc::clone(&transport); + let bot_sent = Arc::clone(&bot_sent); + let room = room.clone(); + async move { + let html = markdown_to_html(&phase_msg); + if let Ok(msg_id) = + transport.send_message(&room, &phase_msg, &html).await + && let Ok(event_id) = msg_id.parse() + { + bot_sent.lock().await.insert(event_id); + } + } + }, + ) + .await; + + let html = markdown_to_html(&response); + if let Ok(msg_id) = ctx + .transport + .send_message(&room_id_str, &response, &html) + .await + && let Ok(event_id) = msg_id.parse() + { + ctx.bot_sent_event_ids.lock().await.insert(event_id); + } + } else { + let msg = "Gateway projects store unavailable — cannot upgrade sled."; + let html = markdown_to_html(msg); + if let Ok(msg_id) = ctx.transport.send_message(&room_id_str, msg, &html).await + && let Ok(event_id) = msg_id.parse() + { + ctx.bot_sent_event_ids.lock().await.insert(event_id); + } + } + } + } + return; + } + // Check for bot-level commands (help, status, ambient, …) before invoking // the LLM. All commands are registered in commands.rs — no special-casing // needed here. diff --git a/server/src/chat/transport/matrix/mod.rs b/server/src/chat/transport/matrix/mod.rs index 37726629..8ed25881 100644 --- a/server/src/chat/transport/matrix/mod.rs +++ b/server/src/chat/transport/matrix/mod.rs @@ -37,6 +37,8 @@ pub mod rebuild; pub mod reset; /// rmtree command — handles `!rmtree` bot commands to remove worktrees. pub mod rmtree; +/// `upgrade []` gateway chat command — streaming per-sled binary upgrade. +pub mod sled_upgrade; /// Start command — handles `!start` bot commands to launch agents on stories. pub mod start; /// Matrix `ChatTransport` implementation wrapping the Matrix SDK client. diff --git a/server/src/chat/transport/matrix/sled_upgrade.rs b/server/src/chat/transport/matrix/sled_upgrade.rs new file mode 100644 index 00000000..aaccf05e --- /dev/null +++ b/server/src/chat/transport/matrix/sled_upgrade.rs @@ -0,0 +1,478 @@ +//! `upgrade []` gateway chat command — streaming sled binary upgrade. +//! +//! Usage (gateway mode only): +//! - `{bot} upgrade ` — upgrade the named sled's binary in-container. +//! - `{bot} upgrade` — list registered projects (shows what can be targeted). +//! +//! The gateway orchestrates the upgrade in four phases, streaming a marker to +//! the chat room at each step: +//! 1. `[1/4] downloading` — POSTs to `{sled_url}/api/upgrade`; sled starts download. +//! 2. `[2/4] swapping binary` — gateway received 202; sled atomically renamed the binary. +//! 3. `[3/4] restarting sled` — sled re-execs with the new binary; HTTP goes dark briefly. +//! 4. `[4/4] reconnected to gateway` — sled's `/health` probe is responding again. +//! +//! Concurrent `upgrade` invocations are serialised via a global async mutex so +//! that two simultaneous upgrades cannot interleave their phase markers or race +//! on the sled restart. + +use crate::service::gateway::config::ProjectEntry; +use std::collections::BTreeMap; +use std::future::Future; +use std::sync::{Arc, OnceLock}; +use std::time::Duration; +use tokio::sync::{Mutex, RwLock}; + +// ── Serial lock ──────────────────────────────────────────────────────────────── + +static UPGRADE_LOCK: OnceLock> = OnceLock::new(); + +fn upgrade_lock() -> &'static Mutex<()> { + UPGRADE_LOCK.get_or_init(|| Mutex::new(())) +} + +// ── Command parsing ──────────────────────────────────────────────────────────── + +/// A parsed `upgrade` command. +#[derive(Debug, PartialEq)] +pub enum UpgradeCommand { + /// `upgrade ` — upgrade the named sled. + Upgrade { + /// The project/sled name to upgrade. + project: String, + }, + /// `upgrade` with no argument — list available projects. + ListProjects, +} + +/// Parse an `upgrade []` command from a raw message body. +/// +/// Strips the bot mention prefix and checks whether the first word is `upgrade`. +/// Returns `None` when the message is not an upgrade command. +pub fn extract_upgrade_command( + message: &str, + bot_name: &str, + bot_user_id: &str, +) -> Option { + let stripped = crate::chat::util::strip_bot_mention(message, bot_name, bot_user_id); + let trimmed = stripped + .trim() + .trim_start_matches(|c: char| !c.is_alphanumeric()); + + let (cmd, rest) = match trimmed.split_once(char::is_whitespace) { + Some((c, r)) => (c, r.trim()), + None => (trimmed, ""), + }; + + if !cmd.eq_ignore_ascii_case("upgrade") { + return None; + } + + if rest.is_empty() { + Some(UpgradeCommand::ListProjects) + } else { + Some(UpgradeCommand::Upgrade { + project: rest.split_whitespace().next().unwrap_or(rest).to_string(), + }) + } +} + +// ── Handlers ─────────────────────────────────────────────────────────────────── + +/// List available projects when `upgrade` is invoked without an argument. +/// +/// Returns a Markdown string enumerating the registered project names so the +/// user knows which targets are valid for `upgrade `. +pub async fn handle_upgrade_list_projects( + projects_store: &Arc>>, +) -> String { + let projects = projects_store.read().await; + if projects.is_empty() { + return "No projects are currently registered with the gateway.".to_string(); + } + let names: Vec<&String> = projects.keys().collect(); + let list = names + .iter() + .map(|n| format!("- `{n}`")) + .collect::>() + .join("\n"); + format!("Registered projects (use `upgrade ` to upgrade one):\n{list}") +} + +/// Upgrade a named sled by streaming phase markers to the chat room. +/// +/// Acquires the global upgrade lock to serialise concurrent invocations. Each +/// phase is announced by calling `send_phase` before the corresponding work +/// begins. On any failure, an error message is returned and the previous +/// binary remains active on the sled. +/// +/// `gateway_port` is used to derive the default binary source URL +/// (`http://gateway:/api/huskies-binary`) when neither +/// `HUSKIES_GATEWAY_BINARY_URL` nor `--source` is set. +pub async fn handle_sled_upgrade( + project: &str, + projects_store: &Arc>>, + gateway_port: Option, + send_phase: F, +) -> String +where + F: Fn(String) -> Fut, + Fut: Future, +{ + // ── Look up project URL ────────────────────────────────────────────────── + let sled_url = { + let projects = projects_store.read().await; + match projects.get(project).and_then(|e| e.url.clone()) { + Some(u) => u, + None => { + let available: Vec<&String> = projects.keys().collect(); + return format!( + "Project `{project}` not found. Registered projects: {}", + available + .iter() + .map(|s| s.as_str()) + .collect::>() + .join(", ") + ); + } + } + }; + + // ── Resolve binary source URL ──────────────────────────────────────────── + let source_url = std::env::var("HUSKIES_GATEWAY_BINARY_URL").unwrap_or_else(|_| { + format!( + "http://gateway:{}/api/huskies-binary", + gateway_port.unwrap_or(3000) + ) + }); + + // ── Acquire serial lock ────────────────────────────────────────────────── + let _lock = upgrade_lock().lock().await; + + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(30)) + .build() + .unwrap_or_default(); + + // ── Phase 1: downloading ───────────────────────────────────────────────── + send_phase("[1/4] downloading\u{2026}".to_string()).await; + + let upgrade_url = format!("{}/api/upgrade", sled_url.trim_end_matches('/')); + let body = serde_json::json!({ "source_url": source_url }); + + let resp = match client.post(&upgrade_url).json(&body).send().await { + Ok(r) => r, + Err(e) => { + return format!( + "Upgrade failed at **[1/4] downloading**: could not reach sled at `{upgrade_url}`.\n\ + Error: {e}\n\n\ + The previous version remains active." + ); + } + }; + + if !resp.status().is_success() && resp.status().as_u16() != 202 { + let status = resp.status(); + let body_text = resp.text().await.unwrap_or_default(); + return format!( + "Upgrade failed at **[1/4] downloading**: sled returned HTTP {status}.\n\ + Response: {body_text}\n\n\ + The previous version remains active." + ); + } + + // ── Phase 2: swapping binary ───────────────────────────────────────────── + // The sled accepted the request (202) and is downloading + atomically + // replacing the binary in the background. + send_phase("[2/4] swapping binary\u{2026}".to_string()).await; + + // ── Phase 3: restarting sled ───────────────────────────────────────────── + // The sled will re-exec momentarily; announce before the health loop. + send_phase("[3/4] restarting sled\u{2026}".to_string()).await; + + // ── Wait for sled to come back up ──────────────────────────────────────── + let health_url = format!("{}/health", sled_url.trim_end_matches('/')); + // Give the sled a few seconds to start the download + re-exec before polling. + tokio::time::sleep(Duration::from_secs(3)).await; + + let reconnected = wait_for_health(&client, &health_url, 120).await; + if !reconnected { + return format!( + "Upgrade failed at **[4/4] reconnected to gateway**: sled at `{sled_url}` did not \ + come back online within 120 seconds after the upgrade was triggered.\n\n\ + Check the container logs: `docker logs huskies-{project}`" + ); + } + + // ── Phase 4: reconnected ───────────────────────────────────────────────── + send_phase("[4/4] reconnected to gateway".to_string()).await; + + // ── Report new version ─────────────────────────────────────────────────── + let version = fetch_sled_version(&client, &sled_url).await; + format!("{project} upgraded to version {version}") +} + +// ── Helpers ─────────────────────────────────────────────────────────────────── + +/// Poll `GET {health_url}` every 3 seconds until it returns 200 or `timeout_secs` elapses. +/// +/// Returns `true` when the probe succeeds, `false` on timeout. +async fn wait_for_health(client: &reqwest::Client, health_url: &str, timeout_secs: u64) -> bool { + let deadline = std::time::Instant::now() + Duration::from_secs(timeout_secs); + let poll = Duration::from_secs(3); + loop { + match client.get(health_url).send().await { + Ok(r) if r.status().is_success() => return true, + _ => {} + } + if std::time::Instant::now() >= deadline { + return false; + } + tokio::time::sleep(poll).await; + } +} + +/// Fetch the running version from the sled's `get_version` MCP tool. +/// +/// Returns the version string on success, or `"unknown"` on any error so the +/// final chat reply is still meaningful. +async fn fetch_sled_version(client: &reqwest::Client, sled_url: &str) -> String { + let mcp_url = format!("{}/mcp", sled_url.trim_end_matches('/')); + let body = serde_json::json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": "get_version", + "arguments": {} + } + }); + let resp = match client.post(&mcp_url).json(&body).send().await { + Ok(r) => r, + Err(_) => return "unknown".to_string(), + }; + let val: serde_json::Value = match resp.json().await { + Ok(v) => v, + Err(_) => return "unknown".to_string(), + }; + // MCP tools/call response: result.content[0].text is a JSON string. + let text = val + .pointer("/result/content/0/text") + .and_then(|v| v.as_str()) + .unwrap_or(""); + if text.is_empty() { + return "unknown".to_string(); + } + serde_json::from_str::(text) + .ok() + .and_then(|v| v.get("version").and_then(|v| v.as_str()).map(String::from)) + .unwrap_or_else(|| "unknown".to_string()) +} + +// ── Tests ────────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + // ── extract_upgrade_command ─────────────────────────────────────────────── + + #[test] + fn extract_upgrade_with_project() { + let cmd = extract_upgrade_command("Timmy upgrade huskies-server", "Timmy", "@timmy:home"); + assert_eq!( + cmd, + Some(UpgradeCommand::Upgrade { + project: "huskies-server".to_string() + }) + ); + } + + #[test] + fn extract_upgrade_no_arg_is_list() { + let cmd = extract_upgrade_command("Timmy upgrade", "Timmy", "@timmy:home"); + assert_eq!(cmd, Some(UpgradeCommand::ListProjects)); + } + + #[test] + fn extract_upgrade_with_full_user_id() { + let cmd = extract_upgrade_command("@timmy:home upgrade myapp", "Timmy", "@timmy:home"); + assert_eq!( + cmd, + Some(UpgradeCommand::Upgrade { + project: "myapp".to_string() + }) + ); + } + + #[test] + fn extract_non_upgrade_returns_none() { + let cmd = extract_upgrade_command("Timmy status", "Timmy", "@timmy:home"); + assert!(cmd.is_none()); + } + + #[test] + fn extract_upgrade_case_insensitive() { + let cmd = extract_upgrade_command("Timmy UPGRADE alpha", "Timmy", "@timmy:home"); + assert_eq!( + cmd, + Some(UpgradeCommand::Upgrade { + project: "alpha".to_string() + }) + ); + } + + // ── handle_upgrade_list_projects ───────────────────────────────────────── + + #[tokio::test] + async fn list_projects_empty_store() { + let store: Arc>> = + Arc::new(RwLock::new(BTreeMap::new())); + let msg = handle_upgrade_list_projects(&store).await; + assert!( + msg.contains("No projects"), + "empty store should say no projects: {msg}" + ); + } + + #[tokio::test] + async fn list_projects_shows_names() { + use std::collections::BTreeMap; + let mut map = BTreeMap::new(); + map.insert( + "alpha".to_string(), + ProjectEntry { + url: Some("http://localhost:3001".into()), + auth_token: None, + ssh_port: None, + host_path: None, + }, + ); + map.insert( + "beta".to_string(), + ProjectEntry { + url: Some("http://localhost:3002".into()), + auth_token: None, + ssh_port: None, + host_path: None, + }, + ); + let store = Arc::new(RwLock::new(map)); + let msg = handle_upgrade_list_projects(&store).await; + assert!(msg.contains("alpha"), "should list alpha: {msg}"); + assert!(msg.contains("beta"), "should list beta: {msg}"); + } + + // ── handle_sled_upgrade validation ─────────────────────────────────────── + + #[tokio::test] + async fn upgrade_unknown_project_returns_error() { + let store: Arc>> = + Arc::new(RwLock::new(BTreeMap::new())); + let phases: std::sync::Mutex> = std::sync::Mutex::new(vec![]); + let result = handle_sled_upgrade("nonexistent", &store, Some(3000), |msg| { + phases.lock().unwrap().push(msg); + async {} + }) + .await; + assert!( + result.contains("not found"), + "should say not found: {result}" + ); + // No phase markers should have been emitted before the validation error. + assert!( + phases.lock().unwrap().is_empty(), + "no phases should be emitted for unknown project" + ); + } + + #[tokio::test] + async fn upgrade_project_with_no_url_fails_gracefully() { + let mut map = BTreeMap::new(); + map.insert( + "myapp".to_string(), + ProjectEntry { + url: None, + auth_token: None, + ssh_port: None, + host_path: None, + }, + ); + let store = Arc::new(RwLock::new(map)); + let result = handle_sled_upgrade("myapp", &store, Some(3000), |_msg| async {}).await; + assert!( + result.contains("not found"), + "project with no URL should say not found: {result}" + ); + } + + #[tokio::test] + async fn upgrade_unreachable_sled_reports_failure() { + let mut map = BTreeMap::new(); + map.insert( + "myapp".to_string(), + ProjectEntry { + url: Some("http://127.0.0.1:1".into()), // port 1 is never listening + auth_token: None, + ssh_port: None, + host_path: None, + }, + ); + let store = Arc::new(RwLock::new(map)); + let phases: std::sync::Mutex> = std::sync::Mutex::new(vec![]); + let result = handle_sled_upgrade("myapp", &store, Some(3000), |msg| { + phases.lock().unwrap().push(msg); + async {} + }) + .await; + // Phase 1 marker must have been sent before the failed request. + let sent = phases.lock().unwrap().clone(); + assert!( + sent.iter().any(|m| m.contains("[1/4]")), + "phase 1 marker must be sent: {sent:?}" + ); + assert!( + result.contains("downloading") || result.contains("reach"), + "error should mention the failure: {result}" + ); + assert!( + result.contains("previous version"), + "error should confirm old version is active: {result}" + ); + } + + // ── wait_for_health ─────────────────────────────────────────────────────── + + #[tokio::test] + async fn wait_for_health_immediate_success() { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + + let handle = tokio::spawn(async move { + if let Ok((mut stream, _)) = listener.accept().await { + use tokio::io::AsyncWriteExt; + let mut buf = [0u8; 4096]; + let _ = tokio::io::AsyncReadExt::read(&mut stream, &mut buf).await; + let _ = stream + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok") + .await; + } + }); + + let client = reqwest::Client::new(); + let url = format!("http://127.0.0.1:{port}/health"); + let ok = wait_for_health(&client, &url, 5).await; + assert!(ok, "should return true when health probe succeeds"); + handle.abort(); + } + + #[tokio::test] + async fn wait_for_health_timeout() { + let client = reqwest::Client::builder() + .timeout(Duration::from_millis(100)) + .build() + .unwrap(); + // Nothing listening on port 1. + let ok = wait_for_health(&client, "http://127.0.0.1:1/health", 1).await; + assert!(!ok, "should return false when health probe never succeeds"); + } +}