huskies: merge 472_story_discord_chat_transport
This commit is contained in:
+1
-1
@@ -34,6 +34,7 @@ walkdir = { workspace = true }
|
||||
matrix-sdk = { workspace = true }
|
||||
pulldown-cmark = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
tokio-tungstenite = { workspace = true }
|
||||
|
||||
# Force bundled SQLite so static musl builds don't need a system libsqlite3
|
||||
libsqlite3-sys = { version = "0.35.0", features = ["bundled"] }
|
||||
@@ -44,6 +45,5 @@ libc = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
tokio-tungstenite = { workspace = true }
|
||||
mockito = "1"
|
||||
filetime = { workspace = true }
|
||||
|
||||
@@ -0,0 +1,641 @@
|
||||
//! Discord incoming message dispatch and command handling.
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tokio::sync::{Mutex as TokioMutex, oneshot};
|
||||
|
||||
use crate::agents::AgentPool;
|
||||
use crate::chat::transport::matrix::{ConversationEntry, ConversationRole, RoomConversation};
|
||||
use crate::chat::util::is_permission_approval;
|
||||
use crate::chat::ChatTransport;
|
||||
use crate::http::context::{PermissionDecision, PermissionForward};
|
||||
use crate::slog;
|
||||
|
||||
use super::format::markdown_to_discord;
|
||||
use super::history::{DiscordConversationHistory, save_discord_history};
|
||||
use super::meta::DiscordTransport;
|
||||
|
||||
// ── Shared context ────────────────────────────────────────────────────
|
||||
|
||||
/// Shared context for the Discord bot, used by both the Gateway listener
|
||||
/// and any future webhook handlers.
|
||||
pub struct DiscordContext {
|
||||
pub bot_token: String,
|
||||
pub transport: Arc<DiscordTransport>,
|
||||
pub project_root: PathBuf,
|
||||
pub agents: Arc<AgentPool>,
|
||||
pub bot_name: String,
|
||||
/// The bot's Discord user ID (set dynamically from READY event, but
|
||||
/// also stored here for command dispatch).
|
||||
pub bot_user_id: String,
|
||||
pub ambient_rooms: Arc<Mutex<HashSet<String>>>,
|
||||
/// Per-channel conversation history for LLM passthrough.
|
||||
pub history: DiscordConversationHistory,
|
||||
/// Maximum number of conversation entries to keep per channel.
|
||||
pub history_size: usize,
|
||||
/// Allowed channel IDs (messages from other channels are ignored).
|
||||
pub channel_ids: HashSet<String>,
|
||||
/// Allowed Discord user IDs. When non-empty, only listed users can
|
||||
/// interact with the bot. When empty, all users are allowed.
|
||||
pub allowed_users: HashSet<String>,
|
||||
/// Permission requests from the MCP `prompt_permission` tool arrive here.
|
||||
pub perm_rx: Arc<TokioMutex<tokio::sync::mpsc::UnboundedReceiver<PermissionForward>>>,
|
||||
/// Pending permission replies keyed by channel ID.
|
||||
pub pending_perm_replies:
|
||||
Arc<TokioMutex<HashMap<String, oneshot::Sender<PermissionDecision>>>>,
|
||||
/// Seconds before an unanswered permission prompt is auto-denied.
|
||||
pub permission_timeout_secs: u64,
|
||||
}
|
||||
|
||||
// ── Incoming message dispatch ───────────────────────────────────────────
|
||||
|
||||
pub(super) async fn handle_incoming_message(
|
||||
ctx: &DiscordContext,
|
||||
channel: &str,
|
||||
user: &str,
|
||||
message: &str,
|
||||
) {
|
||||
use crate::chat::commands::{CommandDispatch, try_handle_command};
|
||||
|
||||
// If there is a pending permission prompt for this channel, interpret the
|
||||
// message as a yes/no response.
|
||||
{
|
||||
let mut pending = ctx.pending_perm_replies.lock().await;
|
||||
if let Some(tx) = pending.remove(channel) {
|
||||
let decision = if is_permission_approval(message) {
|
||||
PermissionDecision::Approve
|
||||
} else {
|
||||
PermissionDecision::Deny
|
||||
};
|
||||
let _ = tx.send(decision);
|
||||
let confirmation = if decision == PermissionDecision::Approve {
|
||||
"Permission approved."
|
||||
} else {
|
||||
"Permission denied."
|
||||
};
|
||||
let formatted = markdown_to_discord(confirmation);
|
||||
let _ = ctx.transport.send_message(channel, &formatted, "").await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let dispatch = CommandDispatch {
|
||||
bot_name: &ctx.bot_name,
|
||||
bot_user_id: &ctx.bot_user_id,
|
||||
project_root: &ctx.project_root,
|
||||
agents: &ctx.agents,
|
||||
ambient_rooms: &ctx.ambient_rooms,
|
||||
room_id: channel,
|
||||
};
|
||||
|
||||
if let Some(response) = try_handle_command(&dispatch, message) {
|
||||
slog!("[discord] Sending command response to {channel}");
|
||||
let response = markdown_to_discord(&response);
|
||||
if let Err(e) = ctx.transport.send_message(channel, &response, "").await {
|
||||
slog!("[discord] Failed to send reply to {channel}: {e}");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Check for async commands (htop, delete).
|
||||
if let Some(htop_cmd) = crate::chat::transport::matrix::htop::extract_htop_command(
|
||||
message,
|
||||
&ctx.bot_name,
|
||||
&ctx.bot_user_id,
|
||||
) {
|
||||
use crate::chat::transport::matrix::htop::HtopCommand;
|
||||
slog!("[discord] Handling htop command from {user} in {channel}");
|
||||
match htop_cmd {
|
||||
HtopCommand::Stop => {
|
||||
let _ = ctx
|
||||
.transport
|
||||
.send_message(channel, "htop stopped.", "")
|
||||
.await;
|
||||
}
|
||||
HtopCommand::Start { duration_secs } => {
|
||||
let snapshot = crate::chat::transport::matrix::htop::build_htop_message(
|
||||
&ctx.agents,
|
||||
0,
|
||||
duration_secs,
|
||||
);
|
||||
let snapshot = markdown_to_discord(&snapshot);
|
||||
let msg_id = match ctx.transport.send_message(channel, &snapshot, "").await {
|
||||
Ok(id) => id,
|
||||
Err(e) => {
|
||||
slog!("[discord] Failed to send htop message: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
let transport = Arc::clone(&ctx.transport) as Arc<dyn ChatTransport>;
|
||||
let agents = Arc::clone(&ctx.agents);
|
||||
let ch = channel.to_string();
|
||||
tokio::spawn(async move {
|
||||
let interval = std::time::Duration::from_secs(2);
|
||||
let total_ticks = (duration_secs as usize) / 2;
|
||||
for tick in 1..=total_ticks {
|
||||
tokio::time::sleep(interval).await;
|
||||
let updated =
|
||||
crate::chat::transport::matrix::htop::build_htop_message(
|
||||
&agents,
|
||||
(tick * 2) as u32,
|
||||
duration_secs,
|
||||
);
|
||||
let updated = markdown_to_discord(&updated);
|
||||
if let Err(e) =
|
||||
transport.edit_message(&ch, &msg_id, &updated, "").await
|
||||
{
|
||||
slog!("[discord] Failed to edit htop message: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(del_cmd) = crate::chat::transport::matrix::delete::extract_delete_command(
|
||||
message,
|
||||
&ctx.bot_name,
|
||||
&ctx.bot_user_id,
|
||||
) {
|
||||
let response = match del_cmd {
|
||||
crate::chat::transport::matrix::delete::DeleteCommand::Delete { story_number } => {
|
||||
slog!("[discord] Handling delete command from {user}: story {story_number}");
|
||||
crate::chat::transport::matrix::delete::handle_delete(
|
||||
&ctx.bot_name,
|
||||
&story_number,
|
||||
&ctx.project_root,
|
||||
&ctx.agents,
|
||||
)
|
||||
.await
|
||||
}
|
||||
crate::chat::transport::matrix::delete::DeleteCommand::BadArgs => {
|
||||
format!("Usage: `{} delete <number>`", ctx.bot_name)
|
||||
}
|
||||
};
|
||||
let response = markdown_to_discord(&response);
|
||||
let _ = ctx.transport.send_message(channel, &response, "").await;
|
||||
return;
|
||||
}
|
||||
|
||||
if crate::chat::transport::matrix::rebuild::extract_rebuild_command(
|
||||
message,
|
||||
&ctx.bot_name,
|
||||
&ctx.bot_user_id,
|
||||
)
|
||||
.is_some()
|
||||
{
|
||||
slog!("[discord] Handling rebuild command from {user} in {channel}");
|
||||
let ack = "Rebuilding server… this may take a moment.";
|
||||
let _ = ctx.transport.send_message(channel, ack, "").await;
|
||||
let response = crate::chat::transport::matrix::rebuild::handle_rebuild(
|
||||
&ctx.bot_name,
|
||||
&ctx.project_root,
|
||||
&ctx.agents,
|
||||
)
|
||||
.await;
|
||||
let response = markdown_to_discord(&response);
|
||||
let _ = ctx.transport.send_message(channel, &response, "").await;
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(rmtree_cmd) = crate::chat::transport::matrix::rmtree::extract_rmtree_command(
|
||||
message,
|
||||
&ctx.bot_name,
|
||||
&ctx.bot_user_id,
|
||||
) {
|
||||
let response = match rmtree_cmd {
|
||||
crate::chat::transport::matrix::rmtree::RmtreeCommand::Rmtree { story_number } => {
|
||||
slog!(
|
||||
"[discord] Handling rmtree command from {user} in {channel}: story {story_number}"
|
||||
);
|
||||
crate::chat::transport::matrix::rmtree::handle_rmtree(
|
||||
&ctx.bot_name,
|
||||
&story_number,
|
||||
&ctx.project_root,
|
||||
&ctx.agents,
|
||||
)
|
||||
.await
|
||||
}
|
||||
crate::chat::transport::matrix::rmtree::RmtreeCommand::BadArgs => {
|
||||
format!("Usage: `{} rmtree <number>`", ctx.bot_name)
|
||||
}
|
||||
};
|
||||
let response = markdown_to_discord(&response);
|
||||
let _ = ctx.transport.send_message(channel, &response, "").await;
|
||||
return;
|
||||
}
|
||||
|
||||
if crate::chat::transport::matrix::reset::extract_reset_command(
|
||||
message,
|
||||
&ctx.bot_name,
|
||||
&ctx.bot_user_id,
|
||||
)
|
||||
.is_some()
|
||||
{
|
||||
slog!("[discord] Handling reset command from {user} in {channel}");
|
||||
{
|
||||
let mut guard = ctx.history.lock().await;
|
||||
let conv = guard
|
||||
.entry(channel.to_string())
|
||||
.or_insert_with(RoomConversation::default);
|
||||
conv.session_id = None;
|
||||
conv.entries.clear();
|
||||
save_discord_history(&ctx.project_root, &guard);
|
||||
}
|
||||
let _ = ctx
|
||||
.transport
|
||||
.send_message(channel, "Session cleared.", "")
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(start_cmd) = crate::chat::transport::matrix::start::extract_start_command(
|
||||
message,
|
||||
&ctx.bot_name,
|
||||
&ctx.bot_user_id,
|
||||
) {
|
||||
let response = match start_cmd {
|
||||
crate::chat::transport::matrix::start::StartCommand::Start {
|
||||
story_number,
|
||||
agent_hint,
|
||||
} => {
|
||||
slog!(
|
||||
"[discord] Handling start command from {user} in {channel}: story {story_number}"
|
||||
);
|
||||
crate::chat::transport::matrix::start::handle_start(
|
||||
&ctx.bot_name,
|
||||
&story_number,
|
||||
agent_hint.as_deref(),
|
||||
&ctx.project_root,
|
||||
&ctx.agents,
|
||||
)
|
||||
.await
|
||||
}
|
||||
crate::chat::transport::matrix::start::StartCommand::BadArgs => {
|
||||
format!("Usage: `{} start <number>`", ctx.bot_name)
|
||||
}
|
||||
};
|
||||
let response = markdown_to_discord(&response);
|
||||
let _ = ctx.transport.send_message(channel, &response, "").await;
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(assign_cmd) = crate::chat::transport::matrix::assign::extract_assign_command(
|
||||
message,
|
||||
&ctx.bot_name,
|
||||
&ctx.bot_user_id,
|
||||
) {
|
||||
let response = match assign_cmd {
|
||||
crate::chat::transport::matrix::assign::AssignCommand::Assign {
|
||||
story_number,
|
||||
model,
|
||||
} => {
|
||||
slog!(
|
||||
"[discord] Handling assign command from {user} in {channel}: story {story_number} model {model}"
|
||||
);
|
||||
crate::chat::transport::matrix::assign::handle_assign(
|
||||
&ctx.bot_name,
|
||||
&story_number,
|
||||
&model,
|
||||
&ctx.project_root,
|
||||
&ctx.agents,
|
||||
)
|
||||
.await
|
||||
}
|
||||
crate::chat::transport::matrix::assign::AssignCommand::BadArgs => {
|
||||
format!("Usage: `{} assign <number> <model>`", ctx.bot_name)
|
||||
}
|
||||
};
|
||||
let response = markdown_to_discord(&response);
|
||||
let _ = ctx.transport.send_message(channel, &response, "").await;
|
||||
return;
|
||||
}
|
||||
|
||||
// No command matched — forward to LLM for conversational response.
|
||||
slog!("[discord] No command matched, forwarding to LLM for {user} in {channel}");
|
||||
handle_llm_message(ctx, channel, user, message).await;
|
||||
}
|
||||
|
||||
/// Forward a message to Claude Code and send the response back via Discord.
|
||||
async fn handle_llm_message(
|
||||
ctx: &DiscordContext,
|
||||
channel: &str,
|
||||
user: &str,
|
||||
user_message: &str,
|
||||
) {
|
||||
use crate::chat::util::drain_complete_paragraphs;
|
||||
use crate::llm::providers::claude_code::{ClaudeCodeProvider, ClaudeCodeResult};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use tokio::sync::watch;
|
||||
|
||||
// Look up existing session ID for this channel.
|
||||
let resume_session_id: Option<String> = {
|
||||
let guard = ctx.history.lock().await;
|
||||
guard
|
||||
.get(channel)
|
||||
.and_then(|conv| conv.session_id.clone())
|
||||
};
|
||||
|
||||
let bot_name = &ctx.bot_name;
|
||||
let prompt = format!(
|
||||
"[Your name is {bot_name}. Refer to yourself as {bot_name}, not Claude.]\n\n{user}: {user_message}"
|
||||
);
|
||||
|
||||
let provider = ClaudeCodeProvider::new();
|
||||
let (_cancel_tx, mut cancel_rx) = watch::channel(false);
|
||||
|
||||
// Channel for sending complete chunks to the Discord posting task.
|
||||
let (msg_tx, mut msg_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
|
||||
let msg_tx_for_callback = msg_tx.clone();
|
||||
|
||||
// Spawn a task to post messages as they arrive.
|
||||
let post_transport = Arc::clone(&ctx.transport) as Arc<dyn ChatTransport>;
|
||||
let post_channel = channel.to_string();
|
||||
let post_task = tokio::spawn(async move {
|
||||
while let Some(chunk) = msg_rx.recv().await {
|
||||
let formatted = markdown_to_discord(&chunk);
|
||||
let _ = post_transport
|
||||
.send_message(&post_channel, &formatted, "")
|
||||
.await;
|
||||
}
|
||||
});
|
||||
|
||||
// Shared buffer between the sync token callback and the async scope.
|
||||
let buffer = Arc::new(std::sync::Mutex::new(String::new()));
|
||||
let buffer_for_callback = Arc::clone(&buffer);
|
||||
let sent_any_chunk = Arc::new(AtomicBool::new(false));
|
||||
let sent_any_chunk_for_callback = Arc::clone(&sent_any_chunk);
|
||||
|
||||
let project_root_str = ctx.project_root.to_string_lossy().to_string();
|
||||
let chat_fut = provider.chat_stream(
|
||||
&prompt,
|
||||
&project_root_str,
|
||||
resume_session_id.as_deref(),
|
||||
None,
|
||||
&mut cancel_rx,
|
||||
move |token| {
|
||||
let mut buf = buffer_for_callback.lock().unwrap();
|
||||
buf.push_str(token);
|
||||
let paragraphs = drain_complete_paragraphs(&mut buf);
|
||||
for chunk in paragraphs {
|
||||
sent_any_chunk_for_callback.store(true, Ordering::Relaxed);
|
||||
let _ = msg_tx_for_callback.send(chunk);
|
||||
}
|
||||
},
|
||||
|_thinking| {},
|
||||
|_activity| {},
|
||||
);
|
||||
tokio::pin!(chat_fut);
|
||||
|
||||
// Lock the permission receiver for the duration of this chat session.
|
||||
let mut perm_rx_guard = ctx.perm_rx.lock().await;
|
||||
|
||||
let result = loop {
|
||||
tokio::select! {
|
||||
r = &mut chat_fut => break r,
|
||||
|
||||
Some(perm_fwd) = perm_rx_guard.recv() => {
|
||||
let prompt_msg = format!(
|
||||
"**Permission Request**\n\nTool: `{}`\n```json\n{}\n```\n\nReply **yes** to approve or **no** to deny.",
|
||||
perm_fwd.tool_name,
|
||||
serde_json::to_string_pretty(&perm_fwd.tool_input)
|
||||
.unwrap_or_else(|_| perm_fwd.tool_input.to_string()),
|
||||
);
|
||||
let formatted = markdown_to_discord(&prompt_msg);
|
||||
let _ = ctx.transport.send_message(channel, &formatted, "").await;
|
||||
|
||||
ctx.pending_perm_replies
|
||||
.lock()
|
||||
.await
|
||||
.insert(channel.to_string(), perm_fwd.response_tx);
|
||||
|
||||
// Spawn a timeout task: auto-deny if the user does not respond.
|
||||
let pending = Arc::clone(&ctx.pending_perm_replies);
|
||||
let timeout_channel = channel.to_string();
|
||||
let timeout_transport = Arc::clone(&ctx.transport) as Arc<dyn ChatTransport>;
|
||||
let timeout_secs = ctx.permission_timeout_secs;
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(timeout_secs)).await;
|
||||
if let Some(tx) = pending.lock().await.remove(&timeout_channel) {
|
||||
let _ = tx.send(PermissionDecision::Deny);
|
||||
let msg = "Permission request timed out — denied (fail-closed).";
|
||||
let _ = timeout_transport.send_message(&timeout_channel, msg, "").await;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
drop(perm_rx_guard);
|
||||
|
||||
// Flush remaining text.
|
||||
let remaining = buffer.lock().unwrap().trim().to_string();
|
||||
let did_send_any = sent_any_chunk.load(Ordering::Relaxed);
|
||||
|
||||
let (assistant_reply, new_session_id) = match result {
|
||||
Ok(ClaudeCodeResult {
|
||||
messages,
|
||||
session_id,
|
||||
}) => {
|
||||
let reply = if !remaining.is_empty() {
|
||||
let _ = msg_tx.send(remaining.clone());
|
||||
remaining
|
||||
} else if !did_send_any {
|
||||
let last_text = messages
|
||||
.iter()
|
||||
.rev()
|
||||
.find(|m| {
|
||||
m.role == crate::llm::types::Role::Assistant && !m.content.is_empty()
|
||||
})
|
||||
.map(|m| m.content.clone())
|
||||
.unwrap_or_default();
|
||||
if !last_text.is_empty() {
|
||||
let _ = msg_tx.send(last_text.clone());
|
||||
}
|
||||
last_text
|
||||
} else {
|
||||
remaining
|
||||
};
|
||||
slog!("[discord] session_id from chat_stream: {:?}", session_id);
|
||||
(reply, session_id)
|
||||
}
|
||||
Err(e) => {
|
||||
slog!("[discord] LLM error: {e}");
|
||||
let err_msg = format!("Error processing your request: {e}");
|
||||
let _ = msg_tx.send(err_msg.clone());
|
||||
(err_msg, None)
|
||||
}
|
||||
};
|
||||
|
||||
// Signal the posting task to finish and wait for it.
|
||||
drop(msg_tx);
|
||||
let _ = post_task.await;
|
||||
|
||||
// Record this exchange in conversation history.
|
||||
if !assistant_reply.starts_with("Error processing") {
|
||||
let mut guard = ctx.history.lock().await;
|
||||
let conv = guard.entry(channel.to_string()).or_default();
|
||||
|
||||
if new_session_id.is_some() {
|
||||
conv.session_id = new_session_id;
|
||||
}
|
||||
|
||||
conv.entries.push(ConversationEntry {
|
||||
role: ConversationRole::User,
|
||||
sender: user.to_string(),
|
||||
content: user_message.to_string(),
|
||||
});
|
||||
conv.entries.push(ConversationEntry {
|
||||
role: ConversationRole::Assistant,
|
||||
sender: String::new(),
|
||||
content: assistant_reply,
|
||||
});
|
||||
|
||||
// Trim to configured maximum.
|
||||
if conv.entries.len() > ctx.history_size {
|
||||
let excess = conv.entries.len() - ctx.history_size;
|
||||
conv.entries.drain(..excess);
|
||||
}
|
||||
|
||||
save_discord_history(&ctx.project_root, &guard);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ───────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_agents() -> Arc<crate::agents::AgentPool> {
|
||||
Arc::new(crate::agents::AgentPool::new_test(3000))
|
||||
}
|
||||
|
||||
fn test_ambient_rooms() -> Arc<Mutex<HashSet<String>>> {
|
||||
Arc::new(Mutex::new(HashSet::new()))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn command_dispatches_through_command_registry() {
|
||||
use crate::chat::commands::{CommandDispatch, try_handle_command};
|
||||
|
||||
let agents = test_agents();
|
||||
let ambient_rooms = test_ambient_rooms();
|
||||
let room_id = "123456789".to_string();
|
||||
|
||||
let bot_name = "Huskies";
|
||||
let synthetic = format!("{bot_name} status");
|
||||
|
||||
let dispatch = CommandDispatch {
|
||||
bot_name,
|
||||
bot_user_id: "discord-bot",
|
||||
project_root: std::path::Path::new("/tmp"),
|
||||
agents: &agents,
|
||||
ambient_rooms: &ambient_rooms,
|
||||
room_id: &room_id,
|
||||
};
|
||||
|
||||
let result = try_handle_command(&dispatch, &synthetic);
|
||||
assert!(
|
||||
result.is_some(),
|
||||
"status command should produce output via registry"
|
||||
);
|
||||
assert!(result.unwrap().contains("Pipeline Status"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rebuild_command_extracted_from_discord_message() {
|
||||
let result = crate::chat::transport::matrix::rebuild::extract_rebuild_command(
|
||||
"Huskies rebuild",
|
||||
"Huskies",
|
||||
"discord-bot",
|
||||
);
|
||||
assert!(result.is_some(), "'Huskies rebuild' should be recognised");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reset_command_extracted_from_discord_message() {
|
||||
let result = crate::chat::transport::matrix::reset::extract_reset_command(
|
||||
"Huskies reset",
|
||||
"Huskies",
|
||||
"discord-bot",
|
||||
);
|
||||
assert!(result.is_some(), "'Huskies reset' should be recognised");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn start_command_extracted_from_discord_message() {
|
||||
let result = crate::chat::transport::matrix::start::extract_start_command(
|
||||
"start 42",
|
||||
"Huskies",
|
||||
"discord-bot",
|
||||
);
|
||||
assert!(result.is_some(), "plain 'start 42' should be recognised");
|
||||
assert_eq!(
|
||||
result,
|
||||
Some(crate::chat::transport::matrix::start::StartCommand::Start {
|
||||
story_number: "42".to_string(),
|
||||
agent_hint: None,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn assign_command_extracted_from_discord_message() {
|
||||
let result = crate::chat::transport::matrix::assign::extract_assign_command(
|
||||
"assign 42 opus",
|
||||
"Huskies",
|
||||
"discord-bot",
|
||||
);
|
||||
assert!(
|
||||
matches!(
|
||||
result,
|
||||
Some(crate::chat::transport::matrix::assign::AssignCommand::Assign { .. })
|
||||
),
|
||||
"plain 'assign 42 opus' should be recognised on Discord"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reset_command_clears_discord_session() {
|
||||
use crate::chat::transport::matrix::RoomConversation;
|
||||
|
||||
let channel = "123456789";
|
||||
let history: DiscordConversationHistory = Arc::new(TokioMutex::new({
|
||||
let mut m = HashMap::new();
|
||||
m.insert(
|
||||
channel.to_string(),
|
||||
RoomConversation {
|
||||
session_id: Some("old-session".to_string()),
|
||||
entries: vec![ConversationEntry {
|
||||
role: ConversationRole::User,
|
||||
sender: "user123".to_string(),
|
||||
content: "previous message".to_string(),
|
||||
}],
|
||||
},
|
||||
);
|
||||
m
|
||||
}));
|
||||
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let sk = tmp.path().join(".huskies");
|
||||
std::fs::create_dir_all(&sk).unwrap();
|
||||
|
||||
{
|
||||
let mut guard = history.lock().await;
|
||||
let conv = guard
|
||||
.entry(channel.to_string())
|
||||
.or_insert_with(RoomConversation::default);
|
||||
conv.session_id = None;
|
||||
conv.entries.clear();
|
||||
save_discord_history(tmp.path(), &guard);
|
||||
}
|
||||
|
||||
let guard = history.lock().await;
|
||||
let conv = guard.get(channel).unwrap();
|
||||
assert!(conv.session_id.is_none(), "session_id should be cleared");
|
||||
assert!(conv.entries.is_empty(), "entries should be cleared");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
//! Markdown to Discord format conversion.
|
||||
//!
|
||||
//! Discord supports standard Markdown natively, so unlike Slack we only need
|
||||
//! minimal adjustments (e.g. truncating messages that exceed Discord's limit).
|
||||
|
||||
/// Convert Markdown text to Discord-compatible format.
|
||||
///
|
||||
/// Discord supports most standard Markdown:
|
||||
/// - `**bold**`, `*italic*`, `~~strikethrough~~`
|
||||
/// - `` `code` `` and ````code blocks````
|
||||
/// - `> blockquotes`
|
||||
/// - `[text](url)` links
|
||||
///
|
||||
/// The main adjustment is ensuring messages stay within Discord's 2000-char
|
||||
/// limit. Actual truncation is handled by the transport layer; this function
|
||||
/// performs any lightweight text normalization needed.
|
||||
pub fn markdown_to_discord(text: &str) -> String {
|
||||
use crate::chat::util::normalize_line_breaks;
|
||||
normalize_line_breaks(text)
|
||||
}
|
||||
|
||||
// ── Tests ───────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn discord_plain_text_unchanged() {
|
||||
let plain = "Hello, this is a plain message.";
|
||||
assert_eq!(markdown_to_discord(plain), plain);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn discord_bold_preserved() {
|
||||
assert_eq!(markdown_to_discord("**bold text**"), "**bold text**");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn discord_code_block_preserved() {
|
||||
let input = "```rust\nlet x = 1;\n```";
|
||||
assert_eq!(markdown_to_discord(input), input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn discord_empty_string_unchanged() {
|
||||
assert_eq!(markdown_to_discord(""), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn discord_link_preserved() {
|
||||
let input = "[click here](https://example.com)";
|
||||
assert_eq!(markdown_to_discord(input), input);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,507 @@
|
||||
//! Minimal Discord Gateway WebSocket client.
|
||||
//!
|
||||
//! Connects to the Discord Gateway, authenticates with a bot token, maintains
|
||||
//! the heartbeat keepalive, and dispatches incoming `MESSAGE_CREATE` events to
|
||||
//! the command handler.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio_tungstenite::tungstenite::Message as WsMessage;
|
||||
|
||||
use crate::slog;
|
||||
|
||||
use super::commands::{self, DiscordContext};
|
||||
|
||||
// ── Gateway opcodes ─────────────────────────────────────────────────────
|
||||
|
||||
const OP_DISPATCH: u8 = 0;
|
||||
const OP_HEARTBEAT: u8 = 1;
|
||||
const OP_IDENTIFY: u8 = 2;
|
||||
const OP_RECONNECT: u8 = 7;
|
||||
const OP_INVALID_SESSION: u8 = 9;
|
||||
const OP_HELLO: u8 = 10;
|
||||
const OP_HEARTBEAT_ACK: u8 = 11;
|
||||
|
||||
// ── Gateway intents ─────────────────────────────────────────────────────
|
||||
|
||||
/// GUILDS | GUILD_MESSAGES | MESSAGE_CONTENT (privileged)
|
||||
const GATEWAY_INTENTS: u64 = (1 << 0) | (1 << 9) | (1 << 15);
|
||||
|
||||
// ── Gateway payload types ───────────────────────────────────────────────
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct GatewayPayload {
|
||||
op: u8,
|
||||
#[serde(default)]
|
||||
d: Option<serde_json::Value>,
|
||||
#[serde(default)]
|
||||
s: Option<u64>,
|
||||
#[serde(default)]
|
||||
t: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct HelloData {
|
||||
heartbeat_interval: u64,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub(super) struct MessageCreateData {
|
||||
#[serde(default)]
|
||||
pub channel_id: String,
|
||||
#[serde(default)]
|
||||
pub content: String,
|
||||
#[serde(default)]
|
||||
pub author: Option<MessageAuthor>,
|
||||
#[serde(default)]
|
||||
pub mentions: Vec<MentionUser>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub(super) struct MessageAuthor {
|
||||
pub id: String,
|
||||
#[serde(default)]
|
||||
pub bot: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub(super) struct MentionUser {
|
||||
pub id: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct ReadyData {
|
||||
user: ReadyUser,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct ReadyUser {
|
||||
id: String,
|
||||
}
|
||||
|
||||
// ── Identify payload ────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct IdentifyPayload {
|
||||
op: u8,
|
||||
d: IdentifyData,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct IdentifyData {
|
||||
token: String,
|
||||
intents: u64,
|
||||
properties: IdentifyProperties,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct IdentifyProperties {
|
||||
os: &'static str,
|
||||
browser: &'static str,
|
||||
device: &'static str,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct HeartbeatPayload {
|
||||
op: u8,
|
||||
d: Option<u64>,
|
||||
}
|
||||
|
||||
// ── Gateway URL ─────────────────────────────────────────────────────────
|
||||
|
||||
const GATEWAY_URL: &str = "wss://gateway.discord.gg/?v=10&encoding=json";
|
||||
|
||||
// ── Public entry point ──────────────────────────────────────────────────
|
||||
|
||||
/// Spawn a background task that connects to the Discord Gateway, maintains
|
||||
/// the heartbeat, and dispatches incoming messages.
|
||||
pub fn spawn_gateway(ctx: Arc<DiscordContext>) {
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
if let Err(e) = run_gateway(Arc::clone(&ctx)).await {
|
||||
slog!("[discord] Gateway connection ended: {e}");
|
||||
}
|
||||
// Wait before reconnecting.
|
||||
slog!("[discord] Reconnecting in 5 seconds…");
|
||||
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Connect to the Discord Gateway and process events until the connection
|
||||
/// closes or an unrecoverable error occurs.
|
||||
async fn run_gateway(ctx: Arc<DiscordContext>) -> Result<(), String> {
|
||||
slog!("[discord] Connecting to Discord Gateway…");
|
||||
|
||||
let (ws_stream, _) = tokio_tungstenite::connect_async(GATEWAY_URL)
|
||||
.await
|
||||
.map_err(|e| format!("Gateway connect failed: {e}"))?;
|
||||
|
||||
slog!("[discord] Gateway WebSocket connected");
|
||||
|
||||
let (mut write, mut read) = ws_stream.split();
|
||||
|
||||
// Wait for Hello (op 10) to get the heartbeat interval.
|
||||
let hello = read
|
||||
.next()
|
||||
.await
|
||||
.ok_or("Gateway closed before Hello")?
|
||||
.map_err(|e| format!("Gateway read error: {e}"))?;
|
||||
|
||||
let hello_payload: GatewayPayload =
|
||||
parse_ws_message(&hello).ok_or("Failed to parse Hello")?;
|
||||
|
||||
if hello_payload.op != OP_HELLO {
|
||||
return Err(format!(
|
||||
"Expected Hello (op 10), got op {}",
|
||||
hello_payload.op
|
||||
));
|
||||
}
|
||||
|
||||
let hello_data: HelloData =
|
||||
serde_json::from_value(hello_payload.d.ok_or("Hello missing data")?)
|
||||
.map_err(|e| format!("Failed to parse Hello data: {e}"))?;
|
||||
|
||||
let heartbeat_interval =
|
||||
std::time::Duration::from_millis(hello_data.heartbeat_interval);
|
||||
slog!(
|
||||
"[discord] Heartbeat interval: {}ms",
|
||||
hello_data.heartbeat_interval
|
||||
);
|
||||
|
||||
// Send Identify.
|
||||
let identify = IdentifyPayload {
|
||||
op: OP_IDENTIFY,
|
||||
d: IdentifyData {
|
||||
token: ctx.bot_token.clone(),
|
||||
intents: GATEWAY_INTENTS,
|
||||
properties: IdentifyProperties {
|
||||
os: "linux",
|
||||
browser: "huskies",
|
||||
device: "huskies",
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
let identify_json = serde_json::to_string(&identify)
|
||||
.map_err(|e| format!("Failed to serialise Identify: {e}"))?;
|
||||
write
|
||||
.send(WsMessage::Text(identify_json.into()))
|
||||
.await
|
||||
.map_err(|e| format!("Failed to send Identify: {e}"))?;
|
||||
|
||||
slog!("[discord] Identify sent, waiting for Ready…");
|
||||
|
||||
// Track sequence number for heartbeats.
|
||||
let sequence = Arc::new(std::sync::atomic::AtomicU64::new(0));
|
||||
let seq_for_heartbeat = Arc::clone(&sequence);
|
||||
|
||||
// Share the write half between the heartbeat task and the main event loop.
|
||||
let write = Arc::new(tokio::sync::Mutex::new(write));
|
||||
let write_for_heartbeat = Arc::clone(&write);
|
||||
|
||||
// Spawn heartbeat task.
|
||||
let heartbeat_handle = tokio::spawn(async move {
|
||||
// Jitter: wait a random fraction of the interval before the first beat.
|
||||
let jitter = heartbeat_interval.mul_f64(rand_fraction());
|
||||
tokio::time::sleep(jitter).await;
|
||||
|
||||
loop {
|
||||
let seq = seq_for_heartbeat.load(std::sync::atomic::Ordering::Relaxed);
|
||||
let seq_val = if seq == 0 { None } else { Some(seq) };
|
||||
let hb = HeartbeatPayload {
|
||||
op: OP_HEARTBEAT,
|
||||
d: seq_val,
|
||||
};
|
||||
if let Ok(json) = serde_json::to_string(&hb) {
|
||||
let mut w = write_for_heartbeat.lock().await;
|
||||
if w.send(WsMessage::Text(json.into())).await.is_err() {
|
||||
slog!("[discord] Heartbeat send failed, stopping heartbeat");
|
||||
break;
|
||||
}
|
||||
}
|
||||
tokio::time::sleep(heartbeat_interval).await;
|
||||
}
|
||||
});
|
||||
|
||||
// Track our own bot user ID (set from READY event).
|
||||
let mut bot_user_id: Option<String> = None;
|
||||
|
||||
// Main event loop.
|
||||
while let Some(msg_result) = read.next().await {
|
||||
let msg = match msg_result {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
slog!("[discord] Gateway read error: {e}");
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let payload: GatewayPayload = match parse_ws_message(&msg) {
|
||||
Some(p) => p,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
// Update sequence number.
|
||||
if let Some(s) = payload.s {
|
||||
sequence.store(s, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
match payload.op {
|
||||
OP_DISPATCH => {
|
||||
let event_name = payload.t.as_deref().unwrap_or("");
|
||||
match event_name {
|
||||
"READY" => {
|
||||
if let Some(d) = payload.d
|
||||
&& let Ok(ready) = serde_json::from_value::<ReadyData>(d)
|
||||
{
|
||||
bot_user_id = Some(ready.user.id.clone());
|
||||
slog!(
|
||||
"[discord] READY — bot user ID: {}",
|
||||
ready.user.id
|
||||
);
|
||||
}
|
||||
}
|
||||
"MESSAGE_CREATE" => {
|
||||
if let Some(d) = payload.d {
|
||||
dispatch_message(
|
||||
Arc::clone(&ctx),
|
||||
d,
|
||||
bot_user_id.clone(),
|
||||
);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
OP_HEARTBEAT => {
|
||||
// Server requested an immediate heartbeat.
|
||||
let seq = sequence.load(std::sync::atomic::Ordering::Relaxed);
|
||||
let seq_val = if seq == 0 { None } else { Some(seq) };
|
||||
let hb = HeartbeatPayload {
|
||||
op: OP_HEARTBEAT,
|
||||
d: seq_val,
|
||||
};
|
||||
if let Ok(json) = serde_json::to_string(&hb) {
|
||||
let mut w = write.lock().await;
|
||||
let _ = w.send(WsMessage::Text(json.into())).await;
|
||||
}
|
||||
}
|
||||
OP_HEARTBEAT_ACK => {}
|
||||
OP_RECONNECT => {
|
||||
slog!("[discord] Gateway requested reconnect");
|
||||
break;
|
||||
}
|
||||
OP_INVALID_SESSION => {
|
||||
slog!("[discord] Invalid session — will reconnect");
|
||||
// Wait a bit before reconnecting for invalid sessions.
|
||||
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
heartbeat_handle.abort();
|
||||
slog!("[discord] Gateway event loop ended");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
fn parse_ws_message(msg: &WsMessage) -> Option<GatewayPayload> {
|
||||
match msg {
|
||||
WsMessage::Text(text) => serde_json::from_str(text).ok(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Filter and dispatch an incoming MESSAGE_CREATE event on a background task.
|
||||
fn dispatch_message(
|
||||
ctx: Arc<DiscordContext>,
|
||||
data: serde_json::Value,
|
||||
bot_user_id: Option<String>,
|
||||
) {
|
||||
tokio::spawn(async move {
|
||||
let msg: MessageCreateData = match serde_json::from_value(data) {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
slog!("[discord] Failed to parse MESSAGE_CREATE: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let author = match msg.author.as_ref() {
|
||||
Some(a) => a,
|
||||
None => return,
|
||||
};
|
||||
|
||||
// Ignore bot messages (including our own).
|
||||
if author.bot.unwrap_or(false) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if the channel is in our configured channel list.
|
||||
if !ctx.channel_ids.contains(&msg.channel_id) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check allowed_users (if non-empty, only listed users can interact).
|
||||
if !ctx.allowed_users.is_empty() && !ctx.allowed_users.contains(&author.id) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if the bot was mentioned, or if we respond to all messages in
|
||||
// configured channels (ambient mode).
|
||||
let bot_mentioned = bot_user_id.as_ref().is_some_and(|bid| {
|
||||
msg.mentions.iter().any(|m| m.id == *bid)
|
||||
});
|
||||
|
||||
let in_ambient = ctx
|
||||
.ambient_rooms
|
||||
.lock()
|
||||
.unwrap()
|
||||
.contains(&msg.channel_id);
|
||||
|
||||
if !bot_mentioned && !in_ambient {
|
||||
return;
|
||||
}
|
||||
|
||||
// Strip the bot mention from the message content if present.
|
||||
let content = if let Some(bid) = bot_user_id.as_ref() {
|
||||
let mention_pat = format!("<@{bid}>");
|
||||
let mention_pat_nick = format!("<@!{bid}>");
|
||||
msg.content
|
||||
.replace(&mention_pat, "")
|
||||
.replace(&mention_pat_nick, "")
|
||||
.trim()
|
||||
.to_string()
|
||||
} else {
|
||||
msg.content.clone()
|
||||
};
|
||||
|
||||
if content.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
slog!(
|
||||
"[discord] Message from {} in {}: {content:.80}",
|
||||
author.id,
|
||||
msg.channel_id
|
||||
);
|
||||
|
||||
commands::handle_incoming_message(&ctx, &msg.channel_id, &author.id, &content)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
/// Cheap pseudo-random fraction [0.0, 1.0) for heartbeat jitter.
|
||||
fn rand_fraction() -> f64 {
|
||||
let nanos = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.subsec_nanos();
|
||||
(nanos as f64) / 1_000_000_000.0
|
||||
}
|
||||
|
||||
// ── Tests ───────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_hello_payload() {
|
||||
let json = r#"{"op": 10, "d": {"heartbeat_interval": 41250}}"#;
|
||||
let payload: GatewayPayload = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(payload.op, OP_HELLO);
|
||||
let hello: HelloData =
|
||||
serde_json::from_value(payload.d.unwrap()).unwrap();
|
||||
assert_eq!(hello.heartbeat_interval, 41250);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_message_create() {
|
||||
let json = r#"{
|
||||
"channel_id": "123456",
|
||||
"content": "hello bot",
|
||||
"author": {"id": "789", "bot": false},
|
||||
"mentions": [{"id": "111"}]
|
||||
}"#;
|
||||
let msg: MessageCreateData = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(msg.channel_id, "123456");
|
||||
assert_eq!(msg.content, "hello bot");
|
||||
assert_eq!(msg.author.as_ref().unwrap().id, "789");
|
||||
assert!(!msg.author.as_ref().unwrap().bot.unwrap_or(false));
|
||||
assert_eq!(msg.mentions.len(), 1);
|
||||
assert_eq!(msg.mentions[0].id, "111");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_ready_event() {
|
||||
let json = r#"{"user": {"id": "999888"}}"#;
|
||||
let ready: ReadyData = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(ready.user.id, "999888");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identify_payload_serializes_correctly() {
|
||||
let identify = IdentifyPayload {
|
||||
op: OP_IDENTIFY,
|
||||
d: IdentifyData {
|
||||
token: "test-token".to_string(),
|
||||
intents: GATEWAY_INTENTS,
|
||||
properties: IdentifyProperties {
|
||||
os: "linux",
|
||||
browser: "huskies",
|
||||
device: "huskies",
|
||||
},
|
||||
},
|
||||
};
|
||||
let json: serde_json::Value =
|
||||
serde_json::from_str(&serde_json::to_string(&identify).unwrap()).unwrap();
|
||||
assert_eq!(json["op"], 2);
|
||||
assert_eq!(json["d"]["token"], "test-token");
|
||||
assert_eq!(json["d"]["intents"], GATEWAY_INTENTS);
|
||||
assert_eq!(json["d"]["properties"]["browser"], "huskies");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_payload_serializes_with_sequence() {
|
||||
let hb = HeartbeatPayload {
|
||||
op: OP_HEARTBEAT,
|
||||
d: Some(42),
|
||||
};
|
||||
let json: serde_json::Value =
|
||||
serde_json::from_str(&serde_json::to_string(&hb).unwrap()).unwrap();
|
||||
assert_eq!(json["op"], 1);
|
||||
assert_eq!(json["d"], 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_payload_serializes_null_sequence() {
|
||||
let hb = HeartbeatPayload {
|
||||
op: OP_HEARTBEAT,
|
||||
d: None,
|
||||
};
|
||||
let json: serde_json::Value =
|
||||
serde_json::from_str(&serde_json::to_string(&hb).unwrap()).unwrap();
|
||||
assert_eq!(json["op"], 1);
|
||||
assert!(json["d"].is_null());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rand_fraction_in_range() {
|
||||
let f = rand_fraction();
|
||||
assert!((0.0..1.0).contains(&f));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gateway_intents_correct() {
|
||||
// GUILDS (1) | GUILD_MESSAGES (512) | MESSAGE_CONTENT (32768)
|
||||
assert_eq!(GATEWAY_INTENTS, 1 | 512 | 32768);
|
||||
assert_eq!(GATEWAY_INTENTS, 33281);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
//! Discord conversation history persistence.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex as TokioMutex;
|
||||
|
||||
use crate::chat::transport::matrix::RoomConversation;
|
||||
use crate::slog;
|
||||
|
||||
/// Per-channel conversation history, keyed by channel ID.
|
||||
pub type DiscordConversationHistory = Arc<TokioMutex<HashMap<String, RoomConversation>>>;
|
||||
|
||||
/// On-disk format for persisted Discord conversation history.
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct PersistedDiscordHistory {
|
||||
channels: HashMap<String, RoomConversation>,
|
||||
}
|
||||
|
||||
/// Path to the persisted Discord conversation history file.
|
||||
const DISCORD_HISTORY_FILE: &str = ".huskies/discord_history.json";
|
||||
|
||||
/// Load Discord conversation history from disk.
|
||||
pub fn load_discord_history(project_root: &std::path::Path) -> HashMap<String, RoomConversation> {
|
||||
let path = project_root.join(DISCORD_HISTORY_FILE);
|
||||
let data = match std::fs::read_to_string(&path) {
|
||||
Ok(d) => d,
|
||||
Err(_) => return HashMap::new(),
|
||||
};
|
||||
let persisted: PersistedDiscordHistory = match serde_json::from_str(&data) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
slog!("[discord] Failed to parse history file: {e}");
|
||||
return HashMap::new();
|
||||
}
|
||||
};
|
||||
persisted.channels
|
||||
}
|
||||
|
||||
/// Save Discord conversation history to disk.
|
||||
pub(super) fn save_discord_history(
|
||||
project_root: &std::path::Path,
|
||||
history: &HashMap<String, RoomConversation>,
|
||||
) {
|
||||
let persisted = PersistedDiscordHistory {
|
||||
channels: history.clone(),
|
||||
};
|
||||
let path = project_root.join(DISCORD_HISTORY_FILE);
|
||||
match serde_json::to_string_pretty(&persisted) {
|
||||
Ok(json) => {
|
||||
if let Err(e) = std::fs::write(&path, json) {
|
||||
slog!("[discord] Failed to write history file: {e}");
|
||||
}
|
||||
}
|
||||
Err(e) => slog!("[discord] Failed to serialise history: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ───────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::chat::transport::matrix::{ConversationEntry, ConversationRole};
|
||||
|
||||
#[test]
|
||||
fn save_and_load_discord_history_round_trips() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let sk = tmp.path().join(".huskies");
|
||||
std::fs::create_dir_all(&sk).unwrap();
|
||||
|
||||
let mut history = HashMap::new();
|
||||
history.insert(
|
||||
"123456789".to_string(),
|
||||
RoomConversation {
|
||||
session_id: Some("sess-abc".to_string()),
|
||||
entries: vec![
|
||||
ConversationEntry {
|
||||
role: ConversationRole::User,
|
||||
sender: "user123".to_string(),
|
||||
content: "hello".to_string(),
|
||||
},
|
||||
ConversationEntry {
|
||||
role: ConversationRole::Assistant,
|
||||
sender: String::new(),
|
||||
content: "hi there!".to_string(),
|
||||
},
|
||||
],
|
||||
},
|
||||
);
|
||||
|
||||
save_discord_history(tmp.path(), &history);
|
||||
let loaded = load_discord_history(tmp.path());
|
||||
|
||||
assert_eq!(loaded.len(), 1);
|
||||
let conv = loaded.get("123456789").unwrap();
|
||||
assert_eq!(conv.session_id.as_deref(), Some("sess-abc"));
|
||||
assert_eq!(conv.entries.len(), 2);
|
||||
assert_eq!(conv.entries[0].content, "hello");
|
||||
assert_eq!(conv.entries[1].content, "hi there!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_discord_history_returns_empty_when_file_missing() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let history = load_discord_history(tmp.path());
|
||||
assert!(history.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_discord_history_returns_empty_on_invalid_json() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let sk = tmp.path().join(".huskies");
|
||||
std::fs::create_dir_all(&sk).unwrap();
|
||||
std::fs::write(sk.join("discord_history.json"), "not json {{{").unwrap();
|
||||
let history = load_discord_history(tmp.path());
|
||||
assert!(history.is_empty());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,302 @@
|
||||
//! DiscordTransport — ChatTransport implementation for the Discord Bot API.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::chat::{ChatTransport, MessageId};
|
||||
use crate::slog;
|
||||
|
||||
// ── Discord API base URL (overridable for tests) ──────────────────────
|
||||
|
||||
const DISCORD_API_BASE: &str = "https://discord.com/api/v10";
|
||||
|
||||
// ── DiscordTransport ──────────────────────────────────────────────────
|
||||
|
||||
/// Discord Bot API transport.
|
||||
///
|
||||
/// Sends messages via `POST {API_BASE}/channels/{channel_id}/messages`,
|
||||
/// edits via `PATCH {API_BASE}/channels/{channel_id}/messages/{message_id}`,
|
||||
/// and typing indicators via `POST {API_BASE}/channels/{channel_id}/typing`.
|
||||
pub struct DiscordTransport {
|
||||
bot_token: String,
|
||||
client: reqwest::Client,
|
||||
/// Optional base URL override for tests.
|
||||
api_base: String,
|
||||
}
|
||||
|
||||
impl DiscordTransport {
|
||||
pub fn new(bot_token: String) -> Self {
|
||||
Self {
|
||||
bot_token,
|
||||
client: reqwest::Client::new(),
|
||||
api_base: DISCORD_API_BASE.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn with_api_base(bot_token: String, api_base: String) -> Self {
|
||||
Self {
|
||||
bot_token,
|
||||
client: reqwest::Client::new(),
|
||||
api_base,
|
||||
}
|
||||
}
|
||||
|
||||
/// Authorization header value: `Bot {token}`.
|
||||
fn auth_header(&self) -> String {
|
||||
format!("Bot {}", self.bot_token)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Discord API response types ────────────────────────────────────────
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct DiscordMessage {
|
||||
id: String,
|
||||
}
|
||||
|
||||
// ── ChatTransport implementation ──────────────────────────────────────
|
||||
|
||||
#[async_trait]
|
||||
impl ChatTransport for DiscordTransport {
|
||||
async fn send_message(
|
||||
&self,
|
||||
channel_id: &str,
|
||||
plain: &str,
|
||||
_html: &str,
|
||||
) -> Result<MessageId, String> {
|
||||
slog!("[discord] send_message to {channel_id}: {plain:.80}");
|
||||
let url = format!("{}/channels/{}/messages", self.api_base, channel_id);
|
||||
|
||||
// Discord messages have a 2000-char limit. Truncate if needed.
|
||||
let content = if plain.len() > 2000 {
|
||||
format!("{}…", &plain[..1999])
|
||||
} else {
|
||||
plain.to_string()
|
||||
};
|
||||
|
||||
let payload = serde_json::json!({ "content": content });
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.header("Authorization", self.auth_header())
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Discord API request failed: {e}"))?;
|
||||
|
||||
let status = resp.status();
|
||||
let resp_text = resp
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "<no body>".to_string());
|
||||
|
||||
if !status.is_success() {
|
||||
return Err(format!("Discord API returned {status}: {resp_text}"));
|
||||
}
|
||||
|
||||
let msg: DiscordMessage = serde_json::from_str(&resp_text).map_err(|e| {
|
||||
format!("Failed to parse Discord API response: {e} — body: {resp_text}")
|
||||
})?;
|
||||
|
||||
Ok(msg.id)
|
||||
}
|
||||
|
||||
async fn edit_message(
|
||||
&self,
|
||||
channel_id: &str,
|
||||
original_message_id: &str,
|
||||
plain: &str,
|
||||
_html: &str,
|
||||
) -> Result<(), String> {
|
||||
slog!("[discord] edit_message in {channel_id}: id={original_message_id}");
|
||||
let url = format!(
|
||||
"{}/channels/{}/messages/{}",
|
||||
self.api_base, channel_id, original_message_id
|
||||
);
|
||||
|
||||
let content = if plain.len() > 2000 {
|
||||
format!("{}…", &plain[..1999])
|
||||
} else {
|
||||
plain.to_string()
|
||||
};
|
||||
|
||||
let payload = serde_json::json!({ "content": content });
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.patch(&url)
|
||||
.header("Authorization", self.auth_header())
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Discord edit request failed: {e}"))?;
|
||||
|
||||
let status = resp.status();
|
||||
let resp_text = resp
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "<no body>".to_string());
|
||||
|
||||
if !status.is_success() {
|
||||
return Err(format!("Discord edit returned {status}: {resp_text}"));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, channel_id: &str, typing: bool) -> Result<(), String> {
|
||||
if !typing {
|
||||
// Discord doesn't have a "stop typing" API — it times out after ~10s.
|
||||
return Ok(());
|
||||
}
|
||||
let url = format!("{}/channels/{}/typing", self.api_base, channel_id);
|
||||
|
||||
self.client
|
||||
.post(&url)
|
||||
.header("Authorization", self.auth_header())
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Discord typing request failed: {e}"))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ───────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_send_message_calls_discord_api() {
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
let mock = server
|
||||
.mock("POST", "/channels/123456/messages")
|
||||
.match_header("authorization", "Bot test-token")
|
||||
.with_body(r#"{"id": "999888777"}"#)
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
let transport =
|
||||
DiscordTransport::with_api_base("test-token".to_string(), server.url());
|
||||
|
||||
let result = transport
|
||||
.send_message("123456", "hello", "<p>hello</p>")
|
||||
.await;
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), "999888777");
|
||||
mock.assert_async().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_send_message_handles_api_error() {
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
server
|
||||
.mock("POST", "/channels/bad/messages")
|
||||
.with_status(404)
|
||||
.with_body(r#"{"message": "Unknown Channel", "code": 10003}"#)
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
let transport =
|
||||
DiscordTransport::with_api_base("test-token".to_string(), server.url());
|
||||
|
||||
let result = transport.send_message("bad", "hello", "").await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("404"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_edit_message_calls_patch() {
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
let mock = server
|
||||
.mock("PATCH", "/channels/123456/messages/999888777")
|
||||
.match_header("authorization", "Bot test-token")
|
||||
.with_body(r#"{"id": "999888777"}"#)
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
let transport =
|
||||
DiscordTransport::with_api_base("test-token".to_string(), server.url());
|
||||
|
||||
let result = transport
|
||||
.edit_message("123456", "999888777", "updated", "")
|
||||
.await;
|
||||
assert!(result.is_ok());
|
||||
mock.assert_async().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_edit_message_handles_error() {
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
server
|
||||
.mock("PATCH", "/channels/123456/messages/bad")
|
||||
.with_status(404)
|
||||
.with_body(r#"{"message": "Unknown Message", "code": 10008}"#)
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
let transport =
|
||||
DiscordTransport::with_api_base("test-token".to_string(), server.url());
|
||||
|
||||
let result = transport
|
||||
.edit_message("123456", "bad", "updated", "")
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("404"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_send_typing_succeeds() {
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
server
|
||||
.mock("POST", "/channels/123456/typing")
|
||||
.with_status(204)
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
let transport =
|
||||
DiscordTransport::with_api_base("test-token".to_string(), server.url());
|
||||
|
||||
assert!(transport.send_typing("123456", true).await.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_send_typing_false_is_noop() {
|
||||
let transport = DiscordTransport::new("test-token".to_string());
|
||||
assert!(transport.send_typing("123456", false).await.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_handles_http_error() {
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
server
|
||||
.mock("POST", "/channels/123456/messages")
|
||||
.with_status(500)
|
||||
.with_body("Internal Server Error")
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
let transport =
|
||||
DiscordTransport::with_api_base("test-token".to_string(), server.url());
|
||||
|
||||
let result = transport.send_message("123456", "hello", "").await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("500"));
|
||||
}
|
||||
|
||||
// ── ChatTransport trait satisfaction ─────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn discord_transport_satisfies_trait() {
|
||||
fn assert_transport<T: ChatTransport>() {}
|
||||
assert_transport::<DiscordTransport>();
|
||||
|
||||
let _: Arc<dyn ChatTransport> =
|
||||
Arc::new(DiscordTransport::new("test-token".to_string()));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
//! Discord Bot integration.
|
||||
//!
|
||||
//! Provides:
|
||||
//! - [`DiscordTransport`] — a [`ChatTransport`] that sends messages via the
|
||||
//! Discord REST API (`/channels/{id}/messages`).
|
||||
//! - [`gateway::spawn_gateway`] — connects to the Discord Gateway WebSocket,
|
||||
//! receives `MESSAGE_CREATE` events, and dispatches commands.
|
||||
//! - [`DiscordContext`] — shared context for the bot.
|
||||
|
||||
pub mod commands;
|
||||
pub mod format;
|
||||
pub mod gateway;
|
||||
pub mod history;
|
||||
pub mod meta;
|
||||
|
||||
pub use commands::DiscordContext;
|
||||
pub use history::load_discord_history;
|
||||
pub use meta::DiscordTransport;
|
||||
@@ -64,7 +64,8 @@ pub struct BotConfig {
|
||||
/// manually while the bot is running.
|
||||
#[serde(default)]
|
||||
pub ambient_rooms: Vec<String>,
|
||||
/// Chat transport to use: `"matrix"` (default) or `"whatsapp"`.
|
||||
/// Chat transport to use: `"matrix"` (default), `"whatsapp"`, `"slack"`,
|
||||
/// or `"discord"`.
|
||||
///
|
||||
/// Selects which [`ChatTransport`] implementation the bot uses for
|
||||
/// sending and editing messages. Currently only read during bot
|
||||
@@ -134,6 +135,20 @@ pub struct BotConfig {
|
||||
/// Slack channel IDs the bot should listen in.
|
||||
#[serde(default)]
|
||||
pub slack_channel_ids: Vec<String>,
|
||||
|
||||
// ── Discord Bot API fields ──────────────────────────────────────
|
||||
// These are only required when `transport = "discord"`.
|
||||
|
||||
/// Discord bot token from the Discord Developer Portal.
|
||||
#[serde(default)]
|
||||
pub discord_bot_token: Option<String>,
|
||||
/// Discord channel IDs the bot should listen in.
|
||||
#[serde(default)]
|
||||
pub discord_channel_ids: Vec<String>,
|
||||
/// Discord user IDs allowed to interact with the bot.
|
||||
/// When empty or absent, all users in configured channels are allowed.
|
||||
#[serde(default)]
|
||||
pub discord_allowed_users: Vec<String>,
|
||||
}
|
||||
|
||||
fn default_transport() -> String {
|
||||
@@ -241,6 +256,22 @@ impl BotConfig {
|
||||
);
|
||||
return None;
|
||||
}
|
||||
} else if config.transport == "discord" {
|
||||
// Validate Discord-specific fields.
|
||||
if config.discord_bot_token.as_ref().is_none_or(|s| s.is_empty()) {
|
||||
eprintln!(
|
||||
"[bot] bot.toml: transport=\"discord\" requires \
|
||||
discord_bot_token"
|
||||
);
|
||||
return None;
|
||||
}
|
||||
if config.discord_channel_ids.is_empty() {
|
||||
eprintln!(
|
||||
"[bot] bot.toml: transport=\"discord\" requires \
|
||||
at least one discord_channel_ids entry"
|
||||
);
|
||||
return None;
|
||||
}
|
||||
} else {
|
||||
// Default transport is Matrix — validate Matrix-specific fields.
|
||||
if config.homeserver.as_ref().is_none_or(|s| s.is_empty()) {
|
||||
@@ -1054,4 +1085,99 @@ whatsapp_allowed_phones = ["+15551234567", "+15559876543"]
|
||||
vec!["+15551234567", "+15559876543"]
|
||||
);
|
||||
}
|
||||
|
||||
// ── Discord config tests ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn load_discord_transport_reads_config() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let sk = tmp.path().join(".huskies");
|
||||
fs::create_dir_all(&sk).unwrap();
|
||||
fs::write(
|
||||
sk.join("bot.toml"),
|
||||
r#"
|
||||
enabled = true
|
||||
transport = "discord"
|
||||
discord_bot_token = "Bot.Token.Here"
|
||||
discord_channel_ids = ["123456789012345678"]
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let config = BotConfig::load(tmp.path()).unwrap();
|
||||
assert_eq!(config.transport, "discord");
|
||||
assert_eq!(
|
||||
config.discord_bot_token.as_deref(),
|
||||
Some("Bot.Token.Here")
|
||||
);
|
||||
assert_eq!(
|
||||
config.discord_channel_ids,
|
||||
vec!["123456789012345678"]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_discord_returns_none_when_missing_bot_token() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let sk = tmp.path().join(".huskies");
|
||||
fs::create_dir_all(&sk).unwrap();
|
||||
fs::write(
|
||||
sk.join("bot.toml"),
|
||||
r#"
|
||||
enabled = true
|
||||
transport = "discord"
|
||||
discord_channel_ids = ["123456789012345678"]
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
assert!(BotConfig::load(tmp.path()).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_discord_returns_none_when_missing_channel_ids() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let sk = tmp.path().join(".huskies");
|
||||
fs::create_dir_all(&sk).unwrap();
|
||||
fs::write(
|
||||
sk.join("bot.toml"),
|
||||
r#"
|
||||
enabled = true
|
||||
transport = "discord"
|
||||
discord_bot_token = "Bot.Token.Here"
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
assert!(BotConfig::load(tmp.path()).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn discord_allowed_users_defaults_to_empty_when_absent() {
|
||||
let config: BotConfig = toml::from_str(
|
||||
r#"
|
||||
enabled = true
|
||||
transport = "discord"
|
||||
discord_bot_token = "Bot.Token.Here"
|
||||
discord_channel_ids = ["123456789"]
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
assert!(config.discord_allowed_users.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn discord_allowed_users_deserializes_list() {
|
||||
let config: BotConfig = toml::from_str(
|
||||
r#"
|
||||
enabled = true
|
||||
transport = "discord"
|
||||
discord_bot_token = "Bot.Token.Here"
|
||||
discord_channel_ids = ["123456789"]
|
||||
discord_allowed_users = ["111222333", "444555666"]
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
config.discord_allowed_users,
|
||||
vec!["111222333", "444555666"]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod discord;
|
||||
pub mod matrix;
|
||||
pub mod slack;
|
||||
pub mod whatsapp;
|
||||
|
||||
@@ -339,12 +339,14 @@ async fn main() -> Result<(), std::io::Error> {
|
||||
// before watcher_tx is moved into AppContext.
|
||||
let watcher_rx_for_whatsapp = watcher_tx.subscribe();
|
||||
let watcher_rx_for_slack = watcher_tx.subscribe();
|
||||
let watcher_rx_for_discord = watcher_tx.subscribe();
|
||||
// Wrap perm_rx in Arc<Mutex> so it can be shared with both the WebSocket
|
||||
// handler (via AppContext) and the Matrix bot.
|
||||
let perm_rx = Arc::new(tokio::sync::Mutex::new(perm_rx));
|
||||
let perm_rx_for_bot = Arc::clone(&perm_rx);
|
||||
let perm_rx_for_whatsapp = Arc::clone(&perm_rx);
|
||||
let perm_rx_for_slack = Arc::clone(&perm_rx);
|
||||
let perm_rx_for_discord = Arc::clone(&perm_rx);
|
||||
|
||||
// Capture project root, agents Arc, and reconciliation sender before ctx
|
||||
// is consumed by build_routes.
|
||||
@@ -441,9 +443,49 @@ async fn main() -> Result<(), std::io::Error> {
|
||||
})
|
||||
});
|
||||
|
||||
// Build Discord context if bot.toml configures transport = "discord".
|
||||
let discord_ctx: Option<Arc<chat::transport::discord::DiscordContext>> = startup_root
|
||||
.as_ref()
|
||||
.and_then(|root| chat::transport::matrix::BotConfig::load(root))
|
||||
.filter(|cfg| cfg.transport == "discord")
|
||||
.map(|cfg| {
|
||||
let transport = Arc::new(chat::transport::discord::DiscordTransport::new(
|
||||
cfg.discord_bot_token.clone().unwrap_or_default(),
|
||||
));
|
||||
let bot_name = cfg
|
||||
.display_name
|
||||
.clone()
|
||||
.unwrap_or_else(|| "Assistant".to_string());
|
||||
let root = startup_root.clone().unwrap();
|
||||
let history = chat::transport::discord::load_discord_history(&root);
|
||||
let channel_ids: std::collections::HashSet<String> =
|
||||
cfg.discord_channel_ids.iter().cloned().collect();
|
||||
let allowed_users: std::collections::HashSet<String> =
|
||||
cfg.discord_allowed_users.iter().cloned().collect();
|
||||
Arc::new(chat::transport::discord::DiscordContext {
|
||||
bot_token: cfg.discord_bot_token.clone().unwrap_or_default(),
|
||||
transport,
|
||||
project_root: root,
|
||||
agents: Arc::clone(&startup_agents),
|
||||
bot_name,
|
||||
bot_user_id: "discord-bot".to_string(),
|
||||
ambient_rooms: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
|
||||
history: std::sync::Arc::new(tokio::sync::Mutex::new(history)),
|
||||
history_size: cfg.history_size,
|
||||
channel_ids,
|
||||
allowed_users,
|
||||
perm_rx: perm_rx_for_discord,
|
||||
pending_perm_replies: Arc::new(tokio::sync::Mutex::new(
|
||||
std::collections::HashMap::new(),
|
||||
)),
|
||||
permission_timeout_secs: cfg.permission_timeout_secs,
|
||||
})
|
||||
});
|
||||
|
||||
// Build a best-effort shutdown notifier for webhook-based transports.
|
||||
//
|
||||
// • Slack: channels are fixed at startup (channel_ids from bot.toml).
|
||||
// • Discord: channels are fixed at startup (channel_ids from bot.toml).
|
||||
// • WhatsApp: active senders are tracked at runtime in ambient_rooms.
|
||||
// We keep the WhatsApp context Arc so we can read the rooms at shutdown.
|
||||
// • Matrix: the bot task manages its own announcement via matrix_shutdown_tx.
|
||||
@@ -454,6 +496,13 @@ async fn main() -> Result<(), std::io::Error> {
|
||||
channels,
|
||||
ctx.bot_name.clone(),
|
||||
)))
|
||||
} else if let Some(ref ctx) = discord_ctx {
|
||||
let channels: Vec<String> = ctx.channel_ids.iter().cloned().collect();
|
||||
Some(Arc::new(BotShutdownNotifier::new(
|
||||
Arc::clone(&ctx.transport) as Arc<dyn crate::chat::ChatTransport>,
|
||||
channels,
|
||||
ctx.bot_name.clone(),
|
||||
)))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -496,6 +545,18 @@ async fn main() -> Result<(), std::io::Error> {
|
||||
notifier.notify_startup().await;
|
||||
});
|
||||
}
|
||||
if let Some(ref ctx) = discord_ctx {
|
||||
let transport = Arc::clone(&ctx.transport) as Arc<dyn crate::chat::ChatTransport>;
|
||||
let bot_name = ctx.bot_name.clone();
|
||||
let channels: Vec<String> = ctx.channel_ids.iter().cloned().collect();
|
||||
tokio::spawn(async move {
|
||||
if channels.is_empty() {
|
||||
return;
|
||||
}
|
||||
let notifier = crate::rebuild::BotShutdownNotifier::new(transport, channels, bot_name);
|
||||
notifier.notify_startup().await;
|
||||
});
|
||||
}
|
||||
|
||||
// Watch channel: signals the Matrix bot task to send a shutdown announcement.
|
||||
// `None` initial value means "server is running".
|
||||
@@ -559,6 +620,21 @@ async fn main() -> Result<(), std::io::Error> {
|
||||
} else {
|
||||
drop(watcher_rx_for_slack);
|
||||
}
|
||||
if let (Some(ctx), Some(root)) = (&discord_ctx, &startup_root) {
|
||||
// Spawn the Discord Gateway WebSocket listener.
|
||||
chat::transport::discord::gateway::spawn_gateway(Arc::clone(ctx));
|
||||
|
||||
// Spawn stage-transition notification listener for Discord.
|
||||
let channel_ids: Vec<String> = ctx.channel_ids.iter().cloned().collect();
|
||||
chat::transport::matrix::notifications::spawn_notification_listener(
|
||||
Arc::clone(&ctx.transport) as Arc<dyn crate::chat::ChatTransport>,
|
||||
move || channel_ids.clone(),
|
||||
watcher_rx_for_discord,
|
||||
root.clone(),
|
||||
);
|
||||
} else {
|
||||
drop(watcher_rx_for_discord);
|
||||
}
|
||||
|
||||
// On startup:
|
||||
// 1. Reconcile any stories whose agent work was committed while the server was
|
||||
|
||||
Reference in New Issue
Block a user