use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; use reqwest::Client; use serde_json::{json, Value}; use tokio::sync::broadcast; use crate::agent_log::AgentLogWriter; use crate::slog; use super::super::{AgentEvent, TokenUsage}; use super::{AgentRuntime, RuntimeContext, RuntimeResult, RuntimeStatus}; // ── Public runtime struct ──────────────────────────────────────────── /// Agent runtime that drives an OpenAI model (GPT-4o, o3, etc.) through /// the OpenAI Chat Completions API. /// /// The runtime: /// 1. Fetches MCP tool definitions from storkit's MCP server. /// 2. Converts them to OpenAI function-calling format. /// 3. Sends the agent prompt + tools to the Chat Completions API. /// 4. Executes any requested tool calls via MCP `tools/call`. /// 5. Loops until the model produces a response with no tool calls. /// 6. Tracks token usage from the API response. pub struct OpenAiRuntime { /// Whether a stop has been requested. cancelled: Arc, } impl OpenAiRuntime { pub fn new() -> Self { Self { cancelled: Arc::new(AtomicBool::new(false)), } } } impl AgentRuntime for OpenAiRuntime { async fn start( &self, ctx: RuntimeContext, tx: broadcast::Sender, event_log: Arc>>, log_writer: Option>>, ) -> Result { let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| { "OPENAI_API_KEY environment variable is not set. \ Set it to your OpenAI API key to use the OpenAI runtime." .to_string() })?; let model = if ctx.command.starts_with("gpt") || ctx.command.starts_with("o") { // The pool puts the model into `command` for non-CLI runtimes. ctx.command.clone() } else { // Fall back to args: look for --model ctx.args .iter() .position(|a| a == "--model") .and_then(|i| ctx.args.get(i + 1)) .cloned() .unwrap_or_else(|| "gpt-4o".to_string()) }; let mcp_port = ctx.mcp_port; let mcp_base = format!("http://localhost:{mcp_port}/mcp"); let client = Client::new(); let cancelled = Arc::clone(&self.cancelled); // Step 1: Fetch MCP tool definitions and convert to OpenAI format. let openai_tools = fetch_and_convert_mcp_tools(&client, &mcp_base).await?; // Step 2: Build the initial conversation messages. let system_text = build_system_text(&ctx); let mut messages: Vec = vec![ json!({ "role": "system", "content": system_text }), json!({ "role": "user", "content": ctx.prompt }), ]; let mut total_usage = TokenUsage { input_tokens: 0, output_tokens: 0, cache_creation_input_tokens: 0, cache_read_input_tokens: 0, total_cost_usd: 0.0, }; let emit = |event: AgentEvent| { super::super::pty::emit_event( event, &tx, &event_log, log_writer.as_ref().map(|w| w.as_ref()), ); }; emit(AgentEvent::Status { story_id: ctx.story_id.clone(), agent_name: ctx.agent_name.clone(), status: "running".to_string(), }); // Step 3: Conversation loop. let mut turn = 0u32; let max_turns = 200; // Safety limit loop { if cancelled.load(Ordering::Relaxed) { emit(AgentEvent::Error { story_id: ctx.story_id.clone(), agent_name: ctx.agent_name.clone(), message: "Agent was stopped by user".to_string(), }); return Ok(RuntimeResult { session_id: None, token_usage: Some(total_usage), }); } turn += 1; if turn > max_turns { emit(AgentEvent::Error { story_id: ctx.story_id.clone(), agent_name: ctx.agent_name.clone(), message: format!("Exceeded maximum turns ({max_turns})"), }); return Ok(RuntimeResult { session_id: None, token_usage: Some(total_usage), }); } slog!( "[openai] Turn {turn} for {}:{}", ctx.story_id, ctx.agent_name ); let mut request_body = json!({ "model": model, "messages": messages, "temperature": 0.2, }); if !openai_tools.is_empty() { request_body["tools"] = json!(openai_tools); } let response = client .post("https://api.openai.com/v1/chat/completions") .bearer_auth(&api_key) .json(&request_body) .send() .await .map_err(|e| format!("OpenAI API request failed: {e}"))?; let status = response.status(); let body: Value = response .json() .await .map_err(|e| format!("Failed to parse OpenAI API response: {e}"))?; if !status.is_success() { let error_msg = body["error"]["message"] .as_str() .unwrap_or("Unknown API error"); let err = format!("OpenAI API error ({status}): {error_msg}"); emit(AgentEvent::Error { story_id: ctx.story_id.clone(), agent_name: ctx.agent_name.clone(), message: err.clone(), }); return Err(err); } // Accumulate token usage. if let Some(usage) = parse_usage(&body) { total_usage.input_tokens += usage.input_tokens; total_usage.output_tokens += usage.output_tokens; } // Extract the first choice. let choice = body["choices"] .as_array() .and_then(|c| c.first()) .ok_or_else(|| "No choices in OpenAI response".to_string())?; let message = &choice["message"]; let content = message["content"].as_str().unwrap_or(""); // Emit any text content. if !content.is_empty() { emit(AgentEvent::Output { story_id: ctx.story_id.clone(), agent_name: ctx.agent_name.clone(), text: content.to_string(), }); } // Check for tool calls. let tool_calls = message["tool_calls"].as_array(); if tool_calls.is_none() || tool_calls.is_some_and(|tc| tc.is_empty()) { // No tool calls — model is done. emit(AgentEvent::Done { story_id: ctx.story_id.clone(), agent_name: ctx.agent_name.clone(), session_id: None, }); return Ok(RuntimeResult { session_id: None, token_usage: Some(total_usage), }); } let tool_calls = tool_calls.unwrap(); // Add the assistant message (with tool_calls) to the conversation. messages.push(message.clone()); // Execute each tool call via MCP and add results. for tc in tool_calls { if cancelled.load(Ordering::Relaxed) { break; } let call_id = tc["id"].as_str().unwrap_or(""); let function = &tc["function"]; let tool_name = function["name"].as_str().unwrap_or(""); let arguments_str = function["arguments"].as_str().unwrap_or("{}"); let args: Value = serde_json::from_str(arguments_str).unwrap_or(json!({})); slog!( "[openai] Calling MCP tool '{}' for {}:{}", tool_name, ctx.story_id, ctx.agent_name ); emit(AgentEvent::Output { story_id: ctx.story_id.clone(), agent_name: ctx.agent_name.clone(), text: format!("\n[Tool call: {tool_name}]\n"), }); let tool_result = call_mcp_tool(&client, &mcp_base, tool_name, &args).await; let result_content = match &tool_result { Ok(result) => { emit(AgentEvent::Output { story_id: ctx.story_id.clone(), agent_name: ctx.agent_name.clone(), text: format!("[Tool result: {} chars]\n", result.len()), }); result.clone() } Err(e) => { emit(AgentEvent::Output { story_id: ctx.story_id.clone(), agent_name: ctx.agent_name.clone(), text: format!("[Tool error: {e}]\n"), }); format!("Error: {e}") } }; // OpenAI expects tool results as role=tool messages with // the matching tool_call_id. messages.push(json!({ "role": "tool", "tool_call_id": call_id, "content": result_content, })); } } } fn stop(&self) { self.cancelled.store(true, Ordering::Relaxed); } fn get_status(&self) -> RuntimeStatus { if self.cancelled.load(Ordering::Relaxed) { RuntimeStatus::Failed } else { RuntimeStatus::Idle } } } // ── Helper functions ───────────────────────────────────────────────── /// Build the system message text from the RuntimeContext. fn build_system_text(ctx: &RuntimeContext) -> String { ctx.args .iter() .position(|a| a == "--append-system-prompt") .and_then(|i| ctx.args.get(i + 1)) .cloned() .unwrap_or_else(|| { format!( "You are an AI coding agent working on story {}. \ You have access to tools via function calling. \ Use them to complete the task. \ Work in the directory: {}", ctx.story_id, ctx.cwd ) }) } /// Fetch MCP tool definitions from storkit's MCP server and convert /// them to OpenAI function-calling format. async fn fetch_and_convert_mcp_tools( client: &Client, mcp_base: &str, ) -> Result, String> { let request = json!({ "jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {} }); let response = client .post(mcp_base) .json(&request) .send() .await .map_err(|e| format!("Failed to fetch MCP tools: {e}"))?; let body: Value = response .json() .await .map_err(|e| format!("Failed to parse MCP tools response: {e}"))?; let tools = body["result"]["tools"] .as_array() .ok_or_else(|| "No tools array in MCP response".to_string())?; let mut openai_tools = Vec::new(); for tool in tools { let name = tool["name"].as_str().unwrap_or("").to_string(); let description = tool["description"].as_str().unwrap_or("").to_string(); if name.is_empty() { continue; } // OpenAI function calling uses JSON Schema natively for parameters, // so the MCP inputSchema can be used with minimal cleanup. let parameters = convert_mcp_schema_to_openai(tool.get("inputSchema")); openai_tools.push(json!({ "type": "function", "function": { "name": name, "description": description, "parameters": parameters.unwrap_or_else(|| json!({"type": "object", "properties": {}})), } })); } slog!( "[openai] Loaded {} MCP tools as function definitions", openai_tools.len() ); Ok(openai_tools) } /// Convert an MCP inputSchema (JSON Schema) to OpenAI-compatible /// function parameters. /// /// OpenAI uses JSON Schema natively, so less transformation is needed /// compared to Gemini. We still strip `$schema` to keep payloads clean. fn convert_mcp_schema_to_openai(schema: Option<&Value>) -> Option { let schema = schema?; let mut result = json!({ "type": "object", }); if let Some(properties) = schema.get("properties") { result["properties"] = clean_schema_properties(properties); } else { result["properties"] = json!({}); } if let Some(required) = schema.get("required") { result["required"] = required.clone(); } // OpenAI recommends additionalProperties: false for strict mode. result["additionalProperties"] = json!(false); Some(result) } /// Recursively clean schema properties, removing unsupported keywords. fn clean_schema_properties(properties: &Value) -> Value { let Some(obj) = properties.as_object() else { return properties.clone(); }; let mut cleaned = serde_json::Map::new(); for (key, value) in obj { let mut prop = value.clone(); if let Some(p) = prop.as_object_mut() { p.remove("$schema"); // Recursively clean nested object properties. if let Some(nested_props) = p.get("properties").cloned() { p.insert( "properties".to_string(), clean_schema_properties(&nested_props), ); } // Clean items schema for arrays. if let Some(items) = p.get("items").cloned() && let Some(items_obj) = items.as_object() { let mut cleaned_items = items_obj.clone(); cleaned_items.remove("$schema"); p.insert("items".to_string(), Value::Object(cleaned_items)); } } cleaned.insert(key.clone(), prop); } Value::Object(cleaned) } /// Call an MCP tool via storkit's MCP server. async fn call_mcp_tool( client: &Client, mcp_base: &str, tool_name: &str, args: &Value, ) -> Result { let request = json!({ "jsonrpc": "2.0", "id": 1, "method": "tools/call", "params": { "name": tool_name, "arguments": args } }); let response = client .post(mcp_base) .json(&request) .send() .await .map_err(|e| format!("MCP tool call failed: {e}"))?; let body: Value = response .json() .await .map_err(|e| format!("Failed to parse MCP tool response: {e}"))?; if let Some(error) = body.get("error") { let msg = error["message"].as_str().unwrap_or("Unknown MCP error"); return Err(format!("MCP tool '{tool_name}' error: {msg}")); } // MCP tools/call returns { result: { content: [{ type: "text", text: "..." }] } } let content = &body["result"]["content"]; if let Some(arr) = content.as_array() { let texts: Vec<&str> = arr .iter() .filter_map(|c| c["text"].as_str()) .collect(); if !texts.is_empty() { return Ok(texts.join("\n")); } } // Fall back to serializing the entire result. Ok(body["result"].to_string()) } /// Parse token usage from an OpenAI API response. fn parse_usage(response: &Value) -> Option { let usage = response.get("usage")?; Some(TokenUsage { input_tokens: usage .get("prompt_tokens") .and_then(|v| v.as_u64()) .unwrap_or(0), output_tokens: usage .get("completion_tokens") .and_then(|v| v.as_u64()) .unwrap_or(0), cache_creation_input_tokens: 0, cache_read_input_tokens: 0, // OpenAI API doesn't report cost directly; leave at 0. total_cost_usd: 0.0, }) } // ── Tests ──────────────────────────────────────────────────────────── #[cfg(test)] mod tests { use super::*; #[test] fn convert_mcp_schema_simple_object() { let schema = json!({ "type": "object", "properties": { "story_id": { "type": "string", "description": "Story identifier" } }, "required": ["story_id"] }); let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap(); assert_eq!(result["type"], "object"); assert!(result["properties"]["story_id"].is_object()); assert_eq!(result["required"][0], "story_id"); assert_eq!(result["additionalProperties"], false); } #[test] fn convert_mcp_schema_empty_properties() { let schema = json!({ "type": "object", "properties": {} }); let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap(); assert_eq!(result["type"], "object"); assert!(result["properties"].as_object().unwrap().is_empty()); } #[test] fn convert_mcp_schema_none_returns_none() { assert!(convert_mcp_schema_to_openai(None).is_none()); } #[test] fn convert_mcp_schema_strips_dollar_schema() { let schema = json!({ "type": "object", "properties": { "name": { "type": "string", "$schema": "http://json-schema.org/draft-07/schema#" } } }); let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap(); let name_prop = &result["properties"]["name"]; assert!(name_prop.get("$schema").is_none()); assert_eq!(name_prop["type"], "string"); } #[test] fn convert_mcp_schema_with_nested_objects() { let schema = json!({ "type": "object", "properties": { "config": { "type": "object", "properties": { "key": { "type": "string" } } } } }); let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap(); assert!(result["properties"]["config"]["properties"]["key"].is_object()); } #[test] fn convert_mcp_schema_with_array_items() { let schema = json!({ "type": "object", "properties": { "items": { "type": "array", "items": { "type": "object", "properties": { "name": { "type": "string" } }, "$schema": "http://json-schema.org/draft-07/schema#" } } } }); let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap(); let items_schema = &result["properties"]["items"]["items"]; assert!(items_schema.get("$schema").is_none()); } #[test] fn build_system_text_uses_args() { let ctx = RuntimeContext { story_id: "42_story_test".to_string(), agent_name: "coder-1".to_string(), command: "gpt-4o".to_string(), args: vec![ "--append-system-prompt".to_string(), "Custom system prompt".to_string(), ], prompt: "Do the thing".to_string(), cwd: "/tmp/wt".to_string(), inactivity_timeout_secs: 300, mcp_port: 3001, }; assert_eq!(build_system_text(&ctx), "Custom system prompt"); } #[test] fn build_system_text_default() { let ctx = RuntimeContext { story_id: "42_story_test".to_string(), agent_name: "coder-1".to_string(), command: "gpt-4o".to_string(), args: vec![], prompt: "Do the thing".to_string(), cwd: "/tmp/wt".to_string(), inactivity_timeout_secs: 300, mcp_port: 3001, }; let text = build_system_text(&ctx); assert!(text.contains("42_story_test")); assert!(text.contains("/tmp/wt")); } #[test] fn parse_usage_valid() { let response = json!({ "usage": { "prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150 } }); let usage = parse_usage(&response).unwrap(); assert_eq!(usage.input_tokens, 100); assert_eq!(usage.output_tokens, 50); assert_eq!(usage.cache_creation_input_tokens, 0); assert_eq!(usage.total_cost_usd, 0.0); } #[test] fn parse_usage_missing() { let response = json!({"choices": []}); assert!(parse_usage(&response).is_none()); } #[test] fn openai_runtime_stop_sets_cancelled() { let runtime = OpenAiRuntime::new(); assert_eq!(runtime.get_status(), RuntimeStatus::Idle); runtime.stop(); assert_eq!(runtime.get_status(), RuntimeStatus::Failed); } #[test] fn model_extraction_from_command_gpt() { let ctx = RuntimeContext { story_id: "1".to_string(), agent_name: "coder".to_string(), command: "gpt-4o".to_string(), args: vec![], prompt: "test".to_string(), cwd: "/tmp".to_string(), inactivity_timeout_secs: 300, mcp_port: 3001, }; assert!(ctx.command.starts_with("gpt")); } #[test] fn model_extraction_from_command_o3() { let ctx = RuntimeContext { story_id: "1".to_string(), agent_name: "coder".to_string(), command: "o3".to_string(), args: vec![], prompt: "test".to_string(), cwd: "/tmp".to_string(), inactivity_timeout_secs: 300, mcp_port: 3001, }; assert!(ctx.command.starts_with("o")); } }