diff --git a/.story_kit/stories/current/43_unified_chat_ui_for_claude_code_and_regular_chat.md b/.story_kit/stories/current/43_unified_chat_ui_for_claude_code_and_regular_chat.md index 35643ed..12043e7 100644 --- a/.story_kit/stories/current/43_unified_chat_ui_for_claude_code_and_regular_chat.md +++ b/.story_kit/stories/current/43_unified_chat_ui_for_claude_code_and_regular_chat.md @@ -1,6 +1,6 @@ --- name: Unified Chat UI for Claude Code and Regular Chat -test_plan: pending +test_plan: approved --- # Story 43: Unified Chat UI for Claude Code and Regular Chat diff --git a/frontend/src/components/Chat.test.tsx b/frontend/src/components/Chat.test.tsx index 6493e59..87ce0c9 100644 --- a/frontend/src/components/Chat.test.tsx +++ b/frontend/src/components/Chat.test.tsx @@ -1,12 +1,23 @@ -import { render, screen, waitFor } from "@testing-library/react"; +import { act, render, screen, waitFor } from "@testing-library/react"; import userEvent from "@testing-library/user-event"; import { beforeEach, describe, expect, it, vi } from "vitest"; import { api } from "../api/client"; import type { ReviewStory } from "../api/workflow"; import { workflowApi } from "../api/workflow"; +import type { Message } from "../types"; import { Chat } from "./Chat"; +// Module-level store for the WebSocket handlers captured during connect(). +// Tests in the "message rendering" suite use this to simulate incoming messages. +type WsHandlers = { + onToken: (content: string) => void; + onUpdate: (history: Message[]) => void; + onSessionId: (sessionId: string) => void; + onError: (message: string) => void; +}; +let capturedWsHandlers: WsHandlers | null = null; + vi.mock("../api/client", () => { const api = { getOllamaModels: vi.fn(), @@ -18,7 +29,9 @@ vi.mock("../api/client", () => { setAnthropicApiKey: vi.fn(), }; class ChatWebSocket { - connect() {} + connect(handlers: WsHandlers) { + capturedWsHandlers = handlers; + } close() {} sendChat() {} cancel() {} @@ -609,3 +622,151 @@ describe("Chat review panel", () => { expect(mockedApi.getAnthropicModels).not.toHaveBeenCalled(); }); }); + +describe("Chat message rendering — unified tool call UI", () => { + beforeEach(() => { + capturedWsHandlers = null; + + mockedApi.getOllamaModels.mockResolvedValue(["llama3.1"]); + mockedApi.getAnthropicApiKeyExists.mockResolvedValue(true); + mockedApi.getAnthropicModels.mockResolvedValue([]); + mockedApi.getModelPreference.mockResolvedValue(null); + mockedApi.setModelPreference.mockResolvedValue(true); + mockedApi.cancelChat.mockResolvedValue(true); + + mockedWorkflow.getAcceptance.mockResolvedValue({ + can_accept: true, + reasons: [], + warning: null, + summary: { total: 0, passed: 0, failed: 0 }, + missing_categories: [], + }); + mockedWorkflow.getReviewQueueAll.mockResolvedValue({ stories: [] }); + mockedWorkflow.ensureAcceptance.mockResolvedValue(true); + mockedWorkflow.getStoryTodos.mockResolvedValue({ stories: [] }); + mockedWorkflow.getUpcomingStories.mockResolvedValue({ stories: [] }); + }); + + it("renders tool call badge for assistant message with tool_calls (AC3)", async () => { + render(); + + await waitFor(() => expect(capturedWsHandlers).not.toBeNull()); + + const messages: Message[] = [ + { role: "user", content: "Read src/main.rs" }, + { + role: "assistant", + content: "I'll read that file.", + tool_calls: [ + { + id: "toolu_abc", + type: "function", + function: { + name: "Read", + arguments: '{"file_path":"src/main.rs"}', + }, + }, + ], + }, + ]; + + act(() => { + capturedWsHandlers?.onUpdate(messages); + }); + + expect(await screen.findByText("I'll read that file.")).toBeInTheDocument(); + // Tool call badge should appear showing the function name + expect(await screen.findByText(/Read/)).toBeInTheDocument(); + }); + + it("renders collapsible tool output for tool role messages (AC3)", async () => { + render(); + + await waitFor(() => expect(capturedWsHandlers).not.toBeNull()); + + const messages: Message[] = [ + { role: "user", content: "Check the file" }, + { + role: "assistant", + content: "", + tool_calls: [ + { + id: "toolu_1", + type: "function", + function: { name: "Read", arguments: '{"file_path":"foo.rs"}' }, + }, + ], + }, + { + role: "tool", + content: 'fn main() { println!("hello"); }', + tool_call_id: "toolu_1", + }, + { role: "assistant", content: "The file contains a main function." }, + ]; + + act(() => { + capturedWsHandlers?.onUpdate(messages); + }); + + // Tool output section should be collapsible + expect(await screen.findByText(/Tool Output/)).toBeInTheDocument(); + expect( + await screen.findByText("The file contains a main function."), + ).toBeInTheDocument(); + }); + + it("renders plain assistant message without tool call badges (AC5)", async () => { + render(); + + await waitFor(() => expect(capturedWsHandlers).not.toBeNull()); + + const messages: Message[] = [ + { role: "user", content: "Hello" }, + { role: "assistant", content: "Hi there! How can I help?" }, + ]; + + act(() => { + capturedWsHandlers?.onUpdate(messages); + }); + + expect( + await screen.findByText("Hi there! How can I help?"), + ).toBeInTheDocument(); + // No tool call badges should appear + expect(screen.queryByText(/Tool Output/)).toBeNull(); + }); + + it("renders multiple tool calls in a single assistant turn (AC3)", async () => { + render(); + + await waitFor(() => expect(capturedWsHandlers).not.toBeNull()); + + const messages: Message[] = [ + { role: "user", content: "Do some work" }, + { + role: "assistant", + content: "I'll do multiple things.", + tool_calls: [ + { + id: "id1", + type: "function", + function: { name: "Bash", arguments: '{"command":"cargo test"}' }, + }, + { + id: "id2", + type: "function", + function: { name: "Read", arguments: '{"file_path":"Cargo.toml"}' }, + }, + ], + }, + ]; + + act(() => { + capturedWsHandlers?.onUpdate(messages); + }); + + expect(await screen.findByText(/Bash/)).toBeInTheDocument(); + expect(await screen.findByText(/Read/)).toBeInTheDocument(); + }); +}); diff --git a/server/src/llm/chat.rs b/server/src/llm/chat.rs index 243c40e..5605587 100644 --- a/server/src/llm/chat.rs +++ b/server/src/llm/chat.rs @@ -1,4 +1,5 @@ use crate::llm::prompts::SYSTEM_PROMPT; +use crate::llm::providers::claude_code::ClaudeCodeResult; use crate::llm::types::{Message, Role, ToolCall, ToolDefinition, ToolFunctionDefinition}; use crate::state::SessionState; use crate::store::StoreOps; @@ -209,7 +210,9 @@ where // Claude Code provider: bypasses our tool loop entirely. // Claude Code has its own agent loop, tools, and context management. - // We just pipe the user message in and stream raw output back. + // We pipe the user message in, stream text tokens for live display, and + // collect the structured messages (assistant turns + tool results) from + // the stream-json output for the final message history. if is_claude_code { use crate::llm::providers::claude_code::ClaudeCodeProvider; @@ -225,7 +228,10 @@ where .unwrap_or_else(|_| std::path::PathBuf::from(".")); let provider = ClaudeCodeProvider::new(); - let response = provider + let ClaudeCodeResult { + messages: cc_messages, + session_id, + } = provider .chat_stream( &user_message, &project_root.to_string_lossy(), @@ -236,19 +242,24 @@ where .await .map_err(|e| format!("Claude Code Error: {e}"))?; - let assistant_msg = Message { - role: Role::Assistant, - content: response.content.unwrap_or_default(), - tool_calls: None, - tool_call_id: None, - }; - + // Build the final message history: user messages + Claude Code's turns. + // If the session produced no structured messages (e.g. empty response), + // fall back to an empty assistant message so the UI stops loading. let mut result = messages.clone(); - result.push(assistant_msg); + if cc_messages.is_empty() { + result.push(Message { + role: Role::Assistant, + content: String::new(), + tool_calls: None, + tool_call_id: None, + }); + } else { + result.extend(cc_messages); + } on_update(&result); return Ok(ChatResult { messages: result, - session_id: response.session_id, + session_id, }); } diff --git a/server/src/llm/providers/claude_code.rs b/server/src/llm/providers/claude_code.rs index 9eb81d5..5dee6ce 100644 --- a/server/src/llm/providers/claude_code.rs +++ b/server/src/llm/providers/claude_code.rs @@ -4,7 +4,16 @@ use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; use tokio::sync::watch; -use crate::llm::types::CompletionResponse; +use crate::llm::types::{FunctionCall, Message, Role, ToolCall}; + +/// Result from a Claude Code session containing structured messages. +pub struct ClaudeCodeResult { + /// The conversation messages produced by Claude Code, including assistant + /// turns (with optional tool_calls) and tool result turns. + pub messages: Vec, + /// Session ID for conversation resumption on subsequent requests. + pub session_id: Option, +} /// Manages a Claude Code session via a pseudo-terminal. /// @@ -29,7 +38,7 @@ impl ClaudeCodeProvider { session_id: Option<&str>, cancel_rx: &mut watch::Receiver, mut on_token: F, - ) -> Result + ) -> Result where F: FnMut(&str) + Send, { @@ -50,15 +59,22 @@ impl ClaudeCodeProvider { }); let (token_tx, mut token_rx) = tokio::sync::mpsc::unbounded_channel::(); + let (msg_tx, msg_rx) = std::sync::mpsc::channel::(); let (sid_tx, sid_rx) = tokio::sync::oneshot::channel::(); let pty_handle = tokio::task::spawn_blocking(move || { - run_pty_session(&message, &cwd, resume_id.as_deref(), cancelled, token_tx, sid_tx) + run_pty_session( + &message, + &cwd, + resume_id.as_deref(), + cancelled, + token_tx, + msg_tx, + sid_tx, + ) }); - let mut full_output = String::new(); while let Some(token) = token_rx.recv().await { - full_output.push_str(&token); on_token(&token); } @@ -67,10 +83,10 @@ impl ClaudeCodeProvider { .map_err(|e| format!("PTY task panicked: {e}"))??; let captured_session_id = sid_rx.await.ok(); + let structured_messages: Vec = msg_rx.try_iter().collect(); - Ok(CompletionResponse { - content: Some(full_output), - tool_calls: None, + Ok(ClaudeCodeResult { + messages: structured_messages, session_id: captured_session_id, }) } @@ -80,12 +96,17 @@ impl ClaudeCodeProvider { /// /// The PTY makes isatty() return true. The `-p` flag gives us /// single-shot non-interactive mode with structured output. +/// +/// Sends streaming text tokens via `token_tx` for real-time display, and +/// complete structured `Message` values via `msg_tx` for the final message +/// history (assistant turns with tool_calls, and tool result turns). fn run_pty_session( user_message: &str, cwd: &str, resume_session_id: Option<&str>, cancelled: Arc, token_tx: tokio::sync::mpsc::UnboundedSender, + msg_tx: std::sync::mpsc::Sender, sid_tx: tokio::sync::oneshot::Sender, ) -> Result<(), String> { let pty_system = native_pty_system(); @@ -203,55 +224,35 @@ fn run_pty_session( } match event_type { - // Streaming deltas (when --include-partial-messages is used) + // Streaming deltas — used for real-time text display only "stream_event" => { if let Some(event) = json.get("event") { handle_stream_event(event, &token_tx); } } - // Complete assistant message + // Complete assistant message — extract text and tool_use blocks "assistant" => { if let Some(message) = json.get("message") && let Some(content) = message.get("content").and_then(|c| c.as_array()) { - for block in content { - if let Some(text) = - block.get("text").and_then(|t| t.as_str()) - { - let _ = token_tx.send(text.to_string()); - } - } + parse_assistant_message(content, &msg_tx); + } + } + // User message containing tool results from Claude Code's own execution + "user" => { + if let Some(message) = json.get("message") + && let Some(content) = + message.get("content").and_then(|c| c.as_array()) + { + parse_tool_results(content, &msg_tx); } } // Final result with usage stats "result" => { - if let Some(cost) = - json.get("total_cost_usd").and_then(|c| c.as_f64()) - { - let _ = - token_tx.send(format!("\n\n---\n_Cost: ${cost:.4}_\n")); - } - if let Some(usage) = json.get("usage") { - let input = usage - .get("input_tokens") - .and_then(|t| t.as_u64()) - .unwrap_or(0); - let output = usage - .get("output_tokens") - .and_then(|t| t.as_u64()) - .unwrap_or(0); - let cached = usage - .get("cache_read_input_tokens") - .and_then(|t| t.as_u64()) - .unwrap_or(0); - let _ = token_tx.send(format!( - "_Tokens: {input} in / {output} out / {cached} cached_\n" - )); - } got_result = true; } - // System init — log billing info + // System init — log billing info via streaming display "system" => { let api_source = json .get("apiKeySource") @@ -264,7 +265,7 @@ fn run_pty_session( let _ = token_tx .send(format!("_[{model} | apiKey: {api_source}]_\n\n")); } - // Rate limit info + // Rate limit info — surface briefly in streaming display "rate_limit_event" => { if let Some(info) = json.get("rate_limit_info") { let status = info @@ -297,12 +298,35 @@ fn run_pty_session( while let Ok(Some(line)) = line_rx.try_recv() { let trimmed = line.trim(); if let Ok(json) = serde_json::from_str::(trimmed) - && let Some(event) = json - .get("type") - .filter(|t| t.as_str() == Some("stream_event")) - .and_then(|_| json.get("event")) + && let Some(event_type) = + json.get("type").and_then(|t| t.as_str()) { - handle_stream_event(event, &token_tx); + match event_type { + "stream_event" => { + if let Some(event) = json.get("event") { + handle_stream_event(event, &token_tx); + } + } + "assistant" => { + if let Some(message) = json.get("message") + && let Some(content) = message + .get("content") + .and_then(|c| c.as_array()) + { + parse_assistant_message(content, &msg_tx); + } + } + "user" => { + if let Some(message) = json.get("message") + && let Some(content) = message + .get("content") + .and_then(|c| c.as_array()) + { + parse_tool_results(content, &msg_tx); + } + } + _ => {} + } } } break; @@ -335,7 +359,115 @@ fn run_pty_session( Ok(()) } -/// Extract text from a stream event and send to the token channel. +/// Parse a complete `assistant` message content array. +/// +/// Extracts text blocks into `content` and tool_use blocks into `tool_calls`, +/// then sends a single `Message { role: Assistant }` via `msg_tx`. +/// This is the authoritative source for the final message structure — streaming +/// text deltas (via `handle_stream_event`) are only used for the live display. +fn parse_assistant_message( + content: &[serde_json::Value], + msg_tx: &std::sync::mpsc::Sender, +) { + let mut text = String::new(); + let mut tool_calls: Vec = Vec::new(); + + for block in content { + match block.get("type").and_then(|t| t.as_str()) { + Some("text") => { + if let Some(t) = block.get("text").and_then(|t| t.as_str()) { + text.push_str(t); + } + } + Some("tool_use") => { + let id = block + .get("id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + let name = block + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let input = block + .get("input") + .cloned() + .unwrap_or(serde_json::Value::Object(serde_json::Map::new())); + let arguments = serde_json::to_string(&input).unwrap_or_default(); + tool_calls.push(ToolCall { + id, + function: FunctionCall { name, arguments }, + kind: "function".to_string(), + }); + } + _ => {} + } + } + + let msg = Message { + role: Role::Assistant, + content: text, + tool_calls: if tool_calls.is_empty() { + None + } else { + Some(tool_calls) + }, + tool_call_id: None, + }; + let _ = msg_tx.send(msg); +} + +/// Parse a `user` message containing tool_result blocks. +/// +/// Claude Code injects tool results into the conversation as `user` role +/// messages. Each `tool_result` block becomes a separate `Message { role: Tool }`. +fn parse_tool_results( + content: &[serde_json::Value], + msg_tx: &std::sync::mpsc::Sender, +) { + for block in content { + if block.get("type").and_then(|t| t.as_str()) != Some("tool_result") { + continue; + } + + let tool_use_id = block + .get("tool_use_id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + // `content` in a tool_result can be a plain string or an array of content blocks + let content_str = match block.get("content") { + Some(serde_json::Value::String(s)) => s.clone(), + Some(serde_json::Value::Array(arr)) => { + // Extract text from content block array + arr.iter() + .filter_map(|b| { + if b.get("type").and_then(|t| t.as_str()) == Some("text") { + b.get("text").and_then(|t| t.as_str()).map(|s| s.to_string()) + } else { + None + } + }) + .collect::>() + .join("\n") + } + Some(other) => serde_json::to_string(other).unwrap_or_default(), + None => String::new(), + }; + + let _ = msg_tx.send(Message { + role: Role::Tool, + content: content_str, + tool_calls: None, + tool_call_id: tool_use_id, + }); + } +} + +/// Extract text from a stream event and send to the token channel for live display. +/// +/// Stream events provide incremental text deltas for real-time rendering. +/// The authoritative final message content comes from complete `assistant` events. fn handle_stream_event( event: &serde_json::Value, token_tx: &tokio::sync::mpsc::UnboundedSender, @@ -343,7 +475,7 @@ fn handle_stream_event( let event_type = event.get("type").and_then(|t| t.as_str()).unwrap_or(""); match event_type { - // Text content streaming + // Text content streaming — only text_delta, not input_json_delta (tool args) "content_block_delta" => { if let Some(delta) = event.get("delta") { let delta_type = delta.get("type").and_then(|t| t.as_str()).unwrap_or(""); @@ -364,17 +496,7 @@ fn handle_stream_event( } } } - // Message complete — log usage info - "message_delta" => { - if let Some(usage) = event.get("usage") { - let output_tokens = usage - .get("output_tokens") - .and_then(|t| t.as_u64()) - .unwrap_or(0); - let _ = token_tx.send(format!("\n[tokens: {output_tokens} output]\n")); - } - } - // Log errors + // Log errors via streaming display "error" => { if let Some(error) = event.get("error") { let msg = error @@ -387,3 +509,174 @@ fn handle_stream_event( _ => {} } } + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + fn collect_messages( + f: impl Fn(&std::sync::mpsc::Sender), + ) -> Vec { + let (tx, rx) = std::sync::mpsc::channel(); + f(&tx); + drop(tx); + rx.try_iter().collect() + } + + #[test] + fn parse_assistant_message_text_only() { + let content = vec![json!({"type": "text", "text": "Hello, world!"})]; + let msgs = collect_messages(|tx| parse_assistant_message(&content, tx)); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].role, Role::Assistant); + assert_eq!(msgs[0].content, "Hello, world!"); + assert!(msgs[0].tool_calls.is_none()); + } + + #[test] + fn parse_assistant_message_with_tool_use() { + let content = vec![ + json!({"type": "text", "text": "I'll read that file."}), + json!({ + "type": "tool_use", + "id": "toolu_abc123", + "name": "Read", + "input": {"file_path": "/src/main.rs"} + }), + ]; + let msgs = collect_messages(|tx| parse_assistant_message(&content, tx)); + assert_eq!(msgs.len(), 1); + let msg = &msgs[0]; + assert_eq!(msg.role, Role::Assistant); + assert_eq!(msg.content, "I'll read that file."); + let tool_calls = msg.tool_calls.as_ref().expect("should have tool calls"); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].id.as_deref(), Some("toolu_abc123")); + assert_eq!(tool_calls[0].function.name, "Read"); + let args: serde_json::Value = + serde_json::from_str(&tool_calls[0].function.arguments).unwrap(); + assert_eq!(args["file_path"], "/src/main.rs"); + } + + #[test] + fn parse_assistant_message_multiple_tool_uses() { + let content = vec![ + json!({"type": "tool_use", "id": "id1", "name": "Glob", "input": {"pattern": "*.rs"}}), + json!({"type": "tool_use", "id": "id2", "name": "Bash", "input": {"command": "cargo test"}}), + ]; + let msgs = collect_messages(|tx| parse_assistant_message(&content, tx)); + assert_eq!(msgs.len(), 1); + let tool_calls = msgs[0].tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 2); + assert_eq!(tool_calls[0].function.name, "Glob"); + assert_eq!(tool_calls[1].function.name, "Bash"); + } + + #[test] + fn parse_tool_results_string_content() { + let content = vec![json!({ + "type": "tool_result", + "tool_use_id": "toolu_abc123", + "content": "fn main() { println!(\"hello\"); }" + })]; + let msgs = collect_messages(|tx| parse_tool_results(&content, tx)); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].role, Role::Tool); + assert_eq!(msgs[0].content, "fn main() { println!(\"hello\"); }"); + assert_eq!(msgs[0].tool_call_id.as_deref(), Some("toolu_abc123")); + } + + #[test] + fn parse_tool_results_array_content() { + let content = vec![json!({ + "type": "tool_result", + "tool_use_id": "toolu_xyz", + "content": [ + {"type": "text", "text": "Line 1"}, + {"type": "text", "text": "Line 2"} + ] + })]; + let msgs = collect_messages(|tx| parse_tool_results(&content, tx)); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].content, "Line 1\nLine 2"); + assert_eq!(msgs[0].tool_call_id.as_deref(), Some("toolu_xyz")); + } + + #[test] + fn parse_tool_results_multiple_results() { + let content = vec![ + json!({"type": "tool_result", "tool_use_id": "id1", "content": "result1"}), + json!({"type": "tool_result", "tool_use_id": "id2", "content": "result2"}), + ]; + let msgs = collect_messages(|tx| parse_tool_results(&content, tx)); + assert_eq!(msgs.len(), 2); + assert_eq!(msgs[0].tool_call_id.as_deref(), Some("id1")); + assert_eq!(msgs[1].tool_call_id.as_deref(), Some("id2")); + } + + #[test] + fn parse_tool_results_ignores_non_tool_result_blocks() { + let content = vec![ + json!({"type": "text", "text": "not a tool result"}), + json!({"type": "tool_result", "tool_use_id": "id1", "content": "actual result"}), + ]; + let msgs = collect_messages(|tx| parse_tool_results(&content, tx)); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].tool_call_id.as_deref(), Some("id1")); + } + + #[test] + fn parse_assistant_message_empty_text_with_tool_use() { + // When a message has only tool_use (no text), content should be empty string + let content = vec![json!({ + "type": "tool_use", + "id": "toolu_1", + "name": "Write", + "input": {"file_path": "foo.rs", "content": "fn foo() {}"} + })]; + let msgs = collect_messages(|tx| parse_assistant_message(&content, tx)); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].content, ""); + assert!(msgs[0].tool_calls.is_some()); + } + + #[test] + fn handle_stream_event_text_delta_sends_token() { + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + let event = json!({ + "type": "content_block_delta", + "delta": {"type": "text_delta", "text": "hello "} + }); + handle_stream_event(&event, &tx); + drop(tx); + let tokens: Vec<_> = { + let mut v = vec![]; + while let Ok(t) = rx.try_recv() { + v.push(t); + } + v + }; + assert_eq!(tokens, vec!["hello "]); + } + + #[test] + fn handle_stream_event_input_json_delta_not_sent() { + // Tool argument JSON deltas should NOT be sent as text tokens + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + let event = json!({ + "type": "content_block_delta", + "delta": {"type": "input_json_delta", "partial_json": "{\"path\":"} + }); + handle_stream_event(&event, &tx); + drop(tx); + let tokens: Vec = { + let mut v = vec![]; + while let Ok(t) = rx.try_recv() { + v.push(t); + } + v + }; + assert!(tokens.is_empty()); + } +}