diff --git a/server/src/llm/providers/claude_code.rs b/server/src/llm/providers/claude_code.rs index 5347c13..ea8f4fe 100644 --- a/server/src/llm/providers/claude_code.rs +++ b/server/src/llm/providers/claude_code.rs @@ -226,53 +226,9 @@ fn run_pty_session( // Try to parse as JSON if let Ok(json) = serde_json::from_str::(trimmed) - && let Some(event_type) = json.get("type").and_then(|t| t.as_str()) + && process_json_event(&json, &token_tx, &msg_tx, &mut sid_tx) { - // Capture session_id from any event that has it - if let Some(tx) = sid_tx.take() { - if let Some(sid) = json.get("session_id").and_then(|s| s.as_str()) { - let _ = tx.send(sid.to_string()); - } else { - // Put it back if this event didn't have a session_id - sid_tx = Some(tx); - } - } - - match event_type { - // Streaming deltas — used for real-time text display only - "stream_event" => { - if let Some(event) = json.get("event") { - handle_stream_event(event, &token_tx); - } - } - // Complete assistant message — extract text and tool_use blocks - "assistant" => { - if let Some(message) = json.get("message") - && let Some(content) = - message.get("content").and_then(|c| c.as_array()) - { - parse_assistant_message(content, &msg_tx); - } - } - // User message containing tool results from Claude Code's own execution - "user" => { - if let Some(message) = json.get("message") - && let Some(content) = - message.get("content").and_then(|c| c.as_array()) - { - parse_tool_results(content, &msg_tx); - } - } - // Final result with usage stats - "result" => { - got_result = true; - } - // System init — suppress noisy model/apiKey notification - "system" => {} - // Rate limit info — suppress noisy notification - "rate_limit_event" => {} - _ => {} - } + got_result = true; } // Ignore non-JSON lines (terminal escape sequences) @@ -287,36 +243,8 @@ fn run_pty_session( // Drain remaining lines while let Ok(Some(line)) = line_rx.try_recv() { let trimmed = line.trim(); - if let Ok(json) = serde_json::from_str::(trimmed) - && let Some(event_type) = - json.get("type").and_then(|t| t.as_str()) - { - match event_type { - "stream_event" => { - if let Some(event) = json.get("event") { - handle_stream_event(event, &token_tx); - } - } - "assistant" => { - if let Some(message) = json.get("message") - && let Some(content) = message - .get("content") - .and_then(|c| c.as_array()) - { - parse_assistant_message(content, &msg_tx); - } - } - "user" => { - if let Some(message) = json.get("message") - && let Some(content) = message - .get("content") - .and_then(|c| c.as_array()) - { - parse_tool_results(content, &msg_tx); - } - } - _ => {} - } + if let Ok(json) = serde_json::from_str::(trimmed) { + process_json_event(&json, &token_tx, &msg_tx, &mut sid_tx); } } break; @@ -349,6 +277,59 @@ fn run_pty_session( Ok(()) } +/// Dispatch a single parsed JSON event to the appropriate handler. +/// +/// Returns `true` if a `result` event was received (signals session completion). +/// Captures the session ID from the first event that carries it. +fn process_json_event( + json: &serde_json::Value, + token_tx: &tokio::sync::mpsc::UnboundedSender, + msg_tx: &std::sync::mpsc::Sender, + sid_tx: &mut Option>, +) -> bool { + let event_type = match json.get("type").and_then(|t| t.as_str()) { + Some(t) => t, + None => return false, + }; + + // Capture session_id from the first event that carries it + if let Some(tx) = sid_tx.take() { + if let Some(sid) = json.get("session_id").and_then(|s| s.as_str()) { + let _ = tx.send(sid.to_string()); + } else { + *sid_tx = Some(tx); + } + } + + match event_type { + "stream_event" => { + if let Some(event) = json.get("event") { + handle_stream_event(event, token_tx); + } + false + } + "assistant" => { + if let Some(message) = json.get("message") + && let Some(content) = message.get("content").and_then(|c| c.as_array()) + { + parse_assistant_message(content, msg_tx); + } + false + } + "user" => { + if let Some(message) = json.get("message") + && let Some(content) = message.get("content").and_then(|c| c.as_array()) + { + parse_tool_results(content, msg_tx); + } + false + } + "result" => true, + // system, rate_limit_event, and unknown types are no-ops + _ => false, + } +} + /// Parse a complete `assistant` message content array. /// /// Extracts text blocks into `content` and tool_use blocks into `tool_calls`, @@ -669,4 +650,298 @@ mod tests { }; assert!(tokens.is_empty()); } + + #[test] + fn handle_stream_event_thinking_delta_sends_prefixed_token() { + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + let event = json!({ + "type": "content_block_delta", + "delta": {"type": "thinking_delta", "thinking": "I should check the file"} + }); + handle_stream_event(&event, &tx); + drop(tx); + let tokens: Vec = { + let mut v = vec![]; + while let Ok(t) = rx.try_recv() { + v.push(t); + } + v + }; + assert_eq!(tokens, vec!["[thinking] I should check the file"]); + } + + #[test] + fn handle_stream_event_error_sends_error_token() { + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + let event = json!({ + "type": "error", + "error": {"type": "overloaded_error", "message": "Overloaded"} + }); + handle_stream_event(&event, &tx); + drop(tx); + let tokens: Vec = { + let mut v = vec![]; + while let Ok(t) = rx.try_recv() { + v.push(t); + } + v + }; + assert_eq!(tokens, vec!["\n[error: Overloaded]\n"]); + } + + #[test] + fn handle_stream_event_unknown_type_is_noop() { + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + let event = json!({"type": "ping"}); + handle_stream_event(&event, &tx); + drop(tx); + let tokens: Vec = { + let mut v = vec![]; + while let Ok(t) = rx.try_recv() { + v.push(t); + } + v + }; + assert!(tokens.is_empty()); + } + + #[test] + fn parse_tool_results_no_tool_use_id() { + let content = vec![json!({"type": "tool_result", "content": "output"})]; + let msgs = collect_messages(|tx| parse_tool_results(&content, tx)); + assert_eq!(msgs.len(), 1); + assert!(msgs[0].tool_call_id.is_none()); + assert_eq!(msgs[0].content, "output"); + } + + #[test] + fn parse_tool_results_null_content() { + let content = vec![json!({"type": "tool_result", "tool_use_id": "id1"})]; + let msgs = collect_messages(|tx| parse_tool_results(&content, tx)); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].content, ""); + } + + #[test] + fn parse_tool_results_other_json_content() { + let content = vec![json!({ + "type": "tool_result", + "tool_use_id": "id1", + "content": {"nested": "object"} + })]; + let msgs = collect_messages(|tx| parse_tool_results(&content, tx)); + assert_eq!(msgs.len(), 1); + // Falls through to serde_json::to_string + assert!(!msgs[0].content.is_empty()); + } + + #[test] + fn parse_assistant_message_empty_content_array() { + let content: Vec = vec![]; + let msgs = collect_messages(|tx| parse_assistant_message(&content, tx)); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].content, ""); + assert!(msgs[0].tool_calls.is_none()); + } + + #[test] + fn parse_assistant_message_unknown_block_type() { + let content = vec![ + json!({"type": "image", "source": {"type": "base64"}}), + json!({"type": "text", "text": "done"}), + ]; + let msgs = collect_messages(|tx| parse_assistant_message(&content, tx)); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].content, "done"); + } + + #[test] + fn parse_assistant_message_tool_use_without_id() { + let content = vec![json!({ + "type": "tool_use", + "name": "Bash", + "input": {"command": "ls"} + })]; + let msgs = collect_messages(|tx| parse_assistant_message(&content, tx)); + assert_eq!(msgs.len(), 1); + let calls = msgs[0].tool_calls.as_ref().unwrap(); + assert_eq!(calls.len(), 1); + assert!(calls[0].id.is_none()); + assert_eq!(calls[0].function.name, "Bash"); + } + + #[test] + fn parse_assistant_message_tool_use_without_input_defaults_to_empty_object() { + let content = vec![json!({"type": "tool_use", "id": "tid", "name": "Read"})]; + let msgs = collect_messages(|tx| parse_assistant_message(&content, tx)); + let calls = msgs[0].tool_calls.as_ref().unwrap(); + let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap(); + assert!(args.is_object()); + assert!(args.as_object().unwrap().is_empty()); + } + + fn make_channels() -> ( + tokio::sync::mpsc::UnboundedSender, + tokio::sync::mpsc::UnboundedReceiver, + std::sync::mpsc::Sender, + std::sync::mpsc::Receiver, + ) { + let (tok_tx, tok_rx) = tokio::sync::mpsc::unbounded_channel(); + let (msg_tx, msg_rx) = std::sync::mpsc::channel(); + (tok_tx, tok_rx, msg_tx, msg_rx) + } + + #[test] + fn process_json_event_result_returns_true() { + let (tok_tx, _tok_rx, msg_tx, _msg_rx) = make_channels(); + let (sid_tx, _sid_rx) = tokio::sync::oneshot::channel::(); + let mut sid_tx_opt = Some(sid_tx); + let json = json!({"type": "result", "subtype": "success"}); + assert!(process_json_event(&json, &tok_tx, &msg_tx, &mut sid_tx_opt)); + } + + #[test] + fn process_json_event_system_returns_false() { + let (tok_tx, _tok_rx, msg_tx, _msg_rx) = make_channels(); + let mut sid_tx = None::>; + let json = json!({"type": "system", "subtype": "init", "apiKeySource": "env"}); + assert!(!process_json_event(&json, &tok_tx, &msg_tx, &mut sid_tx)); + } + + #[test] + fn process_json_event_rate_limit_returns_false() { + let (tok_tx, _tok_rx, msg_tx, _msg_rx) = make_channels(); + let mut sid_tx = None::>; + let json = json!({"type": "rate_limit_event"}); + assert!(!process_json_event(&json, &tok_tx, &msg_tx, &mut sid_tx)); + } + + #[test] + fn process_json_event_unknown_type_returns_false() { + let (tok_tx, _tok_rx, msg_tx, _msg_rx) = make_channels(); + let mut sid_tx = None::>; + let json = json!({"type": "some_future_event"}); + assert!(!process_json_event(&json, &tok_tx, &msg_tx, &mut sid_tx)); + } + + #[test] + fn process_json_event_no_type_returns_false() { + let (tok_tx, _tok_rx, msg_tx, _msg_rx) = make_channels(); + let mut sid_tx = None::>; + let json = json!({"content": "no type field"}); + assert!(!process_json_event(&json, &tok_tx, &msg_tx, &mut sid_tx)); + } + + #[test] + fn process_json_event_captures_session_id() { + let (tok_tx, _tok_rx, msg_tx, _msg_rx) = make_channels(); + let (sid_tx, mut sid_rx) = tokio::sync::oneshot::channel::(); + let mut sid_tx_opt = Some(sid_tx); + let json = json!({"type": "system", "session_id": "sess-abc-123"}); + process_json_event(&json, &tok_tx, &msg_tx, &mut sid_tx_opt); + // sid_tx should have been consumed + assert!(sid_tx_opt.is_none()); + let received = sid_rx.try_recv().unwrap(); + assert_eq!(received, "sess-abc-123"); + } + + #[test] + fn process_json_event_preserves_sid_tx_if_no_session_id() { + let (tok_tx, _tok_rx, msg_tx, _msg_rx) = make_channels(); + let (sid_tx, _sid_rx) = tokio::sync::oneshot::channel::(); + let mut sid_tx_opt = Some(sid_tx); + let json = json!({"type": "system"}); + process_json_event(&json, &tok_tx, &msg_tx, &mut sid_tx_opt); + // sid_tx should still be present since no session_id in event + assert!(sid_tx_opt.is_some()); + } + + #[test] + fn process_json_event_stream_event_forwards_token() { + let (tok_tx, mut tok_rx, msg_tx, _msg_rx) = make_channels(); + let mut sid_tx = None::>; + let json = json!({ + "type": "stream_event", + "session_id": "s1", + "event": { + "type": "content_block_delta", + "delta": {"type": "text_delta", "text": "word"} + } + }); + assert!(!process_json_event(&json, &tok_tx, &msg_tx, &mut sid_tx)); + drop(tok_tx); + let tokens: Vec = { + let mut v = vec![]; + while let Ok(t) = tok_rx.try_recv() { + v.push(t); + } + v + }; + assert_eq!(tokens, vec!["word"]); + } + + #[test] + fn process_json_event_assistant_event_parses_message() { + let (tok_tx, _tok_rx, msg_tx, msg_rx) = make_channels(); + let mut sid_tx = None::>; + let json = json!({ + "type": "assistant", + "message": { + "content": [{"type": "text", "text": "Hi!"}] + } + }); + assert!(!process_json_event(&json, &tok_tx, &msg_tx, &mut sid_tx)); + drop(msg_tx); + let msgs: Vec = msg_rx.try_iter().collect(); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].content, "Hi!"); + } + + #[test] + fn process_json_event_user_event_parses_tool_results() { + let (tok_tx, _tok_rx, msg_tx, msg_rx) = make_channels(); + let mut sid_tx = None::>; + let json = json!({ + "type": "user", + "message": { + "content": [{"type": "tool_result", "tool_use_id": "tid1", "content": "done"}] + } + }); + assert!(!process_json_event(&json, &tok_tx, &msg_tx, &mut sid_tx)); + drop(msg_tx); + let msgs: Vec = msg_rx.try_iter().collect(); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].role, Role::Tool); + assert_eq!(msgs[0].content, "done"); + } + + #[test] + fn process_json_event_assistant_without_content_array_is_noop() { + let (tok_tx, _tok_rx, msg_tx, msg_rx) = make_channels(); + let mut sid_tx = None::>; + let json = json!({ + "type": "assistant", + "message": {"content": "not an array"} + }); + assert!(!process_json_event(&json, &tok_tx, &msg_tx, &mut sid_tx)); + drop(msg_tx); + let msgs: Vec = msg_rx.try_iter().collect(); + assert!(msgs.is_empty()); + } + + #[test] + fn process_json_event_user_without_content_array_is_noop() { + let (tok_tx, _tok_rx, msg_tx, msg_rx) = make_channels(); + let mut sid_tx = None::>; + let json = json!({"type": "user", "message": {"content": null}}); + assert!(!process_json_event(&json, &tok_tx, &msg_tx, &mut sid_tx)); + drop(msg_tx); + let msgs: Vec = msg_rx.try_iter().collect(); + assert!(msgs.is_empty()); + } + + #[test] + fn claude_code_provider_new() { + let _provider = ClaudeCodeProvider::new(); + } }