feat: pause and resume
This commit is contained in:
parent
81107c6f5c
commit
33f1c218f2
|
|
@ -20,6 +20,8 @@ pub enum ControlFlow {
|
||||||
Skip,
|
Skip,
|
||||||
/// 処理を中断
|
/// 処理を中断
|
||||||
Abort(String),
|
Abort(String),
|
||||||
|
/// 処理を一時停止(再開可能)
|
||||||
|
Pause,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ターン終了時の判定結果
|
/// ターン終了時の判定結果
|
||||||
|
|
@ -29,6 +31,8 @@ pub enum TurnResult {
|
||||||
Finish,
|
Finish,
|
||||||
/// メッセージを追加してターン継続(自己修正など)
|
/// メッセージを追加してターン継続(自己修正など)
|
||||||
ContinueWithMessages(Vec<crate::Message>),
|
ContinueWithMessages(Vec<crate::Message>),
|
||||||
|
/// ターンを一時停止
|
||||||
|
Paused,
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
|
||||||
|
|
@ -53,6 +53,25 @@ pub struct WorkerConfig {
|
||||||
_private: (),
|
_private: (),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Worker Result Types
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
/// Workerの実行結果(ステータス)
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum WorkerResult<'a> {
|
||||||
|
/// 完了(ユーザー入力待ち状態)
|
||||||
|
Finished(&'a [Message]),
|
||||||
|
/// 一時停止(再開可能)
|
||||||
|
Paused(&'a [Message]),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 内部用: ツール実行結果
|
||||||
|
enum ToolExecutionResult {
|
||||||
|
Completed(Vec<ToolResult>),
|
||||||
|
Paused,
|
||||||
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// ターン制御用コールバック保持
|
// ターン制御用コールバック保持
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
@ -487,6 +506,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
ControlFlow::Continue => continue,
|
ControlFlow::Continue => continue,
|
||||||
ControlFlow::Skip => return Ok(ControlFlow::Skip),
|
ControlFlow::Skip => return Ok(ControlFlow::Skip),
|
||||||
ControlFlow::Abort(reason) => return Ok(ControlFlow::Abort(reason)),
|
ControlFlow::Abort(reason) => return Ok(ControlFlow::Abort(reason)),
|
||||||
|
ControlFlow::Pause => return Ok(ControlFlow::Pause),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(ControlFlow::Continue)
|
Ok(ControlFlow::Continue)
|
||||||
|
|
@ -501,11 +521,39 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
TurnResult::ContinueWithMessages(msgs) => {
|
TurnResult::ContinueWithMessages(msgs) => {
|
||||||
return Ok(TurnResult::ContinueWithMessages(msgs));
|
return Ok(TurnResult::ContinueWithMessages(msgs));
|
||||||
}
|
}
|
||||||
|
TurnResult::Paused => return Ok(TurnResult::Paused),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(TurnResult::Finish)
|
Ok(TurnResult::Finish)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 未実行のツール呼び出しがあるかチェック(Pauseからの復帰用)
|
||||||
|
fn get_pending_tool_calls(&self) -> Option<Vec<ToolCall>> {
|
||||||
|
let last_msg = self.history.last()?;
|
||||||
|
if last_msg.role != Role::Assistant {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut calls = Vec::new();
|
||||||
|
if let MessageContent::Parts(parts) = &last_msg.content {
|
||||||
|
for part in parts {
|
||||||
|
if let ContentPart::ToolUse { id, name, input } = part {
|
||||||
|
calls.push(ToolCall {
|
||||||
|
id: id.clone(),
|
||||||
|
name: name.clone(),
|
||||||
|
input: input.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if calls.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// ツールを並列実行
|
/// ツールを並列実行
|
||||||
///
|
///
|
||||||
/// 全てのツールに対してbefore_tool_callフックを実行後、
|
/// 全てのツールに対してbefore_tool_callフックを実行後、
|
||||||
|
|
@ -513,7 +561,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
async fn execute_tools(
|
async fn execute_tools(
|
||||||
&self,
|
&self,
|
||||||
tool_calls: Vec<ToolCall>,
|
tool_calls: Vec<ToolCall>,
|
||||||
) -> Result<Vec<ToolResult>, WorkerError> {
|
) -> Result<ToolExecutionResult, WorkerError> {
|
||||||
use futures::future::join_all;
|
use futures::future::join_all;
|
||||||
|
|
||||||
// Phase 1: before_tool_call フックを適用(スキップ/中断を判定)
|
// Phase 1: before_tool_call フックを適用(スキップ/中断を判定)
|
||||||
|
|
@ -531,6 +579,9 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
ControlFlow::Abort(reason) => {
|
ControlFlow::Abort(reason) => {
|
||||||
return Err(WorkerError::Aborted(reason));
|
return Err(WorkerError::Aborted(reason));
|
||||||
}
|
}
|
||||||
|
ControlFlow::Pause => {
|
||||||
|
return Ok(ToolExecutionResult::Paused);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !skip {
|
if !skip {
|
||||||
|
|
@ -573,15 +624,22 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
ControlFlow::Abort(reason) => {
|
ControlFlow::Abort(reason) => {
|
||||||
return Err(WorkerError::Aborted(reason));
|
return Err(WorkerError::Aborted(reason));
|
||||||
}
|
}
|
||||||
|
ControlFlow::Pause => {
|
||||||
|
// after_tool_callでのPauseは結果を受け入れた後、次の処理前に止まる動作とする
|
||||||
|
// ここではContinue扱いとし、on_message_send等でPauseすることを期待する
|
||||||
|
// あるいはここでのPauseをサポートする場合は戻り値を調整する必要がある
|
||||||
|
// 現状はログを出してContinue
|
||||||
|
warn!("ControlFlow::Pause in after_tool_call is treated as Continue");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(results)
|
Ok(ToolExecutionResult::Completed(results))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 内部で使用するターン実行ロジック
|
/// 内部で使用するターン実行ロジック
|
||||||
async fn run_turn_loop(&mut self) -> Result<(), WorkerError> {
|
async fn run_turn_loop(&mut self) -> Result<WorkerResult<'_>, WorkerError> {
|
||||||
let tool_definitions = self.build_tool_definitions();
|
let tool_definitions = self.build_tool_definitions();
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
|
|
@ -590,6 +648,20 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
"Starting worker run"
|
"Starting worker run"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Resume check: Pending tool calls
|
||||||
|
if let Some(tool_calls) = self.get_pending_tool_calls() {
|
||||||
|
info!("Resuming pending tool calls");
|
||||||
|
match self.execute_tools(tool_calls).await? {
|
||||||
|
ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)),
|
||||||
|
ToolExecutionResult::Completed(results) => {
|
||||||
|
for result in results {
|
||||||
|
self.history.push(Message::tool_result(&result.tool_use_id, &result.content));
|
||||||
|
}
|
||||||
|
// Continue to loop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
// ターン開始を通知
|
// ターン開始を通知
|
||||||
let current_turn = self.turn_count;
|
let current_turn = self.turn_count;
|
||||||
|
|
@ -600,14 +672,20 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
|
|
||||||
// Hook: on_message_send
|
// Hook: on_message_send
|
||||||
let control = self.run_on_message_send_hooks().await?;
|
let control = self.run_on_message_send_hooks().await?;
|
||||||
if let ControlFlow::Abort(reason) = control {
|
match control {
|
||||||
|
ControlFlow::Abort(reason) => {
|
||||||
warn!(reason = %reason, "Aborted by hook");
|
warn!(reason = %reason, "Aborted by hook");
|
||||||
// ターン終了を通知(異常終了)
|
|
||||||
for notifier in &self.turn_notifiers {
|
for notifier in &self.turn_notifiers {
|
||||||
notifier.on_turn_end(current_turn);
|
notifier.on_turn_end(current_turn);
|
||||||
}
|
}
|
||||||
return Err(WorkerError::Aborted(reason));
|
return Err(WorkerError::Aborted(reason));
|
||||||
}
|
}
|
||||||
|
ControlFlow::Pause | ControlFlow::Skip => {
|
||||||
|
// Skip or Pause -> Pause the worker
|
||||||
|
return Ok(WorkerResult::Paused(&self.history));
|
||||||
|
}
|
||||||
|
ControlFlow::Continue => {}
|
||||||
|
}
|
||||||
|
|
||||||
// リクエスト構築
|
// リクエスト構築
|
||||||
let request = self.build_request(&tool_definitions);
|
let request = self.build_request(&tool_definitions);
|
||||||
|
|
@ -659,24 +737,35 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
let turn_result = self.run_on_turn_end_hooks().await?;
|
let turn_result = self.run_on_turn_end_hooks().await?;
|
||||||
match turn_result {
|
match turn_result {
|
||||||
TurnResult::Finish => {
|
TurnResult::Finish => {
|
||||||
return Ok(());
|
return Ok(WorkerResult::Finished(&self.history));
|
||||||
}
|
}
|
||||||
TurnResult::ContinueWithMessages(additional) => {
|
TurnResult::ContinueWithMessages(additional) => {
|
||||||
self.history.extend(additional);
|
self.history.extend(additional);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
TurnResult::Paused => {
|
||||||
|
return Ok(WorkerResult::Paused(&self.history));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ツール実行
|
// ツール実行
|
||||||
let tool_results = self.execute_tools(tool_calls).await?;
|
match self.execute_tools(tool_calls).await? {
|
||||||
|
ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)),
|
||||||
|
ToolExecutionResult::Completed(results) => {
|
||||||
|
for result in results {
|
||||||
|
self.history.push(Message::tool_result(&result.tool_use_id, &result.content));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ツール結果を履歴に追加
|
/// 実行を再開(Pause状態からの復帰)
|
||||||
for result in tool_results {
|
///
|
||||||
self.history
|
/// 新しいユーザーメッセージを履歴に追加せず、現在の状態からターン処理を再開する。
|
||||||
.push(Message::tool_result(&result.tool_use_id, &result.content));
|
pub async fn resume(&mut self) -> Result<WorkerResult<'_>, WorkerError> {
|
||||||
}
|
self.run_turn_loop().await
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -887,10 +976,9 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
||||||
///
|
///
|
||||||
/// 注意: この関数は履歴を変更するため、キャッシュ保護が必要な場合は
|
/// 注意: この関数は履歴を変更するため、キャッシュ保護が必要な場合は
|
||||||
/// `lock()` を呼んでからLocked状態で `run` を使用すること。
|
/// `lock()` を呼んでからLocked状態で `run` を使用すること。
|
||||||
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<&[Message], WorkerError> {
|
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<WorkerResult<'_>, WorkerError> {
|
||||||
self.history.push(Message::user(user_input));
|
self.history.push(Message::user(user_input));
|
||||||
self.run_turn_loop().await?;
|
self.run_turn_loop().await
|
||||||
Ok(&self.history)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 複数メッセージでターンを実行(Mutable状態)
|
/// 複数メッセージでターンを実行(Mutable状態)
|
||||||
|
|
@ -899,10 +987,9 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
||||||
pub async fn run_with_messages(
|
pub async fn run_with_messages(
|
||||||
&mut self,
|
&mut self,
|
||||||
messages: Vec<Message>,
|
messages: Vec<Message>,
|
||||||
) -> Result<&[Message], WorkerError> {
|
) -> Result<WorkerResult<'_>, WorkerError> {
|
||||||
self.history.extend(messages);
|
self.history.extend(messages);
|
||||||
self.run_turn_loop().await?;
|
self.run_turn_loop().await
|
||||||
Ok(&self.history)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -915,20 +1002,18 @@ impl<C: LlmClient> Worker<C, Locked> {
|
||||||
///
|
///
|
||||||
/// 新しいユーザーメッセージを履歴の末尾に追加し、LLMにリクエストを送信する。
|
/// 新しいユーザーメッセージを履歴の末尾に追加し、LLMにリクエストを送信する。
|
||||||
/// ロック時点より前の履歴(プレフィックス)は不変であるため、キャッシュヒットが保証される。
|
/// ロック時点より前の履歴(プレフィックス)は不変であるため、キャッシュヒットが保証される。
|
||||||
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<&[Message], WorkerError> {
|
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<WorkerResult<'_>, WorkerError> {
|
||||||
self.history.push(Message::user(user_input));
|
self.history.push(Message::user(user_input));
|
||||||
self.run_turn_loop().await?;
|
self.run_turn_loop().await
|
||||||
Ok(&self.history)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 複数メッセージでターンを実行(Locked状態)
|
/// 複数メッセージでターンを実行(Locked状態)
|
||||||
pub async fn run_with_messages(
|
pub async fn run_with_messages(
|
||||||
&mut self,
|
&mut self,
|
||||||
messages: Vec<Message>,
|
messages: Vec<Message>,
|
||||||
) -> Result<&[Message], WorkerError> {
|
) -> Result<WorkerResult<'_>, WorkerError> {
|
||||||
self.history.extend(messages);
|
self.history.extend(messages);
|
||||||
self.run_turn_loop().await?;
|
self.run_turn_loop().await
|
||||||
Ok(&self.history)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ロック時点のプレフィックス長を取得
|
/// ロック時点のプレフィックス長を取得
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user