Workerのリファクタリング

This commit is contained in:
Keisuke Hirata 2026-04-11 19:47:34 +09:00
parent 7249a8ee6a
commit 02b266dce7
5 changed files with 124 additions and 131 deletions

View File

@ -40,17 +40,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("📡 Sending request to LLM...");
// Mutable::run consumes self → (Locked, WorkerResult)
match worker.run("Tell me a very long story about a brave knight. Make it as detailed as possible with many paragraphs.").await {
Ok((_locked, WorkerResult::Finished)) => {
println!("✅ Task completed normally");
}
Ok((_locked, WorkerResult::Paused)) => {
println!("⏸️ Task paused");
}
Ok((_locked, WorkerResult::LimitReached)) => {
println!("🔒 Turn limit reached");
}
Ok(out) => match out.result {
WorkerResult::Finished => println!("✅ Task completed normally"),
WorkerResult::Paused => println!("⏸️ Task paused"),
WorkerResult::LimitReached => println!("🔒 Turn limit reached"),
},
Err(e) => {
println!("❌ Task error: {}", e);
}

View File

@ -476,8 +476,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
return Ok(());
}
let (mut locked, _) = match worker.run(first_input).await {
Ok(pair) => pair,
let mut locked = match worker.run(first_input).await {
Ok(out) => out.worker,
Err(e) => {
eprintln!("\n❌ Error: {}", e);
return Ok(());

View File

@ -55,4 +55,4 @@ pub use handler::ToolUseBlockStart;
pub use message::{ContentPart, Item, Message, Role};
pub use interceptor::Interceptor;
pub use tool::{ToolCall, ToolResult};
pub use worker::{ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult};
pub use worker::{RunOutput, ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult};

View File

@ -74,6 +74,16 @@ pub enum WorkerResult {
LimitReached,
}
/// Result of [`Worker<C, Mutable>::run()`] / [`Worker<C, Mutable>::resume()`].
///
/// Contains the `Locked` Worker (ready for subsequent runs) and the outcome.
pub struct RunOutput<C: LlmClient> {
/// The Worker, now in Locked state.
pub worker: Worker<C, Locked>,
/// Outcome of the turn.
pub result: WorkerResult,
}
/// Internal: tool execution result
enum ToolExecutionResult {
Completed(Vec<ToolResult>),
@ -100,8 +110,9 @@ enum ToolExecutionResult {
/// .system_prompt("You are a helpful assistant.");
/// worker.register_tool(my_tool);
///
/// // Mutable::run() consumes self → Locked
/// let (mut worker, _result) = worker.run("Hello").await?;
/// // Mutable::run() consumes self → RunOutput { worker: Locked, result }
/// let out = worker.run("Hello").await?;
/// let mut worker = out.worker;
///
/// // Locked::run() borrows &mut self
/// worker.run("Follow-up").await?;
@ -109,7 +120,8 @@ enum ToolExecutionResult {
/// // To edit between turns, unlock back to Mutable
/// let mut worker = worker.unlock();
/// worker.history_mut().truncate(5);
/// let (mut worker, _result) = worker.run("Continue").await?;
/// let out = worker.run("Continue").await?;
/// let mut worker = out.worker;
/// ```
pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
/// LLM client
@ -669,32 +681,15 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
"Starting worker run"
);
// Resume check: Pending tool calls
// Resume pending tool calls from a previous Pause
if let Some(tool_calls) = self.get_pending_tool_calls() {
info!("Resuming pending tool calls");
match self.execute_tools(tool_calls).await {
Ok(ToolExecutionResult::Paused) => {
self.last_run_interrupted = true;
return Ok(WorkerResult::Paused);
}
Ok(ToolExecutionResult::Completed(results)) => {
for result in results {
self.history.push(Item::tool_result(
&result.tool_use_id,
&result.content,
));
}
// Continue to loop
}
Err(err) => {
self.last_run_interrupted = true;
return Err(err);
}
if let Some(result) = self.execute_and_commit_tools(tool_calls).await? {
return Ok(result);
}
}
loop {
// Check for cancellation
if self.try_cancelled() {
info!("Execution cancelled");
self.timeline.abort_current_block();
@ -702,7 +697,6 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
return Err(WorkerError::Cancelled);
}
// Notify turn start
let current_turn = self.turn_count;
debug!(turn = current_turn, "Turn start");
for cb in &self.turn_start_cbs {
@ -723,84 +717,22 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
PreRequestAction::Continue => {}
}
// Build request
// Stream LLM response
let request = self.build_request(&tool_definitions, &request_context);
debug!(
item_count = request.items.len(),
tool_count = request.tools.len(),
has_system = request.system_prompt.is_some(),
"Sending request to LLM"
);
self.stream_response(request).await?;
// Stream processing
debug!("Starting stream...");
let mut event_count = 0;
// Get stream (cancellable)
let mut stream = tokio::select! {
stream_result = self.client.stream(request) => stream_result
.inspect_err(|_| self.last_run_interrupted = true)?,
cancel = self.cancel_rx.recv() => {
if cancel.is_some() {
info!("Cancelled before stream started");
}
self.timeline.abort_current_block();
self.last_run_interrupted = true;
return Err(WorkerError::Cancelled);
}
};
loop {
tokio::select! {
// Receive event from stream
event_result = stream.next() => {
match event_result {
Some(result) => {
match &result {
Ok(event) => {
trace!(event = ?event, "Received event");
event_count += 1;
}
Err(e) => {
warn!(error = %e, "Stream error");
}
}
let event = result
.inspect_err(|_| self.last_run_interrupted = true)?;
self.timeline.dispatch(&event);
}
None => break, // Stream ended
}
}
// Wait for cancellation
cancel = self.cancel_rx.recv() => {
if cancel.is_some() {
info!("Stream cancelled");
}
self.timeline.abort_current_block();
self.last_run_interrupted = true;
return Err(WorkerError::Cancelled);
}
}
}
debug!(event_count = event_count, "Stream completed");
// Notify turn end
for cb in &self.turn_end_cbs {
cb(current_turn);
}
self.turn_count += 1;
// Get collected results
// Collect and commit assistant items
let text_blocks = self.text_block_collector.take_collected();
let tool_calls = self.tool_call_collector.take_collected();
// Add assistant items to history
let assistant_items = self.build_assistant_items(&text_blocks, &tool_calls);
self.history.extend(assistant_items);
if tool_calls.is_empty() {
// No tool calls → determine turn end via interceptor
match self.interceptor.on_turn_end(&self.history).await {
TurnEndAction::Finish => {
self.last_run_interrupted = false;
@ -817,27 +749,10 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
}
}
// Execute tools
match self.execute_tools(tool_calls).await {
Ok(ToolExecutionResult::Paused) => {
self.last_run_interrupted = true;
return Ok(WorkerResult::Paused);
}
Ok(ToolExecutionResult::Completed(results)) => {
for result in results {
self.history.push(Item::tool_result(
&result.tool_use_id,
&result.content,
));
}
}
Err(err) => {
self.last_run_interrupted = true;
return Err(err);
}
if let Some(result) = self.execute_and_commit_tools(tool_calls).await? {
return Ok(result);
}
// Check turn limit (after assistant items and tool results are in history)
if let Some(max) = self.max_turns {
if self.turn_count >= max as usize {
info!(turn_count = self.turn_count, max_turns = max, "Turn limit reached");
@ -848,6 +763,90 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
}
}
/// Open a stream, dispatch all events to the timeline, handle cancellation.
async fn stream_response(&mut self, request: Request) -> Result<(), WorkerError> {
debug!(
item_count = request.items.len(),
tool_count = request.tools.len(),
has_system = request.system_prompt.is_some(),
"Sending request to LLM"
);
let mut stream = tokio::select! {
stream_result = self.client.stream(request) => stream_result
.inspect_err(|_| self.last_run_interrupted = true)?,
cancel = self.cancel_rx.recv() => {
if cancel.is_some() {
info!("Cancelled before stream started");
}
self.timeline.abort_current_block();
self.last_run_interrupted = true;
return Err(WorkerError::Cancelled);
}
};
let mut event_count: usize = 0;
loop {
tokio::select! {
event_result = stream.next() => {
match event_result {
Some(result) => {
match &result {
Ok(event) => {
trace!(event = ?event, "Received event");
event_count += 1;
}
Err(e) => {
warn!(error = %e, "Stream error");
}
}
let event = result
.inspect_err(|_| self.last_run_interrupted = true)?;
self.timeline.dispatch(&event);
}
None => break,
}
}
cancel = self.cancel_rx.recv() => {
if cancel.is_some() {
info!("Stream cancelled");
}
self.timeline.abort_current_block();
self.last_run_interrupted = true;
return Err(WorkerError::Cancelled);
}
}
}
debug!(event_count = event_count, "Stream completed");
Ok(())
}
/// Execute tools and push results to history.
/// Returns `Some(result)` if execution should stop (Paused),
/// `None` if the turn loop should continue.
async fn execute_and_commit_tools(
&mut self,
tool_calls: Vec<ToolCall>,
) -> Result<Option<WorkerResult>, WorkerError> {
match self.execute_tools(tool_calls).await {
Ok(ToolExecutionResult::Paused) => {
self.last_run_interrupted = true;
Ok(Some(WorkerResult::Paused))
}
Ok(ToolExecutionResult::Completed(results)) => {
for result in results {
self.history
.push(Item::tool_result(&result.tool_use_id, &result.content));
}
Ok(None)
}
Err(err) => {
self.last_run_interrupted = true;
Err(err)
}
}
}
}
@ -1088,21 +1087,19 @@ impl<C: LlmClient> Worker<C, Mutable> {
pub async fn run(
self,
user_input: impl Into<String>,
) -> Result<(Worker<C, Locked>, WorkerResult), WorkerError> {
) -> Result<RunOutput<C>, WorkerError> {
let mut locked = self.lock();
let result = locked.run(user_input).await?;
Ok((locked, result))
Ok(RunOutput { worker: locked, result })
}
/// Resume from Paused, consuming self and transitioning to Locked.
///
/// Used after `unlock()` → edit → resume.
pub async fn resume(
self,
) -> Result<(Worker<C, Locked>, WorkerResult), WorkerError> {
pub async fn resume(self) -> Result<RunOutput<C>, WorkerError> {
let mut locked = self.lock();
let result = locked.resume().await?;
Ok((locked, result))
Ok(RunOutput { worker: locked, result })
}
/// Lock and transition to Locked state

View File

@ -211,8 +211,9 @@ async fn test_mutable_run_updates_history() -> Result<(), WorkerError> {
let client = MockLlmClient::new(events);
let worker = Worker::new(client);
// Execute (Mutable::run consumes self, returns (Locked, WorkerResult))
let (worker, _result) = worker.run("Hi there").await?;
// Execute (Mutable::run consumes self, returns RunOutput)
let out = worker.run("Hi there").await?;
let worker = out.worker;
// History is updated
let history = worker.history();
@ -352,8 +353,8 @@ async fn test_turn_count_increment() -> Result<(), WorkerError> {
assert_eq!(worker.turn_count(), 0);
// First run consumes Mutable, returns Locked
let (mut worker, _) = worker.run("First").await?;
// First run consumes Mutable, returns RunOutput
let mut worker = worker.run("First").await?.worker;
assert_eq!(worker.turn_count(), 1);
// Subsequent runs on Locked take &mut self