2026-02-16 16:24:21 +00:00
|
|
|
use crate::http::context::AppContext;
|
|
|
|
|
use crate::llm::chat;
|
|
|
|
|
use crate::llm::types::Message;
|
|
|
|
|
use futures::{SinkExt, StreamExt};
|
|
|
|
|
use poem::handler;
|
|
|
|
|
use poem::web::Data;
|
|
|
|
|
use poem::web::websocket::{Message as WsMessage, WebSocket};
|
|
|
|
|
use serde::{Deserialize, Serialize};
|
2026-02-16 18:57:39 +00:00
|
|
|
use std::sync::Arc;
|
2026-02-16 16:24:21 +00:00
|
|
|
use tokio::sync::mpsc;
|
|
|
|
|
|
|
|
|
|
#[derive(Deserialize)]
|
|
|
|
|
#[serde(tag = "type", rename_all = "snake_case")]
|
2026-02-16 16:55:59 +00:00
|
|
|
/// WebSocket request messages sent by the client.
|
|
|
|
|
///
|
|
|
|
|
/// - `chat` starts a streaming chat session.
|
|
|
|
|
/// - `cancel` stops the active session.
|
2026-02-16 16:24:21 +00:00
|
|
|
enum WsRequest {
|
|
|
|
|
Chat {
|
|
|
|
|
messages: Vec<Message>,
|
|
|
|
|
config: chat::ProviderConfig,
|
|
|
|
|
},
|
|
|
|
|
Cancel,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Serialize)]
|
|
|
|
|
#[serde(tag = "type", rename_all = "snake_case")]
|
2026-02-16 16:55:59 +00:00
|
|
|
/// WebSocket response messages sent by the server.
|
|
|
|
|
///
|
|
|
|
|
/// - `token` streams partial model output.
|
|
|
|
|
/// - `update` pushes the updated message history.
|
|
|
|
|
/// - `error` reports a request or processing failure.
|
2026-02-16 16:24:21 +00:00
|
|
|
enum WsResponse {
|
|
|
|
|
Token { content: String },
|
|
|
|
|
Update { messages: Vec<Message> },
|
2026-02-20 11:51:19 +00:00
|
|
|
/// Session ID for Claude Code conversation resumption.
|
|
|
|
|
SessionId { session_id: String },
|
2026-02-16 16:24:21 +00:00
|
|
|
Error { message: String },
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[handler]
|
2026-02-16 16:55:59 +00:00
|
|
|
/// WebSocket endpoint for streaming chat responses and cancellation.
|
|
|
|
|
///
|
|
|
|
|
/// Accepts JSON `WsRequest` messages and streams `WsResponse` messages.
|
2026-02-16 18:57:39 +00:00
|
|
|
pub async fn ws_handler(ws: WebSocket, ctx: Data<&Arc<AppContext>>) -> impl poem::IntoResponse {
|
2026-02-16 16:24:21 +00:00
|
|
|
let ctx = ctx.0.clone();
|
|
|
|
|
ws.on_upgrade(move |socket| async move {
|
|
|
|
|
let (mut sink, mut stream) = socket.split();
|
|
|
|
|
let (tx, mut rx) = mpsc::unbounded_channel::<WsResponse>();
|
|
|
|
|
|
|
|
|
|
let forward = tokio::spawn(async move {
|
|
|
|
|
while let Some(msg) = rx.recv().await {
|
|
|
|
|
if let Ok(text) = serde_json::to_string(&msg)
|
|
|
|
|
&& sink.send(WsMessage::Text(text)).await.is_err()
|
|
|
|
|
{
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
while let Some(Ok(msg)) = stream.next().await {
|
|
|
|
|
if let WsMessage::Text(text) = msg {
|
|
|
|
|
let parsed: Result<WsRequest, _> = serde_json::from_str(&text);
|
|
|
|
|
match parsed {
|
|
|
|
|
Ok(WsRequest::Chat { messages, config }) => {
|
|
|
|
|
let tx_updates = tx.clone();
|
|
|
|
|
let tx_tokens = tx.clone();
|
|
|
|
|
let ctx_clone = ctx.clone();
|
|
|
|
|
|
|
|
|
|
let result = chat::chat(
|
|
|
|
|
messages,
|
|
|
|
|
config,
|
|
|
|
|
&ctx_clone.state,
|
|
|
|
|
ctx_clone.store.as_ref(),
|
|
|
|
|
|history| {
|
|
|
|
|
let _ = tx_updates.send(WsResponse::Update {
|
|
|
|
|
messages: history.to_vec(),
|
|
|
|
|
});
|
|
|
|
|
},
|
|
|
|
|
|token| {
|
|
|
|
|
let _ = tx_tokens.send(WsResponse::Token {
|
|
|
|
|
content: token.to_string(),
|
|
|
|
|
});
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
.await;
|
|
|
|
|
|
2026-02-20 11:51:19 +00:00
|
|
|
match result {
|
|
|
|
|
Ok(chat_result) => {
|
|
|
|
|
if let Some(sid) = chat_result.session_id {
|
|
|
|
|
let _ = tx.send(WsResponse::SessionId { session_id: sid });
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
Err(err) => {
|
|
|
|
|
let _ = tx.send(WsResponse::Error { message: err });
|
|
|
|
|
}
|
2026-02-16 16:24:21 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
Ok(WsRequest::Cancel) => {
|
|
|
|
|
let _ = chat::cancel_chat(&ctx.state);
|
|
|
|
|
}
|
|
|
|
|
Err(err) => {
|
|
|
|
|
let _ = tx.send(WsResponse::Error {
|
|
|
|
|
message: format!("Invalid request: {err}"),
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
drop(tx);
|
|
|
|
|
let _ = forward.await;
|
|
|
|
|
})
|
|
|
|
|
}
|