story-kit: merge 266_story_matrix_bot_structured_conversation_history

This commit is contained in:
Dave
2026-03-17 17:39:13 +00:00
parent 3fb48cdf51
commit a893a1cef7

View File

@@ -14,6 +14,7 @@ use matrix_sdk::{
}, },
}; };
use pulldown_cmark::{Options, Parser, html}; use pulldown_cmark::{Options, Parser, html};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
@@ -34,7 +35,8 @@ use super::config::BotConfig;
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
/// Role of a participant in the conversation history. /// Role of a participant in the conversation history.
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ConversationRole { pub enum ConversationRole {
/// A message sent by a Matrix room participant. /// A message sent by a Matrix room participant.
User, User,
@@ -43,7 +45,7 @@ pub enum ConversationRole {
} }
/// A single turn in the per-room conversation history. /// A single turn in the per-room conversation history.
#[derive(Clone, Debug)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ConversationEntry { pub struct ConversationEntry {
pub role: ConversationRole, pub role: ConversationRole,
/// Matrix user ID (e.g. `@alice:example.com`). Empty for assistant turns. /// Matrix user ID (e.g. `@alice:example.com`). Empty for assistant turns.
@@ -51,11 +53,81 @@ pub struct ConversationEntry {
pub content: String, pub content: String,
} }
/// Per-room conversation history, keyed by room ID. /// Per-room state: conversation entries plus the Claude Code session ID for
/// structured conversation resumption.
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct RoomConversation {
/// Claude Code session ID used to resume multi-turn conversations so the
/// LLM receives prior turns as structured API messages rather than a
/// flattened text prefix.
#[serde(skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
/// Rolling conversation entries (used for turn counting and persistence).
pub entries: Vec<ConversationEntry>,
}
/// Per-room conversation state, keyed by room ID (serialised as string).
/// ///
/// Wrapped in `Arc<TokioMutex<…>>` so it can be shared across concurrent /// Wrapped in `Arc<TokioMutex<…>>` so it can be shared across concurrent
/// event-handler tasks without blocking the sync loop. /// event-handler tasks without blocking the sync loop.
pub type ConversationHistory = Arc<TokioMutex<HashMap<OwnedRoomId, Vec<ConversationEntry>>>>; pub type ConversationHistory = Arc<TokioMutex<HashMap<OwnedRoomId, RoomConversation>>>;
/// On-disk format for persisted conversation history. Room IDs are stored as
/// strings because `OwnedRoomId` does not implement `Serialize` as a map key.
#[derive(Serialize, Deserialize)]
struct PersistedHistory {
rooms: HashMap<String, RoomConversation>,
}
/// Path to the persisted conversation history file relative to project root.
const HISTORY_FILE: &str = ".story_kit/matrix_history.json";
/// Load conversation history from disk, returning an empty map on any error.
pub fn load_history(project_root: &std::path::Path) -> HashMap<OwnedRoomId, RoomConversation> {
let path = project_root.join(HISTORY_FILE);
let data = match std::fs::read_to_string(&path) {
Ok(d) => d,
Err(_) => return HashMap::new(),
};
let persisted: PersistedHistory = match serde_json::from_str(&data) {
Ok(p) => p,
Err(e) => {
slog!("[matrix-bot] Failed to parse history file: {e}");
return HashMap::new();
}
};
persisted
.rooms
.into_iter()
.filter_map(|(k, v)| {
k.parse::<OwnedRoomId>()
.ok()
.map(|room_id| (room_id, v))
})
.collect()
}
/// Save conversation history to disk. Errors are logged but not propagated.
pub fn save_history(
project_root: &std::path::Path,
history: &HashMap<OwnedRoomId, RoomConversation>,
) {
let persisted = PersistedHistory {
rooms: history
.iter()
.map(|(k, v)| (k.to_string(), v.clone()))
.collect(),
};
let path = project_root.join(HISTORY_FILE);
match serde_json::to_string_pretty(&persisted) {
Ok(json) => {
if let Err(e) = std::fs::write(&path, json) {
slog!("[matrix-bot] Failed to write history file: {e}");
}
}
Err(e) => slog!("[matrix-bot] Failed to serialise history: {e}"),
}
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Bot context // Bot context
@@ -221,12 +293,18 @@ pub async fn run_bot(
let notif_room_ids = target_room_ids.clone(); let notif_room_ids = target_room_ids.clone();
let notif_project_root = project_root.clone(); let notif_project_root = project_root.clone();
let persisted = load_history(&project_root);
slog!(
"[matrix-bot] Loaded persisted conversation history for {} room(s)",
persisted.len()
);
let ctx = BotContext { let ctx = BotContext {
bot_user_id, bot_user_id,
target_room_ids, target_room_ids,
project_root, project_root,
allowed_users: config.allowed_users, allowed_users: config.allowed_users,
history: Arc::new(TokioMutex::new(HashMap::new())), history: Arc::new(TokioMutex::new(persisted)),
history_size: config.history_size, history_size: config.history_size,
bot_sent_event_ids: Arc::new(TokioMutex::new(HashSet::new())), bot_sent_event_ids: Arc::new(TokioMutex::new(HashSet::new())),
}; };
@@ -574,33 +652,10 @@ async fn on_room_message(
// Message handler // Message handler
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
/// Build a context string from the room's conversation history to prepend to /// Build the user-facing prompt for a single turn. In multi-user rooms the
/// the user's current message. Returns an empty string when history is empty. /// sender is included so the LLM can distinguish participants.
fn build_context_prefix( fn format_user_prompt(sender: &str, message: &str) -> String {
history: &[ConversationEntry], format!("{sender}: {message}")
current_sender: &str,
current_message: &str,
) -> String {
if history.is_empty() {
return format!("{current_sender}: {current_message}");
}
let mut out = String::from("[Conversation history for this room]\n");
for entry in history {
match entry.role {
ConversationRole::User => {
out.push_str(&format!("User ({}): {}\n", entry.sender, entry.content));
}
ConversationRole::Assistant => {
out.push_str(&format!("Assistant: {}\n", entry.content));
}
}
}
out.push('\n');
out.push_str(&format!(
"Current message from {current_sender}: {current_message}"
));
out
} }
async fn handle_message( async fn handle_message(
@@ -610,14 +665,19 @@ async fn handle_message(
sender: String, sender: String,
user_message: String, user_message: String,
) { ) {
// Read current history for this room before calling the LLM. // Look up the room's existing Claude Code session ID (if any) so we can
let history_snapshot: Vec<ConversationEntry> = { // resume the conversation with structured API messages instead of
// flattening history into a text prefix.
let resume_session_id: Option<String> = {
let guard = ctx.history.lock().await; let guard = ctx.history.lock().await;
guard.get(&room_id).cloned().unwrap_or_default() guard
.get(&room_id)
.and_then(|conv| conv.session_id.clone())
}; };
// Build the prompt with conversation context. // The prompt is just the current message with sender attribution.
let prompt_with_context = build_context_prefix(&history_snapshot, &sender, &user_message); // Prior conversation context is carried by the Claude Code session.
let prompt = format_user_prompt(&sender, &user_message);
let provider = ClaudeCodeProvider::new(); let provider = ClaudeCodeProvider::new();
let (cancel_tx, mut cancel_rx) = watch::channel(false); let (cancel_tx, mut cancel_rx) = watch::channel(false);
@@ -652,9 +712,9 @@ async fn handle_message(
let result = provider let result = provider
.chat_stream( .chat_stream(
&prompt_with_context, &prompt,
&ctx.project_root.to_string_lossy(), &ctx.project_root.to_string_lossy(),
None, // Each Matrix conversation turn is independent at the Claude Code session level. resume_session_id.as_deref(),
&mut cancel_rx, &mut cancel_rx,
move |token| { move |token| {
let mut buf = buffer_for_callback.lock().unwrap(); let mut buf = buffer_for_callback.lock().unwrap();
@@ -675,9 +735,12 @@ async fn handle_message(
let remaining = buffer.lock().unwrap().trim().to_string(); let remaining = buffer.lock().unwrap().trim().to_string();
let did_send_any = sent_any_chunk.load(Ordering::Relaxed); let did_send_any = sent_any_chunk.load(Ordering::Relaxed);
let assistant_reply = match result { let (assistant_reply, new_session_id) = match result {
Ok(ClaudeCodeResult { messages, .. }) => { Ok(ClaudeCodeResult {
if !remaining.is_empty() { messages,
session_id,
}) => {
let reply = if !remaining.is_empty() {
let _ = msg_tx.send(remaining.clone()); let _ = msg_tx.send(remaining.clone());
remaining remaining
} else if !did_send_any { } else if !did_send_any {
@@ -696,13 +759,14 @@ async fn handle_message(
last_text last_text
} else { } else {
remaining remaining
} };
(reply, session_id)
} }
Err(e) => { Err(e) => {
slog!("[matrix-bot] LLM error: {e}"); slog!("[matrix-bot] LLM error: {e}");
let err_msg = format!("Error processing your request: {e}"); let err_msg = format!("Error processing your request: {e}");
let _ = msg_tx.send(err_msg.clone()); let _ = msg_tx.send(err_msg.clone());
err_msg (err_msg, None)
} }
}; };
@@ -711,25 +775,40 @@ async fn handle_message(
drop(msg_tx); drop(msg_tx);
let _ = post_task.await; let _ = post_task.await;
// Record this exchange in the per-room conversation history. // Record this exchange in the per-room conversation history and persist
// the session ID so the next turn resumes with structured API messages.
if !assistant_reply.starts_with("Error processing") { if !assistant_reply.starts_with("Error processing") {
let mut guard = ctx.history.lock().await; let mut guard = ctx.history.lock().await;
let entries = guard.entry(room_id).or_default(); let conv = guard.entry(room_id).or_default();
entries.push(ConversationEntry {
// Store the session ID so the next turn uses --resume.
if new_session_id.is_some() {
conv.session_id = new_session_id;
}
conv.entries.push(ConversationEntry {
role: ConversationRole::User, role: ConversationRole::User,
sender: sender.clone(), sender: sender.clone(),
content: user_message, content: user_message,
}); });
entries.push(ConversationEntry { conv.entries.push(ConversationEntry {
role: ConversationRole::Assistant, role: ConversationRole::Assistant,
sender: String::new(), sender: String::new(),
content: assistant_reply, content: assistant_reply,
}); });
// Trim to the configured maximum, dropping the oldest entries first. // Trim to the configured maximum, dropping the oldest entries first.
if entries.len() > ctx.history_size { // When trimming occurs, clear the session ID so the next turn starts
let excess = entries.len() - ctx.history_size; // a fresh Claude Code session (the old session's context would be
entries.drain(..excess); // stale since we've dropped entries from our tracking).
if conv.entries.len() > ctx.history_size {
let excess = conv.entries.len() - ctx.history_size;
conv.entries.drain(..excess);
conv.session_id = None;
} }
// Persist to disk so history survives server restarts.
save_history(&ctx.project_root, &guard);
} }
} }
@@ -1128,62 +1207,18 @@ mod tests {
assert_eq!(buf, "Third."); assert_eq!(buf, "Third.");
} }
// -- build_context_prefix ----------------------------------------------- // -- format_user_prompt -------------------------------------------------
#[test] #[test]
fn build_context_prefix_empty_history() { fn format_user_prompt_includes_sender_and_message() {
let prefix = build_context_prefix(&[], "@alice:example.com", "Hello!"); let prompt = format_user_prompt("@alice:example.com", "Hello!");
assert_eq!(prefix, "@alice:example.com: Hello!"); assert_eq!(prompt, "@alice:example.com: Hello!");
} }
#[test] #[test]
fn build_context_prefix_includes_history_entries() { fn format_user_prompt_different_users() {
let history = vec![ let prompt = format_user_prompt("@bob:example.com", "What's up?");
ConversationEntry { assert_eq!(prompt, "@bob:example.com: What's up?");
role: ConversationRole::User,
sender: "@alice:example.com".to_string(),
content: "What is story 42?".to_string(),
},
ConversationEntry {
role: ConversationRole::Assistant,
sender: String::new(),
content: "Story 42 is about…".to_string(),
},
];
let prefix = build_context_prefix(&history, "@bob:example.com", "Tell me more.");
assert!(prefix.contains("[Conversation history for this room]"));
assert!(prefix.contains("User (@alice:example.com): What is story 42?"));
assert!(prefix.contains("Assistant: Story 42 is about…"));
assert!(prefix.contains("Current message from @bob:example.com: Tell me more."));
}
#[test]
fn build_context_prefix_attributes_multiple_users() {
let history = vec![
ConversationEntry {
role: ConversationRole::User,
sender: "@alice:example.com".to_string(),
content: "First question".to_string(),
},
ConversationEntry {
role: ConversationRole::Assistant,
sender: String::new(),
content: "First answer".to_string(),
},
ConversationEntry {
role: ConversationRole::User,
sender: "@bob:example.com".to_string(),
content: "Follow-up".to_string(),
},
ConversationEntry {
role: ConversationRole::Assistant,
sender: String::new(),
content: "Second answer".to_string(),
},
];
let prefix = build_context_prefix(&history, "@alice:example.com", "Another question");
assert!(prefix.contains("User (@alice:example.com): First question"));
assert!(prefix.contains("User (@bob:example.com): Follow-up"));
} }
// -- conversation history trimming -------------------------------------- // -- conversation history trimming --------------------------------------
@@ -1197,37 +1232,44 @@ mod tests {
// Add 6 entries (3 user + 3 assistant turns). // Add 6 entries (3 user + 3 assistant turns).
{ {
let mut guard = history.lock().await; let mut guard = history.lock().await;
let entries = guard.entry(room_id.clone()).or_default(); let conv = guard.entry(room_id.clone()).or_default();
conv.session_id = Some("test-session".to_string());
for i in 0..3usize { for i in 0..3usize {
entries.push(ConversationEntry { conv.entries.push(ConversationEntry {
role: ConversationRole::User, role: ConversationRole::User,
sender: "@user:example.com".to_string(), sender: "@user:example.com".to_string(),
content: format!("msg {i}"), content: format!("msg {i}"),
}); });
entries.push(ConversationEntry { conv.entries.push(ConversationEntry {
role: ConversationRole::Assistant, role: ConversationRole::Assistant,
sender: String::new(), sender: String::new(),
content: format!("reply {i}"), content: format!("reply {i}"),
}); });
if entries.len() > history_size { if conv.entries.len() > history_size {
let excess = entries.len() - history_size; let excess = conv.entries.len() - history_size;
entries.drain(..excess); conv.entries.drain(..excess);
conv.session_id = None;
} }
} }
} }
let guard = history.lock().await; let guard = history.lock().await;
let entries = guard.get(&room_id).unwrap(); let conv = guard.get(&room_id).unwrap();
assert_eq!( assert_eq!(
entries.len(), conv.entries.len(),
history_size, history_size,
"history must be trimmed to history_size" "history must be trimmed to history_size"
); );
// The oldest entries (msg 0 / reply 0) should have been dropped. // The oldest entries (msg 0 / reply 0) should have been dropped.
assert!( assert!(
entries.iter().all(|e| !e.content.contains("msg 0")), conv.entries.iter().all(|e| !e.content.contains("msg 0")),
"oldest entries must be dropped" "oldest entries must be dropped"
); );
// Session ID must be cleared when trimming occurs.
assert!(
conv.session_id.is_none(),
"session_id must be cleared on trim to start a fresh session"
);
} }
#[tokio::test] #[tokio::test]
@@ -1241,6 +1283,7 @@ mod tests {
guard guard
.entry(room_a.clone()) .entry(room_a.clone())
.or_default() .or_default()
.entries
.push(ConversationEntry { .push(ConversationEntry {
role: ConversationRole::User, role: ConversationRole::User,
sender: "@alice:example.com".to_string(), sender: "@alice:example.com".to_string(),
@@ -1249,6 +1292,7 @@ mod tests {
guard guard
.entry(room_b.clone()) .entry(room_b.clone())
.or_default() .or_default()
.entries
.push(ConversationEntry { .push(ConversationEntry {
role: ConversationRole::User, role: ConversationRole::User,
sender: "@bob:example.com".to_string(), sender: "@bob:example.com".to_string(),
@@ -1257,12 +1301,131 @@ mod tests {
} }
let guard = history.lock().await; let guard = history.lock().await;
let entries_a = guard.get(&room_a).unwrap(); let conv_a = guard.get(&room_a).unwrap();
let entries_b = guard.get(&room_b).unwrap(); let conv_b = guard.get(&room_b).unwrap();
assert_eq!(entries_a.len(), 1); assert_eq!(conv_a.entries.len(), 1);
assert_eq!(entries_b.len(), 1); assert_eq!(conv_b.entries.len(), 1);
assert_eq!(entries_a[0].content, "Room A message"); assert_eq!(conv_a.entries[0].content, "Room A message");
assert_eq!(entries_b[0].content, "Room B message"); assert_eq!(conv_b.entries[0].content, "Room B message");
}
// -- persistence --------------------------------------------------------
#[test]
fn save_and_load_history_round_trip() {
let dir = tempfile::tempdir().unwrap();
let story_kit_dir = dir.path().join(".story_kit");
std::fs::create_dir_all(&story_kit_dir).unwrap();
let room_id: OwnedRoomId = "!persist:example.com".parse().unwrap();
let mut map: HashMap<OwnedRoomId, RoomConversation> = HashMap::new();
let conv = map.entry(room_id.clone()).or_default();
conv.session_id = Some("session-abc".to_string());
conv.entries.push(ConversationEntry {
role: ConversationRole::User,
sender: "@alice:example.com".to_string(),
content: "hello".to_string(),
});
conv.entries.push(ConversationEntry {
role: ConversationRole::Assistant,
sender: String::new(),
content: "hi there!".to_string(),
});
save_history(dir.path(), &map);
let loaded = load_history(dir.path());
let loaded_conv = loaded.get(&room_id).expect("room must exist after load");
assert_eq!(loaded_conv.session_id.as_deref(), Some("session-abc"));
assert_eq!(loaded_conv.entries.len(), 2);
assert_eq!(loaded_conv.entries[0].role, ConversationRole::User);
assert_eq!(loaded_conv.entries[0].sender, "@alice:example.com");
assert_eq!(loaded_conv.entries[0].content, "hello");
assert_eq!(loaded_conv.entries[1].role, ConversationRole::Assistant);
assert_eq!(loaded_conv.entries[1].content, "hi there!");
}
#[test]
fn load_history_returns_empty_on_missing_file() {
let dir = tempfile::tempdir().unwrap();
let loaded = load_history(dir.path());
assert!(loaded.is_empty());
}
#[test]
fn load_history_returns_empty_on_corrupt_file() {
let dir = tempfile::tempdir().unwrap();
let story_kit_dir = dir.path().join(".story_kit");
std::fs::create_dir_all(&story_kit_dir).unwrap();
std::fs::write(dir.path().join(HISTORY_FILE), "not valid json").unwrap();
let loaded = load_history(dir.path());
assert!(loaded.is_empty());
}
// -- session_id tracking ------------------------------------------------
#[tokio::test]
async fn session_id_preserved_within_history_size() {
let history: ConversationHistory = Arc::new(TokioMutex::new(HashMap::new()));
let room_id: OwnedRoomId = "!session:example.com".parse().unwrap();
{
let mut guard = history.lock().await;
let conv = guard.entry(room_id.clone()).or_default();
conv.session_id = Some("sess-1".to_string());
conv.entries.push(ConversationEntry {
role: ConversationRole::User,
sender: "@alice:example.com".to_string(),
content: "hello".to_string(),
});
conv.entries.push(ConversationEntry {
role: ConversationRole::Assistant,
sender: String::new(),
content: "hi".to_string(),
});
// No trimming needed (2 entries, well under any reasonable limit).
}
let guard = history.lock().await;
let conv = guard.get(&room_id).unwrap();
assert_eq!(
conv.session_id.as_deref(),
Some("sess-1"),
"session_id must be preserved when no trimming occurs"
);
}
// -- multi-user room attribution ----------------------------------------
#[tokio::test]
async fn multi_user_entries_preserve_sender() {
let history: ConversationHistory = Arc::new(TokioMutex::new(HashMap::new()));
let room_id: OwnedRoomId = "!multi:example.com".parse().unwrap();
{
let mut guard = history.lock().await;
let conv = guard.entry(room_id.clone()).or_default();
conv.entries.push(ConversationEntry {
role: ConversationRole::User,
sender: "@alice:example.com".to_string(),
content: "from alice".to_string(),
});
conv.entries.push(ConversationEntry {
role: ConversationRole::Assistant,
sender: String::new(),
content: "reply to alice".to_string(),
});
conv.entries.push(ConversationEntry {
role: ConversationRole::User,
sender: "@bob:example.com".to_string(),
content: "from bob".to_string(),
});
}
let guard = history.lock().await;
let conv = guard.get(&room_id).unwrap();
assert_eq!(conv.entries[0].sender, "@alice:example.com");
assert_eq!(conv.entries[2].sender, "@bob:example.com");
} }
// -- self-sign device key decision logic ----------------------------------- // -- self-sign device key decision logic -----------------------------------