diff --git a/server/src/matrix/bot.rs b/server/src/matrix/bot.rs index e6c1f23..a9caacc 100644 --- a/server/src/matrix/bot.rs +++ b/server/src/matrix/bot.rs @@ -7,13 +7,14 @@ use matrix_sdk::{ event_handler::Ctx, room::Room, ruma::{ - OwnedRoomId, OwnedUserId, + OwnedEventId, OwnedRoomId, OwnedUserId, events::room::message::{ - MessageType, OriginalSyncRoomMessageEvent, RoomMessageEventContent, + MessageType, OriginalSyncRoomMessageEvent, Relation, + RoomMessageEventContent, RoomMessageEventContentWithoutRelation, }, }, }; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::path::PathBuf; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; @@ -66,6 +67,10 @@ pub struct BotContext { pub history: ConversationHistory, /// Maximum number of entries to keep per room before trimming the oldest. pub history_size: usize, + /// Event IDs of messages the bot has sent. Used to detect replies to the + /// bot so it can continue a conversation thread without requiring an + /// explicit `@mention` on every follow-up. + pub bot_sent_event_ids: Arc>>, } // --------------------------------------------------------------------------- @@ -154,6 +159,7 @@ pub async fn run_bot(config: BotConfig, project_root: PathBuf) -> Result<(), Str allowed_users: config.allowed_users, history: Arc::new(TokioMutex::new(HashMap::new())), history_size: config.history_size, + bot_sent_event_ids: Arc::new(TokioMutex::new(HashSet::new())), }; // Register event handler and inject shared context @@ -171,6 +177,64 @@ pub async fn run_bot(config: BotConfig, project_root: PathBuf) -> Result<(), Str Ok(()) } +// --------------------------------------------------------------------------- +// Address-filtering helpers +// --------------------------------------------------------------------------- + +/// Returns `true` if `body` contains a mention of the bot. +/// +/// Two forms are recognised: +/// - The bot's full Matrix user ID (e.g. `@timmy:homeserver.local`) +/// - The bot's local part prefixed with `@` (e.g. `@timmy`) +/// +/// A short mention (`@timmy`) is only counted when it is not immediately +/// followed by an alphanumeric character, hyphen, or underscore — this avoids +/// false-positives where a longer username (e.g. `@timmybot`) shares the same +/// prefix. +pub fn mentions_bot(body: &str, bot_user_id: &OwnedUserId) -> bool { + let full_id = bot_user_id.as_str(); + if body.contains(full_id) { + return true; + } + + let short = format!("@{}", bot_user_id.localpart()); + let mut start = 0; + while let Some(rel) = body[start..].find(short.as_str()) { + let abs = start + rel; + let after = abs + short.len(); + let next = body[after..].chars().next(); + let is_word_end = + next.is_none_or(|c| !c.is_alphanumeric() && c != '-' && c != '_'); + if is_word_end { + return true; + } + start = abs + 1; + } + false +} + +/// Returns `true` if the message's `relates_to` field references an event that +/// the bot previously sent (i.e. the message is a reply or thread-reply to a +/// bot message). +async fn is_reply_to_bot( + relates_to: Option<&Relation>, + bot_sent_event_ids: &TokioMutex>, +) -> bool { + let candidate_ids: Vec<&OwnedEventId> = match relates_to { + Some(Relation::Reply { in_reply_to }) => vec![&in_reply_to.event_id], + Some(Relation::Thread(thread)) => { + let mut ids = vec![&thread.event_id]; + if let Some(irti) = &thread.in_reply_to { + ids.push(&irti.event_id); + } + ids + } + _ => return false, + }; + let guard = bot_sent_event_ids.lock().await; + candidate_ids.iter().any(|id| guard.contains(*id)) +} + // --------------------------------------------------------------------------- // Event handler // --------------------------------------------------------------------------- @@ -215,12 +279,29 @@ async fn on_room_message( } // Only handle plain text messages. - let MessageType::Text(text_content) = ev.content.msgtype else { - return; + let body = match &ev.content.msgtype { + MessageType::Text(t) => t.body.clone(), + _ => return, }; + // Only respond when the bot is directly addressed (mentioned by name/ID) + // or when the message is a reply to one of the bot's own messages. + if !mentions_bot(&body, &ctx.bot_user_id) + && !is_reply_to_bot( + ev.content.relates_to.as_ref(), + &ctx.bot_sent_event_ids, + ) + .await + { + slog!( + "[matrix-bot] Ignoring unaddressed message from {}", + ev.sender + ); + return; + } + let sender = ev.sender.to_string(); - let user_message = text_content.body.clone(); + let user_message = body; slog!("[matrix-bot] Message from {sender}: {user_message}"); // Spawn a separate task so the Matrix sync loop is not blocked while we @@ -292,12 +373,16 @@ async fn handle_message( // Spawn a task to post messages to Matrix as they arrive so we don't // block the LLM stream while waiting for Matrix send round-trips. let post_room = room.clone(); + let sent_ids = Arc::clone(&ctx.bot_sent_event_ids); let post_task = tokio::spawn(async move { while let Some(chunk) = msg_rx.recv().await { let html = markdown_to_html(&chunk); - let _ = post_room + if let Ok(response) = post_room .send(RoomMessageEventContent::text_html(chunk, html)) - .await; + .await + { + sent_ids.lock().await.insert(response.event_id); + } } }); @@ -439,6 +524,112 @@ pub fn drain_complete_paragraphs(buffer: &mut String) -> Vec { mod tests { use super::*; + // -- mentions_bot ------------------------------------------------------- + + fn make_user_id(s: &str) -> OwnedUserId { + s.parse().unwrap() + } + + #[test] + fn mentions_bot_by_full_id() { + let uid = make_user_id("@timmy:homeserver.local"); + assert!(mentions_bot("hello @timmy:homeserver.local can you help?", &uid)); + } + + #[test] + fn mentions_bot_by_localpart_at_start() { + let uid = make_user_id("@timmy:homeserver.local"); + assert!(mentions_bot("@timmy please list open stories", &uid)); + } + + #[test] + fn mentions_bot_by_localpart_mid_sentence() { + let uid = make_user_id("@timmy:homeserver.local"); + assert!(mentions_bot("hey @timmy what's the status?", &uid)); + } + + #[test] + fn mentions_bot_not_mentioned() { + let uid = make_user_id("@timmy:homeserver.local"); + assert!(!mentions_bot("can someone help me with this PR?", &uid)); + } + + #[test] + fn mentions_bot_no_false_positive_longer_username() { + // "@timmybot" must NOT match "@timmy" + let uid = make_user_id("@timmy:homeserver.local"); + assert!(!mentions_bot("hey @timmybot can you help?", &uid)); + } + + #[test] + fn mentions_bot_at_end_of_string() { + let uid = make_user_id("@timmy:homeserver.local"); + assert!(mentions_bot("shoutout to @timmy", &uid)); + } + + #[test] + fn mentions_bot_followed_by_comma() { + let uid = make_user_id("@timmy:homeserver.local"); + assert!(mentions_bot("@timmy, can you help?", &uid)); + } + + // -- is_reply_to_bot ---------------------------------------------------- + + #[tokio::test] + async fn is_reply_to_bot_direct_reply_match() { + let sent: Arc>> = + Arc::new(TokioMutex::new(HashSet::new())); + let event_id: OwnedEventId = "$abc123:example.com".parse().unwrap(); + sent.lock().await.insert(event_id.clone()); + + let in_reply_to = matrix_sdk::ruma::events::relation::InReplyTo::new(event_id); + let relates_to: Option> = + Some(Relation::Reply { in_reply_to }); + + assert!(is_reply_to_bot(relates_to.as_ref(), &sent).await); + } + + #[tokio::test] + async fn is_reply_to_bot_direct_reply_no_match() { + let sent: Arc>> = + Arc::new(TokioMutex::new(HashSet::new())); + // sent is empty — this event was not sent by the bot + + let in_reply_to = matrix_sdk::ruma::events::relation::InReplyTo::new( + "$other:example.com".parse::().unwrap(), + ); + let relates_to: Option> = + Some(Relation::Reply { in_reply_to }); + + assert!(!is_reply_to_bot(relates_to.as_ref(), &sent).await); + } + + #[tokio::test] + async fn is_reply_to_bot_no_relation() { + let sent: Arc>> = + Arc::new(TokioMutex::new(HashSet::new())); + let relates_to: Option> = None; + assert!(!is_reply_to_bot(relates_to.as_ref(), &sent).await); + } + + #[tokio::test] + async fn is_reply_to_bot_thread_root_match() { + let sent: Arc>> = + Arc::new(TokioMutex::new(HashSet::new())); + let root_id: OwnedEventId = "$root123:example.com".parse().unwrap(); + sent.lock().await.insert(root_id.clone()); + + // Thread reply where the thread root is the bot's message + let thread = matrix_sdk::ruma::events::relation::Thread::plain( + root_id, + "$latest:example.com".parse::().unwrap(), + ); + let relates_to: Option> = + Some(Relation::Thread(thread)); + + assert!(is_reply_to_bot(relates_to.as_ref(), &sent).await); + } + // -- markdown_to_html --------------------------------------------------- #[test]