feat: Implement HookEventKind
This commit is contained in:
parent
33f1c218f2
commit
5691b09fc8
239
Cargo.lock
generated
239
Cargo.lock
generated
|
|
@ -104,9 +104,9 @@ checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3"
|
|||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.2.51"
|
||||
version = "1.2.52"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a0aeaff4ff1a90589618835a598e545176939b97874f7abc7851caa0618f203"
|
||||
checksum = "cd4932aefd12402b36c60956a4fe0035421f544799057659ff86f923657aada3"
|
||||
dependencies = [
|
||||
"find-msvc-tools",
|
||||
"shlex",
|
||||
|
|
@ -203,6 +203,12 @@ version = "1.0.20"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555"
|
||||
|
||||
[[package]]
|
||||
name = "equivalent"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
|
||||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.14"
|
||||
|
|
@ -232,9 +238,15 @@ checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
|||
|
||||
[[package]]
|
||||
name = "find-msvc-tools"
|
||||
version = "0.1.6"
|
||||
version = "0.1.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "645cbb3a84e60b7531617d5ae4e57f7e27308f6445f5abf653209ea76dec8dff"
|
||||
checksum = "f449e6c6c08c865631d4890cfacf252b3d396c9bcc83adb6623cdb02a8336c41"
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
|
||||
|
||||
[[package]]
|
||||
name = "foreign-types"
|
||||
|
|
@ -349,6 +361,17 @@ dependencies = [
|
|||
"slab",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.2.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"wasi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.3.4"
|
||||
|
|
@ -361,6 +384,31 @@ dependencies = [
|
|||
"wasip2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.4.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54"
|
||||
dependencies = [
|
||||
"atomic-waker",
|
||||
"bytes",
|
||||
"fnv",
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
"http",
|
||||
"indexmap",
|
||||
"slab",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.16.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100"
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.5.0"
|
||||
|
|
@ -416,6 +464,7 @@ dependencies = [
|
|||
"bytes",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"h2",
|
||||
"http",
|
||||
"http-body",
|
||||
"httparse",
|
||||
|
|
@ -427,6 +476,22 @@ dependencies = [
|
|||
"want",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-rustls"
|
||||
version = "0.27.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58"
|
||||
dependencies = [
|
||||
"http",
|
||||
"hyper",
|
||||
"hyper-util",
|
||||
"rustls",
|
||||
"rustls-pki-types",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-tls"
|
||||
version = "0.6.0"
|
||||
|
|
@ -569,6 +634,16 @@ dependencies = [
|
|||
"icu_properties",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017"
|
||||
dependencies = [
|
||||
"equivalent",
|
||||
"hashbrown",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ipnet"
|
||||
version = "2.11.0"
|
||||
|
|
@ -648,6 +723,7 @@ dependencies = [
|
|||
"tempfile",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
|
@ -895,10 +971,12 @@ dependencies = [
|
|||
"bytes",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"h2",
|
||||
"http",
|
||||
"http-body",
|
||||
"http-body-util",
|
||||
"hyper",
|
||||
"hyper-rustls",
|
||||
"hyper-tls",
|
||||
"hyper-util",
|
||||
"js-sys",
|
||||
|
|
@ -923,6 +1001,20 @@ dependencies = [
|
|||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ring"
|
||||
version = "0.17.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"cfg-if",
|
||||
"getrandom 0.2.16",
|
||||
"libc",
|
||||
"untrusted",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "1.1.3"
|
||||
|
|
@ -936,6 +1028,19 @@ dependencies = [
|
|||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.23.36"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"rustls-pki-types",
|
||||
"rustls-webpki",
|
||||
"subtle",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pki-types"
|
||||
version = "1.13.2"
|
||||
|
|
@ -945,6 +1050,17 @@ dependencies = [
|
|||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.103.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.22"
|
||||
|
|
@ -1111,6 +1227,12 @@ version = "0.11.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
|
||||
|
||||
[[package]]
|
||||
name = "subtle"
|
||||
version = "2.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.114"
|
||||
|
|
@ -1149,7 +1271,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c"
|
||||
dependencies = [
|
||||
"fastrand",
|
||||
"getrandom",
|
||||
"getrandom 0.3.4",
|
||||
"once_cell",
|
||||
"rustix",
|
||||
"windows-sys 0.61.2",
|
||||
|
|
@ -1230,6 +1352,16 @@ dependencies = [
|
|||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.26.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61"
|
||||
dependencies = [
|
||||
"rustls",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.7.18"
|
||||
|
|
@ -1361,6 +1493,12 @@ version = "1.0.22"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
|
||||
|
||||
[[package]]
|
||||
name = "untrusted"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
|
||||
|
||||
[[package]]
|
||||
name = "url"
|
||||
version = "2.5.8"
|
||||
|
|
@ -1508,13 +1646,22 @@ version = "0.2.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.52.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
|
||||
dependencies = [
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.60.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb"
|
||||
dependencies = [
|
||||
"windows-targets",
|
||||
"windows-targets 0.53.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1526,6 +1673,22 @@ dependencies = [
|
|||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm 0.52.6",
|
||||
"windows_aarch64_msvc 0.52.6",
|
||||
"windows_i686_gnu 0.52.6",
|
||||
"windows_i686_gnullvm 0.52.6",
|
||||
"windows_i686_msvc 0.52.6",
|
||||
"windows_x86_64_gnu 0.52.6",
|
||||
"windows_x86_64_gnullvm 0.52.6",
|
||||
"windows_x86_64_msvc 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.53.5"
|
||||
|
|
@ -1533,58 +1696,106 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3"
|
||||
dependencies = [
|
||||
"windows-link",
|
||||
"windows_aarch64_gnullvm",
|
||||
"windows_aarch64_msvc",
|
||||
"windows_i686_gnu",
|
||||
"windows_i686_gnullvm",
|
||||
"windows_i686_msvc",
|
||||
"windows_x86_64_gnu",
|
||||
"windows_x86_64_gnullvm",
|
||||
"windows_x86_64_msvc",
|
||||
"windows_aarch64_gnullvm 0.53.1",
|
||||
"windows_aarch64_msvc 0.53.1",
|
||||
"windows_i686_gnu 0.53.1",
|
||||
"windows_i686_gnullvm 0.53.1",
|
||||
"windows_i686_msvc 0.53.1",
|
||||
"windows_x86_64_gnu 0.53.1",
|
||||
"windows_x86_64_gnullvm 0.53.1",
|
||||
"windows_x86_64_msvc 0.53.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnullvm"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnullvm"
|
||||
version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.53.1"
|
||||
|
|
|
|||
|
|
@ -33,4 +33,3 @@ let history = worker.run("What is 2+2?").await?;
|
|||
## License
|
||||
|
||||
MIT
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ LLMを用いたワーカーを作成する小型のSDK・ライブラリ。
|
|||
|
||||
module構成概念図
|
||||
|
||||
```
|
||||
```plaintext
|
||||
worker
|
||||
├── context
|
||||
├── llm_client
|
||||
|
|
|
|||
90
docs/spec/cancellation.md
Normal file
90
docs/spec/cancellation.md
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
# 非同期キャンセル設計
|
||||
|
||||
Workerの非同期キャンセル機構についての設計ドキュメント。
|
||||
|
||||
## 概要
|
||||
|
||||
`tokio_util::sync::CancellationToken`を用いて、別タスクからWorkerの実行を安全にキャンセルできる。
|
||||
|
||||
```rust
|
||||
let worker = Arc::new(Mutex::new(Worker::new(client)));
|
||||
|
||||
// 実行タスク
|
||||
let w = worker.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
w.lock().await.run("prompt").await
|
||||
});
|
||||
|
||||
// キャンセル
|
||||
worker.lock().await.cancel();
|
||||
```
|
||||
|
||||
## キャンセルポイント
|
||||
|
||||
キャンセルは以下のタイミングでチェックされる:
|
||||
|
||||
1. **ターンループ先頭** — `is_cancelled()`で即座にチェック
|
||||
2. **ストリーム開始前** — `client.stream()`呼び出し時
|
||||
3. **ストリーム受信中** — `tokio::select!`で各イベント受信と並行監視
|
||||
4. **ツール実行中** — `join_all()`と並行監視
|
||||
|
||||
## キャンセル時の処理フロー
|
||||
|
||||
```
|
||||
キャンセル検知
|
||||
↓
|
||||
timeline.abort_current_block() // 進行中ブロックの終端処理
|
||||
↓
|
||||
run_on_abort_hooks("Cancelled") // on_abort フック呼び出し
|
||||
↓
|
||||
Err(WorkerError::Cancelled) // エラー返却
|
||||
```
|
||||
|
||||
## API
|
||||
|
||||
| メソッド | 説明 |
|
||||
| ---------------------- | --------------------------------------------------------- |
|
||||
| `cancel()` | キャンセルをトリガー |
|
||||
| `is_cancelled()` | キャンセル状態を確認 |
|
||||
| `cancellation_token()` | トークンへの参照を取得(`clone()`してタスク間で共有可能) |
|
||||
|
||||
## on_abort フック
|
||||
|
||||
`Hook::on_abort(&self, reason: &str)`がキャンセル時に呼ばれる。
|
||||
クリーンアップ処理やログ記録に使用できる。
|
||||
|
||||
```rust
|
||||
async fn on_abort(&self, reason: &str) -> Result<(), HookError> {
|
||||
log::info!("Aborted: {}", reason);
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
呼び出しタイミング:
|
||||
|
||||
- `WorkerError::Cancelled` — reason: `"Cancelled"`
|
||||
- `ControlFlow::Abort(reason)` — reason: フックが指定した理由
|
||||
|
||||
---
|
||||
|
||||
## 既知の問題
|
||||
|
||||
### 1. キャンセルトークンの再利用不可
|
||||
|
||||
`CancellationToken`は一度キャンセルされると永続的にキャンセル状態になる。
|
||||
同じWorkerインスタンスで再度`run()`を呼ぶと即座に`Cancelled`エラーになる。
|
||||
|
||||
**対応案:**
|
||||
|
||||
- `run()`開始時に新しいトークンを生成する
|
||||
- `reset_cancellation()`メソッドを提供する
|
||||
|
||||
### 2. Sync バウンドの追加(破壊的変更)
|
||||
|
||||
`tokio::select!`使用のため、Handler/Scope型に`Sync`バウンドを追加した。
|
||||
既存のユーザーコードで`Sync`未実装の型を使用している場合、コンパイルエラーになる。
|
||||
|
||||
### 3. エラー時のon_abort呼び出し
|
||||
|
||||
現在、`on_abort`はキャンセルとフックAbort時のみ呼ばれる。
|
||||
ストリームエラー等のその他エラー時には呼ばれないため、一貫性に欠ける可能性がある。
|
||||
|
|
@ -3,7 +3,8 @@
|
|||
## 概要
|
||||
|
||||
HookはWorker層でのターン制御に介入するためのメカニズムです。
|
||||
Claude CodeのHooks機能に着想を得ており、メッセージ送信・ツール実行・ターン終了の各ポイントで処理を差し込むことができます。
|
||||
|
||||
メッセージ送信・ツール実行・ターン終了等の各ポイントで処理を差し込むことができます。
|
||||
|
||||
## コンセプト
|
||||
|
||||
|
|
@ -11,76 +12,82 @@ Claude CodeのHooks機能に着想を得ており、メッセージ送信・ツ
|
|||
- **Contextへのアクセス**: メッセージ履歴を読み書き可能
|
||||
- **非破壊的チェーン**: 複数のHookを登録順に実行、後続Hookへの影響を制御
|
||||
|
||||
## Hook一覧
|
||||
|
||||
| Hook | タイミング | 主な用途 | 戻り値 |
|
||||
| ------------------ | -------------------------- | --------------------- | ---------------------- |
|
||||
| `on_message_send` | LLM送信前 | コンテキスト改変/検証 | `OnMessageSendResult` |
|
||||
| `before_tool_call` | ツール実行前 | 実行許可/引数改変 | `BeforeToolCallResult` |
|
||||
| `after_tool_call` | ツール実行後 | 結果加工/マスキング | `AfterToolCallResult` |
|
||||
| `on_turn_end` | ツールなしでターン終了直前 | 検証/リトライ指示 | `OnTurnEndResult` |
|
||||
| `on_abort` | 中断時 | クリーンアップ/通知 | `()` |
|
||||
|
||||
## Hook Trait
|
||||
|
||||
```rust
|
||||
#[async_trait]
|
||||
pub trait WorkerHook: Send + Sync {
|
||||
/// メッセージ送信前
|
||||
/// リクエストに含まれるメッセージリストを改変できる
|
||||
async fn on_message_send(
|
||||
&self,
|
||||
context: &mut Vec<Message>,
|
||||
) -> Result<ControlFlow, HookError> {
|
||||
Ok(ControlFlow::Continue)
|
||||
}
|
||||
|
||||
/// ツール実行前
|
||||
/// 実行をキャンセルしたり、引数を書き換えることができる
|
||||
async fn before_tool_call(
|
||||
&self,
|
||||
tool_call: &mut ToolCall,
|
||||
) -> Result<ControlFlow, HookError> {
|
||||
Ok(ControlFlow::Continue)
|
||||
}
|
||||
|
||||
/// ツール実行後
|
||||
/// 結果を書き換えたり、隠蔽したりできる
|
||||
async fn after_tool_call(
|
||||
&self,
|
||||
tool_result: &mut ToolResult,
|
||||
) -> Result<ControlFlow, HookError> {
|
||||
Ok(ControlFlow::Continue)
|
||||
}
|
||||
|
||||
/// ターン終了時
|
||||
/// 生成されたメッセージを検査し、必要ならリトライを指示できる
|
||||
async fn on_turn_end(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
) -> Result<TurnResult, HookError> {
|
||||
Ok(TurnResult::Finish)
|
||||
}
|
||||
pub trait Hook<E: HookEventKind>: Send + Sync {
|
||||
async fn call(&self, input: &mut E::Input) -> Result<E::Output, HookError>;
|
||||
}
|
||||
```
|
||||
|
||||
## 制御フロー型
|
||||
|
||||
### ControlFlow
|
||||
### HookEventKind / Result
|
||||
|
||||
Hook処理の継続/中断を制御する列挙型。
|
||||
Hookイベントごとに入力/出力型を分離し、意味のない制御フローを排除する。
|
||||
|
||||
```rust
|
||||
pub enum ControlFlow {
|
||||
/// 処理を続行(後続Hookも実行)
|
||||
pub trait HookEventKind {
|
||||
type Input;
|
||||
type Output;
|
||||
}
|
||||
|
||||
pub struct OnMessageSend;
|
||||
pub struct BeforeToolCall;
|
||||
pub struct AfterToolCall;
|
||||
pub struct OnTurnEnd;
|
||||
pub struct OnAbort;
|
||||
|
||||
pub enum OnMessageSendResult {
|
||||
Continue,
|
||||
Cancel(String),
|
||||
}
|
||||
|
||||
pub enum BeforeToolCallResult {
|
||||
Continue,
|
||||
/// 現在の処理をスキップ(ツール実行をスキップ等)
|
||||
Skip,
|
||||
/// 処理全体を中断(エラーとして扱う)
|
||||
Abort(String),
|
||||
Pause,
|
||||
}
|
||||
|
||||
pub enum AfterToolCallResult {
|
||||
Continue,
|
||||
Abort(String),
|
||||
}
|
||||
|
||||
pub enum OnTurnEndResult {
|
||||
Finish,
|
||||
ContinueWithMessages(Vec<Message>),
|
||||
Paused,
|
||||
}
|
||||
```
|
||||
|
||||
### TurnResult
|
||||
### Tool Call Context
|
||||
|
||||
ターン終了時の判定結果を表す列挙型。
|
||||
`before_tool_call` / `after_tool_call` は、ツール実行の文脈を含む入力を受け取る。
|
||||
|
||||
```rust
|
||||
pub enum TurnResult {
|
||||
/// ターンを正常終了
|
||||
Finish,
|
||||
/// メッセージを追加してターン継続(自己修正など)
|
||||
ContinueWithMessages(Vec<Message>),
|
||||
pub struct ToolCallContext {
|
||||
pub call: ToolCall,
|
||||
pub meta: ToolMeta, // 不変メタデータ
|
||||
pub tool: Arc<dyn Tool>, // 状態アクセス用
|
||||
}
|
||||
|
||||
pub struct ToolResultContext {
|
||||
pub result: ToolResult,
|
||||
pub meta: ToolMeta,
|
||||
pub tool: Arc<dyn Tool>,
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -112,6 +119,8 @@ Worker::run() ループ
|
|||
└─▶ on_turn_end ─────────────────────────────┘
|
||||
最終応答のチェック(Lint/Fmt等)
|
||||
エラーがあればContinueWithMessagesでリトライ
|
||||
|
||||
※ 中断時は on_abort が呼ばれる
|
||||
```
|
||||
|
||||
## 各Hookの詳細
|
||||
|
|
@ -121,10 +130,12 @@ Worker::run() ループ
|
|||
**呼び出しタイミング**: LLMへリクエスト送信前(ターンループの冒頭)
|
||||
|
||||
**用途**:
|
||||
|
||||
- コンテキストへのシステムメッセージ注入
|
||||
- メッセージのバリデーション
|
||||
- 機密情報のフィルタリング
|
||||
- リクエスト内容のログ出力
|
||||
- `OnMessageSendResult::Cancel` による送信キャンセル
|
||||
|
||||
**例**: メッセージにタイムスタンプを追加
|
||||
|
||||
|
|
@ -132,14 +143,14 @@ Worker::run() ループ
|
|||
struct TimestampHook;
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerHook for TimestampHook {
|
||||
async fn on_message_send(
|
||||
impl Hook<OnMessageSend> for TimestampHook {
|
||||
async fn call(
|
||||
&self,
|
||||
context: &mut Vec<Message>,
|
||||
) -> Result<ControlFlow, HookError> {
|
||||
) -> Result<OnMessageSendResult, HookError> {
|
||||
let timestamp = chrono::Local::now().to_rfc3339();
|
||||
context.insert(0, Message::user(format!("[{}]", timestamp)));
|
||||
Ok(ControlFlow::Continue)
|
||||
Ok(OnMessageSendResult::Continue)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
|
@ -149,10 +160,15 @@ impl WorkerHook for TimestampHook {
|
|||
**呼び出しタイミング**: 各ツール実行前(並列実行フェーズの前)
|
||||
|
||||
**用途**:
|
||||
|
||||
- 危険なツールのブロック
|
||||
- 引数のサニタイズ
|
||||
- 確認プロンプトの表示(UIとの連携)
|
||||
- 実行ログの記録
|
||||
- `BeforeToolCallResult::Pause` による一時停止
|
||||
|
||||
**入力**:
|
||||
- `ToolCallContext`(`ToolCall` + `ToolMeta` + `Arc<dyn Tool>`)
|
||||
|
||||
**例**: 特定ツールをブロック
|
||||
|
||||
|
|
@ -162,16 +178,16 @@ struct ToolBlocker {
|
|||
}
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerHook for ToolBlocker {
|
||||
async fn before_tool_call(
|
||||
impl Hook<BeforeToolCall> for ToolBlocker {
|
||||
async fn call(
|
||||
&self,
|
||||
tool_call: &mut ToolCall,
|
||||
) -> Result<ControlFlow, HookError> {
|
||||
if self.blocked_tools.contains(&tool_call.name) {
|
||||
println!("Blocked tool: {}", tool_call.name);
|
||||
Ok(ControlFlow::Skip)
|
||||
ctx: &mut ToolCallContext,
|
||||
) -> Result<BeforeToolCallResult, HookError> {
|
||||
if self.blocked_tools.contains(&ctx.call.name) {
|
||||
println!("Blocked tool: {}", ctx.call.name);
|
||||
Ok(BeforeToolCallResult::Skip)
|
||||
} else {
|
||||
Ok(ControlFlow::Continue)
|
||||
Ok(BeforeToolCallResult::Continue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -182,10 +198,13 @@ impl WorkerHook for ToolBlocker {
|
|||
**呼び出しタイミング**: 各ツール実行後(並列実行フェーズの後)
|
||||
|
||||
**用途**:
|
||||
|
||||
- 結果の加工・フォーマット
|
||||
- 機密情報のマスキング
|
||||
- 結果のキャッシュ
|
||||
- 実行結果のログ出力
|
||||
**入力**:
|
||||
- `ToolResultContext`(`ToolResult` + `ToolMeta` + `Arc<dyn Tool>`)
|
||||
|
||||
**例**: 結果にプレフィックスを追加
|
||||
|
||||
|
|
@ -193,15 +212,15 @@ impl WorkerHook for ToolBlocker {
|
|||
struct ResultFormatter;
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerHook for ResultFormatter {
|
||||
async fn after_tool_call(
|
||||
impl Hook<AfterToolCall> for ResultFormatter {
|
||||
async fn call(
|
||||
&self,
|
||||
tool_result: &mut ToolResult,
|
||||
) -> Result<ControlFlow, HookError> {
|
||||
if !tool_result.is_error {
|
||||
tool_result.content = format!("[OK] {}", tool_result.content);
|
||||
ctx: &mut ToolResultContext,
|
||||
) -> Result<AfterToolCallResult, HookError> {
|
||||
if !ctx.result.is_error {
|
||||
ctx.result.content = format!("[OK] {}", ctx.result.content);
|
||||
}
|
||||
Ok(ControlFlow::Continue)
|
||||
Ok(AfterToolCallResult::Continue)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
|
@ -211,10 +230,22 @@ impl WorkerHook for ResultFormatter {
|
|||
**呼び出しタイミング**: ツール呼び出しなしでターンが終了する直前
|
||||
|
||||
**用途**:
|
||||
|
||||
- 生成されたコードのLint/Fmt
|
||||
- 出力形式のバリデーション
|
||||
- 自己修正のためのリトライ指示
|
||||
- 最終結果のログ出力
|
||||
- `OnTurnEndResult::Paused` による一時停止
|
||||
|
||||
### on_abort
|
||||
|
||||
**呼び出しタイミング**: キャンセル/エラー/AbortなどでWorkerが中断された時
|
||||
|
||||
**用途**:
|
||||
|
||||
- クリーンアップ処理
|
||||
- 中断理由のログ出力
|
||||
- 外部システムへの通知
|
||||
|
||||
**例**: JSON形式のバリデーション
|
||||
|
||||
|
|
@ -222,11 +253,11 @@ impl WorkerHook for ResultFormatter {
|
|||
struct JsonValidator;
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerHook for JsonValidator {
|
||||
async fn on_turn_end(
|
||||
impl Hook<OnTurnEnd> for JsonValidator {
|
||||
async fn call(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
) -> Result<TurnResult, HookError> {
|
||||
messages: &mut Vec<Message>,
|
||||
) -> Result<OnTurnEndResult, HookError> {
|
||||
// 最後のアシスタントメッセージを取得
|
||||
let last = messages.iter().rev()
|
||||
.find(|m| m.role == Role::Assistant);
|
||||
|
|
@ -236,25 +267,25 @@ impl WorkerHook for JsonValidator {
|
|||
// JSONとしてパースを試みる
|
||||
if serde_json::from_str::<serde_json::Value>(text).is_err() {
|
||||
// 失敗したらリトライ指示
|
||||
return Ok(TurnResult::ContinueWithMessages(vec![
|
||||
return Ok(OnTurnEndResult::ContinueWithMessages(vec![
|
||||
Message::user("Invalid JSON. Please fix and try again.")
|
||||
]));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(TurnResult::Finish)
|
||||
Ok(OnTurnEndResult::Finish)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 複数Hookの実行順序
|
||||
|
||||
Hookは**登録順**に実行されます。
|
||||
Hookは**イベントごとに登録順**に実行されます。
|
||||
|
||||
```rust
|
||||
worker.add_hook(HookA); // 1番目に実行
|
||||
worker.add_hook(HookB); // 2番目に実行
|
||||
worker.add_hook(HookC); // 3番目に実行
|
||||
worker.add_before_tool_call_hook(HookA); // 1番目に実行
|
||||
worker.add_before_tool_call_hook(HookB); // 2番目に実行
|
||||
worker.add_before_tool_call_hook(HookC); // 3番目に実行
|
||||
```
|
||||
|
||||
### 制御フローの伝播
|
||||
|
|
@ -262,6 +293,7 @@ worker.add_hook(HookC); // 3番目に実行
|
|||
- `Continue`: 後続Hookも実行
|
||||
- `Skip`: 現在の処理をスキップし、後続Hookは実行しない
|
||||
- `Abort`: 即座にエラーを返し、処理全体を中断
|
||||
- `Pause`: Workerを一時停止(再開は`resume`)
|
||||
|
||||
```
|
||||
Hook A: Continue → Hook B: Skip → (Hook Cは実行されない)
|
||||
|
|
@ -271,40 +303,27 @@ Hook A: Continue → Hook B: Skip → (Hook Cは実行されない)
|
|||
Hook A: Continue → Hook B: Abort("reason")
|
||||
↓
|
||||
WorkerError::Aborted
|
||||
|
||||
Hook A: Continue → Hook B: Pause
|
||||
↓
|
||||
WorkerResult::Paused
|
||||
```
|
||||
|
||||
## 設計上のポイント
|
||||
|
||||
### 1. デフォルト実装
|
||||
### 1. イベントごとの実装
|
||||
|
||||
全メソッドにデフォルト実装があるため、必要なメソッドだけオーバーライドすれば良い。
|
||||
|
||||
```rust
|
||||
struct SimpleLogger;
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerHook for SimpleLogger {
|
||||
// on_message_send だけ実装
|
||||
async fn on_message_send(
|
||||
&self,
|
||||
context: &mut Vec<Message>,
|
||||
) -> Result<ControlFlow, HookError> {
|
||||
println!("Sending {} messages", context.len());
|
||||
Ok(ControlFlow::Continue)
|
||||
}
|
||||
// 他のメソッドはデフォルト(Continue/Finish)
|
||||
}
|
||||
```
|
||||
必要なイベントのみ `Hook<Event>` を実装する。
|
||||
|
||||
### 2. 可変参照による改変
|
||||
|
||||
`&mut`で引数を受け取るため、直接改変が可能。
|
||||
|
||||
```rust
|
||||
async fn before_tool_call(&self, tool_call: &mut ToolCall) -> ... {
|
||||
async fn call(&self, ctx: &mut ToolCallContext) -> ... {
|
||||
// 引数を直接書き換え
|
||||
tool_call.input["sanitized"] = json!(true);
|
||||
Ok(ControlFlow::Continue)
|
||||
ctx.call.input["sanitized"] = json!(true);
|
||||
Ok(BeforeToolCallResult::Continue)
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -316,7 +335,7 @@ async fn before_tool_call(&self, tool_call: &mut ToolCall) -> ... {
|
|||
|
||||
### 4. Send + Sync 要件
|
||||
|
||||
`WorkerHook`は`Send + Sync`を要求するため、スレッドセーフな実装が必要。
|
||||
`Hook`は`Send + Sync`を要求するため、スレッドセーフな実装が必要。
|
||||
状態を持つ場合は`Arc<Mutex<T>>`や`AtomicUsize`などを使用する。
|
||||
|
||||
```rust
|
||||
|
|
@ -325,10 +344,10 @@ struct CountingHook {
|
|||
}
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerHook for CountingHook {
|
||||
async fn before_tool_call(&self, _: &mut ToolCall) -> Result<ControlFlow, HookError> {
|
||||
impl Hook<BeforeToolCall> for CountingHook {
|
||||
async fn call(&self, _: &mut ToolCallContext) -> Result<BeforeToolCallResult, HookError> {
|
||||
self.count.fetch_add(1, Ordering::SeqCst);
|
||||
Ok(ControlFlow::Continue)
|
||||
Ok(BeforeToolCallResult::Continue)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
|
@ -336,7 +355,7 @@ impl WorkerHook for CountingHook {
|
|||
## 典型的なユースケース
|
||||
|
||||
| ユースケース | 使用Hook | 処理内容 |
|
||||
|-------------|----------|----------|
|
||||
| ------------------ | ------------------------ | -------------------------- |
|
||||
| ツール許可制御 | `before_tool_call` | 危険なツールをSkip |
|
||||
| 実行ログ | `before/after_tool_call` | 呼び出しと結果を記録 |
|
||||
| 出力バリデーション | `on_turn_end` | 形式チェック、リトライ指示 |
|
||||
|
|
@ -350,11 +369,14 @@ impl WorkerHook for CountingHook {
|
|||
|
||||
現在のHooks実装は基本的なユースケースをカバーしているが、以下の点について将来的に厳密な仕様を定義する必要がある:
|
||||
|
||||
- **エラーハンドリングの明確化**: `HookError`発生時のリカバリー戦略、部分的な失敗の扱い
|
||||
- **エラーハンドリングの明確化**:
|
||||
`HookError`発生時のリカバリー戦略、部分的な失敗の扱い
|
||||
- **Hook間の依存関係**: 複数Hookの実行順序が結果に影響する場合のセマンティクス
|
||||
- **非同期キャンセル**: Hook実行中のキャンセル(タイムアウト等)の振る舞い
|
||||
- **状態の一貫性**: `on_message_send`で改変されたコンテキストが後続処理で期待通りに反映される保証
|
||||
- **リトライ制限**: `on_turn_end`での`ContinueWithMessages`による無限ループ防止策
|
||||
- **状態の一貫性**:
|
||||
`on_message_send`で改変されたコンテキストが後続処理で期待通りに反映される保証
|
||||
- **リトライ制限**:
|
||||
`on_turn_end`での`ContinueWithMessages`による無限ループ防止策
|
||||
- **Hook優先度**: 登録順以外の優先度指定メカニズムの必要性
|
||||
- **条件付きHook**: 特定条件でのみ有効化されるHookパターン
|
||||
- **テスト容易性**: Hookのモック/スタブ作成のためのユーティリティ
|
||||
|
|
|
|||
|
|
@ -178,41 +178,60 @@ Workerは生成されたラッパー構造体を `Box<dyn Tool>` として保持
|
|||
|
||||
```rust
|
||||
#[async_trait]
|
||||
pub trait WorkerHook: Send + Sync {
|
||||
/// メッセージ送信前。
|
||||
/// リクエストに含まれるメッセージリストを改変できる。
|
||||
async fn on_message_send(&self, context: &mut Vec<Message>) -> Result<ControlFlow, Error> {
|
||||
Ok(ControlFlow::Continue)
|
||||
}
|
||||
|
||||
/// ツール実行前。
|
||||
/// 実行をキャンセルしたり、引数を書き換えることができる。
|
||||
async fn before_tool_call(&self, tool_call: &mut ToolCall) -> Result<ControlFlow, Error> {
|
||||
Ok(ControlFlow::Continue)
|
||||
}
|
||||
|
||||
/// ツール実行後。
|
||||
/// 結果を書き換えたり、隠蔽したりできる。
|
||||
async fn after_tool_call(&self, tool_result: &mut ToolResult) -> Result<ControlFlow, Error> {
|
||||
Ok(ControlFlow::Continue)
|
||||
}
|
||||
|
||||
/// ターン終了時。
|
||||
/// 生成されたメッセージを検査し、必要ならリトライ(ContinueWithMessages)を指示できる。
|
||||
async fn on_turn_end(&self, messages: &[Message]) -> Result<TurnResult, Error> {
|
||||
Ok(TurnResult::Finish)
|
||||
}
|
||||
pub trait Hook<E: HookEventKind>: Send + Sync {
|
||||
async fn call(&self, input: &mut E::Input) -> Result<E::Output, Error>;
|
||||
}
|
||||
|
||||
pub enum ControlFlow {
|
||||
pub trait HookEventKind {
|
||||
type Input;
|
||||
type Output;
|
||||
}
|
||||
|
||||
pub struct OnMessageSend;
|
||||
pub struct BeforeToolCall;
|
||||
pub struct AfterToolCall;
|
||||
pub struct OnTurnEnd;
|
||||
pub struct OnAbort;
|
||||
|
||||
pub enum OnMessageSendResult {
|
||||
Continue,
|
||||
Cancel(String),
|
||||
}
|
||||
|
||||
pub enum BeforeToolCallResult {
|
||||
Continue,
|
||||
Skip, // Tool実行などをスキップ
|
||||
Abort(String), // 処理中断
|
||||
Pause,
|
||||
}
|
||||
|
||||
pub enum TurnResult {
|
||||
pub enum AfterToolCallResult {
|
||||
Continue,
|
||||
Abort(String),
|
||||
}
|
||||
|
||||
pub enum OnTurnEndResult {
|
||||
Finish,
|
||||
ContinueWithMessages(Vec<Message>), // メッセージを追加してターン継続(自己修正など)
|
||||
Paused,
|
||||
}
|
||||
```
|
||||
|
||||
### Tool Call Context
|
||||
|
||||
`before_tool_call` / `after_tool_call` は、ツール実行の文脈を含む入力を受け取る。
|
||||
|
||||
```rust
|
||||
pub struct ToolCallContext {
|
||||
pub call: ToolCall,
|
||||
pub meta: ToolMeta, // 不変メタデータ
|
||||
pub tool: Arc<dyn Tool>, // 状態アクセス用
|
||||
}
|
||||
|
||||
pub struct ToolResultContext {
|
||||
pub result: ToolResult,
|
||||
pub meta: ToolMeta,
|
||||
pub tool: Arc<dyn Tool>,
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -433,4 +452,3 @@ impl<C: LlmClient> Worker<C> {
|
|||
3. **選択的購読**: on_*で必要なイベントだけ、またはSubscriberで一括
|
||||
4. **累積イベントの追加**: Worker層でComplete系イベントを追加提供
|
||||
5. **後方互換性**: 従来の`run()`も引き続き使用可能
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ tracing = "0.1"
|
|||
async-trait = "0.1"
|
||||
futures = "0.3"
|
||||
tokio = { version = "1.49", features = ["macros", "rt-multi-thread"] }
|
||||
reqwest = { version = "0.13.1", default-features = false, features = ["stream", "json", "native-tls"] }
|
||||
tokio-util = "0.7"
|
||||
reqwest = { version = "0.13.1", default-features = false, features = ["stream", "json", "native-tls", "http2"] }
|
||||
eventsource-stream = "0.2"
|
||||
llm-worker-macros = { path = "../llm-worker-macros", version = "0.1" }
|
||||
|
||||
|
|
|
|||
71
llm-worker/examples/worker_cancel_demo.rs
Normal file
71
llm-worker/examples/worker_cancel_demo.rs
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
//! Worker のキャンセル機能のデモンストレーション
|
||||
//!
|
||||
//! ストリーミング受信中に別スレッドからキャンセルする例
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use llm_worker::{Worker, WorkerResult};
|
||||
use llm_worker::llm_client::providers::anthropic::AnthropicClient;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// .envファイルを読み込む
|
||||
dotenv::dotenv().ok();
|
||||
|
||||
// ロギング初期化
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
|
||||
)
|
||||
.init();
|
||||
|
||||
let api_key = std::env::var("ANTHROPIC_API_KEY")
|
||||
.expect("ANTHROPIC_API_KEY environment variable not set");
|
||||
|
||||
let client = AnthropicClient::new(&api_key, "claude-sonnet-4-20250514");
|
||||
let worker = Arc::new(Mutex::new(Worker::new(client)));
|
||||
|
||||
println!("🚀 Starting Worker...");
|
||||
println!("💡 Will cancel after 2 seconds\n");
|
||||
|
||||
// キャンセルトークンを先に取得(ロックを保持しない)
|
||||
let cancel_token = {
|
||||
let w = worker.lock().await;
|
||||
w.cancellation_token().clone()
|
||||
};
|
||||
|
||||
// タスク1: Workerを実行
|
||||
let worker_clone = worker.clone();
|
||||
let task = tokio::spawn(async move {
|
||||
let mut w = worker_clone.lock().await;
|
||||
println!("📡 Sending request to LLM...");
|
||||
|
||||
match w.run("Tell me a very long story about a brave knight. Make it as detailed as possible with many paragraphs.").await {
|
||||
Ok(WorkerResult::Finished(_)) => {
|
||||
println!("✅ Task completed normally");
|
||||
}
|
||||
Ok(WorkerResult::Paused(_)) => {
|
||||
println!("⏸️ Task paused");
|
||||
}
|
||||
Err(e) => {
|
||||
println!("❌ Task error: {}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// タスク2: 2秒後にキャンセル
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||
println!("\n🛑 Cancelling worker...");
|
||||
cancel_token.cancel();
|
||||
});
|
||||
|
||||
// タスク完了を待つ
|
||||
task.await?;
|
||||
|
||||
println!("\n✨ Demo complete!");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -41,7 +41,7 @@ use tracing_subscriber::EnvFilter;
|
|||
use clap::{Parser, ValueEnum};
|
||||
use llm_worker::{
|
||||
Worker,
|
||||
hook::{ControlFlow, HookError, ToolResult, WorkerHook},
|
||||
hook::{AfterToolCall, AfterToolCallResult, Hook, HookError, ToolResult},
|
||||
llm_client::{
|
||||
LlmClient,
|
||||
providers::{
|
||||
|
|
@ -282,11 +282,11 @@ impl ToolResultPrinterHook {
|
|||
}
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerHook for ToolResultPrinterHook {
|
||||
async fn after_tool_call(
|
||||
impl Hook<AfterToolCall> for ToolResultPrinterHook {
|
||||
async fn call(
|
||||
&self,
|
||||
tool_result: &mut ToolResult,
|
||||
) -> Result<ControlFlow, HookError> {
|
||||
) -> Result<AfterToolCallResult, HookError> {
|
||||
let name = self
|
||||
.call_names
|
||||
.lock()
|
||||
|
|
@ -300,7 +300,7 @@ impl WorkerHook for ToolResultPrinterHook {
|
|||
println!(" Result ({}): ✅ {}", name, tool_result.content);
|
||||
}
|
||||
|
||||
Ok(ControlFlow::Continue)
|
||||
Ok(AfterToolCallResult::Continue)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -451,7 +451,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
.on_text_block(StreamingPrinter::new())
|
||||
.on_tool_use_block(ToolCallPrinter::new(tool_call_names.clone()));
|
||||
|
||||
worker.add_hook(ToolResultPrinterHook::new(tool_call_names));
|
||||
worker.add_after_tool_call_hook(ToolResultPrinterHook::new(tool_call_names));
|
||||
|
||||
// ワンショットモード
|
||||
if let Some(prompt) = args.prompt {
|
||||
|
|
|
|||
|
|
@ -8,33 +8,72 @@ use serde_json::Value;
|
|||
use thiserror::Error;
|
||||
|
||||
// =============================================================================
|
||||
// Control Flow Types
|
||||
// Hook Event Kinds
|
||||
// =============================================================================
|
||||
|
||||
/// Hook処理の制御フロー
|
||||
pub trait HookEventKind: Send + Sync + 'static {
|
||||
type Input;
|
||||
type Output;
|
||||
}
|
||||
|
||||
pub struct OnMessageSend;
|
||||
pub struct BeforeToolCall;
|
||||
pub struct AfterToolCall;
|
||||
pub struct OnTurnEnd;
|
||||
pub struct OnAbort;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ControlFlow {
|
||||
/// 処理を続行
|
||||
pub enum OnMessageSendResult {
|
||||
Continue,
|
||||
Cancel(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum BeforeToolCallResult {
|
||||
Continue,
|
||||
/// 現在の処理をスキップ(Tool実行など)
|
||||
Skip,
|
||||
/// 処理を中断
|
||||
Abort(String),
|
||||
/// 処理を一時停止(再開可能)
|
||||
Pause,
|
||||
}
|
||||
|
||||
/// ターン終了時の判定結果
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum AfterToolCallResult {
|
||||
Continue,
|
||||
Abort(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TurnResult {
|
||||
/// ターンを終了
|
||||
pub enum OnTurnEndResult {
|
||||
Finish,
|
||||
/// メッセージを追加してターン継続(自己修正など)
|
||||
ContinueWithMessages(Vec<crate::Message>),
|
||||
/// ターンを一時停止
|
||||
Paused,
|
||||
}
|
||||
|
||||
impl HookEventKind for OnMessageSend {
|
||||
type Input = Vec<crate::Message>;
|
||||
type Output = OnMessageSendResult;
|
||||
}
|
||||
|
||||
impl HookEventKind for BeforeToolCall {
|
||||
type Input = ToolCall;
|
||||
type Output = BeforeToolCallResult;
|
||||
}
|
||||
|
||||
impl HookEventKind for AfterToolCall {
|
||||
type Input = ToolResult;
|
||||
type Output = AfterToolCallResult;
|
||||
}
|
||||
|
||||
impl HookEventKind for OnTurnEnd {
|
||||
type Input = Vec<crate::Message>;
|
||||
type Output = OnTurnEndResult;
|
||||
}
|
||||
|
||||
impl HookEventKind for OnAbort {
|
||||
type Input = String;
|
||||
type Output = ();
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Tool Call / Result Types
|
||||
// =============================================================================
|
||||
|
|
@ -102,85 +141,13 @@ pub enum HookError {
|
|||
}
|
||||
|
||||
// =============================================================================
|
||||
// WorkerHook Trait
|
||||
// Hook Trait
|
||||
// =============================================================================
|
||||
|
||||
/// ターンの進行・ツール実行に介入するためのトレイト
|
||||
/// Hookイベントの処理を行うトレイト
|
||||
///
|
||||
/// Hookを使うと、メッセージ送信前、ツール実行前後、ターン終了時に
|
||||
/// 処理を挟んだり、実行をキャンセルしたりできます。
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```ignore
|
||||
/// use llm_worker::hook::{ControlFlow, HookError, ToolCall, TurnResult, WorkerHook};
|
||||
/// use llm_worker::Message;
|
||||
///
|
||||
/// struct ValidationHook;
|
||||
///
|
||||
/// #[async_trait::async_trait]
|
||||
/// impl WorkerHook for ValidationHook {
|
||||
/// async fn before_tool_call(&self, call: &mut ToolCall) -> Result<ControlFlow, HookError> {
|
||||
/// // 危険なツールをブロック
|
||||
/// if call.name == "delete_all" {
|
||||
/// return Ok(ControlFlow::Skip);
|
||||
/// }
|
||||
/// Ok(ControlFlow::Continue)
|
||||
/// }
|
||||
///
|
||||
/// async fn on_turn_end(&self, messages: &[Message]) -> Result<TurnResult, HookError> {
|
||||
/// // 条件を満たさなければ追加メッセージで継続
|
||||
/// if messages.len() < 3 {
|
||||
/// return Ok(TurnResult::ContinueWithMessages(vec![
|
||||
/// Message::user("Please elaborate.")
|
||||
/// ]));
|
||||
/// }
|
||||
/// Ok(TurnResult::Finish)
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// # デフォルト実装
|
||||
///
|
||||
/// すべてのメソッドにはデフォルト実装があり、何も行わず`Continue`を返します。
|
||||
/// 必要なメソッドのみオーバーライドしてください。
|
||||
/// 各イベント種別は戻り値型が異なるため、`HookEventKind`を介して型を制約する。
|
||||
#[async_trait]
|
||||
pub trait WorkerHook: Send + Sync {
|
||||
/// メッセージ送信前に呼ばれる
|
||||
///
|
||||
/// リクエストに含まれるメッセージリストを参照・改変できます。
|
||||
/// `ControlFlow::Abort`を返すとターンが中断されます。
|
||||
async fn on_message_send(
|
||||
&self,
|
||||
_context: &mut Vec<crate::Message>,
|
||||
) -> Result<ControlFlow, HookError> {
|
||||
Ok(ControlFlow::Continue)
|
||||
}
|
||||
|
||||
/// ツール実行前に呼ばれる
|
||||
///
|
||||
/// ツール呼び出しの引数を書き換えたり、実行をスキップしたりできます。
|
||||
/// `ControlFlow::Skip`を返すとこのツールの実行がスキップされます。
|
||||
async fn before_tool_call(&self, _tool_call: &mut ToolCall) -> Result<ControlFlow, HookError> {
|
||||
Ok(ControlFlow::Continue)
|
||||
}
|
||||
|
||||
/// ツール実行後に呼ばれる
|
||||
///
|
||||
/// ツールの実行結果を書き換えたり、隠蔽したりできます。
|
||||
async fn after_tool_call(
|
||||
&self,
|
||||
_tool_result: &mut ToolResult,
|
||||
) -> Result<ControlFlow, HookError> {
|
||||
Ok(ControlFlow::Continue)
|
||||
}
|
||||
|
||||
/// ターン終了時に呼ばれる
|
||||
///
|
||||
/// 生成されたメッセージを検査し、必要なら追加メッセージで継続を指示できます。
|
||||
/// `TurnResult::ContinueWithMessages`を返すと、指定したメッセージを追加して
|
||||
/// 次のターンに進みます。
|
||||
async fn on_turn_end(&self, _messages: &[crate::Message]) -> Result<TurnResult, HookError> {
|
||||
Ok(TurnResult::Finish)
|
||||
}
|
||||
pub trait Hook<E: HookEventKind>: Send + Sync {
|
||||
async fn call(&self, input: &mut E::Input) -> Result<E::Output, HookError>;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
//!
|
||||
//! - [`Worker`] - LLMとの対話を管理する中心コンポーネント
|
||||
//! - [`tool::Tool`] - LLMから呼び出し可能なツール
|
||||
//! - [`hook::WorkerHook`] - ターン進行への介入
|
||||
//! - [`hook::Hook`] - ターン進行への介入
|
||||
//! - [`subscriber::WorkerSubscriber`] - ストリーミングイベントの購読
|
||||
//!
|
||||
//! # Quick Start
|
||||
|
|
@ -48,9 +48,5 @@ pub mod subscriber;
|
|||
pub mod timeline;
|
||||
pub mod tool;
|
||||
|
||||
// =============================================================================
|
||||
// トップレベル公開(最も頻繁に使う型)
|
||||
// =============================================================================
|
||||
|
||||
pub use message::{ContentPart, Message, MessageContent, Role};
|
||||
pub use worker::{Worker, WorkerConfig, WorkerError};
|
||||
pub use worker::{Worker, WorkerConfig, WorkerError, WorkerResult};
|
||||
|
|
|
|||
|
|
@ -65,10 +65,10 @@ pub trait WorkerSubscriber: Send {
|
|||
///
|
||||
/// ブロック開始時にDefault::default()で生成され、
|
||||
/// ブロック終了時に破棄される。
|
||||
type TextBlockScope: Default + Send;
|
||||
type TextBlockScope: Default + Send + Sync;
|
||||
|
||||
/// ツール使用ブロック処理用のスコープ型
|
||||
type ToolUseBlockScope: Default + Send;
|
||||
type ToolUseBlockScope: Default + Send + Sync;
|
||||
|
||||
// =========================================================================
|
||||
// ブロックイベント(スコープ管理あり)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ use crate::handler::*;
|
|||
/// 各Handlerは独自のScope型を持つため、Timelineで保持するには型消去が必要です。
|
||||
/// 通常は直接使用せず、`Timeline::on_text_block()`などのメソッド経由で
|
||||
/// 自動的にラップされます。
|
||||
pub trait ErasedHandler<K: Kind>: Send {
|
||||
pub trait ErasedHandler<K: Kind>: Send + Sync {
|
||||
/// イベントをディスパッチ
|
||||
fn dispatch(&mut self, event: &K::Event);
|
||||
/// スコープを開始(Block開始時)
|
||||
|
|
@ -54,9 +54,9 @@ where
|
|||
|
||||
impl<H, K> ErasedHandler<K> for HandlerWrapper<H, K>
|
||||
where
|
||||
H: Handler<K> + Send,
|
||||
H: Handler<K> + Send + Sync,
|
||||
K: Kind,
|
||||
H::Scope: Send,
|
||||
H::Scope: Send + Sync,
|
||||
{
|
||||
fn dispatch(&mut self, event: &K::Event) {
|
||||
if let Some(scope) = &mut self.scope {
|
||||
|
|
@ -78,7 +78,7 @@ where
|
|||
// =============================================================================
|
||||
|
||||
/// ブロックハンドラーの型消去trait
|
||||
trait ErasedBlockHandler: Send {
|
||||
trait ErasedBlockHandler: Send + Sync {
|
||||
fn dispatch_start(&mut self, start: &BlockStart);
|
||||
fn dispatch_delta(&mut self, delta: &BlockDelta);
|
||||
fn dispatch_stop(&mut self, stop: &BlockStop);
|
||||
|
|
@ -112,8 +112,8 @@ where
|
|||
|
||||
impl<H> ErasedBlockHandler for TextBlockHandlerWrapper<H>
|
||||
where
|
||||
H: Handler<TextBlockKind> + Send,
|
||||
H::Scope: Send,
|
||||
H: Handler<TextBlockKind> + Send + Sync,
|
||||
H::Scope: Send + Sync,
|
||||
{
|
||||
fn dispatch_start(&mut self, start: &BlockStart) {
|
||||
if let Some(scope) = &mut self.scope {
|
||||
|
|
@ -185,8 +185,8 @@ where
|
|||
|
||||
impl<H> ErasedBlockHandler for ThinkingBlockHandlerWrapper<H>
|
||||
where
|
||||
H: Handler<ThinkingBlockKind> + Send,
|
||||
H::Scope: Send,
|
||||
H: Handler<ThinkingBlockKind> + Send + Sync,
|
||||
H::Scope: Send + Sync,
|
||||
{
|
||||
fn dispatch_start(&mut self, start: &BlockStart) {
|
||||
if let Some(scope) = &mut self.scope {
|
||||
|
|
@ -255,8 +255,8 @@ where
|
|||
|
||||
impl<H> ErasedBlockHandler for ToolUseBlockHandlerWrapper<H>
|
||||
where
|
||||
H: Handler<ToolUseBlockKind> + Send,
|
||||
H::Scope: Send,
|
||||
H: Handler<ToolUseBlockKind> + Send + Sync,
|
||||
H::Scope: Send + Sync,
|
||||
{
|
||||
fn dispatch_start(&mut self, start: &BlockStart) {
|
||||
if let Some(scope) = &mut self.scope {
|
||||
|
|
@ -391,8 +391,8 @@ impl Timeline {
|
|||
/// UsageKind用のHandlerを登録
|
||||
pub fn on_usage<H>(&mut self, handler: H) -> &mut Self
|
||||
where
|
||||
H: Handler<UsageKind> + Send + 'static,
|
||||
H::Scope: Send,
|
||||
H: Handler<UsageKind> + Send + Sync + 'static,
|
||||
H::Scope: Send + Sync,
|
||||
{
|
||||
// Meta系はデフォルトでスコープを開始しておく
|
||||
let mut wrapper = HandlerWrapper::new(handler);
|
||||
|
|
@ -404,8 +404,8 @@ impl Timeline {
|
|||
/// PingKind用のHandlerを登録
|
||||
pub fn on_ping<H>(&mut self, handler: H) -> &mut Self
|
||||
where
|
||||
H: Handler<PingKind> + Send + 'static,
|
||||
H::Scope: Send,
|
||||
H: Handler<PingKind> + Send + Sync + 'static,
|
||||
H::Scope: Send + Sync,
|
||||
{
|
||||
let mut wrapper = HandlerWrapper::new(handler);
|
||||
wrapper.start_scope();
|
||||
|
|
@ -416,8 +416,8 @@ impl Timeline {
|
|||
/// StatusKind用のHandlerを登録
|
||||
pub fn on_status<H>(&mut self, handler: H) -> &mut Self
|
||||
where
|
||||
H: Handler<StatusKind> + Send + 'static,
|
||||
H::Scope: Send,
|
||||
H: Handler<StatusKind> + Send + Sync + 'static,
|
||||
H::Scope: Send + Sync,
|
||||
{
|
||||
let mut wrapper = HandlerWrapper::new(handler);
|
||||
wrapper.start_scope();
|
||||
|
|
@ -428,8 +428,8 @@ impl Timeline {
|
|||
/// ErrorKind用のHandlerを登録
|
||||
pub fn on_error<H>(&mut self, handler: H) -> &mut Self
|
||||
where
|
||||
H: Handler<ErrorKind> + Send + 'static,
|
||||
H::Scope: Send,
|
||||
H: Handler<ErrorKind> + Send + Sync + 'static,
|
||||
H::Scope: Send + Sync,
|
||||
{
|
||||
let mut wrapper = HandlerWrapper::new(handler);
|
||||
wrapper.start_scope();
|
||||
|
|
@ -440,8 +440,8 @@ impl Timeline {
|
|||
/// TextBlockKind用のHandlerを登録
|
||||
pub fn on_text_block<H>(&mut self, handler: H) -> &mut Self
|
||||
where
|
||||
H: Handler<TextBlockKind> + Send + 'static,
|
||||
H::Scope: Send,
|
||||
H: Handler<TextBlockKind> + Send + Sync + 'static,
|
||||
H::Scope: Send + Sync,
|
||||
{
|
||||
self.text_block_handlers
|
||||
.push(Box::new(TextBlockHandlerWrapper::new(handler)));
|
||||
|
|
@ -451,8 +451,8 @@ impl Timeline {
|
|||
/// ThinkingBlockKind用のHandlerを登録
|
||||
pub fn on_thinking_block<H>(&mut self, handler: H) -> &mut Self
|
||||
where
|
||||
H: Handler<ThinkingBlockKind> + Send + 'static,
|
||||
H::Scope: Send,
|
||||
H: Handler<ThinkingBlockKind> + Send + Sync + 'static,
|
||||
H::Scope: Send + Sync,
|
||||
{
|
||||
self.thinking_block_handlers
|
||||
.push(Box::new(ThinkingBlockHandlerWrapper::new(handler)));
|
||||
|
|
@ -462,8 +462,8 @@ impl Timeline {
|
|||
/// ToolUseBlockKind用のHandlerを登録
|
||||
pub fn on_tool_use_block<H>(&mut self, handler: H) -> &mut Self
|
||||
where
|
||||
H: Handler<ToolUseBlockKind> + Send + 'static,
|
||||
H::Scope: Send,
|
||||
H: Handler<ToolUseBlockKind> + Send + Sync + 'static,
|
||||
H::Scope: Send + Sync,
|
||||
{
|
||||
self.tool_use_block_handlers
|
||||
.push(Box::new(ToolUseBlockHandlerWrapper::new(handler)));
|
||||
|
|
@ -578,6 +578,21 @@ impl Timeline {
|
|||
pub fn current_block(&self) -> Option<BlockType> {
|
||||
self.current_block
|
||||
}
|
||||
|
||||
/// 現在アクティブなブロックを中断する
|
||||
///
|
||||
/// キャンセルやエラー時に呼び出し、進行中のブロックに対して
|
||||
/// BlockAbortイベントを発火してスコープをクリーンアップする。
|
||||
pub fn abort_current_block(&mut self) {
|
||||
if let Some(block_type) = self.current_block {
|
||||
let abort = crate::timeline::event::BlockAbort {
|
||||
index: 0, // インデックスは不明なので0
|
||||
block_type,
|
||||
reason: "Cancelled".to_string(),
|
||||
};
|
||||
self.handle_block_abort(&abort);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
|||
|
|
@ -3,11 +3,16 @@ use std::marker::PhantomData;
|
|||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use futures::StreamExt;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, info, trace, warn};
|
||||
|
||||
use crate::{
|
||||
ContentPart, Message, MessageContent, Role,
|
||||
hook::{ControlFlow, HookError, ToolCall, ToolResult, TurnResult, WorkerHook},
|
||||
hook::{
|
||||
AfterToolCall, AfterToolCallResult, BeforeToolCall, BeforeToolCallResult, Hook, HookError,
|
||||
OnAbort, OnMessageSend, OnMessageSendResult, OnTurnEnd, OnTurnEndResult, ToolCall,
|
||||
ToolResult,
|
||||
},
|
||||
llm_client::{ClientError, ConfigWarning, LlmClient, Request, RequestConfig, ToolDefinition},
|
||||
state::{Locked, Mutable, WorkerState},
|
||||
subscriber::{
|
||||
|
|
@ -37,6 +42,9 @@ pub enum WorkerError {
|
|||
/// 処理が中断された
|
||||
#[error("Aborted: {0}")]
|
||||
Aborted(String),
|
||||
/// Cancellation Tokenによって中断された
|
||||
#[error("Cancelled")]
|
||||
Cancelled,
|
||||
/// 設定に関する警告(未サポートのオプション)
|
||||
#[error("Config warnings: {}", .0.iter().map(|w| w.to_string()).collect::<Vec<_>>().join(", "))]
|
||||
ConfigWarnings(Vec<ConfigWarning>),
|
||||
|
|
@ -77,7 +85,7 @@ enum ToolExecutionResult {
|
|||
// =============================================================================
|
||||
|
||||
/// ターンイベントを通知するためのコールバック (型消去)
|
||||
trait TurnNotifier: Send {
|
||||
trait TurnNotifier: Send + Sync {
|
||||
fn on_turn_start(&self, turn: usize);
|
||||
fn on_turn_end(&self, turn: usize);
|
||||
}
|
||||
|
|
@ -149,8 +157,16 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
|||
tool_call_collector: ToolCallCollector,
|
||||
/// 登録されたツール
|
||||
tools: HashMap<String, Arc<dyn Tool>>,
|
||||
/// 登録されたHook
|
||||
hooks: Vec<Box<dyn WorkerHook>>,
|
||||
/// on_message_send Hook
|
||||
hooks_on_message_send: Vec<Box<dyn Hook<OnMessageSend>>>,
|
||||
/// before_tool_call Hook
|
||||
hooks_before_tool_call: Vec<Box<dyn Hook<BeforeToolCall>>>,
|
||||
/// after_tool_call Hook
|
||||
hooks_after_tool_call: Vec<Box<dyn Hook<AfterToolCall>>>,
|
||||
/// on_turn_end Hook
|
||||
hooks_on_turn_end: Vec<Box<dyn Hook<OnTurnEnd>>>,
|
||||
/// on_abort Hook
|
||||
hooks_on_abort: Vec<Box<dyn Hook<OnAbort>>>,
|
||||
/// システムプロンプト
|
||||
system_prompt: Option<String>,
|
||||
/// メッセージ履歴(Workerが所有)
|
||||
|
|
@ -163,6 +179,8 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
|||
turn_notifiers: Vec<Box<dyn TurnNotifier>>,
|
||||
/// リクエスト設定(max_tokens, temperature等)
|
||||
request_config: RequestConfig,
|
||||
/// キャンセレーショントークン(実行中断用)
|
||||
cancellation_token: CancellationToken,
|
||||
/// 状態マーカー
|
||||
_state: PhantomData<S>,
|
||||
}
|
||||
|
|
@ -252,30 +270,29 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Hookを追加する
|
||||
///
|
||||
/// Hookはターンの進行・ツール実行に介入できます。
|
||||
/// 複数のHookを登録した場合、登録順に実行されます。
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```ignore
|
||||
/// use llm_worker::{Worker, WorkerHook, ControlFlow, ToolCall};
|
||||
///
|
||||
/// struct LoggingHook;
|
||||
///
|
||||
/// #[async_trait::async_trait]
|
||||
/// impl WorkerHook for LoggingHook {
|
||||
/// async fn before_tool_call(&self, call: &mut ToolCall) -> Result<ControlFlow, HookError> {
|
||||
/// println!("Calling tool: {}", call.name);
|
||||
/// Ok(ControlFlow::Continue)
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// worker.add_hook(LoggingHook);
|
||||
/// ```
|
||||
pub fn add_hook(&mut self, hook: impl WorkerHook + 'static) {
|
||||
self.hooks.push(Box::new(hook));
|
||||
/// on_message_send Hookを追加する
|
||||
pub fn add_on_message_send_hook(&mut self, hook: impl Hook<OnMessageSend> + 'static) {
|
||||
self.hooks_on_message_send.push(Box::new(hook));
|
||||
}
|
||||
|
||||
/// before_tool_call Hookを追加する
|
||||
pub fn add_before_tool_call_hook(&mut self, hook: impl Hook<BeforeToolCall> + 'static) {
|
||||
self.hooks_before_tool_call.push(Box::new(hook));
|
||||
}
|
||||
|
||||
/// after_tool_call Hookを追加する
|
||||
pub fn add_after_tool_call_hook(&mut self, hook: impl Hook<AfterToolCall> + 'static) {
|
||||
self.hooks_after_tool_call.push(Box::new(hook));
|
||||
}
|
||||
|
||||
/// on_turn_end Hookを追加する
|
||||
pub fn add_on_turn_end_hook(&mut self, hook: impl Hook<OnTurnEnd> + 'static) {
|
||||
self.hooks_on_turn_end.push(Box::new(hook));
|
||||
}
|
||||
|
||||
/// on_abort Hookを追加する
|
||||
pub fn add_on_abort_hook(&mut self, hook: impl Hook<OnAbort> + 'static) {
|
||||
self.hooks_on_abort.push(Box::new(hook));
|
||||
}
|
||||
|
||||
/// タイムラインへの可変参照を取得(追加ハンドラ登録用)
|
||||
|
|
@ -375,6 +392,41 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
self.request_config = config;
|
||||
}
|
||||
|
||||
/// 実行をキャンセルする
|
||||
///
|
||||
/// 現在実行中のストリーミングやツール実行を中断します。
|
||||
/// 次のイベントループのチェックポイントでWorkerError::Cancelledが返されます。
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```ignore
|
||||
/// use std::sync::Arc;
|
||||
/// let worker = Arc::new(Mutex::new(Worker::new(client)));
|
||||
///
|
||||
/// // 別スレッドで実行
|
||||
/// let worker_clone = worker.clone();
|
||||
/// tokio::spawn(async move {
|
||||
/// let mut w = worker_clone.lock().unwrap();
|
||||
/// w.run("Long task...").await
|
||||
/// });
|
||||
///
|
||||
/// // キャンセル
|
||||
/// worker.lock().unwrap().cancel();
|
||||
/// ```
|
||||
pub fn cancel(&self) {
|
||||
self.cancellation_token.cancel();
|
||||
}
|
||||
|
||||
/// キャンセルされているかチェック
|
||||
pub fn is_cancelled(&self) -> bool {
|
||||
self.cancellation_token.is_cancelled()
|
||||
}
|
||||
|
||||
/// キャンセレーショントークンへの参照を取得
|
||||
pub fn cancellation_token(&self) -> &CancellationToken {
|
||||
&self.cancellation_token
|
||||
}
|
||||
|
||||
/// 登録されたツールからToolDefinitionのリストを生成
|
||||
fn build_tool_definitions(&self) -> Vec<ToolDefinition> {
|
||||
self.tools
|
||||
|
|
@ -430,7 +482,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
}
|
||||
|
||||
/// リクエストを構築
|
||||
fn build_request(&self, tool_definitions: &[ToolDefinition]) -> Request {
|
||||
fn build_request(&self, tool_definitions: &[ToolDefinition], context: &[Message]) -> Request {
|
||||
let mut request = Request::new();
|
||||
|
||||
// システムプロンプトを設定
|
||||
|
|
@ -439,7 +491,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
}
|
||||
|
||||
// メッセージを追加
|
||||
for msg in &self.history {
|
||||
for msg in context {
|
||||
// Message から llm_client::Message への変換
|
||||
request = request.message(crate::llm_client::Message {
|
||||
role: match msg.role {
|
||||
|
|
@ -495,36 +547,45 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
}
|
||||
|
||||
/// Hooks: on_message_send
|
||||
async fn run_on_message_send_hooks(&self) -> Result<ControlFlow, WorkerError> {
|
||||
for hook in &self.hooks {
|
||||
// Note: Locked状態でも履歴全体を参照として渡す(変更は不可)
|
||||
// HookのAPIを変更し、immutable参照のみを渡すようにする必要があるかもしれない
|
||||
// 現在は空のVecを渡して回避(要検討)
|
||||
async fn run_on_message_send_hooks(
|
||||
&self,
|
||||
) -> Result<(OnMessageSendResult, Vec<Message>), WorkerError> {
|
||||
let mut temp_context = self.history.clone();
|
||||
let result = hook.on_message_send(&mut temp_context).await?;
|
||||
for hook in &self.hooks_on_message_send {
|
||||
let result = hook.call(&mut temp_context).await?;
|
||||
match result {
|
||||
ControlFlow::Continue => continue,
|
||||
ControlFlow::Skip => return Ok(ControlFlow::Skip),
|
||||
ControlFlow::Abort(reason) => return Ok(ControlFlow::Abort(reason)),
|
||||
ControlFlow::Pause => return Ok(ControlFlow::Pause),
|
||||
OnMessageSendResult::Continue => continue,
|
||||
OnMessageSendResult::Cancel(reason) => {
|
||||
return Ok((OnMessageSendResult::Cancel(reason), temp_context));
|
||||
}
|
||||
}
|
||||
Ok(ControlFlow::Continue)
|
||||
}
|
||||
Ok((OnMessageSendResult::Continue, temp_context))
|
||||
}
|
||||
|
||||
/// Hooks: on_turn_end
|
||||
async fn run_on_turn_end_hooks(&self) -> Result<TurnResult, WorkerError> {
|
||||
for hook in &self.hooks {
|
||||
let result = hook.on_turn_end(&self.history).await?;
|
||||
async fn run_on_turn_end_hooks(&self) -> Result<OnTurnEndResult, WorkerError> {
|
||||
let mut temp_messages = self.history.clone();
|
||||
for hook in &self.hooks_on_turn_end {
|
||||
let result = hook.call(&mut temp_messages).await?;
|
||||
match result {
|
||||
TurnResult::Finish => continue,
|
||||
TurnResult::ContinueWithMessages(msgs) => {
|
||||
return Ok(TurnResult::ContinueWithMessages(msgs));
|
||||
OnTurnEndResult::Finish => continue,
|
||||
OnTurnEndResult::ContinueWithMessages(msgs) => {
|
||||
return Ok(OnTurnEndResult::ContinueWithMessages(msgs));
|
||||
}
|
||||
TurnResult::Paused => return Ok(TurnResult::Paused),
|
||||
OnTurnEndResult::Paused => return Ok(OnTurnEndResult::Paused),
|
||||
}
|
||||
}
|
||||
Ok(TurnResult::Finish)
|
||||
Ok(OnTurnEndResult::Finish)
|
||||
}
|
||||
|
||||
/// Hooks: on_abort
|
||||
async fn run_on_abort_hooks(&self, reason: &str) -> Result<(), WorkerError> {
|
||||
let mut reason = reason.to_string();
|
||||
for hook in &self.hooks_on_abort {
|
||||
hook.call(&mut reason).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 未実行のツール呼び出しがあるかチェック(Pauseからの復帰用)
|
||||
|
|
@ -559,7 +620,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
/// 全てのツールに対してbefore_tool_callフックを実行後、
|
||||
/// 許可されたツールを並列に実行し、結果にafter_tool_callフックを適用する。
|
||||
async fn execute_tools(
|
||||
&self,
|
||||
&mut self,
|
||||
tool_calls: Vec<ToolCall>,
|
||||
) -> Result<ToolExecutionResult, WorkerError> {
|
||||
use futures::future::join_all;
|
||||
|
|
@ -568,18 +629,18 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
let mut approved_calls = Vec::new();
|
||||
for mut tool_call in tool_calls {
|
||||
let mut skip = false;
|
||||
for hook in &self.hooks {
|
||||
let result = hook.before_tool_call(&mut tool_call).await?;
|
||||
for hook in &self.hooks_before_tool_call {
|
||||
let result = hook.call(&mut tool_call).await?;
|
||||
match result {
|
||||
ControlFlow::Continue => {}
|
||||
ControlFlow::Skip => {
|
||||
BeforeToolCallResult::Continue => {}
|
||||
BeforeToolCallResult::Skip => {
|
||||
skip = true;
|
||||
break;
|
||||
}
|
||||
ControlFlow::Abort(reason) => {
|
||||
BeforeToolCallResult::Abort(reason) => {
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
ControlFlow::Pause => {
|
||||
BeforeToolCallResult::Pause => {
|
||||
return Ok(ToolExecutionResult::Paused);
|
||||
}
|
||||
}
|
||||
|
|
@ -589,7 +650,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
}
|
||||
}
|
||||
|
||||
// Phase 2: 許可されたツールを並列実行
|
||||
// Phase 2: 許可されたツールを並列実行(キャンセル可能)
|
||||
let futures: Vec<_> = approved_calls
|
||||
.into_iter()
|
||||
.map(|tool_call| {
|
||||
|
|
@ -612,25 +673,26 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
})
|
||||
.collect();
|
||||
|
||||
let mut results = join_all(futures).await;
|
||||
// ツール実行をキャンセル可能にする
|
||||
let mut results = tokio::select! {
|
||||
results = join_all(futures) => results,
|
||||
_ = self.cancellation_token.cancelled() => {
|
||||
info!("Tool execution cancelled");
|
||||
self.timeline.abort_current_block();
|
||||
self.run_on_abort_hooks("Cancelled").await?;
|
||||
return Err(WorkerError::Cancelled);
|
||||
}
|
||||
};
|
||||
|
||||
// Phase 3: after_tool_call フックを適用
|
||||
for tool_result in &mut results {
|
||||
for hook in &self.hooks {
|
||||
let result = hook.after_tool_call(tool_result).await?;
|
||||
for hook in &self.hooks_after_tool_call {
|
||||
let result = hook.call(tool_result).await?;
|
||||
match result {
|
||||
ControlFlow::Continue => {}
|
||||
ControlFlow::Skip => break,
|
||||
ControlFlow::Abort(reason) => {
|
||||
AfterToolCallResult::Continue => {}
|
||||
AfterToolCallResult::Abort(reason) => {
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
ControlFlow::Pause => {
|
||||
// after_tool_callでのPauseは結果を受け入れた後、次の処理前に止まる動作とする
|
||||
// ここではContinue扱いとし、on_message_send等でPauseすることを期待する
|
||||
// あるいはここでのPauseをサポートする場合は戻り値を調整する必要がある
|
||||
// 現状はログを出してContinue
|
||||
warn!("ControlFlow::Pause in after_tool_call is treated as Continue");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -663,6 +725,14 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
}
|
||||
|
||||
loop {
|
||||
// キャンセルチェック
|
||||
if self.cancellation_token.is_cancelled() {
|
||||
info!("Execution cancelled");
|
||||
self.timeline.abort_current_block();
|
||||
self.run_on_abort_hooks("Cancelled").await?;
|
||||
return Err(WorkerError::Cancelled);
|
||||
}
|
||||
|
||||
// ターン開始を通知
|
||||
let current_turn = self.turn_count;
|
||||
debug!(turn = current_turn, "Turn start");
|
||||
|
|
@ -671,24 +741,21 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
}
|
||||
|
||||
// Hook: on_message_send
|
||||
let control = self.run_on_message_send_hooks().await?;
|
||||
let (control, request_context) = self.run_on_message_send_hooks().await?;
|
||||
match control {
|
||||
ControlFlow::Abort(reason) => {
|
||||
warn!(reason = %reason, "Aborted by hook");
|
||||
OnMessageSendResult::Cancel(reason) => {
|
||||
info!(reason = %reason, "Aborted by hook");
|
||||
for notifier in &self.turn_notifiers {
|
||||
notifier.on_turn_end(current_turn);
|
||||
}
|
||||
self.run_on_abort_hooks(&reason).await?;
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
ControlFlow::Pause | ControlFlow::Skip => {
|
||||
// Skip or Pause -> Pause the worker
|
||||
return Ok(WorkerResult::Paused(&self.history));
|
||||
}
|
||||
ControlFlow::Continue => {}
|
||||
OnMessageSendResult::Continue => {}
|
||||
}
|
||||
|
||||
// リクエスト構築
|
||||
let request = self.build_request(&tool_definitions);
|
||||
let request = self.build_request(&tool_definitions, &request_context);
|
||||
debug!(
|
||||
message_count = request.messages.len(),
|
||||
tool_count = request.tools.len(),
|
||||
|
|
@ -698,10 +765,26 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
|
||||
// ストリーム処理
|
||||
debug!("Starting stream...");
|
||||
let mut stream = self.client.stream(request).await?;
|
||||
let mut event_count = 0;
|
||||
while let Some(event_result) = stream.next().await {
|
||||
match &event_result {
|
||||
|
||||
// ストリームを取得(キャンセル可能)
|
||||
let mut stream = tokio::select! {
|
||||
stream_result = self.client.stream(request) => stream_result?,
|
||||
_ = self.cancellation_token.cancelled() => {
|
||||
info!("Cancelled before stream started");
|
||||
self.timeline.abort_current_block();
|
||||
self.run_on_abort_hooks("Cancelled").await?;
|
||||
return Err(WorkerError::Cancelled);
|
||||
}
|
||||
};
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// ストリームからイベントを受信
|
||||
event_result = stream.next() => {
|
||||
match event_result {
|
||||
Some(result) => {
|
||||
match &result {
|
||||
Ok(event) => {
|
||||
trace!(event = ?event, "Received event");
|
||||
event_count += 1;
|
||||
|
|
@ -710,10 +793,22 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
warn!(error = %e, "Stream error");
|
||||
}
|
||||
}
|
||||
let event = event_result?;
|
||||
let event = result?;
|
||||
let timeline_event: crate::timeline::event::Event = event.into();
|
||||
self.timeline.dispatch(&timeline_event);
|
||||
}
|
||||
None => break, // ストリーム終了
|
||||
}
|
||||
}
|
||||
// キャンセル待機
|
||||
_ = self.cancellation_token.cancelled() => {
|
||||
info!("Stream cancelled");
|
||||
self.timeline.abort_current_block();
|
||||
self.run_on_abort_hooks("Cancelled").await?;
|
||||
return Err(WorkerError::Cancelled);
|
||||
}
|
||||
}
|
||||
}
|
||||
debug!(event_count = event_count, "Stream completed");
|
||||
|
||||
// ターン終了を通知
|
||||
|
|
@ -736,14 +831,14 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
// ツール呼び出しなし → ターン終了判定
|
||||
let turn_result = self.run_on_turn_end_hooks().await?;
|
||||
match turn_result {
|
||||
TurnResult::Finish => {
|
||||
OnTurnEndResult::Finish => {
|
||||
return Ok(WorkerResult::Finished(&self.history));
|
||||
}
|
||||
TurnResult::ContinueWithMessages(additional) => {
|
||||
OnTurnEndResult::ContinueWithMessages(additional) => {
|
||||
self.history.extend(additional);
|
||||
continue;
|
||||
}
|
||||
TurnResult::Paused => {
|
||||
OnTurnEndResult::Paused => {
|
||||
return Ok(WorkerResult::Paused(&self.history));
|
||||
}
|
||||
}
|
||||
|
|
@ -790,13 +885,18 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
|||
text_block_collector,
|
||||
tool_call_collector,
|
||||
tools: HashMap::new(),
|
||||
hooks: Vec::new(),
|
||||
hooks_on_message_send: Vec::new(),
|
||||
hooks_before_tool_call: Vec::new(),
|
||||
hooks_after_tool_call: Vec::new(),
|
||||
hooks_on_turn_end: Vec::new(),
|
||||
hooks_on_abort: Vec::new(),
|
||||
system_prompt: None,
|
||||
history: Vec::new(),
|
||||
locked_prefix_len: 0,
|
||||
turn_count: 0,
|
||||
turn_notifiers: Vec::new(),
|
||||
request_config: RequestConfig::default(),
|
||||
cancellation_token: CancellationToken::new(),
|
||||
_state: PhantomData,
|
||||
}
|
||||
}
|
||||
|
|
@ -958,13 +1058,18 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
|||
text_block_collector: self.text_block_collector,
|
||||
tool_call_collector: self.tool_call_collector,
|
||||
tools: self.tools,
|
||||
hooks: self.hooks,
|
||||
hooks_on_message_send: self.hooks_on_message_send,
|
||||
hooks_before_tool_call: self.hooks_before_tool_call,
|
||||
hooks_after_tool_call: self.hooks_after_tool_call,
|
||||
hooks_on_turn_end: self.hooks_on_turn_end,
|
||||
hooks_on_abort: self.hooks_on_abort,
|
||||
system_prompt: self.system_prompt,
|
||||
history: self.history,
|
||||
locked_prefix_len,
|
||||
turn_count: self.turn_count,
|
||||
turn_notifiers: self.turn_notifiers,
|
||||
request_config: self.request_config,
|
||||
cancellation_token: self.cancellation_token,
|
||||
_state: PhantomData,
|
||||
}
|
||||
}
|
||||
|
|
@ -1032,13 +1137,18 @@ impl<C: LlmClient> Worker<C, Locked> {
|
|||
text_block_collector: self.text_block_collector,
|
||||
tool_call_collector: self.tool_call_collector,
|
||||
tools: self.tools,
|
||||
hooks: self.hooks,
|
||||
hooks_on_message_send: self.hooks_on_message_send,
|
||||
hooks_before_tool_call: self.hooks_before_tool_call,
|
||||
hooks_after_tool_call: self.hooks_after_tool_call,
|
||||
hooks_on_turn_end: self.hooks_on_turn_end,
|
||||
hooks_on_abort: self.hooks_on_abort,
|
||||
system_prompt: self.system_prompt,
|
||||
history: self.history,
|
||||
locked_prefix_len: 0,
|
||||
turn_count: self.turn_count,
|
||||
turn_notifiers: self.turn_notifiers,
|
||||
request_config: self.request_config,
|
||||
cancellation_token: self.cancellation_token,
|
||||
_state: PhantomData,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,10 @@ use std::time::{Duration, Instant};
|
|||
|
||||
use async_trait::async_trait;
|
||||
use llm_worker::Worker;
|
||||
use llm_worker::hook::{ControlFlow, HookError, ToolCall, ToolResult, WorkerHook};
|
||||
use llm_worker::hook::{
|
||||
AfterToolCall, AfterToolCallResult, BeforeToolCall, BeforeToolCallResult, Hook, HookError,
|
||||
ToolCall, ToolResult,
|
||||
};
|
||||
use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent};
|
||||
use llm_worker::tool::{Tool, ToolError};
|
||||
|
||||
|
|
@ -158,20 +161,17 @@ async fn test_before_tool_call_skip() {
|
|||
struct BlockingHook;
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerHook for BlockingHook {
|
||||
async fn before_tool_call(
|
||||
&self,
|
||||
tool_call: &mut ToolCall,
|
||||
) -> Result<ControlFlow, HookError> {
|
||||
impl Hook<BeforeToolCall> for BlockingHook {
|
||||
async fn call(&self, tool_call: &mut ToolCall) -> Result<BeforeToolCallResult, HookError> {
|
||||
if tool_call.name == "blocked_tool" {
|
||||
Ok(ControlFlow::Skip)
|
||||
Ok(BeforeToolCallResult::Skip)
|
||||
} else {
|
||||
Ok(ControlFlow::Continue)
|
||||
Ok(BeforeToolCallResult::Continue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
worker.add_hook(BlockingHook);
|
||||
worker.add_before_tool_call_hook(BlockingHook);
|
||||
|
||||
let _result = worker.run("Test hook").await;
|
||||
|
||||
|
|
@ -242,19 +242,19 @@ async fn test_after_tool_call_modification() {
|
|||
}
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerHook for ModifyingHook {
|
||||
async fn after_tool_call(
|
||||
impl Hook<AfterToolCall> for ModifyingHook {
|
||||
async fn call(
|
||||
&self,
|
||||
tool_result: &mut ToolResult,
|
||||
) -> Result<ControlFlow, HookError> {
|
||||
) -> Result<AfterToolCallResult, HookError> {
|
||||
tool_result.content = format!("[Modified] {}", tool_result.content);
|
||||
*self.modified_content.lock().unwrap() = Some(tool_result.content.clone());
|
||||
Ok(ControlFlow::Continue)
|
||||
Ok(AfterToolCallResult::Continue)
|
||||
}
|
||||
}
|
||||
|
||||
let modified_content = Arc::new(std::sync::Mutex::new(None));
|
||||
worker.add_hook(ModifyingHook {
|
||||
worker.add_after_tool_call_hook(ModifyingHook {
|
||||
modified_content: modified_content.clone(),
|
||||
});
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user