0.3.0: テンプレートエンジンのファイルのロードを外部化

This commit is contained in:
Keisuke Hirata 2025-10-24 07:37:47 +09:00
parent b494529779
commit 02667f5396
28 changed files with 569 additions and 662 deletions

4
Cargo.lock generated
View File

@ -2212,7 +2212,7 @@ checksum = "052283831dbae3d879dc7f51f3d92703a316ca49f91540417d38591826127814"
[[package]] [[package]]
name = "worker" name = "worker"
version = "0.1.0" version = "0.3.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-stream", "async-stream",
@ -2249,7 +2249,7 @@ dependencies = [
[[package]] [[package]]
name = "worker-macros" name = "worker-macros"
version = "0.1.0" version = "0.2.1"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",

432
README.md
View File

@ -1,399 +1,97 @@
# `worker` # `worker`
`worker` クレートは、大規模言語モデル (LLM) を利用したアプリケーションのバックエンド機能を提供するクレートです。LLM プロバイダーの抽象化、ツール利用、柔軟なプロンプト管理、フックシステムなど、高度な機能をカプセル化し、アプリケーション開発を簡素化します。 `worker` は、複数の LLM プロバイダーを横断して扱える統合ワーカーです。モデル呼び出し、ツール実行、プロンプト構築、フック連携など、対話アプリに必要な機能を 1 つの API にまとめます。
## 主な機能 ## 特徴
- 主要 LLM プロバイダーを単一のインターフェースで利用
- プロンプト/パーシャル読み込みを利用者実装の `ResourceLoader` へ委譲
- ツール連携とストリーミング応答、フックによるカスタマイズに対応
- **マルチプロバイダー対応**: Gemini, Claude, OpenAI, Ollama, XAI など、複数の LLM プロバイダーを統一されたインターフェースで利用できます。 ## 利用手順
- **プラグインシステム**: カスタムプロバイダーをプラグインとして動的に追加できます。独自の LLM API や実験的なプロバイダーをサポートします。
- **ツール利用 (Function Calling)**: LLM が外部ツールを呼び出す機能をサポートします。独自のツールをマクロを用いて定義し、`Worker` に登録できます。
- **ストリーミング処理**: LLM の応答やツール実行結果を `StreamEvent` として非同期に受け取ることができます。これにより、リアルタイムな UI 更新が可能になります。
- **フックシステム**: `Worker` の処理フローの特定のタイミング(例: メッセージ送信前、ツール使用後)にカスタムロジックを介入させることができます。
- **セッション管理**: 会話履歴やワークスペースの状態を管理し、永続化する機能を提供します。
- **柔軟なプロンプト管理**: 設定ファイルを用いて、ロールやコンテキストに応じたシステムプロンプトを動的に構築します。
## 主な概念 `worker` のプロンプト/パーシャル解決は利用者側に委譲されています。以下の流れで組み込みます。
### `Worker` 1. `ResourceLoader` を実装して、テンプレートやパーシャルが参照する識別子から文字列を返す。
2. `Worker::builder()` にプロバイダー・モデル・ロールと合わせて `resource_loader` を渡し、`Worker` を生成。
このクレートの中心的な構造体です。LLM との対話、ツールの登録と実行、セッション管理など、すべての主要な機能を担当します。 3. セッションを初期化し、`process_task_with_history` などの API でイベントストリームを処理。
### `LlmProvider`
サポートしている LLM プロバイダー(`Gemini`, `Claude`, `OpenAI` など)を表す enum です。
### `Tool` トレイト
`Worker` が利用できるツールを定義するためのインターフェースです。このトレイトを実装することで、任意の機能をツールとして `Worker` に追加できます。
```rust
pub trait Tool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters_schema(&self) -> serde_json::Value;
async fn execute(&self, args: serde_json::Value) -> ToolResult<serde_json::Value>;
}
```
### `WorkerHook` トレイト
`Worker` のライフサイクルイベントに介入するためのフックを定義するインターフェースです。特定のイベント(例: `OnMessageSend`, `PostToolUse`)に対して処理を追加できます。
### `StreamEvent`
`Worker` の処理結果を非同期ストリームで受け取るための enum です。LLM の応答チャンク、ツール呼び出し、エラーなど、さまざまなイベントを表します。
## アプリケーションへの組み込み方法
### 1. Worker の初期化
Builder patternを使用してWorkerを作成します。
```rust
use worker::{Worker, LlmProvider, Role};
use std::collections::HashMap;
// ロールを定義(必須)
let role = Role::new(
"assistant",
"AI Assistant",
"You are a helpful AI assistant."
);
// APIキーを準備
let mut api_keys = HashMap::new();
api_keys.insert("openai".to_string(), "your_openai_api_key".to_string());
api_keys.insert("claude".to_string(), "your_claude_api_key".to_string());
// Workerを作成builder pattern
let mut worker = Worker::builder()
.provider(LlmProvider::OpenAI)
.model("gpt-4o")
.api_keys(api_keys)
.role(role)
.build()
.expect("Workerの作成に失敗しました");
// または、個別にAPIキーを設定
let worker = Worker::builder()
.provider(LlmProvider::Claude)
.model("claude-3-sonnet-20240229")
.api_key("claude", "sk-ant-...")
.role(role)
.build()?;
```
### 2. ツールの定義と登録
`Tool` トレイトを実装してカスタムツールを作成し、`Worker` に登録します。
```rust
use worker::{Tool, ToolResult};
use worker::schemars::{self, JsonSchema};
use worker::serde_json::{self, json, Value};
use async_trait::async_trait;
// ツールの引数を定義
#[derive(Debug, serde::Deserialize, JsonSchema)]
struct FileSystemToolArgs {
path: String,
}
// カスタムツールを定義
struct ListFilesTool;
#[async_trait]
impl Tool for ListFilesTool {
fn name(&self) -> &str { "list_files" }
fn description(&self) -> &str { "指定されたパスのファイル一覧を表示します" }
fn parameters_schema(&self) -> Value {
serde_json::to_value(schemars::schema_for!(FileSystemToolArgs)).unwrap()
}
async fn execute(&self, args: Value) -> ToolResult<Value> {
let tool_args: FileSystemToolArgs = serde_json::from_value(args)?;
// ここで実際のファイル一覧取得処理を実装
let files = vec!["file1.txt", "file2.txt"];
Ok(json!({ "files": files }))
}
}
// 作成したツールをWorkerに登録
worker.register_tool(Box::new(ListFilesTool)).unwrap();
```
#### マクロを使ったツール定義(推奨)
`worker-macros` クレートの `#[tool]` マクロを使用すると、ツールの定義がより簡潔になります:
```rust
use worker_macros::tool;
use worker::ToolResult;
use serde::{Deserialize, Serialize};
use schemars::JsonSchema;
#[derive(Debug, Deserialize, Serialize, JsonSchema)]
struct ListFilesArgs {
path: String,
}
#[tool]
async fn list_files(args: ListFilesArgs) -> ToolResult<serde_json::Value> {
// ファイル一覧取得処理
let files = vec!["file1.txt", "file2.txt"];
Ok(serde_json::json!({ "files": files }))
}
// マクロで生成されたツールを登録
worker.register_tool(Box::new(ListFilesTool))?;
```
### 3. 対話処理の実行
`process_task_with_history` メソッドを呼び出して、ユーザーメッセージを処理します。このメソッドはイベントのストリームを返します。
```rust ```rust
use futures_util::StreamExt; use futures_util::StreamExt;
use worker::{LlmProvider, PromptError, ResourceLoader, Role, Worker};
let user_message = "カレントディレクトリのファイルを教えて".to_string(); struct FsLoader;
let mut stream = worker.process_task_with_history(user_message, None).await; impl ResourceLoader for FsLoader {
fn load(&self, id: &str) -> Result<String, PromptError> {
while let Some(event_result) = stream.next().await { std::fs::read_to_string(id)
match event_result { .map_err(|e| PromptError::FileNotFound(format!("{}: {}", id, e)))
Ok(event) => {
// StreamEventに応じた処理
match event {
worker::StreamEvent::Chunk(chunk) => {
print!("{}", chunk);
}
worker::StreamEvent::ToolCall(tool_call) => {
println!("\n[Tool Call: {} with args {}]", tool_call.name, tool_call.arguments);
}
worker::StreamEvent::ToolResult { tool_name, result } => {
println!("\n[Tool Result: {} -> {:?}]", tool_name, result);
}
_ => {}
}
}
Err(e) => {
eprintln!("\n[Error: {}]", e);
break;
}
}
}
```
### 4. (オプション) フックの登録
`WorkerHook` トレイトを実装してカスタムフックを作成し、`Worker` に登録することで、処理フローをカスタマイズできます。
#### 手動実装
```rust
use worker::{WorkerHook, HookContext, HookResult};
use async_trait::async_trait;
struct LoggingHook;
#[async_trait]
impl WorkerHook for LoggingHook {
fn name(&self) -> &str { "logging_hook" }
fn hook_type(&self) -> &str { "OnMessageSend" }
fn matcher(&self) -> &str { "" }
async fn execute(&self, context: HookContext) -> (HookContext, HookResult) {
println!("User message: {}", context.content);
(context, HookResult::Continue)
} }
} }
// フックを登録 #[tokio::main]
worker.register_hook(Box::new(LoggingHook)); async fn main() -> Result<(), Box<dyn std::error::Error>> {
``` let role = Role::new("assistant", "Helper", "You are a helpful assistant.");
#### マクロを使ったフック定義(推奨) let mut worker = Worker::builder()
.provider(LlmProvider::Claude)
`worker-macros` クレートの `#[hook]` マクロを使用すると、フックの定義がより簡潔になります: .model("claude-3-sonnet-20240229")
.resource_loader(FsLoader)
```rust
use worker_macros::hook;
use worker::{HookContext, HookResult};
#[hook(OnMessageSend)]
async fn logging_hook(context: HookContext) -> (HookContext, HookResult) {
println!("User message: {}", context.content);
(context, HookResult::Continue)
}
// マクロで生成されたフックを登録
worker.register_hook(Box::new(LoggingHook));
```
**利用可能なフックタイプ:**
- `OnMessageSend`: ユーザーメッセージ送信前
- `PreToolUse`: ツール実行前
- `PostToolUse`: ツール実行後
- `OnTurnCompleted`: ターン完了時
これで、アプリケーションの要件に応じて `Worker` を中心とした強力な LLM 連携機能を構築できます。
## サンプルコード
完全な動作例は `worker/examples/` ディレクトリを参照してください:
- [`builder_basic.rs`](worker/examples/builder_basic.rs) - Builder patternの基本的な使用方法
- [`plugin_usage.rs`](worker/examples/plugin_usage.rs) - プラグインシステムの使用方法
## プラグインシステム
### プラグインの作成
`ProviderPlugin` トレイトを実装してカスタムプロバイダーを作成できます:
```rust
use worker::plugin::{ProviderPlugin, PluginMetadata};
use async_trait::async_trait;
pub struct MyCustomProvider {
// プロバイダーの状態
}
#[async_trait]
impl ProviderPlugin for MyCustomProvider {
fn metadata(&self) -> PluginMetadata {
PluginMetadata {
id: "my-provider".to_string(),
name: "My Custom Provider".to_string(),
version: "1.0.0".to_string(),
author: "Your Name".to_string(),
description: "カスタムプロバイダーの説明".to_string(),
supported_models: vec!["model-1".to_string()],
requires_api_key: true,
config_schema: None,
}
}
async fn initialize(&mut self, config: HashMap<String, Value>) -> Result<(), WorkerError> {
// プロバイダーの初期化
Ok(())
}
fn create_client(
&self,
model_name: &str,
api_key: Option<&str>,
config: Option<HashMap<String, Value>>,
) -> Result<Box<dyn LlmClientTrait>, WorkerError> {
// LLMクライアントを作成して返す
}
fn as_any(&self) -> &dyn Any {
self
}
}
```
### プラグインの使用
```rust
use worker::{Worker, Role, plugin::PluginRegistry};
use std::sync::{Arc, Mutex};
// プラグインレジストリを作成
let plugin_registry = Arc::new(Mutex::new(PluginRegistry::new()));
// プラグインを作成して登録
let my_plugin = Arc::new(MyCustomProvider::new());
{
let mut registry = plugin_registry.lock().unwrap();
registry.register(my_plugin)?;
}
// ロールを定義
let role = Role::new(
"assistant",
"AI Assistant",
"You are a helpful AI assistant."
);
let worker = Worker::builder()
.plugin("my-provider", plugin_registry.clone())
.model("model-1")
.api_key("__plugin__", "api-key")
.role(role) .role(role)
.build()?; .build()?;
worker.initialize_session()?;
let events = worker
.process_task_with_history("こんにちは!".into(), None)
.await;
futures_util::pin_mut!(events);
while let Some(event) = events.next().await {
println!("{event:?}");
}
Ok(())
}
``` ```
### 動的プラグイン読み込み ### ツールを登録する
`dynamic-loading` フィーチャーを有効にすることで、共有ライブラリからプラグインを動的に読み込むことができます: `#[worker::tool]` マクロで関数を装飾すると、`Tool` 実装を自動生成できます。
```toml
[dependencies]
worker = { path = "../worker", features = ["dynamic-loading"] }
```
```rust ```rust
// ディレクトリからプラグインを読み込み use schemars::JsonSchema;
worker.load_plugins_from_directory(Path::new("./plugins")).await?; use serde::{Deserialize, Serialize};
use worker::ToolResult;
#[derive(Debug, Deserialize, Serialize, JsonSchema)]
struct EchoArgs {
text: String,
}
#[worker::tool(name = "echo")]
async fn echo(args: EchoArgs) -> ToolResult<String> {
Ok(args.text)
}
worker.register_tool(Box::new(EchoTool::new()))?;
``` ```
完全な例は `worker/src/plugin/example_provider.rs``worker/examples/plugin_usage.rs` を参照してください。 ### フックを登録する
## エラーハンドリング `#[worker::hook]` マクロで非同期関数を装飾すると、`WorkerHook` 実装が生成されます。
構造化されたエラー型により、詳細なエラー情報を取得できます。
### WorkerError の種類
```rust ```rust
use worker::WorkerError; use worker::{HookContext, HookResult};
// ツール実行エラー #[worker::hook(hook_type = "OnMessageSend")]
let error = WorkerError::tool_execution("my_tool", "Connection failed"); async fn log(context: HookContext) -> (HookContext, HookResult) {
let error = WorkerError::tool_execution_with_source("my_tool", "Failed", source_error); println!("sending: {}", context.content);
(context, HookResult::Continue)
}
// 設定エラー worker.register_hook(Box::new(LogHook::new()));
let error = WorkerError::config("Invalid configuration");
let error = WorkerError::config_with_context("Parse error", "config.yaml line 10");
let error = WorkerError::config_with_source("Failed to load", io_error);
// LLM APIエラー
let error = WorkerError::llm_api("openai", "Rate limit exceeded");
let error = WorkerError::llm_api_with_details("claude", "Invalid request", Some(400), None);
// モデルエラー
let error = WorkerError::model_not_found("openai", "gpt-5");
// ネットワークエラー
let error = WorkerError::network("Connection timeout");
let error = WorkerError::network_with_source("Request failed", reqwest_error);
``` ```
### エラーのパターンマッチング ## ドキュメント
```rust - API の詳細は `cargo doc --open` で参照できます。
match worker.build() { - プロンプトシステムの概要: `docs/prompt-composer.md`
Ok(worker) => { /* ... */ }, - サンプルコード: `worker/examples/`
Err(WorkerError::ConfigurationError { message, context, .. }) => {
eprintln!("Configuration error: {}", message);
if let Some(ctx) = context {
eprintln!("Context: {}", ctx);
}
},
Err(WorkerError::ModelNotFound { provider, model_name }) => {
eprintln!("Model '{}' not found for provider '{}'", model_name, provider);
},
Err(WorkerError::LlmApiError { provider, message, status_code, .. }) => {
eprintln!("API error from {}: {}", provider, message);
if let Some(code) = status_code {
eprintln!("Status code: {}", code);
}
},
Err(e) => {
eprintln!("Error: {}", e);
}
}
```

33
docs/patch_note/v0.3.0.md Normal file
View File

@ -0,0 +1,33 @@
# Release Notes - v0.3.0
**Release Date**: 2025-??-??
v0.3.0 はプロンプトリソースの解決責務を利用側へ完全に移し、ツール/フック登録の推奨フローを明確化するアップデートです。これにより、ワーカーの動作を環境ごとに柔軟に制御できるようになりました。
## Breaking Changes
- `ConfigParser::resolve_path` を削除し、`#nia/` `#workspace/` 等のプレフィックス解決をライブラリ利用者実装の `ResourceLoader` に委譲しました。
- `WorkerBuilder::build()``resource_loader(...)` が未指定の場合エラーを返すようになりました。ワーカー構築前に必ずローダーを提供してください。
## 新機能 / 仕様変更
- `PromptComposer``ResourceLoader` を必須依存として受け取り、partials や `{{include_file}}` の読み込みをすべてローダー経由で行うようになりました。
- パーシャル読み込み時にフォールバックが失敗した場合、一次/二次エラー内容を含むメッセージを返すよう改善しました。
- README とドキュメントを刷新し、推奨ワークフローResourceLoader 実装 → Worker 構築 → イベント処理)を明示。`#[worker::tool]` / `#[worker::hook]` マクロを用いた登録例を追加しました。
- ユニットテスト `test_prompt_composer_uses_resource_loader` を追加し、注入されたローダーがパーシャルinclude の解決に使われることを保証。
## 不具合修正
- `include_file` ヘルパーがカスタムローダーを利用せずにファイルアクセスしていた問題を修正。
- `ConfigParser` が存在しない `#nia/` プレフィックスを静的に解決しようとしていた挙動を除去し、誤ったパスが静かに通ることを防止。
## 移行ガイド
1. 既存コードで `ConfigParser::resolve_path` を直接利用していた場合、代わりにアプリケーション側で `ResourceLoader` を実装し、その中で必要なプレフィックス処理を行ってください。
2. `Worker::builder()` を使用する箇所で、新たに `.resource_loader(...)` を必ず指定してください。
3. ツール・フック登録は `#[worker::tool]` / `#[worker::hook]` マクロを利用する実装に移行することを推奨します。
## 開発者向けメモ
- README を簡潔化し、RustDocs で確認できる内容の重複を削除しました。
- `worker/examples/` を更新し、`ResourceLoader` の実装例とマクロベースのツール登録を採用しました。

View File

@ -5,14 +5,49 @@
## 基本使用方法 ## 基本使用方法
```rust ```rust
use std::sync::Arc;
use worker::prompt::{PromptComposer, PromptContext, PromptError, ResourceLoader};
struct FsLoader;
impl ResourceLoader for FsLoader {
fn load(&self, identifier: &str) -> Result<String, PromptError> {
std::fs::read_to_string(identifier)
.map_err(|e| PromptError::FileNotFound(format!("{}: {}", identifier, e)))
}
}
// 初期化 // 初期化
let composer = PromptComposer::from_config_file("role.yaml", context)?; let loader = Arc::new(FsLoader);
let mut composer = PromptComposer::from_config_file("role.yaml", context, loader.clone())?;
composer.initialize_session(&messages)?; composer.initialize_session(&messages)?;
// プロンプト構築 // プロンプト構築
let messages = composer.compose(&user_messages)?; let messages = composer.compose(&user_messages)?;
``` ```
## リソースローダー
`PromptComposer` はテンプレート内で参照されるパーシャルや `{{include_file}}` の解決をクレート利用者に委ねています。
`ResourceLoader` トレイトを実装して、任意のストレージや命名規則に基づいて文字列を返してください。
```rust
struct MyLoader;
impl ResourceLoader for MyLoader {
fn load(&self, identifier: &str) -> Result<String, PromptError> {
match identifier.strip_prefix("#workspace/") {
Some(rest) => {
let path = std::env::current_dir()?.join(".nia").join(rest);
std::fs::read_to_string(path).map_err(|e| PromptError::FileNotFound(e.to_string()))
}
None => std::fs::read_to_string(identifier)
.map_err(|e| PromptError::FileNotFound(e.to_string())),
}
}
}
```
## テンプレート構文 ## テンプレート構文
### 変数展開 ### 変数展開

View File

@ -1,6 +1,6 @@
[package] [package]
name = "worker" name = "worker"
version = "0.2.1" version = "0.3.0"
edition = "2024" edition = "2024"
[dependencies] [dependencies]

View File

@ -1,5 +1,14 @@
use worker::{LlmProvider, Worker, Role};
use std::collections::HashMap; use std::collections::HashMap;
use worker::{LlmProvider, PromptError, ResourceLoader, Role, Worker};
struct FsLoader;
impl ResourceLoader for FsLoader {
fn load(&self, identifier: &str) -> Result<String, PromptError> {
std::fs::read_to_string(identifier)
.map_err(|e| PromptError::FileNotFound(format!("{}: {}", identifier, e)))
}
}
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Box<dyn std::error::Error>> {
@ -17,6 +26,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.provider(LlmProvider::Claude) .provider(LlmProvider::Claude)
.model("claude-3-sonnet-20240229") .model("claude-3-sonnet-20240229")
.api_keys(api_keys) .api_keys(api_keys)
.resource_loader(FsLoader)
.role(role) .role(role)
.build()?; .build()?;
@ -35,6 +45,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.provider(LlmProvider::Claude) .provider(LlmProvider::Claude)
.model("claude-3-sonnet-20240229") .model("claude-3-sonnet-20240229")
.api_key("claude", std::env::var("ANTHROPIC_API_KEY")?) .api_key("claude", std::env::var("ANTHROPIC_API_KEY")?)
.resource_loader(FsLoader)
.role(code_reviewer_role) .role(code_reviewer_role)
.build()?; .build()?;

View File

@ -1,6 +1,18 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use worker::{Worker, Role, plugin::{PluginRegistry, ProviderPlugin, example_provider::CustomProviderPlugin}}; use worker::{
PromptError, ResourceLoader, Role, Worker,
plugin::{PluginRegistry, ProviderPlugin, example_provider::CustomProviderPlugin},
};
struct FsLoader;
impl ResourceLoader for FsLoader {
fn load(&self, identifier: &str) -> Result<String, PromptError> {
std::fs::read_to_string(identifier)
.map_err(|e| PromptError::FileNotFound(format!("{}: {}", identifier, e)))
}
}
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Box<dyn std::error::Error>> {
@ -37,7 +49,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let plugins = registry.list(); let plugins = registry.list();
println!("Available plugins:"); println!("Available plugins:");
for plugin in plugins { for plugin in plugins {
println!(" - {} ({}): {}", plugin.name, plugin.id, plugin.description); println!(
" - {} ({}): {}",
plugin.name, plugin.id, plugin.description
);
println!(" Supported models: {:?}", plugin.supported_models); println!(" Supported models: {:?}", plugin.supported_models);
} }
} }
@ -46,13 +61,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let role = Role::new( let role = Role::new(
"assistant", "assistant",
"A helpful AI assistant", "A helpful AI assistant",
"You are a helpful, harmless, and honest AI assistant powered by a custom LLM provider." "You are a helpful, harmless, and honest AI assistant powered by a custom LLM provider.",
); );
let worker = Worker::builder() let worker = Worker::builder()
.plugin("custom-provider", plugin_registry.clone()) .plugin("custom-provider", plugin_registry.clone())
.model("custom-turbo") .model("custom-turbo")
.api_key("__plugin__", "custom-1234567890abcdefghijklmnop") .api_key("__plugin__", "custom-1234567890abcdefghijklmnop")
.resource_loader(FsLoader)
.role(role) .role(role)
.build()?; .build()?;
@ -62,7 +78,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let plugin_list = worker.list_plugins()?; let plugin_list = worker.list_plugins()?;
println!("\nPlugins registered in worker:"); println!("\nPlugins registered in worker:");
for metadata in plugin_list { for metadata in plugin_list {
println!(" - {}: v{} by {}", metadata.name, metadata.version, metadata.author); println!(
" - {}: v{} by {}",
metadata.name, metadata.version, metadata.author
);
} }
// Load plugins from directory (if dynamic loading is enabled) // Load plugins from directory (if dynamic loading is enabled)

View File

@ -1,10 +1,10 @@
use crate::Worker; use crate::Worker;
use crate::prompt::Role; use crate::prompt::{ResourceLoader, Role};
use crate::types::WorkerError; use crate::types::WorkerError;
use worker_types::LlmProvider;
use std::collections::HashMap; use std::collections::HashMap;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use worker_types::LlmProvider;
// Type-state markers // Type-state markers
pub struct NoProvider; pub struct NoProvider;
@ -20,13 +20,23 @@ pub struct WithRole;
/// ///
/// # Example /// # Example
/// ```no_run /// ```no_run
/// use worker::{Worker, LlmProvider, Role}; /// use worker::{Worker, LlmProvider, Role, ResourceLoader, PromptError};
///
/// struct FsLoader;
///
/// impl ResourceLoader for FsLoader {
/// fn load(&self, identifier: &str) -> Result<String, PromptError> {
/// std::fs::read_to_string(identifier)
/// .map_err(|e| PromptError::FileNotFound(format!("{}: {}", identifier, e)))
/// }
/// }
/// ///
/// let role = Role::new("assistant", "AI Assistant", "You are a helpful assistant."); /// let role = Role::new("assistant", "AI Assistant", "You are a helpful assistant.");
/// let worker = Worker::builder() /// let worker = Worker::builder()
/// .provider(LlmProvider::Claude) /// .provider(LlmProvider::Claude)
/// .model("claude-3-sonnet-20240229") /// .model("claude-3-sonnet-20240229")
/// .api_key("claude", "sk-ant-...") /// .api_key("claude", "sk-ant-...")
/// .resource_loader(FsLoader)
/// .role(role) /// .role(role)
/// .build()?; /// .build()?;
/// # Ok::<(), worker::WorkerError>(()) /// # Ok::<(), worker::WorkerError>(())
@ -38,6 +48,7 @@ pub struct WorkerBuilder<P, M, R> {
// Role configuration (required) // Role configuration (required)
role: Option<Role>, role: Option<Role>,
resource_loader: Option<Arc<dyn ResourceLoader>>,
// Plugin configuration // Plugin configuration
plugin_id: Option<String>, plugin_id: Option<String>,
@ -53,6 +64,7 @@ impl Default for WorkerBuilder<NoProvider, NoModel, NoRole> {
model_name: None, model_name: None,
api_keys: HashMap::new(), api_keys: HashMap::new(),
role: None, role: None,
resource_loader: None,
plugin_id: None, plugin_id: None,
plugin_registry: None, plugin_registry: None,
_phantom: PhantomData, _phantom: PhantomData,
@ -76,6 +88,7 @@ impl<M, R> WorkerBuilder<NoProvider, M, R> {
model_name: self.model_name, model_name: self.model_name,
api_keys: self.api_keys, api_keys: self.api_keys,
role: self.role, role: self.role,
resource_loader: self.resource_loader,
plugin_id: self.plugin_id, plugin_id: self.plugin_id,
plugin_registry: self.plugin_registry, plugin_registry: self.plugin_registry,
_phantom: PhantomData, _phantom: PhantomData,
@ -95,6 +108,7 @@ impl<M, R> WorkerBuilder<NoProvider, M, R> {
model_name: self.model_name, model_name: self.model_name,
api_keys: self.api_keys, api_keys: self.api_keys,
role: self.role, role: self.role,
resource_loader: self.resource_loader,
plugin_id: self.plugin_id, plugin_id: self.plugin_id,
plugin_registry: self.plugin_registry, plugin_registry: self.plugin_registry,
_phantom: PhantomData, _phantom: PhantomData,
@ -104,13 +118,17 @@ impl<M, R> WorkerBuilder<NoProvider, M, R> {
// Step 2: Set model // Step 2: Set model
impl<R> WorkerBuilder<WithProvider, NoModel, R> { impl<R> WorkerBuilder<WithProvider, NoModel, R> {
pub fn model(mut self, model_name: impl Into<String>) -> WorkerBuilder<WithProvider, WithModel, R> { pub fn model(
mut self,
model_name: impl Into<String>,
) -> WorkerBuilder<WithProvider, WithModel, R> {
self.model_name = Some(model_name.into()); self.model_name = Some(model_name.into());
WorkerBuilder { WorkerBuilder {
provider: self.provider, provider: self.provider,
model_name: self.model_name, model_name: self.model_name,
api_keys: self.api_keys, api_keys: self.api_keys,
role: self.role, role: self.role,
resource_loader: self.resource_loader,
plugin_id: self.plugin_id, plugin_id: self.plugin_id,
plugin_registry: self.plugin_registry, plugin_registry: self.plugin_registry,
_phantom: PhantomData, _phantom: PhantomData,
@ -127,6 +145,7 @@ impl WorkerBuilder<WithProvider, WithModel, NoRole> {
model_name: self.model_name, model_name: self.model_name,
api_keys: self.api_keys, api_keys: self.api_keys,
role: self.role, role: self.role,
resource_loader: self.resource_loader,
plugin_id: self.plugin_id, plugin_id: self.plugin_id,
plugin_registry: self.plugin_registry, plugin_registry: self.plugin_registry,
_phantom: PhantomData, _phantom: PhantomData,
@ -147,26 +166,44 @@ impl<P, M, R> WorkerBuilder<P, M, R> {
self.api_keys = keys; self.api_keys = keys;
self self
} }
/// Provide a resource loader implementation for partial/include resolution
pub fn resource_loader<L>(mut self, loader: L) -> Self
where
L: ResourceLoader + 'static,
{
self.resource_loader = Some(Arc::new(loader));
self
}
} }
// Build // Build
impl WorkerBuilder<WithProvider, WithModel, WithRole> { impl WorkerBuilder<WithProvider, WithModel, WithRole> {
pub fn build(self) -> Result<Worker, WorkerError> { pub fn build(self) -> Result<Worker, WorkerError> {
use crate::{LlmProviderExt, WorkspaceDetector, PromptComposer, plugin}; use crate::{LlmProviderExt, PromptComposer, WorkspaceDetector, plugin};
let resource_loader = self.resource_loader.clone().ok_or_else(|| {
WorkerError::config(
"Resource loader is required. Call resource_loader(...) before build.",
)
})?;
let role = self.role.unwrap(); let role = self.role.unwrap();
let model_name = self.model_name.unwrap(); let model_name = self.model_name.unwrap();
// Plugin provider // Plugin provider
if let (Some(plugin_id), Some(plugin_registry)) = (self.plugin_id, self.plugin_registry) { if let (Some(plugin_id), Some(plugin_registry)) = (self.plugin_id, self.plugin_registry) {
let api_key_opt = self.api_keys.get("__plugin__").or_else(|| { let api_key_opt = self
self.api_keys.values().next() .api_keys
}); .get("__plugin__")
.or_else(|| self.api_keys.values().next());
let registry = plugin_registry.lock() let registry = plugin_registry.lock().map_err(|e| {
.map_err(|e| WorkerError::config(format!("Failed to lock plugin registry: {}", e)))?; WorkerError::config(format!("Failed to lock plugin registry: {}", e))
})?;
let plugin = registry.get(&plugin_id) let plugin = registry
.get(&plugin_id)
.ok_or_else(|| WorkerError::config(format!("Plugin not found: {}", plugin_id)))?; .ok_or_else(|| WorkerError::config(format!("Plugin not found: {}", plugin_id)))?;
let llm_client = plugin::PluginClient::new( let llm_client = plugin::PluginClient::new(
@ -190,7 +227,8 @@ impl WorkerBuilder<WithProvider, WithModel, WithRole> {
tracing::info!("Creating worker with plugin and role: {}", role.name); tracing::info!("Creating worker with plugin and role: {}", role.name);
let composer = PromptComposer::from_config(role.clone(), prompt_context) let composer =
PromptComposer::from_config(role.clone(), prompt_context, resource_loader.clone())
.map_err(|e| WorkerError::config(e.to_string()))?; .map_err(|e| WorkerError::config(e.to_string()))?;
drop(registry); drop(registry);
@ -198,6 +236,7 @@ impl WorkerBuilder<WithProvider, WithModel, WithRole> {
let mut worker = Worker { let mut worker = Worker {
llm_client: Box::new(llm_client), llm_client: Box::new(llm_client),
composer, composer,
resource_loader: resource_loader.clone(),
tools: Vec::new(), tools: Vec::new(),
api_key, api_key,
provider_str, provider_str,
@ -210,7 +249,8 @@ impl WorkerBuilder<WithProvider, WithModel, WithRole> {
plugin_registry: plugin_registry.clone(), plugin_registry: plugin_registry.clone(),
}; };
worker.initialize_session() worker
.initialize_session()
.map_err(|e| WorkerError::config(e.to_string()))?; .map_err(|e| WorkerError::config(e.to_string()))?;
return Ok(worker); return Ok(worker);
@ -234,12 +274,14 @@ impl WorkerBuilder<WithProvider, WithModel, WithRole> {
tracing::info!("Creating worker with role: {}", role.name); tracing::info!("Creating worker with role: {}", role.name);
let composer = PromptComposer::from_config(role.clone(), prompt_context) let composer =
PromptComposer::from_config(role.clone(), prompt_context, resource_loader.clone())
.map_err(|e| WorkerError::config(e.to_string()))?; .map_err(|e| WorkerError::config(e.to_string()))?;
let mut worker = Worker { let mut worker = Worker {
llm_client: Box::new(llm_client), llm_client: Box::new(llm_client),
composer, composer,
resource_loader,
tools: Vec::new(), tools: Vec::new(),
api_key, api_key,
provider_str: provider_str.to_string(), provider_str: provider_str.to_string(),
@ -252,7 +294,8 @@ impl WorkerBuilder<WithProvider, WithModel, WithRole> {
plugin_registry, plugin_registry,
}; };
worker.initialize_session() worker
.initialize_session()
.map_err(|e| WorkerError::config(e.to_string()))?; .map_err(|e| WorkerError::config(e.to_string()))?;
Ok(worker) Ok(worker)

View File

@ -6,8 +6,8 @@ use crate::llm::{
xai::XAIClient, xai::XAIClient,
}; };
use crate::types::WorkerError; use crate::types::WorkerError;
use worker_types::{LlmProvider, Message, StreamEvent};
use futures_util::Stream; use futures_util::Stream;
use worker_types::{LlmProvider, Message, StreamEvent};
// LlmClient enumを使用してdyn互換性の問題を解決 // LlmClient enumを使用してdyn互換性の問題を解決
pub enum LlmClient { pub enum LlmClient {

View File

@ -53,60 +53,4 @@ impl ConfigParser {
Ok(()) Ok(())
} }
/// パスプレフィックスを解決する
pub fn resolve_path(path_str: &str) -> Result<std::path::PathBuf, PromptError> {
if path_str.starts_with("#nia/") {
// 組み込みリソース
let relative_path = path_str.strip_prefix("#nia/").unwrap();
let project_root = std::env::current_dir()
.map_err(|e| PromptError::WorkspaceDetection(e.to_string()))?;
// 優先順位: ./resources > ./cli/resources > ./nia-core/resources > ./nia-pod/resources
let possible_paths = [
project_root.join("resources").join(relative_path),
project_root
.join("cli")
.join("resources")
.join(relative_path),
project_root
.join("nia-core")
.join("resources")
.join(relative_path),
project_root
.join("nia-pod")
.join("resources")
.join(relative_path),
];
for path in &possible_paths {
if path.exists() {
return Ok(path.clone());
}
}
// 見つからない場合はデフォルトのパスを返す
Ok(project_root
.join("nia-cli")
.join("resources")
.join(relative_path))
} else if path_str.starts_with("#workspace/") {
// ワークスペース固有
let relative_path = path_str.strip_prefix("#workspace/").unwrap();
let project_root = std::env::current_dir()
.map_err(|e| PromptError::WorkspaceDetection(e.to_string()))?;
Ok(project_root.join(".nia").join(relative_path))
} else if path_str.starts_with("#user/") {
// ユーザー設定
let relative_path = path_str.strip_prefix("#user/").unwrap();
let base_dirs = xdg::BaseDirectories::with_prefix("nia");
let config_home = base_dirs.get_config_home().ok_or_else(|| {
PromptError::WorkspaceDetection("Could not determine XDG config home".to_string())
})?;
Ok(config_home.join(relative_path))
} else {
// 相対パスまたは絶対パス
Ok(std::path::PathBuf::from(path_str))
}
}
} }

View File

@ -1,9 +1,9 @@
// Core types and traits for the worker crate // Core types and traits for the worker crate
// This module contains the primary abstractions used throughout the crate // This module contains the primary abstractions used throughout the crate
use worker_types::{DynamicToolDefinition, LlmProvider, Message, StreamEvent};
use crate::types::WorkerError; use crate::types::WorkerError;
use futures_util::Stream; use futures_util::Stream;
use worker_types::{DynamicToolDefinition, LlmProvider, Message, StreamEvent};
/// LlmClient trait - common interface for all LLM clients /// LlmClient trait - common interface for all LLM clients
/// ///

View File

@ -1,6 +1,6 @@
use crate::prompt::{ use crate::prompt::{
PromptComposer, PromptContext, WorkspaceContext, ModelContext, ModelCapabilities, ModelCapabilities, ModelContext, PromptComposer, PromptContext, SessionContext,
SessionContext WorkspaceContext,
}; };
use crate::workspace::WorkspaceDetector; use crate::workspace::WorkspaceDetector;
use async_stream::stream; use async_stream::stream;
@ -15,29 +15,29 @@ use std::fs;
use std::path::PathBuf; use std::path::PathBuf;
use tracing; use tracing;
use uuid; use uuid;
pub use worker_macros::{hook, tool};
pub use worker_types::{ pub use worker_types::{
DynamicToolDefinition, HookContext, HookEvent, HookManager, HookResult, LlmDebug, LlmProvider, DynamicToolDefinition, HookContext, HookEvent, HookManager, HookResult, LlmDebug, LlmProvider,
LlmResponse, Message, ModelInfo, SessionData, StreamEvent, Task, Tool, ToolCall, ToolResult, LlmResponse, Message, ModelInfo, SessionData, StreamEvent, Task, Tool, ToolCall, ToolResult,
WorkerHook, WorkspaceConfig, WorkspaceData, WorkerHook, WorkspaceConfig, WorkspaceData,
}; };
pub use worker_macros::{hook, tool};
pub mod core;
pub mod types;
pub mod client;
pub mod builder; pub mod builder;
pub mod client;
pub mod config; pub mod config;
pub mod core;
pub mod llm; pub mod llm;
pub mod mcp; pub mod mcp;
pub mod plugin; pub mod plugin;
pub mod prompt; pub mod prompt;
pub mod types;
pub mod workspace; pub mod workspace;
pub use core::LlmClientTrait; pub use crate::prompt::{PromptError, ResourceLoader, Role};
pub use client::LlmClient;
pub use crate::prompt::Role;
pub use builder::WorkerBuilder;
pub use crate::types::WorkerError; pub use crate::types::WorkerError;
pub use builder::WorkerBuilder;
pub use client::LlmClient;
pub use core::LlmClientTrait;
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
@ -172,7 +172,8 @@ impl WorkerError {
|| error_msg.contains("billing") || error_msg.contains("billing")
|| error_msg.contains("upgrade") || error_msg.contains("upgrade")
|| error_msg.contains("purchase credits") || error_msg.contains("purchase credits")
|| (error_msg.contains("invalid_request_error") && error_msg.contains("credit balance")); || (error_msg.contains("invalid_request_error")
&& error_msg.contains("credit balance"));
let has_provider_patterns = match provider { let has_provider_patterns = match provider {
LlmProvider::OpenAI => { LlmProvider::OpenAI => {
@ -319,9 +320,7 @@ pub async fn validate_api_key(
Ok(Some(false)) Ok(Some(false))
} }
} }
LlmProvider::Ollama => { LlmProvider::Ollama => Ok(Some(true)),
Ok(Some(true))
}
LlmProvider::XAI => { LlmProvider::XAI => {
if api_key.starts_with("xai-") && api_key.len() > 20 { if api_key.starts_with("xai-") && api_key.len() > 20 {
tracing::debug!("validate_api_key: xAI API key format appears valid"); tracing::debug!("validate_api_key: xAI API key format appears valid");
@ -357,9 +356,8 @@ pub struct ModelMeta {
} }
fn get_models_config_path() -> Result<PathBuf, WorkerError> { fn get_models_config_path() -> Result<PathBuf, WorkerError> {
let home_dir = dirs::home_dir().ok_or_else(|| { let home_dir = dirs::home_dir()
WorkerError::config("Could not determine home directory") .ok_or_else(|| WorkerError::config("Could not determine home directory"))?;
})?;
Ok(home_dir.join(".config").join("nia").join("models.yaml")) Ok(home_dir.join(".config").join("nia").join("models.yaml"))
} }
@ -374,13 +372,11 @@ fn load_models_config() -> Result<ModelsConfig, WorkerError> {
return Ok(ModelsConfig { models: vec![] }); return Ok(ModelsConfig { models: vec![] });
} }
let content = fs::read_to_string(&config_path).map_err(|e| { let content = fs::read_to_string(&config_path)
WorkerError::config(format!("Failed to read models config: {}", e)) .map_err(|e| WorkerError::config(format!("Failed to read models config: {}", e)))?;
})?;
let config: ModelsConfig = serde_yaml::from_str(&content).map_err(|e| { let config: ModelsConfig = serde_yaml::from_str(&content)
WorkerError::config(format!("Failed to parse models config: {}", e)) .map_err(|e| WorkerError::config(format!("Failed to parse models config: {}", e)))?;
})?;
Ok(config) Ok(config)
} }
@ -447,6 +443,7 @@ pub async fn supports_native_tools(
pub struct Worker { pub struct Worker {
pub(crate) llm_client: Box<dyn LlmClientTrait>, pub(crate) llm_client: Box<dyn LlmClientTrait>,
pub(crate) composer: PromptComposer, pub(crate) composer: PromptComposer,
pub(crate) resource_loader: std::sync::Arc<dyn ResourceLoader>,
pub(crate) tools: Vec<Box<dyn Tool>>, pub(crate) tools: Vec<Box<dyn Tool>>,
pub(crate) api_key: String, pub(crate) api_key: String,
pub(crate) provider_str: String, pub(crate) provider_str: String,
@ -464,27 +461,42 @@ impl Worker {
/// ///
/// # Example /// # Example
/// ```no_run /// ```no_run
/// use worker::{Worker, LlmProvider, Role}; /// use worker::{Worker, LlmProvider, Role, PromptError, ResourceLoader};
///
/// struct FsLoader;
///
/// impl ResourceLoader for FsLoader {
/// fn load(&self, identifier: &str) -> Result<String, PromptError> {
/// std::fs::read_to_string(identifier)
/// .map_err(|e| PromptError::FileNotFound(format!("{}: {}", identifier, e)))
/// }
/// }
/// ///
/// let role = Role::new("assistant", "AI Assistant", "You are a helpful assistant."); /// let role = Role::new("assistant", "AI Assistant", "You are a helpful assistant.");
/// let worker = Worker::builder() /// let worker = Worker::builder()
/// .provider(LlmProvider::Claude) /// .provider(LlmProvider::Claude)
/// .model("claude-3-sonnet-20240229") /// .model("claude-3-sonnet-20240229")
/// .api_key("claude", "sk-ant-...") /// .api_key("claude", "sk-ant-...")
/// .resource_loader(FsLoader)
/// .role(role) /// .role(role)
/// .build()?; /// .build()?;
/// # Ok::<(), worker::WorkerError>(()) /// # Ok::<(), worker::WorkerError>(())
/// ``` /// ```
pub fn builder() -> builder::WorkerBuilder<builder::NoProvider, builder::NoModel, builder::NoRole> { pub fn builder()
-> builder::WorkerBuilder<builder::NoProvider, builder::NoModel, builder::NoRole> {
builder::WorkerBuilder::new() builder::WorkerBuilder::new()
} }
/// Load plugins from a directory /// Load plugins from a directory
#[cfg(feature = "dynamic-loading")] #[cfg(feature = "dynamic-loading")]
pub async fn load_plugins_from_directory(&mut self, dir: &std::path::Path) -> Result<(), WorkerError> { pub async fn load_plugins_from_directory(
&mut self,
dir: &std::path::Path,
) -> Result<(), WorkerError> {
let plugins = plugin::PluginLoader::load_from_directory(dir)?; let plugins = plugin::PluginLoader::load_from_directory(dir)?;
let mut registry = self.plugin_registry.lock() let mut registry = self
.plugin_registry
.lock()
.map_err(|e| WorkerError::config(format!("Failed to lock plugin registry: {}", e)))?; .map_err(|e| WorkerError::config(format!("Failed to lock plugin registry: {}", e)))?;
for plugin in plugins { for plugin in plugins {
@ -496,7 +508,9 @@ impl Worker {
/// List all registered plugins /// List all registered plugins
pub fn list_plugins(&self) -> Result<Vec<plugin::PluginMetadata>, WorkerError> { pub fn list_plugins(&self) -> Result<Vec<plugin::PluginMetadata>, WorkerError> {
let registry = self.plugin_registry.lock() let registry = self
.plugin_registry
.lock()
.map_err(|e| WorkerError::config(format!("Failed to lock plugin registry: {}", e)))?; .map_err(|e| WorkerError::config(format!("Failed to lock plugin registry: {}", e)))?;
Ok(registry.list()) Ok(registry.list())
} }
@ -515,10 +529,7 @@ impl Worker {
} }
/// MCPサーバーをツールとして登録する /// MCPサーバーをツールとして登録する
pub fn register_mcp_server( pub fn register_mcp_server(&mut self, config: McpServerConfig) -> Result<(), WorkerError> {
&mut self,
config: McpServerConfig,
) -> Result<(), WorkerError> {
let mcp_tool = McpDynamicTool::new(config.clone()); let mcp_tool = McpDynamicTool::new(config.clone());
self.register_tool(Box::new(mcp_tool))?; self.register_tool(Box::new(mcp_tool))?;
tracing::info!("Registered MCP server as tool: {}", config.name); tracing::info!("Registered MCP server as tool: {}", config.name);
@ -526,10 +537,7 @@ impl Worker {
} }
/// MCPサーバーから個別のツールを登録する /// MCPサーバーから個別のツールを登録する
pub async fn register_mcp_tools( pub async fn register_mcp_tools(&mut self, config: McpServerConfig) -> Result<(), WorkerError> {
&mut self,
config: McpServerConfig,
) -> Result<(), WorkerError> {
let tools = create_single_mcp_tools(&config).await?; let tools = create_single_mcp_tools(&config).await?;
let tool_count = tools.len(); let tool_count = tools.len();
@ -791,7 +799,8 @@ impl Worker {
let prompt_context = self.create_prompt_context()?; let prompt_context = self.create_prompt_context()?;
// DynamicPromptComposerを作成 // DynamicPromptComposerを作成
let composer = PromptComposer::from_config(role.clone(), prompt_context) let composer =
PromptComposer::from_config(role.clone(), prompt_context, self.resource_loader.clone())
.map_err(|e| WorkerError::config(e.to_string()))?; .map_err(|e| WorkerError::config(e.to_string()))?;
self.role = role; self.role = role;
@ -927,7 +936,7 @@ impl Worker {
if self.tools.iter().any(|t| t.name() == tool.name()) { if self.tools.iter().any(|t| t.name() == tool.name()) {
return Err(WorkerError::tool_execution( return Err(WorkerError::tool_execution(
tool.name(), tool.name(),
format!("Tool '{}' is already registered", tool.name()) format!("Tool '{}' is already registered", tool.name()),
)); ));
} }
@ -961,7 +970,7 @@ impl Worker {
Some(tool) => tool.execute(args).await.map_err(WorkerError::from), Some(tool) => tool.execute(args).await.map_err(WorkerError::from),
None => Err(WorkerError::tool_execution( None => Err(WorkerError::tool_execution(
tool_name, tool_name,
format!("Tool '{}' not found", tool_name) format!("Tool '{}' not found", tool_name),
)), )),
} }
} }
@ -1716,9 +1725,7 @@ impl Worker {
/// セッションデータを取得する /// セッションデータを取得する
pub fn get_session_data(&self) -> Result<SessionData, WorkerError> { pub fn get_session_data(&self) -> Result<SessionData, WorkerError> {
let workspace_path = std::env::current_dir() let workspace_path = std::env::current_dir()
.map_err(|e| { .map_err(|e| WorkerError::config(format!("Failed to get current directory: {}", e)))?
WorkerError::config(format!("Failed to get current directory: {}", e))
})?
.to_string_lossy() .to_string_lossy()
.to_string(); .to_string();
@ -1811,7 +1818,6 @@ impl Worker {
} }
} }
/// セッション初期化Worker内部用 /// セッション初期化Worker内部用
fn initialize_session(&mut self) -> Result<(), crate::prompt::PromptError> { fn initialize_session(&mut self) -> Result<(), crate::prompt::PromptError> {
// 空のメッセージでセッション初期化 // 空のメッセージでセッション初期化
@ -1819,9 +1825,7 @@ impl Worker {
} }
/// 履歴付きセッション再初期化Worker内部用 /// 履歴付きセッション再初期化Worker内部用
fn reinitialize_session_with_history( fn reinitialize_session_with_history(&mut self) -> Result<(), crate::prompt::PromptError> {
&mut self,
) -> Result<(), crate::prompt::PromptError> {
// 現在の履歴を使ってセッション初期化 // 現在の履歴を使ってセッション初期化
self.composer.initialize_session(&self.message_history) self.composer.initialize_session(&self.message_history)
} }

View File

@ -1,12 +1,12 @@
use crate::config::UrlConfig;
use crate::core::LlmClientTrait; use crate::core::LlmClientTrait;
use crate::types::WorkerError; use crate::types::WorkerError;
use worker_types::{LlmProvider, Message, Role, StreamEvent, ToolCall};
use crate::config::UrlConfig;
use async_stream::stream; use async_stream::stream;
use futures_util::{Stream, StreamExt}; use futures_util::{Stream, StreamExt};
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use worker_types::{LlmProvider, Message, Role, StreamEvent, ToolCall};
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
struct AnthropicRequest { struct AnthropicRequest {

View File

@ -1,11 +1,11 @@
use crate::config::UrlConfig;
use crate::core::LlmClientTrait; use crate::core::LlmClientTrait;
use crate::types::WorkerError; use crate::types::WorkerError;
use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall};
use crate::config::UrlConfig;
use futures_util::{Stream, StreamExt, TryStreamExt}; use futures_util::{Stream, StreamExt, TryStreamExt};
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tracing; use tracing;
use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall};
/// Extract tool name from Tool message content /// Extract tool name from Tool message content
fn extract_tool_name_from_content(content: &str) -> Option<String> { fn extract_tool_name_from_content(content: &str) -> Option<String> {

View File

@ -1,11 +1,11 @@
use crate::config::UrlConfig;
use crate::core::LlmClientTrait; use crate::core::LlmClientTrait;
use crate::types::WorkerError; use crate::types::WorkerError;
use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall};
use crate::config::UrlConfig;
use futures_util::{Stream, StreamExt}; use futures_util::{Stream, StreamExt};
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall};
// --- Request & Response Structures --- // --- Request & Response Structures ---
#[derive(Debug, Serialize, Clone)] #[derive(Debug, Serialize, Clone)]
@ -670,7 +670,9 @@ impl OllamaClient {
self.add_auth_header(client.get(&url)) self.add_auth_header(client.get(&url))
.send() .send()
.await .await
.map_err(|e| WorkerError::llm_api("ollama", format!("Failed to connect to Ollama: {}", e)))?; .map_err(|e| {
WorkerError::llm_api("ollama", format!("Failed to connect to Ollama: {}", e))
})?;
Ok(()) Ok(())
} }
} }

View File

@ -1,11 +1,11 @@
use crate::config::UrlConfig;
use crate::core::LlmClientTrait; use crate::core::LlmClientTrait;
use crate::types::WorkerError; use crate::types::WorkerError;
use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall};
use crate::config::UrlConfig;
use futures_util::{Stream, StreamExt}; use futures_util::{Stream, StreamExt};
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall};
// --- Request & Response Structures --- // --- Request & Response Structures ---
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]

View File

@ -1,11 +1,11 @@
use crate::config::UrlConfig;
use crate::core::LlmClientTrait; use crate::core::LlmClientTrait;
use crate::types::WorkerError; use crate::types::WorkerError;
use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall};
use crate::config::UrlConfig;
use futures_util::{Stream, StreamExt}; use futures_util::{Stream, StreamExt};
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall};
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub(crate) struct XAIRequest { pub(crate) struct XAIRequest {

View File

@ -1,5 +1,5 @@
use crate::types::WorkerError;
use super::tool::McpServerConfig; use super::tool::McpServerConfig;
use crate::types::WorkerError;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::path::Path; use std::path::Path;
@ -71,17 +71,11 @@ impl McpConfig {
info!("Loading MCP config from: {:?}", path); info!("Loading MCP config from: {:?}", path);
let content = std::fs::read_to_string(path).map_err(|e| { let content = std::fs::read_to_string(path).map_err(|e| {
WorkerError::config(format!( WorkerError::config(format!("Failed to read MCP config file {:?}: {}", path, e))
"Failed to read MCP config file {:?}: {}",
path, e
))
})?; })?;
let config: McpConfig = serde_yaml::from_str(&content).map_err(|e| { let config: McpConfig = serde_yaml::from_str(&content).map_err(|e| {
WorkerError::config(format!( WorkerError::config(format!("Failed to parse MCP config file {:?}: {}", path, e))
"Failed to parse MCP config file {:?}: {}",
path, e
))
})?; })?;
info!("Loaded {} MCP server configurations", config.servers.len()); info!("Loaded {} MCP server configurations", config.servers.len());
@ -102,15 +96,11 @@ impl McpConfig {
})?; })?;
} }
let content = serde_yaml::to_string(self).map_err(|e| { let content = serde_yaml::to_string(self)
WorkerError::config(format!("Failed to serialize MCP config: {}", e)) .map_err(|e| WorkerError::config(format!("Failed to serialize MCP config: {}", e)))?;
})?;
std::fs::write(path, content).map_err(|e| { std::fs::write(path, content).map_err(|e| {
WorkerError::config(format!( WorkerError::config(format!("Failed to write MCP config file {:?}: {}", path, e))
"Failed to write MCP config file {:?}: {}",
path, e
))
})?; })?;
info!("Saved MCP config to: {:?}", path); info!("Saved MCP config to: {:?}", path);

View File

@ -5,6 +5,6 @@ mod tool;
pub use config::{IntegrationMode, McpConfig, McpServerDefinition}; pub use config::{IntegrationMode, McpConfig, McpServerDefinition};
pub use protocol::McpClient; pub use protocol::McpClient;
pub use tool::{ pub use tool::{
create_single_mcp_tools, get_mcp_tools_as_definitions, test_mcp_connection, McpDynamicTool, McpDynamicTool, McpServerConfig, SingleMcpTool, create_single_mcp_tools,
McpServerConfig, SingleMcpTool, get_mcp_tools_as_definitions, test_mcp_connection,
}; };

View File

@ -118,7 +118,10 @@ impl McpDynamicTool {
let tools = client.list_tools().await.map_err(|e| { let tools = client.list_tools().await.map_err(|e| {
crate::WorkerError::tool_execution_with_source( crate::WorkerError::tool_execution_with_source(
&self.config.name, &self.config.name,
format!("Failed to list tools from MCP server '{}'", self.config.name), format!(
"Failed to list tools from MCP server '{}'",
self.config.name
),
e, e,
) )
})?; })?;
@ -326,7 +329,10 @@ impl Tool for McpDynamicTool {
} }
None => Err(Box::new(crate::WorkerError::tool_execution( None => Err(Box::new(crate::WorkerError::tool_execution(
tool_name, tool_name,
format!("Tool '{}' not found in MCP server '{}'", tool_name, self.config.name), format!(
"Tool '{}' not found in MCP server '{}'",
tool_name, self.config.name
),
)) ))
as Box<dyn std::error::Error + Send + Sync>), as Box<dyn std::error::Error + Send + Sync>),
} }

View File

@ -4,8 +4,8 @@ use serde_json::Value;
use std::any::Any; use std::any::Any;
use std::collections::HashMap; use std::collections::HashMap;
use crate::plugin::{PluginMetadata, ProviderPlugin};
use crate::core::LlmClientTrait; use crate::core::LlmClientTrait;
use crate::plugin::{PluginMetadata, ProviderPlugin};
use crate::types::WorkerError; use crate::types::WorkerError;
use worker_types::{DynamicToolDefinition, LlmDebug, LlmProvider, Message, Role, StreamEvent}; use worker_types::{DynamicToolDefinition, LlmDebug, LlmProvider, Message, Role, StreamEvent};
@ -78,8 +78,8 @@ impl ProviderPlugin for CustomProviderPlugin {
return Err(WorkerError::config("Plugin not initialized")); return Err(WorkerError::config("Plugin not initialized"));
} }
let api_key = api_key let api_key =
.ok_or_else(|| WorkerError::config("API key required for custom provider"))?; api_key.ok_or_else(|| WorkerError::config("API key required for custom provider"))?;
let client = CustomLlmClient::new( let client = CustomLlmClient::new(
api_key.to_string(), api_key.to_string(),
@ -138,7 +138,10 @@ impl LlmClientTrait for CustomLlmClient {
messages: Vec<Message>, messages: Vec<Message>,
_tools: Option<&[DynamicToolDefinition]>, _tools: Option<&[DynamicToolDefinition]>,
_llm_debug: Option<LlmDebug>, _llm_debug: Option<LlmDebug>,
) -> Result<Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>, WorkerError> { ) -> Result<
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
WorkerError,
> {
use async_stream::stream; use async_stream::stream;
// Example implementation that echoes the last user message // Example implementation that echoes the last user message
@ -233,7 +236,10 @@ mod tests {
async fn test_plugin_initialization() { async fn test_plugin_initialization() {
let mut plugin = CustomProviderPlugin::new(); let mut plugin = CustomProviderPlugin::new();
let mut config = HashMap::new(); let mut config = HashMap::new();
config.insert("base_url".to_string(), Value::String("https://api.example.com".to_string())); config.insert(
"base_url".to_string(),
Value::String("https://api.example.com".to_string()),
);
let result = plugin.initialize(config).await; let result = plugin.initialize(config).await;
assert!(result.is_ok()); assert!(result.is_ok());

View File

@ -109,9 +109,15 @@ impl PluginRegistry {
/// Find plugin by model name /// Find plugin by model name
pub fn find_by_model(&self, model_name: &str) -> Option<Arc<dyn ProviderPlugin>> { pub fn find_by_model(&self, model_name: &str) -> Option<Arc<dyn ProviderPlugin>> {
self.plugins.values().find(|p| { self.plugins
p.metadata().supported_models.iter().any(|m| m == model_name) .values()
}).cloned() .find(|p| {
p.metadata()
.supported_models
.iter()
.any(|m| m == model_name)
})
.cloned()
} }
/// Unregister a plugin /// Unregister a plugin
@ -133,9 +139,10 @@ impl PluginLoader {
let lib = Library::new(path) let lib = Library::new(path)
.map_err(|e| WorkerError::config(format!("Failed to load plugin: {}", e)))?; .map_err(|e| WorkerError::config(format!("Failed to load plugin: {}", e)))?;
let create_plugin: Symbol<fn() -> Box<dyn ProviderPlugin>> = lib let create_plugin: Symbol<fn() -> Box<dyn ProviderPlugin>> =
.get(b"create_plugin") lib.get(b"create_plugin").map_err(|e| {
.map_err(|e| WorkerError::config(format!("Plugin missing create_plugin function: {}", e)))?; WorkerError::config(format!("Plugin missing create_plugin function: {}", e))
})?;
Ok(create_plugin()) Ok(create_plugin())
} }
@ -143,26 +150,34 @@ impl PluginLoader {
/// Load all plugins from a directory /// Load all plugins from a directory
#[cfg(feature = "dynamic-loading")] #[cfg(feature = "dynamic-loading")]
pub fn load_from_directory(dir: &std::path::Path) -> Result<Vec<Box<dyn ProviderPlugin>>, WorkerError> { pub fn load_from_directory(
dir: &std::path::Path,
) -> Result<Vec<Box<dyn ProviderPlugin>>, WorkerError> {
use std::fs; use std::fs;
let mut plugins = Vec::new(); let mut plugins = Vec::new();
if !dir.is_dir() { if !dir.is_dir() {
return Err(WorkerError::config(format!("Plugin directory does not exist: {:?}", dir))); return Err(WorkerError::config(format!(
"Plugin directory does not exist: {:?}",
dir
)));
} }
for entry in fs::read_dir(dir) for entry in fs::read_dir(dir)
.map_err(|e| WorkerError::Config(format!("Failed to read plugin directory: {}", e)))? .map_err(|e| WorkerError::Config(format!("Failed to read plugin directory: {}", e)))?
{ {
let entry = entry let entry = entry.map_err(|e| {
.map_err(|e| WorkerError::config(format!("Failed to read directory entry: {}", e)))?; WorkerError::config(format!("Failed to read directory entry: {}", e))
})?;
let path = entry.path(); let path = entry.path();
// Check for plugin files (.so on Linux, .dll on Windows, .dylib on macOS) // Check for plugin files (.so on Linux, .dll on Windows, .dylib on macOS)
if path.extension().and_then(|s| s.to_str()).map_or(false, |ext| { if path
ext == "so" || ext == "dll" || ext == "dylib" .extension()
}) { .and_then(|s| s.to_str())
.map_or(false, |ext| ext == "so" || ext == "dll" || ext == "dylib")
{
match Self::load_dynamic(&path) { match Self::load_dynamic(&path) {
Ok(plugin) => plugins.push(plugin), Ok(plugin) => plugins.push(plugin),
Err(e) => { Err(e) => {
@ -201,7 +216,10 @@ impl LlmClientTrait for PluginClient {
messages: Vec<Message>, messages: Vec<Message>,
tools: Option<&[DynamicToolDefinition]>, tools: Option<&[DynamicToolDefinition]>,
llm_debug: Option<LlmDebug>, llm_debug: Option<LlmDebug>,
) -> Result<Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>, WorkerError> { ) -> Result<
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
WorkerError,
> {
self.inner.chat_stream(messages, tools, llm_debug).await self.inner.chat_stream(messages, tools, llm_debug).await
} }

View File

@ -1,8 +1,7 @@
use crate::config::ConfigParser;
use super::types::*; use super::types::*;
use handlebars::{Context, Handlebars, Helper, HelperResult, Output, RenderContext}; use handlebars::{Context, Handlebars, Helper, HelperDef, HelperResult, Output, RenderContext};
use std::fs;
use std::path::Path; use std::path::Path;
use std::sync::Arc;
// Import Message and Role enum from worker_types // Import Message and Role enum from worker_types
use worker_types::{Message, Role as MessageRole}; use worker_types::{Message, Role as MessageRole};
@ -14,6 +13,7 @@ pub struct PromptComposer {
handlebars: Handlebars<'static>, handlebars: Handlebars<'static>,
context: PromptContext, context: PromptContext,
system_prompt: Option<String>, system_prompt: Option<String>,
resource_loader: Arc<dyn ResourceLoader>,
} }
impl PromptComposer { impl PromptComposer {
@ -21,26 +21,29 @@ impl PromptComposer {
pub fn from_config_file<P: AsRef<Path>>( pub fn from_config_file<P: AsRef<Path>>(
config_path: P, config_path: P,
context: PromptContext, context: PromptContext,
resource_loader: Arc<dyn ResourceLoader>,
) -> Result<Self, PromptError> { ) -> Result<Self, PromptError> {
let config = ConfigParser::parse_from_file(config_path)?; let config = crate::config::ConfigParser::parse_from_file(config_path)?;
Self::from_config(config, context) Self::from_config(config, context, resource_loader)
} }
/// 設定オブジェクトから新しいインスタンスを作成 /// 設定オブジェクトから新しいインスタンスを作成
pub fn from_config( pub fn from_config(
config: Role, config: Role,
context: PromptContext, context: PromptContext,
resource_loader: Arc<dyn ResourceLoader>,
) -> Result<Self, PromptError> { ) -> Result<Self, PromptError> {
let mut handlebars = Handlebars::new(); let mut handlebars = Handlebars::new();
// カスタムヘルパー関数を登録 // カスタムヘルパー関数を登録
Self::register_custom_helpers(&mut handlebars)?; Self::register_custom_helpers(&mut handlebars, resource_loader.clone())?;
let mut composer = Self { let mut composer = Self {
config, config,
handlebars, handlebars,
context, context,
system_prompt: None, system_prompt: None,
resource_loader,
}; };
// パーシャルテンプレートを読み込み・登録 // パーシャルテンプレートを読み込み・登録
@ -60,7 +63,8 @@ impl PromptComposer {
pub fn compose(&self, messages: &[Message]) -> Result<Vec<Message>, PromptError> { pub fn compose(&self, messages: &[Message]) -> Result<Vec<Message>, PromptError> {
if let Some(system_prompt) = &self.system_prompt { if let Some(system_prompt) = &self.system_prompt {
// システムプロンプトが既に構築済みの場合、それを使用 // システムプロンプトが既に構築済みの場合、それを使用
let mut result_messages = vec![Message::new(MessageRole::System, system_prompt.clone())]; let mut result_messages =
vec![Message::new(MessageRole::System, system_prompt.clone())];
// ユーザーメッセージを追加 // ユーザーメッセージを追加
for msg in messages { for msg in messages {
@ -102,7 +106,8 @@ impl PromptComposer {
) -> Result<Vec<Message>, PromptError> { ) -> Result<Vec<Message>, PromptError> {
if let Some(system_prompt) = &self.system_prompt { if let Some(system_prompt) = &self.system_prompt {
// システムプロンプトが既に構築済みの場合、それを使用 // システムプロンプトが既に構築済みの場合、それを使用
let mut result_messages = vec![Message::new(MessageRole::System, system_prompt.clone())]; let mut result_messages =
vec![Message::new(MessageRole::System, system_prompt.clone())];
// ユーザーメッセージを追加 // ユーザーメッセージを追加
for msg in messages { for msg in messages {
@ -171,9 +176,16 @@ impl PromptComposer {
} }
/// カスタムヘルパー関数を登録 /// カスタムヘルパー関数を登録
fn register_custom_helpers(handlebars: &mut Handlebars) -> Result<(), PromptError> { fn register_custom_helpers(
// 基本的なヘルパーのみ実装(複雑なライフタイム問題を回避) handlebars: &mut Handlebars<'static>,
handlebars.register_helper("include_file", Box::new(include_file_helper)); resource_loader: Arc<dyn ResourceLoader>,
) -> Result<(), PromptError> {
handlebars.register_helper(
"include_file",
Box::new(IncludeFileHelper {
loader: resource_loader.clone(),
}),
);
handlebars.register_helper("workspace_content", Box::new(workspace_content_helper)); handlebars.register_helper("workspace_content", Box::new(workspace_content_helper));
Ok(()) Ok(())
@ -194,27 +206,22 @@ impl PromptComposer {
/// パーシャルの内容を読み込み(フォールバック対応) /// パーシャルの内容を読み込み(フォールバック対応)
fn load_partial_content(&self, partial_config: &PartialConfig) -> Result<String, PromptError> { fn load_partial_content(&self, partial_config: &PartialConfig) -> Result<String, PromptError> {
let primary_path = ConfigParser::resolve_path(&partial_config.path)?; match self.resource_loader.load(&partial_config.path) {
Ok(content) => Ok(content),
// メインパスを試行 Err(primary_err) => {
if let Ok(content) = fs::read_to_string(&primary_path) {
return Ok(content);
}
// フォールバックパスを試行
if let Some(fallback) = &partial_config.fallback { if let Some(fallback) = &partial_config.fallback {
let fallback_path = ConfigParser::resolve_path(fallback)?; match self.resource_loader.load(fallback) {
if let Ok(content) = fs::read_to_string(&fallback_path) { Ok(content) => Ok(content),
return Ok(content); Err(fallback_err) => Err(PromptError::PartialLoading(format!(
"Could not load partial '{}' (fallback: {:?}): primary error={}, fallback error={}",
partial_config.path, partial_config.fallback, primary_err, fallback_err
))),
}
} else {
Err(primary_err)
}
} }
} }
Err(PromptError::FileNotFound(format!(
"Could not load partial '{}' from {} (fallback: {:?})",
partial_config.path,
primary_path.display(),
partial_config.fallback
)))
} }
/// コンテキストを指定してテンプレート用のデータを準備 /// コンテキストを指定してテンプレート用のデータを準備
@ -289,18 +296,22 @@ impl PromptComposer {
// カスタムヘルパー関数の実装 // カスタムヘルパー関数の実装
fn include_file_helper( struct IncludeFileHelper {
h: &Helper, loader: Arc<dyn ResourceLoader>,
_hbs: &Handlebars, }
_ctx: &Context,
_rc: &mut RenderContext, impl HelperDef for IncludeFileHelper {
fn call<'reg: 'rc, 'rc>(
&self,
h: &Helper<'rc>,
_handlebars: &Handlebars<'reg>,
_context: &Context,
_rc: &mut RenderContext<'reg, 'rc>,
out: &mut dyn Output, out: &mut dyn Output,
) -> HelperResult { ) -> HelperResult {
let file_path = h.param(0).and_then(|v| v.value().as_str()).unwrap_or(""); let file_path = h.param(0).and_then(|v| v.value().as_str()).unwrap_or("");
match ConfigParser::resolve_path(file_path) { match self.loader.load(file_path) {
Ok(path) => {
match fs::read_to_string(&path) {
Ok(content) => { Ok(content) => {
out.write(&content)?; out.write(&content)?;
} }
@ -309,13 +320,9 @@ fn include_file_helper(
out.write("")?; out.write("")?;
} }
} }
}
Err(_) => {
out.write("")?;
}
}
Ok(()) Ok(())
} }
}
fn workspace_content_helper( fn workspace_content_helper(
_h: &Helper, _h: &Helper,

View File

@ -3,6 +3,6 @@ mod types;
pub use composer::PromptComposer; pub use composer::PromptComposer;
pub use types::{ pub use types::{
ConditionConfig, GitInfo, ModelCapabilities, ModelContext, PartialConfig, PromptContext, ConditionConfig, GitInfo, ModelCapabilities, ModelContext, PartialConfig, ProjectType,
PromptError, ProjectType, Role, SessionContext, SystemInfo, WorkspaceContext, PromptContext, PromptError, ResourceLoader, Role, SessionContext, SystemInfo, WorkspaceContext,
}; };

View File

@ -21,6 +21,11 @@ pub struct PartialConfig {
pub description: Option<String>, pub description: Option<String>,
} }
/// External resource loader used to resolve template includes/partials
pub trait ResourceLoader: Send + Sync {
fn load(&self, identifier: &str) -> Result<String, PromptError>;
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConditionConfig { pub struct ConditionConfig {
pub when: String, pub when: String,
@ -354,7 +359,11 @@ impl Default for SessionContext {
impl Role { impl Role {
/// Create a new Role with name, description, and template /// Create a new Role with name, description, and template
pub fn new(name: impl Into<String>, description: impl Into<String>, template: impl Into<String>) -> Self { pub fn new(
name: impl Into<String>,
description: impl Into<String>,
template: impl Into<String>,
) -> Self {
Self { Self {
name: name.into(), name: name.into(),
description: description.into(), description: description.into(),

View File

@ -1,5 +1,13 @@
use crate::config::ConfigParser; use crate::config::ConfigParser;
use crate::prompt::{
ModelCapabilities, ModelContext, PromptComposer, PromptContext, PromptError, ResourceLoader,
SessionContext, SystemInfo, WorkspaceContext,
};
use crate::types::LlmProvider;
use std::collections::HashMap;
use std::io::Write; use std::io::Write;
use std::path::PathBuf;
use std::sync::Arc;
use tempfile::NamedTempFile; use tempfile::NamedTempFile;
#[test] #[test]
@ -166,30 +174,104 @@ template: "File content {{user_input}}"
assert_eq!(config.template, "File content {{user_input}}"); assert_eq!(config.template, "File content {{user_input}}");
} }
struct InMemoryLoader {
data: HashMap<String, String>,
}
impl InMemoryLoader {
fn new(data: HashMap<String, String>) -> Self {
Self { data }
}
}
impl ResourceLoader for InMemoryLoader {
fn load(&self, identifier: &str) -> Result<String, PromptError> {
self.data
.get(identifier)
.cloned()
.ok_or_else(|| PromptError::FileNotFound(format!("not found: {}", identifier)))
}
}
fn build_prompt_context() -> PromptContext {
let workspace = WorkspaceContext {
root_path: PathBuf::from("."),
nia_md_content: None,
project_type: None,
git_info: None,
has_nia_md: false,
project_name: None,
system_info: SystemInfo::collect(),
};
let capabilities = ModelCapabilities {
supports_tools: false,
supports_function_calling: false,
supports_vision: false,
supports_multimodal: None,
context_length: None,
capabilities: vec![],
needs_verification: None,
};
let model_context = ModelContext {
provider: LlmProvider::Claude,
model_name: "test-model".to_string(),
capabilities,
supports_native_tools: false,
};
let session_context = SessionContext {
conversation_id: None,
message_count: 0,
active_tools: vec![],
user_preferences: None,
};
PromptContext {
workspace,
model: model_context,
session: session_context,
variables: HashMap::new(),
}
}
#[test] #[test]
fn test_resolve_path() { fn test_prompt_composer_uses_resource_loader() {
// #nia/ prefix let yaml_content = r##"
let path = name: "Loader Test"
ConfigParser::resolve_path("#nia/prompts/test.md").expect("Failed to resolve nia path"); description: "Ensure resource loader is used"
assert!( template: |
path.to_string_lossy() {{>header}}
.contains("nia-cli/resources/prompts/test.md") {{include_file "include.md"}}
);
// #workspace/ prefix partials:
let path = ConfigParser::resolve_path("#workspace/config.md") header:
.expect("Failed to resolve workspace path"); path: "missing.md"
assert!(path.to_string_lossy().contains(".nia/config.md")); fallback: "fallback.md"
"##;
// #user/ prefix let role =
let path = ConfigParser::parse_from_string(yaml_content).expect("Failed to parse loader test config");
ConfigParser::resolve_path("#user/settings.md").expect("Failed to resolve user path");
assert!(path.to_string_lossy().contains("settings.md"));
// Regular path let loader = Arc::new(InMemoryLoader::new(HashMap::from([
let path = ("fallback.md".to_string(), "Fallback Partial".to_string()),
ConfigParser::resolve_path("regular/path.md").expect("Failed to resolve regular path"); ("include.md".to_string(), "Included Content".to_string()),
assert_eq!(path.to_string_lossy(), "regular/path.md"); ])));
let prompt_context = build_prompt_context();
let composer = PromptComposer::from_config(role, prompt_context, loader)
.expect("Composer should use provided loader");
let messages = composer
.compose(&[])
.expect("Composer should build system prompt");
assert!(!messages.is_empty());
let system_message = &messages[0];
assert!(system_message.content.contains("Fallback Partial"));
assert!(system_message.content.contains("Included Content"));
} }
#[test] #[test]

View File

@ -25,7 +25,10 @@ pub enum WorkerError {
/// Model not found for the specified provider /// Model not found for the specified provider
#[error("Model not found: {model_name} for provider {provider}")] #[error("Model not found: {model_name} for provider {provider}")]
ModelNotFound { provider: String, model_name: String }, ModelNotFound {
provider: String,
model_name: String,
},
/// JSON serialization/deserialization error /// JSON serialization/deserialization error
#[error("JSON error: {0}")] #[error("JSON error: {0}")]
@ -138,10 +141,7 @@ impl WorkerError {
} }
/// Create a configuration error with context /// Create a configuration error with context
pub fn config_with_context( pub fn config_with_context(message: impl Into<String>, context: impl Into<String>) -> Self {
message: impl Into<String>,
context: impl Into<String>,
) -> Self {
Self::ConfigurationError { Self::ConfigurationError {
message: message.into(), message: message.into(),
context: Some(context.into()), context: Some(context.into()),

View File

@ -1,4 +1,4 @@
use crate::prompt::{WorkspaceContext, PromptError, ProjectType, GitInfo}; use crate::prompt::{GitInfo, ProjectType, PromptError, WorkspaceContext};
use std::fs; use std::fs;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::process::Command; use std::process::Command;