Workerの自動キャッシュロック

This commit is contained in:
Keisuke Hirata 2026-04-11 18:47:33 +09:00
parent f241dafac8
commit 9b78c51d0a
23 changed files with 375 additions and 352 deletions

View File

@ -3,7 +3,7 @@
- [x] ツール出力の遅延読み込み設計 (ToolOutput / BlobStore / auto_summarize) - [x] ツール出力の遅延読み込み設計 (ToolOutput / BlobStore / auto_summarize)
- [ ] ツール設計 - [ ] ツール設計
- [ ] ツールの動的追加/削除 → [tickets/tool-dynamic-registry.md](tickets/tool-dynamic-registry.md) - [ ] ツールの動的追加/削除 → [tickets/tool-dynamic-registry.md](tickets/tool-dynamic-registry.md)
- [ ] run() 自動ロックとファクトリ遅延初期化 → [tickets/worker-auto-lock.md](tickets/worker-auto-lock.md) - [x] run() 自動ロックとファクトリ遅延初期化 → [tickets/worker-auto-lock.md](tickets/worker-auto-lock.md)
- [x] inspect ツール実装 - [x] inspect ツール実装
- [x] max_turns: マニフェストによるターン数制限 - [x] max_turns: マニフェストによるターン数制限
- [x] pod バイナリエントリポイント - [x] pod バイナリエントリポイント

View File

@ -13,7 +13,6 @@ use serde_json::json;
use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta}; use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta};
use llm_worker::state::Mutable; use llm_worker::state::Mutable;
use llm_worker::ToolRegistryError;
use llm_worker::Worker; use llm_worker::Worker;
use llm_worker::llm_client::LlmClient; use llm_worker::llm_client::LlmClient;
@ -332,12 +331,11 @@ fn apply_selector(content: &Content, selector: &Selector) -> Result<String, Tool
pub fn register_inspect_tool<C, B>( pub fn register_inspect_tool<C, B>(
worker: &mut Worker<C, Mutable>, worker: &mut Worker<C, Mutable>,
blob_store: Arc<B>, blob_store: Arc<B>,
) -> Result<(), ToolRegistryError> ) where
where
C: LlmClient, C: LlmClient,
B: BlobStore + 'static, B: BlobStore + 'static,
{ {
worker.register_tool(InspectTool::<B>::tool_definition(blob_store)) worker.register_tool(InspectTool::<B>::tool_definition(blob_store));
} }
// ─── Tests ─────────────────────────────────────────────────────────────────── // ─── Tests ───────────────────────────────────────────────────────────────────

View File

