add: 0.1.0 code from other project

This commit is contained in:
Keisuke Hirata 2025-08-30 02:58:52 +09:00
commit 9b608f1a54
29 changed files with 12440 additions and 0 deletions

1
.envrc Normal file
View File

@ -0,0 +1 @@
use flake

2354
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

7
Cargo.toml Normal file
View File

@ -0,0 +1,7 @@
[workspace]
resolver = "2"
members = [
"worker",
"worker-types",
"worker-macros",
]

77
flake.lock Normal file
View File

@ -0,0 +1,77 @@
{
"nodes": {
"flake-compat": {
"locked": {
"lastModified": 1747046372,
"narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
"owner": "edolstra",
"repo": "flake-compat",
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
"type": "github"
},
"original": {
"owner": "edolstra",
"repo": "flake-compat",
"type": "github"
}
},
"flake-utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1751011381,
"narHash": "sha256-krGXKxvkBhnrSC/kGBmg5MyupUUT5R6IBCLEzx9jhMM=",
"owner": "nixos",
"repo": "nixpkgs",
"rev": "30e2e2857ba47844aa71991daa6ed1fc678bcbb7",
"type": "github"
},
"original": {
"owner": "nixos",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"root": {
"inputs": {
"flake-compat": "flake-compat",
"flake-utils": "flake-utils",
"nixpkgs": "nixpkgs"
}
},
"systems": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
}
},
"root": "root",
"version": 7
}

31
flake.nix Normal file
View File

@ -0,0 +1,31 @@
{
inputs = {
nixpkgs.url = "github:nixos/nixpkgs?ref=nixos-unstable";
flake-utils.url = "github:numtide/flake-utils";
flake-compat.url = "github:edolstra/flake-compat";
};
outputs =
{ nixpkgs, flake-utils, ... }:
flake-utils.lib.eachDefaultSystem (
system:
let
pkgs = import nixpkgs { inherit system; };
in
{
devShells.default = pkgs.mkShell {
packages = with pkgs; [
nixfmt
deno
git
rustc
cargo
];
buildInputs = with pkgs; [
pkg-config
openssl
];
};
}
);
}

18
worker-macros/Cargo.toml Normal file
View File

@ -0,0 +1,18 @@
[package]
name = "worker-macros"
version = "0.1.0"
edition = "2024"
[lib]
proc-macro = true
[dependencies]
syn = { version = "2.0", features = ["full"] }
quote = "1.0"
proc-macro2 = "1.0"
worker-types = { path = "../worker-types" }
[dev-dependencies]
tokio = { version = "1.0", features = ["full"] }
schemars = "1.0.3"

327
worker-macros/src/lib.rs Normal file
View File

@ -0,0 +1,327 @@
use proc_macro::TokenStream;
use quote::quote;
use syn::{
Attribute, ItemFn, LitStr,
parse::{Parse, ParseStream},
parse_macro_input,
};
struct ToolAttributeArgs {
name: Option<String>,
}
impl Parse for ToolAttributeArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut name = None;
if !input.is_empty() {
let name_ident: syn::Ident = input.parse()?;
if name_ident != "name" {
return Err(syn::Error::new_spanned(
name_ident,
"Only 'name' attribute is supported",
));
}
input.parse::<syn::Token![=]>()?;
let name_str: LitStr = input.parse()?;
name = Some(name_str.value());
}
Ok(ToolAttributeArgs { name })
}
}
#[proc_macro_attribute]
pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as ToolAttributeArgs);
let func = parse_macro_input!(item as ItemFn);
let description = {
let doc_comments = extract_doc_comments(&func.attrs);
if doc_comments.is_empty() {
format!("Tool function: {}", func.sig.ident)
} else {
doc_comments
}
};
// Validate function signature
if let Err(e) = validate_function_signature(&func) {
return e.to_compile_error().into();
}
let fn_name = &func.sig.ident;
let fn_name_str = fn_name.to_string();
// Use provided name or fall back to CamelCase function name
let tool_name_str = args.name.unwrap_or_else(|| to_camel_case(&fn_name_str));
// Extract arg_type and output_type safely after validation
let arg_type = if let syn::FnArg::Typed(pat_type) = &func.sig.inputs[0] {
&pat_type.ty
} else {
// This case should be caught by validate_function_signature
return syn::Error::new_spanned(&func.sig.inputs[0], "Expected typed argument")
.to_compile_error()
.into();
};
if let syn::ReturnType::Type(_, _) = &func.sig.output {
} else {
// This case should be caught by validate_function_signature
return syn::Error::new_spanned(&func.sig.output, "Expected return type")
.to_compile_error()
.into();
};
// Generate struct name from function name (e.g., read_file -> ReadFileTool)
let tool_struct_name = {
let fn_name_str = fn_name.to_string();
let camel_case = to_camel_case(&fn_name_str);
syn::Ident::new(&format!("{}Tool", camel_case), fn_name.span())
};
let expanded = quote! {
// Keep the original function
#func
// Generate Tool struct
pub struct #tool_struct_name;
impl #tool_struct_name {
pub fn new() -> Self {
Self
}
}
// Implement Tool trait
#[::worker_types::async_trait::async_trait]
impl ::worker_types::Tool for #tool_struct_name {
fn name(&self) -> &str {
#tool_name_str
}
fn description(&self) -> &str {
#description
}
fn parameters_schema(&self) -> ::worker_types::serde_json::Value {
::worker_types::serde_json::to_value(::worker_types::schemars::schema_for!(#arg_type)).unwrap()
}
async fn execute(&self, args: ::worker_types::serde_json::Value) -> ::worker_types::ToolResult<::worker_types::serde_json::Value> {
let typed_args: #arg_type = ::worker_types::serde_json::from_value(args)?;
let result = #fn_name(typed_args).await?;
// Use Display formatting instead of JSON serialization
let formatted_result = format!("{}", result);
Ok(::worker_types::serde_json::Value::String(formatted_result))
}
}
};
TokenStream::from(expanded)
}
fn validate_function_signature(func: &ItemFn) -> syn::Result<()> {
if func.sig.asyncness.is_none() {
return Err(syn::Error::new_spanned(
&func.sig,
"Tool function must be async",
));
}
if func.sig.inputs.len() != 1 {
return Err(syn::Error::new_spanned(
&func.sig.inputs,
"Tool function must have exactly one argument",
));
}
let arg = &func.sig.inputs[0];
if !matches!(arg, syn::FnArg::Typed(_)) {
return Err(syn::Error::new_spanned(
arg,
"Argument must be a typed pattern (e.g., `args: MyArgs`)",
));
}
if let syn::ReturnType::Default = func.sig.output {
return Err(syn::Error::new_spanned(
&func.sig,
"Tool function must have a return type, typically Result<T, E>",
));
}
Ok(())
}
fn extract_doc_comments(attrs: &[Attribute]) -> String {
let mut doc_lines = Vec::new();
for attr in attrs {
if attr.path().is_ident("doc") {
if let syn::Meta::NameValue(meta) = &attr.meta {
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit_str),
..
}) = &meta.value
{
let content = lit_str.value();
let trimmed = content.trim_start();
doc_lines.push(trimmed.to_string());
}
}
}
}
if doc_lines.is_empty() {
return String::new();
}
doc_lines.join("\n").trim().to_string()
}
fn to_camel_case(snake_case: &str) -> String {
snake_case
.split('_')
.map(|word| {
let mut chars = word.chars();
match chars.next() {
None => String::new(),
Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
}
})
.collect()
}
// Hook attribute arguments parser
struct HookAttributeArgs {
hook_type: String,
matcher: Option<String>,
}
impl Parse for HookAttributeArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut hook_type = None;
let mut matcher = None;
while !input.is_empty() {
let name: syn::Ident = input.parse()?;
input.parse::<syn::Token![=]>()?;
let value: LitStr = input.parse()?;
match name.to_string().as_str() {
"hook_type" => hook_type = Some(value.value()),
"matcher" => matcher = Some(value.value()),
_ => return Err(syn::Error::new_spanned(name, "Unknown hook attribute")),
}
if input.peek(syn::Token![,]) {
input.parse::<syn::Token![,]>()?;
}
}
let hook_type = hook_type.ok_or_else(|| input.error("Hook type is required"))?;
Ok(HookAttributeArgs { hook_type, matcher })
}
}
#[proc_macro_attribute]
pub fn hook(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as HookAttributeArgs);
let func = parse_macro_input!(item as ItemFn);
// Validate function signature for hooks
if let Err(e) = validate_hook_function_signature(&func) {
return e.to_compile_error().into();
}
let fn_name = &func.sig.ident;
let fn_name_str = fn_name.to_string();
let hook_type = &args.hook_type;
let matcher = args.matcher.as_deref().unwrap_or("");
// Generate struct name from function name
let hook_struct_name = {
let fn_name_str = fn_name.to_string();
let camel_case = to_camel_case(&fn_name_str);
// 既に "_hook" で終わっている場合は、それを削除してから "Hook" を追加
let cleaned_name = if camel_case.ends_with("Hook") {
camel_case.strip_suffix("Hook").unwrap_or(&camel_case)
} else {
&camel_case
};
syn::Ident::new(&format!("{}Hook", cleaned_name), fn_name.span())
};
let expanded = quote! {
// Keep the original function
#func
// Generate Hook struct
pub struct #hook_struct_name;
impl #hook_struct_name {
pub fn new() -> Self {
Self
}
}
// Implement WorkerHook trait
#[::worker_types::async_trait::async_trait]
impl ::worker_types::WorkerHook for #hook_struct_name {
fn name(&self) -> &str {
#fn_name_str
}
fn hook_type(&self) -> &str {
#hook_type
}
fn matcher(&self) -> &str {
#matcher
}
async fn execute(&self, context: ::worker_types::HookContext) -> (::worker_types::HookContext, ::worker_types::HookResult) {
#fn_name(context).await
}
}
};
TokenStream::from(expanded)
}
fn validate_hook_function_signature(func: &ItemFn) -> syn::Result<()> {
if func.sig.asyncness.is_none() {
return Err(syn::Error::new_spanned(
&func.sig,
"Hook function must be async",
));
}
if func.sig.inputs.len() != 1 {
return Err(syn::Error::new_spanned(
&func.sig.inputs,
"Hook function must have exactly one argument of type HookContext",
));
}
let arg = &func.sig.inputs[0];
if !matches!(arg, syn::FnArg::Typed(_)) {
return Err(syn::Error::new_spanned(
arg,
"Argument must be a typed pattern (e.g., `context: HookContext`)",
));
}
if let syn::ReturnType::Default = func.sig.output {
return Err(syn::Error::new_spanned(
&func.sig,
"Hook function must return (HookContext, HookResult)",
));
}
Ok(())
}

15
worker-types/Cargo.toml Normal file
View File

@ -0,0 +1,15 @@
[package]
name = "worker-types"
version = "0.1.0"
edition = "2024"
[dependencies]
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
schemars = "1.0.3"
async-trait = "0.1.88"
thiserror = "2.0.12"
anyhow = "1.0"
chrono = { version = "0.4", features = ["serde"] }
uuid = { version = "1.10", features = ["v4", "serde"] }
tracing = "0.1.40"

1120
worker-types/src/lib.rs Normal file

File diff suppressed because it is too large Load Diff

42
worker/Cargo.toml Normal file
View File

@ -0,0 +1,42 @@
[package]
name = "worker"
version = "0.1.0"
edition = "2024"
[dependencies]
worker-types = { path = "../worker-types" }
worker-macros = { path = "../worker-macros" }
schemars = "1.0.3"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1", features = ["full"] }
anyhow = "1.0"
reqwest = { version = "0.11", default-features = false, features = [
"json",
"rustls-tls",
"stream",
] }
toml = "0.8"
thiserror = "2.0.12"
futures-util = "0.3"
async-stream = "0.3"
bytes = "1"
async-trait = "0.1.88"
serde_yaml = "0.9.33"
log = "0.4"
dirs = "6.0.0"
strum = { version = "0.27.1", features = ["derive"] }
strum_macros = "0.27.1"
tracing = "0.1.40"
eventsource-stream = "0.2.3"
xdg = "3.0.0"
chrono = { version = "0.4", features = ["serde"] }
handlebars = "5.1.2"
regex = "1.10.2"
uuid = { version = "1.10", features = ["v4", "serde"] }
tokio-util = { version = "0.7", features = ["codec"] }
futures = "0.3"
[dev-dependencies]
tempfile = "3.10.1"
tracing-subscriber = "0.3"

150
worker/README.md Normal file
View File

@ -0,0 +1,150 @@
# `worker` クレート
`worker` クレートは、大規模言語モデル (LLM) を利用したアプリケーションのバックエンド機能を提供するコアコンポーネントです。LLMプロバイダーの抽象化、ツール利用、柔軟なプロンプト管理、フックシステムなど、高度な機能をカプセル化し、アプリケーション開発を簡素化します。
## 主な機能
- **マルチプロバイダー対応**: Gemini, Claude, OpenAI, Ollama, XAIなど、複数のLLMプロバイダーを統一されたインターフェースで利用できます。
- **ツール利用 (Function Calling)**: LLMが外部ツールを呼び出す機能をサポートします。独自のツールを簡単に定義して `Worker` に登録できます。
- **ストリーミング処理**: LLMの応答やツール実行結果を `StreamEvent` として非同期に受け取ることができます。これにより、リアルタイムなUI更新が可能になります。
- **フックシステム**: `Worker` の処理フローの特定のタイミング(例: メッセージ送信前、ツール使用後)にカスタムロジックを介入させることができます。
- **セッション管理**: 会話履歴やワークスペースの状態を管理し、永続化する機能を提供します。
- **柔軟なプロンプト管理**: 設定ファイルを用いて、ロールやコンテキストに応じたシステムプロンプトを動的に構築します。
## 主な概念
### `Worker`
このクレートの中心的な構造体です。LLMとの対話、ツールの登録と実行、セッション管理など、すべての主要な機能を担当します。
### `LlmProvider`
サポートしているLLMプロバイダー`Gemini`, `Claude`, `OpenAI` などを表すenumです。
### `Tool` トレイト
`Worker` が利用できるツールを定義するためのインターフェースです。このトレイトを実装することで、任意の機能をツールとして `Worker` に追加できます。
```rust
pub trait Tool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters_schema(&self) -> serde_json::Value;
async fn execute(&self, args: serde_json::Value) -> Result<serde_json::Value, crate::WorkerError>;
}
```
### `WorkerHook` トレイト
`Worker` のライフサイクルイベントに介入するためのフックを定義するインターフェースです。特定のイベント(例: `OnMessageSend`, `PostToolUse`)に対して処理を追加できます。
### `StreamEvent`
`Worker` の処理結果を非同期ストリームで受け取るためのenumです。LLMの応答チャンク、ツール呼び出し、エラーなど、さまざまなイベントを表します。
## アプリケーションへの組み込み方法
### 1. Workerの初期化
まず、`Worker` のインスタンスを作成します。これには `LlmProvider`、モデル名、APIキーが必要です。
```rust
use worker::{Worker, LlmProvider};
use std::collections::HashMap;
// APIキーを準備
let mut api_keys = HashMap::new();
api_keys.insert("openai".to_string(), "your_openai_api_key".to_string());
api_keys.insert("claude".to_string(), "your_claude_api_key".to_string());
// Workerを作成
let mut worker = Worker::new(
LlmProvider::OpenAI,
"gpt-4o",
&api_keys,
None // RoleConfigはオプション
).expect("Workerの作成に失敗しました");
```
### 2. ツールの定義と登録
`Tool` トレイトを実装してカスタムツールを作成し、`Worker` に登録します。
```rust
use worker::{Tool, ToolResult};
use worker::schemars::{self, JsonSchema};
use worker::serde_json::{self, json, Value};
use async_trait::async_trait;
// ツールの引数を定義
#[derive(Debug, serde::Deserialize, JsonSchema)]
struct FileSystemToolArgs {
path: String,
}
// カスタムツールを定義
struct ListFilesTool;
#[async_trait]
impl Tool for ListFilesTool {
fn name(&self) -> &str { "list_files" }
fn description(&self) -> &str { "指定されたパスのファイル一覧を表示します" }
fn parameters_schema(&self) -> Value {
serde_json::to_value(schemars::schema_for!(FileSystemToolArgs)).unwrap()
}
async fn execute(&self, args: Value) -> ToolResult<Value> {
let tool_args: FileSystemToolArgs = serde_json::from_value(args)?;
// ここで実際のファイル一覧取得処理を実装
let files = vec!["file1.txt", "file2.txt"];
Ok(json!({ "files": files }))
}
}
// 作成したツールをWorkerに登録
worker.register_tool(Box::new(ListFilesTool)).unwrap();
```
### 3. 対話処理の実行
`process_task_with_history` メソッドを呼び出して、ユーザーメッセージを処理します。このメソッドはイベントのストリームを返します。
```rust
use futures_util::StreamExt;
let user_message = "カレントディレクトリのファイルを教えて".to_string();
let mut stream = worker.process_task_with_history(user_message, None).await;
while let Some(event_result) = stream.next().await {
match event_result {
Ok(event) => {
// StreamEventに応じた処理
match event {
worker::StreamEvent::Chunk(chunk) => {
print!("{}", chunk);
}
worker::StreamEvent::ToolCall(tool_call) => {
println!("\n[Tool Call: {} with args {}]", tool_call.name, tool_call.arguments);
}
worker::StreamEvent::ToolResult { tool_name, result } => {
println!("\n[Tool Result: {} -> {:?}]", tool_name, result);
}
_ => {}
}
}
Err(e) => {
eprintln!("\n[Error: {}]", e);
break;
}
}
}
```
### 4. (オプション) フックの登録
`WorkerHook` トレイトを実装してカスタムフックを作成し、`Worker` に登録することで、処理フローをカスタマイズできます。
```rust
// (WorkerHookの実装は省略)
// let my_hook = MyCustomHook::new();
// worker.register_hook(Box::new(my_hook));
```
これで、アプリケーションの要件に応じて `Worker` を中心とした強力なLLM連携機能を構築できます。

110
worker/src/config_parser.rs Normal file
View File

@ -0,0 +1,110 @@
use crate::prompt_types::*;
use std::fs;
use std::path::Path;
/// 設定ファイルのパーサー
pub struct ConfigParser;
impl ConfigParser {
/// YAML設定ファイルを読み込んでパースする
pub fn parse_from_file<P: AsRef<Path>>(path: P) -> Result<PromptRoleConfig, PromptError> {
let content = fs::read_to_string(path.as_ref()).map_err(|e| {
PromptError::FileNotFound(format!("{}: {}", path.as_ref().display(), e))
})?;
Self::parse_from_string(&content)
}
/// YAML文字列をパースしてPromptRoleConfigに変換する
pub fn parse_from_string(content: &str) -> Result<PromptRoleConfig, PromptError> {
let config: PromptRoleConfig = serde_yaml::from_str(content)?;
// 基本的なバリデーション
Self::validate_config(&config)?;
Ok(config)
}
/// 設定ファイルの基本的なバリデーション
fn validate_config(config: &PromptRoleConfig) -> Result<(), PromptError> {
if config.name.is_empty() {
return Err(PromptError::VariableResolution(
"name field cannot be empty".to_string(),
));
}
if config.template.is_empty() {
return Err(PromptError::TemplateCompilation(
"template field cannot be empty".to_string(),
));
}
// パーシャルのパス検証
if let Some(partials) = &config.partials {
for (name, partial) in partials {
if partial.path.is_empty() {
return Err(PromptError::PartialLoading(format!(
"partial '{}' has empty path",
name
)));
}
}
}
Ok(())
}
/// パスプレフィックスを解決する
pub fn resolve_path(path_str: &str) -> Result<std::path::PathBuf, PromptError> {
if path_str.starts_with("#nia/") {
// 組み込みリソース
let relative_path = path_str.strip_prefix("#nia/").unwrap();
let project_root = std::env::current_dir()
.map_err(|e| PromptError::WorkspaceDetection(e.to_string()))?;
// 優先順位: ./resources > ./nia-cli/resources > ../nia-cli/resources
let possible_paths = [
project_root.join("resources").join(relative_path),
project_root
.join("nia-cli")
.join("resources")
.join(relative_path),
project_root
.parent()
.unwrap_or(&project_root)
.join("nia-cli")
.join("resources")
.join(relative_path),
];
for path in &possible_paths {
if path.exists() {
return Ok(path.clone());
}
}
// 見つからない場合はデフォルトのパスを返す
Ok(project_root
.join("nia-cli")
.join("resources")
.join(relative_path))
} else if path_str.starts_with("#workspace/") {
// ワークスペース固有
let relative_path = path_str.strip_prefix("#workspace/").unwrap();
let project_root = std::env::current_dir()
.map_err(|e| PromptError::WorkspaceDetection(e.to_string()))?;
Ok(project_root.join(".nia").join(relative_path))
} else if path_str.starts_with("#user/") {
// ユーザー設定
let relative_path = path_str.strip_prefix("#user/").unwrap();
let base_dirs = xdg::BaseDirectories::with_prefix("nia");
let config_home = base_dirs.get_config_home().ok_or_else(|| {
PromptError::WorkspaceDetection("Could not determine XDG config home".to_string())
})?;
Ok(config_home.join(relative_path))
} else {
// 相対パスまたは絶対パス
Ok(std::path::PathBuf::from(path_str))
}
}
}

2090
worker/src/lib.rs Normal file

File diff suppressed because it is too large Load Diff

393
worker/src/llm/anthropic.rs Normal file
View File

