feat: protect prune tail by token budget

This commit is contained in:
Keisuke Hirata 2026-05-23 05:00:06 +09:00
parent d18e3a0256
commit 820dea1873
No known key found for this signature in database
12 changed files with 423 additions and 138 deletions

View File

@ -11,12 +11,23 @@
//! 射影の適用は上位層(`pod::prune_hook` 等)が LLM に送る一時コンテキスト
//! に対してだけ行う。Worker の永続履歴は決して変更されない。
//!
//! `min_savings` 判定や savings 推定もこの crate には置かず、上位層が
//! usage 履歴ベースのトークン会計と組み合わせて行う。
//! 保護境界は末尾 token budget で決めるが、この crate は usage 履歴を
//! 所有しない。prefix ごとの token 推定値と savings 推定は上位層から
//! callback で注入される。
use serde::{Deserialize, Serialize};
use crate::llm_client::types::Item;
use crate::token_counter::{EstimateSource, TokenEstimate};
/// Callback that returns token estimates for every prefix boundary of the
/// supplied request history.
///
/// The returned slice must have `history.len() + 1` entries where entry `i`
/// estimates the token count of `history[..i]`. Returning a malformed vector,
/// or estimates whose source is [`EstimateSource::NoData`], makes prune treat
/// the request as having no candidates.
pub type TokenEstimator = Box<dyn Fn(&[Item]) -> Vec<TokenEstimate> + Send + Sync>;
/// Callback that estimates the token savings for projecting the
/// `ToolResult.content` out of `history[i]` for each `i` in `indices`.
@ -35,16 +46,16 @@ pub type SavingsEstimator = Box<dyn Fn(&[Item], &[usize]) -> u64 + Send + Sync>;
///
/// Worker は LLM リクエストごとに 1 回 prune の評価をし、その結果を
/// observer が登録されていればこの値で通知する。fire/skip の判定
/// 結果と、判定材料になった候補数 / 推定 savings / 境界ターン位置を持つ。
/// 結果と、判定材料になった候補数 / 推定 savings / 保護領域の先頭 index を持つ。
#[derive(Debug, Clone)]
pub struct PruneEvaluation {
/// `prunable_indices` の長さ。`Skipped::NoCandidates` の時は 0。
pub candidate_count: usize,
/// 推定された savings (tokens)。`NoCandidates` の時は 0。
pub estimated_savings: u64,
/// `protected_turns` 境界に当たる turn-start アイテムの index。
/// turn 数が `protected_turns` 以下で境界が決まらない場合は `None`。
pub border_turn: Option<usize>,
/// Token budget で保護される suffix の先頭 item index。
/// usage 推定が `NoData` で境界が決まらない場合は `None`。
pub protected_start_index: Option<usize>,
/// 判定結果。
pub decision: PruneDecision,
}
@ -70,10 +81,9 @@ pub type PruneObserver = Box<dyn Fn(&PruneEvaluation) + Send + Sync>;
/// Configuration for the Prune algorithm.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PruneConfig {
/// Number of recent turns to protect from pruning.
/// A "turn" starts at each user message.
#[serde(default = "default_protected_turns")]
pub protected_turns: usize,
/// Token budget at the history tail protected from pruning.
#[serde(default = "default_protected_tokens")]
pub protected_tokens: u64,
/// Minimum token savings required to actually prune. If the prunable
/// content is smaller than this, the caller should skip to avoid
@ -84,8 +94,8 @@ pub struct PruneConfig {
pub min_savings: u64,
}
fn default_protected_turns() -> usize {
3
fn default_protected_tokens() -> u64 {
8000
}
fn default_min_savings() -> u64 {
4096
@ -94,25 +104,12 @@ fn default_min_savings() -> u64 {
impl Default for PruneConfig {
fn default() -> Self {
Self {
protected_turns: default_protected_turns(),
protected_tokens: default_protected_tokens(),
min_savings: default_min_savings(),
}
}
}
/// Find indices where each "turn" begins.
///
/// A turn starts at every user message. Returns the indices of those
/// user messages in ascending order.
fn find_turn_starts(items: &[Item]) -> Vec<usize> {
items
.iter()
.enumerate()
.filter(|(_, item)| item.is_user_message())
.map(|(i, _)| i)
.collect()
}
/// Set `content = None` on each `Item::ToolResult` at the given indices.
///
/// Returns the number of items that were actually modified — items that
@ -121,36 +118,43 @@ fn find_turn_starts(items: &[Item]) -> Vec<usize> {
pub fn project(items: &mut [Item], indices: &[usize]) -> usize {
let mut count = 0;
for &i in indices {
if let Item::ToolResult { content, .. } = &mut items[i] {
if content.is_some() {
*content = None;
count += 1;
}
if let Item::ToolResult { content, .. } = &mut items[i]
&& content.is_some()
{
*content = None;
count += 1;
}
}
count
}
/// Indices of `Item::ToolResult { content: Some(_), .. }` that lie outside
/// the last `protected_turns` turns. Pure: does not mutate `items`.
/// Indices of `Item::ToolResult { content: Some(_), .. }` that lie before
/// the suffix protected by `protected_tokens`. Pure: does not mutate `items`.
///
/// Returns an empty vector when there are too few turns or no prunable
/// candidates.
pub fn prunable_indices(items: &[Item], protected_turns: usize) -> Vec<usize> {
evaluate_candidates(items, protected_turns).0
/// Returns an empty vector when token estimates are unavailable (`NoData`) or
/// no prunable candidates exist.
pub fn prunable_indices(
items: &[Item],
protected_tokens: u64,
token_estimates: &[TokenEstimate],
) -> Vec<usize> {
evaluate_candidates(items, protected_tokens, token_estimates).0
}
/// Same as [`prunable_indices`] but also returns the index of the
/// `protected_turns` boundary (the turn-start item whose tail is
/// protected). `None` when too few turns exist for a boundary to be
/// defined.
pub fn evaluate_candidates(items: &[Item], protected_turns: usize) -> (Vec<usize>, Option<usize>) {
let turn_starts = find_turn_starts(items);
if turn_starts.len() <= protected_turns {
/// Same as [`prunable_indices`] but also returns the start index of the
/// protected suffix. `None` means the token boundary could not be determined
/// (currently because usage estimates were `NoData` or malformed).
pub fn evaluate_candidates(
items: &[Item],
protected_tokens: u64,
token_estimates: &[TokenEstimate],
) -> (Vec<usize>, Option<usize>) {
let Some(protected_start) = protected_start_index(items, protected_tokens, token_estimates)
else {
return (Vec::new(), None);
}
let boundary = turn_starts[turn_starts.len() - protected_turns];
let candidates = items[..boundary]
};
let candidates = items[..protected_start]
.iter()
.enumerate()
.filter_map(|(i, item)| match item {
@ -160,7 +164,38 @@ pub fn evaluate_candidates(items: &[Item], protected_turns: usize) -> (Vec<usize
_ => None,
})
.collect();
(candidates, Some(boundary))
(candidates, Some(protected_start))
}
fn protected_start_index(
items: &[Item],
protected_tokens: u64,
token_estimates: &[TokenEstimate],
) -> Option<usize> {
if token_estimates.len() != items.len() + 1 {
return None;
}
let total = token_estimates[items.len()];
if total.source == EstimateSource::NoData {
return None;
}
if protected_tokens == 0 {
return Some(items.len());
}
let mut protected_start = items.len();
for idx in (0..items.len()).rev() {
let prefix = token_estimates[idx];
if prefix.source == EstimateSource::NoData {
return None;
}
protected_start = idx;
let tail_tokens = total.tokens.saturating_sub(prefix.tokens);
if tail_tokens >= protected_tokens {
break;
}
}
Some(protected_start)
}
#[cfg(test)]
@ -185,17 +220,70 @@ mod tests {
items
}
fn measured_prefix(tokens: &[u64]) -> Vec<TokenEstimate> {
tokens
.iter()
.copied()
.map(|tokens| TokenEstimate {
tokens,
source: EstimateSource::Measured,
})
.collect()
}
fn uniform_estimates(items: &[Item], item_tokens: u64) -> Vec<TokenEstimate> {
let mut tokens = Vec::with_capacity(items.len() + 1);
for i in 0..=items.len() {
tokens.push(i as u64 * item_tokens);
}
measured_prefix(&tokens)
}
fn estimates_from_item_tokens(item_tokens: &[u64]) -> Vec<TokenEstimate> {
let mut prefix = Vec::with_capacity(item_tokens.len() + 1);
let mut acc = 0;
prefix.push(acc);
for tokens in item_tokens {
acc += tokens;
prefix.push(acc);
}
measured_prefix(&prefix)
}
fn no_data_estimates(items: &[Item]) -> Vec<TokenEstimate> {
(0..=items.len())
.map(|i| TokenEstimate {
tokens: i as u64,
source: if i == 0 {
EstimateSource::Measured
} else {
EstimateSource::NoData
},
})
.collect()
}
#[test]
fn no_candidates_when_too_few_turns() {
fn no_candidates_when_estimate_has_no_data() {
let items = make_history(&[("turn1", vec![("summary1", Some("big content here"))])]);
let estimates = no_data_estimates(&items);
let (candidates, protected_start) = evaluate_candidates(&items, 10, &estimates);
assert!(candidates.is_empty());
assert_eq!(protected_start, None);
}
#[test]
fn no_candidates_when_history_fits_in_protected_tokens() {
let items = make_history(&[
("turn1", vec![("summary1", Some("big content here"))]),
("turn2", vec![("summary2", Some("more content"))]),
]);
assert!(prunable_indices(&items, 3).is_empty());
let estimates = uniform_estimates(&items, 10);
assert!(prunable_indices(&items, 10_000, &estimates).is_empty());
}
#[test]
fn candidates_in_unprotected_turns() {
fn candidates_before_token_protected_suffix() {
let big = "x".repeat(4096 * 4);
let items = make_history(&[
("turn1", vec![("s1", Some(&big))]),
@ -203,9 +291,39 @@ mod tests {
("turn3", vec![("s3", Some("keep me"))]),
("turn4", vec![("s4", Some("keep me too"))]),
]);
let candidates = prunable_indices(&items, 2);
let estimates = uniform_estimates(&items, 10);
let candidates = prunable_indices(&items, 80, &estimates);
assert_eq!(candidates.len(), 2);
// suffix budget 80 tokens protects turn3+turn4 (8 items), so only s1/s2 are candidates.
for &i in &candidates {
if let Item::ToolResult { summary, .. } = &items[i] {
assert!(summary == "s1" || summary == "s2");
} else {
panic!("non tool-result selected");
}
}
}
#[test]
fn single_long_task_gets_candidates_without_multiple_user_turns() {
let big = "x".repeat(4096 * 8);
let items = make_history(&[(
"one long task",
vec![
("s1", Some(&big)),
("s2", Some(&big)),
("s3", Some(&big)),
("s4", Some(&big)),
],
)]);
// user + assistant are cheap; every ToolCall is cheap; every ToolResult is heavy.
let item_tokens = vec![1, 1, 1, 5_000, 1, 5_000, 1, 5_000, 1, 5_000];
let estimates = estimates_from_item_tokens(&item_tokens);
let (candidates, protected_start) = evaluate_candidates(&items, 8_000, &estimates);
assert_eq!(protected_start, Some(7));
assert_eq!(candidates.len(), 2);
// 候補は turn1 と turn2 の ToolResult のみ
for &i in &candidates {
if let Item::ToolResult { summary, .. } = &items[i] {
assert!(summary == "s1" || summary == "s2");
@ -223,7 +341,8 @@ mod tests {
("turn3", vec![]),
("turn4", vec![]),
]);
assert!(prunable_indices(&items, 2).is_empty());
let estimates = uniform_estimates(&items, 10);
assert!(prunable_indices(&items, 20, &estimates).is_empty());
}
#[test]
@ -235,7 +354,8 @@ mod tests {
("turn3", vec![("s3", Some("keep me"))]),
("turn4", vec![("s4", Some("keep me too"))]),
]);
let candidates = prunable_indices(&items, 2);
let estimates = uniform_estimates(&items, 10);
let candidates = prunable_indices(&items, 80, &estimates);
let count = project(&mut items, &candidates);
assert_eq!(count, 2);
@ -261,7 +381,7 @@ mod tests {
("turn1", vec![("s1", None)]),
("turn2", vec![("s2", Some("hello"))]),
]);
// Manually target s1 (index 3) even though it's already None.
// Manually target s1 even though it's already None.
let target = items
.iter()
.position(|it| matches!(it, Item::ToolResult { summary, .. } if summary == "s1"))
@ -279,14 +399,15 @@ mod tests {
("turn3", vec![]),
("turn4", vec![]),
]);
let candidates = prunable_indices(&items, 2);
let estimates = uniform_estimates(&items, 10);
let candidates = prunable_indices(&items, 20, &estimates);
assert_eq!(project(&mut items, &candidates), 1);
// 2 周目: 候補は一度の prunable_indices 結果を使い回しても 0 件。
assert_eq!(project(&mut items, &candidates), 0);
}
#[test]
fn evaluate_candidates_returns_boundary_index() {
fn evaluate_candidates_returns_protected_start_index() {
let big = "x".repeat(64);
let items = make_history(&[
("turn1", vec![("s1", Some(&big))]),
@ -294,36 +415,37 @@ mod tests {
("turn3", vec![("s3", Some("keep"))]),
("turn4", vec![("s4", Some("keep too"))]),
]);
let (candidates, border) = evaluate_candidates(&items, 2);
let estimates = uniform_estimates(&items, 10);
let (candidates, protected_start) = evaluate_candidates(&items, 80, &estimates);
assert_eq!(candidates.len(), 2);
// protected_turns=2 → boundary は turn3 の user message 位置。
// turn1: u/a/c/r (4) + turn2: u/a/c/r (4) = index 8 (turn3 の user)。
assert_eq!(border, Some(8));
// protected_tokens=80 → protected suffix is turn3+turn4, starting at index 8.
assert_eq!(protected_start, Some(8));
}
#[test]
fn evaluate_candidates_no_boundary_when_too_few_turns() {
fn evaluate_candidates_reports_zero_start_when_everything_is_protected() {
let items = make_history(&[("only", vec![("s", Some("x"))])]);
let (candidates, border) = evaluate_candidates(&items, 2);
let estimates = uniform_estimates(&items, 10);
let (candidates, protected_start) = evaluate_candidates(&items, 10_000, &estimates);
assert!(candidates.is_empty());
assert!(border.is_none());
assert_eq!(protected_start, Some(0));
}
#[test]
fn protected_turns_boundary_exact() {
// 3 turns with protected_turns=2: only turn 1 is a candidate.
fn zero_protected_tokens_allows_all_tool_results_as_candidates() {
let big = "x".repeat(64);
let items = make_history(&[
("turn1", vec![("s1", Some(&big))]),
("turn2", vec![("s2", Some("protected"))]),
("turn3", vec![("s3", Some("also protected"))]),
]);
let candidates = prunable_indices(&items, 2);
assert_eq!(candidates.len(), 1);
if let Item::ToolResult { summary, .. } = &items[candidates[0]] {
assert_eq!(summary, "s1");
} else {
panic!("expected ToolResult at candidate index");
}
let items = make_history(&[("turn1", vec![("s1", Some(&big)), ("s2", Some(&big))])]);
let estimates = uniform_estimates(&items, 10);
let (candidates, protected_start) = evaluate_candidates(&items, 0, &estimates);
assert_eq!(protected_start, Some(items.len()));
assert_eq!(candidates.len(), 2);
}
#[test]
fn malformed_estimate_vector_is_treated_as_no_boundary() {
let items = make_history(&[("turn1", vec![("s1", Some("x"))])]);
let (candidates, protected_start) = evaluate_candidates(&items, 10, &[]);
assert!(candidates.is_empty());
assert_eq!(protected_start, None);
}
}

View File

@ -201,6 +201,10 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
tool_output_limits: Option<ToolOutputLimits>,
/// Prune configuration. `None` disables the prune projection.
prune_config: Option<crate::prune::PruneConfig>,
/// Callback that estimates prefix token counts, injected by higher
/// layers that own usage measurements. `None` disables the prune
/// projection.
token_estimator: Option<crate::prune::TokenEstimator>,
/// Callback that estimates token savings for a drop range, injected
/// by higher layers that own usage measurements. `None` disables
/// the prune projection.
@ -434,6 +438,17 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
self.prune_config = config;
}
/// Inject the callback used to estimate prefix token counts for prune's
/// protected-token boundary.
///
/// The callback is invoked with the *request context* (a clone of
/// history). It must be pure/idempotent since it may be called once per
/// LLM request. Returning `NoData` estimates makes prune skip as if no
/// candidates existed.
pub fn set_token_estimator(&mut self, estimator: Option<crate::prune::TokenEstimator>) {
self.token_estimator = estimator;
}
/// Inject the callback used to estimate token savings for a prune
/// candidate range.
///
@ -983,18 +998,26 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
// prunable candidates whose estimated savings meet the
// threshold. Worker does not own usage history itself; the
// estimator is injected by the layer that does.
if let (Some(config), Some(estimator)) = (&self.prune_config, &self.savings_estimator) {
let (candidates, border_turn) =
crate::prune::evaluate_candidates(&request_context, config.protected_turns);
if let (Some(config), Some(token_estimator), Some(savings_estimator)) = (
&self.prune_config,
&self.token_estimator,
&self.savings_estimator,
) {
let token_estimates = token_estimator(&request_context);
let (candidates, protected_start_index) = crate::prune::evaluate_candidates(
&request_context,
config.protected_tokens,
&token_estimates,
);
let evaluation = if candidates.is_empty() {
crate::prune::PruneEvaluation {
candidate_count: 0,
estimated_savings: 0,
border_turn,
protected_start_index,
decision: crate::prune::PruneDecision::SkippedNoCandidates,
}
} else {
let savings = estimator(&request_context, &candidates);
let savings = savings_estimator(&request_context, &candidates);
if savings >= config.min_savings {
let pruned = crate::prune::project(&mut request_context, &candidates);
if pruned > 0 {
@ -1007,7 +1030,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
crate::prune::PruneEvaluation {
candidate_count: candidates.len(),
estimated_savings: savings,
border_turn,
protected_start_index,
decision: crate::prune::PruneDecision::Fired {
pruned_count: pruned,
},
@ -1016,7 +1039,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
crate::prune::PruneEvaluation {
candidate_count: candidates.len(),
estimated_savings: savings,
border_turn,
protected_start_index,
decision: crate::prune::PruneDecision::SkippedBelowMinSavings,
}
}
@ -1256,6 +1279,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
cancel_rx,
tool_output_limits: None,
prune_config: None,
token_estimator: None,
savings_estimator: None,
prune_observer: None,
cache_anchor: None,
@ -1519,6 +1543,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
cancel_rx: self.cancel_rx,
tool_output_limits: self.tool_output_limits,
prune_config: self.prune_config,
token_estimator: self.token_estimator,
savings_estimator: self.savings_estimator,
prune_observer: self.prune_observer,
cache_anchor: self.cache_anchor,
@ -1605,6 +1630,7 @@ impl<C: LlmClient> Worker<C, Locked> {
cancel_rx: self.cancel_rx,
tool_output_limits: self.tool_output_limits,
prune_config: self.prune_config,
token_estimator: self.token_estimator,
savings_estimator: self.savings_estimator,
prune_observer: self.prune_observer,
cache_anchor: self.cache_anchor,

View File

@ -10,6 +10,7 @@ use std::collections::HashMap;
use std::num::NonZeroU32;
use std::path::{Path, PathBuf};
use serde::de::Error as _;
use serde::{Deserialize, Serialize};
use crate::defaults;
@ -112,7 +113,7 @@ pub struct PermissionConfigPartial {
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CompactionConfigPartial {
#[serde(default)]
pub prune_protected_turns: Option<usize>,
pub prune_protected_tokens: Option<u64>,
#[serde(default)]
pub prune_min_savings: Option<u64>,
#[serde(default)]
@ -141,12 +142,31 @@ pub enum ResolveError {
RelativePath { field: &'static str, path: PathBuf },
}
/// Reject manifest fields that were intentionally removed and must not be
/// silently swallowed by the general warn-and-ignore unknown-field policy.
pub(crate) fn reject_removed_manifest_fields(s: &str) -> Result<(), toml::de::Error> {
let value: toml::Value = toml::from_str(s)?;
if value
.get("compaction")
.and_then(toml::Value::as_table)
.is_some_and(|table| table.contains_key("prune_protected_turns"))
{
return Err(toml::de::Error::custom(
"unknown field in manifest: compaction.prune_protected_turns \
(removed; use compaction.prune_protected_tokens)",
));
}
Ok(())
}
impl PodManifestConfig {
/// Parse a partial manifest from a TOML string. Unknown top-level or
/// nested fields emit a `tracing::warn!` and are ignored; use
/// `tracing_subscriber` with `WARN` enabled to surface them to the
/// operator.
/// operator. Removed fields that must not be silently ignored (currently
/// `compaction.prune_protected_turns`) are rejected before deserialization.
pub fn from_toml(s: &str) -> Result<Self, toml::de::Error> {
reject_removed_manifest_fields(s)?;
let de = toml::Deserializer::parse(s)?;
serde_ignored::deserialize(de, |path| {
tracing::warn!("unknown field in manifest: {}", path);
@ -339,7 +359,7 @@ impl PermissionConfigPartial {
impl CompactionConfigPartial {
fn merge(self, upper: Self) -> Self {
Self {
prune_protected_turns: upper.prune_protected_turns.or(self.prune_protected_turns),
prune_protected_tokens: upper.prune_protected_tokens.or(self.prune_protected_tokens),
prune_min_savings: upper.prune_min_savings.or(self.prune_min_savings),
compact_threshold: upper.compact_threshold.or(self.compact_threshold),
compact_request_threshold: upper
@ -489,9 +509,9 @@ impl TryFrom<PodManifestConfig> for PodManifest {
validate_model_paths(cm, "compaction.model.auth.file")?;
}
Ok(CompactionConfig {
prune_protected_turns: c
.prune_protected_turns
.unwrap_or(defaults::PRUNE_PROTECTED_TURNS),
prune_protected_tokens: c
.prune_protected_tokens
.unwrap_or(defaults::PRUNE_PROTECTED_TOKENS),
prune_min_savings: c.prune_min_savings.unwrap_or(defaults::PRUNE_MIN_SAVINGS),
compact_threshold: c.compact_threshold,
compact_request_threshold: c.compact_request_threshold,
@ -921,7 +941,7 @@ mod tests {
let lower = PodManifestConfig {
compaction: Some(CompactionConfigPartial {
compact_threshold: Some(50_000),
prune_protected_turns: Some(5),
prune_protected_tokens: Some(5_000),
..Default::default()
}),
..Default::default()
@ -937,7 +957,7 @@ mod tests {
let c = merged.compaction.unwrap();
assert_eq!(c.compact_threshold, Some(80_000));
// field from lower retained when upper has None
assert_eq!(c.prune_protected_turns, Some(5));
assert_eq!(c.prune_protected_tokens, Some(5_000));
}
#[test]
@ -971,6 +991,19 @@ unknown_future_field = "tolerated"
assert_eq!(cfg.worker.max_tokens, Some(1000));
}
#[test]
fn from_toml_rejects_removed_prune_protected_turns_field() {
let bad = r#"
[compaction]
prune_protected_turns = 3
"#;
let err = PodManifestConfig::from_toml(bad).unwrap_err();
assert!(
err.to_string().contains("compaction.prune_protected_turns"),
"unexpected error: {err}"
);
}
#[test]
fn from_toml_accepts_worker_reasoning_string_or_integer() {
let effort = PodManifestConfig::from_toml(

View File

@ -14,9 +14,9 @@ pub const TOOL_OUTPUT_MAX_BYTES: usize = 64 * 1024;
/// See [`crate::FileUploadLimits`].
pub const FILE_UPLOAD_MAX_BYTES: usize = 256 * 1024;
/// Number of most-recent turns protected from pruning. See
/// [`crate::CompactionConfig::prune_protected_turns`].
pub const PRUNE_PROTECTED_TURNS: usize = 3;
/// Token budget at the history tail protected from pruning. See
/// [`crate::CompactionConfig::prune_protected_tokens`].
pub const PRUNE_PROTECTED_TOKENS: u64 = 8000;
/// Minimum estimated token savings required to trigger a prune. See
/// [`crate::CompactionConfig::prune_min_savings`].

View File

@ -337,9 +337,9 @@ pub enum ToolPermissionAction {
/// (full history summarisation). Omitting `[compaction]` disables both.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompactionConfig {
/// Number of recent turns protected from pruning.
#[serde(default = "default_prune_protected_turns")]
pub prune_protected_turns: usize,
/// Token budget at the history tail protected from pruning.
#[serde(default = "default_prune_protected_tokens")]
pub prune_protected_tokens: u64,
/// Minimum estimated token savings to trigger a prune.
#[serde(default = "default_prune_min_savings")]
@ -393,8 +393,8 @@ pub struct CompactionConfig {
pub model: Option<ModelManifest>,
}
fn default_prune_protected_turns() -> usize {
defaults::PRUNE_PROTECTED_TURNS
fn default_prune_protected_tokens() -> u64 {
defaults::PRUNE_PROTECTED_TOKENS
}
fn default_prune_min_savings() -> u64 {
defaults::PRUNE_MIN_SAVINGS
@ -415,7 +415,7 @@ fn default_compact_worker_max_turns() -> Option<u32> {
impl Default for CompactionConfig {
fn default() -> Self {
Self {
prune_protected_turns: default_prune_protected_turns(),
prune_protected_tokens: default_prune_protected_tokens(),
prune_min_savings: default_prune_min_savings(),
compact_threshold: None,
compact_request_threshold: None,
@ -431,6 +431,7 @@ impl Default for CompactionConfig {
impl PodManifest {
/// Parse a manifest from a TOML string.
pub fn from_toml(s: &str) -> Result<Self, toml::de::Error> {
config::reject_removed_manifest_fields(s)?;
toml::from_str(s)
}
}
@ -581,7 +582,7 @@ model_id = "claude-sonnet-4-20250514"
let toml = format!("{MINIMAL_REQUIRED}\n[compaction]\ncompact_threshold = 80000\n");
let manifest = PodManifest::from_toml(&toml).unwrap();
let c = manifest.compaction.unwrap();
assert_eq!(c.prune_protected_turns, 3);
assert_eq!(c.prune_protected_tokens, 8000);
assert_eq!(c.prune_min_savings, 4096);
assert_eq!(c.compact_threshold, Some(80000));
assert_eq!(c.compact_request_threshold, None);
@ -589,6 +590,16 @@ model_id = "claude-sonnet-4-20250514"
assert_eq!(c.compact_worker_max_turns, Some(20));
}
#[test]
fn reject_removed_prune_protected_turns_field() {
let toml = format!("{MINIMAL_REQUIRED}\n[compaction]\nprune_protected_turns = 3\n");
let err = PodManifest::from_toml(&toml).unwrap_err();
assert!(
err.to_string().contains("compaction.prune_protected_turns"),
"unexpected error: {err}"
);
}
#[test]
fn parse_compaction_worker_max_turns() {
let toml = format!(

View File

@ -13,19 +13,24 @@
use llm_worker::Item;
use llm_worker::llm_client::client::LlmClient;
use llm_worker::prune::{PruneConfig, PruneDecision, PruneObserver, SavingsEstimator};
use llm_worker::prune::{
PruneConfig, PruneDecision, PruneObserver, SavingsEstimator, TokenEstimator,
};
use session_metrics::Metric;
use session_store::Store;
use crate::Pod;
use crate::compact::token_counter::{EstimateSource, savings_for_prune_impl};
use crate::compact::token_counter::{
EstimateSource, savings_for_prune_impl, token_estimates_for_prune_impl,
};
impl<C: LlmClient, St: Store> Pod<C, St> {
/// Enable prune projection on the underlying Worker.
///
/// Registers the config and a savings-estimator closure on the Worker.
/// The estimator captures a shared handle to [`Pod::usage_history_handle`]
/// so that every LLM request sees the latest measurements.
/// Registers the config and token/savings-estimator closures on the Worker.
/// The estimators combine persisted [`Pod::usage_history_handle`] records
/// with in-flight `UsageTracker` records so multi-request tool loops can
/// prune before the surrounding Pod run finishes.
///
/// Measurement-less estimates (before the first LLM call, or immediately
/// after a compact) return `0` from the estimator, which naturally
@ -37,9 +42,25 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
/// [`UsageTracker`] so the next `LlmUsage` can be paired with a
/// `prune.post_request` metric carrying the same id.
pub fn attach_prune(&mut self, config: PruneConfig) {
let usage = self.usage_history_handle();
let usage_history_for_tokens = self.usage_history_handle();
let usage_tracker_for_tokens = self.usage_tracker_handle();
let token_estimator: TokenEstimator = Box::new(move |history: &[Item]| {
let mut snapshot = usage_history_for_tokens
.lock()
.expect("usage_history poisoned")
.clone();
snapshot.extend(usage_tracker_for_tokens.records());
token_estimates_for_prune_impl(history, &snapshot)
});
let usage_history_for_savings = self.usage_history_handle();
let usage_tracker_for_savings = self.usage_tracker_handle();
let estimator: SavingsEstimator = Box::new(move |history: &[Item], indices| {
let snapshot = usage.lock().expect("usage_history poisoned").clone();
let mut snapshot = usage_history_for_savings
.lock()
.expect("usage_history poisoned")
.clone();
snapshot.extend(usage_tracker_for_savings.records());
let est = savings_for_prune_impl(history, &snapshot, indices);
match est.source {
EstimateSource::NoData => 0,
@ -56,8 +77,9 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
.with_value(eval.estimated_savings as f64)
.with_correlation_id(&correlation_id)
.with_dimension("candidate_count", eval.candidate_count.to_string());
if let Some(border) = eval.border_turn {
metric = metric.with_dimension("border_turn", border.to_string());
if let Some(protected_start) = eval.protected_start_index {
metric =
metric.with_dimension("protected_start_index", protected_start.to_string());
}
metrics.push(metric);
usage_tracker.note_correlation_id(correlation_id);
@ -66,17 +88,21 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
metrics.push(Metric::now("prune.skip").with_dimension("reason", "no_candidates"));
}
PruneDecision::SkippedBelowMinSavings => {
metrics.push(
Metric::now("prune.skip")
.with_dimension("reason", "below_min_savings")
.with_dimension("candidate_count", eval.candidate_count.to_string())
.with_value(eval.estimated_savings as f64),
);
let mut metric = Metric::now("prune.skip")
.with_dimension("reason", "below_min_savings")
.with_dimension("candidate_count", eval.candidate_count.to_string())
.with_value(eval.estimated_savings as f64);
if let Some(protected_start) = eval.protected_start_index {
metric =
metric.with_dimension("protected_start_index", protected_start.to_string());
}
metrics.push(metric);
}
});
let worker = self.worker_mut();
worker.set_prune_config(Some(config));
worker.set_token_estimator(Some(token_estimator));
worker.set_savings_estimator(Some(estimator));
worker.set_prune_observer(Some(observer));
}
@ -90,7 +116,7 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
return;
};
let config = PruneConfig {
protected_turns: compaction.prune_protected_turns,
protected_tokens: compaction.prune_protected_tokens,
min_savings: compaction.prune_min_savings,
};
self.attach_prune(config);

View File

@ -132,6 +132,21 @@ fn tool_result_content_bytes(item: &Item) -> u64 {
item_bytes(item).saturating_sub(item_bytes(&cleared))
}
/// Prefix-boundary token estimates used by Prune to find its protected suffix.
///
/// Returns `history.len() + 1` entries where entry `i` estimates
/// `history[..i]`. This shares the same [`tokens_at`] accounting as compact's
/// retained-tail split and prune's savings estimate.
pub(crate) fn token_estimates_for_prune_impl(
history: &[Item],
records: &[UsageRecord],
) -> Vec<TokenEstimate> {
let prefix = prefix_bytes(history);
(0..=history.len())
.map(|idx| tokens_at(history, records, idx, &prefix))
.collect()
}
/// Prune 射影(`ToolResult.content = None`)で節約されるトークン数の推定。
///
/// `indices` は [`llm_worker::prune::prunable_indices`] が返す候補列を
@ -278,6 +293,26 @@ mod tests {
}
}
#[test]
fn token_estimates_for_prune_returns_every_prefix_boundary() {
let history = vec![msg("a"), msg("b"), msg("c")];
let estimates = token_estimates_for_prune_impl(&history, &[record(3, 300)]);
assert_eq!(estimates.len(), history.len() + 1);
assert_eq!(estimates[0].tokens, 0);
assert_eq!(estimates[3].tokens, 300);
assert_eq!(estimates[3].source, EstimateSource::Measured);
}
#[test]
fn token_estimates_for_prune_propagates_no_data() {
let history = vec![msg("a"), msg("b")];
let estimates = token_estimates_for_prune_impl(&history, &[]);
assert_eq!(estimates.len(), history.len() + 1);
assert_eq!(estimates[0].source, EstimateSource::Measured);
assert_eq!(estimates[1].source, EstimateSource::NoData);
assert_eq!(estimates[2].source, EstimateSource::NoData);
}
#[test]
fn savings_for_prune_skips_non_toolresult_indices() {
let history = vec![msg("a"), msg("b"), msg("c")];

View File

@ -4,10 +4,10 @@
//! returns a long `ToolOutput.content`, then inspects the persisted
//! session log to verify:
//!
//! - `prune.skip { reason: "no_candidates" }` lands when the protected-turn
//! window covers the entire history.
//! - `prune.fire` lands once enough turns + usage measurements exist for
//! the projection to actually apply.
//! - `prune.skip { reason: "no_candidates" }` lands when usage estimates are
//! unavailable or the protected-token window covers all tool results.
//! - `prune.fire` lands once enough measured history exceeds the protected-token
//! budget for the projection to actually apply.
//! - The fire metric and the immediately-following `prune.post_request`
//! metric share the same `correlation_id`, so cache_read / cache_write
//! from the LlmUsage that triggered the projection can be joined back
@ -136,7 +136,7 @@ fn text_response_with_cache(text: &str, cache_read: u64, cache_write: u64) -> Ve
]
}
fn manifest_toml(prune_protected_turns: usize, prune_min_savings: u64) -> String {
fn manifest_toml(prune_protected_tokens: u64, prune_min_savings: u64) -> String {
format!(
r#"
[pod]
@ -151,7 +151,7 @@ model_id = "test-model"
max_tokens = 100
[compaction]
prune_protected_turns = {prune_protected_turns}
prune_protected_tokens = {prune_protected_tokens}
prune_min_savings = {prune_min_savings}
[[scope.allow]]
@ -192,7 +192,7 @@ async fn prune_metrics_emit_skip_then_fire_with_post_request_join() {
// Run 1 (request 0): tool_use → triggers tool execution → request 1
// on the second iteration to produce the assistant reply.
// Run 2 (request 2): plain assistant text. Prune evaluation here
// sees user1's tool_result outside the 1-protected-turn window and
// sees user1's tool_result outside the protected-token suffix and
// should fire.
let client = MockClient::new(vec![
tool_use_response("call-1", "big_tool"),
@ -250,8 +250,8 @@ async fn prune_metrics_emit_skip_then_fire_with_post_request_join() {
"fire missing candidate_count: {fire:?}"
);
assert!(
fire.dimensions.contains_key("border_turn"),
"fire missing border_turn: {fire:?}"
fire.dimensions.contains_key("protected_start_index"),
"fire missing protected_start_index: {fire:?}"
);
assert!(fire.value.is_some(), "fire missing estimated_savings value");
let fire_id = fire
@ -277,6 +277,36 @@ async fn prune_metrics_emit_skip_then_fire_with_post_request_join() {
assert!(post.dimensions.contains_key("history_len"));
}
#[tokio::test]
async fn prune_metrics_fire_during_single_long_task_without_multiple_user_turns() {
let client = MockClient::new(vec![
tool_use_response("call-1", "big_tool"),
tool_use_response("call-2", "big_tool"),
tool_use_response("call-3", "big_tool"),
tool_use_response("call-4", "big_tool"),
text_response_with_cache("done", 100, 20),
]);
let (mut pod, _store_tmp, _pwd_tmp) = make_pod(manifest_toml(1, 1), client, "big_tool").await;
let session_id = pod.session_id();
let segment_id = pod.segment_id();
let store = pod.store().clone();
pod.run_text("one long task").await.unwrap();
let state = session_store::restore(&store, session_id, segment_id).unwrap();
let metrics = metrics_from_extensions(&state.extensions);
let fire_count = metrics.iter().filter(|m| m.name == "prune.fire").count();
assert!(
fire_count > 0,
"single-turn tool loop should produce prune.fire once old heavy ToolResults fall outside the protected-token suffix: {metrics:?}"
);
assert!(
metrics.iter().any(|m| {
m.name == "prune.fire" && m.dimensions.contains_key("protected_start_index")
})
);
}
/// `min_savings` set high enough that candidates exist but the estimated
/// savings always fall short → the second run should record
/// `prune.skip { reason: "below_min_savings" }`.
@ -288,7 +318,7 @@ async fn prune_metrics_record_below_min_savings_skip() {
text_response_with_cache("done", 0, 0),
]);
let (mut pod, _store_tmp, _pwd_tmp) =
make_pod(manifest_toml(1, u64::MAX), client, "big_tool").await;
make_pod(manifest_toml(1, 1_000_000), client, "big_tool").await;
let session_id = pod.session_id();
let segment_id = pod.segment_id();
let store = pod.store().clone();
@ -405,7 +435,7 @@ async fn metric_write_failure_emits_warn_alert_and_does_not_abort_run() {
// Even with a tool registered, this run will only emit
// `prune.skip { reason: "no_candidates" }` (one user message,
// protected_turns=1 covers everything). That is enough to drive
// protected token budget covers the only user message). That is enough to drive
// the failure path: at least one metric attempts to write.
let client = MockClient::new(vec![text_response_with_cache("hi", 0, 0)]);
let worker = Worker::new(client);

View File

@ -104,7 +104,7 @@ mod tests {
#[test]
fn metric_round_trip_via_json() {
let metric = Metric::now("prune.fire")
.with_dimension("border_turn", "3")
.with_dimension("protected_start_index", "3")
.with_dimension("candidate_count", "2")
.with_value(4096.0)
.with_correlation_id("abc-123");

View File

@ -38,6 +38,7 @@ Pod::try_pre_run_compact ← proactive
- **条件付き実行**: 推定トークン節約量が `min_savings` を超えた場合のみ。KV キャッシュの無駄な無効化を避ける
- **リクエストコンテキストのみ操作**: history 本体は変更しない。Prune 状態を Pod が保持し、LLM リクエスト構築時に反映する
- **保護境界**: 直近 `prune_protected_tokens` 相当の suffix は残す。turn 数ではなく usage history 由来の token estimate で境界を引くため、単発の長い tool loop でも古い `ToolResult.content` が候補になる
- **冪等**: `content: None` のアイテムはスキップ
### ToolOutput の構造
@ -138,8 +139,9 @@ compact は fork と同じ構造。旧セッションを保全し、新 SessionI
[compaction]
compact_threshold = 80000 # ターンの合間 (proactive)
compact_request_threshold = 90000 # リクエストの合間 (safety net)
retained_tokens = 8000 # 直近保護トークン数 (Prune 済みで計測)
auto_read_budget = 8000 # compact worker の mark_read_required 合計上限
prune_protected_tokens = 8000 # prune から保護する末尾 token budget
compact_retained_tokens = 8000 # compact 後に生のまま残す末尾 token budget
compact_auto_read_budget = 8000 # compact worker の mark_read_required 合計上限
compact_worker_max_input_tokens = 50000 # compact worker 自身の現在占有トークン上限
compact_worker_max_turns = 20 # compact worker 自身の tool loop 上限
```

View File

@ -191,9 +191,9 @@ permission = "write"
# セクションを書いた時点で Prune は有効化、Compact は閾値が None なら無効。
# [compaction]
#
# # 任意。デフォルト: 3 (`defaults::PRUNE_PROTECTED_TURNS`)。
# # pruning から保護する末尾ターン数
# prune_protected_turns = 3
# # 任意。デフォルト: 8000 (`defaults::PRUNE_PROTECTED_TOKENS`)。
# # pruning から保護する末尾 token budget。turn 数ではなく usage estimate で境界を引く
# prune_protected_tokens = 8000
#
# # 任意。デフォルト: 4096 (`defaults::PRUNE_MIN_SAVINGS`)。
# # prune が発火するための最低節約 token 推定値。

View File

@ -179,7 +179,7 @@ pattern = "*.env"
action = "deny"
[compaction]
prune_protected_turns = 3
prune_protected_tokens = 8000
prune_min_savings = 4096
compact_threshold = 80000
compact_request_threshold = 90000