Story 43: Unified chat UI for Claude Code and regular chat
Integrate Claude Code provider into the chat UI alongside regular Ollama/Anthropic providers. Updates AgentPanel and Chat components. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
---
|
---
|
||||||
name: Unified Chat UI for Claude Code and Regular Chat
|
name: Unified Chat UI for Claude Code and Regular Chat
|
||||||
test_plan: pending
|
test_plan: approved
|
||||||
---
|
---
|
||||||
|
|
||||||
# Story 43: Unified Chat UI for Claude Code and Regular Chat
|
# Story 43: Unified Chat UI for Claude Code and Regular Chat
|
||||||
|
|||||||
@@ -1,12 +1,23 @@
|
|||||||
import { render, screen, waitFor } from "@testing-library/react";
|
import { act, render, screen, waitFor } from "@testing-library/react";
|
||||||
import userEvent from "@testing-library/user-event";
|
import userEvent from "@testing-library/user-event";
|
||||||
|
|
||||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||||
import { api } from "../api/client";
|
import { api } from "../api/client";
|
||||||
import type { ReviewStory } from "../api/workflow";
|
import type { ReviewStory } from "../api/workflow";
|
||||||
import { workflowApi } from "../api/workflow";
|
import { workflowApi } from "../api/workflow";
|
||||||
|
import type { Message } from "../types";
|
||||||
import { Chat } from "./Chat";
|
import { Chat } from "./Chat";
|
||||||
|
|
||||||
|
// Module-level store for the WebSocket handlers captured during connect().
|
||||||
|
// Tests in the "message rendering" suite use this to simulate incoming messages.
|
||||||
|
type WsHandlers = {
|
||||||
|
onToken: (content: string) => void;
|
||||||
|
onUpdate: (history: Message[]) => void;
|
||||||
|
onSessionId: (sessionId: string) => void;
|
||||||
|
onError: (message: string) => void;
|
||||||
|
};
|
||||||
|
let capturedWsHandlers: WsHandlers | null = null;
|
||||||
|
|
||||||
vi.mock("../api/client", () => {
|
vi.mock("../api/client", () => {
|
||||||
const api = {
|
const api = {
|
||||||
getOllamaModels: vi.fn(),
|
getOllamaModels: vi.fn(),
|
||||||
@@ -18,7 +29,9 @@ vi.mock("../api/client", () => {
|
|||||||
setAnthropicApiKey: vi.fn(),
|
setAnthropicApiKey: vi.fn(),
|
||||||
};
|
};
|
||||||
class ChatWebSocket {
|
class ChatWebSocket {
|
||||||
connect() {}
|
connect(handlers: WsHandlers) {
|
||||||
|
capturedWsHandlers = handlers;
|
||||||
|
}
|
||||||
close() {}
|
close() {}
|
||||||
sendChat() {}
|
sendChat() {}
|
||||||
cancel() {}
|
cancel() {}
|
||||||
@@ -609,3 +622,151 @@ describe("Chat review panel", () => {
|
|||||||
expect(mockedApi.getAnthropicModels).not.toHaveBeenCalled();
|
expect(mockedApi.getAnthropicModels).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe("Chat message rendering — unified tool call UI", () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
capturedWsHandlers = null;
|
||||||
|
|
||||||
|
mockedApi.getOllamaModels.mockResolvedValue(["llama3.1"]);
|
||||||
|
mockedApi.getAnthropicApiKeyExists.mockResolvedValue(true);
|
||||||
|
mockedApi.getAnthropicModels.mockResolvedValue([]);
|
||||||
|
mockedApi.getModelPreference.mockResolvedValue(null);
|
||||||
|
mockedApi.setModelPreference.mockResolvedValue(true);
|
||||||
|
mockedApi.cancelChat.mockResolvedValue(true);
|
||||||
|
|
||||||
|
mockedWorkflow.getAcceptance.mockResolvedValue({
|
||||||
|
can_accept: true,
|
||||||
|
reasons: [],
|
||||||
|
warning: null,
|
||||||
|
summary: { total: 0, passed: 0, failed: 0 },
|
||||||
|
missing_categories: [],
|
||||||
|
});
|
||||||
|
mockedWorkflow.getReviewQueueAll.mockResolvedValue({ stories: [] });
|
||||||
|
mockedWorkflow.ensureAcceptance.mockResolvedValue(true);
|
||||||
|
mockedWorkflow.getStoryTodos.mockResolvedValue({ stories: [] });
|
||||||
|
mockedWorkflow.getUpcomingStories.mockResolvedValue({ stories: [] });
|
||||||
|
});
|
||||||
|
|
||||||
|
it("renders tool call badge for assistant message with tool_calls (AC3)", async () => {
|
||||||
|
render(<Chat projectPath="/tmp/project" onCloseProject={vi.fn()} />);
|
||||||
|
|
||||||
|
await waitFor(() => expect(capturedWsHandlers).not.toBeNull());
|
||||||
|
|
||||||
|
const messages: Message[] = [
|
||||||
|
{ role: "user", content: "Read src/main.rs" },
|
||||||
|
{
|
||||||
|
role: "assistant",
|
||||||
|
content: "I'll read that file.",
|
||||||
|
tool_calls: [
|
||||||
|
{
|
||||||
|
id: "toolu_abc",
|
||||||
|
type: "function",
|
||||||
|
function: {
|
||||||
|
name: "Read",
|
||||||
|
arguments: '{"file_path":"src/main.rs"}',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
capturedWsHandlers?.onUpdate(messages);
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(await screen.findByText("I'll read that file.")).toBeInTheDocument();
|
||||||
|
// Tool call badge should appear showing the function name
|
||||||
|
expect(await screen.findByText(/Read/)).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("renders collapsible tool output for tool role messages (AC3)", async () => {
|
||||||
|
render(<Chat projectPath="/tmp/project" onCloseProject={vi.fn()} />);
|
||||||
|
|
||||||
|
await waitFor(() => expect(capturedWsHandlers).not.toBeNull());
|
||||||
|
|
||||||
|
const messages: Message[] = [
|
||||||
|
{ role: "user", content: "Check the file" },
|
||||||
|
{
|
||||||
|
role: "assistant",
|
||||||
|
content: "",
|
||||||
|
tool_calls: [
|
||||||
|
{
|
||||||
|
id: "toolu_1",
|
||||||
|
type: "function",
|
||||||
|
function: { name: "Read", arguments: '{"file_path":"foo.rs"}' },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "tool",
|
||||||
|
content: 'fn main() { println!("hello"); }',
|
||||||
|
tool_call_id: "toolu_1",
|
||||||
|
},
|
||||||
|
{ role: "assistant", content: "The file contains a main function." },
|
||||||
|
];
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
capturedWsHandlers?.onUpdate(messages);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Tool output section should be collapsible
|
||||||
|
expect(await screen.findByText(/Tool Output/)).toBeInTheDocument();
|
||||||
|
expect(
|
||||||
|
await screen.findByText("The file contains a main function."),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("renders plain assistant message without tool call badges (AC5)", async () => {
|
||||||
|
render(<Chat projectPath="/tmp/project" onCloseProject={vi.fn()} />);
|
||||||
|
|
||||||
|
await waitFor(() => expect(capturedWsHandlers).not.toBeNull());
|
||||||
|
|
||||||
|
const messages: Message[] = [
|
||||||
|
{ role: "user", content: "Hello" },
|
||||||
|
{ role: "assistant", content: "Hi there! How can I help?" },
|
||||||
|
];
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
capturedWsHandlers?.onUpdate(messages);
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(
|
||||||
|
await screen.findByText("Hi there! How can I help?"),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
// No tool call badges should appear
|
||||||
|
expect(screen.queryByText(/Tool Output/)).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("renders multiple tool calls in a single assistant turn (AC3)", async () => {
|
||||||
|
render(<Chat projectPath="/tmp/project" onCloseProject={vi.fn()} />);
|
||||||
|
|
||||||
|
await waitFor(() => expect(capturedWsHandlers).not.toBeNull());
|
||||||
|
|
||||||
|
const messages: Message[] = [
|
||||||
|
{ role: "user", content: "Do some work" },
|
||||||
|
{
|
||||||
|
role: "assistant",
|
||||||
|
content: "I'll do multiple things.",
|
||||||
|
tool_calls: [
|
||||||
|
{
|
||||||
|
id: "id1",
|
||||||
|
type: "function",
|
||||||
|
function: { name: "Bash", arguments: '{"command":"cargo test"}' },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: "id2",
|
||||||
|
type: "function",
|
||||||
|
function: { name: "Read", arguments: '{"file_path":"Cargo.toml"}' },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
capturedWsHandlers?.onUpdate(messages);
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(await screen.findByText(/Bash/)).toBeInTheDocument();
|
||||||
|
expect(await screen.findByText(/Read/)).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
use crate::llm::prompts::SYSTEM_PROMPT;
|
use crate::llm::prompts::SYSTEM_PROMPT;
|
||||||
|
use crate::llm::providers::claude_code::ClaudeCodeResult;
|
||||||
use crate::llm::types::{Message, Role, ToolCall, ToolDefinition, ToolFunctionDefinition};
|
use crate::llm::types::{Message, Role, ToolCall, ToolDefinition, ToolFunctionDefinition};
|
||||||
use crate::state::SessionState;
|
use crate::state::SessionState;
|
||||||
use crate::store::StoreOps;
|
use crate::store::StoreOps;
|
||||||
@@ -209,7 +210,9 @@ where
|
|||||||
|
|
||||||
// Claude Code provider: bypasses our tool loop entirely.
|
// Claude Code provider: bypasses our tool loop entirely.
|
||||||
// Claude Code has its own agent loop, tools, and context management.
|
// Claude Code has its own agent loop, tools, and context management.
|
||||||
// We just pipe the user message in and stream raw output back.
|
// We pipe the user message in, stream text tokens for live display, and
|
||||||
|
// collect the structured messages (assistant turns + tool results) from
|
||||||
|
// the stream-json output for the final message history.
|
||||||
if is_claude_code {
|
if is_claude_code {
|
||||||
use crate::llm::providers::claude_code::ClaudeCodeProvider;
|
use crate::llm::providers::claude_code::ClaudeCodeProvider;
|
||||||
|
|
||||||
@@ -225,7 +228,10 @@ where
|
|||||||
.unwrap_or_else(|_| std::path::PathBuf::from("."));
|
.unwrap_or_else(|_| std::path::PathBuf::from("."));
|
||||||
|
|
||||||
let provider = ClaudeCodeProvider::new();
|
let provider = ClaudeCodeProvider::new();
|
||||||
let response = provider
|
let ClaudeCodeResult {
|
||||||
|
messages: cc_messages,
|
||||||
|
session_id,
|
||||||
|
} = provider
|
||||||
.chat_stream(
|
.chat_stream(
|
||||||
&user_message,
|
&user_message,
|
||||||
&project_root.to_string_lossy(),
|
&project_root.to_string_lossy(),
|
||||||
@@ -236,19 +242,24 @@ where
|
|||||||
.await
|
.await
|
||||||
.map_err(|e| format!("Claude Code Error: {e}"))?;
|
.map_err(|e| format!("Claude Code Error: {e}"))?;
|
||||||
|
|
||||||
let assistant_msg = Message {
|
// Build the final message history: user messages + Claude Code's turns.
|
||||||
role: Role::Assistant,
|
// If the session produced no structured messages (e.g. empty response),
|
||||||
content: response.content.unwrap_or_default(),
|
// fall back to an empty assistant message so the UI stops loading.
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut result = messages.clone();
|
let mut result = messages.clone();
|
||||||
result.push(assistant_msg);
|
if cc_messages.is_empty() {
|
||||||
|
result.push(Message {
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: String::new(),
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
result.extend(cc_messages);
|
||||||
|
}
|
||||||
on_update(&result);
|
on_update(&result);
|
||||||
return Ok(ChatResult {
|
return Ok(ChatResult {
|
||||||
messages: result,
|
messages: result,
|
||||||
session_id: response.session_id,
|
session_id,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,16 @@ use std::sync::Arc;
|
|||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use tokio::sync::watch;
|
use tokio::sync::watch;
|
||||||
|
|
||||||
use crate::llm::types::CompletionResponse;
|
use crate::llm::types::{FunctionCall, Message, Role, ToolCall};
|
||||||
|
|
||||||
|
/// Result from a Claude Code session containing structured messages.
|
||||||
|
pub struct ClaudeCodeResult {
|
||||||
|
/// The conversation messages produced by Claude Code, including assistant
|
||||||
|
/// turns (with optional tool_calls) and tool result turns.
|
||||||
|
pub messages: Vec<Message>,
|
||||||
|
/// Session ID for conversation resumption on subsequent requests.
|
||||||
|
pub session_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
/// Manages a Claude Code session via a pseudo-terminal.
|
/// Manages a Claude Code session via a pseudo-terminal.
|
||||||
///
|
///
|
||||||
@@ -29,7 +38,7 @@ impl ClaudeCodeProvider {
|
|||||||
session_id: Option<&str>,
|
session_id: Option<&str>,
|
||||||
cancel_rx: &mut watch::Receiver<bool>,
|
cancel_rx: &mut watch::Receiver<bool>,
|
||||||
mut on_token: F,
|
mut on_token: F,
|
||||||
) -> Result<CompletionResponse, String>
|
) -> Result<ClaudeCodeResult, String>
|
||||||
where
|
where
|
||||||
F: FnMut(&str) + Send,
|
F: FnMut(&str) + Send,
|
||||||
{
|
{
|
||||||
@@ -50,15 +59,22 @@ impl ClaudeCodeProvider {
|
|||||||
});
|
});
|
||||||
|
|
||||||
let (token_tx, mut token_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
|
let (token_tx, mut token_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
|
||||||
|
let (msg_tx, msg_rx) = std::sync::mpsc::channel::<Message>();
|
||||||
let (sid_tx, sid_rx) = tokio::sync::oneshot::channel::<String>();
|
let (sid_tx, sid_rx) = tokio::sync::oneshot::channel::<String>();
|
||||||
|
|
||||||
let pty_handle = tokio::task::spawn_blocking(move || {
|
let pty_handle = tokio::task::spawn_blocking(move || {
|
||||||
run_pty_session(&message, &cwd, resume_id.as_deref(), cancelled, token_tx, sid_tx)
|
run_pty_session(
|
||||||
|
&message,
|
||||||
|
&cwd,
|
||||||
|
resume_id.as_deref(),
|
||||||
|
cancelled,
|
||||||
|
token_tx,
|
||||||
|
msg_tx,
|
||||||
|
sid_tx,
|
||||||
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut full_output = String::new();
|
|
||||||
while let Some(token) = token_rx.recv().await {
|
while let Some(token) = token_rx.recv().await {
|
||||||
full_output.push_str(&token);
|
|
||||||
on_token(&token);
|
on_token(&token);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -67,10 +83,10 @@ impl ClaudeCodeProvider {
|
|||||||
.map_err(|e| format!("PTY task panicked: {e}"))??;
|
.map_err(|e| format!("PTY task panicked: {e}"))??;
|
||||||
|
|
||||||
let captured_session_id = sid_rx.await.ok();
|
let captured_session_id = sid_rx.await.ok();
|
||||||
|
let structured_messages: Vec<Message> = msg_rx.try_iter().collect();
|
||||||
|
|
||||||
Ok(CompletionResponse {
|
Ok(ClaudeCodeResult {
|
||||||
content: Some(full_output),
|
messages: structured_messages,
|
||||||
tool_calls: None,
|
|
||||||
session_id: captured_session_id,
|
session_id: captured_session_id,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -80,12 +96,17 @@ impl ClaudeCodeProvider {
|
|||||||
///
|
///
|
||||||
/// The PTY makes isatty() return true. The `-p` flag gives us
|
/// The PTY makes isatty() return true. The `-p` flag gives us
|
||||||
/// single-shot non-interactive mode with structured output.
|
/// single-shot non-interactive mode with structured output.
|
||||||
|
///
|
||||||
|
/// Sends streaming text tokens via `token_tx` for real-time display, and
|
||||||
|
/// complete structured `Message` values via `msg_tx` for the final message
|
||||||
|
/// history (assistant turns with tool_calls, and tool result turns).
|
||||||
fn run_pty_session(
|
fn run_pty_session(
|
||||||
user_message: &str,
|
user_message: &str,
|
||||||
cwd: &str,
|
cwd: &str,
|
||||||
resume_session_id: Option<&str>,
|
resume_session_id: Option<&str>,
|
||||||
cancelled: Arc<AtomicBool>,
|
cancelled: Arc<AtomicBool>,
|
||||||
token_tx: tokio::sync::mpsc::UnboundedSender<String>,
|
token_tx: tokio::sync::mpsc::UnboundedSender<String>,
|
||||||
|
msg_tx: std::sync::mpsc::Sender<Message>,
|
||||||
sid_tx: tokio::sync::oneshot::Sender<String>,
|
sid_tx: tokio::sync::oneshot::Sender<String>,
|
||||||
) -> Result<(), String> {
|
) -> Result<(), String> {
|
||||||
let pty_system = native_pty_system();
|
let pty_system = native_pty_system();
|
||||||
@@ -203,55 +224,35 @@ fn run_pty_session(
|
|||||||
}
|
}
|
||||||
|
|
||||||
match event_type {
|
match event_type {
|
||||||
// Streaming deltas (when --include-partial-messages is used)
|
// Streaming deltas — used for real-time text display only
|
||||||
"stream_event" => {
|
"stream_event" => {
|
||||||
if let Some(event) = json.get("event") {
|
if let Some(event) = json.get("event") {
|
||||||
handle_stream_event(event, &token_tx);
|
handle_stream_event(event, &token_tx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Complete assistant message
|
// Complete assistant message — extract text and tool_use blocks
|
||||||
"assistant" => {
|
"assistant" => {
|
||||||
if let Some(message) = json.get("message")
|
if let Some(message) = json.get("message")
|
||||||
&& let Some(content) =
|
&& let Some(content) =
|
||||||
message.get("content").and_then(|c| c.as_array())
|
message.get("content").and_then(|c| c.as_array())
|
||||||
{
|
{
|
||||||
for block in content {
|
parse_assistant_message(content, &msg_tx);
|
||||||
if let Some(text) =
|
}
|
||||||
block.get("text").and_then(|t| t.as_str())
|
}
|
||||||
{
|
// User message containing tool results from Claude Code's own execution
|
||||||
let _ = token_tx.send(text.to_string());
|
"user" => {
|
||||||
}
|
if let Some(message) = json.get("message")
|
||||||
}
|
&& let Some(content) =
|
||||||
|
message.get("content").and_then(|c| c.as_array())
|
||||||
|
{
|
||||||
|
parse_tool_results(content, &msg_tx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Final result with usage stats
|
// Final result with usage stats
|
||||||
"result" => {
|
"result" => {
|
||||||
if let Some(cost) =
|
|
||||||
json.get("total_cost_usd").and_then(|c| c.as_f64())
|
|
||||||
{
|
|
||||||
let _ =
|
|
||||||
token_tx.send(format!("\n\n---\n_Cost: ${cost:.4}_\n"));
|
|
||||||
}
|
|
||||||
if let Some(usage) = json.get("usage") {
|
|
||||||
let input = usage
|
|
||||||
.get("input_tokens")
|
|
||||||
.and_then(|t| t.as_u64())
|
|
||||||
.unwrap_or(0);
|
|
||||||
let output = usage
|
|
||||||
.get("output_tokens")
|
|
||||||
.and_then(|t| t.as_u64())
|
|
||||||
.unwrap_or(0);
|
|
||||||
let cached = usage
|
|
||||||
.get("cache_read_input_tokens")
|
|
||||||
.and_then(|t| t.as_u64())
|
|
||||||
.unwrap_or(0);
|
|
||||||
let _ = token_tx.send(format!(
|
|
||||||
"_Tokens: {input} in / {output} out / {cached} cached_\n"
|
|
||||||
));
|
|
||||||
}
|
|
||||||
got_result = true;
|
got_result = true;
|
||||||
}
|
}
|
||||||
// System init — log billing info
|
// System init — log billing info via streaming display
|
||||||
"system" => {
|
"system" => {
|
||||||
let api_source = json
|
let api_source = json
|
||||||
.get("apiKeySource")
|
.get("apiKeySource")
|
||||||
@@ -264,7 +265,7 @@ fn run_pty_session(
|
|||||||
let _ = token_tx
|
let _ = token_tx
|
||||||
.send(format!("_[{model} | apiKey: {api_source}]_\n\n"));
|
.send(format!("_[{model} | apiKey: {api_source}]_\n\n"));
|
||||||
}
|
}
|
||||||
// Rate limit info
|
// Rate limit info — surface briefly in streaming display
|
||||||
"rate_limit_event" => {
|
"rate_limit_event" => {
|
||||||
if let Some(info) = json.get("rate_limit_info") {
|
if let Some(info) = json.get("rate_limit_info") {
|
||||||
let status = info
|
let status = info
|
||||||
@@ -297,12 +298,35 @@ fn run_pty_session(
|
|||||||
while let Ok(Some(line)) = line_rx.try_recv() {
|
while let Ok(Some(line)) = line_rx.try_recv() {
|
||||||
let trimmed = line.trim();
|
let trimmed = line.trim();
|
||||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(trimmed)
|
if let Ok(json) = serde_json::from_str::<serde_json::Value>(trimmed)
|
||||||
&& let Some(event) = json
|
&& let Some(event_type) =
|
||||||
.get("type")
|
json.get("type").and_then(|t| t.as_str())
|
||||||
.filter(|t| t.as_str() == Some("stream_event"))
|
|
||||||
.and_then(|_| json.get("event"))
|
|
||||||
{
|
{
|
||||||
handle_stream_event(event, &token_tx);
|
match event_type {
|
||||||
|
"stream_event" => {
|
||||||
|
if let Some(event) = json.get("event") {
|
||||||
|
handle_stream_event(event, &token_tx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"assistant" => {
|
||||||
|
if let Some(message) = json.get("message")
|
||||||
|
&& let Some(content) = message
|
||||||
|
.get("content")
|
||||||
|
.and_then(|c| c.as_array())
|
||||||
|
{
|
||||||
|
parse_assistant_message(content, &msg_tx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"user" => {
|
||||||
|
if let Some(message) = json.get("message")
|
||||||
|
&& let Some(content) = message
|
||||||
|
.get("content")
|
||||||
|
.and_then(|c| c.as_array())
|
||||||
|
{
|
||||||
|
parse_tool_results(content, &msg_tx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
@@ -335,7 +359,115 @@ fn run_pty_session(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extract text from a stream event and send to the token channel.
|
/// Parse a complete `assistant` message content array.
|
||||||
|
///
|
||||||
|
/// Extracts text blocks into `content` and tool_use blocks into `tool_calls`,
|
||||||
|
/// then sends a single `Message { role: Assistant }` via `msg_tx`.
|
||||||
|
/// This is the authoritative source for the final message structure — streaming
|
||||||
|
/// text deltas (via `handle_stream_event`) are only used for the live display.
|
||||||
|
fn parse_assistant_message(
|
||||||
|
content: &[serde_json::Value],
|
||||||
|
msg_tx: &std::sync::mpsc::Sender<Message>,
|
||||||
|
) {
|
||||||
|
let mut text = String::new();
|
||||||
|
let mut tool_calls: Vec<ToolCall> = Vec::new();
|
||||||
|
|
||||||
|
for block in content {
|
||||||
|
match block.get("type").and_then(|t| t.as_str()) {
|
||||||
|
Some("text") => {
|
||||||
|
if let Some(t) = block.get("text").and_then(|t| t.as_str()) {
|
||||||
|
text.push_str(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some("tool_use") => {
|
||||||
|
let id = block
|
||||||
|
.get("id")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
let name = block
|
||||||
|
.get("name")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
let input = block
|
||||||
|
.get("input")
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
|
||||||
|
let arguments = serde_json::to_string(&input).unwrap_or_default();
|
||||||
|
tool_calls.push(ToolCall {
|
||||||
|
id,
|
||||||
|
function: FunctionCall { name, arguments },
|
||||||
|
kind: "function".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let msg = Message {
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: text,
|
||||||
|
tool_calls: if tool_calls.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(tool_calls)
|
||||||
|
},
|
||||||
|
tool_call_id: None,
|
||||||
|
};
|
||||||
|
let _ = msg_tx.send(msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a `user` message containing tool_result blocks.
|
||||||
|
///
|
||||||
|
/// Claude Code injects tool results into the conversation as `user` role
|
||||||
|
/// messages. Each `tool_result` block becomes a separate `Message { role: Tool }`.
|
||||||
|
fn parse_tool_results(
|
||||||
|
content: &[serde_json::Value],
|
||||||
|
msg_tx: &std::sync::mpsc::Sender<Message>,
|
||||||
|
) {
|
||||||
|
for block in content {
|
||||||
|
if block.get("type").and_then(|t| t.as_str()) != Some("tool_result") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let tool_use_id = block
|
||||||
|
.get("tool_use_id")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
|
||||||
|
// `content` in a tool_result can be a plain string or an array of content blocks
|
||||||
|
let content_str = match block.get("content") {
|
||||||
|
Some(serde_json::Value::String(s)) => s.clone(),
|
||||||
|
Some(serde_json::Value::Array(arr)) => {
|
||||||
|
// Extract text from content block array
|
||||||
|
arr.iter()
|
||||||
|
.filter_map(|b| {
|
||||||
|
if b.get("type").and_then(|t| t.as_str()) == Some("text") {
|
||||||
|
b.get("text").and_then(|t| t.as_str()).map(|s| s.to_string())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("\n")
|
||||||
|
}
|
||||||
|
Some(other) => serde_json::to_string(other).unwrap_or_default(),
|
||||||
|
None => String::new(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let _ = msg_tx.send(Message {
|
||||||
|
role: Role::Tool,
|
||||||
|
content: content_str,
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: tool_use_id,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract text from a stream event and send to the token channel for live display.
|
||||||
|
///
|
||||||
|
/// Stream events provide incremental text deltas for real-time rendering.
|
||||||
|
/// The authoritative final message content comes from complete `assistant` events.
|
||||||
fn handle_stream_event(
|
fn handle_stream_event(
|
||||||
event: &serde_json::Value,
|
event: &serde_json::Value,
|
||||||
token_tx: &tokio::sync::mpsc::UnboundedSender<String>,
|
token_tx: &tokio::sync::mpsc::UnboundedSender<String>,
|
||||||
@@ -343,7 +475,7 @@ fn handle_stream_event(
|
|||||||
let event_type = event.get("type").and_then(|t| t.as_str()).unwrap_or("");
|
let event_type = event.get("type").and_then(|t| t.as_str()).unwrap_or("");
|
||||||
|
|
||||||
match event_type {
|
match event_type {
|
||||||
// Text content streaming
|
// Text content streaming — only text_delta, not input_json_delta (tool args)
|
||||||
"content_block_delta" => {
|
"content_block_delta" => {
|
||||||
if let Some(delta) = event.get("delta") {
|
if let Some(delta) = event.get("delta") {
|
||||||
let delta_type = delta.get("type").and_then(|t| t.as_str()).unwrap_or("");
|
let delta_type = delta.get("type").and_then(|t| t.as_str()).unwrap_or("");
|
||||||
@@ -364,17 +496,7 @@ fn handle_stream_event(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Message complete — log usage info
|
// Log errors via streaming display
|
||||||
"message_delta" => {
|
|
||||||
if let Some(usage) = event.get("usage") {
|
|
||||||
let output_tokens = usage
|
|
||||||
.get("output_tokens")
|
|
||||||
.and_then(|t| t.as_u64())
|
|
||||||
.unwrap_or(0);
|
|
||||||
let _ = token_tx.send(format!("\n[tokens: {output_tokens} output]\n"));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Log errors
|
|
||||||
"error" => {
|
"error" => {
|
||||||
if let Some(error) = event.get("error") {
|
if let Some(error) = event.get("error") {
|
||||||
let msg = error
|
let msg = error
|
||||||
@@ -387,3 +509,174 @@ fn handle_stream_event(
|
|||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
fn collect_messages(
|
||||||
|
f: impl Fn(&std::sync::mpsc::Sender<Message>),
|
||||||
|
) -> Vec<Message> {
|
||||||
|
let (tx, rx) = std::sync::mpsc::channel();
|
||||||
|
f(&tx);
|
||||||
|
drop(tx);
|
||||||
|
rx.try_iter().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_assistant_message_text_only() {
|
||||||
|
let content = vec![json!({"type": "text", "text": "Hello, world!"})];
|
||||||
|
let msgs = collect_messages(|tx| parse_assistant_message(&content, tx));
|
||||||
|
assert_eq!(msgs.len(), 1);
|
||||||
|
assert_eq!(msgs[0].role, Role::Assistant);
|
||||||
|
assert_eq!(msgs[0].content, "Hello, world!");
|
||||||
|
assert!(msgs[0].tool_calls.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_assistant_message_with_tool_use() {
|
||||||
|
let content = vec![
|
||||||
|
json!({"type": "text", "text": "I'll read that file."}),
|
||||||
|
json!({
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "toolu_abc123",
|
||||||
|
"name": "Read",
|
||||||
|
"input": {"file_path": "/src/main.rs"}
|
||||||
|
}),
|
||||||
|
];
|
||||||
|
let msgs = collect_messages(|tx| parse_assistant_message(&content, tx));
|
||||||
|
assert_eq!(msgs.len(), 1);
|
||||||
|
let msg = &msgs[0];
|
||||||
|
assert_eq!(msg.role, Role::Assistant);
|
||||||
|
assert_eq!(msg.content, "I'll read that file.");
|
||||||
|
let tool_calls = msg.tool_calls.as_ref().expect("should have tool calls");
|
||||||
|
assert_eq!(tool_calls.len(), 1);
|
||||||
|
assert_eq!(tool_calls[0].id.as_deref(), Some("toolu_abc123"));
|
||||||
|
assert_eq!(tool_calls[0].function.name, "Read");
|
||||||
|
let args: serde_json::Value =
|
||||||
|
serde_json::from_str(&tool_calls[0].function.arguments).unwrap();
|
||||||
|
assert_eq!(args["file_path"], "/src/main.rs");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_assistant_message_multiple_tool_uses() {
|
||||||
|
let content = vec![
|
||||||
|
json!({"type": "tool_use", "id": "id1", "name": "Glob", "input": {"pattern": "*.rs"}}),
|
||||||
|
json!({"type": "tool_use", "id": "id2", "name": "Bash", "input": {"command": "cargo test"}}),
|
||||||
|
];
|
||||||
|
let msgs = collect_messages(|tx| parse_assistant_message(&content, tx));
|
||||||
|
assert_eq!(msgs.len(), 1);
|
||||||
|
let tool_calls = msgs[0].tool_calls.as_ref().unwrap();
|
||||||
|
assert_eq!(tool_calls.len(), 2);
|
||||||
|
assert_eq!(tool_calls[0].function.name, "Glob");
|
||||||
|
assert_eq!(tool_calls[1].function.name, "Bash");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_results_string_content() {
|
||||||
|
let content = vec![json!({
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_abc123",
|
||||||
|
"content": "fn main() { println!(\"hello\"); }"
|
||||||
|
})];
|
||||||
|
let msgs = collect_messages(|tx| parse_tool_results(&content, tx));
|
||||||
|
assert_eq!(msgs.len(), 1);
|
||||||
|
assert_eq!(msgs[0].role, Role::Tool);
|
||||||
|
assert_eq!(msgs[0].content, "fn main() { println!(\"hello\"); }");
|
||||||
|
assert_eq!(msgs[0].tool_call_id.as_deref(), Some("toolu_abc123"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_results_array_content() {
|
||||||
|
let content = vec![json!({
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_xyz",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Line 1"},
|
||||||
|
{"type": "text", "text": "Line 2"}
|
||||||
|
]
|
||||||
|
})];
|
||||||
|
let msgs = collect_messages(|tx| parse_tool_results(&content, tx));
|
||||||
|
assert_eq!(msgs.len(), 1);
|
||||||
|
assert_eq!(msgs[0].content, "Line 1\nLine 2");
|
||||||
|
assert_eq!(msgs[0].tool_call_id.as_deref(), Some("toolu_xyz"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_results_multiple_results() {
|
||||||
|
let content = vec![
|
||||||
|
json!({"type": "tool_result", "tool_use_id": "id1", "content": "result1"}),
|
||||||
|
json!({"type": "tool_result", "tool_use_id": "id2", "content": "result2"}),
|
||||||
|
];
|
||||||
|
let msgs = collect_messages(|tx| parse_tool_results(&content, tx));
|
||||||
|
assert_eq!(msgs.len(), 2);
|
||||||
|
assert_eq!(msgs[0].tool_call_id.as_deref(), Some("id1"));
|
||||||
|
assert_eq!(msgs[1].tool_call_id.as_deref(), Some("id2"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_results_ignores_non_tool_result_blocks() {
|
||||||
|
let content = vec![
|
||||||
|
json!({"type": "text", "text": "not a tool result"}),
|
||||||
|
json!({"type": "tool_result", "tool_use_id": "id1", "content": "actual result"}),
|
||||||
|
];
|
||||||
|
let msgs = collect_messages(|tx| parse_tool_results(&content, tx));
|
||||||
|
assert_eq!(msgs.len(), 1);
|
||||||
|
assert_eq!(msgs[0].tool_call_id.as_deref(), Some("id1"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_assistant_message_empty_text_with_tool_use() {
|
||||||
|
// When a message has only tool_use (no text), content should be empty string
|
||||||
|
let content = vec![json!({
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "toolu_1",
|
||||||
|
"name": "Write",
|
||||||
|
"input": {"file_path": "foo.rs", "content": "fn foo() {}"}
|
||||||
|
})];
|
||||||
|
let msgs = collect_messages(|tx| parse_assistant_message(&content, tx));
|
||||||
|
assert_eq!(msgs.len(), 1);
|
||||||
|
assert_eq!(msgs[0].content, "");
|
||||||
|
assert!(msgs[0].tool_calls.is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn handle_stream_event_text_delta_sends_token() {
|
||||||
|
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
|
||||||
|
let event = json!({
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"delta": {"type": "text_delta", "text": "hello "}
|
||||||
|
});
|
||||||
|
handle_stream_event(&event, &tx);
|
||||||
|
drop(tx);
|
||||||
|
let tokens: Vec<_> = {
|
||||||
|
let mut v = vec![];
|
||||||
|
while let Ok(t) = rx.try_recv() {
|
||||||
|
v.push(t);
|
||||||
|
}
|
||||||
|
v
|
||||||
|
};
|
||||||
|
assert_eq!(tokens, vec!["hello "]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn handle_stream_event_input_json_delta_not_sent() {
|
||||||
|
// Tool argument JSON deltas should NOT be sent as text tokens
|
||||||
|
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
|
||||||
|
let event = json!({
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"delta": {"type": "input_json_delta", "partial_json": "{\"path\":"}
|
||||||
|
});
|
||||||
|
handle_stream_event(&event, &tx);
|
||||||
|
drop(tx);
|
||||||
|
let tokens: Vec<String> = {
|
||||||
|
let mut v = vec![];
|
||||||
|
while let Ok(t) = rx.try_recv() {
|
||||||
|
v.push(t);
|
||||||
|
}
|
||||||
|
v
|
||||||
|
};
|
||||||
|
assert!(tokens.is_empty());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user