feature/hook-and-tool-types-improvements #2

Merged
Hare merged 4 commits from feature/hook-and-tool-types-improvements into develop 2026-01-10 22:47:06 +09:00
4 changed files with 192 additions and 148 deletions
Showing only changes of commit a2f53d7879 - Show all commits

View File

@ -4,7 +4,7 @@ Workerの非同期キャンセル機構についての設計ドキュメント
## 概要
`tokio_util::sync::CancellationToken`を用いて、別タスクからWorkerの実行を安全にキャンセルできる。
`tokio::sync::mpsc`の通知チャネルを用いて、別タスクからWorkerの実行を安全にキャンセルできる。
```rust
let worker = Arc::new(Mutex::new(Worker::new(client)));
@ -19,15 +19,6 @@ let handle = tokio::spawn(async move {
worker.lock().await.cancel();
```
## キャンセルポイント
キャンセルは以下のタイミングでチェックされる:
1. **ターンループ先頭**`is_cancelled()`で即座にチェック
2. **ストリーム開始前**`client.stream()`呼び出し時
3. **ストリーム受信中**`tokio::select!`で各イベント受信と並行監視
4. **ツール実行中**`join_all()`と並行監視
## キャンセル時の処理フロー
```
@ -42,11 +33,10 @@ Err(WorkerError::Cancelled) // エラー返却
## API
| メソッド | 説明 |
| ---------------------- | --------------------------------------------------------- |
| `cancel()` | キャンセルをトリガー |
| `is_cancelled()` | キャンセル状態を確認 |
| `cancellation_token()` | トークンへの参照を取得(`clone()`してタスク間で共有可能) |
| メソッド | 説明 |
| ----------------- | ------------------------------ |
| `cancel()` | キャンセルをトリガー |
| `cancel_sender()` | キャンセル通知用のSenderを取得 |
## on_abort フック
@ -69,22 +59,12 @@ async fn on_abort(&self, reason: &str) -> Result<(), HookError> {
## 既知の問題
### 1. キャンセルトークンの再利用不可
### on_abort の発火基準
`CancellationToken`は一度キャンセルされると永続的にキャンセル状態になる。
同じWorkerインスタンスで再度`run()`を呼ぶと即座に`Cancelled`エラーになる。
`on_abort`**interrupt中断** された場合に必ず発火する。
**対応案:**
interrupt の例:
- `run()`開始時に新しいトークンを生成する
- `reset_cancellation()`メソッドを提供する
### 2. Sync バウンドの追加(破壊的変更)
`tokio::select!`使用のため、Handler/Scope型に`Sync`バウンドを追加した。
既存のユーザーコードで`Sync`未実装の型を使用している場合、コンパイルエラーになる。
### 3. エラー時のon_abort呼び出し
現在、`on_abort`はキャンセルとフックAbort時のみ呼ばれる。
ストリームエラー等のその他エラー時には呼ばれないため、一貫性に欠ける可能性がある。
- `WorkerError::Cancelled`(キャンセル)
- `WorkerError::Aborted`フックによるAbort
- ストリーム/ツール/クライアント/Hook の各種エラーで処理が中断された場合

View File

@ -30,10 +30,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("🚀 Starting Worker...");
println!("💡 Will cancel after 2 seconds\n");
// キャンセルトークンを先に取得(ロックを保持しない)
let cancel_token = {
// キャンセルSenderを先に取得(ロックを保持しない)
let cancel_tx = {
let w = worker.lock().await;
w.cancellation_token().clone()
w.cancel_sender()
};
// タスク1: Workerを実行
@ -43,10 +43,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("📡 Sending request to LLM...");
match w.run("Tell me a very long story about a brave knight. Make it as detailed as possible with many paragraphs.").await {
Ok(WorkerResult::Finished(_)) => {
Ok(WorkerResult::Finished) => {
println!("✅ Task completed normally");
}
Ok(WorkerResult::Paused(_)) => {
Ok(WorkerResult::Paused) => {
println!("⏸️ Task paused");
}
Err(e) => {
@ -59,7 +59,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
tokio::spawn(async move {
tokio::time::sleep(Duration::from_secs(2)).await;
println!("\n🛑 Cancelling worker...");
cancel_token.cancel();
let _ = cancel_tx.send(()).await;
});
// タスク完了を待つ

View File

@ -3,7 +3,7 @@ use std::marker::PhantomData;
use std::sync::{Arc, Mutex};
use futures::StreamExt;
use tokio_util::sync::CancellationToken;
use tokio::sync::mpsc;
use tracing::{debug, info, trace, warn};
use crate::{
@ -78,11 +78,11 @@ pub struct WorkerConfig {
/// Workerの実行結果ステータス
#[derive(Debug)]
pub enum WorkerResult<'a> {
pub enum WorkerResult {
/// 完了(ユーザー入力待ち状態)
Finished(&'a [Message]),
Finished,
/// 一時停止(再開可能)
Paused(&'a [Message]),
Paused,
}
/// 内部用: ツール実行結果
@ -182,8 +182,11 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
turn_notifiers: Vec<Box<dyn TurnNotifier>>,
/// リクエスト設定max_tokens, temperature等
request_config: RequestConfig,
/// キャンセレーショントークン(実行中断用)
cancellation_token: CancellationToken,
/// 前回の実行が中断されたかどうか
last_run_interrupted: bool,
/// キャンセル通知用チャネル(実行中断用)
cancel_tx: mpsc::Sender<()>,
cancel_rx: mpsc::Receiver<()>,
/// 状態マーカー
_state: PhantomData<S>,
}
@ -193,6 +196,57 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
// =============================================================================
impl<C: LlmClient, S: WorkerState> Worker<C, S> {
fn reset_interruption_state(&mut self) {
self.last_run_interrupted = false;
}
/// ターンを実行
///
/// 新しいユーザーメッセージを履歴に追加し、LLMにリクエストを送信する。
/// ツール呼び出しがある場合は自動的にループする。
pub async fn run(
&mut self,
user_input: impl Into<String>,
) -> Result<WorkerResult, WorkerError> {
self.reset_interruption_state();
// Hook: on_prompt_submit
let mut user_message = Message::user(user_input);
let result = self.run_on_prompt_submit_hooks(&mut user_message).await;
let result = match result {
Ok(value) => value,
Err(err) => return self.finalize_interruption(Err(err)).await,
};
match result {
OnPromptSubmitResult::Cancel(reason) => {
self.last_run_interrupted = true;
return self.finalize_interruption(Err(WorkerError::Aborted(reason))).await;
}
OnPromptSubmitResult::Continue => {}
}
self.history.push(user_message);
let result = self.run_turn_loop().await;
self.finalize_interruption(result).await
}
fn drain_cancel_queue(&mut self) {
use tokio::sync::mpsc::error::TryRecvError;
loop {
match self.cancel_rx.try_recv() {
Ok(()) => continue,
Err(TryRecvError::Empty) | Err(TryRecvError::Disconnected) => break,
}
}
}
fn try_cancelled(&mut self) -> bool {
use tokio::sync::mpsc::error::TryRecvError;
match self.cancel_rx.try_recv() {
Ok(()) => true,
Err(TryRecvError::Empty) => false,
Err(TryRecvError::Disconnected) => true,
}
}
/// イベント購読者を登録する
///
/// 登録したSubscriberは、LLMからのストリーミングイベントを
@ -410,6 +464,11 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
self.request_config.stop_sequences.clear();
}
/// キャンセル通知用Senderを取得する
pub fn cancel_sender(&self) -> mpsc::Sender<()> {
self.cancel_tx.clone()
}
/// リクエスト設定を一括で設定
pub fn set_request_config(&mut self, config: RequestConfig) {
self.request_config = config;
@ -437,17 +496,17 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
/// worker.lock().unwrap().cancel();
/// ```
pub fn cancel(&self) {
self.cancellation_token.cancel();
let _ = self.cancel_tx.try_send(());
}
/// キャンセルされているかチェック
pub fn is_cancelled(&self) -> bool {
self.cancellation_token.is_cancelled()
pub fn is_cancelled(&mut self) -> bool {
self.try_cancelled()
}
/// キャンセレーショントークンへの参照を取得
pub fn cancellation_token(&self) -> &CancellationToken {
&self.cancellation_token
/// 前回の実行が中断されたかどうか
pub fn last_run_interrupted(&self) -> bool {
self.last_run_interrupted
}
/// 登録されたツールからLLM用ToolDefinitionのリストを生成
@ -636,6 +695,28 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
Ok(())
}
async fn finalize_interruption<T>(
&mut self,
result: Result<T, WorkerError>,
) -> Result<T, WorkerError> {
match result {
Ok(value) => Ok(value),
Err(err) => {
self.last_run_interrupted = true;
let reason = match &err {
WorkerError::Aborted(reason) => reason.clone(),
WorkerError::Cancelled => "Cancelled".to_string(),
_ => err.to_string(),
};
if let Err(hook_err) = self.run_on_abort_hooks(&reason).await {
self.last_run_interrupted = true;
return Err(hook_err);
}
Err(err)
}
}
}
/// 未実行のツール呼び出しがあるかチェックPauseからの復帰用
fn get_pending_tool_calls(&self) -> Option<Vec<ToolCall>> {
let last_msg = self.history.last()?;
@ -687,7 +768,10 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
let mut skip = false;
for hook in &self.hooks.pre_tool_call {
let result = hook.call(&mut context).await?;
let result = hook
.call(&mut context)
.await
.inspect_err(|_| self.last_run_interrupted = true)?;
match result {
PreToolCallResult::Continue => {}
PreToolCallResult::Skip => {
@ -695,9 +779,11 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
break;
}
PreToolCallResult::Abort(reason) => {
self.last_run_interrupted = true;
return Err(WorkerError::Aborted(reason));
}
PreToolCallResult::Pause => {
self.last_run_interrupted = true;
return Ok(ToolExecutionResult::Paused);
}
}
@ -747,10 +833,12 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
// ツール実行をキャンセル可能にする
let mut results = tokio::select! {
results = join_all(futures) => results,
_ = self.cancellation_token.cancelled() => {
info!("Tool execution cancelled");
cancel = self.cancel_rx.recv() => {
if cancel.is_some() {
info!("Tool execution cancelled");
}
self.timeline.abort_current_block();
self.run_on_abort_hooks("Cancelled").await?;
self.last_run_interrupted = true;
return Err(WorkerError::Cancelled);
}
};
@ -767,10 +855,14 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
};
for hook in &self.hooks.post_tool_call {
let result = hook.call(&mut context).await?;
let result = hook
.call(&mut context)
.await
.inspect_err(|_| self.last_run_interrupted = true)?;
match result {
PostToolCallResult::Continue => {}
PostToolCallResult::Abort(reason) => {
self.last_run_interrupted = true;
return Err(WorkerError::Aborted(reason));
}
}
@ -784,7 +876,9 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
}
/// 内部で使用するターン実行ロジック
async fn run_turn_loop(&mut self) -> Result<WorkerResult<'_>, WorkerError> {
async fn run_turn_loop(&mut self) -> Result<WorkerResult, WorkerError> {
self.reset_interruption_state();
self.drain_cancel_queue();
let tool_definitions = self.build_tool_definitions();
info!(
@ -796,24 +890,31 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
// Resume check: Pending tool calls
if let Some(tool_calls) = self.get_pending_tool_calls() {
info!("Resuming pending tool calls");
match self.execute_tools(tool_calls).await? {
ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)),
ToolExecutionResult::Completed(results) => {
match self.execute_tools(tool_calls).await {
Ok(ToolExecutionResult::Paused) => {
self.last_run_interrupted = true;
return Ok(WorkerResult::Paused);
}
Ok(ToolExecutionResult::Completed(results)) => {
for result in results {
self.history
.push(Message::tool_result(&result.tool_use_id, &result.content));
}
// Continue to loop
}
Err(err) => {
self.last_run_interrupted = true;
return Err(err);
}
}
}
loop {
// キャンセルチェック
if self.cancellation_token.is_cancelled() {
if self.try_cancelled() {
info!("Execution cancelled");
self.timeline.abort_current_block();
self.run_on_abort_hooks("Cancelled").await?;
self.last_run_interrupted = true;
return Err(WorkerError::Cancelled);
}
@ -825,14 +926,17 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
}
// Hook: pre_llm_request
let (control, request_context) = self.run_pre_llm_request_hooks().await?;
let (control, request_context) = self
.run_pre_llm_request_hooks()
.await
.inspect_err(|_| self.last_run_interrupted = true)?;
match control {
PreLlmRequestResult::Cancel(reason) => {
info!(reason = %reason, "Aborted by hook");
for notifier in &self.turn_notifiers {
notifier.on_turn_end(current_turn);
}
self.run_on_abort_hooks(&reason).await?;
self.last_run_interrupted = true;
return Err(WorkerError::Aborted(reason));
}
PreLlmRequestResult::Continue => {}
@ -853,11 +957,14 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
// ストリームを取得(キャンセル可能)
let mut stream = tokio::select! {
stream_result = self.client.stream(request) => stream_result?,
_ = self.cancellation_token.cancelled() => {
info!("Cancelled before stream started");
stream_result = self.client.stream(request) => stream_result
.inspect_err(|_| self.last_run_interrupted = true)?,
cancel = self.cancel_rx.recv() => {
if cancel.is_some() {
info!("Cancelled before stream started");
}
self.timeline.abort_current_block();
self.run_on_abort_hooks("Cancelled").await?;
self.last_run_interrupted = true;
return Err(WorkerError::Cancelled);
}
};
@ -877,7 +984,8 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
warn!(error = %e, "Stream error");
}
}
let event = result?;
let event = result
.inspect_err(|_| self.last_run_interrupted = true)?;
let timeline_event: crate::timeline::event::Event = event.into();
self.timeline.dispatch(&timeline_event);
}
@ -885,10 +993,12 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
}
}
// キャンセル待機
_ = self.cancellation_token.cancelled() => {
info!("Stream cancelled");
cancel = self.cancel_rx.recv() => {
if cancel.is_some() {
info!("Stream cancelled");
}
self.timeline.abort_current_block();
self.run_on_abort_hooks("Cancelled").await?;
self.last_run_interrupted = true;
return Err(WorkerError::Cancelled);
}
}
@ -913,30 +1023,42 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
if tool_calls.is_empty() {
// ツール呼び出しなし → ターン終了判定
let turn_result = self.run_on_turn_end_hooks().await?;
let turn_result = self
.run_on_turn_end_hooks()
.await
.inspect_err(|_| self.last_run_interrupted = true)?;
match turn_result {
OnTurnEndResult::Finish => {
return Ok(WorkerResult::Finished(&self.history));
self.last_run_interrupted = false;
return Ok(WorkerResult::Finished);
}
OnTurnEndResult::ContinueWithMessages(additional) => {
self.history.extend(additional);
continue;
}
OnTurnEndResult::Paused => {
return Ok(WorkerResult::Paused(&self.history));
self.last_run_interrupted = true;
return Ok(WorkerResult::Paused);
}
}
}
// ツール実行
match self.execute_tools(tool_calls).await? {
ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)),
ToolExecutionResult::Completed(results) => {
match self.execute_tools(tool_calls).await {
Ok(ToolExecutionResult::Paused) => {
self.last_run_interrupted = true;
return Ok(WorkerResult::Paused);
}
Ok(ToolExecutionResult::Completed(results)) => {
for result in results {
self.history
.push(Message::tool_result(&result.tool_use_id, &result.content));
}
}
Err(err) => {
self.last_run_interrupted = true;
return Err(err);
}
}
}
}
@ -944,8 +1066,10 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
/// 実行を再開Pause状態からの復帰
///
/// 新しいユーザーメッセージを履歴に追加せず、現在の状態からターン処理を再開する。
pub async fn resume(&mut self) -> Result<WorkerResult<'_>, WorkerError> {
self.run_turn_loop().await
pub async fn resume(&mut self) -> Result<WorkerResult, WorkerError> {
self.reset_interruption_state();
let result = self.run_turn_loop().await;
self.finalize_interruption(result).await
}
}
@ -959,6 +1083,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
let text_block_collector = TextBlockCollector::new();
let tool_call_collector = ToolCallCollector::new();
let mut timeline = Timeline::new();
let (cancel_tx, cancel_rx) = mpsc::channel(1);
// コレクターをTimelineに登録
timeline.on_text_block(text_block_collector.clone());
@ -977,7 +1102,9 @@ impl<C: LlmClient> Worker<C, Mutable> {
turn_count: 0,
turn_notifiers: Vec::new(),
request_config: RequestConfig::default(),
cancellation_token: CancellationToken::new(),
last_run_interrupted: false,
cancel_tx,
cancel_rx,
_state: PhantomData,
}
}
@ -1146,46 +1273,13 @@ impl<C: LlmClient> Worker<C, Mutable> {
turn_count: self.turn_count,
turn_notifiers: self.turn_notifiers,
request_config: self.request_config,
cancellation_token: self.cancellation_token,
last_run_interrupted: self.last_run_interrupted,
cancel_tx: self.cancel_tx,
cancel_rx: self.cancel_rx,
_state: PhantomData,
}
}
/// ターンを実行Mutable状態
///
/// 新しいユーザーメッセージを履歴に追加し、LLMにリクエストを送信する。
/// ツール呼び出しがある場合は自動的にループする。
///
/// 注意: この関数は履歴を変更するため、キャッシュ保護が必要な場合は
/// `lock()` を呼んでからLocked状態で `run` を使用すること。
pub async fn run(
&mut self,
user_input: impl Into<String>,
) -> Result<WorkerResult<'_>, WorkerError> {
// Hook: on_prompt_submit
let mut user_message = Message::user(user_input);
let result = self.run_on_prompt_submit_hooks(&mut user_message).await?;
match result {
OnPromptSubmitResult::Cancel(reason) => {
self.run_on_abort_hooks(&reason).await?;
return Err(WorkerError::Aborted(reason));
}
OnPromptSubmitResult::Continue => {}
}
self.history.push(user_message);
self.run_turn_loop().await
}
/// 複数メッセージでターンを実行Mutable状態
///
/// 指定されたメッセージを履歴に追加してから実行する。
pub async fn run_with_messages(
&mut self,
messages: Vec<Message>,
) -> Result<WorkerResult<'_>, WorkerError> {
self.history.extend(messages);
self.run_turn_loop().await
}
}
// =============================================================================
@ -1193,37 +1287,6 @@ impl<C: LlmClient> Worker<C, Mutable> {
// =============================================================================
impl<C: LlmClient> Worker<C, Locked> {
/// ターンを実行Locked状態
///
/// 新しいユーザーメッセージを履歴の末尾に追加し、LLMにリクエストを送信する。
/// ロック時点より前の履歴(プレフィックス)は不変であるため、キャッシュヒットが保証される。
pub async fn run(
&mut self,
user_input: impl Into<String>,
) -> Result<WorkerResult<'_>, WorkerError> {
// Hook: on_prompt_submit
let mut user_message = Message::user(user_input);
let result = self.run_on_prompt_submit_hooks(&mut user_message).await?;
match result {
OnPromptSubmitResult::Cancel(reason) => {
self.run_on_abort_hooks(&reason).await?;
return Err(WorkerError::Aborted(reason));
}
OnPromptSubmitResult::Continue => {}
}
self.history.push(user_message);
self.run_turn_loop().await
}
/// 複数メッセージでターンを実行Locked状態
pub async fn run_with_messages(
&mut self,
messages: Vec<Message>,
) -> Result<WorkerResult<'_>, WorkerError> {
self.history.extend(messages);
self.run_turn_loop().await
}
/// ロック時点のプレフィックス長を取得
pub fn locked_prefix_len(&self) -> usize {
self.locked_prefix_len
@ -1247,7 +1310,9 @@ impl<C: LlmClient> Worker<C, Locked> {
turn_count: self.turn_count,
turn_notifiers: self.turn_notifiers,
request_config: self.request_config,
cancellation_token: self.cancellation_token,
last_run_interrupted: self.last_run_interrupted,
cancel_tx: self.cancel_tx,
cancel_rx: self.cancel_rx,
_state: PhantomData,
}
}

View File

@ -9,7 +9,6 @@ use std::sync::atomic::{AtomicUsize, Ordering};
use schemars;
use serde;
use llm_worker::tool::{Tool, ToolMeta};
use llm_worker_macros::tool_registry;
// =============================================================================