fix: align codex oauth wire behavior
This commit is contained in:
parent
00596d3f9a
commit
876d75a747
29
Cargo.lock
generated
29
Cargo.lock
generated
|
|
@ -1671,6 +1671,7 @@ dependencies = [
|
|||
"tracing-subscriber",
|
||||
"trybuild",
|
||||
"wiremock",
|
||||
"zstd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -4542,3 +4543,31 @@ name = "zmij"
|
|||
version = "1.0.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa"
|
||||
|
||||
[[package]]
|
||||
name = "zstd"
|
||||
version = "0.13.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a"
|
||||
dependencies = [
|
||||
"zstd-safe",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstd-safe"
|
||||
version = "7.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d"
|
||||
dependencies = [
|
||||
"zstd-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstd-sys"
|
||||
version = "2.0.16+zstd.1.5.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"pkg-config",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
|
|||
tokio-util = "0.7"
|
||||
reqwest = { version = "0.13", default-features = false, features = ["stream", "json", "native-tls", "http2"] }
|
||||
eventsource-stream = "0.2"
|
||||
zstd = "0.13"
|
||||
llm-worker-macros = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
|
|
|
|||
|
|
@ -45,4 +45,13 @@ pub enum AuthRequirement {
|
|||
pub trait AuthProvider: Send + Sync + std::fmt::Debug {
|
||||
/// 1 リクエスト分の認証ヘッダを返す。refresh が必要なら内部で行う。
|
||||
async fn headers(&self) -> Result<Vec<(HeaderName, HeaderValue)>, ClientError>;
|
||||
|
||||
/// ChatGPT Codex backend 向けの複合認証かどうか。
|
||||
///
|
||||
/// transport は provider crate の具象型を知らないため、この hook だけで
|
||||
/// Codex CLI 互換の wire behavior(conversation header / request compression 等)
|
||||
/// を切り替える。
|
||||
fn is_codex_backend(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,9 @@ use std::time::Duration;
|
|||
use async_trait::async_trait;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::{Stream, StreamExt, TryStreamExt};
|
||||
use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue, RETRY_AFTER};
|
||||
use reqwest::header::{
|
||||
ACCEPT, CONTENT_ENCODING, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue, RETRY_AFTER,
|
||||
};
|
||||
|
||||
use super::auth::{AuthProvider, AuthRequirement};
|
||||
use super::capability::ModelCapability;
|
||||
|
|
@ -149,6 +151,54 @@ impl<S: Scheme> HttpTransport<S> {
|
|||
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
fn is_codex_backend(&self) -> bool {
|
||||
match &self.auth {
|
||||
ResolvedAuth::Custom(provider) => provider.is_codex_backend(),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_stream_headers(
|
||||
&self,
|
||||
headers: &mut HeaderMap,
|
||||
request: &Request,
|
||||
) -> Result<(), ClientError> {
|
||||
headers.insert(ACCEPT, HeaderValue::from_static("text/event-stream"));
|
||||
|
||||
if self.is_codex_backend()
|
||||
&& let Some(cache_key) = request.cache_key.as_deref()
|
||||
{
|
||||
let value = HeaderValue::from_str(cache_key).map_err(|e| {
|
||||
ClientError::Config(format!("invalid Codex conversation header: {e}"))
|
||||
})?;
|
||||
headers.insert(HeaderName::from_static("session_id"), value.clone());
|
||||
headers.insert(HeaderName::from_static("x-client-request-id"), value);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn encode_request_body(
|
||||
&self,
|
||||
body: &serde_json::Value,
|
||||
headers: &mut HeaderMap,
|
||||
) -> Result<RequestBody, ClientError> {
|
||||
if !self.is_codex_backend() {
|
||||
return Ok(RequestBody::Json(body.clone()));
|
||||
}
|
||||
|
||||
let raw = serde_json::to_vec(body)?;
|
||||
let compressed = zstd::stream::encode_all(std::io::Cursor::new(raw), 3)
|
||||
.map_err(|e| ClientError::Config(format!("failed to zstd-compress request: {e}")))?;
|
||||
headers.insert(CONTENT_ENCODING, HeaderValue::from_static("zstd"));
|
||||
Ok(RequestBody::CompressedJson(compressed))
|
||||
}
|
||||
}
|
||||
|
||||
enum RequestBody {
|
||||
Json(serde_json::Value),
|
||||
CompressedJson(Vec<u8>),
|
||||
}
|
||||
|
||||
impl<S: Scheme + Clone> Clone for HttpTransport<S> {
|
||||
|
|
@ -210,19 +260,19 @@ impl<S: Scheme + Clone + 'static> LlmClient for HttpTransport<S> {
|
|||
|
||||
async fn stream(&self, request: Request) -> Result<ResponseStream, ClientError> {
|
||||
let url = self.build_url();
|
||||
let headers = self.build_headers().await?;
|
||||
let mut headers = self.build_headers().await?;
|
||||
self.apply_stream_headers(&mut headers, &request)?;
|
||||
let body = self
|
||||
.scheme
|
||||
.build_request_body(&self.model_id, &request, &self.capability);
|
||||
let request_body = self.encode_request_body(&body, &mut headers)?;
|
||||
|
||||
let response = self
|
||||
.http_client
|
||||
.post(&url)
|
||||
.headers(headers)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(ClientError::Http)?;
|
||||
let builder = self.http_client.post(&url).headers(headers);
|
||||
let builder = match request_body {
|
||||
RequestBody::Json(body) => builder.json(&body),
|
||||
RequestBody::CompressedJson(body) => builder.body(body),
|
||||
};
|
||||
let response = builder.send().await.map_err(ClientError::Http)?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(classify_error_response(response).await);
|
||||
|
|
@ -255,3 +305,145 @@ impl<S: Scheme + Clone + 'static> LlmClient for HttpTransport<S> {
|
|||
Ok(Box::pin(stream))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TestAuthProvider {
|
||||
codex: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AuthProvider for TestAuthProvider {
|
||||
async fn headers(&self) -> Result<Vec<(HeaderName, HeaderValue)>, ClientError> {
|
||||
Ok(vec![
|
||||
(
|
||||
HeaderName::from_static("authorization"),
|
||||
HeaderValue::from_static("Bearer test-token"),
|
||||
),
|
||||
(
|
||||
HeaderName::from_static("chatgpt-account-id"),
|
||||
HeaderValue::from_static("account-1"),
|
||||
),
|
||||
])
|
||||
}
|
||||
|
||||
fn is_codex_backend(&self) -> bool {
|
||||
self.codex
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct TestScheme;
|
||||
|
||||
impl Scheme for TestScheme {
|
||||
type State = ();
|
||||
|
||||
fn default_base_url(&self) -> &'static str {
|
||||
"https://example.test"
|
||||
}
|
||||
|
||||
fn path(&self, _model_id: &str) -> String {
|
||||
"/responses".to_string()
|
||||
}
|
||||
|
||||
fn required_auth(&self) -> AuthRequirement {
|
||||
AuthRequirement::Bearer
|
||||
}
|
||||
|
||||
fn build_request_body(
|
||||
&self,
|
||||
model_id: &str,
|
||||
request: &Request,
|
||||
_capability: &ModelCapability,
|
||||
) -> serde_json::Value {
|
||||
json!({
|
||||
"model": model_id,
|
||||
"input_len": request.items.len(),
|
||||
"prompt_cache_key": request.cache_key,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_sse(
|
||||
&self,
|
||||
_event_type: &str,
|
||||
_data: &str,
|
||||
_state: &mut Self::State,
|
||||
) -> Result<Vec<Event>, ClientError> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
fn default_capability(&self) -> ModelCapability {
|
||||
ModelCapability::minimal()
|
||||
}
|
||||
}
|
||||
|
||||
fn transport(auth: ResolvedAuth) -> HttpTransport<TestScheme> {
|
||||
HttpTransport::new(
|
||||
TestScheme,
|
||||
"gpt-test",
|
||||
"https://example.test",
|
||||
auth,
|
||||
ModelCapability::minimal(),
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn codex_backend_adds_conversation_headers_and_zstd_body() {
|
||||
let transport = transport(ResolvedAuth::Custom(Arc::new(TestAuthProvider {
|
||||
codex: true,
|
||||
})));
|
||||
let request = Request::new().user("hello").cache_key("segment-123");
|
||||
let mut headers = transport.build_headers().await.unwrap();
|
||||
transport
|
||||
.apply_stream_headers(&mut headers, &request)
|
||||
.unwrap();
|
||||
let body = transport.scheme.build_request_body(
|
||||
&transport.model_id,
|
||||
&request,
|
||||
&transport.capability,
|
||||
);
|
||||
let encoded = transport.encode_request_body(&body, &mut headers).unwrap();
|
||||
|
||||
assert_eq!(headers.get(ACCEPT).unwrap(), "text/event-stream");
|
||||
assert_eq!(headers.get("session_id").unwrap(), "segment-123");
|
||||
assert_eq!(headers.get("x-client-request-id").unwrap(), "segment-123");
|
||||
assert_eq!(headers.get(CONTENT_ENCODING).unwrap(), "zstd");
|
||||
|
||||
let RequestBody::CompressedJson(compressed) = encoded else {
|
||||
panic!("Codex backend request body must be zstd-compressed");
|
||||
};
|
||||
let decoded = zstd::stream::decode_all(std::io::Cursor::new(compressed)).unwrap();
|
||||
let decoded: serde_json::Value = serde_json::from_slice(&decoded).unwrap();
|
||||
assert_eq!(decoded["prompt_cache_key"], "segment-123");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn non_codex_request_does_not_get_codex_only_headers_or_compression() {
|
||||
let transport = transport(ResolvedAuth::ApiKey("api-key".to_string()));
|
||||
let request = Request::new().user("hello").cache_key("segment-123");
|
||||
let mut headers = transport.build_headers().await.unwrap();
|
||||
transport
|
||||
.apply_stream_headers(&mut headers, &request)
|
||||
.unwrap();
|
||||
let body = transport.scheme.build_request_body(
|
||||
&transport.model_id,
|
||||
&request,
|
||||
&transport.capability,
|
||||
);
|
||||
let encoded = transport.encode_request_body(&body, &mut headers).unwrap();
|
||||
|
||||
assert_eq!(headers.get(ACCEPT).unwrap(), "text/event-stream");
|
||||
assert!(headers.get("session_id").is_none());
|
||||
assert!(headers.get("x-client-request-id").is_none());
|
||||
assert!(headers.get(CONTENT_ENCODING).is_none());
|
||||
|
||||
let RequestBody::Json(decoded) = encoded else {
|
||||
panic!("non-Codex request body must remain normal JSON");
|
||||
};
|
||||
assert_eq!(decoded["prompt_cache_key"], "segment-123");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -188,6 +188,10 @@ impl AuthProvider for CodexAuthProvider {
|
|||
.map_err(CodexAuthError::to_client_error)?;
|
||||
Self::build_headers(&snap).map_err(CodexAuthError::to_client_error)
|
||||
}
|
||||
|
||||
fn is_codex_backend(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// `access_token` の JWT `exp` を見て、期限切れなら true。
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user