yoi/crates/llm-worker/src/llm_client/transport.rs

863 lines
30 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! `HttpTransport<S: Scheme>`: すべての LLM wire scheme を共通の 1 本の
//! HTTP クライアントで扱う。
//!
//! 旧 `providers/{anthropic,openai,gemini,ollama}.rs` を置き換える。
//! scheme 固有の差分は [`Scheme`] trait 実装に委譲する。
use std::pin::Pin;
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use eventsource_stream::Eventsource;
use futures::{Stream, StreamExt, TryStreamExt};
use reqwest::header::{
ACCEPT, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue,
RETRY_AFTER, TRANSFER_ENCODING,
};
use serde_json::{Map, Value, json};
use super::auth::{AuthProvider, AuthRequirement};
use super::capability::ModelCapability;
use super::client::{ConfigWarning, LlmClient, ResponseStream};
use super::error::ClientError;
use super::event::Event;
use super::scheme::Scheme;
use super::types::{Request, RequestConfig};
pub const DEFAULT_STREAM_OPEN_TIMEOUT: Duration = Duration::from_secs(20);
pub const DEFAULT_FIRST_STREAM_EVENT_TIMEOUT: Duration = Duration::from_secs(30);
/// `AuthRef` を解決したランタイム表現。`crates/provider` が構築する。
///
/// - `None`: 認証ヘッダを送らないOllama 等の opt-out
/// - `ApiKey`: 静的な API key 文字列
/// - `Custom`: リクエスト毎に動的にヘッダを組み立てるCodex OAuth 等)
#[derive(Debug, Clone)]
pub enum ResolvedAuth {
None,
ApiKey(String),
Custom(Arc<dyn AuthProvider>),
}
impl ResolvedAuth {
/// 認証要件と実際の解決値が噛み合うか検査する。構築時検証用。
///
/// - `ResolvedAuth::None` は認証を付けない宣言なので、どの
/// `AuthRequirement` でも受け入れるOllama の Anthropic scheme
/// 流用は `required_auth = XApiKey` だが認証ヘッダなしで動く)
/// - `ResolvedAuth::Custom` は「ヘッダ組立を全部こちらで行う」
/// 宣言なので、scheme が要求する形式によらず受け入れる
pub fn matches(&self, req: AuthRequirement) -> bool {
match (self, req) {
(Self::None, _) => true,
(Self::Custom(_), _) => true,
(
Self::ApiKey(_),
AuthRequirement::Bearer
| AuthRequirement::XApiKey
| AuthRequirement::QueryParam { .. },
) => true,
_ => false,
}
}
}
fn header_value_for_diagnostics(headers: &HeaderMap, name: &'static HeaderName) -> Option<String> {
headers
.get(name)
.and_then(|value| value.to_str().ok())
.map(str::trim)
.filter(|value| !value.is_empty())
.map(ToOwned::to_owned)
}
fn response_header_diagnostics(headers: &HeaderMap) -> serde_json::Value {
serde_json::json!({
"content_type": header_value_for_diagnostics(headers, &CONTENT_TYPE),
"content_encoding": header_value_for_diagnostics(headers, &CONTENT_ENCODING),
"transfer_encoding": header_value_for_diagnostics(headers, &TRANSFER_ENCODING),
"content_length": header_value_for_diagnostics(headers, &CONTENT_LENGTH),
})
}
fn sse_error_context(status: u16, headers: &serde_json::Value, source: &str) -> String {
let field = |name: &str| {
headers
.get(name)
.and_then(serde_json::Value::as_str)
.unwrap_or("<none>")
};
format!(
"SSE stream parse failed after HTTP {status}: {source}; content-type={}, content-encoding={}, transfer-encoding={}, content-length={}",
field("content_type"),
field("content_encoding"),
field("transfer_encoding"),
field("content_length")
)
}
/// scheme 共通の HTTP 通信層。
pub struct HttpTransport<S: Scheme> {
http_client: reqwest::Client,
scheme: S,
model_id: String,
base_url: String,
auth: ResolvedAuth,
capability: ModelCapability,
}
impl<S: Scheme> HttpTransport<S> {
/// 新しい transport を作る。`base_url` は末尾スラッシュの有無を
/// どちらでも受け付ける(内部で正規化)。
pub fn new(
scheme: S,
model_id: impl Into<String>,
base_url: impl Into<String>,
auth: ResolvedAuth,
capability: ModelCapability,
) -> Self {
let base_url = base_url.into();
let base_url = base_url.trim_end_matches('/').to_string();
Self {
http_client: reqwest::Client::new(),
scheme,
model_id: model_id.into(),
base_url,
auth,
capability,
}
}
/// カスタム HTTP クライアントを差し込む(テスト等)。
pub fn with_http_client(mut self, client: reqwest::Client) -> Self {
self.http_client = client;
self
}
fn build_url(&self) -> String {
let path = self.scheme.path(&self.model_id);
let url = format!("{}{}", self.base_url, path);
// Gemini のようにクエリパラメータで認証する場合は URL にキーを追記する
if let (AuthRequirement::QueryParam { name }, ResolvedAuth::ApiKey(key)) =
(self.scheme.required_auth(), &self.auth)
{
let sep = if url.contains('?') { '&' } else { '?' };
format!("{url}{sep}{name}={key}")
} else {
url
}
}
async fn build_headers(&self) -> Result<HeaderMap, ClientError> {
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
match (&self.auth, self.scheme.required_auth()) {
(ResolvedAuth::None, _) | (_, AuthRequirement::None) => {}
(ResolvedAuth::Custom(provider), _) => {
for (name, mut value) in provider.headers().await? {
value.set_sensitive(true);
headers.insert(name, value);
}
}
(ResolvedAuth::ApiKey(key), AuthRequirement::Bearer) => {
let mut val = HeaderValue::from_str(&format!("Bearer {key}"))
.map_err(|e| ClientError::Config(format!("invalid api key: {e}")))?;
val.set_sensitive(true);
headers.insert("Authorization", val);
}
(ResolvedAuth::ApiKey(key), AuthRequirement::XApiKey) => {
let mut val = HeaderValue::from_str(key.as_str())
.map_err(|e| ClientError::Config(format!("invalid api key: {e}")))?;
val.set_sensitive(true);
headers.insert("x-api-key", val);
}
(_, AuthRequirement::QueryParam { .. }) => {
// クエリパラメータは `build_url` で付与済み
}
(ResolvedAuth::ApiKey(_), AuthRequirement::Custom) => {
// scheme が Custom を要求する組合せに ApiKey は流れてこない想定
// `matches()` で弾かれる)。安全側で何もしない
}
}
for (name, value) in self.scheme.additional_headers() {
let hv = HeaderValue::from_str(&value)
.map_err(|e| ClientError::Config(format!("invalid header {name}: {e}")))?;
headers.insert(name, hv);
}
Ok(headers)
}
fn is_codex_backend(&self) -> bool {
match &self.auth {
ResolvedAuth::Custom(provider) => provider.is_codex_backend(),
_ => false,
}
}
fn apply_stream_headers(
&self,
headers: &mut HeaderMap,
request: &Request,
) -> Result<(), ClientError> {
headers.insert(ACCEPT, HeaderValue::from_static("text/event-stream"));
if self.is_codex_backend()
&& let Some(cache_key) = request.cache_key.as_deref()
{
let value = HeaderValue::from_str(cache_key).map_err(|e| {
ClientError::Config(format!("invalid Codex conversation header: {e}"))
})?;
headers.insert(HeaderName::from_static("session_id"), value.clone());
headers.insert(HeaderName::from_static("x-client-request-id"), value);
}
Ok(())
}
fn encode_request_body(
&self,
body: &serde_json::Value,
headers: &mut HeaderMap,
) -> Result<RequestBody, ClientError> {
if !self.is_codex_backend() {
return Ok(RequestBody::Json(body.clone()));
}
let raw = serde_json::to_vec(body)?;
let raw_json_bytes = raw.len();
let compressed = zstd::stream::encode_all(std::io::Cursor::new(raw), 3)
.map_err(|e| ClientError::Config(format!("failed to zstd-compress request: {e}")))?;
headers.insert(CONTENT_ENCODING, HeaderValue::from_static("zstd"));
Ok(RequestBody::CompressedJson {
bytes: compressed,
raw_json_bytes,
})
}
}
enum RequestBody {
Json(serde_json::Value),
CompressedJson {
bytes: Vec<u8>,
raw_json_bytes: usize,
},
}
impl RequestBody {
fn encoding(&self) -> &'static str {
match self {
Self::Json(_) => "json",
Self::CompressedJson { .. } => "zstd",
}
}
fn raw_json_bytes(&self) -> Option<usize> {
match self {
Self::Json(body) => serde_json::to_vec(body).ok().map(|bytes| bytes.len()),
Self::CompressedJson { raw_json_bytes, .. } => Some(*raw_json_bytes),
}
}
fn wire_bytes(&self) -> Option<usize> {
match self {
Self::Json(body) => serde_json::to_vec(body).ok().map(|bytes| bytes.len()),
Self::CompressedJson { bytes, .. } => Some(bytes.len()),
}
}
}
fn auth_kind(auth: &ResolvedAuth) -> &'static str {
match auth {
ResolvedAuth::None => "none",
ResolvedAuth::ApiKey(_) => "api_key",
ResolvedAuth::Custom(_) => "custom",
}
}
fn emit_transport_trace(request: &Request, label: &str, data: Value) {
if let Some(trace) = &request.transport_trace {
trace.emit(label, data);
}
}
fn json_value_kind(value: &Value) -> &'static str {
match value {
Value::Null => "null",
Value::Bool(_) => "bool",
Value::Number(_) => "number",
Value::String(_) => "string",
Value::Array(_) => "array",
Value::Object(_) => "object",
}
}
fn request_body_shape_payload(body: &Value) -> Value {
let mut map = Map::new();
if let Some(input) = body.get("input").and_then(Value::as_array) {
let items_json_bytes = serde_json::to_vec(input).map(|bytes| bytes.len()).ok();
let mut reasoning_items = 0usize;
let mut reasoning_encrypted_content_count = 0usize;
let mut reasoning_encrypted_content_bytes = 0usize;
for item in input {
if item.get("type").and_then(Value::as_str) != Some("reasoning") {
continue;
}
reasoning_items += 1;
if let Some(encrypted) = item.get("encrypted_content").and_then(Value::as_str) {
reasoning_encrypted_content_count += 1;
reasoning_encrypted_content_bytes += encrypted.len();
}
}
map.insert("items_len".to_string(), json!(input.len()));
map.insert("items_json_bytes".to_string(), json!(items_json_bytes));
map.insert("reasoning_items".to_string(), json!(reasoning_items));
map.insert(
"reasoning_encrypted_content_count".to_string(),
json!(reasoning_encrypted_content_count),
);
map.insert(
"reasoning_encrypted_content_bytes".to_string(),
json!(reasoning_encrypted_content_bytes),
);
}
Value::Object(map)
}
fn api_error_code(error: &ClientError) -> Option<&str> {
match error {
ClientError::Api { code, .. } => code.as_deref(),
_ => None,
}
}
fn is_context_length_exceeded(error: &ClientError) -> bool {
match error {
ClientError::Api { code, message, .. } => {
code.as_deref() == Some("context_length_exceeded")
|| message.contains("context_length_exceeded")
}
_ => false,
}
}
async fn response_with_timeout(
future: impl std::future::Future<Output = Result<reqwest::Response, reqwest::Error>>,
timeout: Duration,
phase: &'static str,
) -> Result<reqwest::Response, ClientError> {
tokio::time::timeout(timeout, future)
.await
.map_err(|_| ClientError::Timeout { phase, timeout })?
.map_err(ClientError::Http)
}
impl<S: Scheme + Clone> Clone for HttpTransport<S> {
fn clone(&self) -> Self {
Self {
http_client: self.http_client.clone(),
scheme: self.scheme.clone(),
model_id: self.model_id.clone(),
base_url: self.base_url.clone(),
auth: self.auth.clone(),
capability: self.capability.clone(),
}
}
}
/// エラーレスポンスを `ClientError::Api` に変換する。
async fn classify_error_response(resp: reqwest::Response) -> ClientError {
let status = resp.status().as_u16();
let retry_after = resp
.headers()
.get(RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.trim().parse::<u64>().ok())
.map(Duration::from_secs);
let text = resp.text().await.unwrap_or_default();
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&text) {
let error = json.get("error").unwrap_or(&json);
let code = error
.get("code")
.and_then(|v| v.as_str())
.or_else(|| error.get("type").and_then(|v| v.as_str()))
.map(String::from);
let message = error
.get("message")
.and_then(|v| v.as_str())
.unwrap_or(&text)
.to_string();
ClientError::Api {
status: Some(status),
code,
message,
retry_after,
}
} else {
ClientError::Api {
status: Some(status),
code: None,
message: text,
retry_after,
}
}
}
#[async_trait]
impl<S: Scheme + Clone + 'static> LlmClient for HttpTransport<S> {
fn clone_boxed(&self) -> Box<dyn LlmClient> {
Box::new(self.clone())
}
fn validate_config(&self, config: &RequestConfig) -> Vec<ConfigWarning> {
self.scheme.validate_config(config)
}
async fn stream(&self, request: Request) -> Result<ResponseStream, ClientError> {
let total_started = Instant::now();
let path = self.scheme.path(&self.model_id);
emit_transport_trace(
&request,
"transport_start",
json!({
"model": &self.model_id,
"path": path,
"auth_kind": auth_kind(&self.auth),
"required_auth": format!("{:?}", self.scheme.required_auth()),
"codex_backend": self.is_codex_backend(),
"cache_key_present": request.cache_key.is_some(),
"stream_open_timeout_ms": DEFAULT_STREAM_OPEN_TIMEOUT.as_millis() as u64,
}),
);
let url = self.build_url();
let headers_started = Instant::now();
emit_transport_trace(
&request,
"transport_headers_start",
json!({
"auth_kind": auth_kind(&self.auth),
"required_auth": format!("{:?}", self.scheme.required_auth()),
}),
);
let mut headers = match self.build_headers().await {
Ok(headers) => {
emit_transport_trace(
&request,
"transport_headers_done",
json!({
"elapsed_ms": headers_started.elapsed().as_millis() as u64,
"headers_len": headers.len(),
}),
);
headers
}
Err(error) => {
emit_transport_trace(
&request,
"transport_headers_error",
json!({
"elapsed_ms": headers_started.elapsed().as_millis() as u64,
"error": error.to_string(),
}),
);
return Err(error);
}
};
let stream_headers_started = Instant::now();
if let Err(error) = self.apply_stream_headers(&mut headers, &request) {
emit_transport_trace(
&request,
"transport_stream_headers_error",
json!({
"elapsed_ms": stream_headers_started.elapsed().as_millis() as u64,
"error": error.to_string(),
}),
);
return Err(error);
}
emit_transport_trace(
&request,
"transport_stream_headers_done",
json!({
"elapsed_ms": stream_headers_started.elapsed().as_millis() as u64,
"headers_len": headers.len(),
}),
);
let body_started = Instant::now();
emit_transport_trace(&request, "transport_body_build_start", json!({}));
let body = self
.scheme
.build_request_body(&self.model_id, &request, &self.capability);
let body_shape = request_body_shape_payload(&body);
emit_transport_trace(
&request,
"transport_body_build_done",
json!({
"elapsed_ms": body_started.elapsed().as_millis() as u64,
"body_kind": json_value_kind(&body),
"request_shape": body_shape.clone(),
}),
);
let encode_started = Instant::now();
let request_body = match self.encode_request_body(&body, &mut headers) {
Ok(body) => body,
Err(error) => {
emit_transport_trace(
&request,
"transport_body_encode_error",
json!({
"elapsed_ms": encode_started.elapsed().as_millis() as u64,
"error": error.to_string(),
}),
);
return Err(error);
}
};
emit_transport_trace(
&request,
"transport_body_encode_done",
json!({
"elapsed_ms": encode_started.elapsed().as_millis() as u64,
"encoding": request_body.encoding(),
"raw_json_bytes": request_body.raw_json_bytes(),
"wire_bytes": request_body.wire_bytes(),
"request_shape": body_shape.clone(),
}),
);
let builder = self.http_client.post(&url).headers(headers);
let builder = match request_body {
RequestBody::Json(body) => builder.json(&body),
RequestBody::CompressedJson { bytes, .. } => builder.body(bytes),
};
let send_started = Instant::now();
emit_transport_trace(&request, "transport_http_send_start", json!({}));
let response =
match response_with_timeout(builder.send(), DEFAULT_STREAM_OPEN_TIMEOUT, "stream_open")
.await
{
Ok(response) => {
let response_headers = response_header_diagnostics(response.headers());
emit_transport_trace(
&request,
"transport_http_headers_received",
json!({
"elapsed_ms": send_started.elapsed().as_millis() as u64,
"status": response.status().as_u16(),
"success": response.status().is_success(),
"headers": response_headers,
}),
);
response
}
Err(error) => {
emit_transport_trace(
&request,
"transport_http_send_error",
json!({
"elapsed_ms": send_started.elapsed().as_millis() as u64,
"error": error.to_string(),
}),
);
return Err(error);
}
};
if !response.status().is_success() {
let status = response.status().as_u16();
let retry_after_present = response.headers().get(RETRY_AFTER).is_some();
let error = classify_error_response(response).await;
let context_length_exceeded = is_context_length_exceeded(&error);
emit_transport_trace(
&request,
"transport_http_status_error",
json!({
"status": status,
"retry_after_present": retry_after_present,
"api_error_code": api_error_code(&error),
"context_length_exceeded": context_length_exceeded,
"provider_usage_absent": context_length_exceeded,
"request_shape": body_shape.clone(),
}),
);
return Err(error);
}
emit_transport_trace(
&request,
"transport_stream_ready",
json!({
"elapsed_ms": total_started.elapsed().as_millis() as u64,
}),
);
let scheme = self.scheme.clone();
let status = response.status().as_u16();
let response_headers = response_header_diagnostics(response.headers());
let transport_trace = request.transport_trace.clone();
let byte_stream = response.bytes_stream().map_err(std::io::Error::other);
let event_stream = byte_stream.eventsource();
// scheme 固有のパース状態をストリーム単位で保持する
let mut state = <S::State as Default>::default();
let stream = event_stream
.map(move |result| match result {
Ok(frame) => match scheme.parse_sse(&frame.event, &frame.data, &mut state) {
Ok(events) => Ok(events),
Err(e) => Err(e),
},
Err(e) => {
let source = e.to_string();
let message = sse_error_context(status, &response_headers, &source);
if let Some(trace) = &transport_trace {
trace.emit(
"transport_sse_parse_error",
json!({
"status": status,
"headers": response_headers.clone(),
"error": source,
}),
);
}
Err(ClientError::Sse(message))
}
})
.map(|res| {
let s: Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>> = match res {
Ok(events) => Box::pin(futures::stream::iter(events.into_iter().map(Ok))),
Err(e) => Box::pin(futures::stream::once(async move { Err(e) })),
};
s
})
.flatten();
Ok(Box::pin(stream))
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[derive(Debug)]
struct TestAuthProvider {
codex: bool,
}
#[async_trait]
impl AuthProvider for TestAuthProvider {
async fn headers(&self) -> Result<Vec<(HeaderName, HeaderValue)>, ClientError> {
Ok(vec![
(
HeaderName::from_static("authorization"),
HeaderValue::from_static("Bearer test-token"),
),
(
HeaderName::from_static("chatgpt-account-id"),
HeaderValue::from_static("account-1"),
),
])
}
fn is_codex_backend(&self) -> bool {
self.codex
}
}
#[derive(Clone)]
struct TestScheme;
impl Scheme for TestScheme {
type State = ();
fn default_base_url(&self) -> &'static str {
"https://example.test"
}
fn path(&self, _model_id: &str) -> String {
"/responses".to_string()
}
fn required_auth(&self) -> AuthRequirement {
AuthRequirement::Bearer
}
fn build_request_body(
&self,
model_id: &str,
request: &Request,
_capability: &ModelCapability,
) -> serde_json::Value {
json!({
"model": model_id,
"input_len": request.items.len(),
"prompt_cache_key": request.cache_key,
})
}
fn parse_sse(
&self,
_event_type: &str,
_data: &str,
_state: &mut Self::State,
) -> Result<Vec<Event>, ClientError> {
Ok(Vec::new())
}
fn default_capability(&self) -> ModelCapability {
ModelCapability::minimal()
}
}
fn transport(auth: ResolvedAuth) -> HttpTransport<TestScheme> {
HttpTransport::new(
TestScheme,
"gpt-test",
"https://example.test",
auth,
ModelCapability::minimal(),
)
}
#[test]
fn sse_error_context_includes_response_headers() {
let headers = json!({
"content_type": "application/octet-stream",
"content_encoding": "gzip",
"transfer_encoding": "chunked",
"content_length": "123",
});
let message = sse_error_context(200, &headers, "stream did not contain valid UTF-8");
assert!(message.contains("HTTP 200"));
assert!(message.contains("stream did not contain valid UTF-8"));
assert!(message.contains("content-type=application/octet-stream"));
assert!(message.contains("content-encoding=gzip"));
assert!(message.contains("transfer-encoding=chunked"));
assert!(message.contains("content-length=123"));
}
#[test]
fn response_header_diagnostics_redacts_to_safe_header_subset() {
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
headers.insert(CONTENT_ENCODING, HeaderValue::from_static("identity"));
headers.insert("authorization", HeaderValue::from_static("Bearer secret"));
let diagnostics = response_header_diagnostics(&headers);
assert_eq!(diagnostics["content_type"], "text/event-stream");
assert_eq!(diagnostics["content_encoding"], "identity");
assert!(diagnostics.get("authorization").is_none());
}
#[test]
fn request_body_shape_counts_reasoning_encrypted_content() {
let payload = request_body_shape_payload(&json!({
"reasoning": { "summary": "auto" },
"input": [
{ "type": "message", "role": "user", "content": [] },
{ "type": "reasoning", "encrypted_content": "abc", "summary": [] },
{ "type": "reasoning", "encrypted_content": "defgh", "summary": [] }
]
}));
assert_eq!(payload["items_len"], 3);
assert_eq!(payload["reasoning_items"], 2);
assert_eq!(payload["reasoning_encrypted_content_count"], 2);
assert_eq!(payload["reasoning_encrypted_content_bytes"], 8);
assert!(payload["items_json_bytes"].as_u64().unwrap() > 0);
}
#[tokio::test]
async fn response_timeout_returns_retryable_lifecycle_timeout() {
let err = response_with_timeout(
std::future::pending::<Result<reqwest::Response, reqwest::Error>>(),
Duration::from_millis(5),
"stream_open",
)
.await
.unwrap_err();
assert!(crate::llm_client::error::is_retryable(&err));
assert!(matches!(
err,
ClientError::Timeout {
phase: "stream_open",
..
}
));
}
#[tokio::test]
async fn codex_backend_adds_conversation_headers_and_zstd_body() {
let transport = transport(ResolvedAuth::Custom(Arc::new(TestAuthProvider {
codex: true,
})));
let request = Request::new().user("hello").cache_key("segment-123");
let mut headers = transport.build_headers().await.unwrap();
transport
.apply_stream_headers(&mut headers, &request)
.unwrap();
let body = transport.scheme.build_request_body(
&transport.model_id,
&request,
&transport.capability,
);
let encoded = transport.encode_request_body(&body, &mut headers).unwrap();
assert_eq!(headers.get(ACCEPT).unwrap(), "text/event-stream");
assert_eq!(headers.get("session_id").unwrap(), "segment-123");
assert_eq!(headers.get("x-client-request-id").unwrap(), "segment-123");
assert_eq!(headers.get(CONTENT_ENCODING).unwrap(), "zstd");
let RequestBody::CompressedJson {
bytes: compressed,
raw_json_bytes,
} = encoded
else {
panic!("Codex backend request body must be zstd-compressed");
};
assert!(raw_json_bytes > 0);
let decoded = zstd::stream::decode_all(std::io::Cursor::new(compressed)).unwrap();
let decoded: serde_json::Value = serde_json::from_slice(&decoded).unwrap();
assert_eq!(decoded["prompt_cache_key"], "segment-123");
}
#[tokio::test]
async fn non_codex_request_does_not_get_codex_only_headers_or_compression() {
let transport = transport(ResolvedAuth::ApiKey("api-key".to_string()));
let request = Request::new().user("hello").cache_key("segment-123");
let mut headers = transport.build_headers().await.unwrap();
transport
.apply_stream_headers(&mut headers, &request)
.unwrap();
let body = transport.scheme.build_request_body(
&transport.model_id,
&request,
&transport.capability,
);
let encoded = transport.encode_request_body(&body, &mut headers).unwrap();
assert_eq!(headers.get(ACCEPT).unwrap(), "text/event-stream");
assert!(headers.get("session_id").is_none());
assert!(headers.get("x-client-request-id").is_none());
assert!(headers.get(CONTENT_ENCODING).is_none());
let RequestBody::Json(decoded) = encoded else {
panic!("non-Codex request body must remain normal JSON");
};
assert_eq!(decoded["prompt_cache_key"], "segment-123");
}
}