diff --git a/docs/spec/cancellation.md b/docs/spec/cancellation.md index 54f5dc7..611f612 100644 --- a/docs/spec/cancellation.md +++ b/docs/spec/cancellation.md @@ -4,7 +4,7 @@ Workerの非同期キャンセル機構についての設計ドキュメント ## 概要 -`tokio_util::sync::CancellationToken`を用いて、別タスクからWorkerの実行を安全にキャンセルできる。 +`tokio::sync::mpsc`の通知チャネルを用いて、別タスクからWorkerの実行を安全にキャンセルできる。 ```rust let worker = Arc::new(Mutex::new(Worker::new(client))); @@ -19,15 +19,6 @@ let handle = tokio::spawn(async move { worker.lock().await.cancel(); ``` -## キャンセルポイント - -キャンセルは以下のタイミングでチェックされる: - -1. **ターンループ先頭** — `is_cancelled()`で即座にチェック -2. **ストリーム開始前** — `client.stream()`呼び出し時 -3. **ストリーム受信中** — `tokio::select!`で各イベント受信と並行監視 -4. **ツール実行中** — `join_all()`と並行監視 - ## キャンセル時の処理フロー ``` @@ -42,11 +33,10 @@ Err(WorkerError::Cancelled) // エラー返却 ## API -| メソッド | 説明 | -| ---------------------- | --------------------------------------------------------- | -| `cancel()` | キャンセルをトリガー | -| `is_cancelled()` | キャンセル状態を確認 | -| `cancellation_token()` | トークンへの参照を取得(`clone()`してタスク間で共有可能) | +| メソッド | 説明 | +| ----------------- | ------------------------------ | +| `cancel()` | キャンセルをトリガー | +| `cancel_sender()` | キャンセル通知用のSenderを取得 | ## on_abort フック @@ -69,22 +59,12 @@ async fn on_abort(&self, reason: &str) -> Result<(), HookError> { ## 既知の問題 -### 1. キャンセルトークンの再利用不可 +### on_abort の発火基準 -`CancellationToken`は一度キャンセルされると永続的にキャンセル状態になる。 -同じWorkerインスタンスで再度`run()`を呼ぶと即座に`Cancelled`エラーになる。 +`on_abort` は **interrupt(中断)** された場合に必ず発火する。 -**対応案:** +interrupt の例: -- `run()`開始時に新しいトークンを生成する -- `reset_cancellation()`メソッドを提供する - -### 2. Sync バウンドの追加(破壊的変更) - -`tokio::select!`使用のため、Handler/Scope型に`Sync`バウンドを追加した。 -既存のユーザーコードで`Sync`未実装の型を使用している場合、コンパイルエラーになる。 - -### 3. エラー時のon_abort呼び出し - -現在、`on_abort`はキャンセルとフックAbort時のみ呼ばれる。 -ストリームエラー等のその他エラー時には呼ばれないため、一貫性に欠ける可能性がある。 +- `WorkerError::Cancelled`(キャンセル) +- `WorkerError::Aborted`(フックによるAbort) +- ストリーム/ツール/クライアント/Hook の各種エラーで処理が中断された場合 diff --git a/llm-worker/examples/worker_cancel_demo.rs b/llm-worker/examples/worker_cancel_demo.rs index b4b7114..9d6e6d2 100644 --- a/llm-worker/examples/worker_cancel_demo.rs +++ b/llm-worker/examples/worker_cancel_demo.rs @@ -30,10 +30,10 @@ async fn main() -> Result<(), Box> { println!("🚀 Starting Worker..."); println!("💡 Will cancel after 2 seconds\n"); - // キャンセルトークンを先に取得(ロックを保持しない) - let cancel_token = { + // キャンセルSenderを先に取得(ロックを保持しない) + let cancel_tx = { let w = worker.lock().await; - w.cancellation_token().clone() + w.cancel_sender() }; // タスク1: Workerを実行 @@ -43,10 +43,10 @@ async fn main() -> Result<(), Box> { println!("📡 Sending request to LLM..."); match w.run("Tell me a very long story about a brave knight. Make it as detailed as possible with many paragraphs.").await { - Ok(WorkerResult::Finished(_)) => { + Ok(WorkerResult::Finished) => { println!("✅ Task completed normally"); } - Ok(WorkerResult::Paused(_)) => { + Ok(WorkerResult::Paused) => { println!("⏸️ Task paused"); } Err(e) => { @@ -59,7 +59,7 @@ async fn main() -> Result<(), Box> { tokio::spawn(async move { tokio::time::sleep(Duration::from_secs(2)).await; println!("\n🛑 Cancelling worker..."); - cancel_token.cancel(); + let _ = cancel_tx.send(()).await; }); // タスク完了を待つ diff --git a/llm-worker/src/worker.rs b/llm-worker/src/worker.rs index d8c32a0..396b7b1 100644 --- a/llm-worker/src/worker.rs +++ b/llm-worker/src/worker.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use std::sync::{Arc, Mutex}; use futures::StreamExt; -use tokio_util::sync::CancellationToken; +use tokio::sync::mpsc; use tracing::{debug, info, trace, warn}; use crate::{ @@ -78,11 +78,11 @@ pub struct WorkerConfig { /// Workerの実行結果(ステータス) #[derive(Debug)] -pub enum WorkerResult<'a> { +pub enum WorkerResult { /// 完了(ユーザー入力待ち状態) - Finished(&'a [Message]), + Finished, /// 一時停止(再開可能) - Paused(&'a [Message]), + Paused, } /// 内部用: ツール実行結果 @@ -182,8 +182,11 @@ pub struct Worker { turn_notifiers: Vec>, /// リクエスト設定(max_tokens, temperature等) request_config: RequestConfig, - /// キャンセレーショントークン(実行中断用) - cancellation_token: CancellationToken, + /// 前回の実行が中断されたかどうか + last_run_interrupted: bool, + /// キャンセル通知用チャネル(実行中断用) + cancel_tx: mpsc::Sender<()>, + cancel_rx: mpsc::Receiver<()>, /// 状態マーカー _state: PhantomData, } @@ -193,6 +196,57 @@ pub struct Worker { // ============================================================================= impl Worker { + fn reset_interruption_state(&mut self) { + self.last_run_interrupted = false; + } + + /// ターンを実行 + /// + /// 新しいユーザーメッセージを履歴に追加し、LLMにリクエストを送信する。 + /// ツール呼び出しがある場合は自動的にループする。 + pub async fn run( + &mut self, + user_input: impl Into, + ) -> Result { + self.reset_interruption_state(); + // Hook: on_prompt_submit + let mut user_message = Message::user(user_input); + let result = self.run_on_prompt_submit_hooks(&mut user_message).await; + let result = match result { + Ok(value) => value, + Err(err) => return self.finalize_interruption(Err(err)).await, + }; + match result { + OnPromptSubmitResult::Cancel(reason) => { + self.last_run_interrupted = true; + return self.finalize_interruption(Err(WorkerError::Aborted(reason))).await; + } + OnPromptSubmitResult::Continue => {} + } + self.history.push(user_message); + let result = self.run_turn_loop().await; + self.finalize_interruption(result).await + } + + fn drain_cancel_queue(&mut self) { + use tokio::sync::mpsc::error::TryRecvError; + loop { + match self.cancel_rx.try_recv() { + Ok(()) => continue, + Err(TryRecvError::Empty) | Err(TryRecvError::Disconnected) => break, + } + } + } + + fn try_cancelled(&mut self) -> bool { + use tokio::sync::mpsc::error::TryRecvError; + match self.cancel_rx.try_recv() { + Ok(()) => true, + Err(TryRecvError::Empty) => false, + Err(TryRecvError::Disconnected) => true, + } + } + /// イベント購読者を登録する /// /// 登録したSubscriberは、LLMからのストリーミングイベントを @@ -410,6 +464,11 @@ impl Worker { self.request_config.stop_sequences.clear(); } + /// キャンセル通知用Senderを取得する + pub fn cancel_sender(&self) -> mpsc::Sender<()> { + self.cancel_tx.clone() + } + /// リクエスト設定を一括で設定 pub fn set_request_config(&mut self, config: RequestConfig) { self.request_config = config; @@ -437,17 +496,17 @@ impl Worker { /// worker.lock().unwrap().cancel(); /// ``` pub fn cancel(&self) { - self.cancellation_token.cancel(); + let _ = self.cancel_tx.try_send(()); } /// キャンセルされているかチェック - pub fn is_cancelled(&self) -> bool { - self.cancellation_token.is_cancelled() + pub fn is_cancelled(&mut self) -> bool { + self.try_cancelled() } - /// キャンセレーショントークンへの参照を取得 - pub fn cancellation_token(&self) -> &CancellationToken { - &self.cancellation_token + /// 前回の実行が中断されたかどうか + pub fn last_run_interrupted(&self) -> bool { + self.last_run_interrupted } /// 登録されたツールからLLM用ToolDefinitionのリストを生成 @@ -636,6 +695,28 @@ impl Worker { Ok(()) } + async fn finalize_interruption( + &mut self, + result: Result, + ) -> Result { + match result { + Ok(value) => Ok(value), + Err(err) => { + self.last_run_interrupted = true; + let reason = match &err { + WorkerError::Aborted(reason) => reason.clone(), + WorkerError::Cancelled => "Cancelled".to_string(), + _ => err.to_string(), + }; + if let Err(hook_err) = self.run_on_abort_hooks(&reason).await { + self.last_run_interrupted = true; + return Err(hook_err); + } + Err(err) + } + } + } + /// 未実行のツール呼び出しがあるかチェック(Pauseからの復帰用) fn get_pending_tool_calls(&self) -> Option> { let last_msg = self.history.last()?; @@ -687,7 +768,10 @@ impl Worker { let mut skip = false; for hook in &self.hooks.pre_tool_call { - let result = hook.call(&mut context).await?; + let result = hook + .call(&mut context) + .await + .inspect_err(|_| self.last_run_interrupted = true)?; match result { PreToolCallResult::Continue => {} PreToolCallResult::Skip => { @@ -695,9 +779,11 @@ impl Worker { break; } PreToolCallResult::Abort(reason) => { + self.last_run_interrupted = true; return Err(WorkerError::Aborted(reason)); } PreToolCallResult::Pause => { + self.last_run_interrupted = true; return Ok(ToolExecutionResult::Paused); } } @@ -747,10 +833,12 @@ impl Worker { // ツール実行をキャンセル可能にする let mut results = tokio::select! { results = join_all(futures) => results, - _ = self.cancellation_token.cancelled() => { - info!("Tool execution cancelled"); + cancel = self.cancel_rx.recv() => { + if cancel.is_some() { + info!("Tool execution cancelled"); + } self.timeline.abort_current_block(); - self.run_on_abort_hooks("Cancelled").await?; + self.last_run_interrupted = true; return Err(WorkerError::Cancelled); } }; @@ -767,10 +855,14 @@ impl Worker { }; for hook in &self.hooks.post_tool_call { - let result = hook.call(&mut context).await?; + let result = hook + .call(&mut context) + .await + .inspect_err(|_| self.last_run_interrupted = true)?; match result { PostToolCallResult::Continue => {} PostToolCallResult::Abort(reason) => { + self.last_run_interrupted = true; return Err(WorkerError::Aborted(reason)); } } @@ -784,7 +876,9 @@ impl Worker { } /// 内部で使用するターン実行ロジック - async fn run_turn_loop(&mut self) -> Result, WorkerError> { + async fn run_turn_loop(&mut self) -> Result { + self.reset_interruption_state(); + self.drain_cancel_queue(); let tool_definitions = self.build_tool_definitions(); info!( @@ -796,24 +890,31 @@ impl Worker { // Resume check: Pending tool calls if let Some(tool_calls) = self.get_pending_tool_calls() { info!("Resuming pending tool calls"); - match self.execute_tools(tool_calls).await? { - ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)), - ToolExecutionResult::Completed(results) => { + match self.execute_tools(tool_calls).await { + Ok(ToolExecutionResult::Paused) => { + self.last_run_interrupted = true; + return Ok(WorkerResult::Paused); + } + Ok(ToolExecutionResult::Completed(results)) => { for result in results { self.history .push(Message::tool_result(&result.tool_use_id, &result.content)); } // Continue to loop } + Err(err) => { + self.last_run_interrupted = true; + return Err(err); + } } } loop { // キャンセルチェック - if self.cancellation_token.is_cancelled() { + if self.try_cancelled() { info!("Execution cancelled"); self.timeline.abort_current_block(); - self.run_on_abort_hooks("Cancelled").await?; + self.last_run_interrupted = true; return Err(WorkerError::Cancelled); } @@ -825,14 +926,17 @@ impl Worker { } // Hook: pre_llm_request - let (control, request_context) = self.run_pre_llm_request_hooks().await?; + let (control, request_context) = self + .run_pre_llm_request_hooks() + .await + .inspect_err(|_| self.last_run_interrupted = true)?; match control { PreLlmRequestResult::Cancel(reason) => { info!(reason = %reason, "Aborted by hook"); for notifier in &self.turn_notifiers { notifier.on_turn_end(current_turn); } - self.run_on_abort_hooks(&reason).await?; + self.last_run_interrupted = true; return Err(WorkerError::Aborted(reason)); } PreLlmRequestResult::Continue => {} @@ -853,11 +957,14 @@ impl Worker { // ストリームを取得(キャンセル可能) let mut stream = tokio::select! { - stream_result = self.client.stream(request) => stream_result?, - _ = self.cancellation_token.cancelled() => { - info!("Cancelled before stream started"); + stream_result = self.client.stream(request) => stream_result + .inspect_err(|_| self.last_run_interrupted = true)?, + cancel = self.cancel_rx.recv() => { + if cancel.is_some() { + info!("Cancelled before stream started"); + } self.timeline.abort_current_block(); - self.run_on_abort_hooks("Cancelled").await?; + self.last_run_interrupted = true; return Err(WorkerError::Cancelled); } }; @@ -877,7 +984,8 @@ impl Worker { warn!(error = %e, "Stream error"); } } - let event = result?; + let event = result + .inspect_err(|_| self.last_run_interrupted = true)?; let timeline_event: crate::timeline::event::Event = event.into(); self.timeline.dispatch(&timeline_event); } @@ -885,10 +993,12 @@ impl Worker { } } // キャンセル待機 - _ = self.cancellation_token.cancelled() => { - info!("Stream cancelled"); + cancel = self.cancel_rx.recv() => { + if cancel.is_some() { + info!("Stream cancelled"); + } self.timeline.abort_current_block(); - self.run_on_abort_hooks("Cancelled").await?; + self.last_run_interrupted = true; return Err(WorkerError::Cancelled); } } @@ -913,30 +1023,42 @@ impl Worker { if tool_calls.is_empty() { // ツール呼び出しなし → ターン終了判定 - let turn_result = self.run_on_turn_end_hooks().await?; + let turn_result = self + .run_on_turn_end_hooks() + .await + .inspect_err(|_| self.last_run_interrupted = true)?; match turn_result { OnTurnEndResult::Finish => { - return Ok(WorkerResult::Finished(&self.history)); + self.last_run_interrupted = false; + return Ok(WorkerResult::Finished); } OnTurnEndResult::ContinueWithMessages(additional) => { self.history.extend(additional); continue; } OnTurnEndResult::Paused => { - return Ok(WorkerResult::Paused(&self.history)); + self.last_run_interrupted = true; + return Ok(WorkerResult::Paused); } } } // ツール実行 - match self.execute_tools(tool_calls).await? { - ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)), - ToolExecutionResult::Completed(results) => { + match self.execute_tools(tool_calls).await { + Ok(ToolExecutionResult::Paused) => { + self.last_run_interrupted = true; + return Ok(WorkerResult::Paused); + } + Ok(ToolExecutionResult::Completed(results)) => { for result in results { self.history .push(Message::tool_result(&result.tool_use_id, &result.content)); } } + Err(err) => { + self.last_run_interrupted = true; + return Err(err); + } } } } @@ -944,8 +1066,10 @@ impl Worker { /// 実行を再開(Pause状態からの復帰) /// /// 新しいユーザーメッセージを履歴に追加せず、現在の状態からターン処理を再開する。 - pub async fn resume(&mut self) -> Result, WorkerError> { - self.run_turn_loop().await + pub async fn resume(&mut self) -> Result { + self.reset_interruption_state(); + let result = self.run_turn_loop().await; + self.finalize_interruption(result).await } } @@ -959,6 +1083,7 @@ impl Worker { let text_block_collector = TextBlockCollector::new(); let tool_call_collector = ToolCallCollector::new(); let mut timeline = Timeline::new(); + let (cancel_tx, cancel_rx) = mpsc::channel(1); // コレクターをTimelineに登録 timeline.on_text_block(text_block_collector.clone()); @@ -977,7 +1102,9 @@ impl Worker { turn_count: 0, turn_notifiers: Vec::new(), request_config: RequestConfig::default(), - cancellation_token: CancellationToken::new(), + last_run_interrupted: false, + cancel_tx, + cancel_rx, _state: PhantomData, } } @@ -1146,46 +1273,13 @@ impl Worker { turn_count: self.turn_count, turn_notifiers: self.turn_notifiers, request_config: self.request_config, - cancellation_token: self.cancellation_token, + last_run_interrupted: self.last_run_interrupted, + cancel_tx: self.cancel_tx, + cancel_rx: self.cancel_rx, _state: PhantomData, } } - /// ターンを実行(Mutable状態) - /// - /// 新しいユーザーメッセージを履歴に追加し、LLMにリクエストを送信する。 - /// ツール呼び出しがある場合は自動的にループする。 - /// - /// 注意: この関数は履歴を変更するため、キャッシュ保護が必要な場合は - /// `lock()` を呼んでからLocked状態で `run` を使用すること。 - pub async fn run( - &mut self, - user_input: impl Into, - ) -> Result, WorkerError> { - // Hook: on_prompt_submit - let mut user_message = Message::user(user_input); - let result = self.run_on_prompt_submit_hooks(&mut user_message).await?; - match result { - OnPromptSubmitResult::Cancel(reason) => { - self.run_on_abort_hooks(&reason).await?; - return Err(WorkerError::Aborted(reason)); - } - OnPromptSubmitResult::Continue => {} - } - self.history.push(user_message); - self.run_turn_loop().await - } - - /// 複数メッセージでターンを実行(Mutable状態) - /// - /// 指定されたメッセージを履歴に追加してから実行する。 - pub async fn run_with_messages( - &mut self, - messages: Vec, - ) -> Result, WorkerError> { - self.history.extend(messages); - self.run_turn_loop().await - } } // ============================================================================= @@ -1193,37 +1287,6 @@ impl Worker { // ============================================================================= impl Worker { - /// ターンを実行(Locked状態) - /// - /// 新しいユーザーメッセージを履歴の末尾に追加し、LLMにリクエストを送信する。 - /// ロック時点より前の履歴(プレフィックス)は不変であるため、キャッシュヒットが保証される。 - pub async fn run( - &mut self, - user_input: impl Into, - ) -> Result, WorkerError> { - // Hook: on_prompt_submit - let mut user_message = Message::user(user_input); - let result = self.run_on_prompt_submit_hooks(&mut user_message).await?; - match result { - OnPromptSubmitResult::Cancel(reason) => { - self.run_on_abort_hooks(&reason).await?; - return Err(WorkerError::Aborted(reason)); - } - OnPromptSubmitResult::Continue => {} - } - self.history.push(user_message); - self.run_turn_loop().await - } - - /// 複数メッセージでターンを実行(Locked状態) - pub async fn run_with_messages( - &mut self, - messages: Vec, - ) -> Result, WorkerError> { - self.history.extend(messages); - self.run_turn_loop().await - } - /// ロック時点のプレフィックス長を取得 pub fn locked_prefix_len(&self) -> usize { self.locked_prefix_len @@ -1247,7 +1310,9 @@ impl Worker { turn_count: self.turn_count, turn_notifiers: self.turn_notifiers, request_config: self.request_config, - cancellation_token: self.cancellation_token, + last_run_interrupted: self.last_run_interrupted, + cancel_tx: self.cancel_tx, + cancel_rx: self.cancel_rx, _state: PhantomData, } } diff --git a/llm-worker/tests/tool_macro_test.rs b/llm-worker/tests/tool_macro_test.rs index 7b04cfe..4676852 100644 --- a/llm-worker/tests/tool_macro_test.rs +++ b/llm-worker/tests/tool_macro_test.rs @@ -9,7 +9,6 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use schemars; use serde; -use llm_worker::tool::{Tool, ToolMeta}; use llm_worker_macros::tool_registry; // =============================================================================