fix: compact時にToolCallとOutputの間でCutしてしまう問題
This commit is contained in:
parent
2b5da965ca
commit
0141880b9d
|
|
@ -47,22 +47,71 @@ fn split_for_retained_impl(history: &[Item], records: &[UsageRecord], retained:
|
|||
// (内部で毎回再計算すると O(n²) になる)。将来ボトルネックになれば
|
||||
// record 境界で二分探索に置き換える。
|
||||
let mut chosen_source = current.source;
|
||||
let mut cut_index = history.len();
|
||||
for idx in 1..=history.len() {
|
||||
let est = tokens_at(history, records, idx, &prefix);
|
||||
if est.tokens >= target {
|
||||
chosen_source = est.source;
|
||||
return SplitPoint {
|
||||
index: idx,
|
||||
source: chosen_source,
|
||||
};
|
||||
cut_index = idx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
SplitPoint {
|
||||
index: history.len(),
|
||||
index: balance_to_pair_boundary(history, cut_index),
|
||||
source: chosen_source,
|
||||
}
|
||||
}
|
||||
|
||||
/// `history[cut..]` が `ToolCall` / `ToolResult` のペア境界を尊重するよう
|
||||
/// `cut` を後退させる。
|
||||
///
|
||||
/// LLM API は「`ToolResult` を送るならその `ToolCall` も同じ request に
|
||||
/// 含まれていなければならない」というバリデーションを持つ。トークン数
|
||||
/// だけで切った `cut` は並列 tool 呼び出しの途中に落ちうるので、retained
|
||||
/// 側の先頭に対応 `ToolCall` を持たない `ToolResult`(orphan)が残ると
|
||||
/// 次セッション初回 request が API バリデーションで弾かれる。
|
||||
///
|
||||
/// 対策は「retained に入る `ToolResult` について、対応 `ToolCall` も
|
||||
/// retained に含まれる位置まで `cut` を引き下げる」こと。retained_tokens
|
||||
/// 予算は超えうるが、ここでは直接 LLM に投げる訳ではなく次の
|
||||
/// `pre_llm_request` で再評価されるだけなので safe。
|
||||
///
|
||||
/// アルゴリズム: history を末尾から走査し、retained 範囲内の `ToolResult`
|
||||
/// に出会うたびに対応 `ToolCall` の位置で `cut` を min 更新する。`cut` が
|
||||
/// 下がると以前は要約側だった位置が retained に入るので、後続走査で連鎖的
|
||||
/// に正しい位置まで引き下がる。`ToolCall` の `call_id` はユニークなので
|
||||
/// 事前にマップ化して O(n) で済ます。
|
||||
fn balance_to_pair_boundary(history: &[Item], cut: usize) -> usize {
|
||||
let mut idx = cut.min(history.len());
|
||||
if idx == 0 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let call_positions: std::collections::HashMap<&str, usize> = history
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, item)| match item {
|
||||
Item::ToolCall { call_id, .. } => Some((call_id.as_str(), i)),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut k = history.len();
|
||||
while k > 0 {
|
||||
k -= 1;
|
||||
if k >= idx {
|
||||
if let Item::ToolResult { call_id, .. } = &history[k] {
|
||||
if let Some(&call_pos) = call_positions.get(call_id.as_str()) {
|
||||
if call_pos < idx {
|
||||
idx = call_pos;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
idx
|
||||
}
|
||||
|
||||
/// 1 つの ToolResult 項目について、`content` を `None` に射影したとき
|
||||
/// 減少するシリアライズ後バイト数。ToolResult 以外や既に content=None
|
||||
/// の item は 0 を返す。
|
||||
|
|
@ -305,4 +354,109 @@ mod tests {
|
|||
let est = savings_for_prune_impl(&history, &[record(1, 100)], &[99]);
|
||||
assert_eq!(est.tokens, 0);
|
||||
}
|
||||
|
||||
fn tc(call_id: &str) -> Item {
|
||||
Item::tool_call(call_id, "Read", "{}")
|
||||
}
|
||||
|
||||
fn tr(call_id: &str) -> Item {
|
||||
Item::tool_result(call_id, "summary")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn balance_noop_on_clean_message_boundary() {
|
||||
let history = vec![msg("a"), msg("b"), msg("c")];
|
||||
assert_eq!(balance_to_pair_boundary(&history, 2), 2);
|
||||
assert_eq!(balance_to_pair_boundary(&history, 0), 0);
|
||||
assert_eq!(balance_to_pair_boundary(&history, 3), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn balance_retreats_from_inside_parallel_tool_results() {
|
||||
// [Msg, TC_a, TC_b, TC_c, TR_a, TR_b, TR_c]
|
||||
// cut=5 → retained=[TR_b, TR_c]。TR_c の TC は idx=3、TR_b は idx=2 →
|
||||
// idx=2 まで後退。だが retained に TR_a (idx=4) が新たに入り、その TC_a
|
||||
// は idx=1 でまだ外 → 連鎖後退で最終的に idx=1。retained は
|
||||
// [TC_a, TC_b, TC_c, TR_a, TR_b, TR_c]。
|
||||
let history = vec![
|
||||
msg("u"),
|
||||
tc("a"),
|
||||
tc("b"),
|
||||
tc("c"),
|
||||
tr("a"),
|
||||
tr("b"),
|
||||
tr("c"),
|
||||
];
|
||||
assert_eq!(balance_to_pair_boundary(&history, 5), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn balance_retreats_between_call_and_result() {
|
||||
// [TC_a, TR_a, TC_b, TR_b]。cut=3 → retained=[TR_b] orphan。
|
||||
// TC_b は idx=2 → cut=2。retained=[TC_b, TR_b]。
|
||||
let history = vec![tc("a"), tr("a"), tc("b"), tr("b")];
|
||||
assert_eq!(balance_to_pair_boundary(&history, 3), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn balance_cascades_through_nested_pairs() {
|
||||
// [TC_a, TC_b, TR_b, TR_a, TC_c, TR_c]。cut=3 → retained=[TR_a, TC_c, TR_c]。
|
||||
// TR_a の TC は idx=0 → cut=0。retained=full。
|
||||
let history = vec![tc("a"), tc("b"), tr("b"), tr("a"), tc("c"), tr("c")];
|
||||
assert_eq!(balance_to_pair_boundary(&history, 3), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn balance_noop_when_cut_at_pair_boundary() {
|
||||
// [TC_a, TR_a, Msg, TC_b, TR_b]。cut=2 → retained=[Msg, TC_b, TR_b] balanced。
|
||||
let history = vec![tc("a"), tr("a"), msg("u"), tc("b"), tr("b")];
|
||||
assert_eq!(balance_to_pair_boundary(&history, 2), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn balance_handles_orphan_result_without_matching_call() {
|
||||
// ToolCall がそもそも存在しない ToolResult は触らない(壊れた history は
|
||||
// ここでは直しようがない)。cut=1 → そのまま 1 を返す。
|
||||
let history = vec![msg("u"), tr("zombie")];
|
||||
assert_eq!(balance_to_pair_boundary(&history, 1), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn balance_keeps_cut_when_call_is_inside_retained() {
|
||||
// [Msg, TC_a, TR_a]。cut=1 → retained=[TC_a, TR_a]。TR_a の call_pos=1 >= idx=1。OK。
|
||||
let history = vec![msg("u"), tc("a"), tr("a")];
|
||||
assert_eq!(balance_to_pair_boundary(&history, 1), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn split_for_retained_aligns_to_pair_boundary() {
|
||||
// 並列 TC*3 / TR*3 ターン後に Msg を 1 件足し、retained=Msg のサイズ相当に
|
||||
// 設定。トークン的には cut=末尾近くだが、orphan を避けるため TC 群の手前
|
||||
// まで後退するはず。
|
||||
let history = vec![
|
||||
msg("user"),
|
||||
tc("a"),
|
||||
tc("b"),
|
||||
tc("c"),
|
||||
tr("a"),
|
||||
tr("b"),
|
||||
tr("c"),
|
||||
msg("tail"),
|
||||
];
|
||||
let total_bytes: u64 = history.iter().map(item_bytes).sum();
|
||||
let records = vec![record(history.len(), total_bytes)]; // rate = 1 tok/byte
|
||||
// tail の item_bytes 相当のみ retain したい。
|
||||
let tail_tokens = item_bytes(&history[7]);
|
||||
let cut = split_for_retained_impl(&history, &records, tail_tokens);
|
||||
// token 単独だと cut は 7(tail のみ retained)になるが、retained 先頭が
|
||||
// Msg なら balance しなくて OK。balance helper の no-op を確認する意味も込めて
|
||||
// index == 7 を期待する。
|
||||
assert_eq!(cut.index, 7);
|
||||
|
||||
// 逆に retained をやや増やしてトークン的に cut=6(TR_c のみ retained)に
|
||||
// させると、TR_c は orphan なので balance が 1 まで後退するはず。
|
||||
let big_retain = tail_tokens + item_bytes(&history[6]);
|
||||
let cut = split_for_retained_impl(&history, &records, big_retain);
|
||||
assert_eq!(cut.index, 1);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user