From a893a1cef74d84a3bd2c5aadf7586144d2f1c07e Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 17 Mar 2026 17:39:13 +0000 Subject: [PATCH] story-kit: merge 266_story_matrix_bot_structured_conversation_history --- server/src/matrix/bot.rs | 397 +++++++++++++++++++++++++++------------ 1 file changed, 280 insertions(+), 117 deletions(-) diff --git a/server/src/matrix/bot.rs b/server/src/matrix/bot.rs index ba4aa16..940e2a7 100644 --- a/server/src/matrix/bot.rs +++ b/server/src/matrix/bot.rs @@ -14,6 +14,7 @@ use matrix_sdk::{ }, }; use pulldown_cmark::{Options, Parser, html}; +use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::path::PathBuf; use std::sync::Arc; @@ -34,7 +35,8 @@ use super::config::BotConfig; // --------------------------------------------------------------------------- /// Role of a participant in the conversation history. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] pub enum ConversationRole { /// A message sent by a Matrix room participant. User, @@ -43,7 +45,7 @@ pub enum ConversationRole { } /// A single turn in the per-room conversation history. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct ConversationEntry { pub role: ConversationRole, /// Matrix user ID (e.g. `@alice:example.com`). Empty for assistant turns. @@ -51,11 +53,81 @@ pub struct ConversationEntry { pub content: String, } -/// Per-room conversation history, keyed by room ID. +/// Per-room state: conversation entries plus the Claude Code session ID for +/// structured conversation resumption. +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct RoomConversation { + /// Claude Code session ID used to resume multi-turn conversations so the + /// LLM receives prior turns as structured API messages rather than a + /// flattened text prefix. + #[serde(skip_serializing_if = "Option::is_none")] + pub session_id: Option, + /// Rolling conversation entries (used for turn counting and persistence). + pub entries: Vec, +} + +/// Per-room conversation state, keyed by room ID (serialised as string). /// /// Wrapped in `Arc>` so it can be shared across concurrent /// event-handler tasks without blocking the sync loop. -pub type ConversationHistory = Arc>>>; +pub type ConversationHistory = Arc>>; + +/// On-disk format for persisted conversation history. Room IDs are stored as +/// strings because `OwnedRoomId` does not implement `Serialize` as a map key. +#[derive(Serialize, Deserialize)] +struct PersistedHistory { + rooms: HashMap, +} + +/// Path to the persisted conversation history file relative to project root. +const HISTORY_FILE: &str = ".story_kit/matrix_history.json"; + +/// Load conversation history from disk, returning an empty map on any error. +pub fn load_history(project_root: &std::path::Path) -> HashMap { + let path = project_root.join(HISTORY_FILE); + let data = match std::fs::read_to_string(&path) { + Ok(d) => d, + Err(_) => return HashMap::new(), + }; + let persisted: PersistedHistory = match serde_json::from_str(&data) { + Ok(p) => p, + Err(e) => { + slog!("[matrix-bot] Failed to parse history file: {e}"); + return HashMap::new(); + } + }; + persisted + .rooms + .into_iter() + .filter_map(|(k, v)| { + k.parse::() + .ok() + .map(|room_id| (room_id, v)) + }) + .collect() +} + +/// Save conversation history to disk. Errors are logged but not propagated. +pub fn save_history( + project_root: &std::path::Path, + history: &HashMap, +) { + let persisted = PersistedHistory { + rooms: history + .iter() + .map(|(k, v)| (k.to_string(), v.clone())) + .collect(), + }; + let path = project_root.join(HISTORY_FILE); + match serde_json::to_string_pretty(&persisted) { + Ok(json) => { + if let Err(e) = std::fs::write(&path, json) { + slog!("[matrix-bot] Failed to write history file: {e}"); + } + } + Err(e) => slog!("[matrix-bot] Failed to serialise history: {e}"), + } +} // --------------------------------------------------------------------------- // Bot context @@ -221,12 +293,18 @@ pub async fn run_bot( let notif_room_ids = target_room_ids.clone(); let notif_project_root = project_root.clone(); + let persisted = load_history(&project_root); + slog!( + "[matrix-bot] Loaded persisted conversation history for {} room(s)", + persisted.len() + ); + let ctx = BotContext { bot_user_id, target_room_ids, project_root, allowed_users: config.allowed_users, - history: Arc::new(TokioMutex::new(HashMap::new())), + history: Arc::new(TokioMutex::new(persisted)), history_size: config.history_size, bot_sent_event_ids: Arc::new(TokioMutex::new(HashSet::new())), }; @@ -574,33 +652,10 @@ async fn on_room_message( // Message handler // --------------------------------------------------------------------------- -/// Build a context string from the room's conversation history to prepend to -/// the user's current message. Returns an empty string when history is empty. -fn build_context_prefix( - history: &[ConversationEntry], - current_sender: &str, - current_message: &str, -) -> String { - if history.is_empty() { - return format!("{current_sender}: {current_message}"); - } - - let mut out = String::from("[Conversation history for this room]\n"); - for entry in history { - match entry.role { - ConversationRole::User => { - out.push_str(&format!("User ({}): {}\n", entry.sender, entry.content)); - } - ConversationRole::Assistant => { - out.push_str(&format!("Assistant: {}\n", entry.content)); - } - } - } - out.push('\n'); - out.push_str(&format!( - "Current message from {current_sender}: {current_message}" - )); - out +/// Build the user-facing prompt for a single turn. In multi-user rooms the +/// sender is included so the LLM can distinguish participants. +fn format_user_prompt(sender: &str, message: &str) -> String { + format!("{sender}: {message}") } async fn handle_message( @@ -610,14 +665,19 @@ async fn handle_message( sender: String, user_message: String, ) { - // Read current history for this room before calling the LLM. - let history_snapshot: Vec = { + // Look up the room's existing Claude Code session ID (if any) so we can + // resume the conversation with structured API messages instead of + // flattening history into a text prefix. + let resume_session_id: Option = { let guard = ctx.history.lock().await; - guard.get(&room_id).cloned().unwrap_or_default() + guard + .get(&room_id) + .and_then(|conv| conv.session_id.clone()) }; - // Build the prompt with conversation context. - let prompt_with_context = build_context_prefix(&history_snapshot, &sender, &user_message); + // The prompt is just the current message with sender attribution. + // Prior conversation context is carried by the Claude Code session. + let prompt = format_user_prompt(&sender, &user_message); let provider = ClaudeCodeProvider::new(); let (cancel_tx, mut cancel_rx) = watch::channel(false); @@ -652,9 +712,9 @@ async fn handle_message( let result = provider .chat_stream( - &prompt_with_context, + &prompt, &ctx.project_root.to_string_lossy(), - None, // Each Matrix conversation turn is independent at the Claude Code session level. + resume_session_id.as_deref(), &mut cancel_rx, move |token| { let mut buf = buffer_for_callback.lock().unwrap(); @@ -675,9 +735,12 @@ async fn handle_message( let remaining = buffer.lock().unwrap().trim().to_string(); let did_send_any = sent_any_chunk.load(Ordering::Relaxed); - let assistant_reply = match result { - Ok(ClaudeCodeResult { messages, .. }) => { - if !remaining.is_empty() { + let (assistant_reply, new_session_id) = match result { + Ok(ClaudeCodeResult { + messages, + session_id, + }) => { + let reply = if !remaining.is_empty() { let _ = msg_tx.send(remaining.clone()); remaining } else if !did_send_any { @@ -696,13 +759,14 @@ async fn handle_message( last_text } else { remaining - } + }; + (reply, session_id) } Err(e) => { slog!("[matrix-bot] LLM error: {e}"); let err_msg = format!("Error processing your request: {e}"); let _ = msg_tx.send(err_msg.clone()); - err_msg + (err_msg, None) } }; @@ -711,25 +775,40 @@ async fn handle_message( drop(msg_tx); let _ = post_task.await; - // Record this exchange in the per-room conversation history. + // Record this exchange in the per-room conversation history and persist + // the session ID so the next turn resumes with structured API messages. if !assistant_reply.starts_with("Error processing") { let mut guard = ctx.history.lock().await; - let entries = guard.entry(room_id).or_default(); - entries.push(ConversationEntry { + let conv = guard.entry(room_id).or_default(); + + // Store the session ID so the next turn uses --resume. + if new_session_id.is_some() { + conv.session_id = new_session_id; + } + + conv.entries.push(ConversationEntry { role: ConversationRole::User, sender: sender.clone(), content: user_message, }); - entries.push(ConversationEntry { + conv.entries.push(ConversationEntry { role: ConversationRole::Assistant, sender: String::new(), content: assistant_reply, }); + // Trim to the configured maximum, dropping the oldest entries first. - if entries.len() > ctx.history_size { - let excess = entries.len() - ctx.history_size; - entries.drain(..excess); + // When trimming occurs, clear the session ID so the next turn starts + // a fresh Claude Code session (the old session's context would be + // stale since we've dropped entries from our tracking). + if conv.entries.len() > ctx.history_size { + let excess = conv.entries.len() - ctx.history_size; + conv.entries.drain(..excess); + conv.session_id = None; } + + // Persist to disk so history survives server restarts. + save_history(&ctx.project_root, &guard); } } @@ -1128,62 +1207,18 @@ mod tests { assert_eq!(buf, "Third."); } - // -- build_context_prefix ----------------------------------------------- + // -- format_user_prompt ------------------------------------------------- #[test] - fn build_context_prefix_empty_history() { - let prefix = build_context_prefix(&[], "@alice:example.com", "Hello!"); - assert_eq!(prefix, "@alice:example.com: Hello!"); + fn format_user_prompt_includes_sender_and_message() { + let prompt = format_user_prompt("@alice:example.com", "Hello!"); + assert_eq!(prompt, "@alice:example.com: Hello!"); } #[test] - fn build_context_prefix_includes_history_entries() { - let history = vec![ - ConversationEntry { - role: ConversationRole::User, - sender: "@alice:example.com".to_string(), - content: "What is story 42?".to_string(), - }, - ConversationEntry { - role: ConversationRole::Assistant, - sender: String::new(), - content: "Story 42 is about…".to_string(), - }, - ]; - let prefix = build_context_prefix(&history, "@bob:example.com", "Tell me more."); - assert!(prefix.contains("[Conversation history for this room]")); - assert!(prefix.contains("User (@alice:example.com): What is story 42?")); - assert!(prefix.contains("Assistant: Story 42 is about…")); - assert!(prefix.contains("Current message from @bob:example.com: Tell me more.")); - } - - #[test] - fn build_context_prefix_attributes_multiple_users() { - let history = vec![ - ConversationEntry { - role: ConversationRole::User, - sender: "@alice:example.com".to_string(), - content: "First question".to_string(), - }, - ConversationEntry { - role: ConversationRole::Assistant, - sender: String::new(), - content: "First answer".to_string(), - }, - ConversationEntry { - role: ConversationRole::User, - sender: "@bob:example.com".to_string(), - content: "Follow-up".to_string(), - }, - ConversationEntry { - role: ConversationRole::Assistant, - sender: String::new(), - content: "Second answer".to_string(), - }, - ]; - let prefix = build_context_prefix(&history, "@alice:example.com", "Another question"); - assert!(prefix.contains("User (@alice:example.com): First question")); - assert!(prefix.contains("User (@bob:example.com): Follow-up")); + fn format_user_prompt_different_users() { + let prompt = format_user_prompt("@bob:example.com", "What's up?"); + assert_eq!(prompt, "@bob:example.com: What's up?"); } // -- conversation history trimming -------------------------------------- @@ -1197,37 +1232,44 @@ mod tests { // Add 6 entries (3 user + 3 assistant turns). { let mut guard = history.lock().await; - let entries = guard.entry(room_id.clone()).or_default(); + let conv = guard.entry(room_id.clone()).or_default(); + conv.session_id = Some("test-session".to_string()); for i in 0..3usize { - entries.push(ConversationEntry { + conv.entries.push(ConversationEntry { role: ConversationRole::User, sender: "@user:example.com".to_string(), content: format!("msg {i}"), }); - entries.push(ConversationEntry { + conv.entries.push(ConversationEntry { role: ConversationRole::Assistant, sender: String::new(), content: format!("reply {i}"), }); - if entries.len() > history_size { - let excess = entries.len() - history_size; - entries.drain(..excess); + if conv.entries.len() > history_size { + let excess = conv.entries.len() - history_size; + conv.entries.drain(..excess); + conv.session_id = None; } } } let guard = history.lock().await; - let entries = guard.get(&room_id).unwrap(); + let conv = guard.get(&room_id).unwrap(); assert_eq!( - entries.len(), + conv.entries.len(), history_size, "history must be trimmed to history_size" ); // The oldest entries (msg 0 / reply 0) should have been dropped. assert!( - entries.iter().all(|e| !e.content.contains("msg 0")), + conv.entries.iter().all(|e| !e.content.contains("msg 0")), "oldest entries must be dropped" ); + // Session ID must be cleared when trimming occurs. + assert!( + conv.session_id.is_none(), + "session_id must be cleared on trim to start a fresh session" + ); } #[tokio::test] @@ -1241,6 +1283,7 @@ mod tests { guard .entry(room_a.clone()) .or_default() + .entries .push(ConversationEntry { role: ConversationRole::User, sender: "@alice:example.com".to_string(), @@ -1249,6 +1292,7 @@ mod tests { guard .entry(room_b.clone()) .or_default() + .entries .push(ConversationEntry { role: ConversationRole::User, sender: "@bob:example.com".to_string(), @@ -1257,12 +1301,131 @@ mod tests { } let guard = history.lock().await; - let entries_a = guard.get(&room_a).unwrap(); - let entries_b = guard.get(&room_b).unwrap(); - assert_eq!(entries_a.len(), 1); - assert_eq!(entries_b.len(), 1); - assert_eq!(entries_a[0].content, "Room A message"); - assert_eq!(entries_b[0].content, "Room B message"); + let conv_a = guard.get(&room_a).unwrap(); + let conv_b = guard.get(&room_b).unwrap(); + assert_eq!(conv_a.entries.len(), 1); + assert_eq!(conv_b.entries.len(), 1); + assert_eq!(conv_a.entries[0].content, "Room A message"); + assert_eq!(conv_b.entries[0].content, "Room B message"); + } + + // -- persistence -------------------------------------------------------- + + #[test] + fn save_and_load_history_round_trip() { + let dir = tempfile::tempdir().unwrap(); + let story_kit_dir = dir.path().join(".story_kit"); + std::fs::create_dir_all(&story_kit_dir).unwrap(); + + let room_id: OwnedRoomId = "!persist:example.com".parse().unwrap(); + let mut map: HashMap = HashMap::new(); + let conv = map.entry(room_id.clone()).or_default(); + conv.session_id = Some("session-abc".to_string()); + conv.entries.push(ConversationEntry { + role: ConversationRole::User, + sender: "@alice:example.com".to_string(), + content: "hello".to_string(), + }); + conv.entries.push(ConversationEntry { + role: ConversationRole::Assistant, + sender: String::new(), + content: "hi there!".to_string(), + }); + + save_history(dir.path(), &map); + + let loaded = load_history(dir.path()); + let loaded_conv = loaded.get(&room_id).expect("room must exist after load"); + assert_eq!(loaded_conv.session_id.as_deref(), Some("session-abc")); + assert_eq!(loaded_conv.entries.len(), 2); + assert_eq!(loaded_conv.entries[0].role, ConversationRole::User); + assert_eq!(loaded_conv.entries[0].sender, "@alice:example.com"); + assert_eq!(loaded_conv.entries[0].content, "hello"); + assert_eq!(loaded_conv.entries[1].role, ConversationRole::Assistant); + assert_eq!(loaded_conv.entries[1].content, "hi there!"); + } + + #[test] + fn load_history_returns_empty_on_missing_file() { + let dir = tempfile::tempdir().unwrap(); + let loaded = load_history(dir.path()); + assert!(loaded.is_empty()); + } + + #[test] + fn load_history_returns_empty_on_corrupt_file() { + let dir = tempfile::tempdir().unwrap(); + let story_kit_dir = dir.path().join(".story_kit"); + std::fs::create_dir_all(&story_kit_dir).unwrap(); + std::fs::write(dir.path().join(HISTORY_FILE), "not valid json").unwrap(); + let loaded = load_history(dir.path()); + assert!(loaded.is_empty()); + } + + // -- session_id tracking ------------------------------------------------ + + #[tokio::test] + async fn session_id_preserved_within_history_size() { + let history: ConversationHistory = Arc::new(TokioMutex::new(HashMap::new())); + let room_id: OwnedRoomId = "!session:example.com".parse().unwrap(); + + { + let mut guard = history.lock().await; + let conv = guard.entry(room_id.clone()).or_default(); + conv.session_id = Some("sess-1".to_string()); + conv.entries.push(ConversationEntry { + role: ConversationRole::User, + sender: "@alice:example.com".to_string(), + content: "hello".to_string(), + }); + conv.entries.push(ConversationEntry { + role: ConversationRole::Assistant, + sender: String::new(), + content: "hi".to_string(), + }); + // No trimming needed (2 entries, well under any reasonable limit). + } + + let guard = history.lock().await; + let conv = guard.get(&room_id).unwrap(); + assert_eq!( + conv.session_id.as_deref(), + Some("sess-1"), + "session_id must be preserved when no trimming occurs" + ); + } + + // -- multi-user room attribution ---------------------------------------- + + #[tokio::test] + async fn multi_user_entries_preserve_sender() { + let history: ConversationHistory = Arc::new(TokioMutex::new(HashMap::new())); + let room_id: OwnedRoomId = "!multi:example.com".parse().unwrap(); + + { + let mut guard = history.lock().await; + let conv = guard.entry(room_id.clone()).or_default(); + conv.entries.push(ConversationEntry { + role: ConversationRole::User, + sender: "@alice:example.com".to_string(), + content: "from alice".to_string(), + }); + conv.entries.push(ConversationEntry { + role: ConversationRole::Assistant, + sender: String::new(), + content: "reply to alice".to_string(), + }); + conv.entries.push(ConversationEntry { + role: ConversationRole::User, + sender: "@bob:example.com".to_string(), + content: "from bob".to_string(), + }); + } + + let guard = history.lock().await; + let conv = guard.get(&room_id).unwrap(); + assert_eq!(conv.entries[0].sender, "@alice:example.com"); + assert_eq!(conv.entries[2].sender, "@bob:example.com"); } // -- self-sign device key decision logic -----------------------------------