Compare commits
4 Commits
33f1c218f2
...
c281248bf8
| Author | SHA1 | Date | |
|---|---|---|---|
| c281248bf8 | |||
| a2f53d7879 | |||
| 16fda38039 | |||
| 5691b09fc8 |
239
Cargo.lock
generated
239
Cargo.lock
generated
|
|
@ -104,9 +104,9 @@ checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cc"
|
name = "cc"
|
||||||
version = "1.2.51"
|
version = "1.2.52"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7a0aeaff4ff1a90589618835a598e545176939b97874f7abc7851caa0618f203"
|
checksum = "cd4932aefd12402b36c60956a4fe0035421f544799057659ff86f923657aada3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"find-msvc-tools",
|
"find-msvc-tools",
|
||||||
"shlex",
|
"shlex",
|
||||||
|
|
@ -203,6 +203,12 @@ version = "1.0.20"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555"
|
checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "equivalent"
|
||||||
|
version = "1.0.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "errno"
|
name = "errno"
|
||||||
version = "0.3.14"
|
version = "0.3.14"
|
||||||
|
|
@ -232,9 +238,15 @@ checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "find-msvc-tools"
|
name = "find-msvc-tools"
|
||||||
version = "0.1.6"
|
version = "0.1.7"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "645cbb3a84e60b7531617d5ae4e57f7e27308f6445f5abf653209ea76dec8dff"
|
checksum = "f449e6c6c08c865631d4890cfacf252b3d396c9bcc83adb6623cdb02a8336c41"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fnv"
|
||||||
|
version = "1.0.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "foreign-types"
|
name = "foreign-types"
|
||||||
|
|
@ -349,6 +361,17 @@ dependencies = [
|
||||||
"slab",
|
"slab",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "getrandom"
|
||||||
|
version = "0.2.16"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"libc",
|
||||||
|
"wasi",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "getrandom"
|
name = "getrandom"
|
||||||
version = "0.3.4"
|
version = "0.3.4"
|
||||||
|
|
@ -361,6 +384,31 @@ dependencies = [
|
||||||
"wasip2",
|
"wasip2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "h2"
|
||||||
|
version = "0.4.13"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54"
|
||||||
|
dependencies = [
|
||||||
|
"atomic-waker",
|
||||||
|
"bytes",
|
||||||
|
"fnv",
|
||||||
|
"futures-core",
|
||||||
|
"futures-sink",
|
||||||
|
"http",
|
||||||
|
"indexmap",
|
||||||
|
"slab",
|
||||||
|
"tokio",
|
||||||
|
"tokio-util",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hashbrown"
|
||||||
|
version = "0.16.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "heck"
|
name = "heck"
|
||||||
version = "0.5.0"
|
version = "0.5.0"
|
||||||
|
|
@ -416,6 +464,7 @@ dependencies = [
|
||||||
"bytes",
|
"bytes",
|
||||||
"futures-channel",
|
"futures-channel",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
|
"h2",
|
||||||
"http",
|
"http",
|
||||||
"http-body",
|
"http-body",
|
||||||
"httparse",
|
"httparse",
|
||||||
|
|
@ -427,6 +476,22 @@ dependencies = [
|
||||||
"want",
|
"want",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hyper-rustls"
|
||||||
|
version = "0.27.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58"
|
||||||
|
dependencies = [
|
||||||
|
"http",
|
||||||
|
"hyper",
|
||||||
|
"hyper-util",
|
||||||
|
"rustls",
|
||||||
|
"rustls-pki-types",
|
||||||
|
"tokio",
|
||||||
|
"tokio-rustls",
|
||||||
|
"tower-service",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hyper-tls"
|
name = "hyper-tls"
|
||||||
version = "0.6.0"
|
version = "0.6.0"
|
||||||
|
|
@ -569,6 +634,16 @@ dependencies = [
|
||||||
"icu_properties",
|
"icu_properties",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "indexmap"
|
||||||
|
version = "2.13.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017"
|
||||||
|
dependencies = [
|
||||||
|
"equivalent",
|
||||||
|
"hashbrown",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ipnet"
|
name = "ipnet"
|
||||||
version = "2.11.0"
|
version = "2.11.0"
|
||||||
|
|
@ -648,6 +723,7 @@ dependencies = [
|
||||||
"tempfile",
|
"tempfile",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tokio-util",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
]
|
]
|
||||||
|
|
@ -895,10 +971,12 @@ dependencies = [
|
||||||
"bytes",
|
"bytes",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
|
"h2",
|
||||||
"http",
|
"http",
|
||||||
"http-body",
|
"http-body",
|
||||||
"http-body-util",
|
"http-body-util",
|
||||||
"hyper",
|
"hyper",
|
||||||
|
"hyper-rustls",
|
||||||
"hyper-tls",
|
"hyper-tls",
|
||||||
"hyper-util",
|
"hyper-util",
|
||||||
"js-sys",
|
"js-sys",
|
||||||
|
|
@ -923,6 +1001,20 @@ dependencies = [
|
||||||
"web-sys",
|
"web-sys",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ring"
|
||||||
|
version = "0.17.14"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7"
|
||||||
|
dependencies = [
|
||||||
|
"cc",
|
||||||
|
"cfg-if",
|
||||||
|
"getrandom 0.2.16",
|
||||||
|
"libc",
|
||||||
|
"untrusted",
|
||||||
|
"windows-sys 0.52.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustix"
|
name = "rustix"
|
||||||
version = "1.1.3"
|
version = "1.1.3"
|
||||||
|
|
@ -936,6 +1028,19 @@ dependencies = [
|
||||||
"windows-sys 0.61.2",
|
"windows-sys 0.61.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustls"
|
||||||
|
version = "0.23.36"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b"
|
||||||
|
dependencies = [
|
||||||
|
"once_cell",
|
||||||
|
"rustls-pki-types",
|
||||||
|
"rustls-webpki",
|
||||||
|
"subtle",
|
||||||
|
"zeroize",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustls-pki-types"
|
name = "rustls-pki-types"
|
||||||
version = "1.13.2"
|
version = "1.13.2"
|
||||||
|
|
@ -945,6 +1050,17 @@ dependencies = [
|
||||||
"zeroize",
|
"zeroize",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustls-webpki"
|
||||||
|
version = "0.103.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52"
|
||||||
|
dependencies = [
|
||||||
|
"ring",
|
||||||
|
"rustls-pki-types",
|
||||||
|
"untrusted",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustversion"
|
name = "rustversion"
|
||||||
version = "1.0.22"
|
version = "1.0.22"
|
||||||
|
|
@ -1111,6 +1227,12 @@ version = "0.11.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
|
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "subtle"
|
||||||
|
version = "2.6.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "syn"
|
name = "syn"
|
||||||
version = "2.0.114"
|
version = "2.0.114"
|
||||||
|
|
@ -1149,7 +1271,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c"
|
checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"fastrand",
|
"fastrand",
|
||||||
"getrandom",
|
"getrandom 0.3.4",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"rustix",
|
"rustix",
|
||||||
"windows-sys 0.61.2",
|
"windows-sys 0.61.2",
|
||||||
|
|
@ -1230,6 +1352,16 @@ dependencies = [
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokio-rustls"
|
||||||
|
version = "0.26.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61"
|
||||||
|
dependencies = [
|
||||||
|
"rustls",
|
||||||
|
"tokio",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio-util"
|
name = "tokio-util"
|
||||||
version = "0.7.18"
|
version = "0.7.18"
|
||||||
|
|
@ -1361,6 +1493,12 @@ version = "1.0.22"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
|
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "untrusted"
|
||||||
|
version = "0.9.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "url"
|
name = "url"
|
||||||
version = "2.5.8"
|
version = "2.5.8"
|
||||||
|
|
@ -1508,13 +1646,22 @@ version = "0.2.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
|
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows-sys"
|
||||||
|
version = "0.52.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
|
||||||
|
dependencies = [
|
||||||
|
"windows-targets 0.52.6",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows-sys"
|
name = "windows-sys"
|
||||||
version = "0.60.2"
|
version = "0.60.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb"
|
checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows-targets",
|
"windows-targets 0.53.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -1526,6 +1673,22 @@ dependencies = [
|
||||||
"windows-link",
|
"windows-link",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows-targets"
|
||||||
|
version = "0.52.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
|
||||||
|
dependencies = [
|
||||||
|
"windows_aarch64_gnullvm 0.52.6",
|
||||||
|
"windows_aarch64_msvc 0.52.6",
|
||||||
|
"windows_i686_gnu 0.52.6",
|
||||||
|
"windows_i686_gnullvm 0.52.6",
|
||||||
|
"windows_i686_msvc 0.52.6",
|
||||||
|
"windows_x86_64_gnu 0.52.6",
|
||||||
|
"windows_x86_64_gnullvm 0.52.6",
|
||||||
|
"windows_x86_64_msvc 0.52.6",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows-targets"
|
name = "windows-targets"
|
||||||
version = "0.53.5"
|
version = "0.53.5"
|
||||||
|
|
@ -1533,58 +1696,106 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3"
|
checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows-link",
|
"windows-link",
|
||||||
"windows_aarch64_gnullvm",
|
"windows_aarch64_gnullvm 0.53.1",
|
||||||
"windows_aarch64_msvc",
|
"windows_aarch64_msvc 0.53.1",
|
||||||
"windows_i686_gnu",
|
"windows_i686_gnu 0.53.1",
|
||||||
"windows_i686_gnullvm",
|
"windows_i686_gnullvm 0.53.1",
|
||||||
"windows_i686_msvc",
|
"windows_i686_msvc 0.53.1",
|
||||||
"windows_x86_64_gnu",
|
"windows_x86_64_gnu 0.53.1",
|
||||||
"windows_x86_64_gnullvm",
|
"windows_x86_64_gnullvm 0.53.1",
|
||||||
"windows_x86_64_msvc",
|
"windows_x86_64_msvc 0.53.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_aarch64_gnullvm"
|
||||||
|
version = "0.52.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_aarch64_gnullvm"
|
name = "windows_aarch64_gnullvm"
|
||||||
version = "0.53.1"
|
version = "0.53.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53"
|
checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_aarch64_msvc"
|
||||||
|
version = "0.52.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_aarch64_msvc"
|
name = "windows_aarch64_msvc"
|
||||||
version = "0.53.1"
|
version = "0.53.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006"
|
checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_i686_gnu"
|
||||||
|
version = "0.52.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_i686_gnu"
|
name = "windows_i686_gnu"
|
||||||
version = "0.53.1"
|
version = "0.53.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3"
|
checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_i686_gnullvm"
|
||||||
|
version = "0.52.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_i686_gnullvm"
|
name = "windows_i686_gnullvm"
|
||||||
version = "0.53.1"
|
version = "0.53.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c"
|
checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_i686_msvc"
|
||||||
|
version = "0.52.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_i686_msvc"
|
name = "windows_i686_msvc"
|
||||||
version = "0.53.1"
|
version = "0.53.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2"
|
checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_x86_64_gnu"
|
||||||
|
version = "0.52.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_gnu"
|
name = "windows_x86_64_gnu"
|
||||||
version = "0.53.1"
|
version = "0.53.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499"
|
checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_x86_64_gnullvm"
|
||||||
|
version = "0.52.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_gnullvm"
|
name = "windows_x86_64_gnullvm"
|
||||||
version = "0.53.1"
|
version = "0.53.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1"
|
checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_x86_64_msvc"
|
||||||
|
version = "0.52.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_msvc"
|
name = "windows_x86_64_msvc"
|
||||||
version = "0.53.1"
|
version = "0.53.1"
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ Rusty, Efficient, and Agentic LLM Client Library
|
||||||
- Tool System: Define tools as async functions. The Worker automatically parses LLM tool calls, executes them in parallel, and feeds results back.
|
- Tool System: Define tools as async functions. The Worker automatically parses LLM tool calls, executes them in parallel, and feeds results back.
|
||||||
- Hook System: Intercept execution flow with `before_tool_call`, `after_tool_call`, and `on_turn_end` hooks for validation, logging, or self-correction.
|
- Hook System: Intercept execution flow with `before_tool_call`, `after_tool_call`, and `on_turn_end` hooks for validation, logging, or self-correction.
|
||||||
- Event-Driven Streaming: Subscribe to real-time events (text deltas, tool calls, usage) for responsive UIs.
|
- Event-Driven Streaming: Subscribe to real-time events (text deltas, tool calls, usage) for responsive UIs.
|
||||||
- Cache-Aware State Management: Type-state pattern (`Mutable` → `Locked`) ensures KV cache efficiency by protecting the conversation prefix.
|
- Cache-Aware State Management: Type-state pattern (`Mutable` → `CacheLocked`) ensures KV cache efficiency by protecting the conversation prefix.
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
|
|
@ -33,4 +33,3 @@ let history = worker.run("What is 2+2?").await?;
|
||||||
## License
|
## License
|
||||||
|
|
||||||
MIT
|
MIT
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ LLMを用いたワーカーを作成する小型のSDK・ライブラリ。
|
||||||
|
|
||||||
module構成概念図
|
module構成概念図
|
||||||
|
|
||||||
```
|
```plaintext
|
||||||
worker
|
worker
|
||||||
├── context
|
├── context
|
||||||
├── llm_client
|
├── llm_client
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ RustのType-stateパターンを利用し、Workerの状態によって利用可
|
||||||
* 自由な編集が可能な状態。
|
* 自由な編集が可能な状態。
|
||||||
* システムプロンプトの設定・変更が可能。
|
* システムプロンプトの設定・変更が可能。
|
||||||
* メッセージ履歴の初期構築(ロード、編集)が可能。
|
* メッセージ履歴の初期構築(ロード、編集)が可能。
|
||||||
* **`Locked` (キャッシュ保護状態)**
|
* **`CacheLocked` (キャッシュ保護状態)**
|
||||||
* キャッシュの有効活用を目的とした、前方不変状態。
|
* キャッシュの有効活用を目的とした、前方不変状態。
|
||||||
* **システムプロンプトの変更不可**。
|
* **システムプロンプトの変更不可**。
|
||||||
* **既存メッセージ履歴の変更不可**(追記のみ許可)。
|
* **既存メッセージ履歴の変更不可**(追記のみ許可)。
|
||||||
|
|
@ -47,7 +47,7 @@ worker.history_mut().push(initial_message);
|
||||||
|
|
||||||
// 3. ロックしてLocked状態へ遷移
|
// 3. ロックしてLocked状態へ遷移
|
||||||
// これにより、ここまでのコンテキストが "Fixed Prefix" として扱われる
|
// これにより、ここまでのコンテキストが "Fixed Prefix" として扱われる
|
||||||
let mut locked_worker: Worker<Locked> = worker.lock();
|
let mut locked_worker: Worker<CacheLocked> = worker.lock();
|
||||||
|
|
||||||
// 4. 利用 (Locked状態)
|
// 4. 利用 (Locked状態)
|
||||||
// 実行は可能。新しいメッセージは履歴の末尾に追記される。
|
// 実行は可能。新しいメッセージは履歴の末尾に追記される。
|
||||||
|
|
@ -65,4 +65,4 @@ locked_worker.run(new_user_input).await?;
|
||||||
|
|
||||||
* **状態パラメータの導入**: `Worker<S: WorkerState>` の導入。
|
* **状態パラメータの導入**: `Worker<S: WorkerState>` の導入。
|
||||||
* **コンテキスト所有権の委譲**: `run` メソッドの引数でコンテキストを受け取るのではなく、`Worker` 内部に `history: Vec<Message>` を保持し管理する形へ移行する。
|
* **コンテキスト所有権の委譲**: `run` メソッドの引数でコンテキストを受け取るのではなく、`Worker` 内部に `history: Vec<Message>` を保持し管理する形へ移行する。
|
||||||
* **APIの分離**: `Mutable` 特有のメソッド(setter等)と、`Locked` でも使えるメソッド(実行、参照等)をトレイト境界で分離する。
|
* **APIの分離**: `Mutable` 特有のメソッド(setter等)と、`CacheLocked` でも使えるメソッド(実行、参照等)をトレイト境界で分離する。
|
||||||
|
|
|
||||||
70
docs/spec/cancellation.md
Normal file
70
docs/spec/cancellation.md
Normal file
|
|
@ -0,0 +1,70 @@
|
||||||
|
# 非同期キャンセル設計
|
||||||
|
|
||||||
|
Workerの非同期キャンセル機構についての設計ドキュメント。
|
||||||
|
|
||||||
|
## 概要
|
||||||
|
|
||||||
|
`tokio::sync::mpsc`の通知チャネルを用いて、別タスクからWorkerの実行を安全にキャンセルできる。
|
||||||
|
|
||||||
|
```rust
|
||||||
|
let worker = Arc::new(Mutex::new(Worker::new(client)));
|
||||||
|
|
||||||
|
// 実行タスク
|
||||||
|
let w = worker.clone();
|
||||||
|
let handle = tokio::spawn(async move {
|
||||||
|
w.lock().await.run("prompt").await
|
||||||
|
});
|
||||||
|
|
||||||
|
// キャンセル
|
||||||
|
worker.lock().await.cancel();
|
||||||
|
```
|
||||||
|
|
||||||
|
## キャンセル時の処理フロー
|
||||||
|
|
||||||
|
```
|
||||||
|
キャンセル検知
|
||||||
|
↓
|
||||||
|
timeline.abort_current_block() // 進行中ブロックの終端処理
|
||||||
|
↓
|
||||||
|
run_on_abort_hooks("Cancelled") // on_abort フック呼び出し
|
||||||
|
↓
|
||||||
|
Err(WorkerError::Cancelled) // エラー返却
|
||||||
|
```
|
||||||
|
|
||||||
|
## API
|
||||||
|
|
||||||
|
| メソッド | 説明 |
|
||||||
|
| ----------------- | ------------------------------ |
|
||||||
|
| `cancel()` | キャンセルをトリガー |
|
||||||
|
| `cancel_sender()` | キャンセル通知用のSenderを取得 |
|
||||||
|
|
||||||
|
## on_abort フック
|
||||||
|
|
||||||
|
`Hook::on_abort(&self, reason: &str)`がキャンセル時に呼ばれる。
|
||||||
|
クリーンアップ処理やログ記録に使用できる。
|
||||||
|
|
||||||
|
```rust
|
||||||
|
async fn on_abort(&self, reason: &str) -> Result<(), HookError> {
|
||||||
|
log::info!("Aborted: {}", reason);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
呼び出しタイミング:
|
||||||
|
|
||||||
|
- `WorkerError::Cancelled` — reason: `"Cancelled"`
|
||||||
|
- `ControlFlow::Abort(reason)` — reason: フックが指定した理由
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 既知の問題
|
||||||
|
|
||||||
|
### on_abort の発火基準
|
||||||
|
|
||||||
|
`on_abort` は **interrupt(中断)** された場合に必ず発火する。
|
||||||
|
|
||||||
|
interrupt の例:
|
||||||
|
|
||||||
|
- `WorkerError::Cancelled`(キャンセル)
|
||||||
|
- `WorkerError::Aborted`(フックによるAbort)
|
||||||
|
- ストリーム/ツール/クライアント/Hook の各種エラーで処理が中断された場合
|
||||||
|
|
@ -3,7 +3,8 @@
|
||||||
## 概要
|
## 概要
|
||||||
|
|
||||||
HookはWorker層でのターン制御に介入するためのメカニズムです。
|
HookはWorker層でのターン制御に介入するためのメカニズムです。
|
||||||
Claude CodeのHooks機能に着想を得ており、メッセージ送信・ツール実行・ターン終了の各ポイントで処理を差し込むことができます。
|
|
||||||
|
メッセージ送信・ツール実行・ターン終了等の各ポイントで処理を差し込むことができます。
|
||||||
|
|
||||||
## コンセプト
|
## コンセプト
|
||||||
|
|
||||||
|
|
@ -11,120 +12,184 @@ Claude CodeのHooks機能に着想を得ており、メッセージ送信・ツ
|
||||||
- **Contextへのアクセス**: メッセージ履歴を読み書き可能
|
- **Contextへのアクセス**: メッセージ履歴を読み書き可能
|
||||||
- **非破壊的チェーン**: 複数のHookを登録順に実行、後続Hookへの影響を制御
|
- **非破壊的チェーン**: 複数のHookを登録順に実行、後続Hookへの影響を制御
|
||||||
|
|
||||||
|
## Hook一覧
|
||||||
|
|
||||||
|
| Hook | タイミング | 主な用途 | 戻り値 |
|
||||||
|
| ------------------ | -------------------------- | -------------------------- | ---------------------- |
|
||||||
|
| `on_prompt_submit` | `run()` 呼び出し時 | ユーザーメッセージの前処理 | `OnPromptSubmitResult` |
|
||||||
|
| `pre_llm_request` | 各ターンのLLM送信前 | コンテキスト改変/検証 | `PreLlmRequestResult` |
|
||||||
|
| `pre_tool_call` | ツール実行前 | 実行許可/引数改変 | `PreToolCallResult` |
|
||||||
|
| `post_tool_call` | ツール実行後 | 結果加工/マスキング | `PostToolCallResult` |
|
||||||
|
| `on_turn_end` | ツールなしでターン終了直前 | 検証/リトライ指示 | `OnTurnEndResult` |
|
||||||
|
| `on_abort` | 中断時 | クリーンアップ/通知 | `()` |
|
||||||
|
|
||||||
## Hook Trait
|
## Hook Trait
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait WorkerHook: Send + Sync {
|
pub trait Hook<E: HookEventKind>: Send + Sync {
|
||||||
/// メッセージ送信前
|
async fn call(&self, input: &mut E::Input) -> Result<E::Output, HookError>;
|
||||||
/// リクエストに含まれるメッセージリストを改変できる
|
|
||||||
async fn on_message_send(
|
|
||||||
&self,
|
|
||||||
context: &mut Vec<Message>,
|
|
||||||
) -> Result<ControlFlow, HookError> {
|
|
||||||
Ok(ControlFlow::Continue)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// ツール実行前
|
|
||||||
/// 実行をキャンセルしたり、引数を書き換えることができる
|
|
||||||
async fn before_tool_call(
|
|
||||||
&self,
|
|
||||||
tool_call: &mut ToolCall,
|
|
||||||
) -> Result<ControlFlow, HookError> {
|
|
||||||
Ok(ControlFlow::Continue)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// ツール実行後
|
|
||||||
/// 結果を書き換えたり、隠蔽したりできる
|
|
||||||
async fn after_tool_call(
|
|
||||||
&self,
|
|
||||||
tool_result: &mut ToolResult,
|
|
||||||
) -> Result<ControlFlow, HookError> {
|
|
||||||
Ok(ControlFlow::Continue)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// ターン終了時
|
|
||||||
/// 生成されたメッセージを検査し、必要ならリトライを指示できる
|
|
||||||
async fn on_turn_end(
|
|
||||||
&self,
|
|
||||||
messages: &[Message],
|
|
||||||
) -> Result<TurnResult, HookError> {
|
|
||||||
Ok(TurnResult::Finish)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## 制御フロー型
|
## 制御フロー型
|
||||||
|
|
||||||
### ControlFlow
|
### HookEventKind / Result
|
||||||
|
|
||||||
Hook処理の継続/中断を制御する列挙型。
|
Hookイベントごとに入力/出力型を分離し、意味のない制御フローを排除する。
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
pub enum ControlFlow {
|
pub trait HookEventKind {
|
||||||
/// 処理を続行(後続Hookも実行)
|
type Input;
|
||||||
|
type Output;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct OnPromptSubmit;
|
||||||
|
pub struct PreLlmRequest;
|
||||||
|
pub struct PreToolCall;
|
||||||
|
pub struct PostToolCall;
|
||||||
|
pub struct OnTurnEnd;
|
||||||
|
pub struct OnAbort;
|
||||||
|
|
||||||
|
pub enum OnPromptSubmitResult {
|
||||||
|
Continue,
|
||||||
|
Cancel(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum PreLlmRequestResult {
|
||||||
|
Continue,
|
||||||
|
Cancel(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum PreToolCallResult {
|
||||||
Continue,
|
Continue,
|
||||||
/// 現在の処理をスキップ(ツール実行をスキップ等)
|
|
||||||
Skip,
|
Skip,
|
||||||
/// 処理全体を中断(エラーとして扱う)
|
|
||||||
Abort(String),
|
Abort(String),
|
||||||
|
Pause,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum PostToolCallResult {
|
||||||
|
Continue,
|
||||||
|
Abort(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum OnTurnEndResult {
|
||||||
|
Finish,
|
||||||
|
ContinueWithMessages(Vec<Message>),
|
||||||
|
Paused,
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### TurnResult
|
### Tool Call Context
|
||||||
|
|
||||||
ターン終了時の判定結果を表す列挙型。
|
`pre_tool_call` / `post_tool_call` は、ツール実行の文脈を含む入力を受け取る。
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
pub enum TurnResult {
|
pub struct ToolCallContext {
|
||||||
/// ターンを正常終了
|
pub call: ToolCall,
|
||||||
Finish,
|
pub meta: ToolMeta, // 不変メタデータ
|
||||||
/// メッセージを追加してターン継続(自己修正など)
|
pub tool: Arc<dyn Tool>, // 状態アクセス用
|
||||||
ContinueWithMessages(Vec<Message>),
|
}
|
||||||
|
|
||||||
|
pub struct PostToolCallContext {
|
||||||
|
pub call: ToolCall,
|
||||||
|
pub result: ToolResult,
|
||||||
|
pub meta: ToolMeta,
|
||||||
|
pub tool: Arc<dyn Tool>,
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## 呼び出しタイミング
|
## 呼び出しタイミング
|
||||||
|
|
||||||
```
|
```
|
||||||
Worker::run() ループ
|
Worker::run(user_input)
|
||||||
│
|
│
|
||||||
├─▶ on_message_send ──────────────────────────────┐
|
├─▶ on_prompt_submit ───────────────────────────┐
|
||||||
|
│ ユーザーメッセージの前処理・検証 │
|
||||||
|
│ (最初の1回のみ) │
|
||||||
|
│ │
|
||||||
|
└─▶ loop {
|
||||||
|
│
|
||||||
|
├─▶ pre_llm_request ──────────────────────│
|
||||||
│ コンテキストの改変、バリデーション、 │
|
│ コンテキストの改変、バリデーション、 │
|
||||||
│ システムプロンプト注入などが可能 │
|
│ システムプロンプト注入などが可能 │
|
||||||
|
│ (毎ターン実行) │
|
||||||
│ │
|
│ │
|
||||||
├─▶ LLMリクエスト送信 & ストリーム処理 │
|
├─▶ LLMリクエスト送信 & ストリーム処理 │
|
||||||
│ │
|
│ │
|
||||||
├─▶ ツール呼び出しがある場合: │
|
├─▶ ツール呼び出しがある場合: │
|
||||||
│ │ │
|
│ │ │
|
||||||
│ ├─▶ before_tool_call (各ツールごと・逐次) │
|
│ ├─▶ pre_tool_call (各ツールごと・逐次) │
|
||||||
│ │ 実行可否の判定、引数の改変 │
|
│ │ 実行可否の判定、引数の改変 │
|
||||||
│ │ │
|
│ │ │
|
||||||
│ ├─▶ ツール並列実行 (join_all) │
|
│ ├─▶ ツール並列実行 (join_all) │
|
||||||
│ │ │
|
│ │ │
|
||||||
│ └─▶ after_tool_call (各結果ごと・逐次) │
|
│ └─▶ post_tool_call (各結果ごと・逐次) │
|
||||||
│ 結果の確認、加工、ログ出力 │
|
│ 結果の確認、加工、ログ出力 │
|
||||||
│ │
|
│ │
|
||||||
├─▶ ツール結果をコンテキストに追加 → ループ先頭へ │
|
├─▶ ツール結果をコンテキストに追加 │
|
||||||
|
│ → ループ先頭へ │
|
||||||
│ │
|
│ │
|
||||||
└─▶ ツールなしの場合: │
|
└─▶ ツールなしの場合: │
|
||||||
│ │
|
│ │
|
||||||
└─▶ on_turn_end ─────────────────────────────┘
|
└─▶ on_turn_end ───────────────────┘
|
||||||
最終応答のチェック(Lint/Fmt等)
|
最終応答のチェック(Lint/Fmt等)
|
||||||
エラーがあればContinueWithMessagesでリトライ
|
エラーがあればContinueWithMessagesでリトライ
|
||||||
|
}
|
||||||
|
|
||||||
|
※ 中断時は on_abort が呼ばれる
|
||||||
```
|
```
|
||||||
|
|
||||||
## 各Hookの詳細
|
## 各Hookの詳細
|
||||||
|
|
||||||
### on_message_send
|
### on_prompt_submit
|
||||||
|
|
||||||
**呼び出しタイミング**: LLMへリクエスト送信前(ターンループの冒頭)
|
**呼び出しタイミング**: `run()`
|
||||||
|
でユーザーメッセージを受け取った直後(最初の1回のみ)
|
||||||
|
|
||||||
**用途**:
|
**用途**:
|
||||||
|
|
||||||
|
- ユーザー入力のバリデーション
|
||||||
|
- 入力のサニタイズ・フィルタリング
|
||||||
|
- ログ出力
|
||||||
|
- `OnPromptSubmitResult::Cancel` による実行キャンセル
|
||||||
|
|
||||||
|
**入力**: `&mut Message` - ユーザーメッセージ(改変可能)
|
||||||
|
|
||||||
|
**例**: 入力のバリデーション
|
||||||
|
|
||||||
|
```rust
|
||||||
|
struct InputValidator;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Hook<OnPromptSubmit> for InputValidator {
|
||||||
|
async fn call(
|
||||||
|
&self,
|
||||||
|
message: &mut Message,
|
||||||
|
) -> Result<OnPromptSubmitResult, HookError> {
|
||||||
|
if let MessageContent::Text(text) = &message.content {
|
||||||
|
if text.trim().is_empty() {
|
||||||
|
return Ok(OnPromptSubmitResult::Cancel("Empty input".to_string()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(OnPromptSubmitResult::Continue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### pre_llm_request
|
||||||
|
|
||||||
|
**呼び出しタイミング**: 各ターンのLLMリクエスト送信前(ループの毎回)
|
||||||
|
|
||||||
|
**用途**:
|
||||||
|
|
||||||
- コンテキストへのシステムメッセージ注入
|
- コンテキストへのシステムメッセージ注入
|
||||||
- メッセージのバリデーション
|
- メッセージのバリデーション
|
||||||
- 機密情報のフィルタリング
|
- 機密情報のフィルタリング
|
||||||
- リクエスト内容のログ出力
|
- リクエスト内容のログ出力
|
||||||
|
- `PreLlmRequestResult::Cancel` による送信キャンセル
|
||||||
|
|
||||||
|
**入力**: `&mut Vec<Message>` - コンテキスト全体(改変可能)
|
||||||
|
|
||||||
**例**: メッセージにタイムスタンプを追加
|
**例**: メッセージにタイムスタンプを追加
|
||||||
|
|
||||||
|
|
@ -132,27 +197,33 @@ Worker::run() ループ
|
||||||
struct TimestampHook;
|
struct TimestampHook;
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl WorkerHook for TimestampHook {
|
impl Hook<PreLlmRequest> for TimestampHook {
|
||||||
async fn on_message_send(
|
async fn call(
|
||||||
&self,
|
&self,
|
||||||
context: &mut Vec<Message>,
|
context: &mut Vec<Message>,
|
||||||
) -> Result<ControlFlow, HookError> {
|
) -> Result<PreLlmRequestResult, HookError> {
|
||||||
let timestamp = chrono::Local::now().to_rfc3339();
|
let timestamp = chrono::Local::now().to_rfc3339();
|
||||||
context.insert(0, Message::user(format!("[{}]", timestamp)));
|
context.insert(0, Message::user(format!("[{}]", timestamp)));
|
||||||
Ok(ControlFlow::Continue)
|
Ok(PreLlmRequestResult::Continue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### before_tool_call
|
### pre_tool_call
|
||||||
|
|
||||||
**呼び出しタイミング**: 各ツール実行前(並列実行フェーズの前)
|
**呼び出しタイミング**: 各ツール実行前(並列実行フェーズの前)
|
||||||
|
|
||||||
**用途**:
|
**用途**:
|
||||||
|
|
||||||
- 危険なツールのブロック
|
- 危険なツールのブロック
|
||||||
- 引数のサニタイズ
|
- 引数のサニタイズ
|
||||||
- 確認プロンプトの表示(UIとの連携)
|
- 確認プロンプトの表示(UIとの連携)
|
||||||
- 実行ログの記録
|
- 実行ログの記録
|
||||||
|
- `PreToolCallResult::Pause` による一時停止
|
||||||
|
|
||||||
|
**入力**:
|
||||||
|
|
||||||
|
- `ToolCallContext`(`ToolCall` + `ToolMeta` + `Arc<dyn Tool>`)
|
||||||
|
|
||||||
**例**: 特定ツールをブロック
|
**例**: 特定ツールをブロック
|
||||||
|
|
||||||
|
|
@ -162,46 +233,52 @@ struct ToolBlocker {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl WorkerHook for ToolBlocker {
|
impl Hook<PreToolCall> for ToolBlocker {
|
||||||
async fn before_tool_call(
|
async fn call(
|
||||||
&self,
|
&self,
|
||||||
tool_call: &mut ToolCall,
|
ctx: &mut ToolCallContext,
|
||||||
) -> Result<ControlFlow, HookError> {
|
) -> Result<PreToolCallResult, HookError> {
|
||||||
if self.blocked_tools.contains(&tool_call.name) {
|
if self.blocked_tools.contains(&ctx.call.name) {
|
||||||
println!("Blocked tool: {}", tool_call.name);
|
println!("Blocked tool: {}", ctx.call.name);
|
||||||
Ok(ControlFlow::Skip)
|
Ok(PreToolCallResult::Skip)
|
||||||
} else {
|
} else {
|
||||||
Ok(ControlFlow::Continue)
|
Ok(PreToolCallResult::Continue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### after_tool_call
|
### post_tool_call
|
||||||
|
|
||||||
**呼び出しタイミング**: 各ツール実行後(並列実行フェーズの後)
|
**呼び出しタイミング**: 各ツール実行後(並列実行フェーズの後)
|
||||||
|
|
||||||
**用途**:
|
**用途**:
|
||||||
|
|
||||||
- 結果の加工・フォーマット
|
- 結果の加工・フォーマット
|
||||||
- 機密情報のマスキング
|
- 機密情報のマスキング
|
||||||
- 結果のキャッシュ
|
- 結果のキャッシュ
|
||||||
- 実行結果のログ出力
|
- 実行結果のログ出力
|
||||||
|
|
||||||
|
**入力**:
|
||||||
|
|
||||||
|
- `PostToolCallContext`(`ToolCall` + `ToolResult` + `ToolMeta` +
|
||||||
|
`Arc<dyn Tool>`)
|
||||||
|
|
||||||
**例**: 結果にプレフィックスを追加
|
**例**: 結果にプレフィックスを追加
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
struct ResultFormatter;
|
struct ResultFormatter;
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl WorkerHook for ResultFormatter {
|
impl Hook<PostToolCall> for ResultFormatter {
|
||||||
async fn after_tool_call(
|
async fn call(
|
||||||
&self,
|
&self,
|
||||||
tool_result: &mut ToolResult,
|
ctx: &mut PostToolCallContext,
|
||||||
) -> Result<ControlFlow, HookError> {
|
) -> Result<PostToolCallResult, HookError> {
|
||||||
if !tool_result.is_error {
|
if !ctx.result.is_error {
|
||||||
tool_result.content = format!("[OK] {}", tool_result.content);
|
ctx.result.content = format!("[OK] {}", ctx.result.content);
|
||||||
}
|
}
|
||||||
Ok(ControlFlow::Continue)
|
Ok(PostToolCallResult::Continue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
@ -211,10 +288,22 @@ impl WorkerHook for ResultFormatter {
|
||||||
**呼び出しタイミング**: ツール呼び出しなしでターンが終了する直前
|
**呼び出しタイミング**: ツール呼び出しなしでターンが終了する直前
|
||||||
|
|
||||||
**用途**:
|
**用途**:
|
||||||
|
|
||||||
- 生成されたコードのLint/Fmt
|
- 生成されたコードのLint/Fmt
|
||||||
- 出力形式のバリデーション
|
- 出力形式のバリデーション
|
||||||
- 自己修正のためのリトライ指示
|
- 自己修正のためのリトライ指示
|
||||||
- 最終結果のログ出力
|
- 最終結果のログ出力
|
||||||
|
- `OnTurnEndResult::Paused` による一時停止
|
||||||
|
|
||||||
|
### on_abort
|
||||||
|
|
||||||
|
**呼び出しタイミング**: キャンセル/エラー/AbortなどでWorkerが中断された時
|
||||||
|
|
||||||
|
**用途**:
|
||||||
|
|
||||||
|
- クリーンアップ処理
|
||||||
|
- 中断理由のログ出力
|
||||||
|
- 外部システムへの通知
|
||||||
|
|
||||||
**例**: JSON形式のバリデーション
|
**例**: JSON形式のバリデーション
|
||||||
|
|
||||||
|
|
@ -222,11 +311,11 @@ impl WorkerHook for ResultFormatter {
|
||||||
struct JsonValidator;
|
struct JsonValidator;
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl WorkerHook for JsonValidator {
|
impl Hook<OnTurnEnd> for JsonValidator {
|
||||||
async fn on_turn_end(
|
async fn call(
|
||||||
&self,
|
&self,
|
||||||
messages: &[Message],
|
messages: &mut Vec<Message>,
|
||||||
) -> Result<TurnResult, HookError> {
|
) -> Result<OnTurnEndResult, HookError> {
|
||||||
// 最後のアシスタントメッセージを取得
|
// 最後のアシスタントメッセージを取得
|
||||||
let last = messages.iter().rev()
|
let last = messages.iter().rev()
|
||||||
.find(|m| m.role == Role::Assistant);
|
.find(|m| m.role == Role::Assistant);
|
||||||
|
|
@ -236,25 +325,25 @@ impl WorkerHook for JsonValidator {
|
||||||
// JSONとしてパースを試みる
|
// JSONとしてパースを試みる
|
||||||
if serde_json::from_str::<serde_json::Value>(text).is_err() {
|
if serde_json::from_str::<serde_json::Value>(text).is_err() {
|
||||||
// 失敗したらリトライ指示
|
// 失敗したらリトライ指示
|
||||||
return Ok(TurnResult::ContinueWithMessages(vec![
|
return Ok(OnTurnEndResult::ContinueWithMessages(vec![
|
||||||
Message::user("Invalid JSON. Please fix and try again.")
|
Message::user("Invalid JSON. Please fix and try again.")
|
||||||
]));
|
]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(TurnResult::Finish)
|
Ok(OnTurnEndResult::Finish)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## 複数Hookの実行順序
|
## 複数Hookの実行順序
|
||||||
|
|
||||||
Hookは**登録順**に実行されます。
|
Hookは**イベントごとに登録順**に実行されます。
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
worker.add_hook(HookA); // 1番目に実行
|
worker.add_pre_tool_call_hook(HookA); // 1番目に実行
|
||||||
worker.add_hook(HookB); // 2番目に実行
|
worker.add_pre_tool_call_hook(HookB); // 2番目に実行
|
||||||
worker.add_hook(HookC); // 3番目に実行
|
worker.add_pre_tool_call_hook(HookC); // 3番目に実行
|
||||||
```
|
```
|
||||||
|
|
||||||
### 制御フローの伝播
|
### 制御フローの伝播
|
||||||
|
|
@ -262,6 +351,7 @@ worker.add_hook(HookC); // 3番目に実行
|
||||||
- `Continue`: 後続Hookも実行
|
- `Continue`: 後続Hookも実行
|
||||||
- `Skip`: 現在の処理をスキップし、後続Hookは実行しない
|
- `Skip`: 現在の処理をスキップし、後続Hookは実行しない
|
||||||
- `Abort`: 即座にエラーを返し、処理全体を中断
|
- `Abort`: 即座にエラーを返し、処理全体を中断
|
||||||
|
- `Pause`: Workerを一時停止(再開は`resume`)
|
||||||
|
|
||||||
```
|
```
|
||||||
Hook A: Continue → Hook B: Skip → (Hook Cは実行されない)
|
Hook A: Continue → Hook B: Skip → (Hook Cは実行されない)
|
||||||
|
|
@ -271,52 +361,39 @@ Hook A: Continue → Hook B: Skip → (Hook Cは実行されない)
|
||||||
Hook A: Continue → Hook B: Abort("reason")
|
Hook A: Continue → Hook B: Abort("reason")
|
||||||
↓
|
↓
|
||||||
WorkerError::Aborted
|
WorkerError::Aborted
|
||||||
|
|
||||||
|
Hook A: Continue → Hook B: Pause
|
||||||
|
↓
|
||||||
|
WorkerResult::Paused
|
||||||
```
|
```
|
||||||
|
|
||||||
## 設計上のポイント
|
## 設計上のポイント
|
||||||
|
|
||||||
### 1. デフォルト実装
|
### 1. イベントごとの実装
|
||||||
|
|
||||||
全メソッドにデフォルト実装があるため、必要なメソッドだけオーバーライドすれば良い。
|
必要なイベントのみ `Hook<Event>` を実装する。
|
||||||
|
|
||||||
```rust
|
|
||||||
struct SimpleLogger;
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl WorkerHook for SimpleLogger {
|
|
||||||
// on_message_send だけ実装
|
|
||||||
async fn on_message_send(
|
|
||||||
&self,
|
|
||||||
context: &mut Vec<Message>,
|
|
||||||
) -> Result<ControlFlow, HookError> {
|
|
||||||
println!("Sending {} messages", context.len());
|
|
||||||
Ok(ControlFlow::Continue)
|
|
||||||
}
|
|
||||||
// 他のメソッドはデフォルト(Continue/Finish)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. 可変参照による改変
|
### 2. 可変参照による改変
|
||||||
|
|
||||||
`&mut`で引数を受け取るため、直接改変が可能。
|
`&mut`で引数を受け取るため、直接改変が可能。
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
async fn before_tool_call(&self, tool_call: &mut ToolCall) -> ... {
|
async fn call(&self, ctx: &mut ToolCallContext) -> ... {
|
||||||
// 引数を直接書き換え
|
// 引数を直接書き換え
|
||||||
tool_call.input["sanitized"] = json!(true);
|
ctx.call.input["sanitized"] = json!(true);
|
||||||
Ok(ControlFlow::Continue)
|
Ok(PreToolCallResult::Continue)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3. 並列実行との統合
|
### 3. 並列実行との統合
|
||||||
|
|
||||||
- `before_tool_call`: 並列実行**前**に逐次実行(許可判定のため)
|
- `pre_tool_call`: 並列実行**前**に逐次実行(許可判定のため)
|
||||||
- ツール実行: `join_all`で**並列**実行
|
- ツール実行: `join_all`で**並列**実行
|
||||||
- `after_tool_call`: 並列実行**後**に逐次実行(結果加工のため)
|
- `post_tool_call`: 並列実行**後**に逐次実行(結果加工のため)
|
||||||
|
|
||||||
### 4. Send + Sync 要件
|
### 4. Send + Sync 要件
|
||||||
|
|
||||||
`WorkerHook`は`Send + Sync`を要求するため、スレッドセーフな実装が必要。
|
`Hook`は`Send + Sync`を要求するため、スレッドセーフな実装が必要。
|
||||||
状態を持つ場合は`Arc<Mutex<T>>`や`AtomicUsize`などを使用する。
|
状態を持つ場合は`Arc<Mutex<T>>`や`AtomicUsize`などを使用する。
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
|
|
@ -325,10 +402,10 @@ struct CountingHook {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl WorkerHook for CountingHook {
|
impl Hook<PreToolCall> for CountingHook {
|
||||||
async fn before_tool_call(&self, _: &mut ToolCall) -> Result<ControlFlow, HookError> {
|
async fn call(&self, _: &mut ToolCallContext) -> Result<PreToolCallResult, HookError> {
|
||||||
self.count.fetch_add(1, Ordering::SeqCst);
|
self.count.fetch_add(1, Ordering::SeqCst);
|
||||||
Ok(ControlFlow::Continue)
|
Ok(PreToolCallResult::Continue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
@ -336,13 +413,13 @@ impl WorkerHook for CountingHook {
|
||||||
## 典型的なユースケース
|
## 典型的なユースケース
|
||||||
|
|
||||||
| ユースケース | 使用Hook | 処理内容 |
|
| ユースケース | 使用Hook | 処理内容 |
|
||||||
|-------------|----------|----------|
|
| ------------------ | -------------------- | -------------------------- |
|
||||||
| ツール許可制御 | `before_tool_call` | 危険なツールをSkip |
|
| ツール許可制御 | `pre_tool_call` | 危険なツールをSkip |
|
||||||
| 実行ログ | `before/after_tool_call` | 呼び出しと結果を記録 |
|
| 実行ログ | `pre/post_tool_call` | 呼び出しと結果を記録 |
|
||||||
| 出力バリデーション | `on_turn_end` | 形式チェック、リトライ指示 |
|
| 出力バリデーション | `on_turn_end` | 形式チェック、リトライ指示 |
|
||||||
| コンテキスト注入 | `on_message_send` | システムメッセージ追加 |
|
| コンテキスト注入 | `on_message_send` | システムメッセージ追加 |
|
||||||
| 結果のサニタイズ | `after_tool_call` | 機密情報のマスキング |
|
| 結果のサニタイズ | `post_tool_call` | 機密情報のマスキング |
|
||||||
| レート制限 | `before_tool_call` | 呼び出し頻度の制御 |
|
| レート制限 | `pre_tool_call` | 呼び出し頻度の制御 |
|
||||||
|
|
||||||
## TODO
|
## TODO
|
||||||
|
|
||||||
|
|
@ -350,11 +427,14 @@ impl WorkerHook for CountingHook {
|
||||||
|
|
||||||
現在のHooks実装は基本的なユースケースをカバーしているが、以下の点について将来的に厳密な仕様を定義する必要がある:
|
現在のHooks実装は基本的なユースケースをカバーしているが、以下の点について将来的に厳密な仕様を定義する必要がある:
|
||||||
|
|
||||||
- **エラーハンドリングの明確化**: `HookError`発生時のリカバリー戦略、部分的な失敗の扱い
|
- **エラーハンドリングの明確化**:
|
||||||
|
`HookError`発生時のリカバリー戦略、部分的な失敗の扱い
|
||||||
- **Hook間の依存関係**: 複数Hookの実行順序が結果に影響する場合のセマンティクス
|
- **Hook間の依存関係**: 複数Hookの実行順序が結果に影響する場合のセマンティクス
|
||||||
- **非同期キャンセル**: Hook実行中のキャンセル(タイムアウト等)の振る舞い
|
- **非同期キャンセル**: Hook実行中のキャンセル(タイムアウト等)の振る舞い
|
||||||
- **状態の一貫性**: `on_message_send`で改変されたコンテキストが後続処理で期待通りに反映される保証
|
- **状態の一貫性**:
|
||||||
- **リトライ制限**: `on_turn_end`での`ContinueWithMessages`による無限ループ防止策
|
`on_message_send`で改変されたコンテキストが後続処理で期待通りに反映される保証
|
||||||
|
- **リトライ制限**:
|
||||||
|
`on_turn_end`での`ContinueWithMessages`による無限ループ防止策
|
||||||
- **Hook優先度**: 登録順以外の優先度指定メカニズムの必要性
|
- **Hook優先度**: 登録順以外の優先度指定メカニズムの必要性
|
||||||
- **条件付きHook**: 特定条件でのみ有効化されるHookパターン
|
- **条件付きHook**: 特定条件でのみ有効化されるHookパターン
|
||||||
- **テスト容易性**: Hookのモック/スタブ作成のためのユーティリティ
|
- **テスト容易性**: Hookのモック/スタブ作成のためのユーティリティ
|
||||||
|
|
|
||||||
191
docs/spec/tools_design.md
Normal file
191
docs/spec/tools_design.md
Normal file
|
|
@ -0,0 +1,191 @@
|
||||||
|
# Tool 設計
|
||||||
|
|
||||||
|
## 概要
|
||||||
|
|
||||||
|
`llm-worker`のツールシステムは、LLMが外部リソースにアクセスしたり計算を実行するための仕組みを提供する。
|
||||||
|
メタ情報の不変性とセッションスコープの状態管理を両立させる設計となっている。
|
||||||
|
|
||||||
|
## 主要な型
|
||||||
|
|
||||||
|
```
|
||||||
|
type ToolDefinition
|
||||||
|
Fn() -> (ToolMeta, Arc<dyn Tool>)
|
||||||
|
|
||||||
|
worker.register_tool() で呼び出し
|
||||||
|
|
||||||
|
▼
|
||||||
|
|
||||||
|
- struct ToolMeta (name, desc, schema)
|
||||||
|
不変・登録時固定
|
||||||
|
- trait Tool (executer)
|
||||||
|
登録時生成・セッション中再利用
|
||||||
|
```
|
||||||
|
|
||||||
|
### ToolMeta
|
||||||
|
|
||||||
|
ツールのメタ情報を保持する不変構造体。登録時に固定され、Worker内で変更されない。
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct ToolMeta {
|
||||||
|
pub name: String,
|
||||||
|
pub description: String,
|
||||||
|
pub input_schema: Value,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**目的:**
|
||||||
|
|
||||||
|
- LLM へのツール定義として送信
|
||||||
|
- Hook からの参照(読み取り専用)
|
||||||
|
- 登録後の不変性を保証
|
||||||
|
|
||||||
|
### Tool trait
|
||||||
|
|
||||||
|
ツールの実行ロジックのみを定義するトレイト。
|
||||||
|
|
||||||
|
```rust
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Tool: Send + Sync {
|
||||||
|
async fn execute(&self, input_json: &str) -> Result<String, ToolError>;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**設計方針:**
|
||||||
|
|
||||||
|
- メタ情報(name, description, schema)は含まない
|
||||||
|
- 状態を持つことが可能(セッション中のカウンターなど)
|
||||||
|
- `Send + Sync` で並列実行に対応
|
||||||
|
|
||||||
|
**インスタンスのライフサイクル:**
|
||||||
|
|
||||||
|
1. `register_tool()` 呼び出し時にファクトリが実行され、インスタンスが生成される
|
||||||
|
2. LLM がツールを呼び出すと、既存インスタンスの `execute()` が実行される
|
||||||
|
3. 同じセッション中は同一インスタンスが再利用される
|
||||||
|
|
||||||
|
※ 「最初に呼ばれたとき」の遅延初期化ではなく、**登録時の即時初期化**である。
|
||||||
|
|
||||||
|
### ToolDefinition
|
||||||
|
|
||||||
|
メタ情報とツールインスタンスを生成するファクトリ。
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub type ToolDefinition = Arc<dyn Fn() -> (ToolMeta, Arc<dyn Tool>) + Send + Sync>;
|
||||||
|
```
|
||||||
|
|
||||||
|
**なぜファクトリか:**
|
||||||
|
|
||||||
|
- Worker への登録時に一度だけ呼び出される
|
||||||
|
- メタ情報とインスタンスを同時に生成し、整合性を保証
|
||||||
|
- クロージャでコンテキスト(`self.clone()`)をキャプチャ可能
|
||||||
|
|
||||||
|
## Worker でのツール管理
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// Worker 内部
|
||||||
|
tools: HashMap<String, (ToolMeta, Arc<dyn Tool>)>
|
||||||
|
|
||||||
|
// 登録 API
|
||||||
|
pub fn register_tool(&mut self, factory: ToolDefinition) -> Result<(), ToolRegistryError>
|
||||||
|
```
|
||||||
|
|
||||||
|
登録時の処理:
|
||||||
|
|
||||||
|
1. ファクトリを呼び出し `(meta, instance)` を取得
|
||||||
|
2. 同名ツールが既に登録されていればエラー
|
||||||
|
3. HashMap に `(meta, instance)` を保存
|
||||||
|
|
||||||
|
## マクロによる自動生成
|
||||||
|
|
||||||
|
`#[tool_registry]` マクロは `{method}_definition()` メソッドを生成する。
|
||||||
|
|
||||||
|
```rust
|
||||||
|
#[tool_registry]
|
||||||
|
impl MyApp {
|
||||||
|
/// 検索を実行する
|
||||||
|
#[tool]
|
||||||
|
async fn search(&self, query: String) -> String {
|
||||||
|
// 実装
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成されるコード:
|
||||||
|
impl MyApp {
|
||||||
|
pub fn search_definition(&self) -> ToolDefinition {
|
||||||
|
let ctx = self.clone();
|
||||||
|
Arc::new(move || {
|
||||||
|
let meta = ToolMeta::new("search")
|
||||||
|
.description("検索を実行する")
|
||||||
|
.input_schema(/* schemars で生成 */);
|
||||||
|
let tool = Arc::new(ToolSearch { ctx: ctx.clone() });
|
||||||
|
(meta, tool)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Hook との連携
|
||||||
|
|
||||||
|
Hook は `ToolCallContext` / `AfterToolCallContext`
|
||||||
|
を通じてメタ情報とインスタンスにアクセスできる。
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct ToolCallContext {
|
||||||
|
pub call: ToolCall, // 呼び出し情報(改変可能)
|
||||||
|
pub meta: ToolMeta, // メタ情報(読み取り専用)
|
||||||
|
pub tool: Arc<dyn Tool>, // インスタンス(状態アクセス用)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**用途:**
|
||||||
|
|
||||||
|
- `meta` で名前やスキーマを確認
|
||||||
|
- `tool` でツールの内部状態を読み取り(ダウンキャスト必要)
|
||||||
|
- `call` の引数を改変してツールに渡す
|
||||||
|
|
||||||
|
## 使用例
|
||||||
|
|
||||||
|
### 手動実装
|
||||||
|
|
||||||
|
```rust
|
||||||
|
struct Counter { count: AtomicUsize }
|
||||||
|
|
||||||
|
impl Tool for Counter {
|
||||||
|
async fn execute(&self, _: &str) -> Result<String, ToolError> {
|
||||||
|
let n = self.count.fetch_add(1, Ordering::SeqCst);
|
||||||
|
Ok(format!("count: {}", n))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let def: ToolDefinition = Arc::new(|| {
|
||||||
|
let meta = ToolMeta::new("counter")
|
||||||
|
.description("カウンターを増加")
|
||||||
|
.input_schema(json!({"type": "object"}));
|
||||||
|
(meta, Arc::new(Counter { count: AtomicUsize::new(0) }))
|
||||||
|
});
|
||||||
|
|
||||||
|
worker.register_tool(def)?;
|
||||||
|
```
|
||||||
|
|
||||||
|
### マクロ使用(推奨)
|
||||||
|
|
||||||
|
```rust
|
||||||
|
#[tool_registry]
|
||||||
|
impl App {
|
||||||
|
#[tool]
|
||||||
|
async fn greet(&self, name: String) -> String {
|
||||||
|
format!("Hello, {}!", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let app = App;
|
||||||
|
worker.register_tool(app.greet_definition())?;
|
||||||
|
```
|
||||||
|
|
||||||
|
## 設計上の決定
|
||||||
|
|
||||||
|
| 問題 | 決定 | 理由 |
|
||||||
|
| -------------------- | ------------------------------ | ---------------------------------------------- |
|
||||||
|
| メタ情報の変更可能性 | ToolMeta を分離・不変化 | 登録後の整合性を保証 |
|
||||||
|
| 状態管理 | 登録時にインスタンス生成 | セッション中の状態保持、同一インスタンス再利用 |
|
||||||
|
| Factory vs Instance | Factory + 登録時即時呼び出し | コンテキストキャプチャと登録時検証 |
|
||||||
|
| Hook からのアクセス | Context に meta と tool を含む | 柔軟な介入を可能に |
|
||||||
|
|
@ -178,41 +178,60 @@ Workerは生成されたラッパー構造体を `Box<dyn Tool>` として保持
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait WorkerHook: Send + Sync {
|
pub trait Hook<E: HookEventKind>: Send + Sync {
|
||||||
/// メッセージ送信前。
|
async fn call(&self, input: &mut E::Input) -> Result<E::Output, Error>;
|
||||||
/// リクエストに含まれるメッセージリストを改変できる。
|
|
||||||
async fn on_message_send(&self, context: &mut Vec<Message>) -> Result<ControlFlow, Error> {
|
|
||||||
Ok(ControlFlow::Continue)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ツール実行前。
|
pub trait HookEventKind {
|
||||||
/// 実行をキャンセルしたり、引数を書き換えることができる。
|
type Input;
|
||||||
async fn before_tool_call(&self, tool_call: &mut ToolCall) -> Result<ControlFlow, Error> {
|
type Output;
|
||||||
Ok(ControlFlow::Continue)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ツール実行後。
|
pub struct OnMessageSend;
|
||||||
/// 結果を書き換えたり、隠蔽したりできる。
|
pub struct BeforeToolCall;
|
||||||
async fn after_tool_call(&self, tool_result: &mut ToolResult) -> Result<ControlFlow, Error> {
|
pub struct AfterToolCall;
|
||||||
Ok(ControlFlow::Continue)
|
pub struct OnTurnEnd;
|
||||||
|
pub struct OnAbort;
|
||||||
|
|
||||||
|
pub enum OnMessageSendResult {
|
||||||
|
Continue,
|
||||||
|
Cancel(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ターン終了時。
|
pub enum BeforeToolCallResult {
|
||||||
/// 生成されたメッセージを検査し、必要ならリトライ(ContinueWithMessages)を指示できる。
|
|
||||||
async fn on_turn_end(&self, messages: &[Message]) -> Result<TurnResult, Error> {
|
|
||||||
Ok(TurnResult::Finish)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub enum ControlFlow {
|
|
||||||
Continue,
|
Continue,
|
||||||
Skip, // Tool実行などをスキップ
|
Skip, // Tool実行などをスキップ
|
||||||
Abort(String), // 処理中断
|
Abort(String), // 処理中断
|
||||||
|
Pause,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum TurnResult {
|
pub enum AfterToolCallResult {
|
||||||
|
Continue,
|
||||||
|
Abort(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum OnTurnEndResult {
|
||||||
Finish,
|
Finish,
|
||||||
ContinueWithMessages(Vec<Message>), // メッセージを追加してターン継続(自己修正など)
|
ContinueWithMessages(Vec<Message>), // メッセージを追加してターン継続(自己修正など)
|
||||||
|
Paused,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Tool Call Context
|
||||||
|
|
||||||
|
`before_tool_call` / `after_tool_call` は、ツール実行の文脈を含む入力を受け取る。
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct ToolCallContext {
|
||||||
|
pub call: ToolCall,
|
||||||
|
pub meta: ToolMeta, // 不変メタデータ
|
||||||
|
pub tool: Arc<dyn Tool>, // 状態アクセス用
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ToolResultContext {
|
||||||
|
pub result: ToolResult,
|
||||||
|
pub meta: ToolMeta,
|
||||||
|
pub tool: Arc<dyn Tool>,
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -433,4 +452,3 @@ impl<C: LlmClient> Worker<C> {
|
||||||
3. **選択的購読**: on_*で必要なイベントだけ、またはSubscriberで一括
|
3. **選択的購読**: on_*で必要なイベントだけ、またはSubscriberで一括
|
||||||
4. **累積イベントの追加**: Worker層でComplete系イベントを追加提供
|
4. **累積イベントの追加**: Worker層でComplete系イベントを追加提供
|
||||||
5. **後方互換性**: 従来の`run()`も引き続き使用可能
|
5. **後方互換性**: 従来の`run()`も引き続き使用可能
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -113,7 +113,7 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
|
||||||
let pascal_name = to_pascal_case(&method_name.to_string());
|
let pascal_name = to_pascal_case(&method_name.to_string());
|
||||||
let tool_struct_name = format_ident!("Tool{}", pascal_name);
|
let tool_struct_name = format_ident!("Tool{}", pascal_name);
|
||||||
let args_struct_name = format_ident!("{}Args", pascal_name);
|
let args_struct_name = format_ident!("{}Args", pascal_name);
|
||||||
let factory_name = format_ident!("{}_tool", method_name);
|
let definition_name = format_ident!("{}_definition", method_name);
|
||||||
|
|
||||||
// ドキュメントコメントから説明を取得
|
// ドキュメントコメントから説明を取得
|
||||||
let description = extract_doc_comment(&method.attrs);
|
let description = extract_doc_comment(&method.attrs);
|
||||||
|
|
@ -247,29 +247,24 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl ::llm_worker::tool::Tool for #tool_struct_name {
|
impl ::llm_worker::tool::Tool for #tool_struct_name {
|
||||||
fn name(&self) -> &str {
|
|
||||||
#tool_name
|
|
||||||
}
|
|
||||||
|
|
||||||
fn description(&self) -> &str {
|
|
||||||
#description
|
|
||||||
}
|
|
||||||
|
|
||||||
fn input_schema(&self) -> serde_json::Value {
|
|
||||||
let schema = schemars::schema_for!(#args_struct_name);
|
|
||||||
serde_json::to_value(schema).unwrap_or(serde_json::json!({}))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute(&self, input_json: &str) -> Result<String, ::llm_worker::tool::ToolError> {
|
async fn execute(&self, input_json: &str) -> Result<String, ::llm_worker::tool::ToolError> {
|
||||||
#execute_body
|
#execute_body
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl #self_ty {
|
impl #self_ty {
|
||||||
pub fn #factory_name(&self) -> #tool_struct_name {
|
/// ToolDefinition を取得(Worker への登録用)
|
||||||
#tool_struct_name {
|
pub fn #definition_name(&self) -> ::llm_worker::tool::ToolDefinition {
|
||||||
ctx: self.clone()
|
let ctx = self.clone();
|
||||||
}
|
::std::sync::Arc::new(move || {
|
||||||
|
let schema = schemars::schema_for!(#args_struct_name);
|
||||||
|
let meta = ::llm_worker::tool::ToolMeta::new(#tool_name)
|
||||||
|
.description(#description)
|
||||||
|
.input_schema(serde_json::to_value(schema).unwrap_or(serde_json::json!({})));
|
||||||
|
let tool: ::std::sync::Arc<dyn ::llm_worker::tool::Tool> =
|
||||||
|
::std::sync::Arc::new(#tool_struct_name { ctx: ctx.clone() });
|
||||||
|
(meta, tool)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,8 @@ tracing = "0.1"
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
tokio = { version = "1.49", features = ["macros", "rt-multi-thread"] }
|
tokio = { version = "1.49", features = ["macros", "rt-multi-thread"] }
|
||||||
reqwest = { version = "0.13.1", default-features = false, features = ["stream", "json", "native-tls"] }
|
tokio-util = "0.7"
|
||||||
|
reqwest = { version = "0.13.1", default-features = false, features = ["stream", "json", "native-tls", "http2"] }
|
||||||
eventsource-stream = "0.2"
|
eventsource-stream = "0.2"
|
||||||
llm-worker-macros = { path = "../llm-worker-macros", version = "0.1" }
|
llm-worker-macros = { path = "../llm-worker-macros", version = "0.1" }
|
||||||
|
|
||||||
|
|
|
||||||
71
llm-worker/examples/worker_cancel_demo.rs
Normal file
71
llm-worker/examples/worker_cancel_demo.rs
Normal file
|
|
@ -0,0 +1,71 @@
|
||||||
|
//! Worker のキャンセル機能のデモンストレーション
|
||||||
|
//!
|
||||||
|
//! ストリーミング受信中に別スレッドからキャンセルする例
|
||||||
|
|
||||||
|
use llm_worker::llm_client::providers::anthropic::AnthropicClient;
|
||||||
|
use llm_worker::{Worker, WorkerResult};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
// .envファイルを読み込む
|
||||||
|
dotenv::dotenv().ok();
|
||||||
|
|
||||||
|
// ロギング初期化
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_env_filter(
|
||||||
|
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||||
|
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
|
||||||
|
)
|
||||||
|
.init();
|
||||||
|
|
||||||
|
let api_key =
|
||||||
|
std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY environment variable not set");
|
||||||
|
|
||||||
|
let client = AnthropicClient::new(&api_key, "claude-sonnet-4-20250514");
|
||||||
|
let worker = Arc::new(Mutex::new(Worker::new(client)));
|
||||||
|
|
||||||
|
println!("🚀 Starting Worker...");
|
||||||
|
println!("💡 Will cancel after 2 seconds\n");
|
||||||
|
|
||||||
|
// キャンセルSenderを先に取得(ロックを保持しない)
|
||||||
|
let cancel_tx = {
|
||||||
|
let w = worker.lock().await;
|
||||||
|
w.cancel_sender()
|
||||||
|
};
|
||||||
|
|
||||||
|
// タスク1: Workerを実行
|
||||||
|
let worker_clone = worker.clone();
|
||||||
|
let task = tokio::spawn(async move {
|
||||||
|
let mut w = worker_clone.lock().await;
|
||||||
|
println!("📡 Sending request to LLM...");
|
||||||
|
|
||||||
|
match w.run("Tell me a very long story about a brave knight. Make it as detailed as possible with many paragraphs.").await {
|
||||||
|
Ok(WorkerResult::Finished) => {
|
||||||
|
println!("✅ Task completed normally");
|
||||||
|
}
|
||||||
|
Ok(WorkerResult::Paused) => {
|
||||||
|
println!("⏸️ Task paused");
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
println!("❌ Task error: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// タスク2: 2秒後にキャンセル
|
||||||
|
tokio::spawn(async move {
|
||||||
|
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||||
|
println!("\n🛑 Cancelling worker...");
|
||||||
|
let _ = cancel_tx.send(()).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
// タスク完了を待つ
|
||||||
|
task.await?;
|
||||||
|
|
||||||
|
println!("\n✨ Demo complete!");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
@ -41,7 +41,7 @@ use tracing_subscriber::EnvFilter;
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
use llm_worker::{
|
use llm_worker::{
|
||||||
Worker,
|
Worker,
|
||||||
hook::{ControlFlow, HookError, ToolResult, WorkerHook},
|
hook::{Hook, HookError, PostToolCall, PostToolCallContext, PostToolCallResult},
|
||||||
llm_client::{
|
llm_client::{
|
||||||
LlmClient,
|
LlmClient,
|
||||||
providers::{
|
providers::{
|
||||||
|
|
@ -282,25 +282,22 @@ impl ToolResultPrinterHook {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl WorkerHook for ToolResultPrinterHook {
|
impl Hook<PostToolCall> for ToolResultPrinterHook {
|
||||||
async fn after_tool_call(
|
async fn call(&self, ctx: &mut PostToolCallContext) -> Result<PostToolCallResult, HookError> {
|
||||||
&self,
|
|
||||||
tool_result: &mut ToolResult,
|
|
||||||
) -> Result<ControlFlow, HookError> {
|
|
||||||
let name = self
|
let name = self
|
||||||
.call_names
|
.call_names
|
||||||
.lock()
|
.lock()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.remove(&tool_result.tool_use_id)
|
.remove(&ctx.result.tool_use_id)
|
||||||
.unwrap_or_else(|| tool_result.tool_use_id.clone());
|
.unwrap_or_else(|| ctx.result.tool_use_id.clone());
|
||||||
|
|
||||||
if tool_result.is_error {
|
if ctx.result.is_error {
|
||||||
println!(" Result ({}): ❌ {}", name, tool_result.content);
|
println!(" Result ({}): ❌ {}", name, ctx.result.content);
|
||||||
} else {
|
} else {
|
||||||
println!(" Result ({}): ✅ {}", name, tool_result.content);
|
println!(" Result ({}): ✅ {}", name, ctx.result.content);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(ControlFlow::Continue)
|
Ok(PostToolCallResult::Continue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -441,8 +438,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
// ツール登録(--no-tools でなければ)
|
// ツール登録(--no-tools でなければ)
|
||||||
if !args.no_tools {
|
if !args.no_tools {
|
||||||
let app = AppContext;
|
let app = AppContext;
|
||||||
worker.register_tool(app.get_current_time_tool());
|
worker
|
||||||
worker.register_tool(app.calculate_tool());
|
.register_tool(app.get_current_time_definition())
|
||||||
|
.unwrap();
|
||||||
|
worker.register_tool(app.calculate_definition()).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
// ストリーミング表示用ハンドラーを登録
|
// ストリーミング表示用ハンドラーを登録
|
||||||
|
|
@ -451,7 +450,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
.on_text_block(StreamingPrinter::new())
|
.on_text_block(StreamingPrinter::new())
|
||||||
.on_tool_use_block(ToolCallPrinter::new(tool_call_names.clone()));
|
.on_tool_use_block(ToolCallPrinter::new(tool_call_names.clone()));
|
||||||
|
|
||||||
worker.add_hook(ToolResultPrinterHook::new(tool_call_names));
|
worker.add_post_tool_call_hook(ToolResultPrinterHook::new(tool_call_names));
|
||||||
|
|
||||||
// ワンショットモード
|
// ワンショットモード
|
||||||
if let Some(prompt) = args.prompt {
|
if let Some(prompt) = args.prompt {
|
||||||
|
|
|
||||||
|
|
@ -8,33 +8,110 @@ use serde_json::Value;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// Control Flow Types
|
// Hook Event Kinds
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// Hook処理の制御フロー
|
pub trait HookEventKind: Send + Sync + 'static {
|
||||||
|
type Input;
|
||||||
|
type Output;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct OnPromptSubmit;
|
||||||
|
pub struct PreLlmRequest;
|
||||||
|
pub struct PreToolCall;
|
||||||
|
pub struct PostToolCall;
|
||||||
|
pub struct OnTurnEnd;
|
||||||
|
pub struct OnAbort;
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub enum ControlFlow {
|
pub enum OnPromptSubmitResult {
|
||||||
/// 処理を続行
|
Continue,
|
||||||
|
Cancel(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum PreLlmRequestResult {
|
||||||
|
Continue,
|
||||||
|
Cancel(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum PreToolCallResult {
|
||||||
Continue,
|
Continue,
|
||||||
/// 現在の処理をスキップ(Tool実行など)
|
|
||||||
Skip,
|
Skip,
|
||||||
/// 処理を中断
|
|
||||||
Abort(String),
|
Abort(String),
|
||||||
/// 処理を一時停止(再開可能)
|
|
||||||
Pause,
|
Pause,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ターン終了時の判定結果
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum PostToolCallResult {
|
||||||
|
Continue,
|
||||||
|
Abort(String),
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum TurnResult {
|
pub enum OnTurnEndResult {
|
||||||
/// ターンを終了
|
|
||||||
Finish,
|
Finish,
|
||||||
/// メッセージを追加してターン継続(自己修正など)
|
|
||||||
ContinueWithMessages(Vec<crate::Message>),
|
ContinueWithMessages(Vec<crate::Message>),
|
||||||
/// ターンを一時停止
|
|
||||||
Paused,
|
Paused,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::tool::{Tool, ToolMeta};
|
||||||
|
|
||||||
|
/// PreToolCall の入力コンテキスト
|
||||||
|
pub struct ToolCallContext {
|
||||||
|
/// ツール呼び出し情報(改変可能)
|
||||||
|
pub call: ToolCall,
|
||||||
|
/// ツールメタ情報(不変)
|
||||||
|
pub meta: ToolMeta,
|
||||||
|
/// ツールインスタンス(状態アクセス用)
|
||||||
|
pub tool: Arc<dyn Tool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// PostToolCall の入力コンテキスト
|
||||||
|
pub struct PostToolCallContext {
|
||||||
|
/// ツール呼び出し情報
|
||||||
|
pub call: ToolCall,
|
||||||
|
/// ツール実行結果(改変可能)
|
||||||
|
pub result: ToolResult,
|
||||||
|
/// ツールメタ情報(不変)
|
||||||
|
pub meta: ToolMeta,
|
||||||
|
/// ツールインスタンス(状態アクセス用)
|
||||||
|
pub tool: Arc<dyn Tool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HookEventKind for OnPromptSubmit {
|
||||||
|
type Input = crate::Message;
|
||||||
|
type Output = OnPromptSubmitResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HookEventKind for PreLlmRequest {
|
||||||
|
type Input = Vec<crate::Message>;
|
||||||
|
type Output = PreLlmRequestResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HookEventKind for PreToolCall {
|
||||||
|
type Input = ToolCallContext;
|
||||||
|
type Output = PreToolCallResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HookEventKind for PostToolCall {
|
||||||
|
type Input = PostToolCallContext;
|
||||||
|
type Output = PostToolCallResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HookEventKind for OnTurnEnd {
|
||||||
|
type Input = Vec<crate::Message>;
|
||||||
|
type Output = OnTurnEndResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HookEventKind for OnAbort {
|
||||||
|
type Input = String;
|
||||||
|
type Output = ();
|
||||||
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// Tool Call / Result Types
|
// Tool Call / Result Types
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
@ -102,85 +179,55 @@ pub enum HookError {
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// WorkerHook Trait
|
// Hook Trait
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// ターンの進行・ツール実行に介入するためのトレイト
|
/// Hookイベントの処理を行うトレイト
|
||||||
///
|
///
|
||||||
/// Hookを使うと、メッセージ送信前、ツール実行前後、ターン終了時に
|
/// 各イベント種別は戻り値型が異なるため、`HookEventKind`を介して型を制約する。
|
||||||
/// 処理を挟んだり、実行をキャンセルしたりできます。
|
|
||||||
///
|
|
||||||
/// # Examples
|
|
||||||
///
|
|
||||||
/// ```ignore
|
|
||||||
/// use llm_worker::hook::{ControlFlow, HookError, ToolCall, TurnResult, WorkerHook};
|
|
||||||
/// use llm_worker::Message;
|
|
||||||
///
|
|
||||||
/// struct ValidationHook;
|
|
||||||
///
|
|
||||||
/// #[async_trait::async_trait]
|
|
||||||
/// impl WorkerHook for ValidationHook {
|
|
||||||
/// async fn before_tool_call(&self, call: &mut ToolCall) -> Result<ControlFlow, HookError> {
|
|
||||||
/// // 危険なツールをブロック
|
|
||||||
/// if call.name == "delete_all" {
|
|
||||||
/// return Ok(ControlFlow::Skip);
|
|
||||||
/// }
|
|
||||||
/// Ok(ControlFlow::Continue)
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// async fn on_turn_end(&self, messages: &[Message]) -> Result<TurnResult, HookError> {
|
|
||||||
/// // 条件を満たさなければ追加メッセージで継続
|
|
||||||
/// if messages.len() < 3 {
|
|
||||||
/// return Ok(TurnResult::ContinueWithMessages(vec![
|
|
||||||
/// Message::user("Please elaborate.")
|
|
||||||
/// ]));
|
|
||||||
/// }
|
|
||||||
/// Ok(TurnResult::Finish)
|
|
||||||
/// }
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// # デフォルト実装
|
|
||||||
///
|
|
||||||
/// すべてのメソッドにはデフォルト実装があり、何も行わず`Continue`を返します。
|
|
||||||
/// 必要なメソッドのみオーバーライドしてください。
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait WorkerHook: Send + Sync {
|
pub trait Hook<E: HookEventKind>: Send + Sync {
|
||||||
/// メッセージ送信前に呼ばれる
|
async fn call(&self, input: &mut E::Input) -> Result<E::Output, HookError>;
|
||||||
///
|
|
||||||
/// リクエストに含まれるメッセージリストを参照・改変できます。
|
|
||||||
/// `ControlFlow::Abort`を返すとターンが中断されます。
|
|
||||||
async fn on_message_send(
|
|
||||||
&self,
|
|
||||||
_context: &mut Vec<crate::Message>,
|
|
||||||
) -> Result<ControlFlow, HookError> {
|
|
||||||
Ok(ControlFlow::Continue)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ツール実行前に呼ばれる
|
// =============================================================================
|
||||||
|
// Hook Registry
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
/// 全 Hook を保持するレジストリ
|
||||||
///
|
///
|
||||||
/// ツール呼び出しの引数を書き換えたり、実行をスキップしたりできます。
|
/// Worker 内部で使用され、各種 Hook を一括管理する。
|
||||||
/// `ControlFlow::Skip`を返すとこのツールの実行がスキップされます。
|
pub struct HookRegistry {
|
||||||
async fn before_tool_call(&self, _tool_call: &mut ToolCall) -> Result<ControlFlow, HookError> {
|
/// on_prompt_submit Hook
|
||||||
Ok(ControlFlow::Continue)
|
pub(crate) on_prompt_submit: Vec<Box<dyn Hook<OnPromptSubmit>>>,
|
||||||
|
/// pre_llm_request Hook
|
||||||
|
pub(crate) pre_llm_request: Vec<Box<dyn Hook<PreLlmRequest>>>,
|
||||||
|
/// pre_tool_call Hook
|
||||||
|
pub(crate) pre_tool_call: Vec<Box<dyn Hook<PreToolCall>>>,
|
||||||
|
/// post_tool_call Hook
|
||||||
|
pub(crate) post_tool_call: Vec<Box<dyn Hook<PostToolCall>>>,
|
||||||
|
/// on_turn_end Hook
|
||||||
|
pub(crate) on_turn_end: Vec<Box<dyn Hook<OnTurnEnd>>>,
|
||||||
|
/// on_abort Hook
|
||||||
|
pub(crate) on_abort: Vec<Box<dyn Hook<OnAbort>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ツール実行後に呼ばれる
|
impl Default for HookRegistry {
|
||||||
///
|
fn default() -> Self {
|
||||||
/// ツールの実行結果を書き換えたり、隠蔽したりできます。
|
Self::new()
|
||||||
async fn after_tool_call(
|
}
|
||||||
&self,
|
|
||||||
_tool_result: &mut ToolResult,
|
|
||||||
) -> Result<ControlFlow, HookError> {
|
|
||||||
Ok(ControlFlow::Continue)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ターン終了時に呼ばれる
|
impl HookRegistry {
|
||||||
///
|
/// 空の HookRegistry を作成
|
||||||
/// 生成されたメッセージを検査し、必要なら追加メッセージで継続を指示できます。
|
pub fn new() -> Self {
|
||||||
/// `TurnResult::ContinueWithMessages`を返すと、指定したメッセージを追加して
|
Self {
|
||||||
/// 次のターンに進みます。
|
on_prompt_submit: Vec::new(),
|
||||||
async fn on_turn_end(&self, _messages: &[crate::Message]) -> Result<TurnResult, HookError> {
|
pre_llm_request: Vec::new(),
|
||||||
Ok(TurnResult::Finish)
|
pre_tool_call: Vec::new(),
|
||||||
|
post_tool_call: Vec::new(),
|
||||||
|
on_turn_end: Vec::new(),
|
||||||
|
on_abort: Vec::new(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
//!
|
//!
|
||||||
//! - [`Worker`] - LLMとの対話を管理する中心コンポーネント
|
//! - [`Worker`] - LLMとの対話を管理する中心コンポーネント
|
||||||
//! - [`tool::Tool`] - LLMから呼び出し可能なツール
|
//! - [`tool::Tool`] - LLMから呼び出し可能なツール
|
||||||
//! - [`hook::WorkerHook`] - ターン進行への介入
|
//! - [`hook::Hook`] - ターン進行への介入
|
||||||
//! - [`subscriber::WorkerSubscriber`] - ストリーミングイベントの購読
|
//! - [`subscriber::WorkerSubscriber`] - ストリーミングイベントの購読
|
||||||
//!
|
//!
|
||||||
//! # Quick Start
|
//! # Quick Start
|
||||||
|
|
@ -19,8 +19,7 @@
|
||||||
//! .system_prompt("You are a helpful assistant.");
|
//! .system_prompt("You are a helpful assistant.");
|
||||||
//!
|
//!
|
||||||
//! // ツールを登録(オプション)
|
//! // ツールを登録(オプション)
|
||||||
//! use llm_worker::tool::Tool;
|
//! // worker.register_tool(my_tool_definition)?;
|
||||||
//! worker.register_tool(my_tool);
|
|
||||||
//!
|
//!
|
||||||
//! // 対話を実行
|
//! // 対話を実行
|
||||||
//! let history = worker.run("Hello!").await?;
|
//! let history = worker.run("Hello!").await?;
|
||||||
|
|
@ -48,9 +47,5 @@ pub mod subscriber;
|
||||||
pub mod timeline;
|
pub mod timeline;
|
||||||
pub mod tool;
|
pub mod tool;
|
||||||
|
|
||||||
// =============================================================================
|
|
||||||
// トップレベル公開(最も頻繁に使う型)
|
|
||||||
// =============================================================================
|
|
||||||
|
|
||||||
pub use message::{ContentPart, Message, MessageContent, Role};
|
pub use message::{ContentPart, Message, MessageContent, Role};
|
||||||
pub use worker::{Worker, WorkerConfig, WorkerError};
|
pub use worker::{ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult};
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
//! LLMクライアント層
|
//! LLMクライアント層
|
||||||
//!
|
//!
|
||||||
//! 各LLMプロバイダと通信し、統一された[`Event`](crate::llm_client::event::Event)
|
//! 各LLMプロバイダと通信し、統一された[`Event`]
|
||||||
//! ストリームを出力します。
|
//! ストリームを出力します。
|
||||||
//!
|
//!
|
||||||
//! # サポートするプロバイダ
|
//! # サポートするプロバイダ
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
//! Worker状態
|
//! Worker状態
|
||||||
//!
|
//!
|
||||||
//! Type-stateパターンによるキャッシュ保護のための状態マーカー型。
|
//! Type-stateパターンによるキャッシュ保護のための状態マーカー型。
|
||||||
//! Workerは`Mutable` → `Locked`の状態遷移を持ちます。
|
//! Workerは`Mutable` → `CacheLocked`の状態遷移を持ちます。
|
||||||
|
|
||||||
/// Worker状態を表すマーカートレイト
|
/// Worker状態を表すマーカートレイト
|
||||||
///
|
///
|
||||||
|
|
@ -19,7 +19,7 @@ mod private {
|
||||||
/// - メッセージ履歴の編集(追加、削除、クリア)
|
/// - メッセージ履歴の編集(追加、削除、クリア)
|
||||||
/// - ツール・Hookの登録
|
/// - ツール・Hookの登録
|
||||||
///
|
///
|
||||||
/// `Worker::lock()`により[`Locked`]状態へ遷移できます。
|
/// `Worker::lock()`により[`CacheLocked`]状態へ遷移できます。
|
||||||
///
|
///
|
||||||
/// # Examples
|
/// # Examples
|
||||||
///
|
///
|
||||||
|
|
@ -42,7 +42,7 @@ pub struct Mutable;
|
||||||
impl private::Sealed for Mutable {}
|
impl private::Sealed for Mutable {}
|
||||||
impl WorkerState for Mutable {}
|
impl WorkerState for Mutable {}
|
||||||
|
|
||||||
/// ロック状態(キャッシュ保護)
|
/// キャッシュロック状態(キャッシュ保護)
|
||||||
///
|
///
|
||||||
/// この状態では以下の制限があります:
|
/// この状態では以下の制限があります:
|
||||||
/// - システムプロンプトの変更不可
|
/// - システムプロンプトの変更不可
|
||||||
|
|
@ -54,7 +54,7 @@ impl WorkerState for Mutable {}
|
||||||
/// `Worker::unlock()`により[`Mutable`]状態へ戻せますが、
|
/// `Worker::unlock()`により[`Mutable`]状態へ戻せますが、
|
||||||
/// キャッシュ保護が解除されることに注意してください。
|
/// キャッシュ保護が解除されることに注意してください。
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
pub struct Locked;
|
pub struct CacheLocked;
|
||||||
|
|
||||||
impl private::Sealed for Locked {}
|
impl private::Sealed for CacheLocked {}
|
||||||
impl WorkerState for Locked {}
|
impl WorkerState for CacheLocked {}
|
||||||
|
|
|
||||||
|
|
@ -65,10 +65,10 @@ pub trait WorkerSubscriber: Send {
|
||||||
///
|
///
|
||||||
/// ブロック開始時にDefault::default()で生成され、
|
/// ブロック開始時にDefault::default()で生成され、
|
||||||
/// ブロック終了時に破棄される。
|
/// ブロック終了時に破棄される。
|
||||||
type TextBlockScope: Default + Send;
|
type TextBlockScope: Default + Send + Sync;
|
||||||
|
|
||||||
/// ツール使用ブロック処理用のスコープ型
|
/// ツール使用ブロック処理用のスコープ型
|
||||||
type ToolUseBlockScope: Default + Send;
|
type ToolUseBlockScope: Default + Send + Sync;
|
||||||
|
|
||||||
// =========================================================================
|
// =========================================================================
|
||||||
// ブロックイベント(スコープ管理あり)
|
// ブロックイベント(スコープ管理あり)
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ use crate::handler::*;
|
||||||
/// 各Handlerは独自のScope型を持つため、Timelineで保持するには型消去が必要です。
|
/// 各Handlerは独自のScope型を持つため、Timelineで保持するには型消去が必要です。
|
||||||
/// 通常は直接使用せず、`Timeline::on_text_block()`などのメソッド経由で
|
/// 通常は直接使用せず、`Timeline::on_text_block()`などのメソッド経由で
|
||||||
/// 自動的にラップされます。
|
/// 自動的にラップされます。
|
||||||
pub trait ErasedHandler<K: Kind>: Send {
|
pub trait ErasedHandler<K: Kind>: Send + Sync {
|
||||||
/// イベントをディスパッチ
|
/// イベントをディスパッチ
|
||||||
fn dispatch(&mut self, event: &K::Event);
|
fn dispatch(&mut self, event: &K::Event);
|
||||||
/// スコープを開始(Block開始時)
|
/// スコープを開始(Block開始時)
|
||||||
|
|
@ -54,9 +54,9 @@ where
|
||||||
|
|
||||||
impl<H, K> ErasedHandler<K> for HandlerWrapper<H, K>
|
impl<H, K> ErasedHandler<K> for HandlerWrapper<H, K>
|
||||||
where
|
where
|
||||||
H: Handler<K> + Send,
|
H: Handler<K> + Send + Sync,
|
||||||
K: Kind,
|
K: Kind,
|
||||||
H::Scope: Send,
|
H::Scope: Send + Sync,
|
||||||
{
|
{
|
||||||
fn dispatch(&mut self, event: &K::Event) {
|
fn dispatch(&mut self, event: &K::Event) {
|
||||||
if let Some(scope) = &mut self.scope {
|
if let Some(scope) = &mut self.scope {
|
||||||
|
|
@ -78,7 +78,7 @@ where
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// ブロックハンドラーの型消去trait
|
/// ブロックハンドラーの型消去trait
|
||||||
trait ErasedBlockHandler: Send {
|
trait ErasedBlockHandler: Send + Sync {
|
||||||
fn dispatch_start(&mut self, start: &BlockStart);
|
fn dispatch_start(&mut self, start: &BlockStart);
|
||||||
fn dispatch_delta(&mut self, delta: &BlockDelta);
|
fn dispatch_delta(&mut self, delta: &BlockDelta);
|
||||||
fn dispatch_stop(&mut self, stop: &BlockStop);
|
fn dispatch_stop(&mut self, stop: &BlockStop);
|
||||||
|
|
@ -112,8 +112,8 @@ where
|
||||||
|
|
||||||
impl<H> ErasedBlockHandler for TextBlockHandlerWrapper<H>
|
impl<H> ErasedBlockHandler for TextBlockHandlerWrapper<H>
|
||||||
where
|
where
|
||||||
H: Handler<TextBlockKind> + Send,
|
H: Handler<TextBlockKind> + Send + Sync,
|
||||||
H::Scope: Send,
|
H::Scope: Send + Sync,
|
||||||
{
|
{
|
||||||
fn dispatch_start(&mut self, start: &BlockStart) {
|
fn dispatch_start(&mut self, start: &BlockStart) {
|
||||||
if let Some(scope) = &mut self.scope {
|
if let Some(scope) = &mut self.scope {
|
||||||
|
|
@ -185,8 +185,8 @@ where
|
||||||
|
|
||||||
impl<H> ErasedBlockHandler for ThinkingBlockHandlerWrapper<H>
|
impl<H> ErasedBlockHandler for ThinkingBlockHandlerWrapper<H>
|
||||||
where
|
where
|
||||||
H: Handler<ThinkingBlockKind> + Send,
|
H: Handler<ThinkingBlockKind> + Send + Sync,
|
||||||
H::Scope: Send,
|
H::Scope: Send + Sync,
|
||||||
{
|
{
|
||||||
fn dispatch_start(&mut self, start: &BlockStart) {
|
fn dispatch_start(&mut self, start: &BlockStart) {
|
||||||
if let Some(scope) = &mut self.scope {
|
if let Some(scope) = &mut self.scope {
|
||||||
|
|
@ -255,8 +255,8 @@ where
|
||||||
|
|
||||||
impl<H> ErasedBlockHandler for ToolUseBlockHandlerWrapper<H>
|
impl<H> ErasedBlockHandler for ToolUseBlockHandlerWrapper<H>
|
||||||
where
|
where
|
||||||
H: Handler<ToolUseBlockKind> + Send,
|
H: Handler<ToolUseBlockKind> + Send + Sync,
|
||||||
H::Scope: Send,
|
H::Scope: Send + Sync,
|
||||||
{
|
{
|
||||||
fn dispatch_start(&mut self, start: &BlockStart) {
|
fn dispatch_start(&mut self, start: &BlockStart) {
|
||||||
if let Some(scope) = &mut self.scope {
|
if let Some(scope) = &mut self.scope {
|
||||||
|
|
@ -391,8 +391,8 @@ impl Timeline {
|
||||||
/// UsageKind用のHandlerを登録
|
/// UsageKind用のHandlerを登録
|
||||||
pub fn on_usage<H>(&mut self, handler: H) -> &mut Self
|
pub fn on_usage<H>(&mut self, handler: H) -> &mut Self
|
||||||
where
|
where
|
||||||
H: Handler<UsageKind> + Send + 'static,
|
H: Handler<UsageKind> + Send + Sync + 'static,
|
||||||
H::Scope: Send,
|
H::Scope: Send + Sync,
|
||||||
{
|
{
|
||||||
// Meta系はデフォルトでスコープを開始しておく
|
// Meta系はデフォルトでスコープを開始しておく
|
||||||
let mut wrapper = HandlerWrapper::new(handler);
|
let mut wrapper = HandlerWrapper::new(handler);
|
||||||
|
|
@ -404,8 +404,8 @@ impl Timeline {
|
||||||
/// PingKind用のHandlerを登録
|
/// PingKind用のHandlerを登録
|
||||||
pub fn on_ping<H>(&mut self, handler: H) -> &mut Self
|
pub fn on_ping<H>(&mut self, handler: H) -> &mut Self
|
||||||
where
|
where
|
||||||
H: Handler<PingKind> + Send + 'static,
|
H: Handler<PingKind> + Send + Sync + 'static,
|
||||||
H::Scope: Send,
|
H::Scope: Send + Sync,
|
||||||
{
|
{
|
||||||
let mut wrapper = HandlerWrapper::new(handler);
|
let mut wrapper = HandlerWrapper::new(handler);
|
||||||
wrapper.start_scope();
|
wrapper.start_scope();
|
||||||
|
|
@ -416,8 +416,8 @@ impl Timeline {
|
||||||
/// StatusKind用のHandlerを登録
|
/// StatusKind用のHandlerを登録
|
||||||
pub fn on_status<H>(&mut self, handler: H) -> &mut Self
|
pub fn on_status<H>(&mut self, handler: H) -> &mut Self
|
||||||
where
|
where
|
||||||
H: Handler<StatusKind> + Send + 'static,
|
H: Handler<StatusKind> + Send + Sync + 'static,
|
||||||
H::Scope: Send,
|
H::Scope: Send + Sync,
|
||||||
{
|
{
|
||||||
let mut wrapper = HandlerWrapper::new(handler);
|
let mut wrapper = HandlerWrapper::new(handler);
|
||||||
wrapper.start_scope();
|
wrapper.start_scope();
|
||||||
|
|
@ -428,8 +428,8 @@ impl Timeline {
|
||||||
/// ErrorKind用のHandlerを登録
|
/// ErrorKind用のHandlerを登録
|
||||||
pub fn on_error<H>(&mut self, handler: H) -> &mut Self
|
pub fn on_error<H>(&mut self, handler: H) -> &mut Self
|
||||||
where
|
where
|
||||||
H: Handler<ErrorKind> + Send + 'static,
|
H: Handler<ErrorKind> + Send + Sync + 'static,
|
||||||
H::Scope: Send,
|
H::Scope: Send + Sync,
|
||||||
{
|
{
|
||||||
let mut wrapper = HandlerWrapper::new(handler);
|
let mut wrapper = HandlerWrapper::new(handler);
|
||||||
wrapper.start_scope();
|
wrapper.start_scope();
|
||||||
|
|
@ -440,8 +440,8 @@ impl Timeline {
|
||||||
/// TextBlockKind用のHandlerを登録
|
/// TextBlockKind用のHandlerを登録
|
||||||
pub fn on_text_block<H>(&mut self, handler: H) -> &mut Self
|
pub fn on_text_block<H>(&mut self, handler: H) -> &mut Self
|
||||||
where
|
where
|
||||||
H: Handler<TextBlockKind> + Send + 'static,
|
H: Handler<TextBlockKind> + Send + Sync + 'static,
|
||||||
H::Scope: Send,
|
H::Scope: Send + Sync,
|
||||||
{
|
{
|
||||||
self.text_block_handlers
|
self.text_block_handlers
|
||||||
.push(Box::new(TextBlockHandlerWrapper::new(handler)));
|
.push(Box::new(TextBlockHandlerWrapper::new(handler)));
|
||||||
|
|
@ -451,8 +451,8 @@ impl Timeline {
|
||||||
/// ThinkingBlockKind用のHandlerを登録
|
/// ThinkingBlockKind用のHandlerを登録
|
||||||
pub fn on_thinking_block<H>(&mut self, handler: H) -> &mut Self
|
pub fn on_thinking_block<H>(&mut self, handler: H) -> &mut Self
|
||||||
where
|
where
|
||||||
H: Handler<ThinkingBlockKind> + Send + 'static,
|
H: Handler<ThinkingBlockKind> + Send + Sync + 'static,
|
||||||
H::Scope: Send,
|
H::Scope: Send + Sync,
|
||||||
{
|
{
|
||||||
self.thinking_block_handlers
|
self.thinking_block_handlers
|
||||||
.push(Box::new(ThinkingBlockHandlerWrapper::new(handler)));
|
.push(Box::new(ThinkingBlockHandlerWrapper::new(handler)));
|
||||||
|
|
@ -462,8 +462,8 @@ impl Timeline {
|
||||||
/// ToolUseBlockKind用のHandlerを登録
|
/// ToolUseBlockKind用のHandlerを登録
|
||||||
pub fn on_tool_use_block<H>(&mut self, handler: H) -> &mut Self
|
pub fn on_tool_use_block<H>(&mut self, handler: H) -> &mut Self
|
||||||
where
|
where
|
||||||
H: Handler<ToolUseBlockKind> + Send + 'static,
|
H: Handler<ToolUseBlockKind> + Send + Sync + 'static,
|
||||||
H::Scope: Send,
|
H::Scope: Send + Sync,
|
||||||
{
|
{
|
||||||
self.tool_use_block_handlers
|
self.tool_use_block_handlers
|
||||||
.push(Box::new(ToolUseBlockHandlerWrapper::new(handler)));
|
.push(Box::new(ToolUseBlockHandlerWrapper::new(handler)));
|
||||||
|
|
@ -578,6 +578,21 @@ impl Timeline {
|
||||||
pub fn current_block(&self) -> Option<BlockType> {
|
pub fn current_block(&self) -> Option<BlockType> {
|
||||||
self.current_block
|
self.current_block
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 現在アクティブなブロックを中断する
|
||||||
|
///
|
||||||
|
/// キャンセルやエラー時に呼び出し、進行中のブロックに対して
|
||||||
|
/// BlockAbortイベントを発火してスコープをクリーンアップする。
|
||||||
|
pub fn abort_current_block(&mut self) {
|
||||||
|
if let Some(block_type) = self.current_block {
|
||||||
|
let abort = crate::timeline::event::BlockAbort {
|
||||||
|
index: 0, // インデックスは不明なので0
|
||||||
|
block_type,
|
||||||
|
reason: "Cancelled".to_string(),
|
||||||
|
};
|
||||||
|
self.handle_block_abort(&abort);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,8 @@
|
||||||
//! LLMから呼び出し可能なツールを定義するためのトレイト。
|
//! LLMから呼び出し可能なツールを定義するためのトレイト。
|
||||||
//! 通常は`#[tool]`マクロを使用して自動実装します。
|
//! 通常は`#[tool]`マクロを使用して自動実装します。
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
@ -21,64 +23,126 @@ pub enum ToolError {
|
||||||
Internal(String),
|
Internal(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// ToolMeta - 不変のメタ情報
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
/// ツールのメタ情報(登録時に固定、不変)
|
||||||
|
///
|
||||||
|
/// `ToolDefinition` ファクトリから生成され、Worker に登録後は変更されません。
|
||||||
|
/// LLM へのツール定義送信に使用されます。
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct ToolMeta {
|
||||||
|
/// ツール名(LLMが識別に使用)
|
||||||
|
pub name: String,
|
||||||
|
/// ツールの説明(LLMへのプロンプトに含まれる)
|
||||||
|
pub description: String,
|
||||||
|
/// 引数のJSON Schema
|
||||||
|
pub input_schema: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolMeta {
|
||||||
|
/// 新しい ToolMeta を作成
|
||||||
|
pub fn new(name: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
name: name.into(),
|
||||||
|
description: String::new(),
|
||||||
|
input_schema: Value::Object(Default::default()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 説明を設定
|
||||||
|
pub fn description(mut self, desc: impl Into<String>) -> Self {
|
||||||
|
self.description = desc.into();
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 引数スキーマを設定
|
||||||
|
pub fn input_schema(mut self, schema: Value) -> Self {
|
||||||
|
self.input_schema = schema;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// ToolDefinition - ファクトリ型
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
/// ツール定義ファクトリ
|
||||||
|
///
|
||||||
|
/// 呼び出すと `(ToolMeta, Arc<dyn Tool>)` を返します。
|
||||||
|
/// Worker への登録時に一度だけ呼び出され、メタ情報とインスタンスが
|
||||||
|
/// セッションスコープでキャッシュされます。
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```ignore
|
||||||
|
/// let def: ToolDefinition = Arc::new(|| {
|
||||||
|
/// (
|
||||||
|
/// ToolMeta::new("my_tool")
|
||||||
|
/// .description("My tool description")
|
||||||
|
/// .input_schema(json!({"type": "object"})),
|
||||||
|
/// Arc::new(MyToolImpl { state: 0 }) as Arc<dyn Tool>,
|
||||||
|
/// )
|
||||||
|
/// });
|
||||||
|
/// worker.register_tool(def)?;
|
||||||
|
/// ```
|
||||||
|
pub type ToolDefinition = Arc<dyn Fn() -> (ToolMeta, Arc<dyn Tool>) + Send + Sync>;
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Tool trait
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
/// LLMから呼び出し可能なツールを定義するトレイト
|
/// LLMから呼び出し可能なツールを定義するトレイト
|
||||||
///
|
///
|
||||||
/// ツールはLLMが外部リソースにアクセスしたり、
|
/// ツールはLLMが外部リソースにアクセスしたり、
|
||||||
/// 計算を実行したりするために使用します。
|
/// 計算を実行したりするために使用します。
|
||||||
|
/// セッション中の状態を保持できます。
|
||||||
///
|
///
|
||||||
/// # 実装方法
|
/// # 実装方法
|
||||||
///
|
///
|
||||||
/// 通常は`#[tool]`マクロを使用して自動実装します:
|
/// 通常は`#[tool_registry]`マクロを使用して自動実装します:
|
||||||
///
|
///
|
||||||
/// ```ignore
|
/// ```ignore
|
||||||
/// use llm_worker::tool;
|
/// #[tool_registry]
|
||||||
///
|
/// impl MyApp {
|
||||||
/// #[tool(description = "Search the web for information")]
|
/// #[tool]
|
||||||
/// async fn search(query: String) -> String {
|
/// async fn search(&self, query: String) -> String {
|
||||||
/// // 検索処理
|
|
||||||
/// format!("Results for: {}", query)
|
/// format!("Results for: {}", query)
|
||||||
/// }
|
/// }
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// // 登録
|
||||||
|
/// worker.register_tool(app.search_definition())?;
|
||||||
/// ```
|
/// ```
|
||||||
///
|
///
|
||||||
/// # 手動実装
|
/// # 手動実装
|
||||||
///
|
///
|
||||||
/// ```ignore
|
/// ```ignore
|
||||||
/// use llm_worker::tool::{Tool, ToolError};
|
/// use llm_worker::tool::{Tool, ToolError, ToolMeta, ToolDefinition};
|
||||||
/// use serde_json::{json, Value};
|
/// use std::sync::Arc;
|
||||||
///
|
///
|
||||||
/// struct MyTool;
|
/// struct MyTool { counter: std::sync::atomic::AtomicUsize }
|
||||||
///
|
///
|
||||||
/// #[async_trait::async_trait]
|
/// #[async_trait::async_trait]
|
||||||
/// impl Tool for MyTool {
|
/// impl Tool for MyTool {
|
||||||
/// fn name(&self) -> &str { "my_tool" }
|
|
||||||
/// fn description(&self) -> &str { "My custom tool" }
|
|
||||||
/// fn input_schema(&self) -> Value {
|
|
||||||
/// json!({
|
|
||||||
/// "type": "object",
|
|
||||||
/// "properties": {
|
|
||||||
/// "query": { "type": "string" }
|
|
||||||
/// },
|
|
||||||
/// "required": ["query"]
|
|
||||||
/// })
|
|
||||||
/// }
|
|
||||||
/// async fn execute(&self, input: &str) -> Result<String, ToolError> {
|
/// async fn execute(&self, input: &str) -> Result<String, ToolError> {
|
||||||
|
/// self.counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||||||
/// Ok("result".to_string())
|
/// Ok("result".to_string())
|
||||||
/// }
|
/// }
|
||||||
/// }
|
/// }
|
||||||
|
///
|
||||||
|
/// let def: ToolDefinition = Arc::new(|| {
|
||||||
|
/// (
|
||||||
|
/// ToolMeta::new("my_tool")
|
||||||
|
/// .description("My custom tool")
|
||||||
|
/// .input_schema(serde_json::json!({"type": "object"})),
|
||||||
|
/// Arc::new(MyTool { counter: Default::default() }) as Arc<dyn Tool>,
|
||||||
|
/// )
|
||||||
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Tool: Send + Sync {
|
pub trait Tool: Send + Sync {
|
||||||
/// ツール名(LLMが識別に使用)
|
|
||||||
fn name(&self) -> &str;
|
|
||||||
|
|
||||||
/// ツールの説明(LLMへのプロンプトに含まれる)
|
|
||||||
fn description(&self) -> &str;
|
|
||||||
|
|
||||||
/// 引数のJSON Schema
|
|
||||||
///
|
|
||||||
/// LLMはこのスキーマに従って引数を生成します。
|
|
||||||
fn input_schema(&self) -> Value;
|
|
||||||
|
|
||||||
/// ツールを実行する
|
/// ツールを実行する
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
|
|
|
||||||
|
|
@ -3,19 +3,27 @@ use std::marker::PhantomData;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
use tracing::{debug, info, trace, warn};
|
use tracing::{debug, info, trace, warn};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
ContentPart, Message, MessageContent, Role,
|
ContentPart, Message, MessageContent, Role,
|
||||||
hook::{ControlFlow, HookError, ToolCall, ToolResult, TurnResult, WorkerHook},
|
hook::{
|
||||||
llm_client::{ClientError, ConfigWarning, LlmClient, Request, RequestConfig, ToolDefinition},
|
Hook, HookError, HookRegistry, OnAbort, OnPromptSubmit, OnPromptSubmitResult, OnTurnEnd,
|
||||||
state::{Locked, Mutable, WorkerState},
|
OnTurnEndResult, PostToolCall, PostToolCallContext, PostToolCallResult, PreLlmRequest,
|
||||||
|
PreLlmRequestResult, PreToolCall, PreToolCallResult, ToolCall, ToolCallContext, ToolResult,
|
||||||
|
},
|
||||||
|
llm_client::{
|
||||||
|
ClientError, ConfigWarning, LlmClient, Request, RequestConfig,
|
||||||
|
ToolDefinition as LlmToolDefinition,
|
||||||
|
},
|
||||||
|
state::{CacheLocked, Mutable, WorkerState},
|
||||||
subscriber::{
|
subscriber::{
|
||||||
ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter,
|
ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter,
|
||||||
ToolUseBlockSubscriberAdapter, UsageSubscriberAdapter, WorkerSubscriber,
|
ToolUseBlockSubscriberAdapter, UsageSubscriberAdapter, WorkerSubscriber,
|
||||||
},
|
},
|
||||||
timeline::{TextBlockCollector, Timeline, ToolCallCollector},
|
timeline::{TextBlockCollector, Timeline, ToolCallCollector},
|
||||||
tool::{Tool, ToolError},
|
tool::{Tool, ToolDefinition, ToolError, ToolMeta},
|
||||||
};
|
};
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
@ -37,11 +45,22 @@ pub enum WorkerError {
|
||||||
/// 処理が中断された
|
/// 処理が中断された
|
||||||
#[error("Aborted: {0}")]
|
#[error("Aborted: {0}")]
|
||||||
Aborted(String),
|
Aborted(String),
|
||||||
|
/// Cancellation Tokenによって中断された
|
||||||
|
#[error("Cancelled")]
|
||||||
|
Cancelled,
|
||||||
/// 設定に関する警告(未サポートのオプション)
|
/// 設定に関する警告(未サポートのオプション)
|
||||||
#[error("Config warnings: {}", .0.iter().map(|w| w.to_string()).collect::<Vec<_>>().join(", "))]
|
#[error("Config warnings: {}", .0.iter().map(|w| w.to_string()).collect::<Vec<_>>().join(", "))]
|
||||||
ConfigWarnings(Vec<ConfigWarning>),
|
ConfigWarnings(Vec<ConfigWarning>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// ツール登録エラー
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum ToolRegistryError {
|
||||||
|
/// 同名のツールが既に登録されている
|
||||||
|
#[error("Tool with name '{0}' already registered")]
|
||||||
|
DuplicateName(String),
|
||||||
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// Worker Config
|
// Worker Config
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
@ -59,11 +78,11 @@ pub struct WorkerConfig {
|
||||||
|
|
||||||
/// Workerの実行結果(ステータス)
|
/// Workerの実行結果(ステータス)
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum WorkerResult<'a> {
|
pub enum WorkerResult {
|
||||||
/// 完了(ユーザー入力待ち状態)
|
/// 完了(ユーザー入力待ち状態)
|
||||||
Finished(&'a [Message]),
|
Finished,
|
||||||
/// 一時停止(再開可能)
|
/// 一時停止(再開可能)
|
||||||
Paused(&'a [Message]),
|
Paused,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 内部用: ツール実行結果
|
/// 内部用: ツール実行結果
|
||||||
|
|
@ -77,7 +96,7 @@ enum ToolExecutionResult {
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// ターンイベントを通知するためのコールバック (型消去)
|
/// ターンイベントを通知するためのコールバック (型消去)
|
||||||
trait TurnNotifier: Send {
|
trait TurnNotifier: Send + Sync {
|
||||||
fn on_turn_start(&self, turn: usize);
|
fn on_turn_start(&self, turn: usize);
|
||||||
fn on_turn_end(&self, turn: usize);
|
fn on_turn_end(&self, turn: usize);
|
||||||
}
|
}
|
||||||
|
|
@ -112,7 +131,7 @@ impl<S: WorkerSubscriber + 'static> TurnNotifier for SubscriberTurnNotifier<S> {
|
||||||
/// # 状態遷移(Type-state)
|
/// # 状態遷移(Type-state)
|
||||||
///
|
///
|
||||||
/// - [`Mutable`]: 初期状態。システムプロンプトや履歴を自由に編集可能。
|
/// - [`Mutable`]: 初期状態。システムプロンプトや履歴を自由に編集可能。
|
||||||
/// - [`Locked`]: キャッシュ保護状態。`lock()`で遷移。前方コンテキストは不変。
|
/// - [`CacheLocked`]: キャッシュ保護状態。`lock()`で遷移。前方コンテキストは不変。
|
||||||
///
|
///
|
||||||
/// # Examples
|
/// # Examples
|
||||||
///
|
///
|
||||||
|
|
@ -147,15 +166,15 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
||||||
text_block_collector: TextBlockCollector,
|
text_block_collector: TextBlockCollector,
|
||||||
/// ツールコールコレクター(Timeline用ハンドラ)
|
/// ツールコールコレクター(Timeline用ハンドラ)
|
||||||
tool_call_collector: ToolCallCollector,
|
tool_call_collector: ToolCallCollector,
|
||||||
/// 登録されたツール
|
/// 登録されたツール (meta, instance)
|
||||||
tools: HashMap<String, Arc<dyn Tool>>,
|
tools: HashMap<String, (ToolMeta, Arc<dyn Tool>)>,
|
||||||
/// 登録されたHook
|
/// Hook レジストリ
|
||||||
hooks: Vec<Box<dyn WorkerHook>>,
|
hooks: HookRegistry,
|
||||||
/// システムプロンプト
|
/// システムプロンプト
|
||||||
system_prompt: Option<String>,
|
system_prompt: Option<String>,
|
||||||
/// メッセージ履歴(Workerが所有)
|
/// メッセージ履歴(Workerが所有)
|
||||||
history: Vec<Message>,
|
history: Vec<Message>,
|
||||||
/// ロック時点での履歴長(Locked状態でのみ意味を持つ)
|
/// ロック時点での履歴長(CacheLocked状態でのみ意味を持つ)
|
||||||
locked_prefix_len: usize,
|
locked_prefix_len: usize,
|
||||||
/// ターンカウント
|
/// ターンカウント
|
||||||
turn_count: usize,
|
turn_count: usize,
|
||||||
|
|
@ -163,6 +182,11 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
||||||
turn_notifiers: Vec<Box<dyn TurnNotifier>>,
|
turn_notifiers: Vec<Box<dyn TurnNotifier>>,
|
||||||
/// リクエスト設定(max_tokens, temperature等)
|
/// リクエスト設定(max_tokens, temperature等)
|
||||||
request_config: RequestConfig,
|
request_config: RequestConfig,
|
||||||
|
/// 前回の実行が中断されたかどうか
|
||||||
|
last_run_interrupted: bool,
|
||||||
|
/// キャンセル通知用チャネル(実行中断用)
|
||||||
|
cancel_tx: mpsc::Sender<()>,
|
||||||
|
cancel_rx: mpsc::Receiver<()>,
|
||||||
/// 状態マーカー
|
/// 状態マーカー
|
||||||
_state: PhantomData<S>,
|
_state: PhantomData<S>,
|
||||||
}
|
}
|
||||||
|
|
@ -172,6 +196,57 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
|
fn reset_interruption_state(&mut self) {
|
||||||
|
self.last_run_interrupted = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// ターンを実行
|
||||||
|
///
|
||||||
|
/// 新しいユーザーメッセージを履歴に追加し、LLMにリクエストを送信する。
|
||||||
|
/// ツール呼び出しがある場合は自動的にループする。
|
||||||
|
pub async fn run(
|
||||||
|
&mut self,
|
||||||
|
user_input: impl Into<String>,
|
||||||
|
) -> Result<WorkerResult, WorkerError> {
|
||||||
|
self.reset_interruption_state();
|
||||||
|
// Hook: on_prompt_submit
|
||||||
|
let mut user_message = Message::user(user_input);
|
||||||
|
let result = self.run_on_prompt_submit_hooks(&mut user_message).await;
|
||||||
|
let result = match result {
|
||||||
|
Ok(value) => value,
|
||||||
|
Err(err) => return self.finalize_interruption(Err(err)).await,
|
||||||
|
};
|
||||||
|
match result {
|
||||||
|
OnPromptSubmitResult::Cancel(reason) => {
|
||||||
|
self.last_run_interrupted = true;
|
||||||
|
return self.finalize_interruption(Err(WorkerError::Aborted(reason))).await;
|
||||||
|
}
|
||||||
|
OnPromptSubmitResult::Continue => {}
|
||||||
|
}
|
||||||
|
self.history.push(user_message);
|
||||||
|
let result = self.run_turn_loop().await;
|
||||||
|
self.finalize_interruption(result).await
|
||||||
|
}
|
||||||
|
|
||||||
|
fn drain_cancel_queue(&mut self) {
|
||||||
|
use tokio::sync::mpsc::error::TryRecvError;
|
||||||
|
loop {
|
||||||
|
match self.cancel_rx.try_recv() {
|
||||||
|
Ok(()) => continue,
|
||||||
|
Err(TryRecvError::Empty) | Err(TryRecvError::Disconnected) => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn try_cancelled(&mut self) -> bool {
|
||||||
|
use tokio::sync::mpsc::error::TryRecvError;
|
||||||
|
match self.cancel_rx.try_recv() {
|
||||||
|
Ok(()) => true,
|
||||||
|
Err(TryRecvError::Empty) => false,
|
||||||
|
Err(TryRecvError::Disconnected) => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// イベント購読者を登録する
|
/// イベント購読者を登録する
|
||||||
///
|
///
|
||||||
/// 登録したSubscriberは、LLMからのストリーミングイベントを
|
/// 登録したSubscriberは、LLMからのストリーミングイベントを
|
||||||
|
|
@ -230,52 +305,71 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
/// ツールを登録する
|
/// ツールを登録する
|
||||||
///
|
///
|
||||||
/// 登録されたツールはLLMからの呼び出しで自動的に実行されます。
|
/// 登録されたツールはLLMからの呼び出しで自動的に実行されます。
|
||||||
/// 同名のツールを登録した場合、後から登録したものが優先されます。
|
/// 同名のツールを登録するとエラーになります。
|
||||||
///
|
///
|
||||||
/// # Examples
|
/// # Examples
|
||||||
///
|
///
|
||||||
/// ```ignore
|
/// ```ignore
|
||||||
/// use llm_worker::Worker;
|
/// use llm_worker::tool::{ToolMeta, ToolDefinition, Tool};
|
||||||
/// use my_tools::SearchTool;
|
/// use std::sync::Arc;
|
||||||
///
|
///
|
||||||
/// worker.register_tool(SearchTool::new());
|
/// let def: ToolDefinition = Arc::new(|| {
|
||||||
|
/// (ToolMeta::new("search").description("..."), Arc::new(MyTool) as Arc<dyn Tool>)
|
||||||
|
/// });
|
||||||
|
/// worker.register_tool(def)?;
|
||||||
/// ```
|
/// ```
|
||||||
pub fn register_tool(&mut self, tool: impl Tool + 'static) {
|
pub fn register_tool(&mut self, factory: ToolDefinition) -> Result<(), ToolRegistryError> {
|
||||||
let name = tool.name().to_string();
|
let (meta, instance) = factory();
|
||||||
self.tools.insert(name, Arc::new(tool));
|
if self.tools.contains_key(&meta.name) {
|
||||||
|
return Err(ToolRegistryError::DuplicateName(meta.name.clone()));
|
||||||
|
}
|
||||||
|
self.tools.insert(meta.name.clone(), (meta, instance));
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 複数のツールを登録
|
/// 複数のツールを登録
|
||||||
pub fn register_tools(&mut self, tools: impl IntoIterator<Item = impl Tool + 'static>) {
|
pub fn register_tools(
|
||||||
for tool in tools {
|
&mut self,
|
||||||
self.register_tool(tool);
|
factories: impl IntoIterator<Item = ToolDefinition>,
|
||||||
|
) -> Result<(), ToolRegistryError> {
|
||||||
|
for factory in factories {
|
||||||
|
self.register_tool(factory)?;
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Hookを追加する
|
/// on_prompt_submit Hookを追加する
|
||||||
///
|
///
|
||||||
/// Hookはターンの進行・ツール実行に介入できます。
|
/// `run()` でユーザーメッセージを受け取った直後に呼び出される。
|
||||||
/// 複数のHookを登録した場合、登録順に実行されます。
|
pub fn add_on_prompt_submit_hook(&mut self, hook: impl Hook<OnPromptSubmit> + 'static) {
|
||||||
|
self.hooks.on_prompt_submit.push(Box::new(hook));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// pre_llm_request Hookを追加する
|
||||||
///
|
///
|
||||||
/// # Examples
|
/// 各ターンのLLMリクエスト送信前に呼び出される。
|
||||||
///
|
pub fn add_pre_llm_request_hook(&mut self, hook: impl Hook<PreLlmRequest> + 'static) {
|
||||||
/// ```ignore
|
self.hooks.pre_llm_request.push(Box::new(hook));
|
||||||
/// use llm_worker::{Worker, WorkerHook, ControlFlow, ToolCall};
|
}
|
||||||
///
|
|
||||||
/// struct LoggingHook;
|
/// pre_tool_call Hookを追加する
|
||||||
///
|
pub fn add_pre_tool_call_hook(&mut self, hook: impl Hook<PreToolCall> + 'static) {
|
||||||
/// #[async_trait::async_trait]
|
self.hooks.pre_tool_call.push(Box::new(hook));
|
||||||
/// impl WorkerHook for LoggingHook {
|
}
|
||||||
/// async fn before_tool_call(&self, call: &mut ToolCall) -> Result<ControlFlow, HookError> {
|
|
||||||
/// println!("Calling tool: {}", call.name);
|
/// post_tool_call Hookを追加する
|
||||||
/// Ok(ControlFlow::Continue)
|
pub fn add_post_tool_call_hook(&mut self, hook: impl Hook<PostToolCall> + 'static) {
|
||||||
/// }
|
self.hooks.post_tool_call.push(Box::new(hook));
|
||||||
/// }
|
}
|
||||||
///
|
|
||||||
/// worker.add_hook(LoggingHook);
|
/// on_turn_end Hookを追加する
|
||||||
/// ```
|
pub fn add_on_turn_end_hook(&mut self, hook: impl Hook<OnTurnEnd> + 'static) {
|
||||||
pub fn add_hook(&mut self, hook: impl WorkerHook + 'static) {
|
self.hooks.on_turn_end.push(Box::new(hook));
|
||||||
self.hooks.push(Box::new(hook));
|
}
|
||||||
|
|
||||||
|
/// on_abort Hookを追加する
|
||||||
|
pub fn add_on_abort_hook(&mut self, hook: impl Hook<OnAbort> + 'static) {
|
||||||
|
self.hooks.on_abort.push(Box::new(hook));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// タイムラインへの可変参照を取得(追加ハンドラ登録用)
|
/// タイムラインへの可変参照を取得(追加ハンドラ登録用)
|
||||||
|
|
@ -370,19 +464,59 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
self.request_config.stop_sequences.clear();
|
self.request_config.stop_sequences.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// キャンセル通知用Senderを取得する
|
||||||
|
pub fn cancel_sender(&self) -> mpsc::Sender<()> {
|
||||||
|
self.cancel_tx.clone()
|
||||||
|
}
|
||||||
|
|
||||||
/// リクエスト設定を一括で設定
|
/// リクエスト設定を一括で設定
|
||||||
pub fn set_request_config(&mut self, config: RequestConfig) {
|
pub fn set_request_config(&mut self, config: RequestConfig) {
|
||||||
self.request_config = config;
|
self.request_config = config;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 登録されたツールからToolDefinitionのリストを生成
|
/// 実行をキャンセルする
|
||||||
fn build_tool_definitions(&self) -> Vec<ToolDefinition> {
|
///
|
||||||
|
/// 現在実行中のストリーミングやツール実行を中断します。
|
||||||
|
/// 次のイベントループのチェックポイントでWorkerError::Cancelledが返されます。
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```ignore
|
||||||
|
/// use std::sync::Arc;
|
||||||
|
/// let worker = Arc::new(Mutex::new(Worker::new(client)));
|
||||||
|
///
|
||||||
|
/// // 別スレッドで実行
|
||||||
|
/// let worker_clone = worker.clone();
|
||||||
|
/// tokio::spawn(async move {
|
||||||
|
/// let mut w = worker_clone.lock().unwrap();
|
||||||
|
/// w.run("Long task...").await
|
||||||
|
/// });
|
||||||
|
///
|
||||||
|
/// // キャンセル
|
||||||
|
/// worker.lock().unwrap().cancel();
|
||||||
|
/// ```
|
||||||
|
pub fn cancel(&self) {
|
||||||
|
let _ = self.cancel_tx.try_send(());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// キャンセルされているかチェック
|
||||||
|
pub fn is_cancelled(&mut self) -> bool {
|
||||||
|
self.try_cancelled()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 前回の実行が中断されたかどうか
|
||||||
|
pub fn last_run_interrupted(&self) -> bool {
|
||||||
|
self.last_run_interrupted
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 登録されたツールからLLM用ToolDefinitionのリストを生成
|
||||||
|
fn build_tool_definitions(&self) -> Vec<LlmToolDefinition> {
|
||||||
self.tools
|
self.tools
|
||||||
.values()
|
.values()
|
||||||
.map(|tool| {
|
.map(|(meta, _)| {
|
||||||
ToolDefinition::new(tool.name())
|
LlmToolDefinition::new(&meta.name)
|
||||||
.description(tool.description())
|
.description(&meta.description)
|
||||||
.input_schema(tool.input_schema())
|
.input_schema(meta.input_schema.clone())
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
@ -430,7 +564,11 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// リクエストを構築
|
/// リクエストを構築
|
||||||
fn build_request(&self, tool_definitions: &[ToolDefinition]) -> Request {
|
fn build_request(
|
||||||
|
&self,
|
||||||
|
tool_definitions: &[LlmToolDefinition],
|
||||||
|
context: &[Message],
|
||||||
|
) -> Request {
|
||||||
let mut request = Request::new();
|
let mut request = Request::new();
|
||||||
|
|
||||||
// システムプロンプトを設定
|
// システムプロンプトを設定
|
||||||
|
|
@ -439,7 +577,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// メッセージを追加
|
// メッセージを追加
|
||||||
for msg in &self.history {
|
for msg in context {
|
||||||
// Message から llm_client::Message への変換
|
// Message から llm_client::Message への変換
|
||||||
request = request.message(crate::llm_client::Message {
|
request = request.message(crate::llm_client::Message {
|
||||||
role: match msg.role {
|
role: match msg.role {
|
||||||
|
|
@ -494,37 +632,89 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
request
|
request
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Hooks: on_message_send
|
/// Hooks: on_prompt_submit
|
||||||
async fn run_on_message_send_hooks(&self) -> Result<ControlFlow, WorkerError> {
|
///
|
||||||
for hook in &self.hooks {
|
/// `run()` でユーザーメッセージを受け取った直後に呼び出される(最初だけ)。
|
||||||
// Note: Locked状態でも履歴全体を参照として渡す(変更は不可)
|
async fn run_on_prompt_submit_hooks(
|
||||||
// HookのAPIを変更し、immutable参照のみを渡すようにする必要があるかもしれない
|
&self,
|
||||||
// 現在は空のVecを渡して回避(要検討)
|
message: &mut Message,
|
||||||
let mut temp_context = self.history.clone();
|
) -> Result<OnPromptSubmitResult, WorkerError> {
|
||||||
let result = hook.on_message_send(&mut temp_context).await?;
|
for hook in &self.hooks.on_prompt_submit {
|
||||||
|
let result = hook.call(message).await?;
|
||||||
match result {
|
match result {
|
||||||
ControlFlow::Continue => continue,
|
OnPromptSubmitResult::Continue => continue,
|
||||||
ControlFlow::Skip => return Ok(ControlFlow::Skip),
|
OnPromptSubmitResult::Cancel(reason) => {
|
||||||
ControlFlow::Abort(reason) => return Ok(ControlFlow::Abort(reason)),
|
return Ok(OnPromptSubmitResult::Cancel(reason));
|
||||||
ControlFlow::Pause => return Ok(ControlFlow::Pause),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(ControlFlow::Continue)
|
}
|
||||||
|
Ok(OnPromptSubmitResult::Continue)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Hooks: pre_llm_request
|
||||||
|
///
|
||||||
|
/// 各ターンのLLMリクエスト送信前に呼び出される(毎ターン)。
|
||||||
|
async fn run_pre_llm_request_hooks(
|
||||||
|
&self,
|
||||||
|
) -> Result<(PreLlmRequestResult, Vec<Message>), WorkerError> {
|
||||||
|
let mut temp_context = self.history.clone();
|
||||||
|
for hook in &self.hooks.pre_llm_request {
|
||||||
|
let result = hook.call(&mut temp_context).await?;
|
||||||
|
match result {
|
||||||
|
PreLlmRequestResult::Continue => continue,
|
||||||
|
PreLlmRequestResult::Cancel(reason) => {
|
||||||
|
return Ok((PreLlmRequestResult::Cancel(reason), temp_context));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((PreLlmRequestResult::Continue, temp_context))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Hooks: on_turn_end
|
/// Hooks: on_turn_end
|
||||||
async fn run_on_turn_end_hooks(&self) -> Result<TurnResult, WorkerError> {
|
async fn run_on_turn_end_hooks(&self) -> Result<OnTurnEndResult, WorkerError> {
|
||||||
for hook in &self.hooks {
|
let mut temp_messages = self.history.clone();
|
||||||
let result = hook.on_turn_end(&self.history).await?;
|
for hook in &self.hooks.on_turn_end {
|
||||||
|
let result = hook.call(&mut temp_messages).await?;
|
||||||
match result {
|
match result {
|
||||||
TurnResult::Finish => continue,
|
OnTurnEndResult::Finish => continue,
|
||||||
TurnResult::ContinueWithMessages(msgs) => {
|
OnTurnEndResult::ContinueWithMessages(msgs) => {
|
||||||
return Ok(TurnResult::ContinueWithMessages(msgs));
|
return Ok(OnTurnEndResult::ContinueWithMessages(msgs));
|
||||||
}
|
}
|
||||||
TurnResult::Paused => return Ok(TurnResult::Paused),
|
OnTurnEndResult::Paused => return Ok(OnTurnEndResult::Paused),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(OnTurnEndResult::Finish)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Hooks: on_abort
|
||||||
|
async fn run_on_abort_hooks(&self, reason: &str) -> Result<(), WorkerError> {
|
||||||
|
let mut reason = reason.to_string();
|
||||||
|
for hook in &self.hooks.on_abort {
|
||||||
|
hook.call(&mut reason).await?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn finalize_interruption<T>(
|
||||||
|
&mut self,
|
||||||
|
result: Result<T, WorkerError>,
|
||||||
|
) -> Result<T, WorkerError> {
|
||||||
|
match result {
|
||||||
|
Ok(value) => Ok(value),
|
||||||
|
Err(err) => {
|
||||||
|
self.last_run_interrupted = true;
|
||||||
|
let reason = match &err {
|
||||||
|
WorkerError::Aborted(reason) => reason.clone(),
|
||||||
|
WorkerError::Cancelled => "Cancelled".to_string(),
|
||||||
|
_ => err.to_string(),
|
||||||
|
};
|
||||||
|
if let Err(hook_err) = self.run_on_abort_hooks(&reason).await {
|
||||||
|
self.last_run_interrupted = true;
|
||||||
|
return Err(hook_err);
|
||||||
|
}
|
||||||
|
Err(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(TurnResult::Finish)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 未実行のツール呼び出しがあるかチェック(Pauseからの復帰用)
|
/// 未実行のツール呼び出しがあるかチェック(Pauseからの復帰用)
|
||||||
|
|
@ -547,55 +737,83 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if calls.is_empty() {
|
if calls.is_empty() { None } else { Some(calls) }
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(calls)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ツールを並列実行
|
/// ツールを並列実行
|
||||||
///
|
///
|
||||||
/// 全てのツールに対してbefore_tool_callフックを実行後、
|
/// 全てのツールに対してpre_tool_callフックを実行後、
|
||||||
/// 許可されたツールを並列に実行し、結果にafter_tool_callフックを適用する。
|
/// 許可されたツールを並列に実行し、結果にpost_tool_callフックを適用する。
|
||||||
async fn execute_tools(
|
async fn execute_tools(
|
||||||
&self,
|
&mut self,
|
||||||
tool_calls: Vec<ToolCall>,
|
tool_calls: Vec<ToolCall>,
|
||||||
) -> Result<ToolExecutionResult, WorkerError> {
|
) -> Result<ToolExecutionResult, WorkerError> {
|
||||||
use futures::future::join_all;
|
use futures::future::join_all;
|
||||||
|
|
||||||
// Phase 1: before_tool_call フックを適用(スキップ/中断を判定)
|
// ツール呼び出しIDから (ToolCall, Meta, Tool) へのマップ
|
||||||
|
// PostToolCallフックで必要になるため保持する
|
||||||
|
let mut call_info_map = HashMap::new();
|
||||||
|
|
||||||
|
// Phase 1: pre_tool_call フックを適用(スキップ/中断を判定)
|
||||||
let mut approved_calls = Vec::new();
|
let mut approved_calls = Vec::new();
|
||||||
for mut tool_call in tool_calls {
|
for mut tool_call in tool_calls {
|
||||||
|
// ツール定義を取得
|
||||||
|
if let Some((meta, tool)) = self.tools.get(&tool_call.name) {
|
||||||
|
// コンテキストを作成
|
||||||
|
let mut context = ToolCallContext {
|
||||||
|
call: tool_call.clone(),
|
||||||
|
meta: meta.clone(),
|
||||||
|
tool: tool.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
let mut skip = false;
|
let mut skip = false;
|
||||||
for hook in &self.hooks {
|
for hook in &self.hooks.pre_tool_call {
|
||||||
let result = hook.before_tool_call(&mut tool_call).await?;
|
let result = hook
|
||||||
|
.call(&mut context)
|
||||||
|
.await
|
||||||
|
.inspect_err(|_| self.last_run_interrupted = true)?;
|
||||||
match result {
|
match result {
|
||||||
ControlFlow::Continue => {}
|
PreToolCallResult::Continue => {}
|
||||||
ControlFlow::Skip => {
|
PreToolCallResult::Skip => {
|
||||||
skip = true;
|
skip = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
ControlFlow::Abort(reason) => {
|
PreToolCallResult::Abort(reason) => {
|
||||||
|
self.last_run_interrupted = true;
|
||||||
return Err(WorkerError::Aborted(reason));
|
return Err(WorkerError::Aborted(reason));
|
||||||
}
|
}
|
||||||
ControlFlow::Pause => {
|
PreToolCallResult::Pause => {
|
||||||
|
self.last_run_interrupted = true;
|
||||||
return Ok(ToolExecutionResult::Paused);
|
return Ok(ToolExecutionResult::Paused);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// フックで変更された内容を反映
|
||||||
|
tool_call = context.call;
|
||||||
|
|
||||||
|
// マップに保存(実行する場合のみ)
|
||||||
if !skip {
|
if !skip {
|
||||||
|
call_info_map.insert(
|
||||||
|
tool_call.id.clone(),
|
||||||
|
(tool_call.clone(), meta.clone(), tool.clone()),
|
||||||
|
);
|
||||||
|
approved_calls.push(tool_call);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 未知のツールはそのまま承認リストに入れる(実行時にエラーになる)
|
||||||
|
// Hookは適用しない(Metaがないため)
|
||||||
approved_calls.push(tool_call);
|
approved_calls.push(tool_call);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Phase 2: 許可されたツールを並列実行
|
// Phase 2: 許可されたツールを並列実行(キャンセル可能)
|
||||||
let futures: Vec<_> = approved_calls
|
let futures: Vec<_> = approved_calls
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|tool_call| {
|
.map(|tool_call| {
|
||||||
let tools = &self.tools;
|
let tools = &self.tools;
|
||||||
async move {
|
async move {
|
||||||
if let Some(tool) = tools.get(&tool_call.name) {
|
if let Some((_, tool)) = tools.get(&tool_call.name) {
|
||||||
let input_json =
|
let input_json =
|
||||||
serde_json::to_string(&tool_call.input).unwrap_or_default();
|
serde_json::to_string(&tool_call.input).unwrap_or_default();
|
||||||
match tool.execute(&input_json).await {
|
match tool.execute(&input_json).await {
|
||||||
|
|
@ -612,26 +830,45 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let mut results = join_all(futures).await;
|
// ツール実行をキャンセル可能にする
|
||||||
|
let mut results = tokio::select! {
|
||||||
|
results = join_all(futures) => results,
|
||||||
|
cancel = self.cancel_rx.recv() => {
|
||||||
|
if cancel.is_some() {
|
||||||
|
info!("Tool execution cancelled");
|
||||||
|
}
|
||||||
|
self.timeline.abort_current_block();
|
||||||
|
self.last_run_interrupted = true;
|
||||||
|
return Err(WorkerError::Cancelled);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Phase 3: after_tool_call フックを適用
|
// Phase 3: post_tool_call フックを適用
|
||||||
for tool_result in &mut results {
|
for tool_result in &mut results {
|
||||||
for hook in &self.hooks {
|
// 保存しておいた情報を取得
|
||||||
let result = hook.after_tool_call(tool_result).await?;
|
if let Some((tool_call, meta, tool)) = call_info_map.get(&tool_result.tool_use_id) {
|
||||||
|
let mut context = PostToolCallContext {
|
||||||
|
call: tool_call.clone(),
|
||||||
|
result: tool_result.clone(),
|
||||||
|
meta: meta.clone(),
|
||||||
|
tool: tool.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
for hook in &self.hooks.post_tool_call {
|
||||||
|
let result = hook
|
||||||
|
.call(&mut context)
|
||||||
|
.await
|
||||||
|
.inspect_err(|_| self.last_run_interrupted = true)?;
|
||||||
match result {
|
match result {
|
||||||
ControlFlow::Continue => {}
|
PostToolCallResult::Continue => {}
|
||||||
ControlFlow::Skip => break,
|
PostToolCallResult::Abort(reason) => {
|
||||||
ControlFlow::Abort(reason) => {
|
self.last_run_interrupted = true;
|
||||||
return Err(WorkerError::Aborted(reason));
|
return Err(WorkerError::Aborted(reason));
|
||||||
}
|
}
|
||||||
ControlFlow::Pause => {
|
|
||||||
// after_tool_callでのPauseは結果を受け入れた後、次の処理前に止まる動作とする
|
|
||||||
// ここではContinue扱いとし、on_message_send等でPauseすることを期待する
|
|
||||||
// あるいはここでのPauseをサポートする場合は戻り値を調整する必要がある
|
|
||||||
// 現状はログを出してContinue
|
|
||||||
warn!("ControlFlow::Pause in after_tool_call is treated as Continue");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// フックで変更された結果を反映
|
||||||
|
*tool_result = context.result;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -639,7 +876,9 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 内部で使用するターン実行ロジック
|
/// 内部で使用するターン実行ロジック
|
||||||
async fn run_turn_loop(&mut self) -> Result<WorkerResult<'_>, WorkerError> {
|
async fn run_turn_loop(&mut self) -> Result<WorkerResult, WorkerError> {
|
||||||
|
self.reset_interruption_state();
|
||||||
|
self.drain_cancel_queue();
|
||||||
let tool_definitions = self.build_tool_definitions();
|
let tool_definitions = self.build_tool_definitions();
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
|
|
@ -651,18 +890,34 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
// Resume check: Pending tool calls
|
// Resume check: Pending tool calls
|
||||||
if let Some(tool_calls) = self.get_pending_tool_calls() {
|
if let Some(tool_calls) = self.get_pending_tool_calls() {
|
||||||
info!("Resuming pending tool calls");
|
info!("Resuming pending tool calls");
|
||||||
match self.execute_tools(tool_calls).await? {
|
match self.execute_tools(tool_calls).await {
|
||||||
ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)),
|
Ok(ToolExecutionResult::Paused) => {
|
||||||
ToolExecutionResult::Completed(results) => {
|
self.last_run_interrupted = true;
|
||||||
|
return Ok(WorkerResult::Paused);
|
||||||
|
}
|
||||||
|
Ok(ToolExecutionResult::Completed(results)) => {
|
||||||
for result in results {
|
for result in results {
|
||||||
self.history.push(Message::tool_result(&result.tool_use_id, &result.content));
|
self.history
|
||||||
|
.push(Message::tool_result(&result.tool_use_id, &result.content));
|
||||||
}
|
}
|
||||||
// Continue to loop
|
// Continue to loop
|
||||||
}
|
}
|
||||||
|
Err(err) => {
|
||||||
|
self.last_run_interrupted = true;
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
|
// キャンセルチェック
|
||||||
|
if self.try_cancelled() {
|
||||||
|
info!("Execution cancelled");
|
||||||
|
self.timeline.abort_current_block();
|
||||||
|
self.last_run_interrupted = true;
|
||||||
|
return Err(WorkerError::Cancelled);
|
||||||
|
}
|
||||||
|
|
||||||
// ターン開始を通知
|
// ターン開始を通知
|
||||||
let current_turn = self.turn_count;
|
let current_turn = self.turn_count;
|
||||||
debug!(turn = current_turn, "Turn start");
|
debug!(turn = current_turn, "Turn start");
|
||||||
|
|
@ -670,25 +925,25 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
notifier.on_turn_start(current_turn);
|
notifier.on_turn_start(current_turn);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hook: on_message_send
|
// Hook: pre_llm_request
|
||||||
let control = self.run_on_message_send_hooks().await?;
|
let (control, request_context) = self
|
||||||
|
.run_pre_llm_request_hooks()
|
||||||
|
.await
|
||||||
|
.inspect_err(|_| self.last_run_interrupted = true)?;
|
||||||
match control {
|
match control {
|
||||||
ControlFlow::Abort(reason) => {
|
PreLlmRequestResult::Cancel(reason) => {
|
||||||
warn!(reason = %reason, "Aborted by hook");
|
info!(reason = %reason, "Aborted by hook");
|
||||||
for notifier in &self.turn_notifiers {
|
for notifier in &self.turn_notifiers {
|
||||||
notifier.on_turn_end(current_turn);
|
notifier.on_turn_end(current_turn);
|
||||||
}
|
}
|
||||||
|
self.last_run_interrupted = true;
|
||||||
return Err(WorkerError::Aborted(reason));
|
return Err(WorkerError::Aborted(reason));
|
||||||
}
|
}
|
||||||
ControlFlow::Pause | ControlFlow::Skip => {
|
PreLlmRequestResult::Continue => {}
|
||||||
// Skip or Pause -> Pause the worker
|
|
||||||
return Ok(WorkerResult::Paused(&self.history));
|
|
||||||
}
|
|
||||||
ControlFlow::Continue => {}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// リクエスト構築
|
// リクエスト構築
|
||||||
let request = self.build_request(&tool_definitions);
|
let request = self.build_request(&tool_definitions, &request_context);
|
||||||
debug!(
|
debug!(
|
||||||
message_count = request.messages.len(),
|
message_count = request.messages.len(),
|
||||||
tool_count = request.tools.len(),
|
tool_count = request.tools.len(),
|
||||||
|
|
@ -698,10 +953,29 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
|
|
||||||
// ストリーム処理
|
// ストリーム処理
|
||||||
debug!("Starting stream...");
|
debug!("Starting stream...");
|
||||||
let mut stream = self.client.stream(request).await?;
|
|
||||||
let mut event_count = 0;
|
let mut event_count = 0;
|
||||||
while let Some(event_result) = stream.next().await {
|
|
||||||
match &event_result {
|
// ストリームを取得(キャンセル可能)
|
||||||
|
let mut stream = tokio::select! {
|
||||||
|
stream_result = self.client.stream(request) => stream_result
|
||||||
|
.inspect_err(|_| self.last_run_interrupted = true)?,
|
||||||
|
cancel = self.cancel_rx.recv() => {
|
||||||
|
if cancel.is_some() {
|
||||||
|
info!("Cancelled before stream started");
|
||||||
|
}
|
||||||
|
self.timeline.abort_current_block();
|
||||||
|
self.last_run_interrupted = true;
|
||||||
|
return Err(WorkerError::Cancelled);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
// ストリームからイベントを受信
|
||||||
|
event_result = stream.next() => {
|
||||||
|
match event_result {
|
||||||
|
Some(result) => {
|
||||||
|
match &result {
|
||||||
Ok(event) => {
|
Ok(event) => {
|
||||||
trace!(event = ?event, "Received event");
|
trace!(event = ?event, "Received event");
|
||||||
event_count += 1;
|
event_count += 1;
|
||||||
|
|
@ -710,10 +984,25 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
warn!(error = %e, "Stream error");
|
warn!(error = %e, "Stream error");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let event = event_result?;
|
let event = result
|
||||||
|
.inspect_err(|_| self.last_run_interrupted = true)?;
|
||||||
let timeline_event: crate::timeline::event::Event = event.into();
|
let timeline_event: crate::timeline::event::Event = event.into();
|
||||||
self.timeline.dispatch(&timeline_event);
|
self.timeline.dispatch(&timeline_event);
|
||||||
}
|
}
|
||||||
|
None => break, // ストリーム終了
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// キャンセル待機
|
||||||
|
cancel = self.cancel_rx.recv() => {
|
||||||
|
if cancel.is_some() {
|
||||||
|
info!("Stream cancelled");
|
||||||
|
}
|
||||||
|
self.timeline.abort_current_block();
|
||||||
|
self.last_run_interrupted = true;
|
||||||
|
return Err(WorkerError::Cancelled);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
debug!(event_count = event_count, "Stream completed");
|
debug!(event_count = event_count, "Stream completed");
|
||||||
|
|
||||||
// ターン終了を通知
|
// ターン終了を通知
|
||||||
|
|
@ -734,28 +1023,41 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
|
|
||||||
if tool_calls.is_empty() {
|
if tool_calls.is_empty() {
|
||||||
// ツール呼び出しなし → ターン終了判定
|
// ツール呼び出しなし → ターン終了判定
|
||||||
let turn_result = self.run_on_turn_end_hooks().await?;
|
let turn_result = self
|
||||||
|
.run_on_turn_end_hooks()
|
||||||
|
.await
|
||||||
|
.inspect_err(|_| self.last_run_interrupted = true)?;
|
||||||
match turn_result {
|
match turn_result {
|
||||||
TurnResult::Finish => {
|
OnTurnEndResult::Finish => {
|
||||||
return Ok(WorkerResult::Finished(&self.history));
|
self.last_run_interrupted = false;
|
||||||
|
return Ok(WorkerResult::Finished);
|
||||||
}
|
}
|
||||||
TurnResult::ContinueWithMessages(additional) => {
|
OnTurnEndResult::ContinueWithMessages(additional) => {
|
||||||
self.history.extend(additional);
|
self.history.extend(additional);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
TurnResult::Paused => {
|
OnTurnEndResult::Paused => {
|
||||||
return Ok(WorkerResult::Paused(&self.history));
|
self.last_run_interrupted = true;
|
||||||
|
return Ok(WorkerResult::Paused);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ツール実行
|
// ツール実行
|
||||||
match self.execute_tools(tool_calls).await? {
|
match self.execute_tools(tool_calls).await {
|
||||||
ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)),
|
Ok(ToolExecutionResult::Paused) => {
|
||||||
ToolExecutionResult::Completed(results) => {
|
self.last_run_interrupted = true;
|
||||||
for result in results {
|
return Ok(WorkerResult::Paused);
|
||||||
self.history.push(Message::tool_result(&result.tool_use_id, &result.content));
|
|
||||||
}
|
}
|
||||||
|
Ok(ToolExecutionResult::Completed(results)) => {
|
||||||
|
for result in results {
|
||||||
|
self.history
|
||||||
|
.push(Message::tool_result(&result.tool_use_id, &result.content));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
self.last_run_interrupted = true;
|
||||||
|
return Err(err);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -764,8 +1066,10 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
/// 実行を再開(Pause状態からの復帰)
|
/// 実行を再開(Pause状態からの復帰)
|
||||||
///
|
///
|
||||||
/// 新しいユーザーメッセージを履歴に追加せず、現在の状態からターン処理を再開する。
|
/// 新しいユーザーメッセージを履歴に追加せず、現在の状態からターン処理を再開する。
|
||||||
pub async fn resume(&mut self) -> Result<WorkerResult<'_>, WorkerError> {
|
pub async fn resume(&mut self) -> Result<WorkerResult, WorkerError> {
|
||||||
self.run_turn_loop().await
|
self.reset_interruption_state();
|
||||||
|
let result = self.run_turn_loop().await;
|
||||||
|
self.finalize_interruption(result).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -779,6 +1083,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
||||||
let text_block_collector = TextBlockCollector::new();
|
let text_block_collector = TextBlockCollector::new();
|
||||||
let tool_call_collector = ToolCallCollector::new();
|
let tool_call_collector = ToolCallCollector::new();
|
||||||
let mut timeline = Timeline::new();
|
let mut timeline = Timeline::new();
|
||||||
|
let (cancel_tx, cancel_rx) = mpsc::channel(1);
|
||||||
|
|
||||||
// コレクターをTimelineに登録
|
// コレクターをTimelineに登録
|
||||||
timeline.on_text_block(text_block_collector.clone());
|
timeline.on_text_block(text_block_collector.clone());
|
||||||
|
|
@ -790,13 +1095,16 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
||||||
text_block_collector,
|
text_block_collector,
|
||||||
tool_call_collector,
|
tool_call_collector,
|
||||||
tools: HashMap::new(),
|
tools: HashMap::new(),
|
||||||
hooks: Vec::new(),
|
hooks: HookRegistry::new(),
|
||||||
system_prompt: None,
|
system_prompt: None,
|
||||||
history: Vec::new(),
|
history: Vec::new(),
|
||||||
locked_prefix_len: 0,
|
locked_prefix_len: 0,
|
||||||
turn_count: 0,
|
turn_count: 0,
|
||||||
turn_notifiers: Vec::new(),
|
turn_notifiers: Vec::new(),
|
||||||
request_config: RequestConfig::default(),
|
request_config: RequestConfig::default(),
|
||||||
|
last_run_interrupted: false,
|
||||||
|
cancel_tx,
|
||||||
|
cancel_rx,
|
||||||
_state: PhantomData,
|
_state: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -946,11 +1254,11 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ロックしてLocked状態へ遷移
|
/// ロックしてCacheLocked状態へ遷移
|
||||||
///
|
///
|
||||||
/// この操作により、現在のシステムプロンプトと履歴が「確定済みプレフィックス」として
|
/// この操作により、現在のシステムプロンプトと履歴が「確定済みプレフィックス」として
|
||||||
/// 固定される。以降は履歴への追記のみが可能となり、キャッシュヒットが保証される。
|
/// 固定される。以降は履歴への追記のみが可能となり、キャッシュヒットが保証される。
|
||||||
pub fn lock(self) -> Worker<C, Locked> {
|
pub fn lock(self) -> Worker<C, CacheLocked> {
|
||||||
let locked_prefix_len = self.history.len();
|
let locked_prefix_len = self.history.len();
|
||||||
Worker {
|
Worker {
|
||||||
client: self.client,
|
client: self.client,
|
||||||
|
|
@ -965,57 +1273,20 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
||||||
turn_count: self.turn_count,
|
turn_count: self.turn_count,
|
||||||
turn_notifiers: self.turn_notifiers,
|
turn_notifiers: self.turn_notifiers,
|
||||||
request_config: self.request_config,
|
request_config: self.request_config,
|
||||||
|
last_run_interrupted: self.last_run_interrupted,
|
||||||
|
cancel_tx: self.cancel_tx,
|
||||||
|
cancel_rx: self.cancel_rx,
|
||||||
_state: PhantomData,
|
_state: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ターンを実行(Mutable状態)
|
|
||||||
///
|
|
||||||
/// 新しいユーザーメッセージを履歴に追加し、LLMにリクエストを送信する。
|
|
||||||
/// ツール呼び出しがある場合は自動的にループする。
|
|
||||||
///
|
|
||||||
/// 注意: この関数は履歴を変更するため、キャッシュ保護が必要な場合は
|
|
||||||
/// `lock()` を呼んでからLocked状態で `run` を使用すること。
|
|
||||||
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<WorkerResult<'_>, WorkerError> {
|
|
||||||
self.history.push(Message::user(user_input));
|
|
||||||
self.run_turn_loop().await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 複数メッセージでターンを実行(Mutable状態)
|
|
||||||
///
|
|
||||||
/// 指定されたメッセージを履歴に追加してから実行する。
|
|
||||||
pub async fn run_with_messages(
|
|
||||||
&mut self,
|
|
||||||
messages: Vec<Message>,
|
|
||||||
) -> Result<WorkerResult<'_>, WorkerError> {
|
|
||||||
self.history.extend(messages);
|
|
||||||
self.run_turn_loop().await
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// Locked状態専用の実装
|
// CacheLocked状態専用の実装
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
impl<C: LlmClient> Worker<C, Locked> {
|
impl<C: LlmClient> Worker<C, CacheLocked> {
|
||||||
/// ターンを実行(Locked状態)
|
|
||||||
///
|
|
||||||
/// 新しいユーザーメッセージを履歴の末尾に追加し、LLMにリクエストを送信する。
|
|
||||||
/// ロック時点より前の履歴(プレフィックス)は不変であるため、キャッシュヒットが保証される。
|
|
||||||
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<WorkerResult<'_>, WorkerError> {
|
|
||||||
self.history.push(Message::user(user_input));
|
|
||||||
self.run_turn_loop().await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 複数メッセージでターンを実行(Locked状態)
|
|
||||||
pub async fn run_with_messages(
|
|
||||||
&mut self,
|
|
||||||
messages: Vec<Message>,
|
|
||||||
) -> Result<WorkerResult<'_>, WorkerError> {
|
|
||||||
self.history.extend(messages);
|
|
||||||
self.run_turn_loop().await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// ロック時点のプレフィックス長を取得
|
/// ロック時点のプレフィックス長を取得
|
||||||
pub fn locked_prefix_len(&self) -> usize {
|
pub fn locked_prefix_len(&self) -> usize {
|
||||||
self.locked_prefix_len
|
self.locked_prefix_len
|
||||||
|
|
@ -1039,6 +1310,9 @@ impl<C: LlmClient> Worker<C, Locked> {
|
||||||
turn_count: self.turn_count,
|
turn_count: self.turn_count,
|
||||||
turn_notifiers: self.turn_notifiers,
|
turn_notifiers: self.turn_notifiers,
|
||||||
request_config: self.request_config,
|
request_config: self.request_config,
|
||||||
|
last_run_interrupted: self.last_run_interrupted,
|
||||||
|
cancel_tx: self.cancel_tx,
|
||||||
|
cancel_rx: self.cancel_rx,
|
||||||
_state: PhantomData,
|
_state: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -8,9 +8,12 @@ use std::time::{Duration, Instant};
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use llm_worker::Worker;
|
use llm_worker::Worker;
|
||||||
use llm_worker::hook::{ControlFlow, HookError, ToolCall, ToolResult, WorkerHook};
|
use llm_worker::hook::{
|
||||||
|
Hook, HookError, PostToolCall, PostToolCallContext, PostToolCallResult, PreToolCall,
|
||||||
|
PreToolCallResult, ToolCallContext,
|
||||||
|
};
|
||||||
use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent};
|
use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent};
|
||||||
use llm_worker::tool::{Tool, ToolError};
|
use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta};
|
||||||
|
|
||||||
mod common;
|
mod common;
|
||||||
use common::MockLlmClient;
|
use common::MockLlmClient;
|
||||||
|
|
@ -39,25 +42,24 @@ impl SlowTool {
|
||||||
fn call_count(&self) -> usize {
|
fn call_count(&self) -> usize {
|
||||||
self.call_count.load(Ordering::SeqCst)
|
self.call_count.load(Ordering::SeqCst)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// ToolDefinition を作成
|
||||||
|
fn definition(&self) -> ToolDefinition {
|
||||||
|
let tool = self.clone();
|
||||||
|
Arc::new(move || {
|
||||||
|
let meta = ToolMeta::new(&tool.name)
|
||||||
|
.description("A tool that waits before responding")
|
||||||
|
.input_schema(serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {}
|
||||||
|
}));
|
||||||
|
(meta, Arc::new(tool.clone()) as Arc<dyn Tool>)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl Tool for SlowTool {
|
impl Tool for SlowTool {
|
||||||
fn name(&self) -> &str {
|
|
||||||
&self.name
|
|
||||||
}
|
|
||||||
|
|
||||||
fn description(&self) -> &str {
|
|
||||||
"A tool that waits before responding"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn input_schema(&self) -> serde_json::Value {
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute(&self, _input_json: &str) -> Result<String, ToolError> {
|
async fn execute(&self, _input_json: &str) -> Result<String, ToolError> {
|
||||||
self.call_count.fetch_add(1, Ordering::SeqCst);
|
self.call_count.fetch_add(1, Ordering::SeqCst);
|
||||||
tokio::time::sleep(Duration::from_millis(self.delay_ms)).await;
|
tokio::time::sleep(Duration::from_millis(self.delay_ms)).await;
|
||||||
|
|
@ -103,9 +105,9 @@ async fn test_parallel_tool_execution() {
|
||||||
let tool2_clone = tool2.clone();
|
let tool2_clone = tool2.clone();
|
||||||
let tool3_clone = tool3.clone();
|
let tool3_clone = tool3.clone();
|
||||||
|
|
||||||
worker.register_tool(tool1);
|
worker.register_tool(tool1.definition()).unwrap();
|
||||||
worker.register_tool(tool2);
|
worker.register_tool(tool2.definition()).unwrap();
|
||||||
worker.register_tool(tool3);
|
worker.register_tool(tool3.definition()).unwrap();
|
||||||
|
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let _result = worker.run("Run all tools").await;
|
let _result = worker.run("Run all tools").await;
|
||||||
|
|
@ -127,7 +129,7 @@ async fn test_parallel_tool_execution() {
|
||||||
println!("Parallel execution completed in {:?}", elapsed);
|
println!("Parallel execution completed in {:?}", elapsed);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Hook: before_tool_call でスキップされたツールは実行されないことを確認
|
/// Hook: pre_tool_call でスキップされたツールは実行されないことを確認
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_before_tool_call_skip() {
|
async fn test_before_tool_call_skip() {
|
||||||
let events = vec![
|
let events = vec![
|
||||||
|
|
@ -151,27 +153,24 @@ async fn test_before_tool_call_skip() {
|
||||||
let allowed_clone = allowed_tool.clone();
|
let allowed_clone = allowed_tool.clone();
|
||||||
let blocked_clone = blocked_tool.clone();
|
let blocked_clone = blocked_tool.clone();
|
||||||
|
|
||||||
worker.register_tool(allowed_tool);
|
worker.register_tool(allowed_tool.definition()).unwrap();
|
||||||
worker.register_tool(blocked_tool);
|
worker.register_tool(blocked_tool.definition()).unwrap();
|
||||||
|
|
||||||
// "blocked_tool" をスキップするHook
|
// "blocked_tool" をスキップするHook
|
||||||
struct BlockingHook;
|
struct BlockingHook;
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl WorkerHook for BlockingHook {
|
impl Hook<PreToolCall> for BlockingHook {
|
||||||
async fn before_tool_call(
|
async fn call(&self, ctx: &mut ToolCallContext) -> Result<PreToolCallResult, HookError> {
|
||||||
&self,
|
if ctx.call.name == "blocked_tool" {
|
||||||
tool_call: &mut ToolCall,
|
Ok(PreToolCallResult::Skip)
|
||||||
) -> Result<ControlFlow, HookError> {
|
|
||||||
if tool_call.name == "blocked_tool" {
|
|
||||||
Ok(ControlFlow::Skip)
|
|
||||||
} else {
|
} else {
|
||||||
Ok(ControlFlow::Continue)
|
Ok(PreToolCallResult::Continue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
worker.add_hook(BlockingHook);
|
worker.add_pre_tool_call_hook(BlockingHook);
|
||||||
|
|
||||||
let _result = worker.run("Test hook").await;
|
let _result = worker.run("Test hook").await;
|
||||||
|
|
||||||
|
|
@ -188,9 +187,9 @@ async fn test_before_tool_call_skip() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Hook: after_tool_call で結果が改変されることを確認
|
/// Hook: post_tool_call で結果が改変されることを確認
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_after_tool_call_modification() {
|
async fn test_post_tool_call_modification() {
|
||||||
// 複数リクエストに対応するレスポンスを準備
|
// 複数リクエストに対応するレスポンスを準備
|
||||||
let client = MockLlmClient::with_responses(vec![
|
let client = MockLlmClient::with_responses(vec![
|
||||||
// 1回目のリクエスト: ツール呼び出し
|
// 1回目のリクエスト: ツール呼び出し
|
||||||
|
|
@ -220,21 +219,21 @@ async fn test_after_tool_call_modification() {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl Tool for SimpleTool {
|
impl Tool for SimpleTool {
|
||||||
fn name(&self) -> &str {
|
|
||||||
"test_tool"
|
|
||||||
}
|
|
||||||
fn description(&self) -> &str {
|
|
||||||
"Test"
|
|
||||||
}
|
|
||||||
fn input_schema(&self) -> serde_json::Value {
|
|
||||||
serde_json::json!({})
|
|
||||||
}
|
|
||||||
async fn execute(&self, _: &str) -> Result<String, ToolError> {
|
async fn execute(&self, _: &str) -> Result<String, ToolError> {
|
||||||
Ok("Original Result".to_string())
|
Ok("Original Result".to_string())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
worker.register_tool(SimpleTool);
|
fn simple_tool_definition() -> ToolDefinition {
|
||||||
|
Arc::new(|| {
|
||||||
|
let meta = ToolMeta::new("test_tool")
|
||||||
|
.description("Test")
|
||||||
|
.input_schema(serde_json::json!({}));
|
||||||
|
(meta, Arc::new(SimpleTool) as Arc<dyn Tool>)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
worker.register_tool(simple_tool_definition()).unwrap();
|
||||||
|
|
||||||
// 結果を改変するHook
|
// 結果を改変するHook
|
||||||
struct ModifyingHook {
|
struct ModifyingHook {
|
||||||
|
|
@ -242,19 +241,19 @@ async fn test_after_tool_call_modification() {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl WorkerHook for ModifyingHook {
|
impl Hook<PostToolCall> for ModifyingHook {
|
||||||
async fn after_tool_call(
|
async fn call(
|
||||||
&self,
|
&self,
|
||||||
tool_result: &mut ToolResult,
|
ctx: &mut PostToolCallContext,
|
||||||
) -> Result<ControlFlow, HookError> {
|
) -> Result<PostToolCallResult, HookError> {
|
||||||
tool_result.content = format!("[Modified] {}", tool_result.content);
|
ctx.result.content = format!("[Modified] {}", ctx.result.content);
|
||||||
*self.modified_content.lock().unwrap() = Some(tool_result.content.clone());
|
*self.modified_content.lock().unwrap() = Some(ctx.result.content.clone());
|
||||||
Ok(ControlFlow::Continue)
|
Ok(PostToolCallResult::Continue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let modified_content = Arc::new(std::sync::Mutex::new(None));
|
let modified_content = Arc::new(std::sync::Mutex::new(None));
|
||||||
worker.add_hook(ModifyingHook {
|
worker.add_post_tool_call_hook(ModifyingHook {
|
||||||
modified_content: modified_content.clone(),
|
modified_content: modified_content.clone(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
use schemars;
|
use schemars;
|
||||||
use serde;
|
use serde;
|
||||||
|
|
||||||
use llm_worker::tool::Tool;
|
|
||||||
use llm_worker_macros::tool_registry;
|
use llm_worker_macros::tool_registry;
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
@ -51,30 +50,31 @@ async fn test_basic_tool_generation() {
|
||||||
prefix: "Hello".to_string(),
|
prefix: "Hello".to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// ファクトリメソッドでツールを取得
|
// ファクトリメソッドでToolDefinitionを取得
|
||||||
let greet_tool = ctx.greet_tool();
|
let greet_definition = ctx.greet_definition();
|
||||||
|
|
||||||
// 名前の確認
|
// ファクトリを呼び出してMetaとToolを取得
|
||||||
assert_eq!(greet_tool.name(), "greet");
|
let (meta, tool) = greet_definition();
|
||||||
|
|
||||||
// 説明の確認(docコメントから取得)
|
// メタ情報の確認
|
||||||
let desc = greet_tool.description();
|
assert_eq!(meta.name, "greet");
|
||||||
assert!(
|
assert!(
|
||||||
desc.contains("メッセージに挨拶を追加する"),
|
meta.description.contains("メッセージに挨拶を追加する"),
|
||||||
"Description should contain doc comment: {}",
|
"Description should contain doc comment: {}",
|
||||||
desc
|
meta.description
|
||||||
);
|
);
|
||||||
|
|
||||||
// スキーマの確認
|
|
||||||
let schema = greet_tool.input_schema();
|
|
||||||
println!("Schema: {}", serde_json::to_string_pretty(&schema).unwrap());
|
|
||||||
assert!(
|
assert!(
|
||||||
schema.get("properties").is_some(),
|
meta.input_schema.get("properties").is_some(),
|
||||||
"Schema should have properties"
|
"Schema should have properties"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"Schema: {}",
|
||||||
|
serde_json::to_string_pretty(&meta.input_schema).unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
// 実行テスト
|
// 実行テスト
|
||||||
let result = greet_tool.execute(r#"{"message": "World"}"#).await;
|
let result = tool.execute(r#"{"message": "World"}"#).await;
|
||||||
assert!(result.is_ok(), "Should execute successfully");
|
assert!(result.is_ok(), "Should execute successfully");
|
||||||
let output = result.unwrap();
|
let output = result.unwrap();
|
||||||
assert!(output.contains("Hello"), "Output should contain prefix");
|
assert!(output.contains("Hello"), "Output should contain prefix");
|
||||||
|
|
@ -87,11 +87,11 @@ async fn test_multiple_arguments() {
|
||||||
prefix: "".to_string(),
|
prefix: "".to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let add_tool = ctx.add_tool();
|
let (meta, tool) = ctx.add_definition()();
|
||||||
|
|
||||||
assert_eq!(add_tool.name(), "add");
|
assert_eq!(meta.name, "add");
|
||||||
|
|
||||||
let result = add_tool.execute(r#"{"a": 10, "b": 20}"#).await;
|
let result = tool.execute(r#"{"a": 10, "b": 20}"#).await;
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
let output = result.unwrap();
|
let output = result.unwrap();
|
||||||
assert!(output.contains("30"), "Should contain sum: {}", output);
|
assert!(output.contains("30"), "Should contain sum: {}", output);
|
||||||
|
|
@ -103,12 +103,12 @@ async fn test_no_arguments() {
|
||||||
prefix: "TestPrefix".to_string(),
|
prefix: "TestPrefix".to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let get_prefix_tool = ctx.get_prefix_tool();
|
let (meta, tool) = ctx.get_prefix_definition()();
|
||||||
|
|
||||||
assert_eq!(get_prefix_tool.name(), "get_prefix");
|
assert_eq!(meta.name, "get_prefix");
|
||||||
|
|
||||||
// 空のJSONオブジェクトで呼び出し
|
// 空のJSONオブジェクトで呼び出し
|
||||||
let result = get_prefix_tool.execute(r#"{}"#).await;
|
let result = tool.execute(r#"{}"#).await;
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
let output = result.unwrap();
|
let output = result.unwrap();
|
||||||
assert!(
|
assert!(
|
||||||
|
|
@ -124,10 +124,10 @@ async fn test_invalid_arguments() {
|
||||||
prefix: "".to_string(),
|
prefix: "".to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let greet_tool = ctx.greet_tool();
|
let (_, tool) = ctx.greet_definition()();
|
||||||
|
|
||||||
// 不正なJSON
|
// 不正なJSON
|
||||||
let result = greet_tool.execute(r#"{"wrong_field": "value"}"#).await;
|
let result = tool.execute(r#"{"wrong_field": "value"}"#).await;
|
||||||
assert!(result.is_err(), "Should fail with invalid arguments");
|
assert!(result.is_err(), "Should fail with invalid arguments");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -163,9 +163,9 @@ impl FallibleContext {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_result_return_type_success() {
|
async fn test_result_return_type_success() {
|
||||||
let ctx = FallibleContext;
|
let ctx = FallibleContext;
|
||||||
let validate_tool = ctx.validate_tool();
|
let (_, tool) = ctx.validate_definition()();
|
||||||
|
|
||||||
let result = validate_tool.execute(r#"{"value": 42}"#).await;
|
let result = tool.execute(r#"{"value": 42}"#).await;
|
||||||
assert!(result.is_ok(), "Should succeed for positive value");
|
assert!(result.is_ok(), "Should succeed for positive value");
|
||||||
let output = result.unwrap();
|
let output = result.unwrap();
|
||||||
assert!(output.contains("Valid"), "Should contain Valid: {}", output);
|
assert!(output.contains("Valid"), "Should contain Valid: {}", output);
|
||||||
|
|
@ -174,9 +174,9 @@ async fn test_result_return_type_success() {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_result_return_type_error() {
|
async fn test_result_return_type_error() {
|
||||||
let ctx = FallibleContext;
|
let ctx = FallibleContext;
|
||||||
let validate_tool = ctx.validate_tool();
|
let (_, tool) = ctx.validate_definition()();
|
||||||
|
|
||||||
let result = validate_tool.execute(r#"{"value": -1}"#).await;
|
let result = tool.execute(r#"{"value": -1}"#).await;
|
||||||
assert!(result.is_err(), "Should fail for negative value");
|
assert!(result.is_err(), "Should fail for negative value");
|
||||||
|
|
||||||
let err = result.unwrap_err();
|
let err = result.unwrap_err();
|
||||||
|
|
@ -211,12 +211,12 @@ async fn test_sync_method() {
|
||||||
counter: Arc::new(AtomicUsize::new(0)),
|
counter: Arc::new(AtomicUsize::new(0)),
|
||||||
};
|
};
|
||||||
|
|
||||||
let increment_tool = ctx.increment_tool();
|
let (_, tool) = ctx.increment_definition()();
|
||||||
|
|
||||||
// 3回実行
|
// 3回実行
|
||||||
let result1 = increment_tool.execute(r#"{}"#).await;
|
let result1 = tool.execute(r#"{}"#).await;
|
||||||
let result2 = increment_tool.execute(r#"{}"#).await;
|
let result2 = tool.execute(r#"{}"#).await;
|
||||||
let result3 = increment_tool.execute(r#"{}"#).await;
|
let result3 = tool.execute(r#"{}"#).await;
|
||||||
|
|
||||||
assert!(result1.is_ok());
|
assert!(result1.is_ok());
|
||||||
assert!(result2.is_ok());
|
assert!(result2.is_ok());
|
||||||
|
|
@ -225,3 +225,22 @@ async fn test_sync_method() {
|
||||||
// カウンターは3になっているはず
|
// カウンターは3になっているはず
|
||||||
assert_eq!(ctx.counter.load(Ordering::SeqCst), 3);
|
assert_eq!(ctx.counter.load(Ordering::SeqCst), 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Test: ToolMeta Immutability
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_tool_meta_immutability() {
|
||||||
|
let ctx = SimpleContext {
|
||||||
|
prefix: "Test".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// 2回取得しても同じメタ情報が得られることを確認
|
||||||
|
let (meta1, _) = ctx.greet_definition()();
|
||||||
|
let (meta2, _) = ctx.greet_definition()();
|
||||||
|
|
||||||
|
assert_eq!(meta1.name, meta2.name);
|
||||||
|
assert_eq!(meta1.description, meta2.description);
|
||||||
|
assert_eq!(meta1.input_schema, meta2.input_schema);
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
use llm_worker::llm_client::LlmClient;
|
|
||||||
use llm_worker::llm_client::providers::openai::OpenAIClient;
|
use llm_worker::llm_client::providers::openai::OpenAIClient;
|
||||||
use llm_worker::{Worker, WorkerError};
|
use llm_worker::{Worker, WorkerError};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use common::MockLlmClient;
|
use common::MockLlmClient;
|
||||||
use llm_worker::Worker;
|
use llm_worker::Worker;
|
||||||
use llm_worker::tool::{Tool, ToolError};
|
use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta};
|
||||||
|
|
||||||
/// フィクスチャディレクトリのパス
|
/// フィクスチャディレクトリのパス
|
||||||
fn fixtures_dir() -> std::path::PathBuf {
|
fn fixtures_dir() -> std::path::PathBuf {
|
||||||
|
|
@ -35,20 +35,13 @@ impl MockWeatherTool {
|
||||||
fn get_call_count(&self) -> usize {
|
fn get_call_count(&self) -> usize {
|
||||||
self.call_count.load(Ordering::SeqCst)
|
self.call_count.load(Ordering::SeqCst)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
fn definition(&self) -> ToolDefinition {
|
||||||
impl Tool for MockWeatherTool {
|
let tool = self.clone();
|
||||||
fn name(&self) -> &str {
|
Arc::new(move || {
|
||||||
"get_weather"
|
let meta = ToolMeta::new("get_weather")
|
||||||
}
|
.description("Get the current weather for a city")
|
||||||
|
.input_schema(serde_json::json!({
|
||||||
fn description(&self) -> &str {
|
|
||||||
"Get the current weather for a city"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn input_schema(&self) -> serde_json::Value {
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"city": {
|
"city": {
|
||||||
|
|
@ -57,9 +50,14 @@ impl Tool for MockWeatherTool {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["city"]
|
"required": ["city"]
|
||||||
|
}));
|
||||||
|
(meta, Arc::new(tool.clone()) as Arc<dyn Tool>)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for MockWeatherTool {
|
||||||
async fn execute(&self, input_json: &str) -> Result<String, ToolError> {
|
async fn execute(&self, input_json: &str) -> Result<String, ToolError> {
|
||||||
self.call_count.fetch_add(1, Ordering::SeqCst);
|
self.call_count.fetch_add(1, Ordering::SeqCst);
|
||||||
|
|
||||||
|
|
@ -158,7 +156,7 @@ async fn test_worker_tool_call() {
|
||||||
// ツールを登録
|
// ツールを登録
|
||||||
let weather_tool = MockWeatherTool::new();
|
let weather_tool = MockWeatherTool::new();
|
||||||
let tool_for_check = weather_tool.clone();
|
let tool_for_check = weather_tool.clone();
|
||||||
worker.register_tool(weather_tool);
|
worker.register_tool(weather_tool.definition()).unwrap();
|
||||||
|
|
||||||
// メッセージを送信
|
// メッセージを送信
|
||||||
let _result = worker.run("What's the weather in Tokyo?").await;
|
let _result = worker.run("What's the weather in Tokyo?").await;
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
//! Worker状態管理のテスト
|
//! Worker状態管理のテスト
|
||||||
//!
|
//!
|
||||||
//! Type-stateパターン(Mutable/Locked)による状態遷移と
|
//! Type-stateパターン(Mutable/CacheLocked)による状態遷移と
|
||||||
//! ターン間の状態保持をテストする。
|
//! ターン間の状態保持をテストする。
|
||||||
|
|
||||||
mod common;
|
mod common;
|
||||||
|
|
@ -95,7 +95,7 @@ fn test_mutable_extend_history() {
|
||||||
// 状態遷移テスト
|
// 状態遷移テスト
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// lock()でMutable -> Locked状態に遷移することを確認
|
/// lock()でMutable -> CacheLocked状態に遷移することを確認
|
||||||
#[test]
|
#[test]
|
||||||
fn test_lock_transition() {
|
fn test_lock_transition() {
|
||||||
let client = MockLlmClient::new(vec![]);
|
let client = MockLlmClient::new(vec![]);
|
||||||
|
|
@ -108,13 +108,13 @@ fn test_lock_transition() {
|
||||||
// ロック
|
// ロック
|
||||||
let locked_worker = worker.lock();
|
let locked_worker = worker.lock();
|
||||||
|
|
||||||
// Locked状態でも履歴とシステムプロンプトにアクセス可能
|
// CacheLocked状態でも履歴とシステムプロンプトにアクセス可能
|
||||||
assert_eq!(locked_worker.get_system_prompt(), Some("System"));
|
assert_eq!(locked_worker.get_system_prompt(), Some("System"));
|
||||||
assert_eq!(locked_worker.history().len(), 2);
|
assert_eq!(locked_worker.history().len(), 2);
|
||||||
assert_eq!(locked_worker.locked_prefix_len(), 2);
|
assert_eq!(locked_worker.locked_prefix_len(), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// unlock()でLocked -> Mutable状態に遷移することを確認
|
/// unlock()でCacheLocked -> Mutable状態に遷移することを確認
|
||||||
#[test]
|
#[test]
|
||||||
fn test_unlock_transition() {
|
fn test_unlock_transition() {
|
||||||
let client = MockLlmClient::new(vec![]);
|
let client = MockLlmClient::new(vec![]);
|
||||||
|
|
@ -172,7 +172,7 @@ async fn test_mutable_run_updates_history() {
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Locked状態で複数ターンを実行し、履歴が正しく累積することを確認
|
/// CacheLocked状態で複数ターンを実行し、履歴が正しく累積することを確認
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_locked_multi_turn_history_accumulation() {
|
async fn test_locked_multi_turn_history_accumulation() {
|
||||||
// 2回のリクエストに対応するレスポンスを準備
|
// 2回のリクエストに対応するレスポンスを準備
|
||||||
|
|
@ -340,7 +340,7 @@ async fn test_unlock_edit_relock() {
|
||||||
// システムプロンプト保持のテスト
|
// システムプロンプト保持のテスト
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// Locked状態でもシステムプロンプトが保持されることを確認
|
/// CacheLocked状態でもシステムプロンプトが保持されることを確認
|
||||||
#[test]
|
#[test]
|
||||||
fn test_system_prompt_preserved_in_locked_state() {
|
fn test_system_prompt_preserved_in_locked_state() {
|
||||||
let client = MockLlmClient::new(vec![]);
|
let client = MockLlmClient::new(vec![]);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user