feat: pause and resume
This commit is contained in:
parent
81107c6f5c
commit
33f1c218f2
|
|
@ -20,6 +20,8 @@ pub enum ControlFlow {
|
|||
Skip,
|
||||
/// 処理を中断
|
||||
Abort(String),
|
||||
/// 処理を一時停止(再開可能)
|
||||
Pause,
|
||||
}
|
||||
|
||||
/// ターン終了時の判定結果
|
||||
|
|
@ -29,6 +31,8 @@ pub enum TurnResult {
|
|||
Finish,
|
||||
/// メッセージを追加してターン継続(自己修正など)
|
||||
ContinueWithMessages(Vec<crate::Message>),
|
||||
/// ターンを一時停止
|
||||
Paused,
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
|
|
|||
|
|
@ -53,6 +53,25 @@ pub struct WorkerConfig {
|
|||
_private: (),
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Worker Result Types
|
||||
// =============================================================================
|
||||
|
||||
/// Workerの実行結果(ステータス)
|
||||
#[derive(Debug)]
|
||||
pub enum WorkerResult<'a> {
|
||||
/// 完了(ユーザー入力待ち状態)
|
||||
Finished(&'a [Message]),
|
||||
/// 一時停止(再開可能)
|
||||
Paused(&'a [Message]),
|
||||
}
|
||||
|
||||
/// 内部用: ツール実行結果
|
||||
enum ToolExecutionResult {
|
||||
Completed(Vec<ToolResult>),
|
||||
Paused,
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// ターン制御用コールバック保持
|
||||
// =============================================================================
|
||||
|
|
@ -487,6 +506,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
ControlFlow::Continue => continue,
|
||||
ControlFlow::Skip => return Ok(ControlFlow::Skip),
|
||||
ControlFlow::Abort(reason) => return Ok(ControlFlow::Abort(reason)),
|
||||
ControlFlow::Pause => return Ok(ControlFlow::Pause),
|
||||
}
|
||||
}
|
||||
Ok(ControlFlow::Continue)
|
||||
|
|
@ -501,11 +521,39 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
TurnResult::ContinueWithMessages(msgs) => {
|
||||
return Ok(TurnResult::ContinueWithMessages(msgs));
|
||||
}
|
||||
TurnResult::Paused => return Ok(TurnResult::Paused),
|
||||
}
|
||||
}
|
||||
Ok(TurnResult::Finish)
|
||||
}
|
||||
|
||||
/// 未実行のツール呼び出しがあるかチェック(Pauseからの復帰用)
|
||||
fn get_pending_tool_calls(&self) -> Option<Vec<ToolCall>> {
|
||||
let last_msg = self.history.last()?;
|
||||
if last_msg.role != Role::Assistant {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut calls = Vec::new();
|
||||
if let MessageContent::Parts(parts) = &last_msg.content {
|
||||
for part in parts {
|
||||
if let ContentPart::ToolUse { id, name, input } = part {
|
||||
calls.push(ToolCall {
|
||||
id: id.clone(),
|
||||
name: name.clone(),
|
||||
input: input.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(calls)
|
||||
}
|
||||
}
|
||||
|
||||
/// ツールを並列実行
|
||||
///
|
||||
/// 全てのツールに対してbefore_tool_callフックを実行後、
|
||||
|
|
@ -513,7 +561,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
async fn execute_tools(
|
||||
&self,
|
||||
tool_calls: Vec<ToolCall>,
|
||||
) -> Result<Vec<ToolResult>, WorkerError> {
|
||||
) -> Result<ToolExecutionResult, WorkerError> {
|
||||
use futures::future::join_all;
|
||||
|
||||
// Phase 1: before_tool_call フックを適用(スキップ/中断を判定)
|
||||
|
|
@ -531,6 +579,9 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
ControlFlow::Abort(reason) => {
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
ControlFlow::Pause => {
|
||||
return Ok(ToolExecutionResult::Paused);
|
||||
}
|
||||
}
|
||||
}
|
||||
if !skip {
|
||||
|
|
@ -573,15 +624,22 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
ControlFlow::Abort(reason) => {
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
ControlFlow::Pause => {
|
||||
// after_tool_callでのPauseは結果を受け入れた後、次の処理前に止まる動作とする
|
||||
// ここではContinue扱いとし、on_message_send等でPauseすることを期待する
|
||||
// あるいはここでのPauseをサポートする場合は戻り値を調整する必要がある
|
||||
// 現状はログを出してContinue
|
||||
warn!("ControlFlow::Pause in after_tool_call is treated as Continue");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
Ok(ToolExecutionResult::Completed(results))
|
||||
}
|
||||
|
||||
/// 内部で使用するターン実行ロジック
|
||||
async fn run_turn_loop(&mut self) -> Result<(), WorkerError> {
|
||||
async fn run_turn_loop(&mut self) -> Result<WorkerResult<'_>, WorkerError> {
|
||||
let tool_definitions = self.build_tool_definitions();
|
||||
|
||||
info!(
|
||||
|
|
@ -590,6 +648,20 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
"Starting worker run"
|
||||
);
|
||||
|
||||
// Resume check: Pending tool calls
|
||||
if let Some(tool_calls) = self.get_pending_tool_calls() {
|
||||
info!("Resuming pending tool calls");
|
||||
match self.execute_tools(tool_calls).await? {
|
||||
ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)),
|
||||
ToolExecutionResult::Completed(results) => {
|
||||
for result in results {
|
||||
self.history.push(Message::tool_result(&result.tool_use_id, &result.content));
|
||||
}
|
||||
// Continue to loop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
// ターン開始を通知
|
||||
let current_turn = self.turn_count;
|
||||
|
|
@ -600,13 +672,19 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
|
||||
// Hook: on_message_send
|
||||
let control = self.run_on_message_send_hooks().await?;
|
||||
if let ControlFlow::Abort(reason) = control {
|
||||
warn!(reason = %reason, "Aborted by hook");
|
||||
// ターン終了を通知(異常終了)
|
||||
for notifier in &self.turn_notifiers {
|
||||
notifier.on_turn_end(current_turn);
|
||||
match control {
|
||||
ControlFlow::Abort(reason) => {
|
||||
warn!(reason = %reason, "Aborted by hook");
|
||||
for notifier in &self.turn_notifiers {
|
||||
notifier.on_turn_end(current_turn);
|
||||
}
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
ControlFlow::Pause | ControlFlow::Skip => {
|
||||
// Skip or Pause -> Pause the worker
|
||||
return Ok(WorkerResult::Paused(&self.history));
|
||||
}
|
||||
ControlFlow::Continue => {}
|
||||
}
|
||||
|
||||
// リクエスト構築
|
||||
|
|
@ -659,25 +737,36 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
let turn_result = self.run_on_turn_end_hooks().await?;
|
||||
match turn_result {
|
||||
TurnResult::Finish => {
|
||||
return Ok(());
|
||||
return Ok(WorkerResult::Finished(&self.history));
|
||||
}
|
||||
TurnResult::ContinueWithMessages(additional) => {
|
||||
self.history.extend(additional);
|
||||
continue;
|
||||
}
|
||||
TurnResult::Paused => {
|
||||
return Ok(WorkerResult::Paused(&self.history));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ツール実行
|
||||
let tool_results = self.execute_tools(tool_calls).await?;
|
||||
|
||||
// ツール結果を履歴に追加
|
||||
for result in tool_results {
|
||||
self.history
|
||||
.push(Message::tool_result(&result.tool_use_id, &result.content));
|
||||
match self.execute_tools(tool_calls).await? {
|
||||
ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)),
|
||||
ToolExecutionResult::Completed(results) => {
|
||||
for result in results {
|
||||
self.history.push(Message::tool_result(&result.tool_use_id, &result.content));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 実行を再開(Pause状態からの復帰)
|
||||
///
|
||||
/// 新しいユーザーメッセージを履歴に追加せず、現在の状態からターン処理を再開する。
|
||||
pub async fn resume(&mut self) -> Result<WorkerResult<'_>, WorkerError> {
|
||||
self.run_turn_loop().await
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
|
@ -887,10 +976,9 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
|||
///
|
||||
/// 注意: この関数は履歴を変更するため、キャッシュ保護が必要な場合は
|
||||
/// `lock()` を呼んでからLocked状態で `run` を使用すること。
|
||||
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<&[Message], WorkerError> {
|
||||
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<WorkerResult<'_>, WorkerError> {
|
||||
self.history.push(Message::user(user_input));
|
||||
self.run_turn_loop().await?;
|
||||
Ok(&self.history)
|
||||
self.run_turn_loop().await
|
||||
}
|
||||
|
||||
/// 複数メッセージでターンを実行(Mutable状態)
|
||||
|
|
@ -899,10 +987,9 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
|||
pub async fn run_with_messages(
|
||||
&mut self,
|
||||
messages: Vec<Message>,
|
||||
) -> Result<&[Message], WorkerError> {
|
||||
) -> Result<WorkerResult<'_>, WorkerError> {
|
||||
self.history.extend(messages);
|
||||
self.run_turn_loop().await?;
|
||||
Ok(&self.history)
|
||||
self.run_turn_loop().await
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -915,20 +1002,18 @@ impl<C: LlmClient> Worker<C, Locked> {
|
|||
///
|
||||
/// 新しいユーザーメッセージを履歴の末尾に追加し、LLMにリクエストを送信する。
|
||||
/// ロック時点より前の履歴(プレフィックス)は不変であるため、キャッシュヒットが保証される。
|
||||
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<&[Message], WorkerError> {
|
||||
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<WorkerResult<'_>, WorkerError> {
|
||||
self.history.push(Message::user(user_input));
|
||||
self.run_turn_loop().await?;
|
||||
Ok(&self.history)
|
||||
self.run_turn_loop().await
|
||||
}
|
||||
|
||||
/// 複数メッセージでターンを実行(Locked状態)
|
||||
pub async fn run_with_messages(
|
||||
&mut self,
|
||||
messages: Vec<Message>,
|
||||
) -> Result<&[Message], WorkerError> {
|
||||
) -> Result<WorkerResult<'_>, WorkerError> {
|
||||
self.history.extend(messages);
|
||||
self.run_turn_loop().await?;
|
||||
Ok(&self.history)
|
||||
self.run_turn_loop().await
|
||||
}
|
||||
|
||||
/// ロック時点のプレフィックス長を取得
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user