diff --git a/.huskies/README.md b/.huskies/README.md index 125faa0d..b8cd255a 100644 --- a/.huskies/README.md +++ b/.huskies/README.md @@ -246,6 +246,7 @@ Story Kit includes a chat bot that can be connected to one messaging platform at | WhatsApp (Meta Cloud API) | `bot.toml.whatsapp-meta.example` | `/webhook/whatsapp` | | WhatsApp (Twilio) | `bot.toml.whatsapp-twilio.example` | `/webhook/whatsapp` | | Slack | `bot.toml.slack.example` | `/webhook/slack` | +| Discord | `bot.toml.discord.example` | *(uses Discord Gateway WebSocket)* | ```bash cp .huskies/bot.toml.matrix.example .huskies/bot.toml diff --git a/.huskies/bot.toml.discord.example b/.huskies/bot.toml.discord.example new file mode 100644 index 00000000..bdc5d04f --- /dev/null +++ b/.huskies/bot.toml.discord.example @@ -0,0 +1,28 @@ +# Discord Transport +# Copy this file to bot.toml and fill in your values. +# Only one transport can be active at a time. +# +# Setup: +# 1. Create a Discord Application at discord.com/developers/applications +# 2. Go to Bot → create a bot and copy the token +# 3. Enable "Message Content Intent" under Privileged Gateway Intents +# 4. Go to OAuth2 → URL Generator, select "bot" scope with permissions: +# Send Messages, Read Message History, Manage Messages +# 5. Use the generated URL to invite the bot to your server +# 6. Right-click the channel(s) → Copy Channel ID (enable Developer Mode in settings) + +enabled = true +transport = "discord" + +discord_bot_token = "your-bot-token-here" +discord_channel_ids = ["123456789012345678"] + +# Discord user IDs allowed to interact with the bot. +# When empty, all users in configured channels can interact. +# discord_allowed_users = ["111222333444555666"] + +# Bot display name (used in formatted messages). +# display_name = "Assistant" + +# Maximum conversation turns to remember per channel (default: 20). +# history_size = 20 diff --git a/Cargo.lock b/Cargo.lock index ff926db2..bece483e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1096,6 +1096,21 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -2588,6 +2603,23 @@ dependencies = [ "version_check", ] +[[package]] +name = "native-tls" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "465500e14ea162429d264d44189adc38b199b62b1c21eea9f69e4b73cb03bbf2" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "new_debug_unreachable" version = "1.0.6" @@ -2721,12 +2753,50 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" +[[package]] +name = "openssl" +version = "0.10.76" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "951c002c75e16ea2c65b8c7e4d3d51d5530d8dfa7d060b4776828c88cfb18ecf" +dependencies = [ + "bitflags 2.11.0", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "openssl-probe" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" +[[package]] +name = "openssl-sys" +version = "0.9.112" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57d55af3b3e226502be1526dfdba67ab0e9c96fc293004e79576b2b9edb0dbdb" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "parking" version = "2.2.1" @@ -4411,6 +4481,16 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.26.4" @@ -4453,7 +4533,9 @@ checksum = "8f72a05e828585856dacd553fba484c242c46e391fb0e58917c942ee9202915c" dependencies = [ "futures-util", "log", + "native-tls", "tokio", + "tokio-native-tls", "tungstenite 0.29.0", ] @@ -4685,6 +4767,7 @@ dependencies = [ "http", "httparse", "log", + "native-tls", "rand 0.9.2", "sha1", "thiserror 2.0.18", diff --git a/Cargo.toml b/Cargo.toml index 02fc2a0e..9ce86d36 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ tempfile = "3" tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync"] } toml = "1.1.0" uuid = { version = "1.22.0", features = ["v4", "serde"] } -tokio-tungstenite = "0.29.0" +tokio-tungstenite = { version = "0.29.0", features = ["connect", "native-tls"] } walkdir = "2.5.0" filetime = "0.2" matrix-sdk = { version = "0.16.0", default-features = false, features = [ diff --git a/server/Cargo.toml b/server/Cargo.toml index 6b9d08a2..7c84e772 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -34,6 +34,7 @@ walkdir = { workspace = true } matrix-sdk = { workspace = true } pulldown-cmark = { workspace = true } regex = { workspace = true } +tokio-tungstenite = { workspace = true } # Force bundled SQLite so static musl builds don't need a system libsqlite3 libsqlite3-sys = { version = "0.35.0", features = ["bundled"] } @@ -44,6 +45,5 @@ libc = { workspace = true } [dev-dependencies] tempfile = { workspace = true } -tokio-tungstenite = { workspace = true } mockito = "1" filetime = { workspace = true } diff --git a/server/src/chat/transport/discord/commands.rs b/server/src/chat/transport/discord/commands.rs new file mode 100644 index 00000000..125beb7a --- /dev/null +++ b/server/src/chat/transport/discord/commands.rs @@ -0,0 +1,641 @@ +//! Discord incoming message dispatch and command handling. + +use std::collections::{HashMap, HashSet}; +use std::path::PathBuf; +use std::sync::{Arc, Mutex}; +use tokio::sync::{Mutex as TokioMutex, oneshot}; + +use crate::agents::AgentPool; +use crate::chat::transport::matrix::{ConversationEntry, ConversationRole, RoomConversation}; +use crate::chat::util::is_permission_approval; +use crate::chat::ChatTransport; +use crate::http::context::{PermissionDecision, PermissionForward}; +use crate::slog; + +use super::format::markdown_to_discord; +use super::history::{DiscordConversationHistory, save_discord_history}; +use super::meta::DiscordTransport; + +// ── Shared context ──────────────────────────────────────────────────── + +/// Shared context for the Discord bot, used by both the Gateway listener +/// and any future webhook handlers. +pub struct DiscordContext { + pub bot_token: String, + pub transport: Arc, + pub project_root: PathBuf, + pub agents: Arc, + pub bot_name: String, + /// The bot's Discord user ID (set dynamically from READY event, but + /// also stored here for command dispatch). + pub bot_user_id: String, + pub ambient_rooms: Arc>>, + /// Per-channel conversation history for LLM passthrough. + pub history: DiscordConversationHistory, + /// Maximum number of conversation entries to keep per channel. + pub history_size: usize, + /// Allowed channel IDs (messages from other channels are ignored). + pub channel_ids: HashSet, + /// Allowed Discord user IDs. When non-empty, only listed users can + /// interact with the bot. When empty, all users are allowed. + pub allowed_users: HashSet, + /// Permission requests from the MCP `prompt_permission` tool arrive here. + pub perm_rx: Arc>>, + /// Pending permission replies keyed by channel ID. + pub pending_perm_replies: + Arc>>>, + /// Seconds before an unanswered permission prompt is auto-denied. + pub permission_timeout_secs: u64, +} + +// ── Incoming message dispatch ─────────────────────────────────────────── + +pub(super) async fn handle_incoming_message( + ctx: &DiscordContext, + channel: &str, + user: &str, + message: &str, +) { + use crate::chat::commands::{CommandDispatch, try_handle_command}; + + // If there is a pending permission prompt for this channel, interpret the + // message as a yes/no response. + { + let mut pending = ctx.pending_perm_replies.lock().await; + if let Some(tx) = pending.remove(channel) { + let decision = if is_permission_approval(message) { + PermissionDecision::Approve + } else { + PermissionDecision::Deny + }; + let _ = tx.send(decision); + let confirmation = if decision == PermissionDecision::Approve { + "Permission approved." + } else { + "Permission denied." + }; + let formatted = markdown_to_discord(confirmation); + let _ = ctx.transport.send_message(channel, &formatted, "").await; + return; + } + } + + let dispatch = CommandDispatch { + bot_name: &ctx.bot_name, + bot_user_id: &ctx.bot_user_id, + project_root: &ctx.project_root, + agents: &ctx.agents, + ambient_rooms: &ctx.ambient_rooms, + room_id: channel, + }; + + if let Some(response) = try_handle_command(&dispatch, message) { + slog!("[discord] Sending command response to {channel}"); + let response = markdown_to_discord(&response); + if let Err(e) = ctx.transport.send_message(channel, &response, "").await { + slog!("[discord] Failed to send reply to {channel}: {e}"); + } + return; + } + + // Check for async commands (htop, delete). + if let Some(htop_cmd) = crate::chat::transport::matrix::htop::extract_htop_command( + message, + &ctx.bot_name, + &ctx.bot_user_id, + ) { + use crate::chat::transport::matrix::htop::HtopCommand; + slog!("[discord] Handling htop command from {user} in {channel}"); + match htop_cmd { + HtopCommand::Stop => { + let _ = ctx + .transport + .send_message(channel, "htop stopped.", "") + .await; + } + HtopCommand::Start { duration_secs } => { + let snapshot = crate::chat::transport::matrix::htop::build_htop_message( + &ctx.agents, + 0, + duration_secs, + ); + let snapshot = markdown_to_discord(&snapshot); + let msg_id = match ctx.transport.send_message(channel, &snapshot, "").await { + Ok(id) => id, + Err(e) => { + slog!("[discord] Failed to send htop message: {e}"); + return; + } + }; + let transport = Arc::clone(&ctx.transport) as Arc; + let agents = Arc::clone(&ctx.agents); + let ch = channel.to_string(); + tokio::spawn(async move { + let interval = std::time::Duration::from_secs(2); + let total_ticks = (duration_secs as usize) / 2; + for tick in 1..=total_ticks { + tokio::time::sleep(interval).await; + let updated = + crate::chat::transport::matrix::htop::build_htop_message( + &agents, + (tick * 2) as u32, + duration_secs, + ); + let updated = markdown_to_discord(&updated); + if let Err(e) = + transport.edit_message(&ch, &msg_id, &updated, "").await + { + slog!("[discord] Failed to edit htop message: {e}"); + break; + } + } + }); + } + } + return; + } + + if let Some(del_cmd) = crate::chat::transport::matrix::delete::extract_delete_command( + message, + &ctx.bot_name, + &ctx.bot_user_id, + ) { + let response = match del_cmd { + crate::chat::transport::matrix::delete::DeleteCommand::Delete { story_number } => { + slog!("[discord] Handling delete command from {user}: story {story_number}"); + crate::chat::transport::matrix::delete::handle_delete( + &ctx.bot_name, + &story_number, + &ctx.project_root, + &ctx.agents, + ) + .await + } + crate::chat::transport::matrix::delete::DeleteCommand::BadArgs => { + format!("Usage: `{} delete `", ctx.bot_name) + } + }; + let response = markdown_to_discord(&response); + let _ = ctx.transport.send_message(channel, &response, "").await; + return; + } + + if crate::chat::transport::matrix::rebuild::extract_rebuild_command( + message, + &ctx.bot_name, + &ctx.bot_user_id, + ) + .is_some() + { + slog!("[discord] Handling rebuild command from {user} in {channel}"); + let ack = "Rebuilding server… this may take a moment."; + let _ = ctx.transport.send_message(channel, ack, "").await; + let response = crate::chat::transport::matrix::rebuild::handle_rebuild( + &ctx.bot_name, + &ctx.project_root, + &ctx.agents, + ) + .await; + let response = markdown_to_discord(&response); + let _ = ctx.transport.send_message(channel, &response, "").await; + return; + } + + if let Some(rmtree_cmd) = crate::chat::transport::matrix::rmtree::extract_rmtree_command( + message, + &ctx.bot_name, + &ctx.bot_user_id, + ) { + let response = match rmtree_cmd { + crate::chat::transport::matrix::rmtree::RmtreeCommand::Rmtree { story_number } => { + slog!( + "[discord] Handling rmtree command from {user} in {channel}: story {story_number}" + ); + crate::chat::transport::matrix::rmtree::handle_rmtree( + &ctx.bot_name, + &story_number, + &ctx.project_root, + &ctx.agents, + ) + .await + } + crate::chat::transport::matrix::rmtree::RmtreeCommand::BadArgs => { + format!("Usage: `{} rmtree `", ctx.bot_name) + } + }; + let response = markdown_to_discord(&response); + let _ = ctx.transport.send_message(channel, &response, "").await; + return; + } + + if crate::chat::transport::matrix::reset::extract_reset_command( + message, + &ctx.bot_name, + &ctx.bot_user_id, + ) + .is_some() + { + slog!("[discord] Handling reset command from {user} in {channel}"); + { + let mut guard = ctx.history.lock().await; + let conv = guard + .entry(channel.to_string()) + .or_insert_with(RoomConversation::default); + conv.session_id = None; + conv.entries.clear(); + save_discord_history(&ctx.project_root, &guard); + } + let _ = ctx + .transport + .send_message(channel, "Session cleared.", "") + .await; + return; + } + + if let Some(start_cmd) = crate::chat::transport::matrix::start::extract_start_command( + message, + &ctx.bot_name, + &ctx.bot_user_id, + ) { + let response = match start_cmd { + crate::chat::transport::matrix::start::StartCommand::Start { + story_number, + agent_hint, + } => { + slog!( + "[discord] Handling start command from {user} in {channel}: story {story_number}" + ); + crate::chat::transport::matrix::start::handle_start( + &ctx.bot_name, + &story_number, + agent_hint.as_deref(), + &ctx.project_root, + &ctx.agents, + ) + .await + } + crate::chat::transport::matrix::start::StartCommand::BadArgs => { + format!("Usage: `{} start `", ctx.bot_name) + } + }; + let response = markdown_to_discord(&response); + let _ = ctx.transport.send_message(channel, &response, "").await; + return; + } + + if let Some(assign_cmd) = crate::chat::transport::matrix::assign::extract_assign_command( + message, + &ctx.bot_name, + &ctx.bot_user_id, + ) { + let response = match assign_cmd { + crate::chat::transport::matrix::assign::AssignCommand::Assign { + story_number, + model, + } => { + slog!( + "[discord] Handling assign command from {user} in {channel}: story {story_number} model {model}" + ); + crate::chat::transport::matrix::assign::handle_assign( + &ctx.bot_name, + &story_number, + &model, + &ctx.project_root, + &ctx.agents, + ) + .await + } + crate::chat::transport::matrix::assign::AssignCommand::BadArgs => { + format!("Usage: `{} assign `", ctx.bot_name) + } + }; + let response = markdown_to_discord(&response); + let _ = ctx.transport.send_message(channel, &response, "").await; + return; + } + + // No command matched — forward to LLM for conversational response. + slog!("[discord] No command matched, forwarding to LLM for {user} in {channel}"); + handle_llm_message(ctx, channel, user, message).await; +} + +/// Forward a message to Claude Code and send the response back via Discord. +async fn handle_llm_message( + ctx: &DiscordContext, + channel: &str, + user: &str, + user_message: &str, +) { + use crate::chat::util::drain_complete_paragraphs; + use crate::llm::providers::claude_code::{ClaudeCodeProvider, ClaudeCodeResult}; + use std::sync::atomic::{AtomicBool, Ordering}; + use tokio::sync::watch; + + // Look up existing session ID for this channel. + let resume_session_id: Option = { + let guard = ctx.history.lock().await; + guard + .get(channel) + .and_then(|conv| conv.session_id.clone()) + }; + + let bot_name = &ctx.bot_name; + let prompt = format!( + "[Your name is {bot_name}. Refer to yourself as {bot_name}, not Claude.]\n\n{user}: {user_message}" + ); + + let provider = ClaudeCodeProvider::new(); + let (_cancel_tx, mut cancel_rx) = watch::channel(false); + + // Channel for sending complete chunks to the Discord posting task. + let (msg_tx, mut msg_rx) = tokio::sync::mpsc::unbounded_channel::(); + let msg_tx_for_callback = msg_tx.clone(); + + // Spawn a task to post messages as they arrive. + let post_transport = Arc::clone(&ctx.transport) as Arc; + let post_channel = channel.to_string(); + let post_task = tokio::spawn(async move { + while let Some(chunk) = msg_rx.recv().await { + let formatted = markdown_to_discord(&chunk); + let _ = post_transport + .send_message(&post_channel, &formatted, "") + .await; + } + }); + + // Shared buffer between the sync token callback and the async scope. + let buffer = Arc::new(std::sync::Mutex::new(String::new())); + let buffer_for_callback = Arc::clone(&buffer); + let sent_any_chunk = Arc::new(AtomicBool::new(false)); + let sent_any_chunk_for_callback = Arc::clone(&sent_any_chunk); + + let project_root_str = ctx.project_root.to_string_lossy().to_string(); + let chat_fut = provider.chat_stream( + &prompt, + &project_root_str, + resume_session_id.as_deref(), + None, + &mut cancel_rx, + move |token| { + let mut buf = buffer_for_callback.lock().unwrap(); + buf.push_str(token); + let paragraphs = drain_complete_paragraphs(&mut buf); + for chunk in paragraphs { + sent_any_chunk_for_callback.store(true, Ordering::Relaxed); + let _ = msg_tx_for_callback.send(chunk); + } + }, + |_thinking| {}, + |_activity| {}, + ); + tokio::pin!(chat_fut); + + // Lock the permission receiver for the duration of this chat session. + let mut perm_rx_guard = ctx.perm_rx.lock().await; + + let result = loop { + tokio::select! { + r = &mut chat_fut => break r, + + Some(perm_fwd) = perm_rx_guard.recv() => { + let prompt_msg = format!( + "**Permission Request**\n\nTool: `{}`\n```json\n{}\n```\n\nReply **yes** to approve or **no** to deny.", + perm_fwd.tool_name, + serde_json::to_string_pretty(&perm_fwd.tool_input) + .unwrap_or_else(|_| perm_fwd.tool_input.to_string()), + ); + let formatted = markdown_to_discord(&prompt_msg); + let _ = ctx.transport.send_message(channel, &formatted, "").await; + + ctx.pending_perm_replies + .lock() + .await + .insert(channel.to_string(), perm_fwd.response_tx); + + // Spawn a timeout task: auto-deny if the user does not respond. + let pending = Arc::clone(&ctx.pending_perm_replies); + let timeout_channel = channel.to_string(); + let timeout_transport = Arc::clone(&ctx.transport) as Arc; + let timeout_secs = ctx.permission_timeout_secs; + tokio::spawn(async move { + tokio::time::sleep(std::time::Duration::from_secs(timeout_secs)).await; + if let Some(tx) = pending.lock().await.remove(&timeout_channel) { + let _ = tx.send(PermissionDecision::Deny); + let msg = "Permission request timed out — denied (fail-closed)."; + let _ = timeout_transport.send_message(&timeout_channel, msg, "").await; + } + }); + } + } + }; + drop(perm_rx_guard); + + // Flush remaining text. + let remaining = buffer.lock().unwrap().trim().to_string(); + let did_send_any = sent_any_chunk.load(Ordering::Relaxed); + + let (assistant_reply, new_session_id) = match result { + Ok(ClaudeCodeResult { + messages, + session_id, + }) => { + let reply = if !remaining.is_empty() { + let _ = msg_tx.send(remaining.clone()); + remaining + } else if !did_send_any { + let last_text = messages + .iter() + .rev() + .find(|m| { + m.role == crate::llm::types::Role::Assistant && !m.content.is_empty() + }) + .map(|m| m.content.clone()) + .unwrap_or_default(); + if !last_text.is_empty() { + let _ = msg_tx.send(last_text.clone()); + } + last_text + } else { + remaining + }; + slog!("[discord] session_id from chat_stream: {:?}", session_id); + (reply, session_id) + } + Err(e) => { + slog!("[discord] LLM error: {e}"); + let err_msg = format!("Error processing your request: {e}"); + let _ = msg_tx.send(err_msg.clone()); + (err_msg, None) + } + }; + + // Signal the posting task to finish and wait for it. + drop(msg_tx); + let _ = post_task.await; + + // Record this exchange in conversation history. + if !assistant_reply.starts_with("Error processing") { + let mut guard = ctx.history.lock().await; + let conv = guard.entry(channel.to_string()).or_default(); + + if new_session_id.is_some() { + conv.session_id = new_session_id; + } + + conv.entries.push(ConversationEntry { + role: ConversationRole::User, + sender: user.to_string(), + content: user_message.to_string(), + }); + conv.entries.push(ConversationEntry { + role: ConversationRole::Assistant, + sender: String::new(), + content: assistant_reply, + }); + + // Trim to configured maximum. + if conv.entries.len() > ctx.history_size { + let excess = conv.entries.len() - ctx.history_size; + conv.entries.drain(..excess); + } + + save_discord_history(&ctx.project_root, &guard); + } +} + +// ── Tests ─────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + fn test_agents() -> Arc { + Arc::new(crate::agents::AgentPool::new_test(3000)) + } + + fn test_ambient_rooms() -> Arc>> { + Arc::new(Mutex::new(HashSet::new())) + } + + #[test] + fn command_dispatches_through_command_registry() { + use crate::chat::commands::{CommandDispatch, try_handle_command}; + + let agents = test_agents(); + let ambient_rooms = test_ambient_rooms(); + let room_id = "123456789".to_string(); + + let bot_name = "Huskies"; + let synthetic = format!("{bot_name} status"); + + let dispatch = CommandDispatch { + bot_name, + bot_user_id: "discord-bot", + project_root: std::path::Path::new("/tmp"), + agents: &agents, + ambient_rooms: &ambient_rooms, + room_id: &room_id, + }; + + let result = try_handle_command(&dispatch, &synthetic); + assert!( + result.is_some(), + "status command should produce output via registry" + ); + assert!(result.unwrap().contains("Pipeline Status")); + } + + #[test] + fn rebuild_command_extracted_from_discord_message() { + let result = crate::chat::transport::matrix::rebuild::extract_rebuild_command( + "Huskies rebuild", + "Huskies", + "discord-bot", + ); + assert!(result.is_some(), "'Huskies rebuild' should be recognised"); + } + + #[test] + fn reset_command_extracted_from_discord_message() { + let result = crate::chat::transport::matrix::reset::extract_reset_command( + "Huskies reset", + "Huskies", + "discord-bot", + ); + assert!(result.is_some(), "'Huskies reset' should be recognised"); + } + + #[test] + fn start_command_extracted_from_discord_message() { + let result = crate::chat::transport::matrix::start::extract_start_command( + "start 42", + "Huskies", + "discord-bot", + ); + assert!(result.is_some(), "plain 'start 42' should be recognised"); + assert_eq!( + result, + Some(crate::chat::transport::matrix::start::StartCommand::Start { + story_number: "42".to_string(), + agent_hint: None, + }) + ); + } + + #[test] + fn assign_command_extracted_from_discord_message() { + let result = crate::chat::transport::matrix::assign::extract_assign_command( + "assign 42 opus", + "Huskies", + "discord-bot", + ); + assert!( + matches!( + result, + Some(crate::chat::transport::matrix::assign::AssignCommand::Assign { .. }) + ), + "plain 'assign 42 opus' should be recognised on Discord" + ); + } + + #[tokio::test] + async fn reset_command_clears_discord_session() { + use crate::chat::transport::matrix::RoomConversation; + + let channel = "123456789"; + let history: DiscordConversationHistory = Arc::new(TokioMutex::new({ + let mut m = HashMap::new(); + m.insert( + channel.to_string(), + RoomConversation { + session_id: Some("old-session".to_string()), + entries: vec![ConversationEntry { + role: ConversationRole::User, + sender: "user123".to_string(), + content: "previous message".to_string(), + }], + }, + ); + m + })); + + let tmp = tempfile::tempdir().unwrap(); + let sk = tmp.path().join(".huskies"); + std::fs::create_dir_all(&sk).unwrap(); + + { + let mut guard = history.lock().await; + let conv = guard + .entry(channel.to_string()) + .or_insert_with(RoomConversation::default); + conv.session_id = None; + conv.entries.clear(); + save_discord_history(tmp.path(), &guard); + } + + let guard = history.lock().await; + let conv = guard.get(channel).unwrap(); + assert!(conv.session_id.is_none(), "session_id should be cleared"); + assert!(conv.entries.is_empty(), "entries should be cleared"); + } +} diff --git a/server/src/chat/transport/discord/format.rs b/server/src/chat/transport/discord/format.rs new file mode 100644 index 00000000..eb5d30ab --- /dev/null +++ b/server/src/chat/transport/discord/format.rs @@ -0,0 +1,55 @@ +//! Markdown to Discord format conversion. +//! +//! Discord supports standard Markdown natively, so unlike Slack we only need +//! minimal adjustments (e.g. truncating messages that exceed Discord's limit). + +/// Convert Markdown text to Discord-compatible format. +/// +/// Discord supports most standard Markdown: +/// - `**bold**`, `*italic*`, `~~strikethrough~~` +/// - `` `code` `` and ````code blocks```` +/// - `> blockquotes` +/// - `[text](url)` links +/// +/// The main adjustment is ensuring messages stay within Discord's 2000-char +/// limit. Actual truncation is handled by the transport layer; this function +/// performs any lightweight text normalization needed. +pub fn markdown_to_discord(text: &str) -> String { + use crate::chat::util::normalize_line_breaks; + normalize_line_breaks(text) +} + +// ── Tests ─────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn discord_plain_text_unchanged() { + let plain = "Hello, this is a plain message."; + assert_eq!(markdown_to_discord(plain), plain); + } + + #[test] + fn discord_bold_preserved() { + assert_eq!(markdown_to_discord("**bold text**"), "**bold text**"); + } + + #[test] + fn discord_code_block_preserved() { + let input = "```rust\nlet x = 1;\n```"; + assert_eq!(markdown_to_discord(input), input); + } + + #[test] + fn discord_empty_string_unchanged() { + assert_eq!(markdown_to_discord(""), ""); + } + + #[test] + fn discord_link_preserved() { + let input = "[click here](https://example.com)"; + assert_eq!(markdown_to_discord(input), input); + } +} diff --git a/server/src/chat/transport/discord/gateway.rs b/server/src/chat/transport/discord/gateway.rs new file mode 100644 index 00000000..0bab7227 --- /dev/null +++ b/server/src/chat/transport/discord/gateway.rs @@ -0,0 +1,507 @@ +//! Minimal Discord Gateway WebSocket client. +//! +//! Connects to the Discord Gateway, authenticates with a bot token, maintains +//! the heartbeat keepalive, and dispatches incoming `MESSAGE_CREATE` events to +//! the command handler. + +use std::sync::Arc; + +use futures::{SinkExt, StreamExt}; +use serde::{Deserialize, Serialize}; +use tokio_tungstenite::tungstenite::Message as WsMessage; + +use crate::slog; + +use super::commands::{self, DiscordContext}; + +// ── Gateway opcodes ───────────────────────────────────────────────────── + +const OP_DISPATCH: u8 = 0; +const OP_HEARTBEAT: u8 = 1; +const OP_IDENTIFY: u8 = 2; +const OP_RECONNECT: u8 = 7; +const OP_INVALID_SESSION: u8 = 9; +const OP_HELLO: u8 = 10; +const OP_HEARTBEAT_ACK: u8 = 11; + +// ── Gateway intents ───────────────────────────────────────────────────── + +/// GUILDS | GUILD_MESSAGES | MESSAGE_CONTENT (privileged) +const GATEWAY_INTENTS: u64 = (1 << 0) | (1 << 9) | (1 << 15); + +// ── Gateway payload types ─────────────────────────────────────────────── + +#[derive(Deserialize, Debug)] +struct GatewayPayload { + op: u8, + #[serde(default)] + d: Option, + #[serde(default)] + s: Option, + #[serde(default)] + t: Option, +} + +#[derive(Deserialize, Debug)] +struct HelloData { + heartbeat_interval: u64, +} + +#[derive(Deserialize, Debug)] +pub(super) struct MessageCreateData { + #[serde(default)] + pub channel_id: String, + #[serde(default)] + pub content: String, + #[serde(default)] + pub author: Option, + #[serde(default)] + pub mentions: Vec, +} + +#[derive(Deserialize, Debug)] +pub(super) struct MessageAuthor { + pub id: String, + #[serde(default)] + pub bot: Option, +} + +#[derive(Deserialize, Debug)] +pub(super) struct MentionUser { + pub id: String, +} + +#[derive(Deserialize, Debug)] +struct ReadyData { + user: ReadyUser, +} + +#[derive(Deserialize, Debug)] +struct ReadyUser { + id: String, +} + +// ── Identify payload ──────────────────────────────────────────────────── + +#[derive(Serialize)] +struct IdentifyPayload { + op: u8, + d: IdentifyData, +} + +#[derive(Serialize)] +struct IdentifyData { + token: String, + intents: u64, + properties: IdentifyProperties, +} + +#[derive(Serialize)] +struct IdentifyProperties { + os: &'static str, + browser: &'static str, + device: &'static str, +} + +#[derive(Serialize)] +struct HeartbeatPayload { + op: u8, + d: Option, +} + +// ── Gateway URL ───────────────────────────────────────────────────────── + +const GATEWAY_URL: &str = "wss://gateway.discord.gg/?v=10&encoding=json"; + +// ── Public entry point ────────────────────────────────────────────────── + +/// Spawn a background task that connects to the Discord Gateway, maintains +/// the heartbeat, and dispatches incoming messages. +pub fn spawn_gateway(ctx: Arc) { + tokio::spawn(async move { + loop { + if let Err(e) = run_gateway(Arc::clone(&ctx)).await { + slog!("[discord] Gateway connection ended: {e}"); + } + // Wait before reconnecting. + slog!("[discord] Reconnecting in 5 seconds…"); + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + } + }); +} + +/// Connect to the Discord Gateway and process events until the connection +/// closes or an unrecoverable error occurs. +async fn run_gateway(ctx: Arc) -> Result<(), String> { + slog!("[discord] Connecting to Discord Gateway…"); + + let (ws_stream, _) = tokio_tungstenite::connect_async(GATEWAY_URL) + .await + .map_err(|e| format!("Gateway connect failed: {e}"))?; + + slog!("[discord] Gateway WebSocket connected"); + + let (mut write, mut read) = ws_stream.split(); + + // Wait for Hello (op 10) to get the heartbeat interval. + let hello = read + .next() + .await + .ok_or("Gateway closed before Hello")? + .map_err(|e| format!("Gateway read error: {e}"))?; + + let hello_payload: GatewayPayload = + parse_ws_message(&hello).ok_or("Failed to parse Hello")?; + + if hello_payload.op != OP_HELLO { + return Err(format!( + "Expected Hello (op 10), got op {}", + hello_payload.op + )); + } + + let hello_data: HelloData = + serde_json::from_value(hello_payload.d.ok_or("Hello missing data")?) + .map_err(|e| format!("Failed to parse Hello data: {e}"))?; + + let heartbeat_interval = + std::time::Duration::from_millis(hello_data.heartbeat_interval); + slog!( + "[discord] Heartbeat interval: {}ms", + hello_data.heartbeat_interval + ); + + // Send Identify. + let identify = IdentifyPayload { + op: OP_IDENTIFY, + d: IdentifyData { + token: ctx.bot_token.clone(), + intents: GATEWAY_INTENTS, + properties: IdentifyProperties { + os: "linux", + browser: "huskies", + device: "huskies", + }, + }, + }; + + let identify_json = serde_json::to_string(&identify) + .map_err(|e| format!("Failed to serialise Identify: {e}"))?; + write + .send(WsMessage::Text(identify_json.into())) + .await + .map_err(|e| format!("Failed to send Identify: {e}"))?; + + slog!("[discord] Identify sent, waiting for Ready…"); + + // Track sequence number for heartbeats. + let sequence = Arc::new(std::sync::atomic::AtomicU64::new(0)); + let seq_for_heartbeat = Arc::clone(&sequence); + + // Share the write half between the heartbeat task and the main event loop. + let write = Arc::new(tokio::sync::Mutex::new(write)); + let write_for_heartbeat = Arc::clone(&write); + + // Spawn heartbeat task. + let heartbeat_handle = tokio::spawn(async move { + // Jitter: wait a random fraction of the interval before the first beat. + let jitter = heartbeat_interval.mul_f64(rand_fraction()); + tokio::time::sleep(jitter).await; + + loop { + let seq = seq_for_heartbeat.load(std::sync::atomic::Ordering::Relaxed); + let seq_val = if seq == 0 { None } else { Some(seq) }; + let hb = HeartbeatPayload { + op: OP_HEARTBEAT, + d: seq_val, + }; + if let Ok(json) = serde_json::to_string(&hb) { + let mut w = write_for_heartbeat.lock().await; + if w.send(WsMessage::Text(json.into())).await.is_err() { + slog!("[discord] Heartbeat send failed, stopping heartbeat"); + break; + } + } + tokio::time::sleep(heartbeat_interval).await; + } + }); + + // Track our own bot user ID (set from READY event). + let mut bot_user_id: Option = None; + + // Main event loop. + while let Some(msg_result) = read.next().await { + let msg = match msg_result { + Ok(m) => m, + Err(e) => { + slog!("[discord] Gateway read error: {e}"); + break; + } + }; + + let payload: GatewayPayload = match parse_ws_message(&msg) { + Some(p) => p, + None => continue, + }; + + // Update sequence number. + if let Some(s) = payload.s { + sequence.store(s, std::sync::atomic::Ordering::Relaxed); + } + + match payload.op { + OP_DISPATCH => { + let event_name = payload.t.as_deref().unwrap_or(""); + match event_name { + "READY" => { + if let Some(d) = payload.d + && let Ok(ready) = serde_json::from_value::(d) + { + bot_user_id = Some(ready.user.id.clone()); + slog!( + "[discord] READY — bot user ID: {}", + ready.user.id + ); + } + } + "MESSAGE_CREATE" => { + if let Some(d) = payload.d { + dispatch_message( + Arc::clone(&ctx), + d, + bot_user_id.clone(), + ); + } + } + _ => {} + } + } + OP_HEARTBEAT => { + // Server requested an immediate heartbeat. + let seq = sequence.load(std::sync::atomic::Ordering::Relaxed); + let seq_val = if seq == 0 { None } else { Some(seq) }; + let hb = HeartbeatPayload { + op: OP_HEARTBEAT, + d: seq_val, + }; + if let Ok(json) = serde_json::to_string(&hb) { + let mut w = write.lock().await; + let _ = w.send(WsMessage::Text(json.into())).await; + } + } + OP_HEARTBEAT_ACK => {} + OP_RECONNECT => { + slog!("[discord] Gateway requested reconnect"); + break; + } + OP_INVALID_SESSION => { + slog!("[discord] Invalid session — will reconnect"); + // Wait a bit before reconnecting for invalid sessions. + tokio::time::sleep(std::time::Duration::from_secs(3)).await; + break; + } + _ => {} + } + } + + heartbeat_handle.abort(); + slog!("[discord] Gateway event loop ended"); + Ok(()) +} + +// ── Helpers ───────────────────────────────────────────────────────────── + +fn parse_ws_message(msg: &WsMessage) -> Option { + match msg { + WsMessage::Text(text) => serde_json::from_str(text).ok(), + _ => None, + } +} + +/// Filter and dispatch an incoming MESSAGE_CREATE event on a background task. +fn dispatch_message( + ctx: Arc, + data: serde_json::Value, + bot_user_id: Option, +) { + tokio::spawn(async move { + let msg: MessageCreateData = match serde_json::from_value(data) { + Ok(m) => m, + Err(e) => { + slog!("[discord] Failed to parse MESSAGE_CREATE: {e}"); + return; + } + }; + + let author = match msg.author.as_ref() { + Some(a) => a, + None => return, + }; + + // Ignore bot messages (including our own). + if author.bot.unwrap_or(false) { + return; + } + + // Check if the channel is in our configured channel list. + if !ctx.channel_ids.contains(&msg.channel_id) { + return; + } + + // Check allowed_users (if non-empty, only listed users can interact). + if !ctx.allowed_users.is_empty() && !ctx.allowed_users.contains(&author.id) { + return; + } + + // Check if the bot was mentioned, or if we respond to all messages in + // configured channels (ambient mode). + let bot_mentioned = bot_user_id.as_ref().is_some_and(|bid| { + msg.mentions.iter().any(|m| m.id == *bid) + }); + + let in_ambient = ctx + .ambient_rooms + .lock() + .unwrap() + .contains(&msg.channel_id); + + if !bot_mentioned && !in_ambient { + return; + } + + // Strip the bot mention from the message content if present. + let content = if let Some(bid) = bot_user_id.as_ref() { + let mention_pat = format!("<@{bid}>"); + let mention_pat_nick = format!("<@!{bid}>"); + msg.content + .replace(&mention_pat, "") + .replace(&mention_pat_nick, "") + .trim() + .to_string() + } else { + msg.content.clone() + }; + + if content.is_empty() { + return; + } + + slog!( + "[discord] Message from {} in {}: {content:.80}", + author.id, + msg.channel_id + ); + + commands::handle_incoming_message(&ctx, &msg.channel_id, &author.id, &content) + .await; + }); +} + +/// Cheap pseudo-random fraction [0.0, 1.0) for heartbeat jitter. +fn rand_fraction() -> f64 { + let nanos = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .subsec_nanos(); + (nanos as f64) / 1_000_000_000.0 +} + +// ── Tests ─────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_hello_payload() { + let json = r#"{"op": 10, "d": {"heartbeat_interval": 41250}}"#; + let payload: GatewayPayload = serde_json::from_str(json).unwrap(); + assert_eq!(payload.op, OP_HELLO); + let hello: HelloData = + serde_json::from_value(payload.d.unwrap()).unwrap(); + assert_eq!(hello.heartbeat_interval, 41250); + } + + #[test] + fn parse_message_create() { + let json = r#"{ + "channel_id": "123456", + "content": "hello bot", + "author": {"id": "789", "bot": false}, + "mentions": [{"id": "111"}] + }"#; + let msg: MessageCreateData = serde_json::from_str(json).unwrap(); + assert_eq!(msg.channel_id, "123456"); + assert_eq!(msg.content, "hello bot"); + assert_eq!(msg.author.as_ref().unwrap().id, "789"); + assert!(!msg.author.as_ref().unwrap().bot.unwrap_or(false)); + assert_eq!(msg.mentions.len(), 1); + assert_eq!(msg.mentions[0].id, "111"); + } + + #[test] + fn parse_ready_event() { + let json = r#"{"user": {"id": "999888"}}"#; + let ready: ReadyData = serde_json::from_str(json).unwrap(); + assert_eq!(ready.user.id, "999888"); + } + + #[test] + fn identify_payload_serializes_correctly() { + let identify = IdentifyPayload { + op: OP_IDENTIFY, + d: IdentifyData { + token: "test-token".to_string(), + intents: GATEWAY_INTENTS, + properties: IdentifyProperties { + os: "linux", + browser: "huskies", + device: "huskies", + }, + }, + }; + let json: serde_json::Value = + serde_json::from_str(&serde_json::to_string(&identify).unwrap()).unwrap(); + assert_eq!(json["op"], 2); + assert_eq!(json["d"]["token"], "test-token"); + assert_eq!(json["d"]["intents"], GATEWAY_INTENTS); + assert_eq!(json["d"]["properties"]["browser"], "huskies"); + } + + #[test] + fn heartbeat_payload_serializes_with_sequence() { + let hb = HeartbeatPayload { + op: OP_HEARTBEAT, + d: Some(42), + }; + let json: serde_json::Value = + serde_json::from_str(&serde_json::to_string(&hb).unwrap()).unwrap(); + assert_eq!(json["op"], 1); + assert_eq!(json["d"], 42); + } + + #[test] + fn heartbeat_payload_serializes_null_sequence() { + let hb = HeartbeatPayload { + op: OP_HEARTBEAT, + d: None, + }; + let json: serde_json::Value = + serde_json::from_str(&serde_json::to_string(&hb).unwrap()).unwrap(); + assert_eq!(json["op"], 1); + assert!(json["d"].is_null()); + } + + #[test] + fn rand_fraction_in_range() { + let f = rand_fraction(); + assert!((0.0..1.0).contains(&f)); + } + + #[test] + fn gateway_intents_correct() { + // GUILDS (1) | GUILD_MESSAGES (512) | MESSAGE_CONTENT (32768) + assert_eq!(GATEWAY_INTENTS, 1 | 512 | 32768); + assert_eq!(GATEWAY_INTENTS, 33281); + } +} diff --git a/server/src/chat/transport/discord/history.rs b/server/src/chat/transport/discord/history.rs new file mode 100644 index 00000000..9b8828aa --- /dev/null +++ b/server/src/chat/transport/discord/history.rs @@ -0,0 +1,119 @@ +//! Discord conversation history persistence. + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::Mutex as TokioMutex; + +use crate::chat::transport::matrix::RoomConversation; +use crate::slog; + +/// Per-channel conversation history, keyed by channel ID. +pub type DiscordConversationHistory = Arc>>; + +/// On-disk format for persisted Discord conversation history. +#[derive(Serialize, Deserialize)] +struct PersistedDiscordHistory { + channels: HashMap, +} + +/// Path to the persisted Discord conversation history file. +const DISCORD_HISTORY_FILE: &str = ".huskies/discord_history.json"; + +/// Load Discord conversation history from disk. +pub fn load_discord_history(project_root: &std::path::Path) -> HashMap { + let path = project_root.join(DISCORD_HISTORY_FILE); + let data = match std::fs::read_to_string(&path) { + Ok(d) => d, + Err(_) => return HashMap::new(), + }; + let persisted: PersistedDiscordHistory = match serde_json::from_str(&data) { + Ok(p) => p, + Err(e) => { + slog!("[discord] Failed to parse history file: {e}"); + return HashMap::new(); + } + }; + persisted.channels +} + +/// Save Discord conversation history to disk. +pub(super) fn save_discord_history( + project_root: &std::path::Path, + history: &HashMap, +) { + let persisted = PersistedDiscordHistory { + channels: history.clone(), + }; + let path = project_root.join(DISCORD_HISTORY_FILE); + match serde_json::to_string_pretty(&persisted) { + Ok(json) => { + if let Err(e) = std::fs::write(&path, json) { + slog!("[discord] Failed to write history file: {e}"); + } + } + Err(e) => slog!("[discord] Failed to serialise history: {e}"), + } +} + +// ── Tests ─────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::chat::transport::matrix::{ConversationEntry, ConversationRole}; + + #[test] + fn save_and_load_discord_history_round_trips() { + let tmp = tempfile::tempdir().unwrap(); + let sk = tmp.path().join(".huskies"); + std::fs::create_dir_all(&sk).unwrap(); + + let mut history = HashMap::new(); + history.insert( + "123456789".to_string(), + RoomConversation { + session_id: Some("sess-abc".to_string()), + entries: vec![ + ConversationEntry { + role: ConversationRole::User, + sender: "user123".to_string(), + content: "hello".to_string(), + }, + ConversationEntry { + role: ConversationRole::Assistant, + sender: String::new(), + content: "hi there!".to_string(), + }, + ], + }, + ); + + save_discord_history(tmp.path(), &history); + let loaded = load_discord_history(tmp.path()); + + assert_eq!(loaded.len(), 1); + let conv = loaded.get("123456789").unwrap(); + assert_eq!(conv.session_id.as_deref(), Some("sess-abc")); + assert_eq!(conv.entries.len(), 2); + assert_eq!(conv.entries[0].content, "hello"); + assert_eq!(conv.entries[1].content, "hi there!"); + } + + #[test] + fn load_discord_history_returns_empty_when_file_missing() { + let tmp = tempfile::tempdir().unwrap(); + let history = load_discord_history(tmp.path()); + assert!(history.is_empty()); + } + + #[test] + fn load_discord_history_returns_empty_on_invalid_json() { + let tmp = tempfile::tempdir().unwrap(); + let sk = tmp.path().join(".huskies"); + std::fs::create_dir_all(&sk).unwrap(); + std::fs::write(sk.join("discord_history.json"), "not json {{{").unwrap(); + let history = load_discord_history(tmp.path()); + assert!(history.is_empty()); + } +} diff --git a/server/src/chat/transport/discord/meta.rs b/server/src/chat/transport/discord/meta.rs new file mode 100644 index 00000000..c72668d0 --- /dev/null +++ b/server/src/chat/transport/discord/meta.rs @@ -0,0 +1,302 @@ +//! DiscordTransport — ChatTransport implementation for the Discord Bot API. + +use async_trait::async_trait; +use serde::Deserialize; + +use crate::chat::{ChatTransport, MessageId}; +use crate::slog; + +// ── Discord API base URL (overridable for tests) ────────────────────── + +const DISCORD_API_BASE: &str = "https://discord.com/api/v10"; + +// ── DiscordTransport ────────────────────────────────────────────────── + +/// Discord Bot API transport. +/// +/// Sends messages via `POST {API_BASE}/channels/{channel_id}/messages`, +/// edits via `PATCH {API_BASE}/channels/{channel_id}/messages/{message_id}`, +/// and typing indicators via `POST {API_BASE}/channels/{channel_id}/typing`. +pub struct DiscordTransport { + bot_token: String, + client: reqwest::Client, + /// Optional base URL override for tests. + api_base: String, +} + +impl DiscordTransport { + pub fn new(bot_token: String) -> Self { + Self { + bot_token, + client: reqwest::Client::new(), + api_base: DISCORD_API_BASE.to_string(), + } + } + + #[cfg(test)] + fn with_api_base(bot_token: String, api_base: String) -> Self { + Self { + bot_token, + client: reqwest::Client::new(), + api_base, + } + } + + /// Authorization header value: `Bot {token}`. + fn auth_header(&self) -> String { + format!("Bot {}", self.bot_token) + } +} + +// ── Discord API response types ──────────────────────────────────────── + +#[derive(Deserialize, Debug)] +struct DiscordMessage { + id: String, +} + +// ── ChatTransport implementation ────────────────────────────────────── + +#[async_trait] +impl ChatTransport for DiscordTransport { + async fn send_message( + &self, + channel_id: &str, + plain: &str, + _html: &str, + ) -> Result { + slog!("[discord] send_message to {channel_id}: {plain:.80}"); + let url = format!("{}/channels/{}/messages", self.api_base, channel_id); + + // Discord messages have a 2000-char limit. Truncate if needed. + let content = if plain.len() > 2000 { + format!("{}…", &plain[..1999]) + } else { + plain.to_string() + }; + + let payload = serde_json::json!({ "content": content }); + + let resp = self + .client + .post(&url) + .header("Authorization", self.auth_header()) + .json(&payload) + .send() + .await + .map_err(|e| format!("Discord API request failed: {e}"))?; + + let status = resp.status(); + let resp_text = resp + .text() + .await + .unwrap_or_else(|_| "".to_string()); + + if !status.is_success() { + return Err(format!("Discord API returned {status}: {resp_text}")); + } + + let msg: DiscordMessage = serde_json::from_str(&resp_text).map_err(|e| { + format!("Failed to parse Discord API response: {e} — body: {resp_text}") + })?; + + Ok(msg.id) + } + + async fn edit_message( + &self, + channel_id: &str, + original_message_id: &str, + plain: &str, + _html: &str, + ) -> Result<(), String> { + slog!("[discord] edit_message in {channel_id}: id={original_message_id}"); + let url = format!( + "{}/channels/{}/messages/{}", + self.api_base, channel_id, original_message_id + ); + + let content = if plain.len() > 2000 { + format!("{}…", &plain[..1999]) + } else { + plain.to_string() + }; + + let payload = serde_json::json!({ "content": content }); + + let resp = self + .client + .patch(&url) + .header("Authorization", self.auth_header()) + .json(&payload) + .send() + .await + .map_err(|e| format!("Discord edit request failed: {e}"))?; + + let status = resp.status(); + let resp_text = resp + .text() + .await + .unwrap_or_else(|_| "".to_string()); + + if !status.is_success() { + return Err(format!("Discord edit returned {status}: {resp_text}")); + } + + Ok(()) + } + + async fn send_typing(&self, channel_id: &str, typing: bool) -> Result<(), String> { + if !typing { + // Discord doesn't have a "stop typing" API — it times out after ~10s. + return Ok(()); + } + let url = format!("{}/channels/{}/typing", self.api_base, channel_id); + + self.client + .post(&url) + .header("Authorization", self.auth_header()) + .send() + .await + .map_err(|e| format!("Discord typing request failed: {e}"))?; + + Ok(()) + } +} + +// ── Tests ─────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + #[tokio::test] + async fn transport_send_message_calls_discord_api() { + let mut server = mockito::Server::new_async().await; + let mock = server + .mock("POST", "/channels/123456/messages") + .match_header("authorization", "Bot test-token") + .with_body(r#"{"id": "999888777"}"#) + .create_async() + .await; + + let transport = + DiscordTransport::with_api_base("test-token".to_string(), server.url()); + + let result = transport + .send_message("123456", "hello", "

hello

") + .await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "999888777"); + mock.assert_async().await; + } + + #[tokio::test] + async fn transport_send_message_handles_api_error() { + let mut server = mockito::Server::new_async().await; + server + .mock("POST", "/channels/bad/messages") + .with_status(404) + .with_body(r#"{"message": "Unknown Channel", "code": 10003}"#) + .create_async() + .await; + + let transport = + DiscordTransport::with_api_base("test-token".to_string(), server.url()); + + let result = transport.send_message("bad", "hello", "").await; + assert!(result.is_err()); + assert!(result.unwrap_err().contains("404")); + } + + #[tokio::test] + async fn transport_edit_message_calls_patch() { + let mut server = mockito::Server::new_async().await; + let mock = server + .mock("PATCH", "/channels/123456/messages/999888777") + .match_header("authorization", "Bot test-token") + .with_body(r#"{"id": "999888777"}"#) + .create_async() + .await; + + let transport = + DiscordTransport::with_api_base("test-token".to_string(), server.url()); + + let result = transport + .edit_message("123456", "999888777", "updated", "") + .await; + assert!(result.is_ok()); + mock.assert_async().await; + } + + #[tokio::test] + async fn transport_edit_message_handles_error() { + let mut server = mockito::Server::new_async().await; + server + .mock("PATCH", "/channels/123456/messages/bad") + .with_status(404) + .with_body(r#"{"message": "Unknown Message", "code": 10008}"#) + .create_async() + .await; + + let transport = + DiscordTransport::with_api_base("test-token".to_string(), server.url()); + + let result = transport + .edit_message("123456", "bad", "updated", "") + .await; + assert!(result.is_err()); + assert!(result.unwrap_err().contains("404")); + } + + #[tokio::test] + async fn transport_send_typing_succeeds() { + let mut server = mockito::Server::new_async().await; + server + .mock("POST", "/channels/123456/typing") + .with_status(204) + .create_async() + .await; + + let transport = + DiscordTransport::with_api_base("test-token".to_string(), server.url()); + + assert!(transport.send_typing("123456", true).await.is_ok()); + } + + #[tokio::test] + async fn transport_send_typing_false_is_noop() { + let transport = DiscordTransport::new("test-token".to_string()); + assert!(transport.send_typing("123456", false).await.is_ok()); + } + + #[tokio::test] + async fn transport_handles_http_error() { + let mut server = mockito::Server::new_async().await; + server + .mock("POST", "/channels/123456/messages") + .with_status(500) + .with_body("Internal Server Error") + .create_async() + .await; + + let transport = + DiscordTransport::with_api_base("test-token".to_string(), server.url()); + + let result = transport.send_message("123456", "hello", "").await; + assert!(result.is_err()); + assert!(result.unwrap_err().contains("500")); + } + + // ── ChatTransport trait satisfaction ───────────────────────────────── + + #[test] + fn discord_transport_satisfies_trait() { + fn assert_transport() {} + assert_transport::(); + + let _: Arc = + Arc::new(DiscordTransport::new("test-token".to_string())); + } +} diff --git a/server/src/chat/transport/discord/mod.rs b/server/src/chat/transport/discord/mod.rs new file mode 100644 index 00000000..04503a82 --- /dev/null +++ b/server/src/chat/transport/discord/mod.rs @@ -0,0 +1,18 @@ +//! Discord Bot integration. +//! +//! Provides: +//! - [`DiscordTransport`] — a [`ChatTransport`] that sends messages via the +//! Discord REST API (`/channels/{id}/messages`). +//! - [`gateway::spawn_gateway`] — connects to the Discord Gateway WebSocket, +//! receives `MESSAGE_CREATE` events, and dispatches commands. +//! - [`DiscordContext`] — shared context for the bot. + +pub mod commands; +pub mod format; +pub mod gateway; +pub mod history; +pub mod meta; + +pub use commands::DiscordContext; +pub use history::load_discord_history; +pub use meta::DiscordTransport; diff --git a/server/src/chat/transport/matrix/config.rs b/server/src/chat/transport/matrix/config.rs index 7c4b211f..b8591814 100644 --- a/server/src/chat/transport/matrix/config.rs +++ b/server/src/chat/transport/matrix/config.rs @@ -64,7 +64,8 @@ pub struct BotConfig { /// manually while the bot is running. #[serde(default)] pub ambient_rooms: Vec, - /// Chat transport to use: `"matrix"` (default) or `"whatsapp"`. + /// Chat transport to use: `"matrix"` (default), `"whatsapp"`, `"slack"`, + /// or `"discord"`. /// /// Selects which [`ChatTransport`] implementation the bot uses for /// sending and editing messages. Currently only read during bot @@ -134,6 +135,20 @@ pub struct BotConfig { /// Slack channel IDs the bot should listen in. #[serde(default)] pub slack_channel_ids: Vec, + + // ── Discord Bot API fields ────────────────────────────────────── + // These are only required when `transport = "discord"`. + + /// Discord bot token from the Discord Developer Portal. + #[serde(default)] + pub discord_bot_token: Option, + /// Discord channel IDs the bot should listen in. + #[serde(default)] + pub discord_channel_ids: Vec, + /// Discord user IDs allowed to interact with the bot. + /// When empty or absent, all users in configured channels are allowed. + #[serde(default)] + pub discord_allowed_users: Vec, } fn default_transport() -> String { @@ -241,6 +256,22 @@ impl BotConfig { ); return None; } + } else if config.transport == "discord" { + // Validate Discord-specific fields. + if config.discord_bot_token.as_ref().is_none_or(|s| s.is_empty()) { + eprintln!( + "[bot] bot.toml: transport=\"discord\" requires \ + discord_bot_token" + ); + return None; + } + if config.discord_channel_ids.is_empty() { + eprintln!( + "[bot] bot.toml: transport=\"discord\" requires \ + at least one discord_channel_ids entry" + ); + return None; + } } else { // Default transport is Matrix — validate Matrix-specific fields. if config.homeserver.as_ref().is_none_or(|s| s.is_empty()) { @@ -1054,4 +1085,99 @@ whatsapp_allowed_phones = ["+15551234567", "+15559876543"] vec!["+15551234567", "+15559876543"] ); } + + // ── Discord config tests ────────────────────────────────────────── + + #[test] + fn load_discord_transport_reads_config() { + let tmp = tempfile::tempdir().unwrap(); + let sk = tmp.path().join(".huskies"); + fs::create_dir_all(&sk).unwrap(); + fs::write( + sk.join("bot.toml"), + r#" +enabled = true +transport = "discord" +discord_bot_token = "Bot.Token.Here" +discord_channel_ids = ["123456789012345678"] +"#, + ) + .unwrap(); + let config = BotConfig::load(tmp.path()).unwrap(); + assert_eq!(config.transport, "discord"); + assert_eq!( + config.discord_bot_token.as_deref(), + Some("Bot.Token.Here") + ); + assert_eq!( + config.discord_channel_ids, + vec!["123456789012345678"] + ); + } + + #[test] + fn load_discord_returns_none_when_missing_bot_token() { + let tmp = tempfile::tempdir().unwrap(); + let sk = tmp.path().join(".huskies"); + fs::create_dir_all(&sk).unwrap(); + fs::write( + sk.join("bot.toml"), + r#" +enabled = true +transport = "discord" +discord_channel_ids = ["123456789012345678"] +"#, + ) + .unwrap(); + assert!(BotConfig::load(tmp.path()).is_none()); + } + + #[test] + fn load_discord_returns_none_when_missing_channel_ids() { + let tmp = tempfile::tempdir().unwrap(); + let sk = tmp.path().join(".huskies"); + fs::create_dir_all(&sk).unwrap(); + fs::write( + sk.join("bot.toml"), + r#" +enabled = true +transport = "discord" +discord_bot_token = "Bot.Token.Here" +"#, + ) + .unwrap(); + assert!(BotConfig::load(tmp.path()).is_none()); + } + + #[test] + fn discord_allowed_users_defaults_to_empty_when_absent() { + let config: BotConfig = toml::from_str( + r#" +enabled = true +transport = "discord" +discord_bot_token = "Bot.Token.Here" +discord_channel_ids = ["123456789"] +"#, + ) + .unwrap(); + assert!(config.discord_allowed_users.is_empty()); + } + + #[test] + fn discord_allowed_users_deserializes_list() { + let config: BotConfig = toml::from_str( + r#" +enabled = true +transport = "discord" +discord_bot_token = "Bot.Token.Here" +discord_channel_ids = ["123456789"] +discord_allowed_users = ["111222333", "444555666"] +"#, + ) + .unwrap(); + assert_eq!( + config.discord_allowed_users, + vec!["111222333", "444555666"] + ); + } } diff --git a/server/src/chat/transport/mod.rs b/server/src/chat/transport/mod.rs index a863087d..c4c0eb6e 100644 --- a/server/src/chat/transport/mod.rs +++ b/server/src/chat/transport/mod.rs @@ -1,3 +1,4 @@ +pub mod discord; pub mod matrix; pub mod slack; pub mod whatsapp; diff --git a/server/src/main.rs b/server/src/main.rs index 138258db..59787c8b 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -339,12 +339,14 @@ async fn main() -> Result<(), std::io::Error> { // before watcher_tx is moved into AppContext. let watcher_rx_for_whatsapp = watcher_tx.subscribe(); let watcher_rx_for_slack = watcher_tx.subscribe(); + let watcher_rx_for_discord = watcher_tx.subscribe(); // Wrap perm_rx in Arc so it can be shared with both the WebSocket // handler (via AppContext) and the Matrix bot. let perm_rx = Arc::new(tokio::sync::Mutex::new(perm_rx)); let perm_rx_for_bot = Arc::clone(&perm_rx); let perm_rx_for_whatsapp = Arc::clone(&perm_rx); let perm_rx_for_slack = Arc::clone(&perm_rx); + let perm_rx_for_discord = Arc::clone(&perm_rx); // Capture project root, agents Arc, and reconciliation sender before ctx // is consumed by build_routes. @@ -441,9 +443,49 @@ async fn main() -> Result<(), std::io::Error> { }) }); + // Build Discord context if bot.toml configures transport = "discord". + let discord_ctx: Option> = startup_root + .as_ref() + .and_then(|root| chat::transport::matrix::BotConfig::load(root)) + .filter(|cfg| cfg.transport == "discord") + .map(|cfg| { + let transport = Arc::new(chat::transport::discord::DiscordTransport::new( + cfg.discord_bot_token.clone().unwrap_or_default(), + )); + let bot_name = cfg + .display_name + .clone() + .unwrap_or_else(|| "Assistant".to_string()); + let root = startup_root.clone().unwrap(); + let history = chat::transport::discord::load_discord_history(&root); + let channel_ids: std::collections::HashSet = + cfg.discord_channel_ids.iter().cloned().collect(); + let allowed_users: std::collections::HashSet = + cfg.discord_allowed_users.iter().cloned().collect(); + Arc::new(chat::transport::discord::DiscordContext { + bot_token: cfg.discord_bot_token.clone().unwrap_or_default(), + transport, + project_root: root, + agents: Arc::clone(&startup_agents), + bot_name, + bot_user_id: "discord-bot".to_string(), + ambient_rooms: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())), + history: std::sync::Arc::new(tokio::sync::Mutex::new(history)), + history_size: cfg.history_size, + channel_ids, + allowed_users, + perm_rx: perm_rx_for_discord, + pending_perm_replies: Arc::new(tokio::sync::Mutex::new( + std::collections::HashMap::new(), + )), + permission_timeout_secs: cfg.permission_timeout_secs, + }) + }); + // Build a best-effort shutdown notifier for webhook-based transports. // // • Slack: channels are fixed at startup (channel_ids from bot.toml). + // • Discord: channels are fixed at startup (channel_ids from bot.toml). // • WhatsApp: active senders are tracked at runtime in ambient_rooms. // We keep the WhatsApp context Arc so we can read the rooms at shutdown. // • Matrix: the bot task manages its own announcement via matrix_shutdown_tx. @@ -454,6 +496,13 @@ async fn main() -> Result<(), std::io::Error> { channels, ctx.bot_name.clone(), ))) + } else if let Some(ref ctx) = discord_ctx { + let channels: Vec = ctx.channel_ids.iter().cloned().collect(); + Some(Arc::new(BotShutdownNotifier::new( + Arc::clone(&ctx.transport) as Arc, + channels, + ctx.bot_name.clone(), + ))) } else { None }; @@ -496,6 +545,18 @@ async fn main() -> Result<(), std::io::Error> { notifier.notify_startup().await; }); } + if let Some(ref ctx) = discord_ctx { + let transport = Arc::clone(&ctx.transport) as Arc; + let bot_name = ctx.bot_name.clone(); + let channels: Vec = ctx.channel_ids.iter().cloned().collect(); + tokio::spawn(async move { + if channels.is_empty() { + return; + } + let notifier = crate::rebuild::BotShutdownNotifier::new(transport, channels, bot_name); + notifier.notify_startup().await; + }); + } // Watch channel: signals the Matrix bot task to send a shutdown announcement. // `None` initial value means "server is running". @@ -559,6 +620,21 @@ async fn main() -> Result<(), std::io::Error> { } else { drop(watcher_rx_for_slack); } + if let (Some(ctx), Some(root)) = (&discord_ctx, &startup_root) { + // Spawn the Discord Gateway WebSocket listener. + chat::transport::discord::gateway::spawn_gateway(Arc::clone(ctx)); + + // Spawn stage-transition notification listener for Discord. + let channel_ids: Vec = ctx.channel_ids.iter().cloned().collect(); + chat::transport::matrix::notifications::spawn_notification_listener( + Arc::clone(&ctx.transport) as Arc, + move || channel_ids.clone(), + watcher_rx_for_discord, + root.clone(), + ); + } else { + drop(watcher_rx_for_discord); + } // On startup: // 1. Reconcile any stories whose agent work was committed while the server was