diff --git a/crates/pod/src/controller.rs b/crates/pod/src/controller.rs index d2fc319b..ae0c339e 100644 --- a/crates/pod/src/controller.rs +++ b/crates/pod/src/controller.rs @@ -1,5 +1,6 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; +use std::sync::atomic::Ordering; use llm_worker::WorkerError; use llm_worker::llm_client::client::LlmClient; @@ -303,6 +304,7 @@ fn wire_event_bridges_on_worker( C: LlmClient + Clone + 'static, St: Store + PodMetadataStore + Clone + 'static, { + let ai_activity = pod.ai_activity_counter(); let worker = pod.worker_mut(); let tx = event_tx.clone(); @@ -329,15 +331,22 @@ fn wire_event_bridges_on_worker( }); let tx = event_tx.clone(); + let activity = ai_activity.clone(); worker.on_text_block(move |block| { let tx_d = tx.clone(); + let activity_d = activity.clone(); block.on_delta(move |text| { + activity_d.fetch_add(1, Ordering::SeqCst); let _ = tx_d.send(Event::TextDelta { text: text.to_owned(), }); }); let tx_s = tx.clone(); + let activity_s = activity.clone(); block.on_stop(move |text| { + if !text.is_empty() { + activity_s.fetch_add(1, Ordering::SeqCst); + } let _ = tx_s.send(Event::TextDone { text: text.to_owned(), }); @@ -345,18 +354,26 @@ fn wire_event_bridges_on_worker( }); let tx = event_tx.clone(); + let activity = ai_activity.clone(); worker.on_thinking_block(move |block| { // Start fires unconditionally so the TUI can show "Thinking..." // even when the provider doesn't emit plaintext deltas. + activity.fetch_add(1, Ordering::SeqCst); let _ = tx.send(Event::ThinkingStart); let tx_d = tx.clone(); + let activity_d = activity.clone(); block.on_delta(move |text| { + activity_d.fetch_add(1, Ordering::SeqCst); let _ = tx_d.send(Event::ThinkingDelta { text: text.to_owned(), }); }); let tx_s = tx.clone(); + let activity_s = activity.clone(); block.on_stop(move |text| { + if !text.is_empty() { + activity_s.fetch_add(1, Ordering::SeqCst); + } let _ = tx_s.send(Event::ThinkingDone { text: text.to_owned(), }); @@ -364,21 +381,27 @@ fn wire_event_bridges_on_worker( }); let tx = event_tx.clone(); + let activity = ai_activity.clone(); worker.on_tool_use_block(move |start, block| { + activity.fetch_add(1, Ordering::SeqCst); let _ = tx.send(Event::ToolCallStart { id: start.id.clone(), name: start.name.clone(), }); let id_for_delta = start.id.clone(); let tx_d = tx.clone(); + let activity_d = activity.clone(); block.on_delta(move |json| { + activity_d.fetch_add(1, Ordering::SeqCst); let _ = tx_d.send(Event::ToolCallArgsDelta { id: id_for_delta.clone(), json: json.to_owned(), }); }); let tx_s = tx.clone(); + let activity_s = activity.clone(); block.on_stop(move |call| { + activity_s.fetch_add(1, Ordering::SeqCst); let _ = tx_s.send(Event::ToolCallDone { id: call.id.clone(), name: call.name.clone(), @@ -388,7 +411,9 @@ fn wire_event_bridges_on_worker( }); let tx = event_tx.clone(); + let activity = ai_activity.clone(); worker.on_tool_result(move |result| { + activity.fetch_add(1, Ordering::SeqCst); let _ = tx.send(Event::ToolResult { id: result.tool_use_id.clone(), summary: result.summary.clone(), @@ -879,6 +904,7 @@ where PodRunResult::Finished => (PodStatus::Idle, RunResult::Finished), PodRunResult::Paused => (PodStatus::Paused, RunResult::Paused), PodRunResult::LimitReached => (PodStatus::Idle, RunResult::LimitReached), + PodRunResult::RolledBack => (PodStatus::Idle, RunResult::RolledBack), }; let _ = event_tx.send(Event::RunEnd { result: run_result }); if parent_originated && matches!(run_result, RunResult::Finished) { diff --git a/crates/pod/src/pod.rs b/crates/pod/src/pod.rs index 4c7b61ff..ec8bc52c 100644 --- a/crates/pod/src/pod.rs +++ b/crates/pod/src/pod.rs @@ -6,6 +6,7 @@ use arc_swap::ArcSwap; use llm_worker::Item; use llm_worker::llm_client::RequestConfig; use llm_worker::llm_client::client::LlmClient; +use llm_worker::llm_client::types::Role; use llm_worker::state::Mutable; use llm_worker::{ToolOutputLimits, UsageRecord, Worker, WorkerError, WorkerResult}; use session_store::{ @@ -121,6 +122,24 @@ impl SegmentState { } } +struct EmptyTurnRollbackSnapshot { + history_len: usize, + user_segments_len: usize, + entries_written: usize, + sink_len: usize, + pending_attachments: Vec, + usage_history_len: usize, + ai_activity_count: usize, + last_run_interrupted: bool, +} + +fn is_ai_materialized_item(item: &Item) -> bool { + match item { + Item::Message { role, .. } => *role == Role::Assistant, + Item::ToolCall { .. } | Item::ToolResult { .. } | Item::Reasoning { .. } => true, + } +} + /// Cheap-cloneable bundle of (store + shared session pointer + sink) /// handed to the worker callback and the interceptor so they can /// commit `LogEntry` values directly without going through an mpsc @@ -254,6 +273,12 @@ pub struct Pod { /// notifications, events sent here are NOT replayed to clients that /// connect after the fact — they are fire-and-forget broadcasts. event_tx: Option>, + /// Monotonic counter incremented by worker event bridges when an + /// assistant-side execution artifact becomes visible to clients before + /// it is necessarily committed to history (e.g. streaming text deltas). + /// `Pod::run` uses it to avoid rolling back a turn after the UI has + /// already observed AI output. + ai_activity_counter: Arc, /// Queue of pending `Method::Notify` notifications awaiting /// injection into the next LLM request. Shared with the /// PodInterceptor installed in `ensure_interceptor_installed`. @@ -392,6 +417,7 @@ impl Pod { system_prompt_template: None, alerter: self.alerter.clone(), event_tx: self.event_tx.clone(), + ai_activity_counter: self.ai_activity_counter.clone(), pending_notifies: NotifyBuffer::new(), pending_attachments: Arc::new(Mutex::new(Vec::::new())), scope_allocation: None, @@ -534,6 +560,7 @@ impl Pod { system_prompt_template: None, alerter: None, event_tx: None, + ai_activity_counter: Arc::new(AtomicUsize::new(0)), pending_notifies: NotifyBuffer::new(), pending_attachments: Arc::new(Mutex::new(Vec::::new())), scope_allocation: None, @@ -901,6 +928,12 @@ impl Pod { self.event_tx = Some(event_tx); } + /// Shared activity counter incremented by worker event bridges when any + /// assistant-side output is surfaced before history persistence. + pub fn ai_activity_counter(&self) -> Arc { + self.ai_activity_counter.clone() + } + fn alert(&self, level: AlertLevel, source: AlertSource, message: String) { if let Some(n) = self.alerter.as_ref() { n.alert(level, source, message); @@ -1236,6 +1269,75 @@ impl Pod { Ok(()) } + fn capture_empty_turn_rollback_snapshot(&self) -> EmptyTurnRollbackSnapshot { + let pending_attachments = self + .pending_attachments + .lock() + .expect("pending_attachments poisoned") + .clone(); + let usage_history_len = self + .usage_history + .lock() + .expect("usage_history poisoned") + .len(); + EmptyTurnRollbackSnapshot { + history_len: self.worker().history().len(), + user_segments_len: self.user_segments.len(), + entries_written: self.segment_state.entries_written(), + sink_len: self.sink.len(), + pending_attachments, + usage_history_len, + ai_activity_count: self.ai_activity_counter.load(Ordering::SeqCst), + last_run_interrupted: self.worker().last_run_interrupted(), + } + } + + fn should_rollback_empty_turn( + &self, + result: &Result, + snapshot: &EmptyTurnRollbackSnapshot, + ) -> bool { + if !matches!(result, Err(WorkerError::Cancelled)) { + return false; + } + if self.ai_activity_counter.load(Ordering::SeqCst) != snapshot.ai_activity_count { + return false; + } + !self.worker().history()[snapshot.history_len..] + .iter() + .any(is_ai_materialized_item) + } + + fn rollback_empty_turn( + &mut self, + snapshot: EmptyTurnRollbackSnapshot, + ) -> Result<(), StoreError> { + self.worker_mut() + .history_mut() + .truncate(snapshot.history_len); + self.worker_mut() + .set_last_run_interrupted(snapshot.last_run_interrupted); + self.user_segments.truncate(snapshot.user_segments_len); + *self + .pending_attachments + .lock() + .expect("pending_attachments poisoned") = snapshot.pending_attachments; + self.usage_history + .lock() + .expect("usage_history poisoned") + .truncate(snapshot.usage_history_len); + let _ = self.usage_tracker.drain(); + let _ = self.metrics_tracker.drain(); + + let loc = self.segment_state.location(); + self.store + .truncate(loc.session_id, loc.segment_id, snapshot.entries_written)?; + self.segment_state + .set_entries_written(snapshot.entries_written); + self.sink.truncate_silent(snapshot.sink_len); + Ok(()) + } + /// Send user input and run until the LLM turn completes. /// /// `input` is a typed segment list (see [`protocol::Segment`]). The @@ -1270,6 +1372,8 @@ impl Pod { self.prepare_for_run().await?; + let rollback_snapshot = self.capture_empty_turn_rollback_snapshot(); + // IDLE → active marker. Commits first so the next UserInput entry // is contained inside this Invoke range. See `tickets/invoke-turn-llmcall-semantics.md`. self.commit_entry(LogEntry::Invoke { @@ -1311,6 +1415,11 @@ impl Pod { let result = locked.run(flattened).await; self.worker = Some(locked.unlock()); + if self.should_rollback_empty_turn(&result, &rollback_snapshot) { + self.rollback_empty_turn(rollback_snapshot)?; + return Ok(PodRunResult::RolledBack); + } + self.handle_worker_result(result, history_before).await } @@ -2847,6 +2956,7 @@ where system_prompt_template: common.system_prompt_template, alerter: None, event_tx: None, + ai_activity_counter: Arc::new(AtomicUsize::new(0)), pending_notifies: NotifyBuffer::new(), pending_attachments: Arc::new(Mutex::new(Vec::::new())), scope_allocation: Some(scope_allocation), @@ -2923,6 +3033,7 @@ where system_prompt_template: common.system_prompt_template, alerter: None, event_tx: None, + ai_activity_counter: Arc::new(AtomicUsize::new(0)), pending_notifies: NotifyBuffer::new(), pending_attachments: Arc::new(Mutex::new(Vec::::new())), scope_allocation: Some(scope_allocation), @@ -3098,6 +3209,7 @@ where system_prompt_template: None, alerter: None, event_tx: None, + ai_activity_counter: Arc::new(AtomicUsize::new(0)), pending_notifies: NotifyBuffer::new(), pending_attachments: Arc::new(Mutex::new(Vec::::new())), scope_allocation: Some(scope_allocation), @@ -3184,6 +3296,8 @@ pub enum PodRunResult { Paused, /// The worker reached its configured max_turns limit. LimitReached, + /// The submit-time user turn was rolled back because no AI output was materialized. + RolledBack, } impl From for PodRunResult { diff --git a/crates/pod/tests/controller_test.rs b/crates/pod/tests/controller_test.rs index b422a119..6e5851e0 100644 --- a/crates/pod/tests/controller_test.rs +++ b/crates/pod/tests/controller_test.rs @@ -1470,3 +1470,204 @@ async fn paused_then_run_closes_orphan_tool_use_for_next_request() { "system note must precede new user message" ); } + +fn item_text_contains(item: &Item, needle: &str) -> bool { + item.as_text().unwrap_or_default().contains(needle) +} + +async fn snapshot_contains_user_input(handle: &PodHandle, needle: &str) -> bool { + let stream = tokio::net::UnixStream::connect(handle.runtime_dir.socket_path()) + .await + .unwrap(); + let (reader, _writer) = stream.into_split(); + let mut reader = protocol::stream::JsonLineReader::new(reader); + + loop { + let event = reader.next::().await.unwrap().unwrap(); + match event { + Event::Snapshot { entries, .. } => { + return entries.into_iter().any(|value| { + let entry: session_store::LogEntry = + serde_json::from_value(value).expect("LogEntry deserialise"); + match entry { + session_store::LogEntry::UserInput { segments, .. } => { + protocol::Segment::flatten_to_text(&segments).contains(needle) + } + _ => false, + } + }); + } + Event::Alert(_) => continue, + other => panic!("expected Snapshot first, got {other:?}"), + } + } +} + +#[tokio::test] +async fn empty_turn_cancel_rolls_back_submit_entries_and_emits_signal() { + let client = MockClient::sequential(vec![MockResponse::Hang(vec![])]); + let pod = make_pod(client).await; + let handle = spawn_controller(pod).await; + let mut rx = handle.subscribe(); + + handle.send(Method::run_text("rollback me")).await.unwrap(); + wait_for_status(&handle, PodStatus::Running).await; + handle.send(Method::Cancel).await.unwrap(); + + assert!( + drain_until(&mut rx, std::time::Duration::from_secs(2), |e| matches!( + e, + Event::RunEnd { + result: protocol::RunResult::RolledBack + } + )) + .await, + "expected RunEnd::RolledBack after empty cancel" + ); + wait_for_status(&handle, PodStatus::Idle).await; + + let history = history_from_sink(&handle); + assert!( + !history + .iter() + .any(|item| item_text_contains(item, "rollback me")), + "rolled-back user input must not remain in history: {history:?}" + ); +} + +#[tokio::test] +async fn empty_turn_pause_rolls_back_and_snapshot_does_not_restore_input() { + let client = MockClient::sequential(vec![MockResponse::Hang(vec![])]); + let pod = make_pod(client).await; + let handle = spawn_controller(pod).await; + let mut rx = handle.subscribe(); + + handle + .send(Method::run_text("pause rollback")) + .await + .unwrap(); + wait_for_status(&handle, PodStatus::Running).await; + handle.send(Method::Pause).await.unwrap(); + + assert!( + drain_until(&mut rx, std::time::Duration::from_secs(2), |e| matches!( + e, + Event::RunEnd { + result: protocol::RunResult::RolledBack + } + )) + .await, + "expected RunEnd::RolledBack after empty pause" + ); + wait_for_status(&handle, PodStatus::Idle).await; + + assert!( + !snapshot_contains_user_input(&handle, "pause rollback").await, + "attach snapshot must not resurrect rolled-back empty turn input" + ); +} + +#[tokio::test] +async fn empty_turn_rollback_removes_only_the_most_recent_turn() { + let client = MockClient::sequential(vec![ + MockResponse::Complete(simple_text_events()), + MockResponse::Hang(vec![]), + ]); + let pod = make_pod(client).await; + let handle = spawn_controller(pod).await; + let mut rx = handle.subscribe(); + + handle.send(Method::run_text("first kept")).await.unwrap(); + assert!( + drain_until(&mut rx, std::time::Duration::from_secs(2), |e| matches!( + e, + Event::RunEnd { + result: protocol::RunResult::Finished + } + )) + .await, + "expected first run to finish" + ); + wait_for_status(&handle, PodStatus::Idle).await; + + handle + .send(Method::run_text("second rolled back")) + .await + .unwrap(); + wait_for_status(&handle, PodStatus::Running).await; + handle.send(Method::Cancel).await.unwrap(); + assert!( + drain_until(&mut rx, std::time::Duration::from_secs(2), |e| matches!( + e, + Event::RunEnd { + result: protocol::RunResult::RolledBack + } + )) + .await, + "expected empty second run to roll back" + ); + + let history = history_from_sink(&handle); + assert!( + history + .iter() + .any(|item| item_text_contains(item, "first kept")) + ); + assert!( + history + .iter() + .any(|item| item_text_contains(item, "Hello World")) + ); + assert!( + !history + .iter() + .any(|item| item_text_contains(item, "second rolled back")), + "rollback must affect only the most recent empty turn: {history:?}" + ); +} + +#[tokio::test] +async fn pause_after_assistant_token_does_not_rollback() { + let client = MockClient::sequential(vec![MockResponse::Hang(vec![ + LlmEvent::text_block_start(0), + LlmEvent::text_delta(0, "committed before pause"), + LlmEvent::text_block_stop(0, None), + ])]); + let pod = make_pod(client).await; + let handle = spawn_controller(pod).await; + let mut rx = handle.subscribe(); + + handle + .send(Method::run_text("keep this turn")) + .await + .unwrap(); + assert!( + drain_until(&mut rx, std::time::Duration::from_secs(2), |e| matches!( + e, + Event::TextDone { .. } + )) + .await, + "assistant token should be visible before pause" + ); + handle.send(Method::Pause).await.unwrap(); + + assert!( + drain_until(&mut rx, std::time::Duration::from_secs(2), |e| matches!( + e, + Event::RunEnd { + result: protocol::RunResult::Paused + } + )) + .await, + "pause after assistant output must keep the existing Paused path" + ); + wait_for_status(&handle, PodStatus::Paused).await; + + let history = history_from_sink(&handle); + assert!( + history + .iter() + .any(|item| item_text_contains(item, "keep this turn")), + "token-visible turn must keep its UserInput entry: {history:?}" + ); +} diff --git a/crates/protocol/src/lib.rs b/crates/protocol/src/lib.rs index 7d46faa4..c383d495 100644 --- a/crates/protocol/src/lib.rs +++ b/crates/protocol/src/lib.rs @@ -571,6 +571,11 @@ pub enum RunResult { Finished, Paused, LimitReached, + /// The accepted Method::Run produced no assistant/tool output before + /// user interruption, so the Pod rolled the submit-time turn state back + /// to its pre-submit snapshot. Clients should treat the Pod as Idle and + /// restore the just-submitted input into the editable composer if desired. + RolledBack, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] diff --git a/crates/session-store/src/store.rs b/crates/session-store/src/store.rs index 7e181794..e2c5b6b8 100644 --- a/crates/session-store/src/store.rs +++ b/crates/session-store/src/store.rs @@ -82,6 +82,33 @@ pub trait Store: Send + Sync { /// Check if a segment exists. fn exists(&self, session_id: SessionId, segment_id: SegmentId) -> Result; + /// Truncate a segment log to `entries_len` entries. + /// + /// Used by Pod's submit-time empty-turn rollback after it has proven + /// that no LLM output from the accepted turn was materialized. The + /// default implementation rewrites the retained prefix through + /// `create_segment`, matching the append-only logical model while still + /// allowing concrete stores to provide a more direct truncate. + fn truncate( + &self, + session_id: SessionId, + segment_id: SegmentId, + entries_len: usize, + ) -> Result<(), StoreError> { + let mut entries = self.read_all(session_id, segment_id)?; + if entries_len > entries.len() { + return Err(StoreError::Corrupt { + line: entries_len, + message: format!( + "cannot truncate segment {segment_id} to {entries_len} entries; only {} entries stored", + entries.len() + ), + }); + } + entries.truncate(entries_len); + self.create_segment(session_id, segment_id, &entries) + } + /// Count entries currently stored for a segment. /// /// Used by `ensure_head_or_fork` to detect concurrent writers: