feat: surface llm retry and continuation state

This commit is contained in:
Keisuke Hirata 2026-05-26 07:13:59 +09:00
parent 3f750668ba
commit 3d3db8b6ac
No known key found for this signature in database
20 changed files with 626 additions and 245 deletions

View File

@ -59,4 +59,6 @@ pub use interceptor::Interceptor;
pub use message::{ContentPart, Item, Message, Role}; pub use message::{ContentPart, Item, Message, Role};
pub use tool::{ToolCall, ToolOutputLimits, ToolResult}; pub use tool::{ToolCall, ToolOutputLimits, ToolResult};
pub use usage_record::UsageRecord; pub use usage_record::UsageRecord;
pub use worker::{RunOutput, ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult}; pub use worker::{
LlmRetryNotice, RunOutput, ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult,
};

View File

@ -36,6 +36,8 @@ impl std::fmt::Display for ConfigWarning {
} }
} }
pub type ResponseStream = Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>>;
/// LLMクライアントのtrait /// LLMクライアントのtrait
/// ///
/// 各プロバイダはこのtraitを実装し、統一されたインターフェースを提供する。 /// 各プロバイダはこのtraitを実装し、統一されたインターフェースを提供する。
@ -49,10 +51,7 @@ pub trait LlmClient: Send + Sync {
/// # Returns /// # Returns
/// * `Ok(Stream)` - イベントストリーム /// * `Ok(Stream)` - イベントストリーム
/// * `Err(ClientError)` - エラー /// * `Err(ClientError)` - エラー
async fn stream( async fn stream(&self, request: Request) -> Result<ResponseStream, ClientError>;
&self,
request: Request,
) -> Result<Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>>, ClientError>;
/// Clone this client into a new `Box<dyn LlmClient>`. /// Clone this client into a new `Box<dyn LlmClient>`.
/// ///
@ -85,10 +84,7 @@ impl Clone for Box<dyn LlmClient> {
/// これにより、動的ディスパッチを使用するクライアントも `Worker` で利用可能になる。 /// これにより、動的ディスパッチを使用するクライアントも `Worker` で利用可能になる。
#[async_trait] #[async_trait]
impl LlmClient for Box<dyn LlmClient> { impl LlmClient for Box<dyn LlmClient> {
async fn stream( async fn stream(&self, request: Request) -> Result<ResponseStream, ClientError> {
&self,
request: Request,
) -> Result<Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>>, ClientError> {
(**self).stream(request).await (**self).stream(request).await
} }

View File

@ -1,6 +1,6 @@
//! LLMクライアントエラー型 //! LLMクライアントエラー型
use std::fmt; use std::{fmt, time::Duration};
/// LLMクライアントのエラー /// LLMクライアントのエラー
#[derive(Debug)] #[derive(Debug)]
@ -16,6 +16,7 @@ pub enum ClientError {
status: Option<u16>, status: Option<u16>,
code: Option<String>, code: Option<String>,
message: String, message: String,
retry_after: Option<Duration>,
}, },
/// 設定エラー /// 設定エラー
Config(String), Config(String),
@ -31,6 +32,7 @@ impl fmt::Display for ClientError {
status, status,
code, code,
message, message,
..
} => { } => {
write!(f, "API error")?; write!(f, "API error")?;
if let Some(s) = status { if let Some(s) = status {
@ -68,6 +70,22 @@ impl From<serde_json::Error> for ClientError {
} }
} }
impl ClientError {
pub fn status(&self) -> Option<u16> {
match self {
ClientError::Api { status, .. } => *status,
_ => None,
}
}
pub fn retry_after(&self) -> Option<Duration> {
match self {
ClientError::Api { retry_after, .. } => *retry_after,
_ => None,
}
}
}
/// transient な失敗としてリトライ対象になるかを判定する。 /// transient な失敗としてリトライ対象になるかを判定する。
/// ///
/// 対象: /// 対象:
@ -97,6 +115,7 @@ mod tests {
status, status,
code: None, code: None,
message: String::new(), message: String::new(),
retry_after: None,
} }
} }

View File

@ -1,8 +1,8 @@
//! HTTP transient エラー向けリトライポリシー。 //! LLM response stream を開く前の transient error 向けリトライポリシー。
//! //!
//! `transport.rs` の HTTP 送信〜ステータスチェック区間で `is_retryable` //! Worker が `LlmClient::stream` の open error に対して `is_retryable` を見て
//! が true を返した失敗をリトライする際に、待ち時間と打ち切り条件を //! retry / backoff / TUI event / cancellation をまとめて管理する。
//! 提供する。SSE 読み出し開始後の失敗は対象外。 //! SSE 読み出し開始後の失敗は対象外。
use std::time::Duration; use std::time::Duration;

View File

@ -131,6 +131,7 @@ impl GeminiScheme {
status: None, status: None,
code: Some("parse_error".to_string()), code: Some("parse_error".to_string()),
message: format!("Failed to parse Gemini SSE data: {} -> {}", e, data), message: format!("Failed to parse Gemini SSE data: {} -> {}", e, data),
retry_after: None,
})?; })?;
let mut events = Vec::new(); let mut events = Vec::new();

View File

@ -75,6 +75,7 @@ impl OpenAIScheme {
status: None, status: None,
code: Some("parse_error".to_string()), code: Some("parse_error".to_string()),
message: format!("Failed to parse SSE data: {} -> {}", e, data), message: format!("Failed to parse SSE data: {} -> {}", e, data),
retry_after: None,
})?; })?;
let mut events = Vec::new(); let mut events = Vec::new();

View File

@ -597,6 +597,7 @@ fn from_json<T: for<'de> Deserialize<'de>>(data: &str) -> Result<T, ClientError>
status: None, status: None,
code: Some("parse_error".to_string()), code: Some("parse_error".to_string()),
message: format!("Failed to parse SSE data: {e}"), message: format!("Failed to parse SSE data: {e}"),
retry_after: None,
}) })
} }

View File

