From 5691b09fc810dc966dbbe8d8002915b4b9fb64cb Mon Sep 17 00:00:00 2001 From: Hare Date: Fri, 9 Jan 2026 19:18:20 +0900 Subject: [PATCH 1/4] feat: Implement HookEventKind --- Cargo.lock | 239 ++++++++++++++- README.md | 1 - docs/spec/basis.md | 2 +- docs/spec/cancellation.md | 90 ++++++ docs/spec/hooks_design.md | 286 +++++++++--------- docs/spec/worker_design.md | 76 +++-- llm-worker/Cargo.toml | 3 +- llm-worker/examples/worker_cancel_demo.rs | 71 +++++ llm-worker/examples/worker_cli.rs | 12 +- llm-worker/src/hook.rs | 145 ++++------ llm-worker/src/lib.rs | 8 +- llm-worker/src/subscriber.rs | 4 +- llm-worker/src/timeline/timeline.rs | 63 ++-- llm-worker/src/worker.rs | 304 +++++++++++++------- llm-worker/tests/parallel_execution_test.rs | 28 +- 15 files changed, 916 insertions(+), 416 deletions(-) create mode 100644 docs/spec/cancellation.md create mode 100644 llm-worker/examples/worker_cancel_demo.rs diff --git a/Cargo.lock b/Cargo.lock index 0a959f7..76cefe8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -104,9 +104,9 @@ checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" [[package]] name = "cc" -version = "1.2.51" +version = "1.2.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a0aeaff4ff1a90589618835a598e545176939b97874f7abc7851caa0618f203" +checksum = "cd4932aefd12402b36c60956a4fe0035421f544799057659ff86f923657aada3" dependencies = [ "find-msvc-tools", "shlex", @@ -203,6 +203,12 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + [[package]] name = "errno" version = "0.3.14" @@ -232,9 +238,15 @@ checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] name = "find-msvc-tools" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645cbb3a84e60b7531617d5ae4e57f7e27308f6445f5abf653209ea76dec8dff" +checksum = "f449e6c6c08c865631d4890cfacf252b3d396c9bcc83adb6623cdb02a8336c41" + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "foreign-types" @@ -349,6 +361,17 @@ dependencies = [ "slab", ] +[[package]] +name = "getrandom" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "getrandom" version = "0.3.4" @@ -361,6 +384,31 @@ dependencies = [ "wasip2", ] +[[package]] +name = "h2" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + [[package]] name = "heck" version = "0.5.0" @@ -416,6 +464,7 @@ dependencies = [ "bytes", "futures-channel", "futures-core", + "h2", "http", "http-body", "httparse", @@ -427,6 +476,22 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http", + "hyper", + "hyper-util", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", +] + [[package]] name = "hyper-tls" version = "0.6.0" @@ -569,6 +634,16 @@ dependencies = [ "icu_properties", ] +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown", +] + [[package]] name = "ipnet" version = "2.11.0" @@ -648,6 +723,7 @@ dependencies = [ "tempfile", "thiserror", "tokio", + "tokio-util", "tracing", "tracing-subscriber", ] @@ -895,10 +971,12 @@ dependencies = [ "bytes", "futures-core", "futures-util", + "h2", "http", "http-body", "http-body-util", "hyper", + "hyper-rustls", "hyper-tls", "hyper-util", "js-sys", @@ -923,6 +1001,20 @@ dependencies = [ "web-sys", ] +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.16", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + [[package]] name = "rustix" version = "1.1.3" @@ -936,6 +1028,19 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "rustls" +version = "0.23.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b" +dependencies = [ + "once_cell", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + [[package]] name = "rustls-pki-types" version = "1.13.2" @@ -945,6 +1050,17 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-webpki" +version = "0.103.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.22" @@ -1111,6 +1227,12 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" version = "2.0.114" @@ -1149,7 +1271,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" dependencies = [ "fastrand", - "getrandom", + "getrandom 0.3.4", "once_cell", "rustix", "windows-sys 0.61.2", @@ -1230,6 +1352,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.18" @@ -1361,6 +1493,12 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "url" version = "2.5.8" @@ -1508,13 +1646,22 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.60.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" dependencies = [ - "windows-targets", + "windows-targets 0.53.5", ] [[package]] @@ -1526,6 +1673,22 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + [[package]] name = "windows-targets" version = "0.53.5" @@ -1533,58 +1696,106 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" dependencies = [ "windows-link", - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.53.1", + "windows_aarch64_msvc 0.53.1", + "windows_i686_gnu 0.53.1", + "windows_i686_gnullvm 0.53.1", + "windows_i686_msvc 0.53.1", + "windows_x86_64_gnu 0.53.1", + "windows_x86_64_gnullvm 0.53.1", + "windows_x86_64_msvc 0.53.1", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + [[package]] name = "windows_aarch64_gnullvm" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + [[package]] name = "windows_aarch64_msvc" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + [[package]] name = "windows_i686_gnu" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + [[package]] name = "windows_i686_gnullvm" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + [[package]] name = "windows_i686_msvc" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + [[package]] name = "windows_x86_64_gnu" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + [[package]] name = "windows_x86_64_gnullvm" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + [[package]] name = "windows_x86_64_msvc" version = "0.53.1" diff --git a/README.md b/README.md index b60d501..e6b13db 100644 --- a/README.md +++ b/README.md @@ -33,4 +33,3 @@ let history = worker.run("What is 2+2?").await?; ## License MIT - diff --git a/docs/spec/basis.md b/docs/spec/basis.md index 903658e..d3fcb99 100644 --- a/docs/spec/basis.md +++ b/docs/spec/basis.md @@ -17,7 +17,7 @@ LLMを用いたワーカーを作成する小型のSDK・ライブラリ。 module構成概念図 -``` +```plaintext worker ├── context ├── llm_client diff --git a/docs/spec/cancellation.md b/docs/spec/cancellation.md new file mode 100644 index 0000000..54f5dc7 --- /dev/null +++ b/docs/spec/cancellation.md @@ -0,0 +1,90 @@ +# 非同期キャンセル設計 + +Workerの非同期キャンセル機構についての設計ドキュメント。 + +## 概要 + +`tokio_util::sync::CancellationToken`を用いて、別タスクからWorkerの実行を安全にキャンセルできる。 + +```rust +let worker = Arc::new(Mutex::new(Worker::new(client))); + +// 実行タスク +let w = worker.clone(); +let handle = tokio::spawn(async move { + w.lock().await.run("prompt").await +}); + +// キャンセル +worker.lock().await.cancel(); +``` + +## キャンセルポイント + +キャンセルは以下のタイミングでチェックされる: + +1. **ターンループ先頭** — `is_cancelled()`で即座にチェック +2. **ストリーム開始前** — `client.stream()`呼び出し時 +3. **ストリーム受信中** — `tokio::select!`で各イベント受信と並行監視 +4. **ツール実行中** — `join_all()`と並行監視 + +## キャンセル時の処理フロー + +``` +キャンセル検知 + ↓ +timeline.abort_current_block() // 進行中ブロックの終端処理 + ↓ +run_on_abort_hooks("Cancelled") // on_abort フック呼び出し + ↓ +Err(WorkerError::Cancelled) // エラー返却 +``` + +## API + +| メソッド | 説明 | +| ---------------------- | --------------------------------------------------------- | +| `cancel()` | キャンセルをトリガー | +| `is_cancelled()` | キャンセル状態を確認 | +| `cancellation_token()` | トークンへの参照を取得(`clone()`してタスク間で共有可能) | + +## on_abort フック + +`Hook::on_abort(&self, reason: &str)`がキャンセル時に呼ばれる。 +クリーンアップ処理やログ記録に使用できる。 + +```rust +async fn on_abort(&self, reason: &str) -> Result<(), HookError> { + log::info!("Aborted: {}", reason); + Ok(()) +} +``` + +呼び出しタイミング: + +- `WorkerError::Cancelled` — reason: `"Cancelled"` +- `ControlFlow::Abort(reason)` — reason: フックが指定した理由 + +--- + +## 既知の問題 + +### 1. キャンセルトークンの再利用不可 + +`CancellationToken`は一度キャンセルされると永続的にキャンセル状態になる。 +同じWorkerインスタンスで再度`run()`を呼ぶと即座に`Cancelled`エラーになる。 + +**対応案:** + +- `run()`開始時に新しいトークンを生成する +- `reset_cancellation()`メソッドを提供する + +### 2. Sync バウンドの追加(破壊的変更) + +`tokio::select!`使用のため、Handler/Scope型に`Sync`バウンドを追加した。 +既存のユーザーコードで`Sync`未実装の型を使用している場合、コンパイルエラーになる。 + +### 3. エラー時のon_abort呼び出し + +現在、`on_abort`はキャンセルとフックAbort時のみ呼ばれる。 +ストリームエラー等のその他エラー時には呼ばれないため、一貫性に欠ける可能性がある。 diff --git a/docs/spec/hooks_design.md b/docs/spec/hooks_design.md index b56357b..08354a6 100644 --- a/docs/spec/hooks_design.md +++ b/docs/spec/hooks_design.md @@ -3,7 +3,8 @@ ## 概要 HookはWorker層でのターン制御に介入するためのメカニズムです。 -Claude CodeのHooks機能に着想を得ており、メッセージ送信・ツール実行・ターン終了の各ポイントで処理を差し込むことができます。 + +メッセージ送信・ツール実行・ターン終了等の各ポイントで処理を差し込むことができます。 ## コンセプト @@ -11,76 +12,82 @@ Claude CodeのHooks機能に着想を得ており、メッセージ送信・ツ - **Contextへのアクセス**: メッセージ履歴を読み書き可能 - **非破壊的チェーン**: 複数のHookを登録順に実行、後続Hookへの影響を制御 +## Hook一覧 + +| Hook | タイミング | 主な用途 | 戻り値 | +| ------------------ | -------------------------- | --------------------- | ---------------------- | +| `on_message_send` | LLM送信前 | コンテキスト改変/検証 | `OnMessageSendResult` | +| `before_tool_call` | ツール実行前 | 実行許可/引数改変 | `BeforeToolCallResult` | +| `after_tool_call` | ツール実行後 | 結果加工/マスキング | `AfterToolCallResult` | +| `on_turn_end` | ツールなしでターン終了直前 | 検証/リトライ指示 | `OnTurnEndResult` | +| `on_abort` | 中断時 | クリーンアップ/通知 | `()` | + ## Hook Trait ```rust #[async_trait] -pub trait WorkerHook: Send + Sync { - /// メッセージ送信前 - /// リクエストに含まれるメッセージリストを改変できる - async fn on_message_send( - &self, - context: &mut Vec, - ) -> Result { - Ok(ControlFlow::Continue) - } - - /// ツール実行前 - /// 実行をキャンセルしたり、引数を書き換えることができる - async fn before_tool_call( - &self, - tool_call: &mut ToolCall, - ) -> Result { - Ok(ControlFlow::Continue) - } - - /// ツール実行後 - /// 結果を書き換えたり、隠蔽したりできる - async fn after_tool_call( - &self, - tool_result: &mut ToolResult, - ) -> Result { - Ok(ControlFlow::Continue) - } - - /// ターン終了時 - /// 生成されたメッセージを検査し、必要ならリトライを指示できる - async fn on_turn_end( - &self, - messages: &[Message], - ) -> Result { - Ok(TurnResult::Finish) - } +pub trait Hook: Send + Sync { + async fn call(&self, input: &mut E::Input) -> Result; } ``` ## 制御フロー型 -### ControlFlow +### HookEventKind / Result -Hook処理の継続/中断を制御する列挙型。 +Hookイベントごとに入力/出力型を分離し、意味のない制御フローを排除する。 ```rust -pub enum ControlFlow { - /// 処理を続行(後続Hookも実行) +pub trait HookEventKind { + type Input; + type Output; +} + +pub struct OnMessageSend; +pub struct BeforeToolCall; +pub struct AfterToolCall; +pub struct OnTurnEnd; +pub struct OnAbort; + +pub enum OnMessageSendResult { + Continue, + Cancel(String), +} + +pub enum BeforeToolCallResult { Continue, - /// 現在の処理をスキップ(ツール実行をスキップ等) Skip, - /// 処理全体を中断(エラーとして扱う) Abort(String), + Pause, +} + +pub enum AfterToolCallResult { + Continue, + Abort(String), +} + +pub enum OnTurnEndResult { + Finish, + ContinueWithMessages(Vec), + Paused, } ``` -### TurnResult +### Tool Call Context -ターン終了時の判定結果を表す列挙型。 +`before_tool_call` / `after_tool_call` は、ツール実行の文脈を含む入力を受け取る。 ```rust -pub enum TurnResult { - /// ターンを正常終了 - Finish, - /// メッセージを追加してターン継続(自己修正など) - ContinueWithMessages(Vec), +pub struct ToolCallContext { + pub call: ToolCall, + pub meta: ToolMeta, // 不変メタデータ + pub tool: Arc, // 状態アクセス用 +} + +pub struct ToolResultContext { + pub result: ToolResult, + pub meta: ToolMeta, + pub tool: Arc, } ``` @@ -90,28 +97,30 @@ pub enum TurnResult { Worker::run() ループ │ ├─▶ on_message_send ──────────────────────────────┐ -│ コンテキストの改変、バリデーション、 │ -│ システムプロンプト注入などが可能 │ -│ │ -├─▶ LLMリクエスト送信 & ストリーム処理 │ -│ │ -├─▶ ツール呼び出しがある場合: │ -│ │ │ +│ コンテキストの改変、バリデーション、 │ +│ システムプロンプト注入などが可能 │ +│ │ +├─▶ LLMリクエスト送信 & ストリーム処理 │ +│ │ +├─▶ ツール呼び出しがある場合: │ +│ │ │ │ ├─▶ before_tool_call (各ツールごと・逐次) │ -│ │ 実行可否の判定、引数の改変 │ -│ │ │ +│ │ 実行可否の判定、引数の改変 │ +│ │ │ │ ├─▶ ツール並列実行 (join_all) │ -│ │ │ +│ │ │ │ └─▶ after_tool_call (各結果ごと・逐次) │ -│ 結果の確認、加工、ログ出力 │ -│ │ +│ 結果の確認、加工、ログ出力 │ +│ │ ├─▶ ツール結果をコンテキストに追加 → ループ先頭へ │ -│ │ -└─▶ ツールなしの場合: │ - │ │ - └─▶ on_turn_end ─────────────────────────────┘ +│ │ +└─▶ ツールなしの場合: │ + │ │ + └─▶ on_turn_end ─────────────────────────────┘ 最終応答のチェック(Lint/Fmt等) エラーがあればContinueWithMessagesでリトライ + +※ 中断時は on_abort が呼ばれる ``` ## 各Hookの詳細 @@ -121,10 +130,12 @@ Worker::run() ループ **呼び出しタイミング**: LLMへリクエスト送信前(ターンループの冒頭) **用途**: + - コンテキストへのシステムメッセージ注入 - メッセージのバリデーション - 機密情報のフィルタリング - リクエスト内容のログ出力 +- `OnMessageSendResult::Cancel` による送信キャンセル **例**: メッセージにタイムスタンプを追加 @@ -132,14 +143,14 @@ Worker::run() ループ struct TimestampHook; #[async_trait] -impl WorkerHook for TimestampHook { - async fn on_message_send( +impl Hook for TimestampHook { + async fn call( &self, context: &mut Vec, - ) -> Result { + ) -> Result { let timestamp = chrono::Local::now().to_rfc3339(); context.insert(0, Message::user(format!("[{}]", timestamp))); - Ok(ControlFlow::Continue) + Ok(OnMessageSendResult::Continue) } } ``` @@ -149,10 +160,15 @@ impl WorkerHook for TimestampHook { **呼び出しタイミング**: 各ツール実行前(並列実行フェーズの前) **用途**: + - 危険なツールのブロック - 引数のサニタイズ - 確認プロンプトの表示(UIとの連携) - 実行ログの記録 +- `BeforeToolCallResult::Pause` による一時停止 + +**入力**: +- `ToolCallContext`(`ToolCall` + `ToolMeta` + `Arc`) **例**: 特定ツールをブロック @@ -162,16 +178,16 @@ struct ToolBlocker { } #[async_trait] -impl WorkerHook for ToolBlocker { - async fn before_tool_call( +impl Hook for ToolBlocker { + async fn call( &self, - tool_call: &mut ToolCall, - ) -> Result { - if self.blocked_tools.contains(&tool_call.name) { - println!("Blocked tool: {}", tool_call.name); - Ok(ControlFlow::Skip) + ctx: &mut ToolCallContext, + ) -> Result { + if self.blocked_tools.contains(&ctx.call.name) { + println!("Blocked tool: {}", ctx.call.name); + Ok(BeforeToolCallResult::Skip) } else { - Ok(ControlFlow::Continue) + Ok(BeforeToolCallResult::Continue) } } } @@ -182,10 +198,13 @@ impl WorkerHook for ToolBlocker { **呼び出しタイミング**: 各ツール実行後(並列実行フェーズの後) **用途**: + - 結果の加工・フォーマット - 機密情報のマスキング - 結果のキャッシュ - 実行結果のログ出力 +**入力**: +- `ToolResultContext`(`ToolResult` + `ToolMeta` + `Arc`) **例**: 結果にプレフィックスを追加 @@ -193,15 +212,15 @@ impl WorkerHook for ToolBlocker { struct ResultFormatter; #[async_trait] -impl WorkerHook for ResultFormatter { - async fn after_tool_call( +impl Hook for ResultFormatter { + async fn call( &self, - tool_result: &mut ToolResult, - ) -> Result { - if !tool_result.is_error { - tool_result.content = format!("[OK] {}", tool_result.content); + ctx: &mut ToolResultContext, + ) -> Result { + if !ctx.result.is_error { + ctx.result.content = format!("[OK] {}", ctx.result.content); } - Ok(ControlFlow::Continue) + Ok(AfterToolCallResult::Continue) } } ``` @@ -211,10 +230,22 @@ impl WorkerHook for ResultFormatter { **呼び出しタイミング**: ツール呼び出しなしでターンが終了する直前 **用途**: + - 生成されたコードのLint/Fmt - 出力形式のバリデーション - 自己修正のためのリトライ指示 - 最終結果のログ出力 +- `OnTurnEndResult::Paused` による一時停止 + +### on_abort + +**呼び出しタイミング**: キャンセル/エラー/AbortなどでWorkerが中断された時 + +**用途**: + +- クリーンアップ処理 +- 中断理由のログ出力 +- 外部システムへの通知 **例**: JSON形式のバリデーション @@ -222,11 +253,11 @@ impl WorkerHook for ResultFormatter { struct JsonValidator; #[async_trait] -impl WorkerHook for JsonValidator { - async fn on_turn_end( +impl Hook for JsonValidator { + async fn call( &self, - messages: &[Message], - ) -> Result { + messages: &mut Vec, + ) -> Result { // 最後のアシスタントメッセージを取得 let last = messages.iter().rev() .find(|m| m.role == Role::Assistant); @@ -236,25 +267,25 @@ impl WorkerHook for JsonValidator { // JSONとしてパースを試みる if serde_json::from_str::(text).is_err() { // 失敗したらリトライ指示 - return Ok(TurnResult::ContinueWithMessages(vec![ + return Ok(OnTurnEndResult::ContinueWithMessages(vec![ Message::user("Invalid JSON. Please fix and try again.") ])); } } } - Ok(TurnResult::Finish) + Ok(OnTurnEndResult::Finish) } } ``` ## 複数Hookの実行順序 -Hookは**登録順**に実行されます。 +Hookは**イベントごとに登録順**に実行されます。 ```rust -worker.add_hook(HookA); // 1番目に実行 -worker.add_hook(HookB); // 2番目に実行 -worker.add_hook(HookC); // 3番目に実行 +worker.add_before_tool_call_hook(HookA); // 1番目に実行 +worker.add_before_tool_call_hook(HookB); // 2番目に実行 +worker.add_before_tool_call_hook(HookC); // 3番目に実行 ``` ### 制御フローの伝播 @@ -262,6 +293,7 @@ worker.add_hook(HookC); // 3番目に実行 - `Continue`: 後続Hookも実行 - `Skip`: 現在の処理をスキップし、後続Hookは実行しない - `Abort`: 即座にエラーを返し、処理全体を中断 +- `Pause`: Workerを一時停止(再開は`resume`) ``` Hook A: Continue → Hook B: Skip → (Hook Cは実行されない) @@ -271,40 +303,27 @@ Hook A: Continue → Hook B: Skip → (Hook Cは実行されない) Hook A: Continue → Hook B: Abort("reason") ↓ WorkerError::Aborted + +Hook A: Continue → Hook B: Pause + ↓ + WorkerResult::Paused ``` ## 設計上のポイント -### 1. デフォルト実装 +### 1. イベントごとの実装 -全メソッドにデフォルト実装があるため、必要なメソッドだけオーバーライドすれば良い。 - -```rust -struct SimpleLogger; - -#[async_trait] -impl WorkerHook for SimpleLogger { - // on_message_send だけ実装 - async fn on_message_send( - &self, - context: &mut Vec, - ) -> Result { - println!("Sending {} messages", context.len()); - Ok(ControlFlow::Continue) - } - // 他のメソッドはデフォルト(Continue/Finish) -} -``` +必要なイベントのみ `Hook` を実装する。 ### 2. 可変参照による改変 `&mut`で引数を受け取るため、直接改変が可能。 ```rust -async fn before_tool_call(&self, tool_call: &mut ToolCall) -> ... { +async fn call(&self, ctx: &mut ToolCallContext) -> ... { // 引数を直接書き換え - tool_call.input["sanitized"] = json!(true); - Ok(ControlFlow::Continue) + ctx.call.input["sanitized"] = json!(true); + Ok(BeforeToolCallResult::Continue) } ``` @@ -316,7 +335,7 @@ async fn before_tool_call(&self, tool_call: &mut ToolCall) -> ... { ### 4. Send + Sync 要件 -`WorkerHook`は`Send + Sync`を要求するため、スレッドセーフな実装が必要。 +`Hook`は`Send + Sync`を要求するため、スレッドセーフな実装が必要。 状態を持つ場合は`Arc>`や`AtomicUsize`などを使用する。 ```rust @@ -325,24 +344,24 @@ struct CountingHook { } #[async_trait] -impl WorkerHook for CountingHook { - async fn before_tool_call(&self, _: &mut ToolCall) -> Result { +impl Hook for CountingHook { + async fn call(&self, _: &mut ToolCallContext) -> Result { self.count.fetch_add(1, Ordering::SeqCst); - Ok(ControlFlow::Continue) + Ok(BeforeToolCallResult::Continue) } } ``` ## 典型的なユースケース -| ユースケース | 使用Hook | 処理内容 | -|-------------|----------|----------| -| ツール許可制御 | `before_tool_call` | 危険なツールをSkip | -| 実行ログ | `before/after_tool_call` | 呼び出しと結果を記録 | -| 出力バリデーション | `on_turn_end` | 形式チェック、リトライ指示 | -| コンテキスト注入 | `on_message_send` | システムメッセージ追加 | -| 結果のサニタイズ | `after_tool_call` | 機密情報のマスキング | -| レート制限 | `before_tool_call` | 呼び出し頻度の制御 | +| ユースケース | 使用Hook | 処理内容 | +| ------------------ | ------------------------ | -------------------------- | +| ツール許可制御 | `before_tool_call` | 危険なツールをSkip | +| 実行ログ | `before/after_tool_call` | 呼び出しと結果を記録 | +| 出力バリデーション | `on_turn_end` | 形式チェック、リトライ指示 | +| コンテキスト注入 | `on_message_send` | システムメッセージ追加 | +| 結果のサニタイズ | `after_tool_call` | 機密情報のマスキング | +| レート制限 | `before_tool_call` | 呼び出し頻度の制御 | ## TODO @@ -350,11 +369,14 @@ impl WorkerHook for CountingHook { 現在のHooks実装は基本的なユースケースをカバーしているが、以下の点について将来的に厳密な仕様を定義する必要がある: -- **エラーハンドリングの明確化**: `HookError`発生時のリカバリー戦略、部分的な失敗の扱い +- **エラーハンドリングの明確化**: + `HookError`発生時のリカバリー戦略、部分的な失敗の扱い - **Hook間の依存関係**: 複数Hookの実行順序が結果に影響する場合のセマンティクス - **非同期キャンセル**: Hook実行中のキャンセル(タイムアウト等)の振る舞い -- **状態の一貫性**: `on_message_send`で改変されたコンテキストが後続処理で期待通りに反映される保証 -- **リトライ制限**: `on_turn_end`での`ContinueWithMessages`による無限ループ防止策 +- **状態の一貫性**: + `on_message_send`で改変されたコンテキストが後続処理で期待通りに反映される保証 +- **リトライ制限**: + `on_turn_end`での`ContinueWithMessages`による無限ループ防止策 - **Hook優先度**: 登録順以外の優先度指定メカニズムの必要性 - **条件付きHook**: 特定条件でのみ有効化されるHookパターン - **テスト容易性**: Hookのモック/スタブ作成のためのユーティリティ diff --git a/docs/spec/worker_design.md b/docs/spec/worker_design.md index 300f7ce..767b3eb 100644 --- a/docs/spec/worker_design.md +++ b/docs/spec/worker_design.md @@ -178,41 +178,60 @@ Workerは生成されたラッパー構造体を `Box` として保持 ```rust #[async_trait] -pub trait WorkerHook: Send + Sync { - /// メッセージ送信前。 - /// リクエストに含まれるメッセージリストを改変できる。 - async fn on_message_send(&self, context: &mut Vec) -> Result { - Ok(ControlFlow::Continue) - } - - /// ツール実行前。 - /// 実行をキャンセルしたり、引数を書き換えることができる。 - async fn before_tool_call(&self, tool_call: &mut ToolCall) -> Result { - Ok(ControlFlow::Continue) - } - - /// ツール実行後。 - /// 結果を書き換えたり、隠蔽したりできる。 - async fn after_tool_call(&self, tool_result: &mut ToolResult) -> Result { - Ok(ControlFlow::Continue) - } - - /// ターン終了時。 - /// 生成されたメッセージを検査し、必要ならリトライ(ContinueWithMessages)を指示できる。 - async fn on_turn_end(&self, messages: &[Message]) -> Result { - Ok(TurnResult::Finish) - } +pub trait Hook: Send + Sync { + async fn call(&self, input: &mut E::Input) -> Result; } -pub enum ControlFlow { +pub trait HookEventKind { + type Input; + type Output; +} + +pub struct OnMessageSend; +pub struct BeforeToolCall; +pub struct AfterToolCall; +pub struct OnTurnEnd; +pub struct OnAbort; + +pub enum OnMessageSendResult { Continue, - Skip, // Tool実行などをスキップ - Abort(String), // 処理中断 + Cancel(String), } -pub enum TurnResult { +pub enum BeforeToolCallResult { + Continue, + Skip, // Tool実行などをスキップ + Abort(String), // 処理中断 + Pause, +} + +pub enum AfterToolCallResult { + Continue, + Abort(String), +} + +pub enum OnTurnEndResult { Finish, ContinueWithMessages(Vec), // メッセージを追加してターン継続(自己修正など) + Paused, +} +``` + +### Tool Call Context + +`before_tool_call` / `after_tool_call` は、ツール実行の文脈を含む入力を受け取る。 + +```rust +pub struct ToolCallContext { + pub call: ToolCall, + pub meta: ToolMeta, // 不変メタデータ + pub tool: Arc, // 状態アクセス用 +} + +pub struct ToolResultContext { + pub result: ToolResult, + pub meta: ToolMeta, + pub tool: Arc, } ``` @@ -433,4 +452,3 @@ impl Worker { 3. **選択的購読**: on_*で必要なイベントだけ、またはSubscriberで一括 4. **累積イベントの追加**: Worker層でComplete系イベントを追加提供 5. **後方互換性**: 従来の`run()`も引き続き使用可能 - diff --git a/llm-worker/Cargo.toml b/llm-worker/Cargo.toml index a5e8d9f..529327c 100644 --- a/llm-worker/Cargo.toml +++ b/llm-worker/Cargo.toml @@ -15,7 +15,8 @@ tracing = "0.1" async-trait = "0.1" futures = "0.3" tokio = { version = "1.49", features = ["macros", "rt-multi-thread"] } -reqwest = { version = "0.13.1", default-features = false, features = ["stream", "json", "native-tls"] } +tokio-util = "0.7" +reqwest = { version = "0.13.1", default-features = false, features = ["stream", "json", "native-tls", "http2"] } eventsource-stream = "0.2" llm-worker-macros = { path = "../llm-worker-macros", version = "0.1" } diff --git a/llm-worker/examples/worker_cancel_demo.rs b/llm-worker/examples/worker_cancel_demo.rs new file mode 100644 index 0000000..ad9173c --- /dev/null +++ b/llm-worker/examples/worker_cancel_demo.rs @@ -0,0 +1,71 @@ +//! Worker のキャンセル機能のデモンストレーション +//! +//! ストリーミング受信中に別スレッドからキャンセルする例 + +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::Mutex; +use llm_worker::{Worker, WorkerResult}; +use llm_worker::llm_client::providers::anthropic::AnthropicClient; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // .envファイルを読み込む + dotenv::dotenv().ok(); + + // ロギング初期化 + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")), + ) + .init(); + + let api_key = std::env::var("ANTHROPIC_API_KEY") + .expect("ANTHROPIC_API_KEY environment variable not set"); + + let client = AnthropicClient::new(&api_key, "claude-sonnet-4-20250514"); + let worker = Arc::new(Mutex::new(Worker::new(client))); + + println!("🚀 Starting Worker..."); + println!("💡 Will cancel after 2 seconds\n"); + + // キャンセルトークンを先に取得(ロックを保持しない) + let cancel_token = { + let w = worker.lock().await; + w.cancellation_token().clone() + }; + + // タスク1: Workerを実行 + let worker_clone = worker.clone(); + let task = tokio::spawn(async move { + let mut w = worker_clone.lock().await; + println!("📡 Sending request to LLM..."); + + match w.run("Tell me a very long story about a brave knight. Make it as detailed as possible with many paragraphs.").await { + Ok(WorkerResult::Finished(_)) => { + println!("✅ Task completed normally"); + } + Ok(WorkerResult::Paused(_)) => { + println!("⏸️ Task paused"); + } + Err(e) => { + println!("❌ Task error: {}", e); + } + } + }); + + // タスク2: 2秒後にキャンセル + tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(2)).await; + println!("\n🛑 Cancelling worker..."); + cancel_token.cancel(); + }); + + // タスク完了を待つ + task.await?; + + println!("\n✨ Demo complete!"); + + Ok(()) +} diff --git a/llm-worker/examples/worker_cli.rs b/llm-worker/examples/worker_cli.rs index 856af11..a576871 100644 --- a/llm-worker/examples/worker_cli.rs +++ b/llm-worker/examples/worker_cli.rs @@ -41,7 +41,7 @@ use tracing_subscriber::EnvFilter; use clap::{Parser, ValueEnum}; use llm_worker::{ Worker, - hook::{ControlFlow, HookError, ToolResult, WorkerHook}, + hook::{AfterToolCall, AfterToolCallResult, Hook, HookError, ToolResult}, llm_client::{ LlmClient, providers::{ @@ -282,11 +282,11 @@ impl ToolResultPrinterHook { } #[async_trait] -impl WorkerHook for ToolResultPrinterHook { - async fn after_tool_call( +impl Hook for ToolResultPrinterHook { + async fn call( &self, tool_result: &mut ToolResult, - ) -> Result { + ) -> Result { let name = self .call_names .lock() @@ -300,7 +300,7 @@ impl WorkerHook for ToolResultPrinterHook { println!(" Result ({}): ✅ {}", name, tool_result.content); } - Ok(ControlFlow::Continue) + Ok(AfterToolCallResult::Continue) } } @@ -451,7 +451,7 @@ async fn main() -> Result<(), Box> { .on_text_block(StreamingPrinter::new()) .on_tool_use_block(ToolCallPrinter::new(tool_call_names.clone())); - worker.add_hook(ToolResultPrinterHook::new(tool_call_names)); + worker.add_after_tool_call_hook(ToolResultPrinterHook::new(tool_call_names)); // ワンショットモード if let Some(prompt) = args.prompt { diff --git a/llm-worker/src/hook.rs b/llm-worker/src/hook.rs index 25b0fc5..1007dec 100644 --- a/llm-worker/src/hook.rs +++ b/llm-worker/src/hook.rs @@ -8,33 +8,72 @@ use serde_json::Value; use thiserror::Error; // ============================================================================= -// Control Flow Types +// Hook Event Kinds // ============================================================================= -/// Hook処理の制御フロー +pub trait HookEventKind: Send + Sync + 'static { + type Input; + type Output; +} + +pub struct OnMessageSend; +pub struct BeforeToolCall; +pub struct AfterToolCall; +pub struct OnTurnEnd; +pub struct OnAbort; + #[derive(Debug, Clone, PartialEq, Eq)] -pub enum ControlFlow { - /// 処理を続行 +pub enum OnMessageSendResult { + Continue, + Cancel(String), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum BeforeToolCallResult { Continue, - /// 現在の処理をスキップ(Tool実行など) Skip, - /// 処理を中断 Abort(String), - /// 処理を一時停止(再開可能) Pause, } -/// ターン終了時の判定結果 +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AfterToolCallResult { + Continue, + Abort(String), +} + #[derive(Debug, Clone)] -pub enum TurnResult { - /// ターンを終了 +pub enum OnTurnEndResult { Finish, - /// メッセージを追加してターン継続(自己修正など) ContinueWithMessages(Vec), - /// ターンを一時停止 Paused, } +impl HookEventKind for OnMessageSend { + type Input = Vec; + type Output = OnMessageSendResult; +} + +impl HookEventKind for BeforeToolCall { + type Input = ToolCall; + type Output = BeforeToolCallResult; +} + +impl HookEventKind for AfterToolCall { + type Input = ToolResult; + type Output = AfterToolCallResult; +} + +impl HookEventKind for OnTurnEnd { + type Input = Vec; + type Output = OnTurnEndResult; +} + +impl HookEventKind for OnAbort { + type Input = String; + type Output = (); +} + // ============================================================================= // Tool Call / Result Types // ============================================================================= @@ -102,85 +141,13 @@ pub enum HookError { } // ============================================================================= -// WorkerHook Trait +// Hook Trait // ============================================================================= -/// ターンの進行・ツール実行に介入するためのトレイト +/// Hookイベントの処理を行うトレイト /// -/// Hookを使うと、メッセージ送信前、ツール実行前後、ターン終了時に -/// 処理を挟んだり、実行をキャンセルしたりできます。 -/// -/// # Examples -/// -/// ```ignore -/// use llm_worker::hook::{ControlFlow, HookError, ToolCall, TurnResult, WorkerHook}; -/// use llm_worker::Message; -/// -/// struct ValidationHook; -/// -/// #[async_trait::async_trait] -/// impl WorkerHook for ValidationHook { -/// async fn before_tool_call(&self, call: &mut ToolCall) -> Result { -/// // 危険なツールをブロック -/// if call.name == "delete_all" { -/// return Ok(ControlFlow::Skip); -/// } -/// Ok(ControlFlow::Continue) -/// } -/// -/// async fn on_turn_end(&self, messages: &[Message]) -> Result { -/// // 条件を満たさなければ追加メッセージで継続 -/// if messages.len() < 3 { -/// return Ok(TurnResult::ContinueWithMessages(vec![ -/// Message::user("Please elaborate.") -/// ])); -/// } -/// Ok(TurnResult::Finish) -/// } -/// } -/// ``` -/// -/// # デフォルト実装 -/// -/// すべてのメソッドにはデフォルト実装があり、何も行わず`Continue`を返します。 -/// 必要なメソッドのみオーバーライドしてください。 +/// 各イベント種別は戻り値型が異なるため、`HookEventKind`を介して型を制約する。 #[async_trait] -pub trait WorkerHook: Send + Sync { - /// メッセージ送信前に呼ばれる - /// - /// リクエストに含まれるメッセージリストを参照・改変できます。 - /// `ControlFlow::Abort`を返すとターンが中断されます。 - async fn on_message_send( - &self, - _context: &mut Vec, - ) -> Result { - Ok(ControlFlow::Continue) - } - - /// ツール実行前に呼ばれる - /// - /// ツール呼び出しの引数を書き換えたり、実行をスキップしたりできます。 - /// `ControlFlow::Skip`を返すとこのツールの実行がスキップされます。 - async fn before_tool_call(&self, _tool_call: &mut ToolCall) -> Result { - Ok(ControlFlow::Continue) - } - - /// ツール実行後に呼ばれる - /// - /// ツールの実行結果を書き換えたり、隠蔽したりできます。 - async fn after_tool_call( - &self, - _tool_result: &mut ToolResult, - ) -> Result { - Ok(ControlFlow::Continue) - } - - /// ターン終了時に呼ばれる - /// - /// 生成されたメッセージを検査し、必要なら追加メッセージで継続を指示できます。 - /// `TurnResult::ContinueWithMessages`を返すと、指定したメッセージを追加して - /// 次のターンに進みます。 - async fn on_turn_end(&self, _messages: &[crate::Message]) -> Result { - Ok(TurnResult::Finish) - } +pub trait Hook: Send + Sync { + async fn call(&self, input: &mut E::Input) -> Result; } diff --git a/llm-worker/src/lib.rs b/llm-worker/src/lib.rs index 5668503..ef743b4 100644 --- a/llm-worker/src/lib.rs +++ b/llm-worker/src/lib.rs @@ -6,7 +6,7 @@ //! //! - [`Worker`] - LLMとの対話を管理する中心コンポーネント //! - [`tool::Tool`] - LLMから呼び出し可能なツール -//! - [`hook::WorkerHook`] - ターン進行への介入 +//! - [`hook::Hook`] - ターン進行への介入 //! - [`subscriber::WorkerSubscriber`] - ストリーミングイベントの購読 //! //! # Quick Start @@ -48,9 +48,5 @@ pub mod subscriber; pub mod timeline; pub mod tool; -// ============================================================================= -// トップレベル公開(最も頻繁に使う型) -// ============================================================================= - pub use message::{ContentPart, Message, MessageContent, Role}; -pub use worker::{Worker, WorkerConfig, WorkerError}; +pub use worker::{Worker, WorkerConfig, WorkerError, WorkerResult}; diff --git a/llm-worker/src/subscriber.rs b/llm-worker/src/subscriber.rs index d750e63..0e7cdf6 100644 --- a/llm-worker/src/subscriber.rs +++ b/llm-worker/src/subscriber.rs @@ -65,10 +65,10 @@ pub trait WorkerSubscriber: Send { /// /// ブロック開始時にDefault::default()で生成され、 /// ブロック終了時に破棄される。 - type TextBlockScope: Default + Send; + type TextBlockScope: Default + Send + Sync; /// ツール使用ブロック処理用のスコープ型 - type ToolUseBlockScope: Default + Send; + type ToolUseBlockScope: Default + Send + Sync; // ========================================================================= // ブロックイベント(スコープ管理あり) diff --git a/llm-worker/src/timeline/timeline.rs b/llm-worker/src/timeline/timeline.rs index 8ce92f0..8bb18f8 100644 --- a/llm-worker/src/timeline/timeline.rs +++ b/llm-worker/src/timeline/timeline.rs @@ -17,7 +17,7 @@ use crate::handler::*; /// 各Handlerは独自のScope型を持つため、Timelineで保持するには型消去が必要です。 /// 通常は直接使用せず、`Timeline::on_text_block()`などのメソッド経由で /// 自動的にラップされます。 -pub trait ErasedHandler: Send { +pub trait ErasedHandler: Send + Sync { /// イベントをディスパッチ fn dispatch(&mut self, event: &K::Event); /// スコープを開始(Block開始時) @@ -54,9 +54,9 @@ where impl ErasedHandler for HandlerWrapper where - H: Handler + Send, + H: Handler + Send + Sync, K: Kind, - H::Scope: Send, + H::Scope: Send + Sync, { fn dispatch(&mut self, event: &K::Event) { if let Some(scope) = &mut self.scope { @@ -78,7 +78,7 @@ where // ============================================================================= /// ブロックハンドラーの型消去trait -trait ErasedBlockHandler: Send { +trait ErasedBlockHandler: Send + Sync { fn dispatch_start(&mut self, start: &BlockStart); fn dispatch_delta(&mut self, delta: &BlockDelta); fn dispatch_stop(&mut self, stop: &BlockStop); @@ -112,8 +112,8 @@ where impl ErasedBlockHandler for TextBlockHandlerWrapper where - H: Handler + Send, - H::Scope: Send, + H: Handler + Send + Sync, + H::Scope: Send + Sync, { fn dispatch_start(&mut self, start: &BlockStart) { if let Some(scope) = &mut self.scope { @@ -185,8 +185,8 @@ where impl ErasedBlockHandler for ThinkingBlockHandlerWrapper where - H: Handler + Send, - H::Scope: Send, + H: Handler + Send + Sync, + H::Scope: Send + Sync, { fn dispatch_start(&mut self, start: &BlockStart) { if let Some(scope) = &mut self.scope { @@ -255,8 +255,8 @@ where impl ErasedBlockHandler for ToolUseBlockHandlerWrapper where - H: Handler + Send, - H::Scope: Send, + H: Handler + Send + Sync, + H::Scope: Send + Sync, { fn dispatch_start(&mut self, start: &BlockStart) { if let Some(scope) = &mut self.scope { @@ -391,8 +391,8 @@ impl Timeline { /// UsageKind用のHandlerを登録 pub fn on_usage(&mut self, handler: H) -> &mut Self where - H: Handler + Send + 'static, - H::Scope: Send, + H: Handler + Send + Sync + 'static, + H::Scope: Send + Sync, { // Meta系はデフォルトでスコープを開始しておく let mut wrapper = HandlerWrapper::new(handler); @@ -404,8 +404,8 @@ impl Timeline { /// PingKind用のHandlerを登録 pub fn on_ping(&mut self, handler: H) -> &mut Self where - H: Handler + Send + 'static, - H::Scope: Send, + H: Handler + Send + Sync + 'static, + H::Scope: Send + Sync, { let mut wrapper = HandlerWrapper::new(handler); wrapper.start_scope(); @@ -416,8 +416,8 @@ impl Timeline { /// StatusKind用のHandlerを登録 pub fn on_status(&mut self, handler: H) -> &mut Self where - H: Handler + Send + 'static, - H::Scope: Send, + H: Handler + Send + Sync + 'static, + H::Scope: Send + Sync, { let mut wrapper = HandlerWrapper::new(handler); wrapper.start_scope(); @@ -428,8 +428,8 @@ impl Timeline { /// ErrorKind用のHandlerを登録 pub fn on_error(&mut self, handler: H) -> &mut Self where - H: Handler + Send + 'static, - H::Scope: Send, + H: Handler + Send + Sync + 'static, + H::Scope: Send + Sync, { let mut wrapper = HandlerWrapper::new(handler); wrapper.start_scope(); @@ -440,8 +440,8 @@ impl Timeline { /// TextBlockKind用のHandlerを登録 pub fn on_text_block(&mut self, handler: H) -> &mut Self where - H: Handler + Send + 'static, - H::Scope: Send, + H: Handler + Send + Sync + 'static, + H::Scope: Send + Sync, { self.text_block_handlers .push(Box::new(TextBlockHandlerWrapper::new(handler))); @@ -451,8 +451,8 @@ impl Timeline { /// ThinkingBlockKind用のHandlerを登録 pub fn on_thinking_block(&mut self, handler: H) -> &mut Self where - H: Handler + Send + 'static, - H::Scope: Send, + H: Handler + Send + Sync + 'static, + H::Scope: Send + Sync, { self.thinking_block_handlers .push(Box::new(ThinkingBlockHandlerWrapper::new(handler))); @@ -462,8 +462,8 @@ impl Timeline { /// ToolUseBlockKind用のHandlerを登録 pub fn on_tool_use_block(&mut self, handler: H) -> &mut Self where - H: Handler + Send + 'static, - H::Scope: Send, + H: Handler + Send + Sync + 'static, + H::Scope: Send + Sync, { self.tool_use_block_handlers .push(Box::new(ToolUseBlockHandlerWrapper::new(handler))); @@ -578,6 +578,21 @@ impl Timeline { pub fn current_block(&self) -> Option { self.current_block } + + /// 現在アクティブなブロックを中断する + /// + /// キャンセルやエラー時に呼び出し、進行中のブロックに対して + /// BlockAbortイベントを発火してスコープをクリーンアップする。 + pub fn abort_current_block(&mut self) { + if let Some(block_type) = self.current_block { + let abort = crate::timeline::event::BlockAbort { + index: 0, // インデックスは不明なので0 + block_type, + reason: "Cancelled".to_string(), + }; + self.handle_block_abort(&abort); + } + } } #[cfg(test)] diff --git a/llm-worker/src/worker.rs b/llm-worker/src/worker.rs index 5425f71..0b413c0 100644 --- a/llm-worker/src/worker.rs +++ b/llm-worker/src/worker.rs @@ -3,11 +3,16 @@ use std::marker::PhantomData; use std::sync::{Arc, Mutex}; use futures::StreamExt; +use tokio_util::sync::CancellationToken; use tracing::{debug, info, trace, warn}; use crate::{ ContentPart, Message, MessageContent, Role, - hook::{ControlFlow, HookError, ToolCall, ToolResult, TurnResult, WorkerHook}, + hook::{ + AfterToolCall, AfterToolCallResult, BeforeToolCall, BeforeToolCallResult, Hook, HookError, + OnAbort, OnMessageSend, OnMessageSendResult, OnTurnEnd, OnTurnEndResult, ToolCall, + ToolResult, + }, llm_client::{ClientError, ConfigWarning, LlmClient, Request, RequestConfig, ToolDefinition}, state::{Locked, Mutable, WorkerState}, subscriber::{ @@ -37,6 +42,9 @@ pub enum WorkerError { /// 処理が中断された #[error("Aborted: {0}")] Aborted(String), + /// Cancellation Tokenによって中断された + #[error("Cancelled")] + Cancelled, /// 設定に関する警告(未サポートのオプション) #[error("Config warnings: {}", .0.iter().map(|w| w.to_string()).collect::>().join(", "))] ConfigWarnings(Vec), @@ -77,7 +85,7 @@ enum ToolExecutionResult { // ============================================================================= /// ターンイベントを通知するためのコールバック (型消去) -trait TurnNotifier: Send { +trait TurnNotifier: Send + Sync { fn on_turn_start(&self, turn: usize); fn on_turn_end(&self, turn: usize); } @@ -149,8 +157,16 @@ pub struct Worker { tool_call_collector: ToolCallCollector, /// 登録されたツール tools: HashMap>, - /// 登録されたHook - hooks: Vec>, + /// on_message_send Hook + hooks_on_message_send: Vec>>, + /// before_tool_call Hook + hooks_before_tool_call: Vec>>, + /// after_tool_call Hook + hooks_after_tool_call: Vec>>, + /// on_turn_end Hook + hooks_on_turn_end: Vec>>, + /// on_abort Hook + hooks_on_abort: Vec>>, /// システムプロンプト system_prompt: Option, /// メッセージ履歴(Workerが所有) @@ -163,6 +179,8 @@ pub struct Worker { turn_notifiers: Vec>, /// リクエスト設定(max_tokens, temperature等) request_config: RequestConfig, + /// キャンセレーショントークン(実行中断用) + cancellation_token: CancellationToken, /// 状態マーカー _state: PhantomData, } @@ -252,30 +270,29 @@ impl Worker { } } - /// Hookを追加する - /// - /// Hookはターンの進行・ツール実行に介入できます。 - /// 複数のHookを登録した場合、登録順に実行されます。 - /// - /// # Examples - /// - /// ```ignore - /// use llm_worker::{Worker, WorkerHook, ControlFlow, ToolCall}; - /// - /// struct LoggingHook; - /// - /// #[async_trait::async_trait] - /// impl WorkerHook for LoggingHook { - /// async fn before_tool_call(&self, call: &mut ToolCall) -> Result { - /// println!("Calling tool: {}", call.name); - /// Ok(ControlFlow::Continue) - /// } - /// } - /// - /// worker.add_hook(LoggingHook); - /// ``` - pub fn add_hook(&mut self, hook: impl WorkerHook + 'static) { - self.hooks.push(Box::new(hook)); + /// on_message_send Hookを追加する + pub fn add_on_message_send_hook(&mut self, hook: impl Hook + 'static) { + self.hooks_on_message_send.push(Box::new(hook)); + } + + /// before_tool_call Hookを追加する + pub fn add_before_tool_call_hook(&mut self, hook: impl Hook + 'static) { + self.hooks_before_tool_call.push(Box::new(hook)); + } + + /// after_tool_call Hookを追加する + pub fn add_after_tool_call_hook(&mut self, hook: impl Hook + 'static) { + self.hooks_after_tool_call.push(Box::new(hook)); + } + + /// on_turn_end Hookを追加する + pub fn add_on_turn_end_hook(&mut self, hook: impl Hook + 'static) { + self.hooks_on_turn_end.push(Box::new(hook)); + } + + /// on_abort Hookを追加する + pub fn add_on_abort_hook(&mut self, hook: impl Hook + 'static) { + self.hooks_on_abort.push(Box::new(hook)); } /// タイムラインへの可変参照を取得(追加ハンドラ登録用) @@ -375,6 +392,41 @@ impl Worker { self.request_config = config; } + /// 実行をキャンセルする + /// + /// 現在実行中のストリーミングやツール実行を中断します。 + /// 次のイベントループのチェックポイントでWorkerError::Cancelledが返されます。 + /// + /// # Examples + /// + /// ```ignore + /// use std::sync::Arc; + /// let worker = Arc::new(Mutex::new(Worker::new(client))); + /// + /// // 別スレッドで実行 + /// let worker_clone = worker.clone(); + /// tokio::spawn(async move { + /// let mut w = worker_clone.lock().unwrap(); + /// w.run("Long task...").await + /// }); + /// + /// // キャンセル + /// worker.lock().unwrap().cancel(); + /// ``` + pub fn cancel(&self) { + self.cancellation_token.cancel(); + } + + /// キャンセルされているかチェック + pub fn is_cancelled(&self) -> bool { + self.cancellation_token.is_cancelled() + } + + /// キャンセレーショントークンへの参照を取得 + pub fn cancellation_token(&self) -> &CancellationToken { + &self.cancellation_token + } + /// 登録されたツールからToolDefinitionのリストを生成 fn build_tool_definitions(&self) -> Vec { self.tools @@ -430,7 +482,7 @@ impl Worker { } /// リクエストを構築 - fn build_request(&self, tool_definitions: &[ToolDefinition]) -> Request { + fn build_request(&self, tool_definitions: &[ToolDefinition], context: &[Message]) -> Request { let mut request = Request::new(); // システムプロンプトを設定 @@ -439,7 +491,7 @@ impl Worker { } // メッセージを追加 - for msg in &self.history { + for msg in context { // Message から llm_client::Message への変換 request = request.message(crate::llm_client::Message { role: match msg.role { @@ -495,36 +547,45 @@ impl Worker { } /// Hooks: on_message_send - async fn run_on_message_send_hooks(&self) -> Result { - for hook in &self.hooks { - // Note: Locked状態でも履歴全体を参照として渡す(変更は不可) - // HookのAPIを変更し、immutable参照のみを渡すようにする必要があるかもしれない - // 現在は空のVecを渡して回避(要検討) - let mut temp_context = self.history.clone(); - let result = hook.on_message_send(&mut temp_context).await?; + async fn run_on_message_send_hooks( + &self, + ) -> Result<(OnMessageSendResult, Vec), WorkerError> { + let mut temp_context = self.history.clone(); + for hook in &self.hooks_on_message_send { + let result = hook.call(&mut temp_context).await?; match result { - ControlFlow::Continue => continue, - ControlFlow::Skip => return Ok(ControlFlow::Skip), - ControlFlow::Abort(reason) => return Ok(ControlFlow::Abort(reason)), - ControlFlow::Pause => return Ok(ControlFlow::Pause), + OnMessageSendResult::Continue => continue, + OnMessageSendResult::Cancel(reason) => { + return Ok((OnMessageSendResult::Cancel(reason), temp_context)); + } } } - Ok(ControlFlow::Continue) + Ok((OnMessageSendResult::Continue, temp_context)) } /// Hooks: on_turn_end - async fn run_on_turn_end_hooks(&self) -> Result { - for hook in &self.hooks { - let result = hook.on_turn_end(&self.history).await?; + async fn run_on_turn_end_hooks(&self) -> Result { + let mut temp_messages = self.history.clone(); + for hook in &self.hooks_on_turn_end { + let result = hook.call(&mut temp_messages).await?; match result { - TurnResult::Finish => continue, - TurnResult::ContinueWithMessages(msgs) => { - return Ok(TurnResult::ContinueWithMessages(msgs)); + OnTurnEndResult::Finish => continue, + OnTurnEndResult::ContinueWithMessages(msgs) => { + return Ok(OnTurnEndResult::ContinueWithMessages(msgs)); } - TurnResult::Paused => return Ok(TurnResult::Paused), + OnTurnEndResult::Paused => return Ok(OnTurnEndResult::Paused), } } - Ok(TurnResult::Finish) + Ok(OnTurnEndResult::Finish) + } + + /// Hooks: on_abort + async fn run_on_abort_hooks(&self, reason: &str) -> Result<(), WorkerError> { + let mut reason = reason.to_string(); + for hook in &self.hooks_on_abort { + hook.call(&mut reason).await?; + } + Ok(()) } /// 未実行のツール呼び出しがあるかチェック(Pauseからの復帰用) @@ -559,7 +620,7 @@ impl Worker { /// 全てのツールに対してbefore_tool_callフックを実行後、 /// 許可されたツールを並列に実行し、結果にafter_tool_callフックを適用する。 async fn execute_tools( - &self, + &mut self, tool_calls: Vec, ) -> Result { use futures::future::join_all; @@ -568,18 +629,18 @@ impl Worker { let mut approved_calls = Vec::new(); for mut tool_call in tool_calls { let mut skip = false; - for hook in &self.hooks { - let result = hook.before_tool_call(&mut tool_call).await?; + for hook in &self.hooks_before_tool_call { + let result = hook.call(&mut tool_call).await?; match result { - ControlFlow::Continue => {} - ControlFlow::Skip => { + BeforeToolCallResult::Continue => {} + BeforeToolCallResult::Skip => { skip = true; break; } - ControlFlow::Abort(reason) => { + BeforeToolCallResult::Abort(reason) => { return Err(WorkerError::Aborted(reason)); } - ControlFlow::Pause => { + BeforeToolCallResult::Pause => { return Ok(ToolExecutionResult::Paused); } } @@ -589,7 +650,7 @@ impl Worker { } } - // Phase 2: 許可されたツールを並列実行 + // Phase 2: 許可されたツールを並列実行(キャンセル可能) let futures: Vec<_> = approved_calls .into_iter() .map(|tool_call| { @@ -612,25 +673,26 @@ impl Worker { }) .collect(); - let mut results = join_all(futures).await; + // ツール実行をキャンセル可能にする + let mut results = tokio::select! { + results = join_all(futures) => results, + _ = self.cancellation_token.cancelled() => { + info!("Tool execution cancelled"); + self.timeline.abort_current_block(); + self.run_on_abort_hooks("Cancelled").await?; + return Err(WorkerError::Cancelled); + } + }; // Phase 3: after_tool_call フックを適用 for tool_result in &mut results { - for hook in &self.hooks { - let result = hook.after_tool_call(tool_result).await?; + for hook in &self.hooks_after_tool_call { + let result = hook.call(tool_result).await?; match result { - ControlFlow::Continue => {} - ControlFlow::Skip => break, - ControlFlow::Abort(reason) => { + AfterToolCallResult::Continue => {} + AfterToolCallResult::Abort(reason) => { return Err(WorkerError::Aborted(reason)); } - ControlFlow::Pause => { - // after_tool_callでのPauseは結果を受け入れた後、次の処理前に止まる動作とする - // ここではContinue扱いとし、on_message_send等でPauseすることを期待する - // あるいはここでのPauseをサポートする場合は戻り値を調整する必要がある - // 現状はログを出してContinue - warn!("ControlFlow::Pause in after_tool_call is treated as Continue"); - } } } } @@ -663,6 +725,14 @@ impl Worker { } loop { + // キャンセルチェック + if self.cancellation_token.is_cancelled() { + info!("Execution cancelled"); + self.timeline.abort_current_block(); + self.run_on_abort_hooks("Cancelled").await?; + return Err(WorkerError::Cancelled); + } + // ターン開始を通知 let current_turn = self.turn_count; debug!(turn = current_turn, "Turn start"); @@ -671,24 +741,21 @@ impl Worker { } // Hook: on_message_send - let control = self.run_on_message_send_hooks().await?; + let (control, request_context) = self.run_on_message_send_hooks().await?; match control { - ControlFlow::Abort(reason) => { - warn!(reason = %reason, "Aborted by hook"); + OnMessageSendResult::Cancel(reason) => { + info!(reason = %reason, "Aborted by hook"); for notifier in &self.turn_notifiers { notifier.on_turn_end(current_turn); } + self.run_on_abort_hooks(&reason).await?; return Err(WorkerError::Aborted(reason)); } - ControlFlow::Pause | ControlFlow::Skip => { - // Skip or Pause -> Pause the worker - return Ok(WorkerResult::Paused(&self.history)); - } - ControlFlow::Continue => {} + OnMessageSendResult::Continue => {} } // リクエスト構築 - let request = self.build_request(&tool_definitions); + let request = self.build_request(&tool_definitions, &request_context); debug!( message_count = request.messages.len(), tool_count = request.tools.len(), @@ -698,21 +765,49 @@ impl Worker { // ストリーム処理 debug!("Starting stream..."); - let mut stream = self.client.stream(request).await?; let mut event_count = 0; - while let Some(event_result) = stream.next().await { - match &event_result { - Ok(event) => { - trace!(event = ?event, "Received event"); - event_count += 1; + + // ストリームを取得(キャンセル可能) + let mut stream = tokio::select! { + stream_result = self.client.stream(request) => stream_result?, + _ = self.cancellation_token.cancelled() => { + info!("Cancelled before stream started"); + self.timeline.abort_current_block(); + self.run_on_abort_hooks("Cancelled").await?; + return Err(WorkerError::Cancelled); + } + }; + + loop { + tokio::select! { + // ストリームからイベントを受信 + event_result = stream.next() => { + match event_result { + Some(result) => { + match &result { + Ok(event) => { + trace!(event = ?event, "Received event"); + event_count += 1; + } + Err(e) => { + warn!(error = %e, "Stream error"); + } + } + let event = result?; + let timeline_event: crate::timeline::event::Event = event.into(); + self.timeline.dispatch(&timeline_event); + } + None => break, // ストリーム終了 + } } - Err(e) => { - warn!(error = %e, "Stream error"); + // キャンセル待機 + _ = self.cancellation_token.cancelled() => { + info!("Stream cancelled"); + self.timeline.abort_current_block(); + self.run_on_abort_hooks("Cancelled").await?; + return Err(WorkerError::Cancelled); } } - let event = event_result?; - let timeline_event: crate::timeline::event::Event = event.into(); - self.timeline.dispatch(&timeline_event); } debug!(event_count = event_count, "Stream completed"); @@ -736,14 +831,14 @@ impl Worker { // ツール呼び出しなし → ターン終了判定 let turn_result = self.run_on_turn_end_hooks().await?; match turn_result { - TurnResult::Finish => { + OnTurnEndResult::Finish => { return Ok(WorkerResult::Finished(&self.history)); } - TurnResult::ContinueWithMessages(additional) => { + OnTurnEndResult::ContinueWithMessages(additional) => { self.history.extend(additional); continue; } - TurnResult::Paused => { + OnTurnEndResult::Paused => { return Ok(WorkerResult::Paused(&self.history)); } } @@ -790,13 +885,18 @@ impl Worker { text_block_collector, tool_call_collector, tools: HashMap::new(), - hooks: Vec::new(), + hooks_on_message_send: Vec::new(), + hooks_before_tool_call: Vec::new(), + hooks_after_tool_call: Vec::new(), + hooks_on_turn_end: Vec::new(), + hooks_on_abort: Vec::new(), system_prompt: None, history: Vec::new(), locked_prefix_len: 0, turn_count: 0, turn_notifiers: Vec::new(), request_config: RequestConfig::default(), + cancellation_token: CancellationToken::new(), _state: PhantomData, } } @@ -958,13 +1058,18 @@ impl Worker { text_block_collector: self.text_block_collector, tool_call_collector: self.tool_call_collector, tools: self.tools, - hooks: self.hooks, + hooks_on_message_send: self.hooks_on_message_send, + hooks_before_tool_call: self.hooks_before_tool_call, + hooks_after_tool_call: self.hooks_after_tool_call, + hooks_on_turn_end: self.hooks_on_turn_end, + hooks_on_abort: self.hooks_on_abort, system_prompt: self.system_prompt, history: self.history, locked_prefix_len, turn_count: self.turn_count, turn_notifiers: self.turn_notifiers, request_config: self.request_config, + cancellation_token: self.cancellation_token, _state: PhantomData, } } @@ -1032,13 +1137,18 @@ impl Worker { text_block_collector: self.text_block_collector, tool_call_collector: self.tool_call_collector, tools: self.tools, - hooks: self.hooks, + hooks_on_message_send: self.hooks_on_message_send, + hooks_before_tool_call: self.hooks_before_tool_call, + hooks_after_tool_call: self.hooks_after_tool_call, + hooks_on_turn_end: self.hooks_on_turn_end, + hooks_on_abort: self.hooks_on_abort, system_prompt: self.system_prompt, history: self.history, locked_prefix_len: 0, turn_count: self.turn_count, turn_notifiers: self.turn_notifiers, request_config: self.request_config, + cancellation_token: self.cancellation_token, _state: PhantomData, } } diff --git a/llm-worker/tests/parallel_execution_test.rs b/llm-worker/tests/parallel_execution_test.rs index ac92315..485df23 100644 --- a/llm-worker/tests/parallel_execution_test.rs +++ b/llm-worker/tests/parallel_execution_test.rs @@ -8,7 +8,10 @@ use std::time::{Duration, Instant}; use async_trait::async_trait; use llm_worker::Worker; -use llm_worker::hook::{ControlFlow, HookError, ToolCall, ToolResult, WorkerHook}; +use llm_worker::hook::{ + AfterToolCall, AfterToolCallResult, BeforeToolCall, BeforeToolCallResult, Hook, HookError, + ToolCall, ToolResult, +}; use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent}; use llm_worker::tool::{Tool, ToolError}; @@ -158,20 +161,17 @@ async fn test_before_tool_call_skip() { struct BlockingHook; #[async_trait] - impl WorkerHook for BlockingHook { - async fn before_tool_call( - &self, - tool_call: &mut ToolCall, - ) -> Result { + impl Hook for BlockingHook { + async fn call(&self, tool_call: &mut ToolCall) -> Result { if tool_call.name == "blocked_tool" { - Ok(ControlFlow::Skip) + Ok(BeforeToolCallResult::Skip) } else { - Ok(ControlFlow::Continue) + Ok(BeforeToolCallResult::Continue) } } } - worker.add_hook(BlockingHook); + worker.add_before_tool_call_hook(BlockingHook); let _result = worker.run("Test hook").await; @@ -242,19 +242,19 @@ async fn test_after_tool_call_modification() { } #[async_trait] - impl WorkerHook for ModifyingHook { - async fn after_tool_call( + impl Hook for ModifyingHook { + async fn call( &self, tool_result: &mut ToolResult, - ) -> Result { + ) -> Result { tool_result.content = format!("[Modified] {}", tool_result.content); *self.modified_content.lock().unwrap() = Some(tool_result.content.clone()); - Ok(ControlFlow::Continue) + Ok(AfterToolCallResult::Continue) } } let modified_content = Arc::new(std::sync::Mutex::new(None)); - worker.add_hook(ModifyingHook { + worker.add_after_tool_call_hook(ModifyingHook { modified_content: modified_content.clone(), }); -- 2.43.0 From 16fda38039c2b14f2a6b08147b140e98e798696b Mon Sep 17 00:00:00 2001 From: Hare Date: Sat, 10 Jan 2026 00:31:14 +0900 Subject: [PATCH 2/4] feat: Redesign the tool system --- docs/spec/hooks_design.md | 206 +++++++---- docs/spec/tools_design.md | 191 +++++++++++ llm-worker-macros/src/lib.rs | 31 +- llm-worker/examples/worker_cancel_demo.rs | 12 +- llm-worker/examples/worker_cli.rs | 29 +- llm-worker/src/hook.rs | 108 +++++- llm-worker/src/lib.rs | 5 +- llm-worker/src/tool.rs | 128 +++++-- llm-worker/src/worker.rs | 359 +++++++++++++------- llm-worker/tests/parallel_execution_test.rs | 97 +++--- llm-worker/tests/tool_macro_test.rs | 82 +++-- llm-worker/tests/validation_test.rs | 1 - llm-worker/tests/worker_fixtures.rs | 44 ++- 13 files changed, 897 insertions(+), 396 deletions(-) create mode 100644 docs/spec/tools_design.md diff --git a/docs/spec/hooks_design.md b/docs/spec/hooks_design.md index 08354a6..2d8b64b 100644 --- a/docs/spec/hooks_design.md +++ b/docs/spec/hooks_design.md @@ -14,13 +14,14 @@ HookはWorker層でのターン制御に介入するためのメカニズムで ## Hook一覧 -| Hook | タイミング | 主な用途 | 戻り値 | -| ------------------ | -------------------------- | --------------------- | ---------------------- | -| `on_message_send` | LLM送信前 | コンテキスト改変/検証 | `OnMessageSendResult` | -| `before_tool_call` | ツール実行前 | 実行許可/引数改変 | `BeforeToolCallResult` | -| `after_tool_call` | ツール実行後 | 結果加工/マスキング | `AfterToolCallResult` | -| `on_turn_end` | ツールなしでターン終了直前 | 検証/リトライ指示 | `OnTurnEndResult` | -| `on_abort` | 中断時 | クリーンアップ/通知 | `()` | +| Hook | タイミング | 主な用途 | 戻り値 | +| ------------------ | -------------------------- | -------------------------- | ---------------------- | +| `on_prompt_submit` | `run()` 呼び出し時 | ユーザーメッセージの前処理 | `OnPromptSubmitResult` | +| `pre_llm_request` | 各ターンのLLM送信前 | コンテキスト改変/検証 | `PreLlmRequestResult` | +| `pre_tool_call` | ツール実行前 | 実行許可/引数改変 | `PreToolCallResult` | +| `post_tool_call` | ツール実行後 | 結果加工/マスキング | `PostToolCallResult` | +| `on_turn_end` | ツールなしでターン終了直前 | 検証/リトライ指示 | `OnTurnEndResult` | +| `on_abort` | 中断時 | クリーンアップ/通知 | `()` | ## Hook Trait @@ -43,25 +44,31 @@ pub trait HookEventKind { type Output; } -pub struct OnMessageSend; -pub struct BeforeToolCall; -pub struct AfterToolCall; +pub struct OnPromptSubmit; +pub struct PreLlmRequest; +pub struct PreToolCall; +pub struct PostToolCall; pub struct OnTurnEnd; pub struct OnAbort; -pub enum OnMessageSendResult { +pub enum OnPromptSubmitResult { Continue, Cancel(String), } -pub enum BeforeToolCallResult { +pub enum PreLlmRequestResult { + Continue, + Cancel(String), +} + +pub enum PreToolCallResult { Continue, Skip, Abort(String), Pause, } -pub enum AfterToolCallResult { +pub enum PostToolCallResult { Continue, Abort(String), } @@ -75,7 +82,7 @@ pub enum OnTurnEndResult { ### Tool Call Context -`before_tool_call` / `after_tool_call` は、ツール実行の文脈を含む入力を受け取る。 +`pre_tool_call` / `post_tool_call` は、ツール実行の文脈を含む入力を受け取る。 ```rust pub struct ToolCallContext { @@ -84,7 +91,8 @@ pub struct ToolCallContext { pub tool: Arc, // 状態アクセス用 } -pub struct ToolResultContext { +pub struct PostToolCallContext { + pub call: ToolCall, pub result: ToolResult, pub meta: ToolMeta, pub tool: Arc, @@ -94,40 +102,84 @@ pub struct ToolResultContext { ## 呼び出しタイミング ``` -Worker::run() ループ +Worker::run(user_input) │ -├─▶ on_message_send ──────────────────────────────┐ -│ コンテキストの改変、バリデーション、 │ -│ システムプロンプト注入などが可能 │ -│ │ -├─▶ LLMリクエスト送信 & ストリーム処理 │ -│ │ -├─▶ ツール呼び出しがある場合: │ -│ │ │ -│ ├─▶ before_tool_call (各ツールごと・逐次) │ -│ │ 実行可否の判定、引数の改変 │ -│ │ │ -│ ├─▶ ツール並列実行 (join_all) │ -│ │ │ -│ └─▶ after_tool_call (各結果ごと・逐次) │ -│ 結果の確認、加工、ログ出力 │ -│ │ -├─▶ ツール結果をコンテキストに追加 → ループ先頭へ │ -│ │ -└─▶ ツールなしの場合: │ - │ │ - └─▶ on_turn_end ─────────────────────────────┘ - 最終応答のチェック(Lint/Fmt等) - エラーがあればContinueWithMessagesでリトライ +├─▶ on_prompt_submit ───────────────────────────┐ +│ ユーザーメッセージの前処理・検証 │ +│ (最初の1回のみ) │ +│ │ +└─▶ loop { + │ + ├─▶ pre_llm_request ──────────────────────│ + │ コンテキストの改変、バリデーション、 │ + │ システムプロンプト注入などが可能 │ + │ (毎ターン実行) │ + │ │ + ├─▶ LLMリクエスト送信 & ストリーム処理 │ + │ │ + ├─▶ ツール呼び出しがある場合: │ + │ │ │ + │ ├─▶ pre_tool_call (各ツールごと・逐次) │ + │ │ 実行可否の判定、引数の改変 │ + │ │ │ + │ ├─▶ ツール並列実行 (join_all) │ + │ │ │ + │ └─▶ post_tool_call (各結果ごと・逐次) │ + │ 結果の確認、加工、ログ出力 │ + │ │ + ├─▶ ツール結果をコンテキストに追加 │ + │ → ループ先頭へ │ + │ │ + └─▶ ツールなしの場合: │ + │ │ + └─▶ on_turn_end ───────────────────┘ + 最終応答のチェック(Lint/Fmt等) + エラーがあればContinueWithMessagesでリトライ +} ※ 中断時は on_abort が呼ばれる ``` ## 各Hookの詳細 -### on_message_send +### on_prompt_submit -**呼び出しタイミング**: LLMへリクエスト送信前(ターンループの冒頭) +**呼び出しタイミング**: `run()` +でユーザーメッセージを受け取った直後(最初の1回のみ) + +**用途**: + +- ユーザー入力のバリデーション +- 入力のサニタイズ・フィルタリング +- ログ出力 +- `OnPromptSubmitResult::Cancel` による実行キャンセル + +**入力**: `&mut Message` - ユーザーメッセージ(改変可能) + +**例**: 入力のバリデーション + +```rust +struct InputValidator; + +#[async_trait] +impl Hook for InputValidator { + async fn call( + &self, + message: &mut Message, + ) -> Result { + if let MessageContent::Text(text) = &message.content { + if text.trim().is_empty() { + return Ok(OnPromptSubmitResult::Cancel("Empty input".to_string())); + } + } + Ok(OnPromptSubmitResult::Continue) + } +} +``` + +### pre_llm_request + +**呼び出しタイミング**: 各ターンのLLMリクエスト送信前(ループの毎回) **用途**: @@ -135,7 +187,9 @@ Worker::run() ループ - メッセージのバリデーション - 機密情報のフィルタリング - リクエスト内容のログ出力 -- `OnMessageSendResult::Cancel` による送信キャンセル +- `PreLlmRequestResult::Cancel` による送信キャンセル + +**入力**: `&mut Vec` - コンテキスト全体(改変可能) **例**: メッセージにタイムスタンプを追加 @@ -143,19 +197,19 @@ Worker::run() ループ struct TimestampHook; #[async_trait] -impl Hook for TimestampHook { +impl Hook for TimestampHook { async fn call( &self, context: &mut Vec, - ) -> Result { + ) -> Result { let timestamp = chrono::Local::now().to_rfc3339(); context.insert(0, Message::user(format!("[{}]", timestamp))); - Ok(OnMessageSendResult::Continue) + Ok(PreLlmRequestResult::Continue) } } ``` -### before_tool_call +### pre_tool_call **呼び出しタイミング**: 各ツール実行前(並列実行フェーズの前) @@ -165,9 +219,10 @@ impl Hook for TimestampHook { - 引数のサニタイズ - 確認プロンプトの表示(UIとの連携) - 実行ログの記録 -- `BeforeToolCallResult::Pause` による一時停止 +- `PreToolCallResult::Pause` による一時停止 **入力**: + - `ToolCallContext`(`ToolCall` + `ToolMeta` + `Arc`) **例**: 特定ツールをブロック @@ -178,22 +233,22 @@ struct ToolBlocker { } #[async_trait] -impl Hook for ToolBlocker { +impl Hook for ToolBlocker { async fn call( &self, ctx: &mut ToolCallContext, - ) -> Result { + ) -> Result { if self.blocked_tools.contains(&ctx.call.name) { println!("Blocked tool: {}", ctx.call.name); - Ok(BeforeToolCallResult::Skip) + Ok(PreToolCallResult::Skip) } else { - Ok(BeforeToolCallResult::Continue) + Ok(PreToolCallResult::Continue) } } } ``` -### after_tool_call +### post_tool_call **呼び出しタイミング**: 各ツール実行後(並列実行フェーズの後) @@ -203,8 +258,11 @@ impl Hook for ToolBlocker { - 機密情報のマスキング - 結果のキャッシュ - 実行結果のログ出力 + **入力**: -- `ToolResultContext`(`ToolResult` + `ToolMeta` + `Arc`) + +- `PostToolCallContext`(`ToolCall` + `ToolResult` + `ToolMeta` + + `Arc`) **例**: 結果にプレフィックスを追加 @@ -212,15 +270,15 @@ impl Hook for ToolBlocker { struct ResultFormatter; #[async_trait] -impl Hook for ResultFormatter { +impl Hook for ResultFormatter { async fn call( &self, - ctx: &mut ToolResultContext, - ) -> Result { + ctx: &mut PostToolCallContext, + ) -> Result { if !ctx.result.is_error { ctx.result.content = format!("[OK] {}", ctx.result.content); } - Ok(AfterToolCallResult::Continue) + Ok(PostToolCallResult::Continue) } } ``` @@ -283,9 +341,9 @@ impl Hook for JsonValidator { Hookは**イベントごとに登録順**に実行されます。 ```rust -worker.add_before_tool_call_hook(HookA); // 1番目に実行 -worker.add_before_tool_call_hook(HookB); // 2番目に実行 -worker.add_before_tool_call_hook(HookC); // 3番目に実行 +worker.add_pre_tool_call_hook(HookA); // 1番目に実行 +worker.add_pre_tool_call_hook(HookB); // 2番目に実行 +worker.add_pre_tool_call_hook(HookC); // 3番目に実行 ``` ### 制御フローの伝播 @@ -323,15 +381,15 @@ Hook A: Continue → Hook B: Pause async fn call(&self, ctx: &mut ToolCallContext) -> ... { // 引数を直接書き換え ctx.call.input["sanitized"] = json!(true); - Ok(BeforeToolCallResult::Continue) + Ok(PreToolCallResult::Continue) } ``` ### 3. 並列実行との統合 -- `before_tool_call`: 並列実行**前**に逐次実行(許可判定のため) +- `pre_tool_call`: 並列実行**前**に逐次実行(許可判定のため) - ツール実行: `join_all`で**並列**実行 -- `after_tool_call`: 並列実行**後**に逐次実行(結果加工のため) +- `post_tool_call`: 並列実行**後**に逐次実行(結果加工のため) ### 4. Send + Sync 要件 @@ -344,24 +402,24 @@ struct CountingHook { } #[async_trait] -impl Hook for CountingHook { - async fn call(&self, _: &mut ToolCallContext) -> Result { +impl Hook for CountingHook { + async fn call(&self, _: &mut ToolCallContext) -> Result { self.count.fetch_add(1, Ordering::SeqCst); - Ok(BeforeToolCallResult::Continue) + Ok(PreToolCallResult::Continue) } } ``` ## 典型的なユースケース -| ユースケース | 使用Hook | 処理内容 | -| ------------------ | ------------------------ | -------------------------- | -| ツール許可制御 | `before_tool_call` | 危険なツールをSkip | -| 実行ログ | `before/after_tool_call` | 呼び出しと結果を記録 | -| 出力バリデーション | `on_turn_end` | 形式チェック、リトライ指示 | -| コンテキスト注入 | `on_message_send` | システムメッセージ追加 | -| 結果のサニタイズ | `after_tool_call` | 機密情報のマスキング | -| レート制限 | `before_tool_call` | 呼び出し頻度の制御 | +| ユースケース | 使用Hook | 処理内容 | +| ------------------ | -------------------- | -------------------------- | +| ツール許可制御 | `pre_tool_call` | 危険なツールをSkip | +| 実行ログ | `pre/post_tool_call` | 呼び出しと結果を記録 | +| 出力バリデーション | `on_turn_end` | 形式チェック、リトライ指示 | +| コンテキスト注入 | `on_message_send` | システムメッセージ追加 | +| 結果のサニタイズ | `post_tool_call` | 機密情報のマスキング | +| レート制限 | `pre_tool_call` | 呼び出し頻度の制御 | ## TODO diff --git a/docs/spec/tools_design.md b/docs/spec/tools_design.md new file mode 100644 index 0000000..80af202 --- /dev/null +++ b/docs/spec/tools_design.md @@ -0,0 +1,191 @@ +# Tool 設計 + +## 概要 + +`llm-worker`のツールシステムは、LLMが外部リソースにアクセスしたり計算を実行するための仕組みを提供する。 +メタ情報の不変性とセッションスコープの状態管理を両立させる設計となっている。 + +## 主要な型 + +``` +type ToolDefinition + Fn() -> (ToolMeta, Arc) + +worker.register_tool() で呼び出し + + ▼ + +- struct ToolMeta (name, desc, schema) + 不変・登録時固定 +- trait Tool (executer) + 登録時生成・セッション中再利用 +``` + +### ToolMeta + +ツールのメタ情報を保持する不変構造体。登録時に固定され、Worker内で変更されない。 + +```rust +pub struct ToolMeta { + pub name: String, + pub description: String, + pub input_schema: Value, +} +``` + +**目的:** + +- LLM へのツール定義として送信 +- Hook からの参照(読み取り専用) +- 登録後の不変性を保証 + +### Tool trait + +ツールの実行ロジックのみを定義するトレイト。 + +```rust +#[async_trait] +pub trait Tool: Send + Sync { + async fn execute(&self, input_json: &str) -> Result; +} +``` + +**設計方針:** + +- メタ情報(name, description, schema)は含まない +- 状態を持つことが可能(セッション中のカウンターなど) +- `Send + Sync` で並列実行に対応 + +**インスタンスのライフサイクル:** + +1. `register_tool()` 呼び出し時にファクトリが実行され、インスタンスが生成される +2. LLM がツールを呼び出すと、既存インスタンスの `execute()` が実行される +3. 同じセッション中は同一インスタンスが再利用される + +※ 「最初に呼ばれたとき」の遅延初期化ではなく、**登録時の即時初期化**である。 + +### ToolDefinition + +メタ情報とツールインスタンスを生成するファクトリ。 + +```rust +pub type ToolDefinition = Arc (ToolMeta, Arc) + Send + Sync>; +``` + +**なぜファクトリか:** + +- Worker への登録時に一度だけ呼び出される +- メタ情報とインスタンスを同時に生成し、整合性を保証 +- クロージャでコンテキスト(`self.clone()`)をキャプチャ可能 + +## Worker でのツール管理 + +```rust +// Worker 内部 +tools: HashMap)> + +// 登録 API +pub fn register_tool(&mut self, factory: ToolDefinition) -> Result<(), ToolRegistryError> +``` + +登録時の処理: + +1. ファクトリを呼び出し `(meta, instance)` を取得 +2. 同名ツールが既に登録されていればエラー +3. HashMap に `(meta, instance)` を保存 + +## マクロによる自動生成 + +`#[tool_registry]` マクロは `{method}_definition()` メソッドを生成する。 + +```rust +#[tool_registry] +impl MyApp { + /// 検索を実行する + #[tool] + async fn search(&self, query: String) -> String { + // 実装 + } +} + +// 生成されるコード: +impl MyApp { + pub fn search_definition(&self) -> ToolDefinition { + let ctx = self.clone(); + Arc::new(move || { + let meta = ToolMeta::new("search") + .description("検索を実行する") + .input_schema(/* schemars で生成 */); + let tool = Arc::new(ToolSearch { ctx: ctx.clone() }); + (meta, tool) + }) + } +} +``` + +## Hook との連携 + +Hook は `ToolCallContext` / `AfterToolCallContext` +を通じてメタ情報とインスタンスにアクセスできる。 + +```rust +pub struct ToolCallContext { + pub call: ToolCall, // 呼び出し情報(改変可能) + pub meta: ToolMeta, // メタ情報(読み取り専用) + pub tool: Arc, // インスタンス(状態アクセス用) +} +``` + +**用途:** + +- `meta` で名前やスキーマを確認 +- `tool` でツールの内部状態を読み取り(ダウンキャスト必要) +- `call` の引数を改変してツールに渡す + +## 使用例 + +### 手動実装 + +```rust +struct Counter { count: AtomicUsize } + +impl Tool for Counter { + async fn execute(&self, _: &str) -> Result { + let n = self.count.fetch_add(1, Ordering::SeqCst); + Ok(format!("count: {}", n)) + } +} + +let def: ToolDefinition = Arc::new(|| { + let meta = ToolMeta::new("counter") + .description("カウンターを増加") + .input_schema(json!({"type": "object"})); + (meta, Arc::new(Counter { count: AtomicUsize::new(0) })) +}); + +worker.register_tool(def)?; +``` + +### マクロ使用(推奨) + +```rust +#[tool_registry] +impl App { + #[tool] + async fn greet(&self, name: String) -> String { + format!("Hello, {}!", name) + } +} + +let app = App; +worker.register_tool(app.greet_definition())?; +``` + +## 設計上の決定 + +| 問題 | 決定 | 理由 | +| -------------------- | ------------------------------ | ---------------------------------------------- | +| メタ情報の変更可能性 | ToolMeta を分離・不変化 | 登録後の整合性を保証 | +| 状態管理 | 登録時にインスタンス生成 | セッション中の状態保持、同一インスタンス再利用 | +| Factory vs Instance | Factory + 登録時即時呼び出し | コンテキストキャプチャと登録時検証 | +| Hook からのアクセス | Context に meta と tool を含む | 柔軟な介入を可能に | diff --git a/llm-worker-macros/src/lib.rs b/llm-worker-macros/src/lib.rs index 5c9b73f..b1a2e45 100644 --- a/llm-worker-macros/src/lib.rs +++ b/llm-worker-macros/src/lib.rs @@ -113,7 +113,7 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2:: let pascal_name = to_pascal_case(&method_name.to_string()); let tool_struct_name = format_ident!("Tool{}", pascal_name); let args_struct_name = format_ident!("{}Args", pascal_name); - let factory_name = format_ident!("{}_tool", method_name); + let definition_name = format_ident!("{}_definition", method_name); // ドキュメントコメントから説明を取得 let description = extract_doc_comment(&method.attrs); @@ -247,29 +247,24 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2:: #[async_trait::async_trait] impl ::llm_worker::tool::Tool for #tool_struct_name { - fn name(&self) -> &str { - #tool_name - } - - fn description(&self) -> &str { - #description - } - - fn input_schema(&self) -> serde_json::Value { - let schema = schemars::schema_for!(#args_struct_name); - serde_json::to_value(schema).unwrap_or(serde_json::json!({})) - } - async fn execute(&self, input_json: &str) -> Result { #execute_body } } impl #self_ty { - pub fn #factory_name(&self) -> #tool_struct_name { - #tool_struct_name { - ctx: self.clone() - } + /// ToolDefinition を取得(Worker への登録用) + pub fn #definition_name(&self) -> ::llm_worker::tool::ToolDefinition { + let ctx = self.clone(); + ::std::sync::Arc::new(move || { + let schema = schemars::schema_for!(#args_struct_name); + let meta = ::llm_worker::tool::ToolMeta::new(#tool_name) + .description(#description) + .input_schema(serde_json::to_value(schema).unwrap_or(serde_json::json!({}))); + let tool: ::std::sync::Arc = + ::std::sync::Arc::new(#tool_struct_name { ctx: ctx.clone() }); + (meta, tool) + }) } } } diff --git a/llm-worker/examples/worker_cancel_demo.rs b/llm-worker/examples/worker_cancel_demo.rs index ad9173c..b4b7114 100644 --- a/llm-worker/examples/worker_cancel_demo.rs +++ b/llm-worker/examples/worker_cancel_demo.rs @@ -2,11 +2,11 @@ //! //! ストリーミング受信中に別スレッドからキャンセルする例 +use llm_worker::llm_client::providers::anthropic::AnthropicClient; +use llm_worker::{Worker, WorkerResult}; use std::sync::Arc; use std::time::Duration; use tokio::sync::Mutex; -use llm_worker::{Worker, WorkerResult}; -use llm_worker::llm_client::providers::anthropic::AnthropicClient; #[tokio::main] async fn main() -> Result<(), Box> { @@ -21,8 +21,8 @@ async fn main() -> Result<(), Box> { ) .init(); - let api_key = std::env::var("ANTHROPIC_API_KEY") - .expect("ANTHROPIC_API_KEY environment variable not set"); + let api_key = + std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY environment variable not set"); let client = AnthropicClient::new(&api_key, "claude-sonnet-4-20250514"); let worker = Arc::new(Mutex::new(Worker::new(client))); @@ -41,7 +41,7 @@ async fn main() -> Result<(), Box> { let task = tokio::spawn(async move { let mut w = worker_clone.lock().await; println!("📡 Sending request to LLM..."); - + match w.run("Tell me a very long story about a brave knight. Make it as detailed as possible with many paragraphs.").await { Ok(WorkerResult::Finished(_)) => { println!("✅ Task completed normally"); @@ -66,6 +66,6 @@ async fn main() -> Result<(), Box> { task.await?; println!("\n✨ Demo complete!"); - + Ok(()) } diff --git a/llm-worker/examples/worker_cli.rs b/llm-worker/examples/worker_cli.rs index a576871..fbc03e4 100644 --- a/llm-worker/examples/worker_cli.rs +++ b/llm-worker/examples/worker_cli.rs @@ -41,7 +41,7 @@ use tracing_subscriber::EnvFilter; use clap::{Parser, ValueEnum}; use llm_worker::{ Worker, - hook::{AfterToolCall, AfterToolCallResult, Hook, HookError, ToolResult}, + hook::{Hook, HookError, PostToolCall, PostToolCallContext, PostToolCallResult}, llm_client::{ LlmClient, providers::{ @@ -282,25 +282,22 @@ impl ToolResultPrinterHook { } #[async_trait] -impl Hook for ToolResultPrinterHook { - async fn call( - &self, - tool_result: &mut ToolResult, - ) -> Result { +impl Hook for ToolResultPrinterHook { + async fn call(&self, ctx: &mut PostToolCallContext) -> Result { let name = self .call_names .lock() .unwrap() - .remove(&tool_result.tool_use_id) - .unwrap_or_else(|| tool_result.tool_use_id.clone()); + .remove(&ctx.result.tool_use_id) + .unwrap_or_else(|| ctx.result.tool_use_id.clone()); - if tool_result.is_error { - println!(" Result ({}): ❌ {}", name, tool_result.content); + if ctx.result.is_error { + println!(" Result ({}): ❌ {}", name, ctx.result.content); } else { - println!(" Result ({}): ✅ {}", name, tool_result.content); + println!(" Result ({}): ✅ {}", name, ctx.result.content); } - Ok(AfterToolCallResult::Continue) + Ok(PostToolCallResult::Continue) } } @@ -441,8 +438,10 @@ async fn main() -> Result<(), Box> { // ツール登録(--no-tools でなければ) if !args.no_tools { let app = AppContext; - worker.register_tool(app.get_current_time_tool()); - worker.register_tool(app.calculate_tool()); + worker + .register_tool(app.get_current_time_definition()) + .unwrap(); + worker.register_tool(app.calculate_definition()).unwrap(); } // ストリーミング表示用ハンドラーを登録 @@ -451,7 +450,7 @@ async fn main() -> Result<(), Box> { .on_text_block(StreamingPrinter::new()) .on_tool_use_block(ToolCallPrinter::new(tool_call_names.clone())); - worker.add_after_tool_call_hook(ToolResultPrinterHook::new(tool_call_names)); + worker.add_post_tool_call_hook(ToolResultPrinterHook::new(tool_call_names)); // ワンショットモード if let Some(prompt) = args.prompt { diff --git a/llm-worker/src/hook.rs b/llm-worker/src/hook.rs index 1007dec..9fed3af 100644 --- a/llm-worker/src/hook.rs +++ b/llm-worker/src/hook.rs @@ -16,20 +16,27 @@ pub trait HookEventKind: Send + Sync + 'static { type Output; } -pub struct OnMessageSend; -pub struct BeforeToolCall; -pub struct AfterToolCall; +pub struct OnPromptSubmit; +pub struct PreLlmRequest; +pub struct PreToolCall; +pub struct PostToolCall; pub struct OnTurnEnd; pub struct OnAbort; #[derive(Debug, Clone, PartialEq, Eq)] -pub enum OnMessageSendResult { +pub enum OnPromptSubmitResult { Continue, Cancel(String), } #[derive(Debug, Clone, PartialEq, Eq)] -pub enum BeforeToolCallResult { +pub enum PreLlmRequestResult { + Continue, + Cancel(String), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PreToolCallResult { Continue, Skip, Abort(String), @@ -37,7 +44,7 @@ pub enum BeforeToolCallResult { } #[derive(Debug, Clone, PartialEq, Eq)] -pub enum AfterToolCallResult { +pub enum PostToolCallResult { Continue, Abort(String), } @@ -49,19 +56,50 @@ pub enum OnTurnEndResult { Paused, } -impl HookEventKind for OnMessageSend { +use std::sync::Arc; + +use crate::tool::{Tool, ToolMeta}; + +/// PreToolCall の入力コンテキスト +pub struct ToolCallContext { + /// ツール呼び出し情報(改変可能) + pub call: ToolCall, + /// ツールメタ情報(不変) + pub meta: ToolMeta, + /// ツールインスタンス(状態アクセス用) + pub tool: Arc, +} + +/// PostToolCall の入力コンテキスト +pub struct PostToolCallContext { + /// ツール呼び出し情報 + pub call: ToolCall, + /// ツール実行結果(改変可能) + pub result: ToolResult, + /// ツールメタ情報(不変) + pub meta: ToolMeta, + /// ツールインスタンス(状態アクセス用) + pub tool: Arc, +} + +impl HookEventKind for OnPromptSubmit { + type Input = crate::Message; + type Output = OnPromptSubmitResult; +} + +impl HookEventKind for PreLlmRequest { type Input = Vec; - type Output = OnMessageSendResult; + type Output = PreLlmRequestResult; } -impl HookEventKind for BeforeToolCall { - type Input = ToolCall; - type Output = BeforeToolCallResult; +impl HookEventKind for PreToolCall { + type Input = ToolCallContext; + type Output = PreToolCallResult; } -impl HookEventKind for AfterToolCall { - type Input = ToolResult; - type Output = AfterToolCallResult; +impl HookEventKind for PostToolCall { + type Input = PostToolCallContext; + type Output = PostToolCallResult; } impl HookEventKind for OnTurnEnd { @@ -151,3 +189,45 @@ pub enum HookError { pub trait Hook: Send + Sync { async fn call(&self, input: &mut E::Input) -> Result; } + +// ============================================================================= +// Hook Registry +// ============================================================================= + +/// 全 Hook を保持するレジストリ +/// +/// Worker 内部で使用され、各種 Hook を一括管理する。 +pub struct HookRegistry { + /// on_prompt_submit Hook + pub(crate) on_prompt_submit: Vec>>, + /// pre_llm_request Hook + pub(crate) pre_llm_request: Vec>>, + /// pre_tool_call Hook + pub(crate) pre_tool_call: Vec>>, + /// post_tool_call Hook + pub(crate) post_tool_call: Vec>>, + /// on_turn_end Hook + pub(crate) on_turn_end: Vec>>, + /// on_abort Hook + pub(crate) on_abort: Vec>>, +} + +impl Default for HookRegistry { + fn default() -> Self { + Self::new() + } +} + +impl HookRegistry { + /// 空の HookRegistry を作成 + pub fn new() -> Self { + Self { + on_prompt_submit: Vec::new(), + pre_llm_request: Vec::new(), + pre_tool_call: Vec::new(), + post_tool_call: Vec::new(), + on_turn_end: Vec::new(), + on_abort: Vec::new(), + } + } +} diff --git a/llm-worker/src/lib.rs b/llm-worker/src/lib.rs index ef743b4..2a311ab 100644 --- a/llm-worker/src/lib.rs +++ b/llm-worker/src/lib.rs @@ -19,8 +19,7 @@ //! .system_prompt("You are a helpful assistant."); //! //! // ツールを登録(オプション) -//! use llm_worker::tool::Tool; -//! worker.register_tool(my_tool); +//! // worker.register_tool(my_tool_definition)?; //! //! // 対話を実行 //! let history = worker.run("Hello!").await?; @@ -49,4 +48,4 @@ pub mod timeline; pub mod tool; pub use message::{ContentPart, Message, MessageContent, Role}; -pub use worker::{Worker, WorkerConfig, WorkerError, WorkerResult}; +pub use worker::{ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult}; diff --git a/llm-worker/src/tool.rs b/llm-worker/src/tool.rs index 9585eca..eabe57f 100644 --- a/llm-worker/src/tool.rs +++ b/llm-worker/src/tool.rs @@ -3,6 +3,8 @@ //! LLMから呼び出し可能なツールを定義するためのトレイト。 //! 通常は`#[tool]`マクロを使用して自動実装します。 +use std::sync::Arc; + use async_trait::async_trait; use serde_json::Value; use thiserror::Error; @@ -21,64 +23,126 @@ pub enum ToolError { Internal(String), } +// ============================================================================= +// ToolMeta - 不変のメタ情報 +// ============================================================================= + +/// ツールのメタ情報(登録時に固定、不変) +/// +/// `ToolDefinition` ファクトリから生成され、Worker に登録後は変更されません。 +/// LLM へのツール定義送信に使用されます。 +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ToolMeta { + /// ツール名(LLMが識別に使用) + pub name: String, + /// ツールの説明(LLMへのプロンプトに含まれる) + pub description: String, + /// 引数のJSON Schema + pub input_schema: Value, +} + +impl ToolMeta { + /// 新しい ToolMeta を作成 + pub fn new(name: impl Into) -> Self { + Self { + name: name.into(), + description: String::new(), + input_schema: Value::Object(Default::default()), + } + } + + /// 説明を設定 + pub fn description(mut self, desc: impl Into) -> Self { + self.description = desc.into(); + self + } + + /// 引数スキーマを設定 + pub fn input_schema(mut self, schema: Value) -> Self { + self.input_schema = schema; + self + } +} + +// ============================================================================= +// ToolDefinition - ファクトリ型 +// ============================================================================= + +/// ツール定義ファクトリ +/// +/// 呼び出すと `(ToolMeta, Arc)` を返します。 +/// Worker への登録時に一度だけ呼び出され、メタ情報とインスタンスが +/// セッションスコープでキャッシュされます。 +/// +/// # Examples +/// +/// ```ignore +/// let def: ToolDefinition = Arc::new(|| { +/// ( +/// ToolMeta::new("my_tool") +/// .description("My tool description") +/// .input_schema(json!({"type": "object"})), +/// Arc::new(MyToolImpl { state: 0 }) as Arc, +/// ) +/// }); +/// worker.register_tool(def)?; +/// ``` +pub type ToolDefinition = Arc (ToolMeta, Arc) + Send + Sync>; + +// ============================================================================= +// Tool trait +// ============================================================================= + /// LLMから呼び出し可能なツールを定義するトレイト /// /// ツールはLLMが外部リソースにアクセスしたり、 /// 計算を実行したりするために使用します。 +/// セッション中の状態を保持できます。 /// /// # 実装方法 /// -/// 通常は`#[tool]`マクロを使用して自動実装します: +/// 通常は`#[tool_registry]`マクロを使用して自動実装します: /// /// ```ignore -/// use llm_worker::tool; -/// -/// #[tool(description = "Search the web for information")] -/// async fn search(query: String) -> String { -/// // 検索処理 -/// format!("Results for: {}", query) +/// #[tool_registry] +/// impl MyApp { +/// #[tool] +/// async fn search(&self, query: String) -> String { +/// format!("Results for: {}", query) +/// } /// } +/// +/// // 登録 +/// worker.register_tool(app.search_definition())?; /// ``` /// /// # 手動実装 /// /// ```ignore -/// use llm_worker::tool::{Tool, ToolError}; -/// use serde_json::{json, Value}; +/// use llm_worker::tool::{Tool, ToolError, ToolMeta, ToolDefinition}; +/// use std::sync::Arc; /// -/// struct MyTool; +/// struct MyTool { counter: std::sync::atomic::AtomicUsize } /// /// #[async_trait::async_trait] /// impl Tool for MyTool { -/// fn name(&self) -> &str { "my_tool" } -/// fn description(&self) -> &str { "My custom tool" } -/// fn input_schema(&self) -> Value { -/// json!({ -/// "type": "object", -/// "properties": { -/// "query": { "type": "string" } -/// }, -/// "required": ["query"] -/// }) -/// } /// async fn execute(&self, input: &str) -> Result { +/// self.counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst); /// Ok("result".to_string()) /// } /// } +/// +/// let def: ToolDefinition = Arc::new(|| { +/// ( +/// ToolMeta::new("my_tool") +/// .description("My custom tool") +/// .input_schema(serde_json::json!({"type": "object"})), +/// Arc::new(MyTool { counter: Default::default() }) as Arc, +/// ) +/// }); /// ``` #[async_trait] pub trait Tool: Send + Sync { - /// ツール名(LLMが識別に使用) - fn name(&self) -> &str; - - /// ツールの説明(LLMへのプロンプトに含まれる) - fn description(&self) -> &str; - - /// 引数のJSON Schema - /// - /// LLMはこのスキーマに従って引数を生成します。 - fn input_schema(&self) -> Value; - /// ツールを実行する /// /// # Arguments diff --git a/llm-worker/src/worker.rs b/llm-worker/src/worker.rs index 0b413c0..d8c32a0 100644 --- a/llm-worker/src/worker.rs +++ b/llm-worker/src/worker.rs @@ -9,18 +9,21 @@ use tracing::{debug, info, trace, warn}; use crate::{ ContentPart, Message, MessageContent, Role, hook::{ - AfterToolCall, AfterToolCallResult, BeforeToolCall, BeforeToolCallResult, Hook, HookError, - OnAbort, OnMessageSend, OnMessageSendResult, OnTurnEnd, OnTurnEndResult, ToolCall, - ToolResult, + Hook, HookError, HookRegistry, OnAbort, OnPromptSubmit, OnPromptSubmitResult, OnTurnEnd, + OnTurnEndResult, PostToolCall, PostToolCallContext, PostToolCallResult, PreLlmRequest, + PreLlmRequestResult, PreToolCall, PreToolCallResult, ToolCall, ToolCallContext, ToolResult, + }, + llm_client::{ + ClientError, ConfigWarning, LlmClient, Request, RequestConfig, + ToolDefinition as LlmToolDefinition, }, - llm_client::{ClientError, ConfigWarning, LlmClient, Request, RequestConfig, ToolDefinition}, state::{Locked, Mutable, WorkerState}, subscriber::{ ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter, ToolUseBlockSubscriberAdapter, UsageSubscriberAdapter, WorkerSubscriber, }, timeline::{TextBlockCollector, Timeline, ToolCallCollector}, - tool::{Tool, ToolError}, + tool::{Tool, ToolDefinition, ToolError, ToolMeta}, }; // ============================================================================= @@ -50,6 +53,14 @@ pub enum WorkerError { ConfigWarnings(Vec), } +/// ツール登録エラー +#[derive(Debug, thiserror::Error)] +pub enum ToolRegistryError { + /// 同名のツールが既に登録されている + #[error("Tool with name '{0}' already registered")] + DuplicateName(String), +} + // ============================================================================= // Worker Config // ============================================================================= @@ -155,18 +166,10 @@ pub struct Worker { text_block_collector: TextBlockCollector, /// ツールコールコレクター(Timeline用ハンドラ) tool_call_collector: ToolCallCollector, - /// 登録されたツール - tools: HashMap>, - /// on_message_send Hook - hooks_on_message_send: Vec>>, - /// before_tool_call Hook - hooks_before_tool_call: Vec>>, - /// after_tool_call Hook - hooks_after_tool_call: Vec>>, - /// on_turn_end Hook - hooks_on_turn_end: Vec>>, - /// on_abort Hook - hooks_on_abort: Vec>>, + /// 登録されたツール (meta, instance) + tools: HashMap)>, + /// Hook レジストリ + hooks: HookRegistry, /// システムプロンプト system_prompt: Option, /// メッセージ履歴(Workerが所有) @@ -248,51 +251,71 @@ impl Worker { /// ツールを登録する /// /// 登録されたツールはLLMからの呼び出しで自動的に実行されます。 - /// 同名のツールを登録した場合、後から登録したものが優先されます。 + /// 同名のツールを登録するとエラーになります。 /// /// # Examples /// /// ```ignore - /// use llm_worker::Worker; - /// use my_tools::SearchTool; + /// use llm_worker::tool::{ToolMeta, ToolDefinition, Tool}; + /// use std::sync::Arc; /// - /// worker.register_tool(SearchTool::new()); + /// let def: ToolDefinition = Arc::new(|| { + /// (ToolMeta::new("search").description("..."), Arc::new(MyTool) as Arc) + /// }); + /// worker.register_tool(def)?; /// ``` - pub fn register_tool(&mut self, tool: impl Tool + 'static) { - let name = tool.name().to_string(); - self.tools.insert(name, Arc::new(tool)); + pub fn register_tool(&mut self, factory: ToolDefinition) -> Result<(), ToolRegistryError> { + let (meta, instance) = factory(); + if self.tools.contains_key(&meta.name) { + return Err(ToolRegistryError::DuplicateName(meta.name.clone())); + } + self.tools.insert(meta.name.clone(), (meta, instance)); + Ok(()) } /// 複数のツールを登録 - pub fn register_tools(&mut self, tools: impl IntoIterator) { - for tool in tools { - self.register_tool(tool); + pub fn register_tools( + &mut self, + factories: impl IntoIterator, + ) -> Result<(), ToolRegistryError> { + for factory in factories { + self.register_tool(factory)?; } + Ok(()) } - /// on_message_send Hookを追加する - pub fn add_on_message_send_hook(&mut self, hook: impl Hook + 'static) { - self.hooks_on_message_send.push(Box::new(hook)); + /// on_prompt_submit Hookを追加する + /// + /// `run()` でユーザーメッセージを受け取った直後に呼び出される。 + pub fn add_on_prompt_submit_hook(&mut self, hook: impl Hook + 'static) { + self.hooks.on_prompt_submit.push(Box::new(hook)); } - /// before_tool_call Hookを追加する - pub fn add_before_tool_call_hook(&mut self, hook: impl Hook + 'static) { - self.hooks_before_tool_call.push(Box::new(hook)); + /// pre_llm_request Hookを追加する + /// + /// 各ターンのLLMリクエスト送信前に呼び出される。 + pub fn add_pre_llm_request_hook(&mut self, hook: impl Hook + 'static) { + self.hooks.pre_llm_request.push(Box::new(hook)); } - /// after_tool_call Hookを追加する - pub fn add_after_tool_call_hook(&mut self, hook: impl Hook + 'static) { - self.hooks_after_tool_call.push(Box::new(hook)); + /// pre_tool_call Hookを追加する + pub fn add_pre_tool_call_hook(&mut self, hook: impl Hook + 'static) { + self.hooks.pre_tool_call.push(Box::new(hook)); + } + + /// post_tool_call Hookを追加する + pub fn add_post_tool_call_hook(&mut self, hook: impl Hook + 'static) { + self.hooks.post_tool_call.push(Box::new(hook)); } /// on_turn_end Hookを追加する pub fn add_on_turn_end_hook(&mut self, hook: impl Hook + 'static) { - self.hooks_on_turn_end.push(Box::new(hook)); + self.hooks.on_turn_end.push(Box::new(hook)); } /// on_abort Hookを追加する pub fn add_on_abort_hook(&mut self, hook: impl Hook + 'static) { - self.hooks_on_abort.push(Box::new(hook)); + self.hooks.on_abort.push(Box::new(hook)); } /// タイムラインへの可変参照を取得(追加ハンドラ登録用) @@ -427,14 +450,14 @@ impl Worker { &self.cancellation_token } - /// 登録されたツールからToolDefinitionのリストを生成 - fn build_tool_definitions(&self) -> Vec { + /// 登録されたツールからLLM用ToolDefinitionのリストを生成 + fn build_tool_definitions(&self) -> Vec { self.tools .values() - .map(|tool| { - ToolDefinition::new(tool.name()) - .description(tool.description()) - .input_schema(tool.input_schema()) + .map(|(meta, _)| { + LlmToolDefinition::new(&meta.name) + .description(&meta.description) + .input_schema(meta.input_schema.clone()) }) .collect() } @@ -482,7 +505,11 @@ impl Worker { } /// リクエストを構築 - fn build_request(&self, tool_definitions: &[ToolDefinition], context: &[Message]) -> Request { + fn build_request( + &self, + tool_definitions: &[LlmToolDefinition], + context: &[Message], + ) -> Request { let mut request = Request::new(); // システムプロンプトを設定 @@ -546,27 +573,48 @@ impl Worker { request } - /// Hooks: on_message_send - async fn run_on_message_send_hooks( + /// Hooks: on_prompt_submit + /// + /// `run()` でユーザーメッセージを受け取った直後に呼び出される(最初だけ)。 + async fn run_on_prompt_submit_hooks( &self, - ) -> Result<(OnMessageSendResult, Vec), WorkerError> { - let mut temp_context = self.history.clone(); - for hook in &self.hooks_on_message_send { - let result = hook.call(&mut temp_context).await?; + message: &mut Message, + ) -> Result { + for hook in &self.hooks.on_prompt_submit { + let result = hook.call(message).await?; match result { - OnMessageSendResult::Continue => continue, - OnMessageSendResult::Cancel(reason) => { - return Ok((OnMessageSendResult::Cancel(reason), temp_context)); + OnPromptSubmitResult::Continue => continue, + OnPromptSubmitResult::Cancel(reason) => { + return Ok(OnPromptSubmitResult::Cancel(reason)); } } } - Ok((OnMessageSendResult::Continue, temp_context)) + Ok(OnPromptSubmitResult::Continue) + } + + /// Hooks: pre_llm_request + /// + /// 各ターンのLLMリクエスト送信前に呼び出される(毎ターン)。 + async fn run_pre_llm_request_hooks( + &self, + ) -> Result<(PreLlmRequestResult, Vec), WorkerError> { + let mut temp_context = self.history.clone(); + for hook in &self.hooks.pre_llm_request { + let result = hook.call(&mut temp_context).await?; + match result { + PreLlmRequestResult::Continue => continue, + PreLlmRequestResult::Cancel(reason) => { + return Ok((PreLlmRequestResult::Cancel(reason), temp_context)); + } + } + } + Ok((PreLlmRequestResult::Continue, temp_context)) } /// Hooks: on_turn_end async fn run_on_turn_end_hooks(&self) -> Result { let mut temp_messages = self.history.clone(); - for hook in &self.hooks_on_turn_end { + for hook in &self.hooks.on_turn_end { let result = hook.call(&mut temp_messages).await?; match result { OnTurnEndResult::Finish => continue, @@ -582,7 +630,7 @@ impl Worker { /// Hooks: on_abort async fn run_on_abort_hooks(&self, reason: &str) -> Result<(), WorkerError> { let mut reason = reason.to_string(); - for hook in &self.hooks_on_abort { + for hook in &self.hooks.on_abort { hook.call(&mut reason).await?; } Ok(()) @@ -608,44 +656,67 @@ impl Worker { } } - if calls.is_empty() { - None - } else { - Some(calls) - } + if calls.is_empty() { None } else { Some(calls) } } /// ツールを並列実行 /// - /// 全てのツールに対してbefore_tool_callフックを実行後、 - /// 許可されたツールを並列に実行し、結果にafter_tool_callフックを適用する。 + /// 全てのツールに対してpre_tool_callフックを実行後、 + /// 許可されたツールを並列に実行し、結果にpost_tool_callフックを適用する。 async fn execute_tools( &mut self, tool_calls: Vec, ) -> Result { use futures::future::join_all; - // Phase 1: before_tool_call フックを適用(スキップ/中断を判定) + // ツール呼び出しIDから (ToolCall, Meta, Tool) へのマップ + // PostToolCallフックで必要になるため保持する + let mut call_info_map = HashMap::new(); + + // Phase 1: pre_tool_call フックを適用(スキップ/中断を判定) let mut approved_calls = Vec::new(); for mut tool_call in tool_calls { - let mut skip = false; - for hook in &self.hooks_before_tool_call { - let result = hook.call(&mut tool_call).await?; - match result { - BeforeToolCallResult::Continue => {} - BeforeToolCallResult::Skip => { - skip = true; - break; - } - BeforeToolCallResult::Abort(reason) => { - return Err(WorkerError::Aborted(reason)); - } - BeforeToolCallResult::Pause => { - return Ok(ToolExecutionResult::Paused); + // ツール定義を取得 + if let Some((meta, tool)) = self.tools.get(&tool_call.name) { + // コンテキストを作成 + let mut context = ToolCallContext { + call: tool_call.clone(), + meta: meta.clone(), + tool: tool.clone(), + }; + + let mut skip = false; + for hook in &self.hooks.pre_tool_call { + let result = hook.call(&mut context).await?; + match result { + PreToolCallResult::Continue => {} + PreToolCallResult::Skip => { + skip = true; + break; + } + PreToolCallResult::Abort(reason) => { + return Err(WorkerError::Aborted(reason)); + } + PreToolCallResult::Pause => { + return Ok(ToolExecutionResult::Paused); + } } } - } - if !skip { + + // フックで変更された内容を反映 + tool_call = context.call; + + // マップに保存(実行する場合のみ) + if !skip { + call_info_map.insert( + tool_call.id.clone(), + (tool_call.clone(), meta.clone(), tool.clone()), + ); + approved_calls.push(tool_call); + } + } else { + // 未知のツールはそのまま承認リストに入れる(実行時にエラーになる) + // Hookは適用しない(Metaがないため) approved_calls.push(tool_call); } } @@ -656,7 +727,7 @@ impl Worker { .map(|tool_call| { let tools = &self.tools; async move { - if let Some(tool) = tools.get(&tool_call.name) { + if let Some((_, tool)) = tools.get(&tool_call.name) { let input_json = serde_json::to_string(&tool_call.input).unwrap_or_default(); match tool.execute(&input_json).await { @@ -684,16 +755,28 @@ impl Worker { } }; - // Phase 3: after_tool_call フックを適用 + // Phase 3: post_tool_call フックを適用 for tool_result in &mut results { - for hook in &self.hooks_after_tool_call { - let result = hook.call(tool_result).await?; - match result { - AfterToolCallResult::Continue => {} - AfterToolCallResult::Abort(reason) => { - return Err(WorkerError::Aborted(reason)); + // 保存しておいた情報を取得 + if let Some((tool_call, meta, tool)) = call_info_map.get(&tool_result.tool_use_id) { + let mut context = PostToolCallContext { + call: tool_call.clone(), + result: tool_result.clone(), + meta: meta.clone(), + tool: tool.clone(), + }; + + for hook in &self.hooks.post_tool_call { + let result = hook.call(&mut context).await?; + match result { + PostToolCallResult::Continue => {} + PostToolCallResult::Abort(reason) => { + return Err(WorkerError::Aborted(reason)); + } } } + // フックで変更された結果を反映 + *tool_result = context.result; } } @@ -712,16 +795,17 @@ impl Worker { // Resume check: Pending tool calls if let Some(tool_calls) = self.get_pending_tool_calls() { - info!("Resuming pending tool calls"); - match self.execute_tools(tool_calls).await? { - ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)), - ToolExecutionResult::Completed(results) => { - for result in results { - self.history.push(Message::tool_result(&result.tool_use_id, &result.content)); - } - // Continue to loop - } - } + info!("Resuming pending tool calls"); + match self.execute_tools(tool_calls).await? { + ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)), + ToolExecutionResult::Completed(results) => { + for result in results { + self.history + .push(Message::tool_result(&result.tool_use_id, &result.content)); + } + // Continue to loop + } + } } loop { @@ -740,10 +824,10 @@ impl Worker { notifier.on_turn_start(current_turn); } - // Hook: on_message_send - let (control, request_context) = self.run_on_message_send_hooks().await?; + // Hook: pre_llm_request + let (control, request_context) = self.run_pre_llm_request_hooks().await?; match control { - OnMessageSendResult::Cancel(reason) => { + PreLlmRequestResult::Cancel(reason) => { info!(reason = %reason, "Aborted by hook"); for notifier in &self.turn_notifiers { notifier.on_turn_end(current_turn); @@ -751,7 +835,7 @@ impl Worker { self.run_on_abort_hooks(&reason).await?; return Err(WorkerError::Aborted(reason)); } - OnMessageSendResult::Continue => {} + PreLlmRequestResult::Continue => {} } // リクエスト構築 @@ -766,7 +850,7 @@ impl Worker { // ストリーム処理 debug!("Starting stream..."); let mut event_count = 0; - + // ストリームを取得(キャンセル可能) let mut stream = tokio::select! { stream_result = self.client.stream(request) => stream_result?, @@ -777,7 +861,7 @@ impl Worker { return Err(WorkerError::Cancelled); } }; - + loop { tokio::select! { // ストリームからイベントを受信 @@ -846,12 +930,13 @@ impl Worker { // ツール実行 match self.execute_tools(tool_calls).await? { - ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)), - ToolExecutionResult::Completed(results) => { - for result in results { - self.history.push(Message::tool_result(&result.tool_use_id, &result.content)); - } - } + ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)), + ToolExecutionResult::Completed(results) => { + for result in results { + self.history + .push(Message::tool_result(&result.tool_use_id, &result.content)); + } + } } } } @@ -885,11 +970,7 @@ impl Worker { text_block_collector, tool_call_collector, tools: HashMap::new(), - hooks_on_message_send: Vec::new(), - hooks_before_tool_call: Vec::new(), - hooks_after_tool_call: Vec::new(), - hooks_on_turn_end: Vec::new(), - hooks_on_abort: Vec::new(), + hooks: HookRegistry::new(), system_prompt: None, history: Vec::new(), locked_prefix_len: 0, @@ -1058,11 +1139,7 @@ impl Worker { text_block_collector: self.text_block_collector, tool_call_collector: self.tool_call_collector, tools: self.tools, - hooks_on_message_send: self.hooks_on_message_send, - hooks_before_tool_call: self.hooks_before_tool_call, - hooks_after_tool_call: self.hooks_after_tool_call, - hooks_on_turn_end: self.hooks_on_turn_end, - hooks_on_abort: self.hooks_on_abort, + hooks: self.hooks, system_prompt: self.system_prompt, history: self.history, locked_prefix_len, @@ -1081,8 +1158,21 @@ impl Worker { /// /// 注意: この関数は履歴を変更するため、キャッシュ保護が必要な場合は /// `lock()` を呼んでからLocked状態で `run` を使用すること。 - pub async fn run(&mut self, user_input: impl Into) -> Result, WorkerError> { - self.history.push(Message::user(user_input)); + pub async fn run( + &mut self, + user_input: impl Into, + ) -> Result, WorkerError> { + // Hook: on_prompt_submit + let mut user_message = Message::user(user_input); + let result = self.run_on_prompt_submit_hooks(&mut user_message).await?; + match result { + OnPromptSubmitResult::Cancel(reason) => { + self.run_on_abort_hooks(&reason).await?; + return Err(WorkerError::Aborted(reason)); + } + OnPromptSubmitResult::Continue => {} + } + self.history.push(user_message); self.run_turn_loop().await } @@ -1107,8 +1197,21 @@ impl Worker { /// /// 新しいユーザーメッセージを履歴の末尾に追加し、LLMにリクエストを送信する。 /// ロック時点より前の履歴(プレフィックス)は不変であるため、キャッシュヒットが保証される。 - pub async fn run(&mut self, user_input: impl Into) -> Result, WorkerError> { - self.history.push(Message::user(user_input)); + pub async fn run( + &mut self, + user_input: impl Into, + ) -> Result, WorkerError> { + // Hook: on_prompt_submit + let mut user_message = Message::user(user_input); + let result = self.run_on_prompt_submit_hooks(&mut user_message).await?; + match result { + OnPromptSubmitResult::Cancel(reason) => { + self.run_on_abort_hooks(&reason).await?; + return Err(WorkerError::Aborted(reason)); + } + OnPromptSubmitResult::Continue => {} + } + self.history.push(user_message); self.run_turn_loop().await } @@ -1137,11 +1240,7 @@ impl Worker { text_block_collector: self.text_block_collector, tool_call_collector: self.tool_call_collector, tools: self.tools, - hooks_on_message_send: self.hooks_on_message_send, - hooks_before_tool_call: self.hooks_before_tool_call, - hooks_after_tool_call: self.hooks_after_tool_call, - hooks_on_turn_end: self.hooks_on_turn_end, - hooks_on_abort: self.hooks_on_abort, + hooks: self.hooks, system_prompt: self.system_prompt, history: self.history, locked_prefix_len: 0, diff --git a/llm-worker/tests/parallel_execution_test.rs b/llm-worker/tests/parallel_execution_test.rs index 485df23..f3c2769 100644 --- a/llm-worker/tests/parallel_execution_test.rs +++ b/llm-worker/tests/parallel_execution_test.rs @@ -9,11 +9,11 @@ use std::time::{Duration, Instant}; use async_trait::async_trait; use llm_worker::Worker; use llm_worker::hook::{ - AfterToolCall, AfterToolCallResult, BeforeToolCall, BeforeToolCallResult, Hook, HookError, - ToolCall, ToolResult, + Hook, HookError, PostToolCall, PostToolCallContext, PostToolCallResult, PreToolCall, + PreToolCallResult, ToolCallContext, }; use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent}; -use llm_worker::tool::{Tool, ToolError}; +use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta}; mod common; use common::MockLlmClient; @@ -42,25 +42,24 @@ impl SlowTool { fn call_count(&self) -> usize { self.call_count.load(Ordering::SeqCst) } + + /// ToolDefinition を作成 + fn definition(&self) -> ToolDefinition { + let tool = self.clone(); + Arc::new(move || { + let meta = ToolMeta::new(&tool.name) + .description("A tool that waits before responding") + .input_schema(serde_json::json!({ + "type": "object", + "properties": {} + })); + (meta, Arc::new(tool.clone()) as Arc) + }) + } } #[async_trait] impl Tool for SlowTool { - fn name(&self) -> &str { - &self.name - } - - fn description(&self) -> &str { - "A tool that waits before responding" - } - - fn input_schema(&self) -> serde_json::Value { - serde_json::json!({ - "type": "object", - "properties": {} - }) - } - async fn execute(&self, _input_json: &str) -> Result { self.call_count.fetch_add(1, Ordering::SeqCst); tokio::time::sleep(Duration::from_millis(self.delay_ms)).await; @@ -106,9 +105,9 @@ async fn test_parallel_tool_execution() { let tool2_clone = tool2.clone(); let tool3_clone = tool3.clone(); - worker.register_tool(tool1); - worker.register_tool(tool2); - worker.register_tool(tool3); + worker.register_tool(tool1.definition()).unwrap(); + worker.register_tool(tool2.definition()).unwrap(); + worker.register_tool(tool3.definition()).unwrap(); let start = Instant::now(); let _result = worker.run("Run all tools").await; @@ -130,7 +129,7 @@ async fn test_parallel_tool_execution() { println!("Parallel execution completed in {:?}", elapsed); } -/// Hook: before_tool_call でスキップされたツールは実行されないことを確認 +/// Hook: pre_tool_call でスキップされたツールは実行されないことを確認 #[tokio::test] async fn test_before_tool_call_skip() { let events = vec![ @@ -154,24 +153,24 @@ async fn test_before_tool_call_skip() { let allowed_clone = allowed_tool.clone(); let blocked_clone = blocked_tool.clone(); - worker.register_tool(allowed_tool); - worker.register_tool(blocked_tool); + worker.register_tool(allowed_tool.definition()).unwrap(); + worker.register_tool(blocked_tool.definition()).unwrap(); // "blocked_tool" をスキップするHook struct BlockingHook; #[async_trait] - impl Hook for BlockingHook { - async fn call(&self, tool_call: &mut ToolCall) -> Result { - if tool_call.name == "blocked_tool" { - Ok(BeforeToolCallResult::Skip) + impl Hook for BlockingHook { + async fn call(&self, ctx: &mut ToolCallContext) -> Result { + if ctx.call.name == "blocked_tool" { + Ok(PreToolCallResult::Skip) } else { - Ok(BeforeToolCallResult::Continue) + Ok(PreToolCallResult::Continue) } } } - worker.add_before_tool_call_hook(BlockingHook); + worker.add_pre_tool_call_hook(BlockingHook); let _result = worker.run("Test hook").await; @@ -188,9 +187,9 @@ async fn test_before_tool_call_skip() { ); } -/// Hook: after_tool_call で結果が改変されることを確認 +/// Hook: post_tool_call で結果が改変されることを確認 #[tokio::test] -async fn test_after_tool_call_modification() { +async fn test_post_tool_call_modification() { // 複数リクエストに対応するレスポンスを準備 let client = MockLlmClient::with_responses(vec![ // 1回目のリクエスト: ツール呼び出し @@ -220,21 +219,21 @@ async fn test_after_tool_call_modification() { #[async_trait] impl Tool for SimpleTool { - fn name(&self) -> &str { - "test_tool" - } - fn description(&self) -> &str { - "Test" - } - fn input_schema(&self) -> serde_json::Value { - serde_json::json!({}) - } async fn execute(&self, _: &str) -> Result { Ok("Original Result".to_string()) } } - worker.register_tool(SimpleTool); + fn simple_tool_definition() -> ToolDefinition { + Arc::new(|| { + let meta = ToolMeta::new("test_tool") + .description("Test") + .input_schema(serde_json::json!({})); + (meta, Arc::new(SimpleTool) as Arc) + }) + } + + worker.register_tool(simple_tool_definition()).unwrap(); // 結果を改変するHook struct ModifyingHook { @@ -242,19 +241,19 @@ async fn test_after_tool_call_modification() { } #[async_trait] - impl Hook for ModifyingHook { + impl Hook for ModifyingHook { async fn call( &self, - tool_result: &mut ToolResult, - ) -> Result { - tool_result.content = format!("[Modified] {}", tool_result.content); - *self.modified_content.lock().unwrap() = Some(tool_result.content.clone()); - Ok(AfterToolCallResult::Continue) + ctx: &mut PostToolCallContext, + ) -> Result { + ctx.result.content = format!("[Modified] {}", ctx.result.content); + *self.modified_content.lock().unwrap() = Some(ctx.result.content.clone()); + Ok(PostToolCallResult::Continue) } } let modified_content = Arc::new(std::sync::Mutex::new(None)); - worker.add_after_tool_call_hook(ModifyingHook { + worker.add_post_tool_call_hook(ModifyingHook { modified_content: modified_content.clone(), }); diff --git a/llm-worker/tests/tool_macro_test.rs b/llm-worker/tests/tool_macro_test.rs index 04f8f64..7b04cfe 100644 --- a/llm-worker/tests/tool_macro_test.rs +++ b/llm-worker/tests/tool_macro_test.rs @@ -9,7 +9,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use schemars; use serde; -use llm_worker::tool::Tool; +use llm_worker::tool::{Tool, ToolMeta}; use llm_worker_macros::tool_registry; // ============================================================================= @@ -51,30 +51,31 @@ async fn test_basic_tool_generation() { prefix: "Hello".to_string(), }; - // ファクトリメソッドでツールを取得 - let greet_tool = ctx.greet_tool(); + // ファクトリメソッドでToolDefinitionを取得 + let greet_definition = ctx.greet_definition(); - // 名前の確認 - assert_eq!(greet_tool.name(), "greet"); + // ファクトリを呼び出してMetaとToolを取得 + let (meta, tool) = greet_definition(); - // 説明の確認(docコメントから取得) - let desc = greet_tool.description(); + // メタ情報の確認 + assert_eq!(meta.name, "greet"); assert!( - desc.contains("メッセージに挨拶を追加する"), + meta.description.contains("メッセージに挨拶を追加する"), "Description should contain doc comment: {}", - desc + meta.description ); - - // スキーマの確認 - let schema = greet_tool.input_schema(); - println!("Schema: {}", serde_json::to_string_pretty(&schema).unwrap()); assert!( - schema.get("properties").is_some(), + meta.input_schema.get("properties").is_some(), "Schema should have properties" ); + println!( + "Schema: {}", + serde_json::to_string_pretty(&meta.input_schema).unwrap() + ); + // 実行テスト - let result = greet_tool.execute(r#"{"message": "World"}"#).await; + let result = tool.execute(r#"{"message": "World"}"#).await; assert!(result.is_ok(), "Should execute successfully"); let output = result.unwrap(); assert!(output.contains("Hello"), "Output should contain prefix"); @@ -87,11 +88,11 @@ async fn test_multiple_arguments() { prefix: "".to_string(), }; - let add_tool = ctx.add_tool(); + let (meta, tool) = ctx.add_definition()(); - assert_eq!(add_tool.name(), "add"); + assert_eq!(meta.name, "add"); - let result = add_tool.execute(r#"{"a": 10, "b": 20}"#).await; + let result = tool.execute(r#"{"a": 10, "b": 20}"#).await; assert!(result.is_ok()); let output = result.unwrap(); assert!(output.contains("30"), "Should contain sum: {}", output); @@ -103,12 +104,12 @@ async fn test_no_arguments() { prefix: "TestPrefix".to_string(), }; - let get_prefix_tool = ctx.get_prefix_tool(); + let (meta, tool) = ctx.get_prefix_definition()(); - assert_eq!(get_prefix_tool.name(), "get_prefix"); + assert_eq!(meta.name, "get_prefix"); // 空のJSONオブジェクトで呼び出し - let result = get_prefix_tool.execute(r#"{}"#).await; + let result = tool.execute(r#"{}"#).await; assert!(result.is_ok()); let output = result.unwrap(); assert!( @@ -124,10 +125,10 @@ async fn test_invalid_arguments() { prefix: "".to_string(), }; - let greet_tool = ctx.greet_tool(); + let (_, tool) = ctx.greet_definition()(); // 不正なJSON - let result = greet_tool.execute(r#"{"wrong_field": "value"}"#).await; + let result = tool.execute(r#"{"wrong_field": "value"}"#).await; assert!(result.is_err(), "Should fail with invalid arguments"); } @@ -163,9 +164,9 @@ impl FallibleContext { #[tokio::test] async fn test_result_return_type_success() { let ctx = FallibleContext; - let validate_tool = ctx.validate_tool(); + let (_, tool) = ctx.validate_definition()(); - let result = validate_tool.execute(r#"{"value": 42}"#).await; + let result = tool.execute(r#"{"value": 42}"#).await; assert!(result.is_ok(), "Should succeed for positive value"); let output = result.unwrap(); assert!(output.contains("Valid"), "Should contain Valid: {}", output); @@ -174,9 +175,9 @@ async fn test_result_return_type_success() { #[tokio::test] async fn test_result_return_type_error() { let ctx = FallibleContext; - let validate_tool = ctx.validate_tool(); + let (_, tool) = ctx.validate_definition()(); - let result = validate_tool.execute(r#"{"value": -1}"#).await; + let result = tool.execute(r#"{"value": -1}"#).await; assert!(result.is_err(), "Should fail for negative value"); let err = result.unwrap_err(); @@ -211,12 +212,12 @@ async fn test_sync_method() { counter: Arc::new(AtomicUsize::new(0)), }; - let increment_tool = ctx.increment_tool(); + let (_, tool) = ctx.increment_definition()(); // 3回実行 - let result1 = increment_tool.execute(r#"{}"#).await; - let result2 = increment_tool.execute(r#"{}"#).await; - let result3 = increment_tool.execute(r#"{}"#).await; + let result1 = tool.execute(r#"{}"#).await; + let result2 = tool.execute(r#"{}"#).await; + let result3 = tool.execute(r#"{}"#).await; assert!(result1.is_ok()); assert!(result2.is_ok()); @@ -225,3 +226,22 @@ async fn test_sync_method() { // カウンターは3になっているはず assert_eq!(ctx.counter.load(Ordering::SeqCst), 3); } + +// ============================================================================= +// Test: ToolMeta Immutability +// ============================================================================= + +#[tokio::test] +async fn test_tool_meta_immutability() { + let ctx = SimpleContext { + prefix: "Test".to_string(), + }; + + // 2回取得しても同じメタ情報が得られることを確認 + let (meta1, _) = ctx.greet_definition()(); + let (meta2, _) = ctx.greet_definition()(); + + assert_eq!(meta1.name, meta2.name); + assert_eq!(meta1.description, meta2.description); + assert_eq!(meta1.input_schema, meta2.input_schema); +} diff --git a/llm-worker/tests/validation_test.rs b/llm-worker/tests/validation_test.rs index 9ba3017..71e08e7 100644 --- a/llm-worker/tests/validation_test.rs +++ b/llm-worker/tests/validation_test.rs @@ -1,4 +1,3 @@ -use llm_worker::llm_client::LlmClient; use llm_worker::llm_client::providers::openai::OpenAIClient; use llm_worker::{Worker, WorkerError}; diff --git a/llm-worker/tests/worker_fixtures.rs b/llm-worker/tests/worker_fixtures.rs index edb26b5..ce777fa 100644 --- a/llm-worker/tests/worker_fixtures.rs +++ b/llm-worker/tests/worker_fixtures.rs @@ -12,7 +12,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use async_trait::async_trait; use common::MockLlmClient; use llm_worker::Worker; -use llm_worker::tool::{Tool, ToolError}; +use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta}; /// フィクスチャディレクトリのパス fn fixtures_dir() -> std::path::PathBuf { @@ -35,31 +35,29 @@ impl MockWeatherTool { fn get_call_count(&self) -> usize { self.call_count.load(Ordering::SeqCst) } + + fn definition(&self) -> ToolDefinition { + let tool = self.clone(); + Arc::new(move || { + let meta = ToolMeta::new("get_weather") + .description("Get the current weather for a city") + .input_schema(serde_json::json!({ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name" + } + }, + "required": ["city"] + })); + (meta, Arc::new(tool.clone()) as Arc) + }) + } } #[async_trait] impl Tool for MockWeatherTool { - fn name(&self) -> &str { - "get_weather" - } - - fn description(&self) -> &str { - "Get the current weather for a city" - } - - fn input_schema(&self) -> serde_json::Value { - serde_json::json!({ - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "The city name" - } - }, - "required": ["city"] - }) - } - async fn execute(&self, input_json: &str) -> Result { self.call_count.fetch_add(1, Ordering::SeqCst); @@ -158,7 +156,7 @@ async fn test_worker_tool_call() { // ツールを登録 let weather_tool = MockWeatherTool::new(); let tool_for_check = weather_tool.clone(); - worker.register_tool(weather_tool); + worker.register_tool(weather_tool.definition()).unwrap(); // メッセージを送信 let _result = worker.run("What's the weather in Tokyo?").await; -- 2.43.0 From a2f53d787933872891433d4a4e41972dcf0801db Mon Sep 17 00:00:00 2001 From: Hare Date: Sat, 10 Jan 2026 22:45:01 +0900 Subject: [PATCH 3/4] update: Implement cancellation notification using mpsc::channel --- docs/spec/cancellation.md | 42 +--- llm-worker/examples/worker_cancel_demo.rs | 12 +- llm-worker/src/worker.rs | 285 +++++++++++++--------- llm-worker/tests/tool_macro_test.rs | 1 - 4 files changed, 192 insertions(+), 148 deletions(-) diff --git a/docs/spec/cancellation.md b/docs/spec/cancellation.md index 54f5dc7..611f612 100644 --- a/docs/spec/cancellation.md +++ b/docs/spec/cancellation.md @@ -4,7 +4,7 @@ Workerの非同期キャンセル機構についての設計ドキュメント ## 概要 -`tokio_util::sync::CancellationToken`を用いて、別タスクからWorkerの実行を安全にキャンセルできる。 +`tokio::sync::mpsc`の通知チャネルを用いて、別タスクからWorkerの実行を安全にキャンセルできる。 ```rust let worker = Arc::new(Mutex::new(Worker::new(client))); @@ -19,15 +19,6 @@ let handle = tokio::spawn(async move { worker.lock().await.cancel(); ``` -## キャンセルポイント - -キャンセルは以下のタイミングでチェックされる: - -1. **ターンループ先頭** — `is_cancelled()`で即座にチェック -2. **ストリーム開始前** — `client.stream()`呼び出し時 -3. **ストリーム受信中** — `tokio::select!`で各イベント受信と並行監視 -4. **ツール実行中** — `join_all()`と並行監視 - ## キャンセル時の処理フロー ``` @@ -42,11 +33,10 @@ Err(WorkerError::Cancelled) // エラー返却 ## API -| メソッド | 説明 | -| ---------------------- | --------------------------------------------------------- | -| `cancel()` | キャンセルをトリガー | -| `is_cancelled()` | キャンセル状態を確認 | -| `cancellation_token()` | トークンへの参照を取得(`clone()`してタスク間で共有可能) | +| メソッド | 説明 | +| ----------------- | ------------------------------ | +| `cancel()` | キャンセルをトリガー | +| `cancel_sender()` | キャンセル通知用のSenderを取得 | ## on_abort フック @@ -69,22 +59,12 @@ async fn on_abort(&self, reason: &str) -> Result<(), HookError> { ## 既知の問題 -### 1. キャンセルトークンの再利用不可 +### on_abort の発火基準 -`CancellationToken`は一度キャンセルされると永続的にキャンセル状態になる。 -同じWorkerインスタンスで再度`run()`を呼ぶと即座に`Cancelled`エラーになる。 +`on_abort` は **interrupt(中断)** された場合に必ず発火する。 -**対応案:** +interrupt の例: -- `run()`開始時に新しいトークンを生成する -- `reset_cancellation()`メソッドを提供する - -### 2. Sync バウンドの追加(破壊的変更) - -`tokio::select!`使用のため、Handler/Scope型に`Sync`バウンドを追加した。 -既存のユーザーコードで`Sync`未実装の型を使用している場合、コンパイルエラーになる。 - -### 3. エラー時のon_abort呼び出し - -現在、`on_abort`はキャンセルとフックAbort時のみ呼ばれる。 -ストリームエラー等のその他エラー時には呼ばれないため、一貫性に欠ける可能性がある。 +- `WorkerError::Cancelled`(キャンセル) +- `WorkerError::Aborted`(フックによるAbort) +- ストリーム/ツール/クライアント/Hook の各種エラーで処理が中断された場合 diff --git a/llm-worker/examples/worker_cancel_demo.rs b/llm-worker/examples/worker_cancel_demo.rs index b4b7114..9d6e6d2 100644 --- a/llm-worker/examples/worker_cancel_demo.rs +++ b/llm-worker/examples/worker_cancel_demo.rs @@ -30,10 +30,10 @@ async fn main() -> Result<(), Box> { println!("🚀 Starting Worker..."); println!("💡 Will cancel after 2 seconds\n"); - // キャンセルトークンを先に取得(ロックを保持しない) - let cancel_token = { + // キャンセルSenderを先に取得(ロックを保持しない) + let cancel_tx = { let w = worker.lock().await; - w.cancellation_token().clone() + w.cancel_sender() }; // タスク1: Workerを実行 @@ -43,10 +43,10 @@ async fn main() -> Result<(), Box> { println!("📡 Sending request to LLM..."); match w.run("Tell me a very long story about a brave knight. Make it as detailed as possible with many paragraphs.").await { - Ok(WorkerResult::Finished(_)) => { + Ok(WorkerResult::Finished) => { println!("✅ Task completed normally"); } - Ok(WorkerResult::Paused(_)) => { + Ok(WorkerResult::Paused) => { println!("⏸️ Task paused"); } Err(e) => { @@ -59,7 +59,7 @@ async fn main() -> Result<(), Box> { tokio::spawn(async move { tokio::time::sleep(Duration::from_secs(2)).await; println!("\n🛑 Cancelling worker..."); - cancel_token.cancel(); + let _ = cancel_tx.send(()).await; }); // タスク完了を待つ diff --git a/llm-worker/src/worker.rs b/llm-worker/src/worker.rs index d8c32a0..396b7b1 100644 --- a/llm-worker/src/worker.rs +++ b/llm-worker/src/worker.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use std::sync::{Arc, Mutex}; use futures::StreamExt; -use tokio_util::sync::CancellationToken; +use tokio::sync::mpsc; use tracing::{debug, info, trace, warn}; use crate::{ @@ -78,11 +78,11 @@ pub struct WorkerConfig { /// Workerの実行結果(ステータス) #[derive(Debug)] -pub enum WorkerResult<'a> { +pub enum WorkerResult { /// 完了(ユーザー入力待ち状態) - Finished(&'a [Message]), + Finished, /// 一時停止(再開可能) - Paused(&'a [Message]), + Paused, } /// 内部用: ツール実行結果 @@ -182,8 +182,11 @@ pub struct Worker { turn_notifiers: Vec>, /// リクエスト設定(max_tokens, temperature等) request_config: RequestConfig, - /// キャンセレーショントークン(実行中断用) - cancellation_token: CancellationToken, + /// 前回の実行が中断されたかどうか + last_run_interrupted: bool, + /// キャンセル通知用チャネル(実行中断用) + cancel_tx: mpsc::Sender<()>, + cancel_rx: mpsc::Receiver<()>, /// 状態マーカー _state: PhantomData, } @@ -193,6 +196,57 @@ pub struct Worker { // ============================================================================= impl Worker { + fn reset_interruption_state(&mut self) { + self.last_run_interrupted = false; + } + + /// ターンを実行 + /// + /// 新しいユーザーメッセージを履歴に追加し、LLMにリクエストを送信する。 + /// ツール呼び出しがある場合は自動的にループする。 + pub async fn run( + &mut self, + user_input: impl Into, + ) -> Result { + self.reset_interruption_state(); + // Hook: on_prompt_submit + let mut user_message = Message::user(user_input); + let result = self.run_on_prompt_submit_hooks(&mut user_message).await; + let result = match result { + Ok(value) => value, + Err(err) => return self.finalize_interruption(Err(err)).await, + }; + match result { + OnPromptSubmitResult::Cancel(reason) => { + self.last_run_interrupted = true; + return self.finalize_interruption(Err(WorkerError::Aborted(reason))).await; + } + OnPromptSubmitResult::Continue => {} + } + self.history.push(user_message); + let result = self.run_turn_loop().await; + self.finalize_interruption(result).await + } + + fn drain_cancel_queue(&mut self) { + use tokio::sync::mpsc::error::TryRecvError; + loop { + match self.cancel_rx.try_recv() { + Ok(()) => continue, + Err(TryRecvError::Empty) | Err(TryRecvError::Disconnected) => break, + } + } + } + + fn try_cancelled(&mut self) -> bool { + use tokio::sync::mpsc::error::TryRecvError; + match self.cancel_rx.try_recv() { + Ok(()) => true, + Err(TryRecvError::Empty) => false, + Err(TryRecvError::Disconnected) => true, + } + } + /// イベント購読者を登録する /// /// 登録したSubscriberは、LLMからのストリーミングイベントを @@ -410,6 +464,11 @@ impl Worker { self.request_config.stop_sequences.clear(); } + /// キャンセル通知用Senderを取得する + pub fn cancel_sender(&self) -> mpsc::Sender<()> { + self.cancel_tx.clone() + } + /// リクエスト設定を一括で設定 pub fn set_request_config(&mut self, config: RequestConfig) { self.request_config = config; @@ -437,17 +496,17 @@ impl Worker { /// worker.lock().unwrap().cancel(); /// ``` pub fn cancel(&self) { - self.cancellation_token.cancel(); + let _ = self.cancel_tx.try_send(()); } /// キャンセルされているかチェック - pub fn is_cancelled(&self) -> bool { - self.cancellation_token.is_cancelled() + pub fn is_cancelled(&mut self) -> bool { + self.try_cancelled() } - /// キャンセレーショントークンへの参照を取得 - pub fn cancellation_token(&self) -> &CancellationToken { - &self.cancellation_token + /// 前回の実行が中断されたかどうか + pub fn last_run_interrupted(&self) -> bool { + self.last_run_interrupted } /// 登録されたツールからLLM用ToolDefinitionのリストを生成 @@ -636,6 +695,28 @@ impl Worker { Ok(()) } + async fn finalize_interruption( + &mut self, + result: Result, + ) -> Result { + match result { + Ok(value) => Ok(value), + Err(err) => { + self.last_run_interrupted = true; + let reason = match &err { + WorkerError::Aborted(reason) => reason.clone(), + WorkerError::Cancelled => "Cancelled".to_string(), + _ => err.to_string(), + }; + if let Err(hook_err) = self.run_on_abort_hooks(&reason).await { + self.last_run_interrupted = true; + return Err(hook_err); + } + Err(err) + } + } + } + /// 未実行のツール呼び出しがあるかチェック(Pauseからの復帰用) fn get_pending_tool_calls(&self) -> Option> { let last_msg = self.history.last()?; @@ -687,7 +768,10 @@ impl Worker { let mut skip = false; for hook in &self.hooks.pre_tool_call { - let result = hook.call(&mut context).await?; + let result = hook + .call(&mut context) + .await + .inspect_err(|_| self.last_run_interrupted = true)?; match result { PreToolCallResult::Continue => {} PreToolCallResult::Skip => { @@ -695,9 +779,11 @@ impl Worker { break; } PreToolCallResult::Abort(reason) => { + self.last_run_interrupted = true; return Err(WorkerError::Aborted(reason)); } PreToolCallResult::Pause => { + self.last_run_interrupted = true; return Ok(ToolExecutionResult::Paused); } } @@ -747,10 +833,12 @@ impl Worker { // ツール実行をキャンセル可能にする let mut results = tokio::select! { results = join_all(futures) => results, - _ = self.cancellation_token.cancelled() => { - info!("Tool execution cancelled"); + cancel = self.cancel_rx.recv() => { + if cancel.is_some() { + info!("Tool execution cancelled"); + } self.timeline.abort_current_block(); - self.run_on_abort_hooks("Cancelled").await?; + self.last_run_interrupted = true; return Err(WorkerError::Cancelled); } }; @@ -767,10 +855,14 @@ impl Worker { }; for hook in &self.hooks.post_tool_call { - let result = hook.call(&mut context).await?; + let result = hook + .call(&mut context) + .await + .inspect_err(|_| self.last_run_interrupted = true)?; match result { PostToolCallResult::Continue => {} PostToolCallResult::Abort(reason) => { + self.last_run_interrupted = true; return Err(WorkerError::Aborted(reason)); } } @@ -784,7 +876,9 @@ impl Worker { } /// 内部で使用するターン実行ロジック - async fn run_turn_loop(&mut self) -> Result, WorkerError> { + async fn run_turn_loop(&mut self) -> Result { + self.reset_interruption_state(); + self.drain_cancel_queue(); let tool_definitions = self.build_tool_definitions(); info!( @@ -796,24 +890,31 @@ impl Worker { // Resume check: Pending tool calls if let Some(tool_calls) = self.get_pending_tool_calls() { info!("Resuming pending tool calls"); - match self.execute_tools(tool_calls).await? { - ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)), - ToolExecutionResult::Completed(results) => { + match self.execute_tools(tool_calls).await { + Ok(ToolExecutionResult::Paused) => { + self.last_run_interrupted = true; + return Ok(WorkerResult::Paused); + } + Ok(ToolExecutionResult::Completed(results)) => { for result in results { self.history .push(Message::tool_result(&result.tool_use_id, &result.content)); } // Continue to loop } + Err(err) => { + self.last_run_interrupted = true; + return Err(err); + } } } loop { // キャンセルチェック - if self.cancellation_token.is_cancelled() { + if self.try_cancelled() { info!("Execution cancelled"); self.timeline.abort_current_block(); - self.run_on_abort_hooks("Cancelled").await?; + self.last_run_interrupted = true; return Err(WorkerError::Cancelled); } @@ -825,14 +926,17 @@ impl Worker { } // Hook: pre_llm_request - let (control, request_context) = self.run_pre_llm_request_hooks().await?; + let (control, request_context) = self + .run_pre_llm_request_hooks() + .await + .inspect_err(|_| self.last_run_interrupted = true)?; match control { PreLlmRequestResult::Cancel(reason) => { info!(reason = %reason, "Aborted by hook"); for notifier in &self.turn_notifiers { notifier.on_turn_end(current_turn); } - self.run_on_abort_hooks(&reason).await?; + self.last_run_interrupted = true; return Err(WorkerError::Aborted(reason)); } PreLlmRequestResult::Continue => {} @@ -853,11 +957,14 @@ impl Worker { // ストリームを取得(キャンセル可能) let mut stream = tokio::select! { - stream_result = self.client.stream(request) => stream_result?, - _ = self.cancellation_token.cancelled() => { - info!("Cancelled before stream started"); + stream_result = self.client.stream(request) => stream_result + .inspect_err(|_| self.last_run_interrupted = true)?, + cancel = self.cancel_rx.recv() => { + if cancel.is_some() { + info!("Cancelled before stream started"); + } self.timeline.abort_current_block(); - self.run_on_abort_hooks("Cancelled").await?; + self.last_run_interrupted = true; return Err(WorkerError::Cancelled); } }; @@ -877,7 +984,8 @@ impl Worker { warn!(error = %e, "Stream error"); } } - let event = result?; + let event = result + .inspect_err(|_| self.last_run_interrupted = true)?; let timeline_event: crate::timeline::event::Event = event.into(); self.timeline.dispatch(&timeline_event); } @@ -885,10 +993,12 @@ impl Worker { } } // キャンセル待機 - _ = self.cancellation_token.cancelled() => { - info!("Stream cancelled"); + cancel = self.cancel_rx.recv() => { + if cancel.is_some() { + info!("Stream cancelled"); + } self.timeline.abort_current_block(); - self.run_on_abort_hooks("Cancelled").await?; + self.last_run_interrupted = true; return Err(WorkerError::Cancelled); } } @@ -913,30 +1023,42 @@ impl Worker { if tool_calls.is_empty() { // ツール呼び出しなし → ターン終了判定 - let turn_result = self.run_on_turn_end_hooks().await?; + let turn_result = self + .run_on_turn_end_hooks() + .await + .inspect_err(|_| self.last_run_interrupted = true)?; match turn_result { OnTurnEndResult::Finish => { - return Ok(WorkerResult::Finished(&self.history)); + self.last_run_interrupted = false; + return Ok(WorkerResult::Finished); } OnTurnEndResult::ContinueWithMessages(additional) => { self.history.extend(additional); continue; } OnTurnEndResult::Paused => { - return Ok(WorkerResult::Paused(&self.history)); + self.last_run_interrupted = true; + return Ok(WorkerResult::Paused); } } } // ツール実行 - match self.execute_tools(tool_calls).await? { - ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)), - ToolExecutionResult::Completed(results) => { + match self.execute_tools(tool_calls).await { + Ok(ToolExecutionResult::Paused) => { + self.last_run_interrupted = true; + return Ok(WorkerResult::Paused); + } + Ok(ToolExecutionResult::Completed(results)) => { for result in results { self.history .push(Message::tool_result(&result.tool_use_id, &result.content)); } } + Err(err) => { + self.last_run_interrupted = true; + return Err(err); + } } } } @@ -944,8 +1066,10 @@ impl Worker { /// 実行を再開(Pause状態からの復帰) /// /// 新しいユーザーメッセージを履歴に追加せず、現在の状態からターン処理を再開する。 - pub async fn resume(&mut self) -> Result, WorkerError> { - self.run_turn_loop().await + pub async fn resume(&mut self) -> Result { + self.reset_interruption_state(); + let result = self.run_turn_loop().await; + self.finalize_interruption(result).await } } @@ -959,6 +1083,7 @@ impl Worker { let text_block_collector = TextBlockCollector::new(); let tool_call_collector = ToolCallCollector::new(); let mut timeline = Timeline::new(); + let (cancel_tx, cancel_rx) = mpsc::channel(1); // コレクターをTimelineに登録 timeline.on_text_block(text_block_collector.clone()); @@ -977,7 +1102,9 @@ impl Worker { turn_count: 0, turn_notifiers: Vec::new(), request_config: RequestConfig::default(), - cancellation_token: CancellationToken::new(), + last_run_interrupted: false, + cancel_tx, + cancel_rx, _state: PhantomData, } } @@ -1146,46 +1273,13 @@ impl Worker { turn_count: self.turn_count, turn_notifiers: self.turn_notifiers, request_config: self.request_config, - cancellation_token: self.cancellation_token, + last_run_interrupted: self.last_run_interrupted, + cancel_tx: self.cancel_tx, + cancel_rx: self.cancel_rx, _state: PhantomData, } } - /// ターンを実行(Mutable状態) - /// - /// 新しいユーザーメッセージを履歴に追加し、LLMにリクエストを送信する。 - /// ツール呼び出しがある場合は自動的にループする。 - /// - /// 注意: この関数は履歴を変更するため、キャッシュ保護が必要な場合は - /// `lock()` を呼んでからLocked状態で `run` を使用すること。 - pub async fn run( - &mut self, - user_input: impl Into, - ) -> Result, WorkerError> { - // Hook: on_prompt_submit - let mut user_message = Message::user(user_input); - let result = self.run_on_prompt_submit_hooks(&mut user_message).await?; - match result { - OnPromptSubmitResult::Cancel(reason) => { - self.run_on_abort_hooks(&reason).await?; - return Err(WorkerError::Aborted(reason)); - } - OnPromptSubmitResult::Continue => {} - } - self.history.push(user_message); - self.run_turn_loop().await - } - - /// 複数メッセージでターンを実行(Mutable状態) - /// - /// 指定されたメッセージを履歴に追加してから実行する。 - pub async fn run_with_messages( - &mut self, - messages: Vec, - ) -> Result, WorkerError> { - self.history.extend(messages); - self.run_turn_loop().await - } } // ============================================================================= @@ -1193,37 +1287,6 @@ impl Worker { // ============================================================================= impl Worker { - /// ターンを実行(Locked状態) - /// - /// 新しいユーザーメッセージを履歴の末尾に追加し、LLMにリクエストを送信する。 - /// ロック時点より前の履歴(プレフィックス)は不変であるため、キャッシュヒットが保証される。 - pub async fn run( - &mut self, - user_input: impl Into, - ) -> Result, WorkerError> { - // Hook: on_prompt_submit - let mut user_message = Message::user(user_input); - let result = self.run_on_prompt_submit_hooks(&mut user_message).await?; - match result { - OnPromptSubmitResult::Cancel(reason) => { - self.run_on_abort_hooks(&reason).await?; - return Err(WorkerError::Aborted(reason)); - } - OnPromptSubmitResult::Continue => {} - } - self.history.push(user_message); - self.run_turn_loop().await - } - - /// 複数メッセージでターンを実行(Locked状態) - pub async fn run_with_messages( - &mut self, - messages: Vec, - ) -> Result, WorkerError> { - self.history.extend(messages); - self.run_turn_loop().await - } - /// ロック時点のプレフィックス長を取得 pub fn locked_prefix_len(&self) -> usize { self.locked_prefix_len @@ -1247,7 +1310,9 @@ impl Worker { turn_count: self.turn_count, turn_notifiers: self.turn_notifiers, request_config: self.request_config, - cancellation_token: self.cancellation_token, + last_run_interrupted: self.last_run_interrupted, + cancel_tx: self.cancel_tx, + cancel_rx: self.cancel_rx, _state: PhantomData, } } diff --git a/llm-worker/tests/tool_macro_test.rs b/llm-worker/tests/tool_macro_test.rs index 7b04cfe..4676852 100644 --- a/llm-worker/tests/tool_macro_test.rs +++ b/llm-worker/tests/tool_macro_test.rs @@ -9,7 +9,6 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use schemars; use serde; -use llm_worker::tool::{Tool, ToolMeta}; use llm_worker_macros::tool_registry; // ============================================================================= -- 2.43.0 From c281248bf8ab5db1db6e8f95f277142fe7449b92 Mon Sep 17 00:00:00 2001 From: Hare Date: Sat, 10 Jan 2026 22:46:08 +0900 Subject: [PATCH 4/4] update: Locked -> CacheLocked --- README.md | 2 +- docs/spec/cache_lock.md | 6 +++--- llm-worker/src/llm_client/mod.rs | 2 +- llm-worker/src/state.rs | 12 ++++++------ llm-worker/src/worker.rs | 14 +++++++------- llm-worker/tests/worker_state_test.rs | 12 ++++++------ 6 files changed, 24 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index e6b13db..2929f5f 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Rusty, Efficient, and Agentic LLM Client Library - Tool System: Define tools as async functions. The Worker automatically parses LLM tool calls, executes them in parallel, and feeds results back. - Hook System: Intercept execution flow with `before_tool_call`, `after_tool_call`, and `on_turn_end` hooks for validation, logging, or self-correction. - Event-Driven Streaming: Subscribe to real-time events (text deltas, tool calls, usage) for responsive UIs. -- Cache-Aware State Management: Type-state pattern (`Mutable` → `Locked`) ensures KV cache efficiency by protecting the conversation prefix. +- Cache-Aware State Management: Type-state pattern (`Mutable` → `CacheLocked`) ensures KV cache efficiency by protecting the conversation prefix. ## Quick Start diff --git a/docs/spec/cache_lock.md b/docs/spec/cache_lock.md index f2e8a00..3454f6b 100644 --- a/docs/spec/cache_lock.md +++ b/docs/spec/cache_lock.md @@ -27,7 +27,7 @@ RustのType-stateパターンを利用し、Workerの状態によって利用可 * 自由な編集が可能な状態。 * システムプロンプトの設定・変更が可能。 * メッセージ履歴の初期構築(ロード、編集)が可能。 -* **`Locked` (キャッシュ保護状態)** +* **`CacheLocked` (キャッシュ保護状態)** * キャッシュの有効活用を目的とした、前方不変状態。 * **システムプロンプトの変更不可**。 * **既存メッセージ履歴の変更不可**(追記のみ許可)。 @@ -47,7 +47,7 @@ worker.history_mut().push(initial_message); // 3. ロックしてLocked状態へ遷移 // これにより、ここまでのコンテキストが "Fixed Prefix" として扱われる -let mut locked_worker: Worker = worker.lock(); +let mut locked_worker: Worker = worker.lock(); // 4. 利用 (Locked状態) // 実行は可能。新しいメッセージは履歴の末尾に追記される。 @@ -65,4 +65,4 @@ locked_worker.run(new_user_input).await?; * **状態パラメータの導入**: `Worker` の導入。 * **コンテキスト所有権の委譲**: `run` メソッドの引数でコンテキストを受け取るのではなく、`Worker` 内部に `history: Vec` を保持し管理する形へ移行する。 -* **APIの分離**: `Mutable` 特有のメソッド(setter等)と、`Locked` でも使えるメソッド(実行、参照等)をトレイト境界で分離する。 +* **APIの分離**: `Mutable` 特有のメソッド(setter等)と、`CacheLocked` でも使えるメソッド(実行、参照等)をトレイト境界で分離する。 diff --git a/llm-worker/src/llm_client/mod.rs b/llm-worker/src/llm_client/mod.rs index 7dfc6c7..1a74527 100644 --- a/llm-worker/src/llm_client/mod.rs +++ b/llm-worker/src/llm_client/mod.rs @@ -1,6 +1,6 @@ //! LLMクライアント層 //! -//! 各LLMプロバイダと通信し、統一された[`Event`](crate::llm_client::event::Event) +//! 各LLMプロバイダと通信し、統一された[`Event`] //! ストリームを出力します。 //! //! # サポートするプロバイダ diff --git a/llm-worker/src/state.rs b/llm-worker/src/state.rs index 85dbb19..ca04507 100644 --- a/llm-worker/src/state.rs +++ b/llm-worker/src/state.rs @@ -1,7 +1,7 @@ //! Worker状態 //! //! Type-stateパターンによるキャッシュ保護のための状態マーカー型。 -//! Workerは`Mutable` → `Locked`の状態遷移を持ちます。 +//! Workerは`Mutable` → `CacheLocked`の状態遷移を持ちます。 /// Worker状態を表すマーカートレイト /// @@ -19,7 +19,7 @@ mod private { /// - メッセージ履歴の編集(追加、削除、クリア) /// - ツール・Hookの登録 /// -/// `Worker::lock()`により[`Locked`]状態へ遷移できます。 +/// `Worker::lock()`により[`CacheLocked`]状態へ遷移できます。 /// /// # Examples /// @@ -42,7 +42,7 @@ pub struct Mutable; impl private::Sealed for Mutable {} impl WorkerState for Mutable {} -/// ロック状態(キャッシュ保護) +/// キャッシュロック状態(キャッシュ保護) /// /// この状態では以下の制限があります: /// - システムプロンプトの変更不可 @@ -54,7 +54,7 @@ impl WorkerState for Mutable {} /// `Worker::unlock()`により[`Mutable`]状態へ戻せますが、 /// キャッシュ保護が解除されることに注意してください。 #[derive(Debug, Clone, Copy, Default)] -pub struct Locked; +pub struct CacheLocked; -impl private::Sealed for Locked {} -impl WorkerState for Locked {} +impl private::Sealed for CacheLocked {} +impl WorkerState for CacheLocked {} diff --git a/llm-worker/src/worker.rs b/llm-worker/src/worker.rs index 396b7b1..18f4c34 100644 --- a/llm-worker/src/worker.rs +++ b/llm-worker/src/worker.rs @@ -17,7 +17,7 @@ use crate::{ ClientError, ConfigWarning, LlmClient, Request, RequestConfig, ToolDefinition as LlmToolDefinition, }, - state::{Locked, Mutable, WorkerState}, + state::{CacheLocked, Mutable, WorkerState}, subscriber::{ ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter, ToolUseBlockSubscriberAdapter, UsageSubscriberAdapter, WorkerSubscriber, @@ -131,7 +131,7 @@ impl TurnNotifier for SubscriberTurnNotifier { /// # 状態遷移(Type-state) /// /// - [`Mutable`]: 初期状態。システムプロンプトや履歴を自由に編集可能。 -/// - [`Locked`]: キャッシュ保護状態。`lock()`で遷移。前方コンテキストは不変。 +/// - [`CacheLocked`]: キャッシュ保護状態。`lock()`で遷移。前方コンテキストは不変。 /// /// # Examples /// @@ -174,7 +174,7 @@ pub struct Worker { system_prompt: Option, /// メッセージ履歴(Workerが所有) history: Vec, - /// ロック時点での履歴長(Locked状態でのみ意味を持つ) + /// ロック時点での履歴長(CacheLocked状態でのみ意味を持つ) locked_prefix_len: usize, /// ターンカウント turn_count: usize, @@ -1254,11 +1254,11 @@ impl Worker { self } - /// ロックしてLocked状態へ遷移 + /// ロックしてCacheLocked状態へ遷移 /// /// この操作により、現在のシステムプロンプトと履歴が「確定済みプレフィックス」として /// 固定される。以降は履歴への追記のみが可能となり、キャッシュヒットが保証される。 - pub fn lock(self) -> Worker { + pub fn lock(self) -> Worker { let locked_prefix_len = self.history.len(); Worker { client: self.client, @@ -1283,10 +1283,10 @@ impl Worker { } // ============================================================================= -// Locked状態専用の実装 +// CacheLocked状態専用の実装 // ============================================================================= -impl Worker { +impl Worker { /// ロック時点のプレフィックス長を取得 pub fn locked_prefix_len(&self) -> usize { self.locked_prefix_len diff --git a/llm-worker/tests/worker_state_test.rs b/llm-worker/tests/worker_state_test.rs index 035e3be..e98135c 100644 --- a/llm-worker/tests/worker_state_test.rs +++ b/llm-worker/tests/worker_state_test.rs @@ -1,6 +1,6 @@ //! Worker状態管理のテスト //! -//! Type-stateパターン(Mutable/Locked)による状態遷移と +//! Type-stateパターン(Mutable/CacheLocked)による状態遷移と //! ターン間の状態保持をテストする。 mod common; @@ -95,7 +95,7 @@ fn test_mutable_extend_history() { // 状態遷移テスト // ============================================================================= -/// lock()でMutable -> Locked状態に遷移することを確認 +/// lock()でMutable -> CacheLocked状態に遷移することを確認 #[test] fn test_lock_transition() { let client = MockLlmClient::new(vec![]); @@ -108,13 +108,13 @@ fn test_lock_transition() { // ロック let locked_worker = worker.lock(); - // Locked状態でも履歴とシステムプロンプトにアクセス可能 + // CacheLocked状態でも履歴とシステムプロンプトにアクセス可能 assert_eq!(locked_worker.get_system_prompt(), Some("System")); assert_eq!(locked_worker.history().len(), 2); assert_eq!(locked_worker.locked_prefix_len(), 2); } -/// unlock()でLocked -> Mutable状態に遷移することを確認 +/// unlock()でCacheLocked -> Mutable状態に遷移することを確認 #[test] fn test_unlock_transition() { let client = MockLlmClient::new(vec![]); @@ -172,7 +172,7 @@ async fn test_mutable_run_updates_history() { )); } -/// Locked状態で複数ターンを実行し、履歴が正しく累積することを確認 +/// CacheLocked状態で複数ターンを実行し、履歴が正しく累積することを確認 #[tokio::test] async fn test_locked_multi_turn_history_accumulation() { // 2回のリクエストに対応するレスポンスを準備 @@ -340,7 +340,7 @@ async fn test_unlock_edit_relock() { // システムプロンプト保持のテスト // ============================================================================= -/// Locked状態でもシステムプロンプトが保持されることを確認 +/// CacheLocked状態でもシステムプロンプトが保持されることを確認 #[test] fn test_system_prompt_preserved_in_locked_state() { let client = MockLlmClient::new(vec![]); -- 2.43.0