@ -43,12 +43,13 @@ pub enum SessionError {
/// Persistent session wrapping a [`Worker`]. /// Persistent session wrapping a [`Worker`].
/// ///
/// The `worker` field is public for direct access to Worker APIs /// Use [`worker()`](Self::worker) / [`worker_mut()`](Self::worker_mut) to
/// (tool registration, hook setup, subscriber management, etc.). /// access the underlying Worker for configuration (tool registration, etc.).
/// State-mutating operations (`run`, `resume`) should go through /// State-mutating operations (`run`, `resume`) should go through Session
/// Session methods to ensure proper logging. /// methods to ensure proper logging.
pub struct Session<C: LlmClient, St: Store> { pub struct Session<C: LlmClient, St: Store> {
pub worker: Worker<C, Mutable>, /// Always `Some` outside of `run()` / `resume()`.
worker: Option<Worker<C, Mutable>>,
store: St, store: St,
session_id: SessionId, session_id: SessionId,
head_hash: Option<EntryHash>, head_hash: Option<EntryHash>,
@ -78,7 +79,7 @@ impl<C: LlmClient, St: Store> Session<C, St> {
store.append(session_id, &hashed_entry).await?; store.append(session_id, &hashed_entry).await?;
Ok(Self { Ok(Self {
worker, worker: Some(worker),
store, store,
session_id, session_id,
head_hash: Some(hashed), head_hash: Some(hashed),
@ -87,9 +88,6 @@ impl<C: LlmClient, St: Store> Session<C, St> {
} }
/// Restore a session from a stored log. /// Restore a session from a stored log.
///
/// Reads all log entries, collects state from them,
/// and returns a `Session` ready for `resume()`.
pub async fn restore( pub async fn restore(
client: C, client: C,
store: St, store: St,
@ -109,7 +107,7 @@ impl<C: LlmClient, St: Store> Session<C, St> {
worker.set_last_run_interrupted(state.last_run_interrupted); worker.set_last_run_interrupted(state.last_run_interrupted);
Ok(Self { Ok(Self {
worker, worker: Some(worker),
store, store,
session_id, session_id,
head_hash: state.head_hash, head_hash: state.head_hash,
@ -117,6 +115,20 @@ impl<C: LlmClient, St: Store> Session<C, St> {
}) })
} }
fn w(&self) -> &Worker<C, Mutable> {
self.worker.as_ref().expect("worker taken during run")
}
/// Reference to the underlying Worker.
pub fn worker(&self) -> &Worker<C, Mutable> {
self.w()
}
/// Mutable reference to the underlying Worker.
pub fn worker_mut(&mut self) -> &mut Worker<C, Mutable> {
self.worker.as_mut().expect("worker taken during run")
}
/// The session ID. /// The session ID.
pub fn session_id(&self) -> SessionId { pub fn session_id(&self) -> SessionId {
self.session_id self.session_id
@ -133,15 +145,23 @@ impl<C: LlmClient, St: Store> Session<C, St> {
} }
/// Run a user turn, logging all state changes. /// Run a user turn, logging all state changes.
///
/// Internally locks the Worker (flushing pending tools), runs the turn,
/// then unlocks back to Mutable state.
pub async fn run( pub async fn run(
&mut self, &mut self,
user_input: impl Into<String>, user_input: impl Into<String>,
) -> Result<WorkerResult, SessionError> { ) -> Result<WorkerResult, SessionError> {
let input = user_input.into();
self.ensure_head_or_fork().await?; self.ensure_head_or_fork().await?;
let history_before = self.worker.history().len(); let history_before = self.w().history().len();
let result = self.worker.run(user_input).await; // lock → run → unlock (use lock() directly to keep worker on error)
let worker = self.worker.take().expect("worker taken during run");
let mut locked = worker.lock();
let result = locked.run(input).await;
self.worker = Some(locked.unlock());
self.log_history_delta(history_before).await?; self.log_history_delta(history_before).await?;
self.log_turn_end().await?; self.log_turn_end().await?;
@ -154,9 +174,13 @@ impl<C: LlmClient, St: Store> Session<C, St> {
pub async fn resume(&mut self) -> Result<WorkerResult, SessionError> { pub async fn resume(&mut self) -> Result<WorkerResult, SessionError> {
self.ensure_head_or_fork().await?; self.ensure_head_or_fork().await?;
let history_before = self.worker.history().len(); let history_before = self.w().history().len();
let result = self.worker.resume().await; // lock → resume → unlock
let worker = self.worker.take().expect("worker taken during run");
let mut locked = worker.lock();
let result = locked.resume().await;
self.worker = Some(locked.unlock());
self.log_history_delta(history_before).await?; self.log_history_delta(history_before).await?;
self.log_turn_end().await?; self.log_turn_end().await?;
@ -166,15 +190,13 @@ impl<C: LlmClient, St: Store> Session<C, St> {
} }
/// Fork this session at its current state. /// Fork this session at its current state.
/// Returns the new session ID. The new log contains a `SessionStart`
/// seeded with the current history.
pub async fn fork(&self) -> Result<SessionId, StoreError> { pub async fn fork(&self) -> Result<SessionId, StoreError> {
let fork_id = crate::new_session_id(); let fork_id = crate::new_session_id();
let entry = LogEntry::SessionStart { let entry = LogEntry::SessionStart {
ts: session_log::now_millis(), ts: session_log::now_millis(),
system_prompt: self.worker.get_system_prompt().map(String::from), system_prompt: self.w().get_system_prompt().map(String::from),
config: self.worker.request_config().clone(), config: self.w().request_config().clone(),
history: self.worker.history().to_vec(), history: self.w().history().to_vec(),
}; };
let hashed = session_log::compute_hash(None, &entry); let hashed = session_log::compute_hash(None, &entry);
let hashed_entry = HashedEntry { let hashed_entry = HashedEntry {
@ -189,8 +211,6 @@ impl<C: LlmClient, St: Store> Session<C, St> {
} }
/// Fork from an arbitrary point in a stored session's log. /// Fork from an arbitrary point in a stored session's log.
/// Finds the entry matching `at_hash` and creates a new session
/// with state reconstructed up to that point.
pub async fn fork_at( pub async fn fork_at(
store: &St, store: &St,
source_id: SessionId, source_id: SessionId,
@ -221,12 +241,12 @@ impl<C: LlmClient, St: Store> Session<C, St> {
Ok(fork_id) Ok(fork_id)
} }
/// Log a `CacheLocked` entry. /// Log a `Locked` entry.
pub async fn log_cache_locked( pub async fn log_cache_locked(
&mut self, &mut self,
locked_prefix_len: usize, locked_prefix_len: usize,
) -> Result<(), StoreError> { ) -> Result<(), StoreError> {
let entry = LogEntry::CacheLocked { let entry = LogEntry::Locked {
ts: session_log::now_millis(), ts: session_log::now_millis(),
locked_prefix_len, locked_prefix_len,
}; };
@ -245,14 +265,13 @@ impl<C: LlmClient, St: Store> Session<C, St> {
pub async fn log_config_changed(&mut self) -> Result<(), StoreError> { pub async fn log_config_changed(&mut self) -> Result<(), StoreError> {
let entry = LogEntry::ConfigChanged { let entry = LogEntry::ConfigChanged {
ts: session_log::now_millis(), ts: session_log::now_millis(),
config: self.worker.request_config().clone(), config: self.w().request_config().clone(),
}; };
self.append_entry(entry).await self.append_entry(entry).await
} }
// ── Private helpers ────────────────────────────────────────────────── // ── Private helpers ──────────────────────────────────────────────────
/// Append a `LogEntry`, computing its hash and updating `head_hash`.
async fn append_entry(&mut self, entry: LogEntry) -> Result<(), StoreError> { async fn append_entry(&mut self, entry: LogEntry) -> Result<(), StoreError> {
let hash = session_log::compute_hash(self.head_hash.as_ref(), &entry); let hash = session_log::compute_hash(self.head_hash.as_ref(), &entry);
let hashed_entry = HashedEntry { let hashed_entry = HashedEntry {
@ -267,19 +286,17 @@ impl<C: LlmClient, St: Store> Session<C, St> {
Ok(()) Ok(())
} }
/// Check that the store's head still matches ours. If not, auto-fork.
async fn ensure_head_or_fork(&mut self) -> Result<(), StoreError> { async fn ensure_head_or_fork(&mut self) -> Result<(), StoreError> {
let store_head = self.store.read_head_hash(self.session_id).await?; let store_head = self.store.read_head_hash(self.session_id).await?;
if store_head == self.head_hash { if store_head == self.head_hash {
return Ok(()); return Ok(());
} }
// Another writer advanced this session — fork from our known state.
let fork_id = crate::new_session_id(); let fork_id = crate::new_session_id();
let entry = LogEntry::SessionStart { let entry = LogEntry::SessionStart {
ts: session_log::now_millis(), ts: session_log::now_millis(),
system_prompt: self.worker.get_system_prompt().map(String::from), system_prompt: self.w().get_system_prompt().map(String::from),
config: self.worker.request_config().clone(), config: self.w().request_config().clone(),
history: self.worker.history().to_vec(), history: self.w().history().to_vec(),
}; };
let hash = session_log::compute_hash(None, &entry); let hash = session_log::compute_hash(None, &entry);
let hashed_entry = HashedEntry { let hashed_entry = HashedEntry {
@ -296,7 +313,7 @@ impl<C: LlmClient, St: Store> Session<C, St> {
} }
async fn log_history_delta(&mut self, before_len: usize) -> Result<(), StoreError> { async fn log_history_delta(&mut self, before_len: usize) -> Result<(), StoreError> {
let history = self.worker.history(); let history = self.w().history();
if history.len() <= before_len { if history.len() <= before_len {
return Ok(()); return Ok(());
} }
@ -356,7 +373,7 @@ impl<C: LlmClient, St: Store> Session<C, St> {
async fn log_turn_end(&mut self) -> Result<(), StoreError> { async fn log_turn_end(&mut self) -> Result<(), StoreError> {
self.append_entry(LogEntry::TurnEnd { self.append_entry(LogEntry::TurnEnd {
ts: session_log::now_millis(), ts: session_log::now_millis(),
turn_count: self.worker.turn_count(), turn_count: self.w().turn_count(),
}) })
.await .await
} }
@ -376,7 +393,7 @@ impl<C: LlmClient, St: Store> Session<C, St> {
self.append_entry(LogEntry::RunOutcome { self.append_entry(LogEntry::RunOutcome {
ts: session_log::now_millis(), ts: session_log::now_millis(),
outcome, outcome,
interrupted: self.worker.last_run_interrupted(), interrupted: self.w().last_run_interrupted(),
}) })
.await .await
} }

View File

@ -88,7 +88,7 @@ pub struct HashedEntry {
/// - `SessionStart` — always the first entry; captures initial state /// - `SessionStart` — always the first entry; captures initial state
/// - `UserInput` / `AssistantItems` / `ToolResults` / `HookInjectedItems` — history appends /// - `UserInput` / `AssistantItems` / `ToolResults` / `HookInjectedItems` — history appends
/// - `TurnEnd` — turn boundary marker /// - `TurnEnd` — turn boundary marker
/// - `CacheLocked` / `CacheUnlocked` — KV cache state transitions /// - `Locked` / `CacheUnlocked` — KV cache state transitions
/// - `RunOutcome` — marks end of a `run()` or `resume()` call /// - `RunOutcome` — marks end of a `run()` or `resume()` call
/// - `ConfigChanged` — `RequestConfig` mutation /// - `ConfigChanged` — `RequestConfig` mutation
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -119,7 +119,8 @@ pub enum LogEntry {
TurnEnd { ts: u64, turn_count: usize }, TurnEnd { ts: u64, turn_count: usize },
/// KV cache locked. Records the history prefix length that is now immutable. /// KV cache locked. Records the history prefix length that is now immutable.
CacheLocked { ts: u64, locked_prefix_len: usize }, #[serde(alias = "cache_locked")]
Locked { ts: u64, locked_prefix_len: usize },
/// KV cache unlocked. /// KV cache unlocked.
CacheUnlocked { ts: u64 }, CacheUnlocked { ts: u64 },
@ -200,7 +201,7 @@ pub fn collect_state(entries: &[HashedEntry]) -> RestoredState {
LogEntry::TurnEnd { turn_count, .. } => { LogEntry::TurnEnd { turn_count, .. } => {
state.turn_count = *turn_count; state.turn_count = *turn_count;
} }
LogEntry::CacheLocked { LogEntry::Locked {
locked_prefix_len, .. locked_prefix_len, ..
} => { } => {
state.locked_prefix_len = *locked_prefix_len; state.locked_prefix_len = *locked_prefix_len;
@ -354,7 +355,7 @@ mod tests {
config: RequestConfig::default(), config: RequestConfig::default(),
history: vec![Item::user_message("a"), Item::assistant_message("b")], history: vec![Item::user_message("a"), Item::assistant_message("b")],
}, },
LogEntry::CacheLocked { LogEntry::Locked {
ts: 2000, ts: 2000,
locked_prefix_len: 2, locked_prefix_len: 2,
}, },

View File

@ -159,8 +159,8 @@ async fn session_restore_round_trip() {
session.run("Hi").await.unwrap(); session.run("Hi").await.unwrap();
let original_history = session.worker.history().to_vec(); let original_history = session.worker().history().to_vec();
let original_turn_count = session.worker.turn_count(); let original_turn_count = session.worker().turn_count();
let original_head_hash = session.head_hash().cloned(); let original_head_hash = session.head_hash().cloned();
// Restore // Restore
@ -170,10 +170,10 @@ async fn session_restore_round_trip() {
.await .await
.unwrap(); .unwrap();
assert_eq!(restored.worker.history().len(), original_history.len()); assert_eq!(restored.worker().history().len(), original_history.len());
assert_eq!(restored.worker.turn_count(), original_turn_count); assert_eq!(restored.worker().turn_count(), original_turn_count);
assert_eq!( assert_eq!(
restored.worker.get_system_prompt().map(String::from), restored.worker().get_system_prompt().map(String::from),
Some("You are helpful.".to_string()) Some("You are helpful.".to_string())
); );
assert_eq!(restored.head_hash(), original_head_hash.as_ref()); assert_eq!(restored.head_hash(), original_head_hash.as_ref());
@ -184,7 +184,7 @@ async fn session_run_with_tool_call() {
let (_dir, store) = make_store().await; let (_dir, store) = make_store().await;
let client = MockLlmClient::with_responses(tool_call_events()); let client = MockLlmClient::with_responses(tool_call_events());
let mut worker = Worker::new(client); let mut worker = Worker::new(client);
worker.register_tool(weather_tool_definition()).unwrap(); worker.register_tool(weather_tool_definition());
let mut session = Session::new(worker, store.clone(), SessionConfig::default()) let mut session = Session::new(worker, store.clone(), SessionConfig::default())
.await .await
@ -213,7 +213,7 @@ async fn session_resume_after_pause() {
// First run: tool call with pause hook → Paused // First run: tool call with pause hook → Paused
let client = MockLlmClient::with_responses(tool_call_events()); let client = MockLlmClient::with_responses(tool_call_events());
let mut worker = Worker::new(client); let mut worker = Worker::new(client);
worker.register_tool(weather_tool_definition()).unwrap(); worker.register_tool(weather_tool_definition());
worker.set_interceptor(PausePolicy); worker.set_interceptor(PausePolicy);
let mut session = Session::new(worker, store.clone(), SessionConfig::default()) let mut session = Session::new(worker, store.clone(), SessionConfig::default())
@ -251,7 +251,7 @@ async fn session_resume_after_pause() {
.await .await
.unwrap(); .unwrap();
assert!(restored.worker.last_run_interrupted()); assert!(restored.worker().last_run_interrupted());
// resume may or may not succeed depending on Worker internal state, // resume may or may not succeed depending on Worker internal state,
// but the restore itself should work // but the restore itself should work
@ -271,7 +271,7 @@ async fn session_fork_preserves_state() {
session.run("Hello").await.unwrap(); session.run("Hello").await.unwrap();
let original_history_len = session.worker.history().len(); let original_history_len = session.worker().history().len();
let fork_id = session.fork().await.unwrap(); let fork_id = session.fork().await.unwrap();
// Fork should have a SessionStart with the current history // Fork should have a SessionStart with the current history
@ -334,7 +334,7 @@ async fn session_config_changed_logged() {
// Modify config via worker and log it // Modify config via worker and log it
session session
.worker .worker_mut()
.set_request_config(RequestConfig::default().with_temperature(0.7)); .set_request_config(RequestConfig::default().with_temperature(0.7));
session.log_config_changed().await.unwrap(); session.log_config_changed().await.unwrap();
@ -367,13 +367,13 @@ async fn session_cache_lock_unlock_logged() {
let has_locked = entries.iter().any(|e| { let has_locked = entries.iter().any(|e| {
matches!( matches!(
&e.entry, &e.entry,
LogEntry::CacheLocked { LogEntry::Locked {
locked_prefix_len: 5, locked_prefix_len: 5,
.. ..
} }
) )
}); });
assert!(has_locked, "should have CacheLocked entry"); assert!(has_locked, "should have Locked entry");
let has_unlocked = entries let has_unlocked = entries
.iter() .iter()

View File

@ -4,9 +4,7 @@
use llm_worker::llm_client::providers::anthropic::AnthropicClient; use llm_worker::llm_client::providers::anthropic::AnthropicClient;
use llm_worker::{Worker, WorkerResult}; use llm_worker::{Worker, WorkerResult};
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::sync::Mutex;
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Box<dyn std::error::Error>> {
@ -25,48 +23,38 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY environment variable not set"); std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY environment variable not set");
let client = AnthropicClient::new(&api_key, "claude-sonnet-4-20250514"); let client = AnthropicClient::new(&api_key, "claude-sonnet-4-20250514");
let worker = Arc::new(Mutex::new(Worker::new(client))); let worker = Worker::new(client);
println!("🚀 Starting Worker..."); println!("🚀 Starting Worker...");
println!("💡 Will cancel after 2 seconds\n"); println!("💡 Will cancel after 2 seconds\n");
// Get cancel sender first (without holding lock) // Get cancel sender before run (Mutable state)
let cancel_tx = { let cancel_tx = worker.cancel_sender();
let w = worker.lock().await;
w.cancel_sender()
};
// Task 1: Run Worker // Task: Cancel after 2 seconds
let worker_clone = worker.clone();
let task = tokio::spawn(async move {
let mut w = worker_clone.lock().await;
println!("📡 Sending request to LLM...");
match w.run("Tell me a very long story about a brave knight. Make it as detailed as possible with many paragraphs.").await {
Ok(WorkerResult::Finished) => {
println!("✅ Task completed normally");
}
Ok(WorkerResult::Paused) => {
println!("⏸️ Task paused");
}
Ok(WorkerResult::LimitReached) => {
println!("🔒 Turn limit reached");
}
Err(e) => {
println!("❌ Task error: {}", e);
}
}
});
// Task 2: Cancel after 2 seconds
tokio::spawn(async move { tokio::spawn(async move {
tokio::time::sleep(Duration::from_secs(2)).await; tokio::time::sleep(Duration::from_secs(2)).await;
println!("\n🛑 Cancelling worker..."); println!("\n🛑 Cancelling worker...");
let _ = cancel_tx.send(()).await; let _ = cancel_tx.send(()).await;
}); });
// Wait for task completion println!("📡 Sending request to LLM...");
task.await?;
// Mutable::run consumes self → (Locked, WorkerResult)
match worker.run("Tell me a very long story about a brave knight. Make it as detailed as possible with many paragraphs.").await {
Ok((_locked, WorkerResult::Finished)) => {
println!("✅ Task completed normally");
}
Ok((_locked, WorkerResult::Paused)) => {
println!("⏸️ Task paused");
}
Ok((_locked, WorkerResult::LimitReached)) => {
println!("🔒 Turn limit reached");
}
Err(e) => {
println!("❌ Task error: {}", e);
}
}
println!("\n✨ Demo complete!"); println!("\n✨ Demo complete!");

View File

@ -438,10 +438,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Register tools (unless --no-tools) // Register tools (unless --no-tools)
if !args.no_tools { if !args.no_tools {
let app = AppContext; let app = AppContext;
worker worker.register_tool(app.get_current_time_definition());
.register_tool(app.get_current_time_definition()) worker.register_tool(app.calculate_definition());
.unwrap();
worker.register_tool(app.calculate_definition()).unwrap();
} }
// Register streaming display handlers // Register streaming display handlers
@ -465,7 +463,27 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
return Ok(()); return Ok(());
} }
// Interactive loop // Interactive loop — first input transitions Mutable → Locked
print!("\n👤 You: ");
io::stdout().flush()?;
let mut first_input = String::new();
io::stdin().read_line(&mut first_input)?;
let first_input = first_input.trim();
if first_input == "quit" || first_input == "exit" || first_input.is_empty() {
println!("\n👋 Goodbye!");
return Ok(());
}
let (mut locked, _) = match worker.run(first_input).await {
Ok(pair) => pair,
Err(e) => {
eprintln!("\n❌ Error: {}", e);
return Ok(());
}
};
loop { loop {
print!("\n👤 You: "); print!("\n👤 You: ");
io::stdout().flush()?; io::stdout().flush()?;
@ -483,8 +501,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
break; break;
} }
// Run Worker (Worker manages history) match locked.run(input).await {
match worker.run(input).await {
Ok(_) => {} Ok(_) => {}
Err(e) => { Err(e) => {
eprintln!("\n❌ Error: {}", e); eprintln!("\n❌ Error: {}", e);

View File

@ -27,12 +27,14 @@
//! //!
//! # Cache Protection //! # Cache Protection
//! //!
//! To maximize KV cache hit rate, transition to the locked state //! `run()` automatically locks the cache. To edit state between turns,
//! with [`Worker::lock()`] before execution. //! call `unlock_cache()` first; the next `run()` re-locks automatically.
//! //!
//! ```ignore //! ```ignore
//! let mut locked = worker.lock(); //! worker.run("user input").await?;
//! locked.run("user input").await?; //! worker.unlock_cache();
//! worker.set_system_prompt("new prompt");
//! worker.run("next input").await?;
//! ``` //! ```
mod handler; mod handler;

View File

@ -1,7 +1,7 @@
//! Worker State //! Worker State
//! //!
//! State marker types for cache protection using the Type-state pattern. //! State marker types for cache protection using the Type-state pattern.
//! Worker has state transitions from `Mutable` → `CacheLocked`. //! Worker has state transitions from `Mutable` → `Locked`.
/// Marker trait representing Worker state /// Marker trait representing Worker state
/// ///
@ -19,7 +19,7 @@ mod private {
/// - Editing message history (add, delete, clear) /// - Editing message history (add, delete, clear)
/// - Registering tools and hooks /// - Registering tools and hooks
/// ///
/// Can transition to [`CacheLocked`] state via `Worker::lock()`. /// Can transition to [`Locked`] state via `Worker::lock()`.
/// ///
/// # Examples /// # Examples
/// ///
@ -54,7 +54,7 @@ impl WorkerState for Mutable {}
/// Can return to [`Mutable`] state via `Worker::unlock()`, /// Can return to [`Mutable`] state via `Worker::unlock()`,
/// but note that cache protection will be released. /// but note that cache protection will be released.
#[derive(Debug, Clone, Copy, Default)] #[derive(Debug, Clone, Copy, Default)]
pub struct CacheLocked; pub struct Locked;
impl private::Sealed for CacheLocked {} impl private::Sealed for Locked {}
impl WorkerState for CacheLocked {} impl WorkerState for Locked {}

View File

@ -26,6 +26,7 @@ pub enum ToolServerError {
#[derive(Clone, Default)] #[derive(Clone, Default)]
pub struct ToolServer { pub struct ToolServer {
tools: Arc<Mutex<ToolMap>>, tools: Arc<Mutex<ToolMap>>,
pending: Arc<Mutex<Vec<WorkerToolDefinition>>>,
} }
impl ToolServer { impl ToolServer {
@ -38,6 +39,7 @@ impl ToolServer {
pub fn handle(&self) -> ToolServerHandle { pub fn handle(&self) -> ToolServerHandle {
ToolServerHandle { ToolServerHandle {
tools: Arc::clone(&self.tools), tools: Arc::clone(&self.tools),
pending: Arc::clone(&self.pending),
} }
} }
} }
@ -46,32 +48,57 @@ impl ToolServer {
#[derive(Clone, Default)] #[derive(Clone, Default)]
pub struct ToolServerHandle { pub struct ToolServerHandle {
tools: Arc<Mutex<ToolMap>>, tools: Arc<Mutex<ToolMap>>,
pending: Arc<Mutex<Vec<WorkerToolDefinition>>>,
} }
impl ToolServerHandle { impl ToolServerHandle {
/// Register one tool. /// Queue a tool factory for deferred initialization.
pub(crate) fn register_tool( ///
&self, /// The factory is **not** called here; it is stored and executed
factory: WorkerToolDefinition, /// when [`flush_pending`](Self::flush_pending) is called (typically
) -> Result<(), ToolServerError> { /// at the start of `Worker::run()`).
let (meta, instance) = factory(); pub(crate) fn register_tool(&self, factory: WorkerToolDefinition) {
let mut guard = self.tools.lock().unwrap_or_else(|e| e.into_inner()); self.pending
if guard.contains_key(&meta.name) { .lock()
return Err(ToolServerError::DuplicateName(meta.name)); .unwrap_or_else(|e| e.into_inner())
} .push(factory);
guard.insert(meta.name.clone(), (meta, instance));
Ok(())
} }
/// Register many tools. /// Queue many tool factories for deferred initialization.
pub(crate) fn register_tools( pub(crate) fn register_tools(
&self, &self,
factories: impl IntoIterator<Item = WorkerToolDefinition>, factories: impl IntoIterator<Item = WorkerToolDefinition>,
) -> Result<(), ToolServerError> { ) {
for factory in factories { let mut guard = self.pending.lock().unwrap_or_else(|e| e.into_inner());
self.register_tool(factory)?; guard.extend(factories);
}
/// Execute all pending factories and register the resulting tools.
///
/// # Panics
///
/// Panics if any factory produces a tool whose name collides with
/// an already-registered tool. Duplicate names are a programming
/// error and should be caught during development.
pub(crate) fn flush_pending(&self) {
let pending: Vec<_> = {
let mut guard = self.pending.lock().unwrap_or_else(|e| e.into_inner());
std::mem::take(&mut *guard)
};
if pending.is_empty() {
return;
}
// Execute all factories first, then validate and insert atomically.
let materialized: Vec<_> = pending.into_iter().map(|f| f()).collect();
let mut tools = self.tools.lock().unwrap_or_else(|e| e.into_inner());
for (meta, instance) in materialized {
assert!(
!tools.contains_key(&meta.name),
"duplicate tool name: '{}'",
meta.name,
);
tools.insert(meta.name.clone(), (meta, instance));
} }
Ok(())
} }
/// Get a tool by name for hook contexts. /// Get a tool by name for hook contexts.
@ -143,19 +170,37 @@ mod tests {
} }
#[test] #[test]
fn register_duplicate_name_fails() { fn flush_pending_registers_tools() {
let handle = ToolServer::new().handle(); let handle = ToolServer::new().handle();
handle.register_tool(def("alpha")).expect("first register"); handle.register_tool(def("alpha"));
let err = handle handle.register_tool(def("beta"));
.register_tool(def("alpha"))
.expect_err("duplicate should fail"); // Before flush, no tools are available
assert_eq!(err, ToolServerError::DuplicateName("alpha".to_string())); assert!(handle.get_tool("alpha").is_none());
handle.flush_pending();
// After flush, tools are available
assert!(handle.get_tool("alpha").is_some());
assert!(handle.get_tool("beta").is_some());
}
#[test]
#[should_panic(expected = "duplicate tool name: 'alpha'")]
fn flush_pending_duplicate_name_panics() {
let handle = ToolServer::new().handle();
handle.register_tool(def("alpha"));
handle.flush_pending();
handle.register_tool(def("alpha"));
handle.flush_pending(); // panics
} }
#[tokio::test] #[tokio::test]
async fn call_tool_success_and_not_found() { async fn call_tool_success_and_not_found() {
let handle = ToolServer::new().handle(); let handle = ToolServer::new().handle();
handle.register_tool(def("echo")).expect("register"); handle.register_tool(def("echo"));
handle.flush_pending();
let out = handle.call_tool("echo", r#"{"x":1}"#).await.expect("call"); let out = handle.call_tool("echo", r#"{"x":1}"#).await.expect("call");
assert_eq!(out, r#"{"x":1}"#); assert_eq!(out, r#"{"x":1}"#);
@ -170,9 +215,10 @@ mod tests {
#[test] #[test]
fn tool_definitions_are_sorted() { fn tool_definitions_are_sorted() {
let handle = ToolServer::new().handle(); let handle = ToolServer::new().handle();
handle.register_tool(def("zeta")).expect("register zeta"); handle.register_tool(def("zeta"));
handle.register_tool(def("alpha")).expect("register alpha"); handle.register_tool(def("alpha"));
handle.register_tool(def("beta")).expect("register beta"); handle.register_tool(def("beta"));
handle.flush_pending();
let names: Vec<_> = handle let names: Vec<_> = handle
.tool_definitions_sorted() .tool_definitions_sorted()
@ -181,4 +227,11 @@ mod tests {
.collect(); .collect();
assert_eq!(names, vec!["alpha", "beta", "zeta"]); assert_eq!(names, vec!["alpha", "beta", "zeta"]);
} }
#[test]
fn flush_pending_is_noop_when_empty() {
let handle = ToolServer::new().handle();
handle.flush_pending();
handle.flush_pending();
}
} }

View File

@ -13,7 +13,7 @@ use crate::{
DefaultInterceptor, Interceptor, PostToolAction, PreRequestAction, PreToolAction, DefaultInterceptor, Interceptor, PostToolAction, PreRequestAction, PreToolAction,
PromptAction, ToolCallInfo, ToolResultInfo, TurnEndAction, PromptAction, ToolCallInfo, ToolResultInfo, TurnEndAction,
}, },
state::{CacheLocked, Mutable, WorkerState}, state::{Locked, Mutable, WorkerState},
callback::{ callback::{
ClosureMetaHandler, ClosureTextBlockHandler, ClosureToolUseBlockHandler, TextBlockScope, ClosureMetaHandler, ClosureTextBlockHandler, ClosureToolUseBlockHandler, TextBlockScope,
ToolUseBlockScope, ToolUseBlockScope,
@ -22,12 +22,9 @@ use crate::{
timeline::{TextBlockCollector, Timeline, ToolCallCollector}, timeline::{TextBlockCollector, Timeline, ToolCallCollector},
timeline::event::{ErrorEvent, StatusEvent, UsageEvent}, timeline::event::{ErrorEvent, StatusEvent, UsageEvent},
tool::{ToolCall, ToolDefinition as WorkerToolDefinition, ToolError, ToolOutputProcessor, ToolResult}, tool::{ToolCall, ToolDefinition as WorkerToolDefinition, ToolError, ToolOutputProcessor, ToolResult},
tool_server::{ToolServer, ToolServerError, ToolServerHandle}, tool_server::{ToolServer, ToolServerHandle},
}; };
// =============================================================================
// Worker Error
// =============================================================================
/// Worker errors /// Worker errors
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
@ -57,9 +54,6 @@ pub enum ToolRegistryError {
DuplicateName(String), DuplicateName(String),
} }
// =============================================================================
// Worker Config
// =============================================================================
/// Worker configuration /// Worker configuration
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
@ -68,9 +62,6 @@ pub struct WorkerConfig {
_private: (), _private: (),
} }
// =============================================================================
// Worker Result Types
// =============================================================================
/// Worker execution result (status) /// Worker execution result (status)
#[derive(Debug)] #[derive(Debug)]
@ -89,9 +80,6 @@ enum ToolExecutionResult {
Paused, Paused,
} }
// =============================================================================
// Worker
// =============================================================================
/// Central component for managing LLM interactions /// Central component for managing LLM interactions
/// ///
@ -100,32 +88,28 @@ enum ToolExecutionResult {
/// ///
/// # State Transitions (Type-state) /// # State Transitions (Type-state)
/// ///
/// - [`Mutable`]: Initial state. System prompt and history can be freely edited. /// - [`Mutable`]: Initial state. System prompt, history, and tools can be freely edited.
/// - [`CacheLocked`]: Cache-protected state. Transition via `lock()`. Prefix context is immutable. /// - [`Locked`]: Cache-protected state. Prefix context is immutable; only `run()` / `resume()` are available.
/// ///
/// # Examples /// Calling `run()` on a `Mutable` Worker consumes it and returns a
/// `Locked` Worker together with the result. This ensures the
/// cache prefix is fixed for optimal KV cache hit rate.
/// ///
/// ```ignore /// ```ignore
/// use llm_worker::{Worker, Item};
///
/// // Create a Worker and register tools
/// let mut worker = Worker::new(client) /// let mut worker = Worker::new(client)
/// .system_prompt("You are a helpful assistant."); /// .system_prompt("You are a helpful assistant.");
/// worker.register_tool(my_tool); /// worker.register_tool(my_tool);
/// ///
/// // Run the interaction /// // Mutable::run() consumes self → Locked
/// let history = worker.run("Hello!").await?; /// let (mut worker, _result) = worker.run("Hello").await?;
/// ```
/// ///
/// # When Cache Protection is Needed /// // Locked::run() borrows &mut self
/// worker.run("Follow-up").await?;
/// ///
/// ```ignore /// // To edit between turns, unlock back to Mutable
/// let mut worker = Worker::new(client) /// let mut worker = worker.unlock();
/// .system_prompt("..."); /// worker.history_mut().truncate(5);
/// /// let (mut worker, _result) = worker.run("Continue").await?;
/// // After setting history, lock to protect cache
/// let mut locked = worker.lock();
/// locked.run("user input").await?;
/// ``` /// ```
pub struct Worker<C: LlmClient, S: WorkerState = Mutable> { pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
/// LLM client /// LLM client
@ -144,7 +128,7 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
system_prompt: Option<String>, system_prompt: Option<String>,
/// Item history (owned by Worker) /// Item history (owned by Worker)
history: Vec<Item>, history: Vec<Item>,
/// History length at lock time (only meaningful in CacheLocked state) /// History length at lock time (only meaningful in Locked state)
locked_prefix_len: usize, locked_prefix_len: usize,
/// Turn count /// Turn count
turn_count: usize, turn_count: usize,
@ -167,40 +151,12 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
_state: PhantomData<S>, _state: PhantomData<S>,
} }
// =============================================================================
// Common Implementation (available in all states)
// =============================================================================
impl<C: LlmClient, S: WorkerState> Worker<C, S> { impl<C: LlmClient, S: WorkerState> Worker<C, S> {
fn reset_interruption_state(&mut self) { fn reset_interruption_state(&mut self) {
self.last_run_interrupted = false; self.last_run_interrupted = false;
} }
/// Execute a turn
///
/// Adds a new user message to history and sends a request to the LLM.
/// Automatically loops if there are tool calls.
pub async fn run(
&mut self,
user_input: impl Into<String>,
) -> Result<WorkerResult, WorkerError> {
self.reset_interruption_state();
// Interceptor: on_prompt_submit
let mut user_item = Item::user_message(user_input);
match self.interceptor.on_prompt_submit(&mut user_item).await {
PromptAction::Cancel(reason) => {
self.last_run_interrupted = true;
return self
.finalize_interruption(Err(WorkerError::Aborted(reason)))
.await;
}
PromptAction::Continue => {}
}
self.history.push(user_item);
let result = self.run_turn_loop().await;
self.finalize_interruption(result).await
}
fn drain_cancel_queue(&mut self) { fn drain_cancel_queue(&mut self) {
use tokio::sync::mpsc::error::TryRecvError; use tokio::sync::mpsc::error::TryRecvError;
loop { loop {
@ -892,19 +848,8 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
} }
} }
/// Resume execution (from Paused state)
///
/// Resumes turn processing from current state without adding a new user message to history.
pub async fn resume(&mut self) -> Result<WorkerResult, WorkerError> {
self.reset_interruption_state();
let result = self.run_turn_loop().await;
self.finalize_interruption(result).await
}
} }
// =============================================================================
// Mutable State-Specific Implementation
// =============================================================================
impl<C: LlmClient> Worker<C, Mutable> { impl<C: LlmClient> Worker<C, Mutable> {
/// Create a new Worker (in Mutable state) /// Create a new Worker (in Mutable state)
@ -941,43 +886,21 @@ impl<C: LlmClient> Worker<C, Mutable> {
} }
} }
/// Register a tool /// Register a tool factory for deferred initialization.
/// ///
/// Registered tools are automatically executed when called by the LLM. /// The factory is queued and executed at the next `run()` or `resume()` call.
/// Registering a tool with the same name will result in an error. /// Duplicate name detection occurs at that point and surfaces as
/// /// [`WorkerError::ToolRegistry`].
/// Available only in Mutable state. pub fn register_tool(&mut self, factory: WorkerToolDefinition) {
pub fn register_tool( self.tool_server.register_tool(factory);
&mut self,
factory: WorkerToolDefinition,
) -> Result<(), ToolRegistryError> {
match self.tool_server.register_tool(factory) {
Ok(()) => Ok(()),
Err(ToolServerError::DuplicateName(name)) => {
Err(ToolRegistryError::DuplicateName(name))
}
Err(ToolServerError::ToolNotFound(_) | ToolServerError::ToolExecution(_)) => {
unreachable!("register_tool should only fail with DuplicateName")
}
}
} }
/// Register multiple tools /// Register multiple tool factories for deferred initialization.
///
/// Available only in Mutable state.
pub fn register_tools( pub fn register_tools(
&mut self, &mut self,
factories: impl IntoIterator<Item = WorkerToolDefinition>, factories: impl IntoIterator<Item = WorkerToolDefinition>,
) -> Result<(), ToolRegistryError> { ) {
match self.tool_server.register_tools(factories) { self.tool_server.register_tools(factories);
Ok(()) => Ok(()),
Err(ToolServerError::DuplicateName(name)) => {
Err(ToolRegistryError::DuplicateName(name))
}
Err(ToolServerError::ToolNotFound(_) | ToolServerError::ToolExecution(_)) => {
unreachable!("register_tools should only fail with DuplicateName")
}
}
} }
/// Set system prompt (builder pattern) /// Set system prompt (builder pattern)
@ -1082,40 +1005,47 @@ impl<C: LlmClient> Worker<C, Mutable> {
/// Get a mutable reference to history /// Get a mutable reference to history
/// ///
/// Available only in Mutable state. History can be freely edited. /// Available only in Mutable state.
pub fn history_mut(&mut self) -> &mut Vec<Item> { pub fn history_mut(&mut self) -> &mut Vec<Item> {
&mut self.history &mut self.history
} }
/// Set history /// Set history
pub fn set_history(&mut self, items: Vec<Item>) { pub fn set_history(&mut self, items: Vec<Item>) {
self.history = items; self.history = items;
} }
/// Add an item to history (builder pattern) /// Add an item to history (builder pattern)
pub fn with_item(mut self, item: Item) -> Self { pub fn with_item(mut self, item: Item) -> Self {
self.history.push(item); self.history.push(item);
self self
} }
/// Add an item to history /// Add an item to history
pub fn push_item(&mut self, item: Item) { pub fn push_item(&mut self, item: Item) {
self.history.push(item); self.history.push(item);
} }
/// Add multiple items to history (builder pattern) /// Add multiple items to history (builder pattern)
pub fn with_items(mut self, items: impl IntoIterator<Item = Item>) -> Self { pub fn with_items(mut self, items: impl IntoIterator<Item = Item>) -> Self {
self.history.extend(items); self.history.extend(items);
self self
} }
/// Add multiple items to history /// Add multiple items to history
pub fn extend_history(&mut self, items: impl IntoIterator<Item = Item>) { pub fn extend_history(&mut self, items: impl IntoIterator<Item = Item>) {
self.history.extend(items); self.history.extend(items);
} }
/// Clear history /// Clear history
pub fn clear_history(&mut self) { pub fn clear_history(&mut self) {
self.history.clear(); self.history.clear();
} }
@ -1148,11 +1078,48 @@ impl<C: LlmClient> Worker<C, Mutable> {
self self
} }
/// Lock and transition to CacheLocked state /// Execute a turn, consuming self and transitioning to Locked.
/// ///
/// This operation fixes the current system prompt and history as a "committed prefix". /// This is the primary entry point for first use. Equivalent to
/// After this, only appending to history is allowed, ensuring cache hits. /// `self.lock()` followed by `locked.run(user_input)`.
pub fn lock(self) -> Worker<C, CacheLocked> { ///
/// Subsequent runs can use [`Worker<C, Locked>::run()`] directly.
/// To edit state between turns, call [`unlock()`](Worker::unlock) first.
pub async fn run(
self,
user_input: impl Into<String>,
) -> Result<(Worker<C, Locked>, WorkerResult), WorkerError> {
let mut locked = self.lock();
let result = locked.run(user_input).await?;
Ok((locked, result))
}
/// Resume from Paused, consuming self and transitioning to Locked.
///
/// Used after `unlock()` → edit → resume.
pub async fn resume(
self,
) -> Result<(Worker<C, Locked>, WorkerResult), WorkerError> {
let mut locked = self.lock();
let result = locked.resume().await?;
Ok((locked, result))
}
/// Lock and transition to Locked state
///
/// Flushes pending tool factories, then fixes the current system prompt
/// and history as a "committed prefix". After this, only `run()` / `resume()`
/// may append to history, ensuring cache hits.
///
/// Most callers should use [`run()`](Self::run) instead, which calls
/// this internally. Use `lock()` directly only when you need the
/// `Locked` worker back on error (e.g. in a persistence layer).
///
/// # Panics
///
/// Panics if a pending tool factory produces a duplicate name.
pub fn lock(self) -> Worker<C, Locked> {
self.tool_server.flush_pending();
let locked_prefix_len = self.history.len(); let locked_prefix_len = self.history.len();
Worker { Worker {
client: self.client, client: self.client,
@ -1178,11 +1145,42 @@ impl<C: LlmClient> Worker<C, Mutable> {
} }
} }
// =============================================================================
// CacheLocked State-Specific Implementation
// =============================================================================
impl<C: LlmClient> Worker<C, CacheLocked> { impl<C: LlmClient> Worker<C, Locked> {
/// Execute a turn
///
/// Adds a new user message to history and sends a request to the LLM.
/// Automatically loops if there are tool calls.
pub async fn run(
&mut self,
user_input: impl Into<String>,
) -> Result<WorkerResult, WorkerError> {
self.reset_interruption_state();
// Interceptor: on_prompt_submit
let mut user_item = Item::user_message(user_input);
match self.interceptor.on_prompt_submit(&mut user_item).await {
PromptAction::Cancel(reason) => {
self.last_run_interrupted = true;
return self
.finalize_interruption(Err(WorkerError::Aborted(reason)))
.await;
}
PromptAction::Continue => {}
}
self.history.push(user_item);
let result = self.run_turn_loop().await;
self.finalize_interruption(result).await
}
/// Resume execution (from Paused state)
///
/// Resumes turn processing from current state without adding a new user message.
pub async fn resume(&mut self) -> Result<WorkerResult, WorkerError> {
self.reset_interruption_state();
let result = self.run_turn_loop().await;
self.finalize_interruption(result).await
}
/// Get the prefix length at lock time /// Get the prefix length at lock time
pub fn locked_prefix_len(&self) -> usize { pub fn locked_prefix_len(&self) -> usize {
self.locked_prefix_len self.locked_prefix_len

View File

@ -46,8 +46,9 @@ async fn test_callback_text_block_events() {
}); });
}); });
// Mutable::run consumes self, returns (Locked, WorkerResult)
let result = worker.run("Greet me").await; let result = worker.run("Greet me").await;
assert!(result.is_ok(), "Worker should complete: {:?}", result); assert!(result.is_ok(), "Worker should complete");
let deltas = text_deltas.lock().unwrap(); let deltas = text_deltas.lock().unwrap();
assert_eq!(deltas.len(), 2); assert_eq!(deltas.len(), 2);
@ -91,6 +92,7 @@ async fn test_callback_tool_call_complete() {
}); });
}); });
// Mutable::run consumes self, returns (Locked, WorkerResult)
let _ = worker.run("Weather please").await; let _ = worker.run("Weather please").await;
let starts = tool_starts.lock().unwrap(); let starts = tool_starts.lock().unwrap();
@ -133,6 +135,7 @@ async fn test_callback_turn_events() {
ends.lock().unwrap().push(turn); ends.lock().unwrap().push(turn);
}); });
// Mutable::run consumes self, returns (Locked, WorkerResult)
let result = worker.run("Do something").await; let result = worker.run("Do something").await;
assert!(result.is_ok()); assert!(result.is_ok());
@ -169,6 +172,7 @@ async fn test_callback_usage_events() {
usages.lock().unwrap().push(event.clone()); usages.lock().unwrap().push(event.clone());
}); });
// Mutable::run consumes self, returns (Locked, WorkerResult)
let _ = worker.run("Hello").await; let _ = worker.run("Hello").await;
let usages = usage_events.lock().unwrap(); let usages = usage_events.lock().unwrap();

View File

@ -1,6 +1,6 @@
#[test] #[test]
fn compile_fail_state_constraints() { fn compile_fail_state_constraints() {
let t = trybuild::TestCases::new(); let t = trybuild::TestCases::new();
t.compile_fail("tests/ui/cache_locked_register_tool.rs"); t.compile_fail("tests/ui/locked_register_tool.rs");
t.compile_fail("tests/ui/tool_server_handle_register_tool.rs"); t.compile_fail("tests/ui/tool_server_handle_register_tool.rs");
} }

View File

@ -102,11 +102,12 @@ async fn test_parallel_tool_execution() {
let tool2_clone = tool2.clone(); let tool2_clone = tool2.clone();
let tool3_clone = tool3.clone(); let tool3_clone = tool3.clone();
worker.register_tool(tool1.definition()).unwrap(); worker.register_tool(tool1.definition());
worker.register_tool(tool2.definition()).unwrap(); worker.register_tool(tool2.definition());
worker.register_tool(tool3.definition()).unwrap(); worker.register_tool(tool3.definition());
let start = Instant::now(); let start = Instant::now();
// Mutable::run consumes self, returns (Locked, WorkerResult)
let _result = worker.run("Run all tools").await; let _result = worker.run("Run all tools").await;
let elapsed = start.elapsed(); let elapsed = start.elapsed();
@ -150,8 +151,8 @@ async fn test_before_tool_call_skip() {
let allowed_clone = allowed_tool.clone(); let allowed_clone = allowed_tool.clone();
let blocked_clone = blocked_tool.clone(); let blocked_clone = blocked_tool.clone();
worker.register_tool(allowed_tool.definition()).unwrap(); worker.register_tool(allowed_tool.definition());
worker.register_tool(blocked_tool.definition()).unwrap(); worker.register_tool(blocked_tool.definition());
// Policy to skip "blocked_tool" // Policy to skip "blocked_tool"
struct BlockingPolicy; struct BlockingPolicy;
@ -169,6 +170,7 @@ async fn test_before_tool_call_skip() {
worker.set_interceptor(BlockingPolicy); worker.set_interceptor(BlockingPolicy);
// Mutable::run consumes self, returns (Locked, WorkerResult)
let _result = worker.run("Test hook").await; let _result = worker.run("Test hook").await;
// allowed_tool is called, but blocked_tool is not // allowed_tool is called, but blocked_tool is not
@ -230,7 +232,7 @@ async fn test_post_tool_call_modification() {
}) })
} }
worker.register_tool(simple_tool_definition()).unwrap(); worker.register_tool(simple_tool_definition());
// Policy to modify results // Policy to modify results
struct ModifyingPolicy { struct ModifyingPolicy {
@ -251,9 +253,10 @@ async fn test_post_tool_call_modification() {
modified_content: modified_content.clone(), modified_content: modified_content.clone(),
}); });
// Mutable::run consumes self, returns (Locked, WorkerResult)
let result = worker.run("Test modification").await; let result = worker.run("Test modification").await;
assert!(result.is_ok(), "Worker should complete: {:?}", result); assert!(result.is_ok(), "Worker should complete");
// Verify hook was called and content was modified // Verify hook was called and content was modified
let content = modified_content.lock().unwrap().clone(); let content = modified_content.lock().unwrap().clone();

View File

@ -1,8 +1,8 @@
error[E0599]: no method named `register_tool` found for struct `Worker<OllamaClient, CacheLocked>` in the current scope error[E0599]: no method named `register_tool` found for struct `Worker<OllamaClient, Locked>` in the current scope
--> tests/ui/cache_locked_register_tool.rs:10:20 --> tests/ui/locked_register_tool.rs:10:20
| |
10 | let _ = locked.register_tool(def); 10 | let _ = locked.register_tool(def);
| ^^^^^^^^^^^^^ method not found in `Worker<OllamaClient, CacheLocked>` | ^^^^^^^^^^^^^ method not found in `Worker<OllamaClient, Locked>`
| |
= note: the method was found for = note: the method was found for
- `Worker<C>` - `Worker<C>`

View File

@ -1,13 +1,10 @@
error[E0624]: method `register_tool` is private error[E0624]: method `register_tool` is private
--> tests/ui/tool_server_handle_register_tool.rs:10:20 --> tests/ui/tool_server_handle_register_tool.rs:10:20
| |
10 | let _ = handle.register_tool(def); 10 | let _ = handle.register_tool(def);
| ^^^^^^^^^^^^^ private method | ^^^^^^^^^^^^^ private method
| |
::: src/tool_server.rs ::: src/tool_server.rs
| |
| / pub(crate) fn register_tool( | pub(crate) fn register_tool(&self, factory: WorkerToolDefinition) {
| | &self, | ----------------------------------------------------------------- private method defined here
| | factory: WorkerToolDefinition,
| | ) -> Result<(), ToolServerError> {
| |____________________________________- private method defined here

View File

@ -129,9 +129,9 @@ async fn test_worker_simple_text_response() {
} }
let client = MockLlmClient::from_fixture(&fixture_path).unwrap(); let client = MockLlmClient::from_fixture(&fixture_path).unwrap();
let mut worker = Worker::new(client); let worker = Worker::new(client);
// Send a simple message // Send a simple message (Mutable::run consumes self, returns tuple)
let result = worker.run("Hello").await; let result = worker.run("Hello").await;
assert!(result.is_ok(), "Worker should complete successfully"); assert!(result.is_ok(), "Worker should complete successfully");
@ -156,9 +156,9 @@ async fn test_worker_tool_call() {
// Register tool // Register tool
let weather_tool = MockWeatherTool::new(); let weather_tool = MockWeatherTool::new();
let tool_for_check = weather_tool.clone(); let tool_for_check = weather_tool.clone();
worker.register_tool(weather_tool.definition()).unwrap(); worker.register_tool(weather_tool.definition());
// Send message // Send message (Mutable::run consumes self, returns tuple)
let _result = worker.run("What's the weather in Tokyo?").await; let _result = worker.run("What's the weather in Tokyo?").await;
// Verify tool was called // Verify tool was called
@ -190,8 +190,9 @@ async fn test_worker_with_programmatic_events() {
]; ];
let client = MockLlmClient::new(events); let client = MockLlmClient::new(events);
let mut worker = Worker::new(client); let worker = Worker::new(client);
// Mutable::run consumes self, returns tuple
let result = worker.run("Greet me").await; let result = worker.run("Greet me").await;
assert!(result.is_ok(), "Worker should complete successfully"); assert!(result.is_ok(), "Worker should complete successfully");

View File

@ -1,6 +1,6 @@
//! Worker state management tests //! Worker state management tests
//! //!
//! Tests for state transitions using the Type-state pattern (Mutable/CacheLocked) //! Tests for state transitions using the Type-state pattern (Mutable/Locked)
//! and state preservation between turns. //! and state preservation between turns.
mod common; mod common;
@ -11,7 +11,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};
use async_trait::async_trait; use async_trait::async_trait;
use common::MockLlmClient; use common::MockLlmClient;
use llm_worker::Item; use llm_worker::Item;
use llm_worker::Worker; use llm_worker::{Worker, WorkerError};
use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent}; use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent};
use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta}; use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta};
@ -147,15 +147,15 @@ fn test_mutable_can_register_tool() {
let mut worker = Worker::new(client); let mut worker = Worker::new(client);
let tool = CountingTool::new("count_tool"); let tool = CountingTool::new("count_tool");
let result = worker.register_tool(tool.definition()); // register_tool is infallible (factory deferred to run-time flush)
assert!(result.is_ok(), "Mutable should allow tool registration"); worker.register_tool(tool.definition());
} }
// ============================================================================= // =============================================================================
// State Transition Tests // State Transition Tests
// ============================================================================= // =============================================================================
/// Verify that lock() transitions from Mutable -> CacheLocked state /// Verify that lock() transitions from Mutable -> Locked state
#[test] #[test]
fn test_lock_transition() { fn test_lock_transition() {
let client = MockLlmClient::new(vec![]); let client = MockLlmClient::new(vec![]);
@ -168,13 +168,13 @@ fn test_lock_transition() {
// Lock // Lock
let locked_worker = worker.lock(); let locked_worker = worker.lock();
// History and system prompt are still accessible in CacheLocked state // History and system prompt are still accessible in Locked state
assert_eq!(locked_worker.get_system_prompt(), Some("System")); assert_eq!(locked_worker.get_system_prompt(), Some("System"));
assert_eq!(locked_worker.history().len(), 2); assert_eq!(locked_worker.history().len(), 2);
assert_eq!(locked_worker.locked_prefix_len(), 2); assert_eq!(locked_worker.locked_prefix_len(), 2);
} }
/// Verify that unlock() transitions from CacheLocked -> Mutable state /// Verify that unlock() transitions from Locked -> Mutable state
#[test] #[test]
fn test_unlock_transition() { fn test_unlock_transition() {
let client = MockLlmClient::new(vec![]); let client = MockLlmClient::new(vec![]);
@ -198,7 +198,7 @@ fn test_unlock_transition() {
/// Verify that history is correctly updated after running a turn in Mutable state /// Verify that history is correctly updated after running a turn in Mutable state
#[tokio::test] #[tokio::test]
async fn test_mutable_run_updates_history() { async fn test_mutable_run_updates_history() -> Result<(), WorkerError> {
let events = vec![ let events = vec![
Event::text_block_start(0), Event::text_block_start(0),
Event::text_delta(0, "Hello, I'm an assistant!"), Event::text_delta(0, "Hello, I'm an assistant!"),
@ -209,11 +209,10 @@ async fn test_mutable_run_updates_history() {
]; ];
let client = MockLlmClient::new(events); let client = MockLlmClient::new(events);
let mut worker = Worker::new(client); let worker = Worker::new(client);
// Execute // Execute (Mutable::run consumes self, returns (Locked, WorkerResult))
let result = worker.run("Hi there").await; let (worker, _result) = worker.run("Hi there").await?;
assert!(result.is_ok());
// History is updated // History is updated
let history = worker.history(); let history = worker.history();
@ -224,9 +223,11 @@ async fn test_mutable_run_updates_history() {
// Assistant message // Assistant message
assert_eq!(history[1].as_text(), Some("Hello, I'm an assistant!")); assert_eq!(history[1].as_text(), Some("Hello, I'm an assistant!"));
Ok(())
} }
/// Verify that history accumulates correctly over multiple turns in CacheLocked state /// Verify that history accumulates correctly over multiple turns in Locked state
#[tokio::test] #[tokio::test]
async fn test_locked_multi_turn_history_accumulation() { async fn test_locked_multi_turn_history_accumulation() {
// Prepare responses for 2 requests // Prepare responses for 2 requests
@ -327,7 +328,7 @@ async fn test_locked_prefix_len_tracking() {
/// Verify that turn count is correctly incremented /// Verify that turn count is correctly incremented
#[tokio::test] #[tokio::test]
async fn test_turn_count_increment() { async fn test_turn_count_increment() -> Result<(), WorkerError> {
let client = MockLlmClient::with_responses(vec![ let client = MockLlmClient::with_responses(vec![
vec![ vec![
Event::text_block_start(0), Event::text_block_start(0),
@ -347,15 +348,19 @@ async fn test_turn_count_increment() {
], ],
]); ]);
let mut worker = Worker::new(client); let worker = Worker::new(client);
assert_eq!(worker.turn_count(), 0); assert_eq!(worker.turn_count(), 0);
worker.run("First").await.unwrap(); // First run consumes Mutable, returns Locked
let (mut worker, _) = worker.run("First").await?;
assert_eq!(worker.turn_count(), 1); assert_eq!(worker.turn_count(), 1);
worker.run("Second").await.unwrap(); // Subsequent runs on Locked take &mut self
worker.run("Second").await?;
assert_eq!(worker.turn_count(), 2); assert_eq!(worker.turn_count(), 2);
Ok(())
} }
/// Verify that history can be edited after unlock and re-locked /// Verify that history can be edited after unlock and re-locked
@ -430,9 +435,7 @@ async fn test_lock_unlock_relock_tools_remain_effective() {
let mut worker = Worker::new(client); let mut worker = Worker::new(client);
let tool_a = CountingTool::new("tool_a"); let tool_a = CountingTool::new("tool_a");
worker worker.register_tool(tool_a.definition());
.register_tool(tool_a.definition())
.expect("register tool_a should succeed");
let mut locked = worker.lock(); let mut locked = worker.lock();
locked.run("first").await.expect("first run"); locked.run("first").await.expect("first run");
@ -440,9 +443,7 @@ async fn test_lock_unlock_relock_tools_remain_effective() {
let mut unlocked = locked.unlock(); let mut unlocked = locked.unlock();
let tool_b = CountingTool::new("tool_b"); let tool_b = CountingTool::new("tool_b");
unlocked unlocked.register_tool(tool_b.definition());
.register_tool(tool_b.definition())
.expect("register tool_b after unlock should succeed");
let mut relocked = unlocked.lock(); let mut relocked = unlocked.lock();
relocked.run("second").await.expect("second run"); relocked.run("second").await.expect("second run");
@ -455,7 +456,7 @@ async fn test_lock_unlock_relock_tools_remain_effective() {
// System Prompt Preservation Tests // System Prompt Preservation Tests
// ============================================================================= // =============================================================================
/// Verify that system prompt is preserved in CacheLocked state /// Verify that system prompt is preserved in Locked state
#[test] #[test]
fn test_system_prompt_preserved_in_locked_state() { fn test_system_prompt_preserved_in_locked_state() {
let client = MockLlmClient::new(vec![]); let client = MockLlmClient::new(vec![]);

View File

@ -53,7 +53,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
} }
// 5. Extract the assistant's reply from history // 5. Extract the assistant's reply from history
let history = pod.session_mut().worker.history(); let history = pod.session_mut().worker().history();
if let Some(text) = history if let Some(text) = history
.iter() .iter()
.rev() .rev()

View File

@ -85,7 +85,7 @@ impl PodController {
// Register event bridge callbacks on the worker // Register event bridge callbacks on the worker
{ {
let worker = &mut pod.session_mut().worker; let worker = pod.session_mut().worker_mut();
let tx = event_tx.clone(); let tx = event_tx.clone();
worker.on_turn_start(move |turn| { worker.on_turn_start(move |turn| {
@ -158,7 +158,7 @@ impl PodController {
} }
// Clone cancel sender before moving pod // Clone cancel sender before moving pod
let cancel_tx = pod.session_mut().worker.cancel_sender(); let cancel_tx = pod.session_mut().worker_mut().cancel_sender();
tokio::spawn(async move { tokio::spawn(async move {
// Hold socket server alive for the lifetime of the controller task // Hold socket server alive for the lifetime of the controller task
@ -191,7 +191,7 @@ impl PodController {
) )
.await; .await;
let items = pod.session_mut().worker.history().to_vec(); let items = pod.session_mut().worker_mut().history().to_vec();
shared_state.update_history(items); shared_state.update_history(items);
shared_state.set_status(new_status); shared_state.set_status(new_status);
let _ = runtime_dir.write_status(&shared_state).await; let _ = runtime_dir.write_status(&shared_state).await;
@ -218,7 +218,7 @@ impl PodController {
) )
.await; .await;
let items = pod.session_mut().worker.history().to_vec(); let items = pod.session_mut().worker_mut().history().to_vec();
shared_state.update_history(items); shared_state.update_history(items);
shared_state.set_status(new_status); shared_state.set_status(new_status);
let _ = runtime_dir.write_status(&shared_state).await; let _ = runtime_dir.write_status(&shared_state).await;

View File

@ -144,7 +144,7 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
let builder = std::mem::take(&mut self.hook_builder); let builder = std::mem::take(&mut self.hook_builder);
let registry = Arc::new(builder.build()); let registry = Arc::new(builder.build());
let interceptor = HookInterceptor::new(registry); let interceptor = HookInterceptor::new(registry);
self.session.worker.set_interceptor(interceptor); self.session.worker_mut().set_interceptor(interceptor);
self.interceptor_installed = true; self.interceptor_installed = true;
} }
} }

View File

@ -1,57 +0,0 @@
# Worker: run() 時の自動キャッシュロックと ToolDefinition ファクトリ遅延初期化
## 背景
現状の `Worker` は Type-state パターンで `Mutable` / `CacheLocked` の2状態を持つが、
`lock()` を呼ばなくても `run()` できてしまうため、キャッシュ保護の存在を知らないユーザーは
常に非最適パスを通ることになる。
また `register_tool()` 時に `ToolDefinition` のファクトリクロージャが即時呼び出しされており、
本来の意図である遅延初期化になっていない。
## 方針
### run() 時の自動ロック
`run()` の冒頭で、`Mutable` 状態なら自動的に `CacheLocked` へ遷移する。
これにより lock を知らないユーザーでも嫌でもキャッシュ保護される。
ターンの合間に history や system prompt を編集したい場合は、明示的に `unlock()` を挟む。
次の `run()` で再び自動ロックされる。
```rust
let mut worker = Worker::new(client);
worker.set_system_prompt("...");
worker.register_tool(my_tool)?;
// Mutable のまま run() → 自動で lock される
worker.run("Hello").await?;
// ターン間で内容を弄りたい場合
worker.unlock();
worker.history_mut().truncate(5);
// 次の run() で再 lock
worker.run("Continue").await?;
```
#### 設計ポイント
- `run()``&mut self` を取る以上、内部で状態遷移しても外部の型は変わらない。
実装は内部フラグ(`is_locked: bool`)で管理し、`Mutable` / `CacheLocked`
type-state はそのまま維持する。`Worker<C, Mutable>` の `run()` が内部で lock 相当の
処理を行い、`unlock()` が呼ばれるまでキャッシュ破壊的な操作を(ランタイムで)ブロックする
- `Worker<C, CacheLocked>` は従来どおり。明示的に `lock()` してから使うパスも残る
- interceptor, max_turns, callbacks 等キャッシュに影響しない設定は lock 状態でも自由に変更可能
### ToolDefinition ファクトリの遅延初期化
`register_tool()` は定義を蓄積するだけにし、ファクトリの実行を初回 `run()` まで遅延させる。
- `register_tool()``ToolDefinition` を Vec に push するだけ
- 初回 `run()` の自動 lock 時にファクトリを一括実行し、ToolServer を構築
- `unlock()` 後に追加登録された tool は次の `run()` で初期化
## 移行
- `lock()` / `unlock()` は引き続き使える(明示的なキャッシュ管理用)
- `Worker::new()``run()` のパスが自動保護されるため、既存コードは変更不要