@ -12,15 +12,12 @@ use async_trait::async_trait;
use eventsource_stream::Eventsource; use eventsource_stream::Eventsource;
use futures::{Stream, StreamExt, TryStreamExt}; use futures::{Stream, StreamExt, TryStreamExt};
use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue, RETRY_AFTER}; use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue, RETRY_AFTER};
use tokio::time::Instant;
use tracing::warn;
use super::auth::{AuthProvider, AuthRequirement}; use super::auth::{AuthProvider, AuthRequirement};
use super::capability::ModelCapability; use super::capability::ModelCapability;
use super::client::{ConfigWarning, LlmClient}; use super::client::{ConfigWarning, LlmClient, ResponseStream};
use super::error::{ClientError, is_retryable}; use super::error::ClientError;
use super::event::Event; use super::event::Event;
use super::retry::RetryPolicy;
use super::scheme::Scheme; use super::scheme::Scheme;
use super::types::{Request, RequestConfig}; use super::types::{Request, RequestConfig};
@ -67,7 +64,6 @@ pub struct HttpTransport<S: Scheme> {
base_url: String, base_url: String,
auth: ResolvedAuth, auth: ResolvedAuth,
capability: ModelCapability, capability: ModelCapability,
retry_policy: RetryPolicy,
} }
impl<S: Scheme> HttpTransport<S> { impl<S: Scheme> HttpTransport<S> {
@ -89,7 +85,6 @@ impl<S: Scheme> HttpTransport<S> {
base_url, base_url,
auth, auth,
capability, capability,
retry_policy: RetryPolicy::default(),
} }
} }
@ -99,12 +94,6 @@ impl<S: Scheme> HttpTransport<S> {
self self
} }
/// リトライポリシーを差し替える(テスト用 / 将来の manifest 化フック)。
pub fn with_retry_policy(mut self, policy: RetryPolicy) -> Self {
self.retry_policy = policy;
self
}
fn build_url(&self) -> String { fn build_url(&self) -> String {
let path = self.scheme.path(&self.model_id); let path = self.scheme.path(&self.model_id);
let url = format!("{}{}", self.base_url, path); let url = format!("{}{}", self.base_url, path);
@ -171,14 +160,12 @@ impl<S: Scheme + Clone> Clone for HttpTransport<S> {
base_url: self.base_url.clone(), base_url: self.base_url.clone(),
auth: self.auth.clone(), auth: self.auth.clone(),
capability: self.capability.clone(), capability: self.capability.clone(),
retry_policy: self.retry_policy.clone(),
} }
} }
} }
/// エラーレスポンスを `ClientError::Api` に変換し、`Retry-After` の秒数を /// エラーレスポンスを `ClientError::Api` に変換する。
/// 同時に取り出す。リトライループで wait の上書きに使う。 async fn classify_error_response(resp: reqwest::Response) -> ClientError {
async fn classify_error_response(resp: reqwest::Response) -> (ClientError, Option<Duration>) {
let status = resp.status().as_u16(); let status = resp.status().as_u16();
let retry_after = resp let retry_after = resp
.headers() .headers()
@ -187,7 +174,7 @@ async fn classify_error_response(resp: reqwest::Response) -> (ClientError, Optio
.and_then(|s| s.trim().parse::<u64>().ok()) .and_then(|s| s.trim().parse::<u64>().ok())
.map(Duration::from_secs); .map(Duration::from_secs);
let text = resp.text().await.unwrap_or_default(); let text = resp.text().await.unwrap_or_default();
let err = if let Ok(json) = serde_json::from_str::<serde_json::Value>(&text) { if let Ok(json) = serde_json::from_str::<serde_json::Value>(&text) {
let error = json.get("error").unwrap_or(&json); let error = json.get("error").unwrap_or(&json);
let code = error.get("type").and_then(|v| v.as_str()).map(String::from); let code = error.get("type").and_then(|v| v.as_str()).map(String::from);
let message = error let message = error
@ -199,15 +186,16 @@ async fn classify_error_response(resp: reqwest::Response) -> (ClientError, Optio
status: Some(status), status: Some(status),
code, code,
message, message,
retry_after,
} }
} else { } else {
ClientError::Api { ClientError::Api {
status: Some(status), status: Some(status),
code: None, code: None,
message: text, message: text,
retry_after,
} }
}; }
(err, retry_after)
} }
#[async_trait] #[async_trait]
@ -220,51 +208,25 @@ impl<S: Scheme + Clone + 'static> LlmClient for HttpTransport<S> {
self.scheme.validate_config(config) self.scheme.validate_config(config)
} }
async fn stream( async fn stream(&self, request: Request) -> Result<ResponseStream, ClientError> {
&self,
request: Request,
) -> Result<Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>>, ClientError> {
let url = self.build_url(); let url = self.build_url();
let headers = self.build_headers().await?; let headers = self.build_headers().await?;
let body = self let body = self
.scheme .scheme
.build_request_body(&self.model_id, &request, &self.capability); .build_request_body(&self.model_id, &request, &self.capability);
let policy = &self.retry_policy; let response = self
let started = Instant::now(); .http_client
let mut attempt: u32 = 0; .post(&url)
let response = loop { .headers(headers)
let send_result = self .json(&body)
.http_client .send()
.post(&url) .await
.headers(headers.clone()) .map_err(ClientError::Http)?;
.json(&body)
.send()
.await;
let (err, retry_after) = match send_result { if !response.status().is_success() {
Ok(resp) if resp.status().is_success() => break resp, return Err(classify_error_response(response).await);
Ok(resp) => classify_error_response(resp).await, }
Err(e) => (ClientError::Http(e), None),
};
let next_attempt = attempt + 1;
if next_attempt >= policy.max_attempts || !is_retryable(&err) {
return Err(err);
}
let wait = retry_after.unwrap_or_else(|| policy.backoff(attempt));
if started.elapsed() + wait > policy.total_timeout {
return Err(err);
}
warn!(
error = %err,
attempt = next_attempt,
wait_ms = wait.as_millis() as u64,
"transient HTTP error, retrying"
);
tokio::time::sleep(wait).await;
attempt = next_attempt;
};
let scheme = self.scheme.clone(); let scheme = self.scheme.clone();
let byte_stream = response.bytes_stream().map_err(std::io::Error::other); let byte_stream = response.bytes_stream().map_err(std::io::Error::other);

View File

@ -1,5 +1,5 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::marker::PhantomData; use std::{marker::PhantomData, time::Instant};
use futures::StreamExt; use futures::StreamExt;
use tokio::sync::mpsc; use tokio::sync::mpsc;
@ -17,8 +17,8 @@ use crate::{
PromptAction, ToolCallInfo, ToolResultInfo, TurnEndAction, PromptAction, ToolCallInfo, ToolResultInfo, TurnEndAction,
}, },
llm_client::{ llm_client::{
ClientError, ConfigWarning, LlmClient, Request, RequestConfig, ToolDefinition, ClientError, ConfigWarning, LlmClient, Request, RequestConfig, ResponseStream,
types::parse_tool_arguments, ToolDefinition, error::is_retryable, retry::RetryPolicy, types::parse_tool_arguments,
}, },
state::{Locked, Mutable, WorkerState}, state::{Locked, Mutable, WorkerState},
timeline::event::{ErrorEvent, StatusEvent, UsageEvent}, timeline::event::{ErrorEvent, StatusEvent, UsageEvent},
@ -99,6 +99,8 @@ enum ToolExecutionResult {
Paused, Paused,
} }
const MAX_STREAM_CONTINUATIONS: u32 = 3;
/// Central component for managing LLM interactions /// Central component for managing LLM interactions
/// ///
/// Receives input from the user, sends requests to the LLM, and /// Receives input from the user, sends requests to the LLM, and
@ -131,9 +133,28 @@ enum ToolExecutionResult {
/// let out = worker.run("Continue").await?; /// let out = worker.run("Continue").await?;
/// let mut worker = out.worker; /// let mut worker = out.worker;
/// ``` /// ```
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LlmRetryNotice {
/// 直近で失敗した attempt 番号。1 origin。
pub failed_attempt: u32,
pub max_attempts: u32,
pub wait: std::time::Duration,
pub elapsed: std::time::Duration,
pub status: Option<u16>,
pub error: String,
}
#[derive(Debug)]
enum StreamCompletion {
Complete,
Interrupted { reason: String },
}
pub struct Worker<C: LlmClient, S: WorkerState = Mutable> { pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
/// LLM client /// LLM client
client: C, client: C,
/// Retry policy for opening an LLM response stream.
retry_policy: RetryPolicy,
/// Event timeline /// Event timeline
timeline: Timeline, timeline: Timeline,
/// Text block collector (Timeline handler) /// Text block collector (Timeline handler)
@ -175,6 +196,10 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
llm_call_start_cbs: Vec<Box<dyn Fn(usize) + Send + Sync>>, llm_call_start_cbs: Vec<Box<dyn Fn(usize) + Send + Sync>>,
/// LlmCall-end callbacks /// LlmCall-end callbacks
llm_call_end_cbs: Vec<Box<dyn Fn(usize) + Send + Sync>>, llm_call_end_cbs: Vec<Box<dyn Fn(usize) + Send + Sync>>,
/// Transport-level retry callbacks for a specific LlmCall.
llm_retry_cbs: Vec<Box<dyn Fn(usize, &LlmRetryNotice) + Send + Sync>>,
/// Stream continuation callbacks for a specific LlmCall.
llm_continuation_cbs: Vec<Box<dyn Fn(usize, u32, u32, &str) + Send + Sync>>,
/// Non-fatal warning callbacks. Invoked when the Worker wants to /// Non-fatal warning callbacks. Invoked when the Worker wants to
/// surface an advisory message to the upper layer (e.g. Pod) so it /// surface an advisory message to the upper layer (e.g. Pod) so it
/// can be forwarded to the user — distinct from `tracing::warn!`, /// can be forwarded to the user — distinct from `tracing::warn!`,
@ -355,6 +380,34 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
self.llm_call_end_cbs.push(Box::new(callback)); self.llm_call_end_cbs.push(Box::new(callback));
} }
/// Register a transport-level retry callback.
pub fn on_llm_retry(
&mut self,
callback: impl Fn(usize, &LlmRetryNotice) + Send + Sync + 'static,
) {
self.llm_retry_cbs.push(Box::new(callback));
}
/// Register a stream continuation callback.
pub fn on_llm_continuation(
&mut self,
callback: impl Fn(usize, u32, u32, &str) + Send + Sync + 'static,
) {
self.llm_continuation_cbs.push(Box::new(callback));
}
fn emit_llm_continuation(
&self,
llm_call: usize,
attempt: u32,
max_attempts: u32,
reason: &str,
) {
for cb in &self.llm_continuation_cbs {
cb(llm_call, attempt, max_attempts, reason);
}
}
/// Register a non-fatal warning callback. /// Register a non-fatal warning callback.
/// ///
/// The callback is invoked with a short human-readable message /// The callback is invoked with a short human-readable message
@ -964,6 +1017,8 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
} }
} }
let mut stream_continuations: u32 = 0;
let mut continuing_stream = false;
loop { loop {
if self.try_cancelled() { if self.try_cancelled() {
info!("Execution cancelled"); info!("Execution cancelled");
@ -973,9 +1028,11 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
} }
let current_turn = self.turn_count; let current_turn = self.turn_count;
debug!(turn = current_turn, "Turn start"); if !continuing_stream {
for cb in &self.turn_start_cbs { debug!(turn = current_turn, "Turn start");
cb(current_turn); for cb in &self.turn_start_cbs {
cb(current_turn);
}
} }
// Drain interceptor-side inputs that are meant to land in // Drain interceptor-side inputs that are meant to land in
@ -1080,13 +1137,50 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
// Stream LLM response // Stream LLM response
let request = self.build_request(&tool_definitions, &request_context); let request = self.build_request(&tool_definitions, &request_context);
self.stream_response(request).await?; let stream_outcome = self.stream_response(request, current_llm_call).await?;
for cb in &self.llm_call_end_cbs { for cb in &self.llm_call_end_cbs {
cb(current_llm_call); cb(current_llm_call);
} }
self.llm_call_count += 1; self.llm_call_count += 1;
if let StreamCompletion::Interrupted { reason } = stream_outcome {
stream_continuations += 1;
if stream_continuations > MAX_STREAM_CONTINUATIONS {
self.last_run_interrupted = true;
return Err(WorkerError::Client(ClientError::Api {
status: None,
code: None,
message: format!("LLM stream interrupted too many times: {reason}"),
retry_after: None,
}));
}
self.timeline.abort_current_block();
self.timeline.flush_usage();
let reasoning_items = self.reasoning_item_collector.take_collected();
let text_blocks = self.text_block_collector.take_collected();
// Do not recover tool calls from an interrupted stream. A completed
// tool_use is executable only when the provider finishes the stream.
let _dropped_tool_calls = self.tool_call_collector.take_collected();
let assistant_items =
self.build_assistant_items(&reasoning_items, &text_blocks, &[]);
if !assistant_items.is_empty() {
self.append_history_items(assistant_items);
}
self.emit_llm_continuation(
current_llm_call,
stream_continuations,
MAX_STREAM_CONTINUATIONS,
&reason,
);
continuing_stream = true;
continue;
}
stream_continuations = 0;
continuing_stream = false;
for cb in &self.turn_end_cbs { for cb in &self.turn_end_cbs {
cb(current_turn); cb(current_turn);
} }
@ -1138,8 +1232,88 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
} }
} }
async fn open_stream_with_retry(
&mut self,
request: Request,
llm_call: usize,
) -> Result<ResponseStream, WorkerError> {
let policy = self.retry_policy.clone();
let started = Instant::now();
let mut failed_attempt: u32 = 0;
loop {
let stream_result = tokio::select! {
stream_result = self.client.stream(request.clone()) => stream_result,
cancel = self.cancel_rx.recv() => {
if cancel.is_some() {
info!("Cancelled before stream started");
}
self.timeline.abort_current_block();
self.last_run_interrupted = true;
return Err(WorkerError::Cancelled);
}
};
match stream_result {
Ok(stream) => return Ok(stream),
Err(err) => {
let next_failed_attempt = failed_attempt + 1;
if next_failed_attempt >= policy.max_attempts || !is_retryable(&err) {
self.last_run_interrupted = true;
return Err(WorkerError::Client(err));
}
let wait = err
.retry_after()
.unwrap_or_else(|| policy.backoff(failed_attempt));
let elapsed = started.elapsed();
if elapsed + wait > policy.total_timeout {
self.last_run_interrupted = true;
return Err(WorkerError::Client(err));
}
warn!(
error = %err,
failed_attempt = next_failed_attempt,
wait_ms = wait.as_millis() as u64,
"transient LLM request error, retrying"
);
let notice = LlmRetryNotice {
failed_attempt: next_failed_attempt,
max_attempts: policy.max_attempts,
wait,
elapsed,
status: err.status(),
error: err.to_string(),
};
for cb in &self.llm_retry_cbs {
cb(llm_call, &notice);
}
tokio::select! {
_ = tokio::time::sleep(wait) => {}
cancel = self.cancel_rx.recv() => {
if cancel.is_some() {
info!("Cancelled during LLM retry backoff");
}
self.timeline.abort_current_block();
self.last_run_interrupted = true;
return Err(WorkerError::Cancelled);
}
}
failed_attempt = next_failed_attempt;
}
}
}
}
/// Open a stream, dispatch all events to the timeline, handle cancellation. /// Open a stream, dispatch all events to the timeline, handle cancellation.
async fn stream_response(&mut self, request: Request) -> Result<(), WorkerError> { async fn stream_response(
&mut self,
request: Request,
llm_call: usize,
) -> Result<StreamCompletion, WorkerError> {
debug!( debug!(
item_count = request.items.len(), item_count = request.items.len(),
tool_count = request.tools.len(), tool_count = request.tools.len(),
@ -1147,18 +1321,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
"Sending request to LLM" "Sending request to LLM"
); );
let mut stream = tokio::select! { let mut stream = self.open_stream_with_retry(request, llm_call).await?;
stream_result = self.client.stream(request) => stream_result
.inspect_err(|_| self.last_run_interrupted = true)?,
cancel = self.cancel_rx.recv() => {
if cancel.is_some() {
info!("Cancelled before stream started");
}
self.timeline.abort_current_block();
self.last_run_interrupted = true;
return Err(WorkerError::Cancelled);
}
};
let mut event_count: usize = 0; let mut event_count: usize = 0;
loop { loop {
@ -1175,12 +1338,17 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
warn!(error = %e, "Stream error"); warn!(error = %e, "Stream error");
} }
} }
let event = result let event = match result {
.inspect_err(|_| { Ok(event) => event,
Err(err) => {
self.last_run_interrupted = true; self.last_run_interrupted = true;
// 部分情報でも発火しておく(料金会計用) // 部分情報でも発火しておく(料金会計用)
self.timeline.flush_usage(); self.timeline.flush_usage();
})?; return Ok(StreamCompletion::Interrupted {
reason: err.to_string(),
});
}
};
self.timeline.dispatch(&event); self.timeline.dispatch(&event);
} }
None => break, None => break,
@ -1200,7 +1368,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
// ストリーム完了時に集約済み Usage を 1 度だけ発火 // ストリーム完了時に集約済み Usage を 1 度だけ発火
self.timeline.flush_usage(); self.timeline.flush_usage();
debug!(event_count = event_count, "Stream completed"); debug!(event_count = event_count, "Stream completed");
Ok(()) Ok(StreamCompletion::Complete)
} }
/// Execute tools and push results to history. /// Execute tools and push results to history.
@ -1254,6 +1422,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
Self { Self {
client, client,
retry_policy: RetryPolicy::default(),
timeline, timeline,
text_block_collector, text_block_collector,
tool_call_collector, tool_call_collector,
@ -1270,6 +1439,8 @@ impl<C: LlmClient> Worker<C, Mutable> {
turn_end_cbs: Vec::new(), turn_end_cbs: Vec::new(),
llm_call_start_cbs: Vec::new(), llm_call_start_cbs: Vec::new(),
llm_call_end_cbs: Vec::new(), llm_call_end_cbs: Vec::new(),
llm_retry_cbs: Vec::new(),
llm_continuation_cbs: Vec::new(),
warning_cbs: Vec::new(), warning_cbs: Vec::new(),
tool_result_cbs: Vec::new(), tool_result_cbs: Vec::new(),
history_append_cbs: Vec::new(), history_append_cbs: Vec::new(),
@ -1385,6 +1556,12 @@ impl<C: LlmClient> Worker<C, Mutable> {
self self
} }
/// Set the retry policy used when opening an LLM response stream.
pub fn with_retry_policy(mut self, retry_policy: RetryPolicy) -> Self {
self.retry_policy = retry_policy;
self
}
/// Validate current configuration against the provider /// Validate current configuration against the provider
/// ///
/// Returns an error if there are unsupported settings. /// Returns an error if there are unsupported settings.
@ -1507,6 +1684,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
let locked_prefix_len = self.history.len(); let locked_prefix_len = self.history.len();
Worker { Worker {
client: self.client, client: self.client,
retry_policy: self.retry_policy,
timeline: self.timeline, timeline: self.timeline,
text_block_collector: self.text_block_collector, text_block_collector: self.text_block_collector,
tool_call_collector: self.tool_call_collector, tool_call_collector: self.tool_call_collector,
@ -1523,6 +1701,8 @@ impl<C: LlmClient> Worker<C, Mutable> {
turn_end_cbs: self.turn_end_cbs, turn_end_cbs: self.turn_end_cbs,
llm_call_start_cbs: self.llm_call_start_cbs, llm_call_start_cbs: self.llm_call_start_cbs,
llm_call_end_cbs: self.llm_call_end_cbs, llm_call_end_cbs: self.llm_call_end_cbs,
llm_retry_cbs: self.llm_retry_cbs,
llm_continuation_cbs: self.llm_continuation_cbs,
warning_cbs: self.warning_cbs, warning_cbs: self.warning_cbs,
tool_result_cbs: self.tool_result_cbs, tool_result_cbs: self.tool_result_cbs,
history_append_cbs: self.history_append_cbs, history_append_cbs: self.history_append_cbs,
@ -1594,6 +1774,7 @@ impl<C: LlmClient> Worker<C, Locked> {
pub fn unlock(self) -> Worker<C, Mutable> { pub fn unlock(self) -> Worker<C, Mutable> {
Worker { Worker {
client: self.client, client: self.client,
retry_policy: self.retry_policy,
timeline: self.timeline, timeline: self.timeline,
text_block_collector: self.text_block_collector, text_block_collector: self.text_block_collector,
tool_call_collector: self.tool_call_collector, tool_call_collector: self.tool_call_collector,
@ -1610,6 +1791,8 @@ impl<C: LlmClient> Worker<C, Locked> {
turn_end_cbs: self.turn_end_cbs, turn_end_cbs: self.turn_end_cbs,
llm_call_start_cbs: self.llm_call_start_cbs, llm_call_start_cbs: self.llm_call_start_cbs,
llm_call_end_cbs: self.llm_call_end_cbs, llm_call_end_cbs: self.llm_call_end_cbs,
llm_retry_cbs: self.llm_retry_cbs,
llm_continuation_cbs: self.llm_continuation_cbs,
warning_cbs: self.warning_cbs, warning_cbs: self.warning_cbs,
tool_result_cbs: self.tool_result_cbs, tool_result_cbs: self.tool_result_cbs,
history_append_cbs: self.history_append_cbs, history_append_cbs: self.history_append_cbs,

View File

@ -4,17 +4,77 @@
mod common; mod common;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::time::Duration;
use async_trait::async_trait; use async_trait::async_trait;
use common::MockLlmClient; use common::MockLlmClient;
use llm_worker::Worker; use llm_worker::Worker;
use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent as ClientStatusEvent}; use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent as ClientStatusEvent};
use llm_worker::llm_client::retry::RetryPolicy;
use llm_worker::llm_client::{ClientError, LlmClient, Request, ResponseStream};
use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta, ToolOutput}; use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta, ToolOutput};
// ============================================================================= #[derive(Clone)]
// Tests struct FailOnceClient {
// ============================================================================= calls: Arc<AtomicUsize>,
events: Vec<Event>,
}
#[async_trait]
impl LlmClient for FailOnceClient {
async fn stream(&self, _request: Request) -> Result<ResponseStream, ClientError> {
if self.calls.fetch_add(1, Ordering::SeqCst) == 0 {
return Err(ClientError::Api {
status: Some(504),
code: None,
message: "gateway timeout".into(),
retry_after: None,
});
}
Ok(Box::pin(futures::stream::iter(
self.events.clone().into_iter().map(Ok),
)))
}
fn clone_boxed(&self) -> Box<dyn LlmClient> {
Box::new(self.clone())
}
}
#[tokio::test]
async fn test_callback_llm_retry_event() {
let events = vec![Event::Status(ClientStatusEvent {
status: ResponseStatus::Completed,
})];
let client = FailOnceClient {
calls: Arc::new(AtomicUsize::new(0)),
events,
};
let mut worker = Worker::new(client).with_retry_policy(RetryPolicy {
base: Duration::from_millis(1),
cap: Duration::from_millis(1),
max_attempts: 2,
total_timeout: Duration::from_secs(1),
});
let notices = Arc::new(Mutex::new(Vec::new()));
let sink = notices.clone();
worker.on_llm_retry(move |llm_call, notice| {
sink.lock().unwrap().push((llm_call, notice.clone()));
});
let result = worker.run("retry once").await;
assert!(result.is_ok(), "worker should succeed after one retry");
let notices = notices.lock().unwrap();
assert_eq!(notices.len(), 1);
assert_eq!(notices[0].0, 0);
assert_eq!(notices[0].1.failed_attempt, 1);
assert_eq!(notices[0].1.max_attempts, 2);
assert_eq!(notices[0].1.status, Some(504));
}
/// Verify that on_text_block correctly receives delta and stop events /// Verify that on_text_block correctly receives delta and stop events
#[tokio::test] #[tokio::test]

View File

@ -59,6 +59,7 @@ impl LlmClient for MockLlmClient {
status: Some(500), status: Some(500),
code: Some("mock_error".to_string()), code: Some("mock_error".to_string()),
message: "No more mock responses".to_string(), message: "No more mock responses".to_string(),
retry_after: None,
}); });
} }
let events = self.responses[count].clone(); let events = self.responses[count].clone();

View File

@ -1,12 +1,7 @@
//! HTTP transport の transient エラーリトライ挙動の integration テスト。 //! HTTP transport の単発 request / error classification テスト。
//! //!
//! 対応チケット: `tickets/llm-worker-transient-retry.md`。 //! Retry/backoff は Worker の lifecycle 管理に属するため、transport は 1 回だけ
//! - 503 / 529 / connect refused でリトライ発火 //! request を送り、HTTP status / Retry-After を `ClientError` に載せて返す。
//! - max_attempts 上限到達でエラー
//! - `Retry-After` ヘッダで指数バックオフを上書き
//! - `parse_sse` 由来の `ClientError::Sse`mid-stream 想定)はリトライしない
use std::time::{Duration, Instant};
use futures::StreamExt; use futures::StreamExt;
use llm_worker::llm_client::LlmClient; use llm_worker::llm_client::LlmClient;
@ -14,16 +9,16 @@ use llm_worker::llm_client::auth::AuthRequirement;
use llm_worker::llm_client::capability::ModelCapability; use llm_worker::llm_client::capability::ModelCapability;
use llm_worker::llm_client::error::ClientError; use llm_worker::llm_client::error::ClientError;
use llm_worker::llm_client::event::Event; use llm_worker::llm_client::event::Event;
use llm_worker::llm_client::retry::RetryPolicy;
use llm_worker::llm_client::scheme::Scheme; use llm_worker::llm_client::scheme::Scheme;
use llm_worker::llm_client::transport::{HttpTransport, ResolvedAuth}; use llm_worker::llm_client::transport::{HttpTransport, ResolvedAuth};
use llm_worker::llm_client::types::Request; use llm_worker::llm_client::types::Request;
use serde_json::Value; use serde_json::Value;
use std::time::Duration;
use wiremock::matchers::{method, path}; use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate}; use wiremock::{Mock, MockServer, ResponseTemplate};
/// SSE 本体は触らないテスト用 scheme。`parse_fail` を立てると /// SSE 本体は触らないテスト用 scheme。`parse_fail` を立てると
/// stream 消費中= retry loop の外)で `ClientError::Sse` を返す。 /// stream 消費中で `ClientError::Sse` を返す。
#[derive(Clone)] #[derive(Clone)]
struct DummyScheme { struct DummyScheme {
parse_fail: bool, parse_fail: bool,
@ -31,18 +26,23 @@ struct DummyScheme {
impl Scheme for DummyScheme { impl Scheme for DummyScheme {
type State = (); type State = ();
fn default_base_url(&self) -> &'static str { fn default_base_url(&self) -> &'static str {
"" ""
} }
fn path(&self, _: &str) -> String { fn path(&self, _: &str) -> String {
"/v1/chat".into() "/v1/chat".into()
} }
fn required_auth(&self) -> AuthRequirement { fn required_auth(&self) -> AuthRequirement {
AuthRequirement::None AuthRequirement::None
} }
fn build_request_body(&self, _: &str, _: &Request, _: &ModelCapability) -> Value { fn build_request_body(&self, _: &str, _: &Request, _: &ModelCapability) -> Value {
serde_json::json!({}) serde_json::json!({})
} }
fn parse_sse(&self, _: &str, _: &str, _: &mut ()) -> Result<Vec<Event>, ClientError> { fn parse_sse(&self, _: &str, _: &str, _: &mut ()) -> Result<Vec<Event>, ClientError> {
if self.parse_fail { if self.parse_fail {
Err(ClientError::Sse( Err(ClientError::Sse(
@ -52,25 +52,13 @@ impl Scheme for DummyScheme {
Ok(vec![]) Ok(vec![])
} }
} }
fn default_capability(&self) -> ModelCapability { fn default_capability(&self) -> ModelCapability {
ModelCapability::minimal() ModelCapability::minimal()
} }
} }
fn fast_policy(max_attempts: u32) -> RetryPolicy { fn build_transport(base_url: impl Into<String>, parse_fail: bool) -> HttpTransport<DummyScheme> {
RetryPolicy {
base: Duration::from_millis(1),
cap: Duration::from_millis(1),
max_attempts,
total_timeout: Duration::from_secs(60),
}
}
fn build_transport(
base_url: impl Into<String>,
parse_fail: bool,
policy: RetryPolicy,
) -> HttpTransport<DummyScheme> {
HttpTransport::new( HttpTransport::new(
DummyScheme { parse_fail }, DummyScheme { parse_fail },
"test-model", "test-model",
@ -78,7 +66,6 @@ fn build_transport(
ResolvedAuth::None, ResolvedAuth::None,
ModelCapability::minimal(), ModelCapability::minimal(),
) )
.with_retry_policy(policy)
} }
fn ok_sse() -> ResponseTemplate { fn ok_sse() -> ResponseTemplate {
@ -88,78 +75,11 @@ fn ok_sse() -> ResponseTemplate {
} }
#[tokio::test] #[tokio::test]
async fn retries_503_then_succeeds() { async fn retryable_status_returns_api_error_without_retrying() {
let server = MockServer::start().await; let server = MockServer::start().await;
Mock::given(method("POST")) Mock::given(method("POST"))
.and(path("/v1/chat")) .and(path("/v1/chat"))
.respond_with(ResponseTemplate::new(503).set_body_string("upstream connect error")) .respond_with(ResponseTemplate::new(503).set_body_string("upstream connect error"))
.up_to_n_times(2)
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/v1/chat"))
.respond_with(ok_sse())
.mount(&server)
.await;
let transport = build_transport(server.uri(), false, fast_policy(5));
let mut stream = transport
.stream(Request::default())
.await
.expect("stream should succeed after retries");
while stream.next().await.is_some() {}
let received = server.received_requests().await.unwrap();
assert_eq!(received.len(), 3, "two failures plus one success expected");
}
#[tokio::test]
async fn retries_529_then_exhausts() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat"))
.respond_with(ResponseTemplate::new(529).set_body_string("overloaded"))
.mount(&server)
.await;
let transport = build_transport(server.uri(), false, fast_policy(3));
match transport.stream(Request::default()).await {
Err(ClientError::Api {
status: Some(529), ..
}) => {}
Err(other) => panic!("expected Api(529), got {other:?}"),
Ok(_) => panic!("expected error after exhausting retries"),
}
let received = server.received_requests().await.unwrap();
assert_eq!(received.len(), 3, "should hit max_attempts and stop");
}
#[tokio::test]
async fn connect_refused_retries_then_fails() {
// 接続不能なローカルアドレスを使う。Linux では `Connection refused` で
// 即時失敗するため、`fast_policy` ならテストが秒以下で終わる。
let unreachable = "http://127.0.0.1:1";
let transport = build_transport(unreachable, false, fast_policy(3));
match transport.stream(Request::default()).await {
Err(ClientError::Http(e)) => {
assert!(
e.is_connect() || e.is_timeout(),
"expected connect/timeout, got {e:?}"
);
}
Err(other) => panic!("expected Http error, got {other:?}"),
Ok(_) => panic!("expected error connecting to closed port"),
}
}
#[tokio::test]
async fn retry_after_header_overrides_backoff() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat"))
.respond_with(ResponseTemplate::new(503).insert_header("retry-after", "1"))
.up_to_n_times(1) .up_to_n_times(1)
.mount(&server) .mount(&server)
.await; .await;
@ -169,34 +89,48 @@ async fn retry_after_header_overrides_backoff() {
.mount(&server) .mount(&server)
.await; .await;
// base/cap を 1ms に絞った policy で `Retry-After: 1` を観察すると、 let transport = build_transport(server.uri(), false);
// 指数バックオフ単独なら 1ms 程度で終わるはずが Retry-After に従って match transport.stream(Request::default()).await {
// 1 秒待つ → 経過時間で override を検証できる。 Err(ClientError::Api {
let policy = RetryPolicy { status: Some(503), ..
base: Duration::from_millis(1), }) => {}
cap: Duration::from_millis(1), Err(other) => panic!("expected Api(503), got {other:?}"),
max_attempts: 3, Ok(_) => panic!("transport must not retry internally"),
total_timeout: Duration::from_secs(10), }
};
let transport = build_transport(server.uri(), false, policy);
let start = Instant::now(); let received = server.received_requests().await.unwrap();
let mut stream = transport.stream(Request::default()).await.expect("ok"); assert_eq!(
while stream.next().await.is_some() {} received.len(),
let elapsed = start.elapsed(); 1,
"transport should send exactly one request"
assert!(
elapsed >= Duration::from_secs(1),
"Retry-After=1 should make us wait >=1s, elapsed={elapsed:?}"
);
assert!(
elapsed < Duration::from_secs(3),
"Retry-After=1 should not balloon, elapsed={elapsed:?}"
); );
} }
#[tokio::test] #[tokio::test]
async fn mid_stream_sse_error_does_not_retry() { async fn retry_after_header_is_preserved_on_api_error() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat"))
.respond_with(ResponseTemplate::new(503).insert_header("retry-after", "1"))
.mount(&server)
.await;
let transport = build_transport(server.uri(), false);
match transport.stream(Request::default()).await {
Err(
err @ ClientError::Api {
status: Some(503), ..
},
) => {
assert_eq!(err.retry_after(), Some(Duration::from_secs(1)));
}
Err(other) => panic!("expected Api(503), got {other:?}"),
Ok(_) => panic!("expected error"),
}
}
#[tokio::test]
async fn mid_stream_sse_error_is_stream_item_error() {
let server = MockServer::start().await; let server = MockServer::start().await;
Mock::given(method("POST")) Mock::given(method("POST"))
.and(path("/v1/chat")) .and(path("/v1/chat"))
@ -211,11 +145,11 @@ async fn mid_stream_sse_error_does_not_retry() {
.mount(&server) .mount(&server)
.await; .await;
let transport = build_transport(server.uri(), true, fast_policy(5)); let transport = build_transport(server.uri(), true);
let mut stream = transport let mut stream = transport
.stream(Request::default()) .stream(Request::default())
.await .await
.expect("status 200 should bypass retry loop"); .expect("status 200 should open stream");
let mut saw_sse_err = false; let mut saw_sse_err = false;
while let Some(item) = stream.next().await { while let Some(item) = stream.next().await {
if matches!(item, Err(ClientError::Sse(_))) { if matches!(item, Err(ClientError::Sse(_))) {
@ -225,11 +159,11 @@ async fn mid_stream_sse_error_does_not_retry() {
assert!(saw_sse_err, "expected Sse error from stream consumer"); assert!(saw_sse_err, "expected Sse error from stream consumer");
let received = server.received_requests().await.unwrap(); let received = server.received_requests().await.unwrap();
assert_eq!(received.len(), 1, "mid-stream Sse must not retry"); assert_eq!(received.len(), 1, "mid-stream Sse must not reopen stream");
} }
#[tokio::test] #[tokio::test]
async fn non_retryable_status_returns_immediately() { async fn non_retryable_status_returns_api_error() {
let server = MockServer::start().await; let server = MockServer::start().await;
Mock::given(method("POST")) Mock::given(method("POST"))
.and(path("/v1/chat")) .and(path("/v1/chat"))
@ -237,7 +171,7 @@ async fn non_retryable_status_returns_immediately() {
.mount(&server) .mount(&server)
.await; .await;
let transport = build_transport(server.uri(), false, fast_policy(5)); let transport = build_transport(server.uri(), false);
match transport.stream(Request::default()).await { match transport.stream(Request::default()).await {
Err(ClientError::Api { Err(ClientError::Api {
status: Some(401), .. status: Some(401), ..
@ -247,5 +181,5 @@ async fn non_retryable_status_returns_immediately() {
} }
let received = server.received_requests().await.unwrap(); let received = server.received_requests().await.unwrap();
assert_eq!(received.len(), 1, "401 must not retry"); assert_eq!(received.len(), 1);
} }

View File

@ -330,6 +330,29 @@ fn wire_event_bridges_on_worker<C, St>(
let _ = tx.send(Event::LlmCallEnd { llm_call }); let _ = tx.send(Event::LlmCallEnd { llm_call });
}); });
let tx = event_tx.clone();
worker.on_llm_retry(move |llm_call, notice| {
let _ = tx.send(Event::LlmRetry {
llm_call,
failed_attempt: notice.failed_attempt,
max_attempts: notice.max_attempts,
wait_ms: notice.wait.as_millis() as u64,
elapsed_ms: notice.elapsed.as_millis() as u64,
status: notice.status,
error: notice.error.clone(),
});
});
let tx = event_tx.clone();
worker.on_llm_continuation(move |llm_call, attempt, max_attempts, reason| {
let _ = tx.send(Event::LlmContinuation {
llm_call,
attempt,
max_attempts,
reason: reason.to_owned(),
});
});
let tx = event_tx.clone(); let tx = event_tx.clone();
let activity = ai_activity.clone(); let activity = ai_activity.clone();
worker.on_text_block(move |block| { worker.on_text_block(move |block| {

View File

@ -101,6 +101,7 @@ impl LlmClient for MockClient {
status: Some(500), status: Some(500),
code: Some("mock".into()), code: Some("mock".into()),
message: "No more responses".into(), message: "No more responses".into(),
retry_after: None,
}); });
} }
let response = self.responses[count].clone(); let response = self.responses[count].clone();

View File

@ -298,6 +298,29 @@ pub enum Event {
LlmCallEnd { LlmCallEnd {
llm_call: usize, llm_call: usize,
}, },
/// A transport-level LLM request retry has been scheduled.
///
/// This is operational state for clients to render while the worker is
/// waiting in backoff. It is not part of conversation history.
LlmRetry {
llm_call: usize,
/// The attempt that just failed. 1 origin.
failed_attempt: u32,
max_attempts: u32,
wait_ms: u64,
elapsed_ms: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
status: Option<u16>,
error: String,
},
/// Stream generation was interrupted after events had begun and the worker
/// is continuing with a follow-up LLM request.
LlmContinuation {
llm_call: usize,
attempt: u32,
max_attempts: u32,
reason: String,
},
TextDelta { TextDelta {
text: String, text: String,
}, },
@ -867,6 +890,69 @@ mod tests {
assert_eq!(parsed["data"]["llm_call"], 3); assert_eq!(parsed["data"]["llm_call"], 3);
} }
#[test]
fn event_llm_retry_roundtrip() {
let event = Event::LlmRetry {
llm_call: 3,
failed_attempt: 1,
max_attempts: 4,
wait_ms: 800,
elapsed_ms: 120,
status: Some(504),
error: "API error (status: 504): gateway timeout".into(),
};
let json = serde_json::to_string(&event).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["event"], "llm_retry");
assert_eq!(parsed["data"]["status"], 504);
let decoded: Event = serde_json::from_str(&json).unwrap();
match decoded {
Event::LlmRetry {
llm_call,
failed_attempt,
max_attempts,
wait_ms,
status,
..
} => {
assert_eq!(llm_call, 3);
assert_eq!(failed_attempt, 1);
assert_eq!(max_attempts, 4);
assert_eq!(wait_ms, 800);
assert_eq!(status, Some(504));
}
other => panic!("expected LlmRetry, got {other:?}"),
}
}
#[test]
fn event_llm_continuation_roundtrip() {
let event = Event::LlmContinuation {
llm_call: 4,
attempt: 1,
max_attempts: 3,
reason: "SSE parse error: closed".into(),
};
let json = serde_json::to_string(&event).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["event"], "llm_continuation");
let decoded: Event = serde_json::from_str(&json).unwrap();
match decoded {
Event::LlmContinuation {
llm_call,
attempt,
max_attempts,
reason,
} => {
assert_eq!(llm_call, 4);
assert_eq!(attempt, 1);
assert_eq!(max_attempts, 3);
assert_eq!(reason, "SSE parse error: closed");
}
other => panic!("expected LlmContinuation, got {other:?}"),
}
}
#[test] #[test]
fn method_notify_json_roundtrip() { fn method_notify_json_roundtrip() {
let json = r#"{"method":"notify","params":{"message":"turn done"}}"#; let json = r#"{"method":"notify","params":{"message":"turn done"}}"#;

View File

@ -56,6 +56,7 @@ impl CodexAuthError {
status: None, status: None,
code: Some("refresh_transient".into()), code: Some("refresh_transient".into()),
message: msg, message: msg,
retry_after: None,
}, },
CodexAuthError::RefreshPermanent { reason, message } => ClientError::Api { CodexAuthError::RefreshPermanent { reason, message } => ClientError::Api {
status: Some(401), status: Some(401),
@ -66,6 +67,7 @@ impl CodexAuthError {
PermanentReason::Other => "refresh_token_failed".into(), PermanentReason::Other => "refresh_token_failed".into(),
}), }),
message: format!("{message}. Please run `codex login` again."), message: format!("{message}. Please run `codex login` again."),
retry_after: None,
}, },
} }
} }

View File

@ -45,6 +45,7 @@ impl LlmClient for MockLlmClient {
status: Some(500), status: Some(500),
code: Some("mock_error".to_string()), code: Some("mock_error".to_string()),
message: "No more mock responses".to_string(), message: "No more mock responses".to_string(),
retry_after: None,
}); });
} }
let events = self.responses[count].clone(); let events = self.responses[count].clone();

