fix: align codex oauth wire behavior

This commit is contained in:
Keisuke Hirata 2026-05-28 01:57:04 +09:00
parent 5ae886ea99
commit b3c739867e
5 changed files with 245 additions and 10 deletions

29
Cargo.lock generated
View File

@ -1671,6 +1671,7 @@ dependencies = [
"tracing-subscriber",
"trybuild",
"wiremock",
"zstd",
]
[[package]]
@ -4542,3 +4543,31 @@ name = "zmij"
version = "1.0.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa"
[[package]]
name = "zstd"
version = "0.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a"
dependencies = [
"zstd-safe",
]
[[package]]
name = "zstd-safe"
version = "7.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d"
dependencies = [
"zstd-sys",
]
[[package]]
name = "zstd-sys"
version = "2.0.16+zstd.1.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748"
dependencies = [
"cc",
"pkg-config",
]

View File

@ -16,6 +16,7 @@ tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
tokio-util = "0.7"
reqwest = { version = "0.13", default-features = false, features = ["stream", "json", "native-tls", "http2"] }
eventsource-stream = "0.2"
zstd = "0.13"
llm-worker-macros = { workspace = true }
[dev-dependencies]

View File

@ -45,4 +45,13 @@ pub enum AuthRequirement {
pub trait AuthProvider: Send + Sync + std::fmt::Debug {
/// 1 リクエスト分の認証ヘッダを返す。refresh が必要なら内部で行う。
async fn headers(&self) -> Result<Vec<(HeaderName, HeaderValue)>, ClientError>;
/// ChatGPT Codex backend 向けの複合認証かどうか。
///
/// transport は provider crate の具象型を知らないため、この hook だけで
/// Codex CLI 互換の wire behaviorconversation header / request compression 等)
/// を切り替える。
fn is_codex_backend(&self) -> bool {
false
}
}

View File

@ -11,7 +11,9 @@ use std::time::Duration;
use async_trait::async_trait;
use eventsource_stream::Eventsource;
use futures::{Stream, StreamExt, TryStreamExt};
use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue, RETRY_AFTER};
use reqwest::header::{
ACCEPT, CONTENT_ENCODING, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue, RETRY_AFTER,
};
use super::auth::{AuthProvider, AuthRequirement};
use super::capability::ModelCapability;
@ -149,6 +151,54 @@ impl<S: Scheme> HttpTransport<S> {
Ok(headers)
}
fn is_codex_backend(&self) -> bool {
match &self.auth {
ResolvedAuth::Custom(provider) => provider.is_codex_backend(),
_ => false,
}
}
fn apply_stream_headers(
&self,
headers: &mut HeaderMap,
request: &Request,
) -> Result<(), ClientError> {
headers.insert(ACCEPT, HeaderValue::from_static("text/event-stream"));
if self.is_codex_backend()
&& let Some(cache_key) = request.cache_key.as_deref()
{
let value = HeaderValue::from_str(cache_key).map_err(|e| {
ClientError::Config(format!("invalid Codex conversation header: {e}"))
})?;
headers.insert(HeaderName::from_static("session_id"), value.clone());
headers.insert(HeaderName::from_static("x-client-request-id"), value);
}
Ok(())
}
fn encode_request_body(
&self,
body: &serde_json::Value,
headers: &mut HeaderMap,
) -> Result<RequestBody, ClientError> {
if !self.is_codex_backend() {
return Ok(RequestBody::Json(body.clone()));
}
let raw = serde_json::to_vec(body)?;
let compressed = zstd::stream::encode_all(std::io::Cursor::new(raw), 3)
.map_err(|e| ClientError::Config(format!("failed to zstd-compress request: {e}")))?;
headers.insert(CONTENT_ENCODING, HeaderValue::from_static("zstd"));
Ok(RequestBody::CompressedJson(compressed))
}
}
enum RequestBody {
Json(serde_json::Value),
CompressedJson(Vec<u8>),
}
impl<S: Scheme + Clone> Clone for HttpTransport<S> {
@ -210,19 +260,19 @@ impl<S: Scheme + Clone + 'static> LlmClient for HttpTransport<S> {
async fn stream(&self, request: Request) -> Result<ResponseStream, ClientError> {
let url = self.build_url();
let headers = self.build_headers().await?;
let mut headers = self.build_headers().await?;
self.apply_stream_headers(&mut headers, &request)?;
let body = self
.scheme
.build_request_body(&self.model_id, &request, &self.capability);
let request_body = self.encode_request_body(&body, &mut headers)?;
let response = self
.http_client
.post(&url)
.headers(headers)
.json(&body)
.send()
.await
.map_err(ClientError::Http)?;
let builder = self.http_client.post(&url).headers(headers);
let builder = match request_body {
RequestBody::Json(body) => builder.json(&body),
RequestBody::CompressedJson(body) => builder.body(body),
};
let response = builder.send().await.map_err(ClientError::Http)?;
if !response.status().is_success() {
return Err(classify_error_response(response).await);
@ -255,3 +305,145 @@ impl<S: Scheme + Clone + 'static> LlmClient for HttpTransport<S> {
Ok(Box::pin(stream))
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[derive(Debug)]
struct TestAuthProvider {
codex: bool,
}
#[async_trait]
impl AuthProvider for TestAuthProvider {
async fn headers(&self) -> Result<Vec<(HeaderName, HeaderValue)>, ClientError> {
Ok(vec![
(
HeaderName::from_static("authorization"),
HeaderValue::from_static("Bearer test-token"),
),
(
HeaderName::from_static("chatgpt-account-id"),
HeaderValue::from_static("account-1"),
),
])
}
fn is_codex_backend(&self) -> bool {
self.codex
}
}
#[derive(Clone)]
struct TestScheme;
impl Scheme for TestScheme {
type State = ();
fn default_base_url(&self) -> &'static str {
"https://example.test"
}
fn path(&self, _model_id: &str) -> String {
"/responses".to_string()
}
fn required_auth(&self) -> AuthRequirement {
AuthRequirement::Bearer
}
fn build_request_body(
&self,
model_id: &str,
request: &Request,
_capability: &ModelCapability,
) -> serde_json::Value {
json!({
"model": model_id,
"input_len": request.items.len(),
"prompt_cache_key": request.cache_key,
})
}
fn parse_sse(
&self,
_event_type: &str,
_data: &str,
_state: &mut Self::State,
) -> Result<Vec<Event>, ClientError> {
Ok(Vec::new())
}
fn default_capability(&self) -> ModelCapability {
ModelCapability::minimal()
}
}
fn transport(auth: ResolvedAuth) -> HttpTransport<TestScheme> {
HttpTransport::new(
TestScheme,
"gpt-test",
"https://example.test",
auth,
ModelCapability::minimal(),
)
}
#[tokio::test]
async fn codex_backend_adds_conversation_headers_and_zstd_body() {
let transport = transport(ResolvedAuth::Custom(Arc::new(TestAuthProvider {
codex: true,
})));
let request = Request::new().user("hello").cache_key("segment-123");
let mut headers = transport.build_headers().await.unwrap();
transport
.apply_stream_headers(&mut headers, &request)
.unwrap();
let body = transport.scheme.build_request_body(
&transport.model_id,
&request,
&transport.capability,
);
let encoded = transport.encode_request_body(&body, &mut headers).unwrap();
assert_eq!(headers.get(ACCEPT).unwrap(), "text/event-stream");
assert_eq!(headers.get("session_id").unwrap(), "segment-123");
assert_eq!(headers.get("x-client-request-id").unwrap(), "segment-123");
assert_eq!(headers.get(CONTENT_ENCODING).unwrap(), "zstd");
let RequestBody::CompressedJson(compressed) = encoded else {
panic!("Codex backend request body must be zstd-compressed");
};
let decoded = zstd::stream::decode_all(std::io::Cursor::new(compressed)).unwrap();
let decoded: serde_json::Value = serde_json::from_slice(&decoded).unwrap();
assert_eq!(decoded["prompt_cache_key"], "segment-123");
}
#[tokio::test]
async fn non_codex_request_does_not_get_codex_only_headers_or_compression() {
let transport = transport(ResolvedAuth::ApiKey("api-key".to_string()));
let request = Request::new().user("hello").cache_key("segment-123");
let mut headers = transport.build_headers().await.unwrap();
transport
.apply_stream_headers(&mut headers, &request)
.unwrap();
let body = transport.scheme.build_request_body(
&transport.model_id,
&request,
&transport.capability,
);
let encoded = transport.encode_request_body(&body, &mut headers).unwrap();
assert_eq!(headers.get(ACCEPT).unwrap(), "text/event-stream");
assert!(headers.get("session_id").is_none());
assert!(headers.get("x-client-request-id").is_none());
assert!(headers.get(CONTENT_ENCODING).is_none());
let RequestBody::Json(decoded) = encoded else {
panic!("non-Codex request body must remain normal JSON");
};
assert_eq!(decoded["prompt_cache_key"], "segment-123");
}
}

View File

@ -188,6 +188,10 @@ impl AuthProvider for CodexAuthProvider {
.map_err(CodexAuthError::to_client_error)?;
Self::build_headers(&snap).map_err(CodexAuthError::to_client_error)
}
fn is_codex_backend(&self) -> bool {
true
}
}
/// `access_token` の JWT `exp` を見て、期限切れなら true。