feat: agent brain (ollama) and chat ui

This commit is contained in:
Dave
2025-12-24 17:17:35 +00:00
parent 76e03bc1a2
commit d9cd16601b
18 changed files with 1712 additions and 14 deletions

338
src-tauri/Cargo.lock generated
View File

@@ -473,8 +473,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2"
dependencies = [
"iana-time-zone",
"js-sys",
"num-traits",
"serde",
"wasm-bindgen",
"windows-link 0.2.1",
]
@@ -513,6 +515,16 @@ dependencies = [
"version_check",
]
[[package]]
name = "core-foundation"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "core-foundation"
version = "0.10.1"
@@ -536,9 +548,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa95a34622365fa5bbf40b20b75dba8dfa8c94c734aea8ac9a5ca38af14316f1"
dependencies = [
"bitflags 2.10.0",
"core-foundation",
"core-foundation 0.10.1",
"core-graphics-types",
"foreign-types",
"foreign-types 0.5.0",
"libc",
]
@@ -549,7 +561,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d44a101f213f6c4cdc1853d4b78aef6db6bdfa3468798cc1d9912f4735013eb"
dependencies = [
"bitflags 2.10.0",
"core-foundation",
"core-foundation 0.10.1",
"libc",
]
@@ -864,6 +876,15 @@ version = "1.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ef6b89e5b37196644d8796de5268852ff179b44e96276cf4290264843743bb7"
[[package]]
name = "encoding_rs"
version = "0.8.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3"
dependencies = [
"cfg-if",
]
[[package]]
name = "endi"
version = "1.1.1"
@@ -986,6 +1007,15 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "foreign-types"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
dependencies = [
"foreign-types-shared 0.1.1",
]
[[package]]
name = "foreign-types"
version = "0.5.0"
@@ -993,7 +1023,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965"
dependencies = [
"foreign-types-macros",
"foreign-types-shared",
"foreign-types-shared 0.3.1",
]
[[package]]
@@ -1007,6 +1037,12 @@ dependencies = [
"syn 2.0.111",
]
[[package]]
name = "foreign-types-shared"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
[[package]]
name = "foreign-types-shared"
version = "0.3.1"
@@ -1039,6 +1075,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10"
dependencies = [
"futures-core",
"futures-sink",
]
[[package]]
@@ -1430,6 +1467,25 @@ dependencies = [
"syn 2.0.111",
]
[[package]]
name = "h2"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3c0b69cfcb4e1b9f1bf2f53f95f766e4661169728ec61cd3fe5a0166f2d1386"
dependencies = [
"atomic-waker",
"bytes",
"fnv",
"futures-core",
"futures-sink",
"http",
"indexmap 2.12.1",
"slab",
"tokio",
"tokio-util",
"tracing",
]
[[package]]
name = "hashbrown"
version = "0.12.3"
@@ -1527,6 +1583,7 @@ dependencies = [
"bytes",
"futures-channel",
"futures-core",
"h2",
"http",
"http-body",
"httparse",
@@ -1538,6 +1595,38 @@ dependencies = [
"want",
]
[[package]]
name = "hyper-rustls"
version = "0.27.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58"
dependencies = [
"http",
"hyper",
"hyper-util",
"rustls",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tower-service",
]
[[package]]
name = "hyper-tls"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
dependencies = [
"bytes",
"http-body-util",
"hyper",
"hyper-util",
"native-tls",
"tokio",
"tokio-native-tls",
"tower-service",
]
[[package]]
name = "hyper-util"
version = "0.1.19"
@@ -1557,9 +1646,11 @@ dependencies = [
"percent-encoding",
"pin-project-lite",
"socket2",
"system-configuration",
"tokio",
"tower-service",
"tracing",
"windows-registry",
]
[[package]]
@@ -1965,13 +2056,16 @@ checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77"
name = "living-spec-standalone"
version = "0.1.0"
dependencies = [
"chrono",
"ignore",
"reqwest",
"serde",
"serde_json",
"tauri",
"tauri-build",
"tauri-plugin-dialog",
"tauri-plugin-opener",
"uuid",
"walkdir",
]
@@ -2090,6 +2184,23 @@ dependencies = [
"windows-sys 0.60.2",
]
[[package]]
name = "native-tls"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e"
dependencies = [
"libc",
"log",
"openssl",
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework",
"security-framework-sys",
"tempfile",
]
[[package]]
name = "ndk"
version = "0.9.0"
@@ -2412,6 +2523,50 @@ dependencies = [
"pathdiff",
]
[[package]]
name = "openssl"
version = "0.10.75"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328"
dependencies = [
"bitflags 2.10.0",
"cfg-if",
"foreign-types 0.3.2",
"libc",
"once_cell",
"openssl-macros",
"openssl-sys",
]
[[package]]
name = "openssl-macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.111",
]
[[package]]
name = "openssl-probe"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
[[package]]
name = "openssl-sys"
version = "0.9.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321"
dependencies = [
"cc",
"libc",
"pkg-config",
"vcpkg",
]
[[package]]
name = "option-ext"
version = "0.2.0"
@@ -3021,22 +3176,31 @@ checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147"
dependencies = [
"base64 0.22.1",
"bytes",
"encoding_rs",
"futures-channel",
"futures-core",
"futures-util",
"h2",
"http",
"http-body",
"http-body-util",
"hyper",
"hyper-rustls",
"hyper-tls",
"hyper-util",
"js-sys",
"log",
"mime",
"native-tls",
"percent-encoding",
"pin-project-lite",
"rustls-pki-types",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tokio-native-tls",
"tokio-util",
"tower",
"tower-http",
@@ -3073,6 +3237,20 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "ring"
version = "0.17.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7"
dependencies = [
"cc",
"cfg-if",
"getrandom 0.2.16",
"libc",
"untrusted",
"windows-sys 0.52.0",
]
[[package]]
name = "rustc_version"
version = "0.4.1"
@@ -3095,6 +3273,39 @@ dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "rustls"
version = "0.23.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "533f54bc6a7d4f647e46ad909549eda97bf5afc1585190ef692b4286b198bd8f"
dependencies = [
"once_cell",
"rustls-pki-types",
"rustls-webpki",
"subtle",
"zeroize",
]
[[package]]
name = "rustls-pki-types"
version = "1.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21e6f2ab2928ca4291b86736a8bd920a277a399bba1589409d72154ff87c1282"
dependencies = [
"zeroize",
]
[[package]]
name = "rustls-webpki"
version = "0.103.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52"
dependencies = [
"ring",
"rustls-pki-types",
"untrusted",
]
[[package]]
name = "rustversion"
version = "1.0.22"
@@ -3116,6 +3327,15 @@ dependencies = [
"winapi-util",
]
[[package]]
name = "schannel"
version = "0.1.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1"
dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "schemars"
version = "0.8.22"
@@ -3179,6 +3399,29 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "security-framework"
version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
dependencies = [
"bitflags 2.10.0",
"core-foundation 0.9.4",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "selectors"
version = "0.24.0"
@@ -3534,6 +3777,12 @@ version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "subtle"
version = "2.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292"
[[package]]
name = "swift-rs"
version = "1.0.7"
@@ -3587,6 +3836,27 @@ dependencies = [
"syn 2.0.111",
]
[[package]]
name = "system-configuration"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b"
dependencies = [
"bitflags 2.10.0",
"core-foundation 0.9.4",
"system-configuration-sys",
]
[[package]]
name = "system-configuration-sys"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "system-deps"
version = "6.2.2"
@@ -3608,7 +3878,7 @@ checksum = "f3a753bdc39c07b192151523a3f77cd0394aa75413802c883a0f6f6a0e5ee2e7"
dependencies = [
"bitflags 2.10.0",
"block2",
"core-foundation",
"core-foundation 0.10.1",
"core-graphics",
"crossbeam-channel",
"dispatch",
@@ -4072,6 +4342,26 @@ dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "tokio-native-tls"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
dependencies = [
"native-tls",
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61"
dependencies = [
"rustls",
"tokio",
]
[[package]]
name = "tokio-util"
version = "0.7.17"
@@ -4361,6 +4651,12 @@ version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
[[package]]
name = "untrusted"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]]
name = "url"
version = "2.5.7"
@@ -4409,6 +4705,12 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "vcpkg"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]]
name = "version-compare"
version = "0.2.1"
@@ -4851,6 +5153,17 @@ dependencies = [
"windows-link 0.1.3",
]
[[package]]
name = "windows-registry"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720"
dependencies = [
"windows-link 0.2.1",
"windows-result 0.4.1",
"windows-strings 0.5.1",
]
[[package]]
name = "windows-result"
version = "0.3.4"
@@ -4896,6 +5209,15 @@ dependencies = [
"windows-targets 0.42.2",
]
[[package]]
name = "windows-sys"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-sys"
version = "0.59.0"
@@ -5359,6 +5681,12 @@ dependencies = [
"synstructure",
]
[[package]]
name = "zeroize"
version = "1.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0"
[[package]]
name = "zerotrie"
version = "0.2.3"