@ -0,0 +1,393 @@
use crate::{
LlmClientTrait, WorkerError,
types::{LlmProvider, Message, Role, StreamEvent, ToolCall},
url_config::UrlConfig,
};
use async_stream::stream;
use futures_util::{Stream, StreamExt};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Serialize)]
struct AnthropicRequest {
model: String,
max_tokens: i32,
messages: Vec<AnthropicMessage>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<AnthropicTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct AnthropicMessage {
role: String,
content: String,
}
#[derive(Debug, Serialize, Clone)]
struct AnthropicTool {
name: String,
description: String,
input_schema: Value,
}
#[derive(Debug, Deserialize)]
struct AnthropicResponse {
#[serde(rename = "type")]
response_type: String,
content: Vec<AnthropicContent>,
}
#[derive(Debug, Deserialize, Serialize)]
struct AnthropicStreamResponse {
#[serde(rename = "type")]
response_type: String,
#[serde(flatten)]
data: Value,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type")]
enum AnthropicContent {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: Value,
},
}
pub struct AnthropicClient {
api_key: String,
model: String,
}
impl AnthropicClient {
pub fn new(api_key: &str, model: &str) -> Self {
Self {
api_key: api_key.to_string(),
model: model.to_string(),
}
}
pub fn get_model_name(&self) -> String {
self.model.clone()
}
}
impl AnthropicClient {
pub async fn chat_stream<'a>(
&'a self,
messages: Vec<Message>,
tools: Option<&[crate::types::DynamicToolDefinition]>,
llm_debug: Option<crate::types::LlmDebug>,
) -> Result<
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
WorkerError,
> {
let client = Client::new();
let url = UrlConfig::get_completion_url("anthropic");
// Separate system messages from other messages
let mut system_message: Option<String> = None;
let mut anthropic_messages: Vec<AnthropicMessage> = Vec::new();
for msg in messages {
match msg.role {
Role::System => {
// Combine multiple system messages if they exist
if let Some(existing) = system_message {
system_message = Some(format!(
"{}
{}",
existing, msg.content
));
} else {
system_message = Some(msg.content);
}
}
Role::User => {
anthropic_messages.push(AnthropicMessage {
role: "user".to_string(),
content: msg.content,
});
}
Role::Model => {
anthropic_messages.push(AnthropicMessage {
role: "assistant".to_string(),
content: msg.content,
});
}
Role::Tool => {
anthropic_messages.push(AnthropicMessage {
role: "user".to_string(),
content: msg.content,
});
}
}
}
// Convert tools to Anthropic format
let anthropic_tools = tools.map(|tools| {
tools
.iter()
.map(|tool| AnthropicTool {
name: tool.name.clone(),
description: tool.description.clone(),
input_schema: tool.parameters_schema.clone(),
})
.collect()
});
let request = AnthropicRequest {
model: self.model.clone(),
max_tokens: 4096,
messages: anthropic_messages,
stream: true,
tools: anthropic_tools,
system: system_message,
};
// Log request details for debugging
tracing::debug!(
"Anthropic API request: {}",
serde_json::to_string_pretty(&request).unwrap_or_default()
);
let response = client
.post(url)
.header("Content-Type", "application/json")
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.json(&request)
.send()
.await
.map_err(|e| {
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Claude)
})?;
if !response.status().is_success() {
let status = response.status();
let error_body = response.text().await.unwrap_or_default();
return Err(WorkerError::from_api_error(
format!("Anthropic API error: {} - {}", status, error_body),
&crate::types::LlmProvider::Claude,
));
}
let stream = stream! {
// デバッグ情報を送信
if let Some(ref debug) = llm_debug {
if let Some(debug_event) = debug.debug_request(&self.model, "Anthropic", &serde_json::to_value(&request).unwrap_or_default()) {
yield Ok(debug_event);
}
}
let mut stream = response.bytes_stream();
let mut buffer = String::new();
while let Some(chunk) = stream.next().await {
match chunk {
Ok(bytes) => {
let chunk_str = String::from_utf8_lossy(&bytes);
buffer.push_str(&chunk_str);
// Server-sent eventsを処理
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].to_string();
buffer = buffer[line_end + 1..].to_string();
if line.starts_with("data: ") {
let data = &line[6..];
if data == "[DONE]" {
break;
}
match serde_json::from_str::<AnthropicStreamResponse>(data) {
Ok(stream_response) => {
// デバッグ情報を送信
if let Some(ref debug) = llm_debug {
if let Some(debug_event) = debug.debug_response(&self.model, "Anthropic", &serde_json::to_value(&stream_response).unwrap_or_default()) {
yield Ok(debug_event);
}
}
match stream_response.response_type.as_str() {
"content_block_delta" => {
if let Some(delta) = stream_response.data.get("delta") {
if let Some(text) = delta.get("text").and_then(|t| t.as_str()) {
yield Ok(StreamEvent::Chunk(text.to_string()));
}
}
}
"content_block_start" => {
if let Some(content_block) = stream_response.data.get("content_block") {
if let Some(block_type) = content_block.get("type").and_then(|t| t.as_str()) {
if block_type == "tool_use" {
if let (Some(name), Some(input)) = (
content_block.get("name").and_then(|n| n.as_str()),
content_block.get("input")
) {
let tool_call = ToolCall {
name: name.to_string(),
arguments: input.to_string(),
};
yield Ok(StreamEvent::ToolCall(tool_call));
}
}
}
}
}
"message_start" => {
tracing::debug!("Anthropic message stream started");
}
"message_delta" => {
if let Some(delta) = stream_response.data.get("delta") {
if let Some(stop_reason) = delta.get("stop_reason") {
tracing::debug!("Anthropic message stop reason: {}", stop_reason);
}
}
}
"message_stop" => {
tracing::debug!("Anthropic message stream stopped");
yield Ok(StreamEvent::Completion(Message::new(
Role::Model,
"".to_string(),
)));
break;
}
"content_block_stop" => {
tracing::debug!("Anthropic content block stopped");
}
"ping" => {
tracing::debug!("Anthropic ping received");
}
"error" => {
if let Some(error) = stream_response.data.get("error") {
let error_msg = error.get("message")
.and_then(|m| m.as_str())
.unwrap_or("Unknown error");
tracing::error!("Anthropic stream error: {}", error_msg);
yield Err(WorkerError::from_api_error(
format!("Anthropic stream error: {}", error_msg),
&crate::types::LlmProvider::Claude,
));
}
}
_ => {
tracing::debug!("Unhandled Anthropic stream event: {}", stream_response.response_type);
}
}
}
Err(e) => {
tracing::warn!("Failed to parse Anthropic stream response: {} - Raw data: {}", e, data);
}
}
}
}
}
Err(e) => {
yield Err(WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Claude));
break;
}
}
}
};
Ok(Box::new(Box::pin(stream)))
}
pub async fn check_connection(&self) -> Result<(), WorkerError> {
let client = Client::new();
let url = UrlConfig::get_completion_url("anthropic");
// Use a default valid model for connection testing if model is empty
let test_model = if self.model.is_empty() {
"claude-3-haiku-20240307".to_string()
} else {
self.model.clone()
};
tracing::debug!(
"Anthropic connection test: Using model '{}' with API key length: {}",
test_model,
self.api_key.len()
);
let test_request = AnthropicRequest {
model: test_model,
max_tokens: 1,
messages: vec![AnthropicMessage {
role: "user".to_string(),
content: "Hi".to_string(),
}],
stream: false,
tools: None,
system: None,
};
let response = client
.post(url)
.header("Content-Type", "application/json")
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.json(&test_request)
.send()
.await
.map_err(|e| {
tracing::error!("Anthropic connection test network error: {}", e);
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Claude)
})?;
if !response.status().is_success() {
let status = response.status();
let error_body = response.text().await.unwrap_or_default();
tracing::error!(
"Anthropic connection test failed: Status={}, Body={}",
status,
error_body
);
return Err(WorkerError::from_api_error(
format!(
"Anthropic connection test failed: {} - {}",
status, error_body
),
&crate::types::LlmProvider::Claude,
));
}
Ok(())
}
}
#[async_trait::async_trait]
impl LlmClientTrait for AnthropicClient {
async fn chat_stream<'a>(
&'a self,
messages: Vec<Message>,
tools: Option<&[crate::types::DynamicToolDefinition]>,
llm_debug: Option<crate::types::LlmDebug>,
) -> Result<
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
WorkerError,
> {
self.chat_stream(messages, tools, llm_debug).await
}
async fn check_connection(&self) -> Result<(), WorkerError> {
self.check_connection().await
}
fn provider(&self) -> LlmProvider {
LlmProvider::Claude
}
fn get_model_name(&self) -> String {
self.get_model_name()
}
}

977
worker/src/llm/gemini.rs Normal file
View File

@ -0,0 +1,977 @@
use crate::{
LlmClientTrait, WorkerError,
types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall},
url_config::UrlConfig,
};
use futures_util::{Stream, StreamExt, TryStreamExt};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tracing;
/// Extract tool name from Tool message content
fn extract_tool_name_from_content(content: &str) -> Option<String> {
// Look for patterns like "Tool 'tool_name' executed successfully"
if let Some(start) = content.find("Tool '") {
if let Some(end) = content[start + 6..].find("'") {
let tool_name = &content[start + 6..start + 6 + end];
return Some(tool_name.to_string());
}
}
None
}
/// Parse tool call information from message content
/// Transforms a JSON schema to be compatible with Gemini API
/// Converts 'uint' types to 'integer' types and handles nullable types
/// Also ensures the schema is in the correct format for Gemini function parameters
fn transform_schema_for_gemini(schema: serde_json::Value) -> serde_json::Value {
match schema {
serde_json::Value::Object(mut obj) => {
// Remove $schema key as it's not needed for Gemini
obj.remove("$schema");
// Handle type field
if let Some(type_val) = obj.get("type") {
match type_val {
// Convert 'uint' to 'integer'
serde_json::Value::String(s) if s == "uint" => {
obj.insert(
"type".to_string(),
serde_json::Value::String("integer".to_string()),
);
// Add format for integer types as required by Gemini
obj.insert(
"format".to_string(),
serde_json::Value::String("int64".to_string()),
);
}
// Handle array types like ["integer", "null"]
serde_json::Value::Array(arr) => {
if let Some(non_null_type) = arr.iter().find(|&t| t != "null") {
// Use the non-null type
let mut new_type = non_null_type.clone();
// Convert 'uint' to 'integer' if needed
if let serde_json::Value::String(s) = &new_type {
if s == "uint" {
new_type = serde_json::Value::String("integer".to_string());
}
}
obj.insert("type".to_string(), new_type.clone());
// Add format for integer types as required by Gemini
if let serde_json::Value::String(type_str) = &new_type {
if type_str == "integer" {
obj.insert(
"format".to_string(),
serde_json::Value::String("int64".to_string()),
);
}
}
}
}
// Handle existing integer types
serde_json::Value::String(s) if s == "integer" => {
// Add format for integer types as required by Gemini
obj.insert(
"format".to_string(),
serde_json::Value::String("int64".to_string()),
);
}
_ => {}
}
}
// Handle properties and required fields
if let (Some(properties), Some(required)) = (obj.get("properties"), obj.get("required"))
{
if let (serde_json::Value::Object(props), serde_json::Value::Array(req_arr)) =
(properties, required)
{
let mut new_required = Vec::new();
for (prop_name, _) in props {
// Only include in required if it's not nullable
if req_arr.iter().any(|r| r == prop_name) {
// Check if this property has a nullable type
if let Some(prop_schema) = props.get(prop_name) {
if let Some(prop_type) = prop_schema.get("type") {
// If type is an array containing "null", it's nullable
let is_nullable = match prop_type {
serde_json::Value::Array(arr) => {
arr.iter().any(|t| t == "null")
}
_ => false,
};
// Only add to required if not nullable
if !is_nullable {
new_required
.push(serde_json::Value::String(prop_name.clone()));
}
} else {
// No type info, assume required
new_required.push(serde_json::Value::String(prop_name.clone()));
}
}
}
}
obj.insert(
"required".to_string(),
serde_json::Value::Array(new_required),
);
}
}
// Recursively transform nested objects
for (_, value) in obj.iter_mut() {
*value = transform_schema_for_gemini(value.clone());
}
serde_json::Value::Object(obj)
}
serde_json::Value::Array(arr) => {
serde_json::Value::Array(arr.into_iter().map(transform_schema_for_gemini).collect())
}
other => other,
}
}
// --- Request Structures ---
#[derive(Debug, Serialize, Clone)]
pub struct GeminiTool {
#[serde(rename = "functionDeclarations")]
pub function_declarations: Vec<GeminiFunctionDeclaration>,
}
#[derive(Debug, Serialize, Clone)]
pub struct GeminiFunctionDeclaration {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Serialize, Clone)]
pub struct GeminiRequest {
pub contents: Vec<GeminiContent>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "systemInstruction")]
pub system_instruction: Option<GeminiContent>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<GeminiTool>>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct GeminiContent {
pub role: String,
#[serde(default)]
pub parts: Vec<GeminiPart>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(untagged)]
pub enum GeminiPart {
Text {
text: String,
},
FunctionCall {
#[serde(rename = "functionCall")]
function_call: GeminiFunctionCall,
},
FunctionResponse {
#[serde(rename = "functionResponse")]
function_response: GeminiFunctionResponse,
},
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct GeminiFunctionCall {
pub name: String,
pub args: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct GeminiFunctionResponse {
pub name: String,
pub response: serde_json::Value,
}
// --- Response Structures ---
#[derive(Debug, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct GeminiResponse {
#[serde(default)]
pub candidates: Vec<GeminiCandidate>,
}
#[derive(Debug, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct GeminiCandidate {
pub content: GeminiContent,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
}
fn build_url(model: &str) -> String {
let base_url = UrlConfig::get_base_url("gemini");
let action = "streamGenerateContent";
format!("{}/v1beta/models/{}:{}", base_url, model, action)
}
/// Finds the start and end indices of the first complete JSON object `{...}` in the buffer.
fn find_first_json_object_bounds(buffer: &[u8]) -> Option<(usize, usize)> {
if let Some(start) = buffer.iter().position(|&b| b == b'{') {
let mut brace_count = 0;
let mut in_string = false;
let mut escaped = false;
for (i, &byte) in buffer.iter().skip(start).enumerate() {
if in_string {
if escaped {
escaped = false;
} else if byte == b'\\' {
escaped = true;
} else if byte == b'"' {
in_string = false;
}
} else {
match byte {
b'"' => in_string = true,
b'{' => brace_count += 1,
b'}' => {
brace_count -= 1;
if brace_count == 0 {
let end = start + i + 1;
return Some((start, end));
}
}
_ => {}
}
}
}
}
None // No complete object found
}
/// Completes a chat request with streaming, yielding StreamEvent objects.
pub(crate) fn stream_events<'a>(
api_key: &'a str,
model: &'a str,
request: GeminiRequest,
llm_debug: Option<crate::types::LlmDebug>,
) -> impl Stream<Item = anyhow::Result<StreamEvent>> + 'a {
let api_key = api_key.to_string();
let model = model.to_string();
async_stream::try_stream! {
let body = serde_json::to_string_pretty(&request).unwrap_or_else(|e| e.to_string());
tracing::debug!("Gemini Request Body: {}", body);
if let Some(debug_settings) = &llm_debug {
if let Some(debug_event) = debug_settings.debug_request(&model, "Gemini", &serde_json::to_value(&request).unwrap_or_default()) {
yield debug_event;
}
}
let client = Client::new();
let url = build_url(&model);
let response = client
.post(&url)
.header("x-goog-api-key", &api_key)
.json(&request)
.send()
.await
.map_err(|e| anyhow::anyhow!("Gemini API request failed: {}", e))?;
let status = response.status();
if !status.is_success() {
let error_body = response.text().await.unwrap_or_else(|_| "Could not read error body".to_string());
let error_msg = format!("Gemini API request failed with status: {} - {}", status, error_body);
tracing::error!("{}", error_msg);
Err(anyhow::anyhow!(error_msg))?;
} else {
let mut byte_stream = response.bytes_stream();
let mut buffer = Vec::new();
let mut full_content = String::new();
while let Some(chunk_result) = byte_stream.next().await {
let chunk = chunk_result?;
buffer.extend_from_slice(&chunk);
while let Some((start, end)) = find_first_json_object_bounds(&buffer) {
let json_slice = &buffer[start..end];
if let Some(debug_settings) = &llm_debug {
if let Ok(response_value) = serde_json::from_slice::<serde_json::Value>(json_slice) {
if let Some(debug_event) = debug_settings.debug_response(&model, "Gemini", &response_value) {
yield debug_event;
}
}
}
match serde_json::from_slice::<GeminiResponse>(json_slice) {
Ok(response) => {
let response_text = String::from_utf8_lossy(json_slice);
tracing::debug!(
response = %response_text,
candidates_count = response.candidates.len(),
"Successfully parsed Gemini response"
);
if response.candidates.is_empty() {
tracing::warn!(
response = %response_text,
"Received empty candidates in Gemini response"
);
} else if let Some(candidate) = response.candidates.get(0) {
// Log finish reason for debugging
if let Some(ref finish_reason) = candidate.finish_reason {
tracing::debug!(
finish_reason = %finish_reason,
"Received finish reason in Gemini response"
);
// Handle specific finish reasons
match finish_reason.as_str() {
"STOP" => {
tracing::debug!("Gemini response completed with STOP");
// Continue processing parts if any, this is normal completion
}
"MAX_TOKENS" => {
tracing::warn!("Gemini response stopped due to MAX_TOKENS");
}
"SAFETY" => {
tracing::warn!("Gemini response stopped due to SAFETY concerns");
}
"RECITATION" => {
tracing::warn!("Gemini response stopped due to RECITATION");
}
other => {
tracing::warn!("Gemini response stopped with unknown reason: {}", other);
}
}
}
if candidate.content.parts.is_empty() {
tracing::warn!(
response = %response_text,
role = %candidate.content.role,
finish_reason = ?candidate.finish_reason,
"Received empty parts in Gemini response"
);
} else {
for part in &candidate.content.parts {
tracing::debug!("Processing Gemini part (type unknown)");
match part {
GeminiPart::Text { text } => {
tracing::debug!("Found Text part with content length: {}", text.len());
full_content.push_str(text);
yield StreamEvent::Chunk(text.clone());
}
GeminiPart::FunctionCall { function_call } => {
tracing::debug!("Found FunctionCall part: name={}, args={:?}", function_call.name, function_call.args);
let tool_call = ToolCall {
name: function_call.name.clone(),
arguments: serde_json::to_string(&function_call.args)
.unwrap_or_else(|_| "{}".to_string()),
};
yield StreamEvent::ToolCall(tool_call.clone());
}
GeminiPart::FunctionResponse { .. } => {
// Function responses in model output are not expected
// as they're part of the input conversation history
tracing::warn!("Unexpected FunctionResponse in model output");
}
}
}
}
}
}
Err(e) => {
let response_text = String::from_utf8_lossy(json_slice);
tracing::warn!(
error = %e,
response = %response_text,
"Failed to deserialize GeminiResponse from slice"
);
}
}
buffer.drain(..end);
}
}
let final_message = Message::new(
Role::Model,
full_content,
);
yield StreamEvent::Completion(final_message);
}
}
}
pub struct GeminiClient {
api_key: String,
model: String,
}
impl GeminiClient {
pub fn new(api_key: &str, model: &str) -> Self {
Self {
api_key: api_key.to_string(),
model: model.to_string(),
}
}
pub fn get_model_name(&self) -> String {
self.model.clone()
}
/// 静的メソッドAPI キーを受け取ってモデル一覧を取得
pub async fn list_models_static(
api_key: &str,
) -> Result<Vec<crate::types::ModelInfo>, WorkerError> {
let client = Client::new();
let url = UrlConfig::get_models_url("gemini");
let response = client
.get(url)
.header("x-goog-api-key", api_key)
.send()
.await
.map_err(|e| {
tracing::error!("Gemini API request failed: {}", e);
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Gemini)
})?;
if !response.status().is_success() {
let status = response.status();
let error_body = response.text().await.unwrap_or_default();
tracing::error!(
"Gemini list_models_static failed - Status: {}, Body: {}",
status,
error_body
);
return Err(WorkerError::from_api_error(
format!("Failed to list Gemini models: {} - {}", status, error_body),
&crate::types::LlmProvider::Gemini,
));
}
let models_response: serde_json::Value = response.json().await.map_err(|e| {
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Gemini)
})?;
let mut models = Vec::new();
if let Some(models_array) = models_response.get("models").and_then(|m| m.as_array()) {
for model in models_array {
if let Some(name) = model.get("name").and_then(|n| n.as_str()) {
// "models/" プレフィックスを除去
let model_id = name.strip_prefix("models/").unwrap_or(name);
// generateContentメソッドをサポートするモデルのみを含める
if let Some(supported_methods) = model
.get("supportedGenerationMethods")
.and_then(|m| m.as_array())
{
let supports_generate_content = supported_methods
.iter()
.any(|method| method.as_str() == Some("generateContent"));
if supports_generate_content {
models.push(crate::types::ModelInfo {
id: model_id.to_string(),
name: model
.get("displayName")
.and_then(|d| d.as_str())
.unwrap_or(model_id)
.to_string(),
provider: crate::types::LlmProvider::Gemini,
supports_tools: true,
supports_function_calling: true,
supports_vision: false,
supports_multimodal: false,
context_length: None,
training_cutoff: None,
capabilities: vec!["text_generation".to_string()],
description: model
.get("description")
.and_then(|d| d.as_str())
.map(|s| s.to_string())
.or_else(|| Some(format!("Google Gemini model: {}", model_id))),
});
}
}
}
}
}
tracing::info!(
"Gemini list_models_static found {} models with metadata",
models.len()
);
Ok(models)
}
pub async fn chat_stream<'a>(
&'a self,
messages: Vec<Message>,
tools: Option<&[DynamicToolDefinition]>,
llm_debug: Option<crate::types::LlmDebug>,
) -> Result<
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
WorkerError,
> {
// Separate system messages from regular messages
let (system_messages, regular_messages): (Vec<_>, Vec<_>) = messages
.into_iter()
.partition(|msg| matches!(msg.role, Role::System));
// Create system instruction from system messages
let system_instruction = if !system_messages.is_empty() {
let combined_system_content = system_messages
.into_iter()
.map(|msg| msg.content)
.collect::<Vec<_>>()
.join("\n\n");
Some(GeminiContent {
role: "user".to_string(), // System instruction uses "user" role
parts: vec![GeminiPart::Text {
text: combined_system_content,
}],
})
} else {
None
};
// Process regular messages with proper tool context handling
let contents = regular_messages
.into_iter()
.map(|msg| {
let (role, parts) = match msg.role {
Role::User => (
"user".to_string(),
vec![GeminiPart::Text { text: msg.content }],
),
Role::Model => {
if let Some(tool_calls) = &msg.tool_calls {
// Model message with tool calls - convert to FunctionCall parts
tracing::debug!(
"Converting model message with {} tool calls to FunctionCall parts",
tool_calls.len()
);
let mut parts = Vec::new();
// Add text content if present
if !msg.content.is_empty() {
parts.push(GeminiPart::Text {
text: msg.content.clone(),
});
}
// Add function calls
for tool_call in tool_calls {
tracing::debug!(
"Adding FunctionCall part for tool: {}",
tool_call.name
);
let args = serde_json::from_str(&tool_call.arguments)
.unwrap_or(serde_json::json!({}));
parts.push(GeminiPart::FunctionCall {
function_call: GeminiFunctionCall {
name: tool_call.name.clone(),
args,
},
});
}
("model".to_string(), parts)
} else {
// Regular model message
tracing::debug!("Converting regular model message (no tool calls)");
(
"model".to_string(),
vec![GeminiPart::Text { text: msg.content }],
)
}
}
Role::Tool => {
// Tool responses should be sent as FunctionResponse
if let Some(tool_name) = extract_tool_name_from_content(&msg.content) {
// Extract result from the content
let result_value = if msg.content.contains("Result: ") {
if let Some(result_start) = msg.content.find("Result: ") {
let result_str = &msg.content[result_start + 8..];
// Try to parse as JSON, fallback to string
serde_json::from_str(result_str)
.unwrap_or_else(|_| serde_json::json!(result_str))
} else {
serde_json::json!(msg.content)
}
} else {
serde_json::json!(msg.content)
};
(
"user".to_string(),
vec![GeminiPart::FunctionResponse {
function_response: GeminiFunctionResponse {
name: tool_name,
response: result_value,
},
}],
)
} else {
// Fallback to text response if tool name can't be extracted
(
"user".to_string(),
vec![GeminiPart::Text {
text: format!("Tool Response:\n{}", msg.content),
}],
)
}
}
Role::System => unreachable!(), // Should not reach here after partition
};
GeminiContent { role, parts }
})
.collect();
let tools = tools.map(|tools| {
vec![GeminiTool {
function_declarations: tools
.iter()
.map(|tool| {
let mut transformed_schema =
transform_schema_for_gemini(tool.parameters_schema.clone());
// Ensure the schema has the correct structure for Gemini
match transformed_schema {
serde_json::Value::Object(ref mut obj) => {
// Gemini expects the parameters to be an object with type: "object"
if !obj.contains_key("type") {
obj.insert(
"type".to_string(),
serde_json::Value::String("object".to_string()),
);
}
// If there are no properties, add an empty object
if !obj.contains_key("properties") {
obj.insert(
"properties".to_string(),
serde_json::Value::Object(serde_json::Map::new()),
);
}
// If there are no required fields, add an empty array
if !obj.contains_key("required") {
obj.insert(
"required".to_string(),
serde_json::Value::Array(vec![]),
);
}
}
_ => {
// If it's not an object, create a proper object schema
let mut schema_obj = serde_json::Map::new();
schema_obj.insert(
"type".to_string(),
serde_json::Value::String("object".to_string()),
);
schema_obj.insert(
"properties".to_string(),
serde_json::Value::Object(serde_json::Map::new()),
);
schema_obj.insert(
"required".to_string(),
serde_json::Value::Array(vec![]),
);
transformed_schema = serde_json::Value::Object(schema_obj);
}
}
GeminiFunctionDeclaration {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: transformed_schema,
}
})
.collect(),
}]
});
let request = GeminiRequest {
contents,
system_instruction,
tools,
};
let stream = stream_events(&self.api_key, &self.model, request, llm_debug)
.map_err(|e| WorkerError::LlmApiError(e.to_string()));
Ok(Box::new(Box::pin(stream)))
}
pub async fn get_model_details(
&self,
model_name: &str,
) -> Result<crate::types::ModelInfo, WorkerError> {
let client = Client::new();
let url = UrlConfig::get_model_url("gemini", model_name);
let response = client
.get(&url)
.header("x-goog-api-key", &self.api_key)
.send()
.await
.map_err(|e| {
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Gemini)
})?;
if !response.status().is_success() {
return Err(WorkerError::from_api_error(
format!(
"Gemini model details request failed with status: {}",
response.status()
),
&crate::types::LlmProvider::Gemini,
));
}
let model_data: serde_json::Value = response.json().await.map_err(|e| {
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Gemini)
})?;
let name = model_data
.get("name")
.and_then(|n| n.as_str())
.unwrap_or(model_name);
let display_name = model_data
.get("displayName")
.and_then(|d| d.as_str())
.unwrap_or(name);
let description = model_data
.get("description")
.and_then(|d| d.as_str())
.unwrap_or("");
let version = model_data
.get("version")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let input_token_limit = model_data
.get("inputTokenLimit")
.and_then(|i| i.as_u64())
.map(|i| i as u32);
let _output_token_limit = model_data
.get("outputTokenLimit")
.and_then(|o| o.as_u64())
.map(|o| o as u32);
let empty_vec = Vec::new();
let supported_methods = model_data
.get("supportedGenerationMethods")
.and_then(|s| s.as_array())
.unwrap_or(&empty_vec);
let supports_tools = supported_methods
.iter()
.any(|method| method.as_str() == Some("generateContent"));
let supports_vision = false; // Will be determined dynamically
let capabilities = vec!["text_generation".to_string()]; // Basic default
Ok(crate::types::ModelInfo {
id: model_name.to_string(),
name: display_name.to_string(),
provider: crate::types::LlmProvider::Gemini,
supports_tools,
supports_function_calling: supports_tools,
supports_vision,
supports_multimodal: supports_vision,
context_length: input_token_limit,
training_cutoff: version,
capabilities,
description: Some(if description.is_empty() {
format!("Google Gemini model: {}", display_name)
} else {
description.to_string()
}),
})
}
pub async fn check_connection(&self) -> Result<(), WorkerError> {
// Simple connection check - try to call the API
// For now, just return OK if model is not empty
if self.model.is_empty() {
return Err(WorkerError::ModelNotFound("No model specified".to_string()));
}
Ok(())
}
}
#[async_trait::async_trait]
impl LlmClientTrait for GeminiClient {
async fn chat_stream<'a>(
&'a self,
messages: Vec<Message>,
tools: Option<&[DynamicToolDefinition]>,
llm_debug: Option<crate::types::LlmDebug>,
) -> Result<
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
WorkerError,
> {
self.chat_stream(messages, tools, llm_debug).await
}
async fn check_connection(&self) -> Result<(), WorkerError> {
self.check_connection().await
}
fn provider(&self) -> LlmProvider {
LlmProvider::Gemini
}
fn get_model_name(&self) -> String {
self.get_model_name()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
#[test]
fn test_schema_transformation() {
// Test schema with various type formats including $schema
let schema = serde_json::json!({
"$schema": "https://json-schema.org/draft/2020-12/schema",
"type": "object",
"properties": {
"id": {
"type": "uint"
},
"optional_number": {
"type": ["integer", "null"]
},
"required_string": {
"type": "string"
},
"existing_integer": {
"type": "integer"
},
"nested": {
"type": "object",
"properties": {
"count": {
"type": ["uint", "null"]
}
}
}
},
"required": ["id", "optional_number", "required_string"]
});
let transformed = transform_schema_for_gemini(schema);
// Check that the schema has the correct structure
assert_eq!(transformed["type"], "object");
assert!(transformed["properties"].is_object());
assert!(transformed["required"].is_array());
// Check that $schema key is removed
assert!(transformed.get("$schema").is_none());
// Check that 'uint' was transformed to 'integer'
assert_eq!(transformed["properties"]["id"]["type"], "integer");
assert_eq!(transformed["properties"]["id"]["format"], "int64");
// Check that array types are converted to single types
assert_eq!(
transformed["properties"]["optional_number"]["type"],
"integer"
);
assert_eq!(
transformed["properties"]["optional_number"]["format"],
"int64"
);
assert_eq!(
transformed["properties"]["nested"]["properties"]["count"]["type"],
"integer"
);
assert_eq!(
transformed["properties"]["nested"]["properties"]["count"]["format"],
"int64"
);
// Check that existing integer types also get format
assert_eq!(
transformed["properties"]["existing_integer"]["type"],
"integer"
);
assert_eq!(
transformed["properties"]["existing_integer"]["format"],
"int64"
);
// Check that required array is updated correctly (nullable properties should be removed)
let required: Vec<&str> = transformed["required"]
.as_array()
.unwrap()
.iter()
.map(|v| v.as_str().unwrap())
.collect();
assert!(required.contains(&"id"));
assert!(required.contains(&"required_string"));
assert!(!required.contains(&"optional_number")); // Should be removed because it's nullable
}
#[test]
fn test_empty_schema_transformation() {
// Test with an empty schema as would be processed in tool generation
let schema = serde_json::json!({});
let mut transformed = transform_schema_for_gemini(schema);
// Apply the same logic as in tool generation
match transformed {
serde_json::Value::Object(ref mut obj) => {
if !obj.contains_key("type") {
obj.insert(
"type".to_string(),
serde_json::Value::String("object".to_string()),
);
}
if !obj.contains_key("properties") {
obj.insert(
"properties".to_string(),
serde_json::Value::Object(serde_json::Map::new()),
);
}
if !obj.contains_key("required") {
obj.insert("required".to_string(), serde_json::Value::Array(vec![]));
}
}
_ => {
let mut schema_obj = serde_json::Map::new();
schema_obj.insert(
"type".to_string(),
serde_json::Value::String("object".to_string()),
);
schema_obj.insert(
"properties".to_string(),
serde_json::Value::Object(serde_json::Map::new()),
);
schema_obj.insert("required".to_string(), serde_json::Value::Array(vec![]));
transformed = serde_json::Value::Object(schema_obj);
}
}
// Should be converted to a proper object schema
assert_eq!(transformed["type"], "object");
assert!(transformed["properties"].is_object());
assert!(transformed["required"].is_array());
}
}

