Workerのリファクタリング
This commit is contained in:
parent
61a977779e
commit
982e0d2dbb
|
|
@ -40,17 +40,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
|
||||||
println!("📡 Sending request to LLM...");
|
println!("📡 Sending request to LLM...");
|
||||||
|
|
||||||
// Mutable::run consumes self → (Locked, WorkerResult)
|
|
||||||
match worker.run("Tell me a very long story about a brave knight. Make it as detailed as possible with many paragraphs.").await {
|
match worker.run("Tell me a very long story about a brave knight. Make it as detailed as possible with many paragraphs.").await {
|
||||||
Ok((_locked, WorkerResult::Finished)) => {
|
Ok(out) => match out.result {
|
||||||
println!("✅ Task completed normally");
|
WorkerResult::Finished => println!("✅ Task completed normally"),
|
||||||
}
|
WorkerResult::Paused => println!("⏸️ Task paused"),
|
||||||
Ok((_locked, WorkerResult::Paused)) => {
|
WorkerResult::LimitReached => println!("🔒 Turn limit reached"),
|
||||||
println!("⏸️ Task paused");
|
},
|
||||||
}
|
|
||||||
Ok((_locked, WorkerResult::LimitReached)) => {
|
|
||||||
println!("🔒 Turn limit reached");
|
|
||||||
}
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
println!("❌ Task error: {}", e);
|
println!("❌ Task error: {}", e);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -476,8 +476,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let (mut locked, _) = match worker.run(first_input).await {
|
let mut locked = match worker.run(first_input).await {
|
||||||
Ok(pair) => pair,
|
Ok(out) => out.worker,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("\n❌ Error: {}", e);
|
eprintln!("\n❌ Error: {}", e);
|
||||||
return Ok(());
|
return Ok(());
|
||||||
|
|
|
||||||
|
|
@ -55,4 +55,4 @@ pub use handler::ToolUseBlockStart;
|
||||||
pub use message::{ContentPart, Item, Message, Role};
|
pub use message::{ContentPart, Item, Message, Role};
|
||||||
pub use interceptor::Interceptor;
|
pub use interceptor::Interceptor;
|
||||||
pub use tool::{ToolCall, ToolResult};
|
pub use tool::{ToolCall, ToolResult};
|
||||||
pub use worker::{ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult};
|
pub use worker::{RunOutput, ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult};
|
||||||
|
|
|
||||||
|
|
@ -74,6 +74,16 @@ pub enum WorkerResult {
|
||||||
LimitReached,
|
LimitReached,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Result of [`Worker<C, Mutable>::run()`] / [`Worker<C, Mutable>::resume()`].
|
||||||
|
///
|
||||||
|
/// Contains the `Locked` Worker (ready for subsequent runs) and the outcome.
|
||||||
|
pub struct RunOutput<C: LlmClient> {
|
||||||
|
/// The Worker, now in Locked state.
|
||||||
|
pub worker: Worker<C, Locked>,
|
||||||
|
/// Outcome of the turn.
|
||||||
|
pub result: WorkerResult,
|
||||||
|
}
|
||||||
|
|
||||||
/// Internal: tool execution result
|
/// Internal: tool execution result
|
||||||
enum ToolExecutionResult {
|
enum ToolExecutionResult {
|
||||||
Completed(Vec<ToolResult>),
|
Completed(Vec<ToolResult>),
|
||||||
|
|
@ -100,8 +110,9 @@ enum ToolExecutionResult {
|
||||||
/// .system_prompt("You are a helpful assistant.");
|
/// .system_prompt("You are a helpful assistant.");
|
||||||
/// worker.register_tool(my_tool);
|
/// worker.register_tool(my_tool);
|
||||||
///
|
///
|
||||||
/// // Mutable::run() consumes self → Locked
|
/// // Mutable::run() consumes self → RunOutput { worker: Locked, result }
|
||||||
/// let (mut worker, _result) = worker.run("Hello").await?;
|
/// let out = worker.run("Hello").await?;
|
||||||
|
/// let mut worker = out.worker;
|
||||||
///
|
///
|
||||||
/// // Locked::run() borrows &mut self
|
/// // Locked::run() borrows &mut self
|
||||||
/// worker.run("Follow-up").await?;
|
/// worker.run("Follow-up").await?;
|
||||||
|
|
@ -109,7 +120,8 @@ enum ToolExecutionResult {
|
||||||
/// // To edit between turns, unlock back to Mutable
|
/// // To edit between turns, unlock back to Mutable
|
||||||
/// let mut worker = worker.unlock();
|
/// let mut worker = worker.unlock();
|
||||||
/// worker.history_mut().truncate(5);
|
/// worker.history_mut().truncate(5);
|
||||||
/// let (mut worker, _result) = worker.run("Continue").await?;
|
/// let out = worker.run("Continue").await?;
|
||||||
|
/// let mut worker = out.worker;
|
||||||
/// ```
|
/// ```
|
||||||
pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
||||||
/// LLM client
|
/// LLM client
|
||||||
|
|
@ -669,32 +681,15 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
"Starting worker run"
|
"Starting worker run"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Resume check: Pending tool calls
|
// Resume pending tool calls from a previous Pause
|
||||||
if let Some(tool_calls) = self.get_pending_tool_calls() {
|
if let Some(tool_calls) = self.get_pending_tool_calls() {
|
||||||
info!("Resuming pending tool calls");
|
info!("Resuming pending tool calls");
|
||||||
match self.execute_tools(tool_calls).await {
|
if let Some(result) = self.execute_and_commit_tools(tool_calls).await? {
|
||||||
Ok(ToolExecutionResult::Paused) => {
|
return Ok(result);
|
||||||
self.last_run_interrupted = true;
|
|
||||||
return Ok(WorkerResult::Paused);
|
|
||||||
}
|
|
||||||
Ok(ToolExecutionResult::Completed(results)) => {
|
|
||||||
for result in results {
|
|
||||||
self.history.push(Item::tool_result(
|
|
||||||
&result.tool_use_id,
|
|
||||||
&result.content,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
// Continue to loop
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
self.last_run_interrupted = true;
|
|
||||||
return Err(err);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
// Check for cancellation
|
|
||||||
if self.try_cancelled() {
|
if self.try_cancelled() {
|
||||||
info!("Execution cancelled");
|
info!("Execution cancelled");
|
||||||
self.timeline.abort_current_block();
|
self.timeline.abort_current_block();
|
||||||
|
|
@ -702,7 +697,6 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
return Err(WorkerError::Cancelled);
|
return Err(WorkerError::Cancelled);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Notify turn start
|
|
||||||
let current_turn = self.turn_count;
|
let current_turn = self.turn_count;
|
||||||
debug!(turn = current_turn, "Turn start");
|
debug!(turn = current_turn, "Turn start");
|
||||||
for cb in &self.turn_start_cbs {
|
for cb in &self.turn_start_cbs {
|
||||||
|
|
@ -723,84 +717,22 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
PreRequestAction::Continue => {}
|
PreRequestAction::Continue => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build request
|
// Stream LLM response
|
||||||
let request = self.build_request(&tool_definitions, &request_context);
|
let request = self.build_request(&tool_definitions, &request_context);
|
||||||
debug!(
|
self.stream_response(request).await?;
|
||||||
item_count = request.items.len(),
|
|
||||||
tool_count = request.tools.len(),
|
|
||||||
has_system = request.system_prompt.is_some(),
|
|
||||||
"Sending request to LLM"
|
|
||||||
);
|
|
||||||
|
|
||||||
// Stream processing
|
|
||||||
debug!("Starting stream...");
|
|
||||||
let mut event_count = 0;
|
|
||||||
|
|
||||||
// Get stream (cancellable)
|
|
||||||
let mut stream = tokio::select! {
|
|
||||||
stream_result = self.client.stream(request) => stream_result
|
|
||||||
.inspect_err(|_| self.last_run_interrupted = true)?,
|
|
||||||
cancel = self.cancel_rx.recv() => {
|
|
||||||
if cancel.is_some() {
|
|
||||||
info!("Cancelled before stream started");
|
|
||||||
}
|
|
||||||
self.timeline.abort_current_block();
|
|
||||||
self.last_run_interrupted = true;
|
|
||||||
return Err(WorkerError::Cancelled);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
loop {
|
|
||||||
tokio::select! {
|
|
||||||
// Receive event from stream
|
|
||||||
event_result = stream.next() => {
|
|
||||||
match event_result {
|
|
||||||
Some(result) => {
|
|
||||||
match &result {
|
|
||||||
Ok(event) => {
|
|
||||||
trace!(event = ?event, "Received event");
|
|
||||||
event_count += 1;
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!(error = %e, "Stream error");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let event = result
|
|
||||||
.inspect_err(|_| self.last_run_interrupted = true)?;
|
|
||||||
self.timeline.dispatch(&event);
|
|
||||||
}
|
|
||||||
None => break, // Stream ended
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Wait for cancellation
|
|
||||||
cancel = self.cancel_rx.recv() => {
|
|
||||||
if cancel.is_some() {
|
|
||||||
info!("Stream cancelled");
|
|
||||||
}
|
|
||||||
self.timeline.abort_current_block();
|
|
||||||
self.last_run_interrupted = true;
|
|
||||||
return Err(WorkerError::Cancelled);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
debug!(event_count = event_count, "Stream completed");
|
|
||||||
|
|
||||||
// Notify turn end
|
|
||||||
for cb in &self.turn_end_cbs {
|
for cb in &self.turn_end_cbs {
|
||||||
cb(current_turn);
|
cb(current_turn);
|
||||||
}
|
}
|
||||||
self.turn_count += 1;
|
self.turn_count += 1;
|
||||||
|
|
||||||
// Get collected results
|
// Collect and commit assistant items
|
||||||
let text_blocks = self.text_block_collector.take_collected();
|
let text_blocks = self.text_block_collector.take_collected();
|
||||||
let tool_calls = self.tool_call_collector.take_collected();
|
let tool_calls = self.tool_call_collector.take_collected();
|
||||||
|
|
||||||
// Add assistant items to history
|
|
||||||
let assistant_items = self.build_assistant_items(&text_blocks, &tool_calls);
|
let assistant_items = self.build_assistant_items(&text_blocks, &tool_calls);
|
||||||
self.history.extend(assistant_items);
|
self.history.extend(assistant_items);
|
||||||
|
|
||||||
if tool_calls.is_empty() {
|
if tool_calls.is_empty() {
|
||||||
// No tool calls → determine turn end via interceptor
|
|
||||||
match self.interceptor.on_turn_end(&self.history).await {
|
match self.interceptor.on_turn_end(&self.history).await {
|
||||||
TurnEndAction::Finish => {
|
TurnEndAction::Finish => {
|
||||||
self.last_run_interrupted = false;
|
self.last_run_interrupted = false;
|
||||||
|
|
@ -817,27 +749,10 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute tools
|
if let Some(result) = self.execute_and_commit_tools(tool_calls).await? {
|
||||||
match self.execute_tools(tool_calls).await {
|
return Ok(result);
|
||||||
Ok(ToolExecutionResult::Paused) => {
|
|
||||||
self.last_run_interrupted = true;
|
|
||||||
return Ok(WorkerResult::Paused);
|
|
||||||
}
|
|
||||||
Ok(ToolExecutionResult::Completed(results)) => {
|
|
||||||
for result in results {
|
|
||||||
self.history.push(Item::tool_result(
|
|
||||||
&result.tool_use_id,
|
|
||||||
&result.content,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
self.last_run_interrupted = true;
|
|
||||||
return Err(err);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check turn limit (after assistant items and tool results are in history)
|
|
||||||
if let Some(max) = self.max_turns {
|
if let Some(max) = self.max_turns {
|
||||||
if self.turn_count >= max as usize {
|
if self.turn_count >= max as usize {
|
||||||
info!(turn_count = self.turn_count, max_turns = max, "Turn limit reached");
|
info!(turn_count = self.turn_count, max_turns = max, "Turn limit reached");
|
||||||
|
|
@ -848,6 +763,90 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Open a stream, dispatch all events to the timeline, handle cancellation.
|
||||||
|
async fn stream_response(&mut self, request: Request) -> Result<(), WorkerError> {
|
||||||
|
debug!(
|
||||||
|
item_count = request.items.len(),
|
||||||
|
tool_count = request.tools.len(),
|
||||||
|
has_system = request.system_prompt.is_some(),
|
||||||
|
"Sending request to LLM"
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut stream = tokio::select! {
|
||||||
|
stream_result = self.client.stream(request) => stream_result
|
||||||
|
.inspect_err(|_| self.last_run_interrupted = true)?,
|
||||||
|
cancel = self.cancel_rx.recv() => {
|
||||||
|
if cancel.is_some() {
|
||||||
|
info!("Cancelled before stream started");
|
||||||
|
}
|
||||||
|
self.timeline.abort_current_block();
|
||||||
|
self.last_run_interrupted = true;
|
||||||
|
return Err(WorkerError::Cancelled);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut event_count: usize = 0;
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
event_result = stream.next() => {
|
||||||
|
match event_result {
|
||||||
|
Some(result) => {
|
||||||
|
match &result {
|
||||||
|
Ok(event) => {
|
||||||
|
trace!(event = ?event, "Received event");
|
||||||
|
event_count += 1;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!(error = %e, "Stream error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let event = result
|
||||||
|
.inspect_err(|_| self.last_run_interrupted = true)?;
|
||||||
|
self.timeline.dispatch(&event);
|
||||||
|
}
|
||||||
|
None => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cancel = self.cancel_rx.recv() => {
|
||||||
|
if cancel.is_some() {
|
||||||
|
info!("Stream cancelled");
|
||||||
|
}
|
||||||
|
self.timeline.abort_current_block();
|
||||||
|
self.last_run_interrupted = true;
|
||||||
|
return Err(WorkerError::Cancelled);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
debug!(event_count = event_count, "Stream completed");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute tools and push results to history.
|
||||||
|
/// Returns `Some(result)` if execution should stop (Paused),
|
||||||
|
/// `None` if the turn loop should continue.
|
||||||
|
async fn execute_and_commit_tools(
|
||||||
|
&mut self,
|
||||||
|
tool_calls: Vec<ToolCall>,
|
||||||
|
) -> Result<Option<WorkerResult>, WorkerError> {
|
||||||
|
match self.execute_tools(tool_calls).await {
|
||||||
|
Ok(ToolExecutionResult::Paused) => {
|
||||||
|
self.last_run_interrupted = true;
|
||||||
|
Ok(Some(WorkerResult::Paused))
|
||||||
|
}
|
||||||
|
Ok(ToolExecutionResult::Completed(results)) => {
|
||||||
|
for result in results {
|
||||||
|
self.history
|
||||||
|
.push(Item::tool_result(&result.tool_use_id, &result.content));
|
||||||
|
}
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
self.last_run_interrupted = true;
|
||||||
|
Err(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1088,21 +1087,19 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
||||||
pub async fn run(
|
pub async fn run(
|
||||||
self,
|
self,
|
||||||
user_input: impl Into<String>,
|
user_input: impl Into<String>,
|
||||||
) -> Result<(Worker<C, Locked>, WorkerResult), WorkerError> {
|
) -> Result<RunOutput<C>, WorkerError> {
|
||||||
let mut locked = self.lock();
|
let mut locked = self.lock();
|
||||||
let result = locked.run(user_input).await?;
|
let result = locked.run(user_input).await?;
|
||||||
Ok((locked, result))
|
Ok(RunOutput { worker: locked, result })
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Resume from Paused, consuming self and transitioning to Locked.
|
/// Resume from Paused, consuming self and transitioning to Locked.
|
||||||
///
|
///
|
||||||
/// Used after `unlock()` → edit → resume.
|
/// Used after `unlock()` → edit → resume.
|
||||||
pub async fn resume(
|
pub async fn resume(self) -> Result<RunOutput<C>, WorkerError> {
|
||||||
self,
|
|
||||||
) -> Result<(Worker<C, Locked>, WorkerResult), WorkerError> {
|
|
||||||
let mut locked = self.lock();
|
let mut locked = self.lock();
|
||||||
let result = locked.resume().await?;
|
let result = locked.resume().await?;
|
||||||
Ok((locked, result))
|
Ok(RunOutput { worker: locked, result })
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Lock and transition to Locked state
|
/// Lock and transition to Locked state
|
||||||
|
|
|
||||||
|
|
@ -211,8 +211,9 @@ async fn test_mutable_run_updates_history() -> Result<(), WorkerError> {
|
||||||
let client = MockLlmClient::new(events);
|
let client = MockLlmClient::new(events);
|
||||||
let worker = Worker::new(client);
|
let worker = Worker::new(client);
|
||||||
|
|
||||||
// Execute (Mutable::run consumes self, returns (Locked, WorkerResult))
|
// Execute (Mutable::run consumes self, returns RunOutput)
|
||||||
let (worker, _result) = worker.run("Hi there").await?;
|
let out = worker.run("Hi there").await?;
|
||||||
|
let worker = out.worker;
|
||||||
|
|
||||||
// History is updated
|
// History is updated
|
||||||
let history = worker.history();
|
let history = worker.history();
|
||||||
|
|
@ -352,8 +353,8 @@ async fn test_turn_count_increment() -> Result<(), WorkerError> {
|
||||||
|
|
||||||
assert_eq!(worker.turn_count(), 0);
|
assert_eq!(worker.turn_count(), 0);
|
||||||
|
|
||||||
// First run consumes Mutable, returns Locked
|
// First run consumes Mutable, returns RunOutput
|
||||||
let (mut worker, _) = worker.run("First").await?;
|
let mut worker = worker.run("First").await?.worker;
|
||||||
assert_eq!(worker.turn_count(), 1);
|
assert_eq!(worker.turn_count(), 1);
|
||||||
|
|
||||||
// Subsequent runs on Locked take &mut self
|
// Subsequent runs on Locked take &mut self
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user