From 960f2a305eb5f66d9a2650d79cd3a94eff9af4f6 Mon Sep 17 00:00:00 2001 From: Hare Date: Fri, 5 Jun 2026 11:17:16 +0900 Subject: [PATCH] fix: account hook system item appends --- crates/llm-worker/src/interceptor.rs | 6 + crates/llm-worker/src/worker.rs | 10 ++ crates/pod/src/ipc/interceptor.rs | 189 +++++++++++++++++++++++---- crates/pod/src/pod.rs | 27 +--- 4 files changed, 179 insertions(+), 53 deletions(-) diff --git a/crates/llm-worker/src/interceptor.rs b/crates/llm-worker/src/interceptor.rs index 1c058fb4..5b4b2e25 100644 --- a/crates/llm-worker/src/interceptor.rs +++ b/crates/llm-worker/src/interceptor.rs @@ -42,6 +42,12 @@ pub enum PreRequestAction { /// to: the items are committed before the request so later turns can see /// why the worker changed course. ContinueWith(Vec), + /// Yield after appending these items to durable worker history. + /// + /// This is for host-mediated pre-request appends that must be visible to + /// usage accounting and compaction checks before the current LLM request is + /// allowed to proceed. + YieldWith(Vec), /// Cancel with a reason (treated as an error). Cancel(String), /// Yield control to the caller for external processing. diff --git a/crates/llm-worker/src/worker.rs b/crates/llm-worker/src/worker.rs index d3e95fed..7c0ccc80 100644 --- a/crates/llm-worker/src/worker.rs +++ b/crates/llm-worker/src/worker.rs @@ -1177,6 +1177,16 @@ impl Worker { self.last_run_interrupted = true; return Err(WorkerError::Aborted(reason)); } + PreRequestAction::YieldWith(items) => { + self.append_history_items(items.clone()); + request_context.extend(items); + info!("Yielded by interceptor after pre-request history append"); + for cb in &self.turn_end_cbs { + cb(current_turn); + } + self.last_run_interrupted = true; + return Ok(WorkerResult::Yielded); + } PreRequestAction::Yield => { info!("Yielded by interceptor"); for cb in &self.turn_end_cbs { diff --git a/crates/pod/src/ipc/interceptor.rs b/crates/pod/src/ipc/interceptor.rs index 59669edb..74ba4280 100644 --- a/crates/pod/src/ipc/interceptor.rs +++ b/crates/pod/src/ipc/interceptor.rs @@ -7,6 +7,7 @@ //! event-specific read-only contexts and only return control-flow //! decisions (continue / skip / abort / pause). +use std::borrow::Cow; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, Mutex}; @@ -137,6 +138,29 @@ impl PodInterceptor { } Some(total_tokens(context, &records).tokens) } + + fn request_threshold_exceeded(&self, current_tokens: Option, context: &[Item]) -> bool { + if let Some(state) = self.compact_state.as_ref() { + if !state.is_disabled() && !state.just_compacted() { + let current = current_tokens.unwrap_or(0); + if state.exceeds_request(current) { + let shape = context_shape(context); + info!( + input_tokens = current, + threshold = state.request_threshold().unwrap_or(0), + items_len = shape.items_len, + items_json_bytes = shape.items_json_bytes, + reasoning_items = shape.reasoning_items, + reasoning_encrypted_content_count = shape.reasoning_encrypted_content_count, + reasoning_encrypted_content_bytes = shape.reasoning_encrypted_content_bytes, + "Between-requests compaction threshold exceeded, yielding" + ); + return true; + } + } + } + false + } } #[async_trait] @@ -210,32 +234,13 @@ impl Interceptor for PodInterceptor { } async fn pre_llm_request(&self, context: &mut Vec) -> PreRequestAction { - let current_tokens = self.estimated_tokens(context); - - // Internal mechanism: between-requests compaction trigger (safety net). - if let Some(state) = self.compact_state.as_ref() { - if !state.is_disabled() && !state.just_compacted() { - let current = current_tokens.unwrap_or(0); - if state.exceeds_request(current) { - let shape = context_shape(context); - info!( - input_tokens = current, - threshold = state.request_threshold().unwrap_or(0), - items_len = shape.items_len, - items_json_bytes = shape.items_json_bytes, - reasoning_items = shape.reasoning_items, - reasoning_encrypted_content_count = shape.reasoning_encrypted_content_count, - reasoning_encrypted_content_bytes = shape.reasoning_encrypted_content_bytes, - "Between-requests compaction threshold exceeded, yielding" - ); - return PreRequestAction::Yield; - } - } + let initial_tokens = self.estimated_tokens(context); + if self.request_threshold_exceeded(initial_tokens, context) { + return PreRequestAction::Yield; } - let info = PreRequestInfo { item_count: context.len(), - estimated_tokens: current_tokens, + estimated_tokens: initial_tokens, turn_index: self.current_turn_index(), tool_calls_this_turn: self.tool_calls_this_turn.load(Ordering::Relaxed), }; @@ -257,16 +262,36 @@ impl Interceptor for PodInterceptor { .lock() .expect("pending hook system-item queue poisoned"), ); + let appended_items: Vec = system_items + .iter() + .map(SystemItem::to_history_item) + .collect(); + let effective_context = if appended_items.is_empty() { + Cow::Borrowed(context.as_slice()) + } else { + let mut effective = context.clone(); + effective.extend(appended_items.clone()); + Cow::Owned(effective) + }; + let current_tokens = self.estimated_tokens(effective_context.as_ref()); + + if self.request_threshold_exceeded(current_tokens, effective_context.as_ref()) { + self.commit_system_items(&system_items); + return if appended_items.is_empty() { + PreRequestAction::Yield + } else { + PreRequestAction::YieldWith(appended_items) + }; + } + + if let Some(usage_tracker) = self.usage_tracker.as_ref() { + usage_tracker.note_request(effective_context.len()); + } if system_items.is_empty() { return PreRequestAction::Continue; } self.commit_system_items(&system_items); - PreRequestAction::ContinueWith( - system_items - .into_iter() - .map(|item| item.to_history_item()) - .collect(), - ) + PreRequestAction::ContinueWith(appended_items) } async fn pre_tool_call(&self, info: &mut ToolCallInfo) -> PreToolAction { @@ -396,6 +421,8 @@ mod tests { use std::sync::atomic::{AtomicBool, AtomicUsize}; use super::*; + use crate::feature::FeatureRegistryBuilder; + use crate::feature::builtin::TaskFeature; use crate::hook::{ Hook, HookPostToolAction, HookPreRequestAction, HookPreToolAction, HookRegistryBuilder, HookTurnEndAction, OnTurnEnd, PostToolCall, PreLlmRequest, PreToolCall, @@ -504,6 +531,41 @@ mod tests { assert_eq!(count.load(Ordering::Relaxed), 0); } + #[tokio::test] + async fn pre_llm_request_yields_with_hook_appends_when_post_append_threshold_exceeded() { + let saw_handle = Arc::new(AtomicBool::new(false)); + let mut builder = HookRegistryBuilder::new(); + builder.add_pre_llm_request(AppendingPreRequestHook { + saw_handle: Arc::clone(&saw_handle), + }); + let registry = Arc::new(builder.build()); + let state = Arc::new(CompactState::new(None, Some(50), 2)); + let ctx_items = vec![Item::user_message("hi")]; + let history = usage_handle_with(ctx_items.len(), 50); + let committed = Arc::new(Mutex::new(Vec::new())); + + let interceptor = PodInterceptor::new( + registry, + Some(state), + Some(history), + NotifyBuffer::new(), + Arc::new(Mutex::new(Vec::new())), + PromptCatalog::builtins_only().unwrap(), + Some(Arc::new(RecordingSystemItemCommitter { + committed: Arc::clone(&committed), + })), + ); + let mut ctx = ctx_items; + let action = interceptor.pre_llm_request(&mut ctx).await; + + match action { + PreRequestAction::YieldWith(items) => assert_eq!(items.len(), 1), + other => panic!("expected YieldWith queued system item, got {other:?}"), + } + assert!(saw_handle.load(Ordering::Relaxed)); + assert_eq!(committed.lock().expect("committed system items").len(), 1); + } + #[tokio::test] async fn pre_llm_request_counts_in_flight_usage_records() { let registry = Arc::new(HookRegistryBuilder::new().build()); @@ -878,6 +940,75 @@ mod tests { assert_eq!(count.load(Ordering::Relaxed), 1); } + #[tokio::test] + async fn task_reminder_hook_append_is_counted_in_usage_request_len() { + let feature = TaskFeature::from_history(&[Item::tool_call( + "task-create-call", + "TaskCreate", + r#"{"subject":"track active work","description":"exercise reminder path"}"#, + )]); + let mut hook_builder = HookRegistryBuilder::new(); + let mut pending_tools = Vec::new(); + FeatureRegistryBuilder::new() + .with_module(feature) + .install_into_pending(&mut pending_tools, &mut hook_builder); + let registry = Arc::new(hook_builder.build()); + let usage_tracker = Arc::new(UsageTracker::new()); + let committed = Arc::new(Mutex::new(Vec::new())); + let interceptor = PodInterceptor::new( + registry, + None, + None, + NotifyBuffer::new(), + Arc::new(Mutex::new(Vec::new())), + PromptCatalog::builtins_only().unwrap(), + Some(Arc::new(RecordingSystemItemCommitter { + committed: Arc::clone(&committed), + })), + ) + .with_usage_tracker(Arc::clone(&usage_tracker)); + + let ctx_items = vec![Item::user_message("hi")]; + for _ in 0..23 { + let mut ctx = ctx_items.clone(); + let action = interceptor.pre_llm_request(&mut ctx).await; + assert!(matches!(action, PreRequestAction::Continue)); + usage_tracker.record_usage(&llm_worker::event::UsageEvent { + input_tokens: Some(10), + output_tokens: Some(0), + total_tokens: Some(10), + cache_read_input_tokens: Some(0), + cache_creation_input_tokens: Some(0), + }); + } + + let mut ctx = ctx_items.clone(); + let action = interceptor.pre_llm_request(&mut ctx).await; + let appended_len = match action { + PreRequestAction::ContinueWith(items) => items.len(), + other => panic!("expected reminder append, got {other:?}"), + }; + assert_eq!(appended_len, 1); + usage_tracker.record_usage(&llm_worker::event::UsageEvent { + input_tokens: Some(11), + output_tokens: Some(0), + total_tokens: Some(11), + cache_read_input_tokens: Some(0), + cache_creation_input_tokens: Some(0), + }); + + let records = usage_tracker.records(); + assert_eq!(records.last().expect("usage record").history_len, 2); + let committed = committed + .lock() + .expect("committed system-item list poisoned"); + assert_eq!(committed.len(), 1); + let SystemItem::TaskReminder { body, .. } = &committed[0] else { + panic!("expected task reminder, got {:?}", committed[0]); + }; + assert!(body.contains("track active work")); + } + #[tokio::test] async fn pending_history_appends_drains_buffer_into_items() { let registry = Arc::new(HookRegistryBuilder::new().build()); diff --git a/crates/pod/src/pod.rs b/crates/pod/src/pod.rs index de231c73..8a741e83 100644 --- a/crates/pod/src/pod.rs +++ b/crates/pod/src/pod.rs @@ -31,8 +31,8 @@ use crate::compact::usage_tracker::UsageTracker; use crate::feature::builtin::TaskFeature; use crate::feature::{FeatureRegistryBuilder, FeatureRegistryInstallReport}; use crate::hook::{ - Hook, HookPreRequestAction, HookRegistryBuilder, OnAbort, OnPromptSubmit, OnTurnEnd, - PostToolCall, PreLlmRequest, PreRequestContext, PreToolCall, + Hook, HookRegistryBuilder, OnAbort, OnPromptSubmit, OnTurnEnd, PostToolCall, PreLlmRequest, + PreToolCall, }; use crate::ipc::alerter::Alerter; use crate::ipc::interceptor::PodInterceptor; @@ -44,6 +44,7 @@ use crate::prompt::system::{SystemPromptContext, SystemPromptError, SystemPrompt use crate::runtime::dir; use crate::runtime::pod_registry::{self, ScopeAllocationGuard, ScopeLockError}; use crate::workflow::WorkflowResolveError; +#[cfg(test)] use async_trait::async_trait; use protocol::{ AlertLevel, AlertSource, Event, RewindSummary, RewindTarget, RewindTargetId, Segment, @@ -213,21 +214,6 @@ where } } -/// Pre-LLM-request hook that records `history.len()` at send time into a -/// shared `UsageTracker`. The on_usage callback later pairs this with the -/// aggregated UsageEvent to produce one `UsageRecord` per LLM call. -struct UsageTrackingHook { - tracker: Arc, -} - -#[async_trait] -impl Hook for UsageTrackingHook { - async fn call(&self, info: &PreRequestContext) -> HookPreRequestAction { - self.tracker.note_request(info.item_count); - HookPreRequestAction::Continue - } -} - /// An independent agent execution unit. /// /// Holds a [`Worker`] directly and persists session state via @@ -1167,13 +1153,6 @@ impl Pod { /// occupancy through the `UsageRecord` timeline. fn ensure_interceptor_installed(&mut self) { if !self.interceptor_installed { - // Pre-LLM-request hook: record the item count at send time - // so the on_usage callback can pair it with the measured - // input_tokens. - self.hook_builder.add_pre_llm_request(UsageTrackingHook { - tracker: self.usage_tracker.clone(), - }); - let builder = std::mem::take(&mut self.hook_builder); let registry = Arc::new(builder.build());