llm_worker_rs/llm-worker/tests/tool_macro_test.rs

248 lines
6.6 KiB
Rust

//! ツールマクロのテスト
//!
//! `#[tool_registry]` と `#[tool]` マクロの動作を確認する。
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
// マクロ展開に必要なインポート
use schemars;
use serde;
use llm_worker::tool::{Tool, ToolMeta};
use llm_worker_macros::tool_registry;
// =============================================================================
// Test: Basic Tool Generation
// =============================================================================
/// シンプルなコンテキスト構造体
#[derive(Clone)]
struct SimpleContext {
prefix: String,
}
#[tool_registry]
impl SimpleContext {
/// メッセージに挨拶を追加する
///
/// 指定されたメッセージにプレフィックスを付けて返します。
#[tool]
async fn greet(&self, message: String) -> String {
format!("{}: {}", self.prefix, message)
}
/// 二つの数を足す
#[tool]
async fn add(&self, a: i32, b: i32) -> i32 {
a + b
}
/// 引数なしのツール
#[tool]
async fn get_prefix(&self) -> String {
self.prefix.clone()
}
}
#[tokio::test]
async fn test_basic_tool_generation() {
let ctx = SimpleContext {
prefix: "Hello".to_string(),
};
// ファクトリメソッドでToolDefinitionを取得
let greet_definition = ctx.greet_definition();
// ファクトリを呼び出してMetaとToolを取得
let (meta, tool) = greet_definition();
// メタ情報の確認
assert_eq!(meta.name, "greet");
assert!(
meta.description.contains("メッセージに挨拶を追加する"),
"Description should contain doc comment: {}",
meta.description
);
assert!(
meta.input_schema.get("properties").is_some(),
"Schema should have properties"
);
println!(
"Schema: {}",
serde_json::to_string_pretty(&meta.input_schema).unwrap()
);
// 実行テスト
let result = tool.execute(r#"{"message": "World"}"#).await;
assert!(result.is_ok(), "Should execute successfully");
let output = result.unwrap();
assert!(output.contains("Hello"), "Output should contain prefix");
assert!(output.contains("World"), "Output should contain message");
}
#[tokio::test]
async fn test_multiple_arguments() {
let ctx = SimpleContext {
prefix: "".to_string(),
};
let (meta, tool) = ctx.add_definition()();
assert_eq!(meta.name, "add");
let result = tool.execute(r#"{"a": 10, "b": 20}"#).await;
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.contains("30"), "Should contain sum: {}", output);
}
#[tokio::test]
async fn test_no_arguments() {
let ctx = SimpleContext {
prefix: "TestPrefix".to_string(),
};
let (meta, tool) = ctx.get_prefix_definition()();
assert_eq!(meta.name, "get_prefix");
// 空のJSONオブジェクトで呼び出し
let result = tool.execute(r#"{}"#).await;
assert!(result.is_ok());
let output = result.unwrap();
assert!(
output.contains("TestPrefix"),
"Should contain prefix: {}",
output
);
}
#[tokio::test]
async fn test_invalid_arguments() {
let ctx = SimpleContext {
prefix: "".to_string(),
};
let (_, tool) = ctx.greet_definition()();
// 不正なJSON
let result = tool.execute(r#"{"wrong_field": "value"}"#).await;
assert!(result.is_err(), "Should fail with invalid arguments");
}
// =============================================================================
// Test: Result Return Type
// =============================================================================
#[derive(Clone)]
struct FallibleContext;
#[derive(Debug)]
struct MyError(String);
impl std::fmt::Display for MyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[tool_registry]
impl FallibleContext {
/// 与えられた値を検証する
#[tool]
async fn validate(&self, value: i32) -> Result<String, MyError> {
if value > 0 {
Ok(format!("Valid: {}", value))
} else {
Err(MyError("Value must be positive".to_string()))
}
}
}
#[tokio::test]
async fn test_result_return_type_success() {
let ctx = FallibleContext;
let (_, tool) = ctx.validate_definition()();
let result = tool.execute(r#"{"value": 42}"#).await;
assert!(result.is_ok(), "Should succeed for positive value");
let output = result.unwrap();
assert!(output.contains("Valid"), "Should contain Valid: {}", output);
}
#[tokio::test]
async fn test_result_return_type_error() {
let ctx = FallibleContext;
let (_, tool) = ctx.validate_definition()();
let result = tool.execute(r#"{"value": -1}"#).await;
assert!(result.is_err(), "Should fail for negative value");
let err = result.unwrap_err();
assert!(
err.to_string().contains("positive"),
"Error should mention positive: {}",
err
);
}
// =============================================================================
// Test: Synchronous Methods
// =============================================================================
#[derive(Clone)]
struct SyncContext {
counter: Arc<AtomicUsize>,
}
#[tool_registry]
impl SyncContext {
/// カウンターをインクリメントして返す (非async)
#[tool]
fn increment(&self) -> usize {
self.counter.fetch_add(1, Ordering::SeqCst) + 1
}
}
#[tokio::test]
async fn test_sync_method() {
let ctx = SyncContext {
counter: Arc::new(AtomicUsize::new(0)),
};
let (_, tool) = ctx.increment_definition()();
// 3回実行
let result1 = tool.execute(r#"{}"#).await;
let result2 = tool.execute(r#"{}"#).await;
let result3 = tool.execute(r#"{}"#).await;
assert!(result1.is_ok());
assert!(result2.is_ok());
assert!(result3.is_ok());
// カウンターは3になっているはず
assert_eq!(ctx.counter.load(Ordering::SeqCst), 3);
}
// =============================================================================
// Test: ToolMeta Immutability
// =============================================================================
#[tokio::test]
async fn test_tool_meta_immutability() {
let ctx = SimpleContext {
prefix: "Test".to_string(),
};
// 2回取得しても同じメタ情報が得られることを確認
let (meta1, _) = ctx.greet_definition()();
let (meta2, _) = ctx.greet_definition()();
assert_eq!(meta1.name, meta2.name);
assert_eq!(meta1.description, meta2.description);
assert_eq!(meta1.input_schema, meta2.input_schema);
}