Workerのリファクタリング
This commit is contained in:
parent
7249a8ee6a
commit
02b266dce7
|
|
@ -40,17 +40,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
||||
println!("📡 Sending request to LLM...");
|
||||
|
||||
// Mutable::run consumes self → (Locked, WorkerResult)
|
||||
match worker.run("Tell me a very long story about a brave knight. Make it as detailed as possible with many paragraphs.").await {
|
||||
Ok((_locked, WorkerResult::Finished)) => {
|
||||
println!("✅ Task completed normally");
|
||||
}
|
||||
Ok((_locked, WorkerResult::Paused)) => {
|
||||
println!("⏸️ Task paused");
|
||||
}
|
||||
Ok((_locked, WorkerResult::LimitReached)) => {
|
||||
println!("🔒 Turn limit reached");
|
||||
}
|
||||
Ok(out) => match out.result {
|
||||
WorkerResult::Finished => println!("✅ Task completed normally"),
|
||||
WorkerResult::Paused => println!("⏸️ Task paused"),
|
||||
WorkerResult::LimitReached => println!("🔒 Turn limit reached"),
|
||||
},
|
||||
Err(e) => {
|
||||
println!("❌ Task error: {}", e);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -476,8 +476,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
return Ok(());
|
||||
}
|
||||
|
||||
let (mut locked, _) = match worker.run(first_input).await {
|
||||
Ok(pair) => pair,
|
||||
let mut locked = match worker.run(first_input).await {
|
||||
Ok(out) => out.worker,
|
||||
Err(e) => {
|
||||
eprintln!("\n❌ Error: {}", e);
|
||||
return Ok(());
|
||||
|
|
|
|||
|
|
@ -55,4 +55,4 @@ pub use handler::ToolUseBlockStart;
|
|||
pub use message::{ContentPart, Item, Message, Role};
|
||||
pub use interceptor::Interceptor;
|
||||
pub use tool::{ToolCall, ToolResult};
|
||||
pub use worker::{ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult};
|
||||
pub use worker::{RunOutput, ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult};
|
||||
|
|
|
|||
|
|
@ -74,6 +74,16 @@ pub enum WorkerResult {
|
|||
LimitReached,
|
||||
}
|
||||
|
||||
/// Result of [`Worker<C, Mutable>::run()`] / [`Worker<C, Mutable>::resume()`].
|
||||
///
|
||||
/// Contains the `Locked` Worker (ready for subsequent runs) and the outcome.
|
||||
pub struct RunOutput<C: LlmClient> {
|
||||
/// The Worker, now in Locked state.
|
||||
pub worker: Worker<C, Locked>,
|
||||
/// Outcome of the turn.
|
||||
pub result: WorkerResult,
|
||||
}
|
||||
|
||||
/// Internal: tool execution result
|
||||
enum ToolExecutionResult {
|
||||
Completed(Vec<ToolResult>),
|
||||
|
|
@ -100,8 +110,9 @@ enum ToolExecutionResult {
|
|||
/// .system_prompt("You are a helpful assistant.");
|
||||
/// worker.register_tool(my_tool);
|
||||
///
|
||||
/// // Mutable::run() consumes self → Locked
|
||||
/// let (mut worker, _result) = worker.run("Hello").await?;
|
||||
/// // Mutable::run() consumes self → RunOutput { worker: Locked, result }
|
||||
/// let out = worker.run("Hello").await?;
|
||||
/// let mut worker = out.worker;
|
||||
///
|
||||
/// // Locked::run() borrows &mut self
|
||||
/// worker.run("Follow-up").await?;
|
||||
|
|
@ -109,7 +120,8 @@ enum ToolExecutionResult {
|
|||
/// // To edit between turns, unlock back to Mutable
|
||||
/// let mut worker = worker.unlock();
|
||||
/// worker.history_mut().truncate(5);
|
||||
/// let (mut worker, _result) = worker.run("Continue").await?;
|
||||
/// let out = worker.run("Continue").await?;
|
||||
/// let mut worker = out.worker;
|
||||
/// ```
|
||||
pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
||||
/// LLM client
|
||||
|
|
@ -669,32 +681,15 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
"Starting worker run"
|
||||
);
|
||||
|
||||
// Resume check: Pending tool calls
|
||||
// Resume pending tool calls from a previous Pause
|
||||
if let Some(tool_calls) = self.get_pending_tool_calls() {
|
||||
info!("Resuming pending tool calls");
|
||||
match self.execute_tools(tool_calls).await {
|
||||
Ok(ToolExecutionResult::Paused) => {
|
||||
self.last_run_interrupted = true;
|
||||
return Ok(WorkerResult::Paused);
|
||||
}
|
||||
Ok(ToolExecutionResult::Completed(results)) => {
|
||||
for result in results {
|
||||
self.history.push(Item::tool_result(
|
||||
&result.tool_use_id,
|
||||
&result.content,
|
||||
));
|
||||
}
|
||||
// Continue to loop
|
||||
}
|
||||
Err(err) => {
|
||||
self.last_run_interrupted = true;
|
||||
return Err(err);
|
||||
}
|
||||
if let Some(result) = self.execute_and_commit_tools(tool_calls).await? {
|
||||
return Ok(result);
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
// Check for cancellation
|
||||
if self.try_cancelled() {
|
||||
info!("Execution cancelled");
|
||||
self.timeline.abort_current_block();
|
||||
|
|
@ -702,7 +697,6 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
return Err(WorkerError::Cancelled);
|
||||
}
|
||||
|
||||
// Notify turn start
|
||||
let current_turn = self.turn_count;
|
||||
debug!(turn = current_turn, "Turn start");
|
||||
for cb in &self.turn_start_cbs {
|
||||
|
|
@ -723,84 +717,22 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
PreRequestAction::Continue => {}
|
||||
}
|
||||
|
||||
// Build request
|
||||
// Stream LLM response
|
||||
let request = self.build_request(&tool_definitions, &request_context);
|
||||
debug!(
|
||||
item_count = request.items.len(),
|
||||
tool_count = request.tools.len(),
|
||||
has_system = request.system_prompt.is_some(),
|
||||
"Sending request to LLM"
|
||||
);
|
||||
self.stream_response(request).await?;
|
||||
|
||||
// Stream processing
|
||||
debug!("Starting stream...");
|
||||
let mut event_count = 0;
|
||||
|
||||
// Get stream (cancellable)
|
||||
let mut stream = tokio::select! {
|
||||
stream_result = self.client.stream(request) => stream_result
|
||||
.inspect_err(|_| self.last_run_interrupted = true)?,
|
||||
cancel = self.cancel_rx.recv() => {
|
||||
if cancel.is_some() {
|
||||
info!("Cancelled before stream started");
|
||||
}
|
||||
self.timeline.abort_current_block();
|
||||
self.last_run_interrupted = true;
|
||||
return Err(WorkerError::Cancelled);
|
||||
}
|
||||
};
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Receive event from stream
|
||||
event_result = stream.next() => {
|
||||
match event_result {
|
||||
Some(result) => {
|
||||
match &result {
|
||||
Ok(event) => {
|
||||
trace!(event = ?event, "Received event");
|
||||
event_count += 1;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "Stream error");
|
||||
}
|
||||
}
|
||||
let event = result
|
||||
.inspect_err(|_| self.last_run_interrupted = true)?;
|
||||
self.timeline.dispatch(&event);
|
||||
}
|
||||
None => break, // Stream ended
|
||||
}
|
||||
}
|
||||
// Wait for cancellation
|
||||
cancel = self.cancel_rx.recv() => {
|
||||
if cancel.is_some() {
|
||||
info!("Stream cancelled");
|
||||
}
|
||||
self.timeline.abort_current_block();
|
||||
self.last_run_interrupted = true;
|
||||
return Err(WorkerError::Cancelled);
|
||||
}
|
||||
}
|
||||
}
|
||||
debug!(event_count = event_count, "Stream completed");
|
||||
|
||||
// Notify turn end
|
||||
for cb in &self.turn_end_cbs {
|
||||
cb(current_turn);
|
||||
}
|
||||
self.turn_count += 1;
|
||||
|
||||
// Get collected results
|
||||
// Collect and commit assistant items
|
||||
let text_blocks = self.text_block_collector.take_collected();
|
||||
let tool_calls = self.tool_call_collector.take_collected();
|
||||
|
||||
// Add assistant items to history
|
||||
let assistant_items = self.build_assistant_items(&text_blocks, &tool_calls);
|
||||
self.history.extend(assistant_items);
|
||||
|
||||
if tool_calls.is_empty() {
|
||||
// No tool calls → determine turn end via interceptor
|
||||
match self.interceptor.on_turn_end(&self.history).await {
|
||||
TurnEndAction::Finish => {
|
||||
self.last_run_interrupted = false;
|
||||
|
|
@ -817,27 +749,10 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
}
|
||||
}
|
||||
|
||||
// Execute tools
|
||||
match self.execute_tools(tool_calls).await {
|
||||
Ok(ToolExecutionResult::Paused) => {
|
||||
self.last_run_interrupted = true;
|
||||
return Ok(WorkerResult::Paused);
|
||||
}
|
||||
Ok(ToolExecutionResult::Completed(results)) => {
|
||||
for result in results {
|
||||
self.history.push(Item::tool_result(
|
||||
&result.tool_use_id,
|
||||
&result.content,
|
||||
));
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
self.last_run_interrupted = true;
|
||||
return Err(err);
|
||||
}
|
||||
if let Some(result) = self.execute_and_commit_tools(tool_calls).await? {
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
// Check turn limit (after assistant items and tool results are in history)
|
||||
if let Some(max) = self.max_turns {
|
||||
if self.turn_count >= max as usize {
|
||||
info!(turn_count = self.turn_count, max_turns = max, "Turn limit reached");
|
||||
|
|
@ -848,6 +763,90 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Open a stream, dispatch all events to the timeline, handle cancellation.
|
||||
async fn stream_response(&mut self, request: Request) -> Result<(), WorkerError> {
|
||||
debug!(
|
||||
item_count = request.items.len(),
|
||||
tool_count = request.tools.len(),
|
||||
has_system = request.system_prompt.is_some(),
|
||||
"Sending request to LLM"
|
||||
);
|
||||
|
||||
let mut stream = tokio::select! {
|
||||
stream_result = self.client.stream(request) => stream_result
|
||||
.inspect_err(|_| self.last_run_interrupted = true)?,
|
||||
cancel = self.cancel_rx.recv() => {
|
||||
if cancel.is_some() {
|
||||
info!("Cancelled before stream started");
|
||||
}
|
||||
self.timeline.abort_current_block();
|
||||
self.last_run_interrupted = true;
|
||||
return Err(WorkerError::Cancelled);
|
||||
}
|
||||
};
|
||||
|
||||
let mut event_count: usize = 0;
|
||||
loop {
|
||||
tokio::select! {
|
||||
event_result = stream.next() => {
|
||||
match event_result {
|
||||
Some(result) => {
|
||||
match &result {
|
||||
Ok(event) => {
|
||||
trace!(event = ?event, "Received event");
|
||||
event_count += 1;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "Stream error");
|
||||
}
|
||||
}
|
||||
let event = result
|
||||
.inspect_err(|_| self.last_run_interrupted = true)?;
|
||||
self.timeline.dispatch(&event);
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
cancel = self.cancel_rx.recv() => {
|
||||
if cancel.is_some() {
|
||||
info!("Stream cancelled");
|
||||
}
|
||||
self.timeline.abort_current_block();
|
||||
self.last_run_interrupted = true;
|
||||
return Err(WorkerError::Cancelled);
|
||||
}
|
||||
}
|
||||
}
|
||||
debug!(event_count = event_count, "Stream completed");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Execute tools and push results to history.
|
||||
/// Returns `Some(result)` if execution should stop (Paused),
|
||||
/// `None` if the turn loop should continue.
|
||||
async fn execute_and_commit_tools(
|
||||
&mut self,
|
||||
tool_calls: Vec<ToolCall>,
|
||||
) -> Result<Option<WorkerResult>, WorkerError> {
|
||||
match self.execute_tools(tool_calls).await {
|
||||
Ok(ToolExecutionResult::Paused) => {
|
||||
self.last_run_interrupted = true;
|
||||
Ok(Some(WorkerResult::Paused))
|
||||
}
|
||||
Ok(ToolExecutionResult::Completed(results)) => {
|
||||
for result in results {
|
||||
self.history
|
||||
.push(Item::tool_result(&result.tool_use_id, &result.content));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
Err(err) => {
|
||||
self.last_run_interrupted = true;
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -1088,21 +1087,19 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
|||
pub async fn run(
|
||||
self,
|
||||
user_input: impl Into<String>,
|
||||
) -> Result<(Worker<C, Locked>, WorkerResult), WorkerError> {
|
||||
) -> Result<RunOutput<C>, WorkerError> {
|
||||
let mut locked = self.lock();
|
||||
let result = locked.run(user_input).await?;
|
||||
Ok((locked, result))
|
||||
Ok(RunOutput { worker: locked, result })
|
||||
}
|
||||
|
||||
/// Resume from Paused, consuming self and transitioning to Locked.
|
||||
///
|
||||
/// Used after `unlock()` → edit → resume.
|
||||
pub async fn resume(
|
||||
self,
|
||||
) -> Result<(Worker<C, Locked>, WorkerResult), WorkerError> {
|
||||
pub async fn resume(self) -> Result<RunOutput<C>, WorkerError> {
|
||||
let mut locked = self.lock();
|
||||
let result = locked.resume().await?;
|
||||
Ok((locked, result))
|
||||
Ok(RunOutput { worker: locked, result })
|
||||
}
|
||||
|
||||
/// Lock and transition to Locked state
|
||||
|
|
|
|||
|
|
@ -211,8 +211,9 @@ async fn test_mutable_run_updates_history() -> Result<(), WorkerError> {
|
|||
let client = MockLlmClient::new(events);
|
||||
let worker = Worker::new(client);
|
||||
|
||||
// Execute (Mutable::run consumes self, returns (Locked, WorkerResult))
|
||||
let (worker, _result) = worker.run("Hi there").await?;
|
||||
// Execute (Mutable::run consumes self, returns RunOutput)
|
||||
let out = worker.run("Hi there").await?;
|
||||
let worker = out.worker;
|
||||
|
||||
// History is updated
|
||||
let history = worker.history();
|
||||
|
|
@ -352,8 +353,8 @@ async fn test_turn_count_increment() -> Result<(), WorkerError> {
|
|||
|
||||
assert_eq!(worker.turn_count(), 0);
|
||||
|
||||
// First run consumes Mutable, returns Locked
|
||||
let (mut worker, _) = worker.run("First").await?;
|
||||
// First run consumes Mutable, returns RunOutput
|
||||
let mut worker = worker.run("First").await?.worker;
|
||||
assert_eq!(worker.turn_count(), 1);
|
||||
|
||||
// Subsequent runs on Locked take &mut self
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user