From 9cd776eaeceb4249b171d84982529fd464994a8f Mon Sep 17 00:00:00 2001 From: Hare Date: Thu, 28 May 2026 02:41:15 +0900 Subject: [PATCH] fix: add llm request lifecycle timeouts --- crates/llm-worker/Cargo.toml | 2 +- crates/llm-worker/src/llm_client/error.rs | 18 ++ crates/llm-worker/src/llm_client/transport.rs | 38 +++- crates/llm-worker/src/worker.rs | 177 +++++++++++++----- crates/provider/Cargo.toml | 2 +- crates/provider/src/codex_oauth/refresh.rs | 50 ++++- 6 files changed, 231 insertions(+), 56 deletions(-) diff --git a/crates/llm-worker/Cargo.toml b/crates/llm-worker/Cargo.toml index f3bc5940..fcc29439 100644 --- a/crates/llm-worker/Cargo.toml +++ b/crates/llm-worker/Cargo.toml @@ -12,7 +12,7 @@ thiserror = { workspace = true } tracing = { workspace = true } async-trait = { workspace = true } futures = { workspace = true } -tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } +tokio = { workspace = true, features = ["macros", "rt-multi-thread", "time"] } tokio-util = "0.7" reqwest = { version = "0.13", default-features = false, features = ["stream", "json", "native-tls", "http2"] } eventsource-stream = "0.2" diff --git a/crates/llm-worker/src/llm_client/error.rs b/crates/llm-worker/src/llm_client/error.rs index 50f39e03..d5723646 100644 --- a/crates/llm-worker/src/llm_client/error.rs +++ b/crates/llm-worker/src/llm_client/error.rs @@ -18,6 +18,11 @@ pub enum ClientError { message: String, retry_after: Option, }, + /// A request lifecycle phase exceeded its hard timeout. + Timeout { + phase: &'static str, + timeout: Duration, + }, /// 設定エラー Config(String), } @@ -43,6 +48,9 @@ impl fmt::Display for ClientError { } write!(f, ": {}", message) } + ClientError::Timeout { phase, timeout } => { + write!(f, "{phase} timed out after {}s", timeout.as_secs()) + } ClientError::Config(msg) => write!(f, "Config error: {}", msg), } } @@ -91,6 +99,7 @@ impl ClientError { /// 対象: /// - `Api { status }` のうち 408 / 425 / 429 / 500 / 502 / 503 / 504 / 529 /// - `Http(reqwest::Error)` のうち `is_connect()` または `is_timeout()` +/// - `Timeout { .. }` の lifecycle hard timeout /// /// それ以外(Json、Sse、Config、上記以外の Api ステータス)は false。 /// SSE 読み出し開始後の失敗は呼び出し側で `Sse` として上に流すため、 @@ -101,6 +110,7 @@ pub fn is_retryable(error: &ClientError) -> bool { status: Some(code), .. } => matches!(*code, 408 | 425 | 429 | 500 | 502 | 503 | 504 | 529), ClientError::Api { status: None, .. } => false, + ClientError::Timeout { .. } => true, ClientError::Http(e) => e.is_connect() || e.is_timeout(), ClientError::Json(_) | ClientError::Sse(_) | ClientError::Config(_) => false, } @@ -144,6 +154,14 @@ mod tests { assert!(!is_retryable(&api_err(None))); } + #[test] + fn lifecycle_timeout_is_retryable() { + assert!(is_retryable(&ClientError::Timeout { + phase: "stream_open", + timeout: Duration::from_secs(30), + })); + } + #[test] fn json_sse_config_not_retryable() { let json_err = serde_json::from_str::("not json").unwrap_err(); diff --git a/crates/llm-worker/src/llm_client/transport.rs b/crates/llm-worker/src/llm_client/transport.rs index 334d883e..ae9f7b6c 100644 --- a/crates/llm-worker/src/llm_client/transport.rs +++ b/crates/llm-worker/src/llm_client/transport.rs @@ -23,6 +23,9 @@ use super::event::Event; use super::scheme::Scheme; use super::types::{Request, RequestConfig}; +pub const DEFAULT_STREAM_OPEN_TIMEOUT: Duration = Duration::from_secs(30); +pub const DEFAULT_FIRST_STREAM_EVENT_TIMEOUT: Duration = Duration::from_secs(30); + /// `AuthRef` を解決したランタイム表現。`crates/provider` が構築する。 /// /// - `None`: 認証ヘッダを送らない(Ollama 等の opt-out) @@ -201,6 +204,17 @@ enum RequestBody { CompressedJson(Vec), } +async fn response_with_timeout( + future: impl std::future::Future>, + timeout: Duration, + phase: &'static str, +) -> Result { + tokio::time::timeout(timeout, future) + .await + .map_err(|_| ClientError::Timeout { phase, timeout })? + .map_err(ClientError::Http) +} + impl Clone for HttpTransport { fn clone(&self) -> Self { Self { @@ -272,7 +286,9 @@ impl LlmClient for HttpTransport { RequestBody::Json(body) => builder.json(&body), RequestBody::CompressedJson(body) => builder.body(body), }; - let response = builder.send().await.map_err(ClientError::Http)?; + let response = + response_with_timeout(builder.send(), DEFAULT_STREAM_OPEN_TIMEOUT, "stream_open") + .await?; if !response.status().is_success() { return Err(classify_error_response(response).await); @@ -391,6 +407,26 @@ mod tests { ) } + #[tokio::test] + async fn response_timeout_returns_retryable_lifecycle_timeout() { + let err = response_with_timeout( + std::future::pending::>(), + Duration::from_millis(5), + "stream_open", + ) + .await + .unwrap_err(); + + assert!(crate::llm_client::error::is_retryable(&err)); + assert!(matches!( + err, + ClientError::Timeout { + phase: "stream_open", + .. + } + )); + } + #[tokio::test] async fn codex_backend_adds_conversation_headers_and_zstd_body() { let transport = transport(ResolvedAuth::Custom(Arc::new(TestAuthProvider { diff --git a/crates/llm-worker/src/worker.rs b/crates/llm-worker/src/worker.rs index b65036be..c4313d99 100644 --- a/crates/llm-worker/src/worker.rs +++ b/crates/llm-worker/src/worker.rs @@ -20,7 +20,7 @@ use crate::{ llm_client::{ ClientError, ConfigWarning, LlmClient, Request, RequestConfig, ResponseStream, ToolDefinition, error::is_retryable, event::Event, retry::RetryPolicy, - types::parse_tool_arguments, + transport::DEFAULT_FIRST_STREAM_EVENT_TIMEOUT, types::parse_tool_arguments, }, state::{Locked, Mutable, WorkerState}, timeline::event::{ErrorEvent, StatusEvent, UsageEvent}, @@ -1334,7 +1334,7 @@ impl Worker { } }; - match stream_result { + let err = match stream_result { Ok(stream) => { self.emit_lifecycle_trace( turn, @@ -1345,7 +1345,26 @@ impl Worker { "elapsed_ms": stream_started.elapsed().as_millis() as u64, }), ); - return Ok(stream); + match wait_for_first_stream_event(stream, DEFAULT_FIRST_STREAM_EVENT_TIMEOUT) + .await + { + Ok(FirstStreamEvent::Ready(stream)) => return Ok(stream), + Ok(FirstStreamEvent::Empty(stream)) => return Ok(stream), + Err(err) => { + self.emit_lifecycle_trace( + turn, + llm_call, + "stream_first_event_error", + json!({ + "attempt": attempt, + "elapsed_ms": stream_started.elapsed().as_millis() as u64, + "retryable": is_retryable(&err), + "error": err.to_string(), + }), + ); + err + } + } } Err(err) => { self.emit_lifecycle_trace( @@ -1360,54 +1379,56 @@ impl Worker { "error": err.to_string(), }), ); - let next_failed_attempt = failed_attempt + 1; - if next_failed_attempt >= policy.max_attempts || !is_retryable(&err) { - self.last_run_interrupted = true; - return Err(WorkerError::Client(err)); - } + err + } + }; - let wait = err - .retry_after() - .unwrap_or_else(|| policy.backoff(failed_attempt)); - let elapsed = started.elapsed(); - if elapsed + wait > policy.total_timeout { - self.last_run_interrupted = true; - return Err(WorkerError::Client(err)); - } + let next_failed_attempt = failed_attempt + 1; + if next_failed_attempt >= policy.max_attempts || !is_retryable(&err) { + self.last_run_interrupted = true; + return Err(WorkerError::Client(err)); + } - warn!( - error = %err, - failed_attempt = next_failed_attempt, - wait_ms = wait.as_millis() as u64, - "transient LLM request error, retrying" - ); - let notice = LlmRetryNotice { - failed_attempt: next_failed_attempt, - max_attempts: policy.max_attempts, - wait, - elapsed, - status: err.status(), - error: err.to_string(), - }; - for cb in &self.llm_retry_cbs { - cb(llm_call, ¬ice); - } + let wait = err + .retry_after() + .unwrap_or_else(|| policy.backoff(failed_attempt)); + let elapsed = started.elapsed(); + if elapsed + wait > policy.total_timeout { + self.last_run_interrupted = true; + return Err(WorkerError::Client(err)); + } - tokio::select! { - _ = tokio::time::sleep(wait) => {} - cancel = self.cancel_rx.recv() => { - if cancel.is_some() { - info!("Cancelled during LLM retry backoff"); - } - self.timeline.abort_current_block(); - self.last_run_interrupted = true; - return Err(WorkerError::Cancelled); - } - } + warn!( + error = %err, + failed_attempt = next_failed_attempt, + wait_ms = wait.as_millis() as u64, + "transient LLM request error, retrying" + ); + let notice = LlmRetryNotice { + failed_attempt: next_failed_attempt, + max_attempts: policy.max_attempts, + wait, + elapsed, + status: err.status(), + error: err.to_string(), + }; + for cb in &self.llm_retry_cbs { + cb(llm_call, ¬ice); + } - failed_attempt = next_failed_attempt; + tokio::select! { + _ = tokio::time::sleep(wait) => {} + cancel = self.cancel_rx.recv() => { + if cancel.is_some() { + info!("Cancelled during LLM retry backoff"); + } + self.timeline.abort_current_block(); + self.last_run_interrupted = true; + return Err(WorkerError::Cancelled); } } + + failed_attempt = next_failed_attempt; } } @@ -1932,6 +1953,29 @@ impl Worker { } } +enum FirstStreamEvent { + Ready(ResponseStream), + Empty(ResponseStream), +} + +async fn wait_for_first_stream_event( + mut stream: ResponseStream, + timeout: std::time::Duration, +) -> Result { + match tokio::time::timeout(timeout, stream.next()).await { + Ok(Some(first)) => { + let first = first?; + let stream = futures::stream::once(async move { Ok(first) }).chain(stream); + Ok(FirstStreamEvent::Ready(Box::pin(stream))) + } + Ok(None) => Ok(FirstStreamEvent::Empty(stream)), + Err(_) => Err(ClientError::Timeout { + phase: "stream_first_event", + timeout, + }), + } +} + fn items_trace_payload( items: &[Item], tools_len: usize, @@ -1990,5 +2034,46 @@ fn item_kind(item: &Item) -> &'static str { #[cfg(test)] mod tests { - // Basic tests only. Tests using LlmClient are done in integration tests. + use super::*; + use std::time::Duration; + + #[tokio::test] + async fn first_stream_event_timeout_returns_retryable_timeout() { + let stream: ResponseStream = Box::pin(futures::stream::pending()); + let err = match wait_for_first_stream_event(stream, Duration::from_millis(5)).await { + Ok(_) => panic!("expected first event timeout"), + Err(err) => err, + }; + + assert!(is_retryable(&err)); + assert!(matches!( + err, + ClientError::Timeout { + phase: "stream_first_event", + .. + } + )); + } + + #[tokio::test] + async fn first_stream_event_is_replayed_after_probe() { + let first = Event::Status(crate::llm_client::event::StatusEvent { + status: crate::llm_client::event::ResponseStatus::Started, + }); + let stream: ResponseStream = Box::pin(futures::stream::once({ + let first = first.clone(); + async move { Ok(first) } + })); + + let FirstStreamEvent::Ready(mut stream) = + wait_for_first_stream_event(stream, Duration::from_secs(1)) + .await + .unwrap() + else { + panic!("expected first event to be buffered"); + }; + + let replayed = stream.next().await.unwrap().unwrap(); + assert_eq!(replayed, first); + } } diff --git a/crates/provider/Cargo.toml b/crates/provider/Cargo.toml index 3e83ae24..8ae8939c 100644 --- a/crates/provider/Cargo.toml +++ b/crates/provider/Cargo.toml @@ -14,7 +14,7 @@ reqwest = { version = "0.13", features = ["json", "native-tls"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } thiserror = { workspace = true } -tokio = { workspace = true, features = ["sync", "fs", "rt"] } +tokio = { workspace = true, features = ["sync", "fs", "rt", "time"] } toml = { workspace = true } tracing = { workspace = true } diff --git a/crates/provider/src/codex_oauth/refresh.rs b/crates/provider/src/codex_oauth/refresh.rs index ff15844e..1d94f8c9 100644 --- a/crates/provider/src/codex_oauth/refresh.rs +++ b/crates/provider/src/codex_oauth/refresh.rs @@ -4,11 +4,13 @@ //! 401 + `error.code` で永続失敗を分類する。 use serde::{Deserialize, Serialize}; +use std::time::Duration; use super::error::{CodexAuthError, PermanentReason}; pub const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann"; pub const REFRESH_URL: &str = "https://auth.openai.com/oauth/token"; +pub const DEFAULT_REFRESH_TIMEOUT: Duration = Duration::from_secs(30); #[derive(Serialize)] struct RefreshRequest<'a> { @@ -41,13 +43,15 @@ pub async fn request_refresh( grant_type: "refresh_token", refresh_token, }; - let response = client - .post(endpoint) - .header("Content-Type", "application/json") - .json(&body) - .send() - .await - .map_err(|e| CodexAuthError::RefreshTransient(format!("send: {e}")))?; + let response = response_with_timeout( + client + .post(endpoint) + .header("Content-Type", "application/json") + .json(&body) + .send(), + DEFAULT_REFRESH_TIMEOUT, + ) + .await?; let status = response.status(); if status.is_success() { @@ -68,6 +72,21 @@ pub async fn request_refresh( } } +async fn response_with_timeout( + future: impl std::future::Future>, + timeout: Duration, +) -> Result { + tokio::time::timeout(timeout, future) + .await + .map_err(|_| { + CodexAuthError::RefreshTransient(format!( + "codex_oauth_refresh timed out after {}s", + timeout.as_secs() + )) + })? + .map_err(|e| CodexAuthError::RefreshTransient(format!("send: {e}"))) +} + fn classify_permanent(body: &str) -> (PermanentReason, String) { let code = extract_error_code(body); let reason = match code.as_deref() { @@ -107,6 +126,23 @@ fn extract_error_code(body: &str) -> Option { mod tests { use super::*; + #[tokio::test] + async fn refresh_response_timeout_is_transient() { + let err = match response_with_timeout( + std::future::pending::>(), + Duration::from_millis(5), + ) + .await + { + Ok(_) => panic!("expected refresh timeout"), + Err(err) => err, + }; + + assert!( + matches!(err, CodexAuthError::RefreshTransient(message) if message.contains("timed out")) + ); + } + #[test] fn classify_expired() { let body = r#"{"error":{"code":"refresh_token_expired"}}"#;