//! `HttpTransport`: すべての LLM wire scheme を共通の 1 本の //! HTTP クライアントで扱う。 //! //! 旧 `providers/{anthropic,openai,gemini,ollama}.rs` を置き換える。 //! scheme 固有の差分は [`Scheme`] trait 実装に委譲する。 use std::pin::Pin; use std::sync::Arc; use std::time::{Duration, Instant}; use async_trait::async_trait; use eventsource_stream::Eventsource; use futures::{Stream, StreamExt, TryStreamExt}; use reqwest::header::{ ACCEPT, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue, RETRY_AFTER, TRANSFER_ENCODING, }; use serde_json::{Map, Value, json}; use super::auth::{AuthProvider, AuthRequirement}; use super::capability::ModelCapability; use super::client::{ConfigWarning, LlmClient, ResponseStream}; use super::error::ClientError; use super::event::Event; use super::scheme::Scheme; use super::types::{Request, RequestConfig}; pub const DEFAULT_STREAM_OPEN_TIMEOUT: Duration = Duration::from_secs(20); pub const DEFAULT_FIRST_STREAM_EVENT_TIMEOUT: Duration = Duration::from_secs(30); /// `AuthRef` を解決したランタイム表現。`crates/provider` が構築する。 /// /// - `None`: 認証ヘッダを送らない(Ollama 等の opt-out) /// - `ApiKey`: 静的な API key 文字列 /// - `Custom`: リクエスト毎に動的にヘッダを組み立てる(Codex OAuth 等) #[derive(Debug, Clone)] pub enum ResolvedAuth { None, ApiKey(String), Custom(Arc), } impl ResolvedAuth { /// 認証要件と実際の解決値が噛み合うか検査する。構築時検証用。 /// /// - `ResolvedAuth::None` は認証を付けない宣言なので、どの /// `AuthRequirement` でも受け入れる(Ollama の Anthropic scheme /// 流用は `required_auth = XApiKey` だが認証ヘッダなしで動く) /// - `ResolvedAuth::Custom` は「ヘッダ組立を全部こちらで行う」 /// 宣言なので、scheme が要求する形式によらず受け入れる pub fn matches(&self, req: AuthRequirement) -> bool { match (self, req) { (Self::None, _) => true, (Self::Custom(_), _) => true, ( Self::ApiKey(_), AuthRequirement::Bearer | AuthRequirement::XApiKey | AuthRequirement::QueryParam { .. }, ) => true, _ => false, } } } fn header_value_for_diagnostics(headers: &HeaderMap, name: &'static HeaderName) -> Option { headers .get(name) .and_then(|value| value.to_str().ok()) .map(str::trim) .filter(|value| !value.is_empty()) .map(ToOwned::to_owned) } fn response_header_diagnostics(headers: &HeaderMap) -> serde_json::Value { serde_json::json!({ "content_type": header_value_for_diagnostics(headers, &CONTENT_TYPE), "content_encoding": header_value_for_diagnostics(headers, &CONTENT_ENCODING), "transfer_encoding": header_value_for_diagnostics(headers, &TRANSFER_ENCODING), "content_length": header_value_for_diagnostics(headers, &CONTENT_LENGTH), }) } fn sse_error_context(status: u16, headers: &serde_json::Value, source: &str) -> String { let field = |name: &str| { headers .get(name) .and_then(serde_json::Value::as_str) .unwrap_or("") }; format!( "SSE stream parse failed after HTTP {status}: {source}; content-type={}, content-encoding={}, transfer-encoding={}, content-length={}", field("content_type"), field("content_encoding"), field("transfer_encoding"), field("content_length") ) } /// scheme 共通の HTTP 通信層。 pub struct HttpTransport { http_client: reqwest::Client, scheme: S, model_id: String, base_url: String, auth: ResolvedAuth, capability: ModelCapability, } impl HttpTransport { /// 新しい transport を作る。`base_url` は末尾スラッシュの有無を /// どちらでも受け付ける(内部で正規化)。 pub fn new( scheme: S, model_id: impl Into, base_url: impl Into, auth: ResolvedAuth, capability: ModelCapability, ) -> Self { let base_url = base_url.into(); let base_url = base_url.trim_end_matches('/').to_string(); Self { http_client: reqwest::Client::new(), scheme, model_id: model_id.into(), base_url, auth, capability, } } /// カスタム HTTP クライアントを差し込む(テスト等)。 pub fn with_http_client(mut self, client: reqwest::Client) -> Self { self.http_client = client; self } fn build_url(&self) -> String { let path = self.scheme.path(&self.model_id); let url = format!("{}{}", self.base_url, path); // Gemini のようにクエリパラメータで認証する場合は URL にキーを追記する if let (AuthRequirement::QueryParam { name }, ResolvedAuth::ApiKey(key)) = (self.scheme.required_auth(), &self.auth) { let sep = if url.contains('?') { '&' } else { '?' }; format!("{url}{sep}{name}={key}") } else { url } } async fn build_headers(&self) -> Result { let mut headers = HeaderMap::new(); headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); match (&self.auth, self.scheme.required_auth()) { (ResolvedAuth::None, _) | (_, AuthRequirement::None) => {} (ResolvedAuth::Custom(provider), _) => { for (name, mut value) in provider.headers().await? { value.set_sensitive(true); headers.insert(name, value); } } (ResolvedAuth::ApiKey(key), AuthRequirement::Bearer) => { let mut val = HeaderValue::from_str(&format!("Bearer {key}")) .map_err(|e| ClientError::Config(format!("invalid api key: {e}")))?; val.set_sensitive(true); headers.insert("Authorization", val); } (ResolvedAuth::ApiKey(key), AuthRequirement::XApiKey) => { let mut val = HeaderValue::from_str(key.as_str()) .map_err(|e| ClientError::Config(format!("invalid api key: {e}")))?; val.set_sensitive(true); headers.insert("x-api-key", val); } (_, AuthRequirement::QueryParam { .. }) => { // クエリパラメータは `build_url` で付与済み } (ResolvedAuth::ApiKey(_), AuthRequirement::Custom) => { // scheme が Custom を要求する組合せに ApiKey は流れてこない想定 // (`matches()` で弾かれる)。安全側で何もしない } } for (name, value) in self.scheme.additional_headers() { let hv = HeaderValue::from_str(&value) .map_err(|e| ClientError::Config(format!("invalid header {name}: {e}")))?; headers.insert(name, hv); } Ok(headers) } fn is_codex_backend(&self) -> bool { match &self.auth { ResolvedAuth::Custom(provider) => provider.is_codex_backend(), _ => false, } } fn apply_stream_headers( &self, headers: &mut HeaderMap, request: &Request, ) -> Result<(), ClientError> { headers.insert(ACCEPT, HeaderValue::from_static("text/event-stream")); if self.is_codex_backend() && let Some(cache_key) = request.cache_key.as_deref() { let value = HeaderValue::from_str(cache_key).map_err(|e| { ClientError::Config(format!("invalid Codex conversation header: {e}")) })?; headers.insert(HeaderName::from_static("session_id"), value.clone()); headers.insert(HeaderName::from_static("x-client-request-id"), value); } Ok(()) } fn encode_request_body( &self, body: &serde_json::Value, headers: &mut HeaderMap, ) -> Result { if !self.is_codex_backend() { return Ok(RequestBody::Json(body.clone())); } let raw = serde_json::to_vec(body)?; let raw_json_bytes = raw.len(); let compressed = zstd::stream::encode_all(std::io::Cursor::new(raw), 3) .map_err(|e| ClientError::Config(format!("failed to zstd-compress request: {e}")))?; headers.insert(CONTENT_ENCODING, HeaderValue::from_static("zstd")); Ok(RequestBody::CompressedJson { bytes: compressed, raw_json_bytes, }) } } enum RequestBody { Json(serde_json::Value), CompressedJson { bytes: Vec, raw_json_bytes: usize, }, } impl RequestBody { fn encoding(&self) -> &'static str { match self { Self::Json(_) => "json", Self::CompressedJson { .. } => "zstd", } } fn raw_json_bytes(&self) -> Option { match self { Self::Json(body) => serde_json::to_vec(body).ok().map(|bytes| bytes.len()), Self::CompressedJson { raw_json_bytes, .. } => Some(*raw_json_bytes), } } fn wire_bytes(&self) -> Option { match self { Self::Json(body) => serde_json::to_vec(body).ok().map(|bytes| bytes.len()), Self::CompressedJson { bytes, .. } => Some(bytes.len()), } } } fn auth_kind(auth: &ResolvedAuth) -> &'static str { match auth { ResolvedAuth::None => "none", ResolvedAuth::ApiKey(_) => "api_key", ResolvedAuth::Custom(_) => "custom", } } fn emit_transport_trace(request: &Request, label: &str, data: Value) { if let Some(trace) = &request.transport_trace { trace.emit(label, data); } } fn json_value_kind(value: &Value) -> &'static str { match value { Value::Null => "null", Value::Bool(_) => "bool", Value::Number(_) => "number", Value::String(_) => "string", Value::Array(_) => "array", Value::Object(_) => "object", } } fn request_body_shape_payload(body: &Value) -> Value { let mut map = Map::new(); if let Some(input) = body.get("input").and_then(Value::as_array) { let items_json_bytes = serde_json::to_vec(input).map(|bytes| bytes.len()).ok(); let mut reasoning_items = 0usize; let mut reasoning_encrypted_content_count = 0usize; let mut reasoning_encrypted_content_bytes = 0usize; for item in input { if item.get("type").and_then(Value::as_str) != Some("reasoning") { continue; } reasoning_items += 1; if let Some(encrypted) = item.get("encrypted_content").and_then(Value::as_str) { reasoning_encrypted_content_count += 1; reasoning_encrypted_content_bytes += encrypted.len(); } } map.insert("items_len".to_string(), json!(input.len())); map.insert("items_json_bytes".to_string(), json!(items_json_bytes)); map.insert("reasoning_items".to_string(), json!(reasoning_items)); map.insert( "reasoning_encrypted_content_count".to_string(), json!(reasoning_encrypted_content_count), ); map.insert( "reasoning_encrypted_content_bytes".to_string(), json!(reasoning_encrypted_content_bytes), ); } Value::Object(map) } fn api_error_code(error: &ClientError) -> Option<&str> { match error { ClientError::Api { code, .. } => code.as_deref(), _ => None, } } fn is_context_length_exceeded(error: &ClientError) -> bool { match error { ClientError::Api { code, message, .. } => { code.as_deref() == Some("context_length_exceeded") || message.contains("context_length_exceeded") } _ => false, } } async fn response_with_timeout( future: impl std::future::Future>, timeout: Duration, phase: &'static str, ) -> Result { tokio::time::timeout(timeout, future) .await .map_err(|_| ClientError::Timeout { phase, timeout })? .map_err(ClientError::Http) } impl Clone for HttpTransport { fn clone(&self) -> Self { Self { http_client: self.http_client.clone(), scheme: self.scheme.clone(), model_id: self.model_id.clone(), base_url: self.base_url.clone(), auth: self.auth.clone(), capability: self.capability.clone(), } } } /// エラーレスポンスを `ClientError::Api` に変換する。 async fn classify_error_response(resp: reqwest::Response) -> ClientError { let status = resp.status().as_u16(); let retry_after = resp .headers() .get(RETRY_AFTER) .and_then(|v| v.to_str().ok()) .and_then(|s| s.trim().parse::().ok()) .map(Duration::from_secs); let text = resp.text().await.unwrap_or_default(); if let Ok(json) = serde_json::from_str::(&text) { let error = json.get("error").unwrap_or(&json); let code = error .get("code") .and_then(|v| v.as_str()) .or_else(|| error.get("type").and_then(|v| v.as_str())) .map(String::from); let message = error .get("message") .and_then(|v| v.as_str()) .unwrap_or(&text) .to_string(); ClientError::Api { status: Some(status), code, message, retry_after, } } else { ClientError::Api { status: Some(status), code: None, message: text, retry_after, } } } #[async_trait] impl LlmClient for HttpTransport { fn clone_boxed(&self) -> Box { Box::new(self.clone()) } fn validate_config(&self, config: &RequestConfig) -> Vec { self.scheme.validate_config(config) } async fn stream(&self, request: Request) -> Result { let total_started = Instant::now(); let path = self.scheme.path(&self.model_id); emit_transport_trace( &request, "transport_start", json!({ "model": &self.model_id, "path": path, "auth_kind": auth_kind(&self.auth), "required_auth": format!("{:?}", self.scheme.required_auth()), "codex_backend": self.is_codex_backend(), "cache_key_present": request.cache_key.is_some(), "stream_open_timeout_ms": DEFAULT_STREAM_OPEN_TIMEOUT.as_millis() as u64, }), ); let url = self.build_url(); let headers_started = Instant::now(); emit_transport_trace( &request, "transport_headers_start", json!({ "auth_kind": auth_kind(&self.auth), "required_auth": format!("{:?}", self.scheme.required_auth()), }), ); let mut headers = match self.build_headers().await { Ok(headers) => { emit_transport_trace( &request, "transport_headers_done", json!({ "elapsed_ms": headers_started.elapsed().as_millis() as u64, "headers_len": headers.len(), }), ); headers } Err(error) => { emit_transport_trace( &request, "transport_headers_error", json!({ "elapsed_ms": headers_started.elapsed().as_millis() as u64, "error": error.to_string(), }), ); return Err(error); } }; let stream_headers_started = Instant::now(); if let Err(error) = self.apply_stream_headers(&mut headers, &request) { emit_transport_trace( &request, "transport_stream_headers_error", json!({ "elapsed_ms": stream_headers_started.elapsed().as_millis() as u64, "error": error.to_string(), }), ); return Err(error); } emit_transport_trace( &request, "transport_stream_headers_done", json!({ "elapsed_ms": stream_headers_started.elapsed().as_millis() as u64, "headers_len": headers.len(), }), ); let body_started = Instant::now(); emit_transport_trace(&request, "transport_body_build_start", json!({})); let body = self .scheme .build_request_body(&self.model_id, &request, &self.capability); let body_shape = request_body_shape_payload(&body); emit_transport_trace( &request, "transport_body_build_done", json!({ "elapsed_ms": body_started.elapsed().as_millis() as u64, "body_kind": json_value_kind(&body), "request_shape": body_shape.clone(), }), ); let encode_started = Instant::now(); let request_body = match self.encode_request_body(&body, &mut headers) { Ok(body) => body, Err(error) => { emit_transport_trace( &request, "transport_body_encode_error", json!({ "elapsed_ms": encode_started.elapsed().as_millis() as u64, "error": error.to_string(), }), ); return Err(error); } }; emit_transport_trace( &request, "transport_body_encode_done", json!({ "elapsed_ms": encode_started.elapsed().as_millis() as u64, "encoding": request_body.encoding(), "raw_json_bytes": request_body.raw_json_bytes(), "wire_bytes": request_body.wire_bytes(), "request_shape": body_shape.clone(), }), ); let builder = self.http_client.post(&url).headers(headers); let builder = match request_body { RequestBody::Json(body) => builder.json(&body), RequestBody::CompressedJson { bytes, .. } => builder.body(bytes), }; let send_started = Instant::now(); emit_transport_trace(&request, "transport_http_send_start", json!({})); let response = match response_with_timeout(builder.send(), DEFAULT_STREAM_OPEN_TIMEOUT, "stream_open") .await { Ok(response) => { let response_headers = response_header_diagnostics(response.headers()); emit_transport_trace( &request, "transport_http_headers_received", json!({ "elapsed_ms": send_started.elapsed().as_millis() as u64, "status": response.status().as_u16(), "success": response.status().is_success(), "headers": response_headers, }), ); response } Err(error) => { emit_transport_trace( &request, "transport_http_send_error", json!({ "elapsed_ms": send_started.elapsed().as_millis() as u64, "error": error.to_string(), }), ); return Err(error); } }; if !response.status().is_success() { let status = response.status().as_u16(); let retry_after_present = response.headers().get(RETRY_AFTER).is_some(); let error = classify_error_response(response).await; let context_length_exceeded = is_context_length_exceeded(&error); emit_transport_trace( &request, "transport_http_status_error", json!({ "status": status, "retry_after_present": retry_after_present, "api_error_code": api_error_code(&error), "context_length_exceeded": context_length_exceeded, "provider_usage_absent": context_length_exceeded, "request_shape": body_shape.clone(), }), ); return Err(error); } emit_transport_trace( &request, "transport_stream_ready", json!({ "elapsed_ms": total_started.elapsed().as_millis() as u64, }), ); let scheme = self.scheme.clone(); let status = response.status().as_u16(); let response_headers = response_header_diagnostics(response.headers()); let transport_trace = request.transport_trace.clone(); let byte_stream = response.bytes_stream().map_err(std::io::Error::other); let event_stream = byte_stream.eventsource(); // scheme 固有のパース状態をストリーム単位で保持する let mut state = ::default(); let stream = event_stream .map(move |result| match result { Ok(frame) => match scheme.parse_sse(&frame.event, &frame.data, &mut state) { Ok(events) => Ok(events), Err(e) => Err(e), }, Err(e) => { let source = e.to_string(); let message = sse_error_context(status, &response_headers, &source); if let Some(trace) = &transport_trace { trace.emit( "transport_sse_parse_error", json!({ "status": status, "headers": response_headers.clone(), "error": source, }), ); } Err(ClientError::Sse(message)) } }) .map(|res| { let s: Pin> + Send>> = match res { Ok(events) => Box::pin(futures::stream::iter(events.into_iter().map(Ok))), Err(e) => Box::pin(futures::stream::once(async move { Err(e) })), }; s }) .flatten(); Ok(Box::pin(stream)) } } #[cfg(test)] mod tests { use super::*; use serde_json::json; #[derive(Debug)] struct TestAuthProvider { codex: bool, } #[async_trait] impl AuthProvider for TestAuthProvider { async fn headers(&self) -> Result, ClientError> { Ok(vec![ ( HeaderName::from_static("authorization"), HeaderValue::from_static("Bearer test-token"), ), ( HeaderName::from_static("chatgpt-account-id"), HeaderValue::from_static("account-1"), ), ]) } fn is_codex_backend(&self) -> bool { self.codex } } #[derive(Clone)] struct TestScheme; impl Scheme for TestScheme { type State = (); fn default_base_url(&self) -> &'static str { "https://example.test" } fn path(&self, _model_id: &str) -> String { "/responses".to_string() } fn required_auth(&self) -> AuthRequirement { AuthRequirement::Bearer } fn build_request_body( &self, model_id: &str, request: &Request, _capability: &ModelCapability, ) -> serde_json::Value { json!({ "model": model_id, "input_len": request.items.len(), "prompt_cache_key": request.cache_key, }) } fn parse_sse( &self, _event_type: &str, _data: &str, _state: &mut Self::State, ) -> Result, ClientError> { Ok(Vec::new()) } fn default_capability(&self) -> ModelCapability { ModelCapability::minimal() } } fn transport(auth: ResolvedAuth) -> HttpTransport { HttpTransport::new( TestScheme, "gpt-test", "https://example.test", auth, ModelCapability::minimal(), ) } #[test] fn sse_error_context_includes_response_headers() { let headers = json!({ "content_type": "application/octet-stream", "content_encoding": "gzip", "transfer_encoding": "chunked", "content_length": "123", }); let message = sse_error_context(200, &headers, "stream did not contain valid UTF-8"); assert!(message.contains("HTTP 200")); assert!(message.contains("stream did not contain valid UTF-8")); assert!(message.contains("content-type=application/octet-stream")); assert!(message.contains("content-encoding=gzip")); assert!(message.contains("transfer-encoding=chunked")); assert!(message.contains("content-length=123")); } #[test] fn response_header_diagnostics_redacts_to_safe_header_subset() { let mut headers = HeaderMap::new(); headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); headers.insert(CONTENT_ENCODING, HeaderValue::from_static("identity")); headers.insert("authorization", HeaderValue::from_static("Bearer secret")); let diagnostics = response_header_diagnostics(&headers); assert_eq!(diagnostics["content_type"], "text/event-stream"); assert_eq!(diagnostics["content_encoding"], "identity"); assert!(diagnostics.get("authorization").is_none()); } #[test] fn request_body_shape_counts_reasoning_encrypted_content() { let payload = request_body_shape_payload(&json!({ "reasoning": { "summary": "auto" }, "input": [ { "type": "message", "role": "user", "content": [] }, { "type": "reasoning", "encrypted_content": "abc", "summary": [] }, { "type": "reasoning", "encrypted_content": "defgh", "summary": [] } ] })); assert_eq!(payload["items_len"], 3); assert_eq!(payload["reasoning_items"], 2); assert_eq!(payload["reasoning_encrypted_content_count"], 2); assert_eq!(payload["reasoning_encrypted_content_bytes"], 8); assert!(payload["items_json_bytes"].as_u64().unwrap() > 0); } #[tokio::test] async fn response_timeout_returns_retryable_lifecycle_timeout() { let err = response_with_timeout( std::future::pending::>(), Duration::from_millis(5), "stream_open", ) .await .unwrap_err(); assert!(crate::llm_client::error::is_retryable(&err)); assert!(matches!( err, ClientError::Timeout { phase: "stream_open", .. } )); } #[tokio::test] async fn codex_backend_adds_conversation_headers_and_zstd_body() { let transport = transport(ResolvedAuth::Custom(Arc::new(TestAuthProvider { codex: true, }))); let request = Request::new().user("hello").cache_key("segment-123"); let mut headers = transport.build_headers().await.unwrap(); transport .apply_stream_headers(&mut headers, &request) .unwrap(); let body = transport.scheme.build_request_body( &transport.model_id, &request, &transport.capability, ); let encoded = transport.encode_request_body(&body, &mut headers).unwrap(); assert_eq!(headers.get(ACCEPT).unwrap(), "text/event-stream"); assert_eq!(headers.get("session_id").unwrap(), "segment-123"); assert_eq!(headers.get("x-client-request-id").unwrap(), "segment-123"); assert_eq!(headers.get(CONTENT_ENCODING).unwrap(), "zstd"); let RequestBody::CompressedJson { bytes: compressed, raw_json_bytes, } = encoded else { panic!("Codex backend request body must be zstd-compressed"); }; assert!(raw_json_bytes > 0); let decoded = zstd::stream::decode_all(std::io::Cursor::new(compressed)).unwrap(); let decoded: serde_json::Value = serde_json::from_slice(&decoded).unwrap(); assert_eq!(decoded["prompt_cache_key"], "segment-123"); } #[tokio::test] async fn non_codex_request_does_not_get_codex_only_headers_or_compression() { let transport = transport(ResolvedAuth::ApiKey("api-key".to_string())); let request = Request::new().user("hello").cache_key("segment-123"); let mut headers = transport.build_headers().await.unwrap(); transport .apply_stream_headers(&mut headers, &request) .unwrap(); let body = transport.scheme.build_request_body( &transport.model_id, &request, &transport.capability, ); let encoded = transport.encode_request_body(&body, &mut headers).unwrap(); assert_eq!(headers.get(ACCEPT).unwrap(), "text/event-stream"); assert!(headers.get("session_id").is_none()); assert!(headers.get("x-client-request-id").is_none()); assert!(headers.get(CONTENT_ENCODING).is_none()); let RequestBody::Json(decoded) = encoded else { panic!("non-Codex request body must remain normal JSON"); }; assert_eq!(decoded["prompt_cache_key"], "segment-123"); } }