Fixed up @ mentions on the bot
This commit is contained in:
@@ -1,6 +1,5 @@
|
|||||||
use crate::llm::providers::claude_code::{ClaudeCodeProvider, ClaudeCodeResult};
|
use crate::llm::providers::claude_code::{ClaudeCodeProvider, ClaudeCodeResult};
|
||||||
use crate::slog;
|
use crate::slog;
|
||||||
use pulldown_cmark::{Options, Parser, html};
|
|
||||||
use matrix_sdk::{
|
use matrix_sdk::{
|
||||||
Client,
|
Client,
|
||||||
config::SyncSettings,
|
config::SyncSettings,
|
||||||
@@ -9,17 +8,18 @@ use matrix_sdk::{
|
|||||||
ruma::{
|
ruma::{
|
||||||
OwnedEventId, OwnedRoomId, OwnedUserId,
|
OwnedEventId, OwnedRoomId, OwnedUserId,
|
||||||
events::room::message::{
|
events::room::message::{
|
||||||
MessageType, OriginalSyncRoomMessageEvent, Relation,
|
MessageType, OriginalSyncRoomMessageEvent, Relation, RoomMessageEventContent,
|
||||||
RoomMessageEventContent, RoomMessageEventContentWithoutRelation,
|
RoomMessageEventContentWithoutRelation,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
use pulldown_cmark::{Options, Parser, html};
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use tokio::sync::watch;
|
|
||||||
use tokio::sync::Mutex as TokioMutex;
|
use tokio::sync::Mutex as TokioMutex;
|
||||||
|
use tokio::sync::watch;
|
||||||
|
|
||||||
use super::config::BotConfig;
|
use super::config::BotConfig;
|
||||||
|
|
||||||
@@ -112,10 +112,7 @@ pub async fn run_bot(config: BotConfig, project_root: PathBuf) -> Result<(), Str
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
slog!(
|
slog!("[matrix-bot] Allowed users: {:?}", config.allowed_users);
|
||||||
"[matrix-bot] Allowed users: {:?}",
|
|
||||||
config.allowed_users
|
|
||||||
);
|
|
||||||
|
|
||||||
// Parse and join all configured rooms.
|
// Parse and join all configured rooms.
|
||||||
let mut target_room_ids: Vec<OwnedRoomId> = Vec::new();
|
let mut target_room_ids: Vec<OwnedRoomId> = Vec::new();
|
||||||
@@ -181,30 +178,46 @@ pub async fn run_bot(config: BotConfig, project_root: PathBuf) -> Result<(), Str
|
|||||||
// Address-filtering helpers
|
// Address-filtering helpers
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
/// Returns `true` if `body` contains a mention of the bot.
|
/// Returns `true` if the message mentions the bot.
|
||||||
///
|
///
|
||||||
/// Two forms are recognised:
|
/// Checks both the plain-text `body` and an optional `formatted_body` (HTML).
|
||||||
/// - The bot's full Matrix user ID (e.g. `@timmy:homeserver.local`)
|
/// Recognised forms:
|
||||||
/// - The bot's local part prefixed with `@` (e.g. `@timmy`)
|
/// - The bot's full Matrix user ID (e.g. `@timmy:homeserver.local`) in either body
|
||||||
|
/// - The localpart with `@` prefix (e.g. `@timmy`) with word-boundary check
|
||||||
|
/// - A `matrix.to` link containing the user ID (in `formatted_body`)
|
||||||
///
|
///
|
||||||
/// A short mention (`@timmy`) is only counted when it is not immediately
|
/// Short mentions are only counted when not immediately followed by an
|
||||||
/// followed by an alphanumeric character, hyphen, or underscore — this avoids
|
/// alphanumeric character, hyphen, or underscore to avoid false positives.
|
||||||
/// false-positives where a longer username (e.g. `@timmybot`) shares the same
|
pub fn mentions_bot(body: &str, formatted_body: Option<&str>, bot_user_id: &OwnedUserId) -> bool {
|
||||||
/// prefix.
|
|
||||||
pub fn mentions_bot(body: &str, bot_user_id: &OwnedUserId) -> bool {
|
|
||||||
let full_id = bot_user_id.as_str();
|
let full_id = bot_user_id.as_str();
|
||||||
|
let localpart = bot_user_id.localpart();
|
||||||
|
|
||||||
|
// Check formatted_body for a matrix.to link containing the bot's user ID.
|
||||||
|
if formatted_body.is_some_and(|html| html.contains(full_id)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check plain body for the full ID.
|
||||||
if body.contains(full_id) {
|
if body.contains(full_id) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
let short = format!("@{}", bot_user_id.localpart());
|
// Check plain body for @localpart (e.g. "@timmy") with word boundaries.
|
||||||
|
if contains_word(body, &format!("@{localpart}")) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns `true` if `haystack` contains `needle` at a word boundary.
|
||||||
|
fn contains_word(haystack: &str, needle: &str) -> bool {
|
||||||
let mut start = 0;
|
let mut start = 0;
|
||||||
while let Some(rel) = body[start..].find(short.as_str()) {
|
while let Some(rel) = haystack[start..].find(needle) {
|
||||||
let abs = start + rel;
|
let abs = start + rel;
|
||||||
let after = abs + short.len();
|
let after = abs + needle.len();
|
||||||
let next = body[after..].chars().next();
|
let next = haystack[after..].chars().next();
|
||||||
let is_word_end =
|
let is_word_end = next.is_none_or(|c| !c.is_alphanumeric() && c != '-' && c != '_');
|
||||||
next.is_none_or(|c| !c.is_alphanumeric() && c != '-' && c != '_');
|
|
||||||
if is_word_end {
|
if is_word_end {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@@ -241,11 +254,7 @@ async fn is_reply_to_bot(
|
|||||||
|
|
||||||
/// Matrix event handler for room messages. Each invocation spawns an
|
/// Matrix event handler for room messages. Each invocation spawns an
|
||||||
/// independent task so the sync loop is not blocked by LLM calls.
|
/// independent task so the sync loop is not blocked by LLM calls.
|
||||||
async fn on_room_message(
|
async fn on_room_message(ev: OriginalSyncRoomMessageEvent, room: Room, Ctx(ctx): Ctx<BotContext>) {
|
||||||
ev: OriginalSyncRoomMessageEvent,
|
|
||||||
room: Room,
|
|
||||||
Ctx(ctx): Ctx<BotContext>,
|
|
||||||
) {
|
|
||||||
let incoming_room_id = room.room_id().to_owned();
|
let incoming_room_id = room.room_id().to_owned();
|
||||||
|
|
||||||
slog!(
|
slog!(
|
||||||
@@ -255,11 +264,7 @@ async fn on_room_message(
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Only handle messages from rooms we are configured to listen in.
|
// Only handle messages from rooms we are configured to listen in.
|
||||||
if !ctx
|
if !ctx.target_room_ids.iter().any(|r| r == &incoming_room_id) {
|
||||||
.target_room_ids
|
|
||||||
.iter()
|
|
||||||
.any(|r| r == &incoming_room_id)
|
|
||||||
{
|
|
||||||
slog!("[matrix-bot] Ignoring message from unconfigured room {incoming_room_id}");
|
slog!("[matrix-bot] Ignoring message from unconfigured room {incoming_room_id}");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -279,19 +284,15 @@ async fn on_room_message(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Only handle plain text messages.
|
// Only handle plain text messages.
|
||||||
let body = match &ev.content.msgtype {
|
let (body, formatted_body) = match &ev.content.msgtype {
|
||||||
MessageType::Text(t) => t.body.clone(),
|
MessageType::Text(t) => (t.body.clone(), t.formatted.as_ref().map(|f| f.body.clone())),
|
||||||
_ => return,
|
_ => return,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Only respond when the bot is directly addressed (mentioned by name/ID)
|
// Only respond when the bot is directly addressed (mentioned by name/ID)
|
||||||
// or when the message is a reply to one of the bot's own messages.
|
// or when the message is a reply to one of the bot's own messages.
|
||||||
if !mentions_bot(&body, &ctx.bot_user_id)
|
if !mentions_bot(&body, formatted_body.as_deref(), &ctx.bot_user_id)
|
||||||
&& !is_reply_to_bot(
|
&& !is_reply_to_bot(ev.content.relates_to.as_ref(), &ctx.bot_sent_event_ids).await
|
||||||
ev.content.relates_to.as_ref(),
|
|
||||||
&ctx.bot_sent_event_ids,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
{
|
||||||
slog!(
|
slog!(
|
||||||
"[matrix-bot] Ignoring unaddressed message from {}",
|
"[matrix-bot] Ignoring unaddressed message from {}",
|
||||||
@@ -358,8 +359,7 @@ async fn handle_message(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Build the prompt with conversation context.
|
// Build the prompt with conversation context.
|
||||||
let prompt_with_context =
|
let prompt_with_context = build_context_prefix(&history_snapshot, &sender, &user_message);
|
||||||
build_context_prefix(&history_snapshot, &sender, &user_message);
|
|
||||||
|
|
||||||
let provider = ClaudeCodeProvider::new();
|
let provider = ClaudeCodeProvider::new();
|
||||||
let (cancel_tx, mut cancel_rx) = watch::channel(false);
|
let (cancel_tx, mut cancel_rx) = watch::channel(false);
|
||||||
@@ -566,44 +566,52 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn mentions_bot_by_full_id() {
|
fn mentions_bot_by_full_id() {
|
||||||
let uid = make_user_id("@timmy:homeserver.local");
|
let uid = make_user_id("@timmy:homeserver.local");
|
||||||
assert!(mentions_bot("hello @timmy:homeserver.local can you help?", &uid));
|
assert!(mentions_bot(
|
||||||
|
"hello @timmy:homeserver.local can you help?",
|
||||||
|
None,
|
||||||
|
&uid
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn mentions_bot_by_localpart_at_start() {
|
fn mentions_bot_by_localpart_at_start() {
|
||||||
let uid = make_user_id("@timmy:homeserver.local");
|
let uid = make_user_id("@timmy:homeserver.local");
|
||||||
assert!(mentions_bot("@timmy please list open stories", &uid));
|
assert!(mentions_bot("@timmy please list open stories", None, &uid));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn mentions_bot_by_localpart_mid_sentence() {
|
fn mentions_bot_by_localpart_mid_sentence() {
|
||||||
let uid = make_user_id("@timmy:homeserver.local");
|
let uid = make_user_id("@timmy:homeserver.local");
|
||||||
assert!(mentions_bot("hey @timmy what's the status?", &uid));
|
assert!(mentions_bot("hey @timmy what's the status?", None, &uid));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn mentions_bot_not_mentioned() {
|
fn mentions_bot_not_mentioned() {
|
||||||
let uid = make_user_id("@timmy:homeserver.local");
|
let uid = make_user_id("@timmy:homeserver.local");
|
||||||
assert!(!mentions_bot("can someone help me with this PR?", &uid));
|
assert!(!mentions_bot(
|
||||||
|
"can someone help me with this PR?",
|
||||||
|
None,
|
||||||
|
&uid
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn mentions_bot_no_false_positive_longer_username() {
|
fn mentions_bot_no_false_positive_longer_username() {
|
||||||
// "@timmybot" must NOT match "@timmy"
|
// "@timmybot" must NOT match "@timmy"
|
||||||
let uid = make_user_id("@timmy:homeserver.local");
|
let uid = make_user_id("@timmy:homeserver.local");
|
||||||
assert!(!mentions_bot("hey @timmybot can you help?", &uid));
|
assert!(!mentions_bot("hey @timmybot can you help?", None, &uid));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn mentions_bot_at_end_of_string() {
|
fn mentions_bot_at_end_of_string() {
|
||||||
let uid = make_user_id("@timmy:homeserver.local");
|
let uid = make_user_id("@timmy:homeserver.local");
|
||||||
assert!(mentions_bot("shoutout to @timmy", &uid));
|
assert!(mentions_bot("shoutout to @timmy", None, &uid));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn mentions_bot_followed_by_comma() {
|
fn mentions_bot_followed_by_comma() {
|
||||||
let uid = make_user_id("@timmy:homeserver.local");
|
let uid = make_user_id("@timmy:homeserver.local");
|
||||||
assert!(mentions_bot("@timmy, can you help?", &uid));
|
assert!(mentions_bot("@timmy, can you help?", None, &uid));
|
||||||
}
|
}
|
||||||
|
|
||||||
// -- is_reply_to_bot ----------------------------------------------------
|
// -- is_reply_to_bot ----------------------------------------------------
|
||||||
@@ -668,20 +676,29 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn markdown_to_html_bold() {
|
fn markdown_to_html_bold() {
|
||||||
let html = markdown_to_html("**bold**");
|
let html = markdown_to_html("**bold**");
|
||||||
assert!(html.contains("<strong>bold</strong>"), "expected <strong>: {html}");
|
assert!(
|
||||||
|
html.contains("<strong>bold</strong>"),
|
||||||
|
"expected <strong>: {html}"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn markdown_to_html_unordered_list() {
|
fn markdown_to_html_unordered_list() {
|
||||||
let html = markdown_to_html("- item one\n- item two");
|
let html = markdown_to_html("- item one\n- item two");
|
||||||
assert!(html.contains("<ul>"), "expected <ul>: {html}");
|
assert!(html.contains("<ul>"), "expected <ul>: {html}");
|
||||||
assert!(html.contains("<li>item one</li>"), "expected list item: {html}");
|
assert!(
|
||||||
|
html.contains("<li>item one</li>"),
|
||||||
|
"expected list item: {html}"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn markdown_to_html_inline_code() {
|
fn markdown_to_html_inline_code() {
|
||||||
let html = markdown_to_html("`inline_code()`");
|
let html = markdown_to_html("`inline_code()`");
|
||||||
assert!(html.contains("<code>inline_code()</code>"), "expected <code>: {html}");
|
assert!(
|
||||||
|
html.contains("<code>inline_code()</code>"),
|
||||||
|
"expected <code>: {html}"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -689,13 +706,19 @@ mod tests {
|
|||||||
let html = markdown_to_html("```rust\nfn main() {}\n```");
|
let html = markdown_to_html("```rust\nfn main() {}\n```");
|
||||||
assert!(html.contains("<pre>"), "expected <pre>: {html}");
|
assert!(html.contains("<pre>"), "expected <pre>: {html}");
|
||||||
assert!(html.contains("<code"), "expected <code> inside pre: {html}");
|
assert!(html.contains("<code"), "expected <code> inside pre: {html}");
|
||||||
assert!(html.contains("fn main() {}"), "expected code content: {html}");
|
assert!(
|
||||||
|
html.contains("fn main() {}"),
|
||||||
|
"expected code content: {html}"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn markdown_to_html_plain_text_passthrough() {
|
fn markdown_to_html_plain_text_passthrough() {
|
||||||
let html = markdown_to_html("Hello, world!");
|
let html = markdown_to_html("Hello, world!");
|
||||||
assert!(html.contains("Hello, world!"), "expected plain text passthrough: {html}");
|
assert!(
|
||||||
|
html.contains("Hello, world!"),
|
||||||
|
"expected plain text passthrough: {html}"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// -- bot_context_is_clone -----------------------------------------------
|
// -- bot_context_is_clone -----------------------------------------------
|
||||||
@@ -773,8 +796,9 @@ mod tests {
|
|||||||
// A blank line inside a fenced code block must NOT trigger a split.
|
// A blank line inside a fenced code block must NOT trigger a split.
|
||||||
// Before the fix the function would split at the blank line and the
|
// Before the fix the function would split at the blank line and the
|
||||||
// second half would be sent without the opening fence, breaking rendering.
|
// second half would be sent without the opening fence, breaking rendering.
|
||||||
let mut buf = "```rust\nfn foo() {\n let x = 1;\n\n let y = 2;\n}\n```\n\nNext paragraph."
|
let mut buf =
|
||||||
.to_string();
|
"```rust\nfn foo() {\n let x = 1;\n\n let y = 2;\n}\n```\n\nNext paragraph."
|
||||||
|
.to_string();
|
||||||
let paras = drain_complete_paragraphs(&mut buf);
|
let paras = drain_complete_paragraphs(&mut buf);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
paras.len(),
|
paras.len(),
|
||||||
@@ -797,8 +821,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn drain_complete_paragraphs_text_before_and_after_fenced_block() {
|
fn drain_complete_paragraphs_text_before_and_after_fenced_block() {
|
||||||
// Text paragraph, then a code block with an internal blank line, then more text.
|
// Text paragraph, then a code block with an internal blank line, then more text.
|
||||||
let mut buf =
|
let mut buf = "Before\n\n```\ncode\n\nmore code\n```\n\nAfter".to_string();
|
||||||
"Before\n\n```\ncode\n\nmore code\n```\n\nAfter".to_string();
|
|
||||||
let paras = drain_complete_paragraphs(&mut buf);
|
let paras = drain_complete_paragraphs(&mut buf);
|
||||||
assert_eq!(paras.len(), 2, "expected two paragraphs: {paras:?}");
|
assert_eq!(paras.len(), 2, "expected two paragraphs: {paras:?}");
|
||||||
assert_eq!(paras[0], "Before");
|
assert_eq!(paras[0], "Before");
|
||||||
@@ -892,8 +915,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn history_trims_to_configured_size() {
|
async fn history_trims_to_configured_size() {
|
||||||
let history: ConversationHistory =
|
let history: ConversationHistory = Arc::new(TokioMutex::new(HashMap::new()));
|
||||||
Arc::new(TokioMutex::new(HashMap::new()));
|
|
||||||
let room_id: OwnedRoomId = "!test:example.com".parse().unwrap();
|
let room_id: OwnedRoomId = "!test:example.com".parse().unwrap();
|
||||||
let history_size = 4usize; // keep at most 4 entries
|
let history_size = 4usize; // keep at most 4 entries
|
||||||
|
|
||||||
@@ -935,23 +957,28 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn each_room_has_independent_history() {
|
async fn each_room_has_independent_history() {
|
||||||
let history: ConversationHistory =
|
let history: ConversationHistory = Arc::new(TokioMutex::new(HashMap::new()));
|
||||||
Arc::new(TokioMutex::new(HashMap::new()));
|
|
||||||
let room_a: OwnedRoomId = "!room_a:example.com".parse().unwrap();
|
let room_a: OwnedRoomId = "!room_a:example.com".parse().unwrap();
|
||||||
let room_b: OwnedRoomId = "!room_b:example.com".parse().unwrap();
|
let room_b: OwnedRoomId = "!room_b:example.com".parse().unwrap();
|
||||||
|
|
||||||
{
|
{
|
||||||
let mut guard = history.lock().await;
|
let mut guard = history.lock().await;
|
||||||
guard.entry(room_a.clone()).or_default().push(ConversationEntry {
|
guard
|
||||||
role: ConversationRole::User,
|
.entry(room_a.clone())
|
||||||
sender: "@alice:example.com".to_string(),
|
.or_default()
|
||||||
content: "Room A message".to_string(),
|
.push(ConversationEntry {
|
||||||
});
|
role: ConversationRole::User,
|
||||||
guard.entry(room_b.clone()).or_default().push(ConversationEntry {
|
sender: "@alice:example.com".to_string(),
|
||||||
role: ConversationRole::User,
|
content: "Room A message".to_string(),
|
||||||
sender: "@bob:example.com".to_string(),
|
});
|
||||||
content: "Room B message".to_string(),
|
guard
|
||||||
});
|
.entry(room_b.clone())
|
||||||
|
.or_default()
|
||||||
|
.push(ConversationEntry {
|
||||||
|
role: ConversationRole::User,
|
||||||
|
sender: "@bob:example.com".to_string(),
|
||||||
|
content: "Room B message".to_string(),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let guard = history.lock().await;
|
let guard = history.lock().await;
|
||||||
|
|||||||
Reference in New Issue
Block a user