0.2.0 builder model and crate re-export

This commit is contained in:
Keisuke Hirata 2025-10-24 02:44:26 +09:00
parent d03172610d
commit 48bbab0a69
35 changed files with 2491 additions and 773 deletions

View File

@ -1,55 +0,0 @@
# CLAUDE.md
This file provides guidance to Claude Code when working with code in this repository.
## Development Commands
- **Build**: `cargo build --workspace` or `cargo check --workspace` for quick validation
- **Tests**: `cargo test --workspace` (note: tests currently fail due to unsafe environment variable operations)
- **Fix warnings**: `cargo fix --lib -p worker`
- **Dev environment**: Uses Nix flake - run `nix develop` to enter dev shell with Rust toolchain
## Architecture Overview
This is a Rust workspace implementing a multi-LLM worker system with tool calling capabilities. The architecture follows the Core Crate Pattern to avoid circular dependencies:
### Workspace Structure
- **`worker-types`**: Core type definitions (Tool trait, Message, StreamEvent, hook types) - foundational types used by all other crates
- **`worker-macros`**: Procedural macros (`#[tool]`, `#[hook]`) for automatic Tool and WorkerHook trait implementations
- **`worker`**: Main library containing Worker struct, LLM clients, prompt composer, and session management
### Key Components
**Worker**: Central orchestrator that manages LLM interactions, tool execution, and session state. Supports streaming responses via async streams.
**LLM Providers**: Modular clients in `worker/src/llm/` for:
- Anthropic (Claude)
- Google (Gemini)
- OpenAI (GPT)
- xAI (Grok)
- Ollama (local models)
**Tool System**: Uses `#[tool]` macro to convert async functions into LLM-callable tools. Tools must return `ToolResult<T>` and take a single struct argument implementing `Deserialize + Serialize + JsonSchema`.
**Hook System**: Uses `#[hook]` macro to create lifecycle event handlers. Hook functions take `HookContext` and return `(HookContext, HookResult)` tuple, supporting events like OnMessageSend, PreToolUse, PostToolUse, and OnTurnCompleted with regex matcher support.
**Prompt Management**: Handlebars-based templating system for dynamic prompt generation based on roles and context.
**MCP Integration**: Model Context Protocol support for external tool server communication.
## Development Notes
- Edition 2024 Rust with async/await throughout
- Uses `tokio` for async runtime and `reqwest` for HTTP clients
- Environment-based configuration for API keys and base URLs
- Session persistence using JSON serialization
- Streaming responses via `futures-util` streams
- Current test suite has unsafe environment variable operations that need `unsafe` blocks to compile
## Important Patterns
- **Error Handling**: `ToolResult<T>` for tools, `WorkerError` for library operations with automatic conversions
- **Streaming**: All LLM interactions return `StreamEvent` streams for real-time UI updates
- **Tool Registration**: Dynamic tool registration at runtime using boxed trait objects
- **Hook Registration**: Dynamic hook registration with lifecycle event filtering and regex matching
- **Core Crate Pattern**: `worker-macros` references `worker-types` directly via complete paths to prevent circular dependencies

11
Cargo.lock generated
View File

