yoi/crates/llm-worker-macros/src/lib.rs

322 lines
9.9 KiB
Rust

//! llm-worker-macros - Procedural macros for Tool generation
//!
//! Provides `#[tool_registry]` and `#[tool]` macros to
//! automatically generate `Tool` trait implementations from user-defined methods.
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{
Attribute, FnArg, ImplItem, ItemImpl, Lit, Meta, Pat, ReturnType, Type, parse_macro_input,
};
/// Macro applied to an `impl` block that generates tools from methods marked with `#[tool]`.
///
/// # Example
/// ```ignore
/// #[tool_registry]
/// impl MyApp {
/// /// Get user information
/// /// Retrieves a user from the database by their ID.
/// #[tool]
/// async fn get_user(&self, user_id: String) -> Result<User, Error> { ... }
/// }
/// ```
///
/// This generates:
/// - `GetUserArgs` struct (for arguments)
/// - `Tool_get_user` struct (Tool wrapper)
/// - `impl Tool for Tool_get_user`
/// - `impl MyApp { fn get_user_tool(&self) -> Tool_get_user }`
#[proc_macro_attribute]
pub fn tool_registry(_attr: TokenStream, item: TokenStream) -> TokenStream {
let mut impl_block = parse_macro_input!(item as ItemImpl);
let self_ty = &impl_block.self_ty;
let mut generated_items = Vec::new();
for item in &mut impl_block.items {
if let ImplItem::Fn(method) = item {
// Look for #[tool] attribute
let mut is_tool = false;
// Iterate through attributes to check for tool and remove it
method.attrs.retain(|attr| {
if attr.path().is_ident("tool") {
is_tool = true;
false // Remove the attribute
} else {
true
}
});
if is_tool {
let tool_impl = generate_tool_impl(self_ty, method);
generated_items.push(tool_impl);
}
}
}
let expanded = quote! {
#impl_block
#(#generated_items)*
};
TokenStream::from(expanded)
}
/// Extract description from doc comments
fn extract_doc_comment(attrs: &[Attribute]) -> String {
let mut lines = Vec::new();
for attr in attrs {
if attr.path().is_ident("doc") {
if let Meta::NameValue(meta) = &attr.meta {
if let syn::Expr::Lit(expr_lit) = &meta.value {
if let Lit::Str(lit_str) = &expr_lit.lit {
let line = lit_str.value();
// Remove only the leading space (after ///)
let trimmed = line.strip_prefix(' ').unwrap_or(&line);
lines.push(trimmed.to_string());
}
}
}
}
}
lines.join("\n")
}
/// Extract description from #[description = "..."] attribute
fn extract_description_attr(attrs: &[syn::Attribute]) -> Option<String> {
for attr in attrs {
if attr.path().is_ident("description") {
if let Meta::NameValue(meta) = &attr.meta {
if let syn::Expr::Lit(expr_lit) = &meta.value {
if let Lit::Str(lit_str) = &expr_lit.lit {
return Some(lit_str.value());
}
}
}
}
}
None
}
/// Generate Tool implementation from a method
fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::TokenStream {
let sig = &method.sig;
let method_name = &sig.ident;
let tool_name = method_name.to_string();
// Generate struct names (convert to PascalCase)
let pascal_name = to_pascal_case(&method_name.to_string());
let tool_struct_name = format_ident!("Tool{}", pascal_name);
let args_struct_name = format_ident!("{}Args", pascal_name);
let definition_name = format_ident!("{}_definition", method_name);
// Get description from doc comments
let description = extract_doc_comment(&method.attrs);
let description = if description.is_empty() {
format!("Tool: {}", tool_name)
} else {
description
};
// Parse arguments (excluding self)
let args: Vec<_> = sig
.inputs
.iter()
.filter_map(|arg| {
if let FnArg::Typed(pat_type) = arg {
Some(pat_type)
} else {
None // Exclude self
}
})
.collect();
// Generate argument struct fields
let arg_fields: Vec<_> = args
.iter()
.map(|pat_type| {
let pat = &pat_type.pat;
let ty = &pat_type.ty;
let desc = extract_description_attr(&pat_type.attrs);
// Extract identifier from pattern
let field_name = if let Pat::Ident(pat_ident) = pat.as_ref() {
&pat_ident.ident
} else {
panic!("Only simple identifiers are supported for tool arguments");
};
// Convert #[description] to schemars doc if present
if let Some(desc_str) = desc {
quote! {
#[schemars(description = #desc_str)]
pub #field_name: #ty
}
} else {
quote! {
pub #field_name: #ty
}
}
})
.collect();
// Code to expand arguments in execute
let arg_names: Vec<_> = args
.iter()
.map(|pat_type| {
if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
let ident = &pat_ident.ident;
quote! { args.#ident }
} else {
panic!("Only simple identifiers are supported");
}
})
.collect();
// Check if method is async
let is_async = sig.asyncness.is_some();
// Parse return type and determine if Result
let awaiter = if is_async {
quote! { .await }
} else {
quote! {}
};
// Determine if return type is Result
let result_handling = if is_result_type(&sig.output) {
quote! {
match result {
Ok(val) => Ok(format!("{:?}", val).into()),
Err(e) => Err(::llm_worker::tool::ToolError::ExecutionFailed(format!("{}", e))),
}
}
} else {
quote! {
Ok(format!("{:?}", result).into())
}
};
// Create empty Args struct if no arguments
let args_struct_def = if arg_fields.is_empty() {
quote! {
#[derive(serde::Deserialize, schemars::JsonSchema)]
struct #args_struct_name {}
}
} else {
quote! {
#[derive(serde::Deserialize, schemars::JsonSchema)]
struct #args_struct_name {
#(#arg_fields),*
}
}
};
// Execute body handling for no arguments case
let execute_body = if args.is_empty() {
quote! {
// Allow empty JSON object even with no arguments
let _: #args_struct_name = serde_json::from_str(input_json)
.unwrap_or(#args_struct_name {});
let result = self.ctx.#method_name()#awaiter;
#result_handling
}
} else {
quote! {
let args: #args_struct_name = serde_json::from_str(input_json)
.map_err(|e| ::llm_worker::tool::ToolError::InvalidArgument(e.to_string()))?;
let result = self.ctx.#method_name(#(#arg_names),*)#awaiter;
#result_handling
}
};
quote! {
#args_struct_def
#[derive(Clone)]
pub struct #tool_struct_name {
ctx: #self_ty,
}
#[async_trait::async_trait]
impl ::llm_worker::tool::Tool for #tool_struct_name {
async fn execute(&self, input_json: &str) -> Result<::llm_worker::tool::ToolOutput, ::llm_worker::tool::ToolError> {
#execute_body
}
}
impl #self_ty {
/// Get ToolDefinition (for registering with Worker)
pub fn #definition_name(&self) -> ::llm_worker::tool::ToolDefinition {
let ctx = self.clone();
::std::sync::Arc::new(move || {
let schema = schemars::schema_for!(#args_struct_name);
let meta = ::llm_worker::tool::ToolMeta::new(#tool_name)
.description(#description)
.input_schema(serde_json::to_value(schema).unwrap_or(serde_json::json!({})));
let tool: ::std::sync::Arc<dyn ::llm_worker::tool::Tool> =
::std::sync::Arc::new(#tool_struct_name { ctx: ctx.clone() });
(meta, tool)
})
}
}
}
}
/// Determine if return type is Result
fn is_result_type(return_type: &ReturnType) -> bool {
match return_type {
ReturnType::Default => false,
ReturnType::Type(_, ty) => {
// For Type::Path, check if last segment is "Result"
if let Type::Path(type_path) = ty.as_ref() {
if let Some(segment) = type_path.path.segments.last() {
return segment.ident == "Result";
}
}
false
}
}
}
/// Convert snake_case to PascalCase
fn to_pascal_case(s: &str) -> String {
s.split('_')
.map(|part| {
let mut chars = part.chars();
match chars.next() {
None => String::new(),
Some(first) => first.to_uppercase().chain(chars).collect(),
}
})
.collect()
}
/// Marker attribute. Does nothing here as it's processed by `tool_registry`.
#[proc_macro_attribute]
pub fn tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
/// Marker for argument attributes. Interpreted by `tool_registry` during parsing.
///
/// # Example
/// ```ignore
/// #[tool]
/// async fn get_user(
/// &self,
/// #[description = "The ID of the user to retrieve"] user_id: String
/// ) -> Result<User, Error> { ... }
/// ```
#[proc_macro_attribute]
pub fn description(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}