Files
storkit/server/src/http/agents_sse.rs

209 lines
7.3 KiB
Rust
Raw Normal View History

use crate::http::context::AppContext;
use poem::handler;
use poem::http::StatusCode;
use poem::web::{Data, Path};
use poem::{Body, IntoResponse, Response};
use std::sync::Arc;
/// SSE endpoint: `GET /agents/:story_id/:agent_name/stream`
///
/// Streams `AgentEvent`s as Server-Sent Events. Each event is JSON-encoded
/// with `data:` prefix and double newline terminator per the SSE spec.
///
/// `AgentEvent::Thinking` events are intentionally excluded — thinking traces
/// are internal model state and must never be displayed in the UI.
#[handler]
pub async fn agent_stream(
Path((story_id, agent_name)): Path<(String, String)>,
ctx: Data<&Arc<AppContext>>,
) -> impl IntoResponse {
let mut rx = match ctx.agents.subscribe(&story_id, &agent_name) {
Ok(rx) => rx,
Err(e) => {
return Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from_string(e));
}
};
let stream = async_stream::stream! {
loop {
match rx.recv().await {
Ok(event) => {
// Never forward thinking traces to the UI — they are
// internal model state and must not be displayed.
if matches!(event, crate::agents::AgentEvent::Thinking { .. }) {
continue;
}
if let Ok(json) = serde_json::to_string(&event) {
yield Ok::<_, std::io::Error>(format!("data: {json}\n\n"));
}
// Check for terminal events
match &event {
crate::agents::AgentEvent::Done { .. }
| crate::agents::AgentEvent::Error { .. } => break,
crate::agents::AgentEvent::Status { status, .. }
if status == "stopped" => break,
_ => {}
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
let msg = format!("{{\"type\":\"warning\",\"message\":\"Skipped {n} events\"}}");
yield Ok::<_, std::io::Error>(format!("data: {msg}\n\n"));
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
}
}
};
Response::builder()
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache")
.header("Connection", "keep-alive")
.body(Body::from_bytes_stream(
futures::StreamExt::map(stream, |r| r.map(bytes::Bytes::from)),
))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agents::{AgentEvent, AgentStatus};
use crate::http::context::AppContext;
use poem::{EndpointExt, Route, get};
use std::sync::Arc;
use tempfile::tempdir;
fn test_app(ctx: Arc<AppContext>) -> impl poem::Endpoint {
Route::new()
.at(
"/agents/:story_id/:agent_name/stream",
get(agent_stream),
)
.data(ctx)
}
#[tokio::test]
async fn thinking_events_are_not_forwarded_via_sse() {
let tmp = tempdir().unwrap();
let ctx = Arc::new(AppContext::new_test(tmp.path().to_path_buf()));
// Inject a running agent and get its broadcast sender.
let tx = ctx
.agents
.inject_test_agent("1_story", "coder-1", AgentStatus::Running);
// Spawn a task that sends events after the SSE connection is established.
let tx_clone = tx.clone();
tokio::spawn(async move {
// Brief pause so the SSE handler has subscribed before we emit.
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
// Thinking event — must be filtered out.
let _ = tx_clone.send(AgentEvent::Thinking {
story_id: "1_story".to_string(),
agent_name: "coder-1".to_string(),
text: "secret thinking text".to_string(),
});
// Output event — must be forwarded.
let _ = tx_clone.send(AgentEvent::Output {
story_id: "1_story".to_string(),
agent_name: "coder-1".to_string(),
text: "visible output".to_string(),
});
// Done event — closes the stream.
let _ = tx_clone.send(AgentEvent::Done {
story_id: "1_story".to_string(),
agent_name: "coder-1".to_string(),
session_id: None,
});
});
let cli = poem::test::TestClient::new(test_app(ctx));
let resp = cli
.get("/agents/1_story/coder-1/stream")
.send()
.await;
let body = resp.0.into_body().into_string().await.unwrap();
// Thinking content must not appear anywhere in the SSE output.
assert!(
!body.contains("secret thinking text"),
"Thinking text must not be forwarded via SSE: {body}"
);
assert!(
!body.contains("\"type\":\"thinking\""),
"Thinking event type must not appear in SSE output: {body}"
);
// Output event must be present.
assert!(
body.contains("visible output"),
"Output event must be forwarded via SSE: {body}"
);
assert!(
body.contains("\"type\":\"output\""),
"Output event type must appear in SSE output: {body}"
);
}
#[tokio::test]
async fn output_and_done_events_are_forwarded_via_sse() {
let tmp = tempdir().unwrap();
let ctx = Arc::new(AppContext::new_test(tmp.path().to_path_buf()));
let tx = ctx
.agents
.inject_test_agent("2_story", "coder-1", AgentStatus::Running);
let tx_clone = tx.clone();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
let _ = tx_clone.send(AgentEvent::Output {
story_id: "2_story".to_string(),
agent_name: "coder-1".to_string(),
text: "step 1 output".to_string(),
});
let _ = tx_clone.send(AgentEvent::Done {
story_id: "2_story".to_string(),
agent_name: "coder-1".to_string(),
session_id: Some("sess-abc".to_string()),
});
});
let cli = poem::test::TestClient::new(test_app(ctx));
let resp = cli
.get("/agents/2_story/coder-1/stream")
.send()
.await;
let body = resp.0.into_body().into_string().await.unwrap();
assert!(body.contains("step 1 output"), "Output must be forwarded: {body}");
assert!(body.contains("\"type\":\"done\""), "Done event must be forwarded: {body}");
}
#[tokio::test]
async fn unknown_agent_returns_404() {
let tmp = tempdir().unwrap();
let ctx = Arc::new(AppContext::new_test(tmp.path().to_path_buf()));
let cli = poem::test::TestClient::new(test_app(ctx));
let resp = cli
.get("/agents/nonexistent/coder-1/stream")
.send()
.await;
assert_eq!(
resp.0.status(),
poem::http::StatusCode::NOT_FOUND,
"Unknown agent must return 404"
);
}
}