Files
huskies/server/src/service/ws/message/request.rs
T

282 lines
8.8 KiB
Rust
Raw Normal View History

2026-04-28 20:31:05 +00:00
//! WebSocket request messages sent by the client.
use crate::llm::chat;
use crate::llm::types::Message;
use serde::Deserialize;
/// WebSocket request messages sent by the client.
///
/// - `chat` starts a streaming chat session.
/// - `cancel` stops the active session.
/// - `permission_response` approves or denies a pending permission request.
#[derive(Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum WsRequest {
Chat {
messages: Vec<Message>,
config: chat::ProviderConfig,
},
Cancel,
PermissionResponse {
request_id: String,
approved: bool,
#[serde(default)]
always_allow: bool,
},
/// Heartbeat ping from the client. The server responds with `Pong` so the
/// client can detect stale (half-closed) connections.
Ping,
/// A quick side question answered from current conversation context.
/// The question and response are NOT added to the conversation history
/// and no tool calls are made.
SideQuestion {
question: String,
context_messages: Vec<Message>,
config: chat::ProviderConfig,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deserialize_chat_request() {
let json = r#"{
"type": "chat",
"messages": [
{"role": "user", "content": "hello"}
],
"config": {
"provider": "ollama",
"model": "llama3"
}
}"#;
let req: WsRequest = serde_json::from_str(json).unwrap();
match req {
WsRequest::Chat { messages, config } => {
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].content, "hello");
assert_eq!(config.provider, "ollama");
assert_eq!(config.model, "llama3");
}
_ => panic!("expected Chat variant"),
}
}
#[test]
fn deserialize_chat_request_with_optional_fields() {
let json = r#"{
"type": "chat",
"messages": [],
"config": {
"provider": "anthropic",
"model": "claude-3-5-sonnet",
"base_url": "https://api.anthropic.com",
"enable_tools": true,
"session_id": "sess-123"
}
}"#;
let req: WsRequest = serde_json::from_str(json).unwrap();
match req {
WsRequest::Chat { messages, config } => {
assert!(messages.is_empty());
assert_eq!(
config.base_url.as_deref(),
Some("https://api.anthropic.com")
);
assert_eq!(config.enable_tools, Some(true));
assert_eq!(config.session_id.as_deref(), Some("sess-123"));
}
_ => panic!("expected Chat variant"),
}
}
#[test]
fn deserialize_cancel_request() {
let json = r#"{"type": "cancel"}"#;
let req: WsRequest = serde_json::from_str(json).unwrap();
assert!(matches!(req, WsRequest::Cancel));
}
#[test]
fn deserialize_ping_request() {
let json = r#"{"type": "ping"}"#;
let req: WsRequest = serde_json::from_str(json).unwrap();
assert!(matches!(req, WsRequest::Ping));
}
#[test]
fn deserialize_permission_response_approved() {
let json = r#"{
"type": "permission_response",
"request_id": "req-42",
"approved": true
}"#;
let req: WsRequest = serde_json::from_str(json).unwrap();
match req {
WsRequest::PermissionResponse {
request_id,
approved,
always_allow,
} => {
assert_eq!(request_id, "req-42");
assert!(approved);
assert!(!always_allow);
}
_ => panic!("expected PermissionResponse variant"),
}
}
#[test]
fn deserialize_permission_response_denied() {
let json = r#"{
"type": "permission_response",
"request_id": "req-99",
"approved": false
}"#;
let req: WsRequest = serde_json::from_str(json).unwrap();
match req {
WsRequest::PermissionResponse {
request_id,
approved,
always_allow,
} => {
assert_eq!(request_id, "req-99");
assert!(!approved);
assert!(!always_allow);
}
_ => panic!("expected PermissionResponse variant"),
}
}
#[test]
fn deserialize_permission_response_always_allow() {
let json = r#"{
"type": "permission_response",
"request_id": "req-100",
"approved": true,
"always_allow": true
}"#;
let req: WsRequest = serde_json::from_str(json).unwrap();
match req {
WsRequest::PermissionResponse {
request_id,
approved,
always_allow,
} => {
assert_eq!(request_id, "req-100");
assert!(approved);
assert!(always_allow);
}
_ => panic!("expected PermissionResponse variant"),
}
}
#[test]
fn deserialize_unknown_type_fails() {
let json = r#"{"type": "unknown_type"}"#;
let result: Result<WsRequest, _> = serde_json::from_str(json);
assert!(result.is_err());
}
#[test]
fn deserialize_invalid_json_fails() {
let result: Result<WsRequest, _> = serde_json::from_str("not json");
assert!(result.is_err());
}
#[test]
fn deserialize_missing_type_tag_fails() {
let json = r#"{"messages": [], "config": {"provider": "x", "model": "y"}}"#;
let result: Result<WsRequest, _> = serde_json::from_str(json);
assert!(result.is_err());
}
#[test]
fn deserialize_side_question() {
let json = r#"{
"type": "side_question",
"question": "what is this?",
"context_messages": [{"role": "user", "content": "hi"}],
"config": {"provider": "ollama", "model": "llama3"}
}"#;
let req: WsRequest = serde_json::from_str(json).unwrap();
match req {
WsRequest::SideQuestion {
question,
context_messages,
config,
} => {
assert_eq!(question, "what is this?");
assert_eq!(context_messages.len(), 1);
assert_eq!(config.model, "llama3");
}
_ => panic!("expected SideQuestion variant"),
}
}
#[test]
fn deserialize_chat_with_multiple_messages() {
let json = r#"{
"type": "chat",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
],
"config": {
"provider": "ollama",
"model": "llama3"
}
}"#;
let req: WsRequest = serde_json::from_str(json).unwrap();
match req {
WsRequest::Chat { messages, .. } => {
assert_eq!(messages.len(), 4);
assert_eq!(messages[0].role, crate::llm::types::Role::System);
assert_eq!(messages[3].role, crate::llm::types::Role::User);
}
_ => panic!("expected Chat variant"),
}
}
#[test]
fn deserialize_chat_with_tool_call_message() {
let json = r#"{
"type": "chat",
"messages": [
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "read_file",
"arguments": "{\"path\": \"/tmp/test.rs\"}"
}
}
]
}
],
"config": {
"provider": "anthropic",
"model": "claude-3-5-sonnet"
}
}"#;
let req: WsRequest = serde_json::from_str(json).unwrap();
match req {
WsRequest::Chat { messages, .. } => {
assert_eq!(messages.len(), 1);
let tc = messages[0].tool_calls.as_ref().unwrap();
assert_eq!(tc.len(), 1);
assert_eq!(tc[0].function.name, "read_file");
}
_ => panic!("expected Chat variant"),
}
}
}