mod commands; mod llm; mod state; mod store; use crate::commands::{chat, fs}; use crate::llm::types::Message; use crate::state::SessionState; use crate::store::JsonFileStore; use futures::{SinkExt, StreamExt}; use poem::web::websocket::{Message as WsMessage, WebSocket}; use poem::{ EndpointExt, Response, Route, Server, get, handler, http::{StatusCode, header}, listener::TcpListener, web::{Data, Path}, }; use poem_openapi::{Object, OpenApi, OpenApiService, param::Query, payload::Json}; use rust_embed::RustEmbed; use serde::{Deserialize, Serialize}; use std::path::PathBuf; use std::sync::Arc; use tokio::sync::mpsc; #[derive(Clone)] struct AppContext { state: Arc, store: Arc, } #[derive(RustEmbed)] #[folder = "../frontend/dist"] struct EmbeddedAssets; type OpenApiResult = poem::Result; fn bad_request(message: String) -> poem::Error { poem::Error::from_string(message, StatusCode::BAD_REQUEST) } #[handler] fn health() -> &'static str { "ok" } fn serve_embedded(path: &str) -> Response { let normalized = if path.is_empty() { "index.html" } else { path.trim_start_matches('/') }; let is_asset_request = normalized.starts_with("assets/"); let asset = if is_asset_request { EmbeddedAssets::get(normalized) } else { EmbeddedAssets::get(normalized).or_else(|| { if normalized == "index.html" { None } else { EmbeddedAssets::get("index.html") } }) }; match asset { Some(content) => { let body = content.data.into_owned(); let mime = mime_guess::from_path(normalized) .first_or_octet_stream() .to_string(); Response::builder() .status(StatusCode::OK) .header(header::CONTENT_TYPE, mime) .body(body) } None => Response::builder() .status(StatusCode::NOT_FOUND) .body("Not Found"), } } #[handler] fn embedded_asset(Path(path): Path) -> Response { let asset_path = format!("assets/{path}"); serve_embedded(&asset_path) } #[handler] fn embedded_file(Path(path): Path) -> Response { serve_embedded(&path) } #[handler] fn embedded_index() -> Response { serve_embedded("index.html") } #[derive(Deserialize, Object)] struct PathPayload { path: String, } #[derive(Deserialize, Object)] struct ModelPayload { model: String, } #[derive(Deserialize, Object)] struct ApiKeyPayload { api_key: String, } #[derive(Deserialize, Object)] struct FilePathPayload { path: String, } #[derive(Deserialize, Object)] struct WriteFilePayload { path: String, content: String, } #[derive(Deserialize, Object)] struct SearchPayload { query: String, } #[derive(Deserialize, Object)] struct ExecShellPayload { command: String, args: Vec, } struct Api { ctx: Arc, } #[OpenApi] impl Api { #[oai(path = "/project", method = "get")] async fn get_current_project(&self) -> OpenApiResult>> { let ctx = self.ctx.clone(); let result = fs::get_current_project(&ctx.state, ctx.store.as_ref()).map_err(bad_request)?; Ok(Json(result)) } #[oai(path = "/project", method = "post")] async fn open_project(&self, payload: Json) -> OpenApiResult> { let ctx = self.ctx.clone(); let confirmed = fs::open_project(payload.0.path, &ctx.state, ctx.store.as_ref()) .await .map_err(bad_request)?; Ok(Json(confirmed)) } #[oai(path = "/project", method = "delete")] async fn close_project(&self) -> OpenApiResult> { let ctx = self.ctx.clone(); fs::close_project(&ctx.state, ctx.store.as_ref()).map_err(bad_request)?; Ok(Json(true)) } #[oai(path = "/model", method = "get")] async fn get_model_preference(&self) -> OpenApiResult>> { let ctx = self.ctx.clone(); let result = fs::get_model_preference(ctx.store.as_ref()).map_err(bad_request)?; Ok(Json(result)) } #[oai(path = "/model", method = "post")] async fn set_model_preference(&self, payload: Json) -> OpenApiResult> { let ctx = self.ctx.clone(); fs::set_model_preference(payload.0.model, ctx.store.as_ref()).map_err(bad_request)?; Ok(Json(true)) } #[oai(path = "/ollama/models", method = "get")] async fn get_ollama_models( &self, base_url: Query>, ) -> OpenApiResult>> { let models = chat::get_ollama_models(base_url.0) .await .map_err(bad_request)?; Ok(Json(models)) } #[oai(path = "/anthropic/key/exists", method = "get")] async fn get_anthropic_api_key_exists(&self) -> OpenApiResult> { let ctx = self.ctx.clone(); let exists = chat::get_anthropic_api_key_exists(ctx.store.as_ref()).map_err(bad_request)?; Ok(Json(exists)) } #[oai(path = "/anthropic/key", method = "post")] async fn set_anthropic_api_key( &self, payload: Json, ) -> OpenApiResult> { let ctx = self.ctx.clone(); chat::set_anthropic_api_key(ctx.store.as_ref(), payload.0.api_key).map_err(bad_request)?; Ok(Json(true)) } #[oai(path = "/fs/read", method = "post")] async fn read_file(&self, payload: Json) -> OpenApiResult> { let ctx = self.ctx.clone(); let content = fs::read_file(payload.0.path, &ctx.state) .await .map_err(bad_request)?; Ok(Json(content)) } #[oai(path = "/fs/write", method = "post")] async fn write_file(&self, payload: Json) -> OpenApiResult> { let ctx = self.ctx.clone(); fs::write_file(payload.0.path, payload.0.content, &ctx.state) .await .map_err(bad_request)?; Ok(Json(true)) } #[oai(path = "/fs/list", method = "post")] async fn list_directory( &self, payload: Json, ) -> OpenApiResult>> { let ctx = self.ctx.clone(); let entries = fs::list_directory(payload.0.path, &ctx.state) .await .map_err(bad_request)?; Ok(Json(entries)) } #[oai(path = "/fs/search", method = "post")] async fn search_files( &self, payload: Json, ) -> OpenApiResult>> { let ctx = self.ctx.clone(); let results = crate::commands::search::search_files(payload.0.query, &ctx.state) .await .map_err(bad_request)?; Ok(Json(results)) } #[oai(path = "/shell/exec", method = "post")] async fn exec_shell( &self, payload: Json, ) -> OpenApiResult> { let ctx = self.ctx.clone(); let output = crate::commands::shell::exec_shell(payload.0.command, payload.0.args, &ctx.state) .await .map_err(bad_request)?; Ok(Json(output)) } #[oai(path = "/chat/cancel", method = "post")] async fn cancel_chat(&self) -> OpenApiResult> { let ctx = self.ctx.clone(); chat::cancel_chat(&ctx.state).map_err(bad_request)?; Ok(Json(true)) } } #[derive(Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] enum WsRequest { Chat { messages: Vec, config: chat::ProviderConfig, }, Cancel, } #[derive(Serialize)] #[serde(tag = "type", rename_all = "snake_case")] enum WsResponse { Token { content: String }, Update { messages: Vec }, Error { message: String }, } #[handler] async fn ws_handler(ws: WebSocket, ctx: Data<&AppContext>) -> impl poem::IntoResponse { let ctx = ctx.0.clone(); ws.on_upgrade(move |socket| async move { let (mut sink, mut stream) = socket.split(); let (tx, mut rx) = mpsc::unbounded_channel::(); let forward = tokio::spawn(async move { while let Some(msg) = rx.recv().await { if let Ok(text) = serde_json::to_string(&msg) && sink.send(WsMessage::Text(text)).await.is_err() { break; } } }); while let Some(Ok(msg)) = stream.next().await { if let WsMessage::Text(text) = msg { let parsed: Result = serde_json::from_str(&text); match parsed { Ok(WsRequest::Chat { messages, config }) => { let tx_updates = tx.clone(); let tx_tokens = tx.clone(); let ctx_clone = ctx.clone(); let result = chat::chat( messages, config, &ctx_clone.state, ctx_clone.store.as_ref(), |history| { let _ = tx_updates.send(WsResponse::Update { messages: history.to_vec(), }); }, |token| { let _ = tx_tokens.send(WsResponse::Token { content: token.to_string(), }); }, ) .await; if let Err(err) = result { let _ = tx.send(WsResponse::Error { message: err }); } } Ok(WsRequest::Cancel) => { let _ = chat::cancel_chat(&ctx.state); } Err(err) => { let _ = tx.send(WsResponse::Error { message: format!("Invalid request: {err}"), }); } } } } drop(tx); let _ = forward.await; }) } #[tokio::main] async fn main() -> Result<(), std::io::Error> { let app_state = Arc::new(SessionState::default()); let store = Arc::new( JsonFileStore::from_path(PathBuf::from("store.json")).map_err(std::io::Error::other)?, ); let ctx = AppContext { state: app_state, store, }; let ctx_arc = Arc::new(ctx.clone()); let api_service = OpenApiService::new( Api { ctx: ctx_arc.clone(), }, "Living Spec API", "1.0", ) .server("http://127.0.0.1:3001/api"); let docs_service = OpenApiService::new( Api { ctx: ctx_arc.clone(), }, "Living Spec API", "1.0", ) .server("http://127.0.0.1:3001/api"); let app = Route::new() .nest("/api", api_service) .nest("/docs", docs_service.swagger_ui()) .at("/ws", get(ws_handler)) .at("/health", get(health)) .at("/assets/*path", get(embedded_asset)) .at("/", get(embedded_index)) .at("/*path", get(embedded_file)) .data(ctx); Server::new(TcpListener::bind("127.0.0.1:3001")) .run(app) .await }