feat: protect prune tail by token budget
This commit is contained in:
parent
4072d35f81
commit
9ee7f04805
|
|
@ -11,12 +11,23 @@
|
||||||
//! 射影の適用は上位層(`pod::prune_hook` 等)が LLM に送る一時コンテキスト
|
//! 射影の適用は上位層(`pod::prune_hook` 等)が LLM に送る一時コンテキスト
|
||||||
//! に対してだけ行う。Worker の永続履歴は決して変更されない。
|
//! に対してだけ行う。Worker の永続履歴は決して変更されない。
|
||||||
//!
|
//!
|
||||||
//! `min_savings` 判定や savings 推定もこの crate には置かず、上位層が
|
//! 保護境界は末尾 token budget で決めるが、この crate は usage 履歴を
|
||||||
//! usage 履歴ベースのトークン会計と組み合わせて行う。
|
//! 所有しない。prefix ごとの token 推定値と savings 推定は上位層から
|
||||||
|
//! callback で注入される。
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::llm_client::types::Item;
|
use crate::llm_client::types::Item;
|
||||||
|
use crate::token_counter::{EstimateSource, TokenEstimate};
|
||||||
|
|
||||||
|
/// Callback that returns token estimates for every prefix boundary of the
|
||||||
|
/// supplied request history.
|
||||||
|
///
|
||||||
|
/// The returned slice must have `history.len() + 1` entries where entry `i`
|
||||||
|
/// estimates the token count of `history[..i]`. Returning a malformed vector,
|
||||||
|
/// or estimates whose source is [`EstimateSource::NoData`], makes prune treat
|
||||||
|
/// the request as having no candidates.
|
||||||
|
pub type TokenEstimator = Box<dyn Fn(&[Item]) -> Vec<TokenEstimate> + Send + Sync>;
|
||||||
|
|
||||||
/// Callback that estimates the token savings for projecting the
|
/// Callback that estimates the token savings for projecting the
|
||||||
/// `ToolResult.content` out of `history[i]` for each `i` in `indices`.
|
/// `ToolResult.content` out of `history[i]` for each `i` in `indices`.
|
||||||
|
|
@ -35,16 +46,16 @@ pub type SavingsEstimator = Box<dyn Fn(&[Item], &[usize]) -> u64 + Send + Sync>;
|
||||||
///
|
///
|
||||||
/// Worker は LLM リクエストごとに 1 回 prune の評価をし、その結果を
|
/// Worker は LLM リクエストごとに 1 回 prune の評価をし、その結果を
|
||||||
/// (observer が登録されていれば)この値で通知する。fire/skip の判定
|
/// (observer が登録されていれば)この値で通知する。fire/skip の判定
|
||||||
/// 結果と、判定材料になった候補数 / 推定 savings / 境界ターン位置を持つ。
|
/// 結果と、判定材料になった候補数 / 推定 savings / 保護領域の先頭 index を持つ。
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct PruneEvaluation {
|
pub struct PruneEvaluation {
|
||||||
/// `prunable_indices` の長さ。`Skipped::NoCandidates` の時は 0。
|
/// `prunable_indices` の長さ。`Skipped::NoCandidates` の時は 0。
|
||||||
pub candidate_count: usize,
|
pub candidate_count: usize,
|
||||||
/// 推定された savings (tokens)。`NoCandidates` の時は 0。
|
/// 推定された savings (tokens)。`NoCandidates` の時は 0。
|
||||||
pub estimated_savings: u64,
|
pub estimated_savings: u64,
|
||||||
/// `protected_turns` 境界に当たる turn-start アイテムの index。
|
/// Token budget で保護される suffix の先頭 item index。
|
||||||
/// turn 数が `protected_turns` 以下で境界が決まらない場合は `None`。
|
/// usage 推定が `NoData` で境界が決まらない場合は `None`。
|
||||||
pub border_turn: Option<usize>,
|
pub protected_start_index: Option<usize>,
|
||||||
/// 判定結果。
|
/// 判定結果。
|
||||||
pub decision: PruneDecision,
|
pub decision: PruneDecision,
|
||||||
}
|
}
|
||||||
|
|
@ -70,10 +81,9 @@ pub type PruneObserver = Box<dyn Fn(&PruneEvaluation) + Send + Sync>;
|
||||||
/// Configuration for the Prune algorithm.
|
/// Configuration for the Prune algorithm.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct PruneConfig {
|
pub struct PruneConfig {
|
||||||
/// Number of recent turns to protect from pruning.
|
/// Token budget at the history tail protected from pruning.
|
||||||
/// A "turn" starts at each user message.
|
#[serde(default = "default_protected_tokens")]
|
||||||
#[serde(default = "default_protected_turns")]
|
pub protected_tokens: u64,
|
||||||
pub protected_turns: usize,
|
|
||||||
|
|
||||||
/// Minimum token savings required to actually prune. If the prunable
|
/// Minimum token savings required to actually prune. If the prunable
|
||||||
/// content is smaller than this, the caller should skip to avoid
|
/// content is smaller than this, the caller should skip to avoid
|
||||||
|
|
@ -84,8 +94,8 @@ pub struct PruneConfig {
|
||||||
pub min_savings: u64,
|
pub min_savings: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_protected_turns() -> usize {
|
fn default_protected_tokens() -> u64 {
|
||||||
3
|
8000
|
||||||
}
|
}
|
||||||
fn default_min_savings() -> u64 {
|
fn default_min_savings() -> u64 {
|
||||||
4096
|
4096
|
||||||
|
|
@ -94,25 +104,12 @@ fn default_min_savings() -> u64 {
|
||||||
impl Default for PruneConfig {
|
impl Default for PruneConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
protected_turns: default_protected_turns(),
|
protected_tokens: default_protected_tokens(),
|
||||||
min_savings: default_min_savings(),
|
min_savings: default_min_savings(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Find indices where each "turn" begins.
|
|
||||||
///
|
|
||||||
/// A turn starts at every user message. Returns the indices of those
|
|
||||||
/// user messages in ascending order.
|
|
||||||
fn find_turn_starts(items: &[Item]) -> Vec<usize> {
|
|
||||||
items
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.filter(|(_, item)| item.is_user_message())
|
|
||||||
.map(|(i, _)| i)
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set `content = None` on each `Item::ToolResult` at the given indices.
|
/// Set `content = None` on each `Item::ToolResult` at the given indices.
|
||||||
///
|
///
|
||||||
/// Returns the number of items that were actually modified — items that
|
/// Returns the number of items that were actually modified — items that
|
||||||
|
|
@ -121,36 +118,43 @@ fn find_turn_starts(items: &[Item]) -> Vec<usize> {
|
||||||
pub fn project(items: &mut [Item], indices: &[usize]) -> usize {
|
pub fn project(items: &mut [Item], indices: &[usize]) -> usize {
|
||||||
let mut count = 0;
|
let mut count = 0;
|
||||||
for &i in indices {
|
for &i in indices {
|
||||||
if let Item::ToolResult { content, .. } = &mut items[i] {
|
if let Item::ToolResult { content, .. } = &mut items[i]
|
||||||
if content.is_some() {
|
&& content.is_some()
|
||||||
*content = None;
|
{
|
||||||
count += 1;
|
*content = None;
|
||||||
}
|
count += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
count
|
count
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Indices of `Item::ToolResult { content: Some(_), .. }` that lie outside
|
/// Indices of `Item::ToolResult { content: Some(_), .. }` that lie before
|
||||||
/// the last `protected_turns` turns. Pure: does not mutate `items`.
|
/// the suffix protected by `protected_tokens`. Pure: does not mutate `items`.
|
||||||
///
|
///
|
||||||
/// Returns an empty vector when there are too few turns or no prunable
|
/// Returns an empty vector when token estimates are unavailable (`NoData`) or
|
||||||
/// candidates.
|
/// no prunable candidates exist.
|
||||||
pub fn prunable_indices(items: &[Item], protected_turns: usize) -> Vec<usize> {
|
pub fn prunable_indices(
|
||||||
evaluate_candidates(items, protected_turns).0
|
items: &[Item],
|
||||||
|
protected_tokens: u64,
|
||||||
|
token_estimates: &[TokenEstimate],
|
||||||
|
) -> Vec<usize> {
|
||||||
|
evaluate_candidates(items, protected_tokens, token_estimates).0
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Same as [`prunable_indices`] but also returns the index of the
|
/// Same as [`prunable_indices`] but also returns the start index of the
|
||||||
/// `protected_turns` boundary (the turn-start item whose tail is
|
/// protected suffix. `None` means the token boundary could not be determined
|
||||||
/// protected). `None` when too few turns exist for a boundary to be
|
/// (currently because usage estimates were `NoData` or malformed).
|
||||||
/// defined.
|
pub fn evaluate_candidates(
|
||||||
pub fn evaluate_candidates(items: &[Item], protected_turns: usize) -> (Vec<usize>, Option<usize>) {
|
items: &[Item],
|
||||||
let turn_starts = find_turn_starts(items);
|
protected_tokens: u64,
|
||||||
if turn_starts.len() <= protected_turns {
|
token_estimates: &[TokenEstimate],
|
||||||
|
) -> (Vec<usize>, Option<usize>) {
|
||||||
|
let Some(protected_start) = protected_start_index(items, protected_tokens, token_estimates)
|
||||||
|
else {
|
||||||
return (Vec::new(), None);
|
return (Vec::new(), None);
|
||||||
}
|
};
|
||||||
let boundary = turn_starts[turn_starts.len() - protected_turns];
|
|
||||||
let candidates = items[..boundary]
|
let candidates = items[..protected_start]
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.filter_map(|(i, item)| match item {
|
.filter_map(|(i, item)| match item {
|
||||||
|
|
@ -160,7 +164,38 @@ pub fn evaluate_candidates(items: &[Item], protected_turns: usize) -> (Vec<usize
|
||||||
_ => None,
|
_ => None,
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
(candidates, Some(boundary))
|
(candidates, Some(protected_start))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn protected_start_index(
|
||||||
|
items: &[Item],
|
||||||
|
protected_tokens: u64,
|
||||||
|
token_estimates: &[TokenEstimate],
|
||||||
|
) -> Option<usize> {
|
||||||
|
if token_estimates.len() != items.len() + 1 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let total = token_estimates[items.len()];
|
||||||
|
if total.source == EstimateSource::NoData {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
if protected_tokens == 0 {
|
||||||
|
return Some(items.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut protected_start = items.len();
|
||||||
|
for idx in (0..items.len()).rev() {
|
||||||
|
let prefix = token_estimates[idx];
|
||||||
|
if prefix.source == EstimateSource::NoData {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
protected_start = idx;
|
||||||
|
let tail_tokens = total.tokens.saturating_sub(prefix.tokens);
|
||||||
|
if tail_tokens >= protected_tokens {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(protected_start)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
@ -185,17 +220,70 @@ mod tests {
|
||||||
items
|
items
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn measured_prefix(tokens: &[u64]) -> Vec<TokenEstimate> {
|
||||||
|
tokens
|
||||||
|
.iter()
|
||||||
|
.copied()
|
||||||
|
.map(|tokens| TokenEstimate {
|
||||||
|
tokens,
|
||||||
|
source: EstimateSource::Measured,
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn uniform_estimates(items: &[Item], item_tokens: u64) -> Vec<TokenEstimate> {
|
||||||
|
let mut tokens = Vec::with_capacity(items.len() + 1);
|
||||||
|
for i in 0..=items.len() {
|
||||||
|
tokens.push(i as u64 * item_tokens);
|
||||||
|
}
|
||||||
|
measured_prefix(&tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn estimates_from_item_tokens(item_tokens: &[u64]) -> Vec<TokenEstimate> {
|
||||||
|
let mut prefix = Vec::with_capacity(item_tokens.len() + 1);
|
||||||
|
let mut acc = 0;
|
||||||
|
prefix.push(acc);
|
||||||
|
for tokens in item_tokens {
|
||||||
|
acc += tokens;
|
||||||
|
prefix.push(acc);
|
||||||
|
}
|
||||||
|
measured_prefix(&prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn no_data_estimates(items: &[Item]) -> Vec<TokenEstimate> {
|
||||||
|
(0..=items.len())
|
||||||
|
.map(|i| TokenEstimate {
|
||||||
|
tokens: i as u64,
|
||||||
|
source: if i == 0 {
|
||||||
|
EstimateSource::Measured
|
||||||
|
} else {
|
||||||
|
EstimateSource::NoData
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn no_candidates_when_too_few_turns() {
|
fn no_candidates_when_estimate_has_no_data() {
|
||||||
|
let items = make_history(&[("turn1", vec![("summary1", Some("big content here"))])]);
|
||||||
|
let estimates = no_data_estimates(&items);
|
||||||
|
let (candidates, protected_start) = evaluate_candidates(&items, 10, &estimates);
|
||||||
|
assert!(candidates.is_empty());
|
||||||
|
assert_eq!(protected_start, None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn no_candidates_when_history_fits_in_protected_tokens() {
|
||||||
let items = make_history(&[
|
let items = make_history(&[
|
||||||
("turn1", vec![("summary1", Some("big content here"))]),
|
("turn1", vec![("summary1", Some("big content here"))]),
|
||||||
("turn2", vec![("summary2", Some("more content"))]),
|
("turn2", vec![("summary2", Some("more content"))]),
|
||||||
]);
|
]);
|
||||||
assert!(prunable_indices(&items, 3).is_empty());
|
let estimates = uniform_estimates(&items, 10);
|
||||||
|
assert!(prunable_indices(&items, 10_000, &estimates).is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn candidates_in_unprotected_turns() {
|
fn candidates_before_token_protected_suffix() {
|
||||||
let big = "x".repeat(4096 * 4);
|
let big = "x".repeat(4096 * 4);
|
||||||
let items = make_history(&[
|
let items = make_history(&[
|
||||||
("turn1", vec![("s1", Some(&big))]),
|
("turn1", vec![("s1", Some(&big))]),
|
||||||
|
|
@ -203,9 +291,39 @@ mod tests {
|
||||||
("turn3", vec![("s3", Some("keep me"))]),
|
("turn3", vec![("s3", Some("keep me"))]),
|
||||||
("turn4", vec![("s4", Some("keep me too"))]),
|
("turn4", vec![("s4", Some("keep me too"))]),
|
||||||
]);
|
]);
|
||||||
let candidates = prunable_indices(&items, 2);
|
let estimates = uniform_estimates(&items, 10);
|
||||||
|
let candidates = prunable_indices(&items, 80, &estimates);
|
||||||
|
assert_eq!(candidates.len(), 2);
|
||||||
|
// suffix budget 80 tokens protects turn3+turn4 (8 items), so only s1/s2 are candidates.
|
||||||
|
for &i in &candidates {
|
||||||
|
if let Item::ToolResult { summary, .. } = &items[i] {
|
||||||
|
assert!(summary == "s1" || summary == "s2");
|
||||||
|
} else {
|
||||||
|
panic!("non tool-result selected");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn single_long_task_gets_candidates_without_multiple_user_turns() {
|
||||||
|
let big = "x".repeat(4096 * 8);
|
||||||
|
let items = make_history(&[(
|
||||||
|
"one long task",
|
||||||
|
vec![
|
||||||
|
("s1", Some(&big)),
|
||||||
|
("s2", Some(&big)),
|
||||||
|
("s3", Some(&big)),
|
||||||
|
("s4", Some(&big)),
|
||||||
|
],
|
||||||
|
)]);
|
||||||
|
// user + assistant are cheap; every ToolCall is cheap; every ToolResult is heavy.
|
||||||
|
let item_tokens = vec![1, 1, 1, 5_000, 1, 5_000, 1, 5_000, 1, 5_000];
|
||||||
|
let estimates = estimates_from_item_tokens(&item_tokens);
|
||||||
|
|
||||||
|
let (candidates, protected_start) = evaluate_candidates(&items, 8_000, &estimates);
|
||||||
|
|
||||||
|
assert_eq!(protected_start, Some(7));
|
||||||
assert_eq!(candidates.len(), 2);
|
assert_eq!(candidates.len(), 2);
|
||||||
// 候補は turn1 と turn2 の ToolResult のみ
|
|
||||||
for &i in &candidates {
|
for &i in &candidates {
|
||||||
if let Item::ToolResult { summary, .. } = &items[i] {
|
if let Item::ToolResult { summary, .. } = &items[i] {
|
||||||
assert!(summary == "s1" || summary == "s2");
|
assert!(summary == "s1" || summary == "s2");
|
||||||
|
|
@ -223,7 +341,8 @@ mod tests {
|
||||||
("turn3", vec![]),
|
("turn3", vec![]),
|
||||||
("turn4", vec![]),
|
("turn4", vec![]),
|
||||||
]);
|
]);
|
||||||
assert!(prunable_indices(&items, 2).is_empty());
|
let estimates = uniform_estimates(&items, 10);
|
||||||
|
assert!(prunable_indices(&items, 20, &estimates).is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -235,7 +354,8 @@ mod tests {
|
||||||
("turn3", vec![("s3", Some("keep me"))]),
|
("turn3", vec![("s3", Some("keep me"))]),
|
||||||
("turn4", vec![("s4", Some("keep me too"))]),
|
("turn4", vec![("s4", Some("keep me too"))]),
|
||||||
]);
|
]);
|
||||||
let candidates = prunable_indices(&items, 2);
|
let estimates = uniform_estimates(&items, 10);
|
||||||
|
let candidates = prunable_indices(&items, 80, &estimates);
|
||||||
let count = project(&mut items, &candidates);
|
let count = project(&mut items, &candidates);
|
||||||
assert_eq!(count, 2);
|
assert_eq!(count, 2);
|
||||||
|
|
||||||
|
|
@ -261,7 +381,7 @@ mod tests {
|
||||||
("turn1", vec![("s1", None)]),
|
("turn1", vec![("s1", None)]),
|
||||||
("turn2", vec![("s2", Some("hello"))]),
|
("turn2", vec![("s2", Some("hello"))]),
|
||||||
]);
|
]);
|
||||||
// Manually target s1 (index 3) even though it's already None.
|
// Manually target s1 even though it's already None.
|
||||||
let target = items
|
let target = items
|
||||||
.iter()
|
.iter()
|
||||||
.position(|it| matches!(it, Item::ToolResult { summary, .. } if summary == "s1"))
|
.position(|it| matches!(it, Item::ToolResult { summary, .. } if summary == "s1"))
|
||||||
|
|
@ -279,14 +399,15 @@ mod tests {
|
||||||
("turn3", vec![]),
|
("turn3", vec![]),
|
||||||
("turn4", vec![]),
|
("turn4", vec![]),
|
||||||
]);
|
]);
|
||||||
let candidates = prunable_indices(&items, 2);
|
let estimates = uniform_estimates(&items, 10);
|
||||||
|
let candidates = prunable_indices(&items, 20, &estimates);
|
||||||
assert_eq!(project(&mut items, &candidates), 1);
|
assert_eq!(project(&mut items, &candidates), 1);
|
||||||
// 2 周目: 候補は一度の prunable_indices 結果を使い回しても 0 件。
|
// 2 周目: 候補は一度の prunable_indices 結果を使い回しても 0 件。
|
||||||
assert_eq!(project(&mut items, &candidates), 0);
|
assert_eq!(project(&mut items, &candidates), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn evaluate_candidates_returns_boundary_index() {
|
fn evaluate_candidates_returns_protected_start_index() {
|
||||||
let big = "x".repeat(64);
|
let big = "x".repeat(64);
|
||||||
let items = make_history(&[
|
let items = make_history(&[
|
||||||
("turn1", vec![("s1", Some(&big))]),
|
("turn1", vec![("s1", Some(&big))]),
|
||||||
|
|
@ -294,36 +415,37 @@ mod tests {
|
||||||
("turn3", vec![("s3", Some("keep"))]),
|
("turn3", vec![("s3", Some("keep"))]),
|
||||||
("turn4", vec![("s4", Some("keep too"))]),
|
("turn4", vec![("s4", Some("keep too"))]),
|
||||||
]);
|
]);
|
||||||
let (candidates, border) = evaluate_candidates(&items, 2);
|
let estimates = uniform_estimates(&items, 10);
|
||||||
|
let (candidates, protected_start) = evaluate_candidates(&items, 80, &estimates);
|
||||||
assert_eq!(candidates.len(), 2);
|
assert_eq!(candidates.len(), 2);
|
||||||
// protected_turns=2 → boundary は turn3 の user message 位置。
|
// protected_tokens=80 → protected suffix is turn3+turn4, starting at index 8.
|
||||||
// turn1: u/a/c/r (4) + turn2: u/a/c/r (4) = index 8 (turn3 の user)。
|
assert_eq!(protected_start, Some(8));
|
||||||
assert_eq!(border, Some(8));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn evaluate_candidates_no_boundary_when_too_few_turns() {
|
fn evaluate_candidates_reports_zero_start_when_everything_is_protected() {
|
||||||
let items = make_history(&[("only", vec![("s", Some("x"))])]);
|
let items = make_history(&[("only", vec![("s", Some("x"))])]);
|
||||||
let (candidates, border) = evaluate_candidates(&items, 2);
|
let estimates = uniform_estimates(&items, 10);
|
||||||
|
let (candidates, protected_start) = evaluate_candidates(&items, 10_000, &estimates);
|
||||||
assert!(candidates.is_empty());
|
assert!(candidates.is_empty());
|
||||||
assert!(border.is_none());
|
assert_eq!(protected_start, Some(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn protected_turns_boundary_exact() {
|
fn zero_protected_tokens_allows_all_tool_results_as_candidates() {
|
||||||
// 3 turns with protected_turns=2: only turn 1 is a candidate.
|
|
||||||
let big = "x".repeat(64);
|
let big = "x".repeat(64);
|
||||||
let items = make_history(&[
|
let items = make_history(&[("turn1", vec![("s1", Some(&big)), ("s2", Some(&big))])]);
|
||||||
("turn1", vec![("s1", Some(&big))]),
|
let estimates = uniform_estimates(&items, 10);
|
||||||
("turn2", vec![("s2", Some("protected"))]),
|
let (candidates, protected_start) = evaluate_candidates(&items, 0, &estimates);
|
||||||
("turn3", vec![("s3", Some("also protected"))]),
|
assert_eq!(protected_start, Some(items.len()));
|
||||||
]);
|
assert_eq!(candidates.len(), 2);
|
||||||
let candidates = prunable_indices(&items, 2);
|
}
|
||||||
assert_eq!(candidates.len(), 1);
|
|
||||||
if let Item::ToolResult { summary, .. } = &items[candidates[0]] {
|
#[test]
|
||||||
assert_eq!(summary, "s1");
|
fn malformed_estimate_vector_is_treated_as_no_boundary() {
|
||||||
} else {
|
let items = make_history(&[("turn1", vec![("s1", Some("x"))])]);
|
||||||
panic!("expected ToolResult at candidate index");
|
let (candidates, protected_start) = evaluate_candidates(&items, 10, &[]);
|
||||||
}
|
assert!(candidates.is_empty());
|
||||||
|
assert_eq!(protected_start, None);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -201,6 +201,10 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
||||||
tool_output_limits: Option<ToolOutputLimits>,
|
tool_output_limits: Option<ToolOutputLimits>,
|
||||||
/// Prune configuration. `None` disables the prune projection.
|
/// Prune configuration. `None` disables the prune projection.
|
||||||
prune_config: Option<crate::prune::PruneConfig>,
|
prune_config: Option<crate::prune::PruneConfig>,
|
||||||
|
/// Callback that estimates prefix token counts, injected by higher
|
||||||
|
/// layers that own usage measurements. `None` disables the prune
|
||||||
|
/// projection.
|
||||||
|
token_estimator: Option<crate::prune::TokenEstimator>,
|
||||||
/// Callback that estimates token savings for a drop range, injected
|
/// Callback that estimates token savings for a drop range, injected
|
||||||
/// by higher layers that own usage measurements. `None` disables
|
/// by higher layers that own usage measurements. `None` disables
|
||||||
/// the prune projection.
|
/// the prune projection.
|
||||||
|
|
@ -434,6 +438,17 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
self.prune_config = config;
|
self.prune_config = config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Inject the callback used to estimate prefix token counts for prune's
|
||||||
|
/// protected-token boundary.
|
||||||
|
///
|
||||||
|
/// The callback is invoked with the *request context* (a clone of
|
||||||
|
/// history). It must be pure/idempotent since it may be called once per
|
||||||
|
/// LLM request. Returning `NoData` estimates makes prune skip as if no
|
||||||
|
/// candidates existed.
|
||||||
|
pub fn set_token_estimator(&mut self, estimator: Option<crate::prune::TokenEstimator>) {
|
||||||
|
self.token_estimator = estimator;
|
||||||
|
}
|
||||||
|
|
||||||
/// Inject the callback used to estimate token savings for a prune
|
/// Inject the callback used to estimate token savings for a prune
|
||||||
/// candidate range.
|
/// candidate range.
|
||||||
///
|
///
|
||||||
|
|
@ -983,18 +998,26 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
// prunable candidates whose estimated savings meet the
|
// prunable candidates whose estimated savings meet the
|
||||||
// threshold. Worker does not own usage history itself; the
|
// threshold. Worker does not own usage history itself; the
|
||||||
// estimator is injected by the layer that does.
|
// estimator is injected by the layer that does.
|
||||||
if let (Some(config), Some(estimator)) = (&self.prune_config, &self.savings_estimator) {
|
if let (Some(config), Some(token_estimator), Some(savings_estimator)) = (
|
||||||
let (candidates, border_turn) =
|
&self.prune_config,
|
||||||
crate::prune::evaluate_candidates(&request_context, config.protected_turns);
|
&self.token_estimator,
|
||||||
|
&self.savings_estimator,
|
||||||
|
) {
|
||||||
|
let token_estimates = token_estimator(&request_context);
|
||||||
|
let (candidates, protected_start_index) = crate::prune::evaluate_candidates(
|
||||||
|
&request_context,
|
||||||
|
config.protected_tokens,
|
||||||
|
&token_estimates,
|
||||||
|
);
|
||||||
let evaluation = if candidates.is_empty() {
|
let evaluation = if candidates.is_empty() {
|
||||||
crate::prune::PruneEvaluation {
|
crate::prune::PruneEvaluation {
|
||||||
candidate_count: 0,
|
candidate_count: 0,
|
||||||
estimated_savings: 0,
|
estimated_savings: 0,
|
||||||
border_turn,
|
protected_start_index,
|
||||||
decision: crate::prune::PruneDecision::SkippedNoCandidates,
|
decision: crate::prune::PruneDecision::SkippedNoCandidates,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let savings = estimator(&request_context, &candidates);
|
let savings = savings_estimator(&request_context, &candidates);
|
||||||
if savings >= config.min_savings {
|
if savings >= config.min_savings {
|
||||||
let pruned = crate::prune::project(&mut request_context, &candidates);
|
let pruned = crate::prune::project(&mut request_context, &candidates);
|
||||||
if pruned > 0 {
|
if pruned > 0 {
|
||||||
|
|
@ -1007,7 +1030,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
crate::prune::PruneEvaluation {
|
crate::prune::PruneEvaluation {
|
||||||
candidate_count: candidates.len(),
|
candidate_count: candidates.len(),
|
||||||
estimated_savings: savings,
|
estimated_savings: savings,
|
||||||
border_turn,
|
protected_start_index,
|
||||||
decision: crate::prune::PruneDecision::Fired {
|
decision: crate::prune::PruneDecision::Fired {
|
||||||
pruned_count: pruned,
|
pruned_count: pruned,
|
||||||
},
|
},
|
||||||
|
|
@ -1016,7 +1039,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
crate::prune::PruneEvaluation {
|
crate::prune::PruneEvaluation {
|
||||||
candidate_count: candidates.len(),
|
candidate_count: candidates.len(),
|
||||||
estimated_savings: savings,
|
estimated_savings: savings,
|
||||||
border_turn,
|
protected_start_index,
|
||||||
decision: crate::prune::PruneDecision::SkippedBelowMinSavings,
|
decision: crate::prune::PruneDecision::SkippedBelowMinSavings,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1256,6 +1279,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
||||||
cancel_rx,
|
cancel_rx,
|
||||||
tool_output_limits: None,
|
tool_output_limits: None,
|
||||||
prune_config: None,
|
prune_config: None,
|
||||||
|
token_estimator: None,
|
||||||
savings_estimator: None,
|
savings_estimator: None,
|
||||||
prune_observer: None,
|
prune_observer: None,
|
||||||
cache_anchor: None,
|
cache_anchor: None,
|
||||||
|
|
@ -1519,6 +1543,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
||||||
cancel_rx: self.cancel_rx,
|
cancel_rx: self.cancel_rx,
|
||||||
tool_output_limits: self.tool_output_limits,
|
tool_output_limits: self.tool_output_limits,
|
||||||
prune_config: self.prune_config,
|
prune_config: self.prune_config,
|
||||||
|
token_estimator: self.token_estimator,
|
||||||
savings_estimator: self.savings_estimator,
|
savings_estimator: self.savings_estimator,
|
||||||
prune_observer: self.prune_observer,
|
prune_observer: self.prune_observer,
|
||||||
cache_anchor: self.cache_anchor,
|
cache_anchor: self.cache_anchor,
|
||||||
|
|
@ -1605,6 +1630,7 @@ impl<C: LlmClient> Worker<C, Locked> {
|
||||||
cancel_rx: self.cancel_rx,
|
cancel_rx: self.cancel_rx,
|
||||||
tool_output_limits: self.tool_output_limits,
|
tool_output_limits: self.tool_output_limits,
|
||||||
prune_config: self.prune_config,
|
prune_config: self.prune_config,
|
||||||
|
token_estimator: self.token_estimator,
|
||||||
savings_estimator: self.savings_estimator,
|
savings_estimator: self.savings_estimator,
|
||||||
prune_observer: self.prune_observer,
|
prune_observer: self.prune_observer,
|
||||||
cache_anchor: self.cache_anchor,
|
cache_anchor: self.cache_anchor,
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ use std::collections::HashMap;
|
||||||
use std::num::NonZeroU32;
|
use std::num::NonZeroU32;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
use serde::de::Error as _;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::defaults;
|
use crate::defaults;
|
||||||
|
|
@ -112,7 +113,7 @@ pub struct PermissionConfigPartial {
|
||||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||||
pub struct CompactionConfigPartial {
|
pub struct CompactionConfigPartial {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub prune_protected_turns: Option<usize>,
|
pub prune_protected_tokens: Option<u64>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub prune_min_savings: Option<u64>,
|
pub prune_min_savings: Option<u64>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
|
@ -141,12 +142,31 @@ pub enum ResolveError {
|
||||||
RelativePath { field: &'static str, path: PathBuf },
|
RelativePath { field: &'static str, path: PathBuf },
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Reject manifest fields that were intentionally removed and must not be
|
||||||
|
/// silently swallowed by the general warn-and-ignore unknown-field policy.
|
||||||
|
pub(crate) fn reject_removed_manifest_fields(s: &str) -> Result<(), toml::de::Error> {
|
||||||
|
let value: toml::Value = toml::from_str(s)?;
|
||||||
|
if value
|
||||||
|
.get("compaction")
|
||||||
|
.and_then(toml::Value::as_table)
|
||||||
|
.is_some_and(|table| table.contains_key("prune_protected_turns"))
|
||||||
|
{
|
||||||
|
return Err(toml::de::Error::custom(
|
||||||
|
"unknown field in manifest: compaction.prune_protected_turns \
|
||||||
|
(removed; use compaction.prune_protected_tokens)",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
impl PodManifestConfig {
|
impl PodManifestConfig {
|
||||||
/// Parse a partial manifest from a TOML string. Unknown top-level or
|
/// Parse a partial manifest from a TOML string. Unknown top-level or
|
||||||
/// nested fields emit a `tracing::warn!` and are ignored; use
|
/// nested fields emit a `tracing::warn!` and are ignored; use
|
||||||
/// `tracing_subscriber` with `WARN` enabled to surface them to the
|
/// `tracing_subscriber` with `WARN` enabled to surface them to the
|
||||||
/// operator.
|
/// operator. Removed fields that must not be silently ignored (currently
|
||||||
|
/// `compaction.prune_protected_turns`) are rejected before deserialization.
|
||||||
pub fn from_toml(s: &str) -> Result<Self, toml::de::Error> {
|
pub fn from_toml(s: &str) -> Result<Self, toml::de::Error> {
|
||||||
|
reject_removed_manifest_fields(s)?;
|
||||||
let de = toml::Deserializer::parse(s)?;
|
let de = toml::Deserializer::parse(s)?;
|
||||||
serde_ignored::deserialize(de, |path| {
|
serde_ignored::deserialize(de, |path| {
|
||||||
tracing::warn!("unknown field in manifest: {}", path);
|
tracing::warn!("unknown field in manifest: {}", path);
|
||||||
|
|
@ -339,7 +359,7 @@ impl PermissionConfigPartial {
|
||||||
impl CompactionConfigPartial {
|
impl CompactionConfigPartial {
|
||||||
fn merge(self, upper: Self) -> Self {
|
fn merge(self, upper: Self) -> Self {
|
||||||
Self {
|
Self {
|
||||||
prune_protected_turns: upper.prune_protected_turns.or(self.prune_protected_turns),
|
prune_protected_tokens: upper.prune_protected_tokens.or(self.prune_protected_tokens),
|
||||||
prune_min_savings: upper.prune_min_savings.or(self.prune_min_savings),
|
prune_min_savings: upper.prune_min_savings.or(self.prune_min_savings),
|
||||||
compact_threshold: upper.compact_threshold.or(self.compact_threshold),
|
compact_threshold: upper.compact_threshold.or(self.compact_threshold),
|
||||||
compact_request_threshold: upper
|
compact_request_threshold: upper
|
||||||
|
|
@ -489,9 +509,9 @@ impl TryFrom<PodManifestConfig> for PodManifest {
|
||||||
validate_model_paths(cm, "compaction.model.auth.file")?;
|
validate_model_paths(cm, "compaction.model.auth.file")?;
|
||||||
}
|
}
|
||||||
Ok(CompactionConfig {
|
Ok(CompactionConfig {
|
||||||
prune_protected_turns: c
|
prune_protected_tokens: c
|
||||||
.prune_protected_turns
|
.prune_protected_tokens
|
||||||
.unwrap_or(defaults::PRUNE_PROTECTED_TURNS),
|
.unwrap_or(defaults::PRUNE_PROTECTED_TOKENS),
|
||||||
prune_min_savings: c.prune_min_savings.unwrap_or(defaults::PRUNE_MIN_SAVINGS),
|
prune_min_savings: c.prune_min_savings.unwrap_or(defaults::PRUNE_MIN_SAVINGS),
|
||||||
compact_threshold: c.compact_threshold,
|
compact_threshold: c.compact_threshold,
|
||||||
compact_request_threshold: c.compact_request_threshold,
|
compact_request_threshold: c.compact_request_threshold,
|
||||||
|
|
@ -921,7 +941,7 @@ mod tests {
|
||||||
let lower = PodManifestConfig {
|
let lower = PodManifestConfig {
|
||||||
compaction: Some(CompactionConfigPartial {
|
compaction: Some(CompactionConfigPartial {
|
||||||
compact_threshold: Some(50_000),
|
compact_threshold: Some(50_000),
|
||||||
prune_protected_turns: Some(5),
|
prune_protected_tokens: Some(5_000),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
}),
|
}),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
|
|
@ -937,7 +957,7 @@ mod tests {
|
||||||
let c = merged.compaction.unwrap();
|
let c = merged.compaction.unwrap();
|
||||||
assert_eq!(c.compact_threshold, Some(80_000));
|
assert_eq!(c.compact_threshold, Some(80_000));
|
||||||
// field from lower retained when upper has None
|
// field from lower retained when upper has None
|
||||||
assert_eq!(c.prune_protected_turns, Some(5));
|
assert_eq!(c.prune_protected_tokens, Some(5_000));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -971,6 +991,19 @@ unknown_future_field = "tolerated"
|
||||||
assert_eq!(cfg.worker.max_tokens, Some(1000));
|
assert_eq!(cfg.worker.max_tokens, Some(1000));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn from_toml_rejects_removed_prune_protected_turns_field() {
|
||||||
|
let bad = r#"
|
||||||
|
[compaction]
|
||||||
|
prune_protected_turns = 3
|
||||||
|
"#;
|
||||||
|
let err = PodManifestConfig::from_toml(bad).unwrap_err();
|
||||||
|
assert!(
|
||||||
|
err.to_string().contains("compaction.prune_protected_turns"),
|
||||||
|
"unexpected error: {err}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn from_toml_accepts_worker_reasoning_string_or_integer() {
|
fn from_toml_accepts_worker_reasoning_string_or_integer() {
|
||||||
let effort = PodManifestConfig::from_toml(
|
let effort = PodManifestConfig::from_toml(
|
||||||
|
|
|
||||||
|
|
@ -14,9 +14,9 @@ pub const TOOL_OUTPUT_MAX_BYTES: usize = 64 * 1024;
|
||||||
/// See [`crate::FileUploadLimits`].
|
/// See [`crate::FileUploadLimits`].
|
||||||
pub const FILE_UPLOAD_MAX_BYTES: usize = 256 * 1024;
|
pub const FILE_UPLOAD_MAX_BYTES: usize = 256 * 1024;
|
||||||
|
|
||||||
/// Number of most-recent turns protected from pruning. See
|
/// Token budget at the history tail protected from pruning. See
|
||||||
/// [`crate::CompactionConfig::prune_protected_turns`].
|
/// [`crate::CompactionConfig::prune_protected_tokens`].
|
||||||
pub const PRUNE_PROTECTED_TURNS: usize = 3;
|
pub const PRUNE_PROTECTED_TOKENS: u64 = 8000;
|
||||||
|
|
||||||
/// Minimum estimated token savings required to trigger a prune. See
|
/// Minimum estimated token savings required to trigger a prune. See
|
||||||
/// [`crate::CompactionConfig::prune_min_savings`].
|
/// [`crate::CompactionConfig::prune_min_savings`].
|
||||||
|
|
|
||||||
|
|
@ -337,9 +337,9 @@ pub enum ToolPermissionAction {
|
||||||
/// (full history summarisation). Omitting `[compaction]` disables both.
|
/// (full history summarisation). Omitting `[compaction]` disables both.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct CompactionConfig {
|
pub struct CompactionConfig {
|
||||||
/// Number of recent turns protected from pruning.
|
/// Token budget at the history tail protected from pruning.
|
||||||
#[serde(default = "default_prune_protected_turns")]
|
#[serde(default = "default_prune_protected_tokens")]
|
||||||
pub prune_protected_turns: usize,
|
pub prune_protected_tokens: u64,
|
||||||
|
|
||||||
/// Minimum estimated token savings to trigger a prune.
|
/// Minimum estimated token savings to trigger a prune.
|
||||||
#[serde(default = "default_prune_min_savings")]
|
#[serde(default = "default_prune_min_savings")]
|
||||||
|
|
@ -393,8 +393,8 @@ pub struct CompactionConfig {
|
||||||
pub model: Option<ModelManifest>,
|
pub model: Option<ModelManifest>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_prune_protected_turns() -> usize {
|
fn default_prune_protected_tokens() -> u64 {
|
||||||
defaults::PRUNE_PROTECTED_TURNS
|
defaults::PRUNE_PROTECTED_TOKENS
|
||||||
}
|
}
|
||||||
fn default_prune_min_savings() -> u64 {
|
fn default_prune_min_savings() -> u64 {
|
||||||
defaults::PRUNE_MIN_SAVINGS
|
defaults::PRUNE_MIN_SAVINGS
|
||||||
|
|
@ -415,7 +415,7 @@ fn default_compact_worker_max_turns() -> Option<u32> {
|
||||||
impl Default for CompactionConfig {
|
impl Default for CompactionConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
prune_protected_turns: default_prune_protected_turns(),
|
prune_protected_tokens: default_prune_protected_tokens(),
|
||||||
prune_min_savings: default_prune_min_savings(),
|
prune_min_savings: default_prune_min_savings(),
|
||||||
compact_threshold: None,
|
compact_threshold: None,
|
||||||
compact_request_threshold: None,
|
compact_request_threshold: None,
|
||||||
|
|
@ -431,6 +431,7 @@ impl Default for CompactionConfig {
|
||||||
impl PodManifest {
|
impl PodManifest {
|
||||||
/// Parse a manifest from a TOML string.
|
/// Parse a manifest from a TOML string.
|
||||||
pub fn from_toml(s: &str) -> Result<Self, toml::de::Error> {
|
pub fn from_toml(s: &str) -> Result<Self, toml::de::Error> {
|
||||||
|
config::reject_removed_manifest_fields(s)?;
|
||||||
toml::from_str(s)
|
toml::from_str(s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -581,7 +582,7 @@ model_id = "claude-sonnet-4-20250514"
|
||||||
let toml = format!("{MINIMAL_REQUIRED}\n[compaction]\ncompact_threshold = 80000\n");
|
let toml = format!("{MINIMAL_REQUIRED}\n[compaction]\ncompact_threshold = 80000\n");
|
||||||
let manifest = PodManifest::from_toml(&toml).unwrap();
|
let manifest = PodManifest::from_toml(&toml).unwrap();
|
||||||
let c = manifest.compaction.unwrap();
|
let c = manifest.compaction.unwrap();
|
||||||
assert_eq!(c.prune_protected_turns, 3);
|
assert_eq!(c.prune_protected_tokens, 8000);
|
||||||
assert_eq!(c.prune_min_savings, 4096);
|
assert_eq!(c.prune_min_savings, 4096);
|
||||||
assert_eq!(c.compact_threshold, Some(80000));
|
assert_eq!(c.compact_threshold, Some(80000));
|
||||||
assert_eq!(c.compact_request_threshold, None);
|
assert_eq!(c.compact_request_threshold, None);
|
||||||
|
|
@ -589,6 +590,16 @@ model_id = "claude-sonnet-4-20250514"
|
||||||
assert_eq!(c.compact_worker_max_turns, Some(20));
|
assert_eq!(c.compact_worker_max_turns, Some(20));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn reject_removed_prune_protected_turns_field() {
|
||||||
|
let toml = format!("{MINIMAL_REQUIRED}\n[compaction]\nprune_protected_turns = 3\n");
|
||||||
|
let err = PodManifest::from_toml(&toml).unwrap_err();
|
||||||
|
assert!(
|
||||||
|
err.to_string().contains("compaction.prune_protected_turns"),
|
||||||
|
"unexpected error: {err}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn parse_compaction_worker_max_turns() {
|
fn parse_compaction_worker_max_turns() {
|
||||||
let toml = format!(
|
let toml = format!(
|
||||||
|
|
|
||||||
|
|
@ -13,19 +13,24 @@
|
||||||
|
|
||||||
use llm_worker::Item;
|
use llm_worker::Item;
|
||||||
use llm_worker::llm_client::client::LlmClient;
|
use llm_worker::llm_client::client::LlmClient;
|
||||||
use llm_worker::prune::{PruneConfig, PruneDecision, PruneObserver, SavingsEstimator};
|
use llm_worker::prune::{
|
||||||
|
PruneConfig, PruneDecision, PruneObserver, SavingsEstimator, TokenEstimator,
|
||||||
|
};
|
||||||
use session_metrics::Metric;
|
use session_metrics::Metric;
|
||||||
use session_store::Store;
|
use session_store::Store;
|
||||||
|
|
||||||
use crate::Pod;
|
use crate::Pod;
|
||||||
use crate::compact::token_counter::{EstimateSource, savings_for_prune_impl};
|
use crate::compact::token_counter::{
|
||||||
|
EstimateSource, savings_for_prune_impl, token_estimates_for_prune_impl,
|
||||||
|
};
|
||||||
|
|
||||||
impl<C: LlmClient, St: Store> Pod<C, St> {
|
impl<C: LlmClient, St: Store> Pod<C, St> {
|
||||||
/// Enable prune projection on the underlying Worker.
|
/// Enable prune projection on the underlying Worker.
|
||||||
///
|
///
|
||||||
/// Registers the config and a savings-estimator closure on the Worker.
|
/// Registers the config and token/savings-estimator closures on the Worker.
|
||||||
/// The estimator captures a shared handle to [`Pod::usage_history_handle`]
|
/// The estimators combine persisted [`Pod::usage_history_handle`] records
|
||||||
/// so that every LLM request sees the latest measurements.
|
/// with in-flight `UsageTracker` records so multi-request tool loops can
|
||||||
|
/// prune before the surrounding Pod run finishes.
|
||||||
///
|
///
|
||||||
/// Measurement-less estimates (before the first LLM call, or immediately
|
/// Measurement-less estimates (before the first LLM call, or immediately
|
||||||
/// after a compact) return `0` from the estimator, which naturally
|
/// after a compact) return `0` from the estimator, which naturally
|
||||||
|
|
@ -37,9 +42,25 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
|
||||||
/// [`UsageTracker`] so the next `LlmUsage` can be paired with a
|
/// [`UsageTracker`] so the next `LlmUsage` can be paired with a
|
||||||
/// `prune.post_request` metric carrying the same id.
|
/// `prune.post_request` metric carrying the same id.
|
||||||
pub fn attach_prune(&mut self, config: PruneConfig) {
|
pub fn attach_prune(&mut self, config: PruneConfig) {
|
||||||
let usage = self.usage_history_handle();
|
let usage_history_for_tokens = self.usage_history_handle();
|
||||||
|
let usage_tracker_for_tokens = self.usage_tracker_handle();
|
||||||
|
let token_estimator: TokenEstimator = Box::new(move |history: &[Item]| {
|
||||||
|
let mut snapshot = usage_history_for_tokens
|
||||||
|
.lock()
|
||||||
|
.expect("usage_history poisoned")
|
||||||
|
.clone();
|
||||||
|
snapshot.extend(usage_tracker_for_tokens.records());
|
||||||
|
token_estimates_for_prune_impl(history, &snapshot)
|
||||||
|
});
|
||||||
|
|
||||||
|
let usage_history_for_savings = self.usage_history_handle();
|
||||||
|
let usage_tracker_for_savings = self.usage_tracker_handle();
|
||||||
let estimator: SavingsEstimator = Box::new(move |history: &[Item], indices| {
|
let estimator: SavingsEstimator = Box::new(move |history: &[Item], indices| {
|
||||||
let snapshot = usage.lock().expect("usage_history poisoned").clone();
|
let mut snapshot = usage_history_for_savings
|
||||||
|
.lock()
|
||||||
|
.expect("usage_history poisoned")
|
||||||
|
.clone();
|
||||||
|
snapshot.extend(usage_tracker_for_savings.records());
|
||||||
let est = savings_for_prune_impl(history, &snapshot, indices);
|
let est = savings_for_prune_impl(history, &snapshot, indices);
|
||||||
match est.source {
|
match est.source {
|
||||||
EstimateSource::NoData => 0,
|
EstimateSource::NoData => 0,
|
||||||
|
|
@ -56,8 +77,9 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
|
||||||
.with_value(eval.estimated_savings as f64)
|
.with_value(eval.estimated_savings as f64)
|
||||||
.with_correlation_id(&correlation_id)
|
.with_correlation_id(&correlation_id)
|
||||||
.with_dimension("candidate_count", eval.candidate_count.to_string());
|
.with_dimension("candidate_count", eval.candidate_count.to_string());
|
||||||
if let Some(border) = eval.border_turn {
|
if let Some(protected_start) = eval.protected_start_index {
|
||||||
metric = metric.with_dimension("border_turn", border.to_string());
|
metric =
|
||||||
|
metric.with_dimension("protected_start_index", protected_start.to_string());
|
||||||
}
|
}
|
||||||
metrics.push(metric);
|
metrics.push(metric);
|
||||||
usage_tracker.note_correlation_id(correlation_id);
|
usage_tracker.note_correlation_id(correlation_id);
|
||||||
|
|
@ -66,17 +88,21 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
|
||||||
metrics.push(Metric::now("prune.skip").with_dimension("reason", "no_candidates"));
|
metrics.push(Metric::now("prune.skip").with_dimension("reason", "no_candidates"));
|
||||||
}
|
}
|
||||||
PruneDecision::SkippedBelowMinSavings => {
|
PruneDecision::SkippedBelowMinSavings => {
|
||||||
metrics.push(
|
let mut metric = Metric::now("prune.skip")
|
||||||
Metric::now("prune.skip")
|
.with_dimension("reason", "below_min_savings")
|
||||||
.with_dimension("reason", "below_min_savings")
|
.with_dimension("candidate_count", eval.candidate_count.to_string())
|
||||||
.with_dimension("candidate_count", eval.candidate_count.to_string())
|
.with_value(eval.estimated_savings as f64);
|
||||||
.with_value(eval.estimated_savings as f64),
|
if let Some(protected_start) = eval.protected_start_index {
|
||||||
);
|
metric =
|
||||||
|
metric.with_dimension("protected_start_index", protected_start.to_string());
|
||||||
|
}
|
||||||
|
metrics.push(metric);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let worker = self.worker_mut();
|
let worker = self.worker_mut();
|
||||||
worker.set_prune_config(Some(config));
|
worker.set_prune_config(Some(config));
|
||||||
|
worker.set_token_estimator(Some(token_estimator));
|
||||||
worker.set_savings_estimator(Some(estimator));
|
worker.set_savings_estimator(Some(estimator));
|
||||||
worker.set_prune_observer(Some(observer));
|
worker.set_prune_observer(Some(observer));
|
||||||
}
|
}
|
||||||
|
|
@ -90,7 +116,7 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
let config = PruneConfig {
|
let config = PruneConfig {
|
||||||
protected_turns: compaction.prune_protected_turns,
|
protected_tokens: compaction.prune_protected_tokens,
|
||||||
min_savings: compaction.prune_min_savings,
|
min_savings: compaction.prune_min_savings,
|
||||||
};
|
};
|
||||||
self.attach_prune(config);
|
self.attach_prune(config);
|
||||||
|
|
|
||||||
|
|
@ -132,6 +132,21 @@ fn tool_result_content_bytes(item: &Item) -> u64 {
|
||||||
item_bytes(item).saturating_sub(item_bytes(&cleared))
|
item_bytes(item).saturating_sub(item_bytes(&cleared))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Prefix-boundary token estimates used by Prune to find its protected suffix.
|
||||||
|
///
|
||||||
|
/// Returns `history.len() + 1` entries where entry `i` estimates
|
||||||
|
/// `history[..i]`. This shares the same [`tokens_at`] accounting as compact's
|
||||||
|
/// retained-tail split and prune's savings estimate.
|
||||||
|
pub(crate) fn token_estimates_for_prune_impl(
|
||||||
|
history: &[Item],
|
||||||
|
records: &[UsageRecord],
|
||||||
|
) -> Vec<TokenEstimate> {
|
||||||
|
let prefix = prefix_bytes(history);
|
||||||
|
(0..=history.len())
|
||||||
|
.map(|idx| tokens_at(history, records, idx, &prefix))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
/// Prune 射影(`ToolResult.content = None`)で節約されるトークン数の推定。
|
/// Prune 射影(`ToolResult.content = None`)で節約されるトークン数の推定。
|
||||||
///
|
///
|
||||||
/// `indices` は [`llm_worker::prune::prunable_indices`] が返す候補列を
|
/// `indices` は [`llm_worker::prune::prunable_indices`] が返す候補列を
|
||||||
|
|
@ -278,6 +293,26 @@ mod tests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn token_estimates_for_prune_returns_every_prefix_boundary() {
|
||||||
|
let history = vec![msg("a"), msg("b"), msg("c")];
|
||||||
|
let estimates = token_estimates_for_prune_impl(&history, &[record(3, 300)]);
|
||||||
|
assert_eq!(estimates.len(), history.len() + 1);
|
||||||
|
assert_eq!(estimates[0].tokens, 0);
|
||||||
|
assert_eq!(estimates[3].tokens, 300);
|
||||||
|
assert_eq!(estimates[3].source, EstimateSource::Measured);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn token_estimates_for_prune_propagates_no_data() {
|
||||||
|
let history = vec![msg("a"), msg("b")];
|
||||||
|
let estimates = token_estimates_for_prune_impl(&history, &[]);
|
||||||
|
assert_eq!(estimates.len(), history.len() + 1);
|
||||||
|
assert_eq!(estimates[0].source, EstimateSource::Measured);
|
||||||
|
assert_eq!(estimates[1].source, EstimateSource::NoData);
|
||||||
|
assert_eq!(estimates[2].source, EstimateSource::NoData);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn savings_for_prune_skips_non_toolresult_indices() {
|
fn savings_for_prune_skips_non_toolresult_indices() {
|
||||||
let history = vec![msg("a"), msg("b"), msg("c")];
|
let history = vec![msg("a"), msg("b"), msg("c")];
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,10 @@
|
||||||
//! returns a long `ToolOutput.content`, then inspects the persisted
|
//! returns a long `ToolOutput.content`, then inspects the persisted
|
||||||
//! session log to verify:
|
//! session log to verify:
|
||||||
//!
|
//!
|
||||||
//! - `prune.skip { reason: "no_candidates" }` lands when the protected-turn
|
//! - `prune.skip { reason: "no_candidates" }` lands when usage estimates are
|
||||||
//! window covers the entire history.
|
//! unavailable or the protected-token window covers all tool results.
|
||||||
//! - `prune.fire` lands once enough turns + usage measurements exist for
|
//! - `prune.fire` lands once enough measured history exceeds the protected-token
|
||||||
//! the projection to actually apply.
|
//! budget for the projection to actually apply.
|
||||||
//! - The fire metric and the immediately-following `prune.post_request`
|
//! - The fire metric and the immediately-following `prune.post_request`
|
||||||
//! metric share the same `correlation_id`, so cache_read / cache_write
|
//! metric share the same `correlation_id`, so cache_read / cache_write
|
||||||
//! from the LlmUsage that triggered the projection can be joined back
|
//! from the LlmUsage that triggered the projection can be joined back
|
||||||
|
|
@ -136,7 +136,7 @@ fn text_response_with_cache(text: &str, cache_read: u64, cache_write: u64) -> Ve
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
fn manifest_toml(prune_protected_turns: usize, prune_min_savings: u64) -> String {
|
fn manifest_toml(prune_protected_tokens: u64, prune_min_savings: u64) -> String {
|
||||||
format!(
|
format!(
|
||||||
r#"
|
r#"
|
||||||
[pod]
|
[pod]
|
||||||
|
|
@ -151,7 +151,7 @@ model_id = "test-model"
|
||||||
max_tokens = 100
|
max_tokens = 100
|
||||||
|
|
||||||
[compaction]
|
[compaction]
|
||||||
prune_protected_turns = {prune_protected_turns}
|
prune_protected_tokens = {prune_protected_tokens}
|
||||||
prune_min_savings = {prune_min_savings}
|
prune_min_savings = {prune_min_savings}
|
||||||
|
|
||||||
[[scope.allow]]
|
[[scope.allow]]
|
||||||
|
|
@ -192,7 +192,7 @@ async fn prune_metrics_emit_skip_then_fire_with_post_request_join() {
|
||||||
// Run 1 (request 0): tool_use → triggers tool execution → request 1
|
// Run 1 (request 0): tool_use → triggers tool execution → request 1
|
||||||
// on the second iteration to produce the assistant reply.
|
// on the second iteration to produce the assistant reply.
|
||||||
// Run 2 (request 2): plain assistant text. Prune evaluation here
|
// Run 2 (request 2): plain assistant text. Prune evaluation here
|
||||||
// sees user1's tool_result outside the 1-protected-turn window and
|
// sees user1's tool_result outside the protected-token suffix and
|
||||||
// should fire.
|
// should fire.
|
||||||
let client = MockClient::new(vec![
|
let client = MockClient::new(vec![
|
||||||
tool_use_response("call-1", "big_tool"),
|
tool_use_response("call-1", "big_tool"),
|
||||||
|
|
@ -250,8 +250,8 @@ async fn prune_metrics_emit_skip_then_fire_with_post_request_join() {
|
||||||
"fire missing candidate_count: {fire:?}"
|
"fire missing candidate_count: {fire:?}"
|
||||||
);
|
);
|
||||||
assert!(
|
assert!(
|
||||||
fire.dimensions.contains_key("border_turn"),
|
fire.dimensions.contains_key("protected_start_index"),
|
||||||
"fire missing border_turn: {fire:?}"
|
"fire missing protected_start_index: {fire:?}"
|
||||||
);
|
);
|
||||||
assert!(fire.value.is_some(), "fire missing estimated_savings value");
|
assert!(fire.value.is_some(), "fire missing estimated_savings value");
|
||||||
let fire_id = fire
|
let fire_id = fire
|
||||||
|
|
@ -277,6 +277,36 @@ async fn prune_metrics_emit_skip_then_fire_with_post_request_join() {
|
||||||
assert!(post.dimensions.contains_key("history_len"));
|
assert!(post.dimensions.contains_key("history_len"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn prune_metrics_fire_during_single_long_task_without_multiple_user_turns() {
|
||||||
|
let client = MockClient::new(vec![
|
||||||
|
tool_use_response("call-1", "big_tool"),
|
||||||
|
tool_use_response("call-2", "big_tool"),
|
||||||
|
tool_use_response("call-3", "big_tool"),
|
||||||
|
tool_use_response("call-4", "big_tool"),
|
||||||
|
text_response_with_cache("done", 100, 20),
|
||||||
|
]);
|
||||||
|
let (mut pod, _store_tmp, _pwd_tmp) = make_pod(manifest_toml(1, 1), client, "big_tool").await;
|
||||||
|
let session_id = pod.session_id();
|
||||||
|
let segment_id = pod.segment_id();
|
||||||
|
let store = pod.store().clone();
|
||||||
|
|
||||||
|
pod.run_text("one long task").await.unwrap();
|
||||||
|
|
||||||
|
let state = session_store::restore(&store, session_id, segment_id).unwrap();
|
||||||
|
let metrics = metrics_from_extensions(&state.extensions);
|
||||||
|
let fire_count = metrics.iter().filter(|m| m.name == "prune.fire").count();
|
||||||
|
assert!(
|
||||||
|
fire_count > 0,
|
||||||
|
"single-turn tool loop should produce prune.fire once old heavy ToolResults fall outside the protected-token suffix: {metrics:?}"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
metrics.iter().any(|m| {
|
||||||
|
m.name == "prune.fire" && m.dimensions.contains_key("protected_start_index")
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
/// `min_savings` set high enough that candidates exist but the estimated
|
/// `min_savings` set high enough that candidates exist but the estimated
|
||||||
/// savings always fall short → the second run should record
|
/// savings always fall short → the second run should record
|
||||||
/// `prune.skip { reason: "below_min_savings" }`.
|
/// `prune.skip { reason: "below_min_savings" }`.
|
||||||
|
|
@ -288,7 +318,7 @@ async fn prune_metrics_record_below_min_savings_skip() {
|
||||||
text_response_with_cache("done", 0, 0),
|
text_response_with_cache("done", 0, 0),
|
||||||
]);
|
]);
|
||||||
let (mut pod, _store_tmp, _pwd_tmp) =
|
let (mut pod, _store_tmp, _pwd_tmp) =
|
||||||
make_pod(manifest_toml(1, u64::MAX), client, "big_tool").await;
|
make_pod(manifest_toml(1, 1_000_000), client, "big_tool").await;
|
||||||
let session_id = pod.session_id();
|
let session_id = pod.session_id();
|
||||||
let segment_id = pod.segment_id();
|
let segment_id = pod.segment_id();
|
||||||
let store = pod.store().clone();
|
let store = pod.store().clone();
|
||||||
|
|
@ -405,7 +435,7 @@ async fn metric_write_failure_emits_warn_alert_and_does_not_abort_run() {
|
||||||
|
|
||||||
// Even with a tool registered, this run will only emit
|
// Even with a tool registered, this run will only emit
|
||||||
// `prune.skip { reason: "no_candidates" }` (one user message,
|
// `prune.skip { reason: "no_candidates" }` (one user message,
|
||||||
// protected_turns=1 covers everything). That is enough to drive
|
// protected token budget covers the only user message). That is enough to drive
|
||||||
// the failure path: at least one metric attempts to write.
|
// the failure path: at least one metric attempts to write.
|
||||||
let client = MockClient::new(vec![text_response_with_cache("hi", 0, 0)]);
|
let client = MockClient::new(vec![text_response_with_cache("hi", 0, 0)]);
|
||||||
let worker = Worker::new(client);
|
let worker = Worker::new(client);
|
||||||
|
|
|
||||||
|
|
@ -104,7 +104,7 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn metric_round_trip_via_json() {
|
fn metric_round_trip_via_json() {
|
||||||
let metric = Metric::now("prune.fire")
|
let metric = Metric::now("prune.fire")
|
||||||
.with_dimension("border_turn", "3")
|
.with_dimension("protected_start_index", "3")
|
||||||
.with_dimension("candidate_count", "2")
|
.with_dimension("candidate_count", "2")
|
||||||
.with_value(4096.0)
|
.with_value(4096.0)
|
||||||
.with_correlation_id("abc-123");
|
.with_correlation_id("abc-123");
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@ Pod::try_pre_run_compact ← proactive
|
||||||
|
|
||||||
- **条件付き実行**: 推定トークン節約量が `min_savings` を超えた場合のみ。KV キャッシュの無駄な無効化を避ける
|
- **条件付き実行**: 推定トークン節約量が `min_savings` を超えた場合のみ。KV キャッシュの無駄な無効化を避ける
|
||||||
- **リクエストコンテキストのみ操作**: history 本体は変更しない。Prune 状態を Pod が保持し、LLM リクエスト構築時に反映する
|
- **リクエストコンテキストのみ操作**: history 本体は変更しない。Prune 状態を Pod が保持し、LLM リクエスト構築時に反映する
|
||||||
|
- **保護境界**: 直近 `prune_protected_tokens` 相当の suffix は残す。turn 数ではなく usage history 由来の token estimate で境界を引くため、単発の長い tool loop でも古い `ToolResult.content` が候補になる
|
||||||
- **冪等**: `content: None` のアイテムはスキップ
|
- **冪等**: `content: None` のアイテムはスキップ
|
||||||
|
|
||||||
### ToolOutput の構造
|
### ToolOutput の構造
|
||||||
|
|
@ -138,8 +139,9 @@ compact は fork と同じ構造。旧セッションを保全し、新 SessionI
|
||||||
[compaction]
|
[compaction]
|
||||||
compact_threshold = 80000 # ターンの合間 (proactive)
|
compact_threshold = 80000 # ターンの合間 (proactive)
|
||||||
compact_request_threshold = 90000 # リクエストの合間 (safety net)
|
compact_request_threshold = 90000 # リクエストの合間 (safety net)
|
||||||
retained_tokens = 8000 # 直近保護トークン数 (Prune 済みで計測)
|
prune_protected_tokens = 8000 # prune から保護する末尾 token budget
|
||||||
auto_read_budget = 8000 # compact worker の mark_read_required 合計上限
|
compact_retained_tokens = 8000 # compact 後に生のまま残す末尾 token budget
|
||||||
|
compact_auto_read_budget = 8000 # compact worker の mark_read_required 合計上限
|
||||||
compact_worker_max_input_tokens = 50000 # compact worker 自身の現在占有トークン上限
|
compact_worker_max_input_tokens = 50000 # compact worker 自身の現在占有トークン上限
|
||||||
compact_worker_max_turns = 20 # compact worker 自身の tool loop 上限
|
compact_worker_max_turns = 20 # compact worker 自身の tool loop 上限
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -191,9 +191,9 @@ permission = "write"
|
||||||
# セクションを書いた時点で Prune は有効化、Compact は閾値が None なら無効。
|
# セクションを書いた時点で Prune は有効化、Compact は閾値が None なら無効。
|
||||||
# [compaction]
|
# [compaction]
|
||||||
#
|
#
|
||||||
# # 任意。デフォルト: 3 (`defaults::PRUNE_PROTECTED_TURNS`)。
|
# # 任意。デフォルト: 8000 (`defaults::PRUNE_PROTECTED_TOKENS`)。
|
||||||
# # pruning から保護する末尾ターン数。
|
# # pruning から保護する末尾 token budget。turn 数ではなく usage estimate で境界を引く。
|
||||||
# prune_protected_turns = 3
|
# prune_protected_tokens = 8000
|
||||||
#
|
#
|
||||||
# # 任意。デフォルト: 4096 (`defaults::PRUNE_MIN_SAVINGS`)。
|
# # 任意。デフォルト: 4096 (`defaults::PRUNE_MIN_SAVINGS`)。
|
||||||
# # prune が発火するための最低節約 token 推定値。
|
# # prune が発火するための最低節約 token 推定値。
|
||||||
|
|
|
||||||
|
|
@ -179,7 +179,7 @@ pattern = "*.env"
|
||||||
action = "deny"
|
action = "deny"
|
||||||
|
|
||||||
[compaction]
|
[compaction]
|
||||||
prune_protected_turns = 3
|
prune_protected_tokens = 8000
|
||||||
prune_min_savings = 4096
|
prune_min_savings = 4096
|
||||||
compact_threshold = 80000
|
compact_threshold = 80000
|
||||||
compact_request_threshold = 90000
|
compact_request_threshold = 90000
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user