From 9b78c51d0abfc3b9d9c51884d1e8b3dc74889f7d Mon Sep 17 00:00:00 2001 From: Hare Date: Sat, 11 Apr 2026 18:47:33 +0900 Subject: [PATCH] =?UTF-8?q?Worker=E3=81=AE=E8=87=AA=E5=8B=95=E3=82=AD?= =?UTF-8?q?=E3=83=A3=E3=83=83=E3=82=B7=E3=83=A5=E3=83=AD=E3=83=83=E3=82=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- TODO.md | 2 +- .../src/inspect_tool.rs | 6 +- crates/llm-worker-persistence/src/session.rs | 83 ++++--- .../llm-worker-persistence/src/session_log.rs | 9 +- .../tests/session_test.rs | 24 +- .../llm-worker/examples/worker_cancel_demo.rs | 54 ++--- crates/llm-worker/examples/worker_cli.rs | 31 ++- crates/llm-worker/src/lib.rs | 10 +- crates/llm-worker/src/state.rs | 10 +- crates/llm-worker/src/tool_server.rs | 107 ++++++--- crates/llm-worker/src/worker.rs | 220 +++++++++--------- crates/llm-worker/tests/callback_test.rs | 6 +- crates/llm-worker/tests/compile_fail.rs | 2 +- .../tests/parallel_execution_test.rs | 17 +- ...gister_tool.rs => locked_register_tool.rs} | 0 ...ool.stderr => locked_register_tool.stderr} | 6 +- .../tool_server_handle_register_tool.stderr | 11 +- crates/llm-worker/tests/worker_fixtures.rs | 11 +- crates/llm-worker/tests/worker_state_test.rs | 49 ++-- crates/pod/examples/pod_cli.rs | 2 +- crates/pod/src/controller.rs | 8 +- crates/pod/src/pod.rs | 2 +- tickets/worker-auto-lock.md | 57 ----- 23 files changed, 375 insertions(+), 352 deletions(-) rename crates/llm-worker/tests/ui/{cache_locked_register_tool.rs => locked_register_tool.rs} (100%) rename crates/llm-worker/tests/ui/{cache_locked_register_tool.stderr => locked_register_tool.stderr} (63%) delete mode 100644 tickets/worker-auto-lock.md diff --git a/TODO.md b/TODO.md index d7ad572f..41f251a7 100644 --- a/TODO.md +++ b/TODO.md @@ -3,7 +3,7 @@ - [x] ツール出力の遅延読み込み設計 (ToolOutput / BlobStore / auto_summarize) - [ ] ツール設計 - [ ] ツールの動的追加/削除 → [tickets/tool-dynamic-registry.md](tickets/tool-dynamic-registry.md) - - [ ] run() 自動ロックとファクトリ遅延初期化 → [tickets/worker-auto-lock.md](tickets/worker-auto-lock.md) + - [x] run() 自動ロックとファクトリ遅延初期化 → [tickets/worker-auto-lock.md](tickets/worker-auto-lock.md) - [x] inspect ツール実装 - [x] max_turns: マニフェストによるターン数制限 - [x] pod バイナリエントリポイント diff --git a/crates/llm-worker-persistence/src/inspect_tool.rs b/crates/llm-worker-persistence/src/inspect_tool.rs index 91d01083..961d57f5 100644 --- a/crates/llm-worker-persistence/src/inspect_tool.rs +++ b/crates/llm-worker-persistence/src/inspect_tool.rs @@ -13,7 +13,6 @@ use serde_json::json; use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta}; use llm_worker::state::Mutable; -use llm_worker::ToolRegistryError; use llm_worker::Worker; use llm_worker::llm_client::LlmClient; @@ -332,12 +331,11 @@ fn apply_selector(content: &Content, selector: &Selector) -> Result( worker: &mut Worker, blob_store: Arc, -) -> Result<(), ToolRegistryError> -where +) where C: LlmClient, B: BlobStore + 'static, { - worker.register_tool(InspectTool::::tool_definition(blob_store)) + worker.register_tool(InspectTool::::tool_definition(blob_store)); } // ─── Tests ─────────────────────────────────────────────────────────────────── diff --git a/crates/llm-worker-persistence/src/session.rs b/crates/llm-worker-persistence/src/session.rs index 9d2ea1b7..e7516ae4 100644 --- a/crates/llm-worker-persistence/src/session.rs +++ b/crates/llm-worker-persistence/src/session.rs @@ -43,12 +43,13 @@ pub enum SessionError { /// Persistent session wrapping a [`Worker`]. /// -/// The `worker` field is public for direct access to Worker APIs -/// (tool registration, hook setup, subscriber management, etc.). -/// State-mutating operations (`run`, `resume`) should go through -/// Session methods to ensure proper logging. +/// Use [`worker()`](Self::worker) / [`worker_mut()`](Self::worker_mut) to +/// access the underlying Worker for configuration (tool registration, etc.). +/// State-mutating operations (`run`, `resume`) should go through Session +/// methods to ensure proper logging. pub struct Session { - pub worker: Worker, + /// Always `Some` outside of `run()` / `resume()`. + worker: Option>, store: St, session_id: SessionId, head_hash: Option, @@ -78,7 +79,7 @@ impl Session { store.append(session_id, &hashed_entry).await?; Ok(Self { - worker, + worker: Some(worker), store, session_id, head_hash: Some(hashed), @@ -87,9 +88,6 @@ impl Session { } /// Restore a session from a stored log. - /// - /// Reads all log entries, collects state from them, - /// and returns a `Session` ready for `resume()`. pub async fn restore( client: C, store: St, @@ -109,7 +107,7 @@ impl Session { worker.set_last_run_interrupted(state.last_run_interrupted); Ok(Self { - worker, + worker: Some(worker), store, session_id, head_hash: state.head_hash, @@ -117,6 +115,20 @@ impl Session { }) } + fn w(&self) -> &Worker { + self.worker.as_ref().expect("worker taken during run") + } + + /// Reference to the underlying Worker. + pub fn worker(&self) -> &Worker { + self.w() + } + + /// Mutable reference to the underlying Worker. + pub fn worker_mut(&mut self) -> &mut Worker { + self.worker.as_mut().expect("worker taken during run") + } + /// The session ID. pub fn session_id(&self) -> SessionId { self.session_id @@ -133,15 +145,23 @@ impl Session { } /// Run a user turn, logging all state changes. + /// + /// Internally locks the Worker (flushing pending tools), runs the turn, + /// then unlocks back to Mutable state. pub async fn run( &mut self, user_input: impl Into, ) -> Result { + let input = user_input.into(); self.ensure_head_or_fork().await?; - let history_before = self.worker.history().len(); + let history_before = self.w().history().len(); - let result = self.worker.run(user_input).await; + // lock → run → unlock (use lock() directly to keep worker on error) + let worker = self.worker.take().expect("worker taken during run"); + let mut locked = worker.lock(); + let result = locked.run(input).await; + self.worker = Some(locked.unlock()); self.log_history_delta(history_before).await?; self.log_turn_end().await?; @@ -154,9 +174,13 @@ impl Session { pub async fn resume(&mut self) -> Result { self.ensure_head_or_fork().await?; - let history_before = self.worker.history().len(); + let history_before = self.w().history().len(); - let result = self.worker.resume().await; + // lock → resume → unlock + let worker = self.worker.take().expect("worker taken during run"); + let mut locked = worker.lock(); + let result = locked.resume().await; + self.worker = Some(locked.unlock()); self.log_history_delta(history_before).await?; self.log_turn_end().await?; @@ -166,15 +190,13 @@ impl Session { } /// Fork this session at its current state. - /// Returns the new session ID. The new log contains a `SessionStart` - /// seeded with the current history. pub async fn fork(&self) -> Result { let fork_id = crate::new_session_id(); let entry = LogEntry::SessionStart { ts: session_log::now_millis(), - system_prompt: self.worker.get_system_prompt().map(String::from), - config: self.worker.request_config().clone(), - history: self.worker.history().to_vec(), + system_prompt: self.w().get_system_prompt().map(String::from), + config: self.w().request_config().clone(), + history: self.w().history().to_vec(), }; let hashed = session_log::compute_hash(None, &entry); let hashed_entry = HashedEntry { @@ -189,8 +211,6 @@ impl Session { } /// Fork from an arbitrary point in a stored session's log. - /// Finds the entry matching `at_hash` and creates a new session - /// with state reconstructed up to that point. pub async fn fork_at( store: &St, source_id: SessionId, @@ -221,12 +241,12 @@ impl Session { Ok(fork_id) } - /// Log a `CacheLocked` entry. + /// Log a `Locked` entry. pub async fn log_cache_locked( &mut self, locked_prefix_len: usize, ) -> Result<(), StoreError> { - let entry = LogEntry::CacheLocked { + let entry = LogEntry::Locked { ts: session_log::now_millis(), locked_prefix_len, }; @@ -245,14 +265,13 @@ impl Session { pub async fn log_config_changed(&mut self) -> Result<(), StoreError> { let entry = LogEntry::ConfigChanged { ts: session_log::now_millis(), - config: self.worker.request_config().clone(), + config: self.w().request_config().clone(), }; self.append_entry(entry).await } // ── Private helpers ────────────────────────────────────────────────── - /// Append a `LogEntry`, computing its hash and updating `head_hash`. async fn append_entry(&mut self, entry: LogEntry) -> Result<(), StoreError> { let hash = session_log::compute_hash(self.head_hash.as_ref(), &entry); let hashed_entry = HashedEntry { @@ -267,19 +286,17 @@ impl Session { Ok(()) } - /// Check that the store's head still matches ours. If not, auto-fork. async fn ensure_head_or_fork(&mut self) -> Result<(), StoreError> { let store_head = self.store.read_head_hash(self.session_id).await?; if store_head == self.head_hash { return Ok(()); } - // Another writer advanced this session — fork from our known state. let fork_id = crate::new_session_id(); let entry = LogEntry::SessionStart { ts: session_log::now_millis(), - system_prompt: self.worker.get_system_prompt().map(String::from), - config: self.worker.request_config().clone(), - history: self.worker.history().to_vec(), + system_prompt: self.w().get_system_prompt().map(String::from), + config: self.w().request_config().clone(), + history: self.w().history().to_vec(), }; let hash = session_log::compute_hash(None, &entry); let hashed_entry = HashedEntry { @@ -296,7 +313,7 @@ impl Session { } async fn log_history_delta(&mut self, before_len: usize) -> Result<(), StoreError> { - let history = self.worker.history(); + let history = self.w().history(); if history.len() <= before_len { return Ok(()); } @@ -356,7 +373,7 @@ impl Session { async fn log_turn_end(&mut self) -> Result<(), StoreError> { self.append_entry(LogEntry::TurnEnd { ts: session_log::now_millis(), - turn_count: self.worker.turn_count(), + turn_count: self.w().turn_count(), }) .await } @@ -376,7 +393,7 @@ impl Session { self.append_entry(LogEntry::RunOutcome { ts: session_log::now_millis(), outcome, - interrupted: self.worker.last_run_interrupted(), + interrupted: self.w().last_run_interrupted(), }) .await } diff --git a/crates/llm-worker-persistence/src/session_log.rs b/crates/llm-worker-persistence/src/session_log.rs index be907877..476ab316 100644 --- a/crates/llm-worker-persistence/src/session_log.rs +++ b/crates/llm-worker-persistence/src/session_log.rs @@ -88,7 +88,7 @@ pub struct HashedEntry { /// - `SessionStart` — always the first entry; captures initial state /// - `UserInput` / `AssistantItems` / `ToolResults` / `HookInjectedItems` — history appends /// - `TurnEnd` — turn boundary marker -/// - `CacheLocked` / `CacheUnlocked` — KV cache state transitions +/// - `Locked` / `CacheUnlocked` — KV cache state transitions /// - `RunOutcome` — marks end of a `run()` or `resume()` call /// - `ConfigChanged` — `RequestConfig` mutation #[derive(Debug, Clone, Serialize, Deserialize)] @@ -119,7 +119,8 @@ pub enum LogEntry { TurnEnd { ts: u64, turn_count: usize }, /// KV cache locked. Records the history prefix length that is now immutable. - CacheLocked { ts: u64, locked_prefix_len: usize }, + #[serde(alias = "cache_locked")] + Locked { ts: u64, locked_prefix_len: usize }, /// KV cache unlocked. CacheUnlocked { ts: u64 }, @@ -200,7 +201,7 @@ pub fn collect_state(entries: &[HashedEntry]) -> RestoredState { LogEntry::TurnEnd { turn_count, .. } => { state.turn_count = *turn_count; } - LogEntry::CacheLocked { + LogEntry::Locked { locked_prefix_len, .. } => { state.locked_prefix_len = *locked_prefix_len; @@ -354,7 +355,7 @@ mod tests { config: RequestConfig::default(), history: vec![Item::user_message("a"), Item::assistant_message("b")], }, - LogEntry::CacheLocked { + LogEntry::Locked { ts: 2000, locked_prefix_len: 2, }, diff --git a/crates/llm-worker-persistence/tests/session_test.rs b/crates/llm-worker-persistence/tests/session_test.rs index 25aece74..c77b7cff 100644 --- a/crates/llm-worker-persistence/tests/session_test.rs +++ b/crates/llm-worker-persistence/tests/session_test.rs @@ -159,8 +159,8 @@ async fn session_restore_round_trip() { session.run("Hi").await.unwrap(); - let original_history = session.worker.history().to_vec(); - let original_turn_count = session.worker.turn_count(); + let original_history = session.worker().history().to_vec(); + let original_turn_count = session.worker().turn_count(); let original_head_hash = session.head_hash().cloned(); // Restore @@ -170,10 +170,10 @@ async fn session_restore_round_trip() { .await .unwrap(); - assert_eq!(restored.worker.history().len(), original_history.len()); - assert_eq!(restored.worker.turn_count(), original_turn_count); + assert_eq!(restored.worker().history().len(), original_history.len()); + assert_eq!(restored.worker().turn_count(), original_turn_count); assert_eq!( - restored.worker.get_system_prompt().map(String::from), + restored.worker().get_system_prompt().map(String::from), Some("You are helpful.".to_string()) ); assert_eq!(restored.head_hash(), original_head_hash.as_ref()); @@ -184,7 +184,7 @@ async fn session_run_with_tool_call() { let (_dir, store) = make_store().await; let client = MockLlmClient::with_responses(tool_call_events()); let mut worker = Worker::new(client); - worker.register_tool(weather_tool_definition()).unwrap(); + worker.register_tool(weather_tool_definition()); let mut session = Session::new(worker, store.clone(), SessionConfig::default()) .await @@ -213,7 +213,7 @@ async fn session_resume_after_pause() { // First run: tool call with pause hook → Paused let client = MockLlmClient::with_responses(tool_call_events()); let mut worker = Worker::new(client); - worker.register_tool(weather_tool_definition()).unwrap(); + worker.register_tool(weather_tool_definition()); worker.set_interceptor(PausePolicy); let mut session = Session::new(worker, store.clone(), SessionConfig::default()) @@ -251,7 +251,7 @@ async fn session_resume_after_pause() { .await .unwrap(); - assert!(restored.worker.last_run_interrupted()); + assert!(restored.worker().last_run_interrupted()); // resume may or may not succeed depending on Worker internal state, // but the restore itself should work @@ -271,7 +271,7 @@ async fn session_fork_preserves_state() { session.run("Hello").await.unwrap(); - let original_history_len = session.worker.history().len(); + let original_history_len = session.worker().history().len(); let fork_id = session.fork().await.unwrap(); // Fork should have a SessionStart with the current history @@ -334,7 +334,7 @@ async fn session_config_changed_logged() { // Modify config via worker and log it session - .worker + .worker_mut() .set_request_config(RequestConfig::default().with_temperature(0.7)); session.log_config_changed().await.unwrap(); @@ -367,13 +367,13 @@ async fn session_cache_lock_unlock_logged() { let has_locked = entries.iter().any(|e| { matches!( &e.entry, - LogEntry::CacheLocked { + LogEntry::Locked { locked_prefix_len: 5, .. } ) }); - assert!(has_locked, "should have CacheLocked entry"); + assert!(has_locked, "should have Locked entry"); let has_unlocked = entries .iter() diff --git a/crates/llm-worker/examples/worker_cancel_demo.rs b/crates/llm-worker/examples/worker_cancel_demo.rs index 0d0dec00..2f8afd87 100644 --- a/crates/llm-worker/examples/worker_cancel_demo.rs +++ b/crates/llm-worker/examples/worker_cancel_demo.rs @@ -4,9 +4,7 @@ use llm_worker::llm_client::providers::anthropic::AnthropicClient; use llm_worker::{Worker, WorkerResult}; -use std::sync::Arc; use std::time::Duration; -use tokio::sync::Mutex; #[tokio::main] async fn main() -> Result<(), Box> { @@ -25,48 +23,38 @@ async fn main() -> Result<(), Box> { std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY environment variable not set"); let client = AnthropicClient::new(&api_key, "claude-sonnet-4-20250514"); - let worker = Arc::new(Mutex::new(Worker::new(client))); + let worker = Worker::new(client); println!("🚀 Starting Worker..."); println!("💡 Will cancel after 2 seconds\n"); - // Get cancel sender first (without holding lock) - let cancel_tx = { - let w = worker.lock().await; - w.cancel_sender() - }; + // Get cancel sender before run (Mutable state) + let cancel_tx = worker.cancel_sender(); - // Task 1: Run Worker - let worker_clone = worker.clone(); - let task = tokio::spawn(async move { - let mut w = worker_clone.lock().await; - println!("📡 Sending request to LLM..."); - - match w.run("Tell me a very long story about a brave knight. Make it as detailed as possible with many paragraphs.").await { - Ok(WorkerResult::Finished) => { - println!("✅ Task completed normally"); - } - Ok(WorkerResult::Paused) => { - println!("⏸️ Task paused"); - } - Ok(WorkerResult::LimitReached) => { - println!("🔒 Turn limit reached"); - } - Err(e) => { - println!("❌ Task error: {}", e); - } - } - }); - - // Task 2: Cancel after 2 seconds + // Task: Cancel after 2 seconds tokio::spawn(async move { tokio::time::sleep(Duration::from_secs(2)).await; println!("\n🛑 Cancelling worker..."); let _ = cancel_tx.send(()).await; }); - // Wait for task completion - task.await?; + println!("📡 Sending request to LLM..."); + + // Mutable::run consumes self → (Locked, WorkerResult) + match worker.run("Tell me a very long story about a brave knight. Make it as detailed as possible with many paragraphs.").await { + Ok((_locked, WorkerResult::Finished)) => { + println!("✅ Task completed normally"); + } + Ok((_locked, WorkerResult::Paused)) => { + println!("⏸️ Task paused"); + } + Ok((_locked, WorkerResult::LimitReached)) => { + println!("🔒 Turn limit reached"); + } + Err(e) => { + println!("❌ Task error: {}", e); + } + } println!("\n✨ Demo complete!"); diff --git a/crates/llm-worker/examples/worker_cli.rs b/crates/llm-worker/examples/worker_cli.rs index ed6c606d..68041b2c 100644 --- a/crates/llm-worker/examples/worker_cli.rs +++ b/crates/llm-worker/examples/worker_cli.rs @@ -438,10 +438,8 @@ async fn main() -> Result<(), Box> { // Register tools (unless --no-tools) if !args.no_tools { let app = AppContext; - worker - .register_tool(app.get_current_time_definition()) - .unwrap(); - worker.register_tool(app.calculate_definition()).unwrap(); + worker.register_tool(app.get_current_time_definition()); + worker.register_tool(app.calculate_definition()); } // Register streaming display handlers @@ -465,7 +463,27 @@ async fn main() -> Result<(), Box> { return Ok(()); } - // Interactive loop + // Interactive loop — first input transitions Mutable → Locked + print!("\n👤 You: "); + io::stdout().flush()?; + + let mut first_input = String::new(); + io::stdin().read_line(&mut first_input)?; + let first_input = first_input.trim(); + + if first_input == "quit" || first_input == "exit" || first_input.is_empty() { + println!("\n👋 Goodbye!"); + return Ok(()); + } + + let (mut locked, _) = match worker.run(first_input).await { + Ok(pair) => pair, + Err(e) => { + eprintln!("\n❌ Error: {}", e); + return Ok(()); + } + }; + loop { print!("\n👤 You: "); io::stdout().flush()?; @@ -483,8 +501,7 @@ async fn main() -> Result<(), Box> { break; } - // Run Worker (Worker manages history) - match worker.run(input).await { + match locked.run(input).await { Ok(_) => {} Err(e) => { eprintln!("\n❌ Error: {}", e); diff --git a/crates/llm-worker/src/lib.rs b/crates/llm-worker/src/lib.rs index 771da62b..480e9513 100644 --- a/crates/llm-worker/src/lib.rs +++ b/crates/llm-worker/src/lib.rs @@ -27,12 +27,14 @@ //! //! # Cache Protection //! -//! To maximize KV cache hit rate, transition to the locked state -//! with [`Worker::lock()`] before execution. +//! `run()` automatically locks the cache. To edit state between turns, +//! call `unlock_cache()` first; the next `run()` re-locks automatically. //! //! ```ignore -//! let mut locked = worker.lock(); -//! locked.run("user input").await?; +//! worker.run("user input").await?; +//! worker.unlock_cache(); +//! worker.set_system_prompt("new prompt"); +//! worker.run("next input").await?; //! ``` mod handler; diff --git a/crates/llm-worker/src/state.rs b/crates/llm-worker/src/state.rs index e1127b09..2f67f0d6 100644 --- a/crates/llm-worker/src/state.rs +++ b/crates/llm-worker/src/state.rs @@ -1,7 +1,7 @@ //! Worker State //! //! State marker types for cache protection using the Type-state pattern. -//! Worker has state transitions from `Mutable` → `CacheLocked`. +//! Worker has state transitions from `Mutable` → `Locked`. /// Marker trait representing Worker state /// @@ -19,7 +19,7 @@ mod private { /// - Editing message history (add, delete, clear) /// - Registering tools and hooks /// -/// Can transition to [`CacheLocked`] state via `Worker::lock()`. +/// Can transition to [`Locked`] state via `Worker::lock()`. /// /// # Examples /// @@ -54,7 +54,7 @@ impl WorkerState for Mutable {} /// Can return to [`Mutable`] state via `Worker::unlock()`, /// but note that cache protection will be released. #[derive(Debug, Clone, Copy, Default)] -pub struct CacheLocked; +pub struct Locked; -impl private::Sealed for CacheLocked {} -impl WorkerState for CacheLocked {} +impl private::Sealed for Locked {} +impl WorkerState for Locked {} diff --git a/crates/llm-worker/src/tool_server.rs b/crates/llm-worker/src/tool_server.rs index 343c018b..82d580bc 100644 --- a/crates/llm-worker/src/tool_server.rs +++ b/crates/llm-worker/src/tool_server.rs @@ -26,6 +26,7 @@ pub enum ToolServerError { #[derive(Clone, Default)] pub struct ToolServer { tools: Arc>, + pending: Arc>>, } impl ToolServer { @@ -38,6 +39,7 @@ impl ToolServer { pub fn handle(&self) -> ToolServerHandle { ToolServerHandle { tools: Arc::clone(&self.tools), + pending: Arc::clone(&self.pending), } } } @@ -46,32 +48,57 @@ impl ToolServer { #[derive(Clone, Default)] pub struct ToolServerHandle { tools: Arc>, + pending: Arc>>, } impl ToolServerHandle { - /// Register one tool. - pub(crate) fn register_tool( - &self, - factory: WorkerToolDefinition, - ) -> Result<(), ToolServerError> { - let (meta, instance) = factory(); - let mut guard = self.tools.lock().unwrap_or_else(|e| e.into_inner()); - if guard.contains_key(&meta.name) { - return Err(ToolServerError::DuplicateName(meta.name)); - } - guard.insert(meta.name.clone(), (meta, instance)); - Ok(()) + /// Queue a tool factory for deferred initialization. + /// + /// The factory is **not** called here; it is stored and executed + /// when [`flush_pending`](Self::flush_pending) is called (typically + /// at the start of `Worker::run()`). + pub(crate) fn register_tool(&self, factory: WorkerToolDefinition) { + self.pending + .lock() + .unwrap_or_else(|e| e.into_inner()) + .push(factory); } - /// Register many tools. + /// Queue many tool factories for deferred initialization. pub(crate) fn register_tools( &self, factories: impl IntoIterator, - ) -> Result<(), ToolServerError> { - for factory in factories { - self.register_tool(factory)?; + ) { + let mut guard = self.pending.lock().unwrap_or_else(|e| e.into_inner()); + guard.extend(factories); + } + + /// Execute all pending factories and register the resulting tools. + /// + /// # Panics + /// + /// Panics if any factory produces a tool whose name collides with + /// an already-registered tool. Duplicate names are a programming + /// error and should be caught during development. + pub(crate) fn flush_pending(&self) { + let pending: Vec<_> = { + let mut guard = self.pending.lock().unwrap_or_else(|e| e.into_inner()); + std::mem::take(&mut *guard) + }; + if pending.is_empty() { + return; + } + // Execute all factories first, then validate and insert atomically. + let materialized: Vec<_> = pending.into_iter().map(|f| f()).collect(); + let mut tools = self.tools.lock().unwrap_or_else(|e| e.into_inner()); + for (meta, instance) in materialized { + assert!( + !tools.contains_key(&meta.name), + "duplicate tool name: '{}'", + meta.name, + ); + tools.insert(meta.name.clone(), (meta, instance)); } - Ok(()) } /// Get a tool by name for hook contexts. @@ -143,19 +170,37 @@ mod tests { } #[test] - fn register_duplicate_name_fails() { + fn flush_pending_registers_tools() { let handle = ToolServer::new().handle(); - handle.register_tool(def("alpha")).expect("first register"); - let err = handle - .register_tool(def("alpha")) - .expect_err("duplicate should fail"); - assert_eq!(err, ToolServerError::DuplicateName("alpha".to_string())); + handle.register_tool(def("alpha")); + handle.register_tool(def("beta")); + + // Before flush, no tools are available + assert!(handle.get_tool("alpha").is_none()); + + handle.flush_pending(); + + // After flush, tools are available + assert!(handle.get_tool("alpha").is_some()); + assert!(handle.get_tool("beta").is_some()); + } + + #[test] + #[should_panic(expected = "duplicate tool name: 'alpha'")] + fn flush_pending_duplicate_name_panics() { + let handle = ToolServer::new().handle(); + handle.register_tool(def("alpha")); + handle.flush_pending(); + + handle.register_tool(def("alpha")); + handle.flush_pending(); // panics } #[tokio::test] async fn call_tool_success_and_not_found() { let handle = ToolServer::new().handle(); - handle.register_tool(def("echo")).expect("register"); + handle.register_tool(def("echo")); + handle.flush_pending(); let out = handle.call_tool("echo", r#"{"x":1}"#).await.expect("call"); assert_eq!(out, r#"{"x":1}"#); @@ -170,9 +215,10 @@ mod tests { #[test] fn tool_definitions_are_sorted() { let handle = ToolServer::new().handle(); - handle.register_tool(def("zeta")).expect("register zeta"); - handle.register_tool(def("alpha")).expect("register alpha"); - handle.register_tool(def("beta")).expect("register beta"); + handle.register_tool(def("zeta")); + handle.register_tool(def("alpha")); + handle.register_tool(def("beta")); + handle.flush_pending(); let names: Vec<_> = handle .tool_definitions_sorted() @@ -181,4 +227,11 @@ mod tests { .collect(); assert_eq!(names, vec!["alpha", "beta", "zeta"]); } + + #[test] + fn flush_pending_is_noop_when_empty() { + let handle = ToolServer::new().handle(); + handle.flush_pending(); + handle.flush_pending(); + } } diff --git a/crates/llm-worker/src/worker.rs b/crates/llm-worker/src/worker.rs index 79f232cc..e8aae5e5 100644 --- a/crates/llm-worker/src/worker.rs +++ b/crates/llm-worker/src/worker.rs @@ -13,7 +13,7 @@ use crate::{ DefaultInterceptor, Interceptor, PostToolAction, PreRequestAction, PreToolAction, PromptAction, ToolCallInfo, ToolResultInfo, TurnEndAction, }, - state::{CacheLocked, Mutable, WorkerState}, + state::{Locked, Mutable, WorkerState}, callback::{ ClosureMetaHandler, ClosureTextBlockHandler, ClosureToolUseBlockHandler, TextBlockScope, ToolUseBlockScope, @@ -22,12 +22,9 @@ use crate::{ timeline::{TextBlockCollector, Timeline, ToolCallCollector}, timeline::event::{ErrorEvent, StatusEvent, UsageEvent}, tool::{ToolCall, ToolDefinition as WorkerToolDefinition, ToolError, ToolOutputProcessor, ToolResult}, - tool_server::{ToolServer, ToolServerError, ToolServerHandle}, + tool_server::{ToolServer, ToolServerHandle}, }; -// ============================================================================= -// Worker Error -// ============================================================================= /// Worker errors #[derive(Debug, thiserror::Error)] @@ -57,9 +54,6 @@ pub enum ToolRegistryError { DuplicateName(String), } -// ============================================================================= -// Worker Config -// ============================================================================= /// Worker configuration #[derive(Debug, Clone, Default)] @@ -68,9 +62,6 @@ pub struct WorkerConfig { _private: (), } -// ============================================================================= -// Worker Result Types -// ============================================================================= /// Worker execution result (status) #[derive(Debug)] @@ -89,9 +80,6 @@ enum ToolExecutionResult { Paused, } -// ============================================================================= -// Worker -// ============================================================================= /// Central component for managing LLM interactions /// @@ -100,32 +88,28 @@ enum ToolExecutionResult { /// /// # State Transitions (Type-state) /// -/// - [`Mutable`]: Initial state. System prompt and history can be freely edited. -/// - [`CacheLocked`]: Cache-protected state. Transition via `lock()`. Prefix context is immutable. +/// - [`Mutable`]: Initial state. System prompt, history, and tools can be freely edited. +/// - [`Locked`]: Cache-protected state. Prefix context is immutable; only `run()` / `resume()` are available. /// -/// # Examples +/// Calling `run()` on a `Mutable` Worker consumes it and returns a +/// `Locked` Worker together with the result. This ensures the +/// cache prefix is fixed for optimal KV cache hit rate. /// /// ```ignore -/// use llm_worker::{Worker, Item}; -/// -/// // Create a Worker and register tools /// let mut worker = Worker::new(client) /// .system_prompt("You are a helpful assistant."); /// worker.register_tool(my_tool); /// -/// // Run the interaction -/// let history = worker.run("Hello!").await?; -/// ``` +/// // Mutable::run() consumes self → Locked +/// let (mut worker, _result) = worker.run("Hello").await?; /// -/// # When Cache Protection is Needed +/// // Locked::run() borrows &mut self +/// worker.run("Follow-up").await?; /// -/// ```ignore -/// let mut worker = Worker::new(client) -/// .system_prompt("..."); -/// -/// // After setting history, lock to protect cache -/// let mut locked = worker.lock(); -/// locked.run("user input").await?; +/// // To edit between turns, unlock back to Mutable +/// let mut worker = worker.unlock(); +/// worker.history_mut().truncate(5); +/// let (mut worker, _result) = worker.run("Continue").await?; /// ``` pub struct Worker { /// LLM client @@ -144,7 +128,7 @@ pub struct Worker { system_prompt: Option, /// Item history (owned by Worker) history: Vec, - /// History length at lock time (only meaningful in CacheLocked state) + /// History length at lock time (only meaningful in Locked state) locked_prefix_len: usize, /// Turn count turn_count: usize, @@ -167,40 +151,12 @@ pub struct Worker { _state: PhantomData, } -// ============================================================================= -// Common Implementation (available in all states) -// ============================================================================= impl Worker { fn reset_interruption_state(&mut self) { self.last_run_interrupted = false; } - /// Execute a turn - /// - /// Adds a new user message to history and sends a request to the LLM. - /// Automatically loops if there are tool calls. - pub async fn run( - &mut self, - user_input: impl Into, - ) -> Result { - self.reset_interruption_state(); - // Interceptor: on_prompt_submit - let mut user_item = Item::user_message(user_input); - match self.interceptor.on_prompt_submit(&mut user_item).await { - PromptAction::Cancel(reason) => { - self.last_run_interrupted = true; - return self - .finalize_interruption(Err(WorkerError::Aborted(reason))) - .await; - } - PromptAction::Continue => {} - } - self.history.push(user_item); - let result = self.run_turn_loop().await; - self.finalize_interruption(result).await - } - fn drain_cancel_queue(&mut self) { use tokio::sync::mpsc::error::TryRecvError; loop { @@ -892,19 +848,8 @@ impl Worker { } } - /// Resume execution (from Paused state) - /// - /// Resumes turn processing from current state without adding a new user message to history. - pub async fn resume(&mut self) -> Result { - self.reset_interruption_state(); - let result = self.run_turn_loop().await; - self.finalize_interruption(result).await - } } -// ============================================================================= -// Mutable State-Specific Implementation -// ============================================================================= impl Worker { /// Create a new Worker (in Mutable state) @@ -941,43 +886,21 @@ impl Worker { } } - /// Register a tool + /// Register a tool factory for deferred initialization. /// - /// Registered tools are automatically executed when called by the LLM. - /// Registering a tool with the same name will result in an error. - /// - /// Available only in Mutable state. - pub fn register_tool( - &mut self, - factory: WorkerToolDefinition, - ) -> Result<(), ToolRegistryError> { - match self.tool_server.register_tool(factory) { - Ok(()) => Ok(()), - Err(ToolServerError::DuplicateName(name)) => { - Err(ToolRegistryError::DuplicateName(name)) - } - Err(ToolServerError::ToolNotFound(_) | ToolServerError::ToolExecution(_)) => { - unreachable!("register_tool should only fail with DuplicateName") - } - } + /// The factory is queued and executed at the next `run()` or `resume()` call. + /// Duplicate name detection occurs at that point and surfaces as + /// [`WorkerError::ToolRegistry`]. + pub fn register_tool(&mut self, factory: WorkerToolDefinition) { + self.tool_server.register_tool(factory); } - /// Register multiple tools - /// - /// Available only in Mutable state. + /// Register multiple tool factories for deferred initialization. pub fn register_tools( &mut self, factories: impl IntoIterator, - ) -> Result<(), ToolRegistryError> { - match self.tool_server.register_tools(factories) { - Ok(()) => Ok(()), - Err(ToolServerError::DuplicateName(name)) => { - Err(ToolRegistryError::DuplicateName(name)) - } - Err(ToolServerError::ToolNotFound(_) | ToolServerError::ToolExecution(_)) => { - unreachable!("register_tools should only fail with DuplicateName") - } - } + ) { + self.tool_server.register_tools(factories); } /// Set system prompt (builder pattern) @@ -1082,40 +1005,47 @@ impl Worker { /// Get a mutable reference to history /// - /// Available only in Mutable state. History can be freely edited. + /// Available only in Mutable state. pub fn history_mut(&mut self) -> &mut Vec { + &mut self.history } /// Set history pub fn set_history(&mut self, items: Vec) { + self.history = items; } /// Add an item to history (builder pattern) pub fn with_item(mut self, item: Item) -> Self { + self.history.push(item); self } /// Add an item to history pub fn push_item(&mut self, item: Item) { + self.history.push(item); } /// Add multiple items to history (builder pattern) pub fn with_items(mut self, items: impl IntoIterator) -> Self { + self.history.extend(items); self } /// Add multiple items to history pub fn extend_history(&mut self, items: impl IntoIterator) { + self.history.extend(items); } /// Clear history pub fn clear_history(&mut self) { + self.history.clear(); } @@ -1148,11 +1078,48 @@ impl Worker { self } - /// Lock and transition to CacheLocked state + /// Execute a turn, consuming self and transitioning to Locked. /// - /// This operation fixes the current system prompt and history as a "committed prefix". - /// After this, only appending to history is allowed, ensuring cache hits. - pub fn lock(self) -> Worker { + /// This is the primary entry point for first use. Equivalent to + /// `self.lock()` followed by `locked.run(user_input)`. + /// + /// Subsequent runs can use [`Worker::run()`] directly. + /// To edit state between turns, call [`unlock()`](Worker::unlock) first. + pub async fn run( + self, + user_input: impl Into, + ) -> Result<(Worker, WorkerResult), WorkerError> { + let mut locked = self.lock(); + let result = locked.run(user_input).await?; + Ok((locked, result)) + } + + /// Resume from Paused, consuming self and transitioning to Locked. + /// + /// Used after `unlock()` → edit → resume. + pub async fn resume( + self, + ) -> Result<(Worker, WorkerResult), WorkerError> { + let mut locked = self.lock(); + let result = locked.resume().await?; + Ok((locked, result)) + } + + /// Lock and transition to Locked state + /// + /// Flushes pending tool factories, then fixes the current system prompt + /// and history as a "committed prefix". After this, only `run()` / `resume()` + /// may append to history, ensuring cache hits. + /// + /// Most callers should use [`run()`](Self::run) instead, which calls + /// this internally. Use `lock()` directly only when you need the + /// `Locked` worker back on error (e.g. in a persistence layer). + /// + /// # Panics + /// + /// Panics if a pending tool factory produces a duplicate name. + pub fn lock(self) -> Worker { + self.tool_server.flush_pending(); let locked_prefix_len = self.history.len(); Worker { client: self.client, @@ -1178,11 +1145,42 @@ impl Worker { } } -// ============================================================================= -// CacheLocked State-Specific Implementation -// ============================================================================= -impl Worker { +impl Worker { + /// Execute a turn + /// + /// Adds a new user message to history and sends a request to the LLM. + /// Automatically loops if there are tool calls. + pub async fn run( + &mut self, + user_input: impl Into, + ) -> Result { + self.reset_interruption_state(); + // Interceptor: on_prompt_submit + let mut user_item = Item::user_message(user_input); + match self.interceptor.on_prompt_submit(&mut user_item).await { + PromptAction::Cancel(reason) => { + self.last_run_interrupted = true; + return self + .finalize_interruption(Err(WorkerError::Aborted(reason))) + .await; + } + PromptAction::Continue => {} + } + self.history.push(user_item); + let result = self.run_turn_loop().await; + self.finalize_interruption(result).await + } + + /// Resume execution (from Paused state) + /// + /// Resumes turn processing from current state without adding a new user message. + pub async fn resume(&mut self) -> Result { + self.reset_interruption_state(); + let result = self.run_turn_loop().await; + self.finalize_interruption(result).await + } + /// Get the prefix length at lock time pub fn locked_prefix_len(&self) -> usize { self.locked_prefix_len diff --git a/crates/llm-worker/tests/callback_test.rs b/crates/llm-worker/tests/callback_test.rs index c1c4b71a..1bb972a7 100644 --- a/crates/llm-worker/tests/callback_test.rs +++ b/crates/llm-worker/tests/callback_test.rs @@ -46,8 +46,9 @@ async fn test_callback_text_block_events() { }); }); + // Mutable::run consumes self, returns (Locked, WorkerResult) let result = worker.run("Greet me").await; - assert!(result.is_ok(), "Worker should complete: {:?}", result); + assert!(result.is_ok(), "Worker should complete"); let deltas = text_deltas.lock().unwrap(); assert_eq!(deltas.len(), 2); @@ -91,6 +92,7 @@ async fn test_callback_tool_call_complete() { }); }); + // Mutable::run consumes self, returns (Locked, WorkerResult) let _ = worker.run("Weather please").await; let starts = tool_starts.lock().unwrap(); @@ -133,6 +135,7 @@ async fn test_callback_turn_events() { ends.lock().unwrap().push(turn); }); + // Mutable::run consumes self, returns (Locked, WorkerResult) let result = worker.run("Do something").await; assert!(result.is_ok()); @@ -169,6 +172,7 @@ async fn test_callback_usage_events() { usages.lock().unwrap().push(event.clone()); }); + // Mutable::run consumes self, returns (Locked, WorkerResult) let _ = worker.run("Hello").await; let usages = usage_events.lock().unwrap(); diff --git a/crates/llm-worker/tests/compile_fail.rs b/crates/llm-worker/tests/compile_fail.rs index fa876c29..5d788296 100644 --- a/crates/llm-worker/tests/compile_fail.rs +++ b/crates/llm-worker/tests/compile_fail.rs @@ -1,6 +1,6 @@ #[test] fn compile_fail_state_constraints() { let t = trybuild::TestCases::new(); - t.compile_fail("tests/ui/cache_locked_register_tool.rs"); + t.compile_fail("tests/ui/locked_register_tool.rs"); t.compile_fail("tests/ui/tool_server_handle_register_tool.rs"); } diff --git a/crates/llm-worker/tests/parallel_execution_test.rs b/crates/llm-worker/tests/parallel_execution_test.rs index 65bff220..c63c35ea 100644 --- a/crates/llm-worker/tests/parallel_execution_test.rs +++ b/crates/llm-worker/tests/parallel_execution_test.rs @@ -102,11 +102,12 @@ async fn test_parallel_tool_execution() { let tool2_clone = tool2.clone(); let tool3_clone = tool3.clone(); - worker.register_tool(tool1.definition()).unwrap(); - worker.register_tool(tool2.definition()).unwrap(); - worker.register_tool(tool3.definition()).unwrap(); + worker.register_tool(tool1.definition()); + worker.register_tool(tool2.definition()); + worker.register_tool(tool3.definition()); let start = Instant::now(); + // Mutable::run consumes self, returns (Locked, WorkerResult) let _result = worker.run("Run all tools").await; let elapsed = start.elapsed(); @@ -150,8 +151,8 @@ async fn test_before_tool_call_skip() { let allowed_clone = allowed_tool.clone(); let blocked_clone = blocked_tool.clone(); - worker.register_tool(allowed_tool.definition()).unwrap(); - worker.register_tool(blocked_tool.definition()).unwrap(); + worker.register_tool(allowed_tool.definition()); + worker.register_tool(blocked_tool.definition()); // Policy to skip "blocked_tool" struct BlockingPolicy; @@ -169,6 +170,7 @@ async fn test_before_tool_call_skip() { worker.set_interceptor(BlockingPolicy); + // Mutable::run consumes self, returns (Locked, WorkerResult) let _result = worker.run("Test hook").await; // allowed_tool is called, but blocked_tool is not @@ -230,7 +232,7 @@ async fn test_post_tool_call_modification() { }) } - worker.register_tool(simple_tool_definition()).unwrap(); + worker.register_tool(simple_tool_definition()); // Policy to modify results struct ModifyingPolicy { @@ -251,9 +253,10 @@ async fn test_post_tool_call_modification() { modified_content: modified_content.clone(), }); + // Mutable::run consumes self, returns (Locked, WorkerResult) let result = worker.run("Test modification").await; - assert!(result.is_ok(), "Worker should complete: {:?}", result); + assert!(result.is_ok(), "Worker should complete"); // Verify hook was called and content was modified let content = modified_content.lock().unwrap().clone(); diff --git a/crates/llm-worker/tests/ui/cache_locked_register_tool.rs b/crates/llm-worker/tests/ui/locked_register_tool.rs similarity index 100% rename from crates/llm-worker/tests/ui/cache_locked_register_tool.rs rename to crates/llm-worker/tests/ui/locked_register_tool.rs diff --git a/crates/llm-worker/tests/ui/cache_locked_register_tool.stderr b/crates/llm-worker/tests/ui/locked_register_tool.stderr similarity index 63% rename from crates/llm-worker/tests/ui/cache_locked_register_tool.stderr rename to crates/llm-worker/tests/ui/locked_register_tool.stderr index 0c3b097f..accf7c59 100644 --- a/crates/llm-worker/tests/ui/cache_locked_register_tool.stderr +++ b/crates/llm-worker/tests/ui/locked_register_tool.stderr @@ -1,8 +1,8 @@ -error[E0599]: no method named `register_tool` found for struct `Worker` in the current scope - --> tests/ui/cache_locked_register_tool.rs:10:20 +error[E0599]: no method named `register_tool` found for struct `Worker` in the current scope + --> tests/ui/locked_register_tool.rs:10:20 | 10 | let _ = locked.register_tool(def); - | ^^^^^^^^^^^^^ method not found in `Worker` + | ^^^^^^^^^^^^^ method not found in `Worker` | = note: the method was found for - `Worker` diff --git a/crates/llm-worker/tests/ui/tool_server_handle_register_tool.stderr b/crates/llm-worker/tests/ui/tool_server_handle_register_tool.stderr index 57d71392..9410713d 100644 --- a/crates/llm-worker/tests/ui/tool_server_handle_register_tool.stderr +++ b/crates/llm-worker/tests/ui/tool_server_handle_register_tool.stderr @@ -1,13 +1,10 @@ error[E0624]: method `register_tool` is private --> tests/ui/tool_server_handle_register_tool.rs:10:20 | -10 | let _ = handle.register_tool(def); - | ^^^^^^^^^^^^^ private method +10 | let _ = handle.register_tool(def); + | ^^^^^^^^^^^^^ private method | ::: src/tool_server.rs | - | / pub(crate) fn register_tool( - | | &self, - | | factory: WorkerToolDefinition, - | | ) -> Result<(), ToolServerError> { - | |____________________________________- private method defined here + | pub(crate) fn register_tool(&self, factory: WorkerToolDefinition) { + | ----------------------------------------------------------------- private method defined here diff --git a/crates/llm-worker/tests/worker_fixtures.rs b/crates/llm-worker/tests/worker_fixtures.rs index 6f28752a..3f999068 100644 --- a/crates/llm-worker/tests/worker_fixtures.rs +++ b/crates/llm-worker/tests/worker_fixtures.rs @@ -129,9 +129,9 @@ async fn test_worker_simple_text_response() { } let client = MockLlmClient::from_fixture(&fixture_path).unwrap(); - let mut worker = Worker::new(client); + let worker = Worker::new(client); - // Send a simple message + // Send a simple message (Mutable::run consumes self, returns tuple) let result = worker.run("Hello").await; assert!(result.is_ok(), "Worker should complete successfully"); @@ -156,9 +156,9 @@ async fn test_worker_tool_call() { // Register tool let weather_tool = MockWeatherTool::new(); let tool_for_check = weather_tool.clone(); - worker.register_tool(weather_tool.definition()).unwrap(); + worker.register_tool(weather_tool.definition()); - // Send message + // Send message (Mutable::run consumes self, returns tuple) let _result = worker.run("What's the weather in Tokyo?").await; // Verify tool was called @@ -190,8 +190,9 @@ async fn test_worker_with_programmatic_events() { ]; let client = MockLlmClient::new(events); - let mut worker = Worker::new(client); + let worker = Worker::new(client); + // Mutable::run consumes self, returns tuple let result = worker.run("Greet me").await; assert!(result.is_ok(), "Worker should complete successfully"); diff --git a/crates/llm-worker/tests/worker_state_test.rs b/crates/llm-worker/tests/worker_state_test.rs index 44de83b5..8828767c 100644 --- a/crates/llm-worker/tests/worker_state_test.rs +++ b/crates/llm-worker/tests/worker_state_test.rs @@ -1,6 +1,6 @@ //! Worker state management tests //! -//! Tests for state transitions using the Type-state pattern (Mutable/CacheLocked) +//! Tests for state transitions using the Type-state pattern (Mutable/Locked) //! and state preservation between turns. mod common; @@ -11,7 +11,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use async_trait::async_trait; use common::MockLlmClient; use llm_worker::Item; -use llm_worker::Worker; +use llm_worker::{Worker, WorkerError}; use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent}; use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta}; @@ -147,15 +147,15 @@ fn test_mutable_can_register_tool() { let mut worker = Worker::new(client); let tool = CountingTool::new("count_tool"); - let result = worker.register_tool(tool.definition()); - assert!(result.is_ok(), "Mutable should allow tool registration"); + // register_tool is infallible (factory deferred to run-time flush) + worker.register_tool(tool.definition()); } // ============================================================================= // State Transition Tests // ============================================================================= -/// Verify that lock() transitions from Mutable -> CacheLocked state +/// Verify that lock() transitions from Mutable -> Locked state #[test] fn test_lock_transition() { let client = MockLlmClient::new(vec![]); @@ -168,13 +168,13 @@ fn test_lock_transition() { // Lock let locked_worker = worker.lock(); - // History and system prompt are still accessible in CacheLocked state + // History and system prompt are still accessible in Locked state assert_eq!(locked_worker.get_system_prompt(), Some("System")); assert_eq!(locked_worker.history().len(), 2); assert_eq!(locked_worker.locked_prefix_len(), 2); } -/// Verify that unlock() transitions from CacheLocked -> Mutable state +/// Verify that unlock() transitions from Locked -> Mutable state #[test] fn test_unlock_transition() { let client = MockLlmClient::new(vec![]); @@ -198,7 +198,7 @@ fn test_unlock_transition() { /// Verify that history is correctly updated after running a turn in Mutable state #[tokio::test] -async fn test_mutable_run_updates_history() { +async fn test_mutable_run_updates_history() -> Result<(), WorkerError> { let events = vec![ Event::text_block_start(0), Event::text_delta(0, "Hello, I'm an assistant!"), @@ -209,11 +209,10 @@ async fn test_mutable_run_updates_history() { ]; let client = MockLlmClient::new(events); - let mut worker = Worker::new(client); + let worker = Worker::new(client); - // Execute - let result = worker.run("Hi there").await; - assert!(result.is_ok()); + // Execute (Mutable::run consumes self, returns (Locked, WorkerResult)) + let (worker, _result) = worker.run("Hi there").await?; // History is updated let history = worker.history(); @@ -224,9 +223,11 @@ async fn test_mutable_run_updates_history() { // Assistant message assert_eq!(history[1].as_text(), Some("Hello, I'm an assistant!")); + + Ok(()) } -/// Verify that history accumulates correctly over multiple turns in CacheLocked state +/// Verify that history accumulates correctly over multiple turns in Locked state #[tokio::test] async fn test_locked_multi_turn_history_accumulation() { // Prepare responses for 2 requests @@ -327,7 +328,7 @@ async fn test_locked_prefix_len_tracking() { /// Verify that turn count is correctly incremented #[tokio::test] -async fn test_turn_count_increment() { +async fn test_turn_count_increment() -> Result<(), WorkerError> { let client = MockLlmClient::with_responses(vec![ vec![ Event::text_block_start(0), @@ -347,15 +348,19 @@ async fn test_turn_count_increment() { ], ]); - let mut worker = Worker::new(client); + let worker = Worker::new(client); assert_eq!(worker.turn_count(), 0); - worker.run("First").await.unwrap(); + // First run consumes Mutable, returns Locked + let (mut worker, _) = worker.run("First").await?; assert_eq!(worker.turn_count(), 1); - worker.run("Second").await.unwrap(); + // Subsequent runs on Locked take &mut self + worker.run("Second").await?; assert_eq!(worker.turn_count(), 2); + + Ok(()) } /// Verify that history can be edited after unlock and re-locked @@ -430,9 +435,7 @@ async fn test_lock_unlock_relock_tools_remain_effective() { let mut worker = Worker::new(client); let tool_a = CountingTool::new("tool_a"); - worker - .register_tool(tool_a.definition()) - .expect("register tool_a should succeed"); + worker.register_tool(tool_a.definition()); let mut locked = worker.lock(); locked.run("first").await.expect("first run"); @@ -440,9 +443,7 @@ async fn test_lock_unlock_relock_tools_remain_effective() { let mut unlocked = locked.unlock(); let tool_b = CountingTool::new("tool_b"); - unlocked - .register_tool(tool_b.definition()) - .expect("register tool_b after unlock should succeed"); + unlocked.register_tool(tool_b.definition()); let mut relocked = unlocked.lock(); relocked.run("second").await.expect("second run"); @@ -455,7 +456,7 @@ async fn test_lock_unlock_relock_tools_remain_effective() { // System Prompt Preservation Tests // ============================================================================= -/// Verify that system prompt is preserved in CacheLocked state +/// Verify that system prompt is preserved in Locked state #[test] fn test_system_prompt_preserved_in_locked_state() { let client = MockLlmClient::new(vec![]); diff --git a/crates/pod/examples/pod_cli.rs b/crates/pod/examples/pod_cli.rs index d7cf457f..8b4e4b82 100644 --- a/crates/pod/examples/pod_cli.rs +++ b/crates/pod/examples/pod_cli.rs @@ -53,7 +53,7 @@ async fn main() -> Result<(), Box> { } // 5. Extract the assistant's reply from history - let history = pod.session_mut().worker.history(); + let history = pod.session_mut().worker().history(); if let Some(text) = history .iter() .rev() diff --git a/crates/pod/src/controller.rs b/crates/pod/src/controller.rs index bd8478d8..72af0182 100644 --- a/crates/pod/src/controller.rs +++ b/crates/pod/src/controller.rs @@ -85,7 +85,7 @@ impl PodController { // Register event bridge callbacks on the worker { - let worker = &mut pod.session_mut().worker; + let worker = pod.session_mut().worker_mut(); let tx = event_tx.clone(); worker.on_turn_start(move |turn| { @@ -158,7 +158,7 @@ impl PodController { } // Clone cancel sender before moving pod - let cancel_tx = pod.session_mut().worker.cancel_sender(); + let cancel_tx = pod.session_mut().worker_mut().cancel_sender(); tokio::spawn(async move { // Hold socket server alive for the lifetime of the controller task @@ -191,7 +191,7 @@ impl PodController { ) .await; - let items = pod.session_mut().worker.history().to_vec(); + let items = pod.session_mut().worker_mut().history().to_vec(); shared_state.update_history(items); shared_state.set_status(new_status); let _ = runtime_dir.write_status(&shared_state).await; @@ -218,7 +218,7 @@ impl PodController { ) .await; - let items = pod.session_mut().worker.history().to_vec(); + let items = pod.session_mut().worker_mut().history().to_vec(); shared_state.update_history(items); shared_state.set_status(new_status); let _ = runtime_dir.write_status(&shared_state).await; diff --git a/crates/pod/src/pod.rs b/crates/pod/src/pod.rs index f593416e..508e38f4 100644 --- a/crates/pod/src/pod.rs +++ b/crates/pod/src/pod.rs @@ -144,7 +144,7 @@ impl Pod { let builder = std::mem::take(&mut self.hook_builder); let registry = Arc::new(builder.build()); let interceptor = HookInterceptor::new(registry); - self.session.worker.set_interceptor(interceptor); + self.session.worker_mut().set_interceptor(interceptor); self.interceptor_installed = true; } } diff --git a/tickets/worker-auto-lock.md b/tickets/worker-auto-lock.md deleted file mode 100644 index 290ffbf6..00000000 --- a/tickets/worker-auto-lock.md +++ /dev/null @@ -1,57 +0,0 @@ -# Worker: run() 時の自動キャッシュロックと ToolDefinition ファクトリ遅延初期化 - -## 背景 - -現状の `Worker` は Type-state パターンで `Mutable` / `CacheLocked` の2状態を持つが、 -`lock()` を呼ばなくても `run()` できてしまうため、キャッシュ保護の存在を知らないユーザーは -常に非最適パスを通ることになる。 - -また `register_tool()` 時に `ToolDefinition` のファクトリクロージャが即時呼び出しされており、 -本来の意図である遅延初期化になっていない。 - -## 方針 - -### run() 時の自動ロック - -`run()` の冒頭で、`Mutable` 状態なら自動的に `CacheLocked` へ遷移する。 -これにより lock を知らないユーザーでも嫌でもキャッシュ保護される。 - -ターンの合間に history や system prompt を編集したい場合は、明示的に `unlock()` を挟む。 -次の `run()` で再び自動ロックされる。 - -```rust -let mut worker = Worker::new(client); -worker.set_system_prompt("..."); -worker.register_tool(my_tool)?; - -// Mutable のまま run() → 自動で lock される -worker.run("Hello").await?; - -// ターン間で内容を弄りたい場合 -worker.unlock(); -worker.history_mut().truncate(5); -// 次の run() で再 lock -worker.run("Continue").await?; -``` - -#### 設計ポイント - -- `run()` が `&mut self` を取る以上、内部で状態遷移しても外部の型は変わらない。 - 実装は内部フラグ(`is_locked: bool`)で管理し、`Mutable` / `CacheLocked` の - type-state はそのまま維持する。`Worker` の `run()` が内部で lock 相当の - 処理を行い、`unlock()` が呼ばれるまでキャッシュ破壊的な操作を(ランタイムで)ブロックする -- `Worker` は従来どおり。明示的に `lock()` してから使うパスも残る -- interceptor, max_turns, callbacks 等キャッシュに影響しない設定は lock 状態でも自由に変更可能 - -### ToolDefinition ファクトリの遅延初期化 - -`register_tool()` は定義を蓄積するだけにし、ファクトリの実行を初回 `run()` まで遅延させる。 - -- `register_tool()` → `ToolDefinition` を Vec に push するだけ -- 初回 `run()` の自動 lock 時にファクトリを一括実行し、ToolServer を構築 -- `unlock()` 後に追加登録された tool は次の `run()` で初期化 - -## 移行 - -- `lock()` / `unlock()` は引き続き使える(明示的なキャッシュ管理用) -- `Worker::new()` → `run()` のパスが自動保護されるため、既存コードは変更不要