diff --git a/crates/llm-worker/src/token_counter.rs b/crates/llm-worker/src/token_counter.rs index 5679154f..f7d4482e 100644 --- a/crates/llm-worker/src/token_counter.rs +++ b/crates/llm-worker/src/token_counter.rs @@ -6,7 +6,8 @@ //! # 方針 //! //! - ローカルトークナイザは持たない。実測値があればそれを採用し、 -//! measurement 間はバイト数で按分、最新 measurement より先は最終 rate で外挿する +//! measurement 間はバイト数で按分、最新 measurement より先は測定済みの増分 rate +//! または byte/4 fallback で外挿する //! - 推定の出どころは [`EstimateSource`] で呼び出し側に明示する。 //! 課金判断には使えないが、compact / prune / memory extract trigger 等の //! 閾値判定には十分な精度 @@ -119,17 +120,35 @@ pub fn tokens_at( (Some(lo), None) => { let lo_bytes = prefix[lo.history_len.min(cap)]; let at_bytes = prefix[index]; - if lo_bytes == 0 || lo.input_total_tokens == 0 { - return TokenEstimate { - tokens: lo.input_total_tokens, - source: EstimateSource::Extrapolated, - }; - } let delta_bytes = at_bytes.saturating_sub(lo_bytes); - let delta_tokens = - (delta_bytes as u128 * lo.input_total_tokens as u128 / lo_bytes as u128) as u64; + + let mut measured_span = None; + for pair in records.windows(2) { + let older = &pair[0]; + let newer = &pair[1]; + if newer.history_len > lo.history_len { + break; + } + + let older_bytes = prefix[older.history_len.min(cap)]; + let newer_bytes = prefix[newer.history_len.min(cap)]; + let span_bytes = newer_bytes.saturating_sub(older_bytes); + let span_tokens = newer + .input_total_tokens + .saturating_sub(older.input_total_tokens); + if span_bytes > 0 && span_tokens > 0 { + measured_span = Some((span_tokens, span_bytes)); + } + } + + let delta_tokens = if let Some((span_tokens, span_bytes)) = measured_span { + (delta_bytes as u128 * span_tokens as u128 / span_bytes as u128) as u64 + } else { + delta_bytes / 4 + }; + TokenEstimate { - tokens: lo.input_total_tokens + delta_tokens, + tokens: lo.input_total_tokens.saturating_add(delta_tokens), source: EstimateSource::Extrapolated, } } @@ -214,6 +233,47 @@ mod tests { assert!(est.tokens > 100); } + #[test] + fn extrapolation_after_single_measurement_uses_byte_fallback_not_total_prompt_rate() { + let history = vec![msg("first"), msg(&"tool output ".repeat(400))]; + let records = vec![record(1, 11_124)]; + let prefix = prefix_bytes(&history); + let delta_bytes = prefix[2].saturating_sub(prefix[1]); + + let est = total_tokens(&history, &records); + + assert_eq!(est.source, EstimateSource::Extrapolated); + assert_eq!(est.tokens, 11_124 + delta_bytes / 4); + + let old_projection = + 11_124 + (delta_bytes as u128 * 11_124_u128 / prefix[1] as u128) as u64; + assert!( + old_projection > est.tokens.saturating_mul(10), + "old_projection={old_projection}, corrected={}", + est.tokens + ); + } + + #[test] + fn extrapolation_prefers_latest_measured_incremental_span_rate() { + let history = vec![ + msg("first"), + msg(&"measured increment ".repeat(20)), + msg(&"unmeasured increment ".repeat(30)), + ]; + let records = vec![record(1, 10_000), record(2, 10_200)]; + let prefix = prefix_bytes(&history); + let measured_bytes = prefix[2].saturating_sub(prefix[1]); + let delta_bytes = prefix[3].saturating_sub(prefix[2]); + let expected_delta = (delta_bytes as u128 * 200_u128 / measured_bytes as u128) as u64; + + let est = total_tokens(&history, &records); + + assert_eq!(est.source, EstimateSource::Extrapolated); + assert_eq!(est.tokens, 10_200 + expected_delta); + assert_ne!(est.tokens, 10_200 + delta_bytes / 4); + } + #[test] fn total_zero_history_is_zero() { let est = total_tokens(&[], &[]); diff --git a/crates/pod/src/ipc/interceptor.rs b/crates/pod/src/ipc/interceptor.rs index 8f2fce0f..90e2d614 100644 --- a/crates/pod/src/ipc/interceptor.rs +++ b/crates/pod/src/ipc/interceptor.rs @@ -640,6 +640,49 @@ mod tests { assert_eq!(count.load(Ordering::Relaxed), 1); } + #[tokio::test] + async fn pre_llm_request_does_not_yield_from_single_measurement_history_rate_projection() { + let count = Arc::new(AtomicUsize::new(0)); + let registry = registry_with_pre_llm_hook(count.clone()); + let ctx_items = vec![ + Item::user_message("first"), + Item::user_message("tool output ".repeat(400)), + ]; + let record = UsageRecord { + history_len: 1, + input_total_tokens: 11_124, + cache_read_tokens: 0, + cache_write_tokens: 0, + output_tokens: 0, + }; + let prefix = llm_worker::token_counter::prefix_bytes(&ctx_items); + let delta_bytes = prefix[2].saturating_sub(prefix[1]); + let old_projection = + 11_124 + (delta_bytes as u128 * 11_124_u128 / prefix[1] as u128) as u64; + let corrected = total_tokens(&ctx_items, std::slice::from_ref(&record)).tokens; + let threshold = corrected + 100; + assert!(old_projection > threshold); + + let state = Arc::new(CompactState::new(None, Some(threshold), 2)); + let history = Arc::new(Mutex::new(vec![record])); + let interceptor = PodInterceptor::new( + registry, + Some(state), + Some(history), + NotifyBuffer::new(), + Arc::new(Mutex::new(Vec::new())), + TaskStore::new(), + Arc::new(TaskReminderState::new()), + PromptCatalog::builtins_only().unwrap(), + None, + ); + let mut ctx = ctx_items; + let action = interceptor.pre_llm_request(&mut ctx).await; + + assert!(matches!(action, PreRequestAction::Continue)); + assert_eq!(count.load(Ordering::Relaxed), 1); + } + #[tokio::test] async fn pre_llm_request_does_not_yield_when_only_post_run_threshold_set() { // request_threshold = None → safety-net check is inert inside the turn