fix: correct prompt occupancy extrapolation
This commit is contained in:
parent
3ea005822e
commit
375d0216d1
|
|
@ -6,7 +6,8 @@
|
||||||
//! # 方針
|
//! # 方針
|
||||||
//!
|
//!
|
||||||
//! - ローカルトークナイザは持たない。実測値があればそれを採用し、
|
//! - ローカルトークナイザは持たない。実測値があればそれを採用し、
|
||||||
//! measurement 間はバイト数で按分、最新 measurement より先は最終 rate で外挿する
|
//! measurement 間はバイト数で按分、最新 measurement より先は測定済みの増分 rate
|
||||||
|
//! または byte/4 fallback で外挿する
|
||||||
//! - 推定の出どころは [`EstimateSource`] で呼び出し側に明示する。
|
//! - 推定の出どころは [`EstimateSource`] で呼び出し側に明示する。
|
||||||
//! 課金判断には使えないが、compact / prune / memory extract trigger 等の
|
//! 課金判断には使えないが、compact / prune / memory extract trigger 等の
|
||||||
//! 閾値判定には十分な精度
|
//! 閾値判定には十分な精度
|
||||||
|
|
@ -119,17 +120,35 @@ pub fn tokens_at(
|
||||||
(Some(lo), None) => {
|
(Some(lo), None) => {
|
||||||
let lo_bytes = prefix[lo.history_len.min(cap)];
|
let lo_bytes = prefix[lo.history_len.min(cap)];
|
||||||
let at_bytes = prefix[index];
|
let at_bytes = prefix[index];
|
||||||
if lo_bytes == 0 || lo.input_total_tokens == 0 {
|
|
||||||
return TokenEstimate {
|
|
||||||
tokens: lo.input_total_tokens,
|
|
||||||
source: EstimateSource::Extrapolated,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
let delta_bytes = at_bytes.saturating_sub(lo_bytes);
|
let delta_bytes = at_bytes.saturating_sub(lo_bytes);
|
||||||
let delta_tokens =
|
|
||||||
(delta_bytes as u128 * lo.input_total_tokens as u128 / lo_bytes as u128) as u64;
|
let mut measured_span = None;
|
||||||
|
for pair in records.windows(2) {
|
||||||
|
let older = &pair[0];
|
||||||
|
let newer = &pair[1];
|
||||||
|
if newer.history_len > lo.history_len {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let older_bytes = prefix[older.history_len.min(cap)];
|
||||||
|
let newer_bytes = prefix[newer.history_len.min(cap)];
|
||||||
|
let span_bytes = newer_bytes.saturating_sub(older_bytes);
|
||||||
|
let span_tokens = newer
|
||||||
|
.input_total_tokens
|
||||||
|
.saturating_sub(older.input_total_tokens);
|
||||||
|
if span_bytes > 0 && span_tokens > 0 {
|
||||||
|
measured_span = Some((span_tokens, span_bytes));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let delta_tokens = if let Some((span_tokens, span_bytes)) = measured_span {
|
||||||
|
(delta_bytes as u128 * span_tokens as u128 / span_bytes as u128) as u64
|
||||||
|
} else {
|
||||||
|
delta_bytes / 4
|
||||||
|
};
|
||||||
|
|
||||||
TokenEstimate {
|
TokenEstimate {
|
||||||
tokens: lo.input_total_tokens + delta_tokens,
|
tokens: lo.input_total_tokens.saturating_add(delta_tokens),
|
||||||
source: EstimateSource::Extrapolated,
|
source: EstimateSource::Extrapolated,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -214,6 +233,47 @@ mod tests {
|
||||||
assert!(est.tokens > 100);
|
assert!(est.tokens > 100);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extrapolation_after_single_measurement_uses_byte_fallback_not_total_prompt_rate() {
|
||||||
|
let history = vec![msg("first"), msg(&"tool output ".repeat(400))];
|
||||||
|
let records = vec![record(1, 11_124)];
|
||||||
|
let prefix = prefix_bytes(&history);
|
||||||
|
let delta_bytes = prefix[2].saturating_sub(prefix[1]);
|
||||||
|
|
||||||
|
let est = total_tokens(&history, &records);
|
||||||
|
|
||||||
|
assert_eq!(est.source, EstimateSource::Extrapolated);
|
||||||
|
assert_eq!(est.tokens, 11_124 + delta_bytes / 4);
|
||||||
|
|
||||||
|
let old_projection =
|
||||||
|
11_124 + (delta_bytes as u128 * 11_124_u128 / prefix[1] as u128) as u64;
|
||||||
|
assert!(
|
||||||
|
old_projection > est.tokens.saturating_mul(10),
|
||||||
|
"old_projection={old_projection}, corrected={}",
|
||||||
|
est.tokens
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extrapolation_prefers_latest_measured_incremental_span_rate() {
|
||||||
|
let history = vec![
|
||||||
|
msg("first"),
|
||||||
|
msg(&"measured increment ".repeat(20)),
|
||||||
|
msg(&"unmeasured increment ".repeat(30)),
|
||||||
|
];
|
||||||
|
let records = vec![record(1, 10_000), record(2, 10_200)];
|
||||||
|
let prefix = prefix_bytes(&history);
|
||||||
|
let measured_bytes = prefix[2].saturating_sub(prefix[1]);
|
||||||
|
let delta_bytes = prefix[3].saturating_sub(prefix[2]);
|
||||||
|
let expected_delta = (delta_bytes as u128 * 200_u128 / measured_bytes as u128) as u64;
|
||||||
|
|
||||||
|
let est = total_tokens(&history, &records);
|
||||||
|
|
||||||
|
assert_eq!(est.source, EstimateSource::Extrapolated);
|
||||||
|
assert_eq!(est.tokens, 10_200 + expected_delta);
|
||||||
|
assert_ne!(est.tokens, 10_200 + delta_bytes / 4);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn total_zero_history_is_zero() {
|
fn total_zero_history_is_zero() {
|
||||||
let est = total_tokens(&[], &[]);
|
let est = total_tokens(&[], &[]);
|
||||||
|
|
|
||||||
|
|
@ -640,6 +640,49 @@ mod tests {
|
||||||
assert_eq!(count.load(Ordering::Relaxed), 1);
|
assert_eq!(count.load(Ordering::Relaxed), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn pre_llm_request_does_not_yield_from_single_measurement_history_rate_projection() {
|
||||||
|
let count = Arc::new(AtomicUsize::new(0));
|
||||||
|
let registry = registry_with_pre_llm_hook(count.clone());
|
||||||
|
let ctx_items = vec![
|
||||||
|
Item::user_message("first"),
|
||||||
|
Item::user_message("tool output ".repeat(400)),
|
||||||
|
];
|
||||||
|
let record = UsageRecord {
|
||||||
|
history_len: 1,
|
||||||
|
input_total_tokens: 11_124,
|
||||||
|
cache_read_tokens: 0,
|
||||||
|
cache_write_tokens: 0,
|
||||||
|
output_tokens: 0,
|
||||||
|
};
|
||||||
|
let prefix = llm_worker::token_counter::prefix_bytes(&ctx_items);
|
||||||
|
let delta_bytes = prefix[2].saturating_sub(prefix[1]);
|
||||||
|
let old_projection =
|
||||||
|
11_124 + (delta_bytes as u128 * 11_124_u128 / prefix[1] as u128) as u64;
|
||||||
|
let corrected = total_tokens(&ctx_items, std::slice::from_ref(&record)).tokens;
|
||||||
|
let threshold = corrected + 100;
|
||||||
|
assert!(old_projection > threshold);
|
||||||
|
|
||||||
|
let state = Arc::new(CompactState::new(None, Some(threshold), 2));
|
||||||
|
let history = Arc::new(Mutex::new(vec![record]));
|
||||||
|
let interceptor = PodInterceptor::new(
|
||||||
|
registry,
|
||||||
|
Some(state),
|
||||||
|
Some(history),
|
||||||
|
NotifyBuffer::new(),
|
||||||
|
Arc::new(Mutex::new(Vec::new())),
|
||||||
|
TaskStore::new(),
|
||||||
|
Arc::new(TaskReminderState::new()),
|
||||||
|
PromptCatalog::builtins_only().unwrap(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
let mut ctx = ctx_items;
|
||||||
|
let action = interceptor.pre_llm_request(&mut ctx).await;
|
||||||
|
|
||||||
|
assert!(matches!(action, PreRequestAction::Continue));
|
||||||
|
assert_eq!(count.load(Ordering::Relaxed), 1);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn pre_llm_request_does_not_yield_when_only_post_run_threshold_set() {
|
async fn pre_llm_request_does_not_yield_when_only_post_run_threshold_set() {
|
||||||
// request_threshold = None → safety-net check is inert inside the turn
|
// request_threshold = None → safety-net check is inert inside the turn
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user