From 9ee7f048051ff8266da0b1f63c312c699e2bcc26 Mon Sep 17 00:00:00 2001 From: Hare Date: Sat, 23 May 2026 05:00:06 +0900 Subject: [PATCH] feat: protect prune tail by token budget --- crates/llm-worker/src/prune.rs | 280 ++++++++++++++++------- crates/llm-worker/src/worker.rs | 40 +++- crates/manifest/src/config.rs | 49 +++- crates/manifest/src/defaults.rs | 6 +- crates/manifest/src/lib.rs | 25 +- crates/pod/src/compact/prune.rs | 58 +++-- crates/pod/src/compact/token_counter.rs | 35 +++ crates/pod/tests/session_metrics_test.rs | 52 ++++- crates/session-metrics/src/lib.rs | 2 +- docs/compaction.md | 6 +- docs/manifest.toml | 6 +- docs/pod-factory.md | 2 +- 12 files changed, 423 insertions(+), 138 deletions(-) diff --git a/crates/llm-worker/src/prune.rs b/crates/llm-worker/src/prune.rs index e05001b5..4b2307f2 100644 --- a/crates/llm-worker/src/prune.rs +++ b/crates/llm-worker/src/prune.rs @@ -11,12 +11,23 @@ //! 射影の適用は上位層(`pod::prune_hook` 等)が LLM に送る一時コンテキスト //! に対してだけ行う。Worker の永続履歴は決して変更されない。 //! -//! `min_savings` 判定や savings 推定もこの crate には置かず、上位層が -//! usage 履歴ベースのトークン会計と組み合わせて行う。 +//! 保護境界は末尾 token budget で決めるが、この crate は usage 履歴を +//! 所有しない。prefix ごとの token 推定値と savings 推定は上位層から +//! callback で注入される。 use serde::{Deserialize, Serialize}; use crate::llm_client::types::Item; +use crate::token_counter::{EstimateSource, TokenEstimate}; + +/// Callback that returns token estimates for every prefix boundary of the +/// supplied request history. +/// +/// The returned slice must have `history.len() + 1` entries where entry `i` +/// estimates the token count of `history[..i]`. Returning a malformed vector, +/// or estimates whose source is [`EstimateSource::NoData`], makes prune treat +/// the request as having no candidates. +pub type TokenEstimator = Box Vec + Send + Sync>; /// Callback that estimates the token savings for projecting the /// `ToolResult.content` out of `history[i]` for each `i` in `indices`. @@ -35,16 +46,16 @@ pub type SavingsEstimator = Box u64 + Send + Sync>; /// /// Worker は LLM リクエストごとに 1 回 prune の評価をし、その結果を /// (observer が登録されていれば)この値で通知する。fire/skip の判定 -/// 結果と、判定材料になった候補数 / 推定 savings / 境界ターン位置を持つ。 +/// 結果と、判定材料になった候補数 / 推定 savings / 保護領域の先頭 index を持つ。 #[derive(Debug, Clone)] pub struct PruneEvaluation { /// `prunable_indices` の長さ。`Skipped::NoCandidates` の時は 0。 pub candidate_count: usize, /// 推定された savings (tokens)。`NoCandidates` の時は 0。 pub estimated_savings: u64, - /// `protected_turns` 境界に当たる turn-start アイテムの index。 - /// turn 数が `protected_turns` 以下で境界が決まらない場合は `None`。 - pub border_turn: Option, + /// Token budget で保護される suffix の先頭 item index。 + /// usage 推定が `NoData` で境界が決まらない場合は `None`。 + pub protected_start_index: Option, /// 判定結果。 pub decision: PruneDecision, } @@ -70,10 +81,9 @@ pub type PruneObserver = Box; /// Configuration for the Prune algorithm. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PruneConfig { - /// Number of recent turns to protect from pruning. - /// A "turn" starts at each user message. - #[serde(default = "default_protected_turns")] - pub protected_turns: usize, + /// Token budget at the history tail protected from pruning. + #[serde(default = "default_protected_tokens")] + pub protected_tokens: u64, /// Minimum token savings required to actually prune. If the prunable /// content is smaller than this, the caller should skip to avoid @@ -84,8 +94,8 @@ pub struct PruneConfig { pub min_savings: u64, } -fn default_protected_turns() -> usize { - 3 +fn default_protected_tokens() -> u64 { + 8000 } fn default_min_savings() -> u64 { 4096 @@ -94,25 +104,12 @@ fn default_min_savings() -> u64 { impl Default for PruneConfig { fn default() -> Self { Self { - protected_turns: default_protected_turns(), + protected_tokens: default_protected_tokens(), min_savings: default_min_savings(), } } } -/// Find indices where each "turn" begins. -/// -/// A turn starts at every user message. Returns the indices of those -/// user messages in ascending order. -fn find_turn_starts(items: &[Item]) -> Vec { - items - .iter() - .enumerate() - .filter(|(_, item)| item.is_user_message()) - .map(|(i, _)| i) - .collect() -} - /// Set `content = None` on each `Item::ToolResult` at the given indices. /// /// Returns the number of items that were actually modified — items that @@ -121,36 +118,43 @@ fn find_turn_starts(items: &[Item]) -> Vec { pub fn project(items: &mut [Item], indices: &[usize]) -> usize { let mut count = 0; for &i in indices { - if let Item::ToolResult { content, .. } = &mut items[i] { - if content.is_some() { - *content = None; - count += 1; - } + if let Item::ToolResult { content, .. } = &mut items[i] + && content.is_some() + { + *content = None; + count += 1; } } count } -/// Indices of `Item::ToolResult { content: Some(_), .. }` that lie outside -/// the last `protected_turns` turns. Pure: does not mutate `items`. +/// Indices of `Item::ToolResult { content: Some(_), .. }` that lie before +/// the suffix protected by `protected_tokens`. Pure: does not mutate `items`. /// -/// Returns an empty vector when there are too few turns or no prunable -/// candidates. -pub fn prunable_indices(items: &[Item], protected_turns: usize) -> Vec { - evaluate_candidates(items, protected_turns).0 +/// Returns an empty vector when token estimates are unavailable (`NoData`) or +/// no prunable candidates exist. +pub fn prunable_indices( + items: &[Item], + protected_tokens: u64, + token_estimates: &[TokenEstimate], +) -> Vec { + evaluate_candidates(items, protected_tokens, token_estimates).0 } -/// Same as [`prunable_indices`] but also returns the index of the -/// `protected_turns` boundary (the turn-start item whose tail is -/// protected). `None` when too few turns exist for a boundary to be -/// defined. -pub fn evaluate_candidates(items: &[Item], protected_turns: usize) -> (Vec, Option) { - let turn_starts = find_turn_starts(items); - if turn_starts.len() <= protected_turns { +/// Same as [`prunable_indices`] but also returns the start index of the +/// protected suffix. `None` means the token boundary could not be determined +/// (currently because usage estimates were `NoData` or malformed). +pub fn evaluate_candidates( + items: &[Item], + protected_tokens: u64, + token_estimates: &[TokenEstimate], +) -> (Vec, Option) { + let Some(protected_start) = protected_start_index(items, protected_tokens, token_estimates) + else { return (Vec::new(), None); - } - let boundary = turn_starts[turn_starts.len() - protected_turns]; - let candidates = items[..boundary] + }; + + let candidates = items[..protected_start] .iter() .enumerate() .filter_map(|(i, item)| match item { @@ -160,7 +164,38 @@ pub fn evaluate_candidates(items: &[Item], protected_turns: usize) -> (Vec None, }) .collect(); - (candidates, Some(boundary)) + (candidates, Some(protected_start)) +} + +fn protected_start_index( + items: &[Item], + protected_tokens: u64, + token_estimates: &[TokenEstimate], +) -> Option { + if token_estimates.len() != items.len() + 1 { + return None; + } + let total = token_estimates[items.len()]; + if total.source == EstimateSource::NoData { + return None; + } + if protected_tokens == 0 { + return Some(items.len()); + } + + let mut protected_start = items.len(); + for idx in (0..items.len()).rev() { + let prefix = token_estimates[idx]; + if prefix.source == EstimateSource::NoData { + return None; + } + protected_start = idx; + let tail_tokens = total.tokens.saturating_sub(prefix.tokens); + if tail_tokens >= protected_tokens { + break; + } + } + Some(protected_start) } #[cfg(test)] @@ -185,17 +220,70 @@ mod tests { items } + fn measured_prefix(tokens: &[u64]) -> Vec { + tokens + .iter() + .copied() + .map(|tokens| TokenEstimate { + tokens, + source: EstimateSource::Measured, + }) + .collect() + } + + fn uniform_estimates(items: &[Item], item_tokens: u64) -> Vec { + let mut tokens = Vec::with_capacity(items.len() + 1); + for i in 0..=items.len() { + tokens.push(i as u64 * item_tokens); + } + measured_prefix(&tokens) + } + + fn estimates_from_item_tokens(item_tokens: &[u64]) -> Vec { + let mut prefix = Vec::with_capacity(item_tokens.len() + 1); + let mut acc = 0; + prefix.push(acc); + for tokens in item_tokens { + acc += tokens; + prefix.push(acc); + } + measured_prefix(&prefix) + } + + fn no_data_estimates(items: &[Item]) -> Vec { + (0..=items.len()) + .map(|i| TokenEstimate { + tokens: i as u64, + source: if i == 0 { + EstimateSource::Measured + } else { + EstimateSource::NoData + }, + }) + .collect() + } + #[test] - fn no_candidates_when_too_few_turns() { + fn no_candidates_when_estimate_has_no_data() { + let items = make_history(&[("turn1", vec![("summary1", Some("big content here"))])]); + let estimates = no_data_estimates(&items); + let (candidates, protected_start) = evaluate_candidates(&items, 10, &estimates); + assert!(candidates.is_empty()); + assert_eq!(protected_start, None); + } + + #[test] + fn no_candidates_when_history_fits_in_protected_tokens() { let items = make_history(&[ ("turn1", vec![("summary1", Some("big content here"))]), ("turn2", vec![("summary2", Some("more content"))]), ]); - assert!(prunable_indices(&items, 3).is_empty()); + let estimates = uniform_estimates(&items, 10); + assert!(prunable_indices(&items, 10_000, &estimates).is_empty()); } #[test] - fn candidates_in_unprotected_turns() { + fn candidates_before_token_protected_suffix() { let big = "x".repeat(4096 * 4); let items = make_history(&[ ("turn1", vec![("s1", Some(&big))]), @@ -203,9 +291,39 @@ mod tests { ("turn3", vec![("s3", Some("keep me"))]), ("turn4", vec![("s4", Some("keep me too"))]), ]); - let candidates = prunable_indices(&items, 2); + let estimates = uniform_estimates(&items, 10); + let candidates = prunable_indices(&items, 80, &estimates); + assert_eq!(candidates.len(), 2); + // suffix budget 80 tokens protects turn3+turn4 (8 items), so only s1/s2 are candidates. + for &i in &candidates { + if let Item::ToolResult { summary, .. } = &items[i] { + assert!(summary == "s1" || summary == "s2"); + } else { + panic!("non tool-result selected"); + } + } + } + + #[test] + fn single_long_task_gets_candidates_without_multiple_user_turns() { + let big = "x".repeat(4096 * 8); + let items = make_history(&[( + "one long task", + vec![ + ("s1", Some(&big)), + ("s2", Some(&big)), + ("s3", Some(&big)), + ("s4", Some(&big)), + ], + )]); + // user + assistant are cheap; every ToolCall is cheap; every ToolResult is heavy. + let item_tokens = vec![1, 1, 1, 5_000, 1, 5_000, 1, 5_000, 1, 5_000]; + let estimates = estimates_from_item_tokens(&item_tokens); + + let (candidates, protected_start) = evaluate_candidates(&items, 8_000, &estimates); + + assert_eq!(protected_start, Some(7)); assert_eq!(candidates.len(), 2); - // 候補は turn1 と turn2 の ToolResult のみ for &i in &candidates { if let Item::ToolResult { summary, .. } = &items[i] { assert!(summary == "s1" || summary == "s2"); @@ -223,7 +341,8 @@ mod tests { ("turn3", vec![]), ("turn4", vec![]), ]); - assert!(prunable_indices(&items, 2).is_empty()); + let estimates = uniform_estimates(&items, 10); + assert!(prunable_indices(&items, 20, &estimates).is_empty()); } #[test] @@ -235,7 +354,8 @@ mod tests { ("turn3", vec![("s3", Some("keep me"))]), ("turn4", vec![("s4", Some("keep me too"))]), ]); - let candidates = prunable_indices(&items, 2); + let estimates = uniform_estimates(&items, 10); + let candidates = prunable_indices(&items, 80, &estimates); let count = project(&mut items, &candidates); assert_eq!(count, 2); @@ -261,7 +381,7 @@ mod tests { ("turn1", vec![("s1", None)]), ("turn2", vec![("s2", Some("hello"))]), ]); - // Manually target s1 (index 3) even though it's already None. + // Manually target s1 even though it's already None. let target = items .iter() .position(|it| matches!(it, Item::ToolResult { summary, .. } if summary == "s1")) @@ -279,14 +399,15 @@ mod tests { ("turn3", vec![]), ("turn4", vec![]), ]); - let candidates = prunable_indices(&items, 2); + let estimates = uniform_estimates(&items, 10); + let candidates = prunable_indices(&items, 20, &estimates); assert_eq!(project(&mut items, &candidates), 1); // 2 周目: 候補は一度の prunable_indices 結果を使い回しても 0 件。 assert_eq!(project(&mut items, &candidates), 0); } #[test] - fn evaluate_candidates_returns_boundary_index() { + fn evaluate_candidates_returns_protected_start_index() { let big = "x".repeat(64); let items = make_history(&[ ("turn1", vec![("s1", Some(&big))]), @@ -294,36 +415,37 @@ mod tests { ("turn3", vec![("s3", Some("keep"))]), ("turn4", vec![("s4", Some("keep too"))]), ]); - let (candidates, border) = evaluate_candidates(&items, 2); + let estimates = uniform_estimates(&items, 10); + let (candidates, protected_start) = evaluate_candidates(&items, 80, &estimates); assert_eq!(candidates.len(), 2); - // protected_turns=2 → boundary は turn3 の user message 位置。 - // turn1: u/a/c/r (4) + turn2: u/a/c/r (4) = index 8 (turn3 の user)。 - assert_eq!(border, Some(8)); + // protected_tokens=80 → protected suffix is turn3+turn4, starting at index 8. + assert_eq!(protected_start, Some(8)); } #[test] - fn evaluate_candidates_no_boundary_when_too_few_turns() { + fn evaluate_candidates_reports_zero_start_when_everything_is_protected() { let items = make_history(&[("only", vec![("s", Some("x"))])]); - let (candidates, border) = evaluate_candidates(&items, 2); + let estimates = uniform_estimates(&items, 10); + let (candidates, protected_start) = evaluate_candidates(&items, 10_000, &estimates); assert!(candidates.is_empty()); - assert!(border.is_none()); + assert_eq!(protected_start, Some(0)); } #[test] - fn protected_turns_boundary_exact() { - // 3 turns with protected_turns=2: only turn 1 is a candidate. + fn zero_protected_tokens_allows_all_tool_results_as_candidates() { let big = "x".repeat(64); - let items = make_history(&[ - ("turn1", vec![("s1", Some(&big))]), - ("turn2", vec![("s2", Some("protected"))]), - ("turn3", vec![("s3", Some("also protected"))]), - ]); - let candidates = prunable_indices(&items, 2); - assert_eq!(candidates.len(), 1); - if let Item::ToolResult { summary, .. } = &items[candidates[0]] { - assert_eq!(summary, "s1"); - } else { - panic!("expected ToolResult at candidate index"); - } + let items = make_history(&[("turn1", vec![("s1", Some(&big)), ("s2", Some(&big))])]); + let estimates = uniform_estimates(&items, 10); + let (candidates, protected_start) = evaluate_candidates(&items, 0, &estimates); + assert_eq!(protected_start, Some(items.len())); + assert_eq!(candidates.len(), 2); + } + + #[test] + fn malformed_estimate_vector_is_treated_as_no_boundary() { + let items = make_history(&[("turn1", vec![("s1", Some("x"))])]); + let (candidates, protected_start) = evaluate_candidates(&items, 10, &[]); + assert!(candidates.is_empty()); + assert_eq!(protected_start, None); } } diff --git a/crates/llm-worker/src/worker.rs b/crates/llm-worker/src/worker.rs index 529f2108..d5d00bbc 100644 --- a/crates/llm-worker/src/worker.rs +++ b/crates/llm-worker/src/worker.rs @@ -201,6 +201,10 @@ pub struct Worker { tool_output_limits: Option, /// Prune configuration. `None` disables the prune projection. prune_config: Option, + /// Callback that estimates prefix token counts, injected by higher + /// layers that own usage measurements. `None` disables the prune + /// projection. + token_estimator: Option, /// Callback that estimates token savings for a drop range, injected /// by higher layers that own usage measurements. `None` disables /// the prune projection. @@ -434,6 +438,17 @@ impl Worker { self.prune_config = config; } + /// Inject the callback used to estimate prefix token counts for prune's + /// protected-token boundary. + /// + /// The callback is invoked with the *request context* (a clone of + /// history). It must be pure/idempotent since it may be called once per + /// LLM request. Returning `NoData` estimates makes prune skip as if no + /// candidates existed. + pub fn set_token_estimator(&mut self, estimator: Option) { + self.token_estimator = estimator; + } + /// Inject the callback used to estimate token savings for a prune /// candidate range. /// @@ -983,18 +998,26 @@ impl Worker { // prunable candidates whose estimated savings meet the // threshold. Worker does not own usage history itself; the // estimator is injected by the layer that does. - if let (Some(config), Some(estimator)) = (&self.prune_config, &self.savings_estimator) { - let (candidates, border_turn) = - crate::prune::evaluate_candidates(&request_context, config.protected_turns); + if let (Some(config), Some(token_estimator), Some(savings_estimator)) = ( + &self.prune_config, + &self.token_estimator, + &self.savings_estimator, + ) { + let token_estimates = token_estimator(&request_context); + let (candidates, protected_start_index) = crate::prune::evaluate_candidates( + &request_context, + config.protected_tokens, + &token_estimates, + ); let evaluation = if candidates.is_empty() { crate::prune::PruneEvaluation { candidate_count: 0, estimated_savings: 0, - border_turn, + protected_start_index, decision: crate::prune::PruneDecision::SkippedNoCandidates, } } else { - let savings = estimator(&request_context, &candidates); + let savings = savings_estimator(&request_context, &candidates); if savings >= config.min_savings { let pruned = crate::prune::project(&mut request_context, &candidates); if pruned > 0 { @@ -1007,7 +1030,7 @@ impl Worker { crate::prune::PruneEvaluation { candidate_count: candidates.len(), estimated_savings: savings, - border_turn, + protected_start_index, decision: crate::prune::PruneDecision::Fired { pruned_count: pruned, }, @@ -1016,7 +1039,7 @@ impl Worker { crate::prune::PruneEvaluation { candidate_count: candidates.len(), estimated_savings: savings, - border_turn, + protected_start_index, decision: crate::prune::PruneDecision::SkippedBelowMinSavings, } } @@ -1256,6 +1279,7 @@ impl Worker { cancel_rx, tool_output_limits: None, prune_config: None, + token_estimator: None, savings_estimator: None, prune_observer: None, cache_anchor: None, @@ -1519,6 +1543,7 @@ impl Worker { cancel_rx: self.cancel_rx, tool_output_limits: self.tool_output_limits, prune_config: self.prune_config, + token_estimator: self.token_estimator, savings_estimator: self.savings_estimator, prune_observer: self.prune_observer, cache_anchor: self.cache_anchor, @@ -1605,6 +1630,7 @@ impl Worker { cancel_rx: self.cancel_rx, tool_output_limits: self.tool_output_limits, prune_config: self.prune_config, + token_estimator: self.token_estimator, savings_estimator: self.savings_estimator, prune_observer: self.prune_observer, cache_anchor: self.cache_anchor, diff --git a/crates/manifest/src/config.rs b/crates/manifest/src/config.rs index ab266182..059ba3f9 100644 --- a/crates/manifest/src/config.rs +++ b/crates/manifest/src/config.rs @@ -10,6 +10,7 @@ use std::collections::HashMap; use std::num::NonZeroU32; use std::path::{Path, PathBuf}; +use serde::de::Error as _; use serde::{Deserialize, Serialize}; use crate::defaults; @@ -112,7 +113,7 @@ pub struct PermissionConfigPartial { #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct CompactionConfigPartial { #[serde(default)] - pub prune_protected_turns: Option, + pub prune_protected_tokens: Option, #[serde(default)] pub prune_min_savings: Option, #[serde(default)] @@ -141,12 +142,31 @@ pub enum ResolveError { RelativePath { field: &'static str, path: PathBuf }, } +/// Reject manifest fields that were intentionally removed and must not be +/// silently swallowed by the general warn-and-ignore unknown-field policy. +pub(crate) fn reject_removed_manifest_fields(s: &str) -> Result<(), toml::de::Error> { + let value: toml::Value = toml::from_str(s)?; + if value + .get("compaction") + .and_then(toml::Value::as_table) + .is_some_and(|table| table.contains_key("prune_protected_turns")) + { + return Err(toml::de::Error::custom( + "unknown field in manifest: compaction.prune_protected_turns \ + (removed; use compaction.prune_protected_tokens)", + )); + } + Ok(()) +} + impl PodManifestConfig { /// Parse a partial manifest from a TOML string. Unknown top-level or /// nested fields emit a `tracing::warn!` and are ignored; use /// `tracing_subscriber` with `WARN` enabled to surface them to the - /// operator. + /// operator. Removed fields that must not be silently ignored (currently + /// `compaction.prune_protected_turns`) are rejected before deserialization. pub fn from_toml(s: &str) -> Result { + reject_removed_manifest_fields(s)?; let de = toml::Deserializer::parse(s)?; serde_ignored::deserialize(de, |path| { tracing::warn!("unknown field in manifest: {}", path); @@ -339,7 +359,7 @@ impl PermissionConfigPartial { impl CompactionConfigPartial { fn merge(self, upper: Self) -> Self { Self { - prune_protected_turns: upper.prune_protected_turns.or(self.prune_protected_turns), + prune_protected_tokens: upper.prune_protected_tokens.or(self.prune_protected_tokens), prune_min_savings: upper.prune_min_savings.or(self.prune_min_savings), compact_threshold: upper.compact_threshold.or(self.compact_threshold), compact_request_threshold: upper @@ -489,9 +509,9 @@ impl TryFrom for PodManifest { validate_model_paths(cm, "compaction.model.auth.file")?; } Ok(CompactionConfig { - prune_protected_turns: c - .prune_protected_turns - .unwrap_or(defaults::PRUNE_PROTECTED_TURNS), + prune_protected_tokens: c + .prune_protected_tokens + .unwrap_or(defaults::PRUNE_PROTECTED_TOKENS), prune_min_savings: c.prune_min_savings.unwrap_or(defaults::PRUNE_MIN_SAVINGS), compact_threshold: c.compact_threshold, compact_request_threshold: c.compact_request_threshold, @@ -921,7 +941,7 @@ mod tests { let lower = PodManifestConfig { compaction: Some(CompactionConfigPartial { compact_threshold: Some(50_000), - prune_protected_turns: Some(5), + prune_protected_tokens: Some(5_000), ..Default::default() }), ..Default::default() @@ -937,7 +957,7 @@ mod tests { let c = merged.compaction.unwrap(); assert_eq!(c.compact_threshold, Some(80_000)); // field from lower retained when upper has None - assert_eq!(c.prune_protected_turns, Some(5)); + assert_eq!(c.prune_protected_tokens, Some(5_000)); } #[test] @@ -971,6 +991,19 @@ unknown_future_field = "tolerated" assert_eq!(cfg.worker.max_tokens, Some(1000)); } + #[test] + fn from_toml_rejects_removed_prune_protected_turns_field() { + let bad = r#" +[compaction] +prune_protected_turns = 3 +"#; + let err = PodManifestConfig::from_toml(bad).unwrap_err(); + assert!( + err.to_string().contains("compaction.prune_protected_turns"), + "unexpected error: {err}" + ); + } + #[test] fn from_toml_accepts_worker_reasoning_string_or_integer() { let effort = PodManifestConfig::from_toml( diff --git a/crates/manifest/src/defaults.rs b/crates/manifest/src/defaults.rs index 68123e92..81f38ec4 100644 --- a/crates/manifest/src/defaults.rs +++ b/crates/manifest/src/defaults.rs @@ -14,9 +14,9 @@ pub const TOOL_OUTPUT_MAX_BYTES: usize = 64 * 1024; /// See [`crate::FileUploadLimits`]. pub const FILE_UPLOAD_MAX_BYTES: usize = 256 * 1024; -/// Number of most-recent turns protected from pruning. See -/// [`crate::CompactionConfig::prune_protected_turns`]. -pub const PRUNE_PROTECTED_TURNS: usize = 3; +/// Token budget at the history tail protected from pruning. See +/// [`crate::CompactionConfig::prune_protected_tokens`]. +pub const PRUNE_PROTECTED_TOKENS: u64 = 8000; /// Minimum estimated token savings required to trigger a prune. See /// [`crate::CompactionConfig::prune_min_savings`]. diff --git a/crates/manifest/src/lib.rs b/crates/manifest/src/lib.rs index df78ad34..3a868860 100644 --- a/crates/manifest/src/lib.rs +++ b/crates/manifest/src/lib.rs @@ -337,9 +337,9 @@ pub enum ToolPermissionAction { /// (full history summarisation). Omitting `[compaction]` disables both. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CompactionConfig { - /// Number of recent turns protected from pruning. - #[serde(default = "default_prune_protected_turns")] - pub prune_protected_turns: usize, + /// Token budget at the history tail protected from pruning. + #[serde(default = "default_prune_protected_tokens")] + pub prune_protected_tokens: u64, /// Minimum estimated token savings to trigger a prune. #[serde(default = "default_prune_min_savings")] @@ -393,8 +393,8 @@ pub struct CompactionConfig { pub model: Option, } -fn default_prune_protected_turns() -> usize { - defaults::PRUNE_PROTECTED_TURNS +fn default_prune_protected_tokens() -> u64 { + defaults::PRUNE_PROTECTED_TOKENS } fn default_prune_min_savings() -> u64 { defaults::PRUNE_MIN_SAVINGS @@ -415,7 +415,7 @@ fn default_compact_worker_max_turns() -> Option { impl Default for CompactionConfig { fn default() -> Self { Self { - prune_protected_turns: default_prune_protected_turns(), + prune_protected_tokens: default_prune_protected_tokens(), prune_min_savings: default_prune_min_savings(), compact_threshold: None, compact_request_threshold: None, @@ -431,6 +431,7 @@ impl Default for CompactionConfig { impl PodManifest { /// Parse a manifest from a TOML string. pub fn from_toml(s: &str) -> Result { + config::reject_removed_manifest_fields(s)?; toml::from_str(s) } } @@ -581,7 +582,7 @@ model_id = "claude-sonnet-4-20250514" let toml = format!("{MINIMAL_REQUIRED}\n[compaction]\ncompact_threshold = 80000\n"); let manifest = PodManifest::from_toml(&toml).unwrap(); let c = manifest.compaction.unwrap(); - assert_eq!(c.prune_protected_turns, 3); + assert_eq!(c.prune_protected_tokens, 8000); assert_eq!(c.prune_min_savings, 4096); assert_eq!(c.compact_threshold, Some(80000)); assert_eq!(c.compact_request_threshold, None); @@ -589,6 +590,16 @@ model_id = "claude-sonnet-4-20250514" assert_eq!(c.compact_worker_max_turns, Some(20)); } + #[test] + fn reject_removed_prune_protected_turns_field() { + let toml = format!("{MINIMAL_REQUIRED}\n[compaction]\nprune_protected_turns = 3\n"); + let err = PodManifest::from_toml(&toml).unwrap_err(); + assert!( + err.to_string().contains("compaction.prune_protected_turns"), + "unexpected error: {err}" + ); + } + #[test] fn parse_compaction_worker_max_turns() { let toml = format!( diff --git a/crates/pod/src/compact/prune.rs b/crates/pod/src/compact/prune.rs index 2c62b50b..f4beb679 100644 --- a/crates/pod/src/compact/prune.rs +++ b/crates/pod/src/compact/prune.rs @@ -13,19 +13,24 @@ use llm_worker::Item; use llm_worker::llm_client::client::LlmClient; -use llm_worker::prune::{PruneConfig, PruneDecision, PruneObserver, SavingsEstimator}; +use llm_worker::prune::{ + PruneConfig, PruneDecision, PruneObserver, SavingsEstimator, TokenEstimator, +}; use session_metrics::Metric; use session_store::Store; use crate::Pod; -use crate::compact::token_counter::{EstimateSource, savings_for_prune_impl}; +use crate::compact::token_counter::{ + EstimateSource, savings_for_prune_impl, token_estimates_for_prune_impl, +}; impl Pod { /// Enable prune projection on the underlying Worker. /// - /// Registers the config and a savings-estimator closure on the Worker. - /// The estimator captures a shared handle to [`Pod::usage_history_handle`] - /// so that every LLM request sees the latest measurements. + /// Registers the config and token/savings-estimator closures on the Worker. + /// The estimators combine persisted [`Pod::usage_history_handle`] records + /// with in-flight `UsageTracker` records so multi-request tool loops can + /// prune before the surrounding Pod run finishes. /// /// Measurement-less estimates (before the first LLM call, or immediately /// after a compact) return `0` from the estimator, which naturally @@ -37,9 +42,25 @@ impl Pod { /// [`UsageTracker`] so the next `LlmUsage` can be paired with a /// `prune.post_request` metric carrying the same id. pub fn attach_prune(&mut self, config: PruneConfig) { - let usage = self.usage_history_handle(); + let usage_history_for_tokens = self.usage_history_handle(); + let usage_tracker_for_tokens = self.usage_tracker_handle(); + let token_estimator: TokenEstimator = Box::new(move |history: &[Item]| { + let mut snapshot = usage_history_for_tokens + .lock() + .expect("usage_history poisoned") + .clone(); + snapshot.extend(usage_tracker_for_tokens.records()); + token_estimates_for_prune_impl(history, &snapshot) + }); + + let usage_history_for_savings = self.usage_history_handle(); + let usage_tracker_for_savings = self.usage_tracker_handle(); let estimator: SavingsEstimator = Box::new(move |history: &[Item], indices| { - let snapshot = usage.lock().expect("usage_history poisoned").clone(); + let mut snapshot = usage_history_for_savings + .lock() + .expect("usage_history poisoned") + .clone(); + snapshot.extend(usage_tracker_for_savings.records()); let est = savings_for_prune_impl(history, &snapshot, indices); match est.source { EstimateSource::NoData => 0, @@ -56,8 +77,9 @@ impl Pod { .with_value(eval.estimated_savings as f64) .with_correlation_id(&correlation_id) .with_dimension("candidate_count", eval.candidate_count.to_string()); - if let Some(border) = eval.border_turn { - metric = metric.with_dimension("border_turn", border.to_string()); + if let Some(protected_start) = eval.protected_start_index { + metric = + metric.with_dimension("protected_start_index", protected_start.to_string()); } metrics.push(metric); usage_tracker.note_correlation_id(correlation_id); @@ -66,17 +88,21 @@ impl Pod { metrics.push(Metric::now("prune.skip").with_dimension("reason", "no_candidates")); } PruneDecision::SkippedBelowMinSavings => { - metrics.push( - Metric::now("prune.skip") - .with_dimension("reason", "below_min_savings") - .with_dimension("candidate_count", eval.candidate_count.to_string()) - .with_value(eval.estimated_savings as f64), - ); + let mut metric = Metric::now("prune.skip") + .with_dimension("reason", "below_min_savings") + .with_dimension("candidate_count", eval.candidate_count.to_string()) + .with_value(eval.estimated_savings as f64); + if let Some(protected_start) = eval.protected_start_index { + metric = + metric.with_dimension("protected_start_index", protected_start.to_string()); + } + metrics.push(metric); } }); let worker = self.worker_mut(); worker.set_prune_config(Some(config)); + worker.set_token_estimator(Some(token_estimator)); worker.set_savings_estimator(Some(estimator)); worker.set_prune_observer(Some(observer)); } @@ -90,7 +116,7 @@ impl Pod { return; }; let config = PruneConfig { - protected_turns: compaction.prune_protected_turns, + protected_tokens: compaction.prune_protected_tokens, min_savings: compaction.prune_min_savings, }; self.attach_prune(config); diff --git a/crates/pod/src/compact/token_counter.rs b/crates/pod/src/compact/token_counter.rs index dbf6cd79..02507d39 100644 --- a/crates/pod/src/compact/token_counter.rs +++ b/crates/pod/src/compact/token_counter.rs @@ -132,6 +132,21 @@ fn tool_result_content_bytes(item: &Item) -> u64 { item_bytes(item).saturating_sub(item_bytes(&cleared)) } +/// Prefix-boundary token estimates used by Prune to find its protected suffix. +/// +/// Returns `history.len() + 1` entries where entry `i` estimates +/// `history[..i]`. This shares the same [`tokens_at`] accounting as compact's +/// retained-tail split and prune's savings estimate. +pub(crate) fn token_estimates_for_prune_impl( + history: &[Item], + records: &[UsageRecord], +) -> Vec { + let prefix = prefix_bytes(history); + (0..=history.len()) + .map(|idx| tokens_at(history, records, idx, &prefix)) + .collect() +} + /// Prune 射影(`ToolResult.content = None`)で節約されるトークン数の推定。 /// /// `indices` は [`llm_worker::prune::prunable_indices`] が返す候補列を @@ -278,6 +293,26 @@ mod tests { } } + #[test] + fn token_estimates_for_prune_returns_every_prefix_boundary() { + let history = vec![msg("a"), msg("b"), msg("c")]; + let estimates = token_estimates_for_prune_impl(&history, &[record(3, 300)]); + assert_eq!(estimates.len(), history.len() + 1); + assert_eq!(estimates[0].tokens, 0); + assert_eq!(estimates[3].tokens, 300); + assert_eq!(estimates[3].source, EstimateSource::Measured); + } + + #[test] + fn token_estimates_for_prune_propagates_no_data() { + let history = vec![msg("a"), msg("b")]; + let estimates = token_estimates_for_prune_impl(&history, &[]); + assert_eq!(estimates.len(), history.len() + 1); + assert_eq!(estimates[0].source, EstimateSource::Measured); + assert_eq!(estimates[1].source, EstimateSource::NoData); + assert_eq!(estimates[2].source, EstimateSource::NoData); + } + #[test] fn savings_for_prune_skips_non_toolresult_indices() { let history = vec![msg("a"), msg("b"), msg("c")]; diff --git a/crates/pod/tests/session_metrics_test.rs b/crates/pod/tests/session_metrics_test.rs index 4b2ba875..75add356 100644 --- a/crates/pod/tests/session_metrics_test.rs +++ b/crates/pod/tests/session_metrics_test.rs @@ -4,10 +4,10 @@ //! returns a long `ToolOutput.content`, then inspects the persisted //! session log to verify: //! -//! - `prune.skip { reason: "no_candidates" }` lands when the protected-turn -//! window covers the entire history. -//! - `prune.fire` lands once enough turns + usage measurements exist for -//! the projection to actually apply. +//! - `prune.skip { reason: "no_candidates" }` lands when usage estimates are +//! unavailable or the protected-token window covers all tool results. +//! - `prune.fire` lands once enough measured history exceeds the protected-token +//! budget for the projection to actually apply. //! - The fire metric and the immediately-following `prune.post_request` //! metric share the same `correlation_id`, so cache_read / cache_write //! from the LlmUsage that triggered the projection can be joined back @@ -136,7 +136,7 @@ fn text_response_with_cache(text: &str, cache_read: u64, cache_write: u64) -> Ve ] } -fn manifest_toml(prune_protected_turns: usize, prune_min_savings: u64) -> String { +fn manifest_toml(prune_protected_tokens: u64, prune_min_savings: u64) -> String { format!( r#" [pod] @@ -151,7 +151,7 @@ model_id = "test-model" max_tokens = 100 [compaction] -prune_protected_turns = {prune_protected_turns} +prune_protected_tokens = {prune_protected_tokens} prune_min_savings = {prune_min_savings} [[scope.allow]] @@ -192,7 +192,7 @@ async fn prune_metrics_emit_skip_then_fire_with_post_request_join() { // Run 1 (request 0): tool_use → triggers tool execution → request 1 // on the second iteration to produce the assistant reply. // Run 2 (request 2): plain assistant text. Prune evaluation here - // sees user1's tool_result outside the 1-protected-turn window and + // sees user1's tool_result outside the protected-token suffix and // should fire. let client = MockClient::new(vec![ tool_use_response("call-1", "big_tool"), @@ -250,8 +250,8 @@ async fn prune_metrics_emit_skip_then_fire_with_post_request_join() { "fire missing candidate_count: {fire:?}" ); assert!( - fire.dimensions.contains_key("border_turn"), - "fire missing border_turn: {fire:?}" + fire.dimensions.contains_key("protected_start_index"), + "fire missing protected_start_index: {fire:?}" ); assert!(fire.value.is_some(), "fire missing estimated_savings value"); let fire_id = fire @@ -277,6 +277,36 @@ async fn prune_metrics_emit_skip_then_fire_with_post_request_join() { assert!(post.dimensions.contains_key("history_len")); } +#[tokio::test] +async fn prune_metrics_fire_during_single_long_task_without_multiple_user_turns() { + let client = MockClient::new(vec![ + tool_use_response("call-1", "big_tool"), + tool_use_response("call-2", "big_tool"), + tool_use_response("call-3", "big_tool"), + tool_use_response("call-4", "big_tool"), + text_response_with_cache("done", 100, 20), + ]); + let (mut pod, _store_tmp, _pwd_tmp) = make_pod(manifest_toml(1, 1), client, "big_tool").await; + let session_id = pod.session_id(); + let segment_id = pod.segment_id(); + let store = pod.store().clone(); + + pod.run_text("one long task").await.unwrap(); + + let state = session_store::restore(&store, session_id, segment_id).unwrap(); + let metrics = metrics_from_extensions(&state.extensions); + let fire_count = metrics.iter().filter(|m| m.name == "prune.fire").count(); + assert!( + fire_count > 0, + "single-turn tool loop should produce prune.fire once old heavy ToolResults fall outside the protected-token suffix: {metrics:?}" + ); + assert!( + metrics.iter().any(|m| { + m.name == "prune.fire" && m.dimensions.contains_key("protected_start_index") + }) + ); +} + /// `min_savings` set high enough that candidates exist but the estimated /// savings always fall short → the second run should record /// `prune.skip { reason: "below_min_savings" }`. @@ -288,7 +318,7 @@ async fn prune_metrics_record_below_min_savings_skip() { text_response_with_cache("done", 0, 0), ]); let (mut pod, _store_tmp, _pwd_tmp) = - make_pod(manifest_toml(1, u64::MAX), client, "big_tool").await; + make_pod(manifest_toml(1, 1_000_000), client, "big_tool").await; let session_id = pod.session_id(); let segment_id = pod.segment_id(); let store = pod.store().clone(); @@ -405,7 +435,7 @@ async fn metric_write_failure_emits_warn_alert_and_does_not_abort_run() { // Even with a tool registered, this run will only emit // `prune.skip { reason: "no_candidates" }` (one user message, - // protected_turns=1 covers everything). That is enough to drive + // protected token budget covers the only user message). That is enough to drive // the failure path: at least one metric attempts to write. let client = MockClient::new(vec![text_response_with_cache("hi", 0, 0)]); let worker = Worker::new(client); diff --git a/crates/session-metrics/src/lib.rs b/crates/session-metrics/src/lib.rs index f7deb9db..1e0c1a0c 100644 --- a/crates/session-metrics/src/lib.rs +++ b/crates/session-metrics/src/lib.rs @@ -104,7 +104,7 @@ mod tests { #[test] fn metric_round_trip_via_json() { let metric = Metric::now("prune.fire") - .with_dimension("border_turn", "3") + .with_dimension("protected_start_index", "3") .with_dimension("candidate_count", "2") .with_value(4096.0) .with_correlation_id("abc-123"); diff --git a/docs/compaction.md b/docs/compaction.md index c81498ff..63da259f 100644 --- a/docs/compaction.md +++ b/docs/compaction.md @@ -38,6 +38,7 @@ Pod::try_pre_run_compact ← proactive - **条件付き実行**: 推定トークン節約量が `min_savings` を超えた場合のみ。KV キャッシュの無駄な無効化を避ける - **リクエストコンテキストのみ操作**: history 本体は変更しない。Prune 状態を Pod が保持し、LLM リクエスト構築時に反映する +- **保護境界**: 直近 `prune_protected_tokens` 相当の suffix は残す。turn 数ではなく usage history 由来の token estimate で境界を引くため、単発の長い tool loop でも古い `ToolResult.content` が候補になる - **冪等**: `content: None` のアイテムはスキップ ### ToolOutput の構造 @@ -138,8 +139,9 @@ compact は fork と同じ構造。旧セッションを保全し、新 SessionI [compaction] compact_threshold = 80000 # ターンの合間 (proactive) compact_request_threshold = 90000 # リクエストの合間 (safety net) -retained_tokens = 8000 # 直近保護トークン数 (Prune 済みで計測) -auto_read_budget = 8000 # compact worker の mark_read_required 合計上限 +prune_protected_tokens = 8000 # prune から保護する末尾 token budget +compact_retained_tokens = 8000 # compact 後に生のまま残す末尾 token budget +compact_auto_read_budget = 8000 # compact worker の mark_read_required 合計上限 compact_worker_max_input_tokens = 50000 # compact worker 自身の現在占有トークン上限 compact_worker_max_turns = 20 # compact worker 自身の tool loop 上限 ``` diff --git a/docs/manifest.toml b/docs/manifest.toml index e874eed8..ff739af5 100644 --- a/docs/manifest.toml +++ b/docs/manifest.toml @@ -191,9 +191,9 @@ permission = "write" # セクションを書いた時点で Prune は有効化、Compact は閾値が None なら無効。 # [compaction] # -# # 任意。デフォルト: 3 (`defaults::PRUNE_PROTECTED_TURNS`)。 -# # pruning から保護する末尾ターン数。 -# prune_protected_turns = 3 +# # 任意。デフォルト: 8000 (`defaults::PRUNE_PROTECTED_TOKENS`)。 +# # pruning から保護する末尾 token budget。turn 数ではなく usage estimate で境界を引く。 +# prune_protected_tokens = 8000 # # # 任意。デフォルト: 4096 (`defaults::PRUNE_MIN_SAVINGS`)。 # # prune が発火するための最低節約 token 推定値。 diff --git a/docs/pod-factory.md b/docs/pod-factory.md index d739373c..6c133f3b 100644 --- a/docs/pod-factory.md +++ b/docs/pod-factory.md @@ -179,7 +179,7 @@ pattern = "*.env" action = "deny" [compaction] -prune_protected_turns = 3 +prune_protected_tokens = 8000 prune_min_savings = 4096 compact_threshold = 80000 compact_request_threshold = 90000