213 lines
5.9 KiB
Rust
213 lines
5.9 KiB
Rust
//! ツールマクロのテスト
|
||
//!
|
||
//! `#[tool_registry]` と `#[tool]` マクロの動作を確認する。
|
||
|
||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||
use std::sync::Arc;
|
||
|
||
// マクロ展開に必要なインポート
|
||
use schemars;
|
||
use serde;
|
||
|
||
use worker_macros::tool_registry;
|
||
use worker_types::Tool;
|
||
|
||
// =============================================================================
|
||
// Test: Basic Tool Generation
|
||
// =============================================================================
|
||
|
||
/// シンプルなコンテキスト構造体
|
||
#[derive(Clone)]
|
||
struct SimpleContext {
|
||
prefix: String,
|
||
}
|
||
|
||
#[tool_registry]
|
||
impl SimpleContext {
|
||
/// メッセージに挨拶を追加する
|
||
///
|
||
/// 指定されたメッセージにプレフィックスを付けて返します。
|
||
#[tool]
|
||
async fn greet(&self, message: String) -> String {
|
||
format!("{}: {}", self.prefix, message)
|
||
}
|
||
|
||
/// 二つの数を足す
|
||
#[tool]
|
||
async fn add(&self, a: i32, b: i32) -> i32 {
|
||
a + b
|
||
}
|
||
|
||
/// 引数なしのツール
|
||
#[tool]
|
||
async fn get_prefix(&self) -> String {
|
||
self.prefix.clone()
|
||
}
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_basic_tool_generation() {
|
||
let ctx = SimpleContext {
|
||
prefix: "Hello".to_string(),
|
||
};
|
||
|
||
// ファクトリメソッドでツールを取得
|
||
let greet_tool = ctx.greet_tool();
|
||
|
||
// 名前の確認
|
||
assert_eq!(greet_tool.name(), "greet");
|
||
|
||
// 説明の確認(docコメントから取得)
|
||
let desc = greet_tool.description();
|
||
assert!(desc.contains("メッセージに挨拶を追加する"), "Description should contain doc comment: {}", desc);
|
||
|
||
// スキーマの確認
|
||
let schema = greet_tool.input_schema();
|
||
println!("Schema: {}", serde_json::to_string_pretty(&schema).unwrap());
|
||
assert!(schema.get("properties").is_some(), "Schema should have properties");
|
||
|
||
// 実行テスト
|
||
let result = greet_tool.execute(r#"{"message": "World"}"#).await;
|
||
assert!(result.is_ok(), "Should execute successfully");
|
||
let output = result.unwrap();
|
||
assert!(output.contains("Hello"), "Output should contain prefix");
|
||
assert!(output.contains("World"), "Output should contain message");
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_multiple_arguments() {
|
||
let ctx = SimpleContext {
|
||
prefix: "".to_string(),
|
||
};
|
||
|
||
let add_tool = ctx.add_tool();
|
||
|
||
assert_eq!(add_tool.name(), "add");
|
||
|
||
let result = add_tool.execute(r#"{"a": 10, "b": 20}"#).await;
|
||
assert!(result.is_ok());
|
||
let output = result.unwrap();
|
||
assert!(output.contains("30"), "Should contain sum: {}", output);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_no_arguments() {
|
||
let ctx = SimpleContext {
|
||
prefix: "TestPrefix".to_string(),
|
||
};
|
||
|
||
let get_prefix_tool = ctx.get_prefix_tool();
|
||
|
||
assert_eq!(get_prefix_tool.name(), "get_prefix");
|
||
|
||
// 空のJSONオブジェクトで呼び出し
|
||
let result = get_prefix_tool.execute(r#"{}"#).await;
|
||
assert!(result.is_ok());
|
||
let output = result.unwrap();
|
||
assert!(output.contains("TestPrefix"), "Should contain prefix: {}", output);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_invalid_arguments() {
|
||
let ctx = SimpleContext {
|
||
prefix: "".to_string(),
|
||
};
|
||
|
||
let greet_tool = ctx.greet_tool();
|
||
|
||
// 不正なJSON
|
||
let result = greet_tool.execute(r#"{"wrong_field": "value"}"#).await;
|
||
assert!(result.is_err(), "Should fail with invalid arguments");
|
||
}
|
||
|
||
// =============================================================================
|
||
// Test: Result Return Type
|
||
// =============================================================================
|
||
|
||
#[derive(Clone)]
|
||
struct FallibleContext;
|
||
|
||
#[derive(Debug)]
|
||
struct MyError(String);
|
||
|
||
impl std::fmt::Display for MyError {
|
||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||
write!(f, "{}", self.0)
|
||
}
|
||
}
|
||
|
||
#[tool_registry]
|
||
impl FallibleContext {
|
||
/// 与えられた値を検証する
|
||
#[tool]
|
||
async fn validate(&self, value: i32) -> Result<String, MyError> {
|
||
if value > 0 {
|
||
Ok(format!("Valid: {}", value))
|
||
} else {
|
||
Err(MyError("Value must be positive".to_string()))
|
||
}
|
||
}
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_result_return_type_success() {
|
||
let ctx = FallibleContext;
|
||
let validate_tool = ctx.validate_tool();
|
||
|
||
let result = validate_tool.execute(r#"{"value": 42}"#).await;
|
||
assert!(result.is_ok(), "Should succeed for positive value");
|
||
let output = result.unwrap();
|
||
assert!(output.contains("Valid"), "Should contain Valid: {}", output);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_result_return_type_error() {
|
||
let ctx = FallibleContext;
|
||
let validate_tool = ctx.validate_tool();
|
||
|
||
let result = validate_tool.execute(r#"{"value": -1}"#).await;
|
||
assert!(result.is_err(), "Should fail for negative value");
|
||
|
||
let err = result.unwrap_err();
|
||
assert!(err.to_string().contains("positive"), "Error should mention positive: {}", err);
|
||
}
|
||
|
||
// =============================================================================
|
||
// Test: Synchronous Methods
|
||
// =============================================================================
|
||
|
||
#[derive(Clone)]
|
||
struct SyncContext {
|
||
counter: Arc<AtomicUsize>,
|
||
}
|
||
|
||
#[tool_registry]
|
||
impl SyncContext {
|
||
/// カウンターをインクリメントして返す (非async)
|
||
#[tool]
|
||
fn increment(&self) -> usize {
|
||
self.counter.fetch_add(1, Ordering::SeqCst) + 1
|
||
}
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_sync_method() {
|
||
let ctx = SyncContext {
|
||
counter: Arc::new(AtomicUsize::new(0)),
|
||
};
|
||
|
||
let increment_tool = ctx.increment_tool();
|
||
|
||
// 3回実行
|
||
let result1 = increment_tool.execute(r#"{}"#).await;
|
||
let result2 = increment_tool.execute(r#"{}"#).await;
|
||
let result3 = increment_tool.execute(r#"{}"#).await;
|
||
|
||
assert!(result1.is_ok());
|
||
assert!(result2.is_ok());
|
||
assert!(result3.is_ok());
|
||
|
||
// カウンターは3になっているはず
|
||
assert_eq!(ctx.counter.load(Ordering::SeqCst), 3);
|
||
}
|