storkit: merge 344_story_chatgpt_agent_backend_via_openai_api
This commit is contained in:
@@ -17,7 +17,7 @@ use super::{
|
||||
AgentEvent, AgentInfo, AgentStatus, CompletionReport, PipelineStage, agent_config_stage,
|
||||
pipeline_stage,
|
||||
};
|
||||
use super::runtime::{AgentRuntime, ClaudeCodeRuntime, GeminiRuntime, RuntimeContext};
|
||||
use super::runtime::{AgentRuntime, ClaudeCodeRuntime, GeminiRuntime, OpenAiRuntime, RuntimeContext};
|
||||
|
||||
/// Build the composite key used to track agents in the pool.
|
||||
fn composite_key(story_id: &str, agent_name: &str) -> String {
|
||||
@@ -553,9 +553,25 @@ impl AgentPool {
|
||||
.start(ctx, tx_clone.clone(), log_clone.clone(), log_writer_clone)
|
||||
.await
|
||||
}
|
||||
"openai" => {
|
||||
let runtime = OpenAiRuntime::new();
|
||||
let ctx = RuntimeContext {
|
||||
story_id: sid.clone(),
|
||||
agent_name: aname.clone(),
|
||||
command,
|
||||
args,
|
||||
prompt,
|
||||
cwd: wt_path_str,
|
||||
inactivity_timeout_secs,
|
||||
mcp_port: port_for_task,
|
||||
};
|
||||
runtime
|
||||
.start(ctx, tx_clone.clone(), log_clone.clone(), log_writer_clone)
|
||||
.await
|
||||
}
|
||||
other => Err(format!(
|
||||
"Unknown agent runtime '{other}'; check the 'runtime' field in project.toml. \
|
||||
Supported: 'claude-code', 'gemini'"
|
||||
Supported: 'claude-code', 'gemini', 'openai'"
|
||||
)),
|
||||
};
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
mod claude_code;
|
||||
mod gemini;
|
||||
mod openai;
|
||||
|
||||
pub use claude_code::ClaudeCodeRuntime;
|
||||
pub use gemini::GeminiRuntime;
|
||||
pub use openai::OpenAiRuntime;
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
704
server/src/agents/runtime/openai.rs
Normal file
704
server/src/agents/runtime/openai.rs
Normal file
@@ -0,0 +1,704 @@
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use reqwest::Client;
|
||||
use serde_json::{json, Value};
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
use crate::agent_log::AgentLogWriter;
|
||||
use crate::slog;
|
||||
|
||||
use super::super::{AgentEvent, TokenUsage};
|
||||
use super::{AgentRuntime, RuntimeContext, RuntimeResult, RuntimeStatus};
|
||||
|
||||
// ── Public runtime struct ────────────────────────────────────────────
|
||||
|
||||
/// Agent runtime that drives an OpenAI model (GPT-4o, o3, etc.) through
|
||||
/// the OpenAI Chat Completions API.
|
||||
///
|
||||
/// The runtime:
|
||||
/// 1. Fetches MCP tool definitions from storkit's MCP server.
|
||||
/// 2. Converts them to OpenAI function-calling format.
|
||||
/// 3. Sends the agent prompt + tools to the Chat Completions API.
|
||||
/// 4. Executes any requested tool calls via MCP `tools/call`.
|
||||
/// 5. Loops until the model produces a response with no tool calls.
|
||||
/// 6. Tracks token usage from the API response.
|
||||
pub struct OpenAiRuntime {
|
||||
/// Whether a stop has been requested.
|
||||
cancelled: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl OpenAiRuntime {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
cancelled: Arc::new(AtomicBool::new(false)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentRuntime for OpenAiRuntime {
|
||||
async fn start(
|
||||
&self,
|
||||
ctx: RuntimeContext,
|
||||
tx: broadcast::Sender<AgentEvent>,
|
||||
event_log: Arc<Mutex<Vec<AgentEvent>>>,
|
||||
log_writer: Option<Arc<Mutex<AgentLogWriter>>>,
|
||||
) -> Result<RuntimeResult, String> {
|
||||
let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
|
||||
"OPENAI_API_KEY environment variable is not set. \
|
||||
Set it to your OpenAI API key to use the OpenAI runtime."
|
||||
.to_string()
|
||||
})?;
|
||||
|
||||
let model = if ctx.command.starts_with("gpt") || ctx.command.starts_with("o") {
|
||||
// The pool puts the model into `command` for non-CLI runtimes.
|
||||
ctx.command.clone()
|
||||
} else {
|
||||
// Fall back to args: look for --model <value>
|
||||
ctx.args
|
||||
.iter()
|
||||
.position(|a| a == "--model")
|
||||
.and_then(|i| ctx.args.get(i + 1))
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "gpt-4o".to_string())
|
||||
};
|
||||
|
||||
let mcp_port = ctx.mcp_port;
|
||||
let mcp_base = format!("http://localhost:{mcp_port}/mcp");
|
||||
|
||||
let client = Client::new();
|
||||
let cancelled = Arc::clone(&self.cancelled);
|
||||
|
||||
// Step 1: Fetch MCP tool definitions and convert to OpenAI format.
|
||||
let openai_tools = fetch_and_convert_mcp_tools(&client, &mcp_base).await?;
|
||||
|
||||
// Step 2: Build the initial conversation messages.
|
||||
let system_text = build_system_text(&ctx);
|
||||
let mut messages: Vec<Value> = vec![
|
||||
json!({ "role": "system", "content": system_text }),
|
||||
json!({ "role": "user", "content": ctx.prompt }),
|
||||
];
|
||||
|
||||
let mut total_usage = TokenUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
total_cost_usd: 0.0,
|
||||
};
|
||||
|
||||
let emit = |event: AgentEvent| {
|
||||
super::super::pty::emit_event(
|
||||
event,
|
||||
&tx,
|
||||
&event_log,
|
||||
log_writer.as_ref().map(|w| w.as_ref()),
|
||||
);
|
||||
};
|
||||
|
||||
emit(AgentEvent::Status {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
status: "running".to_string(),
|
||||
});
|
||||
|
||||
// Step 3: Conversation loop.
|
||||
let mut turn = 0u32;
|
||||
let max_turns = 200; // Safety limit
|
||||
|
||||
loop {
|
||||
if cancelled.load(Ordering::Relaxed) {
|
||||
emit(AgentEvent::Error {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
message: "Agent was stopped by user".to_string(),
|
||||
});
|
||||
return Ok(RuntimeResult {
|
||||
session_id: None,
|
||||
token_usage: Some(total_usage),
|
||||
});
|
||||
}
|
||||
|
||||
turn += 1;
|
||||
if turn > max_turns {
|
||||
emit(AgentEvent::Error {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
message: format!("Exceeded maximum turns ({max_turns})"),
|
||||
});
|
||||
return Ok(RuntimeResult {
|
||||
session_id: None,
|
||||
token_usage: Some(total_usage),
|
||||
});
|
||||
}
|
||||
|
||||
slog!(
|
||||
"[openai] Turn {turn} for {}:{}",
|
||||
ctx.story_id,
|
||||
ctx.agent_name
|
||||
);
|
||||
|
||||
let mut request_body = json!({
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": 0.2,
|
||||
});
|
||||
|
||||
if !openai_tools.is_empty() {
|
||||
request_body["tools"] = json!(openai_tools);
|
||||
}
|
||||
|
||||
let response = client
|
||||
.post("https://api.openai.com/v1/chat/completions")
|
||||
.bearer_auth(&api_key)
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("OpenAI API request failed: {e}"))?;
|
||||
|
||||
let status = response.status();
|
||||
let body: Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse OpenAI API response: {e}"))?;
|
||||
|
||||
if !status.is_success() {
|
||||
let error_msg = body["error"]["message"]
|
||||
.as_str()
|
||||
.unwrap_or("Unknown API error");
|
||||
let err = format!("OpenAI API error ({status}): {error_msg}");
|
||||
emit(AgentEvent::Error {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
message: err.clone(),
|
||||
});
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
// Accumulate token usage.
|
||||
if let Some(usage) = parse_usage(&body) {
|
||||
total_usage.input_tokens += usage.input_tokens;
|
||||
total_usage.output_tokens += usage.output_tokens;
|
||||
}
|
||||
|
||||
// Extract the first choice.
|
||||
let choice = body["choices"]
|
||||
.as_array()
|
||||
.and_then(|c| c.first())
|
||||
.ok_or_else(|| "No choices in OpenAI response".to_string())?;
|
||||
|
||||
let message = &choice["message"];
|
||||
let content = message["content"].as_str().unwrap_or("");
|
||||
|
||||
// Emit any text content.
|
||||
if !content.is_empty() {
|
||||
emit(AgentEvent::Output {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
text: content.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Check for tool calls.
|
||||
let tool_calls = message["tool_calls"].as_array();
|
||||
|
||||
if tool_calls.is_none() || tool_calls.is_some_and(|tc| tc.is_empty()) {
|
||||
// No tool calls — model is done.
|
||||
emit(AgentEvent::Done {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
session_id: None,
|
||||
});
|
||||
return Ok(RuntimeResult {
|
||||
session_id: None,
|
||||
token_usage: Some(total_usage),
|
||||
});
|
||||
}
|
||||
|
||||
let tool_calls = tool_calls.unwrap();
|
||||
|
||||
// Add the assistant message (with tool_calls) to the conversation.
|
||||
messages.push(message.clone());
|
||||
|
||||
// Execute each tool call via MCP and add results.
|
||||
for tc in tool_calls {
|
||||
if cancelled.load(Ordering::Relaxed) {
|
||||
break;
|
||||
}
|
||||
|
||||
let call_id = tc["id"].as_str().unwrap_or("");
|
||||
let function = &tc["function"];
|
||||
let tool_name = function["name"].as_str().unwrap_or("");
|
||||
let arguments_str = function["arguments"].as_str().unwrap_or("{}");
|
||||
|
||||
let args: Value = serde_json::from_str(arguments_str).unwrap_or(json!({}));
|
||||
|
||||
slog!(
|
||||
"[openai] Calling MCP tool '{}' for {}:{}",
|
||||
tool_name,
|
||||
ctx.story_id,
|
||||
ctx.agent_name
|
||||
);
|
||||
|
||||
emit(AgentEvent::Output {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
text: format!("\n[Tool call: {tool_name}]\n"),
|
||||
});
|
||||
|
||||
let tool_result = call_mcp_tool(&client, &mcp_base, tool_name, &args).await;
|
||||
|
||||
let result_content = match &tool_result {
|
||||
Ok(result) => {
|
||||
emit(AgentEvent::Output {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
text: format!("[Tool result: {} chars]\n", result.len()),
|
||||
});
|
||||
result.clone()
|
||||
}
|
||||
Err(e) => {
|
||||
emit(AgentEvent::Output {
|
||||
story_id: ctx.story_id.clone(),
|
||||
agent_name: ctx.agent_name.clone(),
|
||||
text: format!("[Tool error: {e}]\n"),
|
||||
});
|
||||
format!("Error: {e}")
|
||||
}
|
||||
};
|
||||
|
||||
// OpenAI expects tool results as role=tool messages with
|
||||
// the matching tool_call_id.
|
||||
messages.push(json!({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": result_content,
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn stop(&self) {
|
||||
self.cancelled.store(true, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn get_status(&self) -> RuntimeStatus {
|
||||
if self.cancelled.load(Ordering::Relaxed) {
|
||||
RuntimeStatus::Failed
|
||||
} else {
|
||||
RuntimeStatus::Idle
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Helper functions ─────────────────────────────────────────────────
|
||||
|
||||
/// Build the system message text from the RuntimeContext.
|
||||
fn build_system_text(ctx: &RuntimeContext) -> String {
|
||||
ctx.args
|
||||
.iter()
|
||||
.position(|a| a == "--append-system-prompt")
|
||||
.and_then(|i| ctx.args.get(i + 1))
|
||||
.cloned()
|
||||
.unwrap_or_else(|| {
|
||||
format!(
|
||||
"You are an AI coding agent working on story {}. \
|
||||
You have access to tools via function calling. \
|
||||
Use them to complete the task. \
|
||||
Work in the directory: {}",
|
||||
ctx.story_id, ctx.cwd
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Fetch MCP tool definitions from storkit's MCP server and convert
|
||||
/// them to OpenAI function-calling format.
|
||||
async fn fetch_and_convert_mcp_tools(
|
||||
client: &Client,
|
||||
mcp_base: &str,
|
||||
) -> Result<Vec<Value>, String> {
|
||||
let request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/list",
|
||||
"params": {}
|
||||
});
|
||||
|
||||
let response = client
|
||||
.post(mcp_base)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to fetch MCP tools: {e}"))?;
|
||||
|
||||
let body: Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse MCP tools response: {e}"))?;
|
||||
|
||||
let tools = body["result"]["tools"]
|
||||
.as_array()
|
||||
.ok_or_else(|| "No tools array in MCP response".to_string())?;
|
||||
|
||||
let mut openai_tools = Vec::new();
|
||||
|
||||
for tool in tools {
|
||||
let name = tool["name"].as_str().unwrap_or("").to_string();
|
||||
let description = tool["description"].as_str().unwrap_or("").to_string();
|
||||
|
||||
if name.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// OpenAI function calling uses JSON Schema natively for parameters,
|
||||
// so the MCP inputSchema can be used with minimal cleanup.
|
||||
let parameters = convert_mcp_schema_to_openai(tool.get("inputSchema"));
|
||||
|
||||
openai_tools.push(json!({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": parameters.unwrap_or_else(|| json!({"type": "object", "properties": {}})),
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
slog!(
|
||||
"[openai] Loaded {} MCP tools as function definitions",
|
||||
openai_tools.len()
|
||||
);
|
||||
Ok(openai_tools)
|
||||
}
|
||||
|
||||
/// Convert an MCP inputSchema (JSON Schema) to OpenAI-compatible
|
||||
/// function parameters.
|
||||
///
|
||||
/// OpenAI uses JSON Schema natively, so less transformation is needed
|
||||
/// compared to Gemini. We still strip `$schema` to keep payloads clean.
|
||||
fn convert_mcp_schema_to_openai(schema: Option<&Value>) -> Option<Value> {
|
||||
let schema = schema?;
|
||||
|
||||
let mut result = json!({
|
||||
"type": "object",
|
||||
});
|
||||
|
||||
if let Some(properties) = schema.get("properties") {
|
||||
result["properties"] = clean_schema_properties(properties);
|
||||
} else {
|
||||
result["properties"] = json!({});
|
||||
}
|
||||
|
||||
if let Some(required) = schema.get("required") {
|
||||
result["required"] = required.clone();
|
||||
}
|
||||
|
||||
// OpenAI recommends additionalProperties: false for strict mode.
|
||||
result["additionalProperties"] = json!(false);
|
||||
|
||||
Some(result)
|
||||
}
|
||||
|
||||
/// Recursively clean schema properties, removing unsupported keywords.
|
||||
fn clean_schema_properties(properties: &Value) -> Value {
|
||||
let Some(obj) = properties.as_object() else {
|
||||
return properties.clone();
|
||||
};
|
||||
|
||||
let mut cleaned = serde_json::Map::new();
|
||||
for (key, value) in obj {
|
||||
let mut prop = value.clone();
|
||||
if let Some(p) = prop.as_object_mut() {
|
||||
p.remove("$schema");
|
||||
|
||||
// Recursively clean nested object properties.
|
||||
if let Some(nested_props) = p.get("properties").cloned() {
|
||||
p.insert(
|
||||
"properties".to_string(),
|
||||
clean_schema_properties(&nested_props),
|
||||
);
|
||||
}
|
||||
|
||||
// Clean items schema for arrays.
|
||||
if let Some(items) = p.get("items").cloned()
|
||||
&& let Some(items_obj) = items.as_object()
|
||||
{
|
||||
let mut cleaned_items = items_obj.clone();
|
||||
cleaned_items.remove("$schema");
|
||||
p.insert("items".to_string(), Value::Object(cleaned_items));
|
||||
}
|
||||
}
|
||||
cleaned.insert(key.clone(), prop);
|
||||
}
|
||||
Value::Object(cleaned)
|
||||
}
|
||||
|
||||
/// Call an MCP tool via storkit's MCP server.
|
||||
async fn call_mcp_tool(
|
||||
client: &Client,
|
||||
mcp_base: &str,
|
||||
tool_name: &str,
|
||||
args: &Value,
|
||||
) -> Result<String, String> {
|
||||
let request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
"arguments": args
|
||||
}
|
||||
});
|
||||
|
||||
let response = client
|
||||
.post(mcp_base)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("MCP tool call failed: {e}"))?;
|
||||
|
||||
let body: Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse MCP tool response: {e}"))?;
|
||||
|
||||
if let Some(error) = body.get("error") {
|
||||
let msg = error["message"].as_str().unwrap_or("Unknown MCP error");
|
||||
return Err(format!("MCP tool '{tool_name}' error: {msg}"));
|
||||
}
|
||||
|
||||
// MCP tools/call returns { result: { content: [{ type: "text", text: "..." }] } }
|
||||
let content = &body["result"]["content"];
|
||||
if let Some(arr) = content.as_array() {
|
||||
let texts: Vec<&str> = arr
|
||||
.iter()
|
||||
.filter_map(|c| c["text"].as_str())
|
||||
.collect();
|
||||
if !texts.is_empty() {
|
||||
return Ok(texts.join("\n"));
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to serializing the entire result.
|
||||
Ok(body["result"].to_string())
|
||||
}
|
||||
|
||||
/// Parse token usage from an OpenAI API response.
|
||||
fn parse_usage(response: &Value) -> Option<TokenUsage> {
|
||||
let usage = response.get("usage")?;
|
||||
Some(TokenUsage {
|
||||
input_tokens: usage
|
||||
.get("prompt_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0),
|
||||
output_tokens: usage
|
||||
.get("completion_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0),
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
// OpenAI API doesn't report cost directly; leave at 0.
|
||||
total_cost_usd: 0.0,
|
||||
})
|
||||
}
|
||||
|
||||
// ── Tests ────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_simple_object() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"story_id": {
|
||||
"type": "string",
|
||||
"description": "Story identifier"
|
||||
}
|
||||
},
|
||||
"required": ["story_id"]
|
||||
});
|
||||
|
||||
let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap();
|
||||
assert_eq!(result["type"], "object");
|
||||
assert!(result["properties"]["story_id"].is_object());
|
||||
assert_eq!(result["required"][0], "story_id");
|
||||
assert_eq!(result["additionalProperties"], false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_empty_properties() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
});
|
||||
|
||||
let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap();
|
||||
assert_eq!(result["type"], "object");
|
||||
assert!(result["properties"].as_object().unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_none_returns_none() {
|
||||
assert!(convert_mcp_schema_to_openai(None).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_strips_dollar_schema() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"$schema": "http://json-schema.org/draft-07/schema#"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap();
|
||||
let name_prop = &result["properties"]["name"];
|
||||
assert!(name_prop.get("$schema").is_none());
|
||||
assert_eq!(name_prop["type"], "string");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_with_nested_objects() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": { "type": "string" }
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap();
|
||||
assert!(result["properties"]["config"]["properties"]["key"].is_object());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_mcp_schema_with_array_items() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" }
|
||||
},
|
||||
"$schema": "http://json-schema.org/draft-07/schema#"
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let result = convert_mcp_schema_to_openai(Some(&schema)).unwrap();
|
||||
let items_schema = &result["properties"]["items"]["items"];
|
||||
assert!(items_schema.get("$schema").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_system_text_uses_args() {
|
||||
let ctx = RuntimeContext {
|
||||
story_id: "42_story_test".to_string(),
|
||||
agent_name: "coder-1".to_string(),
|
||||
command: "gpt-4o".to_string(),
|
||||
args: vec![
|
||||
"--append-system-prompt".to_string(),
|
||||
"Custom system prompt".to_string(),
|
||||
],
|
||||
prompt: "Do the thing".to_string(),
|
||||
cwd: "/tmp/wt".to_string(),
|
||||
inactivity_timeout_secs: 300,
|
||||
mcp_port: 3001,
|
||||
};
|
||||
|
||||
assert_eq!(build_system_text(&ctx), "Custom system prompt");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_system_text_default() {
|
||||
let ctx = RuntimeContext {
|
||||
story_id: "42_story_test".to_string(),
|
||||
agent_name: "coder-1".to_string(),
|
||||
command: "gpt-4o".to_string(),
|
||||
args: vec![],
|
||||
prompt: "Do the thing".to_string(),
|
||||
cwd: "/tmp/wt".to_string(),
|
||||
inactivity_timeout_secs: 300,
|
||||
mcp_port: 3001,
|
||||
};
|
||||
|
||||
let text = build_system_text(&ctx);
|
||||
assert!(text.contains("42_story_test"));
|
||||
assert!(text.contains("/tmp/wt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_usage_valid() {
|
||||
let response = json!({
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150
|
||||
}
|
||||
});
|
||||
|
||||
let usage = parse_usage(&response).unwrap();
|
||||
assert_eq!(usage.input_tokens, 100);
|
||||
assert_eq!(usage.output_tokens, 50);
|
||||
assert_eq!(usage.cache_creation_input_tokens, 0);
|
||||
assert_eq!(usage.total_cost_usd, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_usage_missing() {
|
||||
let response = json!({"choices": []});
|
||||
assert!(parse_usage(&response).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn openai_runtime_stop_sets_cancelled() {
|
||||
let runtime = OpenAiRuntime::new();
|
||||
assert_eq!(runtime.get_status(), RuntimeStatus::Idle);
|
||||
runtime.stop();
|
||||
assert_eq!(runtime.get_status(), RuntimeStatus::Failed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn model_extraction_from_command_gpt() {
|
||||
let ctx = RuntimeContext {
|
||||
story_id: "1".to_string(),
|
||||
agent_name: "coder".to_string(),
|
||||
command: "gpt-4o".to_string(),
|
||||
args: vec![],
|
||||
prompt: "test".to_string(),
|
||||
cwd: "/tmp".to_string(),
|
||||
inactivity_timeout_secs: 300,
|
||||
mcp_port: 3001,
|
||||
};
|
||||
assert!(ctx.command.starts_with("gpt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn model_extraction_from_command_o3() {
|
||||
let ctx = RuntimeContext {
|
||||
story_id: "1".to_string(),
|
||||
agent_name: "coder".to_string(),
|
||||
command: "o3".to_string(),
|
||||
args: vec![],
|
||||
prompt: "test".to_string(),
|
||||
cwd: "/tmp".to_string(),
|
||||
inactivity_timeout_secs: 300,
|
||||
mcp_port: 3001,
|
||||
};
|
||||
assert!(ctx.command.starts_with("o"));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user