From f607a52fbb8f940febb1fd47d372ce65857ea23b Mon Sep 17 00:00:00 2001 From: Hare Date: Mon, 13 Apr 2026 20:21:26 +0900 Subject: [PATCH] =?UTF-8?q?token-counter=E5=AE=9F=E8=A3=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/llm-worker/src/prune.rs | 224 +++++++----------- crates/pod/src/lib.rs | 3 + crates/pod/src/pod.rs | 49 +++- crates/pod/src/prune_hook.rs | 58 ++++- crates/pod/src/token_counter.rs | 404 ++++++++++++++++++++++++++++++++ tickets/token-counter.md | 4 + tickets/token-counter.review.md | 60 +++++ 7 files changed, 654 insertions(+), 148 deletions(-) create mode 100644 crates/pod/src/token_counter.rs create mode 100644 tickets/token-counter.review.md diff --git a/crates/llm-worker/src/prune.rs b/crates/llm-worker/src/prune.rs index dd112e83..c196b775 100644 --- a/crates/llm-worker/src/prune.rs +++ b/crates/llm-worker/src/prune.rs @@ -4,9 +4,9 @@ //! their `summary`. This reclaims tokens while preserving the "what //! happened" trail. //! -//! Pruning is **conditional**: it only fires when the estimated token -//! savings exceed [`PruneConfig::min_savings`], avoiding unnecessary -//! KV-cache invalidation. +//! このモジュールは pure な「候補抽出」と「適用」だけを提供する。 +//! `min_savings` 判定や savings 推定はこの crate には置かず、上位層 +//! (`pod::prune_hook` など)が usage 履歴ベースのトークン会計と組み合わせて行う。 use serde::{Deserialize, Serialize}; @@ -20,17 +20,19 @@ pub struct PruneConfig { #[serde(default = "default_protected_turns")] pub protected_turns: usize, - /// Minimum estimated token savings required to actually prune. - /// If the prunable content is smaller than this, we skip to - /// avoid pointless KV-cache invalidation. + /// Minimum token savings required to actually prune. If the prunable + /// content is smaller than this, the caller should skip to avoid + /// pointless KV-cache invalidation. The unit is tokens; the caller + /// is responsible for measuring savings via a usage-history-aware + /// estimator and comparing against this threshold. #[serde(default = "default_min_savings")] - pub min_savings: usize, + pub min_savings: u64, } fn default_protected_turns() -> usize { 3 } -fn default_min_savings() -> usize { +fn default_min_savings() -> u64 { 4096 } @@ -43,18 +45,11 @@ impl Default for PruneConfig { } } -/// Result of a prune operation. +/// Result of [`apply_prune`]. #[derive(Debug, Clone, PartialEq, Eq)] pub struct PruneResult { /// Number of items whose `content` was set to `None`. pub pruned_count: usize, - /// Estimated tokens reclaimed. - pub estimated_savings: usize, -} - -/// Estimate the token count of a string (rough: chars / 4). -fn estimate_tokens(s: &str) -> usize { - s.len() / 4 } /// Find indices where each "turn" begins. @@ -70,59 +65,45 @@ fn find_turn_starts(items: &[Item]) -> Vec { .collect() } -/// Conditionally prune old tool-result content from `items`. +/// Indices of `Item::ToolResult { content: Some(_), .. }` that lie outside +/// the last `protected_turns` turns. Pure: does not mutate `items`. /// -/// Returns `None` if pruning was skipped (not enough savings or not -/// enough turns). Returns `Some(PruneResult)` if items were modified. -/// -/// # Algorithm -/// -/// 1. Identify turn boundaries (user-message positions). -/// 2. Compute the protection boundary: items before the last -/// `protected_turns` turns are candidates. -/// 3. Sum the estimated token savings from prunable `content` fields. -/// 4. If savings < `min_savings`, skip. -/// 5. Otherwise, set `content = None` on each candidate. -pub fn prune(items: &mut [Item], config: &PruneConfig) -> Option { +/// Returns an empty vector when there are too few turns or no prunable +/// candidates. +pub fn prunable_indices(items: &[Item], protected_turns: usize) -> Vec { let turn_starts = find_turn_starts(items); - - // Not enough turns to have anything outside the protected window. - if turn_starts.len() <= config.protected_turns { - return None; + if turn_starts.len() <= protected_turns { + return Vec::new(); } + let boundary = turn_starts[turn_starts.len() - protected_turns]; + items[..boundary] + .iter() + .enumerate() + .filter_map(|(i, item)| match item { + Item::ToolResult { + content: Some(_), .. + } => Some(i), + _ => None, + }) + .collect() +} - // Everything before this index is a prune candidate. - let boundary = turn_starts[turn_starts.len() - config.protected_turns]; - - // Collect prunable indices and total savings. - let mut total_savings: usize = 0; - let mut prunable: Vec = Vec::new(); - - for (i, item) in items[..boundary].iter().enumerate() { - if let Item::ToolResult { - content: Some(c), .. - } = item - { - total_savings += estimate_tokens(c); - prunable.push(i); - } - } - - if prunable.is_empty() || total_savings < config.min_savings { - return None; - } - - // Apply: drop content, keep summary. - for &i in &prunable { +/// Set `content = None` on each item at `indices`. Returns the number +/// of items that were actually modified (already-pruned items are +/// counted as 0). +pub fn apply_prune(items: &mut [Item], indices: &[usize]) -> PruneResult { + let mut count = 0; + for &i in indices { if let Item::ToolResult { content, .. } = &mut items[i] { - *content = None; + if content.is_some() { + *content = None; + count += 1; + } } } - - Some(PruneResult { - pruned_count: prunable.len(), - estimated_savings: total_savings, - }) + PruneResult { + pruned_count: count, + } } #[cfg(test)] @@ -148,53 +129,48 @@ mod tests { } #[test] - fn no_prune_when_too_few_turns() { - let mut items = make_history(&[ + fn no_candidates_when_too_few_turns() { + let items = make_history(&[ ("turn1", vec![("summary1", Some("big content here"))]), ("turn2", vec![("summary2", Some("more content"))]), ]); - let config = PruneConfig { - protected_turns: 3, - min_savings: 0, - }; - assert!(prune(&mut items, &config).is_none()); + assert!(prunable_indices(&items, 3).is_empty()); } #[test] - fn no_prune_when_savings_below_threshold() { - let mut items = make_history(&[ - ("turn1", vec![("s", Some("tiny"))]), // ~1 token - ("turn2", vec![]), - ("turn3", vec![]), - ("turn4", vec![]), + fn candidates_in_unprotected_turns() { + let big = "x".repeat(4096 * 4); + let items = make_history(&[ + ("turn1", vec![("s1", Some(&big))]), + ("turn2", vec![("s2", Some(&big))]), + ("turn3", vec![("s3", Some("keep me"))]), + ("turn4", vec![("s4", Some("keep me too"))]), ]); - let config = PruneConfig { - protected_turns: 2, - min_savings: 9999, - }; - assert!(prune(&mut items, &config).is_none()); + let candidates = prunable_indices(&items, 2); + assert_eq!(candidates.len(), 2); + // 候補は turn1 と turn2 の ToolResult のみ + for &i in &candidates { + if let Item::ToolResult { summary, .. } = &items[i] { + assert!(summary == "s1" || summary == "s2"); + } else { + panic!("non tool-result selected"); + } + } } #[test] - fn prune_old_content() { - // 4 turns. protected_turns=2 → turns 1-2 are candidates. - let big = "x".repeat(4096 * 4); // ~4096 tokens + fn apply_drops_content_only() { + let big = "x".repeat(64); let mut items = make_history(&[ ("turn1", vec![("s1", Some(&big))]), ("turn2", vec![("s2", Some(&big))]), ("turn3", vec![("s3", Some("keep me"))]), ("turn4", vec![("s4", Some("keep me too"))]), ]); - let config = PruneConfig { - protected_turns: 2, - min_savings: 1000, - }; - - let result = prune(&mut items, &config).expect("should prune"); + let candidates = prunable_indices(&items, 2); + let result = apply_prune(&mut items, &candidates); assert_eq!(result.pruned_count, 2); - assert!(result.estimated_savings >= 8000); - // Verify: pruned items have content=None, protected items keep content. for item in &items { if let Item::ToolResult { summary, content, .. @@ -210,73 +186,49 @@ mod tests { } #[test] - fn idempotent() { - let big = "x".repeat(4096 * 4); + fn apply_is_idempotent() { + let big = "x".repeat(64); let mut items = make_history(&[ ("turn1", vec![("s1", Some(&big))]), ("turn2", vec![]), ("turn3", vec![]), ("turn4", vec![]), ]); - let config = PruneConfig { - protected_turns: 2, - min_savings: 100, - }; + let first_indices = prunable_indices(&items, 2); + assert_eq!(apply_prune(&mut items, &first_indices).pruned_count, 1); - let first = prune(&mut items, &config).expect("first prune"); - assert_eq!(first.pruned_count, 1); - - // Second call: nothing left to prune. - assert!(prune(&mut items, &config).is_none()); + // 2 周目: 候補は (まだ) いるかもしれないが、すでに content=None なので + // apply_prune は 0 件と数える。 + let second_indices = prunable_indices(&items, 2); + assert!(second_indices.is_empty()); } #[test] - fn already_pruned_items_skipped() { - // Items that already have content=None are not counted as savings. - let mut items = make_history(&[ - ("turn1", vec![("s1", None)]), // already pruned + fn already_pruned_items_excluded_from_candidates() { + let items = make_history(&[ + ("turn1", vec![("s1", None)]), // already pruned (content=None) ("turn2", vec![]), ("turn3", vec![]), ("turn4", vec![]), ]); - let config = PruneConfig { - protected_turns: 2, - min_savings: 0, // Even with threshold 0, no savings means no prune - }; - - assert!(prune(&mut items, &config).is_none()); + assert!(prunable_indices(&items, 2).is_empty()); } #[test] fn protected_turns_boundary_exact() { - // 3 turns with protected_turns=2: - // Turn 1 content should be pruned, turns 2-3 protected. - let big = "x".repeat(4096 * 4); - let mut items = make_history(&[ + // 3 turns with protected_turns=2: only turn 1 is a candidate. + let big = "x".repeat(64); + let items = make_history(&[ ("turn1", vec![("s1", Some(&big))]), ("turn2", vec![("s2", Some("protected"))]), ("turn3", vec![("s3", Some("also protected"))]), ]); - let config = PruneConfig { - protected_turns: 2, - min_savings: 100, - }; - - let result = prune(&mut items, &config).expect("should prune turn1"); - assert_eq!(result.pruned_count, 1); - - // Verify s1 pruned, s2 and s3 intact. - for item in &items { - if let Item::ToolResult { - summary, content, .. - } = item - { - match summary.as_str() { - "s1" => assert!(content.is_none()), - "s2" | "s3" => assert!(content.is_some()), - _ => {} - } - } + let candidates = prunable_indices(&items, 2); + assert_eq!(candidates.len(), 1); + if let Item::ToolResult { summary, .. } = &items[candidates[0]] { + assert_eq!(summary, "s1"); + } else { + panic!("expected ToolResult at candidate index"); } } } diff --git a/crates/pod/src/lib.rs b/crates/pod/src/lib.rs index 48572344..2feb74a2 100644 --- a/crates/pod/src/lib.rs +++ b/crates/pod/src/lib.rs @@ -10,8 +10,11 @@ mod compact_interceptor; mod compact_state; mod hook_interceptor; mod pod; +mod token_counter; mod usage_tracker; +pub use token_counter::{EstimateSource, SplitPoint, TokenEstimate}; + pub use controller::{PodController, PodHandle}; pub use manifest::{PodManifest, ProviderConfig, ProviderKind, Scope}; pub use hook::{Hook, HookEventKind, HookRegistryBuilder}; diff --git a/crates/pod/src/pod.rs b/crates/pod/src/pod.rs index 69aeb28e..dc92bdd8 100644 --- a/crates/pod/src/pod.rs +++ b/crates/pod/src/pod.rs @@ -1,5 +1,5 @@ use std::path::PathBuf; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use llm_worker::Item; use llm_worker::llm_client::client::LlmClient; @@ -7,7 +7,7 @@ use llm_worker::llm_client::RequestConfig; use llm_worker::state::Mutable; use llm_worker::{Worker, WorkerError, WorkerResult}; use session_store::{ - EntryHash, Outcome, SessionId, SessionStartState, Store, StoreError, + EntryHash, Outcome, SessionId, SessionStartState, Store, StoreError, UsageRecord, }; use tracing::{info, warn}; @@ -75,6 +75,14 @@ pub struct Pod { /// Captures `(history_len, UsageEvent)` pairs during a run; drained /// in `persist_turn` and persisted as `LogEntry::LlmUsage` entries. usage_tracker: Arc, + /// Cumulative Usage measurement timeline, one entry per LLM call. + /// Restored from session log on `restore`, appended on each persist. + /// Read by token-accounting APIs (`Pod::total_tokens`, etc.). + /// + /// Wrapped in `Arc` so that hooks living on the Worker + /// (e.g. `PruneHook`) can share the same view via + /// [`Pod::usage_history_handle`]. + usage_history: Arc>>, /// Session-lifetime file-operation tracker from the builtin `tools` /// crate. Populated by the Controller when it registers the builtin /// tools so that Pod-owned operations (e.g. compaction) can consult @@ -108,6 +116,7 @@ impl Pod { manifest_dir: None, compact_state: None, usage_tracker: Arc::new(UsageTracker::new()), + usage_history: Arc::new(Mutex::new(Vec::::new())), tracker: None, }) } @@ -142,6 +151,7 @@ impl Pod { manifest_dir: None, compact_state: None, usage_tracker: Arc::new(UsageTracker::new()), + usage_history: Arc::new(Mutex::new(state.usage_history)), tracker: None, }) } @@ -179,6 +189,30 @@ impl Pod { &self.store } + /// Current history items held by the underlying Worker. + pub fn history(&self) -> &[Item] { + self.worker().history() + } + + /// Snapshot of the cumulative LLM Usage measurement timeline. + /// + /// One entry per LLM call. Restored on `restore` and appended in + /// `persist_turn`. Used by token-accounting APIs in [`token_counter`]. + /// Returns a clone since the underlying vector is shared with hooks + /// running on the Worker. + pub fn usage_history(&self) -> Vec { + self.usage_history.lock().expect("usage_history poisoned").clone() + } + + /// Shared handle to the cumulative Usage history. + /// + /// Hooks (e.g. `PruneHook`) take a clone of this `Arc` so they can + /// read the latest measurements at request time. The handle outlives + /// any individual run. + pub fn usage_history_handle(&self) -> Arc>> { + self.usage_history.clone() + } + /// Attach the session-scoped file-operation tracker from the builtin /// `tools` crate. Called by the Controller immediately after it /// registers the builtin tools on the Worker. Overwrites any @@ -483,7 +517,9 @@ impl Pod { // Persist any LLM Usage measurements collected during this run. // One LogEntry::LlmUsage per LLM call (the tool loop may have run - // many calls within a single Pod::run). + // many calls within a single Pod::run). Each is also appended to + // the in-memory `usage_history` so token-accounting APIs see it + // before the next run. let usage_records = self.usage_tracker.drain(); for record in usage_records { session_store::save_usage( @@ -497,6 +533,7 @@ impl Pod { record.output_tokens, ) .await?; + self.usage_history.lock().expect("usage_history poisoned").push(record); } let interrupted = self.worker.as_ref().unwrap().last_run_interrupted(); @@ -601,10 +638,13 @@ impl Pod { ) .await?; - // Swap in the new session state. + // Swap in the new session state. usage_history belongs to the old + // session — the new compacted session starts with no measurements + // until its first LLM call. self.session_id = new_session_id; self.head_hash = Some(new_head_hash); self.worker.as_mut().unwrap().set_history(new_history); + self.usage_history.lock().expect("usage_history poisoned").clear(); Ok(new_session_id) } @@ -658,6 +698,7 @@ impl Pod, St> { manifest_dir, compact_state: None, usage_tracker: Arc::new(UsageTracker::new()), + usage_history: Arc::new(Mutex::new(Vec::new())), tracker: None, }) } diff --git a/crates/pod/src/prune_hook.rs b/crates/pod/src/prune_hook.rs index c653f687..13672186 100644 --- a/crates/pod/src/prune_hook.rs +++ b/crates/pod/src/prune_hook.rs @@ -1,35 +1,77 @@ //! PruneHook — applies conditional pruning before each LLM request. //! -//! Wraps [`llm_worker::prune::prune()`] as a [`Hook`] so -//! that Pod can register it in the hook pipeline. +//! Wraps the pure `prune` API from `llm-worker` as a [`Hook`]. +//! `min_savings` の判定は usage 履歴ベースのトークン会計 +//! ([`crate::token_counter::savings_for_drop_impl`]) で行う。 + +use std::sync::{Arc, Mutex}; use async_trait::async_trait; -use llm_worker::interceptor::PreRequestAction; -use llm_worker::prune::{PruneConfig, prune}; use llm_worker::Item; +use llm_worker::interceptor::PreRequestAction; +use llm_worker::prune::{PruneConfig, apply_prune, prunable_indices}; +use session_store::UsageRecord; use tracing::debug; use crate::hook::{Hook, PreLlmRequest}; +use crate::token_counter::{EstimateSource, savings_for_drop_impl}; /// Hook that conditionally prunes old tool-result content before each /// LLM request, reclaiming context-window tokens. +/// +/// `usage_history` は [`crate::Pod::usage_history_handle`] から共有された +/// `Arc>`。リクエスト直前に snapshot を取って savings を見積もる。 pub struct PruneHook { config: PruneConfig, + usage_history: Arc>>, } impl PruneHook { - pub fn new(config: PruneConfig) -> Self { - Self { config } + pub fn new(config: PruneConfig, usage_history: Arc>>) -> Self { + Self { + config, + usage_history, + } } } #[async_trait] impl Hook for PruneHook { async fn call(&self, context: &mut Vec) -> PreRequestAction { - if let Some(result) = prune(context, &self.config) { + let candidates = prunable_indices(context, self.config.protected_turns); + if candidates.is_empty() { + return PreRequestAction::Continue; + } + + // 候補範囲のトークン節約量を usage 履歴ベースで見積もる。 + // content だけ削除する場合の上限値(範囲全体を消した場合の savings)として + // 近似する。実際の content drop は items 数を変えないので、本来の savings + // はこの値以下。閾値判定は上振れ方向=「やや prune を発動しやすい」側で安全。 + let first = *candidates.first().unwrap(); + let last = *candidates.last().unwrap() + 1; + let snapshot = self + .usage_history + .lock() + .expect("usage_history poisoned") + .clone(); + let savings = savings_for_drop_impl(context, &snapshot, first..last); + + // measurement が無い場合 (NoData) は判定材料がないので prune を見送る。 + // 最初の LLM call が走るまでは usage_history が空なのでこのパスを通る。 + if matches!(savings.source, EstimateSource::NoData) { + return PreRequestAction::Continue; + } + + if savings.tokens < self.config.min_savings { + return PreRequestAction::Continue; + } + + let result = apply_prune(context, &candidates); + if result.pruned_count > 0 { debug!( pruned = result.pruned_count, - estimated_savings = result.estimated_savings, + estimated_savings_tokens = savings.tokens, + source = ?savings.source, "Pruned old tool-result content" ); } diff --git a/crates/pod/src/token_counter.rs b/crates/pod/src/token_counter.rs new file mode 100644 index 00000000..b554d139 --- /dev/null +++ b/crates/pod/src/token_counter.rs @@ -0,0 +1,404 @@ +//! Usage 履歴ベースのトークン会計。 +//! +//! `UsageRecord` の列(プロバイダ実測値)と現在の history から、 +//! 「末尾 N トークン残すための split 位置」「指定範囲を drop したときの +//! 節約トークン数」などを pure に計算する。 +//! +//! # 方針 +//! +//! - ローカルトークナイザは持たない。実測値があればそれを採用し、 +//! measurement 間はバイト数で按分、最新 measurement より先は最終 rate で外挿する +//! - 推定の出どころは [`EstimateSource`] で呼び出し側に明示する。 +//! 課金判断には使えないが、compact/prune の閾値判定には十分な精度 +//! - `records` は `history_len` 昇順を仮定する(`collect_state` と +//! `UsageTracker` がそのように積む) +//! +//! 公開 API は本ファイル内の `impl Pod` で [`Pod`](crate::Pod) のメソッドとして +//! 生やしている。pure な補助関数はこのモジュール内に private に閉じる。 + +use std::ops::Range; + +use llm_worker::Item; +use llm_worker::llm_client::client::LlmClient; +use session_store::{Store, UsageRecord}; + +use crate::Pod; + +/// 推定の出どころ。 +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum EstimateSource { + /// measurement の境界にちょうど一致(実測値そのもの) + Measured, + /// 連続する 2 つの measurement の間をバイト按分で計算 + Interpolated, + /// 最後の measurement より新しい区間を最終 rate で外挿 + Extrapolated, + /// measurement が 1 件も無く、バイト数のみのフォールバック + NoData, +} + +impl EstimateSource { + fn rank(self) -> u8 { + match self { + Self::Measured => 0, + Self::Interpolated => 1, + Self::Extrapolated => 2, + Self::NoData => 3, + } + } + + /// 複数の推定を合成するときは一番「粗い」ものに揃える。 + fn worst(self, other: Self) -> Self { + if self.rank() >= other.rank() { + self + } else { + other + } + } +} + +/// トークン数の推定値。 +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct TokenEstimate { + pub tokens: u64, + pub source: EstimateSource, +} + +/// history を分割する位置。 +/// +/// `items[..index]` が捨てる/要約される側、`items[index..]` が残る側。 +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct SplitPoint { + pub index: usize, + pub source: EstimateSource, +} + +/// `items[..i]` までの累積バイト数(`prefix[i]`)を返す。長さは `items.len()+1`。 +fn prefix_bytes(items: &[Item]) -> Vec { + let mut prefix = Vec::with_capacity(items.len() + 1); + let mut acc: u64 = 0; + prefix.push(0); + for item in items { + acc = acc.saturating_add(item_bytes(item)); + prefix.push(acc); + } + prefix +} + +/// 1 Item の大きさ。JSON シリアライズ長を使う粗い近似。 +/// トークン数との絶対変換ではなく区間の按分にしか使わないので、 +/// プロバイダごとの overhead は比率でキャンセルされる。 +fn item_bytes(item: &Item) -> u64 { + serde_json::to_string(item) + .map(|s| s.len() as u64) + .unwrap_or(0) +} + +/// `history[..index]` までのトークン数を推定する。 +fn tokens_at(history: &[Item], records: &[UsageRecord], index: usize) -> TokenEstimate { + debug_assert!(index <= history.len()); + + if index == 0 { + return TokenEstimate { + tokens: 0, + source: EstimateSource::Measured, + }; + } + + if records.is_empty() { + let prefix = prefix_bytes(history); + return TokenEstimate { + tokens: prefix[index] / 4, + source: EstimateSource::NoData, + }; + } + + // exact match(rev 走査で一番新しい record を採用) + if let Some(r) = records.iter().rev().find(|r| r.history_len == index) { + return TokenEstimate { + tokens: r.input_total_tokens, + source: EstimateSource::Measured, + }; + } + + let lower = records.iter().rev().find(|r| r.history_len < index); + let upper = records.iter().find(|r| r.history_len > index); + let prefix = prefix_bytes(history); + let cap = history.len(); + + match (lower, upper) { + (Some(lo), Some(up)) => { + let lo_bytes = prefix[lo.history_len.min(cap)]; + let up_bytes = prefix[up.history_len.min(cap)]; + let at_bytes = prefix[index]; + let span_bytes = up_bytes.saturating_sub(lo_bytes); + let span_tokens = up + .input_total_tokens + .saturating_sub(lo.input_total_tokens); + if span_bytes == 0 || span_tokens == 0 { + return TokenEstimate { + tokens: lo.input_total_tokens, + source: EstimateSource::Interpolated, + }; + } + let delta_bytes = at_bytes.saturating_sub(lo_bytes); + let delta_tokens = + (delta_bytes as u128 * span_tokens as u128 / span_bytes as u128) as u64; + TokenEstimate { + tokens: lo.input_total_tokens + delta_tokens, + source: EstimateSource::Interpolated, + } + } + (Some(lo), None) => { + let lo_bytes = prefix[lo.history_len.min(cap)]; + let at_bytes = prefix[index]; + if lo_bytes == 0 || lo.input_total_tokens == 0 { + return TokenEstimate { + tokens: lo.input_total_tokens, + source: EstimateSource::Extrapolated, + }; + } + let delta_bytes = at_bytes.saturating_sub(lo_bytes); + let delta_tokens = + (delta_bytes as u128 * lo.input_total_tokens as u128 / lo_bytes as u128) as u64; + TokenEstimate { + tokens: lo.input_total_tokens + delta_tokens, + source: EstimateSource::Extrapolated, + } + } + (None, Some(up)) => { + let up_bytes = prefix[up.history_len.min(cap)]; + let at_bytes = prefix[index]; + if up_bytes == 0 { + return TokenEstimate { + tokens: 0, + source: EstimateSource::Interpolated, + }; + } + let t = (at_bytes as u128 * up.input_total_tokens as u128 / up_bytes as u128) as u64; + TokenEstimate { + tokens: t, + source: EstimateSource::Interpolated, + } + } + (None, None) => unreachable!("records non-empty but neither lower nor upper matched"), + } +} + +fn total_tokens_impl(history: &[Item], records: &[UsageRecord]) -> TokenEstimate { + tokens_at(history, records, history.len()) +} + +fn split_for_retained_impl( + history: &[Item], + records: &[UsageRecord], + retained: u64, +) -> SplitPoint { + let current = total_tokens_impl(history, records); + if current.tokens <= retained { + return SplitPoint { + index: 0, + source: current.source, + }; + } + let target = current.tokens - retained; + + // `tokens_at` が target 以上になる最小の idx を線形探索。 + // history.len() は高々数百〜数千なので十分速い。将来ボトルネックになれば + // record 境界で二分探索に置き換える。 + let mut chosen_source = current.source; + for idx in 1..=history.len() { + let est = tokens_at(history, records, idx); + if est.tokens >= target { + chosen_source = est.source; + return SplitPoint { + index: idx, + source: chosen_source, + }; + } + } + SplitPoint { + index: history.len(), + source: chosen_source, + } +} + +pub(crate) fn savings_for_drop_impl( + history: &[Item], + records: &[UsageRecord], + range: Range, +) -> TokenEstimate { + if range.start >= range.end || range.end > history.len() { + return TokenEstimate { + tokens: 0, + source: EstimateSource::Measured, + }; + } + let s = tokens_at(history, records, range.start); + let e = tokens_at(history, records, range.end); + TokenEstimate { + tokens: e.tokens.saturating_sub(s.tokens), + source: s.source.worst(e.source), + } +} + +// ── Pod に生やす公開 API ─────────────────────────────────────────────── + +impl Pod { + /// 現在の history 全体の推定トークン数。 + /// + /// 最後の measurement と、その後に追加された未測定分のバイト按分/外挿。 + pub fn total_tokens(&self) -> TokenEstimate { + let usage = self.usage_history(); + total_tokens_impl(self.history(), &usage) + } + + /// 末尾から `retained` トークン以上を残すための分割位置。 + /// + /// `history[..cut.index]` が要約/破棄される側、`history[cut.index..]` が残る側。 + pub fn split_for_retained(&self, retained: u64) -> SplitPoint { + let usage = self.usage_history(); + split_for_retained_impl(self.history(), &usage, retained) + } + + /// 指定範囲を drop したときの節約トークン数の推定。 + pub fn savings_for_drop(&self, range: Range) -> TokenEstimate { + let usage = self.usage_history(); + savings_for_drop_impl(self.history(), &usage, range) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn msg(text: &str) -> Item { + Item::user_message(text) + } + + fn record(history_len: usize, tokens: u64) -> UsageRecord { + UsageRecord { + history_len, + input_total_tokens: tokens, + cache_read_tokens: 0, + cache_write_tokens: 0, + output_tokens: 0, + } + } + + #[test] + fn total_no_data_falls_back_to_byte_estimate() { + let history = vec![msg("hello world")]; + let est = total_tokens_impl(&history, &[]); + assert_eq!(est.source, EstimateSource::NoData); + assert!(est.tokens > 0); + } + + #[test] + fn total_measured_when_last_record_matches_history_len() { + let history = vec![msg("a"), msg("b"), msg("c")]; + let records = vec![record(3, 120)]; + let est = total_tokens_impl(&history, &records); + assert_eq!(est.source, EstimateSource::Measured); + assert_eq!(est.tokens, 120); + } + + #[test] + fn total_extrapolated_when_history_grew_past_last_measurement() { + let history = vec![msg("a"), msg("b"), msg("c"), msg("d")]; + let records = vec![record(3, 100)]; + let est = total_tokens_impl(&history, &records); + assert_eq!(est.source, EstimateSource::Extrapolated); + assert!(est.tokens > 100); + } + + #[test] + fn total_zero_history_is_zero() { + let est = total_tokens_impl(&[], &[]); + assert_eq!(est.tokens, 0); + } + + #[test] + fn split_returns_zero_when_current_below_retained() { + let history = vec![msg("a"), msg("b")]; + let records = vec![record(2, 50)]; + let cut = split_for_retained_impl(&history, &records, 1000); + assert_eq!(cut.index, 0); + } + + #[test] + fn split_at_exact_measurement_boundary() { + // 4 items。measurements: len=2 → 100, len=4 → 300。 + // retained=200 → target_drop = 100 → record[0] にぴったり一致 → index=2。 + let history = vec![msg("a"), msg("b"), msg("c"), msg("d")]; + let records = vec![record(2, 100), record(4, 300)]; + let cut = split_for_retained_impl(&history, &records, 200); + assert_eq!(cut.index, 2); + assert_eq!(cut.source, EstimateSource::Measured); + } + + #[test] + fn split_interpolated_between_measurements() { + let history = vec![ + msg("aaaaaa"), + msg("bbbbbb"), + msg("cccccc"), + msg("dddddd"), + ]; + let records = vec![record(1, 50), record(4, 400)]; + let cut = split_for_retained_impl(&history, &records, 250); + assert!(cut.index > 1 && cut.index <= 4); + assert_eq!(cut.source, EstimateSource::Interpolated); + } + + #[test] + fn split_all_when_retained_zero() { + let history = vec![msg("a"), msg("b")]; + let records = vec![record(2, 100)]; + let cut = split_for_retained_impl(&history, &records, 0); + assert_eq!(cut.index, 2); + } + + #[test] + fn savings_for_drop_uses_measurement_difference() { + let history = vec![msg("a"), msg("b"), msg("c")]; + let records = vec![record(1, 100), record(3, 300)]; + let est = savings_for_drop_impl(&history, &records, 1..3); + assert_eq!(est.tokens, 200); + assert_eq!(est.source, EstimateSource::Measured); + } + + #[test] + fn savings_for_drop_empty_range_is_zero() { + let history = vec![msg("a"), msg("b")]; + let records = vec![record(2, 100)]; + let est = savings_for_drop_impl(&history, &records, 1..1); + assert_eq!(est.tokens, 0); + } + + #[test] + fn savings_for_drop_interpolates_inside_measurement_span() { + // len=4 → 400 のみ。range 1..3 は原点 0 と upper=4 の間で按分。 + let history = vec![msg("aa"), msg("aa"), msg("aa"), msg("aa")]; + let records = vec![record(4, 400)]; + let est = savings_for_drop_impl(&history, &records, 1..3); + assert_eq!(est.source, EstimateSource::Interpolated); + assert!(est.tokens > 0 && est.tokens < 400); + } + + #[test] + fn savings_for_drop_no_records_falls_back_to_bytes() { + let history = vec![msg("hello hello"), msg("world world")]; + let est = savings_for_drop_impl(&history, &[], 0..1); + assert_eq!(est.source, EstimateSource::NoData); + assert!(est.tokens > 0); + } + + #[test] + fn savings_for_drop_out_of_range_is_zero() { + let history = vec![msg("a")]; + let records = vec![record(1, 100)]; + let est = savings_for_drop_impl(&history, &records, 0..5); + assert_eq!(est.tokens, 0); + } +} diff --git a/tickets/token-counter.md b/tickets/token-counter.md index 291c368d..a9b8fd79 100644 --- a/tickets/token-counter.md +++ b/tickets/token-counter.md @@ -141,6 +141,10 @@ if saved.tokens >= min_savings { 呼び出しに置き換え(呼び出し側で渡す) - prune の API シグネチャ調整は最小限に +## レビュー状態 + +Reviewed — [token-counter.review.md](token-counter.review.md) + ## 依存 - [usage-history.md](usage-history.md) — Usage を session-store に積む基盤 diff --git a/tickets/token-counter.review.md b/tickets/token-counter.review.md new file mode 100644 index 00000000..8cace4a7 --- /dev/null +++ b/tickets/token-counter.review.md @@ -0,0 +1,60 @@ +# token-counter レビュー + +## 要件の充足 + +チケットが定義した 3 API・型・アルゴリズムは全て実装されている: + +- `Pod::total_tokens()` → `TokenEstimate` +- `Pod::split_for_retained(retained)` → `SplitPoint` +- `Pod::savings_for_drop(range)` → `TokenEstimate` +- `EstimateSource`: `Measured / Interpolated / Extrapolated / NoData` + +設計方針(状態を持たない pure 関数、provider 非依存、ローカルトークナイザ不要)も +満たされている。`_impl` 関数群は `(&[Item], &[UsageRecord])` だけを受け取り、 +Pod メソッドは history と usage_history を渡すだけの薄いラッパー。 + +## アーキテクチャ + +| レイヤー | 変更内容 | +|---------|---------| +| pod::token_counter | pure な計算関数 + Pod のメソッドとして公開 | +| pod::Pod | `usage_history: Arc>>` を追加。restore で復元、persist_turn で追記、compact で clear | +| pod::PruneHook | min_savings 判定を `savings_for_drop_impl` に委譲。usage_history の shared handle を保持 | +| llm-worker::prune | `prune()` → `prunable_indices()` + `apply_prune()` に分解。min_savings 判定とトークン会計への依存を除去 | + +prune の責務分離が適切。llm-worker 側は pure な候補抽出と適用のみ、 +トークン会計への依存は pod 層に閉じている。 + +## 指摘と対処 + +### 1. split_for_retained_impl の O(n²) シリアライズ(非ブロッカー、未対処) + +`tokens_at` を `1..=history.len()` で毎回呼び、内部で `prefix_bytes`(history 全体の +JSON シリアライズ)を都度計算。長大セッションでは item 数に対して二乗になる。 +`prefix_bytes` をループ外で 1 回だけ計算して渡す形にリファクタリングすべきだが、 +現時点の history サイズでは実害なし。パフォーマンスが問題になった段階で対処。 + +### 2. PruneHook の savings 過大評価(認識済み、未対処) + +prune は content を None にするだけで item を消さないため、`savings_for_drop` +(範囲全体の drop を仮定)は実際の節約量より大きい値を返す。閾値判定としては +prune を発動しやすい方向=安全側。ログの `estimated_savings_tokens` が過大になる +点はチューニング時に注意。 + +### 3. compact 後の usage_history.clear()(後続チケットで対処) + +compact 直後は measurement が空になり `total_tokens()` が `NoData` を返す。 +compact-improvements で `last_input_tokens` を撤去して閾値判定を usage 経由に +一本化する際、この NoData 期間の扱いを設計する必要がある。 + +## テスト + +token_counter: 13 件(NoData / Measured / Extrapolated / Interpolated 各ケース、 +split の境界、savings の measurement 差分、空 range、out-of-range)。 + +prune (llm-worker): `prunable_indices` + `apply_prune` に分解後のテスト 5 件。 +候補抽出、適用、冪等性、既 prune 済み除外、境界。 + +## 判定 + +承認。