use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use tokio::sync::broadcast; use crate::agent_log::AgentLogWriter; use crate::slog; use super::super::{AgentEvent, TokenUsage}; use super::{AgentRuntime, RuntimeContext, RuntimeResult, RuntimeStatus}; // ── Public runtime struct ──────────────────────────────────────────── /// Agent runtime that drives a Gemini model through the Google AI /// `generateContent` REST API. /// /// The runtime: /// 1. Fetches MCP tool definitions from storkit's MCP server. /// 2. Converts them to Gemini function-calling format. /// 3. Sends the agent prompt + tools to the Gemini API. /// 4. Executes any requested function calls via MCP `tools/call`. /// 5. Loops until the model produces a text-only response or an error. /// 6. Tracks token usage from the API response metadata. pub struct GeminiRuntime { /// Whether a stop has been requested. cancelled: Arc, } impl GeminiRuntime { pub fn new() -> Self { Self { cancelled: Arc::new(AtomicBool::new(false)), } } } impl AgentRuntime for GeminiRuntime { async fn start( &self, ctx: RuntimeContext, tx: broadcast::Sender, event_log: Arc>>, log_writer: Option>>, ) -> Result { let api_key = std::env::var("GOOGLE_AI_API_KEY").map_err(|_| { "GOOGLE_AI_API_KEY environment variable is not set. \ Set it to your Google AI API key to use the Gemini runtime." .to_string() })?; let model = if ctx.command.starts_with("gemini") { // The pool puts the model into `command` for non-CLI runtimes, // but also check args for a --model flag. ctx.command.clone() } else { // Fall back to args: look for --model ctx.args .iter() .position(|a| a == "--model") .and_then(|i| ctx.args.get(i + 1)) .cloned() .unwrap_or_else(|| "gemini-2.5-pro".to_string()) }; let mcp_port = ctx.mcp_port; let mcp_base = format!("http://localhost:{mcp_port}/mcp"); let client = Client::new(); let cancelled = Arc::clone(&self.cancelled); // Step 1: Fetch MCP tool definitions and convert to Gemini format. let gemini_tools = fetch_and_convert_mcp_tools(&client, &mcp_base).await?; // Step 2: Build the initial conversation contents. let system_instruction = build_system_instruction(&ctx); let mut contents: Vec = vec![json!({ "role": "user", "parts": [{ "text": ctx.prompt }] })]; let mut total_usage = TokenUsage { input_tokens: 0, output_tokens: 0, cache_creation_input_tokens: 0, cache_read_input_tokens: 0, total_cost_usd: 0.0, }; let emit = |event: AgentEvent| { super::super::pty::emit_event( event, &tx, &event_log, log_writer.as_ref().map(|w| w.as_ref()), ); }; emit(AgentEvent::Status { story_id: ctx.story_id.clone(), agent_name: ctx.agent_name.clone(), status: "running".to_string(), }); // Step 3: Conversation loop. let mut turn = 0u32; let max_turns = 200; // Safety limit loop { if cancelled.load(Ordering::Relaxed) { emit(AgentEvent::Error { story_id: ctx.story_id.clone(), agent_name: ctx.agent_name.clone(), message: "Agent was stopped by user".to_string(), }); return Ok(RuntimeResult { session_id: None, token_usage: Some(total_usage), }); } turn += 1; if turn > max_turns { emit(AgentEvent::Error { story_id: ctx.story_id.clone(), agent_name: ctx.agent_name.clone(), message: format!("Exceeded maximum turns ({max_turns})"), }); return Ok(RuntimeResult { session_id: None, token_usage: Some(total_usage), }); } slog!("[gemini] Turn {turn} for {}:{}", ctx.story_id, ctx.agent_name); let request_body = build_generate_content_request( &system_instruction, &contents, &gemini_tools, ); let url = format!( "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}" ); let response = client .post(&url) .json(&request_body) .send() .await .map_err(|e| format!("Gemini API request failed: {e}"))?; let status = response.status(); let body: Value = response .json() .await .map_err(|e| format!("Failed to parse Gemini API response: {e}"))?; if !status.is_success() { let error_msg = body["error"]["message"] .as_str() .unwrap_or("Unknown API error"); let err = format!("Gemini API error ({status}): {error_msg}"); emit(AgentEvent::Error { story_id: ctx.story_id.clone(), agent_name: ctx.agent_name.clone(), message: err.clone(), }); return Err(err); } // Accumulate token usage. if let Some(usage) = parse_usage_metadata(&body) { total_usage.input_tokens += usage.input_tokens; total_usage.output_tokens += usage.output_tokens; } // Extract the candidate response. let candidate = body["candidates"] .as_array() .and_then(|c| c.first()) .ok_or_else(|| "No candidates in Gemini response".to_string())?; let parts = candidate["content"]["parts"] .as_array() .ok_or_else(|| "No parts in Gemini response candidate".to_string())?; // Check finish reason. let finish_reason = candidate["finishReason"].as_str().unwrap_or(""); // Separate text parts and function call parts. let mut text_parts: Vec = Vec::new(); let mut function_calls: Vec = Vec::new(); for part in parts { if let Some(text) = part["text"].as_str() { text_parts.push(text.to_string()); } if let Some(fc) = part.get("functionCall") && let (Some(name), Some(args)) = (fc["name"].as_str(), fc.get("args")) { function_calls.push(GeminiFunctionCall { name: name.to_string(), args: args.clone(), }); } } // Emit any text output. for text in &text_parts { if !text.is_empty() { emit(AgentEvent::Output { story_id: ctx.story_id.clone(), agent_name: ctx.agent_name.clone(), text: text.clone(), }); } } // If no function calls, the model is done. if function_calls.is_empty() { emit(AgentEvent::Done { story_id: ctx.story_id.clone(), agent_name: ctx.agent_name.clone(), session_id: None, }); return Ok(RuntimeResult { session_id: None, token_usage: Some(total_usage), }); } // Add the model's response to the conversation. let model_parts: Vec = parts.to_vec(); contents.push(json!({ "role": "model", "parts": model_parts })); // Execute function calls via MCP and build response parts. let mut response_parts: Vec = Vec::new(); for fc in &function_calls { if cancelled.load(Ordering::Relaxed) { break; } slog!( "[gemini] Calling MCP tool '{}' for {}:{}", fc.name, ctx.story_id, ctx.agent_name ); emit(AgentEvent::Output { story_id: ctx.story_id.clone(), agent_name: ctx.agent_name.clone(), text: format!("\n[Tool call: {}]\n", fc.name), }); let tool_result = call_mcp_tool(&client, &mcp_base, &fc.name, &fc.args).await; let response_value = match &tool_result { Ok(result) => { emit(AgentEvent::Output { story_id: ctx.story_id.clone(), agent_name: ctx.agent_name.clone(), text: format!( "[Tool result: {} chars]\n", result.len() ), }); json!({ "result": result }) } Err(e) => { emit(AgentEvent::Output { story_id: ctx.story_id.clone(), agent_name: ctx.agent_name.clone(), text: format!("[Tool error: {e}]\n"), }); json!({ "error": e }) } }; response_parts.push(json!({ "functionResponse": { "name": fc.name, "response": response_value } })); } // Add function responses to the conversation. contents.push(json!({ "role": "user", "parts": response_parts })); // If the model indicated it's done despite having function calls, // respect the finish reason. if finish_reason == "STOP" && function_calls.is_empty() { break; } } emit(AgentEvent::Done { story_id: ctx.story_id.clone(), agent_name: ctx.agent_name.clone(), session_id: None, }); Ok(RuntimeResult { session_id: None, token_usage: Some(total_usage), }) } fn stop(&self) { self.cancelled.store(true, Ordering::Relaxed); } fn get_status(&self) -> RuntimeStatus { if self.cancelled.load(Ordering::Relaxed) { RuntimeStatus::Failed } else { RuntimeStatus::Idle } } } // ── Internal types ─────────────────────────────────────────────────── struct GeminiFunctionCall { name: String, args: Value, } // ── Gemini API types (for serde) ───────────────────────────────────── #[derive(Debug, Serialize, Deserialize)] struct GeminiFunctionDeclaration { name: String, description: String, #[serde(skip_serializing_if = "Option::is_none")] parameters: Option, } // ── Helper functions ───────────────────────────────────────────────── /// Build the system instruction content from the RuntimeContext. fn build_system_instruction(ctx: &RuntimeContext) -> Value { // Use system_prompt from args if provided via --append-system-prompt, // otherwise use a sensible default. let system_text = ctx .args .iter() .position(|a| a == "--append-system-prompt") .and_then(|i| ctx.args.get(i + 1)) .cloned() .unwrap_or_else(|| { format!( "You are an AI coding agent working on story {}. \ You have access to tools via function calling. \ Use them to complete the task. \ Work in the directory: {}", ctx.story_id, ctx.cwd ) }); json!({ "parts": [{ "text": system_text }] }) } /// Build the full `generateContent` request body. fn build_generate_content_request( system_instruction: &Value, contents: &[Value], gemini_tools: &[GeminiFunctionDeclaration], ) -> Value { let mut body = json!({ "system_instruction": system_instruction, "contents": contents, "generationConfig": { "temperature": 0.2, "maxOutputTokens": 65536, } }); if !gemini_tools.is_empty() { body["tools"] = json!([{ "functionDeclarations": gemini_tools }]); } body } /// Fetch MCP tool definitions from storkit's MCP server and convert /// them to Gemini function declaration format. async fn fetch_and_convert_mcp_tools( client: &Client, mcp_base: &str, ) -> Result, String> { let request = json!({ "jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {} }); let response = client .post(mcp_base) .json(&request) .send() .await .map_err(|e| format!("Failed to fetch MCP tools: {e}"))?; let body: Value = response .json() .await .map_err(|e| format!("Failed to parse MCP tools response: {e}"))?; let tools = body["result"]["tools"] .as_array() .ok_or_else(|| "No tools array in MCP response".to_string())?; let mut declarations = Vec::new(); for tool in tools { let name = tool["name"].as_str().unwrap_or("").to_string(); let description = tool["description"].as_str().unwrap_or("").to_string(); if name.is_empty() { continue; } // Convert MCP inputSchema (JSON Schema) to Gemini parameters // (OpenAPI-subset schema). They are structurally compatible for // simple object schemas. let parameters = convert_mcp_schema_to_gemini(tool.get("inputSchema")); declarations.push(GeminiFunctionDeclaration { name, description, parameters, }); } slog!("[gemini] Loaded {} MCP tools as function declarations", declarations.len()); Ok(declarations) } /// Convert an MCP inputSchema (JSON Schema) to a Gemini-compatible /// OpenAPI-subset parameter schema. /// /// Gemini function calling expects parameters in OpenAPI format, which /// is structurally similar to JSON Schema for simple object types. /// We strip unsupported fields and ensure the type is "object". fn convert_mcp_schema_to_gemini(schema: Option<&Value>) -> Option { let schema = schema?; // If the schema has no properties (empty tool), return None. let properties = schema.get("properties")?; if properties.as_object().is_some_and(|p| p.is_empty()) { return None; } let mut result = json!({ "type": "object", "properties": clean_schema_properties(properties), }); // Preserve required fields if present. if let Some(required) = schema.get("required") { result["required"] = required.clone(); } Some(result) } /// Recursively clean schema properties to be Gemini-compatible. /// Removes unsupported JSON Schema keywords. fn clean_schema_properties(properties: &Value) -> Value { let Some(obj) = properties.as_object() else { return properties.clone(); }; let mut cleaned = serde_json::Map::new(); for (key, value) in obj { let mut prop = value.clone(); // Remove JSON Schema keywords not supported by Gemini if let Some(p) = prop.as_object_mut() { p.remove("$schema"); p.remove("additionalProperties"); // Recursively clean nested object properties if let Some(nested_props) = p.get("properties").cloned() { p.insert( "properties".to_string(), clean_schema_properties(&nested_props), ); } // Clean items schema for arrays if let Some(items) = p.get("items").cloned() && let Some(items_obj) = items.as_object() { let mut cleaned_items = items_obj.clone(); cleaned_items.remove("$schema"); cleaned_items.remove("additionalProperties"); p.insert("items".to_string(), Value::Object(cleaned_items)); } } cleaned.insert(key.clone(), prop); } Value::Object(cleaned) } /// Call an MCP tool via storkit's MCP server. async fn call_mcp_tool( client: &Client, mcp_base: &str, tool_name: &str, args: &Value, ) -> Result { let request = json!({ "jsonrpc": "2.0", "id": 1, "method": "tools/call", "params": { "name": tool_name, "arguments": args } }); let response = client .post(mcp_base) .json(&request) .send() .await .map_err(|e| format!("MCP tool call failed: {e}"))?; let body: Value = response .json() .await .map_err(|e| format!("Failed to parse MCP tool response: {e}"))?; if let Some(error) = body.get("error") { let msg = error["message"].as_str().unwrap_or("Unknown MCP error"); return Err(format!("MCP tool '{tool_name}' error: {msg}")); } // MCP tools/call returns { result: { content: [{ type: "text", text: "..." }] } } let content = &body["result"]["content"]; if let Some(arr) = content.as_array() { let texts: Vec<&str> = arr .iter() .filter_map(|c| c["text"].as_str()) .collect(); if !texts.is_empty() { return Ok(texts.join("\n")); } } // Fall back to serializing the entire result. Ok(body["result"].to_string()) } /// Parse token usage metadata from a Gemini API response. fn parse_usage_metadata(response: &Value) -> Option { let metadata = response.get("usageMetadata")?; Some(TokenUsage { input_tokens: metadata .get("promptTokenCount") .and_then(|v| v.as_u64()) .unwrap_or(0), output_tokens: metadata .get("candidatesTokenCount") .and_then(|v| v.as_u64()) .unwrap_or(0), // Gemini doesn't have cache token fields, but we keep the struct uniform. cache_creation_input_tokens: 0, cache_read_input_tokens: 0, // Google AI API doesn't report cost; leave at 0. total_cost_usd: 0.0, }) } // ── Tests ──────────────────────────────────────────────────────────── #[cfg(test)] mod tests { use super::*; #[test] fn convert_mcp_schema_simple_object() { let schema = json!({ "type": "object", "properties": { "story_id": { "type": "string", "description": "Story identifier" } }, "required": ["story_id"] }); let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap(); assert_eq!(result["type"], "object"); assert!(result["properties"]["story_id"].is_object()); assert_eq!(result["required"][0], "story_id"); } #[test] fn convert_mcp_schema_empty_properties_returns_none() { let schema = json!({ "type": "object", "properties": {} }); assert!(convert_mcp_schema_to_gemini(Some(&schema)).is_none()); } #[test] fn convert_mcp_schema_none_returns_none() { assert!(convert_mcp_schema_to_gemini(None).is_none()); } #[test] fn convert_mcp_schema_strips_additional_properties() { let schema = json!({ "type": "object", "properties": { "name": { "type": "string", "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#" } } }); let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap(); let name_prop = &result["properties"]["name"]; assert!(name_prop.get("additionalProperties").is_none()); assert!(name_prop.get("$schema").is_none()); assert_eq!(name_prop["type"], "string"); } #[test] fn convert_mcp_schema_with_nested_objects() { let schema = json!({ "type": "object", "properties": { "config": { "type": "object", "properties": { "key": { "type": "string" } } } } }); let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap(); assert!(result["properties"]["config"]["properties"]["key"].is_object()); } #[test] fn convert_mcp_schema_with_array_items() { let schema = json!({ "type": "object", "properties": { "items": { "type": "array", "items": { "type": "object", "properties": { "name": { "type": "string" } }, "additionalProperties": false } } } }); let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap(); let items_schema = &result["properties"]["items"]["items"]; assert!(items_schema.get("additionalProperties").is_none()); } #[test] fn build_system_instruction_uses_args() { let ctx = RuntimeContext { story_id: "42_story_test".to_string(), agent_name: "coder-1".to_string(), command: "gemini-2.5-pro".to_string(), args: vec![ "--append-system-prompt".to_string(), "Custom system prompt".to_string(), ], prompt: "Do the thing".to_string(), cwd: "/tmp/wt".to_string(), inactivity_timeout_secs: 300, mcp_port: 3001, }; let instruction = build_system_instruction(&ctx); assert_eq!(instruction["parts"][0]["text"], "Custom system prompt"); } #[test] fn build_system_instruction_default() { let ctx = RuntimeContext { story_id: "42_story_test".to_string(), agent_name: "coder-1".to_string(), command: "gemini-2.5-pro".to_string(), args: vec![], prompt: "Do the thing".to_string(), cwd: "/tmp/wt".to_string(), inactivity_timeout_secs: 300, mcp_port: 3001, }; let instruction = build_system_instruction(&ctx); let text = instruction["parts"][0]["text"].as_str().unwrap(); assert!(text.contains("42_story_test")); assert!(text.contains("/tmp/wt")); } #[test] fn build_generate_content_request_includes_tools() { let system = json!({"parts": [{"text": "system"}]}); let contents = vec![json!({"role": "user", "parts": [{"text": "hello"}]})]; let tools = vec![GeminiFunctionDeclaration { name: "my_tool".to_string(), description: "A tool".to_string(), parameters: Some(json!({"type": "object", "properties": {"x": {"type": "string"}}})), }]; let body = build_generate_content_request(&system, &contents, &tools); assert!(body["tools"][0]["functionDeclarations"].is_array()); assert_eq!(body["tools"][0]["functionDeclarations"][0]["name"], "my_tool"); } #[test] fn build_generate_content_request_no_tools() { let system = json!({"parts": [{"text": "system"}]}); let contents = vec![json!({"role": "user", "parts": [{"text": "hello"}]})]; let tools: Vec = vec![]; let body = build_generate_content_request(&system, &contents, &tools); assert!(body.get("tools").is_none()); } #[test] fn parse_usage_metadata_valid() { let response = json!({ "usageMetadata": { "promptTokenCount": 100, "candidatesTokenCount": 50, "totalTokenCount": 150 } }); let usage = parse_usage_metadata(&response).unwrap(); assert_eq!(usage.input_tokens, 100); assert_eq!(usage.output_tokens, 50); assert_eq!(usage.cache_creation_input_tokens, 0); assert_eq!(usage.total_cost_usd, 0.0); } #[test] fn parse_usage_metadata_missing() { let response = json!({"candidates": []}); assert!(parse_usage_metadata(&response).is_none()); } #[test] fn gemini_runtime_stop_sets_cancelled() { let runtime = GeminiRuntime::new(); assert_eq!(runtime.get_status(), RuntimeStatus::Idle); runtime.stop(); assert_eq!(runtime.get_status(), RuntimeStatus::Failed); } #[test] fn model_extraction_from_command() { // When command starts with "gemini", use it as model name let ctx = RuntimeContext { story_id: "1".to_string(), agent_name: "coder".to_string(), command: "gemini-2.5-pro".to_string(), args: vec![], prompt: "test".to_string(), cwd: "/tmp".to_string(), inactivity_timeout_secs: 300, mcp_port: 3001, }; // The model extraction logic is inside start(), but we test the // condition here. assert!(ctx.command.starts_with("gemini")); } }