From e4af2d5c08ea468525ce05530ec3ad71043ac463 Mon Sep 17 00:00:00 2001 From: dave Date: Tue, 28 Apr 2026 19:05:14 +0000 Subject: [PATCH] huskies: merge 803 --- server/src/agents/runtime/gemini.rs | 815 ------------------------ server/src/agents/runtime/gemini/api.rs | 190 ++++++ server/src/agents/runtime/gemini/mcp.rs | 284 +++++++++ server/src/agents/runtime/gemini/mod.rs | 379 +++++++++++ 4 files changed, 853 insertions(+), 815 deletions(-) delete mode 100644 server/src/agents/runtime/gemini.rs create mode 100644 server/src/agents/runtime/gemini/api.rs create mode 100644 server/src/agents/runtime/gemini/mcp.rs create mode 100644 server/src/agents/runtime/gemini/mod.rs diff --git a/server/src/agents/runtime/gemini.rs b/server/src/agents/runtime/gemini.rs deleted file mode 100644 index ad7385af..00000000 --- a/server/src/agents/runtime/gemini.rs +++ /dev/null @@ -1,815 +0,0 @@ -//! Gemini runtime — drives Google Gemini API sessions as agent backends. -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{Arc, Mutex}; - -use reqwest::Client; -use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; -use tokio::sync::broadcast; - -use crate::agent_log::AgentLogWriter; -use crate::slog; - -use super::super::{AgentEvent, TokenUsage}; -use super::{AgentRuntime, RuntimeContext, RuntimeResult, RuntimeStatus}; - -// ── Public runtime struct ──────────────────────────────────────────── - -/// Agent runtime that drives a Gemini model through the Google AI -/// `generateContent` REST API. -/// -/// The runtime: -/// 1. Fetches MCP tool definitions from huskies' MCP server. -/// 2. Converts them to Gemini function-calling format. -/// 3. Sends the agent prompt + tools to the Gemini API. -/// 4. Executes any requested function calls via MCP `tools/call`. -/// 5. Loops until the model produces a text-only response or an error. -/// 6. Tracks token usage from the API response metadata. -pub struct GeminiRuntime { - /// Whether a stop has been requested. - cancelled: Arc, -} - -impl GeminiRuntime { - pub fn new() -> Self { - Self { - cancelled: Arc::new(AtomicBool::new(false)), - } - } -} - -impl AgentRuntime for GeminiRuntime { - async fn start( - &self, - ctx: RuntimeContext, - tx: broadcast::Sender, - event_log: Arc>>, - log_writer: Option>>, - ) -> Result { - let api_key = std::env::var("GOOGLE_AI_API_KEY").map_err(|_| { - "GOOGLE_AI_API_KEY environment variable is not set. \ - Set it to your Google AI API key to use the Gemini runtime." - .to_string() - })?; - - let model = if ctx.command.starts_with("gemini") { - // The pool puts the model into `command` for non-CLI runtimes, - // but also check args for a --model flag. - ctx.command.clone() - } else { - // Fall back to args: look for --model - ctx.args - .iter() - .position(|a| a == "--model") - .and_then(|i| ctx.args.get(i + 1)) - .cloned() - .unwrap_or_else(|| "gemini-2.5-pro".to_string()) - }; - - let mcp_port = ctx.mcp_port; - let mcp_base = format!("http://localhost:{mcp_port}/mcp"); - - let client = Client::new(); - let cancelled = Arc::clone(&self.cancelled); - - // Step 1: Fetch MCP tool definitions and convert to Gemini format. - let gemini_tools = fetch_and_convert_mcp_tools(&client, &mcp_base).await?; - - // Step 2: Build the initial conversation contents. - let system_instruction = build_system_instruction(&ctx); - let mut contents: Vec = vec![json!({ - "role": "user", - "parts": [{ "text": ctx.prompt }] - })]; - - let mut total_usage = TokenUsage { - input_tokens: 0, - output_tokens: 0, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - total_cost_usd: 0.0, - }; - - let emit = |event: AgentEvent| { - super::super::pty::emit_event( - event, - &tx, - &event_log, - log_writer.as_ref().map(|w| w.as_ref()), - ); - }; - - emit(AgentEvent::Status { - story_id: ctx.story_id.clone(), - agent_name: ctx.agent_name.clone(), - status: "running".to_string(), - }); - - // Step 3: Conversation loop. - let mut turn = 0u32; - let max_turns = 200; // Safety limit - - loop { - if cancelled.load(Ordering::Relaxed) { - emit(AgentEvent::Error { - story_id: ctx.story_id.clone(), - agent_name: ctx.agent_name.clone(), - message: "Agent was stopped by user".to_string(), - }); - return Ok(RuntimeResult { - session_id: None, - token_usage: Some(total_usage), - }); - } - - turn += 1; - if turn > max_turns { - emit(AgentEvent::Error { - story_id: ctx.story_id.clone(), - agent_name: ctx.agent_name.clone(), - message: format!("Exceeded maximum turns ({max_turns})"), - }); - return Ok(RuntimeResult { - session_id: None, - token_usage: Some(total_usage), - }); - } - - slog!( - "[gemini] Turn {turn} for {}:{}", - ctx.story_id, - ctx.agent_name - ); - - let request_body = - build_generate_content_request(&system_instruction, &contents, &gemini_tools); - - let url = format!( - "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}" - ); - - let response = client - .post(&url) - .json(&request_body) - .send() - .await - .map_err(|e| format!("Gemini API request failed: {e}"))?; - - let status = response.status(); - let body: Value = response - .json() - .await - .map_err(|e| format!("Failed to parse Gemini API response: {e}"))?; - - if !status.is_success() { - let error_msg = body["error"]["message"] - .as_str() - .unwrap_or("Unknown API error"); - let err = format!("Gemini API error ({status}): {error_msg}"); - emit(AgentEvent::Error { - story_id: ctx.story_id.clone(), - agent_name: ctx.agent_name.clone(), - message: err.clone(), - }); - return Err(err); - } - - // Accumulate token usage. - if let Some(usage) = parse_usage_metadata(&body) { - total_usage.input_tokens += usage.input_tokens; - total_usage.output_tokens += usage.output_tokens; - } - - // Extract the candidate response. - let candidate = body["candidates"] - .as_array() - .and_then(|c| c.first()) - .ok_or_else(|| "No candidates in Gemini response".to_string())?; - - let parts = candidate["content"]["parts"] - .as_array() - .ok_or_else(|| "No parts in Gemini response candidate".to_string())?; - - // Check finish reason. - let finish_reason = candidate["finishReason"].as_str().unwrap_or(""); - - // Separate text parts and function call parts. - let mut text_parts: Vec = Vec::new(); - let mut function_calls: Vec = Vec::new(); - - for part in parts { - if let Some(text) = part["text"].as_str() { - text_parts.push(text.to_string()); - } - if let Some(fc) = part.get("functionCall") - && let (Some(name), Some(args)) = (fc["name"].as_str(), fc.get("args")) - { - function_calls.push(GeminiFunctionCall { - name: name.to_string(), - args: args.clone(), - }); - } - } - - // Emit any text output. - for text in &text_parts { - if !text.is_empty() { - emit(AgentEvent::Output { - story_id: ctx.story_id.clone(), - agent_name: ctx.agent_name.clone(), - text: text.clone(), - }); - } - } - - // If no function calls, the model is done. - if function_calls.is_empty() { - emit(AgentEvent::Done { - story_id: ctx.story_id.clone(), - agent_name: ctx.agent_name.clone(), - session_id: None, - }); - return Ok(RuntimeResult { - session_id: None, - token_usage: Some(total_usage), - }); - } - - // Add the model's response to the conversation. - let model_parts: Vec = parts.to_vec(); - contents.push(json!({ - "role": "model", - "parts": model_parts - })); - - // Execute function calls via MCP and build response parts. - let mut response_parts: Vec = Vec::new(); - - for fc in &function_calls { - if cancelled.load(Ordering::Relaxed) { - break; - } - - slog!( - "[gemini] Calling MCP tool '{}' for {}:{}", - fc.name, - ctx.story_id, - ctx.agent_name - ); - - emit(AgentEvent::Output { - story_id: ctx.story_id.clone(), - agent_name: ctx.agent_name.clone(), - text: format!("\n[Tool call: {}]\n", fc.name), - }); - - let tool_result = call_mcp_tool(&client, &mcp_base, &fc.name, &fc.args).await; - - let response_value = match &tool_result { - Ok(result) => { - emit(AgentEvent::Output { - story_id: ctx.story_id.clone(), - agent_name: ctx.agent_name.clone(), - text: format!("[Tool result: {} chars]\n", result.len()), - }); - json!({ "result": result }) - } - Err(e) => { - emit(AgentEvent::Output { - story_id: ctx.story_id.clone(), - agent_name: ctx.agent_name.clone(), - text: format!("[Tool error: {e}]\n"), - }); - json!({ "error": e }) - } - }; - - response_parts.push(json!({ - "functionResponse": { - "name": fc.name, - "response": response_value - } - })); - } - - // Add function responses to the conversation. - contents.push(json!({ - "role": "user", - "parts": response_parts - })); - - // If the model indicated it's done despite having function calls, - // respect the finish reason. - if finish_reason == "STOP" && function_calls.is_empty() { - break; - } - } - - emit(AgentEvent::Done { - story_id: ctx.story_id.clone(), - agent_name: ctx.agent_name.clone(), - session_id: None, - }); - - Ok(RuntimeResult { - session_id: None, - token_usage: Some(total_usage), - }) - } - - fn stop(&self) { - self.cancelled.store(true, Ordering::Relaxed); - } - - fn get_status(&self) -> RuntimeStatus { - if self.cancelled.load(Ordering::Relaxed) { - RuntimeStatus::Failed - } else { - RuntimeStatus::Idle - } - } -} - -// ── Internal types ─────────────────────────────────────────────────── - -struct GeminiFunctionCall { - name: String, - args: Value, -} - -// ── Gemini API types (for serde) ───────────────────────────────────── - -#[derive(Debug, Serialize, Deserialize)] -struct GeminiFunctionDeclaration { - name: String, - description: String, - #[serde(skip_serializing_if = "Option::is_none")] - parameters: Option, -} - -// ── Helper functions ───────────────────────────────────────────────── - -/// Build the system instruction content from the RuntimeContext. -fn build_system_instruction(ctx: &RuntimeContext) -> Value { - // Use system_prompt from args if provided via --append-system-prompt, - // otherwise use a sensible default. - let system_text = ctx - .args - .iter() - .position(|a| a == "--append-system-prompt") - .and_then(|i| ctx.args.get(i + 1)) - .cloned() - .unwrap_or_else(|| { - format!( - "You are an AI coding agent working on story {}. \ - You have access to tools via function calling. \ - Use them to complete the task. \ - Work in the directory: {}", - ctx.story_id, ctx.cwd - ) - }); - - json!({ - "parts": [{ "text": system_text }] - }) -} - -/// Build the full `generateContent` request body. -fn build_generate_content_request( - system_instruction: &Value, - contents: &[Value], - gemini_tools: &[GeminiFunctionDeclaration], -) -> Value { - let mut body = json!({ - "system_instruction": system_instruction, - "contents": contents, - "generationConfig": { - "temperature": 0.2, - "maxOutputTokens": 65536, - } - }); - - if !gemini_tools.is_empty() { - body["tools"] = json!([{ - "functionDeclarations": gemini_tools - }]); - } - - body -} - -/// Fetch MCP tool definitions from huskies' MCP server and convert -/// them to Gemini function declaration format. -async fn fetch_and_convert_mcp_tools( - client: &Client, - mcp_base: &str, -) -> Result, String> { - let request = json!({ - "jsonrpc": "2.0", - "id": 1, - "method": "tools/list", - "params": {} - }); - - let response = client - .post(mcp_base) - .json(&request) - .send() - .await - .map_err(|e| format!("Failed to fetch MCP tools: {e}"))?; - - let body: Value = response - .json() - .await - .map_err(|e| format!("Failed to parse MCP tools response: {e}"))?; - - let tools = body["result"]["tools"] - .as_array() - .ok_or_else(|| "No tools array in MCP response".to_string())?; - - let mut declarations = Vec::new(); - - for tool in tools { - let name = tool["name"].as_str().unwrap_or("").to_string(); - let description = tool["description"].as_str().unwrap_or("").to_string(); - - if name.is_empty() { - continue; - } - - // Convert MCP inputSchema (JSON Schema) to Gemini parameters - // (OpenAPI-subset schema). They are structurally compatible for - // simple object schemas. - let parameters = convert_mcp_schema_to_gemini(tool.get("inputSchema")); - - declarations.push(GeminiFunctionDeclaration { - name, - description, - parameters, - }); - } - - slog!( - "[gemini] Loaded {} MCP tools as function declarations", - declarations.len() - ); - Ok(declarations) -} - -/// Convert an MCP inputSchema (JSON Schema) to a Gemini-compatible -/// OpenAPI-subset parameter schema. -/// -/// Gemini function calling expects parameters in OpenAPI format, which -/// is structurally similar to JSON Schema for simple object types. -/// We strip unsupported fields and ensure the type is "object". -fn convert_mcp_schema_to_gemini(schema: Option<&Value>) -> Option { - let schema = schema?; - - // If the schema has no properties (empty tool), return None. - let properties = schema.get("properties")?; - if properties.as_object().is_some_and(|p| p.is_empty()) { - return None; - } - - let mut result = json!({ - "type": "object", - "properties": clean_schema_properties(properties), - }); - - // Preserve required fields if present. - if let Some(required) = schema.get("required") { - result["required"] = required.clone(); - } - - Some(result) -} - -/// Recursively clean schema properties to be Gemini-compatible. -/// Removes unsupported JSON Schema keywords. -fn clean_schema_properties(properties: &Value) -> Value { - let Some(obj) = properties.as_object() else { - return properties.clone(); - }; - - let mut cleaned = serde_json::Map::new(); - for (key, value) in obj { - let mut prop = value.clone(); - // Remove JSON Schema keywords not supported by Gemini - if let Some(p) = prop.as_object_mut() { - p.remove("$schema"); - p.remove("additionalProperties"); - - // Recursively clean nested object properties - if let Some(nested_props) = p.get("properties").cloned() { - p.insert( - "properties".to_string(), - clean_schema_properties(&nested_props), - ); - } - - // Clean items schema for arrays - if let Some(items) = p.get("items").cloned() - && let Some(items_obj) = items.as_object() - { - let mut cleaned_items = items_obj.clone(); - cleaned_items.remove("$schema"); - cleaned_items.remove("additionalProperties"); - p.insert("items".to_string(), Value::Object(cleaned_items)); - } - } - cleaned.insert(key.clone(), prop); - } - Value::Object(cleaned) -} - -/// Call an MCP tool via huskies' MCP server. -async fn call_mcp_tool( - client: &Client, - mcp_base: &str, - tool_name: &str, - args: &Value, -) -> Result { - let request = json!({ - "jsonrpc": "2.0", - "id": 1, - "method": "tools/call", - "params": { - "name": tool_name, - "arguments": args - } - }); - - let response = client - .post(mcp_base) - .json(&request) - .send() - .await - .map_err(|e| format!("MCP tool call failed: {e}"))?; - - let body: Value = response - .json() - .await - .map_err(|e| format!("Failed to parse MCP tool response: {e}"))?; - - if let Some(error) = body.get("error") { - let msg = error["message"].as_str().unwrap_or("Unknown MCP error"); - return Err(format!("MCP tool '{tool_name}' error: {msg}")); - } - - // MCP tools/call returns { result: { content: [{ type: "text", text: "..." }] } } - let content = &body["result"]["content"]; - if let Some(arr) = content.as_array() { - let texts: Vec<&str> = arr.iter().filter_map(|c| c["text"].as_str()).collect(); - if !texts.is_empty() { - return Ok(texts.join("\n")); - } - } - - // Fall back to serializing the entire result. - Ok(body["result"].to_string()) -} - -/// Parse token usage metadata from a Gemini API response. -fn parse_usage_metadata(response: &Value) -> Option { - let metadata = response.get("usageMetadata")?; - Some(TokenUsage { - input_tokens: metadata - .get("promptTokenCount") - .and_then(|v| v.as_u64()) - .unwrap_or(0), - output_tokens: metadata - .get("candidatesTokenCount") - .and_then(|v| v.as_u64()) - .unwrap_or(0), - // Gemini doesn't have cache token fields, but we keep the struct uniform. - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - // Google AI API doesn't report cost; leave at 0. - total_cost_usd: 0.0, - }) -} - -// ── Tests ──────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn convert_mcp_schema_simple_object() { - let schema = json!({ - "type": "object", - "properties": { - "story_id": { - "type": "string", - "description": "Story identifier" - } - }, - "required": ["story_id"] - }); - - let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap(); - assert_eq!(result["type"], "object"); - assert!(result["properties"]["story_id"].is_object()); - assert_eq!(result["required"][0], "story_id"); - } - - #[test] - fn convert_mcp_schema_empty_properties_returns_none() { - let schema = json!({ - "type": "object", - "properties": {} - }); - - assert!(convert_mcp_schema_to_gemini(Some(&schema)).is_none()); - } - - #[test] - fn convert_mcp_schema_none_returns_none() { - assert!(convert_mcp_schema_to_gemini(None).is_none()); - } - - #[test] - fn convert_mcp_schema_strips_additional_properties() { - let schema = json!({ - "type": "object", - "properties": { - "name": { - "type": "string", - "additionalProperties": false, - "$schema": "http://json-schema.org/draft-07/schema#" - } - } - }); - - let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap(); - let name_prop = &result["properties"]["name"]; - assert!(name_prop.get("additionalProperties").is_none()); - assert!(name_prop.get("$schema").is_none()); - assert_eq!(name_prop["type"], "string"); - } - - #[test] - fn convert_mcp_schema_with_nested_objects() { - let schema = json!({ - "type": "object", - "properties": { - "config": { - "type": "object", - "properties": { - "key": { "type": "string" } - } - } - } - }); - - let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap(); - assert!(result["properties"]["config"]["properties"]["key"].is_object()); - } - - #[test] - fn convert_mcp_schema_with_array_items() { - let schema = json!({ - "type": "object", - "properties": { - "items": { - "type": "array", - "items": { - "type": "object", - "properties": { - "name": { "type": "string" } - }, - "additionalProperties": false - } - } - } - }); - - let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap(); - let items_schema = &result["properties"]["items"]["items"]; - assert!(items_schema.get("additionalProperties").is_none()); - } - - #[test] - fn build_system_instruction_uses_args() { - let ctx = RuntimeContext { - story_id: "42_story_test".to_string(), - agent_name: "coder-1".to_string(), - command: "gemini-2.5-pro".to_string(), - args: vec![ - "--append-system-prompt".to_string(), - "Custom system prompt".to_string(), - ], - prompt: "Do the thing".to_string(), - cwd: "/tmp/wt".to_string(), - inactivity_timeout_secs: 300, - mcp_port: 3001, - session_id_to_resume: None, - fresh_prompt: None, - }; - - let instruction = build_system_instruction(&ctx); - assert_eq!(instruction["parts"][0]["text"], "Custom system prompt"); - } - - #[test] - fn build_system_instruction_default() { - let ctx = RuntimeContext { - story_id: "42_story_test".to_string(), - agent_name: "coder-1".to_string(), - command: "gemini-2.5-pro".to_string(), - args: vec![], - prompt: "Do the thing".to_string(), - cwd: "/tmp/wt".to_string(), - inactivity_timeout_secs: 300, - mcp_port: 3001, - session_id_to_resume: None, - fresh_prompt: None, - }; - - let instruction = build_system_instruction(&ctx); - let text = instruction["parts"][0]["text"].as_str().unwrap(); - assert!(text.contains("42_story_test")); - assert!(text.contains("/tmp/wt")); - } - - #[test] - fn build_generate_content_request_includes_tools() { - let system = json!({"parts": [{"text": "system"}]}); - let contents = vec![json!({"role": "user", "parts": [{"text": "hello"}]})]; - let tools = vec![GeminiFunctionDeclaration { - name: "my_tool".to_string(), - description: "A tool".to_string(), - parameters: Some(json!({"type": "object", "properties": {"x": {"type": "string"}}})), - }]; - - let body = build_generate_content_request(&system, &contents, &tools); - assert!(body["tools"][0]["functionDeclarations"].is_array()); - assert_eq!( - body["tools"][0]["functionDeclarations"][0]["name"], - "my_tool" - ); - } - - #[test] - fn build_generate_content_request_no_tools() { - let system = json!({"parts": [{"text": "system"}]}); - let contents = vec![json!({"role": "user", "parts": [{"text": "hello"}]})]; - let tools: Vec = vec![]; - - let body = build_generate_content_request(&system, &contents, &tools); - assert!(body.get("tools").is_none()); - } - - #[test] - fn parse_usage_metadata_valid() { - let response = json!({ - "usageMetadata": { - "promptTokenCount": 100, - "candidatesTokenCount": 50, - "totalTokenCount": 150 - } - }); - - let usage = parse_usage_metadata(&response).unwrap(); - assert_eq!(usage.input_tokens, 100); - assert_eq!(usage.output_tokens, 50); - assert_eq!(usage.cache_creation_input_tokens, 0); - assert_eq!(usage.total_cost_usd, 0.0); - } - - #[test] - fn parse_usage_metadata_missing() { - let response = json!({"candidates": []}); - assert!(parse_usage_metadata(&response).is_none()); - } - - #[test] - fn gemini_runtime_stop_sets_cancelled() { - let runtime = GeminiRuntime::new(); - assert_eq!(runtime.get_status(), RuntimeStatus::Idle); - runtime.stop(); - assert_eq!(runtime.get_status(), RuntimeStatus::Failed); - } - - #[test] - fn model_extraction_from_command() { - // When command starts with "gemini", use it as model name - let ctx = RuntimeContext { - story_id: "1".to_string(), - agent_name: "coder".to_string(), - command: "gemini-2.5-pro".to_string(), - args: vec![], - prompt: "test".to_string(), - cwd: "/tmp".to_string(), - inactivity_timeout_secs: 300, - mcp_port: 3001, - session_id_to_resume: None, - fresh_prompt: None, - }; - - // The model extraction logic is inside start(), but we test the - // condition here. - assert!(ctx.command.starts_with("gemini")); - } -} diff --git a/server/src/agents/runtime/gemini/api.rs b/server/src/agents/runtime/gemini/api.rs new file mode 100644 index 00000000..a927f7ee --- /dev/null +++ b/server/src/agents/runtime/gemini/api.rs @@ -0,0 +1,190 @@ +//! Gemini API types, request builders, and response parsers. +use serde::{Deserialize, Serialize}; +use serde_json::{Value, json}; + +use super::super::super::TokenUsage; +use super::super::RuntimeContext; + +// ── Gemini API types ───────────────────────────────────────────────── + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct GeminiFunctionDeclaration { + pub name: String, + pub description: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option, +} + +// ── Request builders ───────────────────────────────────────────────── + +/// Build the system instruction content from the RuntimeContext. +pub(super) fn build_system_instruction(ctx: &RuntimeContext) -> Value { + // Use system_prompt from args if provided via --append-system-prompt, + // otherwise use a sensible default. + let system_text = ctx + .args + .iter() + .position(|a| a == "--append-system-prompt") + .and_then(|i| ctx.args.get(i + 1)) + .cloned() + .unwrap_or_else(|| { + format!( + "You are an AI coding agent working on story {}. \ + You have access to tools via function calling. \ + Use them to complete the task. \ + Work in the directory: {}", + ctx.story_id, ctx.cwd + ) + }); + + json!({ + "parts": [{ "text": system_text }] + }) +} + +/// Build the full `generateContent` request body. +pub(super) fn build_generate_content_request( + system_instruction: &Value, + contents: &[Value], + gemini_tools: &[GeminiFunctionDeclaration], +) -> Value { + let mut body = json!({ + "system_instruction": system_instruction, + "contents": contents, + "generationConfig": { + "temperature": 0.2, + "maxOutputTokens": 65536, + } + }); + + if !gemini_tools.is_empty() { + body["tools"] = json!([{ + "functionDeclarations": gemini_tools + }]); + } + + body +} + +// ── Response parsing ───────────────────────────────────────────────── + +/// Parse token usage metadata from a Gemini API response. +pub(super) fn parse_usage_metadata(response: &Value) -> Option { + let metadata = response.get("usageMetadata")?; + Some(TokenUsage { + input_tokens: metadata + .get("promptTokenCount") + .and_then(|v| v.as_u64()) + .unwrap_or(0), + output_tokens: metadata + .get("candidatesTokenCount") + .and_then(|v| v.as_u64()) + .unwrap_or(0), + // Gemini doesn't have cache token fields, but we keep the struct uniform. + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + // Google AI API doesn't report cost; leave at 0. + total_cost_usd: 0.0, + }) +} + +// ── Tests ──────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn build_system_instruction_uses_args() { + let ctx = RuntimeContext { + story_id: "42_story_test".to_string(), + agent_name: "coder-1".to_string(), + command: "gemini-2.5-pro".to_string(), + args: vec![ + "--append-system-prompt".to_string(), + "Custom system prompt".to_string(), + ], + prompt: "Do the thing".to_string(), + cwd: "/tmp/wt".to_string(), + inactivity_timeout_secs: 300, + mcp_port: 3001, + session_id_to_resume: None, + fresh_prompt: None, + }; + + let instruction = build_system_instruction(&ctx); + assert_eq!(instruction["parts"][0]["text"], "Custom system prompt"); + } + + #[test] + fn build_system_instruction_default() { + let ctx = RuntimeContext { + story_id: "42_story_test".to_string(), + agent_name: "coder-1".to_string(), + command: "gemini-2.5-pro".to_string(), + args: vec![], + prompt: "Do the thing".to_string(), + cwd: "/tmp/wt".to_string(), + inactivity_timeout_secs: 300, + mcp_port: 3001, + session_id_to_resume: None, + fresh_prompt: None, + }; + + let instruction = build_system_instruction(&ctx); + let text = instruction["parts"][0]["text"].as_str().unwrap(); + assert!(text.contains("42_story_test")); + assert!(text.contains("/tmp/wt")); + } + + #[test] + fn build_generate_content_request_includes_tools() { + let system = json!({"parts": [{"text": "system"}]}); + let contents = vec![json!({"role": "user", "parts": [{"text": "hello"}]})]; + let tools = vec![GeminiFunctionDeclaration { + name: "my_tool".to_string(), + description: "A tool".to_string(), + parameters: Some(json!({"type": "object", "properties": {"x": {"type": "string"}}})), + }]; + + let body = build_generate_content_request(&system, &contents, &tools); + assert!(body["tools"][0]["functionDeclarations"].is_array()); + assert_eq!( + body["tools"][0]["functionDeclarations"][0]["name"], + "my_tool" + ); + } + + #[test] + fn build_generate_content_request_no_tools() { + let system = json!({"parts": [{"text": "system"}]}); + let contents = vec![json!({"role": "user", "parts": [{"text": "hello"}]})]; + let tools: Vec = vec![]; + + let body = build_generate_content_request(&system, &contents, &tools); + assert!(body.get("tools").is_none()); + } + + #[test] + fn parse_usage_metadata_valid() { + let response = json!({ + "usageMetadata": { + "promptTokenCount": 100, + "candidatesTokenCount": 50, + "totalTokenCount": 150 + } + }); + + let usage = parse_usage_metadata(&response).unwrap(); + assert_eq!(usage.input_tokens, 100); + assert_eq!(usage.output_tokens, 50); + assert_eq!(usage.cache_creation_input_tokens, 0); + assert_eq!(usage.total_cost_usd, 0.0); + } + + #[test] + fn parse_usage_metadata_missing() { + let response = json!({"candidates": []}); + assert!(parse_usage_metadata(&response).is_none()); + } +} diff --git a/server/src/agents/runtime/gemini/mcp.rs b/server/src/agents/runtime/gemini/mcp.rs new file mode 100644 index 00000000..658d37da --- /dev/null +++ b/server/src/agents/runtime/gemini/mcp.rs @@ -0,0 +1,284 @@ +//! MCP tool fetching, schema conversion, and tool invocation for the Gemini runtime. +use reqwest::Client; +use serde_json::{Value, json}; + +use crate::slog; + +use super::api::GeminiFunctionDeclaration; + +// ── MCP tool fetching ──────────────────────────────────────────────── + +/// Fetch MCP tool definitions from huskies' MCP server and convert +/// them to Gemini function declaration format. +pub(super) async fn fetch_and_convert_mcp_tools( + client: &Client, + mcp_base: &str, +) -> Result, String> { + let request = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list", + "params": {} + }); + + let response = client + .post(mcp_base) + .json(&request) + .send() + .await + .map_err(|e| format!("Failed to fetch MCP tools: {e}"))?; + + let body: Value = response + .json() + .await + .map_err(|e| format!("Failed to parse MCP tools response: {e}"))?; + + let tools = body["result"]["tools"] + .as_array() + .ok_or_else(|| "No tools array in MCP response".to_string())?; + + let mut declarations = Vec::new(); + + for tool in tools { + let name = tool["name"].as_str().unwrap_or("").to_string(); + let description = tool["description"].as_str().unwrap_or("").to_string(); + + if name.is_empty() { + continue; + } + + // Convert MCP inputSchema (JSON Schema) to Gemini parameters + // (OpenAPI-subset schema). They are structurally compatible for + // simple object schemas. + let parameters = convert_mcp_schema_to_gemini(tool.get("inputSchema")); + + declarations.push(GeminiFunctionDeclaration { + name, + description, + parameters, + }); + } + + slog!( + "[gemini] Loaded {} MCP tools as function declarations", + declarations.len() + ); + Ok(declarations) +} + +/// Call an MCP tool via huskies' MCP server. +pub(super) async fn call_mcp_tool( + client: &Client, + mcp_base: &str, + tool_name: &str, + args: &Value, +) -> Result { + let request = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": tool_name, + "arguments": args + } + }); + + let response = client + .post(mcp_base) + .json(&request) + .send() + .await + .map_err(|e| format!("MCP tool call failed: {e}"))?; + + let body: Value = response + .json() + .await + .map_err(|e| format!("Failed to parse MCP tool response: {e}"))?; + + if let Some(error) = body.get("error") { + let msg = error["message"].as_str().unwrap_or("Unknown MCP error"); + return Err(format!("MCP tool '{tool_name}' error: {msg}")); + } + + // MCP tools/call returns { result: { content: [{ type: "text", text: "..." }] } } + let content = &body["result"]["content"]; + if let Some(arr) = content.as_array() { + let texts: Vec<&str> = arr.iter().filter_map(|c| c["text"].as_str()).collect(); + if !texts.is_empty() { + return Ok(texts.join("\n")); + } + } + + // Fall back to serializing the entire result. + Ok(body["result"].to_string()) +} + +// ── Schema conversion ──────────────────────────────────────────────── + +/// Convert an MCP inputSchema (JSON Schema) to a Gemini-compatible +/// OpenAPI-subset parameter schema. +/// +/// Gemini function calling expects parameters in OpenAPI format, which +/// is structurally similar to JSON Schema for simple object types. +/// We strip unsupported fields and ensure the type is "object". +pub(super) fn convert_mcp_schema_to_gemini(schema: Option<&Value>) -> Option { + let schema = schema?; + + // If the schema has no properties (empty tool), return None. + let properties = schema.get("properties")?; + if properties.as_object().is_some_and(|p| p.is_empty()) { + return None; + } + + let mut result = json!({ + "type": "object", + "properties": clean_schema_properties(properties), + }); + + // Preserve required fields if present. + if let Some(required) = schema.get("required") { + result["required"] = required.clone(); + } + + Some(result) +} + +/// Recursively clean schema properties to be Gemini-compatible. +/// Removes unsupported JSON Schema keywords. +fn clean_schema_properties(properties: &Value) -> Value { + let Some(obj) = properties.as_object() else { + return properties.clone(); + }; + + let mut cleaned = serde_json::Map::new(); + for (key, value) in obj { + let mut prop = value.clone(); + // Remove JSON Schema keywords not supported by Gemini + if let Some(p) = prop.as_object_mut() { + p.remove("$schema"); + p.remove("additionalProperties"); + + // Recursively clean nested object properties + if let Some(nested_props) = p.get("properties").cloned() { + p.insert( + "properties".to_string(), + clean_schema_properties(&nested_props), + ); + } + + // Clean items schema for arrays + if let Some(items) = p.get("items").cloned() + && let Some(items_obj) = items.as_object() + { + let mut cleaned_items = items_obj.clone(); + cleaned_items.remove("$schema"); + cleaned_items.remove("additionalProperties"); + p.insert("items".to_string(), Value::Object(cleaned_items)); + } + } + cleaned.insert(key.clone(), prop); + } + Value::Object(cleaned) +} + +// ── Tests ──────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn convert_mcp_schema_simple_object() { + let schema = json!({ + "type": "object", + "properties": { + "story_id": { + "type": "string", + "description": "Story identifier" + } + }, + "required": ["story_id"] + }); + + let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap(); + assert_eq!(result["type"], "object"); + assert!(result["properties"]["story_id"].is_object()); + assert_eq!(result["required"][0], "story_id"); + } + + #[test] + fn convert_mcp_schema_empty_properties_returns_none() { + let schema = json!({ + "type": "object", + "properties": {} + }); + + assert!(convert_mcp_schema_to_gemini(Some(&schema)).is_none()); + } + + #[test] + fn convert_mcp_schema_none_returns_none() { + assert!(convert_mcp_schema_to_gemini(None).is_none()); + } + + #[test] + fn convert_mcp_schema_strips_additional_properties() { + let schema = json!({ + "type": "object", + "properties": { + "name": { + "type": "string", + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#" + } + } + }); + + let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap(); + let name_prop = &result["properties"]["name"]; + assert!(name_prop.get("additionalProperties").is_none()); + assert!(name_prop.get("$schema").is_none()); + assert_eq!(name_prop["type"], "string"); + } + + #[test] + fn convert_mcp_schema_with_nested_objects() { + let schema = json!({ + "type": "object", + "properties": { + "config": { + "type": "object", + "properties": { + "key": { "type": "string" } + } + } + } + }); + + let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap(); + assert!(result["properties"]["config"]["properties"]["key"].is_object()); + } + + #[test] + fn convert_mcp_schema_with_array_items() { + let schema = json!({ + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { "type": "string" } + }, + "additionalProperties": false + } + } + } + }); + + let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap(); + let items_schema = &result["properties"]["items"]["items"]; + assert!(items_schema.get("additionalProperties").is_none()); + } +} diff --git a/server/src/agents/runtime/gemini/mod.rs b/server/src/agents/runtime/gemini/mod.rs new file mode 100644 index 00000000..924dc477 --- /dev/null +++ b/server/src/agents/runtime/gemini/mod.rs @@ -0,0 +1,379 @@ +//! Gemini runtime — drives Google Gemini API sessions as agent backends. +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; + +use reqwest::Client; +use serde_json::json; +use tokio::sync::broadcast; + +use crate::agent_log::AgentLogWriter; +use crate::slog; + +use super::super::{AgentEvent, TokenUsage}; +use super::{AgentRuntime, RuntimeContext, RuntimeResult, RuntimeStatus}; + +mod api; +mod mcp; + +use api::{build_generate_content_request, build_system_instruction, parse_usage_metadata}; +use mcp::{call_mcp_tool, fetch_and_convert_mcp_tools}; + +// ── Internal types ─────────────────────────────────────────────────── + +struct GeminiFunctionCall { + name: String, + args: serde_json::Value, +} + +// ── Public runtime struct ──────────────────────────────────────────── + +/// Agent runtime that drives a Gemini model through the Google AI +/// `generateContent` REST API. +/// +/// The runtime: +/// 1. Fetches MCP tool definitions from huskies' MCP server. +/// 2. Converts them to Gemini function-calling format. +/// 3. Sends the agent prompt + tools to the Gemini API. +/// 4. Executes any requested function calls via MCP `tools/call`. +/// 5. Loops until the model produces a text-only response or an error. +/// 6. Tracks token usage from the API response metadata. +pub struct GeminiRuntime { + /// Whether a stop has been requested. + cancelled: Arc, +} + +impl GeminiRuntime { + pub fn new() -> Self { + Self { + cancelled: Arc::new(AtomicBool::new(false)), + } + } +} + +impl AgentRuntime for GeminiRuntime { + async fn start( + &self, + ctx: RuntimeContext, + tx: broadcast::Sender, + event_log: Arc>>, + log_writer: Option>>, + ) -> Result { + let api_key = std::env::var("GOOGLE_AI_API_KEY").map_err(|_| { + "GOOGLE_AI_API_KEY environment variable is not set. \ + Set it to your Google AI API key to use the Gemini runtime." + .to_string() + })?; + + let model = if ctx.command.starts_with("gemini") { + // The pool puts the model into `command` for non-CLI runtimes, + // but also check args for a --model flag. + ctx.command.clone() + } else { + // Fall back to args: look for --model + ctx.args + .iter() + .position(|a| a == "--model") + .and_then(|i| ctx.args.get(i + 1)) + .cloned() + .unwrap_or_else(|| "gemini-2.5-pro".to_string()) + }; + + let mcp_port = ctx.mcp_port; + let mcp_base = format!("http://localhost:{mcp_port}/mcp"); + + let client = Client::new(); + let cancelled = Arc::clone(&self.cancelled); + + // Step 1: Fetch MCP tool definitions and convert to Gemini format. + let gemini_tools = fetch_and_convert_mcp_tools(&client, &mcp_base).await?; + + // Step 2: Build the initial conversation contents. + let system_instruction = build_system_instruction(&ctx); + let mut contents: Vec = vec![json!({ + "role": "user", + "parts": [{ "text": ctx.prompt }] + })]; + + let mut total_usage = TokenUsage { + input_tokens: 0, + output_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + total_cost_usd: 0.0, + }; + + let emit = |event: AgentEvent| { + super::super::pty::emit_event( + event, + &tx, + &event_log, + log_writer.as_ref().map(|w| w.as_ref()), + ); + }; + + emit(AgentEvent::Status { + story_id: ctx.story_id.clone(), + agent_name: ctx.agent_name.clone(), + status: "running".to_string(), + }); + + // Step 3: Conversation loop. + let mut turn = 0u32; + let max_turns = 200; // Safety limit + + loop { + if cancelled.load(Ordering::Relaxed) { + emit(AgentEvent::Error { + story_id: ctx.story_id.clone(), + agent_name: ctx.agent_name.clone(), + message: "Agent was stopped by user".to_string(), + }); + return Ok(RuntimeResult { + session_id: None, + token_usage: Some(total_usage), + }); + } + + turn += 1; + if turn > max_turns { + emit(AgentEvent::Error { + story_id: ctx.story_id.clone(), + agent_name: ctx.agent_name.clone(), + message: format!("Exceeded maximum turns ({max_turns})"), + }); + return Ok(RuntimeResult { + session_id: None, + token_usage: Some(total_usage), + }); + } + + slog!( + "[gemini] Turn {turn} for {}:{}", + ctx.story_id, + ctx.agent_name + ); + + let request_body = + build_generate_content_request(&system_instruction, &contents, &gemini_tools); + + let url = format!( + "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}" + ); + + let response = client + .post(&url) + .json(&request_body) + .send() + .await + .map_err(|e| format!("Gemini API request failed: {e}"))?; + + let status = response.status(); + let body: serde_json::Value = response + .json() + .await + .map_err(|e| format!("Failed to parse Gemini API response: {e}"))?; + + if !status.is_success() { + let error_msg = body["error"]["message"] + .as_str() + .unwrap_or("Unknown API error"); + let err = format!("Gemini API error ({status}): {error_msg}"); + emit(AgentEvent::Error { + story_id: ctx.story_id.clone(), + agent_name: ctx.agent_name.clone(), + message: err.clone(), + }); + return Err(err); + } + + // Accumulate token usage. + if let Some(usage) = parse_usage_metadata(&body) { + total_usage.input_tokens += usage.input_tokens; + total_usage.output_tokens += usage.output_tokens; + } + + // Extract the candidate response. + let candidate = body["candidates"] + .as_array() + .and_then(|c| c.first()) + .ok_or_else(|| "No candidates in Gemini response".to_string())?; + + let parts = candidate["content"]["parts"] + .as_array() + .ok_or_else(|| "No parts in Gemini response candidate".to_string())?; + + // Check finish reason. + let finish_reason = candidate["finishReason"].as_str().unwrap_or(""); + + // Separate text parts and function call parts. + let mut text_parts: Vec = Vec::new(); + let mut function_calls: Vec = Vec::new(); + + for part in parts { + if let Some(text) = part["text"].as_str() { + text_parts.push(text.to_string()); + } + if let Some(fc) = part.get("functionCall") + && let (Some(name), Some(args)) = (fc["name"].as_str(), fc.get("args")) + { + function_calls.push(GeminiFunctionCall { + name: name.to_string(), + args: args.clone(), + }); + } + } + + // Emit any text output. + for text in &text_parts { + if !text.is_empty() { + emit(AgentEvent::Output { + story_id: ctx.story_id.clone(), + agent_name: ctx.agent_name.clone(), + text: text.clone(), + }); + } + } + + // If no function calls, the model is done. + if function_calls.is_empty() { + emit(AgentEvent::Done { + story_id: ctx.story_id.clone(), + agent_name: ctx.agent_name.clone(), + session_id: None, + }); + return Ok(RuntimeResult { + session_id: None, + token_usage: Some(total_usage), + }); + } + + // Add the model's response to the conversation. + let model_parts: Vec = parts.to_vec(); + contents.push(json!({ + "role": "model", + "parts": model_parts + })); + + // Execute function calls via MCP and build response parts. + let mut response_parts: Vec = Vec::new(); + + for fc in &function_calls { + if cancelled.load(Ordering::Relaxed) { + break; + } + + slog!( + "[gemini] Calling MCP tool '{}' for {}:{}", + fc.name, + ctx.story_id, + ctx.agent_name + ); + + emit(AgentEvent::Output { + story_id: ctx.story_id.clone(), + agent_name: ctx.agent_name.clone(), + text: format!("\n[Tool call: {}]\n", fc.name), + }); + + let tool_result = call_mcp_tool(&client, &mcp_base, &fc.name, &fc.args).await; + + let response_value = match &tool_result { + Ok(result) => { + emit(AgentEvent::Output { + story_id: ctx.story_id.clone(), + agent_name: ctx.agent_name.clone(), + text: format!("[Tool result: {} chars]\n", result.len()), + }); + json!({ "result": result }) + } + Err(e) => { + emit(AgentEvent::Output { + story_id: ctx.story_id.clone(), + agent_name: ctx.agent_name.clone(), + text: format!("[Tool error: {e}]\n"), + }); + json!({ "error": e }) + } + }; + + response_parts.push(json!({ + "functionResponse": { + "name": fc.name, + "response": response_value + } + })); + } + + // Add function responses to the conversation. + contents.push(json!({ + "role": "user", + "parts": response_parts + })); + + // If the model indicated it's done despite having function calls, + // respect the finish reason. + if finish_reason == "STOP" && function_calls.is_empty() { + break; + } + } + + emit(AgentEvent::Done { + story_id: ctx.story_id.clone(), + agent_name: ctx.agent_name.clone(), + session_id: None, + }); + + Ok(RuntimeResult { + session_id: None, + token_usage: Some(total_usage), + }) + } + + fn stop(&self) { + self.cancelled.store(true, Ordering::Relaxed); + } + + fn get_status(&self) -> RuntimeStatus { + if self.cancelled.load(Ordering::Relaxed) { + RuntimeStatus::Failed + } else { + RuntimeStatus::Idle + } + } +} + +// ── Tests ──────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn gemini_runtime_stop_sets_cancelled() { + let runtime = GeminiRuntime::new(); + assert_eq!(runtime.get_status(), RuntimeStatus::Idle); + runtime.stop(); + assert_eq!(runtime.get_status(), RuntimeStatus::Failed); + } + + #[test] + fn model_extraction_from_command() { + // When command starts with "gemini", use it as model name + let ctx = RuntimeContext { + story_id: "1".to_string(), + agent_name: "coder".to_string(), + command: "gemini-2.5-pro".to_string(), + args: vec![], + prompt: "test".to_string(), + cwd: "/tmp".to_string(), + inactivity_timeout_secs: 300, + mcp_port: 3001, + session_id_to_resume: None, + fresh_prompt: None, + }; + + // The model extraction logic is inside start(), but we test the + // condition here. + assert!(ctx.command.starts_with("gemini")); + } +}