prune-savings-estimation完了

This commit is contained in:
Keisuke Hirata 2026-04-14 03:42:04 +09:00
parent a0a9df11c0
commit 3c58b5dde4
6 changed files with 142 additions and 115 deletions

View File

@ -2,7 +2,6 @@
- [ ] ツール設計
- [ ] Bash ツール (Permission 層と統合) → [tickets/bash-tool.md](tickets/bash-tool.md)
- [ ] Scope の再設計 (pwd + writable、必須化) → [tickets/scope-redesign.md](tickets/scope-redesign.md)
- [ ] Prune の savings 推定を正確にする → [tickets/prune-savings-estimation.md](tickets/prune-savings-estimation.md)
- [ ] Compact の改善(要約品質 + 挙動詳細) → [tickets/compact-improvements.md](tickets/compact-improvements.md)
- [ ] Protocol の設計 → [tickets/protocol-design.md](tickets/protocol-design.md)
- [ ] パーミッション: パターンベースのツール実行制御 → [tickets/permission-extension-point.md](tickets/permission-extension-point.md)

View File

@ -14,18 +14,21 @@
//! `min_savings` 判定や savings 推定もこの crate には置かず、上位層が
//! usage 履歴ベースのトークン会計と組み合わせて行う。
use std::ops::Range;
use serde::{Deserialize, Serialize};
use crate::llm_client::types::Item;
/// Callback that estimates the token savings for dropping `history[range]`.
/// Callback that estimates the token savings for projecting the
/// `ToolResult.content` out of `history[i]` for each `i` in `indices`.
///
/// Injected into [`crate::Worker`] via `set_savings_estimator` so the
/// Worker can make `min_savings` decisions without knowing about usage
/// measurement sources. Return `0` to signal "no data / refuse to prune".
pub type SavingsEstimator = Box<dyn Fn(&[Item], Range<usize>) -> u64 + Send + Sync>;
///
/// 推定対象は「drop する範囲全体」ではなく「content を None にする差分」
/// であることに注意。item 自体summary 等)は残るので、この callback は
/// 実際の projection と一致する savings を返す必要がある。
pub type SavingsEstimator = Box<dyn Fn(&[Item], &[usize]) -> u64 + Send + Sync>;
/// Configuration for the Prune algorithm.
#[derive(Debug, Clone, Serialize, Deserialize)]

View File

