527 lines
19 KiB
Rust
527 lines
19 KiB
Rust
//! Compact worker state and the four tools that drive it.
|
|
//!
|
|
//! The compact worker is a disposable `Worker` instance spun up by
|
|
//! [`Pod::compact`]. It receives the history to summarise plus a list of
|
|
//! default reference files (from the session-lifetime `Tracker`) and runs
|
|
//! a tool-driven LLM loop. The tools here let it:
|
|
//!
|
|
//! - `read_file` — inspect referenced files (reuses `tools::read_tool`)
|
|
//! - `mark_read_required(path, offset?, limit?)` — nominate a file whose
|
|
//! contents should be injected into the compacted context as an
|
|
//! auto-read system message
|
|
//! - `add_reference(path)` — nominate a file the next session should
|
|
//! know about by name only (contents not included)
|
|
//! - `write_summary(text)` — deliver (or overwrite) the structured summary
|
|
//!
|
|
//! Everything the worker decides ends up in [`CompactWorkerContext`],
|
|
//! which `Pod::compact` drains after the loop and turns into the
|
|
//! compacted session's opening system messages.
|
|
|
|
use std::path::PathBuf;
|
|
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
use async_trait::async_trait;
|
|
use llm_worker::Item;
|
|
use llm_worker::interceptor::{Interceptor, PreRequestAction, PreToolAction, ToolCallInfo};
|
|
use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta, ToolOutput, ToolResult};
|
|
use serde::Deserialize;
|
|
use tools::ScopedFs;
|
|
|
|
use crate::compact::usage_tracker::UsageTracker;
|
|
use crate::fs_view::{ReadRequirement, slice_lines};
|
|
|
|
/// Aggregated output of a compact worker run.
|
|
#[derive(Debug, Default, Clone)]
|
|
pub(crate) struct CompactWorkerContext {
|
|
pub read_required: Vec<ReadRequirement>,
|
|
pub references: Vec<PathBuf>,
|
|
pub summary: Option<String>,
|
|
/// Tokens already consumed by `mark_read_required` calls.
|
|
pub auto_read_consumed: u64,
|
|
/// Aggregate cap. `0` treats the budget as disabled.
|
|
pub auto_read_budget: u64,
|
|
}
|
|
|
|
impl CompactWorkerContext {
|
|
pub(crate) fn with_budget(auto_read_budget: u64) -> Self {
|
|
Self {
|
|
auto_read_budget,
|
|
..Self::default()
|
|
}
|
|
}
|
|
|
|
fn remaining_budget(&self) -> u64 {
|
|
self.auto_read_budget
|
|
.saturating_sub(self.auto_read_consumed)
|
|
}
|
|
}
|
|
|
|
/// Input to `mark_read_required`.
|
|
#[derive(Debug, Deserialize, schemars::JsonSchema)]
|
|
struct MarkParams {
|
|
/// Absolute path to the file.
|
|
pub file_path: PathBuf,
|
|
/// 0-based line offset.
|
|
#[serde(default)]
|
|
pub offset: Option<usize>,
|
|
/// Maximum number of lines to inject.
|
|
#[serde(default)]
|
|
pub limit: Option<usize>,
|
|
}
|
|
|
|
/// Input to `add_reference`.
|
|
#[derive(Debug, Deserialize, schemars::JsonSchema)]
|
|
struct ReferenceParams {
|
|
/// Absolute path to the file.
|
|
pub file_path: PathBuf,
|
|
}
|
|
|
|
/// Input to `write_summary`.
|
|
#[derive(Debug, Deserialize, schemars::JsonSchema)]
|
|
struct SummaryParams {
|
|
/// Full structured summary text (overwrites any previous call).
|
|
pub text: String,
|
|
}
|
|
|
|
const MARK_DESCRIPTION: &str = "Inject a file's contents into the compacted context so the \
|
|
next session starts with it already read. Use this for files the next task needs in full. \
|
|
Optionally specify `offset` (0-based line) and `limit` (line count) to inject only a slice. \
|
|
Counts against `auto_read_budget`; overflow returns an error and the mark is not recorded. \
|
|
Paths must be absolute.";
|
|
|
|
const REFERENCE_DESCRIPTION: &str = "Record a file path as a named reference in the compacted \
|
|
context without injecting its contents. Use for files that are contextually relevant but \
|
|
whose current content the next session can fetch on demand.";
|
|
|
|
const SUMMARY_DESCRIPTION: &str = "Provide the final structured summary text. Subsequent calls \
|
|
replace the previous content; only the last call is used. Must be called before the compact run \
|
|
ends or compaction fails.";
|
|
|
|
struct MarkReadRequiredTool {
|
|
fs: ScopedFs,
|
|
ctx: Arc<Mutex<CompactWorkerContext>>,
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Tool for MarkReadRequiredTool {
|
|
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
|
let params: MarkParams = serde_json::from_str(input_json).map_err(|e| {
|
|
ToolError::InvalidArgument(format!("invalid mark_read_required input: {e}"))
|
|
})?;
|
|
|
|
// Read the file through the shared ScopedFs so scope and I/O
|
|
// errors surface the same way the regular `read_file` tool does.
|
|
let bytes = self
|
|
.fs
|
|
.read_bytes(¶ms.file_path)
|
|
.map_err(|e| ToolError::ExecutionFailed(format!("read failed: {e}")))?;
|
|
let text = String::from_utf8_lossy(&bytes);
|
|
let slice = slice_lines(&text, params.offset.unwrap_or(0), params.limit);
|
|
let estimated_tokens = estimate_tokens(slice.len());
|
|
|
|
let mut guard = self.ctx.lock().expect("compact worker context poisoned");
|
|
let budget = guard.auto_read_budget;
|
|
let would_consume = guard.auto_read_consumed.saturating_add(estimated_tokens);
|
|
if budget > 0 && would_consume > budget {
|
|
return Err(ToolError::ExecutionFailed(format!(
|
|
"auto-read budget exhausted ({budget} tokens). Remove an existing mark or use \
|
|
add_reference instead."
|
|
)));
|
|
}
|
|
guard.read_required.push(ReadRequirement {
|
|
path: params.file_path.clone(),
|
|
offset: params.offset,
|
|
limit: params.limit,
|
|
});
|
|
guard.auto_read_consumed = would_consume;
|
|
let remaining = guard.remaining_budget();
|
|
drop(guard);
|
|
|
|
let mut summary = format!(
|
|
"Marked {} for auto-read (≈{estimated_tokens} tokens). \
|
|
Budget: {remaining}/{budget} tokens remaining.",
|
|
params.file_path.display()
|
|
);
|
|
if budget > 0 && remaining * 2 <= budget {
|
|
summary.push_str(
|
|
"\nNote: auto-read budget is at least half consumed. \
|
|
Consider calling write_summary and finishing up soon.",
|
|
);
|
|
}
|
|
Ok(ToolOutput {
|
|
summary,
|
|
content: None,
|
|
})
|
|
}
|
|
}
|
|
|
|
struct AddReferenceTool {
|
|
ctx: Arc<Mutex<CompactWorkerContext>>,
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Tool for AddReferenceTool {
|
|
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
|
let params: ReferenceParams = serde_json::from_str(input_json)
|
|
.map_err(|e| ToolError::InvalidArgument(format!("invalid add_reference input: {e}")))?;
|
|
let mut guard = self.ctx.lock().expect("compact worker context poisoned");
|
|
if !guard
|
|
.references
|
|
.iter()
|
|
.any(|p| p.as_path() == params.file_path.as_path())
|
|
{
|
|
guard.references.push(params.file_path.clone());
|
|
}
|
|
Ok(ToolOutput {
|
|
summary: format!("Added reference {}", params.file_path.display()),
|
|
content: None,
|
|
})
|
|
}
|
|
}
|
|
|
|
struct WriteSummaryTool {
|
|
ctx: Arc<Mutex<CompactWorkerContext>>,
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Tool for WriteSummaryTool {
|
|
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
|
let params: SummaryParams = serde_json::from_str(input_json)
|
|
.map_err(|e| ToolError::InvalidArgument(format!("invalid write_summary input: {e}")))?;
|
|
let mut guard = self.ctx.lock().expect("compact worker context poisoned");
|
|
let overwritten = guard.summary.is_some();
|
|
guard.summary = Some(params.text);
|
|
drop(guard);
|
|
let note = if overwritten {
|
|
"Summary replaced."
|
|
} else {
|
|
"Summary recorded."
|
|
};
|
|
Ok(ToolOutput {
|
|
summary: note.to_string(),
|
|
content: None,
|
|
})
|
|
}
|
|
}
|
|
|
|
pub(crate) fn mark_read_required_tool(
|
|
fs: ScopedFs,
|
|
ctx: Arc<Mutex<CompactWorkerContext>>,
|
|
) -> ToolDefinition {
|
|
Arc::new(move || {
|
|
let schema = schemars::schema_for!(MarkParams);
|
|
let schema_value = serde_json::to_value(schema).unwrap_or(serde_json::json!({}));
|
|
let meta = ToolMeta::new("mark_read_required")
|
|
.description(MARK_DESCRIPTION)
|
|
.input_schema(schema_value);
|
|
let tool: Arc<dyn Tool> = Arc::new(MarkReadRequiredTool {
|
|
fs: fs.clone(),
|
|
ctx: ctx.clone(),
|
|
});
|
|
(meta, tool)
|
|
})
|
|
}
|
|
|
|
pub(crate) fn add_reference_tool(ctx: Arc<Mutex<CompactWorkerContext>>) -> ToolDefinition {
|
|
Arc::new(move || {
|
|
let schema = schemars::schema_for!(ReferenceParams);
|
|
let schema_value = serde_json::to_value(schema).unwrap_or(serde_json::json!({}));
|
|
let meta = ToolMeta::new("add_reference")
|
|
.description(REFERENCE_DESCRIPTION)
|
|
.input_schema(schema_value);
|
|
let tool: Arc<dyn Tool> = Arc::new(AddReferenceTool { ctx: ctx.clone() });
|
|
(meta, tool)
|
|
})
|
|
}
|
|
|
|
pub(crate) fn write_summary_tool(ctx: Arc<Mutex<CompactWorkerContext>>) -> ToolDefinition {
|
|
Arc::new(move || {
|
|
let schema = schemars::schema_for!(SummaryParams);
|
|
let schema_value = serde_json::to_value(schema).unwrap_or(serde_json::json!({}));
|
|
let meta = ToolMeta::new("write_summary")
|
|
.description(SUMMARY_DESCRIPTION)
|
|
.input_schema(schema_value);
|
|
let tool: Arc<dyn Tool> = Arc::new(WriteSummaryTool { ctx: ctx.clone() });
|
|
(meta, tool)
|
|
})
|
|
}
|
|
|
|
/// Interceptor that monitors compact-worker context occupancy.
|
|
///
|
|
/// `max_input_tokens` remains the hard circuit breaker. Before that point,
|
|
/// the interceptor can persist a system warning into worker history telling
|
|
/// the model to stop broad exploration and call `write_summary`, and can block
|
|
/// additional exploratory tool calls once the final reserve is reached.
|
|
pub(crate) struct CompactWorkerInterceptor {
|
|
pub usage_tracker: Arc<UsageTracker>,
|
|
pub max_input_tokens: u64,
|
|
pub finish_warning_remaining_tokens: u64,
|
|
pub final_reserve_tokens: u64,
|
|
pub on_warning: Option<Arc<dyn Fn(String) + Send + Sync>>,
|
|
warning_sent: AtomicBool,
|
|
last_remaining_tokens: AtomicU64,
|
|
}
|
|
|
|
impl CompactWorkerInterceptor {
|
|
pub(crate) fn new(
|
|
usage_tracker: Arc<UsageTracker>,
|
|
max_input_tokens: u64,
|
|
finish_warning_remaining_tokens: u64,
|
|
final_reserve_tokens: u64,
|
|
on_warning: Option<Arc<dyn Fn(String) + Send + Sync>>,
|
|
) -> Self {
|
|
Self {
|
|
usage_tracker,
|
|
max_input_tokens,
|
|
finish_warning_remaining_tokens,
|
|
final_reserve_tokens,
|
|
on_warning,
|
|
warning_sent: AtomicBool::new(false),
|
|
last_remaining_tokens: AtomicU64::new(max_input_tokens),
|
|
}
|
|
}
|
|
|
|
fn maybe_emit_warning(&self, remaining: u64) -> Option<Item> {
|
|
let warning_threshold = self.finish_warning_remaining_tokens;
|
|
let reserve_threshold = self.final_reserve_tokens;
|
|
let should_warn = (warning_threshold > 0 && remaining <= warning_threshold)
|
|
|| (reserve_threshold > 0 && remaining <= reserve_threshold);
|
|
if !should_warn || self.warning_sent.swap(true, Ordering::AcqRel) {
|
|
return None;
|
|
}
|
|
|
|
let message = format!(
|
|
"compact worker context budget is low ({remaining}/{} tokens remaining). \
|
|
Stop broad exploration now, read only if absolutely necessary, then call \
|
|
`write_summary` with the final structured summary.",
|
|
self.max_input_tokens
|
|
);
|
|
if let Some(cb) = self.on_warning.as_ref() {
|
|
cb(message.clone());
|
|
}
|
|
Some(Item::system_message(format!(
|
|
"[Compact worker budget warning]\n\n{message}"
|
|
)))
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Interceptor for CompactWorkerInterceptor {
|
|
async fn pre_llm_request(&self, context: &mut Vec<Item>) -> PreRequestAction {
|
|
let records = self.usage_tracker.records();
|
|
let estimate = llm_worker::token_counter::total_tokens(context, &records);
|
|
if estimate.tokens > self.max_input_tokens {
|
|
return PreRequestAction::Cancel(format!(
|
|
"compact worker input occupancy exceeded {} tokens",
|
|
self.max_input_tokens
|
|
));
|
|
}
|
|
|
|
let remaining = self.max_input_tokens.saturating_sub(estimate.tokens);
|
|
self.last_remaining_tokens
|
|
.store(remaining, Ordering::Release);
|
|
if let Some(item) = self.maybe_emit_warning(remaining) {
|
|
self.usage_tracker.note_request(context.len() + 1);
|
|
return PreRequestAction::ContinueWith(vec![item]);
|
|
}
|
|
|
|
self.usage_tracker.note_request(context.len());
|
|
PreRequestAction::Continue
|
|
}
|
|
|
|
async fn pre_tool_call(&self, info: &mut ToolCallInfo) -> PreToolAction {
|
|
if self.final_reserve_tokens == 0 || info.call.name == "write_summary" {
|
|
return PreToolAction::Continue;
|
|
}
|
|
let remaining = self.last_remaining_tokens.load(Ordering::Acquire);
|
|
if remaining > self.final_reserve_tokens {
|
|
return PreToolAction::Continue;
|
|
}
|
|
PreToolAction::SyntheticResult(ToolResult::error(
|
|
info.call.id.clone(),
|
|
"compact worker final reserve reached; do not perform more exploratory tool reads. Call `write_summary` now.",
|
|
))
|
|
}
|
|
}
|
|
|
|
/// Crude bytes→tokens estimate; good enough for budget accounting.
|
|
fn estimate_tokens(bytes: usize) -> u64 {
|
|
(bytes as u64).div_ceil(4)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use manifest::Scope;
|
|
|
|
fn make_fs(tmp: &std::path::Path) -> ScopedFs {
|
|
let scope = Scope::writable(tmp.to_path_buf()).unwrap();
|
|
ScopedFs::new(scope, tmp.to_path_buf())
|
|
}
|
|
|
|
fn make_usage(input: u64) -> llm_worker::timeline::event::UsageEvent {
|
|
llm_worker::timeline::event::UsageEvent {
|
|
input_tokens: Some(input),
|
|
output_tokens: Some(0),
|
|
total_tokens: Some(input),
|
|
cache_read_input_tokens: None,
|
|
cache_creation_input_tokens: None,
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn compact_worker_interceptor_uses_occupancy_not_cumulative_usage() {
|
|
let tracker = Arc::new(UsageTracker::new());
|
|
let interceptor = CompactWorkerInterceptor::new(tracker.clone(), 150, 0, 0, None);
|
|
let mut context = vec![Item::user_message("hello")];
|
|
|
|
assert!(matches!(
|
|
interceptor.pre_llm_request(&mut context).await,
|
|
PreRequestAction::Continue
|
|
));
|
|
tracker.record_usage(&make_usage(100));
|
|
|
|
assert!(matches!(
|
|
interceptor.pre_llm_request(&mut context).await,
|
|
PreRequestAction::Continue
|
|
));
|
|
tracker.record_usage(&make_usage(100));
|
|
|
|
// Two 100-token requests would exceed a cumulative 150-token cap, but
|
|
// current occupancy is still the latest 100-token measurement.
|
|
assert!(matches!(
|
|
interceptor.pre_llm_request(&mut context).await,
|
|
PreRequestAction::Continue
|
|
));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn compact_worker_interceptor_warns_before_hard_cap() {
|
|
let tracker = Arc::new(UsageTracker::new());
|
|
let warnings = Arc::new(Mutex::new(Vec::new()));
|
|
let captured = warnings.clone();
|
|
let interceptor = CompactWorkerInterceptor::new(
|
|
tracker.clone(),
|
|
150,
|
|
60,
|
|
20,
|
|
Some(Arc::new(move |message| {
|
|
captured.lock().unwrap().push(message);
|
|
})),
|
|
);
|
|
let mut context = vec![Item::user_message("hello")];
|
|
|
|
assert!(matches!(
|
|
interceptor.pre_llm_request(&mut context).await,
|
|
PreRequestAction::Continue
|
|
));
|
|
tracker.record_usage(&make_usage(100));
|
|
|
|
assert!(matches!(
|
|
interceptor.pre_llm_request(&mut context).await,
|
|
PreRequestAction::ContinueWith(items)
|
|
if items.len() == 1 && items[0].as_text().unwrap_or_default().contains("write_summary")
|
|
));
|
|
assert_eq!(warnings.lock().unwrap().len(), 1);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn compact_worker_interceptor_cancels_when_occupancy_exceeds_cap() {
|
|
let tracker = Arc::new(UsageTracker::new());
|
|
let interceptor = CompactWorkerInterceptor::new(tracker.clone(), 99, 0, 0, None);
|
|
let mut context = vec![Item::user_message("hello")];
|
|
|
|
assert!(matches!(
|
|
interceptor.pre_llm_request(&mut context).await,
|
|
PreRequestAction::Continue
|
|
));
|
|
tracker.record_usage(&make_usage(100));
|
|
|
|
assert!(matches!(
|
|
interceptor.pre_llm_request(&mut context).await,
|
|
PreRequestAction::Cancel(message) if message.contains("occupancy")
|
|
));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn mark_read_required_records_and_deducts_budget() {
|
|
let tmp = tempfile::TempDir::new().unwrap();
|
|
let path = tmp.path().join("hello.txt");
|
|
std::fs::write(&path, "hello world\n").unwrap();
|
|
|
|
let ctx = Arc::new(Mutex::new(CompactWorkerContext::with_budget(1_000)));
|
|
let tool: Arc<dyn Tool> = Arc::new(MarkReadRequiredTool {
|
|
fs: make_fs(tmp.path()),
|
|
ctx: ctx.clone(),
|
|
});
|
|
let input = serde_json::json!({ "file_path": path.to_str().unwrap() }).to_string();
|
|
let out = tool.execute(&input).await.unwrap();
|
|
|
|
assert!(out.summary.starts_with("Marked"));
|
|
let guard = ctx.lock().unwrap();
|
|
assert_eq!(guard.read_required.len(), 1);
|
|
assert!(guard.auto_read_consumed > 0);
|
|
assert!(guard.auto_read_consumed <= 1_000);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn mark_read_required_rejects_over_budget() {
|
|
let tmp = tempfile::TempDir::new().unwrap();
|
|
let path = tmp.path().join("big.txt");
|
|
std::fs::write(&path, "x".repeat(4_096)).unwrap(); // ≈1024 tokens
|
|
|
|
let ctx = Arc::new(Mutex::new(CompactWorkerContext::with_budget(100)));
|
|
let tool: Arc<dyn Tool> = Arc::new(MarkReadRequiredTool {
|
|
fs: make_fs(tmp.path()),
|
|
ctx: ctx.clone(),
|
|
});
|
|
let input = serde_json::json!({ "file_path": path.to_str().unwrap() }).to_string();
|
|
let res = tool.execute(&input).await;
|
|
|
|
assert!(matches!(res, Err(ToolError::ExecutionFailed(_))));
|
|
let guard = ctx.lock().unwrap();
|
|
assert!(guard.read_required.is_empty());
|
|
assert_eq!(guard.auto_read_consumed, 0);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn write_summary_overwrites_previous_call() {
|
|
let ctx = Arc::new(Mutex::new(CompactWorkerContext::with_budget(0)));
|
|
let tool: Arc<dyn Tool> = Arc::new(WriteSummaryTool { ctx: ctx.clone() });
|
|
|
|
let first = serde_json::json!({ "text": "first" }).to_string();
|
|
let out1 = tool.execute(&first).await.unwrap();
|
|
assert!(out1.summary.contains("recorded"));
|
|
|
|
let second = serde_json::json!({ "text": "second" }).to_string();
|
|
let out2 = tool.execute(&second).await.unwrap();
|
|
assert!(out2.summary.contains("replaced"));
|
|
|
|
assert_eq!(ctx.lock().unwrap().summary.as_deref(), Some("second"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn add_reference_deduplicates() {
|
|
let ctx = Arc::new(Mutex::new(CompactWorkerContext::with_budget(0)));
|
|
let tool: Arc<dyn Tool> = Arc::new(AddReferenceTool { ctx: ctx.clone() });
|
|
|
|
let p = "/abs/path.rs";
|
|
let input = serde_json::json!({ "file_path": p }).to_string();
|
|
tool.execute(&input).await.unwrap();
|
|
tool.execute(&input).await.unwrap();
|
|
|
|
let guard = ctx.lock().unwrap();
|
|
assert_eq!(guard.references.len(), 1);
|
|
assert_eq!(guard.references[0], PathBuf::from(p));
|
|
}
|
|
|
|
#[test]
|
|
fn slice_lines_handles_offset_and_limit() {
|
|
let text = "a\nb\nc\nd";
|
|
assert_eq!(slice_lines(text, 0, None), "a\nb\nc\nd");
|
|
assert_eq!(slice_lines(text, 1, Some(2)), "b\nc");
|
|
assert_eq!(slice_lines(text, 10, None), "");
|
|
}
|
|
}
|