//! `HttpTransport`: すべての LLM wire scheme を共通の 1 本の //! HTTP クライアントで扱う。 //! //! 旧 `providers/{anthropic,openai,gemini,ollama}.rs` を置き換える。 //! scheme 固有の差分は [`Scheme`] trait 実装に委譲する。 use std::pin::Pin; use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; use eventsource_stream::Eventsource; use futures::{Stream, StreamExt, TryStreamExt}; use reqwest::header::{ ACCEPT, CONTENT_ENCODING, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue, RETRY_AFTER, }; use super::auth::{AuthProvider, AuthRequirement}; use super::capability::ModelCapability; use super::client::{ConfigWarning, LlmClient, ResponseStream}; use super::error::ClientError; use super::event::Event; use super::scheme::Scheme; use super::types::{Request, RequestConfig}; pub const DEFAULT_STREAM_OPEN_TIMEOUT: Duration = Duration::from_secs(30); pub const DEFAULT_FIRST_STREAM_EVENT_TIMEOUT: Duration = Duration::from_secs(30); /// `AuthRef` を解決したランタイム表現。`crates/provider` が構築する。 /// /// - `None`: 認証ヘッダを送らない(Ollama 等の opt-out) /// - `ApiKey`: 静的な API key 文字列 /// - `Custom`: リクエスト毎に動的にヘッダを組み立てる(Codex OAuth 等) #[derive(Debug, Clone)] pub enum ResolvedAuth { None, ApiKey(String), Custom(Arc), } impl ResolvedAuth { /// 認証要件と実際の解決値が噛み合うか検査する。構築時検証用。 /// /// - `ResolvedAuth::None` は認証を付けない宣言なので、どの /// `AuthRequirement` でも受け入れる(Ollama の Anthropic scheme /// 流用は `required_auth = XApiKey` だが認証ヘッダなしで動く) /// - `ResolvedAuth::Custom` は「ヘッダ組立を全部こちらで行う」 /// 宣言なので、scheme が要求する形式によらず受け入れる pub fn matches(&self, req: AuthRequirement) -> bool { match (self, req) { (Self::None, _) => true, (Self::Custom(_), _) => true, ( Self::ApiKey(_), AuthRequirement::Bearer | AuthRequirement::XApiKey | AuthRequirement::QueryParam { .. }, ) => true, _ => false, } } } /// scheme 共通の HTTP 通信層。 pub struct HttpTransport { http_client: reqwest::Client, scheme: S, model_id: String, base_url: String, auth: ResolvedAuth, capability: ModelCapability, } impl HttpTransport { /// 新しい transport を作る。`base_url` は末尾スラッシュの有無を /// どちらでも受け付ける(内部で正規化)。 pub fn new( scheme: S, model_id: impl Into, base_url: impl Into, auth: ResolvedAuth, capability: ModelCapability, ) -> Self { let base_url = base_url.into(); let base_url = base_url.trim_end_matches('/').to_string(); Self { http_client: reqwest::Client::new(), scheme, model_id: model_id.into(), base_url, auth, capability, } } /// カスタム HTTP クライアントを差し込む(テスト等)。 pub fn with_http_client(mut self, client: reqwest::Client) -> Self { self.http_client = client; self } fn build_url(&self) -> String { let path = self.scheme.path(&self.model_id); let url = format!("{}{}", self.base_url, path); // Gemini のようにクエリパラメータで認証する場合は URL にキーを追記する if let (AuthRequirement::QueryParam { name }, ResolvedAuth::ApiKey(key)) = (self.scheme.required_auth(), &self.auth) { let sep = if url.contains('?') { '&' } else { '?' }; format!("{url}{sep}{name}={key}") } else { url } } async fn build_headers(&self) -> Result { let mut headers = HeaderMap::new(); headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); match (&self.auth, self.scheme.required_auth()) { (ResolvedAuth::None, _) | (_, AuthRequirement::None) => {} (ResolvedAuth::Custom(provider), _) => { for (name, mut value) in provider.headers().await? { value.set_sensitive(true); headers.insert(name, value); } } (ResolvedAuth::ApiKey(key), AuthRequirement::Bearer) => { let mut val = HeaderValue::from_str(&format!("Bearer {key}")) .map_err(|e| ClientError::Config(format!("invalid api key: {e}")))?; val.set_sensitive(true); headers.insert("Authorization", val); } (ResolvedAuth::ApiKey(key), AuthRequirement::XApiKey) => { let mut val = HeaderValue::from_str(key.as_str()) .map_err(|e| ClientError::Config(format!("invalid api key: {e}")))?; val.set_sensitive(true); headers.insert("x-api-key", val); } (_, AuthRequirement::QueryParam { .. }) => { // クエリパラメータは `build_url` で付与済み } (ResolvedAuth::ApiKey(_), AuthRequirement::Custom) => { // scheme が Custom を要求する組合せに ApiKey は流れてこない想定 // (`matches()` で弾かれる)。安全側で何もしない } } for (name, value) in self.scheme.additional_headers() { let hv = HeaderValue::from_str(&value) .map_err(|e| ClientError::Config(format!("invalid header {name}: {e}")))?; headers.insert(name, hv); } Ok(headers) } fn is_codex_backend(&self) -> bool { match &self.auth { ResolvedAuth::Custom(provider) => provider.is_codex_backend(), _ => false, } } fn apply_stream_headers( &self, headers: &mut HeaderMap, request: &Request, ) -> Result<(), ClientError> { headers.insert(ACCEPT, HeaderValue::from_static("text/event-stream")); if self.is_codex_backend() && let Some(cache_key) = request.cache_key.as_deref() { let value = HeaderValue::from_str(cache_key).map_err(|e| { ClientError::Config(format!("invalid Codex conversation header: {e}")) })?; headers.insert(HeaderName::from_static("session_id"), value.clone()); headers.insert(HeaderName::from_static("x-client-request-id"), value); } Ok(()) } fn encode_request_body( &self, body: &serde_json::Value, headers: &mut HeaderMap, ) -> Result { if !self.is_codex_backend() { return Ok(RequestBody::Json(body.clone())); } let raw = serde_json::to_vec(body)?; let compressed = zstd::stream::encode_all(std::io::Cursor::new(raw), 3) .map_err(|e| ClientError::Config(format!("failed to zstd-compress request: {e}")))?; headers.insert(CONTENT_ENCODING, HeaderValue::from_static("zstd")); Ok(RequestBody::CompressedJson(compressed)) } } enum RequestBody { Json(serde_json::Value), CompressedJson(Vec), } async fn response_with_timeout( future: impl std::future::Future>, timeout: Duration, phase: &'static str, ) -> Result { tokio::time::timeout(timeout, future) .await .map_err(|_| ClientError::Timeout { phase, timeout })? .map_err(ClientError::Http) } impl Clone for HttpTransport { fn clone(&self) -> Self { Self { http_client: self.http_client.clone(), scheme: self.scheme.clone(), model_id: self.model_id.clone(), base_url: self.base_url.clone(), auth: self.auth.clone(), capability: self.capability.clone(), } } } /// エラーレスポンスを `ClientError::Api` に変換する。 async fn classify_error_response(resp: reqwest::Response) -> ClientError { let status = resp.status().as_u16(); let retry_after = resp .headers() .get(RETRY_AFTER) .and_then(|v| v.to_str().ok()) .and_then(|s| s.trim().parse::().ok()) .map(Duration::from_secs); let text = resp.text().await.unwrap_or_default(); if let Ok(json) = serde_json::from_str::(&text) { let error = json.get("error").unwrap_or(&json); let code = error.get("type").and_then(|v| v.as_str()).map(String::from); let message = error .get("message") .and_then(|v| v.as_str()) .unwrap_or(&text) .to_string(); ClientError::Api { status: Some(status), code, message, retry_after, } } else { ClientError::Api { status: Some(status), code: None, message: text, retry_after, } } } #[async_trait] impl LlmClient for HttpTransport { fn clone_boxed(&self) -> Box { Box::new(self.clone()) } fn validate_config(&self, config: &RequestConfig) -> Vec { self.scheme.validate_config(config) } async fn stream(&self, request: Request) -> Result { let url = self.build_url(); let mut headers = self.build_headers().await?; self.apply_stream_headers(&mut headers, &request)?; let body = self .scheme .build_request_body(&self.model_id, &request, &self.capability); let request_body = self.encode_request_body(&body, &mut headers)?; let builder = self.http_client.post(&url).headers(headers); let builder = match request_body { RequestBody::Json(body) => builder.json(&body), RequestBody::CompressedJson(body) => builder.body(body), }; let response = response_with_timeout(builder.send(), DEFAULT_STREAM_OPEN_TIMEOUT, "stream_open") .await?; if !response.status().is_success() { return Err(classify_error_response(response).await); } let scheme = self.scheme.clone(); let byte_stream = response.bytes_stream().map_err(std::io::Error::other); let event_stream = byte_stream.eventsource(); // scheme 固有のパース状態をストリーム単位で保持する let mut state = ::default(); let stream = event_stream .map(move |result| match result { Ok(frame) => match scheme.parse_sse(&frame.event, &frame.data, &mut state) { Ok(events) => Ok(events), Err(e) => Err(e), }, Err(e) => Err(ClientError::Sse(e.to_string())), }) .map(|res| { let s: Pin> + Send>> = match res { Ok(events) => Box::pin(futures::stream::iter(events.into_iter().map(Ok))), Err(e) => Box::pin(futures::stream::once(async move { Err(e) })), }; s }) .flatten(); Ok(Box::pin(stream)) } } #[cfg(test)] mod tests { use super::*; use serde_json::json; #[derive(Debug)] struct TestAuthProvider { codex: bool, } #[async_trait] impl AuthProvider for TestAuthProvider { async fn headers(&self) -> Result, ClientError> { Ok(vec![ ( HeaderName::from_static("authorization"), HeaderValue::from_static("Bearer test-token"), ), ( HeaderName::from_static("chatgpt-account-id"), HeaderValue::from_static("account-1"), ), ]) } fn is_codex_backend(&self) -> bool { self.codex } } #[derive(Clone)] struct TestScheme; impl Scheme for TestScheme { type State = (); fn default_base_url(&self) -> &'static str { "https://example.test" } fn path(&self, _model_id: &str) -> String { "/responses".to_string() } fn required_auth(&self) -> AuthRequirement { AuthRequirement::Bearer } fn build_request_body( &self, model_id: &str, request: &Request, _capability: &ModelCapability, ) -> serde_json::Value { json!({ "model": model_id, "input_len": request.items.len(), "prompt_cache_key": request.cache_key, }) } fn parse_sse( &self, _event_type: &str, _data: &str, _state: &mut Self::State, ) -> Result, ClientError> { Ok(Vec::new()) } fn default_capability(&self) -> ModelCapability { ModelCapability::minimal() } } fn transport(auth: ResolvedAuth) -> HttpTransport { HttpTransport::new( TestScheme, "gpt-test", "https://example.test", auth, ModelCapability::minimal(), ) } #[tokio::test] async fn response_timeout_returns_retryable_lifecycle_timeout() { let err = response_with_timeout( std::future::pending::>(), Duration::from_millis(5), "stream_open", ) .await .unwrap_err(); assert!(crate::llm_client::error::is_retryable(&err)); assert!(matches!( err, ClientError::Timeout { phase: "stream_open", .. } )); } #[tokio::test] async fn codex_backend_adds_conversation_headers_and_zstd_body() { let transport = transport(ResolvedAuth::Custom(Arc::new(TestAuthProvider { codex: true, }))); let request = Request::new().user("hello").cache_key("segment-123"); let mut headers = transport.build_headers().await.unwrap(); transport .apply_stream_headers(&mut headers, &request) .unwrap(); let body = transport.scheme.build_request_body( &transport.model_id, &request, &transport.capability, ); let encoded = transport.encode_request_body(&body, &mut headers).unwrap(); assert_eq!(headers.get(ACCEPT).unwrap(), "text/event-stream"); assert_eq!(headers.get("session_id").unwrap(), "segment-123"); assert_eq!(headers.get("x-client-request-id").unwrap(), "segment-123"); assert_eq!(headers.get(CONTENT_ENCODING).unwrap(), "zstd"); let RequestBody::CompressedJson(compressed) = encoded else { panic!("Codex backend request body must be zstd-compressed"); }; let decoded = zstd::stream::decode_all(std::io::Cursor::new(compressed)).unwrap(); let decoded: serde_json::Value = serde_json::from_slice(&decoded).unwrap(); assert_eq!(decoded["prompt_cache_key"], "segment-123"); } #[tokio::test] async fn non_codex_request_does_not_get_codex_only_headers_or_compression() { let transport = transport(ResolvedAuth::ApiKey("api-key".to_string())); let request = Request::new().user("hello").cache_key("segment-123"); let mut headers = transport.build_headers().await.unwrap(); transport .apply_stream_headers(&mut headers, &request) .unwrap(); let body = transport.scheme.build_request_body( &transport.model_id, &request, &transport.capability, ); let encoded = transport.encode_request_body(&body, &mut headers).unwrap(); assert_eq!(headers.get(ACCEPT).unwrap(), "text/event-stream"); assert!(headers.get("session_id").is_none()); assert!(headers.get("x-client-request-id").is_none()); assert!(headers.get(CONTENT_ENCODING).is_none()); let RequestBody::Json(decoded) = encoded else { panic!("non-Codex request body must remain normal JSON"); }; assert_eq!(decoded["prompt_cache_key"], "segment-123"); } }