fix: account hook system item appends
This commit is contained in:
parent
c9cb2edc7e
commit
960f2a305e
|
|
@ -42,6 +42,12 @@ pub enum PreRequestAction {
|
||||||
/// to: the items are committed before the request so later turns can see
|
/// to: the items are committed before the request so later turns can see
|
||||||
/// why the worker changed course.
|
/// why the worker changed course.
|
||||||
ContinueWith(Vec<Item>),
|
ContinueWith(Vec<Item>),
|
||||||
|
/// Yield after appending these items to durable worker history.
|
||||||
|
///
|
||||||
|
/// This is for host-mediated pre-request appends that must be visible to
|
||||||
|
/// usage accounting and compaction checks before the current LLM request is
|
||||||
|
/// allowed to proceed.
|
||||||
|
YieldWith(Vec<Item>),
|
||||||
/// Cancel with a reason (treated as an error).
|
/// Cancel with a reason (treated as an error).
|
||||||
Cancel(String),
|
Cancel(String),
|
||||||
/// Yield control to the caller for external processing.
|
/// Yield control to the caller for external processing.
|
||||||
|
|
|
||||||
|
|
@ -1177,6 +1177,16 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
self.last_run_interrupted = true;
|
self.last_run_interrupted = true;
|
||||||
return Err(WorkerError::Aborted(reason));
|
return Err(WorkerError::Aborted(reason));
|
||||||
}
|
}
|
||||||
|
PreRequestAction::YieldWith(items) => {
|
||||||
|
self.append_history_items(items.clone());
|
||||||
|
request_context.extend(items);
|
||||||
|
info!("Yielded by interceptor after pre-request history append");
|
||||||
|
for cb in &self.turn_end_cbs {
|
||||||
|
cb(current_turn);
|
||||||
|
}
|
||||||
|
self.last_run_interrupted = true;
|
||||||
|
return Ok(WorkerResult::Yielded);
|
||||||
|
}
|
||||||
PreRequestAction::Yield => {
|
PreRequestAction::Yield => {
|
||||||
info!("Yielded by interceptor");
|
info!("Yielded by interceptor");
|
||||||
for cb in &self.turn_end_cbs {
|
for cb in &self.turn_end_cbs {
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@
|
||||||
//! event-specific read-only contexts and only return control-flow
|
//! event-specific read-only contexts and only return control-flow
|
||||||
//! decisions (continue / skip / abort / pause).
|
//! decisions (continue / skip / abort / pause).
|
||||||
|
|
||||||
|
use std::borrow::Cow;
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
|
@ -137,6 +138,29 @@ impl PodInterceptor {
|
||||||
}
|
}
|
||||||
Some(total_tokens(context, &records).tokens)
|
Some(total_tokens(context, &records).tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn request_threshold_exceeded(&self, current_tokens: Option<u64>, context: &[Item]) -> bool {
|
||||||
|
if let Some(state) = self.compact_state.as_ref() {
|
||||||
|
if !state.is_disabled() && !state.just_compacted() {
|
||||||
|
let current = current_tokens.unwrap_or(0);
|
||||||
|
if state.exceeds_request(current) {
|
||||||
|
let shape = context_shape(context);
|
||||||
|
info!(
|
||||||
|
input_tokens = current,
|
||||||
|
threshold = state.request_threshold().unwrap_or(0),
|
||||||
|
items_len = shape.items_len,
|
||||||
|
items_json_bytes = shape.items_json_bytes,
|
||||||
|
reasoning_items = shape.reasoning_items,
|
||||||
|
reasoning_encrypted_content_count = shape.reasoning_encrypted_content_count,
|
||||||
|
reasoning_encrypted_content_bytes = shape.reasoning_encrypted_content_bytes,
|
||||||
|
"Between-requests compaction threshold exceeded, yielding"
|
||||||
|
);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
|
@ -210,32 +234,13 @@ impl Interceptor for PodInterceptor {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn pre_llm_request(&self, context: &mut Vec<Item>) -> PreRequestAction {
|
async fn pre_llm_request(&self, context: &mut Vec<Item>) -> PreRequestAction {
|
||||||
let current_tokens = self.estimated_tokens(context);
|
let initial_tokens = self.estimated_tokens(context);
|
||||||
|
if self.request_threshold_exceeded(initial_tokens, context) {
|
||||||
// Internal mechanism: between-requests compaction trigger (safety net).
|
return PreRequestAction::Yield;
|
||||||
if let Some(state) = self.compact_state.as_ref() {
|
|
||||||
if !state.is_disabled() && !state.just_compacted() {
|
|
||||||
let current = current_tokens.unwrap_or(0);
|
|
||||||
if state.exceeds_request(current) {
|
|
||||||
let shape = context_shape(context);
|
|
||||||
info!(
|
|
||||||
input_tokens = current,
|
|
||||||
threshold = state.request_threshold().unwrap_or(0),
|
|
||||||
items_len = shape.items_len,
|
|
||||||
items_json_bytes = shape.items_json_bytes,
|
|
||||||
reasoning_items = shape.reasoning_items,
|
|
||||||
reasoning_encrypted_content_count = shape.reasoning_encrypted_content_count,
|
|
||||||
reasoning_encrypted_content_bytes = shape.reasoning_encrypted_content_bytes,
|
|
||||||
"Between-requests compaction threshold exceeded, yielding"
|
|
||||||
);
|
|
||||||
return PreRequestAction::Yield;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let info = PreRequestInfo {
|
let info = PreRequestInfo {
|
||||||
item_count: context.len(),
|
item_count: context.len(),
|
||||||
estimated_tokens: current_tokens,
|
estimated_tokens: initial_tokens,
|
||||||
turn_index: self.current_turn_index(),
|
turn_index: self.current_turn_index(),
|
||||||
tool_calls_this_turn: self.tool_calls_this_turn.load(Ordering::Relaxed),
|
tool_calls_this_turn: self.tool_calls_this_turn.load(Ordering::Relaxed),
|
||||||
};
|
};
|
||||||
|
|
@ -257,16 +262,36 @@ impl Interceptor for PodInterceptor {
|
||||||
.lock()
|
.lock()
|
||||||
.expect("pending hook system-item queue poisoned"),
|
.expect("pending hook system-item queue poisoned"),
|
||||||
);
|
);
|
||||||
|
let appended_items: Vec<Item> = system_items
|
||||||
|
.iter()
|
||||||
|
.map(SystemItem::to_history_item)
|
||||||
|
.collect();
|
||||||
|
let effective_context = if appended_items.is_empty() {
|
||||||
|
Cow::Borrowed(context.as_slice())
|
||||||
|
} else {
|
||||||
|
let mut effective = context.clone();
|
||||||
|
effective.extend(appended_items.clone());
|
||||||
|
Cow::Owned(effective)
|
||||||
|
};
|
||||||
|
let current_tokens = self.estimated_tokens(effective_context.as_ref());
|
||||||
|
|
||||||
|
if self.request_threshold_exceeded(current_tokens, effective_context.as_ref()) {
|
||||||
|
self.commit_system_items(&system_items);
|
||||||
|
return if appended_items.is_empty() {
|
||||||
|
PreRequestAction::Yield
|
||||||
|
} else {
|
||||||
|
PreRequestAction::YieldWith(appended_items)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(usage_tracker) = self.usage_tracker.as_ref() {
|
||||||
|
usage_tracker.note_request(effective_context.len());
|
||||||
|
}
|
||||||
if system_items.is_empty() {
|
if system_items.is_empty() {
|
||||||
return PreRequestAction::Continue;
|
return PreRequestAction::Continue;
|
||||||
}
|
}
|
||||||
self.commit_system_items(&system_items);
|
self.commit_system_items(&system_items);
|
||||||
PreRequestAction::ContinueWith(
|
PreRequestAction::ContinueWith(appended_items)
|
||||||
system_items
|
|
||||||
.into_iter()
|
|
||||||
.map(|item| item.to_history_item())
|
|
||||||
.collect(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn pre_tool_call(&self, info: &mut ToolCallInfo) -> PreToolAction {
|
async fn pre_tool_call(&self, info: &mut ToolCallInfo) -> PreToolAction {
|
||||||
|
|
@ -396,6 +421,8 @@ mod tests {
|
||||||
use std::sync::atomic::{AtomicBool, AtomicUsize};
|
use std::sync::atomic::{AtomicBool, AtomicUsize};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::feature::FeatureRegistryBuilder;
|
||||||
|
use crate::feature::builtin::TaskFeature;
|
||||||
use crate::hook::{
|
use crate::hook::{
|
||||||
Hook, HookPostToolAction, HookPreRequestAction, HookPreToolAction, HookRegistryBuilder,
|
Hook, HookPostToolAction, HookPreRequestAction, HookPreToolAction, HookRegistryBuilder,
|
||||||
HookTurnEndAction, OnTurnEnd, PostToolCall, PreLlmRequest, PreToolCall,
|
HookTurnEndAction, OnTurnEnd, PostToolCall, PreLlmRequest, PreToolCall,
|
||||||
|
|
@ -504,6 +531,41 @@ mod tests {
|
||||||
assert_eq!(count.load(Ordering::Relaxed), 0);
|
assert_eq!(count.load(Ordering::Relaxed), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn pre_llm_request_yields_with_hook_appends_when_post_append_threshold_exceeded() {
|
||||||
|
let saw_handle = Arc::new(AtomicBool::new(false));
|
||||||
|
let mut builder = HookRegistryBuilder::new();
|
||||||
|
builder.add_pre_llm_request(AppendingPreRequestHook {
|
||||||
|
saw_handle: Arc::clone(&saw_handle),
|
||||||
|
});
|
||||||
|
let registry = Arc::new(builder.build());
|
||||||
|
let state = Arc::new(CompactState::new(None, Some(50), 2));
|
||||||
|
let ctx_items = vec![Item::user_message("hi")];
|
||||||
|
let history = usage_handle_with(ctx_items.len(), 50);
|
||||||
|
let committed = Arc::new(Mutex::new(Vec::new()));
|
||||||
|
|
||||||
|
let interceptor = PodInterceptor::new(
|
||||||
|
registry,
|
||||||
|
Some(state),
|
||||||
|
Some(history),
|
||||||
|
NotifyBuffer::new(),
|
||||||
|
Arc::new(Mutex::new(Vec::new())),
|
||||||
|
PromptCatalog::builtins_only().unwrap(),
|
||||||
|
Some(Arc::new(RecordingSystemItemCommitter {
|
||||||
|
committed: Arc::clone(&committed),
|
||||||
|
})),
|
||||||
|
);
|
||||||
|
let mut ctx = ctx_items;
|
||||||
|
let action = interceptor.pre_llm_request(&mut ctx).await;
|
||||||
|
|
||||||
|
match action {
|
||||||
|
PreRequestAction::YieldWith(items) => assert_eq!(items.len(), 1),
|
||||||
|
other => panic!("expected YieldWith queued system item, got {other:?}"),
|
||||||
|
}
|
||||||
|
assert!(saw_handle.load(Ordering::Relaxed));
|
||||||
|
assert_eq!(committed.lock().expect("committed system items").len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn pre_llm_request_counts_in_flight_usage_records() {
|
async fn pre_llm_request_counts_in_flight_usage_records() {
|
||||||
let registry = Arc::new(HookRegistryBuilder::new().build());
|
let registry = Arc::new(HookRegistryBuilder::new().build());
|
||||||
|
|
@ -878,6 +940,75 @@ mod tests {
|
||||||
assert_eq!(count.load(Ordering::Relaxed), 1);
|
assert_eq!(count.load(Ordering::Relaxed), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn task_reminder_hook_append_is_counted_in_usage_request_len() {
|
||||||
|
let feature = TaskFeature::from_history(&[Item::tool_call(
|
||||||
|
"task-create-call",
|
||||||
|
"TaskCreate",
|
||||||
|
r#"{"subject":"track active work","description":"exercise reminder path"}"#,
|
||||||
|
)]);
|
||||||
|
let mut hook_builder = HookRegistryBuilder::new();
|
||||||
|
let mut pending_tools = Vec::new();
|
||||||
|
FeatureRegistryBuilder::new()
|
||||||
|
.with_module(feature)
|
||||||
|
.install_into_pending(&mut pending_tools, &mut hook_builder);
|
||||||
|
let registry = Arc::new(hook_builder.build());
|
||||||
|
let usage_tracker = Arc::new(UsageTracker::new());
|
||||||
|
let committed = Arc::new(Mutex::new(Vec::new()));
|
||||||
|
let interceptor = PodInterceptor::new(
|
||||||
|
registry,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
NotifyBuffer::new(),
|
||||||
|
Arc::new(Mutex::new(Vec::new())),
|
||||||
|
PromptCatalog::builtins_only().unwrap(),
|
||||||
|
Some(Arc::new(RecordingSystemItemCommitter {
|
||||||
|
committed: Arc::clone(&committed),
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.with_usage_tracker(Arc::clone(&usage_tracker));
|
||||||
|
|
||||||
|
let ctx_items = vec![Item::user_message("hi")];
|
||||||
|
for _ in 0..23 {
|
||||||
|
let mut ctx = ctx_items.clone();
|
||||||
|
let action = interceptor.pre_llm_request(&mut ctx).await;
|
||||||
|
assert!(matches!(action, PreRequestAction::Continue));
|
||||||
|
usage_tracker.record_usage(&llm_worker::event::UsageEvent {
|
||||||
|
input_tokens: Some(10),
|
||||||
|
output_tokens: Some(0),
|
||||||
|
total_tokens: Some(10),
|
||||||
|
cache_read_input_tokens: Some(0),
|
||||||
|
cache_creation_input_tokens: Some(0),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut ctx = ctx_items.clone();
|
||||||
|
let action = interceptor.pre_llm_request(&mut ctx).await;
|
||||||
|
let appended_len = match action {
|
||||||
|
PreRequestAction::ContinueWith(items) => items.len(),
|
||||||
|
other => panic!("expected reminder append, got {other:?}"),
|
||||||
|
};
|
||||||
|
assert_eq!(appended_len, 1);
|
||||||
|
usage_tracker.record_usage(&llm_worker::event::UsageEvent {
|
||||||
|
input_tokens: Some(11),
|
||||||
|
output_tokens: Some(0),
|
||||||
|
total_tokens: Some(11),
|
||||||
|
cache_read_input_tokens: Some(0),
|
||||||
|
cache_creation_input_tokens: Some(0),
|
||||||
|
});
|
||||||
|
|
||||||
|
let records = usage_tracker.records();
|
||||||
|
assert_eq!(records.last().expect("usage record").history_len, 2);
|
||||||
|
let committed = committed
|
||||||
|
.lock()
|
||||||
|
.expect("committed system-item list poisoned");
|
||||||
|
assert_eq!(committed.len(), 1);
|
||||||
|
let SystemItem::TaskReminder { body, .. } = &committed[0] else {
|
||||||
|
panic!("expected task reminder, got {:?}", committed[0]);
|
||||||
|
};
|
||||||
|
assert!(body.contains("track active work"));
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn pending_history_appends_drains_buffer_into_items() {
|
async fn pending_history_appends_drains_buffer_into_items() {
|
||||||
let registry = Arc::new(HookRegistryBuilder::new().build());
|
let registry = Arc::new(HookRegistryBuilder::new().build());
|
||||||
|
|
|
||||||
|
|
@ -31,8 +31,8 @@ use crate::compact::usage_tracker::UsageTracker;
|
||||||
use crate::feature::builtin::TaskFeature;
|
use crate::feature::builtin::TaskFeature;
|
||||||
use crate::feature::{FeatureRegistryBuilder, FeatureRegistryInstallReport};
|
use crate::feature::{FeatureRegistryBuilder, FeatureRegistryInstallReport};
|
||||||
use crate::hook::{
|
use crate::hook::{
|
||||||
Hook, HookPreRequestAction, HookRegistryBuilder, OnAbort, OnPromptSubmit, OnTurnEnd,
|
Hook, HookRegistryBuilder, OnAbort, OnPromptSubmit, OnTurnEnd, PostToolCall, PreLlmRequest,
|
||||||
PostToolCall, PreLlmRequest, PreRequestContext, PreToolCall,
|
PreToolCall,
|
||||||
};
|
};
|
||||||
use crate::ipc::alerter::Alerter;
|
use crate::ipc::alerter::Alerter;
|
||||||
use crate::ipc::interceptor::PodInterceptor;
|
use crate::ipc::interceptor::PodInterceptor;
|
||||||
|
|
@ -44,6 +44,7 @@ use crate::prompt::system::{SystemPromptContext, SystemPromptError, SystemPrompt
|
||||||
use crate::runtime::dir;
|
use crate::runtime::dir;
|
||||||
use crate::runtime::pod_registry::{self, ScopeAllocationGuard, ScopeLockError};
|
use crate::runtime::pod_registry::{self, ScopeAllocationGuard, ScopeLockError};
|
||||||
use crate::workflow::WorkflowResolveError;
|
use crate::workflow::WorkflowResolveError;
|
||||||
|
#[cfg(test)]
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use protocol::{
|
use protocol::{
|
||||||
AlertLevel, AlertSource, Event, RewindSummary, RewindTarget, RewindTargetId, Segment,
|
AlertLevel, AlertSource, Event, RewindSummary, RewindTarget, RewindTargetId, Segment,
|
||||||
|
|
@ -213,21 +214,6 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Pre-LLM-request hook that records `history.len()` at send time into a
|
|
||||||
/// shared `UsageTracker`. The on_usage callback later pairs this with the
|
|
||||||
/// aggregated UsageEvent to produce one `UsageRecord` per LLM call.
|
|
||||||
struct UsageTrackingHook {
|
|
||||||
tracker: Arc<UsageTracker>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Hook<PreLlmRequest> for UsageTrackingHook {
|
|
||||||
async fn call(&self, info: &PreRequestContext) -> HookPreRequestAction {
|
|
||||||
self.tracker.note_request(info.item_count);
|
|
||||||
HookPreRequestAction::Continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// An independent agent execution unit.
|
/// An independent agent execution unit.
|
||||||
///
|
///
|
||||||
/// Holds a [`Worker`] directly and persists session state via
|
/// Holds a [`Worker`] directly and persists session state via
|
||||||
|
|
@ -1167,13 +1153,6 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
|
||||||
/// occupancy through the `UsageRecord` timeline.
|
/// occupancy through the `UsageRecord` timeline.
|
||||||
fn ensure_interceptor_installed(&mut self) {
|
fn ensure_interceptor_installed(&mut self) {
|
||||||
if !self.interceptor_installed {
|
if !self.interceptor_installed {
|
||||||
// Pre-LLM-request hook: record the item count at send time
|
|
||||||
// so the on_usage callback can pair it with the measured
|
|
||||||
// input_tokens.
|
|
||||||
self.hook_builder.add_pre_llm_request(UsageTrackingHook {
|
|
||||||
tracker: self.usage_tracker.clone(),
|
|
||||||
});
|
|
||||||
|
|
||||||
let builder = std::mem::take(&mut self.hook_builder);
|
let builder = std::mem::take(&mut self.hook_builder);
|
||||||
let registry = Arc::new(builder.build());
|
let registry = Arc::new(builder.build());
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user