Story 13: Implement Stop button with backend cancellation
- Add tokio watch channel for cancellation signaling - Implement cancel_chat command - Add cancellation checks in streaming loop and before tool execution - Stop button (■) replaces Send button (↑) during generation - Preserve partial streaming content when cancelled - Clean UX: no error messages on cancellation - Backend properly stops streaming and prevents tool execution Closes Story 13
This commit is contained in:
@@ -30,6 +30,16 @@ pub async fn chat(
|
||||
config: ProviderConfig,
|
||||
state: State<'_, SessionState>,
|
||||
) -> Result<Vec<Message>, String> {
|
||||
// Reset cancel flag at start of new request
|
||||
let _ = state.cancel_tx.send(false);
|
||||
|
||||
// Get a clone of the cancellation receiver
|
||||
let mut cancel_rx = state.cancel_rx.clone();
|
||||
|
||||
// Mark the receiver as having seen the current (false) value
|
||||
// This prevents changed() from firing immediately due to stale state
|
||||
cancel_rx.borrow_and_update();
|
||||
|
||||
// 1. Setup Provider
|
||||
let base_url = config
|
||||
.base_url
|
||||
@@ -79,6 +89,11 @@ pub async fn chat(
|
||||
let mut turn_count = 0;
|
||||
|
||||
loop {
|
||||
// Check for cancellation at start of loop
|
||||
if *cancel_rx.borrow() {
|
||||
return Err("Chat cancelled by user".to_string());
|
||||
}
|
||||
|
||||
if turn_count >= MAX_TURNS {
|
||||
return Err("Max conversation turns reached.".to_string());
|
||||
}
|
||||
@@ -86,7 +101,7 @@ pub async fn chat(
|
||||
|
||||
// Call LLM with streaming
|
||||
let response = provider
|
||||
.chat_stream(&app, &config.model, ¤t_history, tools)
|
||||
.chat_stream(&app, &config.model, ¤t_history, tools, &mut cancel_rx)
|
||||
.await
|
||||
.map_err(|e| format!("LLM Error: {}", e))?;
|
||||
|
||||
@@ -108,6 +123,11 @@ pub async fn chat(
|
||||
|
||||
// Execute Tools
|
||||
for call in tool_calls {
|
||||
// Check for cancellation before executing each tool
|
||||
if *cancel_rx.borrow() {
|
||||
return Err("Chat cancelled before tool execution".to_string());
|
||||
}
|
||||
|
||||
let output = execute_tool(&call, &state).await;
|
||||
|
||||
let tool_msg = Message {
|
||||
@@ -289,3 +309,9 @@ fn get_tool_definitions() -> Vec<ToolDefinition> {
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn cancel_chat(state: State<'_, SessionState>) -> Result<(), String> {
|
||||
state.cancel_tx.send(true).map_err(|e| e.to_string())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -23,7 +23,8 @@ pub fn run() {
|
||||
commands::search::search_files,
|
||||
commands::shell::exec_shell,
|
||||
commands::chat::chat,
|
||||
commands::chat::get_ollama_models
|
||||
commands::chat::get_ollama_models,
|
||||
commands::chat::cancel_chat
|
||||
])
|
||||
.run(tauri::generate_context!())
|
||||
.expect("error while running tauri application");
|
||||
|
||||
@@ -47,6 +47,7 @@ impl OllamaProvider {
|
||||
model: &str,
|
||||
messages: &[Message],
|
||||
tools: &[ToolDefinition],
|
||||
cancel_rx: &mut tokio::sync::watch::Receiver<bool>,
|
||||
) -> Result<CompletionResponse, String> {
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!("{}/api/chat", self.base_url.trim_end_matches('/'));
|
||||
@@ -108,7 +109,29 @@ impl OllamaProvider {
|
||||
let mut accumulated_content = String::new();
|
||||
let mut final_tool_calls: Option<Vec<ToolCall>> = None;
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
loop {
|
||||
// Check for cancellation
|
||||
if *cancel_rx.borrow() {
|
||||
return Err("Chat cancelled by user".to_string());
|
||||
}
|
||||
|
||||
let chunk_result = tokio::select! {
|
||||
chunk = stream.next() => {
|
||||
match chunk {
|
||||
Some(c) => c,
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
_ = cancel_rx.changed() => {
|
||||
// changed() fires on any change, check if it's actually true
|
||||
if *cancel_rx.borrow() {
|
||||
return Err("Chat cancelled by user".to_string());
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let chunk = chunk_result.map_err(|e| format!("Stream error: {}", e))?;
|
||||
buffer.push_str(&String::from_utf8_lossy(&chunk));
|
||||
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Mutex;
|
||||
use tokio::sync::watch;
|
||||
|
||||
pub struct SessionState {
|
||||
pub project_root: Mutex<Option<PathBuf>>,
|
||||
pub cancel_tx: watch::Sender<bool>,
|
||||
pub cancel_rx: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl Default for SessionState {
|
||||
fn default() -> Self {
|
||||
let (cancel_tx, cancel_rx) = watch::channel(false);
|
||||
Self {
|
||||
project_root: Mutex::new(None),
|
||||
cancel_tx,
|
||||
cancel_rx,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user