View File

@ -89,6 +89,8 @@ pub struct App {
pub context_window: u64, pub context_window: u64,
pub turn_index: usize, pub turn_index: usize,
pub current_tool: Option<String>, pub current_tool: Option<String>,
/// Latest LLM wait/retry lifecycle event for actionbar observability.
pub latest_llm_wait_event: Option<String>,
/// Latest memory extract/consolidation lifecycle event for actionbar observability. /// Latest memory extract/consolidation lifecycle event for actionbar observability.
pub latest_memory_worker_event: Option<String>, pub latest_memory_worker_event: Option<String>,
/// Normal composer input that is submitted as `Method::Run`. /// Normal composer input that is submitted as `Method::Run`.
@ -150,6 +152,7 @@ impl App {
context_window: 0, context_window: 0,
turn_index: 0, turn_index: 0,
current_tool: None, current_tool: None,
latest_llm_wait_event: None,
latest_memory_worker_event: None, latest_memory_worker_event: None,
input: InputBuffer::new(), input: InputBuffer::new(),
command_input: InputBuffer::new(), command_input: InputBuffer::new(),
@ -608,20 +611,52 @@ impl App {
self.set_pod_status(PodStatus::Running); self.set_pod_status(PodStatus::Running);
self.run_requests += 1; self.run_requests += 1;
self.current_tool = None; self.current_tool = None;
self.latest_llm_wait_event = None;
self.assistant_streaming = false; self.assistant_streaming = false;
} }
// UI consumers of Invoke / LlmCall semantics are out of scope // UI consumers of Invoke / LlmCall semantics are out of scope
// for `tickets/invoke-turn-llmcall-semantics.md`; events flow // for `tickets/invoke-turn-llmcall-semantics.md`; events flow
// through to subscribers but the TUI currently derives its // through to subscribers but the TUI currently derives its
// turn header from `UserMessage` / `SystemItem` arrivals. // turn header from `UserMessage` / `SystemItem` arrivals.
Event::InvokeStart { .. } | Event::LlmCallStart { .. } | Event::LlmCallEnd { .. } => {} Event::InvokeStart { .. } | Event::LlmCallStart { .. } | Event::LlmCallEnd { .. } => {
self.latest_llm_wait_event = None;
}
Event::LlmRetry {
failed_attempt,
max_attempts,
wait_ms,
status,
error,
..
} => {
let next_attempt = failed_attempt.saturating_add(1).min(max_attempts);
let reason = status
.map(|code| format!("HTTP {code}"))
.unwrap_or_else(|| error);
self.latest_llm_wait_event = Some(format!(
"retrying LLM request after {reason} (attempt {next_attempt}/{max_attempts} in {})",
fmt_millis(wait_ms)
));
}
Event::LlmContinuation {
attempt,
max_attempts,
reason,
..
} => {
self.latest_llm_wait_event = Some(format!(
"LLM stream interrupted; continuing generation ({attempt}/{max_attempts}): {reason}"
));
}
Event::TextDelta { text } => { Event::TextDelta { text } => {
self.latest_llm_wait_event = None;
self.append_assistant_text(&text); self.append_assistant_text(&text);
} }
Event::TextDone { .. } => { Event::TextDone { .. } => {
self.assistant_streaming = false; self.assistant_streaming = false;
} }
Event::ThinkingStart => { Event::ThinkingStart => {
self.latest_llm_wait_event = None;
self.assistant_streaming = false; self.assistant_streaming = false;
self.blocks.push(Block::Thinking(ThinkingBlock { self.blocks.push(Block::Thinking(ThinkingBlock {
text: String::new(), text: String::new(),
@ -661,6 +696,7 @@ impl App {
self.current_tool = None; self.current_tool = None;
} }
Event::ToolCallStart { id, name } => { Event::ToolCallStart { id, name } => {
self.latest_llm_wait_event = None;
self.current_tool = Some(name.clone()); self.current_tool = Some(name.clone());
self.assistant_streaming = false; self.assistant_streaming = false;
self.blocks.push(Block::ToolCall(ToolCallBlock { self.blocks.push(Block::ToolCall(ToolCallBlock {
@ -702,6 +738,7 @@ impl App {
output, output,
is_error, is_error,
} => { } => {
self.latest_llm_wait_event = None;
// Pull the name / args out first so we can look at the // Pull the name / args out first so we can look at the
// (immutable) cache before taking the mutable block // (immutable) cache before taking the mutable block
// borrow below. // borrow below.
@ -776,6 +813,7 @@ impl App {
self.push_error(format!("[{code:?}] {message}")); self.push_error(format!("[{code:?}] {message}"));
} }
Event::RunEnd { result } => { Event::RunEnd { result } => {
self.latest_llm_wait_event = None;
if matches!(result, RunResult::RolledBack) { if matches!(result, RunResult::RolledBack) {
self.handle_rolled_back_run(); self.handle_rolled_back_run();
} else { } else {
@ -889,6 +927,7 @@ impl App {
self.run_upload_tokens = 0; self.run_upload_tokens = 0;
self.run_output_tokens = 0; self.run_output_tokens = 0;
self.current_tool = None; self.current_tool = None;
self.latest_llm_wait_event = None;
self.assistant_streaming = false; self.assistant_streaming = false;
} }
@ -1291,6 +1330,14 @@ pub fn fmt_tokens(n: u64) -> String {
} }
} }
fn fmt_millis(ms: u64) -> String {
if ms >= 1_000 {
format!("{:.1}s", ms as f64 / 1_000.0)
} else {
format!("{ms}ms")
}
}
fn message_text(item: &serde_json::Value) -> String { fn message_text(item: &serde_json::Value) -> String {
item["content"] item["content"]
.as_array() .as_array()
@ -1356,6 +1403,47 @@ pub fn alert_source_label(source: AlertSource) -> &'static str {
} }
} }
#[cfg(test)]
mod llm_wait_event_tests {
use super::*;
#[test]
fn llm_retry_updates_and_progress_clears_transient_status() {
let mut app = App::new("test".into());
app.handle_pod_event(Event::LlmRetry {
llm_call: 2,
failed_attempt: 1,
max_attempts: 4,
wait_ms: 1_200,
elapsed_ms: 50,
status: Some(504),
error: "gateway timeout".into(),
});
assert_eq!(
app.latest_llm_wait_event.as_deref(),
Some("retrying LLM request after HTTP 504 (attempt 2/4 in 1.2s)")
);
app.handle_pod_event(Event::TextDelta { text: "ok".into() });
assert!(app.latest_llm_wait_event.is_none());
}
#[test]
fn llm_continuation_updates_transient_status() {
let mut app = App::new("test".into());
app.handle_pod_event(Event::LlmContinuation {
llm_call: 3,
attempt: 1,
max_attempts: 3,
reason: "SSE parse error: closed".into(),
});
assert_eq!(
app.latest_llm_wait_event.as_deref(),
Some("LLM stream interrupted; continuing generation (1/3): SSE parse error: closed")
);
}
}
#[cfg(test)] #[cfg(test)]
mod completion_flow_tests { mod completion_flow_tests {
use super::*; use super::*;

View File

@ -1158,7 +1158,14 @@ fn draw_status(frame: &mut Frame, app: &App, area: Rect) {
]; ];
if app.running { if app.running {
let status = if let Some(tool) = &app.current_tool { let status = if let Some(wait_event) = &app.latest_llm_wait_event {
format!(
"request: {} | ↑{}/↓{} | {wait_event}",
app.run_requests,
fmt_tokens(app.run_upload_tokens),
fmt_tokens(app.run_output_tokens),
)
} else if let Some(tool) = &app.current_tool {
format!( format!(
"request: {} | ↑{}/↓{} | tool: {tool}", "request: {} | ↑{}/↓{} | tool: {tool}",
app.run_requests, app.run_requests,
@ -1218,6 +1225,11 @@ fn draw_actionbar(frame: &mut Frame, app: &App, area: Rect) {
"Alt-q edit queued Alt-c clear queued", "Alt-q edit queued Alt-c clear queued",
Style::default().fg(Color::DarkGray), Style::default().fg(Color::DarkGray),
)); ));
} else if let Some(llm_event) = app.latest_llm_wait_event.as_deref() {
left.push(Span::styled(
truncate_with_ellipsis(llm_event, 96),
Style::default().fg(Color::Yellow),
));
} else if let Some(memory_event) = app.latest_memory_worker_event.as_deref() { } else if let Some(memory_event) = app.latest_memory_worker_event.as_deref() {
left.push(Span::styled( left.push(Span::styled(
truncate_with_ellipsis(memory_event, 72), truncate_with_ellipsis(memory_event, 72),

View File

@ -2,7 +2,7 @@
## 背景 ## 背景
`Read` などの tool 実行が完了した後、本来は tool result を含めた次の LLM request が走り、assistant 応答が続く。実運用ではこの次 request が provider / upstream gateway から HTTP 504 を返すことがあり、現行の `HttpTransport` は transient retry として最大 `RetryPolicy::default()` の範囲で再試行する。 `Read` などの tool 実行が完了した後、本来は tool result を含めた次の LLM request が走り、assistant 応答が続く。実運用ではこの次 request が provider / upstream gateway から HTTP 504 を返すことがあり、LLM response stream を開く前の transient failure として retry する必要がある。
現在の問題は、retry / backoff 中であることが TUI に表示されず、「tool は終わったのに、その後の LLM 応答がハングした」ように見えることにある。 現在の問題は、retry / backoff 中であることが TUI に表示されず、「tool は終わったのに、その後の LLM 応答がハングした」ように見えることにある。
@ -18,10 +18,12 @@
LLM request が待っている理由は user-visible operational state として扱う。history / LLM context には入れない。 LLM request が待っている理由は user-visible operational state として扱う。history / LLM context には入れない。
transport から protocol / TUI に直接依存させず、下から上へ typed event を渡す。 retry / continuation から protocol / TUI に直接依存させず、Worker の lifecycle event として下から上へ typed event を渡す。`HttpTransport` は 1 回の HTTP request と response classification を担当し、retry / backoff / cancellation / TUI 通知は Worker が管理する。
```text ```text
HttpTransport / stream consumer HttpTransport
-> ClientError { status, retry_after, ... }
-> Worker retry / continuation state
-> llm-worker callback -> llm-worker callback
-> Pod controller bridge -> Pod controller bridge
-> protocol::Event -> protocol::Event
@ -55,12 +57,17 @@ continuation は `Worker::stream_response` の error branch 周辺に閉じ込
### 実装方針 ### 実装方針
1. `llm_client::client::LlmClient` に retry observer 付き stream entrypoint を追加する。 1. `llm_client::client::LlmClient::stream(request)` は単発 request として維持する。
- 既存 `stream(request)` の意味は維持する。 - 成功時は `ResponseStream` を返す。
- default 実装は observer を無視して `stream(request)` に委譲する。 - stream open 前の失敗は `ClientError` として返す。
- `HttpTransport` だけが observer を利用する。 - retry observer 付き entrypoint は作らない。
2. `llm_client::transport::HttpTransport::stream` の retry 判定直後、`tokio::time::sleep(wait)` の直前で retry notice を発火する。 2. `llm_client::transport::HttpTransport::stream` は retry しない。
3. `Worker``on_llm_retry` callback を追加する。 - HTTP status / connect / timeout を `ClientError` に分類する。
- `Retry-After` がある場合は `ClientError` の metadata として保持する。
3. `Worker``open_stream_with_retry` 相当の helper を置く。
- `RetryPolicy``is_retryable(&ClientError)` に従って `client.stream(request.clone())` を再試行する。
- backoff sleep は cancel / abort より低優先にする。
- sleep 前に `on_llm_retry` callback を発火する。
4. `Pod``wire_event_bridges_on_worker` で protocol event に変換する。 4. `Pod``wire_event_bridges_on_worker` で protocol event に変換する。
5. `TUI` は retry state を transient に表示する。 5. `TUI` は retry state を transient に表示する。
@ -71,7 +78,7 @@ continuation は `Worker::stream_response` の error branch 周辺に閉じ込
```rust ```rust
Event::LlmRetry { Event::LlmRetry {
llm_call: usize, llm_call: usize,
attempt: u32, failed_attempt: u32,
max_attempts: u32, max_attempts: u32,
wait_ms: u64, wait_ms: u64,
elapsed_ms: u64, elapsed_ms: u64,
@ -80,7 +87,7 @@ Event::LlmRetry {
} }
``` ```
- `attempt` は「次に実行する attempt 番号」または「失敗した attempt 番号」のどちらかに統一し、protocol comment と TUI 表示で曖昧にならないようにする - `failed_attempt` は「直近で失敗した attempt 番号」として扱う。TUI 表示では次に実行される attempt を `failed_attempt + 1` として表示してよい
- `status` は HTTP status が取れる場合のみ入れる。504 の場合は `Some(504)` - `status` は HTTP status が取れる場合のみ入れる。504 の場合は `Some(504)`
- `error` は user-visible になり得るので、API key / Authorization header / request body を含めない。 - `error` は user-visible になり得るので、API key / Authorization header / request body を含めない。
- retry exhausted は既存の final error 経路で表示する。初期実装では sleep 前の retry notice に絞る。 - retry exhausted は既存の final error 経路で表示する。初期実装では sleep 前の retry notice に絞る。
@ -138,13 +145,13 @@ HTTP status 504 のような stream 開始前 error は Phase 1 の retry 表示
`stream_response` の成功 result は保ちつつ、stream が途中で切れたことだけを表せる型にする。 `stream_response` の成功 result は保ちつつ、stream が途中で切れたことだけを表せる型にする。
```rust ```rust
enum StreamResponseOutcome { enum StreamCompletion {
Completed(CompletedResponse), Complete,
Interrupted(StreamInterruption), Interrupted { reason: String },
} }
``` ```
`CompletedResponse` は現行成功経路で使っている情報を保持する。`Interrupted` には partial commit に必要な情報だけを入れる。 `Complete` は現行成功経路へ進むだけで、成功時の assistant item / tool call を別 result type へ包み直さない。`Interrupted` には continuation notice と partial commit 判断に必要な理由だけを入れる。
`TimelineDispatch` / collector に partial drain API を追加し、途中中断時に安全に history 化できるものだけを取り出す。 `TimelineDispatch` / collector に partial drain API を追加し、途中中断時に安全に history 化できるものだけを取り出す。
@ -158,19 +165,19 @@ enum StreamResponseOutcome {
```rust ```rust
match self.stream_response(request).await? { match self.stream_response(request).await? {
StreamResponseOutcome::Completed(response) => { StreamCompletion::Complete => {
self.handle_completed_response(response).await?; self.handle_completed_response().await?;
if self.execute_and_commit_tools(...).await? { if self.execute_and_commit_tools(...).await? {
continue; continue;
} }
break; break;
} }
StreamResponseOutcome::Interrupted(interruption) => { StreamCompletion::Interrupted { reason } => {
if continuation_budget.exhausted() { if continuation_budget.exhausted() {
return Err(...); return Err(...);
} }
self.commit_partial_assistant(interruption.safe_items).await?; self.commit_partial_assistant(...).await?;
self.emit_continuation_notice(...); self.emit_continuation_notice(reason);
continue; continue;
} }
} }
@ -197,8 +204,9 @@ match self.stream_response(request).await? {
## 完了条件 ## 完了条件
- `HttpTransport` の unit test で retryable 504 時に retry notice が発火する。 - `HttpTransport` の unit test で retryable 504/503 が transport 内部では retry されず、`ClientError` として返る。
- `Worker` の test で `on_llm_retry` callback が呼ばれる。 - `HttpTransport` の unit test で `Retry-After``ClientError` metadata として保持される。
- `Worker` の test で stream open 前の retryable error に対して `on_llm_retry` callback が呼ばれる。
- `protocol::Event` の retry / continuation event の serde roundtrip test がある。 - `protocol::Event` の retry / continuation event の serde roundtrip test がある。
- Pod controller bridge の test、または既存 bridge test への追加で retry / continuation event が流れることを確認する。 - Pod controller bridge の test、または既存 bridge test への追加で retry / continuation event が流れることを確認する。
- TUI app test で retry / continuation event が transient state を更新し、進行イベントで clear されることを確認する。 - TUI app test で retry / continuation event が transient state を更新し、進行イベントで clear されることを確認する。