This repository has been archived on 2026-01-07. You can view files and clone it, but cannot push or open issues or pull requests.
worker/worker-macros/src/lib.rs

328 lines
9.6 KiB
Rust

use proc_macro::TokenStream;
use quote::quote;
use syn::{
Attribute, ItemFn, LitStr,
parse::{Parse, ParseStream},
parse_macro_input,
};
struct ToolAttributeArgs {
name: Option<String>,
}
impl Parse for ToolAttributeArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut name = None;
if !input.is_empty() {
let name_ident: syn::Ident = input.parse()?;
if name_ident != "name" {
return Err(syn::Error::new_spanned(
name_ident,
"Only 'name' attribute is supported",
));
}
input.parse::<syn::Token![=]>()?;
let name_str: LitStr = input.parse()?;
name = Some(name_str.value());
}
Ok(ToolAttributeArgs { name })
}
}
#[proc_macro_attribute]
pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as ToolAttributeArgs);
let func = parse_macro_input!(item as ItemFn);
let description = {
let doc_comments = extract_doc_comments(&func.attrs);
if doc_comments.is_empty() {
format!("Tool function: {}", func.sig.ident)
} else {
doc_comments
}
};
// Validate function signature
if let Err(e) = validate_function_signature(&func) {
return e.to_compile_error().into();
}
let fn_name = &func.sig.ident;
let fn_name_str = fn_name.to_string();
// Use provided name or fall back to CamelCase function name
let tool_name_str = args.name.unwrap_or_else(|| to_camel_case(&fn_name_str));
// Extract arg_type and output_type safely after validation
let arg_type = if let syn::FnArg::Typed(pat_type) = &func.sig.inputs[0] {
&pat_type.ty
} else {
// This case should be caught by validate_function_signature
return syn::Error::new_spanned(&func.sig.inputs[0], "Expected typed argument")
.to_compile_error()
.into();
};
if let syn::ReturnType::Type(_, _) = &func.sig.output {
} else {
// This case should be caught by validate_function_signature
return syn::Error::new_spanned(&func.sig.output, "Expected return type")
.to_compile_error()
.into();
};
// Generate struct name from function name (e.g., read_file -> ReadFileTool)
let tool_struct_name = {
let fn_name_str = fn_name.to_string();
let camel_case = to_camel_case(&fn_name_str);
syn::Ident::new(&format!("{}Tool", camel_case), fn_name.span())
};
let expanded = quote! {
// Keep the original function
#func
// Generate Tool struct
pub struct #tool_struct_name;
impl #tool_struct_name {
pub fn new() -> Self {
Self
}
}
// Implement Tool trait
#[::worker_types::async_trait::async_trait]
impl ::worker_types::Tool for #tool_struct_name {
fn name(&self) -> &str {
#tool_name_str
}
fn description(&self) -> &str {
#description
}
fn parameters_schema(&self) -> ::worker_types::serde_json::Value {
::worker_types::serde_json::to_value(::worker_types::schemars::schema_for!(#arg_type)).unwrap()
}
async fn execute(&self, args: ::worker_types::serde_json::Value) -> ::worker_types::ToolResult<::worker_types::serde_json::Value> {
let typed_args: #arg_type = ::worker_types::serde_json::from_value(args)?;
let result = #fn_name(typed_args).await?;
// Use Display formatting instead of JSON serialization
let formatted_result = format!("{}", result);
Ok(::worker_types::serde_json::Value::String(formatted_result))
}
}
};
TokenStream::from(expanded)
}
fn validate_function_signature(func: &ItemFn) -> syn::Result<()> {
if func.sig.asyncness.is_none() {
return Err(syn::Error::new_spanned(
&func.sig,
"Tool function must be async",
));
}
if func.sig.inputs.len() != 1 {
return Err(syn::Error::new_spanned(
&func.sig.inputs,
"Tool function must have exactly one argument",
));
}
let arg = &func.sig.inputs[0];
if !matches!(arg, syn::FnArg::Typed(_)) {
return Err(syn::Error::new_spanned(
arg,
"Argument must be a typed pattern (e.g., `args: MyArgs`)",
));
}
if let syn::ReturnType::Default = func.sig.output {
return Err(syn::Error::new_spanned(
&func.sig,
"Tool function must have a return type, typically Result<T, E>",
));
}
Ok(())
}
fn extract_doc_comments(attrs: &[Attribute]) -> String {
let mut doc_lines = Vec::new();
for attr in attrs {
if attr.path().is_ident("doc") {
if let syn::Meta::NameValue(meta) = &attr.meta {
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit_str),
..
}) = &meta.value
{
let content = lit_str.value();
let trimmed = content.trim_start();
doc_lines.push(trimmed.to_string());
}
}
}
}
if doc_lines.is_empty() {
return String::new();
}
doc_lines.join("\n").trim().to_string()
}
fn to_camel_case(snake_case: &str) -> String {
snake_case
.split('_')
.map(|word| {
let mut chars = word.chars();
match chars.next() {
None => String::new(),
Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
}
})
.collect()
}
// Hook attribute arguments parser
struct HookAttributeArgs {
hook_type: String,
matcher: Option<String>,
}
impl Parse for HookAttributeArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut hook_type = None;
let mut matcher = None;
while !input.is_empty() {
let name: syn::Ident = input.parse()?;
input.parse::<syn::Token![=]>()?;
let value: LitStr = input.parse()?;
match name.to_string().as_str() {
"hook_type" => hook_type = Some(value.value()),
"matcher" => matcher = Some(value.value()),
_ => return Err(syn::Error::new_spanned(name, "Unknown hook attribute")),
}
if input.peek(syn::Token![,]) {
input.parse::<syn::Token![,]>()?;
}
}
let hook_type = hook_type.ok_or_else(|| input.error("Hook type is required"))?;
Ok(HookAttributeArgs { hook_type, matcher })
}
}
#[proc_macro_attribute]
pub fn hook(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as HookAttributeArgs);
let func = parse_macro_input!(item as ItemFn);
// Validate function signature for hooks
if let Err(e) = validate_hook_function_signature(&func) {
return e.to_compile_error().into();
}
let fn_name = &func.sig.ident;
let fn_name_str = fn_name.to_string();
let hook_type = &args.hook_type;
let matcher = args.matcher.as_deref().unwrap_or("");
// Generate struct name from function name
let hook_struct_name = {
let fn_name_str = fn_name.to_string();
let camel_case = to_camel_case(&fn_name_str);
// 既に "_hook" で終わっている場合は、それを削除してから "Hook" を追加
let cleaned_name = if camel_case.ends_with("Hook") {
camel_case.strip_suffix("Hook").unwrap_or(&camel_case)
} else {
&camel_case
};
syn::Ident::new(&format!("{}Hook", cleaned_name), fn_name.span())
};
let expanded = quote! {
// Keep the original function
#func
// Generate Hook struct
pub struct #hook_struct_name;
impl #hook_struct_name {
pub fn new() -> Self {
Self
}
}
// Implement WorkerHook trait
#[::worker_types::async_trait::async_trait]
impl ::worker_types::WorkerHook for #hook_struct_name {
fn name(&self) -> &str {
#fn_name_str
}
fn hook_type(&self) -> &str {
#hook_type
}
fn matcher(&self) -> &str {
#matcher
}
async fn execute(&self, context: ::worker_types::HookContext) -> (::worker_types::HookContext, ::worker_types::HookResult) {
#fn_name(context).await
}
}
};
TokenStream::from(expanded)
}
fn validate_hook_function_signature(func: &ItemFn) -> syn::Result<()> {
if func.sig.asyncness.is_none() {
return Err(syn::Error::new_spanned(
&func.sig,
"Hook function must be async",
));
}
if func.sig.inputs.len() != 1 {
return Err(syn::Error::new_spanned(
&func.sig.inputs,
"Hook function must have exactly one argument of type HookContext",
));
}
let arg = &func.sig.inputs[0];
if !matches!(arg, syn::FnArg::Typed(_)) {
return Err(syn::Error::new_spanned(
arg,
"Argument must be a typed pattern (e.g., `context: HookContext`)",
));
}
if let syn::ReturnType::Default = func.sig.output {
return Err(syn::Error::new_spanned(
&func.sig,
"Hook function must return (HookContext, HookResult)",
));
}
Ok(())
}