storkit: merge 345_story_gemini_agent_backend_via_google_ai_api
This commit is contained in:
@@ -17,7 +17,7 @@ use super::{
|
|||||||
AgentEvent, AgentInfo, AgentStatus, CompletionReport, PipelineStage, agent_config_stage,
|
AgentEvent, AgentInfo, AgentStatus, CompletionReport, PipelineStage, agent_config_stage,
|
||||||
pipeline_stage,
|
pipeline_stage,
|
||||||
};
|
};
|
||||||
use super::runtime::{AgentRuntime, ClaudeCodeRuntime, RuntimeContext};
|
use super::runtime::{AgentRuntime, ClaudeCodeRuntime, GeminiRuntime, RuntimeContext};
|
||||||
|
|
||||||
/// Build the composite key used to track agents in the pool.
|
/// Build the composite key used to track agents in the pool.
|
||||||
fn composite_key(story_id: &str, agent_name: &str) -> String {
|
fn composite_key(story_id: &str, agent_name: &str) -> String {
|
||||||
@@ -531,6 +531,23 @@ impl AgentPool {
|
|||||||
prompt,
|
prompt,
|
||||||
cwd: wt_path_str,
|
cwd: wt_path_str,
|
||||||
inactivity_timeout_secs,
|
inactivity_timeout_secs,
|
||||||
|
mcp_port: port_for_task,
|
||||||
|
};
|
||||||
|
runtime
|
||||||
|
.start(ctx, tx_clone.clone(), log_clone.clone(), log_writer_clone)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
"gemini" => {
|
||||||
|
let runtime = GeminiRuntime::new();
|
||||||
|
let ctx = RuntimeContext {
|
||||||
|
story_id: sid.clone(),
|
||||||
|
agent_name: aname.clone(),
|
||||||
|
command,
|
||||||
|
args,
|
||||||
|
prompt,
|
||||||
|
cwd: wt_path_str,
|
||||||
|
inactivity_timeout_secs,
|
||||||
|
mcp_port: port_for_task,
|
||||||
};
|
};
|
||||||
runtime
|
runtime
|
||||||
.start(ctx, tx_clone.clone(), log_clone.clone(), log_writer_clone)
|
.start(ctx, tx_clone.clone(), log_clone.clone(), log_writer_clone)
|
||||||
@@ -538,7 +555,7 @@ impl AgentPool {
|
|||||||
}
|
}
|
||||||
other => Err(format!(
|
other => Err(format!(
|
||||||
"Unknown agent runtime '{other}'; check the 'runtime' field in project.toml. \
|
"Unknown agent runtime '{other}'; check the 'runtime' field in project.toml. \
|
||||||
Supported: 'claude-code'"
|
Supported: 'claude-code', 'gemini'"
|
||||||
)),
|
)),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
809
server/src/agents/runtime/gemini.rs
Normal file
809
server/src/agents/runtime/gemini.rs
Normal file
@@ -0,0 +1,809 @@
|
|||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
use reqwest::Client;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
use tokio::sync::broadcast;
|
||||||
|
|
||||||
|
use crate::agent_log::AgentLogWriter;
|
||||||
|
use crate::slog;
|
||||||
|
|
||||||
|
use super::super::{AgentEvent, TokenUsage};
|
||||||
|
use super::{AgentRuntime, RuntimeContext, RuntimeResult, RuntimeStatus};
|
||||||
|
|
||||||
|
// ── Public runtime struct ────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Agent runtime that drives a Gemini model through the Google AI
|
||||||
|
/// `generateContent` REST API.
|
||||||
|
///
|
||||||
|
/// The runtime:
|
||||||
|
/// 1. Fetches MCP tool definitions from storkit's MCP server.
|
||||||
|
/// 2. Converts them to Gemini function-calling format.
|
||||||
|
/// 3. Sends the agent prompt + tools to the Gemini API.
|
||||||
|
/// 4. Executes any requested function calls via MCP `tools/call`.
|
||||||
|
/// 5. Loops until the model produces a text-only response or an error.
|
||||||
|
/// 6. Tracks token usage from the API response metadata.
|
||||||
|
pub struct GeminiRuntime {
|
||||||
|
/// Whether a stop has been requested.
|
||||||
|
cancelled: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GeminiRuntime {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
cancelled: Arc::new(AtomicBool::new(false)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AgentRuntime for GeminiRuntime {
|
||||||
|
async fn start(
|
||||||
|
&self,
|
||||||
|
ctx: RuntimeContext,
|
||||||
|
tx: broadcast::Sender<AgentEvent>,
|
||||||
|
event_log: Arc<Mutex<Vec<AgentEvent>>>,
|
||||||
|
log_writer: Option<Arc<Mutex<AgentLogWriter>>>,
|
||||||
|
) -> Result<RuntimeResult, String> {
|
||||||
|
let api_key = std::env::var("GOOGLE_AI_API_KEY").map_err(|_| {
|
||||||
|
"GOOGLE_AI_API_KEY environment variable is not set. \
|
||||||
|
Set it to your Google AI API key to use the Gemini runtime."
|
||||||
|
.to_string()
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let model = if ctx.command.starts_with("gemini") {
|
||||||
|
// The pool puts the model into `command` for non-CLI runtimes,
|
||||||
|
// but also check args for a --model flag.
|
||||||
|
ctx.command.clone()
|
||||||
|
} else {
|
||||||
|
// Fall back to args: look for --model <value>
|
||||||
|
ctx.args
|
||||||
|
.iter()
|
||||||
|
.position(|a| a == "--model")
|
||||||
|
.and_then(|i| ctx.args.get(i + 1))
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_else(|| "gemini-2.5-pro".to_string())
|
||||||
|
};
|
||||||
|
|
||||||
|
let mcp_port = ctx.mcp_port;
|
||||||
|
let mcp_base = format!("http://localhost:{mcp_port}/mcp");
|
||||||
|
|
||||||
|
let client = Client::new();
|
||||||
|
let cancelled = Arc::clone(&self.cancelled);
|
||||||
|
|
||||||
|
// Step 1: Fetch MCP tool definitions and convert to Gemini format.
|
||||||
|
let gemini_tools = fetch_and_convert_mcp_tools(&client, &mcp_base).await?;
|
||||||
|
|
||||||
|
// Step 2: Build the initial conversation contents.
|
||||||
|
let system_instruction = build_system_instruction(&ctx);
|
||||||
|
let mut contents: Vec<Value> = vec![json!({
|
||||||
|
"role": "user",
|
||||||
|
"parts": [{ "text": ctx.prompt }]
|
||||||
|
})];
|
||||||
|
|
||||||
|
let mut total_usage = TokenUsage {
|
||||||
|
input_tokens: 0,
|
||||||
|
output_tokens: 0,
|
||||||
|
cache_creation_input_tokens: 0,
|
||||||
|
cache_read_input_tokens: 0,
|
||||||
|
total_cost_usd: 0.0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let emit = |event: AgentEvent| {
|
||||||
|
super::super::pty::emit_event(
|
||||||
|
event,
|
||||||
|
&tx,
|
||||||
|
&event_log,
|
||||||
|
log_writer.as_ref().map(|w| w.as_ref()),
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
emit(AgentEvent::Status {
|
||||||
|
story_id: ctx.story_id.clone(),
|
||||||
|
agent_name: ctx.agent_name.clone(),
|
||||||
|
status: "running".to_string(),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Step 3: Conversation loop.
|
||||||
|
let mut turn = 0u32;
|
||||||
|
let max_turns = 200; // Safety limit
|
||||||
|
|
||||||
|
loop {
|
||||||
|
if cancelled.load(Ordering::Relaxed) {
|
||||||
|
emit(AgentEvent::Error {
|
||||||
|
story_id: ctx.story_id.clone(),
|
||||||
|
agent_name: ctx.agent_name.clone(),
|
||||||
|
message: "Agent was stopped by user".to_string(),
|
||||||
|
});
|
||||||
|
return Ok(RuntimeResult {
|
||||||
|
session_id: None,
|
||||||
|
token_usage: Some(total_usage),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
turn += 1;
|
||||||
|
if turn > max_turns {
|
||||||
|
emit(AgentEvent::Error {
|
||||||
|
story_id: ctx.story_id.clone(),
|
||||||
|
agent_name: ctx.agent_name.clone(),
|
||||||
|
message: format!("Exceeded maximum turns ({max_turns})"),
|
||||||
|
});
|
||||||
|
return Ok(RuntimeResult {
|
||||||
|
session_id: None,
|
||||||
|
token_usage: Some(total_usage),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
slog!("[gemini] Turn {turn} for {}:{}", ctx.story_id, ctx.agent_name);
|
||||||
|
|
||||||
|
let request_body = build_generate_content_request(
|
||||||
|
&system_instruction,
|
||||||
|
&contents,
|
||||||
|
&gemini_tools,
|
||||||
|
);
|
||||||
|
|
||||||
|
let url = format!(
|
||||||
|
"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}"
|
||||||
|
);
|
||||||
|
|
||||||
|
let response = client
|
||||||
|
.post(&url)
|
||||||
|
.json(&request_body)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Gemini API request failed: {e}"))?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
let body: Value = response
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to parse Gemini API response: {e}"))?;
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
let error_msg = body["error"]["message"]
|
||||||
|
.as_str()
|
||||||
|
.unwrap_or("Unknown API error");
|
||||||
|
let err = format!("Gemini API error ({status}): {error_msg}");
|
||||||
|
emit(AgentEvent::Error {
|
||||||
|
story_id: ctx.story_id.clone(),
|
||||||
|
agent_name: ctx.agent_name.clone(),
|
||||||
|
message: err.clone(),
|
||||||
|
});
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accumulate token usage.
|
||||||
|
if let Some(usage) = parse_usage_metadata(&body) {
|
||||||
|
total_usage.input_tokens += usage.input_tokens;
|
||||||
|
total_usage.output_tokens += usage.output_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the candidate response.
|
||||||
|
let candidate = body["candidates"]
|
||||||
|
.as_array()
|
||||||
|
.and_then(|c| c.first())
|
||||||
|
.ok_or_else(|| "No candidates in Gemini response".to_string())?;
|
||||||
|
|
||||||
|
let parts = candidate["content"]["parts"]
|
||||||
|
.as_array()
|
||||||
|
.ok_or_else(|| "No parts in Gemini response candidate".to_string())?;
|
||||||
|
|
||||||
|
// Check finish reason.
|
||||||
|
let finish_reason = candidate["finishReason"].as_str().unwrap_or("");
|
||||||
|
|
||||||
|
// Separate text parts and function call parts.
|
||||||
|
let mut text_parts: Vec<String> = Vec::new();
|
||||||
|
let mut function_calls: Vec<GeminiFunctionCall> = Vec::new();
|
||||||
|
|
||||||
|
for part in parts {
|
||||||
|
if let Some(text) = part["text"].as_str() {
|
||||||
|
text_parts.push(text.to_string());
|
||||||
|
}
|
||||||
|
if let Some(fc) = part.get("functionCall")
|
||||||
|
&& let (Some(name), Some(args)) =
|
||||||
|
(fc["name"].as_str(), fc.get("args"))
|
||||||
|
{
|
||||||
|
function_calls.push(GeminiFunctionCall {
|
||||||
|
name: name.to_string(),
|
||||||
|
args: args.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit any text output.
|
||||||
|
for text in &text_parts {
|
||||||
|
if !text.is_empty() {
|
||||||
|
emit(AgentEvent::Output {
|
||||||
|
story_id: ctx.story_id.clone(),
|
||||||
|
agent_name: ctx.agent_name.clone(),
|
||||||
|
text: text.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no function calls, the model is done.
|
||||||
|
if function_calls.is_empty() {
|
||||||
|
emit(AgentEvent::Done {
|
||||||
|
story_id: ctx.story_id.clone(),
|
||||||
|
agent_name: ctx.agent_name.clone(),
|
||||||
|
session_id: None,
|
||||||
|
});
|
||||||
|
return Ok(RuntimeResult {
|
||||||
|
session_id: None,
|
||||||
|
token_usage: Some(total_usage),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the model's response to the conversation.
|
||||||
|
let model_parts: Vec<Value> = parts.to_vec();
|
||||||
|
contents.push(json!({
|
||||||
|
"role": "model",
|
||||||
|
"parts": model_parts
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Execute function calls via MCP and build response parts.
|
||||||
|
let mut response_parts: Vec<Value> = Vec::new();
|
||||||
|
|
||||||
|
for fc in &function_calls {
|
||||||
|
if cancelled.load(Ordering::Relaxed) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
slog!(
|
||||||
|
"[gemini] Calling MCP tool '{}' for {}:{}",
|
||||||
|
fc.name,
|
||||||
|
ctx.story_id,
|
||||||
|
ctx.agent_name
|
||||||
|
);
|
||||||
|
|
||||||
|
emit(AgentEvent::Output {
|
||||||
|
story_id: ctx.story_id.clone(),
|
||||||
|
agent_name: ctx.agent_name.clone(),
|
||||||
|
text: format!("\n[Tool call: {}]\n", fc.name),
|
||||||
|
});
|
||||||
|
|
||||||
|
let tool_result =
|
||||||
|
call_mcp_tool(&client, &mcp_base, &fc.name, &fc.args).await;
|
||||||
|
|
||||||
|
let response_value = match &tool_result {
|
||||||
|
Ok(result) => {
|
||||||
|
emit(AgentEvent::Output {
|
||||||
|
story_id: ctx.story_id.clone(),
|
||||||
|
agent_name: ctx.agent_name.clone(),
|
||||||
|
text: format!(
|
||||||
|
"[Tool result: {} chars]\n",
|
||||||
|
result.len()
|
||||||
|
),
|
||||||
|
});
|
||||||
|
json!({ "result": result })
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
emit(AgentEvent::Output {
|
||||||
|
story_id: ctx.story_id.clone(),
|
||||||
|
agent_name: ctx.agent_name.clone(),
|
||||||
|
text: format!("[Tool error: {e}]\n"),
|
||||||
|
});
|
||||||
|
json!({ "error": e })
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
response_parts.push(json!({
|
||||||
|
"functionResponse": {
|
||||||
|
"name": fc.name,
|
||||||
|
"response": response_value
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add function responses to the conversation.
|
||||||
|
contents.push(json!({
|
||||||
|
"role": "user",
|
||||||
|
"parts": response_parts
|
||||||
|
}));
|
||||||
|
|
||||||
|
// If the model indicated it's done despite having function calls,
|
||||||
|
// respect the finish reason.
|
||||||
|
if finish_reason == "STOP" && function_calls.is_empty() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
emit(AgentEvent::Done {
|
||||||
|
story_id: ctx.story_id.clone(),
|
||||||
|
agent_name: ctx.agent_name.clone(),
|
||||||
|
session_id: None,
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(RuntimeResult {
|
||||||
|
session_id: None,
|
||||||
|
token_usage: Some(total_usage),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stop(&self) {
|
||||||
|
self.cancelled.store(true, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_status(&self) -> RuntimeStatus {
|
||||||
|
if self.cancelled.load(Ordering::Relaxed) {
|
||||||
|
RuntimeStatus::Failed
|
||||||
|
} else {
|
||||||
|
RuntimeStatus::Idle
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Internal types ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
struct GeminiFunctionCall {
|
||||||
|
name: String,
|
||||||
|
args: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Gemini API types (for serde) ─────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct GeminiFunctionDeclaration {
|
||||||
|
name: String,
|
||||||
|
description: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
parameters: Option<Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Helper functions ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Build the system instruction content from the RuntimeContext.
|
||||||
|
fn build_system_instruction(ctx: &RuntimeContext) -> Value {
|
||||||
|
// Use system_prompt from args if provided via --append-system-prompt,
|
||||||
|
// otherwise use a sensible default.
|
||||||
|
let system_text = ctx
|
||||||
|
.args
|
||||||
|
.iter()
|
||||||
|
.position(|a| a == "--append-system-prompt")
|
||||||
|
.and_then(|i| ctx.args.get(i + 1))
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_else(|| {
|
||||||
|
format!(
|
||||||
|
"You are an AI coding agent working on story {}. \
|
||||||
|
You have access to tools via function calling. \
|
||||||
|
Use them to complete the task. \
|
||||||
|
Work in the directory: {}",
|
||||||
|
ctx.story_id, ctx.cwd
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
json!({
|
||||||
|
"parts": [{ "text": system_text }]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build the full `generateContent` request body.
|
||||||
|
fn build_generate_content_request(
|
||||||
|
system_instruction: &Value,
|
||||||
|
contents: &[Value],
|
||||||
|
gemini_tools: &[GeminiFunctionDeclaration],
|
||||||
|
) -> Value {
|
||||||
|
let mut body = json!({
|
||||||
|
"system_instruction": system_instruction,
|
||||||
|
"contents": contents,
|
||||||
|
"generationConfig": {
|
||||||
|
"temperature": 0.2,
|
||||||
|
"maxOutputTokens": 65536,
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
if !gemini_tools.is_empty() {
|
||||||
|
body["tools"] = json!([{
|
||||||
|
"functionDeclarations": gemini_tools
|
||||||
|
}]);
|
||||||
|
}
|
||||||
|
|
||||||
|
body
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fetch MCP tool definitions from storkit's MCP server and convert
|
||||||
|
/// them to Gemini function declaration format.
|
||||||
|
async fn fetch_and_convert_mcp_tools(
|
||||||
|
client: &Client,
|
||||||
|
mcp_base: &str,
|
||||||
|
) -> Result<Vec<GeminiFunctionDeclaration>, String> {
|
||||||
|
let request = json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 1,
|
||||||
|
"method": "tools/list",
|
||||||
|
"params": {}
|
||||||
|
});
|
||||||
|
|
||||||
|
let response = client
|
||||||
|
.post(mcp_base)
|
||||||
|
.json(&request)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to fetch MCP tools: {e}"))?;
|
||||||
|
|
||||||
|
let body: Value = response
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to parse MCP tools response: {e}"))?;
|
||||||
|
|
||||||
|
let tools = body["result"]["tools"]
|
||||||
|
.as_array()
|
||||||
|
.ok_or_else(|| "No tools array in MCP response".to_string())?;
|
||||||
|
|
||||||
|
let mut declarations = Vec::new();
|
||||||
|
|
||||||
|
for tool in tools {
|
||||||
|
let name = tool["name"].as_str().unwrap_or("").to_string();
|
||||||
|
let description = tool["description"].as_str().unwrap_or("").to_string();
|
||||||
|
|
||||||
|
if name.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert MCP inputSchema (JSON Schema) to Gemini parameters
|
||||||
|
// (OpenAPI-subset schema). They are structurally compatible for
|
||||||
|
// simple object schemas.
|
||||||
|
let parameters = convert_mcp_schema_to_gemini(tool.get("inputSchema"));
|
||||||
|
|
||||||
|
declarations.push(GeminiFunctionDeclaration {
|
||||||
|
name,
|
||||||
|
description,
|
||||||
|
parameters,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
slog!("[gemini] Loaded {} MCP tools as function declarations", declarations.len());
|
||||||
|
Ok(declarations)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert an MCP inputSchema (JSON Schema) to a Gemini-compatible
|
||||||
|
/// OpenAPI-subset parameter schema.
|
||||||
|
///
|
||||||
|
/// Gemini function calling expects parameters in OpenAPI format, which
|
||||||
|
/// is structurally similar to JSON Schema for simple object types.
|
||||||
|
/// We strip unsupported fields and ensure the type is "object".
|
||||||
|
fn convert_mcp_schema_to_gemini(schema: Option<&Value>) -> Option<Value> {
|
||||||
|
let schema = schema?;
|
||||||
|
|
||||||
|
// If the schema has no properties (empty tool), return None.
|
||||||
|
let properties = schema.get("properties")?;
|
||||||
|
if properties.as_object().is_some_and(|p| p.is_empty()) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut result = json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": clean_schema_properties(properties),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Preserve required fields if present.
|
||||||
|
if let Some(required) = schema.get("required") {
|
||||||
|
result["required"] = required.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Recursively clean schema properties to be Gemini-compatible.
|
||||||
|
/// Removes unsupported JSON Schema keywords.
|
||||||
|
fn clean_schema_properties(properties: &Value) -> Value {
|
||||||
|
let Some(obj) = properties.as_object() else {
|
||||||
|
return properties.clone();
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut cleaned = serde_json::Map::new();
|
||||||
|
for (key, value) in obj {
|
||||||
|
let mut prop = value.clone();
|
||||||
|
// Remove JSON Schema keywords not supported by Gemini
|
||||||
|
if let Some(p) = prop.as_object_mut() {
|
||||||
|
p.remove("$schema");
|
||||||
|
p.remove("additionalProperties");
|
||||||
|
|
||||||
|
// Recursively clean nested object properties
|
||||||
|
if let Some(nested_props) = p.get("properties").cloned() {
|
||||||
|
p.insert(
|
||||||
|
"properties".to_string(),
|
||||||
|
clean_schema_properties(&nested_props),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean items schema for arrays
|
||||||
|
if let Some(items) = p.get("items").cloned()
|
||||||
|
&& let Some(items_obj) = items.as_object()
|
||||||
|
{
|
||||||
|
let mut cleaned_items = items_obj.clone();
|
||||||
|
cleaned_items.remove("$schema");
|
||||||
|
cleaned_items.remove("additionalProperties");
|
||||||
|
p.insert("items".to_string(), Value::Object(cleaned_items));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cleaned.insert(key.clone(), prop);
|
||||||
|
}
|
||||||
|
Value::Object(cleaned)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Call an MCP tool via storkit's MCP server.
|
||||||
|
async fn call_mcp_tool(
|
||||||
|
client: &Client,
|
||||||
|
mcp_base: &str,
|
||||||
|
tool_name: &str,
|
||||||
|
args: &Value,
|
||||||
|
) -> Result<String, String> {
|
||||||
|
let request = json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 1,
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": {
|
||||||
|
"name": tool_name,
|
||||||
|
"arguments": args
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let response = client
|
||||||
|
.post(mcp_base)
|
||||||
|
.json(&request)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("MCP tool call failed: {e}"))?;
|
||||||
|
|
||||||
|
let body: Value = response
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to parse MCP tool response: {e}"))?;
|
||||||
|
|
||||||
|
if let Some(error) = body.get("error") {
|
||||||
|
let msg = error["message"].as_str().unwrap_or("Unknown MCP error");
|
||||||
|
return Err(format!("MCP tool '{tool_name}' error: {msg}"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// MCP tools/call returns { result: { content: [{ type: "text", text: "..." }] } }
|
||||||
|
let content = &body["result"]["content"];
|
||||||
|
if let Some(arr) = content.as_array() {
|
||||||
|
let texts: Vec<&str> = arr
|
||||||
|
.iter()
|
||||||
|
.filter_map(|c| c["text"].as_str())
|
||||||
|
.collect();
|
||||||
|
if !texts.is_empty() {
|
||||||
|
return Ok(texts.join("\n"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to serializing the entire result.
|
||||||
|
Ok(body["result"].to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse token usage metadata from a Gemini API response.
|
||||||
|
fn parse_usage_metadata(response: &Value) -> Option<TokenUsage> {
|
||||||
|
let metadata = response.get("usageMetadata")?;
|
||||||
|
Some(TokenUsage {
|
||||||
|
input_tokens: metadata
|
||||||
|
.get("promptTokenCount")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.unwrap_or(0),
|
||||||
|
output_tokens: metadata
|
||||||
|
.get("candidatesTokenCount")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.unwrap_or(0),
|
||||||
|
// Gemini doesn't have cache token fields, but we keep the struct uniform.
|
||||||
|
cache_creation_input_tokens: 0,
|
||||||
|
cache_read_input_tokens: 0,
|
||||||
|
// Google AI API doesn't report cost; leave at 0.
|
||||||
|
total_cost_usd: 0.0,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn convert_mcp_schema_simple_object() {
|
||||||
|
let schema = json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"story_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Story identifier"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["story_id"]
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap();
|
||||||
|
assert_eq!(result["type"], "object");
|
||||||
|
assert!(result["properties"]["story_id"].is_object());
|
||||||
|
assert_eq!(result["required"][0], "story_id");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn convert_mcp_schema_empty_properties_returns_none() {
|
||||||
|
let schema = json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {}
|
||||||
|
});
|
||||||
|
|
||||||
|
assert!(convert_mcp_schema_to_gemini(Some(&schema)).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn convert_mcp_schema_none_returns_none() {
|
||||||
|
assert!(convert_mcp_schema_to_gemini(None).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn convert_mcp_schema_strips_additional_properties() {
|
||||||
|
let schema = json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string",
|
||||||
|
"additionalProperties": false,
|
||||||
|
"$schema": "http://json-schema.org/draft-07/schema#"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap();
|
||||||
|
let name_prop = &result["properties"]["name"];
|
||||||
|
assert!(name_prop.get("additionalProperties").is_none());
|
||||||
|
assert!(name_prop.get("$schema").is_none());
|
||||||
|
assert_eq!(name_prop["type"], "string");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn convert_mcp_schema_with_nested_objects() {
|
||||||
|
let schema = json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"config": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"key": { "type": "string" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap();
|
||||||
|
assert!(result["properties"]["config"]["properties"]["key"].is_object());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn convert_mcp_schema_with_array_items() {
|
||||||
|
let schema = json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"items": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": { "type": "string" }
|
||||||
|
},
|
||||||
|
"additionalProperties": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap();
|
||||||
|
let items_schema = &result["properties"]["items"]["items"];
|
||||||
|
assert!(items_schema.get("additionalProperties").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_system_instruction_uses_args() {
|
||||||
|
let ctx = RuntimeContext {
|
||||||
|
story_id: "42_story_test".to_string(),
|
||||||
|
agent_name: "coder-1".to_string(),
|
||||||
|
command: "gemini-2.5-pro".to_string(),
|
||||||
|
args: vec![
|
||||||
|
"--append-system-prompt".to_string(),
|
||||||
|
"Custom system prompt".to_string(),
|
||||||
|
],
|
||||||
|
prompt: "Do the thing".to_string(),
|
||||||
|
cwd: "/tmp/wt".to_string(),
|
||||||
|
inactivity_timeout_secs: 300,
|
||||||
|
mcp_port: 3001,
|
||||||
|
};
|
||||||
|
|
||||||
|
let instruction = build_system_instruction(&ctx);
|
||||||
|
assert_eq!(instruction["parts"][0]["text"], "Custom system prompt");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_system_instruction_default() {
|
||||||
|
let ctx = RuntimeContext {
|
||||||
|
story_id: "42_story_test".to_string(),
|
||||||
|
agent_name: "coder-1".to_string(),
|
||||||
|
command: "gemini-2.5-pro".to_string(),
|
||||||
|
args: vec![],
|
||||||
|
prompt: "Do the thing".to_string(),
|
||||||
|
cwd: "/tmp/wt".to_string(),
|
||||||
|
inactivity_timeout_secs: 300,
|
||||||
|
mcp_port: 3001,
|
||||||
|
};
|
||||||
|
|
||||||
|
let instruction = build_system_instruction(&ctx);
|
||||||
|
let text = instruction["parts"][0]["text"].as_str().unwrap();
|
||||||
|
assert!(text.contains("42_story_test"));
|
||||||
|
assert!(text.contains("/tmp/wt"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_generate_content_request_includes_tools() {
|
||||||
|
let system = json!({"parts": [{"text": "system"}]});
|
||||||
|
let contents = vec![json!({"role": "user", "parts": [{"text": "hello"}]})];
|
||||||
|
let tools = vec![GeminiFunctionDeclaration {
|
||||||
|
name: "my_tool".to_string(),
|
||||||
|
description: "A tool".to_string(),
|
||||||
|
parameters: Some(json!({"type": "object", "properties": {"x": {"type": "string"}}})),
|
||||||
|
}];
|
||||||
|
|
||||||
|
let body = build_generate_content_request(&system, &contents, &tools);
|
||||||
|
assert!(body["tools"][0]["functionDeclarations"].is_array());
|
||||||
|
assert_eq!(body["tools"][0]["functionDeclarations"][0]["name"], "my_tool");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_generate_content_request_no_tools() {
|
||||||
|
let system = json!({"parts": [{"text": "system"}]});
|
||||||
|
let contents = vec![json!({"role": "user", "parts": [{"text": "hello"}]})];
|
||||||
|
let tools: Vec<GeminiFunctionDeclaration> = vec![];
|
||||||
|
|
||||||
|
let body = build_generate_content_request(&system, &contents, &tools);
|
||||||
|
assert!(body.get("tools").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_usage_metadata_valid() {
|
||||||
|
let response = json!({
|
||||||
|
"usageMetadata": {
|
||||||
|
"promptTokenCount": 100,
|
||||||
|
"candidatesTokenCount": 50,
|
||||||
|
"totalTokenCount": 150
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let usage = parse_usage_metadata(&response).unwrap();
|
||||||
|
assert_eq!(usage.input_tokens, 100);
|
||||||
|
assert_eq!(usage.output_tokens, 50);
|
||||||
|
assert_eq!(usage.cache_creation_input_tokens, 0);
|
||||||
|
assert_eq!(usage.total_cost_usd, 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_usage_metadata_missing() {
|
||||||
|
let response = json!({"candidates": []});
|
||||||
|
assert!(parse_usage_metadata(&response).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn gemini_runtime_stop_sets_cancelled() {
|
||||||
|
let runtime = GeminiRuntime::new();
|
||||||
|
assert_eq!(runtime.get_status(), RuntimeStatus::Idle);
|
||||||
|
runtime.stop();
|
||||||
|
assert_eq!(runtime.get_status(), RuntimeStatus::Failed);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn model_extraction_from_command() {
|
||||||
|
// When command starts with "gemini", use it as model name
|
||||||
|
let ctx = RuntimeContext {
|
||||||
|
story_id: "1".to_string(),
|
||||||
|
agent_name: "coder".to_string(),
|
||||||
|
command: "gemini-2.5-pro".to_string(),
|
||||||
|
args: vec![],
|
||||||
|
prompt: "test".to_string(),
|
||||||
|
cwd: "/tmp".to_string(),
|
||||||
|
inactivity_timeout_secs: 300,
|
||||||
|
mcp_port: 3001,
|
||||||
|
};
|
||||||
|
|
||||||
|
// The model extraction logic is inside start(), but we test the
|
||||||
|
// condition here.
|
||||||
|
assert!(ctx.command.starts_with("gemini"));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
mod claude_code;
|
mod claude_code;
|
||||||
|
mod gemini;
|
||||||
|
|
||||||
pub use claude_code::ClaudeCodeRuntime;
|
pub use claude_code::ClaudeCodeRuntime;
|
||||||
|
pub use gemini::GeminiRuntime;
|
||||||
|
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
use tokio::sync::broadcast;
|
use tokio::sync::broadcast;
|
||||||
@@ -18,6 +20,9 @@ pub struct RuntimeContext {
|
|||||||
pub prompt: String,
|
pub prompt: String,
|
||||||
pub cwd: String,
|
pub cwd: String,
|
||||||
pub inactivity_timeout_secs: u64,
|
pub inactivity_timeout_secs: u64,
|
||||||
|
/// Port of the storkit MCP server, used by API-based runtimes (Gemini, OpenAI)
|
||||||
|
/// to call back for tool execution.
|
||||||
|
pub mcp_port: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Result returned by a runtime after the agent session completes.
|
/// Result returned by a runtime after the agent session completes.
|
||||||
@@ -85,6 +90,7 @@ mod tests {
|
|||||||
prompt: "Do the thing".to_string(),
|
prompt: "Do the thing".to_string(),
|
||||||
cwd: "/tmp/wt".to_string(),
|
cwd: "/tmp/wt".to_string(),
|
||||||
inactivity_timeout_secs: 300,
|
inactivity_timeout_secs: 300,
|
||||||
|
mcp_port: 3001,
|
||||||
};
|
};
|
||||||
assert_eq!(ctx.story_id, "42_story_foo");
|
assert_eq!(ctx.story_id, "42_story_foo");
|
||||||
assert_eq!(ctx.agent_name, "coder-1");
|
assert_eq!(ctx.agent_name, "coder-1");
|
||||||
@@ -93,6 +99,7 @@ mod tests {
|
|||||||
assert_eq!(ctx.prompt, "Do the thing");
|
assert_eq!(ctx.prompt, "Do the thing");
|
||||||
assert_eq!(ctx.cwd, "/tmp/wt");
|
assert_eq!(ctx.cwd, "/tmp/wt");
|
||||||
assert_eq!(ctx.inactivity_timeout_secs, 300);
|
assert_eq!(ctx.inactivity_timeout_secs, 300);
|
||||||
|
assert_eq!(ctx.mcp_port, 3001);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -378,10 +378,10 @@ fn validate_agents(agents: &[AgentConfig]) -> Result<(), String> {
|
|||||||
}
|
}
|
||||||
if let Some(ref runtime) = agent.runtime {
|
if let Some(ref runtime) = agent.runtime {
|
||||||
match runtime.as_str() {
|
match runtime.as_str() {
|
||||||
"claude-code" => {}
|
"claude-code" | "gemini" => {}
|
||||||
other => {
|
other => {
|
||||||
return Err(format!(
|
return Err(format!(
|
||||||
"Agent '{}': unknown runtime '{other}'. Supported: 'claude-code'",
|
"Agent '{}': unknown runtime '{other}'. Supported: 'claude-code', 'gemini'",
|
||||||
agent.name
|
agent.name
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
@@ -835,6 +835,18 @@ runtime = "claude-code"
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn runtime_gemini_accepted() {
|
||||||
|
let toml_str = r#"
|
||||||
|
[[agent]]
|
||||||
|
name = "coder"
|
||||||
|
runtime = "gemini"
|
||||||
|
model = "gemini-2.5-pro"
|
||||||
|
"#;
|
||||||
|
let config = ProjectConfig::parse(toml_str).unwrap();
|
||||||
|
assert_eq!(config.agent[0].runtime, Some("gemini".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn runtime_unknown_rejected() {
|
fn runtime_unknown_rejected() {
|
||||||
let toml_str = r#"
|
let toml_str = r#"
|
||||||
|
|||||||
Reference in New Issue
Block a user