diff --git a/README.md b/README.md index 7f7ee63..e606b41 100644 --- a/README.md +++ b/README.md @@ -23,9 +23,9 @@ use worker::{LlmProvider, SystemPromptContext, PromptError, Worker}; async fn main() -> Result<(), Box> { let system_prompt = |ctx: &SystemPromptContext, _messages: &[worker_types::Message]| { Ok(format!( - "You are assisting with model {} from provider {:?}.", + "You are assisting with model {} from provider {}.", ctx.model.model_name, - ctx.model.provider + ctx.model.provider_id )) }; diff --git a/docs/hooks.md b/docs/hooks.md index 5b68a67..0f3756f 100644 --- a/docs/hooks.md +++ b/docs/hooks.md @@ -31,7 +31,7 @@ pub struct HookManager { pub enum HookEvent { OnMessageSend, PreToolUse, - PostToolUse, + PostToolUse, OnTurnCompleted, } ``` @@ -181,10 +181,10 @@ impl HookContext { impl HookContext { // ストリーミング中にメッセージを送信 pub fn stream_message(&self, content: String, role: Role); - + // ストリーミング中にシステム通知を送信 pub fn stream_system_message(&self, content: String); - + // ストリーミング中にデバッグ情報を送信 pub fn stream_debug(&self, title: String, data: serde_json::Value); } @@ -198,22 +198,22 @@ Hook関数は以下のいずれかの結果を返す必要があります: pub enum HookResult { // 処理を続行 Continue, - + // コンテンツを変更して続行 ModifyContent(String), - + // システムメッセージを追加して続行 AddMessage(String, Role), - + // 複数のメッセージを追加して続行 AddMessages(Vec), - + // ターンを強制完了 Complete, - + // エラーでターンを終了 Error(String), - + // Hook処理をスキップ(デバッグ用) Skip, } @@ -282,7 +282,7 @@ pub async fn dangerous_command_hook(context: HookContext) -> HookResult { if let Some(args) = &context.tool_args { if let Some(command) = args.get("command").and_then(|v| v.as_str()) { let dangerous_commands = ["rm -rf", "format", "dd if="]; - + for dangerous in &dangerous_commands { if command.contains(dangerous) { return HookResult::Error(format!( @@ -293,7 +293,7 @@ pub async fn dangerous_command_hook(context: HookContext) -> HookResult { } } } - + HookResult::Continue } ``` @@ -326,7 +326,7 @@ pub async fn auto_read_hook(mut context: HookContext) -> HookResult { } } } - + HookResult::Continue } ``` @@ -361,7 +361,7 @@ worker.register_hooks(tui_hooks); ```rust // 実行順序の例 worker.register_hook(Box::new(TimestampHook)); // 1番目 -worker.register_hook(Box::new(ValidationHook)); // 2番目 +worker.register_hook(Box::new(ValidationHook)); // 2番目 worker.register_hook(Box::new(LoggingHook)); // 3番目 ``` @@ -375,13 +375,13 @@ worker.register_hook(Box::new(LoggingHook)); // 3番目 impl Worker { // Hook一覧を取得 pub fn list_hooks(&self) -> Vec<(&str, &str)>; // (name, hook_type) - + // 特定のHookを削除 pub fn remove_hook(&mut self, hook_name: &str) -> bool; - + // フェーズ別Hookを削除 pub fn remove_hooks_by_phase(&mut self, hook_type: &str); - + // すべてのHookをクリア pub fn clear_hooks(&mut self); } @@ -395,16 +395,16 @@ impl Worker { // worker/src/lib.rs の process_with_shared_state より stream! { // ... LLM応答処理中 ... - + // ツール呼び出し検出時 if let Some(tool_calls) = &response.tool_calls { for tool_call in tool_calls { // PreToolUse hooks 実行 let (context, hook_result) = execute_hooks( - HookEvent::PreToolUse, + HookEvent::PreToolUse, tool_call.name.clone() ).await; - + match hook_result { HookResult::Error(msg) => { yield Ok(StreamEvent::Error(msg)); @@ -413,16 +413,16 @@ stream! { HookResult::Complete => break, _ => {} } - + // ツール実行 let result = execute_tool(tool_call).await; - + // PostToolUse hooks 実行(ストリーミング中) let (context, hook_result) = execute_hooks( HookEvent::PostToolUse, tool_call.name.clone() ).await; - + // Hook結果を即座にストリーミング if let HookResult::AddMessage(msg, role) = hook_result { yield Ok(StreamEvent::HookMessage { @@ -468,13 +468,13 @@ pub async fn performance_aware_hook(context: HookContext) -> HookResult { // 大きなコンテンツの場合はスキップ return HookResult::Skip; } - + // 非同期処理は適切にawaitする let result = tokio::time::timeout( Duration::from_secs(5), expensive_operation(&context) ).await; - + match result { Ok(output) => HookResult::AddMessage(output, Role::System), Err(_) => { @@ -495,20 +495,20 @@ pub async fn configurable_hook(mut context: HookContext) -> HookResult { .unwrap_or_default() .parse::() .unwrap_or(false); - + if !enabled { return HookResult::Skip; } - + // 設定ファイルからオプション読み込み - let config_path = format!("{}/.nia/hook_config.json", context.workspace_path); + let config_path = format!("{}/hook_config.json", context.workspace_path); if let Ok(config_content) = tokio::fs::read_to_string(&config_path).await { if let Ok(config) = serde_json::from_str::(&config_content) { // 設定に基づく処理 return process_with_config(&mut context, &config).await; } } - + HookResult::Continue } ``` @@ -523,7 +523,7 @@ pub async fn conditional_hook(context: HookContext) -> HookResult { let is_rust_project = tokio::fs::metadata( format!("{}/Cargo.toml", context.workspace_path) ).await.is_ok(); - + match (is_git_repo, is_rust_project) { (true, true) => { // Rustプロジェクト + Git @@ -548,7 +548,7 @@ pub async fn conditional_hook(context: HookContext) -> HookResult { mod tests { use super::*; use worker::types::*; - + #[tokio::test] async fn test_timestamp_hook() { let mut context = HookContext { @@ -561,9 +561,9 @@ mod tests { tool_args: None, tool_result: None, }; - + let result = add_timestamp_hook(context).await; - + match result { HookResult::ModifyContent(content) => { assert!(content.contains("Hello, world!")); @@ -587,7 +587,7 @@ pub async fn debug_hook(context: HookContext) -> HookResult { context.tools.len(), context.message_history.len() ); - + // デバッグ情報をストリーミング context.stream_debug( "Hook Debug Info".to_string(), @@ -598,7 +598,7 @@ pub async fn debug_hook(context: HookContext) -> HookResult { "workspace": context.workspace_path }) ); - + HookResult::Continue } ``` @@ -625,13 +625,13 @@ impl WorkerHook for StatefulHook { fn name(&self) -> &str { "stateful_hook" } fn hook_type(&self) -> &str { "OnTurnCompleted" } fn matcher(&self) -> &str { "" } - + async fn execute(&self, mut context: HookContext) -> (HookContext, HookResult) { let mut count = self.counter.lock().unwrap(); *count += 1; - + context.set_variable("turn_count".to_string(), count.to_string()); - + if *count % 10 == 0 { ( context, @@ -658,7 +658,7 @@ impl HookChain { pub fn new() -> Self { Self { hooks: Vec::new() } } - + pub fn add_hook(mut self, hook: Box) -> Self { self.hooks.push(hook); self @@ -670,18 +670,18 @@ impl WorkerHook for HookChain { fn name(&self) -> &str { "hook_chain" } fn hook_type(&self) -> &str { "OnMessageSend" } fn matcher(&self) -> &str { "" } - + async fn execute(&self, mut context: HookContext) -> (HookContext, HookResult) { for hook in &self.hooks { let (new_context, result) = hook.execute(context).await; context = new_context; - + match result { HookResult::Continue | HookResult::Skip => continue, other => return (context, other), } } - + (context, HookResult::Continue) } } @@ -715,4 +715,3 @@ A: `HookResult::Error`を返すと、そのターンは中断されます。継 - [worker-macro.md](worker-macro.md) - マクロシステム - `worker/src/lib.rs` - Hook実装コード - `worker-types/src/lib.rs` - Hook型定義 -- `nia-cli/src/tui/hooks/` - TUI用Hook実装例 \ No newline at end of file diff --git a/docs/patch_note/v0.3.0.md b/docs/patch_note/v0.3.0.md index 37cce7d..73bd240 100644 --- a/docs/patch_note/v0.3.0.md +++ b/docs/patch_note/v0.3.0.md @@ -6,7 +6,7 @@ v0.3.0 はプロンプトリソースの解決責務を利用側へ完全に移 ## Breaking Changes -- `ConfigParser::resolve_path` を削除し、`#nia/` `#workspace/` 等のプレフィックス解決をライブラリ利用者実装の `ResourceLoader` に委譲しました。 +- `ConfigParser::resolve_path` を削除し、`#user/` `#workspace/` 等のプレフィックス解決をライブラリ利用者実装の `ResourceLoader` に委譲しました。 - `WorkerBuilder::build()` は `resource_loader(...)` が未指定の場合エラーを返すようになりました。ワーカー構築前に必ずローダーを提供してください。 ## 新機能 / 仕様変更 @@ -19,7 +19,6 @@ v0.3.0 はプロンプトリソースの解決責務を利用側へ完全に移 ## 不具合修正 - `include_file` ヘルパーがカスタムローダーを利用せずにファイルアクセスしていた問題を修正。 -- `ConfigParser` が存在しない `#nia/` プレフィックスを静的に解決しようとしていた挙動を除去し、誤ったパスが静かに通ることを防止。 ## 移行ガイド diff --git a/docs/patch_note/v0.4.0.md b/docs/patch_note/v0.4.0.md index 01b2610..8d2fb67 100644 --- a/docs/patch_note/v0.4.0.md +++ b/docs/patch_note/v0.4.0.md @@ -16,7 +16,7 @@ v0.4.0 は Worker が `Role` や YAML 設定を扱わず、システムプロン ## 不具合修正 -- Worker から NIA 固有の設定コードを除去し、環境依存の副作用を縮小。 +- Worker から旧プロジェクト固有の設定コードを除去し、環境依存の副作用を縮小。 ## 移行ガイド diff --git a/docs/patch_note/v0.5.0.md b/docs/patch_note/v0.5.0.md index b89e729..a1b4498 100644 --- a/docs/patch_note/v0.5.0.md +++ b/docs/patch_note/v0.5.0.md @@ -2,7 +2,7 @@ **Release Date**: 2025-10-25 -v0.5.0 introduces the Worker Blueprint API and removes the old type-state builder. Configuration now lives on the blueprint, while instantiated workers keep only the materialised system prompt and runtime state. +v0.5.0 introduces the Worker Blueprint API and removes the old type-state builder. Configuration now lives on the blueprint, while instantiated workers keep only the materialised system prompt, model metadata, and runtime state. ## Breaking Changes @@ -12,9 +12,10 @@ v0.5.0 introduces the Worker Blueprint API and removes the old type-state builde ## New Features / Behaviour -- `WorkerBlueprint` stores provider/model/api keys, tools, hooks, and optional precomputed system prompt strings. `instantiate()` evaluates the prompt (if not already cached) and hands the final string to the `Worker`. -- Instantiated workers retain only the composed system prompt string; the generator function lives solely on the blueprint and is dropped after instantiation. +- `WorkerBlueprint` stores provider/model/api keys, tools, hooks, optional precomputed system prompt messages, and optional model feature flags. `instantiate()` evaluates the prompt (if not already cached) and hands the final string to the `Worker`. +- Instantiated workers retain the composed system prompt, the original generator closure, and a `Model` struct describing provider/model/features; the generator only runs again if a new session requires it. - System prompts are no longer recomputed per turn. Tool metadata is appended dynamically as plain text when native tool support is unavailable. +- Worker now exposes a `Model` struct (`provider`, `name`, `features`) in place of the previous loose strings and `supports_native_tools` helper. Capability heuristics remain for built-in providers but applications can override them via `WorkerBlueprint::model_features`. ## Migration Guide diff --git a/worker/examples/builder_basic.rs b/worker/examples/builder_basic.rs index b49fb1a..25d2470 100644 --- a/worker/examples/builder_basic.rs +++ b/worker/examples/builder_basic.rs @@ -10,8 +10,8 @@ async fn main() -> Result<(), Box> { _messages: &[Message], ) -> Result { Ok(format!( - "You are helping with requests for model {} (provider {:?}).", - ctx.model.model_name, ctx.model.provider + "You are helping with requests for model {} (provider {}).", + ctx.model.model_name, ctx.model.provider_id )) } diff --git a/worker/src/blueprint.rs b/worker/src/blueprint.rs index 4b6dc7b..6b98ad5 100644 --- a/worker/src/blueprint.rs +++ b/worker/src/blueprint.rs @@ -1,8 +1,7 @@ -use crate::LlmProviderExt; -use crate::Worker; use crate::plugin; use crate::prompt::{PromptError, SystemPromptContext, SystemPromptFn}; use crate::types::{HookManager, Tool, WorkerError, WorkerHook}; +use crate::{LlmProviderExt, Model, ModelFeatures, ModelProvider, Worker}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; use worker_types::{LlmProvider, Message, Role}; @@ -24,6 +23,7 @@ pub struct WorkerBlueprint { pub(crate) tools: Vec>, pub(crate) hooks: Vec>, pub(crate) prompt_cache: Option>, + pub(crate) model_features: Option, } impl WorkerBlueprint { @@ -36,6 +36,7 @@ impl WorkerBlueprint { tools: Vec::new(), hooks: Vec::new(), prompt_cache: None, + model_features: None, } } @@ -61,6 +62,11 @@ impl WorkerBlueprint { self } + pub fn model_features(&mut self, features: ModelFeatures) -> &mut Self { + self.model_features = Some(features); + self + } + pub fn api_key(&mut self, provider: impl Into, key: impl Into) -> &mut Self { self.api_keys.insert(provider.into(), key.into()); self @@ -118,9 +124,27 @@ impl WorkerBlueprint { .map(|tool| tool.name().to_string()) .collect(); - let context = self.build_system_prompt_context(provider, &model_name, &tool_names); + let features = self + .model_features + .clone() + .unwrap_or_else(|| match provider { + ProviderConfig::BuiltIn(p) => Worker::infer_model_features(Some(*p), &model_name), + ProviderConfig::Plugin { .. } => ModelFeatures::default(), + }); + + let preview_model = Model { + provider: match provider { + ProviderConfig::BuiltIn(p) => ModelProvider::BuiltIn(*p), + ProviderConfig::Plugin { id, .. } => ModelProvider::Plugin(id.clone()), + }, + name: model_name.clone(), + features: features.clone(), + }; + + let context = Worker::create_system_prompt_context(&preview_model, &tool_names); let prompt = generator(&context, &[]).map_err(|e| WorkerError::config(e.to_string()))?; self.prompt_cache = Some(vec![Message::new(Role::System, prompt)]); + self.model_features = Some(features); Ok(self) } @@ -138,30 +162,49 @@ impl WorkerBlueprint { .take() .ok_or_else(|| WorkerError::config("System prompt generator is not configured"))?; + let mut prompt_cache = self.prompt_cache.take(); + let mut provided_features = self.model_features.take(); let tools = std::mem::take(&mut self.tools); let hooks = std::mem::take(&mut self.hooks); let mut api_keys = self.api_keys; - let tool_names: Vec = tools.iter().map(|tool| tool.name().to_string()).collect(); - let provider_hint = provider_config.provider_hint(); - let prompt_context = - Worker::create_system_prompt_context(provider_hint, &model_name, &tool_names); - - let base_messages = match self.prompt_cache.take() { - Some(messages) if !messages.is_empty() => messages, - _ => { - let prompt = system_prompt_fn(&prompt_context, &[]) - .map_err(|e| WorkerError::config(e.to_string()))?; - vec![Message::new(Role::System, prompt)] - } - }; - let base_system_prompt = base_messages - .first() - .map(|msg| msg.content.clone()) - .unwrap_or_else(|| String::new()); match provider_config { ProviderConfig::BuiltIn(provider) => { + let features = provided_features + .take() + .unwrap_or_else(|| Worker::infer_model_features(Some(provider), &model_name)); + + let model = Model { + provider: ModelProvider::BuiltIn(provider), + name: model_name.clone(), + features, + }; + + let prompt_context = Worker::create_system_prompt_context(&model, &tool_names); + let base_messages = if let Some(messages) = prompt_cache.take() { + if messages.is_empty() { + vec![Message::new( + Role::System, + system_prompt_fn(&prompt_context, &[]) + .map_err(|e| WorkerError::config(e.to_string()))?, + )] + } else { + messages + } + } else { + vec![Message::new( + Role::System, + system_prompt_fn(&prompt_context, &[]) + .map_err(|e| WorkerError::config(e.to_string()))?, + )] + }; + + let base_system_prompt = base_messages + .first() + .map(|msg| msg.content.clone()) + .unwrap_or_default(); + let api_key = api_keys .entry(provider.as_str().to_string()) .or_insert_with(String::new) @@ -170,13 +213,12 @@ impl WorkerBlueprint { let llm_client = provider.create_client(&model_name, &api_key)?; let mut worker = Worker { llm_client: Box::new(llm_client), - system_prompt: base_system_prompt.clone(), + system_prompt: base_system_prompt, system_prompt_fn: Arc::clone(&system_prompt_fn), tools, api_key, - provider_str: provider.as_str().to_string(), - model_name, - message_history: base_messages.clone(), + model, + message_history: base_messages, hook_manager: HookManager::new(), mcp_lazy_configs: Vec::new(), plugin_registry: Arc::new(Mutex::new(plugin::PluginRegistry::new())), @@ -185,6 +227,40 @@ impl WorkerBlueprint { Ok(worker) } ProviderConfig::Plugin { id, registry } => { + let features = provided_features + .take() + .unwrap_or_else(ModelFeatures::default); + + let model = Model { + provider: ModelProvider::Plugin(id.clone()), + name: model_name.clone(), + features, + }; + + let prompt_context = Worker::create_system_prompt_context(&model, &tool_names); + let base_messages = if let Some(messages) = prompt_cache.take() { + if messages.is_empty() { + vec![Message::new( + Role::System, + system_prompt_fn(&prompt_context, &[]) + .map_err(|e| WorkerError::config(e.to_string()))?, + )] + } else { + messages + } + } else { + vec![Message::new( + Role::System, + system_prompt_fn(&prompt_context, &[]) + .map_err(|e| WorkerError::config(e.to_string()))?, + )] + }; + + let base_system_prompt = base_messages + .first() + .map(|msg| msg.content.clone()) + .unwrap_or_default(); + let api_key = api_keys .remove("__plugin__") .or_else(|| api_keys.values().next().cloned()) @@ -208,8 +284,7 @@ impl WorkerBlueprint { system_prompt_fn, tools, api_key, - provider_str: id, - model_name, + model, message_history: base_messages, hook_manager: HookManager::new(), mcp_lazy_configs: Vec::new(), @@ -220,15 +295,6 @@ impl WorkerBlueprint { } } } - - fn build_system_prompt_context( - &self, - provider: &ProviderConfig, - model_name: &str, - tool_names: &[String], - ) -> SystemPromptContext { - Worker::create_system_prompt_context(provider.provider_hint(), model_name, tool_names) - } } impl ProviderConfig { @@ -245,11 +311,4 @@ impl ProviderConfig { ProviderConfig::Plugin { registry, .. } => Some(Arc::clone(registry)), } } - - fn provider_hint(&self) -> LlmProvider { - match self { - ProviderConfig::BuiltIn(provider) => *provider, - ProviderConfig::Plugin { .. } => LlmProvider::OpenAI, - } - } } diff --git a/worker/src/lib.rs b/worker/src/lib.rs index 066a19d..e153770 100644 --- a/worker/src/lib.rs +++ b/worker/src/lib.rs @@ -8,8 +8,6 @@ use llm::{ }; use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use std::fs; -use std::path::PathBuf; use std::sync::Arc; use tracing; use uuid; @@ -324,127 +322,67 @@ pub async fn validate_api_key( } } -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct ModelsConfig { - pub models: Vec, -} - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct ModelDefinition { - pub model: String, - pub name: String, - pub meta: ModelMeta, -} - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct ModelMeta { - pub tool_support: bool, - pub function_calling: bool, - pub vision: bool, - pub multimodal: bool, - pub context_length: Option, - pub description: Option, -} - -fn get_models_config_path() -> Result { - let home_dir = dirs::home_dir() - .ok_or_else(|| WorkerError::config("Could not determine home directory"))?; - Ok(home_dir.join(".config").join("nia").join("models.yaml")) -} - -fn load_models_config() -> Result { - let config_path = get_models_config_path()?; - - if !config_path.exists() { - tracing::warn!( - "Models config file not found at {:?}, using defaults", - config_path - ); - return Ok(ModelsConfig { models: vec![] }); - } - - let content = fs::read_to_string(&config_path) - .map_err(|e| WorkerError::config(format!("Failed to read models config: {}", e)))?; - - let config: ModelsConfig = serde_yaml::from_str(&content) - .map_err(|e| WorkerError::config(format!("Failed to parse models config: {}", e)))?; - - Ok(config) -} - -pub async fn supports_native_tools( - provider: &LlmProvider, - model_name: &str, - _api_key: &str, -) -> Result { - let config = load_models_config()?; - - let model_id = format!( - "{}/{}", - match provider { - LlmProvider::Claude => "anthropic", - LlmProvider::OpenAI => "openai", - LlmProvider::Gemini => "gemini", - LlmProvider::Ollama => "ollama", - LlmProvider::XAI => "xai", - }, - model_name - ); - - for model_def in &config.models { - if model_def.model == model_id || model_def.model.contains(model_name) { - tracing::debug!( - "Found model config: model={}, function_calling={}", - model_def.model, - model_def.meta.function_calling - ); - return Ok(model_def.meta.function_calling); - } - } - - tracing::warn!( - "Model not found in config: {} ({}), using provider defaults", - model_id, - model_name - ); - - tracing::warn!( - "Using provider-based fallback - this should be configured in models.yaml: provider={:?}, model={}", - provider, - model_name - ); - - let supports_tools = match provider { - LlmProvider::Claude => true, - LlmProvider::OpenAI => !model_name.contains("gpt-3.5-turbo-instruct"), - LlmProvider::Gemini => !model_name.contains("gemini-pro-vision"), - LlmProvider::Ollama => false, - LlmProvider::XAI => true, - }; - - tracing::debug!( - "Fallback tool support check: provider={:?}, model={}, supports_tools={}", - provider, - model_name, - supports_tools - ); - Ok(supports_tools) -} - pub struct Worker { pub(crate) llm_client: Box, pub(crate) system_prompt: String, pub(crate) system_prompt_fn: Arc, pub(crate) tools: Vec>, pub(crate) api_key: String, - pub(crate) provider_str: String, - pub(crate) model_name: String, + pub(crate) model: Model, pub(crate) message_history: Vec, pub(crate) hook_manager: crate::types::HookManager, pub(crate) mcp_lazy_configs: Vec, pub(crate) plugin_registry: std::sync::Arc>, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ModelProvider { + BuiltIn(LlmProvider), + Plugin(String), +} + +impl ModelProvider { + pub fn identifier(&self) -> String { + match self { + ModelProvider::BuiltIn(provider) => provider.as_str().to_string(), + ModelProvider::Plugin(id) => id.clone(), + } + } + + pub fn as_llm_provider(&self) -> Option { + match self { + ModelProvider::BuiltIn(provider) => Some(*provider), + ModelProvider::Plugin(_) => None, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Model { + pub provider: ModelProvider, + pub name: String, + pub features: ModelFeatures, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ModelFeatures { + pub supports_tools: bool, + pub supports_function_calling: bool, + pub supports_vision: bool, + pub supports_multimodal: bool, + pub context_length: Option, +} + +impl Model { + pub fn provider_id(&self) -> String { + self.provider.identifier() + } + + pub fn built_in_provider(&self) -> Option { + self.provider.as_llm_provider() + } +} + impl Worker { /// Create a new Worker blueprint /// @@ -504,21 +442,20 @@ impl Worker { system_prompt_fn, tools, api_key, - provider_str, - model_name, + model, message_history, hook_manager, mcp_lazy_configs: _, plugin_registry, } = self; - let provider = match LlmProvider::from_str(&provider_str) { - Some(p) => { + let provider = match &model.provider { + ModelProvider::BuiltIn(p) => { drop(plugin_registry); - ProviderConfig::BuiltIn(p) + ProviderConfig::BuiltIn(*p) } - None => ProviderConfig::Plugin { - id: provider_str.clone(), + ModelProvider::Plugin(id) => ProviderConfig::Plugin { + id: id.clone(), registry: plugin_registry, }, }; @@ -535,12 +472,13 @@ impl Worker { WorkerBlueprint { provider: Some(provider), - model_name: Some(model_name), + model_name: Some(model.name), api_keys, system_prompt_fn: Some(system_prompt_fn), tools, hooks: hook_manager.into_hooks(), prompt_cache: Some(message_history), + model_features: Some(model.features), } } @@ -815,27 +753,19 @@ impl Worker { /// 静的プロンプトコンテキストを作成(構築時用) pub(crate) fn create_system_prompt_context( - provider: LlmProvider, - model_name: &str, + model: &Model, tools: &[String], ) -> crate::prompt::SystemPromptContext { - let supports_native_tools = match provider { - LlmProvider::Claude => true, - LlmProvider::OpenAI => !model_name.contains("gpt-3.5-turbo-instruct"), - LlmProvider::Gemini => !model_name.contains("gemini-pro-vision"), - LlmProvider::Ollama => model_name.contains("llama") || model_name.contains("mistral"), - LlmProvider::XAI => true, - }; - let model_context = crate::prompt::ModelContext { - provider, - model_name: model_name.to_string(), + provider: model.built_in_provider(), + provider_id: model.provider_id(), + model_name: model.name.clone(), capabilities: crate::prompt::ModelCapabilities { - supports_tools: supports_native_tools, - supports_function_calling: supports_native_tools, - supports_vision: false, - supports_multimodal: Some(false), - context_length: None, + supports_tools: model.features.supports_tools, + supports_function_calling: model.features.supports_function_calling, + supports_vision: model.features.supports_vision, + supports_multimodal: Some(model.features.supports_multimodal), + context_length: model.features.context_length, capabilities: vec![], needs_verification: Some(false), }, @@ -850,23 +780,38 @@ impl Worker { } } - /// モデルを変更する - pub fn change_model( - &mut self, - provider: LlmProvider, - model_name: &str, - api_key: &str, - ) -> Result<(), WorkerError> { - // 新しいLLMクライアントを作成 - let new_client = provider.create_client(model_name, api_key)?; + fn infer_model_features(provider: Option, model_name: &str) -> ModelFeatures { + let mut features = ModelFeatures::default(); - // 古いクライアントを新しいものに置き換え - self.llm_client = Box::new(new_client); - self.provider_str = provider.as_str().to_string(); - self.model_name = model_name.to_string(); + if let Some(provider) = provider { + let supports_tools = match provider { + LlmProvider::Claude => true, + LlmProvider::OpenAI => !model_name.contains("gpt-3.5-turbo-instruct"), + LlmProvider::Gemini => !model_name.contains("gemini-pro-vision"), + LlmProvider::Ollama => { + model_name.contains("llama") || model_name.contains("mistral") + } + LlmProvider::XAI => true, + }; + + features.supports_tools = supports_tools; + features.supports_function_calling = supports_tools; + } + + features + } + + /// モデルを変更する + pub fn change_model(&mut self, model: Model, api_key: &str) -> Result<(), WorkerError> { + if let Some(provider) = model.built_in_provider() { + let new_client = provider.create_client(&model.name, api_key)?; + self.llm_client = Box::new(new_client); + } + + self.model = model; self.api_key = api_key.to_string(); - tracing::info!("Model changed to {}/{}", provider.as_str(), model_name); + tracing::info!("Model changed to {}", self.model.provider_id()); Ok(()) } @@ -931,16 +876,18 @@ impl Worker { /// Get the model name for tool support detection pub fn get_model_name(&self) -> String { - self.llm_client.get_model_name() + self.model.name.clone() } pub fn get_provider_name(&self) -> String { - self.llm_client.provider().to_string() + self.model.provider_id() } /// Get configuration information for task delegation - pub fn get_config(&self) -> (LlmProvider, &str, &str) { - (self.llm_client.provider(), &self.model_name, &self.api_key) + pub fn get_config(&self) -> Option<(LlmProvider, &str, &str)> { + self.model + .built_in_provider() + .map(|provider| (provider, self.model.name.as_str(), self.api_key.as_str())) } /// Get tool names (used to filter out specific tools) @@ -1036,25 +983,37 @@ impl Worker { }; // Create a temporary worker for processing without holding the lock - let (llm_client, system_prompt, tool_definitions, api_key, model_name) = { + let (llm_client, system_prompt, tool_definitions, model) = { let w_locked = worker.lock().await; - let llm_client = w_locked.llm_client.provider().create_client(&w_locked.model_name, &w_locked.api_key); - match llm_client { + let provider = match w_locked.model.built_in_provider() { + Some(provider) => provider, + None => { + yield Err(WorkerError::config( + "Delegated processing is not supported for plugin providers", + )); + return; + } + }; + + match provider.create_client(&w_locked.model.name, &w_locked.api_key) { Ok(client) => { - let tool_defs = w_locked.tools.iter().map(|tool| crate::types::DynamicToolDefinition { - name: tool.name().to_string(), - description: tool.description().to_string(), - parameters_schema: tool.parameters_schema(), - }).collect::>(); + let tool_defs = w_locked + .tools + .iter() + .map(|tool| crate::types::DynamicToolDefinition { + name: tool.name().to_string(), + description: tool.description().to_string(), + parameters_schema: tool.parameters_schema(), + }) + .collect::>(); ( client, w_locked.system_prompt.clone(), tool_defs, - w_locked.api_key.clone(), - w_locked.model_name.clone() + w_locked.model.clone(), ) - }, + } Err(e) => { yield Err(e); return; @@ -1067,13 +1026,7 @@ impl Worker { loop { let provider = llm_client.provider(); - let supports_native = match supports_native_tools(&provider, &model_name, &api_key).await { - Ok(supports) => supports, - Err(e) => { - tracing::warn!("Failed to check native tool support: {}", e); - false - } - }; + let supports_native = model.features.supports_tools; let (composed_messages, tools_for_llm) = if supports_native { let messages = @@ -1254,16 +1207,21 @@ impl Worker { let tools = self.get_tools(); let provider = self.llm_client.provider(); let model_name = self.get_model_name(); - tracing::debug!("Checking native tool support: provider={:?}, model_name={}, api_key_len={}, provider_str={}", provider, model_name, self.api_key.len(), self.provider_str); - let supports_native = match supports_native_tools(&provider, &model_name, &self.api_key).await { - Ok(supports) => supports, - Err(e) => { - tracing::warn!("Failed to check native tool support: {}", e); - false - } - }; + tracing::debug!( + "Checking native tool support: provider={:?}, model_name={}, api_key_len={}, provider_id={}", + provider, + model_name, + self.api_key.len(), + self.model.provider_id() + ); - tracing::info!("Model {} supports native tools: {}", model_name, supports_native); + let supports_native = self.model.features.supports_tools; + + tracing::info!( + "Model {} supports native tools: {}", + model_name, + supports_native + ); let (composed_messages, tools_for_llm) = if supports_native { // Native tools - basic composition @@ -1471,15 +1429,16 @@ impl Worker { loop { let tools = self.get_tools(); let provider = self.llm_client.provider(); - let model_name = self.get_model_name(); - tracing::debug!("Checking native tool support: provider={:?}, model_name={}, api_key_len={}, provider_str={}", provider, model_name, self.api_key.len(), self.provider_str); - let supports_native = match supports_native_tools(&provider, &model_name, &self.api_key).await { - Ok(supports) => supports, - Err(e) => { - tracing::warn!("Failed to check native tool support: {}", e); - false - } - }; + let model_name = self.model.name.clone(); + tracing::debug!( + "Checking native tool support: provider={:?}, model_name={}, api_key_len={}, provider_id={}", + provider, + model_name, + self.api_key.len(), + self.model.provider_id() + ); + + let supports_native = self.model.features.supports_tools; tracing::info!("Model {} supports native tools: {}", model_name, supports_native); @@ -1673,8 +1632,8 @@ impl Worker { let session_id = uuid::Uuid::new_v4().to_string(); let mut session_data = SessionData::new( session_id, - self.provider_str.clone(), - self.model_name.clone(), + self.model.provider_id(), + self.model.name.clone(), workspace_path, ); session_data.git_branch = git_branch; @@ -1688,15 +1647,15 @@ impl Worker { /// セッションデータから履歴を復元する pub fn load_session(&mut self, session_data: &SessionData) -> Result<(), WorkerError> { // モデルが異なる場合は警告をログに出す - if session_data.model_provider != self.provider_str - || session_data.model_name != self.model_name + if session_data.model_provider != self.model.provider_id() + || session_data.model_name != self.model.name { tracing::warn!( "Loading session with different model: session={}:{}, current={}:{}", session_data.model_provider, session_data.model_name, - self.provider_str, - self.model_name + self.model.provider_id(), + self.model.name ); } diff --git a/worker/src/mcp/config.rs b/worker/src/mcp/config.rs index ae326ef..a81892f 100644 --- a/worker/src/mcp/config.rs +++ b/worker/src/mcp/config.rs @@ -2,8 +2,7 @@ use super::tool::McpServerConfig; use crate::types::WorkerError; use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use std::path::Path; -use tracing::{debug, info, warn}; +use tracing::{debug, warn}; /// MCP設定ファイルの構造 #[derive(Debug, Clone, Serialize, Deserialize)] @@ -57,121 +56,10 @@ fn default_integration_mode() -> IntegrationMode { } impl McpConfig { - /// 設定ファイルを読み込む - pub fn load_from_file>(path: P) -> Result { - let path = path.as_ref(); - - if !path.exists() { - debug!( - "MCP config file not found at {:?}, returning empty config", - path - ); - return Ok(Self::default()); - } - - info!("Loading MCP config from: {:?}", path); - let content = std::fs::read_to_string(path).map_err(|e| { - WorkerError::config(format!("Failed to read MCP config file {:?}: {}", path, e)) - })?; - - let config: McpConfig = serde_yaml::from_str(&content).map_err(|e| { - WorkerError::config(format!("Failed to parse MCP config file {:?}: {}", path, e)) - })?; - - info!("Loaded {} MCP server configurations", config.servers.len()); - Ok(config) - } - - /// 設定ファイルに保存する - pub fn save_to_file>(&self, path: P) -> Result<(), WorkerError> { - let path = path.as_ref(); - - // ディレクトリが存在しない場合は作成 - if let Some(parent) = path.parent() { - std::fs::create_dir_all(parent).map_err(|e| { - WorkerError::config(format!( - "Failed to create config directory {:?}: {}", - parent, e - )) - })?; - } - - let content = serde_yaml::to_string(self) - .map_err(|e| WorkerError::config(format!("Failed to serialize MCP config: {}", e)))?; - - std::fs::write(path, content).map_err(|e| { - WorkerError::config(format!("Failed to write MCP config file {:?}: {}", path, e)) - })?; - - info!("Saved MCP config to: {:?}", path); - Ok(()) - } - /// 有効なサーバー設定を取得 pub fn get_enabled_servers(&self) -> Vec<(&String, &McpServerDefinition)> { self.servers.iter().filter(|(_, def)| def.enabled).collect() } - - /// デフォルト設定ファイルを生成 - pub fn create_default_config() -> Self { - let mut servers = HashMap::new(); - - // Brave Search MCP Server の設定例 - servers.insert( - "brave_search".to_string(), - McpServerDefinition { - command: "npx".to_string(), - args: vec![ - "-y".to_string(), - "@brave/brave-search-mcp-server".to_string(), - ], - env: { - let mut env = HashMap::new(); - env.insert("BRAVE_API_KEY".to_string(), "${BRAVE_API_KEY}".to_string()); - env - }, - description: Some("Brave Search API for web searching".to_string()), - enabled: false, // デフォルトでは無効(APIキーが必要なため) - integration_mode: IntegrationMode::Individual, - }, - ); - - // ファイルシステムMCPサーバーの設定例 - servers.insert( - "filesystem".to_string(), - McpServerDefinition { - command: "npx".to_string(), - args: vec![ - "-y".to_string(), - "@modelcontextprotocol/server-filesystem".to_string(), - "/tmp".to_string(), - ], - env: HashMap::new(), - description: Some("Filesystem operations in /tmp directory".to_string()), - enabled: false, // デフォルトでは無効 - integration_mode: IntegrationMode::Individual, - }, - ); - - // Git MCP サーバーの設定例 - servers.insert( - "git".to_string(), - McpServerDefinition { - command: "npx".to_string(), - args: vec![ - "-y".to_string(), - "@modelcontextprotocol/server-git".to_string(), - ".".to_string(), - ], - env: HashMap::new(), - description: Some("Git operations in current directory".to_string()), - enabled: false, // デフォルトでは無効 - integration_mode: IntegrationMode::Individual, - }, - ); - - Self { servers } - } } impl Default for McpConfig { @@ -242,25 +130,53 @@ fn expand_environment_variables(input: &str) -> Result { #[cfg(test)] mod tests { use super::*; - use std::fs; - use tempfile::tempdir; #[test] fn test_default_config_creation() { - let config = McpConfig::create_default_config(); + let mut config = McpConfig::default(); + config.servers.insert( + "brave_search".to_string(), + McpServerDefinition { + command: "npx".to_string(), + args: vec![ + "-y".to_string(), + "@brave/brave-search-mcp-server".to_string(), + ], + env: HashMap::new(), + description: None, + enabled: true, + integration_mode: IntegrationMode::Individual, + }, + ); + assert!(!config.servers.is_empty()); assert!(config.servers.contains_key("brave_search")); - assert!(config.servers.contains_key("filesystem")); } #[test] fn test_config_serialization() { - let config = McpConfig::create_default_config(); + let mut servers = HashMap::new(); + servers.insert( + "filesystem".to_string(), + McpServerDefinition { + command: "npx".to_string(), + args: vec![ + "-y".to_string(), + "@modelcontextprotocol/server-filesystem".to_string(), + ], + env: HashMap::new(), + description: Some("Filesystem operations".to_string()), + enabled: true, + integration_mode: IntegrationMode::Proxy, + }, + ); + + let config = McpConfig { servers }; let yaml = serde_yaml::to_string(&config).unwrap(); // YAML形式で正しくシリアライズされることを確認 assert!(yaml.contains("servers:")); - assert!(yaml.contains("brave_search:")); + assert!(yaml.contains("filesystem:")); assert!(yaml.contains("command:")); } @@ -303,23 +219,6 @@ servers: assert_eq!(result, "${NON_EXISTENT_VAR}"); } - #[test] - fn test_config_file_operations() { - let dir = tempdir().unwrap(); - let config_path = dir.path().join("mcp.yaml"); - - // 設定を作成して保存 - let config = McpConfig::create_default_config(); - config.save_to_file(&config_path).unwrap(); - - // ファイルが作成されたことを確認 - assert!(config_path.exists()); - - // 設定を読み込み - let loaded_config = McpConfig::load_from_file(&config_path).unwrap(); - assert_eq!(config.servers.len(), loaded_config.servers.len()); - } - #[test] fn test_enabled_servers_filter() { let mut config = McpConfig::default(); diff --git a/worker/src/mcp/protocol.rs b/worker/src/mcp/protocol.rs index 02e6220..e5271ad 100644 --- a/worker/src/mcp/protocol.rs +++ b/worker/src/mcp/protocol.rs @@ -184,7 +184,7 @@ impl McpClient { }), }, client_info: ClientInfo { - name: "nia-worker".to_string(), + name: "llm-worker-rs".to_string(), version: "0.1.0".to_string(), }, }; diff --git a/worker/src/prompt/types.rs b/worker/src/prompt/types.rs index dc11461..ba1c3e4 100644 --- a/worker/src/prompt/types.rs +++ b/worker/src/prompt/types.rs @@ -4,7 +4,8 @@ use std::collections::HashMap; /// モデルに関する静的な情報 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ModelContext { - pub provider: crate::types::LlmProvider, + pub provider: Option, + pub provider_id: String, pub model_name: String, pub capabilities: ModelCapabilities, }