5
worker/src/llm/mod.rs Normal file
View File

@ -0,0 +1,5 @@
pub mod anthropic;
pub mod gemini;
pub mod ollama;
pub mod openai;
pub mod xai;

801
worker/src/llm/ollama.rs Normal file
View File

@ -0,0 +1,801 @@
use crate::{
LlmClientTrait, WorkerError,
types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall},
url_config::UrlConfig,
};
use futures_util::{Stream, StreamExt};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::Value;
// --- Request & Response Structures ---
#[derive(Debug, Serialize, Clone)]
pub struct OllamaRequest {
pub model: String,
pub messages: Vec<OllamaMessage>,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<OllamaTool>>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct OllamaMessage {
pub role: String,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<OllamaToolCall>>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct OllamaToolCall {
pub function: OllamaToolCallFunction,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct OllamaToolCallFunction {
pub name: String,
#[serde(
serialize_with = "serialize_arguments",
deserialize_with = "deserialize_arguments"
)]
pub arguments: String,
}
/// Custom serializer for arguments field that serializes strings as-is
fn serialize_arguments<S>(arguments: &str, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
// Try to parse as JSON first, if successful serialize as raw JSON
// If not valid JSON, serialize as string
if let Ok(value) = serde_json::from_str::<serde_json::Value>(arguments) {
value.serialize(serializer)
} else {
arguments.serialize(serializer)
}
}
/// Custom deserializer for arguments field that handles both string and object formats
fn deserialize_arguments<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let value: serde_json::Value = serde::Deserialize::deserialize(deserializer)?;
match value {
serde_json::Value::String(s) => Ok(s),
serde_json::Value::Object(_) | serde_json::Value::Array(_) => {
// If it's an object or array, serialize it back to a JSON string
serde_json::to_string(&value).map_err(D::Error::custom)
}
_ => Err(D::Error::custom("arguments must be a string or object")),
}
}
#[derive(Debug, Serialize, Clone)]
pub struct OllamaTool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: OllamaFunction,
}
#[derive(Debug, Serialize, Clone)]
pub struct OllamaFunction {
pub name: String,
pub description: String,
pub parameters: Value,
}
#[derive(Debug, Deserialize)]
pub struct OllamaResponse {
pub message: OllamaMessage,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct OllamaStreamResponse {
pub message: OllamaMessage,
pub done: bool,
}
#[derive(Debug, Deserialize)]
pub struct OllamaModelShowResponse {
pub details: Option<OllamaModelDetails>,
pub model_info: Option<serde_json::Value>,
pub template: Option<String>,
pub system: Option<String>,
pub parameters: Option<serde_json::Value>,
}
#[derive(Debug, Deserialize)]
pub struct OllamaModelDetails {
pub format: Option<String>,
pub family: Option<String>,
pub families: Option<Vec<String>>,
pub parameter_size: Option<String>,
pub quantization_level: Option<String>,
}
// --- Client ---
pub struct OllamaClient {
model: String,
base_url: String,
api_key: Option<String>,
}
impl OllamaClient {
pub fn new(model: &str) -> Self {
Self {
model: model.to_string(),
base_url: UrlConfig::get_base_url("ollama"),
api_key: None,
}
}
pub fn new_with_key(api_key: &str, model: &str) -> Self {
tracing::debug!(
"Ollama: Creating client with API key (length: {}), model: {}",
api_key.len(),
model
);
Self {
model: model.to_string(),
base_url: UrlConfig::get_base_url("ollama"),
api_key: Some(api_key.to_string()),
}
}
pub fn get_model_name(&self) -> String {
self.model.clone()
}
fn add_auth_header(&self, request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
tracing::debug!(
"Ollama: add_auth_header called, api_key present: {}",
self.api_key.is_some()
);
if let Some(ref api_key) = self.api_key {
// API key詳細ログは削除セキュリティと見づらさ解消のため
// API keyが空でない場合のみヘッダーを追加
if !api_key.trim().is_empty() {
// API keyがすでにフォーマットされているかチェック
if api_key.starts_with("Basic ") || api_key.starts_with("Bearer ") {
// すでにフォーマット済み(例: "Basic base64string" や "Bearer token"
// Auth header詳細ログは削除セキュリティと見づらさ解消のため
request_builder.header("Authorization", api_key)
} else {
// URLに基づいて認証方式を決定
let auth_header = if self.base_url.contains("ollama.com") {
// ollama.comの場合はBearerトークンを使用
format!("Bearer {}", api_key)
} else {
// その他の場合はBasic認証を使用ローカル/プロキシ向け)
format!("Basic {}", api_key)
};
// Auth header詳細ログは削除セキュリティと見づらさ解消のため
request_builder.header("Authorization", auth_header)
}
} else {
tracing::debug!("Ollama: Empty API key, skipping auth header");
request_builder
}
} else {
tracing::debug!(
"Ollama: No API key provided, using unauthenticated request (typical for local Ollama)"
);
request_builder
}
}
/// 静的メソッドOllamaサーバーからモデル一覧を取得デフォルトURL使用
pub async fn list_models_static(
api_key: &str,
) -> Result<Vec<crate::types::ModelInfo>, WorkerError> {
let client = Client::new();
let url = UrlConfig::get_models_url("ollama");
tracing::debug!("Ollama list_models_static requesting: {}", url);
let mut request_builder = client.get(&url);
if !api_key.trim().is_empty() {
// API keyがすでにフォーマットされているかチェック
if api_key.starts_with("Basic ") || api_key.starts_with("Bearer ") {
// すでにフォーマット済み
request_builder = request_builder.header("Authorization", api_key);
} else {
// URLに基づいて認証方式を決定
let auth_header = if url.contains("ollama.com") {
// ollama.comの場合はBearerトークンを使用
format!("Bearer {}", api_key)
} else {
// その他の場合はBasic認証を使用ローカル/プロキシ向け)
format!("Basic {}", api_key)
};
// Auth header詳細ログは削除セキュリティと見づらさ解消のため
request_builder = request_builder.header("Authorization", auth_header);
}
}
let response = request_builder.send().await.map_err(|e| {
tracing::error!("Ollama API request failed: {}", e);
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
})?;
let status = response.status();
tracing::info!("Ollama list_models_static response status: {}", status);
if !status.is_success() {
let error_body = response.text().await.unwrap_or_default();
tracing::error!(
"Ollama list_models_static failed - Status: {}, Body: {}",
status,
error_body
);
let error_msg = format!("Failed to list Ollama models: {} - {}", status, error_body);
return Err(WorkerError::from_api_error(
error_msg,
&crate::types::LlmProvider::Ollama,
));
}
let response_text = response.text().await.map_err(|e| {
tracing::error!("Failed to read Ollama response text: {}", e);
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
})?;
// Raw response詳細ログは削除見づらさ解消のため
let models_response: serde_json::Value =
serde_json::from_str(&response_text).map_err(|e| {
tracing::error!(
"Failed to parse Ollama JSON response: {} - Response: {}",
e,
response_text
);
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
})?;
let mut models = Vec::new();
if let Some(models_array) = models_response.get("models").and_then(|m| m.as_array()) {
for model in models_array {
if let Some(name) = model.get("name").and_then(|n| n.as_str()) {
models.push(crate::types::ModelInfo {
id: name.to_string(),
name: name.to_string(),
provider: crate::types::LlmProvider::Ollama,
supports_tools: true, // Will be determined by config
supports_function_calling: true,
supports_vision: false, // Will be determined by config
supports_multimodal: false,
context_length: None,
training_cutoff: None,
capabilities: vec!["text_generation".to_string()],
description: Some(format!("Ollama model: {}", name)),
});
}
}
}
tracing::info!(
"Ollama list_models_static found {} models with metadata",
models.len()
);
Ok(models)
}
// list_models_with_info was removed - models should be configured in models.yaml
// This private method is kept for future reference if needed
#[allow(dead_code)]
async fn list_models_with_info_internal(
&self,
) -> Result<Vec<crate::types::ModelInfo>, WorkerError> {
let client = Client::new();
let url = format!("{}/api/tags", self.base_url);
tracing::debug!("Ollama list_models requesting: {}", url);
let request = self.add_auth_header(client.get(&url));
tracing::debug!("Ollama list_models_with_info sending request to: {}", &url);
let response = request.send().await.map_err(|e| {
tracing::error!("Ollama API request failed: {}", e);
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
})?;
let status = response.status();
tracing::info!("Ollama list_models response status: {}", status);
if !status.is_success() {
let error_body = response.text().await.unwrap_or_default();
tracing::error!(
"Ollama list_models failed - Status: {}, Body: {}, URL: {}",
status,
error_body,
&url
);
let error_msg = format!("Failed to list Ollama models: {} - {}", status, error_body);
return Err(WorkerError::from_api_error(
error_msg,
&crate::types::LlmProvider::Ollama,
));
}
let response_text = response.text().await.map_err(|e| {
tracing::error!("Failed to read Ollama response text: {}", e);
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
})?;
// Raw response詳細ログは削除見づらさ解消のため
let models_response: serde_json::Value =
serde_json::from_str(&response_text).map_err(|e| {
tracing::error!(
"Failed to parse Ollama JSON response: {} - Response: {}",
e,
response_text
);
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
})?;
let model_names: Vec<String> = models_response
.get("models")
.and_then(|models| models.as_array())
.ok_or_else(|| {
tracing::error!("Invalid Ollama models response format - missing 'models' array");
WorkerError::LlmApiError("Invalid models response format".to_string())
})?
.iter()
.filter_map(|model| {
model
.get("name")
.and_then(|name| name.as_str())
.map(|s| s.to_string())
})
.collect();
// Process models concurrently to get detailed information
let mut models = Vec::new();
for name in model_names {
models.push(crate::types::ModelInfo {
id: name.clone(),
name: name.clone(),
provider: crate::types::LlmProvider::Ollama,
supports_tools: true, // Will be determined by config
supports_function_calling: true,
supports_vision: false, // Will be determined by config
supports_multimodal: false,
context_length: None,
training_cutoff: None,
capabilities: vec!["text_generation".to_string()],
description: Some(format!("Ollama model: {}", name)),
});
}
tracing::info!(
"Ollama list_models found {} models with dynamic capability detection",
models.len()
);
Ok(models)
}
}
use async_stream::stream;
impl OllamaClient {
pub async fn chat_stream<'a>(
&'a self,
messages: Vec<Message>,
tools: Option<&[crate::types::DynamicToolDefinition]>,
llm_debug: Option<crate::types::LlmDebug>,
) -> Result<
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
WorkerError,
> {
let client = Client::new();
let url = format!("{}/api/chat", self.base_url);
let ollama_messages: Vec<OllamaMessage> = messages
.into_iter()
.map(|msg| {
// Convert tool calls if present
let tool_calls = msg.tool_calls.map(|calls| {
calls
.into_iter()
.map(|call| OllamaToolCall {
function: OllamaToolCallFunction {
name: call.name,
arguments: call.arguments,
},
})
.collect()
});
OllamaMessage {
role: match msg.role {
Role::User => "user".to_string(),
Role::Model => "assistant".to_string(),
Role::System => "system".to_string(),
Role::Tool => "tool".to_string(),
},
content: msg.content,
tool_calls,
}
})
.collect();
// Convert tools to Ollama format (similar to OpenAI)
let ollama_tools = tools.map(|tools| {
tools
.iter()
.map(|tool| OllamaTool {
tool_type: "function".to_string(),
function: OllamaFunction {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: tool.parameters_schema.clone(),
},
})
.collect()
});
let request = OllamaRequest {
model: self.model.clone(),
messages: ollama_messages,
stream: true,
tools: ollama_tools,
};
let stream = stream! {
// デバッグ情報を送信
if let Some(ref debug) = llm_debug {
if let Some(debug_event) = debug.debug_request(&self.model, "Ollama", &serde_json::to_value(&request).unwrap_or_default()) {
yield Ok(debug_event);
}
}
// リクエスト情報をログに出力
tracing::info!("Ollama chat_stream: Sending request to {}", &url);
tracing::debug!("Ollama request model: {}", &request.model);
tracing::debug!("Ollama request messages count: {}", request.messages.len());
if let Some(ref tools) = request.tools {
tracing::debug!("Ollama request tools count: {}", tools.len());
}
// リクエストの詳細ログは削除(見づらさ解消のため)
let request_builder = self.add_auth_header(client.post(&url));
let response = request_builder
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| {
tracing::error!("Ollama chat_stream request failed: {}", e);
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
});
let response = match response {
Ok(resp) => resp,
Err(e) => {
yield Err(e);
return;
}
};
if !response.status().is_success() {
let status = response.status();
let error_body = response.text().await.unwrap_or_default();
tracing::error!("Ollama chat_stream failed - Status: {}, Body: {}, URL: {}", status, error_body, &url);
yield Err(WorkerError::from_api_error(
format!("Ollama API error: {} - {}", status, error_body),
&crate::types::LlmProvider::Ollama,
));
return;
} else {
tracing::info!("Ollama chat_stream response status: {}", response.status());
}
let mut byte_stream = response.bytes_stream();
let mut buffer = String::new();
let mut full_content = String::new();
let mut chunk_count = 0;
tracing::debug!("Ollama chat_stream: Starting to process response stream");
while let Some(chunk) = byte_stream.next().await {
match chunk {
Ok(bytes) => {
chunk_count += 1;
let chunk_str = String::from_utf8_lossy(&bytes);
// Chunk詳細ログは削除見づらさ解消のため
buffer.push_str(&chunk_str);
// Process line by line
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].to_string();
buffer = buffer[line_end + 1..].to_string();
if line.trim().is_empty() {
continue;
}
// Stream行詳細ログは削除見づらさ解消のため
match serde_json::from_str::<OllamaStreamResponse>(&line) {
Ok(stream_response) => {
// デバッグ情報を送信
if let Some(ref debug) = llm_debug {
if let Some(debug_event) = debug.debug_response(&self.model, "Ollama", &serde_json::to_value(&stream_response).unwrap_or_default()) {
yield Ok(debug_event);
}
}
// Handle tool calls
if let Some(tool_calls) = &stream_response.message.tool_calls {
tracing::info!("Ollama stream response contains {} tool calls", tool_calls.len());
for (i, tool_call) in tool_calls.iter().enumerate() {
tracing::debug!("Tool call #{}: name={}, arguments={}",
i + 1, tool_call.function.name, tool_call.function.arguments);
let parsed_tool_call = ToolCall {
name: tool_call.function.name.clone(),
arguments: tool_call.function.arguments.clone(),
};
yield Ok(StreamEvent::ToolCall(parsed_tool_call));
}
}
// Handle regular content
if !stream_response.message.content.is_empty() {
full_content.push_str(&stream_response.message.content);
yield Ok(StreamEvent::Chunk(stream_response.message.content));
}
if stream_response.done {
tracing::info!("Ollama stream completed, total content: {} chars", full_content.len());
tracing::debug!("Ollama complete response content: {}", full_content);
yield Ok(StreamEvent::Completion(Message::new(
Role::Model,
full_content.clone(),
)));
break;
}
}
Err(e) => {
tracing::warn!("Failed to parse Ollama stream response: {} - Line: {}", e, line);
tracing::debug!("Parse error details: line_length={}, error={}", line.len(), e);
}
}
}
}
Err(e) => {
tracing::error!("Ollama stream error after {} chunks: {}", chunk_count, e);
yield Err(WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama));
break;
}
}
}
tracing::debug!("Ollama chat_stream: Stream ended, processed {} chunks", chunk_count);
};
Ok(Box::new(Box::pin(stream)))
}
pub async fn get_model_details(
&self,
model_name: &str,
) -> Result<crate::types::ModelInfo, WorkerError> {
let client = Client::new();
let url = format!("{}/api/show", self.base_url);
let request = serde_json::json!({
"name": model_name
});
let response = self
.add_auth_header(client.post(&url))
.json(&request)
.send()
.await
.map_err(|e| {
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
})?;
if !response.status().is_success() {
return Err(WorkerError::from_api_error(
format!(
"Ollama model details request failed with status: {}",
response.status()
),
&crate::types::LlmProvider::Ollama,
));
}
let model_data: serde_json::Value = response.json().await.map_err(|e| {
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
})?;
let details = model_data
.get("details")
.unwrap_or(&serde_json::Value::Null);
let family = details
.get("family")
.and_then(|f| f.as_str())
.unwrap_or("unknown");
let parameter_size = details
.get("parameter_size")
.and_then(|p| p.as_str())
.unwrap_or("unknown");
let quantization = details
.get("quantization_level")
.and_then(|q| q.as_str())
.unwrap_or("unknown");
let size = model_data.get("size").and_then(|s| s.as_u64()).unwrap_or(0);
let modified_at = model_data
.get("modified_at")
.and_then(|m| m.as_str())
.map(|s| s.to_string());
let supports_tools = true; // Will be determined by config
let context_length = None; // Will be determined by config
let capabilities = vec!["text_generation".to_string()]; // Basic default
let description = format!("Ollama model: {}", model_name);
Ok(crate::types::ModelInfo {
id: model_name.to_string(),
name: format!("{} ({}, {})", model_name, family, parameter_size),
provider: crate::types::LlmProvider::Ollama,
supports_tools,
supports_function_calling: supports_tools,
supports_vision: false, // Will be determined dynamically
supports_multimodal: false,
context_length,
training_cutoff: modified_at,
capabilities,
description: Some(format!(
"{} (Size: {} bytes, Quantization: {})",
description, size, quantization
)),
})
}
pub async fn check_connection(&self) -> Result<(), WorkerError> {
let client = Client::new();
let url = format!("{}/api/tags", self.base_url);
self.add_auth_header(client.get(&url))
.send()
.await
.map_err(|e| WorkerError::LlmApiError(format!("Failed to connect to Ollama: {}", e)))?;
Ok(())
}
}
#[async_trait::async_trait]
impl LlmClientTrait for OllamaClient {
async fn chat_stream<'a>(
&'a self,
messages: Vec<Message>,
tools: Option<&[DynamicToolDefinition]>,
llm_debug: Option<crate::types::LlmDebug>,
) -> Result<
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
WorkerError,
> {
self.chat_stream(messages, tools, llm_debug).await
}
async fn check_connection(&self) -> Result<(), WorkerError> {
self.check_connection().await
}
fn provider(&self) -> LlmProvider {
LlmProvider::Ollama
}
fn get_model_name(&self) -> String {
self.get_model_name()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{Message, Role, ToolCall};
#[test]
fn test_message_conversion_with_tool_calls() {
let tool_call = ToolCall {
name: "List".to_string(),
arguments: r#"{"path": "./"}"#.to_string(),
};
let message = Message::with_tool_calls(
Role::Model,
"".to_string(), // Empty content, only tool calls
vec![tool_call.clone()],
);
let messages = vec![message];
// Simulate the conversion that happens in chat_stream
let ollama_messages: Vec<OllamaMessage> = messages
.into_iter()
.map(|msg| {
// Convert tool calls if present
let tool_calls = msg.tool_calls.map(|calls| {
calls
.into_iter()
.map(|call| OllamaToolCall {
function: OllamaToolCallFunction {
name: call.name,
arguments: call.arguments,
},
})
.collect()
});
OllamaMessage {
role: "assistant".to_string(),
content: msg.content,
tool_calls,
}
})
.collect();
// Verify the conversion preserved tool calls
assert_eq!(ollama_messages.len(), 1);
let converted_msg = &ollama_messages[0];
assert_eq!(converted_msg.role, "assistant");
assert_eq!(converted_msg.content, "");
assert!(converted_msg.tool_calls.is_some());
let converted_tool_calls = converted_msg.tool_calls.as_ref().unwrap();
assert_eq!(converted_tool_calls.len(), 1);
assert_eq!(converted_tool_calls[0].function.name, "List");
assert_eq!(
converted_tool_calls[0].function.arguments,
r#"{"path": "./"}"#
);
}
#[test]
fn test_message_conversion_without_tool_calls() {
let message = Message::new(Role::User, "Hello".to_string());
let messages = vec![message];
let ollama_messages: Vec<OllamaMessage> = messages
.into_iter()
.map(|msg| {
let tool_calls = msg.tool_calls.map(|calls| {
calls
.into_iter()
.map(|call| OllamaToolCall {
function: OllamaToolCallFunction {
name: call.name,
arguments: call.arguments,
},
})
.collect()
});
OllamaMessage {
role: "user".to_string(),
content: msg.content,
tool_calls,
}
})
.collect();
assert_eq!(ollama_messages.len(), 1);
let converted_msg = &ollama_messages[0];
assert_eq!(converted_msg.role, "user");
assert_eq!(converted_msg.content, "Hello");
assert!(converted_msg.tool_calls.is_none());
}
}

