From 982e0d2dbb71b6766cc33835b3519dd758611653 Mon Sep 17 00:00:00 2001 From: Hare Date: Sat, 11 Apr 2026 19:47:34 +0900 Subject: [PATCH] =?UTF-8?q?Worker=E3=81=AE=E3=83=AA=E3=83=95=E3=82=A1?= =?UTF-8?q?=E3=82=AF=E3=82=BF=E3=83=AA=E3=83=B3=E3=82=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../llm-worker/examples/worker_cancel_demo.rs | 15 +- crates/llm-worker/examples/worker_cli.rs | 4 +- crates/llm-worker/src/lib.rs | 2 +- crates/llm-worker/src/worker.rs | 225 +++++++++--------- crates/llm-worker/tests/worker_state_test.rs | 9 +- 5 files changed, 124 insertions(+), 131 deletions(-) diff --git a/crates/llm-worker/examples/worker_cancel_demo.rs b/crates/llm-worker/examples/worker_cancel_demo.rs index 2f8afd87..44dc9850 100644 --- a/crates/llm-worker/examples/worker_cancel_demo.rs +++ b/crates/llm-worker/examples/worker_cancel_demo.rs @@ -40,17 +40,12 @@ async fn main() -> Result<(), Box> { println!("πŸ“‘ Sending request to LLM..."); - // Mutable::run consumes self β†’ (Locked, WorkerResult) match worker.run("Tell me a very long story about a brave knight. Make it as detailed as possible with many paragraphs.").await { - Ok((_locked, WorkerResult::Finished)) => { - println!("βœ… Task completed normally"); - } - Ok((_locked, WorkerResult::Paused)) => { - println!("⏸️ Task paused"); - } - Ok((_locked, WorkerResult::LimitReached)) => { - println!("πŸ”’ Turn limit reached"); - } + Ok(out) => match out.result { + WorkerResult::Finished => println!("βœ… Task completed normally"), + WorkerResult::Paused => println!("⏸️ Task paused"), + WorkerResult::LimitReached => println!("πŸ”’ Turn limit reached"), + }, Err(e) => { println!("❌ Task error: {}", e); } diff --git a/crates/llm-worker/examples/worker_cli.rs b/crates/llm-worker/examples/worker_cli.rs index 68041b2c..fd61f84b 100644 --- a/crates/llm-worker/examples/worker_cli.rs +++ b/crates/llm-worker/examples/worker_cli.rs @@ -476,8 +476,8 @@ async fn main() -> Result<(), Box> { return Ok(()); } - let (mut locked, _) = match worker.run(first_input).await { - Ok(pair) => pair, + let mut locked = match worker.run(first_input).await { + Ok(out) => out.worker, Err(e) => { eprintln!("\n❌ Error: {}", e); return Ok(()); diff --git a/crates/llm-worker/src/lib.rs b/crates/llm-worker/src/lib.rs index 480e9513..48f80cf0 100644 --- a/crates/llm-worker/src/lib.rs +++ b/crates/llm-worker/src/lib.rs @@ -55,4 +55,4 @@ pub use handler::ToolUseBlockStart; pub use message::{ContentPart, Item, Message, Role}; pub use interceptor::Interceptor; pub use tool::{ToolCall, ToolResult}; -pub use worker::{ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult}; +pub use worker::{RunOutput, ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult}; diff --git a/crates/llm-worker/src/worker.rs b/crates/llm-worker/src/worker.rs index e8aae5e5..54486617 100644 --- a/crates/llm-worker/src/worker.rs +++ b/crates/llm-worker/src/worker.rs @@ -74,6 +74,16 @@ pub enum WorkerResult { LimitReached, } +/// Result of [`Worker::run()`] / [`Worker::resume()`]. +/// +/// Contains the `Locked` Worker (ready for subsequent runs) and the outcome. +pub struct RunOutput { + /// The Worker, now in Locked state. + pub worker: Worker, + /// Outcome of the turn. + pub result: WorkerResult, +} + /// Internal: tool execution result enum ToolExecutionResult { Completed(Vec), @@ -100,8 +110,9 @@ enum ToolExecutionResult { /// .system_prompt("You are a helpful assistant."); /// worker.register_tool(my_tool); /// -/// // Mutable::run() consumes self β†’ Locked -/// let (mut worker, _result) = worker.run("Hello").await?; +/// // Mutable::run() consumes self β†’ RunOutput { worker: Locked, result } +/// let out = worker.run("Hello").await?; +/// let mut worker = out.worker; /// /// // Locked::run() borrows &mut self /// worker.run("Follow-up").await?; @@ -109,7 +120,8 @@ enum ToolExecutionResult { /// // To edit between turns, unlock back to Mutable /// let mut worker = worker.unlock(); /// worker.history_mut().truncate(5); -/// let (mut worker, _result) = worker.run("Continue").await?; +/// let out = worker.run("Continue").await?; +/// let mut worker = out.worker; /// ``` pub struct Worker { /// LLM client @@ -669,32 +681,15 @@ impl Worker { "Starting worker run" ); - // Resume check: Pending tool calls + // Resume pending tool calls from a previous Pause if let Some(tool_calls) = self.get_pending_tool_calls() { info!("Resuming pending tool calls"); - match self.execute_tools(tool_calls).await { - Ok(ToolExecutionResult::Paused) => { - self.last_run_interrupted = true; - return Ok(WorkerResult::Paused); - } - Ok(ToolExecutionResult::Completed(results)) => { - for result in results { - self.history.push(Item::tool_result( - &result.tool_use_id, - &result.content, - )); - } - // Continue to loop - } - Err(err) => { - self.last_run_interrupted = true; - return Err(err); - } + if let Some(result) = self.execute_and_commit_tools(tool_calls).await? { + return Ok(result); } } loop { - // Check for cancellation if self.try_cancelled() { info!("Execution cancelled"); self.timeline.abort_current_block(); @@ -702,7 +697,6 @@ impl Worker { return Err(WorkerError::Cancelled); } - // Notify turn start let current_turn = self.turn_count; debug!(turn = current_turn, "Turn start"); for cb in &self.turn_start_cbs { @@ -723,84 +717,22 @@ impl Worker { PreRequestAction::Continue => {} } - // Build request + // Stream LLM response let request = self.build_request(&tool_definitions, &request_context); - debug!( - item_count = request.items.len(), - tool_count = request.tools.len(), - has_system = request.system_prompt.is_some(), - "Sending request to LLM" - ); + self.stream_response(request).await?; - // Stream processing - debug!("Starting stream..."); - let mut event_count = 0; - - // Get stream (cancellable) - let mut stream = tokio::select! { - stream_result = self.client.stream(request) => stream_result - .inspect_err(|_| self.last_run_interrupted = true)?, - cancel = self.cancel_rx.recv() => { - if cancel.is_some() { - info!("Cancelled before stream started"); - } - self.timeline.abort_current_block(); - self.last_run_interrupted = true; - return Err(WorkerError::Cancelled); - } - }; - - loop { - tokio::select! { - // Receive event from stream - event_result = stream.next() => { - match event_result { - Some(result) => { - match &result { - Ok(event) => { - trace!(event = ?event, "Received event"); - event_count += 1; - } - Err(e) => { - warn!(error = %e, "Stream error"); - } - } - let event = result - .inspect_err(|_| self.last_run_interrupted = true)?; - self.timeline.dispatch(&event); - } - None => break, // Stream ended - } - } - // Wait for cancellation - cancel = self.cancel_rx.recv() => { - if cancel.is_some() { - info!("Stream cancelled"); - } - self.timeline.abort_current_block(); - self.last_run_interrupted = true; - return Err(WorkerError::Cancelled); - } - } - } - debug!(event_count = event_count, "Stream completed"); - - // Notify turn end for cb in &self.turn_end_cbs { cb(current_turn); } self.turn_count += 1; - // Get collected results + // Collect and commit assistant items let text_blocks = self.text_block_collector.take_collected(); let tool_calls = self.tool_call_collector.take_collected(); - - // Add assistant items to history let assistant_items = self.build_assistant_items(&text_blocks, &tool_calls); self.history.extend(assistant_items); if tool_calls.is_empty() { - // No tool calls β†’ determine turn end via interceptor match self.interceptor.on_turn_end(&self.history).await { TurnEndAction::Finish => { self.last_run_interrupted = false; @@ -817,27 +749,10 @@ impl Worker { } } - // Execute tools - match self.execute_tools(tool_calls).await { - Ok(ToolExecutionResult::Paused) => { - self.last_run_interrupted = true; - return Ok(WorkerResult::Paused); - } - Ok(ToolExecutionResult::Completed(results)) => { - for result in results { - self.history.push(Item::tool_result( - &result.tool_use_id, - &result.content, - )); - } - } - Err(err) => { - self.last_run_interrupted = true; - return Err(err); - } + if let Some(result) = self.execute_and_commit_tools(tool_calls).await? { + return Ok(result); } - // Check turn limit (after assistant items and tool results are in history) if let Some(max) = self.max_turns { if self.turn_count >= max as usize { info!(turn_count = self.turn_count, max_turns = max, "Turn limit reached"); @@ -848,6 +763,90 @@ impl Worker { } } + /// Open a stream, dispatch all events to the timeline, handle cancellation. + async fn stream_response(&mut self, request: Request) -> Result<(), WorkerError> { + debug!( + item_count = request.items.len(), + tool_count = request.tools.len(), + has_system = request.system_prompt.is_some(), + "Sending request to LLM" + ); + + let mut stream = tokio::select! { + stream_result = self.client.stream(request) => stream_result + .inspect_err(|_| self.last_run_interrupted = true)?, + cancel = self.cancel_rx.recv() => { + if cancel.is_some() { + info!("Cancelled before stream started"); + } + self.timeline.abort_current_block(); + self.last_run_interrupted = true; + return Err(WorkerError::Cancelled); + } + }; + + let mut event_count: usize = 0; + loop { + tokio::select! { + event_result = stream.next() => { + match event_result { + Some(result) => { + match &result { + Ok(event) => { + trace!(event = ?event, "Received event"); + event_count += 1; + } + Err(e) => { + warn!(error = %e, "Stream error"); + } + } + let event = result + .inspect_err(|_| self.last_run_interrupted = true)?; + self.timeline.dispatch(&event); + } + None => break, + } + } + cancel = self.cancel_rx.recv() => { + if cancel.is_some() { + info!("Stream cancelled"); + } + self.timeline.abort_current_block(); + self.last_run_interrupted = true; + return Err(WorkerError::Cancelled); + } + } + } + debug!(event_count = event_count, "Stream completed"); + Ok(()) + } + + /// Execute tools and push results to history. + /// Returns `Some(result)` if execution should stop (Paused), + /// `None` if the turn loop should continue. + async fn execute_and_commit_tools( + &mut self, + tool_calls: Vec, + ) -> Result, WorkerError> { + match self.execute_tools(tool_calls).await { + Ok(ToolExecutionResult::Paused) => { + self.last_run_interrupted = true; + Ok(Some(WorkerResult::Paused)) + } + Ok(ToolExecutionResult::Completed(results)) => { + for result in results { + self.history + .push(Item::tool_result(&result.tool_use_id, &result.content)); + } + Ok(None) + } + Err(err) => { + self.last_run_interrupted = true; + Err(err) + } + } + } + } @@ -1088,21 +1087,19 @@ impl Worker { pub async fn run( self, user_input: impl Into, - ) -> Result<(Worker, WorkerResult), WorkerError> { + ) -> Result, WorkerError> { let mut locked = self.lock(); let result = locked.run(user_input).await?; - Ok((locked, result)) + Ok(RunOutput { worker: locked, result }) } /// Resume from Paused, consuming self and transitioning to Locked. /// /// Used after `unlock()` β†’ edit β†’ resume. - pub async fn resume( - self, - ) -> Result<(Worker, WorkerResult), WorkerError> { + pub async fn resume(self) -> Result, WorkerError> { let mut locked = self.lock(); let result = locked.resume().await?; - Ok((locked, result)) + Ok(RunOutput { worker: locked, result }) } /// Lock and transition to Locked state diff --git a/crates/llm-worker/tests/worker_state_test.rs b/crates/llm-worker/tests/worker_state_test.rs index 8828767c..14ca1197 100644 --- a/crates/llm-worker/tests/worker_state_test.rs +++ b/crates/llm-worker/tests/worker_state_test.rs @@ -211,8 +211,9 @@ async fn test_mutable_run_updates_history() -> Result<(), WorkerError> { let client = MockLlmClient::new(events); let worker = Worker::new(client); - // Execute (Mutable::run consumes self, returns (Locked, WorkerResult)) - let (worker, _result) = worker.run("Hi there").await?; + // Execute (Mutable::run consumes self, returns RunOutput) + let out = worker.run("Hi there").await?; + let worker = out.worker; // History is updated let history = worker.history(); @@ -352,8 +353,8 @@ async fn test_turn_count_increment() -> Result<(), WorkerError> { assert_eq!(worker.turn_count(), 0); - // First run consumes Mutable, returns Locked - let (mut worker, _) = worker.run("First").await?; + // First run consumes Mutable, returns RunOutput + let mut worker = worker.run("First").await?.worker; assert_eq!(worker.turn_count(), 1); // Subsequent runs on Locked take &mut self