Compare commits
No commits in common. "b206acc3d39d05cd480e80699bca1459eebfbe30" and "cc6bbe2a43cb8e660849f776bdc517e2bc7a6af3" have entirely different histories.
b206acc3d3
...
cc6bbe2a43
197
Cargo.lock
generated
197
Cargo.lock
generated
|
|
@ -26,16 +26,6 @@ version = "1.0.100"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61"
|
checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "assert-json-diff"
|
|
||||||
version = "2.0.2"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12"
|
|
||||||
dependencies = [
|
|
||||||
"serde",
|
|
||||||
"serde_json",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-stream"
|
name = "async-stream"
|
||||||
version = "0.3.6"
|
version = "0.3.6"
|
||||||
|
|
@ -69,12 +59,6 @@ dependencies = [
|
||||||
"syn",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "atomic-waker"
|
|
||||||
version = "1.1.2"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "autocfg"
|
name = "autocfg"
|
||||||
version = "1.5.0"
|
version = "1.5.0"
|
||||||
|
|
@ -87,12 +71,6 @@ version = "0.21.7"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
|
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "base64"
|
|
||||||
version = "0.22.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bitflags"
|
name = "bitflags"
|
||||||
version = "1.3.2"
|
version = "1.3.2"
|
||||||
|
|
@ -163,24 +141,6 @@ version = "0.8.7"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
|
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "deadpool"
|
|
||||||
version = "0.12.3"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "0be2b1d1d6ec8d846f05e137292d0b89133caf95ef33695424c09568bdd39b1b"
|
|
||||||
dependencies = [
|
|
||||||
"deadpool-runtime",
|
|
||||||
"lazy_static",
|
|
||||||
"num_cpus",
|
|
||||||
"tokio",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "deadpool-runtime"
|
|
||||||
version = "0.1.4"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dirs"
|
name = "dirs"
|
||||||
version = "6.0.0"
|
version = "6.0.0"
|
||||||
|
|
@ -405,26 +365,7 @@ dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-sink",
|
"futures-sink",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"http 0.2.12",
|
"http",
|
||||||
"indexmap",
|
|
||||||
"slab",
|
|
||||||
"tokio",
|
|
||||||
"tokio-util",
|
|
||||||
"tracing",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "h2"
|
|
||||||
version = "0.4.12"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "f3c0b69cfcb4e1b9f1bf2f53f95f766e4661169728ec61cd3fe5a0166f2d1386"
|
|
||||||
dependencies = [
|
|
||||||
"atomic-waker",
|
|
||||||
"bytes",
|
|
||||||
"fnv",
|
|
||||||
"futures-core",
|
|
||||||
"futures-sink",
|
|
||||||
"http 1.3.1",
|
|
||||||
"indexmap",
|
"indexmap",
|
||||||
"slab",
|
"slab",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
|
@ -444,12 +385,6 @@ version = "0.5.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "hermit-abi"
|
|
||||||
version = "0.5.2"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "http"
|
name = "http"
|
||||||
version = "0.2.12"
|
version = "0.2.12"
|
||||||
|
|
@ -461,17 +396,6 @@ dependencies = [
|
||||||
"itoa",
|
"itoa",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "http"
|
|
||||||
version = "1.3.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565"
|
|
||||||
dependencies = [
|
|
||||||
"bytes",
|
|
||||||
"fnv",
|
|
||||||
"itoa",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "http-body"
|
name = "http-body"
|
||||||
version = "0.4.6"
|
version = "0.4.6"
|
||||||
|
|
@ -479,30 +403,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2"
|
checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytes",
|
"bytes",
|
||||||
"http 0.2.12",
|
"http",
|
||||||
"pin-project-lite",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "http-body"
|
|
||||||
version = "1.0.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184"
|
|
||||||
dependencies = [
|
|
||||||
"bytes",
|
|
||||||
"http 1.3.1",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "http-body-util"
|
|
||||||
version = "0.1.3"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a"
|
|
||||||
dependencies = [
|
|
||||||
"bytes",
|
|
||||||
"futures-core",
|
|
||||||
"http 1.3.1",
|
|
||||||
"http-body 1.0.1",
|
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -528,9 +429,9 @@ dependencies = [
|
||||||
"futures-channel",
|
"futures-channel",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"h2 0.3.27",
|
"h2",
|
||||||
"http 0.2.12",
|
"http",
|
||||||
"http-body 0.4.6",
|
"http-body",
|
||||||
"httparse",
|
"httparse",
|
||||||
"httpdate",
|
"httpdate",
|
||||||
"itoa",
|
"itoa",
|
||||||
|
|
@ -542,29 +443,6 @@ dependencies = [
|
||||||
"want",
|
"want",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "hyper"
|
|
||||||
version = "1.7.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "eb3aa54a13a0dfe7fbe3a59e0c76093041720fdc77b110cc0fc260fafb4dc51e"
|
|
||||||
dependencies = [
|
|
||||||
"atomic-waker",
|
|
||||||
"bytes",
|
|
||||||
"futures-channel",
|
|
||||||
"futures-core",
|
|
||||||
"h2 0.4.12",
|
|
||||||
"http 1.3.1",
|
|
||||||
"http-body 1.0.1",
|
|
||||||
"httparse",
|
|
||||||
"httpdate",
|
|
||||||
"itoa",
|
|
||||||
"pin-project-lite",
|
|
||||||
"pin-utils",
|
|
||||||
"smallvec",
|
|
||||||
"tokio",
|
|
||||||
"want",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hyper-rustls"
|
name = "hyper-rustls"
|
||||||
version = "0.24.2"
|
version = "0.24.2"
|
||||||
|
|
@ -572,28 +450,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590"
|
checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"http 0.2.12",
|
"http",
|
||||||
"hyper 0.14.32",
|
"hyper",
|
||||||
"rustls",
|
"rustls",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-rustls",
|
"tokio-rustls",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "hyper-util"
|
|
||||||
version = "0.1.17"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "3c6995591a8f1380fcb4ba966a252a4b29188d51d2b89e3a252f5305be65aea8"
|
|
||||||
dependencies = [
|
|
||||||
"bytes",
|
|
||||||
"futures-core",
|
|
||||||
"http 1.3.1",
|
|
||||||
"http-body 1.0.1",
|
|
||||||
"hyper 1.7.0",
|
|
||||||
"pin-project-lite",
|
|
||||||
"tokio",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "iana-time-zone"
|
name = "iana-time-zone"
|
||||||
version = "0.1.64"
|
version = "0.1.64"
|
||||||
|
|
@ -873,16 +736,6 @@ dependencies = [
|
||||||
"autocfg",
|
"autocfg",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "num_cpus"
|
|
||||||
version = "1.17.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b"
|
|
||||||
dependencies = [
|
|
||||||
"hermit-abi",
|
|
||||||
"libc",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "once_cell"
|
name = "once_cell"
|
||||||
version = "1.21.3"
|
version = "1.21.3"
|
||||||
|
|
@ -1044,15 +897,15 @@ version = "0.11.27"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62"
|
checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"base64 0.21.7",
|
"base64",
|
||||||
"bytes",
|
"bytes",
|
||||||
"encoding_rs",
|
"encoding_rs",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"h2 0.3.27",
|
"h2",
|
||||||
"http 0.2.12",
|
"http",
|
||||||
"http-body 0.4.6",
|
"http-body",
|
||||||
"hyper 0.14.32",
|
"hyper",
|
||||||
"hyper-rustls",
|
"hyper-rustls",
|
||||||
"ipnet",
|
"ipnet",
|
||||||
"js-sys",
|
"js-sys",
|
||||||
|
|
@ -1126,7 +979,7 @@ version = "1.0.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c"
|
checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"base64 0.21.7",
|
"base64",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -2104,29 +1957,6 @@ dependencies = [
|
||||||
"windows-sys 0.48.0",
|
"windows-sys 0.48.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "wiremock"
|
|
||||||
version = "0.6.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "08db1edfb05d9b3c1542e521aea074442088292f00b5f28e435c714a98f85031"
|
|
||||||
dependencies = [
|
|
||||||
"assert-json-diff",
|
|
||||||
"base64 0.22.1",
|
|
||||||
"deadpool",
|
|
||||||
"futures",
|
|
||||||
"http 1.3.1",
|
|
||||||
"http-body-util",
|
|
||||||
"hyper 1.7.0",
|
|
||||||
"hyper-util",
|
|
||||||
"log",
|
|
||||||
"once_cell",
|
|
||||||
"regex",
|
|
||||||
"serde",
|
|
||||||
"serde_json",
|
|
||||||
"tokio",
|
|
||||||
"url",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wit-bindgen"
|
name = "wit-bindgen"
|
||||||
version = "0.46.0"
|
version = "0.46.0"
|
||||||
|
|
@ -2164,7 +1994,6 @@ dependencies = [
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
"uuid",
|
"uuid",
|
||||||
"wiremock",
|
|
||||||
"worker-macros",
|
"worker-macros",
|
||||||
"worker-types",
|
"worker-types",
|
||||||
"xdg",
|
"xdg",
|
||||||
|
|
|
||||||
|
|
@ -23,9 +23,9 @@ use worker::{LlmProvider, SystemPromptContext, PromptError, Worker};
|
||||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let system_prompt = |ctx: &SystemPromptContext, _messages: &[worker_types::Message]| {
|
let system_prompt = |ctx: &SystemPromptContext, _messages: &[worker_types::Message]| {
|
||||||
Ok(format!(
|
Ok(format!(
|
||||||
"You are assisting with model {} from provider {}.",
|
"You are assisting with model {} from provider {:?}.",
|
||||||
ctx.model.model_name,
|
ctx.model.model_name,
|
||||||
ctx.model.provider_id
|
ctx.model.provider
|
||||||
))
|
))
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ pub struct HookManager {
|
||||||
pub enum HookEvent {
|
pub enum HookEvent {
|
||||||
OnMessageSend,
|
OnMessageSend,
|
||||||
PreToolUse,
|
PreToolUse,
|
||||||
PostToolUse,
|
PostToolUse,
|
||||||
OnTurnCompleted,
|
OnTurnCompleted,
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
@ -181,10 +181,10 @@ impl HookContext {
|
||||||
impl HookContext {
|
impl HookContext {
|
||||||
// ストリーミング中にメッセージを送信
|
// ストリーミング中にメッセージを送信
|
||||||
pub fn stream_message(&self, content: String, role: Role);
|
pub fn stream_message(&self, content: String, role: Role);
|
||||||
|
|
||||||
// ストリーミング中にシステム通知を送信
|
// ストリーミング中にシステム通知を送信
|
||||||
pub fn stream_system_message(&self, content: String);
|
pub fn stream_system_message(&self, content: String);
|
||||||
|
|
||||||
// ストリーミング中にデバッグ情報を送信
|
// ストリーミング中にデバッグ情報を送信
|
||||||
pub fn stream_debug(&self, title: String, data: serde_json::Value);
|
pub fn stream_debug(&self, title: String, data: serde_json::Value);
|
||||||
}
|
}
|
||||||
|
|
@ -198,22 +198,22 @@ Hook関数は以下のいずれかの結果を返す必要があります:
|
||||||
pub enum HookResult {
|
pub enum HookResult {
|
||||||
// 処理を続行
|
// 処理を続行
|
||||||
Continue,
|
Continue,
|
||||||
|
|
||||||
// コンテンツを変更して続行
|
// コンテンツを変更して続行
|
||||||
ModifyContent(String),
|
ModifyContent(String),
|
||||||
|
|
||||||
// システムメッセージを追加して続行
|
// システムメッセージを追加して続行
|
||||||
AddMessage(String, Role),
|
AddMessage(String, Role),
|
||||||
|
|
||||||
// 複数のメッセージを追加して続行
|
// 複数のメッセージを追加して続行
|
||||||
AddMessages(Vec<Message>),
|
AddMessages(Vec<Message>),
|
||||||
|
|
||||||
// ターンを強制完了
|
// ターンを強制完了
|
||||||
Complete,
|
Complete,
|
||||||
|
|
||||||
// エラーでターンを終了
|
// エラーでターンを終了
|
||||||
Error(String),
|
Error(String),
|
||||||
|
|
||||||
// Hook処理をスキップ(デバッグ用)
|
// Hook処理をスキップ(デバッグ用)
|
||||||
Skip,
|
Skip,
|
||||||
}
|
}
|
||||||
|
|
@ -282,7 +282,7 @@ pub async fn dangerous_command_hook(context: HookContext) -> HookResult {
|
||||||
if let Some(args) = &context.tool_args {
|
if let Some(args) = &context.tool_args {
|
||||||
if let Some(command) = args.get("command").and_then(|v| v.as_str()) {
|
if let Some(command) = args.get("command").and_then(|v| v.as_str()) {
|
||||||
let dangerous_commands = ["rm -rf", "format", "dd if="];
|
let dangerous_commands = ["rm -rf", "format", "dd if="];
|
||||||
|
|
||||||
for dangerous in &dangerous_commands {
|
for dangerous in &dangerous_commands {
|
||||||
if command.contains(dangerous) {
|
if command.contains(dangerous) {
|
||||||
return HookResult::Error(format!(
|
return HookResult::Error(format!(
|
||||||
|
|
@ -293,7 +293,7 @@ pub async fn dangerous_command_hook(context: HookContext) -> HookResult {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
HookResult::Continue
|
HookResult::Continue
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
@ -326,7 +326,7 @@ pub async fn auto_read_hook(mut context: HookContext) -> HookResult {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
HookResult::Continue
|
HookResult::Continue
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
@ -361,7 +361,7 @@ worker.register_hooks(tui_hooks);
|
||||||
```rust
|
```rust
|
||||||
// 実行順序の例
|
// 実行順序の例
|
||||||
worker.register_hook(Box::new(TimestampHook)); // 1番目
|
worker.register_hook(Box::new(TimestampHook)); // 1番目
|
||||||
worker.register_hook(Box::new(ValidationHook)); // 2番目
|
worker.register_hook(Box::new(ValidationHook)); // 2番目
|
||||||
worker.register_hook(Box::new(LoggingHook)); // 3番目
|
worker.register_hook(Box::new(LoggingHook)); // 3番目
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -375,13 +375,13 @@ worker.register_hook(Box::new(LoggingHook)); // 3番目
|
||||||
impl Worker {
|
impl Worker {
|
||||||
// Hook一覧を取得
|
// Hook一覧を取得
|
||||||
pub fn list_hooks(&self) -> Vec<(&str, &str)>; // (name, hook_type)
|
pub fn list_hooks(&self) -> Vec<(&str, &str)>; // (name, hook_type)
|
||||||
|
|
||||||
// 特定のHookを削除
|
// 特定のHookを削除
|
||||||
pub fn remove_hook(&mut self, hook_name: &str) -> bool;
|
pub fn remove_hook(&mut self, hook_name: &str) -> bool;
|
||||||
|
|
||||||
// フェーズ別Hookを削除
|
// フェーズ別Hookを削除
|
||||||
pub fn remove_hooks_by_phase(&mut self, hook_type: &str);
|
pub fn remove_hooks_by_phase(&mut self, hook_type: &str);
|
||||||
|
|
||||||
// すべてのHookをクリア
|
// すべてのHookをクリア
|
||||||
pub fn clear_hooks(&mut self);
|
pub fn clear_hooks(&mut self);
|
||||||
}
|
}
|
||||||
|
|
@ -395,16 +395,16 @@ impl Worker {
|
||||||
// worker/src/lib.rs の process_with_shared_state より
|
// worker/src/lib.rs の process_with_shared_state より
|
||||||
stream! {
|
stream! {
|
||||||
// ... LLM応答処理中 ...
|
// ... LLM応答処理中 ...
|
||||||
|
|
||||||
// ツール呼び出し検出時
|
// ツール呼び出し検出時
|
||||||
if let Some(tool_calls) = &response.tool_calls {
|
if let Some(tool_calls) = &response.tool_calls {
|
||||||
for tool_call in tool_calls {
|
for tool_call in tool_calls {
|
||||||
// PreToolUse hooks 実行
|
// PreToolUse hooks 実行
|
||||||
let (context, hook_result) = execute_hooks(
|
let (context, hook_result) = execute_hooks(
|
||||||
HookEvent::PreToolUse,
|
HookEvent::PreToolUse,
|
||||||
tool_call.name.clone()
|
tool_call.name.clone()
|
||||||
).await;
|
).await;
|
||||||
|
|
||||||
match hook_result {
|
match hook_result {
|
||||||
HookResult::Error(msg) => {
|
HookResult::Error(msg) => {
|
||||||
yield Ok(StreamEvent::Error(msg));
|
yield Ok(StreamEvent::Error(msg));
|
||||||
|
|
@ -413,16 +413,16 @@ stream! {
|
||||||
HookResult::Complete => break,
|
HookResult::Complete => break,
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ツール実行
|
// ツール実行
|
||||||
let result = execute_tool(tool_call).await;
|
let result = execute_tool(tool_call).await;
|
||||||
|
|
||||||
// PostToolUse hooks 実行(ストリーミング中)
|
// PostToolUse hooks 実行(ストリーミング中)
|
||||||
let (context, hook_result) = execute_hooks(
|
let (context, hook_result) = execute_hooks(
|
||||||
HookEvent::PostToolUse,
|
HookEvent::PostToolUse,
|
||||||
tool_call.name.clone()
|
tool_call.name.clone()
|
||||||
).await;
|
).await;
|
||||||
|
|
||||||
// Hook結果を即座にストリーミング
|
// Hook結果を即座にストリーミング
|
||||||
if let HookResult::AddMessage(msg, role) = hook_result {
|
if let HookResult::AddMessage(msg, role) = hook_result {
|
||||||
yield Ok(StreamEvent::HookMessage {
|
yield Ok(StreamEvent::HookMessage {
|
||||||
|
|
@ -468,13 +468,13 @@ pub async fn performance_aware_hook(context: HookContext) -> HookResult {
|
||||||
// 大きなコンテンツの場合はスキップ
|
// 大きなコンテンツの場合はスキップ
|
||||||
return HookResult::Skip;
|
return HookResult::Skip;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 非同期処理は適切にawaitする
|
// 非同期処理は適切にawaitする
|
||||||
let result = tokio::time::timeout(
|
let result = tokio::time::timeout(
|
||||||
Duration::from_secs(5),
|
Duration::from_secs(5),
|
||||||
expensive_operation(&context)
|
expensive_operation(&context)
|
||||||
).await;
|
).await;
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(output) => HookResult::AddMessage(output, Role::System),
|
Ok(output) => HookResult::AddMessage(output, Role::System),
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
|
|
@ -495,20 +495,20 @@ pub async fn configurable_hook(mut context: HookContext) -> HookResult {
|
||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
.parse::<bool>()
|
.parse::<bool>()
|
||||||
.unwrap_or(false);
|
.unwrap_or(false);
|
||||||
|
|
||||||
if !enabled {
|
if !enabled {
|
||||||
return HookResult::Skip;
|
return HookResult::Skip;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 設定ファイルからオプション読み込み
|
// 設定ファイルからオプション読み込み
|
||||||
let config_path = format!("{}/hook_config.json", context.workspace_path);
|
let config_path = format!("{}/.nia/hook_config.json", context.workspace_path);
|
||||||
if let Ok(config_content) = tokio::fs::read_to_string(&config_path).await {
|
if let Ok(config_content) = tokio::fs::read_to_string(&config_path).await {
|
||||||
if let Ok(config) = serde_json::from_str::<HookConfig>(&config_content) {
|
if let Ok(config) = serde_json::from_str::<HookConfig>(&config_content) {
|
||||||
// 設定に基づく処理
|
// 設定に基づく処理
|
||||||
return process_with_config(&mut context, &config).await;
|
return process_with_config(&mut context, &config).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
HookResult::Continue
|
HookResult::Continue
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
@ -523,7 +523,7 @@ pub async fn conditional_hook(context: HookContext) -> HookResult {
|
||||||
let is_rust_project = tokio::fs::metadata(
|
let is_rust_project = tokio::fs::metadata(
|
||||||
format!("{}/Cargo.toml", context.workspace_path)
|
format!("{}/Cargo.toml", context.workspace_path)
|
||||||
).await.is_ok();
|
).await.is_ok();
|
||||||
|
|
||||||
match (is_git_repo, is_rust_project) {
|
match (is_git_repo, is_rust_project) {
|
||||||
(true, true) => {
|
(true, true) => {
|
||||||
// Rustプロジェクト + Git
|
// Rustプロジェクト + Git
|
||||||
|
|
@ -548,7 +548,7 @@ pub async fn conditional_hook(context: HookContext) -> HookResult {
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use worker::types::*;
|
use worker::types::*;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_timestamp_hook() {
|
async fn test_timestamp_hook() {
|
||||||
let mut context = HookContext {
|
let mut context = HookContext {
|
||||||
|
|
@ -561,9 +561,9 @@ mod tests {
|
||||||
tool_args: None,
|
tool_args: None,
|
||||||
tool_result: None,
|
tool_result: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = add_timestamp_hook(context).await;
|
let result = add_timestamp_hook(context).await;
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
HookResult::ModifyContent(content) => {
|
HookResult::ModifyContent(content) => {
|
||||||
assert!(content.contains("Hello, world!"));
|
assert!(content.contains("Hello, world!"));
|
||||||
|
|
@ -587,7 +587,7 @@ pub async fn debug_hook(context: HookContext) -> HookResult {
|
||||||
context.tools.len(),
|
context.tools.len(),
|
||||||
context.message_history.len()
|
context.message_history.len()
|
||||||
);
|
);
|
||||||
|
|
||||||
// デバッグ情報をストリーミング
|
// デバッグ情報をストリーミング
|
||||||
context.stream_debug(
|
context.stream_debug(
|
||||||
"Hook Debug Info".to_string(),
|
"Hook Debug Info".to_string(),
|
||||||
|
|
@ -598,7 +598,7 @@ pub async fn debug_hook(context: HookContext) -> HookResult {
|
||||||
"workspace": context.workspace_path
|
"workspace": context.workspace_path
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
|
||||||
HookResult::Continue
|
HookResult::Continue
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
@ -625,13 +625,13 @@ impl WorkerHook for StatefulHook {
|
||||||
fn name(&self) -> &str { "stateful_hook" }
|
fn name(&self) -> &str { "stateful_hook" }
|
||||||
fn hook_type(&self) -> &str { "OnTurnCompleted" }
|
fn hook_type(&self) -> &str { "OnTurnCompleted" }
|
||||||
fn matcher(&self) -> &str { "" }
|
fn matcher(&self) -> &str { "" }
|
||||||
|
|
||||||
async fn execute(&self, mut context: HookContext) -> (HookContext, HookResult) {
|
async fn execute(&self, mut context: HookContext) -> (HookContext, HookResult) {
|
||||||
let mut count = self.counter.lock().unwrap();
|
let mut count = self.counter.lock().unwrap();
|
||||||
*count += 1;
|
*count += 1;
|
||||||
|
|
||||||
context.set_variable("turn_count".to_string(), count.to_string());
|
context.set_variable("turn_count".to_string(), count.to_string());
|
||||||
|
|
||||||
if *count % 10 == 0 {
|
if *count % 10 == 0 {
|
||||||
(
|
(
|
||||||
context,
|
context,
|
||||||
|
|
@ -658,7 +658,7 @@ impl HookChain {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self { hooks: Vec::new() }
|
Self { hooks: Vec::new() }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_hook(mut self, hook: Box<dyn WorkerHook>) -> Self {
|
pub fn add_hook(mut self, hook: Box<dyn WorkerHook>) -> Self {
|
||||||
self.hooks.push(hook);
|
self.hooks.push(hook);
|
||||||
self
|
self
|
||||||
|
|
@ -670,18 +670,18 @@ impl WorkerHook for HookChain {
|
||||||
fn name(&self) -> &str { "hook_chain" }
|
fn name(&self) -> &str { "hook_chain" }
|
||||||
fn hook_type(&self) -> &str { "OnMessageSend" }
|
fn hook_type(&self) -> &str { "OnMessageSend" }
|
||||||
fn matcher(&self) -> &str { "" }
|
fn matcher(&self) -> &str { "" }
|
||||||
|
|
||||||
async fn execute(&self, mut context: HookContext) -> (HookContext, HookResult) {
|
async fn execute(&self, mut context: HookContext) -> (HookContext, HookResult) {
|
||||||
for hook in &self.hooks {
|
for hook in &self.hooks {
|
||||||
let (new_context, result) = hook.execute(context).await;
|
let (new_context, result) = hook.execute(context).await;
|
||||||
context = new_context;
|
context = new_context;
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
HookResult::Continue | HookResult::Skip => continue,
|
HookResult::Continue | HookResult::Skip => continue,
|
||||||
other => return (context, other),
|
other => return (context, other),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
(context, HookResult::Continue)
|
(context, HookResult::Continue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -715,3 +715,4 @@ A: `HookResult::Error`を返すと、そのターンは中断されます。継
|
||||||
- [worker-macro.md](worker-macro.md) - マクロシステム
|
- [worker-macro.md](worker-macro.md) - マクロシステム
|
||||||
- `worker/src/lib.rs` - Hook実装コード
|
- `worker/src/lib.rs` - Hook実装コード
|
||||||
- `worker-types/src/lib.rs` - Hook型定義
|
- `worker-types/src/lib.rs` - Hook型定義
|
||||||
|
- `nia-cli/src/tui/hooks/` - TUI用Hook実装例
|
||||||
|
|
@ -6,7 +6,7 @@ v0.3.0 はプロンプトリソースの解決責務を利用側へ完全に移
|
||||||
|
|
||||||
## Breaking Changes
|
## Breaking Changes
|
||||||
|
|
||||||
- `ConfigParser::resolve_path` を削除し、`#user/` `#workspace/` 等のプレフィックス解決をライブラリ利用者実装の `ResourceLoader` に委譲しました。
|
- `ConfigParser::resolve_path` を削除し、`#nia/` `#workspace/` 等のプレフィックス解決をライブラリ利用者実装の `ResourceLoader` に委譲しました。
|
||||||
- `WorkerBuilder::build()` は `resource_loader(...)` が未指定の場合エラーを返すようになりました。ワーカー構築前に必ずローダーを提供してください。
|
- `WorkerBuilder::build()` は `resource_loader(...)` が未指定の場合エラーを返すようになりました。ワーカー構築前に必ずローダーを提供してください。
|
||||||
|
|
||||||
## 新機能 / 仕様変更
|
## 新機能 / 仕様変更
|
||||||
|
|
@ -19,6 +19,7 @@ v0.3.0 はプロンプトリソースの解決責務を利用側へ完全に移
|
||||||
## 不具合修正
|
## 不具合修正
|
||||||
|
|
||||||
- `include_file` ヘルパーがカスタムローダーを利用せずにファイルアクセスしていた問題を修正。
|
- `include_file` ヘルパーがカスタムローダーを利用せずにファイルアクセスしていた問題を修正。
|
||||||
|
- `ConfigParser` が存在しない `#nia/` プレフィックスを静的に解決しようとしていた挙動を除去し、誤ったパスが静かに通ることを防止。
|
||||||
|
|
||||||
## 移行ガイド
|
## 移行ガイド
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ v0.4.0 は Worker が `Role` や YAML 設定を扱わず、システムプロン
|
||||||
|
|
||||||
## 不具合修正
|
## 不具合修正
|
||||||
|
|
||||||
- Worker から旧プロジェクト固有の設定コードを除去し、環境依存の副作用を縮小。
|
- Worker から NIA 固有の設定コードを除去し、環境依存の副作用を縮小。
|
||||||
|
|
||||||
## 移行ガイド
|
## 移行ガイド
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
**Release Date**: 2025-10-25
|
**Release Date**: 2025-10-25
|
||||||
|
|
||||||
v0.5.0 introduces the Worker Blueprint API and removes the old type-state builder. Configuration now lives on the blueprint, while instantiated workers keep only the materialised system prompt, model metadata, and runtime state.
|
v0.5.0 introduces the Worker Blueprint API and removes the old type-state builder. Configuration now lives on the blueprint, while instantiated workers keep only the materialised system prompt and runtime state.
|
||||||
|
|
||||||
## Breaking Changes
|
## Breaking Changes
|
||||||
|
|
||||||
|
|
@ -12,10 +12,9 @@ v0.5.0 introduces the Worker Blueprint API and removes the old type-state builde
|
||||||
|
|
||||||
## New Features / Behaviour
|
## New Features / Behaviour
|
||||||
|
|
||||||
- `WorkerBlueprint` stores provider/model/api keys, tools, hooks, optional precomputed system prompt messages, and optional model feature flags. `instantiate()` evaluates the prompt (if not already cached) and hands the final string to the `Worker`.
|
- `WorkerBlueprint` stores provider/model/api keys, tools, hooks, and optional precomputed system prompt strings. `instantiate()` evaluates the prompt (if not already cached) and hands the final string to the `Worker`.
|
||||||
- Instantiated workers retain the composed system prompt, the original generator closure, and a `Model` struct describing provider/model/features; the generator only runs again if a new session requires it.
|
- Instantiated workers retain only the composed system prompt string; the generator function lives solely on the blueprint and is dropped after instantiation.
|
||||||
- System prompts are no longer recomputed per turn. Tool metadata is appended dynamically as plain text when native tool support is unavailable.
|
- System prompts are no longer recomputed per turn. Tool metadata is appended dynamically as plain text when native tool support is unavailable.
|
||||||
- Worker now exposes a `Model` struct (`provider`, `name`, `features`) in place of the previous loose strings and `supports_native_tools` helper. Capability heuristics remain for built-in providers but applications can override them via `WorkerBlueprint::model_features`.
|
|
||||||
|
|
||||||
## Migration Guide
|
## Migration Guide
|
||||||
|
|
||||||
|
|
|
||||||
12
flake.lock
12
flake.lock
|
|
@ -2,11 +2,11 @@
|
||||||
"nodes": {
|
"nodes": {
|
||||||
"flake-compat": {
|
"flake-compat": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1761588595,
|
"lastModified": 1747046372,
|
||||||
"narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=",
|
"narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
|
||||||
"owner": "edolstra",
|
"owner": "edolstra",
|
||||||
"repo": "flake-compat",
|
"repo": "flake-compat",
|
||||||
"rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5",
|
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
|
@ -35,11 +35,11 @@
|
||||||
},
|
},
|
||||||
"nixpkgs": {
|
"nixpkgs": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1761373498,
|
"lastModified": 1751011381,
|
||||||
"narHash": "sha256-Q/uhWNvd7V7k1H1ZPMy/vkx3F8C13ZcdrKjO7Jv7v0c=",
|
"narHash": "sha256-krGXKxvkBhnrSC/kGBmg5MyupUUT5R6IBCLEzx9jhMM=",
|
||||||
"owner": "nixos",
|
"owner": "nixos",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "6a08e6bb4e46ff7fcbb53d409b253f6bad8a28ce",
|
"rev": "30e2e2857ba47844aa71991daa6ed1fc678bcbb7",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
|
|
||||||
|
|
@ -45,4 +45,3 @@ dynamic-loading = ["libloading"]
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tempfile = "3.10.1"
|
tempfile = "3.10.1"
|
||||||
tracing-subscriber = "0.3"
|
tracing-subscriber = "0.3"
|
||||||
wiremock = "0.6"
|
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
_messages: &[Message],
|
_messages: &[Message],
|
||||||
) -> Result<String, PromptError> {
|
) -> Result<String, PromptError> {
|
||||||
Ok(format!(
|
Ok(format!(
|
||||||
"You are helping with requests for model {} (provider {}).",
|
"You are helping with requests for model {} (provider {:?}).",
|
||||||
ctx.model.model_name, ctx.model.provider_id
|
ctx.model.model_name, ctx.model.provider
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
|
use crate::LlmProviderExt;
|
||||||
|
use crate::Worker;
|
||||||
use crate::plugin;
|
use crate::plugin;
|
||||||
use crate::prompt::{PromptError, SystemPromptContext, SystemPromptFn};
|
use crate::prompt::{PromptError, SystemPromptContext, SystemPromptFn};
|
||||||
use crate::types::{HookManager, Tool, WorkerError, WorkerHook};
|
use crate::types::{HookManager, Tool, WorkerError, WorkerHook};
|
||||||
use crate::{LlmProviderExt, Model, ModelFeatures, ModelProvider, Worker};
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
use worker_types::{LlmProvider, Message, Role};
|
use worker_types::{LlmProvider, Message, Role};
|
||||||
|
|
@ -23,7 +24,6 @@ pub struct WorkerBlueprint {
|
||||||
pub(crate) tools: Vec<Box<dyn Tool>>,
|
pub(crate) tools: Vec<Box<dyn Tool>>,
|
||||||
pub(crate) hooks: Vec<Box<dyn WorkerHook>>,
|
pub(crate) hooks: Vec<Box<dyn WorkerHook>>,
|
||||||
pub(crate) prompt_cache: Option<Vec<Message>>,
|
pub(crate) prompt_cache: Option<Vec<Message>>,
|
||||||
pub(crate) model_features: Option<ModelFeatures>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WorkerBlueprint {
|
impl WorkerBlueprint {
|
||||||
|
|
@ -36,7 +36,6 @@ impl WorkerBlueprint {
|
||||||
tools: Vec::new(),
|
tools: Vec::new(),
|
||||||
hooks: Vec::new(),
|
hooks: Vec::new(),
|
||||||
prompt_cache: None,
|
prompt_cache: None,
|
||||||
model_features: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -62,11 +61,6 @@ impl WorkerBlueprint {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn model_features(&mut self, features: ModelFeatures) -> &mut Self {
|
|
||||||
self.model_features = Some(features);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn api_key(&mut self, provider: impl Into<String>, key: impl Into<String>) -> &mut Self {
|
pub fn api_key(&mut self, provider: impl Into<String>, key: impl Into<String>) -> &mut Self {
|
||||||
self.api_keys.insert(provider.into(), key.into());
|
self.api_keys.insert(provider.into(), key.into());
|
||||||
self
|
self
|
||||||
|
|
@ -124,27 +118,9 @@ impl WorkerBlueprint {
|
||||||
.map(|tool| tool.name().to_string())
|
.map(|tool| tool.name().to_string())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let features = self
|
let context = self.build_system_prompt_context(provider, &model_name, &tool_names);
|
||||||
.model_features
|
|
||||||
.clone()
|
|
||||||
.unwrap_or_else(|| match provider {
|
|
||||||
ProviderConfig::BuiltIn(p) => Worker::infer_model_features(Some(*p), &model_name),
|
|
||||||
ProviderConfig::Plugin { .. } => ModelFeatures::default(),
|
|
||||||
});
|
|
||||||
|
|
||||||
let preview_model = Model {
|
|
||||||
provider: match provider {
|
|
||||||
ProviderConfig::BuiltIn(p) => ModelProvider::BuiltIn(*p),
|
|
||||||
ProviderConfig::Plugin { id, .. } => ModelProvider::Plugin(id.clone()),
|
|
||||||
},
|
|
||||||
name: model_name.clone(),
|
|
||||||
features: features.clone(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let context = Worker::create_system_prompt_context(&preview_model, &tool_names);
|
|
||||||
let prompt = generator(&context, &[]).map_err(|e| WorkerError::config(e.to_string()))?;
|
let prompt = generator(&context, &[]).map_err(|e| WorkerError::config(e.to_string()))?;
|
||||||
self.prompt_cache = Some(vec![Message::new(Role::System, prompt)]);
|
self.prompt_cache = Some(vec![Message::new(Role::System, prompt)]);
|
||||||
self.model_features = Some(features);
|
|
||||||
Ok(self)
|
Ok(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -162,49 +138,30 @@ impl WorkerBlueprint {
|
||||||
.take()
|
.take()
|
||||||
.ok_or_else(|| WorkerError::config("System prompt generator is not configured"))?;
|
.ok_or_else(|| WorkerError::config("System prompt generator is not configured"))?;
|
||||||
|
|
||||||
let mut prompt_cache = self.prompt_cache.take();
|
|
||||||
let mut provided_features = self.model_features.take();
|
|
||||||
let tools = std::mem::take(&mut self.tools);
|
let tools = std::mem::take(&mut self.tools);
|
||||||
let hooks = std::mem::take(&mut self.hooks);
|
let hooks = std::mem::take(&mut self.hooks);
|
||||||
let mut api_keys = self.api_keys;
|
let mut api_keys = self.api_keys;
|
||||||
|
|
||||||
let tool_names: Vec<String> = tools.iter().map(|tool| tool.name().to_string()).collect();
|
let tool_names: Vec<String> = tools.iter().map(|tool| tool.name().to_string()).collect();
|
||||||
|
let provider_hint = provider_config.provider_hint();
|
||||||
|
let prompt_context =
|
||||||
|
Worker::create_system_prompt_context(provider_hint, &model_name, &tool_names);
|
||||||
|
|
||||||
|
let base_messages = match self.prompt_cache.take() {
|
||||||
|
Some(messages) if !messages.is_empty() => messages,
|
||||||
|
_ => {
|
||||||
|
let prompt = system_prompt_fn(&prompt_context, &[])
|
||||||
|
.map_err(|e| WorkerError::config(e.to_string()))?;
|
||||||
|
vec![Message::new(Role::System, prompt)]
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let base_system_prompt = base_messages
|
||||||
|
.first()
|
||||||
|
.map(|msg| msg.content.clone())
|
||||||
|
.unwrap_or_else(|| String::new());
|
||||||
|
|
||||||
match provider_config {
|
match provider_config {
|
||||||
ProviderConfig::BuiltIn(provider) => {
|
ProviderConfig::BuiltIn(provider) => {
|
||||||
let features = provided_features
|
|
||||||
.take()
|
|
||||||
.unwrap_or_else(|| Worker::infer_model_features(Some(provider), &model_name));
|
|
||||||
|
|
||||||
let model = Model {
|
|
||||||
provider: ModelProvider::BuiltIn(provider),
|
|
||||||
name: model_name.clone(),
|
|
||||||
features,
|
|
||||||
};
|
|
||||||
|
|
||||||
let prompt_context = Worker::create_system_prompt_context(&model, &tool_names);
|
|
||||||
let base_messages = if let Some(messages) = prompt_cache.take() {
|
|
||||||
if messages.is_empty() {
|
|
||||||
vec![Message::new(
|
|
||||||
Role::System,
|
|
||||||
system_prompt_fn(&prompt_context, &[])
|
|
||||||
.map_err(|e| WorkerError::config(e.to_string()))?,
|
|
||||||
)]
|
|
||||||
} else {
|
|
||||||
messages
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
vec![Message::new(
|
|
||||||
Role::System,
|
|
||||||
system_prompt_fn(&prompt_context, &[])
|
|
||||||
.map_err(|e| WorkerError::config(e.to_string()))?,
|
|
||||||
)]
|
|
||||||
};
|
|
||||||
|
|
||||||
let base_system_prompt = base_messages
|
|
||||||
.first()
|
|
||||||
.map(|msg| msg.content.clone())
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
let api_key = api_keys
|
let api_key = api_keys
|
||||||
.entry(provider.as_str().to_string())
|
.entry(provider.as_str().to_string())
|
||||||
.or_insert_with(String::new)
|
.or_insert_with(String::new)
|
||||||
|
|
@ -213,12 +170,13 @@ impl WorkerBlueprint {
|
||||||
let llm_client = provider.create_client(&model_name, &api_key)?;
|
let llm_client = provider.create_client(&model_name, &api_key)?;
|
||||||
let mut worker = Worker {
|
let mut worker = Worker {
|
||||||
llm_client: Box::new(llm_client),
|
llm_client: Box::new(llm_client),
|
||||||
system_prompt: base_system_prompt,
|
system_prompt: base_system_prompt.clone(),
|
||||||
system_prompt_fn: Arc::clone(&system_prompt_fn),
|
system_prompt_fn: Arc::clone(&system_prompt_fn),
|
||||||
tools,
|
tools,
|
||||||
api_key,
|
api_key,
|
||||||
model,
|
provider_str: provider.as_str().to_string(),
|
||||||
message_history: base_messages,
|
model_name,
|
||||||
|
message_history: base_messages.clone(),
|
||||||
hook_manager: HookManager::new(),
|
hook_manager: HookManager::new(),
|
||||||
mcp_lazy_configs: Vec::new(),
|
mcp_lazy_configs: Vec::new(),
|
||||||
plugin_registry: Arc::new(Mutex::new(plugin::PluginRegistry::new())),
|
plugin_registry: Arc::new(Mutex::new(plugin::PluginRegistry::new())),
|
||||||
|
|
@ -227,40 +185,6 @@ impl WorkerBlueprint {
|
||||||
Ok(worker)
|
Ok(worker)
|
||||||
}
|
}
|
||||||
ProviderConfig::Plugin { id, registry } => {
|
ProviderConfig::Plugin { id, registry } => {
|
||||||
let features = provided_features
|
|
||||||
.take()
|
|
||||||
.unwrap_or_else(ModelFeatures::default);
|
|
||||||
|
|
||||||
let model = Model {
|
|
||||||
provider: ModelProvider::Plugin(id.clone()),
|
|
||||||
name: model_name.clone(),
|
|
||||||
features,
|
|
||||||
};
|
|
||||||
|
|
||||||
let prompt_context = Worker::create_system_prompt_context(&model, &tool_names);
|
|
||||||
let base_messages = if let Some(messages) = prompt_cache.take() {
|
|
||||||
if messages.is_empty() {
|
|
||||||
vec![Message::new(
|
|
||||||
Role::System,
|
|
||||||
system_prompt_fn(&prompt_context, &[])
|
|
||||||
.map_err(|e| WorkerError::config(e.to_string()))?,
|
|
||||||
)]
|
|
||||||
} else {
|
|
||||||
messages
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
vec![Message::new(
|
|
||||||
Role::System,
|
|
||||||
system_prompt_fn(&prompt_context, &[])
|
|
||||||
.map_err(|e| WorkerError::config(e.to_string()))?,
|
|
||||||
)]
|
|
||||||
};
|
|
||||||
|
|
||||||
let base_system_prompt = base_messages
|
|
||||||
.first()
|
|
||||||
.map(|msg| msg.content.clone())
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
let api_key = api_keys
|
let api_key = api_keys
|
||||||
.remove("__plugin__")
|
.remove("__plugin__")
|
||||||
.or_else(|| api_keys.values().next().cloned())
|
.or_else(|| api_keys.values().next().cloned())
|
||||||
|
|
@ -284,7 +208,8 @@ impl WorkerBlueprint {
|
||||||
system_prompt_fn,
|
system_prompt_fn,
|
||||||
tools,
|
tools,
|
||||||
api_key,
|
api_key,
|
||||||
model,
|
provider_str: id,
|
||||||
|
model_name,
|
||||||
message_history: base_messages,
|
message_history: base_messages,
|
||||||
hook_manager: HookManager::new(),
|
hook_manager: HookManager::new(),
|
||||||
mcp_lazy_configs: Vec::new(),
|
mcp_lazy_configs: Vec::new(),
|
||||||
|
|
@ -295,6 +220,15 @@ impl WorkerBlueprint {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn build_system_prompt_context(
|
||||||
|
&self,
|
||||||
|
provider: &ProviderConfig,
|
||||||
|
model_name: &str,
|
||||||
|
tool_names: &[String],
|
||||||
|
) -> SystemPromptContext {
|
||||||
|
Worker::create_system_prompt_context(provider.provider_hint(), model_name, tool_names)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ProviderConfig {
|
impl ProviderConfig {
|
||||||
|
|
@ -311,4 +245,11 @@ impl ProviderConfig {
|
||||||
ProviderConfig::Plugin { registry, .. } => Some(Arc::clone(registry)),
|
ProviderConfig::Plugin { registry, .. } => Some(Arc::clone(registry)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn provider_hint(&self) -> LlmProvider {
|
||||||
|
match self {
|
||||||
|
ProviderConfig::BuiltIn(provider) => *provider,
|
||||||
|
ProviderConfig::Plugin { .. } => LlmProvider::OpenAI,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,8 @@ use llm::{
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::fs;
|
||||||
|
use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tracing;
|
use tracing;
|
||||||
use uuid;
|
use uuid;
|
||||||
|
|
@ -322,67 +324,127 @@ pub async fn validate_api_key(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
pub struct ModelsConfig {
|
||||||
|
pub models: Vec<ModelDefinition>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
pub struct ModelDefinition {
|
||||||
|
pub model: String,
|
||||||
|
pub name: String,
|
||||||
|
pub meta: ModelMeta,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
pub struct ModelMeta {
|
||||||
|
pub tool_support: bool,
|
||||||
|
pub function_calling: bool,
|
||||||
|
pub vision: bool,
|
||||||
|
pub multimodal: bool,
|
||||||
|
pub context_length: Option<u32>,
|
||||||
|
pub description: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_models_config_path() -> Result<PathBuf, WorkerError> {
|
||||||
|
let home_dir = dirs::home_dir()
|
||||||
|
.ok_or_else(|| WorkerError::config("Could not determine home directory"))?;
|
||||||
|
Ok(home_dir.join(".config").join("nia").join("models.yaml"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_models_config() -> Result<ModelsConfig, WorkerError> {
|
||||||
|
let config_path = get_models_config_path()?;
|
||||||
|
|
||||||
|
if !config_path.exists() {
|
||||||
|
tracing::warn!(
|
||||||
|
"Models config file not found at {:?}, using defaults",
|
||||||
|
config_path
|
||||||
|
);
|
||||||
|
return Ok(ModelsConfig { models: vec![] });
|
||||||
|
}
|
||||||
|
|
||||||
|
let content = fs::read_to_string(&config_path)
|
||||||
|
.map_err(|e| WorkerError::config(format!("Failed to read models config: {}", e)))?;
|
||||||
|
|
||||||
|
let config: ModelsConfig = serde_yaml::from_str(&content)
|
||||||
|
.map_err(|e| WorkerError::config(format!("Failed to parse models config: {}", e)))?;
|
||||||
|
|
||||||
|
Ok(config)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn supports_native_tools(
|
||||||
|
provider: &LlmProvider,
|
||||||
|
model_name: &str,
|
||||||
|
_api_key: &str,
|
||||||
|
) -> Result<bool, WorkerError> {
|
||||||
|
let config = load_models_config()?;
|
||||||
|
|
||||||
|
let model_id = format!(
|
||||||
|
"{}/{}",
|
||||||
|
match provider {
|
||||||
|
LlmProvider::Claude => "anthropic",
|
||||||
|
LlmProvider::OpenAI => "openai",
|
||||||
|
LlmProvider::Gemini => "gemini",
|
||||||
|
LlmProvider::Ollama => "ollama",
|
||||||
|
LlmProvider::XAI => "xai",
|
||||||
|
},
|
||||||
|
model_name
|
||||||
|
);
|
||||||
|
|
||||||
|
for model_def in &config.models {
|
||||||
|
if model_def.model == model_id || model_def.model.contains(model_name) {
|
||||||
|
tracing::debug!(
|
||||||
|
"Found model config: model={}, function_calling={}",
|
||||||
|
model_def.model,
|
||||||
|
model_def.meta.function_calling
|
||||||
|
);
|
||||||
|
return Ok(model_def.meta.function_calling);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::warn!(
|
||||||
|
"Model not found in config: {} ({}), using provider defaults",
|
||||||
|
model_id,
|
||||||
|
model_name
|
||||||
|
);
|
||||||
|
|
||||||
|
tracing::warn!(
|
||||||
|
"Using provider-based fallback - this should be configured in models.yaml: provider={:?}, model={}",
|
||||||
|
provider,
|
||||||
|
model_name
|
||||||
|
);
|
||||||
|
|
||||||
|
let supports_tools = match provider {
|
||||||
|
LlmProvider::Claude => true,
|
||||||
|
LlmProvider::OpenAI => !model_name.contains("gpt-3.5-turbo-instruct"),
|
||||||
|
LlmProvider::Gemini => !model_name.contains("gemini-pro-vision"),
|
||||||
|
LlmProvider::Ollama => false,
|
||||||
|
LlmProvider::XAI => true,
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::debug!(
|
||||||
|
"Fallback tool support check: provider={:?}, model={}, supports_tools={}",
|
||||||
|
provider,
|
||||||
|
model_name,
|
||||||
|
supports_tools
|
||||||
|
);
|
||||||
|
Ok(supports_tools)
|
||||||
|
}
|
||||||
|
|
||||||
pub struct Worker {
|
pub struct Worker {
|
||||||
pub(crate) llm_client: Box<dyn LlmClientTrait>,
|
pub(crate) llm_client: Box<dyn LlmClientTrait>,
|
||||||
pub(crate) system_prompt: String,
|
pub(crate) system_prompt: String,
|
||||||
pub(crate) system_prompt_fn: Arc<SystemPromptFn>,
|
pub(crate) system_prompt_fn: Arc<SystemPromptFn>,
|
||||||
pub(crate) tools: Vec<Box<dyn Tool>>,
|
pub(crate) tools: Vec<Box<dyn Tool>>,
|
||||||
pub(crate) api_key: String,
|
pub(crate) api_key: String,
|
||||||
pub(crate) model: Model,
|
pub(crate) provider_str: String,
|
||||||
|
pub(crate) model_name: String,
|
||||||
pub(crate) message_history: Vec<Message>,
|
pub(crate) message_history: Vec<Message>,
|
||||||
pub(crate) hook_manager: crate::types::HookManager,
|
pub(crate) hook_manager: crate::types::HookManager,
|
||||||
pub(crate) mcp_lazy_configs: Vec<McpServerConfig>,
|
pub(crate) mcp_lazy_configs: Vec<McpServerConfig>,
|
||||||
pub(crate) plugin_registry: std::sync::Arc<std::sync::Mutex<plugin::PluginRegistry>>,
|
pub(crate) plugin_registry: std::sync::Arc<std::sync::Mutex<plugin::PluginRegistry>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub enum ModelProvider {
|
|
||||||
BuiltIn(LlmProvider),
|
|
||||||
Plugin(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ModelProvider {
|
|
||||||
pub fn identifier(&self) -> String {
|
|
||||||
match self {
|
|
||||||
ModelProvider::BuiltIn(provider) => provider.as_str().to_string(),
|
|
||||||
ModelProvider::Plugin(id) => id.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn as_llm_provider(&self) -> Option<LlmProvider> {
|
|
||||||
match self {
|
|
||||||
ModelProvider::BuiltIn(provider) => Some(*provider),
|
|
||||||
ModelProvider::Plugin(_) => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct Model {
|
|
||||||
pub provider: ModelProvider,
|
|
||||||
pub name: String,
|
|
||||||
pub features: ModelFeatures,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
|
||||||
pub struct ModelFeatures {
|
|
||||||
pub supports_tools: bool,
|
|
||||||
pub supports_function_calling: bool,
|
|
||||||
pub supports_vision: bool,
|
|
||||||
pub supports_multimodal: bool,
|
|
||||||
pub context_length: Option<u64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Model {
|
|
||||||
pub fn provider_id(&self) -> String {
|
|
||||||
self.provider.identifier()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn built_in_provider(&self) -> Option<LlmProvider> {
|
|
||||||
self.provider.as_llm_provider()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Worker {
|
impl Worker {
|
||||||
/// Create a new Worker blueprint
|
/// Create a new Worker blueprint
|
||||||
///
|
///
|
||||||
|
|
@ -442,20 +504,21 @@ impl Worker {
|
||||||
system_prompt_fn,
|
system_prompt_fn,
|
||||||
tools,
|
tools,
|
||||||
api_key,
|
api_key,
|
||||||
model,
|
provider_str,
|
||||||
|
model_name,
|
||||||
message_history,
|
message_history,
|
||||||
hook_manager,
|
hook_manager,
|
||||||
mcp_lazy_configs: _,
|
mcp_lazy_configs: _,
|
||||||
plugin_registry,
|
plugin_registry,
|
||||||
} = self;
|
} = self;
|
||||||
|
|
||||||
let provider = match &model.provider {
|
let provider = match LlmProvider::from_str(&provider_str) {
|
||||||
ModelProvider::BuiltIn(p) => {
|
Some(p) => {
|
||||||
drop(plugin_registry);
|
drop(plugin_registry);
|
||||||
ProviderConfig::BuiltIn(*p)
|
ProviderConfig::BuiltIn(p)
|
||||||
}
|
}
|
||||||
ModelProvider::Plugin(id) => ProviderConfig::Plugin {
|
None => ProviderConfig::Plugin {
|
||||||
id: id.clone(),
|
id: provider_str.clone(),
|
||||||
registry: plugin_registry,
|
registry: plugin_registry,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
@ -472,13 +535,12 @@ impl Worker {
|
||||||
|
|
||||||
WorkerBlueprint {
|
WorkerBlueprint {
|
||||||
provider: Some(provider),
|
provider: Some(provider),
|
||||||
model_name: Some(model.name),
|
model_name: Some(model_name),
|
||||||
api_keys,
|
api_keys,
|
||||||
system_prompt_fn: Some(system_prompt_fn),
|
system_prompt_fn: Some(system_prompt_fn),
|
||||||
tools,
|
tools,
|
||||||
hooks: hook_manager.into_hooks(),
|
hooks: hook_manager.into_hooks(),
|
||||||
prompt_cache: Some(message_history),
|
prompt_cache: Some(message_history),
|
||||||
model_features: Some(model.features),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -753,19 +815,27 @@ impl Worker {
|
||||||
|
|
||||||
/// 静的プロンプトコンテキストを作成(構築時用)
|
/// 静的プロンプトコンテキストを作成(構築時用)
|
||||||
pub(crate) fn create_system_prompt_context(
|
pub(crate) fn create_system_prompt_context(
|
||||||
model: &Model,
|
provider: LlmProvider,
|
||||||
|
model_name: &str,
|
||||||
tools: &[String],
|
tools: &[String],
|
||||||
) -> crate::prompt::SystemPromptContext {
|
) -> crate::prompt::SystemPromptContext {
|
||||||
|
let supports_native_tools = match provider {
|
||||||
|
LlmProvider::Claude => true,
|
||||||
|
LlmProvider::OpenAI => !model_name.contains("gpt-3.5-turbo-instruct"),
|
||||||
|
LlmProvider::Gemini => !model_name.contains("gemini-pro-vision"),
|
||||||
|
LlmProvider::Ollama => model_name.contains("llama") || model_name.contains("mistral"),
|
||||||
|
LlmProvider::XAI => true,
|
||||||
|
};
|
||||||
|
|
||||||
let model_context = crate::prompt::ModelContext {
|
let model_context = crate::prompt::ModelContext {
|
||||||
provider: model.built_in_provider(),
|
provider,
|
||||||
provider_id: model.provider_id(),
|
model_name: model_name.to_string(),
|
||||||
model_name: model.name.clone(),
|
|
||||||
capabilities: crate::prompt::ModelCapabilities {
|
capabilities: crate::prompt::ModelCapabilities {
|
||||||
supports_tools: model.features.supports_tools,
|
supports_tools: supports_native_tools,
|
||||||
supports_function_calling: model.features.supports_function_calling,
|
supports_function_calling: supports_native_tools,
|
||||||
supports_vision: model.features.supports_vision,
|
supports_vision: false,
|
||||||
supports_multimodal: Some(model.features.supports_multimodal),
|
supports_multimodal: Some(false),
|
||||||
context_length: model.features.context_length,
|
context_length: None,
|
||||||
capabilities: vec![],
|
capabilities: vec![],
|
||||||
needs_verification: Some(false),
|
needs_verification: Some(false),
|
||||||
},
|
},
|
||||||
|
|
@ -780,38 +850,23 @@ impl Worker {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn infer_model_features(provider: Option<LlmProvider>, model_name: &str) -> ModelFeatures {
|
|
||||||
let mut features = ModelFeatures::default();
|
|
||||||
|
|
||||||
if let Some(provider) = provider {
|
|
||||||
let supports_tools = match provider {
|
|
||||||
LlmProvider::Claude => true,
|
|
||||||
LlmProvider::OpenAI => !model_name.contains("gpt-3.5-turbo-instruct"),
|
|
||||||
LlmProvider::Gemini => !model_name.contains("gemini-pro-vision"),
|
|
||||||
LlmProvider::Ollama => {
|
|
||||||
model_name.contains("llama") || model_name.contains("mistral")
|
|
||||||
}
|
|
||||||
LlmProvider::XAI => true,
|
|
||||||
};
|
|
||||||
|
|
||||||
features.supports_tools = supports_tools;
|
|
||||||
features.supports_function_calling = supports_tools;
|
|
||||||
}
|
|
||||||
|
|
||||||
features
|
|
||||||
}
|
|
||||||
|
|
||||||
/// モデルを変更する
|
/// モデルを変更する
|
||||||
pub fn change_model(&mut self, model: Model, api_key: &str) -> Result<(), WorkerError> {
|
pub fn change_model(
|
||||||
if let Some(provider) = model.built_in_provider() {
|
&mut self,
|
||||||
let new_client = provider.create_client(&model.name, api_key)?;
|
provider: LlmProvider,
|
||||||
self.llm_client = Box::new(new_client);
|
model_name: &str,
|
||||||
}
|
api_key: &str,
|
||||||
|
) -> Result<(), WorkerError> {
|
||||||
|
// 新しいLLMクライアントを作成
|
||||||
|
let new_client = provider.create_client(model_name, api_key)?;
|
||||||
|
|
||||||
self.model = model;
|
// 古いクライアントを新しいものに置き換え
|
||||||
|
self.llm_client = Box::new(new_client);
|
||||||
|
self.provider_str = provider.as_str().to_string();
|
||||||
|
self.model_name = model_name.to_string();
|
||||||
self.api_key = api_key.to_string();
|
self.api_key = api_key.to_string();
|
||||||
|
|
||||||
tracing::info!("Model changed to {}", self.model.provider_id());
|
tracing::info!("Model changed to {}/{}", provider.as_str(), model_name);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -876,18 +931,16 @@ impl Worker {
|
||||||
|
|
||||||
/// Get the model name for tool support detection
|
/// Get the model name for tool support detection
|
||||||
pub fn get_model_name(&self) -> String {
|
pub fn get_model_name(&self) -> String {
|
||||||
self.model.name.clone()
|
self.llm_client.get_model_name()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_provider_name(&self) -> String {
|
pub fn get_provider_name(&self) -> String {
|
||||||
self.model.provider_id()
|
self.llm_client.provider().to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get configuration information for task delegation
|
/// Get configuration information for task delegation
|
||||||
pub fn get_config(&self) -> Option<(LlmProvider, &str, &str)> {
|
pub fn get_config(&self) -> (LlmProvider, &str, &str) {
|
||||||
self.model
|
(self.llm_client.provider(), &self.model_name, &self.api_key)
|
||||||
.built_in_provider()
|
|
||||||
.map(|provider| (provider, self.model.name.as_str(), self.api_key.as_str()))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get tool names (used to filter out specific tools)
|
/// Get tool names (used to filter out specific tools)
|
||||||
|
|
@ -983,37 +1036,25 @@ impl Worker {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create a temporary worker for processing without holding the lock
|
// Create a temporary worker for processing without holding the lock
|
||||||
let (llm_client, system_prompt, tool_definitions, model) = {
|
let (llm_client, system_prompt, tool_definitions, api_key, model_name) = {
|
||||||
let w_locked = worker.lock().await;
|
let w_locked = worker.lock().await;
|
||||||
let provider = match w_locked.model.built_in_provider() {
|
let llm_client = w_locked.llm_client.provider().create_client(&w_locked.model_name, &w_locked.api_key);
|
||||||
Some(provider) => provider,
|
match llm_client {
|
||||||
None => {
|
|
||||||
yield Err(WorkerError::config(
|
|
||||||
"Delegated processing is not supported for plugin providers",
|
|
||||||
));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
match provider.create_client(&w_locked.model.name, &w_locked.api_key) {
|
|
||||||
Ok(client) => {
|
Ok(client) => {
|
||||||
let tool_defs = w_locked
|
let tool_defs = w_locked.tools.iter().map(|tool| crate::types::DynamicToolDefinition {
|
||||||
.tools
|
name: tool.name().to_string(),
|
||||||
.iter()
|
description: tool.description().to_string(),
|
||||||
.map(|tool| crate::types::DynamicToolDefinition {
|
parameters_schema: tool.parameters_schema(),
|
||||||
name: tool.name().to_string(),
|
}).collect::<Vec<_>>();
|
||||||
description: tool.description().to_string(),
|
|
||||||
parameters_schema: tool.parameters_schema(),
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
(
|
(
|
||||||
client,
|
client,
|
||||||
w_locked.system_prompt.clone(),
|
w_locked.system_prompt.clone(),
|
||||||
tool_defs,
|
tool_defs,
|
||||||
w_locked.model.clone(),
|
w_locked.api_key.clone(),
|
||||||
|
w_locked.model_name.clone()
|
||||||
)
|
)
|
||||||
}
|
},
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
yield Err(e);
|
yield Err(e);
|
||||||
return;
|
return;
|
||||||
|
|
@ -1026,7 +1067,13 @@ impl Worker {
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let provider = llm_client.provider();
|
let provider = llm_client.provider();
|
||||||
let supports_native = model.features.supports_tools;
|
let supports_native = match supports_native_tools(&provider, &model_name, &api_key).await {
|
||||||
|
Ok(supports) => supports,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!("Failed to check native tool support: {}", e);
|
||||||
|
false
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let (composed_messages, tools_for_llm) = if supports_native {
|
let (composed_messages, tools_for_llm) = if supports_native {
|
||||||
let messages =
|
let messages =
|
||||||
|
|
@ -1207,21 +1254,16 @@ impl Worker {
|
||||||
let tools = self.get_tools();
|
let tools = self.get_tools();
|
||||||
let provider = self.llm_client.provider();
|
let provider = self.llm_client.provider();
|
||||||
let model_name = self.get_model_name();
|
let model_name = self.get_model_name();
|
||||||
tracing::debug!(
|
tracing::debug!("Checking native tool support: provider={:?}, model_name={}, api_key_len={}, provider_str={}", provider, model_name, self.api_key.len(), self.provider_str);
|
||||||
"Checking native tool support: provider={:?}, model_name={}, api_key_len={}, provider_id={}",
|
let supports_native = match supports_native_tools(&provider, &model_name, &self.api_key).await {
|
||||||
provider,
|
Ok(supports) => supports,
|
||||||
model_name,
|
Err(e) => {
|
||||||
self.api_key.len(),
|
tracing::warn!("Failed to check native tool support: {}", e);
|
||||||
self.model.provider_id()
|
false
|
||||||
);
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let supports_native = self.model.features.supports_tools;
|
tracing::info!("Model {} supports native tools: {}", model_name, supports_native);
|
||||||
|
|
||||||
tracing::info!(
|
|
||||||
"Model {} supports native tools: {}",
|
|
||||||
model_name,
|
|
||||||
supports_native
|
|
||||||
);
|
|
||||||
|
|
||||||
let (composed_messages, tools_for_llm) = if supports_native {
|
let (composed_messages, tools_for_llm) = if supports_native {
|
||||||
// Native tools - basic composition
|
// Native tools - basic composition
|
||||||
|
|
@ -1429,16 +1471,15 @@ impl Worker {
|
||||||
loop {
|
loop {
|
||||||
let tools = self.get_tools();
|
let tools = self.get_tools();
|
||||||
let provider = self.llm_client.provider();
|
let provider = self.llm_client.provider();
|
||||||
let model_name = self.model.name.clone();
|
let model_name = self.get_model_name();
|
||||||
tracing::debug!(
|
tracing::debug!("Checking native tool support: provider={:?}, model_name={}, api_key_len={}, provider_str={}", provider, model_name, self.api_key.len(), self.provider_str);
|
||||||
"Checking native tool support: provider={:?}, model_name={}, api_key_len={}, provider_id={}",
|
let supports_native = match supports_native_tools(&provider, &model_name, &self.api_key).await {
|
||||||
provider,
|
Ok(supports) => supports,
|
||||||
model_name,
|
Err(e) => {
|
||||||
self.api_key.len(),
|
tracing::warn!("Failed to check native tool support: {}", e);
|
||||||
self.model.provider_id()
|
false
|
||||||
);
|
}
|
||||||
|
};
|
||||||
let supports_native = self.model.features.supports_tools;
|
|
||||||
|
|
||||||
tracing::info!("Model {} supports native tools: {}", model_name, supports_native);
|
tracing::info!("Model {} supports native tools: {}", model_name, supports_native);
|
||||||
|
|
||||||
|
|
@ -1632,8 +1673,8 @@ impl Worker {
|
||||||
let session_id = uuid::Uuid::new_v4().to_string();
|
let session_id = uuid::Uuid::new_v4().to_string();
|
||||||
let mut session_data = SessionData::new(
|
let mut session_data = SessionData::new(
|
||||||
session_id,
|
session_id,
|
||||||
self.model.provider_id(),
|
self.provider_str.clone(),
|
||||||
self.model.name.clone(),
|
self.model_name.clone(),
|
||||||
workspace_path,
|
workspace_path,
|
||||||
);
|
);
|
||||||
session_data.git_branch = git_branch;
|
session_data.git_branch = git_branch;
|
||||||
|
|
@ -1647,15 +1688,15 @@ impl Worker {
|
||||||
/// セッションデータから履歴を復元する
|
/// セッションデータから履歴を復元する
|
||||||
pub fn load_session(&mut self, session_data: &SessionData) -> Result<(), WorkerError> {
|
pub fn load_session(&mut self, session_data: &SessionData) -> Result<(), WorkerError> {
|
||||||
// モデルが異なる場合は警告をログに出す
|
// モデルが異なる場合は警告をログに出す
|
||||||
if session_data.model_provider != self.model.provider_id()
|
if session_data.model_provider != self.provider_str
|
||||||
|| session_data.model_name != self.model.name
|
|| session_data.model_name != self.model_name
|
||||||
{
|
{
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
"Loading session with different model: session={}:{}, current={}:{}",
|
"Loading session with different model: session={}:{}, current={}:{}",
|
||||||
session_data.model_provider,
|
session_data.model_provider,
|
||||||
session_data.model_name,
|
session_data.model_name,
|
||||||
self.model.provider_id(),
|
self.provider_str,
|
||||||
self.model.name
|
self.model_name
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,8 @@ use super::tool::McpServerConfig;
|
||||||
use crate::types::WorkerError;
|
use crate::types::WorkerError;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use tracing::{debug, warn};
|
use std::path::Path;
|
||||||
|
use tracing::{debug, info, warn};
|
||||||
|
|
||||||
/// MCP設定ファイルの構造
|
/// MCP設定ファイルの構造
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|
@ -56,10 +57,121 @@ fn default_integration_mode() -> IntegrationMode {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl McpConfig {
|
impl McpConfig {
|
||||||
|
/// 設定ファイルを読み込む
|
||||||
|
pub fn load_from_file<P: AsRef<Path>>(path: P) -> Result<Self, WorkerError> {
|
||||||
|
let path = path.as_ref();
|
||||||
|
|
||||||
|
if !path.exists() {
|
||||||
|
debug!(
|
||||||
|
"MCP config file not found at {:?}, returning empty config",
|
||||||
|
path
|
||||||
|
);
|
||||||
|
return Ok(Self::default());
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("Loading MCP config from: {:?}", path);
|
||||||
|
let content = std::fs::read_to_string(path).map_err(|e| {
|
||||||
|
WorkerError::config(format!("Failed to read MCP config file {:?}: {}", path, e))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let config: McpConfig = serde_yaml::from_str(&content).map_err(|e| {
|
||||||
|
WorkerError::config(format!("Failed to parse MCP config file {:?}: {}", path, e))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
info!("Loaded {} MCP server configurations", config.servers.len());
|
||||||
|
Ok(config)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 設定ファイルに保存する
|
||||||
|
pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> Result<(), WorkerError> {
|
||||||
|
let path = path.as_ref();
|
||||||
|
|
||||||
|
// ディレクトリが存在しない場合は作成
|
||||||
|
if let Some(parent) = path.parent() {
|
||||||
|
std::fs::create_dir_all(parent).map_err(|e| {
|
||||||
|
WorkerError::config(format!(
|
||||||
|
"Failed to create config directory {:?}: {}",
|
||||||
|
parent, e
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let content = serde_yaml::to_string(self)
|
||||||
|
.map_err(|e| WorkerError::config(format!("Failed to serialize MCP config: {}", e)))?;
|
||||||
|
|
||||||
|
std::fs::write(path, content).map_err(|e| {
|
||||||
|
WorkerError::config(format!("Failed to write MCP config file {:?}: {}", path, e))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
info!("Saved MCP config to: {:?}", path);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// 有効なサーバー設定を取得
|
/// 有効なサーバー設定を取得
|
||||||
pub fn get_enabled_servers(&self) -> Vec<(&String, &McpServerDefinition)> {
|
pub fn get_enabled_servers(&self) -> Vec<(&String, &McpServerDefinition)> {
|
||||||
self.servers.iter().filter(|(_, def)| def.enabled).collect()
|
self.servers.iter().filter(|(_, def)| def.enabled).collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// デフォルト設定ファイルを生成
|
||||||
|
pub fn create_default_config() -> Self {
|
||||||
|
let mut servers = HashMap::new();
|
||||||
|
|
||||||
|
// Brave Search MCP Server の設定例
|
||||||
|
servers.insert(
|
||||||
|
"brave_search".to_string(),
|
||||||
|
McpServerDefinition {
|
||||||
|
command: "npx".to_string(),
|
||||||
|
args: vec![
|
||||||
|
"-y".to_string(),
|
||||||
|
"@brave/brave-search-mcp-server".to_string(),
|
||||||
|
],
|
||||||
|
env: {
|
||||||
|
let mut env = HashMap::new();
|
||||||
|
env.insert("BRAVE_API_KEY".to_string(), "${BRAVE_API_KEY}".to_string());
|
||||||
|
env
|
||||||
|
},
|
||||||
|
description: Some("Brave Search API for web searching".to_string()),
|
||||||
|
enabled: false, // デフォルトでは無効(APIキーが必要なため)
|
||||||
|
integration_mode: IntegrationMode::Individual,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
// ファイルシステムMCPサーバーの設定例
|
||||||
|
servers.insert(
|
||||||
|
"filesystem".to_string(),
|
||||||
|
McpServerDefinition {
|
||||||
|
command: "npx".to_string(),
|
||||||
|
args: vec![
|
||||||
|
"-y".to_string(),
|
||||||
|
"@modelcontextprotocol/server-filesystem".to_string(),
|
||||||
|
"/tmp".to_string(),
|
||||||
|
],
|
||||||
|
env: HashMap::new(),
|
||||||
|
description: Some("Filesystem operations in /tmp directory".to_string()),
|
||||||
|
enabled: false, // デフォルトでは無効
|
||||||
|
integration_mode: IntegrationMode::Individual,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
// Git MCP サーバーの設定例
|
||||||
|
servers.insert(
|
||||||
|
"git".to_string(),
|
||||||
|
McpServerDefinition {
|
||||||
|
command: "npx".to_string(),
|
||||||
|
args: vec![
|
||||||
|
"-y".to_string(),
|
||||||
|
"@modelcontextprotocol/server-git".to_string(),
|
||||||
|
".".to_string(),
|
||||||
|
],
|
||||||
|
env: HashMap::new(),
|
||||||
|
description: Some("Git operations in current directory".to_string()),
|
||||||
|
enabled: false, // デフォルトでは無効
|
||||||
|
integration_mode: IntegrationMode::Individual,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
Self { servers }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for McpConfig {
|
impl Default for McpConfig {
|
||||||
|
|
@ -126,3 +238,120 @@ fn expand_environment_variables(input: &str) -> Result<String, WorkerError> {
|
||||||
|
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::fs;
|
||||||
|
use tempfile::tempdir;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_default_config_creation() {
|
||||||
|
let config = McpConfig::create_default_config();
|
||||||
|
assert!(!config.servers.is_empty());
|
||||||
|
assert!(config.servers.contains_key("brave_search"));
|
||||||
|
assert!(config.servers.contains_key("filesystem"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_config_serialization() {
|
||||||
|
let config = McpConfig::create_default_config();
|
||||||
|
let yaml = serde_yaml::to_string(&config).unwrap();
|
||||||
|
|
||||||
|
// YAML形式で正しくシリアライズされることを確認
|
||||||
|
assert!(yaml.contains("servers:"));
|
||||||
|
assert!(yaml.contains("brave_search:"));
|
||||||
|
assert!(yaml.contains("command:"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_config_deserialization() {
|
||||||
|
let yaml_content = r#"
|
||||||
|
servers:
|
||||||
|
test_server:
|
||||||
|
command: "python3"
|
||||||
|
args: ["test.py"]
|
||||||
|
env:
|
||||||
|
TEST_VAR: "test_value"
|
||||||
|
description: "Test server"
|
||||||
|
enabled: true
|
||||||
|
integration_mode: "individual"
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let config: McpConfig = serde_yaml::from_str(yaml_content).unwrap();
|
||||||
|
assert_eq!(config.servers.len(), 1);
|
||||||
|
|
||||||
|
let server = config.servers.get("test_server").unwrap();
|
||||||
|
assert_eq!(server.command, "python3");
|
||||||
|
assert_eq!(server.args, vec!["test.py"]);
|
||||||
|
assert_eq!(server.env.get("TEST_VAR").unwrap(), "test_value");
|
||||||
|
assert!(server.enabled);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_environment_variable_expansion() {
|
||||||
|
// SAFETY: Setting test environment variables in a single-threaded test context
|
||||||
|
unsafe {
|
||||||
|
std::env::set_var("TEST_VAR", "test_value");
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = expand_environment_variables("prefix_${TEST_VAR}_suffix").unwrap();
|
||||||
|
assert_eq!(result, "prefix_test_value_suffix");
|
||||||
|
|
||||||
|
// 存在しない環境変数の場合はそのまま残る
|
||||||
|
let result = expand_environment_variables("${NON_EXISTENT_VAR}").unwrap();
|
||||||
|
assert_eq!(result, "${NON_EXISTENT_VAR}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_config_file_operations() {
|
||||||
|
let dir = tempdir().unwrap();
|
||||||
|
let config_path = dir.path().join("mcp.yaml");
|
||||||
|
|
||||||
|
// 設定を作成して保存
|
||||||
|
let config = McpConfig::create_default_config();
|
||||||
|
config.save_to_file(&config_path).unwrap();
|
||||||
|
|
||||||
|
// ファイルが作成されたことを確認
|
||||||
|
assert!(config_path.exists());
|
||||||
|
|
||||||
|
// 設定を読み込み
|
||||||
|
let loaded_config = McpConfig::load_from_file(&config_path).unwrap();
|
||||||
|
assert_eq!(config.servers.len(), loaded_config.servers.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_enabled_servers_filter() {
|
||||||
|
let mut config = McpConfig::default();
|
||||||
|
|
||||||
|
// 有効なサーバーを追加
|
||||||
|
config.servers.insert(
|
||||||
|
"enabled_server".to_string(),
|
||||||
|
McpServerDefinition {
|
||||||
|
command: "test".to_string(),
|
||||||
|
args: vec![],
|
||||||
|
env: HashMap::new(),
|
||||||
|
description: None,
|
||||||
|
enabled: true,
|
||||||
|
integration_mode: IntegrationMode::Individual,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
// 無効なサーバーを追加
|
||||||
|
config.servers.insert(
|
||||||
|
"disabled_server".to_string(),
|
||||||
|
McpServerDefinition {
|
||||||
|
command: "test".to_string(),
|
||||||
|
args: vec![],
|
||||||
|
env: HashMap::new(),
|
||||||
|
description: None,
|
||||||
|
enabled: false,
|
||||||
|
integration_mode: IntegrationMode::Individual,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
let enabled_servers = config.get_enabled_servers();
|
||||||
|
assert_eq!(enabled_servers.len(), 1);
|
||||||
|
assert_eq!(enabled_servers[0].0, "enabled_server");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -184,7 +184,7 @@ impl McpClient {
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
client_info: ClientInfo {
|
client_info: ClientInfo {
|
||||||
name: "llm-worker-rs".to_string(),
|
name: "nia-worker".to_string(),
|
||||||
version: "0.1.0".to_string(),
|
version: "0.1.0".to_string(),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
@ -417,3 +417,33 @@ impl Drop for McpClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_json_rpc_serialization() {
|
||||||
|
let request = JsonRpcRequest {
|
||||||
|
jsonrpc: "2.0".to_string(),
|
||||||
|
id: Value::Number(serde_json::Number::from(1)),
|
||||||
|
method: "tools/list".to_string(),
|
||||||
|
params: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let json = serde_json::to_string(&request).unwrap();
|
||||||
|
assert!(json.contains("\"jsonrpc\":\"2.0\""));
|
||||||
|
assert!(json.contains("\"id\":1"));
|
||||||
|
assert!(json.contains("\"method\":\"tools/list\""));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_error_response() {
|
||||||
|
let response_json =
|
||||||
|
r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32601,"message":"Method not found"}}"#;
|
||||||
|
let response: JsonRpcResponse = serde_json::from_str(response_json).unwrap();
|
||||||
|
|
||||||
|
assert!(response.error.is_some());
|
||||||
|
assert_eq!(response.error.unwrap().code, -32601);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -451,3 +451,36 @@ pub async fn test_mcp_connection(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_mcp_server_config() {
|
||||||
|
let config =
|
||||||
|
McpServerConfig::new("npx", vec!["-y", "@modelcontextprotocol/server-everything"]);
|
||||||
|
assert_eq!(config.command, "npx");
|
||||||
|
assert_eq!(
|
||||||
|
config.args,
|
||||||
|
vec!["-y", "@modelcontextprotocol/server-everything"]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
config.name,
|
||||||
|
"npx(-y @modelcontextprotocol/server-everything)"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_mcp_tool_creation() {
|
||||||
|
let config = McpServerConfig::new("echo", vec!["test"]);
|
||||||
|
let tool = McpDynamicTool::new(config);
|
||||||
|
|
||||||
|
assert_eq!(tool.name(), "mcp_proxy");
|
||||||
|
assert!(!tool.description().is_empty());
|
||||||
|
|
||||||
|
let schema = tool.parameters_schema();
|
||||||
|
assert!(schema.is_object());
|
||||||
|
assert!(schema.get("properties").is_some());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -206,3 +206,42 @@ impl LlmClientTrait for CustomLlmClient {
|
||||||
pub extern "C" fn create_plugin() -> Box<dyn ProviderPlugin> {
|
pub extern "C" fn create_plugin() -> Box<dyn ProviderPlugin> {
|
||||||
Box::new(CustomProviderPlugin::new())
|
Box::new(CustomProviderPlugin::new())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_plugin_metadata() {
|
||||||
|
let plugin = CustomProviderPlugin::new();
|
||||||
|
let metadata = plugin.metadata();
|
||||||
|
|
||||||
|
assert_eq!(metadata.id, "custom-provider");
|
||||||
|
assert_eq!(metadata.name, "Custom LLM Provider");
|
||||||
|
assert!(metadata.requires_api_key);
|
||||||
|
assert_eq!(metadata.supported_models.len(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_api_key_validation() {
|
||||||
|
let plugin = CustomProviderPlugin::new();
|
||||||
|
|
||||||
|
assert!(plugin.validate_api_key("custom-1234567890abcdefghij"));
|
||||||
|
assert!(!plugin.validate_api_key("invalid-key"));
|
||||||
|
assert!(!plugin.validate_api_key("custom-short"));
|
||||||
|
assert!(!plugin.validate_api_key(""));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_plugin_initialization() {
|
||||||
|
let mut plugin = CustomProviderPlugin::new();
|
||||||
|
let mut config = HashMap::new();
|
||||||
|
config.insert(
|
||||||
|
"base_url".to_string(),
|
||||||
|
Value::String("https://api.example.com".to_string()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = plugin.initialize(config).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,7 @@ use std::collections::HashMap;
|
||||||
/// モデルに関する静的な情報
|
/// モデルに関する静的な情報
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct ModelContext {
|
pub struct ModelContext {
|
||||||
pub provider: Option<crate::types::LlmProvider>,
|
pub provider: crate::types::LlmProvider,
|
||||||
pub provider_id: String,
|
|
||||||
pub model_name: String,
|
pub model_name: String,
|
||||||
pub capabilities: ModelCapabilities,
|
pub capabilities: ModelCapabilities,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,496 +0,0 @@
|
||||||
use std::sync::{Arc, Mutex, OnceLock};
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use futures::StreamExt;
|
|
||||||
use serde_json::{json, Value};
|
|
||||||
use wiremock::matchers::{method, path};
|
|
||||||
use wiremock::{Mock, MockGuard, MockServer, ResponseTemplate};
|
|
||||||
|
|
||||||
use worker::{
|
|
||||||
HookContext, HookResult, LlmProvider, PromptError, StreamEvent, Tool, ToolResult, Worker,
|
|
||||||
WorkerBlueprint, WorkerHook,
|
|
||||||
};
|
|
||||||
use worker_types::Role;
|
|
||||||
|
|
||||||
const SAMPLE_TOOL_NAME: &str = "sample_tool";
|
|
||||||
static ENV_MUTEX: OnceLock<Mutex<()>> = OnceLock::new();
|
|
||||||
|
|
||||||
struct ProviderCase {
|
|
||||||
name: &'static str,
|
|
||||||
provider: LlmProvider,
|
|
||||||
model: &'static str,
|
|
||||||
env_var: &'static str,
|
|
||||||
completion_path: &'static str,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct SampleTool {
|
|
||||||
name: String,
|
|
||||||
description: String,
|
|
||||||
calls: Arc<Mutex<Vec<Value>>>,
|
|
||||||
response: Value,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SampleTool {
|
|
||||||
fn new(provider_label: &str, calls: Arc<Mutex<Vec<Value>>>) -> Self {
|
|
||||||
Self {
|
|
||||||
name: SAMPLE_TOOL_NAME.to_string(),
|
|
||||||
description: format!("Records invocations for {}", provider_label),
|
|
||||||
calls,
|
|
||||||
response: json!({
|
|
||||||
"status": "ok",
|
|
||||||
"provider": provider_label,
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Tool for SampleTool {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
&self.name
|
|
||||||
}
|
|
||||||
|
|
||||||
fn description(&self) -> &str {
|
|
||||||
&self.description
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parameters_schema(&self) -> Value {
|
|
||||||
json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"provider": {"type": "string"},
|
|
||||||
"request_id": {"type": "integer"}
|
|
||||||
},
|
|
||||||
"required": ["provider", "request_id"]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute(&self, args: Value) -> ToolResult<Value> {
|
|
||||||
self.calls.lock().unwrap().push(args.clone());
|
|
||||||
Ok(self.response.clone())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct RecordingHook {
|
|
||||||
tool_name: String,
|
|
||||||
provider_label: String,
|
|
||||||
events: Arc<Mutex<Vec<String>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RecordingHook {
|
|
||||||
fn new(provider_label: &str, events: Arc<Mutex<Vec<String>>>) -> Self {
|
|
||||||
Self {
|
|
||||||
tool_name: SAMPLE_TOOL_NAME.to_string(),
|
|
||||||
provider_label: provider_label.to_string(),
|
|
||||||
events,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl WorkerHook for RecordingHook {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"recording_hook"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn hook_type(&self) -> &str {
|
|
||||||
"PostToolUse"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn matcher(&self) -> &str {
|
|
||||||
&self.tool_name
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute(&self, context: HookContext) -> (HookContext, HookResult) {
|
|
||||||
let tool = context
|
|
||||||
.get_variable("current_tool")
|
|
||||||
.cloned()
|
|
||||||
.unwrap_or_else(|| self.tool_name.clone());
|
|
||||||
|
|
||||||
let entry = format!(
|
|
||||||
"{}::{}::{}",
|
|
||||||
self.provider_label, tool, context.content
|
|
||||||
);
|
|
||||||
self.events.lock().unwrap().push(entry);
|
|
||||||
|
|
||||||
let message = format!("{} hook observed {}", self.provider_label, tool);
|
|
||||||
(
|
|
||||||
context,
|
|
||||||
HookResult::AddMessage(message, Role::Assistant),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct EnvOverride {
|
|
||||||
key: String,
|
|
||||||
previous: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EnvOverride {
|
|
||||||
fn set(key: &str, value: String) -> Self {
|
|
||||||
let previous = std::env::var(key).ok();
|
|
||||||
std::env::set_var(key, &value);
|
|
||||||
Self {
|
|
||||||
key: key.to_string(),
|
|
||||||
previous,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Drop for EnvOverride {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
if let Some(prev) = self.previous.take() {
|
|
||||||
std::env::set_var(&self.key, prev);
|
|
||||||
} else {
|
|
||||||
std::env::remove_var(&self.key);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_blueprint(
|
|
||||||
provider: LlmProvider,
|
|
||||||
model: &str,
|
|
||||||
provider_label: &str,
|
|
||||||
tool_calls: Arc<Mutex<Vec<Value>>>,
|
|
||||||
hook_events: Arc<Mutex<Vec<String>>>,
|
|
||||||
) -> WorkerBlueprint {
|
|
||||||
let mut blueprint = Worker::blueprint();
|
|
||||||
blueprint
|
|
||||||
.provider(provider)
|
|
||||||
.model(model)
|
|
||||||
.api_key(provider.as_str(), "test-key")
|
|
||||||
.system_prompt_fn(|_, _| Ok::<_, PromptError>("Integration test system prompt.".into()))
|
|
||||||
.add_tool(SampleTool::new(provider_label, Arc::clone(&tool_calls)))
|
|
||||||
.attach_hook(RecordingHook::new(provider_label, Arc::clone(&hook_events)));
|
|
||||||
blueprint
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn setup_mock_response(
|
|
||||||
case: &ProviderCase,
|
|
||||||
server: &MockServer,
|
|
||||||
expected_args: &Value,
|
|
||||||
) -> MockGuard {
|
|
||||||
match case.provider {
|
|
||||||
LlmProvider::OpenAI => {
|
|
||||||
let arguments = expected_args.to_string();
|
|
||||||
let event_body = json!({
|
|
||||||
"choices": [{
|
|
||||||
"delta": {
|
|
||||||
"tool_calls": [{
|
|
||||||
"function": {
|
|
||||||
"name": SAMPLE_TOOL_NAME,
|
|
||||||
"arguments": arguments
|
|
||||||
}
|
|
||||||
}]
|
|
||||||
}
|
|
||||||
}]
|
|
||||||
});
|
|
||||||
let sse = format!(
|
|
||||||
"data: {}\n\ndata: [DONE]\n\n",
|
|
||||||
event_body
|
|
||||||
);
|
|
||||||
Mock::given(method("POST"))
|
|
||||||
.and(path(case.completion_path))
|
|
||||||
.respond_with(
|
|
||||||
ResponseTemplate::new(200)
|
|
||||||
.set_body_raw(sse, "text/event-stream"),
|
|
||||||
)
|
|
||||||
.mount(server)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
LlmProvider::Claude => {
|
|
||||||
let event_tool = json!({
|
|
||||||
"type": "content_block_start",
|
|
||||||
"data": {
|
|
||||||
"content_block": {
|
|
||||||
"type": "tool_use",
|
|
||||||
"name": SAMPLE_TOOL_NAME,
|
|
||||||
"input": expected_args
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
let event_stop = json!({
|
|
||||||
"type": "message_stop",
|
|
||||||
"data": {}
|
|
||||||
});
|
|
||||||
let sse = format!(
|
|
||||||
"data: {}\n\ndata: {}\n\ndata: [DONE]\n\n",
|
|
||||||
event_tool, event_stop
|
|
||||||
);
|
|
||||||
Mock::given(method("POST"))
|
|
||||||
.and(path(case.completion_path))
|
|
||||||
.respond_with(
|
|
||||||
ResponseTemplate::new(200)
|
|
||||||
.set_body_raw(sse, "text/event-stream"),
|
|
||||||
)
|
|
||||||
.mount(server)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
LlmProvider::Gemini => {
|
|
||||||
let body = json!({
|
|
||||||
"candidates": [{
|
|
||||||
"content": {
|
|
||||||
"role": "model",
|
|
||||||
"parts": [{
|
|
||||||
"functionCall": {
|
|
||||||
"name": SAMPLE_TOOL_NAME,
|
|
||||||
"args": expected_args
|
|
||||||
}
|
|
||||||
}]
|
|
||||||
},
|
|
||||||
"finishReason": "STOP"
|
|
||||||
}]
|
|
||||||
});
|
|
||||||
Mock::given(method("POST"))
|
|
||||||
.and(path(case.completion_path))
|
|
||||||
.respond_with(ResponseTemplate::new(200).set_body_json(body))
|
|
||||||
.mount(server)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
LlmProvider::Ollama => {
|
|
||||||
let first = json!({
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "",
|
|
||||||
"tool_calls": [{
|
|
||||||
"function": {
|
|
||||||
"name": SAMPLE_TOOL_NAME,
|
|
||||||
"arguments": expected_args
|
|
||||||
}
|
|
||||||
}]
|
|
||||||
},
|
|
||||||
"done": false
|
|
||||||
});
|
|
||||||
let second = json!({
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "finished"
|
|
||||||
},
|
|
||||||
"done": true
|
|
||||||
});
|
|
||||||
let body = format!("{}\n{}\n", first, second);
|
|
||||||
Mock::given(method("POST"))
|
|
||||||
.and(path(case.completion_path))
|
|
||||||
.respond_with(
|
|
||||||
ResponseTemplate::new(200)
|
|
||||||
.set_body_raw(body, "application/x-ndjson"),
|
|
||||||
)
|
|
||||||
.mount(server)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
other => panic!("Unsupported provider in test: {:?}", other),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn provider_cases() -> Vec<ProviderCase> {
|
|
||||||
vec![
|
|
||||||
ProviderCase {
|
|
||||||
name: "openai",
|
|
||||||
provider: LlmProvider::OpenAI,
|
|
||||||
model: "gpt-4o-mini",
|
|
||||||
env_var: "OPENAI_BASE_URL",
|
|
||||||
completion_path: "/v1/chat/completions",
|
|
||||||
},
|
|
||||||
ProviderCase {
|
|
||||||
name: "gemini",
|
|
||||||
provider: LlmProvider::Gemini,
|
|
||||||
model: "gemini-1.5-flash",
|
|
||||||
env_var: "GEMINI_BASE_URL",
|
|
||||||
completion_path: "/v1beta/models/gemini-1.5-flash:streamGenerateContent",
|
|
||||||
},
|
|
||||||
ProviderCase {
|
|
||||||
name: "anthropic",
|
|
||||||
provider: LlmProvider::Claude,
|
|
||||||
model: "claude-3-opus-20240229",
|
|
||||||
env_var: "ANTHROPIC_BASE_URL",
|
|
||||||
completion_path: "/v1/messages",
|
|
||||||
},
|
|
||||||
ProviderCase {
|
|
||||||
name: "ollama",
|
|
||||||
provider: LlmProvider::Ollama,
|
|
||||||
model: "llama3",
|
|
||||||
env_var: "OLLAMA_BASE_URL",
|
|
||||||
completion_path: "/api/chat",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn worker_executes_tools_and_hooks_across_mocked_providers() {
|
|
||||||
let env_lock = ENV_MUTEX.get_or_init(|| Mutex::new(()));
|
|
||||||
|
|
||||||
for case in provider_cases() {
|
|
||||||
let tool_calls = Arc::new(Mutex::new(Vec::<Value>::new()));
|
|
||||||
let hook_events = Arc::new(Mutex::new(Vec::<String>::new()));
|
|
||||||
|
|
||||||
let _env_guard = env_lock.lock().unwrap();
|
|
||||||
|
|
||||||
let server = MockServer::start().await;
|
|
||||||
let _env_override = EnvOverride::set(case.env_var, server.uri());
|
|
||||||
|
|
||||||
let expected_args = json!({
|
|
||||||
"provider": case.name,
|
|
||||||
"request_id": 1
|
|
||||||
});
|
|
||||||
|
|
||||||
let _mock = setup_mock_response(&case, &server, &expected_args).await;
|
|
||||||
|
|
||||||
let mut blueprint = build_blueprint(
|
|
||||||
case.provider,
|
|
||||||
case.model,
|
|
||||||
case.name,
|
|
||||||
Arc::clone(&tool_calls),
|
|
||||||
Arc::clone(&hook_events),
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut worker = blueprint.instantiate().expect("worker to instantiate");
|
|
||||||
|
|
||||||
let mut stream = worker
|
|
||||||
.process_task_stream(
|
|
||||||
"Trigger the sample tool".to_string(),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let mut events = Vec::new();
|
|
||||||
while let Some(event) = stream.next().await {
|
|
||||||
events.push(event.expect("stream event"));
|
|
||||||
}
|
|
||||||
|
|
||||||
let requests = server
|
|
||||||
.received_requests()
|
|
||||||
.await
|
|
||||||
.expect("to inspect received requests");
|
|
||||||
assert_eq!(
|
|
||||||
requests.len(),
|
|
||||||
1,
|
|
||||||
"expected exactly one request for provider {}",
|
|
||||||
case.name
|
|
||||||
);
|
|
||||||
let body: Value =
|
|
||||||
serde_json::from_slice(&requests[0].body).expect("request body to be JSON");
|
|
||||||
|
|
||||||
match case.provider {
|
|
||||||
LlmProvider::OpenAI => {
|
|
||||||
assert_eq!(body["model"], case.model);
|
|
||||||
assert_eq!(body["stream"], true);
|
|
||||||
assert_eq!(
|
|
||||||
body["tools"][0]["function"]["name"],
|
|
||||||
SAMPLE_TOOL_NAME
|
|
||||||
);
|
|
||||||
}
|
|
||||||
LlmProvider::Claude => {
|
|
||||||
assert_eq!(body["model"], case.model);
|
|
||||||
assert_eq!(body["stream"], true);
|
|
||||||
assert_eq!(body["tools"][0]["name"], SAMPLE_TOOL_NAME);
|
|
||||||
}
|
|
||||||
LlmProvider::Gemini => {
|
|
||||||
assert_eq!(
|
|
||||||
body["contents"]
|
|
||||||
.as_array()
|
|
||||||
.expect("contents to be array")
|
|
||||||
.len(),
|
|
||||||
2,
|
|
||||||
"system + user messages should be present"
|
|
||||||
);
|
|
||||||
let tools = body["tools"][0]["functionDeclarations"]
|
|
||||||
.as_array()
|
|
||||||
.expect("function declarations to exist");
|
|
||||||
assert_eq!(tools[0]["name"], SAMPLE_TOOL_NAME);
|
|
||||||
}
|
|
||||||
LlmProvider::Ollama => {
|
|
||||||
assert_eq!(body["model"], case.model);
|
|
||||||
assert_eq!(body["stream"], true);
|
|
||||||
assert_eq!(
|
|
||||||
body["tools"][0]["function"]["name"],
|
|
||||||
SAMPLE_TOOL_NAME
|
|
||||||
);
|
|
||||||
}
|
|
||||||
_ => unreachable!(),
|
|
||||||
}
|
|
||||||
|
|
||||||
let recorded_calls = tool_calls.lock().unwrap().clone();
|
|
||||||
assert_eq!(
|
|
||||||
recorded_calls.len(),
|
|
||||||
1,
|
|
||||||
"tool should execute exactly once for {}",
|
|
||||||
case.name
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
recorded_calls[0], expected_args,
|
|
||||||
"tool arguments should match for {}",
|
|
||||||
case.name
|
|
||||||
);
|
|
||||||
|
|
||||||
let recorded_hooks = hook_events.lock().unwrap().clone();
|
|
||||||
assert!(
|
|
||||||
recorded_hooks
|
|
||||||
.iter()
|
|
||||||
.any(|entry| entry.contains(case.name) && entry.contains(SAMPLE_TOOL_NAME)),
|
|
||||||
"hook should capture tool usage for {}: {:?}",
|
|
||||||
case.name,
|
|
||||||
recorded_hooks
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut saw_tool_call = false;
|
|
||||||
let mut saw_tool_result = false;
|
|
||||||
let mut saw_hook_message = false;
|
|
||||||
let mut saw_completion = false;
|
|
||||||
|
|
||||||
for event in &events {
|
|
||||||
match event {
|
|
||||||
StreamEvent::ToolCall(call) => {
|
|
||||||
saw_tool_call = true;
|
|
||||||
assert_eq!(call.name, SAMPLE_TOOL_NAME);
|
|
||||||
assert_eq!(
|
|
||||||
serde_json::from_str::<Value>(&call.arguments)
|
|
||||||
.expect("tool arguments to be JSON"),
|
|
||||||
expected_args
|
|
||||||
);
|
|
||||||
}
|
|
||||||
StreamEvent::ToolResult { tool_name, result } => {
|
|
||||||
if tool_name == SAMPLE_TOOL_NAME {
|
|
||||||
saw_tool_result = true;
|
|
||||||
let value = result
|
|
||||||
.as_ref()
|
|
||||||
.expect("tool execution should succeed");
|
|
||||||
assert_eq!(value["status"], "ok");
|
|
||||||
assert_eq!(value["provider"], case.name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
StreamEvent::HookMessage { hook_name, content, .. } => {
|
|
||||||
if hook_name == "recording_hook" {
|
|
||||||
saw_hook_message = true;
|
|
||||||
assert!(
|
|
||||||
content.contains(case.name),
|
|
||||||
"hook content should mention provider"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
StreamEvent::Completion(_) => {
|
|
||||||
saw_completion = true;
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
assert!(saw_tool_call, "missing tool call event for {}", case.name);
|
|
||||||
assert!(
|
|
||||||
saw_tool_result,
|
|
||||||
"missing tool result event for {}",
|
|
||||||
case.name
|
|
||||||
);
|
|
||||||
assert!(
|
|
||||||
saw_hook_message,
|
|
||||||
"missing hook message for {}",
|
|
||||||
case.name
|
|
||||||
);
|
|
||||||
assert!(
|
|
||||||
saw_completion,
|
|
||||||
"missing completion event for {}",
|
|
||||||
case.name
|
|
||||||
);
|
|
||||||
|
|
||||||
std::env::remove_var(case.env_var);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Loading…
Reference in New Issue
Block a user