380
worker/src/llm/openai.rs Normal file
View File

@ -0,0 +1,380 @@
use crate::{
LlmClientTrait, WorkerError,
types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall},
url_config::UrlConfig,
};
use futures_util::{Stream, StreamExt};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::Value;
// --- Request & Response Structures ---
#[derive(Debug, Serialize)]
pub(crate) struct OpenAIRequest {
pub model: String,
pub messages: Vec<OpenAIMessage>,
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<OpenAITool>>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct OpenAIMessage {
pub role: String,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<OpenAIToolCall>>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct OpenAIToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: OpenAIFunction,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct OpenAIFunction {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Serialize, Clone)]
pub struct OpenAITool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: OpenAIFunctionDef,
}
#[derive(Debug, Serialize, Clone)]
pub struct OpenAIFunctionDef {
pub name: String,
pub description: String,
pub parameters: Value,
}
#[derive(Debug, Deserialize)]
pub(crate) struct OpenAIResponse {
pub choices: Vec<OpenAIChoice>,
}
#[derive(Debug, Deserialize)]
pub struct OpenAIChoice {
pub message: OpenAIMessage,
#[serde(skip_serializing_if = "Option::is_none")]
pub delta: Option<OpenAIDelta>,
}
#[derive(Debug, Deserialize)]
pub struct OpenAIDelta {
pub content: Option<String>,
pub tool_calls: Option<Vec<OpenAIToolCall>>,
}
// --- Client ---
pub struct OpenAIClient {
api_key: String,
model: String,
}
impl OpenAIClient {
pub fn new(api_key: &str, model: &str) -> Self {
Self {
api_key: api_key.to_string(),
model: model.to_string(),
}
}
pub fn get_model_name(&self) -> String {
self.model.clone()
}
}
use async_stream::stream;
impl OpenAIClient {
pub async fn chat_stream<'a>(
&'a self,
messages: Vec<Message>,
tools: Option<&[crate::types::DynamicToolDefinition]>,
llm_debug: Option<crate::types::LlmDebug>,
) -> Result<
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
WorkerError,
> {
let client = Client::new();
let url = UrlConfig::get_completion_url("openai");
let openai_messages: Vec<OpenAIMessage> = messages
.into_iter()
.map(|msg| OpenAIMessage {
role: match msg.role {
Role::User => "user".to_string(),
Role::Model => "assistant".to_string(),
Role::System => "system".to_string(),
Role::Tool => "tool".to_string(),
},
content: msg.content,
tool_calls: None,
})
.collect();
// Convert tools to OpenAI format
let openai_tools = tools.map(|tools| {
tools
.iter()
.map(|tool| OpenAITool {
tool_type: "function".to_string(),
function: OpenAIFunctionDef {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: tool.parameters_schema.clone(),
},
})
.collect()
});
let request = OpenAIRequest {
model: self.model.clone(),
messages: openai_messages,
stream: true,
tools: openai_tools,
};
let response = client
.post(url)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&request)
.send()
.await
.map_err(|e| {
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::OpenAI)
})?;
if !response.status().is_success() {
let status = response.status();
let error_body = response.text().await.unwrap_or_default();
return Err(WorkerError::from_api_error(
format!("OpenAI API error: {} - {}", status, error_body),
&crate::types::LlmProvider::OpenAI,
));
}
let stream = stream! {
// デバッグ情報を送信
if let Some(ref debug) = llm_debug {
if let Some(debug_event) = debug.debug_request(&self.model, "OpenAI", &serde_json::to_value(&request).unwrap_or_default()) {
yield Ok(debug_event);
}
}
let mut stream = response.bytes_stream();
let mut buffer = String::new();
while let Some(chunk) = stream.next().await {
match chunk {
Ok(bytes) => {
let chunk_str = String::from_utf8_lossy(&bytes);
buffer.push_str(&chunk_str);
// Server-sent eventsを処理
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].to_string();
buffer = buffer[line_end + 1..].to_string();
if line.starts_with("data: ") {
let data = &line[6..];
if data == "[DONE]" {
yield Ok(StreamEvent::Completion(Message::new(
Role::Model,
"".to_string(),
)));
break;
}
match serde_json::from_str::<Value>(data) {
Ok(json_data) => {
// デバッグ情報を送信
if let Some(ref debug) = llm_debug {
if let Some(debug_event) = debug.debug_response(&self.model, "OpenAI", &json_data) {
yield Ok(debug_event);
}
}
if let Some(choices) = json_data.get("choices").and_then(|c| c.as_array()) {
for choice in choices {
if let Some(delta) = choice.get("delta") {
// コンテンツを処理
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
yield Ok(StreamEvent::Chunk(content.to_string()));
}
// ツールコールを処理
if let Some(tool_calls) = delta.get("tool_calls").and_then(|tc| tc.as_array()) {
for tool_call in tool_calls {
if let Some(function) = tool_call.get("function") {
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
let arguments = function.get("arguments")
.and_then(|a| a.as_str())
.unwrap_or("");
let tool_call = ToolCall {
name: name.to_string(),
arguments: arguments.to_string(),
};
yield Ok(StreamEvent::ToolCall(tool_call));
}
}
}
}
}
}
}
}
Err(e) => {
tracing::warn!("Failed to parse OpenAI stream response: {}", e);
}
}
}
}
}
Err(e) => {
yield Err(WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::OpenAI));
break;
}
}
}
};
Ok(Box::new(Box::pin(stream)))
}
pub async fn get_model_details(
&self,
model_id: &str,
) -> Result<crate::types::ModelInfo, WorkerError> {
let client = Client::new();
let url = UrlConfig::get_model_url("openai", model_id);
let response = client
.get(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await
.map_err(|e| {
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::OpenAI)
})?;
if !response.status().is_success() {
return Err(WorkerError::from_api_error(
format!(
"OpenAI model details request failed with status: {}",
response.status()
),
&crate::types::LlmProvider::OpenAI,
));
}
let model_data: serde_json::Value = response.json().await.map_err(|e| {
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::OpenAI)
})?;
let id = model_data
.get("id")
.and_then(|id| id.as_str())
.unwrap_or(model_id);
let owned_by = model_data
.get("owned_by")
.and_then(|owner| owner.as_str())
.unwrap_or("openai");
let created = model_data
.get("created")
.and_then(|c| c.as_i64())
.map(|timestamp| format!("{}", timestamp));
let supports_tools = true; // Default to true, will be determined by config
let context_length = None; // Will be determined by config
let capabilities = vec!["text_generation".to_string()]; // Basic default
let description = format!("OpenAI model: {}", id);
Ok(crate::types::ModelInfo {
id: id.to_string(),
name: format!("{} ({})", id, owned_by),
provider: crate::types::LlmProvider::OpenAI,
supports_tools,
supports_function_calling: supports_tools,
supports_vision: false, // Will be determined dynamically
supports_multimodal: false,
context_length,
training_cutoff: created,
capabilities,
description: Some(description),
})
}
pub async fn check_connection(&self) -> Result<(), WorkerError> {
let client = Client::new();
let url = UrlConfig::get_completion_url("openai");
let test_request = OpenAIRequest {
model: self.model.clone(),
messages: vec![OpenAIMessage {
role: "user".to_string(),
content: "Hi".to_string(),
tool_calls: None,
}],
stream: false,
tools: None,
};
let response = client
.post(url)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&test_request)
.send()
.await
.map_err(|e| {
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::OpenAI)
})?;
if !response.status().is_success() {
let status = response.status();
let error_body = response.text().await.unwrap_or_default();
return Err(WorkerError::from_api_error(
format!("OpenAI connection test failed: {} - {}", status, error_body),
&crate::types::LlmProvider::OpenAI,
));
}
Ok(())
}
}
#[async_trait::async_trait]
impl LlmClientTrait for OpenAIClient {
async fn chat_stream<'a>(
&'a self,
messages: Vec<Message>,
tools: Option<&[DynamicToolDefinition]>,
llm_debug: Option<crate::types::LlmDebug>,
) -> Result<
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
WorkerError,
> {
self.chat_stream(messages, tools, llm_debug).await
}
async fn check_connection(&self) -> Result<(), WorkerError> {
self.check_connection().await
}
fn provider(&self) -> LlmProvider {
LlmProvider::OpenAI
}
fn get_model_name(&self) -> String {
self.get_model_name()
}
}

386
worker/src/llm/xai.rs Normal file
View File

