0.5.1: MCPの設定読み込みの削除
This commit is contained in:
parent
cc6bbe2a43
commit
90edd3828b
|
|
@ -23,9 +23,9 @@ use worker::{LlmProvider, SystemPromptContext, PromptError, Worker};
|
|||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let system_prompt = |ctx: &SystemPromptContext, _messages: &[worker_types::Message]| {
|
||||
Ok(format!(
|
||||
"You are assisting with model {} from provider {:?}.",
|
||||
"You are assisting with model {} from provider {}.",
|
||||
ctx.model.model_name,
|
||||
ctx.model.provider
|
||||
ctx.model.provider_id
|
||||
))
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ pub struct HookManager {
|
|||
pub enum HookEvent {
|
||||
OnMessageSend,
|
||||
PreToolUse,
|
||||
PostToolUse,
|
||||
PostToolUse,
|
||||
OnTurnCompleted,
|
||||
}
|
||||
```
|
||||
|
|
@ -181,10 +181,10 @@ impl HookContext {
|
|||
impl HookContext {
|
||||
// ストリーミング中にメッセージを送信
|
||||
pub fn stream_message(&self, content: String, role: Role);
|
||||
|
||||
|
||||
// ストリーミング中にシステム通知を送信
|
||||
pub fn stream_system_message(&self, content: String);
|
||||
|
||||
|
||||
// ストリーミング中にデバッグ情報を送信
|
||||
pub fn stream_debug(&self, title: String, data: serde_json::Value);
|
||||
}
|
||||
|
|
@ -198,22 +198,22 @@ Hook関数は以下のいずれかの結果を返す必要があります:
|
|||
pub enum HookResult {
|
||||
// 処理を続行
|
||||
Continue,
|
||||
|
||||
|
||||
// コンテンツを変更して続行
|
||||
ModifyContent(String),
|
||||
|
||||
|
||||
// システムメッセージを追加して続行
|
||||
AddMessage(String, Role),
|
||||
|
||||
|
||||
// 複数のメッセージを追加して続行
|
||||
AddMessages(Vec<Message>),
|
||||
|
||||
|
||||
// ターンを強制完了
|
||||
Complete,
|
||||
|
||||
|
||||
// エラーでターンを終了
|
||||
Error(String),
|
||||
|
||||
|
||||
// Hook処理をスキップ(デバッグ用)
|
||||
Skip,
|
||||
}
|
||||
|
|
@ -282,7 +282,7 @@ pub async fn dangerous_command_hook(context: HookContext) -> HookResult {
|
|||
if let Some(args) = &context.tool_args {
|
||||
if let Some(command) = args.get("command").and_then(|v| v.as_str()) {
|
||||
let dangerous_commands = ["rm -rf", "format", "dd if="];
|
||||
|
||||
|
||||
for dangerous in &dangerous_commands {
|
||||
if command.contains(dangerous) {
|
||||
return HookResult::Error(format!(
|
||||
|
|
@ -293,7 +293,7 @@ pub async fn dangerous_command_hook(context: HookContext) -> HookResult {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
HookResult::Continue
|
||||
}
|
||||
```
|
||||
|
|
@ -326,7 +326,7 @@ pub async fn auto_read_hook(mut context: HookContext) -> HookResult {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
HookResult::Continue
|
||||
}
|
||||
```
|
||||
|
|
@ -361,7 +361,7 @@ worker.register_hooks(tui_hooks);
|
|||
```rust
|
||||
// 実行順序の例
|
||||
worker.register_hook(Box::new(TimestampHook)); // 1番目
|
||||
worker.register_hook(Box::new(ValidationHook)); // 2番目
|
||||
worker.register_hook(Box::new(ValidationHook)); // 2番目
|
||||
worker.register_hook(Box::new(LoggingHook)); // 3番目
|
||||
```
|
||||
|
||||
|
|
@ -375,13 +375,13 @@ worker.register_hook(Box::new(LoggingHook)); // 3番目
|
|||
impl Worker {
|
||||
// Hook一覧を取得
|
||||
pub fn list_hooks(&self) -> Vec<(&str, &str)>; // (name, hook_type)
|
||||
|
||||
|
||||
// 特定のHookを削除
|
||||
pub fn remove_hook(&mut self, hook_name: &str) -> bool;
|
||||
|
||||
|
||||
// フェーズ別Hookを削除
|
||||
pub fn remove_hooks_by_phase(&mut self, hook_type: &str);
|
||||
|
||||
|
||||
// すべてのHookをクリア
|
||||
pub fn clear_hooks(&mut self);
|
||||
}
|
||||
|
|
@ -395,16 +395,16 @@ impl Worker {
|
|||
// worker/src/lib.rs の process_with_shared_state より
|
||||
stream! {
|
||||
// ... LLM応答処理中 ...
|
||||
|
||||
|
||||
// ツール呼び出し検出時
|
||||
if let Some(tool_calls) = &response.tool_calls {
|
||||
for tool_call in tool_calls {
|
||||
// PreToolUse hooks 実行
|
||||
let (context, hook_result) = execute_hooks(
|
||||
HookEvent::PreToolUse,
|
||||
HookEvent::PreToolUse,
|
||||
tool_call.name.clone()
|
||||
).await;
|
||||
|
||||
|
||||
match hook_result {
|
||||
HookResult::Error(msg) => {
|
||||
yield Ok(StreamEvent::Error(msg));
|
||||
|
|
@ -413,16 +413,16 @@ stream! {
|
|||
HookResult::Complete => break,
|
||||
_ => {}
|
||||
}
|
||||
|
||||
|
||||
// ツール実行
|
||||
let result = execute_tool(tool_call).await;
|
||||
|
||||
|
||||
// PostToolUse hooks 実行(ストリーミング中)
|
||||
let (context, hook_result) = execute_hooks(
|
||||
HookEvent::PostToolUse,
|
||||
tool_call.name.clone()
|
||||
).await;
|
||||
|
||||
|
||||
// Hook結果を即座にストリーミング
|
||||
if let HookResult::AddMessage(msg, role) = hook_result {
|
||||
yield Ok(StreamEvent::HookMessage {
|
||||
|
|
@ -468,13 +468,13 @@ pub async fn performance_aware_hook(context: HookContext) -> HookResult {
|
|||
// 大きなコンテンツの場合はスキップ
|
||||
return HookResult::Skip;
|
||||
}
|
||||
|
||||
|
||||
// 非同期処理は適切にawaitする
|
||||
let result = tokio::time::timeout(
|
||||
Duration::from_secs(5),
|
||||
expensive_operation(&context)
|
||||
).await;
|
||||
|
||||
|
||||
match result {
|
||||
Ok(output) => HookResult::AddMessage(output, Role::System),
|
||||
Err(_) => {
|
||||
|
|
@ -495,20 +495,20 @@ pub async fn configurable_hook(mut context: HookContext) -> HookResult {
|
|||
.unwrap_or_default()
|
||||
.parse::<bool>()
|
||||
.unwrap_or(false);
|
||||
|
||||
|
||||
if !enabled {
|
||||
return HookResult::Skip;
|
||||
}
|
||||
|
||||
|
||||
// 設定ファイルからオプション読み込み
|
||||
let config_path = format!("{}/.nia/hook_config.json", context.workspace_path);
|
||||
let config_path = format!("{}/hook_config.json", context.workspace_path);
|
||||
if let Ok(config_content) = tokio::fs::read_to_string(&config_path).await {
|
||||
if let Ok(config) = serde_json::from_str::<HookConfig>(&config_content) {
|
||||
// 設定に基づく処理
|
||||
return process_with_config(&mut context, &config).await;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
HookResult::Continue
|
||||
}
|
||||
```
|
||||
|
|
@ -523,7 +523,7 @@ pub async fn conditional_hook(context: HookContext) -> HookResult {
|
|||
let is_rust_project = tokio::fs::metadata(
|
||||
format!("{}/Cargo.toml", context.workspace_path)
|
||||
).await.is_ok();
|
||||
|
||||
|
||||
match (is_git_repo, is_rust_project) {
|
||||
(true, true) => {
|
||||
// Rustプロジェクト + Git
|
||||
|
|
@ -548,7 +548,7 @@ pub async fn conditional_hook(context: HookContext) -> HookResult {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use worker::types::*;
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timestamp_hook() {
|
||||
let mut context = HookContext {
|
||||
|
|
@ -561,9 +561,9 @@ mod tests {
|
|||
tool_args: None,
|
||||
tool_result: None,
|
||||
};
|
||||
|
||||
|
||||
let result = add_timestamp_hook(context).await;
|
||||
|
||||
|
||||
match result {
|
||||
HookResult::ModifyContent(content) => {
|
||||
assert!(content.contains("Hello, world!"));
|
||||
|
|
@ -587,7 +587,7 @@ pub async fn debug_hook(context: HookContext) -> HookResult {
|
|||
context.tools.len(),
|
||||
context.message_history.len()
|
||||
);
|
||||
|
||||
|
||||
// デバッグ情報をストリーミング
|
||||
context.stream_debug(
|
||||
"Hook Debug Info".to_string(),
|
||||
|
|
@ -598,7 +598,7 @@ pub async fn debug_hook(context: HookContext) -> HookResult {
|
|||
"workspace": context.workspace_path
|
||||
})
|
||||
);
|
||||
|
||||
|
||||
HookResult::Continue
|
||||
}
|
||||
```
|
||||
|
|
@ -625,13 +625,13 @@ impl WorkerHook for StatefulHook {
|
|||
fn name(&self) -> &str { "stateful_hook" }
|
||||
fn hook_type(&self) -> &str { "OnTurnCompleted" }
|
||||
fn matcher(&self) -> &str { "" }
|
||||
|
||||
|
||||
async fn execute(&self, mut context: HookContext) -> (HookContext, HookResult) {
|
||||
let mut count = self.counter.lock().unwrap();
|
||||
*count += 1;
|
||||
|
||||
|
||||
context.set_variable("turn_count".to_string(), count.to_string());
|
||||
|
||||
|
||||
if *count % 10 == 0 {
|
||||
(
|
||||
context,
|
||||
|
|
@ -658,7 +658,7 @@ impl HookChain {
|
|||
pub fn new() -> Self {
|
||||
Self { hooks: Vec::new() }
|
||||
}
|
||||
|
||||
|
||||
pub fn add_hook(mut self, hook: Box<dyn WorkerHook>) -> Self {
|
||||
self.hooks.push(hook);
|
||||
self
|
||||
|
|
@ -670,18 +670,18 @@ impl WorkerHook for HookChain {
|
|||
fn name(&self) -> &str { "hook_chain" }
|
||||
fn hook_type(&self) -> &str { "OnMessageSend" }
|
||||
fn matcher(&self) -> &str { "" }
|
||||
|
||||
|
||||
async fn execute(&self, mut context: HookContext) -> (HookContext, HookResult) {
|
||||
for hook in &self.hooks {
|
||||
let (new_context, result) = hook.execute(context).await;
|
||||
context = new_context;
|
||||
|
||||
|
||||
match result {
|
||||
HookResult::Continue | HookResult::Skip => continue,
|
||||
other => return (context, other),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
(context, HookResult::Continue)
|
||||
}
|
||||
}
|
||||
|
|
@ -715,4 +715,3 @@ A: `HookResult::Error`を返すと、そのターンは中断されます。継
|
|||
- [worker-macro.md](worker-macro.md) - マクロシステム
|
||||
- `worker/src/lib.rs` - Hook実装コード
|
||||
- `worker-types/src/lib.rs` - Hook型定義
|
||||
- `nia-cli/src/tui/hooks/` - TUI用Hook実装例
|
||||
|
|
@ -6,7 +6,7 @@ v0.3.0 はプロンプトリソースの解決責務を利用側へ完全に移
|
|||
|
||||
## Breaking Changes
|
||||
|
||||
- `ConfigParser::resolve_path` を削除し、`#nia/` `#workspace/` 等のプレフィックス解決をライブラリ利用者実装の `ResourceLoader` に委譲しました。
|
||||
- `ConfigParser::resolve_path` を削除し、`#user/` `#workspace/` 等のプレフィックス解決をライブラリ利用者実装の `ResourceLoader` に委譲しました。
|
||||
- `WorkerBuilder::build()` は `resource_loader(...)` が未指定の場合エラーを返すようになりました。ワーカー構築前に必ずローダーを提供してください。
|
||||
|
||||
## 新機能 / 仕様変更
|
||||
|
|
@ -19,7 +19,6 @@ v0.3.0 はプロンプトリソースの解決責務を利用側へ完全に移
|
|||
## 不具合修正
|
||||
|
||||
- `include_file` ヘルパーがカスタムローダーを利用せずにファイルアクセスしていた問題を修正。
|
||||
- `ConfigParser` が存在しない `#nia/` プレフィックスを静的に解決しようとしていた挙動を除去し、誤ったパスが静かに通ることを防止。
|
||||
|
||||
## 移行ガイド
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ v0.4.0 は Worker が `Role` や YAML 設定を扱わず、システムプロン
|
|||
|
||||
## 不具合修正
|
||||
|
||||
- Worker から NIA 固有の設定コードを除去し、環境依存の副作用を縮小。
|
||||
- Worker から旧プロジェクト固有の設定コードを除去し、環境依存の副作用を縮小。
|
||||
|
||||
## 移行ガイド
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
**Release Date**: 2025-10-25
|
||||
|
||||
v0.5.0 introduces the Worker Blueprint API and removes the old type-state builder. Configuration now lives on the blueprint, while instantiated workers keep only the materialised system prompt and runtime state.
|
||||
v0.5.0 introduces the Worker Blueprint API and removes the old type-state builder. Configuration now lives on the blueprint, while instantiated workers keep only the materialised system prompt, model metadata, and runtime state.
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
|
|
@ -12,9 +12,10 @@ v0.5.0 introduces the Worker Blueprint API and removes the old type-state builde
|
|||
|
||||
## New Features / Behaviour
|
||||
|
||||
- `WorkerBlueprint` stores provider/model/api keys, tools, hooks, and optional precomputed system prompt strings. `instantiate()` evaluates the prompt (if not already cached) and hands the final string to the `Worker`.
|
||||
- Instantiated workers retain only the composed system prompt string; the generator function lives solely on the blueprint and is dropped after instantiation.
|
||||
- `WorkerBlueprint` stores provider/model/api keys, tools, hooks, optional precomputed system prompt messages, and optional model feature flags. `instantiate()` evaluates the prompt (if not already cached) and hands the final string to the `Worker`.
|
||||
- Instantiated workers retain the composed system prompt, the original generator closure, and a `Model` struct describing provider/model/features; the generator only runs again if a new session requires it.
|
||||
- System prompts are no longer recomputed per turn. Tool metadata is appended dynamically as plain text when native tool support is unavailable.
|
||||
- Worker now exposes a `Model` struct (`provider`, `name`, `features`) in place of the previous loose strings and `supports_native_tools` helper. Capability heuristics remain for built-in providers but applications can override them via `WorkerBlueprint::model_features`.
|
||||
|
||||
## Migration Guide
|
||||
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
_messages: &[Message],
|
||||
) -> Result<String, PromptError> {
|
||||
Ok(format!(
|
||||
"You are helping with requests for model {} (provider {:?}).",
|
||||
ctx.model.model_name, ctx.model.provider
|
||||
"You are helping with requests for model {} (provider {}).",
|
||||
ctx.model.model_name, ctx.model.provider_id
|
||||
))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
use crate::LlmProviderExt;
|
||||
use crate::Worker;
|
||||
use crate::plugin;
|
||||
use crate::prompt::{PromptError, SystemPromptContext, SystemPromptFn};
|
||||
use crate::types::{HookManager, Tool, WorkerError, WorkerHook};
|
||||
use crate::{LlmProviderExt, Model, ModelFeatures, ModelProvider, Worker};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use worker_types::{LlmProvider, Message, Role};
|
||||
|
|
@ -24,6 +23,7 @@ pub struct WorkerBlueprint {
|
|||
pub(crate) tools: Vec<Box<dyn Tool>>,
|
||||
pub(crate) hooks: Vec<Box<dyn WorkerHook>>,
|
||||
pub(crate) prompt_cache: Option<Vec<Message>>,
|
||||
pub(crate) model_features: Option<ModelFeatures>,
|
||||
}
|
||||
|
||||
impl WorkerBlueprint {
|
||||
|
|
@ -36,6 +36,7 @@ impl WorkerBlueprint {
|
|||
tools: Vec::new(),
|
||||
hooks: Vec::new(),
|
||||
prompt_cache: None,
|
||||
model_features: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -61,6 +62,11 @@ impl WorkerBlueprint {
|
|||
self
|
||||
}
|
||||
|
||||
pub fn model_features(&mut self, features: ModelFeatures) -> &mut Self {
|
||||
self.model_features = Some(features);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn api_key(&mut self, provider: impl Into<String>, key: impl Into<String>) -> &mut Self {
|
||||
self.api_keys.insert(provider.into(), key.into());
|
||||
self
|
||||
|
|
@ -118,9 +124,27 @@ impl WorkerBlueprint {
|
|||
.map(|tool| tool.name().to_string())
|
||||
.collect();
|
||||
|
||||
let context = self.build_system_prompt_context(provider, &model_name, &tool_names);
|
||||
let features = self
|
||||
.model_features
|
||||
.clone()
|
||||
.unwrap_or_else(|| match provider {
|
||||
ProviderConfig::BuiltIn(p) => Worker::infer_model_features(Some(*p), &model_name),
|
||||
ProviderConfig::Plugin { .. } => ModelFeatures::default(),
|
||||
});
|
||||
|
||||
let preview_model = Model {
|
||||
provider: match provider {
|
||||
ProviderConfig::BuiltIn(p) => ModelProvider::BuiltIn(*p),
|
||||
ProviderConfig::Plugin { id, .. } => ModelProvider::Plugin(id.clone()),
|
||||
},
|
||||
name: model_name.clone(),
|
||||
features: features.clone(),
|
||||
};
|
||||
|
||||
let context = Worker::create_system_prompt_context(&preview_model, &tool_names);
|
||||
let prompt = generator(&context, &[]).map_err(|e| WorkerError::config(e.to_string()))?;
|
||||
self.prompt_cache = Some(vec![Message::new(Role::System, prompt)]);
|
||||
self.model_features = Some(features);
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
|
|
@ -138,30 +162,49 @@ impl WorkerBlueprint {
|
|||
.take()
|
||||
.ok_or_else(|| WorkerError::config("System prompt generator is not configured"))?;
|
||||
|
||||
let mut prompt_cache = self.prompt_cache.take();
|
||||
let mut provided_features = self.model_features.take();
|
||||
let tools = std::mem::take(&mut self.tools);
|
||||
let hooks = std::mem::take(&mut self.hooks);
|
||||
let mut api_keys = self.api_keys;
|
||||
|
||||
let tool_names: Vec<String> = tools.iter().map(|tool| tool.name().to_string()).collect();
|
||||
let provider_hint = provider_config.provider_hint();
|
||||
let prompt_context =
|
||||
Worker::create_system_prompt_context(provider_hint, &model_name, &tool_names);
|
||||
|
||||
let base_messages = match self.prompt_cache.take() {
|
||||
Some(messages) if !messages.is_empty() => messages,
|
||||
_ => {
|
||||
let prompt = system_prompt_fn(&prompt_context, &[])
|
||||
.map_err(|e| WorkerError::config(e.to_string()))?;
|
||||
vec![Message::new(Role::System, prompt)]
|
||||
}
|
||||
};
|
||||
let base_system_prompt = base_messages
|
||||
.first()
|
||||
.map(|msg| msg.content.clone())
|
||||
.unwrap_or_else(|| String::new());
|
||||
|
||||
match provider_config {
|
||||
ProviderConfig::BuiltIn(provider) => {
|
||||
let features = provided_features
|
||||
.take()
|
||||
.unwrap_or_else(|| Worker::infer_model_features(Some(provider), &model_name));
|
||||
|
||||
let model = Model {
|
||||
provider: ModelProvider::BuiltIn(provider),
|
||||
name: model_name.clone(),
|
||||
features,
|
||||
};
|
||||
|
||||
let prompt_context = Worker::create_system_prompt_context(&model, &tool_names);
|
||||
let base_messages = if let Some(messages) = prompt_cache.take() {
|
||||
if messages.is_empty() {
|
||||
vec![Message::new(
|
||||
Role::System,
|
||||
system_prompt_fn(&prompt_context, &[])
|
||||
.map_err(|e| WorkerError::config(e.to_string()))?,
|
||||
)]
|
||||
} else {
|
||||
messages
|
||||
}
|
||||
} else {
|
||||
vec![Message::new(
|
||||
Role::System,
|
||||
system_prompt_fn(&prompt_context, &[])
|
||||
.map_err(|e| WorkerError::config(e.to_string()))?,
|
||||
)]
|
||||
};
|
||||
|
||||
let base_system_prompt = base_messages
|
||||
.first()
|
||||
.map(|msg| msg.content.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
let api_key = api_keys
|
||||
.entry(provider.as_str().to_string())
|
||||
.or_insert_with(String::new)
|
||||
|
|
@ -170,13 +213,12 @@ impl WorkerBlueprint {
|
|||
let llm_client = provider.create_client(&model_name, &api_key)?;
|
||||
let mut worker = Worker {
|
||||
llm_client: Box::new(llm_client),
|
||||
system_prompt: base_system_prompt.clone(),
|
||||
system_prompt: base_system_prompt,
|
||||
system_prompt_fn: Arc::clone(&system_prompt_fn),
|
||||
tools,
|
||||
api_key,
|
||||
provider_str: provider.as_str().to_string(),
|
||||
model_name,
|
||||
message_history: base_messages.clone(),
|
||||
model,
|
||||
message_history: base_messages,
|
||||
hook_manager: HookManager::new(),
|
||||
mcp_lazy_configs: Vec::new(),
|
||||
plugin_registry: Arc::new(Mutex::new(plugin::PluginRegistry::new())),
|
||||
|
|
@ -185,6 +227,40 @@ impl WorkerBlueprint {
|
|||
Ok(worker)
|
||||
}
|
||||
ProviderConfig::Plugin { id, registry } => {
|
||||
let features = provided_features
|
||||
.take()
|
||||
.unwrap_or_else(ModelFeatures::default);
|
||||
|
||||
let model = Model {
|
||||
provider: ModelProvider::Plugin(id.clone()),
|
||||
name: model_name.clone(),
|
||||
features,
|
||||
};
|
||||
|
||||
let prompt_context = Worker::create_system_prompt_context(&model, &tool_names);
|
||||
let base_messages = if let Some(messages) = prompt_cache.take() {
|
||||
if messages.is_empty() {
|
||||
vec![Message::new(
|
||||
Role::System,
|
||||
system_prompt_fn(&prompt_context, &[])
|
||||
.map_err(|e| WorkerError::config(e.to_string()))?,
|
||||
)]
|
||||
} else {
|
||||
messages
|
||||
}
|
||||
} else {
|
||||
vec![Message::new(
|
||||
Role::System,
|
||||
system_prompt_fn(&prompt_context, &[])
|
||||
.map_err(|e| WorkerError::config(e.to_string()))?,
|
||||
)]
|
||||
};
|
||||
|
||||
let base_system_prompt = base_messages
|
||||
.first()
|
||||
.map(|msg| msg.content.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
let api_key = api_keys
|
||||
.remove("__plugin__")
|
||||
.or_else(|| api_keys.values().next().cloned())
|
||||
|
|
@ -208,8 +284,7 @@ impl WorkerBlueprint {
|
|||
system_prompt_fn,
|
||||
tools,
|
||||
api_key,
|
||||
provider_str: id,
|
||||
model_name,
|
||||
model,
|
||||
message_history: base_messages,
|
||||
hook_manager: HookManager::new(),
|
||||
mcp_lazy_configs: Vec::new(),
|
||||
|
|
@ -220,15 +295,6 @@ impl WorkerBlueprint {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_system_prompt_context(
|
||||
&self,
|
||||
provider: &ProviderConfig,
|
||||
model_name: &str,
|
||||
tool_names: &[String],
|
||||
) -> SystemPromptContext {
|
||||
Worker::create_system_prompt_context(provider.provider_hint(), model_name, tool_names)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderConfig {
|
||||
|
|
@ -245,11 +311,4 @@ impl ProviderConfig {
|
|||
ProviderConfig::Plugin { registry, .. } => Some(Arc::clone(registry)),
|
||||
}
|
||||
}
|
||||
|
||||
fn provider_hint(&self) -> LlmProvider {
|
||||
match self {
|
||||
ProviderConfig::BuiltIn(provider) => *provider,
|
||||
ProviderConfig::Plugin { .. } => LlmProvider::OpenAI,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,8 +8,6 @@ use llm::{
|
|||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tracing;
|
||||
use uuid;
|
||||
|
|
@ -324,127 +322,67 @@ pub async fn validate_api_key(
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ModelsConfig {
|
||||
pub models: Vec<ModelDefinition>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ModelDefinition {
|
||||
pub model: String,
|
||||
pub name: String,
|
||||
pub meta: ModelMeta,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ModelMeta {
|
||||
pub tool_support: bool,
|
||||
pub function_calling: bool,
|
||||
pub vision: bool,
|
||||
pub multimodal: bool,
|
||||
pub context_length: Option<u32>,
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
fn get_models_config_path() -> Result<PathBuf, WorkerError> {
|
||||
let home_dir = dirs::home_dir()
|
||||
.ok_or_else(|| WorkerError::config("Could not determine home directory"))?;
|
||||
Ok(home_dir.join(".config").join("nia").join("models.yaml"))
|
||||
}
|
||||
|
||||
fn load_models_config() -> Result<ModelsConfig, WorkerError> {
|
||||
let config_path = get_models_config_path()?;
|
||||
|
||||
if !config_path.exists() {
|
||||
tracing::warn!(
|
||||
"Models config file not found at {:?}, using defaults",
|
||||
config_path
|
||||
);
|
||||
return Ok(ModelsConfig { models: vec![] });
|
||||
}
|
||||
|
||||
let content = fs::read_to_string(&config_path)
|
||||
.map_err(|e| WorkerError::config(format!("Failed to read models config: {}", e)))?;
|
||||
|
||||
let config: ModelsConfig = serde_yaml::from_str(&content)
|
||||
.map_err(|e| WorkerError::config(format!("Failed to parse models config: {}", e)))?;
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
pub async fn supports_native_tools(
|
||||
provider: &LlmProvider,
|
||||
model_name: &str,
|
||||
_api_key: &str,
|
||||
) -> Result<bool, WorkerError> {
|
||||
let config = load_models_config()?;
|
||||
|
||||
let model_id = format!(
|
||||
"{}/{}",
|
||||
match provider {
|
||||
LlmProvider::Claude => "anthropic",
|
||||
LlmProvider::OpenAI => "openai",
|
||||
LlmProvider::Gemini => "gemini",
|
||||
LlmProvider::Ollama => "ollama",
|
||||
LlmProvider::XAI => "xai",
|
||||
},
|
||||
model_name
|
||||
);
|
||||
|
||||
for model_def in &config.models {
|
||||
if model_def.model == model_id || model_def.model.contains(model_name) {
|
||||
tracing::debug!(
|
||||
"Found model config: model={}, function_calling={}",
|
||||
model_def.model,
|
||||
model_def.meta.function_calling
|
||||
);
|
||||
return Ok(model_def.meta.function_calling);
|
||||
}
|
||||
}
|
||||
|
||||
tracing::warn!(
|
||||
"Model not found in config: {} ({}), using provider defaults",
|
||||
model_id,
|
||||
model_name
|
||||
);
|
||||
|
||||
tracing::warn!(
|
||||
"Using provider-based fallback - this should be configured in models.yaml: provider={:?}, model={}",
|
||||
provider,
|
||||
model_name
|
||||
);
|
||||
|
||||
let supports_tools = match provider {
|
||||
LlmProvider::Claude => true,
|
||||
LlmProvider::OpenAI => !model_name.contains("gpt-3.5-turbo-instruct"),
|
||||
LlmProvider::Gemini => !model_name.contains("gemini-pro-vision"),
|
||||
LlmProvider::Ollama => false,
|
||||
LlmProvider::XAI => true,
|
||||
};
|
||||
|
||||
tracing::debug!(
|
||||
"Fallback tool support check: provider={:?}, model={}, supports_tools={}",
|
||||
provider,
|
||||
model_name,
|
||||
supports_tools
|
||||
);
|
||||
Ok(supports_tools)
|
||||
}
|
||||
|
||||
pub struct Worker {
|
||||
pub(crate) llm_client: Box<dyn LlmClientTrait>,
|
||||
pub(crate) system_prompt: String,
|
||||
pub(crate) system_prompt_fn: Arc<SystemPromptFn>,
|
||||
pub(crate) tools: Vec<Box<dyn Tool>>,
|
||||
pub(crate) api_key: String,
|
||||
pub(crate) provider_str: String,
|
||||
pub(crate) model_name: String,
|
||||
pub(crate) model: Model,
|
||||
pub(crate) message_history: Vec<Message>,
|
||||
pub(crate) hook_manager: crate::types::HookManager,
|
||||
pub(crate) mcp_lazy_configs: Vec<McpServerConfig>,
|
||||
pub(crate) plugin_registry: std::sync::Arc<std::sync::Mutex<plugin::PluginRegistry>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ModelProvider {
|
||||
BuiltIn(LlmProvider),
|
||||
Plugin(String),
|
||||
}
|
||||
|
||||
impl ModelProvider {
|
||||
pub fn identifier(&self) -> String {
|
||||
match self {
|
||||
ModelProvider::BuiltIn(provider) => provider.as_str().to_string(),
|
||||
ModelProvider::Plugin(id) => id.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_llm_provider(&self) -> Option<LlmProvider> {
|
||||
match self {
|
||||
ModelProvider::BuiltIn(provider) => Some(*provider),
|
||||
ModelProvider::Plugin(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Model {
|
||||
pub provider: ModelProvider,
|
||||
pub name: String,
|
||||
pub features: ModelFeatures,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ModelFeatures {
|
||||
pub supports_tools: bool,
|
||||
pub supports_function_calling: bool,
|
||||
pub supports_vision: bool,
|
||||
pub supports_multimodal: bool,
|
||||
pub context_length: Option<u64>,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn provider_id(&self) -> String {
|
||||
self.provider.identifier()
|
||||
}
|
||||
|
||||
pub fn built_in_provider(&self) -> Option<LlmProvider> {
|
||||
self.provider.as_llm_provider()
|
||||
}
|
||||
}
|
||||
|
||||
impl Worker {
|
||||
/// Create a new Worker blueprint
|
||||
///
|
||||
|
|
@ -504,21 +442,20 @@ impl Worker {
|
|||
system_prompt_fn,
|
||||
tools,
|
||||
api_key,
|
||||
provider_str,
|
||||
model_name,
|
||||
model,
|
||||
message_history,
|
||||
hook_manager,
|
||||
mcp_lazy_configs: _,
|
||||
plugin_registry,
|
||||
} = self;
|
||||
|
||||
let provider = match LlmProvider::from_str(&provider_str) {
|
||||
Some(p) => {
|
||||
let provider = match &model.provider {
|
||||
ModelProvider::BuiltIn(p) => {
|
||||
drop(plugin_registry);
|
||||
ProviderConfig::BuiltIn(p)
|
||||
ProviderConfig::BuiltIn(*p)
|
||||
}
|
||||
None => ProviderConfig::Plugin {
|
||||
id: provider_str.clone(),
|
||||
ModelProvider::Plugin(id) => ProviderConfig::Plugin {
|
||||
id: id.clone(),
|
||||
registry: plugin_registry,
|
||||
},
|
||||
};
|
||||
|
|
@ -535,12 +472,13 @@ impl Worker {
|
|||
|
||||
WorkerBlueprint {
|
||||
provider: Some(provider),
|
||||
model_name: Some(model_name),
|
||||
model_name: Some(model.name),
|
||||
api_keys,
|
||||
system_prompt_fn: Some(system_prompt_fn),
|
||||
tools,
|
||||
hooks: hook_manager.into_hooks(),
|
||||
prompt_cache: Some(message_history),
|
||||
model_features: Some(model.features),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -815,27 +753,19 @@ impl Worker {
|
|||
|
||||
/// 静的プロンプトコンテキストを作成(構築時用)
|
||||
pub(crate) fn create_system_prompt_context(
|
||||
provider: LlmProvider,
|
||||
model_name: &str,
|
||||
model: &Model,
|
||||
tools: &[String],
|
||||
) -> crate::prompt::SystemPromptContext {
|
||||
let supports_native_tools = match provider {
|
||||
LlmProvider::Claude => true,
|
||||
LlmProvider::OpenAI => !model_name.contains("gpt-3.5-turbo-instruct"),
|
||||
LlmProvider::Gemini => !model_name.contains("gemini-pro-vision"),
|
||||
LlmProvider::Ollama => model_name.contains("llama") || model_name.contains("mistral"),
|
||||
LlmProvider::XAI => true,
|
||||
};
|
||||
|
||||
let model_context = crate::prompt::ModelContext {
|
||||
provider,
|
||||
model_name: model_name.to_string(),
|
||||
provider: model.built_in_provider(),
|
||||
provider_id: model.provider_id(),
|
||||
model_name: model.name.clone(),
|
||||
capabilities: crate::prompt::ModelCapabilities {
|
||||
supports_tools: supports_native_tools,
|
||||
supports_function_calling: supports_native_tools,
|
||||
supports_vision: false,
|
||||
supports_multimodal: Some(false),
|
||||
context_length: None,
|
||||
supports_tools: model.features.supports_tools,
|
||||
supports_function_calling: model.features.supports_function_calling,
|
||||
supports_vision: model.features.supports_vision,
|
||||
supports_multimodal: Some(model.features.supports_multimodal),
|
||||
context_length: model.features.context_length,
|
||||
capabilities: vec![],
|
||||
needs_verification: Some(false),
|
||||
},
|
||||
|
|
@ -850,23 +780,38 @@ impl Worker {
|
|||
}
|
||||
}
|
||||
|
||||
/// モデルを変更する
|
||||
pub fn change_model(
|
||||
&mut self,
|
||||
provider: LlmProvider,
|
||||
model_name: &str,
|
||||
api_key: &str,
|
||||
) -> Result<(), WorkerError> {
|
||||
// 新しいLLMクライアントを作成
|
||||
let new_client = provider.create_client(model_name, api_key)?;
|
||||
fn infer_model_features(provider: Option<LlmProvider>, model_name: &str) -> ModelFeatures {
|
||||
let mut features = ModelFeatures::default();
|
||||
|
||||
// 古いクライアントを新しいものに置き換え
|
||||
self.llm_client = Box::new(new_client);
|
||||
self.provider_str = provider.as_str().to_string();
|
||||
self.model_name = model_name.to_string();
|
||||
if let Some(provider) = provider {
|
||||
let supports_tools = match provider {
|
||||
LlmProvider::Claude => true,
|
||||
LlmProvider::OpenAI => !model_name.contains("gpt-3.5-turbo-instruct"),
|
||||
LlmProvider::Gemini => !model_name.contains("gemini-pro-vision"),
|
||||
LlmProvider::Ollama => {
|
||||
model_name.contains("llama") || model_name.contains("mistral")
|
||||
}
|
||||
LlmProvider::XAI => true,
|
||||
};
|
||||
|
||||
features.supports_tools = supports_tools;
|
||||
features.supports_function_calling = supports_tools;
|
||||
}
|
||||
|
||||
features
|
||||
}
|
||||
|
||||
/// モデルを変更する
|
||||
pub fn change_model(&mut self, model: Model, api_key: &str) -> Result<(), WorkerError> {
|
||||
if let Some(provider) = model.built_in_provider() {
|
||||
let new_client = provider.create_client(&model.name, api_key)?;
|
||||
self.llm_client = Box::new(new_client);
|
||||
}
|
||||
|
||||
self.model = model;
|
||||
self.api_key = api_key.to_string();
|
||||
|
||||
tracing::info!("Model changed to {}/{}", provider.as_str(), model_name);
|
||||
tracing::info!("Model changed to {}", self.model.provider_id());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
@ -931,16 +876,18 @@ impl Worker {
|
|||
|
||||
/// Get the model name for tool support detection
|
||||
pub fn get_model_name(&self) -> String {
|
||||
self.llm_client.get_model_name()
|
||||
self.model.name.clone()
|
||||
}
|
||||
|
||||
pub fn get_provider_name(&self) -> String {
|
||||
self.llm_client.provider().to_string()
|
||||
self.model.provider_id()
|
||||
}
|
||||
|
||||
/// Get configuration information for task delegation
|
||||
pub fn get_config(&self) -> (LlmProvider, &str, &str) {
|
||||
(self.llm_client.provider(), &self.model_name, &self.api_key)
|
||||
pub fn get_config(&self) -> Option<(LlmProvider, &str, &str)> {
|
||||
self.model
|
||||
.built_in_provider()
|
||||
.map(|provider| (provider, self.model.name.as_str(), self.api_key.as_str()))
|
||||
}
|
||||
|
||||
/// Get tool names (used to filter out specific tools)
|
||||
|
|
@ -1036,25 +983,37 @@ impl Worker {
|
|||
};
|
||||
|
||||
// Create a temporary worker for processing without holding the lock
|
||||
let (llm_client, system_prompt, tool_definitions, api_key, model_name) = {
|
||||
let (llm_client, system_prompt, tool_definitions, model) = {
|
||||
let w_locked = worker.lock().await;
|
||||
let llm_client = w_locked.llm_client.provider().create_client(&w_locked.model_name, &w_locked.api_key);
|
||||
match llm_client {
|
||||
let provider = match w_locked.model.built_in_provider() {
|
||||
Some(provider) => provider,
|
||||
None => {
|
||||
yield Err(WorkerError::config(
|
||||
"Delegated processing is not supported for plugin providers",
|
||||
));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match provider.create_client(&w_locked.model.name, &w_locked.api_key) {
|
||||
Ok(client) => {
|
||||
let tool_defs = w_locked.tools.iter().map(|tool| crate::types::DynamicToolDefinition {
|
||||
name: tool.name().to_string(),
|
||||
description: tool.description().to_string(),
|
||||
parameters_schema: tool.parameters_schema(),
|
||||
}).collect::<Vec<_>>();
|
||||
let tool_defs = w_locked
|
||||
.tools
|
||||
.iter()
|
||||
.map(|tool| crate::types::DynamicToolDefinition {
|
||||
name: tool.name().to_string(),
|
||||
description: tool.description().to_string(),
|
||||
parameters_schema: tool.parameters_schema(),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
(
|
||||
client,
|
||||
w_locked.system_prompt.clone(),
|
||||
tool_defs,
|
||||
w_locked.api_key.clone(),
|
||||
w_locked.model_name.clone()
|
||||
w_locked.model.clone(),
|
||||
)
|
||||
},
|
||||
}
|
||||
Err(e) => {
|
||||
yield Err(e);
|
||||
return;
|
||||
|
|
@ -1067,13 +1026,7 @@ impl Worker {
|
|||
|
||||
loop {
|
||||
let provider = llm_client.provider();
|
||||
let supports_native = match supports_native_tools(&provider, &model_name, &api_key).await {
|
||||
Ok(supports) => supports,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to check native tool support: {}", e);
|
||||
false
|
||||
}
|
||||
};
|
||||
let supports_native = model.features.supports_tools;
|
||||
|
||||
let (composed_messages, tools_for_llm) = if supports_native {
|
||||
let messages =
|
||||
|
|
@ -1254,16 +1207,21 @@ impl Worker {
|
|||
let tools = self.get_tools();
|
||||
let provider = self.llm_client.provider();
|
||||
let model_name = self.get_model_name();
|
||||
tracing::debug!("Checking native tool support: provider={:?}, model_name={}, api_key_len={}, provider_str={}", provider, model_name, self.api_key.len(), self.provider_str);
|
||||
let supports_native = match supports_native_tools(&provider, &model_name, &self.api_key).await {
|
||||
Ok(supports) => supports,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to check native tool support: {}", e);
|
||||
false
|
||||
}
|
||||
};
|
||||
tracing::debug!(
|
||||
"Checking native tool support: provider={:?}, model_name={}, api_key_len={}, provider_id={}",
|
||||
provider,
|
||||
model_name,
|
||||
self.api_key.len(),
|
||||
self.model.provider_id()
|
||||
);
|
||||
|
||||
tracing::info!("Model {} supports native tools: {}", model_name, supports_native);
|
||||
let supports_native = self.model.features.supports_tools;
|
||||
|
||||
tracing::info!(
|
||||
"Model {} supports native tools: {}",
|
||||
model_name,
|
||||
supports_native
|
||||
);
|
||||
|
||||
let (composed_messages, tools_for_llm) = if supports_native {
|
||||
// Native tools - basic composition
|
||||
|
|
@ -1471,15 +1429,16 @@ impl Worker {
|
|||
loop {
|
||||
let tools = self.get_tools();
|
||||
let provider = self.llm_client.provider();
|
||||
let model_name = self.get_model_name();
|
||||
tracing::debug!("Checking native tool support: provider={:?}, model_name={}, api_key_len={}, provider_str={}", provider, model_name, self.api_key.len(), self.provider_str);
|
||||
let supports_native = match supports_native_tools(&provider, &model_name, &self.api_key).await {
|
||||
Ok(supports) => supports,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to check native tool support: {}", e);
|
||||
false
|
||||
}
|
||||
};
|
||||
let model_name = self.model.name.clone();
|
||||
tracing::debug!(
|
||||
"Checking native tool support: provider={:?}, model_name={}, api_key_len={}, provider_id={}",
|
||||
provider,
|
||||
model_name,
|
||||
self.api_key.len(),
|
||||
self.model.provider_id()
|
||||
);
|
||||
|
||||
let supports_native = self.model.features.supports_tools;
|
||||
|
||||
tracing::info!("Model {} supports native tools: {}", model_name, supports_native);
|
||||
|
||||
|
|
@ -1673,8 +1632,8 @@ impl Worker {
|
|||
let session_id = uuid::Uuid::new_v4().to_string();
|
||||
let mut session_data = SessionData::new(
|
||||
session_id,
|
||||
self.provider_str.clone(),
|
||||
self.model_name.clone(),
|
||||
self.model.provider_id(),
|
||||
self.model.name.clone(),
|
||||
workspace_path,
|
||||
);
|
||||
session_data.git_branch = git_branch;
|
||||
|
|
@ -1688,15 +1647,15 @@ impl Worker {
|
|||
/// セッションデータから履歴を復元する
|
||||
pub fn load_session(&mut self, session_data: &SessionData) -> Result<(), WorkerError> {
|
||||
// モデルが異なる場合は警告をログに出す
|
||||
if session_data.model_provider != self.provider_str
|
||||
|| session_data.model_name != self.model_name
|
||||
if session_data.model_provider != self.model.provider_id()
|
||||
|| session_data.model_name != self.model.name
|
||||
{
|
||||
tracing::warn!(
|
||||
"Loading session with different model: session={}:{}, current={}:{}",
|
||||
session_data.model_provider,
|
||||
session_data.model_name,
|
||||
self.provider_str,
|
||||
self.model_name
|
||||
self.model.provider_id(),
|
||||
self.model.name
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,7 @@ use super::tool::McpServerConfig;
|
|||
use crate::types::WorkerError;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use tracing::{debug, info, warn};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
/// MCP設定ファイルの構造
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -57,121 +56,10 @@ fn default_integration_mode() -> IntegrationMode {
|
|||
}
|
||||
|
||||
impl McpConfig {
|
||||
/// 設定ファイルを読み込む
|
||||
pub fn load_from_file<P: AsRef<Path>>(path: P) -> Result<Self, WorkerError> {
|
||||
let path = path.as_ref();
|
||||
|
||||
if !path.exists() {
|
||||
debug!(
|
||||
"MCP config file not found at {:?}, returning empty config",
|
||||
path
|
||||
);
|
||||
return Ok(Self::default());
|
||||
}
|
||||
|
||||
info!("Loading MCP config from: {:?}", path);
|
||||
let content = std::fs::read_to_string(path).map_err(|e| {
|
||||
WorkerError::config(format!("Failed to read MCP config file {:?}: {}", path, e))
|
||||
})?;
|
||||
|
||||
let config: McpConfig = serde_yaml::from_str(&content).map_err(|e| {
|
||||
WorkerError::config(format!("Failed to parse MCP config file {:?}: {}", path, e))
|
||||
})?;
|
||||
|
||||
info!("Loaded {} MCP server configurations", config.servers.len());
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// 設定ファイルに保存する
|
||||
pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> Result<(), WorkerError> {
|
||||
let path = path.as_ref();
|
||||
|
||||
// ディレクトリが存在しない場合は作成
|
||||
if let Some(parent) = path.parent() {
|
||||
std::fs::create_dir_all(parent).map_err(|e| {
|
||||
WorkerError::config(format!(
|
||||
"Failed to create config directory {:?}: {}",
|
||||
parent, e
|
||||
))
|
||||
})?;
|
||||
}
|
||||
|
||||
let content = serde_yaml::to_string(self)
|
||||
.map_err(|e| WorkerError::config(format!("Failed to serialize MCP config: {}", e)))?;
|
||||
|
||||
std::fs::write(path, content).map_err(|e| {
|
||||
WorkerError::config(format!("Failed to write MCP config file {:?}: {}", path, e))
|
||||
})?;
|
||||
|
||||
info!("Saved MCP config to: {:?}", path);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 有効なサーバー設定を取得
|
||||
pub fn get_enabled_servers(&self) -> Vec<(&String, &McpServerDefinition)> {
|
||||
self.servers.iter().filter(|(_, def)| def.enabled).collect()
|
||||
}
|
||||
|
||||
/// デフォルト設定ファイルを生成
|
||||
pub fn create_default_config() -> Self {
|
||||
let mut servers = HashMap::new();
|
||||
|
||||
// Brave Search MCP Server の設定例
|
||||
servers.insert(
|
||||
"brave_search".to_string(),
|
||||
McpServerDefinition {
|
||||
command: "npx".to_string(),
|
||||
args: vec![
|
||||
"-y".to_string(),
|
||||
"@brave/brave-search-mcp-server".to_string(),
|
||||
],
|
||||
env: {
|
||||
let mut env = HashMap::new();
|
||||
env.insert("BRAVE_API_KEY".to_string(), "${BRAVE_API_KEY}".to_string());
|
||||
env
|
||||
},
|
||||
description: Some("Brave Search API for web searching".to_string()),
|
||||
enabled: false, // デフォルトでは無効(APIキーが必要なため)
|
||||
integration_mode: IntegrationMode::Individual,
|
||||
},
|
||||
);
|
||||
|
||||
// ファイルシステムMCPサーバーの設定例
|
||||
servers.insert(
|
||||
"filesystem".to_string(),
|
||||
McpServerDefinition {
|
||||
command: "npx".to_string(),
|
||||
args: vec![
|
||||
"-y".to_string(),
|
||||
"@modelcontextprotocol/server-filesystem".to_string(),
|
||||
"/tmp".to_string(),
|
||||
],
|
||||
env: HashMap::new(),
|
||||
description: Some("Filesystem operations in /tmp directory".to_string()),
|
||||
enabled: false, // デフォルトでは無効
|
||||
integration_mode: IntegrationMode::Individual,
|
||||
},
|
||||
);
|
||||
|
||||
// Git MCP サーバーの設定例
|
||||
servers.insert(
|
||||
"git".to_string(),
|
||||
McpServerDefinition {
|
||||
command: "npx".to_string(),
|
||||
args: vec![
|
||||
"-y".to_string(),
|
||||
"@modelcontextprotocol/server-git".to_string(),
|
||||
".".to_string(),
|
||||
],
|
||||
env: HashMap::new(),
|
||||
description: Some("Git operations in current directory".to_string()),
|
||||
enabled: false, // デフォルトでは無効
|
||||
integration_mode: IntegrationMode::Individual,
|
||||
},
|
||||
);
|
||||
|
||||
Self { servers }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for McpConfig {
|
||||
|
|
@ -242,25 +130,53 @@ fn expand_environment_variables(input: &str) -> Result<String, WorkerError> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::fs;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_default_config_creation() {
|
||||
let config = McpConfig::create_default_config();
|
||||
let mut config = McpConfig::default();
|
||||
config.servers.insert(
|
||||
"brave_search".to_string(),
|
||||
McpServerDefinition {
|
||||
command: "npx".to_string(),
|
||||
args: vec![
|
||||
"-y".to_string(),
|
||||
"@brave/brave-search-mcp-server".to_string(),
|
||||
],
|
||||
env: HashMap::new(),
|
||||
description: None,
|
||||
enabled: true,
|
||||
integration_mode: IntegrationMode::Individual,
|
||||
},
|
||||
);
|
||||
|
||||
assert!(!config.servers.is_empty());
|
||||
assert!(config.servers.contains_key("brave_search"));
|
||||
assert!(config.servers.contains_key("filesystem"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_serialization() {
|
||||
let config = McpConfig::create_default_config();
|
||||
let mut servers = HashMap::new();
|
||||
servers.insert(
|
||||
"filesystem".to_string(),
|
||||
McpServerDefinition {
|
||||
command: "npx".to_string(),
|
||||
args: vec![
|
||||
"-y".to_string(),
|
||||
"@modelcontextprotocol/server-filesystem".to_string(),
|
||||
],
|
||||
env: HashMap::new(),
|
||||
description: Some("Filesystem operations".to_string()),
|
||||
enabled: true,
|
||||
integration_mode: IntegrationMode::Proxy,
|
||||
},
|
||||
);
|
||||
|
||||
let config = McpConfig { servers };
|
||||
let yaml = serde_yaml::to_string(&config).unwrap();
|
||||
|
||||
// YAML形式で正しくシリアライズされることを確認
|
||||
assert!(yaml.contains("servers:"));
|
||||
assert!(yaml.contains("brave_search:"));
|
||||
assert!(yaml.contains("filesystem:"));
|
||||
assert!(yaml.contains("command:"));
|
||||
}
|
||||
|
||||
|
|
@ -303,23 +219,6 @@ servers:
|
|||
assert_eq!(result, "${NON_EXISTENT_VAR}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_file_operations() {
|
||||
let dir = tempdir().unwrap();
|
||||
let config_path = dir.path().join("mcp.yaml");
|
||||
|
||||
// 設定を作成して保存
|
||||
let config = McpConfig::create_default_config();
|
||||
config.save_to_file(&config_path).unwrap();
|
||||
|
||||
// ファイルが作成されたことを確認
|
||||
assert!(config_path.exists());
|
||||
|
||||
// 設定を読み込み
|
||||
let loaded_config = McpConfig::load_from_file(&config_path).unwrap();
|
||||
assert_eq!(config.servers.len(), loaded_config.servers.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_enabled_servers_filter() {
|
||||
let mut config = McpConfig::default();
|
||||
|
|
|
|||
|
|
@ -184,7 +184,7 @@ impl McpClient {
|
|||
}),
|
||||
},
|
||||
client_info: ClientInfo {
|
||||
name: "nia-worker".to_string(),
|
||||
name: "llm-worker-rs".to_string(),
|
||||
version: "0.1.0".to_string(),
|
||||
},
|
||||
};
|
||||
|
|
|
|||
|
|
@ -4,7 +4,8 @@ use std::collections::HashMap;
|
|||
/// モデルに関する静的な情報
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelContext {
|
||||
pub provider: crate::types::LlmProvider,
|
||||
pub provider: Option<crate::types::LlmProvider>,
|
||||
pub provider_id: String,
|
||||
pub model_name: String,
|
||||
pub capabilities: ModelCapabilities,
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user