merge: main trace diagnostics

This commit is contained in:
Keisuke Hirata 2026-05-28 12:32:24 +09:00
commit 11d1dcffb6
4 changed files with 282 additions and 16 deletions

View File

@ -29,7 +29,7 @@ impl Default for RetryPolicy {
base: Duration::from_millis(500),
cap: Duration::from_secs(10),
max_attempts: 4,
total_timeout: Duration::from_secs(30),
total_timeout: Duration::from_secs(40),
}
}
}
@ -75,7 +75,7 @@ mod tests {
assert_eq!(p.base, Duration::from_millis(500));
assert_eq!(p.cap, Duration::from_secs(10));
assert_eq!(p.max_attempts, 4);
assert_eq!(p.total_timeout, Duration::from_secs(30));
assert_eq!(p.total_timeout, Duration::from_secs(40));
}
#[test]

View File

@ -6,7 +6,7 @@
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use eventsource_stream::Eventsource;
@ -14,6 +14,7 @@ use futures::{Stream, StreamExt, TryStreamExt};
use reqwest::header::{
ACCEPT, CONTENT_ENCODING, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue, RETRY_AFTER,
};
use serde_json::{Value, json};
use super::auth::{AuthProvider, AuthRequirement};
use super::capability::ModelCapability;
@ -23,7 +24,7 @@ use super::event::Event;
use super::scheme::Scheme;
use super::types::{Request, RequestConfig};
pub const DEFAULT_STREAM_OPEN_TIMEOUT: Duration = Duration::from_secs(30);
pub const DEFAULT_STREAM_OPEN_TIMEOUT: Duration = Duration::from_secs(20);
pub const DEFAULT_FIRST_STREAM_EVENT_TIMEOUT: Duration = Duration::from_secs(30);
/// `AuthRef` を解決したランタイム表現。`crates/provider` が構築する。
@ -192,16 +193,71 @@ impl<S: Scheme> HttpTransport<S> {
}
let raw = serde_json::to_vec(body)?;
let raw_json_bytes = raw.len();
let compressed = zstd::stream::encode_all(std::io::Cursor::new(raw), 3)
.map_err(|e| ClientError::Config(format!("failed to zstd-compress request: {e}")))?;
headers.insert(CONTENT_ENCODING, HeaderValue::from_static("zstd"));
Ok(RequestBody::CompressedJson(compressed))
Ok(RequestBody::CompressedJson {
bytes: compressed,
raw_json_bytes,
})
}
}
enum RequestBody {
Json(serde_json::Value),
CompressedJson(Vec<u8>),
CompressedJson {
bytes: Vec<u8>,
raw_json_bytes: usize,
},
}
impl RequestBody {
fn encoding(&self) -> &'static str {
match self {
Self::Json(_) => "json",
Self::CompressedJson { .. } => "zstd",
}
}
fn raw_json_bytes(&self) -> Option<usize> {
match self {
Self::Json(body) => serde_json::to_vec(body).ok().map(|bytes| bytes.len()),
Self::CompressedJson { raw_json_bytes, .. } => Some(*raw_json_bytes),
}
}
fn wire_bytes(&self) -> Option<usize> {
match self {
Self::Json(body) => serde_json::to_vec(body).ok().map(|bytes| bytes.len()),
Self::CompressedJson { bytes, .. } => Some(bytes.len()),
}
}
}
fn auth_kind(auth: &ResolvedAuth) -> &'static str {
match auth {
ResolvedAuth::None => "none",
ResolvedAuth::ApiKey(_) => "api_key",
ResolvedAuth::Custom(_) => "custom",
}
}
fn emit_transport_trace(request: &Request, label: &str, data: Value) {
if let Some(trace) = &request.transport_trace {
trace.emit(label, data);
}
}
fn json_value_kind(value: &Value) -> &'static str {
match value {
Value::Null => "null",
Value::Bool(_) => "bool",
Value::Number(_) => "number",
Value::String(_) => "string",
Value::Array(_) => "array",
Value::Object(_) => "object",
}
}
async fn response_with_timeout(
@ -273,27 +329,175 @@ impl<S: Scheme + Clone + 'static> LlmClient for HttpTransport<S> {
}
async fn stream(&self, request: Request) -> Result<ResponseStream, ClientError> {
let total_started = Instant::now();
let path = self.scheme.path(&self.model_id);
emit_transport_trace(
&request,
"transport_start",
json!({
"model": &self.model_id,
"path": path,
"auth_kind": auth_kind(&self.auth),
"required_auth": format!("{:?}", self.scheme.required_auth()),
"codex_backend": self.is_codex_backend(),
"cache_key_present": request.cache_key.is_some(),
"stream_open_timeout_ms": DEFAULT_STREAM_OPEN_TIMEOUT.as_millis() as u64,
}),
);
let url = self.build_url();
let mut headers = self.build_headers().await?;
self.apply_stream_headers(&mut headers, &request)?;
let headers_started = Instant::now();
emit_transport_trace(
&request,
"transport_headers_start",
json!({
"auth_kind": auth_kind(&self.auth),
"required_auth": format!("{:?}", self.scheme.required_auth()),
}),
);
let mut headers = match self.build_headers().await {
Ok(headers) => {
emit_transport_trace(
&request,
"transport_headers_done",
json!({
"elapsed_ms": headers_started.elapsed().as_millis() as u64,
"headers_len": headers.len(),
}),
);
headers
}
Err(error) => {
emit_transport_trace(
&request,
"transport_headers_error",
json!({
"elapsed_ms": headers_started.elapsed().as_millis() as u64,
"error": error.to_string(),
}),
);
return Err(error);
}
};
let stream_headers_started = Instant::now();
if let Err(error) = self.apply_stream_headers(&mut headers, &request) {
emit_transport_trace(
&request,
"transport_stream_headers_error",
json!({
"elapsed_ms": stream_headers_started.elapsed().as_millis() as u64,
"error": error.to_string(),
}),
);
return Err(error);
}
emit_transport_trace(
&request,
"transport_stream_headers_done",
json!({
"elapsed_ms": stream_headers_started.elapsed().as_millis() as u64,
"headers_len": headers.len(),
}),
);
let body_started = Instant::now();
emit_transport_trace(&request, "transport_body_build_start", json!({}));
let body = self
.scheme
.build_request_body(&self.model_id, &request, &self.capability);
let request_body = self.encode_request_body(&body, &mut headers)?;
emit_transport_trace(
&request,
"transport_body_build_done",
json!({
"elapsed_ms": body_started.elapsed().as_millis() as u64,
"body_kind": json_value_kind(&body),
}),
);
let encode_started = Instant::now();
let request_body = match self.encode_request_body(&body, &mut headers) {
Ok(body) => body,
Err(error) => {
emit_transport_trace(
&request,
"transport_body_encode_error",
json!({
"elapsed_ms": encode_started.elapsed().as_millis() as u64,
"error": error.to_string(),
}),
);
return Err(error);
}
};
emit_transport_trace(
&request,
"transport_body_encode_done",
json!({
"elapsed_ms": encode_started.elapsed().as_millis() as u64,
"encoding": request_body.encoding(),
"raw_json_bytes": request_body.raw_json_bytes(),
"wire_bytes": request_body.wire_bytes(),
}),
);
let builder = self.http_client.post(&url).headers(headers);
let builder = match request_body {
RequestBody::Json(body) => builder.json(&body),
RequestBody::CompressedJson(body) => builder.body(body),
RequestBody::CompressedJson { bytes, .. } => builder.body(bytes),
};
let send_started = Instant::now();
emit_transport_trace(&request, "transport_http_send_start", json!({}));
let response =
response_with_timeout(builder.send(), DEFAULT_STREAM_OPEN_TIMEOUT, "stream_open")
.await?;
match response_with_timeout(builder.send(), DEFAULT_STREAM_OPEN_TIMEOUT, "stream_open")
.await
{
Ok(response) => {
emit_transport_trace(
&request,
"transport_http_headers_received",
json!({
"elapsed_ms": send_started.elapsed().as_millis() as u64,
"status": response.status().as_u16(),
"success": response.status().is_success(),
}),
);
response
}
Err(error) => {
emit_transport_trace(
&request,
"transport_http_send_error",
json!({
"elapsed_ms": send_started.elapsed().as_millis() as u64,
"error": error.to_string(),
}),
);
return Err(error);
}
};
if !response.status().is_success() {
emit_transport_trace(
&request,
"transport_http_status_error",
json!({
"status": response.status().as_u16(),
"retry_after_present": response.headers().get(RETRY_AFTER).is_some(),
}),
);
return Err(classify_error_response(response).await);
}
emit_transport_trace(
&request,
"transport_stream_ready",
json!({
"elapsed_ms": total_started.elapsed().as_millis() as u64,
}),
);
let scheme = self.scheme.clone();
let byte_stream = response.bytes_stream().map_err(std::io::Error::other);
let event_stream = byte_stream.eventsource();
@ -449,9 +653,14 @@ mod tests {
assert_eq!(headers.get("x-client-request-id").unwrap(), "segment-123");
assert_eq!(headers.get(CONTENT_ENCODING).unwrap(), "zstd");
let RequestBody::CompressedJson(compressed) = encoded else {
let RequestBody::CompressedJson {
bytes: compressed,
raw_json_bytes,
} = encoded
else {
panic!("Codex backend request body must be zstd-compressed");
};
assert!(raw_json_bytes > 0);
let decoded = zstd::stream::decode_all(std::io::Cursor::new(compressed)).unwrap();
let decoded: serde_json::Value = serde_json::from_slice(&decoded).unwrap();
assert_eq!(decoded["prompt_cache_key"], "segment-123");

View File

@ -7,6 +7,8 @@
//! - ToolResult items (tool results)
//! - Reasoning items (extended thinking)
use std::{fmt, sync::Arc};
use serde::{Deserialize, Serialize};
fn is_false(value: &bool) -> bool {
@ -23,6 +25,35 @@ pub type ItemId = String;
/// Call ID type for linking function calls to their outputs
pub type CallId = String;
/// Callback sink for request-local transport lifecycle diagnostics.
///
/// This is carried on [`Request`] so generic [`crate::llm_client::LlmClient`]
/// implementations can emit fine-grained transport milestones without widening
/// the trait method signature. The callback must never receive request body
/// contents or secret header values.
#[derive(Clone)]
pub struct RequestTrace {
callback: Arc<dyn Fn(&str, serde_json::Value) + Send + Sync>,
}
impl RequestTrace {
pub fn new(callback: impl Fn(&str, serde_json::Value) + Send + Sync + 'static) -> Self {
Self {
callback: Arc::new(callback),
}
}
pub fn emit(&self, label: &str, data: serde_json::Value) {
(self.callback)(label, data);
}
}
impl fmt::Debug for RequestTrace {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RequestTrace").finish_non_exhaustive()
}
}
/// Conversation item - the primary unit of conversation history
///
/// Items represent discrete elements in a conversation. Tool calls and reasoning
@ -497,6 +528,9 @@ pub struct Request {
/// 別の概念。`cache_anchor` を読まない provider と同じく、
/// `prompt_cache_key` を持たない provider は無視する。
pub cache_key: Option<String>,
/// Request-local diagnostics sink for transport lifecycle tracing.
#[doc(hidden)]
pub transport_trace: Option<RequestTrace>,
}
impl Request {
@ -547,6 +581,15 @@ impl Request {
self
}
/// Attach a request-local transport trace callback.
pub fn transport_trace(
mut self,
callback: impl Fn(&str, serde_json::Value) + Send + Sync + 'static,
) -> Self {
self.transport_trace = Some(RequestTrace::new(callback));
self
}
/// Set max tokens
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.config.max_tokens = Some(max_tokens);

View File

@ -1,5 +1,5 @@
use std::collections::HashMap;
use std::{marker::PhantomData, time::Instant};
use std::{marker::PhantomData, sync::Arc, time::Instant};
use futures::StreamExt;
use serde_json::{Value, json};
@ -207,7 +207,7 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
stream_event_cbs: Vec<Box<dyn Fn(usize, usize, &Event) + Send + Sync>>,
/// Pre-stream lifecycle callbacks for debugging stalls before provider
/// stream events become visible.
lifecycle_trace_cbs: Vec<Box<dyn Fn(usize, usize, &str, &Value) + Send + Sync>>,
lifecycle_trace_cbs: Vec<Arc<dyn Fn(usize, usize, &str, &Value) + Send + Sync>>,
/// Non-fatal warning callbacks. Invoked when the Worker wants to
/// surface an advisory message to the upper layer (e.g. Pod) so it
/// can be forwarded to the user — distinct from `tracing::warn!`,
@ -435,7 +435,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
&mut self,
callback: impl Fn(usize, usize, &str, &Value) + Send + Sync + 'static,
) {
self.lifecycle_trace_cbs.push(Box::new(callback));
self.lifecycle_trace_cbs.push(Arc::new(callback));
}
fn emit_lifecycle_trace(&self, turn: usize, llm_call: usize, label: &str, data: Value) {
@ -444,6 +444,19 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
}
}
fn attach_transport_trace(&self, request: Request, turn: usize, llm_call: usize) -> Request {
if self.lifecycle_trace_cbs.is_empty() {
return request;
}
let callbacks = self.lifecycle_trace_cbs.clone();
request.transport_trace(move |label, data| {
for cb in &callbacks {
cb(turn, llm_call, label, &data);
}
})
}
/// Register a non-fatal warning callback.
///
/// The callback is invoked with a short human-readable message
@ -1198,6 +1211,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
"build_request_done",
self.request_trace_payload(&request),
);
let request = self.attach_transport_trace(request, current_turn, current_llm_call);
let stream_outcome = self
.stream_response(request, current_turn, current_llm_call)
.await?;