story-kit: merge 266_story_matrix_bot_structured_conversation_history
This commit is contained in:
@@ -14,6 +14,7 @@ use matrix_sdk::{
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
use pulldown_cmark::{Options, Parser, html};
|
use pulldown_cmark::{Options, Parser, html};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@@ -34,7 +35,8 @@ use super::config::BotConfig;
|
|||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
/// Role of a participant in the conversation history.
|
/// Role of a participant in the conversation history.
|
||||||
#[derive(Clone, Debug, PartialEq)]
|
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
pub enum ConversationRole {
|
pub enum ConversationRole {
|
||||||
/// A message sent by a Matrix room participant.
|
/// A message sent by a Matrix room participant.
|
||||||
User,
|
User,
|
||||||
@@ -43,7 +45,7 @@ pub enum ConversationRole {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// A single turn in the per-room conversation history.
|
/// A single turn in the per-room conversation history.
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
pub struct ConversationEntry {
|
pub struct ConversationEntry {
|
||||||
pub role: ConversationRole,
|
pub role: ConversationRole,
|
||||||
/// Matrix user ID (e.g. `@alice:example.com`). Empty for assistant turns.
|
/// Matrix user ID (e.g. `@alice:example.com`). Empty for assistant turns.
|
||||||
@@ -51,11 +53,81 @@ pub struct ConversationEntry {
|
|||||||
pub content: String,
|
pub content: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Per-room conversation history, keyed by room ID.
|
/// Per-room state: conversation entries plus the Claude Code session ID for
|
||||||
|
/// structured conversation resumption.
|
||||||
|
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||||
|
pub struct RoomConversation {
|
||||||
|
/// Claude Code session ID used to resume multi-turn conversations so the
|
||||||
|
/// LLM receives prior turns as structured API messages rather than a
|
||||||
|
/// flattened text prefix.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub session_id: Option<String>,
|
||||||
|
/// Rolling conversation entries (used for turn counting and persistence).
|
||||||
|
pub entries: Vec<ConversationEntry>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Per-room conversation state, keyed by room ID (serialised as string).
|
||||||
///
|
///
|
||||||
/// Wrapped in `Arc<TokioMutex<…>>` so it can be shared across concurrent
|
/// Wrapped in `Arc<TokioMutex<…>>` so it can be shared across concurrent
|
||||||
/// event-handler tasks without blocking the sync loop.
|
/// event-handler tasks without blocking the sync loop.
|
||||||
pub type ConversationHistory = Arc<TokioMutex<HashMap<OwnedRoomId, Vec<ConversationEntry>>>>;
|
pub type ConversationHistory = Arc<TokioMutex<HashMap<OwnedRoomId, RoomConversation>>>;
|
||||||
|
|
||||||
|
/// On-disk format for persisted conversation history. Room IDs are stored as
|
||||||
|
/// strings because `OwnedRoomId` does not implement `Serialize` as a map key.
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
struct PersistedHistory {
|
||||||
|
rooms: HashMap<String, RoomConversation>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Path to the persisted conversation history file relative to project root.
|
||||||
|
const HISTORY_FILE: &str = ".story_kit/matrix_history.json";
|
||||||
|
|
||||||
|
/// Load conversation history from disk, returning an empty map on any error.
|
||||||
|
pub fn load_history(project_root: &std::path::Path) -> HashMap<OwnedRoomId, RoomConversation> {
|
||||||
|
let path = project_root.join(HISTORY_FILE);
|
||||||
|
let data = match std::fs::read_to_string(&path) {
|
||||||
|
Ok(d) => d,
|
||||||
|
Err(_) => return HashMap::new(),
|
||||||
|
};
|
||||||
|
let persisted: PersistedHistory = match serde_json::from_str(&data) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(e) => {
|
||||||
|
slog!("[matrix-bot] Failed to parse history file: {e}");
|
||||||
|
return HashMap::new();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
persisted
|
||||||
|
.rooms
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|(k, v)| {
|
||||||
|
k.parse::<OwnedRoomId>()
|
||||||
|
.ok()
|
||||||
|
.map(|room_id| (room_id, v))
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Save conversation history to disk. Errors are logged but not propagated.
|
||||||
|
pub fn save_history(
|
||||||
|
project_root: &std::path::Path,
|
||||||
|
history: &HashMap<OwnedRoomId, RoomConversation>,
|
||||||
|
) {
|
||||||
|
let persisted = PersistedHistory {
|
||||||
|
rooms: history
|
||||||
|
.iter()
|
||||||
|
.map(|(k, v)| (k.to_string(), v.clone()))
|
||||||
|
.collect(),
|
||||||
|
};
|
||||||
|
let path = project_root.join(HISTORY_FILE);
|
||||||
|
match serde_json::to_string_pretty(&persisted) {
|
||||||
|
Ok(json) => {
|
||||||
|
if let Err(e) = std::fs::write(&path, json) {
|
||||||
|
slog!("[matrix-bot] Failed to write history file: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => slog!("[matrix-bot] Failed to serialise history: {e}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// Bot context
|
// Bot context
|
||||||
@@ -221,12 +293,18 @@ pub async fn run_bot(
|
|||||||
let notif_room_ids = target_room_ids.clone();
|
let notif_room_ids = target_room_ids.clone();
|
||||||
let notif_project_root = project_root.clone();
|
let notif_project_root = project_root.clone();
|
||||||
|
|
||||||
|
let persisted = load_history(&project_root);
|
||||||
|
slog!(
|
||||||
|
"[matrix-bot] Loaded persisted conversation history for {} room(s)",
|
||||||
|
persisted.len()
|
||||||
|
);
|
||||||
|
|
||||||
let ctx = BotContext {
|
let ctx = BotContext {
|
||||||
bot_user_id,
|
bot_user_id,
|
||||||
target_room_ids,
|
target_room_ids,
|
||||||
project_root,
|
project_root,
|
||||||
allowed_users: config.allowed_users,
|
allowed_users: config.allowed_users,
|
||||||
history: Arc::new(TokioMutex::new(HashMap::new())),
|
history: Arc::new(TokioMutex::new(persisted)),
|
||||||
history_size: config.history_size,
|
history_size: config.history_size,
|
||||||
bot_sent_event_ids: Arc::new(TokioMutex::new(HashSet::new())),
|
bot_sent_event_ids: Arc::new(TokioMutex::new(HashSet::new())),
|
||||||
};
|
};
|
||||||
@@ -574,33 +652,10 @@ async fn on_room_message(
|
|||||||
// Message handler
|
// Message handler
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
/// Build a context string from the room's conversation history to prepend to
|
/// Build the user-facing prompt for a single turn. In multi-user rooms the
|
||||||
/// the user's current message. Returns an empty string when history is empty.
|
/// sender is included so the LLM can distinguish participants.
|
||||||
fn build_context_prefix(
|
fn format_user_prompt(sender: &str, message: &str) -> String {
|
||||||
history: &[ConversationEntry],
|
format!("{sender}: {message}")
|
||||||
current_sender: &str,
|
|
||||||
current_message: &str,
|
|
||||||
) -> String {
|
|
||||||
if history.is_empty() {
|
|
||||||
return format!("{current_sender}: {current_message}");
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut out = String::from("[Conversation history for this room]\n");
|
|
||||||
for entry in history {
|
|
||||||
match entry.role {
|
|
||||||
ConversationRole::User => {
|
|
||||||
out.push_str(&format!("User ({}): {}\n", entry.sender, entry.content));
|
|
||||||
}
|
|
||||||
ConversationRole::Assistant => {
|
|
||||||
out.push_str(&format!("Assistant: {}\n", entry.content));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
out.push('\n');
|
|
||||||
out.push_str(&format!(
|
|
||||||
"Current message from {current_sender}: {current_message}"
|
|
||||||
));
|
|
||||||
out
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_message(
|
async fn handle_message(
|
||||||
@@ -610,14 +665,19 @@ async fn handle_message(
|
|||||||
sender: String,
|
sender: String,
|
||||||
user_message: String,
|
user_message: String,
|
||||||
) {
|
) {
|
||||||
// Read current history for this room before calling the LLM.
|
// Look up the room's existing Claude Code session ID (if any) so we can
|
||||||
let history_snapshot: Vec<ConversationEntry> = {
|
// resume the conversation with structured API messages instead of
|
||||||
|
// flattening history into a text prefix.
|
||||||
|
let resume_session_id: Option<String> = {
|
||||||
let guard = ctx.history.lock().await;
|
let guard = ctx.history.lock().await;
|
||||||
guard.get(&room_id).cloned().unwrap_or_default()
|
guard
|
||||||
|
.get(&room_id)
|
||||||
|
.and_then(|conv| conv.session_id.clone())
|
||||||
};
|
};
|
||||||
|
|
||||||
// Build the prompt with conversation context.
|
// The prompt is just the current message with sender attribution.
|
||||||
let prompt_with_context = build_context_prefix(&history_snapshot, &sender, &user_message);
|
// Prior conversation context is carried by the Claude Code session.
|
||||||
|
let prompt = format_user_prompt(&sender, &user_message);
|
||||||
|
|
||||||
let provider = ClaudeCodeProvider::new();
|
let provider = ClaudeCodeProvider::new();
|
||||||
let (cancel_tx, mut cancel_rx) = watch::channel(false);
|
let (cancel_tx, mut cancel_rx) = watch::channel(false);
|
||||||
@@ -652,9 +712,9 @@ async fn handle_message(
|
|||||||
|
|
||||||
let result = provider
|
let result = provider
|
||||||
.chat_stream(
|
.chat_stream(
|
||||||
&prompt_with_context,
|
&prompt,
|
||||||
&ctx.project_root.to_string_lossy(),
|
&ctx.project_root.to_string_lossy(),
|
||||||
None, // Each Matrix conversation turn is independent at the Claude Code session level.
|
resume_session_id.as_deref(),
|
||||||
&mut cancel_rx,
|
&mut cancel_rx,
|
||||||
move |token| {
|
move |token| {
|
||||||
let mut buf = buffer_for_callback.lock().unwrap();
|
let mut buf = buffer_for_callback.lock().unwrap();
|
||||||
@@ -675,9 +735,12 @@ async fn handle_message(
|
|||||||
let remaining = buffer.lock().unwrap().trim().to_string();
|
let remaining = buffer.lock().unwrap().trim().to_string();
|
||||||
let did_send_any = sent_any_chunk.load(Ordering::Relaxed);
|
let did_send_any = sent_any_chunk.load(Ordering::Relaxed);
|
||||||
|
|
||||||
let assistant_reply = match result {
|
let (assistant_reply, new_session_id) = match result {
|
||||||
Ok(ClaudeCodeResult { messages, .. }) => {
|
Ok(ClaudeCodeResult {
|
||||||
if !remaining.is_empty() {
|
messages,
|
||||||
|
session_id,
|
||||||
|
}) => {
|
||||||
|
let reply = if !remaining.is_empty() {
|
||||||
let _ = msg_tx.send(remaining.clone());
|
let _ = msg_tx.send(remaining.clone());
|
||||||
remaining
|
remaining
|
||||||
} else if !did_send_any {
|
} else if !did_send_any {
|
||||||
@@ -696,13 +759,14 @@ async fn handle_message(
|
|||||||
last_text
|
last_text
|
||||||
} else {
|
} else {
|
||||||
remaining
|
remaining
|
||||||
}
|
};
|
||||||
|
(reply, session_id)
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
slog!("[matrix-bot] LLM error: {e}");
|
slog!("[matrix-bot] LLM error: {e}");
|
||||||
let err_msg = format!("Error processing your request: {e}");
|
let err_msg = format!("Error processing your request: {e}");
|
||||||
let _ = msg_tx.send(err_msg.clone());
|
let _ = msg_tx.send(err_msg.clone());
|
||||||
err_msg
|
(err_msg, None)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -711,25 +775,40 @@ async fn handle_message(
|
|||||||
drop(msg_tx);
|
drop(msg_tx);
|
||||||
let _ = post_task.await;
|
let _ = post_task.await;
|
||||||
|
|
||||||
// Record this exchange in the per-room conversation history.
|
// Record this exchange in the per-room conversation history and persist
|
||||||
|
// the session ID so the next turn resumes with structured API messages.
|
||||||
if !assistant_reply.starts_with("Error processing") {
|
if !assistant_reply.starts_with("Error processing") {
|
||||||
let mut guard = ctx.history.lock().await;
|
let mut guard = ctx.history.lock().await;
|
||||||
let entries = guard.entry(room_id).or_default();
|
let conv = guard.entry(room_id).or_default();
|
||||||
entries.push(ConversationEntry {
|
|
||||||
|
// Store the session ID so the next turn uses --resume.
|
||||||
|
if new_session_id.is_some() {
|
||||||
|
conv.session_id = new_session_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
conv.entries.push(ConversationEntry {
|
||||||
role: ConversationRole::User,
|
role: ConversationRole::User,
|
||||||
sender: sender.clone(),
|
sender: sender.clone(),
|
||||||
content: user_message,
|
content: user_message,
|
||||||
});
|
});
|
||||||
entries.push(ConversationEntry {
|
conv.entries.push(ConversationEntry {
|
||||||
role: ConversationRole::Assistant,
|
role: ConversationRole::Assistant,
|
||||||
sender: String::new(),
|
sender: String::new(),
|
||||||
content: assistant_reply,
|
content: assistant_reply,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Trim to the configured maximum, dropping the oldest entries first.
|
// Trim to the configured maximum, dropping the oldest entries first.
|
||||||
if entries.len() > ctx.history_size {
|
// When trimming occurs, clear the session ID so the next turn starts
|
||||||
let excess = entries.len() - ctx.history_size;
|
// a fresh Claude Code session (the old session's context would be
|
||||||
entries.drain(..excess);
|
// stale since we've dropped entries from our tracking).
|
||||||
|
if conv.entries.len() > ctx.history_size {
|
||||||
|
let excess = conv.entries.len() - ctx.history_size;
|
||||||
|
conv.entries.drain(..excess);
|
||||||
|
conv.session_id = None;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Persist to disk so history survives server restarts.
|
||||||
|
save_history(&ctx.project_root, &guard);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1128,62 +1207,18 @@ mod tests {
|
|||||||
assert_eq!(buf, "Third.");
|
assert_eq!(buf, "Third.");
|
||||||
}
|
}
|
||||||
|
|
||||||
// -- build_context_prefix -----------------------------------------------
|
// -- format_user_prompt -------------------------------------------------
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn build_context_prefix_empty_history() {
|
fn format_user_prompt_includes_sender_and_message() {
|
||||||
let prefix = build_context_prefix(&[], "@alice:example.com", "Hello!");
|
let prompt = format_user_prompt("@alice:example.com", "Hello!");
|
||||||
assert_eq!(prefix, "@alice:example.com: Hello!");
|
assert_eq!(prompt, "@alice:example.com: Hello!");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn build_context_prefix_includes_history_entries() {
|
fn format_user_prompt_different_users() {
|
||||||
let history = vec![
|
let prompt = format_user_prompt("@bob:example.com", "What's up?");
|
||||||
ConversationEntry {
|
assert_eq!(prompt, "@bob:example.com: What's up?");
|
||||||
role: ConversationRole::User,
|
|
||||||
sender: "@alice:example.com".to_string(),
|
|
||||||
content: "What is story 42?".to_string(),
|
|
||||||
},
|
|
||||||
ConversationEntry {
|
|
||||||
role: ConversationRole::Assistant,
|
|
||||||
sender: String::new(),
|
|
||||||
content: "Story 42 is about…".to_string(),
|
|
||||||
},
|
|
||||||
];
|
|
||||||
let prefix = build_context_prefix(&history, "@bob:example.com", "Tell me more.");
|
|
||||||
assert!(prefix.contains("[Conversation history for this room]"));
|
|
||||||
assert!(prefix.contains("User (@alice:example.com): What is story 42?"));
|
|
||||||
assert!(prefix.contains("Assistant: Story 42 is about…"));
|
|
||||||
assert!(prefix.contains("Current message from @bob:example.com: Tell me more."));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn build_context_prefix_attributes_multiple_users() {
|
|
||||||
let history = vec![
|
|
||||||
ConversationEntry {
|
|
||||||
role: ConversationRole::User,
|
|
||||||
sender: "@alice:example.com".to_string(),
|
|
||||||
content: "First question".to_string(),
|
|
||||||
},
|
|
||||||
ConversationEntry {
|
|
||||||
role: ConversationRole::Assistant,
|
|
||||||
sender: String::new(),
|
|
||||||
content: "First answer".to_string(),
|
|
||||||
},
|
|
||||||
ConversationEntry {
|
|
||||||
role: ConversationRole::User,
|
|
||||||
sender: "@bob:example.com".to_string(),
|
|
||||||
content: "Follow-up".to_string(),
|
|
||||||
},
|
|
||||||
ConversationEntry {
|
|
||||||
role: ConversationRole::Assistant,
|
|
||||||
sender: String::new(),
|
|
||||||
content: "Second answer".to_string(),
|
|
||||||
},
|
|
||||||
];
|
|
||||||
let prefix = build_context_prefix(&history, "@alice:example.com", "Another question");
|
|
||||||
assert!(prefix.contains("User (@alice:example.com): First question"));
|
|
||||||
assert!(prefix.contains("User (@bob:example.com): Follow-up"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// -- conversation history trimming --------------------------------------
|
// -- conversation history trimming --------------------------------------
|
||||||
@@ -1197,37 +1232,44 @@ mod tests {
|
|||||||
// Add 6 entries (3 user + 3 assistant turns).
|
// Add 6 entries (3 user + 3 assistant turns).
|
||||||
{
|
{
|
||||||
let mut guard = history.lock().await;
|
let mut guard = history.lock().await;
|
||||||
let entries = guard.entry(room_id.clone()).or_default();
|
let conv = guard.entry(room_id.clone()).or_default();
|
||||||
|
conv.session_id = Some("test-session".to_string());
|
||||||
for i in 0..3usize {
|
for i in 0..3usize {
|
||||||
entries.push(ConversationEntry {
|
conv.entries.push(ConversationEntry {
|
||||||
role: ConversationRole::User,
|
role: ConversationRole::User,
|
||||||
sender: "@user:example.com".to_string(),
|
sender: "@user:example.com".to_string(),
|
||||||
content: format!("msg {i}"),
|
content: format!("msg {i}"),
|
||||||
});
|
});
|
||||||
entries.push(ConversationEntry {
|
conv.entries.push(ConversationEntry {
|
||||||
role: ConversationRole::Assistant,
|
role: ConversationRole::Assistant,
|
||||||
sender: String::new(),
|
sender: String::new(),
|
||||||
content: format!("reply {i}"),
|
content: format!("reply {i}"),
|
||||||
});
|
});
|
||||||
if entries.len() > history_size {
|
if conv.entries.len() > history_size {
|
||||||
let excess = entries.len() - history_size;
|
let excess = conv.entries.len() - history_size;
|
||||||
entries.drain(..excess);
|
conv.entries.drain(..excess);
|
||||||
|
conv.session_id = None;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let guard = history.lock().await;
|
let guard = history.lock().await;
|
||||||
let entries = guard.get(&room_id).unwrap();
|
let conv = guard.get(&room_id).unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
entries.len(),
|
conv.entries.len(),
|
||||||
history_size,
|
history_size,
|
||||||
"history must be trimmed to history_size"
|
"history must be trimmed to history_size"
|
||||||
);
|
);
|
||||||
// The oldest entries (msg 0 / reply 0) should have been dropped.
|
// The oldest entries (msg 0 / reply 0) should have been dropped.
|
||||||
assert!(
|
assert!(
|
||||||
entries.iter().all(|e| !e.content.contains("msg 0")),
|
conv.entries.iter().all(|e| !e.content.contains("msg 0")),
|
||||||
"oldest entries must be dropped"
|
"oldest entries must be dropped"
|
||||||
);
|
);
|
||||||
|
// Session ID must be cleared when trimming occurs.
|
||||||
|
assert!(
|
||||||
|
conv.session_id.is_none(),
|
||||||
|
"session_id must be cleared on trim to start a fresh session"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -1241,6 +1283,7 @@ mod tests {
|
|||||||
guard
|
guard
|
||||||
.entry(room_a.clone())
|
.entry(room_a.clone())
|
||||||
.or_default()
|
.or_default()
|
||||||
|
.entries
|
||||||
.push(ConversationEntry {
|
.push(ConversationEntry {
|
||||||
role: ConversationRole::User,
|
role: ConversationRole::User,
|
||||||
sender: "@alice:example.com".to_string(),
|
sender: "@alice:example.com".to_string(),
|
||||||
@@ -1249,6 +1292,7 @@ mod tests {
|
|||||||
guard
|
guard
|
||||||
.entry(room_b.clone())
|
.entry(room_b.clone())
|
||||||
.or_default()
|
.or_default()
|
||||||
|
.entries
|
||||||
.push(ConversationEntry {
|
.push(ConversationEntry {
|
||||||
role: ConversationRole::User,
|
role: ConversationRole::User,
|
||||||
sender: "@bob:example.com".to_string(),
|
sender: "@bob:example.com".to_string(),
|
||||||
@@ -1257,12 +1301,131 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let guard = history.lock().await;
|
let guard = history.lock().await;
|
||||||
let entries_a = guard.get(&room_a).unwrap();
|
let conv_a = guard.get(&room_a).unwrap();
|
||||||
let entries_b = guard.get(&room_b).unwrap();
|
let conv_b = guard.get(&room_b).unwrap();
|
||||||
assert_eq!(entries_a.len(), 1);
|
assert_eq!(conv_a.entries.len(), 1);
|
||||||
assert_eq!(entries_b.len(), 1);
|
assert_eq!(conv_b.entries.len(), 1);
|
||||||
assert_eq!(entries_a[0].content, "Room A message");
|
assert_eq!(conv_a.entries[0].content, "Room A message");
|
||||||
assert_eq!(entries_b[0].content, "Room B message");
|
assert_eq!(conv_b.entries[0].content, "Room B message");
|
||||||
|
}
|
||||||
|
|
||||||
|
// -- persistence --------------------------------------------------------
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn save_and_load_history_round_trip() {
|
||||||
|
let dir = tempfile::tempdir().unwrap();
|
||||||
|
let story_kit_dir = dir.path().join(".story_kit");
|
||||||
|
std::fs::create_dir_all(&story_kit_dir).unwrap();
|
||||||
|
|
||||||
|
let room_id: OwnedRoomId = "!persist:example.com".parse().unwrap();
|
||||||
|
let mut map: HashMap<OwnedRoomId, RoomConversation> = HashMap::new();
|
||||||
|
let conv = map.entry(room_id.clone()).or_default();
|
||||||
|
conv.session_id = Some("session-abc".to_string());
|
||||||
|
conv.entries.push(ConversationEntry {
|
||||||
|
role: ConversationRole::User,
|
||||||
|
sender: "@alice:example.com".to_string(),
|
||||||
|
content: "hello".to_string(),
|
||||||
|
});
|
||||||
|
conv.entries.push(ConversationEntry {
|
||||||
|
role: ConversationRole::Assistant,
|
||||||
|
sender: String::new(),
|
||||||
|
content: "hi there!".to_string(),
|
||||||
|
});
|
||||||
|
|
||||||
|
save_history(dir.path(), &map);
|
||||||
|
|
||||||
|
let loaded = load_history(dir.path());
|
||||||
|
let loaded_conv = loaded.get(&room_id).expect("room must exist after load");
|
||||||
|
assert_eq!(loaded_conv.session_id.as_deref(), Some("session-abc"));
|
||||||
|
assert_eq!(loaded_conv.entries.len(), 2);
|
||||||
|
assert_eq!(loaded_conv.entries[0].role, ConversationRole::User);
|
||||||
|
assert_eq!(loaded_conv.entries[0].sender, "@alice:example.com");
|
||||||
|
assert_eq!(loaded_conv.entries[0].content, "hello");
|
||||||
|
assert_eq!(loaded_conv.entries[1].role, ConversationRole::Assistant);
|
||||||
|
assert_eq!(loaded_conv.entries[1].content, "hi there!");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn load_history_returns_empty_on_missing_file() {
|
||||||
|
let dir = tempfile::tempdir().unwrap();
|
||||||
|
let loaded = load_history(dir.path());
|
||||||
|
assert!(loaded.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn load_history_returns_empty_on_corrupt_file() {
|
||||||
|
let dir = tempfile::tempdir().unwrap();
|
||||||
|
let story_kit_dir = dir.path().join(".story_kit");
|
||||||
|
std::fs::create_dir_all(&story_kit_dir).unwrap();
|
||||||
|
std::fs::write(dir.path().join(HISTORY_FILE), "not valid json").unwrap();
|
||||||
|
let loaded = load_history(dir.path());
|
||||||
|
assert!(loaded.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
// -- session_id tracking ------------------------------------------------
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn session_id_preserved_within_history_size() {
|
||||||
|
let history: ConversationHistory = Arc::new(TokioMutex::new(HashMap::new()));
|
||||||
|
let room_id: OwnedRoomId = "!session:example.com".parse().unwrap();
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut guard = history.lock().await;
|
||||||
|
let conv = guard.entry(room_id.clone()).or_default();
|
||||||
|
conv.session_id = Some("sess-1".to_string());
|
||||||
|
conv.entries.push(ConversationEntry {
|
||||||
|
role: ConversationRole::User,
|
||||||
|
sender: "@alice:example.com".to_string(),
|
||||||
|
content: "hello".to_string(),
|
||||||
|
});
|
||||||
|
conv.entries.push(ConversationEntry {
|
||||||
|
role: ConversationRole::Assistant,
|
||||||
|
sender: String::new(),
|
||||||
|
content: "hi".to_string(),
|
||||||
|
});
|
||||||
|
// No trimming needed (2 entries, well under any reasonable limit).
|
||||||
|
}
|
||||||
|
|
||||||
|
let guard = history.lock().await;
|
||||||
|
let conv = guard.get(&room_id).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
conv.session_id.as_deref(),
|
||||||
|
Some("sess-1"),
|
||||||
|
"session_id must be preserved when no trimming occurs"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// -- multi-user room attribution ----------------------------------------
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn multi_user_entries_preserve_sender() {
|
||||||
|
let history: ConversationHistory = Arc::new(TokioMutex::new(HashMap::new()));
|
||||||
|
let room_id: OwnedRoomId = "!multi:example.com".parse().unwrap();
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut guard = history.lock().await;
|
||||||
|
let conv = guard.entry(room_id.clone()).or_default();
|
||||||
|
conv.entries.push(ConversationEntry {
|
||||||
|
role: ConversationRole::User,
|
||||||
|
sender: "@alice:example.com".to_string(),
|
||||||
|
content: "from alice".to_string(),
|
||||||
|
});
|
||||||
|
conv.entries.push(ConversationEntry {
|
||||||
|
role: ConversationRole::Assistant,
|
||||||
|
sender: String::new(),
|
||||||
|
content: "reply to alice".to_string(),
|
||||||
|
});
|
||||||
|
conv.entries.push(ConversationEntry {
|
||||||
|
role: ConversationRole::User,
|
||||||
|
sender: "@bob:example.com".to_string(),
|
||||||
|
content: "from bob".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let guard = history.lock().await;
|
||||||
|
let conv = guard.get(&room_id).unwrap();
|
||||||
|
assert_eq!(conv.entries[0].sender, "@alice:example.com");
|
||||||
|
assert_eq!(conv.entries[2].sender, "@bob:example.com");
|
||||||
}
|
}
|
||||||
|
|
||||||
// -- self-sign device key decision logic -----------------------------------
|
// -- self-sign device key decision logic -----------------------------------
|
||||||
|
|||||||
Reference in New Issue
Block a user