282 lines
8.8 KiB
Rust
282 lines
8.8 KiB
Rust
//! WebSocket request messages sent by the client.
|
|
|
|
use crate::llm::chat;
|
|
use crate::llm::types::Message;
|
|
use serde::Deserialize;
|
|
|
|
/// WebSocket request messages sent by the client.
|
|
///
|
|
/// - `chat` starts a streaming chat session.
|
|
/// - `cancel` stops the active session.
|
|
/// - `permission_response` approves or denies a pending permission request.
|
|
#[derive(Deserialize)]
|
|
#[serde(tag = "type", rename_all = "snake_case")]
|
|
pub enum WsRequest {
|
|
Chat {
|
|
messages: Vec<Message>,
|
|
config: chat::ProviderConfig,
|
|
},
|
|
Cancel,
|
|
PermissionResponse {
|
|
request_id: String,
|
|
approved: bool,
|
|
#[serde(default)]
|
|
always_allow: bool,
|
|
},
|
|
/// Heartbeat ping from the client. The server responds with `Pong` so the
|
|
/// client can detect stale (half-closed) connections.
|
|
Ping,
|
|
/// A quick side question answered from current conversation context.
|
|
/// The question and response are NOT added to the conversation history
|
|
/// and no tool calls are made.
|
|
SideQuestion {
|
|
question: String,
|
|
context_messages: Vec<Message>,
|
|
config: chat::ProviderConfig,
|
|
},
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn deserialize_chat_request() {
|
|
let json = r#"{
|
|
"type": "chat",
|
|
"messages": [
|
|
{"role": "user", "content": "hello"}
|
|
],
|
|
"config": {
|
|
"provider": "ollama",
|
|
"model": "llama3"
|
|
}
|
|
}"#;
|
|
let req: WsRequest = serde_json::from_str(json).unwrap();
|
|
match req {
|
|
WsRequest::Chat { messages, config } => {
|
|
assert_eq!(messages.len(), 1);
|
|
assert_eq!(messages[0].content, "hello");
|
|
assert_eq!(config.provider, "ollama");
|
|
assert_eq!(config.model, "llama3");
|
|
}
|
|
_ => panic!("expected Chat variant"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn deserialize_chat_request_with_optional_fields() {
|
|
let json = r#"{
|
|
"type": "chat",
|
|
"messages": [],
|
|
"config": {
|
|
"provider": "anthropic",
|
|
"model": "claude-3-5-sonnet",
|
|
"base_url": "https://api.anthropic.com",
|
|
"enable_tools": true,
|
|
"session_id": "sess-123"
|
|
}
|
|
}"#;
|
|
let req: WsRequest = serde_json::from_str(json).unwrap();
|
|
match req {
|
|
WsRequest::Chat { messages, config } => {
|
|
assert!(messages.is_empty());
|
|
assert_eq!(
|
|
config.base_url.as_deref(),
|
|
Some("https://api.anthropic.com")
|
|
);
|
|
assert_eq!(config.enable_tools, Some(true));
|
|
assert_eq!(config.session_id.as_deref(), Some("sess-123"));
|
|
}
|
|
_ => panic!("expected Chat variant"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn deserialize_cancel_request() {
|
|
let json = r#"{"type": "cancel"}"#;
|
|
let req: WsRequest = serde_json::from_str(json).unwrap();
|
|
assert!(matches!(req, WsRequest::Cancel));
|
|
}
|
|
|
|
#[test]
|
|
fn deserialize_ping_request() {
|
|
let json = r#"{"type": "ping"}"#;
|
|
let req: WsRequest = serde_json::from_str(json).unwrap();
|
|
assert!(matches!(req, WsRequest::Ping));
|
|
}
|
|
|
|
#[test]
|
|
fn deserialize_permission_response_approved() {
|
|
let json = r#"{
|
|
"type": "permission_response",
|
|
"request_id": "req-42",
|
|
"approved": true
|
|
}"#;
|
|
let req: WsRequest = serde_json::from_str(json).unwrap();
|
|
match req {
|
|
WsRequest::PermissionResponse {
|
|
request_id,
|
|
approved,
|
|
always_allow,
|
|
} => {
|
|
assert_eq!(request_id, "req-42");
|
|
assert!(approved);
|
|
assert!(!always_allow);
|
|
}
|
|
_ => panic!("expected PermissionResponse variant"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn deserialize_permission_response_denied() {
|
|
let json = r#"{
|
|
"type": "permission_response",
|
|
"request_id": "req-99",
|
|
"approved": false
|
|
}"#;
|
|
let req: WsRequest = serde_json::from_str(json).unwrap();
|
|
match req {
|
|
WsRequest::PermissionResponse {
|
|
request_id,
|
|
approved,
|
|
always_allow,
|
|
} => {
|
|
assert_eq!(request_id, "req-99");
|
|
assert!(!approved);
|
|
assert!(!always_allow);
|
|
}
|
|
_ => panic!("expected PermissionResponse variant"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn deserialize_permission_response_always_allow() {
|
|
let json = r#"{
|
|
"type": "permission_response",
|
|
"request_id": "req-100",
|
|
"approved": true,
|
|
"always_allow": true
|
|
}"#;
|
|
let req: WsRequest = serde_json::from_str(json).unwrap();
|
|
match req {
|
|
WsRequest::PermissionResponse {
|
|
request_id,
|
|
approved,
|
|
always_allow,
|
|
} => {
|
|
assert_eq!(request_id, "req-100");
|
|
assert!(approved);
|
|
assert!(always_allow);
|
|
}
|
|
_ => panic!("expected PermissionResponse variant"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn deserialize_unknown_type_fails() {
|
|
let json = r#"{"type": "unknown_type"}"#;
|
|
let result: Result<WsRequest, _> = serde_json::from_str(json);
|
|
assert!(result.is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn deserialize_invalid_json_fails() {
|
|
let result: Result<WsRequest, _> = serde_json::from_str("not json");
|
|
assert!(result.is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn deserialize_missing_type_tag_fails() {
|
|
let json = r#"{"messages": [], "config": {"provider": "x", "model": "y"}}"#;
|
|
let result: Result<WsRequest, _> = serde_json::from_str(json);
|
|
assert!(result.is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn deserialize_side_question() {
|
|
let json = r#"{
|
|
"type": "side_question",
|
|
"question": "what is this?",
|
|
"context_messages": [{"role": "user", "content": "hi"}],
|
|
"config": {"provider": "ollama", "model": "llama3"}
|
|
}"#;
|
|
let req: WsRequest = serde_json::from_str(json).unwrap();
|
|
match req {
|
|
WsRequest::SideQuestion {
|
|
question,
|
|
context_messages,
|
|
config,
|
|
} => {
|
|
assert_eq!(question, "what is this?");
|
|
assert_eq!(context_messages.len(), 1);
|
|
assert_eq!(config.model, "llama3");
|
|
}
|
|
_ => panic!("expected SideQuestion variant"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn deserialize_chat_with_multiple_messages() {
|
|
let json = r#"{
|
|
"type": "chat",
|
|
"messages": [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "Hello"},
|
|
{"role": "assistant", "content": "Hi there!"},
|
|
{"role": "user", "content": "How are you?"}
|
|
],
|
|
"config": {
|
|
"provider": "ollama",
|
|
"model": "llama3"
|
|
}
|
|
}"#;
|
|
let req: WsRequest = serde_json::from_str(json).unwrap();
|
|
match req {
|
|
WsRequest::Chat { messages, .. } => {
|
|
assert_eq!(messages.len(), 4);
|
|
assert_eq!(messages[0].role, crate::llm::types::Role::System);
|
|
assert_eq!(messages[3].role, crate::llm::types::Role::User);
|
|
}
|
|
_ => panic!("expected Chat variant"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn deserialize_chat_with_tool_call_message() {
|
|
let json = r#"{
|
|
"type": "chat",
|
|
"messages": [
|
|
{
|
|
"role": "assistant",
|
|
"content": "",
|
|
"tool_calls": [
|
|
{
|
|
"id": "call_1",
|
|
"type": "function",
|
|
"function": {
|
|
"name": "read_file",
|
|
"arguments": "{\"path\": \"/tmp/test.rs\"}"
|
|
}
|
|
}
|
|
]
|
|
}
|
|
],
|
|
"config": {
|
|
"provider": "anthropic",
|
|
"model": "claude-3-5-sonnet"
|
|
}
|
|
}"#;
|
|
let req: WsRequest = serde_json::from_str(json).unwrap();
|
|
match req {
|
|
WsRequest::Chat { messages, .. } => {
|
|
assert_eq!(messages.len(), 1);
|
|
let tc = messages[0].tool_calls.as_ref().unwrap();
|
|
assert_eq!(tc.len(), 1);
|
|
assert_eq!(tc[0].function.name, "read_file");
|
|
}
|
|
_ => panic!("expected Chat variant"),
|
|
}
|
|
}
|
|
}
|