storkit: merge 344_story_chatgpt_agent_backend_via_openai_api

This commit is contained in:
Dave
2026-03-20 23:49:57 +00:00
parent fbf391684a
commit be3b5b0b60
3 changed files with 724 additions and 2 deletions

View File

@@ -17,7 +17,7 @@ use super::{
AgentEvent, AgentInfo, AgentStatus, CompletionReport, PipelineStage, agent_config_stage, AgentEvent, AgentInfo, AgentStatus, CompletionReport, PipelineStage, agent_config_stage,
pipeline_stage, pipeline_stage,
}; };
use super::runtime::{AgentRuntime, ClaudeCodeRuntime, GeminiRuntime, RuntimeContext}; use super::runtime::{AgentRuntime, ClaudeCodeRuntime, GeminiRuntime, OpenAiRuntime, RuntimeContext};
/// Build the composite key used to track agents in the pool. /// Build the composite key used to track agents in the pool.
fn composite_key(story_id: &str, agent_name: &str) -> String { fn composite_key(story_id: &str, agent_name: &str) -> String {
@@ -553,9 +553,25 @@ impl AgentPool {
.start(ctx, tx_clone.clone(), log_clone.clone(), log_writer_clone) .start(ctx, tx_clone.clone(), log_clone.clone(), log_writer_clone)
.await .await
} }
"openai" => {
let runtime = OpenAiRuntime::new();
let ctx = RuntimeContext {
story_id: sid.clone(),
agent_name: aname.clone(),
command,
args,
prompt,
cwd: wt_path_str,
inactivity_timeout_secs,
mcp_port: port_for_task,
};
runtime
.start(ctx, tx_clone.clone(), log_clone.clone(), log_writer_clone)
.await
}
other => Err(format!( other => Err(format!(
"Unknown agent runtime '{other}'; check the 'runtime' field in project.toml. \ "Unknown agent runtime '{other}'; check the 'runtime' field in project.toml. \
Supported: 'claude-code', 'gemini'" Supported: 'claude-code', 'gemini', 'openai'"
)), )),
}; };

View File

@@ -1,8 +1,10 @@
mod claude_code; mod claude_code;
mod gemini; mod gemini;
mod openai;
pub use claude_code::ClaudeCodeRuntime; pub use claude_code::ClaudeCodeRuntime;
pub use gemini::GeminiRuntime; pub use gemini::GeminiRuntime;
pub use openai::OpenAiRuntime;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use tokio::sync::broadcast; use tokio::sync::broadcast;

View File

