diff --git a/crates/pod/src/compact/token_counter.rs b/crates/pod/src/compact/token_counter.rs index c8787db3..dbf6cd79 100644 --- a/crates/pod/src/compact/token_counter.rs +++ b/crates/pod/src/compact/token_counter.rs @@ -47,22 +47,71 @@ fn split_for_retained_impl(history: &[Item], records: &[UsageRecord], retained: // (内部で毎回再計算すると O(n²) になる)。将来ボトルネックになれば // record 境界で二分探索に置き換える。 let mut chosen_source = current.source; + let mut cut_index = history.len(); for idx in 1..=history.len() { let est = tokens_at(history, records, idx, &prefix); if est.tokens >= target { chosen_source = est.source; - return SplitPoint { - index: idx, - source: chosen_source, - }; + cut_index = idx; + break; } } SplitPoint { - index: history.len(), + index: balance_to_pair_boundary(history, cut_index), source: chosen_source, } } +/// `history[cut..]` が `ToolCall` / `ToolResult` のペア境界を尊重するよう +/// `cut` を後退させる。 +/// +/// LLM API は「`ToolResult` を送るならその `ToolCall` も同じ request に +/// 含まれていなければならない」というバリデーションを持つ。トークン数 +/// だけで切った `cut` は並列 tool 呼び出しの途中に落ちうるので、retained +/// 側の先頭に対応 `ToolCall` を持たない `ToolResult`(orphan)が残ると +/// 次セッション初回 request が API バリデーションで弾かれる。 +/// +/// 対策は「retained に入る `ToolResult` について、対応 `ToolCall` も +/// retained に含まれる位置まで `cut` を引き下げる」こと。retained_tokens +/// 予算は超えうるが、ここでは直接 LLM に投げる訳ではなく次の +/// `pre_llm_request` で再評価されるだけなので safe。 +/// +/// アルゴリズム: history を末尾から走査し、retained 範囲内の `ToolResult` +/// に出会うたびに対応 `ToolCall` の位置で `cut` を min 更新する。`cut` が +/// 下がると以前は要約側だった位置が retained に入るので、後続走査で連鎖的 +/// に正しい位置まで引き下がる。`ToolCall` の `call_id` はユニークなので +/// 事前にマップ化して O(n) で済ます。 +fn balance_to_pair_boundary(history: &[Item], cut: usize) -> usize { + let mut idx = cut.min(history.len()); + if idx == 0 { + return 0; + } + + let call_positions: std::collections::HashMap<&str, usize> = history + .iter() + .enumerate() + .filter_map(|(i, item)| match item { + Item::ToolCall { call_id, .. } => Some((call_id.as_str(), i)), + _ => None, + }) + .collect(); + + let mut k = history.len(); + while k > 0 { + k -= 1; + if k >= idx { + if let Item::ToolResult { call_id, .. } = &history[k] { + if let Some(&call_pos) = call_positions.get(call_id.as_str()) { + if call_pos < idx { + idx = call_pos; + } + } + } + } + } + idx +} + /// 1 つの ToolResult 項目について、`content` を `None` に射影したとき /// 減少するシリアライズ後バイト数。ToolResult 以外や既に content=None /// の item は 0 を返す。 @@ -305,4 +354,109 @@ mod tests { let est = savings_for_prune_impl(&history, &[record(1, 100)], &[99]); assert_eq!(est.tokens, 0); } + + fn tc(call_id: &str) -> Item { + Item::tool_call(call_id, "Read", "{}") + } + + fn tr(call_id: &str) -> Item { + Item::tool_result(call_id, "summary") + } + + #[test] + fn balance_noop_on_clean_message_boundary() { + let history = vec![msg("a"), msg("b"), msg("c")]; + assert_eq!(balance_to_pair_boundary(&history, 2), 2); + assert_eq!(balance_to_pair_boundary(&history, 0), 0); + assert_eq!(balance_to_pair_boundary(&history, 3), 3); + } + + #[test] + fn balance_retreats_from_inside_parallel_tool_results() { + // [Msg, TC_a, TC_b, TC_c, TR_a, TR_b, TR_c] + // cut=5 → retained=[TR_b, TR_c]。TR_c の TC は idx=3、TR_b は idx=2 → + // idx=2 まで後退。だが retained に TR_a (idx=4) が新たに入り、その TC_a + // は idx=1 でまだ外 → 連鎖後退で最終的に idx=1。retained は + // [TC_a, TC_b, TC_c, TR_a, TR_b, TR_c]。 + let history = vec![ + msg("u"), + tc("a"), + tc("b"), + tc("c"), + tr("a"), + tr("b"), + tr("c"), + ]; + assert_eq!(balance_to_pair_boundary(&history, 5), 1); + } + + #[test] + fn balance_retreats_between_call_and_result() { + // [TC_a, TR_a, TC_b, TR_b]。cut=3 → retained=[TR_b] orphan。 + // TC_b は idx=2 → cut=2。retained=[TC_b, TR_b]。 + let history = vec![tc("a"), tr("a"), tc("b"), tr("b")]; + assert_eq!(balance_to_pair_boundary(&history, 3), 2); + } + + #[test] + fn balance_cascades_through_nested_pairs() { + // [TC_a, TC_b, TR_b, TR_a, TC_c, TR_c]。cut=3 → retained=[TR_a, TC_c, TR_c]。 + // TR_a の TC は idx=0 → cut=0。retained=full。 + let history = vec![tc("a"), tc("b"), tr("b"), tr("a"), tc("c"), tr("c")]; + assert_eq!(balance_to_pair_boundary(&history, 3), 0); + } + + #[test] + fn balance_noop_when_cut_at_pair_boundary() { + // [TC_a, TR_a, Msg, TC_b, TR_b]。cut=2 → retained=[Msg, TC_b, TR_b] balanced。 + let history = vec![tc("a"), tr("a"), msg("u"), tc("b"), tr("b")]; + assert_eq!(balance_to_pair_boundary(&history, 2), 2); + } + + #[test] + fn balance_handles_orphan_result_without_matching_call() { + // ToolCall がそもそも存在しない ToolResult は触らない(壊れた history は + // ここでは直しようがない)。cut=1 → そのまま 1 を返す。 + let history = vec![msg("u"), tr("zombie")]; + assert_eq!(balance_to_pair_boundary(&history, 1), 1); + } + + #[test] + fn balance_keeps_cut_when_call_is_inside_retained() { + // [Msg, TC_a, TR_a]。cut=1 → retained=[TC_a, TR_a]。TR_a の call_pos=1 >= idx=1。OK。 + let history = vec![msg("u"), tc("a"), tr("a")]; + assert_eq!(balance_to_pair_boundary(&history, 1), 1); + } + + #[test] + fn split_for_retained_aligns_to_pair_boundary() { + // 並列 TC*3 / TR*3 ターン後に Msg を 1 件足し、retained=Msg のサイズ相当に + // 設定。トークン的には cut=末尾近くだが、orphan を避けるため TC 群の手前 + // まで後退するはず。 + let history = vec![ + msg("user"), + tc("a"), + tc("b"), + tc("c"), + tr("a"), + tr("b"), + tr("c"), + msg("tail"), + ]; + let total_bytes: u64 = history.iter().map(item_bytes).sum(); + let records = vec![record(history.len(), total_bytes)]; // rate = 1 tok/byte + // tail の item_bytes 相当のみ retain したい。 + let tail_tokens = item_bytes(&history[7]); + let cut = split_for_retained_impl(&history, &records, tail_tokens); + // token 単独だと cut は 7(tail のみ retained)になるが、retained 先頭が + // Msg なら balance しなくて OK。balance helper の no-op を確認する意味も込めて + // index == 7 を期待する。 + assert_eq!(cut.index, 7); + + // 逆に retained をやや増やしてトークン的に cut=6(TR_c のみ retained)に + // させると、TR_c は orphan なので balance が 1 まで後退するはず。 + let big_retain = tail_tokens + item_bytes(&history[6]); + let cut = split_for_retained_impl(&history, &records, big_retain); + assert_eq!(cut.index, 1); + } }