storkit: merge 345_story_gemini_agent_backend_via_google_ai_api
This commit is contained in:
@@ -17,7 +17,7 @@ use super::{
|
||||
AgentEvent, AgentInfo, AgentStatus, CompletionReport, PipelineStage, agent_config_stage,
|
||||
pipeline_stage,
|
||||
};
|
||||
use super::runtime::{AgentRuntime, ClaudeCodeRuntime, RuntimeContext};
|
||||
use super::runtime::{AgentRuntime, ClaudeCodeRuntime, GeminiRuntime, RuntimeContext};
|
||||
|
||||
/// Build the composite key used to track agents in the pool.
|
||||
fn composite_key(story_id: &str, agent_name: &str) -> String {
|
||||
@@ -531,6 +531,23 @@ impl AgentPool {
|
||||
prompt,
|
||||
cwd: wt_path_str,
|
||||
inactivity_timeout_secs,
|
||||
mcp_port: port_for_task,
|
||||
};
|
||||
runtime
|
||||
.start(ctx, tx_clone.clone(), log_clone.clone(), log_writer_clone)
|
||||
.await
|
||||
}
|
||||
"gemini" => {
|
||||
let runtime = GeminiRuntime::new();
|
||||
let ctx = RuntimeContext {
|
||||
story_id: sid.clone(),
|
||||
agent_name: aname.clone(),
|
||||
command,
|
||||
args,
|
||||
prompt,
|
||||
cwd: wt_path_str,
|
||||
inactivity_timeout_secs,
|
||||
mcp_port: port_for_task,
|
||||
};
|
||||
runtime
|
||||
.start(ctx, tx_clone.clone(), log_clone.clone(), log_writer_clone)
|
||||
@@ -538,7 +555,7 @@ impl AgentPool {
|
||||
}
|
||||
other => Err(format!(
|
||||
"Unknown agent runtime '{other}'; check the 'runtime' field in project.toml. \
|
||||
Supported: 'claude-code'"
|
||||
Supported: 'claude-code', 'gemini'"
|
||||
)),
|
||||
};
|
||||
|
||||
|
||||
809
server/src/agents/runtime/gemini.rs
Normal file
809
server/src/agents/runtime/gemini.rs
Normal file
@@ -0,0 +1,809 @@
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
use crate::agent_log::AgentLogWriter;
|
||||
use crate::slog;
|
||||
|
||||
use super::super::{AgentEvent, TokenUsage};
|
||||
use super::{AgentRuntime, RuntimeContext, RuntimeResult, RuntimeStatus};
|
||||
|
||||
// ── Public runtime struct ────────────────────────────────────────────
|
||||
|
||||
/// Agent runtime that drives a Gemini model through the Google AI
|
||||
/// `generateContent` REST API.
|
||||
///
|
||||
/// The runtime:
|
||||
/// 1. Fetches MCP tool definitions from storkit's MCP server.
|
||||
/// 2. Converts them to Gemini function-calling format.
|
||||
/// 3. Sends the agent prompt + tools to the Gemini API.
|
||||
/// 4. Executes any requested function calls via MCP `tools/call`.
|
||||
/// 5. Loops until the model produces a text-only response or an error.
|
||||
/// 6. Tracks token usage from the API response metadata.
|
||||
pub struct GeminiRuntime {
|
||||
/// Whether a stop has been requested.
|
||||
cancelled: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl GeminiRuntime {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
cancelled: Arc::new(AtomicBool::new(false)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentRuntime for GeminiRuntime {
|
||||
async fn start(
|
||||
&self,
|
||||
ctx: RuntimeContext,
|
||||
tx: broadcast::Sender<AgentEvent>,
|
||||
event_log: Arc<Mutex<Vec<AgentEvent>>>,
|
||||
log_writer: Option<Arc<Mutex<AgentLogWriter>>>,
|
||||
) -> Result<RuntimeResult, String> {
|
||||
let api_key = std::env::var("GOOGLE_AI_API_KEY").map_err(|_| {
|
||||
"GOOGLE_AI_API_KEY environment variable is not set. \
|
||||
Set it to your Google AI API key to use the Gemini runtime."
|
||||
.to_string()
|
||||
})?;
|
||||
|
||||
let model = if ctx.command.starts_with("gemini") {
|
||||
// The pool puts the model into `command` for non-CLI runtimes,
|
||||
// but also check args for a --model flag.
|
||||
ctx.command.clone()
|
||||
} else {
|
||||
// Fall back to args: look for --model <value>
|
||||
ctx.args
|
||||
.iter()
|
||||
.position(|a| a == "--model")
|
||||
.and_then(|i| ctx.args.get(i + 1))
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "gemini-2.5-pro".to_string())
|
||||
};
|
||||
|
||||
let mcp_port = ctx.mcp_port;
|
||||
let mcp_base = format!("http://localhost:{mcp_port}/mcp");
|
||||
|
||||
let client = Client::new();
|
||||
let cancelled = Arc::clone(&self.cancelled);
|
||||
|
||||
// Step 1: Fetch MCP tool definitions and convert to Gemini format.
|
||||
let gemini_tools = fetch_and_convert_mcp_tools(&client, &mcp_base).await?;
|
||||
|
||||
// Step 2: Build the initial conversation contents.
|
||||
let system_instruction = build_system_instruction(&ctx);
|
||||
let mut contents: Vec<Value> = vec![json!({
|
||||
"role": "user",
|
||||
"parts": [{ "text": ctx.prompt }]
|
||||
})];
|
||||
|
||||
let mut total_usage = TokenUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
total_cost_usd: 0.0,
|
||||
};
|
||||
|
||||
let emit = |event: AgentEvent| {
|
||||
super::super::pty::emit_event(
|
||||
event,
|
||||
&tx,
|
||||
&event_log,
|
||||
log_writer.as_ref().map(|w| w.as_ref()),
|
||||
);
|
||||
};
|
||||
|
||||
emit(AgentEvent::Status {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
status: "running".to_string(),
|
||||
});
|
||||
|
||||
// Step 3: Conversation loop.
|
||||
let mut turn = 0u32;
|
||||
let max_turns = 200; // Safety limit
|
||||
|
||||
loop {
|
||||
if cancelled.load(Ordering::Relaxed) {
|
||||
emit(AgentEvent::Error {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
message: "Agent was stopped by user".to_string(),
|
||||
});
|
||||
return Ok(RuntimeResult {
|
||||
session_id: None,
|
||||
token_usage: Some(total_usage),
|
||||
});
|
||||
}
|
||||
|
||||
turn += 1;
|
||||
if turn > max_turns {
|
||||
emit(AgentEvent::Error {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
message: format!("Exceeded maximum turns ({max_turns})"),
|
||||
});
|
||||
return Ok(RuntimeResult {
|
||||
session_id: None,
|
||||
token_usage: Some(total_usage),
|
||||
});
|
||||
}
|
||||
|
||||
slog!("[gemini] Turn {turn} for {}:{}", ctx.story_id, ctx.agent_name);
|
||||
|
||||
let request_body = build_generate_content_request(
|
||||
&system_instruction,
|
||||
&contents,
|
||||
&gemini_tools,
|
||||
);
|
||||
|
||||
let url = format!(
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}"
|
||||
);
|
||||
|
||||
let response = client
|
||||
.post(&url)
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Gemini API request failed: {e}"))?;
|
||||
|
||||
let status = response.status();
|
||||
let body: Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse Gemini API response: {e}"))?;
|
||||
|
||||
if !status.is_success() {
|
||||
let error_msg = body["error"]["message"]
|
||||
.as_str()
|
||||
.unwrap_or("Unknown API error");
|
||||
let err = format!("Gemini API error ({status}): {error_msg}");
|
||||
emit(AgentEvent::Error {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
message: err.clone(),
|
||||
});
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
// Accumulate token usage.
|
||||
if let Some(usage) = parse_usage_metadata(&body) {
|
||||
total_usage.input_tokens += usage.input_tokens;
|
||||
total_usage.output_tokens += usage.output_tokens;
|
||||
}
|
||||
|
||||
// Extract the candidate response.
|
||||
let candidate = body["candidates"]
|
||||
.as_array()
|
||||
.and_then(|c| c.first())
|
||||
.ok_or_else(|| "No candidates in Gemini response".to_string())?;
|
||||
|
||||
let parts = candidate["content"]["parts"]
|
||||
.as_array()
|
||||
.ok_or_else(|| "No parts in Gemini response candidate".to_string())?;
|
||||
|
||||
// Check finish reason.
|
||||
let finish_reason = candidate["finishReason"].as_str().unwrap_or("");
|
||||
|
||||
// Separate text parts and function call parts.
|
||||
let mut text_parts: Vec<String> = Vec::new();
|
||||
let mut function_calls: Vec<GeminiFunctionCall> = Vec::new();
|
||||
|
||||
for part in parts {
|
||||
if let Some(text) = part["text"].as_str() {
|
||||
text_parts.push(text.to_string());
|
||||
}
|
||||
if let Some(fc) = part.get("functionCall")
|
||||
&& let (Some(name), Some(args)) =
|
||||
(fc["name"].as_str(), fc.get("args"))
|
||||
{
|
||||
function_calls.push(GeminiFunctionCall {
|
||||
name: name.to_string(),
|
||||
args: args.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Emit any text output.
|
||||
for text in &text_parts {
|
||||
if !text.is_empty() {
|
||||
emit(AgentEvent::Output {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
text: text.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// If no function calls, the model is done.
|
||||
if function_calls.is_empty() {
|
||||
emit(AgentEvent::Done {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
session_id: None,
|
||||
});
|
||||
return Ok(RuntimeResult {
|
||||
session_id: None,
|
||||
token_usage: Some(total_usage),
|
||||
});
|
||||
}
|
||||
|
||||
// Add the model's response to the conversation.
|
||||
let model_parts: Vec<Value> = parts.to_vec();
|
||||
contents.push(json!({
|
||||
"role": "model",
|
||||
"parts": model_parts
|
||||
}));
|
||||
|
||||
// Execute function calls via MCP and build response parts.
|
||||
let mut response_parts: Vec<Value> = Vec::new();
|
||||
|
||||
for fc in &function_calls {
|
||||
if cancelled.load(Ordering::Relaxed) {
|
||||
break;
|
||||
}
|
||||
|
||||
slog!(
|
||||
"[gemini] Calling MCP tool '{}' for {}:{}",
|
||||
fc.name,
|
||||
ctx.story_id,
|
||||
ctx.agent_name
|
||||
);
|
||||
|
||||
emit(AgentEvent::Output {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
text: format!("\n[Tool call: {}]\n", fc.name),
|
||||
});
|
||||
|
||||
let tool_result =
|
||||
call_mcp_tool(&client, &mcp_base, &fc.name, &fc.args).await;
|
||||
|
||||
let response_value = match &tool_result {
|
||||
Ok(result) => {
|
||||
emit(AgentEvent::Output {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
text: format!(
|
||||
"[Tool result: {} chars]\n",
|
||||
result.len()
|
||||
),
|
||||
});
|
||||
json!({ "result": result })
|
||||
}
|
||||
Err(e) => {
|
||||
emit(AgentEvent::Output {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
text: format!("[Tool error: {e}]\n"),
|
||||
});
|
||||
json!({ "error": e })
|
||||
}
|
||||
};
|
||||
|
||||
response_parts.push(json!({
|
||||
"functionResponse": {
|
||||
"name": fc.name,
|
||||
"response": response_value
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
// Add function responses to the conversation.
|
||||
contents.push(json!({
|
||||
"role": "user",
|
||||
"parts": response_parts
|
||||
}));
|
||||
|
||||
// If the model indicated it's done despite having function calls,
|
||||
// respect the finish reason.
|
||||
if finish_reason == "STOP" && function_calls.is_empty() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
emit(AgentEvent::Done {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
session_id: None,
|
||||
});
|
||||
|
||||
Ok(RuntimeResult {
|
||||
session_id: None,
|
||||
token_usage: Some(total_usage),
|
||||
})
|
||||
}
|
||||
|
||||
fn stop(&self) {
|
||||
self.cancelled.store(true, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn get_status(&self) -> RuntimeStatus {
|
||||
if self.cancelled.load(Ordering::Relaxed) {
|
||||
RuntimeStatus::Failed
|
||||
} else {
|
||||
RuntimeStatus::Idle
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Internal types ───────────────────────────────────────────────────
|
||||
|
||||
struct GeminiFunctionCall {
|
||||
name: String,
|
||||
args: Value,
|
||||
}
|
||||
|
||||
// ── Gemini API types (for serde) ─────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct GeminiFunctionDeclaration {
|
||||
name: String,
|
||||
description: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
parameters: Option<Value>,
|
||||
}
|
||||
|
||||
// ── Helper functions ─────────────────────────────────────────────────
|
||||
|
||||
/// Build the system instruction content from the RuntimeContext.
|
||||
fn build_system_instruction(ctx: &RuntimeContext) -> Value {
|
||||
// Use system_prompt from args if provided via --append-system-prompt,
|
||||
// otherwise use a sensible default.
|
||||
let system_text = ctx
|
||||
.args
|
||||
.iter()
|
||||
.position(|a| a == "--append-system-prompt")
|
||||
.and_then(|i| ctx.args.get(i + 1))
|
||||
.cloned()
|
||||
.unwrap_or_else(|| {
|
||||
format!(
|
||||
"You are an AI coding agent working on story {}. \
|
||||
You have access to tools via function calling. \
|
||||
Use them to complete the task. \
|
||||
Work in the directory: {}",
|
||||
ctx.story_id, ctx.cwd
|
||||
)
|
||||
});
|
||||
|
||||
json!({
|
||||
"parts": [{ "text": system_text }]
|
||||
})
|
||||
}
|
||||
|
||||
/// Build the full `generateContent` request body.
|
||||
fn build_generate_content_request(
|
||||
system_instruction: &Value,
|
||||
contents: &[Value],
|
||||
gemini_tools: &[GeminiFunctionDeclaration],
|
||||
) -> Value {
|
||||
let mut body = json!({
|
||||
"system_instruction": system_instruction,
|
||||
"contents": contents,
|
||||
"generationConfig": {
|
||||
"temperature": 0.2,
|
||||
"maxOutputTokens": 65536,
|
||||
}
|
||||
});
|
||||
|
||||
if !gemini_tools.is_empty() {
|
||||
body["tools"] = json!([{
|
||||
"functionDeclarations": gemini_tools
|
||||
}]);
|
||||
}
|
||||
|
||||
body
|
||||
}
|
||||
|
||||
/// Fetch MCP tool definitions from storkit's MCP server and convert
|
||||
/// them to Gemini function declaration format.
|
||||
async fn fetch_and_convert_mcp_tools(
|
||||
client: &Client,
|
||||
mcp_base: &str,
|
||||
) -> Result<Vec<GeminiFunctionDeclaration>, String> {
|
||||
let request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/list",
|
||||
"params": {}
|
||||
});
|
||||
|
||||
let response = client
|
||||
.post(mcp_base)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to fetch MCP tools: {e}"))?;
|
||||
|
||||
let body: Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse MCP tools response: {e}"))?;
|
||||
|
||||
let tools = body["result"]["tools"]
|
||||
.as_array()
|
||||
.ok_or_else(|| "No tools array in MCP response".to_string())?;
|
||||
|
||||
let mut declarations = Vec::new();
|
||||
|
||||
for tool in tools {
|
||||
let name = tool["name"].as_str().unwrap_or("").to_string();
|
||||
let description = tool["description"].as_str().unwrap_or("").to_string();
|
||||
|
||||
if name.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Convert MCP inputSchema (JSON Schema) to Gemini parameters
|
||||
// (OpenAPI-subset schema). They are structurally compatible for
|
||||
// simple object schemas.
|
||||
let parameters = convert_mcp_schema_to_gemini(tool.get("inputSchema"));
|
||||
|
||||
declarations.push(GeminiFunctionDeclaration {
|
||||
name,
|
||||
description,
|
||||
parameters,
|
||||
});
|
||||
}
|
||||
|
||||
slog!("[gemini] Loaded {} MCP tools as function declarations", declarations.len());
|
||||
Ok(declarations)
|
||||
}
|
||||
|
||||
/// Convert an MCP inputSchema (JSON Schema) to a Gemini-compatible
|
||||
/// OpenAPI-subset parameter schema.
|
||||
///
|
||||
/// Gemini function calling expects parameters in OpenAPI format, which
|
||||
/// is structurally similar to JSON Schema for simple object types.
|
||||
/// We strip unsupported fields and ensure the type is "object".
|
||||
fn convert_mcp_schema_to_gemini(schema: Option<&Value>) -> Option<Value> {
|
||||
let schema = schema?;
|
||||
|
||||
// If the schema has no properties (empty tool), return None.
|
||||
let properties = schema.get("properties")?;
|
||||
if properties.as_object().is_some_and(|p| p.is_empty()) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut result = json!({
|
||||
"type": "object",
|
||||
"properties": clean_schema_properties(properties),
|
||||
});
|
||||
|
||||
// Preserve required fields if present.
|
||||
if let Some(required) = schema.get("required") {
|
||||
result["required"] = required.clone();
|
||||
}
|
||||
|
||||
Some(result)
|
||||
}
|
||||
|
||||
/// Recursively clean schema properties to be Gemini-compatible.
|
||||
/// Removes unsupported JSON Schema keywords.
|
||||
fn clean_schema_properties(properties: &Value) -> Value {
|
||||
let Some(obj) = properties.as_object() else {
|
||||
return properties.clone();
|
||||
};
|
||||
|
||||
let mut cleaned = serde_json::Map::new();
|
||||
for (key, value) in obj {
|
||||
let mut prop = value.clone();
|
||||
// Remove JSON Schema keywords not supported by Gemini
|
||||
if let Some(p) = prop.as_object_mut() {
|
||||
p.remove("$schema");
|
||||
p.remove("additionalProperties");
|
||||
|
||||
// Recursively clean nested object properties
|
||||
if let Some(nested_props) = p.get("properties").cloned() {
|
||||
p.insert(
|
||||
"properties".to_string(),
|
||||
clean_schema_properties(&nested_props),
|
||||
);
|
||||
}
|
||||
|
||||
// Clean items schema for arrays
|
||||
if let Some(items) = p.get("items").cloned()
|
||||
&& let Some(items_obj) = items.as_object()
|
||||
{
|
||||
let mut cleaned_items = items_obj.clone();
|
||||
cleaned_items.remove("$schema");
|
||||
cleaned_items.remove("additionalProperties");
|
||||
p.insert("items".to_string(), Value::Object(cleaned_items));
|
||||
}
|
||||
}
|
||||
cleaned.insert(key.clone(), prop);
|
||||
}
|
||||
Value::Object(cleaned)
|
||||
}
|
||||
|
||||
/// Call an MCP tool via storkit's MCP server.
|
||||
async fn call_mcp_tool(
|
||||
client: &Client,
|
||||
mcp_base: &str,
|
||||
tool_name: &str,
|
||||
args: &Value,
|
||||
) -> Result<String, String> {
|
||||
let request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
"arguments": args
|
||||
}
|
||||
});
|
||||
|
||||
let response = client
|
||||
.post(mcp_base)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("MCP tool call failed: {e}"))?;
|
||||
|
||||
let body: Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse MCP tool response: {e}"))?;
|
||||
|
||||
if let Some(error) = body.get("error") {
|
||||
let msg = error["message"].as_str().unwrap_or("Unknown MCP error");
|
||||
return Err(format!("MCP tool '{tool_name}' error: {msg}"));
|
||||
}
|
||||
|
||||
// MCP tools/call returns { result: { content: [{ type: "text", text: "..." }] } }
|
||||
let content = &body["result"]["content"];
|
||||
if let Some(arr) = content.as_array() {
|
||||
let texts: Vec<&str> = arr
|
||||
.iter()
|
||||
.filter_map(|c| c["text"].as_str())
|
||||
.collect();
|
||||
if !texts.is_empty() {
|
||||
return Ok(texts.join("\n"));
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to serializing the entire result.
|
||||
Ok(body["result"].to_string())
|
||||
}
|
||||
|
||||
/// Parse token usage metadata from a Gemini API response.
|
||||
fn parse_usage_metadata(response: &Value) -> Option<TokenUsage> {
|
||||
let metadata = response.get("usageMetadata")?;
|
||||
Some(TokenUsage {
|
||||
input_tokens: metadata
|
||||
.get("promptTokenCount")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0),
|
||||
output_tokens: metadata
|
||||
.get("candidatesTokenCount")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0),
|
||||
// Gemini doesn't have cache token fields, but we keep the struct uniform.
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
// Google AI API doesn't report cost; leave at 0.
|
||||
total_cost_usd: 0.0,
|
||||
})
|
||||
}
|
||||
|
||||
// ── Tests ────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_simple_object() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"story_id": {
|
||||
"type": "string",
|
||||
"description": "Story identifier"
|
||||
}
|
||||
},
|
||||
"required": ["story_id"]
|
||||
});
|
||||
|
||||
let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap();
|
||||
assert_eq!(result["type"], "object");
|
||||
assert!(result["properties"]["story_id"].is_object());
|
||||
assert_eq!(result["required"][0], "story_id");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_empty_properties_returns_none() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
});
|
||||
|
||||
assert!(convert_mcp_schema_to_gemini(Some(&schema)).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_none_returns_none() {
|
||||
assert!(convert_mcp_schema_to_gemini(None).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_strips_additional_properties() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"additionalProperties": false,
|
||||
"$schema": "http://json-schema.org/draft-07/schema#"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap();
|
||||
let name_prop = &result["properties"]["name"];
|
||||
assert!(name_prop.get("additionalProperties").is_none());
|
||||
assert!(name_prop.get("$schema").is_none());
|
||||
assert_eq!(name_prop["type"], "string");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_with_nested_objects() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": { "type": "string" }
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap();
|
||||
assert!(result["properties"]["config"]["properties"]["key"].is_object());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_with_array_items() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" }
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let result = convert_mcp_schema_to_gemini(Some(&schema)).unwrap();
|
||||
let items_schema = &result["properties"]["items"]["items"];
|
||||
assert!(items_schema.get("additionalProperties").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_system_instruction_uses_args() {
|
||||
let ctx = RuntimeContext {
|
||||
story_id: "42_story_test".to_string(),
|
||||
agent_name: "coder-1".to_string(),
|
||||
command: "gemini-2.5-pro".to_string(),
|
||||
args: vec![
|
||||
"--append-system-prompt".to_string(),
|
||||
"Custom system prompt".to_string(),
|
||||
],
|
||||
prompt: "Do the thing".to_string(),
|
||||
cwd: "/tmp/wt".to_string(),
|
||||
inactivity_timeout_secs: 300,
|
||||
mcp_port: 3001,
|
||||
};
|
||||
|
||||
let instruction = build_system_instruction(&ctx);
|
||||
assert_eq!(instruction["parts"][0]["text"], "Custom system prompt");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_system_instruction_default() {
|
||||
let ctx = RuntimeContext {
|
||||
story_id: "42_story_test".to_string(),
|
||||
agent_name: "coder-1".to_string(),
|
||||
command: "gemini-2.5-pro".to_string(),
|
||||
args: vec![],
|
||||
prompt: "Do the thing".to_string(),
|
||||
cwd: "/tmp/wt".to_string(),
|
||||
inactivity_timeout_secs: 300,
|
||||
mcp_port: 3001,
|
||||
};
|
||||
|
||||
let instruction = build_system_instruction(&ctx);
|
||||
let text = instruction["parts"][0]["text"].as_str().unwrap();
|
||||
assert!(text.contains("42_story_test"));
|
||||
assert!(text.contains("/tmp/wt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_generate_content_request_includes_tools() {
|
||||
let system = json!({"parts": [{"text": "system"}]});
|
||||
let contents = vec![json!({"role": "user", "parts": [{"text": "hello"}]})];
|
||||
let tools = vec![GeminiFunctionDeclaration {
|
||||
name: "my_tool".to_string(),
|
||||
description: "A tool".to_string(),
|
||||
parameters: Some(json!({"type": "object", "properties": {"x": {"type": "string"}}})),
|
||||
}];
|
||||
|
||||
let body = build_generate_content_request(&system, &contents, &tools);
|
||||
assert!(body["tools"][0]["functionDeclarations"].is_array());
|
||||
assert_eq!(body["tools"][0]["functionDeclarations"][0]["name"], "my_tool");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_generate_content_request_no_tools() {
|
||||
let system = json!({"parts": [{"text": "system"}]});
|
||||
let contents = vec![json!({"role": "user", "parts": [{"text": "hello"}]})];
|
||||
let tools: Vec<GeminiFunctionDeclaration> = vec![];
|
||||
|
||||
let body = build_generate_content_request(&system, &contents, &tools);
|
||||
assert!(body.get("tools").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_usage_metadata_valid() {
|
||||
let response = json!({
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 100,
|
||||
"candidatesTokenCount": 50,
|
||||
"totalTokenCount": 150
|
||||
}
|
||||
});
|
||||
|
||||
let usage = parse_usage_metadata(&response).unwrap();
|
||||
assert_eq!(usage.input_tokens, 100);
|
||||
assert_eq!(usage.output_tokens, 50);
|
||||
assert_eq!(usage.cache_creation_input_tokens, 0);
|
||||
assert_eq!(usage.total_cost_usd, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_usage_metadata_missing() {
|
||||
let response = json!({"candidates": []});
|
||||
assert!(parse_usage_metadata(&response).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gemini_runtime_stop_sets_cancelled() {
|
||||
let runtime = GeminiRuntime::new();
|
||||
assert_eq!(runtime.get_status(), RuntimeStatus::Idle);
|
||||
runtime.stop();
|
||||
assert_eq!(runtime.get_status(), RuntimeStatus::Failed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn model_extraction_from_command() {
|
||||
// When command starts with "gemini", use it as model name
|
||||
let ctx = RuntimeContext {
|
||||
story_id: "1".to_string(),
|
||||
agent_name: "coder".to_string(),
|
||||
command: "gemini-2.5-pro".to_string(),
|
||||
args: vec![],
|
||||
prompt: "test".to_string(),
|
||||
cwd: "/tmp".to_string(),
|
||||
inactivity_timeout_secs: 300,
|
||||
mcp_port: 3001,
|
||||
};
|
||||
|
||||
// The model extraction logic is inside start(), but we test the
|
||||
// condition here.
|
||||
assert!(ctx.command.starts_with("gemini"));
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
mod claude_code;
|
||||
mod gemini;
|
||||
|
||||
pub use claude_code::ClaudeCodeRuntime;
|
||||
pub use gemini::GeminiRuntime;
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tokio::sync::broadcast;
|
||||
@@ -18,6 +20,9 @@ pub struct RuntimeContext {
|
||||
pub prompt: String,
|
||||
pub cwd: String,
|
||||
pub inactivity_timeout_secs: u64,
|
||||
/// Port of the storkit MCP server, used by API-based runtimes (Gemini, OpenAI)
|
||||
/// to call back for tool execution.
|
||||
pub mcp_port: u16,
|
||||
}
|
||||
|
||||
/// Result returned by a runtime after the agent session completes.
|
||||
@@ -85,6 +90,7 @@ mod tests {
|
||||
prompt: "Do the thing".to_string(),
|
||||
cwd: "/tmp/wt".to_string(),
|
||||
inactivity_timeout_secs: 300,
|
||||
mcp_port: 3001,
|
||||
};
|
||||
assert_eq!(ctx.story_id, "42_story_foo");
|
||||
assert_eq!(ctx.agent_name, "coder-1");
|
||||
@@ -93,6 +99,7 @@ mod tests {
|
||||
assert_eq!(ctx.prompt, "Do the thing");
|
||||
assert_eq!(ctx.cwd, "/tmp/wt");
|
||||
assert_eq!(ctx.inactivity_timeout_secs, 300);
|
||||
assert_eq!(ctx.mcp_port, 3001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -378,10 +378,10 @@ fn validate_agents(agents: &[AgentConfig]) -> Result<(), String> {
|
||||
}
|
||||
if let Some(ref runtime) = agent.runtime {
|
||||
match runtime.as_str() {
|
||||
"claude-code" => {}
|
||||
"claude-code" | "gemini" => {}
|
||||
other => {
|
||||
return Err(format!(
|
||||
"Agent '{}': unknown runtime '{other}'. Supported: 'claude-code'",
|
||||
"Agent '{}': unknown runtime '{other}'. Supported: 'claude-code', 'gemini'",
|
||||
agent.name
|
||||
));
|
||||
}
|
||||
@@ -835,6 +835,18 @@ runtime = "claude-code"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn runtime_gemini_accepted() {
|
||||
let toml_str = r#"
|
||||
[[agent]]
|
||||
name = "coder"
|
||||
runtime = "gemini"
|
||||
model = "gemini-2.5-pro"
|
||||
"#;
|
||||
let config = ProjectConfig::parse(toml_str).unwrap();
|
||||
assert_eq!(config.agent[0].runtime, Some("gemini".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn runtime_unknown_rejected() {
|
||||
let toml_str = r#"
|
||||
|
||||
Reference in New Issue
Block a user