yoi/crates/provider/src/codex_oauth/refresh.rs

174 lines
5.4 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! ChatGPT OAuth トークンの refresh HTTP 呼出。
//!
//! Codex CLI と同じ `POST https://auth.openai.com/oauth/token` 形式。
//! 401 + `error.code` で永続失敗を分類する。
use serde::{Deserialize, Serialize};
use std::time::Duration;
use super::error::{CodexAuthError, PermanentReason};
pub const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
pub const REFRESH_URL: &str = "https://auth.openai.com/oauth/token";
pub const DEFAULT_REFRESH_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Serialize)]
struct RefreshRequest<'a> {
client_id: &'static str,
grant_type: &'static str,
refresh_token: &'a str,
}
#[derive(Deserialize, Debug, Default, Clone)]
pub struct RefreshResponse {
#[serde(default)]
pub id_token: Option<String>,
#[serde(default)]
pub access_token: Option<String>,
#[serde(default)]
pub refresh_token: Option<String>,
}
/// refresh_token を使って新しいトークン群を取得する。
///
/// 永続失敗401 + `refresh_token_(expired|reused|invalidated)`)は
/// `RefreshPermanent`、それ以外は `RefreshTransient`。
pub async fn request_refresh(
client: &reqwest::Client,
endpoint: &str,
refresh_token: &str,
) -> Result<RefreshResponse, CodexAuthError> {
let body = RefreshRequest {
client_id: CLIENT_ID,
grant_type: "refresh_token",
refresh_token,
};
let response = response_with_timeout(
client
.post(endpoint)
.header("Content-Type", "application/json")
.json(&body)
.send(),
DEFAULT_REFRESH_TIMEOUT,
)
.await?;
let status = response.status();
if status.is_success() {
response
.json::<RefreshResponse>()
.await
.map_err(|e| CodexAuthError::RefreshTransient(format!("parse response: {e}")))
} else {
let body = response.text().await.unwrap_or_default();
if status == reqwest::StatusCode::UNAUTHORIZED {
let (reason, message) = classify_permanent(&body);
Err(CodexAuthError::RefreshPermanent { reason, message })
} else {
Err(CodexAuthError::RefreshTransient(format!(
"{status}: {body}"
)))
}
}
}
async fn response_with_timeout(
future: impl std::future::Future<Output = Result<reqwest::Response, reqwest::Error>>,
timeout: Duration,
) -> Result<reqwest::Response, CodexAuthError> {
tokio::time::timeout(timeout, future)
.await
.map_err(|_| {
CodexAuthError::RefreshTransient(format!(
"codex_oauth_refresh timed out after {}s",
timeout.as_secs()
))
})?
.map_err(|e| CodexAuthError::RefreshTransient(format!("send: {e}")))
}
fn classify_permanent(body: &str) -> (PermanentReason, String) {
let code = extract_error_code(body);
let reason = match code.as_deref() {
Some("refresh_token_expired") => PermanentReason::Expired,
Some("refresh_token_reused") => PermanentReason::Reused,
Some("refresh_token_invalidated") => PermanentReason::Revoked,
_ => PermanentReason::Other,
};
let message = match reason {
PermanentReason::Expired => "Your refresh token has expired".to_string(),
PermanentReason::Reused => "Your refresh token was already used".to_string(),
PermanentReason::Revoked => "Your refresh token was revoked".to_string(),
PermanentReason::Other => format!("Unknown 401 from refresh endpoint: {body}"),
};
(reason, message)
}
fn extract_error_code(body: &str) -> Option<String> {
let value: serde_json::Value = serde_json::from_str(body).ok()?;
if let Some(error) = value.get("error") {
if let Some(obj) = error.as_object() {
if let Some(code) = obj.get("code").and_then(|v| v.as_str()) {
return Some(code.to_string());
}
}
if let Some(s) = error.as_str() {
return Some(s.to_string());
}
}
value
.get("code")
.and_then(|v| v.as_str())
.map(str::to_string)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn refresh_response_timeout_is_transient() {
let err = match response_with_timeout(
std::future::pending::<Result<reqwest::Response, reqwest::Error>>(),
Duration::from_millis(5),
)
.await
{
Ok(_) => panic!("expected refresh timeout"),
Err(err) => err,
};
assert!(
matches!(err, CodexAuthError::RefreshTransient(message) if message.contains("timed out"))
);
}
#[test]
fn classify_expired() {
let body = r#"{"error":{"code":"refresh_token_expired"}}"#;
let (r, _) = classify_permanent(body);
assert_eq!(r, PermanentReason::Expired);
}
#[test]
fn classify_reused() {
let body = r#"{"error":{"code":"refresh_token_reused"}}"#;
let (r, _) = classify_permanent(body);
assert_eq!(r, PermanentReason::Reused);
}
#[test]
fn classify_unknown_falls_to_other() {
let body = r#"{"error":{"code":"weird"}}"#;
let (r, _) = classify_permanent(body);
assert_eq!(r, PermanentReason::Other);
}
#[test]
fn classify_top_level_code() {
let body = r#"{"code":"refresh_token_invalidated"}"#;
let (r, _) = classify_permanent(body);
assert_eq!(r, PermanentReason::Revoked);
}
}