486 lines
17 KiB
Rust
486 lines
17 KiB
Rust
//! `HttpTransport<S: Scheme>`: すべての LLM wire scheme を共通の 1 本の
|
||
//! HTTP クライアントで扱う。
|
||
//!
|
||
//! 旧 `providers/{anthropic,openai,gemini,ollama}.rs` を置き換える。
|
||
//! scheme 固有の差分は [`Scheme`] trait 実装に委譲する。
|
||
|
||
use std::pin::Pin;
|
||
use std::sync::Arc;
|
||
use std::time::Duration;
|
||
|
||
use async_trait::async_trait;
|
||
use eventsource_stream::Eventsource;
|
||
use futures::{Stream, StreamExt, TryStreamExt};
|
||
use reqwest::header::{
|
||
ACCEPT, CONTENT_ENCODING, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue, RETRY_AFTER,
|
||
};
|
||
|
||
use super::auth::{AuthProvider, AuthRequirement};
|
||
use super::capability::ModelCapability;
|
||
use super::client::{ConfigWarning, LlmClient, ResponseStream};
|
||
use super::error::ClientError;
|
||
use super::event::Event;
|
||
use super::scheme::Scheme;
|
||
use super::types::{Request, RequestConfig};
|
||
|
||
pub const DEFAULT_STREAM_OPEN_TIMEOUT: Duration = Duration::from_secs(30);
|
||
pub const DEFAULT_FIRST_STREAM_EVENT_TIMEOUT: Duration = Duration::from_secs(30);
|
||
|
||
/// `AuthRef` を解決したランタイム表現。`crates/provider` が構築する。
|
||
///
|
||
/// - `None`: 認証ヘッダを送らない(Ollama 等の opt-out)
|
||
/// - `ApiKey`: 静的な API key 文字列
|
||
/// - `Custom`: リクエスト毎に動的にヘッダを組み立てる(Codex OAuth 等)
|
||
#[derive(Debug, Clone)]
|
||
pub enum ResolvedAuth {
|
||
None,
|
||
ApiKey(String),
|
||
Custom(Arc<dyn AuthProvider>),
|
||
}
|
||
|
||
impl ResolvedAuth {
|
||
/// 認証要件と実際の解決値が噛み合うか検査する。構築時検証用。
|
||
///
|
||
/// - `ResolvedAuth::None` は認証を付けない宣言なので、どの
|
||
/// `AuthRequirement` でも受け入れる(Ollama の Anthropic scheme
|
||
/// 流用は `required_auth = XApiKey` だが認証ヘッダなしで動く)
|
||
/// - `ResolvedAuth::Custom` は「ヘッダ組立を全部こちらで行う」
|
||
/// 宣言なので、scheme が要求する形式によらず受け入れる
|
||
pub fn matches(&self, req: AuthRequirement) -> bool {
|
||
match (self, req) {
|
||
(Self::None, _) => true,
|
||
(Self::Custom(_), _) => true,
|
||
(
|
||
Self::ApiKey(_),
|
||
AuthRequirement::Bearer
|
||
| AuthRequirement::XApiKey
|
||
| AuthRequirement::QueryParam { .. },
|
||
) => true,
|
||
_ => false,
|
||
}
|
||
}
|
||
}
|
||
|
||
/// scheme 共通の HTTP 通信層。
|
||
pub struct HttpTransport<S: Scheme> {
|
||
http_client: reqwest::Client,
|
||
scheme: S,
|
||
model_id: String,
|
||
base_url: String,
|
||
auth: ResolvedAuth,
|
||
capability: ModelCapability,
|
||
}
|
||
|
||
impl<S: Scheme> HttpTransport<S> {
|
||
/// 新しい transport を作る。`base_url` は末尾スラッシュの有無を
|
||
/// どちらでも受け付ける(内部で正規化)。
|
||
pub fn new(
|
||
scheme: S,
|
||
model_id: impl Into<String>,
|
||
base_url: impl Into<String>,
|
||
auth: ResolvedAuth,
|
||
capability: ModelCapability,
|
||
) -> Self {
|
||
let base_url = base_url.into();
|
||
let base_url = base_url.trim_end_matches('/').to_string();
|
||
Self {
|
||
http_client: reqwest::Client::new(),
|
||
scheme,
|
||
model_id: model_id.into(),
|
||
base_url,
|
||
auth,
|
||
capability,
|
||
}
|
||
}
|
||
|
||
/// カスタム HTTP クライアントを差し込む(テスト等)。
|
||
pub fn with_http_client(mut self, client: reqwest::Client) -> Self {
|
||
self.http_client = client;
|
||
self
|
||
}
|
||
|
||
fn build_url(&self) -> String {
|
||
let path = self.scheme.path(&self.model_id);
|
||
let url = format!("{}{}", self.base_url, path);
|
||
// Gemini のようにクエリパラメータで認証する場合は URL にキーを追記する
|
||
if let (AuthRequirement::QueryParam { name }, ResolvedAuth::ApiKey(key)) =
|
||
(self.scheme.required_auth(), &self.auth)
|
||
{
|
||
let sep = if url.contains('?') { '&' } else { '?' };
|
||
format!("{url}{sep}{name}={key}")
|
||
} else {
|
||
url
|
||
}
|
||
}
|
||
|
||
async fn build_headers(&self) -> Result<HeaderMap, ClientError> {
|
||
let mut headers = HeaderMap::new();
|
||
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
|
||
|
||
match (&self.auth, self.scheme.required_auth()) {
|
||
(ResolvedAuth::None, _) | (_, AuthRequirement::None) => {}
|
||
(ResolvedAuth::Custom(provider), _) => {
|
||
for (name, mut value) in provider.headers().await? {
|
||
value.set_sensitive(true);
|
||
headers.insert(name, value);
|
||
}
|
||
}
|
||
(ResolvedAuth::ApiKey(key), AuthRequirement::Bearer) => {
|
||
let mut val = HeaderValue::from_str(&format!("Bearer {key}"))
|
||
.map_err(|e| ClientError::Config(format!("invalid api key: {e}")))?;
|
||
val.set_sensitive(true);
|
||
headers.insert("Authorization", val);
|
||
}
|
||
(ResolvedAuth::ApiKey(key), AuthRequirement::XApiKey) => {
|
||
let mut val = HeaderValue::from_str(key.as_str())
|
||
.map_err(|e| ClientError::Config(format!("invalid api key: {e}")))?;
|
||
val.set_sensitive(true);
|
||
headers.insert("x-api-key", val);
|
||
}
|
||
(_, AuthRequirement::QueryParam { .. }) => {
|
||
// クエリパラメータは `build_url` で付与済み
|
||
}
|
||
(ResolvedAuth::ApiKey(_), AuthRequirement::Custom) => {
|
||
// scheme が Custom を要求する組合せに ApiKey は流れてこない想定
|
||
// (`matches()` で弾かれる)。安全側で何もしない
|
||
}
|
||
}
|
||
|
||
for (name, value) in self.scheme.additional_headers() {
|
||
let hv = HeaderValue::from_str(&value)
|
||
.map_err(|e| ClientError::Config(format!("invalid header {name}: {e}")))?;
|
||
headers.insert(name, hv);
|
||
}
|
||
|
||
Ok(headers)
|
||
}
|
||
|
||
fn is_codex_backend(&self) -> bool {
|
||
match &self.auth {
|
||
ResolvedAuth::Custom(provider) => provider.is_codex_backend(),
|
||
_ => false,
|
||
}
|
||
}
|
||
|
||
fn apply_stream_headers(
|
||
&self,
|
||
headers: &mut HeaderMap,
|
||
request: &Request,
|
||
) -> Result<(), ClientError> {
|
||
headers.insert(ACCEPT, HeaderValue::from_static("text/event-stream"));
|
||
|
||
if self.is_codex_backend()
|
||
&& let Some(cache_key) = request.cache_key.as_deref()
|
||
{
|
||
let value = HeaderValue::from_str(cache_key).map_err(|e| {
|
||
ClientError::Config(format!("invalid Codex conversation header: {e}"))
|
||
})?;
|
||
headers.insert(HeaderName::from_static("session_id"), value.clone());
|
||
headers.insert(HeaderName::from_static("x-client-request-id"), value);
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
fn encode_request_body(
|
||
&self,
|
||
body: &serde_json::Value,
|
||
headers: &mut HeaderMap,
|
||
) -> Result<RequestBody, ClientError> {
|
||
if !self.is_codex_backend() {
|
||
return Ok(RequestBody::Json(body.clone()));
|
||
}
|
||
|
||
let raw = serde_json::to_vec(body)?;
|
||
let compressed = zstd::stream::encode_all(std::io::Cursor::new(raw), 3)
|
||
.map_err(|e| ClientError::Config(format!("failed to zstd-compress request: {e}")))?;
|
||
headers.insert(CONTENT_ENCODING, HeaderValue::from_static("zstd"));
|
||
Ok(RequestBody::CompressedJson(compressed))
|
||
}
|
||
}
|
||
|
||
enum RequestBody {
|
||
Json(serde_json::Value),
|
||
CompressedJson(Vec<u8>),
|
||
}
|
||
|
||
async fn response_with_timeout(
|
||
future: impl std::future::Future<Output = Result<reqwest::Response, reqwest::Error>>,
|
||
timeout: Duration,
|
||
phase: &'static str,
|
||
) -> Result<reqwest::Response, ClientError> {
|
||
tokio::time::timeout(timeout, future)
|
||
.await
|
||
.map_err(|_| ClientError::Timeout { phase, timeout })?
|
||
.map_err(ClientError::Http)
|
||
}
|
||
|
||
impl<S: Scheme + Clone> Clone for HttpTransport<S> {
|
||
fn clone(&self) -> Self {
|
||
Self {
|
||
http_client: self.http_client.clone(),
|
||
scheme: self.scheme.clone(),
|
||
model_id: self.model_id.clone(),
|
||
base_url: self.base_url.clone(),
|
||
auth: self.auth.clone(),
|
||
capability: self.capability.clone(),
|
||
}
|
||
}
|
||
}
|
||
|
||
/// エラーレスポンスを `ClientError::Api` に変換する。
|
||
async fn classify_error_response(resp: reqwest::Response) -> ClientError {
|
||
let status = resp.status().as_u16();
|
||
let retry_after = resp
|
||
.headers()
|
||
.get(RETRY_AFTER)
|
||
.and_then(|v| v.to_str().ok())
|
||
.and_then(|s| s.trim().parse::<u64>().ok())
|
||
.map(Duration::from_secs);
|
||
let text = resp.text().await.unwrap_or_default();
|
||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&text) {
|
||
let error = json.get("error").unwrap_or(&json);
|
||
let code = error.get("type").and_then(|v| v.as_str()).map(String::from);
|
||
let message = error
|
||
.get("message")
|
||
.and_then(|v| v.as_str())
|
||
.unwrap_or(&text)
|
||
.to_string();
|
||
ClientError::Api {
|
||
status: Some(status),
|
||
code,
|
||
message,
|
||
retry_after,
|
||
}
|
||
} else {
|
||
ClientError::Api {
|
||
status: Some(status),
|
||
code: None,
|
||
message: text,
|
||
retry_after,
|
||
}
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl<S: Scheme + Clone + 'static> LlmClient for HttpTransport<S> {
|
||
fn clone_boxed(&self) -> Box<dyn LlmClient> {
|
||
Box::new(self.clone())
|
||
}
|
||
|
||
fn validate_config(&self, config: &RequestConfig) -> Vec<ConfigWarning> {
|
||
self.scheme.validate_config(config)
|
||
}
|
||
|
||
async fn stream(&self, request: Request) -> Result<ResponseStream, ClientError> {
|
||
let url = self.build_url();
|
||
let mut headers = self.build_headers().await?;
|
||
self.apply_stream_headers(&mut headers, &request)?;
|
||
let body = self
|
||
.scheme
|
||
.build_request_body(&self.model_id, &request, &self.capability);
|
||
let request_body = self.encode_request_body(&body, &mut headers)?;
|
||
|
||
let builder = self.http_client.post(&url).headers(headers);
|
||
let builder = match request_body {
|
||
RequestBody::Json(body) => builder.json(&body),
|
||
RequestBody::CompressedJson(body) => builder.body(body),
|
||
};
|
||
let response =
|
||
response_with_timeout(builder.send(), DEFAULT_STREAM_OPEN_TIMEOUT, "stream_open")
|
||
.await?;
|
||
|
||
if !response.status().is_success() {
|
||
return Err(classify_error_response(response).await);
|
||
}
|
||
|
||
let scheme = self.scheme.clone();
|
||
let byte_stream = response.bytes_stream().map_err(std::io::Error::other);
|
||
let event_stream = byte_stream.eventsource();
|
||
|
||
// scheme 固有のパース状態をストリーム単位で保持する
|
||
let mut state = <S::State as Default>::default();
|
||
|
||
let stream = event_stream
|
||
.map(move |result| match result {
|
||
Ok(frame) => match scheme.parse_sse(&frame.event, &frame.data, &mut state) {
|
||
Ok(events) => Ok(events),
|
||
Err(e) => Err(e),
|
||
},
|
||
Err(e) => Err(ClientError::Sse(e.to_string())),
|
||
})
|
||
.map(|res| {
|
||
let s: Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>> = match res {
|
||
Ok(events) => Box::pin(futures::stream::iter(events.into_iter().map(Ok))),
|
||
Err(e) => Box::pin(futures::stream::once(async move { Err(e) })),
|
||
};
|
||
s
|
||
})
|
||
.flatten();
|
||
|
||
Ok(Box::pin(stream))
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use serde_json::json;
|
||
|
||
#[derive(Debug)]
|
||
struct TestAuthProvider {
|
||
codex: bool,
|
||
}
|
||
|
||
#[async_trait]
|
||
impl AuthProvider for TestAuthProvider {
|
||
async fn headers(&self) -> Result<Vec<(HeaderName, HeaderValue)>, ClientError> {
|
||
Ok(vec![
|
||
(
|
||
HeaderName::from_static("authorization"),
|
||
HeaderValue::from_static("Bearer test-token"),
|
||
),
|
||
(
|
||
HeaderName::from_static("chatgpt-account-id"),
|
||
HeaderValue::from_static("account-1"),
|
||
),
|
||
])
|
||
}
|
||
|
||
fn is_codex_backend(&self) -> bool {
|
||
self.codex
|
||
}
|
||
}
|
||
|
||
#[derive(Clone)]
|
||
struct TestScheme;
|
||
|
||
impl Scheme for TestScheme {
|
||
type State = ();
|
||
|
||
fn default_base_url(&self) -> &'static str {
|
||
"https://example.test"
|
||
}
|
||
|
||
fn path(&self, _model_id: &str) -> String {
|
||
"/responses".to_string()
|
||
}
|
||
|
||
fn required_auth(&self) -> AuthRequirement {
|
||
AuthRequirement::Bearer
|
||
}
|
||
|
||
fn build_request_body(
|
||
&self,
|
||
model_id: &str,
|
||
request: &Request,
|
||
_capability: &ModelCapability,
|
||
) -> serde_json::Value {
|
||
json!({
|
||
"model": model_id,
|
||
"input_len": request.items.len(),
|
||
"prompt_cache_key": request.cache_key,
|
||
})
|
||
}
|
||
|
||
fn parse_sse(
|
||
&self,
|
||
_event_type: &str,
|
||
_data: &str,
|
||
_state: &mut Self::State,
|
||
) -> Result<Vec<Event>, ClientError> {
|
||
Ok(Vec::new())
|
||
}
|
||
|
||
fn default_capability(&self) -> ModelCapability {
|
||
ModelCapability::minimal()
|
||
}
|
||
}
|
||
|
||
fn transport(auth: ResolvedAuth) -> HttpTransport<TestScheme> {
|
||
HttpTransport::new(
|
||
TestScheme,
|
||
"gpt-test",
|
||
"https://example.test",
|
||
auth,
|
||
ModelCapability::minimal(),
|
||
)
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn response_timeout_returns_retryable_lifecycle_timeout() {
|
||
let err = response_with_timeout(
|
||
std::future::pending::<Result<reqwest::Response, reqwest::Error>>(),
|
||
Duration::from_millis(5),
|
||
"stream_open",
|
||
)
|
||
.await
|
||
.unwrap_err();
|
||
|
||
assert!(crate::llm_client::error::is_retryable(&err));
|
||
assert!(matches!(
|
||
err,
|
||
ClientError::Timeout {
|
||
phase: "stream_open",
|
||
..
|
||
}
|
||
));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn codex_backend_adds_conversation_headers_and_zstd_body() {
|
||
let transport = transport(ResolvedAuth::Custom(Arc::new(TestAuthProvider {
|
||
codex: true,
|
||
})));
|
||
let request = Request::new().user("hello").cache_key("segment-123");
|
||
let mut headers = transport.build_headers().await.unwrap();
|
||
transport
|
||
.apply_stream_headers(&mut headers, &request)
|
||
.unwrap();
|
||
let body = transport.scheme.build_request_body(
|
||
&transport.model_id,
|
||
&request,
|
||
&transport.capability,
|
||
);
|
||
let encoded = transport.encode_request_body(&body, &mut headers).unwrap();
|
||
|
||
assert_eq!(headers.get(ACCEPT).unwrap(), "text/event-stream");
|
||
assert_eq!(headers.get("session_id").unwrap(), "segment-123");
|
||
assert_eq!(headers.get("x-client-request-id").unwrap(), "segment-123");
|
||
assert_eq!(headers.get(CONTENT_ENCODING).unwrap(), "zstd");
|
||
|
||
let RequestBody::CompressedJson(compressed) = encoded else {
|
||
panic!("Codex backend request body must be zstd-compressed");
|
||
};
|
||
let decoded = zstd::stream::decode_all(std::io::Cursor::new(compressed)).unwrap();
|
||
let decoded: serde_json::Value = serde_json::from_slice(&decoded).unwrap();
|
||
assert_eq!(decoded["prompt_cache_key"], "segment-123");
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn non_codex_request_does_not_get_codex_only_headers_or_compression() {
|
||
let transport = transport(ResolvedAuth::ApiKey("api-key".to_string()));
|
||
let request = Request::new().user("hello").cache_key("segment-123");
|
||
let mut headers = transport.build_headers().await.unwrap();
|
||
transport
|
||
.apply_stream_headers(&mut headers, &request)
|
||
.unwrap();
|
||
let body = transport.scheme.build_request_body(
|
||
&transport.model_id,
|
||
&request,
|
||
&transport.capability,
|
||
);
|
||
let encoded = transport.encode_request_body(&body, &mut headers).unwrap();
|
||
|
||
assert_eq!(headers.get(ACCEPT).unwrap(), "text/event-stream");
|
||
assert!(headers.get("session_id").is_none());
|
||
assert!(headers.get("x-client-request-id").is_none());
|
||
assert!(headers.get(CONTENT_ENCODING).is_none());
|
||
|
||
let RequestBody::Json(decoded) = encoded else {
|
||
panic!("non-Codex request body must remain normal JSON");
|
||
};
|
||
assert_eq!(decoded["prompt_cache_key"], "segment-123");
|
||
}
|
||
}
|