huskies: merge 472_story_discord_chat_transport

This commit is contained in:
dave
2026-04-04 12:08:39 +00:00
parent ee86e4a3d3
commit c56e462340
14 changed files with 1960 additions and 3 deletions
@@ -0,0 +1,641 @@
//! Discord incoming message dispatch and command handling.
use std::collections::{HashMap, HashSet};
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use tokio::sync::{Mutex as TokioMutex, oneshot};
use crate::agents::AgentPool;
use crate::chat::transport::matrix::{ConversationEntry, ConversationRole, RoomConversation};
use crate::chat::util::is_permission_approval;
use crate::chat::ChatTransport;
use crate::http::context::{PermissionDecision, PermissionForward};
use crate::slog;
use super::format::markdown_to_discord;
use super::history::{DiscordConversationHistory, save_discord_history};
use super::meta::DiscordTransport;
// ── Shared context ────────────────────────────────────────────────────
/// Shared context for the Discord bot, used by both the Gateway listener
/// and any future webhook handlers.
pub struct DiscordContext {
pub bot_token: String,
pub transport: Arc<DiscordTransport>,
pub project_root: PathBuf,
pub agents: Arc<AgentPool>,
pub bot_name: String,
/// The bot's Discord user ID (set dynamically from READY event, but
/// also stored here for command dispatch).
pub bot_user_id: String,
pub ambient_rooms: Arc<Mutex<HashSet<String>>>,
/// Per-channel conversation history for LLM passthrough.
pub history: DiscordConversationHistory,
/// Maximum number of conversation entries to keep per channel.
pub history_size: usize,
/// Allowed channel IDs (messages from other channels are ignored).
pub channel_ids: HashSet<String>,
/// Allowed Discord user IDs. When non-empty, only listed users can
/// interact with the bot. When empty, all users are allowed.
pub allowed_users: HashSet<String>,
/// Permission requests from the MCP `prompt_permission` tool arrive here.
pub perm_rx: Arc<TokioMutex<tokio::sync::mpsc::UnboundedReceiver<PermissionForward>>>,
/// Pending permission replies keyed by channel ID.
pub pending_perm_replies:
Arc<TokioMutex<HashMap<String, oneshot::Sender<PermissionDecision>>>>,
/// Seconds before an unanswered permission prompt is auto-denied.
pub permission_timeout_secs: u64,
}
// ── Incoming message dispatch ───────────────────────────────────────────
pub(super) async fn handle_incoming_message(
ctx: &DiscordContext,
channel: &str,
user: &str,
message: &str,
) {
use crate::chat::commands::{CommandDispatch, try_handle_command};
// If there is a pending permission prompt for this channel, interpret the
// message as a yes/no response.
{
let mut pending = ctx.pending_perm_replies.lock().await;
if let Some(tx) = pending.remove(channel) {
let decision = if is_permission_approval(message) {
PermissionDecision::Approve
} else {
PermissionDecision::Deny
};
let _ = tx.send(decision);
let confirmation = if decision == PermissionDecision::Approve {
"Permission approved."
} else {
"Permission denied."
};
let formatted = markdown_to_discord(confirmation);
let _ = ctx.transport.send_message(channel, &formatted, "").await;
return;
}
}
let dispatch = CommandDispatch {
bot_name: &ctx.bot_name,
bot_user_id: &ctx.bot_user_id,
project_root: &ctx.project_root,
agents: &ctx.agents,
ambient_rooms: &ctx.ambient_rooms,
room_id: channel,
};
if let Some(response) = try_handle_command(&dispatch, message) {
slog!("[discord] Sending command response to {channel}");
let response = markdown_to_discord(&response);
if let Err(e) = ctx.transport.send_message(channel, &response, "").await {
slog!("[discord] Failed to send reply to {channel}: {e}");
}
return;
}
// Check for async commands (htop, delete).
if let Some(htop_cmd) = crate::chat::transport::matrix::htop::extract_htop_command(
message,
&ctx.bot_name,
&ctx.bot_user_id,
) {
use crate::chat::transport::matrix::htop::HtopCommand;
slog!("[discord] Handling htop command from {user} in {channel}");
match htop_cmd {
HtopCommand::Stop => {
let _ = ctx
.transport
.send_message(channel, "htop stopped.", "")
.await;
}
HtopCommand::Start { duration_secs } => {
let snapshot = crate::chat::transport::matrix::htop::build_htop_message(
&ctx.agents,
0,
duration_secs,
);
let snapshot = markdown_to_discord(&snapshot);
let msg_id = match ctx.transport.send_message(channel, &snapshot, "").await {
Ok(id) => id,
Err(e) => {
slog!("[discord] Failed to send htop message: {e}");
return;
}
};
let transport = Arc::clone(&ctx.transport) as Arc<dyn ChatTransport>;
let agents = Arc::clone(&ctx.agents);
let ch = channel.to_string();
tokio::spawn(async move {
let interval = std::time::Duration::from_secs(2);
let total_ticks = (duration_secs as usize) / 2;
for tick in 1..=total_ticks {
tokio::time::sleep(interval).await;
let updated =
crate::chat::transport::matrix::htop::build_htop_message(
&agents,
(tick * 2) as u32,
duration_secs,
);
let updated = markdown_to_discord(&updated);
if let Err(e) =
transport.edit_message(&ch, &msg_id, &updated, "").await
{
slog!("[discord] Failed to edit htop message: {e}");
break;
}
}
});
}
}
return;
}
if let Some(del_cmd) = crate::chat::transport::matrix::delete::extract_delete_command(
message,
&ctx.bot_name,
&ctx.bot_user_id,
) {
let response = match del_cmd {
crate::chat::transport::matrix::delete::DeleteCommand::Delete { story_number } => {
slog!("[discord] Handling delete command from {user}: story {story_number}");
crate::chat::transport::matrix::delete::handle_delete(
&ctx.bot_name,
&story_number,
&ctx.project_root,
&ctx.agents,
)
.await
}
crate::chat::transport::matrix::delete::DeleteCommand::BadArgs => {
format!("Usage: `{} delete <number>`", ctx.bot_name)
}
};
let response = markdown_to_discord(&response);
let _ = ctx.transport.send_message(channel, &response, "").await;
return;
}
if crate::chat::transport::matrix::rebuild::extract_rebuild_command(
message,
&ctx.bot_name,
&ctx.bot_user_id,
)
.is_some()
{
slog!("[discord] Handling rebuild command from {user} in {channel}");
let ack = "Rebuilding server… this may take a moment.";
let _ = ctx.transport.send_message(channel, ack, "").await;
let response = crate::chat::transport::matrix::rebuild::handle_rebuild(
&ctx.bot_name,
&ctx.project_root,
&ctx.agents,
)
.await;
let response = markdown_to_discord(&response);
let _ = ctx.transport.send_message(channel, &response, "").await;
return;
}
if let Some(rmtree_cmd) = crate::chat::transport::matrix::rmtree::extract_rmtree_command(
message,
&ctx.bot_name,
&ctx.bot_user_id,
) {
let response = match rmtree_cmd {
crate::chat::transport::matrix::rmtree::RmtreeCommand::Rmtree { story_number } => {
slog!(
"[discord] Handling rmtree command from {user} in {channel}: story {story_number}"
);
crate::chat::transport::matrix::rmtree::handle_rmtree(
&ctx.bot_name,
&story_number,
&ctx.project_root,
&ctx.agents,
)
.await
}
crate::chat::transport::matrix::rmtree::RmtreeCommand::BadArgs => {
format!("Usage: `{} rmtree <number>`", ctx.bot_name)
}
};
let response = markdown_to_discord(&response);
let _ = ctx.transport.send_message(channel, &response, "").await;
return;
}
if crate::chat::transport::matrix::reset::extract_reset_command(
message,
&ctx.bot_name,
&ctx.bot_user_id,
)
.is_some()
{
slog!("[discord] Handling reset command from {user} in {channel}");
{
let mut guard = ctx.history.lock().await;
let conv = guard
.entry(channel.to_string())
.or_insert_with(RoomConversation::default);
conv.session_id = None;
conv.entries.clear();
save_discord_history(&ctx.project_root, &guard);
}
let _ = ctx
.transport
.send_message(channel, "Session cleared.", "")
.await;
return;
}
if let Some(start_cmd) = crate::chat::transport::matrix::start::extract_start_command(
message,
&ctx.bot_name,
&ctx.bot_user_id,
) {
let response = match start_cmd {
crate::chat::transport::matrix::start::StartCommand::Start {
story_number,
agent_hint,
} => {
slog!(
"[discord] Handling start command from {user} in {channel}: story {story_number}"
);
crate::chat::transport::matrix::start::handle_start(
&ctx.bot_name,
&story_number,
agent_hint.as_deref(),
&ctx.project_root,
&ctx.agents,
)
.await
}
crate::chat::transport::matrix::start::StartCommand::BadArgs => {
format!("Usage: `{} start <number>`", ctx.bot_name)
}
};
let response = markdown_to_discord(&response);
let _ = ctx.transport.send_message(channel, &response, "").await;
return;
}
if let Some(assign_cmd) = crate::chat::transport::matrix::assign::extract_assign_command(
message,
&ctx.bot_name,
&ctx.bot_user_id,
) {
let response = match assign_cmd {
crate::chat::transport::matrix::assign::AssignCommand::Assign {
story_number,
model,
} => {
slog!(
"[discord] Handling assign command from {user} in {channel}: story {story_number} model {model}"
);
crate::chat::transport::matrix::assign::handle_assign(
&ctx.bot_name,
&story_number,
&model,
&ctx.project_root,
&ctx.agents,
)
.await
}
crate::chat::transport::matrix::assign::AssignCommand::BadArgs => {
format!("Usage: `{} assign <number> <model>`", ctx.bot_name)
}
};
let response = markdown_to_discord(&response);
let _ = ctx.transport.send_message(channel, &response, "").await;
return;
}
// No command matched — forward to LLM for conversational response.
slog!("[discord] No command matched, forwarding to LLM for {user} in {channel}");
handle_llm_message(ctx, channel, user, message).await;
}
/// Forward a message to Claude Code and send the response back via Discord.
async fn handle_llm_message(
ctx: &DiscordContext,
channel: &str,
user: &str,
user_message: &str,
) {
use crate::chat::util::drain_complete_paragraphs;
use crate::llm::providers::claude_code::{ClaudeCodeProvider, ClaudeCodeResult};
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::sync::watch;
// Look up existing session ID for this channel.
let resume_session_id: Option<String> = {
let guard = ctx.history.lock().await;
guard
.get(channel)
.and_then(|conv| conv.session_id.clone())
};
let bot_name = &ctx.bot_name;
let prompt = format!(
"[Your name is {bot_name}. Refer to yourself as {bot_name}, not Claude.]\n\n{user}: {user_message}"
);
let provider = ClaudeCodeProvider::new();
let (_cancel_tx, mut cancel_rx) = watch::channel(false);
// Channel for sending complete chunks to the Discord posting task.
let (msg_tx, mut msg_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
let msg_tx_for_callback = msg_tx.clone();
// Spawn a task to post messages as they arrive.
let post_transport = Arc::clone(&ctx.transport) as Arc<dyn ChatTransport>;
let post_channel = channel.to_string();
let post_task = tokio::spawn(async move {
while let Some(chunk) = msg_rx.recv().await {
let formatted = markdown_to_discord(&chunk);
let _ = post_transport
.send_message(&post_channel, &formatted, "")
.await;
}
});
// Shared buffer between the sync token callback and the async scope.
let buffer = Arc::new(std::sync::Mutex::new(String::new()));
let buffer_for_callback = Arc::clone(&buffer);
let sent_any_chunk = Arc::new(AtomicBool::new(false));
let sent_any_chunk_for_callback = Arc::clone(&sent_any_chunk);
let project_root_str = ctx.project_root.to_string_lossy().to_string();
let chat_fut = provider.chat_stream(
&prompt,
&project_root_str,
resume_session_id.as_deref(),
None,
&mut cancel_rx,
move |token| {
let mut buf = buffer_for_callback.lock().unwrap();
buf.push_str(token);
let paragraphs = drain_complete_paragraphs(&mut buf);
for chunk in paragraphs {
sent_any_chunk_for_callback.store(true, Ordering::Relaxed);
let _ = msg_tx_for_callback.send(chunk);
}
},
|_thinking| {},
|_activity| {},
);
tokio::pin!(chat_fut);
// Lock the permission receiver for the duration of this chat session.
let mut perm_rx_guard = ctx.perm_rx.lock().await;
let result = loop {
tokio::select! {
r = &mut chat_fut => break r,
Some(perm_fwd) = perm_rx_guard.recv() => {
let prompt_msg = format!(
"**Permission Request**\n\nTool: `{}`\n```json\n{}\n```\n\nReply **yes** to approve or **no** to deny.",
perm_fwd.tool_name,
serde_json::to_string_pretty(&perm_fwd.tool_input)
.unwrap_or_else(|_| perm_fwd.tool_input.to_string()),
);
let formatted = markdown_to_discord(&prompt_msg);
let _ = ctx.transport.send_message(channel, &formatted, "").await;
ctx.pending_perm_replies
.lock()
.await
.insert(channel.to_string(), perm_fwd.response_tx);
// Spawn a timeout task: auto-deny if the user does not respond.
let pending = Arc::clone(&ctx.pending_perm_replies);
let timeout_channel = channel.to_string();
let timeout_transport = Arc::clone(&ctx.transport) as Arc<dyn ChatTransport>;
let timeout_secs = ctx.permission_timeout_secs;
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_secs(timeout_secs)).await;
if let Some(tx) = pending.lock().await.remove(&timeout_channel) {
let _ = tx.send(PermissionDecision::Deny);
let msg = "Permission request timed out — denied (fail-closed).";
let _ = timeout_transport.send_message(&timeout_channel, msg, "").await;
}
});
}
}
};
drop(perm_rx_guard);
// Flush remaining text.
let remaining = buffer.lock().unwrap().trim().to_string();
let did_send_any = sent_any_chunk.load(Ordering::Relaxed);
let (assistant_reply, new_session_id) = match result {
Ok(ClaudeCodeResult {
messages,
session_id,
}) => {
let reply = if !remaining.is_empty() {
let _ = msg_tx.send(remaining.clone());
remaining
} else if !did_send_any {
let last_text = messages
.iter()
.rev()
.find(|m| {
m.role == crate::llm::types::Role::Assistant && !m.content.is_empty()
})
.map(|m| m.content.clone())
.unwrap_or_default();
if !last_text.is_empty() {
let _ = msg_tx.send(last_text.clone());
}
last_text
} else {
remaining
};
slog!("[discord] session_id from chat_stream: {:?}", session_id);
(reply, session_id)
}
Err(e) => {
slog!("[discord] LLM error: {e}");
let err_msg = format!("Error processing your request: {e}");
let _ = msg_tx.send(err_msg.clone());
(err_msg, None)
}
};
// Signal the posting task to finish and wait for it.
drop(msg_tx);
let _ = post_task.await;
// Record this exchange in conversation history.
if !assistant_reply.starts_with("Error processing") {
let mut guard = ctx.history.lock().await;
let conv = guard.entry(channel.to_string()).or_default();
if new_session_id.is_some() {
conv.session_id = new_session_id;
}
conv.entries.push(ConversationEntry {
role: ConversationRole::User,
sender: user.to_string(),
content: user_message.to_string(),
});
conv.entries.push(ConversationEntry {
role: ConversationRole::Assistant,
sender: String::new(),
content: assistant_reply,
});
// Trim to configured maximum.
if conv.entries.len() > ctx.history_size {
let excess = conv.entries.len() - ctx.history_size;
conv.entries.drain(..excess);
}
save_discord_history(&ctx.project_root, &guard);
}
}
// ── Tests ───────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
fn test_agents() -> Arc<crate::agents::AgentPool> {
Arc::new(crate::agents::AgentPool::new_test(3000))
}
fn test_ambient_rooms() -> Arc<Mutex<HashSet<String>>> {
Arc::new(Mutex::new(HashSet::new()))
}
#[test]
fn command_dispatches_through_command_registry() {
use crate::chat::commands::{CommandDispatch, try_handle_command};
let agents = test_agents();
let ambient_rooms = test_ambient_rooms();
let room_id = "123456789".to_string();
let bot_name = "Huskies";
let synthetic = format!("{bot_name} status");
let dispatch = CommandDispatch {
bot_name,
bot_user_id: "discord-bot",
project_root: std::path::Path::new("/tmp"),
agents: &agents,
ambient_rooms: &ambient_rooms,
room_id: &room_id,
};
let result = try_handle_command(&dispatch, &synthetic);
assert!(
result.is_some(),
"status command should produce output via registry"
);
assert!(result.unwrap().contains("Pipeline Status"));
}
#[test]
fn rebuild_command_extracted_from_discord_message() {
let result = crate::chat::transport::matrix::rebuild::extract_rebuild_command(
"Huskies rebuild",
"Huskies",
"discord-bot",
);
assert!(result.is_some(), "'Huskies rebuild' should be recognised");
}
#[test]
fn reset_command_extracted_from_discord_message() {
let result = crate::chat::transport::matrix::reset::extract_reset_command(
"Huskies reset",
"Huskies",
"discord-bot",
);
assert!(result.is_some(), "'Huskies reset' should be recognised");
}
#[test]
fn start_command_extracted_from_discord_message() {
let result = crate::chat::transport::matrix::start::extract_start_command(
"start 42",
"Huskies",
"discord-bot",
);
assert!(result.is_some(), "plain 'start 42' should be recognised");
assert_eq!(
result,
Some(crate::chat::transport::matrix::start::StartCommand::Start {
story_number: "42".to_string(),
agent_hint: None,
})
);
}
#[test]
fn assign_command_extracted_from_discord_message() {
let result = crate::chat::transport::matrix::assign::extract_assign_command(
"assign 42 opus",
"Huskies",
"discord-bot",
);
assert!(
matches!(
result,
Some(crate::chat::transport::matrix::assign::AssignCommand::Assign { .. })
),
"plain 'assign 42 opus' should be recognised on Discord"
);
}
#[tokio::test]
async fn reset_command_clears_discord_session() {
use crate::chat::transport::matrix::RoomConversation;
let channel = "123456789";
let history: DiscordConversationHistory = Arc::new(TokioMutex::new({
let mut m = HashMap::new();
m.insert(
channel.to_string(),
RoomConversation {
session_id: Some("old-session".to_string()),
entries: vec![ConversationEntry {
role: ConversationRole::User,
sender: "user123".to_string(),
content: "previous message".to_string(),
}],
},
);
m
}));
let tmp = tempfile::tempdir().unwrap();
let sk = tmp.path().join(".huskies");
std::fs::create_dir_all(&sk).unwrap();
{
let mut guard = history.lock().await;
let conv = guard
.entry(channel.to_string())
.or_insert_with(RoomConversation::default);
conv.session_id = None;
conv.entries.clear();
save_discord_history(tmp.path(), &guard);
}
let guard = history.lock().await;
let conv = guard.get(channel).unwrap();
assert!(conv.session_id.is_none(), "session_id should be cleared");
assert!(conv.entries.is_empty(), "entries should be cleared");
}
}
@@ -0,0 +1,55 @@
//! Markdown to Discord format conversion.
//!
//! Discord supports standard Markdown natively, so unlike Slack we only need
//! minimal adjustments (e.g. truncating messages that exceed Discord's limit).
/// Convert Markdown text to Discord-compatible format.
///
/// Discord supports most standard Markdown:
/// - `**bold**`, `*italic*`, `~~strikethrough~~`
/// - `` `code` `` and ````code blocks````
/// - `> blockquotes`
/// - `[text](url)` links
///
/// The main adjustment is ensuring messages stay within Discord's 2000-char
/// limit. Actual truncation is handled by the transport layer; this function
/// performs any lightweight text normalization needed.
pub fn markdown_to_discord(text: &str) -> String {
use crate::chat::util::normalize_line_breaks;
normalize_line_breaks(text)
}
// ── Tests ───────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn discord_plain_text_unchanged() {
let plain = "Hello, this is a plain message.";
assert_eq!(markdown_to_discord(plain), plain);
}
#[test]
fn discord_bold_preserved() {
assert_eq!(markdown_to_discord("**bold text**"), "**bold text**");
}
#[test]
fn discord_code_block_preserved() {
let input = "```rust\nlet x = 1;\n```";
assert_eq!(markdown_to_discord(input), input);
}
#[test]
fn discord_empty_string_unchanged() {
assert_eq!(markdown_to_discord(""), "");
}
#[test]
fn discord_link_preserved() {
let input = "[click here](https://example.com)";
assert_eq!(markdown_to_discord(input), input);
}
}
@@ -0,0 +1,507 @@
//! Minimal Discord Gateway WebSocket client.
//!
//! Connects to the Discord Gateway, authenticates with a bot token, maintains
//! the heartbeat keepalive, and dispatches incoming `MESSAGE_CREATE` events to
//! the command handler.
use std::sync::Arc;
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use tokio_tungstenite::tungstenite::Message as WsMessage;
use crate::slog;
use super::commands::{self, DiscordContext};
// ── Gateway opcodes ─────────────────────────────────────────────────────
const OP_DISPATCH: u8 = 0;
const OP_HEARTBEAT: u8 = 1;
const OP_IDENTIFY: u8 = 2;
const OP_RECONNECT: u8 = 7;
const OP_INVALID_SESSION: u8 = 9;
const OP_HELLO: u8 = 10;
const OP_HEARTBEAT_ACK: u8 = 11;
// ── Gateway intents ─────────────────────────────────────────────────────
/// GUILDS | GUILD_MESSAGES | MESSAGE_CONTENT (privileged)
const GATEWAY_INTENTS: u64 = (1 << 0) | (1 << 9) | (1 << 15);
// ── Gateway payload types ───────────────────────────────────────────────
#[derive(Deserialize, Debug)]
struct GatewayPayload {
op: u8,
#[serde(default)]
d: Option<serde_json::Value>,
#[serde(default)]
s: Option<u64>,
#[serde(default)]
t: Option<String>,
}
#[derive(Deserialize, Debug)]
struct HelloData {
heartbeat_interval: u64,
}
#[derive(Deserialize, Debug)]
pub(super) struct MessageCreateData {
#[serde(default)]
pub channel_id: String,
#[serde(default)]
pub content: String,
#[serde(default)]
pub author: Option<MessageAuthor>,
#[serde(default)]
pub mentions: Vec<MentionUser>,
}
#[derive(Deserialize, Debug)]
pub(super) struct MessageAuthor {
pub id: String,
#[serde(default)]
pub bot: Option<bool>,
}
#[derive(Deserialize, Debug)]
pub(super) struct MentionUser {
pub id: String,
}
#[derive(Deserialize, Debug)]
struct ReadyData {
user: ReadyUser,
}
#[derive(Deserialize, Debug)]
struct ReadyUser {
id: String,
}
// ── Identify payload ────────────────────────────────────────────────────
#[derive(Serialize)]
struct IdentifyPayload {
op: u8,
d: IdentifyData,
}
#[derive(Serialize)]
struct IdentifyData {
token: String,
intents: u64,
properties: IdentifyProperties,
}
#[derive(Serialize)]
struct IdentifyProperties {
os: &'static str,
browser: &'static str,
device: &'static str,
}
#[derive(Serialize)]
struct HeartbeatPayload {
op: u8,
d: Option<u64>,
}
// ── Gateway URL ─────────────────────────────────────────────────────────
const GATEWAY_URL: &str = "wss://gateway.discord.gg/?v=10&encoding=json";
// ── Public entry point ──────────────────────────────────────────────────
/// Spawn a background task that connects to the Discord Gateway, maintains
/// the heartbeat, and dispatches incoming messages.
pub fn spawn_gateway(ctx: Arc<DiscordContext>) {
tokio::spawn(async move {
loop {
if let Err(e) = run_gateway(Arc::clone(&ctx)).await {
slog!("[discord] Gateway connection ended: {e}");
}
// Wait before reconnecting.
slog!("[discord] Reconnecting in 5 seconds…");
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
}
});
}
/// Connect to the Discord Gateway and process events until the connection
/// closes or an unrecoverable error occurs.
async fn run_gateway(ctx: Arc<DiscordContext>) -> Result<(), String> {
slog!("[discord] Connecting to Discord Gateway…");
let (ws_stream, _) = tokio_tungstenite::connect_async(GATEWAY_URL)
.await
.map_err(|e| format!("Gateway connect failed: {e}"))?;
slog!("[discord] Gateway WebSocket connected");
let (mut write, mut read) = ws_stream.split();
// Wait for Hello (op 10) to get the heartbeat interval.
let hello = read
.next()
.await
.ok_or("Gateway closed before Hello")?
.map_err(|e| format!("Gateway read error: {e}"))?;
let hello_payload: GatewayPayload =
parse_ws_message(&hello).ok_or("Failed to parse Hello")?;
if hello_payload.op != OP_HELLO {
return Err(format!(
"Expected Hello (op 10), got op {}",
hello_payload.op
));
}
let hello_data: HelloData =
serde_json::from_value(hello_payload.d.ok_or("Hello missing data")?)
.map_err(|e| format!("Failed to parse Hello data: {e}"))?;
let heartbeat_interval =
std::time::Duration::from_millis(hello_data.heartbeat_interval);
slog!(
"[discord] Heartbeat interval: {}ms",
hello_data.heartbeat_interval
);
// Send Identify.
let identify = IdentifyPayload {
op: OP_IDENTIFY,
d: IdentifyData {
token: ctx.bot_token.clone(),
intents: GATEWAY_INTENTS,
properties: IdentifyProperties {
os: "linux",
browser: "huskies",
device: "huskies",
},
},
};
let identify_json = serde_json::to_string(&identify)
.map_err(|e| format!("Failed to serialise Identify: {e}"))?;
write
.send(WsMessage::Text(identify_json.into()))
.await
.map_err(|e| format!("Failed to send Identify: {e}"))?;
slog!("[discord] Identify sent, waiting for Ready…");
// Track sequence number for heartbeats.
let sequence = Arc::new(std::sync::atomic::AtomicU64::new(0));
let seq_for_heartbeat = Arc::clone(&sequence);
// Share the write half between the heartbeat task and the main event loop.
let write = Arc::new(tokio::sync::Mutex::new(write));
let write_for_heartbeat = Arc::clone(&write);
// Spawn heartbeat task.
let heartbeat_handle = tokio::spawn(async move {
// Jitter: wait a random fraction of the interval before the first beat.
let jitter = heartbeat_interval.mul_f64(rand_fraction());
tokio::time::sleep(jitter).await;
loop {
let seq = seq_for_heartbeat.load(std::sync::atomic::Ordering::Relaxed);
let seq_val = if seq == 0 { None } else { Some(seq) };
let hb = HeartbeatPayload {
op: OP_HEARTBEAT,
d: seq_val,
};
if let Ok(json) = serde_json::to_string(&hb) {
let mut w = write_for_heartbeat.lock().await;
if w.send(WsMessage::Text(json.into())).await.is_err() {
slog!("[discord] Heartbeat send failed, stopping heartbeat");
break;
}
}
tokio::time::sleep(heartbeat_interval).await;
}
});
// Track our own bot user ID (set from READY event).
let mut bot_user_id: Option<String> = None;
// Main event loop.
while let Some(msg_result) = read.next().await {
let msg = match msg_result {
Ok(m) => m,
Err(e) => {
slog!("[discord] Gateway read error: {e}");
break;
}
};
let payload: GatewayPayload = match parse_ws_message(&msg) {
Some(p) => p,
None => continue,
};
// Update sequence number.
if let Some(s) = payload.s {
sequence.store(s, std::sync::atomic::Ordering::Relaxed);
}
match payload.op {
OP_DISPATCH => {
let event_name = payload.t.as_deref().unwrap_or("");
match event_name {
"READY" => {
if let Some(d) = payload.d
&& let Ok(ready) = serde_json::from_value::<ReadyData>(d)
{
bot_user_id = Some(ready.user.id.clone());
slog!(
"[discord] READY — bot user ID: {}",
ready.user.id
);
}
}
"MESSAGE_CREATE" => {
if let Some(d) = payload.d {
dispatch_message(
Arc::clone(&ctx),
d,
bot_user_id.clone(),
);
}
}
_ => {}
}
}
OP_HEARTBEAT => {
// Server requested an immediate heartbeat.
let seq = sequence.load(std::sync::atomic::Ordering::Relaxed);
let seq_val = if seq == 0 { None } else { Some(seq) };
let hb = HeartbeatPayload {
op: OP_HEARTBEAT,
d: seq_val,
};
if let Ok(json) = serde_json::to_string(&hb) {
let mut w = write.lock().await;
let _ = w.send(WsMessage::Text(json.into())).await;
}
}
OP_HEARTBEAT_ACK => {}
OP_RECONNECT => {
slog!("[discord] Gateway requested reconnect");
break;
}
OP_INVALID_SESSION => {
slog!("[discord] Invalid session — will reconnect");
// Wait a bit before reconnecting for invalid sessions.
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
break;
}
_ => {}
}
}
heartbeat_handle.abort();
slog!("[discord] Gateway event loop ended");
Ok(())
}
// ── Helpers ─────────────────────────────────────────────────────────────
fn parse_ws_message(msg: &WsMessage) -> Option<GatewayPayload> {
match msg {
WsMessage::Text(text) => serde_json::from_str(text).ok(),
_ => None,
}
}
/// Filter and dispatch an incoming MESSAGE_CREATE event on a background task.
fn dispatch_message(
ctx: Arc<DiscordContext>,
data: serde_json::Value,
bot_user_id: Option<String>,
) {
tokio::spawn(async move {
let msg: MessageCreateData = match serde_json::from_value(data) {
Ok(m) => m,
Err(e) => {
slog!("[discord] Failed to parse MESSAGE_CREATE: {e}");
return;
}
};
let author = match msg.author.as_ref() {
Some(a) => a,
None => return,
};
// Ignore bot messages (including our own).
if author.bot.unwrap_or(false) {
return;
}
// Check if the channel is in our configured channel list.
if !ctx.channel_ids.contains(&msg.channel_id) {
return;
}
// Check allowed_users (if non-empty, only listed users can interact).
if !ctx.allowed_users.is_empty() && !ctx.allowed_users.contains(&author.id) {
return;
}
// Check if the bot was mentioned, or if we respond to all messages in
// configured channels (ambient mode).
let bot_mentioned = bot_user_id.as_ref().is_some_and(|bid| {
msg.mentions.iter().any(|m| m.id == *bid)
});
let in_ambient = ctx
.ambient_rooms
.lock()
.unwrap()
.contains(&msg.channel_id);
if !bot_mentioned && !in_ambient {
return;
}
// Strip the bot mention from the message content if present.
let content = if let Some(bid) = bot_user_id.as_ref() {
let mention_pat = format!("<@{bid}>");
let mention_pat_nick = format!("<@!{bid}>");
msg.content
.replace(&mention_pat, "")
.replace(&mention_pat_nick, "")
.trim()
.to_string()
} else {
msg.content.clone()
};
if content.is_empty() {
return;
}
slog!(
"[discord] Message from {} in {}: {content:.80}",
author.id,
msg.channel_id
);
commands::handle_incoming_message(&ctx, &msg.channel_id, &author.id, &content)
.await;
});
}
/// Cheap pseudo-random fraction [0.0, 1.0) for heartbeat jitter.
fn rand_fraction() -> f64 {
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.subsec_nanos();
(nanos as f64) / 1_000_000_000.0
}
// ── Tests ───────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_hello_payload() {
let json = r#"{"op": 10, "d": {"heartbeat_interval": 41250}}"#;
let payload: GatewayPayload = serde_json::from_str(json).unwrap();
assert_eq!(payload.op, OP_HELLO);
let hello: HelloData =
serde_json::from_value(payload.d.unwrap()).unwrap();
assert_eq!(hello.heartbeat_interval, 41250);
}
#[test]
fn parse_message_create() {
let json = r#"{
"channel_id": "123456",
"content": "hello bot",
"author": {"id": "789", "bot": false},
"mentions": [{"id": "111"}]
}"#;
let msg: MessageCreateData = serde_json::from_str(json).unwrap();
assert_eq!(msg.channel_id, "123456");
assert_eq!(msg.content, "hello bot");
assert_eq!(msg.author.as_ref().unwrap().id, "789");
assert!(!msg.author.as_ref().unwrap().bot.unwrap_or(false));
assert_eq!(msg.mentions.len(), 1);
assert_eq!(msg.mentions[0].id, "111");
}
#[test]
fn parse_ready_event() {
let json = r#"{"user": {"id": "999888"}}"#;
let ready: ReadyData = serde_json::from_str(json).unwrap();
assert_eq!(ready.user.id, "999888");
}
#[test]
fn identify_payload_serializes_correctly() {
let identify = IdentifyPayload {
op: OP_IDENTIFY,
d: IdentifyData {
token: "test-token".to_string(),
intents: GATEWAY_INTENTS,
properties: IdentifyProperties {
os: "linux",
browser: "huskies",
device: "huskies",
},
},
};
let json: serde_json::Value =
serde_json::from_str(&serde_json::to_string(&identify).unwrap()).unwrap();
assert_eq!(json["op"], 2);
assert_eq!(json["d"]["token"], "test-token");
assert_eq!(json["d"]["intents"], GATEWAY_INTENTS);
assert_eq!(json["d"]["properties"]["browser"], "huskies");
}
#[test]
fn heartbeat_payload_serializes_with_sequence() {
let hb = HeartbeatPayload {
op: OP_HEARTBEAT,
d: Some(42),
};
let json: serde_json::Value =
serde_json::from_str(&serde_json::to_string(&hb).unwrap()).unwrap();
assert_eq!(json["op"], 1);
assert_eq!(json["d"], 42);
}
#[test]
fn heartbeat_payload_serializes_null_sequence() {
let hb = HeartbeatPayload {
op: OP_HEARTBEAT,
d: None,
};
let json: serde_json::Value =
serde_json::from_str(&serde_json::to_string(&hb).unwrap()).unwrap();
assert_eq!(json["op"], 1);
assert!(json["d"].is_null());
}
#[test]
fn rand_fraction_in_range() {
let f = rand_fraction();
assert!((0.0..1.0).contains(&f));
}
#[test]
fn gateway_intents_correct() {
// GUILDS (1) | GUILD_MESSAGES (512) | MESSAGE_CONTENT (32768)
assert_eq!(GATEWAY_INTENTS, 1 | 512 | 32768);
assert_eq!(GATEWAY_INTENTS, 33281);
}
}
@@ -0,0 +1,119 @@
//! Discord conversation history persistence.
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex as TokioMutex;
use crate::chat::transport::matrix::RoomConversation;
use crate::slog;
/// Per-channel conversation history, keyed by channel ID.
pub type DiscordConversationHistory = Arc<TokioMutex<HashMap<String, RoomConversation>>>;
/// On-disk format for persisted Discord conversation history.
#[derive(Serialize, Deserialize)]
struct PersistedDiscordHistory {
channels: HashMap<String, RoomConversation>,
}
/// Path to the persisted Discord conversation history file.
const DISCORD_HISTORY_FILE: &str = ".huskies/discord_history.json";
/// Load Discord conversation history from disk.
pub fn load_discord_history(project_root: &std::path::Path) -> HashMap<String, RoomConversation> {
let path = project_root.join(DISCORD_HISTORY_FILE);
let data = match std::fs::read_to_string(&path) {
Ok(d) => d,
Err(_) => return HashMap::new(),
};
let persisted: PersistedDiscordHistory = match serde_json::from_str(&data) {
Ok(p) => p,
Err(e) => {
slog!("[discord] Failed to parse history file: {e}");
return HashMap::new();
}
};
persisted.channels
}
/// Save Discord conversation history to disk.
pub(super) fn save_discord_history(
project_root: &std::path::Path,
history: &HashMap<String, RoomConversation>,
) {
let persisted = PersistedDiscordHistory {
channels: history.clone(),
};
let path = project_root.join(DISCORD_HISTORY_FILE);
match serde_json::to_string_pretty(&persisted) {
Ok(json) => {
if let Err(e) = std::fs::write(&path, json) {
slog!("[discord] Failed to write history file: {e}");
}
}
Err(e) => slog!("[discord] Failed to serialise history: {e}"),
}
}
// ── Tests ───────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use crate::chat::transport::matrix::{ConversationEntry, ConversationRole};
#[test]
fn save_and_load_discord_history_round_trips() {
let tmp = tempfile::tempdir().unwrap();
let sk = tmp.path().join(".huskies");
std::fs::create_dir_all(&sk).unwrap();
let mut history = HashMap::new();
history.insert(
"123456789".to_string(),
RoomConversation {
session_id: Some("sess-abc".to_string()),
entries: vec![
ConversationEntry {
role: ConversationRole::User,
sender: "user123".to_string(),
content: "hello".to_string(),
},
ConversationEntry {
role: ConversationRole::Assistant,
sender: String::new(),
content: "hi there!".to_string(),
},
],
},
);
save_discord_history(tmp.path(), &history);
let loaded = load_discord_history(tmp.path());
assert_eq!(loaded.len(), 1);
let conv = loaded.get("123456789").unwrap();
assert_eq!(conv.session_id.as_deref(), Some("sess-abc"));
assert_eq!(conv.entries.len(), 2);
assert_eq!(conv.entries[0].content, "hello");
assert_eq!(conv.entries[1].content, "hi there!");
}
#[test]
fn load_discord_history_returns_empty_when_file_missing() {
let tmp = tempfile::tempdir().unwrap();
let history = load_discord_history(tmp.path());
assert!(history.is_empty());
}
#[test]
fn load_discord_history_returns_empty_on_invalid_json() {
let tmp = tempfile::tempdir().unwrap();
let sk = tmp.path().join(".huskies");
std::fs::create_dir_all(&sk).unwrap();
std::fs::write(sk.join("discord_history.json"), "not json {{{").unwrap();
let history = load_discord_history(tmp.path());
assert!(history.is_empty());
}
}
+302
View File
@@ -0,0 +1,302 @@
//! DiscordTransport — ChatTransport implementation for the Discord Bot API.
use async_trait::async_trait;
use serde::Deserialize;
use crate::chat::{ChatTransport, MessageId};
use crate::slog;
// ── Discord API base URL (overridable for tests) ──────────────────────
const DISCORD_API_BASE: &str = "https://discord.com/api/v10";
// ── DiscordTransport ──────────────────────────────────────────────────
/// Discord Bot API transport.
///
/// Sends messages via `POST {API_BASE}/channels/{channel_id}/messages`,
/// edits via `PATCH {API_BASE}/channels/{channel_id}/messages/{message_id}`,
/// and typing indicators via `POST {API_BASE}/channels/{channel_id}/typing`.
pub struct DiscordTransport {
bot_token: String,
client: reqwest::Client,
/// Optional base URL override for tests.
api_base: String,
}
impl DiscordTransport {
pub fn new(bot_token: String) -> Self {
Self {
bot_token,
client: reqwest::Client::new(),
api_base: DISCORD_API_BASE.to_string(),
}
}
#[cfg(test)]
fn with_api_base(bot_token: String, api_base: String) -> Self {
Self {
bot_token,
client: reqwest::Client::new(),
api_base,
}
}
/// Authorization header value: `Bot {token}`.
fn auth_header(&self) -> String {
format!("Bot {}", self.bot_token)
}
}
// ── Discord API response types ────────────────────────────────────────
#[derive(Deserialize, Debug)]
struct DiscordMessage {
id: String,
}
// ── ChatTransport implementation ──────────────────────────────────────
#[async_trait]
impl ChatTransport for DiscordTransport {
async fn send_message(
&self,
channel_id: &str,
plain: &str,
_html: &str,
) -> Result<MessageId, String> {
slog!("[discord] send_message to {channel_id}: {plain:.80}");
let url = format!("{}/channels/{}/messages", self.api_base, channel_id);
// Discord messages have a 2000-char limit. Truncate if needed.
let content = if plain.len() > 2000 {
format!("{}", &plain[..1999])
} else {
plain.to_string()
};
let payload = serde_json::json!({ "content": content });
let resp = self
.client
.post(&url)
.header("Authorization", self.auth_header())
.json(&payload)
.send()
.await
.map_err(|e| format!("Discord API request failed: {e}"))?;
let status = resp.status();
let resp_text = resp
.text()
.await
.unwrap_or_else(|_| "<no body>".to_string());
if !status.is_success() {
return Err(format!("Discord API returned {status}: {resp_text}"));
}
let msg: DiscordMessage = serde_json::from_str(&resp_text).map_err(|e| {
format!("Failed to parse Discord API response: {e} — body: {resp_text}")
})?;
Ok(msg.id)
}
async fn edit_message(
&self,
channel_id: &str,
original_message_id: &str,
plain: &str,
_html: &str,
) -> Result<(), String> {
slog!("[discord] edit_message in {channel_id}: id={original_message_id}");
let url = format!(
"{}/channels/{}/messages/{}",
self.api_base, channel_id, original_message_id
);
let content = if plain.len() > 2000 {
format!("{}", &plain[..1999])
} else {
plain.to_string()
};
let payload = serde_json::json!({ "content": content });
let resp = self
.client
.patch(&url)
.header("Authorization", self.auth_header())
.json(&payload)
.send()
.await
.map_err(|e| format!("Discord edit request failed: {e}"))?;
let status = resp.status();
let resp_text = resp
.text()
.await
.unwrap_or_else(|_| "<no body>".to_string());
if !status.is_success() {
return Err(format!("Discord edit returned {status}: {resp_text}"));
}
Ok(())
}
async fn send_typing(&self, channel_id: &str, typing: bool) -> Result<(), String> {
if !typing {
// Discord doesn't have a "stop typing" API — it times out after ~10s.
return Ok(());
}
let url = format!("{}/channels/{}/typing", self.api_base, channel_id);
self.client
.post(&url)
.header("Authorization", self.auth_header())
.send()
.await
.map_err(|e| format!("Discord typing request failed: {e}"))?;
Ok(())
}
}
// ── Tests ───────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[tokio::test]
async fn transport_send_message_calls_discord_api() {
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("POST", "/channels/123456/messages")
.match_header("authorization", "Bot test-token")
.with_body(r#"{"id": "999888777"}"#)
.create_async()
.await;
let transport =
DiscordTransport::with_api_base("test-token".to_string(), server.url());
let result = transport
.send_message("123456", "hello", "<p>hello</p>")
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "999888777");
mock.assert_async().await;
}
#[tokio::test]
async fn transport_send_message_handles_api_error() {
let mut server = mockito::Server::new_async().await;
server
.mock("POST", "/channels/bad/messages")
.with_status(404)
.with_body(r#"{"message": "Unknown Channel", "code": 10003}"#)
.create_async()
.await;
let transport =
DiscordTransport::with_api_base("test-token".to_string(), server.url());
let result = transport.send_message("bad", "hello", "").await;
assert!(result.is_err());
assert!(result.unwrap_err().contains("404"));
}
#[tokio::test]
async fn transport_edit_message_calls_patch() {
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("PATCH", "/channels/123456/messages/999888777")
.match_header("authorization", "Bot test-token")
.with_body(r#"{"id": "999888777"}"#)
.create_async()
.await;
let transport =
DiscordTransport::with_api_base("test-token".to_string(), server.url());
let result = transport
.edit_message("123456", "999888777", "updated", "")
.await;
assert!(result.is_ok());
mock.assert_async().await;
}
#[tokio::test]
async fn transport_edit_message_handles_error() {
let mut server = mockito::Server::new_async().await;
server
.mock("PATCH", "/channels/123456/messages/bad")
.with_status(404)
.with_body(r#"{"message": "Unknown Message", "code": 10008}"#)
.create_async()
.await;
let transport =
DiscordTransport::with_api_base("test-token".to_string(), server.url());
let result = transport
.edit_message("123456", "bad", "updated", "")
.await;
assert!(result.is_err());
assert!(result.unwrap_err().contains("404"));
}
#[tokio::test]
async fn transport_send_typing_succeeds() {
let mut server = mockito::Server::new_async().await;
server
.mock("POST", "/channels/123456/typing")
.with_status(204)
.create_async()
.await;
let transport =
DiscordTransport::with_api_base("test-token".to_string(), server.url());
assert!(transport.send_typing("123456", true).await.is_ok());
}
#[tokio::test]
async fn transport_send_typing_false_is_noop() {
let transport = DiscordTransport::new("test-token".to_string());
assert!(transport.send_typing("123456", false).await.is_ok());
}
#[tokio::test]
async fn transport_handles_http_error() {
let mut server = mockito::Server::new_async().await;
server
.mock("POST", "/channels/123456/messages")
.with_status(500)
.with_body("Internal Server Error")
.create_async()
.await;
let transport =
DiscordTransport::with_api_base("test-token".to_string(), server.url());
let result = transport.send_message("123456", "hello", "").await;
assert!(result.is_err());
assert!(result.unwrap_err().contains("500"));
}
// ── ChatTransport trait satisfaction ─────────────────────────────────
#[test]
fn discord_transport_satisfies_trait() {
fn assert_transport<T: ChatTransport>() {}
assert_transport::<DiscordTransport>();
let _: Arc<dyn ChatTransport> =
Arc::new(DiscordTransport::new("test-token".to_string()));
}
}
+18
View File
@@ -0,0 +1,18 @@
//! Discord Bot integration.
//!
//! Provides:
//! - [`DiscordTransport`] — a [`ChatTransport`] that sends messages via the
//! Discord REST API (`/channels/{id}/messages`).
//! - [`gateway::spawn_gateway`] — connects to the Discord Gateway WebSocket,
//! receives `MESSAGE_CREATE` events, and dispatches commands.
//! - [`DiscordContext`] — shared context for the bot.
pub mod commands;
pub mod format;
pub mod gateway;
pub mod history;
pub mod meta;
pub use commands::DiscordContext;
pub use history::load_discord_history;
pub use meta::DiscordTransport;
+127 -1
View File
@@ -64,7 +64,8 @@ pub struct BotConfig {
/// manually while the bot is running.
#[serde(default)]
pub ambient_rooms: Vec<String>,
/// Chat transport to use: `"matrix"` (default) or `"whatsapp"`.
/// Chat transport to use: `"matrix"` (default), `"whatsapp"`, `"slack"`,
/// or `"discord"`.
///
/// Selects which [`ChatTransport`] implementation the bot uses for
/// sending and editing messages. Currently only read during bot
@@ -134,6 +135,20 @@ pub struct BotConfig {
/// Slack channel IDs the bot should listen in.
#[serde(default)]
pub slack_channel_ids: Vec<String>,
// ── Discord Bot API fields ──────────────────────────────────────
// These are only required when `transport = "discord"`.
/// Discord bot token from the Discord Developer Portal.
#[serde(default)]
pub discord_bot_token: Option<String>,
/// Discord channel IDs the bot should listen in.
#[serde(default)]
pub discord_channel_ids: Vec<String>,
/// Discord user IDs allowed to interact with the bot.
/// When empty or absent, all users in configured channels are allowed.
#[serde(default)]
pub discord_allowed_users: Vec<String>,
}
fn default_transport() -> String {
@@ -241,6 +256,22 @@ impl BotConfig {
);
return None;
}
} else if config.transport == "discord" {
// Validate Discord-specific fields.
if config.discord_bot_token.as_ref().is_none_or(|s| s.is_empty()) {
eprintln!(
"[bot] bot.toml: transport=\"discord\" requires \
discord_bot_token"
);
return None;
}
if config.discord_channel_ids.is_empty() {
eprintln!(
"[bot] bot.toml: transport=\"discord\" requires \
at least one discord_channel_ids entry"
);
return None;
}
} else {
// Default transport is Matrix — validate Matrix-specific fields.
if config.homeserver.as_ref().is_none_or(|s| s.is_empty()) {
@@ -1054,4 +1085,99 @@ whatsapp_allowed_phones = ["+15551234567", "+15559876543"]
vec!["+15551234567", "+15559876543"]
);
}
// ── Discord config tests ──────────────────────────────────────────
#[test]
fn load_discord_transport_reads_config() {
let tmp = tempfile::tempdir().unwrap();
let sk = tmp.path().join(".huskies");
fs::create_dir_all(&sk).unwrap();
fs::write(
sk.join("bot.toml"),
r#"
enabled = true
transport = "discord"
discord_bot_token = "Bot.Token.Here"
discord_channel_ids = ["123456789012345678"]
"#,
)
.unwrap();
let config = BotConfig::load(tmp.path()).unwrap();
assert_eq!(config.transport, "discord");
assert_eq!(
config.discord_bot_token.as_deref(),
Some("Bot.Token.Here")
);
assert_eq!(
config.discord_channel_ids,
vec!["123456789012345678"]
);
}
#[test]
fn load_discord_returns_none_when_missing_bot_token() {
let tmp = tempfile::tempdir().unwrap();
let sk = tmp.path().join(".huskies");
fs::create_dir_all(&sk).unwrap();
fs::write(
sk.join("bot.toml"),
r#"
enabled = true
transport = "discord"
discord_channel_ids = ["123456789012345678"]
"#,
)
.unwrap();
assert!(BotConfig::load(tmp.path()).is_none());
}
#[test]
fn load_discord_returns_none_when_missing_channel_ids() {
let tmp = tempfile::tempdir().unwrap();
let sk = tmp.path().join(".huskies");
fs::create_dir_all(&sk).unwrap();
fs::write(
sk.join("bot.toml"),
r#"
enabled = true
transport = "discord"
discord_bot_token = "Bot.Token.Here"
"#,
)
.unwrap();
assert!(BotConfig::load(tmp.path()).is_none());
}
#[test]
fn discord_allowed_users_defaults_to_empty_when_absent() {
let config: BotConfig = toml::from_str(
r#"
enabled = true
transport = "discord"
discord_bot_token = "Bot.Token.Here"
discord_channel_ids = ["123456789"]
"#,
)
.unwrap();
assert!(config.discord_allowed_users.is_empty());
}
#[test]
fn discord_allowed_users_deserializes_list() {
let config: BotConfig = toml::from_str(
r#"
enabled = true
transport = "discord"
discord_bot_token = "Bot.Token.Here"
discord_channel_ids = ["123456789"]
discord_allowed_users = ["111222333", "444555666"]
"#,
)
.unwrap();
assert_eq!(
config.discord_allowed_users,
vec!["111222333", "444555666"]
);
}
}
+1
View File
@@ -1,3 +1,4 @@
pub mod discord;
pub mod matrix;
pub mod slack;
pub mod whatsapp;
+76
View File
@@ -339,12 +339,14 @@ async fn main() -> Result<(), std::io::Error> {
// before watcher_tx is moved into AppContext.
let watcher_rx_for_whatsapp = watcher_tx.subscribe();
let watcher_rx_for_slack = watcher_tx.subscribe();
let watcher_rx_for_discord = watcher_tx.subscribe();
// Wrap perm_rx in Arc<Mutex> so it can be shared with both the WebSocket
// handler (via AppContext) and the Matrix bot.
let perm_rx = Arc::new(tokio::sync::Mutex::new(perm_rx));
let perm_rx_for_bot = Arc::clone(&perm_rx);
let perm_rx_for_whatsapp = Arc::clone(&perm_rx);
let perm_rx_for_slack = Arc::clone(&perm_rx);
let perm_rx_for_discord = Arc::clone(&perm_rx);
// Capture project root, agents Arc, and reconciliation sender before ctx
// is consumed by build_routes.
@@ -441,9 +443,49 @@ async fn main() -> Result<(), std::io::Error> {
})
});
// Build Discord context if bot.toml configures transport = "discord".
let discord_ctx: Option<Arc<chat::transport::discord::DiscordContext>> = startup_root
.as_ref()
.and_then(|root| chat::transport::matrix::BotConfig::load(root))
.filter(|cfg| cfg.transport == "discord")
.map(|cfg| {
let transport = Arc::new(chat::transport::discord::DiscordTransport::new(
cfg.discord_bot_token.clone().unwrap_or_default(),
));
let bot_name = cfg
.display_name
.clone()
.unwrap_or_else(|| "Assistant".to_string());
let root = startup_root.clone().unwrap();
let history = chat::transport::discord::load_discord_history(&root);
let channel_ids: std::collections::HashSet<String> =
cfg.discord_channel_ids.iter().cloned().collect();
let allowed_users: std::collections::HashSet<String> =
cfg.discord_allowed_users.iter().cloned().collect();
Arc::new(chat::transport::discord::DiscordContext {
bot_token: cfg.discord_bot_token.clone().unwrap_or_default(),
transport,
project_root: root,
agents: Arc::clone(&startup_agents),
bot_name,
bot_user_id: "discord-bot".to_string(),
ambient_rooms: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
history: std::sync::Arc::new(tokio::sync::Mutex::new(history)),
history_size: cfg.history_size,
channel_ids,
allowed_users,
perm_rx: perm_rx_for_discord,
pending_perm_replies: Arc::new(tokio::sync::Mutex::new(
std::collections::HashMap::new(),
)),
permission_timeout_secs: cfg.permission_timeout_secs,
})
});
// Build a best-effort shutdown notifier for webhook-based transports.
//
// • Slack: channels are fixed at startup (channel_ids from bot.toml).
// • Discord: channels are fixed at startup (channel_ids from bot.toml).
// • WhatsApp: active senders are tracked at runtime in ambient_rooms.
// We keep the WhatsApp context Arc so we can read the rooms at shutdown.
// • Matrix: the bot task manages its own announcement via matrix_shutdown_tx.
@@ -454,6 +496,13 @@ async fn main() -> Result<(), std::io::Error> {
channels,
ctx.bot_name.clone(),
)))
} else if let Some(ref ctx) = discord_ctx {
let channels: Vec<String> = ctx.channel_ids.iter().cloned().collect();
Some(Arc::new(BotShutdownNotifier::new(
Arc::clone(&ctx.transport) as Arc<dyn crate::chat::ChatTransport>,
channels,
ctx.bot_name.clone(),
)))
} else {
None
};
@@ -496,6 +545,18 @@ async fn main() -> Result<(), std::io::Error> {
notifier.notify_startup().await;
});
}
if let Some(ref ctx) = discord_ctx {
let transport = Arc::clone(&ctx.transport) as Arc<dyn crate::chat::ChatTransport>;
let bot_name = ctx.bot_name.clone();
let channels: Vec<String> = ctx.channel_ids.iter().cloned().collect();
tokio::spawn(async move {
if channels.is_empty() {
return;
}
let notifier = crate::rebuild::BotShutdownNotifier::new(transport, channels, bot_name);
notifier.notify_startup().await;
});
}
// Watch channel: signals the Matrix bot task to send a shutdown announcement.
// `None` initial value means "server is running".
@@ -559,6 +620,21 @@ async fn main() -> Result<(), std::io::Error> {
} else {
drop(watcher_rx_for_slack);
}
if let (Some(ctx), Some(root)) = (&discord_ctx, &startup_root) {
// Spawn the Discord Gateway WebSocket listener.
chat::transport::discord::gateway::spawn_gateway(Arc::clone(ctx));
// Spawn stage-transition notification listener for Discord.
let channel_ids: Vec<String> = ctx.channel_ids.iter().cloned().collect();
chat::transport::matrix::notifications::spawn_notification_listener(
Arc::clone(&ctx.transport) as Arc<dyn crate::chat::ChatTransport>,
move || channel_ids.clone(),
watcher_rx_for_discord,
root.clone(),
);
} else {
drop(watcher_rx_for_discord);
}
// On startup:
// 1. Reconcile any stories whose agent work was committed while the server was