368 lines
14 KiB
Rust
368 lines
14 KiB
Rust
//! Protocol-agnostic chat dispatcher — coalesce window + per-session serial lock.
|
|
//!
|
|
//! Sits between every inbound transport (Matrix, Slack, WhatsApp, …) and the
|
|
//! `claude -p` spawner. Transport handlers call [`ChatDispatcher::submit`]
|
|
//! instead of spawning directly; the dispatcher enforces two invariants:
|
|
//!
|
|
//! 1. **Coalesce window**: messages arriving for the same session within
|
|
//! `coalesce_ms` of each other are concatenated and delivered to a single
|
|
//! spawn. The window is a *debounce*: each new message extends the window by
|
|
//! `coalesce_ms` from its arrival time, so bursts flush as one batch.
|
|
//!
|
|
//! 2. **Per-session serial lock**: while one `claude -p` run is active, further
|
|
//! messages for that session queue up and are dispatched as a single batch
|
|
//! once the running invocation completes.
|
|
//!
|
|
//! A [`ChatDispatcher::stop`] call cancels the active run for a session and
|
|
//! discards the pending queue.
|
|
|
|
use crate::slog;
|
|
use std::collections::HashMap;
|
|
use std::pin::Pin;
|
|
use std::sync::{Arc, Mutex};
|
|
use std::time::Duration;
|
|
use tokio::sync::{mpsc, watch};
|
|
|
|
/// A factory function that produces one LLM execution future per dispatch.
|
|
///
|
|
/// Arguments:
|
|
/// - `String` — the (possibly concatenated) prompt to send to `claude -p`.
|
|
/// - `watch::Receiver<bool>` — send `true` on this channel to cancel the run.
|
|
///
|
|
/// Returns a boxed, pinned `Send + 'static` future that resolves when the LLM
|
|
/// session ends (whether normally or via cancellation).
|
|
pub type SpawnFn = Arc<
|
|
dyn Fn(
|
|
String,
|
|
watch::Receiver<bool>,
|
|
) -> Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>>
|
|
+ Send
|
|
+ Sync,
|
|
>;
|
|
|
|
enum SessionMsg {
|
|
UserMessage { text: String, factory: SpawnFn },
|
|
Stop,
|
|
}
|
|
|
|
struct SessionHandle {
|
|
tx: mpsc::UnboundedSender<SessionMsg>,
|
|
}
|
|
|
|
/// Coalescing, serialising dispatcher for chat-to-LLM message routing.
|
|
///
|
|
/// Construct once at startup via [`ChatDispatcher::new`] and share via `Arc`.
|
|
/// Call [`submit`](ChatDispatcher::submit) from every transport handler instead
|
|
/// of spawning `claude -p` directly.
|
|
pub struct ChatDispatcher {
|
|
sessions: Mutex<HashMap<String, SessionHandle>>,
|
|
coalesce_ms: u64,
|
|
}
|
|
|
|
impl ChatDispatcher {
|
|
/// Create a new dispatcher with the given coalesce window in milliseconds.
|
|
pub fn new(coalesce_ms: u64) -> Self {
|
|
Self {
|
|
sessions: Mutex::new(HashMap::new()),
|
|
coalesce_ms,
|
|
}
|
|
}
|
|
|
|
/// Submit a message for a chat session.
|
|
///
|
|
/// If no session task exists for `session_key`, one is created lazily.
|
|
/// The `factory` is called by the session task when the coalesce window
|
|
/// closes (or immediately after the current run finishes, for pending
|
|
/// messages).
|
|
pub fn submit(&self, session_key: String, message: String, factory: SpawnFn) {
|
|
let mut guard = self.sessions.lock().unwrap();
|
|
let coalesce_ms = self.coalesce_ms;
|
|
let handle = guard.entry(session_key.clone()).or_insert_with(|| {
|
|
let (tx, rx) = mpsc::unbounded_channel();
|
|
tokio::spawn(session_task(session_key.clone(), rx, coalesce_ms));
|
|
SessionHandle { tx }
|
|
});
|
|
let _ = handle.tx.send(SessionMsg::UserMessage {
|
|
text: message,
|
|
factory,
|
|
});
|
|
}
|
|
|
|
/// Stop the active LLM run for `session_key` and clear its pending queue.
|
|
///
|
|
/// Returns `true` if the session existed (whether or not anything was
|
|
/// actually running), `false` if no session for that key has been created.
|
|
pub fn stop(&self, session_key: &str) -> bool {
|
|
let guard = self.sessions.lock().unwrap();
|
|
if let Some(handle) = guard.get(session_key) {
|
|
let _ = handle.tx.send(SessionMsg::Stop);
|
|
true
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Per-session background task.
|
|
///
|
|
/// Phases:
|
|
/// 1. **Wait** — blocks until the first `UserMessage` arrives.
|
|
/// 2. **Coalesce** — extends the window by `coalesce_ms` on each new message;
|
|
/// fires when no message arrives within the window.
|
|
/// 3. **Run** — calls the factory with the concatenated batch; while running,
|
|
/// collects further `UserMessage`s into a pending list and logs a warn per
|
|
/// message. A `Stop` message cancels the running call and clears pending.
|
|
/// 4. **Drain** — after the run, if pending is non-empty, fires a second run
|
|
/// with the accumulated batch and loops back to step 3.
|
|
/// 5. Returns to step 1 when pending is empty.
|
|
async fn session_task(
|
|
session_key: String,
|
|
mut rx: mpsc::UnboundedReceiver<SessionMsg>,
|
|
coalesce_ms: u64,
|
|
) {
|
|
let coalesce_dur = Duration::from_millis(coalesce_ms);
|
|
|
|
loop {
|
|
// ── Phase 1: wait for the first message ─────────────────────────────
|
|
let (first_text, first_factory) = loop {
|
|
match rx.recv().await {
|
|
None => return,
|
|
Some(SessionMsg::Stop) => continue,
|
|
Some(SessionMsg::UserMessage { text, factory }) => break (text, factory),
|
|
}
|
|
};
|
|
|
|
// ── Phase 2: coalesce window (debounce) ──────────────────────────────
|
|
let mut batch: Vec<String> = vec![first_text];
|
|
let mut latest_factory: SpawnFn = first_factory;
|
|
let mut deadline = tokio::time::Instant::now() + coalesce_dur;
|
|
|
|
'coalesce: loop {
|
|
let now = tokio::time::Instant::now();
|
|
if now >= deadline {
|
|
break 'coalesce;
|
|
}
|
|
let remaining = deadline - now;
|
|
match tokio::time::timeout(remaining, rx.recv()).await {
|
|
Err(_) => break 'coalesce, // window closed
|
|
Ok(None) => return, // channel closed → exit task
|
|
Ok(Some(SessionMsg::Stop)) => {
|
|
batch.clear();
|
|
break 'coalesce;
|
|
}
|
|
Ok(Some(SessionMsg::UserMessage { text, factory })) => {
|
|
batch.push(text);
|
|
latest_factory = factory;
|
|
// Extend deadline on each new message (debounce).
|
|
deadline = tokio::time::Instant::now() + coalesce_dur;
|
|
}
|
|
}
|
|
}
|
|
|
|
if batch.is_empty() {
|
|
continue; // Stop received during coalesce — restart
|
|
}
|
|
|
|
// ── Phase 3 + 4: run → drain pending → repeat ───────────────────────
|
|
let mut prompt = batch.join("\n\n");
|
|
let mut factory = latest_factory;
|
|
|
|
loop {
|
|
let (cancel_tx, cancel_rx) = watch::channel(false);
|
|
let llm_fut = factory(prompt, cancel_rx);
|
|
let mut llm_task = tokio::spawn(llm_fut);
|
|
|
|
let mut pending_texts: Vec<String> = vec![];
|
|
let mut pending_factory: Option<SpawnFn> = None;
|
|
let mut stopped = false;
|
|
|
|
// Wait for the LLM to finish, collecting messages that arrive during the run.
|
|
loop {
|
|
tokio::select! {
|
|
_ = &mut llm_task => { break; }
|
|
msg = rx.recv() => {
|
|
match msg {
|
|
None => {
|
|
llm_task.abort();
|
|
return;
|
|
}
|
|
Some(SessionMsg::Stop) => {
|
|
let _ = cancel_tx.send(true);
|
|
let _ = llm_task.await;
|
|
pending_texts.clear();
|
|
stopped = true;
|
|
break;
|
|
}
|
|
Some(SessionMsg::UserMessage { text, factory: f }) => {
|
|
pending_texts.push(text);
|
|
let depth = pending_texts.len();
|
|
slog!(
|
|
"[chat-dispatcher] coalescing message for session={}, queue_depth={}",
|
|
session_key,
|
|
depth,
|
|
);
|
|
pending_factory = Some(f);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if stopped || pending_texts.is_empty() {
|
|
break; // back to Phase 1
|
|
}
|
|
|
|
// Fire the pending batch as the next run (no additional coalesce window).
|
|
prompt = pending_texts.join("\n\n");
|
|
factory = pending_factory.unwrap();
|
|
}
|
|
}
|
|
}
|
|
|
|
// ── Tests ─────────────────────────────────────────────────────────────────────
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
|
|
fn make_factory(spawn_count: Arc<AtomicUsize>, run_ms: u64) -> SpawnFn {
|
|
Arc::new(move |_prompt: String, _cancel_rx: watch::Receiver<bool>| {
|
|
let count = Arc::clone(&spawn_count);
|
|
Box::pin(async move {
|
|
count.fetch_add(1, Ordering::SeqCst);
|
|
tokio::time::sleep(Duration::from_millis(run_ms)).await;
|
|
})
|
|
})
|
|
}
|
|
|
|
/// AC 6 regression: three messages arriving 200 ms / (long gap) / (after run)
|
|
/// apart on the same session must produce at most two spawns, never three
|
|
/// concurrent processes.
|
|
///
|
|
/// Setup:
|
|
/// coalesce_ms = 50 ms (short window so test runs fast)
|
|
/// LLM "run" = 150 ms
|
|
/// msg1 @ t=0
|
|
/// msg2 @ t=20 ms — within coalesce window, merged with msg1 → 1 spawn
|
|
/// msg3 @ t=300 ms — after run completes → 2nd spawn
|
|
///
|
|
/// Expected: exactly 2 spawns, never 3.
|
|
#[tokio::test]
|
|
async fn three_messages_never_three_concurrent_spawns() {
|
|
let spawn_count = Arc::new(AtomicUsize::new(0));
|
|
let dispatcher = Arc::new(ChatDispatcher::new(50));
|
|
let session = "room1".to_string();
|
|
|
|
// msg1 at t=0
|
|
dispatcher.submit(
|
|
session.clone(),
|
|
"msg1".to_string(),
|
|
make_factory(Arc::clone(&spawn_count), 150),
|
|
);
|
|
|
|
// msg2 at t=20 ms — inside the 50 ms coalesce window
|
|
tokio::time::sleep(Duration::from_millis(20)).await;
|
|
dispatcher.submit(
|
|
session.clone(),
|
|
"msg2".to_string(),
|
|
make_factory(Arc::clone(&spawn_count), 150),
|
|
);
|
|
|
|
// msg3 at t=300 ms — after the coalesce window fires (t≈70 ms) and the
|
|
// 150 ms run completes (t≈220 ms), so msg3 starts a second coalesce cycle.
|
|
tokio::time::sleep(Duration::from_millis(280)).await;
|
|
dispatcher.submit(
|
|
session.clone(),
|
|
"msg3".to_string(),
|
|
make_factory(Arc::clone(&spawn_count), 150),
|
|
);
|
|
|
|
// Wait long enough for both runs to finish.
|
|
tokio::time::sleep(Duration::from_millis(500)).await;
|
|
|
|
let count = spawn_count.load(Ordering::SeqCst);
|
|
assert!(
|
|
(1..=2).contains(&count),
|
|
"expected 1 or 2 spawns (msgs 1+2 coalesced, msg3 separate), got {count}"
|
|
);
|
|
}
|
|
|
|
/// Messages that arrive while the LLM is running are not lost — they are
|
|
/// delivered as a single follow-up spawn once the first run completes.
|
|
#[tokio::test]
|
|
async fn pending_messages_dispatched_after_run_completes() {
|
|
let spawn_count = Arc::new(AtomicUsize::new(0));
|
|
let dispatcher = Arc::new(ChatDispatcher::new(50));
|
|
let session = "room2".to_string();
|
|
|
|
// First message — starts a 200 ms run.
|
|
dispatcher.submit(
|
|
session.clone(),
|
|
"first".to_string(),
|
|
make_factory(Arc::clone(&spawn_count), 200),
|
|
);
|
|
|
|
// Wait for coalesce window to fire, then send two more.
|
|
tokio::time::sleep(Duration::from_millis(100)).await;
|
|
dispatcher.submit(
|
|
session.clone(),
|
|
"second".to_string(),
|
|
make_factory(Arc::clone(&spawn_count), 50),
|
|
);
|
|
dispatcher.submit(
|
|
session.clone(),
|
|
"third".to_string(),
|
|
make_factory(Arc::clone(&spawn_count), 50),
|
|
);
|
|
|
|
// Wait long enough for both runs.
|
|
tokio::time::sleep(Duration::from_millis(600)).await;
|
|
|
|
let count = spawn_count.load(Ordering::SeqCst);
|
|
assert_eq!(
|
|
count, 2,
|
|
"first run + one pending-batch run = 2 total spawns"
|
|
);
|
|
}
|
|
|
|
/// Stop cancels the running LLM and discards pending messages.
|
|
#[tokio::test]
|
|
async fn stop_cancels_run_and_clears_pending() {
|
|
let spawn_count = Arc::new(AtomicUsize::new(0));
|
|
let dispatcher = Arc::new(ChatDispatcher::new(30));
|
|
let session = "room3".to_string();
|
|
|
|
// Start a long run.
|
|
dispatcher.submit(
|
|
session.clone(),
|
|
"long-running".to_string(),
|
|
make_factory(Arc::clone(&spawn_count), 500),
|
|
);
|
|
|
|
// Wait for coalesce window to fire.
|
|
tokio::time::sleep(Duration::from_millis(80)).await;
|
|
|
|
// Queue a pending message.
|
|
dispatcher.submit(
|
|
session.clone(),
|
|
"pending".to_string(),
|
|
make_factory(Arc::clone(&spawn_count), 50),
|
|
);
|
|
|
|
// Stop immediately.
|
|
dispatcher.stop(&session);
|
|
|
|
// Wait longer than the run would have taken if not stopped.
|
|
tokio::time::sleep(Duration::from_millis(700)).await;
|
|
|
|
let count = spawn_count.load(Ordering::SeqCst);
|
|
// The first run was started before stop (spawn_count=1).
|
|
// The pending message should NOT have produced a second spawn.
|
|
assert!(
|
|
count <= 1,
|
|
"stop should discard pending; got {count} spawns"
|
|
);
|
|
}
|
|
}
|