@ -717,9 +717,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
let candidates =
crate::prune::prunable_indices(&request_context, config.protected_turns);
if !candidates.is_empty() {
let first = *candidates.first().unwrap();
let last = *candidates.last().unwrap() + 1;
let savings = estimator(&request_context, first..last);
let savings = estimator(&request_context, &candidates);
if savings >= config.min_savings {
let pruned = crate::prune::project(&mut request_context, &candidates);
if pruned > 0 {

View File

@ -12,7 +12,7 @@ use llm_worker::prune::{PruneConfig, SavingsEstimator};
use session_store::Store;
use crate::Pod;
use crate::token_counter::{EstimateSource, savings_for_drop_impl};
use crate::token_counter::{EstimateSource, savings_for_prune_impl};
impl<C: LlmClient, St: Store> Pod<C, St> {
/// Enable prune projection on the underlying Worker.
@ -21,14 +21,14 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
/// The estimator captures a shared handle to [`Pod::usage_history_handle`]
/// so that every LLM request sees the latest measurements.
///
/// Measurement-less ranges (before the first LLM call, or immediately
/// Measurement-less estimates (before the first LLM call, or immediately
/// after a compact) return `0` from the estimator, which naturally
/// prevents the prune projection from firing until usage data exists.
pub fn attach_prune(&mut self, config: PruneConfig) {
let usage = self.usage_history_handle();
let estimator: SavingsEstimator = Box::new(move |history: &[Item], range| {
let estimator: SavingsEstimator = Box::new(move |history: &[Item], indices| {
let snapshot = usage.lock().expect("usage_history poisoned").clone();
let est = savings_for_drop_impl(history, &snapshot, range);
let est = savings_for_prune_impl(history, &snapshot, indices);
match est.source {
EstimateSource::NoData => 0,
_ => est.tokens,

View File

@ -1,8 +1,8 @@
//! Usage 履歴ベースのトークン会計。
//!
//! `UsageRecord` の列(プロバイダ実測値)と現在の history から、
//! 「末尾 N トークン残すための split 位置」「指定範囲を drop したときの
//! 節約トークン数」などを pure に計算する。
//! 「末尾 N トークン残すための split 位置」「prune 射影で節約される
//! トークン数」などを pure に計算する。
//!
//! # 方針
//!
@ -16,8 +16,6 @@
//! 公開 API は本ファイル内の `impl Pod` で [`Pod`](crate::Pod) のメソッドとして
//! 生やしている。pure な補助関数はこのモジュール内に private に閉じる。
use std::ops::Range;
use llm_worker::Item;
use llm_worker::llm_client::client::LlmClient;
use session_store::{Store, UsageRecord};
@ -37,26 +35,6 @@ pub enum EstimateSource {
NoData,
}
impl EstimateSource {
fn rank(self) -> u8 {
match self {
Self::Measured => 0,
Self::Interpolated => 1,
Self::Extrapolated => 2,
Self::NoData => 3,
}
}
/// 複数の推定を合成するときは一番「粗い」ものに揃える。
fn worst(self, other: Self) -> Self {
if self.rank() >= other.rank() {
self
} else {
other
}
}
}
/// トークン数の推定値。
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TokenEstimate {
@ -228,24 +206,78 @@ fn split_for_retained_impl(history: &[Item], records: &[UsageRecord], retained:
}
}
pub(crate) fn savings_for_drop_impl(
/// 1 つの ToolResult 項目について、`content` を `None` に射影したとき
/// 減少するシリアライズ後バイト数。ToolResult 以外や既に content=None
/// の item は 0 を返す。
fn tool_result_content_bytes(item: &Item) -> u64 {
if !matches!(
item,
Item::ToolResult {
content: Some(_),
..
}
) {
return 0;
}
let mut cleared = item.clone();
if let Item::ToolResult { content, .. } = &mut cleared {
*content = None;
}
item_bytes(item).saturating_sub(item_bytes(&cleared))
}
/// Prune 射影(`ToolResult.content = None`)で節約されるトークン数の推定。
///
/// `indices` は [`llm_worker::prune::prunable_indices`] が返す候補列を
/// 想定する。各候補の content バイト差分を合算し、usage 履歴由来の
/// tokens/byte レートでトークン数に換算する。範囲を「丸ごと drop」する
/// のではなく、item 自体summary 等)は残したままの値を返す点が
/// `tokens_at` ベースの計算と異なる。
pub(crate) fn savings_for_prune_impl(
history: &[Item],
records: &[UsageRecord],
range: Range<usize>,
indices: &[usize],
) -> TokenEstimate {
if range.start >= range.end || range.end > history.len() {
let removed_bytes: u64 = indices
.iter()
.filter_map(|&i| history.get(i))
.map(tool_result_content_bytes)
.sum();
if removed_bytes == 0 {
return TokenEstimate {
tokens: 0,
source: EstimateSource::Measured,
};
}
let prefix = prefix_bytes(history);
let s = tokens_at(history, records, range.start, &prefix);
let e = tokens_at(history, records, range.end, &prefix);
TokenEstimate {
tokens: e.tokens.saturating_sub(s.tokens),
source: s.source.worst(e.source),
if records.is_empty() {
return TokenEstimate {
tokens: removed_bytes / 4,
source: EstimateSource::NoData,
};
}
// 最新の measurement を使って tokens/byte を求め、バイト差分を換算する。
// 実測値そのものではなく比率しか使わないので、history_len と
// record.history_len が一致しなくても rate は正しい。
let prefix = prefix_bytes(history);
let last = records.last().expect("records non-empty");
let ref_bytes = prefix[last.history_len.min(history.len())];
if ref_bytes == 0 || last.input_total_tokens == 0 {
return TokenEstimate {
tokens: 0,
source: EstimateSource::Extrapolated,
};
}
let tokens =
(removed_bytes as u128 * last.input_total_tokens as u128 / ref_bytes as u128) as u64;
let source = if last.history_len == history.len() {
EstimateSource::Measured
} else {
EstimateSource::Extrapolated
};
TokenEstimate { tokens, source }
}
// ── Pod に生やす公開 API ───────────────────────────────────────────────
@ -266,12 +298,6 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
let usage = self.usage_history();
split_for_retained_impl(self.history(), &usage, retained)
}
/// 指定範囲を drop したときの節約トークン数の推定。
pub fn savings_for_drop(&self, range: Range<usize>) -> TokenEstimate {
let usage = self.usage_history();
savings_for_drop_impl(self.history(), &usage, range)
}
}
#[cfg(test)]
@ -360,46 +386,87 @@ mod tests {
assert_eq!(cut.index, 2);
}
#[test]
fn savings_for_drop_uses_measurement_difference() {
let history = vec![msg("a"), msg("b"), msg("c")];
let records = vec![record(1, 100), record(3, 300)];
let est = savings_for_drop_impl(&history, &records, 1..3);
assert_eq!(est.tokens, 200);
assert_eq!(est.source, EstimateSource::Measured);
fn tool_result_with(summary: &str, content: Option<&str>) -> Item {
match content {
Some(c) => Item::tool_result_with_content("call", summary, c),
None => Item::tool_result("call", summary),
}
}
#[test]
fn savings_for_drop_empty_range_is_zero() {
let history = vec![msg("a"), msg("b")];
let records = vec![record(2, 100)];
let est = savings_for_drop_impl(&history, &records, 1..1);
fn savings_for_prune_skips_non_toolresult_indices() {
let history = vec![msg("a"), msg("b"), msg("c")];
// indices point at plain messages, not ToolResult → 0 savings.
let est = savings_for_prune_impl(&history, &[record(3, 300)], &[0, 1, 2]);
assert_eq!(est.tokens, 0);
}
#[test]
fn savings_for_drop_interpolates_inside_measurement_span() {
// len=4 → 400 のみ。range 1..3 は原点 0 と upper=4 の間で按分。
let history = vec![msg("aa"), msg("aa"), msg("aa"), msg("aa")];
let records = vec![record(4, 400)];
let est = savings_for_drop_impl(&history, &records, 1..3);
assert_eq!(est.source, EstimateSource::Interpolated);
assert!(est.tokens > 0 && est.tokens < 400);
fn savings_for_prune_skips_content_none_items() {
let history = vec![
msg("user"),
tool_result_with("s1", None),
tool_result_with("s2", None),
];
let est = savings_for_prune_impl(&history, &[record(3, 300)], &[1, 2]);
assert_eq!(est.tokens, 0);
}
#[test]
fn savings_for_drop_no_records_falls_back_to_bytes() {
let history = vec![msg("hello hello"), msg("world world")];
let est = savings_for_drop_impl(&history, &[], 0..1);
fn savings_for_prune_counts_only_content_delta() {
// 1 item with big content vs the same structure without content.
let big = "x".repeat(400);
let history = vec![
msg("user"),
tool_result_with("summary", Some(&big)),
msg("tail"),
];
// 1 record at end so rate = tokens / total_bytes
let total_bytes: u64 = history.iter().map(item_bytes).sum();
let records = vec![record(history.len(), total_bytes)]; // rate = 1 tok/byte
let est = savings_for_prune_impl(&history, &records, &[1]);
// saved bytes ≈ size of the big content payload; with rate=1 it
// should be close to 400 and far from the full item bytes.
let full_item_bytes = item_bytes(&history[1]);
assert!(est.tokens > 0);
assert!(est.tokens < full_item_bytes);
assert!(est.tokens >= 400);
assert_eq!(est.source, EstimateSource::Measured);
}
#[test]
fn savings_for_prune_no_records_falls_back_to_bytes() {
let history = vec![msg("u"), tool_result_with("s", Some("hello world"))];
let est = savings_for_prune_impl(&history, &[], &[1]);
assert_eq!(est.source, EstimateSource::NoData);
assert!(est.tokens > 0);
}
#[test]
fn savings_for_drop_out_of_range_is_zero() {
fn savings_for_prune_extrapolated_when_history_grew_past_measurement() {
let big = "x".repeat(200);
let history = vec![
msg("u1"),
tool_result_with("s", Some(&big)),
msg("u2"), // added after the last measurement
];
let records = vec![record(2, 100)];
let est = savings_for_prune_impl(&history, &records, &[1]);
assert_eq!(est.source, EstimateSource::Extrapolated);
assert!(est.tokens > 0);
}
#[test]
fn savings_for_prune_empty_indices_is_zero() {
let history = vec![msg("a")];
let records = vec![record(1, 100)];
let est = savings_for_drop_impl(&history, &records, 0..5);
let est = savings_for_prune_impl(&history, &[record(1, 100)], &[]);
assert_eq!(est.tokens, 0);
}
#[test]
fn savings_for_prune_ignores_out_of_range_indices() {
let history = vec![msg("a")];
let est = savings_for_prune_impl(&history, &[record(1, 100)], &[99]);
assert_eq!(est.tokens, 0);
}
}

View File

@ -1,40 +0,0 @@
# Prune の savings 推定を正確にする
## 背景
現在の PruneHook は `savings_for_drop_impl(context, &snapshot, first..last)`
候補範囲を「丸ごと drop した場合」の savings を計算し、`min_savings` と比較している。
しかし prune が実際に行うのは ToolResult の content を省略するだけで、
item 自体summary、メタデータは残る。そのため `savings_for_drop`
実際の節約量を過大評価しており、本来 prune 不要な場面でも発動しうる。
savings の推定は prune 側の責務であり、token_counter の汎用 API に
prune 固有の挙動を押し込むべきではない。
## 方針
PruneHook が「content 部分だけの savings」を計算する。
### 計算方法
候補の各 ToolResult について:
- content ありの item のトークン推定
- content を None にした場合summary のみ)のトークン推定
- 差分が prune による実際の savings
バイト数の差分を measurement 由来の rate で換算するか、
`tokens_at` を使った前後比較にするかは実装時判断。
### 影響範囲
- `crates/pod/src/prune_hook.rs`: savings 計算ロジックの置き換え
- `crates/pod/src/token_counter.rs`: 必要に応じて content-level の推定ヘルパーを追加
## 依存
- [prune-projection.md](prune-projection.md) — prune が射影ベースになった後の方が設計しやすい
## ブロックする後続
- なし(チューニングの精度改善)