huskies: merge 472_story_discord_chat_transport
This commit is contained in:
@@ -246,6 +246,7 @@ Story Kit includes a chat bot that can be connected to one messaging platform at
|
|||||||
| WhatsApp (Meta Cloud API) | `bot.toml.whatsapp-meta.example` | `/webhook/whatsapp` |
|
| WhatsApp (Meta Cloud API) | `bot.toml.whatsapp-meta.example` | `/webhook/whatsapp` |
|
||||||
| WhatsApp (Twilio) | `bot.toml.whatsapp-twilio.example` | `/webhook/whatsapp` |
|
| WhatsApp (Twilio) | `bot.toml.whatsapp-twilio.example` | `/webhook/whatsapp` |
|
||||||
| Slack | `bot.toml.slack.example` | `/webhook/slack` |
|
| Slack | `bot.toml.slack.example` | `/webhook/slack` |
|
||||||
|
| Discord | `bot.toml.discord.example` | *(uses Discord Gateway WebSocket)* |
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cp .huskies/bot.toml.matrix.example .huskies/bot.toml
|
cp .huskies/bot.toml.matrix.example .huskies/bot.toml
|
||||||
|
|||||||
@@ -0,0 +1,28 @@
|
|||||||
|
# Discord Transport
|
||||||
|
# Copy this file to bot.toml and fill in your values.
|
||||||
|
# Only one transport can be active at a time.
|
||||||
|
#
|
||||||
|
# Setup:
|
||||||
|
# 1. Create a Discord Application at discord.com/developers/applications
|
||||||
|
# 2. Go to Bot → create a bot and copy the token
|
||||||
|
# 3. Enable "Message Content Intent" under Privileged Gateway Intents
|
||||||
|
# 4. Go to OAuth2 → URL Generator, select "bot" scope with permissions:
|
||||||
|
# Send Messages, Read Message History, Manage Messages
|
||||||
|
# 5. Use the generated URL to invite the bot to your server
|
||||||
|
# 6. Right-click the channel(s) → Copy Channel ID (enable Developer Mode in settings)
|
||||||
|
|
||||||
|
enabled = true
|
||||||
|
transport = "discord"
|
||||||
|
|
||||||
|
discord_bot_token = "your-bot-token-here"
|
||||||
|
discord_channel_ids = ["123456789012345678"]
|
||||||
|
|
||||||
|
# Discord user IDs allowed to interact with the bot.
|
||||||
|
# When empty, all users in configured channels can interact.
|
||||||
|
# discord_allowed_users = ["111222333444555666"]
|
||||||
|
|
||||||
|
# Bot display name (used in formatted messages).
|
||||||
|
# display_name = "Assistant"
|
||||||
|
|
||||||
|
# Maximum conversation turns to remember per channel (default: 20).
|
||||||
|
# history_size = 20
|
||||||
Generated
+83
@@ -1096,6 +1096,21 @@ version = "0.1.5"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
|
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "foreign-types"
|
||||||
|
version = "0.3.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
|
||||||
|
dependencies = [
|
||||||
|
"foreign-types-shared",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "foreign-types-shared"
|
||||||
|
version = "0.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "form_urlencoded"
|
name = "form_urlencoded"
|
||||||
version = "1.2.2"
|
version = "1.2.2"
|
||||||
@@ -2588,6 +2603,23 @@ dependencies = [
|
|||||||
"version_check",
|
"version_check",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "native-tls"
|
||||||
|
version = "0.2.18"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "465500e14ea162429d264d44189adc38b199b62b1c21eea9f69e4b73cb03bbf2"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"log",
|
||||||
|
"openssl",
|
||||||
|
"openssl-probe",
|
||||||
|
"openssl-sys",
|
||||||
|
"schannel",
|
||||||
|
"security-framework",
|
||||||
|
"security-framework-sys",
|
||||||
|
"tempfile",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "new_debug_unreachable"
|
name = "new_debug_unreachable"
|
||||||
version = "1.0.6"
|
version = "1.0.6"
|
||||||
@@ -2721,12 +2753,50 @@ version = "0.3.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381"
|
checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "openssl"
|
||||||
|
version = "0.10.76"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "951c002c75e16ea2c65b8c7e4d3d51d5530d8dfa7d060b4776828c88cfb18ecf"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.11.0",
|
||||||
|
"cfg-if",
|
||||||
|
"foreign-types",
|
||||||
|
"libc",
|
||||||
|
"once_cell",
|
||||||
|
"openssl-macros",
|
||||||
|
"openssl-sys",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "openssl-macros"
|
||||||
|
version = "0.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.117",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "openssl-probe"
|
name = "openssl-probe"
|
||||||
version = "0.2.1"
|
version = "0.2.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe"
|
checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "openssl-sys"
|
||||||
|
version = "0.9.112"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "57d55af3b3e226502be1526dfdba67ab0e9c96fc293004e79576b2b9edb0dbdb"
|
||||||
|
dependencies = [
|
||||||
|
"cc",
|
||||||
|
"libc",
|
||||||
|
"pkg-config",
|
||||||
|
"vcpkg",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "parking"
|
name = "parking"
|
||||||
version = "2.2.1"
|
version = "2.2.1"
|
||||||
@@ -4411,6 +4481,16 @@ dependencies = [
|
|||||||
"syn 2.0.117",
|
"syn 2.0.117",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokio-native-tls"
|
||||||
|
version = "0.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
|
||||||
|
dependencies = [
|
||||||
|
"native-tls",
|
||||||
|
"tokio",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio-rustls"
|
name = "tokio-rustls"
|
||||||
version = "0.26.4"
|
version = "0.26.4"
|
||||||
@@ -4453,7 +4533,9 @@ checksum = "8f72a05e828585856dacd553fba484c242c46e391fb0e58917c942ee9202915c"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"log",
|
"log",
|
||||||
|
"native-tls",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tokio-native-tls",
|
||||||
"tungstenite 0.29.0",
|
"tungstenite 0.29.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -4685,6 +4767,7 @@ dependencies = [
|
|||||||
"http",
|
"http",
|
||||||
"httparse",
|
"httparse",
|
||||||
"log",
|
"log",
|
||||||
|
"native-tls",
|
||||||
"rand 0.9.2",
|
"rand 0.9.2",
|
||||||
"sha1",
|
"sha1",
|
||||||
"thiserror 2.0.18",
|
"thiserror 2.0.18",
|
||||||
|
|||||||
+1
-1
@@ -29,7 +29,7 @@ tempfile = "3"
|
|||||||
tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync"] }
|
tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync"] }
|
||||||
toml = "1.1.0"
|
toml = "1.1.0"
|
||||||
uuid = { version = "1.22.0", features = ["v4", "serde"] }
|
uuid = { version = "1.22.0", features = ["v4", "serde"] }
|
||||||
tokio-tungstenite = "0.29.0"
|
tokio-tungstenite = { version = "0.29.0", features = ["connect", "native-tls"] }
|
||||||
walkdir = "2.5.0"
|
walkdir = "2.5.0"
|
||||||
filetime = "0.2"
|
filetime = "0.2"
|
||||||
matrix-sdk = { version = "0.16.0", default-features = false, features = [
|
matrix-sdk = { version = "0.16.0", default-features = false, features = [
|
||||||
|
|||||||
+1
-1
@@ -34,6 +34,7 @@ walkdir = { workspace = true }
|
|||||||
matrix-sdk = { workspace = true }
|
matrix-sdk = { workspace = true }
|
||||||
pulldown-cmark = { workspace = true }
|
pulldown-cmark = { workspace = true }
|
||||||
regex = { workspace = true }
|
regex = { workspace = true }
|
||||||
|
tokio-tungstenite = { workspace = true }
|
||||||
|
|
||||||
# Force bundled SQLite so static musl builds don't need a system libsqlite3
|
# Force bundled SQLite so static musl builds don't need a system libsqlite3
|
||||||
libsqlite3-sys = { version = "0.35.0", features = ["bundled"] }
|
libsqlite3-sys = { version = "0.35.0", features = ["bundled"] }
|
||||||
@@ -44,6 +45,5 @@ libc = { workspace = true }
|
|||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tempfile = { workspace = true }
|
tempfile = { workspace = true }
|
||||||
tokio-tungstenite = { workspace = true }
|
|
||||||
mockito = "1"
|
mockito = "1"
|
||||||
filetime = { workspace = true }
|
filetime = { workspace = true }
|
||||||
|
|||||||
@@ -0,0 +1,641 @@
|
|||||||
|
//! Discord incoming message dispatch and command handling.
|
||||||
|
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
use tokio::sync::{Mutex as TokioMutex, oneshot};
|
||||||
|
|
||||||
|
use crate::agents::AgentPool;
|
||||||
|
use crate::chat::transport::matrix::{ConversationEntry, ConversationRole, RoomConversation};
|
||||||
|
use crate::chat::util::is_permission_approval;
|
||||||
|
use crate::chat::ChatTransport;
|
||||||
|
use crate::http::context::{PermissionDecision, PermissionForward};
|
||||||
|
use crate::slog;
|
||||||
|
|
||||||
|
use super::format::markdown_to_discord;
|
||||||
|
use super::history::{DiscordConversationHistory, save_discord_history};
|
||||||
|
use super::meta::DiscordTransport;
|
||||||
|
|
||||||
|
// ── Shared context ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Shared context for the Discord bot, used by both the Gateway listener
|
||||||
|
/// and any future webhook handlers.
|
||||||
|
pub struct DiscordContext {
|
||||||
|
pub bot_token: String,
|
||||||
|
pub transport: Arc<DiscordTransport>,
|
||||||
|
pub project_root: PathBuf,
|
||||||
|
pub agents: Arc<AgentPool>,
|
||||||
|
pub bot_name: String,
|
||||||
|
/// The bot's Discord user ID (set dynamically from READY event, but
|
||||||
|
/// also stored here for command dispatch).
|
||||||
|
pub bot_user_id: String,
|
||||||
|
pub ambient_rooms: Arc<Mutex<HashSet<String>>>,
|
||||||
|
/// Per-channel conversation history for LLM passthrough.
|
||||||
|
pub history: DiscordConversationHistory,
|
||||||
|
/// Maximum number of conversation entries to keep per channel.
|
||||||
|
pub history_size: usize,
|
||||||
|
/// Allowed channel IDs (messages from other channels are ignored).
|
||||||
|
pub channel_ids: HashSet<String>,
|
||||||
|
/// Allowed Discord user IDs. When non-empty, only listed users can
|
||||||
|
/// interact with the bot. When empty, all users are allowed.
|
||||||
|
pub allowed_users: HashSet<String>,
|
||||||
|
/// Permission requests from the MCP `prompt_permission` tool arrive here.
|
||||||
|
pub perm_rx: Arc<TokioMutex<tokio::sync::mpsc::UnboundedReceiver<PermissionForward>>>,
|
||||||
|
/// Pending permission replies keyed by channel ID.
|
||||||
|
pub pending_perm_replies:
|
||||||
|
Arc<TokioMutex<HashMap<String, oneshot::Sender<PermissionDecision>>>>,
|
||||||
|
/// Seconds before an unanswered permission prompt is auto-denied.
|
||||||
|
pub permission_timeout_secs: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Incoming message dispatch ───────────────────────────────────────────
|
||||||
|
|
||||||
|
pub(super) async fn handle_incoming_message(
|
||||||
|
ctx: &DiscordContext,
|
||||||
|
channel: &str,
|
||||||
|
user: &str,
|
||||||
|
message: &str,
|
||||||
|
) {
|
||||||
|
use crate::chat::commands::{CommandDispatch, try_handle_command};
|
||||||
|
|
||||||
|
// If there is a pending permission prompt for this channel, interpret the
|
||||||
|
// message as a yes/no response.
|
||||||
|
{
|
||||||
|
let mut pending = ctx.pending_perm_replies.lock().await;
|
||||||
|
if let Some(tx) = pending.remove(channel) {
|
||||||
|
let decision = if is_permission_approval(message) {
|
||||||
|
PermissionDecision::Approve
|
||||||
|
} else {
|
||||||
|
PermissionDecision::Deny
|
||||||
|
};
|
||||||
|
let _ = tx.send(decision);
|
||||||
|
let confirmation = if decision == PermissionDecision::Approve {
|
||||||
|
"Permission approved."
|
||||||
|
} else {
|
||||||
|
"Permission denied."
|
||||||
|
};
|
||||||
|
let formatted = markdown_to_discord(confirmation);
|
||||||
|
let _ = ctx.transport.send_message(channel, &formatted, "").await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let dispatch = CommandDispatch {
|
||||||
|
bot_name: &ctx.bot_name,
|
||||||
|
bot_user_id: &ctx.bot_user_id,
|
||||||
|
project_root: &ctx.project_root,
|
||||||
|
agents: &ctx.agents,
|
||||||
|
ambient_rooms: &ctx.ambient_rooms,
|
||||||
|
room_id: channel,
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(response) = try_handle_command(&dispatch, message) {
|
||||||
|
slog!("[discord] Sending command response to {channel}");
|
||||||
|
let response = markdown_to_discord(&response);
|
||||||
|
if let Err(e) = ctx.transport.send_message(channel, &response, "").await {
|
||||||
|
slog!("[discord] Failed to send reply to {channel}: {e}");
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for async commands (htop, delete).
|
||||||
|
if let Some(htop_cmd) = crate::chat::transport::matrix::htop::extract_htop_command(
|
||||||
|
message,
|
||||||
|
&ctx.bot_name,
|
||||||
|
&ctx.bot_user_id,
|
||||||
|
) {
|
||||||
|
use crate::chat::transport::matrix::htop::HtopCommand;
|
||||||
|
slog!("[discord] Handling htop command from {user} in {channel}");
|
||||||
|
match htop_cmd {
|
||||||
|
HtopCommand::Stop => {
|
||||||
|
let _ = ctx
|
||||||
|
.transport
|
||||||
|
.send_message(channel, "htop stopped.", "")
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
HtopCommand::Start { duration_secs } => {
|
||||||
|
let snapshot = crate::chat::transport::matrix::htop::build_htop_message(
|
||||||
|
&ctx.agents,
|
||||||
|
0,
|
||||||
|
duration_secs,
|
||||||
|
);
|
||||||
|
let snapshot = markdown_to_discord(&snapshot);
|
||||||
|
let msg_id = match ctx.transport.send_message(channel, &snapshot, "").await {
|
||||||
|
Ok(id) => id,
|
||||||
|
Err(e) => {
|
||||||
|
slog!("[discord] Failed to send htop message: {e}");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let transport = Arc::clone(&ctx.transport) as Arc<dyn ChatTransport>;
|
||||||
|
let agents = Arc::clone(&ctx.agents);
|
||||||
|
let ch = channel.to_string();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let interval = std::time::Duration::from_secs(2);
|
||||||
|
let total_ticks = (duration_secs as usize) / 2;
|
||||||
|
for tick in 1..=total_ticks {
|
||||||
|
tokio::time::sleep(interval).await;
|
||||||
|
let updated =
|
||||||
|
crate::chat::transport::matrix::htop::build_htop_message(
|
||||||
|
&agents,
|
||||||
|
(tick * 2) as u32,
|
||||||
|
duration_secs,
|
||||||
|
);
|
||||||
|
let updated = markdown_to_discord(&updated);
|
||||||
|
if let Err(e) =
|
||||||
|
transport.edit_message(&ch, &msg_id, &updated, "").await
|
||||||
|
{
|
||||||
|
slog!("[discord] Failed to edit htop message: {e}");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(del_cmd) = crate::chat::transport::matrix::delete::extract_delete_command(
|
||||||
|
message,
|
||||||
|
&ctx.bot_name,
|
||||||
|
&ctx.bot_user_id,
|
||||||
|
) {
|
||||||
|
let response = match del_cmd {
|
||||||
|
crate::chat::transport::matrix::delete::DeleteCommand::Delete { story_number } => {
|
||||||
|
slog!("[discord] Handling delete command from {user}: story {story_number}");
|
||||||
|
crate::chat::transport::matrix::delete::handle_delete(
|
||||||
|
&ctx.bot_name,
|
||||||
|
&story_number,
|
||||||
|
&ctx.project_root,
|
||||||
|
&ctx.agents,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
crate::chat::transport::matrix::delete::DeleteCommand::BadArgs => {
|
||||||
|
format!("Usage: `{} delete <number>`", ctx.bot_name)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let response = markdown_to_discord(&response);
|
||||||
|
let _ = ctx.transport.send_message(channel, &response, "").await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if crate::chat::transport::matrix::rebuild::extract_rebuild_command(
|
||||||
|
message,
|
||||||
|
&ctx.bot_name,
|
||||||
|
&ctx.bot_user_id,
|
||||||
|
)
|
||||||
|
.is_some()
|
||||||
|
{
|
||||||
|
slog!("[discord] Handling rebuild command from {user} in {channel}");
|
||||||
|
let ack = "Rebuilding server… this may take a moment.";
|
||||||
|
let _ = ctx.transport.send_message(channel, ack, "").await;
|
||||||
|
let response = crate::chat::transport::matrix::rebuild::handle_rebuild(
|
||||||
|
&ctx.bot_name,
|
||||||
|
&ctx.project_root,
|
||||||
|
&ctx.agents,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let response = markdown_to_discord(&response);
|
||||||
|
let _ = ctx.transport.send_message(channel, &response, "").await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(rmtree_cmd) = crate::chat::transport::matrix::rmtree::extract_rmtree_command(
|
||||||
|
message,
|
||||||
|
&ctx.bot_name,
|
||||||
|
&ctx.bot_user_id,
|
||||||
|
) {
|
||||||
|
let response = match rmtree_cmd {
|
||||||
|
crate::chat::transport::matrix::rmtree::RmtreeCommand::Rmtree { story_number } => {
|
||||||
|
slog!(
|
||||||
|
"[discord] Handling rmtree command from {user} in {channel}: story {story_number}"
|
||||||
|
);
|
||||||
|
crate::chat::transport::matrix::rmtree::handle_rmtree(
|
||||||
|
&ctx.bot_name,
|
||||||
|
&story_number,
|
||||||
|
&ctx.project_root,
|
||||||
|
&ctx.agents,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
crate::chat::transport::matrix::rmtree::RmtreeCommand::BadArgs => {
|
||||||
|
format!("Usage: `{} rmtree <number>`", ctx.bot_name)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let response = markdown_to_discord(&response);
|
||||||
|
let _ = ctx.transport.send_message(channel, &response, "").await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if crate::chat::transport::matrix::reset::extract_reset_command(
|
||||||
|
message,
|
||||||
|
&ctx.bot_name,
|
||||||
|
&ctx.bot_user_id,
|
||||||
|
)
|
||||||
|
.is_some()
|
||||||
|
{
|
||||||
|
slog!("[discord] Handling reset command from {user} in {channel}");
|
||||||
|
{
|
||||||
|
let mut guard = ctx.history.lock().await;
|
||||||
|
let conv = guard
|
||||||
|
.entry(channel.to_string())
|
||||||
|
.or_insert_with(RoomConversation::default);
|
||||||
|
conv.session_id = None;
|
||||||
|
conv.entries.clear();
|
||||||
|
save_discord_history(&ctx.project_root, &guard);
|
||||||
|
}
|
||||||
|
let _ = ctx
|
||||||
|
.transport
|
||||||
|
.send_message(channel, "Session cleared.", "")
|
||||||
|
.await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(start_cmd) = crate::chat::transport::matrix::start::extract_start_command(
|
||||||
|
message,
|
||||||
|
&ctx.bot_name,
|
||||||
|
&ctx.bot_user_id,
|
||||||
|
) {
|
||||||
|
let response = match start_cmd {
|
||||||
|
crate::chat::transport::matrix::start::StartCommand::Start {
|
||||||
|
story_number,
|
||||||
|
agent_hint,
|
||||||
|
} => {
|
||||||
|
slog!(
|
||||||
|
"[discord] Handling start command from {user} in {channel}: story {story_number}"
|
||||||
|
);
|
||||||
|
crate::chat::transport::matrix::start::handle_start(
|
||||||
|
&ctx.bot_name,
|
||||||
|
&story_number,
|
||||||
|
agent_hint.as_deref(),
|
||||||
|
&ctx.project_root,
|
||||||
|
&ctx.agents,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
crate::chat::transport::matrix::start::StartCommand::BadArgs => {
|
||||||
|
format!("Usage: `{} start <number>`", ctx.bot_name)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let response = markdown_to_discord(&response);
|
||||||
|
let _ = ctx.transport.send_message(channel, &response, "").await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(assign_cmd) = crate::chat::transport::matrix::assign::extract_assign_command(
|
||||||
|
message,
|
||||||
|
&ctx.bot_name,
|
||||||
|
&ctx.bot_user_id,
|
||||||
|
) {
|
||||||
|
let response = match assign_cmd {
|
||||||
|
crate::chat::transport::matrix::assign::AssignCommand::Assign {
|
||||||
|
story_number,
|
||||||
|
model,
|
||||||
|
} => {
|
||||||
|
slog!(
|
||||||
|
"[discord] Handling assign command from {user} in {channel}: story {story_number} model {model}"
|
||||||
|
);
|
||||||
|
crate::chat::transport::matrix::assign::handle_assign(
|
||||||
|
&ctx.bot_name,
|
||||||
|
&story_number,
|
||||||
|
&model,
|
||||||
|
&ctx.project_root,
|
||||||
|
&ctx.agents,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
crate::chat::transport::matrix::assign::AssignCommand::BadArgs => {
|
||||||
|
format!("Usage: `{} assign <number> <model>`", ctx.bot_name)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let response = markdown_to_discord(&response);
|
||||||
|
let _ = ctx.transport.send_message(channel, &response, "").await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// No command matched — forward to LLM for conversational response.
|
||||||
|
slog!("[discord] No command matched, forwarding to LLM for {user} in {channel}");
|
||||||
|
handle_llm_message(ctx, channel, user, message).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Forward a message to Claude Code and send the response back via Discord.
|
||||||
|
async fn handle_llm_message(
|
||||||
|
ctx: &DiscordContext,
|
||||||
|
channel: &str,
|
||||||
|
user: &str,
|
||||||
|
user_message: &str,
|
||||||
|
) {
|
||||||
|
use crate::chat::util::drain_complete_paragraphs;
|
||||||
|
use crate::llm::providers::claude_code::{ClaudeCodeProvider, ClaudeCodeResult};
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use tokio::sync::watch;
|
||||||
|
|
||||||
|
// Look up existing session ID for this channel.
|
||||||
|
let resume_session_id: Option<String> = {
|
||||||
|
let guard = ctx.history.lock().await;
|
||||||
|
guard
|
||||||
|
.get(channel)
|
||||||
|
.and_then(|conv| conv.session_id.clone())
|
||||||
|
};
|
||||||
|
|
||||||
|
let bot_name = &ctx.bot_name;
|
||||||
|
let prompt = format!(
|
||||||
|
"[Your name is {bot_name}. Refer to yourself as {bot_name}, not Claude.]\n\n{user}: {user_message}"
|
||||||
|
);
|
||||||
|
|
||||||
|
let provider = ClaudeCodeProvider::new();
|
||||||
|
let (_cancel_tx, mut cancel_rx) = watch::channel(false);
|
||||||
|
|
||||||
|
// Channel for sending complete chunks to the Discord posting task.
|
||||||
|
let (msg_tx, mut msg_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
|
||||||
|
let msg_tx_for_callback = msg_tx.clone();
|
||||||
|
|
||||||
|
// Spawn a task to post messages as they arrive.
|
||||||
|
let post_transport = Arc::clone(&ctx.transport) as Arc<dyn ChatTransport>;
|
||||||
|
let post_channel = channel.to_string();
|
||||||
|
let post_task = tokio::spawn(async move {
|
||||||
|
while let Some(chunk) = msg_rx.recv().await {
|
||||||
|
let formatted = markdown_to_discord(&chunk);
|
||||||
|
let _ = post_transport
|
||||||
|
.send_message(&post_channel, &formatted, "")
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Shared buffer between the sync token callback and the async scope.
|
||||||
|
let buffer = Arc::new(std::sync::Mutex::new(String::new()));
|
||||||
|
let buffer_for_callback = Arc::clone(&buffer);
|
||||||
|
let sent_any_chunk = Arc::new(AtomicBool::new(false));
|
||||||
|
let sent_any_chunk_for_callback = Arc::clone(&sent_any_chunk);
|
||||||
|
|
||||||
|
let project_root_str = ctx.project_root.to_string_lossy().to_string();
|
||||||
|
let chat_fut = provider.chat_stream(
|
||||||
|
&prompt,
|
||||||
|
&project_root_str,
|
||||||
|
resume_session_id.as_deref(),
|
||||||
|
None,
|
||||||
|
&mut cancel_rx,
|
||||||
|
move |token| {
|
||||||
|
let mut buf = buffer_for_callback.lock().unwrap();
|
||||||
|
buf.push_str(token);
|
||||||
|
let paragraphs = drain_complete_paragraphs(&mut buf);
|
||||||
|
for chunk in paragraphs {
|
||||||
|
sent_any_chunk_for_callback.store(true, Ordering::Relaxed);
|
||||||
|
let _ = msg_tx_for_callback.send(chunk);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|_thinking| {},
|
||||||
|
|_activity| {},
|
||||||
|
);
|
||||||
|
tokio::pin!(chat_fut);
|
||||||
|
|
||||||
|
// Lock the permission receiver for the duration of this chat session.
|
||||||
|
let mut perm_rx_guard = ctx.perm_rx.lock().await;
|
||||||
|
|
||||||
|
let result = loop {
|
||||||
|
tokio::select! {
|
||||||
|
r = &mut chat_fut => break r,
|
||||||
|
|
||||||
|
Some(perm_fwd) = perm_rx_guard.recv() => {
|
||||||
|
let prompt_msg = format!(
|
||||||
|
"**Permission Request**\n\nTool: `{}`\n```json\n{}\n```\n\nReply **yes** to approve or **no** to deny.",
|
||||||
|
perm_fwd.tool_name,
|
||||||
|
serde_json::to_string_pretty(&perm_fwd.tool_input)
|
||||||
|
.unwrap_or_else(|_| perm_fwd.tool_input.to_string()),
|
||||||
|
);
|
||||||
|
let formatted = markdown_to_discord(&prompt_msg);
|
||||||
|
let _ = ctx.transport.send_message(channel, &formatted, "").await;
|
||||||
|
|
||||||
|
ctx.pending_perm_replies
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.insert(channel.to_string(), perm_fwd.response_tx);
|
||||||
|
|
||||||
|
// Spawn a timeout task: auto-deny if the user does not respond.
|
||||||
|
let pending = Arc::clone(&ctx.pending_perm_replies);
|
||||||
|
let timeout_channel = channel.to_string();
|
||||||
|
let timeout_transport = Arc::clone(&ctx.transport) as Arc<dyn ChatTransport>;
|
||||||
|
let timeout_secs = ctx.permission_timeout_secs;
|
||||||
|
tokio::spawn(async move {
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(timeout_secs)).await;
|
||||||
|
if let Some(tx) = pending.lock().await.remove(&timeout_channel) {
|
||||||
|
let _ = tx.send(PermissionDecision::Deny);
|
||||||
|
let msg = "Permission request timed out — denied (fail-closed).";
|
||||||
|
let _ = timeout_transport.send_message(&timeout_channel, msg, "").await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
drop(perm_rx_guard);
|
||||||
|
|
||||||
|
// Flush remaining text.
|
||||||
|
let remaining = buffer.lock().unwrap().trim().to_string();
|
||||||
|
let did_send_any = sent_any_chunk.load(Ordering::Relaxed);
|
||||||
|
|
||||||
|
let (assistant_reply, new_session_id) = match result {
|
||||||
|
Ok(ClaudeCodeResult {
|
||||||
|
messages,
|
||||||
|
session_id,
|
||||||
|
}) => {
|
||||||
|
let reply = if !remaining.is_empty() {
|
||||||
|
let _ = msg_tx.send(remaining.clone());
|
||||||
|
remaining
|
||||||
|
} else if !did_send_any {
|
||||||
|
let last_text = messages
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.find(|m| {
|
||||||
|
m.role == crate::llm::types::Role::Assistant && !m.content.is_empty()
|
||||||
|
})
|
||||||
|
.map(|m| m.content.clone())
|
||||||
|
.unwrap_or_default();
|
||||||
|
if !last_text.is_empty() {
|
||||||
|
let _ = msg_tx.send(last_text.clone());
|
||||||
|
}
|
||||||
|
last_text
|
||||||
|
} else {
|
||||||
|
remaining
|
||||||
|
};
|
||||||
|
slog!("[discord] session_id from chat_stream: {:?}", session_id);
|
||||||
|
(reply, session_id)
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
slog!("[discord] LLM error: {e}");
|
||||||
|
let err_msg = format!("Error processing your request: {e}");
|
||||||
|
let _ = msg_tx.send(err_msg.clone());
|
||||||
|
(err_msg, None)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Signal the posting task to finish and wait for it.
|
||||||
|
drop(msg_tx);
|
||||||
|
let _ = post_task.await;
|
||||||
|
|
||||||
|
// Record this exchange in conversation history.
|
||||||
|
if !assistant_reply.starts_with("Error processing") {
|
||||||
|
let mut guard = ctx.history.lock().await;
|
||||||
|
let conv = guard.entry(channel.to_string()).or_default();
|
||||||
|
|
||||||
|
if new_session_id.is_some() {
|
||||||
|
conv.session_id = new_session_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
conv.entries.push(ConversationEntry {
|
||||||
|
role: ConversationRole::User,
|
||||||
|
sender: user.to_string(),
|
||||||
|
content: user_message.to_string(),
|
||||||
|
});
|
||||||
|
conv.entries.push(ConversationEntry {
|
||||||
|
role: ConversationRole::Assistant,
|
||||||
|
sender: String::new(),
|
||||||
|
content: assistant_reply,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Trim to configured maximum.
|
||||||
|
if conv.entries.len() > ctx.history_size {
|
||||||
|
let excess = conv.entries.len() - ctx.history_size;
|
||||||
|
conv.entries.drain(..excess);
|
||||||
|
}
|
||||||
|
|
||||||
|
save_discord_history(&ctx.project_root, &guard);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn test_agents() -> Arc<crate::agents::AgentPool> {
|
||||||
|
Arc::new(crate::agents::AgentPool::new_test(3000))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn test_ambient_rooms() -> Arc<Mutex<HashSet<String>>> {
|
||||||
|
Arc::new(Mutex::new(HashSet::new()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn command_dispatches_through_command_registry() {
|
||||||
|
use crate::chat::commands::{CommandDispatch, try_handle_command};
|
||||||
|
|
||||||
|
let agents = test_agents();
|
||||||
|
let ambient_rooms = test_ambient_rooms();
|
||||||
|
let room_id = "123456789".to_string();
|
||||||
|
|
||||||
|
let bot_name = "Huskies";
|
||||||
|
let synthetic = format!("{bot_name} status");
|
||||||
|
|
||||||
|
let dispatch = CommandDispatch {
|
||||||
|
bot_name,
|
||||||
|
bot_user_id: "discord-bot",
|
||||||
|
project_root: std::path::Path::new("/tmp"),
|
||||||
|
agents: &agents,
|
||||||
|
ambient_rooms: &ambient_rooms,
|
||||||
|
room_id: &room_id,
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = try_handle_command(&dispatch, &synthetic);
|
||||||
|
assert!(
|
||||||
|
result.is_some(),
|
||||||
|
"status command should produce output via registry"
|
||||||
|
);
|
||||||
|
assert!(result.unwrap().contains("Pipeline Status"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rebuild_command_extracted_from_discord_message() {
|
||||||
|
let result = crate::chat::transport::matrix::rebuild::extract_rebuild_command(
|
||||||
|
"Huskies rebuild",
|
||||||
|
"Huskies",
|
||||||
|
"discord-bot",
|
||||||
|
);
|
||||||
|
assert!(result.is_some(), "'Huskies rebuild' should be recognised");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn reset_command_extracted_from_discord_message() {
|
||||||
|
let result = crate::chat::transport::matrix::reset::extract_reset_command(
|
||||||
|
"Huskies reset",
|
||||||
|
"Huskies",
|
||||||
|
"discord-bot",
|
||||||
|
);
|
||||||
|
assert!(result.is_some(), "'Huskies reset' should be recognised");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn start_command_extracted_from_discord_message() {
|
||||||
|
let result = crate::chat::transport::matrix::start::extract_start_command(
|
||||||
|
"start 42",
|
||||||
|
"Huskies",
|
||||||
|
"discord-bot",
|
||||||
|
);
|
||||||
|
assert!(result.is_some(), "plain 'start 42' should be recognised");
|
||||||
|
assert_eq!(
|
||||||
|
result,
|
||||||
|
Some(crate::chat::transport::matrix::start::StartCommand::Start {
|
||||||
|
story_number: "42".to_string(),
|
||||||
|
agent_hint: None,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn assign_command_extracted_from_discord_message() {
|
||||||
|
let result = crate::chat::transport::matrix::assign::extract_assign_command(
|
||||||
|
"assign 42 opus",
|
||||||
|
"Huskies",
|
||||||
|
"discord-bot",
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
matches!(
|
||||||
|
result,
|
||||||
|
Some(crate::chat::transport::matrix::assign::AssignCommand::Assign { .. })
|
||||||
|
),
|
||||||
|
"plain 'assign 42 opus' should be recognised on Discord"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reset_command_clears_discord_session() {
|
||||||
|
use crate::chat::transport::matrix::RoomConversation;
|
||||||
|
|
||||||
|
let channel = "123456789";
|
||||||
|
let history: DiscordConversationHistory = Arc::new(TokioMutex::new({
|
||||||
|
let mut m = HashMap::new();
|
||||||
|
m.insert(
|
||||||
|
channel.to_string(),
|
||||||
|
RoomConversation {
|
||||||
|
session_id: Some("old-session".to_string()),
|
||||||
|
entries: vec![ConversationEntry {
|
||||||
|
role: ConversationRole::User,
|
||||||
|
sender: "user123".to_string(),
|
||||||
|
content: "previous message".to_string(),
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
);
|
||||||
|
m
|
||||||
|
}));
|
||||||
|
|
||||||
|
let tmp = tempfile::tempdir().unwrap();
|
||||||
|
let sk = tmp.path().join(".huskies");
|
||||||
|
std::fs::create_dir_all(&sk).unwrap();
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut guard = history.lock().await;
|
||||||
|
let conv = guard
|
||||||
|
.entry(channel.to_string())
|
||||||
|
.or_insert_with(RoomConversation::default);
|
||||||
|
conv.session_id = None;
|
||||||
|
conv.entries.clear();
|
||||||
|
save_discord_history(tmp.path(), &guard);
|
||||||
|
}
|
||||||
|
|
||||||
|
let guard = history.lock().await;
|
||||||
|
let conv = guard.get(channel).unwrap();
|
||||||
|
assert!(conv.session_id.is_none(), "session_id should be cleared");
|
||||||
|
assert!(conv.entries.is_empty(), "entries should be cleared");
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
//! Markdown to Discord format conversion.
|
||||||
|
//!
|
||||||
|
//! Discord supports standard Markdown natively, so unlike Slack we only need
|
||||||
|
//! minimal adjustments (e.g. truncating messages that exceed Discord's limit).
|
||||||
|
|
||||||
|
/// Convert Markdown text to Discord-compatible format.
|
||||||
|
///
|
||||||
|
/// Discord supports most standard Markdown:
|
||||||
|
/// - `**bold**`, `*italic*`, `~~strikethrough~~`
|
||||||
|
/// - `` `code` `` and ````code blocks````
|
||||||
|
/// - `> blockquotes`
|
||||||
|
/// - `[text](url)` links
|
||||||
|
///
|
||||||
|
/// The main adjustment is ensuring messages stay within Discord's 2000-char
|
||||||
|
/// limit. Actual truncation is handled by the transport layer; this function
|
||||||
|
/// performs any lightweight text normalization needed.
|
||||||
|
pub fn markdown_to_discord(text: &str) -> String {
|
||||||
|
use crate::chat::util::normalize_line_breaks;
|
||||||
|
normalize_line_breaks(text)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn discord_plain_text_unchanged() {
|
||||||
|
let plain = "Hello, this is a plain message.";
|
||||||
|
assert_eq!(markdown_to_discord(plain), plain);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn discord_bold_preserved() {
|
||||||
|
assert_eq!(markdown_to_discord("**bold text**"), "**bold text**");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn discord_code_block_preserved() {
|
||||||
|
let input = "```rust\nlet x = 1;\n```";
|
||||||
|
assert_eq!(markdown_to_discord(input), input);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn discord_empty_string_unchanged() {
|
||||||
|
assert_eq!(markdown_to_discord(""), "");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn discord_link_preserved() {
|
||||||
|
let input = "[click here](https://example.com)";
|
||||||
|
assert_eq!(markdown_to_discord(input), input);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,507 @@
|
|||||||
|
//! Minimal Discord Gateway WebSocket client.
|
||||||
|
//!
|
||||||
|
//! Connects to the Discord Gateway, authenticates with a bot token, maintains
|
||||||
|
//! the heartbeat keepalive, and dispatches incoming `MESSAGE_CREATE` events to
|
||||||
|
//! the command handler.
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use futures::{SinkExt, StreamExt};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tokio_tungstenite::tungstenite::Message as WsMessage;
|
||||||
|
|
||||||
|
use crate::slog;
|
||||||
|
|
||||||
|
use super::commands::{self, DiscordContext};
|
||||||
|
|
||||||
|
// ── Gateway opcodes ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
const OP_DISPATCH: u8 = 0;
|
||||||
|
const OP_HEARTBEAT: u8 = 1;
|
||||||
|
const OP_IDENTIFY: u8 = 2;
|
||||||
|
const OP_RECONNECT: u8 = 7;
|
||||||
|
const OP_INVALID_SESSION: u8 = 9;
|
||||||
|
const OP_HELLO: u8 = 10;
|
||||||
|
const OP_HEARTBEAT_ACK: u8 = 11;
|
||||||
|
|
||||||
|
// ── Gateway intents ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// GUILDS | GUILD_MESSAGES | MESSAGE_CONTENT (privileged)
|
||||||
|
const GATEWAY_INTENTS: u64 = (1 << 0) | (1 << 9) | (1 << 15);
|
||||||
|
|
||||||
|
// ── Gateway payload types ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
struct GatewayPayload {
|
||||||
|
op: u8,
|
||||||
|
#[serde(default)]
|
||||||
|
d: Option<serde_json::Value>,
|
||||||
|
#[serde(default)]
|
||||||
|
s: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
t: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
struct HelloData {
|
||||||
|
heartbeat_interval: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
pub(super) struct MessageCreateData {
|
||||||
|
#[serde(default)]
|
||||||
|
pub channel_id: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub content: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub author: Option<MessageAuthor>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub mentions: Vec<MentionUser>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
pub(super) struct MessageAuthor {
|
||||||
|
pub id: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub bot: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
pub(super) struct MentionUser {
|
||||||
|
pub id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
struct ReadyData {
|
||||||
|
user: ReadyUser,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
struct ReadyUser {
|
||||||
|
id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Identify payload ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct IdentifyPayload {
|
||||||
|
op: u8,
|
||||||
|
d: IdentifyData,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct IdentifyData {
|
||||||
|
token: String,
|
||||||
|
intents: u64,
|
||||||
|
properties: IdentifyProperties,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct IdentifyProperties {
|
||||||
|
os: &'static str,
|
||||||
|
browser: &'static str,
|
||||||
|
device: &'static str,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct HeartbeatPayload {
|
||||||
|
op: u8,
|
||||||
|
d: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Gateway URL ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
const GATEWAY_URL: &str = "wss://gateway.discord.gg/?v=10&encoding=json";
|
||||||
|
|
||||||
|
// ── Public entry point ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Spawn a background task that connects to the Discord Gateway, maintains
|
||||||
|
/// the heartbeat, and dispatches incoming messages.
|
||||||
|
pub fn spawn_gateway(ctx: Arc<DiscordContext>) {
|
||||||
|
tokio::spawn(async move {
|
||||||
|
loop {
|
||||||
|
if let Err(e) = run_gateway(Arc::clone(&ctx)).await {
|
||||||
|
slog!("[discord] Gateway connection ended: {e}");
|
||||||
|
}
|
||||||
|
// Wait before reconnecting.
|
||||||
|
slog!("[discord] Reconnecting in 5 seconds…");
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Connect to the Discord Gateway and process events until the connection
|
||||||
|
/// closes or an unrecoverable error occurs.
|
||||||
|
async fn run_gateway(ctx: Arc<DiscordContext>) -> Result<(), String> {
|
||||||
|
slog!("[discord] Connecting to Discord Gateway…");
|
||||||
|
|
||||||
|
let (ws_stream, _) = tokio_tungstenite::connect_async(GATEWAY_URL)
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Gateway connect failed: {e}"))?;
|
||||||
|
|
||||||
|
slog!("[discord] Gateway WebSocket connected");
|
||||||
|
|
||||||
|
let (mut write, mut read) = ws_stream.split();
|
||||||
|
|
||||||
|
// Wait for Hello (op 10) to get the heartbeat interval.
|
||||||
|
let hello = read
|
||||||
|
.next()
|
||||||
|
.await
|
||||||
|
.ok_or("Gateway closed before Hello")?
|
||||||
|
.map_err(|e| format!("Gateway read error: {e}"))?;
|
||||||
|
|
||||||
|
let hello_payload: GatewayPayload =
|
||||||
|
parse_ws_message(&hello).ok_or("Failed to parse Hello")?;
|
||||||
|
|
||||||
|
if hello_payload.op != OP_HELLO {
|
||||||
|
return Err(format!(
|
||||||
|
"Expected Hello (op 10), got op {}",
|
||||||
|
hello_payload.op
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let hello_data: HelloData =
|
||||||
|
serde_json::from_value(hello_payload.d.ok_or("Hello missing data")?)
|
||||||
|
.map_err(|e| format!("Failed to parse Hello data: {e}"))?;
|
||||||
|
|
||||||
|
let heartbeat_interval =
|
||||||
|
std::time::Duration::from_millis(hello_data.heartbeat_interval);
|
||||||
|
slog!(
|
||||||
|
"[discord] Heartbeat interval: {}ms",
|
||||||
|
hello_data.heartbeat_interval
|
||||||
|
);
|
||||||
|
|
||||||
|
// Send Identify.
|
||||||
|
let identify = IdentifyPayload {
|
||||||
|
op: OP_IDENTIFY,
|
||||||
|
d: IdentifyData {
|
||||||
|
token: ctx.bot_token.clone(),
|
||||||
|
intents: GATEWAY_INTENTS,
|
||||||
|
properties: IdentifyProperties {
|
||||||
|
os: "linux",
|
||||||
|
browser: "huskies",
|
||||||
|
device: "huskies",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let identify_json = serde_json::to_string(&identify)
|
||||||
|
.map_err(|e| format!("Failed to serialise Identify: {e}"))?;
|
||||||
|
write
|
||||||
|
.send(WsMessage::Text(identify_json.into()))
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to send Identify: {e}"))?;
|
||||||
|
|
||||||
|
slog!("[discord] Identify sent, waiting for Ready…");
|
||||||
|
|
||||||
|
// Track sequence number for heartbeats.
|
||||||
|
let sequence = Arc::new(std::sync::atomic::AtomicU64::new(0));
|
||||||
|
let seq_for_heartbeat = Arc::clone(&sequence);
|
||||||
|
|
||||||
|
// Share the write half between the heartbeat task and the main event loop.
|
||||||
|
let write = Arc::new(tokio::sync::Mutex::new(write));
|
||||||
|
let write_for_heartbeat = Arc::clone(&write);
|
||||||
|
|
||||||
|
// Spawn heartbeat task.
|
||||||
|
let heartbeat_handle = tokio::spawn(async move {
|
||||||
|
// Jitter: wait a random fraction of the interval before the first beat.
|
||||||
|
let jitter = heartbeat_interval.mul_f64(rand_fraction());
|
||||||
|
tokio::time::sleep(jitter).await;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let seq = seq_for_heartbeat.load(std::sync::atomic::Ordering::Relaxed);
|
||||||
|
let seq_val = if seq == 0 { None } else { Some(seq) };
|
||||||
|
let hb = HeartbeatPayload {
|
||||||
|
op: OP_HEARTBEAT,
|
||||||
|
d: seq_val,
|
||||||
|
};
|
||||||
|
if let Ok(json) = serde_json::to_string(&hb) {
|
||||||
|
let mut w = write_for_heartbeat.lock().await;
|
||||||
|
if w.send(WsMessage::Text(json.into())).await.is_err() {
|
||||||
|
slog!("[discord] Heartbeat send failed, stopping heartbeat");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tokio::time::sleep(heartbeat_interval).await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Track our own bot user ID (set from READY event).
|
||||||
|
let mut bot_user_id: Option<String> = None;
|
||||||
|
|
||||||
|
// Main event loop.
|
||||||
|
while let Some(msg_result) = read.next().await {
|
||||||
|
let msg = match msg_result {
|
||||||
|
Ok(m) => m,
|
||||||
|
Err(e) => {
|
||||||
|
slog!("[discord] Gateway read error: {e}");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let payload: GatewayPayload = match parse_ws_message(&msg) {
|
||||||
|
Some(p) => p,
|
||||||
|
None => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Update sequence number.
|
||||||
|
if let Some(s) = payload.s {
|
||||||
|
sequence.store(s, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
match payload.op {
|
||||||
|
OP_DISPATCH => {
|
||||||
|
let event_name = payload.t.as_deref().unwrap_or("");
|
||||||
|
match event_name {
|
||||||
|
"READY" => {
|
||||||
|
if let Some(d) = payload.d
|
||||||
|
&& let Ok(ready) = serde_json::from_value::<ReadyData>(d)
|
||||||
|
{
|
||||||
|
bot_user_id = Some(ready.user.id.clone());
|
||||||
|
slog!(
|
||||||
|
"[discord] READY — bot user ID: {}",
|
||||||
|
ready.user.id
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"MESSAGE_CREATE" => {
|
||||||
|
if let Some(d) = payload.d {
|
||||||
|
dispatch_message(
|
||||||
|
Arc::clone(&ctx),
|
||||||
|
d,
|
||||||
|
bot_user_id.clone(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
OP_HEARTBEAT => {
|
||||||
|
// Server requested an immediate heartbeat.
|
||||||
|
let seq = sequence.load(std::sync::atomic::Ordering::Relaxed);
|
||||||
|
let seq_val = if seq == 0 { None } else { Some(seq) };
|
||||||
|
let hb = HeartbeatPayload {
|
||||||
|
op: OP_HEARTBEAT,
|
||||||
|
d: seq_val,
|
||||||
|
};
|
||||||
|
if let Ok(json) = serde_json::to_string(&hb) {
|
||||||
|
let mut w = write.lock().await;
|
||||||
|
let _ = w.send(WsMessage::Text(json.into())).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
OP_HEARTBEAT_ACK => {}
|
||||||
|
OP_RECONNECT => {
|
||||||
|
slog!("[discord] Gateway requested reconnect");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
OP_INVALID_SESSION => {
|
||||||
|
slog!("[discord] Invalid session — will reconnect");
|
||||||
|
// Wait a bit before reconnecting for invalid sessions.
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
heartbeat_handle.abort();
|
||||||
|
slog!("[discord] Gateway event loop ended");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Helpers ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
fn parse_ws_message(msg: &WsMessage) -> Option<GatewayPayload> {
|
||||||
|
match msg {
|
||||||
|
WsMessage::Text(text) => serde_json::from_str(text).ok(),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter and dispatch an incoming MESSAGE_CREATE event on a background task.
|
||||||
|
fn dispatch_message(
|
||||||
|
ctx: Arc<DiscordContext>,
|
||||||
|
data: serde_json::Value,
|
||||||
|
bot_user_id: Option<String>,
|
||||||
|
) {
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let msg: MessageCreateData = match serde_json::from_value(data) {
|
||||||
|
Ok(m) => m,
|
||||||
|
Err(e) => {
|
||||||
|
slog!("[discord] Failed to parse MESSAGE_CREATE: {e}");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let author = match msg.author.as_ref() {
|
||||||
|
Some(a) => a,
|
||||||
|
None => return,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Ignore bot messages (including our own).
|
||||||
|
if author.bot.unwrap_or(false) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the channel is in our configured channel list.
|
||||||
|
if !ctx.channel_ids.contains(&msg.channel_id) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check allowed_users (if non-empty, only listed users can interact).
|
||||||
|
if !ctx.allowed_users.is_empty() && !ctx.allowed_users.contains(&author.id) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the bot was mentioned, or if we respond to all messages in
|
||||||
|
// configured channels (ambient mode).
|
||||||
|
let bot_mentioned = bot_user_id.as_ref().is_some_and(|bid| {
|
||||||
|
msg.mentions.iter().any(|m| m.id == *bid)
|
||||||
|
});
|
||||||
|
|
||||||
|
let in_ambient = ctx
|
||||||
|
.ambient_rooms
|
||||||
|
.lock()
|
||||||
|
.unwrap()
|
||||||
|
.contains(&msg.channel_id);
|
||||||
|
|
||||||
|
if !bot_mentioned && !in_ambient {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strip the bot mention from the message content if present.
|
||||||
|
let content = if let Some(bid) = bot_user_id.as_ref() {
|
||||||
|
let mention_pat = format!("<@{bid}>");
|
||||||
|
let mention_pat_nick = format!("<@!{bid}>");
|
||||||
|
msg.content
|
||||||
|
.replace(&mention_pat, "")
|
||||||
|
.replace(&mention_pat_nick, "")
|
||||||
|
.trim()
|
||||||
|
.to_string()
|
||||||
|
} else {
|
||||||
|
msg.content.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
if content.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
slog!(
|
||||||
|
"[discord] Message from {} in {}: {content:.80}",
|
||||||
|
author.id,
|
||||||
|
msg.channel_id
|
||||||
|
);
|
||||||
|
|
||||||
|
commands::handle_incoming_message(&ctx, &msg.channel_id, &author.id, &content)
|
||||||
|
.await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Cheap pseudo-random fraction [0.0, 1.0) for heartbeat jitter.
|
||||||
|
fn rand_fraction() -> f64 {
|
||||||
|
let nanos = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.subsec_nanos();
|
||||||
|
(nanos as f64) / 1_000_000_000.0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_hello_payload() {
|
||||||
|
let json = r#"{"op": 10, "d": {"heartbeat_interval": 41250}}"#;
|
||||||
|
let payload: GatewayPayload = serde_json::from_str(json).unwrap();
|
||||||
|
assert_eq!(payload.op, OP_HELLO);
|
||||||
|
let hello: HelloData =
|
||||||
|
serde_json::from_value(payload.d.unwrap()).unwrap();
|
||||||
|
assert_eq!(hello.heartbeat_interval, 41250);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_message_create() {
|
||||||
|
let json = r#"{
|
||||||
|
"channel_id": "123456",
|
||||||
|
"content": "hello bot",
|
||||||
|
"author": {"id": "789", "bot": false},
|
||||||
|
"mentions": [{"id": "111"}]
|
||||||
|
}"#;
|
||||||
|
let msg: MessageCreateData = serde_json::from_str(json).unwrap();
|
||||||
|
assert_eq!(msg.channel_id, "123456");
|
||||||
|
assert_eq!(msg.content, "hello bot");
|
||||||
|
assert_eq!(msg.author.as_ref().unwrap().id, "789");
|
||||||
|
assert!(!msg.author.as_ref().unwrap().bot.unwrap_or(false));
|
||||||
|
assert_eq!(msg.mentions.len(), 1);
|
||||||
|
assert_eq!(msg.mentions[0].id, "111");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_ready_event() {
|
||||||
|
let json = r#"{"user": {"id": "999888"}}"#;
|
||||||
|
let ready: ReadyData = serde_json::from_str(json).unwrap();
|
||||||
|
assert_eq!(ready.user.id, "999888");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn identify_payload_serializes_correctly() {
|
||||||
|
let identify = IdentifyPayload {
|
||||||
|
op: OP_IDENTIFY,
|
||||||
|
d: IdentifyData {
|
||||||
|
token: "test-token".to_string(),
|
||||||
|
intents: GATEWAY_INTENTS,
|
||||||
|
properties: IdentifyProperties {
|
||||||
|
os: "linux",
|
||||||
|
browser: "huskies",
|
||||||
|
device: "huskies",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let json: serde_json::Value =
|
||||||
|
serde_json::from_str(&serde_json::to_string(&identify).unwrap()).unwrap();
|
||||||
|
assert_eq!(json["op"], 2);
|
||||||
|
assert_eq!(json["d"]["token"], "test-token");
|
||||||
|
assert_eq!(json["d"]["intents"], GATEWAY_INTENTS);
|
||||||
|
assert_eq!(json["d"]["properties"]["browser"], "huskies");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn heartbeat_payload_serializes_with_sequence() {
|
||||||
|
let hb = HeartbeatPayload {
|
||||||
|
op: OP_HEARTBEAT,
|
||||||
|
d: Some(42),
|
||||||
|
};
|
||||||
|
let json: serde_json::Value =
|
||||||
|
serde_json::from_str(&serde_json::to_string(&hb).unwrap()).unwrap();
|
||||||
|
assert_eq!(json["op"], 1);
|
||||||
|
assert_eq!(json["d"], 42);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn heartbeat_payload_serializes_null_sequence() {
|
||||||
|
let hb = HeartbeatPayload {
|
||||||
|
op: OP_HEARTBEAT,
|
||||||
|
d: None,
|
||||||
|
};
|
||||||
|
let json: serde_json::Value =
|
||||||
|
serde_json::from_str(&serde_json::to_string(&hb).unwrap()).unwrap();
|
||||||
|
assert_eq!(json["op"], 1);
|
||||||
|
assert!(json["d"].is_null());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rand_fraction_in_range() {
|
||||||
|
let f = rand_fraction();
|
||||||
|
assert!((0.0..1.0).contains(&f));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn gateway_intents_correct() {
|
||||||
|
// GUILDS (1) | GUILD_MESSAGES (512) | MESSAGE_CONTENT (32768)
|
||||||
|
assert_eq!(GATEWAY_INTENTS, 1 | 512 | 32768);
|
||||||
|
assert_eq!(GATEWAY_INTENTS, 33281);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,119 @@
|
|||||||
|
//! Discord conversation history persistence.
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::Mutex as TokioMutex;
|
||||||
|
|
||||||
|
use crate::chat::transport::matrix::RoomConversation;
|
||||||
|
use crate::slog;
|
||||||
|
|
||||||
|
/// Per-channel conversation history, keyed by channel ID.
|
||||||
|
pub type DiscordConversationHistory = Arc<TokioMutex<HashMap<String, RoomConversation>>>;
|
||||||
|
|
||||||
|
/// On-disk format for persisted Discord conversation history.
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
struct PersistedDiscordHistory {
|
||||||
|
channels: HashMap<String, RoomConversation>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Path to the persisted Discord conversation history file.
|
||||||
|
const DISCORD_HISTORY_FILE: &str = ".huskies/discord_history.json";
|
||||||
|
|
||||||
|
/// Load Discord conversation history from disk.
|
||||||
|
pub fn load_discord_history(project_root: &std::path::Path) -> HashMap<String, RoomConversation> {
|
||||||
|
let path = project_root.join(DISCORD_HISTORY_FILE);
|
||||||
|
let data = match std::fs::read_to_string(&path) {
|
||||||
|
Ok(d) => d,
|
||||||
|
Err(_) => return HashMap::new(),
|
||||||
|
};
|
||||||
|
let persisted: PersistedDiscordHistory = match serde_json::from_str(&data) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(e) => {
|
||||||
|
slog!("[discord] Failed to parse history file: {e}");
|
||||||
|
return HashMap::new();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
persisted.channels
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Save Discord conversation history to disk.
|
||||||
|
pub(super) fn save_discord_history(
|
||||||
|
project_root: &std::path::Path,
|
||||||
|
history: &HashMap<String, RoomConversation>,
|
||||||
|
) {
|
||||||
|
let persisted = PersistedDiscordHistory {
|
||||||
|
channels: history.clone(),
|
||||||
|
};
|
||||||
|
let path = project_root.join(DISCORD_HISTORY_FILE);
|
||||||
|
match serde_json::to_string_pretty(&persisted) {
|
||||||
|
Ok(json) => {
|
||||||
|
if let Err(e) = std::fs::write(&path, json) {
|
||||||
|
slog!("[discord] Failed to write history file: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => slog!("[discord] Failed to serialise history: {e}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::chat::transport::matrix::{ConversationEntry, ConversationRole};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn save_and_load_discord_history_round_trips() {
|
||||||
|
let tmp = tempfile::tempdir().unwrap();
|
||||||
|
let sk = tmp.path().join(".huskies");
|
||||||
|
std::fs::create_dir_all(&sk).unwrap();
|
||||||
|
|
||||||
|
let mut history = HashMap::new();
|
||||||
|
history.insert(
|
||||||
|
"123456789".to_string(),
|
||||||
|
RoomConversation {
|
||||||
|
session_id: Some("sess-abc".to_string()),
|
||||||
|
entries: vec![
|
||||||
|
ConversationEntry {
|
||||||
|
role: ConversationRole::User,
|
||||||
|
sender: "user123".to_string(),
|
||||||
|
content: "hello".to_string(),
|
||||||
|
},
|
||||||
|
ConversationEntry {
|
||||||
|
role: ConversationRole::Assistant,
|
||||||
|
sender: String::new(),
|
||||||
|
content: "hi there!".to_string(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
save_discord_history(tmp.path(), &history);
|
||||||
|
let loaded = load_discord_history(tmp.path());
|
||||||
|
|
||||||
|
assert_eq!(loaded.len(), 1);
|
||||||
|
let conv = loaded.get("123456789").unwrap();
|
||||||
|
assert_eq!(conv.session_id.as_deref(), Some("sess-abc"));
|
||||||
|
assert_eq!(conv.entries.len(), 2);
|
||||||
|
assert_eq!(conv.entries[0].content, "hello");
|
||||||
|
assert_eq!(conv.entries[1].content, "hi there!");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn load_discord_history_returns_empty_when_file_missing() {
|
||||||
|
let tmp = tempfile::tempdir().unwrap();
|
||||||
|
let history = load_discord_history(tmp.path());
|
||||||
|
assert!(history.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn load_discord_history_returns_empty_on_invalid_json() {
|
||||||
|
let tmp = tempfile::tempdir().unwrap();
|
||||||
|
let sk = tmp.path().join(".huskies");
|
||||||
|
std::fs::create_dir_all(&sk).unwrap();
|
||||||
|
std::fs::write(sk.join("discord_history.json"), "not json {{{").unwrap();
|
||||||
|
let history = load_discord_history(tmp.path());
|
||||||
|
assert!(history.is_empty());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,302 @@
|
|||||||
|
//! DiscordTransport — ChatTransport implementation for the Discord Bot API.
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
use crate::chat::{ChatTransport, MessageId};
|
||||||
|
use crate::slog;
|
||||||
|
|
||||||
|
// ── Discord API base URL (overridable for tests) ──────────────────────
|
||||||
|
|
||||||
|
const DISCORD_API_BASE: &str = "https://discord.com/api/v10";
|
||||||
|
|
||||||
|
// ── DiscordTransport ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Discord Bot API transport.
|
||||||
|
///
|
||||||
|
/// Sends messages via `POST {API_BASE}/channels/{channel_id}/messages`,
|
||||||
|
/// edits via `PATCH {API_BASE}/channels/{channel_id}/messages/{message_id}`,
|
||||||
|
/// and typing indicators via `POST {API_BASE}/channels/{channel_id}/typing`.
|
||||||
|
pub struct DiscordTransport {
|
||||||
|
bot_token: String,
|
||||||
|
client: reqwest::Client,
|
||||||
|
/// Optional base URL override for tests.
|
||||||
|
api_base: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DiscordTransport {
|
||||||
|
pub fn new(bot_token: String) -> Self {
|
||||||
|
Self {
|
||||||
|
bot_token,
|
||||||
|
client: reqwest::Client::new(),
|
||||||
|
api_base: DISCORD_API_BASE.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn with_api_base(bot_token: String, api_base: String) -> Self {
|
||||||
|
Self {
|
||||||
|
bot_token,
|
||||||
|
client: reqwest::Client::new(),
|
||||||
|
api_base,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Authorization header value: `Bot {token}`.
|
||||||
|
fn auth_header(&self) -> String {
|
||||||
|
format!("Bot {}", self.bot_token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Discord API response types ────────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
struct DiscordMessage {
|
||||||
|
id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── ChatTransport implementation ──────────────────────────────────────
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ChatTransport for DiscordTransport {
|
||||||
|
async fn send_message(
|
||||||
|
&self,
|
||||||
|
channel_id: &str,
|
||||||
|
plain: &str,
|
||||||
|
_html: &str,
|
||||||
|
) -> Result<MessageId, String> {
|
||||||
|
slog!("[discord] send_message to {channel_id}: {plain:.80}");
|
||||||
|
let url = format!("{}/channels/{}/messages", self.api_base, channel_id);
|
||||||
|
|
||||||
|
// Discord messages have a 2000-char limit. Truncate if needed.
|
||||||
|
let content = if plain.len() > 2000 {
|
||||||
|
format!("{}…", &plain[..1999])
|
||||||
|
} else {
|
||||||
|
plain.to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
let payload = serde_json::json!({ "content": content });
|
||||||
|
|
||||||
|
let resp = self
|
||||||
|
.client
|
||||||
|
.post(&url)
|
||||||
|
.header("Authorization", self.auth_header())
|
||||||
|
.json(&payload)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Discord API request failed: {e}"))?;
|
||||||
|
|
||||||
|
let status = resp.status();
|
||||||
|
let resp_text = resp
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|_| "<no body>".to_string());
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
return Err(format!("Discord API returned {status}: {resp_text}"));
|
||||||
|
}
|
||||||
|
|
||||||
|
let msg: DiscordMessage = serde_json::from_str(&resp_text).map_err(|e| {
|
||||||
|
format!("Failed to parse Discord API response: {e} — body: {resp_text}")
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(msg.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn edit_message(
|
||||||
|
&self,
|
||||||
|
channel_id: &str,
|
||||||
|
original_message_id: &str,
|
||||||
|
plain: &str,
|
||||||
|
_html: &str,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
slog!("[discord] edit_message in {channel_id}: id={original_message_id}");
|
||||||
|
let url = format!(
|
||||||
|
"{}/channels/{}/messages/{}",
|
||||||
|
self.api_base, channel_id, original_message_id
|
||||||
|
);
|
||||||
|
|
||||||
|
let content = if plain.len() > 2000 {
|
||||||
|
format!("{}…", &plain[..1999])
|
||||||
|
} else {
|
||||||
|
plain.to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
let payload = serde_json::json!({ "content": content });
|
||||||
|
|
||||||
|
let resp = self
|
||||||
|
.client
|
||||||
|
.patch(&url)
|
||||||
|
.header("Authorization", self.auth_header())
|
||||||
|
.json(&payload)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Discord edit request failed: {e}"))?;
|
||||||
|
|
||||||
|
let status = resp.status();
|
||||||
|
let resp_text = resp
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|_| "<no body>".to_string());
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
return Err(format!("Discord edit returned {status}: {resp_text}"));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_typing(&self, channel_id: &str, typing: bool) -> Result<(), String> {
|
||||||
|
if !typing {
|
||||||
|
// Discord doesn't have a "stop typing" API — it times out after ~10s.
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
let url = format!("{}/channels/{}/typing", self.api_base, channel_id);
|
||||||
|
|
||||||
|
self.client
|
||||||
|
.post(&url)
|
||||||
|
.header("Authorization", self.auth_header())
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Discord typing request failed: {e}"))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn transport_send_message_calls_discord_api() {
|
||||||
|
let mut server = mockito::Server::new_async().await;
|
||||||
|
let mock = server
|
||||||
|
.mock("POST", "/channels/123456/messages")
|
||||||
|
.match_header("authorization", "Bot test-token")
|
||||||
|
.with_body(r#"{"id": "999888777"}"#)
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let transport =
|
||||||
|
DiscordTransport::with_api_base("test-token".to_string(), server.url());
|
||||||
|
|
||||||
|
let result = transport
|
||||||
|
.send_message("123456", "hello", "<p>hello</p>")
|
||||||
|
.await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
assert_eq!(result.unwrap(), "999888777");
|
||||||
|
mock.assert_async().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn transport_send_message_handles_api_error() {
|
||||||
|
let mut server = mockito::Server::new_async().await;
|
||||||
|
server
|
||||||
|
.mock("POST", "/channels/bad/messages")
|
||||||
|
.with_status(404)
|
||||||
|
.with_body(r#"{"message": "Unknown Channel", "code": 10003}"#)
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let transport =
|
||||||
|
DiscordTransport::with_api_base("test-token".to_string(), server.url());
|
||||||
|
|
||||||
|
let result = transport.send_message("bad", "hello", "").await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().contains("404"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn transport_edit_message_calls_patch() {
|
||||||
|
let mut server = mockito::Server::new_async().await;
|
||||||
|
let mock = server
|
||||||
|
.mock("PATCH", "/channels/123456/messages/999888777")
|
||||||
|
.match_header("authorization", "Bot test-token")
|
||||||
|
.with_body(r#"{"id": "999888777"}"#)
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let transport =
|
||||||
|
DiscordTransport::with_api_base("test-token".to_string(), server.url());
|
||||||
|
|
||||||
|
let result = transport
|
||||||
|
.edit_message("123456", "999888777", "updated", "")
|
||||||
|
.await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
mock.assert_async().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn transport_edit_message_handles_error() {
|
||||||
|
let mut server = mockito::Server::new_async().await;
|
||||||
|
server
|
||||||
|
.mock("PATCH", "/channels/123456/messages/bad")
|
||||||
|
.with_status(404)
|
||||||
|
.with_body(r#"{"message": "Unknown Message", "code": 10008}"#)
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let transport =
|
||||||
|
DiscordTransport::with_api_base("test-token".to_string(), server.url());
|
||||||
|
|
||||||
|
let result = transport
|
||||||
|
.edit_message("123456", "bad", "updated", "")
|
||||||
|
.await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().contains("404"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn transport_send_typing_succeeds() {
|
||||||
|
let mut server = mockito::Server::new_async().await;
|
||||||
|
server
|
||||||
|
.mock("POST", "/channels/123456/typing")
|
||||||
|
.with_status(204)
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let transport =
|
||||||
|
DiscordTransport::with_api_base("test-token".to_string(), server.url());
|
||||||
|
|
||||||
|
assert!(transport.send_typing("123456", true).await.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn transport_send_typing_false_is_noop() {
|
||||||
|
let transport = DiscordTransport::new("test-token".to_string());
|
||||||
|
assert!(transport.send_typing("123456", false).await.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn transport_handles_http_error() {
|
||||||
|
let mut server = mockito::Server::new_async().await;
|
||||||
|
server
|
||||||
|
.mock("POST", "/channels/123456/messages")
|
||||||
|
.with_status(500)
|
||||||
|
.with_body("Internal Server Error")
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let transport =
|
||||||
|
DiscordTransport::with_api_base("test-token".to_string(), server.url());
|
||||||
|
|
||||||
|
let result = transport.send_message("123456", "hello", "").await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().contains("500"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── ChatTransport trait satisfaction ─────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn discord_transport_satisfies_trait() {
|
||||||
|
fn assert_transport<T: ChatTransport>() {}
|
||||||
|
assert_transport::<DiscordTransport>();
|
||||||
|
|
||||||
|
let _: Arc<dyn ChatTransport> =
|
||||||
|
Arc::new(DiscordTransport::new("test-token".to_string()));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
//! Discord Bot integration.
|
||||||
|
//!
|
||||||
|
//! Provides:
|
||||||
|
//! - [`DiscordTransport`] — a [`ChatTransport`] that sends messages via the
|
||||||
|
//! Discord REST API (`/channels/{id}/messages`).
|
||||||
|
//! - [`gateway::spawn_gateway`] — connects to the Discord Gateway WebSocket,
|
||||||
|
//! receives `MESSAGE_CREATE` events, and dispatches commands.
|
||||||
|
//! - [`DiscordContext`] — shared context for the bot.
|
||||||
|
|
||||||
|
pub mod commands;
|
||||||
|
pub mod format;
|
||||||
|
pub mod gateway;
|
||||||
|
pub mod history;
|
||||||
|
pub mod meta;
|
||||||
|
|
||||||
|
pub use commands::DiscordContext;
|
||||||
|
pub use history::load_discord_history;
|
||||||
|
pub use meta::DiscordTransport;
|
||||||
@@ -64,7 +64,8 @@ pub struct BotConfig {
|
|||||||
/// manually while the bot is running.
|
/// manually while the bot is running.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub ambient_rooms: Vec<String>,
|
pub ambient_rooms: Vec<String>,
|
||||||
/// Chat transport to use: `"matrix"` (default) or `"whatsapp"`.
|
/// Chat transport to use: `"matrix"` (default), `"whatsapp"`, `"slack"`,
|
||||||
|
/// or `"discord"`.
|
||||||
///
|
///
|
||||||
/// Selects which [`ChatTransport`] implementation the bot uses for
|
/// Selects which [`ChatTransport`] implementation the bot uses for
|
||||||
/// sending and editing messages. Currently only read during bot
|
/// sending and editing messages. Currently only read during bot
|
||||||
@@ -134,6 +135,20 @@ pub struct BotConfig {
|
|||||||
/// Slack channel IDs the bot should listen in.
|
/// Slack channel IDs the bot should listen in.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub slack_channel_ids: Vec<String>,
|
pub slack_channel_ids: Vec<String>,
|
||||||
|
|
||||||
|
// ── Discord Bot API fields ──────────────────────────────────────
|
||||||
|
// These are only required when `transport = "discord"`.
|
||||||
|
|
||||||
|
/// Discord bot token from the Discord Developer Portal.
|
||||||
|
#[serde(default)]
|
||||||
|
pub discord_bot_token: Option<String>,
|
||||||
|
/// Discord channel IDs the bot should listen in.
|
||||||
|
#[serde(default)]
|
||||||
|
pub discord_channel_ids: Vec<String>,
|
||||||
|
/// Discord user IDs allowed to interact with the bot.
|
||||||
|
/// When empty or absent, all users in configured channels are allowed.
|
||||||
|
#[serde(default)]
|
||||||
|
pub discord_allowed_users: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_transport() -> String {
|
fn default_transport() -> String {
|
||||||
@@ -241,6 +256,22 @@ impl BotConfig {
|
|||||||
);
|
);
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
} else if config.transport == "discord" {
|
||||||
|
// Validate Discord-specific fields.
|
||||||
|
if config.discord_bot_token.as_ref().is_none_or(|s| s.is_empty()) {
|
||||||
|
eprintln!(
|
||||||
|
"[bot] bot.toml: transport=\"discord\" requires \
|
||||||
|
discord_bot_token"
|
||||||
|
);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
if config.discord_channel_ids.is_empty() {
|
||||||
|
eprintln!(
|
||||||
|
"[bot] bot.toml: transport=\"discord\" requires \
|
||||||
|
at least one discord_channel_ids entry"
|
||||||
|
);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// Default transport is Matrix — validate Matrix-specific fields.
|
// Default transport is Matrix — validate Matrix-specific fields.
|
||||||
if config.homeserver.as_ref().is_none_or(|s| s.is_empty()) {
|
if config.homeserver.as_ref().is_none_or(|s| s.is_empty()) {
|
||||||
@@ -1054,4 +1085,99 @@ whatsapp_allowed_phones = ["+15551234567", "+15559876543"]
|
|||||||
vec!["+15551234567", "+15559876543"]
|
vec!["+15551234567", "+15559876543"]
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Discord config tests ──────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn load_discord_transport_reads_config() {
|
||||||
|
let tmp = tempfile::tempdir().unwrap();
|
||||||
|
let sk = tmp.path().join(".huskies");
|
||||||
|
fs::create_dir_all(&sk).unwrap();
|
||||||
|
fs::write(
|
||||||
|
sk.join("bot.toml"),
|
||||||
|
r#"
|
||||||
|
enabled = true
|
||||||
|
transport = "discord"
|
||||||
|
discord_bot_token = "Bot.Token.Here"
|
||||||
|
discord_channel_ids = ["123456789012345678"]
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let config = BotConfig::load(tmp.path()).unwrap();
|
||||||
|
assert_eq!(config.transport, "discord");
|
||||||
|
assert_eq!(
|
||||||
|
config.discord_bot_token.as_deref(),
|
||||||
|
Some("Bot.Token.Here")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
config.discord_channel_ids,
|
||||||
|
vec!["123456789012345678"]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn load_discord_returns_none_when_missing_bot_token() {
|
||||||
|
let tmp = tempfile::tempdir().unwrap();
|
||||||
|
let sk = tmp.path().join(".huskies");
|
||||||
|
fs::create_dir_all(&sk).unwrap();
|
||||||
|
fs::write(
|
||||||
|
sk.join("bot.toml"),
|
||||||
|
r#"
|
||||||
|
enabled = true
|
||||||
|
transport = "discord"
|
||||||
|
discord_channel_ids = ["123456789012345678"]
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
assert!(BotConfig::load(tmp.path()).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn load_discord_returns_none_when_missing_channel_ids() {
|
||||||
|
let tmp = tempfile::tempdir().unwrap();
|
||||||
|
let sk = tmp.path().join(".huskies");
|
||||||
|
fs::create_dir_all(&sk).unwrap();
|
||||||
|
fs::write(
|
||||||
|
sk.join("bot.toml"),
|
||||||
|
r#"
|
||||||
|
enabled = true
|
||||||
|
transport = "discord"
|
||||||
|
discord_bot_token = "Bot.Token.Here"
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
assert!(BotConfig::load(tmp.path()).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn discord_allowed_users_defaults_to_empty_when_absent() {
|
||||||
|
let config: BotConfig = toml::from_str(
|
||||||
|
r#"
|
||||||
|
enabled = true
|
||||||
|
transport = "discord"
|
||||||
|
discord_bot_token = "Bot.Token.Here"
|
||||||
|
discord_channel_ids = ["123456789"]
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
assert!(config.discord_allowed_users.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn discord_allowed_users_deserializes_list() {
|
||||||
|
let config: BotConfig = toml::from_str(
|
||||||
|
r#"
|
||||||
|
enabled = true
|
||||||
|
transport = "discord"
|
||||||
|
discord_bot_token = "Bot.Token.Here"
|
||||||
|
discord_channel_ids = ["123456789"]
|
||||||
|
discord_allowed_users = ["111222333", "444555666"]
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
config.discord_allowed_users,
|
||||||
|
vec!["111222333", "444555666"]
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
pub mod discord;
|
||||||
pub mod matrix;
|
pub mod matrix;
|
||||||
pub mod slack;
|
pub mod slack;
|
||||||
pub mod whatsapp;
|
pub mod whatsapp;
|
||||||
|
|||||||
@@ -339,12 +339,14 @@ async fn main() -> Result<(), std::io::Error> {
|
|||||||
// before watcher_tx is moved into AppContext.
|
// before watcher_tx is moved into AppContext.
|
||||||
let watcher_rx_for_whatsapp = watcher_tx.subscribe();
|
let watcher_rx_for_whatsapp = watcher_tx.subscribe();
|
||||||
let watcher_rx_for_slack = watcher_tx.subscribe();
|
let watcher_rx_for_slack = watcher_tx.subscribe();
|
||||||
|
let watcher_rx_for_discord = watcher_tx.subscribe();
|
||||||
// Wrap perm_rx in Arc<Mutex> so it can be shared with both the WebSocket
|
// Wrap perm_rx in Arc<Mutex> so it can be shared with both the WebSocket
|
||||||
// handler (via AppContext) and the Matrix bot.
|
// handler (via AppContext) and the Matrix bot.
|
||||||
let perm_rx = Arc::new(tokio::sync::Mutex::new(perm_rx));
|
let perm_rx = Arc::new(tokio::sync::Mutex::new(perm_rx));
|
||||||
let perm_rx_for_bot = Arc::clone(&perm_rx);
|
let perm_rx_for_bot = Arc::clone(&perm_rx);
|
||||||
let perm_rx_for_whatsapp = Arc::clone(&perm_rx);
|
let perm_rx_for_whatsapp = Arc::clone(&perm_rx);
|
||||||
let perm_rx_for_slack = Arc::clone(&perm_rx);
|
let perm_rx_for_slack = Arc::clone(&perm_rx);
|
||||||
|
let perm_rx_for_discord = Arc::clone(&perm_rx);
|
||||||
|
|
||||||
// Capture project root, agents Arc, and reconciliation sender before ctx
|
// Capture project root, agents Arc, and reconciliation sender before ctx
|
||||||
// is consumed by build_routes.
|
// is consumed by build_routes.
|
||||||
@@ -441,9 +443,49 @@ async fn main() -> Result<(), std::io::Error> {
|
|||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Build Discord context if bot.toml configures transport = "discord".
|
||||||
|
let discord_ctx: Option<Arc<chat::transport::discord::DiscordContext>> = startup_root
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|root| chat::transport::matrix::BotConfig::load(root))
|
||||||
|
.filter(|cfg| cfg.transport == "discord")
|
||||||
|
.map(|cfg| {
|
||||||
|
let transport = Arc::new(chat::transport::discord::DiscordTransport::new(
|
||||||
|
cfg.discord_bot_token.clone().unwrap_or_default(),
|
||||||
|
));
|
||||||
|
let bot_name = cfg
|
||||||
|
.display_name
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| "Assistant".to_string());
|
||||||
|
let root = startup_root.clone().unwrap();
|
||||||
|
let history = chat::transport::discord::load_discord_history(&root);
|
||||||
|
let channel_ids: std::collections::HashSet<String> =
|
||||||
|
cfg.discord_channel_ids.iter().cloned().collect();
|
||||||
|
let allowed_users: std::collections::HashSet<String> =
|
||||||
|
cfg.discord_allowed_users.iter().cloned().collect();
|
||||||
|
Arc::new(chat::transport::discord::DiscordContext {
|
||||||
|
bot_token: cfg.discord_bot_token.clone().unwrap_or_default(),
|
||||||
|
transport,
|
||||||
|
project_root: root,
|
||||||
|
agents: Arc::clone(&startup_agents),
|
||||||
|
bot_name,
|
||||||
|
bot_user_id: "discord-bot".to_string(),
|
||||||
|
ambient_rooms: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
|
||||||
|
history: std::sync::Arc::new(tokio::sync::Mutex::new(history)),
|
||||||
|
history_size: cfg.history_size,
|
||||||
|
channel_ids,
|
||||||
|
allowed_users,
|
||||||
|
perm_rx: perm_rx_for_discord,
|
||||||
|
pending_perm_replies: Arc::new(tokio::sync::Mutex::new(
|
||||||
|
std::collections::HashMap::new(),
|
||||||
|
)),
|
||||||
|
permission_timeout_secs: cfg.permission_timeout_secs,
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
// Build a best-effort shutdown notifier for webhook-based transports.
|
// Build a best-effort shutdown notifier for webhook-based transports.
|
||||||
//
|
//
|
||||||
// • Slack: channels are fixed at startup (channel_ids from bot.toml).
|
// • Slack: channels are fixed at startup (channel_ids from bot.toml).
|
||||||
|
// • Discord: channels are fixed at startup (channel_ids from bot.toml).
|
||||||
// • WhatsApp: active senders are tracked at runtime in ambient_rooms.
|
// • WhatsApp: active senders are tracked at runtime in ambient_rooms.
|
||||||
// We keep the WhatsApp context Arc so we can read the rooms at shutdown.
|
// We keep the WhatsApp context Arc so we can read the rooms at shutdown.
|
||||||
// • Matrix: the bot task manages its own announcement via matrix_shutdown_tx.
|
// • Matrix: the bot task manages its own announcement via matrix_shutdown_tx.
|
||||||
@@ -454,6 +496,13 @@ async fn main() -> Result<(), std::io::Error> {
|
|||||||
channels,
|
channels,
|
||||||
ctx.bot_name.clone(),
|
ctx.bot_name.clone(),
|
||||||
)))
|
)))
|
||||||
|
} else if let Some(ref ctx) = discord_ctx {
|
||||||
|
let channels: Vec<String> = ctx.channel_ids.iter().cloned().collect();
|
||||||
|
Some(Arc::new(BotShutdownNotifier::new(
|
||||||
|
Arc::clone(&ctx.transport) as Arc<dyn crate::chat::ChatTransport>,
|
||||||
|
channels,
|
||||||
|
ctx.bot_name.clone(),
|
||||||
|
)))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
@@ -496,6 +545,18 @@ async fn main() -> Result<(), std::io::Error> {
|
|||||||
notifier.notify_startup().await;
|
notifier.notify_startup().await;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
if let Some(ref ctx) = discord_ctx {
|
||||||
|
let transport = Arc::clone(&ctx.transport) as Arc<dyn crate::chat::ChatTransport>;
|
||||||
|
let bot_name = ctx.bot_name.clone();
|
||||||
|
let channels: Vec<String> = ctx.channel_ids.iter().cloned().collect();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if channels.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let notifier = crate::rebuild::BotShutdownNotifier::new(transport, channels, bot_name);
|
||||||
|
notifier.notify_startup().await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Watch channel: signals the Matrix bot task to send a shutdown announcement.
|
// Watch channel: signals the Matrix bot task to send a shutdown announcement.
|
||||||
// `None` initial value means "server is running".
|
// `None` initial value means "server is running".
|
||||||
@@ -559,6 +620,21 @@ async fn main() -> Result<(), std::io::Error> {
|
|||||||
} else {
|
} else {
|
||||||
drop(watcher_rx_for_slack);
|
drop(watcher_rx_for_slack);
|
||||||
}
|
}
|
||||||
|
if let (Some(ctx), Some(root)) = (&discord_ctx, &startup_root) {
|
||||||
|
// Spawn the Discord Gateway WebSocket listener.
|
||||||
|
chat::transport::discord::gateway::spawn_gateway(Arc::clone(ctx));
|
||||||
|
|
||||||
|
// Spawn stage-transition notification listener for Discord.
|
||||||
|
let channel_ids: Vec<String> = ctx.channel_ids.iter().cloned().collect();
|
||||||
|
chat::transport::matrix::notifications::spawn_notification_listener(
|
||||||
|
Arc::clone(&ctx.transport) as Arc<dyn crate::chat::ChatTransport>,
|
||||||
|
move || channel_ids.clone(),
|
||||||
|
watcher_rx_for_discord,
|
||||||
|
root.clone(),
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
drop(watcher_rx_for_discord);
|
||||||
|
}
|
||||||
|
|
||||||
// On startup:
|
// On startup:
|
||||||
// 1. Reconcile any stories whose agent work was committed while the server was
|
// 1. Reconcile any stories whose agent work was committed while the server was
|
||||||
|
|||||||
Reference in New Issue
Block a user