From 5691b09fc810dc966dbbe8d8002915b4b9fb64cb Mon Sep 17 00:00:00 2001 From: Hare Date: Fri, 9 Jan 2026 19:18:20 +0900 Subject: [PATCH] feat: Implement HookEventKind --- Cargo.lock | 239 ++++++++++++++- README.md | 1 - docs/spec/basis.md | 2 +- docs/spec/cancellation.md | 90 ++++++ docs/spec/hooks_design.md | 286 +++++++++--------- docs/spec/worker_design.md | 76 +++-- llm-worker/Cargo.toml | 3 +- llm-worker/examples/worker_cancel_demo.rs | 71 +++++ llm-worker/examples/worker_cli.rs | 12 +- llm-worker/src/hook.rs | 145 ++++------ llm-worker/src/lib.rs | 8 +- llm-worker/src/subscriber.rs | 4 +- llm-worker/src/timeline/timeline.rs | 63 ++-- llm-worker/src/worker.rs | 304 +++++++++++++------- llm-worker/tests/parallel_execution_test.rs | 28 +- 15 files changed, 916 insertions(+), 416 deletions(-) create mode 100644 docs/spec/cancellation.md create mode 100644 llm-worker/examples/worker_cancel_demo.rs diff --git a/Cargo.lock b/Cargo.lock index 0a959f7..76cefe8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -104,9 +104,9 @@ checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" [[package]] name = "cc" -version = "1.2.51" +version = "1.2.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a0aeaff4ff1a90589618835a598e545176939b97874f7abc7851caa0618f203" +checksum = "cd4932aefd12402b36c60956a4fe0035421f544799057659ff86f923657aada3" dependencies = [ "find-msvc-tools", "shlex", @@ -203,6 +203,12 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + [[package]] name = "errno" version = "0.3.14" @@ -232,9 +238,15 @@ checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] name = "find-msvc-tools" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645cbb3a84e60b7531617d5ae4e57f7e27308f6445f5abf653209ea76dec8dff" +checksum = "f449e6c6c08c865631d4890cfacf252b3d396c9bcc83adb6623cdb02a8336c41" + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "foreign-types" @@ -349,6 +361,17 @@ dependencies = [ "slab", ] +[[package]] +name = "getrandom" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "getrandom" version = "0.3.4" @@ -361,6 +384,31 @@ dependencies = [ "wasip2", ] +[[package]] +name = "h2" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + [[package]] name = "heck" version = "0.5.0" @@ -416,6 +464,7 @@ dependencies = [ "bytes", "futures-channel", "futures-core", + "h2", "http", "http-body", "httparse", @@ -427,6 +476,22 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http", + "hyper", + "hyper-util", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", +] + [[package]] name = "hyper-tls" version = "0.6.0" @@ -569,6 +634,16 @@ dependencies = [ "icu_properties", ] +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown", +] + [[package]] name = "ipnet" version = "2.11.0" @@ -648,6 +723,7 @@ dependencies = [ "tempfile", "thiserror", "tokio", + "tokio-util", "tracing", "tracing-subscriber", ] @@ -895,10 +971,12 @@ dependencies = [ "bytes", "futures-core", "futures-util", + "h2", "http", "http-body", "http-body-util", "hyper", + "hyper-rustls", "hyper-tls", "hyper-util", "js-sys", @@ -923,6 +1001,20 @@ dependencies = [ "web-sys", ] +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.16", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + [[package]] name = "rustix" version = "1.1.3" @@ -936,6 +1028,19 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "rustls" +version = "0.23.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b" +dependencies = [ + "once_cell", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + [[package]] name = "rustls-pki-types" version = "1.13.2" @@ -945,6 +1050,17 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-webpki" +version = "0.103.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.22" @@ -1111,6 +1227,12 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" version = "2.0.114" @@ -1149,7 +1271,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" dependencies = [ "fastrand", - "getrandom", + "getrandom 0.3.4", "once_cell", "rustix", "windows-sys 0.61.2", @@ -1230,6 +1352,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.18" @@ -1361,6 +1493,12 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "url" version = "2.5.8" @@ -1508,13 +1646,22 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.60.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" dependencies = [ - "windows-targets", + "windows-targets 0.53.5", ] [[package]] @@ -1526,6 +1673,22 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + [[package]] name = "windows-targets" version = "0.53.5" @@ -1533,58 +1696,106 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" dependencies = [ "windows-link", - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.53.1", + "windows_aarch64_msvc 0.53.1", + "windows_i686_gnu 0.53.1", + "windows_i686_gnullvm 0.53.1", + "windows_i686_msvc 0.53.1", + "windows_x86_64_gnu 0.53.1", + "windows_x86_64_gnullvm 0.53.1", + "windows_x86_64_msvc 0.53.1", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + [[package]] name = "windows_aarch64_gnullvm" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + [[package]] name = "windows_aarch64_msvc" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + [[package]] name = "windows_i686_gnu" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + [[package]] name = "windows_i686_gnullvm" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + [[package]] name = "windows_i686_msvc" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + [[package]] name = "windows_x86_64_gnu" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + [[package]] name = "windows_x86_64_gnullvm" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + [[package]] name = "windows_x86_64_msvc" version = "0.53.1" diff --git a/README.md b/README.md index b60d501..e6b13db 100644 --- a/README.md +++ b/README.md @@ -33,4 +33,3 @@ let history = worker.run("What is 2+2?").await?; ## License MIT - diff --git a/docs/spec/basis.md b/docs/spec/basis.md index 903658e..d3fcb99 100644 --- a/docs/spec/basis.md +++ b/docs/spec/basis.md @@ -17,7 +17,7 @@ LLMを用いたワーカーを作成する小型のSDK・ライブラリ。 module構成概念図 -``` +```plaintext worker ├── context ├── llm_client diff --git a/docs/spec/cancellation.md b/docs/spec/cancellation.md new file mode 100644 index 0000000..54f5dc7 --- /dev/null +++ b/docs/spec/cancellation.md @@ -0,0 +1,90 @@ +# 非同期キャンセル設計 + +Workerの非同期キャンセル機構についての設計ドキュメント。 + +## 概要 + +`tokio_util::sync::CancellationToken`を用いて、別タスクからWorkerの実行を安全にキャンセルできる。 + +```rust +let worker = Arc::new(Mutex::new(Worker::new(client))); + +// 実行タスク +let w = worker.clone(); +let handle = tokio::spawn(async move { + w.lock().await.run("prompt").await +}); + +// キャンセル +worker.lock().await.cancel(); +``` + +## キャンセルポイント + +キャンセルは以下のタイミングでチェックされる: + +1. **ターンループ先頭** — `is_cancelled()`で即座にチェック +2. **ストリーム開始前** — `client.stream()`呼び出し時 +3. **ストリーム受信中** — `tokio::select!`で各イベント受信と並行監視 +4. **ツール実行中** — `join_all()`と並行監視 + +## キャンセル時の処理フロー + +``` +キャンセル検知 + ↓ +timeline.abort_current_block() // 進行中ブロックの終端処理 + ↓ +run_on_abort_hooks("Cancelled") // on_abort フック呼び出し + ↓ +Err(WorkerError::Cancelled) // エラー返却 +``` + +## API + +| メソッド | 説明 | +| ---------------------- | --------------------------------------------------------- | +| `cancel()` | キャンセルをトリガー | +| `is_cancelled()` | キャンセル状態を確認 | +| `cancellation_token()` | トークンへの参照を取得(`clone()`してタスク間で共有可能) | + +## on_abort フック + +`Hook::on_abort(&self, reason: &str)`がキャンセル時に呼ばれる。 +クリーンアップ処理やログ記録に使用できる。 + +```rust +async fn on_abort(&self, reason: &str) -> Result<(), HookError> { + log::info!("Aborted: {}", reason); + Ok(()) +} +``` + +呼び出しタイミング: + +- `WorkerError::Cancelled` — reason: `"Cancelled"` +- `ControlFlow::Abort(reason)` — reason: フックが指定した理由 + +--- + +## 既知の問題 + +### 1. キャンセルトークンの再利用不可 + +`CancellationToken`は一度キャンセルされると永続的にキャンセル状態になる。 +同じWorkerインスタンスで再度`run()`を呼ぶと即座に`Cancelled`エラーになる。 + +**対応案:** + +- `run()`開始時に新しいトークンを生成する +- `reset_cancellation()`メソッドを提供する + +### 2. Sync バウンドの追加(破壊的変更) + +`tokio::select!`使用のため、Handler/Scope型に`Sync`バウンドを追加した。 +既存のユーザーコードで`Sync`未実装の型を使用している場合、コンパイルエラーになる。 + +### 3. エラー時のon_abort呼び出し + +現在、`on_abort`はキャンセルとフックAbort時のみ呼ばれる。 +ストリームエラー等のその他エラー時には呼ばれないため、一貫性に欠ける可能性がある。 diff --git a/docs/spec/hooks_design.md b/docs/spec/hooks_design.md index b56357b..08354a6 100644 --- a/docs/spec/hooks_design.md +++ b/docs/spec/hooks_design.md @@ -3,7 +3,8 @@ ## 概要 HookはWorker層でのターン制御に介入するためのメカニズムです。 -Claude CodeのHooks機能に着想を得ており、メッセージ送信・ツール実行・ターン終了の各ポイントで処理を差し込むことができます。 + +メッセージ送信・ツール実行・ターン終了等の各ポイントで処理を差し込むことができます。 ## コンセプト @@ -11,76 +12,82 @@ Claude CodeのHooks機能に着想を得ており、メッセージ送信・ツ - **Contextへのアクセス**: メッセージ履歴を読み書き可能 - **非破壊的チェーン**: 複数のHookを登録順に実行、後続Hookへの影響を制御 +## Hook一覧 + +| Hook | タイミング | 主な用途 | 戻り値 | +| ------------------ | -------------------------- | --------------------- | ---------------------- | +| `on_message_send` | LLM送信前 | コンテキスト改変/検証 | `OnMessageSendResult` | +| `before_tool_call` | ツール実行前 | 実行許可/引数改変 | `BeforeToolCallResult` | +| `after_tool_call` | ツール実行後 | 結果加工/マスキング | `AfterToolCallResult` | +| `on_turn_end` | ツールなしでターン終了直前 | 検証/リトライ指示 | `OnTurnEndResult` | +| `on_abort` | 中断時 | クリーンアップ/通知 | `()` | + ## Hook Trait ```rust #[async_trait] -pub trait WorkerHook: Send + Sync { - /// メッセージ送信前 - /// リクエストに含まれるメッセージリストを改変できる - async fn on_message_send( - &self, - context: &mut Vec, - ) -> Result { - Ok(ControlFlow::Continue) - } - - /// ツール実行前 - /// 実行をキャンセルしたり、引数を書き換えることができる - async fn before_tool_call( - &self, - tool_call: &mut ToolCall, - ) -> Result { - Ok(ControlFlow::Continue) - } - - /// ツール実行後 - /// 結果を書き換えたり、隠蔽したりできる - async fn after_tool_call( - &self, - tool_result: &mut ToolResult, - ) -> Result { - Ok(ControlFlow::Continue) - } - - /// ターン終了時 - /// 生成されたメッセージを検査し、必要ならリトライを指示できる - async fn on_turn_end( - &self, - messages: &[Message], - ) -> Result { - Ok(TurnResult::Finish) - } +pub trait Hook: Send + Sync { + async fn call(&self, input: &mut E::Input) -> Result; } ``` ## 制御フロー型 -### ControlFlow +### HookEventKind / Result -Hook処理の継続/中断を制御する列挙型。 +Hookイベントごとに入力/出力型を分離し、意味のない制御フローを排除する。 ```rust -pub enum ControlFlow { - /// 処理を続行(後続Hookも実行) +pub trait HookEventKind { + type Input; + type Output; +} + +pub struct OnMessageSend; +pub struct BeforeToolCall; +pub struct AfterToolCall; +pub struct OnTurnEnd; +pub struct OnAbort; + +pub enum OnMessageSendResult { + Continue, + Cancel(String), +} + +pub enum BeforeToolCallResult { Continue, - /// 現在の処理をスキップ(ツール実行をスキップ等) Skip, - /// 処理全体を中断(エラーとして扱う) Abort(String), + Pause, +} + +pub enum AfterToolCallResult { + Continue, + Abort(String), +} + +pub enum OnTurnEndResult { + Finish, + ContinueWithMessages(Vec), + Paused, } ``` -### TurnResult +### Tool Call Context -ターン終了時の判定結果を表す列挙型。 +`before_tool_call` / `after_tool_call` は、ツール実行の文脈を含む入力を受け取る。 ```rust -pub enum TurnResult { - /// ターンを正常終了 - Finish, - /// メッセージを追加してターン継続(自己修正など) - ContinueWithMessages(Vec), +pub struct ToolCallContext { + pub call: ToolCall, + pub meta: ToolMeta, // 不変メタデータ + pub tool: Arc, // 状態アクセス用 +} + +pub struct ToolResultContext { + pub result: ToolResult, + pub meta: ToolMeta, + pub tool: Arc, } ``` @@ -90,28 +97,30 @@ pub enum TurnResult { Worker::run() ループ │ ├─▶ on_message_send ──────────────────────────────┐ -│ コンテキストの改変、バリデーション、 │ -│ システムプロンプト注入などが可能 │ -│ │ -├─▶ LLMリクエスト送信 & ストリーム処理 │ -│ │ -├─▶ ツール呼び出しがある場合: │ -│ │ │ +│ コンテキストの改変、バリデーション、 │ +│ システムプロンプト注入などが可能 │ +│ │ +├─▶ LLMリクエスト送信 & ストリーム処理 │ +│ │ +├─▶ ツール呼び出しがある場合: │ +│ │ │ │ ├─▶ before_tool_call (各ツールごと・逐次) │ -│ │ 実行可否の判定、引数の改変 │ -│ │ │ +│ │ 実行可否の判定、引数の改変 │ +│ │ │ │ ├─▶ ツール並列実行 (join_all) │ -│ │ │ +│ │ │ │ └─▶ after_tool_call (各結果ごと・逐次) │ -│ 結果の確認、加工、ログ出力 │ -│ │ +│ 結果の確認、加工、ログ出力 │ +│ │ ├─▶ ツール結果をコンテキストに追加 → ループ先頭へ │ -│ │ -└─▶ ツールなしの場合: │ - │ │ - └─▶ on_turn_end ─────────────────────────────┘ +│ │ +└─▶ ツールなしの場合: │ + │ │ + └─▶ on_turn_end ─────────────────────────────┘ 最終応答のチェック(Lint/Fmt等) エラーがあればContinueWithMessagesでリトライ + +※ 中断時は on_abort が呼ばれる ``` ## 各Hookの詳細 @@ -121,10 +130,12 @@ Worker::run() ループ **呼び出しタイミング**: LLMへリクエスト送信前(ターンループの冒頭) **用途**: + - コンテキストへのシステムメッセージ注入 - メッセージのバリデーション - 機密情報のフィルタリング - リクエスト内容のログ出力 +- `OnMessageSendResult::Cancel` による送信キャンセル **例**: メッセージにタイムスタンプを追加 @@ -132,14 +143,14 @@ Worker::run() ループ struct TimestampHook; #[async_trait] -impl WorkerHook for TimestampHook { - async fn on_message_send( +impl Hook for TimestampHook { + async fn call( &self, context: &mut Vec, - ) -> Result { + ) -> Result { let timestamp = chrono::Local::now().to_rfc3339(); context.insert(0, Message::user(format!("[{}]", timestamp))); - Ok(ControlFlow::Continue) + Ok(OnMessageSendResult::Continue) } } ``` @@ -149,10 +160,15 @@ impl WorkerHook for TimestampHook { **呼び出しタイミング**: 各ツール実行前(並列実行フェーズの前) **用途**: + - 危険なツールのブロック - 引数のサニタイズ - 確認プロンプトの表示(UIとの連携) - 実行ログの記録 +- `BeforeToolCallResult::Pause` による一時停止 + +**入力**: +- `ToolCallContext`(`ToolCall` + `ToolMeta` + `Arc`) **例**: 特定ツールをブロック @@ -162,16 +178,16 @@ struct ToolBlocker { } #[async_trait] -impl WorkerHook for ToolBlocker { - async fn before_tool_call( +impl Hook for ToolBlocker { + async fn call( &self, - tool_call: &mut ToolCall, - ) -> Result { - if self.blocked_tools.contains(&tool_call.name) { - println!("Blocked tool: {}", tool_call.name); - Ok(ControlFlow::Skip) + ctx: &mut ToolCallContext, + ) -> Result { + if self.blocked_tools.contains(&ctx.call.name) { + println!("Blocked tool: {}", ctx.call.name); + Ok(BeforeToolCallResult::Skip) } else { - Ok(ControlFlow::Continue) + Ok(BeforeToolCallResult::Continue) } } } @@ -182,10 +198,13 @@ impl WorkerHook for ToolBlocker { **呼び出しタイミング**: 各ツール実行後(並列実行フェーズの後) **用途**: + - 結果の加工・フォーマット - 機密情報のマスキング - 結果のキャッシュ - 実行結果のログ出力 +**入力**: +- `ToolResultContext`(`ToolResult` + `ToolMeta` + `Arc`) **例**: 結果にプレフィックスを追加 @@ -193,15 +212,15 @@ impl WorkerHook for ToolBlocker { struct ResultFormatter; #[async_trait] -impl WorkerHook for ResultFormatter { - async fn after_tool_call( +impl Hook for ResultFormatter { + async fn call( &self, - tool_result: &mut ToolResult, - ) -> Result { - if !tool_result.is_error { - tool_result.content = format!("[OK] {}", tool_result.content); + ctx: &mut ToolResultContext, + ) -> Result { + if !ctx.result.is_error { + ctx.result.content = format!("[OK] {}", ctx.result.content); } - Ok(ControlFlow::Continue) + Ok(AfterToolCallResult::Continue) } } ``` @@ -211,10 +230,22 @@ impl WorkerHook for ResultFormatter { **呼び出しタイミング**: ツール呼び出しなしでターンが終了する直前 **用途**: + - 生成されたコードのLint/Fmt - 出力形式のバリデーション - 自己修正のためのリトライ指示 - 最終結果のログ出力 +- `OnTurnEndResult::Paused` による一時停止 + +### on_abort + +**呼び出しタイミング**: キャンセル/エラー/AbortなどでWorkerが中断された時 + +**用途**: + +- クリーンアップ処理 +- 中断理由のログ出力 +- 外部システムへの通知 **例**: JSON形式のバリデーション @@ -222,11 +253,11 @@ impl WorkerHook for ResultFormatter { struct JsonValidator; #[async_trait] -impl WorkerHook for JsonValidator { - async fn on_turn_end( +impl Hook for JsonValidator { + async fn call( &self, - messages: &[Message], - ) -> Result { + messages: &mut Vec, + ) -> Result { // 最後のアシスタントメッセージを取得 let last = messages.iter().rev() .find(|m| m.role == Role::Assistant); @@ -236,25 +267,25 @@ impl WorkerHook for JsonValidator { // JSONとしてパースを試みる if serde_json::from_str::(text).is_err() { // 失敗したらリトライ指示 - return Ok(TurnResult::ContinueWithMessages(vec![ + return Ok(OnTurnEndResult::ContinueWithMessages(vec![ Message::user("Invalid JSON. Please fix and try again.") ])); } } } - Ok(TurnResult::Finish) + Ok(OnTurnEndResult::Finish) } } ``` ## 複数Hookの実行順序 -Hookは**登録順**に実行されます。 +Hookは**イベントごとに登録順**に実行されます。 ```rust -worker.add_hook(HookA); // 1番目に実行 -worker.add_hook(HookB); // 2番目に実行 -worker.add_hook(HookC); // 3番目に実行 +worker.add_before_tool_call_hook(HookA); // 1番目に実行 +worker.add_before_tool_call_hook(HookB); // 2番目に実行 +worker.add_before_tool_call_hook(HookC); // 3番目に実行 ``` ### 制御フローの伝播 @@ -262,6 +293,7 @@ worker.add_hook(HookC); // 3番目に実行 - `Continue`: 後続Hookも実行 - `Skip`: 現在の処理をスキップし、後続Hookは実行しない - `Abort`: 即座にエラーを返し、処理全体を中断 +- `Pause`: Workerを一時停止(再開は`resume`) ``` Hook A: Continue → Hook B: Skip → (Hook Cは実行されない) @@ -271,40 +303,27 @@ Hook A: Continue → Hook B: Skip → (Hook Cは実行されない) Hook A: Continue → Hook B: Abort("reason") ↓ WorkerError::Aborted + +Hook A: Continue → Hook B: Pause + ↓ + WorkerResult::Paused ``` ## 設計上のポイント -### 1. デフォルト実装 +### 1. イベントごとの実装 -全メソッドにデフォルト実装があるため、必要なメソッドだけオーバーライドすれば良い。 - -```rust -struct SimpleLogger; - -#[async_trait] -impl WorkerHook for SimpleLogger { - // on_message_send だけ実装 - async fn on_message_send( - &self, - context: &mut Vec, - ) -> Result { - println!("Sending {} messages", context.len()); - Ok(ControlFlow::Continue) - } - // 他のメソッドはデフォルト(Continue/Finish) -} -``` +必要なイベントのみ `Hook` を実装する。 ### 2. 可変参照による改変 `&mut`で引数を受け取るため、直接改変が可能。 ```rust -async fn before_tool_call(&self, tool_call: &mut ToolCall) -> ... { +async fn call(&self, ctx: &mut ToolCallContext) -> ... { // 引数を直接書き換え - tool_call.input["sanitized"] = json!(true); - Ok(ControlFlow::Continue) + ctx.call.input["sanitized"] = json!(true); + Ok(BeforeToolCallResult::Continue) } ``` @@ -316,7 +335,7 @@ async fn before_tool_call(&self, tool_call: &mut ToolCall) -> ... { ### 4. Send + Sync 要件 -`WorkerHook`は`Send + Sync`を要求するため、スレッドセーフな実装が必要。 +`Hook`は`Send + Sync`を要求するため、スレッドセーフな実装が必要。 状態を持つ場合は`Arc>`や`AtomicUsize`などを使用する。 ```rust @@ -325,24 +344,24 @@ struct CountingHook { } #[async_trait] -impl WorkerHook for CountingHook { - async fn before_tool_call(&self, _: &mut ToolCall) -> Result { +impl Hook for CountingHook { + async fn call(&self, _: &mut ToolCallContext) -> Result { self.count.fetch_add(1, Ordering::SeqCst); - Ok(ControlFlow::Continue) + Ok(BeforeToolCallResult::Continue) } } ``` ## 典型的なユースケース -| ユースケース | 使用Hook | 処理内容 | -|-------------|----------|----------| -| ツール許可制御 | `before_tool_call` | 危険なツールをSkip | -| 実行ログ | `before/after_tool_call` | 呼び出しと結果を記録 | -| 出力バリデーション | `on_turn_end` | 形式チェック、リトライ指示 | -| コンテキスト注入 | `on_message_send` | システムメッセージ追加 | -| 結果のサニタイズ | `after_tool_call` | 機密情報のマスキング | -| レート制限 | `before_tool_call` | 呼び出し頻度の制御 | +| ユースケース | 使用Hook | 処理内容 | +| ------------------ | ------------------------ | -------------------------- | +| ツール許可制御 | `before_tool_call` | 危険なツールをSkip | +| 実行ログ | `before/after_tool_call` | 呼び出しと結果を記録 | +| 出力バリデーション | `on_turn_end` | 形式チェック、リトライ指示 | +| コンテキスト注入 | `on_message_send` | システムメッセージ追加 | +| 結果のサニタイズ | `after_tool_call` | 機密情報のマスキング | +| レート制限 | `before_tool_call` | 呼び出し頻度の制御 | ## TODO @@ -350,11 +369,14 @@ impl WorkerHook for CountingHook { 現在のHooks実装は基本的なユースケースをカバーしているが、以下の点について将来的に厳密な仕様を定義する必要がある: -- **エラーハンドリングの明確化**: `HookError`発生時のリカバリー戦略、部分的な失敗の扱い +- **エラーハンドリングの明確化**: + `HookError`発生時のリカバリー戦略、部分的な失敗の扱い - **Hook間の依存関係**: 複数Hookの実行順序が結果に影響する場合のセマンティクス - **非同期キャンセル**: Hook実行中のキャンセル(タイムアウト等)の振る舞い -- **状態の一貫性**: `on_message_send`で改変されたコンテキストが後続処理で期待通りに反映される保証 -- **リトライ制限**: `on_turn_end`での`ContinueWithMessages`による無限ループ防止策 +- **状態の一貫性**: + `on_message_send`で改変されたコンテキストが後続処理で期待通りに反映される保証 +- **リトライ制限**: + `on_turn_end`での`ContinueWithMessages`による無限ループ防止策 - **Hook優先度**: 登録順以外の優先度指定メカニズムの必要性 - **条件付きHook**: 特定条件でのみ有効化されるHookパターン - **テスト容易性**: Hookのモック/スタブ作成のためのユーティリティ diff --git a/docs/spec/worker_design.md b/docs/spec/worker_design.md index 300f7ce..767b3eb 100644 --- a/docs/spec/worker_design.md +++ b/docs/spec/worker_design.md @@ -178,41 +178,60 @@ Workerは生成されたラッパー構造体を `Box` として保持 ```rust #[async_trait] -pub trait WorkerHook: Send + Sync { - /// メッセージ送信前。 - /// リクエストに含まれるメッセージリストを改変できる。 - async fn on_message_send(&self, context: &mut Vec) -> Result { - Ok(ControlFlow::Continue) - } - - /// ツール実行前。 - /// 実行をキャンセルしたり、引数を書き換えることができる。 - async fn before_tool_call(&self, tool_call: &mut ToolCall) -> Result { - Ok(ControlFlow::Continue) - } - - /// ツール実行後。 - /// 結果を書き換えたり、隠蔽したりできる。 - async fn after_tool_call(&self, tool_result: &mut ToolResult) -> Result { - Ok(ControlFlow::Continue) - } - - /// ターン終了時。 - /// 生成されたメッセージを検査し、必要ならリトライ(ContinueWithMessages)を指示できる。 - async fn on_turn_end(&self, messages: &[Message]) -> Result { - Ok(TurnResult::Finish) - } +pub trait Hook: Send + Sync { + async fn call(&self, input: &mut E::Input) -> Result; } -pub enum ControlFlow { +pub trait HookEventKind { + type Input; + type Output; +} + +pub struct OnMessageSend; +pub struct BeforeToolCall; +pub struct AfterToolCall; +pub struct OnTurnEnd; +pub struct OnAbort; + +pub enum OnMessageSendResult { Continue, - Skip, // Tool実行などをスキップ - Abort(String), // 処理中断 + Cancel(String), } -pub enum TurnResult { +pub enum BeforeToolCallResult { + Continue, + Skip, // Tool実行などをスキップ + Abort(String), // 処理中断 + Pause, +} + +pub enum AfterToolCallResult { + Continue, + Abort(String), +} + +pub enum OnTurnEndResult { Finish, ContinueWithMessages(Vec), // メッセージを追加してターン継続(自己修正など) + Paused, +} +``` + +### Tool Call Context + +`before_tool_call` / `after_tool_call` は、ツール実行の文脈を含む入力を受け取る。 + +```rust +pub struct ToolCallContext { + pub call: ToolCall, + pub meta: ToolMeta, // 不変メタデータ + pub tool: Arc, // 状態アクセス用 +} + +pub struct ToolResultContext { + pub result: ToolResult, + pub meta: ToolMeta, + pub tool: Arc, } ``` @@ -433,4 +452,3 @@ impl Worker { 3. **選択的購読**: on_*で必要なイベントだけ、またはSubscriberで一括 4. **累積イベントの追加**: Worker層でComplete系イベントを追加提供 5. **後方互換性**: 従来の`run()`も引き続き使用可能 - diff --git a/llm-worker/Cargo.toml b/llm-worker/Cargo.toml index a5e8d9f..529327c 100644 --- a/llm-worker/Cargo.toml +++ b/llm-worker/Cargo.toml @@ -15,7 +15,8 @@ tracing = "0.1" async-trait = "0.1" futures = "0.3" tokio = { version = "1.49", features = ["macros", "rt-multi-thread"] } -reqwest = { version = "0.13.1", default-features = false, features = ["stream", "json", "native-tls"] } +tokio-util = "0.7" +reqwest = { version = "0.13.1", default-features = false, features = ["stream", "json", "native-tls", "http2"] } eventsource-stream = "0.2" llm-worker-macros = { path = "../llm-worker-macros", version = "0.1" } diff --git a/llm-worker/examples/worker_cancel_demo.rs b/llm-worker/examples/worker_cancel_demo.rs new file mode 100644 index 0000000..ad9173c --- /dev/null +++ b/llm-worker/examples/worker_cancel_demo.rs @@ -0,0 +1,71 @@ +//! Worker のキャンセル機能のデモンストレーション +//! +//! ストリーミング受信中に別スレッドからキャンセルする例 + +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::Mutex; +use llm_worker::{Worker, WorkerResult}; +use llm_worker::llm_client::providers::anthropic::AnthropicClient; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // .envファイルを読み込む + dotenv::dotenv().ok(); + + // ロギング初期化 + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")), + ) + .init(); + + let api_key = std::env::var("ANTHROPIC_API_KEY") + .expect("ANTHROPIC_API_KEY environment variable not set"); + + let client = AnthropicClient::new(&api_key, "claude-sonnet-4-20250514"); + let worker = Arc::new(Mutex::new(Worker::new(client))); + + println!("🚀 Starting Worker..."); + println!("💡 Will cancel after 2 seconds\n"); + + // キャンセルトークンを先に取得(ロックを保持しない) + let cancel_token = { + let w = worker.lock().await; + w.cancellation_token().clone() + }; + + // タスク1: Workerを実行 + let worker_clone = worker.clone(); + let task = tokio::spawn(async move { + let mut w = worker_clone.lock().await; + println!("📡 Sending request to LLM..."); + + match w.run("Tell me a very long story about a brave knight. Make it as detailed as possible with many paragraphs.").await { + Ok(WorkerResult::Finished(_)) => { + println!("✅ Task completed normally"); + } + Ok(WorkerResult::Paused(_)) => { + println!("⏸️ Task paused"); + } + Err(e) => { + println!("❌ Task error: {}", e); + } + } + }); + + // タスク2: 2秒後にキャンセル + tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(2)).await; + println!("\n🛑 Cancelling worker..."); + cancel_token.cancel(); + }); + + // タスク完了を待つ + task.await?; + + println!("\n✨ Demo complete!"); + + Ok(()) +} diff --git a/llm-worker/examples/worker_cli.rs b/llm-worker/examples/worker_cli.rs index 856af11..a576871 100644 --- a/llm-worker/examples/worker_cli.rs +++ b/llm-worker/examples/worker_cli.rs @@ -41,7 +41,7 @@ use tracing_subscriber::EnvFilter; use clap::{Parser, ValueEnum}; use llm_worker::{ Worker, - hook::{ControlFlow, HookError, ToolResult, WorkerHook}, + hook::{AfterToolCall, AfterToolCallResult, Hook, HookError, ToolResult}, llm_client::{ LlmClient, providers::{ @@ -282,11 +282,11 @@ impl ToolResultPrinterHook { } #[async_trait] -impl WorkerHook for ToolResultPrinterHook { - async fn after_tool_call( +impl Hook for ToolResultPrinterHook { + async fn call( &self, tool_result: &mut ToolResult, - ) -> Result { + ) -> Result { let name = self .call_names .lock() @@ -300,7 +300,7 @@ impl WorkerHook for ToolResultPrinterHook { println!(" Result ({}): ✅ {}", name, tool_result.content); } - Ok(ControlFlow::Continue) + Ok(AfterToolCallResult::Continue) } } @@ -451,7 +451,7 @@ async fn main() -> Result<(), Box> { .on_text_block(StreamingPrinter::new()) .on_tool_use_block(ToolCallPrinter::new(tool_call_names.clone())); - worker.add_hook(ToolResultPrinterHook::new(tool_call_names)); + worker.add_after_tool_call_hook(ToolResultPrinterHook::new(tool_call_names)); // ワンショットモード if let Some(prompt) = args.prompt { diff --git a/llm-worker/src/hook.rs b/llm-worker/src/hook.rs index 25b0fc5..1007dec 100644 --- a/llm-worker/src/hook.rs +++ b/llm-worker/src/hook.rs @@ -8,33 +8,72 @@ use serde_json::Value; use thiserror::Error; // ============================================================================= -// Control Flow Types +// Hook Event Kinds // ============================================================================= -/// Hook処理の制御フロー +pub trait HookEventKind: Send + Sync + 'static { + type Input; + type Output; +} + +pub struct OnMessageSend; +pub struct BeforeToolCall; +pub struct AfterToolCall; +pub struct OnTurnEnd; +pub struct OnAbort; + #[derive(Debug, Clone, PartialEq, Eq)] -pub enum ControlFlow { - /// 処理を続行 +pub enum OnMessageSendResult { + Continue, + Cancel(String), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum BeforeToolCallResult { Continue, - /// 現在の処理をスキップ(Tool実行など) Skip, - /// 処理を中断 Abort(String), - /// 処理を一時停止(再開可能) Pause, } -/// ターン終了時の判定結果 +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AfterToolCallResult { + Continue, + Abort(String), +} + #[derive(Debug, Clone)] -pub enum TurnResult { - /// ターンを終了 +pub enum OnTurnEndResult { Finish, - /// メッセージを追加してターン継続(自己修正など) ContinueWithMessages(Vec), - /// ターンを一時停止 Paused, } +impl HookEventKind for OnMessageSend { + type Input = Vec; + type Output = OnMessageSendResult; +} + +impl HookEventKind for BeforeToolCall { + type Input = ToolCall; + type Output = BeforeToolCallResult; +} + +impl HookEventKind for AfterToolCall { + type Input = ToolResult; + type Output = AfterToolCallResult; +} + +impl HookEventKind for OnTurnEnd { + type Input = Vec; + type Output = OnTurnEndResult; +} + +impl HookEventKind for OnAbort { + type Input = String; + type Output = (); +} + // ============================================================================= // Tool Call / Result Types // ============================================================================= @@ -102,85 +141,13 @@ pub enum HookError { } // ============================================================================= -// WorkerHook Trait +// Hook Trait // ============================================================================= -/// ターンの進行・ツール実行に介入するためのトレイト +/// Hookイベントの処理を行うトレイト /// -/// Hookを使うと、メッセージ送信前、ツール実行前後、ターン終了時に -/// 処理を挟んだり、実行をキャンセルしたりできます。 -/// -/// # Examples -/// -/// ```ignore -/// use llm_worker::hook::{ControlFlow, HookError, ToolCall, TurnResult, WorkerHook}; -/// use llm_worker::Message; -/// -/// struct ValidationHook; -/// -/// #[async_trait::async_trait] -/// impl WorkerHook for ValidationHook { -/// async fn before_tool_call(&self, call: &mut ToolCall) -> Result { -/// // 危険なツールをブロック -/// if call.name == "delete_all" { -/// return Ok(ControlFlow::Skip); -/// } -/// Ok(ControlFlow::Continue) -/// } -/// -/// async fn on_turn_end(&self, messages: &[Message]) -> Result { -/// // 条件を満たさなければ追加メッセージで継続 -/// if messages.len() < 3 { -/// return Ok(TurnResult::ContinueWithMessages(vec![ -/// Message::user("Please elaborate.") -/// ])); -/// } -/// Ok(TurnResult::Finish) -/// } -/// } -/// ``` -/// -/// # デフォルト実装 -/// -/// すべてのメソッドにはデフォルト実装があり、何も行わず`Continue`を返します。 -/// 必要なメソッドのみオーバーライドしてください。 +/// 各イベント種別は戻り値型が異なるため、`HookEventKind`を介して型を制約する。 #[async_trait] -pub trait WorkerHook: Send + Sync { - /// メッセージ送信前に呼ばれる - /// - /// リクエストに含まれるメッセージリストを参照・改変できます。 - /// `ControlFlow::Abort`を返すとターンが中断されます。 - async fn on_message_send( - &self, - _context: &mut Vec, - ) -> Result { - Ok(ControlFlow::Continue) - } - - /// ツール実行前に呼ばれる - /// - /// ツール呼び出しの引数を書き換えたり、実行をスキップしたりできます。 - /// `ControlFlow::Skip`を返すとこのツールの実行がスキップされます。 - async fn before_tool_call(&self, _tool_call: &mut ToolCall) -> Result { - Ok(ControlFlow::Continue) - } - - /// ツール実行後に呼ばれる - /// - /// ツールの実行結果を書き換えたり、隠蔽したりできます。 - async fn after_tool_call( - &self, - _tool_result: &mut ToolResult, - ) -> Result { - Ok(ControlFlow::Continue) - } - - /// ターン終了時に呼ばれる - /// - /// 生成されたメッセージを検査し、必要なら追加メッセージで継続を指示できます。 - /// `TurnResult::ContinueWithMessages`を返すと、指定したメッセージを追加して - /// 次のターンに進みます。 - async fn on_turn_end(&self, _messages: &[crate::Message]) -> Result { - Ok(TurnResult::Finish) - } +pub trait Hook: Send + Sync { + async fn call(&self, input: &mut E::Input) -> Result; } diff --git a/llm-worker/src/lib.rs b/llm-worker/src/lib.rs index 5668503..ef743b4 100644 --- a/llm-worker/src/lib.rs +++ b/llm-worker/src/lib.rs @@ -6,7 +6,7 @@ //! //! - [`Worker`] - LLMとの対話を管理する中心コンポーネント //! - [`tool::Tool`] - LLMから呼び出し可能なツール -//! - [`hook::WorkerHook`] - ターン進行への介入 +//! - [`hook::Hook`] - ターン進行への介入 //! - [`subscriber::WorkerSubscriber`] - ストリーミングイベントの購読 //! //! # Quick Start @@ -48,9 +48,5 @@ pub mod subscriber; pub mod timeline; pub mod tool; -// ============================================================================= -// トップレベル公開(最も頻繁に使う型) -// ============================================================================= - pub use message::{ContentPart, Message, MessageContent, Role}; -pub use worker::{Worker, WorkerConfig, WorkerError}; +pub use worker::{Worker, WorkerConfig, WorkerError, WorkerResult}; diff --git a/llm-worker/src/subscriber.rs b/llm-worker/src/subscriber.rs index d750e63..0e7cdf6 100644 --- a/llm-worker/src/subscriber.rs +++ b/llm-worker/src/subscriber.rs @@ -65,10 +65,10 @@ pub trait WorkerSubscriber: Send { /// /// ブロック開始時にDefault::default()で生成され、 /// ブロック終了時に破棄される。 - type TextBlockScope: Default + Send; + type TextBlockScope: Default + Send + Sync; /// ツール使用ブロック処理用のスコープ型 - type ToolUseBlockScope: Default + Send; + type ToolUseBlockScope: Default + Send + Sync; // ========================================================================= // ブロックイベント(スコープ管理あり) diff --git a/llm-worker/src/timeline/timeline.rs b/llm-worker/src/timeline/timeline.rs index 8ce92f0..8bb18f8 100644 --- a/llm-worker/src/timeline/timeline.rs +++ b/llm-worker/src/timeline/timeline.rs @@ -17,7 +17,7 @@ use crate::handler::*; /// 各Handlerは独自のScope型を持つため、Timelineで保持するには型消去が必要です。 /// 通常は直接使用せず、`Timeline::on_text_block()`などのメソッド経由で /// 自動的にラップされます。 -pub trait ErasedHandler: Send { +pub trait ErasedHandler: Send + Sync { /// イベントをディスパッチ fn dispatch(&mut self, event: &K::Event); /// スコープを開始(Block開始時) @@ -54,9 +54,9 @@ where impl ErasedHandler for HandlerWrapper where - H: Handler + Send, + H: Handler + Send + Sync, K: Kind, - H::Scope: Send, + H::Scope: Send + Sync, { fn dispatch(&mut self, event: &K::Event) { if let Some(scope) = &mut self.scope { @@ -78,7 +78,7 @@ where // ============================================================================= /// ブロックハンドラーの型消去trait -trait ErasedBlockHandler: Send { +trait ErasedBlockHandler: Send + Sync { fn dispatch_start(&mut self, start: &BlockStart); fn dispatch_delta(&mut self, delta: &BlockDelta); fn dispatch_stop(&mut self, stop: &BlockStop); @@ -112,8 +112,8 @@ where impl ErasedBlockHandler for TextBlockHandlerWrapper where - H: Handler + Send, - H::Scope: Send, + H: Handler + Send + Sync, + H::Scope: Send + Sync, { fn dispatch_start(&mut self, start: &BlockStart) { if let Some(scope) = &mut self.scope { @@ -185,8 +185,8 @@ where impl ErasedBlockHandler for ThinkingBlockHandlerWrapper where - H: Handler + Send, - H::Scope: Send, + H: Handler + Send + Sync, + H::Scope: Send + Sync, { fn dispatch_start(&mut self, start: &BlockStart) { if let Some(scope) = &mut self.scope { @@ -255,8 +255,8 @@ where impl ErasedBlockHandler for ToolUseBlockHandlerWrapper where - H: Handler + Send, - H::Scope: Send, + H: Handler + Send + Sync, + H::Scope: Send + Sync, { fn dispatch_start(&mut self, start: &BlockStart) { if let Some(scope) = &mut self.scope { @@ -391,8 +391,8 @@ impl Timeline { /// UsageKind用のHandlerを登録 pub fn on_usage(&mut self, handler: H) -> &mut Self where - H: Handler + Send + 'static, - H::Scope: Send, + H: Handler + Send + Sync + 'static, + H::Scope: Send + Sync, { // Meta系はデフォルトでスコープを開始しておく let mut wrapper = HandlerWrapper::new(handler); @@ -404,8 +404,8 @@ impl Timeline { /// PingKind用のHandlerを登録 pub fn on_ping(&mut self, handler: H) -> &mut Self where - H: Handler + Send + 'static, - H::Scope: Send, + H: Handler + Send + Sync + 'static, + H::Scope: Send + Sync, { let mut wrapper = HandlerWrapper::new(handler); wrapper.start_scope(); @@ -416,8 +416,8 @@ impl Timeline { /// StatusKind用のHandlerを登録 pub fn on_status(&mut self, handler: H) -> &mut Self where - H: Handler + Send + 'static, - H::Scope: Send, + H: Handler + Send + Sync + 'static, + H::Scope: Send + Sync, { let mut wrapper = HandlerWrapper::new(handler); wrapper.start_scope(); @@ -428,8 +428,8 @@ impl Timeline { /// ErrorKind用のHandlerを登録 pub fn on_error(&mut self, handler: H) -> &mut Self where - H: Handler + Send + 'static, - H::Scope: Send, + H: Handler + Send + Sync + 'static, + H::Scope: Send + Sync, { let mut wrapper = HandlerWrapper::new(handler); wrapper.start_scope(); @@ -440,8 +440,8 @@ impl Timeline { /// TextBlockKind用のHandlerを登録 pub fn on_text_block(&mut self, handler: H) -> &mut Self where - H: Handler + Send + 'static, - H::Scope: Send, + H: Handler + Send + Sync + 'static, + H::Scope: Send + Sync, { self.text_block_handlers .push(Box::new(TextBlockHandlerWrapper::new(handler))); @@ -451,8 +451,8 @@ impl Timeline { /// ThinkingBlockKind用のHandlerを登録 pub fn on_thinking_block(&mut self, handler: H) -> &mut Self where - H: Handler + Send + 'static, - H::Scope: Send, + H: Handler + Send + Sync + 'static, + H::Scope: Send + Sync, { self.thinking_block_handlers .push(Box::new(ThinkingBlockHandlerWrapper::new(handler))); @@ -462,8 +462,8 @@ impl Timeline { /// ToolUseBlockKind用のHandlerを登録 pub fn on_tool_use_block(&mut self, handler: H) -> &mut Self where - H: Handler + Send + 'static, - H::Scope: Send, + H: Handler + Send + Sync + 'static, + H::Scope: Send + Sync, { self.tool_use_block_handlers .push(Box::new(ToolUseBlockHandlerWrapper::new(handler))); @@ -578,6 +578,21 @@ impl Timeline { pub fn current_block(&self) -> Option { self.current_block } + + /// 現在アクティブなブロックを中断する + /// + /// キャンセルやエラー時に呼び出し、進行中のブロックに対して + /// BlockAbortイベントを発火してスコープをクリーンアップする。 + pub fn abort_current_block(&mut self) { + if let Some(block_type) = self.current_block { + let abort = crate::timeline::event::BlockAbort { + index: 0, // インデックスは不明なので0 + block_type, + reason: "Cancelled".to_string(), + }; + self.handle_block_abort(&abort); + } + } } #[cfg(test)] diff --git a/llm-worker/src/worker.rs b/llm-worker/src/worker.rs index 5425f71..0b413c0 100644 --- a/llm-worker/src/worker.rs +++ b/llm-worker/src/worker.rs @@ -3,11 +3,16 @@ use std::marker::PhantomData; use std::sync::{Arc, Mutex}; use futures::StreamExt; +use tokio_util::sync::CancellationToken; use tracing::{debug, info, trace, warn}; use crate::{ ContentPart, Message, MessageContent, Role, - hook::{ControlFlow, HookError, ToolCall, ToolResult, TurnResult, WorkerHook}, + hook::{ + AfterToolCall, AfterToolCallResult, BeforeToolCall, BeforeToolCallResult, Hook, HookError, + OnAbort, OnMessageSend, OnMessageSendResult, OnTurnEnd, OnTurnEndResult, ToolCall, + ToolResult, + }, llm_client::{ClientError, ConfigWarning, LlmClient, Request, RequestConfig, ToolDefinition}, state::{Locked, Mutable, WorkerState}, subscriber::{ @@ -37,6 +42,9 @@ pub enum WorkerError { /// 処理が中断された #[error("Aborted: {0}")] Aborted(String), + /// Cancellation Tokenによって中断された + #[error("Cancelled")] + Cancelled, /// 設定に関する警告(未サポートのオプション) #[error("Config warnings: {}", .0.iter().map(|w| w.to_string()).collect::>().join(", "))] ConfigWarnings(Vec), @@ -77,7 +85,7 @@ enum ToolExecutionResult { // ============================================================================= /// ターンイベントを通知するためのコールバック (型消去) -trait TurnNotifier: Send { +trait TurnNotifier: Send + Sync { fn on_turn_start(&self, turn: usize); fn on_turn_end(&self, turn: usize); } @@ -149,8 +157,16 @@ pub struct Worker { tool_call_collector: ToolCallCollector, /// 登録されたツール tools: HashMap>, - /// 登録されたHook - hooks: Vec>, + /// on_message_send Hook + hooks_on_message_send: Vec>>, + /// before_tool_call Hook + hooks_before_tool_call: Vec>>, + /// after_tool_call Hook + hooks_after_tool_call: Vec>>, + /// on_turn_end Hook + hooks_on_turn_end: Vec>>, + /// on_abort Hook + hooks_on_abort: Vec>>, /// システムプロンプト system_prompt: Option, /// メッセージ履歴(Workerが所有) @@ -163,6 +179,8 @@ pub struct Worker { turn_notifiers: Vec>, /// リクエスト設定(max_tokens, temperature等) request_config: RequestConfig, + /// キャンセレーショントークン(実行中断用) + cancellation_token: CancellationToken, /// 状態マーカー _state: PhantomData, } @@ -252,30 +270,29 @@ impl Worker { } } - /// Hookを追加する - /// - /// Hookはターンの進行・ツール実行に介入できます。 - /// 複数のHookを登録した場合、登録順に実行されます。 - /// - /// # Examples - /// - /// ```ignore - /// use llm_worker::{Worker, WorkerHook, ControlFlow, ToolCall}; - /// - /// struct LoggingHook; - /// - /// #[async_trait::async_trait] - /// impl WorkerHook for LoggingHook { - /// async fn before_tool_call(&self, call: &mut ToolCall) -> Result { - /// println!("Calling tool: {}", call.name); - /// Ok(ControlFlow::Continue) - /// } - /// } - /// - /// worker.add_hook(LoggingHook); - /// ``` - pub fn add_hook(&mut self, hook: impl WorkerHook + 'static) { - self.hooks.push(Box::new(hook)); + /// on_message_send Hookを追加する + pub fn add_on_message_send_hook(&mut self, hook: impl Hook + 'static) { + self.hooks_on_message_send.push(Box::new(hook)); + } + + /// before_tool_call Hookを追加する + pub fn add_before_tool_call_hook(&mut self, hook: impl Hook + 'static) { + self.hooks_before_tool_call.push(Box::new(hook)); + } + + /// after_tool_call Hookを追加する + pub fn add_after_tool_call_hook(&mut self, hook: impl Hook + 'static) { + self.hooks_after_tool_call.push(Box::new(hook)); + } + + /// on_turn_end Hookを追加する + pub fn add_on_turn_end_hook(&mut self, hook: impl Hook + 'static) { + self.hooks_on_turn_end.push(Box::new(hook)); + } + + /// on_abort Hookを追加する + pub fn add_on_abort_hook(&mut self, hook: impl Hook + 'static) { + self.hooks_on_abort.push(Box::new(hook)); } /// タイムラインへの可変参照を取得(追加ハンドラ登録用) @@ -375,6 +392,41 @@ impl Worker { self.request_config = config; } + /// 実行をキャンセルする + /// + /// 現在実行中のストリーミングやツール実行を中断します。 + /// 次のイベントループのチェックポイントでWorkerError::Cancelledが返されます。 + /// + /// # Examples + /// + /// ```ignore + /// use std::sync::Arc; + /// let worker = Arc::new(Mutex::new(Worker::new(client))); + /// + /// // 別スレッドで実行 + /// let worker_clone = worker.clone(); + /// tokio::spawn(async move { + /// let mut w = worker_clone.lock().unwrap(); + /// w.run("Long task...").await + /// }); + /// + /// // キャンセル + /// worker.lock().unwrap().cancel(); + /// ``` + pub fn cancel(&self) { + self.cancellation_token.cancel(); + } + + /// キャンセルされているかチェック + pub fn is_cancelled(&self) -> bool { + self.cancellation_token.is_cancelled() + } + + /// キャンセレーショントークンへの参照を取得 + pub fn cancellation_token(&self) -> &CancellationToken { + &self.cancellation_token + } + /// 登録されたツールからToolDefinitionのリストを生成 fn build_tool_definitions(&self) -> Vec { self.tools @@ -430,7 +482,7 @@ impl Worker { } /// リクエストを構築 - fn build_request(&self, tool_definitions: &[ToolDefinition]) -> Request { + fn build_request(&self, tool_definitions: &[ToolDefinition], context: &[Message]) -> Request { let mut request = Request::new(); // システムプロンプトを設定 @@ -439,7 +491,7 @@ impl Worker { } // メッセージを追加 - for msg in &self.history { + for msg in context { // Message から llm_client::Message への変換 request = request.message(crate::llm_client::Message { role: match msg.role { @@ -495,36 +547,45 @@ impl Worker { } /// Hooks: on_message_send - async fn run_on_message_send_hooks(&self) -> Result { - for hook in &self.hooks { - // Note: Locked状態でも履歴全体を参照として渡す(変更は不可) - // HookのAPIを変更し、immutable参照のみを渡すようにする必要があるかもしれない - // 現在は空のVecを渡して回避(要検討) - let mut temp_context = self.history.clone(); - let result = hook.on_message_send(&mut temp_context).await?; + async fn run_on_message_send_hooks( + &self, + ) -> Result<(OnMessageSendResult, Vec), WorkerError> { + let mut temp_context = self.history.clone(); + for hook in &self.hooks_on_message_send { + let result = hook.call(&mut temp_context).await?; match result { - ControlFlow::Continue => continue, - ControlFlow::Skip => return Ok(ControlFlow::Skip), - ControlFlow::Abort(reason) => return Ok(ControlFlow::Abort(reason)), - ControlFlow::Pause => return Ok(ControlFlow::Pause), + OnMessageSendResult::Continue => continue, + OnMessageSendResult::Cancel(reason) => { + return Ok((OnMessageSendResult::Cancel(reason), temp_context)); + } } } - Ok(ControlFlow::Continue) + Ok((OnMessageSendResult::Continue, temp_context)) } /// Hooks: on_turn_end - async fn run_on_turn_end_hooks(&self) -> Result { - for hook in &self.hooks { - let result = hook.on_turn_end(&self.history).await?; + async fn run_on_turn_end_hooks(&self) -> Result { + let mut temp_messages = self.history.clone(); + for hook in &self.hooks_on_turn_end { + let result = hook.call(&mut temp_messages).await?; match result { - TurnResult::Finish => continue, - TurnResult::ContinueWithMessages(msgs) => { - return Ok(TurnResult::ContinueWithMessages(msgs)); + OnTurnEndResult::Finish => continue, + OnTurnEndResult::ContinueWithMessages(msgs) => { + return Ok(OnTurnEndResult::ContinueWithMessages(msgs)); } - TurnResult::Paused => return Ok(TurnResult::Paused), + OnTurnEndResult::Paused => return Ok(OnTurnEndResult::Paused), } } - Ok(TurnResult::Finish) + Ok(OnTurnEndResult::Finish) + } + + /// Hooks: on_abort + async fn run_on_abort_hooks(&self, reason: &str) -> Result<(), WorkerError> { + let mut reason = reason.to_string(); + for hook in &self.hooks_on_abort { + hook.call(&mut reason).await?; + } + Ok(()) } /// 未実行のツール呼び出しがあるかチェック(Pauseからの復帰用) @@ -559,7 +620,7 @@ impl Worker { /// 全てのツールに対してbefore_tool_callフックを実行後、 /// 許可されたツールを並列に実行し、結果にafter_tool_callフックを適用する。 async fn execute_tools( - &self, + &mut self, tool_calls: Vec, ) -> Result { use futures::future::join_all; @@ -568,18 +629,18 @@ impl Worker { let mut approved_calls = Vec::new(); for mut tool_call in tool_calls { let mut skip = false; - for hook in &self.hooks { - let result = hook.before_tool_call(&mut tool_call).await?; + for hook in &self.hooks_before_tool_call { + let result = hook.call(&mut tool_call).await?; match result { - ControlFlow::Continue => {} - ControlFlow::Skip => { + BeforeToolCallResult::Continue => {} + BeforeToolCallResult::Skip => { skip = true; break; } - ControlFlow::Abort(reason) => { + BeforeToolCallResult::Abort(reason) => { return Err(WorkerError::Aborted(reason)); } - ControlFlow::Pause => { + BeforeToolCallResult::Pause => { return Ok(ToolExecutionResult::Paused); } } @@ -589,7 +650,7 @@ impl Worker { } } - // Phase 2: 許可されたツールを並列実行 + // Phase 2: 許可されたツールを並列実行(キャンセル可能) let futures: Vec<_> = approved_calls .into_iter() .map(|tool_call| { @@ -612,25 +673,26 @@ impl Worker { }) .collect(); - let mut results = join_all(futures).await; + // ツール実行をキャンセル可能にする + let mut results = tokio::select! { + results = join_all(futures) => results, + _ = self.cancellation_token.cancelled() => { + info!("Tool execution cancelled"); + self.timeline.abort_current_block(); + self.run_on_abort_hooks("Cancelled").await?; + return Err(WorkerError::Cancelled); + } + }; // Phase 3: after_tool_call フックを適用 for tool_result in &mut results { - for hook in &self.hooks { - let result = hook.after_tool_call(tool_result).await?; + for hook in &self.hooks_after_tool_call { + let result = hook.call(tool_result).await?; match result { - ControlFlow::Continue => {} - ControlFlow::Skip => break, - ControlFlow::Abort(reason) => { + AfterToolCallResult::Continue => {} + AfterToolCallResult::Abort(reason) => { return Err(WorkerError::Aborted(reason)); } - ControlFlow::Pause => { - // after_tool_callでのPauseは結果を受け入れた後、次の処理前に止まる動作とする - // ここではContinue扱いとし、on_message_send等でPauseすることを期待する - // あるいはここでのPauseをサポートする場合は戻り値を調整する必要がある - // 現状はログを出してContinue - warn!("ControlFlow::Pause in after_tool_call is treated as Continue"); - } } } } @@ -663,6 +725,14 @@ impl Worker { } loop { + // キャンセルチェック + if self.cancellation_token.is_cancelled() { + info!("Execution cancelled"); + self.timeline.abort_current_block(); + self.run_on_abort_hooks("Cancelled").await?; + return Err(WorkerError::Cancelled); + } + // ターン開始を通知 let current_turn = self.turn_count; debug!(turn = current_turn, "Turn start"); @@ -671,24 +741,21 @@ impl Worker { } // Hook: on_message_send - let control = self.run_on_message_send_hooks().await?; + let (control, request_context) = self.run_on_message_send_hooks().await?; match control { - ControlFlow::Abort(reason) => { - warn!(reason = %reason, "Aborted by hook"); + OnMessageSendResult::Cancel(reason) => { + info!(reason = %reason, "Aborted by hook"); for notifier in &self.turn_notifiers { notifier.on_turn_end(current_turn); } + self.run_on_abort_hooks(&reason).await?; return Err(WorkerError::Aborted(reason)); } - ControlFlow::Pause | ControlFlow::Skip => { - // Skip or Pause -> Pause the worker - return Ok(WorkerResult::Paused(&self.history)); - } - ControlFlow::Continue => {} + OnMessageSendResult::Continue => {} } // リクエスト構築 - let request = self.build_request(&tool_definitions); + let request = self.build_request(&tool_definitions, &request_context); debug!( message_count = request.messages.len(), tool_count = request.tools.len(), @@ -698,21 +765,49 @@ impl Worker { // ストリーム処理 debug!("Starting stream..."); - let mut stream = self.client.stream(request).await?; let mut event_count = 0; - while let Some(event_result) = stream.next().await { - match &event_result { - Ok(event) => { - trace!(event = ?event, "Received event"); - event_count += 1; + + // ストリームを取得(キャンセル可能) + let mut stream = tokio::select! { + stream_result = self.client.stream(request) => stream_result?, + _ = self.cancellation_token.cancelled() => { + info!("Cancelled before stream started"); + self.timeline.abort_current_block(); + self.run_on_abort_hooks("Cancelled").await?; + return Err(WorkerError::Cancelled); + } + }; + + loop { + tokio::select! { + // ストリームからイベントを受信 + event_result = stream.next() => { + match event_result { + Some(result) => { + match &result { + Ok(event) => { + trace!(event = ?event, "Received event"); + event_count += 1; + } + Err(e) => { + warn!(error = %e, "Stream error"); + } + } + let event = result?; + let timeline_event: crate::timeline::event::Event = event.into(); + self.timeline.dispatch(&timeline_event); + } + None => break, // ストリーム終了 + } } - Err(e) => { - warn!(error = %e, "Stream error"); + // キャンセル待機 + _ = self.cancellation_token.cancelled() => { + info!("Stream cancelled"); + self.timeline.abort_current_block(); + self.run_on_abort_hooks("Cancelled").await?; + return Err(WorkerError::Cancelled); } } - let event = event_result?; - let timeline_event: crate::timeline::event::Event = event.into(); - self.timeline.dispatch(&timeline_event); } debug!(event_count = event_count, "Stream completed"); @@ -736,14 +831,14 @@ impl Worker { // ツール呼び出しなし → ターン終了判定 let turn_result = self.run_on_turn_end_hooks().await?; match turn_result { - TurnResult::Finish => { + OnTurnEndResult::Finish => { return Ok(WorkerResult::Finished(&self.history)); } - TurnResult::ContinueWithMessages(additional) => { + OnTurnEndResult::ContinueWithMessages(additional) => { self.history.extend(additional); continue; } - TurnResult::Paused => { + OnTurnEndResult::Paused => { return Ok(WorkerResult::Paused(&self.history)); } } @@ -790,13 +885,18 @@ impl Worker { text_block_collector, tool_call_collector, tools: HashMap::new(), - hooks: Vec::new(), + hooks_on_message_send: Vec::new(), + hooks_before_tool_call: Vec::new(), + hooks_after_tool_call: Vec::new(), + hooks_on_turn_end: Vec::new(), + hooks_on_abort: Vec::new(), system_prompt: None, history: Vec::new(), locked_prefix_len: 0, turn_count: 0, turn_notifiers: Vec::new(), request_config: RequestConfig::default(), + cancellation_token: CancellationToken::new(), _state: PhantomData, } } @@ -958,13 +1058,18 @@ impl Worker { text_block_collector: self.text_block_collector, tool_call_collector: self.tool_call_collector, tools: self.tools, - hooks: self.hooks, + hooks_on_message_send: self.hooks_on_message_send, + hooks_before_tool_call: self.hooks_before_tool_call, + hooks_after_tool_call: self.hooks_after_tool_call, + hooks_on_turn_end: self.hooks_on_turn_end, + hooks_on_abort: self.hooks_on_abort, system_prompt: self.system_prompt, history: self.history, locked_prefix_len, turn_count: self.turn_count, turn_notifiers: self.turn_notifiers, request_config: self.request_config, + cancellation_token: self.cancellation_token, _state: PhantomData, } } @@ -1032,13 +1137,18 @@ impl Worker { text_block_collector: self.text_block_collector, tool_call_collector: self.tool_call_collector, tools: self.tools, - hooks: self.hooks, + hooks_on_message_send: self.hooks_on_message_send, + hooks_before_tool_call: self.hooks_before_tool_call, + hooks_after_tool_call: self.hooks_after_tool_call, + hooks_on_turn_end: self.hooks_on_turn_end, + hooks_on_abort: self.hooks_on_abort, system_prompt: self.system_prompt, history: self.history, locked_prefix_len: 0, turn_count: self.turn_count, turn_notifiers: self.turn_notifiers, request_config: self.request_config, + cancellation_token: self.cancellation_token, _state: PhantomData, } } diff --git a/llm-worker/tests/parallel_execution_test.rs b/llm-worker/tests/parallel_execution_test.rs index ac92315..485df23 100644 --- a/llm-worker/tests/parallel_execution_test.rs +++ b/llm-worker/tests/parallel_execution_test.rs @@ -8,7 +8,10 @@ use std::time::{Duration, Instant}; use async_trait::async_trait; use llm_worker::Worker; -use llm_worker::hook::{ControlFlow, HookError, ToolCall, ToolResult, WorkerHook}; +use llm_worker::hook::{ + AfterToolCall, AfterToolCallResult, BeforeToolCall, BeforeToolCallResult, Hook, HookError, + ToolCall, ToolResult, +}; use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent}; use llm_worker::tool::{Tool, ToolError}; @@ -158,20 +161,17 @@ async fn test_before_tool_call_skip() { struct BlockingHook; #[async_trait] - impl WorkerHook for BlockingHook { - async fn before_tool_call( - &self, - tool_call: &mut ToolCall, - ) -> Result { + impl Hook for BlockingHook { + async fn call(&self, tool_call: &mut ToolCall) -> Result { if tool_call.name == "blocked_tool" { - Ok(ControlFlow::Skip) + Ok(BeforeToolCallResult::Skip) } else { - Ok(ControlFlow::Continue) + Ok(BeforeToolCallResult::Continue) } } } - worker.add_hook(BlockingHook); + worker.add_before_tool_call_hook(BlockingHook); let _result = worker.run("Test hook").await; @@ -242,19 +242,19 @@ async fn test_after_tool_call_modification() { } #[async_trait] - impl WorkerHook for ModifyingHook { - async fn after_tool_call( + impl Hook for ModifyingHook { + async fn call( &self, tool_result: &mut ToolResult, - ) -> Result { + ) -> Result { tool_result.content = format!("[Modified] {}", tool_result.content); *self.modified_content.lock().unwrap() = Some(tool_result.content.clone()); - Ok(ControlFlow::Continue) + Ok(AfterToolCallResult::Continue) } } let modified_content = Arc::new(std::sync::Mutex::new(None)); - worker.add_hook(ModifyingHook { + worker.add_after_tool_call_hook(ModifyingHook { modified_content: modified_content.clone(), });