llm_worker_rs/worker/tests/tool_macro_test.rs

213 lines
5.9 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! ツールマクロのテスト
//!
//! `#[tool_registry]` と `#[tool]` マクロの動作を確認する。
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
// マクロ展開に必要なインポート
use schemars;
use serde;
use worker_macros::tool_registry;
use worker_types::Tool;
// =============================================================================
// Test: Basic Tool Generation
// =============================================================================
/// シンプルなコンテキスト構造体
#[derive(Clone)]
struct SimpleContext {
prefix: String,
}
#[tool_registry]
impl SimpleContext {
/// メッセージに挨拶を追加する
///
/// 指定されたメッセージにプレフィックスを付けて返します。
#[tool]
async fn greet(&self, message: String) -> String {
format!("{}: {}", self.prefix, message)
}
/// 二つの数を足す
#[tool]
async fn add(&self, a: i32, b: i32) -> i32 {
a + b
}
/// 引数なしのツール
#[tool]
async fn get_prefix(&self) -> String {
self.prefix.clone()
}
}
#[tokio::test]
async fn test_basic_tool_generation() {
let ctx = SimpleContext {
prefix: "Hello".to_string(),
};
// ファクトリメソッドでツールを取得
let greet_tool = ctx.greet_tool();
// 名前の確認
assert_eq!(greet_tool.name(), "greet");
// 説明の確認docコメントから取得
let desc = greet_tool.description();
assert!(desc.contains("メッセージに挨拶を追加する"), "Description should contain doc comment: {}", desc);
// スキーマの確認
let schema = greet_tool.input_schema();
println!("Schema: {}", serde_json::to_string_pretty(&schema).unwrap());
assert!(schema.get("properties").is_some(), "Schema should have properties");
// 実行テスト
let result = greet_tool.execute(r#"{"message": "World"}"#).await;
assert!(result.is_ok(), "Should execute successfully");
let output = result.unwrap();
assert!(output.contains("Hello"), "Output should contain prefix");
assert!(output.contains("World"), "Output should contain message");
}
#[tokio::test]
async fn test_multiple_arguments() {
let ctx = SimpleContext {
prefix: "".to_string(),
};
let add_tool = ctx.add_tool();
assert_eq!(add_tool.name(), "add");
let result = add_tool.execute(r#"{"a": 10, "b": 20}"#).await;
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.contains("30"), "Should contain sum: {}", output);
}
#[tokio::test]
async fn test_no_arguments() {
let ctx = SimpleContext {
prefix: "TestPrefix".to_string(),
};
let get_prefix_tool = ctx.get_prefix_tool();
assert_eq!(get_prefix_tool.name(), "get_prefix");
// 空のJSONオブジェクトで呼び出し
let result = get_prefix_tool.execute(r#"{}"#).await;
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.contains("TestPrefix"), "Should contain prefix: {}", output);
}
#[tokio::test]
async fn test_invalid_arguments() {
let ctx = SimpleContext {
prefix: "".to_string(),
};
let greet_tool = ctx.greet_tool();
// 不正なJSON
let result = greet_tool.execute(r#"{"wrong_field": "value"}"#).await;
assert!(result.is_err(), "Should fail with invalid arguments");
}
// =============================================================================
// Test: Result Return Type
// =============================================================================
#[derive(Clone)]
struct FallibleContext;
#[derive(Debug)]
struct MyError(String);
impl std::fmt::Display for MyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[tool_registry]
impl FallibleContext {
/// 与えられた値を検証する
#[tool]
async fn validate(&self, value: i32) -> Result<String, MyError> {
if value > 0 {
Ok(format!("Valid: {}", value))
} else {
Err(MyError("Value must be positive".to_string()))
}
}
}
#[tokio::test]
async fn test_result_return_type_success() {
let ctx = FallibleContext;
let validate_tool = ctx.validate_tool();
let result = validate_tool.execute(r#"{"value": 42}"#).await;
assert!(result.is_ok(), "Should succeed for positive value");
let output = result.unwrap();
assert!(output.contains("Valid"), "Should contain Valid: {}", output);
}
#[tokio::test]
async fn test_result_return_type_error() {
let ctx = FallibleContext;
let validate_tool = ctx.validate_tool();
let result = validate_tool.execute(r#"{"value": -1}"#).await;
assert!(result.is_err(), "Should fail for negative value");
let err = result.unwrap_err();
assert!(err.to_string().contains("positive"), "Error should mention positive: {}", err);
}
// =============================================================================
// Test: Synchronous Methods
// =============================================================================
#[derive(Clone)]
struct SyncContext {
counter: Arc<AtomicUsize>,
}
#[tool_registry]
impl SyncContext {
/// カウンターをインクリメントして返す (非async)
#[tool]
fn increment(&self) -> usize {
self.counter.fetch_add(1, Ordering::SeqCst) + 1
}
}
#[tokio::test]
async fn test_sync_method() {
let ctx = SyncContext {
counter: Arc::new(AtomicUsize::new(0)),
};
let increment_tool = ctx.increment_tool();
// 3回実行
let result1 = increment_tool.execute(r#"{}"#).await;
let result2 = increment_tool.execute(r#"{}"#).await;
let result3 = increment_tool.execute(r#"{}"#).await;
assert!(result1.is_ok());
assert!(result2.is_ok());
assert!(result3.is_ok());
// カウンターは3になっているはず
assert_eq!(ctx.counter.load(Ordering::SeqCst), 3);
}