Files
storkit/server/src/llm/providers/claude_code.rs

390 lines
14 KiB
Rust
Raw Normal View History

use portable_pty::{CommandBuilder, PtySize, native_pty_system};
use std::io::{BufRead, BufReader};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::sync::watch;
use crate::llm::types::CompletionResponse;
/// Manages a Claude Code session via a pseudo-terminal.
///
/// Spawns `claude -p` in a PTY so isatty() returns true (which may
/// influence billing), while using `--output-format stream-json` to
/// get clean, structured NDJSON output instead of TUI escape sequences.
///
/// Supports session resumption: if a `session_id` is provided, passes
/// `--resume <id>` so Claude Code loads the prior conversation transcript
/// from disk and continues with full context.
pub struct ClaudeCodeProvider;
impl ClaudeCodeProvider {
pub fn new() -> Self {
Self
}
pub async fn chat_stream<F>(
&self,
user_message: &str,
project_root: &str,
session_id: Option<&str>,
cancel_rx: &mut watch::Receiver<bool>,
mut on_token: F,
) -> Result<CompletionResponse, String>
where
F: FnMut(&str) + Send,
{
let message = user_message.to_string();
let cwd = project_root.to_string();
let resume_id = session_id.map(|s| s.to_string());
let cancelled = Arc::new(AtomicBool::new(false));
let cancelled_clone = cancelled.clone();
let mut cancel_watch = cancel_rx.clone();
tokio::spawn(async move {
while cancel_watch.changed().await.is_ok() {
if *cancel_watch.borrow() {
cancelled_clone.store(true, Ordering::Relaxed);
break;
}
}
});
let (token_tx, mut token_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
let (sid_tx, sid_rx) = tokio::sync::oneshot::channel::<String>();
let pty_handle = tokio::task::spawn_blocking(move || {
run_pty_session(&message, &cwd, resume_id.as_deref(), cancelled, token_tx, sid_tx)
});
let mut full_output = String::new();
while let Some(token) = token_rx.recv().await {
full_output.push_str(&token);
on_token(&token);
}
pty_handle
.await
.map_err(|e| format!("PTY task panicked: {e}"))??;
let captured_session_id = sid_rx.await.ok();
Ok(CompletionResponse {
content: Some(full_output),
tool_calls: None,
session_id: captured_session_id,
})
}
}
/// Run `claude -p` with stream-json output inside a PTY.
///
/// The PTY makes isatty() return true. The `-p` flag gives us
/// single-shot non-interactive mode with structured output.
fn run_pty_session(
user_message: &str,
cwd: &str,
resume_session_id: Option<&str>,
cancelled: Arc<AtomicBool>,
token_tx: tokio::sync::mpsc::UnboundedSender<String>,
sid_tx: tokio::sync::oneshot::Sender<String>,
) -> Result<(), String> {
let pty_system = native_pty_system();
let pair = pty_system
.openpty(PtySize {
rows: 50,
cols: 200,
pixel_width: 0,
pixel_height: 0,
})
.map_err(|e| format!("Failed to open PTY: {e}"))?;
let mut cmd = CommandBuilder::new("claude");
cmd.arg("-p");
cmd.arg(user_message);
if let Some(sid) = resume_session_id {
cmd.arg("--resume");
cmd.arg(sid);
}
cmd.arg("--output-format");
cmd.arg("stream-json");
cmd.arg("--verbose");
cmd.cwd(cwd);
// Keep TERM reasonable but disable color
cmd.env("NO_COLOR", "1");
// Allow nested spawning when the server itself runs inside Claude Code
cmd.env("CLAUDECODE", "");
eprintln!(
"[pty-debug] Spawning: claude -p \"{}\" {} --output-format stream-json --verbose",
user_message,
resume_session_id
.map(|s| format!("--resume {s}"))
.unwrap_or_default()
);
let mut child = pair
.slave
.spawn_command(cmd)
.map_err(|e| format!("Failed to spawn claude: {e}"))?;
eprintln!(
"[pty-debug] Process spawned, pid: {:?}",
child.process_id()
);
drop(pair.slave);
let reader = pair
.master
.try_clone_reader()
.map_err(|e| format!("Failed to clone PTY reader: {e}"))?;
// We don't need to write anything — -p mode takes prompt as arg
drop(pair.master);
// Read NDJSON lines from stdout
let (line_tx, line_rx) = std::sync::mpsc::channel::<Option<String>>();
std::thread::spawn(move || {
let buf_reader = BufReader::new(reader);
eprintln!("[pty-debug] Reader thread started");
for line in buf_reader.lines() {
match line {
Ok(l) => {
eprintln!("[pty-debug] raw line: {}", l);
if line_tx.send(Some(l)).is_err() {
break;
}
}
Err(e) => {
eprintln!("[pty-debug] read error: {e}");
let _ = line_tx.send(None);
break;
}
}
}
eprintln!("[pty-debug] Reader thread done");
let _ = line_tx.send(None);
});
let mut got_result = false;
let mut sid_tx = Some(sid_tx);
loop {
if cancelled.load(Ordering::Relaxed) {
let _ = child.kill();
return Err("Cancelled".to_string());
}
match line_rx.recv_timeout(std::time::Duration::from_millis(500)) {
Ok(Some(line)) => {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
eprintln!(
"[pty-debug] processing: {}...",
&trimmed[..trimmed.len().min(120)]
);
// Try to parse as JSON
if let Ok(json) = serde_json::from_str::<serde_json::Value>(trimmed)
&& let Some(event_type) = json.get("type").and_then(|t| t.as_str())
{
// Capture session_id from any event that has it
if let Some(tx) = sid_tx.take() {
if let Some(sid) = json.get("session_id").and_then(|s| s.as_str()) {
let _ = tx.send(sid.to_string());
} else {
// Put it back if this event didn't have a session_id
sid_tx = Some(tx);
}
}
match event_type {
// Streaming deltas (when --include-partial-messages is used)
"stream_event" => {
if let Some(event) = json.get("event") {
handle_stream_event(event, &token_tx);
}
}
// Complete assistant message
"assistant" => {
if let Some(message) = json.get("message")
&& let Some(content) =
message.get("content").and_then(|c| c.as_array())
{
for block in content {
if let Some(text) =
block.get("text").and_then(|t| t.as_str())
{
let _ = token_tx.send(text.to_string());
}
}
}
}
// Final result with usage stats
"result" => {
if let Some(cost) =
json.get("total_cost_usd").and_then(|c| c.as_f64())
{
let _ =
token_tx.send(format!("\n\n---\n_Cost: ${cost:.4}_\n"));
}
if let Some(usage) = json.get("usage") {
let input = usage
.get("input_tokens")
.and_then(|t| t.as_u64())
.unwrap_or(0);
let output = usage
.get("output_tokens")
.and_then(|t| t.as_u64())
.unwrap_or(0);
let cached = usage
.get("cache_read_input_tokens")
.and_then(|t| t.as_u64())
.unwrap_or(0);
let _ = token_tx.send(format!(
"_Tokens: {input} in / {output} out / {cached} cached_\n"
));
}
got_result = true;
}
// System init — log billing info
"system" => {
let api_source = json
.get("apiKeySource")
.and_then(|s| s.as_str())
.unwrap_or("unknown");
let model = json
.get("model")
.and_then(|s| s.as_str())
.unwrap_or("unknown");
let _ = token_tx
.send(format!("_[{model} | apiKey: {api_source}]_\n\n"));
}
// Rate limit info
"rate_limit_event" => {
if let Some(info) = json.get("rate_limit_info") {
let status = info
.get("status")
.and_then(|s| s.as_str())
.unwrap_or("unknown");
let limit_type = info
.get("rateLimitType")
.and_then(|s| s.as_str())
.unwrap_or("unknown");
let _ = token_tx.send(format!(
"_[rate limit: {status} ({limit_type})]_\n\n"
));
}
}
_ => {}
}
}
// Ignore non-JSON lines (terminal escape sequences)
if got_result {
break;
}
}
Ok(None) => break, // EOF
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
// Check if child has exited
if let Ok(Some(_status)) = child.try_wait() {
// Drain remaining lines
while let Ok(Some(line)) = line_rx.try_recv() {
let trimmed = line.trim();
if let Ok(json) = serde_json::from_str::<serde_json::Value>(trimmed)
&& let Some(event) = json
.get("type")
.filter(|t| t.as_str() == Some("stream_event"))
.and_then(|_| json.get("event"))
{
handle_stream_event(event, &token_tx);
}
}
break;
}
}
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => break,
}
// Don't set got_result here — just let the process finish naturally
let _ = got_result;
}
// Wait briefly for Claude Code to flush its session transcript to disk.
// The `result` event means the API response is done, but the process
// still needs to write the conversation to the JSONL session file.
match child.try_wait() {
Ok(Some(_)) => {} // Already exited
_ => {
// Give it up to 2 seconds to exit cleanly
for _ in 0..20 {
std::thread::sleep(std::time::Duration::from_millis(100));
if let Ok(Some(_)) = child.try_wait() {
break;
}
}
// If still running after 2s, kill it
let _ = child.kill();
}
}
Ok(())
}
/// Extract text from a stream event and send to the token channel.
fn handle_stream_event(
event: &serde_json::Value,
token_tx: &tokio::sync::mpsc::UnboundedSender<String>,
) {
let event_type = event.get("type").and_then(|t| t.as_str()).unwrap_or("");
match event_type {
// Text content streaming
"content_block_delta" => {
if let Some(delta) = event.get("delta") {
let delta_type = delta.get("type").and_then(|t| t.as_str()).unwrap_or("");
match delta_type {
"text_delta" => {
if let Some(text) = delta.get("text").and_then(|t| t.as_str()) {
let _ = token_tx.send(text.to_string());
}
}
"thinking_delta" => {
if let Some(thinking) =
delta.get("thinking").and_then(|t| t.as_str())
{
let _ = token_tx.send(format!("[thinking] {thinking}"));
}
}
_ => {}
}
}
}
// Message complete — log usage info
"message_delta" => {
if let Some(usage) = event.get("usage") {
let output_tokens = usage
.get("output_tokens")
.and_then(|t| t.as_u64())
.unwrap_or(0);
let _ = token_tx.send(format!("\n[tokens: {output_tokens} output]\n"));
}
}
// Log errors
"error" => {
if let Some(error) = event.get("error") {
let msg = error
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("unknown error");
let _ = token_tx.send(format!("\n[error: {msg}]\n"));
}
}
_ => {}
}
}