Files
huskies/server/src/chat/transport/matrix/bot/mentions.rs
T

199 lines
7.0 KiB
Rust
Raw Normal View History

use matrix_sdk::ruma::events::room::message::{Relation, RoomMessageEventContentWithoutRelation};
use matrix_sdk::ruma::{OwnedEventId, OwnedUserId};
use std::collections::HashSet;
use tokio::sync::Mutex as TokioMutex;
/// Returns `true` if the message mentions the bot.
///
/// Checks both the plain-text `body` and an optional `formatted_body` (HTML).
/// Recognised forms:
/// - The bot's full Matrix user ID (e.g. `@timmy:homeserver.local`) in either body
/// - The localpart with `@` prefix (e.g. `@timmy`) with word-boundary check
/// - A `matrix.to` link containing the user ID (in `formatted_body`)
///
/// Short mentions are only counted when not immediately followed by an
/// alphanumeric character, hyphen, or underscore to avoid false positives.
pub fn mentions_bot(body: &str, formatted_body: Option<&str>, bot_user_id: &OwnedUserId) -> bool {
let full_id = bot_user_id.as_str();
let localpart = bot_user_id.localpart();
// Check formatted_body for a matrix.to link containing the bot's user ID.
if formatted_body.is_some_and(|html| html.contains(full_id)) {
return true;
}
// Check plain body for the full ID.
if body.contains(full_id) {
return true;
}
// Check plain body for @localpart (e.g. "@timmy") with word boundaries.
if contains_word(body, &format!("@{localpart}")) {
return true;
}
false
}
/// Returns `true` if `haystack` contains `needle` at a word boundary.
pub(super) fn contains_word(haystack: &str, needle: &str) -> bool {
let mut start = 0;
while let Some(rel) = haystack[start..].find(needle) {
let abs = start + rel;
let after = abs + needle.len();
let next = haystack[after..].chars().next();
let is_word_end = next.is_none_or(|c| !c.is_alphanumeric() && c != '-' && c != '_');
if is_word_end {
return true;
}
start = abs + 1;
}
false
}
/// Returns `true` if the message's `relates_to` field references an event that
/// the bot previously sent (i.e. the message is a reply or thread-reply to a
/// bot message).
pub(super) async fn is_reply_to_bot(
relates_to: Option<&Relation<RoomMessageEventContentWithoutRelation>>,
bot_sent_event_ids: &TokioMutex<HashSet<OwnedEventId>>,
) -> bool {
let candidate_ids: Vec<&OwnedEventId> = match relates_to {
Some(Relation::Reply { in_reply_to }) => vec![&in_reply_to.event_id],
Some(Relation::Thread(thread)) => {
let mut ids = vec![&thread.event_id];
if let Some(irti) = &thread.in_reply_to {
ids.push(&irti.event_id);
}
ids
}
_ => return false,
};
let guard = bot_sent_event_ids.lock().await;
candidate_ids.iter().any(|id| guard.contains(*id))
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
fn make_user_id(s: &str) -> OwnedUserId {
s.parse().unwrap()
}
// -- mentions_bot -------------------------------------------------------
#[test]
fn mentions_bot_by_full_id() {
let uid = make_user_id("@timmy:homeserver.local");
assert!(mentions_bot(
"hello @timmy:homeserver.local can you help?",
None,
&uid
));
}
#[test]
fn mentions_bot_by_localpart_at_start() {
let uid = make_user_id("@timmy:homeserver.local");
assert!(mentions_bot("@timmy please list open stories", None, &uid));
}
#[test]
fn mentions_bot_by_localpart_mid_sentence() {
let uid = make_user_id("@timmy:homeserver.local");
assert!(mentions_bot("hey @timmy what's the status?", None, &uid));
}
#[test]
fn mentions_bot_not_mentioned() {
let uid = make_user_id("@timmy:homeserver.local");
assert!(!mentions_bot(
"can someone help me with this PR?",
None,
&uid
));
}
#[test]
fn mentions_bot_no_false_positive_longer_username() {
// "@timmybot" must NOT match "@timmy"
let uid = make_user_id("@timmy:homeserver.local");
assert!(!mentions_bot("hey @timmybot can you help?", None, &uid));
}
#[test]
fn mentions_bot_at_end_of_string() {
let uid = make_user_id("@timmy:homeserver.local");
assert!(mentions_bot("shoutout to @timmy", None, &uid));
}
#[test]
fn mentions_bot_followed_by_comma() {
let uid = make_user_id("@timmy:homeserver.local");
assert!(mentions_bot("@timmy, can you help?", None, &uid));
}
// -- is_reply_to_bot ----------------------------------------------------
#[tokio::test]
async fn is_reply_to_bot_direct_reply_match() {
let sent: Arc<TokioMutex<HashSet<OwnedEventId>>> =
Arc::new(TokioMutex::new(HashSet::new()));
let event_id: OwnedEventId = "$abc123:example.com".parse().unwrap();
sent.lock().await.insert(event_id.clone());
let in_reply_to = matrix_sdk::ruma::events::relation::InReplyTo::new(event_id);
let relates_to: Option<Relation<RoomMessageEventContentWithoutRelation>> =
Some(Relation::Reply { in_reply_to });
assert!(is_reply_to_bot(relates_to.as_ref(), &sent).await);
}
#[tokio::test]
async fn is_reply_to_bot_direct_reply_no_match() {
let sent: Arc<TokioMutex<HashSet<OwnedEventId>>> =
Arc::new(TokioMutex::new(HashSet::new()));
// sent is empty — this event was not sent by the bot
let in_reply_to = matrix_sdk::ruma::events::relation::InReplyTo::new(
"$other:example.com".parse::<OwnedEventId>().unwrap(),
);
let relates_to: Option<Relation<RoomMessageEventContentWithoutRelation>> =
Some(Relation::Reply { in_reply_to });
assert!(!is_reply_to_bot(relates_to.as_ref(), &sent).await);
}
#[tokio::test]
async fn is_reply_to_bot_no_relation() {
let sent: Arc<TokioMutex<HashSet<OwnedEventId>>> =
Arc::new(TokioMutex::new(HashSet::new()));
let relates_to: Option<Relation<RoomMessageEventContentWithoutRelation>> = None;
assert!(!is_reply_to_bot(relates_to.as_ref(), &sent).await);
}
#[tokio::test]
async fn is_reply_to_bot_thread_root_match() {
let sent: Arc<TokioMutex<HashSet<OwnedEventId>>> =
Arc::new(TokioMutex::new(HashSet::new()));
let root_id: OwnedEventId = "$root123:example.com".parse().unwrap();
sent.lock().await.insert(root_id.clone());
// Thread reply where the thread root is the bot's message
let thread = matrix_sdk::ruma::events::relation::Thread::plain(
root_id,
"$latest:example.com".parse::<OwnedEventId>().unwrap(),
);
let relates_to: Option<Relation<RoomMessageEventContentWithoutRelation>> =
Some(Relation::Thread(thread));
assert!(is_reply_to_bot(relates_to.as_ref(), &sent).await);
}
}