feat: rollback empty interrupted turns
This commit is contained in:
parent
8813d966bb
commit
03e7795130
|
|
@ -1,5 +1,6 @@
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::Ordering;
|
||||||
|
|
||||||
use llm_worker::WorkerError;
|
use llm_worker::WorkerError;
|
||||||
use llm_worker::llm_client::client::LlmClient;
|
use llm_worker::llm_client::client::LlmClient;
|
||||||
|
|
@ -303,6 +304,7 @@ fn wire_event_bridges_on_worker<C, St>(
|
||||||
C: LlmClient + Clone + 'static,
|
C: LlmClient + Clone + 'static,
|
||||||
St: Store + PodMetadataStore + Clone + 'static,
|
St: Store + PodMetadataStore + Clone + 'static,
|
||||||
{
|
{
|
||||||
|
let ai_activity = pod.ai_activity_counter();
|
||||||
let worker = pod.worker_mut();
|
let worker = pod.worker_mut();
|
||||||
|
|
||||||
let tx = event_tx.clone();
|
let tx = event_tx.clone();
|
||||||
|
|
@ -329,15 +331,22 @@ fn wire_event_bridges_on_worker<C, St>(
|
||||||
});
|
});
|
||||||
|
|
||||||
let tx = event_tx.clone();
|
let tx = event_tx.clone();
|
||||||
|
let activity = ai_activity.clone();
|
||||||
worker.on_text_block(move |block| {
|
worker.on_text_block(move |block| {
|
||||||
let tx_d = tx.clone();
|
let tx_d = tx.clone();
|
||||||
|
let activity_d = activity.clone();
|
||||||
block.on_delta(move |text| {
|
block.on_delta(move |text| {
|
||||||
|
activity_d.fetch_add(1, Ordering::SeqCst);
|
||||||
let _ = tx_d.send(Event::TextDelta {
|
let _ = tx_d.send(Event::TextDelta {
|
||||||
text: text.to_owned(),
|
text: text.to_owned(),
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
let tx_s = tx.clone();
|
let tx_s = tx.clone();
|
||||||
|
let activity_s = activity.clone();
|
||||||
block.on_stop(move |text| {
|
block.on_stop(move |text| {
|
||||||
|
if !text.is_empty() {
|
||||||
|
activity_s.fetch_add(1, Ordering::SeqCst);
|
||||||
|
}
|
||||||
let _ = tx_s.send(Event::TextDone {
|
let _ = tx_s.send(Event::TextDone {
|
||||||
text: text.to_owned(),
|
text: text.to_owned(),
|
||||||
});
|
});
|
||||||
|
|
@ -345,18 +354,26 @@ fn wire_event_bridges_on_worker<C, St>(
|
||||||
});
|
});
|
||||||
|
|
||||||
let tx = event_tx.clone();
|
let tx = event_tx.clone();
|
||||||
|
let activity = ai_activity.clone();
|
||||||
worker.on_thinking_block(move |block| {
|
worker.on_thinking_block(move |block| {
|
||||||
// Start fires unconditionally so the TUI can show "Thinking..."
|
// Start fires unconditionally so the TUI can show "Thinking..."
|
||||||
// even when the provider doesn't emit plaintext deltas.
|
// even when the provider doesn't emit plaintext deltas.
|
||||||
|
activity.fetch_add(1, Ordering::SeqCst);
|
||||||
let _ = tx.send(Event::ThinkingStart);
|
let _ = tx.send(Event::ThinkingStart);
|
||||||
let tx_d = tx.clone();
|
let tx_d = tx.clone();
|
||||||
|
let activity_d = activity.clone();
|
||||||
block.on_delta(move |text| {
|
block.on_delta(move |text| {
|
||||||
|
activity_d.fetch_add(1, Ordering::SeqCst);
|
||||||
let _ = tx_d.send(Event::ThinkingDelta {
|
let _ = tx_d.send(Event::ThinkingDelta {
|
||||||
text: text.to_owned(),
|
text: text.to_owned(),
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
let tx_s = tx.clone();
|
let tx_s = tx.clone();
|
||||||
|
let activity_s = activity.clone();
|
||||||
block.on_stop(move |text| {
|
block.on_stop(move |text| {
|
||||||
|
if !text.is_empty() {
|
||||||
|
activity_s.fetch_add(1, Ordering::SeqCst);
|
||||||
|
}
|
||||||
let _ = tx_s.send(Event::ThinkingDone {
|
let _ = tx_s.send(Event::ThinkingDone {
|
||||||
text: text.to_owned(),
|
text: text.to_owned(),
|
||||||
});
|
});
|
||||||
|
|
@ -364,21 +381,27 @@ fn wire_event_bridges_on_worker<C, St>(
|
||||||
});
|
});
|
||||||
|
|
||||||
let tx = event_tx.clone();
|
let tx = event_tx.clone();
|
||||||
|
let activity = ai_activity.clone();
|
||||||
worker.on_tool_use_block(move |start, block| {
|
worker.on_tool_use_block(move |start, block| {
|
||||||
|
activity.fetch_add(1, Ordering::SeqCst);
|
||||||
let _ = tx.send(Event::ToolCallStart {
|
let _ = tx.send(Event::ToolCallStart {
|
||||||
id: start.id.clone(),
|
id: start.id.clone(),
|
||||||
name: start.name.clone(),
|
name: start.name.clone(),
|
||||||
});
|
});
|
||||||
let id_for_delta = start.id.clone();
|
let id_for_delta = start.id.clone();
|
||||||
let tx_d = tx.clone();
|
let tx_d = tx.clone();
|
||||||
|
let activity_d = activity.clone();
|
||||||
block.on_delta(move |json| {
|
block.on_delta(move |json| {
|
||||||
|
activity_d.fetch_add(1, Ordering::SeqCst);
|
||||||
let _ = tx_d.send(Event::ToolCallArgsDelta {
|
let _ = tx_d.send(Event::ToolCallArgsDelta {
|
||||||
id: id_for_delta.clone(),
|
id: id_for_delta.clone(),
|
||||||
json: json.to_owned(),
|
json: json.to_owned(),
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
let tx_s = tx.clone();
|
let tx_s = tx.clone();
|
||||||
|
let activity_s = activity.clone();
|
||||||
block.on_stop(move |call| {
|
block.on_stop(move |call| {
|
||||||
|
activity_s.fetch_add(1, Ordering::SeqCst);
|
||||||
let _ = tx_s.send(Event::ToolCallDone {
|
let _ = tx_s.send(Event::ToolCallDone {
|
||||||
id: call.id.clone(),
|
id: call.id.clone(),
|
||||||
name: call.name.clone(),
|
name: call.name.clone(),
|
||||||
|
|
@ -388,7 +411,9 @@ fn wire_event_bridges_on_worker<C, St>(
|
||||||
});
|
});
|
||||||
|
|
||||||
let tx = event_tx.clone();
|
let tx = event_tx.clone();
|
||||||
|
let activity = ai_activity.clone();
|
||||||
worker.on_tool_result(move |result| {
|
worker.on_tool_result(move |result| {
|
||||||
|
activity.fetch_add(1, Ordering::SeqCst);
|
||||||
let _ = tx.send(Event::ToolResult {
|
let _ = tx.send(Event::ToolResult {
|
||||||
id: result.tool_use_id.clone(),
|
id: result.tool_use_id.clone(),
|
||||||
summary: result.summary.clone(),
|
summary: result.summary.clone(),
|
||||||
|
|
@ -879,6 +904,7 @@ where
|
||||||
PodRunResult::Finished => (PodStatus::Idle, RunResult::Finished),
|
PodRunResult::Finished => (PodStatus::Idle, RunResult::Finished),
|
||||||
PodRunResult::Paused => (PodStatus::Paused, RunResult::Paused),
|
PodRunResult::Paused => (PodStatus::Paused, RunResult::Paused),
|
||||||
PodRunResult::LimitReached => (PodStatus::Idle, RunResult::LimitReached),
|
PodRunResult::LimitReached => (PodStatus::Idle, RunResult::LimitReached),
|
||||||
|
PodRunResult::RolledBack => (PodStatus::Idle, RunResult::RolledBack),
|
||||||
};
|
};
|
||||||
let _ = event_tx.send(Event::RunEnd { result: run_result });
|
let _ = event_tx.send(Event::RunEnd { result: run_result });
|
||||||
if parent_originated && matches!(run_result, RunResult::Finished) {
|
if parent_originated && matches!(run_result, RunResult::Finished) {
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ use arc_swap::ArcSwap;
|
||||||
use llm_worker::Item;
|
use llm_worker::Item;
|
||||||
use llm_worker::llm_client::RequestConfig;
|
use llm_worker::llm_client::RequestConfig;
|
||||||
use llm_worker::llm_client::client::LlmClient;
|
use llm_worker::llm_client::client::LlmClient;
|
||||||
|
use llm_worker::llm_client::types::Role;
|
||||||
use llm_worker::state::Mutable;
|
use llm_worker::state::Mutable;
|
||||||
use llm_worker::{ToolOutputLimits, UsageRecord, Worker, WorkerError, WorkerResult};
|
use llm_worker::{ToolOutputLimits, UsageRecord, Worker, WorkerError, WorkerResult};
|
||||||
use session_store::{
|
use session_store::{
|
||||||
|
|
@ -121,6 +122,24 @@ impl SegmentState {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct EmptyTurnRollbackSnapshot {
|
||||||
|
history_len: usize,
|
||||||
|
user_segments_len: usize,
|
||||||
|
entries_written: usize,
|
||||||
|
sink_len: usize,
|
||||||
|
pending_attachments: Vec<SystemItem>,
|
||||||
|
usage_history_len: usize,
|
||||||
|
ai_activity_count: usize,
|
||||||
|
last_run_interrupted: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_ai_materialized_item(item: &Item) -> bool {
|
||||||
|
match item {
|
||||||
|
Item::Message { role, .. } => *role == Role::Assistant,
|
||||||
|
Item::ToolCall { .. } | Item::ToolResult { .. } | Item::Reasoning { .. } => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Cheap-cloneable bundle of (store + shared session pointer + sink)
|
/// Cheap-cloneable bundle of (store + shared session pointer + sink)
|
||||||
/// handed to the worker callback and the interceptor so they can
|
/// handed to the worker callback and the interceptor so they can
|
||||||
/// commit `LogEntry` values directly without going through an mpsc
|
/// commit `LogEntry` values directly without going through an mpsc
|
||||||
|
|
@ -254,6 +273,12 @@ pub struct Pod<C: LlmClient, St: Store> {
|
||||||
/// notifications, events sent here are NOT replayed to clients that
|
/// notifications, events sent here are NOT replayed to clients that
|
||||||
/// connect after the fact — they are fire-and-forget broadcasts.
|
/// connect after the fact — they are fire-and-forget broadcasts.
|
||||||
event_tx: Option<broadcast::Sender<Event>>,
|
event_tx: Option<broadcast::Sender<Event>>,
|
||||||
|
/// Monotonic counter incremented by worker event bridges when an
|
||||||
|
/// assistant-side execution artifact becomes visible to clients before
|
||||||
|
/// it is necessarily committed to history (e.g. streaming text deltas).
|
||||||
|
/// `Pod::run` uses it to avoid rolling back a turn after the UI has
|
||||||
|
/// already observed AI output.
|
||||||
|
ai_activity_counter: Arc<AtomicUsize>,
|
||||||
/// Queue of pending `Method::Notify` notifications awaiting
|
/// Queue of pending `Method::Notify` notifications awaiting
|
||||||
/// injection into the next LLM request. Shared with the
|
/// injection into the next LLM request. Shared with the
|
||||||
/// PodInterceptor installed in `ensure_interceptor_installed`.
|
/// PodInterceptor installed in `ensure_interceptor_installed`.
|
||||||
|
|
@ -392,6 +417,7 @@ impl<C: LlmClient + Clone + 'static, St: Store + Clone + 'static> Pod<C, St> {
|
||||||
system_prompt_template: None,
|
system_prompt_template: None,
|
||||||
alerter: self.alerter.clone(),
|
alerter: self.alerter.clone(),
|
||||||
event_tx: self.event_tx.clone(),
|
event_tx: self.event_tx.clone(),
|
||||||
|
ai_activity_counter: self.ai_activity_counter.clone(),
|
||||||
pending_notifies: NotifyBuffer::new(),
|
pending_notifies: NotifyBuffer::new(),
|
||||||
pending_attachments: Arc::new(Mutex::new(Vec::<SystemItem>::new())),
|
pending_attachments: Arc::new(Mutex::new(Vec::<SystemItem>::new())),
|
||||||
scope_allocation: None,
|
scope_allocation: None,
|
||||||
|
|
@ -534,6 +560,7 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
|
||||||
system_prompt_template: None,
|
system_prompt_template: None,
|
||||||
alerter: None,
|
alerter: None,
|
||||||
event_tx: None,
|
event_tx: None,
|
||||||
|
ai_activity_counter: Arc::new(AtomicUsize::new(0)),
|
||||||
pending_notifies: NotifyBuffer::new(),
|
pending_notifies: NotifyBuffer::new(),
|
||||||
pending_attachments: Arc::new(Mutex::new(Vec::<SystemItem>::new())),
|
pending_attachments: Arc::new(Mutex::new(Vec::<SystemItem>::new())),
|
||||||
scope_allocation: None,
|
scope_allocation: None,
|
||||||
|
|
@ -901,6 +928,12 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
|
||||||
self.event_tx = Some(event_tx);
|
self.event_tx = Some(event_tx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Shared activity counter incremented by worker event bridges when any
|
||||||
|
/// assistant-side output is surfaced before history persistence.
|
||||||
|
pub fn ai_activity_counter(&self) -> Arc<AtomicUsize> {
|
||||||
|
self.ai_activity_counter.clone()
|
||||||
|
}
|
||||||
|
|
||||||
fn alert(&self, level: AlertLevel, source: AlertSource, message: String) {
|
fn alert(&self, level: AlertLevel, source: AlertSource, message: String) {
|
||||||
if let Some(n) = self.alerter.as_ref() {
|
if let Some(n) = self.alerter.as_ref() {
|
||||||
n.alert(level, source, message);
|
n.alert(level, source, message);
|
||||||
|
|
@ -1236,6 +1269,75 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn capture_empty_turn_rollback_snapshot(&self) -> EmptyTurnRollbackSnapshot {
|
||||||
|
let pending_attachments = self
|
||||||
|
.pending_attachments
|
||||||
|
.lock()
|
||||||
|
.expect("pending_attachments poisoned")
|
||||||
|
.clone();
|
||||||
|
let usage_history_len = self
|
||||||
|
.usage_history
|
||||||
|
.lock()
|
||||||
|
.expect("usage_history poisoned")
|
||||||
|
.len();
|
||||||
|
EmptyTurnRollbackSnapshot {
|
||||||
|
history_len: self.worker().history().len(),
|
||||||
|
user_segments_len: self.user_segments.len(),
|
||||||
|
entries_written: self.segment_state.entries_written(),
|
||||||
|
sink_len: self.sink.len(),
|
||||||
|
pending_attachments,
|
||||||
|
usage_history_len,
|
||||||
|
ai_activity_count: self.ai_activity_counter.load(Ordering::SeqCst),
|
||||||
|
last_run_interrupted: self.worker().last_run_interrupted(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn should_rollback_empty_turn(
|
||||||
|
&self,
|
||||||
|
result: &Result<WorkerResult, WorkerError>,
|
||||||
|
snapshot: &EmptyTurnRollbackSnapshot,
|
||||||
|
) -> bool {
|
||||||
|
if !matches!(result, Err(WorkerError::Cancelled)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if self.ai_activity_counter.load(Ordering::SeqCst) != snapshot.ai_activity_count {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
!self.worker().history()[snapshot.history_len..]
|
||||||
|
.iter()
|
||||||
|
.any(is_ai_materialized_item)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rollback_empty_turn(
|
||||||
|
&mut self,
|
||||||
|
snapshot: EmptyTurnRollbackSnapshot,
|
||||||
|
) -> Result<(), StoreError> {
|
||||||
|
self.worker_mut()
|
||||||
|
.history_mut()
|
||||||
|
.truncate(snapshot.history_len);
|
||||||
|
self.worker_mut()
|
||||||
|
.set_last_run_interrupted(snapshot.last_run_interrupted);
|
||||||
|
self.user_segments.truncate(snapshot.user_segments_len);
|
||||||
|
*self
|
||||||
|
.pending_attachments
|
||||||
|
.lock()
|
||||||
|
.expect("pending_attachments poisoned") = snapshot.pending_attachments;
|
||||||
|
self.usage_history
|
||||||
|
.lock()
|
||||||
|
.expect("usage_history poisoned")
|
||||||
|
.truncate(snapshot.usage_history_len);
|
||||||
|
let _ = self.usage_tracker.drain();
|
||||||
|
let _ = self.metrics_tracker.drain();
|
||||||
|
|
||||||
|
let loc = self.segment_state.location();
|
||||||
|
self.store
|
||||||
|
.truncate(loc.session_id, loc.segment_id, snapshot.entries_written)?;
|
||||||
|
self.segment_state
|
||||||
|
.set_entries_written(snapshot.entries_written);
|
||||||
|
self.sink.truncate_silent(snapshot.sink_len);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Send user input and run until the LLM turn completes.
|
/// Send user input and run until the LLM turn completes.
|
||||||
///
|
///
|
||||||
/// `input` is a typed segment list (see [`protocol::Segment`]). The
|
/// `input` is a typed segment list (see [`protocol::Segment`]). The
|
||||||
|
|
@ -1270,6 +1372,8 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
|
||||||
|
|
||||||
self.prepare_for_run().await?;
|
self.prepare_for_run().await?;
|
||||||
|
|
||||||
|
let rollback_snapshot = self.capture_empty_turn_rollback_snapshot();
|
||||||
|
|
||||||
// IDLE → active marker. Commits first so the next UserInput entry
|
// IDLE → active marker. Commits first so the next UserInput entry
|
||||||
// is contained inside this Invoke range. See `tickets/invoke-turn-llmcall-semantics.md`.
|
// is contained inside this Invoke range. See `tickets/invoke-turn-llmcall-semantics.md`.
|
||||||
self.commit_entry(LogEntry::Invoke {
|
self.commit_entry(LogEntry::Invoke {
|
||||||
|
|
@ -1311,6 +1415,11 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
|
||||||
let result = locked.run(flattened).await;
|
let result = locked.run(flattened).await;
|
||||||
self.worker = Some(locked.unlock());
|
self.worker = Some(locked.unlock());
|
||||||
|
|
||||||
|
if self.should_rollback_empty_turn(&result, &rollback_snapshot) {
|
||||||
|
self.rollback_empty_turn(rollback_snapshot)?;
|
||||||
|
return Ok(PodRunResult::RolledBack);
|
||||||
|
}
|
||||||
|
|
||||||
self.handle_worker_result(result, history_before).await
|
self.handle_worker_result(result, history_before).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2847,6 +2956,7 @@ where
|
||||||
system_prompt_template: common.system_prompt_template,
|
system_prompt_template: common.system_prompt_template,
|
||||||
alerter: None,
|
alerter: None,
|
||||||
event_tx: None,
|
event_tx: None,
|
||||||
|
ai_activity_counter: Arc::new(AtomicUsize::new(0)),
|
||||||
pending_notifies: NotifyBuffer::new(),
|
pending_notifies: NotifyBuffer::new(),
|
||||||
pending_attachments: Arc::new(Mutex::new(Vec::<SystemItem>::new())),
|
pending_attachments: Arc::new(Mutex::new(Vec::<SystemItem>::new())),
|
||||||
scope_allocation: Some(scope_allocation),
|
scope_allocation: Some(scope_allocation),
|
||||||
|
|
@ -2923,6 +3033,7 @@ where
|
||||||
system_prompt_template: common.system_prompt_template,
|
system_prompt_template: common.system_prompt_template,
|
||||||
alerter: None,
|
alerter: None,
|
||||||
event_tx: None,
|
event_tx: None,
|
||||||
|
ai_activity_counter: Arc::new(AtomicUsize::new(0)),
|
||||||
pending_notifies: NotifyBuffer::new(),
|
pending_notifies: NotifyBuffer::new(),
|
||||||
pending_attachments: Arc::new(Mutex::new(Vec::<SystemItem>::new())),
|
pending_attachments: Arc::new(Mutex::new(Vec::<SystemItem>::new())),
|
||||||
scope_allocation: Some(scope_allocation),
|
scope_allocation: Some(scope_allocation),
|
||||||
|
|
@ -3098,6 +3209,7 @@ where
|
||||||
system_prompt_template: None,
|
system_prompt_template: None,
|
||||||
alerter: None,
|
alerter: None,
|
||||||
event_tx: None,
|
event_tx: None,
|
||||||
|
ai_activity_counter: Arc::new(AtomicUsize::new(0)),
|
||||||
pending_notifies: NotifyBuffer::new(),
|
pending_notifies: NotifyBuffer::new(),
|
||||||
pending_attachments: Arc::new(Mutex::new(Vec::<SystemItem>::new())),
|
pending_attachments: Arc::new(Mutex::new(Vec::<SystemItem>::new())),
|
||||||
scope_allocation: Some(scope_allocation),
|
scope_allocation: Some(scope_allocation),
|
||||||
|
|
@ -3184,6 +3296,8 @@ pub enum PodRunResult {
|
||||||
Paused,
|
Paused,
|
||||||
/// The worker reached its configured max_turns limit.
|
/// The worker reached its configured max_turns limit.
|
||||||
LimitReached,
|
LimitReached,
|
||||||
|
/// The submit-time user turn was rolled back because no AI output was materialized.
|
||||||
|
RolledBack,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<WorkerResult> for PodRunResult {
|
impl From<WorkerResult> for PodRunResult {
|
||||||
|
|
|
||||||
|
|
@ -1470,3 +1470,204 @@ async fn paused_then_run_closes_orphan_tool_use_for_next_request() {
|
||||||
"system note must precede new user message"
|
"system note must precede new user message"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn item_text_contains(item: &Item, needle: &str) -> bool {
|
||||||
|
item.as_text().unwrap_or_default().contains(needle)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn snapshot_contains_user_input(handle: &PodHandle, needle: &str) -> bool {
|
||||||
|
let stream = tokio::net::UnixStream::connect(handle.runtime_dir.socket_path())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let (reader, _writer) = stream.into_split();
|
||||||
|
let mut reader = protocol::stream::JsonLineReader::new(reader);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let event = reader.next::<Event>().await.unwrap().unwrap();
|
||||||
|
match event {
|
||||||
|
Event::Snapshot { entries, .. } => {
|
||||||
|
return entries.into_iter().any(|value| {
|
||||||
|
let entry: session_store::LogEntry =
|
||||||
|
serde_json::from_value(value).expect("LogEntry deserialise");
|
||||||
|
match entry {
|
||||||
|
session_store::LogEntry::UserInput { segments, .. } => {
|
||||||
|
protocol::Segment::flatten_to_text(&segments).contains(needle)
|
||||||
|
}
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Event::Alert(_) => continue,
|
||||||
|
other => panic!("expected Snapshot first, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn empty_turn_cancel_rolls_back_submit_entries_and_emits_signal() {
|
||||||
|
let client = MockClient::sequential(vec![MockResponse::Hang(vec![])]);
|
||||||
|
let pod = make_pod(client).await;
|
||||||
|
let handle = spawn_controller(pod).await;
|
||||||
|
let mut rx = handle.subscribe();
|
||||||
|
|
||||||
|
handle.send(Method::run_text("rollback me")).await.unwrap();
|
||||||
|
wait_for_status(&handle, PodStatus::Running).await;
|
||||||
|
handle.send(Method::Cancel).await.unwrap();
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
drain_until(&mut rx, std::time::Duration::from_secs(2), |e| matches!(
|
||||||
|
e,
|
||||||
|
Event::RunEnd {
|
||||||
|
result: protocol::RunResult::RolledBack
|
||||||
|
}
|
||||||
|
))
|
||||||
|
.await,
|
||||||
|
"expected RunEnd::RolledBack after empty cancel"
|
||||||
|
);
|
||||||
|
wait_for_status(&handle, PodStatus::Idle).await;
|
||||||
|
|
||||||
|
let history = history_from_sink(&handle);
|
||||||
|
assert!(
|
||||||
|
!history
|
||||||
|
.iter()
|
||||||
|
.any(|item| item_text_contains(item, "rollback me")),
|
||||||
|
"rolled-back user input must not remain in history: {history:?}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn empty_turn_pause_rolls_back_and_snapshot_does_not_restore_input() {
|
||||||
|
let client = MockClient::sequential(vec![MockResponse::Hang(vec![])]);
|
||||||
|
let pod = make_pod(client).await;
|
||||||
|
let handle = spawn_controller(pod).await;
|
||||||
|
let mut rx = handle.subscribe();
|
||||||
|
|
||||||
|
handle
|
||||||
|
.send(Method::run_text("pause rollback"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
wait_for_status(&handle, PodStatus::Running).await;
|
||||||
|
handle.send(Method::Pause).await.unwrap();
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
drain_until(&mut rx, std::time::Duration::from_secs(2), |e| matches!(
|
||||||
|
e,
|
||||||
|
Event::RunEnd {
|
||||||
|
result: protocol::RunResult::RolledBack
|
||||||
|
}
|
||||||
|
))
|
||||||
|
.await,
|
||||||
|
"expected RunEnd::RolledBack after empty pause"
|
||||||
|
);
|
||||||
|
wait_for_status(&handle, PodStatus::Idle).await;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!snapshot_contains_user_input(&handle, "pause rollback").await,
|
||||||
|
"attach snapshot must not resurrect rolled-back empty turn input"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn empty_turn_rollback_removes_only_the_most_recent_turn() {
|
||||||
|
let client = MockClient::sequential(vec![
|
||||||
|
MockResponse::Complete(simple_text_events()),
|
||||||
|
MockResponse::Hang(vec![]),
|
||||||
|
]);
|
||||||
|
let pod = make_pod(client).await;
|
||||||
|
let handle = spawn_controller(pod).await;
|
||||||
|
let mut rx = handle.subscribe();
|
||||||
|
|
||||||
|
handle.send(Method::run_text("first kept")).await.unwrap();
|
||||||
|
assert!(
|
||||||
|
drain_until(&mut rx, std::time::Duration::from_secs(2), |e| matches!(
|
||||||
|
e,
|
||||||
|
Event::RunEnd {
|
||||||
|
result: protocol::RunResult::Finished
|
||||||
|
}
|
||||||
|
))
|
||||||
|
.await,
|
||||||
|
"expected first run to finish"
|
||||||
|
);
|
||||||
|
wait_for_status(&handle, PodStatus::Idle).await;
|
||||||
|
|
||||||
|
handle
|
||||||
|
.send(Method::run_text("second rolled back"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
wait_for_status(&handle, PodStatus::Running).await;
|
||||||
|
handle.send(Method::Cancel).await.unwrap();
|
||||||
|
assert!(
|
||||||
|
drain_until(&mut rx, std::time::Duration::from_secs(2), |e| matches!(
|
||||||
|
e,
|
||||||
|
Event::RunEnd {
|
||||||
|
result: protocol::RunResult::RolledBack
|
||||||
|
}
|
||||||
|
))
|
||||||
|
.await,
|
||||||
|
"expected empty second run to roll back"
|
||||||
|
);
|
||||||
|
|
||||||
|
let history = history_from_sink(&handle);
|
||||||
|
assert!(
|
||||||
|
history
|
||||||
|
.iter()
|
||||||
|
.any(|item| item_text_contains(item, "first kept"))
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
history
|
||||||
|
.iter()
|
||||||
|
.any(|item| item_text_contains(item, "Hello World"))
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
!history
|
||||||
|
.iter()
|
||||||
|
.any(|item| item_text_contains(item, "second rolled back")),
|
||||||
|
"rollback must affect only the most recent empty turn: {history:?}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn pause_after_assistant_token_does_not_rollback() {
|
||||||
|
let client = MockClient::sequential(vec![MockResponse::Hang(vec![
|
||||||
|
LlmEvent::text_block_start(0),
|
||||||
|
LlmEvent::text_delta(0, "committed before pause"),
|
||||||
|
LlmEvent::text_block_stop(0, None),
|
||||||
|
])]);
|
||||||
|
let pod = make_pod(client).await;
|
||||||
|
let handle = spawn_controller(pod).await;
|
||||||
|
let mut rx = handle.subscribe();
|
||||||
|
|
||||||
|
handle
|
||||||
|
.send(Method::run_text("keep this turn"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(
|
||||||
|
drain_until(&mut rx, std::time::Duration::from_secs(2), |e| matches!(
|
||||||
|
e,
|
||||||
|
Event::TextDone { .. }
|
||||||
|
))
|
||||||
|
.await,
|
||||||
|
"assistant token should be visible before pause"
|
||||||
|
);
|
||||||
|
handle.send(Method::Pause).await.unwrap();
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
drain_until(&mut rx, std::time::Duration::from_secs(2), |e| matches!(
|
||||||
|
e,
|
||||||
|
Event::RunEnd {
|
||||||
|
result: protocol::RunResult::Paused
|
||||||
|
}
|
||||||
|
))
|
||||||
|
.await,
|
||||||
|
"pause after assistant output must keep the existing Paused path"
|
||||||
|
);
|
||||||
|
wait_for_status(&handle, PodStatus::Paused).await;
|
||||||
|
|
||||||
|
let history = history_from_sink(&handle);
|
||||||
|
assert!(
|
||||||
|
history
|
||||||
|
.iter()
|
||||||
|
.any(|item| item_text_contains(item, "keep this turn")),
|
||||||
|
"token-visible turn must keep its UserInput entry: {history:?}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -571,6 +571,11 @@ pub enum RunResult {
|
||||||
Finished,
|
Finished,
|
||||||
Paused,
|
Paused,
|
||||||
LimitReached,
|
LimitReached,
|
||||||
|
/// The accepted Method::Run produced no assistant/tool output before
|
||||||
|
/// user interruption, so the Pod rolled the submit-time turn state back
|
||||||
|
/// to its pre-submit snapshot. Clients should treat the Pod as Idle and
|
||||||
|
/// restore the just-submitted input into the editable composer if desired.
|
||||||
|
RolledBack,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
|
|
||||||
|
|
@ -82,6 +82,33 @@ pub trait Store: Send + Sync {
|
||||||
/// Check if a segment exists.
|
/// Check if a segment exists.
|
||||||
fn exists(&self, session_id: SessionId, segment_id: SegmentId) -> Result<bool, StoreError>;
|
fn exists(&self, session_id: SessionId, segment_id: SegmentId) -> Result<bool, StoreError>;
|
||||||
|
|
||||||
|
/// Truncate a segment log to `entries_len` entries.
|
||||||
|
///
|
||||||
|
/// Used by Pod's submit-time empty-turn rollback after it has proven
|
||||||
|
/// that no LLM output from the accepted turn was materialized. The
|
||||||
|
/// default implementation rewrites the retained prefix through
|
||||||
|
/// `create_segment`, matching the append-only logical model while still
|
||||||
|
/// allowing concrete stores to provide a more direct truncate.
|
||||||
|
fn truncate(
|
||||||
|
&self,
|
||||||
|
session_id: SessionId,
|
||||||
|
segment_id: SegmentId,
|
||||||
|
entries_len: usize,
|
||||||
|
) -> Result<(), StoreError> {
|
||||||
|
let mut entries = self.read_all(session_id, segment_id)?;
|
||||||
|
if entries_len > entries.len() {
|
||||||
|
return Err(StoreError::Corrupt {
|
||||||
|
line: entries_len,
|
||||||
|
message: format!(
|
||||||
|
"cannot truncate segment {segment_id} to {entries_len} entries; only {} entries stored",
|
||||||
|
entries.len()
|
||||||
|
),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
entries.truncate(entries_len);
|
||||||
|
self.create_segment(session_id, segment_id, &entries)
|
||||||
|
}
|
||||||
|
|
||||||
/// Count entries currently stored for a segment.
|
/// Count entries currently stored for a segment.
|
||||||
///
|
///
|
||||||
/// Used by `ensure_head_or_fork` to detect concurrent writers:
|
/// Used by `ensure_head_or_fork` to detect concurrent writers:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user