feat(llm-worker): HTTP transient エラーへのリトライを追加
`transport.rs` の HTTP 送信〜ステータスチェック区間に指数バックオフ + フルジッターのリトライループを追加する。SSE 読み出し開始後 ( `bytes_stream()` 以降) のエラーは従来どおりそのまま流す。 - `is_retryable(&ClientError)`: 408/425/429/500/502/503/504/529 と reqwest の connect/timeout のみ true - `RetryPolicy` (default: base 500ms / cap 10s / max_attempts 4 / total_timeout 30s) - `Retry-After` ヘッダ (秒数) があればバックオフを上書き - リトライ発火ごとに warn! でステータス・attempt・wait を出す ref: tickets/llm-worker-transient-retry.md
This commit is contained in:
parent
a0771608b1
commit
19df6340cd
1
Cargo.lock
generated
1
Cargo.lock
generated
|
|
@ -1650,6 +1650,7 @@ dependencies = [
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
"trybuild",
|
"trybuild",
|
||||||
|
"wiremock",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
|
||||||
|
|
@ -25,3 +25,4 @@ tempfile = { workspace = true }
|
||||||
dotenv = "0.15"
|
dotenv = "0.15"
|
||||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
trybuild = "1.0.116"
|
trybuild = "1.0.116"
|
||||||
|
wiremock = "0.6.5"
|
||||||
|
|
|
||||||
|
|
@ -67,3 +67,69 @@ impl From<serde_json::Error> for ClientError {
|
||||||
ClientError::Json(err)
|
ClientError::Json(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// transient な失敗としてリトライ対象になるかを判定する。
|
||||||
|
///
|
||||||
|
/// 対象:
|
||||||
|
/// - `Api { status }` のうち 408 / 425 / 429 / 500 / 502 / 503 / 504 / 529
|
||||||
|
/// - `Http(reqwest::Error)` のうち `is_connect()` または `is_timeout()`
|
||||||
|
///
|
||||||
|
/// それ以外(Json、Sse、Config、上記以外の Api ステータス)は false。
|
||||||
|
/// SSE 読み出し開始後の失敗は呼び出し側で `Sse` として上に流すため、
|
||||||
|
/// ここで対象外にしておけば自動的に弾かれる。
|
||||||
|
pub fn is_retryable(error: &ClientError) -> bool {
|
||||||
|
match error {
|
||||||
|
ClientError::Api {
|
||||||
|
status: Some(code), ..
|
||||||
|
} => matches!(*code, 408 | 425 | 429 | 500 | 502 | 503 | 504 | 529),
|
||||||
|
ClientError::Api { status: None, .. } => false,
|
||||||
|
ClientError::Http(e) => e.is_connect() || e.is_timeout(),
|
||||||
|
ClientError::Json(_) | ClientError::Sse(_) | ClientError::Config(_) => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn api_err(status: Option<u16>) -> ClientError {
|
||||||
|
ClientError::Api {
|
||||||
|
status,
|
||||||
|
code: None,
|
||||||
|
message: String::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn retryable_status_codes() {
|
||||||
|
for code in [408u16, 425, 429, 500, 502, 503, 504, 529] {
|
||||||
|
assert!(
|
||||||
|
is_retryable(&api_err(Some(code))),
|
||||||
|
"status {code} should be retryable",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn non_retryable_status_codes() {
|
||||||
|
for code in [400u16, 401, 403, 404, 409, 410, 422, 501] {
|
||||||
|
assert!(
|
||||||
|
!is_retryable(&api_err(Some(code))),
|
||||||
|
"status {code} should not be retryable",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn api_without_status_not_retryable() {
|
||||||
|
assert!(!is_retryable(&api_err(None)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn json_sse_config_not_retryable() {
|
||||||
|
let json_err = serde_json::from_str::<serde_json::Value>("not json").unwrap_err();
|
||||||
|
assert!(!is_retryable(&ClientError::Json(json_err)));
|
||||||
|
assert!(!is_retryable(&ClientError::Sse("boom".into())));
|
||||||
|
assert!(!is_retryable(&ClientError::Config("boom".into())));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ pub mod error;
|
||||||
pub mod event;
|
pub mod event;
|
||||||
pub mod types;
|
pub mod types;
|
||||||
|
|
||||||
|
pub mod retry;
|
||||||
pub mod scheme;
|
pub mod scheme;
|
||||||
pub mod transport;
|
pub mod transport;
|
||||||
|
|
||||||
|
|
|
||||||
104
crates/llm-worker/src/llm_client/retry.rs
Normal file
104
crates/llm-worker/src/llm_client/retry.rs
Normal file
|
|
@ -0,0 +1,104 @@
|
||||||
|
//! HTTP transient エラー向けリトライポリシー。
|
||||||
|
//!
|
||||||
|
//! `transport.rs` の HTTP 送信〜ステータスチェック区間で `is_retryable`
|
||||||
|
//! が true を返した失敗をリトライする際に、待ち時間と打ち切り条件を
|
||||||
|
//! 提供する。SSE 読み出し開始後の失敗は対象外。
|
||||||
|
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
/// 指数バックオフ + ジッター + 累積タイムアウトを表すポリシー。
|
||||||
|
///
|
||||||
|
/// `Default` は llm-worker 全体の固定値を返す。manifest 経由の上書きが
|
||||||
|
/// 必要になったら拡張する(現状は不要 → `tickets/llm-worker-transient-retry.md`)。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RetryPolicy {
|
||||||
|
/// 指数の基準値。`base * 2^attempt` を `cap` で頭打ちにした上限から
|
||||||
|
/// フルジッターで実際の wait を抽選する。
|
||||||
|
pub base: Duration,
|
||||||
|
/// 1 回あたりの wait の上限。
|
||||||
|
pub cap: Duration,
|
||||||
|
/// 試行の合計回数(初回 + リトライ)。`1` ならリトライしない。
|
||||||
|
pub max_attempts: u32,
|
||||||
|
/// 初回送信開始からの累積タイムアウト。これを超える wait は打ち切る。
|
||||||
|
pub total_timeout: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RetryPolicy {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
base: Duration::from_millis(500),
|
||||||
|
cap: Duration::from_secs(10),
|
||||||
|
max_attempts: 4,
|
||||||
|
total_timeout: Duration::from_secs(30),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RetryPolicy {
|
||||||
|
/// `attempt` 回目の失敗(0-indexed)後に待つ時間を返す。
|
||||||
|
/// `Retry-After` で上書きしたい場合は呼び出さず、その値をそのまま使う。
|
||||||
|
pub fn backoff(&self, attempt: u32) -> Duration {
|
||||||
|
let shift = attempt.min(20);
|
||||||
|
let base_nanos = self.base.as_nanos() as u64;
|
||||||
|
let exp_nanos = base_nanos.saturating_mul(1u64 << shift);
|
||||||
|
let cap_nanos = self.cap.as_nanos() as u64;
|
||||||
|
let upper = exp_nanos.min(cap_nanos);
|
||||||
|
Duration::from_nanos(jitter_nanos(upper))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `[0, max_nanos]` から擬似乱数的に 1 つ取り出す。`SystemTime` の
|
||||||
|
/// 下位ビットを splitmix64 で攪拌するだけの軽量実装で、暗号的乱数性は
|
||||||
|
/// 持たないがフルジッターのぶつかり回避には十分。
|
||||||
|
fn jitter_nanos(max_nanos: u64) -> u64 {
|
||||||
|
if max_nanos == 0 {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
let seed = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.map(|d| d.as_nanos() as u64)
|
||||||
|
.unwrap_or(0);
|
||||||
|
let mut x = seed.wrapping_add(0x9E37_79B9_7F4A_7C15);
|
||||||
|
x = (x ^ (x >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
|
||||||
|
x = (x ^ (x >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
|
||||||
|
x ^= x >> 31;
|
||||||
|
x % (max_nanos + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn default_policy_values() {
|
||||||
|
let p = RetryPolicy::default();
|
||||||
|
assert_eq!(p.base, Duration::from_millis(500));
|
||||||
|
assert_eq!(p.cap, Duration::from_secs(10));
|
||||||
|
assert_eq!(p.max_attempts, 4);
|
||||||
|
assert_eq!(p.total_timeout, Duration::from_secs(30));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn backoff_respects_cap() {
|
||||||
|
let p = RetryPolicy::default();
|
||||||
|
for attempt in 0..30u32 {
|
||||||
|
assert!(
|
||||||
|
p.backoff(attempt) <= p.cap,
|
||||||
|
"attempt {attempt} exceeded cap",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn backoff_zero_when_base_zero() {
|
||||||
|
let p = RetryPolicy {
|
||||||
|
base: Duration::ZERO,
|
||||||
|
cap: Duration::from_secs(10),
|
||||||
|
max_attempts: 4,
|
||||||
|
total_timeout: Duration::from_secs(30),
|
||||||
|
};
|
||||||
|
for attempt in 0..5 {
|
||||||
|
assert_eq!(p.backoff(attempt), Duration::ZERO);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -6,17 +6,21 @@
|
||||||
|
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use eventsource_stream::Eventsource;
|
use eventsource_stream::Eventsource;
|
||||||
use futures::{Stream, StreamExt, TryStreamExt};
|
use futures::{Stream, StreamExt, TryStreamExt};
|
||||||
use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
|
use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue, RETRY_AFTER};
|
||||||
|
use tokio::time::Instant;
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
use super::auth::{AuthProvider, AuthRequirement};
|
use super::auth::{AuthProvider, AuthRequirement};
|
||||||
use super::capability::ModelCapability;
|
use super::capability::ModelCapability;
|
||||||
use super::client::{ConfigWarning, LlmClient};
|
use super::client::{ConfigWarning, LlmClient};
|
||||||
use super::error::ClientError;
|
use super::error::{ClientError, is_retryable};
|
||||||
use super::event::Event;
|
use super::event::Event;
|
||||||
|
use super::retry::RetryPolicy;
|
||||||
use super::scheme::Scheme;
|
use super::scheme::Scheme;
|
||||||
use super::types::{Request, RequestConfig};
|
use super::types::{Request, RequestConfig};
|
||||||
|
|
||||||
|
|
@ -63,6 +67,7 @@ pub struct HttpTransport<S: Scheme> {
|
||||||
base_url: String,
|
base_url: String,
|
||||||
auth: ResolvedAuth,
|
auth: ResolvedAuth,
|
||||||
capability: ModelCapability,
|
capability: ModelCapability,
|
||||||
|
retry_policy: RetryPolicy,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: Scheme> HttpTransport<S> {
|
impl<S: Scheme> HttpTransport<S> {
|
||||||
|
|
@ -84,6 +89,7 @@ impl<S: Scheme> HttpTransport<S> {
|
||||||
base_url,
|
base_url,
|
||||||
auth,
|
auth,
|
||||||
capability,
|
capability,
|
||||||
|
retry_policy: RetryPolicy::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -93,6 +99,12 @@ impl<S: Scheme> HttpTransport<S> {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// リトライポリシーを差し替える(テスト用 / 将来の manifest 化フック)。
|
||||||
|
pub fn with_retry_policy(mut self, policy: RetryPolicy) -> Self {
|
||||||
|
self.retry_policy = policy;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
fn build_url(&self) -> String {
|
fn build_url(&self) -> String {
|
||||||
let path = self.scheme.path(&self.model_id);
|
let path = self.scheme.path(&self.model_id);
|
||||||
let url = format!("{}{}", self.base_url, path);
|
let url = format!("{}{}", self.base_url, path);
|
||||||
|
|
@ -159,10 +171,45 @@ impl<S: Scheme + Clone> Clone for HttpTransport<S> {
|
||||||
base_url: self.base_url.clone(),
|
base_url: self.base_url.clone(),
|
||||||
auth: self.auth.clone(),
|
auth: self.auth.clone(),
|
||||||
capability: self.capability.clone(),
|
capability: self.capability.clone(),
|
||||||
|
retry_policy: self.retry_policy.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// エラーレスポンスを `ClientError::Api` に変換し、`Retry-After` の秒数を
|
||||||
|
/// 同時に取り出す。リトライループで wait の上書きに使う。
|
||||||
|
async fn classify_error_response(resp: reqwest::Response) -> (ClientError, Option<Duration>) {
|
||||||
|
let status = resp.status().as_u16();
|
||||||
|
let retry_after = resp
|
||||||
|
.headers()
|
||||||
|
.get(RETRY_AFTER)
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.and_then(|s| s.trim().parse::<u64>().ok())
|
||||||
|
.map(Duration::from_secs);
|
||||||
|
let text = resp.text().await.unwrap_or_default();
|
||||||
|
let err = if let Ok(json) = serde_json::from_str::<serde_json::Value>(&text) {
|
||||||
|
let error = json.get("error").unwrap_or(&json);
|
||||||
|
let code = error.get("type").and_then(|v| v.as_str()).map(String::from);
|
||||||
|
let message = error
|
||||||
|
.get("message")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or(&text)
|
||||||
|
.to_string();
|
||||||
|
ClientError::Api {
|
||||||
|
status: Some(status),
|
||||||
|
code,
|
||||||
|
message,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ClientError::Api {
|
||||||
|
status: Some(status),
|
||||||
|
code: None,
|
||||||
|
message: text,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
(err, retry_after)
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<S: Scheme + Clone + 'static> LlmClient for HttpTransport<S> {
|
impl<S: Scheme + Clone + 'static> LlmClient for HttpTransport<S> {
|
||||||
fn clone_boxed(&self) -> Box<dyn LlmClient> {
|
fn clone_boxed(&self) -> Box<dyn LlmClient> {
|
||||||
|
|
@ -183,37 +230,41 @@ impl<S: Scheme + Clone + 'static> LlmClient for HttpTransport<S> {
|
||||||
.scheme
|
.scheme
|
||||||
.build_request_body(&self.model_id, &request, &self.capability);
|
.build_request_body(&self.model_id, &request, &self.capability);
|
||||||
|
|
||||||
let response = self
|
let policy = &self.retry_policy;
|
||||||
.http_client
|
let started = Instant::now();
|
||||||
.post(&url)
|
let mut attempt: u32 = 0;
|
||||||
.headers(headers)
|
let response = loop {
|
||||||
.json(&body)
|
let send_result = self
|
||||||
.send()
|
.http_client
|
||||||
.await?;
|
.post(&url)
|
||||||
|
.headers(headers.clone())
|
||||||
|
.json(&body)
|
||||||
|
.send()
|
||||||
|
.await;
|
||||||
|
|
||||||
if !response.status().is_success() {
|
let (err, retry_after) = match send_result {
|
||||||
let status = response.status().as_u16();
|
Ok(resp) if resp.status().is_success() => break resp,
|
||||||
let text = response.text().await.unwrap_or_default();
|
Ok(resp) => classify_error_response(resp).await,
|
||||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&text) {
|
Err(e) => (ClientError::Http(e), None),
|
||||||
let error = json.get("error").unwrap_or(&json);
|
};
|
||||||
let code = error.get("type").and_then(|v| v.as_str()).map(String::from);
|
|
||||||
let message = error
|
let next_attempt = attempt + 1;
|
||||||
.get("message")
|
if next_attempt >= policy.max_attempts || !is_retryable(&err) {
|
||||||
.and_then(|v| v.as_str())
|
return Err(err);
|
||||||
.unwrap_or(&text)
|
|
||||||
.to_string();
|
|
||||||
return Err(ClientError::Api {
|
|
||||||
status: Some(status),
|
|
||||||
code,
|
|
||||||
message,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
return Err(ClientError::Api {
|
let wait = retry_after.unwrap_or_else(|| policy.backoff(attempt));
|
||||||
status: Some(status),
|
if started.elapsed() + wait > policy.total_timeout {
|
||||||
code: None,
|
return Err(err);
|
||||||
message: text,
|
}
|
||||||
});
|
warn!(
|
||||||
}
|
error = %err,
|
||||||
|
attempt = next_attempt,
|
||||||
|
wait_ms = wait.as_millis() as u64,
|
||||||
|
"transient HTTP error, retrying"
|
||||||
|
);
|
||||||
|
tokio::time::sleep(wait).await;
|
||||||
|
attempt = next_attempt;
|
||||||
|
};
|
||||||
|
|
||||||
let scheme = self.scheme.clone();
|
let scheme = self.scheme.clone();
|
||||||
let byte_stream = response.bytes_stream().map_err(std::io::Error::other);
|
let byte_stream = response.bytes_stream().map_err(std::io::Error::other);
|
||||||
|
|
|
||||||
249
crates/llm-worker/tests/transport_retry_test.rs
Normal file
249
crates/llm-worker/tests/transport_retry_test.rs
Normal file
|
|
@ -0,0 +1,249 @@
|
||||||
|
//! HTTP transport の transient エラーリトライ挙動の integration テスト。
|
||||||
|
//!
|
||||||
|
//! 対応チケット: `tickets/llm-worker-transient-retry.md`。
|
||||||
|
//! - 503 / 529 / connect refused でリトライ発火
|
||||||
|
//! - max_attempts 上限到達でエラー
|
||||||
|
//! - `Retry-After` ヘッダで指数バックオフを上書き
|
||||||
|
//! - `parse_sse` 由来の `ClientError::Sse`(mid-stream 想定)はリトライしない
|
||||||
|
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
use futures::StreamExt;
|
||||||
|
use llm_worker::llm_client::LlmClient;
|
||||||
|
use llm_worker::llm_client::auth::AuthRequirement;
|
||||||
|
use llm_worker::llm_client::capability::ModelCapability;
|
||||||
|
use llm_worker::llm_client::error::ClientError;
|
||||||
|
use llm_worker::llm_client::event::Event;
|
||||||
|
use llm_worker::llm_client::retry::RetryPolicy;
|
||||||
|
use llm_worker::llm_client::scheme::Scheme;
|
||||||
|
use llm_worker::llm_client::transport::{HttpTransport, ResolvedAuth};
|
||||||
|
use llm_worker::llm_client::types::Request;
|
||||||
|
use serde_json::Value;
|
||||||
|
use wiremock::matchers::{method, path};
|
||||||
|
use wiremock::{Mock, MockServer, ResponseTemplate};
|
||||||
|
|
||||||
|
/// SSE 本体は触らないテスト用 scheme。`parse_fail` を立てると
|
||||||
|
/// stream 消費中(= retry loop の外)で `ClientError::Sse` を返す。
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct DummyScheme {
|
||||||
|
parse_fail: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Scheme for DummyScheme {
|
||||||
|
type State = ();
|
||||||
|
fn default_base_url(&self) -> &'static str {
|
||||||
|
""
|
||||||
|
}
|
||||||
|
fn path(&self, _: &str) -> String {
|
||||||
|
"/v1/chat".into()
|
||||||
|
}
|
||||||
|
fn required_auth(&self) -> AuthRequirement {
|
||||||
|
AuthRequirement::None
|
||||||
|
}
|
||||||
|
fn build_request_body(&self, _: &str, _: &Request, _: &ModelCapability) -> Value {
|
||||||
|
serde_json::json!({})
|
||||||
|
}
|
||||||
|
fn parse_sse(&self, _: &str, _: &str, _: &mut ()) -> Result<Vec<Event>, ClientError> {
|
||||||
|
if self.parse_fail {
|
||||||
|
Err(ClientError::Sse("simulated mid-stream parse failure".into()))
|
||||||
|
} else {
|
||||||
|
Ok(vec![])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn default_capability(&self) -> ModelCapability {
|
||||||
|
ModelCapability::minimal()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fast_policy(max_attempts: u32) -> RetryPolicy {
|
||||||
|
RetryPolicy {
|
||||||
|
base: Duration::from_millis(1),
|
||||||
|
cap: Duration::from_millis(1),
|
||||||
|
max_attempts,
|
||||||
|
total_timeout: Duration::from_secs(60),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_transport(
|
||||||
|
base_url: impl Into<String>,
|
||||||
|
parse_fail: bool,
|
||||||
|
policy: RetryPolicy,
|
||||||
|
) -> HttpTransport<DummyScheme> {
|
||||||
|
HttpTransport::new(
|
||||||
|
DummyScheme { parse_fail },
|
||||||
|
"test-model",
|
||||||
|
base_url,
|
||||||
|
ResolvedAuth::None,
|
||||||
|
ModelCapability::minimal(),
|
||||||
|
)
|
||||||
|
.with_retry_policy(policy)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ok_sse() -> ResponseTemplate {
|
||||||
|
ResponseTemplate::new(200)
|
||||||
|
.insert_header("content-type", "text/event-stream")
|
||||||
|
.set_body_raw(b"".to_vec(), "text/event-stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn retries_503_then_succeeds() {
|
||||||
|
let server = MockServer::start().await;
|
||||||
|
Mock::given(method("POST"))
|
||||||
|
.and(path("/v1/chat"))
|
||||||
|
.respond_with(ResponseTemplate::new(503).set_body_string("upstream connect error"))
|
||||||
|
.up_to_n_times(2)
|
||||||
|
.mount(&server)
|
||||||
|
.await;
|
||||||
|
Mock::given(method("POST"))
|
||||||
|
.and(path("/v1/chat"))
|
||||||
|
.respond_with(ok_sse())
|
||||||
|
.mount(&server)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let transport = build_transport(server.uri(), false, fast_policy(5));
|
||||||
|
let mut stream = transport
|
||||||
|
.stream(Request::default())
|
||||||
|
.await
|
||||||
|
.expect("stream should succeed after retries");
|
||||||
|
while stream.next().await.is_some() {}
|
||||||
|
|
||||||
|
let received = server.received_requests().await.unwrap();
|
||||||
|
assert_eq!(received.len(), 3, "two failures plus one success expected");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn retries_529_then_exhausts() {
|
||||||
|
let server = MockServer::start().await;
|
||||||
|
Mock::given(method("POST"))
|
||||||
|
.and(path("/v1/chat"))
|
||||||
|
.respond_with(ResponseTemplate::new(529).set_body_string("overloaded"))
|
||||||
|
.mount(&server)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let transport = build_transport(server.uri(), false, fast_policy(3));
|
||||||
|
match transport.stream(Request::default()).await {
|
||||||
|
Err(ClientError::Api {
|
||||||
|
status: Some(529), ..
|
||||||
|
}) => {}
|
||||||
|
Err(other) => panic!("expected Api(529), got {other:?}"),
|
||||||
|
Ok(_) => panic!("expected error after exhausting retries"),
|
||||||
|
}
|
||||||
|
|
||||||
|
let received = server.received_requests().await.unwrap();
|
||||||
|
assert_eq!(received.len(), 3, "should hit max_attempts and stop");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn connect_refused_retries_then_fails() {
|
||||||
|
// 接続不能なローカルアドレスを使う。Linux では `Connection refused` で
|
||||||
|
// 即時失敗するため、`fast_policy` ならテストが秒以下で終わる。
|
||||||
|
let unreachable = "http://127.0.0.1:1";
|
||||||
|
|
||||||
|
let transport = build_transport(unreachable, false, fast_policy(3));
|
||||||
|
match transport.stream(Request::default()).await {
|
||||||
|
Err(ClientError::Http(e)) => {
|
||||||
|
assert!(
|
||||||
|
e.is_connect() || e.is_timeout(),
|
||||||
|
"expected connect/timeout, got {e:?}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(other) => panic!("expected Http error, got {other:?}"),
|
||||||
|
Ok(_) => panic!("expected error connecting to closed port"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn retry_after_header_overrides_backoff() {
|
||||||
|
let server = MockServer::start().await;
|
||||||
|
Mock::given(method("POST"))
|
||||||
|
.and(path("/v1/chat"))
|
||||||
|
.respond_with(ResponseTemplate::new(503).insert_header("retry-after", "1"))
|
||||||
|
.up_to_n_times(1)
|
||||||
|
.mount(&server)
|
||||||
|
.await;
|
||||||
|
Mock::given(method("POST"))
|
||||||
|
.and(path("/v1/chat"))
|
||||||
|
.respond_with(ok_sse())
|
||||||
|
.mount(&server)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// base/cap を 1ms に絞った policy で `Retry-After: 1` を観察すると、
|
||||||
|
// 指数バックオフ単独なら 1ms 程度で終わるはずが Retry-After に従って
|
||||||
|
// 1 秒待つ → 経過時間で override を検証できる。
|
||||||
|
let policy = RetryPolicy {
|
||||||
|
base: Duration::from_millis(1),
|
||||||
|
cap: Duration::from_millis(1),
|
||||||
|
max_attempts: 3,
|
||||||
|
total_timeout: Duration::from_secs(10),
|
||||||
|
};
|
||||||
|
let transport = build_transport(server.uri(), false, policy);
|
||||||
|
|
||||||
|
let start = Instant::now();
|
||||||
|
let mut stream = transport.stream(Request::default()).await.expect("ok");
|
||||||
|
while stream.next().await.is_some() {}
|
||||||
|
let elapsed = start.elapsed();
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
elapsed >= Duration::from_secs(1),
|
||||||
|
"Retry-After=1 should make us wait >=1s, elapsed={elapsed:?}"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
elapsed < Duration::from_secs(3),
|
||||||
|
"Retry-After=1 should not balloon, elapsed={elapsed:?}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn mid_stream_sse_error_does_not_retry() {
|
||||||
|
let server = MockServer::start().await;
|
||||||
|
Mock::given(method("POST"))
|
||||||
|
.and(path("/v1/chat"))
|
||||||
|
.respond_with(
|
||||||
|
ResponseTemplate::new(200)
|
||||||
|
.insert_header("content-type", "text/event-stream")
|
||||||
|
.set_body_raw(
|
||||||
|
b"event: data\ndata: payload\n\n".to_vec(),
|
||||||
|
"text/event-stream",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.mount(&server)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let transport = build_transport(server.uri(), true, fast_policy(5));
|
||||||
|
let mut stream = transport
|
||||||
|
.stream(Request::default())
|
||||||
|
.await
|
||||||
|
.expect("status 200 should bypass retry loop");
|
||||||
|
let mut saw_sse_err = false;
|
||||||
|
while let Some(item) = stream.next().await {
|
||||||
|
if matches!(item, Err(ClientError::Sse(_))) {
|
||||||
|
saw_sse_err = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert!(saw_sse_err, "expected Sse error from stream consumer");
|
||||||
|
|
||||||
|
let received = server.received_requests().await.unwrap();
|
||||||
|
assert_eq!(received.len(), 1, "mid-stream Sse must not retry");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn non_retryable_status_returns_immediately() {
|
||||||
|
let server = MockServer::start().await;
|
||||||
|
Mock::given(method("POST"))
|
||||||
|
.and(path("/v1/chat"))
|
||||||
|
.respond_with(ResponseTemplate::new(401).set_body_string("unauthorized"))
|
||||||
|
.mount(&server)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let transport = build_transport(server.uri(), false, fast_policy(5));
|
||||||
|
match transport.stream(Request::default()).await {
|
||||||
|
Err(ClientError::Api {
|
||||||
|
status: Some(401), ..
|
||||||
|
}) => {}
|
||||||
|
Err(other) => panic!("expected Api(401), got {other:?}"),
|
||||||
|
Ok(_) => panic!("expected error"),
|
||||||
|
}
|
||||||
|
|
||||||
|
let received = server.received_requests().await.unwrap();
|
||||||
|
assert_eq!(received.len(), 1, "401 must not retry");
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user