From 33f1c218f265f5090128a0407b7b564eeb178249 Mon Sep 17 00:00:00 2001 From: Hare Date: Fri, 9 Jan 2026 14:50:34 +0900 Subject: [PATCH] feat: pause and resume --- llm-worker/src/hook.rs | 4 ++ llm-worker/src/worker.rs | 141 +++++++++++++++++++++++++++++++-------- 2 files changed, 117 insertions(+), 28 deletions(-) diff --git a/llm-worker/src/hook.rs b/llm-worker/src/hook.rs index 8991d33..25b0fc5 100644 --- a/llm-worker/src/hook.rs +++ b/llm-worker/src/hook.rs @@ -20,6 +20,8 @@ pub enum ControlFlow { Skip, /// 処理を中断 Abort(String), + /// 処理を一時停止(再開可能) + Pause, } /// ターン終了時の判定結果 @@ -29,6 +31,8 @@ pub enum TurnResult { Finish, /// メッセージを追加してターン継続(自己修正など) ContinueWithMessages(Vec), + /// ターンを一時停止 + Paused, } // ============================================================================= diff --git a/llm-worker/src/worker.rs b/llm-worker/src/worker.rs index 3b87429..5425f71 100644 --- a/llm-worker/src/worker.rs +++ b/llm-worker/src/worker.rs @@ -53,6 +53,25 @@ pub struct WorkerConfig { _private: (), } +// ============================================================================= +// Worker Result Types +// ============================================================================= + +/// Workerの実行結果(ステータス) +#[derive(Debug)] +pub enum WorkerResult<'a> { + /// 完了(ユーザー入力待ち状態) + Finished(&'a [Message]), + /// 一時停止(再開可能) + Paused(&'a [Message]), +} + +/// 内部用: ツール実行結果 +enum ToolExecutionResult { + Completed(Vec), + Paused, +} + // ============================================================================= // ターン制御用コールバック保持 // ============================================================================= @@ -487,6 +506,7 @@ impl Worker { ControlFlow::Continue => continue, ControlFlow::Skip => return Ok(ControlFlow::Skip), ControlFlow::Abort(reason) => return Ok(ControlFlow::Abort(reason)), + ControlFlow::Pause => return Ok(ControlFlow::Pause), } } Ok(ControlFlow::Continue) @@ -501,11 +521,39 @@ impl Worker { TurnResult::ContinueWithMessages(msgs) => { return Ok(TurnResult::ContinueWithMessages(msgs)); } + TurnResult::Paused => return Ok(TurnResult::Paused), } } Ok(TurnResult::Finish) } + /// 未実行のツール呼び出しがあるかチェック(Pauseからの復帰用) + fn get_pending_tool_calls(&self) -> Option> { + let last_msg = self.history.last()?; + if last_msg.role != Role::Assistant { + return None; + } + + let mut calls = Vec::new(); + if let MessageContent::Parts(parts) = &last_msg.content { + for part in parts { + if let ContentPart::ToolUse { id, name, input } = part { + calls.push(ToolCall { + id: id.clone(), + name: name.clone(), + input: input.clone(), + }); + } + } + } + + if calls.is_empty() { + None + } else { + Some(calls) + } + } + /// ツールを並列実行 /// /// 全てのツールに対してbefore_tool_callフックを実行後、 @@ -513,7 +561,7 @@ impl Worker { async fn execute_tools( &self, tool_calls: Vec, - ) -> Result, WorkerError> { + ) -> Result { use futures::future::join_all; // Phase 1: before_tool_call フックを適用(スキップ/中断を判定) @@ -531,6 +579,9 @@ impl Worker { ControlFlow::Abort(reason) => { return Err(WorkerError::Aborted(reason)); } + ControlFlow::Pause => { + return Ok(ToolExecutionResult::Paused); + } } } if !skip { @@ -573,15 +624,22 @@ impl Worker { ControlFlow::Abort(reason) => { return Err(WorkerError::Aborted(reason)); } + ControlFlow::Pause => { + // after_tool_callでのPauseは結果を受け入れた後、次の処理前に止まる動作とする + // ここではContinue扱いとし、on_message_send等でPauseすることを期待する + // あるいはここでのPauseをサポートする場合は戻り値を調整する必要がある + // 現状はログを出してContinue + warn!("ControlFlow::Pause in after_tool_call is treated as Continue"); + } } } } - Ok(results) + Ok(ToolExecutionResult::Completed(results)) } /// 内部で使用するターン実行ロジック - async fn run_turn_loop(&mut self) -> Result<(), WorkerError> { + async fn run_turn_loop(&mut self) -> Result, WorkerError> { let tool_definitions = self.build_tool_definitions(); info!( @@ -590,6 +648,20 @@ impl Worker { "Starting worker run" ); + // Resume check: Pending tool calls + if let Some(tool_calls) = self.get_pending_tool_calls() { + info!("Resuming pending tool calls"); + match self.execute_tools(tool_calls).await? { + ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)), + ToolExecutionResult::Completed(results) => { + for result in results { + self.history.push(Message::tool_result(&result.tool_use_id, &result.content)); + } + // Continue to loop + } + } + } + loop { // ターン開始を通知 let current_turn = self.turn_count; @@ -600,13 +672,19 @@ impl Worker { // Hook: on_message_send let control = self.run_on_message_send_hooks().await?; - if let ControlFlow::Abort(reason) = control { - warn!(reason = %reason, "Aborted by hook"); - // ターン終了を通知(異常終了) - for notifier in &self.turn_notifiers { - notifier.on_turn_end(current_turn); + match control { + ControlFlow::Abort(reason) => { + warn!(reason = %reason, "Aborted by hook"); + for notifier in &self.turn_notifiers { + notifier.on_turn_end(current_turn); + } + return Err(WorkerError::Aborted(reason)); } - return Err(WorkerError::Aborted(reason)); + ControlFlow::Pause | ControlFlow::Skip => { + // Skip or Pause -> Pause the worker + return Ok(WorkerResult::Paused(&self.history)); + } + ControlFlow::Continue => {} } // リクエスト構築 @@ -659,25 +737,36 @@ impl Worker { let turn_result = self.run_on_turn_end_hooks().await?; match turn_result { TurnResult::Finish => { - return Ok(()); + return Ok(WorkerResult::Finished(&self.history)); } TurnResult::ContinueWithMessages(additional) => { self.history.extend(additional); continue; } + TurnResult::Paused => { + return Ok(WorkerResult::Paused(&self.history)); + } } } // ツール実行 - let tool_results = self.execute_tools(tool_calls).await?; - - // ツール結果を履歴に追加 - for result in tool_results { - self.history - .push(Message::tool_result(&result.tool_use_id, &result.content)); + match self.execute_tools(tool_calls).await? { + ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)), + ToolExecutionResult::Completed(results) => { + for result in results { + self.history.push(Message::tool_result(&result.tool_use_id, &result.content)); + } + } } } } + + /// 実行を再開(Pause状態からの復帰) + /// + /// 新しいユーザーメッセージを履歴に追加せず、現在の状態からターン処理を再開する。 + pub async fn resume(&mut self) -> Result, WorkerError> { + self.run_turn_loop().await + } } // ============================================================================= @@ -887,10 +976,9 @@ impl Worker { /// /// 注意: この関数は履歴を変更するため、キャッシュ保護が必要な場合は /// `lock()` を呼んでからLocked状態で `run` を使用すること。 - pub async fn run(&mut self, user_input: impl Into) -> Result<&[Message], WorkerError> { + pub async fn run(&mut self, user_input: impl Into) -> Result, WorkerError> { self.history.push(Message::user(user_input)); - self.run_turn_loop().await?; - Ok(&self.history) + self.run_turn_loop().await } /// 複数メッセージでターンを実行(Mutable状態) @@ -899,10 +987,9 @@ impl Worker { pub async fn run_with_messages( &mut self, messages: Vec, - ) -> Result<&[Message], WorkerError> { + ) -> Result, WorkerError> { self.history.extend(messages); - self.run_turn_loop().await?; - Ok(&self.history) + self.run_turn_loop().await } } @@ -915,20 +1002,18 @@ impl Worker { /// /// 新しいユーザーメッセージを履歴の末尾に追加し、LLMにリクエストを送信する。 /// ロック時点より前の履歴(プレフィックス)は不変であるため、キャッシュヒットが保証される。 - pub async fn run(&mut self, user_input: impl Into) -> Result<&[Message], WorkerError> { + pub async fn run(&mut self, user_input: impl Into) -> Result, WorkerError> { self.history.push(Message::user(user_input)); - self.run_turn_loop().await?; - Ok(&self.history) + self.run_turn_loop().await } /// 複数メッセージでターンを実行(Locked状態) pub async fn run_with_messages( &mut self, messages: Vec, - ) -> Result<&[Message], WorkerError> { + ) -> Result, WorkerError> { self.history.extend(messages); - self.run_turn_loop().await?; - Ok(&self.history) + self.run_turn_loop().await } /// ロック時点のプレフィックス長を取得