use std::sync::{Arc, Mutex, MutexGuard}; use protocol::{Event, InFlightBlock, InFlightSnapshot, InFlightToolCallState}; use session_store::{LoggedContentPart, LoggedItem}; use tokio::sync::broadcast; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct InFlightBlockId(u64); #[derive(Debug, Clone)] pub struct InFlightEvents { inner: Arc>, event_tx: broadcast::Sender, } #[derive(Debug)] pub(crate) struct InFlightInner { next_block_id: u64, blocks: Vec, } #[derive(Debug, Clone)] enum TrackedBlock { Text { block_id: InFlightBlockId, text: String, finished: bool, }, Thinking { block_id: InFlightBlockId, text: String, finished: bool, }, ToolCall { block_id: InFlightBlockId, id: String, name: String, args: String, state: InFlightToolCallState, }, } impl InFlightEvents { pub(crate) fn new(event_tx: broadcast::Sender) -> Self { Self { inner: Arc::new(Mutex::new(InFlightInner { next_block_id: 1, blocks: Vec::new(), })), event_tx, } } pub(crate) fn snapshot_guard(&self) -> MutexGuard<'_, InFlightInner> { self.inner.lock().expect("in-flight event mutex poisoned") } pub(crate) fn start_text_block(&self) -> InFlightBlockId { let mut inner = self.lock(); let block_id = inner.next_id(); inner.blocks.push(TrackedBlock::Text { block_id, text: String::new(), finished: false, }); block_id } pub(crate) fn text_delta(&self, block_id: InFlightBlockId, text: String) { let mut inner = self.lock(); if let Some(TrackedBlock::Text { text: current, finished, .. }) = inner.find_block_mut(block_id) { current.push_str(&text); *finished = false; } let _ = self.event_tx.send(Event::TextDelta { text }); } pub(crate) fn text_done(&self, block_id: InFlightBlockId, text: String) { let mut inner = self.lock(); if let Some(TrackedBlock::Text { text: current, finished, .. }) = inner.find_block_mut(block_id) { if current.is_empty() { *current = text.clone(); } *finished = true; } let _ = self.event_tx.send(Event::TextDone { text }); } pub(crate) fn thinking_start(&self) -> InFlightBlockId { let mut inner = self.lock(); let block_id = inner.next_id(); inner.blocks.push(TrackedBlock::Thinking { block_id, text: String::new(), finished: false, }); let _ = self.event_tx.send(Event::ThinkingStart); block_id } pub(crate) fn thinking_delta(&self, block_id: InFlightBlockId, text: String) { let mut inner = self.lock(); if let Some(TrackedBlock::Thinking { text: current, finished, .. }) = inner.find_block_mut(block_id) { current.push_str(&text); *finished = false; } let _ = self.event_tx.send(Event::ThinkingDelta { text }); } pub(crate) fn thinking_done(&self, block_id: InFlightBlockId, text: String) { let mut inner = self.lock(); if let Some(TrackedBlock::Thinking { text: current, finished, .. }) = inner.find_block_mut(block_id) { if current.is_empty() { *current = text.clone(); } *finished = true; } let _ = self.event_tx.send(Event::ThinkingDone { text }); } pub(crate) fn tool_call_start(&self, id: String, name: String) -> InFlightBlockId { let mut inner = self.lock(); let block_id = inner.next_id(); inner.blocks.push(TrackedBlock::ToolCall { block_id, id: id.clone(), name: name.clone(), args: String::new(), state: InFlightToolCallState::Pending, }); let _ = self.event_tx.send(Event::ToolCallStart { id, name }); block_id } pub(crate) fn tool_call_args_delta( &self, block_id: InFlightBlockId, id: String, delta: String, ) { let mut inner = self.lock(); if let Some(TrackedBlock::ToolCall { args, state, .. }) = inner.find_block_mut(block_id) { args.push_str(&delta); *state = InFlightToolCallState::StreamingArgs; } let _ = self .event_tx .send(Event::ToolCallArgsDelta { id, json: delta }); } pub(crate) fn tool_call_done(&self, block_id: InFlightBlockId, id: String, args: String) { let mut inner = self.lock(); let mut name = String::new(); if let Some(TrackedBlock::ToolCall { name: current_name, args: current, state, .. }) = inner.find_block_mut(block_id) { name = current_name.clone(); if current.is_empty() { *current = args.clone(); } *state = InFlightToolCallState::Done; } let _ = self.event_tx.send(Event::ToolCallDone { id, name, arguments: args, }); } pub(crate) fn clear_for_committed_item_then( &self, item: &LoggedItem, f: impl FnOnce() -> R, ) -> R { let mut inner = self.lock(); inner.clear_for_committed_item(item); f() } fn lock(&self) -> MutexGuard<'_, InFlightInner> { self.inner.lock().expect("in-flight event mutex poisoned") } } impl InFlightInner { fn next_id(&mut self) -> InFlightBlockId { let id = InFlightBlockId(self.next_block_id); self.next_block_id = self.next_block_id.saturating_add(1); id } fn find_block_mut(&mut self, block_id: InFlightBlockId) -> Option<&mut TrackedBlock> { self.blocks .iter_mut() .find(|block| block.block_id() == block_id) } fn clear_for_committed_item(&mut self, item: &LoggedItem) { match item { LoggedItem::Message { role, content } if matches!(role, session_store::LoggedRole::Assistant) => { let text = content .iter() .filter_map(|part| match part { LoggedContentPart::Text { text } => Some(text.as_str()), LoggedContentPart::Refusal { refusal } => Some(refusal.as_str()), }) .collect::(); if !text.is_empty() { self.remove_first_text_matching(&text); } } LoggedItem::Reasoning { text, .. } => { self.remove_first_thinking_matching(text); } LoggedItem::ToolCall { call_id, .. } => { self.remove_tool_call(call_id); } _ => {} } } fn snapshot(&self) -> InFlightSnapshot { InFlightSnapshot { blocks: self .blocks .iter() .filter_map(TrackedBlock::to_snapshot_block) .collect(), } } fn remove_first_text_matching(&mut self, committed: &str) { if let Some(index) = self.blocks.iter().position(|block| match block { TrackedBlock::Text { text, .. } => text == committed, _ => false, }) { self.blocks.remove(index); } } fn remove_first_thinking_matching(&mut self, committed: &str) { if let Some(index) = self.blocks.iter().position(|block| match block { TrackedBlock::Thinking { text, .. } => text == committed, _ => false, }) { self.blocks.remove(index); } } fn remove_tool_call(&mut self, call_id: &str) { if let Some(index) = self.blocks.iter().position(|block| match block { TrackedBlock::ToolCall { id, .. } => id == call_id, _ => false, }) { self.blocks.remove(index); } } } impl TrackedBlock { fn block_id(&self) -> InFlightBlockId { match self { TrackedBlock::Text { block_id, .. } | TrackedBlock::Thinking { block_id, .. } | TrackedBlock::ToolCall { block_id, .. } => *block_id, } } fn to_snapshot_block(&self) -> Option { match self { TrackedBlock::Text { text, finished, .. } => { if text.is_empty() { None } else { Some(InFlightBlock::Text { text: text.clone(), finished: *finished, }) } } TrackedBlock::Thinking { text, finished, .. } => Some(InFlightBlock::Thinking { text: text.clone(), finished: *finished, }), TrackedBlock::ToolCall { id, name, args, state, .. } => Some(InFlightBlock::ToolCall { id: id.clone(), name: name.clone(), args: args.clone(), state: *state, }), } } } pub(crate) fn snapshot_from_guard(guard: &MutexGuard<'_, InFlightInner>) -> InFlightSnapshot { guard.snapshot() } #[cfg(test)] mod tests { use super::*; #[test] fn snapshot_boundary_does_not_duplicate_or_gap_delta_sent_after_subscribe() { let (event_tx, _) = broadcast::channel(16); let in_flight = InFlightEvents::new(event_tx.clone()); let block_id = in_flight.start_text_block(); in_flight.text_delta(block_id, "hel".into()); let guard = in_flight.snapshot_guard(); let mut rx = event_tx.subscribe(); let snapshot = snapshot_from_guard(&guard); drop(guard); in_flight.text_delta(block_id, "lo".into()); assert_eq!( snapshot.blocks, vec![InFlightBlock::Text { text: "hel".into(), finished: false, }] ); assert!(matches!( rx.try_recv().unwrap(), Event::TextDelta { text } if text == "lo" )); assert!(rx.try_recv().is_err()); } #[test] fn session_log_and_in_flight_snapshot_prevents_mirror_only_assistant_gap() { use std::sync::mpsc; use std::thread; use crate::segment_log_sink::SegmentLogSink; use session_store::{LogEntry, LoggedRole}; let (event_tx, _) = broadcast::channel(16); let sink = SegmentLogSink::new(); let in_flight = InFlightEvents::new(event_tx); let block_id = in_flight.start_text_block(); in_flight.text_delta(block_id, "done".into()); in_flight.text_done(block_id, "done".into()); let assistant_item = LoggedItem::Message { role: LoggedRole::Assistant, content: vec![LoggedContentPart::Text { text: "done".into(), }], }; let assistant_entry = LogEntry::AssistantItem { ts: 1, item: assistant_item.clone(), }; let in_flight_guard = in_flight.snapshot_guard(); let in_flight_for_commit = in_flight.clone(); let sink_for_commit = sink.clone(); let (committed_tx, committed_rx) = mpsc::channel(); let commit_thread = thread::spawn(move || { // This mirrors Pod::append_entry ordering: clear in-flight first, // then publish the finalized AssistantItem. AssistantItem entries // are mirror-only and are not delivered as live entry events. in_flight_for_commit.clear_for_committed_item_then(&assistant_item, || { sink_for_commit.publish(assistant_entry); }); committed_tx.send(()).unwrap(); }); let (entries_snapshot, mut entry_rx) = sink.subscribe_with_snapshot(); let in_flight_snapshot = snapshot_from_guard(&in_flight_guard); drop(in_flight_guard); committed_rx.recv().unwrap(); commit_thread.join().unwrap(); assert!(entries_snapshot.is_empty()); assert!(matches!( in_flight_snapshot.blocks.as_slice(), [InFlightBlock::Text { text, finished: true }] if text == "done" )); assert!(entry_rx.try_recv().is_err()); let post_commit_guard = in_flight.snapshot_guard(); assert!(snapshot_from_guard(&post_commit_guard).is_empty()); } #[test] fn committed_assistant_snapshot_does_not_duplicate_in_flight_block() { use crate::segment_log_sink::SegmentLogSink; use session_store::{LogEntry, LoggedRole}; let (event_tx, _) = broadcast::channel(16); let sink = SegmentLogSink::new(); let in_flight = InFlightEvents::new(event_tx); let block_id = in_flight.start_text_block(); in_flight.text_delta(block_id, "done".into()); in_flight.text_done(block_id, "done".into()); let assistant_item = LoggedItem::Message { role: LoggedRole::Assistant, content: vec![LoggedContentPart::Text { text: "done".into(), }], }; let assistant_entry = LogEntry::AssistantItem { ts: 1, item: assistant_item.clone(), }; in_flight.clear_for_committed_item_then(&assistant_item, || { sink.publish(assistant_entry); }); let in_flight_guard = in_flight.snapshot_guard(); let (entries_snapshot, _entry_rx) = sink.subscribe_with_snapshot(); let in_flight_snapshot = snapshot_from_guard(&in_flight_guard); assert!(matches!( entries_snapshot.as_slice(), [LogEntry::AssistantItem { item, .. }] if item == &assistant_item )); assert!(in_flight_snapshot.is_empty()); } #[test] fn committed_item_clears_matching_in_flight_block() { let (event_tx, _) = broadcast::channel(16); let in_flight = InFlightEvents::new(event_tx); let block_id = in_flight.start_text_block(); in_flight.text_delta(block_id, "done".into()); in_flight.clear_for_committed_item_then( &LoggedItem::Message { role: session_store::LoggedRole::Assistant, content: vec![LoggedContentPart::Text { text: "done".into(), }], }, || (), ); let guard = in_flight.snapshot_guard(); assert!(snapshot_from_guard(&guard).is_empty()); } }