//! WebSocket request messages sent by the client. use crate::llm::chat; use crate::llm::types::Message; use serde::Deserialize; /// WebSocket request messages sent by the client. /// /// - `chat` starts a streaming chat session. /// - `cancel` stops the active session. /// - `permission_response` approves or denies a pending permission request. #[derive(Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum WsRequest { Chat { messages: Vec, config: chat::ProviderConfig, }, Cancel, PermissionResponse { request_id: String, approved: bool, #[serde(default)] always_allow: bool, }, /// Heartbeat ping from the client. The server responds with `Pong` so the /// client can detect stale (half-closed) connections. Ping, /// A quick side question answered from current conversation context. /// The question and response are NOT added to the conversation history /// and no tool calls are made. SideQuestion { question: String, context_messages: Vec, config: chat::ProviderConfig, }, } #[cfg(test)] mod tests { use super::*; #[test] fn deserialize_chat_request() { let json = r#"{ "type": "chat", "messages": [ {"role": "user", "content": "hello"} ], "config": { "provider": "ollama", "model": "llama3" } }"#; let req: WsRequest = serde_json::from_str(json).unwrap(); match req { WsRequest::Chat { messages, config } => { assert_eq!(messages.len(), 1); assert_eq!(messages[0].content, "hello"); assert_eq!(config.provider, "ollama"); assert_eq!(config.model, "llama3"); } _ => panic!("expected Chat variant"), } } #[test] fn deserialize_chat_request_with_optional_fields() { let json = r#"{ "type": "chat", "messages": [], "config": { "provider": "anthropic", "model": "claude-3-5-sonnet", "base_url": "https://api.anthropic.com", "enable_tools": true, "session_id": "sess-123" } }"#; let req: WsRequest = serde_json::from_str(json).unwrap(); match req { WsRequest::Chat { messages, config } => { assert!(messages.is_empty()); assert_eq!( config.base_url.as_deref(), Some("https://api.anthropic.com") ); assert_eq!(config.enable_tools, Some(true)); assert_eq!(config.session_id.as_deref(), Some("sess-123")); } _ => panic!("expected Chat variant"), } } #[test] fn deserialize_cancel_request() { let json = r#"{"type": "cancel"}"#; let req: WsRequest = serde_json::from_str(json).unwrap(); assert!(matches!(req, WsRequest::Cancel)); } #[test] fn deserialize_ping_request() { let json = r#"{"type": "ping"}"#; let req: WsRequest = serde_json::from_str(json).unwrap(); assert!(matches!(req, WsRequest::Ping)); } #[test] fn deserialize_permission_response_approved() { let json = r#"{ "type": "permission_response", "request_id": "req-42", "approved": true }"#; let req: WsRequest = serde_json::from_str(json).unwrap(); match req { WsRequest::PermissionResponse { request_id, approved, always_allow, } => { assert_eq!(request_id, "req-42"); assert!(approved); assert!(!always_allow); } _ => panic!("expected PermissionResponse variant"), } } #[test] fn deserialize_permission_response_denied() { let json = r#"{ "type": "permission_response", "request_id": "req-99", "approved": false }"#; let req: WsRequest = serde_json::from_str(json).unwrap(); match req { WsRequest::PermissionResponse { request_id, approved, always_allow, } => { assert_eq!(request_id, "req-99"); assert!(!approved); assert!(!always_allow); } _ => panic!("expected PermissionResponse variant"), } } #[test] fn deserialize_permission_response_always_allow() { let json = r#"{ "type": "permission_response", "request_id": "req-100", "approved": true, "always_allow": true }"#; let req: WsRequest = serde_json::from_str(json).unwrap(); match req { WsRequest::PermissionResponse { request_id, approved, always_allow, } => { assert_eq!(request_id, "req-100"); assert!(approved); assert!(always_allow); } _ => panic!("expected PermissionResponse variant"), } } #[test] fn deserialize_unknown_type_fails() { let json = r#"{"type": "unknown_type"}"#; let result: Result = serde_json::from_str(json); assert!(result.is_err()); } #[test] fn deserialize_invalid_json_fails() { let result: Result = serde_json::from_str("not json"); assert!(result.is_err()); } #[test] fn deserialize_missing_type_tag_fails() { let json = r#"{"messages": [], "config": {"provider": "x", "model": "y"}}"#; let result: Result = serde_json::from_str(json); assert!(result.is_err()); } #[test] fn deserialize_side_question() { let json = r#"{ "type": "side_question", "question": "what is this?", "context_messages": [{"role": "user", "content": "hi"}], "config": {"provider": "ollama", "model": "llama3"} }"#; let req: WsRequest = serde_json::from_str(json).unwrap(); match req { WsRequest::SideQuestion { question, context_messages, config, } => { assert_eq!(question, "what is this?"); assert_eq!(context_messages.len(), 1); assert_eq!(config.model, "llama3"); } _ => panic!("expected SideQuestion variant"), } } #[test] fn deserialize_chat_with_multiple_messages() { let json = r#"{ "type": "chat", "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"}, {"role": "user", "content": "How are you?"} ], "config": { "provider": "ollama", "model": "llama3" } }"#; let req: WsRequest = serde_json::from_str(json).unwrap(); match req { WsRequest::Chat { messages, .. } => { assert_eq!(messages.len(), 4); assert_eq!(messages[0].role, crate::llm::types::Role::System); assert_eq!(messages[3].role, crate::llm::types::Role::User); } _ => panic!("expected Chat variant"), } } #[test] fn deserialize_chat_with_tool_call_message() { let json = r#"{ "type": "chat", "messages": [ { "role": "assistant", "content": "", "tool_calls": [ { "id": "call_1", "type": "function", "function": { "name": "read_file", "arguments": "{\"path\": \"/tmp/test.rs\"}" } } ] } ], "config": { "provider": "anthropic", "model": "claude-3-5-sonnet" } }"#; let req: WsRequest = serde_json::from_str(json).unwrap(); match req { WsRequest::Chat { messages, .. } => { assert_eq!(messages.len(), 1); let tc = messages[0].tool_calls.as_ref().unwrap(); assert_eq!(tc.len(), 1); assert_eq!(tc[0].function.name, "read_file"); } _ => panic!("expected Chat variant"), } } }