use crate::llm::types::{ CompletionResponse, FunctionCall, Message, ModelProvider, Role, ToolCall, ToolDefinition, }; use async_trait::async_trait; use futures::StreamExt; use serde::{Deserialize, Serialize}; use serde_json::Value; use tauri::{AppHandle, Emitter}; pub struct OllamaProvider { base_url: String, } impl OllamaProvider { pub fn new(base_url: String) -> Self { Self { base_url } } pub async fn get_models(base_url: &str) -> Result, String> { let client = reqwest::Client::new(); let url = format!("{}/api/tags", base_url.trim_end_matches('/')); let res = client .get(&url) .send() .await .map_err(|e| format!("Request failed: {}", e))?; if !res.status().is_success() { let status = res.status(); let text = res.text().await.unwrap_or_default(); return Err(format!("Ollama API error {}: {}", status, text)); } let body: OllamaTagsResponse = res .json() .await .map_err(|e| format!("Failed to parse response: {}", e))?; Ok(body.models.into_iter().map(|m| m.name).collect()) } /// Streaming chat that emits tokens via Tauri events pub async fn chat_stream( &self, app: &AppHandle, model: &str, messages: &[Message], tools: &[ToolDefinition], cancel_rx: &mut tokio::sync::watch::Receiver, ) -> Result { let client = reqwest::Client::new(); let url = format!("{}/api/chat", self.base_url.trim_end_matches('/')); // Convert domain Messages to Ollama Messages let ollama_messages: Vec = messages .iter() .map(|m| { let tool_calls = m.tool_calls.as_ref().map(|calls| { calls .iter() .map(|tc| { let args_val: Value = serde_json::from_str(&tc.function.arguments) .unwrap_or(Value::String(tc.function.arguments.clone())); OllamaRequestToolCall { kind: tc.kind.clone(), function: OllamaRequestFunctionCall { name: tc.function.name.clone(), arguments: args_val, }, } }) .collect() }); OllamaRequestMessage { role: m.role.clone(), content: m.content.clone(), tool_calls, tool_call_id: m.tool_call_id.clone(), } }) .collect(); let request_body = OllamaRequest { model, messages: ollama_messages, stream: true, // Enable streaming tools, }; let res = client .post(&url) .json(&request_body) .send() .await .map_err(|e| format!("Request failed: {}", e))?; if !res.status().is_success() { let status = res.status(); let text = res.text().await.unwrap_or_default(); return Err(format!("Ollama API error {}: {}", status, text)); } // Process streaming response let mut stream = res.bytes_stream(); let mut buffer = String::new(); let mut accumulated_content = String::new(); let mut final_tool_calls: Option> = None; loop { // Check for cancellation if *cancel_rx.borrow() { return Err("Chat cancelled by user".to_string()); } let chunk_result = tokio::select! { chunk = stream.next() => { match chunk { Some(c) => c, None => break, } } _ = cancel_rx.changed() => { // changed() fires on any change, check if it's actually true if *cancel_rx.borrow() { return Err("Chat cancelled by user".to_string()); } else { continue; } } }; let chunk = chunk_result.map_err(|e| format!("Stream error: {}", e))?; buffer.push_str(&String::from_utf8_lossy(&chunk)); // Process complete lines (newline-delimited JSON) while let Some(newline_pos) = buffer.find('\n') { let line = buffer[..newline_pos].trim().to_string(); buffer = buffer[newline_pos + 1..].to_string(); if line.is_empty() { continue; } // Parse the streaming response let stream_msg: OllamaStreamResponse = serde_json::from_str(&line).map_err(|e| format!("JSON parse error: {}", e))?; // Emit token if there's content if !stream_msg.message.content.is_empty() { accumulated_content.push_str(&stream_msg.message.content); // Emit chat:token event app.emit("chat:token", &stream_msg.message.content) .map_err(|e| e.to_string())?; } // Check for tool calls if let Some(tool_calls) = stream_msg.message.tool_calls { final_tool_calls = Some( tool_calls .into_iter() .map(|tc| ToolCall { id: None, kind: "function".to_string(), function: FunctionCall { name: tc.function.name, arguments: tc.function.arguments.to_string(), }, }) .collect(), ); } // If done, break if stream_msg.done { break; } } } Ok(CompletionResponse { content: if accumulated_content.is_empty() { None } else { Some(accumulated_content) }, tool_calls: final_tool_calls, }) } } #[derive(Deserialize)] struct OllamaTagsResponse { models: Vec, } #[derive(Deserialize)] struct OllamaModelTag { name: String, } // --- Request Types --- #[derive(Serialize)] struct OllamaRequest<'a> { model: &'a str, messages: Vec, stream: bool, #[serde(skip_serializing_if = "is_empty_tools")] tools: &'a [ToolDefinition], } fn is_empty_tools(tools: &&[ToolDefinition]) -> bool { tools.is_empty() } #[derive(Serialize)] struct OllamaRequestMessage { role: Role, content: String, #[serde(skip_serializing_if = "Option::is_none")] tool_calls: Option>, #[serde(skip_serializing_if = "Option::is_none")] tool_call_id: Option, } #[derive(Serialize)] struct OllamaRequestToolCall { function: OllamaRequestFunctionCall, #[serde(rename = "type")] kind: String, } #[derive(Serialize)] struct OllamaRequestFunctionCall { name: String, arguments: Value, } // --- Response Types --- #[derive(Deserialize)] #[allow(dead_code)] struct OllamaResponse { message: OllamaResponseMessage, } #[derive(Deserialize)] #[allow(dead_code)] struct OllamaResponseMessage { content: String, tool_calls: Option>, } #[derive(Deserialize)] struct OllamaResponseToolCall { function: OllamaResponseFunctionCall, } #[derive(Deserialize)] struct OllamaResponseFunctionCall { name: String, arguments: Value, // Ollama returns Object, we convert to String for internal storage } // --- Streaming Response Types --- #[derive(Deserialize)] struct OllamaStreamResponse { message: OllamaStreamMessage, done: bool, } #[derive(Deserialize)] struct OllamaStreamMessage { #[serde(default)] content: String, #[serde(default)] tool_calls: Option>, } #[async_trait] impl ModelProvider for OllamaProvider { async fn chat( &self, model: &str, messages: &[Message], tools: &[ToolDefinition], ) -> Result { let client = reqwest::Client::new(); let url = format!("{}/api/chat", self.base_url.trim_end_matches('/')); // Convert domain Messages to Ollama Messages (handling String -> Object args mismatch) let ollama_messages: Vec = messages .iter() .map(|m| { let tool_calls = m.tool_calls.as_ref().map(|calls| { calls .iter() .map(|tc| { // Try to parse string args as JSON, fallback to string value if fails let args_val: Value = serde_json::from_str(&tc.function.arguments) .unwrap_or(Value::String(tc.function.arguments.clone())); OllamaRequestToolCall { kind: tc.kind.clone(), function: OllamaRequestFunctionCall { name: tc.function.name.clone(), arguments: args_val, }, } }) .collect() }); OllamaRequestMessage { role: m.role.clone(), content: m.content.clone(), tool_calls, tool_call_id: m.tool_call_id.clone(), } }) .collect(); let request_body = OllamaRequest { model, messages: ollama_messages, stream: false, tools, }; // Debug: Log the request body if let Ok(json_str) = serde_json::to_string_pretty(&request_body) { eprintln!("=== Ollama Request ===\n{}\n===================", json_str); } let res = client .post(&url) .json(&request_body) .send() .await .map_err(|e| format!("Request failed: {}", e))?; if !res.status().is_success() { let status = res.status(); let text = res.text().await.unwrap_or_default(); eprintln!( "=== Ollama Error Response ===\n{}\n========================", text ); return Err(format!("Ollama API error {}: {}", status, text)); } let response_body: OllamaResponse = res .json() .await .map_err(|e| format!("Failed to parse response: {}", e))?; // Convert Response back to Domain types let content = if response_body.message.content.is_empty() { None } else { Some(response_body.message.content) }; let tool_calls = response_body.message.tool_calls.map(|calls| { calls .into_iter() .map(|tc| ToolCall { id: None, // Ollama doesn't typically send IDs kind: "function".to_string(), function: FunctionCall { name: tc.function.name, arguments: tc.function.arguments.to_string(), // Convert Object -> String }, }) .collect() }); Ok(CompletionResponse { content, tool_calls, }) } }