use crate::llm::types::{ CompletionResponse, FunctionCall, Message, Role, ToolCall, ToolDefinition, }; use futures::StreamExt; use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue}; use serde::{Deserialize, Serialize}; use serde_json::json; use tokio::sync::watch::Receiver; const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages"; const ANTHROPIC_VERSION: &str = "2023-06-01"; pub struct AnthropicProvider { api_key: String, client: reqwest::Client, api_url: String, } #[derive(Debug, Serialize, Deserialize)] struct AnthropicMessage { role: String, // "user" or "assistant" content: AnthropicContent, } #[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] enum AnthropicContent { Text(String), Blocks(Vec), } #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type")] enum AnthropicContentBlock { #[serde(rename = "text")] Text { text: String }, #[serde(rename = "tool_use")] ToolUse { id: String, name: String, input: serde_json::Value, }, #[serde(rename = "tool_result")] ToolResult { tool_use_id: String, content: String, }, } #[derive(Debug, Serialize)] struct AnthropicTool { name: String, description: String, input_schema: serde_json::Value, } #[derive(Debug, Deserialize)] struct StreamEvent { #[serde(rename = "type")] event_type: String, #[serde(flatten)] data: serde_json::Value, } impl AnthropicProvider { pub fn new(api_key: String) -> Self { Self { api_key, client: reqwest::Client::new(), api_url: ANTHROPIC_API_URL.to_string(), } } #[cfg(test)] fn new_with_url(api_key: String, api_url: String) -> Self { Self { api_key, client: reqwest::Client::new(), api_url, } } fn convert_tools(tools: &[ToolDefinition]) -> Vec { tools .iter() .map(|tool| AnthropicTool { name: tool.function.name.clone(), description: tool.function.description.clone(), input_schema: tool.function.parameters.clone(), }) .collect() } fn convert_messages(messages: &[Message]) -> Vec { let mut anthropic_messages: Vec = Vec::new(); for msg in messages { match msg.role { Role::System => { continue; } Role::User => { anthropic_messages.push(AnthropicMessage { role: "user".to_string(), content: AnthropicContent::Text(msg.content.clone()), }); } Role::Assistant => { if let Some(tool_calls) = &msg.tool_calls { let mut blocks = Vec::new(); if !msg.content.is_empty() { blocks.push(AnthropicContentBlock::Text { text: msg.content.clone(), }); } for call in tool_calls { let input: serde_json::Value = serde_json::from_str(&call.function.arguments).unwrap_or(json!({})); blocks.push(AnthropicContentBlock::ToolUse { id: call .id .clone() .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), name: call.function.name.clone(), input, }); } anthropic_messages.push(AnthropicMessage { role: "assistant".to_string(), content: AnthropicContent::Blocks(blocks), }); } else { anthropic_messages.push(AnthropicMessage { role: "assistant".to_string(), content: AnthropicContent::Text(msg.content.clone()), }); } } Role::Tool => { let tool_use_id = msg.tool_call_id.clone().unwrap_or_default(); anthropic_messages.push(AnthropicMessage { role: "user".to_string(), content: AnthropicContent::Blocks(vec![ AnthropicContentBlock::ToolResult { tool_use_id, content: msg.content.clone(), }, ]), }); } } } anthropic_messages } fn extract_system_prompt(messages: &[Message]) -> String { messages .iter() .filter(|m| matches!(m.role, Role::System)) .map(|m| m.content.as_str()) .collect::>() .join("\n\n") } pub async fn chat_stream( &self, model: &str, messages: &[Message], tools: &[ToolDefinition], cancel_rx: &mut Receiver, mut on_token: F, mut on_activity: A, ) -> Result where F: FnMut(&str), A: FnMut(&str), { let anthropic_messages = Self::convert_messages(messages); let anthropic_tools = Self::convert_tools(tools); let system_prompt = Self::extract_system_prompt(messages); let mut request_body = json!({ "model": model, "max_tokens": 4096, "messages": anthropic_messages, "stream": true, }); if !system_prompt.is_empty() { request_body["system"] = json!(system_prompt); } if !anthropic_tools.is_empty() { request_body["tools"] = json!(anthropic_tools); } let mut headers = HeaderMap::new(); headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); headers.insert( "x-api-key", HeaderValue::from_str(&self.api_key).map_err(|e| e.to_string())?, ); headers.insert( "anthropic-version", HeaderValue::from_static(ANTHROPIC_VERSION), ); let response = self .client .post(&self.api_url) .headers(headers) .json(&request_body) .send() .await .map_err(|e| format!("Failed to send request to Anthropic: {e}"))?; if !response.status().is_success() { let status = response.status(); let error_text = response .text() .await .unwrap_or_else(|_| "Unknown error".to_string()); return Err(format!("Anthropic API error {status}: {error_text}")); } let mut stream = response.bytes_stream(); let mut accumulated_text = String::new(); let mut tool_calls: Vec = Vec::new(); let mut current_tool_use: Option<(String, String, String)> = None; loop { let chunk = tokio::select! { result = stream.next() => { match result { Some(c) => c, None => break, } } _ = cancel_rx.changed() => { if *cancel_rx.borrow() { return Err("Chat cancelled by user".to_string()); } continue; } }; let bytes = chunk.map_err(|e| format!("Stream error: {e}"))?; let text = String::from_utf8_lossy(&bytes); for line in text.lines() { if let Some(json_str) = line.strip_prefix("data: ") { if json_str == "[DONE]" { break; } let event: StreamEvent = match serde_json::from_str(json_str) { Ok(e) => e, Err(_) => continue, }; match event.event_type.as_str() { "content_block_start" => { if let Some(content_block) = event.data.get("content_block") && content_block.get("type") == Some(&json!("tool_use")) { let id = content_block["id"].as_str().unwrap_or("").to_string(); let name = content_block["name"].as_str().unwrap_or("").to_string(); on_activity(&name); current_tool_use = Some((id, name, String::new())); } } "content_block_delta" => { if let Some(delta) = event.data.get("delta") { if delta.get("type") == Some(&json!("text_delta")) { if let Some(text) = delta.get("text").and_then(|t| t.as_str()) { accumulated_text.push_str(text); on_token(text); } } else if delta.get("type") == Some(&json!("input_json_delta")) && let Some((_, _, input_json)) = &mut current_tool_use && let Some(partial) = delta.get("partial_json").and_then(|p| p.as_str()) { input_json.push_str(partial); } } } "content_block_stop" => { if let Some((id, name, input_json)) = current_tool_use.take() { tool_calls.push(ToolCall { id: Some(id), kind: "function".to_string(), function: FunctionCall { name, arguments: input_json, }, }); } } _ => {} } } } } Ok(CompletionResponse { content: if accumulated_text.is_empty() { None } else { Some(accumulated_text) }, tool_calls: if tool_calls.is_empty() { None } else { Some(tool_calls) }, session_id: None, }) } } #[cfg(test)] mod tests { use super::{AnthropicContent, AnthropicContentBlock, AnthropicProvider}; use crate::llm::types::{ FunctionCall, Message, Role, ToolCall, ToolDefinition, ToolFunctionDefinition, }; use serde_json::json; fn user_msg(content: &str) -> Message { Message { role: Role::User, content: content.to_string(), tool_calls: None, tool_call_id: None, } } fn system_msg(content: &str) -> Message { Message { role: Role::System, content: content.to_string(), tool_calls: None, tool_call_id: None, } } fn assistant_msg(content: &str) -> Message { Message { role: Role::Assistant, content: content.to_string(), tool_calls: None, tool_call_id: None, } } fn make_tool_def(name: &str) -> ToolDefinition { ToolDefinition { kind: "function".to_string(), function: ToolFunctionDefinition { name: name.to_string(), description: format!("{name} description"), parameters: json!({"type": "object", "properties": {}}), }, } } // ── convert_tools ──────────────────────────────────────────────────────── #[test] fn test_convert_tools_empty() { let result = AnthropicProvider::convert_tools(&[]); assert!(result.is_empty()); } #[test] fn test_convert_tools_single() { let tool = make_tool_def("search_files"); let result = AnthropicProvider::convert_tools(&[tool]); assert_eq!(result.len(), 1); assert_eq!(result[0].name, "search_files"); assert_eq!(result[0].description, "search_files description"); assert_eq!( result[0].input_schema, json!({"type": "object", "properties": {}}) ); } #[test] fn test_convert_tools_multiple() { let tools = vec![make_tool_def("read_file"), make_tool_def("write_file")]; let result = AnthropicProvider::convert_tools(&tools); assert_eq!(result.len(), 2); assert_eq!(result[0].name, "read_file"); assert_eq!(result[1].name, "write_file"); } // ── convert_messages ───────────────────────────────────────────────────── #[test] fn test_convert_messages_user() { let msgs = vec![user_msg("Hello")]; let result = AnthropicProvider::convert_messages(&msgs); assert_eq!(result.len(), 1); assert_eq!(result[0].role, "user"); match &result[0].content { AnthropicContent::Text(t) => assert_eq!(t, "Hello"), _ => panic!("Expected text content"), } } #[test] fn test_convert_messages_system_skipped() { let msgs = vec![system_msg("You are helpful"), user_msg("Hi")]; let result = AnthropicProvider::convert_messages(&msgs); assert_eq!(result.len(), 1); assert_eq!(result[0].role, "user"); } #[test] fn test_convert_messages_assistant_text() { let msgs = vec![assistant_msg("I can help with that")]; let result = AnthropicProvider::convert_messages(&msgs); assert_eq!(result.len(), 1); assert_eq!(result[0].role, "assistant"); match &result[0].content { AnthropicContent::Text(t) => assert_eq!(t, "I can help with that"), _ => panic!("Expected text content"), } } #[test] fn test_convert_messages_assistant_with_tool_calls_no_content() { let msgs = vec![Message { role: Role::Assistant, content: String::new(), tool_calls: Some(vec![ToolCall { id: Some("toolu_abc".to_string()), kind: "function".to_string(), function: FunctionCall { name: "search_files".to_string(), arguments: r#"{"pattern": "*.rs"}"#.to_string(), }, }]), tool_call_id: None, }]; let result = AnthropicProvider::convert_messages(&msgs); assert_eq!(result.len(), 1); assert_eq!(result[0].role, "assistant"); match &result[0].content { AnthropicContent::Blocks(blocks) => { assert_eq!(blocks.len(), 1); match &blocks[0] { AnthropicContentBlock::ToolUse { id, name, .. } => { assert_eq!(id, "toolu_abc"); assert_eq!(name, "search_files"); } _ => panic!("Expected ToolUse block"), } } _ => panic!("Expected blocks content"), } } #[test] fn test_convert_messages_assistant_with_tool_calls_and_content() { let msgs = vec![Message { role: Role::Assistant, content: "Let me search for that".to_string(), tool_calls: Some(vec![ToolCall { id: Some("toolu_xyz".to_string()), kind: "function".to_string(), function: FunctionCall { name: "read_file".to_string(), arguments: r#"{"path": "main.rs"}"#.to_string(), }, }]), tool_call_id: None, }]; let result = AnthropicProvider::convert_messages(&msgs); assert_eq!(result.len(), 1); match &result[0].content { AnthropicContent::Blocks(blocks) => { assert_eq!(blocks.len(), 2); match &blocks[0] { AnthropicContentBlock::Text { text } => { assert_eq!(text, "Let me search for that"); } _ => panic!("Expected Text block first"), } match &blocks[1] { AnthropicContentBlock::ToolUse { id, name, .. } => { assert_eq!(id, "toolu_xyz"); assert_eq!(name, "read_file"); } _ => panic!("Expected ToolUse block second"), } } _ => panic!("Expected blocks content"), } } #[test] fn test_convert_messages_assistant_tool_call_invalid_json_args() { // Invalid JSON args fall back to {} let msgs = vec![Message { role: Role::Assistant, content: String::new(), tool_calls: Some(vec![ToolCall { id: None, kind: "function".to_string(), function: FunctionCall { name: "my_tool".to_string(), arguments: "not valid json".to_string(), }, }]), tool_call_id: None, }]; let result = AnthropicProvider::convert_messages(&msgs); match &result[0].content { AnthropicContent::Blocks(blocks) => match &blocks[0] { AnthropicContentBlock::ToolUse { input, .. } => { assert_eq!(*input, json!({})); } _ => panic!("Expected ToolUse block"), }, _ => panic!("Expected blocks"), } } #[test] fn test_convert_messages_assistant_tool_call_no_id_generates_uuid() { let msgs = vec![Message { role: Role::Assistant, content: String::new(), tool_calls: Some(vec![ToolCall { id: None, // no id provided kind: "function".to_string(), function: FunctionCall { name: "my_tool".to_string(), arguments: "{}".to_string(), }, }]), tool_call_id: None, }]; let result = AnthropicProvider::convert_messages(&msgs); match &result[0].content { AnthropicContent::Blocks(blocks) => match &blocks[0] { AnthropicContentBlock::ToolUse { id, .. } => { assert!(!id.is_empty(), "Should have generated a UUID"); } _ => panic!("Expected ToolUse block"), }, _ => panic!("Expected blocks"), } } #[test] fn test_convert_messages_tool_role() { let msgs = vec![Message { role: Role::Tool, content: "file content here".to_string(), tool_calls: None, tool_call_id: Some("toolu_123".to_string()), }]; let result = AnthropicProvider::convert_messages(&msgs); assert_eq!(result.len(), 1); assert_eq!(result[0].role, "user"); match &result[0].content { AnthropicContent::Blocks(blocks) => { assert_eq!(blocks.len(), 1); match &blocks[0] { AnthropicContentBlock::ToolResult { tool_use_id, content, } => { assert_eq!(tool_use_id, "toolu_123"); assert_eq!(content, "file content here"); } _ => panic!("Expected ToolResult block"), } } _ => panic!("Expected blocks content"), } } #[test] fn test_convert_messages_tool_role_no_id_defaults_empty() { let msgs = vec![Message { role: Role::Tool, content: "result".to_string(), tool_calls: None, tool_call_id: None, }]; let result = AnthropicProvider::convert_messages(&msgs); match &result[0].content { AnthropicContent::Blocks(blocks) => match &blocks[0] { AnthropicContentBlock::ToolResult { tool_use_id, .. } => { assert_eq!(tool_use_id, ""); } _ => panic!("Expected ToolResult block"), }, _ => panic!("Expected blocks"), } } #[test] fn test_convert_messages_mixed_roles() { let msgs = vec![ system_msg("Be helpful"), user_msg("What is the time?"), assistant_msg("I can check that."), ]; let result = AnthropicProvider::convert_messages(&msgs); // System is skipped assert_eq!(result.len(), 2); assert_eq!(result[0].role, "user"); assert_eq!(result[1].role, "assistant"); } // ── extract_system_prompt ───────────────────────────────────────────────── #[test] fn test_extract_system_prompt_no_messages() { let msgs: Vec = vec![]; let prompt = AnthropicProvider::extract_system_prompt(&msgs); assert!(prompt.is_empty()); } #[test] fn test_extract_system_prompt_no_system_messages() { let msgs = vec![user_msg("Hello"), assistant_msg("Hi there")]; let prompt = AnthropicProvider::extract_system_prompt(&msgs); assert!(prompt.is_empty()); } #[test] fn test_extract_system_prompt_single() { let msgs = vec![system_msg("You are a helpful assistant"), user_msg("Hi")]; let prompt = AnthropicProvider::extract_system_prompt(&msgs); assert_eq!(prompt, "You are a helpful assistant"); } #[test] fn test_extract_system_prompt_multiple_joined() { let msgs = vec![ system_msg("First instruction"), system_msg("Second instruction"), user_msg("Hello"), ]; let prompt = AnthropicProvider::extract_system_prompt(&msgs); assert_eq!(prompt, "First instruction\n\nSecond instruction"); } // ── chat_stream (HTTP mocked) ───────────────────────────────────────────── #[tokio::test] async fn test_chat_stream_text_response() { let mut server = mockito::Server::new_async().await; let delta1 = json!({ "type": "content_block_delta", "delta": {"type": "text_delta", "text": "Hello"} }); let delta2 = json!({ "type": "content_block_delta", "delta": {"type": "text_delta", "text": " world"} }); let body = format!("data: {delta1}\ndata: {delta2}\ndata: [DONE]\n"); let _m = server .mock("POST", "/v1/messages") .with_status(200) .with_header("content-type", "text/event-stream") .with_body(body) .create_async() .await; let provider = AnthropicProvider::new_with_url( "test-key".to_string(), format!("{}/v1/messages", server.url()), ); let (_tx, mut cancel_rx) = tokio::sync::watch::channel(false); let mut tokens = Vec::::new(); let result = provider .chat_stream( "claude-3-5-sonnet-20241022", &[user_msg("Hello")], &[], &mut cancel_rx, |t| tokens.push(t.to_string()), |_| {}, ) .await; assert!(result.is_ok()); let response = result.unwrap(); assert_eq!(response.content, Some("Hello world".to_string())); assert!(response.tool_calls.is_none()); assert_eq!(tokens, vec!["Hello", " world"]); } #[tokio::test] async fn test_chat_stream_error_response() { let mut server = mockito::Server::new_async().await; let _m = server .mock("POST", "/v1/messages") .with_status(401) .with_body(r#"{"error":{"type":"authentication_error","message":"Invalid API key"}}"#) .create_async() .await; let provider = AnthropicProvider::new_with_url( "bad-key".to_string(), format!("{}/v1/messages", server.url()), ); let (_tx, mut cancel_rx) = tokio::sync::watch::channel(false); let result = provider .chat_stream( "claude-3-5-sonnet-20241022", &[user_msg("Hello")], &[], &mut cancel_rx, |_| {}, |_| {}, ) .await; assert!(result.is_err()); assert!(result.unwrap_err().contains("401")); } #[tokio::test] async fn test_chat_stream_tool_use_response() { let mut server = mockito::Server::new_async().await; let start_event = json!({ "type": "content_block_start", "content_block": {"type": "tool_use", "id": "toolu_abc", "name": "search_files"} }); let delta_event = json!({ "type": "content_block_delta", "delta": {"type": "input_json_delta", "partial_json": "{}"} }); let stop_event = json!({"type": "content_block_stop"}); let body = format!( "data: {start_event}\ndata: {delta_event}\ndata: {stop_event}\ndata: [DONE]\n" ); let _m = server .mock("POST", "/v1/messages") .with_status(200) .with_header("content-type", "text/event-stream") .with_body(body) .create_async() .await; let provider = AnthropicProvider::new_with_url( "test-key".to_string(), format!("{}/v1/messages", server.url()), ); let (_tx, mut cancel_rx) = tokio::sync::watch::channel(false); let mut activities = Vec::::new(); let result = provider .chat_stream( "claude-3-5-sonnet-20241022", &[user_msg("Find Rust files")], &[make_tool_def("search_files")], &mut cancel_rx, |_| {}, |a| activities.push(a.to_string()), ) .await; assert!(result.is_ok()); let response = result.unwrap(); assert!(response.content.is_none()); let tool_calls = response.tool_calls.expect("Expected tool calls"); assert_eq!(tool_calls.len(), 1); assert_eq!(tool_calls[0].id, Some("toolu_abc".to_string())); assert_eq!(tool_calls[0].function.name, "search_files"); assert_eq!(activities, vec!["search_files"]); } #[tokio::test] async fn test_chat_stream_includes_system_prompt() { let mut server = mockito::Server::new_async().await; let delta = json!({ "type": "content_block_delta", "delta": {"type": "text_delta", "text": "ok"} }); let body = format!("data: {delta}\ndata: [DONE]\n"); let _m = server .mock("POST", "/v1/messages") .with_status(200) .with_header("content-type", "text/event-stream") .with_body(body) .create_async() .await; let provider = AnthropicProvider::new_with_url( "test-key".to_string(), format!("{}/v1/messages", server.url()), ); let (_tx, mut cancel_rx) = tokio::sync::watch::channel(false); let messages = vec![system_msg("Be concise"), user_msg("Hello")]; let result = provider .chat_stream( "claude-3-5-sonnet-20241022", &messages, &[], &mut cancel_rx, |_| {}, |_| {}, ) .await; assert!(result.is_ok()); assert_eq!(result.unwrap().content, Some("ok".to_string())); } #[tokio::test] async fn test_chat_stream_empty_response_gives_none_content() { let mut server = mockito::Server::new_async().await; let _m = server .mock("POST", "/v1/messages") .with_status(200) .with_header("content-type", "text/event-stream") .with_body("data: [DONE]\n") .create_async() .await; let provider = AnthropicProvider::new_with_url( "test-key".to_string(), format!("{}/v1/messages", server.url()), ); let (_tx, mut cancel_rx) = tokio::sync::watch::channel(false); let result = provider .chat_stream( "claude-3-5-sonnet-20241022", &[user_msg("Hello")], &[], &mut cancel_rx, |_| {}, |_| {}, ) .await; assert!(result.is_ok()); let response = result.unwrap(); assert!(response.content.is_none()); assert!(response.tool_calls.is_none()); } }