yoi/crates/pod/src/token_counter.rs
Hare 967acd23ee compact: 閾値を個別指定化し占有量ソースを UsageRecord に一本化
- manifest に compact_request_threshold を追加 (proactive と safety net を個別指定)
- CompactState の両閾値を Option<u64> 化、last_input_tokens を撤去
- 閾値判定は Pod::total_tokens() / usage_history 経由の実測値ベースに切替
- turn_threshold → request_threshold にリネーム、Between-requests のログへ
2026-04-19 08:49:25 +09:00

473 lines
17 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! Usage 履歴ベースのトークン会計。
//!
//! `UsageRecord` の列(プロバイダ実測値)と現在の history から、
//! 「末尾 N トークン残すための split 位置」「prune 射影で節約される
//! トークン数」などを pure に計算する。
//!
//! # 方針
//!
//! - ローカルトークナイザは持たない。実測値があればそれを採用し、
//! measurement 間はバイト数で按分、最新 measurement より先は最終 rate で外挿する
//! - 推定の出どころは [`EstimateSource`] で呼び出し側に明示する。
//! 課金判断には使えないが、compact/prune の閾値判定には十分な精度
//! - `records` は `history_len` 昇順を仮定する(`collect_state` と
//! `UsageTracker` がそのように積む)
//!
//! 公開 API は本ファイル内の `impl Pod` で [`Pod`](crate::Pod) のメソッドとして
//! 生やしている。pure な補助関数はこのモジュール内に private に閉じる。
use llm_worker::Item;
use llm_worker::llm_client::client::LlmClient;
use session_store::{Store, UsageRecord};
use crate::Pod;
/// 推定の出どころ。
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EstimateSource {
/// measurement の境界にちょうど一致(実測値そのもの)
Measured,
/// 連続する 2 つの measurement の間をバイト按分で計算
Interpolated,
/// 最後の measurement より新しい区間を最終 rate で外挿
Extrapolated,
/// measurement が 1 件も無く、バイト数のみのフォールバック
NoData,
}
/// トークン数の推定値。
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TokenEstimate {
pub tokens: u64,
pub source: EstimateSource,
}
/// history を分割する位置。
///
/// `items[..index]` が捨てる/要約される側、`items[index..]` が残る側。
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SplitPoint {
pub index: usize,
pub source: EstimateSource,
}
/// `items[..i]` までの累積バイト数(`prefix[i]`)を返す。長さは `items.len()+1`。
fn prefix_bytes(items: &[Item]) -> Vec<u64> {
let mut prefix = Vec::with_capacity(items.len() + 1);
let mut acc: u64 = 0;
prefix.push(0);
for item in items {
acc = acc.saturating_add(item_bytes(item));
prefix.push(acc);
}
prefix
}
/// 1 Item の大きさ。JSON シリアライズ長を使う粗い近似。
/// トークン数との絶対変換ではなく区間の按分にしか使わないので、
/// プロバイダごとの overhead は比率でキャンセルされる。
fn item_bytes(item: &Item) -> u64 {
serde_json::to_string(item)
.map(|s| s.len() as u64)
.unwrap_or(0)
}
/// `history[..index]` までのトークン数を推定する。
///
/// `prefix` は [`prefix_bytes`] で得た `history.len() + 1` 長の累積バイト列。
/// 呼び出し側が 1 度だけ計算して使い回すことで、線形探索や複数回の推定が
/// O(n) シリアライズで済む(内部で毎回再計算すると O(n²) になる)。
fn tokens_at(
history: &[Item],
records: &[UsageRecord],
index: usize,
prefix: &[u64],
) -> TokenEstimate {
debug_assert!(index <= history.len());
debug_assert_eq!(prefix.len(), history.len() + 1);
if index == 0 {
return TokenEstimate {
tokens: 0,
source: EstimateSource::Measured,
};
}
if records.is_empty() {
return TokenEstimate {
tokens: prefix[index] / 4,
source: EstimateSource::NoData,
};
}
// exact matchrev 走査で一番新しい record を採用)
if let Some(r) = records.iter().rev().find(|r| r.history_len == index) {
return TokenEstimate {
tokens: r.input_total_tokens,
source: EstimateSource::Measured,
};
}
let lower = records.iter().rev().find(|r| r.history_len < index);
let upper = records.iter().find(|r| r.history_len > index);
let cap = history.len();
match (lower, upper) {
(Some(lo), Some(up)) => {
let lo_bytes = prefix[lo.history_len.min(cap)];
let up_bytes = prefix[up.history_len.min(cap)];
let at_bytes = prefix[index];
let span_bytes = up_bytes.saturating_sub(lo_bytes);
let span_tokens = up.input_total_tokens.saturating_sub(lo.input_total_tokens);
if span_bytes == 0 || span_tokens == 0 {
return TokenEstimate {
tokens: lo.input_total_tokens,
source: EstimateSource::Interpolated,
};
}
let delta_bytes = at_bytes.saturating_sub(lo_bytes);
let delta_tokens =
(delta_bytes as u128 * span_tokens as u128 / span_bytes as u128) as u64;
TokenEstimate {
tokens: lo.input_total_tokens + delta_tokens,
source: EstimateSource::Interpolated,
}
}
(Some(lo), None) => {
let lo_bytes = prefix[lo.history_len.min(cap)];
let at_bytes = prefix[index];
if lo_bytes == 0 || lo.input_total_tokens == 0 {
return TokenEstimate {
tokens: lo.input_total_tokens,
source: EstimateSource::Extrapolated,
};
}
let delta_bytes = at_bytes.saturating_sub(lo_bytes);
let delta_tokens =
(delta_bytes as u128 * lo.input_total_tokens as u128 / lo_bytes as u128) as u64;
TokenEstimate {
tokens: lo.input_total_tokens + delta_tokens,
source: EstimateSource::Extrapolated,
}
}
(None, Some(up)) => {
let up_bytes = prefix[up.history_len.min(cap)];
let at_bytes = prefix[index];
if up_bytes == 0 {
return TokenEstimate {
tokens: 0,
source: EstimateSource::Interpolated,
};
}
let t = (at_bytes as u128 * up.input_total_tokens as u128 / up_bytes as u128) as u64;
TokenEstimate {
tokens: t,
source: EstimateSource::Interpolated,
}
}
(None, None) => unreachable!("records non-empty but neither lower nor upper matched"),
}
}
pub(crate) fn total_tokens_impl(history: &[Item], records: &[UsageRecord]) -> TokenEstimate {
let prefix = prefix_bytes(history);
tokens_at(history, records, history.len(), &prefix)
}
fn split_for_retained_impl(history: &[Item], records: &[UsageRecord], retained: u64) -> SplitPoint {
let prefix = prefix_bytes(history);
let current = tokens_at(history, records, history.len(), &prefix);
if current.tokens <= retained {
return SplitPoint {
index: 0,
source: current.source,
};
}
let target = current.tokens - retained;
// `tokens_at` が target 以上になる最小の idx を線形探索。
// prefix を使い回すので 1 回の split 呼び出しあたり O(n) で済む
// (内部で毎回再計算すると O(n²) になる)。将来ボトルネックになれば
// record 境界で二分探索に置き換える。
let mut chosen_source = current.source;
for idx in 1..=history.len() {
let est = tokens_at(history, records, idx, &prefix);
if est.tokens >= target {
chosen_source = est.source;
return SplitPoint {
index: idx,
source: chosen_source,
};
}
}
SplitPoint {
index: history.len(),
source: chosen_source,
}
}
/// 1 つの ToolResult 項目について、`content` を `None` に射影したとき
/// 減少するシリアライズ後バイト数。ToolResult 以外や既に content=None
/// の item は 0 を返す。
fn tool_result_content_bytes(item: &Item) -> u64 {
if !matches!(
item,
Item::ToolResult {
content: Some(_),
..
}
) {
return 0;
}
let mut cleared = item.clone();
if let Item::ToolResult { content, .. } = &mut cleared {
*content = None;
}
item_bytes(item).saturating_sub(item_bytes(&cleared))
}
/// Prune 射影(`ToolResult.content = None`)で節約されるトークン数の推定。
///
/// `indices` は [`llm_worker::prune::prunable_indices`] が返す候補列を
/// 想定する。各候補の content バイト差分を合算し、usage 履歴由来の
/// tokens/byte レートでトークン数に換算する。範囲を「丸ごと drop」する
/// のではなく、item 自体summary 等)は残したままの値を返す点が
/// `tokens_at` ベースの計算と異なる。
pub(crate) fn savings_for_prune_impl(
history: &[Item],
records: &[UsageRecord],
indices: &[usize],
) -> TokenEstimate {
let removed_bytes: u64 = indices
.iter()
.filter_map(|&i| history.get(i))
.map(tool_result_content_bytes)
.sum();
if removed_bytes == 0 {
return TokenEstimate {
tokens: 0,
source: EstimateSource::Measured,
};
}
if records.is_empty() {
return TokenEstimate {
tokens: removed_bytes / 4,
source: EstimateSource::NoData,
};
}
// 最新の measurement を使って tokens/byte を求め、バイト差分を換算する。
// 実測値そのものではなく比率しか使わないので、history_len と
// record.history_len が一致しなくても rate は正しい。
let prefix = prefix_bytes(history);
let last = records.last().expect("records non-empty");
let ref_bytes = prefix[last.history_len.min(history.len())];
if ref_bytes == 0 || last.input_total_tokens == 0 {
return TokenEstimate {
tokens: 0,
source: EstimateSource::Extrapolated,
};
}
let tokens =
(removed_bytes as u128 * last.input_total_tokens as u128 / ref_bytes as u128) as u64;
let source = if last.history_len == history.len() {
EstimateSource::Measured
} else {
EstimateSource::Extrapolated
};
TokenEstimate { tokens, source }
}
// ── Pod に生やす公開 API ───────────────────────────────────────────────
impl<C: LlmClient, St: Store> Pod<C, St> {
/// 現在の history 全体の推定トークン数。
///
/// 最後の measurement と、その後に追加された未測定分のバイト按分/外挿。
pub fn total_tokens(&self) -> TokenEstimate {
let usage = self.usage_history();
total_tokens_impl(self.history(), &usage)
}
/// 末尾から `retained` トークン以上を残すための分割位置。
///
/// `history[..cut.index]` が要約/破棄される側、`history[cut.index..]` が残る側。
pub fn split_for_retained(&self, retained: u64) -> SplitPoint {
let usage = self.usage_history();
split_for_retained_impl(self.history(), &usage, retained)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn msg(text: &str) -> Item {
Item::user_message(text)
}
fn record(history_len: usize, tokens: u64) -> UsageRecord {
UsageRecord {
history_len,
input_total_tokens: tokens,
cache_read_tokens: 0,
cache_write_tokens: 0,
output_tokens: 0,
}
}
#[test]
fn total_no_data_falls_back_to_byte_estimate() {
let history = vec![msg("hello world")];
let est = total_tokens_impl(&history, &[]);
assert_eq!(est.source, EstimateSource::NoData);
assert!(est.tokens > 0);
}
#[test]
fn total_measured_when_last_record_matches_history_len() {
let history = vec![msg("a"), msg("b"), msg("c")];
let records = vec![record(3, 120)];
let est = total_tokens_impl(&history, &records);
assert_eq!(est.source, EstimateSource::Measured);
assert_eq!(est.tokens, 120);
}
#[test]
fn total_extrapolated_when_history_grew_past_last_measurement() {
let history = vec![msg("a"), msg("b"), msg("c"), msg("d")];
let records = vec![record(3, 100)];
let est = total_tokens_impl(&history, &records);
assert_eq!(est.source, EstimateSource::Extrapolated);
assert!(est.tokens > 100);
}
#[test]
fn total_zero_history_is_zero() {
let est = total_tokens_impl(&[], &[]);
assert_eq!(est.tokens, 0);
}
#[test]
fn split_returns_zero_when_current_below_retained() {
let history = vec![msg("a"), msg("b")];
let records = vec![record(2, 50)];
let cut = split_for_retained_impl(&history, &records, 1000);
assert_eq!(cut.index, 0);
}
#[test]
fn split_at_exact_measurement_boundary() {
// 4 items。measurements: len=2 → 100, len=4 → 300。
// retained=200 → target_drop = 100 → record[0] にぴったり一致 → index=2。
let history = vec![msg("a"), msg("b"), msg("c"), msg("d")];
let records = vec![record(2, 100), record(4, 300)];
let cut = split_for_retained_impl(&history, &records, 200);
assert_eq!(cut.index, 2);
assert_eq!(cut.source, EstimateSource::Measured);
}
#[test]
fn split_interpolated_between_measurements() {
let history = vec![msg("aaaaaa"), msg("bbbbbb"), msg("cccccc"), msg("dddddd")];
let records = vec![record(1, 50), record(4, 400)];
let cut = split_for_retained_impl(&history, &records, 250);
assert!(cut.index > 1 && cut.index <= 4);
assert_eq!(cut.source, EstimateSource::Interpolated);
}
#[test]
fn split_all_when_retained_zero() {
let history = vec![msg("a"), msg("b")];
let records = vec![record(2, 100)];
let cut = split_for_retained_impl(&history, &records, 0);
assert_eq!(cut.index, 2);
}
fn tool_result_with(summary: &str, content: Option<&str>) -> Item {
match content {
Some(c) => Item::tool_result_with_content("call", summary, c),
None => Item::tool_result("call", summary),
}
}
#[test]
fn savings_for_prune_skips_non_toolresult_indices() {
let history = vec![msg("a"), msg("b"), msg("c")];
// indices point at plain messages, not ToolResult → 0 savings.
let est = savings_for_prune_impl(&history, &[record(3, 300)], &[0, 1, 2]);
assert_eq!(est.tokens, 0);
}
#[test]
fn savings_for_prune_skips_content_none_items() {
let history = vec![
msg("user"),
tool_result_with("s1", None),
tool_result_with("s2", None),
];
let est = savings_for_prune_impl(&history, &[record(3, 300)], &[1, 2]);
assert_eq!(est.tokens, 0);
}
#[test]
fn savings_for_prune_counts_only_content_delta() {
// 1 item with big content vs the same structure without content.
let big = "x".repeat(400);
let history = vec![
msg("user"),
tool_result_with("summary", Some(&big)),
msg("tail"),
];
// 1 record at end so rate = tokens / total_bytes
let total_bytes: u64 = history.iter().map(item_bytes).sum();
let records = vec![record(history.len(), total_bytes)]; // rate = 1 tok/byte
let est = savings_for_prune_impl(&history, &records, &[1]);
// saved bytes ≈ size of the big content payload; with rate=1 it
// should be close to 400 and far from the full item bytes.
let full_item_bytes = item_bytes(&history[1]);
assert!(est.tokens > 0);
assert!(est.tokens < full_item_bytes);
assert!(est.tokens >= 400);
assert_eq!(est.source, EstimateSource::Measured);
}
#[test]
fn savings_for_prune_no_records_falls_back_to_bytes() {
let history = vec![msg("u"), tool_result_with("s", Some("hello world"))];
let est = savings_for_prune_impl(&history, &[], &[1]);
assert_eq!(est.source, EstimateSource::NoData);
assert!(est.tokens > 0);
}
#[test]
fn savings_for_prune_extrapolated_when_history_grew_past_measurement() {
let big = "x".repeat(200);
let history = vec![
msg("u1"),
tool_result_with("s", Some(&big)),
msg("u2"), // added after the last measurement
];
let records = vec![record(2, 100)];
let est = savings_for_prune_impl(&history, &records, &[1]);
assert_eq!(est.source, EstimateSource::Extrapolated);
assert!(est.tokens > 0);
}
#[test]
fn savings_for_prune_empty_indices_is_zero() {
let history = vec![msg("a")];
let est = savings_for_prune_impl(&history, &[record(1, 100)], &[]);
assert_eq!(est.tokens, 0);
}
#[test]
fn savings_for_prune_ignores_out_of_range_indices() {
let history = vec![msg("a")];
let est = savings_for_prune_impl(&history, &[record(1, 100)], &[99]);
assert_eq!(est.tokens, 0);
}
}