247 lines
6.6 KiB
Rust
247 lines
6.6 KiB
Rust
//! ツールマクロのテスト
|
|
//!
|
|
//! `#[tool_registry]` と `#[tool]` マクロの動作を確認する。
|
|
|
|
use std::sync::Arc;
|
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
|
|
// マクロ展開に必要なインポート
|
|
use schemars;
|
|
use serde;
|
|
|
|
use llm_worker_macros::tool_registry;
|
|
|
|
// =============================================================================
|
|
// Test: Basic Tool Generation
|
|
// =============================================================================
|
|
|
|
/// シンプルなコンテキスト構造体
|
|
#[derive(Clone)]
|
|
struct SimpleContext {
|
|
prefix: String,
|
|
}
|
|
|
|
#[tool_registry]
|
|
impl SimpleContext {
|
|
/// メッセージに挨拶を追加する
|
|
///
|
|
/// 指定されたメッセージにプレフィックスを付けて返します。
|
|
#[tool]
|
|
async fn greet(&self, message: String) -> String {
|
|
format!("{}: {}", self.prefix, message)
|
|
}
|
|
|
|
/// 二つの数を足す
|
|
#[tool]
|
|
async fn add(&self, a: i32, b: i32) -> i32 {
|
|
a + b
|
|
}
|
|
|
|
/// 引数なしのツール
|
|
#[tool]
|
|
async fn get_prefix(&self) -> String {
|
|
self.prefix.clone()
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_basic_tool_generation() {
|
|
let ctx = SimpleContext {
|
|
prefix: "Hello".to_string(),
|
|
};
|
|
|
|
// ファクトリメソッドでToolDefinitionを取得
|
|
let greet_definition = ctx.greet_definition();
|
|
|
|
// ファクトリを呼び出してMetaとToolを取得
|
|
let (meta, tool) = greet_definition();
|
|
|
|
// メタ情報の確認
|
|
assert_eq!(meta.name, "greet");
|
|
assert!(
|
|
meta.description.contains("メッセージに挨拶を追加する"),
|
|
"Description should contain doc comment: {}",
|
|
meta.description
|
|
);
|
|
assert!(
|
|
meta.input_schema.get("properties").is_some(),
|
|
"Schema should have properties"
|
|
);
|
|
|
|
println!(
|
|
"Schema: {}",
|
|
serde_json::to_string_pretty(&meta.input_schema).unwrap()
|
|
);
|
|
|
|
// 実行テスト
|
|
let result = tool.execute(r#"{"message": "World"}"#).await;
|
|
assert!(result.is_ok(), "Should execute successfully");
|
|
let output = result.unwrap();
|
|
assert!(output.contains("Hello"), "Output should contain prefix");
|
|
assert!(output.contains("World"), "Output should contain message");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_multiple_arguments() {
|
|
let ctx = SimpleContext {
|
|
prefix: "".to_string(),
|
|
};
|
|
|
|
let (meta, tool) = ctx.add_definition()();
|
|
|
|
assert_eq!(meta.name, "add");
|
|
|
|
let result = tool.execute(r#"{"a": 10, "b": 20}"#).await;
|
|
assert!(result.is_ok());
|
|
let output = result.unwrap();
|
|
assert!(output.contains("30"), "Should contain sum: {}", output);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_no_arguments() {
|
|
let ctx = SimpleContext {
|
|
prefix: "TestPrefix".to_string(),
|
|
};
|
|
|
|
let (meta, tool) = ctx.get_prefix_definition()();
|
|
|
|
assert_eq!(meta.name, "get_prefix");
|
|
|
|
// 空のJSONオブジェクトで呼び出し
|
|
let result = tool.execute(r#"{}"#).await;
|
|
assert!(result.is_ok());
|
|
let output = result.unwrap();
|
|
assert!(
|
|
output.contains("TestPrefix"),
|
|
"Should contain prefix: {}",
|
|
output
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_invalid_arguments() {
|
|
let ctx = SimpleContext {
|
|
prefix: "".to_string(),
|
|
};
|
|
|
|
let (_, tool) = ctx.greet_definition()();
|
|
|
|
// 不正なJSON
|
|
let result = tool.execute(r#"{"wrong_field": "value"}"#).await;
|
|
assert!(result.is_err(), "Should fail with invalid arguments");
|
|
}
|
|
|
|
// =============================================================================
|
|
// Test: Result Return Type
|
|
// =============================================================================
|
|
|
|
#[derive(Clone)]
|
|
struct FallibleContext;
|
|
|
|
#[derive(Debug)]
|
|
struct MyError(String);
|
|
|
|
impl std::fmt::Display for MyError {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "{}", self.0)
|
|
}
|
|
}
|
|
|
|
#[tool_registry]
|
|
impl FallibleContext {
|
|
/// 与えられた値を検証する
|
|
#[tool]
|
|
async fn validate(&self, value: i32) -> Result<String, MyError> {
|
|
if value > 0 {
|
|
Ok(format!("Valid: {}", value))
|
|
} else {
|
|
Err(MyError("Value must be positive".to_string()))
|
|
}
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_result_return_type_success() {
|
|
let ctx = FallibleContext;
|
|
let (_, tool) = ctx.validate_definition()();
|
|
|
|
let result = tool.execute(r#"{"value": 42}"#).await;
|
|
assert!(result.is_ok(), "Should succeed for positive value");
|
|
let output = result.unwrap();
|
|
assert!(output.contains("Valid"), "Should contain Valid: {}", output);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_result_return_type_error() {
|
|
let ctx = FallibleContext;
|
|
let (_, tool) = ctx.validate_definition()();
|
|
|
|
let result = tool.execute(r#"{"value": -1}"#).await;
|
|
assert!(result.is_err(), "Should fail for negative value");
|
|
|
|
let err = result.unwrap_err();
|
|
assert!(
|
|
err.to_string().contains("positive"),
|
|
"Error should mention positive: {}",
|
|
err
|
|
);
|
|
}
|
|
|
|
// =============================================================================
|
|
// Test: Synchronous Methods
|
|
// =============================================================================
|
|
|
|
#[derive(Clone)]
|
|
struct SyncContext {
|
|
counter: Arc<AtomicUsize>,
|
|
}
|
|
|
|
#[tool_registry]
|
|
impl SyncContext {
|
|
/// カウンターをインクリメントして返す (非async)
|
|
#[tool]
|
|
fn increment(&self) -> usize {
|
|
self.counter.fetch_add(1, Ordering::SeqCst) + 1
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_sync_method() {
|
|
let ctx = SyncContext {
|
|
counter: Arc::new(AtomicUsize::new(0)),
|
|
};
|
|
|
|
let (_, tool) = ctx.increment_definition()();
|
|
|
|
// 3回実行
|
|
let result1 = tool.execute(r#"{}"#).await;
|
|
let result2 = tool.execute(r#"{}"#).await;
|
|
let result3 = tool.execute(r#"{}"#).await;
|
|
|
|
assert!(result1.is_ok());
|
|
assert!(result2.is_ok());
|
|
assert!(result3.is_ok());
|
|
|
|
// カウンターは3になっているはず
|
|
assert_eq!(ctx.counter.load(Ordering::SeqCst), 3);
|
|
}
|
|
|
|
// =============================================================================
|
|
// Test: ToolMeta Immutability
|
|
// =============================================================================
|
|
|
|
#[tokio::test]
|
|
async fn test_tool_meta_immutability() {
|
|
let ctx = SimpleContext {
|
|
prefix: "Test".to_string(),
|
|
};
|
|
|
|
// 2回取得しても同じメタ情報が得られることを確認
|
|
let (meta1, _) = ctx.greet_definition()();
|
|
let (meta2, _) = ctx.greet_definition()();
|
|
|
|
assert_eq!(meta1.name, meta2.name);
|
|
assert_eq!(meta1.description, meta2.description);
|
|
assert_eq!(meta1.input_schema, meta2.input_schema);
|
|
}
|