Files
huskies/server/src/chat/transport/matrix/bot/run.rs
T

434 lines
17 KiB
Rust
Raw Normal View History

use crate::agents::AgentPool;
use crate::slog;
use matrix_sdk::{Client, LoopCtrl, config::SyncSettings};
use matrix_sdk::ruma::OwnedRoomId;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::collections::{HashMap, HashSet};
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::Mutex as TokioMutex;
use tokio::sync::{mpsc, watch};
use super::context::BotContext;
use super::format::{format_startup_announcement, markdown_to_html};
use super::history::load_history;
use super::messages::on_room_message;
use super::verification::on_to_device_verification_request;
/// Connect to the Matrix homeserver, join all configured rooms, and start
/// listening for messages. Runs the full Matrix sync loop — call from a
/// `tokio::spawn` task so it doesn't block the main thread.
pub async fn run_bot(
config: super::super::config::BotConfig,
project_root: PathBuf,
watcher_rx: tokio::sync::broadcast::Receiver<crate::io::watcher::WatcherEvent>,
perm_rx: Arc<TokioMutex<mpsc::UnboundedReceiver<crate::http::context::PermissionForward>>>,
agents: Arc<AgentPool>,
shutdown_rx: watch::Receiver<Option<crate::rebuild::ShutdownReason>>,
) -> Result<(), String> {
let store_path = project_root.join(".storkit").join("matrix_store");
let client = Client::builder()
.homeserver_url(config.homeserver.as_deref().unwrap_or_default())
.sqlite_store(&store_path, None)
.build()
.await
.map_err(|e| format!("Failed to build Matrix client: {e}"))?;
// Persist device ID so E2EE crypto state survives restarts.
let device_id_path = project_root.join(".storkit").join("matrix_device_id");
let saved_device_id: Option<String> = std::fs::read_to_string(&device_id_path)
.ok()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty());
let mut login_builder = client
.matrix_auth()
.login_username(
config.username.as_deref().unwrap_or_default(),
config.password.as_deref().unwrap_or_default(),
)
.initial_device_display_name("Storkit Bot");
if let Some(ref device_id) = saved_device_id {
login_builder = login_builder.device_id(device_id);
}
let login_response = login_builder
.await
.map_err(|e| format!("Matrix login failed: {e}"))?;
// Save device ID on first login so subsequent restarts reuse the same device.
if saved_device_id.is_none() {
let _ = std::fs::write(&device_id_path, &login_response.device_id);
slog!(
"[matrix-bot] Saved device ID {} for future restarts",
login_response.device_id
);
}
let bot_user_id = client
.user_id()
.ok_or_else(|| "No user ID after login".to_string())?
.to_owned();
slog!("[matrix-bot] Logged in as {bot_user_id} (device: {})", login_response.device_id);
// Bootstrap cross-signing keys for E2EE verification support.
// Pass the bot's password for UIA (User-Interactive Authentication) —
// the homeserver requires proof of identity before accepting cross-signing keys.
{
use matrix_sdk::ruma::api::client::uiaa;
let password_auth = uiaa::AuthData::Password(uiaa::Password::new(
uiaa::UserIdentifier::UserIdOrLocalpart(
config.username.clone().unwrap_or_default(),
),
config.password.clone().unwrap_or_default(),
));
if let Err(e) = client
.encryption()
.bootstrap_cross_signing(Some(password_auth))
.await
{
slog!("[matrix-bot] Cross-signing bootstrap note: {e}");
}
}
// Self-sign own device keys so other clients don't show
// "encrypted by a device not verified by its owner" warnings.
match client.encryption().get_own_device().await {
Ok(Some(own_device)) => {
if own_device.is_cross_signed_by_owner() {
slog!("[matrix-bot] Device already self-signed");
} else {
slog!("[matrix-bot] Device not self-signed, signing now...");
match own_device.verify().await {
Ok(()) => slog!("[matrix-bot] Successfully self-signed device keys"),
Err(e) => slog!("[matrix-bot] Failed to self-sign device keys: {e}"),
}
}
}
Ok(None) => slog!("[matrix-bot] Could not find own device in crypto store"),
Err(e) => slog!("[matrix-bot] Error retrieving own device: {e}"),
}
if config.allowed_users.is_empty() {
return Err(
"allowed_users is empty in bot.toml — refusing to start (fail-closed). \
Add at least one Matrix user ID to allowed_users."
.to_string(),
);
}
slog!("[matrix-bot] Allowed users: {:?}", config.allowed_users);
// Parse and join all configured rooms.
let mut target_room_ids: Vec<OwnedRoomId> = Vec::new();
for room_id_str in config.effective_room_ids() {
let room_id: OwnedRoomId = room_id_str
.parse()
.map_err(|_| format!("Invalid room ID '{room_id_str}'"))?;
// Try to join with a timeout. Conduit sometimes hangs or returns
// errors on join if the bot is already a member.
match tokio::time::timeout(
std::time::Duration::from_secs(10),
client.join_room_by_id(&room_id),
)
.await
{
Ok(Ok(_)) => slog!("[matrix-bot] Joined room {room_id}"),
Ok(Err(e)) => {
slog!("[matrix-bot] Join room error (may already be a member): {e}")
}
Err(_) => slog!("[matrix-bot] Join room timed out (may already be a member)"),
}
target_room_ids.push(room_id);
}
if target_room_ids.is_empty() {
return Err("No valid room IDs configured — cannot start".to_string());
}
slog!(
"[matrix-bot] Listening in {} room(s): {:?}",
target_room_ids.len(),
target_room_ids
);
// Clone values needed by the notification listener and startup announcement
// before they are moved into BotContext.
let notif_room_ids = target_room_ids.clone();
let notif_project_root = project_root.clone();
let announce_room_ids = target_room_ids.clone();
let persisted = load_history(&project_root);
slog!(
"[matrix-bot] Loaded persisted conversation history for {} room(s)",
persisted.len()
);
// Restore persisted ambient rooms from config.
let persisted_ambient: HashSet<String> = config
.ambient_rooms
.iter()
.cloned()
.collect();
if !persisted_ambient.is_empty() {
slog!(
"[matrix-bot] Restored ambient mode for {} room(s): {:?}",
persisted_ambient.len(),
persisted_ambient
);
}
// Create the transport abstraction based on the configured transport type.
let transport: Arc<dyn crate::chat::ChatTransport> = match config.transport.as_str() {
"whatsapp" => {
if config.whatsapp_provider == "twilio" {
slog!("[matrix-bot] Using WhatsApp/Twilio transport");
Arc::new(crate::chat::transport::whatsapp::TwilioWhatsAppTransport::new(
config.twilio_account_sid.clone().unwrap_or_default(),
config.twilio_auth_token.clone().unwrap_or_default(),
config.twilio_whatsapp_number.clone().unwrap_or_default(),
))
} else {
slog!("[matrix-bot] Using WhatsApp/Meta transport");
Arc::new(crate::chat::transport::whatsapp::WhatsAppTransport::new(
config.whatsapp_phone_number_id.clone().unwrap_or_default(),
config.whatsapp_access_token.clone().unwrap_or_default(),
config
.whatsapp_notification_template
.clone()
.unwrap_or_else(|| "pipeline_notification".to_string()),
))
}
}
_ => {
slog!("[matrix-bot] Using Matrix transport");
Arc::new(super::super::transport_impl::MatrixTransport::new(client.clone()))
}
};
let bot_name = config
.display_name
.clone()
.unwrap_or_else(|| "Assistant".to_string());
let announce_bot_name = bot_name.clone();
let timer_store = Arc::new(crate::chat::timer::TimerStore::load(
project_root.join(".storkit").join("timers.json"),
));
crate::chat::timer::spawn_timer_tick_loop(
Arc::clone(&timer_store),
Arc::clone(&agents),
project_root.clone(),
);
let ctx = BotContext {
bot_user_id,
target_room_ids,
project_root,
allowed_users: config.allowed_users,
history: Arc::new(TokioMutex::new(persisted)),
history_size: config.history_size,
bot_sent_event_ids: Arc::new(TokioMutex::new(HashSet::new())),
perm_rx,
pending_perm_replies: Arc::new(TokioMutex::new(HashMap::new())),
permission_timeout_secs: config.permission_timeout_secs,
bot_name,
ambient_rooms: Arc::new(std::sync::Mutex::new(persisted_ambient)),
agents,
htop_sessions: Arc::new(TokioMutex::new(HashMap::new())),
transport: Arc::clone(&transport),
timer_store,
};
slog!("[matrix-bot] Cryptographic identity verification is always ON — commands from unencrypted rooms or unverified devices are rejected");
// Register event handlers and inject shared context.
client.add_event_handler_context(ctx);
client.add_event_handler(on_room_message);
client.add_event_handler(on_to_device_verification_request);
// Spawn the stage-transition notification listener before entering the
// sync loop so it starts receiving watcher events immediately.
let notif_room_id_strings: Vec<String> =
notif_room_ids.iter().map(|r| r.to_string()).collect();
super::super::notifications::spawn_notification_listener(
Arc::clone(&transport),
move || notif_room_id_strings.clone(),
watcher_rx,
notif_project_root,
);
// Spawn a shutdown watcher that sends a best-effort goodbye message to all
// configured rooms when the server is about to stop (SIGINT/SIGTERM or rebuild).
{
let shutdown_transport = Arc::clone(&transport);
let shutdown_rooms: Vec<String> =
announce_room_ids.iter().map(|r| r.to_string()).collect();
let shutdown_bot_name = announce_bot_name.clone();
let mut rx = shutdown_rx;
tokio::spawn(async move {
// Wait until the channel holds Some(reason).
if rx.wait_for(|v| v.is_some()).await.is_ok() {
let reason = rx.borrow().clone();
let notifier = crate::rebuild::BotShutdownNotifier::new(
shutdown_transport,
shutdown_rooms,
shutdown_bot_name,
);
if let Some(r) = reason {
notifier.notify(r).await;
}
}
});
}
// Send a startup announcement to each configured room so users know the
// bot is online. This runs once per process start — the sync loop handles
// reconnects internally so this code is never reached again on a network
// blip or sync resumption.
let announce_msg = format_startup_announcement(&announce_bot_name);
let announce_html = markdown_to_html(&announce_msg);
slog!("[matrix-bot] Sending startup announcement: {announce_msg}");
for room_id in &announce_room_ids {
let room_id_str = room_id.to_string();
if let Err(e) = transport
.send_message(&room_id_str, &announce_msg, &announce_html)
.await
{
slog!("[matrix-bot] Failed to send startup announcement to {room_id}: {e}");
}
}
slog!("[matrix-bot] Starting Matrix sync loop");
// Retry state — shared across `Fn` closure invocations via Arc atomics.
const MAX_BACKOFF_SECS: u64 = 300;
const INITIAL_BACKOFF_SECS: u64 = 5;
let backoff = Arc::new(AtomicU64::new(INITIAL_BACKOFF_SECS));
let was_disconnected = Arc::new(AtomicBool::new(false));
let sync_transport = Arc::clone(&transport);
let sync_rooms: Vec<String> = announce_room_ids.iter().map(|r| r.to_string()).collect();
let sync_bot_name = announce_bot_name.clone();
let backoff_cb = Arc::clone(&backoff);
let was_disconnected_cb = Arc::clone(&was_disconnected);
// Use sync_with_result_callback so transient errors (network blips, DNS
// hiccups, temporary homeserver outages) are handled in the callback
// rather than bubbling up as fatal errors. Fatal errors (HTTP 401/403)
// still terminate the loop and propagate to the caller.
client
.sync_with_result_callback(SyncSettings::default(), move |result| {
let backoff = Arc::clone(&backoff_cb);
let was_disconnected = Arc::clone(&was_disconnected_cb);
let recovery_transport = Arc::clone(&sync_transport);
let recovery_rooms = sync_rooms.clone();
let recovery_bot_name = sync_bot_name.clone();
async move {
match result {
Ok(_) => {
// If we previously lost the connection, announce recovery.
if was_disconnected.swap(false, Ordering::Relaxed) {
backoff.store(INITIAL_BACKOFF_SECS, Ordering::Relaxed);
slog!("[matrix-bot] Reconnected to homeserver — resuming normal operation");
let msg = format!(
"⚡ **{recovery_bot_name}** reconnected to homeserver."
);
let html = format!(
"<p>⚡ <strong>{recovery_bot_name}</strong> reconnected to homeserver.</p>"
);
for room_id in &recovery_rooms {
if let Err(e) = recovery_transport
.send_message(room_id, &msg, &html)
.await
{
slog!(
"[matrix-bot] Failed to send recovery notification to {room_id}: {e}"
);
}
}
}
Ok(LoopCtrl::Continue)
}
Err(e) if is_fatal_sync_error(&e) => Err(e),
Err(e) => {
// Transient error: log, back off, and let the stream retry.
let delay = backoff.load(Ordering::Relaxed);
slog!("[matrix-bot] Sync warning (retrying in {delay}s): {e}");
was_disconnected.store(true, Ordering::Relaxed);
tokio::time::sleep(std::time::Duration::from_secs(delay)).await;
let new_delay = (delay * 2).min(MAX_BACKOFF_SECS);
backoff.store(new_delay, Ordering::Relaxed);
Ok(LoopCtrl::Continue)
}
}
}
})
.await
.map_err(|e| format!("Matrix sync error: {e}"))?;
Ok(())
}
/// Returns `true` for errors that indicate the bot's session is permanently
/// invalid (HTTP 401 Unauthorized or 403 Forbidden). All other errors —
/// network failures, timeouts, transient 5xx responses — are considered
/// recoverable and should be retried with exponential back-off.
fn is_fatal_sync_error(e: &matrix_sdk::Error) -> bool {
e.as_client_api_error()
.map(|api_err| {
let code = api_err.status_code.as_u16();
code == 401 || code == 403
})
.unwrap_or(false)
}
#[cfg(test)]
mod tests {
use super::*;
/// An I/O error (e.g. connection refused) must NOT be treated as fatal so
/// that the sync loop retries rather than shutting the bot down.
#[test]
fn io_error_is_not_fatal() {
let e: matrix_sdk::Error =
std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "connection refused")
.into();
assert!(!is_fatal_sync_error(&e));
}
/// Exponential back-off must clamp at MAX_BACKOFF_SECS (300 s) regardless
/// of how many consecutive failures occur.
#[test]
fn backoff_clamps_at_max() {
const MAX_BACKOFF_SECS: u64 = 300;
let mut delay = 5u64;
for _ in 0..20 {
delay = (delay * 2).min(MAX_BACKOFF_SECS);
}
assert_eq!(delay, MAX_BACKOFF_SECS);
}
/// Back-off must at least double each step before clamping.
#[test]
fn backoff_doubles_each_step() {
const MAX_BACKOFF_SECS: u64 = 300;
let steps: Vec<u64> = std::iter::successors(Some(5u64), |&d| {
let next = (d * 2).min(MAX_BACKOFF_SECS);
if next < MAX_BACKOFF_SECS { Some(next) } else { None }
})
.collect();
// First few steps: 5, 10, 20, 40, 80, 160
assert_eq!(steps[0], 5);
assert_eq!(steps[1], 10);
assert_eq!(steps[2], 20);
assert_eq!(steps[3], 40);
}
}