fix: correct prompt occupancy extrapolation
This commit is contained in:
parent
3ea005822e
commit
375d0216d1
|
|
@ -6,7 +6,8 @@
|
|||
//! # 方針
|
||||
//!
|
||||
//! - ローカルトークナイザは持たない。実測値があればそれを採用し、
|
||||
//! measurement 間はバイト数で按分、最新 measurement より先は最終 rate で外挿する
|
||||
//! measurement 間はバイト数で按分、最新 measurement より先は測定済みの増分 rate
|
||||
//! または byte/4 fallback で外挿する
|
||||
//! - 推定の出どころは [`EstimateSource`] で呼び出し側に明示する。
|
||||
//! 課金判断には使えないが、compact / prune / memory extract trigger 等の
|
||||
//! 閾値判定には十分な精度
|
||||
|
|
@ -119,17 +120,35 @@ pub fn tokens_at(
|
|||
(Some(lo), None) => {
|
||||
let lo_bytes = prefix[lo.history_len.min(cap)];
|
||||
let at_bytes = prefix[index];
|
||||
if lo_bytes == 0 || lo.input_total_tokens == 0 {
|
||||
return TokenEstimate {
|
||||
tokens: lo.input_total_tokens,
|
||||
source: EstimateSource::Extrapolated,
|
||||
};
|
||||
}
|
||||
let delta_bytes = at_bytes.saturating_sub(lo_bytes);
|
||||
let delta_tokens =
|
||||
(delta_bytes as u128 * lo.input_total_tokens as u128 / lo_bytes as u128) as u64;
|
||||
|
||||
let mut measured_span = None;
|
||||
for pair in records.windows(2) {
|
||||
let older = &pair[0];
|
||||
let newer = &pair[1];
|
||||
if newer.history_len > lo.history_len {
|
||||
break;
|
||||
}
|
||||
|
||||
let older_bytes = prefix[older.history_len.min(cap)];
|
||||
let newer_bytes = prefix[newer.history_len.min(cap)];
|
||||
let span_bytes = newer_bytes.saturating_sub(older_bytes);
|
||||
let span_tokens = newer
|
||||
.input_total_tokens
|
||||
.saturating_sub(older.input_total_tokens);
|
||||
if span_bytes > 0 && span_tokens > 0 {
|
||||
measured_span = Some((span_tokens, span_bytes));
|
||||
}
|
||||
}
|
||||
|
||||
let delta_tokens = if let Some((span_tokens, span_bytes)) = measured_span {
|
||||
(delta_bytes as u128 * span_tokens as u128 / span_bytes as u128) as u64
|
||||
} else {
|
||||
delta_bytes / 4
|
||||
};
|
||||
|
||||
TokenEstimate {
|
||||
tokens: lo.input_total_tokens + delta_tokens,
|
||||
tokens: lo.input_total_tokens.saturating_add(delta_tokens),
|
||||
source: EstimateSource::Extrapolated,
|
||||
}
|
||||
}
|
||||
|
|
@ -214,6 +233,47 @@ mod tests {
|
|||
assert!(est.tokens > 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extrapolation_after_single_measurement_uses_byte_fallback_not_total_prompt_rate() {
|
||||
let history = vec![msg("first"), msg(&"tool output ".repeat(400))];
|
||||
let records = vec![record(1, 11_124)];
|
||||
let prefix = prefix_bytes(&history);
|
||||
let delta_bytes = prefix[2].saturating_sub(prefix[1]);
|
||||
|
||||
let est = total_tokens(&history, &records);
|
||||
|
||||
assert_eq!(est.source, EstimateSource::Extrapolated);
|
||||
assert_eq!(est.tokens, 11_124 + delta_bytes / 4);
|
||||
|
||||
let old_projection =
|
||||
11_124 + (delta_bytes as u128 * 11_124_u128 / prefix[1] as u128) as u64;
|
||||
assert!(
|
||||
old_projection > est.tokens.saturating_mul(10),
|
||||
"old_projection={old_projection}, corrected={}",
|
||||
est.tokens
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extrapolation_prefers_latest_measured_incremental_span_rate() {
|
||||
let history = vec![
|
||||
msg("first"),
|
||||
msg(&"measured increment ".repeat(20)),
|
||||
msg(&"unmeasured increment ".repeat(30)),
|
||||
];
|
||||
let records = vec![record(1, 10_000), record(2, 10_200)];
|
||||
let prefix = prefix_bytes(&history);
|
||||
let measured_bytes = prefix[2].saturating_sub(prefix[1]);
|
||||
let delta_bytes = prefix[3].saturating_sub(prefix[2]);
|
||||
let expected_delta = (delta_bytes as u128 * 200_u128 / measured_bytes as u128) as u64;
|
||||
|
||||
let est = total_tokens(&history, &records);
|
||||
|
||||
assert_eq!(est.source, EstimateSource::Extrapolated);
|
||||
assert_eq!(est.tokens, 10_200 + expected_delta);
|
||||
assert_ne!(est.tokens, 10_200 + delta_bytes / 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn total_zero_history_is_zero() {
|
||||
let est = total_tokens(&[], &[]);
|
||||
|
|
|
|||
|
|
@ -640,6 +640,49 @@ mod tests {
|
|||
assert_eq!(count.load(Ordering::Relaxed), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pre_llm_request_does_not_yield_from_single_measurement_history_rate_projection() {
|
||||
let count = Arc::new(AtomicUsize::new(0));
|
||||
let registry = registry_with_pre_llm_hook(count.clone());
|
||||
let ctx_items = vec![
|
||||
Item::user_message("first"),
|
||||
Item::user_message("tool output ".repeat(400)),
|
||||
];
|
||||
let record = UsageRecord {
|
||||
history_len: 1,
|
||||
input_total_tokens: 11_124,
|
||||
cache_read_tokens: 0,
|
||||
cache_write_tokens: 0,
|
||||
output_tokens: 0,
|
||||
};
|
||||
let prefix = llm_worker::token_counter::prefix_bytes(&ctx_items);
|
||||
let delta_bytes = prefix[2].saturating_sub(prefix[1]);
|
||||
let old_projection =
|
||||
11_124 + (delta_bytes as u128 * 11_124_u128 / prefix[1] as u128) as u64;
|
||||
let corrected = total_tokens(&ctx_items, std::slice::from_ref(&record)).tokens;
|
||||
let threshold = corrected + 100;
|
||||
assert!(old_projection > threshold);
|
||||
|
||||
let state = Arc::new(CompactState::new(None, Some(threshold), 2));
|
||||
let history = Arc::new(Mutex::new(vec![record]));
|
||||
let interceptor = PodInterceptor::new(
|
||||
registry,
|
||||
Some(state),
|
||||
Some(history),
|
||||
NotifyBuffer::new(),
|
||||
Arc::new(Mutex::new(Vec::new())),
|
||||
TaskStore::new(),
|
||||
Arc::new(TaskReminderState::new()),
|
||||
PromptCatalog::builtins_only().unwrap(),
|
||||
None,
|
||||
);
|
||||
let mut ctx = ctx_items;
|
||||
let action = interceptor.pre_llm_request(&mut ctx).await;
|
||||
|
||||
assert!(matches!(action, PreRequestAction::Continue));
|
||||
assert_eq!(count.load(Ordering::Relaxed), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pre_llm_request_does_not_yield_when_only_post_run_threshold_set() {
|
||||
// request_threshold = None → safety-net check is inert inside the turn
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user