@ -0,0 +1,386 @@
use crate::{
LlmClientTrait, WorkerError,
types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall},
url_config::UrlConfig,
};
use futures_util::{Stream, StreamExt};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Serialize)]
pub(crate) struct XAIRequest {
pub model: String,
pub messages: Vec<XAIMessage>,
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<XAITool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct XAIMessage {
pub role: String,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<XAIToolCall>>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct XAIToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: XAIFunction,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct XAIFunction {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Serialize, Clone)]
pub struct XAITool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: XAIFunctionDef,
}
#[derive(Debug, Serialize, Clone)]
pub struct XAIFunctionDef {
pub name: String,
pub description: String,
pub parameters: Value,
}
#[derive(Debug, Deserialize)]
pub(crate) struct XAIResponse {
pub choices: Vec<XAIChoice>,
}
#[derive(Debug, Deserialize)]
pub struct XAIChoice {
pub message: XAIMessage,
#[serde(skip_serializing_if = "Option::is_none")]
pub delta: Option<XAIDelta>,
}
#[derive(Debug, Deserialize)]
pub struct XAIDelta {
pub content: Option<String>,
pub tool_calls: Option<Vec<XAIToolCall>>,
}
#[derive(Debug, Deserialize)]
pub struct XAIModel {
pub id: String,
pub object: String,
pub created: i64,
pub owned_by: String,
}
#[derive(Debug, Deserialize)]
pub struct XAIModelsResponse {
pub object: String,
pub data: Vec<XAIModel>,
}
pub struct XAIClient {
api_key: String,
model: String,
}
impl XAIClient {
pub fn new(api_key: &str, model: &str) -> Self {
Self {
api_key: api_key.to_string(),
model: model.to_string(),
}
}
pub fn get_model_name(&self) -> String {
self.model.clone()
}
}
use async_stream::stream;
impl XAIClient {
pub async fn chat_stream<'a>(
&'a self,
messages: Vec<Message>,
tools: Option<&[crate::types::DynamicToolDefinition]>,
llm_debug: Option<crate::types::LlmDebug>,
) -> Result<
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
WorkerError,
> {
let client = Client::new();
let url = UrlConfig::get_completion_url("xai");
let xai_messages: Vec<XAIMessage> = messages
.into_iter()
.map(|msg| XAIMessage {
role: match msg.role {
Role::User => "user".to_string(),
Role::Model => "assistant".to_string(),
Role::System => "system".to_string(),
Role::Tool => "tool".to_string(),
},
content: msg.content,
tool_calls: None,
})
.collect();
let xai_tools = tools.map(|tools| {
tools
.iter()
.map(|tool| XAITool {
tool_type: "function".to_string(),
function: XAIFunctionDef {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: tool.parameters_schema.clone(),
},
})
.collect()
});
let request = XAIRequest {
model: self.model.clone(),
messages: xai_messages,
stream: true,
tools: xai_tools,
max_tokens: None,
temperature: None,
};
let response = client
.post(url)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&request)
.send()
.await
.map_err(|e| {
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::XAI)
})?;
if !response.status().is_success() {
let status = response.status();
let error_body = response.text().await.unwrap_or_default();
return Err(WorkerError::from_api_error(
format!("xAI API error: {} - {}", status, error_body),
&crate::types::LlmProvider::XAI,
));
}
let stream = stream! {
if let Some(ref debug) = llm_debug {
if let Some(debug_event) = debug.debug_request(&self.model, "xAI", &serde_json::to_value(&request).unwrap_or_default()) {
yield Ok(debug_event);
}
}
let mut stream = response.bytes_stream();
let mut buffer = String::new();
while let Some(chunk) = stream.next().await {
match chunk {
Ok(bytes) => {
let chunk_str = String::from_utf8_lossy(&bytes);
buffer.push_str(&chunk_str);
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].to_string();
buffer = buffer[line_end + 1..].to_string();
if line.starts_with("data: ") {
let data = &line[6..];
if data == "[DONE]" {
yield Ok(StreamEvent::Completion(Message::new(
Role::Model,
"".to_string(),
)));
break;
}
match serde_json::from_str::<Value>(data) {
Ok(json_data) => {
if let Some(ref debug) = llm_debug {
if let Some(debug_event) = debug.debug_response(&self.model, "xAI", &json_data) {
yield Ok(debug_event);
}
}
if let Some(choices) = json_data.get("choices").and_then(|c| c.as_array()) {
for choice in choices {
if let Some(delta) = choice.get("delta") {
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
yield Ok(StreamEvent::Chunk(content.to_string()));
}
if let Some(tool_calls) = delta.get("tool_calls").and_then(|tc| tc.as_array()) {
for tool_call in tool_calls {
if let Some(function) = tool_call.get("function") {
if let (Some(name), Some(arguments)) = (
function.get("name").and_then(|n| n.as_str()),
function.get("arguments").and_then(|a| a.as_str())
) {
let tool_call = ToolCall {
name: name.to_string(),
arguments: arguments.to_string(),
};
yield Ok(StreamEvent::ToolCall(tool_call));
}
}
}
}
}
}
}
}
Err(e) => {
tracing::warn!("Failed to parse xAI stream response: {}", e);
}
}
}
}
}
Err(e) => {
yield Err(WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::XAI));
break;
}
}
}
};
Ok(Box::new(Box::pin(stream)))
}
pub async fn get_model_details(
&self,
model_id: &str,
) -> Result<crate::types::ModelInfo, WorkerError> {
let client = Client::new();
let url = UrlConfig::get_model_url("xai", model_id);
let response = client
.get(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await
.map_err(|e| {
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::XAI)
})?;
if !response.status().is_success() {
return Err(WorkerError::from_api_error(
format!(
"xAI model details request failed with status: {}",
response.status()
),
&crate::types::LlmProvider::XAI,
));
}
let model_data: XAIModel = response.json().await.map_err(|e| {
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::XAI)
})?;
let supports_tools = true; // Will be determined by config
let supports_vision = false; // Will be determined by config
let context_length = None; // Will be determined by config
let capabilities = vec!["text_generation".to_string()]; // Basic default
let description = format!("xAI {} model ({})", model_data.id, model_data.owned_by);
Ok(crate::types::ModelInfo {
id: model_data.id.clone(),
name: format!("{} ({})", model_data.id, model_data.owned_by),
provider: crate::types::LlmProvider::XAI,
supports_tools,
supports_function_calling: supports_tools,
supports_vision,
supports_multimodal: supports_vision,
context_length,
training_cutoff: Some(
chrono::DateTime::from_timestamp(model_data.created, 0)
.map(|dt| dt.format("%Y-%m-%d").to_string())
.unwrap_or_else(|| "2024-12-12".to_string()),
),
capabilities,
description: Some(description),
})
}
pub async fn check_connection(&self) -> Result<(), WorkerError> {
let client = Client::new();
let url = UrlConfig::get_completion_url("xai");
let test_request = XAIRequest {
model: self.model.clone(),
messages: vec![XAIMessage {
role: "user".to_string(),
content: "Hi".to_string(),
tool_calls: None,
}],
stream: false,
tools: None,
max_tokens: Some(10),
temperature: Some(0.1),
};
let response = client
.post(url)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&test_request)
.send()
.await
.map_err(|e| {
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::XAI)
})?;
if !response.status().is_success() {
let status = response.status();
let error_body = response.text().await.unwrap_or_default();
return Err(WorkerError::from_api_error(
format!("xAI connection test failed: {} - {}", status, error_body),
&crate::types::LlmProvider::XAI,
));
}
Ok(())
}
}
#[async_trait::async_trait]
impl LlmClientTrait for XAIClient {
async fn chat_stream<'a>(
&'a self,
messages: Vec<Message>,
tools: Option<&[DynamicToolDefinition]>,
llm_debug: Option<crate::types::LlmDebug>,
) -> Result<
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
WorkerError,
> {
self.chat_stream(messages, tools, llm_debug).await
}
async fn check_connection(&self) -> Result<(), WorkerError> {
self.check_connection().await
}
fn provider(&self) -> LlmProvider {
LlmProvider::XAI
}
fn get_model_name(&self) -> String {
self.get_model_name()
}
}

364
worker/src/mcp_config.rs Normal file
View File

