token-counter実装
This commit is contained in:
parent
7fb2e4bc6c
commit
f607a52fbb
|
|
@ -4,9 +4,9 @@
|
|||
//! their `summary`. This reclaims tokens while preserving the "what
|
||||
//! happened" trail.
|
||||
//!
|
||||
//! Pruning is **conditional**: it only fires when the estimated token
|
||||
//! savings exceed [`PruneConfig::min_savings`], avoiding unnecessary
|
||||
//! KV-cache invalidation.
|
||||
//! このモジュールは pure な「候補抽出」と「適用」だけを提供する。
|
||||
//! `min_savings` 判定や savings 推定はこの crate には置かず、上位層
|
||||
//! (`pod::prune_hook` など)が usage 履歴ベースのトークン会計と組み合わせて行う。
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
|
@ -20,17 +20,19 @@ pub struct PruneConfig {
|
|||
#[serde(default = "default_protected_turns")]
|
||||
pub protected_turns: usize,
|
||||
|
||||
/// Minimum estimated token savings required to actually prune.
|
||||
/// If the prunable content is smaller than this, we skip to
|
||||
/// avoid pointless KV-cache invalidation.
|
||||
/// Minimum token savings required to actually prune. If the prunable
|
||||
/// content is smaller than this, the caller should skip to avoid
|
||||
/// pointless KV-cache invalidation. The unit is tokens; the caller
|
||||
/// is responsible for measuring savings via a usage-history-aware
|
||||
/// estimator and comparing against this threshold.
|
||||
#[serde(default = "default_min_savings")]
|
||||
pub min_savings: usize,
|
||||
pub min_savings: u64,
|
||||
}
|
||||
|
||||
fn default_protected_turns() -> usize {
|
||||
3
|
||||
}
|
||||
fn default_min_savings() -> usize {
|
||||
fn default_min_savings() -> u64 {
|
||||
4096
|
||||
}
|
||||
|
||||
|
|
@ -43,18 +45,11 @@ impl Default for PruneConfig {
|
|||
}
|
||||
}
|
||||
|
||||
/// Result of a prune operation.
|
||||
/// Result of [`apply_prune`].
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PruneResult {
|
||||
/// Number of items whose `content` was set to `None`.
|
||||
pub pruned_count: usize,
|
||||
/// Estimated tokens reclaimed.
|
||||
pub estimated_savings: usize,
|
||||
}
|
||||
|
||||
/// Estimate the token count of a string (rough: chars / 4).
|
||||
fn estimate_tokens(s: &str) -> usize {
|
||||
s.len() / 4
|
||||
}
|
||||
|
||||
/// Find indices where each "turn" begins.
|
||||
|
|
@ -70,59 +65,45 @@ fn find_turn_starts(items: &[Item]) -> Vec<usize> {
|
|||
.collect()
|
||||
}
|
||||
|
||||
/// Conditionally prune old tool-result content from `items`.
|
||||
/// Indices of `Item::ToolResult { content: Some(_), .. }` that lie outside
|
||||
/// the last `protected_turns` turns. Pure: does not mutate `items`.
|
||||
///
|
||||
/// Returns `None` if pruning was skipped (not enough savings or not
|
||||
/// enough turns). Returns `Some(PruneResult)` if items were modified.
|
||||
///
|
||||
/// # Algorithm
|
||||
///
|
||||
/// 1. Identify turn boundaries (user-message positions).
|
||||
/// 2. Compute the protection boundary: items before the last
|
||||
/// `protected_turns` turns are candidates.
|
||||
/// 3. Sum the estimated token savings from prunable `content` fields.
|
||||
/// 4. If savings < `min_savings`, skip.
|
||||
/// 5. Otherwise, set `content = None` on each candidate.
|
||||
pub fn prune(items: &mut [Item], config: &PruneConfig) -> Option<PruneResult> {
|
||||
/// Returns an empty vector when there are too few turns or no prunable
|
||||
/// candidates.
|
||||
pub fn prunable_indices(items: &[Item], protected_turns: usize) -> Vec<usize> {
|
||||
let turn_starts = find_turn_starts(items);
|
||||
|
||||
// Not enough turns to have anything outside the protected window.
|
||||
if turn_starts.len() <= config.protected_turns {
|
||||
return None;
|
||||
if turn_starts.len() <= protected_turns {
|
||||
return Vec::new();
|
||||
}
|
||||
let boundary = turn_starts[turn_starts.len() - protected_turns];
|
||||
items[..boundary]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, item)| match item {
|
||||
Item::ToolResult {
|
||||
content: Some(_), ..
|
||||
} => Some(i),
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// Everything before this index is a prune candidate.
|
||||
let boundary = turn_starts[turn_starts.len() - config.protected_turns];
|
||||
|
||||
// Collect prunable indices and total savings.
|
||||
let mut total_savings: usize = 0;
|
||||
let mut prunable: Vec<usize> = Vec::new();
|
||||
|
||||
for (i, item) in items[..boundary].iter().enumerate() {
|
||||
if let Item::ToolResult {
|
||||
content: Some(c), ..
|
||||
} = item
|
||||
{
|
||||
total_savings += estimate_tokens(c);
|
||||
prunable.push(i);
|
||||
}
|
||||
}
|
||||
|
||||
if prunable.is_empty() || total_savings < config.min_savings {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Apply: drop content, keep summary.
|
||||
for &i in &prunable {
|
||||
/// Set `content = None` on each item at `indices`. Returns the number
|
||||
/// of items that were actually modified (already-pruned items are
|
||||
/// counted as 0).
|
||||
pub fn apply_prune(items: &mut [Item], indices: &[usize]) -> PruneResult {
|
||||
let mut count = 0;
|
||||
for &i in indices {
|
||||
if let Item::ToolResult { content, .. } = &mut items[i] {
|
||||
*content = None;
|
||||
if content.is_some() {
|
||||
*content = None;
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(PruneResult {
|
||||
pruned_count: prunable.len(),
|
||||
estimated_savings: total_savings,
|
||||
})
|
||||
PruneResult {
|
||||
pruned_count: count,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -148,53 +129,48 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn no_prune_when_too_few_turns() {
|
||||
let mut items = make_history(&[
|
||||
fn no_candidates_when_too_few_turns() {
|
||||
let items = make_history(&[
|
||||
("turn1", vec![("summary1", Some("big content here"))]),
|
||||
("turn2", vec![("summary2", Some("more content"))]),
|
||||
]);
|
||||
let config = PruneConfig {
|
||||
protected_turns: 3,
|
||||
min_savings: 0,
|
||||
};
|
||||
assert!(prune(&mut items, &config).is_none());
|
||||
assert!(prunable_indices(&items, 3).is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_prune_when_savings_below_threshold() {
|
||||
let mut items = make_history(&[
|
||||
("turn1", vec![("s", Some("tiny"))]), // ~1 token
|
||||
("turn2", vec![]),
|
||||
("turn3", vec![]),
|
||||
("turn4", vec![]),
|
||||
fn candidates_in_unprotected_turns() {
|
||||
let big = "x".repeat(4096 * 4);
|
||||
let items = make_history(&[
|
||||
("turn1", vec![("s1", Some(&big))]),
|
||||
("turn2", vec![("s2", Some(&big))]),
|
||||
("turn3", vec![("s3", Some("keep me"))]),
|
||||
("turn4", vec![("s4", Some("keep me too"))]),
|
||||
]);
|
||||
let config = PruneConfig {
|
||||
protected_turns: 2,
|
||||
min_savings: 9999,
|
||||
};
|
||||
assert!(prune(&mut items, &config).is_none());
|
||||
let candidates = prunable_indices(&items, 2);
|
||||
assert_eq!(candidates.len(), 2);
|
||||
// 候補は turn1 と turn2 の ToolResult のみ
|
||||
for &i in &candidates {
|
||||
if let Item::ToolResult { summary, .. } = &items[i] {
|
||||
assert!(summary == "s1" || summary == "s2");
|
||||
} else {
|
||||
panic!("non tool-result selected");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prune_old_content() {
|
||||
// 4 turns. protected_turns=2 → turns 1-2 are candidates.
|
||||
let big = "x".repeat(4096 * 4); // ~4096 tokens
|
||||
fn apply_drops_content_only() {
|
||||
let big = "x".repeat(64);
|
||||
let mut items = make_history(&[
|
||||
("turn1", vec![("s1", Some(&big))]),
|
||||
("turn2", vec![("s2", Some(&big))]),
|
||||
("turn3", vec![("s3", Some("keep me"))]),
|
||||
("turn4", vec![("s4", Some("keep me too"))]),
|
||||
]);
|
||||
let config = PruneConfig {
|
||||
protected_turns: 2,
|
||||
min_savings: 1000,
|
||||
};
|
||||
|
||||
let result = prune(&mut items, &config).expect("should prune");
|
||||
let candidates = prunable_indices(&items, 2);
|
||||
let result = apply_prune(&mut items, &candidates);
|
||||
assert_eq!(result.pruned_count, 2);
|
||||
assert!(result.estimated_savings >= 8000);
|
||||
|
||||
// Verify: pruned items have content=None, protected items keep content.
|
||||
for item in &items {
|
||||
if let Item::ToolResult {
|
||||
summary, content, ..
|
||||
|
|
@ -210,73 +186,49 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn idempotent() {
|
||||
let big = "x".repeat(4096 * 4);
|
||||
fn apply_is_idempotent() {
|
||||
let big = "x".repeat(64);
|
||||
let mut items = make_history(&[
|
||||
("turn1", vec![("s1", Some(&big))]),
|
||||
("turn2", vec![]),
|
||||
("turn3", vec![]),
|
||||
("turn4", vec![]),
|
||||
]);
|
||||
let config = PruneConfig {
|
||||
protected_turns: 2,
|
||||
min_savings: 100,
|
||||
};
|
||||
let first_indices = prunable_indices(&items, 2);
|
||||
assert_eq!(apply_prune(&mut items, &first_indices).pruned_count, 1);
|
||||
|
||||
let first = prune(&mut items, &config).expect("first prune");
|
||||
assert_eq!(first.pruned_count, 1);
|
||||
|
||||
// Second call: nothing left to prune.
|
||||
assert!(prune(&mut items, &config).is_none());
|
||||
// 2 周目: 候補は (まだ) いるかもしれないが、すでに content=None なので
|
||||
// apply_prune は 0 件と数える。
|
||||
let second_indices = prunable_indices(&items, 2);
|
||||
assert!(second_indices.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn already_pruned_items_skipped() {
|
||||
// Items that already have content=None are not counted as savings.
|
||||
let mut items = make_history(&[
|
||||
("turn1", vec![("s1", None)]), // already pruned
|
||||
fn already_pruned_items_excluded_from_candidates() {
|
||||
let items = make_history(&[
|
||||
("turn1", vec![("s1", None)]), // already pruned (content=None)
|
||||
("turn2", vec![]),
|
||||
("turn3", vec![]),
|
||||
("turn4", vec![]),
|
||||
]);
|
||||
let config = PruneConfig {
|
||||
protected_turns: 2,
|
||||
min_savings: 0, // Even with threshold 0, no savings means no prune
|
||||
};
|
||||
|
||||
assert!(prune(&mut items, &config).is_none());
|
||||
assert!(prunable_indices(&items, 2).is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn protected_turns_boundary_exact() {
|
||||
// 3 turns with protected_turns=2:
|
||||
// Turn 1 content should be pruned, turns 2-3 protected.
|
||||
let big = "x".repeat(4096 * 4);
|
||||
let mut items = make_history(&[
|
||||
// 3 turns with protected_turns=2: only turn 1 is a candidate.
|
||||
let big = "x".repeat(64);
|
||||
let items = make_history(&[
|
||||
("turn1", vec![("s1", Some(&big))]),
|
||||
("turn2", vec![("s2", Some("protected"))]),
|
||||
("turn3", vec![("s3", Some("also protected"))]),
|
||||
]);
|
||||
let config = PruneConfig {
|
||||
protected_turns: 2,
|
||||
min_savings: 100,
|
||||
};
|
||||
|
||||
let result = prune(&mut items, &config).expect("should prune turn1");
|
||||
assert_eq!(result.pruned_count, 1);
|
||||
|
||||
// Verify s1 pruned, s2 and s3 intact.
|
||||
for item in &items {
|
||||
if let Item::ToolResult {
|
||||
summary, content, ..
|
||||
} = item
|
||||
{
|
||||
match summary.as_str() {
|
||||
"s1" => assert!(content.is_none()),
|
||||
"s2" | "s3" => assert!(content.is_some()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
let candidates = prunable_indices(&items, 2);
|
||||
assert_eq!(candidates.len(), 1);
|
||||
if let Item::ToolResult { summary, .. } = &items[candidates[0]] {
|
||||
assert_eq!(summary, "s1");
|
||||
} else {
|
||||
panic!("expected ToolResult at candidate index");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,8 +10,11 @@ mod compact_interceptor;
|
|||
mod compact_state;
|
||||
mod hook_interceptor;
|
||||
mod pod;
|
||||
mod token_counter;
|
||||
mod usage_tracker;
|
||||
|
||||
pub use token_counter::{EstimateSource, SplitPoint, TokenEstimate};
|
||||
|
||||
pub use controller::{PodController, PodHandle};
|
||||
pub use manifest::{PodManifest, ProviderConfig, ProviderKind, Scope};
|
||||
pub use hook::{Hook, HookEventKind, HookRegistryBuilder};
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use llm_worker::Item;
|
||||
use llm_worker::llm_client::client::LlmClient;
|
||||
|
|
@ -7,7 +7,7 @@ use llm_worker::llm_client::RequestConfig;
|
|||
use llm_worker::state::Mutable;
|
||||
use llm_worker::{Worker, WorkerError, WorkerResult};
|
||||
use session_store::{
|
||||
EntryHash, Outcome, SessionId, SessionStartState, Store, StoreError,
|
||||
EntryHash, Outcome, SessionId, SessionStartState, Store, StoreError, UsageRecord,
|
||||
};
|
||||
use tracing::{info, warn};
|
||||
|
||||
|
|
@ -75,6 +75,14 @@ pub struct Pod<C: LlmClient, St: Store> {
|
|||
/// Captures `(history_len, UsageEvent)` pairs during a run; drained
|
||||
/// in `persist_turn` and persisted as `LogEntry::LlmUsage` entries.
|
||||
usage_tracker: Arc<UsageTracker>,
|
||||
/// Cumulative Usage measurement timeline, one entry per LLM call.
|
||||
/// Restored from session log on `restore`, appended on each persist.
|
||||
/// Read by token-accounting APIs (`Pod::total_tokens`, etc.).
|
||||
///
|
||||
/// Wrapped in `Arc<Mutex>` so that hooks living on the Worker
|
||||
/// (e.g. `PruneHook`) can share the same view via
|
||||
/// [`Pod::usage_history_handle`].
|
||||
usage_history: Arc<Mutex<Vec<UsageRecord>>>,
|
||||
/// Session-lifetime file-operation tracker from the builtin `tools`
|
||||
/// crate. Populated by the Controller when it registers the builtin
|
||||
/// tools so that Pod-owned operations (e.g. compaction) can consult
|
||||
|
|
@ -108,6 +116,7 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
|
|||
manifest_dir: None,
|
||||
compact_state: None,
|
||||
usage_tracker: Arc::new(UsageTracker::new()),
|
||||
usage_history: Arc::new(Mutex::new(Vec::<UsageRecord>::new())),
|
||||
tracker: None,
|
||||
})
|
||||
}
|
||||
|
|
@ -142,6 +151,7 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
|
|||
manifest_dir: None,
|
||||
compact_state: None,
|
||||
usage_tracker: Arc::new(UsageTracker::new()),
|
||||
usage_history: Arc::new(Mutex::new(state.usage_history)),
|
||||
tracker: None,
|
||||
})
|
||||
}
|
||||
|
|
@ -179,6 +189,30 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
|
|||
&self.store
|
||||
}
|
||||
|
||||
/// Current history items held by the underlying Worker.
|
||||
pub fn history(&self) -> &[Item] {
|
||||
self.worker().history()
|
||||
}
|
||||
|
||||
/// Snapshot of the cumulative LLM Usage measurement timeline.
|
||||
///
|
||||
/// One entry per LLM call. Restored on `restore` and appended in
|
||||
/// `persist_turn`. Used by token-accounting APIs in [`token_counter`].
|
||||
/// Returns a clone since the underlying vector is shared with hooks
|
||||
/// running on the Worker.
|
||||
pub fn usage_history(&self) -> Vec<UsageRecord> {
|
||||
self.usage_history.lock().expect("usage_history poisoned").clone()
|
||||
}
|
||||
|
||||
/// Shared handle to the cumulative Usage history.
|
||||
///
|
||||
/// Hooks (e.g. `PruneHook`) take a clone of this `Arc` so they can
|
||||
/// read the latest measurements at request time. The handle outlives
|
||||
/// any individual run.
|
||||
pub fn usage_history_handle(&self) -> Arc<Mutex<Vec<UsageRecord>>> {
|
||||
self.usage_history.clone()
|
||||
}
|
||||
|
||||
/// Attach the session-scoped file-operation tracker from the builtin
|
||||
/// `tools` crate. Called by the Controller immediately after it
|
||||
/// registers the builtin tools on the Worker. Overwrites any
|
||||
|
|
@ -483,7 +517,9 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
|
|||
|
||||
// Persist any LLM Usage measurements collected during this run.
|
||||
// One LogEntry::LlmUsage per LLM call (the tool loop may have run
|
||||
// many calls within a single Pod::run).
|
||||
// many calls within a single Pod::run). Each is also appended to
|
||||
// the in-memory `usage_history` so token-accounting APIs see it
|
||||
// before the next run.
|
||||
let usage_records = self.usage_tracker.drain();
|
||||
for record in usage_records {
|
||||
session_store::save_usage(
|
||||
|
|
@ -497,6 +533,7 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
|
|||
record.output_tokens,
|
||||
)
|
||||
.await?;
|
||||
self.usage_history.lock().expect("usage_history poisoned").push(record);
|
||||
}
|
||||
|
||||
let interrupted = self.worker.as_ref().unwrap().last_run_interrupted();
|
||||
|
|
@ -601,10 +638,13 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
|
|||
)
|
||||
.await?;
|
||||
|
||||
// Swap in the new session state.
|
||||
// Swap in the new session state. usage_history belongs to the old
|
||||
// session — the new compacted session starts with no measurements
|
||||
// until its first LLM call.
|
||||
self.session_id = new_session_id;
|
||||
self.head_hash = Some(new_head_hash);
|
||||
self.worker.as_mut().unwrap().set_history(new_history);
|
||||
self.usage_history.lock().expect("usage_history poisoned").clear();
|
||||
|
||||
Ok(new_session_id)
|
||||
}
|
||||
|
|
@ -658,6 +698,7 @@ impl<St: Store> Pod<Box<dyn LlmClient>, St> {
|
|||
manifest_dir,
|
||||
compact_state: None,
|
||||
usage_tracker: Arc::new(UsageTracker::new()),
|
||||
usage_history: Arc::new(Mutex::new(Vec::new())),
|
||||
tracker: None,
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,35 +1,77 @@
|
|||
//! PruneHook — applies conditional pruning before each LLM request.
|
||||
//!
|
||||
//! Wraps [`llm_worker::prune::prune()`] as a [`Hook<PreLlmRequest>`] so
|
||||
//! that Pod can register it in the hook pipeline.
|
||||
//! Wraps the pure `prune` API from `llm-worker` as a [`Hook<PreLlmRequest>`].
|
||||
//! `min_savings` の判定は usage 履歴ベースのトークン会計
|
||||
//! ([`crate::token_counter::savings_for_drop_impl`]) で行う。
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use llm_worker::interceptor::PreRequestAction;
|
||||
use llm_worker::prune::{PruneConfig, prune};
|
||||
use llm_worker::Item;
|
||||
use llm_worker::interceptor::PreRequestAction;
|
||||
use llm_worker::prune::{PruneConfig, apply_prune, prunable_indices};
|
||||
use session_store::UsageRecord;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::hook::{Hook, PreLlmRequest};
|
||||
use crate::token_counter::{EstimateSource, savings_for_drop_impl};
|
||||
|
||||
/// Hook that conditionally prunes old tool-result content before each
|
||||
/// LLM request, reclaiming context-window tokens.
|
||||
///
|
||||
/// `usage_history` は [`crate::Pod::usage_history_handle`] から共有された
|
||||
/// `Arc<Mutex<_>>`。リクエスト直前に snapshot を取って savings を見積もる。
|
||||
pub struct PruneHook {
|
||||
config: PruneConfig,
|
||||
usage_history: Arc<Mutex<Vec<UsageRecord>>>,
|
||||
}
|
||||
|
||||
impl PruneHook {
|
||||
pub fn new(config: PruneConfig) -> Self {
|
||||
Self { config }
|
||||
pub fn new(config: PruneConfig, usage_history: Arc<Mutex<Vec<UsageRecord>>>) -> Self {
|
||||
Self {
|
||||
config,
|
||||
usage_history,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Hook<PreLlmRequest> for PruneHook {
|
||||
async fn call(&self, context: &mut Vec<Item>) -> PreRequestAction {
|
||||
if let Some(result) = prune(context, &self.config) {
|
||||
let candidates = prunable_indices(context, self.config.protected_turns);
|
||||
if candidates.is_empty() {
|
||||
return PreRequestAction::Continue;
|
||||
}
|
||||
|
||||
// 候補範囲のトークン節約量を usage 履歴ベースで見積もる。
|
||||
// content だけ削除する場合の上限値(範囲全体を消した場合の savings)として
|
||||
// 近似する。実際の content drop は items 数を変えないので、本来の savings
|
||||
// はこの値以下。閾値判定は上振れ方向=「やや prune を発動しやすい」側で安全。
|
||||
let first = *candidates.first().unwrap();
|
||||
let last = *candidates.last().unwrap() + 1;
|
||||
let snapshot = self
|
||||
.usage_history
|
||||
.lock()
|
||||
.expect("usage_history poisoned")
|
||||
.clone();
|
||||
let savings = savings_for_drop_impl(context, &snapshot, first..last);
|
||||
|
||||
// measurement が無い場合 (NoData) は判定材料がないので prune を見送る。
|
||||
// 最初の LLM call が走るまでは usage_history が空なのでこのパスを通る。
|
||||
if matches!(savings.source, EstimateSource::NoData) {
|
||||
return PreRequestAction::Continue;
|
||||
}
|
||||
|
||||
if savings.tokens < self.config.min_savings {
|
||||
return PreRequestAction::Continue;
|
||||
}
|
||||
|
||||
let result = apply_prune(context, &candidates);
|
||||
if result.pruned_count > 0 {
|
||||
debug!(
|
||||
pruned = result.pruned_count,
|
||||
estimated_savings = result.estimated_savings,
|
||||
estimated_savings_tokens = savings.tokens,
|
||||
source = ?savings.source,
|
||||
"Pruned old tool-result content"
|
||||
);
|
||||
}
|
||||
|
|
|
|||
404
crates/pod/src/token_counter.rs
Normal file
404
crates/pod/src/token_counter.rs
Normal file
|
|
@ -0,0 +1,404 @@
|
|||
//! Usage 履歴ベースのトークン会計。
|
||||
//!
|
||||
//! `UsageRecord` の列(プロバイダ実測値)と現在の history から、
|
||||
//! 「末尾 N トークン残すための split 位置」「指定範囲を drop したときの
|
||||
//! 節約トークン数」などを pure に計算する。
|
||||
//!
|
||||
//! # 方針
|
||||
//!
|
||||
//! - ローカルトークナイザは持たない。実測値があればそれを採用し、
|
||||
//! measurement 間はバイト数で按分、最新 measurement より先は最終 rate で外挿する
|
||||
//! - 推定の出どころは [`EstimateSource`] で呼び出し側に明示する。
|
||||
//! 課金判断には使えないが、compact/prune の閾値判定には十分な精度
|
||||
//! - `records` は `history_len` 昇順を仮定する(`collect_state` と
|
||||
//! `UsageTracker` がそのように積む)
|
||||
//!
|
||||
//! 公開 API は本ファイル内の `impl Pod` で [`Pod`](crate::Pod) のメソッドとして
|
||||
//! 生やしている。pure な補助関数はこのモジュール内に private に閉じる。
|
||||
|
||||
use std::ops::Range;
|
||||
|
||||
use llm_worker::Item;
|
||||
use llm_worker::llm_client::client::LlmClient;
|
||||
use session_store::{Store, UsageRecord};
|
||||
|
||||
use crate::Pod;
|
||||
|
||||
/// 推定の出どころ。
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum EstimateSource {
|
||||
/// measurement の境界にちょうど一致(実測値そのもの)
|
||||
Measured,
|
||||
/// 連続する 2 つの measurement の間をバイト按分で計算
|
||||
Interpolated,
|
||||
/// 最後の measurement より新しい区間を最終 rate で外挿
|
||||
Extrapolated,
|
||||
/// measurement が 1 件も無く、バイト数のみのフォールバック
|
||||
NoData,
|
||||
}
|
||||
|
||||
impl EstimateSource {
|
||||
fn rank(self) -> u8 {
|
||||
match self {
|
||||
Self::Measured => 0,
|
||||
Self::Interpolated => 1,
|
||||
Self::Extrapolated => 2,
|
||||
Self::NoData => 3,
|
||||
}
|
||||
}
|
||||
|
||||
/// 複数の推定を合成するときは一番「粗い」ものに揃える。
|
||||
fn worst(self, other: Self) -> Self {
|
||||
if self.rank() >= other.rank() {
|
||||
self
|
||||
} else {
|
||||
other
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// トークン数の推定値。
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct TokenEstimate {
|
||||
pub tokens: u64,
|
||||
pub source: EstimateSource,
|
||||
}
|
||||
|
||||
/// history を分割する位置。
|
||||
///
|
||||
/// `items[..index]` が捨てる/要約される側、`items[index..]` が残る側。
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct SplitPoint {
|
||||
pub index: usize,
|
||||
pub source: EstimateSource,
|
||||
}
|
||||
|
||||
/// `items[..i]` までの累積バイト数(`prefix[i]`)を返す。長さは `items.len()+1`。
|
||||
fn prefix_bytes(items: &[Item]) -> Vec<u64> {
|
||||
let mut prefix = Vec::with_capacity(items.len() + 1);
|
||||
let mut acc: u64 = 0;
|
||||
prefix.push(0);
|
||||
for item in items {
|
||||
acc = acc.saturating_add(item_bytes(item));
|
||||
prefix.push(acc);
|
||||
}
|
||||
prefix
|
||||
}
|
||||
|
||||
/// 1 Item の大きさ。JSON シリアライズ長を使う粗い近似。
|
||||
/// トークン数との絶対変換ではなく区間の按分にしか使わないので、
|
||||
/// プロバイダごとの overhead は比率でキャンセルされる。
|
||||
fn item_bytes(item: &Item) -> u64 {
|
||||
serde_json::to_string(item)
|
||||
.map(|s| s.len() as u64)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// `history[..index]` までのトークン数を推定する。
|
||||
fn tokens_at(history: &[Item], records: &[UsageRecord], index: usize) -> TokenEstimate {
|
||||
debug_assert!(index <= history.len());
|
||||
|
||||
if index == 0 {
|
||||
return TokenEstimate {
|
||||
tokens: 0,
|
||||
source: EstimateSource::Measured,
|
||||
};
|
||||
}
|
||||
|
||||
if records.is_empty() {
|
||||
let prefix = prefix_bytes(history);
|
||||
return TokenEstimate {
|
||||
tokens: prefix[index] / 4,
|
||||
source: EstimateSource::NoData,
|
||||
};
|
||||
}
|
||||
|
||||
// exact match(rev 走査で一番新しい record を採用)
|
||||
if let Some(r) = records.iter().rev().find(|r| r.history_len == index) {
|
||||
return TokenEstimate {
|
||||
tokens: r.input_total_tokens,
|
||||
source: EstimateSource::Measured,
|
||||
};
|
||||
}
|
||||
|
||||
let lower = records.iter().rev().find(|r| r.history_len < index);
|
||||
let upper = records.iter().find(|r| r.history_len > index);
|
||||
let prefix = prefix_bytes(history);
|
||||
let cap = history.len();
|
||||
|
||||
match (lower, upper) {
|
||||
(Some(lo), Some(up)) => {
|
||||
let lo_bytes = prefix[lo.history_len.min(cap)];
|
||||
let up_bytes = prefix[up.history_len.min(cap)];
|
||||
let at_bytes = prefix[index];
|
||||
let span_bytes = up_bytes.saturating_sub(lo_bytes);
|
||||
let span_tokens = up
|
||||
.input_total_tokens
|
||||
.saturating_sub(lo.input_total_tokens);
|
||||
if span_bytes == 0 || span_tokens == 0 {
|
||||
return TokenEstimate {
|
||||
tokens: lo.input_total_tokens,
|
||||
source: EstimateSource::Interpolated,
|
||||
};
|
||||
}
|
||||
let delta_bytes = at_bytes.saturating_sub(lo_bytes);
|
||||
let delta_tokens =
|
||||
(delta_bytes as u128 * span_tokens as u128 / span_bytes as u128) as u64;
|
||||
TokenEstimate {
|
||||
tokens: lo.input_total_tokens + delta_tokens,
|
||||
source: EstimateSource::Interpolated,
|
||||
}
|
||||
}
|
||||
(Some(lo), None) => {
|
||||
let lo_bytes = prefix[lo.history_len.min(cap)];
|
||||
let at_bytes = prefix[index];
|
||||
if lo_bytes == 0 || lo.input_total_tokens == 0 {
|
||||
return TokenEstimate {
|
||||
tokens: lo.input_total_tokens,
|
||||
source: EstimateSource::Extrapolated,
|
||||
};
|
||||
}
|
||||
let delta_bytes = at_bytes.saturating_sub(lo_bytes);
|
||||
let delta_tokens =
|
||||
(delta_bytes as u128 * lo.input_total_tokens as u128 / lo_bytes as u128) as u64;
|
||||
TokenEstimate {
|
||||
tokens: lo.input_total_tokens + delta_tokens,
|
||||
source: EstimateSource::Extrapolated,
|
||||
}
|
||||
}
|
||||
(None, Some(up)) => {
|
||||
let up_bytes = prefix[up.history_len.min(cap)];
|
||||
let at_bytes = prefix[index];
|
||||
if up_bytes == 0 {
|
||||
return TokenEstimate {
|
||||
tokens: 0,
|
||||
source: EstimateSource::Interpolated,
|
||||
};
|
||||
}
|
||||
let t = (at_bytes as u128 * up.input_total_tokens as u128 / up_bytes as u128) as u64;
|
||||
TokenEstimate {
|
||||
tokens: t,
|
||||
source: EstimateSource::Interpolated,
|
||||
}
|
||||
}
|
||||
(None, None) => unreachable!("records non-empty but neither lower nor upper matched"),
|
||||
}
|
||||
}
|
||||
|
||||
fn total_tokens_impl(history: &[Item], records: &[UsageRecord]) -> TokenEstimate {
|
||||
tokens_at(history, records, history.len())
|
||||
}
|
||||
|
||||
fn split_for_retained_impl(
|
||||
history: &[Item],
|
||||
records: &[UsageRecord],
|
||||
retained: u64,
|
||||
) -> SplitPoint {
|
||||
let current = total_tokens_impl(history, records);
|
||||
if current.tokens <= retained {
|
||||
return SplitPoint {
|
||||
index: 0,
|
||||
source: current.source,
|
||||
};
|
||||
}
|
||||
let target = current.tokens - retained;
|
||||
|
||||
// `tokens_at` が target 以上になる最小の idx を線形探索。
|
||||
// history.len() は高々数百〜数千なので十分速い。将来ボトルネックになれば
|
||||
// record 境界で二分探索に置き換える。
|
||||
let mut chosen_source = current.source;
|
||||
for idx in 1..=history.len() {
|
||||
let est = tokens_at(history, records, idx);
|
||||
if est.tokens >= target {
|
||||
chosen_source = est.source;
|
||||
return SplitPoint {
|
||||
index: idx,
|
||||
source: chosen_source,
|
||||
};
|
||||
}
|
||||
}
|
||||
SplitPoint {
|
||||
index: history.len(),
|
||||
source: chosen_source,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn savings_for_drop_impl(
|
||||
history: &[Item],
|
||||
records: &[UsageRecord],
|
||||
range: Range<usize>,
|
||||
) -> TokenEstimate {
|
||||
if range.start >= range.end || range.end > history.len() {
|
||||
return TokenEstimate {
|
||||
tokens: 0,
|
||||
source: EstimateSource::Measured,
|
||||
};
|
||||
}
|
||||
let s = tokens_at(history, records, range.start);
|
||||
let e = tokens_at(history, records, range.end);
|
||||
TokenEstimate {
|
||||
tokens: e.tokens.saturating_sub(s.tokens),
|
||||
source: s.source.worst(e.source),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Pod に生やす公開 API ───────────────────────────────────────────────
|
||||
|
||||
impl<C: LlmClient, St: Store> Pod<C, St> {
|
||||
/// 現在の history 全体の推定トークン数。
|
||||
///
|
||||
/// 最後の measurement と、その後に追加された未測定分のバイト按分/外挿。
|
||||
pub fn total_tokens(&self) -> TokenEstimate {
|
||||
let usage = self.usage_history();
|
||||
total_tokens_impl(self.history(), &usage)
|
||||
}
|
||||
|
||||
/// 末尾から `retained` トークン以上を残すための分割位置。
|
||||
///
|
||||
/// `history[..cut.index]` が要約/破棄される側、`history[cut.index..]` が残る側。
|
||||
pub fn split_for_retained(&self, retained: u64) -> SplitPoint {
|
||||
let usage = self.usage_history();
|
||||
split_for_retained_impl(self.history(), &usage, retained)
|
||||
}
|
||||
|
||||
/// 指定範囲を drop したときの節約トークン数の推定。
|
||||
pub fn savings_for_drop(&self, range: Range<usize>) -> TokenEstimate {
|
||||
let usage = self.usage_history();
|
||||
savings_for_drop_impl(self.history(), &usage, range)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn msg(text: &str) -> Item {
|
||||
Item::user_message(text)
|
||||
}
|
||||
|
||||
fn record(history_len: usize, tokens: u64) -> UsageRecord {
|
||||
UsageRecord {
|
||||
history_len,
|
||||
input_total_tokens: tokens,
|
||||
cache_read_tokens: 0,
|
||||
cache_write_tokens: 0,
|
||||
output_tokens: 0,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn total_no_data_falls_back_to_byte_estimate() {
|
||||
let history = vec![msg("hello world")];
|
||||
let est = total_tokens_impl(&history, &[]);
|
||||
assert_eq!(est.source, EstimateSource::NoData);
|
||||
assert!(est.tokens > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn total_measured_when_last_record_matches_history_len() {
|
||||
let history = vec![msg("a"), msg("b"), msg("c")];
|
||||
let records = vec![record(3, 120)];
|
||||
let est = total_tokens_impl(&history, &records);
|
||||
assert_eq!(est.source, EstimateSource::Measured);
|
||||
assert_eq!(est.tokens, 120);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn total_extrapolated_when_history_grew_past_last_measurement() {
|
||||
let history = vec![msg("a"), msg("b"), msg("c"), msg("d")];
|
||||
let records = vec![record(3, 100)];
|
||||
let est = total_tokens_impl(&history, &records);
|
||||
assert_eq!(est.source, EstimateSource::Extrapolated);
|
||||
assert!(est.tokens > 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn total_zero_history_is_zero() {
|
||||
let est = total_tokens_impl(&[], &[]);
|
||||
assert_eq!(est.tokens, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn split_returns_zero_when_current_below_retained() {
|
||||
let history = vec![msg("a"), msg("b")];
|
||||
let records = vec![record(2, 50)];
|
||||
let cut = split_for_retained_impl(&history, &records, 1000);
|
||||
assert_eq!(cut.index, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn split_at_exact_measurement_boundary() {
|
||||
// 4 items。measurements: len=2 → 100, len=4 → 300。
|
||||
// retained=200 → target_drop = 100 → record[0] にぴったり一致 → index=2。
|
||||
let history = vec![msg("a"), msg("b"), msg("c"), msg("d")];
|
||||
let records = vec![record(2, 100), record(4, 300)];
|
||||
let cut = split_for_retained_impl(&history, &records, 200);
|
||||
assert_eq!(cut.index, 2);
|
||||
assert_eq!(cut.source, EstimateSource::Measured);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn split_interpolated_between_measurements() {
|
||||
let history = vec![
|
||||
msg("aaaaaa"),
|
||||
msg("bbbbbb"),
|
||||
msg("cccccc"),
|
||||
msg("dddddd"),
|
||||
];
|
||||
let records = vec![record(1, 50), record(4, 400)];
|
||||
let cut = split_for_retained_impl(&history, &records, 250);
|
||||
assert!(cut.index > 1 && cut.index <= 4);
|
||||
assert_eq!(cut.source, EstimateSource::Interpolated);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn split_all_when_retained_zero() {
|
||||
let history = vec![msg("a"), msg("b")];
|
||||
let records = vec![record(2, 100)];
|
||||
let cut = split_for_retained_impl(&history, &records, 0);
|
||||
assert_eq!(cut.index, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn savings_for_drop_uses_measurement_difference() {
|
||||
let history = vec![msg("a"), msg("b"), msg("c")];
|
||||
let records = vec![record(1, 100), record(3, 300)];
|
||||
let est = savings_for_drop_impl(&history, &records, 1..3);
|
||||
assert_eq!(est.tokens, 200);
|
||||
assert_eq!(est.source, EstimateSource::Measured);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn savings_for_drop_empty_range_is_zero() {
|
||||
let history = vec![msg("a"), msg("b")];
|
||||
let records = vec![record(2, 100)];
|
||||
let est = savings_for_drop_impl(&history, &records, 1..1);
|
||||
assert_eq!(est.tokens, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn savings_for_drop_interpolates_inside_measurement_span() {
|
||||
// len=4 → 400 のみ。range 1..3 は原点 0 と upper=4 の間で按分。
|
||||
let history = vec![msg("aa"), msg("aa"), msg("aa"), msg("aa")];
|
||||
let records = vec![record(4, 400)];
|
||||
let est = savings_for_drop_impl(&history, &records, 1..3);
|
||||
assert_eq!(est.source, EstimateSource::Interpolated);
|
||||
assert!(est.tokens > 0 && est.tokens < 400);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn savings_for_drop_no_records_falls_back_to_bytes() {
|
||||
let history = vec![msg("hello hello"), msg("world world")];
|
||||
let est = savings_for_drop_impl(&history, &[], 0..1);
|
||||
assert_eq!(est.source, EstimateSource::NoData);
|
||||
assert!(est.tokens > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn savings_for_drop_out_of_range_is_zero() {
|
||||
let history = vec![msg("a")];
|
||||
let records = vec![record(1, 100)];
|
||||
let est = savings_for_drop_impl(&history, &records, 0..5);
|
||||
assert_eq!(est.tokens, 0);
|
||||
}
|
||||
}
|
||||
|
|
@ -141,6 +141,10 @@ if saved.tokens >= min_savings {
|
|||
呼び出しに置き換え(呼び出し側で渡す)
|
||||
- prune の API シグネチャ調整は最小限に
|
||||
|
||||
## レビュー状態
|
||||
|
||||
Reviewed — [token-counter.review.md](token-counter.review.md)
|
||||
|
||||
## 依存
|
||||
|
||||
- [usage-history.md](usage-history.md) — Usage を session-store に積む基盤
|
||||
|
|
|
|||
60
tickets/token-counter.review.md
Normal file
60
tickets/token-counter.review.md
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
# token-counter レビュー
|
||||
|
||||
## 要件の充足
|
||||
|
||||
チケットが定義した 3 API・型・アルゴリズムは全て実装されている:
|
||||
|
||||
- `Pod::total_tokens()` → `TokenEstimate`
|
||||
- `Pod::split_for_retained(retained)` → `SplitPoint`
|
||||
- `Pod::savings_for_drop(range)` → `TokenEstimate`
|
||||
- `EstimateSource`: `Measured / Interpolated / Extrapolated / NoData`
|
||||
|
||||
設計方針(状態を持たない pure 関数、provider 非依存、ローカルトークナイザ不要)も
|
||||
満たされている。`_impl` 関数群は `(&[Item], &[UsageRecord])` だけを受け取り、
|
||||
Pod メソッドは history と usage_history を渡すだけの薄いラッパー。
|
||||
|
||||
## アーキテクチャ
|
||||
|
||||
| レイヤー | 変更内容 |
|
||||
|---------|---------|
|
||||
| pod::token_counter | pure な計算関数 + Pod のメソッドとして公開 |
|
||||
| pod::Pod | `usage_history: Arc<Mutex<Vec<UsageRecord>>>` を追加。restore で復元、persist_turn で追記、compact で clear |
|
||||
| pod::PruneHook | min_savings 判定を `savings_for_drop_impl` に委譲。usage_history の shared handle を保持 |
|
||||
| llm-worker::prune | `prune()` → `prunable_indices()` + `apply_prune()` に分解。min_savings 判定とトークン会計への依存を除去 |
|
||||
|
||||
prune の責務分離が適切。llm-worker 側は pure な候補抽出と適用のみ、
|
||||
トークン会計への依存は pod 層に閉じている。
|
||||
|
||||
## 指摘と対処
|
||||
|
||||
### 1. split_for_retained_impl の O(n²) シリアライズ(非ブロッカー、未対処)
|
||||
|
||||
`tokens_at` を `1..=history.len()` で毎回呼び、内部で `prefix_bytes`(history 全体の
|
||||
JSON シリアライズ)を都度計算。長大セッションでは item 数に対して二乗になる。
|
||||
`prefix_bytes` をループ外で 1 回だけ計算して渡す形にリファクタリングすべきだが、
|
||||
現時点の history サイズでは実害なし。パフォーマンスが問題になった段階で対処。
|
||||
|
||||
### 2. PruneHook の savings 過大評価(認識済み、未対処)
|
||||
|
||||
prune は content を None にするだけで item を消さないため、`savings_for_drop`
|
||||
(範囲全体の drop を仮定)は実際の節約量より大きい値を返す。閾値判定としては
|
||||
prune を発動しやすい方向=安全側。ログの `estimated_savings_tokens` が過大になる
|
||||
点はチューニング時に注意。
|
||||
|
||||
### 3. compact 後の usage_history.clear()(後続チケットで対処)
|
||||
|
||||
compact 直後は measurement が空になり `total_tokens()` が `NoData` を返す。
|
||||
compact-improvements で `last_input_tokens` を撤去して閾値判定を usage 経由に
|
||||
一本化する際、この NoData 期間の扱いを設計する必要がある。
|
||||
|
||||
## テスト
|
||||
|
||||
token_counter: 13 件(NoData / Measured / Extrapolated / Interpolated 各ケース、
|
||||
split の境界、savings の measurement 差分、空 range、out-of-range)。
|
||||
|
||||
prune (llm-worker): `prunable_indices` + `apply_prune` に分解後のテスト 5 件。
|
||||
候補抽出、適用、冪等性、既 prune 済み除外、境界。
|
||||
|
||||
## 判定
|
||||
|
||||
承認。
|
||||
Loading…
Reference in New Issue
Block a user