247 lines
6.3 KiB
Rust
247 lines
6.3 KiB
Rust
//! Tool macro tests
|
|
//!
|
|
//! Verify the behavior of `#[tool_registry]` and `#[tool]` macros.
|
|
|
|
use std::sync::Arc;
|
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
|
|
// Imports needed for macro expansion
|
|
use schemars;
|
|
use serde;
|
|
|
|
use llm_worker_macros::tool_registry;
|
|
|
|
// =============================================================================
|
|
// Test: Basic Tool Generation
|
|
// =============================================================================
|
|
|
|
/// Simple context struct
|
|
#[derive(Clone)]
|
|
struct SimpleContext {
|
|
prefix: String,
|
|
}
|
|
|
|
#[tool_registry]
|
|
impl SimpleContext {
|
|
/// Add greeting to message
|
|
///
|
|
/// Returns the message with a prefix added.
|
|
#[tool]
|
|
async fn greet(&self, message: String) -> String {
|
|
format!("{}: {}", self.prefix, message)
|
|
}
|
|
|
|
/// Add two numbers
|
|
#[tool]
|
|
async fn add(&self, a: i32, b: i32) -> i32 {
|
|
a + b
|
|
}
|
|
|
|
/// Tool with no arguments
|
|
#[tool]
|
|
async fn get_prefix(&self) -> String {
|
|
self.prefix.clone()
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_basic_tool_generation() {
|
|
let ctx = SimpleContext {
|
|
prefix: "Hello".to_string(),
|
|
};
|
|
|
|
// Get ToolDefinition from factory method
|
|
let greet_definition = ctx.greet_definition();
|
|
|
|
// Call factory to get Meta and Tool
|
|
let (meta, tool) = greet_definition();
|
|
|
|
// Verify meta information
|
|
assert_eq!(meta.name, "greet");
|
|
assert!(
|
|
meta.description.contains("Add greeting to message"),
|
|
"Description should contain doc comment: {}",
|
|
meta.description
|
|
);
|
|
assert!(
|
|
meta.input_schema.get("properties").is_some(),
|
|
"Schema should have properties"
|
|
);
|
|
|
|
println!(
|
|
"Schema: {}",
|
|
serde_json::to_string_pretty(&meta.input_schema).unwrap()
|
|
);
|
|
|
|
// Execution test
|
|
let result = tool.execute(r#"{"message": "World"}"#).await;
|
|
assert!(result.is_ok(), "Should execute successfully");
|
|
let output = result.unwrap();
|
|
assert!(output.contains("Hello"), "Output should contain prefix");
|
|
assert!(output.contains("World"), "Output should contain message");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_multiple_arguments() {
|
|
let ctx = SimpleContext {
|
|
prefix: "".to_string(),
|
|
};
|
|
|
|
let (meta, tool) = ctx.add_definition()();
|
|
|
|
assert_eq!(meta.name, "add");
|
|
|
|
let result = tool.execute(r#"{"a": 10, "b": 20}"#).await;
|
|
assert!(result.is_ok());
|
|
let output = result.unwrap();
|
|
assert!(output.contains("30"), "Should contain sum: {}", output);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_no_arguments() {
|
|
let ctx = SimpleContext {
|
|
prefix: "TestPrefix".to_string(),
|
|
};
|
|
|
|
let (meta, tool) = ctx.get_prefix_definition()();
|
|
|
|
assert_eq!(meta.name, "get_prefix");
|
|
|
|
// Call with empty JSON object
|
|
let result = tool.execute(r#"{}"#).await;
|
|
assert!(result.is_ok());
|
|
let output = result.unwrap();
|
|
assert!(
|
|
output.contains("TestPrefix"),
|
|
"Should contain prefix: {}",
|
|
output
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_invalid_arguments() {
|
|
let ctx = SimpleContext {
|
|
prefix: "".to_string(),
|
|
};
|
|
|
|
let (_, tool) = ctx.greet_definition()();
|
|
|
|
// Invalid JSON
|
|
let result = tool.execute(r#"{"wrong_field": "value"}"#).await;
|
|
assert!(result.is_err(), "Should fail with invalid arguments");
|
|
}
|
|
|
|
// =============================================================================
|
|
// Test: Result Return Type
|
|
// =============================================================================
|
|
|
|
#[derive(Clone)]
|
|
struct FallibleContext;
|
|
|
|
#[derive(Debug)]
|
|
struct MyError(String);
|
|
|
|
impl std::fmt::Display for MyError {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "{}", self.0)
|
|
}
|
|
}
|
|
|
|
#[tool_registry]
|
|
impl FallibleContext {
|
|
/// Validate the given value
|
|
#[tool]
|
|
async fn validate(&self, value: i32) -> Result<String, MyError> {
|
|
if value > 0 {
|
|
Ok(format!("Valid: {}", value))
|
|
} else {
|
|
Err(MyError("Value must be positive".to_string()))
|
|
}
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_result_return_type_success() {
|
|
let ctx = FallibleContext;
|
|
let (_, tool) = ctx.validate_definition()();
|
|
|
|
let result = tool.execute(r#"{"value": 42}"#).await;
|
|
assert!(result.is_ok(), "Should succeed for positive value");
|
|
let output = result.unwrap();
|
|
assert!(output.contains("Valid"), "Should contain Valid: {}", output);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_result_return_type_error() {
|
|
let ctx = FallibleContext;
|
|
let (_, tool) = ctx.validate_definition()();
|
|
|
|
let result = tool.execute(r#"{"value": -1}"#).await;
|
|
assert!(result.is_err(), "Should fail for negative value");
|
|
|
|
let err = result.unwrap_err();
|
|
assert!(
|
|
err.to_string().contains("positive"),
|
|
"Error should mention positive: {}",
|
|
err
|
|
);
|
|
}
|
|
|
|
// =============================================================================
|
|
// Test: Synchronous Methods
|
|
// =============================================================================
|
|
|
|
#[derive(Clone)]
|
|
struct SyncContext {
|
|
counter: Arc<AtomicUsize>,
|
|
}
|
|
|
|
#[tool_registry]
|
|
impl SyncContext {
|
|
/// Increment counter and return (non-async)
|
|
#[tool]
|
|
fn increment(&self) -> usize {
|
|
self.counter.fetch_add(1, Ordering::SeqCst) + 1
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_sync_method() {
|
|
let ctx = SyncContext {
|
|
counter: Arc::new(AtomicUsize::new(0)),
|
|
};
|
|
|
|
let (_, tool) = ctx.increment_definition()();
|
|
|
|
// Execute 3 times
|
|
let result1 = tool.execute(r#"{}"#).await;
|
|
let result2 = tool.execute(r#"{}"#).await;
|
|
let result3 = tool.execute(r#"{}"#).await;
|
|
|
|
assert!(result1.is_ok());
|
|
assert!(result2.is_ok());
|
|
assert!(result3.is_ok());
|
|
|
|
// Counter should be 3
|
|
assert_eq!(ctx.counter.load(Ordering::SeqCst), 3);
|
|
}
|
|
|
|
// =============================================================================
|
|
// Test: ToolMeta Immutability
|
|
// =============================================================================
|
|
|
|
#[tokio::test]
|
|
async fn test_tool_meta_immutability() {
|
|
let ctx = SimpleContext {
|
|
prefix: "Test".to_string(),
|
|
};
|
|
|
|
// Verify same meta info is returned on multiple calls
|
|
let (meta1, _) = ctx.greet_definition()();
|
|
let (meta2, _) = ctx.greet_definition()();
|
|
|
|
assert_eq!(meta1.name, meta2.name);
|
|
assert_eq!(meta1.description, meta2.description);
|
|
assert_eq!(meta1.input_schema, meta2.input_schema);
|
|
}
|