story-kit: merge 266_story_matrix_bot_structured_conversation_history

This commit is contained in:
Dave
2026-03-17 17:39:13 +00:00
parent 3fb48cdf51
commit a893a1cef7

View File

@@ -14,6 +14,7 @@ use matrix_sdk::{
},
};
use pulldown_cmark::{Options, Parser, html};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::path::PathBuf;
use std::sync::Arc;
@@ -34,7 +35,8 @@ use super::config::BotConfig;
// ---------------------------------------------------------------------------
/// Role of a participant in the conversation history.
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ConversationRole {
/// A message sent by a Matrix room participant.
User,
@@ -43,7 +45,7 @@ pub enum ConversationRole {
}
/// A single turn in the per-room conversation history.
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ConversationEntry {
pub role: ConversationRole,
/// Matrix user ID (e.g. `@alice:example.com`). Empty for assistant turns.
@@ -51,11 +53,81 @@ pub struct ConversationEntry {
pub content: String,
}
/// Per-room conversation history, keyed by room ID.
/// Per-room state: conversation entries plus the Claude Code session ID for
/// structured conversation resumption.
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct RoomConversation {
/// Claude Code session ID used to resume multi-turn conversations so the
/// LLM receives prior turns as structured API messages rather than a
/// flattened text prefix.
#[serde(skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
/// Rolling conversation entries (used for turn counting and persistence).
pub entries: Vec<ConversationEntry>,
}
/// Per-room conversation state, keyed by room ID (serialised as string).
///
/// Wrapped in `Arc<TokioMutex<…>>` so it can be shared across concurrent
/// event-handler tasks without blocking the sync loop.
pub type ConversationHistory = Arc<TokioMutex<HashMap<OwnedRoomId, Vec<ConversationEntry>>>>;
pub type ConversationHistory = Arc<TokioMutex<HashMap<OwnedRoomId, RoomConversation>>>;
/// On-disk format for persisted conversation history. Room IDs are stored as
/// strings because `OwnedRoomId` does not implement `Serialize` as a map key.
#[derive(Serialize, Deserialize)]
struct PersistedHistory {
rooms: HashMap<String, RoomConversation>,
}
/// Path to the persisted conversation history file relative to project root.
const HISTORY_FILE: &str = ".story_kit/matrix_history.json";
/// Load conversation history from disk, returning an empty map on any error.
pub fn load_history(project_root: &std::path::Path) -> HashMap<OwnedRoomId, RoomConversation> {
let path = project_root.join(HISTORY_FILE);
let data = match std::fs::read_to_string(&path) {
Ok(d) => d,
Err(_) => return HashMap::new(),
};
let persisted: PersistedHistory = match serde_json::from_str(&data) {
Ok(p) => p,
Err(e) => {
slog!("[matrix-bot] Failed to parse history file: {e}");
return HashMap::new();
}
};
persisted
.rooms
.into_iter()
.filter_map(|(k, v)| {
k.parse::<OwnedRoomId>()
.ok()
.map(|room_id| (room_id, v))
})
.collect()
}
/// Save conversation history to disk. Errors are logged but not propagated.
pub fn save_history(
project_root: &std::path::Path,
history: &HashMap<OwnedRoomId, RoomConversation>,
) {
let persisted = PersistedHistory {
rooms: history
.iter()
.map(|(k, v)| (k.to_string(), v.clone()))
.collect(),
};
let path = project_root.join(HISTORY_FILE);
match serde_json::to_string_pretty(&persisted) {
Ok(json) => {
if let Err(e) = std::fs::write(&path, json) {
slog!("[matrix-bot] Failed to write history file: {e}");
}
}
Err(e) => slog!("[matrix-bot] Failed to serialise history: {e}"),
}
}
// ---------------------------------------------------------------------------
// Bot context
@@ -221,12 +293,18 @@ pub async fn run_bot(
let notif_room_ids = target_room_ids.clone();
let notif_project_root = project_root.clone();
let persisted = load_history(&project_root);
slog!(
"[matrix-bot] Loaded persisted conversation history for {} room(s)",
persisted.len()
);
let ctx = BotContext {
bot_user_id,
target_room_ids,
project_root,
allowed_users: config.allowed_users,
history: Arc::new(TokioMutex::new(HashMap::new())),
history: Arc::new(TokioMutex::new(persisted)),
history_size: config.history_size,
bot_sent_event_ids: Arc::new(TokioMutex::new(HashSet::new())),
};
@@ -574,33 +652,10 @@ async fn on_room_message(
// Message handler
// ---------------------------------------------------------------------------
/// Build a context string from the room's conversation history to prepend to
/// the user's current message. Returns an empty string when history is empty.
fn build_context_prefix(
history: &[ConversationEntry],
current_sender: &str,
current_message: &str,
) -> String {
if history.is_empty() {
return format!("{current_sender}: {current_message}");
}
let mut out = String::from("[Conversation history for this room]\n");
for entry in history {
match entry.role {
ConversationRole::User => {
out.push_str(&format!("User ({}): {}\n", entry.sender, entry.content));
}
ConversationRole::Assistant => {
out.push_str(&format!("Assistant: {}\n", entry.content));
}
}
}
out.push('\n');
out.push_str(&format!(
"Current message from {current_sender}: {current_message}"
));
out
/// Build the user-facing prompt for a single turn. In multi-user rooms the
/// sender is included so the LLM can distinguish participants.
fn format_user_prompt(sender: &str, message: &str) -> String {
format!("{sender}: {message}")
}
async fn handle_message(
@@ -610,14 +665,19 @@ async fn handle_message(
sender: String,
user_message: String,
) {
// Read current history for this room before calling the LLM.
let history_snapshot: Vec<ConversationEntry> = {
// Look up the room's existing Claude Code session ID (if any) so we can
// resume the conversation with structured API messages instead of
// flattening history into a text prefix.
let resume_session_id: Option<String> = {
let guard = ctx.history.lock().await;
guard.get(&room_id).cloned().unwrap_or_default()
guard
.get(&room_id)
.and_then(|conv| conv.session_id.clone())
};
// Build the prompt with conversation context.
let prompt_with_context = build_context_prefix(&history_snapshot, &sender, &user_message);
// The prompt is just the current message with sender attribution.
// Prior conversation context is carried by the Claude Code session.
let prompt = format_user_prompt(&sender, &user_message);
let provider = ClaudeCodeProvider::new();
let (cancel_tx, mut cancel_rx) = watch::channel(false);
@@ -652,9 +712,9 @@ async fn handle_message(
let result = provider
.chat_stream(
&prompt_with_context,
&prompt,
&ctx.project_root.to_string_lossy(),
None, // Each Matrix conversation turn is independent at the Claude Code session level.
resume_session_id.as_deref(),
&mut cancel_rx,
move |token| {
let mut buf = buffer_for_callback.lock().unwrap();
@@ -675,9 +735,12 @@ async fn handle_message(
let remaining = buffer.lock().unwrap().trim().to_string();
let did_send_any = sent_any_chunk.load(Ordering::Relaxed);
let assistant_reply = match result {
Ok(ClaudeCodeResult { messages, .. }) => {
if !remaining.is_empty() {
let (assistant_reply, new_session_id) = match result {
Ok(ClaudeCodeResult {
messages,
session_id,
}) => {
let reply = if !remaining.is_empty() {
let _ = msg_tx.send(remaining.clone());
remaining
} else if !did_send_any {
@@ -696,13 +759,14 @@ async fn handle_message(
last_text
} else {
remaining
}
};
(reply, session_id)
}
Err(e) => {
slog!("[matrix-bot] LLM error: {e}");
let err_msg = format!("Error processing your request: {e}");
let _ = msg_tx.send(err_msg.clone());
err_msg
(err_msg, None)
}
};
@@ -711,25 +775,40 @@ async fn handle_message(
drop(msg_tx);
let _ = post_task.await;
// Record this exchange in the per-room conversation history.
// Record this exchange in the per-room conversation history and persist
// the session ID so the next turn resumes with structured API messages.
if !assistant_reply.starts_with("Error processing") {
let mut guard = ctx.history.lock().await;
let entries = guard.entry(room_id).or_default();
entries.push(ConversationEntry {
let conv = guard.entry(room_id).or_default();
// Store the session ID so the next turn uses --resume.
if new_session_id.is_some() {
conv.session_id = new_session_id;
}
conv.entries.push(ConversationEntry {
role: ConversationRole::User,
sender: sender.clone(),
content: user_message,
});
entries.push(ConversationEntry {
conv.entries.push(ConversationEntry {
role: ConversationRole::Assistant,
sender: String::new(),
content: assistant_reply,
});
// Trim to the configured maximum, dropping the oldest entries first.
if entries.len() > ctx.history_size {
let excess = entries.len() - ctx.history_size;
entries.drain(..excess);
// When trimming occurs, clear the session ID so the next turn starts
// a fresh Claude Code session (the old session's context would be
// stale since we've dropped entries from our tracking).
if conv.entries.len() > ctx.history_size {
let excess = conv.entries.len() - ctx.history_size;
conv.entries.drain(..excess);
conv.session_id = None;
}
// Persist to disk so history survives server restarts.
save_history(&ctx.project_root, &guard);
}
}
@@ -1128,62 +1207,18 @@ mod tests {
assert_eq!(buf, "Third.");
}
// -- build_context_prefix -----------------------------------------------
// -- format_user_prompt -------------------------------------------------
#[test]
fn build_context_prefix_empty_history() {
let prefix = build_context_prefix(&[], "@alice:example.com", "Hello!");
assert_eq!(prefix, "@alice:example.com: Hello!");
fn format_user_prompt_includes_sender_and_message() {
let prompt = format_user_prompt("@alice:example.com", "Hello!");
assert_eq!(prompt, "@alice:example.com: Hello!");
}
#[test]
fn build_context_prefix_includes_history_entries() {
let history = vec![
ConversationEntry {
role: ConversationRole::User,
sender: "@alice:example.com".to_string(),
content: "What is story 42?".to_string(),
},
ConversationEntry {
role: ConversationRole::Assistant,
sender: String::new(),
content: "Story 42 is about…".to_string(),
},
];
let prefix = build_context_prefix(&history, "@bob:example.com", "Tell me more.");
assert!(prefix.contains("[Conversation history for this room]"));
assert!(prefix.contains("User (@alice:example.com): What is story 42?"));
assert!(prefix.contains("Assistant: Story 42 is about…"));
assert!(prefix.contains("Current message from @bob:example.com: Tell me more."));
}
#[test]
fn build_context_prefix_attributes_multiple_users() {
let history = vec![
ConversationEntry {
role: ConversationRole::User,
sender: "@alice:example.com".to_string(),
content: "First question".to_string(),
},
ConversationEntry {
role: ConversationRole::Assistant,
sender: String::new(),
content: "First answer".to_string(),
},
ConversationEntry {
role: ConversationRole::User,
sender: "@bob:example.com".to_string(),
content: "Follow-up".to_string(),
},
ConversationEntry {
role: ConversationRole::Assistant,
sender: String::new(),
content: "Second answer".to_string(),
},
];
let prefix = build_context_prefix(&history, "@alice:example.com", "Another question");
assert!(prefix.contains("User (@alice:example.com): First question"));
assert!(prefix.contains("User (@bob:example.com): Follow-up"));
fn format_user_prompt_different_users() {
let prompt = format_user_prompt("@bob:example.com", "What's up?");
assert_eq!(prompt, "@bob:example.com: What's up?");
}
// -- conversation history trimming --------------------------------------
@@ -1197,37 +1232,44 @@ mod tests {
// Add 6 entries (3 user + 3 assistant turns).
{
let mut guard = history.lock().await;
let entries = guard.entry(room_id.clone()).or_default();
let conv = guard.entry(room_id.clone()).or_default();
conv.session_id = Some("test-session".to_string());
for i in 0..3usize {
entries.push(ConversationEntry {
conv.entries.push(ConversationEntry {
role: ConversationRole::User,
sender: "@user:example.com".to_string(),
content: format!("msg {i}"),
});
entries.push(ConversationEntry {
conv.entries.push(ConversationEntry {
role: ConversationRole::Assistant,
sender: String::new(),
content: format!("reply {i}"),
});
if entries.len() > history_size {
let excess = entries.len() - history_size;
entries.drain(..excess);
if conv.entries.len() > history_size {
let excess = conv.entries.len() - history_size;
conv.entries.drain(..excess);
conv.session_id = None;
}
}
}
let guard = history.lock().await;
let entries = guard.get(&room_id).unwrap();
let conv = guard.get(&room_id).unwrap();
assert_eq!(
entries.len(),
conv.entries.len(),
history_size,
"history must be trimmed to history_size"
);
// The oldest entries (msg 0 / reply 0) should have been dropped.
assert!(
entries.iter().all(|e| !e.content.contains("msg 0")),
conv.entries.iter().all(|e| !e.content.contains("msg 0")),
"oldest entries must be dropped"
);
// Session ID must be cleared when trimming occurs.
assert!(
conv.session_id.is_none(),
"session_id must be cleared on trim to start a fresh session"
);
}
#[tokio::test]
@@ -1241,6 +1283,7 @@ mod tests {
guard
.entry(room_a.clone())
.or_default()
.entries
.push(ConversationEntry {
role: ConversationRole::User,
sender: "@alice:example.com".to_string(),
@@ -1249,6 +1292,7 @@ mod tests {
guard
.entry(room_b.clone())
.or_default()
.entries
.push(ConversationEntry {
role: ConversationRole::User,
sender: "@bob:example.com".to_string(),
@@ -1257,12 +1301,131 @@ mod tests {
}
let guard = history.lock().await;
let entries_a = guard.get(&room_a).unwrap();
let entries_b = guard.get(&room_b).unwrap();
assert_eq!(entries_a.len(), 1);
assert_eq!(entries_b.len(), 1);
assert_eq!(entries_a[0].content, "Room A message");
assert_eq!(entries_b[0].content, "Room B message");
let conv_a = guard.get(&room_a).unwrap();
let conv_b = guard.get(&room_b).unwrap();
assert_eq!(conv_a.entries.len(), 1);
assert_eq!(conv_b.entries.len(), 1);
assert_eq!(conv_a.entries[0].content, "Room A message");
assert_eq!(conv_b.entries[0].content, "Room B message");
}
// -- persistence --------------------------------------------------------
#[test]
fn save_and_load_history_round_trip() {
let dir = tempfile::tempdir().unwrap();
let story_kit_dir = dir.path().join(".story_kit");
std::fs::create_dir_all(&story_kit_dir).unwrap();
let room_id: OwnedRoomId = "!persist:example.com".parse().unwrap();
let mut map: HashMap<OwnedRoomId, RoomConversation> = HashMap::new();
let conv = map.entry(room_id.clone()).or_default();
conv.session_id = Some("session-abc".to_string());
conv.entries.push(ConversationEntry {
role: ConversationRole::User,
sender: "@alice:example.com".to_string(),
content: "hello".to_string(),
});
conv.entries.push(ConversationEntry {
role: ConversationRole::Assistant,
sender: String::new(),
content: "hi there!".to_string(),
});
save_history(dir.path(), &map);
let loaded = load_history(dir.path());
let loaded_conv = loaded.get(&room_id).expect("room must exist after load");
assert_eq!(loaded_conv.session_id.as_deref(), Some("session-abc"));
assert_eq!(loaded_conv.entries.len(), 2);
assert_eq!(loaded_conv.entries[0].role, ConversationRole::User);
assert_eq!(loaded_conv.entries[0].sender, "@alice:example.com");
assert_eq!(loaded_conv.entries[0].content, "hello");
assert_eq!(loaded_conv.entries[1].role, ConversationRole::Assistant);
assert_eq!(loaded_conv.entries[1].content, "hi there!");
}
#[test]
fn load_history_returns_empty_on_missing_file() {
let dir = tempfile::tempdir().unwrap();
let loaded = load_history(dir.path());
assert!(loaded.is_empty());
}
#[test]
fn load_history_returns_empty_on_corrupt_file() {
let dir = tempfile::tempdir().unwrap();
let story_kit_dir = dir.path().join(".story_kit");
std::fs::create_dir_all(&story_kit_dir).unwrap();
std::fs::write(dir.path().join(HISTORY_FILE), "not valid json").unwrap();
let loaded = load_history(dir.path());
assert!(loaded.is_empty());
}
// -- session_id tracking ------------------------------------------------
#[tokio::test]
async fn session_id_preserved_within_history_size() {
let history: ConversationHistory = Arc::new(TokioMutex::new(HashMap::new()));
let room_id: OwnedRoomId = "!session:example.com".parse().unwrap();
{
let mut guard = history.lock().await;
let conv = guard.entry(room_id.clone()).or_default();
conv.session_id = Some("sess-1".to_string());
conv.entries.push(ConversationEntry {
role: ConversationRole::User,
sender: "@alice:example.com".to_string(),
content: "hello".to_string(),
});
conv.entries.push(ConversationEntry {
role: ConversationRole::Assistant,
sender: String::new(),
content: "hi".to_string(),
});
// No trimming needed (2 entries, well under any reasonable limit).
}
let guard = history.lock().await;
let conv = guard.get(&room_id).unwrap();
assert_eq!(
conv.session_id.as_deref(),
Some("sess-1"),
"session_id must be preserved when no trimming occurs"
);
}
// -- multi-user room attribution ----------------------------------------
#[tokio::test]
async fn multi_user_entries_preserve_sender() {
let history: ConversationHistory = Arc::new(TokioMutex::new(HashMap::new()));
let room_id: OwnedRoomId = "!multi:example.com".parse().unwrap();
{
let mut guard = history.lock().await;
let conv = guard.entry(room_id.clone()).or_default();
conv.entries.push(ConversationEntry {
role: ConversationRole::User,
sender: "@alice:example.com".to_string(),
content: "from alice".to_string(),
});
conv.entries.push(ConversationEntry {
role: ConversationRole::Assistant,
sender: String::new(),
content: "reply to alice".to_string(),
});
conv.entries.push(ConversationEntry {
role: ConversationRole::User,
sender: "@bob:example.com".to_string(),
content: "from bob".to_string(),
});
}
let guard = history.lock().await;
let conv = guard.get(&room_id).unwrap();
assert_eq!(conv.entries[0].sender, "@alice:example.com");
assert_eq!(conv.entries[2].sender, "@bob:example.com");
}
// -- self-sign device key decision logic -----------------------------------