fix: add llm request lifecycle timeouts

This commit is contained in:
Keisuke Hirata 2026-05-28 02:41:15 +09:00
parent 647223eb32
commit 9cd776eaec
6 changed files with 231 additions and 56 deletions

View File

@ -12,7 +12,7 @@ thiserror = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
async-trait = { workspace = true } async-trait = { workspace = true }
futures = { workspace = true } futures = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } tokio = { workspace = true, features = ["macros", "rt-multi-thread", "time"] }
tokio-util = "0.7" tokio-util = "0.7"
reqwest = { version = "0.13", default-features = false, features = ["stream", "json", "native-tls", "http2"] } reqwest = { version = "0.13", default-features = false, features = ["stream", "json", "native-tls", "http2"] }
eventsource-stream = "0.2" eventsource-stream = "0.2"

View File

@ -18,6 +18,11 @@ pub enum ClientError {
message: String, message: String,
retry_after: Option<Duration>, retry_after: Option<Duration>,
}, },
/// A request lifecycle phase exceeded its hard timeout.
Timeout {
phase: &'static str,
timeout: Duration,
},
/// 設定エラー /// 設定エラー
Config(String), Config(String),
} }
@ -43,6 +48,9 @@ impl fmt::Display for ClientError {
} }
write!(f, ": {}", message) write!(f, ": {}", message)
} }
ClientError::Timeout { phase, timeout } => {
write!(f, "{phase} timed out after {}s", timeout.as_secs())
}
ClientError::Config(msg) => write!(f, "Config error: {}", msg), ClientError::Config(msg) => write!(f, "Config error: {}", msg),
} }
} }
@ -91,6 +99,7 @@ impl ClientError {
/// 対象: /// 対象:
/// - `Api { status }` のうち 408 / 425 / 429 / 500 / 502 / 503 / 504 / 529 /// - `Api { status }` のうち 408 / 425 / 429 / 500 / 502 / 503 / 504 / 529
/// - `Http(reqwest::Error)` のうち `is_connect()` または `is_timeout()` /// - `Http(reqwest::Error)` のうち `is_connect()` または `is_timeout()`
/// - `Timeout { .. }` の lifecycle hard timeout
/// ///
/// それ以外Json、Sse、Config、上記以外の Api ステータス)は false。 /// それ以外Json、Sse、Config、上記以外の Api ステータス)は false。
/// SSE 読み出し開始後の失敗は呼び出し側で `Sse` として上に流すため、 /// SSE 読み出し開始後の失敗は呼び出し側で `Sse` として上に流すため、
@ -101,6 +110,7 @@ pub fn is_retryable(error: &ClientError) -> bool {
status: Some(code), .. status: Some(code), ..
} => matches!(*code, 408 | 425 | 429 | 500 | 502 | 503 | 504 | 529), } => matches!(*code, 408 | 425 | 429 | 500 | 502 | 503 | 504 | 529),
ClientError::Api { status: None, .. } => false, ClientError::Api { status: None, .. } => false,
ClientError::Timeout { .. } => true,
ClientError::Http(e) => e.is_connect() || e.is_timeout(), ClientError::Http(e) => e.is_connect() || e.is_timeout(),
ClientError::Json(_) | ClientError::Sse(_) | ClientError::Config(_) => false, ClientError::Json(_) | ClientError::Sse(_) | ClientError::Config(_) => false,
} }
@ -144,6 +154,14 @@ mod tests {
assert!(!is_retryable(&api_err(None))); assert!(!is_retryable(&api_err(None)));
} }
#[test]
fn lifecycle_timeout_is_retryable() {
assert!(is_retryable(&ClientError::Timeout {
phase: "stream_open",
timeout: Duration::from_secs(30),
}));
}
#[test] #[test]
fn json_sse_config_not_retryable() { fn json_sse_config_not_retryable() {
let json_err = serde_json::from_str::<serde_json::Value>("not json").unwrap_err(); let json_err = serde_json::from_str::<serde_json::Value>("not json").unwrap_err();

View File

@ -23,6 +23,9 @@ use super::event::Event;
use super::scheme::Scheme; use super::scheme::Scheme;
use super::types::{Request, RequestConfig}; use super::types::{Request, RequestConfig};
pub const DEFAULT_STREAM_OPEN_TIMEOUT: Duration = Duration::from_secs(30);
pub const DEFAULT_FIRST_STREAM_EVENT_TIMEOUT: Duration = Duration::from_secs(30);
/// `AuthRef` を解決したランタイム表現。`crates/provider` が構築する。 /// `AuthRef` を解決したランタイム表現。`crates/provider` が構築する。
/// ///
/// - `None`: 認証ヘッダを送らないOllama 等の opt-out /// - `None`: 認証ヘッダを送らないOllama 等の opt-out
@ -201,6 +204,17 @@ enum RequestBody {
CompressedJson(Vec<u8>), CompressedJson(Vec<u8>),
} }
async fn response_with_timeout(
future: impl std::future::Future<Output = Result<reqwest::Response, reqwest::Error>>,
timeout: Duration,
phase: &'static str,
) -> Result<reqwest::Response, ClientError> {
tokio::time::timeout(timeout, future)
.await
.map_err(|_| ClientError::Timeout { phase, timeout })?
.map_err(ClientError::Http)
}
impl<S: Scheme + Clone> Clone for HttpTransport<S> { impl<S: Scheme + Clone> Clone for HttpTransport<S> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
@ -272,7 +286,9 @@ impl<S: Scheme + Clone + 'static> LlmClient for HttpTransport<S> {
RequestBody::Json(body) => builder.json(&body), RequestBody::Json(body) => builder.json(&body),
RequestBody::CompressedJson(body) => builder.body(body), RequestBody::CompressedJson(body) => builder.body(body),
}; };
let response = builder.send().await.map_err(ClientError::Http)?; let response =
response_with_timeout(builder.send(), DEFAULT_STREAM_OPEN_TIMEOUT, "stream_open")
.await?;
if !response.status().is_success() { if !response.status().is_success() {
return Err(classify_error_response(response).await); return Err(classify_error_response(response).await);
@ -391,6 +407,26 @@ mod tests {
) )
} }
#[tokio::test]
async fn response_timeout_returns_retryable_lifecycle_timeout() {
let err = response_with_timeout(
std::future::pending::<Result<reqwest::Response, reqwest::Error>>(),
Duration::from_millis(5),
"stream_open",
)
.await
.unwrap_err();
assert!(crate::llm_client::error::is_retryable(&err));
assert!(matches!(
err,
ClientError::Timeout {
phase: "stream_open",
..
}
));
}
#[tokio::test] #[tokio::test]
async fn codex_backend_adds_conversation_headers_and_zstd_body() { async fn codex_backend_adds_conversation_headers_and_zstd_body() {
let transport = transport(ResolvedAuth::Custom(Arc::new(TestAuthProvider { let transport = transport(ResolvedAuth::Custom(Arc::new(TestAuthProvider {

View File

@ -20,7 +20,7 @@ use crate::{
llm_client::{ llm_client::{
ClientError, ConfigWarning, LlmClient, Request, RequestConfig, ResponseStream, ClientError, ConfigWarning, LlmClient, Request, RequestConfig, ResponseStream,
ToolDefinition, error::is_retryable, event::Event, retry::RetryPolicy, ToolDefinition, error::is_retryable, event::Event, retry::RetryPolicy,
types::parse_tool_arguments, transport::DEFAULT_FIRST_STREAM_EVENT_TIMEOUT, types::parse_tool_arguments,
}, },
state::{Locked, Mutable, WorkerState}, state::{Locked, Mutable, WorkerState},
timeline::event::{ErrorEvent, StatusEvent, UsageEvent}, timeline::event::{ErrorEvent, StatusEvent, UsageEvent},
@ -1334,7 +1334,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
} }
}; };
match stream_result { let err = match stream_result {
Ok(stream) => { Ok(stream) => {
self.emit_lifecycle_trace( self.emit_lifecycle_trace(
turn, turn,
@ -1345,7 +1345,26 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
"elapsed_ms": stream_started.elapsed().as_millis() as u64, "elapsed_ms": stream_started.elapsed().as_millis() as u64,
}), }),
); );
return Ok(stream); match wait_for_first_stream_event(stream, DEFAULT_FIRST_STREAM_EVENT_TIMEOUT)
.await
{
Ok(FirstStreamEvent::Ready(stream)) => return Ok(stream),
Ok(FirstStreamEvent::Empty(stream)) => return Ok(stream),
Err(err) => {
self.emit_lifecycle_trace(
turn,
llm_call,
"stream_first_event_error",
json!({
"attempt": attempt,
"elapsed_ms": stream_started.elapsed().as_millis() as u64,
"retryable": is_retryable(&err),
"error": err.to_string(),
}),
);
err
}
}
} }
Err(err) => { Err(err) => {
self.emit_lifecycle_trace( self.emit_lifecycle_trace(
@ -1360,54 +1379,56 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
"error": err.to_string(), "error": err.to_string(),
}), }),
); );
let next_failed_attempt = failed_attempt + 1; err
if next_failed_attempt >= policy.max_attempts || !is_retryable(&err) { }
self.last_run_interrupted = true; };
return Err(WorkerError::Client(err));
}
let wait = err let next_failed_attempt = failed_attempt + 1;
.retry_after() if next_failed_attempt >= policy.max_attempts || !is_retryable(&err) {
.unwrap_or_else(|| policy.backoff(failed_attempt)); self.last_run_interrupted = true;
let elapsed = started.elapsed(); return Err(WorkerError::Client(err));
if elapsed + wait > policy.total_timeout { }
self.last_run_interrupted = true;
return Err(WorkerError::Client(err));
}
warn!( let wait = err
error = %err, .retry_after()
failed_attempt = next_failed_attempt, .unwrap_or_else(|| policy.backoff(failed_attempt));
wait_ms = wait.as_millis() as u64, let elapsed = started.elapsed();
"transient LLM request error, retrying" if elapsed + wait > policy.total_timeout {
); self.last_run_interrupted = true;
let notice = LlmRetryNotice { return Err(WorkerError::Client(err));
failed_attempt: next_failed_attempt, }
max_attempts: policy.max_attempts,
wait,
elapsed,
status: err.status(),
error: err.to_string(),
};
for cb in &self.llm_retry_cbs {
cb(llm_call, &notice);
}
tokio::select! { warn!(
_ = tokio::time::sleep(wait) => {} error = %err,
cancel = self.cancel_rx.recv() => { failed_attempt = next_failed_attempt,
if cancel.is_some() { wait_ms = wait.as_millis() as u64,
info!("Cancelled during LLM retry backoff"); "transient LLM request error, retrying"
} );
self.timeline.abort_current_block(); let notice = LlmRetryNotice {
self.last_run_interrupted = true; failed_attempt: next_failed_attempt,
return Err(WorkerError::Cancelled); max_attempts: policy.max_attempts,
} wait,
} elapsed,
status: err.status(),
error: err.to_string(),
};
for cb in &self.llm_retry_cbs {
cb(llm_call, &notice);
}
failed_attempt = next_failed_attempt; tokio::select! {
_ = tokio::time::sleep(wait) => {}
cancel = self.cancel_rx.recv() => {
if cancel.is_some() {
info!("Cancelled during LLM retry backoff");
}
self.timeline.abort_current_block();
self.last_run_interrupted = true;
return Err(WorkerError::Cancelled);
} }
} }
failed_attempt = next_failed_attempt;
} }
} }
@ -1932,6 +1953,29 @@ impl<C: LlmClient> Worker<C, Locked> {
} }
} }
enum FirstStreamEvent {
Ready(ResponseStream),
Empty(ResponseStream),
}
async fn wait_for_first_stream_event(
mut stream: ResponseStream,
timeout: std::time::Duration,
) -> Result<FirstStreamEvent, ClientError> {
match tokio::time::timeout(timeout, stream.next()).await {
Ok(Some(first)) => {
let first = first?;
let stream = futures::stream::once(async move { Ok(first) }).chain(stream);
Ok(FirstStreamEvent::Ready(Box::pin(stream)))
}
Ok(None) => Ok(FirstStreamEvent::Empty(stream)),
Err(_) => Err(ClientError::Timeout {
phase: "stream_first_event",
timeout,
}),
}
}
fn items_trace_payload( fn items_trace_payload(
items: &[Item], items: &[Item],
tools_len: usize, tools_len: usize,
@ -1990,5 +2034,46 @@ fn item_kind(item: &Item) -> &'static str {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
// Basic tests only. Tests using LlmClient are done in integration tests. use super::*;
use std::time::Duration;
#[tokio::test]
async fn first_stream_event_timeout_returns_retryable_timeout() {
let stream: ResponseStream = Box::pin(futures::stream::pending());
let err = match wait_for_first_stream_event(stream, Duration::from_millis(5)).await {
Ok(_) => panic!("expected first event timeout"),
Err(err) => err,
};
assert!(is_retryable(&err));
assert!(matches!(
err,
ClientError::Timeout {
phase: "stream_first_event",
..
}
));
}
#[tokio::test]
async fn first_stream_event_is_replayed_after_probe() {
let first = Event::Status(crate::llm_client::event::StatusEvent {
status: crate::llm_client::event::ResponseStatus::Started,
});
let stream: ResponseStream = Box::pin(futures::stream::once({
let first = first.clone();
async move { Ok(first) }
}));
let FirstStreamEvent::Ready(mut stream) =
wait_for_first_stream_event(stream, Duration::from_secs(1))
.await
.unwrap()
else {
panic!("expected first event to be buffered");
};
let replayed = stream.next().await.unwrap().unwrap();
assert_eq!(replayed, first);
}
} }

