966 lines
34 KiB
Rust
966 lines
34 KiB
Rust
use crate::llm::providers::claude_code::{ClaudeCodeProvider, ClaudeCodeResult};
|
|
use crate::slog;
|
|
use pulldown_cmark::{Options, Parser, html};
|
|
use matrix_sdk::{
|
|
Client,
|
|
config::SyncSettings,
|
|
event_handler::Ctx,
|
|
room::Room,
|
|
ruma::{
|
|
OwnedEventId, OwnedRoomId, OwnedUserId,
|
|
events::room::message::{
|
|
MessageType, OriginalSyncRoomMessageEvent, Relation,
|
|
RoomMessageEventContent, RoomMessageEventContentWithoutRelation,
|
|
},
|
|
},
|
|
};
|
|
use std::collections::{HashMap, HashSet};
|
|
use std::path::PathBuf;
|
|
use std::sync::Arc;
|
|
use std::sync::atomic::{AtomicBool, Ordering};
|
|
use tokio::sync::watch;
|
|
use tokio::sync::Mutex as TokioMutex;
|
|
|
|
use super::config::BotConfig;
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Conversation history types
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Role of a participant in the conversation history.
|
|
#[derive(Clone, Debug, PartialEq)]
|
|
pub enum ConversationRole {
|
|
/// A message sent by a Matrix room participant.
|
|
User,
|
|
/// A response generated by the bot / LLM.
|
|
Assistant,
|
|
}
|
|
|
|
/// A single turn in the per-room conversation history.
|
|
#[derive(Clone, Debug)]
|
|
pub struct ConversationEntry {
|
|
pub role: ConversationRole,
|
|
/// Matrix user ID (e.g. `@alice:example.com`). Empty for assistant turns.
|
|
pub sender: String,
|
|
pub content: String,
|
|
}
|
|
|
|
/// Per-room conversation history, keyed by room ID.
|
|
///
|
|
/// Wrapped in `Arc<TokioMutex<…>>` so it can be shared across concurrent
|
|
/// event-handler tasks without blocking the sync loop.
|
|
pub type ConversationHistory = Arc<TokioMutex<HashMap<OwnedRoomId, Vec<ConversationEntry>>>>;
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Bot context
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Shared context injected into Matrix event handlers.
|
|
#[derive(Clone)]
|
|
pub struct BotContext {
|
|
pub bot_user_id: OwnedUserId,
|
|
/// All room IDs the bot listens in.
|
|
pub target_room_ids: Vec<OwnedRoomId>,
|
|
pub project_root: PathBuf,
|
|
pub allowed_users: Vec<String>,
|
|
/// Shared, per-room rolling conversation history.
|
|
pub history: ConversationHistory,
|
|
/// Maximum number of entries to keep per room before trimming the oldest.
|
|
pub history_size: usize,
|
|
/// Event IDs of messages the bot has sent. Used to detect replies to the
|
|
/// bot so it can continue a conversation thread without requiring an
|
|
/// explicit `@mention` on every follow-up.
|
|
pub bot_sent_event_ids: Arc<TokioMutex<HashSet<OwnedEventId>>>,
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Bot entry point
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Connect to the Matrix homeserver, join all configured rooms, and start
|
|
/// listening for messages. Runs the full Matrix sync loop — call from a
|
|
/// `tokio::spawn` task so it doesn't block the main thread.
|
|
pub async fn run_bot(config: BotConfig, project_root: PathBuf) -> Result<(), String> {
|
|
let store_path = project_root.join(".story_kit").join("matrix_store");
|
|
let client = Client::builder()
|
|
.homeserver_url(&config.homeserver)
|
|
.sqlite_store(&store_path, None)
|
|
.build()
|
|
.await
|
|
.map_err(|e| format!("Failed to build Matrix client: {e}"))?;
|
|
|
|
// Login
|
|
client
|
|
.matrix_auth()
|
|
.login_username(&config.username, &config.password)
|
|
.initial_device_display_name("Story Kit Bot")
|
|
.await
|
|
.map_err(|e| format!("Matrix login failed: {e}"))?;
|
|
|
|
let bot_user_id = client
|
|
.user_id()
|
|
.ok_or_else(|| "No user ID after login".to_string())?
|
|
.to_owned();
|
|
|
|
slog!("[matrix-bot] Logged in as {bot_user_id}");
|
|
|
|
if config.allowed_users.is_empty() {
|
|
return Err(
|
|
"allowed_users is empty in bot.toml — refusing to start (fail-closed). \
|
|
Add at least one Matrix user ID to allowed_users."
|
|
.to_string(),
|
|
);
|
|
}
|
|
|
|
slog!(
|
|
"[matrix-bot] Allowed users: {:?}",
|
|
config.allowed_users
|
|
);
|
|
|
|
// Parse and join all configured rooms.
|
|
let mut target_room_ids: Vec<OwnedRoomId> = Vec::new();
|
|
for room_id_str in config.effective_room_ids() {
|
|
let room_id: OwnedRoomId = room_id_str
|
|
.parse()
|
|
.map_err(|_| format!("Invalid room ID '{room_id_str}'"))?;
|
|
|
|
// Try to join with a timeout. Conduit sometimes hangs or returns
|
|
// errors on join if the bot is already a member.
|
|
match tokio::time::timeout(
|
|
std::time::Duration::from_secs(10),
|
|
client.join_room_by_id(&room_id),
|
|
)
|
|
.await
|
|
{
|
|
Ok(Ok(_)) => slog!("[matrix-bot] Joined room {room_id}"),
|
|
Ok(Err(e)) => {
|
|
slog!("[matrix-bot] Join room error (may already be a member): {e}")
|
|
}
|
|
Err(_) => slog!("[matrix-bot] Join room timed out (may already be a member)"),
|
|
}
|
|
|
|
target_room_ids.push(room_id);
|
|
}
|
|
|
|
if target_room_ids.is_empty() {
|
|
return Err("No valid room IDs configured — cannot start".to_string());
|
|
}
|
|
|
|
slog!(
|
|
"[matrix-bot] Listening in {} room(s): {:?}",
|
|
target_room_ids.len(),
|
|
target_room_ids
|
|
);
|
|
|
|
let ctx = BotContext {
|
|
bot_user_id,
|
|
target_room_ids,
|
|
project_root,
|
|
allowed_users: config.allowed_users,
|
|
history: Arc::new(TokioMutex::new(HashMap::new())),
|
|
history_size: config.history_size,
|
|
bot_sent_event_ids: Arc::new(TokioMutex::new(HashSet::new())),
|
|
};
|
|
|
|
// Register event handler and inject shared context
|
|
client.add_event_handler_context(ctx);
|
|
client.add_event_handler(on_room_message);
|
|
|
|
slog!("[matrix-bot] Starting Matrix sync loop");
|
|
|
|
// This blocks until the connection is terminated or an error occurs.
|
|
client
|
|
.sync(SyncSettings::default())
|
|
.await
|
|
.map_err(|e| format!("Matrix sync error: {e}"))?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Address-filtering helpers
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Returns `true` if `body` contains a mention of the bot.
|
|
///
|
|
/// Two forms are recognised:
|
|
/// - The bot's full Matrix user ID (e.g. `@timmy:homeserver.local`)
|
|
/// - The bot's local part prefixed with `@` (e.g. `@timmy`)
|
|
///
|
|
/// A short mention (`@timmy`) is only counted when it is not immediately
|
|
/// followed by an alphanumeric character, hyphen, or underscore — this avoids
|
|
/// false-positives where a longer username (e.g. `@timmybot`) shares the same
|
|
/// prefix.
|
|
pub fn mentions_bot(body: &str, bot_user_id: &OwnedUserId) -> bool {
|
|
let full_id = bot_user_id.as_str();
|
|
if body.contains(full_id) {
|
|
return true;
|
|
}
|
|
|
|
let short = format!("@{}", bot_user_id.localpart());
|
|
let mut start = 0;
|
|
while let Some(rel) = body[start..].find(short.as_str()) {
|
|
let abs = start + rel;
|
|
let after = abs + short.len();
|
|
let next = body[after..].chars().next();
|
|
let is_word_end =
|
|
next.is_none_or(|c| !c.is_alphanumeric() && c != '-' && c != '_');
|
|
if is_word_end {
|
|
return true;
|
|
}
|
|
start = abs + 1;
|
|
}
|
|
false
|
|
}
|
|
|
|
/// Returns `true` if the message's `relates_to` field references an event that
|
|
/// the bot previously sent (i.e. the message is a reply or thread-reply to a
|
|
/// bot message).
|
|
async fn is_reply_to_bot(
|
|
relates_to: Option<&Relation<RoomMessageEventContentWithoutRelation>>,
|
|
bot_sent_event_ids: &TokioMutex<HashSet<OwnedEventId>>,
|
|
) -> bool {
|
|
let candidate_ids: Vec<&OwnedEventId> = match relates_to {
|
|
Some(Relation::Reply { in_reply_to }) => vec![&in_reply_to.event_id],
|
|
Some(Relation::Thread(thread)) => {
|
|
let mut ids = vec![&thread.event_id];
|
|
if let Some(irti) = &thread.in_reply_to {
|
|
ids.push(&irti.event_id);
|
|
}
|
|
ids
|
|
}
|
|
_ => return false,
|
|
};
|
|
let guard = bot_sent_event_ids.lock().await;
|
|
candidate_ids.iter().any(|id| guard.contains(*id))
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Event handler
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Matrix event handler for room messages. Each invocation spawns an
|
|
/// independent task so the sync loop is not blocked by LLM calls.
|
|
async fn on_room_message(
|
|
ev: OriginalSyncRoomMessageEvent,
|
|
room: Room,
|
|
Ctx(ctx): Ctx<BotContext>,
|
|
) {
|
|
let incoming_room_id = room.room_id().to_owned();
|
|
|
|
slog!(
|
|
"[matrix-bot] Event received: room={} sender={}",
|
|
incoming_room_id,
|
|
ev.sender,
|
|
);
|
|
|
|
// Only handle messages from rooms we are configured to listen in.
|
|
if !ctx
|
|
.target_room_ids
|
|
.iter()
|
|
.any(|r| r == &incoming_room_id)
|
|
{
|
|
slog!("[matrix-bot] Ignoring message from unconfigured room {incoming_room_id}");
|
|
return;
|
|
}
|
|
|
|
// Ignore the bot's own messages to prevent echo loops.
|
|
if ev.sender == ctx.bot_user_id {
|
|
return;
|
|
}
|
|
|
|
// Only respond to users on the allowlist (fail-closed).
|
|
if !ctx.allowed_users.iter().any(|u| u == ev.sender.as_str()) {
|
|
slog!(
|
|
"[matrix-bot] Ignoring message from unauthorised user: {}",
|
|
ev.sender
|
|
);
|
|
return;
|
|
}
|
|
|
|
// Only handle plain text messages.
|
|
let body = match &ev.content.msgtype {
|
|
MessageType::Text(t) => t.body.clone(),
|
|
_ => return,
|
|
};
|
|
|
|
// Only respond when the bot is directly addressed (mentioned by name/ID)
|
|
// or when the message is a reply to one of the bot's own messages.
|
|
if !mentions_bot(&body, &ctx.bot_user_id)
|
|
&& !is_reply_to_bot(
|
|
ev.content.relates_to.as_ref(),
|
|
&ctx.bot_sent_event_ids,
|
|
)
|
|
.await
|
|
{
|
|
slog!(
|
|
"[matrix-bot] Ignoring unaddressed message from {}",
|
|
ev.sender
|
|
);
|
|
return;
|
|
}
|
|
|
|
let sender = ev.sender.to_string();
|
|
let user_message = body;
|
|
slog!("[matrix-bot] Message from {sender}: {user_message}");
|
|
|
|
// Spawn a separate task so the Matrix sync loop is not blocked while we
|
|
// wait for the LLM response (which can take several seconds).
|
|
tokio::spawn(async move {
|
|
handle_message(room, incoming_room_id, ctx, sender, user_message).await;
|
|
});
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Message handler
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Build a context string from the room's conversation history to prepend to
|
|
/// the user's current message. Returns an empty string when history is empty.
|
|
fn build_context_prefix(
|
|
history: &[ConversationEntry],
|
|
current_sender: &str,
|
|
current_message: &str,
|
|
) -> String {
|
|
if history.is_empty() {
|
|
return format!("{current_sender}: {current_message}");
|
|
}
|
|
|
|
let mut out = String::from("[Conversation history for this room]\n");
|
|
for entry in history {
|
|
match entry.role {
|
|
ConversationRole::User => {
|
|
out.push_str(&format!("User ({}): {}\n", entry.sender, entry.content));
|
|
}
|
|
ConversationRole::Assistant => {
|
|
out.push_str(&format!("Assistant: {}\n", entry.content));
|
|
}
|
|
}
|
|
}
|
|
out.push('\n');
|
|
out.push_str(&format!(
|
|
"Current message from {current_sender}: {current_message}"
|
|
));
|
|
out
|
|
}
|
|
|
|
async fn handle_message(
|
|
room: Room,
|
|
room_id: OwnedRoomId,
|
|
ctx: BotContext,
|
|
sender: String,
|
|
user_message: String,
|
|
) {
|
|
// Read current history for this room before calling the LLM.
|
|
let history_snapshot: Vec<ConversationEntry> = {
|
|
let guard = ctx.history.lock().await;
|
|
guard.get(&room_id).cloned().unwrap_or_default()
|
|
};
|
|
|
|
// Build the prompt with conversation context.
|
|
let prompt_with_context =
|
|
build_context_prefix(&history_snapshot, &sender, &user_message);
|
|
|
|
let provider = ClaudeCodeProvider::new();
|
|
let (cancel_tx, mut cancel_rx) = watch::channel(false);
|
|
// Keep the sender alive for the duration of the call.
|
|
let _cancel_tx = cancel_tx;
|
|
|
|
// Channel for sending complete paragraphs to the Matrix posting task.
|
|
let (msg_tx, mut msg_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
|
|
let msg_tx_for_callback = msg_tx.clone();
|
|
|
|
// Spawn a task to post messages to Matrix as they arrive so we don't
|
|
// block the LLM stream while waiting for Matrix send round-trips.
|
|
let post_room = room.clone();
|
|
let sent_ids = Arc::clone(&ctx.bot_sent_event_ids);
|
|
let post_task = tokio::spawn(async move {
|
|
while let Some(chunk) = msg_rx.recv().await {
|
|
let html = markdown_to_html(&chunk);
|
|
if let Ok(response) = post_room
|
|
.send(RoomMessageEventContent::text_html(chunk, html))
|
|
.await
|
|
{
|
|
sent_ids.lock().await.insert(response.event_id);
|
|
}
|
|
}
|
|
});
|
|
|
|
// Shared state between the sync token callback and the async outer scope.
|
|
let buffer = Arc::new(std::sync::Mutex::new(String::new()));
|
|
let buffer_for_callback = Arc::clone(&buffer);
|
|
let sent_any_chunk = Arc::new(AtomicBool::new(false));
|
|
let sent_any_chunk_for_callback = Arc::clone(&sent_any_chunk);
|
|
|
|
let result = provider
|
|
.chat_stream(
|
|
&prompt_with_context,
|
|
&ctx.project_root.to_string_lossy(),
|
|
None, // Each Matrix conversation turn is independent at the Claude Code session level.
|
|
&mut cancel_rx,
|
|
move |token| {
|
|
let mut buf = buffer_for_callback.lock().unwrap();
|
|
buf.push_str(token);
|
|
// Flush complete paragraphs as they arrive.
|
|
let paragraphs = drain_complete_paragraphs(&mut buf);
|
|
for chunk in paragraphs {
|
|
sent_any_chunk_for_callback.store(true, Ordering::Relaxed);
|
|
let _ = msg_tx_for_callback.send(chunk);
|
|
}
|
|
},
|
|
|_thinking| {}, // Discard thinking tokens
|
|
|_activity| {}, // Discard activity signals
|
|
)
|
|
.await;
|
|
|
|
// Flush any remaining text that didn't end with a paragraph boundary.
|
|
let remaining = buffer.lock().unwrap().trim().to_string();
|
|
let did_send_any = sent_any_chunk.load(Ordering::Relaxed);
|
|
|
|
let assistant_reply = match result {
|
|
Ok(ClaudeCodeResult { messages, .. }) => {
|
|
if !remaining.is_empty() {
|
|
let _ = msg_tx.send(remaining.clone());
|
|
remaining
|
|
} else if !did_send_any {
|
|
// Nothing was streamed at all (e.g. only tool calls with no
|
|
// final text) — fall back to the last assistant message from
|
|
// the structured result.
|
|
let last_text = messages
|
|
.iter()
|
|
.rev()
|
|
.find(|m| m.role == crate::llm::types::Role::Assistant && !m.content.is_empty())
|
|
.map(|m| m.content.clone())
|
|
.unwrap_or_default();
|
|
if !last_text.is_empty() {
|
|
let _ = msg_tx.send(last_text.clone());
|
|
}
|
|
last_text
|
|
} else {
|
|
remaining
|
|
}
|
|
}
|
|
Err(e) => {
|
|
slog!("[matrix-bot] LLM error: {e}");
|
|
let err_msg = format!("Error processing your request: {e}");
|
|
let _ = msg_tx.send(err_msg.clone());
|
|
err_msg
|
|
}
|
|
};
|
|
|
|
// Drop the sender to signal the posting task that no more messages will
|
|
// arrive, then wait for all pending Matrix sends to complete.
|
|
drop(msg_tx);
|
|
let _ = post_task.await;
|
|
|
|
// Record this exchange in the per-room conversation history.
|
|
if !assistant_reply.starts_with("Error processing") {
|
|
let mut guard = ctx.history.lock().await;
|
|
let entries = guard.entry(room_id).or_default();
|
|
entries.push(ConversationEntry {
|
|
role: ConversationRole::User,
|
|
sender: sender.clone(),
|
|
content: user_message,
|
|
});
|
|
entries.push(ConversationEntry {
|
|
role: ConversationRole::Assistant,
|
|
sender: String::new(),
|
|
content: assistant_reply,
|
|
});
|
|
// Trim to the configured maximum, dropping the oldest entries first.
|
|
if entries.len() > ctx.history_size {
|
|
let excess = entries.len() - ctx.history_size;
|
|
entries.drain(..excess);
|
|
}
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Markdown rendering helper
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Convert a Markdown string to an HTML string using pulldown-cmark.
|
|
///
|
|
/// Enables the standard extension set (tables, footnotes, strikethrough,
|
|
/// tasklists) so that common Markdown constructs render correctly in Matrix
|
|
/// clients such as Element.
|
|
pub fn markdown_to_html(markdown: &str) -> String {
|
|
let options = Options::ENABLE_TABLES
|
|
| Options::ENABLE_FOOTNOTES
|
|
| Options::ENABLE_STRIKETHROUGH
|
|
| Options::ENABLE_TASKLISTS;
|
|
let parser = Parser::new_ext(markdown, options);
|
|
let mut html_output = String::new();
|
|
html::push_html(&mut html_output, parser);
|
|
html_output
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Paragraph buffering helper
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Returns `true` when `text` ends while inside an open fenced code block.
|
|
///
|
|
/// A fenced code block opens and closes on lines that start with ` ``` `
|
|
/// (three or more backticks). We count the fence markers and return `true`
|
|
/// when the count is odd (a fence was opened but not yet closed).
|
|
fn is_inside_code_fence(text: &str) -> bool {
|
|
let mut in_fence = false;
|
|
for line in text.lines() {
|
|
if line.trim_start().starts_with("```") {
|
|
in_fence = !in_fence;
|
|
}
|
|
}
|
|
in_fence
|
|
}
|
|
|
|
/// Drain all complete paragraphs from `buffer` and return them.
|
|
///
|
|
/// A paragraph boundary is a double newline (`\n\n`). Each drained paragraph
|
|
/// is trimmed of surrounding whitespace; empty paragraphs are discarded.
|
|
/// The buffer is left with only the remaining incomplete text.
|
|
///
|
|
/// **Code-fence awareness:** a `\n\n` that occurs *inside* a fenced code
|
|
/// block (delimited by ` ``` ` lines) is **not** treated as a paragraph
|
|
/// boundary. This prevents a blank line inside a code block from splitting
|
|
/// the fence across multiple Matrix messages, which would corrupt the
|
|
/// rendering of the second half.
|
|
pub fn drain_complete_paragraphs(buffer: &mut String) -> Vec<String> {
|
|
let mut paragraphs = Vec::new();
|
|
let mut search_from = 0;
|
|
loop {
|
|
let Some(pos) = buffer[search_from..].find("\n\n") else {
|
|
break;
|
|
};
|
|
let abs_pos = search_from + pos;
|
|
// Only split at this boundary when we are NOT inside a code fence.
|
|
if is_inside_code_fence(&buffer[..abs_pos]) {
|
|
// Skip past this \n\n and keep looking for the next boundary.
|
|
search_from = abs_pos + 2;
|
|
} else {
|
|
let chunk = buffer[..abs_pos].trim().to_string();
|
|
*buffer = buffer[abs_pos + 2..].to_string();
|
|
search_from = 0;
|
|
if !chunk.is_empty() {
|
|
paragraphs.push(chunk);
|
|
}
|
|
}
|
|
}
|
|
paragraphs
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Tests
|
|
// ---------------------------------------------------------------------------
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
// -- mentions_bot -------------------------------------------------------
|
|
|
|
fn make_user_id(s: &str) -> OwnedUserId {
|
|
s.parse().unwrap()
|
|
}
|
|
|
|
#[test]
|
|
fn mentions_bot_by_full_id() {
|
|
let uid = make_user_id("@timmy:homeserver.local");
|
|
assert!(mentions_bot("hello @timmy:homeserver.local can you help?", &uid));
|
|
}
|
|
|
|
#[test]
|
|
fn mentions_bot_by_localpart_at_start() {
|
|
let uid = make_user_id("@timmy:homeserver.local");
|
|
assert!(mentions_bot("@timmy please list open stories", &uid));
|
|
}
|
|
|
|
#[test]
|
|
fn mentions_bot_by_localpart_mid_sentence() {
|
|
let uid = make_user_id("@timmy:homeserver.local");
|
|
assert!(mentions_bot("hey @timmy what's the status?", &uid));
|
|
}
|
|
|
|
#[test]
|
|
fn mentions_bot_not_mentioned() {
|
|
let uid = make_user_id("@timmy:homeserver.local");
|
|
assert!(!mentions_bot("can someone help me with this PR?", &uid));
|
|
}
|
|
|
|
#[test]
|
|
fn mentions_bot_no_false_positive_longer_username() {
|
|
// "@timmybot" must NOT match "@timmy"
|
|
let uid = make_user_id("@timmy:homeserver.local");
|
|
assert!(!mentions_bot("hey @timmybot can you help?", &uid));
|
|
}
|
|
|
|
#[test]
|
|
fn mentions_bot_at_end_of_string() {
|
|
let uid = make_user_id("@timmy:homeserver.local");
|
|
assert!(mentions_bot("shoutout to @timmy", &uid));
|
|
}
|
|
|
|
#[test]
|
|
fn mentions_bot_followed_by_comma() {
|
|
let uid = make_user_id("@timmy:homeserver.local");
|
|
assert!(mentions_bot("@timmy, can you help?", &uid));
|
|
}
|
|
|
|
// -- is_reply_to_bot ----------------------------------------------------
|
|
|
|
#[tokio::test]
|
|
async fn is_reply_to_bot_direct_reply_match() {
|
|
let sent: Arc<TokioMutex<HashSet<OwnedEventId>>> =
|
|
Arc::new(TokioMutex::new(HashSet::new()));
|
|
let event_id: OwnedEventId = "$abc123:example.com".parse().unwrap();
|
|
sent.lock().await.insert(event_id.clone());
|
|
|
|
let in_reply_to = matrix_sdk::ruma::events::relation::InReplyTo::new(event_id);
|
|
let relates_to: Option<Relation<RoomMessageEventContentWithoutRelation>> =
|
|
Some(Relation::Reply { in_reply_to });
|
|
|
|
assert!(is_reply_to_bot(relates_to.as_ref(), &sent).await);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn is_reply_to_bot_direct_reply_no_match() {
|
|
let sent: Arc<TokioMutex<HashSet<OwnedEventId>>> =
|
|
Arc::new(TokioMutex::new(HashSet::new()));
|
|
// sent is empty — this event was not sent by the bot
|
|
|
|
let in_reply_to = matrix_sdk::ruma::events::relation::InReplyTo::new(
|
|
"$other:example.com".parse::<OwnedEventId>().unwrap(),
|
|
);
|
|
let relates_to: Option<Relation<RoomMessageEventContentWithoutRelation>> =
|
|
Some(Relation::Reply { in_reply_to });
|
|
|
|
assert!(!is_reply_to_bot(relates_to.as_ref(), &sent).await);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn is_reply_to_bot_no_relation() {
|
|
let sent: Arc<TokioMutex<HashSet<OwnedEventId>>> =
|
|
Arc::new(TokioMutex::new(HashSet::new()));
|
|
let relates_to: Option<Relation<RoomMessageEventContentWithoutRelation>> = None;
|
|
assert!(!is_reply_to_bot(relates_to.as_ref(), &sent).await);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn is_reply_to_bot_thread_root_match() {
|
|
let sent: Arc<TokioMutex<HashSet<OwnedEventId>>> =
|
|
Arc::new(TokioMutex::new(HashSet::new()));
|
|
let root_id: OwnedEventId = "$root123:example.com".parse().unwrap();
|
|
sent.lock().await.insert(root_id.clone());
|
|
|
|
// Thread reply where the thread root is the bot's message
|
|
let thread = matrix_sdk::ruma::events::relation::Thread::plain(
|
|
root_id,
|
|
"$latest:example.com".parse::<OwnedEventId>().unwrap(),
|
|
);
|
|
let relates_to: Option<Relation<RoomMessageEventContentWithoutRelation>> =
|
|
Some(Relation::Thread(thread));
|
|
|
|
assert!(is_reply_to_bot(relates_to.as_ref(), &sent).await);
|
|
}
|
|
|
|
// -- markdown_to_html ---------------------------------------------------
|
|
|
|
#[test]
|
|
fn markdown_to_html_bold() {
|
|
let html = markdown_to_html("**bold**");
|
|
assert!(html.contains("<strong>bold</strong>"), "expected <strong>: {html}");
|
|
}
|
|
|
|
#[test]
|
|
fn markdown_to_html_unordered_list() {
|
|
let html = markdown_to_html("- item one\n- item two");
|
|
assert!(html.contains("<ul>"), "expected <ul>: {html}");
|
|
assert!(html.contains("<li>item one</li>"), "expected list item: {html}");
|
|
}
|
|
|
|
#[test]
|
|
fn markdown_to_html_inline_code() {
|
|
let html = markdown_to_html("`inline_code()`");
|
|
assert!(html.contains("<code>inline_code()</code>"), "expected <code>: {html}");
|
|
}
|
|
|
|
#[test]
|
|
fn markdown_to_html_code_block() {
|
|
let html = markdown_to_html("```rust\nfn main() {}\n```");
|
|
assert!(html.contains("<pre>"), "expected <pre>: {html}");
|
|
assert!(html.contains("<code"), "expected <code> inside pre: {html}");
|
|
assert!(html.contains("fn main() {}"), "expected code content: {html}");
|
|
}
|
|
|
|
#[test]
|
|
fn markdown_to_html_plain_text_passthrough() {
|
|
let html = markdown_to_html("Hello, world!");
|
|
assert!(html.contains("Hello, world!"), "expected plain text passthrough: {html}");
|
|
}
|
|
|
|
// -- bot_context_is_clone -----------------------------------------------
|
|
|
|
#[test]
|
|
fn bot_context_is_clone() {
|
|
// BotContext must be Clone for the Matrix event handler injection.
|
|
fn assert_clone<T: Clone>() {}
|
|
assert_clone::<BotContext>();
|
|
}
|
|
|
|
// -- drain_complete_paragraphs ------------------------------------------
|
|
|
|
#[test]
|
|
fn drain_complete_paragraphs_no_boundary_returns_empty() {
|
|
let mut buf = "Hello World".to_string();
|
|
let paras = drain_complete_paragraphs(&mut buf);
|
|
assert!(paras.is_empty());
|
|
assert_eq!(buf, "Hello World");
|
|
}
|
|
|
|
#[test]
|
|
fn drain_complete_paragraphs_single_boundary() {
|
|
let mut buf = "Paragraph one.\n\nParagraph two.".to_string();
|
|
let paras = drain_complete_paragraphs(&mut buf);
|
|
assert_eq!(paras, vec!["Paragraph one."]);
|
|
assert_eq!(buf, "Paragraph two.");
|
|
}
|
|
|
|
#[test]
|
|
fn drain_complete_paragraphs_multiple_boundaries() {
|
|
let mut buf = "A\n\nB\n\nC".to_string();
|
|
let paras = drain_complete_paragraphs(&mut buf);
|
|
assert_eq!(paras, vec!["A", "B"]);
|
|
assert_eq!(buf, "C");
|
|
}
|
|
|
|
#[test]
|
|
fn drain_complete_paragraphs_trailing_boundary() {
|
|
let mut buf = "A\n\nB\n\n".to_string();
|
|
let paras = drain_complete_paragraphs(&mut buf);
|
|
assert_eq!(paras, vec!["A", "B"]);
|
|
assert_eq!(buf, "");
|
|
}
|
|
|
|
#[test]
|
|
fn drain_complete_paragraphs_empty_input() {
|
|
let mut buf = String::new();
|
|
let paras = drain_complete_paragraphs(&mut buf);
|
|
assert!(paras.is_empty());
|
|
assert_eq!(buf, "");
|
|
}
|
|
|
|
#[test]
|
|
fn drain_complete_paragraphs_skips_empty_chunks() {
|
|
// Consecutive double-newlines produce no empty paragraphs.
|
|
let mut buf = "\n\n\n\nHello".to_string();
|
|
let paras = drain_complete_paragraphs(&mut buf);
|
|
assert!(paras.is_empty());
|
|
assert_eq!(buf, "Hello");
|
|
}
|
|
|
|
#[test]
|
|
fn drain_complete_paragraphs_trims_whitespace() {
|
|
let mut buf = " Hello \n\n World ".to_string();
|
|
let paras = drain_complete_paragraphs(&mut buf);
|
|
assert_eq!(paras, vec!["Hello"]);
|
|
assert_eq!(buf, " World ");
|
|
}
|
|
|
|
// -- drain_complete_paragraphs: code-fence awareness -------------------
|
|
|
|
#[test]
|
|
fn drain_complete_paragraphs_code_fence_blank_line_not_split() {
|
|
// A blank line inside a fenced code block must NOT trigger a split.
|
|
// Before the fix the function would split at the blank line and the
|
|
// second half would be sent without the opening fence, breaking rendering.
|
|
let mut buf = "```rust\nfn foo() {\n let x = 1;\n\n let y = 2;\n}\n```\n\nNext paragraph."
|
|
.to_string();
|
|
let paras = drain_complete_paragraphs(&mut buf);
|
|
assert_eq!(
|
|
paras.len(),
|
|
1,
|
|
"code fence with blank line should not be split into multiple messages: {paras:?}"
|
|
);
|
|
assert!(
|
|
paras[0].starts_with("```rust"),
|
|
"first paragraph should be the code fence: {:?}",
|
|
paras[0]
|
|
);
|
|
assert!(
|
|
paras[0].contains("let y = 2;"),
|
|
"code fence should contain content from both sides of the blank line: {:?}",
|
|
paras[0]
|
|
);
|
|
assert_eq!(buf, "Next paragraph.");
|
|
}
|
|
|
|
#[test]
|
|
fn drain_complete_paragraphs_text_before_and_after_fenced_block() {
|
|
// Text paragraph, then a code block with an internal blank line, then more text.
|
|
let mut buf =
|
|
"Before\n\n```\ncode\n\nmore code\n```\n\nAfter".to_string();
|
|
let paras = drain_complete_paragraphs(&mut buf);
|
|
assert_eq!(paras.len(), 2, "expected two paragraphs: {paras:?}");
|
|
assert_eq!(paras[0], "Before");
|
|
assert!(
|
|
paras[1].starts_with("```"),
|
|
"second paragraph should be the code fence: {:?}",
|
|
paras[1]
|
|
);
|
|
assert!(
|
|
paras[1].contains("more code"),
|
|
"code fence content must include the part after the blank line: {:?}",
|
|
paras[1]
|
|
);
|
|
assert_eq!(buf, "After");
|
|
}
|
|
|
|
#[test]
|
|
fn drain_complete_paragraphs_incremental_simulation() {
|
|
// Simulate tokens arriving one character at a time.
|
|
let mut buf = String::new();
|
|
let mut all_paragraphs = Vec::new();
|
|
|
|
for ch in "First para.\n\nSecond para.\n\nThird.".chars() {
|
|
buf.push(ch);
|
|
all_paragraphs.extend(drain_complete_paragraphs(&mut buf));
|
|
}
|
|
|
|
assert_eq!(all_paragraphs, vec!["First para.", "Second para."]);
|
|
assert_eq!(buf, "Third.");
|
|
}
|
|
|
|
// -- build_context_prefix -----------------------------------------------
|
|
|
|
#[test]
|
|
fn build_context_prefix_empty_history() {
|
|
let prefix = build_context_prefix(&[], "@alice:example.com", "Hello!");
|
|
assert_eq!(prefix, "@alice:example.com: Hello!");
|
|
}
|
|
|
|
#[test]
|
|
fn build_context_prefix_includes_history_entries() {
|
|
let history = vec![
|
|
ConversationEntry {
|
|
role: ConversationRole::User,
|
|
sender: "@alice:example.com".to_string(),
|
|
content: "What is story 42?".to_string(),
|
|
},
|
|
ConversationEntry {
|
|
role: ConversationRole::Assistant,
|
|
sender: String::new(),
|
|
content: "Story 42 is about…".to_string(),
|
|
},
|
|
];
|
|
let prefix = build_context_prefix(&history, "@bob:example.com", "Tell me more.");
|
|
assert!(prefix.contains("[Conversation history for this room]"));
|
|
assert!(prefix.contains("User (@alice:example.com): What is story 42?"));
|
|
assert!(prefix.contains("Assistant: Story 42 is about…"));
|
|
assert!(prefix.contains("Current message from @bob:example.com: Tell me more."));
|
|
}
|
|
|
|
#[test]
|
|
fn build_context_prefix_attributes_multiple_users() {
|
|
let history = vec![
|
|
ConversationEntry {
|
|
role: ConversationRole::User,
|
|
sender: "@alice:example.com".to_string(),
|
|
content: "First question".to_string(),
|
|
},
|
|
ConversationEntry {
|
|
role: ConversationRole::Assistant,
|
|
sender: String::new(),
|
|
content: "First answer".to_string(),
|
|
},
|
|
ConversationEntry {
|
|
role: ConversationRole::User,
|
|
sender: "@bob:example.com".to_string(),
|
|
content: "Follow-up".to_string(),
|
|
},
|
|
ConversationEntry {
|
|
role: ConversationRole::Assistant,
|
|
sender: String::new(),
|
|
content: "Second answer".to_string(),
|
|
},
|
|
];
|
|
let prefix = build_context_prefix(&history, "@alice:example.com", "Another question");
|
|
assert!(prefix.contains("User (@alice:example.com): First question"));
|
|
assert!(prefix.contains("User (@bob:example.com): Follow-up"));
|
|
}
|
|
|
|
// -- conversation history trimming --------------------------------------
|
|
|
|
#[tokio::test]
|
|
async fn history_trims_to_configured_size() {
|
|
let history: ConversationHistory =
|
|
Arc::new(TokioMutex::new(HashMap::new()));
|
|
let room_id: OwnedRoomId = "!test:example.com".parse().unwrap();
|
|
let history_size = 4usize; // keep at most 4 entries
|
|
|
|
// Add 6 entries (3 user + 3 assistant turns).
|
|
{
|
|
let mut guard = history.lock().await;
|
|
let entries = guard.entry(room_id.clone()).or_default();
|
|
for i in 0..3usize {
|
|
entries.push(ConversationEntry {
|
|
role: ConversationRole::User,
|
|
sender: "@user:example.com".to_string(),
|
|
content: format!("msg {i}"),
|
|
});
|
|
entries.push(ConversationEntry {
|
|
role: ConversationRole::Assistant,
|
|
sender: String::new(),
|
|
content: format!("reply {i}"),
|
|
});
|
|
if entries.len() > history_size {
|
|
let excess = entries.len() - history_size;
|
|
entries.drain(..excess);
|
|
}
|
|
}
|
|
}
|
|
|
|
let guard = history.lock().await;
|
|
let entries = guard.get(&room_id).unwrap();
|
|
assert_eq!(
|
|
entries.len(),
|
|
history_size,
|
|
"history must be trimmed to history_size"
|
|
);
|
|
// The oldest entries (msg 0 / reply 0) should have been dropped.
|
|
assert!(
|
|
entries.iter().all(|e| !e.content.contains("msg 0")),
|
|
"oldest entries must be dropped"
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn each_room_has_independent_history() {
|
|
let history: ConversationHistory =
|
|
Arc::new(TokioMutex::new(HashMap::new()));
|
|
let room_a: OwnedRoomId = "!room_a:example.com".parse().unwrap();
|
|
let room_b: OwnedRoomId = "!room_b:example.com".parse().unwrap();
|
|
|
|
{
|
|
let mut guard = history.lock().await;
|
|
guard.entry(room_a.clone()).or_default().push(ConversationEntry {
|
|
role: ConversationRole::User,
|
|
sender: "@alice:example.com".to_string(),
|
|
content: "Room A message".to_string(),
|
|
});
|
|
guard.entry(room_b.clone()).or_default().push(ConversationEntry {
|
|
role: ConversationRole::User,
|
|
sender: "@bob:example.com".to_string(),
|
|
content: "Room B message".to_string(),
|
|
});
|
|
}
|
|
|
|
let guard = history.lock().await;
|
|
let entries_a = guard.get(&room_a).unwrap();
|
|
let entries_b = guard.get(&room_b).unwrap();
|
|
assert_eq!(entries_a.len(), 1);
|
|
assert_eq!(entries_b.len(), 1);
|
|
assert_eq!(entries_a[0].content, "Room A message");
|
|
assert_eq!(entries_b[0].content, "Room B message");
|
|
}
|
|
}
|