View File

@@ -25,4 +25,7 @@ serde_json = "1"
tauri-plugin-dialog = "2.4.2"
ignore = "0.4.25"
walkdir = "2.5.0"
reqwest = { version = "0.12.28", features = ["json", "blocking"] }
uuid = { version = "1.19.0", features = ["v4", "serde"] }
chrono = { version = "0.4.42", features = ["serde"] }

View File

@@ -0,0 +1,243 @@
use crate::commands::{fs, search, shell};
use crate::llm::ollama::OllamaProvider;
use crate::llm::types::{
Message, ModelProvider, Role, ToolCall, ToolDefinition, ToolFunctionDefinition,
};
use crate::state::SessionState;
use serde::{Deserialize, Serialize};
use serde_json::json;
use tauri::State;
#[derive(Deserialize)]
pub struct ProviderConfig {
pub provider: String, // "ollama"
pub model: String,
pub base_url: Option<String>,
}
const MAX_TURNS: usize = 10;
#[tauri::command]
pub async fn chat(
messages: Vec<Message>,
config: ProviderConfig,
state: State<'_, SessionState>,
) -> Result<Vec<Message>, String> {
// 1. Setup Provider
let provider: Box<dyn ModelProvider> = match config.provider.as_str() {
"ollama" => Box::new(OllamaProvider::new(
config
.base_url
.unwrap_or_else(|| "http://localhost:11434".to_string()),
)),
_ => return Err(format!("Unsupported provider: {}", config.provider)),
};
// 2. Define Tools
let tools = get_tool_definitions();
// 3. Agent Loop
let mut current_history = messages.clone();
let mut new_messages: Vec<Message> = Vec::new();
let mut turn_count = 0;
loop {
if turn_count >= MAX_TURNS {
return Err("Max conversation turns reached.".to_string());
}
turn_count += 1;
// Call LLM
let response = provider
.chat(&config.model, &current_history, &tools)
.map_err(|e| format!("LLM Error: {}", e))?;
// Process Response
if let Some(tool_calls) = response.tool_calls {
// The Assistant wants to run tools
let assistant_msg = Message {
role: Role::Assistant,
content: response.content.unwrap_or_default(),
tool_calls: Some(tool_calls.clone()),
tool_call_id: None,
};
current_history.push(assistant_msg.clone());
new_messages.push(assistant_msg);
// Execute Tools
for call in tool_calls {
let output = execute_tool(&call, &state).await;
let tool_msg = Message {
role: Role::Tool,
content: output,
tool_calls: None,
// For Ollama/Simple flow, we just append.
// For OpenAI strict, this needs to match call.id.
tool_call_id: call.id,
};
current_history.push(tool_msg.clone());
new_messages.push(tool_msg);
}
} else {
// Final text response
let assistant_msg = Message {
role: Role::Assistant,
content: response.content.unwrap_or_default(),
tool_calls: None,
tool_call_id: None,
};
// We don't push to current_history needed for next loop, because we are done.
new_messages.push(assistant_msg);
break;
}
}
Ok(new_messages)
}
async fn execute_tool(call: &ToolCall, state: &State<'_, SessionState>) -> String {
let name = call.function.name.as_str();
// Parse arguments. They come as a JSON string from the LLM abstraction.
let args: serde_json::Value = match serde_json::from_str(&call.function.arguments) {
Ok(v) => v,
Err(e) => return format!("Error parsing arguments: {}", e),
};
match name {
"read_file" => {
let path = args["path"].as_str().unwrap_or("").to_string();
match fs::read_file(path, state.clone()).await {
Ok(content) => content,
Err(e) => format!("Error: {}", e),
}
}
"write_file" => {
let path = args["path"].as_str().unwrap_or("").to_string();
let content = args["content"].as_str().unwrap_or("").to_string();
match fs::write_file(path, content, state.clone()).await {
Ok(_) => "File written successfully.".to_string(),
Err(e) => format!("Error: {}", e),
}
}
"list_directory" => {
let path = args["path"].as_str().unwrap_or("").to_string();
match fs::list_directory(path, state.clone()).await {
Ok(entries) => serde_json::to_string(&entries).unwrap_or_default(),
Err(e) => format!("Error: {}", e),
}
}
"search_files" => {
let query = args["query"].as_str().unwrap_or("").to_string();
match search::search_files(query, state.clone()).await {
Ok(results) => serde_json::to_string(&results).unwrap_or_default(),
Err(e) => format!("Error: {}", e),
}
}
"exec_shell" => {
let command = args["command"].as_str().unwrap_or("").to_string();
let args_vec: Vec<String> = args["args"]
.as_array()
.map(|arr| {
arr.iter()
.map(|v| v.as_str().unwrap_or("").to_string())
.collect()
})
.unwrap_or_default();
match shell::exec_shell(command, args_vec, state.clone()).await {
Ok(output) => serde_json::to_string(&output).unwrap_or_default(),
Err(e) => format!("Error: {}", e),
}
}
_ => format!("Unknown tool: {}", name),
}
}
fn get_tool_definitions() -> Vec<ToolDefinition> {
vec![
ToolDefinition {
kind: "function".to_string(),
function: ToolFunctionDefinition {
name: "read_file".to_string(),
description: "Reads the content of a file in the project.".to_string(),
parameters: json!({
"type": "object",
"properties": {
"path": { "type": "string", "description": "Relative path to the file" }
},
"required": ["path"]
}),
},
},
ToolDefinition {
kind: "function".to_string(),
function: ToolFunctionDefinition {
name: "write_file".to_string(),
description: "Writes content to a file. Overwrites if exists.".to_string(),
parameters: json!({
"type": "object",
"properties": {
"path": { "type": "string", "description": "Relative path to the file" },
"content": { "type": "string", "description": "The full content to write" }
},
"required": ["path", "content"]
}),
},
},
ToolDefinition {
kind: "function".to_string(),
function: ToolFunctionDefinition {
name: "list_directory".to_string(),
description: "Lists files and directories at a path.".to_string(),
parameters: json!({
"type": "object",
"properties": {
"path": { "type": "string", "description": "Relative path to list (use '.' for root)" }
},
"required": ["path"]
}),
},
},
ToolDefinition {
kind: "function".to_string(),
function: ToolFunctionDefinition {
name: "search_files".to_string(),
description: "Searches for text content across all files in the project."
.to_string(),
parameters: json!({
"type": "object",
"properties": {
"query": { "type": "string", "description": "The string to search for" }
},
"required": ["query"]
}),
},
},
ToolDefinition {
kind: "function".to_string(),
function: ToolFunctionDefinition {
name: "exec_shell".to_string(),
description: "Executes a shell command in the project root.".to_string(),
parameters: json!({
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The command to run (e.g., 'git', 'cargo', 'ls')"
},
"args": {
"type": "array",
"items": { "type": "string" },
"description": "Arguments for the command"
}
},
"required": ["command", "args"]
}),
},
},
]
}

