diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index 7f45d64..0000000 --- a/CLAUDE.md +++ /dev/null @@ -1,55 +0,0 @@ -# CLAUDE.md - -This file provides guidance to Claude Code when working with code in this repository. - -## Development Commands - -- **Build**: `cargo build --workspace` or `cargo check --workspace` for quick validation -- **Tests**: `cargo test --workspace` (note: tests currently fail due to unsafe environment variable operations) -- **Fix warnings**: `cargo fix --lib -p worker` -- **Dev environment**: Uses Nix flake - run `nix develop` to enter dev shell with Rust toolchain - -## Architecture Overview - -This is a Rust workspace implementing a multi-LLM worker system with tool calling capabilities. The architecture follows the Core Crate Pattern to avoid circular dependencies: - -### Workspace Structure -- **`worker-types`**: Core type definitions (Tool trait, Message, StreamEvent, hook types) - foundational types used by all other crates -- **`worker-macros`**: Procedural macros (`#[tool]`, `#[hook]`) for automatic Tool and WorkerHook trait implementations -- **`worker`**: Main library containing Worker struct, LLM clients, prompt composer, and session management - -### Key Components - -**Worker**: Central orchestrator that manages LLM interactions, tool execution, and session state. Supports streaming responses via async streams. - -**LLM Providers**: Modular clients in `worker/src/llm/` for: -- Anthropic (Claude) -- Google (Gemini) -- OpenAI (GPT) -- xAI (Grok) -- Ollama (local models) - -**Tool System**: Uses `#[tool]` macro to convert async functions into LLM-callable tools. Tools must return `ToolResult` and take a single struct argument implementing `Deserialize + Serialize + JsonSchema`. - -**Hook System**: Uses `#[hook]` macro to create lifecycle event handlers. Hook functions take `HookContext` and return `(HookContext, HookResult)` tuple, supporting events like OnMessageSend, PreToolUse, PostToolUse, and OnTurnCompleted with regex matcher support. - -**Prompt Management**: Handlebars-based templating system for dynamic prompt generation based on roles and context. - -**MCP Integration**: Model Context Protocol support for external tool server communication. - -## Development Notes - -- Edition 2024 Rust with async/await throughout -- Uses `tokio` for async runtime and `reqwest` for HTTP clients -- Environment-based configuration for API keys and base URLs -- Session persistence using JSON serialization -- Streaming responses via `futures-util` streams -- Current test suite has unsafe environment variable operations that need `unsafe` blocks to compile - -## Important Patterns - -- **Error Handling**: `ToolResult` for tools, `WorkerError` for library operations with automatic conversions -- **Streaming**: All LLM interactions return `StreamEvent` streams for real-time UI updates -- **Tool Registration**: Dynamic tool registration at runtime using boxed trait objects -- **Hook Registration**: Dynamic hook registration with lifecycle event filtering and regex matching -- **Core Crate Pattern**: `worker-macros` references `worker-types` directly via complete paths to prevent circular dependencies \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 8f1def3..4aa809e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -741,6 +741,16 @@ version = "0.2.175" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" +[[package]] +name = "libloading" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" +dependencies = [ + "cfg-if", + "windows-targets 0.53.3", +] + [[package]] name = "libredox" version = "0.1.9" @@ -2214,6 +2224,7 @@ dependencies = [ "futures", "futures-util", "handlebars", + "libloading", "log", "regex", "reqwest", diff --git a/README.md b/README.md new file mode 100644 index 0000000..2b6fcb1 --- /dev/null +++ b/README.md @@ -0,0 +1,399 @@ +# `worker` + +`worker` クレートは、大規模言語モデル (LLM) を利用したアプリケーションのバックエンド機能を提供するクレートです。LLM プロバイダーの抽象化、ツール利用、柔軟なプロンプト管理、フックシステムなど、高度な機能をカプセル化し、アプリケーション開発を簡素化します。 + +## 主な機能 + +- **マルチプロバイダー対応**: Gemini, Claude, OpenAI, Ollama, XAI など、複数の LLM プロバイダーを統一されたインターフェースで利用できます。 +- **プラグインシステム**: カスタムプロバイダーをプラグインとして動的に追加できます。独自の LLM API や実験的なプロバイダーをサポートします。 +- **ツール利用 (Function Calling)**: LLM が外部ツールを呼び出す機能をサポートします。独自のツールをマクロを用いて定義し、`Worker` に登録できます。 +- **ストリーミング処理**: LLM の応答やツール実行結果を `StreamEvent` として非同期に受け取ることができます。これにより、リアルタイムな UI 更新が可能になります。 +- **フックシステム**: `Worker` の処理フローの特定のタイミング(例: メッセージ送信前、ツール使用後)にカスタムロジックを介入させることができます。 +- **セッション管理**: 会話履歴やワークスペースの状態を管理し、永続化する機能を提供します。 +- **柔軟なプロンプト管理**: 設定ファイルを用いて、ロールやコンテキストに応じたシステムプロンプトを動的に構築します。 + +## 主な概念 + +### `Worker` + +このクレートの中心的な構造体です。LLM との対話、ツールの登録と実行、セッション管理など、すべての主要な機能を担当します。 + +### `LlmProvider` + +サポートしている LLM プロバイダー(`Gemini`, `Claude`, `OpenAI` など)を表す enum です。 + +### `Tool` トレイト + +`Worker` が利用できるツールを定義するためのインターフェースです。このトレイトを実装することで、任意の機能をツールとして `Worker` に追加できます。 + +```rust +pub trait Tool: Send + Sync { + fn name(&self) -> &str; + fn description(&self) -> &str; + fn parameters_schema(&self) -> serde_json::Value; + async fn execute(&self, args: serde_json::Value) -> ToolResult; +} +``` + +### `WorkerHook` トレイト + +`Worker` のライフサイクルイベントに介入するためのフックを定義するインターフェースです。特定のイベント(例: `OnMessageSend`, `PostToolUse`)に対して処理を追加できます。 + +### `StreamEvent` + +`Worker` の処理結果を非同期ストリームで受け取るための enum です。LLM の応答チャンク、ツール呼び出し、エラーなど、さまざまなイベントを表します。 + +## アプリケーションへの組み込み方法 + +### 1. Worker の初期化 + +Builder patternを使用してWorkerを作成します。 + +```rust +use worker::{Worker, LlmProvider, Role}; +use std::collections::HashMap; + +// ロールを定義(必須) +let role = Role::new( + "assistant", + "AI Assistant", + "You are a helpful AI assistant." +); + +// APIキーを準備 +let mut api_keys = HashMap::new(); +api_keys.insert("openai".to_string(), "your_openai_api_key".to_string()); +api_keys.insert("claude".to_string(), "your_claude_api_key".to_string()); + +// Workerを作成(builder pattern) +let mut worker = Worker::builder() + .provider(LlmProvider::OpenAI) + .model("gpt-4o") + .api_keys(api_keys) + .role(role) + .build() + .expect("Workerの作成に失敗しました"); + +// または、個別にAPIキーを設定 +let worker = Worker::builder() + .provider(LlmProvider::Claude) + .model("claude-3-sonnet-20240229") + .api_key("claude", "sk-ant-...") + .role(role) + .build()?; +``` + +### 2. ツールの定義と登録 + +`Tool` トレイトを実装してカスタムツールを作成し、`Worker` に登録します。 + +```rust +use worker::{Tool, ToolResult}; +use worker::schemars::{self, JsonSchema}; +use worker::serde_json::{self, json, Value}; +use async_trait::async_trait; + +// ツールの引数を定義 +#[derive(Debug, serde::Deserialize, JsonSchema)] +struct FileSystemToolArgs { + path: String, +} + +// カスタムツールを定義 +struct ListFilesTool; + +#[async_trait] +impl Tool for ListFilesTool { + fn name(&self) -> &str { "list_files" } + fn description(&self) -> &str { "指定されたパスのファイル一覧を表示します" } + + fn parameters_schema(&self) -> Value { + serde_json::to_value(schemars::schema_for!(FileSystemToolArgs)).unwrap() + } + + async fn execute(&self, args: Value) -> ToolResult { + let tool_args: FileSystemToolArgs = serde_json::from_value(args)?; + // ここで実際のファイル一覧取得処理を実装 + let files = vec!["file1.txt", "file2.txt"]; + Ok(json!({ "files": files })) + } +} + +// 作成したツールをWorkerに登録 +worker.register_tool(Box::new(ListFilesTool)).unwrap(); +``` + +#### マクロを使ったツール定義(推奨) + +`worker-macros` クレートの `#[tool]` マクロを使用すると、ツールの定義がより簡潔になります: + +```rust +use worker_macros::tool; +use worker::ToolResult; +use serde::{Deserialize, Serialize}; +use schemars::JsonSchema; + +#[derive(Debug, Deserialize, Serialize, JsonSchema)] +struct ListFilesArgs { + path: String, +} + +#[tool] +async fn list_files(args: ListFilesArgs) -> ToolResult { + // ファイル一覧取得処理 + let files = vec!["file1.txt", "file2.txt"]; + Ok(serde_json::json!({ "files": files })) +} + +// マクロで生成されたツールを登録 +worker.register_tool(Box::new(ListFilesTool))?; +``` + +### 3. 対話処理の実行 + +`process_task_with_history` メソッドを呼び出して、ユーザーメッセージを処理します。このメソッドはイベントのストリームを返します。 + +```rust +use futures_util::StreamExt; + +let user_message = "カレントディレクトリのファイルを教えて".to_string(); + +let mut stream = worker.process_task_with_history(user_message, None).await; + +while let Some(event_result) = stream.next().await { + match event_result { + Ok(event) => { + // StreamEventに応じた処理 + match event { + worker::StreamEvent::Chunk(chunk) => { + print!("{}", chunk); + } + worker::StreamEvent::ToolCall(tool_call) => { + println!("\n[Tool Call: {} with args {}]", tool_call.name, tool_call.arguments); + } + worker::StreamEvent::ToolResult { tool_name, result } => { + println!("\n[Tool Result: {} -> {:?}]", tool_name, result); + } + _ => {} + } + } + Err(e) => { + eprintln!("\n[Error: {}]", e); + break; + } + } +} +``` + +### 4. (オプション) フックの登録 + +`WorkerHook` トレイトを実装してカスタムフックを作成し、`Worker` に登録することで、処理フローをカスタマイズできます。 + +#### 手動実装 + +```rust +use worker::{WorkerHook, HookContext, HookResult}; +use async_trait::async_trait; + +struct LoggingHook; + +#[async_trait] +impl WorkerHook for LoggingHook { + fn name(&self) -> &str { "logging_hook" } + fn hook_type(&self) -> &str { "OnMessageSend" } + fn matcher(&self) -> &str { "" } + + async fn execute(&self, context: HookContext) -> (HookContext, HookResult) { + println!("User message: {}", context.content); + (context, HookResult::Continue) + } +} + +// フックを登録 +worker.register_hook(Box::new(LoggingHook)); +``` + +#### マクロを使ったフック定義(推奨) + +`worker-macros` クレートの `#[hook]` マクロを使用すると、フックの定義がより簡潔になります: + +```rust +use worker_macros::hook; +use worker::{HookContext, HookResult}; + +#[hook(OnMessageSend)] +async fn logging_hook(context: HookContext) -> (HookContext, HookResult) { + println!("User message: {}", context.content); + (context, HookResult::Continue) +} + +// マクロで生成されたフックを登録 +worker.register_hook(Box::new(LoggingHook)); +``` + +**利用可能なフックタイプ:** +- `OnMessageSend`: ユーザーメッセージ送信前 +- `PreToolUse`: ツール実行前 +- `PostToolUse`: ツール実行後 +- `OnTurnCompleted`: ターン完了時 + +これで、アプリケーションの要件に応じて `Worker` を中心とした強力な LLM 連携機能を構築できます。 + +## サンプルコード + +完全な動作例は `worker/examples/` ディレクトリを参照してください: + +- [`builder_basic.rs`](worker/examples/builder_basic.rs) - Builder patternの基本的な使用方法 +- [`plugin_usage.rs`](worker/examples/plugin_usage.rs) - プラグインシステムの使用方法 + +## プラグインシステム + +### プラグインの作成 + +`ProviderPlugin` トレイトを実装してカスタムプロバイダーを作成できます: + +```rust +use worker::plugin::{ProviderPlugin, PluginMetadata}; +use async_trait::async_trait; + +pub struct MyCustomProvider { + // プロバイダーの状態 +} + +#[async_trait] +impl ProviderPlugin for MyCustomProvider { + fn metadata(&self) -> PluginMetadata { + PluginMetadata { + id: "my-provider".to_string(), + name: "My Custom Provider".to_string(), + version: "1.0.0".to_string(), + author: "Your Name".to_string(), + description: "カスタムプロバイダーの説明".to_string(), + supported_models: vec!["model-1".to_string()], + requires_api_key: true, + config_schema: None, + } + } + + async fn initialize(&mut self, config: HashMap) -> Result<(), WorkerError> { + // プロバイダーの初期化 + Ok(()) + } + + fn create_client( + &self, + model_name: &str, + api_key: Option<&str>, + config: Option>, + ) -> Result, WorkerError> { + // LLMクライアントを作成して返す + } + + fn as_any(&self) -> &dyn Any { + self + } +} +``` + +### プラグインの使用 + +```rust +use worker::{Worker, Role, plugin::PluginRegistry}; +use std::sync::{Arc, Mutex}; + +// プラグインレジストリを作成 +let plugin_registry = Arc::new(Mutex::new(PluginRegistry::new())); + +// プラグインを作成して登録 +let my_plugin = Arc::new(MyCustomProvider::new()); +{ + let mut registry = plugin_registry.lock().unwrap(); + registry.register(my_plugin)?; +} + +// ロールを定義 +let role = Role::new( + "assistant", + "AI Assistant", + "You are a helpful AI assistant." +); + +let worker = Worker::builder() + .plugin("my-provider", plugin_registry.clone()) + .model("model-1") + .api_key("__plugin__", "api-key") + .role(role) + .build()?; +``` + +### 動的プラグイン読み込み + +`dynamic-loading` フィーチャーを有効にすることで、共有ライブラリからプラグインを動的に読み込むことができます: + +```toml +[dependencies] +worker = { path = "../worker", features = ["dynamic-loading"] } +``` + +```rust +// ディレクトリからプラグインを読み込み +worker.load_plugins_from_directory(Path::new("./plugins")).await?; +``` + +完全な例は `worker/src/plugin/example_provider.rs` と `worker/examples/plugin_usage.rs` を参照してください。 + +## エラーハンドリング + +構造化されたエラー型により、詳細なエラー情報を取得できます。 + +### WorkerError の種類 + +```rust +use worker::WorkerError; + +// ツール実行エラー +let error = WorkerError::tool_execution("my_tool", "Connection failed"); +let error = WorkerError::tool_execution_with_source("my_tool", "Failed", source_error); + +// 設定エラー +let error = WorkerError::config("Invalid configuration"); +let error = WorkerError::config_with_context("Parse error", "config.yaml line 10"); +let error = WorkerError::config_with_source("Failed to load", io_error); + +// LLM APIエラー +let error = WorkerError::llm_api("openai", "Rate limit exceeded"); +let error = WorkerError::llm_api_with_details("claude", "Invalid request", Some(400), None); + +// モデルエラー +let error = WorkerError::model_not_found("openai", "gpt-5"); + +// ネットワークエラー +let error = WorkerError::network("Connection timeout"); +let error = WorkerError::network_with_source("Request failed", reqwest_error); +``` + +### エラーのパターンマッチング + +```rust +match worker.build() { + Ok(worker) => { /* ... */ }, + Err(WorkerError::ConfigurationError { message, context, .. }) => { + eprintln!("Configuration error: {}", message); + if let Some(ctx) = context { + eprintln!("Context: {}", ctx); + } + }, + Err(WorkerError::ModelNotFound { provider, model_name }) => { + eprintln!("Model '{}' not found for provider '{}'", model_name, provider); + }, + Err(WorkerError::LlmApiError { provider, message, status_code, .. }) => { + eprintln!("API error from {}: {}", provider, message); + if let Some(code) = status_code { + eprintln!("Status code: {}", code); + } + }, + Err(e) => { + eprintln!("Error: {}", e); + } +} +``` diff --git a/docs/patch_note/v0.2.0.md b/docs/patch_note/v0.2.0.md new file mode 100644 index 0000000..fb3aa9a --- /dev/null +++ b/docs/patch_note/v0.2.0.md @@ -0,0 +1,632 @@ +# Release Notes - v0.2.0 + +**Release Date**: 2025-10-14 + +## Overview +# Release Notes - v0.2.0 + +**Release Date**: 2025-10-14 + +## Overview + +Version 0.2.0 introduces significant improvements to the external API, including a builder pattern for Worker construction, simplified Role management, comprehensive error handling improvements, and a new plugin system for extensible LLM providers. + +## Breaking Changes + +### 1. Role System Redesign + +**Removed duplicate types:** +- `RoleConfig` has been removed from `worker-types` +- `PromptRoleConfig` has been renamed to `Role` +- All configuration is now unified under `worker::prompt_types::Role` + +**Built-in roles removed:** +- `Role::default()` - removed +- `Role::assistant()` - removed +- `Role::code_assistant()` - removed + +**Migration:** +```rust +// Before (v0.1.0) +let role = RoleConfig::default(); + +// After (v0.2.0) +use worker::prompt_types::Role; +let role = Role::new( + "assistant", + "A helpful AI assistant", + "You are a helpful, harmless, and honest AI assistant." +); +``` + +### 2. Worker Construction - Builder Pattern + +**New Type-state Builder API:** + +Worker construction now uses a type-state builder pattern that enforces required parameters at compile time. + +```rust +// Before (v0.1.0) +let worker = Worker::new( + LlmProvider::Claude, + "claude-3-sonnet-20240229", + &api_keys, + role, + plugin_registry, +)?; + +// After (v0.2.0) +let worker = Worker::builder() + .provider(LlmProvider::Claude) + .model("claude-3-sonnet-20240229") + .api_keys(api_keys) + .role(role) + .build()?; +``` + +**Required parameters (enforced at compile-time):** +- `provider(LlmProvider)` - LLM provider +- `model(&str)` - Model name +- `role(Role)` - System role configuration + +**Optional parameters:** +- `api_keys(HashMap)` - API keys for providers +- `plugin_id(&str)` - Plugin identifier for custom providers +- `plugin_registry(Arc>)` - Plugin registry + +### 3. WorkerError Redesign + +**Complete error type restructuring with structured error information:** + +**Removed duplicate variants:** +- `ToolExecution` (use `ToolExecutionError`) +- `Config` (use `ConfigurationError`) +- `Serialization` (use `JsonError`) + +**New structured error variants:** + +```rust +pub enum WorkerError { + ToolExecutionError { + tool_name: String, + reason: String, + source: Option>, + }, + LlmApiError { + provider: String, + message: String, + status_code: Option, + source: Option>, + }, + ConfigurationError { + message: String, + context: Option, + source: Option>, + }, + ModelNotFound { + provider: String, + model_name: String, + }, + Network { + message: String, + source: Option>, + }, + JsonError(serde_json::Error), + Other(String), +} +``` + +**New helper methods for ergonomic error construction:** + +```rust +// Tool execution errors +WorkerError::tool_execution(tool_name, reason) +WorkerError::tool_execution_with_source(tool_name, reason, source) + +// Configuration errors +WorkerError::config(message) +WorkerError::config_with_context(message, context) +WorkerError::config_with_source(message, source) + +// LLM API errors +WorkerError::llm_api(provider, message) +WorkerError::llm_api_with_details(provider, message, status_code, source) + +// Model errors +WorkerError::model_not_found(provider, model_name) + +// Network errors +WorkerError::network(message) +WorkerError::network_with_source(message, source) +``` + +**Migration:** +```rust +// Before (v0.1.0) +return Err(WorkerError::ToolExecutionError( + format!("Tool '{}' failed: {}", tool_name, reason) +)); + +// After (v0.2.0) +return Err(WorkerError::tool_execution(tool_name, reason)); + +// With source error +return Err(WorkerError::tool_execution_with_source( + tool_name, + reason, + Box::new(source_error) +)); +``` + +## New Features + +### 1. Plugin System + +A new extensible plugin system for custom LLM providers: + +**Core Components:** +- `ProviderPlugin` trait - Define custom LLM providers +- `PluginRegistry` - Manage plugin lifecycle +- `PluginClient` - Adapter for plugin-based clients +- `PluginMetadata` - Plugin identification and configuration + +**Example:** +```rust +use worker::plugin::{ProviderPlugin, PluginRegistry}; + +// Register a plugin +let mut registry = PluginRegistry::new(); +registry.register(Arc::new(CustomProviderPlugin::new()))?; + +// Use plugin with Worker +let worker = Worker::builder() + .provider(LlmProvider::OpenAI) // Base provider type + .model("custom-model") + .plugin_id("custom-provider") + .plugin_registry(Arc::new(Mutex::new(registry))) + .role(role) + .build()?; +``` + +**Plugin Features:** +- Custom provider implementation +- API key validation +- Model listing +- Health checks +- Dynamic configuration schemas +- Optional dynamic library loading (`#[cfg(feature = "dynamic-loading")]`) + +### 2. Examples + +**New example files:** + +1. **`worker/examples/builder_basic.rs`** + - Demonstrates the new builder pattern API + - Shows basic Worker construction + - Example tool registration + +2. **`worker/examples/plugin_usage.rs`** + - Plugin system usage example + - Custom provider implementation + - Plugin registration and usage + +## Improvements + +### Error Handling + +1. **Structured error information** - All errors now carry contextual information (provider, tool name, status codes, etc.) +2. **Error source tracking** - `#[source]` attribute enables full error chain tracing +3. **Automatic conversions** - Added `From` implementations for common error types: + - `From` → `Network` error + - `From` → `ConfigurationError` + - `From` → `JsonError` + - `From>` → `Other` error + +### Code Organization + +1. **Eliminated duplicate types** between `worker-types` and `worker` crates +2. **Clearer separation of concerns** - Role definition vs. PromptComposer execution +3. **Consistent error construction** - All error sites updated to use new helper methods + +## Files Changed + +**Modified (11 files):** +- `Cargo.lock` - Dependency updates +- `worker-types/src/lib.rs` - Removed duplicate types +- `worker/Cargo.toml` - Version and dependency updates +- `worker/src/config_parser.rs` - Role type updates +- `worker/src/lib.rs` - Builder pattern, error handling updates +- `worker/src/llm/gemini.rs` - Error handling improvements +- `worker/src/llm/ollama.rs` - Error handling improvements +- `worker/src/mcp_config.rs` - Error handling updates +- `worker/src/mcp_tool.rs` - Error handling improvements +- `worker/src/prompt_composer.rs` - Role type imports +- `worker/src/prompt_types.rs` - PromptRoleConfig → Role rename +- `worker/src/types.rs` - Complete WorkerError redesign + +**Added (5 files):** +- `worker/src/builder.rs` - Type-state builder implementation +- `worker/src/plugin/mod.rs` - Plugin system core +- `worker/src/plugin/example_provider.rs` - Example plugin implementation +- `worker/examples/builder_basic.rs` - Builder pattern example +- `worker/examples/plugin_usage.rs` - Plugin usage example + +**Renamed (1 file):** +- `worker/README.md` → `README.md` - Moved to root + +## Migration Guide + +### Step 1: Update Role Construction + +Replace all `RoleConfig` and built-in role usage: + +```rust +// Remove old imports +// use worker::RoleConfig; + +// Add new import +use worker::prompt_types::Role; + +// Replace role construction +let role = Role::new( + "your-role-name", + "Role description", + "Your system prompt template" +); +``` + +### Step 2: Update Worker Construction + +Replace direct `Worker::new()` calls with builder pattern: + +```rust +let worker = Worker::builder() + .provider(provider) + .model(model_name) + .api_keys(api_keys) + .role(role) + .build()?; +``` + +### Step 3: Update Error Handling + +Replace old error construction with new helper methods: + +```rust +// Configuration errors +WorkerError::ConfigurationError("message".to_string()) +// becomes: +WorkerError::config("message") + +// Tool errors +WorkerError::ToolExecutionError(format!("...")) +// becomes: +WorkerError::tool_execution(tool_name, reason) + +// Pattern matching +match error { + WorkerError::ConfigurationError(msg) => { /* ... */ } + // becomes: + WorkerError::ConfigurationError { message, .. } => { /* ... */ } +} +``` + +## Acknowledgments + +This release includes significant API improvements based on architectural review and refactoring to improve type safety, error handling, and code maintainability. + +## Next Steps + +See the examples in `worker/examples/` for complete usage demonstrations: +- `builder_basic.rs` - Basic Worker construction +- `plugin_usage.rs` - Plugin system usage + +## Breaking Changes + +### 1. Role System Redesign + +**Removed duplicate types:** +- `RoleConfig` has been removed from `worker-types` +- `PromptRoleConfig` has been renamed to `Role` +- All configuration is now unified under `worker::prompt_types::Role` + +**Built-in roles removed:** +- `Role::default()` - removed +- `Role::assistant()` - removed +- `Role::code_assistant()` - removed + +**Migration:** +```rust +// Before (v0.1.0) +let role = RoleConfig::default(); + +// After (v0.2.0) +use worker::prompt_types::Role; +let role = Role::new( + "assistant", + "A helpful AI assistant", + "You are a helpful, harmless, and honest AI assistant." +); +``` + +### 2. Worker Construction - Builder Pattern + +**New Type-state Builder API:** + +Worker construction now uses a type-state builder pattern that enforces required parameters at compile time. + +```rust +// Before (v0.1.0) +let worker = Worker::new( + LlmProvider::Claude, + "claude-3-sonnet-20240229", + &api_keys, + role, + plugin_registry, +)?; + +// After (v0.2.0) +let worker = Worker::builder() + .provider(LlmProvider::Claude) + .model("claude-3-sonnet-20240229") + .api_keys(api_keys) + .role(role) + .build()?; +``` + +**Required parameters (enforced at compile-time):** +- `provider(LlmProvider)` - LLM provider +- `model(&str)` - Model name +- `role(Role)` - System role configuration + +**Optional parameters:** +- `api_keys(HashMap)` - API keys for providers +- `plugin_id(&str)` - Plugin identifier for custom providers +- `plugin_registry(Arc>)` - Plugin registry + +### 3. WorkerError Redesign + +**Complete error type restructuring with structured error information:** + +**Removed duplicate variants:** +- `ToolExecution` (use `ToolExecutionError`) +- `Config` (use `ConfigurationError`) +- `Serialization` (use `JsonError`) + +**New structured error variants:** + +```rust +pub enum WorkerError { + ToolExecutionError { + tool_name: String, + reason: String, + source: Option>, + }, + LlmApiError { + provider: String, + message: String, + status_code: Option, + source: Option>, + }, + ConfigurationError { + message: String, + context: Option, + source: Option>, + }, + ModelNotFound { + provider: String, + model_name: String, + }, + Network { + message: String, + source: Option>, + }, + JsonError(serde_json::Error), + Other(String), +} +``` + +**New helper methods for ergonomic error construction:** + +```rust +// Tool execution errors +WorkerError::tool_execution(tool_name, reason) +WorkerError::tool_execution_with_source(tool_name, reason, source) + +// Configuration errors +WorkerError::config(message) +WorkerError::config_with_context(message, context) +WorkerError::config_with_source(message, source) + +// LLM API errors +WorkerError::llm_api(provider, message) +WorkerError::llm_api_with_details(provider, message, status_code, source) + +// Model errors +WorkerError::model_not_found(provider, model_name) + +// Network errors +WorkerError::network(message) +WorkerError::network_with_source(message, source) +``` + +**Migration:** +```rust +// Before (v0.1.0) +return Err(WorkerError::ToolExecutionError( + format!("Tool '{}' failed: {}", tool_name, reason) +)); + +// After (v0.2.0) +return Err(WorkerError::tool_execution(tool_name, reason)); + +// With source error +return Err(WorkerError::tool_execution_with_source( + tool_name, + reason, + Box::new(source_error) +)); +``` + +## New Features + +### 1. Plugin System + +A new extensible plugin system for custom LLM providers: + +**Core Components:** +- `ProviderPlugin` trait - Define custom LLM providers +- `PluginRegistry` - Manage plugin lifecycle +- `PluginClient` - Adapter for plugin-based clients +- `PluginMetadata` - Plugin identification and configuration + +**Example:** +```rust +use worker::plugin::{ProviderPlugin, PluginRegistry}; + +// Register a plugin +let mut registry = PluginRegistry::new(); +registry.register(Arc::new(CustomProviderPlugin::new()))?; + +// Use plugin with Worker +let worker = Worker::builder() + .provider(LlmProvider::OpenAI) // Base provider type + .model("custom-model") + .plugin_id("custom-provider") + .plugin_registry(Arc::new(Mutex::new(registry))) + .role(role) + .build()?; +``` + +**Plugin Features:** +- Custom provider implementation +- API key validation +- Model listing +- Health checks +- Dynamic configuration schemas +- Optional dynamic library loading (`#[cfg(feature = "dynamic-loading")]`) + +### 2. Examples + +**New example files:** + +1. **`worker/examples/builder_basic.rs`** + - Demonstrates the new builder pattern API + - Shows basic Worker construction + - Example tool registration + +2. **`worker/examples/plugin_usage.rs`** + - Plugin system usage example + - Custom provider implementation + - Plugin registration and usage + +## Improvements + +### Error Handling + +1. **Structured error information** - All errors now carry contextual information (provider, tool name, status codes, etc.) +2. **Error source tracking** - `#[source]` attribute enables full error chain tracing +3. **Automatic conversions** - Added `From` implementations for common error types: + - `From` → `Network` error + - `From` → `ConfigurationError` + - `From` → `JsonError` + - `From>` → `Other` error + +### Code Organization + +1. **Eliminated duplicate types** between `worker-types` and `worker` crates +2. **Clearer separation of concerns** - Role definition vs. PromptComposer execution +3. **Consistent error construction** - All error sites updated to use new helper methods + +## Files Changed + +**Modified (11 files):** +- `Cargo.lock` - Dependency updates +- `worker-types/src/lib.rs` - Removed duplicate types +- `worker/Cargo.toml` - Version and dependency updates +- `worker/src/config_parser.rs` - Role type updates +- `worker/src/lib.rs` - Builder pattern, error handling updates +- `worker/src/llm/gemini.rs` - Error handling improvements +- `worker/src/llm/ollama.rs` - Error handling improvements +- `worker/src/mcp_config.rs` - Error handling updates +- `worker/src/mcp_tool.rs` - Error handling improvements +- `worker/src/prompt_composer.rs` - Role type imports +- `worker/src/prompt_types.rs` - PromptRoleConfig → Role rename +- `worker/src/types.rs` - Complete WorkerError redesign + +**Added (5 files):** +- `worker/src/builder.rs` - Type-state builder implementation +- `worker/src/plugin/mod.rs` - Plugin system core +- `worker/src/plugin/example_provider.rs` - Example plugin implementation +- `worker/examples/builder_basic.rs` - Builder pattern example +- `worker/examples/plugin_usage.rs` - Plugin usage example + +**Renamed (1 file):** +- `worker/README.md` → `README.md` - Moved to root + +## Migration Guide + +### Step 1: Update Role Construction + +Replace all `RoleConfig` and built-in role usage: + +```rust +// Remove old imports +// use worker::RoleConfig; + +// Add new import +use worker::prompt_types::Role; + +// Replace role construction +let role = Role::new( + "your-role-name", + "Role description", + "Your system prompt template" +); +``` + +### Step 2: Update Worker Construction + +Replace direct `Worker::new()` calls with builder pattern: + +```rust +let worker = Worker::builder() + .provider(provider) + .model(model_name) + .api_keys(api_keys) + .role(role) + .build()?; +``` + +### Step 3: Update Error Handling + +Replace old error construction with new helper methods: + +```rust +// Configuration errors +WorkerError::ConfigurationError("message".to_string()) +// becomes: +WorkerError::config("message") + +// Tool errors +WorkerError::ToolExecutionError(format!("...")) +// becomes: +WorkerError::tool_execution(tool_name, reason) + +// Pattern matching +match error { + WorkerError::ConfigurationError(msg) => { /* ... */ } + // becomes: + WorkerError::ConfigurationError { message, .. } => { /* ... */ } +} +``` + +## Acknowledgments + +This release includes significant API improvements based on architectural review and refactoring to improve type safety, error handling, and code maintainability. + +## Next Steps + +See the examples in `worker/examples/` for complete usage demonstrations: +- `builder_basic.rs` - Basic Worker construction +- `plugin_usage.rs` - Plugin system usage diff --git a/worker-macros/src/lib.rs b/worker-macros/src/lib.rs index 57b3aa5..478e8c5 100644 --- a/worker-macros/src/lib.rs +++ b/worker-macros/src/lib.rs @@ -95,8 +95,8 @@ pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream { } // Implement Tool trait - #[::worker_types::async_trait::async_trait] - impl ::worker_types::Tool for #tool_struct_name { + #[::worker::types::async_trait::async_trait] + impl ::worker::types::Tool for #tool_struct_name { fn name(&self) -> &str { #tool_name_str } @@ -105,16 +105,16 @@ pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream { #description } - fn parameters_schema(&self) -> ::worker_types::serde_json::Value { - ::worker_types::serde_json::to_value(::worker_types::schemars::schema_for!(#arg_type)).unwrap() + fn parameters_schema(&self) -> ::worker::types::serde_json::Value { + ::worker::types::serde_json::to_value(::worker::types::schemars::schema_for!(#arg_type)).unwrap() } - async fn execute(&self, args: ::worker_types::serde_json::Value) -> ::worker_types::ToolResult<::worker_types::serde_json::Value> { - let typed_args: #arg_type = ::worker_types::serde_json::from_value(args)?; + async fn execute(&self, args: ::worker::types::serde_json::Value) -> ::worker::types::ToolResult<::worker::types::serde_json::Value> { + let typed_args: #arg_type = ::worker::types::serde_json::from_value(args)?; let result = #fn_name(typed_args).await?; // Use Display formatting instead of JSON serialization let formatted_result = format!("{}", result); - Ok(::worker_types::serde_json::Value::String(formatted_result)) + Ok(::worker::types::serde_json::Value::String(formatted_result)) } } @@ -270,8 +270,8 @@ pub fn hook(attr: TokenStream, item: TokenStream) -> TokenStream { } // Implement WorkerHook trait - #[::worker_types::async_trait::async_trait] - impl ::worker_types::WorkerHook for #hook_struct_name { + #[::worker::types::async_trait::async_trait] + impl ::worker::types::WorkerHook for #hook_struct_name { fn name(&self) -> &str { #fn_name_str } @@ -284,7 +284,7 @@ pub fn hook(attr: TokenStream, item: TokenStream) -> TokenStream { #matcher } - async fn execute(&self, context: ::worker_types::HookContext) -> (::worker_types::HookContext, ::worker_types::HookResult) { + async fn execute(&self, context: ::worker::types::HookContext) -> (::worker::types::HookContext, ::worker::types::HookResult) { #fn_name(context).await } } diff --git a/worker-types/src/lib.rs b/worker-types/src/lib.rs index 5aef49a..b4e1fbf 100644 --- a/worker-types/src/lib.rs +++ b/worker-types/src/lib.rs @@ -341,70 +341,6 @@ pub enum StreamEvent { }, } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RoleConfig { - pub name: String, - pub description: String, - #[serde(default)] - pub version: Option, - // 新しいテンプレートベースの設定 - pub template: Option, - pub partials: Option>, - // 従来の prompt フィールドもサポート(後方互換性のため) - pub prompt: Option, - #[serde(skip)] - pub path: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PartialConfig { - pub path: String, - #[serde(default)] - pub description: Option, -} - -impl Default for RoleConfig { - fn default() -> Self { - Self { - name: String::new(), - description: String::new(), - version: None, - template: None, - partials: None, - prompt: None, - path: None, - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -#[serde(untagged)] -pub enum PromptComponent { - #[default] - None, - Single(PromptComponentDetail), - Multiple(Vec), -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct PromptComponentDetail { - pub path: String, - #[serde(flatten)] - pub inner: Option>, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct PromptConfig { - #[serde(rename = "ROLE_DEFINE")] - pub role_define: Option, - #[serde(rename = "BASIS")] - pub basis: Option, - #[serde(rename = "TOOL_USE")] - pub tool_use: Option, - #[serde(rename = "SECTIONS")] - pub sections: Option, -} - // Session management types #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SessionData { diff --git a/worker/Cargo.toml b/worker/Cargo.toml index dcc9ca5..efdfdbb 100644 --- a/worker/Cargo.toml +++ b/worker/Cargo.toml @@ -36,6 +36,12 @@ regex = "1.10.2" uuid = { version = "1.10", features = ["v4", "serde"] } tokio-util = { version = "0.7", features = ["codec"] } futures = "0.3" +# Optional dependency for dynamic plugin loading +libloading = { version = "0.8", optional = true } + +[features] +default = [] +dynamic-loading = ["libloading"] [dev-dependencies] tempfile = "3.10.1" diff --git a/worker/README.md b/worker/README.md deleted file mode 100644 index 94e5053..0000000 --- a/worker/README.md +++ /dev/null @@ -1,150 +0,0 @@ -# `worker` クレート - -`worker` クレートは、大規模言語モデル (LLM) を利用したアプリケーションのバックエンド機能を提供するコアコンポーネントです。LLMプロバイダーの抽象化、ツール利用、柔軟なプロンプト管理、フックシステムなど、高度な機能をカプセル化し、アプリケーション開発を簡素化します。 - -## 主な機能 - -- **マルチプロバイダー対応**: Gemini, Claude, OpenAI, Ollama, XAIなど、複数のLLMプロバイダーを統一されたインターフェースで利用できます。 -- **ツール利用 (Function Calling)**: LLMが外部ツールを呼び出す機能をサポートします。独自のツールを簡単に定義して `Worker` に登録できます。 -- **ストリーミング処理**: LLMの応答やツール実行結果を `StreamEvent` として非同期に受け取ることができます。これにより、リアルタイムなUI更新が可能になります。 -- **フックシステム**: `Worker` の処理フローの特定のタイミング(例: メッセージ送信前、ツール使用後)にカスタムロジックを介入させることができます。 -- **セッション管理**: 会話履歴やワークスペースの状態を管理し、永続化する機能を提供します。 -- **柔軟なプロンプト管理**: 設定ファイルを用いて、ロールやコンテキストに応じたシステムプロンプトを動的に構築します。 - -## 主な概念 - -### `Worker` -このクレートの中心的な構造体です。LLMとの対話、ツールの登録と実行、セッション管理など、すべての主要な機能を担当します。 - -### `LlmProvider` -サポートしているLLMプロバイダー(`Gemini`, `Claude`, `OpenAI` など)を表すenumです。 - -### `Tool` トレイト -`Worker` が利用できるツールを定義するためのインターフェースです。このトレイトを実装することで、任意の機能をツールとして `Worker` に追加できます。 - -```rust -pub trait Tool: Send + Sync { - fn name(&self) -> &str; - fn description(&self) -> &str; - fn parameters_schema(&self) -> serde_json::Value; - async fn execute(&self, args: serde_json::Value) -> Result; -} -``` - -### `WorkerHook` トレイト -`Worker` のライフサイクルイベントに介入するためのフックを定義するインターフェースです。特定のイベント(例: `OnMessageSend`, `PostToolUse`)に対して処理を追加できます。 - -### `StreamEvent` -`Worker` の処理結果を非同期ストリームで受け取るためのenumです。LLMの応答チャンク、ツール呼び出し、エラーなど、さまざまなイベントを表します。 - -## アプリケーションへの組み込み方法 - -### 1. Workerの初期化 - -まず、`Worker` のインスタンスを作成します。これには `LlmProvider`、モデル名、APIキーが必要です。 - -```rust -use worker::{Worker, LlmProvider}; -use std::collections::HashMap; - -// APIキーを準備 -let mut api_keys = HashMap::new(); -api_keys.insert("openai".to_string(), "your_openai_api_key".to_string()); -api_keys.insert("claude".to_string(), "your_claude_api_key".to_string()); - -// Workerを作成 -let mut worker = Worker::new( - LlmProvider::OpenAI, - "gpt-4o", - &api_keys, - None // RoleConfigはオプション -).expect("Workerの作成に失敗しました"); -``` - -### 2. ツールの定義と登録 - -`Tool` トレイトを実装してカスタムツールを作成し、`Worker` に登録します。 - -```rust -use worker::{Tool, ToolResult}; -use worker::schemars::{self, JsonSchema}; -use worker::serde_json::{self, json, Value}; -use async_trait::async_trait; - -// ツールの引数を定義 -#[derive(Debug, serde::Deserialize, JsonSchema)] -struct FileSystemToolArgs { - path: String, -} - -// カスタムツールを定義 -struct ListFilesTool; - -#[async_trait] -impl Tool for ListFilesTool { - fn name(&self) -> &str { "list_files" } - fn description(&self) -> &str { "指定されたパスのファイル一覧を表示します" } - - fn parameters_schema(&self) -> Value { - serde_json::to_value(schemars::schema_for!(FileSystemToolArgs)).unwrap() - } - - async fn execute(&self, args: Value) -> ToolResult { - let tool_args: FileSystemToolArgs = serde_json::from_value(args)?; - // ここで実際のファイル一覧取得処理を実装 - let files = vec!["file1.txt", "file2.txt"]; - Ok(json!({ "files": files })) - } -} - -// 作成したツールをWorkerに登録 -worker.register_tool(Box::new(ListFilesTool)).unwrap(); -``` - -### 3. 対話処理の実行 - -`process_task_with_history` メソッドを呼び出して、ユーザーメッセージを処理します。このメソッドはイベントのストリームを返します。 - -```rust -use futures_util::StreamExt; - -let user_message = "カレントディレクトリのファイルを教えて".to_string(); - -let mut stream = worker.process_task_with_history(user_message, None).await; - -while let Some(event_result) = stream.next().await { - match event_result { - Ok(event) => { - // StreamEventに応じた処理 - match event { - worker::StreamEvent::Chunk(chunk) => { - print!("{}", chunk); - } - worker::StreamEvent::ToolCall(tool_call) => { - println!("\n[Tool Call: {} with args {}]", tool_call.name, tool_call.arguments); - } - worker::StreamEvent::ToolResult { tool_name, result } => { - println!("\n[Tool Result: {} -> {:?}]", tool_name, result); - } - _ => {} - } - } - Err(e) => { - eprintln!("\n[Error: {}]", e); - break; - } - } -} -``` - -### 4. (オプション) フックの登録 - -`WorkerHook` トレイトを実装してカスタムフックを作成し、`Worker` に登録することで、処理フローをカスタマイズできます。 - -```rust -// (WorkerHookの実装は省略) -// let my_hook = MyCustomHook::new(); -// worker.register_hook(Box::new(my_hook)); -``` - -これで、アプリケーションの要件に応じて `Worker` を中心とした強力なLLM連携機能を構築できます。 diff --git a/worker/examples/builder_basic.rs b/worker/examples/builder_basic.rs new file mode 100644 index 0000000..cde9863 --- /dev/null +++ b/worker/examples/builder_basic.rs @@ -0,0 +1,44 @@ +use worker::{LlmProvider, Worker, Role}; +use std::collections::HashMap; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Example 1: Basic role + let role = Role::new( + "assistant", + "A helpful AI assistant", + "You are a helpful, harmless, and honest AI assistant.", + ); + + let mut api_keys = HashMap::new(); + api_keys.insert("claude".to_string(), std::env::var("ANTHROPIC_API_KEY")?); + + let worker = Worker::builder() + .provider(LlmProvider::Claude) + .model("claude-3-sonnet-20240229") + .api_keys(api_keys) + .role(role) + .build()?; + + println!("✅ Worker created with builder pattern"); + println!(" Provider: {:?}", worker.get_provider_name()); + println!(" Model: {}", worker.get_model_name()); + + // Example 2: Code reviewer role + let code_reviewer_role = Role::new( + "code-reviewer", + "An AI that reviews code for best practices", + "You are an expert code reviewer. Always provide constructive feedback.", + ); + + let _worker2 = Worker::builder() + .provider(LlmProvider::Claude) + .model("claude-3-sonnet-20240229") + .api_key("claude", std::env::var("ANTHROPIC_API_KEY")?) + .role(code_reviewer_role) + .build()?; + + println!("✅ Worker created with custom role"); + + Ok(()) +} diff --git a/worker/examples/plugin_usage.rs b/worker/examples/plugin_usage.rs new file mode 100644 index 0000000..861e32f --- /dev/null +++ b/worker/examples/plugin_usage.rs @@ -0,0 +1,82 @@ +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use worker::{Worker, Role, plugin::{PluginRegistry, ProviderPlugin, example_provider::CustomProviderPlugin}}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Initialize tracing for debugging + tracing_subscriber::fmt::init(); + + // Create a plugin registry + let plugin_registry = Arc::new(Mutex::new(PluginRegistry::new())); + + // Create and initialize a custom provider plugin + let mut custom_plugin = CustomProviderPlugin::new(); + + let mut config = HashMap::new(); + config.insert( + "base_url".to_string(), + serde_json::Value::String("https://api.custom-provider.com".to_string()), + ); + config.insert( + "timeout".to_string(), + serde_json::Value::Number(serde_json::Number::from(60)), + ); + + custom_plugin.initialize(config).await?; + + // Register the plugin + { + let mut registry = plugin_registry.lock().unwrap(); + registry.register(Arc::new(custom_plugin))?; + } + + // List available plugins + { + let registry = plugin_registry.lock().unwrap(); + let plugins = registry.list(); + println!("Available plugins:"); + for plugin in plugins { + println!(" - {} ({}): {}", plugin.name, plugin.id, plugin.description); + println!(" Supported models: {:?}", plugin.supported_models); + } + } + + // Create a Worker instance using the plugin + let role = Role::new( + "assistant", + "A helpful AI assistant", + "You are a helpful, harmless, and honest AI assistant powered by a custom LLM provider." + ); + + let worker = Worker::builder() + .plugin("custom-provider", plugin_registry.clone()) + .model("custom-turbo") + .api_key("__plugin__", "custom-1234567890abcdefghijklmnop") + .role(role) + .build()?; + + println!("\nWorker created successfully with custom provider plugin!"); + + // Example: List plugins from the worker + let plugin_list = worker.list_plugins()?; + println!("\nPlugins registered in worker:"); + for metadata in plugin_list { + println!(" - {}: v{} by {}", metadata.name, metadata.version, metadata.author); + } + + // Load plugins from directory (if dynamic loading is enabled) + #[cfg(feature = "dynamic-loading")] + { + use std::path::Path; + + let plugin_dir = Path::new("./plugins"); + if plugin_dir.exists() { + let mut worker = worker; + worker.load_plugins_from_directory(plugin_dir).await?; + println!("\nLoaded plugins from directory: {:?}", plugin_dir); + } + } + + Ok(()) +} \ No newline at end of file diff --git a/worker/src/builder.rs b/worker/src/builder.rs new file mode 100644 index 0000000..87eee62 --- /dev/null +++ b/worker/src/builder.rs @@ -0,0 +1,260 @@ +use crate::Worker; +use crate::prompt::Role; +use crate::types::WorkerError; +use worker_types::LlmProvider; +use std::collections::HashMap; +use std::marker::PhantomData; +use std::sync::{Arc, Mutex}; + +// Type-state markers +pub struct NoProvider; +pub struct WithProvider; +pub struct NoModel; +pub struct WithModel; +pub struct NoRole; +pub struct WithRole; + +/// WorkerBuilder with type-state pattern +/// +/// This ensures at compile-time that all required fields are set. +/// +/// # Example +/// ```no_run +/// use worker::{Worker, LlmProvider, Role}; +/// +/// let role = Role::new("assistant", "AI Assistant", "You are a helpful assistant."); +/// let worker = Worker::builder() +/// .provider(LlmProvider::Claude) +/// .model("claude-3-sonnet-20240229") +/// .api_key("claude", "sk-ant-...") +/// .role(role) +/// .build()?; +/// # Ok::<(), worker::WorkerError>(()) +/// ``` +pub struct WorkerBuilder { + provider: Option, + model_name: Option, + api_keys: HashMap, + + // Role configuration (required) + role: Option, + + // Plugin configuration + plugin_id: Option, + plugin_registry: Option>>, + + _phantom: PhantomData<(P, M, R)>, +} + +impl Default for WorkerBuilder { + fn default() -> Self { + Self { + provider: None, + model_name: None, + api_keys: HashMap::new(), + role: None, + plugin_id: None, + plugin_registry: None, + _phantom: PhantomData, + } + } +} + +impl WorkerBuilder { + /// Create a new WorkerBuilder + pub fn new() -> Self { + Self::default() + } +} + +// Step 1: Set provider +impl WorkerBuilder { + pub fn provider(mut self, provider: LlmProvider) -> WorkerBuilder { + self.provider = Some(provider); + WorkerBuilder { + provider: self.provider, + model_name: self.model_name, + api_keys: self.api_keys, + role: self.role, + plugin_id: self.plugin_id, + plugin_registry: self.plugin_registry, + _phantom: PhantomData, + } + } + + /// Use a plugin provider instead of built-in provider + pub fn plugin( + mut self, + plugin_id: impl Into, + registry: Arc>, + ) -> WorkerBuilder { + self.plugin_id = Some(plugin_id.into()); + self.plugin_registry = Some(registry); + WorkerBuilder { + provider: None, + model_name: self.model_name, + api_keys: self.api_keys, + role: self.role, + plugin_id: self.plugin_id, + plugin_registry: self.plugin_registry, + _phantom: PhantomData, + } + } +} + +// Step 2: Set model +impl WorkerBuilder { + pub fn model(mut self, model_name: impl Into) -> WorkerBuilder { + self.model_name = Some(model_name.into()); + WorkerBuilder { + provider: self.provider, + model_name: self.model_name, + api_keys: self.api_keys, + role: self.role, + plugin_id: self.plugin_id, + plugin_registry: self.plugin_registry, + _phantom: PhantomData, + } + } +} + +// Step 3: Set role +impl WorkerBuilder { + pub fn role(mut self, role: Role) -> WorkerBuilder { + self.role = Some(role); + WorkerBuilder { + provider: self.provider, + model_name: self.model_name, + api_keys: self.api_keys, + role: self.role, + plugin_id: self.plugin_id, + plugin_registry: self.plugin_registry, + _phantom: PhantomData, + } + } +} + +// Optional configurations (available at any stage) +impl WorkerBuilder { + /// Add API key for a provider + pub fn api_key(mut self, provider: impl Into, key: impl Into) -> Self { + self.api_keys.insert(provider.into(), key.into()); + self + } + + /// Set multiple API keys at once + pub fn api_keys(mut self, keys: HashMap) -> Self { + self.api_keys = keys; + self + } +} + +// Build +impl WorkerBuilder { + pub fn build(self) -> Result { + use crate::{LlmProviderExt, WorkspaceDetector, PromptComposer, plugin}; + + let role = self.role.unwrap(); + let model_name = self.model_name.unwrap(); + + // Plugin provider + if let (Some(plugin_id), Some(plugin_registry)) = (self.plugin_id, self.plugin_registry) { + let api_key_opt = self.api_keys.get("__plugin__").or_else(|| { + self.api_keys.values().next() + }); + + let registry = plugin_registry.lock() + .map_err(|e| WorkerError::config(format!("Failed to lock plugin registry: {}", e)))?; + + let plugin = registry.get(&plugin_id) + .ok_or_else(|| WorkerError::config(format!("Plugin not found: {}", plugin_id)))?; + + let llm_client = plugin::PluginClient::new( + plugin.clone(), + &model_name, + api_key_opt.map(|s| s.as_str()), + None, + )?; + + let provider_str = plugin_id.clone(); + let api_key = api_key_opt.map(|s| s.to_string()).unwrap_or_default(); + + let workspace_context = WorkspaceDetector::detect_workspace().ok(); + + let prompt_context = Worker::create_prompt_context_static( + &workspace_context, + worker_types::LlmProvider::OpenAI, + &model_name, + &[], + ); + + tracing::info!("Creating worker with plugin and role: {}", role.name); + + let composer = PromptComposer::from_config(role.clone(), prompt_context) + .map_err(|e| WorkerError::config(e.to_string()))?; + + drop(registry); + + let mut worker = Worker { + llm_client: Box::new(llm_client), + composer, + tools: Vec::new(), + api_key, + provider_str, + model_name, + role, + workspace_context, + message_history: Vec::new(), + hook_manager: worker_types::HookManager::new(), + mcp_lazy_configs: Vec::new(), + plugin_registry: plugin_registry.clone(), + }; + + worker.initialize_session() + .map_err(|e| WorkerError::config(e.to_string()))?; + + return Ok(worker); + } + + // Standard provider + let provider = self.provider.unwrap(); + let provider_str = provider.as_str(); + let api_key = self.api_keys.get(provider_str).cloned().unwrap_or_default(); + let llm_client = provider.create_client(&model_name, &api_key)?; + let plugin_registry = Arc::new(Mutex::new(plugin::PluginRegistry::new())); + + let workspace_context = WorkspaceDetector::detect_workspace().ok(); + + let prompt_context = Worker::create_prompt_context_static( + &workspace_context, + provider.clone(), + &model_name, + &[], + ); + + tracing::info!("Creating worker with role: {}", role.name); + + let composer = PromptComposer::from_config(role.clone(), prompt_context) + .map_err(|e| WorkerError::config(e.to_string()))?; + + let mut worker = Worker { + llm_client: Box::new(llm_client), + composer, + tools: Vec::new(), + api_key, + provider_str: provider_str.to_string(), + model_name, + role, + workspace_context, + message_history: Vec::new(), + hook_manager: worker_types::HookManager::new(), + mcp_lazy_configs: Vec::new(), + plugin_registry, + }; + + worker.initialize_session() + .map_err(|e| WorkerError::config(e.to_string()))?; + + Ok(worker) + } +} diff --git a/worker/src/client.rs b/worker/src/client.rs new file mode 100644 index 0000000..1d035c6 --- /dev/null +++ b/worker/src/client.rs @@ -0,0 +1,86 @@ +// LLM client wrapper that provides a unified interface for all provider clients + +use crate::core::LlmClientTrait; +use crate::llm::{ + anthropic::AnthropicClient, gemini::GeminiClient, ollama::OllamaClient, openai::OpenAIClient, + xai::XAIClient, +}; +use crate::types::WorkerError; +use worker_types::{LlmProvider, Message, StreamEvent}; +use futures_util::Stream; + +// LlmClient enumを使用してdyn互換性の問題を解決 +pub enum LlmClient { + Anthropic(AnthropicClient), + Gemini(GeminiClient), + Ollama(OllamaClient), + OpenAI(OpenAIClient), + XAI(XAIClient), +} + +// LlmClient enumに対するメソッド実装を削除 +// 代わりにLlmClientTraitの実装のみを使用 + +// 委譲マクロでボイラープレートを削減 +macro_rules! delegate_to_client { + // 引数ありの場合 + ($self:expr, $method:ident, $($arg:expr),+) => { + match $self { + LlmClient::Anthropic(client) => client.$method($($arg),*), + LlmClient::Gemini(client) => client.$method($($arg),*), + LlmClient::Ollama(client) => client.$method($($arg),*), + LlmClient::OpenAI(client) => client.$method($($arg),*), + LlmClient::XAI(client) => client.$method($($arg),*), + } + }; + // 引数なしの場合 + ($self:expr, $method:ident) => { + match $self { + LlmClient::Anthropic(client) => client.$method(), + LlmClient::Gemini(client) => client.$method(), + LlmClient::Ollama(client) => client.$method(), + LlmClient::OpenAI(client) => client.$method(), + LlmClient::XAI(client) => client.$method(), + } + }; +} + +// LlmClient enum にtraitを実装 - マクロで簡潔な委譲 +#[async_trait::async_trait] +impl LlmClientTrait for LlmClient { + async fn chat_stream<'a>( + &'a self, + messages: Vec, + tools: Option<&[crate::types::DynamicToolDefinition]>, + llm_debug: Option, + ) -> Result< + Box> + Unpin + Send + 'a>, + WorkerError, + > { + match self { + LlmClient::Anthropic(client) => client.chat_stream(messages, tools, llm_debug).await, + LlmClient::Gemini(client) => client.chat_stream(messages, tools, llm_debug).await, + LlmClient::Ollama(client) => client.chat_stream(messages, tools, llm_debug).await, + LlmClient::OpenAI(client) => client.chat_stream(messages, tools, llm_debug).await, + LlmClient::XAI(client) => client.chat_stream(messages, tools, llm_debug).await, + } + } + + async fn check_connection(&self) -> Result<(), WorkerError> { + match self { + LlmClient::Anthropic(client) => client.check_connection().await, + LlmClient::Gemini(client) => client.check_connection().await, + LlmClient::Ollama(client) => client.check_connection().await, + LlmClient::OpenAI(client) => client.check_connection().await, + LlmClient::XAI(client) => client.check_connection().await, + } + } + + fn provider(&self) -> LlmProvider { + delegate_to_client!(self, provider) + } + + fn get_model_name(&self) -> String { + delegate_to_client!(self, get_model_name) + } +} diff --git a/worker/src/config/mod.rs b/worker/src/config/mod.rs new file mode 100644 index 0000000..d5d3926 --- /dev/null +++ b/worker/src/config/mod.rs @@ -0,0 +1,5 @@ +mod parser; +mod url; + +pub use parser::ConfigParser; +pub use url::UrlConfig; diff --git a/worker/src/config_parser.rs b/worker/src/config/parser.rs similarity index 83% rename from worker/src/config_parser.rs rename to worker/src/config/parser.rs index 7fd145d..df9b5bd 100644 --- a/worker/src/config_parser.rs +++ b/worker/src/config/parser.rs @@ -1,4 +1,4 @@ -use crate::prompt_types::*; +use crate::prompt::{PromptError, Role}; use std::fs; use std::path::Path; @@ -7,7 +7,7 @@ pub struct ConfigParser; impl ConfigParser { /// YAML設定ファイルを読み込んでパースする - pub fn parse_from_file>(path: P) -> Result { + pub fn parse_from_file>(path: P) -> Result { let content = fs::read_to_string(path.as_ref()).map_err(|e| { PromptError::FileNotFound(format!("{}: {}", path.as_ref().display(), e)) })?; @@ -15,9 +15,9 @@ impl ConfigParser { Self::parse_from_string(&content) } - /// YAML文字列をパースしてPromptRoleConfigに変換する - pub fn parse_from_string(content: &str) -> Result { - let config: PromptRoleConfig = serde_yaml::from_str(content)?; + /// YAML文字列をパースしてRoleに変換する + pub fn parse_from_string(content: &str) -> Result { + let config: Role = serde_yaml::from_str(content)?; // 基本的なバリデーション Self::validate_config(&config)?; @@ -26,7 +26,7 @@ impl ConfigParser { } /// 設定ファイルの基本的なバリデーション - fn validate_config(config: &PromptRoleConfig) -> Result<(), PromptError> { + fn validate_config(config: &Role) -> Result<(), PromptError> { if config.name.is_empty() { return Err(PromptError::VariableResolution( "name field cannot be empty".to_string(), @@ -62,17 +62,19 @@ impl ConfigParser { let project_root = std::env::current_dir() .map_err(|e| PromptError::WorkspaceDetection(e.to_string()))?; - // 優先順位: ./resources > ./nia-cli/resources > ../nia-cli/resources + // 優先順位: ./resources > ./cli/resources > ./nia-core/resources > ./nia-pod/resources let possible_paths = [ project_root.join("resources").join(relative_path), project_root - .join("nia-cli") + .join("cli") .join("resources") .join(relative_path), project_root - .parent() - .unwrap_or(&project_root) - .join("nia-cli") + .join("nia-core") + .join("resources") + .join(relative_path), + project_root + .join("nia-pod") .join("resources") .join(relative_path), ]; diff --git a/worker/src/url_config.rs b/worker/src/config/url.rs similarity index 75% rename from worker/src/url_config.rs rename to worker/src/config/url.rs index e2f4226..1bcdb58 100644 --- a/worker/src/url_config.rs +++ b/worker/src/config/url.rs @@ -86,12 +86,15 @@ mod tests { #[test] fn test_default_urls() { + // SAFETY: Setting test environment variables in a single-threaded test context // Clean up any existing env vars first - env::remove_var("OPENAI_BASE_URL"); - env::remove_var("ANTHROPIC_BASE_URL"); - env::remove_var("GEMINI_BASE_URL"); - env::remove_var("XAI_BASE_URL"); - env::remove_var("OLLAMA_BASE_URL"); + unsafe { + env::remove_var("OPENAI_BASE_URL"); + env::remove_var("ANTHROPIC_BASE_URL"); + env::remove_var("GEMINI_BASE_URL"); + env::remove_var("XAI_BASE_URL"); + env::remove_var("OLLAMA_BASE_URL"); + } assert_eq!(UrlConfig::get_base_url("openai"), "https://api.openai.com"); assert_eq!( @@ -108,12 +111,15 @@ mod tests { #[test] fn test_env_override() { + // SAFETY: Setting test environment variables in a single-threaded test context // Clean up any existing env vars first - env::remove_var("OPENAI_BASE_URL"); - env::remove_var("ANTHROPIC_BASE_URL"); + unsafe { + env::remove_var("OPENAI_BASE_URL"); + env::remove_var("ANTHROPIC_BASE_URL"); - env::set_var("OPENAI_BASE_URL", "https://custom.openai.com"); - env::set_var("ANTHROPIC_BASE_URL", "https://custom.anthropic.com"); + env::set_var("OPENAI_BASE_URL", "https://custom.openai.com"); + env::set_var("ANTHROPIC_BASE_URL", "https://custom.anthropic.com"); + } assert_eq!( UrlConfig::get_base_url("openai"), @@ -125,16 +131,21 @@ mod tests { ); // Clean up - env::remove_var("OPENAI_BASE_URL"); - env::remove_var("ANTHROPIC_BASE_URL"); + unsafe { + env::remove_var("OPENAI_BASE_URL"); + env::remove_var("ANTHROPIC_BASE_URL"); + } } #[test] fn test_models_url() { + // SAFETY: Setting test environment variables in a single-threaded test context // Clean up any existing env vars first - env::remove_var("OPENAI_BASE_URL"); - env::remove_var("ANTHROPIC_BASE_URL"); - env::remove_var("OLLAMA_BASE_URL"); + unsafe { + env::remove_var("OPENAI_BASE_URL"); + env::remove_var("ANTHROPIC_BASE_URL"); + env::remove_var("OLLAMA_BASE_URL"); + } assert_eq!( UrlConfig::get_models_url("openai"), @@ -152,10 +163,13 @@ mod tests { #[test] fn test_completion_url() { + // SAFETY: Setting test environment variables in a single-threaded test context // Clean up any existing env vars first - env::remove_var("OPENAI_BASE_URL"); - env::remove_var("ANTHROPIC_BASE_URL"); - env::remove_var("OLLAMA_BASE_URL"); + unsafe { + env::remove_var("OPENAI_BASE_URL"); + env::remove_var("ANTHROPIC_BASE_URL"); + env::remove_var("OLLAMA_BASE_URL"); + } assert_eq!( UrlConfig::get_completion_url("openai"), @@ -173,19 +187,24 @@ mod tests { #[test] fn test_get_active_overrides() { + // SAFETY: Setting test environment variables in a single-threaded test context // Clean up any existing env vars first - env::remove_var("OPENAI_BASE_URL"); - env::remove_var("ANTHROPIC_BASE_URL"); - env::remove_var("GEMINI_BASE_URL"); - env::remove_var("XAI_BASE_URL"); - env::remove_var("OLLAMA_BASE_URL"); + unsafe { + env::remove_var("OPENAI_BASE_URL"); + env::remove_var("ANTHROPIC_BASE_URL"); + env::remove_var("GEMINI_BASE_URL"); + env::remove_var("XAI_BASE_URL"); + env::remove_var("OLLAMA_BASE_URL"); + } // Should return empty when no overrides are set assert_eq!(UrlConfig::get_active_overrides().len(), 0); // Set some overrides - env::set_var("OPENAI_BASE_URL", "https://custom-openai.example.com"); - env::set_var("ANTHROPIC_BASE_URL", "https://custom-anthropic.example.com"); + unsafe { + env::set_var("OPENAI_BASE_URL", "https://custom-openai.example.com"); + env::set_var("ANTHROPIC_BASE_URL", "https://custom-anthropic.example.com"); + } let overrides = UrlConfig::get_active_overrides(); assert_eq!(overrides.len(), 2); @@ -203,7 +222,9 @@ mod tests { assert_eq!(anthropic_override.1, "https://custom-anthropic.example.com"); // Clean up - env::remove_var("OPENAI_BASE_URL"); - env::remove_var("ANTHROPIC_BASE_URL"); + unsafe { + env::remove_var("OPENAI_BASE_URL"); + env::remove_var("ANTHROPIC_BASE_URL"); + } } } diff --git a/worker/src/core.rs b/worker/src/core.rs new file mode 100644 index 0000000..f28635d --- /dev/null +++ b/worker/src/core.rs @@ -0,0 +1,37 @@ +// Core types and traits for the worker crate +// This module contains the primary abstractions used throughout the crate + +use worker_types::{DynamicToolDefinition, LlmProvider, Message, StreamEvent}; +use crate::types::WorkerError; +use futures_util::Stream; + +/// LlmClient trait - common interface for all LLM clients +/// +/// This trait defines the common interface that all LLM client implementations must provide. +/// It's defined here in the core module so it can be imported early in the module hierarchy, +/// before the specific client implementations. +#[async_trait::async_trait] +pub trait LlmClientTrait: Send + Sync { + /// Send a chat message and get a stream of responses + async fn chat_stream<'a>( + &'a self, + messages: Vec, + tools: Option<&[DynamicToolDefinition]>, + llm_debug: Option, + ) -> Result< + Box> + Unpin + Send + 'a>, + WorkerError, + >; + + /// Check if the connection to the LLM provider is working + async fn check_connection(&self) -> Result<(), WorkerError>; + + /// Get the provider type for this client + fn provider(&self) -> LlmProvider; + + /// Get the model name being used + fn get_model_name(&self) -> String; +} + +// The Worker struct and LlmClient enum are defined in lib.rs +// This file just provides the trait definition that's needed early in the module hierarchy diff --git a/worker/src/lib.rs b/worker/src/lib.rs index b21c081..d09263e 100644 --- a/worker/src/lib.rs +++ b/worker/src/lib.rs @@ -1,6 +1,8 @@ -use crate::prompt_composer::PromptComposer; -use crate::prompt_types::*; -use crate::workspace_detector::WorkspaceDetector; +use crate::prompt::{ + PromptComposer, PromptContext, WorkspaceContext, ModelContext, ModelCapabilities, + SessionContext +}; +use crate::workspace::WorkspaceDetector; use async_stream::stream; use futures_util::{Stream, StreamExt}; use llm::{ @@ -11,40 +13,42 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fs; use std::path::PathBuf; -use thiserror::Error; use tracing; use uuid; pub use worker_types::{ DynamicToolDefinition, HookContext, HookEvent, HookManager, HookResult, LlmDebug, LlmProvider, - LlmResponse, Message, ModelInfo, PartialConfig, PromptComponent, PromptComponentDetail, - PromptConfig, Role, RoleConfig, SessionData, StreamEvent, Task, Tool, ToolCall, ToolResult, + LlmResponse, Message, ModelInfo, SessionData, StreamEvent, Task, Tool, ToolCall, ToolResult, WorkerHook, WorkspaceConfig, WorkspaceData, }; +pub use worker_macros::{hook, tool}; -pub mod config_parser; -pub mod llm; -pub mod mcp_config; -pub mod mcp_protocol; -pub mod mcp_tool; -pub mod prompt_composer; -pub mod prompt_types; +pub mod core; pub mod types; -pub mod url_config; -pub mod workspace_detector; +pub mod client; +pub mod builder; +pub mod config; +pub mod llm; +pub mod mcp; +pub mod plugin; +pub mod prompt; +pub mod workspace; + +pub use core::LlmClientTrait; +pub use client::LlmClient; +pub use crate::prompt::Role; +pub use builder::WorkerBuilder; +pub use crate::types::WorkerError; #[cfg(test)] mod tests { mod config_tests; - // mod integration_tests; // Temporarily disabled due to missing dependencies } -// Re-export for tool macros pub use schemars; pub use serde_json; -// Re-export MCP functionality -pub use mcp_config::{IntegrationMode, McpConfig, McpServerDefinition}; -pub use mcp_tool::{ +pub use mcp::{IntegrationMode, McpConfig, McpServerDefinition}; +pub use mcp::{ McpDynamicTool, McpServerConfig, SingleMcpTool, create_single_mcp_tools, get_mcp_tools_as_definitions, test_mcp_connection, }; @@ -63,7 +67,6 @@ pub fn generate_tools_schema(provider: &LlmProvider, tools: &[Box]) -> ) } -/// ツール定義からスキーマを生成する fn generate_tools_schema_from_definitions( provider: &LlmProvider, tool_definitions: &[DynamicToolDefinition], @@ -115,25 +118,21 @@ fn generate_tools_schema_from_definitions( } } -pub use crate::types::WorkerError; - impl WorkerError { /// Check if this error is likely an authentication/API key error pub fn is_authentication_error(&self) -> bool { - matches!(self, WorkerError::General(_)) + matches!(self, WorkerError::ConfigurationError { .. }) } /// Convert a generic error to a WorkerError, detecting authentication issues pub fn from_api_error(error: String, provider: &LlmProvider) -> Self { if Self::is_likely_auth_error(&error, provider) { - WorkerError::Config(error) + WorkerError::config(error) } else { - WorkerError::Network(error) + WorkerError::network(error) } } - /// Comprehensive authentication error detection - /// Many APIs return 400 Bad Request for invalid API keys instead of proper 401/403 fn is_likely_auth_error(error_msg: &str, provider: &LlmProvider) -> bool { let error_msg = error_msg.to_lowercase(); tracing::debug!( @@ -142,22 +141,18 @@ impl WorkerError { error_msg ); - // Standard auth error codes let has_auth_status = error_msg.contains("unauthorized") || error_msg.contains("forbidden") || error_msg.contains("401") || error_msg.contains("403"); - // API key related error messages let has_api_key_error = error_msg.contains("api key") || error_msg.contains("invalid key") || error_msg.contains("authentication") || error_msg.contains("token"); - // Bad request that might be auth related let has_bad_request = error_msg.contains("400") || error_msg.contains("bad request"); - // Common API key error patterns (case insensitive) let has_key_patterns = error_msg.contains("incorrect api key") || error_msg.contains("invalid api key") || error_msg.contains("api key not found") @@ -171,18 +166,14 @@ impl WorkerError { || error_msg.contains("expired") || error_msg.contains("revoked") || error_msg.contains("suspended") - // Generic "invalid" but exclude credit balance specific messages || (error_msg.contains("invalid") && !error_msg.contains("credit balance")); - // Exclude credit balance errors - these are not authentication errors let is_credit_balance_error = error_msg.contains("credit balance") || error_msg.contains("billing") || error_msg.contains("upgrade") || error_msg.contains("purchase credits") - // Also exclude Anthropic's specific credit balance error pattern || (error_msg.contains("invalid_request_error") && error_msg.contains("credit balance")); - // Provider-specific patterns let has_provider_patterns = match provider { LlmProvider::OpenAI => { error_msg.contains("invalid_api_key") @@ -190,9 +181,7 @@ impl WorkerError { && !error_msg.contains("credit balance")) } LlmProvider::Claude => { - // Anthropic specific auth error patterns (error_msg.contains("invalid_x_api_key") || error_msg.contains("x-api-key")) - // But exclude credit balance issues which are not auth errors && !error_msg.contains("credit balance") && !error_msg.contains("billing") && !error_msg.contains("upgrade") @@ -201,18 +190,15 @@ impl WorkerError { LlmProvider::Gemini => { error_msg.contains("invalid_argument") || error_msg.contains("credentials") } - LlmProvider::Ollama => false, // Ollama typically doesn't have API keys + LlmProvider::Ollama => false, LlmProvider::XAI => { error_msg.contains("invalid_api_key") || error_msg.contains("unauthorized") } }; - // Generic patterns let has_generic_patterns = error_msg.contains("credentials") || error_msg.contains("authorization"); - // Provider-specific bad request handling - // Some providers return 400 for auth issues instead of proper status codes let provider_specific_bad_request = match provider { LlmProvider::OpenAI => { has_bad_request @@ -242,9 +228,6 @@ impl WorkerError { } }; - // If it's a bad request with auth-related keywords, treat as auth error - // This handles APIs that incorrectly return 400 for auth issues - // But exclude credit balance errors which are not authentication issues let result = (has_auth_status || has_api_key_error || has_key_patterns @@ -292,16 +275,10 @@ pub fn get_supported_providers() -> Vec { ] } -/// Validate if a provider name is supported pub fn is_provider_supported(provider_name: &str) -> bool { LlmProvider::from_str(provider_name).is_some() } -/// Validate API key for a specific provider -/// Returns: -/// - Some(true): API key is valid -/// - Some(false): API key is invalid -/// - None: Unable to validate (e.g., no official validation endpoint) pub async fn validate_api_key( provider: LlmProvider, api_key: &str, @@ -314,48 +291,41 @@ pub async fn validate_api_key( return Ok(Some(false)); } - // Only perform validation if provider has a simple, official validation method match provider { LlmProvider::Claude => { - // Anthropic doesn't have a dedicated validation endpoint - // Simple format check: should start with "sk-ant-" if api_key.starts_with("sk-ant-") && api_key.len() > 20 { tracing::debug!("validate_api_key: Anthropic API key format appears valid"); - Ok(None) // Cannot validate without making a request + Ok(None) } else { tracing::debug!("validate_api_key: Anthropic API key format is invalid"); Ok(Some(false)) } } LlmProvider::OpenAI => { - // OpenAI: simple format check if api_key.starts_with("sk-") && api_key.len() > 20 { tracing::debug!("validate_api_key: OpenAI API key format appears valid"); - Ok(None) // Cannot validate without making a request + Ok(None) } else { tracing::debug!("validate_api_key: OpenAI API key format is invalid"); Ok(Some(false)) } } LlmProvider::Gemini => { - // Gemini: simple format check if api_key.len() > 20 { tracing::debug!("validate_api_key: Gemini API key format appears valid"); - Ok(None) // Cannot validate without making a request + Ok(None) } else { tracing::debug!("validate_api_key: Gemini API key format is invalid"); Ok(Some(false)) } } LlmProvider::Ollama => { - // Ollama typically doesn't require API keys Ok(Some(true)) } LlmProvider::XAI => { - // xAI: simple format check if api_key.starts_with("xai-") && api_key.len() > 20 { tracing::debug!("validate_api_key: xAI API key format appears valid"); - Ok(None) // Cannot validate without making a request + Ok(None) } else { tracing::debug!("validate_api_key: xAI API key format is invalid"); Ok(Some(false)) @@ -364,7 +334,6 @@ pub async fn validate_api_key( } } -// Models configuration structures #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ModelsConfig { pub models: Vec, @@ -387,15 +356,13 @@ pub struct ModelMeta { pub description: Option, } -// Get models config path fn get_models_config_path() -> Result { let home_dir = dirs::home_dir().ok_or_else(|| { - WorkerError::ConfigurationError("Could not determine home directory".to_string()) + WorkerError::config("Could not determine home directory") })?; Ok(home_dir.join(".config").join("nia").join("models.yaml")) } -// Load models configuration fn load_models_config() -> Result { let config_path = get_models_config_path()?; @@ -408,26 +375,23 @@ fn load_models_config() -> Result { } let content = fs::read_to_string(&config_path).map_err(|e| { - WorkerError::ConfigurationError(format!("Failed to read models config: {}", e)) + WorkerError::config(format!("Failed to read models config: {}", e)) })?; let config: ModelsConfig = serde_yaml::from_str(&content).map_err(|e| { - WorkerError::ConfigurationError(format!("Failed to parse models config: {}", e)) + WorkerError::config(format!("Failed to parse models config: {}", e)) })?; Ok(config) } -// Tool support detection using configuration pub async fn supports_native_tools( provider: &LlmProvider, model_name: &str, _api_key: &str, ) -> Result { - // Load models configuration let config = load_models_config()?; - // Look for the specific model in configuration let model_id = format!( "{}/{}", match provider { @@ -440,7 +404,6 @@ pub async fn supports_native_tools( model_name ); - // Find model in config and check function_calling setting for model_def in &config.models { if model_def.model == model_id || model_def.model.contains(model_name) { tracing::debug!( @@ -458,8 +421,6 @@ pub async fn supports_native_tools( model_name ); - // Fallback to provider-based detection if model not found in config - // But prioritize setting over provider defaults tracing::warn!( "Using provider-based fallback - this should be configured in models.yaml: provider={:?}, model={}", provider, @@ -470,7 +431,7 @@ pub async fn supports_native_tools( LlmProvider::Claude => true, LlmProvider::OpenAI => !model_name.contains("gpt-3.5-turbo-instruct"), LlmProvider::Gemini => !model_name.contains("gemini-pro-vision"), - LlmProvider::Ollama => false, // Default to XML-based tools for Ollama + LlmProvider::Ollama => false, LlmProvider::XAI => true, }; @@ -483,183 +444,63 @@ pub async fn supports_native_tools( Ok(supports_tools) } -// LlmClient trait - 共通インターフェース -#[async_trait::async_trait] -pub trait LlmClientTrait: Send + Sync { - async fn chat_stream<'a>( - &'a self, - messages: Vec, - tools: Option<&[crate::types::DynamicToolDefinition]>, - llm_debug: Option, - ) -> Result< - Box> + Unpin + Send + 'a>, - WorkerError, - >; - - async fn check_connection(&self) -> Result<(), WorkerError>; - - fn provider(&self) -> LlmProvider; - - fn get_model_name(&self) -> String; -} - -// LlmClient enumを使用してdyn互換性の問題を解決 -pub enum LlmClient { - Anthropic(AnthropicClient), - Gemini(GeminiClient), - Ollama(OllamaClient), - OpenAI(OpenAIClient), - XAI(XAIClient), -} - -// LlmClient enumに対するメソッド実装を削除 -// 代わりにLlmClientTraitの実装のみを使用 - -// 委譲マクロでボイラープレートを削減 -macro_rules! delegate_to_client { - // 引数ありの場合 - ($self:expr, $method:ident, $($arg:expr),+) => { - match $self { - LlmClient::Anthropic(client) => client.$method($($arg),*), - LlmClient::Gemini(client) => client.$method($($arg),*), - LlmClient::Ollama(client) => client.$method($($arg),*), - LlmClient::OpenAI(client) => client.$method($($arg),*), - LlmClient::XAI(client) => client.$method($($arg),*), - } - }; - // 引数なしの場合 - ($self:expr, $method:ident) => { - match $self { - LlmClient::Anthropic(client) => client.$method(), - LlmClient::Gemini(client) => client.$method(), - LlmClient::Ollama(client) => client.$method(), - LlmClient::OpenAI(client) => client.$method(), - LlmClient::XAI(client) => client.$method(), - } - }; -} - -// LlmClient enum にtraitを実装 - マクロで簡潔な委譲 -#[async_trait::async_trait] -impl LlmClientTrait for LlmClient { - async fn chat_stream<'a>( - &'a self, - messages: Vec, - tools: Option<&[crate::types::DynamicToolDefinition]>, - llm_debug: Option, - ) -> Result< - Box> + Unpin + Send + 'a>, - WorkerError, - > { - match self { - LlmClient::Anthropic(client) => client.chat_stream(messages, tools, llm_debug).await, - LlmClient::Gemini(client) => client.chat_stream(messages, tools, llm_debug).await, - LlmClient::Ollama(client) => client.chat_stream(messages, tools, llm_debug).await, - LlmClient::OpenAI(client) => client.chat_stream(messages, tools, llm_debug).await, - LlmClient::XAI(client) => client.chat_stream(messages, tools, llm_debug).await, - } - } - - async fn check_connection(&self) -> Result<(), WorkerError> { - match self { - LlmClient::Anthropic(client) => client.check_connection().await, - LlmClient::Gemini(client) => client.check_connection().await, - LlmClient::Ollama(client) => client.check_connection().await, - LlmClient::OpenAI(client) => client.check_connection().await, - LlmClient::XAI(client) => client.check_connection().await, - } - } - - fn provider(&self) -> LlmProvider { - delegate_to_client!(self, provider) - } - - fn get_model_name(&self) -> String { - delegate_to_client!(self, get_model_name) - } -} - pub struct Worker { - llm_client: Box, - composer: PromptComposer, - tools: Vec>, - api_key: String, - provider_str: String, - model_name: String, - role_config: Option, - config: Option, - workspace_context: Option, - message_history: Vec, - hook_manager: crate::types::HookManager, - mcp_lazy_configs: Vec, + pub(crate) llm_client: Box, + pub(crate) composer: PromptComposer, + pub(crate) tools: Vec>, + pub(crate) api_key: String, + pub(crate) provider_str: String, + pub(crate) model_name: String, + pub(crate) role: Role, + pub(crate) workspace_context: Option, + pub(crate) message_history: Vec, + pub(crate) hook_manager: crate::types::HookManager, + pub(crate) mcp_lazy_configs: Vec, + pub(crate) plugin_registry: std::sync::Arc>, } impl Worker { - pub fn new( - provider: LlmProvider, - model_name: &str, - api_keys: &HashMap, - role_config: Option, - ) -> Result { - let provider_str = provider.as_str(); - let api_key = api_keys.get(provider_str).cloned().unwrap_or_default(); - let llm_client = provider.create_client(model_name, &api_key)?; - - // ワークスペースコンテキストを取得 - let workspace_context = WorkspaceDetector::detect_workspace().ok(); - - // プロンプトコンテキストを作成 - let prompt_context = Self::create_prompt_context_static( - &workspace_context, - provider.clone(), - model_name, - &[], - ); - - // デフォルト設定またはcli-assistant設定を使用 - let composer = match Self::try_load_default_config(prompt_context.clone()) { - Ok(composer) => { - tracing::info!("Loaded default CLI assistant configuration"); - composer - } - Err(e) => { - tracing::warn!("Failed to load default config, using fallback: {}", e); - let default_config = PromptRoleConfig::default(); - PromptComposer::from_config(default_config, prompt_context) - .map_err(|e| WorkerError::ConfigurationError(e.to_string()))? - } - }; - - let mut worker = Self { - llm_client: Box::new(llm_client), - composer, - tools: Vec::new(), - api_key, - provider_str: provider_str.to_string(), - model_name: model_name.to_string(), - role_config, - config: None, - workspace_context, - message_history: Vec::new(), - hook_manager: crate::types::HookManager::new(), - mcp_lazy_configs: Vec::new(), - }; - - // セッション開始時のシステムプロンプト初期化 - worker - .initialize_session() - .map_err(|e| WorkerError::ConfigurationError(e.to_string()))?; - - Ok(worker) + /// Create a new WorkerBuilder + /// + /// # Example + /// ```no_run + /// use worker::{Worker, LlmProvider, Role}; + /// + /// let role = Role::new("assistant", "AI Assistant", "You are a helpful assistant."); + /// let worker = Worker::builder() + /// .provider(LlmProvider::Claude) + /// .model("claude-3-sonnet-20240229") + /// .api_key("claude", "sk-ant-...") + /// .role(role) + /// .build()?; + /// # Ok::<(), worker::WorkerError>(()) + /// ``` + pub fn builder() -> builder::WorkerBuilder { + builder::WorkerBuilder::new() } - /// ツールリストをロードする - pub fn load_tools(&mut self, tools: Vec>) -> Result<(), WorkerError> { - self.tools.extend(tools); - tracing::info!("Loaded {} tools", self.tools.len()); + + /// Load plugins from a directory + #[cfg(feature = "dynamic-loading")] + pub async fn load_plugins_from_directory(&mut self, dir: &std::path::Path) -> Result<(), WorkerError> { + let plugins = plugin::PluginLoader::load_from_directory(dir)?; + let mut registry = self.plugin_registry.lock() + .map_err(|e| WorkerError::config(format!("Failed to lock plugin registry: {}", e)))?; + + for plugin in plugins { + registry.register(std::sync::Arc::from(plugin))?; + } + Ok(()) } + /// List all registered plugins + pub fn list_plugins(&self) -> Result, WorkerError> { + let registry = self.plugin_registry.lock() + .map_err(|e| WorkerError::config(format!("Failed to lock plugin registry: {}", e)))?; + Ok(registry.list()) + } + /// フックを登録する pub fn register_hook(&mut self, hook: Box) { let hook_name = hook.name().to_string(); @@ -676,9 +517,9 @@ impl Worker { /// MCPサーバーをツールとして登録する pub fn register_mcp_server( &mut self, - config: mcp_tool::McpServerConfig, + config: McpServerConfig, ) -> Result<(), WorkerError> { - let mcp_tool = mcp_tool::McpDynamicTool::new(config.clone()); + let mcp_tool = McpDynamicTool::new(config.clone()); self.register_tool(Box::new(mcp_tool))?; tracing::info!("Registered MCP server as tool: {}", config.name); Ok(()) @@ -687,9 +528,9 @@ impl Worker { /// MCPサーバーから個別のツールを登録する pub async fn register_mcp_tools( &mut self, - config: mcp_tool::McpServerConfig, + config: McpServerConfig, ) -> Result<(), WorkerError> { - let tools = mcp_tool::create_single_mcp_tools(&config).await?; + let tools = create_single_mcp_tools(&config).await?; let tool_count = tools.len(); for tool in tools { @@ -705,7 +546,7 @@ impl Worker { } /// MCP サーバーを並列初期化キューに追加 - pub fn queue_mcp_server(&mut self, config: mcp_tool::McpServerConfig) { + pub fn queue_mcp_server(&mut self, config: McpServerConfig) { tracing::info!("Queuing MCP server: {}", config.name); self.mcp_lazy_configs.push(config); } @@ -735,7 +576,7 @@ impl Worker { tokio::spawn(async move { tracing::info!("Parallel initializing MCP server: {}", config_name); - match mcp_tool::create_single_mcp_tools(&config).await { + match create_single_mcp_tools(&config).await { Ok(tools) => { tracing::info!( "Successfully initialized {} tools from MCP server: {}", @@ -940,25 +781,25 @@ impl Worker { &mut self, config_path: P, ) -> Result<(), WorkerError> { - use crate::config_parser::ConfigParser; + use crate::config::ConfigParser; // 設定ファイルを読み込み - let config = ConfigParser::parse_from_file(config_path) - .map_err(|e| WorkerError::ConfigurationError(e.to_string()))?; + let role = ConfigParser::parse_from_file(config_path) + .map_err(|e| WorkerError::config(e.to_string()))?; // プロンプトコンテキストを構築 let prompt_context = self.create_prompt_context()?; // DynamicPromptComposerを作成 - let composer = PromptComposer::from_config(config.clone(), prompt_context) - .map_err(|e| WorkerError::ConfigurationError(e.to_string()))?; + let composer = PromptComposer::from_config(role.clone(), prompt_context) + .map_err(|e| WorkerError::config(e.to_string()))?; - self.config = Some(config); + self.role = role; self.composer = composer; // 設定変更後にセッション再初期化 self.initialize_session() - .map_err(|e| WorkerError::ConfigurationError(e.to_string()))?; + .map_err(|e| WorkerError::config(e.to_string()))?; tracing::info!("Dynamic configuration loaded successfully"); Ok(()) @@ -1014,7 +855,7 @@ impl Worker { /// プロンプトコンテキストを作成 fn create_prompt_context(&self) -> Result { let provider = LlmProvider::from_str(&self.provider_str).ok_or_else(|| { - WorkerError::ConfigurationError(format!("Unknown provider: {}", self.provider_str)) + WorkerError::config(format!("Unknown provider: {}", self.provider_str)) })?; // モデル能力を静的に判定 @@ -1084,10 +925,10 @@ impl Worker { pub fn register_tool(&mut self, tool: Box) -> Result<(), WorkerError> { // 同名のツールが既に存在するかチェック if self.tools.iter().any(|t| t.name() == tool.name()) { - return Err(WorkerError::ToolExecutionError(format!( - "Tool '{}' is already registered", - tool.name() - ))); + return Err(WorkerError::tool_execution( + tool.name(), + format!("Tool '{}' is already registered", tool.name()) + )); } self.tools.push(tool); @@ -1118,10 +959,10 @@ impl Worker { ) -> Result { match self.tools.iter().find(|tool| tool.name() == tool_name) { Some(tool) => tool.execute(args).await.map_err(WorkerError::from), - None => Err(WorkerError::ToolExecutionError(format!( - "Tool '{}' not found", - tool_name - ))), + None => Err(WorkerError::tool_execution( + tool_name, + format!("Tool '{}' not found", tool_name) + )), } } @@ -1149,12 +990,12 @@ impl Worker { } /// Get configuration information for task delegation - pub fn get_config(&self) -> (LlmProvider, &str, &str, &Option) { + pub fn get_config(&self) -> (LlmProvider, &str, &str, &Role) { ( self.llm_client.provider(), &self.model_name, &self.api_key, - &self.role_config, + &self.role, ) } @@ -1267,7 +1108,7 @@ impl Worker { let messages = match composer.compose(&conversation_messages) { Ok(m) => m, Err(e) => { - yield Err(WorkerError::ConfigurationError(e.to_string())); + yield Err(WorkerError::config(e.to_string())); return; } }; @@ -1278,7 +1119,7 @@ impl Worker { let messages = match composer.compose_with_tools(&conversation_messages, &tools_schema) { Ok(m) => m, Err(e) => { - yield Err(WorkerError::ConfigurationError(e.to_string())); + yield Err(WorkerError::config(e.to_string())); return; } }; @@ -1466,7 +1307,7 @@ impl Worker { let messages = match self.composer.compose(&conversation_messages) { Ok(m) => m, Err(e) => { - yield Err(WorkerError::ConfigurationError(e.to_string())); + yield Err(WorkerError::config(e.to_string())); return; } }; @@ -1477,7 +1318,7 @@ impl Worker { let messages = match self.composer.compose_with_tools(&conversation_messages, &tools_schema) { Ok(m) => m, Err(e) => { - yield Err(WorkerError::ConfigurationError(e.to_string())); + yield Err(WorkerError::config(e.to_string())); return; } }; @@ -1692,7 +1533,7 @@ impl Worker { let messages = match self.composer.compose(&conversation_messages) { Ok(m) => m, Err(e) => { - yield Err(WorkerError::ConfigurationError(e.to_string())); + yield Err(WorkerError::config(e.to_string())); return; } }; @@ -1703,7 +1544,7 @@ impl Worker { let messages = match self.composer.compose_with_tools(&conversation_messages, &tools_schema) { Ok(m) => m, Err(e) => { - yield Err(WorkerError::ConfigurationError(e.to_string())); + yield Err(WorkerError::config(e.to_string())); return; } }; @@ -1876,7 +1717,7 @@ impl Worker { pub fn get_session_data(&self) -> Result { let workspace_path = std::env::current_dir() .map_err(|e| { - WorkerError::ConfigurationError(format!("Failed to get current directory: {}", e)) + WorkerError::config(format!("Failed to get current directory: {}", e)) })? .to_string_lossy() .to_string(); @@ -1918,7 +1759,7 @@ impl Worker { // セッション復元時にプロンプトコンポーザーを再初期化 self.reinitialize_session_with_history() - .map_err(|e| WorkerError::ConfigurationError(e.to_string()))?; + .map_err(|e| WorkerError::config(e.to_string()))?; Ok(()) } @@ -1970,61 +1811,9 @@ impl Worker { } } - /// デフォルト設定ファイルの読み込みを試行 - fn try_load_default_config( - prompt_context: PromptContext, - ) -> Result { - use crate::config_parser::ConfigParser; - - // デフォルト設定ファイルのパスを試行 - let possible_paths = [ - "#nia/config/roles/cli-assistant.yaml", - "./resources/config/roles/cli-assistant.yaml", - "./nia-cli/resources/config/roles/cli-assistant.yaml", - "../nia-cli/resources/config/roles/cli-assistant.yaml", - ]; - - for path in &possible_paths { - if let Ok(resolved_path) = ConfigParser::resolve_path(path) { - if resolved_path.exists() { - match ConfigParser::parse_from_file(&resolved_path) { - Ok(config) => { - match PromptComposer::from_config(config, prompt_context.clone()) { - Ok(composer) => { - tracing::info!( - "Successfully loaded config from: {}", - resolved_path.display() - ); - return Ok(composer); - } - Err(e) => { - tracing::warn!( - "Failed to create composer from {}: {}", - resolved_path.display(), - e - ); - } - } - } - Err(e) => { - tracing::warn!( - "Failed to parse config from {}: {}", - resolved_path.display(), - e - ); - } - } - } - } - } - - Err(WorkerError::ConfigurationError( - "No default configuration found".to_string(), - )) - } /// セッション初期化(Worker内部用) - fn initialize_session(&mut self) -> Result<(), crate::prompt_types::PromptError> { + fn initialize_session(&mut self) -> Result<(), crate::prompt::PromptError> { // 空のメッセージでセッション初期化 self.composer.initialize_session(&[]) } @@ -2032,7 +1821,7 @@ impl Worker { /// 履歴付きセッション再初期化(Worker内部用) fn reinitialize_session_with_history( &mut self, - ) -> Result<(), crate::prompt_types::PromptError> { + ) -> Result<(), crate::prompt::PromptError> { // 現在の履歴を使ってセッション初期化 self.composer.initialize_session(&self.message_history) } diff --git a/worker/src/llm/anthropic.rs b/worker/src/llm/anthropic.rs index 2f68b3b..a6b1a88 100644 --- a/worker/src/llm/anthropic.rs +++ b/worker/src/llm/anthropic.rs @@ -1,8 +1,7 @@ -use crate::{ - LlmClientTrait, WorkerError, - types::{LlmProvider, Message, Role, StreamEvent, ToolCall}, - url_config::UrlConfig, -}; +use crate::core::LlmClientTrait; +use crate::types::WorkerError; +use worker_types::{LlmProvider, Message, Role, StreamEvent, ToolCall}; +use crate::config::UrlConfig; use async_stream::stream; use futures_util::{Stream, StreamExt}; use reqwest::Client; diff --git a/worker/src/llm/gemini.rs b/worker/src/llm/gemini.rs index 47967b3..e8a890e 100644 --- a/worker/src/llm/gemini.rs +++ b/worker/src/llm/gemini.rs @@ -1,8 +1,7 @@ -use crate::{ - LlmClientTrait, WorkerError, - types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall}, - url_config::UrlConfig, -}; +use crate::core::LlmClientTrait; +use crate::types::WorkerError; +use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall}; +use crate::config::UrlConfig; use futures_util::{Stream, StreamExt, TryStreamExt}; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -708,7 +707,7 @@ impl GeminiClient { }; let stream = stream_events(&self.api_key, &self.model, request, llm_debug) - .map_err(|e| WorkerError::LlmApiError(e.to_string())); + .map_err(|e| WorkerError::llm_api("gemini", e.to_string())); Ok(Box::new(Box::pin(stream))) } @@ -805,7 +804,7 @@ impl GeminiClient { // Simple connection check - try to call the API // For now, just return OK if model is not empty if self.model.is_empty() { - return Err(WorkerError::ModelNotFound("No model specified".to_string())); + return Err(WorkerError::model_not_found("gemini", "No model specified")); } Ok(()) } diff --git a/worker/src/llm/ollama.rs b/worker/src/llm/ollama.rs index 222617b..53810da 100644 --- a/worker/src/llm/ollama.rs +++ b/worker/src/llm/ollama.rs @@ -1,8 +1,7 @@ -use crate::{ - LlmClientTrait, WorkerError, - types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall}, - url_config::UrlConfig, -}; +use crate::core::LlmClientTrait; +use crate::types::WorkerError; +use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall}; +use crate::config::UrlConfig; use futures_util::{Stream, StreamExt}; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -345,7 +344,7 @@ impl OllamaClient { .and_then(|models| models.as_array()) .ok_or_else(|| { tracing::error!("Invalid Ollama models response format - missing 'models' array"); - WorkerError::LlmApiError("Invalid models response format".to_string()) + WorkerError::llm_api("ollama", "Invalid models response format") })? .iter() .filter_map(|model| { @@ -671,7 +670,7 @@ impl OllamaClient { self.add_auth_header(client.get(&url)) .send() .await - .map_err(|e| WorkerError::LlmApiError(format!("Failed to connect to Ollama: {}", e)))?; + .map_err(|e| WorkerError::llm_api("ollama", format!("Failed to connect to Ollama: {}", e)))?; Ok(()) } } diff --git a/worker/src/llm/openai.rs b/worker/src/llm/openai.rs index 5eae402..e77df62 100644 --- a/worker/src/llm/openai.rs +++ b/worker/src/llm/openai.rs @@ -1,8 +1,7 @@ -use crate::{ - LlmClientTrait, WorkerError, - types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall}, - url_config::UrlConfig, -}; +use crate::core::LlmClientTrait; +use crate::types::WorkerError; +use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall}; +use crate::config::UrlConfig; use futures_util::{Stream, StreamExt}; use reqwest::Client; use serde::{Deserialize, Serialize}; diff --git a/worker/src/llm/xai.rs b/worker/src/llm/xai.rs index 0bd5e74..fc6239c 100644 --- a/worker/src/llm/xai.rs +++ b/worker/src/llm/xai.rs @@ -1,8 +1,7 @@ -use crate::{ - LlmClientTrait, WorkerError, - types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall}, - url_config::UrlConfig, -}; +use crate::core::LlmClientTrait; +use crate::types::WorkerError; +use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall}; +use crate::config::UrlConfig; use futures_util::{Stream, StreamExt}; use reqwest::Client; use serde::{Deserialize, Serialize}; diff --git a/worker/src/mcp_config.rs b/worker/src/mcp/config.rs similarity index 95% rename from worker/src/mcp_config.rs rename to worker/src/mcp/config.rs index 3056f05..db11479 100644 --- a/worker/src/mcp_config.rs +++ b/worker/src/mcp/config.rs @@ -1,5 +1,5 @@ -use crate::WorkerError; -use crate::mcp_tool::McpServerConfig; +use crate::types::WorkerError; +use super::tool::McpServerConfig; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::path::Path; @@ -71,14 +71,14 @@ impl McpConfig { info!("Loading MCP config from: {:?}", path); let content = std::fs::read_to_string(path).map_err(|e| { - WorkerError::ConfigurationError(format!( + WorkerError::config(format!( "Failed to read MCP config file {:?}: {}", path, e )) })?; let config: McpConfig = serde_yaml::from_str(&content).map_err(|e| { - WorkerError::ConfigurationError(format!( + WorkerError::config(format!( "Failed to parse MCP config file {:?}: {}", path, e )) @@ -95,7 +95,7 @@ impl McpConfig { // ディレクトリが存在しない場合は作成 if let Some(parent) = path.parent() { std::fs::create_dir_all(parent).map_err(|e| { - WorkerError::ConfigurationError(format!( + WorkerError::config(format!( "Failed to create config directory {:?}: {}", parent, e )) @@ -103,11 +103,11 @@ impl McpConfig { } let content = serde_yaml::to_string(self).map_err(|e| { - WorkerError::ConfigurationError(format!("Failed to serialize MCP config: {}", e)) + WorkerError::config(format!("Failed to serialize MCP config: {}", e)) })?; std::fs::write(path, content).map_err(|e| { - WorkerError::ConfigurationError(format!( + WorkerError::config(format!( "Failed to write MCP config file {:?}: {}", path, e )) @@ -225,7 +225,7 @@ fn expand_environment_variables(input: &str) -> Result { // ${VAR_NAME} パターンを検索して置換 let re = regex::Regex::new(r"\$\{([^}]+)\}") - .map_err(|e| WorkerError::ConfigurationError(format!("Regex error: {}", e)))?; + .map_err(|e| WorkerError::config(format!("Regex error: {}", e)))?; for caps in re.captures_iter(input) { let full_match = &caps[0]; @@ -300,7 +300,10 @@ servers: #[test] fn test_environment_variable_expansion() { - std::env::set_var("TEST_VAR", "test_value"); + // SAFETY: Setting test environment variables in a single-threaded test context + unsafe { + std::env::set_var("TEST_VAR", "test_value"); + } let result = expand_environment_variables("prefix_${TEST_VAR}_suffix").unwrap(); assert_eq!(result, "prefix_test_value_suffix"); diff --git a/worker/src/mcp/mod.rs b/worker/src/mcp/mod.rs new file mode 100644 index 0000000..420e3d0 --- /dev/null +++ b/worker/src/mcp/mod.rs @@ -0,0 +1,10 @@ +mod config; +mod protocol; +mod tool; + +pub use config::{IntegrationMode, McpConfig, McpServerDefinition}; +pub use protocol::McpClient; +pub use tool::{ + create_single_mcp_tools, get_mcp_tools_as_definitions, test_mcp_connection, McpDynamicTool, + McpServerConfig, SingleMcpTool, +}; diff --git a/worker/src/mcp_protocol.rs b/worker/src/mcp/protocol.rs similarity index 100% rename from worker/src/mcp_protocol.rs rename to worker/src/mcp/protocol.rs diff --git a/worker/src/mcp_tool.rs b/worker/src/mcp/tool.rs similarity index 88% rename from worker/src/mcp_tool.rs rename to worker/src/mcp/tool.rs index 76aff28..8d2b85b 100644 --- a/worker/src/mcp_tool.rs +++ b/worker/src/mcp/tool.rs @@ -1,4 +1,4 @@ -use crate::mcp_protocol::{CallToolResult, McpClient, McpToolDefinition}; +use super::protocol::{CallToolResult, McpClient, McpToolDefinition}; use crate::types::{Tool, ToolResult}; use async_trait::async_trait; use serde_json::Value; @@ -92,10 +92,11 @@ impl McpDynamicTool { .connect(self.config.command.clone(), self.config.args.clone()) .await .map_err(|e| { - crate::WorkerError::ToolExecutionError(format!( - "Failed to connect to MCP server '{}': {}", - self.config.name, e - )) + crate::WorkerError::tool_execution_with_source( + &self.config.name, + format!("Failed to connect to MCP server '{}'", self.config.name), + e, + ) })?; *client_guard = Some(mcp_client); @@ -111,14 +112,15 @@ impl McpDynamicTool { let mut client_guard = self.client.lock().await; let client = client_guard.as_mut().ok_or_else(|| { - crate::WorkerError::ToolExecutionError("MCP client not connected".to_string()) + crate::WorkerError::tool_execution(&self.config.name, "MCP client not connected") })?; let tools = client.list_tools().await.map_err(|e| { - crate::WorkerError::ToolExecutionError(format!( - "Failed to list tools from MCP server '{}': {}", - self.config.name, e - )) + crate::WorkerError::tool_execution_with_source( + &self.config.name, + format!("Failed to list tools from MCP server '{}'", self.config.name), + e, + ) })?; debug!( @@ -166,16 +168,17 @@ impl McpDynamicTool { let mut client_guard = self.client.lock().await; let client = client_guard.as_mut().ok_or_else(|| { - crate::WorkerError::ToolExecutionError("MCP client not connected".to_string()) + crate::WorkerError::tool_execution(tool_name, "MCP client not connected") })?; debug!("Calling MCP tool '{}' with args: {}", tool_name, args); let result = client.call_tool(tool_name, Some(args)).await.map_err(|e| { - crate::WorkerError::ToolExecutionError(format!( - "Failed to call MCP tool '{}': {}", - tool_name, e - )) + crate::WorkerError::tool_execution_with_source( + tool_name, + format!("Failed to call MCP tool '{}'", tool_name), + e, + ) })?; debug!("MCP tool '{}' returned: {:?}", tool_name, result); @@ -205,7 +208,7 @@ impl SingleMcpTool { async fn call_mcp_tool(&self, args: Value) -> ToolResult { let mut client_guard = self.client.lock().await; let client = client_guard.as_mut().ok_or_else(|| { - crate::WorkerError::ToolExecutionError("MCP client not connected".to_string()) + crate::WorkerError::tool_execution(&self.tool_name, "MCP client not connected") })?; debug!("Calling MCP tool '{}' with args: {}", self.tool_name, args); @@ -214,10 +217,11 @@ impl SingleMcpTool { .call_tool(&self.tool_name, Some(args)) .await .map_err(|e| { - crate::WorkerError::ToolExecutionError(format!( - "Failed to call MCP tool '{}': {}", - self.tool_name, e - )) + crate::WorkerError::tool_execution_with_source( + &self.tool_name, + format!("Failed to call MCP tool '{}'", self.tool_name), + e, + ) })?; debug!("MCP tool '{}' returned: {:?}", self.tool_name, result); @@ -279,16 +283,18 @@ impl Tool for McpDynamicTool { .get("tool_name") .and_then(|v| v.as_str()) .ok_or_else(|| { - crate::WorkerError::ToolExecutionError( - "Missing required parameter 'tool_name'".to_string(), + crate::WorkerError::tool_execution( + "mcp_proxy", + "Missing required parameter 'tool_name'", ) })?; let tool_args = args .get("tool_args") .ok_or_else(|| { - crate::WorkerError::ToolExecutionError( - "Missing required parameter 'tool_args'".to_string(), + crate::WorkerError::tool_execution( + "mcp_proxy", + "Missing required parameter 'tool_args'", ) })? .clone(); @@ -318,10 +324,10 @@ impl Tool for McpDynamicTool { "result": result })) } - None => Err(Box::new(crate::WorkerError::ToolExecutionError(format!( - "Tool '{}' not found in MCP server '{}'", - tool_name, self.config.name - ))) + None => Err(Box::new(crate::WorkerError::tool_execution( + tool_name, + format!("Tool '{}' not found in MCP server '{}'", tool_name, self.config.name), + )) as Box), } } diff --git a/worker/src/plugin/example_provider.rs b/worker/src/plugin/example_provider.rs new file mode 100644 index 0000000..723ff95 --- /dev/null +++ b/worker/src/plugin/example_provider.rs @@ -0,0 +1,241 @@ +use async_trait::async_trait; +use futures_util::Stream; +use serde_json::Value; +use std::any::Any; +use std::collections::HashMap; + +use crate::plugin::{PluginMetadata, ProviderPlugin}; +use crate::core::LlmClientTrait; +use crate::types::WorkerError; +use worker_types::{DynamicToolDefinition, LlmDebug, LlmProvider, Message, Role, StreamEvent}; + +/// Example custom provider plugin implementation +pub struct CustomProviderPlugin { + initialized: bool, + config: HashMap, +} + +impl CustomProviderPlugin { + pub fn new() -> Self { + Self { + initialized: false, + config: HashMap::new(), + } + } +} + +#[async_trait] +impl ProviderPlugin for CustomProviderPlugin { + fn metadata(&self) -> PluginMetadata { + PluginMetadata { + id: "custom-provider".to_string(), + name: "Custom LLM Provider".to_string(), + version: "1.0.0".to_string(), + author: "Example Author".to_string(), + description: "An example custom LLM provider plugin".to_string(), + supported_models: vec![ + "custom-small".to_string(), + "custom-large".to_string(), + "custom-turbo".to_string(), + ], + requires_api_key: true, + config_schema: Some(serde_json::json!({ + "type": "object", + "properties": { + "base_url": { + "type": "string", + "description": "Base URL for the API endpoint" + }, + "timeout": { + "type": "integer", + "description": "Request timeout in seconds", + "default": 30 + }, + "max_retries": { + "type": "integer", + "description": "Maximum number of retries", + "default": 3 + } + } + })), + } + } + + async fn initialize(&mut self, config: HashMap) -> Result<(), WorkerError> { + self.config = config; + self.initialized = true; + tracing::info!("Custom provider plugin initialized"); + Ok(()) + } + + fn create_client( + &self, + model_name: &str, + api_key: Option<&str>, + config: Option>, + ) -> Result, WorkerError> { + if !self.initialized { + return Err(WorkerError::config("Plugin not initialized")); + } + + let api_key = api_key + .ok_or_else(|| WorkerError::config("API key required for custom provider"))?; + + let client = CustomLlmClient::new( + api_key.to_string(), + model_name.to_string(), + config.unwrap_or_else(|| self.config.clone()), + ); + + Ok(Box::new(client)) + } + + fn validate_api_key(&self, api_key: &str) -> bool { + // Custom validation logic - in this example, check if key starts with "custom-" + api_key.starts_with("custom-") && api_key.len() > 20 + } + + async fn list_models(&self, _api_key: Option<&str>) -> Result, WorkerError> { + // Could make an API call to fetch available models + Ok(self.metadata().supported_models) + } + + async fn health_check(&self, api_key: Option<&str>) -> Result { + // Simple health check - validate API key format + if let Some(key) = api_key { + Ok(self.validate_api_key(key)) + } else { + Ok(false) + } + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +/// Custom LLM client implementation +struct CustomLlmClient { + api_key: String, + model: String, + config: HashMap, +} + +impl CustomLlmClient { + fn new(api_key: String, model: String, config: HashMap) -> Self { + Self { + api_key, + model, + config, + } + } +} + +#[async_trait] +impl LlmClientTrait for CustomLlmClient { + async fn chat_stream<'a>( + &'a self, + messages: Vec, + _tools: Option<&[DynamicToolDefinition]>, + _llm_debug: Option, + ) -> Result> + Unpin + Send + 'a>, WorkerError> { + use async_stream::stream; + + // Example implementation that echoes the last user message + let last_user_message = messages + .iter() + .rev() + .find(|m| m.role == Role::User) + .map(|m| m.content.clone()) + .unwrap_or_else(|| "Hello from custom provider!".to_string()); + + let response = format!( + "This is a response from the custom provider using model '{}'. You said: {}", + self.model, last_user_message + ); + + // Create a stream that yields the response in chunks + let stream = stream! { + // Simulate streaming response + for word in response.split_whitespace() { + yield Ok(StreamEvent::Chunk(format!("{} ", word))); + // Small delay to simulate streaming + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + } + + // Send completion event with the full message + let completion_message = Message { + role: Role::Model, + content: response.clone(), + tool_calls: None, + tool_call_id: None, + metadata: None, + timestamp: Some(chrono::Utc::now()), + }; + yield Ok(StreamEvent::Completion(completion_message)); + }; + + Ok(Box::new(Box::pin(stream))) + } + + async fn check_connection(&self) -> Result<(), WorkerError> { + // Simple check - could make an actual API call in a real implementation + if self.api_key.is_empty() { + Err(WorkerError::config("API key is required")) + } else { + Ok(()) + } + } + + fn provider(&self) -> LlmProvider { + // Return a default provider type for plugins + // In a real implementation, you might extend LlmProvider enum + LlmProvider::OpenAI + } + + fn get_model_name(&self) -> String { + self.model.clone() + } +} + +/// Factory function for dynamic loading +#[unsafe(no_mangle)] +pub extern "C" fn create_plugin() -> Box { + Box::new(CustomProviderPlugin::new()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_plugin_metadata() { + let plugin = CustomProviderPlugin::new(); + let metadata = plugin.metadata(); + + assert_eq!(metadata.id, "custom-provider"); + assert_eq!(metadata.name, "Custom LLM Provider"); + assert!(metadata.requires_api_key); + assert_eq!(metadata.supported_models.len(), 3); + } + + #[test] + fn test_api_key_validation() { + let plugin = CustomProviderPlugin::new(); + + assert!(plugin.validate_api_key("custom-1234567890abcdefghij")); + assert!(!plugin.validate_api_key("invalid-key")); + assert!(!plugin.validate_api_key("custom-short")); + assert!(!plugin.validate_api_key("")); + } + + #[tokio::test] + async fn test_plugin_initialization() { + let mut plugin = CustomProviderPlugin::new(); + let mut config = HashMap::new(); + config.insert("base_url".to_string(), Value::String("https://api.example.com".to_string())); + + let result = plugin.initialize(config).await; + assert!(result.is_ok()); + } +} \ No newline at end of file diff --git a/worker/src/plugin/mod.rs b/worker/src/plugin/mod.rs new file mode 100644 index 0000000..dad4d0f --- /dev/null +++ b/worker/src/plugin/mod.rs @@ -0,0 +1,219 @@ +pub mod example_provider; + +use async_trait::async_trait; +use futures_util::Stream; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::any::Any; +use std::collections::HashMap; +use std::sync::Arc; + +use crate::core::LlmClientTrait; +use crate::types::WorkerError; +use worker_types::{DynamicToolDefinition, LlmDebug, LlmProvider, Message, StreamEvent}; + +/// Plugin metadata for provider identification and configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PluginMetadata { + /// Unique identifier for the plugin + pub id: String, + /// Display name for the plugin + pub name: String, + /// Plugin version + pub version: String, + /// Plugin author + pub author: String, + /// Plugin description + pub description: String, + /// Supported models by this plugin + pub supported_models: Vec, + /// Whether this plugin requires an API key + pub requires_api_key: bool, + /// Custom configuration schema (JSON Schema) + pub config_schema: Option, +} + +/// Main plugin trait that all provider plugins must implement +#[async_trait] +pub trait ProviderPlugin: Send + Sync { + /// Get plugin metadata + fn metadata(&self) -> PluginMetadata; + + /// Initialize the plugin with configuration + async fn initialize(&mut self, config: HashMap) -> Result<(), WorkerError>; + + /// Create a new LLM client instance + fn create_client( + &self, + model_name: &str, + api_key: Option<&str>, + config: Option>, + ) -> Result, WorkerError>; + + /// Validate API key format (if applicable) + fn validate_api_key(&self, api_key: &str) -> bool { + if !self.metadata().requires_api_key { + return true; + } + !api_key.is_empty() + } + + /// Get available models from the provider + async fn list_models(&self, _api_key: Option<&str>) -> Result, WorkerError> { + Ok(self.metadata().supported_models.clone()) + } + + /// Health check for the provider + async fn health_check(&self, _api_key: Option<&str>) -> Result { + Ok(true) + } + + /// Return as Any for downcasting if needed + fn as_any(&self) -> &dyn Any; +} + +/// Plugin loader and registry +pub struct PluginRegistry { + plugins: HashMap>, +} + +impl PluginRegistry { + pub fn new() -> Self { + Self { + plugins: HashMap::new(), + } + } + + /// Register a new plugin + pub fn register(&mut self, plugin: Arc) -> Result<(), WorkerError> { + let metadata = plugin.metadata(); + if self.plugins.contains_key(&metadata.id) { + return Err(WorkerError::config(format!( + "Plugin with id '{}' already registered", + metadata.id + ))); + } + self.plugins.insert(metadata.id.clone(), plugin); + Ok(()) + } + + /// Get a plugin by ID + pub fn get(&self, id: &str) -> Option> { + self.plugins.get(id).cloned() + } + + /// List all registered plugins + pub fn list(&self) -> Vec { + self.plugins.values().map(|p| p.metadata()).collect() + } + + /// Find plugin by model name + pub fn find_by_model(&self, model_name: &str) -> Option> { + self.plugins.values().find(|p| { + p.metadata().supported_models.iter().any(|m| m == model_name) + }).cloned() + } + + /// Unregister a plugin + pub fn unregister(&mut self, id: &str) -> Option> { + self.plugins.remove(id) + } +} + +/// Dynamic plugin loader for loading plugins at runtime +pub struct PluginLoader; + +impl PluginLoader { + /// Load a plugin from a shared library (.so/.dll/.dylib) + #[cfg(feature = "dynamic-loading")] + pub fn load_dynamic(path: &std::path::Path) -> Result, WorkerError> { + use libloading::{Library, Symbol}; + + unsafe { + let lib = Library::new(path) + .map_err(|e| WorkerError::config(format!("Failed to load plugin: {}", e)))?; + + let create_plugin: Symbol Box> = lib + .get(b"create_plugin") + .map_err(|e| WorkerError::config(format!("Plugin missing create_plugin function: {}", e)))?; + + Ok(create_plugin()) + } + } + + /// Load all plugins from a directory + #[cfg(feature = "dynamic-loading")] + pub fn load_from_directory(dir: &std::path::Path) -> Result>, WorkerError> { + use std::fs; + + let mut plugins = Vec::new(); + + if !dir.is_dir() { + return Err(WorkerError::config(format!("Plugin directory does not exist: {:?}", dir))); + } + + for entry in fs::read_dir(dir) + .map_err(|e| WorkerError::Config(format!("Failed to read plugin directory: {}", e)))? + { + let entry = entry + .map_err(|e| WorkerError::config(format!("Failed to read directory entry: {}", e)))?; + let path = entry.path(); + + // Check for plugin files (.so on Linux, .dll on Windows, .dylib on macOS) + if path.extension().and_then(|s| s.to_str()).map_or(false, |ext| { + ext == "so" || ext == "dll" || ext == "dylib" + }) { + match Self::load_dynamic(&path) { + Ok(plugin) => plugins.push(plugin), + Err(e) => { + tracing::warn!("Failed to load plugin from {:?}: {}", path, e); + } + } + } + } + + Ok(plugins) + } +} + +/// Wrapper to adapt plugin-based clients to LlmClientTrait +pub struct PluginClient { + plugin: Arc, + inner: Box, +} + +impl PluginClient { + pub fn new( + plugin: Arc, + model_name: &str, + api_key: Option<&str>, + config: Option>, + ) -> Result { + let inner = plugin.create_client(model_name, api_key, config)?; + Ok(Self { plugin, inner }) + } +} + +#[async_trait] +impl LlmClientTrait for PluginClient { + async fn chat_stream<'a>( + &'a self, + messages: Vec, + tools: Option<&[DynamicToolDefinition]>, + llm_debug: Option, + ) -> Result> + Unpin + Send + 'a>, WorkerError> { + self.inner.chat_stream(messages, tools, llm_debug).await + } + + async fn check_connection(&self) -> Result<(), WorkerError> { + self.inner.check_connection().await + } + + fn provider(&self) -> LlmProvider { + self.inner.provider() + } + + fn get_model_name(&self) -> String { + self.inner.get_model_name() + } +} \ No newline at end of file diff --git a/worker/src/prompt_composer.rs b/worker/src/prompt/composer.rs similarity index 94% rename from worker/src/prompt_composer.rs rename to worker/src/prompt/composer.rs index c9822a0..961af51 100644 --- a/worker/src/prompt_composer.rs +++ b/worker/src/prompt/composer.rs @@ -1,14 +1,16 @@ -use crate::config_parser::ConfigParser; -use crate::prompt_types::*; -use crate::types::{Message, Role}; +use crate::config::ConfigParser; +use super::types::*; use handlebars::{Context, Handlebars, Helper, HelperResult, Output, RenderContext}; use std::fs; use std::path::Path; +// Import Message and Role enum from worker_types +use worker_types::{Message, Role as MessageRole}; + /// プロンプト構築システム #[derive(Clone)] pub struct PromptComposer { - config: PromptRoleConfig, + config: Role, handlebars: Handlebars<'static>, context: PromptContext, system_prompt: Option, @@ -26,7 +28,7 @@ impl PromptComposer { /// 設定オブジェクトから新しいインスタンスを作成 pub fn from_config( - config: PromptRoleConfig, + config: Role, context: PromptContext, ) -> Result { let mut handlebars = Handlebars::new(); @@ -58,11 +60,11 @@ impl PromptComposer { pub fn compose(&self, messages: &[Message]) -> Result, PromptError> { if let Some(system_prompt) = &self.system_prompt { // システムプロンプトが既に構築済みの場合、それを使用 - let mut result_messages = vec![Message::new(Role::System, system_prompt.clone())]; + let mut result_messages = vec![Message::new(MessageRole::System, system_prompt.clone())]; // ユーザーメッセージを追加 for msg in messages { - if msg.role != Role::System { + if msg.role != MessageRole::System { result_messages.push(msg.clone()); } } @@ -100,11 +102,11 @@ impl PromptComposer { ) -> Result, PromptError> { if let Some(system_prompt) = &self.system_prompt { // システムプロンプトが既に構築済みの場合、それを使用 - let mut result_messages = vec![Message::new(Role::System, system_prompt.clone())]; + let mut result_messages = vec![Message::new(MessageRole::System, system_prompt.clone())]; // ユーザーメッセージを追加 for msg in messages { - if msg.role != Role::System { + if msg.role != MessageRole::System { result_messages.push(msg.clone()); } } @@ -156,11 +158,11 @@ impl PromptComposer { let system_prompt = self.compose_system_prompt_with_context(messages, context)?; // システムメッセージとユーザーメッセージを結合 - let mut result_messages = vec![Message::new(Role::System, system_prompt)]; + let mut result_messages = vec![Message::new(MessageRole::System, system_prompt)]; // ユーザーメッセージを追加 for msg in messages { - if msg.role != Role::System { + if msg.role != MessageRole::System { result_messages.push(msg.clone()); } } @@ -223,7 +225,7 @@ impl PromptComposer { ) -> Result { let user_input = messages .iter() - .filter(|m| m.role == Role::User) + .filter(|m| m.role == MessageRole::User) .map(|m| m.content.as_str()) .collect::>() .join("\n\n"); diff --git a/worker/src/prompt/mod.rs b/worker/src/prompt/mod.rs new file mode 100644 index 0000000..4e9c5eb --- /dev/null +++ b/worker/src/prompt/mod.rs @@ -0,0 +1,8 @@ +mod composer; +mod types; + +pub use composer::PromptComposer; +pub use types::{ + ConditionConfig, GitInfo, ModelCapabilities, ModelContext, PartialConfig, PromptContext, + PromptError, ProjectType, Role, SessionContext, SystemInfo, WorkspaceContext, +}; diff --git a/worker/src/prompt_types.rs b/worker/src/prompt/types.rs similarity index 93% rename from worker/src/prompt_types.rs rename to worker/src/prompt/types.rs index bb0ecd7..2d406b4 100644 --- a/worker/src/prompt_types.rs +++ b/worker/src/prompt/types.rs @@ -2,9 +2,9 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::path::PathBuf; -/// ロール設定ファイルの型定義 +/// Role configuration - defines the system instructions for the LLM #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PromptRoleConfig { +pub struct Role { pub name: String, pub description: String, pub version: Option, @@ -352,24 +352,15 @@ impl Default for SessionContext { } } -impl Default for PromptRoleConfig { - fn default() -> Self { - let mut partials = HashMap::new(); - partials.insert( - "role_definition".to_string(), - PartialConfig { - path: "./resources/prompts/cli-assistant.md".to_string(), - fallback: None, - description: Some("Default role definition".to_string()), - }, - ); - +impl Role { + /// Create a new Role with name, description, and template + pub fn new(name: impl Into, description: impl Into, template: impl Into) -> Self { Self { - name: "default".to_string(), - description: "Default dynamic role configuration".to_string(), + name: name.into(), + description: description.into(), version: Some("1.0.0".to_string()), - template: "{{>role_definition}}".to_string(), - partials: Some(partials), + template: template.into(), + partials: None, variables: None, conditions: None, } diff --git a/worker/src/tests/config_tests.rs b/worker/src/tests/config_tests.rs index 5ed2352..e3d7a4a 100644 --- a/worker/src/tests/config_tests.rs +++ b/worker/src/tests/config_tests.rs @@ -1,4 +1,4 @@ -use crate::config_parser::ConfigParser; +use crate::config::ConfigParser; use std::io::Write; use tempfile::NamedTempFile; diff --git a/worker/src/types.rs b/worker/src/types.rs index 726ccde..3876643 100644 --- a/worker/src/types.rs +++ b/worker/src/types.rs @@ -1,48 +1,183 @@ // Re-export all types from worker-types for backwards compatibility pub use worker_types::*; -// Worker-specific error type +// Worker-specific error type with structured information #[derive(Debug, thiserror::Error)] pub enum WorkerError { - #[error("Tool execution failed: {0}")] - ToolExecution(String), - #[error("Tool execution error: {0}")] - ToolExecutionError(String), - #[error("LLM API error: {0}")] - LlmApiError(String), - #[error("Model not found: {0}")] - ModelNotFound(String), - #[error("JSON serialization/deserialization error: {0}")] + /// Tool execution failed + #[error("Tool execution failed: {tool_name} - {reason}")] + ToolExecutionError { + tool_name: String, + reason: String, + #[source] + source: Option>, + }, + + /// LLM API error with provider context + #[error("LLM API error ({provider}): {message}")] + LlmApiError { + provider: String, + message: String, + status_code: Option, + #[source] + source: Option>, + }, + + /// Model not found for the specified provider + #[error("Model not found: {model_name} for provider {provider}")] + ModelNotFound { provider: String, model_name: String }, + + /// JSON serialization/deserialization error + #[error("JSON error: {0}")] JsonError(#[from] serde_json::Error), - #[error("Serialization error: {0}")] - Serialization(serde_json::Error), - #[error("Network error: {0}")] - Network(String), - #[error("Configuration error: {0}")] - Config(String), - #[error("Configuration error: {0}")] - ConfigurationError(String), - #[error("General error: {0}")] - General(#[from] anyhow::Error), - #[error("Box error: {0}")] - BoxError(Box), + + /// Network communication error + #[error("Network error: {message}")] + Network { + message: String, + #[source] + source: Option>, + }, + + /// Configuration error with optional context + #[error("Configuration error: {message}")] + ConfigurationError { + message: String, + context: Option, + #[source] + source: Option>, + }, + + /// Other errors that don't fit specific categories + #[error("{0}")] + Other(String), } -impl From<&str> for WorkerError { - fn from(s: &str) -> Self { - WorkerError::General(anyhow::anyhow!(s.to_string())) +impl WorkerError { + /// Create a tool execution error + pub fn tool_execution(tool_name: impl Into, reason: impl Into) -> Self { + Self::ToolExecutionError { + tool_name: tool_name.into(), + reason: reason.into(), + source: None, + } + } + + /// Create a tool execution error with source + pub fn tool_execution_with_source( + tool_name: impl Into, + reason: impl Into, + source: Box, + ) -> Self { + Self::ToolExecutionError { + tool_name: tool_name.into(), + reason: reason.into(), + source: Some(source), + } + } + + /// Create an LLM API error + pub fn llm_api(provider: impl Into, message: impl Into) -> Self { + Self::LlmApiError { + provider: provider.into(), + message: message.into(), + status_code: None, + source: None, + } + } + + /// Create an LLM API error with status code and source + pub fn llm_api_with_details( + provider: impl Into, + message: impl Into, + status_code: Option, + source: Option>, + ) -> Self { + Self::LlmApiError { + provider: provider.into(), + message: message.into(), + status_code, + source, + } + } + + /// Create a model not found error + pub fn model_not_found(provider: impl Into, model_name: impl Into) -> Self { + Self::ModelNotFound { + provider: provider.into(), + model_name: model_name.into(), + } + } + + /// Create a network error + pub fn network(message: impl Into) -> Self { + Self::Network { + message: message.into(), + source: None, + } + } + + /// Create a network error with source + pub fn network_with_source( + message: impl Into, + source: Box, + ) -> Self { + Self::Network { + message: message.into(), + source: Some(source), + } + } + + /// Create a configuration error + pub fn config(message: impl Into) -> Self { + Self::ConfigurationError { + message: message.into(), + context: None, + source: None, + } + } + + /// Create a configuration error with context + pub fn config_with_context( + message: impl Into, + context: impl Into, + ) -> Self { + Self::ConfigurationError { + message: message.into(), + context: Some(context.into()), + source: None, + } + } + + /// Create a configuration error with source + pub fn config_with_source( + message: impl Into, + source: Box, + ) -> Self { + Self::ConfigurationError { + message: message.into(), + context: None, + source: Some(source), + } } } -impl From for WorkerError { - fn from(s: String) -> Self { - WorkerError::General(anyhow::anyhow!(s)) +// Explicit conversion from common error types +impl From for WorkerError { + fn from(e: reqwest::Error) -> Self { + Self::network_with_source("HTTP request failed", Box::new(e)) + } +} + +impl From for WorkerError { + fn from(e: std::io::Error) -> Self { + Self::config_with_source("I/O error", Box::new(e)) } } impl From> for WorkerError { fn from(e: Box) -> Self { - WorkerError::BoxError(e) + Self::Other(format!("Error: {}", e)) } } diff --git a/worker/src/workspace_detector.rs b/worker/src/workspace/detector.rs similarity index 98% rename from worker/src/workspace_detector.rs rename to worker/src/workspace/detector.rs index 7c27707..899350a 100644 --- a/worker/src/workspace_detector.rs +++ b/worker/src/workspace/detector.rs @@ -1,4 +1,4 @@ -use crate::prompt_types::*; +use crate::prompt::{WorkspaceContext, PromptError, ProjectType, GitInfo}; use std::fs; use std::path::{Path, PathBuf}; use std::process::Command; @@ -34,7 +34,7 @@ impl WorkspaceDetector { let project_name = Self::determine_project_name(&root_path, &git_info); // 6. システム情報を生成 - let system_info = crate::prompt_types::SystemInfo::default(); + let system_info = crate::prompt::SystemInfo::default(); Ok(WorkspaceContext { root_path, diff --git a/worker/src/workspace/mod.rs b/worker/src/workspace/mod.rs new file mode 100644 index 0000000..a45053b --- /dev/null +++ b/worker/src/workspace/mod.rs @@ -0,0 +1,3 @@ +mod detector; + +pub use detector::WorkspaceDetector;