feat: pause and resume

This commit is contained in:
Keisuke Hirata 2026-01-09 14:50:34 +09:00
parent 81107c6f5c
commit 33f1c218f2
2 changed files with 117 additions and 28 deletions

View File

@ -20,6 +20,8 @@ pub enum ControlFlow {
Skip,
/// 処理を中断
Abort(String),
/// 処理を一時停止(再開可能)
Pause,
}
/// ターン終了時の判定結果
@ -29,6 +31,8 @@ pub enum TurnResult {
Finish,
/// メッセージを追加してターン継続(自己修正など)
ContinueWithMessages(Vec<crate::Message>),
/// ターンを一時停止
Paused,
}
// =============================================================================

View File

@ -53,6 +53,25 @@ pub struct WorkerConfig {
_private: (),
}
// =============================================================================
// Worker Result Types
// =============================================================================
/// Workerの実行結果ステータス
#[derive(Debug)]
pub enum WorkerResult<'a> {
/// 完了(ユーザー入力待ち状態)
Finished(&'a [Message]),
/// 一時停止(再開可能)
Paused(&'a [Message]),
}
/// 内部用: ツール実行結果
enum ToolExecutionResult {
Completed(Vec<ToolResult>),
Paused,
}
// =============================================================================
// ターン制御用コールバック保持
// =============================================================================
@ -487,6 +506,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
ControlFlow::Continue => continue,
ControlFlow::Skip => return Ok(ControlFlow::Skip),
ControlFlow::Abort(reason) => return Ok(ControlFlow::Abort(reason)),
ControlFlow::Pause => return Ok(ControlFlow::Pause),
}
}
Ok(ControlFlow::Continue)
@ -501,11 +521,39 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
TurnResult::ContinueWithMessages(msgs) => {
return Ok(TurnResult::ContinueWithMessages(msgs));
}
TurnResult::Paused => return Ok(TurnResult::Paused),
}
}
Ok(TurnResult::Finish)
}
/// 未実行のツール呼び出しがあるかチェックPauseからの復帰用
fn get_pending_tool_calls(&self) -> Option<Vec<ToolCall>> {
let last_msg = self.history.last()?;
if last_msg.role != Role::Assistant {
return None;
}
let mut calls = Vec::new();
if let MessageContent::Parts(parts) = &last_msg.content {
for part in parts {
if let ContentPart::ToolUse { id, name, input } = part {
calls.push(ToolCall {
id: id.clone(),
name: name.clone(),
input: input.clone(),
});
}
}
}
if calls.is_empty() {
None
} else {
Some(calls)
}
}
/// ツールを並列実行
///
/// 全てのツールに対してbefore_tool_callフックを実行後、
@ -513,7 +561,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
async fn execute_tools(
&self,
tool_calls: Vec<ToolCall>,
) -> Result<Vec<ToolResult>, WorkerError> {
) -> Result<ToolExecutionResult, WorkerError> {
use futures::future::join_all;
// Phase 1: before_tool_call フックを適用(スキップ/中断を判定)
@ -531,6 +579,9 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
ControlFlow::Abort(reason) => {
return Err(WorkerError::Aborted(reason));
}
ControlFlow::Pause => {
return Ok(ToolExecutionResult::Paused);
}
}
}
if !skip {
@ -573,15 +624,22 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
ControlFlow::Abort(reason) => {
return Err(WorkerError::Aborted(reason));
}
ControlFlow::Pause => {
// after_tool_callでのPauseは結果を受け入れた後、次の処理前に止まる動作とする
// ここではContinue扱いとし、on_message_send等でPauseすることを期待する
// あるいはここでのPauseをサポートする場合は戻り値を調整する必要がある
// 現状はログを出してContinue
warn!("ControlFlow::Pause in after_tool_call is treated as Continue");
}
}
}
}
Ok(results)
Ok(ToolExecutionResult::Completed(results))
}
/// 内部で使用するターン実行ロジック
async fn run_turn_loop(&mut self) -> Result<(), WorkerError> {
async fn run_turn_loop(&mut self) -> Result<WorkerResult<'_>, WorkerError> {
let tool_definitions = self.build_tool_definitions();
info!(
@ -590,6 +648,20 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
"Starting worker run"
);
// Resume check: Pending tool calls
if let Some(tool_calls) = self.get_pending_tool_calls() {
info!("Resuming pending tool calls");
match self.execute_tools(tool_calls).await? {
ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)),
ToolExecutionResult::Completed(results) => {
for result in results {
self.history.push(Message::tool_result(&result.tool_use_id, &result.content));
}
// Continue to loop
}
}
}
loop {
// ターン開始を通知
let current_turn = self.turn_count;
@ -600,14 +672,20 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
// Hook: on_message_send
let control = self.run_on_message_send_hooks().await?;
if let ControlFlow::Abort(reason) = control {
match control {
ControlFlow::Abort(reason) => {
warn!(reason = %reason, "Aborted by hook");
// ターン終了を通知(異常終了)
for notifier in &self.turn_notifiers {
notifier.on_turn_end(current_turn);
}
return Err(WorkerError::Aborted(reason));
}
ControlFlow::Pause | ControlFlow::Skip => {
// Skip or Pause -> Pause the worker
return Ok(WorkerResult::Paused(&self.history));
}
ControlFlow::Continue => {}
}
// リクエスト構築
let request = self.build_request(&tool_definitions);
@ -659,24 +737,35 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
let turn_result = self.run_on_turn_end_hooks().await?;
match turn_result {
TurnResult::Finish => {
return Ok(());
return Ok(WorkerResult::Finished(&self.history));
}
TurnResult::ContinueWithMessages(additional) => {
self.history.extend(additional);
continue;
}
TurnResult::Paused => {
return Ok(WorkerResult::Paused(&self.history));
}
}
}
// ツール実行
let tool_results = self.execute_tools(tool_calls).await?;
match self.execute_tools(tool_calls).await? {
ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)),
ToolExecutionResult::Completed(results) => {
for result in results {
self.history.push(Message::tool_result(&result.tool_use_id, &result.content));
}
}
}
}
}
// ツール結果を履歴に追加
for result in tool_results {
self.history
.push(Message::tool_result(&result.tool_use_id, &result.content));
}
}
/// 実行を再開Pause状態からの復帰
///
/// 新しいユーザーメッセージを履歴に追加せず、現在の状態からターン処理を再開する。
pub async fn resume(&mut self) -> Result<WorkerResult<'_>, WorkerError> {
self.run_turn_loop().await
}
}
@ -887,10 +976,9 @@ impl<C: LlmClient> Worker<C, Mutable> {
///
/// 注意: この関数は履歴を変更するため、キャッシュ保護が必要な場合は
/// `lock()` を呼んでからLocked状態で `run` を使用すること。
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<&[Message], WorkerError> {
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<WorkerResult<'_>, WorkerError> {
self.history.push(Message::user(user_input));
self.run_turn_loop().await?;
Ok(&self.history)
self.run_turn_loop().await
}
/// 複数メッセージでターンを実行Mutable状態
@ -899,10 +987,9 @@ impl<C: LlmClient> Worker<C, Mutable> {
pub async fn run_with_messages(
&mut self,
messages: Vec<Message>,
) -> Result<&[Message], WorkerError> {
) -> Result<WorkerResult<'_>, WorkerError> {
self.history.extend(messages);
self.run_turn_loop().await?;
Ok(&self.history)
self.run_turn_loop().await
}
}
@ -915,20 +1002,18 @@ impl<C: LlmClient> Worker<C, Locked> {
///
/// 新しいユーザーメッセージを履歴の末尾に追加し、LLMにリクエストを送信する。
/// ロック時点より前の履歴(プレフィックス)は不変であるため、キャッシュヒットが保証される。
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<&[Message], WorkerError> {
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<WorkerResult<'_>, WorkerError> {
self.history.push(Message::user(user_input));
self.run_turn_loop().await?;
Ok(&self.history)
self.run_turn_loop().await
}
/// 複数メッセージでターンを実行Locked状態
pub async fn run_with_messages(
&mut self,
messages: Vec<Message>,
) -> Result<&[Message], WorkerError> {
) -> Result<WorkerResult<'_>, WorkerError> {
self.history.extend(messages);
self.run_turn_loop().await?;
Ok(&self.history)
self.run_turn_loop().await
}
/// ロック時点のプレフィックス長を取得