yoi/crates/llm-worker/src/llm_client/transport.rs

486 lines
17 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! `HttpTransport<S: Scheme>`: すべての LLM wire scheme を共通の 1 本の
//! HTTP クライアントで扱う。
//!
//! 旧 `providers/{anthropic,openai,gemini,ollama}.rs` を置き換える。
//! scheme 固有の差分は [`Scheme`] trait 実装に委譲する。
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use eventsource_stream::Eventsource;
use futures::{Stream, StreamExt, TryStreamExt};
use reqwest::header::{
ACCEPT, CONTENT_ENCODING, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue, RETRY_AFTER,
};
use super::auth::{AuthProvider, AuthRequirement};
use super::capability::ModelCapability;
use super::client::{ConfigWarning, LlmClient, ResponseStream};
use super::error::ClientError;
use super::event::Event;
use super::scheme::Scheme;
use super::types::{Request, RequestConfig};
pub const DEFAULT_STREAM_OPEN_TIMEOUT: Duration = Duration::from_secs(30);
pub const DEFAULT_FIRST_STREAM_EVENT_TIMEOUT: Duration = Duration::from_secs(30);
/// `AuthRef` を解決したランタイム表現。`crates/provider` が構築する。
///
/// - `None`: 認証ヘッダを送らないOllama 等の opt-out
/// - `ApiKey`: 静的な API key 文字列
/// - `Custom`: リクエスト毎に動的にヘッダを組み立てるCodex OAuth 等)
#[derive(Debug, Clone)]
pub enum ResolvedAuth {
None,
ApiKey(String),
Custom(Arc<dyn AuthProvider>),
}
impl ResolvedAuth {
/// 認証要件と実際の解決値が噛み合うか検査する。構築時検証用。
///
/// - `ResolvedAuth::None` は認証を付けない宣言なので、どの
/// `AuthRequirement` でも受け入れるOllama の Anthropic scheme
/// 流用は `required_auth = XApiKey` だが認証ヘッダなしで動く)
/// - `ResolvedAuth::Custom` は「ヘッダ組立を全部こちらで行う」
/// 宣言なので、scheme が要求する形式によらず受け入れる
pub fn matches(&self, req: AuthRequirement) -> bool {
match (self, req) {
(Self::None, _) => true,
(Self::Custom(_), _) => true,
(
Self::ApiKey(_),
AuthRequirement::Bearer
| AuthRequirement::XApiKey
| AuthRequirement::QueryParam { .. },
) => true,
_ => false,
}
}
}
/// scheme 共通の HTTP 通信層。
pub struct HttpTransport<S: Scheme> {
http_client: reqwest::Client,
scheme: S,
model_id: String,
base_url: String,
auth: ResolvedAuth,
capability: ModelCapability,
}
impl<S: Scheme> HttpTransport<S> {
/// 新しい transport を作る。`base_url` は末尾スラッシュの有無を
/// どちらでも受け付ける(内部で正規化)。
pub fn new(
scheme: S,
model_id: impl Into<String>,
base_url: impl Into<String>,
auth: ResolvedAuth,
capability: ModelCapability,
) -> Self {
let base_url = base_url.into();
let base_url = base_url.trim_end_matches('/').to_string();
Self {
http_client: reqwest::Client::new(),
scheme,
model_id: model_id.into(),
base_url,
auth,
capability,
}
}
/// カスタム HTTP クライアントを差し込む(テスト等)。
pub fn with_http_client(mut self, client: reqwest::Client) -> Self {
self.http_client = client;
self
}
fn build_url(&self) -> String {
let path = self.scheme.path(&self.model_id);
let url = format!("{}{}", self.base_url, path);
// Gemini のようにクエリパラメータで認証する場合は URL にキーを追記する
if let (AuthRequirement::QueryParam { name }, ResolvedAuth::ApiKey(key)) =
(self.scheme.required_auth(), &self.auth)
{
let sep = if url.contains('?') { '&' } else { '?' };
format!("{url}{sep}{name}={key}")
} else {
url
}
}
async fn build_headers(&self) -> Result<HeaderMap, ClientError> {
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
match (&self.auth, self.scheme.required_auth()) {
(ResolvedAuth::None, _) | (_, AuthRequirement::None) => {}
(ResolvedAuth::Custom(provider), _) => {
for (name, mut value) in provider.headers().await? {
value.set_sensitive(true);
headers.insert(name, value);
}
}
(ResolvedAuth::ApiKey(key), AuthRequirement::Bearer) => {
let mut val = HeaderValue::from_str(&format!("Bearer {key}"))
.map_err(|e| ClientError::Config(format!("invalid api key: {e}")))?;
val.set_sensitive(true);
headers.insert("Authorization", val);
}
(ResolvedAuth::ApiKey(key), AuthRequirement::XApiKey) => {
let mut val = HeaderValue::from_str(key.as_str())
.map_err(|e| ClientError::Config(format!("invalid api key: {e}")))?;
val.set_sensitive(true);
headers.insert("x-api-key", val);
}
(_, AuthRequirement::QueryParam { .. }) => {
// クエリパラメータは `build_url` で付与済み
}
(ResolvedAuth::ApiKey(_), AuthRequirement::Custom) => {
// scheme が Custom を要求する組合せに ApiKey は流れてこない想定
// `matches()` で弾かれる)。安全側で何もしない
}
}
for (name, value) in self.scheme.additional_headers() {
let hv = HeaderValue::from_str(&value)
.map_err(|e| ClientError::Config(format!("invalid header {name}: {e}")))?;
headers.insert(name, hv);
}
Ok(headers)
}
fn is_codex_backend(&self) -> bool {
match &self.auth {
ResolvedAuth::Custom(provider) => provider.is_codex_backend(),
_ => false,
}
}
fn apply_stream_headers(
&self,
headers: &mut HeaderMap,
request: &Request,
) -> Result<(), ClientError> {
headers.insert(ACCEPT, HeaderValue::from_static("text/event-stream"));
if self.is_codex_backend()
&& let Some(cache_key) = request.cache_key.as_deref()
{
let value = HeaderValue::from_str(cache_key).map_err(|e| {
ClientError::Config(format!("invalid Codex conversation header: {e}"))
})?;
headers.insert(HeaderName::from_static("session_id"), value.clone());
headers.insert(HeaderName::from_static("x-client-request-id"), value);
}
Ok(())
}
fn encode_request_body(
&self,
body: &serde_json::Value,
headers: &mut HeaderMap,
) -> Result<RequestBody, ClientError> {
if !self.is_codex_backend() {
return Ok(RequestBody::Json(body.clone()));
}
let raw = serde_json::to_vec(body)?;
let compressed = zstd::stream::encode_all(std::io::Cursor::new(raw), 3)
.map_err(|e| ClientError::Config(format!("failed to zstd-compress request: {e}")))?;
headers.insert(CONTENT_ENCODING, HeaderValue::from_static("zstd"));
Ok(RequestBody::CompressedJson(compressed))
}
}
enum RequestBody {
Json(serde_json::Value),
CompressedJson(Vec<u8>),
}
async fn response_with_timeout(
future: impl std::future::Future<Output = Result<reqwest::Response, reqwest::Error>>,
timeout: Duration,
phase: &'static str,
) -> Result<reqwest::Response, ClientError> {
tokio::time::timeout(timeout, future)
.await
.map_err(|_| ClientError::Timeout { phase, timeout })?
.map_err(ClientError::Http)
}
impl<S: Scheme + Clone> Clone for HttpTransport<S> {
fn clone(&self) -> Self {
Self {
http_client: self.http_client.clone(),
scheme: self.scheme.clone(),
model_id: self.model_id.clone(),
base_url: self.base_url.clone(),
auth: self.auth.clone(),
capability: self.capability.clone(),
}
}
}
/// エラーレスポンスを `ClientError::Api` に変換する。
async fn classify_error_response(resp: reqwest::Response) -> ClientError {
let status = resp.status().as_u16();
let retry_after = resp
.headers()
.get(RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.trim().parse::<u64>().ok())
.map(Duration::from_secs);
let text = resp.text().await.unwrap_or_default();
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&text) {
let error = json.get("error").unwrap_or(&json);
let code = error.get("type").and_then(|v| v.as_str()).map(String::from);
let message = error
.get("message")
.and_then(|v| v.as_str())
.unwrap_or(&text)
.to_string();
ClientError::Api {
status: Some(status),
code,
message,
retry_after,
}
} else {
ClientError::Api {
status: Some(status),
code: None,
message: text,
retry_after,
}
}
}
#[async_trait]
impl<S: Scheme + Clone + 'static> LlmClient for HttpTransport<S> {
fn clone_boxed(&self) -> Box<dyn LlmClient> {
Box::new(self.clone())
}
fn validate_config(&self, config: &RequestConfig) -> Vec<ConfigWarning> {
self.scheme.validate_config(config)
}
async fn stream(&self, request: Request) -> Result<ResponseStream, ClientError> {
let url = self.build_url();
let mut headers = self.build_headers().await?;
self.apply_stream_headers(&mut headers, &request)?;
let body = self
.scheme
.build_request_body(&self.model_id, &request, &self.capability);
let request_body = self.encode_request_body(&body, &mut headers)?;
let builder = self.http_client.post(&url).headers(headers);
let builder = match request_body {
RequestBody::Json(body) => builder.json(&body),
RequestBody::CompressedJson(body) => builder.body(body),
};
let response =
response_with_timeout(builder.send(), DEFAULT_STREAM_OPEN_TIMEOUT, "stream_open")
.await?;
if !response.status().is_success() {
return Err(classify_error_response(response).await);
}
let scheme = self.scheme.clone();
let byte_stream = response.bytes_stream().map_err(std::io::Error::other);
let event_stream = byte_stream.eventsource();
// scheme 固有のパース状態をストリーム単位で保持する
let mut state = <S::State as Default>::default();
let stream = event_stream
.map(move |result| match result {
Ok(frame) => match scheme.parse_sse(&frame.event, &frame.data, &mut state) {
Ok(events) => Ok(events),
Err(e) => Err(e),
},
Err(e) => Err(ClientError::Sse(e.to_string())),
})
.map(|res| {
let s: Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>> = match res {
Ok(events) => Box::pin(futures::stream::iter(events.into_iter().map(Ok))),
Err(e) => Box::pin(futures::stream::once(async move { Err(e) })),
};
s
})
.flatten();
Ok(Box::pin(stream))
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[derive(Debug)]
struct TestAuthProvider {
codex: bool,
}
#[async_trait]
impl AuthProvider for TestAuthProvider {
async fn headers(&self) -> Result<Vec<(HeaderName, HeaderValue)>, ClientError> {
Ok(vec![
(
HeaderName::from_static("authorization"),
HeaderValue::from_static("Bearer test-token"),
),
(
HeaderName::from_static("chatgpt-account-id"),
HeaderValue::from_static("account-1"),
),
])
}
fn is_codex_backend(&self) -> bool {
self.codex
}
}
#[derive(Clone)]
struct TestScheme;
impl Scheme for TestScheme {
type State = ();
fn default_base_url(&self) -> &'static str {
"https://example.test"
}
fn path(&self, _model_id: &str) -> String {
"/responses".to_string()
}
fn required_auth(&self) -> AuthRequirement {
AuthRequirement::Bearer
}
fn build_request_body(
&self,
model_id: &str,
request: &Request,
_capability: &ModelCapability,
) -> serde_json::Value {
json!({
"model": model_id,
"input_len": request.items.len(),
"prompt_cache_key": request.cache_key,
})
}
fn parse_sse(
&self,
_event_type: &str,
_data: &str,
_state: &mut Self::State,
) -> Result<Vec<Event>, ClientError> {
Ok(Vec::new())
}
fn default_capability(&self) -> ModelCapability {
ModelCapability::minimal()
}
}
fn transport(auth: ResolvedAuth) -> HttpTransport<TestScheme> {
HttpTransport::new(
TestScheme,
"gpt-test",
"https://example.test",
auth,
ModelCapability::minimal(),
)
}
#[tokio::test]
async fn response_timeout_returns_retryable_lifecycle_timeout() {
let err = response_with_timeout(
std::future::pending::<Result<reqwest::Response, reqwest::Error>>(),
Duration::from_millis(5),
"stream_open",
)
.await
.unwrap_err();
assert!(crate::llm_client::error::is_retryable(&err));
assert!(matches!(
err,
ClientError::Timeout {
phase: "stream_open",
..
}
));
}
#[tokio::test]
async fn codex_backend_adds_conversation_headers_and_zstd_body() {
let transport = transport(ResolvedAuth::Custom(Arc::new(TestAuthProvider {
codex: true,
})));
let request = Request::new().user("hello").cache_key("segment-123");
let mut headers = transport.build_headers().await.unwrap();
transport
.apply_stream_headers(&mut headers, &request)
.unwrap();
let body = transport.scheme.build_request_body(
&transport.model_id,
&request,
&transport.capability,
);
let encoded = transport.encode_request_body(&body, &mut headers).unwrap();
assert_eq!(headers.get(ACCEPT).unwrap(), "text/event-stream");
assert_eq!(headers.get("session_id").unwrap(), "segment-123");
assert_eq!(headers.get("x-client-request-id").unwrap(), "segment-123");
assert_eq!(headers.get(CONTENT_ENCODING).unwrap(), "zstd");
let RequestBody::CompressedJson(compressed) = encoded else {
panic!("Codex backend request body must be zstd-compressed");
};
let decoded = zstd::stream::decode_all(std::io::Cursor::new(compressed)).unwrap();
let decoded: serde_json::Value = serde_json::from_slice(&decoded).unwrap();
assert_eq!(decoded["prompt_cache_key"], "segment-123");
}
#[tokio::test]
async fn non_codex_request_does_not_get_codex_only_headers_or_compression() {
let transport = transport(ResolvedAuth::ApiKey("api-key".to_string()));
let request = Request::new().user("hello").cache_key("segment-123");
let mut headers = transport.build_headers().await.unwrap();
transport
.apply_stream_headers(&mut headers, &request)
.unwrap();
let body = transport.scheme.build_request_body(
&transport.model_id,
&request,
&transport.capability,
);
let encoded = transport.encode_request_body(&body, &mut headers).unwrap();
assert_eq!(headers.get(ACCEPT).unwrap(), "text/event-stream");
assert!(headers.get("session_id").is_none());
assert!(headers.get("x-client-request-id").is_none());
assert!(headers.get(CONTENT_ENCODING).is_none());
let RequestBody::Json(decoded) = encoded else {
panic!("non-Codex request body must remain normal JSON");
};
assert_eq!(decoded["prompt_cache_key"], "segment-123");
}
}