use crate::http::context::AppContext; use crate::llm::chat; use crate::llm::types::Message; use futures::{SinkExt, StreamExt}; use poem::handler; use poem::web::Data; use poem::web::websocket::{Message as WsMessage, WebSocket}; use serde::{Deserialize, Serialize}; use tokio::sync::mpsc; #[derive(Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] /// WebSocket request messages sent by the client. /// /// - `chat` starts a streaming chat session. /// - `cancel` stops the active session. enum WsRequest { Chat { messages: Vec, config: chat::ProviderConfig, }, Cancel, } #[derive(Serialize)] #[serde(tag = "type", rename_all = "snake_case")] /// WebSocket response messages sent by the server. /// /// - `token` streams partial model output. /// - `update` pushes the updated message history. /// - `error` reports a request or processing failure. enum WsResponse { Token { content: String }, Update { messages: Vec }, Error { message: String }, } #[handler] /// WebSocket endpoint for streaming chat responses and cancellation. /// /// Accepts JSON `WsRequest` messages and streams `WsResponse` messages. pub async fn ws_handler(ws: WebSocket, ctx: Data<&AppContext>) -> impl poem::IntoResponse { let ctx = ctx.0.clone(); ws.on_upgrade(move |socket| async move { let (mut sink, mut stream) = socket.split(); let (tx, mut rx) = mpsc::unbounded_channel::(); let forward = tokio::spawn(async move { while let Some(msg) = rx.recv().await { if let Ok(text) = serde_json::to_string(&msg) && sink.send(WsMessage::Text(text)).await.is_err() { break; } } }); while let Some(Ok(msg)) = stream.next().await { if let WsMessage::Text(text) = msg { let parsed: Result = serde_json::from_str(&text); match parsed { Ok(WsRequest::Chat { messages, config }) => { let tx_updates = tx.clone(); let tx_tokens = tx.clone(); let ctx_clone = ctx.clone(); let result = chat::chat( messages, config, &ctx_clone.state, ctx_clone.store.as_ref(), |history| { let _ = tx_updates.send(WsResponse::Update { messages: history.to_vec(), }); }, |token| { let _ = tx_tokens.send(WsResponse::Token { content: token.to_string(), }); }, ) .await; if let Err(err) = result { let _ = tx.send(WsResponse::Error { message: err }); } } Ok(WsRequest::Cancel) => { let _ = chat::cancel_chat(&ctx.state); } Err(err) => { let _ = tx.send(WsResponse::Error { message: format!("Invalid request: {err}"), }); } } } } drop(tx); let _ = forward.await; }) }