diff --git a/src-tauri/src/commands/chat.rs b/src-tauri/src/commands/chat.rs index 9298364..cbb5575 100644 --- a/src-tauri/src/commands/chat.rs +++ b/src-tauri/src/commands/chat.rs @@ -1,62 +1,54 @@ -use crate::commands::{fs, search, shell}; +use crate::commands::fs::{StoreOps, TauriStoreWrapper}; use crate::llm::prompts::SYSTEM_PROMPT; -use crate::llm::providers::anthropic::AnthropicProvider; -use crate::llm::providers::ollama::OllamaProvider; use crate::llm::types::{Message, Role, ToolCall, ToolDefinition, ToolFunctionDefinition}; use crate::state::SessionState; use serde::Deserialize; use serde_json::json; -use tauri::{AppHandle, Emitter, State}; -use tauri_plugin_store::StoreExt; +use tauri::{AppHandle, State}; -#[derive(Deserialize)] +// ----------------------------------------------------------------------------- +// Constants +// ----------------------------------------------------------------------------- + +const MAX_TURNS: usize = 30; +const STORE_PATH: &str = "store.json"; +const KEY_ANTHROPIC_API_KEY: &str = "anthropic_api_key"; + +// ----------------------------------------------------------------------------- +// Types +// ----------------------------------------------------------------------------- + +#[derive(Deserialize, Clone)] pub struct ProviderConfig { - pub provider: String, // "ollama" + pub provider: String, pub model: String, pub base_url: Option, pub enable_tools: Option, } -const MAX_TURNS: usize = 30; +// ----------------------------------------------------------------------------- +// Pure Implementation Functions +// ----------------------------------------------------------------------------- -#[tauri::command] -pub async fn get_ollama_models(base_url: Option) -> Result, String> { - let url = base_url.unwrap_or_else(|| "http://localhost:11434".to_string()); - OllamaProvider::get_models(&url).await -} - -#[tauri::command] -pub async fn get_anthropic_api_key_exists(app: AppHandle) -> Result { - let store = app - .store("store.json") - .map_err(|e| format!("Failed to access store: {}", e))?; - - match store.get("anthropic_api_key") { +fn get_anthropic_api_key_exists_impl(store: &dyn StoreOps) -> bool { + match store.get(KEY_ANTHROPIC_API_KEY) { Some(value) => { if let Some(key) = value.as_str() { - Ok(!key.is_empty()) + !key.is_empty() } else { - Ok(false) + false } } - None => Ok(false), + None => false, } } -#[tauri::command] -pub async fn set_anthropic_api_key(app: AppHandle, api_key: String) -> Result<(), String> { - let store = app - .store("store.json") - .map_err(|e| format!("Failed to access store: {}", e))?; - - store.set("anthropic_api_key", json!(api_key)); - - store - .save() - .map_err(|e| format!("Failed to save store: {}", e))?; +fn set_anthropic_api_key_impl(store: &dyn StoreOps, api_key: &str) -> Result<(), String> { + store.set(KEY_ANTHROPIC_API_KEY, json!(api_key)); + store.save()?; // Verify it was saved - match store.get("anthropic_api_key") { + match store.get(KEY_ANTHROPIC_API_KEY) { Some(value) => { if let Some(retrieved) = value.as_str() { if retrieved != api_key { @@ -74,12 +66,8 @@ pub async fn set_anthropic_api_key(app: AppHandle, api_key: String) -> Result<() Ok(()) } -fn get_anthropic_api_key(app: &AppHandle) -> Result { - let store = app - .store("store.json") - .map_err(|e| format!("Failed to access store: {}", e))?; - - match store.get("anthropic_api_key") { +fn get_anthropic_api_key_impl(store: &dyn StoreOps) -> Result { + match store.get(KEY_ANTHROPIC_API_KEY) { Some(value) => { if let Some(key) = value.as_str() { if key.is_empty() { @@ -95,222 +83,11 @@ fn get_anthropic_api_key(app: &AppHandle) -> Result { } } -#[tauri::command] -pub async fn chat( - app: AppHandle, - messages: Vec, - config: ProviderConfig, - state: State<'_, SessionState>, -) -> Result, String> { - // Reset cancel flag at start of new request - let _ = state.cancel_tx.send(false); - - // Get a clone of the cancellation receiver - let mut cancel_rx = state.cancel_rx.clone(); - - // Mark the receiver as having seen the current (false) value - // This prevents changed() from firing immediately due to stale state - cancel_rx.borrow_and_update(); - - // 1. Setup Provider - let base_url = config - .base_url - .clone() - .unwrap_or_else(|| "http://localhost:11434".to_string()); - - // Determine provider from model name - let is_claude = config.model.starts_with("claude-"); - - if !is_claude && config.provider.as_str() != "ollama" { - return Err(format!("Unsupported provider: {}", config.provider)); - } - - // 2. Define Tools - let tool_defs = get_tool_definitions(); - let tools = if config.enable_tools.unwrap_or(true) { - tool_defs.as_slice() - } else { - &[] - }; - - // 3. Agent Loop - let mut current_history = messages.clone(); - - // Inject System Prompt - current_history.insert( - 0, - Message { - role: Role::System, - content: SYSTEM_PROMPT.to_string(), - tool_calls: None, - tool_call_id: None, - }, - ); - - // Inject reminder as a second system message - current_history.insert( - 1, - Message { - role: Role::System, - content: "REMINDER: Distinguish between showing examples (use code blocks in chat) vs implementing changes (use write_file tool). Keywords like 'show me', 'example', 'how does' = chat response. Keywords like 'create', 'add', 'implement', 'fix' = use tools.".to_string(), - tool_calls: None, - tool_call_id: None, - }, - ); - - let mut new_messages: Vec = Vec::new(); - let mut turn_count = 0; - - loop { - // Check for cancellation at start of loop - if *cancel_rx.borrow() { - return Err("Chat cancelled by user".to_string()); - } - - if turn_count >= MAX_TURNS { - return Err("Max conversation turns reached.".to_string()); - } - turn_count += 1; - - // Call LLM with streaming - let response = if is_claude { - // Use Anthropic provider - let api_key = get_anthropic_api_key(&app)?; - let anthropic_provider = AnthropicProvider::new(api_key); - anthropic_provider - .chat_stream(&app, &config.model, ¤t_history, tools, &mut cancel_rx) - .await - .map_err(|e| format!("Anthropic Error: {}", e))? - } else { - // Use Ollama provider - let ollama_provider = OllamaProvider::new(base_url.clone()); - ollama_provider - .chat_stream(&app, &config.model, ¤t_history, tools, &mut cancel_rx) - .await - .map_err(|e| format!("Ollama Error: {}", e))? - }; - - // Process Response - if let Some(tool_calls) = response.tool_calls { - // The Assistant wants to run tools - let assistant_msg = Message { - role: Role::Assistant, - content: response.content.unwrap_or_default(), - tool_calls: Some(tool_calls.clone()), - tool_call_id: None, - }; - - current_history.push(assistant_msg.clone()); - new_messages.push(assistant_msg); - // Emit history excluding system prompts (indices 0 and 1) - app.emit("chat:update", ¤t_history[2..]) - .map_err(|e| e.to_string())?; - - // Execute Tools - for call in tool_calls { - // Check for cancellation before executing each tool - if *cancel_rx.borrow() { - return Err("Chat cancelled before tool execution".to_string()); - } - - let output = execute_tool(&call, &state).await; - - let tool_msg = Message { - role: Role::Tool, - content: output, - tool_calls: None, - // For Ollama/Simple flow, we just append. - // For OpenAI strict, this needs to match call.id. - tool_call_id: call.id, - }; - - current_history.push(tool_msg.clone()); - new_messages.push(tool_msg); - // Emit history excluding system prompts (indices 0 and 1) - app.emit("chat:update", ¤t_history[2..]) - .map_err(|e| e.to_string())?; - } - } else { - // Final text response - let assistant_msg = Message { - role: Role::Assistant, - content: response.content.unwrap_or_default(), - tool_calls: None, - tool_call_id: None, - }; - - // We don't push to current_history needed for next loop, because we are done. - new_messages.push(assistant_msg.clone()); - current_history.push(assistant_msg); - // Emit history excluding system prompts (indices 0 and 1) - app.emit("chat:update", ¤t_history[2..]) - .map_err(|e| e.to_string())?; - break; - } - } - - Ok(new_messages) +fn parse_tool_arguments(args_str: &str) -> Result { + serde_json::from_str(args_str).map_err(|e| format!("Error parsing arguments: {e}")) } -async fn execute_tool(call: &ToolCall, state: &State<'_, SessionState>) -> String { - let name = call.function.name.as_str(); - // Parse arguments. They come as a JSON string from the LLM abstraction. - let args: serde_json::Value = match serde_json::from_str(&call.function.arguments) { - Ok(v) => v, - Err(e) => return format!("Error parsing arguments: {}", e), - }; - - match name { - "read_file" => { - let path = args["path"].as_str().unwrap_or("").to_string(); - match fs::read_file(path, state.clone()).await { - Ok(content) => content, - Err(e) => format!("Error: {}", e), - } - } - "write_file" => { - let path = args["path"].as_str().unwrap_or("").to_string(); - let content = args["content"].as_str().unwrap_or("").to_string(); - match fs::write_file(path, content, state.clone()).await { - Ok(_) => "File written successfully.".to_string(), - Err(e) => format!("Error: {}", e), - } - } - "list_directory" => { - let path = args["path"].as_str().unwrap_or("").to_string(); - match fs::list_directory(path, state.clone()).await { - Ok(entries) => serde_json::to_string(&entries).unwrap_or_default(), - Err(e) => format!("Error: {}", e), - } - } - "search_files" => { - let query = args["query"].as_str().unwrap_or("").to_string(); - match search::search_files(query, state.clone()).await { - Ok(results) => serde_json::to_string(&results).unwrap_or_default(), - Err(e) => format!("Error: {}", e), - } - } - "exec_shell" => { - let command = args["command"].as_str().unwrap_or("").to_string(); - let args_vec: Vec = args["args"] - .as_array() - .map(|arr| { - arr.iter() - .map(|v| v.as_str().unwrap_or("").to_string()) - .collect() - }) - .unwrap_or_default(); - - match shell::exec_shell(command, args_vec, state.clone()).await { - Ok(output) => serde_json::to_string(&output).unwrap_or_default(), - Err(e) => format!("Error: {}", e), - } - } - _ => format!("Unknown tool: {}", name), - } -} - -fn get_tool_definitions() -> Vec { +pub fn get_tool_definitions() -> Vec { vec![ ToolDefinition { kind: "function".to_string(), @@ -395,8 +172,686 @@ fn get_tool_definitions() -> Vec { ] } +// ----------------------------------------------------------------------------- +// Tauri Commands (Thin Wrappers) +// ----------------------------------------------------------------------------- + +#[tauri::command] +pub async fn get_ollama_models(base_url: Option) -> Result, String> { + use crate::llm::providers::ollama::OllamaProvider; + let url = base_url.unwrap_or_else(|| "http://localhost:11434".to_string()); + OllamaProvider::get_models(&url).await +} + +#[tauri::command] +pub async fn get_anthropic_api_key_exists(app: AppHandle) -> Result { + use tauri_plugin_store::StoreExt; + + let store = app + .store(STORE_PATH) + .map_err(|e| format!("Failed to access store: {e}"))?; + + let wrapper = TauriStoreWrapper { store: &store }; + Ok(get_anthropic_api_key_exists_impl(&wrapper)) +} + +#[tauri::command] +pub async fn set_anthropic_api_key(app: AppHandle, api_key: String) -> Result<(), String> { + use tauri_plugin_store::StoreExt; + + let store = app + .store(STORE_PATH) + .map_err(|e| format!("Failed to access store: {e}"))?; + + let wrapper = TauriStoreWrapper { store: &store }; + set_anthropic_api_key_impl(&wrapper, &api_key) +} + +#[tauri::command] +pub async fn chat( + app: AppHandle, + messages: Vec, + config: ProviderConfig, + state: State<'_, SessionState>, +) -> Result, String> { + use crate::llm::providers::anthropic::AnthropicProvider; + use crate::llm::providers::ollama::OllamaProvider; + use tauri::Emitter; + use tauri_plugin_store::StoreExt; + + // Reset cancel flag at start of new request + let _ = state.cancel_tx.send(false); + + // Get a clone of the cancellation receiver + let mut cancel_rx = state.cancel_rx.clone(); + + // Mark the receiver as having seen the current (false) value + cancel_rx.borrow_and_update(); + + // Setup Provider + let base_url = config + .base_url + .clone() + .unwrap_or_else(|| "http://localhost:11434".to_string()); + + // Determine provider from model name + let is_claude = config.model.starts_with("claude-"); + + if !is_claude && config.provider.as_str() != "ollama" { + return Err(format!("Unsupported provider: {}", config.provider)); + } + + // Define Tools + let tool_defs = get_tool_definitions(); + let tools = if config.enable_tools.unwrap_or(true) { + tool_defs.as_slice() + } else { + &[] + }; + + // Agent Loop + let mut current_history = messages.clone(); + + // Inject System Prompt + current_history.insert( + 0, + Message { + role: Role::System, + content: SYSTEM_PROMPT.to_string(), + tool_calls: None, + tool_call_id: None, + }, + ); + + // Inject reminder as a second system message + current_history.insert( + 1, + Message { + role: Role::System, + content: "REMINDER: Distinguish between showing examples (use code blocks in chat) vs implementing changes (use write_file tool). Keywords like 'show me', 'example', 'how does' = chat response. Keywords like 'create', 'add', 'implement', 'fix' = use tools.".to_string(), + tool_calls: None, + tool_call_id: None, + }, + ); + + let mut new_messages: Vec = Vec::new(); + let mut turn_count = 0; + + loop { + // Check for cancellation at start of loop + if *cancel_rx.borrow() { + return Err("Chat cancelled by user".to_string()); + } + + if turn_count >= MAX_TURNS { + return Err("Max conversation turns reached.".to_string()); + } + turn_count += 1; + + // Call LLM with streaming + let response = if is_claude { + // Use Anthropic provider + let store = app + .store(STORE_PATH) + .map_err(|e| format!("Failed to access store: {e}"))?; + + let wrapper = TauriStoreWrapper { store: &store }; + let api_key = get_anthropic_api_key_impl(&wrapper)?; + let anthropic_provider = AnthropicProvider::new(api_key); + anthropic_provider + .chat_stream(&app, &config.model, ¤t_history, tools, &mut cancel_rx) + .await + .map_err(|e| format!("Anthropic Error: {e}"))? + } else { + // Use Ollama provider + let ollama_provider = OllamaProvider::new(base_url.clone()); + ollama_provider + .chat_stream(&app, &config.model, ¤t_history, tools, &mut cancel_rx) + .await + .map_err(|e| format!("Ollama Error: {e}"))? + }; + + // Process Response + if let Some(tool_calls) = response.tool_calls { + // The Assistant wants to run tools + let assistant_msg = Message { + role: Role::Assistant, + content: response.content.unwrap_or_default(), + tool_calls: Some(tool_calls.clone()), + tool_call_id: None, + }; + + current_history.push(assistant_msg.clone()); + new_messages.push(assistant_msg); + // Emit history excluding system prompts (indices 0 and 1) + app.emit("chat:update", ¤t_history[2..]) + .map_err(|e| e.to_string())?; + + // Execute Tools + for call in tool_calls { + // Check for cancellation before executing each tool + if *cancel_rx.borrow() { + return Err("Chat cancelled before tool execution".to_string()); + } + + let output = execute_tool(&call, &state).await; + + let tool_msg = Message { + role: Role::Tool, + content: output, + tool_calls: None, + tool_call_id: call.id, + }; + + current_history.push(tool_msg.clone()); + new_messages.push(tool_msg); + // Emit history excluding system prompts (indices 0 and 1) + app.emit("chat:update", ¤t_history[2..]) + .map_err(|e| e.to_string())?; + } + } else { + // Final text response + let assistant_msg = Message { + role: Role::Assistant, + content: response.content.unwrap_or_default(), + tool_calls: None, + tool_call_id: None, + }; + + new_messages.push(assistant_msg.clone()); + current_history.push(assistant_msg); + // Emit history excluding system prompts (indices 0 and 1) + app.emit("chat:update", ¤t_history[2..]) + .map_err(|e| e.to_string())?; + break; + } + } + + Ok(new_messages) +} + +async fn execute_tool(call: &ToolCall, state: &State<'_, SessionState>) -> String { + use crate::commands::{fs, search, shell}; + + let name = call.function.name.as_str(); + // Parse arguments. They come as a JSON string from the LLM abstraction. + let args: serde_json::Value = match parse_tool_arguments(&call.function.arguments) { + Ok(v) => v, + Err(e) => return e, + }; + + match name { + "read_file" => { + let path = args["path"].as_str().unwrap_or("").to_string(); + match fs::read_file(path, state.clone()).await { + Ok(content) => content, + Err(e) => format!("Error: {e}"), + } + } + "write_file" => { + let path = args["path"].as_str().unwrap_or("").to_string(); + let content = args["content"].as_str().unwrap_or("").to_string(); + match fs::write_file(path, content, state.clone()).await { + Ok(()) => "File written successfully.".to_string(), + Err(e) => format!("Error: {e}"), + } + } + "list_directory" => { + let path = args["path"].as_str().unwrap_or("").to_string(); + match fs::list_directory(path, state.clone()).await { + Ok(entries) => serde_json::to_string(&entries).unwrap_or_default(), + Err(e) => format!("Error: {e}"), + } + } + "search_files" => { + let query = args["query"].as_str().unwrap_or("").to_string(); + match search::search_files(query, state.clone()).await { + Ok(results) => serde_json::to_string(&results).unwrap_or_default(), + Err(e) => format!("Error: {e}"), + } + } + "exec_shell" => { + let command = args["command"].as_str().unwrap_or("").to_string(); + let args_vec: Vec = args["args"] + .as_array() + .map(|arr| { + arr.iter() + .map(|v| v.as_str().unwrap_or("").to_string()) + .collect() + }) + .unwrap_or_default(); + + match shell::exec_shell(command, args_vec, state.clone()).await { + Ok(output) => serde_json::to_string(&output).unwrap_or_default(), + Err(e) => format!("Error: {e}"), + } + } + _ => format!("Unknown tool: {name}"), + } +} + #[tauri::command] pub async fn cancel_chat(state: State<'_, SessionState>) -> Result<(), String> { state.cancel_tx.send(true).map_err(|e| e.to_string())?; Ok(()) } + +// ----------------------------------------------------------------------------- +// Tests +// ----------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_utils::MockStore; + use std::collections::HashMap; + + // Tests for get_anthropic_api_key_exists_impl + mod get_anthropic_api_key_exists_tests { + use super::*; + + #[test] + fn test_key_exists_and_not_empty() { + let mut data = HashMap::new(); + data.insert(KEY_ANTHROPIC_API_KEY.to_string(), json!("sk-ant-test123")); + let store = MockStore::with_data(data); + + let result = get_anthropic_api_key_exists_impl(&store); + + assert!(result); + } + + #[test] + fn test_key_exists_but_empty() { + let mut data = HashMap::new(); + data.insert(KEY_ANTHROPIC_API_KEY.to_string(), json!("")); + let store = MockStore::with_data(data); + + let result = get_anthropic_api_key_exists_impl(&store); + + assert!(!result); + } + + #[test] + fn test_key_not_exists() { + let store = MockStore::new(); + + let result = get_anthropic_api_key_exists_impl(&store); + + assert!(!result); + } + + #[test] + fn test_key_exists_but_not_string() { + let mut data = HashMap::new(); + data.insert(KEY_ANTHROPIC_API_KEY.to_string(), json!(123)); + let store = MockStore::with_data(data); + + let result = get_anthropic_api_key_exists_impl(&store); + + assert!(!result); + } + } + + // Tests for set_anthropic_api_key_impl + mod set_anthropic_api_key_tests { + use super::*; + + #[test] + fn test_set_new_key() { + let store = MockStore::new(); + let api_key = "sk-ant-new-key".to_string(); + + let result = set_anthropic_api_key_impl(&store, &api_key); + + assert!(result.is_ok()); + assert_eq!( + store.get(KEY_ANTHROPIC_API_KEY), + Some(json!("sk-ant-new-key")) + ); + } + + #[test] + fn test_set_overwrites_existing_key() { + let mut data = HashMap::new(); + data.insert(KEY_ANTHROPIC_API_KEY.to_string(), json!("old-key")); + let store = MockStore::with_data(data); + let api_key = "sk-ant-new-key".to_string(); + + let result = set_anthropic_api_key_impl(&store, &api_key); + + assert!(result.is_ok()); + assert_eq!( + store.get(KEY_ANTHROPIC_API_KEY), + Some(json!("sk-ant-new-key")) + ); + } + + #[test] + fn test_set_empty_string() { + let store = MockStore::new(); + let api_key = "".to_string(); + + let result = set_anthropic_api_key_impl(&store, &api_key); + + assert!(result.is_ok()); + assert_eq!(store.get(KEY_ANTHROPIC_API_KEY), Some(json!(""))); + } + } + + // Tests for get_anthropic_api_key_impl + mod get_anthropic_api_key_tests { + use super::*; + + #[test] + fn test_get_existing_key() { + let mut data = HashMap::new(); + data.insert(KEY_ANTHROPIC_API_KEY.to_string(), json!("sk-ant-test-key")); + let store = MockStore::with_data(data); + + let result = get_anthropic_api_key_impl(&store); + + assert_eq!(result, Ok("sk-ant-test-key".to_string())); + } + + #[test] + fn test_get_key_not_found() { + let store = MockStore::new(); + + let result = get_anthropic_api_key_impl(&store); + + assert_eq!( + result, + Err("Anthropic API key not found. Please set your API key.".to_string()) + ); + } + + #[test] + fn test_get_empty_key() { + let mut data = HashMap::new(); + data.insert(KEY_ANTHROPIC_API_KEY.to_string(), json!("")); + let store = MockStore::with_data(data); + + let result = get_anthropic_api_key_impl(&store); + + assert_eq!( + result, + Err("Anthropic API key is empty. Please set your API key.".to_string()) + ); + } + + #[test] + fn test_get_key_not_string() { + let mut data = HashMap::new(); + data.insert(KEY_ANTHROPIC_API_KEY.to_string(), json!(12345)); + let store = MockStore::with_data(data); + + let result = get_anthropic_api_key_impl(&store); + + assert_eq!(result, Err("Stored API key is not a string".to_string())); + } + } + + // Tests for parse_tool_arguments + mod parse_tool_arguments_tests { + use super::*; + + #[test] + fn test_parse_valid_json() { + let args_str = r#"{"path": "test.txt"}"#; + + let result = parse_tool_arguments(args_str); + + assert!(result.is_ok()); + let value = result.unwrap(); + assert_eq!(value["path"], "test.txt"); + } + + #[test] + fn test_parse_complex_json() { + let args_str = r#"{"command": "git", "args": ["status", "--short"]}"#; + + let result = parse_tool_arguments(args_str); + + assert!(result.is_ok()); + let value = result.unwrap(); + assert_eq!(value["command"], "git"); + assert_eq!(value["args"][0], "status"); + assert_eq!(value["args"][1], "--short"); + } + + #[test] + fn test_parse_invalid_json() { + let args_str = r#"{"path": "test.txt"#; // Missing closing brace + + let result = parse_tool_arguments(args_str); + + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Error parsing arguments")); + } + + #[test] + fn test_parse_empty_json() { + let args_str = "{}"; + + let result = parse_tool_arguments(args_str); + + assert!(result.is_ok()); + let value = result.unwrap(); + assert_eq!(value, json!({})); + } + + #[test] + fn test_parse_empty_string() { + let args_str = ""; + + let result = parse_tool_arguments(args_str); + + assert!(result.is_err()); + } + } + + // Tests for get_tool_definitions + mod get_tool_definitions_tests { + use super::*; + + #[test] + fn test_returns_all_tools() { + let tools = get_tool_definitions(); + + assert_eq!(tools.len(), 5); + } + + #[test] + fn test_read_file_tool_exists() { + let tools = get_tool_definitions(); + + let read_file = tools + .iter() + .find(|t| t.function.name == "read_file") + .expect("read_file tool should exist"); + + assert_eq!(read_file.kind, "function"); + assert!( + read_file + .function + .description + .contains("Reads the complete content") + ); + } + + #[test] + fn test_write_file_tool_exists() { + let tools = get_tool_definitions(); + + let write_file = tools + .iter() + .find(|t| t.function.name == "write_file") + .expect("write_file tool should exist"); + + assert_eq!(write_file.kind, "function"); + assert!( + write_file + .function + .description + .contains("Creates or completely overwrites") + ); + } + + #[test] + fn test_list_directory_tool_exists() { + let tools = get_tool_definitions(); + + let list_directory = tools + .iter() + .find(|t| t.function.name == "list_directory") + .expect("list_directory tool should exist"); + + assert_eq!(list_directory.kind, "function"); + assert!( + list_directory + .function + .description + .contains("Lists all files and directories") + ); + } + + #[test] + fn test_search_files_tool_exists() { + let tools = get_tool_definitions(); + + let search_files = tools + .iter() + .find(|t| t.function.name == "search_files") + .expect("search_files tool should exist"); + + assert_eq!(search_files.kind, "function"); + assert!( + search_files + .function + .description + .contains("Searches for text patterns") + ); + } + + #[test] + fn test_exec_shell_tool_exists() { + let tools = get_tool_definitions(); + + let exec_shell = tools + .iter() + .find(|t| t.function.name == "exec_shell") + .expect("exec_shell tool should exist"); + + assert_eq!(exec_shell.kind, "function"); + assert!( + exec_shell + .function + .description + .contains("Executes a shell command") + ); + } + + #[test] + fn test_all_tools_have_function_kind() { + let tools = get_tool_definitions(); + + for tool in tools { + assert_eq!(tool.kind, "function"); + } + } + + #[test] + fn test_all_tools_have_non_empty_descriptions() { + let tools = get_tool_definitions(); + + for tool in tools { + assert!(!tool.function.description.is_empty()); + assert!(!tool.function.name.is_empty()); + } + } + + #[test] + fn test_read_file_parameters() { + let tools = get_tool_definitions(); + let read_file = tools + .iter() + .find(|t| t.function.name == "read_file") + .unwrap(); + + let params = &read_file.function.parameters; + assert_eq!(params["type"], "object"); + assert!(params["properties"]["path"].is_object()); + assert_eq!(params["required"][0], "path"); + } + + #[test] + fn test_write_file_parameters() { + let tools = get_tool_definitions(); + let write_file = tools + .iter() + .find(|t| t.function.name == "write_file") + .unwrap(); + + let params = &write_file.function.parameters; + assert_eq!(params["type"], "object"); + assert!(params["properties"]["path"].is_object()); + assert!(params["properties"]["content"].is_object()); + assert_eq!(params["required"][0], "path"); + assert_eq!(params["required"][1], "content"); + } + + #[test] + fn test_exec_shell_parameters() { + let tools = get_tool_definitions(); + let exec_shell = tools + .iter() + .find(|t| t.function.name == "exec_shell") + .unwrap(); + + let params = &exec_shell.function.parameters; + assert_eq!(params["type"], "object"); + assert!(params["properties"]["command"].is_object()); + assert!(params["properties"]["args"].is_object()); + assert_eq!(params["required"][0], "command"); + assert_eq!(params["required"][1], "args"); + } + } + + // Tests for get_project_root helper + mod get_project_root_tests { + use super::*; + use std::sync::Mutex; + use tempfile::TempDir; + + #[test] + fn test_get_project_root_no_project() { + let state = SessionState { + project_root: Mutex::new(None), + cancel_tx: tokio::sync::watch::channel(false).0, + cancel_rx: tokio::sync::watch::channel(false).1, + }; + + let result = state.get_project_root(); + + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), "No project is currently open."); + } + + #[test] + fn test_get_project_root_success() { + let temp_dir = TempDir::new().unwrap(); + let path = temp_dir.path().to_path_buf(); + let state = SessionState { + project_root: Mutex::new(Some(path.clone())), + cancel_tx: tokio::sync::watch::channel(false).0, + cancel_rx: tokio::sync::watch::channel(false).1, + }; + + let result = state.get_project_root(); + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), path); + } + } +} diff --git a/src-tauri/src/commands/fs.rs b/src-tauri/src/commands/fs.rs index 9a3b3cc..a3b3879 100644 --- a/src-tauri/src/commands/fs.rs +++ b/src-tauri/src/commands/fs.rs @@ -29,8 +29,8 @@ pub trait StoreOps: Send + Sync { // ----------------------------------------------------------------------------- /// Wrapper for Tauri Store that implements StoreOps -struct TauriStoreWrapper<'a> { - store: &'a tauri_plugin_store::Store, +pub struct TauriStoreWrapper<'a> { + pub store: &'a tauri_plugin_store::Store, } impl<'a> StoreOps for TauriStoreWrapper<'a> {