@@ -0,0 +1,704 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use reqwest::Client;
use serde_json::{json, Value};
use tokio::sync::broadcast;
use crate::agent_log::AgentLogWriter;
use crate::slog;
use super::super::{AgentEvent, TokenUsage};
use super::{AgentRuntime, RuntimeContext, RuntimeResult, RuntimeStatus};
// ── Public runtime struct ────────────────────────────────────────────
/// Agent runtime that drives an OpenAI model (GPT-4o, o3, etc.) through
/// the OpenAI Chat Completions API.
///
/// The runtime:
/// 1. Fetches MCP tool definitions from storkit's MCP server.
/// 2. Converts them to OpenAI function-calling format.
/// 3. Sends the agent prompt + tools to the Chat Completions API.
/// 4. Executes any requested tool calls via MCP `tools/call`.
/// 5. Loops until the model produces a response with no tool calls.
/// 6. Tracks token usage from the API response.
pub struct OpenAiRuntime {
/// Whether a stop has been requested.
cancelled: Arc<AtomicBool>,
}
impl OpenAiRuntime {
pub fn new() -> Self {
Self {
cancelled: Arc::new(AtomicBool::new(false)),
}
}
}
impl AgentRuntime for OpenAiRuntime {
async fn start(
&self,
ctx: RuntimeContext,
tx: broadcast::Sender<AgentEvent>,
event_log: Arc<Mutex<Vec<AgentEvent>>>,
log_writer: Option<Arc<Mutex<AgentLogWriter>>>,
) -> Result<RuntimeResult, String> {
let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
"OPENAI_API_KEY environment variable is not set. \
Set it to your OpenAI API key to use the OpenAI runtime."
.to_string()
})?;
let model = if ctx.command.starts_with("gpt") || ctx.command.starts_with("o") {
// The pool puts the model into `command` for non-CLI runtimes.
ctx.command.clone()
} else {
// Fall back to args: look for --model <value>
ctx.args
.iter()
.position(|a| a == "--model")
.and_then(|i| ctx.args.get(i + 1))
.cloned()
.unwrap_or_else(|| "gpt-4o".to_string())
};
let mcp_port = ctx.mcp_port;
let mcp_base = format!("http://localhost:{mcp_port}/mcp");
let client = Client::new();
let cancelled = Arc::clone(&self.cancelled);
// Step 1: Fetch MCP tool definitions and convert to OpenAI format.
let openai_tools = fetch_and_convert_mcp_tools(&client, &mcp_base).await?;
// Step 2: Build the initial conversation messages.
let system_text = build_system_text(&ctx);
let mut messages: Vec<Value> = vec![
json!({ "role": "system", "content": system_text }),
json!({ "role": "user", "content": ctx.prompt }),
];
let mut total_usage = TokenUsage {
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
total_cost_usd: 0.0,
};
let emit = |event: AgentEvent| {
super::super::pty::emit_event(
event,
&tx,
&event_log,
log_writer.as_ref().map(|w| w.as_ref()),
);
};
emit(AgentEvent::Status {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
status: "running".to_string(),
});
// Step 3: Conversation loop.
let mut turn = 0u32;
let max_turns = 200; // Safety limit
loop {
if cancelled.load(Ordering::Relaxed) {
emit(AgentEvent::Error {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
message: "Agent was stopped by user".to_string(),
});
return Ok(RuntimeResult {
session_id: None,
token_usage: Some(total_usage),
});
}
turn += 1;
if turn > max_turns {
emit(AgentEvent::Error {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
message: format!("Exceeded maximum turns ({max_turns})"),
});
return Ok(RuntimeResult {
session_id: None,
token_usage: Some(total_usage),
});
}
slog!(
"[openai] Turn {turn} for {}:{}",
ctx.story_id,
ctx.agent_name
);
let mut request_body = json!({
"model": model,
"messages": messages,
"temperature": 0.2,
});
if !openai_tools.is_empty() {
request_body["tools"] = json!(openai_tools);
}
let response = client
.post("https://api.openai.com/v1/chat/completions")
.bearer_auth(&api_key)
.json(&request_body)
.send()
.await
.map_err(|e| format!("OpenAI API request failed: {e}"))?;
let status = response.status();
let body: Value = response
.json()
.await
.map_err(|e| format!("Failed to parse OpenAI API response: {e}"))?;
if !status.is_success() {
let error_msg = body["error"]["message"]
.as_str()
.unwrap_or("Unknown API error");
let err = format!("OpenAI API error ({status}): {error_msg}");
emit(AgentEvent::Error {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
message: err.clone(),
});
return Err(err);
}
// Accumulate token usage.
if let Some(usage) = parse_usage(&body) {
total_usage.input_tokens += usage.input_tokens;
total_usage.output_tokens += usage.output_tokens;
}
// Extract the first choice.
let choice = body["choices"]
.as_array()
.and_then(|c| c.first())
.ok_or_else(|| "No choices in OpenAI response".to_string())?;
let message = &choice["message"];
let content = message["content"].as_str().unwrap_or("");
// Emit any text content.
if !content.is_empty() {
emit(AgentEvent::Output {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
text: content.to_string(),
});
}
// Check for tool calls.
let tool_calls = message["tool_calls"].as_array();
if tool_calls.is_none() || tool_calls.is_some_and(|tc| tc.is_empty()) {
// No tool calls — model is done.
emit(AgentEvent::Done {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
session_id: None,
});
return Ok(RuntimeResult {
session_id: None,
token_usage: Some(total_usage),
});
}
let tool_calls = tool_calls.unwrap();
// Add the assistant message (with tool_calls) to the conversation.
messages.push(message.clone());
// Execute each tool call via MCP and add results.
for tc in tool_calls {
if cancelled.load(Ordering::Relaxed) {
break;
}
let call_id = tc["id"].as_str().unwrap_or("");
let function = &tc["function"];
let tool_name = function["name"].as_str().unwrap_or("");
let arguments_str = function["arguments"].as_str().unwrap_or("{}");
let args: Value = serde_json::from_str(arguments_str).unwrap_or(json!({}));
slog!(
"[openai] Calling MCP tool '{}' for {}:{}",
tool_name,
ctx.story_id,
ctx.agent_name
);
emit(AgentEvent::Output {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
text: format!("\n[Tool call: {tool_name}]\n"),
});
let tool_result = call_mcp_tool(&client, &mcp_base, tool_name, &args).await;
let result_content = match &tool_result {
Ok(result) => {
emit(AgentEvent::Output {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
text: format!("[Tool result: {} chars]\n", result.len()),
});
result.clone()
}
Err(e) => {
emit(AgentEvent::Output {
story_id: ctx.story_id.clone(),
agent_name: ctx.agent_name.clone(),
text: format!("[Tool error: {e}]\n"),
});
format!("Error: {e}")
}
};
// OpenAI expects tool results as role=tool messages with
// the matching tool_call_id.
messages.push(json!({
"role": "tool",
"tool_call_id": call_id,
"content": result_content,
}));
}
}
}
fn stop(&self) {
self.cancelled.store(true, Ordering::Relaxed);
}
fn get_status(&self) -> RuntimeStatus {
if self.cancelled.load(Ordering::Relaxed) {
RuntimeStatus::Failed
} else {
RuntimeStatus::Idle
}
}
}
// ── Helper functions ─────────────────────────────────────────────────
/// Build the system message text from the RuntimeContext.
fn build_system_text(ctx: &RuntimeContext) -> String {
ctx.args
.iter()
.position(|a| a == "--append-system-prompt")
.and_then(|i| ctx.args.get(i + 1))
.cloned()
.unwrap_or_else(|| {
format!(
"You are an AI coding agent working on story {}. \
You have access to tools via function calling. \
Use them to complete the task. \
Work in the directory: {}",
ctx.story_id, ctx.cwd
)
})
}
/// Fetch MCP tool definitions from storkit's MCP server and convert
/// them to OpenAI function-calling format.
async fn fetch_and_convert_mcp_tools(
client: &Client,
mcp_base: &str,
) -> Result<Vec<Value>, String> {
let request = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list",
"params": {}
});
let response = client
.post(mcp_base)
.json(&request)
.send()
.await
.map_err(|e| format!("Failed to fetch MCP tools: {e}"))?;
let body: Value = response
.json()
.await
.map_err(|e| format!("Failed to parse MCP tools response: {e}"))?;
let tools = body["result"]["tools"]
.as_array()
.ok_or_else(|| "No tools array in MCP response".to_string())?;
let mut openai_tools = Vec::new();
for tool in tools {
let name = tool["name"].as_str().unwrap_or("").to_string();
let description = tool["description"].as_str().unwrap_or("").to_string();
if name.is_empty() {
continue;
}
// OpenAI function calling uses JSON Schema natively for parameters,
// so the MCP inputSchema can be used with minimal cleanup.
let parameters = convert_mcp_schema_to_openai(tool.get("inputSchema"));
openai_tools.push(json!({
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": parameters.unwrap_or_else(|| json!({"type": "object", "properties": {}})),
}
}));
}
slog!(
"[openai] Loaded {} MCP tools as function definitions",
openai_tools.len()
);
Ok(openai_tools)
}
/// Convert an MCP inputSchema (JSON Schema) to OpenAI-compatible
/// function parameters.
///
/// OpenAI uses JSON Schema natively, so less transformation is needed
/// compared to Gemini. We still strip `$schema` to keep payloads clean.
fn convert_mcp_schema_to_openai(schema: Option<&Value>) -> Option<Value> {
let schema = schema?;
let mut result = json!({
"type": "object",
});
if let Some(properties) = schema.get("properties") {
result["properties"] = clean_schema_properties(properties);
} else {
result["properties"] = json!({});
}
if let Some(required) = schema.get("required") {
result["required"] = required.clone();
}
// OpenAI recommends additionalProperties: false for strict mode.
result["additionalProperties"] = json!(false);
Some(result)
}
/// Recursively clean schema properties, removing unsupported keywords.
fn clean_schema_properties(properties: &Value) -> Value {
let Some(obj) = properties.as_object() else {
return properties.clone();
};
let mut cleaned = serde_json::Map::new();
for (key, value) in obj {
let mut prop = value.clone();
if let Some(p) = prop.as_object_mut() {
p.remove("$schema");
// Recursively clean nested object properties.
if let Some(nested_props) = p.get("properties").cloned() {
p.insert(
"properties".to_string(),
clean_schema_properties(&nested_props),
);
}
// Clean items schema for arrays.
if let Some(items) = p.get("items").cloned()
&& let Some(items_obj) = items.as_object()
{
let mut cleaned_items = items_obj.clone();
cleaned_items.remove("$schema");
p.insert("items".to_string(), Value::Object(cleaned_items));
}
}
cleaned.insert(key.clone(), prop);
}
Value::Object(cleaned)
}
/// Call an MCP tool via storkit's MCP server.
async fn call_mcp_tool(
client: &Client,
mcp_base: &str,
tool_name: &str,
args: &Value,
) -> Result<String, String> {
let request = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": args
}
});
let response = client
.post(mcp_base)
.json(&request)
.send()
.await
.map_err(|e| format!("MCP tool call failed: {e}"))?;
let body: Value = response
.json()
.await
.map_err(|e| format!("Failed to parse MCP tool response: {e}"))?;
if let Some(error) = body.get("error") {
let msg = error["message"].as_str().unwrap_or("Unknown MCP error");
return Err(format!("MCP tool '{tool_name}' error: {msg}"));
}
// MCP tools/call returns { result: { content: [{ type: "text", text: "..." }] } }
let content = &body["result"]["content"];
if let Some(arr) = content.as_array() {
let texts: Vec<&str> = arr
.iter()
.filter_map(|c| c["text"].as_str())
.collect();
if !texts.is_empty() {
return Ok(texts.join("\n"));
}
}
// Fall back to serializing the entire result.
Ok(body["result"].to_string())
}
/// Parse token usage from an OpenAI API response.
fn parse_usage(response: &Value) -> Option<TokenUsage> {
let usage = response.get("usage")?;
Some(TokenUsage {
input_tokens: usage
.get("prompt_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0),
output_tokens: usage
.get("completion_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0),
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
// OpenAI API doesn't report cost directly; leave at 0.
total_cost_usd: 0.0,
})
}
// ── Tests ────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn convert_mcp_schema_simple_object() {
let schema = json!({
"type": "object",
"properties": {
"story_id": {
"type": "string",
"description": "Story identifier"
}
},
"required": ["story_id"]
});
let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap();
assert_eq!(result["type"], "object");
assert!(result["properties"]["story_id"].is_object());
assert_eq!(result["required"][0], "story_id");
assert_eq!(result["additionalProperties"], false);
}
#[test]
fn convert_mcp_schema_empty_properties() {
let schema = json!({
"type": "object",
"properties": {}
});
let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap();
assert_eq!(result["type"], "object");
assert!(result["properties"].as_object().unwrap().is_empty());
}
#[test]
fn convert_mcp_schema_none_returns_none() {
assert!(convert_mcp_schema_to_openai(None).is_none());
}
#[test]
fn convert_mcp_schema_strips_dollar_schema() {
let schema = json!({
"type": "object",
"properties": {
"name": {
"type": "string",
"$schema": "http://json-schema.org/draft-07/schema#"
}
}
});
let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap();
let name_prop = &result["properties"]["name"];
assert!(name_prop.get("$schema").is_none());
assert_eq!(name_prop["type"], "string");
}
#[test]
fn convert_mcp_schema_with_nested_objects() {
let schema = json!({
"type": "object",
"properties": {
"config": {
"type": "object",
"properties": {
"key": { "type": "string" }
}
}
}
});
let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap();
assert!(result["properties"]["config"]["properties"]["key"].is_object());
}
#[test]
fn convert_mcp_schema_with_array_items() {
let schema = json!({
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": { "type": "string" }
},
"$schema": "http://json-schema.org/draft-07/schema#"
}
}
}
});
let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap();
let items_schema = &result["properties"]["items"]["items"];
assert!(items_schema.get("$schema").is_none());
}
#[test]
fn build_system_text_uses_args() {
let ctx = RuntimeContext {
story_id: "42_story_test".to_string(),
agent_name: "coder-1".to_string(),
command: "gpt-4o".to_string(),
args: vec![
"--append-system-prompt".to_string(),
"Custom system prompt".to_string(),
],
prompt: "Do the thing".to_string(),
cwd: "/tmp/wt".to_string(),
inactivity_timeout_secs: 300,
mcp_port: 3001,
};
assert_eq!(build_system_text(&ctx), "Custom system prompt");
}
#[test]
fn build_system_text_default() {
let ctx = RuntimeContext {
story_id: "42_story_test".to_string(),
agent_name: "coder-1".to_string(),
command: "gpt-4o".to_string(),
args: vec![],
prompt: "Do the thing".to_string(),
cwd: "/tmp/wt".to_string(),
inactivity_timeout_secs: 300,
mcp_port: 3001,
};
let text = build_system_text(&ctx);
assert!(text.contains("42_story_test"));
assert!(text.contains("/tmp/wt"));
}
#[test]
fn parse_usage_valid() {
let response = json!({
"usage": {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150
}
});
let usage = parse_usage(&response).unwrap();
assert_eq!(usage.input_tokens, 100);
assert_eq!(usage.output_tokens, 50);
assert_eq!(usage.cache_creation_input_tokens, 0);
assert_eq!(usage.total_cost_usd, 0.0);
}
#[test]
fn parse_usage_missing() {
let response = json!({"choices": []});
assert!(parse_usage(&response).is_none());
}
#[test]
fn openai_runtime_stop_sets_cancelled() {
let runtime = OpenAiRuntime::new();
assert_eq!(runtime.get_status(), RuntimeStatus::Idle);
runtime.stop();
assert_eq!(runtime.get_status(), RuntimeStatus::Failed);
}
#[test]
fn model_extraction_from_command_gpt() {
let ctx = RuntimeContext {
story_id: "1".to_string(),
agent_name: "coder".to_string(),
command: "gpt-4o".to_string(),
args: vec![],
prompt: "test".to_string(),
cwd: "/tmp".to_string(),
inactivity_timeout_secs: 300,
mcp_port: 3001,
};
assert!(ctx.command.starts_with("gpt"));
}
#[test]
fn model_extraction_from_command_o3() {
let ctx = RuntimeContext {
story_id: "1".to_string(),
agent_name: "coder".to_string(),
command: "o3".to_string(),
args: vec![],
prompt: "test".to_string(),
cwd: "/tmp".to_string(),
inactivity_timeout_secs: 300,
mcp_port: 3001,
};
assert!(ctx.command.starts_with("o"));
}
}