feat: pause and resume

This commit is contained in:
Keisuke Hirata 2026-01-09 14:50:34 +09:00
parent 81107c6f5c
commit 33f1c218f2
2 changed files with 117 additions and 28 deletions

View File

@ -20,6 +20,8 @@ pub enum ControlFlow {
Skip, Skip,
/// 処理を中断 /// 処理を中断
Abort(String), Abort(String),
/// 処理を一時停止(再開可能)
Pause,
} }
/// ターン終了時の判定結果 /// ターン終了時の判定結果
@ -29,6 +31,8 @@ pub enum TurnResult {
Finish, Finish,
/// メッセージを追加してターン継続(自己修正など) /// メッセージを追加してターン継続(自己修正など)
ContinueWithMessages(Vec<crate::Message>), ContinueWithMessages(Vec<crate::Message>),
/// ターンを一時停止
Paused,
} }
// ============================================================================= // =============================================================================

View File

@ -53,6 +53,25 @@ pub struct WorkerConfig {
_private: (), _private: (),
} }
// =============================================================================
// Worker Result Types
// =============================================================================
/// Workerの実行結果ステータス
#[derive(Debug)]
pub enum WorkerResult<'a> {
/// 完了(ユーザー入力待ち状態)
Finished(&'a [Message]),
/// 一時停止(再開可能)
Paused(&'a [Message]),
}
/// 内部用: ツール実行結果
enum ToolExecutionResult {
Completed(Vec<ToolResult>),
Paused,
}
// ============================================================================= // =============================================================================
// ターン制御用コールバック保持 // ターン制御用コールバック保持
// ============================================================================= // =============================================================================
@ -487,6 +506,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
ControlFlow::Continue => continue, ControlFlow::Continue => continue,
ControlFlow::Skip => return Ok(ControlFlow::Skip), ControlFlow::Skip => return Ok(ControlFlow::Skip),
ControlFlow::Abort(reason) => return Ok(ControlFlow::Abort(reason)), ControlFlow::Abort(reason) => return Ok(ControlFlow::Abort(reason)),
ControlFlow::Pause => return Ok(ControlFlow::Pause),
} }
} }
Ok(ControlFlow::Continue) Ok(ControlFlow::Continue)
@ -501,11 +521,39 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
TurnResult::ContinueWithMessages(msgs) => { TurnResult::ContinueWithMessages(msgs) => {
return Ok(TurnResult::ContinueWithMessages(msgs)); return Ok(TurnResult::ContinueWithMessages(msgs));
} }
TurnResult::Paused => return Ok(TurnResult::Paused),
} }
} }
Ok(TurnResult::Finish) Ok(TurnResult::Finish)
} }
/// 未実行のツール呼び出しがあるかチェックPauseからの復帰用
fn get_pending_tool_calls(&self) -> Option<Vec<ToolCall>> {
let last_msg = self.history.last()?;
if last_msg.role != Role::Assistant {
return None;
}
let mut calls = Vec::new();
if let MessageContent::Parts(parts) = &last_msg.content {
for part in parts {
if let ContentPart::ToolUse { id, name, input } = part {
calls.push(ToolCall {
id: id.clone(),
name: name.clone(),
input: input.clone(),
});
}
}
}
if calls.is_empty() {
None
} else {
Some(calls)
}
}
/// ツールを並列実行 /// ツールを並列実行
/// ///
/// 全てのツールに対してbefore_tool_callフックを実行後、 /// 全てのツールに対してbefore_tool_callフックを実行後、
@ -513,7 +561,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
async fn execute_tools( async fn execute_tools(
&self, &self,
tool_calls: Vec<ToolCall>, tool_calls: Vec<ToolCall>,
) -> Result<Vec<ToolResult>, WorkerError> { ) -> Result<ToolExecutionResult, WorkerError> {
use futures::future::join_all; use futures::future::join_all;
// Phase 1: before_tool_call フックを適用(スキップ/中断を判定) // Phase 1: before_tool_call フックを適用(スキップ/中断を判定)
@ -531,6 +579,9 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
ControlFlow::Abort(reason) => { ControlFlow::Abort(reason) => {
return Err(WorkerError::Aborted(reason)); return Err(WorkerError::Aborted(reason));
} }
ControlFlow::Pause => {
return Ok(ToolExecutionResult::Paused);
}
} }
} }
if !skip { if !skip {
@ -573,15 +624,22 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
ControlFlow::Abort(reason) => { ControlFlow::Abort(reason) => {
return Err(WorkerError::Aborted(reason)); return Err(WorkerError::Aborted(reason));
} }
ControlFlow::Pause => {
// after_tool_callでのPauseは結果を受け入れた後、次の処理前に止まる動作とする
// ここではContinue扱いとし、on_message_send等でPauseすることを期待する
// あるいはここでのPauseをサポートする場合は戻り値を調整する必要がある
// 現状はログを出してContinue
warn!("ControlFlow::Pause in after_tool_call is treated as Continue");
}
} }
} }
} }
Ok(results) Ok(ToolExecutionResult::Completed(results))
} }
/// 内部で使用するターン実行ロジック /// 内部で使用するターン実行ロジック
async fn run_turn_loop(&mut self) -> Result<(), WorkerError> { async fn run_turn_loop(&mut self) -> Result<WorkerResult<'_>, WorkerError> {
let tool_definitions = self.build_tool_definitions(); let tool_definitions = self.build_tool_definitions();
info!( info!(
@ -590,6 +648,20 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
"Starting worker run" "Starting worker run"
); );
// Resume check: Pending tool calls
if let Some(tool_calls) = self.get_pending_tool_calls() {
info!("Resuming pending tool calls");
match self.execute_tools(tool_calls).await? {
ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)),
ToolExecutionResult::Completed(results) => {
for result in results {
self.history.push(Message::tool_result(&result.tool_use_id, &result.content));
}
// Continue to loop
}
}
}
loop { loop {
// ターン開始を通知 // ターン開始を通知
let current_turn = self.turn_count; let current_turn = self.turn_count;
@ -600,13 +672,19 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
// Hook: on_message_send // Hook: on_message_send
let control = self.run_on_message_send_hooks().await?; let control = self.run_on_message_send_hooks().await?;
if let ControlFlow::Abort(reason) = control { match control {
warn!(reason = %reason, "Aborted by hook"); ControlFlow::Abort(reason) => {
// ターン終了を通知(異常終了) warn!(reason = %reason, "Aborted by hook");
for notifier in &self.turn_notifiers { for notifier in &self.turn_notifiers {
notifier.on_turn_end(current_turn); notifier.on_turn_end(current_turn);
}
return Err(WorkerError::Aborted(reason));
} }
return Err(WorkerError::Aborted(reason)); ControlFlow::Pause | ControlFlow::Skip => {
// Skip or Pause -> Pause the worker
return Ok(WorkerResult::Paused(&self.history));
}
ControlFlow::Continue => {}
} }
// リクエスト構築 // リクエスト構築
@ -659,25 +737,36 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
let turn_result = self.run_on_turn_end_hooks().await?; let turn_result = self.run_on_turn_end_hooks().await?;
match turn_result { match turn_result {
TurnResult::Finish => { TurnResult::Finish => {
return Ok(()); return Ok(WorkerResult::Finished(&self.history));
} }
TurnResult::ContinueWithMessages(additional) => { TurnResult::ContinueWithMessages(additional) => {
self.history.extend(additional); self.history.extend(additional);
continue; continue;
} }
TurnResult::Paused => {
return Ok(WorkerResult::Paused(&self.history));
}
} }
} }
// ツール実行 // ツール実行
let tool_results = self.execute_tools(tool_calls).await?; match self.execute_tools(tool_calls).await? {
ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)),
// ツール結果を履歴に追加 ToolExecutionResult::Completed(results) => {
for result in tool_results { for result in results {
self.history self.history.push(Message::tool_result(&result.tool_use_id, &result.content));
.push(Message::tool_result(&result.tool_use_id, &result.content)); }
}
} }
} }
} }
/// 実行を再開Pause状態からの復帰
///
/// 新しいユーザーメッセージを履歴に追加せず、現在の状態からターン処理を再開する。
pub async fn resume(&mut self) -> Result<WorkerResult<'_>, WorkerError> {
self.run_turn_loop().await
}
} }
// ============================================================================= // =============================================================================
@ -887,10 +976,9 @@ impl<C: LlmClient> Worker<C, Mutable> {
/// ///
/// 注意: この関数は履歴を変更するため、キャッシュ保護が必要な場合は /// 注意: この関数は履歴を変更するため、キャッシュ保護が必要な場合は
/// `lock()` を呼んでからLocked状態で `run` を使用すること。 /// `lock()` を呼んでからLocked状態で `run` を使用すること。
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<&[Message], WorkerError> { pub async fn run(&mut self, user_input: impl Into<String>) -> Result<WorkerResult<'_>, WorkerError> {
self.history.push(Message::user(user_input)); self.history.push(Message::user(user_input));
self.run_turn_loop().await?; self.run_turn_loop().await
Ok(&self.history)
} }
/// 複数メッセージでターンを実行Mutable状態 /// 複数メッセージでターンを実行Mutable状態
@ -899,10 +987,9 @@ impl<C: LlmClient> Worker<C, Mutable> {
pub async fn run_with_messages( pub async fn run_with_messages(
&mut self, &mut self,
messages: Vec<Message>, messages: Vec<Message>,
) -> Result<&[Message], WorkerError> { ) -> Result<WorkerResult<'_>, WorkerError> {
self.history.extend(messages); self.history.extend(messages);
self.run_turn_loop().await?; self.run_turn_loop().await
Ok(&self.history)
} }
} }
@ -915,20 +1002,18 @@ impl<C: LlmClient> Worker<C, Locked> {
/// ///
/// 新しいユーザーメッセージを履歴の末尾に追加し、LLMにリクエストを送信する。 /// 新しいユーザーメッセージを履歴の末尾に追加し、LLMにリクエストを送信する。
/// ロック時点より前の履歴(プレフィックス)は不変であるため、キャッシュヒットが保証される。 /// ロック時点より前の履歴(プレフィックス)は不変であるため、キャッシュヒットが保証される。
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<&[Message], WorkerError> { pub async fn run(&mut self, user_input: impl Into<String>) -> Result<WorkerResult<'_>, WorkerError> {
self.history.push(Message::user(user_input)); self.history.push(Message::user(user_input));
self.run_turn_loop().await?; self.run_turn_loop().await
Ok(&self.history)
} }
/// 複数メッセージでターンを実行Locked状態 /// 複数メッセージでターンを実行Locked状態
pub async fn run_with_messages( pub async fn run_with_messages(
&mut self, &mut self,
messages: Vec<Message>, messages: Vec<Message>,
) -> Result<&[Message], WorkerError> { ) -> Result<WorkerResult<'_>, WorkerError> {
self.history.extend(messages); self.history.extend(messages);
self.run_turn_loop().await?; self.run_turn_loop().await
Ok(&self.history)
} }
/// ロック時点のプレフィックス長を取得 /// ロック時点のプレフィックス長を取得