493 lines
18 KiB
Rust
493 lines
18 KiB
Rust
//! Pod-owned `Interceptor` implementation.
|
|
//!
|
|
//! Bridges Pod's internal mechanisms (compaction trigger today;
|
|
//! notification injection / output truncation in the future) and the
|
|
//! public `HookRegistry`. Internal mechanisms run first and have full
|
|
//! mutable access via the `Interceptor` trait. Hooks then receive
|
|
//! read-only summary information and only return control-flow
|
|
//! decisions (continue / skip / abort / pause).
|
|
|
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
use async_trait::async_trait;
|
|
use llm_worker::Item;
|
|
use llm_worker::UsageRecord;
|
|
use llm_worker::interceptor::{
|
|
Interceptor, PostToolAction, PreRequestAction, PreToolAction, PromptAction, ToolCallInfo,
|
|
ToolResultInfo, TurnEndAction,
|
|
};
|
|
use llm_worker::tool::ToolOutput;
|
|
use tracing::info;
|
|
|
|
use crate::compact::state::CompactState;
|
|
use crate::hook::{
|
|
AbortInfo, HookPromptAction, HookRegistry, PreRequestInfo, PromptSubmitInfo, ToolCallSummary,
|
|
ToolResultSummary, TurnEndInfo,
|
|
};
|
|
use crate::ipc::notify_buffer::{NotifyBuffer, format_notify};
|
|
use crate::prompt::catalog::PromptCatalog;
|
|
use llm_worker::token_counter::total_tokens;
|
|
use tracing::warn;
|
|
|
|
/// Maximum number of bytes copied into `TurnEndInfo::final_text_preview`.
|
|
const FINAL_TEXT_PREVIEW_LIMIT: usize = 512;
|
|
|
|
pub(crate) struct PodInterceptor {
|
|
registry: Arc<HookRegistry>,
|
|
compact_state: Option<Arc<CompactState>>,
|
|
/// Shared view of the cumulative UsageRecord timeline. Used with the
|
|
/// per-request `context` to estimate current occupancy for threshold
|
|
/// checks. `None` when compaction is disabled (both thresholds unset).
|
|
usage_history: Option<Arc<Mutex<Vec<UsageRecord>>>>,
|
|
/// Pending-notification buffer drained into the per-request
|
|
/// context at the head of `pre_llm_request`.
|
|
pending_notifies: NotifyBuffer,
|
|
/// Submit-scoped stash of resolver-produced system messages.
|
|
/// Drained inside `on_prompt_submit` and returned via
|
|
/// `PromptAction::ContinueWith`. Populated by `Pod::run` immediately
|
|
/// before handing off to the worker.
|
|
pending_attachments: Arc<Mutex<Vec<Item>>>,
|
|
/// Prompt catalog used to render the injected notification wrapper.
|
|
prompts: Arc<PromptCatalog>,
|
|
/// Next turn index assigned by `on_prompt_submit`.
|
|
next_turn_index: AtomicUsize,
|
|
/// Tool calls observed in the current turn (reset on each new prompt).
|
|
tool_calls_this_turn: AtomicUsize,
|
|
}
|
|
|
|
impl PodInterceptor {
|
|
pub(crate) fn new(
|
|
registry: Arc<HookRegistry>,
|
|
compact_state: Option<Arc<CompactState>>,
|
|
usage_history: Option<Arc<Mutex<Vec<UsageRecord>>>>,
|
|
pending_notifies: NotifyBuffer,
|
|
pending_attachments: Arc<Mutex<Vec<Item>>>,
|
|
prompts: Arc<PromptCatalog>,
|
|
) -> Self {
|
|
Self {
|
|
registry,
|
|
compact_state,
|
|
usage_history,
|
|
pending_notifies,
|
|
pending_attachments,
|
|
prompts,
|
|
next_turn_index: AtomicUsize::new(0),
|
|
tool_calls_this_turn: AtomicUsize::new(0),
|
|
}
|
|
}
|
|
|
|
fn current_turn_index(&self) -> usize {
|
|
self.next_turn_index
|
|
.load(Ordering::Relaxed)
|
|
.saturating_sub(1)
|
|
}
|
|
|
|
/// Estimate current input-token occupancy for `context`, projected
|
|
/// through the shared UsageRecord timeline. Returns `None` when
|
|
/// `usage_history` is not attached (compaction fully disabled).
|
|
fn estimated_tokens(&self, context: &[Item]) -> Option<u64> {
|
|
let handle = self.usage_history.as_ref()?;
|
|
let records = handle.lock().expect("usage_history poisoned").clone();
|
|
Some(total_tokens(context, &records).tokens)
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Interceptor for PodInterceptor {
|
|
async fn on_prompt_submit(&self, item: &mut Item) -> PromptAction {
|
|
let turn_index = self.next_turn_index.fetch_add(1, Ordering::Relaxed);
|
|
self.tool_calls_this_turn.store(0, Ordering::Relaxed);
|
|
|
|
let info = PromptSubmitInfo {
|
|
input_text: extract_message_text(item).unwrap_or_default(),
|
|
turn_index,
|
|
};
|
|
for hook in &self.registry.on_prompt_submit {
|
|
let action = hook.call(&info).await;
|
|
if !matches!(action, HookPromptAction::Continue) {
|
|
return action.into();
|
|
}
|
|
}
|
|
let extras = std::mem::take(
|
|
&mut *self
|
|
.pending_attachments
|
|
.lock()
|
|
.expect("pending_attachments poisoned"),
|
|
);
|
|
if extras.is_empty() {
|
|
PromptAction::Continue
|
|
} else {
|
|
PromptAction::ContinueWith(extras)
|
|
}
|
|
}
|
|
|
|
async fn pre_llm_request(&self, context: &mut Vec<Item>) -> PreRequestAction {
|
|
let current_tokens = self.estimated_tokens(context);
|
|
|
|
// Internal mechanism: between-requests compaction trigger (safety net).
|
|
if let Some(state) = self.compact_state.as_ref() {
|
|
if !state.is_disabled() && !state.just_compacted() {
|
|
let current = current_tokens.unwrap_or(0);
|
|
if state.exceeds_request(current) {
|
|
info!(
|
|
input_tokens = current,
|
|
threshold = state.request_threshold().unwrap_or(0),
|
|
"Between-requests compaction threshold exceeded, yielding"
|
|
);
|
|
return PreRequestAction::Yield;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Internal mechanism: drain pending `Method::Notify` notifications
|
|
// into the per-request context as transient system messages.
|
|
// These are not persisted to the Worker history; they exist only
|
|
// for this single LLM request.
|
|
for n in self.pending_notifies.drain() {
|
|
match format_notify(&n, &self.prompts) {
|
|
Ok(item) => context.push(item),
|
|
Err(e) => {
|
|
// A render failure here would starve the LLM of the
|
|
// notify text. Fall back to the raw message —
|
|
// it still carries the intent, just without the
|
|
// wrapper phrasing.
|
|
warn!(error = %e, "failed to render notify_wrapper; using raw message");
|
|
context.push(Item::system_message(n.message.clone()));
|
|
}
|
|
}
|
|
}
|
|
|
|
let info = PreRequestInfo {
|
|
item_count: context.len(),
|
|
estimated_tokens: current_tokens,
|
|
turn_index: self.current_turn_index(),
|
|
tool_calls_this_turn: self.tool_calls_this_turn.load(Ordering::Relaxed),
|
|
};
|
|
for hook in &self.registry.pre_llm_request {
|
|
let action = hook.call(&info).await;
|
|
if !matches!(action, PreRequestAction::Continue) {
|
|
return action;
|
|
}
|
|
}
|
|
PreRequestAction::Continue
|
|
}
|
|
|
|
async fn pre_tool_call(&self, info: &mut ToolCallInfo) -> PreToolAction {
|
|
let summary = ToolCallSummary {
|
|
call_id: info.call.id.clone(),
|
|
tool_name: info.call.name.clone(),
|
|
arguments: info.call.input.clone(),
|
|
};
|
|
for hook in &self.registry.pre_tool_call {
|
|
let action = hook.call(&summary).await;
|
|
if !matches!(action, PreToolAction::Continue) {
|
|
return action;
|
|
}
|
|
}
|
|
self.tool_calls_this_turn.fetch_add(1, Ordering::Relaxed);
|
|
PreToolAction::Continue
|
|
}
|
|
|
|
async fn post_tool_call(&self, info: &mut ToolResultInfo) -> PostToolAction {
|
|
let summary = ToolResultSummary {
|
|
call_id: info.result.tool_use_id.clone(),
|
|
tool_name: info.call.name.clone(),
|
|
is_error: info.result.is_error,
|
|
output: ToolOutput {
|
|
summary: info.result.summary.clone(),
|
|
content: info.result.content.clone(),
|
|
},
|
|
};
|
|
for hook in &self.registry.post_tool_call {
|
|
let action = hook.call(&summary).await;
|
|
if !matches!(action, PostToolAction::Continue) {
|
|
return action;
|
|
}
|
|
}
|
|
PostToolAction::Continue
|
|
}
|
|
|
|
async fn on_turn_end(&self, history: &[Item]) -> TurnEndAction {
|
|
let final_text_preview = history
|
|
.iter()
|
|
.rev()
|
|
.find(|i| i.is_assistant_message())
|
|
.and_then(extract_message_text)
|
|
.map(|t| preview(&t, FINAL_TEXT_PREVIEW_LIMIT))
|
|
.unwrap_or_default();
|
|
let info = TurnEndInfo {
|
|
turn_index: self.current_turn_index(),
|
|
tool_calls_count: self.tool_calls_this_turn.load(Ordering::Relaxed),
|
|
final_text_preview,
|
|
};
|
|
for hook in &self.registry.on_turn_end {
|
|
let action = hook.call(&info).await;
|
|
if !matches!(action, TurnEndAction::Finish) {
|
|
return action;
|
|
}
|
|
}
|
|
TurnEndAction::Finish
|
|
}
|
|
|
|
async fn on_abort(&self, reason: &str) {
|
|
let info = AbortInfo {
|
|
reason: reason.to_string(),
|
|
};
|
|
for hook in &self.registry.on_abort {
|
|
hook.call(&info).await;
|
|
}
|
|
}
|
|
}
|
|
|
|
fn extract_message_text(item: &Item) -> Option<String> {
|
|
match item {
|
|
Item::Message { content, .. } => Some(
|
|
content
|
|
.iter()
|
|
.map(|p| p.as_text())
|
|
.collect::<Vec<_>>()
|
|
.join(""),
|
|
),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
fn preview(text: &str, limit: usize) -> String {
|
|
if text.len() <= limit {
|
|
return text.to_string();
|
|
}
|
|
let mut end = limit;
|
|
while end > 0 && !text.is_char_boundary(end) {
|
|
end -= 1;
|
|
}
|
|
text[..end].to_string()
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use std::sync::atomic::{AtomicBool, AtomicUsize};
|
|
|
|
use super::*;
|
|
use crate::hook::{Hook, HookRegistryBuilder, PreLlmRequest};
|
|
|
|
struct CountingHook(Arc<AtomicUsize>);
|
|
|
|
#[async_trait]
|
|
impl Hook<PreLlmRequest> for CountingHook {
|
|
async fn call(&self, _info: &PreRequestInfo) -> PreRequestAction {
|
|
self.0.fetch_add(1, Ordering::Relaxed);
|
|
PreRequestAction::Continue
|
|
}
|
|
}
|
|
|
|
fn registry_with_pre_llm_hook(counter: Arc<AtomicUsize>) -> Arc<HookRegistry> {
|
|
let mut builder = HookRegistryBuilder::new();
|
|
builder.add_pre_llm_request(CountingHook(counter));
|
|
Arc::new(builder.build())
|
|
}
|
|
|
|
/// Build a usage_history handle with a single record pinned at the
|
|
/// current `context_len` so that `total_tokens` returns exactly
|
|
/// `tokens` (Measured, no interpolation or byte-based fallback).
|
|
fn usage_handle_with(context_len: usize, tokens: u64) -> Arc<Mutex<Vec<UsageRecord>>> {
|
|
Arc::new(Mutex::new(vec![UsageRecord {
|
|
history_len: context_len,
|
|
input_total_tokens: tokens,
|
|
cache_read_tokens: 0,
|
|
cache_write_tokens: 0,
|
|
output_tokens: 0,
|
|
}]))
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn pre_llm_request_yields_and_skips_hooks_when_request_threshold_exceeded() {
|
|
let count = Arc::new(AtomicUsize::new(0));
|
|
let registry = registry_with_pre_llm_hook(count.clone());
|
|
|
|
let state = Arc::new(CompactState::new(None, Some(100), 2));
|
|
let ctx_items = vec![Item::user_message("hi")];
|
|
let history = usage_handle_with(ctx_items.len(), 200);
|
|
|
|
let interceptor = PodInterceptor::new(
|
|
registry,
|
|
Some(state),
|
|
Some(history),
|
|
NotifyBuffer::new(),
|
|
Arc::new(Mutex::new(Vec::new())),
|
|
PromptCatalog::builtins_only().unwrap(),
|
|
);
|
|
let mut ctx = ctx_items;
|
|
let action = interceptor.pre_llm_request(&mut ctx).await;
|
|
|
|
assert!(matches!(action, PreRequestAction::Yield));
|
|
// Hook must not run when an internal mechanism short-circuits first.
|
|
assert_eq!(count.load(Ordering::Relaxed), 0);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn pre_llm_request_runs_hooks_when_under_threshold() {
|
|
let count = Arc::new(AtomicUsize::new(0));
|
|
let registry = registry_with_pre_llm_hook(count.clone());
|
|
|
|
let state = Arc::new(CompactState::new(None, Some(100), 2));
|
|
let ctx_items = vec![Item::user_message("hi")];
|
|
let history = usage_handle_with(ctx_items.len(), 50);
|
|
|
|
let interceptor = PodInterceptor::new(
|
|
registry,
|
|
Some(state),
|
|
Some(history),
|
|
NotifyBuffer::new(),
|
|
Arc::new(Mutex::new(Vec::new())),
|
|
PromptCatalog::builtins_only().unwrap(),
|
|
);
|
|
let mut ctx = ctx_items;
|
|
let action = interceptor.pre_llm_request(&mut ctx).await;
|
|
|
|
assert!(matches!(action, PreRequestAction::Continue));
|
|
assert_eq!(count.load(Ordering::Relaxed), 1);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn pre_llm_request_does_not_yield_when_only_post_run_threshold_set() {
|
|
// request_threshold = None → safety-net check is inert inside the turn
|
|
// even if current occupancy is huge. Post-run check runs elsewhere.
|
|
let count = Arc::new(AtomicUsize::new(0));
|
|
let registry = registry_with_pre_llm_hook(count.clone());
|
|
|
|
let state = Arc::new(CompactState::new(Some(100), None, 2));
|
|
let ctx_items = vec![Item::user_message("hi")];
|
|
let history = usage_handle_with(ctx_items.len(), 10_000);
|
|
|
|
let interceptor = PodInterceptor::new(
|
|
registry,
|
|
Some(state),
|
|
Some(history),
|
|
NotifyBuffer::new(),
|
|
Arc::new(Mutex::new(Vec::new())),
|
|
PromptCatalog::builtins_only().unwrap(),
|
|
);
|
|
let mut ctx = ctx_items;
|
|
let action = interceptor.pre_llm_request(&mut ctx).await;
|
|
|
|
assert!(matches!(action, PreRequestAction::Continue));
|
|
assert_eq!(count.load(Ordering::Relaxed), 1);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn pre_llm_request_runs_hooks_when_no_compact_state() {
|
|
let count = Arc::new(AtomicUsize::new(0));
|
|
let registry = registry_with_pre_llm_hook(count.clone());
|
|
|
|
let interceptor = PodInterceptor::new(
|
|
registry,
|
|
None,
|
|
None,
|
|
NotifyBuffer::new(),
|
|
Arc::new(Mutex::new(Vec::new())),
|
|
PromptCatalog::builtins_only().unwrap(),
|
|
);
|
|
let mut ctx: Vec<Item> = Vec::new();
|
|
let action = interceptor.pre_llm_request(&mut ctx).await;
|
|
|
|
assert!(matches!(action, PreRequestAction::Continue));
|
|
assert_eq!(count.load(Ordering::Relaxed), 1);
|
|
}
|
|
|
|
struct AbortingHook(Arc<AtomicBool>);
|
|
|
|
#[async_trait]
|
|
impl Hook<PreLlmRequest> for AbortingHook {
|
|
async fn call(&self, _info: &PreRequestInfo) -> PreRequestAction {
|
|
self.0.store(true, Ordering::Relaxed);
|
|
PreRequestAction::Cancel("nope".into())
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn pre_llm_request_drains_pending_notifies_into_context() {
|
|
let registry = Arc::new(HookRegistryBuilder::new().build());
|
|
let buffer = NotifyBuffer::new();
|
|
buffer.push("first".into());
|
|
buffer.push("second".into());
|
|
|
|
let interceptor = PodInterceptor::new(
|
|
registry,
|
|
None,
|
|
None,
|
|
buffer.clone(),
|
|
Arc::new(Mutex::new(Vec::new())),
|
|
PromptCatalog::builtins_only().unwrap(),
|
|
);
|
|
let mut ctx: Vec<Item> = vec![Item::user_message("hi")];
|
|
let action = interceptor.pre_llm_request(&mut ctx).await;
|
|
|
|
assert!(matches!(action, PreRequestAction::Continue));
|
|
// Original user message preserved, two notifications appended in order.
|
|
assert_eq!(ctx.len(), 3);
|
|
let second = ctx[1].as_text().unwrap_or_default();
|
|
let third = ctx[2].as_text().unwrap_or_default();
|
|
assert!(second.contains("[Notification]"));
|
|
assert!(second.contains("first"));
|
|
assert!(third.contains("[Notification]"));
|
|
assert!(third.contains("second"));
|
|
// Buffer is drained after a single pre_llm_request call.
|
|
assert!(buffer.is_empty());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn pre_llm_request_skips_notification_injection_when_yielding() {
|
|
// When compaction yields, notifications remain in the buffer for
|
|
// the next pre_llm_request (after compaction + resume).
|
|
let registry = Arc::new(HookRegistryBuilder::new().build());
|
|
let buffer = NotifyBuffer::new();
|
|
buffer.push("msg".into());
|
|
|
|
let state = Arc::new(CompactState::new(None, Some(100), 2));
|
|
let ctx_items = vec![Item::user_message("hi")];
|
|
let history = usage_handle_with(ctx_items.len(), 200);
|
|
|
|
let interceptor = PodInterceptor::new(
|
|
registry,
|
|
Some(state),
|
|
Some(history),
|
|
buffer.clone(),
|
|
Arc::new(Mutex::new(Vec::new())),
|
|
PromptCatalog::builtins_only().unwrap(),
|
|
);
|
|
let mut ctx = ctx_items;
|
|
let action = interceptor.pre_llm_request(&mut ctx).await;
|
|
|
|
assert!(matches!(action, PreRequestAction::Yield));
|
|
// Notifications were not drained (still held for post-compact resume).
|
|
assert_eq!(ctx.len(), 1);
|
|
assert_eq!(buffer.len(), 1);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn pre_llm_request_short_circuits_on_first_non_continue() {
|
|
let first_called = Arc::new(AtomicBool::new(false));
|
|
let second_count = Arc::new(AtomicUsize::new(0));
|
|
let mut builder = HookRegistryBuilder::new();
|
|
builder.add_pre_llm_request(AbortingHook(first_called.clone()));
|
|
builder.add_pre_llm_request(CountingHook(second_count.clone()));
|
|
let registry = Arc::new(builder.build());
|
|
|
|
let interceptor = PodInterceptor::new(
|
|
registry,
|
|
None,
|
|
None,
|
|
NotifyBuffer::new(),
|
|
Arc::new(Mutex::new(Vec::new())),
|
|
PromptCatalog::builtins_only().unwrap(),
|
|
);
|
|
let mut ctx: Vec<Item> = Vec::new();
|
|
let action = interceptor.pre_llm_request(&mut ctx).await;
|
|
|
|
assert!(matches!(action, PreRequestAction::Cancel(_)));
|
|
assert!(first_called.load(Ordering::Relaxed));
|
|
assert_eq!(second_count.load(Ordering::Relaxed), 0);
|
|
}
|
|
}
|