diff --git a/server/src/llm/chat.rs b/server/src/llm/chat.rs index 7a2a69b..c951fff 100644 --- a/server/src/llm/chat.rs +++ b/server/src/llm/chat.rs @@ -22,6 +22,7 @@ pub struct ProviderConfig { /// Result of a chat call, including messages and optional metadata. #[allow(dead_code)] +#[derive(Debug)] pub struct ChatResult { pub messages: Vec, /// Session ID returned by Claude Code for resumption. @@ -451,3 +452,547 @@ pub fn cancel_chat(state: &SessionState) -> Result<(), String> { state.cancel_tx.send(true).map_err(|e| e.to_string())?; Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::llm::types::{FunctionCall, ToolCall}; + use crate::state::SessionState; + use serde_json::json; + use std::collections::HashMap; + use std::sync::Mutex; + + // --------------------------------------------------------------------------- + // Minimal in-memory StoreOps mock + // --------------------------------------------------------------------------- + + struct MockStore { + data: Mutex>, + save_should_fail: bool, + } + + impl MockStore { + fn new() -> Self { + Self { + data: Mutex::new(HashMap::new()), + save_should_fail: false, + } + } + + fn with_save_error() -> Self { + Self { + data: Mutex::new(HashMap::new()), + save_should_fail: true, + } + } + + fn with_entry(key: &str, value: serde_json::Value) -> Self { + let mut map = HashMap::new(); + map.insert(key.to_string(), value); + Self { + data: Mutex::new(map), + save_should_fail: false, + } + } + } + + impl StoreOps for MockStore { + fn get(&self, key: &str) -> Option { + self.data.lock().ok().and_then(|m| m.get(key).cloned()) + } + + fn set(&self, key: &str, value: serde_json::Value) { + if let Ok(mut m) = self.data.lock() { + m.insert(key.to_string(), value); + } + } + + fn delete(&self, key: &str) { + if let Ok(mut m) = self.data.lock() { + m.remove(key); + } + } + + fn save(&self) -> Result<(), String> { + if self.save_should_fail { + Err("mock save error".to_string()) + } else { + Ok(()) + } + } + } + + // --------------------------------------------------------------------------- + // parse_tool_arguments + // --------------------------------------------------------------------------- + + #[test] + fn parse_tool_arguments_valid_json() { + let result = parse_tool_arguments(r#"{"path": "src/main.rs"}"#); + assert!(result.is_ok()); + assert_eq!(result.unwrap()["path"], json!("src/main.rs")); + } + + #[test] + fn parse_tool_arguments_invalid_json() { + let result = parse_tool_arguments("not json {{{"); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Error parsing arguments:")); + } + + #[test] + fn parse_tool_arguments_empty_object() { + let result = parse_tool_arguments("{}"); + assert!(result.is_ok()); + } + + // --------------------------------------------------------------------------- + // get_anthropic_api_key_exists_impl + // --------------------------------------------------------------------------- + + #[test] + fn api_key_exists_when_key_is_present_and_non_empty() { + let store = MockStore::with_entry("anthropic_api_key", json!("sk-test-key")); + assert!(get_anthropic_api_key_exists_impl(&store)); + } + + #[test] + fn api_key_exists_returns_false_when_key_is_empty_string() { + let store = MockStore::with_entry("anthropic_api_key", json!("")); + assert!(!get_anthropic_api_key_exists_impl(&store)); + } + + #[test] + fn api_key_exists_returns_false_when_key_absent() { + let store = MockStore::new(); + assert!(!get_anthropic_api_key_exists_impl(&store)); + } + + #[test] + fn api_key_exists_returns_false_when_value_is_not_string() { + let store = MockStore::with_entry("anthropic_api_key", json!(42)); + assert!(!get_anthropic_api_key_exists_impl(&store)); + } + + // --------------------------------------------------------------------------- + // get_anthropic_api_key_impl + // --------------------------------------------------------------------------- + + #[test] + fn get_api_key_returns_key_when_present() { + let store = MockStore::with_entry("anthropic_api_key", json!("sk-test-key")); + let result = get_anthropic_api_key_impl(&store); + assert_eq!(result.unwrap(), "sk-test-key"); + } + + #[test] + fn get_api_key_errors_when_empty() { + let store = MockStore::with_entry("anthropic_api_key", json!("")); + let result = get_anthropic_api_key_impl(&store); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("empty")); + } + + #[test] + fn get_api_key_errors_when_absent() { + let store = MockStore::new(); + let result = get_anthropic_api_key_impl(&store); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("not found")); + } + + #[test] + fn get_api_key_errors_when_value_not_string() { + let store = MockStore::with_entry("anthropic_api_key", json!(123)); + let result = get_anthropic_api_key_impl(&store); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("not a string")); + } + + // --------------------------------------------------------------------------- + // set_anthropic_api_key_impl + // --------------------------------------------------------------------------- + + #[test] + fn set_api_key_stores_and_returns_ok() { + let store = MockStore::new(); + let result = set_anthropic_api_key_impl(&store, "sk-my-key"); + assert!(result.is_ok()); + assert_eq!( + store.get("anthropic_api_key"), + Some(json!("sk-my-key")) + ); + } + + #[test] + fn set_api_key_returns_error_when_save_fails() { + let store = MockStore::with_save_error(); + let result = set_anthropic_api_key_impl(&store, "sk-my-key"); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("mock save error")); + } + + // --------------------------------------------------------------------------- + // Public wrappers: get_anthropic_api_key_exists / set_anthropic_api_key + // --------------------------------------------------------------------------- + + #[test] + fn public_api_key_exists_returns_ok_bool() { + let store = MockStore::with_entry("anthropic_api_key", json!("sk-abc")); + let result = get_anthropic_api_key_exists(&store); + assert_eq!(result, Ok(true)); + } + + #[test] + fn public_api_key_exists_false_when_absent() { + let store = MockStore::new(); + let result = get_anthropic_api_key_exists(&store); + assert_eq!(result, Ok(false)); + } + + #[test] + fn public_set_api_key_succeeds() { + let store = MockStore::new(); + let result = set_anthropic_api_key(&store, "sk-xyz".to_string()); + assert!(result.is_ok()); + } + + #[test] + fn public_set_api_key_propagates_save_error() { + let store = MockStore::with_save_error(); + let result = set_anthropic_api_key(&store, "sk-xyz".to_string()); + assert!(result.is_err()); + } + + // --------------------------------------------------------------------------- + // get_tool_definitions + // --------------------------------------------------------------------------- + + #[test] + fn tool_definitions_returns_five_tools() { + let tools = get_tool_definitions(); + assert_eq!(tools.len(), 5); + } + + #[test] + fn tool_definitions_has_expected_names() { + let tools = get_tool_definitions(); + let names: Vec<&str> = tools.iter().map(|t| t.function.name.as_str()).collect(); + assert!(names.contains(&"read_file")); + assert!(names.contains(&"write_file")); + assert!(names.contains(&"list_directory")); + assert!(names.contains(&"search_files")); + assert!(names.contains(&"exec_shell")); + } + + #[test] + fn tool_definitions_each_has_function_kind() { + let tools = get_tool_definitions(); + for tool in &tools { + assert_eq!(tool.kind, "function"); + } + } + + #[test] + fn tool_definitions_each_has_non_empty_description() { + let tools = get_tool_definitions(); + for tool in &tools { + assert!(!tool.function.description.is_empty()); + } + } + + #[test] + fn tool_definitions_parameters_have_object_type() { + let tools = get_tool_definitions(); + for tool in &tools { + assert_eq!(tool.function.parameters["type"], json!("object")); + } + } + + #[test] + fn read_file_tool_requires_path_parameter() { + let tools = get_tool_definitions(); + let read_file = tools + .iter() + .find(|t| t.function.name == "read_file") + .unwrap(); + let required = read_file.function.parameters["required"] + .as_array() + .unwrap(); + let required_names: Vec<&str> = + required.iter().map(|v| v.as_str().unwrap()).collect(); + assert!(required_names.contains(&"path")); + } + + #[test] + fn exec_shell_tool_requires_command_and_args() { + let tools = get_tool_definitions(); + let exec_shell = tools + .iter() + .find(|t| t.function.name == "exec_shell") + .unwrap(); + let required = exec_shell.function.parameters["required"] + .as_array() + .unwrap(); + let required_names: Vec<&str> = + required.iter().map(|v| v.as_str().unwrap()).collect(); + assert!(required_names.contains(&"command")); + assert!(required_names.contains(&"args")); + } + + // --------------------------------------------------------------------------- + // cancel_chat + // --------------------------------------------------------------------------- + + #[test] + fn cancel_chat_sends_true_to_channel() { + let state = SessionState::default(); + let result = cancel_chat(&state); + assert!(result.is_ok()); + assert!(*state.cancel_rx.borrow()); + } + + // --------------------------------------------------------------------------- + // chat — unsupported provider early return (no network calls) + // --------------------------------------------------------------------------- + + #[tokio::test] + async fn chat_rejects_unknown_provider() { + let state = SessionState::default(); + let store = MockStore::new(); + let config = ProviderConfig { + provider: "unsupported-provider".to_string(), + model: "some-model".to_string(), + base_url: None, + enable_tools: None, + session_id: None, + }; + let messages = vec![Message { + role: Role::User, + content: "hello".to_string(), + tool_calls: None, + tool_call_id: None, + }]; + + let result = chat( + messages, + config, + &state, + &store, + |_| {}, + |_| {}, + |_| {}, + ) + .await; + + assert!(result.is_err()); + assert!(result + .unwrap_err() + .contains("Unsupported provider: unsupported-provider")); + } + + // --------------------------------------------------------------------------- + // chat — ollama path exercises system prompt insertion + tool setup + // (connection to a non-existent port fails fast) + // --------------------------------------------------------------------------- + + #[tokio::test] + async fn chat_ollama_bad_url_fails_with_ollama_error() { + let state = SessionState::default(); + let store = MockStore::new(); + let config = ProviderConfig { + provider: "ollama".to_string(), + model: "llama3".to_string(), + // Port 1 is reserved / closed — connection fails immediately. + base_url: Some("http://127.0.0.1:1".to_string()), + enable_tools: Some(false), + session_id: None, + }; + let messages = vec![Message { + role: Role::User, + content: "ping".to_string(), + tool_calls: None, + tool_call_id: None, + }]; + + let result = chat( + messages, + config, + &state, + &store, + |_| {}, + |_| {}, + |_| {}, + ) + .await; + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(err.contains("Ollama Error:"), "unexpected error: {err}"); + } + + // --------------------------------------------------------------------------- + // chat — Anthropic model prefix detection (fails without API key) + // --------------------------------------------------------------------------- + + #[tokio::test] + async fn chat_claude_model_without_api_key_returns_error() { + let state = SessionState::default(); + // No API key in store → should fail with "API key not found" + let store = MockStore::new(); + let config = ProviderConfig { + provider: "anthropic".to_string(), + model: "claude-3-haiku-20240307".to_string(), + base_url: None, + enable_tools: Some(false), + session_id: None, + }; + let messages = vec![Message { + role: Role::User, + content: "hello".to_string(), + tool_calls: None, + tool_call_id: None, + }]; + + let result = chat( + messages, + config, + &state, + &store, + |_| {}, + |_| {}, + |_| {}, + ) + .await; + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.contains("API key"), + "expected API key error, got: {err}" + ); + } + + // --------------------------------------------------------------------------- + // execute_tool — unknown tool name (no I/O, no network) + // --------------------------------------------------------------------------- + + #[tokio::test] + async fn execute_tool_returns_error_for_unknown_tool() { + let state = SessionState::default(); + let call = ToolCall { + id: Some("call-1".to_string()), + kind: "function".to_string(), + function: FunctionCall { + name: "nonexistent_tool".to_string(), + arguments: "{}".to_string(), + }, + }; + + let result = execute_tool(&call, &state).await; + assert_eq!(result, "Unknown tool: nonexistent_tool"); + } + + #[tokio::test] + async fn execute_tool_returns_parse_error_for_invalid_json_args() { + let state = SessionState::default(); + let call = ToolCall { + id: Some("call-2".to_string()), + kind: "function".to_string(), + function: FunctionCall { + name: "read_file".to_string(), + arguments: "INVALID JSON".to_string(), + }, + }; + + let result = execute_tool(&call, &state).await; + assert!( + result.contains("Error parsing arguments:"), + "unexpected result: {result}" + ); + } + + // --------------------------------------------------------------------------- + // execute_tool — tools that use SessionState (no project root → errors) + // --------------------------------------------------------------------------- + + #[tokio::test] + async fn execute_tool_read_file_no_project_root_returns_error() { + let state = SessionState::default(); // no project_root set + let call = ToolCall { + id: None, + kind: "function".to_string(), + function: FunctionCall { + name: "read_file".to_string(), + arguments: r#"{"path": "some_file.txt"}"#.to_string(), + }, + }; + + let result = execute_tool(&call, &state).await; + assert!(result.starts_with("Error:"), "unexpected result: {result}"); + } + + #[tokio::test] + async fn execute_tool_write_file_no_project_root_returns_error() { + let state = SessionState::default(); + let call = ToolCall { + id: None, + kind: "function".to_string(), + function: FunctionCall { + name: "write_file".to_string(), + arguments: r#"{"path": "out.txt", "content": "hello"}"#.to_string(), + }, + }; + + let result = execute_tool(&call, &state).await; + assert!(result.starts_with("Error:"), "unexpected result: {result}"); + } + + #[tokio::test] + async fn execute_tool_list_directory_no_project_root_returns_error() { + let state = SessionState::default(); + let call = ToolCall { + id: None, + kind: "function".to_string(), + function: FunctionCall { + name: "list_directory".to_string(), + arguments: r#"{"path": "."}"#.to_string(), + }, + }; + + let result = execute_tool(&call, &state).await; + assert!(result.starts_with("Error:"), "unexpected result: {result}"); + } + + #[tokio::test] + async fn execute_tool_search_files_no_project_root_returns_error() { + let state = SessionState::default(); + let call = ToolCall { + id: None, + kind: "function".to_string(), + function: FunctionCall { + name: "search_files".to_string(), + arguments: r#"{"query": "fn main"}"#.to_string(), + }, + }; + + let result = execute_tool(&call, &state).await; + assert!(result.starts_with("Error:"), "unexpected result: {result}"); + } + + #[tokio::test] + async fn execute_tool_exec_shell_no_project_root_returns_error() { + let state = SessionState::default(); + let call = ToolCall { + id: None, + kind: "function".to_string(), + function: FunctionCall { + name: "exec_shell".to_string(), + arguments: r#"{"command": "ls", "args": []}"#.to_string(), + }, + }; + + let result = execute_tool(&call, &state).await; + assert!(result.starts_with("Error:"), "unexpected result: {result}"); + } +}