From 876d75a7471279f3a572a0fcbb51690735ba34a0 Mon Sep 17 00:00:00 2001 From: Hare Date: Thu, 28 May 2026 01:57:04 +0900 Subject: [PATCH] fix: align codex oauth wire behavior --- Cargo.lock | 29 +++ crates/llm-worker/Cargo.toml | 1 + crates/llm-worker/src/llm_client/auth.rs | 9 + crates/llm-worker/src/llm_client/transport.rs | 212 +++++++++++++++++- crates/provider/src/codex_oauth/mod.rs | 4 + 5 files changed, 245 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4a1cbde9..3b5f1708 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1671,6 +1671,7 @@ dependencies = [ "tracing-subscriber", "trybuild", "wiremock", + "zstd", ] [[package]] @@ -4542,3 +4543,31 @@ name = "zmij" version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" + +[[package]] +name = "zstd" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.16+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/crates/llm-worker/Cargo.toml b/crates/llm-worker/Cargo.toml index f40776f3..f3bc5940 100644 --- a/crates/llm-worker/Cargo.toml +++ b/crates/llm-worker/Cargo.toml @@ -16,6 +16,7 @@ tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } tokio-util = "0.7" reqwest = { version = "0.13", default-features = false, features = ["stream", "json", "native-tls", "http2"] } eventsource-stream = "0.2" +zstd = "0.13" llm-worker-macros = { workspace = true } [dev-dependencies] diff --git a/crates/llm-worker/src/llm_client/auth.rs b/crates/llm-worker/src/llm_client/auth.rs index 7d022f7f..a94b06a0 100644 --- a/crates/llm-worker/src/llm_client/auth.rs +++ b/crates/llm-worker/src/llm_client/auth.rs @@ -45,4 +45,13 @@ pub enum AuthRequirement { pub trait AuthProvider: Send + Sync + std::fmt::Debug { /// 1 リクエスト分の認証ヘッダを返す。refresh が必要なら内部で行う。 async fn headers(&self) -> Result, ClientError>; + + /// ChatGPT Codex backend 向けの複合認証かどうか。 + /// + /// transport は provider crate の具象型を知らないため、この hook だけで + /// Codex CLI 互換の wire behavior(conversation header / request compression 等) + /// を切り替える。 + fn is_codex_backend(&self) -> bool { + false + } } diff --git a/crates/llm-worker/src/llm_client/transport.rs b/crates/llm-worker/src/llm_client/transport.rs index 3d32c3cb..334d883e 100644 --- a/crates/llm-worker/src/llm_client/transport.rs +++ b/crates/llm-worker/src/llm_client/transport.rs @@ -11,7 +11,9 @@ use std::time::Duration; use async_trait::async_trait; use eventsource_stream::Eventsource; use futures::{Stream, StreamExt, TryStreamExt}; -use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue, RETRY_AFTER}; +use reqwest::header::{ + ACCEPT, CONTENT_ENCODING, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue, RETRY_AFTER, +}; use super::auth::{AuthProvider, AuthRequirement}; use super::capability::ModelCapability; @@ -149,6 +151,54 @@ impl HttpTransport { Ok(headers) } + + fn is_codex_backend(&self) -> bool { + match &self.auth { + ResolvedAuth::Custom(provider) => provider.is_codex_backend(), + _ => false, + } + } + + fn apply_stream_headers( + &self, + headers: &mut HeaderMap, + request: &Request, + ) -> Result<(), ClientError> { + headers.insert(ACCEPT, HeaderValue::from_static("text/event-stream")); + + if self.is_codex_backend() + && let Some(cache_key) = request.cache_key.as_deref() + { + let value = HeaderValue::from_str(cache_key).map_err(|e| { + ClientError::Config(format!("invalid Codex conversation header: {e}")) + })?; + headers.insert(HeaderName::from_static("session_id"), value.clone()); + headers.insert(HeaderName::from_static("x-client-request-id"), value); + } + + Ok(()) + } + + fn encode_request_body( + &self, + body: &serde_json::Value, + headers: &mut HeaderMap, + ) -> Result { + if !self.is_codex_backend() { + return Ok(RequestBody::Json(body.clone())); + } + + let raw = serde_json::to_vec(body)?; + let compressed = zstd::stream::encode_all(std::io::Cursor::new(raw), 3) + .map_err(|e| ClientError::Config(format!("failed to zstd-compress request: {e}")))?; + headers.insert(CONTENT_ENCODING, HeaderValue::from_static("zstd")); + Ok(RequestBody::CompressedJson(compressed)) + } +} + +enum RequestBody { + Json(serde_json::Value), + CompressedJson(Vec), } impl Clone for HttpTransport { @@ -210,19 +260,19 @@ impl LlmClient for HttpTransport { async fn stream(&self, request: Request) -> Result { let url = self.build_url(); - let headers = self.build_headers().await?; + let mut headers = self.build_headers().await?; + self.apply_stream_headers(&mut headers, &request)?; let body = self .scheme .build_request_body(&self.model_id, &request, &self.capability); + let request_body = self.encode_request_body(&body, &mut headers)?; - let response = self - .http_client - .post(&url) - .headers(headers) - .json(&body) - .send() - .await - .map_err(ClientError::Http)?; + let builder = self.http_client.post(&url).headers(headers); + let builder = match request_body { + RequestBody::Json(body) => builder.json(&body), + RequestBody::CompressedJson(body) => builder.body(body), + }; + let response = builder.send().await.map_err(ClientError::Http)?; if !response.status().is_success() { return Err(classify_error_response(response).await); @@ -255,3 +305,145 @@ impl LlmClient for HttpTransport { Ok(Box::pin(stream)) } } + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[derive(Debug)] + struct TestAuthProvider { + codex: bool, + } + + #[async_trait] + impl AuthProvider for TestAuthProvider { + async fn headers(&self) -> Result, ClientError> { + Ok(vec![ + ( + HeaderName::from_static("authorization"), + HeaderValue::from_static("Bearer test-token"), + ), + ( + HeaderName::from_static("chatgpt-account-id"), + HeaderValue::from_static("account-1"), + ), + ]) + } + + fn is_codex_backend(&self) -> bool { + self.codex + } + } + + #[derive(Clone)] + struct TestScheme; + + impl Scheme for TestScheme { + type State = (); + + fn default_base_url(&self) -> &'static str { + "https://example.test" + } + + fn path(&self, _model_id: &str) -> String { + "/responses".to_string() + } + + fn required_auth(&self) -> AuthRequirement { + AuthRequirement::Bearer + } + + fn build_request_body( + &self, + model_id: &str, + request: &Request, + _capability: &ModelCapability, + ) -> serde_json::Value { + json!({ + "model": model_id, + "input_len": request.items.len(), + "prompt_cache_key": request.cache_key, + }) + } + + fn parse_sse( + &self, + _event_type: &str, + _data: &str, + _state: &mut Self::State, + ) -> Result, ClientError> { + Ok(Vec::new()) + } + + fn default_capability(&self) -> ModelCapability { + ModelCapability::minimal() + } + } + + fn transport(auth: ResolvedAuth) -> HttpTransport { + HttpTransport::new( + TestScheme, + "gpt-test", + "https://example.test", + auth, + ModelCapability::minimal(), + ) + } + + #[tokio::test] + async fn codex_backend_adds_conversation_headers_and_zstd_body() { + let transport = transport(ResolvedAuth::Custom(Arc::new(TestAuthProvider { + codex: true, + }))); + let request = Request::new().user("hello").cache_key("segment-123"); + let mut headers = transport.build_headers().await.unwrap(); + transport + .apply_stream_headers(&mut headers, &request) + .unwrap(); + let body = transport.scheme.build_request_body( + &transport.model_id, + &request, + &transport.capability, + ); + let encoded = transport.encode_request_body(&body, &mut headers).unwrap(); + + assert_eq!(headers.get(ACCEPT).unwrap(), "text/event-stream"); + assert_eq!(headers.get("session_id").unwrap(), "segment-123"); + assert_eq!(headers.get("x-client-request-id").unwrap(), "segment-123"); + assert_eq!(headers.get(CONTENT_ENCODING).unwrap(), "zstd"); + + let RequestBody::CompressedJson(compressed) = encoded else { + panic!("Codex backend request body must be zstd-compressed"); + }; + let decoded = zstd::stream::decode_all(std::io::Cursor::new(compressed)).unwrap(); + let decoded: serde_json::Value = serde_json::from_slice(&decoded).unwrap(); + assert_eq!(decoded["prompt_cache_key"], "segment-123"); + } + + #[tokio::test] + async fn non_codex_request_does_not_get_codex_only_headers_or_compression() { + let transport = transport(ResolvedAuth::ApiKey("api-key".to_string())); + let request = Request::new().user("hello").cache_key("segment-123"); + let mut headers = transport.build_headers().await.unwrap(); + transport + .apply_stream_headers(&mut headers, &request) + .unwrap(); + let body = transport.scheme.build_request_body( + &transport.model_id, + &request, + &transport.capability, + ); + let encoded = transport.encode_request_body(&body, &mut headers).unwrap(); + + assert_eq!(headers.get(ACCEPT).unwrap(), "text/event-stream"); + assert!(headers.get("session_id").is_none()); + assert!(headers.get("x-client-request-id").is_none()); + assert!(headers.get(CONTENT_ENCODING).is_none()); + + let RequestBody::Json(decoded) = encoded else { + panic!("non-Codex request body must remain normal JSON"); + }; + assert_eq!(decoded["prompt_cache_key"], "segment-123"); + } +} diff --git a/crates/provider/src/codex_oauth/mod.rs b/crates/provider/src/codex_oauth/mod.rs index 3a593c47..fbcee8db 100644 --- a/crates/provider/src/codex_oauth/mod.rs +++ b/crates/provider/src/codex_oauth/mod.rs @@ -188,6 +188,10 @@ impl AuthProvider for CodexAuthProvider { .map_err(CodexAuthError::to_client_error)?; Self::build_headers(&snap).map_err(CodexAuthError::to_client_error) } + + fn is_codex_backend(&self) -> bool { + true + } } /// `access_token` の JWT `exp` を見て、期限切れなら true。