diff --git a/server/src/agents/pool/mod.rs b/server/src/agents/pool/mod.rs index bb3e7b7..0ff5436 100644 --- a/server/src/agents/pool/mod.rs +++ b/server/src/agents/pool/mod.rs @@ -17,7 +17,7 @@ use super::{ AgentEvent, AgentInfo, AgentStatus, CompletionReport, PipelineStage, agent_config_stage, pipeline_stage, }; -use super::runtime::{AgentRuntime, ClaudeCodeRuntime, GeminiRuntime, RuntimeContext}; +use super::runtime::{AgentRuntime, ClaudeCodeRuntime, GeminiRuntime, OpenAiRuntime, RuntimeContext}; /// Build the composite key used to track agents in the pool. fn composite_key(story_id: &str, agent_name: &str) -> String { @@ -553,9 +553,25 @@ impl AgentPool { .start(ctx, tx_clone.clone(), log_clone.clone(), log_writer_clone) .await } + "openai" => { + let runtime = OpenAiRuntime::new(); + let ctx = RuntimeContext { + story_id: sid.clone(), + agent_name: aname.clone(), + command, + args, + prompt, + cwd: wt_path_str, + inactivity_timeout_secs, + mcp_port: port_for_task, + }; + runtime + .start(ctx, tx_clone.clone(), log_clone.clone(), log_writer_clone) + .await + } other => Err(format!( "Unknown agent runtime '{other}'; check the 'runtime' field in project.toml. \ - Supported: 'claude-code', 'gemini'" + Supported: 'claude-code', 'gemini', 'openai'" )), }; diff --git a/server/src/agents/runtime/mod.rs b/server/src/agents/runtime/mod.rs index 08c0a51..0b05933 100644 --- a/server/src/agents/runtime/mod.rs +++ b/server/src/agents/runtime/mod.rs @@ -1,8 +1,10 @@ mod claude_code; mod gemini; +mod openai; pub use claude_code::ClaudeCodeRuntime; pub use gemini::GeminiRuntime; +pub use openai::OpenAiRuntime; use std::sync::{Arc, Mutex}; use tokio::sync::broadcast; diff --git a/server/src/agents/runtime/openai.rs b/server/src/agents/runtime/openai.rs new file mode 100644 index 0000000..b7ecdad --- /dev/null +++ b/server/src/agents/runtime/openai.rs @@ -0,0 +1,704 @@ +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; + +use reqwest::Client; +use serde_json::{json, Value}; +use tokio::sync::broadcast; + +use crate::agent_log::AgentLogWriter; +use crate::slog; + +use super::super::{AgentEvent, TokenUsage}; +use super::{AgentRuntime, RuntimeContext, RuntimeResult, RuntimeStatus}; + +// ── Public runtime struct ──────────────────────────────────────────── + +/// Agent runtime that drives an OpenAI model (GPT-4o, o3, etc.) through +/// the OpenAI Chat Completions API. +/// +/// The runtime: +/// 1. Fetches MCP tool definitions from storkit's MCP server. +/// 2. Converts them to OpenAI function-calling format. +/// 3. Sends the agent prompt + tools to the Chat Completions API. +/// 4. Executes any requested tool calls via MCP `tools/call`. +/// 5. Loops until the model produces a response with no tool calls. +/// 6. Tracks token usage from the API response. +pub struct OpenAiRuntime { + /// Whether a stop has been requested. + cancelled: Arc, +} + +impl OpenAiRuntime { + pub fn new() -> Self { + Self { + cancelled: Arc::new(AtomicBool::new(false)), + } + } +} + +impl AgentRuntime for OpenAiRuntime { + async fn start( + &self, + ctx: RuntimeContext, + tx: broadcast::Sender, + event_log: Arc>>, + log_writer: Option>>, + ) -> Result { + let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| { + "OPENAI_API_KEY environment variable is not set. \ + Set it to your OpenAI API key to use the OpenAI runtime." + .to_string() + })?; + + let model = if ctx.command.starts_with("gpt") || ctx.command.starts_with("o") { + // The pool puts the model into `command` for non-CLI runtimes. + ctx.command.clone() + } else { + // Fall back to args: look for --model + ctx.args + .iter() + .position(|a| a == "--model") + .and_then(|i| ctx.args.get(i + 1)) + .cloned() + .unwrap_or_else(|| "gpt-4o".to_string()) + }; + + let mcp_port = ctx.mcp_port; + let mcp_base = format!("http://localhost:{mcp_port}/mcp"); + + let client = Client::new(); + let cancelled = Arc::clone(&self.cancelled); + + // Step 1: Fetch MCP tool definitions and convert to OpenAI format. + let openai_tools = fetch_and_convert_mcp_tools(&client, &mcp_base).await?; + + // Step 2: Build the initial conversation messages. + let system_text = build_system_text(&ctx); + let mut messages: Vec = vec![ + json!({ "role": "system", "content": system_text }), + json!({ "role": "user", "content": ctx.prompt }), + ]; + + let mut total_usage = TokenUsage { + input_tokens: 0, + output_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + total_cost_usd: 0.0, + }; + + let emit = |event: AgentEvent| { + super::super::pty::emit_event( + event, + &tx, + &event_log, + log_writer.as_ref().map(|w| w.as_ref()), + ); + }; + + emit(AgentEvent::Status { + story_id: ctx.story_id.clone(), + agent_name: ctx.agent_name.clone(), + status: "running".to_string(), + }); + + // Step 3: Conversation loop. + let mut turn = 0u32; + let max_turns = 200; // Safety limit + + loop { + if cancelled.load(Ordering::Relaxed) { + emit(AgentEvent::Error { + story_id: ctx.story_id.clone(), + agent_name: ctx.agent_name.clone(), + message: "Agent was stopped by user".to_string(), + }); + return Ok(RuntimeResult { + session_id: None, + token_usage: Some(total_usage), + }); + } + + turn += 1; + if turn > max_turns { + emit(AgentEvent::Error { + story_id: ctx.story_id.clone(), + agent_name: ctx.agent_name.clone(), + message: format!("Exceeded maximum turns ({max_turns})"), + }); + return Ok(RuntimeResult { + session_id: None, + token_usage: Some(total_usage), + }); + } + + slog!( + "[openai] Turn {turn} for {}:{}", + ctx.story_id, + ctx.agent_name + ); + + let mut request_body = json!({ + "model": model, + "messages": messages, + "temperature": 0.2, + }); + + if !openai_tools.is_empty() { + request_body["tools"] = json!(openai_tools); + } + + let response = client + .post("https://api.openai.com/v1/chat/completions") + .bearer_auth(&api_key) + .json(&request_body) + .send() + .await + .map_err(|e| format!("OpenAI API request failed: {e}"))?; + + let status = response.status(); + let body: Value = response + .json() + .await + .map_err(|e| format!("Failed to parse OpenAI API response: {e}"))?; + + if !status.is_success() { + let error_msg = body["error"]["message"] + .as_str() + .unwrap_or("Unknown API error"); + let err = format!("OpenAI API error ({status}): {error_msg}"); + emit(AgentEvent::Error { + story_id: ctx.story_id.clone(), + agent_name: ctx.agent_name.clone(), + message: err.clone(), + }); + return Err(err); + } + + // Accumulate token usage. + if let Some(usage) = parse_usage(&body) { + total_usage.input_tokens += usage.input_tokens; + total_usage.output_tokens += usage.output_tokens; + } + + // Extract the first choice. + let choice = body["choices"] + .as_array() + .and_then(|c| c.first()) + .ok_or_else(|| "No choices in OpenAI response".to_string())?; + + let message = &choice["message"]; + let content = message["content"].as_str().unwrap_or(""); + + // Emit any text content. + if !content.is_empty() { + emit(AgentEvent::Output { + story_id: ctx.story_id.clone(), + agent_name: ctx.agent_name.clone(), + text: content.to_string(), + }); + } + + // Check for tool calls. + let tool_calls = message["tool_calls"].as_array(); + + if tool_calls.is_none() || tool_calls.is_some_and(|tc| tc.is_empty()) { + // No tool calls — model is done. + emit(AgentEvent::Done { + story_id: ctx.story_id.clone(), + agent_name: ctx.agent_name.clone(), + session_id: None, + }); + return Ok(RuntimeResult { + session_id: None, + token_usage: Some(total_usage), + }); + } + + let tool_calls = tool_calls.unwrap(); + + // Add the assistant message (with tool_calls) to the conversation. + messages.push(message.clone()); + + // Execute each tool call via MCP and add results. + for tc in tool_calls { + if cancelled.load(Ordering::Relaxed) { + break; + } + + let call_id = tc["id"].as_str().unwrap_or(""); + let function = &tc["function"]; + let tool_name = function["name"].as_str().unwrap_or(""); + let arguments_str = function["arguments"].as_str().unwrap_or("{}"); + + let args: Value = serde_json::from_str(arguments_str).unwrap_or(json!({})); + + slog!( + "[openai] Calling MCP tool '{}' for {}:{}", + tool_name, + ctx.story_id, + ctx.agent_name + ); + + emit(AgentEvent::Output { + story_id: ctx.story_id.clone(), + agent_name: ctx.agent_name.clone(), + text: format!("\n[Tool call: {tool_name}]\n"), + }); + + let tool_result = call_mcp_tool(&client, &mcp_base, tool_name, &args).await; + + let result_content = match &tool_result { + Ok(result) => { + emit(AgentEvent::Output { + story_id: ctx.story_id.clone(), + agent_name: ctx.agent_name.clone(), + text: format!("[Tool result: {} chars]\n", result.len()), + }); + result.clone() + } + Err(e) => { + emit(AgentEvent::Output { + story_id: ctx.story_id.clone(), + agent_name: ctx.agent_name.clone(), + text: format!("[Tool error: {e}]\n"), + }); + format!("Error: {e}") + } + }; + + // OpenAI expects tool results as role=tool messages with + // the matching tool_call_id. + messages.push(json!({ + "role": "tool", + "tool_call_id": call_id, + "content": result_content, + })); + } + } + } + + fn stop(&self) { + self.cancelled.store(true, Ordering::Relaxed); + } + + fn get_status(&self) -> RuntimeStatus { + if self.cancelled.load(Ordering::Relaxed) { + RuntimeStatus::Failed + } else { + RuntimeStatus::Idle + } + } +} + +// ── Helper functions ───────────────────────────────────────────────── + +/// Build the system message text from the RuntimeContext. +fn build_system_text(ctx: &RuntimeContext) -> String { + ctx.args + .iter() + .position(|a| a == "--append-system-prompt") + .and_then(|i| ctx.args.get(i + 1)) + .cloned() + .unwrap_or_else(|| { + format!( + "You are an AI coding agent working on story {}. \ + You have access to tools via function calling. \ + Use them to complete the task. \ + Work in the directory: {}", + ctx.story_id, ctx.cwd + ) + }) +} + +/// Fetch MCP tool definitions from storkit's MCP server and convert +/// them to OpenAI function-calling format. +async fn fetch_and_convert_mcp_tools( + client: &Client, + mcp_base: &str, +) -> Result, String> { + let request = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list", + "params": {} + }); + + let response = client + .post(mcp_base) + .json(&request) + .send() + .await + .map_err(|e| format!("Failed to fetch MCP tools: {e}"))?; + + let body: Value = response + .json() + .await + .map_err(|e| format!("Failed to parse MCP tools response: {e}"))?; + + let tools = body["result"]["tools"] + .as_array() + .ok_or_else(|| "No tools array in MCP response".to_string())?; + + let mut openai_tools = Vec::new(); + + for tool in tools { + let name = tool["name"].as_str().unwrap_or("").to_string(); + let description = tool["description"].as_str().unwrap_or("").to_string(); + + if name.is_empty() { + continue; + } + + // OpenAI function calling uses JSON Schema natively for parameters, + // so the MCP inputSchema can be used with minimal cleanup. + let parameters = convert_mcp_schema_to_openai(tool.get("inputSchema")); + + openai_tools.push(json!({ + "type": "function", + "function": { + "name": name, + "description": description, + "parameters": parameters.unwrap_or_else(|| json!({"type": "object", "properties": {}})), + } + })); + } + + slog!( + "[openai] Loaded {} MCP tools as function definitions", + openai_tools.len() + ); + Ok(openai_tools) +} + +/// Convert an MCP inputSchema (JSON Schema) to OpenAI-compatible +/// function parameters. +/// +/// OpenAI uses JSON Schema natively, so less transformation is needed +/// compared to Gemini. We still strip `$schema` to keep payloads clean. +fn convert_mcp_schema_to_openai(schema: Option<&Value>) -> Option { + let schema = schema?; + + let mut result = json!({ + "type": "object", + }); + + if let Some(properties) = schema.get("properties") { + result["properties"] = clean_schema_properties(properties); + } else { + result["properties"] = json!({}); + } + + if let Some(required) = schema.get("required") { + result["required"] = required.clone(); + } + + // OpenAI recommends additionalProperties: false for strict mode. + result["additionalProperties"] = json!(false); + + Some(result) +} + +/// Recursively clean schema properties, removing unsupported keywords. +fn clean_schema_properties(properties: &Value) -> Value { + let Some(obj) = properties.as_object() else { + return properties.clone(); + }; + + let mut cleaned = serde_json::Map::new(); + for (key, value) in obj { + let mut prop = value.clone(); + if let Some(p) = prop.as_object_mut() { + p.remove("$schema"); + + // Recursively clean nested object properties. + if let Some(nested_props) = p.get("properties").cloned() { + p.insert( + "properties".to_string(), + clean_schema_properties(&nested_props), + ); + } + + // Clean items schema for arrays. + if let Some(items) = p.get("items").cloned() + && let Some(items_obj) = items.as_object() + { + let mut cleaned_items = items_obj.clone(); + cleaned_items.remove("$schema"); + p.insert("items".to_string(), Value::Object(cleaned_items)); + } + } + cleaned.insert(key.clone(), prop); + } + Value::Object(cleaned) +} + +/// Call an MCP tool via storkit's MCP server. +async fn call_mcp_tool( + client: &Client, + mcp_base: &str, + tool_name: &str, + args: &Value, +) -> Result { + let request = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": tool_name, + "arguments": args + } + }); + + let response = client + .post(mcp_base) + .json(&request) + .send() + .await + .map_err(|e| format!("MCP tool call failed: {e}"))?; + + let body: Value = response + .json() + .await + .map_err(|e| format!("Failed to parse MCP tool response: {e}"))?; + + if let Some(error) = body.get("error") { + let msg = error["message"].as_str().unwrap_or("Unknown MCP error"); + return Err(format!("MCP tool '{tool_name}' error: {msg}")); + } + + // MCP tools/call returns { result: { content: [{ type: "text", text: "..." }] } } + let content = &body["result"]["content"]; + if let Some(arr) = content.as_array() { + let texts: Vec<&str> = arr + .iter() + .filter_map(|c| c["text"].as_str()) + .collect(); + if !texts.is_empty() { + return Ok(texts.join("\n")); + } + } + + // Fall back to serializing the entire result. + Ok(body["result"].to_string()) +} + +/// Parse token usage from an OpenAI API response. +fn parse_usage(response: &Value) -> Option { + let usage = response.get("usage")?; + Some(TokenUsage { + input_tokens: usage + .get("prompt_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0), + output_tokens: usage + .get("completion_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0), + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + // OpenAI API doesn't report cost directly; leave at 0. + total_cost_usd: 0.0, + }) +} + +// ── Tests ──────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn convert_mcp_schema_simple_object() { + let schema = json!({ + "type": "object", + "properties": { + "story_id": { + "type": "string", + "description": "Story identifier" + } + }, + "required": ["story_id"] + }); + + let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap(); + assert_eq!(result["type"], "object"); + assert!(result["properties"]["story_id"].is_object()); + assert_eq!(result["required"][0], "story_id"); + assert_eq!(result["additionalProperties"], false); + } + + #[test] + fn convert_mcp_schema_empty_properties() { + let schema = json!({ + "type": "object", + "properties": {} + }); + + let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap(); + assert_eq!(result["type"], "object"); + assert!(result["properties"].as_object().unwrap().is_empty()); + } + + #[test] + fn convert_mcp_schema_none_returns_none() { + assert!(convert_mcp_schema_to_openai(None).is_none()); + } + + #[test] + fn convert_mcp_schema_strips_dollar_schema() { + let schema = json!({ + "type": "object", + "properties": { + "name": { + "type": "string", + "$schema": "http://json-schema.org/draft-07/schema#" + } + } + }); + + let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap(); + let name_prop = &result["properties"]["name"]; + assert!(name_prop.get("$schema").is_none()); + assert_eq!(name_prop["type"], "string"); + } + + #[test] + fn convert_mcp_schema_with_nested_objects() { + let schema = json!({ + "type": "object", + "properties": { + "config": { + "type": "object", + "properties": { + "key": { "type": "string" } + } + } + } + }); + + let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap(); + assert!(result["properties"]["config"]["properties"]["key"].is_object()); + } + + #[test] + fn convert_mcp_schema_with_array_items() { + let schema = json!({ + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { "type": "string" } + }, + "$schema": "http://json-schema.org/draft-07/schema#" + } + } + } + }); + + let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap(); + let items_schema = &result["properties"]["items"]["items"]; + assert!(items_schema.get("$schema").is_none()); + } + + #[test] + fn build_system_text_uses_args() { + let ctx = RuntimeContext { + story_id: "42_story_test".to_string(), + agent_name: "coder-1".to_string(), + command: "gpt-4o".to_string(), + args: vec![ + "--append-system-prompt".to_string(), + "Custom system prompt".to_string(), + ], + prompt: "Do the thing".to_string(), + cwd: "/tmp/wt".to_string(), + inactivity_timeout_secs: 300, + mcp_port: 3001, + }; + + assert_eq!(build_system_text(&ctx), "Custom system prompt"); + } + + #[test] + fn build_system_text_default() { + let ctx = RuntimeContext { + story_id: "42_story_test".to_string(), + agent_name: "coder-1".to_string(), + command: "gpt-4o".to_string(), + args: vec![], + prompt: "Do the thing".to_string(), + cwd: "/tmp/wt".to_string(), + inactivity_timeout_secs: 300, + mcp_port: 3001, + }; + + let text = build_system_text(&ctx); + assert!(text.contains("42_story_test")); + assert!(text.contains("/tmp/wt")); + } + + #[test] + fn parse_usage_valid() { + let response = json!({ + "usage": { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150 + } + }); + + let usage = parse_usage(&response).unwrap(); + assert_eq!(usage.input_tokens, 100); + assert_eq!(usage.output_tokens, 50); + assert_eq!(usage.cache_creation_input_tokens, 0); + assert_eq!(usage.total_cost_usd, 0.0); + } + + #[test] + fn parse_usage_missing() { + let response = json!({"choices": []}); + assert!(parse_usage(&response).is_none()); + } + + #[test] + fn openai_runtime_stop_sets_cancelled() { + let runtime = OpenAiRuntime::new(); + assert_eq!(runtime.get_status(), RuntimeStatus::Idle); + runtime.stop(); + assert_eq!(runtime.get_status(), RuntimeStatus::Failed); + } + + #[test] + fn model_extraction_from_command_gpt() { + let ctx = RuntimeContext { + story_id: "1".to_string(), + agent_name: "coder".to_string(), + command: "gpt-4o".to_string(), + args: vec![], + prompt: "test".to_string(), + cwd: "/tmp".to_string(), + inactivity_timeout_secs: 300, + mcp_port: 3001, + }; + assert!(ctx.command.starts_with("gpt")); + } + + #[test] + fn model_extraction_from_command_o3() { + let ctx = RuntimeContext { + story_id: "1".to_string(), + agent_name: "coder".to_string(), + command: "o3".to_string(), + args: vec![], + prompt: "test".to_string(), + cwd: "/tmp".to_string(), + inactivity_timeout_secs: 300, + mcp_port: 3001, + }; + assert!(ctx.command.starts_with("o")); + } +}