feat: rollback empty interrupted turns

This commit is contained in:
Keisuke Hirata 2026-05-23 12:50:46 +09:00
parent df629b4dc6
commit 55dedd173c
5 changed files with 373 additions and 0 deletions

View File

@ -1,5 +1,6 @@
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::atomic::Ordering;
use llm_worker::WorkerError;
use llm_worker::llm_client::client::LlmClient;
@ -303,6 +304,7 @@ fn wire_event_bridges_on_worker<C, St>(
C: LlmClient + Clone + 'static,
St: Store + PodMetadataStore + Clone + 'static,
{
let ai_activity = pod.ai_activity_counter();
let worker = pod.worker_mut();
let tx = event_tx.clone();
@ -329,15 +331,22 @@ fn wire_event_bridges_on_worker<C, St>(
});
let tx = event_tx.clone();
let activity = ai_activity.clone();
worker.on_text_block(move |block| {
let tx_d = tx.clone();
let activity_d = activity.clone();
block.on_delta(move |text| {
activity_d.fetch_add(1, Ordering::SeqCst);
let _ = tx_d.send(Event::TextDelta {
text: text.to_owned(),
});
});
let tx_s = tx.clone();
let activity_s = activity.clone();
block.on_stop(move |text| {
if !text.is_empty() {
activity_s.fetch_add(1, Ordering::SeqCst);
}
let _ = tx_s.send(Event::TextDone {
text: text.to_owned(),
});
@ -345,18 +354,26 @@ fn wire_event_bridges_on_worker<C, St>(
});
let tx = event_tx.clone();
let activity = ai_activity.clone();
worker.on_thinking_block(move |block| {
// Start fires unconditionally so the TUI can show "Thinking..."
// even when the provider doesn't emit plaintext deltas.
activity.fetch_add(1, Ordering::SeqCst);
let _ = tx.send(Event::ThinkingStart);
let tx_d = tx.clone();
let activity_d = activity.clone();
block.on_delta(move |text| {
activity_d.fetch_add(1, Ordering::SeqCst);
let _ = tx_d.send(Event::ThinkingDelta {
text: text.to_owned(),
});
});
let tx_s = tx.clone();
let activity_s = activity.clone();
block.on_stop(move |text| {
if !text.is_empty() {
activity_s.fetch_add(1, Ordering::SeqCst);
}
let _ = tx_s.send(Event::ThinkingDone {
text: text.to_owned(),
});
@ -364,21 +381,27 @@ fn wire_event_bridges_on_worker<C, St>(
});
let tx = event_tx.clone();
let activity = ai_activity.clone();
worker.on_tool_use_block(move |start, block| {
activity.fetch_add(1, Ordering::SeqCst);
let _ = tx.send(Event::ToolCallStart {
id: start.id.clone(),
name: start.name.clone(),
});
let id_for_delta = start.id.clone();
let tx_d = tx.clone();
let activity_d = activity.clone();
block.on_delta(move |json| {
activity_d.fetch_add(1, Ordering::SeqCst);
let _ = tx_d.send(Event::ToolCallArgsDelta {
id: id_for_delta.clone(),
json: json.to_owned(),
});
});
let tx_s = tx.clone();
let activity_s = activity.clone();
block.on_stop(move |call| {
activity_s.fetch_add(1, Ordering::SeqCst);
let _ = tx_s.send(Event::ToolCallDone {
id: call.id.clone(),
name: call.name.clone(),
@ -388,7 +411,9 @@ fn wire_event_bridges_on_worker<C, St>(
});
let tx = event_tx.clone();
let activity = ai_activity.clone();
worker.on_tool_result(move |result| {
activity.fetch_add(1, Ordering::SeqCst);
let _ = tx.send(Event::ToolResult {
id: result.tool_use_id.clone(),
summary: result.summary.clone(),
@ -879,6 +904,7 @@ where
PodRunResult::Finished => (PodStatus::Idle, RunResult::Finished),
PodRunResult::Paused => (PodStatus::Paused, RunResult::Paused),
PodRunResult::LimitReached => (PodStatus::Idle, RunResult::LimitReached),
PodRunResult::RolledBack => (PodStatus::Idle, RunResult::RolledBack),
};
let _ = event_tx.send(Event::RunEnd { result: run_result });
if parent_originated && matches!(run_result, RunResult::Finished) {

View File

@ -6,6 +6,7 @@ use arc_swap::ArcSwap;
use llm_worker::Item;
use llm_worker::llm_client::RequestConfig;
use llm_worker::llm_client::client::LlmClient;
use llm_worker::llm_client::types::Role;
use llm_worker::state::Mutable;
use llm_worker::{ToolOutputLimits, UsageRecord, Worker, WorkerError, WorkerResult};
use session_store::{
@ -121,6 +122,24 @@ impl SegmentState {
}
}
struct EmptyTurnRollbackSnapshot {
history_len: usize,
user_segments_len: usize,
entries_written: usize,
sink_len: usize,
pending_attachments: Vec<SystemItem>,
usage_history_len: usize,
ai_activity_count: usize,
last_run_interrupted: bool,
}
fn is_ai_materialized_item(item: &Item) -> bool {
match item {
Item::Message { role, .. } => *role == Role::Assistant,
Item::ToolCall { .. } | Item::ToolResult { .. } | Item::Reasoning { .. } => true,
}
}
/// Cheap-cloneable bundle of (store + shared session pointer + sink)
/// handed to the worker callback and the interceptor so they can
/// commit `LogEntry` values directly without going through an mpsc
@ -254,6 +273,12 @@ pub struct Pod<C: LlmClient, St: Store> {
/// notifications, events sent here are NOT replayed to clients that
/// connect after the fact — they are fire-and-forget broadcasts.
event_tx: Option<broadcast::Sender<Event>>,
/// Monotonic counter incremented by worker event bridges when an
/// assistant-side execution artifact becomes visible to clients before
/// it is necessarily committed to history (e.g. streaming text deltas).
/// `Pod::run` uses it to avoid rolling back a turn after the UI has
/// already observed AI output.
ai_activity_counter: Arc<AtomicUsize>,
/// Queue of pending `Method::Notify` notifications awaiting
/// injection into the next LLM request. Shared with the
/// PodInterceptor installed in `ensure_interceptor_installed`.
@ -392,6 +417,7 @@ impl<C: LlmClient + Clone + 'static, St: Store + Clone + 'static> Pod<C, St> {
system_prompt_template: None,
alerter: self.alerter.clone(),
event_tx: self.event_tx.clone(),
ai_activity_counter: self.ai_activity_counter.clone(),
pending_notifies: NotifyBuffer::new(),
pending_attachments: Arc::new(Mutex::new(Vec::<SystemItem>::new())),
scope_allocation: None,
@ -534,6 +560,7 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
system_prompt_template: None,
alerter: None,
event_tx: None,
ai_activity_counter: Arc::new(AtomicUsize::new(0)),
pending_notifies: NotifyBuffer::new(),
pending_attachments: Arc::new(Mutex::new(Vec::<SystemItem>::new())),
scope_allocation: None,
@ -901,6 +928,12 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
self.event_tx = Some(event_tx);
}
/// Shared activity counter incremented by worker event bridges when any
/// assistant-side output is surfaced before history persistence.
pub fn ai_activity_counter(&self) -> Arc<AtomicUsize> {
self.ai_activity_counter.clone()
}
fn alert(&self, level: AlertLevel, source: AlertSource, message: String) {
if let Some(n) = self.alerter.as_ref() {
n.alert(level, source, message);
@ -1236,6 +1269,75 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
Ok(())
}
fn capture_empty_turn_rollback_snapshot(&self) -> EmptyTurnRollbackSnapshot {
let pending_attachments = self
.pending_attachments
.lock()
.expect("pending_attachments poisoned")
.clone();
let usage_history_len = self
.usage_history
.lock()
.expect("usage_history poisoned")
.len();
EmptyTurnRollbackSnapshot {
history_len: self.worker().history().len(),
user_segments_len: self.user_segments.len(),
entries_written: self.segment_state.entries_written(),
sink_len: self.sink.len(),
pending_attachments,
usage_history_len,
ai_activity_count: self.ai_activity_counter.load(Ordering::SeqCst),
last_run_interrupted: self.worker().last_run_interrupted(),
}
}
fn should_rollback_empty_turn(
&self,
result: &Result<WorkerResult, WorkerError>,
snapshot: &EmptyTurnRollbackSnapshot,
) -> bool {
if !matches!(result, Err(WorkerError::Cancelled)) {
return false;
}
if self.ai_activity_counter.load(Ordering::SeqCst) != snapshot.ai_activity_count {
return false;
}
!self.worker().history()[snapshot.history_len..]
.iter()
.any(is_ai_materialized_item)
}
fn rollback_empty_turn(
&mut self,
snapshot: EmptyTurnRollbackSnapshot,
) -> Result<(), StoreError> {
self.worker_mut()
.history_mut()
.truncate(snapshot.history_len);
self.worker_mut()
.set_last_run_interrupted(snapshot.last_run_interrupted);
self.user_segments.truncate(snapshot.user_segments_len);
*self
.pending_attachments
.lock()
.expect("pending_attachments poisoned") = snapshot.pending_attachments;
self.usage_history
.lock()
.expect("usage_history poisoned")
.truncate(snapshot.usage_history_len);
let _ = self.usage_tracker.drain();
let _ = self.metrics_tracker.drain();
let loc = self.segment_state.location();
self.store
.truncate(loc.session_id, loc.segment_id, snapshot.entries_written)?;
self.segment_state
.set_entries_written(snapshot.entries_written);
self.sink.truncate_silent(snapshot.sink_len);
Ok(())
}
/// Send user input and run until the LLM turn completes.
///
/// `input` is a typed segment list (see [`protocol::Segment`]). The
@ -1270,6 +1372,8 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
self.prepare_for_run().await?;
let rollback_snapshot = self.capture_empty_turn_rollback_snapshot();
// IDLE → active marker. Commits first so the next UserInput entry
// is contained inside this Invoke range. See `tickets/invoke-turn-llmcall-semantics.md`.
self.commit_entry(LogEntry::Invoke {
@ -1311,6 +1415,11 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
let result = locked.run(flattened).await;
self.worker = Some(locked.unlock());
if self.should_rollback_empty_turn(&result, &rollback_snapshot) {
self.rollback_empty_turn(rollback_snapshot)?;
return Ok(PodRunResult::RolledBack);
}
self.handle_worker_result(result, history_before).await
}
@ -2847,6 +2956,7 @@ where
system_prompt_template: common.system_prompt_template,
alerter: None,
event_tx: None,
ai_activity_counter: Arc::new(AtomicUsize::new(0)),
pending_notifies: NotifyBuffer::new(),
pending_attachments: Arc::new(Mutex::new(Vec::<SystemItem>::new())),
scope_allocation: Some(scope_allocation),
@ -2923,6 +3033,7 @@ where
system_prompt_template: common.system_prompt_template,
alerter: None,
event_tx: None,
ai_activity_counter: Arc::new(AtomicUsize::new(0)),
pending_notifies: NotifyBuffer::new(),
pending_attachments: Arc::new(Mutex::new(Vec::<SystemItem>::new())),
scope_allocation: Some(scope_allocation),
@ -3098,6 +3209,7 @@ where
system_prompt_template: None,
alerter: None,
event_tx: None,
ai_activity_counter: Arc::new(AtomicUsize::new(0)),
pending_notifies: NotifyBuffer::new(),
pending_attachments: Arc::new(Mutex::new(Vec::<SystemItem>::new())),
scope_allocation: Some(scope_allocation),
@ -3184,6 +3296,8 @@ pub enum PodRunResult {
Paused,
/// The worker reached its configured max_turns limit.
LimitReached,
/// The submit-time user turn was rolled back because no AI output was materialized.
RolledBack,
}
impl From<WorkerResult> for PodRunResult {

View File

@ -1470,3 +1470,204 @@ async fn paused_then_run_closes_orphan_tool_use_for_next_request() {
"system note must precede new user message"
);
}
fn item_text_contains(item: &Item, needle: &str) -> bool {
item.as_text().unwrap_or_default().contains(needle)
}
async fn snapshot_contains_user_input(handle: &PodHandle, needle: &str) -> bool {
let stream = tokio::net::UnixStream::connect(handle.runtime_dir.socket_path())
.await
.unwrap();
let (reader, _writer) = stream.into_split();
let mut reader = protocol::stream::JsonLineReader::new(reader);
loop {
let event = reader.next::<Event>().await.unwrap().unwrap();
match event {
Event::Snapshot { entries, .. } => {
return entries.into_iter().any(|value| {
let entry: session_store::LogEntry =
serde_json::from_value(value).expect("LogEntry deserialise");
match entry {
session_store::LogEntry::UserInput { segments, .. } => {
protocol::Segment::flatten_to_text(&segments).contains(needle)
}
_ => false,
}
});
}
Event::Alert(_) => continue,
other => panic!("expected Snapshot first, got {other:?}"),
}
}
}
#[tokio::test]
async fn empty_turn_cancel_rolls_back_submit_entries_and_emits_signal() {
let client = MockClient::sequential(vec![MockResponse::Hang(vec![])]);
let pod = make_pod(client).await;
let handle = spawn_controller(pod).await;
let mut rx = handle.subscribe();
handle.send(Method::run_text("rollback me")).await.unwrap();
wait_for_status(&handle, PodStatus::Running).await;
handle.send(Method::Cancel).await.unwrap();
assert!(
drain_until(&mut rx, std::time::Duration::from_secs(2), |e| matches!(
e,
Event::RunEnd {
result: protocol::RunResult::RolledBack
}
))
.await,
"expected RunEnd::RolledBack after empty cancel"
);
wait_for_status(&handle, PodStatus::Idle).await;
let history = history_from_sink(&handle);
assert!(
!history
.iter()
.any(|item| item_text_contains(item, "rollback me")),
"rolled-back user input must not remain in history: {history:?}"
);
}
#[tokio::test]
async fn empty_turn_pause_rolls_back_and_snapshot_does_not_restore_input() {
let client = MockClient::sequential(vec![MockResponse::Hang(vec![])]);
let pod = make_pod(client).await;
let handle = spawn_controller(pod).await;
let mut rx = handle.subscribe();
handle
.send(Method::run_text("pause rollback"))
.await
.unwrap();
wait_for_status(&handle, PodStatus::Running).await;
handle.send(Method::Pause).await.unwrap();
assert!(
drain_until(&mut rx, std::time::Duration::from_secs(2), |e| matches!(
e,
Event::RunEnd {
result: protocol::RunResult::RolledBack
}
))
.await,
"expected RunEnd::RolledBack after empty pause"
);
wait_for_status(&handle, PodStatus::Idle).await;
assert!(
!snapshot_contains_user_input(&handle, "pause rollback").await,
"attach snapshot must not resurrect rolled-back empty turn input"
);
}
#[tokio::test]
async fn empty_turn_rollback_removes_only_the_most_recent_turn() {
let client = MockClient::sequential(vec![
MockResponse::Complete(simple_text_events()),
MockResponse::Hang(vec![]),
]);
let pod = make_pod(client).await;
let handle = spawn_controller(pod).await;
let mut rx = handle.subscribe();
handle.send(Method::run_text("first kept")).await.unwrap();
assert!(
drain_until(&mut rx, std::time::Duration::from_secs(2), |e| matches!(
e,
Event::RunEnd {
result: protocol::RunResult::Finished
}
))
.await,
"expected first run to finish"
);
wait_for_status(&handle, PodStatus::Idle).await;
handle
.send(Method::run_text("second rolled back"))
.await
.unwrap();
wait_for_status(&handle, PodStatus::Running).await;
handle.send(Method::Cancel).await.unwrap();
assert!(
drain_until(&mut rx, std::time::Duration::from_secs(2), |e| matches!(
e,
Event::RunEnd {
result: protocol::RunResult::RolledBack
}
))
.await,
"expected empty second run to roll back"
);
let history = history_from_sink(&handle);
assert!(
history
.iter()
.any(|item| item_text_contains(item, "first kept"))
);
assert!(
history
.iter()
.any(|item| item_text_contains(item, "Hello World"))
);
assert!(
!history
.iter()
.any(|item| item_text_contains(item, "second rolled back")),
"rollback must affect only the most recent empty turn: {history:?}"
);
}
#[tokio::test]
async fn pause_after_assistant_token_does_not_rollback() {
let client = MockClient::sequential(vec![MockResponse::Hang(vec![
LlmEvent::text_block_start(0),
LlmEvent::text_delta(0, "committed before pause"),
LlmEvent::text_block_stop(0, None),
])]);
let pod = make_pod(client).await;
let handle = spawn_controller(pod).await;
let mut rx = handle.subscribe();
handle
.send(Method::run_text("keep this turn"))
.await
.unwrap();
assert!(
drain_until(&mut rx, std::time::Duration::from_secs(2), |e| matches!(
e,
Event::TextDone { .. }
))
.await,
"assistant token should be visible before pause"
);
handle.send(Method::Pause).await.unwrap();
assert!(
drain_until(&mut rx, std::time::Duration::from_secs(2), |e| matches!(
e,
Event::RunEnd {
result: protocol::RunResult::Paused
}
))
.await,
"pause after assistant output must keep the existing Paused path"
);
wait_for_status(&handle, PodStatus::Paused).await;
let history = history_from_sink(&handle);
assert!(
history
.iter()
.any(|item| item_text_contains(item, "keep this turn")),
"token-visible turn must keep its UserInput entry: {history:?}"
);
}

View File

@ -571,6 +571,11 @@ pub enum RunResult {
Finished,
Paused,
LimitReached,
/// The accepted Method::Run produced no assistant/tool output before
/// user interruption, so the Pod rolled the submit-time turn state back
/// to its pre-submit snapshot. Clients should treat the Pod as Idle and
/// restore the just-submitted input into the editable composer if desired.
RolledBack,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]

View File

@ -82,6 +82,33 @@ pub trait Store: Send + Sync {
/// Check if a segment exists.
fn exists(&self, session_id: SessionId, segment_id: SegmentId) -> Result<bool, StoreError>;
/// Truncate a segment log to `entries_len` entries.
///
/// Used by Pod's submit-time empty-turn rollback after it has proven
/// that no LLM output from the accepted turn was materialized. The
/// default implementation rewrites the retained prefix through
/// `create_segment`, matching the append-only logical model while still
/// allowing concrete stores to provide a more direct truncate.
fn truncate(
&self,
session_id: SessionId,
segment_id: SegmentId,
entries_len: usize,
) -> Result<(), StoreError> {
let mut entries = self.read_all(session_id, segment_id)?;
if entries_len > entries.len() {
return Err(StoreError::Corrupt {
line: entries_len,
message: format!(
"cannot truncate segment {segment_id} to {entries_len} entries; only {} entries stored",
entries.len()
),
});
}
entries.truncate(entries_len);
self.create_segment(session_id, segment_id, &entries)
}
/// Count entries currently stored for a segment.
///
/// Used by `ensure_head_or_fork` to detect concurrent writers: