yoi/crates/llm-worker/tests/transport_retry_test.rs
Hare 19df6340cd feat(llm-worker): HTTP transient エラーへのリトライを追加
`transport.rs` の HTTP 送信〜ステータスチェック区間に指数バックオフ
+ フルジッターのリトライループを追加する。SSE 読み出し開始後 (
`bytes_stream()` 以降) のエラーは従来どおりそのまま流す。

- `is_retryable(&ClientError)`: 408/425/429/500/502/503/504/529 と
  reqwest の connect/timeout のみ true
- `RetryPolicy` (default: base 500ms / cap 10s / max_attempts 4 /
  total_timeout 30s)
- `Retry-After` ヘッダ (秒数) があればバックオフを上書き
- リトライ発火ごとに warn! でステータス・attempt・wait を出す

ref: tickets/llm-worker-transient-retry.md
2026-05-04 12:45:33 +09:00

250 lines
8.2 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! HTTP transport の transient エラーリトライ挙動の integration テスト。
//!
//! 対応チケット: `tickets/llm-worker-transient-retry.md`。
//! - 503 / 529 / connect refused でリトライ発火
//! - max_attempts 上限到達でエラー
//! - `Retry-After` ヘッダで指数バックオフを上書き
//! - `parse_sse` 由来の `ClientError::Sse`mid-stream 想定)はリトライしない
use std::time::{Duration, Instant};
use futures::StreamExt;
use llm_worker::llm_client::LlmClient;
use llm_worker::llm_client::auth::AuthRequirement;
use llm_worker::llm_client::capability::ModelCapability;
use llm_worker::llm_client::error::ClientError;
use llm_worker::llm_client::event::Event;
use llm_worker::llm_client::retry::RetryPolicy;
use llm_worker::llm_client::scheme::Scheme;
use llm_worker::llm_client::transport::{HttpTransport, ResolvedAuth};
use llm_worker::llm_client::types::Request;
use serde_json::Value;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
/// SSE 本体は触らないテスト用 scheme。`parse_fail` を立てると
/// stream 消費中(= retry loop の外)で `ClientError::Sse` を返す。
#[derive(Clone)]
struct DummyScheme {
parse_fail: bool,
}
impl Scheme for DummyScheme {
type State = ();
fn default_base_url(&self) -> &'static str {
""
}
fn path(&self, _: &str) -> String {
"/v1/chat".into()
}
fn required_auth(&self) -> AuthRequirement {
AuthRequirement::None
}
fn build_request_body(&self, _: &str, _: &Request, _: &ModelCapability) -> Value {
serde_json::json!({})
}
fn parse_sse(&self, _: &str, _: &str, _: &mut ()) -> Result<Vec<Event>, ClientError> {
if self.parse_fail {
Err(ClientError::Sse("simulated mid-stream parse failure".into()))
} else {
Ok(vec![])
}
}
fn default_capability(&self) -> ModelCapability {
ModelCapability::minimal()
}
}
fn fast_policy(max_attempts: u32) -> RetryPolicy {
RetryPolicy {
base: Duration::from_millis(1),
cap: Duration::from_millis(1),
max_attempts,
total_timeout: Duration::from_secs(60),
}
}
fn build_transport(
base_url: impl Into<String>,
parse_fail: bool,
policy: RetryPolicy,
) -> HttpTransport<DummyScheme> {
HttpTransport::new(
DummyScheme { parse_fail },
"test-model",
base_url,
ResolvedAuth::None,
ModelCapability::minimal(),
)
.with_retry_policy(policy)
}
fn ok_sse() -> ResponseTemplate {
ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(b"".to_vec(), "text/event-stream")
}
#[tokio::test]
async fn retries_503_then_succeeds() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat"))
.respond_with(ResponseTemplate::new(503).set_body_string("upstream connect error"))
.up_to_n_times(2)
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/v1/chat"))
.respond_with(ok_sse())
.mount(&server)
.await;
let transport = build_transport(server.uri(), false, fast_policy(5));
let mut stream = transport
.stream(Request::default())
.await
.expect("stream should succeed after retries");
while stream.next().await.is_some() {}
let received = server.received_requests().await.unwrap();
assert_eq!(received.len(), 3, "two failures plus one success expected");
}
#[tokio::test]
async fn retries_529_then_exhausts() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat"))
.respond_with(ResponseTemplate::new(529).set_body_string("overloaded"))
.mount(&server)
.await;
let transport = build_transport(server.uri(), false, fast_policy(3));
match transport.stream(Request::default()).await {
Err(ClientError::Api {
status: Some(529), ..
}) => {}
Err(other) => panic!("expected Api(529), got {other:?}"),
Ok(_) => panic!("expected error after exhausting retries"),
}
let received = server.received_requests().await.unwrap();
assert_eq!(received.len(), 3, "should hit max_attempts and stop");
}
#[tokio::test]
async fn connect_refused_retries_then_fails() {
// 接続不能なローカルアドレスを使う。Linux では `Connection refused` で
// 即時失敗するため、`fast_policy` ならテストが秒以下で終わる。
let unreachable = "http://127.0.0.1:1";
let transport = build_transport(unreachable, false, fast_policy(3));
match transport.stream(Request::default()).await {
Err(ClientError::Http(e)) => {
assert!(
e.is_connect() || e.is_timeout(),
"expected connect/timeout, got {e:?}"
);
}
Err(other) => panic!("expected Http error, got {other:?}"),
Ok(_) => panic!("expected error connecting to closed port"),
}
}
#[tokio::test]
async fn retry_after_header_overrides_backoff() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat"))
.respond_with(ResponseTemplate::new(503).insert_header("retry-after", "1"))
.up_to_n_times(1)
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/v1/chat"))
.respond_with(ok_sse())
.mount(&server)
.await;
// base/cap を 1ms に絞った policy で `Retry-After: 1` を観察すると、
// 指数バックオフ単独なら 1ms 程度で終わるはずが Retry-After に従って
// 1 秒待つ → 経過時間で override を検証できる。
let policy = RetryPolicy {
base: Duration::from_millis(1),
cap: Duration::from_millis(1),
max_attempts: 3,
total_timeout: Duration::from_secs(10),
};
let transport = build_transport(server.uri(), false, policy);
let start = Instant::now();
let mut stream = transport.stream(Request::default()).await.expect("ok");
while stream.next().await.is_some() {}
let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_secs(1),
"Retry-After=1 should make us wait >=1s, elapsed={elapsed:?}"
);
assert!(
elapsed < Duration::from_secs(3),
"Retry-After=1 should not balloon, elapsed={elapsed:?}"
);
}
#[tokio::test]
async fn mid_stream_sse_error_does_not_retry() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(
b"event: data\ndata: payload\n\n".to_vec(),
"text/event-stream",
),
)
.mount(&server)
.await;
let transport = build_transport(server.uri(), true, fast_policy(5));
let mut stream = transport
.stream(Request::default())
.await
.expect("status 200 should bypass retry loop");
let mut saw_sse_err = false;
while let Some(item) = stream.next().await {
if matches!(item, Err(ClientError::Sse(_))) {
saw_sse_err = true;
}
}
assert!(saw_sse_err, "expected Sse error from stream consumer");
let received = server.received_requests().await.unwrap();
assert_eq!(received.len(), 1, "mid-stream Sse must not retry");
}
#[tokio::test]
async fn non_retryable_status_returns_immediately() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat"))
.respond_with(ResponseTemplate::new(401).set_body_string("unauthorized"))
.mount(&server)
.await;
let transport = build_transport(server.uri(), false, fast_policy(5));
match transport.stream(Request::default()).await {
Err(ClientError::Api {
status: Some(401), ..
}) => {}
Err(other) => panic!("expected Api(401), got {other:?}"),
Ok(_) => panic!("expected error"),
}
let received = server.received_requests().await.unwrap();
assert_eq!(received.len(), 1, "401 must not retry");
}