fix: add llm request lifecycle timeouts

This commit is contained in:
Keisuke Hirata 2026-05-28 02:41:15 +09:00
parent 647223eb32
commit 9cd776eaec
6 changed files with 231 additions and 56 deletions

View File

@ -12,7 +12,7 @@ thiserror = { workspace = true }
tracing = { workspace = true }
async-trait = { workspace = true }
futures = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "time"] }
tokio-util = "0.7"
reqwest = { version = "0.13", default-features = false, features = ["stream", "json", "native-tls", "http2"] }
eventsource-stream = "0.2"

View File

@ -18,6 +18,11 @@ pub enum ClientError {
message: String,
retry_after: Option<Duration>,
},
/// A request lifecycle phase exceeded its hard timeout.
Timeout {
phase: &'static str,
timeout: Duration,
},
/// 設定エラー
Config(String),
}
@ -43,6 +48,9 @@ impl fmt::Display for ClientError {
}
write!(f, ": {}", message)
}
ClientError::Timeout { phase, timeout } => {
write!(f, "{phase} timed out after {}s", timeout.as_secs())
}
ClientError::Config(msg) => write!(f, "Config error: {}", msg),
}
}
@ -91,6 +99,7 @@ impl ClientError {
/// 対象:
/// - `Api { status }` のうち 408 / 425 / 429 / 500 / 502 / 503 / 504 / 529
/// - `Http(reqwest::Error)` のうち `is_connect()` または `is_timeout()`
/// - `Timeout { .. }` の lifecycle hard timeout
///
/// それ以外Json、Sse、Config、上記以外の Api ステータス)は false。
/// SSE 読み出し開始後の失敗は呼び出し側で `Sse` として上に流すため、
@ -101,6 +110,7 @@ pub fn is_retryable(error: &ClientError) -> bool {
status: Some(code), ..
} => matches!(*code, 408 | 425 | 429 | 500 | 502 | 503 | 504 | 529),
ClientError::Api { status: None, .. } => false,
ClientError::Timeout { .. } => true,
ClientError::Http(e) => e.is_connect() || e.is_timeout(),
ClientError::Json(_) | ClientError::Sse(_) | ClientError::Config(_) => false,
}
@ -144,6 +154,14 @@ mod tests {
assert!(!is_retryable(&api_err(None)));
}
#[test]
fn lifecycle_timeout_is_retryable() {
assert!(is_retryable(&ClientError::Timeout {
phase: "stream_open",
timeout: Duration::from_secs(30),
}));
}
#[test]
fn json_sse_config_not_retryable() {
let json_err = serde_json::from_str::<serde_json::Value>("not json").unwrap_err();

View File

@ -23,6 +23,9 @@ use super::event::Event;
use super::scheme::Scheme;
use super::types::{Request, RequestConfig};
pub const DEFAULT_STREAM_OPEN_TIMEOUT: Duration = Duration::from_secs(30);
pub const DEFAULT_FIRST_STREAM_EVENT_TIMEOUT: Duration = Duration::from_secs(30);
/// `AuthRef` を解決したランタイム表現。`crates/provider` が構築する。
///
/// - `None`: 認証ヘッダを送らないOllama 等の opt-out
@ -201,6 +204,17 @@ enum RequestBody {
CompressedJson(Vec<u8>),
}
async fn response_with_timeout(
future: impl std::future::Future<Output = Result<reqwest::Response, reqwest::Error>>,
timeout: Duration,
phase: &'static str,
) -> Result<reqwest::Response, ClientError> {
tokio::time::timeout(timeout, future)
.await
.map_err(|_| ClientError::Timeout { phase, timeout })?
.map_err(ClientError::Http)
}
impl<S: Scheme + Clone> Clone for HttpTransport<S> {
fn clone(&self) -> Self {
Self {
@ -272,7 +286,9 @@ impl<S: Scheme + Clone + 'static> LlmClient for HttpTransport<S> {
RequestBody::Json(body) => builder.json(&body),
RequestBody::CompressedJson(body) => builder.body(body),
};
let response = builder.send().await.map_err(ClientError::Http)?;
let response =
response_with_timeout(builder.send(), DEFAULT_STREAM_OPEN_TIMEOUT, "stream_open")
.await?;
if !response.status().is_success() {
return Err(classify_error_response(response).await);
@ -391,6 +407,26 @@ mod tests {
)
}
#[tokio::test]
async fn response_timeout_returns_retryable_lifecycle_timeout() {
let err = response_with_timeout(
std::future::pending::<Result<reqwest::Response, reqwest::Error>>(),
Duration::from_millis(5),
"stream_open",
)
.await
.unwrap_err();
assert!(crate::llm_client::error::is_retryable(&err));
assert!(matches!(
err,
ClientError::Timeout {
phase: "stream_open",
..
}
));
}
#[tokio::test]
async fn codex_backend_adds_conversation_headers_and_zstd_body() {
let transport = transport(ResolvedAuth::Custom(Arc::new(TestAuthProvider {

View File

@ -20,7 +20,7 @@ use crate::{
llm_client::{
ClientError, ConfigWarning, LlmClient, Request, RequestConfig, ResponseStream,
ToolDefinition, error::is_retryable, event::Event, retry::RetryPolicy,
types::parse_tool_arguments,
transport::DEFAULT_FIRST_STREAM_EVENT_TIMEOUT, types::parse_tool_arguments,
},
state::{Locked, Mutable, WorkerState},
timeline::event::{ErrorEvent, StatusEvent, UsageEvent},
@ -1334,7 +1334,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
}
};
match stream_result {
let err = match stream_result {
Ok(stream) => {
self.emit_lifecycle_trace(
turn,
@ -1345,7 +1345,26 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
"elapsed_ms": stream_started.elapsed().as_millis() as u64,
}),
);
return Ok(stream);
match wait_for_first_stream_event(stream, DEFAULT_FIRST_STREAM_EVENT_TIMEOUT)
.await
{
Ok(FirstStreamEvent::Ready(stream)) => return Ok(stream),
Ok(FirstStreamEvent::Empty(stream)) => return Ok(stream),
Err(err) => {
self.emit_lifecycle_trace(
turn,
llm_call,
"stream_first_event_error",
json!({
"attempt": attempt,
"elapsed_ms": stream_started.elapsed().as_millis() as u64,
"retryable": is_retryable(&err),
"error": err.to_string(),
}),
);
err
}
}
}
Err(err) => {
self.emit_lifecycle_trace(
@ -1360,54 +1379,56 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
"error": err.to_string(),
}),
);
let next_failed_attempt = failed_attempt + 1;
if next_failed_attempt >= policy.max_attempts || !is_retryable(&err) {
self.last_run_interrupted = true;
return Err(WorkerError::Client(err));
}
err
}
};
let wait = err
.retry_after()
.unwrap_or_else(|| policy.backoff(failed_attempt));
let elapsed = started.elapsed();
if elapsed + wait > policy.total_timeout {
self.last_run_interrupted = true;
return Err(WorkerError::Client(err));
}
let next_failed_attempt = failed_attempt + 1;
if next_failed_attempt >= policy.max_attempts || !is_retryable(&err) {
self.last_run_interrupted = true;
return Err(WorkerError::Client(err));
}
warn!(
error = %err,
failed_attempt = next_failed_attempt,
wait_ms = wait.as_millis() as u64,
"transient LLM request error, retrying"
);
let notice = LlmRetryNotice {
failed_attempt: next_failed_attempt,
max_attempts: policy.max_attempts,
wait,
elapsed,
status: err.status(),
error: err.to_string(),
};
for cb in &self.llm_retry_cbs {
cb(llm_call, &notice);
}
let wait = err
.retry_after()
.unwrap_or_else(|| policy.backoff(failed_attempt));
let elapsed = started.elapsed();
if elapsed + wait > policy.total_timeout {
self.last_run_interrupted = true;
return Err(WorkerError::Client(err));
}
tokio::select! {
_ = tokio::time::sleep(wait) => {}
cancel = self.cancel_rx.recv() => {
if cancel.is_some() {
info!("Cancelled during LLM retry backoff");
}
self.timeline.abort_current_block();
self.last_run_interrupted = true;
return Err(WorkerError::Cancelled);
}
}
warn!(
error = %err,
failed_attempt = next_failed_attempt,
wait_ms = wait.as_millis() as u64,
"transient LLM request error, retrying"
);
let notice = LlmRetryNotice {
failed_attempt: next_failed_attempt,
max_attempts: policy.max_attempts,
wait,
elapsed,
status: err.status(),
error: err.to_string(),
};
for cb in &self.llm_retry_cbs {
cb(llm_call, &notice);
}
failed_attempt = next_failed_attempt;
tokio::select! {
_ = tokio::time::sleep(wait) => {}
cancel = self.cancel_rx.recv() => {
if cancel.is_some() {
info!("Cancelled during LLM retry backoff");
}
self.timeline.abort_current_block();
self.last_run_interrupted = true;
return Err(WorkerError::Cancelled);
}
}
failed_attempt = next_failed_attempt;
}
}
@ -1932,6 +1953,29 @@ impl<C: LlmClient> Worker<C, Locked> {
}
}
enum FirstStreamEvent {
Ready(ResponseStream),
Empty(ResponseStream),
}
async fn wait_for_first_stream_event(
mut stream: ResponseStream,
timeout: std::time::Duration,
) -> Result<FirstStreamEvent, ClientError> {
match tokio::time::timeout(timeout, stream.next()).await {
Ok(Some(first)) => {
let first = first?;
let stream = futures::stream::once(async move { Ok(first) }).chain(stream);
Ok(FirstStreamEvent::Ready(Box::pin(stream)))
}
Ok(None) => Ok(FirstStreamEvent::Empty(stream)),
Err(_) => Err(ClientError::Timeout {
phase: "stream_first_event",
timeout,
}),
}
}
fn items_trace_payload(
items: &[Item],
tools_len: usize,
@ -1990,5 +2034,46 @@ fn item_kind(item: &Item) -> &'static str {
#[cfg(test)]
mod tests {
// Basic tests only. Tests using LlmClient are done in integration tests.
use super::*;
use std::time::Duration;
#[tokio::test]
async fn first_stream_event_timeout_returns_retryable_timeout() {
let stream: ResponseStream = Box::pin(futures::stream::pending());
let err = match wait_for_first_stream_event(stream, Duration::from_millis(5)).await {
Ok(_) => panic!("expected first event timeout"),
Err(err) => err,
};
assert!(is_retryable(&err));
assert!(matches!(
err,
ClientError::Timeout {
phase: "stream_first_event",
..
}
));
}
#[tokio::test]
async fn first_stream_event_is_replayed_after_probe() {
let first = Event::Status(crate::llm_client::event::StatusEvent {
status: crate::llm_client::event::ResponseStatus::Started,
});
let stream: ResponseStream = Box::pin(futures::stream::once({
let first = first.clone();
async move { Ok(first) }
}));
let FirstStreamEvent::Ready(mut stream) =
wait_for_first_stream_event(stream, Duration::from_secs(1))
.await
.unwrap()
else {
panic!("expected first event to be buffered");
};
let replayed = stream.next().await.unwrap().unwrap();
assert_eq!(replayed, first);
}
}

View File

@ -14,7 +14,7 @@ reqwest = { version = "0.13", features = ["json", "native-tls"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["sync", "fs", "rt"] }
tokio = { workspace = true, features = ["sync", "fs", "rt", "time"] }
toml = { workspace = true }
tracing = { workspace = true }

View File

@ -4,11 +4,13 @@
//! 401 + `error.code` で永続失敗を分類する。
use serde::{Deserialize, Serialize};
use std::time::Duration;
use super::error::{CodexAuthError, PermanentReason};
pub const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
pub const REFRESH_URL: &str = "https://auth.openai.com/oauth/token";
pub const DEFAULT_REFRESH_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Serialize)]
struct RefreshRequest<'a> {
@ -41,13 +43,15 @@ pub async fn request_refresh(
grant_type: "refresh_token",
refresh_token,
};
let response = client
.post(endpoint)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| CodexAuthError::RefreshTransient(format!("send: {e}")))?;
let response = response_with_timeout(
client
.post(endpoint)
.header("Content-Type", "application/json")
.json(&body)
.send(),
DEFAULT_REFRESH_TIMEOUT,
)
.await?;
let status = response.status();
if status.is_success() {
@ -68,6 +72,21 @@ pub async fn request_refresh(
}
}
async fn response_with_timeout(
future: impl std::future::Future<Output = Result<reqwest::Response, reqwest::Error>>,
timeout: Duration,
) -> Result<reqwest::Response, CodexAuthError> {
tokio::time::timeout(timeout, future)
.await
.map_err(|_| {
CodexAuthError::RefreshTransient(format!(
"codex_oauth_refresh timed out after {}s",
timeout.as_secs()
))
})?
.map_err(|e| CodexAuthError::RefreshTransient(format!("send: {e}")))
}
fn classify_permanent(body: &str) -> (PermanentReason, String) {
let code = extract_error_code(body);
let reason = match code.as_deref() {
@ -107,6 +126,23 @@ fn extract_error_code(body: &str) -> Option<String> {
mod tests {
use super::*;
#[tokio::test]
async fn refresh_response_timeout_is_transient() {
let err = match response_with_timeout(
std::future::pending::<Result<reqwest::Response, reqwest::Error>>(),
Duration::from_millis(5),
)
.await
{
Ok(_) => panic!("expected refresh timeout"),
Err(err) => err,
};
assert!(
matches!(err, CodexAuthError::RefreshTransient(message) if message.contains("timed out"))
);
}
#[test]
fn classify_expired() {
let body = r#"{"error":{"code":"refresh_token_expired"}}"#;