yoi/crates/llm-worker/src/prune.rs

452 lines
17 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! Prune — context projection for old tool-result content.
//!
//! LLM 送信時のコンテキストから古い [`Item::ToolResult`] の `content` を
//! 省略して、コンテキスト窓のトークンを回収する。`summary` は残すので
//! 「何が起きたか」の痕跡は保たれる。
//!
//! # 設計方針
//!
//! Prune は **コンテキスト射影** であり、history の変換ではない。
//! この crate が提供するのは pure な候補抽出 [`prunable_indices`] のみで、
//! 射影の適用は上位層(`pod::prune_hook` 等)が LLM に送る一時コンテキスト
//! に対してだけ行う。Worker の永続履歴は決して変更されない。
//!
//! 保護境界は末尾 token budget で決めるが、この crate は usage 履歴を
//! 所有しない。prefix ごとの token 推定値と savings 推定は上位層から
//! callback で注入される。
use serde::{Deserialize, Serialize};
use crate::llm_client::types::Item;
use crate::token_counter::{EstimateSource, TokenEstimate};
/// Callback that returns token estimates for every prefix boundary of the
/// supplied request history.
///
/// The returned slice must have `history.len() + 1` entries where entry `i`
/// estimates the token count of `history[..i]`. Returning a malformed vector,
/// or estimates whose source is [`EstimateSource::NoData`], makes prune treat
/// the request as having no candidates.
pub type TokenEstimator = Box<dyn Fn(&[Item]) -> Vec<TokenEstimate> + Send + Sync>;
/// Callback that estimates the token savings for projecting the
/// `ToolResult.content` out of `history[i]` for each `i` in `indices`.
///
/// Injected into [`crate::Worker`] via `set_savings_estimator` so the
/// Worker can make `min_savings` decisions without knowing about usage
/// measurement sources. Return `0` to signal "no data / refuse to prune".
///
/// 推定対象は「drop する範囲全体」ではなく「content を None にする差分」
/// であることに注意。item 自体summary 等)は残るので、この callback は
/// 実際の projection と一致する savings を返す必要がある。
pub type SavingsEstimator = Box<dyn Fn(&[Item], &[usize]) -> u64 + Send + Sync>;
/// Result of one prune evaluation pass, surfaced to the optional
/// [`PruneObserver`] for instrumentation.
///
/// Worker は LLM リクエストごとに 1 回 prune の評価をし、その結果を
/// observer が登録されていればこの値で通知する。fire/skip の判定
/// 結果と、判定材料になった候補数 / 推定 savings / 保護領域の先頭 index を持つ。
#[derive(Debug, Clone)]
pub struct PruneEvaluation {
/// `prunable_indices` の長さ。`Skipped::NoCandidates` の時は 0。
pub candidate_count: usize,
/// 推定された savings (tokens)。`NoCandidates` の時は 0。
pub estimated_savings: u64,
/// Token budget で保護される suffix の先頭 item index。
/// usage 推定が `NoData` で境界が決まらない場合は `None`。
pub protected_start_index: Option<usize>,
/// 判定結果。
pub decision: PruneDecision,
}
/// Outcome of one prune evaluation. Each variant is one branch of the
/// "fire vs skip" decision tree the Worker walks before each LLM request.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PruneDecision {
/// `prunable_indices` が空 → 何もしない。
SkippedNoCandidates,
/// 候補はあったが推定 savings が `min_savings` 未満 → 何もしない。
SkippedBelowMinSavings,
/// 候補があり savings >= min_savings → projection を適用した。
/// `pruned_count` は `project()` が実際に書き換えた item 数
/// (既に content=None だった候補は 0 計上)。
Fired { pruned_count: usize },
}
/// Optional observer invoked after each prune evaluation, regardless of
/// branch. Pod 等の上位層が install して metrics を発行する。
pub type PruneObserver = Box<dyn Fn(&PruneEvaluation) + Send + Sync>;
/// Configuration for the Prune algorithm.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PruneConfig {
/// Token budget at the history tail protected from pruning.
#[serde(default = "default_protected_tokens")]
pub protected_tokens: u64,
/// Minimum token savings required to actually prune. If the prunable
/// content is smaller than this, the caller should skip to avoid
/// pointless KV-cache invalidation. The unit is tokens; the caller
/// is responsible for measuring savings via a usage-history-aware
/// estimator and comparing against this threshold.
#[serde(default = "default_min_savings")]
pub min_savings: u64,
}
fn default_protected_tokens() -> u64 {
8000
}
fn default_min_savings() -> u64 {
4096
}
impl Default for PruneConfig {
fn default() -> Self {
Self {
protected_tokens: default_protected_tokens(),
min_savings: default_min_savings(),
}
}
}
/// Set `content = None` on each `Item::ToolResult` at the given indices.
///
/// Returns the number of items that were actually modified — items that
/// are already content-less are counted as 0. Intended for use on a
/// request-context clone (never on a persistent history).
pub fn project(items: &mut [Item], indices: &[usize]) -> usize {
let mut count = 0;
for &i in indices {
if let Item::ToolResult { content, .. } = &mut items[i]
&& content.is_some()
{
*content = None;
count += 1;
}
}
count
}
/// Indices of `Item::ToolResult { content: Some(_), .. }` that lie before
/// the suffix protected by `protected_tokens`. Pure: does not mutate `items`.
///
/// Returns an empty vector when token estimates are unavailable (`NoData`) or
/// no prunable candidates exist.
pub fn prunable_indices(
items: &[Item],
protected_tokens: u64,
token_estimates: &[TokenEstimate],
) -> Vec<usize> {
evaluate_candidates(items, protected_tokens, token_estimates).0
}
/// Same as [`prunable_indices`] but also returns the start index of the
/// protected suffix. `None` means the token boundary could not be determined
/// (currently because usage estimates were `NoData` or malformed).
pub fn evaluate_candidates(
items: &[Item],
protected_tokens: u64,
token_estimates: &[TokenEstimate],
) -> (Vec<usize>, Option<usize>) {
let Some(protected_start) = protected_start_index(items, protected_tokens, token_estimates)
else {
return (Vec::new(), None);
};
let candidates = items[..protected_start]
.iter()
.enumerate()
.filter_map(|(i, item)| match item {
Item::ToolResult {
content: Some(_), ..
} => Some(i),
_ => None,
})
.collect();
(candidates, Some(protected_start))
}
fn protected_start_index(
items: &[Item],
protected_tokens: u64,
token_estimates: &[TokenEstimate],
) -> Option<usize> {
if token_estimates.len() != items.len() + 1 {
return None;
}
let total = token_estimates[items.len()];
if total.source == EstimateSource::NoData {
return None;
}
if protected_tokens == 0 {
return Some(items.len());
}
let mut protected_start = items.len();
for idx in (0..items.len()).rev() {
let prefix = token_estimates[idx];
if prefix.source == EstimateSource::NoData {
return None;
}
protected_start = idx;
let tail_tokens = total.tokens.saturating_sub(prefix.tokens);
if tail_tokens >= protected_tokens {
break;
}
}
Some(protected_start)
}
#[cfg(test)]
mod tests {
use super::*;
/// Helper: build a history with interleaved user messages and tool results.
fn make_history(turns: &[(&str, Vec<(&str, Option<&str>)>)]) -> Vec<Item> {
let mut items = Vec::new();
for (user_msg, tool_results) in turns {
items.push(Item::user_message(*user_msg));
items.push(Item::assistant_message("ok"));
for (i, (summary, content)) in tool_results.iter().enumerate() {
let call_id = format!("call_{}", items.len() + i);
items.push(Item::tool_call(&call_id, "some_tool", "{}"));
match content {
Some(c) => items.push(Item::tool_result_with_content(&call_id, *summary, *c)),
None => items.push(Item::tool_result(&call_id, *summary)),
}
}
}
items
}
fn measured_prefix(tokens: &[u64]) -> Vec<TokenEstimate> {
tokens
.iter()
.copied()
.map(|tokens| TokenEstimate {
tokens,
source: EstimateSource::Measured,
})
.collect()
}
fn uniform_estimates(items: &[Item], item_tokens: u64) -> Vec<TokenEstimate> {
let mut tokens = Vec::with_capacity(items.len() + 1);
for i in 0..=items.len() {
tokens.push(i as u64 * item_tokens);
}
measured_prefix(&tokens)
}
fn estimates_from_item_tokens(item_tokens: &[u64]) -> Vec<TokenEstimate> {
let mut prefix = Vec::with_capacity(item_tokens.len() + 1);
let mut acc = 0;
prefix.push(acc);
for tokens in item_tokens {
acc += tokens;
prefix.push(acc);
}
measured_prefix(&prefix)
}
fn no_data_estimates(items: &[Item]) -> Vec<TokenEstimate> {
(0..=items.len())
.map(|i| TokenEstimate {
tokens: i as u64,
source: if i == 0 {
EstimateSource::Measured
} else {
EstimateSource::NoData
},
})
.collect()
}
#[test]
fn no_candidates_when_estimate_has_no_data() {
let items = make_history(&[("turn1", vec![("summary1", Some("big content here"))])]);
let estimates = no_data_estimates(&items);
let (candidates, protected_start) = evaluate_candidates(&items, 10, &estimates);
assert!(candidates.is_empty());
assert_eq!(protected_start, None);
}
#[test]
fn no_candidates_when_history_fits_in_protected_tokens() {
let items = make_history(&[
("turn1", vec![("summary1", Some("big content here"))]),
("turn2", vec![("summary2", Some("more content"))]),
]);
let estimates = uniform_estimates(&items, 10);
assert!(prunable_indices(&items, 10_000, &estimates).is_empty());
}
#[test]
fn candidates_before_token_protected_suffix() {
let big = "x".repeat(4096 * 4);
let items = make_history(&[
("turn1", vec![("s1", Some(&big))]),
("turn2", vec![("s2", Some(&big))]),
("turn3", vec![("s3", Some("keep me"))]),
("turn4", vec![("s4", Some("keep me too"))]),
]);
let estimates = uniform_estimates(&items, 10);
let candidates = prunable_indices(&items, 80, &estimates);
assert_eq!(candidates.len(), 2);
// suffix budget 80 tokens protects turn3+turn4 (8 items), so only s1/s2 are candidates.
for &i in &candidates {
if let Item::ToolResult { summary, .. } = &items[i] {
assert!(summary == "s1" || summary == "s2");
} else {
panic!("non tool-result selected");
}
}
}
#[test]
fn single_long_task_gets_candidates_without_multiple_user_turns() {
let big = "x".repeat(4096 * 8);
let items = make_history(&[(
"one long task",
vec![
("s1", Some(&big)),
("s2", Some(&big)),
("s3", Some(&big)),
("s4", Some(&big)),
],
)]);
// user + assistant are cheap; every ToolCall is cheap; every ToolResult is heavy.
let item_tokens = vec![1, 1, 1, 5_000, 1, 5_000, 1, 5_000, 1, 5_000];
let estimates = estimates_from_item_tokens(&item_tokens);
let (candidates, protected_start) = evaluate_candidates(&items, 8_000, &estimates);
assert_eq!(protected_start, Some(7));
assert_eq!(candidates.len(), 2);
for &i in &candidates {
if let Item::ToolResult { summary, .. } = &items[i] {
assert!(summary == "s1" || summary == "s2");
} else {
panic!("non tool-result selected");
}
}
}
#[test]
fn already_pruned_items_excluded_from_candidates() {
let items = make_history(&[
("turn1", vec![("s1", None)]), // already pruned (content=None)
("turn2", vec![]),
("turn3", vec![]),
("turn4", vec![]),
]);
let estimates = uniform_estimates(&items, 10);
assert!(prunable_indices(&items, 20, &estimates).is_empty());
}
#[test]
fn project_drops_content_and_counts_modifications() {
let big = "x".repeat(64);
let mut items = make_history(&[
("turn1", vec![("s1", Some(&big))]),
("turn2", vec![("s2", Some(&big))]),
("turn3", vec![("s3", Some("keep me"))]),
("turn4", vec![("s4", Some("keep me too"))]),
]);
let estimates = uniform_estimates(&items, 10);
let candidates = prunable_indices(&items, 80, &estimates);
let count = project(&mut items, &candidates);
assert_eq!(count, 2);
for item in &items {
if let Item::ToolResult {
summary, content, ..
} = item
{
if summary == "s1" || summary == "s2" {
assert!(content.is_none(), "old content should be projected out");
} else {
assert!(content.is_some(), "protected content should remain");
}
}
}
}
#[test]
fn project_skips_already_pruned_items() {
// indices points at an item whose content is already None.
// project() should count it as 0 modifications.
let mut items = make_history(&[
("turn1", vec![("s1", None)]),
("turn2", vec![("s2", Some("hello"))]),
]);
// Manually target s1 even though it's already None.
let target = items
.iter()
.position(|it| matches!(it, Item::ToolResult { summary, .. } if summary == "s1"))
.unwrap();
let count = project(&mut items, &[target]);
assert_eq!(count, 0);
}
#[test]
fn project_is_idempotent() {
let big = "x".repeat(64);
let mut items = make_history(&[
("turn1", vec![("s1", Some(&big))]),
("turn2", vec![]),
("turn3", vec![]),
("turn4", vec![]),
]);
let estimates = uniform_estimates(&items, 10);
let candidates = prunable_indices(&items, 20, &estimates);
assert_eq!(project(&mut items, &candidates), 1);
// 2 周目: 候補は一度の prunable_indices 結果を使い回しても 0 件。
assert_eq!(project(&mut items, &candidates), 0);
}
#[test]
fn evaluate_candidates_returns_protected_start_index() {
let big = "x".repeat(64);
let items = make_history(&[
("turn1", vec![("s1", Some(&big))]),
("turn2", vec![("s2", Some(&big))]),
("turn3", vec![("s3", Some("keep"))]),
("turn4", vec![("s4", Some("keep too"))]),
]);
let estimates = uniform_estimates(&items, 10);
let (candidates, protected_start) = evaluate_candidates(&items, 80, &estimates);
assert_eq!(candidates.len(), 2);
// protected_tokens=80 → protected suffix is turn3+turn4, starting at index 8.
assert_eq!(protected_start, Some(8));
}
#[test]
fn evaluate_candidates_reports_zero_start_when_everything_is_protected() {
let items = make_history(&[("only", vec![("s", Some("x"))])]);
let estimates = uniform_estimates(&items, 10);
let (candidates, protected_start) = evaluate_candidates(&items, 10_000, &estimates);
assert!(candidates.is_empty());
assert_eq!(protected_start, Some(0));
}
#[test]
fn zero_protected_tokens_allows_all_tool_results_as_candidates() {
let big = "x".repeat(64);
let items = make_history(&[("turn1", vec![("s1", Some(&big)), ("s2", Some(&big))])]);
let estimates = uniform_estimates(&items, 10);
let (candidates, protected_start) = evaluate_candidates(&items, 0, &estimates);
assert_eq!(protected_start, Some(items.len()));
assert_eq!(candidates.len(), 2);
}
#[test]
fn malformed_estimate_vector_is_treated_as_no_boundary() {
let items = make_history(&[("turn1", vec![("s1", Some("x"))])]);
let (candidates, protected_start) = evaluate_candidates(&items, 10, &[]);
assert!(candidates.is_empty());
assert_eq!(protected_start, None);
}
}