mod common; use std::sync::Arc; use async_trait::async_trait; use common::MockLlmClient; use llm_worker::Worker; use llm_worker::interceptor::{Interceptor, TurnEndAction}; use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent}; use llm_worker::llm_client::types::{Item, RequestConfig}; use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta, ToolOutput}; use session_store::{FsStore, LogEntry, SegmentStartState, Store, collect_state}; // ============================================================================= // Helpers // ============================================================================= fn simple_text_events() -> Vec { vec![ Event::text_block_start(0), Event::text_delta(0, "Hello!"), Event::text_block_stop(0, None), Event::Status(StatusEvent { status: ResponseStatus::Completed, }), ] } fn tool_call_events() -> Vec> { vec![ // 1st response: tool call vec![ Event::tool_use_start(0, "call_1", "get_weather"), Event::tool_input_delta(0, r#"{"city":"Tokyo"}"#), Event::tool_use_stop(0), Event::Status(StatusEvent { status: ResponseStatus::Completed, }), ], // 2nd response: final text vec![ Event::text_block_start(0), Event::text_delta(0, "It's sunny in Tokyo!"), Event::text_block_stop(0, None), Event::Status(StatusEvent { status: ResponseStatus::Completed, }), ], ] } #[derive(Clone)] struct MockWeatherTool; #[async_trait] impl Tool for MockWeatherTool { async fn execute(&self, _input_json: &str) -> Result { Ok("Sunny, 25C".to_string().into()) } } fn weather_tool_definition() -> ToolDefinition { Arc::new(|| { let meta = ToolMeta::new("get_weather") .description("Get weather") .input_schema(serde_json::json!({ "type": "object", "properties": { "city": { "type": "string" } }, "required": ["city"] })); (meta, Arc::new(MockWeatherTool) as Arc) }) } /// Policy that forces Pause on every turn end. struct PausePolicy; #[async_trait] impl Interceptor for PausePolicy { async fn on_turn_end(&self, _history: &[Item]) -> TurnEndAction { TurnEndAction::Pause } } fn make_store() -> (tempfile::TempDir, FsStore) { let dir = tempfile::tempdir().unwrap(); let store = FsStore::new(dir.path()).unwrap(); (dir, store) } /// Run a worker turn and persist via session-store functions. /// Takes ownership of the worker (needed for lock/unlock) and returns it. async fn run_and_persist( worker: Worker, store: &FsStore, session_id: session_store::SessionId, segment_id: session_store::SegmentId, input: &str, ) -> (Worker, llm_worker::WorkerResult) { // Mirror Pod's run-entry contract: log the user input as segments // before the worker pushes its flattened user_message; save_delta // skips the resulting user_message item to avoid double-write. session_store::save_user_input( store, session_id, segment_id, vec![protocol::Segment::text(input)], ) .unwrap(); let history_before = worker.history().len(); let mut locked = worker.lock(); let result = locked.run(input).await; let worker = locked.unlock(); let new_items = &worker.history()[history_before..]; session_store::save_delta(store, session_id, segment_id, new_items).unwrap(); session_store::save_turn_end(store, session_id, segment_id, worker.turn_count()).unwrap(); match &result { Ok(r) => { session_store::save_run_completed( store, session_id, segment_id, r.clone(), worker.last_run_interrupted(), ) .unwrap(); } Err(e) => { session_store::save_run_errored( store, session_id, segment_id, e.to_string(), worker.last_run_interrupted(), ) .unwrap(); } } let r = result.unwrap(); (worker, r) } // ============================================================================= // Tests // ============================================================================= #[tokio::test] async fn session_run_logs_entries() { let (_dir, store) = make_store(); let client = MockLlmClient::new(simple_text_events()); let worker = Worker::new(client); let (sid, segid) = session_store::create_segment( &store, SegmentStartState { system_prompt: worker.get_system_prompt(), config: worker.request_config(), history: worker.history(), }, ) .unwrap(); let (worker, _) = run_and_persist(worker, &store, sid, segid, "Hi").await; let _ = &worker; let entries = store.read_all(sid, segid).unwrap(); // SegmentStart, UserInput, AssistantItem, TurnEnd, RunCompleted (at minimum) assert!( entries.len() >= 4, "expected at least 4 entries, got {}", entries.len() ); // First entry is SegmentStart assert!(matches!(&entries[0], LogEntry::SegmentStart { .. })); // Has a RunCompleted with Finished let has_finished = entries.iter().any(|e| { matches!( e, LogEntry::RunCompleted { result: llm_worker::WorkerResult::Finished, .. } ) }); assert!(has_finished, "should have a Finished outcome"); } #[tokio::test] async fn session_restore_round_trip() { let (_dir, store) = make_store(); let client = MockLlmClient::new(simple_text_events()); let mut worker = Worker::new(client); worker.set_system_prompt("You are helpful."); let (sid, segid) = session_store::create_segment( &store, SegmentStartState { system_prompt: worker.get_system_prompt(), config: worker.request_config(), history: worker.history(), }, ) .unwrap(); let (worker, _) = run_and_persist(worker, &store, sid, segid, "Hi").await; let original_history_len = worker.history().len(); let original_turn_count = worker.turn_count(); // Restore let state = session_store::restore(&store, sid, segid).unwrap(); assert_eq!(state.session_id, Some(sid)); assert_eq!(state.history.len(), original_history_len); assert_eq!(state.turn_count, original_turn_count); assert_eq!(state.system_prompt.as_deref(), Some("You are helpful.")); assert_eq!( state.entries_count, store.read_entry_count(sid, segid).unwrap() ); // Shim by segment ID alone. let by_segment = session_store::restore_by_segment(&store, segid).unwrap(); assert_eq!(by_segment.session_id, Some(sid)); } #[tokio::test] async fn session_run_with_tool_call() { let (_dir, store) = make_store(); let client = MockLlmClient::with_responses(tool_call_events()); let mut worker = Worker::new(client); worker.register_tool(weather_tool_definition()); let (sid, segid) = session_store::create_segment( &store, SegmentStartState { system_prompt: worker.get_system_prompt(), config: worker.request_config(), history: worker.history(), }, ) .unwrap(); let (_worker, _) = run_and_persist(worker, &store, sid, segid, "What's the weather?").await; let entries = store.read_all(sid, segid).unwrap(); let has_tool_results = entries .iter() .any(|e| matches!(e, LogEntry::ToolResult { .. })); assert!(has_tool_results, "should have ToolResult entry"); let has_assistant = entries .iter() .any(|e| matches!(e, LogEntry::AssistantItem { .. })); assert!(has_assistant, "should have AssistantItem entry"); } #[tokio::test] async fn session_resume_after_pause() { let (_dir, store) = make_store(); // First run: tool call with pause policy → Paused let client = MockLlmClient::with_responses(tool_call_events()); let mut worker = Worker::new(client); worker.register_tool(weather_tool_definition()); worker.set_interceptor(PausePolicy); let (sid, segid) = session_store::create_segment( &store, SegmentStartState { system_prompt: worker.get_system_prompt(), config: worker.request_config(), history: worker.history(), }, ) .unwrap(); let (_worker, result) = run_and_persist(worker, &store, sid, segid, "Weather?").await; assert!(matches!(result, llm_worker::WorkerResult::Paused)); // Check RunCompleted is Paused let entries = store.read_all(sid, segid).unwrap(); let has_paused = entries.iter().any(|e| { matches!( e, LogEntry::RunCompleted { result: llm_worker::WorkerResult::Paused, .. } ) }); assert!(has_paused, "should have Paused outcome"); // Restore state and verify let state = session_store::restore(&store, sid, segid).unwrap(); assert!(state.last_run_interrupted); } #[tokio::test] async fn session_fork_creates_new_session() { let (_dir, store) = make_store(); let client = MockLlmClient::new(simple_text_events()); let mut worker = Worker::new(client); worker.set_system_prompt("System prompt"); let (sid, segid) = session_store::create_segment( &store, SegmentStartState { system_prompt: worker.get_system_prompt(), config: worker.request_config(), history: worker.history(), }, ) .unwrap(); let (worker, _) = run_and_persist(worker, &store, sid, segid, "Hello").await; let original_history_len = worker.history().len(); let (fork_sid, fork_segid) = session_store::fork( &store, SegmentStartState { system_prompt: worker.get_system_prompt(), config: worker.request_config(), history: worker.history(), }, ) .unwrap(); assert_ne!(fork_sid, sid, "`fork` mints a fresh Session"); // Fork should have a SegmentStart with the current history let fork_entries = store.read_all(fork_sid, fork_segid).unwrap(); assert_eq!(fork_entries.len(), 1); assert!(matches!(&fork_entries[0], LogEntry::SegmentStart { .. })); let fork_state = collect_state(&fork_entries); assert_eq!(fork_state.session_id, Some(fork_sid)); assert_eq!(fork_state.history.len(), original_history_len); assert_eq!(fork_state.system_prompt.as_deref(), Some("System prompt")); } #[tokio::test] async fn session_fork_at_truncates_within_session() { let (_dir, store) = make_store(); let client = MockLlmClient::new(simple_text_events()); let worker = Worker::new(client); let (sid, segid) = session_store::create_segment( &store, SegmentStartState { system_prompt: worker.get_system_prompt(), config: worker.request_config(), history: worker.history(), }, ) .unwrap(); let (worker, _) = run_and_persist(worker, &store, sid, segid, "Hello").await; let all_entries = store.read_all(sid, segid).unwrap(); assert!(all_entries.len() > 2); // Fork at turn 1 (one completed turn). Stays in same Session. let fork_segid = session_store::fork_at(&store, sid, segid, worker.turn_count()).unwrap(); let fork_entries = store.read_all(sid, fork_segid).unwrap(); assert_eq!(fork_entries.len(), 1); // Just the new SegmentStart let fork_state = collect_state(&fork_entries); assert_eq!(fork_state.session_id, Some(sid), "fork_at inherits Session"); // History at fork point should match history right after the TurnEnd in // the source segment. let turn_end_pos = all_entries .iter() .position(|e| matches!(e, LogEntry::TurnEnd { turn_count, .. } if *turn_count == worker.turn_count())) .expect("source segment has the matching TurnEnd"); let source_state_at_fork = collect_state(&all_entries[..=turn_end_pos]); assert_eq!(fork_state.history.len(), source_state_at_fork.history.len()); // list_segments should show both source and fork in the same Session. let segs = store.list_segments(sid).unwrap(); assert!(segs.contains(&segid)); assert!(segs.contains(&fork_segid)); } #[tokio::test] async fn session_config_changed_logged() { let (_dir, store) = make_store(); let client = MockLlmClient::new(vec![]); let mut worker = Worker::new(client); let (sid, segid) = session_store::create_segment( &store, SegmentStartState { system_prompt: worker.get_system_prompt(), config: worker.request_config(), history: worker.history(), }, ) .unwrap(); // Modify config and log it let new_config = RequestConfig::default().with_temperature(0.7); worker.set_request_config(new_config.clone()); session_store::save_config_changed(&store, sid, segid, &new_config).unwrap(); let entries = store.read_all(sid, segid).unwrap(); let has_config_changed = entries.iter().any(|e| { matches!( e, LogEntry::ConfigChanged { config, .. } if config.temperature == Some(0.7) ) }); assert!(has_config_changed, "should have ConfigChanged entry"); } #[tokio::test] async fn session_auto_forks_on_conflict() { let (_dir, store) = make_store(); // Create a segment let client_a = MockLlmClient::new(simple_text_events()); let worker_a = Worker::new(client_a); let (sid, original_segid) = session_store::create_segment( &store, SegmentStartState { system_prompt: worker_a.get_system_prompt(), config: worker_a.request_config(), history: worker_a.history(), }, ) .unwrap(); let mut segment_id = original_segid; // Writer tracked: just the SegmentStart we wrote. let mut entries_written: usize = 1; // Simulate another Pod writing to the same segment behind our back. let extra_entry = LogEntry::UserInput { ts: 9999, segments: vec![protocol::Segment::text("Interloper")], }; store.append(sid, original_segid, &extra_entry).unwrap(); // Now the on-disk count exceeds our tally — ensure_head_or_fork should auto-fork. session_store::ensure_head_or_fork( &store, sid, &mut segment_id, &mut entries_written, /* at_turn_index */ 0, SegmentStartState { system_prompt: worker_a.get_system_prompt(), config: worker_a.request_config(), history: worker_a.history(), }, ) .unwrap(); // segment_id should now be different but live in the same Session. assert_ne!(segment_id, original_segid); // The fork segment should exist and have entries let fork_entries = store.read_all(sid, segment_id).unwrap(); assert!(!fork_entries.is_empty()); let fork_state = collect_state(&fork_entries); assert_eq!( fork_state.session_id, Some(sid), "auto-fork inherits Session" ); // The new segment records its lineage forward via forked_from; the // source segment is left immutable (no terminal marker written back). match &fork_entries[0] { LogEntry::SegmentStart { forked_from: Some(origin), .. } => { assert_eq!(origin.segment_id, original_segid); assert_eq!(origin.at_turn_index, 0); } other => panic!("expected SegmentStart with forked_from, got {other:?}"), } // Original segment should still have the interloper entry and NO // terminal fork marker — it is byte-for-byte unchanged. let original_entries = store.read_all(sid, original_segid).unwrap(); assert_eq!( original_entries.len(), 2, "source segment holds only SegmentStart + interloper UserInput" ); let has_interloper = original_entries .iter() .any(|e| matches!(e, LogEntry::UserInput { .. })); assert!(has_interloper); } /// Nested past-fork: forking a segment that is itself a fork must not /// require touching any ancestor. Each `fork_at` only reads its direct /// source and seeds a new segment, so a chain of forks composes cleanly. #[tokio::test] async fn nested_past_fork_leaves_ancestors_immutable() { let (_dir, store) = make_store(); let client = MockLlmClient::new(simple_text_events()); let worker = Worker::new(client); let (sid, root_segid) = session_store::create_segment( &store, SegmentStartState { system_prompt: worker.get_system_prompt(), config: worker.request_config(), history: worker.history(), }, ) .unwrap(); let (worker, _) = run_and_persist(worker, &store, sid, root_segid, "Hello").await; let root_before = store.read_all(sid, root_segid).unwrap(); // First past-fork at the completed turn. let fork1 = session_store::fork_at(&store, sid, root_segid, worker.turn_count()).unwrap(); // Fork the fork (turn 0 = right after its SegmentStart seed). let fork2 = session_store::fork_at(&store, sid, fork1, 0).unwrap(); // All three are distinct, all in the same Session. assert_ne!(fork1, root_segid); assert_ne!(fork2, fork1); for seg in [root_segid, fork1, fork2] { assert_eq!( collect_state(&store.read_all(sid, seg).unwrap()).session_id, Some(sid) ); } // The root and fork1 are untouched by forking their descendants. assert_eq!( store.read_all(sid, root_segid).unwrap().len(), root_before.len() ); let fork1_entries = store.read_all(sid, fork1).unwrap(); assert_eq!( fork1_entries.len(), 1, "fork1 is just its SegmentStart seed" ); // fork2's lineage points at fork1, not the root. match &store.read_all(sid, fork2).unwrap()[0] { LogEntry::SegmentStart { forked_from: Some(origin), .. } => assert_eq!(origin.segment_id, fork1), other => panic!("expected SegmentStart with forked_from, got {other:?}"), } }