@ -741,6 +741,16 @@ version = "0.2.175"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
[[package]]
name = "libloading"
version = "0.8.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667"
dependencies = [
"cfg-if",
"windows-targets 0.53.3",
]
[[package]] [[package]]
name = "libredox" name = "libredox"
version = "0.1.9" version = "0.1.9"
@ -2214,6 +2224,7 @@ dependencies = [
"futures", "futures",
"futures-util", "futures-util",
"handlebars", "handlebars",
"libloading",
"log", "log",
"regex", "regex",
"reqwest", "reqwest",

399
README.md Normal file
View File

@ -0,0 +1,399 @@
# `worker`
`worker` クレートは、大規模言語モデル (LLM) を利用したアプリケーションのバックエンド機能を提供するクレートです。LLM プロバイダーの抽象化、ツール利用、柔軟なプロンプト管理、フックシステムなど、高度な機能をカプセル化し、アプリケーション開発を簡素化します。
## 主な機能
- **マルチプロバイダー対応**: Gemini, Claude, OpenAI, Ollama, XAI など、複数の LLM プロバイダーを統一されたインターフェースで利用できます。
- **プラグインシステム**: カスタムプロバイダーをプラグインとして動的に追加できます。独自の LLM API や実験的なプロバイダーをサポートします。
- **ツール利用 (Function Calling)**: LLM が外部ツールを呼び出す機能をサポートします。独自のツールをマクロを用いて定義し、`Worker` に登録できます。
- **ストリーミング処理**: LLM の応答やツール実行結果を `StreamEvent` として非同期に受け取ることができます。これにより、リアルタイムな UI 更新が可能になります。
- **フックシステム**: `Worker` の処理フローの特定のタイミング(例: メッセージ送信前、ツール使用後)にカスタムロジックを介入させることができます。
- **セッション管理**: 会話履歴やワークスペースの状態を管理し、永続化する機能を提供します。
- **柔軟なプロンプト管理**: 設定ファイルを用いて、ロールやコンテキストに応じたシステムプロンプトを動的に構築します。
## 主な概念
### `Worker`
このクレートの中心的な構造体です。LLM との対話、ツールの登録と実行、セッション管理など、すべての主要な機能を担当します。
### `LlmProvider`
サポートしている LLM プロバイダー(`Gemini`, `Claude`, `OpenAI` など)を表す enum です。
### `Tool` トレイト
`Worker` が利用できるツールを定義するためのインターフェースです。このトレイトを実装することで、任意の機能をツールとして `Worker` に追加できます。
```rust
pub trait Tool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters_schema(&self) -> serde_json::Value;
async fn execute(&self, args: serde_json::Value) -> ToolResult<serde_json::Value>;
}
```
### `WorkerHook` トレイト
`Worker` のライフサイクルイベントに介入するためのフックを定義するインターフェースです。特定のイベント(例: `OnMessageSend`, `PostToolUse`)に対して処理を追加できます。
### `StreamEvent`
`Worker` の処理結果を非同期ストリームで受け取るための enum です。LLM の応答チャンク、ツール呼び出し、エラーなど、さまざまなイベントを表します。
## アプリケーションへの組み込み方法
### 1. Worker の初期化
Builder patternを使用してWorkerを作成します。
```rust
use worker::{Worker, LlmProvider, Role};
use std::collections::HashMap;
// ロールを定義(必須)
let role = Role::new(
"assistant",
"AI Assistant",
"You are a helpful AI assistant."
);
// APIキーを準備
let mut api_keys = HashMap::new();
api_keys.insert("openai".to_string(), "your_openai_api_key".to_string());
api_keys.insert("claude".to_string(), "your_claude_api_key".to_string());
// Workerを作成builder pattern
let mut worker = Worker::builder()
.provider(LlmProvider::OpenAI)
.model("gpt-4o")
.api_keys(api_keys)
.role(role)
.build()
.expect("Workerの作成に失敗しました");
// または、個別にAPIキーを設定
let worker = Worker::builder()
.provider(LlmProvider::Claude)
.model("claude-3-sonnet-20240229")
.api_key("claude", "sk-ant-...")
.role(role)
.build()?;
```
### 2. ツールの定義と登録
`Tool` トレイトを実装してカスタムツールを作成し、`Worker` に登録します。
```rust
use worker::{Tool, ToolResult};
use worker::schemars::{self, JsonSchema};
use worker::serde_json::{self, json, Value};
use async_trait::async_trait;
// ツールの引数を定義
#[derive(Debug, serde::Deserialize, JsonSchema)]
struct FileSystemToolArgs {
path: String,
}
// カスタムツールを定義
struct ListFilesTool;
#[async_trait]
impl Tool for ListFilesTool {
fn name(&self) -> &str { "list_files" }
fn description(&self) -> &str { "指定されたパスのファイル一覧を表示します" }
fn parameters_schema(&self) -> Value {
serde_json::to_value(schemars::schema_for!(FileSystemToolArgs)).unwrap()
}
async fn execute(&self, args: Value) -> ToolResult<Value> {
let tool_args: FileSystemToolArgs = serde_json::from_value(args)?;
// ここで実際のファイル一覧取得処理を実装
let files = vec!["file1.txt", "file2.txt"];
Ok(json!({ "files": files }))
}
}
// 作成したツールをWorkerに登録
worker.register_tool(Box::new(ListFilesTool)).unwrap();
```
#### マクロを使ったツール定義(推奨)
`worker-macros` クレートの `#[tool]` マクロを使用すると、ツールの定義がより簡潔になります:
```rust
use worker_macros::tool;
use worker::ToolResult;
use serde::{Deserialize, Serialize};
use schemars::JsonSchema;
#[derive(Debug, Deserialize, Serialize, JsonSchema)]
struct ListFilesArgs {
path: String,
}
#[tool]
async fn list_files(args: ListFilesArgs) -> ToolResult<serde_json::Value> {
// ファイル一覧取得処理
let files = vec!["file1.txt", "file2.txt"];
Ok(serde_json::json!({ "files": files }))
}
// マクロで生成されたツールを登録
worker.register_tool(Box::new(ListFilesTool))?;
```
### 3. 対話処理の実行
`process_task_with_history` メソッドを呼び出して、ユーザーメッセージを処理します。このメソッドはイベントのストリームを返します。
```rust
use futures_util::StreamExt;
let user_message = "カレントディレクトリのファイルを教えて".to_string();
let mut stream = worker.process_task_with_history(user_message, None).await;
while let Some(event_result) = stream.next().await {
match event_result {
Ok(event) => {
// StreamEventに応じた処理
match event {
worker::StreamEvent::Chunk(chunk) => {
print!("{}", chunk);
}
worker::StreamEvent::ToolCall(tool_call) => {
println!("\n[Tool Call: {} with args {}]", tool_call.name, tool_call.arguments);
}
worker::StreamEvent::ToolResult { tool_name, result } => {
println!("\n[Tool Result: {} -> {:?}]", tool_name, result);
}
_ => {}
}
}
Err(e) => {
eprintln!("\n[Error: {}]", e);
break;
}
}
}
```
### 4. (オプション) フックの登録
`WorkerHook` トレイトを実装してカスタムフックを作成し、`Worker` に登録することで、処理フローをカスタマイズできます。
#### 手動実装
```rust
use worker::{WorkerHook, HookContext, HookResult};
use async_trait::async_trait;
struct LoggingHook;
#[async_trait]
impl WorkerHook for LoggingHook {
fn name(&self) -> &str { "logging_hook" }
fn hook_type(&self) -> &str { "OnMessageSend" }
fn matcher(&self) -> &str { "" }
async fn execute(&self, context: HookContext) -> (HookContext, HookResult) {
println!("User message: {}", context.content);
(context, HookResult::Continue)
}
}
// フックを登録
worker.register_hook(Box::new(LoggingHook));
```
#### マクロを使ったフック定義(推奨)
`worker-macros` クレートの `#[hook]` マクロを使用すると、フックの定義がより簡潔になります:
```rust
use worker_macros::hook;
use worker::{HookContext, HookResult};
#[hook(OnMessageSend)]
async fn logging_hook(context: HookContext) -> (HookContext, HookResult) {
println!("User message: {}", context.content);
(context, HookResult::Continue)
}
// マクロで生成されたフックを登録
worker.register_hook(Box::new(LoggingHook));
```
**利用可能なフックタイプ:**
- `OnMessageSend`: ユーザーメッセージ送信前
- `PreToolUse`: ツール実行前
- `PostToolUse`: ツール実行後
- `OnTurnCompleted`: ターン完了時
これで、アプリケーションの要件に応じて `Worker` を中心とした強力な LLM 連携機能を構築できます。
## サンプルコード
完全な動作例は `worker/examples/` ディレクトリを参照してください:
- [`builder_basic.rs`](worker/examples/builder_basic.rs) - Builder patternの基本的な使用方法
- [`plugin_usage.rs`](worker/examples/plugin_usage.rs) - プラグインシステムの使用方法
## プラグインシステム
### プラグインの作成
`ProviderPlugin` トレイトを実装してカスタムプロバイダーを作成できます:
```rust
use worker::plugin::{ProviderPlugin, PluginMetadata};
use async_trait::async_trait;
pub struct MyCustomProvider {
// プロバイダーの状態
}
#[async_trait]
impl ProviderPlugin for MyCustomProvider {
fn metadata(&self) -> PluginMetadata {
PluginMetadata {
id: "my-provider".to_string(),
name: "My Custom Provider".to_string(),
version: "1.0.0".to_string(),
author: "Your Name".to_string(),
description: "カスタムプロバイダーの説明".to_string(),
supported_models: vec!["model-1".to_string()],
requires_api_key: true,
config_schema: None,
}
}
async fn initialize(&mut self, config: HashMap<String, Value>) -> Result<(), WorkerError> {
// プロバイダーの初期化
Ok(())
}
fn create_client(
&self,
model_name: &str,
api_key: Option<&str>,
config: Option<HashMap<String, Value>>,
) -> Result<Box<dyn LlmClientTrait>, WorkerError> {
// LLMクライアントを作成して返す
}
fn as_any(&self) -> &dyn Any {
self
}
}
```
### プラグインの使用
```rust
use worker::{Worker, Role, plugin::PluginRegistry};
use std::sync::{Arc, Mutex};
// プラグインレジストリを作成
let plugin_registry = Arc::new(Mutex::new(PluginRegistry::new()));
// プラグインを作成して登録
let my_plugin = Arc::new(MyCustomProvider::new());
{
let mut registry = plugin_registry.lock().unwrap();
registry.register(my_plugin)?;
}
// ロールを定義
let role = Role::new(
"assistant",
"AI Assistant",
"You are a helpful AI assistant."
);
let worker = Worker::builder()
.plugin("my-provider", plugin_registry.clone())
.model("model-1")
.api_key("__plugin__", "api-key")
.role(role)
.build()?;
```
### 動的プラグイン読み込み
`dynamic-loading` フィーチャーを有効にすることで、共有ライブラリからプラグインを動的に読み込むことができます:
```toml
[dependencies]
worker = { path = "../worker", features = ["dynamic-loading"] }
```
```rust
// ディレクトリからプラグインを読み込み
worker.load_plugins_from_directory(Path::new("./plugins")).await?;
```
完全な例は `worker/src/plugin/example_provider.rs``worker/examples/plugin_usage.rs` を参照してください。
## エラーハンドリング
構造化されたエラー型により、詳細なエラー情報を取得できます。
### WorkerError の種類
```rust
use worker::WorkerError;
// ツール実行エラー
let error = WorkerError::tool_execution("my_tool", "Connection failed");
let error = WorkerError::tool_execution_with_source("my_tool", "Failed", source_error);
// 設定エラー
let error = WorkerError::config("Invalid configuration");
let error = WorkerError::config_with_context("Parse error", "config.yaml line 10");
let error = WorkerError::config_with_source("Failed to load", io_error);
// LLM APIエラー
let error = WorkerError::llm_api("openai", "Rate limit exceeded");
let error = WorkerError::llm_api_with_details("claude", "Invalid request", Some(400), None);
// モデルエラー
let error = WorkerError::model_not_found("openai", "gpt-5");
// ネットワークエラー
let error = WorkerError::network("Connection timeout");
let error = WorkerError::network_with_source("Request failed", reqwest_error);
```
### エラーのパターンマッチング
```rust
match worker.build() {
Ok(worker) => { /* ... */ },
Err(WorkerError::ConfigurationError { message, context, .. }) => {
eprintln!("Configuration error: {}", message);
if let Some(ctx) = context {
eprintln!("Context: {}", ctx);
}
},
Err(WorkerError::ModelNotFound { provider, model_name }) => {
eprintln!("Model '{}' not found for provider '{}'", model_name, provider);
},
Err(WorkerError::LlmApiError { provider, message, status_code, .. }) => {
eprintln!("API error from {}: {}", provider, message);
if let Some(code) = status_code {
eprintln!("Status code: {}", code);
}
},
Err(e) => {
eprintln!("Error: {}", e);
}
}
```

632
docs/patch_note/v0.2.0.md Normal file
View File

@ -0,0 +1,632 @@
# Release Notes - v0.2.0
**Release Date**: 2025-10-14
## Overview
# Release Notes - v0.2.0
**Release Date**: 2025-10-14
## Overview
Version 0.2.0 introduces significant improvements to the external API, including a builder pattern for Worker construction, simplified Role management, comprehensive error handling improvements, and a new plugin system for extensible LLM providers.
## Breaking Changes
### 1. Role System Redesign
**Removed duplicate types:**
- `RoleConfig` has been removed from `worker-types`
- `PromptRoleConfig` has been renamed to `Role`
- All configuration is now unified under `worker::prompt_types::Role`
**Built-in roles removed:**
- `Role::default()` - removed
- `Role::assistant()` - removed
- `Role::code_assistant()` - removed
**Migration:**
```rust
// Before (v0.1.0)
let role = RoleConfig::default();
// After (v0.2.0)
use worker::prompt_types::Role;
let role = Role::new(
"assistant",
"A helpful AI assistant",
"You are a helpful, harmless, and honest AI assistant."
);
```
### 2. Worker Construction - Builder Pattern
**New Type-state Builder API:**
Worker construction now uses a type-state builder pattern that enforces required parameters at compile time.
```rust
// Before (v0.1.0)
let worker = Worker::new(
LlmProvider::Claude,
"claude-3-sonnet-20240229",
&api_keys,
role,
plugin_registry,
)?;
// After (v0.2.0)
let worker = Worker::builder()
.provider(LlmProvider::Claude)
.model("claude-3-sonnet-20240229")
.api_keys(api_keys)
.role(role)
.build()?;
```
**Required parameters (enforced at compile-time):**
- `provider(LlmProvider)` - LLM provider
- `model(&str)` - Model name
- `role(Role)` - System role configuration
**Optional parameters:**
- `api_keys(HashMap<String, String>)` - API keys for providers
- `plugin_id(&str)` - Plugin identifier for custom providers
- `plugin_registry(Arc<Mutex<PluginRegistry>>)` - Plugin registry
### 3. WorkerError Redesign
**Complete error type restructuring with structured error information:**
**Removed duplicate variants:**
- `ToolExecution` (use `ToolExecutionError`)
- `Config` (use `ConfigurationError`)
- `Serialization` (use `JsonError`)
**New structured error variants:**
```rust
pub enum WorkerError {
ToolExecutionError {
tool_name: String,
reason: String,
source: Option<Box<dyn Error>>,
},
LlmApiError {
provider: String,
message: String,
status_code: Option<u16>,
source: Option<Box<dyn Error>>,
},
ConfigurationError {
message: String,
context: Option<String>,
source: Option<Box<dyn Error>>,
},
ModelNotFound {
provider: String,
model_name: String,
},
Network {
message: String,
source: Option<Box<dyn Error>>,
},
JsonError(serde_json::Error),
Other(String),
}
```
**New helper methods for ergonomic error construction:**
```rust
// Tool execution errors
WorkerError::tool_execution(tool_name, reason)
WorkerError::tool_execution_with_source(tool_name, reason, source)
// Configuration errors
WorkerError::config(message)
WorkerError::config_with_context(message, context)
WorkerError::config_with_source(message, source)
// LLM API errors
WorkerError::llm_api(provider, message)
WorkerError::llm_api_with_details(provider, message, status_code, source)
// Model errors
WorkerError::model_not_found(provider, model_name)
// Network errors
WorkerError::network(message)
WorkerError::network_with_source(message, source)
```
**Migration:**
```rust
// Before (v0.1.0)
return Err(WorkerError::ToolExecutionError(
format!("Tool '{}' failed: {}", tool_name, reason)
));
// After (v0.2.0)
return Err(WorkerError::tool_execution(tool_name, reason));
// With source error
return Err(WorkerError::tool_execution_with_source(
tool_name,
reason,
Box::new(source_error)
));
```
## New Features
### 1. Plugin System
A new extensible plugin system for custom LLM providers:
**Core Components:**
- `ProviderPlugin` trait - Define custom LLM providers
- `PluginRegistry` - Manage plugin lifecycle
- `PluginClient` - Adapter for plugin-based clients
- `PluginMetadata` - Plugin identification and configuration
**Example:**
```rust
use worker::plugin::{ProviderPlugin, PluginRegistry};
// Register a plugin
let mut registry = PluginRegistry::new();
registry.register(Arc::new(CustomProviderPlugin::new()))?;
// Use plugin with Worker
let worker = Worker::builder()
.provider(LlmProvider::OpenAI) // Base provider type
.model("custom-model")
.plugin_id("custom-provider")
.plugin_registry(Arc::new(Mutex::new(registry)))
.role(role)
.build()?;
```
**Plugin Features:**
- Custom provider implementation
- API key validation
- Model listing
- Health checks
- Dynamic configuration schemas
- Optional dynamic library loading (`#[cfg(feature = "dynamic-loading")]`)
### 2. Examples
**New example files:**
1. **`worker/examples/builder_basic.rs`**
- Demonstrates the new builder pattern API
- Shows basic Worker construction
- Example tool registration
2. **`worker/examples/plugin_usage.rs`**
- Plugin system usage example
- Custom provider implementation
- Plugin registration and usage
## Improvements
### Error Handling
1. **Structured error information** - All errors now carry contextual information (provider, tool name, status codes, etc.)
2. **Error source tracking** - `#[source]` attribute enables full error chain tracing
3. **Automatic conversions** - Added `From` implementations for common error types:
- `From<reqwest::Error>``Network` error
- `From<std::io::Error>``ConfigurationError`
- `From<serde_json::Error>``JsonError`
- `From<Box<dyn Error>>``Other` error
### Code Organization
1. **Eliminated duplicate types** between `worker-types` and `worker` crates
2. **Clearer separation of concerns** - Role definition vs. PromptComposer execution
3. **Consistent error construction** - All error sites updated to use new helper methods
## Files Changed
**Modified (11 files):**
- `Cargo.lock` - Dependency updates
- `worker-types/src/lib.rs` - Removed duplicate types
- `worker/Cargo.toml` - Version and dependency updates
- `worker/src/config_parser.rs` - Role type updates
- `worker/src/lib.rs` - Builder pattern, error handling updates
- `worker/src/llm/gemini.rs` - Error handling improvements
- `worker/src/llm/ollama.rs` - Error handling improvements
- `worker/src/mcp_config.rs` - Error handling updates
- `worker/src/mcp_tool.rs` - Error handling improvements
- `worker/src/prompt_composer.rs` - Role type imports
- `worker/src/prompt_types.rs` - PromptRoleConfig → Role rename
- `worker/src/types.rs` - Complete WorkerError redesign
**Added (5 files):**
- `worker/src/builder.rs` - Type-state builder implementation
- `worker/src/plugin/mod.rs` - Plugin system core
- `worker/src/plugin/example_provider.rs` - Example plugin implementation
- `worker/examples/builder_basic.rs` - Builder pattern example
- `worker/examples/plugin_usage.rs` - Plugin usage example
**Renamed (1 file):**
- `worker/README.md``README.md` - Moved to root
## Migration Guide
### Step 1: Update Role Construction
Replace all `RoleConfig` and built-in role usage:
```rust
// Remove old imports
// use worker::RoleConfig;
// Add new import
use worker::prompt_types::Role;
// Replace role construction
let role = Role::new(
"your-role-name",
"Role description",
"Your system prompt template"
);
```
### Step 2: Update Worker Construction
Replace direct `Worker::new()` calls with builder pattern:
```rust
let worker = Worker::builder()
.provider(provider)
.model(model_name)
.api_keys(api_keys)
.role(role)
.build()?;
```
### Step 3: Update Error Handling
Replace old error construction with new helper methods:
```rust
// Configuration errors
WorkerError::ConfigurationError("message".to_string())
// becomes:
WorkerError::config("message")
// Tool errors
WorkerError::ToolExecutionError(format!("..."))
// becomes:
WorkerError::tool_execution(tool_name, reason)
// Pattern matching
match error {
WorkerError::ConfigurationError(msg) => { /* ... */ }
// becomes:
WorkerError::ConfigurationError { message, .. } => { /* ... */ }
}
```
## Acknowledgments
This release includes significant API improvements based on architectural review and refactoring to improve type safety, error handling, and code maintainability.
## Next Steps
See the examples in `worker/examples/` for complete usage demonstrations:
- `builder_basic.rs` - Basic Worker construction
- `plugin_usage.rs` - Plugin system usage
## Breaking Changes
### 1. Role System Redesign
**Removed duplicate types:**
- `RoleConfig` has been removed from `worker-types`
- `PromptRoleConfig` has been renamed to `Role`
- All configuration is now unified under `worker::prompt_types::Role`
**Built-in roles removed:**
- `Role::default()` - removed
- `Role::assistant()` - removed
- `Role::code_assistant()` - removed
**Migration:**
```rust
// Before (v0.1.0)
let role = RoleConfig::default();
// After (v0.2.0)
use worker::prompt_types::Role;
let role = Role::new(
"assistant",
"A helpful AI assistant",
"You are a helpful, harmless, and honest AI assistant."
);
```
### 2. Worker Construction - Builder Pattern
**New Type-state Builder API:**
Worker construction now uses a type-state builder pattern that enforces required parameters at compile time.
```rust
// Before (v0.1.0)
let worker = Worker::new(
LlmProvider::Claude,
"claude-3-sonnet-20240229",
&api_keys,
role,
plugin_registry,
)?;
// After (v0.2.0)
let worker = Worker::builder()
.provider(LlmProvider::Claude)
.model("claude-3-sonnet-20240229")
.api_keys(api_keys)
.role(role)
.build()?;
```
**Required parameters (enforced at compile-time):**
- `provider(LlmProvider)` - LLM provider
- `model(&str)` - Model name
- `role(Role)` - System role configuration
**Optional parameters:**
- `api_keys(HashMap<String, String>)` - API keys for providers
- `plugin_id(&str)` - Plugin identifier for custom providers
- `plugin_registry(Arc<Mutex<PluginRegistry>>)` - Plugin registry
### 3. WorkerError Redesign
**Complete error type restructuring with structured error information:**
**Removed duplicate variants:**
- `ToolExecution` (use `ToolExecutionError`)
- `Config` (use `ConfigurationError`)
- `Serialization` (use `JsonError`)
**New structured error variants:**
```rust
pub enum WorkerError {
ToolExecutionError {
tool_name: String,
reason: String,
source: Option<Box<dyn Error>>,
},
LlmApiError {
provider: String,
message: String,
status_code: Option<u16>,
source: Option<Box<dyn Error>>,
},
ConfigurationError {
message: String,
context: Option<String>,
source: Option<Box<dyn Error>>,
},
ModelNotFound {
provider: String,
model_name: String,
},
Network {
message: String,
source: Option<Box<dyn Error>>,
},
JsonError(serde_json::Error),
Other(String),
}
```
**New helper methods for ergonomic error construction:**
```rust
// Tool execution errors
WorkerError::tool_execution(tool_name, reason)
WorkerError::tool_execution_with_source(tool_name, reason, source)
// Configuration errors
WorkerError::config(message)
WorkerError::config_with_context(message, context)
WorkerError::config_with_source(message, source)
// LLM API errors
WorkerError::llm_api(provider, message)
WorkerError::llm_api_with_details(provider, message, status_code, source)
// Model errors
WorkerError::model_not_found(provider, model_name)
// Network errors
WorkerError::network(message)
WorkerError::network_with_source(message, source)
```
**Migration:**
```rust
// Before (v0.1.0)
return Err(WorkerError::ToolExecutionError(
format!("Tool '{}' failed: {}", tool_name, reason)
));
// After (v0.2.0)
return Err(WorkerError::tool_execution(tool_name, reason));
// With source error
return Err(WorkerError::tool_execution_with_source(
tool_name,
reason,
Box::new(source_error)
));
```
## New Features
### 1. Plugin System
A new extensible plugin system for custom LLM providers:
**Core Components:**
- `ProviderPlugin` trait - Define custom LLM providers
- `PluginRegistry` - Manage plugin lifecycle
- `PluginClient` - Adapter for plugin-based clients
- `PluginMetadata` - Plugin identification and configuration
**Example:**
```rust
use worker::plugin::{ProviderPlugin, PluginRegistry};
// Register a plugin
let mut registry = PluginRegistry::new();
registry.register(Arc::new(CustomProviderPlugin::new()))?;
// Use plugin with Worker
let worker = Worker::builder()
.provider(LlmProvider::OpenAI) // Base provider type
.model("custom-model")
.plugin_id("custom-provider")
.plugin_registry(Arc::new(Mutex::new(registry)))
.role(role)
.build()?;
```
**Plugin Features:**
- Custom provider implementation
- API key validation
- Model listing
- Health checks
- Dynamic configuration schemas
- Optional dynamic library loading (`#[cfg(feature = "dynamic-loading")]`)
### 2. Examples
**New example files:**
1. **`worker/examples/builder_basic.rs`**
- Demonstrates the new builder pattern API
- Shows basic Worker construction
- Example tool registration
2. **`worker/examples/plugin_usage.rs`**
- Plugin system usage example
- Custom provider implementation
- Plugin registration and usage
## Improvements
### Error Handling
1. **Structured error information** - All errors now carry contextual information (provider, tool name, status codes, etc.)
2. **Error source tracking** - `#[source]` attribute enables full error chain tracing
3. **Automatic conversions** - Added `From` implementations for common error types:
- `From<reqwest::Error>``Network` error
- `From<std::io::Error>``ConfigurationError`
- `From<serde_json::Error>``JsonError`
- `From<Box<dyn Error>>``Other` error
### Code Organization
1. **Eliminated duplicate types** between `worker-types` and `worker` crates
2. **Clearer separation of concerns** - Role definition vs. PromptComposer execution
3. **Consistent error construction** - All error sites updated to use new helper methods
## Files Changed
**Modified (11 files):**
- `Cargo.lock` - Dependency updates
- `worker-types/src/lib.rs` - Removed duplicate types
- `worker/Cargo.toml` - Version and dependency updates
- `worker/src/config_parser.rs` - Role type updates
- `worker/src/lib.rs` - Builder pattern, error handling updates
- `worker/src/llm/gemini.rs` - Error handling improvements
- `worker/src/llm/ollama.rs` - Error handling improvements
- `worker/src/mcp_config.rs` - Error handling updates
- `worker/src/mcp_tool.rs` - Error handling improvements
- `worker/src/prompt_composer.rs` - Role type imports
- `worker/src/prompt_types.rs` - PromptRoleConfig → Role rename
- `worker/src/types.rs` - Complete WorkerError redesign
**Added (5 files):**
- `worker/src/builder.rs` - Type-state builder implementation
- `worker/src/plugin/mod.rs` - Plugin system core
- `worker/src/plugin/example_provider.rs` - Example plugin implementation
- `worker/examples/builder_basic.rs` - Builder pattern example
- `worker/examples/plugin_usage.rs` - Plugin usage example
**Renamed (1 file):**
- `worker/README.md``README.md` - Moved to root
## Migration Guide
### Step 1: Update Role Construction
Replace all `RoleConfig` and built-in role usage:
```rust
// Remove old imports
// use worker::RoleConfig;
// Add new import
use worker::prompt_types::Role;
// Replace role construction
let role = Role::new(
"your-role-name",
"Role description",
"Your system prompt template"
);
```
### Step 2: Update Worker Construction
Replace direct `Worker::new()` calls with builder pattern:
```rust
let worker = Worker::builder()
.provider(provider)
.model(model_name)
.api_keys(api_keys)
.role(role)
.build()?;
```
### Step 3: Update Error Handling
Replace old error construction with new helper methods:
```rust
// Configuration errors
WorkerError::ConfigurationError("message".to_string())
// becomes:
WorkerError::config("message")
// Tool errors
WorkerError::ToolExecutionError(format!("..."))
// becomes:
WorkerError::tool_execution(tool_name, reason)
// Pattern matching
match error {
WorkerError::ConfigurationError(msg) => { /* ... */ }
// becomes:
WorkerError::ConfigurationError { message, .. } => { /* ... */ }
}
```
## Acknowledgments
This release includes significant API improvements based on architectural review and refactoring to improve type safety, error handling, and code maintainability.
## Next Steps
See the examples in `worker/examples/` for complete usage demonstrations:
- `builder_basic.rs` - Basic Worker construction
- `plugin_usage.rs` - Plugin system usage

View File

@ -95,8 +95,8 @@ pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
} }
// Implement Tool trait // Implement Tool trait
#[::worker_types::async_trait::async_trait] #[::worker::types::async_trait::async_trait]
impl ::worker_types::Tool for #tool_struct_name { impl ::worker::types::Tool for #tool_struct_name {
fn name(&self) -> &str { fn name(&self) -> &str {
#tool_name_str #tool_name_str
} }
@ -105,16 +105,16 @@ pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
#description #description
} }
fn parameters_schema(&self) -> ::worker_types::serde_json::Value { fn parameters_schema(&self) -> ::worker::types::serde_json::Value {
::worker_types::serde_json::to_value(::worker_types::schemars::schema_for!(#arg_type)).unwrap() ::worker::types::serde_json::to_value(::worker::types::schemars::schema_for!(#arg_type)).unwrap()
} }
async fn execute(&self, args: ::worker_types::serde_json::Value) -> ::worker_types::ToolResult<::worker_types::serde_json::Value> { async fn execute(&self, args: ::worker::types::serde_json::Value) -> ::worker::types::ToolResult<::worker::types::serde_json::Value> {
let typed_args: #arg_type = ::worker_types::serde_json::from_value(args)?; let typed_args: #arg_type = ::worker::types::serde_json::from_value(args)?;
let result = #fn_name(typed_args).await?; let result = #fn_name(typed_args).await?;
// Use Display formatting instead of JSON serialization // Use Display formatting instead of JSON serialization
let formatted_result = format!("{}", result); let formatted_result = format!("{}", result);
Ok(::worker_types::serde_json::Value::String(formatted_result)) Ok(::worker::types::serde_json::Value::String(formatted_result))
} }
} }
@ -270,8 +270,8 @@ pub fn hook(attr: TokenStream, item: TokenStream) -> TokenStream {
} }
// Implement WorkerHook trait // Implement WorkerHook trait
#[::worker_types::async_trait::async_trait] #[::worker::types::async_trait::async_trait]
impl ::worker_types::WorkerHook for #hook_struct_name { impl ::worker::types::WorkerHook for #hook_struct_name {
fn name(&self) -> &str { fn name(&self) -> &str {
#fn_name_str #fn_name_str
} }
@ -284,7 +284,7 @@ pub fn hook(attr: TokenStream, item: TokenStream) -> TokenStream {
#matcher #matcher
} }
async fn execute(&self, context: ::worker_types::HookContext) -> (::worker_types::HookContext, ::worker_types::HookResult) { async fn execute(&self, context: ::worker::types::HookContext) -> (::worker::types::HookContext, ::worker::types::HookResult) {
#fn_name(context).await #fn_name(context).await
} }
} }

View File