View File

@@ -1,3 +1,4 @@
pub mod chat;
pub mod fs;
pub mod search;
pub mod shell;

View File

@@ -1,4 +1,5 @@
mod commands;
mod llm;
mod state;
use state::SessionState;
@@ -15,7 +16,8 @@ pub fn run() {
commands::fs::write_file,
commands::fs::list_directory,
commands::search::search_files,
commands::shell::exec_shell
commands::shell::exec_shell,
commands::chat::chat
])
.run(tauri::generate_context!())
.expect("error while running tauri application");

2
src-tauri/src/llm/mod.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod ollama;
pub mod types;

170
src-tauri/src/llm/ollama.rs Normal file
View File

@@ -0,0 +1,170 @@
use crate::llm::types::{
CompletionResponse, FunctionCall, Message, ModelProvider, Role, ToolCall, ToolDefinition,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
pub struct OllamaProvider {
base_url: String,
}
impl OllamaProvider {
pub fn new(base_url: String) -> Self {
Self { base_url }
}
}
// --- Request Types ---
#[derive(Serialize)]
struct OllamaRequest<'a> {
model: &'a str,
messages: Vec<OllamaRequestMessage>,
stream: bool,
#[serde(skip_serializing_if = "is_empty_tools")]
tools: &'a [ToolDefinition],
}
fn is_empty_tools(tools: &&[ToolDefinition]) -> bool {
tools.is_empty()
}
#[derive(Serialize)]
struct OllamaRequestMessage {
role: Role,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<OllamaRequestToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
}
#[derive(Serialize)]
struct OllamaRequestToolCall {
function: OllamaRequestFunctionCall,
#[serde(rename = "type")]
kind: String,
}
#[derive(Serialize)]
struct OllamaRequestFunctionCall {
name: String,
arguments: Value,
}
// --- Response Types ---
#[derive(Deserialize)]
struct OllamaResponse {
message: OllamaResponseMessage,
}
#[derive(Deserialize)]
struct OllamaResponseMessage {
content: String,
tool_calls: Option<Vec<OllamaResponseToolCall>>,
}
#[derive(Deserialize)]
struct OllamaResponseToolCall {
function: OllamaResponseFunctionCall,
}
#[derive(Deserialize)]
struct OllamaResponseFunctionCall {
name: String,
arguments: Value, // Ollama returns Object, we convert to String for internal storage
}
impl ModelProvider for OllamaProvider {
fn chat(
&self,
model: &str,
messages: &[Message],
tools: &[ToolDefinition],
) -> Result<CompletionResponse, String> {
let client = reqwest::blocking::Client::new();
let url = format!("{}/api/chat", self.base_url.trim_end_matches('/'));
// Convert domain Messages to Ollama Messages (handling String -> Object args mismatch)
let ollama_messages: Vec<OllamaRequestMessage> = messages
.iter()
.map(|m| {
let tool_calls = m.tool_calls.as_ref().map(|calls| {
calls
.iter()
.map(|tc| {
// Try to parse string args as JSON, fallback to string value if fails
let args_val: Value = serde_json::from_str(&tc.function.arguments)
.unwrap_or(Value::String(tc.function.arguments.clone()));
OllamaRequestToolCall {
kind: tc.kind.clone(),
function: OllamaRequestFunctionCall {
name: tc.function.name.clone(),
arguments: args_val,
},
}
})
.collect()
});
OllamaRequestMessage {
role: m.role.clone(),
content: m.content.clone(),
tool_calls,
tool_call_id: m.tool_call_id.clone(),
}
})
.collect();
let request_body = OllamaRequest {
model,
messages: ollama_messages,
stream: false,
tools,
};
let res = client
.post(&url)
.json(&request_body)
.send()
.map_err(|e| format!("Request failed: {}", e))?;
if !res.status().is_success() {
let status = res.status();
let text = res.text().unwrap_or_default();
return Err(format!("Ollama API error {}: {}", status, text));
}
let response_body: OllamaResponse = res
.json()
.map_err(|e| format!("Failed to parse response: {}", e))?;
// Convert Response back to Domain types
let content = if response_body.message.content.is_empty() {
None
} else {
Some(response_body.message.content)
};
let tool_calls = response_body.message.tool_calls.map(|calls| {
calls
.into_iter()
.map(|tc| ToolCall {
id: None, // Ollama doesn't typically send IDs
kind: "function".to_string(),
function: FunctionCall {
name: tc.function.name,
arguments: tc.function.arguments.to_string(), // Convert Object -> String
},
})
.collect()
});
Ok(CompletionResponse {
content,
tool_calls,
})
}
}

View File

@@ -0,0 +1,72 @@
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Message {
pub role: Role,
pub content: String,
// For assistant messages that request tool execution
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
// For tool output messages, we need to link back to the call ID
// Note: OpenAI uses 'tool_call_id', Ollama sometimes just relies on sequence.
// We will include it for compatibility.
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ToolCall {
// ID is required by OpenAI, optional/generated for Ollama depending on version
pub id: Option<String>,
pub function: FunctionCall,
#[serde(rename = "type")]
pub kind: String, // usually "function"
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct FunctionCall {
pub name: String,
pub arguments: String, // JSON string of arguments
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ToolDefinition {
#[serde(rename = "type")]
pub kind: String, // "function"
pub function: ToolFunctionDefinition,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ToolFunctionDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value, // JSON Schema object
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CompletionResponse {
pub content: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
}
/// The abstraction for different LLM providers (Ollama, Anthropic, etc.)
pub trait ModelProvider: Send + Sync {
fn chat(
&self,
model: &str,
messages: &[Message],
tools: &[ToolDefinition],
) -> Result<CompletionResponse, String>;
}