@ -0,0 +1,364 @@
use crate::WorkerError;
use crate::mcp_tool::McpServerConfig;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use tracing::{debug, info, warn};
/// MCP設定ファイルの構造
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpConfig {
/// MCPサーバーの設定一覧
pub servers: HashMap<String, McpServerDefinition>,
}
/// 個別のMCPサーバー定義
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpServerDefinition {
/// 実行コマンド
pub command: String,
/// コマンドライン引数
#[serde(default)]
pub args: Vec<String>,
/// 環境変数の設定
#[serde(default)]
pub env: HashMap<String, String>,
/// サーバーの説明(オプション)
pub description: Option<String>,
/// 有効/無効の設定(デフォルト: true
#[serde(default = "default_enabled")]
pub enabled: bool,
/// 統合方式の選択proxy または individual
#[serde(default = "default_integration_mode")]
pub integration_mode: IntegrationMode,
}
/// MCP統合方式
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum IntegrationMode {
/// プロキシーモード - 単一ツールで全MCPツールにアクセス
Proxy,
/// 個別モード - MCPツールを個別のWorkerツールとして登録
Individual,
}
fn default_enabled() -> bool {
true
}
fn default_integration_mode() -> IntegrationMode {
IntegrationMode::Individual
}
impl McpConfig {
/// 設定ファイルを読み込む
pub fn load_from_file<P: AsRef<Path>>(path: P) -> Result<Self, WorkerError> {
let path = path.as_ref();
if !path.exists() {
debug!(
"MCP config file not found at {:?}, returning empty config",
path
);
return Ok(Self::default());
}
info!("Loading MCP config from: {:?}", path);
let content = std::fs::read_to_string(path).map_err(|e| {
WorkerError::ConfigurationError(format!(
"Failed to read MCP config file {:?}: {}",
path, e
))
})?;
let config: McpConfig = serde_yaml::from_str(&content).map_err(|e| {
WorkerError::ConfigurationError(format!(
"Failed to parse MCP config file {:?}: {}",
path, e
))
})?;
info!("Loaded {} MCP server configurations", config.servers.len());
Ok(config)
}
/// 設定ファイルに保存する
pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> Result<(), WorkerError> {
let path = path.as_ref();
// ディレクトリが存在しない場合は作成
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).map_err(|e| {
WorkerError::ConfigurationError(format!(
"Failed to create config directory {:?}: {}",
parent, e
))
})?;
}
let content = serde_yaml::to_string(self).map_err(|e| {
WorkerError::ConfigurationError(format!("Failed to serialize MCP config: {}", e))
})?;
std::fs::write(path, content).map_err(|e| {
WorkerError::ConfigurationError(format!(
"Failed to write MCP config file {:?}: {}",
path, e
))
})?;
info!("Saved MCP config to: {:?}", path);
Ok(())
}
/// 有効なサーバー設定を取得
pub fn get_enabled_servers(&self) -> Vec<(&String, &McpServerDefinition)> {
self.servers.iter().filter(|(_, def)| def.enabled).collect()
}
/// デフォルト設定ファイルを生成
pub fn create_default_config() -> Self {
let mut servers = HashMap::new();
// Brave Search MCP Server の設定例
servers.insert(
"brave_search".to_string(),
McpServerDefinition {
command: "npx".to_string(),
args: vec![
"-y".to_string(),
"@brave/brave-search-mcp-server".to_string(),
],
env: {
let mut env = HashMap::new();
env.insert("BRAVE_API_KEY".to_string(), "${BRAVE_API_KEY}".to_string());
env
},
description: Some("Brave Search API for web searching".to_string()),
enabled: false, // デフォルトでは無効APIキーが必要なため
integration_mode: IntegrationMode::Individual,
},
);
// ファイルシステムMCPサーバーの設定例
servers.insert(
"filesystem".to_string(),
McpServerDefinition {
command: "npx".to_string(),
args: vec![
"-y".to_string(),
"@modelcontextprotocol/server-filesystem".to_string(),
"/tmp".to_string(),
],
env: HashMap::new(),
description: Some("Filesystem operations in /tmp directory".to_string()),
enabled: false, // デフォルトでは無効
integration_mode: IntegrationMode::Individual,
},
);
// Git MCP サーバーの設定例
servers.insert(
"git".to_string(),
McpServerDefinition {
command: "npx".to_string(),
args: vec![
"-y".to_string(),
"@modelcontextprotocol/server-git".to_string(),
".".to_string(),
],
env: HashMap::new(),
description: Some("Git operations in current directory".to_string()),
enabled: false, // デフォルトでは無効
integration_mode: IntegrationMode::Individual,
},
);
Self { servers }
}
}
impl Default for McpConfig {
fn default() -> Self {
Self {
servers: HashMap::new(),
}
}
}
impl McpServerDefinition {
/// 環境変数を展開してMcpServerConfigに変換
pub fn to_mcp_server_config(&self, _name: &str) -> Result<McpServerConfig, WorkerError> {
// 環境変数を設定
for (key, value) in &self.env {
let expanded_value = expand_environment_variables(value)?;
// SAFETY: Setting environment variables is safe when done in a controlled manner
// for MCP server configuration purposes
unsafe {
std::env::set_var(key, expanded_value);
}
debug!("Set environment variable: {}={}", key, value); // 実際の値はログに出力しない
}
// コマンドと引数の環境変数を展開
let expanded_command = expand_environment_variables(&self.command)?;
let expanded_args: Result<Vec<String>, WorkerError> = self
.args
.iter()
.map(|arg| expand_environment_variables(arg))
.collect();
let expanded_args = expanded_args?;
Ok(McpServerConfig::new(expanded_command, expanded_args))
}
}
/// 環境変数を展開する(${VAR_NAME} 形式をサポート)
fn expand_environment_variables(input: &str) -> Result<String, WorkerError> {
let mut result = input.to_string();
// ${VAR_NAME} パターンを検索して置換
let re = regex::Regex::new(r"\$\{([^}]+)\}")
.map_err(|e| WorkerError::ConfigurationError(format!("Regex error: {}", e)))?;
for caps in re.captures_iter(input) {
let full_match = &caps[0];
let var_name = &caps[1];
match std::env::var(var_name) {
Ok(value) => {
result = result.replace(full_match, &value);
}
Err(_) => {
warn!(
"Environment variable '{}' not found, leaving unexpanded",
var_name
);
// 環境変数が見つからない場合はそのまま残す
// これにより、必要に応じてユーザーが後で設定できる
}
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::tempdir;
#[test]
fn test_default_config_creation() {
let config = McpConfig::create_default_config();
assert!(!config.servers.is_empty());
assert!(config.servers.contains_key("brave_search"));
assert!(config.servers.contains_key("filesystem"));
}
#[test]
fn test_config_serialization() {
let config = McpConfig::create_default_config();
let yaml = serde_yaml::to_string(&config).unwrap();
// YAML形式で正しくシリアライズされることを確認
assert!(yaml.contains("servers:"));
assert!(yaml.contains("brave_search:"));
assert!(yaml.contains("command:"));
}
#[test]
fn test_config_deserialization() {
let yaml_content = r#"
servers:
test_server:
command: "python3"
args: ["test.py"]
env:
TEST_VAR: "test_value"
description: "Test server"
enabled: true
integration_mode: "individual"
"#;
let config: McpConfig = serde_yaml::from_str(yaml_content).unwrap();
assert_eq!(config.servers.len(), 1);
let server = config.servers.get("test_server").unwrap();
assert_eq!(server.command, "python3");
assert_eq!(server.args, vec!["test.py"]);
assert_eq!(server.env.get("TEST_VAR").unwrap(), "test_value");
assert!(server.enabled);
}
#[test]
fn test_environment_variable_expansion() {
std::env::set_var("TEST_VAR", "test_value");
let result = expand_environment_variables("prefix_${TEST_VAR}_suffix").unwrap();
assert_eq!(result, "prefix_test_value_suffix");
// 存在しない環境変数の場合はそのまま残る
let result = expand_environment_variables("${NON_EXISTENT_VAR}").unwrap();
assert_eq!(result, "${NON_EXISTENT_VAR}");
}
#[test]
fn test_config_file_operations() {
let dir = tempdir().unwrap();
let config_path = dir.path().join("mcp.yaml");
// 設定を作成して保存
let config = McpConfig::create_default_config();
config.save_to_file(&config_path).unwrap();
// ファイルが作成されたことを確認
assert!(config_path.exists());
// 設定を読み込み
let loaded_config = McpConfig::load_from_file(&config_path).unwrap();
assert_eq!(config.servers.len(), loaded_config.servers.len());
}
#[test]
fn test_enabled_servers_filter() {
let mut config = McpConfig::default();
// 有効なサーバーを追加
config.servers.insert(
"enabled_server".to_string(),
McpServerDefinition {
command: "test".to_string(),
args: vec![],
env: HashMap::new(),
description: None,
enabled: true,
integration_mode: IntegrationMode::Individual,
},
);
// 無効なサーバーを追加
config.servers.insert(
"disabled_server".to_string(),
McpServerDefinition {
command: "test".to_string(),
args: vec![],
env: HashMap::new(),
description: None,
enabled: false,
integration_mode: IntegrationMode::Individual,
},
);
let enabled_servers = config.get_enabled_servers();
assert_eq!(enabled_servers.len(), 1);
assert_eq!(enabled_servers[0].0, "enabled_server");
}
}

449
worker/src/mcp_protocol.rs Normal file
View File

@ -0,0 +1,449 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, ChildStdout};
use tracing::{debug, info, trace, warn};
/// JSON-RPC 2.0 Request
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcRequest {
pub jsonrpc: String,
pub id: Value,
pub method: String,
pub params: Option<Value>,
}
/// JSON-RPC 2.0 Response
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcResponse {
pub jsonrpc: String,
pub id: Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>,
}
/// JSON-RPC 2.0 Error
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcError {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<Value>,
}
/// MCP Tool definition
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpToolDefinition {
pub name: String,
pub description: Option<String>,
#[serde(rename = "inputSchema")]
pub input_schema: Value,
}
/// MCP Initialize request parameters
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InitializeParams {
#[serde(rename = "protocolVersion")]
pub protocol_version: String,
pub capabilities: InitializeCapabilities,
#[serde(rename = "clientInfo")]
pub client_info: ClientInfo,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InitializeCapabilities {
pub tools: Option<ToolCapabilities>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCapabilities {
#[serde(rename = "listChanged")]
pub list_changed: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientInfo {
pub name: String,
pub version: String,
}
/// MCP Initialize response
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InitializeResult {
#[serde(rename = "protocolVersion")]
pub protocol_version: String,
pub capabilities: ServerCapabilities,
#[serde(rename = "serverInfo")]
pub server_info: ServerInfo,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerCapabilities {
pub tools: Option<ToolCapabilities>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerInfo {
pub name: String,
pub version: String,
}
/// List tools response
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ListToolsResult {
pub tools: Vec<McpToolDefinition>,
}
/// Call tool parameters
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CallToolParams {
pub name: String,
pub arguments: Option<Value>,
}
/// Tool call result content
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolContent {
#[serde(rename = "type")]
pub content_type: String,
pub text: Option<String>,
pub data: Option<Value>,
}
/// Call tool result
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CallToolResult {
pub content: Vec<ToolContent>,
}
/// MCP Client for JSON-RPC 2.0 communication
pub struct McpClient {
child: Option<Child>,
stdin: Option<ChildStdin>,
stdout: Option<BufReader<ChildStdout>>,
request_id: u64,
initialized: bool,
}
impl McpClient {
/// Create a new MCP client
pub fn new() -> Self {
Self {
child: None,
stdin: None,
stdout: None,
request_id: 0,
initialized: false,
}
}
/// Start MCP server process and initialize connection
pub async fn connect(
&mut self,
command: String,
args: Vec<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
info!("Starting MCP server: {} {:?}", command, args);
// Start the process
let mut cmd = tokio::process::Command::new(&command);
cmd.args(&args);
cmd.stdin(std::process::Stdio::piped());
cmd.stdout(std::process::Stdio::piped());
cmd.stderr(std::process::Stdio::piped());
let mut child = cmd.spawn()?;
let stdin = child.stdin.take().ok_or("Failed to get stdin")?;
let stdout = child.stdout.take().ok_or("Failed to get stdout")?;
let stdout = BufReader::new(stdout);
self.child = Some(child);
self.stdin = Some(stdin);
self.stdout = Some(stdout);
info!("MCP server started successfully");
// Initialize the connection
self.initialize().await?;
Ok(())
}
/// Initialize MCP connection
async fn initialize(
&mut self,
) -> Result<InitializeResult, Box<dyn std::error::Error + Send + Sync>> {
let params = InitializeParams {
protocol_version: "2024-11-05".to_string(),
capabilities: InitializeCapabilities {
tools: Some(ToolCapabilities {
list_changed: Some(true),
}),
},
client_info: ClientInfo {
name: "nia-worker".to_string(),
version: "0.1.0".to_string(),
},
};
debug!("Sending initialize request with params: {:?}", params);
let result: InitializeResult = self
.send_request("initialize", Some(serde_json::to_value(&params)?))
.await?;
debug!("Received initialize result: {:?}", result);
// Send initialized notification
debug!("Sending initialized notification");
self.send_notification("initialized", None).await?;
self.initialized = true;
info!("MCP connection initialized successfully");
Ok(result)
}
/// Send a JSON-RPC request and wait for response
async fn send_request<T>(
&mut self,
method: &str,
params: Option<Value>,
) -> Result<T, Box<dyn std::error::Error + Send + Sync>>
where
T: for<'de> Deserialize<'de>,
{
if self.stdin.is_none() || self.stdout.is_none() {
return Err("Not connected to MCP server".into());
}
self.request_id += 1;
let id = Value::Number(serde_json::Number::from(self.request_id));
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: id.clone(),
method: method.to_string(),
params,
};
let request_json = serde_json::to_string(&request)?;
trace!("Sending MCP request: {}", request_json);
// Send request
let stdin = self.stdin.as_mut().unwrap();
stdin.write_all(request_json.as_bytes()).await?;
stdin.write_all(b"\n").await?;
stdin.flush().await?;
// Read response - keep reading until we get a valid JSON-RPC response
let stdout = self.stdout.as_mut().unwrap();
let mut response_line = String::new();
let mut attempts = 0;
const MAX_ATTEMPTS: usize = 20;
let response = loop {
response_line.clear();
// Add timeout for reading each line
let read_result = tokio::time::timeout(
std::time::Duration::from_secs(10),
stdout.read_line(&mut response_line),
)
.await;
let bytes_read = match read_result {
Ok(Ok(bytes)) => bytes,
Ok(Err(e)) => {
return Err(format!("I/O error reading from MCP server: {}", e).into());
}
Err(_) => return Err("Timeout reading from MCP server".into()),
};
trace!("Read {} bytes from MCP server", bytes_read);
trace!("Raw response line: {:?}", response_line);
if bytes_read == 0 {
// Check if the process is still running
if let Some(child) = &mut self.child {
match child.try_wait() {
Ok(Some(exit_status)) => {
return Err(
format!("MCP server exited with status: {}", exit_status).into()
);
}
Ok(None) => {
// Process is still running but closed stdout
return Err("MCP server closed stdout connection".into());
}
Err(e) => {
return Err(format!("Failed to check MCP server status: {}", e).into());
}
}
} else {
return Err("MCP server process not found".into());
}
}
let trimmed = response_line.trim();
if trimmed.is_empty() {
continue;
}
// Try to parse as JSON-RPC response
if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(trimmed) {
debug!("Received valid JSON-RPC response: {}", trimmed);
break response;
} else {
// This is likely a log message or other non-JSON output
debug!("Skipping non-JSON output: {}", trimmed);
attempts += 1;
if attempts >= MAX_ATTEMPTS {
return Err("Too many non-JSON responses from MCP server".into());
}
continue;
}
};
// Check if this is our response
if response.id != id {
return Err(format!(
"Response ID mismatch: expected {:?}, got {:?}",
id, response.id
)
.into());
}
// Handle error response
if let Some(error) = response.error {
return Err(format!("MCP server error: {} ({})", error.message, error.code).into());
}
// Parse result
let result = response.result.ok_or("No result in response")?;
let parsed_result: T = serde_json::from_value(result)?;
Ok(parsed_result)
}
/// Send a JSON-RPC notification (no response expected)
async fn send_notification(
&mut self,
method: &str,
params: Option<Value>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
if self.stdin.is_none() {
return Err("Not connected to MCP server".into());
}
let request = serde_json::json!({
"jsonrpc": "2.0",
"method": method,
"params": params
});
let request_json = serde_json::to_string(&request)?;
trace!("Sending MCP notification: {}", request_json);
let stdin = self.stdin.as_mut().unwrap();
stdin.write_all(request_json.as_bytes()).await?;
stdin.write_all(b"\n").await?;
stdin.flush().await?;
Ok(())
}
/// List available tools from MCP server
pub async fn list_tools(
&mut self,
) -> Result<Vec<McpToolDefinition>, Box<dyn std::error::Error + Send + Sync>> {
if !self.initialized {
return Err("MCP client not initialized".into());
}
// Some MCP servers expect an empty object instead of null for params
let params = serde_json::json!({});
let result: ListToolsResult = self.send_request("tools/list", Some(params)).await?;
Ok(result.tools)
}
/// Call a tool on the MCP server
pub async fn call_tool(
&mut self,
name: &str,
arguments: Option<Value>,
) -> Result<CallToolResult, Box<dyn std::error::Error + Send + Sync>> {
if !self.initialized {
return Err("MCP client not initialized".into());
}
let params = CallToolParams {
name: name.to_string(),
arguments,
};
let result: CallToolResult = self
.send_request("tools/call", Some(serde_json::to_value(&params)?))
.await?;
Ok(result)
}
/// Close the connection and terminate the server process
pub async fn close(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
if let Some(mut child) = self.child.take() {
// Try to terminate gracefully first
match child.kill().await {
Ok(()) => info!("MCP server process terminated"),
Err(e) => warn!("Failed to terminate MCP server process: {}", e),
}
}
self.stdin = None;
self.stdout = None;
self.initialized = false;
Ok(())
}
}
impl Drop for McpClient {
fn drop(&mut self) {
if let Some(mut child) = self.child.take() {
// Best effort cleanup - spawn a task to handle async kill
tokio::spawn(async move {
let _ = child.kill().await;
});
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_json_rpc_serialization() {
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Value::Number(serde_json::Number::from(1)),
method: "tools/list".to_string(),
params: None,
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"jsonrpc\":\"2.0\""));
assert!(json.contains("\"id\":1"));
assert!(json.contains("\"method\":\"tools/list\""));
}
#[test]
fn test_error_response() {
let response_json =
r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32601,"message":"Method not found"}}"#;
let response: JsonRpcResponse = serde_json::from_str(response_json).unwrap();
assert!(response.error.is_some());
assert_eq!(response.error.unwrap().code, -32601);
}
}

474
worker/src/mcp_tool.rs Normal file
View File

@ -0,0 +1,474 @@
use crate::mcp_protocol::{CallToolResult, McpClient, McpToolDefinition};
use crate::types::{Tool, ToolResult};
use async_trait::async_trait;
use serde_json::Value;
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
use tracing::{debug, error, info, warn};
/// Convert MCP CallToolResult to JSON format
fn convert_mcp_result_to_json(result: &CallToolResult) -> Value {
if result.content.is_empty() {
serde_json::json!({
"success": true,
"content": []
})
} else {
let content_json: Vec<Value> = result
.content
.iter()
.map(|content| {
serde_json::json!({
"type": content.content_type,
"text": content.text,
"data": content.data
})
})
.collect();
serde_json::json!({
"success": true,
"content": content_json
})
}
}
/// MCPサーバー設定
#[derive(Debug, Clone)]
pub struct McpServerConfig {
pub command: String,
pub args: Vec<String>,
pub name: String,
}
impl McpServerConfig {
pub fn new(command: impl Into<String>, args: Vec<impl Into<String>>) -> Self {
let command = command.into();
let args: Vec<String> = args.into_iter().map(|s| s.into()).collect();
let name = format!("{}({})", command, args.join(" "));
Self {
command,
args,
name,
}
}
}
/// 実際に動作するMCP統合ツール
pub struct McpDynamicTool {
config: McpServerConfig,
client: Arc<Mutex<Option<McpClient>>>,
tools_cache: Arc<RwLock<Vec<McpToolDefinition>>>,
}
/// 単一のMCPツールを表すDynamicTool
pub struct SingleMcpTool {
tool_name: String,
tool_description: String,
tool_schema: Value,
client: Arc<Mutex<Option<McpClient>>>,
}
impl McpDynamicTool {
/// 新しいMCPツールを作成
pub fn new(config: McpServerConfig) -> Self {
Self {
config,
client: Arc::new(Mutex::new(None)),
tools_cache: Arc::new(RwLock::new(Vec::new())),
}
}
/// MCPサーバーに接続
async fn ensure_connected(&self) -> ToolResult<()> {
let mut client_guard = self.client.lock().await;
if client_guard.is_none() {
info!("Connecting to MCP server: {}", self.config.name);
let mut mcp_client = McpClient::new();
mcp_client
.connect(self.config.command.clone(), self.config.args.clone())
.await
.map_err(|e| {
crate::WorkerError::ToolExecutionError(format!(
"Failed to connect to MCP server '{}': {}",
self.config.name, e
))
})?;
*client_guard = Some(mcp_client);
info!("Successfully connected to MCP server: {}", self.config.name);
}
Ok(())
}
/// 利用可能なツール一覧を取得
async fn fetch_tools(&self) -> ToolResult<Vec<McpToolDefinition>> {
self.ensure_connected().await?;
let mut client_guard = self.client.lock().await;
let client = client_guard.as_mut().ok_or_else(|| {
crate::WorkerError::ToolExecutionError("MCP client not connected".to_string())
})?;
let tools = client.list_tools().await.map_err(|e| {
crate::WorkerError::ToolExecutionError(format!(
"Failed to list tools from MCP server '{}': {}",
self.config.name, e
))
})?;
debug!(
"Retrieved {} tools from MCP server '{}'",
tools.len(),
self.config.name
);
Ok(tools)
}
/// ツールキャッシュを更新
async fn update_tools_cache(&self) -> ToolResult<()> {
let tools = self.fetch_tools().await?;
let mut cache_guard = self.tools_cache.write().await;
*cache_guard = tools;
Ok(())
}
/// 特定のツールを名前で検索
async fn find_tool_by_name(&self, tool_name: &str) -> ToolResult<Option<McpToolDefinition>> {
let cache_guard = self.tools_cache.read().await;
// キャッシュが空の場合は更新
if cache_guard.is_empty() {
drop(cache_guard);
self.update_tools_cache().await?;
let cache_guard = self.tools_cache.read().await;
let result = cache_guard
.iter()
.find(|tool| tool.name == tool_name)
.cloned();
Ok(result)
} else {
let result = cache_guard
.iter()
.find(|tool| tool.name == tool_name)
.cloned();
Ok(result)
}
}
/// MCPサーバーのツールを実行
async fn call_mcp_tool(&self, tool_name: &str, args: Value) -> ToolResult<Value> {
self.ensure_connected().await?;
let mut client_guard = self.client.lock().await;
let client = client_guard.as_mut().ok_or_else(|| {
crate::WorkerError::ToolExecutionError("MCP client not connected".to_string())
})?;
debug!("Calling MCP tool '{}' with args: {}", tool_name, args);
let result = client.call_tool(tool_name, Some(args)).await.map_err(|e| {
crate::WorkerError::ToolExecutionError(format!(
"Failed to call MCP tool '{}': {}",
tool_name, e
))
})?;
debug!("MCP tool '{}' returned: {:?}", tool_name, result);
// Convert MCP result to JSON
Ok(convert_mcp_result_to_json(&result))
}
}
impl SingleMcpTool {
/// 新しい単一MCPツールを作成
pub fn new(
tool_name: String,
tool_description: String,
tool_schema: Value,
client: Arc<Mutex<Option<McpClient>>>,
) -> Self {
Self {
tool_name,
tool_description,
tool_schema,
client,
}
}
/// MCPサーバーのツールを実行
async fn call_mcp_tool(&self, args: Value) -> ToolResult<Value> {
let mut client_guard = self.client.lock().await;
let client = client_guard.as_mut().ok_or_else(|| {
crate::WorkerError::ToolExecutionError("MCP client not connected".to_string())
})?;
debug!("Calling MCP tool '{}' with args: {}", self.tool_name, args);
let result = client
.call_tool(&self.tool_name, Some(args))
.await
.map_err(|e| {
crate::WorkerError::ToolExecutionError(format!(
"Failed to call MCP tool '{}': {}",
self.tool_name, e
))
})?;
debug!("MCP tool '{}' returned: {:?}", self.tool_name, result);
// Convert MCP result to JSON
Ok(convert_mcp_result_to_json(&result))
}
}
#[async_trait]
impl Tool for SingleMcpTool {
fn name(&self) -> &str {
&self.tool_name
}
fn description(&self) -> &str {
&self.tool_description
}
fn parameters_schema(&self) -> Value {
self.tool_schema.clone()
}
async fn execute(&self, args: Value) -> ToolResult<Value> {
self.call_mcp_tool(args).await
}
}
#[async_trait]
impl Tool for McpDynamicTool {
fn name(&self) -> &str {
"mcp_proxy"
}
fn description(&self) -> &str {
"Execute tools from external MCP servers"
}
fn parameters_schema(&self) -> Value {
serde_json::json!({
"type": "object",
"properties": {
"tool_name": {
"type": "string",
"description": "Name of the MCP tool to execute"
},
"tool_args": {
"type": "object",
"description": "Arguments to pass to the MCP tool",
"additionalProperties": true
}
},
"required": ["tool_name", "tool_args"]
})
}
async fn execute(&self, args: Value) -> ToolResult<Value> {
let tool_name = args
.get("tool_name")
.and_then(|v| v.as_str())
.ok_or_else(|| {
crate::WorkerError::ToolExecutionError(
"Missing required parameter 'tool_name'".to_string(),
)
})?;
let tool_args = args
.get("tool_args")
.ok_or_else(|| {
crate::WorkerError::ToolExecutionError(
"Missing required parameter 'tool_args'".to_string(),
)
})?
.clone();
// ツールが存在するか確認
match self.find_tool_by_name(tool_name).await? {
Some(_tool) => {
// ツールを実行
let result = self.call_mcp_tool(tool_name, tool_args).await?;
Ok(serde_json::json!({
"success": true,
"tool_name": tool_name,
"result": result
}))
}
None => {
// ツールキャッシュを更新して再試行
warn!("Tool '{}' not found in cache, refreshing...", tool_name);
self.update_tools_cache().await?;
match self.find_tool_by_name(tool_name).await? {
Some(_tool) => {
let result = self.call_mcp_tool(tool_name, tool_args).await?;
Ok(serde_json::json!({
"success": true,
"tool_name": tool_name,
"result": result
}))
}
None => Err(Box::new(crate::WorkerError::ToolExecutionError(format!(
"Tool '{}' not found in MCP server '{}'",
tool_name, self.config.name
)))
as Box<dyn std::error::Error + Send + Sync>),
}
}
}
}
}
/// MCPサーバーから利用可能なツールをDynamicToolDefinitionとして取得
pub async fn get_mcp_tools_as_definitions(
config: &McpServerConfig,
) -> ToolResult<Vec<crate::types::DynamicToolDefinition>> {
let mcp_tool = McpDynamicTool::new(config.clone());
let tools = mcp_tool.fetch_tools().await?;
let mut definitions = Vec::new();
for tool in tools {
let definition = crate::types::DynamicToolDefinition {
name: tool.name.clone(),
description: tool
.description
.unwrap_or_else(|| format!("MCP tool: {}", tool.name)),
parameters_schema: tool.input_schema,
};
definitions.push(definition);
}
info!(
"Converted {} MCP tools to DynamicToolDefinitions",
definitions.len()
);
Ok(definitions)
}
/// MCPサーバーから単一のツールを取得してSingleMcpToolを作成
pub async fn create_single_mcp_tools(config: &McpServerConfig) -> ToolResult<Vec<Box<dyn Tool>>> {
let mcp_tool = McpDynamicTool::new(config.clone());
let tools = mcp_tool.fetch_tools().await?;
// 共有クライアントを作成
let shared_client = mcp_tool.client.clone();
let mut single_tools: Vec<Box<dyn Tool>> = Vec::new();
for tool in tools {
let tool_name = tool.name;
let tool_description = tool
.description
.unwrap_or_else(|| format!("MCP tool: {}", tool_name));
let tool_schema = tool.input_schema;
let single_tool = SingleMcpTool::new(
tool_name,
tool_description,
tool_schema,
shared_client.clone(),
);
single_tools.push(Box::new(single_tool));
}
info!(
"Created {} SingleMcpTools from MCP server '{}'",
single_tools.len(),
config.name
);
Ok(single_tools)
}
/// MCPサーバーとの接続をテストする
pub async fn test_mcp_connection(
config: &McpServerConfig,
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
info!("Testing MCP connection to server: {}", config.name);
let mut client = McpClient::new();
match client
.connect(config.command.clone(), config.args.clone())
.await
{
Ok(()) => {
info!("Successfully connected to MCP server: {}", config.name);
// Test listing tools
match client.list_tools().await {
Ok(tools) => {
info!(
"MCP server '{}' provides {} tools",
config.name,
tools.len()
);
for tool in &tools {
debug!(
"Available tool: {} - {}",
tool.name,
tool.description.as_deref().unwrap_or("No description")
);
}
}
Err(e) => {
warn!(
"Failed to list tools from MCP server '{}': {}",
config.name, e
);
}
}
// Close connection
let _ = client.close().await;
Ok(true)
}
Err(e) => {
error!("Failed to connect to MCP server '{}': {}", config.name, e);
Ok(false)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mcp_server_config() {
let config =
McpServerConfig::new("npx", vec!["-y", "@modelcontextprotocol/server-everything"]);
assert_eq!(config.command, "npx");
assert_eq!(
config.args,
vec!["-y", "@modelcontextprotocol/server-everything"]
);
assert_eq!(
config.name,
"npx(-y @modelcontextprotocol/server-everything)"
);
}
#[tokio::test]
async fn test_mcp_tool_creation() {
let config = McpServerConfig::new("echo", vec!["test"]);
let tool = McpDynamicTool::new(config);
assert_eq!(tool.name(), "mcp_proxy");
assert!(!tool.description().is_empty());
let schema = tool.parameters_schema();
assert!(schema.is_object());
assert!(schema.get("properties").is_some());
}
}

View File

@ -0,0 +1,333 @@
use crate::config_parser::ConfigParser;
use crate::prompt_types::*;
use crate::types::{Message, Role};
use handlebars::{Context, Handlebars, Helper, HelperResult, Output, RenderContext};
use std::fs;
use std::path::Path;
/// プロンプト構築システム
#[derive(Clone)]
pub struct PromptComposer {
config: PromptRoleConfig,
handlebars: Handlebars<'static>,
context: PromptContext,
system_prompt: Option<String>,
}
impl PromptComposer {
/// 設定ファイルから新しいインスタンスを作成
pub fn from_config_file<P: AsRef<Path>>(
config_path: P,
context: PromptContext,
) -> Result<Self, PromptError> {
let config = ConfigParser::parse_from_file(config_path)?;
Self::from_config(config, context)
}
/// 設定オブジェクトから新しいインスタンスを作成
pub fn from_config(
config: PromptRoleConfig,
context: PromptContext,
) -> Result<Self, PromptError> {
let mut handlebars = Handlebars::new();
// カスタムヘルパー関数を登録
Self::register_custom_helpers(&mut handlebars)?;
let mut composer = Self {
config,
handlebars,
context,
system_prompt: None,
};
// パーシャルテンプレートを読み込み・登録
composer.load_partials()?;
Ok(composer)
}
/// セッション開始時にシステムプロンプトを事前構築
pub fn initialize_session(&mut self, initial_messages: &[Message]) -> Result<(), PromptError> {
let system_prompt = self.compose_system_prompt(initial_messages)?;
self.system_prompt = Some(system_prompt);
Ok(())
}
/// メインのプロンプト構築メソッド
pub fn compose(&self, messages: &[Message]) -> Result<Vec<Message>, PromptError> {
if let Some(system_prompt) = &self.system_prompt {
// システムプロンプトが既に構築済みの場合、それを使用
let mut result_messages = vec![Message::new(Role::System, system_prompt.clone())];
// ユーザーメッセージを追加
for msg in messages {
if msg.role != Role::System {
result_messages.push(msg.clone());
}
}
Ok(result_messages)
} else {
// フォールバック: 従来の動的構築
self.compose_with_context(messages, &self.context)
}
}
/// ツール情報を含むセッション初期化
pub fn initialize_session_with_tools(
&mut self,
initial_messages: &[Message],
tools_schema: &serde_json::Value,
) -> Result<(), PromptError> {
// 一時的にコンテキストをコピーしてツールスキーマを追加
let mut temp_context = self.context.clone();
temp_context
.variables
.insert("tools_schema".to_string(), tools_schema.clone());
let system_prompt =
self.compose_system_prompt_with_context(initial_messages, &temp_context)?;
self.system_prompt = Some(system_prompt);
Ok(())
}
/// ツール情報を含むプロンプト構築(後方互換性のため保持)
pub fn compose_with_tools(
&self,
messages: &[Message],
tools_schema: &serde_json::Value,
) -> Result<Vec<Message>, PromptError> {
if let Some(system_prompt) = &self.system_prompt {
// システムプロンプトが既に構築済みの場合、それを使用
let mut result_messages = vec![Message::new(Role::System, system_prompt.clone())];
// ユーザーメッセージを追加
for msg in messages {
if msg.role != Role::System {
result_messages.push(msg.clone());
}
}
Ok(result_messages)
} else {
// フォールバック: 従来の動的構築
let mut temp_context = self.context.clone();
temp_context
.variables
.insert("tools_schema".to_string(), tools_schema.clone());
self.compose_with_context(messages, &temp_context)
}
}
/// システムプロンプトのみを構築(セッション初期化用)
fn compose_system_prompt(&self, messages: &[Message]) -> Result<String, PromptError> {
self.compose_system_prompt_with_context(messages, &self.context)
}
/// コンテキストを指定してシステムプロンプトを構築
fn compose_system_prompt_with_context(
&self,
messages: &[Message],
context: &PromptContext,
) -> Result<String, PromptError> {
// コンテキスト変数を準備
let mut template_data = self.prepare_template_data_with_context(messages, context)?;
// 条件評価と変数の動的設定
self.apply_conditions(&mut template_data)?;
// メインテンプレートを実行
let system_prompt = self
.handlebars
.render_template(&self.config.template, &template_data)
.map_err(PromptError::Handlebars)?;
Ok(system_prompt)
}
/// コンテキストを指定してプロンプトを構築(後方互換性のため保持)
fn compose_with_context(
&self,
messages: &[Message],
context: &PromptContext,
) -> Result<Vec<Message>, PromptError> {
let system_prompt = self.compose_system_prompt_with_context(messages, context)?;
// システムメッセージとユーザーメッセージを結合
let mut result_messages = vec![Message::new(Role::System, system_prompt)];
// ユーザーメッセージを追加
for msg in messages {
if msg.role != Role::System {
result_messages.push(msg.clone());
}
}
Ok(result_messages)
}
/// カスタムヘルパー関数を登録
fn register_custom_helpers(handlebars: &mut Handlebars) -> Result<(), PromptError> {
// 基本的なヘルパーのみ実装(複雑なライフタイム問題を回避)
handlebars.register_helper("include_file", Box::new(include_file_helper));
handlebars.register_helper("workspace_content", Box::new(workspace_content_helper));
Ok(())
}
/// パーシャルテンプレートを読み込み・登録
fn load_partials(&mut self) -> Result<(), PromptError> {
if let Some(partials) = &self.config.partials {
for (name, partial_config) in partials {
let content = self.load_partial_content(partial_config)?;
self.handlebars
.register_partial(name, content)
.map_err(|e| PromptError::PartialLoading(e.to_string()))?;
}
}
Ok(())
}
/// パーシャルの内容を読み込み(フォールバック対応)
fn load_partial_content(&self, partial_config: &PartialConfig) -> Result<String, PromptError> {
let primary_path = ConfigParser::resolve_path(&partial_config.path)?;
// メインパスを試行
if let Ok(content) = fs::read_to_string(&primary_path) {
return Ok(content);
}
// フォールバックパスを試行
if let Some(fallback) = &partial_config.fallback {
let fallback_path = ConfigParser::resolve_path(fallback)?;
if let Ok(content) = fs::read_to_string(&fallback_path) {
return Ok(content);
}
}
Err(PromptError::FileNotFound(format!(
"Could not load partial '{}' from {} (fallback: {:?})",
partial_config.path,
primary_path.display(),
partial_config.fallback
)))
}
/// コンテキストを指定してテンプレート用のデータを準備
fn prepare_template_data_with_context(
&self,
messages: &[Message],
context: &PromptContext,
) -> Result<serde_json::Value, PromptError> {
let user_input = messages
.iter()
.filter(|m| m.role == Role::User)
.map(|m| m.content.as_str())
.collect::<Vec<_>>()
.join("\n\n");
let mut data = serde_json::json!({
"workspace": context.workspace,
"model": context.model,
"session": context.session,
"user_input": user_input,
"tools": context.variables.get("tools_schema").unwrap_or(&serde_json::Value::Null),
"tools_schema": context.variables.get("tools_schema").unwrap_or(&serde_json::Value::Null),
});
// 設定ファイルの変数を追加
if let Some(variables) = &self.config.variables {
for (key, value_template) in variables {
// 変数値もHandlebarsテンプレートとして処理
let resolved_value = self
.handlebars
.render_template(value_template, &data)
.map_err(PromptError::Handlebars)?;
data[key] = serde_json::Value::String(resolved_value);
}
}
// コンテキストの追加変数をマージ
for (key, value) in &context.variables {
data[key] = value.clone();
}
Ok(data)
}
/// 条件評価と動的変数設定
fn apply_conditions(&self, data: &mut serde_json::Value) -> Result<(), PromptError> {
if let Some(conditions) = &self.config.conditions {
for (_condition_name, condition_config) in conditions {
// 条件式を評価
let condition_result = self
.handlebars
.render_template(&condition_config.when, data)
.map_err(PromptError::Handlebars)?;
// 条件が真の場合、変数を適用
if condition_result.trim() == "true" {
if let Some(variables) = &condition_config.variables {
for (key, value_template) in variables {
let resolved_value = self
.handlebars
.render_template(value_template, data)
.map_err(PromptError::Handlebars)?;
data[key] = serde_json::Value::String(resolved_value);
}
}
}
}
}
Ok(())
}
}
// カスタムヘルパー関数の実装
fn include_file_helper(
h: &Helper,
_hbs: &Handlebars,
_ctx: &Context,
_rc: &mut RenderContext,
out: &mut dyn Output,
) -> HelperResult {
let file_path = h.param(0).and_then(|v| v.value().as_str()).unwrap_or("");
match ConfigParser::resolve_path(file_path) {
Ok(path) => {
match fs::read_to_string(&path) {
Ok(content) => {
out.write(&content)?;
}
Err(_) => {
// ファイルが見つからない場合は空文字を出力
out.write("")?;
}
}
}
Err(_) => {
out.write("")?;
}
}
Ok(())
}
fn workspace_content_helper(
_h: &Helper,
_hbs: &Handlebars,
ctx: &Context,
_rc: &mut RenderContext,
out: &mut dyn Output,
) -> HelperResult {
if let Some(workspace) = ctx.data().get("workspace") {
if let Some(content) = workspace.get("nia_md_content") {
if let Some(content_str) = content.as_str() {
out.write(content_str)?;
}
}
}
Ok(())
}

377
worker/src/prompt_types.rs Normal file
View File

@ -0,0 +1,377 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
/// ロール設定ファイルの型定義
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptRoleConfig {
pub name: String,
pub description: String,
pub version: Option<String>,
pub template: String,
pub partials: Option<HashMap<String, PartialConfig>>,
pub variables: Option<HashMap<String, String>>,
pub conditions: Option<HashMap<String, ConditionConfig>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PartialConfig {
pub path: String,
pub fallback: Option<String>,
pub description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConditionConfig {
pub when: String,
pub variables: Option<HashMap<String, String>>,
pub template_override: Option<String>,
}
/// システム情報
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SystemInfo {
pub os_name: String, // linux, windows, macos
pub kernel_version: String, // Linux 6.15.6
pub distribution: String, // NixOS 25.11 (Xantusia)
pub architecture: String, // x86_64
pub full_system_info: String, // 全体の情報を組み合わせた文字列
pub working_directory: String,
pub current_time: String,
pub timezone: String,
}
/// ワークスペースコンテキスト
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkspaceContext {
pub root_path: PathBuf,
pub nia_md_content: Option<String>,
pub project_type: Option<ProjectType>,
pub git_info: Option<GitInfo>,
pub has_nia_md: bool,
pub project_name: Option<String>,
pub system_info: SystemInfo,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GitInfo {
pub repo_name: Option<String>,
pub current_branch: Option<String>,
pub last_commit_summary: Option<String>,
pub is_clean: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ProjectType {
Rust,
JavaScript,
TypeScript,
Python,
Go,
Java,
Cpp,
Unknown,
}
/// モデルコンテキスト
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelContext {
pub provider: crate::types::LlmProvider,
pub model_name: String,
pub capabilities: ModelCapabilities,
pub supports_native_tools: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelCapabilities {
pub supports_tools: bool,
pub supports_function_calling: bool,
pub supports_vision: bool,
pub supports_multimodal: Option<bool>,
pub context_length: Option<u64>,
pub capabilities: Vec<String>,
pub needs_verification: Option<bool>,
}
/// セッションコンテキスト
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionContext {
pub conversation_id: Option<String>,
pub message_count: usize,
pub active_tools: Vec<String>,
pub user_preferences: Option<HashMap<String, String>>,
}
/// 全体的なプロンプトコンテキスト
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptContext {
pub workspace: WorkspaceContext,
pub model: ModelContext,
pub session: SessionContext,
pub variables: HashMap<String, serde_json::Value>,
}
/// プロンプト構築エラー
#[derive(Debug, thiserror::Error)]
pub enum PromptError {
#[error("Template compilation error: {0}")]
TemplateCompilation(String),
#[error("Variable resolution error: {0}")]
VariableResolution(String),
#[error("Partial loading error: {0}")]
PartialLoading(String),
#[error("File not found: {0}")]
FileNotFound(String),
#[error("Workspace detection error: {0}")]
WorkspaceDetection(String),
#[error("Git information error: {0}")]
GitInfo(String),
#[error("Handlebars error: {0}")]
Handlebars(#[from] handlebars::RenderError),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("YAML parsing error: {0}")]
YamlParsing(#[from] serde_yaml::Error),
}
impl SystemInfo {
/// システム情報を詳細に収集する
pub fn collect() -> Self {
let current_dir = std::env::current_dir()
.map(|p| p.to_string_lossy().to_string())
.unwrap_or_else(|_| ".".to_string());
let now = chrono::Local::now();
let current_time = now.format("%Y-%m-%d %H:%M:%S").to_string();
let timezone = now.format("%Z").to_string();
let os_name = std::env::consts::OS.to_string();
let architecture = std::env::consts::ARCH.to_string();
let (kernel_version, distribution) = Self::get_system_details();
// フルシステム情報を構築
let full_system_info = if distribution.is_empty() {
format!("{} {}", kernel_version, architecture)
} else {
format!("{} - {} {}", kernel_version, distribution, architecture)
};
Self {
os_name,
kernel_version,
distribution,
architecture,
full_system_info,
working_directory: current_dir,
current_time,
timezone,
}
}
/// OSの詳細情報を取得
fn get_system_details() -> (String, String) {
#[cfg(target_os = "linux")]
{
Self::get_linux_details()
}
#[cfg(target_os = "windows")]
{
Self::get_windows_details()
}
#[cfg(target_os = "macos")]
{
Self::get_macos_details()
}
#[cfg(not(any(target_os = "linux", target_os = "windows", target_os = "macos")))]
{
(std::env::consts::OS.to_string(), String::new())
}
}
#[cfg(target_os = "linux")]
fn get_linux_details() -> (String, String) {
use std::process::Command;
// カーネルバージョンを取得
let kernel_version = Command::new("uname")
.arg("-r")
.output()
.ok()
.and_then(|output| {
if output.status.success() {
Some(format!(
"Linux {}",
String::from_utf8_lossy(&output.stdout).trim()
))
} else {
None
}
})
.unwrap_or_else(|| "Linux".to_string());
// ディストリビューション情報を取得
let distribution = Self::get_linux_distribution();
(kernel_version, distribution)
}
#[cfg(target_os = "linux")]
fn get_linux_distribution() -> String {
use std::fs;
// /etc/os-release を読み取る
if let Ok(content) = fs::read_to_string("/etc/os-release") {
let mut name = None;
let mut version = None;
let mut pretty_name = None;
for line in content.lines() {
if let Some(value) = line.strip_prefix("NAME=") {
name = Some(value.trim_matches('"').to_string());
} else if let Some(value) = line.strip_prefix("VERSION=") {
version = Some(value.trim_matches('"').to_string());
} else if let Some(value) = line.strip_prefix("PRETTY_NAME=") {
pretty_name = Some(value.trim_matches('"').to_string());
}
}
// PRETTY_NAME があればそれを使用、なければ NAME + VERSION
if let Some(pretty) = pretty_name {
return pretty;
} else if let (Some(n), Some(v)) = (name, version) {
return format!("{} {}", n, v);
}
}
// /etc/issue をフォールバックとして試行
if let Ok(content) = fs::read_to_string("/etc/issue") {
let first_line = content.lines().next().unwrap_or("").trim();
if !first_line.is_empty() && !first_line.contains("\\") {
return first_line.to_string();
}
}
String::new()
}
#[cfg(target_os = "windows")]
fn get_windows_details() -> (String, String) {
use std::process::Command;
let version = Command::new("cmd")
.args(&["/C", "ver"])
.output()
.ok()
.and_then(|output| {
if output.status.success() {
Some(String::from_utf8_lossy(&output.stdout).trim().to_string())
} else {
None
}
})
.unwrap_or_else(|| "Windows".to_string());
(version, String::new())
}
#[cfg(target_os = "macos")]
fn get_macos_details() -> (String, String) {
use std::process::Command;
let version = Command::new("sw_vers")
.arg("-productVersion")
.output()
.ok()
.and_then(|output| {
if output.status.success() {
Some(format!(
"macOS {}",
String::from_utf8_lossy(&output.stdout).trim()
))
} else {
None
}
})
.unwrap_or_else(|| "macOS".to_string());
(version, String::new())
}
}
impl Default for SystemInfo {
fn default() -> Self {
Self::collect()
}
}
impl Default for WorkspaceContext {
fn default() -> Self {
Self {
root_path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
nia_md_content: None,
project_type: None,
git_info: None,
has_nia_md: false,
project_name: None,
system_info: SystemInfo::default(),
}
}
}
impl Default for ModelCapabilities {
fn default() -> Self {
Self {
supports_tools: false,
supports_function_calling: false,
supports_vision: false,
supports_multimodal: None,
context_length: None,
capabilities: Vec::new(),
needs_verification: Some(false),
}
}
}
impl Default for SessionContext {
fn default() -> Self {
Self {
conversation_id: None,
message_count: 0,
active_tools: Vec::new(),
user_preferences: None,
}
}
}
impl Default for PromptRoleConfig {
fn default() -> Self {
let mut partials = HashMap::new();
partials.insert(
"role_definition".to_string(),
PartialConfig {
path: "./resources/prompts/cli-assistant.md".to_string(),
fallback: None,
description: Some("Default role definition".to_string()),
},
);
Self {
name: "default".to_string(),
description: "Default dynamic role configuration".to_string(),
version: Some("1.0.0".to_string()),
template: "{{>role_definition}}".to_string(),
partials: Some(partials),
variables: None,
conditions: None,
}
}
}

View File

@ -0,0 +1,252 @@
use crate::config_parser::ConfigParser;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_parse_basic_config() {
let yaml_content = r##"
name: "Test Assistant"
description: "A test configuration"
version: "1.0"
template: |
# Test Role
{{>role_header}}
{{#if workspace.has_nia_md}}
# Project Context
{{workspace.nia_md_content}}
{{/if}}
partials:
role_header:
path: "#nia/prompts/headers/role.md"
description: "Basic role definition"
variables:
max_context_length: "{{model.context_length}}"
project_name: "{{workspace.project_name}}"
"##;
let config =
ConfigParser::parse_from_string(yaml_content).expect("Failed to parse basic config");
assert_eq!(config.name, "Test Assistant");
assert_eq!(config.description, "A test configuration");
assert_eq!(config.version, Some("1.0".to_string()));
assert!(!config.template.is_empty());
// パーシャルのテスト
let partials = config.partials.expect("Partials should be present");
assert!(partials.contains_key("role_header"));
let role_header = &partials["role_header"];
assert_eq!(role_header.path, "#nia/prompts/headers/role.md");
assert_eq!(
role_header.description,
Some("Basic role definition".to_string())
);
// 変数のテスト
let variables = config.variables.expect("Variables should be present");
assert!(variables.contains_key("max_context_length"));
assert!(variables.contains_key("project_name"));
}
#[test]
fn test_parse_minimal_config() {
let yaml_content = r##"
name: "Minimal Assistant"
description: "A minimal configuration"
template: "Hello {{user_input}}"
"##;
let config =
ConfigParser::parse_from_string(yaml_content).expect("Failed to parse minimal config");
assert_eq!(config.name, "Minimal Assistant");
assert_eq!(config.description, "A minimal configuration");
assert_eq!(config.template, "Hello {{user_input}}");
assert!(config.partials.is_none());
assert!(config.variables.is_none());
assert!(config.conditions.is_none());
}
#[test]
fn test_parse_with_conditions() {
let yaml_content = r##"
name: "Conditional Assistant"
description: "Configuration with conditions"
template: "Base template"
conditions:
native_tools_enabled:
when: "{{model.supports_native_tools}}"
variables:
tool_format: "native"
include_tool_schemas: false
xml_tools_enabled:
when: "{{not model.supports_native_tools}}"
variables:
tool_format: "xml"
include_tool_schemas: true
"##;
let config = ConfigParser::parse_from_string(yaml_content)
.expect("Failed to parse config with conditions");
let conditions = config.conditions.expect("Conditions should be present");
assert!(conditions.contains_key("native_tools_enabled"));
assert!(conditions.contains_key("xml_tools_enabled"));
let native_condition = &conditions["native_tools_enabled"];
assert_eq!(native_condition.when, "{{model.supports_native_tools}}");
let variables = native_condition
.variables
.as_ref()
.expect("Variables should be present");
assert_eq!(variables.get("tool_format"), Some(&"native".to_string()));
assert_eq!(
variables.get("include_tool_schemas"),
Some(&"false".to_string())
);
}
#[test]
fn test_validation_errors() {
// 空の名前
let invalid_yaml = r##"
name: ""
description: "Test"
template: "Test template"
"##;
let result = ConfigParser::parse_from_string(invalid_yaml);
assert!(result.is_err());
// 空のテンプレート
let invalid_yaml = r##"
name: "Test"
description: "Test"
template: ""
"##;
let result = ConfigParser::parse_from_string(invalid_yaml);
assert!(result.is_err());
// 空のパーシャルパス
let invalid_yaml = r##"
name: "Test"
description: "Test"
template: "Test template"
partials:
empty_path:
path: ""
"##;
let result = ConfigParser::parse_from_string(invalid_yaml);
assert!(result.is_err());
}
#[test]
fn test_parse_from_file() {
let yaml_content = r##"
name: "File Test Assistant"
description: "Testing file parsing"
template: "File content {{user_input}}"
"##;
let mut temp_file = NamedTempFile::new().expect("Failed to create temp file");
temp_file
.write_all(yaml_content.as_bytes())
.expect("Failed to write to temp file");
let config =
ConfigParser::parse_from_file(temp_file.path()).expect("Failed to parse config from file");
assert_eq!(config.name, "File Test Assistant");
assert_eq!(config.description, "Testing file parsing");
assert_eq!(config.template, "File content {{user_input}}");
}
#[test]
fn test_resolve_path() {
// #nia/ prefix
let path =
ConfigParser::resolve_path("#nia/prompts/test.md").expect("Failed to resolve nia path");
assert!(
path.to_string_lossy()
.contains("nia-cli/resources/prompts/test.md")
);
// #workspace/ prefix
let path = ConfigParser::resolve_path("#workspace/config.md")
.expect("Failed to resolve workspace path");
assert!(path.to_string_lossy().contains(".nia/config.md"));
// #user/ prefix
let path =
ConfigParser::resolve_path("#user/settings.md").expect("Failed to resolve user path");
assert!(path.to_string_lossy().contains("settings.md"));
// Regular path
let path =
ConfigParser::resolve_path("regular/path.md").expect("Failed to resolve regular path");
assert_eq!(path.to_string_lossy(), "regular/path.md");
}
#[test]
fn test_complex_template_syntax() {
let yaml_content = r##"
name: "Complex Template Assistant"
description: "Testing complex Handlebars syntax"
template: |
# Dynamic Role
{{>role_header}}
{{#if workspace.has_nia_md}}
# Project: {{workspace.project_name}}
{{workspace.nia_md_content}}
{{/if}}
{{#if_native_tools model.supports_native_tools}}
Native tools are supported.
{{else}}
Using XML-based tool calls.
Available tools:
```json
{{tools_schema}}
```
{{/if_native_tools}}
{{#model_specific model.provider}}
{{#case "Claude"}}
Claude-specific instructions here.
{{/case}}
{{#case "Gemini"}}
Gemini-specific instructions here.
{{/case}}
{{#default}}
Generic model instructions.
{{/default}}
{{/model_specific}}
partials:
role_header:
path: "#nia/prompts/headers/role.md"
"##;
let config =
ConfigParser::parse_from_string(yaml_content).expect("Failed to parse complex template");
assert!(!config.template.is_empty());
assert!(config.template.contains("{{>role_header}}"));
assert!(config.template.contains("{{#if workspace.has_nia_md}}"));
assert!(
config
.template
.contains("{{#if_native_tools model.supports_native_tools}}")
);
assert!(
config
.template
.contains("{{#model_specific model.provider}}")
);
}

View File

@ -0,0 +1,333 @@
use crate::types::{LlmProvider, Message, Role};
use crate::workspace_detector::WorkspaceDetector;
use std::collections::HashMap;
use std::fs;
use tempfile::TempDir;
#[test]
#[ignore] // Temporarily disabled due to missing dependencies
fn test_full_dynamic_prompt_composition() {
// テスト用の一時ディレクトリを作成
let temp_dir = TempDir::new().expect("Failed to create temp dir");
let temp_path = temp_dir.path();
// テスト用のNIA.mdファイルを作成
let nia_md_content = r#"# Test Project
This is a test project for dynamic prompt composition.
## Features
- Dynamic prompt generation
- Workspace detection
- Model-specific optimizations
"#;
fs::write(temp_path.join("NIA.md"), nia_md_content).expect("Failed to write NIA.md");
// テスト用のCargoファイルを作成Rustプロジェクトとして認識させる
let cargo_toml = r#"[package]
name = "test-project"
version = "0.1.0"
edition = "2021"
"#;
fs::write(temp_path.join("Cargo.toml"), cargo_toml).expect("Failed to write Cargo.toml");
// ワークスペースコンテキストを取得
let workspace = WorkspaceDetector::detect_workspace_from_path(temp_path)
.expect("Failed to detect workspace");
assert!(workspace.has_nia_md);
assert_eq!(workspace.project_type, Some(ProjectType::Rust));
assert!(workspace.nia_md_content.is_some());
assert_eq!(workspace.project_name, Some("test-project".to_string()));
// モデルコンテキストを作成
let model_context = ModelContext {
provider: LlmProvider::Claude,
model_name: "claude-3-sonnet".to_string(),
capabilities: ModelCapabilities {
supports_tools: true,
supports_function_calling: true,
..Default::default()
},
supports_native_tools: true,
};
// セッションコンテキストを作成
let session_context = SessionContext {
conversation_id: Some("test-conv".to_string()),
message_count: 1,
active_tools: vec!["file_read".to_string(), "file_write".to_string()],
user_preferences: None,
};
// 全体的なプロンプトコンテキストを作成
let prompt_context = PromptContext {
workspace,
model: model_context,
session: session_context,
variables: HashMap::new(),
};
// 動的設定を作成
let config = create_test_dynamic_config();
// DynamicPromptComposerを作成
let mut composer = DynamicPromptComposer::from_config(config, prompt_context)
.expect("Failed to create composer");
// テストメッセージ
let messages = vec![Message {
role: Role::User,
content: "Please help me understand the project structure".to_string(),
}];
// プロンプトを構築
let result = composer
.compose(&messages)
.expect("Failed to compose prompt");
assert!(!result.is_empty());
assert_eq!(result[0].role, Role::System);
// 生成されたプロンプトにワークスペース情報が含まれていることを確認
let system_prompt = &result[0].content;
assert!(system_prompt.contains("Test Project"));
assert!(system_prompt.contains("dynamic prompt generation"));
assert!(system_prompt.contains("test-project"));
}
#[test]
#[ignore] // Temporarily disabled due to missing dependencies
fn test_native_tools_vs_xml_tools() {
let temp_dir = TempDir::new().expect("Failed to create temp dir");
let workspace = WorkspaceDetector::detect_workspace_from_path(temp_dir.path())
.expect("Failed to detect workspace");
// ネイティブツールサポートありのモデル
let native_model = ModelContext {
provider: LlmProvider::Claude,
model_name: "claude-3-sonnet".to_string(),
capabilities: ModelCapabilities {
supports_tools: true,
supports_function_calling: true,
..Default::default()
},
supports_native_tools: true,
};
// XMLツールのみのモデル
let xml_model = ModelContext {
provider: LlmProvider::Ollama,
model_name: "llama3".to_string(),
capabilities: ModelCapabilities {
supports_tools: false,
supports_function_calling: false,
..Default::default()
},
supports_native_tools: false,
};
let session = SessionContext::default();
// 両方のモデルでプロンプトを生成
let native_context = PromptContext {
workspace: workspace.clone(),
model: native_model,
session: session.clone(),
variables: HashMap::new(),
};
let xml_context = PromptContext {
workspace: workspace.clone(),
model: xml_model,
session: session.clone(),
variables: HashMap::new(),
};
let config = create_test_dynamic_config();
let mut native_composer = DynamicPromptComposer::from_config(config.clone(), native_context)
.expect("Failed to create native composer");
let mut xml_composer = DynamicPromptComposer::from_config(config, xml_context)
.expect("Failed to create xml composer");
let messages = vec![Message {
role: Role::User,
content: "Test message".to_string(),
}];
let native_result = native_composer
.compose(&messages)
.expect("Failed to compose native prompt");
let xml_result = xml_composer
.compose(&messages)
.expect("Failed to compose xml prompt");
// 両方のプロンプトが生成されることを確認
assert!(!native_result.is_empty());
assert!(!xml_result.is_empty());
// ネイティブツール用プロンプトとXMLツール用プロンプトが異なることを確認
assert_ne!(native_result[0].content, xml_result[0].content);
}
#[test]
#[ignore] // Temporarily disabled due to missing dependencies
fn test_workspace_detection_without_nia_md() {
let temp_dir = TempDir::new().expect("Failed to create temp dir");
// .nia ディレクトリのみ作成NIA.mdなし
fs::create_dir(temp_dir.path().join(".nia")).expect("Failed to create .nia dir");
let workspace = WorkspaceDetector::detect_workspace_from_path(temp_dir.path())
.expect("Failed to detect workspace");
assert!(!workspace.has_nia_md);
assert!(workspace.nia_md_content.is_none());
assert_eq!(workspace.project_type, Some(ProjectType::Unknown));
}
#[test]
#[ignore] // Temporarily disabled due to missing dependencies
fn test_project_type_detection() {
let temp_dir = TempDir::new().expect("Failed to create temp dir");
let temp_path = temp_dir.path();
// TypeScriptプロジェクト
fs::write(temp_path.join("package.json"), r#"{"name": "test"}"#)
.expect("Failed to write package.json");
fs::write(temp_path.join("tsconfig.json"), "{}").expect("Failed to write tsconfig.json");
let workspace = WorkspaceDetector::detect_workspace_from_path(temp_path)
.expect("Failed to detect workspace");
assert_eq!(workspace.project_type, Some(ProjectType::TypeScript));
}
#[test]
#[ignore] // Temporarily disabled due to missing dependencies
fn test_tools_schema_integration() {
let temp_dir = TempDir::new().expect("Failed to create temp dir");
let workspace = WorkspaceDetector::detect_workspace_from_path(temp_dir.path())
.expect("Failed to detect workspace");
let model_context = ModelContext {
provider: LlmProvider::Gemini,
model_name: "gemini-1.5-flash".to_string(),
capabilities: ModelCapabilities {
supports_tools: false,
supports_function_calling: false,
..Default::default()
},
supports_native_tools: false,
};
let session_context = SessionContext::default();
let prompt_context = PromptContext {
workspace,
model: model_context,
session: session_context,
variables: HashMap::new(),
};
let config = create_test_dynamic_config();
let mut composer = DynamicPromptComposer::from_config(config, prompt_context)
.expect("Failed to create composer");
// ツールスキーマを作成
let tools_schema = serde_json::json!([
{
"name": "file_read",
"description": "Read a file",
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string"}
}
}
}
]);
let messages = vec![Message {
role: Role::User,
content: "Read a file for me".to_string(),
}];
// ツール情報付きでプロンプトを構築
let result = composer
.compose_with_tools(&messages, &tools_schema)
.expect("Failed to compose with tools");
assert!(!result.is_empty());
// XMLツールモデルなので、ツール情報がプロンプトに含まれるはず
let system_prompt = &result[0].content;
assert!(system_prompt.contains("file_read"));
}
// テスト用の動的設定を作成
fn create_test_dynamic_config() -> DynamicRoleConfig {
let mut variables = HashMap::new();
variables.insert(
"project_name".to_string(),
"{{workspace.project_name}}".to_string(),
);
variables.insert("model_name".to_string(), "{{model.model_name}}".to_string());
let mut conditions = HashMap::new();
// ネイティブツール条件
let mut native_vars = HashMap::new();
native_vars.insert("tool_format".to_string(), "native".to_string());
conditions.insert(
"native_tools".to_string(),
ConditionConfig {
when: "{{model.supports_native_tools}}".to_string(),
variables: Some(native_vars),
template_override: None,
},
);
// XMLツール条件
let mut xml_vars = HashMap::new();
xml_vars.insert("tool_format".to_string(), "xml".to_string());
conditions.insert(
"xml_tools".to_string(),
ConditionConfig {
when: "{{not model.supports_native_tools}}".to_string(),
variables: Some(xml_vars),
template_override: None,
},
);
DynamicRoleConfig {
name: "Test Assistant".to_string(),
description: "A test configuration".to_string(),
version: Some("1.0".to_string()),
template: r#"# Test Role
{{#if workspace.has_nia_md}}
# Project Context
Project: {{workspace.project_name}}
{{workspace.nia_md_content}}
{{/if}}
{{#if model.supports_native_tools}}
Native tools are supported for {{model.model_name}}.
{{else}}
Using XML-based tool calls for {{model.model_name}}.
{{#if tools_schema}}
Available tools: {{tools_schema}}
{{/if}}
{{/if}}
User request: {{user_input}}
"#
.to_string(),
partials: None, // パーシャルを使わない
variables: Some(variables),
conditions: Some(conditions),
}
}

50
worker/src/types.rs Normal file
View File

@ -0,0 +1,50 @@
// Re-export all types from worker-types for backwards compatibility
pub use worker_types::*;
// Worker-specific error type
#[derive(Debug, thiserror::Error)]
pub enum WorkerError {
#[error("Tool execution failed: {0}")]
ToolExecution(String),
#[error("Tool execution error: {0}")]
ToolExecutionError(String),
#[error("LLM API error: {0}")]
LlmApiError(String),
#[error("Model not found: {0}")]
ModelNotFound(String),
#[error("JSON serialization/deserialization error: {0}")]
JsonError(#[from] serde_json::Error),
#[error("Serialization error: {0}")]
Serialization(serde_json::Error),
#[error("Network error: {0}")]
Network(String),
#[error("Configuration error: {0}")]
Config(String),
#[error("Configuration error: {0}")]
ConfigurationError(String),
#[error("General error: {0}")]
General(#[from] anyhow::Error),
#[error("Box error: {0}")]
BoxError(Box<dyn std::error::Error + Send + Sync>),
}
impl From<&str> for WorkerError {
fn from(s: &str) -> Self {
WorkerError::General(anyhow::anyhow!(s.to_string()))
}
}
impl From<String> for WorkerError {
fn from(s: String) -> Self {
WorkerError::General(anyhow::anyhow!(s))
}
}
impl From<Box<dyn std::error::Error + Send + Sync>> for WorkerError {
fn from(e: Box<dyn std::error::Error + Send + Sync>) -> Self {
WorkerError::BoxError(e)
}
}
// Update ToolResult to use WorkerError
pub type WorkerToolResult<T> = Result<T, WorkerError>;

209
worker/src/url_config.rs Normal file
View File

@ -0,0 +1,209 @@
/// URL configuration for LLM providers with environment variable support
use std::env;
pub struct UrlConfig;
impl UrlConfig {
/// Get base URL for a provider with environment variable override support
pub fn get_base_url(provider: &str) -> String {
let env_var = format!("{}_BASE_URL", provider.to_uppercase());
// Check environment variable first
if let Ok(url) = env::var(&env_var) {
return url;
}
// Return default URLs (domain part only)
match provider.to_lowercase().as_str() {
"openai" => "https://api.openai.com".to_string(),
"anthropic" | "claude" => "https://api.anthropic.com".to_string(),
"gemini" | "google" => "https://generativelanguage.googleapis.com".to_string(),
"xai" => "https://api.x.ai".to_string(),
"ollama" => "http://localhost:11434".to_string(),
_ => panic!("Unknown LLM provider: {}", provider),
}
}
/// Get models endpoint URL for a provider
pub fn get_models_url(provider: &str) -> String {
let base_url = Self::get_base_url(provider);
match provider.to_lowercase().as_str() {
"openai" => format!("{}/v1/models", base_url),
"anthropic" | "claude" => format!("{}/v1/models", base_url),
"gemini" | "google" => format!("{}/v1beta/models", base_url),
"xai" => format!("{}/v1/models", base_url),
"ollama" => format!("{}/api/tags", base_url),
_ => panic!("Unknown LLM provider: {}", provider),
}
}
/// Get chat/completion endpoint URL for a provider
pub fn get_completion_url(provider: &str) -> String {
let base_url = Self::get_base_url(provider);
match provider.to_lowercase().as_str() {
"openai" => format!("{}/v1/chat/completions", base_url),
"anthropic" | "claude" => format!("{}/v1/messages", base_url),
"gemini" | "google" => format!("{}/v1beta/models/{{model}}:generateContent", base_url),
"xai" => format!("{}/v1/chat/completions", base_url),
"ollama" => format!("{}/api/chat", base_url),
_ => panic!("Unknown LLM provider: {}", provider),
}
}
/// Get model-specific endpoint URL for a provider
pub fn get_model_url(provider: &str, model_id: &str) -> String {
let base_url = Self::get_base_url(provider);
match provider.to_lowercase().as_str() {
"openai" => format!("{}/v1/models/{}", base_url, model_id),
"anthropic" | "claude" => format!("{}/v1/models/{}", base_url, model_id),
"gemini" | "google" => format!("{}/v1beta/models/{}", base_url, model_id),
"xai" => format!("{}/v1/models/{}", base_url, model_id),
"ollama" => format!("{}/api/show", base_url), // Ollama uses different pattern
_ => panic!("Unknown LLM provider: {}", provider),
}
}
/// Get all active URL overrides from environment variables
pub fn get_active_overrides() -> Vec<(String, String)> {
let providers = ["openai", "anthropic", "gemini", "xai", "ollama"];
let mut overrides = Vec::new();
for provider in providers {
let env_var = format!("{}_BASE_URL", provider.to_uppercase());
if let Ok(url) = env::var(&env_var) {
overrides.push((provider.to_string(), url));
}
}
overrides
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
#[test]
fn test_default_urls() {
// Clean up any existing env vars first
env::remove_var("OPENAI_BASE_URL");
env::remove_var("ANTHROPIC_BASE_URL");
env::remove_var("GEMINI_BASE_URL");
env::remove_var("XAI_BASE_URL");
env::remove_var("OLLAMA_BASE_URL");
assert_eq!(UrlConfig::get_base_url("openai"), "https://api.openai.com");
assert_eq!(
UrlConfig::get_base_url("anthropic"),
"https://api.anthropic.com"
);
assert_eq!(
UrlConfig::get_base_url("gemini"),
"https://generativelanguage.googleapis.com"
);
assert_eq!(UrlConfig::get_base_url("xai"), "https://api.x.ai");
assert_eq!(UrlConfig::get_base_url("ollama"), "http://localhost:11434");
}
#[test]
fn test_env_override() {
// Clean up any existing env vars first
env::remove_var("OPENAI_BASE_URL");
env::remove_var("ANTHROPIC_BASE_URL");
env::set_var("OPENAI_BASE_URL", "https://custom.openai.com");
env::set_var("ANTHROPIC_BASE_URL", "https://custom.anthropic.com");
assert_eq!(
UrlConfig::get_base_url("openai"),
"https://custom.openai.com"
);
assert_eq!(
UrlConfig::get_base_url("anthropic"),
"https://custom.anthropic.com"
);
// Clean up
env::remove_var("OPENAI_BASE_URL");
env::remove_var("ANTHROPIC_BASE_URL");
}
#[test]
fn test_models_url() {
// Clean up any existing env vars first
env::remove_var("OPENAI_BASE_URL");
env::remove_var("ANTHROPIC_BASE_URL");
env::remove_var("OLLAMA_BASE_URL");
assert_eq!(
UrlConfig::get_models_url("openai"),
"https://api.openai.com/v1/models"
);
assert_eq!(
UrlConfig::get_models_url("anthropic"),
"https://api.anthropic.com/v1/models"
);
assert_eq!(
UrlConfig::get_models_url("ollama"),
"http://localhost:11434/api/tags"
);
}
#[test]
fn test_completion_url() {
// Clean up any existing env vars first
env::remove_var("OPENAI_BASE_URL");
env::remove_var("ANTHROPIC_BASE_URL");
env::remove_var("OLLAMA_BASE_URL");
assert_eq!(
UrlConfig::get_completion_url("openai"),
"https://api.openai.com/v1/chat/completions"
);
assert_eq!(
UrlConfig::get_completion_url("anthropic"),
"https://api.anthropic.com/v1/messages"
);
assert_eq!(
UrlConfig::get_completion_url("ollama"),
"http://localhost:11434/api/chat"
);
}
#[test]
fn test_get_active_overrides() {
// Clean up any existing env vars first
env::remove_var("OPENAI_BASE_URL");
env::remove_var("ANTHROPIC_BASE_URL");
env::remove_var("GEMINI_BASE_URL");
env::remove_var("XAI_BASE_URL");
env::remove_var("OLLAMA_BASE_URL");
// Should return empty when no overrides are set
assert_eq!(UrlConfig::get_active_overrides().len(), 0);
// Set some overrides
env::set_var("OPENAI_BASE_URL", "https://custom-openai.example.com");
env::set_var("ANTHROPIC_BASE_URL", "https://custom-anthropic.example.com");
let overrides = UrlConfig::get_active_overrides();
assert_eq!(overrides.len(), 2);
// Check if both providers are in the overrides
let providers: Vec<String> = overrides.iter().map(|(p, _)| p.clone()).collect();
assert!(providers.contains(&"openai".to_string()));
assert!(providers.contains(&"anthropic".to_string()));
// Check URLs
let openai_override = overrides.iter().find(|(p, _)| p == "openai").unwrap();
assert_eq!(openai_override.1, "https://custom-openai.example.com");
let anthropic_override = overrides.iter().find(|(p, _)| p == "anthropic").unwrap();
assert_eq!(anthropic_override.1, "https://custom-anthropic.example.com");
// Clean up
env::remove_var("OPENAI_BASE_URL");
env::remove_var("ANTHROPIC_BASE_URL");
}
}

View File

@ -0,0 +1,315 @@
use crate::prompt_types::*;
use std::fs;
use std::path::{Path, PathBuf};
use std::process::Command;
/// ワークスペース検出とプロジェクト情報収集
pub struct WorkspaceDetector;
impl WorkspaceDetector {
/// 現在のディレクトリからワークスペースを検出し、コンテキストを構築
pub fn detect_workspace() -> Result<WorkspaceContext, PromptError> {
let current_dir =
std::env::current_dir().map_err(|e| PromptError::WorkspaceDetection(e.to_string()))?;
Self::detect_workspace_from_path(&current_dir)
}
/// 指定されたパスからワークスペースを検出
pub fn detect_workspace_from_path(start_path: &Path) -> Result<WorkspaceContext, PromptError> {
// 1. プロジェクトルートを決定
let root_path = Self::find_project_root(start_path)?;
// 2. .nia/context.md を読み込み
let nia_md_content = Self::read_nia_md(&root_path);
let has_nia_md = nia_md_content.is_some();
// 3. プロジェクトタイプを推定
let project_type = Self::detect_project_type(&root_path);
// 4. Git情報を取得
let git_info = Self::get_git_info(&root_path);
// 5. プロジェクト名を決定
let project_name = Self::determine_project_name(&root_path, &git_info);
// 6. システム情報を生成
let system_info = crate::prompt_types::SystemInfo::default();
Ok(WorkspaceContext {
root_path,
nia_md_content,
project_type,
git_info,
has_nia_md,
project_name,
system_info,
})
}
/// プロジェクトルートを検出Git > .nia > 現在のディレクトリの順)
fn find_project_root(start_path: &Path) -> Result<PathBuf, PromptError> {
let mut current = start_path.to_path_buf();
loop {
// Git リポジトリルートをチェック
if current.join(".git").exists() {
return Ok(current);
}
// .nia ディレクトリをチェック
if current.join(".nia").exists() {
return Ok(current);
}
// 親ディレクトリに移動
match current.parent() {
Some(parent) => current = parent.to_path_buf(),
None => break,
}
}
// 見つからない場合は開始パスを返す
Ok(start_path.to_path_buf())
}
/// .nia/context.md ファイルを読み込み
fn read_nia_md(root_path: &Path) -> Option<String> {
let file_path = root_path.join(".nia/context.md");
if let Ok(content) = fs::read_to_string(&file_path) {
// ファイルサイズが妥当であることを確認10MB以下
if content.len() <= 10 * 1024 * 1024 {
return Some(content);
}
}
None
}
/// プロジェクトタイプを推定
fn detect_project_type(root_path: &Path) -> Option<ProjectType> {
// ファイルの存在によってプロジェクトタイプを判定
if root_path.join("Cargo.toml").exists() {
return Some(ProjectType::Rust);
}
if root_path.join("package.json").exists() {
// TypeScript か JavaScript かを判定
if root_path.join("tsconfig.json").exists()
|| root_path.join("src").join("index.ts").exists()
|| Self::check_typescript_files(root_path)
{
return Some(ProjectType::TypeScript);
}
return Some(ProjectType::JavaScript);
}
if root_path.join("pyproject.toml").exists()
|| root_path.join("setup.py").exists()
|| root_path.join("requirements.txt").exists()
{
return Some(ProjectType::Python);
}
if root_path.join("go.mod").exists() {
return Some(ProjectType::Go);
}
if root_path.join("pom.xml").exists()
|| root_path.join("build.gradle").exists()
|| root_path.join("build.gradle.kts").exists()
{
return Some(ProjectType::Java);
}
if root_path.join("CMakeLists.txt").exists() || root_path.join("Makefile").exists() {
return Some(ProjectType::Cpp);
}
Some(ProjectType::Unknown)
}
/// TypeScriptファイルの存在をチェック
fn check_typescript_files(root_path: &Path) -> bool {
// src ディレクトリ内の .ts ファイルをチェック
let src_dir = root_path.join("src");
if src_dir.exists() {
if let Ok(entries) = fs::read_dir(&src_dir) {
for entry in entries.flatten() {
if let Some(ext) = entry.path().extension() {
if ext == "ts" || ext == "tsx" {
return true;
}
}
}
}
}
false
}
/// Git情報を取得
fn get_git_info(root_path: &Path) -> Option<GitInfo> {
if !root_path.join(".git").exists() {
return None;
}
let repo_name = Self::get_git_repo_name(root_path);
let current_branch = Self::get_git_current_branch(root_path);
let last_commit_summary = Self::get_git_last_commit(root_path);
let is_clean = Self::is_git_clean(root_path);
Some(GitInfo {
repo_name,
current_branch,
last_commit_summary,
is_clean,
})
}
/// Git リポジトリ名を取得
fn get_git_repo_name(root_path: &Path) -> Option<String> {
// リモートURLから名前を取得
let output = Command::new("git")
.args(&["remote", "get-url", "origin"])
.current_dir(root_path)
.output()
.ok()?;
if output.status.success() {
let url = String::from_utf8_lossy(&output.stdout).trim().to_string();
return Self::extract_repo_name_from_url(&url);
}
// フォールバック: ディレクトリ名を使用
root_path
.file_name()
.and_then(|name| name.to_str())
.map(|s| s.to_string())
}
/// Git URL からリポジトリ名を抽出
fn extract_repo_name_from_url(url: &str) -> Option<String> {
// GitHub/GitLab/Bitbucket などの一般的なパターンに対応
if let Some(captures) = regex::Regex::new(r"([^/]+/[^/]+?)(?:\.git)?$")
.ok()?
.captures(url)
{
return Some(captures[1].to_string());
}
// SSH形式: git@github.com:user/repo.git
if let Some(captures) = regex::Regex::new(r":([^/]+/[^/]+?)(?:\.git)?$")
.ok()?
.captures(url)
{
return Some(captures[1].to_string());
}
None
}
/// 現在のGitブランチを取得
fn get_git_current_branch(root_path: &Path) -> Option<String> {
let output = Command::new("git")
.args(&["branch", "--show-current"])
.current_dir(root_path)
.output()
.ok()?;
if output.status.success() {
let branch = String::from_utf8_lossy(&output.stdout).trim().to_string();
if !branch.is_empty() {
return Some(branch);
}
}
None
}
/// 最新コミットの概要を取得
fn get_git_last_commit(root_path: &Path) -> Option<String> {
let output = Command::new("git")
.args(&["log", "-1", "--pretty=format:%s"])
.current_dir(root_path)
.output()
.ok()?;
if output.status.success() {
let commit = String::from_utf8_lossy(&output.stdout).trim().to_string();
if !commit.is_empty() {
return Some(commit);
}
}
None
}
/// Git作業ディレクトリがクリーンかどうかチェック
fn is_git_clean(root_path: &Path) -> Option<bool> {
let output = Command::new("git")
.args(&["status", "--porcelain"])
.current_dir(root_path)
.output()
.ok()?;
if output.status.success() {
let status = String::from_utf8_lossy(&output.stdout);
return Some(status.trim().is_empty());
}
None
}
/// プロジェクト名を決定
fn determine_project_name(root_path: &Path, git_info: &Option<GitInfo>) -> Option<String> {
// 1. Git リポジトリ名を使用
if let Some(git) = git_info {
if let Some(repo_name) = &git.repo_name {
return Some(repo_name.clone());
}
}
// 2. Cargo.toml の name フィールドを使用
if let Some(cargo_name) = Self::get_cargo_project_name(root_path) {
return Some(cargo_name);
}
// 3. package.json の name フィールドを使用
if let Some(npm_name) = Self::get_npm_project_name(root_path) {
return Some(npm_name);
}
// 4. ディレクトリ名を使用
root_path
.file_name()
.and_then(|name| name.to_str())
.map(|s| s.to_string())
}
/// Cargo.toml からプロジェクト名を取得
fn get_cargo_project_name(root_path: &Path) -> Option<String> {
let cargo_toml_path = root_path.join("Cargo.toml");
let content = fs::read_to_string(&cargo_toml_path).ok()?;
// 簡単なパースで name フィールドを抽出
for line in content.lines() {
if let Some(captures) = regex::Regex::new(r#"name\s*=\s*"([^"]+)""#)
.ok()?
.captures(line)
{
return Some(captures[1].to_string());
}
}
None
}
/// package.json からプロジェクト名を取得
fn get_npm_project_name(root_path: &Path) -> Option<String> {
let package_json_path = root_path.join("package.json");
let content = fs::read_to_string(&package_json_path).ok()?;
// JSON パースでnameフィールドを取得
let package_json: serde_json::Value = serde_json::from_str(&content).ok()?;
package_json.get("name")?.as_str().map(|s| s.to_string())
}
}