@ -341,70 +341,6 @@ pub enum StreamEvent {
}, },
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoleConfig {
pub name: String,
pub description: String,
#[serde(default)]
pub version: Option<String>,
// 新しいテンプレートベースの設定
pub template: Option<String>,
pub partials: Option<HashMap<String, PartialConfig>>,
// 従来の prompt フィールドもサポート(後方互換性のため)
pub prompt: Option<PromptConfig>,
#[serde(skip)]
pub path: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PartialConfig {
pub path: String,
#[serde(default)]
pub description: Option<String>,
}
impl Default for RoleConfig {
fn default() -> Self {
Self {
name: String::new(),
description: String::new(),
version: None,
template: None,
partials: None,
prompt: None,
path: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(untagged)]
pub enum PromptComponent {
#[default]
None,
Single(PromptComponentDetail),
Multiple(Vec<PromptComponentDetail>),
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PromptComponentDetail {
pub path: String,
#[serde(flatten)]
pub inner: Option<HashMap<String, PromptComponentDetail>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PromptConfig {
#[serde(rename = "ROLE_DEFINE")]
pub role_define: Option<PromptComponent>,
#[serde(rename = "BASIS")]
pub basis: Option<PromptComponent>,
#[serde(rename = "TOOL_USE")]
pub tool_use: Option<PromptComponent>,
#[serde(rename = "SECTIONS")]
pub sections: Option<PromptComponent>,
}
// Session management types // Session management types
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionData { pub struct SessionData {

View File

@ -36,6 +36,12 @@ regex = "1.10.2"
uuid = { version = "1.10", features = ["v4", "serde"] } uuid = { version = "1.10", features = ["v4", "serde"] }
tokio-util = { version = "0.7", features = ["codec"] } tokio-util = { version = "0.7", features = ["codec"] }
futures = "0.3" futures = "0.3"
# Optional dependency for dynamic plugin loading
libloading = { version = "0.8", optional = true }
[features]
default = []
dynamic-loading = ["libloading"]
[dev-dependencies] [dev-dependencies]
tempfile = "3.10.1" tempfile = "3.10.1"

View File

@ -1,150 +0,0 @@
# `worker` クレート
`worker` クレートは、大規模言語モデル (LLM) を利用したアプリケーションのバックエンド機能を提供するコアコンポーネントです。LLMプロバイダーの抽象化、ツール利用、柔軟なプロンプト管理、フックシステムなど、高度な機能をカプセル化し、アプリケーション開発を簡素化します。
## 主な機能
- **マルチプロバイダー対応**: Gemini, Claude, OpenAI, Ollama, XAIなど、複数のLLMプロバイダーを統一されたインターフェースで利用できます。
- **ツール利用 (Function Calling)**: LLMが外部ツールを呼び出す機能をサポートします。独自のツールを簡単に定義して `Worker` に登録できます。
- **ストリーミング処理**: LLMの応答やツール実行結果を `StreamEvent` として非同期に受け取ることができます。これにより、リアルタイムなUI更新が可能になります。
- **フックシステム**: `Worker` の処理フローの特定のタイミング(例: メッセージ送信前、ツール使用後)にカスタムロジックを介入させることができます。
- **セッション管理**: 会話履歴やワークスペースの状態を管理し、永続化する機能を提供します。
- **柔軟なプロンプト管理**: 設定ファイルを用いて、ロールやコンテキストに応じたシステムプロンプトを動的に構築します。
## 主な概念
### `Worker`
このクレートの中心的な構造体です。LLMとの対話、ツールの登録と実行、セッション管理など、すべての主要な機能を担当します。
### `LlmProvider`
サポートしているLLMプロバイダー`Gemini`, `Claude`, `OpenAI` などを表すenumです。
### `Tool` トレイト
`Worker` が利用できるツールを定義するためのインターフェースです。このトレイトを実装することで、任意の機能をツールとして `Worker` に追加できます。
```rust
pub trait Tool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters_schema(&self) -> serde_json::Value;
async fn execute(&self, args: serde_json::Value) -> Result<serde_json::Value, crate::WorkerError>;
}
```
### `WorkerHook` トレイト
`Worker` のライフサイクルイベントに介入するためのフックを定義するインターフェースです。特定のイベント(例: `OnMessageSend`, `PostToolUse`)に対して処理を追加できます。
### `StreamEvent`
`Worker` の処理結果を非同期ストリームで受け取るためのenumです。LLMの応答チャンク、ツール呼び出し、エラーなど、さまざまなイベントを表します。
## アプリケーションへの組み込み方法
### 1. Workerの初期化
まず、`Worker` のインスタンスを作成します。これには `LlmProvider`、モデル名、APIキーが必要です。
```rust
use worker::{Worker, LlmProvider};
use std::collections::HashMap;
// APIキーを準備
let mut api_keys = HashMap::new();
api_keys.insert("openai".to_string(), "your_openai_api_key".to_string());
api_keys.insert("claude".to_string(), "your_claude_api_key".to_string());
// Workerを作成
let mut worker = Worker::new(
LlmProvider::OpenAI,
"gpt-4o",
&api_keys,
None // RoleConfigはオプション
).expect("Workerの作成に失敗しました");
```
### 2. ツールの定義と登録
`Tool` トレイトを実装してカスタムツールを作成し、`Worker` に登録します。
```rust
use worker::{Tool, ToolResult};
use worker::schemars::{self, JsonSchema};
use worker::serde_json::{self, json, Value};
use async_trait::async_trait;
// ツールの引数を定義
#[derive(Debug, serde::Deserialize, JsonSchema)]
struct FileSystemToolArgs {
path: String,
}
// カスタムツールを定義
struct ListFilesTool;
#[async_trait]
impl Tool for ListFilesTool {
fn name(&self) -> &str { "list_files" }
fn description(&self) -> &str { "指定されたパスのファイル一覧を表示します" }
fn parameters_schema(&self) -> Value {
serde_json::to_value(schemars::schema_for!(FileSystemToolArgs)).unwrap()
}
async fn execute(&self, args: Value) -> ToolResult<Value> {
let tool_args: FileSystemToolArgs = serde_json::from_value(args)?;
// ここで実際のファイル一覧取得処理を実装
let files = vec!["file1.txt", "file2.txt"];
Ok(json!({ "files": files }))
}
}
// 作成したツールをWorkerに登録
worker.register_tool(Box::new(ListFilesTool)).unwrap();
```
### 3. 対話処理の実行
`process_task_with_history` メソッドを呼び出して、ユーザーメッセージを処理します。このメソッドはイベントのストリームを返します。
```rust
use futures_util::StreamExt;
let user_message = "カレントディレクトリのファイルを教えて".to_string();
let mut stream = worker.process_task_with_history(user_message, None).await;
while let Some(event_result) = stream.next().await {
match event_result {
Ok(event) => {
// StreamEventに応じた処理
match event {
worker::StreamEvent::Chunk(chunk) => {
print!("{}", chunk);
}
worker::StreamEvent::ToolCall(tool_call) => {
println!("\n[Tool Call: {} with args {}]", tool_call.name, tool_call.arguments);
}
worker::StreamEvent::ToolResult { tool_name, result } => {
println!("\n[Tool Result: {} -> {:?}]", tool_name, result);
}
_ => {}
}
}
Err(e) => {
eprintln!("\n[Error: {}]", e);
break;
}
}
}
```
### 4. (オプション) フックの登録
`WorkerHook` トレイトを実装してカスタムフックを作成し、`Worker` に登録することで、処理フローをカスタマイズできます。
```rust
// (WorkerHookの実装は省略)
// let my_hook = MyCustomHook::new();
// worker.register_hook(Box::new(my_hook));
```
これで、アプリケーションの要件に応じて `Worker` を中心とした強力なLLM連携機能を構築できます。

View File

@ -0,0 +1,44 @@
use worker::{LlmProvider, Worker, Role};
use std::collections::HashMap;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Example 1: Basic role
let role = Role::new(
"assistant",
"A helpful AI assistant",
"You are a helpful, harmless, and honest AI assistant.",
);
let mut api_keys = HashMap::new();
api_keys.insert("claude".to_string(), std::env::var("ANTHROPIC_API_KEY")?);
let worker = Worker::builder()
.provider(LlmProvider::Claude)
.model("claude-3-sonnet-20240229")
.api_keys(api_keys)
.role(role)
.build()?;
println!("✅ Worker created with builder pattern");
println!(" Provider: {:?}", worker.get_provider_name());
println!(" Model: {}", worker.get_model_name());
// Example 2: Code reviewer role
let code_reviewer_role = Role::new(
"code-reviewer",
"An AI that reviews code for best practices",
"You are an expert code reviewer. Always provide constructive feedback.",
);
let _worker2 = Worker::builder()
.provider(LlmProvider::Claude)
.model("claude-3-sonnet-20240229")
.api_key("claude", std::env::var("ANTHROPIC_API_KEY")?)
.role(code_reviewer_role)
.build()?;
println!("✅ Worker created with custom role");
Ok(())
}

View File

@ -0,0 +1,82 @@
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use worker::{Worker, Role, plugin::{PluginRegistry, ProviderPlugin, example_provider::CustomProviderPlugin}};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Initialize tracing for debugging
tracing_subscriber::fmt::init();
// Create a plugin registry
let plugin_registry = Arc::new(Mutex::new(PluginRegistry::new()));
// Create and initialize a custom provider plugin
let mut custom_plugin = CustomProviderPlugin::new();
let mut config = HashMap::new();
config.insert(
"base_url".to_string(),
serde_json::Value::String("https://api.custom-provider.com".to_string()),
);
config.insert(
"timeout".to_string(),
serde_json::Value::Number(serde_json::Number::from(60)),
);
custom_plugin.initialize(config).await?;
// Register the plugin
{
let mut registry = plugin_registry.lock().unwrap();
registry.register(Arc::new(custom_plugin))?;
}
// List available plugins
{
let registry = plugin_registry.lock().unwrap();
let plugins = registry.list();
println!("Available plugins:");
for plugin in plugins {
println!(" - {} ({}): {}", plugin.name, plugin.id, plugin.description);
println!(" Supported models: {:?}", plugin.supported_models);
}
}
// Create a Worker instance using the plugin
let role = Role::new(
"assistant",
"A helpful AI assistant",
"You are a helpful, harmless, and honest AI assistant powered by a custom LLM provider."
);
let worker = Worker::builder()
.plugin("custom-provider", plugin_registry.clone())
.model("custom-turbo")
.api_key("__plugin__", "custom-1234567890abcdefghijklmnop")
.role(role)
.build()?;
println!("\nWorker created successfully with custom provider plugin!");
// Example: List plugins from the worker
let plugin_list = worker.list_plugins()?;
println!("\nPlugins registered in worker:");
for metadata in plugin_list {
println!(" - {}: v{} by {}", metadata.name, metadata.version, metadata.author);
}
// Load plugins from directory (if dynamic loading is enabled)
#[cfg(feature = "dynamic-loading")]
{
use std::path::Path;
let plugin_dir = Path::new("./plugins");
if plugin_dir.exists() {
let mut worker = worker;
worker.load_plugins_from_directory(plugin_dir).await?;
println!("\nLoaded plugins from directory: {:?}", plugin_dir);
}
}
Ok(())
}

260
worker/src/builder.rs Normal file
View File

@ -0,0 +1,260 @@
use crate::Worker;
use crate::prompt::Role;
use crate::types::WorkerError;
use worker_types::LlmProvider;
use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::{Arc, Mutex};
// Type-state markers
pub struct NoProvider;
pub struct WithProvider;
pub struct NoModel;
pub struct WithModel;
pub struct NoRole;
pub struct WithRole;
/// WorkerBuilder with type-state pattern
///
/// This ensures at compile-time that all required fields are set.
///
/// # Example
/// ```no_run
/// use worker::{Worker, LlmProvider, Role};
///
/// let role = Role::new("assistant", "AI Assistant", "You are a helpful assistant.");
/// let worker = Worker::builder()
/// .provider(LlmProvider::Claude)
/// .model("claude-3-sonnet-20240229")
/// .api_key("claude", "sk-ant-...")
/// .role(role)
/// .build()?;
/// # Ok::<(), worker::WorkerError>(())
/// ```
pub struct WorkerBuilder<P, M, R> {
provider: Option<LlmProvider>,
model_name: Option<String>,
api_keys: HashMap<String, String>,
// Role configuration (required)
role: Option<Role>,
// Plugin configuration
plugin_id: Option<String>,
plugin_registry: Option<Arc<Mutex<crate::plugin::PluginRegistry>>>,
_phantom: PhantomData<(P, M, R)>,
}
impl Default for WorkerBuilder<NoProvider, NoModel, NoRole> {
fn default() -> Self {
Self {
provider: None,
model_name: None,
api_keys: HashMap::new(),
role: None,
plugin_id: None,
plugin_registry: None,
_phantom: PhantomData,
}
}
}
impl WorkerBuilder<NoProvider, NoModel, NoRole> {
/// Create a new WorkerBuilder
pub fn new() -> Self {
Self::default()
}
}
// Step 1: Set provider
impl<M, R> WorkerBuilder<NoProvider, M, R> {
pub fn provider(mut self, provider: LlmProvider) -> WorkerBuilder<WithProvider, M, R> {
self.provider = Some(provider);
WorkerBuilder {
provider: self.provider,
model_name: self.model_name,
api_keys: self.api_keys,
role: self.role,
plugin_id: self.plugin_id,
plugin_registry: self.plugin_registry,
_phantom: PhantomData,
}
}
/// Use a plugin provider instead of built-in provider
pub fn plugin(
mut self,
plugin_id: impl Into<String>,
registry: Arc<Mutex<crate::plugin::PluginRegistry>>,
) -> WorkerBuilder<WithProvider, M, R> {
self.plugin_id = Some(plugin_id.into());
self.plugin_registry = Some(registry);
WorkerBuilder {
provider: None,
model_name: self.model_name,
api_keys: self.api_keys,
role: self.role,
plugin_id: self.plugin_id,
plugin_registry: self.plugin_registry,
_phantom: PhantomData,
}
}
}
// Step 2: Set model
impl<R> WorkerBuilder<WithProvider, NoModel, R> {
pub fn model(mut self, model_name: impl Into<String>) -> WorkerBuilder<WithProvider, WithModel, R> {
self.model_name = Some(model_name.into());
WorkerBuilder {
provider: self.provider,
model_name: self.model_name,
api_keys: self.api_keys,
role: self.role,
plugin_id: self.plugin_id,
plugin_registry: self.plugin_registry,
_phantom: PhantomData,
}
}
}
// Step 3: Set role
impl WorkerBuilder<WithProvider, WithModel, NoRole> {
pub fn role(mut self, role: Role) -> WorkerBuilder<WithProvider, WithModel, WithRole> {
self.role = Some(role);
WorkerBuilder {
provider: self.provider,
model_name: self.model_name,
api_keys: self.api_keys,
role: self.role,
plugin_id: self.plugin_id,
plugin_registry: self.plugin_registry,
_phantom: PhantomData,
}
}
}
// Optional configurations (available at any stage)
impl<P, M, R> WorkerBuilder<P, M, R> {
/// Add API key for a provider
pub fn api_key(mut self, provider: impl Into<String>, key: impl Into<String>) -> Self {
self.api_keys.insert(provider.into(), key.into());
self
}
/// Set multiple API keys at once
pub fn api_keys(mut self, keys: HashMap<String, String>) -> Self {
self.api_keys = keys;
self
}
}
// Build
impl WorkerBuilder<WithProvider, WithModel, WithRole> {
pub fn build(self) -> Result<Worker, WorkerError> {
use crate::{LlmProviderExt, WorkspaceDetector, PromptComposer, plugin};
let role = self.role.unwrap();
let model_name = self.model_name.unwrap();
// Plugin provider
if let (Some(plugin_id), Some(plugin_registry)) = (self.plugin_id, self.plugin_registry) {
let api_key_opt = self.api_keys.get("__plugin__").or_else(|| {
self.api_keys.values().next()
});
let registry = plugin_registry.lock()
.map_err(|e| WorkerError::config(format!("Failed to lock plugin registry: {}", e)))?;
let plugin = registry.get(&plugin_id)
.ok_or_else(|| WorkerError::config(format!("Plugin not found: {}", plugin_id)))?;
let llm_client = plugin::PluginClient::new(
plugin.clone(),
&model_name,
api_key_opt.map(|s| s.as_str()),
None,
)?;
let provider_str = plugin_id.clone();
let api_key = api_key_opt.map(|s| s.to_string()).unwrap_or_default();
let workspace_context = WorkspaceDetector::detect_workspace().ok();
let prompt_context = Worker::create_prompt_context_static(
&workspace_context,
worker_types::LlmProvider::OpenAI,
&model_name,
&[],
);
tracing::info!("Creating worker with plugin and role: {}", role.name);
let composer = PromptComposer::from_config(role.clone(), prompt_context)
.map_err(|e| WorkerError::config(e.to_string()))?;
drop(registry);
let mut worker = Worker {
llm_client: Box::new(llm_client),
composer,
tools: Vec::new(),
api_key,
provider_str,
model_name,
role,
workspace_context,
message_history: Vec::new(),
hook_manager: worker_types::HookManager::new(),
mcp_lazy_configs: Vec::new(),
plugin_registry: plugin_registry.clone(),
};
worker.initialize_session()
.map_err(|e| WorkerError::config(e.to_string()))?;
return Ok(worker);
}
// Standard provider
let provider = self.provider.unwrap();
let provider_str = provider.as_str();
let api_key = self.api_keys.get(provider_str).cloned().unwrap_or_default();
let llm_client = provider.create_client(&model_name, &api_key)?;
let plugin_registry = Arc::new(Mutex::new(plugin::PluginRegistry::new()));
let workspace_context = WorkspaceDetector::detect_workspace().ok();
let prompt_context = Worker::create_prompt_context_static(
&workspace_context,
provider.clone(),
&model_name,
&[],
);
tracing::info!("Creating worker with role: {}", role.name);
let composer = PromptComposer::from_config(role.clone(), prompt_context)
.map_err(|e| WorkerError::config(e.to_string()))?;
let mut worker = Worker {
llm_client: Box::new(llm_client),
composer,
tools: Vec::new(),
api_key,
provider_str: provider_str.to_string(),
model_name,
role,
workspace_context,
message_history: Vec::new(),
hook_manager: worker_types::HookManager::new(),
mcp_lazy_configs: Vec::new(),
plugin_registry,
};
worker.initialize_session()
.map_err(|e| WorkerError::config(e.to_string()))?;
Ok(worker)
}
}

86
worker/src/client.rs Normal file
View File

@ -0,0 +1,86 @@
// LLM client wrapper that provides a unified interface for all provider clients
use crate::core::LlmClientTrait;
use crate::llm::{
anthropic::AnthropicClient, gemini::GeminiClient, ollama::OllamaClient, openai::OpenAIClient,
xai::XAIClient,
};
use crate::types::WorkerError;
use worker_types::{LlmProvider, Message, StreamEvent};
use futures_util::Stream;
// LlmClient enumを使用してdyn互換性の問題を解決
pub enum LlmClient {
Anthropic(AnthropicClient),
Gemini(GeminiClient),
Ollama(OllamaClient),
OpenAI(OpenAIClient),
XAI(XAIClient),
}
// LlmClient enumに対するメソッド実装を削除
// 代わりにLlmClientTraitの実装のみを使用
// 委譲マクロでボイラープレートを削減
macro_rules! delegate_to_client {
// 引数ありの場合
($self:expr, $method:ident, $($arg:expr),+) => {
match $self {
LlmClient::Anthropic(client) => client.$method($($arg),*),
LlmClient::Gemini(client) => client.$method($($arg),*),
LlmClient::Ollama(client) => client.$method($($arg),*),
LlmClient::OpenAI(client) => client.$method($($arg),*),
LlmClient::XAI(client) => client.$method($($arg),*),
}
};
// 引数なしの場合
($self:expr, $method:ident) => {
match $self {
LlmClient::Anthropic(client) => client.$method(),
LlmClient::Gemini(client) => client.$method(),
LlmClient::Ollama(client) => client.$method(),
LlmClient::OpenAI(client) => client.$method(),
LlmClient::XAI(client) => client.$method(),
}
};
}
// LlmClient enum にtraitを実装 - マクロで簡潔な委譲
#[async_trait::async_trait]
impl LlmClientTrait for LlmClient {
async fn chat_stream<'a>(
&'a self,
messages: Vec<Message>,
tools: Option<&[crate::types::DynamicToolDefinition]>,
llm_debug: Option<crate::types::LlmDebug>,
) -> Result<
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
WorkerError,
> {
match self {
LlmClient::Anthropic(client) => client.chat_stream(messages, tools, llm_debug).await,
LlmClient::Gemini(client) => client.chat_stream(messages, tools, llm_debug).await,
LlmClient::Ollama(client) => client.chat_stream(messages, tools, llm_debug).await,
LlmClient::OpenAI(client) => client.chat_stream(messages, tools, llm_debug).await,
LlmClient::XAI(client) => client.chat_stream(messages, tools, llm_debug).await,
}
}
async fn check_connection(&self) -> Result<(), WorkerError> {
match self {
LlmClient::Anthropic(client) => client.check_connection().await,
LlmClient::Gemini(client) => client.check_connection().await,
LlmClient::Ollama(client) => client.check_connection().await,
LlmClient::OpenAI(client) => client.check_connection().await,
LlmClient::XAI(client) => client.check_connection().await,
}
}
fn provider(&self) -> LlmProvider {
delegate_to_client!(self, provider)
}
fn get_model_name(&self) -> String {
delegate_to_client!(self, get_model_name)
}
}

5
worker/src/config/mod.rs Normal file
View File

@ -0,0 +1,5 @@
mod parser;
mod url;
pub use parser::ConfigParser;
pub use url::UrlConfig;

View File

@ -1,4 +1,4 @@
use crate::prompt_types::*; use crate::prompt::{PromptError, Role};
use std::fs; use std::fs;
use std::path::Path; use std::path::Path;
@ -7,7 +7,7 @@ pub struct ConfigParser;
impl ConfigParser { impl ConfigParser {
/// YAML設定ファイルを読み込んでパースする /// YAML設定ファイルを読み込んでパースする
pub fn parse_from_file<P: AsRef<Path>>(path: P) -> Result<PromptRoleConfig, PromptError> { pub fn parse_from_file<P: AsRef<Path>>(path: P) -> Result<Role, PromptError> {
let content = fs::read_to_string(path.as_ref()).map_err(|e| { let content = fs::read_to_string(path.as_ref()).map_err(|e| {
PromptError::FileNotFound(format!("{}: {}", path.as_ref().display(), e)) PromptError::FileNotFound(format!("{}: {}", path.as_ref().display(), e))
})?; })?;
@ -15,9 +15,9 @@ impl ConfigParser {
Self::parse_from_string(&content) Self::parse_from_string(&content)
} }
/// YAML文字列をパースしてPromptRoleConfigに変換する /// YAML文字列をパースしてRoleに変換する
pub fn parse_from_string(content: &str) -> Result<PromptRoleConfig, PromptError> { pub fn parse_from_string(content: &str) -> Result<Role, PromptError> {
let config: PromptRoleConfig = serde_yaml::from_str(content)?; let config: Role = serde_yaml::from_str(content)?;
// 基本的なバリデーション // 基本的なバリデーション
Self::validate_config(&config)?; Self::validate_config(&config)?;
@ -26,7 +26,7 @@ impl ConfigParser {
} }
/// 設定ファイルの基本的なバリデーション /// 設定ファイルの基本的なバリデーション
fn validate_config(config: &PromptRoleConfig) -> Result<(), PromptError> { fn validate_config(config: &Role) -> Result<(), PromptError> {
if config.name.is_empty() { if config.name.is_empty() {
return Err(PromptError::VariableResolution( return Err(PromptError::VariableResolution(
"name field cannot be empty".to_string(), "name field cannot be empty".to_string(),
@ -62,17 +62,19 @@ impl ConfigParser {
let project_root = std::env::current_dir() let project_root = std::env::current_dir()
.map_err(|e| PromptError::WorkspaceDetection(e.to_string()))?; .map_err(|e| PromptError::WorkspaceDetection(e.to_string()))?;
// 優先順位: ./resources > ./nia-cli/resources > ../nia-cli/resources // 優先順位: ./resources > ./cli/resources > ./nia-core/resources > ./nia-pod/resources
let possible_paths = [ let possible_paths = [
project_root.join("resources").join(relative_path), project_root.join("resources").join(relative_path),
project_root project_root
.join("nia-cli") .join("cli")
.join("resources") .join("resources")
.join(relative_path), .join(relative_path),
project_root project_root
.parent() .join("nia-core")
.unwrap_or(&project_root) .join("resources")
.join("nia-cli") .join(relative_path),
project_root
.join("nia-pod")
.join("resources") .join("resources")
.join(relative_path), .join(relative_path),
]; ];

View File

@ -86,12 +86,15 @@ mod tests {
#[test] #[test]
fn test_default_urls() { fn test_default_urls() {
// SAFETY: Setting test environment variables in a single-threaded test context
// Clean up any existing env vars first // Clean up any existing env vars first
unsafe {
env::remove_var("OPENAI_BASE_URL"); env::remove_var("OPENAI_BASE_URL");
env::remove_var("ANTHROPIC_BASE_URL"); env::remove_var("ANTHROPIC_BASE_URL");
env::remove_var("GEMINI_BASE_URL"); env::remove_var("GEMINI_BASE_URL");
env::remove_var("XAI_BASE_URL"); env::remove_var("XAI_BASE_URL");
env::remove_var("OLLAMA_BASE_URL"); env::remove_var("OLLAMA_BASE_URL");
}
assert_eq!(UrlConfig::get_base_url("openai"), "https://api.openai.com"); assert_eq!(UrlConfig::get_base_url("openai"), "https://api.openai.com");
assert_eq!( assert_eq!(
@ -108,12 +111,15 @@ mod tests {
#[test] #[test]
fn test_env_override() { fn test_env_override() {
// SAFETY: Setting test environment variables in a single-threaded test context
// Clean up any existing env vars first // Clean up any existing env vars first
unsafe {
env::remove_var("OPENAI_BASE_URL"); env::remove_var("OPENAI_BASE_URL");
env::remove_var("ANTHROPIC_BASE_URL"); env::remove_var("ANTHROPIC_BASE_URL");
env::set_var("OPENAI_BASE_URL", "https://custom.openai.com"); env::set_var("OPENAI_BASE_URL", "https://custom.openai.com");
env::set_var("ANTHROPIC_BASE_URL", "https://custom.anthropic.com"); env::set_var("ANTHROPIC_BASE_URL", "https://custom.anthropic.com");
}
assert_eq!( assert_eq!(
UrlConfig::get_base_url("openai"), UrlConfig::get_base_url("openai"),
@ -125,16 +131,21 @@ mod tests {
); );
// Clean up // Clean up
unsafe {
env::remove_var("OPENAI_BASE_URL"); env::remove_var("OPENAI_BASE_URL");
env::remove_var("ANTHROPIC_BASE_URL"); env::remove_var("ANTHROPIC_BASE_URL");
} }
}
#[test] #[test]
fn test_models_url() { fn test_models_url() {
// SAFETY: Setting test environment variables in a single-threaded test context
// Clean up any existing env vars first // Clean up any existing env vars first
unsafe {
env::remove_var("OPENAI_BASE_URL"); env::remove_var("OPENAI_BASE_URL");
env::remove_var("ANTHROPIC_BASE_URL"); env::remove_var("ANTHROPIC_BASE_URL");
env::remove_var("OLLAMA_BASE_URL"); env::remove_var("OLLAMA_BASE_URL");
}
assert_eq!( assert_eq!(
UrlConfig::get_models_url("openai"), UrlConfig::get_models_url("openai"),
@ -152,10 +163,13 @@ mod tests {
#[test] #[test]
fn test_completion_url() { fn test_completion_url() {
// SAFETY: Setting test environment variables in a single-threaded test context
// Clean up any existing env vars first // Clean up any existing env vars first
unsafe {
env::remove_var("OPENAI_BASE_URL"); env::remove_var("OPENAI_BASE_URL");
env::remove_var("ANTHROPIC_BASE_URL"); env::remove_var("ANTHROPIC_BASE_URL");
env::remove_var("OLLAMA_BASE_URL"); env::remove_var("OLLAMA_BASE_URL");
}
assert_eq!( assert_eq!(
UrlConfig::get_completion_url("openai"), UrlConfig::get_completion_url("openai"),
@ -173,19 +187,24 @@ mod tests {
#[test] #[test]
fn test_get_active_overrides() { fn test_get_active_overrides() {
// SAFETY: Setting test environment variables in a single-threaded test context
// Clean up any existing env vars first // Clean up any existing env vars first
unsafe {
env::remove_var("OPENAI_BASE_URL"); env::remove_var("OPENAI_BASE_URL");
env::remove_var("ANTHROPIC_BASE_URL"); env::remove_var("ANTHROPIC_BASE_URL");
env::remove_var("GEMINI_BASE_URL"); env::remove_var("GEMINI_BASE_URL");
env::remove_var("XAI_BASE_URL"); env::remove_var("XAI_BASE_URL");
env::remove_var("OLLAMA_BASE_URL"); env::remove_var("OLLAMA_BASE_URL");
}
// Should return empty when no overrides are set // Should return empty when no overrides are set
assert_eq!(UrlConfig::get_active_overrides().len(), 0); assert_eq!(UrlConfig::get_active_overrides().len(), 0);
// Set some overrides // Set some overrides
unsafe {
env::set_var("OPENAI_BASE_URL", "https://custom-openai.example.com"); env::set_var("OPENAI_BASE_URL", "https://custom-openai.example.com");
env::set_var("ANTHROPIC_BASE_URL", "https://custom-anthropic.example.com"); env::set_var("ANTHROPIC_BASE_URL", "https://custom-anthropic.example.com");
}
let overrides = UrlConfig::get_active_overrides(); let overrides = UrlConfig::get_active_overrides();
assert_eq!(overrides.len(), 2); assert_eq!(overrides.len(), 2);
@ -203,7 +222,9 @@ mod tests {
assert_eq!(anthropic_override.1, "https://custom-anthropic.example.com"); assert_eq!(anthropic_override.1, "https://custom-anthropic.example.com");
// Clean up // Clean up
unsafe {
env::remove_var("OPENAI_BASE_URL"); env::remove_var("OPENAI_BASE_URL");
env::remove_var("ANTHROPIC_BASE_URL"); env::remove_var("ANTHROPIC_BASE_URL");
} }
}
} }

37
worker/src/core.rs Normal file
View File

@ -0,0 +1,37 @@
// Core types and traits for the worker crate
// This module contains the primary abstractions used throughout the crate
use worker_types::{DynamicToolDefinition, LlmProvider, Message, StreamEvent};
use crate::types::WorkerError;
use futures_util::Stream;
/// LlmClient trait - common interface for all LLM clients
///
/// This trait defines the common interface that all LLM client implementations must provide.
/// It's defined here in the core module so it can be imported early in the module hierarchy,
/// before the specific client implementations.
#[async_trait::async_trait]
pub trait LlmClientTrait: Send + Sync {
/// Send a chat message and get a stream of responses
async fn chat_stream<'a>(
&'a self,
messages: Vec<Message>,
tools: Option<&[DynamicToolDefinition]>,
llm_debug: Option<worker_types::LlmDebug>,
) -> Result<
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
WorkerError,
>;
/// Check if the connection to the LLM provider is working
async fn check_connection(&self) -> Result<(), WorkerError>;
/// Get the provider type for this client
fn provider(&self) -> LlmProvider;
/// Get the model name being used
fn get_model_name(&self) -> String;
}
// The Worker struct and LlmClient enum are defined in lib.rs
// This file just provides the trait definition that's needed early in the module hierarchy

View File

@ -1,6 +1,8 @@
use crate::prompt_composer::PromptComposer; use crate::prompt::{
use crate::prompt_types::*; PromptComposer, PromptContext, WorkspaceContext, ModelContext, ModelCapabilities,
use crate::workspace_detector::WorkspaceDetector; SessionContext
};
use crate::workspace::WorkspaceDetector;
use async_stream::stream; use async_stream::stream;
use futures_util::{Stream, StreamExt}; use futures_util::{Stream, StreamExt};
use llm::{ use llm::{
@ -11,40 +13,42 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::fs; use std::fs;
use std::path::PathBuf; use std::path::PathBuf;
use thiserror::Error;
use tracing; use tracing;
use uuid; use uuid;
pub use worker_types::{ pub use worker_types::{
DynamicToolDefinition, HookContext, HookEvent, HookManager, HookResult, LlmDebug, LlmProvider, DynamicToolDefinition, HookContext, HookEvent, HookManager, HookResult, LlmDebug, LlmProvider,
LlmResponse, Message, ModelInfo, PartialConfig, PromptComponent, PromptComponentDetail, LlmResponse, Message, ModelInfo, SessionData, StreamEvent, Task, Tool, ToolCall, ToolResult,
PromptConfig, Role, RoleConfig, SessionData, StreamEvent, Task, Tool, ToolCall, ToolResult,
WorkerHook, WorkspaceConfig, WorkspaceData, WorkerHook, WorkspaceConfig, WorkspaceData,
}; };
pub use worker_macros::{hook, tool};
pub mod config_parser; pub mod core;
pub mod llm;
pub mod mcp_config;
pub mod mcp_protocol;
pub mod mcp_tool;
pub mod prompt_composer;
pub mod prompt_types;
pub mod types; pub mod types;
pub mod url_config; pub mod client;
pub mod workspace_detector; pub mod builder;
pub mod config;
pub mod llm;
pub mod mcp;
pub mod plugin;
pub mod prompt;
pub mod workspace;
pub use core::LlmClientTrait;
pub use client::LlmClient;
pub use crate::prompt::Role;
pub use builder::WorkerBuilder;
pub use crate::types::WorkerError;
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
mod config_tests; mod config_tests;
// mod integration_tests; // Temporarily disabled due to missing dependencies
} }
// Re-export for tool macros
pub use schemars; pub use schemars;
pub use serde_json; pub use serde_json;
// Re-export MCP functionality pub use mcp::{IntegrationMode, McpConfig, McpServerDefinition};
pub use mcp_config::{IntegrationMode, McpConfig, McpServerDefinition}; pub use mcp::{
pub use mcp_tool::{
McpDynamicTool, McpServerConfig, SingleMcpTool, create_single_mcp_tools, McpDynamicTool, McpServerConfig, SingleMcpTool, create_single_mcp_tools,
get_mcp_tools_as_definitions, test_mcp_connection, get_mcp_tools_as_definitions, test_mcp_connection,
}; };
@ -63,7 +67,6 @@ pub fn generate_tools_schema(provider: &LlmProvider, tools: &[Box<dyn Tool>]) ->
) )
} }
/// ツール定義からスキーマを生成する
fn generate_tools_schema_from_definitions( fn generate_tools_schema_from_definitions(
provider: &LlmProvider, provider: &LlmProvider,
tool_definitions: &[DynamicToolDefinition], tool_definitions: &[DynamicToolDefinition],
@ -115,25 +118,21 @@ fn generate_tools_schema_from_definitions(
} }
} }
pub use crate::types::WorkerError;
impl WorkerError { impl WorkerError {
/// Check if this error is likely an authentication/API key error /// Check if this error is likely an authentication/API key error
pub fn is_authentication_error(&self) -> bool { pub fn is_authentication_error(&self) -> bool {
matches!(self, WorkerError::General(_)) matches!(self, WorkerError::ConfigurationError { .. })
} }
/// Convert a generic error to a WorkerError, detecting authentication issues /// Convert a generic error to a WorkerError, detecting authentication issues
pub fn from_api_error(error: String, provider: &LlmProvider) -> Self { pub fn from_api_error(error: String, provider: &LlmProvider) -> Self {
if Self::is_likely_auth_error(&error, provider) { if Self::is_likely_auth_error(&error, provider) {
WorkerError::Config(error) WorkerError::config(error)
} else { } else {
WorkerError::Network(error) WorkerError::network(error)
} }
} }
/// Comprehensive authentication error detection
/// Many APIs return 400 Bad Request for invalid API keys instead of proper 401/403
fn is_likely_auth_error(error_msg: &str, provider: &LlmProvider) -> bool { fn is_likely_auth_error(error_msg: &str, provider: &LlmProvider) -> bool {
let error_msg = error_msg.to_lowercase(); let error_msg = error_msg.to_lowercase();
tracing::debug!( tracing::debug!(
@ -142,22 +141,18 @@ impl WorkerError {
error_msg error_msg
); );
// Standard auth error codes
let has_auth_status = error_msg.contains("unauthorized") let has_auth_status = error_msg.contains("unauthorized")
|| error_msg.contains("forbidden") || error_msg.contains("forbidden")
|| error_msg.contains("401") || error_msg.contains("401")
|| error_msg.contains("403"); || error_msg.contains("403");
// API key related error messages
let has_api_key_error = error_msg.contains("api key") let has_api_key_error = error_msg.contains("api key")
|| error_msg.contains("invalid key") || error_msg.contains("invalid key")
|| error_msg.contains("authentication") || error_msg.contains("authentication")
|| error_msg.contains("token"); || error_msg.contains("token");
// Bad request that might be auth related
let has_bad_request = error_msg.contains("400") || error_msg.contains("bad request"); let has_bad_request = error_msg.contains("400") || error_msg.contains("bad request");
// Common API key error patterns (case insensitive)
let has_key_patterns = error_msg.contains("incorrect api key") let has_key_patterns = error_msg.contains("incorrect api key")
|| error_msg.contains("invalid api key") || error_msg.contains("invalid api key")
|| error_msg.contains("api key not found") || error_msg.contains("api key not found")
@ -171,18 +166,14 @@ impl WorkerError {
|| error_msg.contains("expired") || error_msg.contains("expired")
|| error_msg.contains("revoked") || error_msg.contains("revoked")
|| error_msg.contains("suspended") || error_msg.contains("suspended")
// Generic "invalid" but exclude credit balance specific messages
|| (error_msg.contains("invalid") && !error_msg.contains("credit balance")); || (error_msg.contains("invalid") && !error_msg.contains("credit balance"));
// Exclude credit balance errors - these are not authentication errors
let is_credit_balance_error = error_msg.contains("credit balance") let is_credit_balance_error = error_msg.contains("credit balance")
|| error_msg.contains("billing") || error_msg.contains("billing")
|| error_msg.contains("upgrade") || error_msg.contains("upgrade")
|| error_msg.contains("purchase credits") || error_msg.contains("purchase credits")
// Also exclude Anthropic's specific credit balance error pattern
|| (error_msg.contains("invalid_request_error") && error_msg.contains("credit balance")); || (error_msg.contains("invalid_request_error") && error_msg.contains("credit balance"));
// Provider-specific patterns
let has_provider_patterns = match provider { let has_provider_patterns = match provider {
LlmProvider::OpenAI => { LlmProvider::OpenAI => {
error_msg.contains("invalid_api_key") error_msg.contains("invalid_api_key")
@ -190,9 +181,7 @@ impl WorkerError {
&& !error_msg.contains("credit balance")) && !error_msg.contains("credit balance"))
} }
LlmProvider::Claude => { LlmProvider::Claude => {
// Anthropic specific auth error patterns
(error_msg.contains("invalid_x_api_key") || error_msg.contains("x-api-key")) (error_msg.contains("invalid_x_api_key") || error_msg.contains("x-api-key"))
// But exclude credit balance issues which are not auth errors
&& !error_msg.contains("credit balance") && !error_msg.contains("credit balance")
&& !error_msg.contains("billing") && !error_msg.contains("billing")
&& !error_msg.contains("upgrade") && !error_msg.contains("upgrade")
@ -201,18 +190,15 @@ impl WorkerError {
LlmProvider::Gemini => { LlmProvider::Gemini => {
error_msg.contains("invalid_argument") || error_msg.contains("credentials") error_msg.contains("invalid_argument") || error_msg.contains("credentials")
} }
LlmProvider::Ollama => false, // Ollama typically doesn't have API keys LlmProvider::Ollama => false,
LlmProvider::XAI => { LlmProvider::XAI => {
error_msg.contains("invalid_api_key") || error_msg.contains("unauthorized") error_msg.contains("invalid_api_key") || error_msg.contains("unauthorized")
} }
}; };
// Generic patterns
let has_generic_patterns = let has_generic_patterns =
error_msg.contains("credentials") || error_msg.contains("authorization"); error_msg.contains("credentials") || error_msg.contains("authorization");
// Provider-specific bad request handling
// Some providers return 400 for auth issues instead of proper status codes
let provider_specific_bad_request = match provider { let provider_specific_bad_request = match provider {
LlmProvider::OpenAI => { LlmProvider::OpenAI => {
has_bad_request has_bad_request
@ -242,9 +228,6 @@ impl WorkerError {
} }
}; };
// If it's a bad request with auth-related keywords, treat as auth error
// This handles APIs that incorrectly return 400 for auth issues
// But exclude credit balance errors which are not authentication issues
let result = (has_auth_status let result = (has_auth_status
|| has_api_key_error || has_api_key_error
|| has_key_patterns || has_key_patterns
@ -292,16 +275,10 @@ pub fn get_supported_providers() -> Vec<String> {
] ]
} }
/// Validate if a provider name is supported
pub fn is_provider_supported(provider_name: &str) -> bool { pub fn is_provider_supported(provider_name: &str) -> bool {
LlmProvider::from_str(provider_name).is_some() LlmProvider::from_str(provider_name).is_some()
} }
/// Validate API key for a specific provider
/// Returns:
/// - Some(true): API key is valid
/// - Some(false): API key is invalid
/// - None: Unable to validate (e.g., no official validation endpoint)
pub async fn validate_api_key( pub async fn validate_api_key(
provider: LlmProvider, provider: LlmProvider,
api_key: &str, api_key: &str,
@ -314,48 +291,41 @@ pub async fn validate_api_key(
return Ok(Some(false)); return Ok(Some(false));
} }
// Only perform validation if provider has a simple, official validation method
match provider { match provider {
LlmProvider::Claude => { LlmProvider::Claude => {
// Anthropic doesn't have a dedicated validation endpoint
// Simple format check: should start with "sk-ant-"
if api_key.starts_with("sk-ant-") && api_key.len() > 20 { if api_key.starts_with("sk-ant-") && api_key.len() > 20 {
tracing::debug!("validate_api_key: Anthropic API key format appears valid"); tracing::debug!("validate_api_key: Anthropic API key format appears valid");
Ok(None) // Cannot validate without making a request Ok(None)
} else { } else {
tracing::debug!("validate_api_key: Anthropic API key format is invalid"); tracing::debug!("validate_api_key: Anthropic API key format is invalid");
Ok(Some(false)) Ok(Some(false))
} }
} }
LlmProvider::OpenAI => { LlmProvider::OpenAI => {
// OpenAI: simple format check
if api_key.starts_with("sk-") && api_key.len() > 20 { if api_key.starts_with("sk-") && api_key.len() > 20 {
tracing::debug!("validate_api_key: OpenAI API key format appears valid"); tracing::debug!("validate_api_key: OpenAI API key format appears valid");
Ok(None) // Cannot validate without making a request Ok(None)
} else { } else {
tracing::debug!("validate_api_key: OpenAI API key format is invalid"); tracing::debug!("validate_api_key: OpenAI API key format is invalid");
Ok(Some(false)) Ok(Some(false))
} }
} }
LlmProvider::Gemini => { LlmProvider::Gemini => {
// Gemini: simple format check
if api_key.len() > 20 { if api_key.len() > 20 {
tracing::debug!("validate_api_key: Gemini API key format appears valid"); tracing::debug!("validate_api_key: Gemini API key format appears valid");
Ok(None) // Cannot validate without making a request Ok(None)
} else { } else {
tracing::debug!("validate_api_key: Gemini API key format is invalid"); tracing::debug!("validate_api_key: Gemini API key format is invalid");
Ok(Some(false)) Ok(Some(false))
} }
} }
LlmProvider::Ollama => { LlmProvider::Ollama => {
// Ollama typically doesn't require API keys
Ok(Some(true)) Ok(Some(true))
} }
LlmProvider::XAI => { LlmProvider::XAI => {
// xAI: simple format check
if api_key.starts_with("xai-") && api_key.len() > 20 { if api_key.starts_with("xai-") && api_key.len() > 20 {
tracing::debug!("validate_api_key: xAI API key format appears valid"); tracing::debug!("validate_api_key: xAI API key format appears valid");
Ok(None) // Cannot validate without making a request Ok(None)
} else { } else {
tracing::debug!("validate_api_key: xAI API key format is invalid"); tracing::debug!("validate_api_key: xAI API key format is invalid");
Ok(Some(false)) Ok(Some(false))
@ -364,7 +334,6 @@ pub async fn validate_api_key(
} }
} }
// Models configuration structures
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ModelsConfig { pub struct ModelsConfig {
pub models: Vec<ModelDefinition>, pub models: Vec<ModelDefinition>,
@ -387,15 +356,13 @@ pub struct ModelMeta {
pub description: Option<String>, pub description: Option<String>,
} }
// Get models config path
fn get_models_config_path() -> Result<PathBuf, WorkerError> { fn get_models_config_path() -> Result<PathBuf, WorkerError> {
let home_dir = dirs::home_dir().ok_or_else(|| { let home_dir = dirs::home_dir().ok_or_else(|| {
WorkerError::ConfigurationError("Could not determine home directory".to_string()) WorkerError::config("Could not determine home directory")
})?; })?;
Ok(home_dir.join(".config").join("nia").join("models.yaml")) Ok(home_dir.join(".config").join("nia").join("models.yaml"))
} }
// Load models configuration
fn load_models_config() -> Result<ModelsConfig, WorkerError> { fn load_models_config() -> Result<ModelsConfig, WorkerError> {
let config_path = get_models_config_path()?; let config_path = get_models_config_path()?;
@ -408,26 +375,23 @@ fn load_models_config() -> Result<ModelsConfig, WorkerError> {
} }
let content = fs::read_to_string(&config_path).map_err(|e| { let content = fs::read_to_string(&config_path).map_err(|e| {
WorkerError::ConfigurationError(format!("Failed to read models config: {}", e)) WorkerError::config(format!("Failed to read models config: {}", e))
})?; })?;
let config: ModelsConfig = serde_yaml::from_str(&content).map_err(|e| { let config: ModelsConfig = serde_yaml::from_str(&content).map_err(|e| {
WorkerError::ConfigurationError(format!("Failed to parse models config: {}", e)) WorkerError::config(format!("Failed to parse models config: {}", e))
})?; })?;
Ok(config) Ok(config)
} }
// Tool support detection using configuration
pub async fn supports_native_tools( pub async fn supports_native_tools(
provider: &LlmProvider, provider: &LlmProvider,
model_name: &str, model_name: &str,
_api_key: &str, _api_key: &str,
) -> Result<bool, WorkerError> { ) -> Result<bool, WorkerError> {
// Load models configuration
let config = load_models_config()?; let config = load_models_config()?;
// Look for the specific model in configuration
let model_id = format!( let model_id = format!(
"{}/{}", "{}/{}",
match provider { match provider {
@ -440,7 +404,6 @@ pub async fn supports_native_tools(
model_name model_name
); );
// Find model in config and check function_calling setting
for model_def in &config.models { for model_def in &config.models {
if model_def.model == model_id || model_def.model.contains(model_name) { if model_def.model == model_id || model_def.model.contains(model_name) {
tracing::debug!( tracing::debug!(
@ -458,8 +421,6 @@ pub async fn supports_native_tools(
model_name model_name
); );
// Fallback to provider-based detection if model not found in config
// But prioritize setting over provider defaults
tracing::warn!( tracing::warn!(
"Using provider-based fallback - this should be configured in models.yaml: provider={:?}, model={}", "Using provider-based fallback - this should be configured in models.yaml: provider={:?}, model={}",
provider, provider,
@ -470,7 +431,7 @@ pub async fn supports_native_tools(
LlmProvider::Claude => true, LlmProvider::Claude => true,
LlmProvider::OpenAI => !model_name.contains("gpt-3.5-turbo-instruct"), LlmProvider::OpenAI => !model_name.contains("gpt-3.5-turbo-instruct"),
LlmProvider::Gemini => !model_name.contains("gemini-pro-vision"), LlmProvider::Gemini => !model_name.contains("gemini-pro-vision"),
LlmProvider::Ollama => false, // Default to XML-based tools for Ollama LlmProvider::Ollama => false,
LlmProvider::XAI => true, LlmProvider::XAI => true,
}; };
@ -483,183 +444,63 @@ pub async fn supports_native_tools(
Ok(supports_tools) Ok(supports_tools)
} }
// LlmClient trait - 共通インターフェース
#[async_trait::async_trait]
pub trait LlmClientTrait: Send + Sync {
async fn chat_stream<'a>(
&'a self,
messages: Vec<Message>,
tools: Option<&[crate::types::DynamicToolDefinition]>,
llm_debug: Option<crate::types::LlmDebug>,
) -> Result<
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
WorkerError,
>;
async fn check_connection(&self) -> Result<(), WorkerError>;
fn provider(&self) -> LlmProvider;
fn get_model_name(&self) -> String;
}
// LlmClient enumを使用してdyn互換性の問題を解決
pub enum LlmClient {
Anthropic(AnthropicClient),
Gemini(GeminiClient),
Ollama(OllamaClient),
OpenAI(OpenAIClient),
XAI(XAIClient),
}
// LlmClient enumに対するメソッド実装を削除
// 代わりにLlmClientTraitの実装のみを使用
// 委譲マクロでボイラープレートを削減
macro_rules! delegate_to_client {
// 引数ありの場合
($self:expr, $method:ident, $($arg:expr),+) => {
match $self {
LlmClient::Anthropic(client) => client.$method($($arg),*),
LlmClient::Gemini(client) => client.$method($($arg),*),
LlmClient::Ollama(client) => client.$method($($arg),*),
LlmClient::OpenAI(client) => client.$method($($arg),*),
LlmClient::XAI(client) => client.$method($($arg),*),
}
};
// 引数なしの場合
($self:expr, $method:ident) => {
match $self {
LlmClient::Anthropic(client) => client.$method(),
LlmClient::Gemini(client) => client.$method(),
LlmClient::Ollama(client) => client.$method(),
LlmClient::OpenAI(client) => client.$method(),
LlmClient::XAI(client) => client.$method(),
}
};
}
// LlmClient enum にtraitを実装 - マクロで簡潔な委譲
#[async_trait::async_trait]
impl LlmClientTrait for LlmClient {
async fn chat_stream<'a>(
&'a self,
messages: Vec<Message>,
tools: Option<&[crate::types::DynamicToolDefinition]>,
llm_debug: Option<crate::types::LlmDebug>,
) -> Result<
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
WorkerError,
> {
match self {
LlmClient::Anthropic(client) => client.chat_stream(messages, tools, llm_debug).await,
LlmClient::Gemini(client) => client.chat_stream(messages, tools, llm_debug).await,
LlmClient::Ollama(client) => client.chat_stream(messages, tools, llm_debug).await,
LlmClient::OpenAI(client) => client.chat_stream(messages, tools, llm_debug).await,
LlmClient::XAI(client) => client.chat_stream(messages, tools, llm_debug).await,
}
}
async fn check_connection(&self) -> Result<(), WorkerError> {
match self {
LlmClient::Anthropic(client) => client.check_connection().await,
LlmClient::Gemini(client) => client.check_connection().await,
LlmClient::Ollama(client) => client.check_connection().await,
LlmClient::OpenAI(client) => client.check_connection().await,
LlmClient::XAI(client) => client.check_connection().await,
}
}
fn provider(&self) -> LlmProvider {
delegate_to_client!(self, provider)
}
fn get_model_name(&self) -> String {
delegate_to_client!(self, get_model_name)
}
}
pub struct Worker { pub struct Worker {
llm_client: Box<dyn LlmClientTrait>, pub(crate) llm_client: Box<dyn LlmClientTrait>,
composer: PromptComposer, pub(crate) composer: PromptComposer,
tools: Vec<Box<dyn Tool>>, pub(crate) tools: Vec<Box<dyn Tool>>,
api_key: String, pub(crate) api_key: String,
provider_str: String, pub(crate) provider_str: String,
model_name: String, pub(crate) model_name: String,
role_config: Option<RoleConfig>, pub(crate) role: Role,
config: Option<PromptRoleConfig>, pub(crate) workspace_context: Option<WorkspaceContext>,
workspace_context: Option<WorkspaceContext>, pub(crate) message_history: Vec<Message>,
message_history: Vec<Message>, pub(crate) hook_manager: crate::types::HookManager,
hook_manager: crate::types::HookManager, pub(crate) mcp_lazy_configs: Vec<McpServerConfig>,
mcp_lazy_configs: Vec<mcp_tool::McpServerConfig>, pub(crate) plugin_registry: std::sync::Arc<std::sync::Mutex<plugin::PluginRegistry>>,
} }
impl Worker { impl Worker {
pub fn new( /// Create a new WorkerBuilder
provider: LlmProvider, ///
model_name: &str, /// # Example
api_keys: &HashMap<String, String>, /// ```no_run
role_config: Option<RoleConfig>, /// use worker::{Worker, LlmProvider, Role};
) -> Result<Self, WorkerError> { ///
let provider_str = provider.as_str(); /// let role = Role::new("assistant", "AI Assistant", "You are a helpful assistant.");
let api_key = api_keys.get(provider_str).cloned().unwrap_or_default(); /// let worker = Worker::builder()
let llm_client = provider.create_client(model_name, &api_key)?; /// .provider(LlmProvider::Claude)
/// .model("claude-3-sonnet-20240229")
// ワークスペースコンテキストを取得 /// .api_key("claude", "sk-ant-...")
let workspace_context = WorkspaceDetector::detect_workspace().ok(); /// .role(role)
/// .build()?;
// プロンプトコンテキストを作成 /// # Ok::<(), worker::WorkerError>(())
let prompt_context = Self::create_prompt_context_static( /// ```
&workspace_context, pub fn builder() -> builder::WorkerBuilder<builder::NoProvider, builder::NoModel, builder::NoRole> {
provider.clone(), builder::WorkerBuilder::new()
model_name, }
&[],
);
/// Load plugins from a directory
// デフォルト設定またはcli-assistant設定を使用 #[cfg(feature = "dynamic-loading")]
let composer = match Self::try_load_default_config(prompt_context.clone()) { pub async fn load_plugins_from_directory(&mut self, dir: &std::path::Path) -> Result<(), WorkerError> {
Ok(composer) => { let plugins = plugin::PluginLoader::load_from_directory(dir)?;
tracing::info!("Loaded default CLI assistant configuration"); let mut registry = self.plugin_registry.lock()
composer .map_err(|e| WorkerError::config(format!("Failed to lock plugin registry: {}", e)))?;
}
Err(e) => { for plugin in plugins {
tracing::warn!("Failed to load default config, using fallback: {}", e); registry.register(std::sync::Arc::from(plugin))?;
let default_config = PromptRoleConfig::default();
PromptComposer::from_config(default_config, prompt_context)
.map_err(|e| WorkerError::ConfigurationError(e.to_string()))?
}
};
let mut worker = Self {
llm_client: Box::new(llm_client),
composer,
tools: Vec::new(),
api_key,
provider_str: provider_str.to_string(),
model_name: model_name.to_string(),
role_config,
config: None,
workspace_context,
message_history: Vec::new(),
hook_manager: crate::types::HookManager::new(),
mcp_lazy_configs: Vec::new(),
};
// セッション開始時のシステムプロンプト初期化
worker
.initialize_session()
.map_err(|e| WorkerError::ConfigurationError(e.to_string()))?;
Ok(worker)
} }
/// ツールリストをロードする
pub fn load_tools(&mut self, tools: Vec<Box<dyn Tool>>) -> Result<(), WorkerError> {
self.tools.extend(tools);
tracing::info!("Loaded {} tools", self.tools.len());
Ok(()) Ok(())
} }
/// List all registered plugins
pub fn list_plugins(&self) -> Result<Vec<plugin::PluginMetadata>, WorkerError> {
let registry = self.plugin_registry.lock()
.map_err(|e| WorkerError::config(format!("Failed to lock plugin registry: {}", e)))?;
Ok(registry.list())
}
/// フックを登録する /// フックを登録する
pub fn register_hook(&mut self, hook: Box<dyn crate::types::WorkerHook>) { pub fn register_hook(&mut self, hook: Box<dyn crate::types::WorkerHook>) {
let hook_name = hook.name().to_string(); let hook_name = hook.name().to_string();
@ -676,9 +517,9 @@ impl Worker {
/// MCPサーバーをツールとして登録する /// MCPサーバーをツールとして登録する
pub fn register_mcp_server( pub fn register_mcp_server(
&mut self, &mut self,
config: mcp_tool::McpServerConfig, config: McpServerConfig,
) -> Result<(), WorkerError> { ) -> Result<(), WorkerError> {
let mcp_tool = mcp_tool::McpDynamicTool::new(config.clone()); let mcp_tool = McpDynamicTool::new(config.clone());
self.register_tool(Box::new(mcp_tool))?; self.register_tool(Box::new(mcp_tool))?;
tracing::info!("Registered MCP server as tool: {}", config.name); tracing::info!("Registered MCP server as tool: {}", config.name);
Ok(()) Ok(())
@ -687,9 +528,9 @@ impl Worker {
/// MCPサーバーから個別のツールを登録する /// MCPサーバーから個別のツールを登録する
pub async fn register_mcp_tools( pub async fn register_mcp_tools(
&mut self, &mut self,
config: mcp_tool::McpServerConfig, config: McpServerConfig,
) -> Result<(), WorkerError> { ) -> Result<(), WorkerError> {
let tools = mcp_tool::create_single_mcp_tools(&config).await?; let tools = create_single_mcp_tools(&config).await?;
let tool_count = tools.len(); let tool_count = tools.len();
for tool in tools { for tool in tools {
@ -705,7 +546,7 @@ impl Worker {
} }
/// MCP サーバーを並列初期化キューに追加 /// MCP サーバーを並列初期化キューに追加
pub fn queue_mcp_server(&mut self, config: mcp_tool::McpServerConfig) { pub fn queue_mcp_server(&mut self, config: McpServerConfig) {
tracing::info!("Queuing MCP server: {}", config.name); tracing::info!("Queuing MCP server: {}", config.name);
self.mcp_lazy_configs.push(config); self.mcp_lazy_configs.push(config);
} }
@ -735,7 +576,7 @@ impl Worker {
tokio::spawn(async move { tokio::spawn(async move {
tracing::info!("Parallel initializing MCP server: {}", config_name); tracing::info!("Parallel initializing MCP server: {}", config_name);
match mcp_tool::create_single_mcp_tools(&config).await { match create_single_mcp_tools(&config).await {
Ok(tools) => { Ok(tools) => {
tracing::info!( tracing::info!(
"Successfully initialized {} tools from MCP server: {}", "Successfully initialized {} tools from MCP server: {}",
@ -940,25 +781,25 @@ impl Worker {
&mut self, &mut self,
config_path: P, config_path: P,
) -> Result<(), WorkerError> { ) -> Result<(), WorkerError> {
use crate::config_parser::ConfigParser; use crate::config::ConfigParser;
// 設定ファイルを読み込み // 設定ファイルを読み込み
let config = ConfigParser::parse_from_file(config_path) let role = ConfigParser::parse_from_file(config_path)
.map_err(|e| WorkerError::ConfigurationError(e.to_string()))?; .map_err(|e| WorkerError::config(e.to_string()))?;
// プロンプトコンテキストを構築 // プロンプトコンテキストを構築
let prompt_context = self.create_prompt_context()?; let prompt_context = self.create_prompt_context()?;
// DynamicPromptComposerを作成 // DynamicPromptComposerを作成
let composer = PromptComposer::from_config(config.clone(), prompt_context) let composer = PromptComposer::from_config(role.clone(), prompt_context)
.map_err(|e| WorkerError::ConfigurationError(e.to_string()))?; .map_err(|e| WorkerError::config(e.to_string()))?;
self.config = Some(config); self.role = role;
self.composer = composer; self.composer = composer;
// 設定変更後にセッション再初期化 // 設定変更後にセッション再初期化
self.initialize_session() self.initialize_session()
.map_err(|e| WorkerError::ConfigurationError(e.to_string()))?; .map_err(|e| WorkerError::config(e.to_string()))?;
tracing::info!("Dynamic configuration loaded successfully"); tracing::info!("Dynamic configuration loaded successfully");
Ok(()) Ok(())
@ -1014,7 +855,7 @@ impl Worker {
/// プロンプトコンテキストを作成 /// プロンプトコンテキストを作成
fn create_prompt_context(&self) -> Result<PromptContext, WorkerError> { fn create_prompt_context(&self) -> Result<PromptContext, WorkerError> {
let provider = LlmProvider::from_str(&self.provider_str).ok_or_else(|| { let provider = LlmProvider::from_str(&self.provider_str).ok_or_else(|| {
WorkerError::ConfigurationError(format!("Unknown provider: {}", self.provider_str)) WorkerError::config(format!("Unknown provider: {}", self.provider_str))
})?; })?;
// モデル能力を静的に判定 // モデル能力を静的に判定
@ -1084,10 +925,10 @@ impl Worker {
pub fn register_tool(&mut self, tool: Box<dyn Tool>) -> Result<(), WorkerError> { pub fn register_tool(&mut self, tool: Box<dyn Tool>) -> Result<(), WorkerError> {
// 同名のツールが既に存在するかチェック // 同名のツールが既に存在するかチェック
if self.tools.iter().any(|t| t.name() == tool.name()) { if self.tools.iter().any(|t| t.name() == tool.name()) {
return Err(WorkerError::ToolExecutionError(format!( return Err(WorkerError::tool_execution(
"Tool '{}' is already registered", tool.name(),
tool.name() format!("Tool '{}' is already registered", tool.name())
))); ));
} }
self.tools.push(tool); self.tools.push(tool);
@ -1118,10 +959,10 @@ impl Worker {
) -> Result<serde_json::Value, WorkerError> { ) -> Result<serde_json::Value, WorkerError> {
match self.tools.iter().find(|tool| tool.name() == tool_name) { match self.tools.iter().find(|tool| tool.name() == tool_name) {
Some(tool) => tool.execute(args).await.map_err(WorkerError::from), Some(tool) => tool.execute(args).await.map_err(WorkerError::from),
None => Err(WorkerError::ToolExecutionError(format!( None => Err(WorkerError::tool_execution(
"Tool '{}' not found", tool_name,
tool_name format!("Tool '{}' not found", tool_name)
))), )),
} }
} }
@ -1149,12 +990,12 @@ impl Worker {
} }
/// Get configuration information for task delegation /// Get configuration information for task delegation
pub fn get_config(&self) -> (LlmProvider, &str, &str, &Option<RoleConfig>) { pub fn get_config(&self) -> (LlmProvider, &str, &str, &Role) {
( (
self.llm_client.provider(), self.llm_client.provider(),
&self.model_name, &self.model_name,
&self.api_key, &self.api_key,
&self.role_config, &self.role,
) )
} }
@ -1267,7 +1108,7 @@ impl Worker {
let messages = match composer.compose(&conversation_messages) { let messages = match composer.compose(&conversation_messages) {
Ok(m) => m, Ok(m) => m,
Err(e) => { Err(e) => {
yield Err(WorkerError::ConfigurationError(e.to_string())); yield Err(WorkerError::config(e.to_string()));
return; return;
} }
}; };
@ -1278,7 +1119,7 @@ impl Worker {
let messages = match composer.compose_with_tools(&conversation_messages, &tools_schema) { let messages = match composer.compose_with_tools(&conversation_messages, &tools_schema) {
Ok(m) => m, Ok(m) => m,
Err(e) => { Err(e) => {
yield Err(WorkerError::ConfigurationError(e.to_string())); yield Err(WorkerError::config(e.to_string()));
return; return;
} }
}; };
@ -1466,7 +1307,7 @@ impl Worker {
let messages = match self.composer.compose(&conversation_messages) { let messages = match self.composer.compose(&conversation_messages) {
Ok(m) => m, Ok(m) => m,
Err(e) => { Err(e) => {
yield Err(WorkerError::ConfigurationError(e.to_string())); yield Err(WorkerError::config(e.to_string()));
return; return;
} }
}; };
@ -1477,7 +1318,7 @@ impl Worker {
let messages = match self.composer.compose_with_tools(&conversation_messages, &tools_schema) { let messages = match self.composer.compose_with_tools(&conversation_messages, &tools_schema) {
Ok(m) => m, Ok(m) => m,
Err(e) => { Err(e) => {
yield Err(WorkerError::ConfigurationError(e.to_string())); yield Err(WorkerError::config(e.to_string()));
return; return;
} }
}; };
@ -1692,7 +1533,7 @@ impl Worker {
let messages = match self.composer.compose(&conversation_messages) { let messages = match self.composer.compose(&conversation_messages) {
Ok(m) => m, Ok(m) => m,
Err(e) => { Err(e) => {
yield Err(WorkerError::ConfigurationError(e.to_string())); yield Err(WorkerError::config(e.to_string()));
return; return;
} }
}; };
@ -1703,7 +1544,7 @@ impl Worker {
let messages = match self.composer.compose_with_tools(&conversation_messages, &tools_schema) { let messages = match self.composer.compose_with_tools(&conversation_messages, &tools_schema) {
Ok(m) => m, Ok(m) => m,
Err(e) => { Err(e) => {
yield Err(WorkerError::ConfigurationError(e.to_string())); yield Err(WorkerError::config(e.to_string()));
return; return;
} }
}; };
@ -1876,7 +1717,7 @@ impl Worker {
pub fn get_session_data(&self) -> Result<SessionData, WorkerError> { pub fn get_session_data(&self) -> Result<SessionData, WorkerError> {
let workspace_path = std::env::current_dir() let workspace_path = std::env::current_dir()
.map_err(|e| { .map_err(|e| {
WorkerError::ConfigurationError(format!("Failed to get current directory: {}", e)) WorkerError::config(format!("Failed to get current directory: {}", e))
})? })?
.to_string_lossy() .to_string_lossy()
.to_string(); .to_string();
@ -1918,7 +1759,7 @@ impl Worker {
// セッション復元時にプロンプトコンポーザーを再初期化 // セッション復元時にプロンプトコンポーザーを再初期化
self.reinitialize_session_with_history() self.reinitialize_session_with_history()
.map_err(|e| WorkerError::ConfigurationError(e.to_string()))?; .map_err(|e| WorkerError::config(e.to_string()))?;
Ok(()) Ok(())
} }
@ -1970,61 +1811,9 @@ impl Worker {
} }
} }
/// デフォルト設定ファイルの読み込みを試行
fn try_load_default_config(
prompt_context: PromptContext,
) -> Result<PromptComposer, WorkerError> {
use crate::config_parser::ConfigParser;
// デフォルト設定ファイルのパスを試行
let possible_paths = [
"#nia/config/roles/cli-assistant.yaml",
"./resources/config/roles/cli-assistant.yaml",
"./nia-cli/resources/config/roles/cli-assistant.yaml",
"../nia-cli/resources/config/roles/cli-assistant.yaml",
];
for path in &possible_paths {
if let Ok(resolved_path) = ConfigParser::resolve_path(path) {
if resolved_path.exists() {
match ConfigParser::parse_from_file(&resolved_path) {
Ok(config) => {
match PromptComposer::from_config(config, prompt_context.clone()) {
Ok(composer) => {
tracing::info!(
"Successfully loaded config from: {}",
resolved_path.display()
);
return Ok(composer);
}
Err(e) => {
tracing::warn!(
"Failed to create composer from {}: {}",
resolved_path.display(),
e
);
}
}
}
Err(e) => {
tracing::warn!(
"Failed to parse config from {}: {}",
resolved_path.display(),
e
);
}
}
}
}
}
Err(WorkerError::ConfigurationError(
"No default configuration found".to_string(),
))
}
/// セッション初期化Worker内部用 /// セッション初期化Worker内部用
fn initialize_session(&mut self) -> Result<(), crate::prompt_types::PromptError> { fn initialize_session(&mut self) -> Result<(), crate::prompt::PromptError> {
// 空のメッセージでセッション初期化 // 空のメッセージでセッション初期化
self.composer.initialize_session(&[]) self.composer.initialize_session(&[])
} }
@ -2032,7 +1821,7 @@ impl Worker {
/// 履歴付きセッション再初期化Worker内部用 /// 履歴付きセッション再初期化Worker内部用
fn reinitialize_session_with_history( fn reinitialize_session_with_history(
&mut self, &mut self,
) -> Result<(), crate::prompt_types::PromptError> { ) -> Result<(), crate::prompt::PromptError> {
// 現在の履歴を使ってセッション初期化 // 現在の履歴を使ってセッション初期化
self.composer.initialize_session(&self.message_history) self.composer.initialize_session(&self.message_history)
} }

View File

@ -1,8 +1,7 @@
use crate::{ use crate::core::LlmClientTrait;
LlmClientTrait, WorkerError, use crate::types::WorkerError;
types::{LlmProvider, Message, Role, StreamEvent, ToolCall}, use worker_types::{LlmProvider, Message, Role, StreamEvent, ToolCall};
url_config::UrlConfig, use crate::config::UrlConfig;
};
use async_stream::stream; use async_stream::stream;
use futures_util::{Stream, StreamExt}; use futures_util::{Stream, StreamExt};
use reqwest::Client; use reqwest::Client;

View File

@ -1,8 +1,7 @@
use crate::{ use crate::core::LlmClientTrait;
LlmClientTrait, WorkerError, use crate::types::WorkerError;
types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall}, use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall};
url_config::UrlConfig, use crate::config::UrlConfig;
};
use futures_util::{Stream, StreamExt, TryStreamExt}; use futures_util::{Stream, StreamExt, TryStreamExt};
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -708,7 +707,7 @@ impl GeminiClient {
}; };
let stream = stream_events(&self.api_key, &self.model, request, llm_debug) let stream = stream_events(&self.api_key, &self.model, request, llm_debug)
.map_err(|e| WorkerError::LlmApiError(e.to_string())); .map_err(|e| WorkerError::llm_api("gemini", e.to_string()));
Ok(Box::new(Box::pin(stream))) Ok(Box::new(Box::pin(stream)))
} }
@ -805,7 +804,7 @@ impl GeminiClient {
// Simple connection check - try to call the API // Simple connection check - try to call the API
// For now, just return OK if model is not empty // For now, just return OK if model is not empty
if self.model.is_empty() { if self.model.is_empty() {
return Err(WorkerError::ModelNotFound("No model specified".to_string())); return Err(WorkerError::model_not_found("gemini", "No model specified"));
} }
Ok(()) Ok(())
} }

View File

@ -1,8 +1,7 @@
use crate::{ use crate::core::LlmClientTrait;
LlmClientTrait, WorkerError, use crate::types::WorkerError;
types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall}, use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall};
url_config::UrlConfig, use crate::config::UrlConfig;
};
use futures_util::{Stream, StreamExt}; use futures_util::{Stream, StreamExt};
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -345,7 +344,7 @@ impl OllamaClient {
.and_then(|models| models.as_array()) .and_then(|models| models.as_array())
.ok_or_else(|| { .ok_or_else(|| {
tracing::error!("Invalid Ollama models response format - missing 'models' array"); tracing::error!("Invalid Ollama models response format - missing 'models' array");
WorkerError::LlmApiError("Invalid models response format".to_string()) WorkerError::llm_api("ollama", "Invalid models response format")
})? })?
.iter() .iter()
.filter_map(|model| { .filter_map(|model| {
@ -671,7 +670,7 @@ impl OllamaClient {
self.add_auth_header(client.get(&url)) self.add_auth_header(client.get(&url))
.send() .send()
.await .await
.map_err(|e| WorkerError::LlmApiError(format!("Failed to connect to Ollama: {}", e)))?; .map_err(|e| WorkerError::llm_api("ollama", format!("Failed to connect to Ollama: {}", e)))?;
Ok(()) Ok(())
} }
} }

View File

@ -1,8 +1,7 @@
use crate::{ use crate::core::LlmClientTrait;
LlmClientTrait, WorkerError, use crate::types::WorkerError;
types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall}, use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall};
url_config::UrlConfig, use crate::config::UrlConfig;
};
use futures_util::{Stream, StreamExt}; use futures_util::{Stream, StreamExt};
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};

View File

@ -1,8 +1,7 @@
use crate::{ use crate::core::LlmClientTrait;
LlmClientTrait, WorkerError, use crate::types::WorkerError;
types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall}, use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall};
url_config::UrlConfig, use crate::config::UrlConfig;
};
use futures_util::{Stream, StreamExt}; use futures_util::{Stream, StreamExt};
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};

View File

@ -1,5 +1,5 @@
use crate::WorkerError; use crate::types::WorkerError;
use crate::mcp_tool::McpServerConfig; use super::tool::McpServerConfig;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::path::Path; use std::path::Path;
@ -71,14 +71,14 @@ impl McpConfig {
info!("Loading MCP config from: {:?}", path); info!("Loading MCP config from: {:?}", path);
let content = std::fs::read_to_string(path).map_err(|e| { let content = std::fs::read_to_string(path).map_err(|e| {
WorkerError::ConfigurationError(format!( WorkerError::config(format!(
"Failed to read MCP config file {:?}: {}", "Failed to read MCP config file {:?}: {}",
path, e path, e
)) ))
})?; })?;
let config: McpConfig = serde_yaml::from_str(&content).map_err(|e| { let config: McpConfig = serde_yaml::from_str(&content).map_err(|e| {
WorkerError::ConfigurationError(format!( WorkerError::config(format!(
"Failed to parse MCP config file {:?}: {}", "Failed to parse MCP config file {:?}: {}",
path, e path, e
)) ))
@ -95,7 +95,7 @@ impl McpConfig {
// ディレクトリが存在しない場合は作成 // ディレクトリが存在しない場合は作成
if let Some(parent) = path.parent() { if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).map_err(|e| { std::fs::create_dir_all(parent).map_err(|e| {
WorkerError::ConfigurationError(format!( WorkerError::config(format!(
"Failed to create config directory {:?}: {}", "Failed to create config directory {:?}: {}",
parent, e parent, e
)) ))
@ -103,11 +103,11 @@ impl McpConfig {
} }
let content = serde_yaml::to_string(self).map_err(|e| { let content = serde_yaml::to_string(self).map_err(|e| {
WorkerError::ConfigurationError(format!("Failed to serialize MCP config: {}", e)) WorkerError::config(format!("Failed to serialize MCP config: {}", e))
})?; })?;
std::fs::write(path, content).map_err(|e| { std::fs::write(path, content).map_err(|e| {
WorkerError::ConfigurationError(format!( WorkerError::config(format!(
"Failed to write MCP config file {:?}: {}", "Failed to write MCP config file {:?}: {}",
path, e path, e
)) ))
@ -225,7 +225,7 @@ fn expand_environment_variables(input: &str) -> Result<String, WorkerError> {
// ${VAR_NAME} パターンを検索して置換 // ${VAR_NAME} パターンを検索して置換
let re = regex::Regex::new(r"\$\{([^}]+)\}") let re = regex::Regex::new(r"\$\{([^}]+)\}")
.map_err(|e| WorkerError::ConfigurationError(format!("Regex error: {}", e)))?; .map_err(|e| WorkerError::config(format!("Regex error: {}", e)))?;
for caps in re.captures_iter(input) { for caps in re.captures_iter(input) {
let full_match = &caps[0]; let full_match = &caps[0];
@ -300,7 +300,10 @@ servers:
#[test] #[test]
fn test_environment_variable_expansion() { fn test_environment_variable_expansion() {
// SAFETY: Setting test environment variables in a single-threaded test context
unsafe {
std::env::set_var("TEST_VAR", "test_value"); std::env::set_var("TEST_VAR", "test_value");
}
let result = expand_environment_variables("prefix_${TEST_VAR}_suffix").unwrap(); let result = expand_environment_variables("prefix_${TEST_VAR}_suffix").unwrap();
assert_eq!(result, "prefix_test_value_suffix"); assert_eq!(result, "prefix_test_value_suffix");

10
worker/src/mcp/mod.rs Normal file
View File

@ -0,0 +1,10 @@
mod config;
mod protocol;
mod tool;
pub use config::{IntegrationMode, McpConfig, McpServerDefinition};
pub use protocol::McpClient;
pub use tool::{
create_single_mcp_tools, get_mcp_tools_as_definitions, test_mcp_connection, McpDynamicTool,
McpServerConfig, SingleMcpTool,
};

View File

@ -1,4 +1,4 @@
use crate::mcp_protocol::{CallToolResult, McpClient, McpToolDefinition}; use super::protocol::{CallToolResult, McpClient, McpToolDefinition};
use crate::types::{Tool, ToolResult}; use crate::types::{Tool, ToolResult};
use async_trait::async_trait; use async_trait::async_trait;
use serde_json::Value; use serde_json::Value;
@ -92,10 +92,11 @@ impl McpDynamicTool {
.connect(self.config.command.clone(), self.config.args.clone()) .connect(self.config.command.clone(), self.config.args.clone())
.await .await
.map_err(|e| { .map_err(|e| {
crate::WorkerError::ToolExecutionError(format!( crate::WorkerError::tool_execution_with_source(
"Failed to connect to MCP server '{}': {}", &self.config.name,
self.config.name, e format!("Failed to connect to MCP server '{}'", self.config.name),
)) e,
)
})?; })?;
*client_guard = Some(mcp_client); *client_guard = Some(mcp_client);
@ -111,14 +112,15 @@ impl McpDynamicTool {
let mut client_guard = self.client.lock().await; let mut client_guard = self.client.lock().await;
let client = client_guard.as_mut().ok_or_else(|| { let client = client_guard.as_mut().ok_or_else(|| {
crate::WorkerError::ToolExecutionError("MCP client not connected".to_string()) crate::WorkerError::tool_execution(&self.config.name, "MCP client not connected")
})?; })?;
let tools = client.list_tools().await.map_err(|e| { let tools = client.list_tools().await.map_err(|e| {
crate::WorkerError::ToolExecutionError(format!( crate::WorkerError::tool_execution_with_source(
"Failed to list tools from MCP server '{}': {}", &self.config.name,
self.config.name, e format!("Failed to list tools from MCP server '{}'", self.config.name),
)) e,
)
})?; })?;
debug!( debug!(
@ -166,16 +168,17 @@ impl McpDynamicTool {
let mut client_guard = self.client.lock().await; let mut client_guard = self.client.lock().await;
let client = client_guard.as_mut().ok_or_else(|| { let client = client_guard.as_mut().ok_or_else(|| {
crate::WorkerError::ToolExecutionError("MCP client not connected".to_string()) crate::WorkerError::tool_execution(tool_name, "MCP client not connected")
})?; })?;
debug!("Calling MCP tool '{}' with args: {}", tool_name, args); debug!("Calling MCP tool '{}' with args: {}", tool_name, args);
let result = client.call_tool(tool_name, Some(args)).await.map_err(|e| { let result = client.call_tool(tool_name, Some(args)).await.map_err(|e| {
crate::WorkerError::ToolExecutionError(format!( crate::WorkerError::tool_execution_with_source(
"Failed to call MCP tool '{}': {}", tool_name,
tool_name, e format!("Failed to call MCP tool '{}'", tool_name),
)) e,
)
})?; })?;
debug!("MCP tool '{}' returned: {:?}", tool_name, result); debug!("MCP tool '{}' returned: {:?}", tool_name, result);
@ -205,7 +208,7 @@ impl SingleMcpTool {
async fn call_mcp_tool(&self, args: Value) -> ToolResult<Value> { async fn call_mcp_tool(&self, args: Value) -> ToolResult<Value> {
let mut client_guard = self.client.lock().await; let mut client_guard = self.client.lock().await;
let client = client_guard.as_mut().ok_or_else(|| { let client = client_guard.as_mut().ok_or_else(|| {
crate::WorkerError::ToolExecutionError("MCP client not connected".to_string()) crate::WorkerError::tool_execution(&self.tool_name, "MCP client not connected")
})?; })?;
debug!("Calling MCP tool '{}' with args: {}", self.tool_name, args); debug!("Calling MCP tool '{}' with args: {}", self.tool_name, args);
@ -214,10 +217,11 @@ impl SingleMcpTool {
.call_tool(&self.tool_name, Some(args)) .call_tool(&self.tool_name, Some(args))
.await .await
.map_err(|e| { .map_err(|e| {
crate::WorkerError::ToolExecutionError(format!( crate::WorkerError::tool_execution_with_source(
"Failed to call MCP tool '{}': {}", &self.tool_name,
self.tool_name, e format!("Failed to call MCP tool '{}'", self.tool_name),
)) e,
)
})?; })?;
debug!("MCP tool '{}' returned: {:?}", self.tool_name, result); debug!("MCP tool '{}' returned: {:?}", self.tool_name, result);
@ -279,16 +283,18 @@ impl Tool for McpDynamicTool {
.get("tool_name") .get("tool_name")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.ok_or_else(|| { .ok_or_else(|| {
crate::WorkerError::ToolExecutionError( crate::WorkerError::tool_execution(
"Missing required parameter 'tool_name'".to_string(), "mcp_proxy",
"Missing required parameter 'tool_name'",
) )
})?; })?;
let tool_args = args let tool_args = args
.get("tool_args") .get("tool_args")
.ok_or_else(|| { .ok_or_else(|| {
crate::WorkerError::ToolExecutionError( crate::WorkerError::tool_execution(
"Missing required parameter 'tool_args'".to_string(), "mcp_proxy",
"Missing required parameter 'tool_args'",
) )
})? })?
.clone(); .clone();
@ -318,10 +324,10 @@ impl Tool for McpDynamicTool {
"result": result "result": result
})) }))
} }
None => Err(Box::new(crate::WorkerError::ToolExecutionError(format!( None => Err(Box::new(crate::WorkerError::tool_execution(
"Tool '{}' not found in MCP server '{}'", tool_name,
tool_name, self.config.name format!("Tool '{}' not found in MCP server '{}'", tool_name, self.config.name),
))) ))
as Box<dyn std::error::Error + Send + Sync>), as Box<dyn std::error::Error + Send + Sync>),
} }
} }

View File

@ -0,0 +1,241 @@
use async_trait::async_trait;
use futures_util::Stream;
use serde_json::Value;
use std::any::Any;
use std::collections::HashMap;
use crate::plugin::{PluginMetadata, ProviderPlugin};
use crate::core::LlmClientTrait;
use crate::types::WorkerError;
use worker_types::{DynamicToolDefinition, LlmDebug, LlmProvider, Message, Role, StreamEvent};
/// Example custom provider plugin implementation
pub struct CustomProviderPlugin {
initialized: bool,
config: HashMap<String, Value>,
}
impl CustomProviderPlugin {
pub fn new() -> Self {
Self {
initialized: false,
config: HashMap::new(),
}
}
}
#[async_trait]
impl ProviderPlugin for CustomProviderPlugin {
fn metadata(&self) -> PluginMetadata {
PluginMetadata {
id: "custom-provider".to_string(),
name: "Custom LLM Provider".to_string(),
version: "1.0.0".to_string(),
author: "Example Author".to_string(),
description: "An example custom LLM provider plugin".to_string(),
supported_models: vec![
"custom-small".to_string(),
"custom-large".to_string(),
"custom-turbo".to_string(),
],
requires_api_key: true,
config_schema: Some(serde_json::json!({
"type": "object",
"properties": {
"base_url": {
"type": "string",
"description": "Base URL for the API endpoint"
},
"timeout": {
"type": "integer",
"description": "Request timeout in seconds",
"default": 30
},
"max_retries": {
"type": "integer",
"description": "Maximum number of retries",
"default": 3
}
}
})),
}
}
async fn initialize(&mut self, config: HashMap<String, Value>) -> Result<(), WorkerError> {
self.config = config;
self.initialized = true;
tracing::info!("Custom provider plugin initialized");
Ok(())
}
fn create_client(
&self,
model_name: &str,
api_key: Option<&str>,
config: Option<HashMap<String, Value>>,
) -> Result<Box<dyn LlmClientTrait>, WorkerError> {
if !self.initialized {
return Err(WorkerError::config("Plugin not initialized"));
}
let api_key = api_key
.ok_or_else(|| WorkerError::config("API key required for custom provider"))?;
let client = CustomLlmClient::new(
api_key.to_string(),
model_name.to_string(),
config.unwrap_or_else(|| self.config.clone()),
);
Ok(Box::new(client))
}
fn validate_api_key(&self, api_key: &str) -> bool {
// Custom validation logic - in this example, check if key starts with "custom-"
api_key.starts_with("custom-") && api_key.len() > 20
}
async fn list_models(&self, _api_key: Option<&str>) -> Result<Vec<String>, WorkerError> {
// Could make an API call to fetch available models
Ok(self.metadata().supported_models)
}
async fn health_check(&self, api_key: Option<&str>) -> Result<bool, WorkerError> {
// Simple health check - validate API key format
if let Some(key) = api_key {
Ok(self.validate_api_key(key))
} else {
Ok(false)
}
}
fn as_any(&self) -> &dyn Any {
self
}
}
/// Custom LLM client implementation
struct CustomLlmClient {
api_key: String,
model: String,
config: HashMap<String, Value>,
}
impl CustomLlmClient {
fn new(api_key: String, model: String, config: HashMap<String, Value>) -> Self {
Self {
api_key,
model,
config,
}
}
}
#[async_trait]
impl LlmClientTrait for CustomLlmClient {
async fn chat_stream<'a>(
&'a self,
messages: Vec<Message>,
_tools: Option<&[DynamicToolDefinition]>,
_llm_debug: Option<LlmDebug>,
) -> Result<Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>, WorkerError> {
use async_stream::stream;
// Example implementation that echoes the last user message
let last_user_message = messages
.iter()
.rev()
.find(|m| m.role == Role::User)
.map(|m| m.content.clone())
.unwrap_or_else(|| "Hello from custom provider!".to_string());
let response = format!(
"This is a response from the custom provider using model '{}'. You said: {}",
self.model, last_user_message
);
// Create a stream that yields the response in chunks
let stream = stream! {
// Simulate streaming response
for word in response.split_whitespace() {
yield Ok(StreamEvent::Chunk(format!("{} ", word)));
// Small delay to simulate streaming
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
}
// Send completion event with the full message
let completion_message = Message {
role: Role::Model,
content: response.clone(),
tool_calls: None,
tool_call_id: None,
metadata: None,
timestamp: Some(chrono::Utc::now()),
};
yield Ok(StreamEvent::Completion(completion_message));
};
Ok(Box::new(Box::pin(stream)))
}
async fn check_connection(&self) -> Result<(), WorkerError> {
// Simple check - could make an actual API call in a real implementation
if self.api_key.is_empty() {
Err(WorkerError::config("API key is required"))
} else {
Ok(())
}
}
fn provider(&self) -> LlmProvider {
// Return a default provider type for plugins
// In a real implementation, you might extend LlmProvider enum
LlmProvider::OpenAI
}
fn get_model_name(&self) -> String {
self.model.clone()
}
}
/// Factory function for dynamic loading
#[unsafe(no_mangle)]
pub extern "C" fn create_plugin() -> Box<dyn ProviderPlugin> {
Box::new(CustomProviderPlugin::new())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_plugin_metadata() {
let plugin = CustomProviderPlugin::new();
let metadata = plugin.metadata();
assert_eq!(metadata.id, "custom-provider");
assert_eq!(metadata.name, "Custom LLM Provider");
assert!(metadata.requires_api_key);
assert_eq!(metadata.supported_models.len(), 3);
}
#[test]
fn test_api_key_validation() {
let plugin = CustomProviderPlugin::new();
assert!(plugin.validate_api_key("custom-1234567890abcdefghij"));
assert!(!plugin.validate_api_key("invalid-key"));
assert!(!plugin.validate_api_key("custom-short"));
assert!(!plugin.validate_api_key(""));
}
#[tokio::test]
async fn test_plugin_initialization() {
let mut plugin = CustomProviderPlugin::new();
let mut config = HashMap::new();
config.insert("base_url".to_string(), Value::String("https://api.example.com".to_string()));
let result = plugin.initialize(config).await;
assert!(result.is_ok());
}
}

219
worker/src/plugin/mod.rs Normal file
View File

@ -0,0 +1,219 @@
pub mod example_provider;
use async_trait::async_trait;
use futures_util::Stream;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
use crate::core::LlmClientTrait;
use crate::types::WorkerError;
use worker_types::{DynamicToolDefinition, LlmDebug, LlmProvider, Message, StreamEvent};
/// Plugin metadata for provider identification and configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PluginMetadata {
/// Unique identifier for the plugin
pub id: String,
/// Display name for the plugin
pub name: String,
/// Plugin version
pub version: String,
/// Plugin author
pub author: String,
/// Plugin description
pub description: String,
/// Supported models by this plugin
pub supported_models: Vec<String>,
/// Whether this plugin requires an API key
pub requires_api_key: bool,
/// Custom configuration schema (JSON Schema)
pub config_schema: Option<Value>,
}
/// Main plugin trait that all provider plugins must implement
#[async_trait]
pub trait ProviderPlugin: Send + Sync {
/// Get plugin metadata
fn metadata(&self) -> PluginMetadata;
/// Initialize the plugin with configuration
async fn initialize(&mut self, config: HashMap<String, Value>) -> Result<(), WorkerError>;
/// Create a new LLM client instance
fn create_client(
&self,
model_name: &str,
api_key: Option<&str>,
config: Option<HashMap<String, Value>>,
) -> Result<Box<dyn LlmClientTrait>, WorkerError>;
/// Validate API key format (if applicable)
fn validate_api_key(&self, api_key: &str) -> bool {
if !self.metadata().requires_api_key {
return true;
}
!api_key.is_empty()
}
/// Get available models from the provider
async fn list_models(&self, _api_key: Option<&str>) -> Result<Vec<String>, WorkerError> {
Ok(self.metadata().supported_models.clone())
}
/// Health check for the provider
async fn health_check(&self, _api_key: Option<&str>) -> Result<bool, WorkerError> {
Ok(true)
}
/// Return as Any for downcasting if needed
fn as_any(&self) -> &dyn Any;
}
/// Plugin loader and registry
pub struct PluginRegistry {
plugins: HashMap<String, Arc<dyn ProviderPlugin>>,
}
impl PluginRegistry {
pub fn new() -> Self {
Self {
plugins: HashMap::new(),
}
}
/// Register a new plugin
pub fn register(&mut self, plugin: Arc<dyn ProviderPlugin>) -> Result<(), WorkerError> {
let metadata = plugin.metadata();
if self.plugins.contains_key(&metadata.id) {
return Err(WorkerError::config(format!(
"Plugin with id '{}' already registered",
metadata.id
)));
}
self.plugins.insert(metadata.id.clone(), plugin);
Ok(())
}
/// Get a plugin by ID
pub fn get(&self, id: &str) -> Option<Arc<dyn ProviderPlugin>> {
self.plugins.get(id).cloned()
}
/// List all registered plugins
pub fn list(&self) -> Vec<PluginMetadata> {
self.plugins.values().map(|p| p.metadata()).collect()
}
/// Find plugin by model name
pub fn find_by_model(&self, model_name: &str) -> Option<Arc<dyn ProviderPlugin>> {
self.plugins.values().find(|p| {
p.metadata().supported_models.iter().any(|m| m == model_name)
}).cloned()
}
/// Unregister a plugin
pub fn unregister(&mut self, id: &str) -> Option<Arc<dyn ProviderPlugin>> {
self.plugins.remove(id)
}
}
/// Dynamic plugin loader for loading plugins at runtime
pub struct PluginLoader;
impl PluginLoader {
/// Load a plugin from a shared library (.so/.dll/.dylib)
#[cfg(feature = "dynamic-loading")]
pub fn load_dynamic(path: &std::path::Path) -> Result<Box<dyn ProviderPlugin>, WorkerError> {
use libloading::{Library, Symbol};
unsafe {
let lib = Library::new(path)
.map_err(|e| WorkerError::config(format!("Failed to load plugin: {}", e)))?;
let create_plugin: Symbol<fn() -> Box<dyn ProviderPlugin>> = lib
.get(b"create_plugin")
.map_err(|e| WorkerError::config(format!("Plugin missing create_plugin function: {}", e)))?;
Ok(create_plugin())
}
}
/// Load all plugins from a directory
#[cfg(feature = "dynamic-loading")]
pub fn load_from_directory(dir: &std::path::Path) -> Result<Vec<Box<dyn ProviderPlugin>>, WorkerError> {
use std::fs;
let mut plugins = Vec::new();
if !dir.is_dir() {
return Err(WorkerError::config(format!("Plugin directory does not exist: {:?}", dir)));
}
for entry in fs::read_dir(dir)
.map_err(|e| WorkerError::Config(format!("Failed to read plugin directory: {}", e)))?
{
let entry = entry
.map_err(|e| WorkerError::config(format!("Failed to read directory entry: {}", e)))?;
let path = entry.path();
// Check for plugin files (.so on Linux, .dll on Windows, .dylib on macOS)
if path.extension().and_then(|s| s.to_str()).map_or(false, |ext| {
ext == "so" || ext == "dll" || ext == "dylib"
}) {
match Self::load_dynamic(&path) {
Ok(plugin) => plugins.push(plugin),
Err(e) => {
tracing::warn!("Failed to load plugin from {:?}: {}", path, e);
}
}
}
}
Ok(plugins)
}
}
/// Wrapper to adapt plugin-based clients to LlmClientTrait
pub struct PluginClient {
plugin: Arc<dyn ProviderPlugin>,
inner: Box<dyn LlmClientTrait>,
}
impl PluginClient {
pub fn new(
plugin: Arc<dyn ProviderPlugin>,
model_name: &str,
api_key: Option<&str>,
config: Option<HashMap<String, Value>>,
) -> Result<Self, WorkerError> {
let inner = plugin.create_client(model_name, api_key, config)?;
Ok(Self { plugin, inner })
}
}
#[async_trait]
impl LlmClientTrait for PluginClient {
async fn chat_stream<'a>(
&'a self,
messages: Vec<Message>,
tools: Option<&[DynamicToolDefinition]>,
llm_debug: Option<LlmDebug>,
) -> Result<Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>, WorkerError> {
self.inner.chat_stream(messages, tools, llm_debug).await
}
async fn check_connection(&self) -> Result<(), WorkerError> {
self.inner.check_connection().await
}
fn provider(&self) -> LlmProvider {
self.inner.provider()
}
fn get_model_name(&self) -> String {
self.inner.get_model_name()
}
}

View File

@ -1,14 +1,16 @@
use crate::config_parser::ConfigParser; use crate::config::ConfigParser;
use crate::prompt_types::*; use super::types::*;
use crate::types::{Message, Role};
use handlebars::{Context, Handlebars, Helper, HelperResult, Output, RenderContext}; use handlebars::{Context, Handlebars, Helper, HelperResult, Output, RenderContext};
use std::fs; use std::fs;
use std::path::Path; use std::path::Path;
// Import Message and Role enum from worker_types
use worker_types::{Message, Role as MessageRole};
/// プロンプト構築システム /// プロンプト構築システム
#[derive(Clone)] #[derive(Clone)]
pub struct PromptComposer { pub struct PromptComposer {
config: PromptRoleConfig, config: Role,
handlebars: Handlebars<'static>, handlebars: Handlebars<'static>,
context: PromptContext, context: PromptContext,
system_prompt: Option<String>, system_prompt: Option<String>,
@ -26,7 +28,7 @@ impl PromptComposer {
/// 設定オブジェクトから新しいインスタンスを作成 /// 設定オブジェクトから新しいインスタンスを作成
pub fn from_config( pub fn from_config(
config: PromptRoleConfig, config: Role,
context: PromptContext, context: PromptContext,
) -> Result<Self, PromptError> { ) -> Result<Self, PromptError> {
let mut handlebars = Handlebars::new(); let mut handlebars = Handlebars::new();
@ -58,11 +60,11 @@ impl PromptComposer {
pub fn compose(&self, messages: &[Message]) -> Result<Vec<Message>, PromptError> { pub fn compose(&self, messages: &[Message]) -> Result<Vec<Message>, PromptError> {
if let Some(system_prompt) = &self.system_prompt { if let Some(system_prompt) = &self.system_prompt {
// システムプロンプトが既に構築済みの場合、それを使用 // システムプロンプトが既に構築済みの場合、それを使用
let mut result_messages = vec![Message::new(Role::System, system_prompt.clone())]; let mut result_messages = vec![Message::new(MessageRole::System, system_prompt.clone())];
// ユーザーメッセージを追加 // ユーザーメッセージを追加
for msg in messages { for msg in messages {
if msg.role != Role::System { if msg.role != MessageRole::System {
result_messages.push(msg.clone()); result_messages.push(msg.clone());
} }
} }
@ -100,11 +102,11 @@ impl PromptComposer {
) -> Result<Vec<Message>, PromptError> { ) -> Result<Vec<Message>, PromptError> {
if let Some(system_prompt) = &self.system_prompt { if let Some(system_prompt) = &self.system_prompt {
// システムプロンプトが既に構築済みの場合、それを使用 // システムプロンプトが既に構築済みの場合、それを使用
let mut result_messages = vec![Message::new(Role::System, system_prompt.clone())]; let mut result_messages = vec![Message::new(MessageRole::System, system_prompt.clone())];
// ユーザーメッセージを追加 // ユーザーメッセージを追加
for msg in messages { for msg in messages {
if msg.role != Role::System { if msg.role != MessageRole::System {
result_messages.push(msg.clone()); result_messages.push(msg.clone());
} }
} }
@ -156,11 +158,11 @@ impl PromptComposer {
let system_prompt = self.compose_system_prompt_with_context(messages, context)?; let system_prompt = self.compose_system_prompt_with_context(messages, context)?;
// システムメッセージとユーザーメッセージを結合 // システムメッセージとユーザーメッセージを結合
let mut result_messages = vec![Message::new(Role::System, system_prompt)]; let mut result_messages = vec![Message::new(MessageRole::System, system_prompt)];
// ユーザーメッセージを追加 // ユーザーメッセージを追加
for msg in messages { for msg in messages {
if msg.role != Role::System { if msg.role != MessageRole::System {
result_messages.push(msg.clone()); result_messages.push(msg.clone());
} }
} }
@ -223,7 +225,7 @@ impl PromptComposer {
) -> Result<serde_json::Value, PromptError> { ) -> Result<serde_json::Value, PromptError> {
let user_input = messages let user_input = messages
.iter() .iter()
.filter(|m| m.role == Role::User) .filter(|m| m.role == MessageRole::User)
.map(|m| m.content.as_str()) .map(|m| m.content.as_str())
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join("\n\n"); .join("\n\n");

8
worker/src/prompt/mod.rs Normal file
View File

@ -0,0 +1,8 @@
mod composer;
mod types;
pub use composer::PromptComposer;
pub use types::{
ConditionConfig, GitInfo, ModelCapabilities, ModelContext, PartialConfig, PromptContext,
PromptError, ProjectType, Role, SessionContext, SystemInfo, WorkspaceContext,
};

View File

@ -2,9 +2,9 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::path::PathBuf; use std::path::PathBuf;
/// ロール設定ファイルの型定義 /// Role configuration - defines the system instructions for the LLM
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptRoleConfig { pub struct Role {
pub name: String, pub name: String,
pub description: String, pub description: String,
pub version: Option<String>, pub version: Option<String>,
@ -352,24 +352,15 @@ impl Default for SessionContext {
} }
} }
impl Default for PromptRoleConfig { impl Role {
fn default() -> Self { /// Create a new Role with name, description, and template
let mut partials = HashMap::new(); pub fn new(name: impl Into<String>, description: impl Into<String>, template: impl Into<String>) -> Self {
partials.insert(
"role_definition".to_string(),
PartialConfig {
path: "./resources/prompts/cli-assistant.md".to_string(),
fallback: None,
description: Some("Default role definition".to_string()),
},
);
Self { Self {
name: "default".to_string(), name: name.into(),
description: "Default dynamic role configuration".to_string(), description: description.into(),
version: Some("1.0.0".to_string()), version: Some("1.0.0".to_string()),
template: "{{>role_definition}}".to_string(), template: template.into(),
partials: Some(partials), partials: None,
variables: None, variables: None,
conditions: None, conditions: None,
} }

View File

@ -1,4 +1,4 @@
use crate::config_parser::ConfigParser; use crate::config::ConfigParser;
use std::io::Write; use std::io::Write;
use tempfile::NamedTempFile; use tempfile::NamedTempFile;

View File

@ -1,48 +1,183 @@
// Re-export all types from worker-types for backwards compatibility // Re-export all types from worker-types for backwards compatibility
pub use worker_types::*; pub use worker_types::*;
// Worker-specific error type // Worker-specific error type with structured information
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum WorkerError { pub enum WorkerError {
#[error("Tool execution failed: {0}")] /// Tool execution failed
ToolExecution(String), #[error("Tool execution failed: {tool_name} - {reason}")]
#[error("Tool execution error: {0}")] ToolExecutionError {
ToolExecutionError(String), tool_name: String,
#[error("LLM API error: {0}")] reason: String,
LlmApiError(String), #[source]
#[error("Model not found: {0}")] source: Option<Box<dyn std::error::Error + Send + Sync>>,
ModelNotFound(String), },
#[error("JSON serialization/deserialization error: {0}")]
/// LLM API error with provider context
#[error("LLM API error ({provider}): {message}")]
LlmApiError {
provider: String,
message: String,
status_code: Option<u16>,
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync>>,
},
/// Model not found for the specified provider
#[error("Model not found: {model_name} for provider {provider}")]
ModelNotFound { provider: String, model_name: String },
/// JSON serialization/deserialization error
#[error("JSON error: {0}")]
JsonError(#[from] serde_json::Error), JsonError(#[from] serde_json::Error),
#[error("Serialization error: {0}")]
Serialization(serde_json::Error), /// Network communication error
#[error("Network error: {0}")] #[error("Network error: {message}")]
Network(String), Network {
#[error("Configuration error: {0}")] message: String,
Config(String), #[source]
#[error("Configuration error: {0}")] source: Option<Box<dyn std::error::Error + Send + Sync>>,
ConfigurationError(String), },
#[error("General error: {0}")]
General(#[from] anyhow::Error), /// Configuration error with optional context
#[error("Box error: {0}")] #[error("Configuration error: {message}")]
BoxError(Box<dyn std::error::Error + Send + Sync>), ConfigurationError {
message: String,
context: Option<String>,
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync>>,
},
/// Other errors that don't fit specific categories
#[error("{0}")]
Other(String),
} }
impl From<&str> for WorkerError { impl WorkerError {
fn from(s: &str) -> Self { /// Create a tool execution error
WorkerError::General(anyhow::anyhow!(s.to_string())) pub fn tool_execution(tool_name: impl Into<String>, reason: impl Into<String>) -> Self {
Self::ToolExecutionError {
tool_name: tool_name.into(),
reason: reason.into(),
source: None,
}
}
/// Create a tool execution error with source
pub fn tool_execution_with_source(
tool_name: impl Into<String>,
reason: impl Into<String>,
source: Box<dyn std::error::Error + Send + Sync>,
) -> Self {
Self::ToolExecutionError {
tool_name: tool_name.into(),
reason: reason.into(),
source: Some(source),
}
}
/// Create an LLM API error
pub fn llm_api(provider: impl Into<String>, message: impl Into<String>) -> Self {
Self::LlmApiError {
provider: provider.into(),
message: message.into(),
status_code: None,
source: None,
}
}
/// Create an LLM API error with status code and source
pub fn llm_api_with_details(
provider: impl Into<String>,
message: impl Into<String>,
status_code: Option<u16>,
source: Option<Box<dyn std::error::Error + Send + Sync>>,
) -> Self {
Self::LlmApiError {
provider: provider.into(),
message: message.into(),
status_code,
source,
}
}
/// Create a model not found error
pub fn model_not_found(provider: impl Into<String>, model_name: impl Into<String>) -> Self {
Self::ModelNotFound {
provider: provider.into(),
model_name: model_name.into(),
}
}
/// Create a network error
pub fn network(message: impl Into<String>) -> Self {
Self::Network {
message: message.into(),
source: None,
}
}
/// Create a network error with source
pub fn network_with_source(
message: impl Into<String>,
source: Box<dyn std::error::Error + Send + Sync>,
) -> Self {
Self::Network {
message: message.into(),
source: Some(source),
}
}
/// Create a configuration error
pub fn config(message: impl Into<String>) -> Self {
Self::ConfigurationError {
message: message.into(),
context: None,
source: None,
}
}
/// Create a configuration error with context
pub fn config_with_context(
message: impl Into<String>,
context: impl Into<String>,
) -> Self {
Self::ConfigurationError {
message: message.into(),
context: Some(context.into()),
source: None,
}
}
/// Create a configuration error with source
pub fn config_with_source(
message: impl Into<String>,
source: Box<dyn std::error::Error + Send + Sync>,
) -> Self {
Self::ConfigurationError {
message: message.into(),
context: None,
source: Some(source),
}
} }
} }
impl From<String> for WorkerError { // Explicit conversion from common error types
fn from(s: String) -> Self { impl From<reqwest::Error> for WorkerError {
WorkerError::General(anyhow::anyhow!(s)) fn from(e: reqwest::Error) -> Self {
Self::network_with_source("HTTP request failed", Box::new(e))
}
}
impl From<std::io::Error> for WorkerError {
fn from(e: std::io::Error) -> Self {
Self::config_with_source("I/O error", Box::new(e))
} }
} }
impl From<Box<dyn std::error::Error + Send + Sync>> for WorkerError { impl From<Box<dyn std::error::Error + Send + Sync>> for WorkerError {
fn from(e: Box<dyn std::error::Error + Send + Sync>) -> Self { fn from(e: Box<dyn std::error::Error + Send + Sync>) -> Self {
WorkerError::BoxError(e) Self::Other(format!("Error: {}", e))
} }
} }

View File

@ -1,4 +1,4 @@
use crate::prompt_types::*; use crate::prompt::{WorkspaceContext, PromptError, ProjectType, GitInfo};
use std::fs; use std::fs;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::process::Command; use std::process::Command;
@ -34,7 +34,7 @@ impl WorkspaceDetector {
let project_name = Self::determine_project_name(&root_path, &git_info); let project_name = Self::determine_project_name(&root_path, &git_info);
// 6. システム情報を生成 // 6. システム情報を生成
let system_info = crate::prompt_types::SystemInfo::default(); let system_info = crate::prompt::SystemInfo::default();
Ok(WorkspaceContext { Ok(WorkspaceContext {
root_path, root_path,

View File

@ -0,0 +1,3 @@
mod detector;
pub use detector::WorkspaceDetector;