use crate::llm::types::{ CompletionResponse, FunctionCall, Message, ModelProvider, Role, ToolCall, ToolDefinition, }; use async_trait::async_trait; use futures::StreamExt; use serde::{Deserialize, Serialize}; use serde_json::Value; pub struct OllamaProvider { base_url: String, } impl OllamaProvider { pub fn new(base_url: String) -> Self { Self { base_url } } pub async fn get_models(base_url: &str) -> Result, String> { let client = reqwest::Client::new(); let url = format!("{}/api/tags", base_url.trim_end_matches('/')); let res = client .get(&url) .send() .await .map_err(|e| format!("Request failed: {}", e))?; if !res.status().is_success() { let status = res.status(); let text = res.text().await.unwrap_or_default(); return Err(format!("Ollama API error {}: {}", status, text)); } let body: OllamaTagsResponse = res .json() .await .map_err(|e| format!("Failed to parse response: {}", e))?; Ok(body.models.into_iter().map(|m| m.name).collect()) } /// Streaming chat that calls `on_token` for each token chunk. pub async fn chat_stream( &self, model: &str, messages: &[Message], tools: &[ToolDefinition], cancel_rx: &mut tokio::sync::watch::Receiver, mut on_token: F, ) -> Result where F: FnMut(&str) + Send, { let client = reqwest::Client::new(); let url = format!("{}/api/chat", self.base_url.trim_end_matches('/')); let ollama_messages: Vec = messages .iter() .map(|m| { let tool_calls = m.tool_calls.as_ref().map(|calls| { calls .iter() .map(|tc| { let args_val: Value = serde_json::from_str(&tc.function.arguments) .unwrap_or(Value::String(tc.function.arguments.clone())); OllamaRequestToolCall { kind: tc.kind.clone(), function: OllamaRequestFunctionCall { name: tc.function.name.clone(), arguments: args_val, }, } }) .collect() }); OllamaRequestMessage { role: m.role.clone(), content: m.content.clone(), tool_calls, tool_call_id: m.tool_call_id.clone(), } }) .collect(); let request_body = OllamaRequest { model, messages: ollama_messages, stream: true, tools, }; let res = client .post(&url) .json(&request_body) .send() .await .map_err(|e| format!("Request failed: {}", e))?; if !res.status().is_success() { let status = res.status(); let text = res.text().await.unwrap_or_default(); return Err(format!("Ollama API error {}: {}", status, text)); } let mut stream = res.bytes_stream(); let mut buffer = String::new(); let mut accumulated_content = String::new(); let mut final_tool_calls: Option> = None; loop { if *cancel_rx.borrow() { return Err("Chat cancelled by user".to_string()); } let chunk_result = tokio::select! { chunk = stream.next() => { match chunk { Some(c) => c, None => break, } } _ = cancel_rx.changed() => { if *cancel_rx.borrow() { return Err("Chat cancelled by user".to_string()); } else { continue; } } }; let chunk = chunk_result.map_err(|e| format!("Stream error: {}", e))?; buffer.push_str(&String::from_utf8_lossy(&chunk)); while let Some(newline_pos) = buffer.find('\n') { let line = buffer[..newline_pos].trim().to_string(); buffer = buffer[newline_pos + 1..].to_string(); if line.is_empty() { continue; } let stream_msg: OllamaStreamResponse = serde_json::from_str(&line).map_err(|e| format!("JSON parse error: {}", e))?; if !stream_msg.message.content.is_empty() { accumulated_content.push_str(&stream_msg.message.content); on_token(&stream_msg.message.content); } if let Some(tool_calls) = stream_msg.message.tool_calls { final_tool_calls = Some( tool_calls .into_iter() .map(|tc| ToolCall { id: None, kind: "function".to_string(), function: FunctionCall { name: tc.function.name, arguments: tc.function.arguments.to_string(), }, }) .collect(), ); } if stream_msg.done { break; } } } Ok(CompletionResponse { content: if accumulated_content.is_empty() { None } else { Some(accumulated_content) }, tool_calls: final_tool_calls, }) } } #[derive(Deserialize)] struct OllamaTagsResponse { models: Vec, } #[derive(Deserialize)] struct OllamaModelTag { name: String, } #[derive(Serialize)] struct OllamaRequest<'a> { model: &'a str, messages: Vec, stream: bool, #[serde(skip_serializing_if = "is_empty_tools")] tools: &'a [ToolDefinition], } fn is_empty_tools(tools: &&[ToolDefinition]) -> bool { tools.is_empty() } #[derive(Serialize)] struct OllamaRequestMessage { role: Role, content: String, #[serde(skip_serializing_if = "Option::is_none")] tool_calls: Option>, #[serde(skip_serializing_if = "Option::is_none")] tool_call_id: Option, } #[derive(Serialize)] struct OllamaRequestToolCall { function: OllamaRequestFunctionCall, #[serde(rename = "type")] kind: String, } #[derive(Serialize)] struct OllamaRequestFunctionCall { name: String, arguments: Value, } #[derive(Deserialize)] struct OllamaStreamResponse { message: OllamaStreamMessage, done: bool, } #[derive(Deserialize)] struct OllamaStreamMessage { #[serde(default)] content: String, #[serde(default)] tool_calls: Option>, } #[derive(Deserialize)] struct OllamaResponseToolCall { function: OllamaResponseFunctionCall, } #[derive(Deserialize)] struct OllamaResponseFunctionCall { name: String, arguments: Value, } #[async_trait] impl ModelProvider for OllamaProvider { async fn chat( &self, _model: &str, _messages: &[Message], _tools: &[ToolDefinition], ) -> Result { Err("Non-streaming Ollama chat not implemented for server".to_string()) } }