From 02667f5396493cc42dede516123d2db30bff6b3a Mon Sep 17 00:00:00 2001 From: Hare Date: Fri, 24 Oct 2025 07:37:47 +0900 Subject: [PATCH] =?UTF-8?q?0.3.0:=20=E3=83=86=E3=83=B3=E3=83=97=E3=83=AC?= =?UTF-8?q?=E3=83=BC=E3=83=88=E3=82=A8=E3=83=B3=E3=82=B8=E3=83=B3=E3=81=AE?= =?UTF-8?q?=E3=83=95=E3=82=A1=E3=82=A4=E3=83=AB=E3=81=AE=E3=83=AD=E3=83=BC?= =?UTF-8?q?=E3=83=89=E3=82=92=E5=A4=96=E9=83=A8=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 4 +- README.md | 442 ++++---------------------- docs/patch_note/v0.3.0.md | 33 ++ docs/prompt-composer.md | 39 ++- worker/Cargo.toml | 2 +- worker/examples/builder_basic.rs | 13 +- worker/examples/plugin_usage.rs | 33 +- worker/src/builder.rs | 77 ++++- worker/src/client.rs | 2 +- worker/src/config/parser.rs | 56 ---- worker/src/core.rs | 2 +- worker/src/lib.rs | 108 ++++--- worker/src/llm/anthropic.rs | 4 +- worker/src/llm/gemini.rs | 4 +- worker/src/llm/ollama.rs | 8 +- worker/src/llm/openai.rs | 4 +- worker/src/llm/xai.rs | 4 +- worker/src/mcp/config.rs | 22 +- worker/src/mcp/mod.rs | 4 +- worker/src/mcp/tool.rs | 10 +- worker/src/plugin/example_provider.rs | 34 +- worker/src/plugin/mod.rs | 66 ++-- worker/src/prompt/composer.rs | 111 ++++--- worker/src/prompt/mod.rs | 4 +- worker/src/prompt/types.rs | 11 +- worker/src/tests/config_tests.rs | 122 +++++-- worker/src/types.rs | 10 +- worker/src/workspace/detector.rs | 2 +- 28 files changed, 569 insertions(+), 662 deletions(-) create mode 100644 docs/patch_note/v0.3.0.md diff --git a/Cargo.lock b/Cargo.lock index 4aa809e..f79a2db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2212,7 +2212,7 @@ checksum = "052283831dbae3d879dc7f51f3d92703a316ca49f91540417d38591826127814" [[package]] name = "worker" -version = "0.1.0" +version = "0.3.0" dependencies = [ "anyhow", "async-stream", @@ -2249,7 +2249,7 @@ dependencies = [ [[package]] name = "worker-macros" -version = "0.1.0" +version = "0.2.1" dependencies = [ "proc-macro2", "quote", diff --git a/README.md b/README.md index 2b6fcb1..2c3f7f3 100644 --- a/README.md +++ b/README.md @@ -1,399 +1,97 @@ # `worker` -`worker` クレートは、大規模言語モデル (LLM) を利用したアプリケーションのバックエンド機能を提供するクレートです。LLM プロバイダーの抽象化、ツール利用、柔軟なプロンプト管理、フックシステムなど、高度な機能をカプセル化し、アプリケーション開発を簡素化します。 +`worker` は、複数の LLM プロバイダーを横断して扱える統合ワーカーです。モデル呼び出し、ツール実行、プロンプト構築、フック連携など、対話アプリに必要な機能を 1 つの API にまとめます。 -## 主な機能 +## 特徴 +- 主要 LLM プロバイダーを単一のインターフェースで利用 +- プロンプト/パーシャル読み込みを利用者実装の `ResourceLoader` へ委譲 +- ツール連携とストリーミング応答、フックによるカスタマイズに対応 -- **マルチプロバイダー対応**: Gemini, Claude, OpenAI, Ollama, XAI など、複数の LLM プロバイダーを統一されたインターフェースで利用できます。 -- **プラグインシステム**: カスタムプロバイダーをプラグインとして動的に追加できます。独自の LLM API や実験的なプロバイダーをサポートします。 -- **ツール利用 (Function Calling)**: LLM が外部ツールを呼び出す機能をサポートします。独自のツールをマクロを用いて定義し、`Worker` に登録できます。 -- **ストリーミング処理**: LLM の応答やツール実行結果を `StreamEvent` として非同期に受け取ることができます。これにより、リアルタイムな UI 更新が可能になります。 -- **フックシステム**: `Worker` の処理フローの特定のタイミング(例: メッセージ送信前、ツール使用後)にカスタムロジックを介入させることができます。 -- **セッション管理**: 会話履歴やワークスペースの状態を管理し、永続化する機能を提供します。 -- **柔軟なプロンプト管理**: 設定ファイルを用いて、ロールやコンテキストに応じたシステムプロンプトを動的に構築します。 +## 利用手順 -## 主な概念 +`worker` のプロンプト/パーシャル解決は利用者側に委譲されています。以下の流れで組み込みます。 -### `Worker` - -このクレートの中心的な構造体です。LLM との対話、ツールの登録と実行、セッション管理など、すべての主要な機能を担当します。 - -### `LlmProvider` - -サポートしている LLM プロバイダー(`Gemini`, `Claude`, `OpenAI` など)を表す enum です。 - -### `Tool` トレイト - -`Worker` が利用できるツールを定義するためのインターフェースです。このトレイトを実装することで、任意の機能をツールとして `Worker` に追加できます。 - -```rust -pub trait Tool: Send + Sync { - fn name(&self) -> &str; - fn description(&self) -> &str; - fn parameters_schema(&self) -> serde_json::Value; - async fn execute(&self, args: serde_json::Value) -> ToolResult; -} -``` - -### `WorkerHook` トレイト - -`Worker` のライフサイクルイベントに介入するためのフックを定義するインターフェースです。特定のイベント(例: `OnMessageSend`, `PostToolUse`)に対して処理を追加できます。 - -### `StreamEvent` - -`Worker` の処理結果を非同期ストリームで受け取るための enum です。LLM の応答チャンク、ツール呼び出し、エラーなど、さまざまなイベントを表します。 - -## アプリケーションへの組み込み方法 - -### 1. Worker の初期化 - -Builder patternを使用してWorkerを作成します。 - -```rust -use worker::{Worker, LlmProvider, Role}; -use std::collections::HashMap; - -// ロールを定義(必須) -let role = Role::new( - "assistant", - "AI Assistant", - "You are a helpful AI assistant." -); - -// APIキーを準備 -let mut api_keys = HashMap::new(); -api_keys.insert("openai".to_string(), "your_openai_api_key".to_string()); -api_keys.insert("claude".to_string(), "your_claude_api_key".to_string()); - -// Workerを作成(builder pattern) -let mut worker = Worker::builder() - .provider(LlmProvider::OpenAI) - .model("gpt-4o") - .api_keys(api_keys) - .role(role) - .build() - .expect("Workerの作成に失敗しました"); - -// または、個別にAPIキーを設定 -let worker = Worker::builder() - .provider(LlmProvider::Claude) - .model("claude-3-sonnet-20240229") - .api_key("claude", "sk-ant-...") - .role(role) - .build()?; -``` - -### 2. ツールの定義と登録 - -`Tool` トレイトを実装してカスタムツールを作成し、`Worker` に登録します。 - -```rust -use worker::{Tool, ToolResult}; -use worker::schemars::{self, JsonSchema}; -use worker::serde_json::{self, json, Value}; -use async_trait::async_trait; - -// ツールの引数を定義 -#[derive(Debug, serde::Deserialize, JsonSchema)] -struct FileSystemToolArgs { - path: String, -} - -// カスタムツールを定義 -struct ListFilesTool; - -#[async_trait] -impl Tool for ListFilesTool { - fn name(&self) -> &str { "list_files" } - fn description(&self) -> &str { "指定されたパスのファイル一覧を表示します" } - - fn parameters_schema(&self) -> Value { - serde_json::to_value(schemars::schema_for!(FileSystemToolArgs)).unwrap() - } - - async fn execute(&self, args: Value) -> ToolResult { - let tool_args: FileSystemToolArgs = serde_json::from_value(args)?; - // ここで実際のファイル一覧取得処理を実装 - let files = vec!["file1.txt", "file2.txt"]; - Ok(json!({ "files": files })) - } -} - -// 作成したツールをWorkerに登録 -worker.register_tool(Box::new(ListFilesTool)).unwrap(); -``` - -#### マクロを使ったツール定義(推奨) - -`worker-macros` クレートの `#[tool]` マクロを使用すると、ツールの定義がより簡潔になります: - -```rust -use worker_macros::tool; -use worker::ToolResult; -use serde::{Deserialize, Serialize}; -use schemars::JsonSchema; - -#[derive(Debug, Deserialize, Serialize, JsonSchema)] -struct ListFilesArgs { - path: String, -} - -#[tool] -async fn list_files(args: ListFilesArgs) -> ToolResult { - // ファイル一覧取得処理 - let files = vec!["file1.txt", "file2.txt"]; - Ok(serde_json::json!({ "files": files })) -} - -// マクロで生成されたツールを登録 -worker.register_tool(Box::new(ListFilesTool))?; -``` - -### 3. 対話処理の実行 - -`process_task_with_history` メソッドを呼び出して、ユーザーメッセージを処理します。このメソッドはイベントのストリームを返します。 +1. `ResourceLoader` を実装して、テンプレートやパーシャルが参照する識別子から文字列を返す。 +2. `Worker::builder()` にプロバイダー・モデル・ロールと合わせて `resource_loader` を渡し、`Worker` を生成。 +3. セッションを初期化し、`process_task_with_history` などの API でイベントストリームを処理。 ```rust use futures_util::StreamExt; +use worker::{LlmProvider, PromptError, ResourceLoader, Role, Worker}; -let user_message = "カレントディレクトリのファイルを教えて".to_string(); +struct FsLoader; -let mut stream = worker.process_task_with_history(user_message, None).await; - -while let Some(event_result) = stream.next().await { - match event_result { - Ok(event) => { - // StreamEventに応じた処理 - match event { - worker::StreamEvent::Chunk(chunk) => { - print!("{}", chunk); - } - worker::StreamEvent::ToolCall(tool_call) => { - println!("\n[Tool Call: {} with args {}]", tool_call.name, tool_call.arguments); - } - worker::StreamEvent::ToolResult { tool_name, result } => { - println!("\n[Tool Result: {} -> {:?}]", tool_name, result); - } - _ => {} - } - } - Err(e) => { - eprintln!("\n[Error: {}]", e); - break; - } - } -} -``` - -### 4. (オプション) フックの登録 - -`WorkerHook` トレイトを実装してカスタムフックを作成し、`Worker` に登録することで、処理フローをカスタマイズできます。 - -#### 手動実装 - -```rust -use worker::{WorkerHook, HookContext, HookResult}; -use async_trait::async_trait; - -struct LoggingHook; - -#[async_trait] -impl WorkerHook for LoggingHook { - fn name(&self) -> &str { "logging_hook" } - fn hook_type(&self) -> &str { "OnMessageSend" } - fn matcher(&self) -> &str { "" } - - async fn execute(&self, context: HookContext) -> (HookContext, HookResult) { - println!("User message: {}", context.content); - (context, HookResult::Continue) +impl ResourceLoader for FsLoader { + fn load(&self, id: &str) -> Result { + std::fs::read_to_string(id) + .map_err(|e| PromptError::FileNotFound(format!("{}: {}", id, e))) } } -// フックを登録 -worker.register_hook(Box::new(LoggingHook)); +#[tokio::main] +async fn main() -> Result<(), Box> { + let role = Role::new("assistant", "Helper", "You are a helpful assistant."); + + let mut worker = Worker::builder() + .provider(LlmProvider::Claude) + .model("claude-3-sonnet-20240229") + .resource_loader(FsLoader) + .role(role) + .build()?; + + worker.initialize_session()?; + let events = worker + .process_task_with_history("こんにちは!".into(), None) + .await; + + futures_util::pin_mut!(events); + while let Some(event) = events.next().await { + println!("{event:?}"); + } + Ok(()) +} ``` -#### マクロを使ったフック定義(推奨) +### ツールを登録する -`worker-macros` クレートの `#[hook]` マクロを使用すると、フックの定義がより簡潔になります: +`#[worker::tool]` マクロで関数を装飾すると、`Tool` 実装を自動生成できます。 + +```rust +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use worker::ToolResult; + +#[derive(Debug, Deserialize, Serialize, JsonSchema)] +struct EchoArgs { + text: String, +} + +#[worker::tool(name = "echo")] +async fn echo(args: EchoArgs) -> ToolResult { + Ok(args.text) +} + +worker.register_tool(Box::new(EchoTool::new()))?; +``` + +### フックを登録する + +`#[worker::hook]` マクロで非同期関数を装飾すると、`WorkerHook` 実装が生成されます。 ```rust -use worker_macros::hook; use worker::{HookContext, HookResult}; -#[hook(OnMessageSend)] -async fn logging_hook(context: HookContext) -> (HookContext, HookResult) { - println!("User message: {}", context.content); +#[worker::hook(hook_type = "OnMessageSend")] +async fn log(context: HookContext) -> (HookContext, HookResult) { + println!("sending: {}", context.content); (context, HookResult::Continue) } -// マクロで生成されたフックを登録 -worker.register_hook(Box::new(LoggingHook)); +worker.register_hook(Box::new(LogHook::new())); ``` -**利用可能なフックタイプ:** -- `OnMessageSend`: ユーザーメッセージ送信前 -- `PreToolUse`: ツール実行前 -- `PostToolUse`: ツール実行後 -- `OnTurnCompleted`: ターン完了時 +## ドキュメント -これで、アプリケーションの要件に応じて `Worker` を中心とした強力な LLM 連携機能を構築できます。 - -## サンプルコード - -完全な動作例は `worker/examples/` ディレクトリを参照してください: - -- [`builder_basic.rs`](worker/examples/builder_basic.rs) - Builder patternの基本的な使用方法 -- [`plugin_usage.rs`](worker/examples/plugin_usage.rs) - プラグインシステムの使用方法 - -## プラグインシステム - -### プラグインの作成 - -`ProviderPlugin` トレイトを実装してカスタムプロバイダーを作成できます: - -```rust -use worker::plugin::{ProviderPlugin, PluginMetadata}; -use async_trait::async_trait; - -pub struct MyCustomProvider { - // プロバイダーの状態 -} - -#[async_trait] -impl ProviderPlugin for MyCustomProvider { - fn metadata(&self) -> PluginMetadata { - PluginMetadata { - id: "my-provider".to_string(), - name: "My Custom Provider".to_string(), - version: "1.0.0".to_string(), - author: "Your Name".to_string(), - description: "カスタムプロバイダーの説明".to_string(), - supported_models: vec!["model-1".to_string()], - requires_api_key: true, - config_schema: None, - } - } - - async fn initialize(&mut self, config: HashMap) -> Result<(), WorkerError> { - // プロバイダーの初期化 - Ok(()) - } - - fn create_client( - &self, - model_name: &str, - api_key: Option<&str>, - config: Option>, - ) -> Result, WorkerError> { - // LLMクライアントを作成して返す - } - - fn as_any(&self) -> &dyn Any { - self - } -} -``` - -### プラグインの使用 - -```rust -use worker::{Worker, Role, plugin::PluginRegistry}; -use std::sync::{Arc, Mutex}; - -// プラグインレジストリを作成 -let plugin_registry = Arc::new(Mutex::new(PluginRegistry::new())); - -// プラグインを作成して登録 -let my_plugin = Arc::new(MyCustomProvider::new()); -{ - let mut registry = plugin_registry.lock().unwrap(); - registry.register(my_plugin)?; -} - -// ロールを定義 -let role = Role::new( - "assistant", - "AI Assistant", - "You are a helpful AI assistant." -); - -let worker = Worker::builder() - .plugin("my-provider", plugin_registry.clone()) - .model("model-1") - .api_key("__plugin__", "api-key") - .role(role) - .build()?; -``` - -### 動的プラグイン読み込み - -`dynamic-loading` フィーチャーを有効にすることで、共有ライブラリからプラグインを動的に読み込むことができます: - -```toml -[dependencies] -worker = { path = "../worker", features = ["dynamic-loading"] } -``` - -```rust -// ディレクトリからプラグインを読み込み -worker.load_plugins_from_directory(Path::new("./plugins")).await?; -``` - -完全な例は `worker/src/plugin/example_provider.rs` と `worker/examples/plugin_usage.rs` を参照してください。 - -## エラーハンドリング - -構造化されたエラー型により、詳細なエラー情報を取得できます。 - -### WorkerError の種類 - -```rust -use worker::WorkerError; - -// ツール実行エラー -let error = WorkerError::tool_execution("my_tool", "Connection failed"); -let error = WorkerError::tool_execution_with_source("my_tool", "Failed", source_error); - -// 設定エラー -let error = WorkerError::config("Invalid configuration"); -let error = WorkerError::config_with_context("Parse error", "config.yaml line 10"); -let error = WorkerError::config_with_source("Failed to load", io_error); - -// LLM APIエラー -let error = WorkerError::llm_api("openai", "Rate limit exceeded"); -let error = WorkerError::llm_api_with_details("claude", "Invalid request", Some(400), None); - -// モデルエラー -let error = WorkerError::model_not_found("openai", "gpt-5"); - -// ネットワークエラー -let error = WorkerError::network("Connection timeout"); -let error = WorkerError::network_with_source("Request failed", reqwest_error); -``` - -### エラーのパターンマッチング - -```rust -match worker.build() { - Ok(worker) => { /* ... */ }, - Err(WorkerError::ConfigurationError { message, context, .. }) => { - eprintln!("Configuration error: {}", message); - if let Some(ctx) = context { - eprintln!("Context: {}", ctx); - } - }, - Err(WorkerError::ModelNotFound { provider, model_name }) => { - eprintln!("Model '{}' not found for provider '{}'", model_name, provider); - }, - Err(WorkerError::LlmApiError { provider, message, status_code, .. }) => { - eprintln!("API error from {}: {}", provider, message); - if let Some(code) = status_code { - eprintln!("Status code: {}", code); - } - }, - Err(e) => { - eprintln!("Error: {}", e); - } -} -``` +- API の詳細は `cargo doc --open` で参照できます。 +- プロンプトシステムの概要: `docs/prompt-composer.md` +- サンプルコード: `worker/examples/` diff --git a/docs/patch_note/v0.3.0.md b/docs/patch_note/v0.3.0.md new file mode 100644 index 0000000..4508229 --- /dev/null +++ b/docs/patch_note/v0.3.0.md @@ -0,0 +1,33 @@ +# Release Notes - v0.3.0 + +**Release Date**: 2025-??-?? + +v0.3.0 はプロンプトリソースの解決責務を利用側へ完全に移し、ツール/フック登録の推奨フローを明確化するアップデートです。これにより、ワーカーの動作を環境ごとに柔軟に制御できるようになりました。 + +## Breaking Changes + +- `ConfigParser::resolve_path` を削除し、`#nia/` `#workspace/` 等のプレフィックス解決をライブラリ利用者実装の `ResourceLoader` に委譲しました。 +- `WorkerBuilder::build()` は `resource_loader(...)` が未指定の場合エラーを返すようになりました。ワーカー構築前に必ずローダーを提供してください。 + +## 新機能 / 仕様変更 + +- `PromptComposer` が `ResourceLoader` を必須依存として受け取り、partials や `{{include_file}}` の読み込みをすべてローダー経由で行うようになりました。 +- パーシャル読み込み時にフォールバックが失敗した場合、一次/二次エラー内容を含むメッセージを返すよう改善しました。 +- README とドキュメントを刷新し、推奨ワークフロー(ResourceLoader 実装 → Worker 構築 → イベント処理)を明示。`#[worker::tool]` / `#[worker::hook]` マクロを用いた登録例を追加しました。 +- ユニットテスト `test_prompt_composer_uses_resource_loader` を追加し、注入されたローダーがパーシャル/include の解決に使われることを保証。 + +## 不具合修正 + +- `include_file` ヘルパーがカスタムローダーを利用せずにファイルアクセスしていた問題を修正。 +- `ConfigParser` が存在しない `#nia/` プレフィックスを静的に解決しようとしていた挙動を除去し、誤ったパスが静かに通ることを防止。 + +## 移行ガイド + +1. 既存コードで `ConfigParser::resolve_path` を直接利用していた場合、代わりにアプリケーション側で `ResourceLoader` を実装し、その中で必要なプレフィックス処理を行ってください。 +2. `Worker::builder()` を使用する箇所で、新たに `.resource_loader(...)` を必ず指定してください。 +3. ツール・フック登録は `#[worker::tool]` / `#[worker::hook]` マクロを利用する実装に移行することを推奨します。 + +## 開発者向けメモ + +- README を簡潔化し、RustDocs で確認できる内容の重複を削除しました。 +- `worker/examples/` を更新し、`ResourceLoader` の実装例とマクロベースのツール登録を採用しました。 diff --git a/docs/prompt-composer.md b/docs/prompt-composer.md index feccd6e..5778f00 100644 --- a/docs/prompt-composer.md +++ b/docs/prompt-composer.md @@ -5,14 +5,49 @@ ## 基本使用方法 ```rust +use std::sync::Arc; +use worker::prompt::{PromptComposer, PromptContext, PromptError, ResourceLoader}; + +struct FsLoader; + +impl ResourceLoader for FsLoader { + fn load(&self, identifier: &str) -> Result { + std::fs::read_to_string(identifier) + .map_err(|e| PromptError::FileNotFound(format!("{}: {}", identifier, e))) + } +} + // 初期化 -let composer = PromptComposer::from_config_file("role.yaml", context)?; +let loader = Arc::new(FsLoader); +let mut composer = PromptComposer::from_config_file("role.yaml", context, loader.clone())?; composer.initialize_session(&messages)?; // プロンプト構築 let messages = composer.compose(&user_messages)?; ``` +## リソースローダー + +`PromptComposer` はテンプレート内で参照されるパーシャルや `{{include_file}}` の解決をクレート利用者に委ねています。 +`ResourceLoader` トレイトを実装して、任意のストレージや命名規則に基づいて文字列を返してください。 + +```rust +struct MyLoader; + +impl ResourceLoader for MyLoader { + fn load(&self, identifier: &str) -> Result { + match identifier.strip_prefix("#workspace/") { + Some(rest) => { + let path = std::env::current_dir()?.join(".nia").join(rest); + std::fs::read_to_string(path).map_err(|e| PromptError::FileNotFound(e.to_string())) + } + None => std::fs::read_to_string(identifier) + .map_err(|e| PromptError::FileNotFound(e.to_string())), + } + } +} +``` + ## テンプレート構文 ### 変数展開 @@ -68,4 +103,4 @@ Focus on memory safety and performance. - `model`: LLMモデル情報(provider、model_name、capabilities) - `session`: セッション情報(conversation_id、message_count) - `user_input`: ユーザー入力内容 -- `tools_schema`: ツール定義JSON \ No newline at end of file +- `tools_schema`: ツール定義JSON diff --git a/worker/Cargo.toml b/worker/Cargo.toml index da58170..5fd15f7 100644 --- a/worker/Cargo.toml +++ b/worker/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "worker" -version = "0.2.1" +version = "0.3.0" edition = "2024" [dependencies] diff --git a/worker/examples/builder_basic.rs b/worker/examples/builder_basic.rs index cde9863..7b041cc 100644 --- a/worker/examples/builder_basic.rs +++ b/worker/examples/builder_basic.rs @@ -1,5 +1,14 @@ -use worker::{LlmProvider, Worker, Role}; use std::collections::HashMap; +use worker::{LlmProvider, PromptError, ResourceLoader, Role, Worker}; + +struct FsLoader; + +impl ResourceLoader for FsLoader { + fn load(&self, identifier: &str) -> Result { + std::fs::read_to_string(identifier) + .map_err(|e| PromptError::FileNotFound(format!("{}: {}", identifier, e))) + } +} #[tokio::main] async fn main() -> Result<(), Box> { @@ -17,6 +26,7 @@ async fn main() -> Result<(), Box> { .provider(LlmProvider::Claude) .model("claude-3-sonnet-20240229") .api_keys(api_keys) + .resource_loader(FsLoader) .role(role) .build()?; @@ -35,6 +45,7 @@ async fn main() -> Result<(), Box> { .provider(LlmProvider::Claude) .model("claude-3-sonnet-20240229") .api_key("claude", std::env::var("ANTHROPIC_API_KEY")?) + .resource_loader(FsLoader) .role(code_reviewer_role) .build()?; diff --git a/worker/examples/plugin_usage.rs b/worker/examples/plugin_usage.rs index 861e32f..636cbb7 100644 --- a/worker/examples/plugin_usage.rs +++ b/worker/examples/plugin_usage.rs @@ -1,6 +1,18 @@ use std::collections::HashMap; use std::sync::{Arc, Mutex}; -use worker::{Worker, Role, plugin::{PluginRegistry, ProviderPlugin, example_provider::CustomProviderPlugin}}; +use worker::{ + PromptError, ResourceLoader, Role, Worker, + plugin::{PluginRegistry, ProviderPlugin, example_provider::CustomProviderPlugin}, +}; + +struct FsLoader; + +impl ResourceLoader for FsLoader { + fn load(&self, identifier: &str) -> Result { + std::fs::read_to_string(identifier) + .map_err(|e| PromptError::FileNotFound(format!("{}: {}", identifier, e))) + } +} #[tokio::main] async fn main() -> Result<(), Box> { @@ -37,7 +49,10 @@ async fn main() -> Result<(), Box> { let plugins = registry.list(); println!("Available plugins:"); for plugin in plugins { - println!(" - {} ({}): {}", plugin.name, plugin.id, plugin.description); + println!( + " - {} ({}): {}", + plugin.name, plugin.id, plugin.description + ); println!(" Supported models: {:?}", plugin.supported_models); } } @@ -46,30 +61,34 @@ async fn main() -> Result<(), Box> { let role = Role::new( "assistant", "A helpful AI assistant", - "You are a helpful, harmless, and honest AI assistant powered by a custom LLM provider." + "You are a helpful, harmless, and honest AI assistant powered by a custom LLM provider.", ); let worker = Worker::builder() .plugin("custom-provider", plugin_registry.clone()) .model("custom-turbo") .api_key("__plugin__", "custom-1234567890abcdefghijklmnop") + .resource_loader(FsLoader) .role(role) .build()?; println!("\nWorker created successfully with custom provider plugin!"); - + // Example: List plugins from the worker let plugin_list = worker.list_plugins()?; println!("\nPlugins registered in worker:"); for metadata in plugin_list { - println!(" - {}: v{} by {}", metadata.name, metadata.version, metadata.author); + println!( + " - {}: v{} by {}", + metadata.name, metadata.version, metadata.author + ); } // Load plugins from directory (if dynamic loading is enabled) #[cfg(feature = "dynamic-loading")] { use std::path::Path; - + let plugin_dir = Path::new("./plugins"); if plugin_dir.exists() { let mut worker = worker; @@ -79,4 +98,4 @@ async fn main() -> Result<(), Box> { } Ok(()) -} \ No newline at end of file +} diff --git a/worker/src/builder.rs b/worker/src/builder.rs index 87eee62..afd5591 100644 --- a/worker/src/builder.rs +++ b/worker/src/builder.rs @@ -1,10 +1,10 @@ use crate::Worker; -use crate::prompt::Role; +use crate::prompt::{ResourceLoader, Role}; use crate::types::WorkerError; -use worker_types::LlmProvider; use std::collections::HashMap; use std::marker::PhantomData; use std::sync::{Arc, Mutex}; +use worker_types::LlmProvider; // Type-state markers pub struct NoProvider; @@ -20,13 +20,23 @@ pub struct WithRole; /// /// # Example /// ```no_run -/// use worker::{Worker, LlmProvider, Role}; +/// use worker::{Worker, LlmProvider, Role, ResourceLoader, PromptError}; +/// +/// struct FsLoader; +/// +/// impl ResourceLoader for FsLoader { +/// fn load(&self, identifier: &str) -> Result { +/// std::fs::read_to_string(identifier) +/// .map_err(|e| PromptError::FileNotFound(format!("{}: {}", identifier, e))) +/// } +/// } /// /// let role = Role::new("assistant", "AI Assistant", "You are a helpful assistant."); /// let worker = Worker::builder() /// .provider(LlmProvider::Claude) /// .model("claude-3-sonnet-20240229") /// .api_key("claude", "sk-ant-...") +/// .resource_loader(FsLoader) /// .role(role) /// .build()?; /// # Ok::<(), worker::WorkerError>(()) @@ -38,6 +48,7 @@ pub struct WorkerBuilder { // Role configuration (required) role: Option, + resource_loader: Option>, // Plugin configuration plugin_id: Option, @@ -53,6 +64,7 @@ impl Default for WorkerBuilder { model_name: None, api_keys: HashMap::new(), role: None, + resource_loader: None, plugin_id: None, plugin_registry: None, _phantom: PhantomData, @@ -76,6 +88,7 @@ impl WorkerBuilder { model_name: self.model_name, api_keys: self.api_keys, role: self.role, + resource_loader: self.resource_loader, plugin_id: self.plugin_id, plugin_registry: self.plugin_registry, _phantom: PhantomData, @@ -95,6 +108,7 @@ impl WorkerBuilder { model_name: self.model_name, api_keys: self.api_keys, role: self.role, + resource_loader: self.resource_loader, plugin_id: self.plugin_id, plugin_registry: self.plugin_registry, _phantom: PhantomData, @@ -104,13 +118,17 @@ impl WorkerBuilder { // Step 2: Set model impl WorkerBuilder { - pub fn model(mut self, model_name: impl Into) -> WorkerBuilder { + pub fn model( + mut self, + model_name: impl Into, + ) -> WorkerBuilder { self.model_name = Some(model_name.into()); WorkerBuilder { provider: self.provider, model_name: self.model_name, api_keys: self.api_keys, role: self.role, + resource_loader: self.resource_loader, plugin_id: self.plugin_id, plugin_registry: self.plugin_registry, _phantom: PhantomData, @@ -127,6 +145,7 @@ impl WorkerBuilder { model_name: self.model_name, api_keys: self.api_keys, role: self.role, + resource_loader: self.resource_loader, plugin_id: self.plugin_id, plugin_registry: self.plugin_registry, _phantom: PhantomData, @@ -147,26 +166,44 @@ impl WorkerBuilder { self.api_keys = keys; self } + + /// Provide a resource loader implementation for partial/include resolution + pub fn resource_loader(mut self, loader: L) -> Self + where + L: ResourceLoader + 'static, + { + self.resource_loader = Some(Arc::new(loader)); + self + } } // Build impl WorkerBuilder { pub fn build(self) -> Result { - use crate::{LlmProviderExt, WorkspaceDetector, PromptComposer, plugin}; + use crate::{LlmProviderExt, PromptComposer, WorkspaceDetector, plugin}; + + let resource_loader = self.resource_loader.clone().ok_or_else(|| { + WorkerError::config( + "Resource loader is required. Call resource_loader(...) before build.", + ) + })?; let role = self.role.unwrap(); let model_name = self.model_name.unwrap(); // Plugin provider if let (Some(plugin_id), Some(plugin_registry)) = (self.plugin_id, self.plugin_registry) { - let api_key_opt = self.api_keys.get("__plugin__").or_else(|| { - self.api_keys.values().next() - }); + let api_key_opt = self + .api_keys + .get("__plugin__") + .or_else(|| self.api_keys.values().next()); - let registry = plugin_registry.lock() - .map_err(|e| WorkerError::config(format!("Failed to lock plugin registry: {}", e)))?; + let registry = plugin_registry.lock().map_err(|e| { + WorkerError::config(format!("Failed to lock plugin registry: {}", e)) + })?; - let plugin = registry.get(&plugin_id) + let plugin = registry + .get(&plugin_id) .ok_or_else(|| WorkerError::config(format!("Plugin not found: {}", plugin_id)))?; let llm_client = plugin::PluginClient::new( @@ -190,14 +227,16 @@ impl WorkerBuilder { tracing::info!("Creating worker with plugin and role: {}", role.name); - let composer = PromptComposer::from_config(role.clone(), prompt_context) - .map_err(|e| WorkerError::config(e.to_string()))?; + let composer = + PromptComposer::from_config(role.clone(), prompt_context, resource_loader.clone()) + .map_err(|e| WorkerError::config(e.to_string()))?; drop(registry); let mut worker = Worker { llm_client: Box::new(llm_client), composer, + resource_loader: resource_loader.clone(), tools: Vec::new(), api_key, provider_str, @@ -210,7 +249,8 @@ impl WorkerBuilder { plugin_registry: plugin_registry.clone(), }; - worker.initialize_session() + worker + .initialize_session() .map_err(|e| WorkerError::config(e.to_string()))?; return Ok(worker); @@ -234,12 +274,14 @@ impl WorkerBuilder { tracing::info!("Creating worker with role: {}", role.name); - let composer = PromptComposer::from_config(role.clone(), prompt_context) - .map_err(|e| WorkerError::config(e.to_string()))?; + let composer = + PromptComposer::from_config(role.clone(), prompt_context, resource_loader.clone()) + .map_err(|e| WorkerError::config(e.to_string()))?; let mut worker = Worker { llm_client: Box::new(llm_client), composer, + resource_loader, tools: Vec::new(), api_key, provider_str: provider_str.to_string(), @@ -252,7 +294,8 @@ impl WorkerBuilder { plugin_registry, }; - worker.initialize_session() + worker + .initialize_session() .map_err(|e| WorkerError::config(e.to_string()))?; Ok(worker) diff --git a/worker/src/client.rs b/worker/src/client.rs index 1d035c6..f1cd684 100644 --- a/worker/src/client.rs +++ b/worker/src/client.rs @@ -6,8 +6,8 @@ use crate::llm::{ xai::XAIClient, }; use crate::types::WorkerError; -use worker_types::{LlmProvider, Message, StreamEvent}; use futures_util::Stream; +use worker_types::{LlmProvider, Message, StreamEvent}; // LlmClient enumを使用してdyn互換性の問題を解決 pub enum LlmClient { diff --git a/worker/src/config/parser.rs b/worker/src/config/parser.rs index df9b5bd..eb36ace 100644 --- a/worker/src/config/parser.rs +++ b/worker/src/config/parser.rs @@ -53,60 +53,4 @@ impl ConfigParser { Ok(()) } - - /// パスプレフィックスを解決する - pub fn resolve_path(path_str: &str) -> Result { - if path_str.starts_with("#nia/") { - // 組み込みリソース - let relative_path = path_str.strip_prefix("#nia/").unwrap(); - let project_root = std::env::current_dir() - .map_err(|e| PromptError::WorkspaceDetection(e.to_string()))?; - - // 優先順位: ./resources > ./cli/resources > ./nia-core/resources > ./nia-pod/resources - let possible_paths = [ - project_root.join("resources").join(relative_path), - project_root - .join("cli") - .join("resources") - .join(relative_path), - project_root - .join("nia-core") - .join("resources") - .join(relative_path), - project_root - .join("nia-pod") - .join("resources") - .join(relative_path), - ]; - - for path in &possible_paths { - if path.exists() { - return Ok(path.clone()); - } - } - - // 見つからない場合はデフォルトのパスを返す - Ok(project_root - .join("nia-cli") - .join("resources") - .join(relative_path)) - } else if path_str.starts_with("#workspace/") { - // ワークスペース固有 - let relative_path = path_str.strip_prefix("#workspace/").unwrap(); - let project_root = std::env::current_dir() - .map_err(|e| PromptError::WorkspaceDetection(e.to_string()))?; - Ok(project_root.join(".nia").join(relative_path)) - } else if path_str.starts_with("#user/") { - // ユーザー設定 - let relative_path = path_str.strip_prefix("#user/").unwrap(); - let base_dirs = xdg::BaseDirectories::with_prefix("nia"); - let config_home = base_dirs.get_config_home().ok_or_else(|| { - PromptError::WorkspaceDetection("Could not determine XDG config home".to_string()) - })?; - Ok(config_home.join(relative_path)) - } else { - // 相対パスまたは絶対パス - Ok(std::path::PathBuf::from(path_str)) - } - } } diff --git a/worker/src/core.rs b/worker/src/core.rs index f28635d..d51d4dd 100644 --- a/worker/src/core.rs +++ b/worker/src/core.rs @@ -1,9 +1,9 @@ // Core types and traits for the worker crate // This module contains the primary abstractions used throughout the crate -use worker_types::{DynamicToolDefinition, LlmProvider, Message, StreamEvent}; use crate::types::WorkerError; use futures_util::Stream; +use worker_types::{DynamicToolDefinition, LlmProvider, Message, StreamEvent}; /// LlmClient trait - common interface for all LLM clients /// diff --git a/worker/src/lib.rs b/worker/src/lib.rs index d09263e..b2d14fa 100644 --- a/worker/src/lib.rs +++ b/worker/src/lib.rs @@ -1,6 +1,6 @@ use crate::prompt::{ - PromptComposer, PromptContext, WorkspaceContext, ModelContext, ModelCapabilities, - SessionContext + ModelCapabilities, ModelContext, PromptComposer, PromptContext, SessionContext, + WorkspaceContext, }; use crate::workspace::WorkspaceDetector; use async_stream::stream; @@ -15,29 +15,29 @@ use std::fs; use std::path::PathBuf; use tracing; use uuid; +pub use worker_macros::{hook, tool}; pub use worker_types::{ DynamicToolDefinition, HookContext, HookEvent, HookManager, HookResult, LlmDebug, LlmProvider, LlmResponse, Message, ModelInfo, SessionData, StreamEvent, Task, Tool, ToolCall, ToolResult, WorkerHook, WorkspaceConfig, WorkspaceData, }; -pub use worker_macros::{hook, tool}; -pub mod core; -pub mod types; -pub mod client; pub mod builder; +pub mod client; pub mod config; +pub mod core; pub mod llm; pub mod mcp; pub mod plugin; pub mod prompt; +pub mod types; pub mod workspace; -pub use core::LlmClientTrait; -pub use client::LlmClient; -pub use crate::prompt::Role; -pub use builder::WorkerBuilder; +pub use crate::prompt::{PromptError, ResourceLoader, Role}; pub use crate::types::WorkerError; +pub use builder::WorkerBuilder; +pub use client::LlmClient; +pub use core::LlmClientTrait; #[cfg(test)] mod tests { @@ -172,7 +172,8 @@ impl WorkerError { || error_msg.contains("billing") || error_msg.contains("upgrade") || error_msg.contains("purchase credits") - || (error_msg.contains("invalid_request_error") && error_msg.contains("credit balance")); + || (error_msg.contains("invalid_request_error") + && error_msg.contains("credit balance")); let has_provider_patterns = match provider { LlmProvider::OpenAI => { @@ -182,10 +183,10 @@ impl WorkerError { } LlmProvider::Claude => { (error_msg.contains("invalid_x_api_key") || error_msg.contains("x-api-key")) - && !error_msg.contains("credit balance") - && !error_msg.contains("billing") - && !error_msg.contains("upgrade") - && !error_msg.contains("purchase credits") + && !error_msg.contains("credit balance") + && !error_msg.contains("billing") + && !error_msg.contains("upgrade") + && !error_msg.contains("purchase credits") } LlmProvider::Gemini => { error_msg.contains("invalid_argument") || error_msg.contains("credentials") @@ -319,9 +320,7 @@ pub async fn validate_api_key( Ok(Some(false)) } } - LlmProvider::Ollama => { - Ok(Some(true)) - } + LlmProvider::Ollama => Ok(Some(true)), LlmProvider::XAI => { if api_key.starts_with("xai-") && api_key.len() > 20 { tracing::debug!("validate_api_key: xAI API key format appears valid"); @@ -357,9 +356,8 @@ pub struct ModelMeta { } fn get_models_config_path() -> Result { - let home_dir = dirs::home_dir().ok_or_else(|| { - WorkerError::config("Could not determine home directory") - })?; + let home_dir = dirs::home_dir() + .ok_or_else(|| WorkerError::config("Could not determine home directory"))?; Ok(home_dir.join(".config").join("nia").join("models.yaml")) } @@ -374,13 +372,11 @@ fn load_models_config() -> Result { return Ok(ModelsConfig { models: vec![] }); } - let content = fs::read_to_string(&config_path).map_err(|e| { - WorkerError::config(format!("Failed to read models config: {}", e)) - })?; + let content = fs::read_to_string(&config_path) + .map_err(|e| WorkerError::config(format!("Failed to read models config: {}", e)))?; - let config: ModelsConfig = serde_yaml::from_str(&content).map_err(|e| { - WorkerError::config(format!("Failed to parse models config: {}", e)) - })?; + let config: ModelsConfig = serde_yaml::from_str(&content) + .map_err(|e| WorkerError::config(format!("Failed to parse models config: {}", e)))?; Ok(config) } @@ -447,6 +443,7 @@ pub async fn supports_native_tools( pub struct Worker { pub(crate) llm_client: Box, pub(crate) composer: PromptComposer, + pub(crate) resource_loader: std::sync::Arc, pub(crate) tools: Vec>, pub(crate) api_key: String, pub(crate) provider_str: String, @@ -464,27 +461,42 @@ impl Worker { /// /// # Example /// ```no_run - /// use worker::{Worker, LlmProvider, Role}; + /// use worker::{Worker, LlmProvider, Role, PromptError, ResourceLoader}; + /// + /// struct FsLoader; + /// + /// impl ResourceLoader for FsLoader { + /// fn load(&self, identifier: &str) -> Result { + /// std::fs::read_to_string(identifier) + /// .map_err(|e| PromptError::FileNotFound(format!("{}: {}", identifier, e))) + /// } + /// } /// /// let role = Role::new("assistant", "AI Assistant", "You are a helpful assistant."); /// let worker = Worker::builder() /// .provider(LlmProvider::Claude) /// .model("claude-3-sonnet-20240229") /// .api_key("claude", "sk-ant-...") + /// .resource_loader(FsLoader) /// .role(role) /// .build()?; /// # Ok::<(), worker::WorkerError>(()) /// ``` - pub fn builder() -> builder::WorkerBuilder { + pub fn builder() + -> builder::WorkerBuilder { builder::WorkerBuilder::new() } - /// Load plugins from a directory #[cfg(feature = "dynamic-loading")] - pub async fn load_plugins_from_directory(&mut self, dir: &std::path::Path) -> Result<(), WorkerError> { + pub async fn load_plugins_from_directory( + &mut self, + dir: &std::path::Path, + ) -> Result<(), WorkerError> { let plugins = plugin::PluginLoader::load_from_directory(dir)?; - let mut registry = self.plugin_registry.lock() + let mut registry = self + .plugin_registry + .lock() .map_err(|e| WorkerError::config(format!("Failed to lock plugin registry: {}", e)))?; for plugin in plugins { @@ -496,7 +508,9 @@ impl Worker { /// List all registered plugins pub fn list_plugins(&self) -> Result, WorkerError> { - let registry = self.plugin_registry.lock() + let registry = self + .plugin_registry + .lock() .map_err(|e| WorkerError::config(format!("Failed to lock plugin registry: {}", e)))?; Ok(registry.list()) } @@ -515,10 +529,7 @@ impl Worker { } /// MCPサーバーをツールとして登録する - pub fn register_mcp_server( - &mut self, - config: McpServerConfig, - ) -> Result<(), WorkerError> { + pub fn register_mcp_server(&mut self, config: McpServerConfig) -> Result<(), WorkerError> { let mcp_tool = McpDynamicTool::new(config.clone()); self.register_tool(Box::new(mcp_tool))?; tracing::info!("Registered MCP server as tool: {}", config.name); @@ -526,10 +537,7 @@ impl Worker { } /// MCPサーバーから個別のツールを登録する - pub async fn register_mcp_tools( - &mut self, - config: McpServerConfig, - ) -> Result<(), WorkerError> { + pub async fn register_mcp_tools(&mut self, config: McpServerConfig) -> Result<(), WorkerError> { let tools = create_single_mcp_tools(&config).await?; let tool_count = tools.len(); @@ -791,8 +799,9 @@ impl Worker { let prompt_context = self.create_prompt_context()?; // DynamicPromptComposerを作成 - let composer = PromptComposer::from_config(role.clone(), prompt_context) - .map_err(|e| WorkerError::config(e.to_string()))?; + let composer = + PromptComposer::from_config(role.clone(), prompt_context, self.resource_loader.clone()) + .map_err(|e| WorkerError::config(e.to_string()))?; self.role = role; self.composer = composer; @@ -927,7 +936,7 @@ impl Worker { if self.tools.iter().any(|t| t.name() == tool.name()) { return Err(WorkerError::tool_execution( tool.name(), - format!("Tool '{}' is already registered", tool.name()) + format!("Tool '{}' is already registered", tool.name()), )); } @@ -961,7 +970,7 @@ impl Worker { Some(tool) => tool.execute(args).await.map_err(WorkerError::from), None => Err(WorkerError::tool_execution( tool_name, - format!("Tool '{}' not found", tool_name) + format!("Tool '{}' not found", tool_name), )), } } @@ -1716,9 +1725,7 @@ impl Worker { /// セッションデータを取得する pub fn get_session_data(&self) -> Result { let workspace_path = std::env::current_dir() - .map_err(|e| { - WorkerError::config(format!("Failed to get current directory: {}", e)) - })? + .map_err(|e| WorkerError::config(format!("Failed to get current directory: {}", e)))? .to_string_lossy() .to_string(); @@ -1811,7 +1818,6 @@ impl Worker { } } - /// セッション初期化(Worker内部用) fn initialize_session(&mut self) -> Result<(), crate::prompt::PromptError> { // 空のメッセージでセッション初期化 @@ -1819,9 +1825,7 @@ impl Worker { } /// 履歴付きセッション再初期化(Worker内部用) - fn reinitialize_session_with_history( - &mut self, - ) -> Result<(), crate::prompt::PromptError> { + fn reinitialize_session_with_history(&mut self) -> Result<(), crate::prompt::PromptError> { // 現在の履歴を使ってセッション初期化 self.composer.initialize_session(&self.message_history) } diff --git a/worker/src/llm/anthropic.rs b/worker/src/llm/anthropic.rs index a6b1a88..1a0d75b 100644 --- a/worker/src/llm/anthropic.rs +++ b/worker/src/llm/anthropic.rs @@ -1,12 +1,12 @@ +use crate::config::UrlConfig; use crate::core::LlmClientTrait; use crate::types::WorkerError; -use worker_types::{LlmProvider, Message, Role, StreamEvent, ToolCall}; -use crate::config::UrlConfig; use async_stream::stream; use futures_util::{Stream, StreamExt}; use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::Value; +use worker_types::{LlmProvider, Message, Role, StreamEvent, ToolCall}; #[derive(Debug, Serialize)] struct AnthropicRequest { diff --git a/worker/src/llm/gemini.rs b/worker/src/llm/gemini.rs index e8a890e..328ed22 100644 --- a/worker/src/llm/gemini.rs +++ b/worker/src/llm/gemini.rs @@ -1,11 +1,11 @@ +use crate::config::UrlConfig; use crate::core::LlmClientTrait; use crate::types::WorkerError; -use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall}; -use crate::config::UrlConfig; use futures_util::{Stream, StreamExt, TryStreamExt}; use reqwest::Client; use serde::{Deserialize, Serialize}; use tracing; +use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall}; /// Extract tool name from Tool message content fn extract_tool_name_from_content(content: &str) -> Option { diff --git a/worker/src/llm/ollama.rs b/worker/src/llm/ollama.rs index 53810da..8f1874c 100644 --- a/worker/src/llm/ollama.rs +++ b/worker/src/llm/ollama.rs @@ -1,11 +1,11 @@ +use crate::config::UrlConfig; use crate::core::LlmClientTrait; use crate::types::WorkerError; -use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall}; -use crate::config::UrlConfig; use futures_util::{Stream, StreamExt}; use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::Value; +use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall}; // --- Request & Response Structures --- #[derive(Debug, Serialize, Clone)] @@ -670,7 +670,9 @@ impl OllamaClient { self.add_auth_header(client.get(&url)) .send() .await - .map_err(|e| WorkerError::llm_api("ollama", format!("Failed to connect to Ollama: {}", e)))?; + .map_err(|e| { + WorkerError::llm_api("ollama", format!("Failed to connect to Ollama: {}", e)) + })?; Ok(()) } } diff --git a/worker/src/llm/openai.rs b/worker/src/llm/openai.rs index e77df62..9b86393 100644 --- a/worker/src/llm/openai.rs +++ b/worker/src/llm/openai.rs @@ -1,11 +1,11 @@ +use crate::config::UrlConfig; use crate::core::LlmClientTrait; use crate::types::WorkerError; -use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall}; -use crate::config::UrlConfig; use futures_util::{Stream, StreamExt}; use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::Value; +use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall}; // --- Request & Response Structures --- #[derive(Debug, Serialize)] diff --git a/worker/src/llm/xai.rs b/worker/src/llm/xai.rs index fc6239c..f6e0edb 100644 --- a/worker/src/llm/xai.rs +++ b/worker/src/llm/xai.rs @@ -1,11 +1,11 @@ +use crate::config::UrlConfig; use crate::core::LlmClientTrait; use crate::types::WorkerError; -use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall}; -use crate::config::UrlConfig; use futures_util::{Stream, StreamExt}; use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::Value; +use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall}; #[derive(Debug, Serialize)] pub(crate) struct XAIRequest { diff --git a/worker/src/mcp/config.rs b/worker/src/mcp/config.rs index db11479..ae326ef 100644 --- a/worker/src/mcp/config.rs +++ b/worker/src/mcp/config.rs @@ -1,5 +1,5 @@ -use crate::types::WorkerError; use super::tool::McpServerConfig; +use crate::types::WorkerError; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::path::Path; @@ -71,17 +71,11 @@ impl McpConfig { info!("Loading MCP config from: {:?}", path); let content = std::fs::read_to_string(path).map_err(|e| { - WorkerError::config(format!( - "Failed to read MCP config file {:?}: {}", - path, e - )) + WorkerError::config(format!("Failed to read MCP config file {:?}: {}", path, e)) })?; let config: McpConfig = serde_yaml::from_str(&content).map_err(|e| { - WorkerError::config(format!( - "Failed to parse MCP config file {:?}: {}", - path, e - )) + WorkerError::config(format!("Failed to parse MCP config file {:?}: {}", path, e)) })?; info!("Loaded {} MCP server configurations", config.servers.len()); @@ -102,15 +96,11 @@ impl McpConfig { })?; } - let content = serde_yaml::to_string(self).map_err(|e| { - WorkerError::config(format!("Failed to serialize MCP config: {}", e)) - })?; + let content = serde_yaml::to_string(self) + .map_err(|e| WorkerError::config(format!("Failed to serialize MCP config: {}", e)))?; std::fs::write(path, content).map_err(|e| { - WorkerError::config(format!( - "Failed to write MCP config file {:?}: {}", - path, e - )) + WorkerError::config(format!("Failed to write MCP config file {:?}: {}", path, e)) })?; info!("Saved MCP config to: {:?}", path); diff --git a/worker/src/mcp/mod.rs b/worker/src/mcp/mod.rs index 420e3d0..6bc6f5e 100644 --- a/worker/src/mcp/mod.rs +++ b/worker/src/mcp/mod.rs @@ -5,6 +5,6 @@ mod tool; pub use config::{IntegrationMode, McpConfig, McpServerDefinition}; pub use protocol::McpClient; pub use tool::{ - create_single_mcp_tools, get_mcp_tools_as_definitions, test_mcp_connection, McpDynamicTool, - McpServerConfig, SingleMcpTool, + McpDynamicTool, McpServerConfig, SingleMcpTool, create_single_mcp_tools, + get_mcp_tools_as_definitions, test_mcp_connection, }; diff --git a/worker/src/mcp/tool.rs b/worker/src/mcp/tool.rs index 8d2b85b..62d25d3 100644 --- a/worker/src/mcp/tool.rs +++ b/worker/src/mcp/tool.rs @@ -118,7 +118,10 @@ impl McpDynamicTool { let tools = client.list_tools().await.map_err(|e| { crate::WorkerError::tool_execution_with_source( &self.config.name, - format!("Failed to list tools from MCP server '{}'", self.config.name), + format!( + "Failed to list tools from MCP server '{}'", + self.config.name + ), e, ) })?; @@ -326,7 +329,10 @@ impl Tool for McpDynamicTool { } None => Err(Box::new(crate::WorkerError::tool_execution( tool_name, - format!("Tool '{}' not found in MCP server '{}'", tool_name, self.config.name), + format!( + "Tool '{}' not found in MCP server '{}'", + tool_name, self.config.name + ), )) as Box), } diff --git a/worker/src/plugin/example_provider.rs b/worker/src/plugin/example_provider.rs index 723ff95..296e601 100644 --- a/worker/src/plugin/example_provider.rs +++ b/worker/src/plugin/example_provider.rs @@ -4,8 +4,8 @@ use serde_json::Value; use std::any::Any; use std::collections::HashMap; -use crate::plugin::{PluginMetadata, ProviderPlugin}; use crate::core::LlmClientTrait; +use crate::plugin::{PluginMetadata, ProviderPlugin}; use crate::types::WorkerError; use worker_types::{DynamicToolDefinition, LlmDebug, LlmProvider, Message, Role, StreamEvent}; @@ -78,8 +78,8 @@ impl ProviderPlugin for CustomProviderPlugin { return Err(WorkerError::config("Plugin not initialized")); } - let api_key = api_key - .ok_or_else(|| WorkerError::config("API key required for custom provider"))?; + let api_key = + api_key.ok_or_else(|| WorkerError::config("API key required for custom provider"))?; let client = CustomLlmClient::new( api_key.to_string(), @@ -138,9 +138,12 @@ impl LlmClientTrait for CustomLlmClient { messages: Vec, _tools: Option<&[DynamicToolDefinition]>, _llm_debug: Option, - ) -> Result> + Unpin + Send + 'a>, WorkerError> { + ) -> Result< + Box> + Unpin + Send + 'a>, + WorkerError, + > { use async_stream::stream; - + // Example implementation that echoes the last user message let last_user_message = messages .iter() @@ -148,12 +151,12 @@ impl LlmClientTrait for CustomLlmClient { .find(|m| m.role == Role::User) .map(|m| m.content.clone()) .unwrap_or_else(|| "Hello from custom provider!".to_string()); - + let response = format!( "This is a response from the custom provider using model '{}'. You said: {}", self.model, last_user_message ); - + // Create a stream that yields the response in chunks let stream = stream! { // Simulate streaming response @@ -162,7 +165,7 @@ impl LlmClientTrait for CustomLlmClient { // Small delay to simulate streaming tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; } - + // Send completion event with the full message let completion_message = Message { role: Role::Model, @@ -174,7 +177,7 @@ impl LlmClientTrait for CustomLlmClient { }; yield Ok(StreamEvent::Completion(completion_message)); }; - + Ok(Box::new(Box::pin(stream))) } @@ -212,7 +215,7 @@ mod tests { fn test_plugin_metadata() { let plugin = CustomProviderPlugin::new(); let metadata = plugin.metadata(); - + assert_eq!(metadata.id, "custom-provider"); assert_eq!(metadata.name, "Custom LLM Provider"); assert!(metadata.requires_api_key); @@ -222,7 +225,7 @@ mod tests { #[test] fn test_api_key_validation() { let plugin = CustomProviderPlugin::new(); - + assert!(plugin.validate_api_key("custom-1234567890abcdefghij")); assert!(!plugin.validate_api_key("invalid-key")); assert!(!plugin.validate_api_key("custom-short")); @@ -233,9 +236,12 @@ mod tests { async fn test_plugin_initialization() { let mut plugin = CustomProviderPlugin::new(); let mut config = HashMap::new(); - config.insert("base_url".to_string(), Value::String("https://api.example.com".to_string())); - + config.insert( + "base_url".to_string(), + Value::String("https://api.example.com".to_string()), + ); + let result = plugin.initialize(config).await; assert!(result.is_ok()); } -} \ No newline at end of file +} diff --git a/worker/src/plugin/mod.rs b/worker/src/plugin/mod.rs index dad4d0f..ab80d1f 100644 --- a/worker/src/plugin/mod.rs +++ b/worker/src/plugin/mod.rs @@ -109,9 +109,15 @@ impl PluginRegistry { /// Find plugin by model name pub fn find_by_model(&self, model_name: &str) -> Option> { - self.plugins.values().find(|p| { - p.metadata().supported_models.iter().any(|m| m == model_name) - }).cloned() + self.plugins + .values() + .find(|p| { + p.metadata() + .supported_models + .iter() + .any(|m| m == model_name) + }) + .cloned() } /// Unregister a plugin @@ -128,41 +134,50 @@ impl PluginLoader { #[cfg(feature = "dynamic-loading")] pub fn load_dynamic(path: &std::path::Path) -> Result, WorkerError> { use libloading::{Library, Symbol}; - + unsafe { let lib = Library::new(path) .map_err(|e| WorkerError::config(format!("Failed to load plugin: {}", e)))?; - - let create_plugin: Symbol Box> = lib - .get(b"create_plugin") - .map_err(|e| WorkerError::config(format!("Plugin missing create_plugin function: {}", e)))?; - + + let create_plugin: Symbol Box> = + lib.get(b"create_plugin").map_err(|e| { + WorkerError::config(format!("Plugin missing create_plugin function: {}", e)) + })?; + Ok(create_plugin()) } } - + /// Load all plugins from a directory #[cfg(feature = "dynamic-loading")] - pub fn load_from_directory(dir: &std::path::Path) -> Result>, WorkerError> { + pub fn load_from_directory( + dir: &std::path::Path, + ) -> Result>, WorkerError> { use std::fs; - + let mut plugins = Vec::new(); - + if !dir.is_dir() { - return Err(WorkerError::config(format!("Plugin directory does not exist: {:?}", dir))); + return Err(WorkerError::config(format!( + "Plugin directory does not exist: {:?}", + dir + ))); } - + for entry in fs::read_dir(dir) .map_err(|e| WorkerError::Config(format!("Failed to read plugin directory: {}", e)))? { - let entry = entry - .map_err(|e| WorkerError::config(format!("Failed to read directory entry: {}", e)))?; + let entry = entry.map_err(|e| { + WorkerError::config(format!("Failed to read directory entry: {}", e)) + })?; let path = entry.path(); - + // Check for plugin files (.so on Linux, .dll on Windows, .dylib on macOS) - if path.extension().and_then(|s| s.to_str()).map_or(false, |ext| { - ext == "so" || ext == "dll" || ext == "dylib" - }) { + if path + .extension() + .and_then(|s| s.to_str()) + .map_or(false, |ext| ext == "so" || ext == "dll" || ext == "dylib") + { match Self::load_dynamic(&path) { Ok(plugin) => plugins.push(plugin), Err(e) => { @@ -171,7 +186,7 @@ impl PluginLoader { } } } - + Ok(plugins) } } @@ -201,7 +216,10 @@ impl LlmClientTrait for PluginClient { messages: Vec, tools: Option<&[DynamicToolDefinition]>, llm_debug: Option, - ) -> Result> + Unpin + Send + 'a>, WorkerError> { + ) -> Result< + Box> + Unpin + Send + 'a>, + WorkerError, + > { self.inner.chat_stream(messages, tools, llm_debug).await } @@ -216,4 +234,4 @@ impl LlmClientTrait for PluginClient { fn get_model_name(&self) -> String { self.inner.get_model_name() } -} \ No newline at end of file +} diff --git a/worker/src/prompt/composer.rs b/worker/src/prompt/composer.rs index 961af51..d4484ed 100644 --- a/worker/src/prompt/composer.rs +++ b/worker/src/prompt/composer.rs @@ -1,8 +1,7 @@ -use crate::config::ConfigParser; use super::types::*; -use handlebars::{Context, Handlebars, Helper, HelperResult, Output, RenderContext}; -use std::fs; +use handlebars::{Context, Handlebars, Helper, HelperDef, HelperResult, Output, RenderContext}; use std::path::Path; +use std::sync::Arc; // Import Message and Role enum from worker_types use worker_types::{Message, Role as MessageRole}; @@ -14,6 +13,7 @@ pub struct PromptComposer { handlebars: Handlebars<'static>, context: PromptContext, system_prompt: Option, + resource_loader: Arc, } impl PromptComposer { @@ -21,26 +21,29 @@ impl PromptComposer { pub fn from_config_file>( config_path: P, context: PromptContext, + resource_loader: Arc, ) -> Result { - let config = ConfigParser::parse_from_file(config_path)?; - Self::from_config(config, context) + let config = crate::config::ConfigParser::parse_from_file(config_path)?; + Self::from_config(config, context, resource_loader) } /// 設定オブジェクトから新しいインスタンスを作成 pub fn from_config( config: Role, context: PromptContext, + resource_loader: Arc, ) -> Result { let mut handlebars = Handlebars::new(); // カスタムヘルパー関数を登録 - Self::register_custom_helpers(&mut handlebars)?; + Self::register_custom_helpers(&mut handlebars, resource_loader.clone())?; let mut composer = Self { config, handlebars, context, system_prompt: None, + resource_loader, }; // パーシャルテンプレートを読み込み・登録 @@ -60,7 +63,8 @@ impl PromptComposer { pub fn compose(&self, messages: &[Message]) -> Result, PromptError> { if let Some(system_prompt) = &self.system_prompt { // システムプロンプトが既に構築済みの場合、それを使用 - let mut result_messages = vec![Message::new(MessageRole::System, system_prompt.clone())]; + let mut result_messages = + vec![Message::new(MessageRole::System, system_prompt.clone())]; // ユーザーメッセージを追加 for msg in messages { @@ -102,7 +106,8 @@ impl PromptComposer { ) -> Result, PromptError> { if let Some(system_prompt) = &self.system_prompt { // システムプロンプトが既に構築済みの場合、それを使用 - let mut result_messages = vec![Message::new(MessageRole::System, system_prompt.clone())]; + let mut result_messages = + vec![Message::new(MessageRole::System, system_prompt.clone())]; // ユーザーメッセージを追加 for msg in messages { @@ -171,9 +176,16 @@ impl PromptComposer { } /// カスタムヘルパー関数を登録 - fn register_custom_helpers(handlebars: &mut Handlebars) -> Result<(), PromptError> { - // 基本的なヘルパーのみ実装(複雑なライフタイム問題を回避) - handlebars.register_helper("include_file", Box::new(include_file_helper)); + fn register_custom_helpers( + handlebars: &mut Handlebars<'static>, + resource_loader: Arc, + ) -> Result<(), PromptError> { + handlebars.register_helper( + "include_file", + Box::new(IncludeFileHelper { + loader: resource_loader.clone(), + }), + ); handlebars.register_helper("workspace_content", Box::new(workspace_content_helper)); Ok(()) @@ -194,27 +206,22 @@ impl PromptComposer { /// パーシャルの内容を読み込み(フォールバック対応) fn load_partial_content(&self, partial_config: &PartialConfig) -> Result { - let primary_path = ConfigParser::resolve_path(&partial_config.path)?; - - // メインパスを試行 - if let Ok(content) = fs::read_to_string(&primary_path) { - return Ok(content); - } - - // フォールバックパスを試行 - if let Some(fallback) = &partial_config.fallback { - let fallback_path = ConfigParser::resolve_path(fallback)?; - if let Ok(content) = fs::read_to_string(&fallback_path) { - return Ok(content); + match self.resource_loader.load(&partial_config.path) { + Ok(content) => Ok(content), + Err(primary_err) => { + if let Some(fallback) = &partial_config.fallback { + match self.resource_loader.load(fallback) { + Ok(content) => Ok(content), + Err(fallback_err) => Err(PromptError::PartialLoading(format!( + "Could not load partial '{}' (fallback: {:?}): primary error={}, fallback error={}", + partial_config.path, partial_config.fallback, primary_err, fallback_err + ))), + } + } else { + Err(primary_err) + } } } - - Err(PromptError::FileNotFound(format!( - "Could not load partial '{}' from {} (fallback: {:?})", - partial_config.path, - primary_path.display(), - partial_config.fallback - ))) } /// コンテキストを指定してテンプレート用のデータを準備 @@ -289,32 +296,32 @@ impl PromptComposer { // カスタムヘルパー関数の実装 -fn include_file_helper( - h: &Helper, - _hbs: &Handlebars, - _ctx: &Context, - _rc: &mut RenderContext, - out: &mut dyn Output, -) -> HelperResult { - let file_path = h.param(0).and_then(|v| v.value().as_str()).unwrap_or(""); +struct IncludeFileHelper { + loader: Arc, +} - match ConfigParser::resolve_path(file_path) { - Ok(path) => { - match fs::read_to_string(&path) { - Ok(content) => { - out.write(&content)?; - } - Err(_) => { - // ファイルが見つからない場合は空文字を出力 - out.write("")?; - } +impl HelperDef for IncludeFileHelper { + fn call<'reg: 'rc, 'rc>( + &self, + h: &Helper<'rc>, + _handlebars: &Handlebars<'reg>, + _context: &Context, + _rc: &mut RenderContext<'reg, 'rc>, + out: &mut dyn Output, + ) -> HelperResult { + let file_path = h.param(0).and_then(|v| v.value().as_str()).unwrap_or(""); + + match self.loader.load(file_path) { + Ok(content) => { + out.write(&content)?; + } + Err(_) => { + // ファイルが見つからない場合は空文字を出力 + out.write("")?; } } - Err(_) => { - out.write("")?; - } + Ok(()) } - Ok(()) } fn workspace_content_helper( diff --git a/worker/src/prompt/mod.rs b/worker/src/prompt/mod.rs index 4e9c5eb..5d612cf 100644 --- a/worker/src/prompt/mod.rs +++ b/worker/src/prompt/mod.rs @@ -3,6 +3,6 @@ mod types; pub use composer::PromptComposer; pub use types::{ - ConditionConfig, GitInfo, ModelCapabilities, ModelContext, PartialConfig, PromptContext, - PromptError, ProjectType, Role, SessionContext, SystemInfo, WorkspaceContext, + ConditionConfig, GitInfo, ModelCapabilities, ModelContext, PartialConfig, ProjectType, + PromptContext, PromptError, ResourceLoader, Role, SessionContext, SystemInfo, WorkspaceContext, }; diff --git a/worker/src/prompt/types.rs b/worker/src/prompt/types.rs index 2d406b4..6f8a820 100644 --- a/worker/src/prompt/types.rs +++ b/worker/src/prompt/types.rs @@ -21,6 +21,11 @@ pub struct PartialConfig { pub description: Option, } +/// External resource loader used to resolve template includes/partials +pub trait ResourceLoader: Send + Sync { + fn load(&self, identifier: &str) -> Result; +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ConditionConfig { pub when: String, @@ -354,7 +359,11 @@ impl Default for SessionContext { impl Role { /// Create a new Role with name, description, and template - pub fn new(name: impl Into, description: impl Into, template: impl Into) -> Self { + pub fn new( + name: impl Into, + description: impl Into, + template: impl Into, + ) -> Self { Self { name: name.into(), description: description.into(), diff --git a/worker/src/tests/config_tests.rs b/worker/src/tests/config_tests.rs index e3d7a4a..4c9eb31 100644 --- a/worker/src/tests/config_tests.rs +++ b/worker/src/tests/config_tests.rs @@ -1,5 +1,13 @@ use crate::config::ConfigParser; +use crate::prompt::{ + ModelCapabilities, ModelContext, PromptComposer, PromptContext, PromptError, ResourceLoader, + SessionContext, SystemInfo, WorkspaceContext, +}; +use crate::types::LlmProvider; +use std::collections::HashMap; use std::io::Write; +use std::path::PathBuf; +use std::sync::Arc; use tempfile::NamedTempFile; #[test] @@ -166,30 +174,104 @@ template: "File content {{user_input}}" assert_eq!(config.template, "File content {{user_input}}"); } +struct InMemoryLoader { + data: HashMap, +} + +impl InMemoryLoader { + fn new(data: HashMap) -> Self { + Self { data } + } +} + +impl ResourceLoader for InMemoryLoader { + fn load(&self, identifier: &str) -> Result { + self.data + .get(identifier) + .cloned() + .ok_or_else(|| PromptError::FileNotFound(format!("not found: {}", identifier))) + } +} + +fn build_prompt_context() -> PromptContext { + let workspace = WorkspaceContext { + root_path: PathBuf::from("."), + nia_md_content: None, + project_type: None, + git_info: None, + has_nia_md: false, + project_name: None, + system_info: SystemInfo::collect(), + }; + + let capabilities = ModelCapabilities { + supports_tools: false, + supports_function_calling: false, + supports_vision: false, + supports_multimodal: None, + context_length: None, + capabilities: vec![], + needs_verification: None, + }; + + let model_context = ModelContext { + provider: LlmProvider::Claude, + model_name: "test-model".to_string(), + capabilities, + supports_native_tools: false, + }; + + let session_context = SessionContext { + conversation_id: None, + message_count: 0, + active_tools: vec![], + user_preferences: None, + }; + + PromptContext { + workspace, + model: model_context, + session: session_context, + variables: HashMap::new(), + } +} + #[test] -fn test_resolve_path() { - // #nia/ prefix - let path = - ConfigParser::resolve_path("#nia/prompts/test.md").expect("Failed to resolve nia path"); - assert!( - path.to_string_lossy() - .contains("nia-cli/resources/prompts/test.md") - ); +fn test_prompt_composer_uses_resource_loader() { + let yaml_content = r##" +name: "Loader Test" +description: "Ensure resource loader is used" +template: | + {{>header}} + {{include_file "include.md"}} - // #workspace/ prefix - let path = ConfigParser::resolve_path("#workspace/config.md") - .expect("Failed to resolve workspace path"); - assert!(path.to_string_lossy().contains(".nia/config.md")); +partials: + header: + path: "missing.md" + fallback: "fallback.md" +"##; - // #user/ prefix - let path = - ConfigParser::resolve_path("#user/settings.md").expect("Failed to resolve user path"); - assert!(path.to_string_lossy().contains("settings.md")); + let role = + ConfigParser::parse_from_string(yaml_content).expect("Failed to parse loader test config"); - // Regular path - let path = - ConfigParser::resolve_path("regular/path.md").expect("Failed to resolve regular path"); - assert_eq!(path.to_string_lossy(), "regular/path.md"); + let loader = Arc::new(InMemoryLoader::new(HashMap::from([ + ("fallback.md".to_string(), "Fallback Partial".to_string()), + ("include.md".to_string(), "Included Content".to_string()), + ]))); + + let prompt_context = build_prompt_context(); + + let composer = PromptComposer::from_config(role, prompt_context, loader) + .expect("Composer should use provided loader"); + + let messages = composer + .compose(&[]) + .expect("Composer should build system prompt"); + + assert!(!messages.is_empty()); + let system_message = &messages[0]; + assert!(system_message.content.contains("Fallback Partial")); + assert!(system_message.content.contains("Included Content")); } #[test] diff --git a/worker/src/types.rs b/worker/src/types.rs index 3876643..933173e 100644 --- a/worker/src/types.rs +++ b/worker/src/types.rs @@ -25,7 +25,10 @@ pub enum WorkerError { /// Model not found for the specified provider #[error("Model not found: {model_name} for provider {provider}")] - ModelNotFound { provider: String, model_name: String }, + ModelNotFound { + provider: String, + model_name: String, + }, /// JSON serialization/deserialization error #[error("JSON error: {0}")] @@ -138,10 +141,7 @@ impl WorkerError { } /// Create a configuration error with context - pub fn config_with_context( - message: impl Into, - context: impl Into, - ) -> Self { + pub fn config_with_context(message: impl Into, context: impl Into) -> Self { Self::ConfigurationError { message: message.into(), context: Some(context.into()), diff --git a/worker/src/workspace/detector.rs b/worker/src/workspace/detector.rs index 899350a..d48d81d 100644 --- a/worker/src/workspace/detector.rs +++ b/worker/src/workspace/detector.rs @@ -1,4 +1,4 @@ -use crate::prompt::{WorkspaceContext, PromptError, ProjectType, GitInfo}; +use crate::prompt::{GitInfo, ProjectType, PromptError, WorkspaceContext}; use std::fs; use std::path::{Path, PathBuf}; use std::process::Command;