Compare commits

..

No commits in common. "b206acc3d39d05cd480e80699bca1459eebfbe30" and "cc6bbe2a43cb8e660849f776bdc517e2bc7a6af3" have entirely different histories.

17 changed files with 642 additions and 997 deletions

197
Cargo.lock generated
View File

@ -26,16 +26,6 @@ version = "1.0.100"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61"
[[package]]
name = "assert-json-diff"
version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12"
dependencies = [
"serde",
"serde_json",
]
[[package]] [[package]]
name = "async-stream" name = "async-stream"
version = "0.3.6" version = "0.3.6"
@ -69,12 +59,6 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "atomic-waker"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
[[package]] [[package]]
name = "autocfg" name = "autocfg"
version = "1.5.0" version = "1.5.0"
@ -87,12 +71,6 @@ version = "0.21.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
[[package]]
name = "base64"
version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "1.3.2" version = "1.3.2"
@ -163,24 +141,6 @@ version = "0.8.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
[[package]]
name = "deadpool"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0be2b1d1d6ec8d846f05e137292d0b89133caf95ef33695424c09568bdd39b1b"
dependencies = [
"deadpool-runtime",
"lazy_static",
"num_cpus",
"tokio",
]
[[package]]
name = "deadpool-runtime"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b"
[[package]] [[package]]
name = "dirs" name = "dirs"
version = "6.0.0" version = "6.0.0"
@ -405,26 +365,7 @@ dependencies = [
"futures-core", "futures-core",
"futures-sink", "futures-sink",
"futures-util", "futures-util",
"http 0.2.12", "http",
"indexmap",
"slab",
"tokio",
"tokio-util",
"tracing",
]
[[package]]
name = "h2"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3c0b69cfcb4e1b9f1bf2f53f95f766e4661169728ec61cd3fe5a0166f2d1386"
dependencies = [
"atomic-waker",
"bytes",
"fnv",
"futures-core",
"futures-sink",
"http 1.3.1",
"indexmap", "indexmap",
"slab", "slab",
"tokio", "tokio",
@ -444,12 +385,6 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hermit-abi"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
[[package]] [[package]]
name = "http" name = "http"
version = "0.2.12" version = "0.2.12"
@ -461,17 +396,6 @@ dependencies = [
"itoa", "itoa",
] ]
[[package]]
name = "http"
version = "1.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565"
dependencies = [
"bytes",
"fnv",
"itoa",
]
[[package]] [[package]]
name = "http-body" name = "http-body"
version = "0.4.6" version = "0.4.6"
@ -479,30 +403,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2"
dependencies = [ dependencies = [
"bytes", "bytes",
"http 0.2.12", "http",
"pin-project-lite",
]
[[package]]
name = "http-body"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184"
dependencies = [
"bytes",
"http 1.3.1",
]
[[package]]
name = "http-body-util"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a"
dependencies = [
"bytes",
"futures-core",
"http 1.3.1",
"http-body 1.0.1",
"pin-project-lite", "pin-project-lite",
] ]
@ -528,9 +429,9 @@ dependencies = [
"futures-channel", "futures-channel",
"futures-core", "futures-core",
"futures-util", "futures-util",
"h2 0.3.27", "h2",
"http 0.2.12", "http",
"http-body 0.4.6", "http-body",
"httparse", "httparse",
"httpdate", "httpdate",
"itoa", "itoa",
@ -542,29 +443,6 @@ dependencies = [
"want", "want",
] ]
[[package]]
name = "hyper"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb3aa54a13a0dfe7fbe3a59e0c76093041720fdc77b110cc0fc260fafb4dc51e"
dependencies = [
"atomic-waker",
"bytes",
"futures-channel",
"futures-core",
"h2 0.4.12",
"http 1.3.1",
"http-body 1.0.1",
"httparse",
"httpdate",
"itoa",
"pin-project-lite",
"pin-utils",
"smallvec",
"tokio",
"want",
]
[[package]] [[package]]
name = "hyper-rustls" name = "hyper-rustls"
version = "0.24.2" version = "0.24.2"
@ -572,28 +450,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590"
dependencies = [ dependencies = [
"futures-util", "futures-util",
"http 0.2.12", "http",
"hyper 0.14.32", "hyper",
"rustls", "rustls",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",
] ]
[[package]]
name = "hyper-util"
version = "0.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c6995591a8f1380fcb4ba966a252a4b29188d51d2b89e3a252f5305be65aea8"
dependencies = [
"bytes",
"futures-core",
"http 1.3.1",
"http-body 1.0.1",
"hyper 1.7.0",
"pin-project-lite",
"tokio",
]
[[package]] [[package]]
name = "iana-time-zone" name = "iana-time-zone"
version = "0.1.64" version = "0.1.64"
@ -873,16 +736,6 @@ dependencies = [
"autocfg", "autocfg",
] ]
[[package]]
name = "num_cpus"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b"
dependencies = [
"hermit-abi",
"libc",
]
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.21.3" version = "1.21.3"
@ -1044,15 +897,15 @@ version = "0.11.27"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62" checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62"
dependencies = [ dependencies = [
"base64 0.21.7", "base64",
"bytes", "bytes",
"encoding_rs", "encoding_rs",
"futures-core", "futures-core",
"futures-util", "futures-util",
"h2 0.3.27", "h2",
"http 0.2.12", "http",
"http-body 0.4.6", "http-body",
"hyper 0.14.32", "hyper",
"hyper-rustls", "hyper-rustls",
"ipnet", "ipnet",
"js-sys", "js-sys",
@ -1126,7 +979,7 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c"
dependencies = [ dependencies = [
"base64 0.21.7", "base64",
] ]
[[package]] [[package]]
@ -2104,29 +1957,6 @@ dependencies = [
"windows-sys 0.48.0", "windows-sys 0.48.0",
] ]
[[package]]
name = "wiremock"
version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08db1edfb05d9b3c1542e521aea074442088292f00b5f28e435c714a98f85031"
dependencies = [
"assert-json-diff",
"base64 0.22.1",
"deadpool",
"futures",
"http 1.3.1",
"http-body-util",
"hyper 1.7.0",
"hyper-util",
"log",
"once_cell",
"regex",
"serde",
"serde_json",
"tokio",
"url",
]
[[package]] [[package]]
name = "wit-bindgen" name = "wit-bindgen"
version = "0.46.0" version = "0.46.0"
@ -2164,7 +1994,6 @@ dependencies = [
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"uuid", "uuid",
"wiremock",
"worker-macros", "worker-macros",
"worker-types", "worker-types",
"xdg", "xdg",

View File

@ -23,9 +23,9 @@ use worker::{LlmProvider, SystemPromptContext, PromptError, Worker};
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Box<dyn std::error::Error>> {
let system_prompt = |ctx: &SystemPromptContext, _messages: &[worker_types::Message]| { let system_prompt = |ctx: &SystemPromptContext, _messages: &[worker_types::Message]| {
Ok(format!( Ok(format!(
"You are assisting with model {} from provider {}.", "You are assisting with model {} from provider {:?}.",
ctx.model.model_name, ctx.model.model_name,
ctx.model.provider_id ctx.model.provider
)) ))
}; };

View File

@ -501,7 +501,7 @@ pub async fn configurable_hook(mut context: HookContext) -> HookResult {
} }
// 設定ファイルからオプション読み込み // 設定ファイルからオプション読み込み
let config_path = format!("{}/hook_config.json", context.workspace_path); let config_path = format!("{}/.nia/hook_config.json", context.workspace_path);
if let Ok(config_content) = tokio::fs::read_to_string(&config_path).await { if let Ok(config_content) = tokio::fs::read_to_string(&config_path).await {
if let Ok(config) = serde_json::from_str::<HookConfig>(&config_content) { if let Ok(config) = serde_json::from_str::<HookConfig>(&config_content) {
// 設定に基づく処理 // 設定に基づく処理
@ -715,3 +715,4 @@ A: `HookResult::Error`を返すと、そのターンは中断されます。継
- [worker-macro.md](worker-macro.md) - マクロシステム - [worker-macro.md](worker-macro.md) - マクロシステム
- `worker/src/lib.rs` - Hook実装コード - `worker/src/lib.rs` - Hook実装コード
- `worker-types/src/lib.rs` - Hook型定義 - `worker-types/src/lib.rs` - Hook型定義
- `nia-cli/src/tui/hooks/` - TUI用Hook実装例

View File

@ -6,7 +6,7 @@ v0.3.0 はプロンプトリソースの解決責務を利用側へ完全に移
## Breaking Changes ## Breaking Changes
- `ConfigParser::resolve_path` を削除し、`#user/` `#workspace/` 等のプレフィックス解決をライブラリ利用者実装の `ResourceLoader` に委譲しました。 - `ConfigParser::resolve_path` を削除し、`#nia/` `#workspace/` 等のプレフィックス解決をライブラリ利用者実装の `ResourceLoader` に委譲しました。
- `WorkerBuilder::build()``resource_loader(...)` が未指定の場合エラーを返すようになりました。ワーカー構築前に必ずローダーを提供してください。 - `WorkerBuilder::build()``resource_loader(...)` が未指定の場合エラーを返すようになりました。ワーカー構築前に必ずローダーを提供してください。
## 新機能 / 仕様変更 ## 新機能 / 仕様変更
@ -19,6 +19,7 @@ v0.3.0 はプロンプトリソースの解決責務を利用側へ完全に移
## 不具合修正 ## 不具合修正
- `include_file` ヘルパーがカスタムローダーを利用せずにファイルアクセスしていた問題を修正。 - `include_file` ヘルパーがカスタムローダーを利用せずにファイルアクセスしていた問題を修正。
- `ConfigParser` が存在しない `#nia/` プレフィックスを静的に解決しようとしていた挙動を除去し、誤ったパスが静かに通ることを防止。
## 移行ガイド ## 移行ガイド

View File

@ -16,7 +16,7 @@ v0.4.0 は Worker が `Role` や YAML 設定を扱わず、システムプロン
## 不具合修正 ## 不具合修正
- Worker から旧プロジェクト固有の設定コードを除去し、環境依存の副作用を縮小。 - Worker から NIA 固有の設定コードを除去し、環境依存の副作用を縮小。
## 移行ガイド ## 移行ガイド

View File

@ -2,7 +2,7 @@
**Release Date**: 2025-10-25 **Release Date**: 2025-10-25
v0.5.0 introduces the Worker Blueprint API and removes the old type-state builder. Configuration now lives on the blueprint, while instantiated workers keep only the materialised system prompt, model metadata, and runtime state. v0.5.0 introduces the Worker Blueprint API and removes the old type-state builder. Configuration now lives on the blueprint, while instantiated workers keep only the materialised system prompt and runtime state.
## Breaking Changes ## Breaking Changes
@ -12,10 +12,9 @@ v0.5.0 introduces the Worker Blueprint API and removes the old type-state builde
## New Features / Behaviour ## New Features / Behaviour
- `WorkerBlueprint` stores provider/model/api keys, tools, hooks, optional precomputed system prompt messages, and optional model feature flags. `instantiate()` evaluates the prompt (if not already cached) and hands the final string to the `Worker`. - `WorkerBlueprint` stores provider/model/api keys, tools, hooks, and optional precomputed system prompt strings. `instantiate()` evaluates the prompt (if not already cached) and hands the final string to the `Worker`.
- Instantiated workers retain the composed system prompt, the original generator closure, and a `Model` struct describing provider/model/features; the generator only runs again if a new session requires it. - Instantiated workers retain only the composed system prompt string; the generator function lives solely on the blueprint and is dropped after instantiation.
- System prompts are no longer recomputed per turn. Tool metadata is appended dynamically as plain text when native tool support is unavailable. - System prompts are no longer recomputed per turn. Tool metadata is appended dynamically as plain text when native tool support is unavailable.
- Worker now exposes a `Model` struct (`provider`, `name`, `features`) in place of the previous loose strings and `supports_native_tools` helper. Capability heuristics remain for built-in providers but applications can override them via `WorkerBlueprint::model_features`.
## Migration Guide ## Migration Guide

View File

@ -2,11 +2,11 @@
"nodes": { "nodes": {
"flake-compat": { "flake-compat": {
"locked": { "locked": {
"lastModified": 1761588595, "lastModified": 1747046372,
"narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=", "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
"owner": "edolstra", "owner": "edolstra",
"repo": "flake-compat", "repo": "flake-compat",
"rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -35,11 +35,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1761373498, "lastModified": 1751011381,
"narHash": "sha256-Q/uhWNvd7V7k1H1ZPMy/vkx3F8C13ZcdrKjO7Jv7v0c=", "narHash": "sha256-krGXKxvkBhnrSC/kGBmg5MyupUUT5R6IBCLEzx9jhMM=",
"owner": "nixos", "owner": "nixos",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "6a08e6bb4e46ff7fcbb53d409b253f6bad8a28ce", "rev": "30e2e2857ba47844aa71991daa6ed1fc678bcbb7",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -45,4 +45,3 @@ dynamic-loading = ["libloading"]
[dev-dependencies] [dev-dependencies]
tempfile = "3.10.1" tempfile = "3.10.1"
tracing-subscriber = "0.3" tracing-subscriber = "0.3"
wiremock = "0.6"

View File

@ -10,8 +10,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
_messages: &[Message], _messages: &[Message],
) -> Result<String, PromptError> { ) -> Result<String, PromptError> {
Ok(format!( Ok(format!(
"You are helping with requests for model {} (provider {}).", "You are helping with requests for model {} (provider {:?}).",
ctx.model.model_name, ctx.model.provider_id ctx.model.model_name, ctx.model.provider
)) ))
} }

View File

@ -1,7 +1,8 @@
use crate::LlmProviderExt;
use crate::Worker;
use crate::plugin; use crate::plugin;
use crate::prompt::{PromptError, SystemPromptContext, SystemPromptFn}; use crate::prompt::{PromptError, SystemPromptContext, SystemPromptFn};
use crate::types::{HookManager, Tool, WorkerError, WorkerHook}; use crate::types::{HookManager, Tool, WorkerError, WorkerHook};
use crate::{LlmProviderExt, Model, ModelFeatures, ModelProvider, Worker};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use worker_types::{LlmProvider, Message, Role}; use worker_types::{LlmProvider, Message, Role};
@ -23,7 +24,6 @@ pub struct WorkerBlueprint {
pub(crate) tools: Vec<Box<dyn Tool>>, pub(crate) tools: Vec<Box<dyn Tool>>,
pub(crate) hooks: Vec<Box<dyn WorkerHook>>, pub(crate) hooks: Vec<Box<dyn WorkerHook>>,
pub(crate) prompt_cache: Option<Vec<Message>>, pub(crate) prompt_cache: Option<Vec<Message>>,
pub(crate) model_features: Option<ModelFeatures>,
} }
impl WorkerBlueprint { impl WorkerBlueprint {
@ -36,7 +36,6 @@ impl WorkerBlueprint {
tools: Vec::new(), tools: Vec::new(),
hooks: Vec::new(), hooks: Vec::new(),
prompt_cache: None, prompt_cache: None,
model_features: None,
} }
} }
@ -62,11 +61,6 @@ impl WorkerBlueprint {
self self
} }
pub fn model_features(&mut self, features: ModelFeatures) -> &mut Self {
self.model_features = Some(features);
self
}
pub fn api_key(&mut self, provider: impl Into<String>, key: impl Into<String>) -> &mut Self { pub fn api_key(&mut self, provider: impl Into<String>, key: impl Into<String>) -> &mut Self {
self.api_keys.insert(provider.into(), key.into()); self.api_keys.insert(provider.into(), key.into());
self self
@ -124,27 +118,9 @@ impl WorkerBlueprint {
.map(|tool| tool.name().to_string()) .map(|tool| tool.name().to_string())
.collect(); .collect();
let features = self let context = self.build_system_prompt_context(provider, &model_name, &tool_names);
.model_features
.clone()
.unwrap_or_else(|| match provider {
ProviderConfig::BuiltIn(p) => Worker::infer_model_features(Some(*p), &model_name),
ProviderConfig::Plugin { .. } => ModelFeatures::default(),
});
let preview_model = Model {
provider: match provider {
ProviderConfig::BuiltIn(p) => ModelProvider::BuiltIn(*p),
ProviderConfig::Plugin { id, .. } => ModelProvider::Plugin(id.clone()),
},
name: model_name.clone(),
features: features.clone(),
};
let context = Worker::create_system_prompt_context(&preview_model, &tool_names);
let prompt = generator(&context, &[]).map_err(|e| WorkerError::config(e.to_string()))?; let prompt = generator(&context, &[]).map_err(|e| WorkerError::config(e.to_string()))?;
self.prompt_cache = Some(vec![Message::new(Role::System, prompt)]); self.prompt_cache = Some(vec![Message::new(Role::System, prompt)]);
self.model_features = Some(features);
Ok(self) Ok(self)
} }
@ -162,49 +138,30 @@ impl WorkerBlueprint {
.take() .take()
.ok_or_else(|| WorkerError::config("System prompt generator is not configured"))?; .ok_or_else(|| WorkerError::config("System prompt generator is not configured"))?;
let mut prompt_cache = self.prompt_cache.take();
let mut provided_features = self.model_features.take();
let tools = std::mem::take(&mut self.tools); let tools = std::mem::take(&mut self.tools);
let hooks = std::mem::take(&mut self.hooks); let hooks = std::mem::take(&mut self.hooks);
let mut api_keys = self.api_keys; let mut api_keys = self.api_keys;
let tool_names: Vec<String> = tools.iter().map(|tool| tool.name().to_string()).collect(); let tool_names: Vec<String> = tools.iter().map(|tool| tool.name().to_string()).collect();
let provider_hint = provider_config.provider_hint();
let prompt_context =
Worker::create_system_prompt_context(provider_hint, &model_name, &tool_names);
match provider_config { let base_messages = match self.prompt_cache.take() {
ProviderConfig::BuiltIn(provider) => { Some(messages) if !messages.is_empty() => messages,
let features = provided_features _ => {
.take() let prompt = system_prompt_fn(&prompt_context, &[])
.unwrap_or_else(|| Worker::infer_model_features(Some(provider), &model_name)); .map_err(|e| WorkerError::config(e.to_string()))?;
vec![Message::new(Role::System, prompt)]
let model = Model {
provider: ModelProvider::BuiltIn(provider),
name: model_name.clone(),
features,
};
let prompt_context = Worker::create_system_prompt_context(&model, &tool_names);
let base_messages = if let Some(messages) = prompt_cache.take() {
if messages.is_empty() {
vec![Message::new(
Role::System,
system_prompt_fn(&prompt_context, &[])
.map_err(|e| WorkerError::config(e.to_string()))?,
)]
} else {
messages
} }
} else {
vec![Message::new(
Role::System,
system_prompt_fn(&prompt_context, &[])
.map_err(|e| WorkerError::config(e.to_string()))?,
)]
}; };
let base_system_prompt = base_messages let base_system_prompt = base_messages
.first() .first()
.map(|msg| msg.content.clone()) .map(|msg| msg.content.clone())
.unwrap_or_default(); .unwrap_or_else(|| String::new());
match provider_config {
ProviderConfig::BuiltIn(provider) => {
let api_key = api_keys let api_key = api_keys
.entry(provider.as_str().to_string()) .entry(provider.as_str().to_string())
.or_insert_with(String::new) .or_insert_with(String::new)
@ -213,12 +170,13 @@ impl WorkerBlueprint {
let llm_client = provider.create_client(&model_name, &api_key)?; let llm_client = provider.create_client(&model_name, &api_key)?;
let mut worker = Worker { let mut worker = Worker {
llm_client: Box::new(llm_client), llm_client: Box::new(llm_client),
system_prompt: base_system_prompt, system_prompt: base_system_prompt.clone(),
system_prompt_fn: Arc::clone(&system_prompt_fn), system_prompt_fn: Arc::clone(&system_prompt_fn),
tools, tools,
api_key, api_key,
model, provider_str: provider.as_str().to_string(),
message_history: base_messages, model_name,
message_history: base_messages.clone(),
hook_manager: HookManager::new(), hook_manager: HookManager::new(),
mcp_lazy_configs: Vec::new(), mcp_lazy_configs: Vec::new(),
plugin_registry: Arc::new(Mutex::new(plugin::PluginRegistry::new())), plugin_registry: Arc::new(Mutex::new(plugin::PluginRegistry::new())),
@ -227,40 +185,6 @@ impl WorkerBlueprint {
Ok(worker) Ok(worker)
} }
ProviderConfig::Plugin { id, registry } => { ProviderConfig::Plugin { id, registry } => {
let features = provided_features
.take()
.unwrap_or_else(ModelFeatures::default);
let model = Model {
provider: ModelProvider::Plugin(id.clone()),
name: model_name.clone(),
features,
};
let prompt_context = Worker::create_system_prompt_context(&model, &tool_names);
let base_messages = if let Some(messages) = prompt_cache.take() {
if messages.is_empty() {
vec![Message::new(
Role::System,
system_prompt_fn(&prompt_context, &[])
.map_err(|e| WorkerError::config(e.to_string()))?,
)]
} else {
messages
}
} else {
vec![Message::new(
Role::System,
system_prompt_fn(&prompt_context, &[])
.map_err(|e| WorkerError::config(e.to_string()))?,
)]
};
let base_system_prompt = base_messages
.first()
.map(|msg| msg.content.clone())
.unwrap_or_default();
let api_key = api_keys let api_key = api_keys
.remove("__plugin__") .remove("__plugin__")
.or_else(|| api_keys.values().next().cloned()) .or_else(|| api_keys.values().next().cloned())
@ -284,7 +208,8 @@ impl WorkerBlueprint {
system_prompt_fn, system_prompt_fn,
tools, tools,
api_key, api_key,
model, provider_str: id,
model_name,
message_history: base_messages, message_history: base_messages,
hook_manager: HookManager::new(), hook_manager: HookManager::new(),
mcp_lazy_configs: Vec::new(), mcp_lazy_configs: Vec::new(),
@ -295,6 +220,15 @@ impl WorkerBlueprint {
} }
} }
} }
fn build_system_prompt_context(
&self,
provider: &ProviderConfig,
model_name: &str,
tool_names: &[String],
) -> SystemPromptContext {
Worker::create_system_prompt_context(provider.provider_hint(), model_name, tool_names)
}
} }
impl ProviderConfig { impl ProviderConfig {
@ -311,4 +245,11 @@ impl ProviderConfig {
ProviderConfig::Plugin { registry, .. } => Some(Arc::clone(registry)), ProviderConfig::Plugin { registry, .. } => Some(Arc::clone(registry)),
} }
} }
fn provider_hint(&self) -> LlmProvider {
match self {
ProviderConfig::BuiltIn(provider) => *provider,
ProviderConfig::Plugin { .. } => LlmProvider::OpenAI,
}
}
} }

View File

@ -8,6 +8,8 @@ use llm::{
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use tracing; use tracing;
use uuid; use uuid;
@ -322,67 +324,127 @@ pub async fn validate_api_key(
} }
} }
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ModelsConfig {
pub models: Vec<ModelDefinition>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ModelDefinition {
pub model: String,
pub name: String,
pub meta: ModelMeta,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ModelMeta {
pub tool_support: bool,
pub function_calling: bool,
pub vision: bool,
pub multimodal: bool,
pub context_length: Option<u32>,
pub description: Option<String>,
}
fn get_models_config_path() -> Result<PathBuf, WorkerError> {
let home_dir = dirs::home_dir()
.ok_or_else(|| WorkerError::config("Could not determine home directory"))?;
Ok(home_dir.join(".config").join("nia").join("models.yaml"))
}
fn load_models_config() -> Result<ModelsConfig, WorkerError> {
let config_path = get_models_config_path()?;
if !config_path.exists() {
tracing::warn!(
"Models config file not found at {:?}, using defaults",
config_path
);
return Ok(ModelsConfig { models: vec![] });
}
let content = fs::read_to_string(&config_path)
.map_err(|e| WorkerError::config(format!("Failed to read models config: {}", e)))?;
let config: ModelsConfig = serde_yaml::from_str(&content)
.map_err(|e| WorkerError::config(format!("Failed to parse models config: {}", e)))?;
Ok(config)
}
pub async fn supports_native_tools(
provider: &LlmProvider,
model_name: &str,
_api_key: &str,
) -> Result<bool, WorkerError> {
let config = load_models_config()?;
let model_id = format!(
"{}/{}",
match provider {
LlmProvider::Claude => "anthropic",
LlmProvider::OpenAI => "openai",
LlmProvider::Gemini => "gemini",
LlmProvider::Ollama => "ollama",
LlmProvider::XAI => "xai",
},
model_name
);
for model_def in &config.models {
if model_def.model == model_id || model_def.model.contains(model_name) {
tracing::debug!(
"Found model config: model={}, function_calling={}",
model_def.model,
model_def.meta.function_calling
);
return Ok(model_def.meta.function_calling);
}
}
tracing::warn!(
"Model not found in config: {} ({}), using provider defaults",
model_id,
model_name
);
tracing::warn!(
"Using provider-based fallback - this should be configured in models.yaml: provider={:?}, model={}",
provider,
model_name
);
let supports_tools = match provider {
LlmProvider::Claude => true,
LlmProvider::OpenAI => !model_name.contains("gpt-3.5-turbo-instruct"),
LlmProvider::Gemini => !model_name.contains("gemini-pro-vision"),
LlmProvider::Ollama => false,
LlmProvider::XAI => true,
};
tracing::debug!(
"Fallback tool support check: provider={:?}, model={}, supports_tools={}",
provider,
model_name,
supports_tools
);
Ok(supports_tools)
}
pub struct Worker { pub struct Worker {
pub(crate) llm_client: Box<dyn LlmClientTrait>, pub(crate) llm_client: Box<dyn LlmClientTrait>,
pub(crate) system_prompt: String, pub(crate) system_prompt: String,
pub(crate) system_prompt_fn: Arc<SystemPromptFn>, pub(crate) system_prompt_fn: Arc<SystemPromptFn>,
pub(crate) tools: Vec<Box<dyn Tool>>, pub(crate) tools: Vec<Box<dyn Tool>>,
pub(crate) api_key: String, pub(crate) api_key: String,
pub(crate) model: Model, pub(crate) provider_str: String,
pub(crate) model_name: String,
pub(crate) message_history: Vec<Message>, pub(crate) message_history: Vec<Message>,
pub(crate) hook_manager: crate::types::HookManager, pub(crate) hook_manager: crate::types::HookManager,
pub(crate) mcp_lazy_configs: Vec<McpServerConfig>, pub(crate) mcp_lazy_configs: Vec<McpServerConfig>,
pub(crate) plugin_registry: std::sync::Arc<std::sync::Mutex<plugin::PluginRegistry>>, pub(crate) plugin_registry: std::sync::Arc<std::sync::Mutex<plugin::PluginRegistry>>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ModelProvider {
BuiltIn(LlmProvider),
Plugin(String),
}
impl ModelProvider {
pub fn identifier(&self) -> String {
match self {
ModelProvider::BuiltIn(provider) => provider.as_str().to_string(),
ModelProvider::Plugin(id) => id.clone(),
}
}
pub fn as_llm_provider(&self) -> Option<LlmProvider> {
match self {
ModelProvider::BuiltIn(provider) => Some(*provider),
ModelProvider::Plugin(_) => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Model {
pub provider: ModelProvider,
pub name: String,
pub features: ModelFeatures,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ModelFeatures {
pub supports_tools: bool,
pub supports_function_calling: bool,
pub supports_vision: bool,
pub supports_multimodal: bool,
pub context_length: Option<u64>,
}
impl Model {
pub fn provider_id(&self) -> String {
self.provider.identifier()
}
pub fn built_in_provider(&self) -> Option<LlmProvider> {
self.provider.as_llm_provider()
}
}
impl Worker { impl Worker {
/// Create a new Worker blueprint /// Create a new Worker blueprint
/// ///
@ -442,20 +504,21 @@ impl Worker {
system_prompt_fn, system_prompt_fn,
tools, tools,
api_key, api_key,
model, provider_str,
model_name,
message_history, message_history,
hook_manager, hook_manager,
mcp_lazy_configs: _, mcp_lazy_configs: _,
plugin_registry, plugin_registry,
} = self; } = self;
let provider = match &model.provider { let provider = match LlmProvider::from_str(&provider_str) {
ModelProvider::BuiltIn(p) => { Some(p) => {
drop(plugin_registry); drop(plugin_registry);
ProviderConfig::BuiltIn(*p) ProviderConfig::BuiltIn(p)
} }
ModelProvider::Plugin(id) => ProviderConfig::Plugin { None => ProviderConfig::Plugin {
id: id.clone(), id: provider_str.clone(),
registry: plugin_registry, registry: plugin_registry,
}, },
}; };
@ -472,13 +535,12 @@ impl Worker {
WorkerBlueprint { WorkerBlueprint {
provider: Some(provider), provider: Some(provider),
model_name: Some(model.name), model_name: Some(model_name),
api_keys, api_keys,
system_prompt_fn: Some(system_prompt_fn), system_prompt_fn: Some(system_prompt_fn),
tools, tools,
hooks: hook_manager.into_hooks(), hooks: hook_manager.into_hooks(),
prompt_cache: Some(message_history), prompt_cache: Some(message_history),
model_features: Some(model.features),
} }
} }
@ -753,19 +815,27 @@ impl Worker {
/// 静的プロンプトコンテキストを作成(構築時用) /// 静的プロンプトコンテキストを作成(構築時用)
pub(crate) fn create_system_prompt_context( pub(crate) fn create_system_prompt_context(
model: &Model, provider: LlmProvider,
model_name: &str,
tools: &[String], tools: &[String],
) -> crate::prompt::SystemPromptContext { ) -> crate::prompt::SystemPromptContext {
let supports_native_tools = match provider {
LlmProvider::Claude => true,
LlmProvider::OpenAI => !model_name.contains("gpt-3.5-turbo-instruct"),
LlmProvider::Gemini => !model_name.contains("gemini-pro-vision"),
LlmProvider::Ollama => model_name.contains("llama") || model_name.contains("mistral"),
LlmProvider::XAI => true,
};
let model_context = crate::prompt::ModelContext { let model_context = crate::prompt::ModelContext {
provider: model.built_in_provider(), provider,
provider_id: model.provider_id(), model_name: model_name.to_string(),
model_name: model.name.clone(),
capabilities: crate::prompt::ModelCapabilities { capabilities: crate::prompt::ModelCapabilities {
supports_tools: model.features.supports_tools, supports_tools: supports_native_tools,
supports_function_calling: model.features.supports_function_calling, supports_function_calling: supports_native_tools,
supports_vision: model.features.supports_vision, supports_vision: false,
supports_multimodal: Some(model.features.supports_multimodal), supports_multimodal: Some(false),
context_length: model.features.context_length, context_length: None,
capabilities: vec![], capabilities: vec![],
needs_verification: Some(false), needs_verification: Some(false),
}, },
@ -780,38 +850,23 @@ impl Worker {
} }
} }
fn infer_model_features(provider: Option<LlmProvider>, model_name: &str) -> ModelFeatures {
let mut features = ModelFeatures::default();
if let Some(provider) = provider {
let supports_tools = match provider {
LlmProvider::Claude => true,
LlmProvider::OpenAI => !model_name.contains("gpt-3.5-turbo-instruct"),
LlmProvider::Gemini => !model_name.contains("gemini-pro-vision"),
LlmProvider::Ollama => {
model_name.contains("llama") || model_name.contains("mistral")
}
LlmProvider::XAI => true,
};
features.supports_tools = supports_tools;
features.supports_function_calling = supports_tools;
}
features
}
/// モデルを変更する /// モデルを変更する
pub fn change_model(&mut self, model: Model, api_key: &str) -> Result<(), WorkerError> { pub fn change_model(
if let Some(provider) = model.built_in_provider() { &mut self,
let new_client = provider.create_client(&model.name, api_key)?; provider: LlmProvider,
self.llm_client = Box::new(new_client); model_name: &str,
} api_key: &str,
) -> Result<(), WorkerError> {
// 新しいLLMクライアントを作成
let new_client = provider.create_client(model_name, api_key)?;
self.model = model; // 古いクライアントを新しいものに置き換え
self.llm_client = Box::new(new_client);
self.provider_str = provider.as_str().to_string();
self.model_name = model_name.to_string();
self.api_key = api_key.to_string(); self.api_key = api_key.to_string();
tracing::info!("Model changed to {}", self.model.provider_id()); tracing::info!("Model changed to {}/{}", provider.as_str(), model_name);
Ok(()) Ok(())
} }
@ -876,18 +931,16 @@ impl Worker {
/// Get the model name for tool support detection /// Get the model name for tool support detection
pub fn get_model_name(&self) -> String { pub fn get_model_name(&self) -> String {
self.model.name.clone() self.llm_client.get_model_name()
} }
pub fn get_provider_name(&self) -> String { pub fn get_provider_name(&self) -> String {
self.model.provider_id() self.llm_client.provider().to_string()
} }
/// Get configuration information for task delegation /// Get configuration information for task delegation
pub fn get_config(&self) -> Option<(LlmProvider, &str, &str)> { pub fn get_config(&self) -> (LlmProvider, &str, &str) {
self.model (self.llm_client.provider(), &self.model_name, &self.api_key)
.built_in_provider()
.map(|provider| (provider, self.model.name.as_str(), self.api_key.as_str()))
} }
/// Get tool names (used to filter out specific tools) /// Get tool names (used to filter out specific tools)
@ -983,37 +1036,25 @@ impl Worker {
}; };
// Create a temporary worker for processing without holding the lock // Create a temporary worker for processing without holding the lock
let (llm_client, system_prompt, tool_definitions, model) = { let (llm_client, system_prompt, tool_definitions, api_key, model_name) = {
let w_locked = worker.lock().await; let w_locked = worker.lock().await;
let provider = match w_locked.model.built_in_provider() { let llm_client = w_locked.llm_client.provider().create_client(&w_locked.model_name, &w_locked.api_key);
Some(provider) => provider, match llm_client {
None => {
yield Err(WorkerError::config(
"Delegated processing is not supported for plugin providers",
));
return;
}
};
match provider.create_client(&w_locked.model.name, &w_locked.api_key) {
Ok(client) => { Ok(client) => {
let tool_defs = w_locked let tool_defs = w_locked.tools.iter().map(|tool| crate::types::DynamicToolDefinition {
.tools
.iter()
.map(|tool| crate::types::DynamicToolDefinition {
name: tool.name().to_string(), name: tool.name().to_string(),
description: tool.description().to_string(), description: tool.description().to_string(),
parameters_schema: tool.parameters_schema(), parameters_schema: tool.parameters_schema(),
}) }).collect::<Vec<_>>();
.collect::<Vec<_>>();
( (
client, client,
w_locked.system_prompt.clone(), w_locked.system_prompt.clone(),
tool_defs, tool_defs,
w_locked.model.clone(), w_locked.api_key.clone(),
w_locked.model_name.clone()
) )
} },
Err(e) => { Err(e) => {
yield Err(e); yield Err(e);
return; return;
@ -1026,7 +1067,13 @@ impl Worker {
loop { loop {
let provider = llm_client.provider(); let provider = llm_client.provider();
let supports_native = model.features.supports_tools; let supports_native = match supports_native_tools(&provider, &model_name, &api_key).await {
Ok(supports) => supports,
Err(e) => {
tracing::warn!("Failed to check native tool support: {}", e);
false
}
};
let (composed_messages, tools_for_llm) = if supports_native { let (composed_messages, tools_for_llm) = if supports_native {
let messages = let messages =
@ -1207,21 +1254,16 @@ impl Worker {
let tools = self.get_tools(); let tools = self.get_tools();
let provider = self.llm_client.provider(); let provider = self.llm_client.provider();
let model_name = self.get_model_name(); let model_name = self.get_model_name();
tracing::debug!( tracing::debug!("Checking native tool support: provider={:?}, model_name={}, api_key_len={}, provider_str={}", provider, model_name, self.api_key.len(), self.provider_str);
"Checking native tool support: provider={:?}, model_name={}, api_key_len={}, provider_id={}", let supports_native = match supports_native_tools(&provider, &model_name, &self.api_key).await {
provider, Ok(supports) => supports,
model_name, Err(e) => {
self.api_key.len(), tracing::warn!("Failed to check native tool support: {}", e);
self.model.provider_id() false
); }
};
let supports_native = self.model.features.supports_tools; tracing::info!("Model {} supports native tools: {}", model_name, supports_native);
tracing::info!(
"Model {} supports native tools: {}",
model_name,
supports_native
);
let (composed_messages, tools_for_llm) = if supports_native { let (composed_messages, tools_for_llm) = if supports_native {
// Native tools - basic composition // Native tools - basic composition
@ -1429,16 +1471,15 @@ impl Worker {
loop { loop {
let tools = self.get_tools(); let tools = self.get_tools();
let provider = self.llm_client.provider(); let provider = self.llm_client.provider();
let model_name = self.model.name.clone(); let model_name = self.get_model_name();
tracing::debug!( tracing::debug!("Checking native tool support: provider={:?}, model_name={}, api_key_len={}, provider_str={}", provider, model_name, self.api_key.len(), self.provider_str);
"Checking native tool support: provider={:?}, model_name={}, api_key_len={}, provider_id={}", let supports_native = match supports_native_tools(&provider, &model_name, &self.api_key).await {
provider, Ok(supports) => supports,
model_name, Err(e) => {
self.api_key.len(), tracing::warn!("Failed to check native tool support: {}", e);
self.model.provider_id() false
); }
};
let supports_native = self.model.features.supports_tools;
tracing::info!("Model {} supports native tools: {}", model_name, supports_native); tracing::info!("Model {} supports native tools: {}", model_name, supports_native);
@ -1632,8 +1673,8 @@ impl Worker {
let session_id = uuid::Uuid::new_v4().to_string(); let session_id = uuid::Uuid::new_v4().to_string();
let mut session_data = SessionData::new( let mut session_data = SessionData::new(
session_id, session_id,
self.model.provider_id(), self.provider_str.clone(),
self.model.name.clone(), self.model_name.clone(),
workspace_path, workspace_path,
); );
session_data.git_branch = git_branch; session_data.git_branch = git_branch;
@ -1647,15 +1688,15 @@ impl Worker {
/// セッションデータから履歴を復元する /// セッションデータから履歴を復元する
pub fn load_session(&mut self, session_data: &SessionData) -> Result<(), WorkerError> { pub fn load_session(&mut self, session_data: &SessionData) -> Result<(), WorkerError> {
// モデルが異なる場合は警告をログに出す // モデルが異なる場合は警告をログに出す
if session_data.model_provider != self.model.provider_id() if session_data.model_provider != self.provider_str
|| session_data.model_name != self.model.name || session_data.model_name != self.model_name
{ {
tracing::warn!( tracing::warn!(
"Loading session with different model: session={}:{}, current={}:{}", "Loading session with different model: session={}:{}, current={}:{}",
session_data.model_provider, session_data.model_provider,
session_data.model_name, session_data.model_name,
self.model.provider_id(), self.provider_str,
self.model.name self.model_name
); );
} }

View File

@ -2,7 +2,8 @@ use super::tool::McpServerConfig;
use crate::types::WorkerError; use crate::types::WorkerError;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use tracing::{debug, warn}; use std::path::Path;
use tracing::{debug, info, warn};
/// MCP設定ファイルの構造 /// MCP設定ファイルの構造
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -56,10 +57,121 @@ fn default_integration_mode() -> IntegrationMode {
} }
impl McpConfig { impl McpConfig {
/// 設定ファイルを読み込む
pub fn load_from_file<P: AsRef<Path>>(path: P) -> Result<Self, WorkerError> {
let path = path.as_ref();
if !path.exists() {
debug!(
"MCP config file not found at {:?}, returning empty config",
path
);
return Ok(Self::default());
}
info!("Loading MCP config from: {:?}", path);
let content = std::fs::read_to_string(path).map_err(|e| {
WorkerError::config(format!("Failed to read MCP config file {:?}: {}", path, e))
})?;
let config: McpConfig = serde_yaml::from_str(&content).map_err(|e| {
WorkerError::config(format!("Failed to parse MCP config file {:?}: {}", path, e))
})?;
info!("Loaded {} MCP server configurations", config.servers.len());
Ok(config)
}
/// 設定ファイルに保存する
pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> Result<(), WorkerError> {
let path = path.as_ref();
// ディレクトリが存在しない場合は作成
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).map_err(|e| {
WorkerError::config(format!(
"Failed to create config directory {:?}: {}",
parent, e
))
})?;
}
let content = serde_yaml::to_string(self)
.map_err(|e| WorkerError::config(format!("Failed to serialize MCP config: {}", e)))?;
std::fs::write(path, content).map_err(|e| {
WorkerError::config(format!("Failed to write MCP config file {:?}: {}", path, e))
})?;
info!("Saved MCP config to: {:?}", path);
Ok(())
}
/// 有効なサーバー設定を取得 /// 有効なサーバー設定を取得
pub fn get_enabled_servers(&self) -> Vec<(&String, &McpServerDefinition)> { pub fn get_enabled_servers(&self) -> Vec<(&String, &McpServerDefinition)> {
self.servers.iter().filter(|(_, def)| def.enabled).collect() self.servers.iter().filter(|(_, def)| def.enabled).collect()
} }
/// デフォルト設定ファイルを生成
pub fn create_default_config() -> Self {
let mut servers = HashMap::new();
// Brave Search MCP Server の設定例
servers.insert(
"brave_search".to_string(),
McpServerDefinition {
command: "npx".to_string(),
args: vec![
"-y".to_string(),
"@brave/brave-search-mcp-server".to_string(),
],
env: {
let mut env = HashMap::new();
env.insert("BRAVE_API_KEY".to_string(), "${BRAVE_API_KEY}".to_string());
env
},
description: Some("Brave Search API for web searching".to_string()),
enabled: false, // デフォルトでは無効APIキーが必要なため
integration_mode: IntegrationMode::Individual,
},
);
// ファイルシステムMCPサーバーの設定例
servers.insert(
"filesystem".to_string(),
McpServerDefinition {
command: "npx".to_string(),
args: vec![
"-y".to_string(),
"@modelcontextprotocol/server-filesystem".to_string(),
"/tmp".to_string(),
],
env: HashMap::new(),
description: Some("Filesystem operations in /tmp directory".to_string()),
enabled: false, // デフォルトでは無効
integration_mode: IntegrationMode::Individual,
},
);
// Git MCP サーバーの設定例
servers.insert(
"git".to_string(),
McpServerDefinition {
command: "npx".to_string(),
args: vec![
"-y".to_string(),
"@modelcontextprotocol/server-git".to_string(),
".".to_string(),
],
env: HashMap::new(),
description: Some("Git operations in current directory".to_string()),
enabled: false, // デフォルトでは無効
integration_mode: IntegrationMode::Individual,
},
);
Self { servers }
}
} }
impl Default for McpConfig { impl Default for McpConfig {
@ -126,3 +238,120 @@ fn expand_environment_variables(input: &str) -> Result<String, WorkerError> {
Ok(result) Ok(result)
} }
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::tempdir;
#[test]
fn test_default_config_creation() {
let config = McpConfig::create_default_config();
assert!(!config.servers.is_empty());
assert!(config.servers.contains_key("brave_search"));
assert!(config.servers.contains_key("filesystem"));
}
#[test]
fn test_config_serialization() {
let config = McpConfig::create_default_config();
let yaml = serde_yaml::to_string(&config).unwrap();
// YAML形式で正しくシリアライズされることを確認
assert!(yaml.contains("servers:"));
assert!(yaml.contains("brave_search:"));
assert!(yaml.contains("command:"));
}
#[test]
fn test_config_deserialization() {
let yaml_content = r#"
servers:
test_server:
command: "python3"
args: ["test.py"]
env:
TEST_VAR: "test_value"
description: "Test server"
enabled: true
integration_mode: "individual"
"#;
let config: McpConfig = serde_yaml::from_str(yaml_content).unwrap();
assert_eq!(config.servers.len(), 1);
let server = config.servers.get("test_server").unwrap();
assert_eq!(server.command, "python3");
assert_eq!(server.args, vec!["test.py"]);
assert_eq!(server.env.get("TEST_VAR").unwrap(), "test_value");
assert!(server.enabled);
}
#[test]
fn test_environment_variable_expansion() {
// SAFETY: Setting test environment variables in a single-threaded test context
unsafe {
std::env::set_var("TEST_VAR", "test_value");
}
let result = expand_environment_variables("prefix_${TEST_VAR}_suffix").unwrap();
assert_eq!(result, "prefix_test_value_suffix");
// 存在しない環境変数の場合はそのまま残る
let result = expand_environment_variables("${NON_EXISTENT_VAR}").unwrap();
assert_eq!(result, "${NON_EXISTENT_VAR}");
}
#[test]
fn test_config_file_operations() {
let dir = tempdir().unwrap();
let config_path = dir.path().join("mcp.yaml");
// 設定を作成して保存
let config = McpConfig::create_default_config();
config.save_to_file(&config_path).unwrap();
// ファイルが作成されたことを確認
assert!(config_path.exists());
// 設定を読み込み
let loaded_config = McpConfig::load_from_file(&config_path).unwrap();
assert_eq!(config.servers.len(), loaded_config.servers.len());
}
#[test]
fn test_enabled_servers_filter() {
let mut config = McpConfig::default();
// 有効なサーバーを追加
config.servers.insert(
"enabled_server".to_string(),
McpServerDefinition {
command: "test".to_string(),
args: vec![],
env: HashMap::new(),
description: None,
enabled: true,
integration_mode: IntegrationMode::Individual,
},
);
// 無効なサーバーを追加
config.servers.insert(
"disabled_server".to_string(),
McpServerDefinition {
command: "test".to_string(),
args: vec![],
env: HashMap::new(),
description: None,
enabled: false,
integration_mode: IntegrationMode::Individual,
},
);
let enabled_servers = config.get_enabled_servers();
assert_eq!(enabled_servers.len(), 1);
assert_eq!(enabled_servers[0].0, "enabled_server");
}
}

View File

@ -184,7 +184,7 @@ impl McpClient {
}), }),
}, },
client_info: ClientInfo { client_info: ClientInfo {
name: "llm-worker-rs".to_string(), name: "nia-worker".to_string(),
version: "0.1.0".to_string(), version: "0.1.0".to_string(),
}, },
}; };
@ -417,3 +417,33 @@ impl Drop for McpClient {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_json_rpc_serialization() {
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Value::Number(serde_json::Number::from(1)),
method: "tools/list".to_string(),
params: None,
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"jsonrpc\":\"2.0\""));
assert!(json.contains("\"id\":1"));
assert!(json.contains("\"method\":\"tools/list\""));
}
#[test]
fn test_error_response() {
let response_json =
r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32601,"message":"Method not found"}}"#;
let response: JsonRpcResponse = serde_json::from_str(response_json).unwrap();
assert!(response.error.is_some());
assert_eq!(response.error.unwrap().code, -32601);
}
}

View File

@ -451,3 +451,36 @@ pub async fn test_mcp_connection(
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mcp_server_config() {
let config =
McpServerConfig::new("npx", vec!["-y", "@modelcontextprotocol/server-everything"]);
assert_eq!(config.command, "npx");
assert_eq!(
config.args,
vec!["-y", "@modelcontextprotocol/server-everything"]
);
assert_eq!(
config.name,
"npx(-y @modelcontextprotocol/server-everything)"
);
}
#[tokio::test]
async fn test_mcp_tool_creation() {
let config = McpServerConfig::new("echo", vec!["test"]);
let tool = McpDynamicTool::new(config);
assert_eq!(tool.name(), "mcp_proxy");
assert!(!tool.description().is_empty());
let schema = tool.parameters_schema();
assert!(schema.is_object());
assert!(schema.get("properties").is_some());
}
}

View File

@ -206,3 +206,42 @@ impl LlmClientTrait for CustomLlmClient {
pub extern "C" fn create_plugin() -> Box<dyn ProviderPlugin> { pub extern "C" fn create_plugin() -> Box<dyn ProviderPlugin> {
Box::new(CustomProviderPlugin::new()) Box::new(CustomProviderPlugin::new())
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_plugin_metadata() {
let plugin = CustomProviderPlugin::new();
let metadata = plugin.metadata();
assert_eq!(metadata.id, "custom-provider");
assert_eq!(metadata.name, "Custom LLM Provider");
assert!(metadata.requires_api_key);
assert_eq!(metadata.supported_models.len(), 3);
}
#[test]
fn test_api_key_validation() {
let plugin = CustomProviderPlugin::new();
assert!(plugin.validate_api_key("custom-1234567890abcdefghij"));
assert!(!plugin.validate_api_key("invalid-key"));
assert!(!plugin.validate_api_key("custom-short"));
assert!(!plugin.validate_api_key(""));
}
#[tokio::test]
async fn test_plugin_initialization() {
let mut plugin = CustomProviderPlugin::new();
let mut config = HashMap::new();
config.insert(
"base_url".to_string(),
Value::String("https://api.example.com".to_string()),
);
let result = plugin.initialize(config).await;
assert!(result.is_ok());
}
}

View File

@ -4,8 +4,7 @@ use std::collections::HashMap;
/// モデルに関する静的な情報 /// モデルに関する静的な情報
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelContext { pub struct ModelContext {
pub provider: Option<crate::types::LlmProvider>, pub provider: crate::types::LlmProvider,
pub provider_id: String,
pub model_name: String, pub model_name: String,
pub capabilities: ModelCapabilities, pub capabilities: ModelCapabilities,
} }

View File

@ -1,496 +0,0 @@
use std::sync::{Arc, Mutex, OnceLock};
use async_trait::async_trait;
use futures::StreamExt;
use serde_json::{json, Value};
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockGuard, MockServer, ResponseTemplate};
use worker::{
HookContext, HookResult, LlmProvider, PromptError, StreamEvent, Tool, ToolResult, Worker,
WorkerBlueprint, WorkerHook,
};
use worker_types::Role;
const SAMPLE_TOOL_NAME: &str = "sample_tool";
static ENV_MUTEX: OnceLock<Mutex<()>> = OnceLock::new();
struct ProviderCase {
name: &'static str,
provider: LlmProvider,
model: &'static str,
env_var: &'static str,
completion_path: &'static str,
}
struct SampleTool {
name: String,
description: String,
calls: Arc<Mutex<Vec<Value>>>,
response: Value,
}
impl SampleTool {
fn new(provider_label: &str, calls: Arc<Mutex<Vec<Value>>>) -> Self {
Self {
name: SAMPLE_TOOL_NAME.to_string(),
description: format!("Records invocations for {}", provider_label),
calls,
response: json!({
"status": "ok",
"provider": provider_label,
}),
}
}
}
#[async_trait]
impl Tool for SampleTool {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn parameters_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"provider": {"type": "string"},
"request_id": {"type": "integer"}
},
"required": ["provider", "request_id"]
})
}
async fn execute(&self, args: Value) -> ToolResult<Value> {
self.calls.lock().unwrap().push(args.clone());
Ok(self.response.clone())
}
}
struct RecordingHook {
tool_name: String,
provider_label: String,
events: Arc<Mutex<Vec<String>>>,
}
impl RecordingHook {
fn new(provider_label: &str, events: Arc<Mutex<Vec<String>>>) -> Self {
Self {
tool_name: SAMPLE_TOOL_NAME.to_string(),
provider_label: provider_label.to_string(),
events,
}
}
}
#[async_trait]
impl WorkerHook for RecordingHook {
fn name(&self) -> &str {
"recording_hook"
}
fn hook_type(&self) -> &str {
"PostToolUse"
}
fn matcher(&self) -> &str {
&self.tool_name
}
async fn execute(&self, context: HookContext) -> (HookContext, HookResult) {
let tool = context
.get_variable("current_tool")
.cloned()
.unwrap_or_else(|| self.tool_name.clone());
let entry = format!(
"{}::{}::{}",
self.provider_label, tool, context.content
);
self.events.lock().unwrap().push(entry);
let message = format!("{} hook observed {}", self.provider_label, tool);
(
context,
HookResult::AddMessage(message, Role::Assistant),
)
}
}
struct EnvOverride {
key: String,
previous: Option<String>,
}
impl EnvOverride {
fn set(key: &str, value: String) -> Self {
let previous = std::env::var(key).ok();
std::env::set_var(key, &value);
Self {
key: key.to_string(),
previous,
}
}
}
impl Drop for EnvOverride {
fn drop(&mut self) {
if let Some(prev) = self.previous.take() {
std::env::set_var(&self.key, prev);
} else {
std::env::remove_var(&self.key);
}
}
}
fn build_blueprint(
provider: LlmProvider,
model: &str,
provider_label: &str,
tool_calls: Arc<Mutex<Vec<Value>>>,
hook_events: Arc<Mutex<Vec<String>>>,
) -> WorkerBlueprint {
let mut blueprint = Worker::blueprint();
blueprint
.provider(provider)
.model(model)
.api_key(provider.as_str(), "test-key")
.system_prompt_fn(|_, _| Ok::<_, PromptError>("Integration test system prompt.".into()))
.add_tool(SampleTool::new(provider_label, Arc::clone(&tool_calls)))
.attach_hook(RecordingHook::new(provider_label, Arc::clone(&hook_events)));
blueprint
}
async fn setup_mock_response(
case: &ProviderCase,
server: &MockServer,
expected_args: &Value,
) -> MockGuard {
match case.provider {
LlmProvider::OpenAI => {
let arguments = expected_args.to_string();
let event_body = json!({
"choices": [{
"delta": {
"tool_calls": [{
"function": {
"name": SAMPLE_TOOL_NAME,
"arguments": arguments
}
}]
}
}]
});
let sse = format!(
"data: {}\n\ndata: [DONE]\n\n",
event_body
);
Mock::given(method("POST"))
.and(path(case.completion_path))
.respond_with(
ResponseTemplate::new(200)
.set_body_raw(sse, "text/event-stream"),
)
.mount(server)
.await
}
LlmProvider::Claude => {
let event_tool = json!({
"type": "content_block_start",
"data": {
"content_block": {
"type": "tool_use",
"name": SAMPLE_TOOL_NAME,
"input": expected_args
}
}
});
let event_stop = json!({
"type": "message_stop",
"data": {}
});
let sse = format!(
"data: {}\n\ndata: {}\n\ndata: [DONE]\n\n",
event_tool, event_stop
);
Mock::given(method("POST"))
.and(path(case.completion_path))
.respond_with(
ResponseTemplate::new(200)
.set_body_raw(sse, "text/event-stream"),
)
.mount(server)
.await
}
LlmProvider::Gemini => {
let body = json!({
"candidates": [{
"content": {
"role": "model",
"parts": [{
"functionCall": {
"name": SAMPLE_TOOL_NAME,
"args": expected_args
}
}]
},
"finishReason": "STOP"
}]
});
Mock::given(method("POST"))
.and(path(case.completion_path))
.respond_with(ResponseTemplate::new(200).set_body_json(body))
.mount(server)
.await
}
LlmProvider::Ollama => {
let first = json!({
"message": {
"role": "assistant",
"content": "",
"tool_calls": [{
"function": {
"name": SAMPLE_TOOL_NAME,
"arguments": expected_args
}
}]
},
"done": false
});
let second = json!({
"message": {
"role": "assistant",
"content": "finished"
},
"done": true
});
let body = format!("{}\n{}\n", first, second);
Mock::given(method("POST"))
.and(path(case.completion_path))
.respond_with(
ResponseTemplate::new(200)
.set_body_raw(body, "application/x-ndjson"),
)
.mount(server)
.await
}
other => panic!("Unsupported provider in test: {:?}", other),
}
}
fn provider_cases() -> Vec<ProviderCase> {
vec![
ProviderCase {
name: "openai",
provider: LlmProvider::OpenAI,
model: "gpt-4o-mini",
env_var: "OPENAI_BASE_URL",
completion_path: "/v1/chat/completions",
},
ProviderCase {
name: "gemini",
provider: LlmProvider::Gemini,
model: "gemini-1.5-flash",
env_var: "GEMINI_BASE_URL",
completion_path: "/v1beta/models/gemini-1.5-flash:streamGenerateContent",
},
ProviderCase {
name: "anthropic",
provider: LlmProvider::Claude,
model: "claude-3-opus-20240229",
env_var: "ANTHROPIC_BASE_URL",
completion_path: "/v1/messages",
},
ProviderCase {
name: "ollama",
provider: LlmProvider::Ollama,
model: "llama3",
env_var: "OLLAMA_BASE_URL",
completion_path: "/api/chat",
},
]
}
#[tokio::test]
async fn worker_executes_tools_and_hooks_across_mocked_providers() {
let env_lock = ENV_MUTEX.get_or_init(|| Mutex::new(()));
for case in provider_cases() {
let tool_calls = Arc::new(Mutex::new(Vec::<Value>::new()));
let hook_events = Arc::new(Mutex::new(Vec::<String>::new()));
let _env_guard = env_lock.lock().unwrap();
let server = MockServer::start().await;
let _env_override = EnvOverride::set(case.env_var, server.uri());
let expected_args = json!({
"provider": case.name,
"request_id": 1
});
let _mock = setup_mock_response(&case, &server, &expected_args).await;
let mut blueprint = build_blueprint(
case.provider,
case.model,
case.name,
Arc::clone(&tool_calls),
Arc::clone(&hook_events),
);
let mut worker = blueprint.instantiate().expect("worker to instantiate");
let mut stream = worker
.process_task_stream(
"Trigger the sample tool".to_string(),
None,
)
.await;
let mut events = Vec::new();
while let Some(event) = stream.next().await {
events.push(event.expect("stream event"));
}
let requests = server
.received_requests()
.await
.expect("to inspect received requests");
assert_eq!(
requests.len(),
1,
"expected exactly one request for provider {}",
case.name
);
let body: Value =
serde_json::from_slice(&requests[0].body).expect("request body to be JSON");
match case.provider {
LlmProvider::OpenAI => {
assert_eq!(body["model"], case.model);
assert_eq!(body["stream"], true);
assert_eq!(
body["tools"][0]["function"]["name"],
SAMPLE_TOOL_NAME
);
}
LlmProvider::Claude => {
assert_eq!(body["model"], case.model);
assert_eq!(body["stream"], true);
assert_eq!(body["tools"][0]["name"], SAMPLE_TOOL_NAME);
}
LlmProvider::Gemini => {
assert_eq!(
body["contents"]
.as_array()
.expect("contents to be array")
.len(),
2,
"system + user messages should be present"
);
let tools = body["tools"][0]["functionDeclarations"]
.as_array()
.expect("function declarations to exist");
assert_eq!(tools[0]["name"], SAMPLE_TOOL_NAME);
}
LlmProvider::Ollama => {
assert_eq!(body["model"], case.model);
assert_eq!(body["stream"], true);
assert_eq!(
body["tools"][0]["function"]["name"],
SAMPLE_TOOL_NAME
);
}
_ => unreachable!(),
}
let recorded_calls = tool_calls.lock().unwrap().clone();
assert_eq!(
recorded_calls.len(),
1,
"tool should execute exactly once for {}",
case.name
);
assert_eq!(
recorded_calls[0], expected_args,
"tool arguments should match for {}",
case.name
);
let recorded_hooks = hook_events.lock().unwrap().clone();
assert!(
recorded_hooks
.iter()
.any(|entry| entry.contains(case.name) && entry.contains(SAMPLE_TOOL_NAME)),
"hook should capture tool usage for {}: {:?}",
case.name,
recorded_hooks
);
let mut saw_tool_call = false;
let mut saw_tool_result = false;
let mut saw_hook_message = false;
let mut saw_completion = false;
for event in &events {
match event {
StreamEvent::ToolCall(call) => {
saw_tool_call = true;
assert_eq!(call.name, SAMPLE_TOOL_NAME);
assert_eq!(
serde_json::from_str::<Value>(&call.arguments)
.expect("tool arguments to be JSON"),
expected_args
);
}
StreamEvent::ToolResult { tool_name, result } => {
if tool_name == SAMPLE_TOOL_NAME {
saw_tool_result = true;
let value = result
.as_ref()
.expect("tool execution should succeed");
assert_eq!(value["status"], "ok");
assert_eq!(value["provider"], case.name);
}
}
StreamEvent::HookMessage { hook_name, content, .. } => {
if hook_name == "recording_hook" {
saw_hook_message = true;
assert!(
content.contains(case.name),
"hook content should mention provider"
);
}
}
StreamEvent::Completion(_) => {
saw_completion = true;
}
_ => {}
}
}
assert!(saw_tool_call, "missing tool call event for {}", case.name);
assert!(
saw_tool_result,
"missing tool result event for {}",
case.name
);
assert!(
saw_hook_message,
"missing hook message for {}",
case.name
);
assert!(
saw_completion,
"missing completion event for {}",
case.name
);
std::env::remove_var(case.env_var);
}
}