From 19df6340cd1f3c273e80365eeec845d56bf622cb Mon Sep 17 00:00:00 2001 From: Hare Date: Mon, 4 May 2026 12:45:33 +0900 Subject: [PATCH] =?UTF-8?q?feat(llm-worker):=20HTTP=20transient=20?= =?UTF-8?q?=E3=82=A8=E3=83=A9=E3=83=BC=E3=81=B8=E3=81=AE=E3=83=AA=E3=83=88?= =?UTF-8?q?=E3=83=A9=E3=82=A4=E3=82=92=E8=BF=BD=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `transport.rs` の HTTP 送信〜ステータスチェック区間に指数バックオフ + フルジッターのリトライループを追加する。SSE 読み出し開始後 ( `bytes_stream()` 以降) のエラーは従来どおりそのまま流す。 - `is_retryable(&ClientError)`: 408/425/429/500/502/503/504/529 と reqwest の connect/timeout のみ true - `RetryPolicy` (default: base 500ms / cap 10s / max_attempts 4 / total_timeout 30s) - `Retry-After` ヘッダ (秒数) があればバックオフを上書き - リトライ発火ごとに warn! でステータス・attempt・wait を出す ref: tickets/llm-worker-transient-retry.md --- Cargo.lock | 1 + crates/llm-worker/Cargo.toml | 1 + crates/llm-worker/src/llm_client/error.rs | 66 +++++ crates/llm-worker/src/llm_client/mod.rs | 1 + crates/llm-worker/src/llm_client/retry.rs | 104 ++++++++ crates/llm-worker/src/llm_client/transport.rs | 113 +++++--- .../llm-worker/tests/transport_retry_test.rs | 249 ++++++++++++++++++ 7 files changed, 504 insertions(+), 31 deletions(-) create mode 100644 crates/llm-worker/src/llm_client/retry.rs create mode 100644 crates/llm-worker/tests/transport_retry_test.rs diff --git a/Cargo.lock b/Cargo.lock index 2d18267b..36bc3719 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1650,6 +1650,7 @@ dependencies = [ "tracing", "tracing-subscriber", "trybuild", + "wiremock", ] [[package]] diff --git a/crates/llm-worker/Cargo.toml b/crates/llm-worker/Cargo.toml index 7eb09302..f40776f3 100644 --- a/crates/llm-worker/Cargo.toml +++ b/crates/llm-worker/Cargo.toml @@ -25,3 +25,4 @@ tempfile = { workspace = true } dotenv = "0.15" tracing-subscriber = { version = "0.3", features = ["env-filter"] } trybuild = "1.0.116" +wiremock = "0.6.5" diff --git a/crates/llm-worker/src/llm_client/error.rs b/crates/llm-worker/src/llm_client/error.rs index 02ecbf1b..819ed84e 100644 --- a/crates/llm-worker/src/llm_client/error.rs +++ b/crates/llm-worker/src/llm_client/error.rs @@ -67,3 +67,69 @@ impl From for ClientError { ClientError::Json(err) } } + +/// transient な失敗としてリトライ対象になるかを判定する。 +/// +/// 対象: +/// - `Api { status }` のうち 408 / 425 / 429 / 500 / 502 / 503 / 504 / 529 +/// - `Http(reqwest::Error)` のうち `is_connect()` または `is_timeout()` +/// +/// それ以外(Json、Sse、Config、上記以外の Api ステータス)は false。 +/// SSE 読み出し開始後の失敗は呼び出し側で `Sse` として上に流すため、 +/// ここで対象外にしておけば自動的に弾かれる。 +pub fn is_retryable(error: &ClientError) -> bool { + match error { + ClientError::Api { + status: Some(code), .. + } => matches!(*code, 408 | 425 | 429 | 500 | 502 | 503 | 504 | 529), + ClientError::Api { status: None, .. } => false, + ClientError::Http(e) => e.is_connect() || e.is_timeout(), + ClientError::Json(_) | ClientError::Sse(_) | ClientError::Config(_) => false, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn api_err(status: Option) -> ClientError { + ClientError::Api { + status, + code: None, + message: String::new(), + } + } + + #[test] + fn retryable_status_codes() { + for code in [408u16, 425, 429, 500, 502, 503, 504, 529] { + assert!( + is_retryable(&api_err(Some(code))), + "status {code} should be retryable", + ); + } + } + + #[test] + fn non_retryable_status_codes() { + for code in [400u16, 401, 403, 404, 409, 410, 422, 501] { + assert!( + !is_retryable(&api_err(Some(code))), + "status {code} should not be retryable", + ); + } + } + + #[test] + fn api_without_status_not_retryable() { + assert!(!is_retryable(&api_err(None))); + } + + #[test] + fn json_sse_config_not_retryable() { + let json_err = serde_json::from_str::("not json").unwrap_err(); + assert!(!is_retryable(&ClientError::Json(json_err))); + assert!(!is_retryable(&ClientError::Sse("boom".into()))); + assert!(!is_retryable(&ClientError::Config("boom".into()))); + } +} diff --git a/crates/llm-worker/src/llm_client/mod.rs b/crates/llm-worker/src/llm_client/mod.rs index c707f94f..3037820a 100644 --- a/crates/llm-worker/src/llm_client/mod.rs +++ b/crates/llm-worker/src/llm_client/mod.rs @@ -23,6 +23,7 @@ pub mod error; pub mod event; pub mod types; +pub mod retry; pub mod scheme; pub mod transport; diff --git a/crates/llm-worker/src/llm_client/retry.rs b/crates/llm-worker/src/llm_client/retry.rs new file mode 100644 index 00000000..8f4d766a --- /dev/null +++ b/crates/llm-worker/src/llm_client/retry.rs @@ -0,0 +1,104 @@ +//! HTTP transient エラー向けリトライポリシー。 +//! +//! `transport.rs` の HTTP 送信〜ステータスチェック区間で `is_retryable` +//! が true を返した失敗をリトライする際に、待ち時間と打ち切り条件を +//! 提供する。SSE 読み出し開始後の失敗は対象外。 + +use std::time::Duration; + +/// 指数バックオフ + ジッター + 累積タイムアウトを表すポリシー。 +/// +/// `Default` は llm-worker 全体の固定値を返す。manifest 経由の上書きが +/// 必要になったら拡張する(現状は不要 → `tickets/llm-worker-transient-retry.md`)。 +#[derive(Debug, Clone)] +pub struct RetryPolicy { + /// 指数の基準値。`base * 2^attempt` を `cap` で頭打ちにした上限から + /// フルジッターで実際の wait を抽選する。 + pub base: Duration, + /// 1 回あたりの wait の上限。 + pub cap: Duration, + /// 試行の合計回数(初回 + リトライ)。`1` ならリトライしない。 + pub max_attempts: u32, + /// 初回送信開始からの累積タイムアウト。これを超える wait は打ち切る。 + pub total_timeout: Duration, +} + +impl Default for RetryPolicy { + fn default() -> Self { + Self { + base: Duration::from_millis(500), + cap: Duration::from_secs(10), + max_attempts: 4, + total_timeout: Duration::from_secs(30), + } + } +} + +impl RetryPolicy { + /// `attempt` 回目の失敗(0-indexed)後に待つ時間を返す。 + /// `Retry-After` で上書きしたい場合は呼び出さず、その値をそのまま使う。 + pub fn backoff(&self, attempt: u32) -> Duration { + let shift = attempt.min(20); + let base_nanos = self.base.as_nanos() as u64; + let exp_nanos = base_nanos.saturating_mul(1u64 << shift); + let cap_nanos = self.cap.as_nanos() as u64; + let upper = exp_nanos.min(cap_nanos); + Duration::from_nanos(jitter_nanos(upper)) + } +} + +/// `[0, max_nanos]` から擬似乱数的に 1 つ取り出す。`SystemTime` の +/// 下位ビットを splitmix64 で攪拌するだけの軽量実装で、暗号的乱数性は +/// 持たないがフルジッターのぶつかり回避には十分。 +fn jitter_nanos(max_nanos: u64) -> u64 { + if max_nanos == 0 { + return 0; + } + let seed = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_nanos() as u64) + .unwrap_or(0); + let mut x = seed.wrapping_add(0x9E37_79B9_7F4A_7C15); + x = (x ^ (x >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9); + x = (x ^ (x >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB); + x ^= x >> 31; + x % (max_nanos + 1) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_policy_values() { + let p = RetryPolicy::default(); + assert_eq!(p.base, Duration::from_millis(500)); + assert_eq!(p.cap, Duration::from_secs(10)); + assert_eq!(p.max_attempts, 4); + assert_eq!(p.total_timeout, Duration::from_secs(30)); + } + + #[test] + fn backoff_respects_cap() { + let p = RetryPolicy::default(); + for attempt in 0..30u32 { + assert!( + p.backoff(attempt) <= p.cap, + "attempt {attempt} exceeded cap", + ); + } + } + + #[test] + fn backoff_zero_when_base_zero() { + let p = RetryPolicy { + base: Duration::ZERO, + cap: Duration::from_secs(10), + max_attempts: 4, + total_timeout: Duration::from_secs(30), + }; + for attempt in 0..5 { + assert_eq!(p.backoff(attempt), Duration::ZERO); + } + } +} diff --git a/crates/llm-worker/src/llm_client/transport.rs b/crates/llm-worker/src/llm_client/transport.rs index 42a0a3c4..45a5198e 100644 --- a/crates/llm-worker/src/llm_client/transport.rs +++ b/crates/llm-worker/src/llm_client/transport.rs @@ -6,17 +6,21 @@ use std::pin::Pin; use std::sync::Arc; +use std::time::Duration; use async_trait::async_trait; use eventsource_stream::Eventsource; use futures::{Stream, StreamExt, TryStreamExt}; -use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue}; +use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue, RETRY_AFTER}; +use tokio::time::Instant; +use tracing::warn; use super::auth::{AuthProvider, AuthRequirement}; use super::capability::ModelCapability; use super::client::{ConfigWarning, LlmClient}; -use super::error::ClientError; +use super::error::{ClientError, is_retryable}; use super::event::Event; +use super::retry::RetryPolicy; use super::scheme::Scheme; use super::types::{Request, RequestConfig}; @@ -63,6 +67,7 @@ pub struct HttpTransport { base_url: String, auth: ResolvedAuth, capability: ModelCapability, + retry_policy: RetryPolicy, } impl HttpTransport { @@ -84,6 +89,7 @@ impl HttpTransport { base_url, auth, capability, + retry_policy: RetryPolicy::default(), } } @@ -93,6 +99,12 @@ impl HttpTransport { self } + /// リトライポリシーを差し替える(テスト用 / 将来の manifest 化フック)。 + pub fn with_retry_policy(mut self, policy: RetryPolicy) -> Self { + self.retry_policy = policy; + self + } + fn build_url(&self) -> String { let path = self.scheme.path(&self.model_id); let url = format!("{}{}", self.base_url, path); @@ -159,10 +171,45 @@ impl Clone for HttpTransport { base_url: self.base_url.clone(), auth: self.auth.clone(), capability: self.capability.clone(), + retry_policy: self.retry_policy.clone(), } } } +/// エラーレスポンスを `ClientError::Api` に変換し、`Retry-After` の秒数を +/// 同時に取り出す。リトライループで wait の上書きに使う。 +async fn classify_error_response(resp: reqwest::Response) -> (ClientError, Option) { + let status = resp.status().as_u16(); + let retry_after = resp + .headers() + .get(RETRY_AFTER) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.trim().parse::().ok()) + .map(Duration::from_secs); + let text = resp.text().await.unwrap_or_default(); + let err = if let Ok(json) = serde_json::from_str::(&text) { + let error = json.get("error").unwrap_or(&json); + let code = error.get("type").and_then(|v| v.as_str()).map(String::from); + let message = error + .get("message") + .and_then(|v| v.as_str()) + .unwrap_or(&text) + .to_string(); + ClientError::Api { + status: Some(status), + code, + message, + } + } else { + ClientError::Api { + status: Some(status), + code: None, + message: text, + } + }; + (err, retry_after) +} + #[async_trait] impl LlmClient for HttpTransport { fn clone_boxed(&self) -> Box { @@ -183,37 +230,41 @@ impl LlmClient for HttpTransport { .scheme .build_request_body(&self.model_id, &request, &self.capability); - let response = self - .http_client - .post(&url) - .headers(headers) - .json(&body) - .send() - .await?; + let policy = &self.retry_policy; + let started = Instant::now(); + let mut attempt: u32 = 0; + let response = loop { + let send_result = self + .http_client + .post(&url) + .headers(headers.clone()) + .json(&body) + .send() + .await; - if !response.status().is_success() { - let status = response.status().as_u16(); - let text = response.text().await.unwrap_or_default(); - if let Ok(json) = serde_json::from_str::(&text) { - let error = json.get("error").unwrap_or(&json); - let code = error.get("type").and_then(|v| v.as_str()).map(String::from); - let message = error - .get("message") - .and_then(|v| v.as_str()) - .unwrap_or(&text) - .to_string(); - return Err(ClientError::Api { - status: Some(status), - code, - message, - }); + let (err, retry_after) = match send_result { + Ok(resp) if resp.status().is_success() => break resp, + Ok(resp) => classify_error_response(resp).await, + Err(e) => (ClientError::Http(e), None), + }; + + let next_attempt = attempt + 1; + if next_attempt >= policy.max_attempts || !is_retryable(&err) { + return Err(err); } - return Err(ClientError::Api { - status: Some(status), - code: None, - message: text, - }); - } + let wait = retry_after.unwrap_or_else(|| policy.backoff(attempt)); + if started.elapsed() + wait > policy.total_timeout { + return Err(err); + } + warn!( + error = %err, + attempt = next_attempt, + wait_ms = wait.as_millis() as u64, + "transient HTTP error, retrying" + ); + tokio::time::sleep(wait).await; + attempt = next_attempt; + }; let scheme = self.scheme.clone(); let byte_stream = response.bytes_stream().map_err(std::io::Error::other); diff --git a/crates/llm-worker/tests/transport_retry_test.rs b/crates/llm-worker/tests/transport_retry_test.rs new file mode 100644 index 00000000..64108da1 --- /dev/null +++ b/crates/llm-worker/tests/transport_retry_test.rs @@ -0,0 +1,249 @@ +//! HTTP transport の transient エラーリトライ挙動の integration テスト。 +//! +//! 対応チケット: `tickets/llm-worker-transient-retry.md`。 +//! - 503 / 529 / connect refused でリトライ発火 +//! - max_attempts 上限到達でエラー +//! - `Retry-After` ヘッダで指数バックオフを上書き +//! - `parse_sse` 由来の `ClientError::Sse`(mid-stream 想定)はリトライしない + +use std::time::{Duration, Instant}; + +use futures::StreamExt; +use llm_worker::llm_client::LlmClient; +use llm_worker::llm_client::auth::AuthRequirement; +use llm_worker::llm_client::capability::ModelCapability; +use llm_worker::llm_client::error::ClientError; +use llm_worker::llm_client::event::Event; +use llm_worker::llm_client::retry::RetryPolicy; +use llm_worker::llm_client::scheme::Scheme; +use llm_worker::llm_client::transport::{HttpTransport, ResolvedAuth}; +use llm_worker::llm_client::types::Request; +use serde_json::Value; +use wiremock::matchers::{method, path}; +use wiremock::{Mock, MockServer, ResponseTemplate}; + +/// SSE 本体は触らないテスト用 scheme。`parse_fail` を立てると +/// stream 消費中(= retry loop の外)で `ClientError::Sse` を返す。 +#[derive(Clone)] +struct DummyScheme { + parse_fail: bool, +} + +impl Scheme for DummyScheme { + type State = (); + fn default_base_url(&self) -> &'static str { + "" + } + fn path(&self, _: &str) -> String { + "/v1/chat".into() + } + fn required_auth(&self) -> AuthRequirement { + AuthRequirement::None + } + fn build_request_body(&self, _: &str, _: &Request, _: &ModelCapability) -> Value { + serde_json::json!({}) + } + fn parse_sse(&self, _: &str, _: &str, _: &mut ()) -> Result, ClientError> { + if self.parse_fail { + Err(ClientError::Sse("simulated mid-stream parse failure".into())) + } else { + Ok(vec![]) + } + } + fn default_capability(&self) -> ModelCapability { + ModelCapability::minimal() + } +} + +fn fast_policy(max_attempts: u32) -> RetryPolicy { + RetryPolicy { + base: Duration::from_millis(1), + cap: Duration::from_millis(1), + max_attempts, + total_timeout: Duration::from_secs(60), + } +} + +fn build_transport( + base_url: impl Into, + parse_fail: bool, + policy: RetryPolicy, +) -> HttpTransport { + HttpTransport::new( + DummyScheme { parse_fail }, + "test-model", + base_url, + ResolvedAuth::None, + ModelCapability::minimal(), + ) + .with_retry_policy(policy) +} + +fn ok_sse() -> ResponseTemplate { + ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(b"".to_vec(), "text/event-stream") +} + +#[tokio::test] +async fn retries_503_then_succeeds() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/chat")) + .respond_with(ResponseTemplate::new(503).set_body_string("upstream connect error")) + .up_to_n_times(2) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/v1/chat")) + .respond_with(ok_sse()) + .mount(&server) + .await; + + let transport = build_transport(server.uri(), false, fast_policy(5)); + let mut stream = transport + .stream(Request::default()) + .await + .expect("stream should succeed after retries"); + while stream.next().await.is_some() {} + + let received = server.received_requests().await.unwrap(); + assert_eq!(received.len(), 3, "two failures plus one success expected"); +} + +#[tokio::test] +async fn retries_529_then_exhausts() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/chat")) + .respond_with(ResponseTemplate::new(529).set_body_string("overloaded")) + .mount(&server) + .await; + + let transport = build_transport(server.uri(), false, fast_policy(3)); + match transport.stream(Request::default()).await { + Err(ClientError::Api { + status: Some(529), .. + }) => {} + Err(other) => panic!("expected Api(529), got {other:?}"), + Ok(_) => panic!("expected error after exhausting retries"), + } + + let received = server.received_requests().await.unwrap(); + assert_eq!(received.len(), 3, "should hit max_attempts and stop"); +} + +#[tokio::test] +async fn connect_refused_retries_then_fails() { + // 接続不能なローカルアドレスを使う。Linux では `Connection refused` で + // 即時失敗するため、`fast_policy` ならテストが秒以下で終わる。 + let unreachable = "http://127.0.0.1:1"; + + let transport = build_transport(unreachable, false, fast_policy(3)); + match transport.stream(Request::default()).await { + Err(ClientError::Http(e)) => { + assert!( + e.is_connect() || e.is_timeout(), + "expected connect/timeout, got {e:?}" + ); + } + Err(other) => panic!("expected Http error, got {other:?}"), + Ok(_) => panic!("expected error connecting to closed port"), + } +} + +#[tokio::test] +async fn retry_after_header_overrides_backoff() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/chat")) + .respond_with(ResponseTemplate::new(503).insert_header("retry-after", "1")) + .up_to_n_times(1) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/v1/chat")) + .respond_with(ok_sse()) + .mount(&server) + .await; + + // base/cap を 1ms に絞った policy で `Retry-After: 1` を観察すると、 + // 指数バックオフ単独なら 1ms 程度で終わるはずが Retry-After に従って + // 1 秒待つ → 経過時間で override を検証できる。 + let policy = RetryPolicy { + base: Duration::from_millis(1), + cap: Duration::from_millis(1), + max_attempts: 3, + total_timeout: Duration::from_secs(10), + }; + let transport = build_transport(server.uri(), false, policy); + + let start = Instant::now(); + let mut stream = transport.stream(Request::default()).await.expect("ok"); + while stream.next().await.is_some() {} + let elapsed = start.elapsed(); + + assert!( + elapsed >= Duration::from_secs(1), + "Retry-After=1 should make us wait >=1s, elapsed={elapsed:?}" + ); + assert!( + elapsed < Duration::from_secs(3), + "Retry-After=1 should not balloon, elapsed={elapsed:?}" + ); +} + +#[tokio::test] +async fn mid_stream_sse_error_does_not_retry() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/chat")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw( + b"event: data\ndata: payload\n\n".to_vec(), + "text/event-stream", + ), + ) + .mount(&server) + .await; + + let transport = build_transport(server.uri(), true, fast_policy(5)); + let mut stream = transport + .stream(Request::default()) + .await + .expect("status 200 should bypass retry loop"); + let mut saw_sse_err = false; + while let Some(item) = stream.next().await { + if matches!(item, Err(ClientError::Sse(_))) { + saw_sse_err = true; + } + } + assert!(saw_sse_err, "expected Sse error from stream consumer"); + + let received = server.received_requests().await.unwrap(); + assert_eq!(received.len(), 1, "mid-stream Sse must not retry"); +} + +#[tokio::test] +async fn non_retryable_status_returns_immediately() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/chat")) + .respond_with(ResponseTemplate::new(401).set_body_string("unauthorized")) + .mount(&server) + .await; + + let transport = build_transport(server.uri(), false, fast_policy(5)); + match transport.stream(Request::default()).await { + Err(ClientError::Api { + status: Some(401), .. + }) => {} + Err(other) => panic!("expected Api(401), got {other:?}"), + Ok(_) => panic!("expected error"), + } + + let received = server.received_requests().await.unwrap(); + assert_eq!(received.len(), 1, "401 must not retry"); +}