diff --git a/Cargo.lock b/Cargo.lock index 7183238..cc0d8cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,16 @@ version = "1.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f0e0fee31ef5ed1ba1316088939cea399010ed7731dba877ed44aeb407a75ea" +[[package]] +name = "assert-json-diff" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "async-stream" version = "0.3.6" @@ -207,6 +217,15 @@ dependencies = [ "cc", ] +[[package]] +name = "colored" +version = "3.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "faf9468729b8cbcea668e36183cb69d317348c2e08e994829fb56ebfdfbaac34" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "combine" version = "4.6.7" @@ -1212,6 +1231,31 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "mockito" +version = "1.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90820618712cab19cfc46b274c6c22546a82affcb3c3bdf0f29e3db8e1bb92c0" +dependencies = [ + "assert-json-diff", + "bytes 1.11.1", + "colored", + "futures-core", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "log", + "pin-project-lite", + "rand", + "regex", + "serde_json", + "serde_urlencoded", + "similar", + "tokio", +] + [[package]] name = "multer" version = "3.1.0" @@ -2085,6 +2129,12 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "similar" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" + [[package]] name = "slab" version = "0.4.12" @@ -2144,6 +2194,7 @@ dependencies = [ "homedir", "ignore", "mime_guess", + "mockito", "notify", "poem", "poem-openapi", @@ -2333,6 +2384,7 @@ dependencies = [ "bytes 1.11.1", "libc", "mio", + "parking_lot", "pin-project-lite", "socket2", "tokio-macros", diff --git a/server/Cargo.toml b/server/Cargo.toml index 16a1ea9..37e9fb0 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -32,3 +32,4 @@ walkdir = { workspace = true } [dev-dependencies] tempfile = { workspace = true } tokio-tungstenite = { workspace = true } +mockito = "1" diff --git a/server/src/llm/providers/anthropic.rs b/server/src/llm/providers/anthropic.rs index abd9ee1..8179d49 100644 --- a/server/src/llm/providers/anthropic.rs +++ b/server/src/llm/providers/anthropic.rs @@ -13,6 +13,7 @@ const ANTHROPIC_VERSION: &str = "2023-06-01"; pub struct AnthropicProvider { api_key: String, client: reqwest::Client, + api_url: String, } #[derive(Debug, Serialize, Deserialize)] @@ -66,6 +67,16 @@ impl AnthropicProvider { Self { api_key, client: reqwest::Client::new(), + api_url: ANTHROPIC_API_URL.to_string(), + } + } + + #[cfg(test)] + fn new_with_url(api_key: String, api_url: String) -> Self { + Self { + api_key, + client: reqwest::Client::new(), + api_url, } } @@ -201,7 +212,7 @@ impl AnthropicProvider { let response = self .client - .post(ANTHROPIC_API_URL) + .post(&self.api_url) .headers(headers) .json(&request_body) .send() @@ -312,3 +323,546 @@ impl AnthropicProvider { }) } } + +#[cfg(test)] +mod tests { + use super::{AnthropicContent, AnthropicContentBlock, AnthropicProvider}; + use crate::llm::types::{ + FunctionCall, Message, Role, ToolCall, ToolDefinition, ToolFunctionDefinition, + }; + use serde_json::json; + + fn user_msg(content: &str) -> Message { + Message { + role: Role::User, + content: content.to_string(), + tool_calls: None, + tool_call_id: None, + } + } + + fn system_msg(content: &str) -> Message { + Message { + role: Role::System, + content: content.to_string(), + tool_calls: None, + tool_call_id: None, + } + } + + fn assistant_msg(content: &str) -> Message { + Message { + role: Role::Assistant, + content: content.to_string(), + tool_calls: None, + tool_call_id: None, + } + } + + fn make_tool_def(name: &str) -> ToolDefinition { + ToolDefinition { + kind: "function".to_string(), + function: ToolFunctionDefinition { + name: name.to_string(), + description: format!("{name} description"), + parameters: json!({"type": "object", "properties": {}}), + }, + } + } + + // ── convert_tools ──────────────────────────────────────────────────────── + + #[test] + fn test_convert_tools_empty() { + let result = AnthropicProvider::convert_tools(&[]); + assert!(result.is_empty()); + } + + #[test] + fn test_convert_tools_single() { + let tool = make_tool_def("search_files"); + let result = AnthropicProvider::convert_tools(&[tool]); + assert_eq!(result.len(), 1); + assert_eq!(result[0].name, "search_files"); + assert_eq!(result[0].description, "search_files description"); + assert_eq!( + result[0].input_schema, + json!({"type": "object", "properties": {}}) + ); + } + + #[test] + fn test_convert_tools_multiple() { + let tools = vec![make_tool_def("read_file"), make_tool_def("write_file")]; + let result = AnthropicProvider::convert_tools(&tools); + assert_eq!(result.len(), 2); + assert_eq!(result[0].name, "read_file"); + assert_eq!(result[1].name, "write_file"); + } + + // ── convert_messages ───────────────────────────────────────────────────── + + #[test] + fn test_convert_messages_user() { + let msgs = vec![user_msg("Hello")]; + let result = AnthropicProvider::convert_messages(&msgs); + assert_eq!(result.len(), 1); + assert_eq!(result[0].role, "user"); + match &result[0].content { + AnthropicContent::Text(t) => assert_eq!(t, "Hello"), + _ => panic!("Expected text content"), + } + } + + #[test] + fn test_convert_messages_system_skipped() { + let msgs = vec![system_msg("You are helpful"), user_msg("Hi")]; + let result = AnthropicProvider::convert_messages(&msgs); + assert_eq!(result.len(), 1); + assert_eq!(result[0].role, "user"); + } + + #[test] + fn test_convert_messages_assistant_text() { + let msgs = vec![assistant_msg("I can help with that")]; + let result = AnthropicProvider::convert_messages(&msgs); + assert_eq!(result.len(), 1); + assert_eq!(result[0].role, "assistant"); + match &result[0].content { + AnthropicContent::Text(t) => assert_eq!(t, "I can help with that"), + _ => panic!("Expected text content"), + } + } + + #[test] + fn test_convert_messages_assistant_with_tool_calls_no_content() { + let msgs = vec![Message { + role: Role::Assistant, + content: String::new(), + tool_calls: Some(vec![ToolCall { + id: Some("toolu_abc".to_string()), + kind: "function".to_string(), + function: FunctionCall { + name: "search_files".to_string(), + arguments: r#"{"pattern": "*.rs"}"#.to_string(), + }, + }]), + tool_call_id: None, + }]; + let result = AnthropicProvider::convert_messages(&msgs); + assert_eq!(result.len(), 1); + assert_eq!(result[0].role, "assistant"); + match &result[0].content { + AnthropicContent::Blocks(blocks) => { + assert_eq!(blocks.len(), 1); + match &blocks[0] { + AnthropicContentBlock::ToolUse { id, name, .. } => { + assert_eq!(id, "toolu_abc"); + assert_eq!(name, "search_files"); + } + _ => panic!("Expected ToolUse block"), + } + } + _ => panic!("Expected blocks content"), + } + } + + #[test] + fn test_convert_messages_assistant_with_tool_calls_and_content() { + let msgs = vec![Message { + role: Role::Assistant, + content: "Let me search for that".to_string(), + tool_calls: Some(vec![ToolCall { + id: Some("toolu_xyz".to_string()), + kind: "function".to_string(), + function: FunctionCall { + name: "read_file".to_string(), + arguments: r#"{"path": "main.rs"}"#.to_string(), + }, + }]), + tool_call_id: None, + }]; + let result = AnthropicProvider::convert_messages(&msgs); + assert_eq!(result.len(), 1); + match &result[0].content { + AnthropicContent::Blocks(blocks) => { + assert_eq!(blocks.len(), 2); + match &blocks[0] { + AnthropicContentBlock::Text { text } => { + assert_eq!(text, "Let me search for that"); + } + _ => panic!("Expected Text block first"), + } + match &blocks[1] { + AnthropicContentBlock::ToolUse { id, name, .. } => { + assert_eq!(id, "toolu_xyz"); + assert_eq!(name, "read_file"); + } + _ => panic!("Expected ToolUse block second"), + } + } + _ => panic!("Expected blocks content"), + } + } + + #[test] + fn test_convert_messages_assistant_tool_call_invalid_json_args() { + // Invalid JSON args fall back to {} + let msgs = vec![Message { + role: Role::Assistant, + content: String::new(), + tool_calls: Some(vec![ToolCall { + id: None, + kind: "function".to_string(), + function: FunctionCall { + name: "my_tool".to_string(), + arguments: "not valid json".to_string(), + }, + }]), + tool_call_id: None, + }]; + let result = AnthropicProvider::convert_messages(&msgs); + match &result[0].content { + AnthropicContent::Blocks(blocks) => match &blocks[0] { + AnthropicContentBlock::ToolUse { input, .. } => { + assert_eq!(*input, json!({})); + } + _ => panic!("Expected ToolUse block"), + }, + _ => panic!("Expected blocks"), + } + } + + #[test] + fn test_convert_messages_assistant_tool_call_no_id_generates_uuid() { + let msgs = vec![Message { + role: Role::Assistant, + content: String::new(), + tool_calls: Some(vec![ToolCall { + id: None, // no id provided + kind: "function".to_string(), + function: FunctionCall { + name: "my_tool".to_string(), + arguments: "{}".to_string(), + }, + }]), + tool_call_id: None, + }]; + let result = AnthropicProvider::convert_messages(&msgs); + match &result[0].content { + AnthropicContent::Blocks(blocks) => match &blocks[0] { + AnthropicContentBlock::ToolUse { id, .. } => { + assert!(!id.is_empty(), "Should have generated a UUID"); + } + _ => panic!("Expected ToolUse block"), + }, + _ => panic!("Expected blocks"), + } + } + + #[test] + fn test_convert_messages_tool_role() { + let msgs = vec![Message { + role: Role::Tool, + content: "file content here".to_string(), + tool_calls: None, + tool_call_id: Some("toolu_123".to_string()), + }]; + let result = AnthropicProvider::convert_messages(&msgs); + assert_eq!(result.len(), 1); + assert_eq!(result[0].role, "user"); + match &result[0].content { + AnthropicContent::Blocks(blocks) => { + assert_eq!(blocks.len(), 1); + match &blocks[0] { + AnthropicContentBlock::ToolResult { + tool_use_id, + content, + } => { + assert_eq!(tool_use_id, "toolu_123"); + assert_eq!(content, "file content here"); + } + _ => panic!("Expected ToolResult block"), + } + } + _ => panic!("Expected blocks content"), + } + } + + #[test] + fn test_convert_messages_tool_role_no_id_defaults_empty() { + let msgs = vec![Message { + role: Role::Tool, + content: "result".to_string(), + tool_calls: None, + tool_call_id: None, + }]; + let result = AnthropicProvider::convert_messages(&msgs); + match &result[0].content { + AnthropicContent::Blocks(blocks) => match &blocks[0] { + AnthropicContentBlock::ToolResult { tool_use_id, .. } => { + assert_eq!(tool_use_id, ""); + } + _ => panic!("Expected ToolResult block"), + }, + _ => panic!("Expected blocks"), + } + } + + #[test] + fn test_convert_messages_mixed_roles() { + let msgs = vec![ + system_msg("Be helpful"), + user_msg("What is the time?"), + assistant_msg("I can check that."), + ]; + let result = AnthropicProvider::convert_messages(&msgs); + // System is skipped + assert_eq!(result.len(), 2); + assert_eq!(result[0].role, "user"); + assert_eq!(result[1].role, "assistant"); + } + + // ── extract_system_prompt ───────────────────────────────────────────────── + + #[test] + fn test_extract_system_prompt_no_messages() { + let msgs: Vec = vec![]; + let prompt = AnthropicProvider::extract_system_prompt(&msgs); + assert!(prompt.is_empty()); + } + + #[test] + fn test_extract_system_prompt_no_system_messages() { + let msgs = vec![user_msg("Hello"), assistant_msg("Hi there")]; + let prompt = AnthropicProvider::extract_system_prompt(&msgs); + assert!(prompt.is_empty()); + } + + #[test] + fn test_extract_system_prompt_single() { + let msgs = vec![system_msg("You are a helpful assistant"), user_msg("Hi")]; + let prompt = AnthropicProvider::extract_system_prompt(&msgs); + assert_eq!(prompt, "You are a helpful assistant"); + } + + #[test] + fn test_extract_system_prompt_multiple_joined() { + let msgs = vec![ + system_msg("First instruction"), + system_msg("Second instruction"), + user_msg("Hello"), + ]; + let prompt = AnthropicProvider::extract_system_prompt(&msgs); + assert_eq!(prompt, "First instruction\n\nSecond instruction"); + } + + // ── chat_stream (HTTP mocked) ───────────────────────────────────────────── + + #[tokio::test] + async fn test_chat_stream_text_response() { + let mut server = mockito::Server::new_async().await; + + let delta1 = json!({ + "type": "content_block_delta", + "delta": {"type": "text_delta", "text": "Hello"} + }); + let delta2 = json!({ + "type": "content_block_delta", + "delta": {"type": "text_delta", "text": " world"} + }); + let body = format!("data: {delta1}\ndata: {delta2}\ndata: [DONE]\n"); + + let _m = server + .mock("POST", "/v1/messages") + .with_status(200) + .with_header("content-type", "text/event-stream") + .with_body(body) + .create_async() + .await; + + let provider = AnthropicProvider::new_with_url( + "test-key".to_string(), + format!("{}/v1/messages", server.url()), + ); + let (_tx, mut cancel_rx) = tokio::sync::watch::channel(false); + let mut tokens = Vec::::new(); + + let result = provider + .chat_stream( + "claude-3-5-sonnet-20241022", + &[user_msg("Hello")], + &[], + &mut cancel_rx, + |t| tokens.push(t.to_string()), + |_| {}, + ) + .await; + + assert!(result.is_ok()); + let response = result.unwrap(); + assert_eq!(response.content, Some("Hello world".to_string())); + assert!(response.tool_calls.is_none()); + assert_eq!(tokens, vec!["Hello", " world"]); + } + + #[tokio::test] + async fn test_chat_stream_error_response() { + let mut server = mockito::Server::new_async().await; + + let _m = server + .mock("POST", "/v1/messages") + .with_status(401) + .with_body(r#"{"error":{"type":"authentication_error","message":"Invalid API key"}}"#) + .create_async() + .await; + + let provider = AnthropicProvider::new_with_url( + "bad-key".to_string(), + format!("{}/v1/messages", server.url()), + ); + let (_tx, mut cancel_rx) = tokio::sync::watch::channel(false); + + let result = provider + .chat_stream( + "claude-3-5-sonnet-20241022", + &[user_msg("Hello")], + &[], + &mut cancel_rx, + |_| {}, + |_| {}, + ) + .await; + + assert!(result.is_err()); + assert!(result.unwrap_err().contains("401")); + } + + #[tokio::test] + async fn test_chat_stream_tool_use_response() { + let mut server = mockito::Server::new_async().await; + + let start_event = json!({ + "type": "content_block_start", + "content_block": {"type": "tool_use", "id": "toolu_abc", "name": "search_files"} + }); + let delta_event = json!({ + "type": "content_block_delta", + "delta": {"type": "input_json_delta", "partial_json": "{}"} + }); + let stop_event = json!({"type": "content_block_stop"}); + let body = format!( + "data: {start_event}\ndata: {delta_event}\ndata: {stop_event}\ndata: [DONE]\n" + ); + + let _m = server + .mock("POST", "/v1/messages") + .with_status(200) + .with_header("content-type", "text/event-stream") + .with_body(body) + .create_async() + .await; + + let provider = AnthropicProvider::new_with_url( + "test-key".to_string(), + format!("{}/v1/messages", server.url()), + ); + let (_tx, mut cancel_rx) = tokio::sync::watch::channel(false); + let mut activities = Vec::::new(); + + let result = provider + .chat_stream( + "claude-3-5-sonnet-20241022", + &[user_msg("Find Rust files")], + &[make_tool_def("search_files")], + &mut cancel_rx, + |_| {}, + |a| activities.push(a.to_string()), + ) + .await; + + assert!(result.is_ok()); + let response = result.unwrap(); + assert!(response.content.is_none()); + let tool_calls = response.tool_calls.expect("Expected tool calls"); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].id, Some("toolu_abc".to_string())); + assert_eq!(tool_calls[0].function.name, "search_files"); + assert_eq!(activities, vec!["search_files"]); + } + + #[tokio::test] + async fn test_chat_stream_includes_system_prompt() { + let mut server = mockito::Server::new_async().await; + + let delta = json!({ + "type": "content_block_delta", + "delta": {"type": "text_delta", "text": "ok"} + }); + let body = format!("data: {delta}\ndata: [DONE]\n"); + + let _m = server + .mock("POST", "/v1/messages") + .with_status(200) + .with_header("content-type", "text/event-stream") + .with_body(body) + .create_async() + .await; + + let provider = AnthropicProvider::new_with_url( + "test-key".to_string(), + format!("{}/v1/messages", server.url()), + ); + let (_tx, mut cancel_rx) = tokio::sync::watch::channel(false); + let messages = vec![system_msg("Be concise"), user_msg("Hello")]; + + let result = provider + .chat_stream( + "claude-3-5-sonnet-20241022", + &messages, + &[], + &mut cancel_rx, + |_| {}, + |_| {}, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap().content, Some("ok".to_string())); + } + + #[tokio::test] + async fn test_chat_stream_empty_response_gives_none_content() { + let mut server = mockito::Server::new_async().await; + + let _m = server + .mock("POST", "/v1/messages") + .with_status(200) + .with_header("content-type", "text/event-stream") + .with_body("data: [DONE]\n") + .create_async() + .await; + + let provider = AnthropicProvider::new_with_url( + "test-key".to_string(), + format!("{}/v1/messages", server.url()), + ); + let (_tx, mut cancel_rx) = tokio::sync::watch::channel(false); + + let result = provider + .chat_stream( + "claude-3-5-sonnet-20241022", + &[user_msg("Hello")], + &[], + &mut cancel_rx, + |_| {}, + |_| {}, + ) + .await; + + assert!(result.is_ok()); + let response = result.unwrap(); + assert!(response.content.is_none()); + assert!(response.tool_calls.is_none()); + } +}