From 710b604b7c164883f568e0418be5ff8384fbc7d5 Mon Sep 17 00:00:00 2001 From: Timmy Date: Thu, 26 Mar 2026 19:58:04 +0000 Subject: [PATCH] feat: auto-refresh expired OAuth token for Claude Code PTY (story 405) Detect authentication_failed errors from the Claude Code PTY stream and automatically refresh the OAuth access token using the stored refresh token in ~/.claude/.credentials.json. - New module server/src/llm/oauth.rs: reads credentials, calls platform.claude.com/v1/oauth/token with JSON body, writes back - PTY provider detects "error":"authentication_failed" via AtomicBool - chat_stream retries once after successful refresh - Clear error message if refresh also fails On success the retry is transparent. On failure the user sees: "OAuth session expired. Please run claude login to re-authenticate." Co-Authored-By: Claude Opus 4.6 (1M context) --- ...expired_oauth_token_for_claude_code_pty.md | 30 +++ server/src/llm/mod.rs | 1 + server/src/llm/oauth.rs | 215 +++++++++++++++ server/src/llm/providers/claude_code.rs | 250 ++++++++++++------ 4 files changed, 421 insertions(+), 75 deletions(-) create mode 100644 .storkit/work/5_done/405_story_auto_refresh_expired_oauth_token_for_claude_code_pty.md create mode 100644 server/src/llm/oauth.rs diff --git a/.storkit/work/5_done/405_story_auto_refresh_expired_oauth_token_for_claude_code_pty.md b/.storkit/work/5_done/405_story_auto_refresh_expired_oauth_token_for_claude_code_pty.md new file mode 100644 index 00000000..636dd2ab --- /dev/null +++ b/.storkit/work/5_done/405_story_auto_refresh_expired_oauth_token_for_claude_code_pty.md @@ -0,0 +1,30 @@ +--- +name: "Auto-refresh expired OAuth token for Claude Code PTY" +--- + +# Story 405: Auto-refresh expired OAuth token for Claude Code PTY + +## User Story + +As a storkit user with a Claude Max subscription, I want the server to automatically refresh my expired OAuth token so that chat, Matrix, and WhatsApp integrations don't stop working when the token expires. + +## Acceptance Criteria + +### Detection +- [ ] When the Claude Code PTY returns an `authentication_failed` error, storkit detects it instead of passing the raw 401 JSON to the user + +### Auto-refresh (credentials exist, refresh token valid) +- [ ] Storkit reads the OAuth refresh token from `~/.claude/.credentials.json` +- [ ] Storkit calls the Anthropic OAuth token refresh endpoint (`https://console.anthropic.com/v1/oauth/token` with `grant_type=refresh_token`) to obtain a new access token +- [ ] Storkit writes the refreshed access token (and new expiresAt) back to `~/.claude/.credentials.json` +- [ ] After a successful refresh, storkit automatically retries the original chat request +- [ ] The refresh+retry is transparent to the user — they see no error + +### Full login required (no credentials, or refresh token also expired) +- [ ] If `.credentials.json` doesn't exist or the refresh call itself fails, storkit surfaces a clear error: "OAuth session expired. Please run `claude login` to re-authenticate." +- [ ] The error message is surfaced through the normal chat stream (not just server logs) + +## Out of Scope + +- Implementing the full interactive `claude login` browser OAuth flow inside storkit +- Proactive token refresh before expiry (refreshing on demand when the error occurs is sufficient) diff --git a/server/src/llm/mod.rs b/server/src/llm/mod.rs index 431529bd..fb658d7b 100644 --- a/server/src/llm/mod.rs +++ b/server/src/llm/mod.rs @@ -1,4 +1,5 @@ pub mod chat; +pub mod oauth; pub mod prompts; pub mod providers; pub mod types; diff --git a/server/src/llm/oauth.rs b/server/src/llm/oauth.rs new file mode 100644 index 00000000..f986c863 --- /dev/null +++ b/server/src/llm/oauth.rs @@ -0,0 +1,215 @@ +use crate::slog; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; + +/// The client ID used by Claude Code for OAuth. +const CLAUDE_CODE_CLIENT_ID: &str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"; +const TOKEN_ENDPOINT: &str = "https://platform.claude.com/v1/oauth/token"; + +/// OAuth credentials as stored in `~/.claude/.credentials.json`. +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct OAuthCredentials { + pub access_token: String, + pub refresh_token: String, + pub expires_at: u64, + #[serde(default)] + pub scopes: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub subscription_type: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub rate_limit_tier: Option, +} + +/// Top-level structure of `~/.claude/.credentials.json`. +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CredentialsFile { + pub claude_ai_oauth: OAuthCredentials, +} + +/// Response from the Anthropic OAuth token refresh endpoint. +#[derive(Debug, Deserialize)] +struct TokenRefreshResponse { + access_token: String, + expires_in: u64, + #[allow(dead_code)] + token_type: Option, +} + +/// Error from the Anthropic OAuth token refresh endpoint. +#[derive(Debug, Deserialize)] +struct TokenRefreshError { + #[allow(dead_code)] + error: String, + error_description: Option, +} + +/// Returns the path to `~/.claude/.credentials.json`. +fn credentials_path() -> Result { + let home = std::env::var("HOME").map_err(|_| "HOME not set".to_string())?; + Ok(PathBuf::from(home).join(".claude").join(".credentials.json")) +} + +/// Read OAuth credentials from disk. +pub fn read_credentials() -> Result { + let path = credentials_path()?; + let data = std::fs::read_to_string(&path).map_err(|e| { + format!( + "Cannot read {}: {e}. Run `claude login` to authenticate.", + path.display() + ) + })?; + serde_json::from_str(&data).map_err(|e| { + format!( + "Failed to parse {}: {e}", + path.display() + ) + }) +} + +/// Write updated credentials back to disk with 0600 permissions. +pub fn write_credentials(creds: &CredentialsFile) -> Result<(), String> { + let path = credentials_path()?; + let data = serde_json::to_string_pretty(creds) + .map_err(|e| format!("Failed to serialize credentials: {e}"))?; + std::fs::write(&path, &data) + .map_err(|e| format!("Failed to write {}: {e}", path.display()))?; + + // Restore 0600 permissions + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let perms = std::fs::Permissions::from_mode(0o600); + std::fs::set_permissions(&path, perms) + .map_err(|e| format!("Failed to set permissions on {}: {e}", path.display()))?; + } + + Ok(()) +} + +/// Refresh the OAuth access token using the stored refresh token. +/// +/// On success, updates `~/.claude/.credentials.json` with the new access +/// token and expiry, then returns `Ok(())`. +/// +/// On failure (e.g. refresh token expired), returns an error string. +pub async fn refresh_access_token() -> Result<(), String> { + slog!("[oauth] Attempting to refresh OAuth access token"); + + let mut creds = read_credentials()?; + let refresh_token = creds.claude_ai_oauth.refresh_token.clone(); + + if refresh_token.is_empty() { + return Err( + "No refresh token found. Run `claude login` to authenticate.".to_string(), + ); + } + + let client = reqwest::Client::new(); + let resp = client + .post(TOKEN_ENDPOINT) + .json(&serde_json::json!({ + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": CLAUDE_CODE_CLIENT_ID, + })) + .send() + .await + .map_err(|e| format!("OAuth refresh request failed: {e}"))?; + + let status = resp.status(); + let body = resp + .text() + .await + .map_err(|e| format!("Failed to read refresh response: {e}"))?; + + if !status.is_success() { + // Try to parse a structured error + if let Ok(err) = serde_json::from_str::(&body) { + let desc = err + .error_description + .unwrap_or_else(|| "unknown error".to_string()); + slog!("[oauth] Refresh failed: {desc} (full body: {body})"); + return Err(format!( + "OAuth session expired. Please run `claude login` to re-authenticate. ({desc})" + )); + } + return Err(format!( + "OAuth session expired. Please run `claude login` to re-authenticate. (HTTP {status})" + )); + } + + let token_resp: TokenRefreshResponse = serde_json::from_str(&body) + .map_err(|e| format!("Failed to parse refresh response: {e}"))?; + + let now_ms = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_millis() as u64) + .unwrap_or(0); + + creds.claude_ai_oauth.access_token = token_resp.access_token; + creds.claude_ai_oauth.expires_at = now_ms + (token_resp.expires_in * 1000); + + write_credentials(&creds)?; + + slog!( + "[oauth] Successfully refreshed access token, expires at {}", + creds.claude_ai_oauth.expires_at + ); + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_credentials_file() { + let json = r#"{ + "claudeAiOauth": { + "accessToken": "sk-ant-oat01-test", + "refreshToken": "sk-ant-ort01-test", + "expiresAt": 1774466144677, + "scopes": ["user:inference"], + "subscriptionType": "max", + "rateLimitTier": "default_claude_max_20x" + } + }"#; + let creds: CredentialsFile = serde_json::from_str(json).unwrap(); + assert_eq!(creds.claude_ai_oauth.access_token, "sk-ant-oat01-test"); + assert_eq!(creds.claude_ai_oauth.refresh_token, "sk-ant-ort01-test"); + assert_eq!(creds.claude_ai_oauth.expires_at, 1774466144677); + assert_eq!(creds.claude_ai_oauth.subscription_type.as_deref(), Some("max")); + } + + #[test] + fn serialize_credentials_roundtrip() { + let creds = CredentialsFile { + claude_ai_oauth: OAuthCredentials { + access_token: "access".to_string(), + refresh_token: "refresh".to_string(), + expires_at: 12345, + scopes: vec!["user:inference".to_string()], + subscription_type: Some("max".to_string()), + rate_limit_tier: None, + }, + }; + let json = serde_json::to_string(&creds).unwrap(); + let parsed: CredentialsFile = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.claude_ai_oauth.access_token, "access"); + assert_eq!(parsed.claude_ai_oauth.refresh_token, "refresh"); + // rate_limit_tier should be omitted from JSON (skip_serializing_if) + assert!(!json.contains("rateLimitTier")); + } + + #[test] + fn credentials_path_uses_home() { + // Just verify it doesn't panic and returns a path ending in .credentials.json + if std::env::var("HOME").is_ok() { + let path = credentials_path().unwrap(); + assert!(path.ends_with(".claude/.credentials.json")); + } + } +} diff --git a/server/src/llm/providers/claude_code.rs b/server/src/llm/providers/claude_code.rs index eeaed417..32734b52 100644 --- a/server/src/llm/providers/claude_code.rs +++ b/server/src/llm/providers/claude_code.rs @@ -53,10 +53,6 @@ impl ClaudeCodeProvider { T: FnMut(&str) + Send, A: FnMut(&str) + Send, { - let message = user_message.to_string(); - let cwd = project_root.to_string(); - let resume_id = session_id.map(|s| s.to_string()); - let sys_prompt = system_prompt.map(|s| s.to_string()); let cancelled = Arc::new(AtomicBool::new(false)); let cancelled_clone = cancelled.clone(); @@ -70,66 +66,98 @@ impl ClaudeCodeProvider { } }); - let (token_tx, mut token_rx) = tokio::sync::mpsc::unbounded_channel::(); - let (thinking_tx, mut thinking_rx) = tokio::sync::mpsc::unbounded_channel::(); - let (activity_tx, mut activity_rx) = tokio::sync::mpsc::unbounded_channel::(); - let (msg_tx, msg_rx) = std::sync::mpsc::channel::(); - let (sid_tx, sid_rx) = tokio::sync::oneshot::channel::(); + // Attempt up to 2 times: first try, then retry after OAuth refresh. + for attempt in 0..2 { + let message = user_message.to_string(); + let cwd = project_root.to_string(); + let resume_id = session_id.map(|s| s.to_string()); + let sys_prompt = system_prompt.map(|s| s.to_string()); + let cancelled_inner = cancelled.clone(); + let auth_failed = Arc::new(AtomicBool::new(false)); + let auth_failed_clone = auth_failed.clone(); - let pty_handle = tokio::task::spawn_blocking(move || { - run_pty_session( - &message, - &cwd, - resume_id.as_deref(), - sys_prompt.as_deref(), - cancelled, - token_tx, - thinking_tx, - activity_tx, - msg_tx, - sid_tx, - ) - }); + let (token_tx, mut token_rx) = + tokio::sync::mpsc::unbounded_channel::(); + let (thinking_tx, mut thinking_rx) = + tokio::sync::mpsc::unbounded_channel::(); + let (activity_tx, mut activity_rx) = + tokio::sync::mpsc::unbounded_channel::(); + let (msg_tx, msg_rx) = std::sync::mpsc::channel::(); + let (sid_tx, sid_rx) = tokio::sync::oneshot::channel::(); - loop { - tokio::select! { - msg = token_rx.recv() => match msg { - Some(t) => on_token(&t), - None => break, - }, - msg = thinking_rx.recv() => if let Some(t) = msg { - on_thinking(&t); - }, - msg = activity_rx.recv() => if let Some(name) = msg { - on_activity(&name); - }, + let pty_handle = tokio::task::spawn_blocking(move || { + run_pty_session( + &message, + &cwd, + resume_id.as_deref(), + sys_prompt.as_deref(), + cancelled_inner, + auth_failed_clone, + token_tx, + thinking_tx, + activity_tx, + msg_tx, + sid_tx, + ) + }); + + loop { + tokio::select! { + msg = token_rx.recv() => match msg { + Some(t) => on_token(&t), + None => break, + }, + msg = thinking_rx.recv() => if let Some(t) = msg { + on_thinking(&t); + }, + msg = activity_rx.recv() => if let Some(name) = msg { + on_activity(&name); + }, + } } + + // Drain any remaining activity/thinking messages that were buffered + // when the token channel closed. + while let Ok(t) = thinking_rx.try_recv() { + on_thinking(&t); + } + while let Ok(name) = activity_rx.try_recv() { + on_activity(&name); + } + + pty_handle + .await + .map_err(|e| format!("PTY task panicked: {e}"))??; + + // Check if the PTY session failed due to expired OAuth token. + if auth_failed.load(Ordering::Relaxed) && attempt == 0 { + slog!("[oauth] Authentication failed, attempting token refresh"); + match crate::llm::oauth::refresh_access_token().await { + Ok(()) => { + slog!("[oauth] Token refreshed, retrying request"); + on_token("\n*Refreshing authentication token...*\n"); + continue; + } + Err(e) => { + return Err(format!( + "OAuth session expired. Please run `claude login` to re-authenticate. ({e})" + )); + } + } + } + + let captured_session_id = sid_rx.await.ok(); + slog!("[pty-debug] RECEIVED session_id: {:?}", captured_session_id); + let structured_messages: Vec = msg_rx.try_iter().collect(); + + return Ok(ClaudeCodeResult { + messages: structured_messages, + session_id: captured_session_id, + }); } - // Drain any remaining activity/thinking messages that were buffered when - // the token channel closed. - while let Ok(t) = thinking_rx.try_recv() { - on_thinking(&t); - } - // Drain any remaining activity messages that were buffered when the - // token channel closed. The select! loop breaks on token_rx → None, - // but activity_rx may still hold signals sent in the same instant. - while let Ok(name) = activity_rx.try_recv() { - on_activity(&name); - } - - pty_handle - .await - .map_err(|e| format!("PTY task panicked: {e}"))??; - - let captured_session_id = sid_rx.await.ok(); - slog!("[pty-debug] RECEIVED session_id: {:?}", captured_session_id); - let structured_messages: Vec = msg_rx.try_iter().collect(); - - Ok(ClaudeCodeResult { - messages: structured_messages, - session_id: captured_session_id, - }) + // Should never reach here, but just in case. + Err("Authentication failed after retry".to_string()) } } @@ -152,6 +180,7 @@ fn run_pty_session( resume_session_id: Option<&str>, _system_prompt: Option<&str>, cancelled: Arc, + auth_failed: Arc, token_tx: tokio::sync::mpsc::UnboundedSender, thinking_tx: tokio::sync::mpsc::UnboundedSender, activity_tx: tokio::sync::mpsc::UnboundedSender, @@ -278,6 +307,7 @@ fn run_pty_session( &activity_tx, &msg_tx, &mut sid_tx, + &auth_failed, ) { got_result = true; @@ -304,6 +334,7 @@ fn run_pty_session( &activity_tx, &msg_tx, &mut sid_tx, + &auth_failed, ); } } @@ -341,6 +372,7 @@ fn run_pty_session( /// /// Returns `true` if a `result` event was received (signals session completion). /// Captures the session ID from the first event that carries it. +/// Sets `auth_failed` to `true` if an `authentication_failed` error is detected. fn process_json_event( json: &serde_json::Value, token_tx: &tokio::sync::mpsc::UnboundedSender, @@ -348,6 +380,7 @@ fn process_json_event( activity_tx: &tokio::sync::mpsc::UnboundedSender, msg_tx: &std::sync::mpsc::Sender, sid_tx: &mut Option>, + auth_failed: &AtomicBool, ) -> bool { let event_type = match json.get("type").and_then(|t| t.as_str()) { Some(t) => t, @@ -364,6 +397,12 @@ fn process_json_event( } } + // Detect authentication_failed at the top level of any event. + if json.get("error").and_then(|e| e.as_str()) == Some("authentication_failed") { + slog!("[pty-debug] Detected authentication_failed error"); + auth_failed.store(true, Ordering::Relaxed); + } + match event_type { "stream_event" => { if let Some(event) = json.get("event") { @@ -916,7 +955,8 @@ mod tests { &thi_tx, &act_tx, &msg_tx, - &mut sid_tx_opt + &mut sid_tx_opt, + &AtomicBool::new(false), )); } @@ -931,7 +971,8 @@ mod tests { &thi_tx, &act_tx, &msg_tx, - &mut sid_tx + &mut sid_tx, + &AtomicBool::new(false), )); } @@ -946,7 +987,8 @@ mod tests { &thi_tx, &act_tx, &msg_tx, - &mut sid_tx + &mut sid_tx, + &AtomicBool::new(false), )); } @@ -961,7 +1003,8 @@ mod tests { &thi_tx, &act_tx, &msg_tx, - &mut sid_tx + &mut sid_tx, + &AtomicBool::new(false), )); } @@ -976,7 +1019,8 @@ mod tests { &thi_tx, &act_tx, &msg_tx, - &mut sid_tx + &mut sid_tx, + &AtomicBool::new(false), )); } @@ -986,7 +1030,7 @@ mod tests { let (sid_tx, mut sid_rx) = tokio::sync::oneshot::channel::(); let mut sid_tx_opt = Some(sid_tx); let json = json!({"type": "system", "session_id": "sess-abc-123"}); - process_json_event(&json, &tok_tx, &thi_tx, &act_tx, &msg_tx, &mut sid_tx_opt); + process_json_event(&json, &tok_tx, &thi_tx, &act_tx, &msg_tx, &mut sid_tx_opt, &AtomicBool::new(false)); // sid_tx should have been consumed assert!(sid_tx_opt.is_none()); let received = sid_rx.try_recv().unwrap(); @@ -999,7 +1043,7 @@ mod tests { let (sid_tx, _sid_rx) = tokio::sync::oneshot::channel::(); let mut sid_tx_opt = Some(sid_tx); let json = json!({"type": "system"}); - process_json_event(&json, &tok_tx, &thi_tx, &act_tx, &msg_tx, &mut sid_tx_opt); + process_json_event(&json, &tok_tx, &thi_tx, &act_tx, &msg_tx, &mut sid_tx_opt, &AtomicBool::new(false)); // sid_tx should still be present since no session_id in event assert!(sid_tx_opt.is_some()); } @@ -1023,7 +1067,8 @@ mod tests { &thi_tx, &act_tx, &msg_tx, - &mut sid_tx + &mut sid_tx, + &AtomicBool::new(false), )); drop(tok_tx); let tokens: Vec = { @@ -1058,7 +1103,8 @@ mod tests { &thi_tx, &act_tx, &msg_tx, - &mut sid_tx + &mut sid_tx, + &AtomicBool::new(false), )); drop(act_tx); let activities: Vec = { @@ -1091,7 +1137,8 @@ mod tests { &thi_tx, &act_tx, &msg_tx, - &mut sid_tx + &mut sid_tx, + &AtomicBool::new(false), )); drop(act_tx); let activities: Vec = { @@ -1124,7 +1171,8 @@ mod tests { &thi_tx, &act_tx, &msg_tx, - &mut sid_tx + &mut sid_tx, + &AtomicBool::new(false), )); drop(act_tx); let activities: Vec = { @@ -1154,7 +1202,8 @@ mod tests { &thi_tx, &act_tx, &msg_tx, - &mut sid_tx + &mut sid_tx, + &AtomicBool::new(false), )); drop(act_tx); let activities: Vec = { @@ -1183,7 +1232,8 @@ mod tests { &thi_tx, &act_tx, &msg_tx, - &mut sid_tx + &mut sid_tx, + &AtomicBool::new(false), )); drop(msg_tx); let msgs: Vec = msg_rx.try_iter().collect(); @@ -1207,7 +1257,8 @@ mod tests { &thi_tx, &act_tx, &msg_tx, - &mut sid_tx + &mut sid_tx, + &AtomicBool::new(false), )); drop(msg_tx); let msgs: Vec = msg_rx.try_iter().collect(); @@ -1230,7 +1281,8 @@ mod tests { &thi_tx, &act_tx, &msg_tx, - &mut sid_tx + &mut sid_tx, + &AtomicBool::new(false), )); drop(msg_tx); let msgs: Vec = msg_rx.try_iter().collect(); @@ -1248,13 +1300,61 @@ mod tests { &thi_tx, &act_tx, &msg_tx, - &mut sid_tx + &mut sid_tx, + &AtomicBool::new(false), )); drop(msg_tx); let msgs: Vec = msg_rx.try_iter().collect(); assert!(msgs.is_empty()); } + #[test] + fn process_json_event_detects_authentication_failed() { + let (tok_tx, _tok_rx, thi_tx, _thi_rx, act_tx, _act_rx, msg_tx, _msg_rx) = make_channels(); + let mut sid_tx = None::>; + let auth_failed = AtomicBool::new(false); + let json = json!({ + "type": "assistant", + "error": "authentication_failed", + "message": { + "content": [{"type": "text", "text": "Failed to authenticate."}] + } + }); + assert!(!process_json_event( + &json, + &tok_tx, + &thi_tx, + &act_tx, + &msg_tx, + &mut sid_tx, + &auth_failed, + )); + assert!(auth_failed.load(Ordering::Relaxed)); + } + + #[test] + fn process_json_event_no_auth_failed_for_normal_events() { + let (tok_tx, _tok_rx, thi_tx, _thi_rx, act_tx, _act_rx, msg_tx, _msg_rx) = make_channels(); + let mut sid_tx = None::>; + let auth_failed = AtomicBool::new(false); + let json = json!({ + "type": "assistant", + "message": { + "content": [{"type": "text", "text": "Hello!"}] + } + }); + assert!(!process_json_event( + &json, + &tok_tx, + &thi_tx, + &act_tx, + &msg_tx, + &mut sid_tx, + &auth_failed, + )); + assert!(!auth_failed.load(Ordering::Relaxed)); + } + #[test] fn claude_code_provider_new() { let _provider = ClaudeCodeProvider::new();