huskies: merge 554_story_multi_project_gateway_that_proxies_mcp_calls_to_per_project_docker_containers

This commit is contained in:
dave
2026-04-13 13:02:41 +00:00
parent 5806156af3
commit 69dab063a8
2 changed files with 719 additions and 1 deletions
+672
View File
@@ -0,0 +1,672 @@
//! Multi-project gateway — proxies MCP calls to per-project Docker containers.
//!
//! When `huskies --gateway` is used, the server starts in gateway mode: it reads
//! a `projects.toml` config that maps project names to container URLs, maintains
//! an "active project" selection, and proxies all MCP tool calls to the active
//! project's container. Gateway-specific tools allow switching projects, querying
//! status, and aggregating health checks across all registered projects.
use poem::EndpointExt;
use poem::handler;
use poem::http::StatusCode;
use poem::web::Data;
use poem::{Body, Request, Response};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::BTreeMap;
use std::path::Path;
use std::sync::Arc;
use tokio::sync::RwLock;
// ── Config ───────────────────────────────────────────────────────────
/// A single project entry in `projects.toml`.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ProjectEntry {
/// Base URL of the project's huskies container (e.g. `http://localhost:3001`).
pub url: String,
}
/// Top-level `projects.toml` config.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct GatewayConfig {
/// Map of project name → container URL.
#[serde(default)]
pub projects: BTreeMap<String, ProjectEntry>,
}
impl GatewayConfig {
/// Load gateway config from a `projects.toml` file.
pub fn load(path: &Path) -> Result<Self, String> {
let contents = std::fs::read_to_string(path)
.map_err(|e| format!("cannot read {}: {e}", path.display()))?;
toml::from_str(&contents)
.map_err(|e| format!("invalid projects.toml: {e}"))
}
}
// ── Gateway state ────────────────────────────────────────────────────
/// Shared gateway state threaded through HTTP handlers.
#[derive(Clone)]
pub struct GatewayState {
/// The parsed gateway config with all registered projects.
pub config: GatewayConfig,
/// The currently active project name.
pub active_project: Arc<RwLock<String>>,
/// HTTP client for proxying requests to project containers.
pub client: Client,
}
impl GatewayState {
/// Create a new gateway state from a config. The first project in the config
/// becomes the active project by default.
pub fn new(config: GatewayConfig) -> Result<Self, String> {
if config.projects.is_empty() {
return Err("projects.toml must define at least one project".to_string());
}
let first = config.projects.keys().next().unwrap().clone();
Ok(Self {
config,
active_project: Arc::new(RwLock::new(first)),
client: Client::new(),
})
}
/// Get the URL of the currently active project.
async fn active_url(&self) -> Result<String, String> {
let name = self.active_project.read().await.clone();
self.config
.projects
.get(&name)
.map(|p| p.url.clone())
.ok_or_else(|| format!("active project '{name}' not found in config"))
}
}
// ── MCP proxy handler ────────────────────────────────────────────────
/// JSON-RPC request (duplicated here to keep the gateway self-contained).
#[derive(Deserialize)]
struct JsonRpcRequest {
jsonrpc: String,
id: Option<Value>,
method: String,
#[serde(default)]
params: Value,
}
/// JSON-RPC response.
#[derive(Serialize)]
struct JsonRpcResponse {
jsonrpc: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
id: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
result: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<JsonRpcError>,
}
#[derive(Serialize)]
struct JsonRpcError {
code: i64,
message: String,
}
impl JsonRpcResponse {
fn success(id: Option<Value>, result: Value) -> Self {
Self { jsonrpc: "2.0", id, result: Some(result), error: None }
}
fn error(id: Option<Value>, code: i64, message: String) -> Self {
Self { jsonrpc: "2.0", id, result: None, error: Some(JsonRpcError { code, message }) }
}
}
fn to_json_response(resp: JsonRpcResponse) -> Response {
let body = serde_json::to_vec(&resp).unwrap_or_default();
Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(Body::from(body))
}
/// Gateway-specific MCP tools exposed alongside the proxied tools.
const GATEWAY_TOOLS: &[&str] = &["switch_project", "gateway_status", "gateway_health"];
/// Main MCP POST handler for the gateway. Intercepts gateway-specific tools and
/// proxies everything else to the active project's container.
#[handler]
pub async fn gateway_mcp_post_handler(
req: &Request,
body: Body,
state: Data<&Arc<GatewayState>>,
) -> Response {
let content_type = req.header("content-type").unwrap_or("");
if !content_type.is_empty() && !content_type.contains("application/json") {
return to_json_response(JsonRpcResponse::error(
None, -32700, "Unsupported Content-Type; expected application/json".into(),
));
}
let bytes = match body.into_bytes().await {
Ok(b) => b,
Err(_) => return to_json_response(JsonRpcResponse::error(None, -32700, "Parse error".into())),
};
let rpc: JsonRpcRequest = match serde_json::from_slice(&bytes) {
Ok(r) => r,
Err(_) => return to_json_response(JsonRpcResponse::error(None, -32700, "Parse error".into())),
};
if rpc.jsonrpc != "2.0" {
return to_json_response(JsonRpcResponse::error(rpc.id, -32600, "Invalid JSON-RPC version".into()));
}
// Accept notifications silently.
if rpc.id.is_none() || rpc.id.as_ref() == Some(&Value::Null) {
if rpc.method.starts_with("notifications/") {
return Response::builder()
.status(StatusCode::ACCEPTED)
.body(Body::empty());
}
return to_json_response(JsonRpcResponse::error(None, -32600, "Missing id".into()));
}
match rpc.method.as_str() {
"initialize" => to_json_response(handle_initialize(rpc.id)),
"tools/list" => {
// Merge gateway tools with proxied tools from the active project.
match merge_tools_list(&state, rpc.id.clone()).await {
Ok(resp) => to_json_response(resp),
Err(e) => to_json_response(JsonRpcResponse::error(rpc.id, -32603, e)),
}
}
"tools/call" => {
let tool_name = rpc.params
.get("name")
.and_then(|v| v.as_str())
.unwrap_or("");
if GATEWAY_TOOLS.contains(&tool_name) {
to_json_response(handle_gateway_tool(tool_name, &rpc.params, &state).await)
} else {
// Proxy to active project's container.
match proxy_mcp_call(&state, &bytes).await {
Ok(resp_body) => Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(Body::from(resp_body)),
Err(e) => to_json_response(JsonRpcResponse::error(
rpc.id, -32603, format!("proxy error: {e}"),
)),
}
}
}
_ => {
// Proxy unknown methods too.
match proxy_mcp_call(&state, &bytes).await {
Ok(resp_body) => Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(Body::from(resp_body)),
Err(e) => to_json_response(JsonRpcResponse::error(
rpc.id, -32603, format!("proxy error: {e}"),
)),
}
}
}
}
/// GET handler — method not allowed (matches the regular MCP endpoint behavior).
#[handler]
pub async fn gateway_mcp_get_handler() -> Response {
Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.body(Body::empty())
}
// ── Protocol handlers ────────────────────────────────────────────────
fn handle_initialize(id: Option<Value>) -> JsonRpcResponse {
JsonRpcResponse::success(
id,
json!({
"protocolVersion": "2025-03-26",
"capabilities": { "tools": {} },
"serverInfo": {
"name": "huskies-gateway",
"version": "1.0.0"
}
}),
)
}
/// Gateway tool definitions.
fn gateway_tool_definitions() -> Vec<Value> {
vec![
json!({
"name": "switch_project",
"description": "Switch the active project. All subsequent MCP tool calls will be proxied to this project's container.",
"inputSchema": {
"type": "object",
"properties": {
"project": {
"type": "string",
"description": "Name of the project to switch to (must exist in projects.toml)"
}
},
"required": ["project"]
}
}),
json!({
"name": "gateway_status",
"description": "Show pipeline status for the active project by proxying the get_pipeline_status tool call.",
"inputSchema": {
"type": "object",
"properties": {}
}
}),
json!({
"name": "gateway_health",
"description": "Health check aggregation across all registered projects. Returns the health status of every project container.",
"inputSchema": {
"type": "object",
"properties": {}
}
}),
]
}
/// Fetch tools/list from the active project and merge in gateway tools.
async fn merge_tools_list(
state: &GatewayState,
id: Option<Value>,
) -> Result<JsonRpcResponse, String> {
let url = state.active_url().await?;
let mcp_url = format!("{}/mcp", url.trim_end_matches('/'));
let rpc_body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list",
"params": {}
});
let resp = state.client
.post(&mcp_url)
.json(&rpc_body)
.send()
.await
.map_err(|e| format!("failed to reach {mcp_url}: {e}"))?;
let resp_json: Value = resp.json().await
.map_err(|e| format!("invalid JSON from upstream: {e}"))?;
let mut tools: Vec<Value> = resp_json
.get("result")
.and_then(|r| r.get("tools"))
.and_then(|t| t.as_array())
.cloned()
.unwrap_or_default();
// Prepend gateway-specific tools.
let mut all_tools = gateway_tool_definitions();
all_tools.append(&mut tools);
Ok(JsonRpcResponse::success(id, json!({ "tools": all_tools })))
}
/// Proxy a raw MCP request body to the active project's container.
async fn proxy_mcp_call(
state: &GatewayState,
request_bytes: &[u8],
) -> Result<Vec<u8>, String> {
let url = state.active_url().await?;
let mcp_url = format!("{}/mcp", url.trim_end_matches('/'));
let resp = state.client
.post(&mcp_url)
.header("Content-Type", "application/json")
.body(request_bytes.to_vec())
.send()
.await
.map_err(|e| format!("failed to reach {mcp_url}: {e}"))?;
resp.bytes()
.await
.map(|b| b.to_vec())
.map_err(|e| format!("failed to read response from {mcp_url}: {e}"))
}
// ── Gateway-specific tools ───────────────────────────────────────────
/// Dispatch a gateway-specific tool call.
async fn handle_gateway_tool(
tool_name: &str,
params: &Value,
state: &GatewayState,
) -> JsonRpcResponse {
let id = None; // The caller wraps this in a proper response.
match tool_name {
"switch_project" => handle_switch_project(params, state).await,
"gateway_status" => handle_gateway_status(state).await,
"gateway_health" => handle_gateway_health(state).await,
_ => JsonRpcResponse::error(id, -32601, format!("Unknown gateway tool: {tool_name}")),
}
}
/// Switch the active project.
async fn handle_switch_project(params: &Value, state: &GatewayState) -> JsonRpcResponse {
let project = params
.get("arguments")
.and_then(|a| a.get("project"))
.or_else(|| params.get("project"))
.and_then(|v| v.as_str())
.unwrap_or("");
if project.is_empty() {
return JsonRpcResponse::error(None, -32602, "missing required parameter: project".into());
}
if !state.config.projects.contains_key(project) {
let available: Vec<&str> = state.config.projects.keys().map(|s| s.as_str()).collect();
return JsonRpcResponse::error(
None, -32602,
format!("unknown project '{project}'. Available: {}", available.join(", ")),
);
}
*state.active_project.write().await = project.to_string();
let url = &state.config.projects[project].url;
JsonRpcResponse::success(
None,
json!({
"content": [{
"type": "text",
"text": format!("Switched to project '{project}' ({})", url)
}]
}),
)
}
/// Show pipeline status for the active project by proxying `get_pipeline_status`.
async fn handle_gateway_status(state: &GatewayState) -> JsonRpcResponse {
let active = state.active_project.read().await.clone();
let url = match state.active_url().await {
Ok(u) => u,
Err(e) => return JsonRpcResponse::error(None, -32603, e),
};
let mcp_url = format!("{}/mcp", url.trim_end_matches('/'));
let rpc_body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {
"name": "get_pipeline_status",
"arguments": {}
}
});
match state.client.post(&mcp_url).json(&rpc_body).send().await {
Ok(resp) => {
match resp.json::<Value>().await {
Ok(upstream) => {
// Extract the result from the upstream response and wrap it.
let pipeline = upstream.get("result").cloned().unwrap_or(json!(null));
JsonRpcResponse::success(
None,
json!({
"content": [{
"type": "text",
"text": format!(
"Pipeline status for '{active}':\n{}",
serde_json::to_string_pretty(&pipeline).unwrap_or_default()
)
}]
}),
)
}
Err(e) => JsonRpcResponse::error(None, -32603, format!("invalid upstream response: {e}")),
}
}
Err(e) => JsonRpcResponse::error(None, -32603, format!("failed to reach {mcp_url}: {e}")),
}
}
/// Aggregate health checks across all registered projects.
async fn handle_gateway_health(state: &GatewayState) -> JsonRpcResponse {
let mut results = BTreeMap::new();
for (name, entry) in &state.config.projects {
let health_url = format!("{}/health", entry.url.trim_end_matches('/'));
let status = match state.client.get(&health_url).send().await {
Ok(resp) => {
if resp.status().is_success() {
"healthy".to_string()
} else {
format!("unhealthy (HTTP {})", resp.status().as_u16())
}
}
Err(e) => format!("unreachable: {e}"),
};
results.insert(name.clone(), status);
}
let active = state.active_project.read().await.clone();
JsonRpcResponse::success(
None,
json!({
"content": [{
"type": "text",
"text": format!(
"Health check (active: '{active}'):\n{}",
results.iter()
.map(|(name, status)| format!(" {name}: {status}"))
.collect::<Vec<_>>()
.join("\n")
)
}]
}),
)
}
// ── Health aggregation endpoint ──────────────────────────────────────
/// HTTP GET `/health` handler for the gateway — aggregates health from all projects.
#[handler]
pub async fn gateway_health_handler(state: Data<&Arc<GatewayState>>) -> Response {
let mut all_healthy = true;
let mut statuses = BTreeMap::new();
for (name, entry) in &state.config.projects {
let health_url = format!("{}/health", entry.url.trim_end_matches('/'));
let healthy = match state.client.get(&health_url).send().await {
Ok(resp) => resp.status().is_success(),
Err(_) => false,
};
if !healthy {
all_healthy = false;
}
statuses.insert(name.clone(), if healthy { "ok" } else { "error" });
}
let body = json!({
"status": if all_healthy { "ok" } else { "degraded" },
"projects": statuses,
});
let status = if all_healthy { StatusCode::OK } else { StatusCode::SERVICE_UNAVAILABLE };
Response::builder()
.status(status)
.header("Content-Type", "application/json")
.body(Body::from(serde_json::to_vec(&body).unwrap_or_default()))
}
// ── Gateway server startup ───────────────────────────────────────────
/// Start the gateway HTTP server. This is the entry point when `--gateway` is used.
pub async fn run(config_path: &Path, port: u16) -> Result<(), std::io::Error> {
let config = GatewayConfig::load(config_path).map_err(std::io::Error::other)?;
let state = GatewayState::new(config).map_err(std::io::Error::other)?;
let state_arc = Arc::new(state);
let active = state_arc.active_project.read().await.clone();
crate::slog!("[gateway] Starting gateway on port {port}, active project: {active}");
crate::slog!(
"[gateway] Registered projects: {}",
state_arc.config.projects.keys().cloned().collect::<Vec<_>>().join(", ")
);
let route = poem::Route::new()
.at(
"/mcp",
poem::post(gateway_mcp_post_handler).get(gateway_mcp_get_handler),
)
.at("/health", poem::get(gateway_health_handler))
.data(state_arc);
let host = std::env::var("HUSKIES_HOST").unwrap_or_else(|_| "127.0.0.1".to_string());
let addr = format!("{host}:{port}");
crate::slog!("[gateway] Listening on {addr}");
poem::Server::new(poem::listener::TcpListener::bind(&addr))
.run(route)
.await
}
// ── Tests ────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_valid_projects_toml() {
let toml_str = r#"
[projects.huskies]
url = "http://localhost:3001"
[projects.robot-studio]
url = "http://localhost:3002"
"#;
let config: GatewayConfig = toml::from_str(toml_str).unwrap();
assert_eq!(config.projects.len(), 2);
assert_eq!(config.projects["huskies"].url, "http://localhost:3001");
assert_eq!(config.projects["robot-studio"].url, "http://localhost:3002");
}
#[test]
fn parse_empty_projects_toml() {
let toml_str = "[projects]\n";
let config: GatewayConfig = toml::from_str(toml_str).unwrap();
assert!(config.projects.is_empty());
}
#[test]
fn gateway_state_rejects_empty_config() {
let config = GatewayConfig { projects: BTreeMap::new() };
assert!(GatewayState::new(config).is_err());
}
#[test]
fn gateway_state_sets_first_project_active() {
let mut projects = BTreeMap::new();
projects.insert("alpha".into(), ProjectEntry { url: "http://a:3001".into() });
projects.insert("beta".into(), ProjectEntry { url: "http://b:3002".into() });
let config = GatewayConfig { projects };
let state = GatewayState::new(config).unwrap();
let active = state.active_project.blocking_read().clone();
assert_eq!(active, "alpha"); // BTreeMap sorts alphabetically.
}
#[test]
fn gateway_tool_definitions_has_expected_tools() {
let defs = gateway_tool_definitions();
let names: Vec<&str> = defs.iter()
.filter_map(|d| d.get("name").and_then(|n| n.as_str()))
.collect();
assert!(names.contains(&"switch_project"));
assert!(names.contains(&"gateway_status"));
assert!(names.contains(&"gateway_health"));
}
#[tokio::test]
async fn switch_project_to_known_project() {
let mut projects = BTreeMap::new();
projects.insert("alpha".into(), ProjectEntry { url: "http://a:3001".into() });
projects.insert("beta".into(), ProjectEntry { url: "http://b:3002".into() });
let config = GatewayConfig { projects };
let state = GatewayState::new(config).unwrap();
let params = json!({ "arguments": { "project": "beta" } });
let resp = handle_switch_project(&params, &state).await;
assert!(resp.result.is_some());
let active = state.active_project.read().await.clone();
assert_eq!(active, "beta");
}
#[tokio::test]
async fn switch_project_to_unknown_project_fails() {
let mut projects = BTreeMap::new();
projects.insert("alpha".into(), ProjectEntry { url: "http://a:3001".into() });
let config = GatewayConfig { projects };
let state = GatewayState::new(config).unwrap();
let params = json!({ "arguments": { "project": "nonexistent" } });
let resp = handle_switch_project(&params, &state).await;
assert!(resp.error.is_some());
}
#[tokio::test]
async fn active_url_returns_correct_url() {
let mut projects = BTreeMap::new();
projects.insert("myproj".into(), ProjectEntry { url: "http://my:3001".into() });
let config = GatewayConfig { projects };
let state = GatewayState::new(config).unwrap();
let url = state.active_url().await.unwrap();
assert_eq!(url, "http://my:3001");
}
#[test]
fn json_rpc_response_success_serializes() {
let resp = JsonRpcResponse::success(Some(json!(1)), json!({"ok": true}));
let s = serde_json::to_string(&resp).unwrap();
assert!(s.contains("\"result\""));
assert!(!s.contains("\"error\""));
}
#[test]
fn json_rpc_response_error_serializes() {
let resp = JsonRpcResponse::error(Some(json!(1)), -32600, "bad".into());
let s = serde_json::to_string(&resp).unwrap();
assert!(s.contains("\"error\""));
assert!(!s.contains("\"result\""));
}
#[test]
fn load_config_from_file() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("projects.toml");
std::fs::write(&path, r#"
[projects.test]
url = "http://localhost:9999"
"#).unwrap();
let config = GatewayConfig::load(&path).unwrap();
assert_eq!(config.projects.len(), 1);
assert_eq!(config.projects["test"].url, "http://localhost:9999");
}
#[test]
fn load_config_missing_file_fails() {
let result = GatewayConfig::load(Path::new("/nonexistent/projects.toml"));
assert!(result.is_err());
}
}