diff --git a/TODO.md b/TODO.md index 263a2da9..d0ae9570 100644 --- a/TODO.md +++ b/TODO.md @@ -2,7 +2,6 @@ - [ ] ツール設計 - [ ] Bash ツール (Permission 層と統合) → [tickets/bash-tool.md](tickets/bash-tool.md) - [ ] Scope の再設計 (pwd + writable、必須化) → [tickets/scope-redesign.md](tickets/scope-redesign.md) -- [ ] Prune の savings 推定を正確にする → [tickets/prune-savings-estimation.md](tickets/prune-savings-estimation.md) - [ ] Compact の改善(要約品質 + 挙動詳細) → [tickets/compact-improvements.md](tickets/compact-improvements.md) - [ ] Protocol の設計 → [tickets/protocol-design.md](tickets/protocol-design.md) - [ ] パーミッション: パターンベースのツール実行制御 → [tickets/permission-extension-point.md](tickets/permission-extension-point.md) diff --git a/crates/llm-worker/src/prune.rs b/crates/llm-worker/src/prune.rs index 7cfc819a..4acd6f69 100644 --- a/crates/llm-worker/src/prune.rs +++ b/crates/llm-worker/src/prune.rs @@ -14,18 +14,21 @@ //! `min_savings` 判定や savings 推定もこの crate には置かず、上位層が //! usage 履歴ベースのトークン会計と組み合わせて行う。 -use std::ops::Range; - use serde::{Deserialize, Serialize}; use crate::llm_client::types::Item; -/// Callback that estimates the token savings for dropping `history[range]`. +/// Callback that estimates the token savings for projecting the +/// `ToolResult.content` out of `history[i]` for each `i` in `indices`. /// /// Injected into [`crate::Worker`] via `set_savings_estimator` so the /// Worker can make `min_savings` decisions without knowing about usage /// measurement sources. Return `0` to signal "no data / refuse to prune". -pub type SavingsEstimator = Box) -> u64 + Send + Sync>; +/// +/// 推定対象は「drop する範囲全体」ではなく「content を None にする差分」 +/// であることに注意。item 自体(summary 等)は残るので、この callback は +/// 実際の projection と一致する savings を返す必要がある。 +pub type SavingsEstimator = Box u64 + Send + Sync>; /// Configuration for the Prune algorithm. #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/crates/llm-worker/src/worker.rs b/crates/llm-worker/src/worker.rs index 7ae6c941..7191eacd 100644 --- a/crates/llm-worker/src/worker.rs +++ b/crates/llm-worker/src/worker.rs @@ -717,9 +717,7 @@ impl Worker { let candidates = crate::prune::prunable_indices(&request_context, config.protected_turns); if !candidates.is_empty() { - let first = *candidates.first().unwrap(); - let last = *candidates.last().unwrap() + 1; - let savings = estimator(&request_context, first..last); + let savings = estimator(&request_context, &candidates); if savings >= config.min_savings { let pruned = crate::prune::project(&mut request_context, &candidates); if pruned > 0 { diff --git a/crates/pod/src/prune.rs b/crates/pod/src/prune.rs index 03bce9e7..480f806a 100644 --- a/crates/pod/src/prune.rs +++ b/crates/pod/src/prune.rs @@ -12,7 +12,7 @@ use llm_worker::prune::{PruneConfig, SavingsEstimator}; use session_store::Store; use crate::Pod; -use crate::token_counter::{EstimateSource, savings_for_drop_impl}; +use crate::token_counter::{EstimateSource, savings_for_prune_impl}; impl Pod { /// Enable prune projection on the underlying Worker. @@ -21,14 +21,14 @@ impl Pod { /// The estimator captures a shared handle to [`Pod::usage_history_handle`] /// so that every LLM request sees the latest measurements. /// - /// Measurement-less ranges (before the first LLM call, or immediately + /// Measurement-less estimates (before the first LLM call, or immediately /// after a compact) return `0` from the estimator, which naturally /// prevents the prune projection from firing until usage data exists. pub fn attach_prune(&mut self, config: PruneConfig) { let usage = self.usage_history_handle(); - let estimator: SavingsEstimator = Box::new(move |history: &[Item], range| { + let estimator: SavingsEstimator = Box::new(move |history: &[Item], indices| { let snapshot = usage.lock().expect("usage_history poisoned").clone(); - let est = savings_for_drop_impl(history, &snapshot, range); + let est = savings_for_prune_impl(history, &snapshot, indices); match est.source { EstimateSource::NoData => 0, _ => est.tokens, diff --git a/crates/pod/src/token_counter.rs b/crates/pod/src/token_counter.rs index 876f39ca..3ca4dc76 100644 --- a/crates/pod/src/token_counter.rs +++ b/crates/pod/src/token_counter.rs @@ -1,8 +1,8 @@ //! Usage 履歴ベースのトークン会計。 //! //! `UsageRecord` の列(プロバイダ実測値)と現在の history から、 -//! 「末尾 N トークン残すための split 位置」「指定範囲を drop したときの -//! 節約トークン数」などを pure に計算する。 +//! 「末尾 N トークン残すための split 位置」「prune 射影で節約される +//! トークン数」などを pure に計算する。 //! //! # 方針 //! @@ -16,8 +16,6 @@ //! 公開 API は本ファイル内の `impl Pod` で [`Pod`](crate::Pod) のメソッドとして //! 生やしている。pure な補助関数はこのモジュール内に private に閉じる。 -use std::ops::Range; - use llm_worker::Item; use llm_worker::llm_client::client::LlmClient; use session_store::{Store, UsageRecord}; @@ -37,26 +35,6 @@ pub enum EstimateSource { NoData, } -impl EstimateSource { - fn rank(self) -> u8 { - match self { - Self::Measured => 0, - Self::Interpolated => 1, - Self::Extrapolated => 2, - Self::NoData => 3, - } - } - - /// 複数の推定を合成するときは一番「粗い」ものに揃える。 - fn worst(self, other: Self) -> Self { - if self.rank() >= other.rank() { - self - } else { - other - } - } -} - /// トークン数の推定値。 #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct TokenEstimate { @@ -228,24 +206,78 @@ fn split_for_retained_impl(history: &[Item], records: &[UsageRecord], retained: } } -pub(crate) fn savings_for_drop_impl( +/// 1 つの ToolResult 項目について、`content` を `None` に射影したとき +/// 減少するシリアライズ後バイト数。ToolResult 以外や既に content=None +/// の item は 0 を返す。 +fn tool_result_content_bytes(item: &Item) -> u64 { + if !matches!( + item, + Item::ToolResult { + content: Some(_), + .. + } + ) { + return 0; + } + let mut cleared = item.clone(); + if let Item::ToolResult { content, .. } = &mut cleared { + *content = None; + } + item_bytes(item).saturating_sub(item_bytes(&cleared)) +} + +/// Prune 射影(`ToolResult.content = None`)で節約されるトークン数の推定。 +/// +/// `indices` は [`llm_worker::prune::prunable_indices`] が返す候補列を +/// 想定する。各候補の content バイト差分を合算し、usage 履歴由来の +/// tokens/byte レートでトークン数に換算する。範囲を「丸ごと drop」する +/// のではなく、item 自体(summary 等)は残したままの値を返す点が +/// `tokens_at` ベースの計算と異なる。 +pub(crate) fn savings_for_prune_impl( history: &[Item], records: &[UsageRecord], - range: Range, + indices: &[usize], ) -> TokenEstimate { - if range.start >= range.end || range.end > history.len() { + let removed_bytes: u64 = indices + .iter() + .filter_map(|&i| history.get(i)) + .map(tool_result_content_bytes) + .sum(); + + if removed_bytes == 0 { return TokenEstimate { tokens: 0, source: EstimateSource::Measured, }; } - let prefix = prefix_bytes(history); - let s = tokens_at(history, records, range.start, &prefix); - let e = tokens_at(history, records, range.end, &prefix); - TokenEstimate { - tokens: e.tokens.saturating_sub(s.tokens), - source: s.source.worst(e.source), + + if records.is_empty() { + return TokenEstimate { + tokens: removed_bytes / 4, + source: EstimateSource::NoData, + }; } + + // 最新の measurement を使って tokens/byte を求め、バイト差分を換算する。 + // 実測値そのものではなく比率しか使わないので、history_len と + // record.history_len が一致しなくても rate は正しい。 + let prefix = prefix_bytes(history); + let last = records.last().expect("records non-empty"); + let ref_bytes = prefix[last.history_len.min(history.len())]; + if ref_bytes == 0 || last.input_total_tokens == 0 { + return TokenEstimate { + tokens: 0, + source: EstimateSource::Extrapolated, + }; + } + let tokens = + (removed_bytes as u128 * last.input_total_tokens as u128 / ref_bytes as u128) as u64; + let source = if last.history_len == history.len() { + EstimateSource::Measured + } else { + EstimateSource::Extrapolated + }; + TokenEstimate { tokens, source } } // ── Pod に生やす公開 API ─────────────────────────────────────────────── @@ -266,12 +298,6 @@ impl Pod { let usage = self.usage_history(); split_for_retained_impl(self.history(), &usage, retained) } - - /// 指定範囲を drop したときの節約トークン数の推定。 - pub fn savings_for_drop(&self, range: Range) -> TokenEstimate { - let usage = self.usage_history(); - savings_for_drop_impl(self.history(), &usage, range) - } } #[cfg(test)] @@ -360,46 +386,87 @@ mod tests { assert_eq!(cut.index, 2); } - #[test] - fn savings_for_drop_uses_measurement_difference() { - let history = vec![msg("a"), msg("b"), msg("c")]; - let records = vec![record(1, 100), record(3, 300)]; - let est = savings_for_drop_impl(&history, &records, 1..3); - assert_eq!(est.tokens, 200); - assert_eq!(est.source, EstimateSource::Measured); + fn tool_result_with(summary: &str, content: Option<&str>) -> Item { + match content { + Some(c) => Item::tool_result_with_content("call", summary, c), + None => Item::tool_result("call", summary), + } } #[test] - fn savings_for_drop_empty_range_is_zero() { - let history = vec![msg("a"), msg("b")]; - let records = vec![record(2, 100)]; - let est = savings_for_drop_impl(&history, &records, 1..1); + fn savings_for_prune_skips_non_toolresult_indices() { + let history = vec![msg("a"), msg("b"), msg("c")]; + // indices point at plain messages, not ToolResult → 0 savings. + let est = savings_for_prune_impl(&history, &[record(3, 300)], &[0, 1, 2]); assert_eq!(est.tokens, 0); } #[test] - fn savings_for_drop_interpolates_inside_measurement_span() { - // len=4 → 400 のみ。range 1..3 は原点 0 と upper=4 の間で按分。 - let history = vec![msg("aa"), msg("aa"), msg("aa"), msg("aa")]; - let records = vec![record(4, 400)]; - let est = savings_for_drop_impl(&history, &records, 1..3); - assert_eq!(est.source, EstimateSource::Interpolated); - assert!(est.tokens > 0 && est.tokens < 400); + fn savings_for_prune_skips_content_none_items() { + let history = vec![ + msg("user"), + tool_result_with("s1", None), + tool_result_with("s2", None), + ]; + let est = savings_for_prune_impl(&history, &[record(3, 300)], &[1, 2]); + assert_eq!(est.tokens, 0); } #[test] - fn savings_for_drop_no_records_falls_back_to_bytes() { - let history = vec![msg("hello hello"), msg("world world")]; - let est = savings_for_drop_impl(&history, &[], 0..1); + fn savings_for_prune_counts_only_content_delta() { + // 1 item with big content vs the same structure without content. + let big = "x".repeat(400); + let history = vec![ + msg("user"), + tool_result_with("summary", Some(&big)), + msg("tail"), + ]; + // 1 record at end so rate = tokens / total_bytes + let total_bytes: u64 = history.iter().map(item_bytes).sum(); + let records = vec![record(history.len(), total_bytes)]; // rate = 1 tok/byte + let est = savings_for_prune_impl(&history, &records, &[1]); + // saved bytes ≈ size of the big content payload; with rate=1 it + // should be close to 400 and far from the full item bytes. + let full_item_bytes = item_bytes(&history[1]); + assert!(est.tokens > 0); + assert!(est.tokens < full_item_bytes); + assert!(est.tokens >= 400); + assert_eq!(est.source, EstimateSource::Measured); + } + + #[test] + fn savings_for_prune_no_records_falls_back_to_bytes() { + let history = vec![msg("u"), tool_result_with("s", Some("hello world"))]; + let est = savings_for_prune_impl(&history, &[], &[1]); assert_eq!(est.source, EstimateSource::NoData); assert!(est.tokens > 0); } #[test] - fn savings_for_drop_out_of_range_is_zero() { + fn savings_for_prune_extrapolated_when_history_grew_past_measurement() { + let big = "x".repeat(200); + let history = vec![ + msg("u1"), + tool_result_with("s", Some(&big)), + msg("u2"), // added after the last measurement + ]; + let records = vec![record(2, 100)]; + let est = savings_for_prune_impl(&history, &records, &[1]); + assert_eq!(est.source, EstimateSource::Extrapolated); + assert!(est.tokens > 0); + } + + #[test] + fn savings_for_prune_empty_indices_is_zero() { let history = vec![msg("a")]; - let records = vec![record(1, 100)]; - let est = savings_for_drop_impl(&history, &records, 0..5); + let est = savings_for_prune_impl(&history, &[record(1, 100)], &[]); + assert_eq!(est.tokens, 0); + } + + #[test] + fn savings_for_prune_ignores_out_of_range_indices() { + let history = vec![msg("a")]; + let est = savings_for_prune_impl(&history, &[record(1, 100)], &[99]); assert_eq!(est.tokens, 0); } } diff --git a/tickets/prune-savings-estimation.md b/tickets/prune-savings-estimation.md deleted file mode 100644 index f4a3e8b8..00000000 --- a/tickets/prune-savings-estimation.md +++ /dev/null @@ -1,40 +0,0 @@ -# Prune の savings 推定を正確にする - -## 背景 - -現在の PruneHook は `savings_for_drop_impl(context, &snapshot, first..last)` で -候補範囲を「丸ごと drop した場合」の savings を計算し、`min_savings` と比較している。 - -しかし prune が実際に行うのは ToolResult の content を省略するだけで、 -item 自体(summary、メタデータ)は残る。そのため `savings_for_drop` は -実際の節約量を過大評価しており、本来 prune 不要な場面でも発動しうる。 - -savings の推定は prune 側の責務であり、token_counter の汎用 API に -prune 固有の挙動を押し込むべきではない。 - -## 方針 - -PruneHook が「content 部分だけの savings」を計算する。 - -### 計算方法 - -候補の各 ToolResult について: -- content ありの item のトークン推定 -- content を None にした場合(summary のみ)のトークン推定 -- 差分が prune による実際の savings - -バイト数の差分を measurement 由来の rate で換算するか、 -`tokens_at` を使った前後比較にするかは実装時判断。 - -### 影響範囲 - -- `crates/pod/src/prune_hook.rs`: savings 計算ロジックの置き換え -- `crates/pod/src/token_counter.rs`: 必要に応じて content-level の推定ヘルパーを追加 - -## 依存 - -- [prune-projection.md](prune-projection.md) — prune が射影ベースになった後の方が設計しやすい - -## ブロックする後続 - -- なし(チューニングの精度改善)