Workerの自動キャッシュロック
This commit is contained in:
parent
f241dafac8
commit
9b78c51d0a
2
TODO.md
2
TODO.md
|
|
@ -3,7 +3,7 @@
|
||||||
- [x] ツール出力の遅延読み込み設計 (ToolOutput / BlobStore / auto_summarize)
|
- [x] ツール出力の遅延読み込み設計 (ToolOutput / BlobStore / auto_summarize)
|
||||||
- [ ] ツール設計
|
- [ ] ツール設計
|
||||||
- [ ] ツールの動的追加/削除 → [tickets/tool-dynamic-registry.md](tickets/tool-dynamic-registry.md)
|
- [ ] ツールの動的追加/削除 → [tickets/tool-dynamic-registry.md](tickets/tool-dynamic-registry.md)
|
||||||
- [ ] run() 自動ロックとファクトリ遅延初期化 → [tickets/worker-auto-lock.md](tickets/worker-auto-lock.md)
|
- [x] run() 自動ロックとファクトリ遅延初期化 → [tickets/worker-auto-lock.md](tickets/worker-auto-lock.md)
|
||||||
- [x] inspect ツール実装
|
- [x] inspect ツール実装
|
||||||
- [x] max_turns: マニフェストによるターン数制限
|
- [x] max_turns: マニフェストによるターン数制限
|
||||||
- [x] pod バイナリエントリポイント
|
- [x] pod バイナリエントリポイント
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,6 @@ use serde_json::json;
|
||||||
|
|
||||||
use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta};
|
use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta};
|
||||||
use llm_worker::state::Mutable;
|
use llm_worker::state::Mutable;
|
||||||
use llm_worker::ToolRegistryError;
|
|
||||||
use llm_worker::Worker;
|
use llm_worker::Worker;
|
||||||
use llm_worker::llm_client::LlmClient;
|
use llm_worker::llm_client::LlmClient;
|
||||||
|
|
||||||
|
|
@ -332,12 +331,11 @@ fn apply_selector(content: &Content, selector: &Selector) -> Result<String, Tool
|
||||||
pub fn register_inspect_tool<C, B>(
|
pub fn register_inspect_tool<C, B>(
|
||||||
worker: &mut Worker<C, Mutable>,
|
worker: &mut Worker<C, Mutable>,
|
||||||
blob_store: Arc<B>,
|
blob_store: Arc<B>,
|
||||||
) -> Result<(), ToolRegistryError>
|
) where
|
||||||
where
|
|
||||||
C: LlmClient,
|
C: LlmClient,
|
||||||
B: BlobStore + 'static,
|
B: BlobStore + 'static,
|
||||||
{
|
{
|
||||||
worker.register_tool(InspectTool::<B>::tool_definition(blob_store))
|
worker.register_tool(InspectTool::<B>::tool_definition(blob_store));
|
||||||
}
|
}
|
||||||
|
|
||||||
// ─── Tests ───────────────────────────────────────────────────────────────────
|
// ─── Tests ───────────────────────────────────────────────────────────────────
|
||||||
|
|
|
||||||
|
|
@ -43,12 +43,13 @@ pub enum SessionError {
|
||||||
|
|
||||||
/// Persistent session wrapping a [`Worker`].
|
/// Persistent session wrapping a [`Worker`].
|
||||||
///
|
///
|
||||||
/// The `worker` field is public for direct access to Worker APIs
|
/// Use [`worker()`](Self::worker) / [`worker_mut()`](Self::worker_mut) to
|
||||||
/// (tool registration, hook setup, subscriber management, etc.).
|
/// access the underlying Worker for configuration (tool registration, etc.).
|
||||||
/// State-mutating operations (`run`, `resume`) should go through
|
/// State-mutating operations (`run`, `resume`) should go through Session
|
||||||
/// Session methods to ensure proper logging.
|
/// methods to ensure proper logging.
|
||||||
pub struct Session<C: LlmClient, St: Store> {
|
pub struct Session<C: LlmClient, St: Store> {
|
||||||
pub worker: Worker<C, Mutable>,
|
/// Always `Some` outside of `run()` / `resume()`.
|
||||||
|
worker: Option<Worker<C, Mutable>>,
|
||||||
store: St,
|
store: St,
|
||||||
session_id: SessionId,
|
session_id: SessionId,
|
||||||
head_hash: Option<EntryHash>,
|
head_hash: Option<EntryHash>,
|
||||||
|
|
@ -78,7 +79,7 @@ impl<C: LlmClient, St: Store> Session<C, St> {
|
||||||
store.append(session_id, &hashed_entry).await?;
|
store.append(session_id, &hashed_entry).await?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
worker,
|
worker: Some(worker),
|
||||||
store,
|
store,
|
||||||
session_id,
|
session_id,
|
||||||
head_hash: Some(hashed),
|
head_hash: Some(hashed),
|
||||||
|
|
@ -87,9 +88,6 @@ impl<C: LlmClient, St: Store> Session<C, St> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Restore a session from a stored log.
|
/// Restore a session from a stored log.
|
||||||
///
|
|
||||||
/// Reads all log entries, collects state from them,
|
|
||||||
/// and returns a `Session` ready for `resume()`.
|
|
||||||
pub async fn restore(
|
pub async fn restore(
|
||||||
client: C,
|
client: C,
|
||||||
store: St,
|
store: St,
|
||||||
|
|
@ -109,7 +107,7 @@ impl<C: LlmClient, St: Store> Session<C, St> {
|
||||||
worker.set_last_run_interrupted(state.last_run_interrupted);
|
worker.set_last_run_interrupted(state.last_run_interrupted);
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
worker,
|
worker: Some(worker),
|
||||||
store,
|
store,
|
||||||
session_id,
|
session_id,
|
||||||
head_hash: state.head_hash,
|
head_hash: state.head_hash,
|
||||||
|
|
@ -117,6 +115,20 @@ impl<C: LlmClient, St: Store> Session<C, St> {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn w(&self) -> &Worker<C, Mutable> {
|
||||||
|
self.worker.as_ref().expect("worker taken during run")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reference to the underlying Worker.
|
||||||
|
pub fn worker(&self) -> &Worker<C, Mutable> {
|
||||||
|
self.w()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mutable reference to the underlying Worker.
|
||||||
|
pub fn worker_mut(&mut self) -> &mut Worker<C, Mutable> {
|
||||||
|
self.worker.as_mut().expect("worker taken during run")
|
||||||
|
}
|
||||||
|
|
||||||
/// The session ID.
|
/// The session ID.
|
||||||
pub fn session_id(&self) -> SessionId {
|
pub fn session_id(&self) -> SessionId {
|
||||||
self.session_id
|
self.session_id
|
||||||
|
|
@ -133,15 +145,23 @@ impl<C: LlmClient, St: Store> Session<C, St> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Run a user turn, logging all state changes.
|
/// Run a user turn, logging all state changes.
|
||||||
|
///
|
||||||
|
/// Internally locks the Worker (flushing pending tools), runs the turn,
|
||||||
|
/// then unlocks back to Mutable state.
|
||||||
pub async fn run(
|
pub async fn run(
|
||||||
&mut self,
|
&mut self,
|
||||||
user_input: impl Into<String>,
|
user_input: impl Into<String>,
|
||||||
) -> Result<WorkerResult, SessionError> {
|
) -> Result<WorkerResult, SessionError> {
|
||||||
|
let input = user_input.into();
|
||||||
self.ensure_head_or_fork().await?;
|
self.ensure_head_or_fork().await?;
|
||||||
|
|
||||||
let history_before = self.worker.history().len();
|
let history_before = self.w().history().len();
|
||||||
|
|
||||||
let result = self.worker.run(user_input).await;
|
// lock → run → unlock (use lock() directly to keep worker on error)
|
||||||
|
let worker = self.worker.take().expect("worker taken during run");
|
||||||
|
let mut locked = worker.lock();
|
||||||
|
let result = locked.run(input).await;
|
||||||
|
self.worker = Some(locked.unlock());
|
||||||
|
|
||||||
self.log_history_delta(history_before).await?;
|
self.log_history_delta(history_before).await?;
|
||||||
self.log_turn_end().await?;
|
self.log_turn_end().await?;
|
||||||
|
|
@ -154,9 +174,13 @@ impl<C: LlmClient, St: Store> Session<C, St> {
|
||||||
pub async fn resume(&mut self) -> Result<WorkerResult, SessionError> {
|
pub async fn resume(&mut self) -> Result<WorkerResult, SessionError> {
|
||||||
self.ensure_head_or_fork().await?;
|
self.ensure_head_or_fork().await?;
|
||||||
|
|
||||||
let history_before = self.worker.history().len();
|
let history_before = self.w().history().len();
|
||||||
|
|
||||||
let result = self.worker.resume().await;
|
// lock → resume → unlock
|
||||||
|
let worker = self.worker.take().expect("worker taken during run");
|
||||||
|
let mut locked = worker.lock();
|
||||||
|
let result = locked.resume().await;
|
||||||
|
self.worker = Some(locked.unlock());
|
||||||
|
|
||||||
self.log_history_delta(history_before).await?;
|
self.log_history_delta(history_before).await?;
|
||||||
self.log_turn_end().await?;
|
self.log_turn_end().await?;
|
||||||
|
|
@ -166,15 +190,13 @@ impl<C: LlmClient, St: Store> Session<C, St> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Fork this session at its current state.
|
/// Fork this session at its current state.
|
||||||
/// Returns the new session ID. The new log contains a `SessionStart`
|
|
||||||
/// seeded with the current history.
|
|
||||||
pub async fn fork(&self) -> Result<SessionId, StoreError> {
|
pub async fn fork(&self) -> Result<SessionId, StoreError> {
|
||||||
let fork_id = crate::new_session_id();
|
let fork_id = crate::new_session_id();
|
||||||
let entry = LogEntry::SessionStart {
|
let entry = LogEntry::SessionStart {
|
||||||
ts: session_log::now_millis(),
|
ts: session_log::now_millis(),
|
||||||
system_prompt: self.worker.get_system_prompt().map(String::from),
|
system_prompt: self.w().get_system_prompt().map(String::from),
|
||||||
config: self.worker.request_config().clone(),
|
config: self.w().request_config().clone(),
|
||||||
history: self.worker.history().to_vec(),
|
history: self.w().history().to_vec(),
|
||||||
};
|
};
|
||||||
let hashed = session_log::compute_hash(None, &entry);
|
let hashed = session_log::compute_hash(None, &entry);
|
||||||
let hashed_entry = HashedEntry {
|
let hashed_entry = HashedEntry {
|
||||||
|
|
@ -189,8 +211,6 @@ impl<C: LlmClient, St: Store> Session<C, St> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Fork from an arbitrary point in a stored session's log.
|
/// Fork from an arbitrary point in a stored session's log.
|
||||||
/// Finds the entry matching `at_hash` and creates a new session
|
|
||||||
/// with state reconstructed up to that point.
|
|
||||||
pub async fn fork_at(
|
pub async fn fork_at(
|
||||||
store: &St,
|
store: &St,
|
||||||
source_id: SessionId,
|
source_id: SessionId,
|
||||||
|
|
@ -221,12 +241,12 @@ impl<C: LlmClient, St: Store> Session<C, St> {
|
||||||
Ok(fork_id)
|
Ok(fork_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Log a `CacheLocked` entry.
|
/// Log a `Locked` entry.
|
||||||
pub async fn log_cache_locked(
|
pub async fn log_cache_locked(
|
||||||
&mut self,
|
&mut self,
|
||||||
locked_prefix_len: usize,
|
locked_prefix_len: usize,
|
||||||
) -> Result<(), StoreError> {
|
) -> Result<(), StoreError> {
|
||||||
let entry = LogEntry::CacheLocked {
|
let entry = LogEntry::Locked {
|
||||||
ts: session_log::now_millis(),
|
ts: session_log::now_millis(),
|
||||||
locked_prefix_len,
|
locked_prefix_len,
|
||||||
};
|
};
|
||||||
|
|
@ -245,14 +265,13 @@ impl<C: LlmClient, St: Store> Session<C, St> {
|
||||||
pub async fn log_config_changed(&mut self) -> Result<(), StoreError> {
|
pub async fn log_config_changed(&mut self) -> Result<(), StoreError> {
|
||||||
let entry = LogEntry::ConfigChanged {
|
let entry = LogEntry::ConfigChanged {
|
||||||
ts: session_log::now_millis(),
|
ts: session_log::now_millis(),
|
||||||
config: self.worker.request_config().clone(),
|
config: self.w().request_config().clone(),
|
||||||
};
|
};
|
||||||
self.append_entry(entry).await
|
self.append_entry(entry).await
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Private helpers ──────────────────────────────────────────────────
|
// ── Private helpers ──────────────────────────────────────────────────
|
||||||
|
|
||||||
/// Append a `LogEntry`, computing its hash and updating `head_hash`.
|
|
||||||
async fn append_entry(&mut self, entry: LogEntry) -> Result<(), StoreError> {
|
async fn append_entry(&mut self, entry: LogEntry) -> Result<(), StoreError> {
|
||||||
let hash = session_log::compute_hash(self.head_hash.as_ref(), &entry);
|
let hash = session_log::compute_hash(self.head_hash.as_ref(), &entry);
|
||||||
let hashed_entry = HashedEntry {
|
let hashed_entry = HashedEntry {
|
||||||
|
|
@ -267,19 +286,17 @@ impl<C: LlmClient, St: Store> Session<C, St> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check that the store's head still matches ours. If not, auto-fork.
|
|
||||||
async fn ensure_head_or_fork(&mut self) -> Result<(), StoreError> {
|
async fn ensure_head_or_fork(&mut self) -> Result<(), StoreError> {
|
||||||
let store_head = self.store.read_head_hash(self.session_id).await?;
|
let store_head = self.store.read_head_hash(self.session_id).await?;
|
||||||
if store_head == self.head_hash {
|
if store_head == self.head_hash {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
// Another writer advanced this session — fork from our known state.
|
|
||||||
let fork_id = crate::new_session_id();
|
let fork_id = crate::new_session_id();
|
||||||
let entry = LogEntry::SessionStart {
|
let entry = LogEntry::SessionStart {
|
||||||
ts: session_log::now_millis(),
|
ts: session_log::now_millis(),
|
||||||
system_prompt: self.worker.get_system_prompt().map(String::from),
|
system_prompt: self.w().get_system_prompt().map(String::from),
|
||||||
config: self.worker.request_config().clone(),
|
config: self.w().request_config().clone(),
|
||||||
history: self.worker.history().to_vec(),
|
history: self.w().history().to_vec(),
|
||||||
};
|
};
|
||||||
let hash = session_log::compute_hash(None, &entry);
|
let hash = session_log::compute_hash(None, &entry);
|
||||||
let hashed_entry = HashedEntry {
|
let hashed_entry = HashedEntry {
|
||||||
|
|
@ -296,7 +313,7 @@ impl<C: LlmClient, St: Store> Session<C, St> {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn log_history_delta(&mut self, before_len: usize) -> Result<(), StoreError> {
|
async fn log_history_delta(&mut self, before_len: usize) -> Result<(), StoreError> {
|
||||||
let history = self.worker.history();
|
let history = self.w().history();
|
||||||
if history.len() <= before_len {
|
if history.len() <= before_len {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
@ -356,7 +373,7 @@ impl<C: LlmClient, St: Store> Session<C, St> {
|
||||||
async fn log_turn_end(&mut self) -> Result<(), StoreError> {
|
async fn log_turn_end(&mut self) -> Result<(), StoreError> {
|
||||||
self.append_entry(LogEntry::TurnEnd {
|
self.append_entry(LogEntry::TurnEnd {
|
||||||
ts: session_log::now_millis(),
|
ts: session_log::now_millis(),
|
||||||
turn_count: self.worker.turn_count(),
|
turn_count: self.w().turn_count(),
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
@ -376,7 +393,7 @@ impl<C: LlmClient, St: Store> Session<C, St> {
|
||||||
self.append_entry(LogEntry::RunOutcome {
|
self.append_entry(LogEntry::RunOutcome {
|
||||||
ts: session_log::now_millis(),
|
ts: session_log::now_millis(),
|
||||||
outcome,
|
outcome,
|
||||||
interrupted: self.worker.last_run_interrupted(),
|
interrupted: self.w().last_run_interrupted(),
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -88,7 +88,7 @@ pub struct HashedEntry {
|
||||||
/// - `SessionStart` — always the first entry; captures initial state
|
/// - `SessionStart` — always the first entry; captures initial state
|
||||||
/// - `UserInput` / `AssistantItems` / `ToolResults` / `HookInjectedItems` — history appends
|
/// - `UserInput` / `AssistantItems` / `ToolResults` / `HookInjectedItems` — history appends
|
||||||
/// - `TurnEnd` — turn boundary marker
|
/// - `TurnEnd` — turn boundary marker
|
||||||
/// - `CacheLocked` / `CacheUnlocked` — KV cache state transitions
|
/// - `Locked` / `CacheUnlocked` — KV cache state transitions
|
||||||
/// - `RunOutcome` — marks end of a `run()` or `resume()` call
|
/// - `RunOutcome` — marks end of a `run()` or `resume()` call
|
||||||
/// - `ConfigChanged` — `RequestConfig` mutation
|
/// - `ConfigChanged` — `RequestConfig` mutation
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|
@ -119,7 +119,8 @@ pub enum LogEntry {
|
||||||
TurnEnd { ts: u64, turn_count: usize },
|
TurnEnd { ts: u64, turn_count: usize },
|
||||||
|
|
||||||
/// KV cache locked. Records the history prefix length that is now immutable.
|
/// KV cache locked. Records the history prefix length that is now immutable.
|
||||||
CacheLocked { ts: u64, locked_prefix_len: usize },
|
#[serde(alias = "cache_locked")]
|
||||||
|
Locked { ts: u64, locked_prefix_len: usize },
|
||||||
|
|
||||||
/// KV cache unlocked.
|
/// KV cache unlocked.
|
||||||
CacheUnlocked { ts: u64 },
|
CacheUnlocked { ts: u64 },
|
||||||
|
|
@ -200,7 +201,7 @@ pub fn collect_state(entries: &[HashedEntry]) -> RestoredState {
|
||||||
LogEntry::TurnEnd { turn_count, .. } => {
|
LogEntry::TurnEnd { turn_count, .. } => {
|
||||||
state.turn_count = *turn_count;
|
state.turn_count = *turn_count;
|
||||||
}
|
}
|
||||||
LogEntry::CacheLocked {
|
LogEntry::Locked {
|
||||||
locked_prefix_len, ..
|
locked_prefix_len, ..
|
||||||
} => {
|
} => {
|
||||||
state.locked_prefix_len = *locked_prefix_len;
|
state.locked_prefix_len = *locked_prefix_len;
|
||||||
|
|
@ -354,7 +355,7 @@ mod tests {
|
||||||
config: RequestConfig::default(),
|
config: RequestConfig::default(),
|
||||||
history: vec![Item::user_message("a"), Item::assistant_message("b")],
|
history: vec![Item::user_message("a"), Item::assistant_message("b")],
|
||||||
},
|
},
|
||||||
LogEntry::CacheLocked {
|
LogEntry::Locked {
|
||||||
ts: 2000,
|
ts: 2000,
|
||||||
locked_prefix_len: 2,
|
locked_prefix_len: 2,
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -159,8 +159,8 @@ async fn session_restore_round_trip() {
|
||||||
|
|
||||||
session.run("Hi").await.unwrap();
|
session.run("Hi").await.unwrap();
|
||||||
|
|
||||||
let original_history = session.worker.history().to_vec();
|
let original_history = session.worker().history().to_vec();
|
||||||
let original_turn_count = session.worker.turn_count();
|
let original_turn_count = session.worker().turn_count();
|
||||||
let original_head_hash = session.head_hash().cloned();
|
let original_head_hash = session.head_hash().cloned();
|
||||||
|
|
||||||
// Restore
|
// Restore
|
||||||
|
|
@ -170,10 +170,10 @@ async fn session_restore_round_trip() {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(restored.worker.history().len(), original_history.len());
|
assert_eq!(restored.worker().history().len(), original_history.len());
|
||||||
assert_eq!(restored.worker.turn_count(), original_turn_count);
|
assert_eq!(restored.worker().turn_count(), original_turn_count);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
restored.worker.get_system_prompt().map(String::from),
|
restored.worker().get_system_prompt().map(String::from),
|
||||||
Some("You are helpful.".to_string())
|
Some("You are helpful.".to_string())
|
||||||
);
|
);
|
||||||
assert_eq!(restored.head_hash(), original_head_hash.as_ref());
|
assert_eq!(restored.head_hash(), original_head_hash.as_ref());
|
||||||
|
|
@ -184,7 +184,7 @@ async fn session_run_with_tool_call() {
|
||||||
let (_dir, store) = make_store().await;
|
let (_dir, store) = make_store().await;
|
||||||
let client = MockLlmClient::with_responses(tool_call_events());
|
let client = MockLlmClient::with_responses(tool_call_events());
|
||||||
let mut worker = Worker::new(client);
|
let mut worker = Worker::new(client);
|
||||||
worker.register_tool(weather_tool_definition()).unwrap();
|
worker.register_tool(weather_tool_definition());
|
||||||
|
|
||||||
let mut session = Session::new(worker, store.clone(), SessionConfig::default())
|
let mut session = Session::new(worker, store.clone(), SessionConfig::default())
|
||||||
.await
|
.await
|
||||||
|
|
@ -213,7 +213,7 @@ async fn session_resume_after_pause() {
|
||||||
// First run: tool call with pause hook → Paused
|
// First run: tool call with pause hook → Paused
|
||||||
let client = MockLlmClient::with_responses(tool_call_events());
|
let client = MockLlmClient::with_responses(tool_call_events());
|
||||||
let mut worker = Worker::new(client);
|
let mut worker = Worker::new(client);
|
||||||
worker.register_tool(weather_tool_definition()).unwrap();
|
worker.register_tool(weather_tool_definition());
|
||||||
worker.set_interceptor(PausePolicy);
|
worker.set_interceptor(PausePolicy);
|
||||||
|
|
||||||
let mut session = Session::new(worker, store.clone(), SessionConfig::default())
|
let mut session = Session::new(worker, store.clone(), SessionConfig::default())
|
||||||
|
|
@ -251,7 +251,7 @@ async fn session_resume_after_pause() {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert!(restored.worker.last_run_interrupted());
|
assert!(restored.worker().last_run_interrupted());
|
||||||
|
|
||||||
// resume may or may not succeed depending on Worker internal state,
|
// resume may or may not succeed depending on Worker internal state,
|
||||||
// but the restore itself should work
|
// but the restore itself should work
|
||||||
|
|
@ -271,7 +271,7 @@ async fn session_fork_preserves_state() {
|
||||||
|
|
||||||
session.run("Hello").await.unwrap();
|
session.run("Hello").await.unwrap();
|
||||||
|
|
||||||
let original_history_len = session.worker.history().len();
|
let original_history_len = session.worker().history().len();
|
||||||
let fork_id = session.fork().await.unwrap();
|
let fork_id = session.fork().await.unwrap();
|
||||||
|
|
||||||
// Fork should have a SessionStart with the current history
|
// Fork should have a SessionStart with the current history
|
||||||
|
|
@ -334,7 +334,7 @@ async fn session_config_changed_logged() {
|
||||||
|
|
||||||
// Modify config via worker and log it
|
// Modify config via worker and log it
|
||||||
session
|
session
|
||||||
.worker
|
.worker_mut()
|
||||||
.set_request_config(RequestConfig::default().with_temperature(0.7));
|
.set_request_config(RequestConfig::default().with_temperature(0.7));
|
||||||
session.log_config_changed().await.unwrap();
|
session.log_config_changed().await.unwrap();
|
||||||
|
|
||||||
|
|
@ -367,13 +367,13 @@ async fn session_cache_lock_unlock_logged() {
|
||||||
let has_locked = entries.iter().any(|e| {
|
let has_locked = entries.iter().any(|e| {
|
||||||
matches!(
|
matches!(
|
||||||
&e.entry,
|
&e.entry,
|
||||||
LogEntry::CacheLocked {
|
LogEntry::Locked {
|
||||||
locked_prefix_len: 5,
|
locked_prefix_len: 5,
|
||||||
..
|
..
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
assert!(has_locked, "should have CacheLocked entry");
|
assert!(has_locked, "should have Locked entry");
|
||||||
|
|
||||||
let has_unlocked = entries
|
let has_unlocked = entries
|
||||||
.iter()
|
.iter()
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,7 @@
|
||||||
|
|
||||||
use llm_worker::llm_client::providers::anthropic::AnthropicClient;
|
use llm_worker::llm_client::providers::anthropic::AnthropicClient;
|
||||||
use llm_worker::{Worker, WorkerResult};
|
use llm_worker::{Worker, WorkerResult};
|
||||||
use std::sync::Arc;
|
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::sync::Mutex;
|
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
|
@ -25,48 +23,38 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY environment variable not set");
|
std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY environment variable not set");
|
||||||
|
|
||||||
let client = AnthropicClient::new(&api_key, "claude-sonnet-4-20250514");
|
let client = AnthropicClient::new(&api_key, "claude-sonnet-4-20250514");
|
||||||
let worker = Arc::new(Mutex::new(Worker::new(client)));
|
let worker = Worker::new(client);
|
||||||
|
|
||||||
println!("🚀 Starting Worker...");
|
println!("🚀 Starting Worker...");
|
||||||
println!("💡 Will cancel after 2 seconds\n");
|
println!("💡 Will cancel after 2 seconds\n");
|
||||||
|
|
||||||
// Get cancel sender first (without holding lock)
|
// Get cancel sender before run (Mutable state)
|
||||||
let cancel_tx = {
|
let cancel_tx = worker.cancel_sender();
|
||||||
let w = worker.lock().await;
|
|
||||||
w.cancel_sender()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Task 1: Run Worker
|
// Task: Cancel after 2 seconds
|
||||||
let worker_clone = worker.clone();
|
|
||||||
let task = tokio::spawn(async move {
|
|
||||||
let mut w = worker_clone.lock().await;
|
|
||||||
println!("📡 Sending request to LLM...");
|
|
||||||
|
|
||||||
match w.run("Tell me a very long story about a brave knight. Make it as detailed as possible with many paragraphs.").await {
|
|
||||||
Ok(WorkerResult::Finished) => {
|
|
||||||
println!("✅ Task completed normally");
|
|
||||||
}
|
|
||||||
Ok(WorkerResult::Paused) => {
|
|
||||||
println!("⏸️ Task paused");
|
|
||||||
}
|
|
||||||
Ok(WorkerResult::LimitReached) => {
|
|
||||||
println!("🔒 Turn limit reached");
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
println!("❌ Task error: {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Task 2: Cancel after 2 seconds
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
tokio::time::sleep(Duration::from_secs(2)).await;
|
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||||
println!("\n🛑 Cancelling worker...");
|
println!("\n🛑 Cancelling worker...");
|
||||||
let _ = cancel_tx.send(()).await;
|
let _ = cancel_tx.send(()).await;
|
||||||
});
|
});
|
||||||
|
|
||||||
// Wait for task completion
|
println!("📡 Sending request to LLM...");
|
||||||
task.await?;
|
|
||||||
|
// Mutable::run consumes self → (Locked, WorkerResult)
|
||||||
|
match worker.run("Tell me a very long story about a brave knight. Make it as detailed as possible with many paragraphs.").await {
|
||||||
|
Ok((_locked, WorkerResult::Finished)) => {
|
||||||
|
println!("✅ Task completed normally");
|
||||||
|
}
|
||||||
|
Ok((_locked, WorkerResult::Paused)) => {
|
||||||
|
println!("⏸️ Task paused");
|
||||||
|
}
|
||||||
|
Ok((_locked, WorkerResult::LimitReached)) => {
|
||||||
|
println!("🔒 Turn limit reached");
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
println!("❌ Task error: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
println!("\n✨ Demo complete!");
|
println!("\n✨ Demo complete!");
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -438,10 +438,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
// Register tools (unless --no-tools)
|
// Register tools (unless --no-tools)
|
||||||
if !args.no_tools {
|
if !args.no_tools {
|
||||||
let app = AppContext;
|
let app = AppContext;
|
||||||
worker
|
worker.register_tool(app.get_current_time_definition());
|
||||||
.register_tool(app.get_current_time_definition())
|
worker.register_tool(app.calculate_definition());
|
||||||
.unwrap();
|
|
||||||
worker.register_tool(app.calculate_definition()).unwrap();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register streaming display handlers
|
// Register streaming display handlers
|
||||||
|
|
@ -465,7 +463,27 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Interactive loop
|
// Interactive loop — first input transitions Mutable → Locked
|
||||||
|
print!("\n👤 You: ");
|
||||||
|
io::stdout().flush()?;
|
||||||
|
|
||||||
|
let mut first_input = String::new();
|
||||||
|
io::stdin().read_line(&mut first_input)?;
|
||||||
|
let first_input = first_input.trim();
|
||||||
|
|
||||||
|
if first_input == "quit" || first_input == "exit" || first_input.is_empty() {
|
||||||
|
println!("\n👋 Goodbye!");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let (mut locked, _) = match worker.run(first_input).await {
|
||||||
|
Ok(pair) => pair,
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("\n❌ Error: {}", e);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
print!("\n👤 You: ");
|
print!("\n👤 You: ");
|
||||||
io::stdout().flush()?;
|
io::stdout().flush()?;
|
||||||
|
|
@ -483,8 +501,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run Worker (Worker manages history)
|
match locked.run(input).await {
|
||||||
match worker.run(input).await {
|
|
||||||
Ok(_) => {}
|
Ok(_) => {}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("\n❌ Error: {}", e);
|
eprintln!("\n❌ Error: {}", e);
|
||||||
|
|
|
||||||
|
|
@ -27,12 +27,14 @@
|
||||||
//!
|
//!
|
||||||
//! # Cache Protection
|
//! # Cache Protection
|
||||||
//!
|
//!
|
||||||
//! To maximize KV cache hit rate, transition to the locked state
|
//! `run()` automatically locks the cache. To edit state between turns,
|
||||||
//! with [`Worker::lock()`] before execution.
|
//! call `unlock_cache()` first; the next `run()` re-locks automatically.
|
||||||
//!
|
//!
|
||||||
//! ```ignore
|
//! ```ignore
|
||||||
//! let mut locked = worker.lock();
|
//! worker.run("user input").await?;
|
||||||
//! locked.run("user input").await?;
|
//! worker.unlock_cache();
|
||||||
|
//! worker.set_system_prompt("new prompt");
|
||||||
|
//! worker.run("next input").await?;
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
mod handler;
|
mod handler;
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
//! Worker State
|
//! Worker State
|
||||||
//!
|
//!
|
||||||
//! State marker types for cache protection using the Type-state pattern.
|
//! State marker types for cache protection using the Type-state pattern.
|
||||||
//! Worker has state transitions from `Mutable` → `CacheLocked`.
|
//! Worker has state transitions from `Mutable` → `Locked`.
|
||||||
|
|
||||||
/// Marker trait representing Worker state
|
/// Marker trait representing Worker state
|
||||||
///
|
///
|
||||||
|
|
@ -19,7 +19,7 @@ mod private {
|
||||||
/// - Editing message history (add, delete, clear)
|
/// - Editing message history (add, delete, clear)
|
||||||
/// - Registering tools and hooks
|
/// - Registering tools and hooks
|
||||||
///
|
///
|
||||||
/// Can transition to [`CacheLocked`] state via `Worker::lock()`.
|
/// Can transition to [`Locked`] state via `Worker::lock()`.
|
||||||
///
|
///
|
||||||
/// # Examples
|
/// # Examples
|
||||||
///
|
///
|
||||||
|
|
@ -54,7 +54,7 @@ impl WorkerState for Mutable {}
|
||||||
/// Can return to [`Mutable`] state via `Worker::unlock()`,
|
/// Can return to [`Mutable`] state via `Worker::unlock()`,
|
||||||
/// but note that cache protection will be released.
|
/// but note that cache protection will be released.
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
pub struct CacheLocked;
|
pub struct Locked;
|
||||||
|
|
||||||
impl private::Sealed for CacheLocked {}
|
impl private::Sealed for Locked {}
|
||||||
impl WorkerState for CacheLocked {}
|
impl WorkerState for Locked {}
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ pub enum ToolServerError {
|
||||||
#[derive(Clone, Default)]
|
#[derive(Clone, Default)]
|
||||||
pub struct ToolServer {
|
pub struct ToolServer {
|
||||||
tools: Arc<Mutex<ToolMap>>,
|
tools: Arc<Mutex<ToolMap>>,
|
||||||
|
pending: Arc<Mutex<Vec<WorkerToolDefinition>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToolServer {
|
impl ToolServer {
|
||||||
|
|
@ -38,6 +39,7 @@ impl ToolServer {
|
||||||
pub fn handle(&self) -> ToolServerHandle {
|
pub fn handle(&self) -> ToolServerHandle {
|
||||||
ToolServerHandle {
|
ToolServerHandle {
|
||||||
tools: Arc::clone(&self.tools),
|
tools: Arc::clone(&self.tools),
|
||||||
|
pending: Arc::clone(&self.pending),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -46,32 +48,57 @@ impl ToolServer {
|
||||||
#[derive(Clone, Default)]
|
#[derive(Clone, Default)]
|
||||||
pub struct ToolServerHandle {
|
pub struct ToolServerHandle {
|
||||||
tools: Arc<Mutex<ToolMap>>,
|
tools: Arc<Mutex<ToolMap>>,
|
||||||
|
pending: Arc<Mutex<Vec<WorkerToolDefinition>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToolServerHandle {
|
impl ToolServerHandle {
|
||||||
/// Register one tool.
|
/// Queue a tool factory for deferred initialization.
|
||||||
pub(crate) fn register_tool(
|
///
|
||||||
&self,
|
/// The factory is **not** called here; it is stored and executed
|
||||||
factory: WorkerToolDefinition,
|
/// when [`flush_pending`](Self::flush_pending) is called (typically
|
||||||
) -> Result<(), ToolServerError> {
|
/// at the start of `Worker::run()`).
|
||||||
let (meta, instance) = factory();
|
pub(crate) fn register_tool(&self, factory: WorkerToolDefinition) {
|
||||||
let mut guard = self.tools.lock().unwrap_or_else(|e| e.into_inner());
|
self.pending
|
||||||
if guard.contains_key(&meta.name) {
|
.lock()
|
||||||
return Err(ToolServerError::DuplicateName(meta.name));
|
.unwrap_or_else(|e| e.into_inner())
|
||||||
}
|
.push(factory);
|
||||||
guard.insert(meta.name.clone(), (meta, instance));
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Register many tools.
|
/// Queue many tool factories for deferred initialization.
|
||||||
pub(crate) fn register_tools(
|
pub(crate) fn register_tools(
|
||||||
&self,
|
&self,
|
||||||
factories: impl IntoIterator<Item = WorkerToolDefinition>,
|
factories: impl IntoIterator<Item = WorkerToolDefinition>,
|
||||||
) -> Result<(), ToolServerError> {
|
) {
|
||||||
for factory in factories {
|
let mut guard = self.pending.lock().unwrap_or_else(|e| e.into_inner());
|
||||||
self.register_tool(factory)?;
|
guard.extend(factories);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute all pending factories and register the resulting tools.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if any factory produces a tool whose name collides with
|
||||||
|
/// an already-registered tool. Duplicate names are a programming
|
||||||
|
/// error and should be caught during development.
|
||||||
|
pub(crate) fn flush_pending(&self) {
|
||||||
|
let pending: Vec<_> = {
|
||||||
|
let mut guard = self.pending.lock().unwrap_or_else(|e| e.into_inner());
|
||||||
|
std::mem::take(&mut *guard)
|
||||||
|
};
|
||||||
|
if pending.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Execute all factories first, then validate and insert atomically.
|
||||||
|
let materialized: Vec<_> = pending.into_iter().map(|f| f()).collect();
|
||||||
|
let mut tools = self.tools.lock().unwrap_or_else(|e| e.into_inner());
|
||||||
|
for (meta, instance) in materialized {
|
||||||
|
assert!(
|
||||||
|
!tools.contains_key(&meta.name),
|
||||||
|
"duplicate tool name: '{}'",
|
||||||
|
meta.name,
|
||||||
|
);
|
||||||
|
tools.insert(meta.name.clone(), (meta, instance));
|
||||||
}
|
}
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a tool by name for hook contexts.
|
/// Get a tool by name for hook contexts.
|
||||||
|
|
@ -143,19 +170,37 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn register_duplicate_name_fails() {
|
fn flush_pending_registers_tools() {
|
||||||
let handle = ToolServer::new().handle();
|
let handle = ToolServer::new().handle();
|
||||||
handle.register_tool(def("alpha")).expect("first register");
|
handle.register_tool(def("alpha"));
|
||||||
let err = handle
|
handle.register_tool(def("beta"));
|
||||||
.register_tool(def("alpha"))
|
|
||||||
.expect_err("duplicate should fail");
|
// Before flush, no tools are available
|
||||||
assert_eq!(err, ToolServerError::DuplicateName("alpha".to_string()));
|
assert!(handle.get_tool("alpha").is_none());
|
||||||
|
|
||||||
|
handle.flush_pending();
|
||||||
|
|
||||||
|
// After flush, tools are available
|
||||||
|
assert!(handle.get_tool("alpha").is_some());
|
||||||
|
assert!(handle.get_tool("beta").is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "duplicate tool name: 'alpha'")]
|
||||||
|
fn flush_pending_duplicate_name_panics() {
|
||||||
|
let handle = ToolServer::new().handle();
|
||||||
|
handle.register_tool(def("alpha"));
|
||||||
|
handle.flush_pending();
|
||||||
|
|
||||||
|
handle.register_tool(def("alpha"));
|
||||||
|
handle.flush_pending(); // panics
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn call_tool_success_and_not_found() {
|
async fn call_tool_success_and_not_found() {
|
||||||
let handle = ToolServer::new().handle();
|
let handle = ToolServer::new().handle();
|
||||||
handle.register_tool(def("echo")).expect("register");
|
handle.register_tool(def("echo"));
|
||||||
|
handle.flush_pending();
|
||||||
|
|
||||||
let out = handle.call_tool("echo", r#"{"x":1}"#).await.expect("call");
|
let out = handle.call_tool("echo", r#"{"x":1}"#).await.expect("call");
|
||||||
assert_eq!(out, r#"{"x":1}"#);
|
assert_eq!(out, r#"{"x":1}"#);
|
||||||
|
|
@ -170,9 +215,10 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn tool_definitions_are_sorted() {
|
fn tool_definitions_are_sorted() {
|
||||||
let handle = ToolServer::new().handle();
|
let handle = ToolServer::new().handle();
|
||||||
handle.register_tool(def("zeta")).expect("register zeta");
|
handle.register_tool(def("zeta"));
|
||||||
handle.register_tool(def("alpha")).expect("register alpha");
|
handle.register_tool(def("alpha"));
|
||||||
handle.register_tool(def("beta")).expect("register beta");
|
handle.register_tool(def("beta"));
|
||||||
|
handle.flush_pending();
|
||||||
|
|
||||||
let names: Vec<_> = handle
|
let names: Vec<_> = handle
|
||||||
.tool_definitions_sorted()
|
.tool_definitions_sorted()
|
||||||
|
|
@ -181,4 +227,11 @@ mod tests {
|
||||||
.collect();
|
.collect();
|
||||||
assert_eq!(names, vec!["alpha", "beta", "zeta"]);
|
assert_eq!(names, vec!["alpha", "beta", "zeta"]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn flush_pending_is_noop_when_empty() {
|
||||||
|
let handle = ToolServer::new().handle();
|
||||||
|
handle.flush_pending();
|
||||||
|
handle.flush_pending();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ use crate::{
|
||||||
DefaultInterceptor, Interceptor, PostToolAction, PreRequestAction, PreToolAction,
|
DefaultInterceptor, Interceptor, PostToolAction, PreRequestAction, PreToolAction,
|
||||||
PromptAction, ToolCallInfo, ToolResultInfo, TurnEndAction,
|
PromptAction, ToolCallInfo, ToolResultInfo, TurnEndAction,
|
||||||
},
|
},
|
||||||
state::{CacheLocked, Mutable, WorkerState},
|
state::{Locked, Mutable, WorkerState},
|
||||||
callback::{
|
callback::{
|
||||||
ClosureMetaHandler, ClosureTextBlockHandler, ClosureToolUseBlockHandler, TextBlockScope,
|
ClosureMetaHandler, ClosureTextBlockHandler, ClosureToolUseBlockHandler, TextBlockScope,
|
||||||
ToolUseBlockScope,
|
ToolUseBlockScope,
|
||||||
|
|
@ -22,12 +22,9 @@ use crate::{
|
||||||
timeline::{TextBlockCollector, Timeline, ToolCallCollector},
|
timeline::{TextBlockCollector, Timeline, ToolCallCollector},
|
||||||
timeline::event::{ErrorEvent, StatusEvent, UsageEvent},
|
timeline::event::{ErrorEvent, StatusEvent, UsageEvent},
|
||||||
tool::{ToolCall, ToolDefinition as WorkerToolDefinition, ToolError, ToolOutputProcessor, ToolResult},
|
tool::{ToolCall, ToolDefinition as WorkerToolDefinition, ToolError, ToolOutputProcessor, ToolResult},
|
||||||
tool_server::{ToolServer, ToolServerError, ToolServerHandle},
|
tool_server::{ToolServer, ToolServerHandle},
|
||||||
};
|
};
|
||||||
|
|
||||||
// =============================================================================
|
|
||||||
// Worker Error
|
|
||||||
// =============================================================================
|
|
||||||
|
|
||||||
/// Worker errors
|
/// Worker errors
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
|
@ -57,9 +54,6 @@ pub enum ToolRegistryError {
|
||||||
DuplicateName(String),
|
DuplicateName(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
|
||||||
// Worker Config
|
|
||||||
// =============================================================================
|
|
||||||
|
|
||||||
/// Worker configuration
|
/// Worker configuration
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default)]
|
||||||
|
|
@ -68,9 +62,6 @@ pub struct WorkerConfig {
|
||||||
_private: (),
|
_private: (),
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
|
||||||
// Worker Result Types
|
|
||||||
// =============================================================================
|
|
||||||
|
|
||||||
/// Worker execution result (status)
|
/// Worker execution result (status)
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
|
@ -89,9 +80,6 @@ enum ToolExecutionResult {
|
||||||
Paused,
|
Paused,
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
|
||||||
// Worker
|
|
||||||
// =============================================================================
|
|
||||||
|
|
||||||
/// Central component for managing LLM interactions
|
/// Central component for managing LLM interactions
|
||||||
///
|
///
|
||||||
|
|
@ -100,32 +88,28 @@ enum ToolExecutionResult {
|
||||||
///
|
///
|
||||||
/// # State Transitions (Type-state)
|
/// # State Transitions (Type-state)
|
||||||
///
|
///
|
||||||
/// - [`Mutable`]: Initial state. System prompt and history can be freely edited.
|
/// - [`Mutable`]: Initial state. System prompt, history, and tools can be freely edited.
|
||||||
/// - [`CacheLocked`]: Cache-protected state. Transition via `lock()`. Prefix context is immutable.
|
/// - [`Locked`]: Cache-protected state. Prefix context is immutable; only `run()` / `resume()` are available.
|
||||||
///
|
///
|
||||||
/// # Examples
|
/// Calling `run()` on a `Mutable` Worker consumes it and returns a
|
||||||
|
/// `Locked` Worker together with the result. This ensures the
|
||||||
|
/// cache prefix is fixed for optimal KV cache hit rate.
|
||||||
///
|
///
|
||||||
/// ```ignore
|
/// ```ignore
|
||||||
/// use llm_worker::{Worker, Item};
|
|
||||||
///
|
|
||||||
/// // Create a Worker and register tools
|
|
||||||
/// let mut worker = Worker::new(client)
|
/// let mut worker = Worker::new(client)
|
||||||
/// .system_prompt("You are a helpful assistant.");
|
/// .system_prompt("You are a helpful assistant.");
|
||||||
/// worker.register_tool(my_tool);
|
/// worker.register_tool(my_tool);
|
||||||
///
|
///
|
||||||
/// // Run the interaction
|
/// // Mutable::run() consumes self → Locked
|
||||||
/// let history = worker.run("Hello!").await?;
|
/// let (mut worker, _result) = worker.run("Hello").await?;
|
||||||
/// ```
|
|
||||||
///
|
///
|
||||||
/// # When Cache Protection is Needed
|
/// // Locked::run() borrows &mut self
|
||||||
|
/// worker.run("Follow-up").await?;
|
||||||
///
|
///
|
||||||
/// ```ignore
|
/// // To edit between turns, unlock back to Mutable
|
||||||
/// let mut worker = Worker::new(client)
|
/// let mut worker = worker.unlock();
|
||||||
/// .system_prompt("...");
|
/// worker.history_mut().truncate(5);
|
||||||
///
|
/// let (mut worker, _result) = worker.run("Continue").await?;
|
||||||
/// // After setting history, lock to protect cache
|
|
||||||
/// let mut locked = worker.lock();
|
|
||||||
/// locked.run("user input").await?;
|
|
||||||
/// ```
|
/// ```
|
||||||
pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
||||||
/// LLM client
|
/// LLM client
|
||||||
|
|
@ -144,7 +128,7 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
||||||
system_prompt: Option<String>,
|
system_prompt: Option<String>,
|
||||||
/// Item history (owned by Worker)
|
/// Item history (owned by Worker)
|
||||||
history: Vec<Item>,
|
history: Vec<Item>,
|
||||||
/// History length at lock time (only meaningful in CacheLocked state)
|
/// History length at lock time (only meaningful in Locked state)
|
||||||
locked_prefix_len: usize,
|
locked_prefix_len: usize,
|
||||||
/// Turn count
|
/// Turn count
|
||||||
turn_count: usize,
|
turn_count: usize,
|
||||||
|
|
@ -167,40 +151,12 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
||||||
_state: PhantomData<S>,
|
_state: PhantomData<S>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
|
||||||
// Common Implementation (available in all states)
|
|
||||||
// =============================================================================
|
|
||||||
|
|
||||||
impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
fn reset_interruption_state(&mut self) {
|
fn reset_interruption_state(&mut self) {
|
||||||
self.last_run_interrupted = false;
|
self.last_run_interrupted = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Execute a turn
|
|
||||||
///
|
|
||||||
/// Adds a new user message to history and sends a request to the LLM.
|
|
||||||
/// Automatically loops if there are tool calls.
|
|
||||||
pub async fn run(
|
|
||||||
&mut self,
|
|
||||||
user_input: impl Into<String>,
|
|
||||||
) -> Result<WorkerResult, WorkerError> {
|
|
||||||
self.reset_interruption_state();
|
|
||||||
// Interceptor: on_prompt_submit
|
|
||||||
let mut user_item = Item::user_message(user_input);
|
|
||||||
match self.interceptor.on_prompt_submit(&mut user_item).await {
|
|
||||||
PromptAction::Cancel(reason) => {
|
|
||||||
self.last_run_interrupted = true;
|
|
||||||
return self
|
|
||||||
.finalize_interruption(Err(WorkerError::Aborted(reason)))
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
PromptAction::Continue => {}
|
|
||||||
}
|
|
||||||
self.history.push(user_item);
|
|
||||||
let result = self.run_turn_loop().await;
|
|
||||||
self.finalize_interruption(result).await
|
|
||||||
}
|
|
||||||
|
|
||||||
fn drain_cancel_queue(&mut self) {
|
fn drain_cancel_queue(&mut self) {
|
||||||
use tokio::sync::mpsc::error::TryRecvError;
|
use tokio::sync::mpsc::error::TryRecvError;
|
||||||
loop {
|
loop {
|
||||||
|
|
@ -892,19 +848,8 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Resume execution (from Paused state)
|
|
||||||
///
|
|
||||||
/// Resumes turn processing from current state without adding a new user message to history.
|
|
||||||
pub async fn resume(&mut self) -> Result<WorkerResult, WorkerError> {
|
|
||||||
self.reset_interruption_state();
|
|
||||||
let result = self.run_turn_loop().await;
|
|
||||||
self.finalize_interruption(result).await
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
|
||||||
// Mutable State-Specific Implementation
|
|
||||||
// =============================================================================
|
|
||||||
|
|
||||||
impl<C: LlmClient> Worker<C, Mutable> {
|
impl<C: LlmClient> Worker<C, Mutable> {
|
||||||
/// Create a new Worker (in Mutable state)
|
/// Create a new Worker (in Mutable state)
|
||||||
|
|
@ -941,43 +886,21 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Register a tool
|
/// Register a tool factory for deferred initialization.
|
||||||
///
|
///
|
||||||
/// Registered tools are automatically executed when called by the LLM.
|
/// The factory is queued and executed at the next `run()` or `resume()` call.
|
||||||
/// Registering a tool with the same name will result in an error.
|
/// Duplicate name detection occurs at that point and surfaces as
|
||||||
///
|
/// [`WorkerError::ToolRegistry`].
|
||||||
/// Available only in Mutable state.
|
pub fn register_tool(&mut self, factory: WorkerToolDefinition) {
|
||||||
pub fn register_tool(
|
self.tool_server.register_tool(factory);
|
||||||
&mut self,
|
|
||||||
factory: WorkerToolDefinition,
|
|
||||||
) -> Result<(), ToolRegistryError> {
|
|
||||||
match self.tool_server.register_tool(factory) {
|
|
||||||
Ok(()) => Ok(()),
|
|
||||||
Err(ToolServerError::DuplicateName(name)) => {
|
|
||||||
Err(ToolRegistryError::DuplicateName(name))
|
|
||||||
}
|
|
||||||
Err(ToolServerError::ToolNotFound(_) | ToolServerError::ToolExecution(_)) => {
|
|
||||||
unreachable!("register_tool should only fail with DuplicateName")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Register multiple tools
|
/// Register multiple tool factories for deferred initialization.
|
||||||
///
|
|
||||||
/// Available only in Mutable state.
|
|
||||||
pub fn register_tools(
|
pub fn register_tools(
|
||||||
&mut self,
|
&mut self,
|
||||||
factories: impl IntoIterator<Item = WorkerToolDefinition>,
|
factories: impl IntoIterator<Item = WorkerToolDefinition>,
|
||||||
) -> Result<(), ToolRegistryError> {
|
) {
|
||||||
match self.tool_server.register_tools(factories) {
|
self.tool_server.register_tools(factories);
|
||||||
Ok(()) => Ok(()),
|
|
||||||
Err(ToolServerError::DuplicateName(name)) => {
|
|
||||||
Err(ToolRegistryError::DuplicateName(name))
|
|
||||||
}
|
|
||||||
Err(ToolServerError::ToolNotFound(_) | ToolServerError::ToolExecution(_)) => {
|
|
||||||
unreachable!("register_tools should only fail with DuplicateName")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set system prompt (builder pattern)
|
/// Set system prompt (builder pattern)
|
||||||
|
|
@ -1082,40 +1005,47 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
||||||
|
|
||||||
/// Get a mutable reference to history
|
/// Get a mutable reference to history
|
||||||
///
|
///
|
||||||
/// Available only in Mutable state. History can be freely edited.
|
/// Available only in Mutable state.
|
||||||
pub fn history_mut(&mut self) -> &mut Vec<Item> {
|
pub fn history_mut(&mut self) -> &mut Vec<Item> {
|
||||||
|
|
||||||
&mut self.history
|
&mut self.history
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set history
|
/// Set history
|
||||||
pub fn set_history(&mut self, items: Vec<Item>) {
|
pub fn set_history(&mut self, items: Vec<Item>) {
|
||||||
|
|
||||||
self.history = items;
|
self.history = items;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add an item to history (builder pattern)
|
/// Add an item to history (builder pattern)
|
||||||
pub fn with_item(mut self, item: Item) -> Self {
|
pub fn with_item(mut self, item: Item) -> Self {
|
||||||
|
|
||||||
self.history.push(item);
|
self.history.push(item);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add an item to history
|
/// Add an item to history
|
||||||
pub fn push_item(&mut self, item: Item) {
|
pub fn push_item(&mut self, item: Item) {
|
||||||
|
|
||||||
self.history.push(item);
|
self.history.push(item);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add multiple items to history (builder pattern)
|
/// Add multiple items to history (builder pattern)
|
||||||
pub fn with_items(mut self, items: impl IntoIterator<Item = Item>) -> Self {
|
pub fn with_items(mut self, items: impl IntoIterator<Item = Item>) -> Self {
|
||||||
|
|
||||||
self.history.extend(items);
|
self.history.extend(items);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add multiple items to history
|
/// Add multiple items to history
|
||||||
pub fn extend_history(&mut self, items: impl IntoIterator<Item = Item>) {
|
pub fn extend_history(&mut self, items: impl IntoIterator<Item = Item>) {
|
||||||
|
|
||||||
self.history.extend(items);
|
self.history.extend(items);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Clear history
|
/// Clear history
|
||||||
pub fn clear_history(&mut self) {
|
pub fn clear_history(&mut self) {
|
||||||
|
|
||||||
self.history.clear();
|
self.history.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1148,11 +1078,48 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Lock and transition to CacheLocked state
|
/// Execute a turn, consuming self and transitioning to Locked.
|
||||||
///
|
///
|
||||||
/// This operation fixes the current system prompt and history as a "committed prefix".
|
/// This is the primary entry point for first use. Equivalent to
|
||||||
/// After this, only appending to history is allowed, ensuring cache hits.
|
/// `self.lock()` followed by `locked.run(user_input)`.
|
||||||
pub fn lock(self) -> Worker<C, CacheLocked> {
|
///
|
||||||
|
/// Subsequent runs can use [`Worker<C, Locked>::run()`] directly.
|
||||||
|
/// To edit state between turns, call [`unlock()`](Worker::unlock) first.
|
||||||
|
pub async fn run(
|
||||||
|
self,
|
||||||
|
user_input: impl Into<String>,
|
||||||
|
) -> Result<(Worker<C, Locked>, WorkerResult), WorkerError> {
|
||||||
|
let mut locked = self.lock();
|
||||||
|
let result = locked.run(user_input).await?;
|
||||||
|
Ok((locked, result))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resume from Paused, consuming self and transitioning to Locked.
|
||||||
|
///
|
||||||
|
/// Used after `unlock()` → edit → resume.
|
||||||
|
pub async fn resume(
|
||||||
|
self,
|
||||||
|
) -> Result<(Worker<C, Locked>, WorkerResult), WorkerError> {
|
||||||
|
let mut locked = self.lock();
|
||||||
|
let result = locked.resume().await?;
|
||||||
|
Ok((locked, result))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Lock and transition to Locked state
|
||||||
|
///
|
||||||
|
/// Flushes pending tool factories, then fixes the current system prompt
|
||||||
|
/// and history as a "committed prefix". After this, only `run()` / `resume()`
|
||||||
|
/// may append to history, ensuring cache hits.
|
||||||
|
///
|
||||||
|
/// Most callers should use [`run()`](Self::run) instead, which calls
|
||||||
|
/// this internally. Use `lock()` directly only when you need the
|
||||||
|
/// `Locked` worker back on error (e.g. in a persistence layer).
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if a pending tool factory produces a duplicate name.
|
||||||
|
pub fn lock(self) -> Worker<C, Locked> {
|
||||||
|
self.tool_server.flush_pending();
|
||||||
let locked_prefix_len = self.history.len();
|
let locked_prefix_len = self.history.len();
|
||||||
Worker {
|
Worker {
|
||||||
client: self.client,
|
client: self.client,
|
||||||
|
|
@ -1178,11 +1145,42 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
|
||||||
// CacheLocked State-Specific Implementation
|
|
||||||
// =============================================================================
|
|
||||||
|
|
||||||
impl<C: LlmClient> Worker<C, CacheLocked> {
|
impl<C: LlmClient> Worker<C, Locked> {
|
||||||
|
/// Execute a turn
|
||||||
|
///
|
||||||
|
/// Adds a new user message to history and sends a request to the LLM.
|
||||||
|
/// Automatically loops if there are tool calls.
|
||||||
|
pub async fn run(
|
||||||
|
&mut self,
|
||||||
|
user_input: impl Into<String>,
|
||||||
|
) -> Result<WorkerResult, WorkerError> {
|
||||||
|
self.reset_interruption_state();
|
||||||
|
// Interceptor: on_prompt_submit
|
||||||
|
let mut user_item = Item::user_message(user_input);
|
||||||
|
match self.interceptor.on_prompt_submit(&mut user_item).await {
|
||||||
|
PromptAction::Cancel(reason) => {
|
||||||
|
self.last_run_interrupted = true;
|
||||||
|
return self
|
||||||
|
.finalize_interruption(Err(WorkerError::Aborted(reason)))
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
PromptAction::Continue => {}
|
||||||
|
}
|
||||||
|
self.history.push(user_item);
|
||||||
|
let result = self.run_turn_loop().await;
|
||||||
|
self.finalize_interruption(result).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resume execution (from Paused state)
|
||||||
|
///
|
||||||
|
/// Resumes turn processing from current state without adding a new user message.
|
||||||
|
pub async fn resume(&mut self) -> Result<WorkerResult, WorkerError> {
|
||||||
|
self.reset_interruption_state();
|
||||||
|
let result = self.run_turn_loop().await;
|
||||||
|
self.finalize_interruption(result).await
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the prefix length at lock time
|
/// Get the prefix length at lock time
|
||||||
pub fn locked_prefix_len(&self) -> usize {
|
pub fn locked_prefix_len(&self) -> usize {
|
||||||
self.locked_prefix_len
|
self.locked_prefix_len
|
||||||
|
|
|
||||||
|
|
@ -46,8 +46,9 @@ async fn test_callback_text_block_events() {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Mutable::run consumes self, returns (Locked, WorkerResult)
|
||||||
let result = worker.run("Greet me").await;
|
let result = worker.run("Greet me").await;
|
||||||
assert!(result.is_ok(), "Worker should complete: {:?}", result);
|
assert!(result.is_ok(), "Worker should complete");
|
||||||
|
|
||||||
let deltas = text_deltas.lock().unwrap();
|
let deltas = text_deltas.lock().unwrap();
|
||||||
assert_eq!(deltas.len(), 2);
|
assert_eq!(deltas.len(), 2);
|
||||||
|
|
@ -91,6 +92,7 @@ async fn test_callback_tool_call_complete() {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Mutable::run consumes self, returns (Locked, WorkerResult)
|
||||||
let _ = worker.run("Weather please").await;
|
let _ = worker.run("Weather please").await;
|
||||||
|
|
||||||
let starts = tool_starts.lock().unwrap();
|
let starts = tool_starts.lock().unwrap();
|
||||||
|
|
@ -133,6 +135,7 @@ async fn test_callback_turn_events() {
|
||||||
ends.lock().unwrap().push(turn);
|
ends.lock().unwrap().push(turn);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Mutable::run consumes self, returns (Locked, WorkerResult)
|
||||||
let result = worker.run("Do something").await;
|
let result = worker.run("Do something").await;
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
|
@ -169,6 +172,7 @@ async fn test_callback_usage_events() {
|
||||||
usages.lock().unwrap().push(event.clone());
|
usages.lock().unwrap().push(event.clone());
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Mutable::run consumes self, returns (Locked, WorkerResult)
|
||||||
let _ = worker.run("Hello").await;
|
let _ = worker.run("Hello").await;
|
||||||
|
|
||||||
let usages = usage_events.lock().unwrap();
|
let usages = usage_events.lock().unwrap();
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
#[test]
|
#[test]
|
||||||
fn compile_fail_state_constraints() {
|
fn compile_fail_state_constraints() {
|
||||||
let t = trybuild::TestCases::new();
|
let t = trybuild::TestCases::new();
|
||||||
t.compile_fail("tests/ui/cache_locked_register_tool.rs");
|
t.compile_fail("tests/ui/locked_register_tool.rs");
|
||||||
t.compile_fail("tests/ui/tool_server_handle_register_tool.rs");
|
t.compile_fail("tests/ui/tool_server_handle_register_tool.rs");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -102,11 +102,12 @@ async fn test_parallel_tool_execution() {
|
||||||
let tool2_clone = tool2.clone();
|
let tool2_clone = tool2.clone();
|
||||||
let tool3_clone = tool3.clone();
|
let tool3_clone = tool3.clone();
|
||||||
|
|
||||||
worker.register_tool(tool1.definition()).unwrap();
|
worker.register_tool(tool1.definition());
|
||||||
worker.register_tool(tool2.definition()).unwrap();
|
worker.register_tool(tool2.definition());
|
||||||
worker.register_tool(tool3.definition()).unwrap();
|
worker.register_tool(tool3.definition());
|
||||||
|
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
// Mutable::run consumes self, returns (Locked, WorkerResult)
|
||||||
let _result = worker.run("Run all tools").await;
|
let _result = worker.run("Run all tools").await;
|
||||||
let elapsed = start.elapsed();
|
let elapsed = start.elapsed();
|
||||||
|
|
||||||
|
|
@ -150,8 +151,8 @@ async fn test_before_tool_call_skip() {
|
||||||
let allowed_clone = allowed_tool.clone();
|
let allowed_clone = allowed_tool.clone();
|
||||||
let blocked_clone = blocked_tool.clone();
|
let blocked_clone = blocked_tool.clone();
|
||||||
|
|
||||||
worker.register_tool(allowed_tool.definition()).unwrap();
|
worker.register_tool(allowed_tool.definition());
|
||||||
worker.register_tool(blocked_tool.definition()).unwrap();
|
worker.register_tool(blocked_tool.definition());
|
||||||
|
|
||||||
// Policy to skip "blocked_tool"
|
// Policy to skip "blocked_tool"
|
||||||
struct BlockingPolicy;
|
struct BlockingPolicy;
|
||||||
|
|
@ -169,6 +170,7 @@ async fn test_before_tool_call_skip() {
|
||||||
|
|
||||||
worker.set_interceptor(BlockingPolicy);
|
worker.set_interceptor(BlockingPolicy);
|
||||||
|
|
||||||
|
// Mutable::run consumes self, returns (Locked, WorkerResult)
|
||||||
let _result = worker.run("Test hook").await;
|
let _result = worker.run("Test hook").await;
|
||||||
|
|
||||||
// allowed_tool is called, but blocked_tool is not
|
// allowed_tool is called, but blocked_tool is not
|
||||||
|
|
@ -230,7 +232,7 @@ async fn test_post_tool_call_modification() {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
worker.register_tool(simple_tool_definition()).unwrap();
|
worker.register_tool(simple_tool_definition());
|
||||||
|
|
||||||
// Policy to modify results
|
// Policy to modify results
|
||||||
struct ModifyingPolicy {
|
struct ModifyingPolicy {
|
||||||
|
|
@ -251,9 +253,10 @@ async fn test_post_tool_call_modification() {
|
||||||
modified_content: modified_content.clone(),
|
modified_content: modified_content.clone(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Mutable::run consumes self, returns (Locked, WorkerResult)
|
||||||
let result = worker.run("Test modification").await;
|
let result = worker.run("Test modification").await;
|
||||||
|
|
||||||
assert!(result.is_ok(), "Worker should complete: {:?}", result);
|
assert!(result.is_ok(), "Worker should complete");
|
||||||
|
|
||||||
// Verify hook was called and content was modified
|
// Verify hook was called and content was modified
|
||||||
let content = modified_content.lock().unwrap().clone();
|
let content = modified_content.lock().unwrap().clone();
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
error[E0599]: no method named `register_tool` found for struct `Worker<OllamaClient, CacheLocked>` in the current scope
|
error[E0599]: no method named `register_tool` found for struct `Worker<OllamaClient, Locked>` in the current scope
|
||||||
--> tests/ui/cache_locked_register_tool.rs:10:20
|
--> tests/ui/locked_register_tool.rs:10:20
|
||||||
|
|
|
|
||||||
10 | let _ = locked.register_tool(def);
|
10 | let _ = locked.register_tool(def);
|
||||||
| ^^^^^^^^^^^^^ method not found in `Worker<OllamaClient, CacheLocked>`
|
| ^^^^^^^^^^^^^ method not found in `Worker<OllamaClient, Locked>`
|
||||||
|
|
|
|
||||||
= note: the method was found for
|
= note: the method was found for
|
||||||
- `Worker<C>`
|
- `Worker<C>`
|
||||||
|
|
@ -1,13 +1,10 @@
|
||||||
error[E0624]: method `register_tool` is private
|
error[E0624]: method `register_tool` is private
|
||||||
--> tests/ui/tool_server_handle_register_tool.rs:10:20
|
--> tests/ui/tool_server_handle_register_tool.rs:10:20
|
||||||
|
|
|
|
||||||
10 | let _ = handle.register_tool(def);
|
10 | let _ = handle.register_tool(def);
|
||||||
| ^^^^^^^^^^^^^ private method
|
| ^^^^^^^^^^^^^ private method
|
||||||
|
|
|
|
||||||
::: src/tool_server.rs
|
::: src/tool_server.rs
|
||||||
|
|
|
|
||||||
| / pub(crate) fn register_tool(
|
| pub(crate) fn register_tool(&self, factory: WorkerToolDefinition) {
|
||||||
| | &self,
|
| ----------------------------------------------------------------- private method defined here
|
||||||
| | factory: WorkerToolDefinition,
|
|
||||||
| | ) -> Result<(), ToolServerError> {
|
|
||||||
| |____________________________________- private method defined here
|
|
||||||
|
|
|
||||||
|
|
@ -129,9 +129,9 @@ async fn test_worker_simple_text_response() {
|
||||||
}
|
}
|
||||||
|
|
||||||
let client = MockLlmClient::from_fixture(&fixture_path).unwrap();
|
let client = MockLlmClient::from_fixture(&fixture_path).unwrap();
|
||||||
let mut worker = Worker::new(client);
|
let worker = Worker::new(client);
|
||||||
|
|
||||||
// Send a simple message
|
// Send a simple message (Mutable::run consumes self, returns tuple)
|
||||||
let result = worker.run("Hello").await;
|
let result = worker.run("Hello").await;
|
||||||
|
|
||||||
assert!(result.is_ok(), "Worker should complete successfully");
|
assert!(result.is_ok(), "Worker should complete successfully");
|
||||||
|
|
@ -156,9 +156,9 @@ async fn test_worker_tool_call() {
|
||||||
// Register tool
|
// Register tool
|
||||||
let weather_tool = MockWeatherTool::new();
|
let weather_tool = MockWeatherTool::new();
|
||||||
let tool_for_check = weather_tool.clone();
|
let tool_for_check = weather_tool.clone();
|
||||||
worker.register_tool(weather_tool.definition()).unwrap();
|
worker.register_tool(weather_tool.definition());
|
||||||
|
|
||||||
// Send message
|
// Send message (Mutable::run consumes self, returns tuple)
|
||||||
let _result = worker.run("What's the weather in Tokyo?").await;
|
let _result = worker.run("What's the weather in Tokyo?").await;
|
||||||
|
|
||||||
// Verify tool was called
|
// Verify tool was called
|
||||||
|
|
@ -190,8 +190,9 @@ async fn test_worker_with_programmatic_events() {
|
||||||
];
|
];
|
||||||
|
|
||||||
let client = MockLlmClient::new(events);
|
let client = MockLlmClient::new(events);
|
||||||
let mut worker = Worker::new(client);
|
let worker = Worker::new(client);
|
||||||
|
|
||||||
|
// Mutable::run consumes self, returns tuple
|
||||||
let result = worker.run("Greet me").await;
|
let result = worker.run("Greet me").await;
|
||||||
|
|
||||||
assert!(result.is_ok(), "Worker should complete successfully");
|
assert!(result.is_ok(), "Worker should complete successfully");
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
//! Worker state management tests
|
//! Worker state management tests
|
||||||
//!
|
//!
|
||||||
//! Tests for state transitions using the Type-state pattern (Mutable/CacheLocked)
|
//! Tests for state transitions using the Type-state pattern (Mutable/Locked)
|
||||||
//! and state preservation between turns.
|
//! and state preservation between turns.
|
||||||
|
|
||||||
mod common;
|
mod common;
|
||||||
|
|
@ -11,7 +11,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use common::MockLlmClient;
|
use common::MockLlmClient;
|
||||||
use llm_worker::Item;
|
use llm_worker::Item;
|
||||||
use llm_worker::Worker;
|
use llm_worker::{Worker, WorkerError};
|
||||||
use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent};
|
use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent};
|
||||||
use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta};
|
use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta};
|
||||||
|
|
||||||
|
|
@ -147,15 +147,15 @@ fn test_mutable_can_register_tool() {
|
||||||
let mut worker = Worker::new(client);
|
let mut worker = Worker::new(client);
|
||||||
let tool = CountingTool::new("count_tool");
|
let tool = CountingTool::new("count_tool");
|
||||||
|
|
||||||
let result = worker.register_tool(tool.definition());
|
// register_tool is infallible (factory deferred to run-time flush)
|
||||||
assert!(result.is_ok(), "Mutable should allow tool registration");
|
worker.register_tool(tool.definition());
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// State Transition Tests
|
// State Transition Tests
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// Verify that lock() transitions from Mutable -> CacheLocked state
|
/// Verify that lock() transitions from Mutable -> Locked state
|
||||||
#[test]
|
#[test]
|
||||||
fn test_lock_transition() {
|
fn test_lock_transition() {
|
||||||
let client = MockLlmClient::new(vec![]);
|
let client = MockLlmClient::new(vec![]);
|
||||||
|
|
@ -168,13 +168,13 @@ fn test_lock_transition() {
|
||||||
// Lock
|
// Lock
|
||||||
let locked_worker = worker.lock();
|
let locked_worker = worker.lock();
|
||||||
|
|
||||||
// History and system prompt are still accessible in CacheLocked state
|
// History and system prompt are still accessible in Locked state
|
||||||
assert_eq!(locked_worker.get_system_prompt(), Some("System"));
|
assert_eq!(locked_worker.get_system_prompt(), Some("System"));
|
||||||
assert_eq!(locked_worker.history().len(), 2);
|
assert_eq!(locked_worker.history().len(), 2);
|
||||||
assert_eq!(locked_worker.locked_prefix_len(), 2);
|
assert_eq!(locked_worker.locked_prefix_len(), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Verify that unlock() transitions from CacheLocked -> Mutable state
|
/// Verify that unlock() transitions from Locked -> Mutable state
|
||||||
#[test]
|
#[test]
|
||||||
fn test_unlock_transition() {
|
fn test_unlock_transition() {
|
||||||
let client = MockLlmClient::new(vec![]);
|
let client = MockLlmClient::new(vec![]);
|
||||||
|
|
@ -198,7 +198,7 @@ fn test_unlock_transition() {
|
||||||
|
|
||||||
/// Verify that history is correctly updated after running a turn in Mutable state
|
/// Verify that history is correctly updated after running a turn in Mutable state
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_mutable_run_updates_history() {
|
async fn test_mutable_run_updates_history() -> Result<(), WorkerError> {
|
||||||
let events = vec![
|
let events = vec![
|
||||||
Event::text_block_start(0),
|
Event::text_block_start(0),
|
||||||
Event::text_delta(0, "Hello, I'm an assistant!"),
|
Event::text_delta(0, "Hello, I'm an assistant!"),
|
||||||
|
|
@ -209,11 +209,10 @@ async fn test_mutable_run_updates_history() {
|
||||||
];
|
];
|
||||||
|
|
||||||
let client = MockLlmClient::new(events);
|
let client = MockLlmClient::new(events);
|
||||||
let mut worker = Worker::new(client);
|
let worker = Worker::new(client);
|
||||||
|
|
||||||
// Execute
|
// Execute (Mutable::run consumes self, returns (Locked, WorkerResult))
|
||||||
let result = worker.run("Hi there").await;
|
let (worker, _result) = worker.run("Hi there").await?;
|
||||||
assert!(result.is_ok());
|
|
||||||
|
|
||||||
// History is updated
|
// History is updated
|
||||||
let history = worker.history();
|
let history = worker.history();
|
||||||
|
|
@ -224,9 +223,11 @@ async fn test_mutable_run_updates_history() {
|
||||||
|
|
||||||
// Assistant message
|
// Assistant message
|
||||||
assert_eq!(history[1].as_text(), Some("Hello, I'm an assistant!"));
|
assert_eq!(history[1].as_text(), Some("Hello, I'm an assistant!"));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Verify that history accumulates correctly over multiple turns in CacheLocked state
|
/// Verify that history accumulates correctly over multiple turns in Locked state
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_locked_multi_turn_history_accumulation() {
|
async fn test_locked_multi_turn_history_accumulation() {
|
||||||
// Prepare responses for 2 requests
|
// Prepare responses for 2 requests
|
||||||
|
|
@ -327,7 +328,7 @@ async fn test_locked_prefix_len_tracking() {
|
||||||
|
|
||||||
/// Verify that turn count is correctly incremented
|
/// Verify that turn count is correctly incremented
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_turn_count_increment() {
|
async fn test_turn_count_increment() -> Result<(), WorkerError> {
|
||||||
let client = MockLlmClient::with_responses(vec![
|
let client = MockLlmClient::with_responses(vec![
|
||||||
vec![
|
vec![
|
||||||
Event::text_block_start(0),
|
Event::text_block_start(0),
|
||||||
|
|
@ -347,15 +348,19 @@ async fn test_turn_count_increment() {
|
||||||
],
|
],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let mut worker = Worker::new(client);
|
let worker = Worker::new(client);
|
||||||
|
|
||||||
assert_eq!(worker.turn_count(), 0);
|
assert_eq!(worker.turn_count(), 0);
|
||||||
|
|
||||||
worker.run("First").await.unwrap();
|
// First run consumes Mutable, returns Locked
|
||||||
|
let (mut worker, _) = worker.run("First").await?;
|
||||||
assert_eq!(worker.turn_count(), 1);
|
assert_eq!(worker.turn_count(), 1);
|
||||||
|
|
||||||
worker.run("Second").await.unwrap();
|
// Subsequent runs on Locked take &mut self
|
||||||
|
worker.run("Second").await?;
|
||||||
assert_eq!(worker.turn_count(), 2);
|
assert_eq!(worker.turn_count(), 2);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Verify that history can be edited after unlock and re-locked
|
/// Verify that history can be edited after unlock and re-locked
|
||||||
|
|
@ -430,9 +435,7 @@ async fn test_lock_unlock_relock_tools_remain_effective() {
|
||||||
|
|
||||||
let mut worker = Worker::new(client);
|
let mut worker = Worker::new(client);
|
||||||
let tool_a = CountingTool::new("tool_a");
|
let tool_a = CountingTool::new("tool_a");
|
||||||
worker
|
worker.register_tool(tool_a.definition());
|
||||||
.register_tool(tool_a.definition())
|
|
||||||
.expect("register tool_a should succeed");
|
|
||||||
|
|
||||||
let mut locked = worker.lock();
|
let mut locked = worker.lock();
|
||||||
locked.run("first").await.expect("first run");
|
locked.run("first").await.expect("first run");
|
||||||
|
|
@ -440,9 +443,7 @@ async fn test_lock_unlock_relock_tools_remain_effective() {
|
||||||
|
|
||||||
let mut unlocked = locked.unlock();
|
let mut unlocked = locked.unlock();
|
||||||
let tool_b = CountingTool::new("tool_b");
|
let tool_b = CountingTool::new("tool_b");
|
||||||
unlocked
|
unlocked.register_tool(tool_b.definition());
|
||||||
.register_tool(tool_b.definition())
|
|
||||||
.expect("register tool_b after unlock should succeed");
|
|
||||||
|
|
||||||
let mut relocked = unlocked.lock();
|
let mut relocked = unlocked.lock();
|
||||||
relocked.run("second").await.expect("second run");
|
relocked.run("second").await.expect("second run");
|
||||||
|
|
@ -455,7 +456,7 @@ async fn test_lock_unlock_relock_tools_remain_effective() {
|
||||||
// System Prompt Preservation Tests
|
// System Prompt Preservation Tests
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// Verify that system prompt is preserved in CacheLocked state
|
/// Verify that system prompt is preserved in Locked state
|
||||||
#[test]
|
#[test]
|
||||||
fn test_system_prompt_preserved_in_locked_state() {
|
fn test_system_prompt_preserved_in_locked_state() {
|
||||||
let client = MockLlmClient::new(vec![]);
|
let client = MockLlmClient::new(vec![]);
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. Extract the assistant's reply from history
|
// 5. Extract the assistant's reply from history
|
||||||
let history = pod.session_mut().worker.history();
|
let history = pod.session_mut().worker().history();
|
||||||
if let Some(text) = history
|
if let Some(text) = history
|
||||||
.iter()
|
.iter()
|
||||||
.rev()
|
.rev()
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,7 @@ impl PodController {
|
||||||
|
|
||||||
// Register event bridge callbacks on the worker
|
// Register event bridge callbacks on the worker
|
||||||
{
|
{
|
||||||
let worker = &mut pod.session_mut().worker;
|
let worker = pod.session_mut().worker_mut();
|
||||||
|
|
||||||
let tx = event_tx.clone();
|
let tx = event_tx.clone();
|
||||||
worker.on_turn_start(move |turn| {
|
worker.on_turn_start(move |turn| {
|
||||||
|
|
@ -158,7 +158,7 @@ impl PodController {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clone cancel sender before moving pod
|
// Clone cancel sender before moving pod
|
||||||
let cancel_tx = pod.session_mut().worker.cancel_sender();
|
let cancel_tx = pod.session_mut().worker_mut().cancel_sender();
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
// Hold socket server alive for the lifetime of the controller task
|
// Hold socket server alive for the lifetime of the controller task
|
||||||
|
|
@ -191,7 +191,7 @@ impl PodController {
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let items = pod.session_mut().worker.history().to_vec();
|
let items = pod.session_mut().worker_mut().history().to_vec();
|
||||||
shared_state.update_history(items);
|
shared_state.update_history(items);
|
||||||
shared_state.set_status(new_status);
|
shared_state.set_status(new_status);
|
||||||
let _ = runtime_dir.write_status(&shared_state).await;
|
let _ = runtime_dir.write_status(&shared_state).await;
|
||||||
|
|
@ -218,7 +218,7 @@ impl PodController {
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let items = pod.session_mut().worker.history().to_vec();
|
let items = pod.session_mut().worker_mut().history().to_vec();
|
||||||
shared_state.update_history(items);
|
shared_state.update_history(items);
|
||||||
shared_state.set_status(new_status);
|
shared_state.set_status(new_status);
|
||||||
let _ = runtime_dir.write_status(&shared_state).await;
|
let _ = runtime_dir.write_status(&shared_state).await;
|
||||||
|
|
|
||||||
|
|
@ -144,7 +144,7 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
|
||||||
let builder = std::mem::take(&mut self.hook_builder);
|
let builder = std::mem::take(&mut self.hook_builder);
|
||||||
let registry = Arc::new(builder.build());
|
let registry = Arc::new(builder.build());
|
||||||
let interceptor = HookInterceptor::new(registry);
|
let interceptor = HookInterceptor::new(registry);
|
||||||
self.session.worker.set_interceptor(interceptor);
|
self.session.worker_mut().set_interceptor(interceptor);
|
||||||
self.interceptor_installed = true;
|
self.interceptor_installed = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,57 +0,0 @@
|
||||||
# Worker: run() 時の自動キャッシュロックと ToolDefinition ファクトリ遅延初期化
|
|
||||||
|
|
||||||
## 背景
|
|
||||||
|
|
||||||
現状の `Worker` は Type-state パターンで `Mutable` / `CacheLocked` の2状態を持つが、
|
|
||||||
`lock()` を呼ばなくても `run()` できてしまうため、キャッシュ保護の存在を知らないユーザーは
|
|
||||||
常に非最適パスを通ることになる。
|
|
||||||
|
|
||||||
また `register_tool()` 時に `ToolDefinition` のファクトリクロージャが即時呼び出しされており、
|
|
||||||
本来の意図である遅延初期化になっていない。
|
|
||||||
|
|
||||||
## 方針
|
|
||||||
|
|
||||||
### run() 時の自動ロック
|
|
||||||
|
|
||||||
`run()` の冒頭で、`Mutable` 状態なら自動的に `CacheLocked` へ遷移する。
|
|
||||||
これにより lock を知らないユーザーでも嫌でもキャッシュ保護される。
|
|
||||||
|
|
||||||
ターンの合間に history や system prompt を編集したい場合は、明示的に `unlock()` を挟む。
|
|
||||||
次の `run()` で再び自動ロックされる。
|
|
||||||
|
|
||||||
```rust
|
|
||||||
let mut worker = Worker::new(client);
|
|
||||||
worker.set_system_prompt("...");
|
|
||||||
worker.register_tool(my_tool)?;
|
|
||||||
|
|
||||||
// Mutable のまま run() → 自動で lock される
|
|
||||||
worker.run("Hello").await?;
|
|
||||||
|
|
||||||
// ターン間で内容を弄りたい場合
|
|
||||||
worker.unlock();
|
|
||||||
worker.history_mut().truncate(5);
|
|
||||||
// 次の run() で再 lock
|
|
||||||
worker.run("Continue").await?;
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 設計ポイント
|
|
||||||
|
|
||||||
- `run()` が `&mut self` を取る以上、内部で状態遷移しても外部の型は変わらない。
|
|
||||||
実装は内部フラグ(`is_locked: bool`)で管理し、`Mutable` / `CacheLocked` の
|
|
||||||
type-state はそのまま維持する。`Worker<C, Mutable>` の `run()` が内部で lock 相当の
|
|
||||||
処理を行い、`unlock()` が呼ばれるまでキャッシュ破壊的な操作を(ランタイムで)ブロックする
|
|
||||||
- `Worker<C, CacheLocked>` は従来どおり。明示的に `lock()` してから使うパスも残る
|
|
||||||
- interceptor, max_turns, callbacks 等キャッシュに影響しない設定は lock 状態でも自由に変更可能
|
|
||||||
|
|
||||||
### ToolDefinition ファクトリの遅延初期化
|
|
||||||
|
|
||||||
`register_tool()` は定義を蓄積するだけにし、ファクトリの実行を初回 `run()` まで遅延させる。
|
|
||||||
|
|
||||||
- `register_tool()` → `ToolDefinition` を Vec に push するだけ
|
|
||||||
- 初回 `run()` の自動 lock 時にファクトリを一括実行し、ToolServer を構築
|
|
||||||
- `unlock()` 後に追加登録された tool は次の `run()` で初期化
|
|
||||||
|
|
||||||
## 移行
|
|
||||||
|
|
||||||
- `lock()` / `unlock()` は引き続き使える(明示的なキャッシュ管理用)
|
|
||||||
- `Worker::new()` → `run()` のパスが自動保護されるため、既存コードは変更不要
|
|
||||||
Loading…
Reference in New Issue
Block a user