feat: auto-refresh expired OAuth token for Claude Code PTY (story 405)
Detect authentication_failed errors from the Claude Code PTY stream and automatically refresh the OAuth access token using the stored refresh token in ~/.claude/.credentials.json. - New module server/src/llm/oauth.rs: reads credentials, calls platform.claude.com/v1/oauth/token with JSON body, writes back - PTY provider detects "error":"authentication_failed" via AtomicBool - chat_stream retries once after successful refresh - Clear error message if refresh also fails On success the retry is transparent. On failure the user sees: "OAuth session expired. Please run claude login to re-authenticate." Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
+30
@@ -0,0 +1,30 @@
|
||||
---
|
||||
name: "Auto-refresh expired OAuth token for Claude Code PTY"
|
||||
---
|
||||
|
||||
# Story 405: Auto-refresh expired OAuth token for Claude Code PTY
|
||||
|
||||
## User Story
|
||||
|
||||
As a storkit user with a Claude Max subscription, I want the server to automatically refresh my expired OAuth token so that chat, Matrix, and WhatsApp integrations don't stop working when the token expires.
|
||||
|
||||
## Acceptance Criteria
|
||||
|
||||
### Detection
|
||||
- [ ] When the Claude Code PTY returns an `authentication_failed` error, storkit detects it instead of passing the raw 401 JSON to the user
|
||||
|
||||
### Auto-refresh (credentials exist, refresh token valid)
|
||||
- [ ] Storkit reads the OAuth refresh token from `~/.claude/.credentials.json`
|
||||
- [ ] Storkit calls the Anthropic OAuth token refresh endpoint (`https://console.anthropic.com/v1/oauth/token` with `grant_type=refresh_token`) to obtain a new access token
|
||||
- [ ] Storkit writes the refreshed access token (and new expiresAt) back to `~/.claude/.credentials.json`
|
||||
- [ ] After a successful refresh, storkit automatically retries the original chat request
|
||||
- [ ] The refresh+retry is transparent to the user — they see no error
|
||||
|
||||
### Full login required (no credentials, or refresh token also expired)
|
||||
- [ ] If `.credentials.json` doesn't exist or the refresh call itself fails, storkit surfaces a clear error: "OAuth session expired. Please run `claude login` to re-authenticate."
|
||||
- [ ] The error message is surfaced through the normal chat stream (not just server logs)
|
||||
|
||||
## Out of Scope
|
||||
|
||||
- Implementing the full interactive `claude login` browser OAuth flow inside storkit
|
||||
- Proactive token refresh before expiry (refreshing on demand when the error occurs is sufficient)
|
||||
@@ -1,4 +1,5 @@
|
||||
pub mod chat;
|
||||
pub mod oauth;
|
||||
pub mod prompts;
|
||||
pub mod providers;
|
||||
pub mod types;
|
||||
|
||||
@@ -0,0 +1,215 @@
|
||||
use crate::slog;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// The client ID used by Claude Code for OAuth.
|
||||
const CLAUDE_CODE_CLIENT_ID: &str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e";
|
||||
const TOKEN_ENDPOINT: &str = "https://platform.claude.com/v1/oauth/token";
|
||||
|
||||
/// OAuth credentials as stored in `~/.claude/.credentials.json`.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct OAuthCredentials {
|
||||
pub access_token: String,
|
||||
pub refresh_token: String,
|
||||
pub expires_at: u64,
|
||||
#[serde(default)]
|
||||
pub scopes: Vec<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub subscription_type: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub rate_limit_tier: Option<String>,
|
||||
}
|
||||
|
||||
/// Top-level structure of `~/.claude/.credentials.json`.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CredentialsFile {
|
||||
pub claude_ai_oauth: OAuthCredentials,
|
||||
}
|
||||
|
||||
/// Response from the Anthropic OAuth token refresh endpoint.
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenRefreshResponse {
|
||||
access_token: String,
|
||||
expires_in: u64,
|
||||
#[allow(dead_code)]
|
||||
token_type: Option<String>,
|
||||
}
|
||||
|
||||
/// Error from the Anthropic OAuth token refresh endpoint.
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenRefreshError {
|
||||
#[allow(dead_code)]
|
||||
error: String,
|
||||
error_description: Option<String>,
|
||||
}
|
||||
|
||||
/// Returns the path to `~/.claude/.credentials.json`.
|
||||
fn credentials_path() -> Result<PathBuf, String> {
|
||||
let home = std::env::var("HOME").map_err(|_| "HOME not set".to_string())?;
|
||||
Ok(PathBuf::from(home).join(".claude").join(".credentials.json"))
|
||||
}
|
||||
|
||||
/// Read OAuth credentials from disk.
|
||||
pub fn read_credentials() -> Result<CredentialsFile, String> {
|
||||
let path = credentials_path()?;
|
||||
let data = std::fs::read_to_string(&path).map_err(|e| {
|
||||
format!(
|
||||
"Cannot read {}: {e}. Run `claude login` to authenticate.",
|
||||
path.display()
|
||||
)
|
||||
})?;
|
||||
serde_json::from_str(&data).map_err(|e| {
|
||||
format!(
|
||||
"Failed to parse {}: {e}",
|
||||
path.display()
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Write updated credentials back to disk with 0600 permissions.
|
||||
pub fn write_credentials(creds: &CredentialsFile) -> Result<(), String> {
|
||||
let path = credentials_path()?;
|
||||
let data = serde_json::to_string_pretty(creds)
|
||||
.map_err(|e| format!("Failed to serialize credentials: {e}"))?;
|
||||
std::fs::write(&path, &data)
|
||||
.map_err(|e| format!("Failed to write {}: {e}", path.display()))?;
|
||||
|
||||
// Restore 0600 permissions
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let perms = std::fs::Permissions::from_mode(0o600);
|
||||
std::fs::set_permissions(&path, perms)
|
||||
.map_err(|e| format!("Failed to set permissions on {}: {e}", path.display()))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Refresh the OAuth access token using the stored refresh token.
|
||||
///
|
||||
/// On success, updates `~/.claude/.credentials.json` with the new access
|
||||
/// token and expiry, then returns `Ok(())`.
|
||||
///
|
||||
/// On failure (e.g. refresh token expired), returns an error string.
|
||||
pub async fn refresh_access_token() -> Result<(), String> {
|
||||
slog!("[oauth] Attempting to refresh OAuth access token");
|
||||
|
||||
let mut creds = read_credentials()?;
|
||||
let refresh_token = creds.claude_ai_oauth.refresh_token.clone();
|
||||
|
||||
if refresh_token.is_empty() {
|
||||
return Err(
|
||||
"No refresh token found. Run `claude login` to authenticate.".to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(TOKEN_ENDPOINT)
|
||||
.json(&serde_json::json!({
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": CLAUDE_CODE_CLIENT_ID,
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("OAuth refresh request failed: {e}"))?;
|
||||
|
||||
let status = resp.status();
|
||||
let body = resp
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to read refresh response: {e}"))?;
|
||||
|
||||
if !status.is_success() {
|
||||
// Try to parse a structured error
|
||||
if let Ok(err) = serde_json::from_str::<TokenRefreshError>(&body) {
|
||||
let desc = err
|
||||
.error_description
|
||||
.unwrap_or_else(|| "unknown error".to_string());
|
||||
slog!("[oauth] Refresh failed: {desc} (full body: {body})");
|
||||
return Err(format!(
|
||||
"OAuth session expired. Please run `claude login` to re-authenticate. ({desc})"
|
||||
));
|
||||
}
|
||||
return Err(format!(
|
||||
"OAuth session expired. Please run `claude login` to re-authenticate. (HTTP {status})"
|
||||
));
|
||||
}
|
||||
|
||||
let token_resp: TokenRefreshResponse = serde_json::from_str(&body)
|
||||
.map_err(|e| format!("Failed to parse refresh response: {e}"))?;
|
||||
|
||||
let now_ms = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_millis() as u64)
|
||||
.unwrap_or(0);
|
||||
|
||||
creds.claude_ai_oauth.access_token = token_resp.access_token;
|
||||
creds.claude_ai_oauth.expires_at = now_ms + (token_resp.expires_in * 1000);
|
||||
|
||||
write_credentials(&creds)?;
|
||||
|
||||
slog!(
|
||||
"[oauth] Successfully refreshed access token, expires at {}",
|
||||
creds.claude_ai_oauth.expires_at
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_credentials_file() {
|
||||
let json = r#"{
|
||||
"claudeAiOauth": {
|
||||
"accessToken": "sk-ant-oat01-test",
|
||||
"refreshToken": "sk-ant-ort01-test",
|
||||
"expiresAt": 1774466144677,
|
||||
"scopes": ["user:inference"],
|
||||
"subscriptionType": "max",
|
||||
"rateLimitTier": "default_claude_max_20x"
|
||||
}
|
||||
}"#;
|
||||
let creds: CredentialsFile = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(creds.claude_ai_oauth.access_token, "sk-ant-oat01-test");
|
||||
assert_eq!(creds.claude_ai_oauth.refresh_token, "sk-ant-ort01-test");
|
||||
assert_eq!(creds.claude_ai_oauth.expires_at, 1774466144677);
|
||||
assert_eq!(creds.claude_ai_oauth.subscription_type.as_deref(), Some("max"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_credentials_roundtrip() {
|
||||
let creds = CredentialsFile {
|
||||
claude_ai_oauth: OAuthCredentials {
|
||||
access_token: "access".to_string(),
|
||||
refresh_token: "refresh".to_string(),
|
||||
expires_at: 12345,
|
||||
scopes: vec!["user:inference".to_string()],
|
||||
subscription_type: Some("max".to_string()),
|
||||
rate_limit_tier: None,
|
||||
},
|
||||
};
|
||||
let json = serde_json::to_string(&creds).unwrap();
|
||||
let parsed: CredentialsFile = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.claude_ai_oauth.access_token, "access");
|
||||
assert_eq!(parsed.claude_ai_oauth.refresh_token, "refresh");
|
||||
// rate_limit_tier should be omitted from JSON (skip_serializing_if)
|
||||
assert!(!json.contains("rateLimitTier"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credentials_path_uses_home() {
|
||||
// Just verify it doesn't panic and returns a path ending in .credentials.json
|
||||
if std::env::var("HOME").is_ok() {
|
||||
let path = credentials_path().unwrap();
|
||||
assert!(path.ends_with(".claude/.credentials.json"));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -53,10 +53,6 @@ impl ClaudeCodeProvider {
|
||||
T: FnMut(&str) + Send,
|
||||
A: FnMut(&str) + Send,
|
||||
{
|
||||
let message = user_message.to_string();
|
||||
let cwd = project_root.to_string();
|
||||
let resume_id = session_id.map(|s| s.to_string());
|
||||
let sys_prompt = system_prompt.map(|s| s.to_string());
|
||||
let cancelled = Arc::new(AtomicBool::new(false));
|
||||
let cancelled_clone = cancelled.clone();
|
||||
|
||||
@@ -70,66 +66,98 @@ impl ClaudeCodeProvider {
|
||||
}
|
||||
});
|
||||
|
||||
let (token_tx, mut token_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
|
||||
let (thinking_tx, mut thinking_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
|
||||
let (activity_tx, mut activity_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
|
||||
let (msg_tx, msg_rx) = std::sync::mpsc::channel::<Message>();
|
||||
let (sid_tx, sid_rx) = tokio::sync::oneshot::channel::<String>();
|
||||
// Attempt up to 2 times: first try, then retry after OAuth refresh.
|
||||
for attempt in 0..2 {
|
||||
let message = user_message.to_string();
|
||||
let cwd = project_root.to_string();
|
||||
let resume_id = session_id.map(|s| s.to_string());
|
||||
let sys_prompt = system_prompt.map(|s| s.to_string());
|
||||
let cancelled_inner = cancelled.clone();
|
||||
let auth_failed = Arc::new(AtomicBool::new(false));
|
||||
let auth_failed_clone = auth_failed.clone();
|
||||
|
||||
let pty_handle = tokio::task::spawn_blocking(move || {
|
||||
run_pty_session(
|
||||
&message,
|
||||
&cwd,
|
||||
resume_id.as_deref(),
|
||||
sys_prompt.as_deref(),
|
||||
cancelled,
|
||||
token_tx,
|
||||
thinking_tx,
|
||||
activity_tx,
|
||||
msg_tx,
|
||||
sid_tx,
|
||||
)
|
||||
});
|
||||
let (token_tx, mut token_rx) =
|
||||
tokio::sync::mpsc::unbounded_channel::<String>();
|
||||
let (thinking_tx, mut thinking_rx) =
|
||||
tokio::sync::mpsc::unbounded_channel::<String>();
|
||||
let (activity_tx, mut activity_rx) =
|
||||
tokio::sync::mpsc::unbounded_channel::<String>();
|
||||
let (msg_tx, msg_rx) = std::sync::mpsc::channel::<Message>();
|
||||
let (sid_tx, sid_rx) = tokio::sync::oneshot::channel::<String>();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = token_rx.recv() => match msg {
|
||||
Some(t) => on_token(&t),
|
||||
None => break,
|
||||
},
|
||||
msg = thinking_rx.recv() => if let Some(t) = msg {
|
||||
on_thinking(&t);
|
||||
},
|
||||
msg = activity_rx.recv() => if let Some(name) = msg {
|
||||
on_activity(&name);
|
||||
},
|
||||
let pty_handle = tokio::task::spawn_blocking(move || {
|
||||
run_pty_session(
|
||||
&message,
|
||||
&cwd,
|
||||
resume_id.as_deref(),
|
||||
sys_prompt.as_deref(),
|
||||
cancelled_inner,
|
||||
auth_failed_clone,
|
||||
token_tx,
|
||||
thinking_tx,
|
||||
activity_tx,
|
||||
msg_tx,
|
||||
sid_tx,
|
||||
)
|
||||
});
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = token_rx.recv() => match msg {
|
||||
Some(t) => on_token(&t),
|
||||
None => break,
|
||||
},
|
||||
msg = thinking_rx.recv() => if let Some(t) = msg {
|
||||
on_thinking(&t);
|
||||
},
|
||||
msg = activity_rx.recv() => if let Some(name) = msg {
|
||||
on_activity(&name);
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Drain any remaining activity/thinking messages that were buffered
|
||||
// when the token channel closed.
|
||||
while let Ok(t) = thinking_rx.try_recv() {
|
||||
on_thinking(&t);
|
||||
}
|
||||
while let Ok(name) = activity_rx.try_recv() {
|
||||
on_activity(&name);
|
||||
}
|
||||
|
||||
pty_handle
|
||||
.await
|
||||
.map_err(|e| format!("PTY task panicked: {e}"))??;
|
||||
|
||||
// Check if the PTY session failed due to expired OAuth token.
|
||||
if auth_failed.load(Ordering::Relaxed) && attempt == 0 {
|
||||
slog!("[oauth] Authentication failed, attempting token refresh");
|
||||
match crate::llm::oauth::refresh_access_token().await {
|
||||
Ok(()) => {
|
||||
slog!("[oauth] Token refreshed, retrying request");
|
||||
on_token("\n*Refreshing authentication token...*\n");
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(format!(
|
||||
"OAuth session expired. Please run `claude login` to re-authenticate. ({e})"
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let captured_session_id = sid_rx.await.ok();
|
||||
slog!("[pty-debug] RECEIVED session_id: {:?}", captured_session_id);
|
||||
let structured_messages: Vec<Message> = msg_rx.try_iter().collect();
|
||||
|
||||
return Ok(ClaudeCodeResult {
|
||||
messages: structured_messages,
|
||||
session_id: captured_session_id,
|
||||
});
|
||||
}
|
||||
|
||||
// Drain any remaining activity/thinking messages that were buffered when
|
||||
// the token channel closed.
|
||||
while let Ok(t) = thinking_rx.try_recv() {
|
||||
on_thinking(&t);
|
||||
}
|
||||
// Drain any remaining activity messages that were buffered when the
|
||||
// token channel closed. The select! loop breaks on token_rx → None,
|
||||
// but activity_rx may still hold signals sent in the same instant.
|
||||
while let Ok(name) = activity_rx.try_recv() {
|
||||
on_activity(&name);
|
||||
}
|
||||
|
||||
pty_handle
|
||||
.await
|
||||
.map_err(|e| format!("PTY task panicked: {e}"))??;
|
||||
|
||||
let captured_session_id = sid_rx.await.ok();
|
||||
slog!("[pty-debug] RECEIVED session_id: {:?}", captured_session_id);
|
||||
let structured_messages: Vec<Message> = msg_rx.try_iter().collect();
|
||||
|
||||
Ok(ClaudeCodeResult {
|
||||
messages: structured_messages,
|
||||
session_id: captured_session_id,
|
||||
})
|
||||
// Should never reach here, but just in case.
|
||||
Err("Authentication failed after retry".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,6 +180,7 @@ fn run_pty_session(
|
||||
resume_session_id: Option<&str>,
|
||||
_system_prompt: Option<&str>,
|
||||
cancelled: Arc<AtomicBool>,
|
||||
auth_failed: Arc<AtomicBool>,
|
||||
token_tx: tokio::sync::mpsc::UnboundedSender<String>,
|
||||
thinking_tx: tokio::sync::mpsc::UnboundedSender<String>,
|
||||
activity_tx: tokio::sync::mpsc::UnboundedSender<String>,
|
||||
@@ -278,6 +307,7 @@ fn run_pty_session(
|
||||
&activity_tx,
|
||||
&msg_tx,
|
||||
&mut sid_tx,
|
||||
&auth_failed,
|
||||
)
|
||||
{
|
||||
got_result = true;
|
||||
@@ -304,6 +334,7 @@ fn run_pty_session(
|
||||
&activity_tx,
|
||||
&msg_tx,
|
||||
&mut sid_tx,
|
||||
&auth_failed,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -341,6 +372,7 @@ fn run_pty_session(
|
||||
///
|
||||
/// Returns `true` if a `result` event was received (signals session completion).
|
||||
/// Captures the session ID from the first event that carries it.
|
||||
/// Sets `auth_failed` to `true` if an `authentication_failed` error is detected.
|
||||
fn process_json_event(
|
||||
json: &serde_json::Value,
|
||||
token_tx: &tokio::sync::mpsc::UnboundedSender<String>,
|
||||
@@ -348,6 +380,7 @@ fn process_json_event(
|
||||
activity_tx: &tokio::sync::mpsc::UnboundedSender<String>,
|
||||
msg_tx: &std::sync::mpsc::Sender<Message>,
|
||||
sid_tx: &mut Option<tokio::sync::oneshot::Sender<String>>,
|
||||
auth_failed: &AtomicBool,
|
||||
) -> bool {
|
||||
let event_type = match json.get("type").and_then(|t| t.as_str()) {
|
||||
Some(t) => t,
|
||||
@@ -364,6 +397,12 @@ fn process_json_event(
|
||||
}
|
||||
}
|
||||
|
||||
// Detect authentication_failed at the top level of any event.
|
||||
if json.get("error").and_then(|e| e.as_str()) == Some("authentication_failed") {
|
||||
slog!("[pty-debug] Detected authentication_failed error");
|
||||
auth_failed.store(true, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
match event_type {
|
||||
"stream_event" => {
|
||||
if let Some(event) = json.get("event") {
|
||||
@@ -916,7 +955,8 @@ mod tests {
|
||||
&thi_tx,
|
||||
&act_tx,
|
||||
&msg_tx,
|
||||
&mut sid_tx_opt
|
||||
&mut sid_tx_opt,
|
||||
&AtomicBool::new(false),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -931,7 +971,8 @@ mod tests {
|
||||
&thi_tx,
|
||||
&act_tx,
|
||||
&msg_tx,
|
||||
&mut sid_tx
|
||||
&mut sid_tx,
|
||||
&AtomicBool::new(false),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -946,7 +987,8 @@ mod tests {
|
||||
&thi_tx,
|
||||
&act_tx,
|
||||
&msg_tx,
|
||||
&mut sid_tx
|
||||
&mut sid_tx,
|
||||
&AtomicBool::new(false),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -961,7 +1003,8 @@ mod tests {
|
||||
&thi_tx,
|
||||
&act_tx,
|
||||
&msg_tx,
|
||||
&mut sid_tx
|
||||
&mut sid_tx,
|
||||
&AtomicBool::new(false),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -976,7 +1019,8 @@ mod tests {
|
||||
&thi_tx,
|
||||
&act_tx,
|
||||
&msg_tx,
|
||||
&mut sid_tx
|
||||
&mut sid_tx,
|
||||
&AtomicBool::new(false),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -986,7 +1030,7 @@ mod tests {
|
||||
let (sid_tx, mut sid_rx) = tokio::sync::oneshot::channel::<String>();
|
||||
let mut sid_tx_opt = Some(sid_tx);
|
||||
let json = json!({"type": "system", "session_id": "sess-abc-123"});
|
||||
process_json_event(&json, &tok_tx, &thi_tx, &act_tx, &msg_tx, &mut sid_tx_opt);
|
||||
process_json_event(&json, &tok_tx, &thi_tx, &act_tx, &msg_tx, &mut sid_tx_opt, &AtomicBool::new(false));
|
||||
// sid_tx should have been consumed
|
||||
assert!(sid_tx_opt.is_none());
|
||||
let received = sid_rx.try_recv().unwrap();
|
||||
@@ -999,7 +1043,7 @@ mod tests {
|
||||
let (sid_tx, _sid_rx) = tokio::sync::oneshot::channel::<String>();
|
||||
let mut sid_tx_opt = Some(sid_tx);
|
||||
let json = json!({"type": "system"});
|
||||
process_json_event(&json, &tok_tx, &thi_tx, &act_tx, &msg_tx, &mut sid_tx_opt);
|
||||
process_json_event(&json, &tok_tx, &thi_tx, &act_tx, &msg_tx, &mut sid_tx_opt, &AtomicBool::new(false));
|
||||
// sid_tx should still be present since no session_id in event
|
||||
assert!(sid_tx_opt.is_some());
|
||||
}
|
||||
@@ -1023,7 +1067,8 @@ mod tests {
|
||||
&thi_tx,
|
||||
&act_tx,
|
||||
&msg_tx,
|
||||
&mut sid_tx
|
||||
&mut sid_tx,
|
||||
&AtomicBool::new(false),
|
||||
));
|
||||
drop(tok_tx);
|
||||
let tokens: Vec<String> = {
|
||||
@@ -1058,7 +1103,8 @@ mod tests {
|
||||
&thi_tx,
|
||||
&act_tx,
|
||||
&msg_tx,
|
||||
&mut sid_tx
|
||||
&mut sid_tx,
|
||||
&AtomicBool::new(false),
|
||||
));
|
||||
drop(act_tx);
|
||||
let activities: Vec<String> = {
|
||||
@@ -1091,7 +1137,8 @@ mod tests {
|
||||
&thi_tx,
|
||||
&act_tx,
|
||||
&msg_tx,
|
||||
&mut sid_tx
|
||||
&mut sid_tx,
|
||||
&AtomicBool::new(false),
|
||||
));
|
||||
drop(act_tx);
|
||||
let activities: Vec<String> = {
|
||||
@@ -1124,7 +1171,8 @@ mod tests {
|
||||
&thi_tx,
|
||||
&act_tx,
|
||||
&msg_tx,
|
||||
&mut sid_tx
|
||||
&mut sid_tx,
|
||||
&AtomicBool::new(false),
|
||||
));
|
||||
drop(act_tx);
|
||||
let activities: Vec<String> = {
|
||||
@@ -1154,7 +1202,8 @@ mod tests {
|
||||
&thi_tx,
|
||||
&act_tx,
|
||||
&msg_tx,
|
||||
&mut sid_tx
|
||||
&mut sid_tx,
|
||||
&AtomicBool::new(false),
|
||||
));
|
||||
drop(act_tx);
|
||||
let activities: Vec<String> = {
|
||||
@@ -1183,7 +1232,8 @@ mod tests {
|
||||
&thi_tx,
|
||||
&act_tx,
|
||||
&msg_tx,
|
||||
&mut sid_tx
|
||||
&mut sid_tx,
|
||||
&AtomicBool::new(false),
|
||||
));
|
||||
drop(msg_tx);
|
||||
let msgs: Vec<Message> = msg_rx.try_iter().collect();
|
||||
@@ -1207,7 +1257,8 @@ mod tests {
|
||||
&thi_tx,
|
||||
&act_tx,
|
||||
&msg_tx,
|
||||
&mut sid_tx
|
||||
&mut sid_tx,
|
||||
&AtomicBool::new(false),
|
||||
));
|
||||
drop(msg_tx);
|
||||
let msgs: Vec<Message> = msg_rx.try_iter().collect();
|
||||
@@ -1230,7 +1281,8 @@ mod tests {
|
||||
&thi_tx,
|
||||
&act_tx,
|
||||
&msg_tx,
|
||||
&mut sid_tx
|
||||
&mut sid_tx,
|
||||
&AtomicBool::new(false),
|
||||
));
|
||||
drop(msg_tx);
|
||||
let msgs: Vec<Message> = msg_rx.try_iter().collect();
|
||||
@@ -1248,13 +1300,61 @@ mod tests {
|
||||
&thi_tx,
|
||||
&act_tx,
|
||||
&msg_tx,
|
||||
&mut sid_tx
|
||||
&mut sid_tx,
|
||||
&AtomicBool::new(false),
|
||||
));
|
||||
drop(msg_tx);
|
||||
let msgs: Vec<Message> = msg_rx.try_iter().collect();
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn process_json_event_detects_authentication_failed() {
|
||||
let (tok_tx, _tok_rx, thi_tx, _thi_rx, act_tx, _act_rx, msg_tx, _msg_rx) = make_channels();
|
||||
let mut sid_tx = None::<tokio::sync::oneshot::Sender<String>>;
|
||||
let auth_failed = AtomicBool::new(false);
|
||||
let json = json!({
|
||||
"type": "assistant",
|
||||
"error": "authentication_failed",
|
||||
"message": {
|
||||
"content": [{"type": "text", "text": "Failed to authenticate."}]
|
||||
}
|
||||
});
|
||||
assert!(!process_json_event(
|
||||
&json,
|
||||
&tok_tx,
|
||||
&thi_tx,
|
||||
&act_tx,
|
||||
&msg_tx,
|
||||
&mut sid_tx,
|
||||
&auth_failed,
|
||||
));
|
||||
assert!(auth_failed.load(Ordering::Relaxed));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn process_json_event_no_auth_failed_for_normal_events() {
|
||||
let (tok_tx, _tok_rx, thi_tx, _thi_rx, act_tx, _act_rx, msg_tx, _msg_rx) = make_channels();
|
||||
let mut sid_tx = None::<tokio::sync::oneshot::Sender<String>>;
|
||||
let auth_failed = AtomicBool::new(false);
|
||||
let json = json!({
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
"content": [{"type": "text", "text": "Hello!"}]
|
||||
}
|
||||
});
|
||||
assert!(!process_json_event(
|
||||
&json,
|
||||
&tok_tx,
|
||||
&thi_tx,
|
||||
&act_tx,
|
||||
&msg_tx,
|
||||
&mut sid_tx,
|
||||
&auth_failed,
|
||||
));
|
||||
assert!(!auth_failed.load(Ordering::Relaxed));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn claude_code_provider_new() {
|
||||
let _provider = ClaudeCodeProvider::new();
|
||||
|
||||
Reference in New Issue
Block a user