yoi/crates/llm-worker/src/worker.rs
2026-04-11 03:16:36 +09:00

1442 lines
49 KiB
Rust

use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::{Arc, Mutex};
use futures::StreamExt;
use tokio::sync::mpsc;
use tracing::{debug, info, trace, warn};
use crate::{
Item,
hook::{
Hook, HookError, HookRegistry, OnAbort, OnPromptSubmit, OnPromptSubmitResult,
OnStreamChunk, OnStreamComplete, OnTextDelta, OnToolCallDelta, OnTurnEnd, OnTurnEndResult,
PostToolCall, PostToolCallContext, PostToolCallResult, PreLlmRequest, PreLlmRequestResult,
PreToolCall, PreToolCallResult, StreamChunkContext, StreamCompleteContext,
StreamHookResult, TextDeltaContext, ToolCall, ToolCallContext, ToolCallDeltaContext,
ToolResult,
},
llm_client::{ClientError, ConfigWarning, LlmClient, Request, RequestConfig, ToolDefinition},
state::{CacheLocked, Mutable, WorkerState},
subscriber::{
ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter,
ToolUseBlockSubscriberAdapter, UsageSubscriberAdapter, WorkerSubscriber,
},
timeline::{TextBlockCollector, Timeline, ToolCallCollector},
tool::{ToolDefinition as WorkerToolDefinition, ToolError, ToolOutputProcessor},
tool_server::{ToolServer, ToolServerError, ToolServerHandle},
};
// =============================================================================
// Worker Error
// =============================================================================
/// Worker errors
#[derive(Debug, thiserror::Error)]
pub enum WorkerError {
/// Client error
#[error("Client error: {0}")]
Client(#[from] ClientError),
/// Tool error
#[error("Tool error: {0}")]
Tool(#[from] ToolError),
/// Hook error
#[error("Hook error: {0}")]
Hook(#[from] HookError),
/// Execution was aborted
#[error("Aborted: {0}")]
Aborted(String),
/// Cancelled by CancellationToken
#[error("Cancelled")]
Cancelled,
/// Config warnings (unsupported options)
#[error("Config warnings: {}", .0.iter().map(|w| w.to_string()).collect::<Vec<_>>().join(", "))]
ConfigWarnings(Vec<ConfigWarning>),
}
/// Tool registration error
#[derive(Debug, thiserror::Error)]
pub enum ToolRegistryError {
/// A tool with the same name is already registered
#[error("Tool with name '{0}' already registered")]
DuplicateName(String),
}
// =============================================================================
// Worker Config
// =============================================================================
/// Worker configuration
#[derive(Debug, Clone, Default)]
pub struct WorkerConfig {
// Reserved for future extensions (currently empty)
_private: (),
}
// =============================================================================
// Worker Result Types
// =============================================================================
/// Worker execution result (status)
#[derive(Debug)]
pub enum WorkerResult {
/// Completed (waiting for user input)
Finished,
/// Paused (can be resumed)
Paused,
/// Turn limit reached (max_turns exceeded)
LimitReached,
}
/// Internal: tool execution result
enum ToolExecutionResult {
Completed(Vec<ToolResult>),
Paused,
}
// =============================================================================
// Turn Control Callback Storage
// =============================================================================
/// Callback for notifying turn events (type-erased)
trait TurnNotifier: Send + Sync {
fn on_turn_start(&self, turn: usize);
fn on_turn_end(&self, turn: usize);
}
struct SubscriberTurnNotifier<S: WorkerSubscriber + 'static> {
subscriber: Arc<Mutex<S>>,
}
impl<S: WorkerSubscriber + 'static> TurnNotifier for SubscriberTurnNotifier<S> {
fn on_turn_start(&self, turn: usize) {
if let Ok(mut s) = self.subscriber.lock() {
s.on_turn_start(turn);
}
}
fn on_turn_end(&self, turn: usize) {
if let Ok(mut s) = self.subscriber.lock() {
s.on_turn_end(turn);
}
}
}
// =============================================================================
// Worker
// =============================================================================
/// Central component for managing LLM interactions
///
/// Receives input from the user, sends requests to the LLM, and
/// automatically executes tool calls if any, advancing the turn.
///
/// # State Transitions (Type-state)
///
/// - [`Mutable`]: Initial state. System prompt and history can be freely edited.
/// - [`CacheLocked`]: Cache-protected state. Transition via `lock()`. Prefix context is immutable.
///
/// # Examples
///
/// ```ignore
/// use llm_worker::{Worker, Item};
///
/// // Create a Worker and register tools
/// let mut worker = Worker::new(client)
/// .system_prompt("You are a helpful assistant.");
/// worker.register_tool(my_tool);
///
/// // Run the interaction
/// let history = worker.run("Hello!").await?;
/// ```
///
/// # When Cache Protection is Needed
///
/// ```ignore
/// let mut worker = Worker::new(client)
/// .system_prompt("...");
///
/// // After setting history, lock to protect cache
/// let mut locked = worker.lock();
/// locked.run("user input").await?;
/// ```
pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
/// LLM client
client: C,
/// Event timeline
timeline: Timeline,
/// Text block collector (Timeline handler)
text_block_collector: TextBlockCollector,
/// Tool call collector (Timeline handler)
tool_call_collector: ToolCallCollector,
/// Tool server handle
tool_server: ToolServerHandle,
/// Hook registry
hooks: HookRegistry,
/// System prompt
system_prompt: Option<String>,
/// Item history (owned by Worker)
history: Vec<Item>,
/// History length at lock time (only meaningful in CacheLocked state)
locked_prefix_len: usize,
/// Turn count
turn_count: usize,
/// Maximum number of turns (None = unlimited)
max_turns: Option<u32>,
/// Turn notification callbacks
turn_notifiers: Vec<Box<dyn TurnNotifier>>,
/// Request configuration (max_tokens, temperature, etc.)
request_config: RequestConfig,
/// Whether the previous run was interrupted
last_run_interrupted: bool,
/// Optional processor for large tool outputs (stores externally, returns summary)
output_processor: Option<Arc<dyn ToolOutputProcessor>>,
/// Cancel notification channel (for interrupting execution)
cancel_tx: mpsc::Sender<()>,
cancel_rx: mpsc::Receiver<()>,
/// State marker
_state: PhantomData<S>,
}
// =============================================================================
// Common Implementation (available in all states)
// =============================================================================
impl<C: LlmClient, S: WorkerState> Worker<C, S> {
fn reset_interruption_state(&mut self) {
self.last_run_interrupted = false;
}
/// Execute a turn
///
/// Adds a new user message to history and sends a request to the LLM.
/// Automatically loops if there are tool calls.
pub async fn run(
&mut self,
user_input: impl Into<String>,
) -> Result<WorkerResult, WorkerError> {
self.reset_interruption_state();
// Hook: on_prompt_submit
let mut user_item = Item::user_message(user_input);
let result = self.run_on_prompt_submit_hooks(&mut user_item).await;
let result = match result {
Ok(value) => value,
Err(err) => return self.finalize_interruption(Err(err)).await,
};
match result {
OnPromptSubmitResult::Cancel(reason) => {
self.last_run_interrupted = true;
return self
.finalize_interruption(Err(WorkerError::Aborted(reason)))
.await;
}
OnPromptSubmitResult::Continue => {}
}
self.history.push(user_item);
let result = self.run_turn_loop().await;
self.finalize_interruption(result).await
}
fn drain_cancel_queue(&mut self) {
use tokio::sync::mpsc::error::TryRecvError;
loop {
match self.cancel_rx.try_recv() {
Ok(()) => continue,
Err(TryRecvError::Empty) | Err(TryRecvError::Disconnected) => break,
}
}
}
fn try_cancelled(&mut self) -> bool {
use tokio::sync::mpsc::error::TryRecvError;
match self.cancel_rx.try_recv() {
Ok(()) => true,
Err(TryRecvError::Empty) => false,
Err(TryRecvError::Disconnected) => true,
}
}
/// Register an event subscriber
///
/// Registered subscribers receive streaming events from the LLM
/// in real-time. Useful for streaming display to UI.
///
/// # Available Events
///
/// - **Block events**: `on_text_block`, `on_tool_use_block`
/// - **Meta events**: `on_usage`, `on_status`, `on_error`
/// - **Completion events**: `on_text_complete`, `on_tool_call_complete`
/// - **Turn control**: `on_turn_start`, `on_turn_end`
///
/// # Examples
///
/// ```ignore
/// use llm_worker::{Worker, WorkerSubscriber, TextBlockEvent};
///
/// struct MyPrinter;
/// impl WorkerSubscriber for MyPrinter {
/// type TextBlockScope = ();
/// type ToolUseBlockScope = ();
///
/// fn on_text_block(&mut self, _: &mut (), event: &TextBlockEvent) {
/// if let TextBlockEvent::Delta(text) = event {
/// print!("{}", text);
/// }
/// }
/// }
///
/// worker.subscribe(MyPrinter);
/// ```
pub fn subscribe<Sub: WorkerSubscriber + 'static>(&mut self, subscriber: Sub) {
let subscriber = Arc::new(Mutex::new(subscriber));
// Register TextBlock handler
self.timeline
.on_text_block(TextBlockSubscriberAdapter::new(subscriber.clone()));
// Register ToolUseBlock handler
self.timeline
.on_tool_use_block(ToolUseBlockSubscriberAdapter::new(subscriber.clone()));
// Register meta handlers
self.timeline
.on_usage(UsageSubscriberAdapter::new(subscriber.clone()));
self.timeline
.on_status(StatusSubscriberAdapter::new(subscriber.clone()));
self.timeline
.on_error(ErrorSubscriberAdapter::new(subscriber.clone()));
// Register turn control callback
self.turn_notifiers
.push(Box::new(SubscriberTurnNotifier { subscriber }));
}
/// Get a shared tool server handle.
pub fn tool_server_handle(&self) -> ToolServerHandle {
self.tool_server.clone()
}
/// Add an on_prompt_submit Hook
///
/// Called immediately after receiving a user message in `run()`.
pub fn add_on_prompt_submit_hook(&mut self, hook: impl Hook<OnPromptSubmit> + 'static) {
self.hooks.on_prompt_submit.push(Box::new(hook));
}
/// Add a pre_llm_request Hook
///
/// Called before sending an LLM request for each turn.
pub fn add_pre_llm_request_hook(&mut self, hook: impl Hook<PreLlmRequest> + 'static) {
self.hooks.pre_llm_request.push(Box::new(hook));
}
/// Add a pre_tool_call Hook
pub fn add_pre_tool_call_hook(&mut self, hook: impl Hook<PreToolCall> + 'static) {
self.hooks.pre_tool_call.push(Box::new(hook));
}
/// Add a post_tool_call Hook
pub fn add_post_tool_call_hook(&mut self, hook: impl Hook<PostToolCall> + 'static) {
self.hooks.post_tool_call.push(Box::new(hook));
}
/// Add an on_turn_end Hook
pub fn add_on_turn_end_hook(&mut self, hook: impl Hook<OnTurnEnd> + 'static) {
self.hooks.on_turn_end.push(Box::new(hook));
}
/// Add an on_abort Hook
pub fn add_on_abort_hook(&mut self, hook: impl Hook<OnAbort> + 'static) {
self.hooks.on_abort.push(Box::new(hook));
}
/// Add an on_text_delta Hook
pub fn add_on_text_delta_hook(&mut self, hook: impl Hook<OnTextDelta> + 'static) {
self.hooks.on_text_delta.push(Box::new(hook));
}
/// Add an on_tool_call_delta Hook
pub fn add_on_tool_call_delta_hook(&mut self, hook: impl Hook<OnToolCallDelta> + 'static) {
self.hooks.on_tool_call_delta.push(Box::new(hook));
}
/// Add an on_stream_chunk Hook
pub fn add_on_stream_chunk_hook(&mut self, hook: impl Hook<OnStreamChunk> + 'static) {
self.hooks.on_stream_chunk.push(Box::new(hook));
}
/// Add an on_stream_complete Hook
pub fn add_on_stream_complete_hook(&mut self, hook: impl Hook<OnStreamComplete> + 'static) {
self.hooks.on_stream_complete.push(Box::new(hook));
}
/// Get a mutable reference to the timeline (for additional handler registration)
pub fn timeline_mut(&mut self) -> &mut Timeline {
&mut self.timeline
}
/// Get a reference to the history
pub fn history(&self) -> &[Item] {
&self.history
}
/// Get a reference to the system prompt
pub fn get_system_prompt(&self) -> Option<&str> {
self.system_prompt.as_deref()
}
/// Get the current turn count
pub fn turn_count(&self) -> usize {
self.turn_count
}
/// Get a reference to the current request configuration
pub fn request_config(&self) -> &RequestConfig {
&self.request_config
}
/// Set maximum tokens
///
/// This setting is independent of cache lock and applies to each request.
///
/// # Examples
///
/// ```ignore
/// worker.set_max_tokens(4096);
/// ```
pub fn set_max_tokens(&mut self, max_tokens: u32) {
self.request_config.max_tokens = Some(max_tokens);
}
/// Set temperature
///
/// Set in the range of 0.0 to 1.0 (or 2.0).
/// Lower values produce more deterministic output, higher values produce more diverse output.
///
/// # Examples
///
/// ```ignore
/// worker.set_temperature(0.7);
/// ```
pub fn set_temperature(&mut self, temperature: f32) {
self.request_config.temperature = Some(temperature);
}
/// Set top_p (nucleus sampling)
///
/// # Examples
///
/// ```ignore
/// worker.set_top_p(0.9);
/// ```
pub fn set_top_p(&mut self, top_p: f32) {
self.request_config.top_p = Some(top_p);
}
/// Set top_k
///
/// Specifies the top k tokens to consider when selecting tokens.
///
/// # Examples
///
/// ```ignore
/// worker.set_top_k(40);
/// ```
pub fn set_top_k(&mut self, top_k: u32) {
self.request_config.top_k = Some(top_k);
}
/// Add a stop sequence
///
/// # Examples
///
/// ```ignore
/// worker.add_stop_sequence("\n\n");
/// ```
pub fn add_stop_sequence(&mut self, sequence: impl Into<String>) {
self.request_config.stop_sequences.push(sequence.into());
}
/// Clear stop sequences
pub fn clear_stop_sequences(&mut self) {
self.request_config.stop_sequences.clear();
}
/// Get the cancel notification sender
pub fn cancel_sender(&self) -> mpsc::Sender<()> {
self.cancel_tx.clone()
}
/// Set request configuration at once
pub fn set_request_config(&mut self, config: RequestConfig) {
self.request_config = config;
}
/// Cancel execution
///
/// Interrupts currently running streaming or tool execution.
/// WorkerError::Cancelled is returned at the next event loop checkpoint.
///
/// # Examples
///
/// ```ignore
/// use std::sync::Arc;
/// let worker = Arc::new(Mutex::new(Worker::new(client)));
///
/// // Run in another thread
/// let worker_clone = worker.clone();
/// tokio::spawn(async move {
/// let mut w = worker_clone.lock().unwrap();
/// w.run("Long task...").await
/// });
///
/// // Cancel
/// worker.lock().unwrap().cancel();
/// ```
pub fn cancel(&self) {
let _ = self.cancel_tx.try_send(());
}
/// Check if cancelled
pub fn is_cancelled(&mut self) -> bool {
self.try_cancelled()
}
/// Whether the previous run was interrupted
pub fn last_run_interrupted(&self) -> bool {
self.last_run_interrupted
}
/// Generate list of ToolDefinitions for LLM from registered tools
fn build_tool_definitions(&self) -> Vec<ToolDefinition> {
self.tool_server.tool_definitions_sorted()
}
/// Build assistant response items from text blocks and tool calls
fn build_assistant_items(&self, text_blocks: &[String], tool_calls: &[ToolCall]) -> Vec<Item> {
let mut items = Vec::new();
// Add text as assistant message if present
let text = text_blocks.join("");
if !text.is_empty() {
items.push(Item::assistant_message(text));
}
// Add tool calls as ToolCall items
for call in tool_calls {
items.push(Item::tool_call_json(
&call.id,
&call.name,
call.input.clone(),
));
}
items
}
/// Build a request
fn build_request(&self, tool_definitions: &[ToolDefinition], context: &[Item]) -> Request {
let mut request = Request::new();
// Set system prompt
if let Some(ref system) = self.system_prompt {
request = request.system(system);
}
// Add items directly (Request now uses Items natively)
request = request.items(context.iter().cloned());
// Add tool definitions
for tool_def in tool_definitions {
request = request.tool(tool_def.clone());
}
// Apply request configuration
request = request.config(self.request_config.clone());
request
}
/// Hooks: on_prompt_submit
///
/// Called immediately after receiving a user message in `run()` (first time only).
async fn run_on_prompt_submit_hooks(
&self,
item: &mut Item,
) -> Result<OnPromptSubmitResult, WorkerError> {
for hook in &self.hooks.on_prompt_submit {
let result = hook.call(item).await?;
match result {
OnPromptSubmitResult::Continue => continue,
OnPromptSubmitResult::Cancel(reason) => {
return Ok(OnPromptSubmitResult::Cancel(reason));
}
}
}
Ok(OnPromptSubmitResult::Continue)
}
/// Hooks: pre_llm_request
///
/// Called before sending an LLM request for each turn.
async fn run_pre_llm_request_hooks(
&self,
) -> Result<(PreLlmRequestResult, Vec<Item>), WorkerError> {
let mut temp_context = self.history.clone();
for hook in &self.hooks.pre_llm_request {
let result = hook.call(&mut temp_context).await?;
match result {
PreLlmRequestResult::Continue => continue,
PreLlmRequestResult::Cancel(reason) => {
return Ok((PreLlmRequestResult::Cancel(reason), temp_context));
}
}
}
Ok((PreLlmRequestResult::Continue, temp_context))
}
/// Hooks: on_turn_end
async fn run_on_turn_end_hooks(&self) -> Result<OnTurnEndResult, WorkerError> {
let mut temp_items = self.history.clone();
for hook in &self.hooks.on_turn_end {
let result = hook.call(&mut temp_items).await?;
match result {
OnTurnEndResult::Finish => continue,
OnTurnEndResult::ContinueWithMessages(items) => {
return Ok(OnTurnEndResult::ContinueWithMessages(items));
}
OnTurnEndResult::Paused => return Ok(OnTurnEndResult::Paused),
}
}
Ok(OnTurnEndResult::Finish)
}
/// Hooks: on_abort
async fn run_on_abort_hooks(&self, reason: &str) -> Result<(), WorkerError> {
let mut reason = reason.to_string();
for hook in &self.hooks.on_abort {
hook.call(&mut reason).await?;
}
Ok(())
}
fn apply_stream_hook_result(result: StreamHookResult) -> Result<(), WorkerError> {
match result {
StreamHookResult::Continue => Ok(()),
StreamHookResult::Abort(reason) => Err(WorkerError::Aborted(reason)),
StreamHookResult::Pause => {
Err(WorkerError::Aborted("Paused by stream hook".to_string()))
}
}
}
async fn run_on_stream_chunk_hooks(
&self,
event: crate::event::Event,
) -> Result<(), WorkerError> {
let mut context = StreamChunkContext { event };
for hook in &self.hooks.on_stream_chunk {
let result = hook.call(&mut context).await?;
Self::apply_stream_hook_result(result)?;
}
Ok(())
}
async fn run_on_text_delta_hooks(
&self,
index: usize,
delta: String,
) -> Result<(), WorkerError> {
let mut context = TextDeltaContext { index, delta };
for hook in &self.hooks.on_text_delta {
let result = hook.call(&mut context).await?;
Self::apply_stream_hook_result(result)?;
}
Ok(())
}
async fn run_on_tool_call_delta_hooks(
&self,
index: usize,
delta_json_fragment: String,
) -> Result<(), WorkerError> {
let mut context = ToolCallDeltaContext {
index,
delta_json_fragment,
};
for hook in &self.hooks.on_tool_call_delta {
let result = hook.call(&mut context).await?;
Self::apply_stream_hook_result(result)?;
}
Ok(())
}
async fn run_on_stream_complete_hooks(
&self,
turn: usize,
event_count: usize,
) -> Result<(), WorkerError> {
let mut context = StreamCompleteContext { turn, event_count };
for hook in &self.hooks.on_stream_complete {
let result = hook.call(&mut context).await?;
Self::apply_stream_hook_result(result)?;
}
Ok(())
}
async fn finalize_interruption<T>(
&mut self,
result: Result<T, WorkerError>,
) -> Result<T, WorkerError> {
match result {
Ok(value) => Ok(value),
Err(err) => {
self.last_run_interrupted = true;
let reason = match &err {
WorkerError::Aborted(reason) => reason.clone(),
WorkerError::Cancelled => "Cancelled".to_string(),
_ => err.to_string(),
};
if let Err(hook_err) = self.run_on_abort_hooks(&reason).await {
self.last_run_interrupted = true;
return Err(hook_err);
}
Err(err)
}
}
}
/// Check for pending tool calls (for resuming from Pause)
fn get_pending_tool_calls(&self) -> Option<Vec<ToolCall>> {
// Find the last ToolCall items that don't have corresponding ToolResult
let mut pending_calls = Vec::new();
let mut answered_call_ids = std::collections::HashSet::new();
// First pass: collect all answered call IDs
for item in &self.history {
if let Item::ToolResult { call_id, .. } = item {
answered_call_ids.insert(call_id.clone());
}
}
// Second pass: find unanswered tool calls
for item in &self.history {
if let Item::ToolCall {
call_id,
name,
arguments,
..
} = item
{
if !answered_call_ids.contains(call_id) {
let input = serde_json::from_str(arguments)
.unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new()));
pending_calls.push(ToolCall {
id: call_id.clone(),
name: name.clone(),
input,
});
}
}
}
if pending_calls.is_empty() {
None
} else {
Some(pending_calls)
}
}
/// Execute tools in parallel
///
/// After running pre_tool_call hooks on all tools,
/// executes approved tools in parallel and applies post_tool_call hooks to results.
async fn execute_tools(
&mut self,
tool_calls: Vec<ToolCall>,
) -> Result<ToolExecutionResult, WorkerError> {
use futures::future::join_all;
// Map from tool call ID to (ToolCall, Meta, Tool)
// Retained because it's needed for PostToolCall hooks
let mut call_info_map = HashMap::new();
// Phase 1: Apply pre_tool_call hooks (determine skip/abort)
let mut approved_calls = Vec::new();
for mut tool_call in tool_calls {
// Get tool definition
if let Some((meta, tool)) = self.tool_server.get_tool(&tool_call.name) {
// Create context
let mut context = ToolCallContext {
call: tool_call.clone(),
meta,
tool,
};
let mut skip = false;
for hook in &self.hooks.pre_tool_call {
let result = hook
.call(&mut context)
.await
.inspect_err(|_| self.last_run_interrupted = true)?;
match result {
PreToolCallResult::Continue => {}
PreToolCallResult::Skip => {
skip = true;
break;
}
PreToolCallResult::Abort(reason) => {
self.last_run_interrupted = true;
return Err(WorkerError::Aborted(reason));
}
PreToolCallResult::Pause => {
self.last_run_interrupted = true;
return Ok(ToolExecutionResult::Paused);
}
}
}
// Reflect changes made by hooks
tool_call = context.call;
// Save to map (only if executing)
if !skip {
call_info_map.insert(
tool_call.id.clone(),
(
tool_call.clone(),
context.meta.clone(),
context.tool.clone(),
),
);
approved_calls.push(tool_call);
}
} else {
// Unknown tools go into approved list as-is (will error at execution)
// Hooks are not applied (no Meta available)
approved_calls.push(tool_call);
}
}
// Phase 2: Execute approved tools in parallel (cancellable)
let futures: Vec<_> = approved_calls
.into_iter()
.map(|tool_call| {
let tool_server = self.tool_server.clone();
async move {
let input_json = serde_json::to_string(&tool_call.input).unwrap_or_default();
match tool_server.call_tool(&tool_call.name, &input_json).await {
Ok(content) => ToolResult::success(&tool_call.id, content),
Err(e) => ToolResult::error(&tool_call.id, e.to_string()),
}
}
})
.collect();
// Make tool execution cancellable
let mut results = tokio::select! {
results = join_all(futures) => results,
cancel = self.cancel_rx.recv() => {
if cancel.is_some() {
info!("Tool execution cancelled");
}
self.timeline.abort_current_block();
self.last_run_interrupted = true;
return Err(WorkerError::Cancelled);
}
};
// Phase 2.5: Apply output processor (store large results externally)
if let Some(ref processor) = self.output_processor {
for tool_result in &mut results {
if !tool_result.is_error {
match processor.process(tool_result.content.clone()).await {
Ok(processed) => tool_result.content = processed,
Err(e) => {
warn!(error = %e, "Output processor failed, keeping original content");
}
}
}
}
}
// Phase 3: Apply post_tool_call hooks
for tool_result in &mut results {
// Get saved information
if let Some((tool_call, meta, tool)) = call_info_map.get(&tool_result.tool_use_id) {
let mut context = PostToolCallContext {
call: tool_call.clone(),
result: tool_result.clone(),
meta: meta.clone(),
tool: tool.clone(),
};
for hook in &self.hooks.post_tool_call {
let result = hook
.call(&mut context)
.await
.inspect_err(|_| self.last_run_interrupted = true)?;
match result {
PostToolCallResult::Continue => {}
PostToolCallResult::Abort(reason) => {
self.last_run_interrupted = true;
return Err(WorkerError::Aborted(reason));
}
}
}
// Reflect hook-modified results
*tool_result = context.result;
}
}
Ok(ToolExecutionResult::Completed(results))
}
/// Internal turn execution logic
async fn run_turn_loop(&mut self) -> Result<WorkerResult, WorkerError> {
self.reset_interruption_state();
self.drain_cancel_queue();
let tool_definitions = self.build_tool_definitions();
info!(
item_count = self.history.len(),
tool_count = tool_definitions.len(),
"Starting worker run"
);
// Resume check: Pending tool calls
if let Some(tool_calls) = self.get_pending_tool_calls() {
info!("Resuming pending tool calls");
match self.execute_tools(tool_calls).await {
Ok(ToolExecutionResult::Paused) => {
self.last_run_interrupted = true;
return Ok(WorkerResult::Paused);
}
Ok(ToolExecutionResult::Completed(results)) => {
for result in results {
self.history.push(Item::tool_result(
&result.tool_use_id,
&result.content,
));
}
// Continue to loop
}
Err(err) => {
self.last_run_interrupted = true;
return Err(err);
}
}
}
loop {
// Check for cancellation
if self.try_cancelled() {
info!("Execution cancelled");
self.timeline.abort_current_block();
self.last_run_interrupted = true;
return Err(WorkerError::Cancelled);
}
// Notify turn start
let current_turn = self.turn_count;
debug!(turn = current_turn, "Turn start");
for notifier in &self.turn_notifiers {
notifier.on_turn_start(current_turn);
}
// Hook: pre_llm_request
let (control, request_context) = self
.run_pre_llm_request_hooks()
.await
.inspect_err(|_| self.last_run_interrupted = true)?;
match control {
PreLlmRequestResult::Cancel(reason) => {
info!(reason = %reason, "Aborted by hook");
for notifier in &self.turn_notifiers {
notifier.on_turn_end(current_turn);
}
self.last_run_interrupted = true;
return Err(WorkerError::Aborted(reason));
}
PreLlmRequestResult::Continue => {}
}
// Build request
let request = self.build_request(&tool_definitions, &request_context);
debug!(
item_count = request.items.len(),
tool_count = request.tools.len(),
has_system = request.system_prompt.is_some(),
"Sending request to LLM"
);
// Stream processing
debug!("Starting stream...");
let mut event_count = 0;
// Get stream (cancellable)
let mut stream = tokio::select! {
stream_result = self.client.stream(request) => stream_result
.inspect_err(|_| self.last_run_interrupted = true)?,
cancel = self.cancel_rx.recv() => {
if cancel.is_some() {
info!("Cancelled before stream started");
}
self.timeline.abort_current_block();
self.last_run_interrupted = true;
return Err(WorkerError::Cancelled);
}
};
loop {
tokio::select! {
// Receive event from stream
event_result = stream.next() => {
match event_result {
Some(result) => {
match &result {
Ok(event) => {
trace!(event = ?event, "Received event");
event_count += 1;
}
Err(e) => {
warn!(error = %e, "Stream error");
}
}
let event = result
.inspect_err(|_| self.last_run_interrupted = true)?;
self.timeline.dispatch(&event);
self.run_on_stream_chunk_hooks(event.clone())
.await
.inspect_err(|_| self.last_run_interrupted = true)?;
if let crate::llm_client::event::Event::BlockDelta(delta) = &event {
match &delta.delta {
crate::llm_client::event::DeltaContent::Text(text) => {
self.run_on_text_delta_hooks(delta.index, text.clone())
.await
.inspect_err(|_| self.last_run_interrupted = true)?;
}
crate::llm_client::event::DeltaContent::InputJson(json_fragment) => {
self.run_on_tool_call_delta_hooks(delta.index, json_fragment.clone())
.await
.inspect_err(|_| self.last_run_interrupted = true)?;
}
crate::llm_client::event::DeltaContent::Thinking(_) => {}
}
}
}
None => break, // Stream ended
}
}
// Wait for cancellation
cancel = self.cancel_rx.recv() => {
if cancel.is_some() {
info!("Stream cancelled");
}
self.timeline.abort_current_block();
self.last_run_interrupted = true;
return Err(WorkerError::Cancelled);
}
}
}
self.run_on_stream_complete_hooks(current_turn, event_count)
.await
.inspect_err(|_| self.last_run_interrupted = true)?;
debug!(event_count = event_count, "Stream completed");
// Notify turn end
for notifier in &self.turn_notifiers {
notifier.on_turn_end(current_turn);
}
self.turn_count += 1;
// Get collected results
let text_blocks = self.text_block_collector.take_collected();
let tool_calls = self.tool_call_collector.take_collected();
// Add assistant items to history
let assistant_items = self.build_assistant_items(&text_blocks, &tool_calls);
self.history.extend(assistant_items);
if tool_calls.is_empty() {
// No tool calls → determine turn end
let turn_result = self
.run_on_turn_end_hooks()
.await
.inspect_err(|_| self.last_run_interrupted = true)?;
match turn_result {
OnTurnEndResult::Finish => {
self.last_run_interrupted = false;
return Ok(WorkerResult::Finished);
}
OnTurnEndResult::ContinueWithMessages(additional) => {
self.history.extend(additional);
continue;
}
OnTurnEndResult::Paused => {
self.last_run_interrupted = true;
return Ok(WorkerResult::Paused);
}
}
}
// Execute tools
match self.execute_tools(tool_calls).await {
Ok(ToolExecutionResult::Paused) => {
self.last_run_interrupted = true;
return Ok(WorkerResult::Paused);
}
Ok(ToolExecutionResult::Completed(results)) => {
for result in results {
self.history.push(Item::tool_result(
&result.tool_use_id,
&result.content,
));
}
}
Err(err) => {
self.last_run_interrupted = true;
return Err(err);
}
}
// Check turn limit (after assistant items and tool results are in history)
if let Some(max) = self.max_turns {
if self.turn_count >= max as usize {
info!(turn_count = self.turn_count, max_turns = max, "Turn limit reached");
self.last_run_interrupted = false;
return Ok(WorkerResult::LimitReached);
}
}
}
}
/// Resume execution (from Paused state)
///
/// Resumes turn processing from current state without adding a new user message to history.
pub async fn resume(&mut self) -> Result<WorkerResult, WorkerError> {
self.reset_interruption_state();
let result = self.run_turn_loop().await;
self.finalize_interruption(result).await
}
}
// =============================================================================
// Mutable State-Specific Implementation
// =============================================================================
impl<C: LlmClient> Worker<C, Mutable> {
/// Create a new Worker (in Mutable state)
pub fn new(client: C) -> Self {
let text_block_collector = TextBlockCollector::new();
let tool_call_collector = ToolCallCollector::new();
let mut timeline = Timeline::new();
let (cancel_tx, cancel_rx) = mpsc::channel(1);
// Register collectors with Timeline
timeline.on_text_block(text_block_collector.clone());
timeline.on_tool_use_block(tool_call_collector.clone());
Self {
client,
timeline,
text_block_collector,
tool_call_collector,
tool_server: ToolServer::new().handle(),
hooks: HookRegistry::new(),
system_prompt: None,
history: Vec::new(),
locked_prefix_len: 0,
turn_count: 0,
max_turns: None,
turn_notifiers: Vec::new(),
request_config: RequestConfig::default(),
last_run_interrupted: false,
output_processor: None,
cancel_tx,
cancel_rx,
_state: PhantomData,
}
}
/// Register a tool
///
/// Registered tools are automatically executed when called by the LLM.
/// Registering a tool with the same name will result in an error.
///
/// Available only in Mutable state.
pub fn register_tool(
&mut self,
factory: WorkerToolDefinition,
) -> Result<(), ToolRegistryError> {
match self.tool_server.register_tool(factory) {
Ok(()) => Ok(()),
Err(ToolServerError::DuplicateName(name)) => {
Err(ToolRegistryError::DuplicateName(name))
}
Err(ToolServerError::ToolNotFound(_) | ToolServerError::ToolExecution(_)) => {
unreachable!("register_tool should only fail with DuplicateName")
}
}
}
/// Register multiple tools
///
/// Available only in Mutable state.
pub fn register_tools(
&mut self,
factories: impl IntoIterator<Item = WorkerToolDefinition>,
) -> Result<(), ToolRegistryError> {
match self.tool_server.register_tools(factories) {
Ok(()) => Ok(()),
Err(ToolServerError::DuplicateName(name)) => {
Err(ToolRegistryError::DuplicateName(name))
}
Err(ToolServerError::ToolNotFound(_) | ToolServerError::ToolExecution(_)) => {
unreachable!("register_tools should only fail with DuplicateName")
}
}
}
/// Set system prompt (builder pattern)
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
/// Set system prompt (mutable reference version)
pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
self.system_prompt = Some(prompt.into());
}
/// Set maximum tokens (builder pattern)
///
/// # Examples
///
/// ```ignore
/// let worker = Worker::new(client)
/// .system_prompt("You are a helpful assistant.")
/// .max_tokens(4096);
/// ```
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.request_config.max_tokens = Some(max_tokens);
self
}
/// Set temperature (builder pattern)
///
/// # Examples
///
/// ```ignore
/// let worker = Worker::new(client)
/// .temperature(0.7);
/// ```
pub fn temperature(mut self, temperature: f32) -> Self {
self.request_config.temperature = Some(temperature);
self
}
/// Set top_p (builder pattern)
pub fn top_p(mut self, top_p: f32) -> Self {
self.request_config.top_p = Some(top_p);
self
}
/// Set top_k (builder pattern)
pub fn top_k(mut self, top_k: u32) -> Self {
self.request_config.top_k = Some(top_k);
self
}
/// Add stop sequence (builder pattern)
pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
self.request_config.stop_sequences.push(sequence.into());
self
}
/// Set request configuration at once (builder pattern)
///
/// # Examples
///
/// ```ignore
/// let config = RequestConfig::new()
/// .with_max_tokens(4096)
/// .with_temperature(0.7);
///
/// let worker = Worker::new(client)
/// .system_prompt("...")
/// .with_config(config);
/// ```
pub fn with_config(mut self, config: RequestConfig) -> Self {
self.request_config = config;
self
}
/// Validate current configuration against the provider
///
/// Returns an error if there are unsupported settings.
/// Call at the end of the chain to detect configuration issues early.
///
/// # Examples
///
/// ```ignore
/// let worker = Worker::new(client)
/// .temperature(0.7)
/// .top_k(40)
/// .validate()?; // Error if using OpenAI since top_k is not supported
/// ```
///
/// # Returns
/// * `Ok(Self)` - Validation successful
/// * `Err(WorkerError::ConfigWarnings)` - Has unsupported settings
pub fn validate(self) -> Result<Self, WorkerError> {
let warnings = self.client.validate_config(&self.request_config);
if warnings.is_empty() {
Ok(self)
} else {
Err(WorkerError::ConfigWarnings(warnings))
}
}
/// Get a mutable reference to history
///
/// Available only in Mutable state. History can be freely edited.
pub fn history_mut(&mut self) -> &mut Vec<Item> {
&mut self.history
}
/// Set history
pub fn set_history(&mut self, items: Vec<Item>) {
self.history = items;
}
/// Add an item to history (builder pattern)
pub fn with_item(mut self, item: Item) -> Self {
self.history.push(item);
self
}
/// Add an item to history
pub fn push_item(&mut self, item: Item) {
self.history.push(item);
}
/// Add multiple items to history (builder pattern)
pub fn with_items(mut self, items: impl IntoIterator<Item = Item>) -> Self {
self.history.extend(items);
self
}
/// Add multiple items to history
pub fn extend_history(&mut self, items: impl IntoIterator<Item = Item>) {
self.history.extend(items);
}
/// Clear history
pub fn clear_history(&mut self) {
self.history.clear();
}
/// Set the turn count (for session restoration)
pub fn set_turn_count(&mut self, count: usize) {
self.turn_count = count;
}
/// Set the maximum number of turns. None means unlimited.
pub fn set_max_turns(&mut self, max_turns: Option<u32>) {
self.max_turns = max_turns;
}
/// Set the last_run_interrupted flag (for session restoration)
pub fn set_last_run_interrupted(&mut self, interrupted: bool) {
self.last_run_interrupted = interrupted;
}
/// Set a tool output processor for handling large tool results.
///
/// When set, tool execution results are passed through this processor
/// before being placed into conversation history.
pub fn set_output_processor(&mut self, processor: Arc<dyn ToolOutputProcessor>) {
self.output_processor = Some(processor);
}
/// Apply configuration (reserved for future extensions)
#[allow(dead_code)]
pub fn config(self, _config: WorkerConfig) -> Self {
self
}
/// Lock and transition to CacheLocked state
///
/// This operation fixes the current system prompt and history as a "committed prefix".
/// After this, only appending to history is allowed, ensuring cache hits.
pub fn lock(self) -> Worker<C, CacheLocked> {
let locked_prefix_len = self.history.len();
Worker {
client: self.client,
timeline: self.timeline,
text_block_collector: self.text_block_collector,
tool_call_collector: self.tool_call_collector,
tool_server: self.tool_server,
hooks: self.hooks,
system_prompt: self.system_prompt,
history: self.history,
locked_prefix_len,
turn_count: self.turn_count,
max_turns: self.max_turns,
turn_notifiers: self.turn_notifiers,
request_config: self.request_config,
last_run_interrupted: self.last_run_interrupted,
output_processor: self.output_processor,
cancel_tx: self.cancel_tx,
cancel_rx: self.cancel_rx,
_state: PhantomData,
}
}
}
// =============================================================================
// CacheLocked State-Specific Implementation
// =============================================================================
impl<C: LlmClient> Worker<C, CacheLocked> {
/// Get the prefix length at lock time
pub fn locked_prefix_len(&self) -> usize {
self.locked_prefix_len
}
/// Unlock and return to Mutable state
///
/// Note: After this operation, subsequent requests may not hit the cache.
/// Use only when you need to edit history.
pub fn unlock(self) -> Worker<C, Mutable> {
Worker {
client: self.client,
timeline: self.timeline,
text_block_collector: self.text_block_collector,
tool_call_collector: self.tool_call_collector,
tool_server: self.tool_server,
hooks: self.hooks,
system_prompt: self.system_prompt,
history: self.history,
locked_prefix_len: 0,
turn_count: self.turn_count,
max_turns: self.max_turns,
turn_notifiers: self.turn_notifiers,
request_config: self.request_config,
last_run_interrupted: self.last_run_interrupted,
output_processor: self.output_processor,
cancel_tx: self.cancel_tx,
cancel_rx: self.cancel_rx,
_state: PhantomData,
}
}
}
#[cfg(test)]
mod tests {
// Basic tests only. Tests using LlmClient are done in integration tests.
}