huskies: merge 1148 story Per-sled upgrade chat command using huskies upgrade (1138), serial-locked

This commit is contained in:
dave
2026-05-19 18:34:44 +00:00
parent 34af2f1820
commit 2593b36072
3 changed files with 559 additions and 0 deletions
@@ -248,6 +248,7 @@ pub(in crate::chat::transport::matrix::bot) async fn on_room_message(
"new",
"config",
"project-rebuild",
"upgrade",
];
let stripped = crate::chat::util::strip_bot_mention(
@@ -467,6 +468,84 @@ pub(in crate::chat::transport::matrix::bot) async fn on_room_message(
return;
}
// In gateway mode, handle the `upgrade [<project>]` command to upgrade a
// sled's binary in-container, streaming phase markers to the room.
if ctx.is_gateway()
&& let Some(upgrade_cmd) = super::super::super::sled_upgrade::extract_upgrade_command(
&user_message,
&ctx.services.bot_name,
ctx.matrix_user_id.as_str(),
)
{
match upgrade_cmd {
super::super::super::sled_upgrade::UpgradeCommand::ListProjects => {
slog!("[matrix-bot] Handling 'upgrade' list-projects from {sender}");
let response = if let Some(ref store) = ctx.gateway_projects_store {
super::super::super::sled_upgrade::handle_upgrade_list_projects(store).await
} else {
"Gateway projects store unavailable.".to_string()
};
let html = markdown_to_html(&response);
if let Ok(msg_id) = ctx
.transport
.send_message(&room_id_str, &response, &html)
.await
&& let Ok(event_id) = msg_id.parse()
{
ctx.bot_sent_event_ids.lock().await.insert(event_id);
}
}
super::super::super::sled_upgrade::UpgradeCommand::Upgrade { project } => {
slog!("[matrix-bot] Handling 'upgrade {project}' from {sender}");
if let Some(ref store) = ctx.gateway_projects_store {
let transport = Arc::clone(&ctx.transport);
let bot_sent = Arc::clone(&ctx.bot_sent_event_ids);
let room = room_id_str.clone();
let response = super::super::super::sled_upgrade::handle_sled_upgrade(
&project,
store,
ctx.gateway_port,
|phase_msg| {
let transport = Arc::clone(&transport);
let bot_sent = Arc::clone(&bot_sent);
let room = room.clone();
async move {
let html = markdown_to_html(&phase_msg);
if let Ok(msg_id) =
transport.send_message(&room, &phase_msg, &html).await
&& let Ok(event_id) = msg_id.parse()
{
bot_sent.lock().await.insert(event_id);
}
}
},
)
.await;
let html = markdown_to_html(&response);
if let Ok(msg_id) = ctx
.transport
.send_message(&room_id_str, &response, &html)
.await
&& let Ok(event_id) = msg_id.parse()
{
ctx.bot_sent_event_ids.lock().await.insert(event_id);
}
} else {
let msg = "Gateway projects store unavailable — cannot upgrade sled.";
let html = markdown_to_html(msg);
if let Ok(msg_id) = ctx.transport.send_message(&room_id_str, msg, &html).await
&& let Ok(event_id) = msg_id.parse()
{
ctx.bot_sent_event_ids.lock().await.insert(event_id);
}
}
}
}
return;
}
// Check for bot-level commands (help, status, ambient, …) before invoking
// the LLM. All commands are registered in commands.rs — no special-casing
// needed here.
+2
View File
@@ -37,6 +37,8 @@ pub mod rebuild;
pub mod reset;
/// rmtree command — handles `!rmtree` bot commands to remove worktrees.
pub mod rmtree;
/// `upgrade [<project>]` gateway chat command — streaming per-sled binary upgrade.
pub mod sled_upgrade;
/// Start command — handles `!start` bot commands to launch agents on stories.
pub mod start;
/// Matrix `ChatTransport` implementation wrapping the Matrix SDK client.
@@ -0,0 +1,478 @@
//! `upgrade [<project>]` gateway chat command — streaming sled binary upgrade.
//!
//! Usage (gateway mode only):
//! - `{bot} upgrade <project>` — upgrade the named sled's binary in-container.
//! - `{bot} upgrade` — list registered projects (shows what can be targeted).
//!
//! The gateway orchestrates the upgrade in four phases, streaming a marker to
//! the chat room at each step:
//! 1. `[1/4] downloading` — POSTs to `{sled_url}/api/upgrade`; sled starts download.
//! 2. `[2/4] swapping binary` — gateway received 202; sled atomically renamed the binary.
//! 3. `[3/4] restarting sled` — sled re-execs with the new binary; HTTP goes dark briefly.
//! 4. `[4/4] reconnected to gateway` — sled's `/health` probe is responding again.
//!
//! Concurrent `upgrade` invocations are serialised via a global async mutex so
//! that two simultaneous upgrades cannot interleave their phase markers or race
//! on the sled restart.
use crate::service::gateway::config::ProjectEntry;
use std::collections::BTreeMap;
use std::future::Future;
use std::sync::{Arc, OnceLock};
use std::time::Duration;
use tokio::sync::{Mutex, RwLock};
// ── Serial lock ────────────────────────────────────────────────────────────────
static UPGRADE_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
fn upgrade_lock() -> &'static Mutex<()> {
UPGRADE_LOCK.get_or_init(|| Mutex::new(()))
}
// ── Command parsing ────────────────────────────────────────────────────────────
/// A parsed `upgrade` command.
#[derive(Debug, PartialEq)]
pub enum UpgradeCommand {
/// `upgrade <project>` — upgrade the named sled.
Upgrade {
/// The project/sled name to upgrade.
project: String,
},
/// `upgrade` with no argument — list available projects.
ListProjects,
}
/// Parse an `upgrade [<project>]` command from a raw message body.
///
/// Strips the bot mention prefix and checks whether the first word is `upgrade`.
/// Returns `None` when the message is not an upgrade command.
pub fn extract_upgrade_command(
message: &str,
bot_name: &str,
bot_user_id: &str,
) -> Option<UpgradeCommand> {
let stripped = crate::chat::util::strip_bot_mention(message, bot_name, bot_user_id);
let trimmed = stripped
.trim()
.trim_start_matches(|c: char| !c.is_alphanumeric());
let (cmd, rest) = match trimmed.split_once(char::is_whitespace) {
Some((c, r)) => (c, r.trim()),
None => (trimmed, ""),
};
if !cmd.eq_ignore_ascii_case("upgrade") {
return None;
}
if rest.is_empty() {
Some(UpgradeCommand::ListProjects)
} else {
Some(UpgradeCommand::Upgrade {
project: rest.split_whitespace().next().unwrap_or(rest).to_string(),
})
}
}
// ── Handlers ───────────────────────────────────────────────────────────────────
/// List available projects when `upgrade` is invoked without an argument.
///
/// Returns a Markdown string enumerating the registered project names so the
/// user knows which targets are valid for `upgrade <project>`.
pub async fn handle_upgrade_list_projects(
projects_store: &Arc<RwLock<BTreeMap<String, ProjectEntry>>>,
) -> String {
let projects = projects_store.read().await;
if projects.is_empty() {
return "No projects are currently registered with the gateway.".to_string();
}
let names: Vec<&String> = projects.keys().collect();
let list = names
.iter()
.map(|n| format!("- `{n}`"))
.collect::<Vec<_>>()
.join("\n");
format!("Registered projects (use `upgrade <project>` to upgrade one):\n{list}")
}
/// Upgrade a named sled by streaming phase markers to the chat room.
///
/// Acquires the global upgrade lock to serialise concurrent invocations. Each
/// phase is announced by calling `send_phase` before the corresponding work
/// begins. On any failure, an error message is returned and the previous
/// binary remains active on the sled.
///
/// `gateway_port` is used to derive the default binary source URL
/// (`http://gateway:<port>/api/huskies-binary`) when neither
/// `HUSKIES_GATEWAY_BINARY_URL` nor `--source` is set.
pub async fn handle_sled_upgrade<F, Fut>(
project: &str,
projects_store: &Arc<RwLock<BTreeMap<String, ProjectEntry>>>,
gateway_port: Option<u16>,
send_phase: F,
) -> String
where
F: Fn(String) -> Fut,
Fut: Future<Output = ()>,
{
// ── Look up project URL ──────────────────────────────────────────────────
let sled_url = {
let projects = projects_store.read().await;
match projects.get(project).and_then(|e| e.url.clone()) {
Some(u) => u,
None => {
let available: Vec<&String> = projects.keys().collect();
return format!(
"Project `{project}` not found. Registered projects: {}",
available
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>()
.join(", ")
);
}
}
};
// ── Resolve binary source URL ────────────────────────────────────────────
let source_url = std::env::var("HUSKIES_GATEWAY_BINARY_URL").unwrap_or_else(|_| {
format!(
"http://gateway:{}/api/huskies-binary",
gateway_port.unwrap_or(3000)
)
});
// ── Acquire serial lock ──────────────────────────────────────────────────
let _lock = upgrade_lock().lock().await;
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()
.unwrap_or_default();
// ── Phase 1: downloading ─────────────────────────────────────────────────
send_phase("[1/4] downloading\u{2026}".to_string()).await;
let upgrade_url = format!("{}/api/upgrade", sled_url.trim_end_matches('/'));
let body = serde_json::json!({ "source_url": source_url });
let resp = match client.post(&upgrade_url).json(&body).send().await {
Ok(r) => r,
Err(e) => {
return format!(
"Upgrade failed at **[1/4] downloading**: could not reach sled at `{upgrade_url}`.\n\
Error: {e}\n\n\
The previous version remains active."
);
}
};
if !resp.status().is_success() && resp.status().as_u16() != 202 {
let status = resp.status();
let body_text = resp.text().await.unwrap_or_default();
return format!(
"Upgrade failed at **[1/4] downloading**: sled returned HTTP {status}.\n\
Response: {body_text}\n\n\
The previous version remains active."
);
}
// ── Phase 2: swapping binary ─────────────────────────────────────────────
// The sled accepted the request (202) and is downloading + atomically
// replacing the binary in the background.
send_phase("[2/4] swapping binary\u{2026}".to_string()).await;
// ── Phase 3: restarting sled ─────────────────────────────────────────────
// The sled will re-exec momentarily; announce before the health loop.
send_phase("[3/4] restarting sled\u{2026}".to_string()).await;
// ── Wait for sled to come back up ────────────────────────────────────────
let health_url = format!("{}/health", sled_url.trim_end_matches('/'));
// Give the sled a few seconds to start the download + re-exec before polling.
tokio::time::sleep(Duration::from_secs(3)).await;
let reconnected = wait_for_health(&client, &health_url, 120).await;
if !reconnected {
return format!(
"Upgrade failed at **[4/4] reconnected to gateway**: sled at `{sled_url}` did not \
come back online within 120 seconds after the upgrade was triggered.\n\n\
Check the container logs: `docker logs huskies-{project}`"
);
}
// ── Phase 4: reconnected ─────────────────────────────────────────────────
send_phase("[4/4] reconnected to gateway".to_string()).await;
// ── Report new version ───────────────────────────────────────────────────
let version = fetch_sled_version(&client, &sled_url).await;
format!("{project} upgraded to version {version}")
}
// ── Helpers ───────────────────────────────────────────────────────────────────
/// Poll `GET {health_url}` every 3 seconds until it returns 200 or `timeout_secs` elapses.
///
/// Returns `true` when the probe succeeds, `false` on timeout.
async fn wait_for_health(client: &reqwest::Client, health_url: &str, timeout_secs: u64) -> bool {
let deadline = std::time::Instant::now() + Duration::from_secs(timeout_secs);
let poll = Duration::from_secs(3);
loop {
match client.get(health_url).send().await {
Ok(r) if r.status().is_success() => return true,
_ => {}
}
if std::time::Instant::now() >= deadline {
return false;
}
tokio::time::sleep(poll).await;
}
}
/// Fetch the running version from the sled's `get_version` MCP tool.
///
/// Returns the version string on success, or `"unknown"` on any error so the
/// final chat reply is still meaningful.
async fn fetch_sled_version(client: &reqwest::Client, sled_url: &str) -> String {
let mcp_url = format!("{}/mcp", sled_url.trim_end_matches('/'));
let body = serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {
"name": "get_version",
"arguments": {}
}
});
let resp = match client.post(&mcp_url).json(&body).send().await {
Ok(r) => r,
Err(_) => return "unknown".to_string(),
};
let val: serde_json::Value = match resp.json().await {
Ok(v) => v,
Err(_) => return "unknown".to_string(),
};
// MCP tools/call response: result.content[0].text is a JSON string.
let text = val
.pointer("/result/content/0/text")
.and_then(|v| v.as_str())
.unwrap_or("");
if text.is_empty() {
return "unknown".to_string();
}
serde_json::from_str::<serde_json::Value>(text)
.ok()
.and_then(|v| v.get("version").and_then(|v| v.as_str()).map(String::from))
.unwrap_or_else(|| "unknown".to_string())
}
// ── Tests ──────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
// ── extract_upgrade_command ───────────────────────────────────────────────
#[test]
fn extract_upgrade_with_project() {
let cmd = extract_upgrade_command("Timmy upgrade huskies-server", "Timmy", "@timmy:home");
assert_eq!(
cmd,
Some(UpgradeCommand::Upgrade {
project: "huskies-server".to_string()
})
);
}
#[test]
fn extract_upgrade_no_arg_is_list() {
let cmd = extract_upgrade_command("Timmy upgrade", "Timmy", "@timmy:home");
assert_eq!(cmd, Some(UpgradeCommand::ListProjects));
}
#[test]
fn extract_upgrade_with_full_user_id() {
let cmd = extract_upgrade_command("@timmy:home upgrade myapp", "Timmy", "@timmy:home");
assert_eq!(
cmd,
Some(UpgradeCommand::Upgrade {
project: "myapp".to_string()
})
);
}
#[test]
fn extract_non_upgrade_returns_none() {
let cmd = extract_upgrade_command("Timmy status", "Timmy", "@timmy:home");
assert!(cmd.is_none());
}
#[test]
fn extract_upgrade_case_insensitive() {
let cmd = extract_upgrade_command("Timmy UPGRADE alpha", "Timmy", "@timmy:home");
assert_eq!(
cmd,
Some(UpgradeCommand::Upgrade {
project: "alpha".to_string()
})
);
}
// ── handle_upgrade_list_projects ─────────────────────────────────────────
#[tokio::test]
async fn list_projects_empty_store() {
let store: Arc<RwLock<BTreeMap<String, ProjectEntry>>> =
Arc::new(RwLock::new(BTreeMap::new()));
let msg = handle_upgrade_list_projects(&store).await;
assert!(
msg.contains("No projects"),
"empty store should say no projects: {msg}"
);
}
#[tokio::test]
async fn list_projects_shows_names() {
use std::collections::BTreeMap;
let mut map = BTreeMap::new();
map.insert(
"alpha".to_string(),
ProjectEntry {
url: Some("http://localhost:3001".into()),
auth_token: None,
ssh_port: None,
host_path: None,
},
);
map.insert(
"beta".to_string(),
ProjectEntry {
url: Some("http://localhost:3002".into()),
auth_token: None,
ssh_port: None,
host_path: None,
},
);
let store = Arc::new(RwLock::new(map));
let msg = handle_upgrade_list_projects(&store).await;
assert!(msg.contains("alpha"), "should list alpha: {msg}");
assert!(msg.contains("beta"), "should list beta: {msg}");
}
// ── handle_sled_upgrade validation ───────────────────────────────────────
#[tokio::test]
async fn upgrade_unknown_project_returns_error() {
let store: Arc<RwLock<BTreeMap<String, ProjectEntry>>> =
Arc::new(RwLock::new(BTreeMap::new()));
let phases: std::sync::Mutex<Vec<String>> = std::sync::Mutex::new(vec![]);
let result = handle_sled_upgrade("nonexistent", &store, Some(3000), |msg| {
phases.lock().unwrap().push(msg);
async {}
})
.await;
assert!(
result.contains("not found"),
"should say not found: {result}"
);
// No phase markers should have been emitted before the validation error.
assert!(
phases.lock().unwrap().is_empty(),
"no phases should be emitted for unknown project"
);
}
#[tokio::test]
async fn upgrade_project_with_no_url_fails_gracefully() {
let mut map = BTreeMap::new();
map.insert(
"myapp".to_string(),
ProjectEntry {
url: None,
auth_token: None,
ssh_port: None,
host_path: None,
},
);
let store = Arc::new(RwLock::new(map));
let result = handle_sled_upgrade("myapp", &store, Some(3000), |_msg| async {}).await;
assert!(
result.contains("not found"),
"project with no URL should say not found: {result}"
);
}
#[tokio::test]
async fn upgrade_unreachable_sled_reports_failure() {
let mut map = BTreeMap::new();
map.insert(
"myapp".to_string(),
ProjectEntry {
url: Some("http://127.0.0.1:1".into()), // port 1 is never listening
auth_token: None,
ssh_port: None,
host_path: None,
},
);
let store = Arc::new(RwLock::new(map));
let phases: std::sync::Mutex<Vec<String>> = std::sync::Mutex::new(vec![]);
let result = handle_sled_upgrade("myapp", &store, Some(3000), |msg| {
phases.lock().unwrap().push(msg);
async {}
})
.await;
// Phase 1 marker must have been sent before the failed request.
let sent = phases.lock().unwrap().clone();
assert!(
sent.iter().any(|m| m.contains("[1/4]")),
"phase 1 marker must be sent: {sent:?}"
);
assert!(
result.contains("downloading") || result.contains("reach"),
"error should mention the failure: {result}"
);
assert!(
result.contains("previous version"),
"error should confirm old version is active: {result}"
);
}
// ── wait_for_health ───────────────────────────────────────────────────────
#[tokio::test]
async fn wait_for_health_immediate_success() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let handle = tokio::spawn(async move {
if let Ok((mut stream, _)) = listener.accept().await {
use tokio::io::AsyncWriteExt;
let mut buf = [0u8; 4096];
let _ = tokio::io::AsyncReadExt::read(&mut stream, &mut buf).await;
let _ = stream
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok")
.await;
}
});
let client = reqwest::Client::new();
let url = format!("http://127.0.0.1:{port}/health");
let ok = wait_for_health(&client, &url, 5).await;
assert!(ok, "should return true when health probe succeeds");
handle.abort();
}
#[tokio::test]
async fn wait_for_health_timeout() {
let client = reqwest::Client::builder()
.timeout(Duration::from_millis(100))
.build()
.unwrap();
// Nothing listening on port 1.
let ok = wait_for_health(&client, "http://127.0.0.1:1/health", 1).await;
assert!(!ok, "should return false when health probe never succeeds");
}
}