From d20f9b80183419ed24190a7e66fe1ff1c6e57d4b Mon Sep 17 00:00:00 2001 From: Dave Date: Wed, 25 Feb 2026 17:00:29 +0000 Subject: [PATCH] Fixed up @ mentions on the bot --- server/src/matrix/bot.rs | 173 ++++++++++++++++++++++----------------- 1 file changed, 100 insertions(+), 73 deletions(-) diff --git a/server/src/matrix/bot.rs b/server/src/matrix/bot.rs index d292537..1dba292 100644 --- a/server/src/matrix/bot.rs +++ b/server/src/matrix/bot.rs @@ -1,6 +1,5 @@ use crate::llm::providers::claude_code::{ClaudeCodeProvider, ClaudeCodeResult}; use crate::slog; -use pulldown_cmark::{Options, Parser, html}; use matrix_sdk::{ Client, config::SyncSettings, @@ -9,17 +8,18 @@ use matrix_sdk::{ ruma::{ OwnedEventId, OwnedRoomId, OwnedUserId, events::room::message::{ - MessageType, OriginalSyncRoomMessageEvent, Relation, - RoomMessageEventContent, RoomMessageEventContentWithoutRelation, + MessageType, OriginalSyncRoomMessageEvent, Relation, RoomMessageEventContent, + RoomMessageEventContentWithoutRelation, }, }, }; +use pulldown_cmark::{Options, Parser, html}; use std::collections::{HashMap, HashSet}; use std::path::PathBuf; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; -use tokio::sync::watch; use tokio::sync::Mutex as TokioMutex; +use tokio::sync::watch; use super::config::BotConfig; @@ -112,10 +112,7 @@ pub async fn run_bot(config: BotConfig, project_root: PathBuf) -> Result<(), Str ); } - slog!( - "[matrix-bot] Allowed users: {:?}", - config.allowed_users - ); + slog!("[matrix-bot] Allowed users: {:?}", config.allowed_users); // Parse and join all configured rooms. let mut target_room_ids: Vec = Vec::new(); @@ -181,30 +178,46 @@ pub async fn run_bot(config: BotConfig, project_root: PathBuf) -> Result<(), Str // Address-filtering helpers // --------------------------------------------------------------------------- -/// Returns `true` if `body` contains a mention of the bot. +/// Returns `true` if the message mentions the bot. /// -/// Two forms are recognised: -/// - The bot's full Matrix user ID (e.g. `@timmy:homeserver.local`) -/// - The bot's local part prefixed with `@` (e.g. `@timmy`) +/// Checks both the plain-text `body` and an optional `formatted_body` (HTML). +/// Recognised forms: +/// - The bot's full Matrix user ID (e.g. `@timmy:homeserver.local`) in either body +/// - The localpart with `@` prefix (e.g. `@timmy`) with word-boundary check +/// - A `matrix.to` link containing the user ID (in `formatted_body`) /// -/// A short mention (`@timmy`) is only counted when it is not immediately -/// followed by an alphanumeric character, hyphen, or underscore — this avoids -/// false-positives where a longer username (e.g. `@timmybot`) shares the same -/// prefix. -pub fn mentions_bot(body: &str, bot_user_id: &OwnedUserId) -> bool { +/// Short mentions are only counted when not immediately followed by an +/// alphanumeric character, hyphen, or underscore to avoid false positives. +pub fn mentions_bot(body: &str, formatted_body: Option<&str>, bot_user_id: &OwnedUserId) -> bool { let full_id = bot_user_id.as_str(); + let localpart = bot_user_id.localpart(); + + // Check formatted_body for a matrix.to link containing the bot's user ID. + if formatted_body.is_some_and(|html| html.contains(full_id)) { + return true; + } + + // Check plain body for the full ID. if body.contains(full_id) { return true; } - let short = format!("@{}", bot_user_id.localpart()); + // Check plain body for @localpart (e.g. "@timmy") with word boundaries. + if contains_word(body, &format!("@{localpart}")) { + return true; + } + + false +} + +/// Returns `true` if `haystack` contains `needle` at a word boundary. +fn contains_word(haystack: &str, needle: &str) -> bool { let mut start = 0; - while let Some(rel) = body[start..].find(short.as_str()) { + while let Some(rel) = haystack[start..].find(needle) { let abs = start + rel; - let after = abs + short.len(); - let next = body[after..].chars().next(); - let is_word_end = - next.is_none_or(|c| !c.is_alphanumeric() && c != '-' && c != '_'); + let after = abs + needle.len(); + let next = haystack[after..].chars().next(); + let is_word_end = next.is_none_or(|c| !c.is_alphanumeric() && c != '-' && c != '_'); if is_word_end { return true; } @@ -241,11 +254,7 @@ async fn is_reply_to_bot( /// Matrix event handler for room messages. Each invocation spawns an /// independent task so the sync loop is not blocked by LLM calls. -async fn on_room_message( - ev: OriginalSyncRoomMessageEvent, - room: Room, - Ctx(ctx): Ctx, -) { +async fn on_room_message(ev: OriginalSyncRoomMessageEvent, room: Room, Ctx(ctx): Ctx) { let incoming_room_id = room.room_id().to_owned(); slog!( @@ -255,11 +264,7 @@ async fn on_room_message( ); // Only handle messages from rooms we are configured to listen in. - if !ctx - .target_room_ids - .iter() - .any(|r| r == &incoming_room_id) - { + if !ctx.target_room_ids.iter().any(|r| r == &incoming_room_id) { slog!("[matrix-bot] Ignoring message from unconfigured room {incoming_room_id}"); return; } @@ -279,19 +284,15 @@ async fn on_room_message( } // Only handle plain text messages. - let body = match &ev.content.msgtype { - MessageType::Text(t) => t.body.clone(), + let (body, formatted_body) = match &ev.content.msgtype { + MessageType::Text(t) => (t.body.clone(), t.formatted.as_ref().map(|f| f.body.clone())), _ => return, }; // Only respond when the bot is directly addressed (mentioned by name/ID) // or when the message is a reply to one of the bot's own messages. - if !mentions_bot(&body, &ctx.bot_user_id) - && !is_reply_to_bot( - ev.content.relates_to.as_ref(), - &ctx.bot_sent_event_ids, - ) - .await + if !mentions_bot(&body, formatted_body.as_deref(), &ctx.bot_user_id) + && !is_reply_to_bot(ev.content.relates_to.as_ref(), &ctx.bot_sent_event_ids).await { slog!( "[matrix-bot] Ignoring unaddressed message from {}", @@ -358,8 +359,7 @@ async fn handle_message( }; // Build the prompt with conversation context. - let prompt_with_context = - build_context_prefix(&history_snapshot, &sender, &user_message); + let prompt_with_context = build_context_prefix(&history_snapshot, &sender, &user_message); let provider = ClaudeCodeProvider::new(); let (cancel_tx, mut cancel_rx) = watch::channel(false); @@ -566,44 +566,52 @@ mod tests { #[test] fn mentions_bot_by_full_id() { let uid = make_user_id("@timmy:homeserver.local"); - assert!(mentions_bot("hello @timmy:homeserver.local can you help?", &uid)); + assert!(mentions_bot( + "hello @timmy:homeserver.local can you help?", + None, + &uid + )); } #[test] fn mentions_bot_by_localpart_at_start() { let uid = make_user_id("@timmy:homeserver.local"); - assert!(mentions_bot("@timmy please list open stories", &uid)); + assert!(mentions_bot("@timmy please list open stories", None, &uid)); } #[test] fn mentions_bot_by_localpart_mid_sentence() { let uid = make_user_id("@timmy:homeserver.local"); - assert!(mentions_bot("hey @timmy what's the status?", &uid)); + assert!(mentions_bot("hey @timmy what's the status?", None, &uid)); } #[test] fn mentions_bot_not_mentioned() { let uid = make_user_id("@timmy:homeserver.local"); - assert!(!mentions_bot("can someone help me with this PR?", &uid)); + assert!(!mentions_bot( + "can someone help me with this PR?", + None, + &uid + )); } #[test] fn mentions_bot_no_false_positive_longer_username() { // "@timmybot" must NOT match "@timmy" let uid = make_user_id("@timmy:homeserver.local"); - assert!(!mentions_bot("hey @timmybot can you help?", &uid)); + assert!(!mentions_bot("hey @timmybot can you help?", None, &uid)); } #[test] fn mentions_bot_at_end_of_string() { let uid = make_user_id("@timmy:homeserver.local"); - assert!(mentions_bot("shoutout to @timmy", &uid)); + assert!(mentions_bot("shoutout to @timmy", None, &uid)); } #[test] fn mentions_bot_followed_by_comma() { let uid = make_user_id("@timmy:homeserver.local"); - assert!(mentions_bot("@timmy, can you help?", &uid)); + assert!(mentions_bot("@timmy, can you help?", None, &uid)); } // -- is_reply_to_bot ---------------------------------------------------- @@ -668,20 +676,29 @@ mod tests { #[test] fn markdown_to_html_bold() { let html = markdown_to_html("**bold**"); - assert!(html.contains("bold"), "expected : {html}"); + assert!( + html.contains("bold"), + "expected : {html}" + ); } #[test] fn markdown_to_html_unordered_list() { let html = markdown_to_html("- item one\n- item two"); assert!(html.contains("
    "), "expected
      : {html}"); - assert!(html.contains("
    • item one
    • "), "expected list item: {html}"); + assert!( + html.contains("
    • item one
    • "), + "expected list item: {html}" + ); } #[test] fn markdown_to_html_inline_code() { let html = markdown_to_html("`inline_code()`"); - assert!(html.contains("inline_code()"), "expected : {html}"); + assert!( + html.contains("inline_code()"), + "expected : {html}" + ); } #[test] @@ -689,13 +706,19 @@ mod tests { let html = markdown_to_html("```rust\nfn main() {}\n```"); assert!(html.contains("
      "), "expected 
      : {html}");
               assert!(html.contains(" inside pre: {html}");
      -        assert!(html.contains("fn main() {}"), "expected code content: {html}");
      +        assert!(
      +            html.contains("fn main() {}"),
      +            "expected code content: {html}"
      +        );
           }
       
           #[test]
           fn markdown_to_html_plain_text_passthrough() {
               let html = markdown_to_html("Hello, world!");
      -        assert!(html.contains("Hello, world!"), "expected plain text passthrough: {html}");
      +        assert!(
      +            html.contains("Hello, world!"),
      +            "expected plain text passthrough: {html}"
      +        );
           }
       
           // -- bot_context_is_clone -----------------------------------------------
      @@ -773,8 +796,9 @@ mod tests {
               // A blank line inside a fenced code block must NOT trigger a split.
               // Before the fix the function would split at the blank line and the
               // second half would be sent without the opening fence, breaking rendering.
      -        let mut buf = "```rust\nfn foo() {\n    let x = 1;\n\n    let y = 2;\n}\n```\n\nNext paragraph."
      -            .to_string();
      +        let mut buf =
      +            "```rust\nfn foo() {\n    let x = 1;\n\n    let y = 2;\n}\n```\n\nNext paragraph."
      +                .to_string();
               let paras = drain_complete_paragraphs(&mut buf);
               assert_eq!(
                   paras.len(),
      @@ -797,8 +821,7 @@ mod tests {
           #[test]
           fn drain_complete_paragraphs_text_before_and_after_fenced_block() {
               // Text paragraph, then a code block with an internal blank line, then more text.
      -        let mut buf =
      -            "Before\n\n```\ncode\n\nmore code\n```\n\nAfter".to_string();
      +        let mut buf = "Before\n\n```\ncode\n\nmore code\n```\n\nAfter".to_string();
               let paras = drain_complete_paragraphs(&mut buf);
               assert_eq!(paras.len(), 2, "expected two paragraphs: {paras:?}");
               assert_eq!(paras[0], "Before");
      @@ -892,8 +915,7 @@ mod tests {
       
           #[tokio::test]
           async fn history_trims_to_configured_size() {
      -        let history: ConversationHistory =
      -            Arc::new(TokioMutex::new(HashMap::new()));
      +        let history: ConversationHistory = Arc::new(TokioMutex::new(HashMap::new()));
               let room_id: OwnedRoomId = "!test:example.com".parse().unwrap();
               let history_size = 4usize; // keep at most 4 entries
       
      @@ -935,23 +957,28 @@ mod tests {
       
           #[tokio::test]
           async fn each_room_has_independent_history() {
      -        let history: ConversationHistory =
      -            Arc::new(TokioMutex::new(HashMap::new()));
      +        let history: ConversationHistory = Arc::new(TokioMutex::new(HashMap::new()));
               let room_a: OwnedRoomId = "!room_a:example.com".parse().unwrap();
               let room_b: OwnedRoomId = "!room_b:example.com".parse().unwrap();
       
               {
                   let mut guard = history.lock().await;
      -            guard.entry(room_a.clone()).or_default().push(ConversationEntry {
      -                role: ConversationRole::User,
      -                sender: "@alice:example.com".to_string(),
      -                content: "Room A message".to_string(),
      -            });
      -            guard.entry(room_b.clone()).or_default().push(ConversationEntry {
      -                role: ConversationRole::User,
      -                sender: "@bob:example.com".to_string(),
      -                content: "Room B message".to_string(),
      -            });
      +            guard
      +                .entry(room_a.clone())
      +                .or_default()
      +                .push(ConversationEntry {
      +                    role: ConversationRole::User,
      +                    sender: "@alice:example.com".to_string(),
      +                    content: "Room A message".to_string(),
      +                });
      +            guard
      +                .entry(room_b.clone())
      +                .or_default()
      +                .push(ConversationEntry {
      +                    role: ConversationRole::User,
      +                    sender: "@bob:example.com".to_string(),
      +                    content: "Room B message".to_string(),
      +                });
               }
       
               let guard = history.lock().await;