fix: account hook system item appends

This commit is contained in:
Keisuke Hirata 2026-06-05 11:17:16 +09:00
parent c9cb2edc7e
commit 960f2a305e
No known key found for this signature in database
4 changed files with 179 additions and 53 deletions

View File

@ -42,6 +42,12 @@ pub enum PreRequestAction {
/// to: the items are committed before the request so later turns can see
/// why the worker changed course.
ContinueWith(Vec<Item>),
/// Yield after appending these items to durable worker history.
///
/// This is for host-mediated pre-request appends that must be visible to
/// usage accounting and compaction checks before the current LLM request is
/// allowed to proceed.
YieldWith(Vec<Item>),
/// Cancel with a reason (treated as an error).
Cancel(String),
/// Yield control to the caller for external processing.

View File

@ -1177,6 +1177,16 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
self.last_run_interrupted = true;
return Err(WorkerError::Aborted(reason));
}
PreRequestAction::YieldWith(items) => {
self.append_history_items(items.clone());
request_context.extend(items);
info!("Yielded by interceptor after pre-request history append");
for cb in &self.turn_end_cbs {
cb(current_turn);
}
self.last_run_interrupted = true;
return Ok(WorkerResult::Yielded);
}
PreRequestAction::Yield => {
info!("Yielded by interceptor");
for cb in &self.turn_end_cbs {

View File

@ -7,6 +7,7 @@
//! event-specific read-only contexts and only return control-flow
//! decisions (continue / skip / abort / pause).
use std::borrow::Cow;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
@ -137,6 +138,29 @@ impl PodInterceptor {
}
Some(total_tokens(context, &records).tokens)
}
fn request_threshold_exceeded(&self, current_tokens: Option<u64>, context: &[Item]) -> bool {
if let Some(state) = self.compact_state.as_ref() {
if !state.is_disabled() && !state.just_compacted() {
let current = current_tokens.unwrap_or(0);
if state.exceeds_request(current) {
let shape = context_shape(context);
info!(
input_tokens = current,
threshold = state.request_threshold().unwrap_or(0),
items_len = shape.items_len,
items_json_bytes = shape.items_json_bytes,
reasoning_items = shape.reasoning_items,
reasoning_encrypted_content_count = shape.reasoning_encrypted_content_count,
reasoning_encrypted_content_bytes = shape.reasoning_encrypted_content_bytes,
"Between-requests compaction threshold exceeded, yielding"
);
return true;
}
}
}
false
}
}
#[async_trait]
@ -210,32 +234,13 @@ impl Interceptor for PodInterceptor {
}
async fn pre_llm_request(&self, context: &mut Vec<Item>) -> PreRequestAction {
let current_tokens = self.estimated_tokens(context);
// Internal mechanism: between-requests compaction trigger (safety net).
if let Some(state) = self.compact_state.as_ref() {
if !state.is_disabled() && !state.just_compacted() {
let current = current_tokens.unwrap_or(0);
if state.exceeds_request(current) {
let shape = context_shape(context);
info!(
input_tokens = current,
threshold = state.request_threshold().unwrap_or(0),
items_len = shape.items_len,
items_json_bytes = shape.items_json_bytes,
reasoning_items = shape.reasoning_items,
reasoning_encrypted_content_count = shape.reasoning_encrypted_content_count,
reasoning_encrypted_content_bytes = shape.reasoning_encrypted_content_bytes,
"Between-requests compaction threshold exceeded, yielding"
);
return PreRequestAction::Yield;
}
}
let initial_tokens = self.estimated_tokens(context);
if self.request_threshold_exceeded(initial_tokens, context) {
return PreRequestAction::Yield;
}
let info = PreRequestInfo {
item_count: context.len(),
estimated_tokens: current_tokens,
estimated_tokens: initial_tokens,
turn_index: self.current_turn_index(),
tool_calls_this_turn: self.tool_calls_this_turn.load(Ordering::Relaxed),
};
@ -257,16 +262,36 @@ impl Interceptor for PodInterceptor {
.lock()
.expect("pending hook system-item queue poisoned"),
);
let appended_items: Vec<Item> = system_items
.iter()
.map(SystemItem::to_history_item)
.collect();
let effective_context = if appended_items.is_empty() {
Cow::Borrowed(context.as_slice())
} else {
let mut effective = context.clone();
effective.extend(appended_items.clone());
Cow::Owned(effective)
};
let current_tokens = self.estimated_tokens(effective_context.as_ref());
if self.request_threshold_exceeded(current_tokens, effective_context.as_ref()) {
self.commit_system_items(&system_items);
return if appended_items.is_empty() {
PreRequestAction::Yield
} else {
PreRequestAction::YieldWith(appended_items)
};
}
if let Some(usage_tracker) = self.usage_tracker.as_ref() {
usage_tracker.note_request(effective_context.len());
}
if system_items.is_empty() {
return PreRequestAction::Continue;
}
self.commit_system_items(&system_items);
PreRequestAction::ContinueWith(
system_items
.into_iter()
.map(|item| item.to_history_item())
.collect(),
)
PreRequestAction::ContinueWith(appended_items)
}
async fn pre_tool_call(&self, info: &mut ToolCallInfo) -> PreToolAction {
@ -396,6 +421,8 @@ mod tests {
use std::sync::atomic::{AtomicBool, AtomicUsize};
use super::*;
use crate::feature::FeatureRegistryBuilder;
use crate::feature::builtin::TaskFeature;
use crate::hook::{
Hook, HookPostToolAction, HookPreRequestAction, HookPreToolAction, HookRegistryBuilder,
HookTurnEndAction, OnTurnEnd, PostToolCall, PreLlmRequest, PreToolCall,
@ -504,6 +531,41 @@ mod tests {
assert_eq!(count.load(Ordering::Relaxed), 0);
}
#[tokio::test]
async fn pre_llm_request_yields_with_hook_appends_when_post_append_threshold_exceeded() {
let saw_handle = Arc::new(AtomicBool::new(false));
let mut builder = HookRegistryBuilder::new();
builder.add_pre_llm_request(AppendingPreRequestHook {
saw_handle: Arc::clone(&saw_handle),
});
let registry = Arc::new(builder.build());
let state = Arc::new(CompactState::new(None, Some(50), 2));
let ctx_items = vec![Item::user_message("hi")];
let history = usage_handle_with(ctx_items.len(), 50);
let committed = Arc::new(Mutex::new(Vec::new()));
let interceptor = PodInterceptor::new(
registry,
Some(state),
Some(history),
NotifyBuffer::new(),
Arc::new(Mutex::new(Vec::new())),
PromptCatalog::builtins_only().unwrap(),
Some(Arc::new(RecordingSystemItemCommitter {
committed: Arc::clone(&committed),
})),
);
let mut ctx = ctx_items;
let action = interceptor.pre_llm_request(&mut ctx).await;
match action {
PreRequestAction::YieldWith(items) => assert_eq!(items.len(), 1),
other => panic!("expected YieldWith queued system item, got {other:?}"),
}
assert!(saw_handle.load(Ordering::Relaxed));
assert_eq!(committed.lock().expect("committed system items").len(), 1);
}
#[tokio::test]
async fn pre_llm_request_counts_in_flight_usage_records() {
let registry = Arc::new(HookRegistryBuilder::new().build());
@ -878,6 +940,75 @@ mod tests {
assert_eq!(count.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn task_reminder_hook_append_is_counted_in_usage_request_len() {
let feature = TaskFeature::from_history(&[Item::tool_call(
"task-create-call",
"TaskCreate",
r#"{"subject":"track active work","description":"exercise reminder path"}"#,
)]);
let mut hook_builder = HookRegistryBuilder::new();
let mut pending_tools = Vec::new();
FeatureRegistryBuilder::new()
.with_module(feature)
.install_into_pending(&mut pending_tools, &mut hook_builder);
let registry = Arc::new(hook_builder.build());
let usage_tracker = Arc::new(UsageTracker::new());
let committed = Arc::new(Mutex::new(Vec::new()));
let interceptor = PodInterceptor::new(
registry,
None,
None,
NotifyBuffer::new(),
Arc::new(Mutex::new(Vec::new())),
PromptCatalog::builtins_only().unwrap(),
Some(Arc::new(RecordingSystemItemCommitter {
committed: Arc::clone(&committed),
})),
)
.with_usage_tracker(Arc::clone(&usage_tracker));
let ctx_items = vec![Item::user_message("hi")];
for _ in 0..23 {
let mut ctx = ctx_items.clone();
let action = interceptor.pre_llm_request(&mut ctx).await;
assert!(matches!(action, PreRequestAction::Continue));
usage_tracker.record_usage(&llm_worker::event::UsageEvent {
input_tokens: Some(10),
output_tokens: Some(0),
total_tokens: Some(10),
cache_read_input_tokens: Some(0),
cache_creation_input_tokens: Some(0),
});
}
let mut ctx = ctx_items.clone();
let action = interceptor.pre_llm_request(&mut ctx).await;
let appended_len = match action {
PreRequestAction::ContinueWith(items) => items.len(),
other => panic!("expected reminder append, got {other:?}"),
};
assert_eq!(appended_len, 1);
usage_tracker.record_usage(&llm_worker::event::UsageEvent {
input_tokens: Some(11),
output_tokens: Some(0),
total_tokens: Some(11),
cache_read_input_tokens: Some(0),
cache_creation_input_tokens: Some(0),
});
let records = usage_tracker.records();
assert_eq!(records.last().expect("usage record").history_len, 2);
let committed = committed
.lock()
.expect("committed system-item list poisoned");
assert_eq!(committed.len(), 1);
let SystemItem::TaskReminder { body, .. } = &committed[0] else {
panic!("expected task reminder, got {:?}", committed[0]);
};
assert!(body.contains("track active work"));
}
#[tokio::test]
async fn pending_history_appends_drains_buffer_into_items() {
let registry = Arc::new(HookRegistryBuilder::new().build());

View File

@ -31,8 +31,8 @@ use crate::compact::usage_tracker::UsageTracker;
use crate::feature::builtin::TaskFeature;
use crate::feature::{FeatureRegistryBuilder, FeatureRegistryInstallReport};
use crate::hook::{
Hook, HookPreRequestAction, HookRegistryBuilder, OnAbort, OnPromptSubmit, OnTurnEnd,
PostToolCall, PreLlmRequest, PreRequestContext, PreToolCall,
Hook, HookRegistryBuilder, OnAbort, OnPromptSubmit, OnTurnEnd, PostToolCall, PreLlmRequest,
PreToolCall,
};
use crate::ipc::alerter::Alerter;
use crate::ipc::interceptor::PodInterceptor;
@ -44,6 +44,7 @@ use crate::prompt::system::{SystemPromptContext, SystemPromptError, SystemPrompt
use crate::runtime::dir;
use crate::runtime::pod_registry::{self, ScopeAllocationGuard, ScopeLockError};
use crate::workflow::WorkflowResolveError;
#[cfg(test)]
use async_trait::async_trait;
use protocol::{
AlertLevel, AlertSource, Event, RewindSummary, RewindTarget, RewindTargetId, Segment,
@ -213,21 +214,6 @@ where
}
}
/// Pre-LLM-request hook that records `history.len()` at send time into a
/// shared `UsageTracker`. The on_usage callback later pairs this with the
/// aggregated UsageEvent to produce one `UsageRecord` per LLM call.
struct UsageTrackingHook {
tracker: Arc<UsageTracker>,
}
#[async_trait]
impl Hook<PreLlmRequest> for UsageTrackingHook {
async fn call(&self, info: &PreRequestContext) -> HookPreRequestAction {
self.tracker.note_request(info.item_count);
HookPreRequestAction::Continue
}
}
/// An independent agent execution unit.
///
/// Holds a [`Worker`] directly and persists session state via
@ -1167,13 +1153,6 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
/// occupancy through the `UsageRecord` timeline.
fn ensure_interceptor_installed(&mut self) {
if !self.interceptor_installed {
// Pre-LLM-request hook: record the item count at send time
// so the on_usage callback can pair it with the measured
// input_tokens.
self.hook_builder.add_pre_llm_request(UsageTrackingHook {
tracker: self.usage_tracker.clone(),
});
let builder = std::mem::take(&mut self.hook_builder);
let registry = Arc::new(builder.build());