fix: add llm request lifecycle timeouts
This commit is contained in:
parent
bdabe789e3
commit
1babd021b0
|
|
@ -12,7 +12,7 @@ thiserror = { workspace = true }
|
|||
tracing = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
|
||||
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "time"] }
|
||||
tokio-util = "0.7"
|
||||
reqwest = { version = "0.13", default-features = false, features = ["stream", "json", "native-tls", "http2"] }
|
||||
eventsource-stream = "0.2"
|
||||
|
|
|
|||
|
|
@ -18,6 +18,11 @@ pub enum ClientError {
|
|||
message: String,
|
||||
retry_after: Option<Duration>,
|
||||
},
|
||||
/// A request lifecycle phase exceeded its hard timeout.
|
||||
Timeout {
|
||||
phase: &'static str,
|
||||
timeout: Duration,
|
||||
},
|
||||
/// 設定エラー
|
||||
Config(String),
|
||||
}
|
||||
|
|
@ -43,6 +48,9 @@ impl fmt::Display for ClientError {
|
|||
}
|
||||
write!(f, ": {}", message)
|
||||
}
|
||||
ClientError::Timeout { phase, timeout } => {
|
||||
write!(f, "{phase} timed out after {}s", timeout.as_secs())
|
||||
}
|
||||
ClientError::Config(msg) => write!(f, "Config error: {}", msg),
|
||||
}
|
||||
}
|
||||
|
|
@ -91,6 +99,7 @@ impl ClientError {
|
|||
/// 対象:
|
||||
/// - `Api { status }` のうち 408 / 425 / 429 / 500 / 502 / 503 / 504 / 529
|
||||
/// - `Http(reqwest::Error)` のうち `is_connect()` または `is_timeout()`
|
||||
/// - `Timeout { .. }` の lifecycle hard timeout
|
||||
///
|
||||
/// それ以外(Json、Sse、Config、上記以外の Api ステータス)は false。
|
||||
/// SSE 読み出し開始後の失敗は呼び出し側で `Sse` として上に流すため、
|
||||
|
|
@ -101,6 +110,7 @@ pub fn is_retryable(error: &ClientError) -> bool {
|
|||
status: Some(code), ..
|
||||
} => matches!(*code, 408 | 425 | 429 | 500 | 502 | 503 | 504 | 529),
|
||||
ClientError::Api { status: None, .. } => false,
|
||||
ClientError::Timeout { .. } => true,
|
||||
ClientError::Http(e) => e.is_connect() || e.is_timeout(),
|
||||
ClientError::Json(_) | ClientError::Sse(_) | ClientError::Config(_) => false,
|
||||
}
|
||||
|
|
@ -144,6 +154,14 @@ mod tests {
|
|||
assert!(!is_retryable(&api_err(None)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lifecycle_timeout_is_retryable() {
|
||||
assert!(is_retryable(&ClientError::Timeout {
|
||||
phase: "stream_open",
|
||||
timeout: Duration::from_secs(30),
|
||||
}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn json_sse_config_not_retryable() {
|
||||
let json_err = serde_json::from_str::<serde_json::Value>("not json").unwrap_err();
|
||||
|
|
|
|||
|
|
@ -23,6 +23,9 @@ use super::event::Event;
|
|||
use super::scheme::Scheme;
|
||||
use super::types::{Request, RequestConfig};
|
||||
|
||||
pub const DEFAULT_STREAM_OPEN_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
pub const DEFAULT_FIRST_STREAM_EVENT_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
|
||||
/// `AuthRef` を解決したランタイム表現。`crates/provider` が構築する。
|
||||
///
|
||||
/// - `None`: 認証ヘッダを送らない(Ollama 等の opt-out)
|
||||
|
|
@ -201,6 +204,17 @@ enum RequestBody {
|
|||
CompressedJson(Vec<u8>),
|
||||
}
|
||||
|
||||
async fn response_with_timeout(
|
||||
future: impl std::future::Future<Output = Result<reqwest::Response, reqwest::Error>>,
|
||||
timeout: Duration,
|
||||
phase: &'static str,
|
||||
) -> Result<reqwest::Response, ClientError> {
|
||||
tokio::time::timeout(timeout, future)
|
||||
.await
|
||||
.map_err(|_| ClientError::Timeout { phase, timeout })?
|
||||
.map_err(ClientError::Http)
|
||||
}
|
||||
|
||||
impl<S: Scheme + Clone> Clone for HttpTransport<S> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
|
|
@ -272,7 +286,9 @@ impl<S: Scheme + Clone + 'static> LlmClient for HttpTransport<S> {
|
|||
RequestBody::Json(body) => builder.json(&body),
|
||||
RequestBody::CompressedJson(body) => builder.body(body),
|
||||
};
|
||||
let response = builder.send().await.map_err(ClientError::Http)?;
|
||||
let response =
|
||||
response_with_timeout(builder.send(), DEFAULT_STREAM_OPEN_TIMEOUT, "stream_open")
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(classify_error_response(response).await);
|
||||
|
|
@ -391,6 +407,26 @@ mod tests {
|
|||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn response_timeout_returns_retryable_lifecycle_timeout() {
|
||||
let err = response_with_timeout(
|
||||
std::future::pending::<Result<reqwest::Response, reqwest::Error>>(),
|
||||
Duration::from_millis(5),
|
||||
"stream_open",
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
assert!(crate::llm_client::error::is_retryable(&err));
|
||||
assert!(matches!(
|
||||
err,
|
||||
ClientError::Timeout {
|
||||
phase: "stream_open",
|
||||
..
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn codex_backend_adds_conversation_headers_and_zstd_body() {
|
||||
let transport = transport(ResolvedAuth::Custom(Arc::new(TestAuthProvider {
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ use crate::{
|
|||
llm_client::{
|
||||
ClientError, ConfigWarning, LlmClient, Request, RequestConfig, ResponseStream,
|
||||
ToolDefinition, error::is_retryable, event::Event, retry::RetryPolicy,
|
||||
types::parse_tool_arguments,
|
||||
transport::DEFAULT_FIRST_STREAM_EVENT_TIMEOUT, types::parse_tool_arguments,
|
||||
},
|
||||
state::{Locked, Mutable, WorkerState},
|
||||
timeline::event::{ErrorEvent, StatusEvent, UsageEvent},
|
||||
|
|
@ -1334,7 +1334,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
}
|
||||
};
|
||||
|
||||
match stream_result {
|
||||
let err = match stream_result {
|
||||
Ok(stream) => {
|
||||
self.emit_lifecycle_trace(
|
||||
turn,
|
||||
|
|
@ -1345,7 +1345,26 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
"elapsed_ms": stream_started.elapsed().as_millis() as u64,
|
||||
}),
|
||||
);
|
||||
return Ok(stream);
|
||||
match wait_for_first_stream_event(stream, DEFAULT_FIRST_STREAM_EVENT_TIMEOUT)
|
||||
.await
|
||||
{
|
||||
Ok(FirstStreamEvent::Ready(stream)) => return Ok(stream),
|
||||
Ok(FirstStreamEvent::Empty(stream)) => return Ok(stream),
|
||||
Err(err) => {
|
||||
self.emit_lifecycle_trace(
|
||||
turn,
|
||||
llm_call,
|
||||
"stream_first_event_error",
|
||||
json!({
|
||||
"attempt": attempt,
|
||||
"elapsed_ms": stream_started.elapsed().as_millis() as u64,
|
||||
"retryable": is_retryable(&err),
|
||||
"error": err.to_string(),
|
||||
}),
|
||||
);
|
||||
err
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
self.emit_lifecycle_trace(
|
||||
|
|
@ -1360,54 +1379,56 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
"error": err.to_string(),
|
||||
}),
|
||||
);
|
||||
let next_failed_attempt = failed_attempt + 1;
|
||||
if next_failed_attempt >= policy.max_attempts || !is_retryable(&err) {
|
||||
self.last_run_interrupted = true;
|
||||
return Err(WorkerError::Client(err));
|
||||
}
|
||||
err
|
||||
}
|
||||
};
|
||||
|
||||
let wait = err
|
||||
.retry_after()
|
||||
.unwrap_or_else(|| policy.backoff(failed_attempt));
|
||||
let elapsed = started.elapsed();
|
||||
if elapsed + wait > policy.total_timeout {
|
||||
self.last_run_interrupted = true;
|
||||
return Err(WorkerError::Client(err));
|
||||
}
|
||||
let next_failed_attempt = failed_attempt + 1;
|
||||
if next_failed_attempt >= policy.max_attempts || !is_retryable(&err) {
|
||||
self.last_run_interrupted = true;
|
||||
return Err(WorkerError::Client(err));
|
||||
}
|
||||
|
||||
warn!(
|
||||
error = %err,
|
||||
failed_attempt = next_failed_attempt,
|
||||
wait_ms = wait.as_millis() as u64,
|
||||
"transient LLM request error, retrying"
|
||||
);
|
||||
let notice = LlmRetryNotice {
|
||||
failed_attempt: next_failed_attempt,
|
||||
max_attempts: policy.max_attempts,
|
||||
wait,
|
||||
elapsed,
|
||||
status: err.status(),
|
||||
error: err.to_string(),
|
||||
};
|
||||
for cb in &self.llm_retry_cbs {
|
||||
cb(llm_call, ¬ice);
|
||||
}
|
||||
let wait = err
|
||||
.retry_after()
|
||||
.unwrap_or_else(|| policy.backoff(failed_attempt));
|
||||
let elapsed = started.elapsed();
|
||||
if elapsed + wait > policy.total_timeout {
|
||||
self.last_run_interrupted = true;
|
||||
return Err(WorkerError::Client(err));
|
||||
}
|
||||
|
||||
tokio::select! {
|
||||
_ = tokio::time::sleep(wait) => {}
|
||||
cancel = self.cancel_rx.recv() => {
|
||||
if cancel.is_some() {
|
||||
info!("Cancelled during LLM retry backoff");
|
||||
}
|
||||
self.timeline.abort_current_block();
|
||||
self.last_run_interrupted = true;
|
||||
return Err(WorkerError::Cancelled);
|
||||
}
|
||||
}
|
||||
warn!(
|
||||
error = %err,
|
||||
failed_attempt = next_failed_attempt,
|
||||
wait_ms = wait.as_millis() as u64,
|
||||
"transient LLM request error, retrying"
|
||||
);
|
||||
let notice = LlmRetryNotice {
|
||||
failed_attempt: next_failed_attempt,
|
||||
max_attempts: policy.max_attempts,
|
||||
wait,
|
||||
elapsed,
|
||||
status: err.status(),
|
||||
error: err.to_string(),
|
||||
};
|
||||
for cb in &self.llm_retry_cbs {
|
||||
cb(llm_call, ¬ice);
|
||||
}
|
||||
|
||||
failed_attempt = next_failed_attempt;
|
||||
tokio::select! {
|
||||
_ = tokio::time::sleep(wait) => {}
|
||||
cancel = self.cancel_rx.recv() => {
|
||||
if cancel.is_some() {
|
||||
info!("Cancelled during LLM retry backoff");
|
||||
}
|
||||
self.timeline.abort_current_block();
|
||||
self.last_run_interrupted = true;
|
||||
return Err(WorkerError::Cancelled);
|
||||
}
|
||||
}
|
||||
|
||||
failed_attempt = next_failed_attempt;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1932,6 +1953,29 @@ impl<C: LlmClient> Worker<C, Locked> {
|
|||
}
|
||||
}
|
||||
|
||||
enum FirstStreamEvent {
|
||||
Ready(ResponseStream),
|
||||
Empty(ResponseStream),
|
||||
}
|
||||
|
||||
async fn wait_for_first_stream_event(
|
||||
mut stream: ResponseStream,
|
||||
timeout: std::time::Duration,
|
||||
) -> Result<FirstStreamEvent, ClientError> {
|
||||
match tokio::time::timeout(timeout, stream.next()).await {
|
||||
Ok(Some(first)) => {
|
||||
let first = first?;
|
||||
let stream = futures::stream::once(async move { Ok(first) }).chain(stream);
|
||||
Ok(FirstStreamEvent::Ready(Box::pin(stream)))
|
||||
}
|
||||
Ok(None) => Ok(FirstStreamEvent::Empty(stream)),
|
||||
Err(_) => Err(ClientError::Timeout {
|
||||
phase: "stream_first_event",
|
||||
timeout,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn items_trace_payload(
|
||||
items: &[Item],
|
||||
tools_len: usize,
|
||||
|
|
@ -1990,5 +2034,46 @@ fn item_kind(item: &Item) -> &'static str {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
// Basic tests only. Tests using LlmClient are done in integration tests.
|
||||
use super::*;
|
||||
use std::time::Duration;
|
||||
|
||||
#[tokio::test]
|
||||
async fn first_stream_event_timeout_returns_retryable_timeout() {
|
||||
let stream: ResponseStream = Box::pin(futures::stream::pending());
|
||||
let err = match wait_for_first_stream_event(stream, Duration::from_millis(5)).await {
|
||||
Ok(_) => panic!("expected first event timeout"),
|
||||
Err(err) => err,
|
||||
};
|
||||
|
||||
assert!(is_retryable(&err));
|
||||
assert!(matches!(
|
||||
err,
|
||||
ClientError::Timeout {
|
||||
phase: "stream_first_event",
|
||||
..
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn first_stream_event_is_replayed_after_probe() {
|
||||
let first = Event::Status(crate::llm_client::event::StatusEvent {
|
||||
status: crate::llm_client::event::ResponseStatus::Started,
|
||||
});
|
||||
let stream: ResponseStream = Box::pin(futures::stream::once({
|
||||
let first = first.clone();
|
||||
async move { Ok(first) }
|
||||
}));
|
||||
|
||||
let FirstStreamEvent::Ready(mut stream) =
|
||||
wait_for_first_stream_event(stream, Duration::from_secs(1))
|
||||
.await
|
||||
.unwrap()
|
||||
else {
|
||||
panic!("expected first event to be buffered");
|
||||
};
|
||||
|
||||
let replayed = stream.next().await.unwrap().unwrap();
|
||||
assert_eq!(replayed, first);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ reqwest = { version = "0.13", features = ["json", "native-tls"] }
|
|||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio = { workspace = true, features = ["sync", "fs", "rt"] }
|
||||
tokio = { workspace = true, features = ["sync", "fs", "rt", "time"] }
|
||||
toml = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
|
|
|
|||
|
|
@ -4,11 +4,13 @@
|
|||
//! 401 + `error.code` で永続失敗を分類する。
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
use super::error::{CodexAuthError, PermanentReason};
|
||||
|
||||
pub const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
|
||||
pub const REFRESH_URL: &str = "https://auth.openai.com/oauth/token";
|
||||
pub const DEFAULT_REFRESH_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct RefreshRequest<'a> {
|
||||
|
|
@ -41,13 +43,15 @@ pub async fn request_refresh(
|
|||
grant_type: "refresh_token",
|
||||
refresh_token,
|
||||
};
|
||||
let response = client
|
||||
.post(endpoint)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CodexAuthError::RefreshTransient(format!("send: {e}")))?;
|
||||
let response = response_with_timeout(
|
||||
client
|
||||
.post(endpoint)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send(),
|
||||
DEFAULT_REFRESH_TIMEOUT,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if status.is_success() {
|
||||
|
|
@ -68,6 +72,21 @@ pub async fn request_refresh(
|
|||
}
|
||||
}
|
||||
|
||||
async fn response_with_timeout(
|
||||
future: impl std::future::Future<Output = Result<reqwest::Response, reqwest::Error>>,
|
||||
timeout: Duration,
|
||||
) -> Result<reqwest::Response, CodexAuthError> {
|
||||
tokio::time::timeout(timeout, future)
|
||||
.await
|
||||
.map_err(|_| {
|
||||
CodexAuthError::RefreshTransient(format!(
|
||||
"codex_oauth_refresh timed out after {}s",
|
||||
timeout.as_secs()
|
||||
))
|
||||
})?
|
||||
.map_err(|e| CodexAuthError::RefreshTransient(format!("send: {e}")))
|
||||
}
|
||||
|
||||
fn classify_permanent(body: &str) -> (PermanentReason, String) {
|
||||
let code = extract_error_code(body);
|
||||
let reason = match code.as_deref() {
|
||||
|
|
@ -107,6 +126,23 @@ fn extract_error_code(body: &str) -> Option<String> {
|
|||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn refresh_response_timeout_is_transient() {
|
||||
let err = match response_with_timeout(
|
||||
std::future::pending::<Result<reqwest::Response, reqwest::Error>>(),
|
||||
Duration::from_millis(5),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_) => panic!("expected refresh timeout"),
|
||||
Err(err) => err,
|
||||
};
|
||||
|
||||
assert!(
|
||||
matches!(err, CodexAuthError::RefreshTransient(message) if message.contains("timed out"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_expired() {
|
||||
let body = r#"{"error":{"code":"refresh_token_expired"}}"#;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user