View File

@ -14,7 +14,7 @@ reqwest = { version = "0.13", features = ["json", "native-tls"] }
serde = { workspace = true, features = ["derive"] } serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true } serde_json = { workspace = true }
thiserror = { workspace = true } thiserror = { workspace = true }
tokio = { workspace = true, features = ["sync", "fs", "rt"] } tokio = { workspace = true, features = ["sync", "fs", "rt", "time"] }
toml = { workspace = true } toml = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }

View File

@ -4,11 +4,13 @@
//! 401 + `error.code` で永続失敗を分類する。 //! 401 + `error.code` で永続失敗を分類する。
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::time::Duration;
use super::error::{CodexAuthError, PermanentReason}; use super::error::{CodexAuthError, PermanentReason};
pub const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann"; pub const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
pub const REFRESH_URL: &str = "https://auth.openai.com/oauth/token"; pub const REFRESH_URL: &str = "https://auth.openai.com/oauth/token";
pub const DEFAULT_REFRESH_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Serialize)] #[derive(Serialize)]
struct RefreshRequest<'a> { struct RefreshRequest<'a> {
@ -41,13 +43,15 @@ pub async fn request_refresh(
grant_type: "refresh_token", grant_type: "refresh_token",
refresh_token, refresh_token,
}; };
let response = client let response = response_with_timeout(
.post(endpoint) client
.header("Content-Type", "application/json") .post(endpoint)
.json(&body) .header("Content-Type", "application/json")
.send() .json(&body)
.await .send(),
.map_err(|e| CodexAuthError::RefreshTransient(format!("send: {e}")))?; DEFAULT_REFRESH_TIMEOUT,
)
.await?;
let status = response.status(); let status = response.status();
if status.is_success() { if status.is_success() {
@ -68,6 +72,21 @@ pub async fn request_refresh(
} }
} }
async fn response_with_timeout(
future: impl std::future::Future<Output = Result<reqwest::Response, reqwest::Error>>,
timeout: Duration,
) -> Result<reqwest::Response, CodexAuthError> {
tokio::time::timeout(timeout, future)
.await
.map_err(|_| {
CodexAuthError::RefreshTransient(format!(
"codex_oauth_refresh timed out after {}s",
timeout.as_secs()
))
})?
.map_err(|e| CodexAuthError::RefreshTransient(format!("send: {e}")))
}
fn classify_permanent(body: &str) -> (PermanentReason, String) { fn classify_permanent(body: &str) -> (PermanentReason, String) {
let code = extract_error_code(body); let code = extract_error_code(body);
let reason = match code.as_deref() { let reason = match code.as_deref() {
@ -107,6 +126,23 @@ fn extract_error_code(body: &str) -> Option<String> {
mod tests { mod tests {
use super::*; use super::*;
#[tokio::test]
async fn refresh_response_timeout_is_transient() {
let err = match response_with_timeout(
std::future::pending::<Result<reqwest::Response, reqwest::Error>>(),
Duration::from_millis(5),
)
.await
{
Ok(_) => panic!("expected refresh timeout"),
Err(err) => err,
};
assert!(
matches!(err, CodexAuthError::RefreshTransient(message) if message.contains("timed out"))
);
}
#[test] #[test]
fn classify_expired() { fn classify_expired() {
let body = r#"{"error":{"code":"refresh_token_expired"}}"#; let body = r#"{"error":{"code":"refresh_token_expired"}}"#;