merge: prompt occupancy estimator

This commit is contained in:
Keisuke Hirata 2026-06-01 10:09:30 +09:00
commit e51944f045
No known key found for this signature in database
2 changed files with 113 additions and 10 deletions

View File

@ -6,7 +6,8 @@
//! # 方針
//!
//! - ローカルトークナイザは持たない。実測値があればそれを採用し、
//! measurement 間はバイト数で按分、最新 measurement より先は最終 rate で外挿する
//! measurement 間はバイト数で按分、最新 measurement より先は測定済みの増分 rate
//! または byte/4 fallback で外挿する
//! - 推定の出どころは [`EstimateSource`] で呼び出し側に明示する。
//! 課金判断には使えないが、compact / prune / memory extract trigger 等の
//! 閾値判定には十分な精度
@ -119,17 +120,35 @@ pub fn tokens_at(
(Some(lo), None) => {
let lo_bytes = prefix[lo.history_len.min(cap)];
let at_bytes = prefix[index];
if lo_bytes == 0 || lo.input_total_tokens == 0 {
return TokenEstimate {
tokens: lo.input_total_tokens,
source: EstimateSource::Extrapolated,
};
}
let delta_bytes = at_bytes.saturating_sub(lo_bytes);
let delta_tokens =
(delta_bytes as u128 * lo.input_total_tokens as u128 / lo_bytes as u128) as u64;
let mut measured_span = None;
for pair in records.windows(2) {
let older = &pair[0];
let newer = &pair[1];
if newer.history_len > lo.history_len {
break;
}
let older_bytes = prefix[older.history_len.min(cap)];
let newer_bytes = prefix[newer.history_len.min(cap)];
let span_bytes = newer_bytes.saturating_sub(older_bytes);
let span_tokens = newer
.input_total_tokens
.saturating_sub(older.input_total_tokens);
if span_bytes > 0 && span_tokens > 0 {
measured_span = Some((span_tokens, span_bytes));
}
}
let delta_tokens = if let Some((span_tokens, span_bytes)) = measured_span {
(delta_bytes as u128 * span_tokens as u128 / span_bytes as u128) as u64
} else {
delta_bytes / 4
};
TokenEstimate {
tokens: lo.input_total_tokens + delta_tokens,
tokens: lo.input_total_tokens.saturating_add(delta_tokens),
source: EstimateSource::Extrapolated,
}
}
@ -214,6 +233,47 @@ mod tests {
assert!(est.tokens > 100);
}
#[test]
fn extrapolation_after_single_measurement_uses_byte_fallback_not_total_prompt_rate() {
let history = vec![msg("first"), msg(&"tool output ".repeat(400))];
let records = vec![record(1, 11_124)];
let prefix = prefix_bytes(&history);
let delta_bytes = prefix[2].saturating_sub(prefix[1]);
let est = total_tokens(&history, &records);
assert_eq!(est.source, EstimateSource::Extrapolated);
assert_eq!(est.tokens, 11_124 + delta_bytes / 4);
let old_projection =
11_124 + (delta_bytes as u128 * 11_124_u128 / prefix[1] as u128) as u64;
assert!(
old_projection > est.tokens.saturating_mul(10),
"old_projection={old_projection}, corrected={}",
est.tokens
);
}
#[test]
fn extrapolation_prefers_latest_measured_incremental_span_rate() {
let history = vec![
msg("first"),
msg(&"measured increment ".repeat(20)),
msg(&"unmeasured increment ".repeat(30)),
];
let records = vec![record(1, 10_000), record(2, 10_200)];
let prefix = prefix_bytes(&history);
let measured_bytes = prefix[2].saturating_sub(prefix[1]);
let delta_bytes = prefix[3].saturating_sub(prefix[2]);
let expected_delta = (delta_bytes as u128 * 200_u128 / measured_bytes as u128) as u64;
let est = total_tokens(&history, &records);
assert_eq!(est.source, EstimateSource::Extrapolated);
assert_eq!(est.tokens, 10_200 + expected_delta);
assert_ne!(est.tokens, 10_200 + delta_bytes / 4);
}
#[test]
fn total_zero_history_is_zero() {
let est = total_tokens(&[], &[]);

View File

@ -640,6 +640,49 @@ mod tests {
assert_eq!(count.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn pre_llm_request_does_not_yield_from_single_measurement_history_rate_projection() {
let count = Arc::new(AtomicUsize::new(0));
let registry = registry_with_pre_llm_hook(count.clone());
let ctx_items = vec![
Item::user_message("first"),
Item::user_message("tool output ".repeat(400)),
];
let record = UsageRecord {
history_len: 1,
input_total_tokens: 11_124,
cache_read_tokens: 0,
cache_write_tokens: 0,
output_tokens: 0,
};
let prefix = llm_worker::token_counter::prefix_bytes(&ctx_items);
let delta_bytes = prefix[2].saturating_sub(prefix[1]);
let old_projection =
11_124 + (delta_bytes as u128 * 11_124_u128 / prefix[1] as u128) as u64;
let corrected = total_tokens(&ctx_items, std::slice::from_ref(&record)).tokens;
let threshold = corrected + 100;
assert!(old_projection > threshold);
let state = Arc::new(CompactState::new(None, Some(threshold), 2));
let history = Arc::new(Mutex::new(vec![record]));
let interceptor = PodInterceptor::new(
registry,
Some(state),
Some(history),
NotifyBuffer::new(),
Arc::new(Mutex::new(Vec::new())),
TaskStore::new(),
Arc::new(TaskReminderState::new()),
PromptCatalog::builtins_only().unwrap(),
None,
);
let mut ctx = ctx_items;
let action = interceptor.pre_llm_request(&mut ctx).await;
assert!(matches!(action, PreRequestAction::Continue));
assert_eq!(count.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn pre_llm_request_does_not_yield_when_only_post_run_threshold_set() {
// request_threshold = None → safety-net check is inert inside the turn