use crate::llm::types::{ CompletionResponse, FunctionCall, Message, Role, ToolCall, ToolDefinition, }; use futures::StreamExt; use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue}; use serde::{Deserialize, Serialize}; use serde_json::json; use tokio::sync::watch::Receiver; const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages"; const ANTHROPIC_VERSION: &str = "2023-06-01"; pub struct AnthropicProvider { api_key: String, client: reqwest::Client, } #[derive(Debug, Serialize, Deserialize)] struct AnthropicMessage { role: String, // "user" or "assistant" content: AnthropicContent, } #[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] enum AnthropicContent { Text(String), Blocks(Vec), } #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type")] enum AnthropicContentBlock { #[serde(rename = "text")] Text { text: String }, #[serde(rename = "tool_use")] ToolUse { id: String, name: String, input: serde_json::Value, }, #[serde(rename = "tool_result")] ToolResult { tool_use_id: String, content: String, }, } #[derive(Debug, Serialize)] struct AnthropicTool { name: String, description: String, input_schema: serde_json::Value, } #[derive(Debug, Deserialize)] struct StreamEvent { #[serde(rename = "type")] event_type: String, #[serde(flatten)] data: serde_json::Value, } impl AnthropicProvider { pub fn new(api_key: String) -> Self { Self { api_key, client: reqwest::Client::new(), } } fn convert_tools(tools: &[ToolDefinition]) -> Vec { tools .iter() .map(|tool| AnthropicTool { name: tool.function.name.clone(), description: tool.function.description.clone(), input_schema: tool.function.parameters.clone(), }) .collect() } fn convert_messages(messages: &[Message]) -> Vec { let mut anthropic_messages: Vec = Vec::new(); for msg in messages { match msg.role { Role::System => { continue; } Role::User => { anthropic_messages.push(AnthropicMessage { role: "user".to_string(), content: AnthropicContent::Text(msg.content.clone()), }); } Role::Assistant => { if let Some(tool_calls) = &msg.tool_calls { let mut blocks = Vec::new(); if !msg.content.is_empty() { blocks.push(AnthropicContentBlock::Text { text: msg.content.clone(), }); } for call in tool_calls { let input: serde_json::Value = serde_json::from_str(&call.function.arguments).unwrap_or(json!({})); blocks.push(AnthropicContentBlock::ToolUse { id: call .id .clone() .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), name: call.function.name.clone(), input, }); } anthropic_messages.push(AnthropicMessage { role: "assistant".to_string(), content: AnthropicContent::Blocks(blocks), }); } else { anthropic_messages.push(AnthropicMessage { role: "assistant".to_string(), content: AnthropicContent::Text(msg.content.clone()), }); } } Role::Tool => { let tool_use_id = msg.tool_call_id.clone().unwrap_or_default(); anthropic_messages.push(AnthropicMessage { role: "user".to_string(), content: AnthropicContent::Blocks(vec![ AnthropicContentBlock::ToolResult { tool_use_id, content: msg.content.clone(), }, ]), }); } } } anthropic_messages } fn extract_system_prompt(messages: &[Message]) -> String { messages .iter() .filter(|m| matches!(m.role, Role::System)) .map(|m| m.content.as_str()) .collect::>() .join("\n\n") } pub async fn chat_stream( &self, model: &str, messages: &[Message], tools: &[ToolDefinition], cancel_rx: &mut Receiver, mut on_token: F, ) -> Result where F: FnMut(&str), { let anthropic_messages = Self::convert_messages(messages); let anthropic_tools = Self::convert_tools(tools); let system_prompt = Self::extract_system_prompt(messages); let mut request_body = json!({ "model": model, "max_tokens": 4096, "messages": anthropic_messages, "stream": true, }); if !system_prompt.is_empty() { request_body["system"] = json!(system_prompt); } if !anthropic_tools.is_empty() { request_body["tools"] = json!(anthropic_tools); } let mut headers = HeaderMap::new(); headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); headers.insert( "x-api-key", HeaderValue::from_str(&self.api_key).map_err(|e| e.to_string())?, ); headers.insert( "anthropic-version", HeaderValue::from_static(ANTHROPIC_VERSION), ); let response = self .client .post(ANTHROPIC_API_URL) .headers(headers) .json(&request_body) .send() .await .map_err(|e| format!("Failed to send request to Anthropic: {e}"))?; if !response.status().is_success() { let status = response.status(); let error_text = response .text() .await .unwrap_or_else(|_| "Unknown error".to_string()); return Err(format!("Anthropic API error {status}: {error_text}")); } let mut stream = response.bytes_stream(); let mut accumulated_text = String::new(); let mut tool_calls: Vec = Vec::new(); let mut current_tool_use: Option<(String, String, String)> = None; loop { let chunk = tokio::select! { result = stream.next() => { match result { Some(c) => c, None => break, } } _ = cancel_rx.changed() => { if *cancel_rx.borrow() { return Err("Chat cancelled by user".to_string()); } continue; } }; let bytes = chunk.map_err(|e| format!("Stream error: {e}"))?; let text = String::from_utf8_lossy(&bytes); for line in text.lines() { if let Some(json_str) = line.strip_prefix("data: ") { if json_str == "[DONE]" { break; } let event: StreamEvent = match serde_json::from_str(json_str) { Ok(e) => e, Err(_) => continue, }; match event.event_type.as_str() { "content_block_start" => { if let Some(content_block) = event.data.get("content_block") && content_block.get("type") == Some(&json!("tool_use")) { let id = content_block["id"].as_str().unwrap_or("").to_string(); let name = content_block["name"].as_str().unwrap_or("").to_string(); current_tool_use = Some((id, name, String::new())); } } "content_block_delta" => { if let Some(delta) = event.data.get("delta") { if delta.get("type") == Some(&json!("text_delta")) { if let Some(text) = delta.get("text").and_then(|t| t.as_str()) { accumulated_text.push_str(text); on_token(text); } } else if delta.get("type") == Some(&json!("input_json_delta")) && let Some((_, _, input_json)) = &mut current_tool_use && let Some(partial) = delta.get("partial_json").and_then(|p| p.as_str()) { input_json.push_str(partial); } } } "content_block_stop" => { if let Some((id, name, input_json)) = current_tool_use.take() { tool_calls.push(ToolCall { id: Some(id), kind: "function".to_string(), function: FunctionCall { name, arguments: input_json, }, }); } } _ => {} } } } } Ok(CompletionResponse { content: if accumulated_text.is_empty() { None } else { Some(accumulated_text) }, tool_calls: if tool_calls.is_empty() { None } else { Some(tool_calls) }, session_id: None, }) } }