use proc_macro::TokenStream; use quote::quote; use syn::{ Attribute, ItemFn, LitStr, parse::{Parse, ParseStream}, parse_macro_input, }; struct ToolAttributeArgs { name: Option, } impl Parse for ToolAttributeArgs { fn parse(input: ParseStream) -> syn::Result { let mut name = None; if !input.is_empty() { let name_ident: syn::Ident = input.parse()?; if name_ident != "name" { return Err(syn::Error::new_spanned( name_ident, "Only 'name' attribute is supported", )); } input.parse::()?; let name_str: LitStr = input.parse()?; name = Some(name_str.value()); } Ok(ToolAttributeArgs { name }) } } #[proc_macro_attribute] pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream { let args = parse_macro_input!(attr as ToolAttributeArgs); let func = parse_macro_input!(item as ItemFn); let description = { let doc_comments = extract_doc_comments(&func.attrs); if doc_comments.is_empty() { format!("Tool function: {}", func.sig.ident) } else { doc_comments } }; // Validate function signature if let Err(e) = validate_function_signature(&func) { return e.to_compile_error().into(); } let fn_name = &func.sig.ident; let fn_name_str = fn_name.to_string(); // Use provided name or fall back to CamelCase function name let tool_name_str = args.name.unwrap_or_else(|| to_camel_case(&fn_name_str)); // Extract arg_type and output_type safely after validation let arg_type = if let syn::FnArg::Typed(pat_type) = &func.sig.inputs[0] { &pat_type.ty } else { // This case should be caught by validate_function_signature return syn::Error::new_spanned(&func.sig.inputs[0], "Expected typed argument") .to_compile_error() .into(); }; if let syn::ReturnType::Type(_, _) = &func.sig.output { } else { // This case should be caught by validate_function_signature return syn::Error::new_spanned(&func.sig.output, "Expected return type") .to_compile_error() .into(); }; // Generate struct name from function name (e.g., read_file -> ReadFileTool) let tool_struct_name = { let fn_name_str = fn_name.to_string(); let camel_case = to_camel_case(&fn_name_str); syn::Ident::new(&format!("{}Tool", camel_case), fn_name.span()) }; let expanded = quote! { // Keep the original function #func // Generate Tool struct pub struct #tool_struct_name; impl #tool_struct_name { pub fn new() -> Self { Self } } // Implement Tool trait #[::worker_types::async_trait::async_trait] impl ::worker_types::Tool for #tool_struct_name { fn name(&self) -> &str { #tool_name_str } fn description(&self) -> &str { #description } fn parameters_schema(&self) -> ::worker_types::serde_json::Value { ::worker_types::serde_json::to_value(::worker_types::schemars::schema_for!(#arg_type)).unwrap() } async fn execute(&self, args: ::worker_types::serde_json::Value) -> ::worker_types::ToolResult<::worker_types::serde_json::Value> { let typed_args: #arg_type = ::worker_types::serde_json::from_value(args)?; let result = #fn_name(typed_args).await?; // Use Display formatting instead of JSON serialization let formatted_result = format!("{}", result); Ok(::worker_types::serde_json::Value::String(formatted_result)) } } }; TokenStream::from(expanded) } fn validate_function_signature(func: &ItemFn) -> syn::Result<()> { if func.sig.asyncness.is_none() { return Err(syn::Error::new_spanned( &func.sig, "Tool function must be async", )); } if func.sig.inputs.len() != 1 { return Err(syn::Error::new_spanned( &func.sig.inputs, "Tool function must have exactly one argument", )); } let arg = &func.sig.inputs[0]; if !matches!(arg, syn::FnArg::Typed(_)) { return Err(syn::Error::new_spanned( arg, "Argument must be a typed pattern (e.g., `args: MyArgs`)", )); } if let syn::ReturnType::Default = func.sig.output { return Err(syn::Error::new_spanned( &func.sig, "Tool function must have a return type, typically Result", )); } Ok(()) } fn extract_doc_comments(attrs: &[Attribute]) -> String { let mut doc_lines = Vec::new(); for attr in attrs { if attr.path().is_ident("doc") { if let syn::Meta::NameValue(meta) = &attr.meta { if let syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(lit_str), .. }) = &meta.value { let content = lit_str.value(); let trimmed = content.trim_start(); doc_lines.push(trimmed.to_string()); } } } } if doc_lines.is_empty() { return String::new(); } doc_lines.join("\n").trim().to_string() } fn to_camel_case(snake_case: &str) -> String { snake_case .split('_') .map(|word| { let mut chars = word.chars(); match chars.next() { None => String::new(), Some(first) => first.to_uppercase().collect::() + chars.as_str(), } }) .collect() } // Hook attribute arguments parser struct HookAttributeArgs { hook_type: String, matcher: Option, } impl Parse for HookAttributeArgs { fn parse(input: ParseStream) -> syn::Result { let mut hook_type = None; let mut matcher = None; while !input.is_empty() { let name: syn::Ident = input.parse()?; input.parse::()?; let value: LitStr = input.parse()?; match name.to_string().as_str() { "hook_type" => hook_type = Some(value.value()), "matcher" => matcher = Some(value.value()), _ => return Err(syn::Error::new_spanned(name, "Unknown hook attribute")), } if input.peek(syn::Token![,]) { input.parse::()?; } } let hook_type = hook_type.ok_or_else(|| input.error("Hook type is required"))?; Ok(HookAttributeArgs { hook_type, matcher }) } } #[proc_macro_attribute] pub fn hook(attr: TokenStream, item: TokenStream) -> TokenStream { let args = parse_macro_input!(attr as HookAttributeArgs); let func = parse_macro_input!(item as ItemFn); // Validate function signature for hooks if let Err(e) = validate_hook_function_signature(&func) { return e.to_compile_error().into(); } let fn_name = &func.sig.ident; let fn_name_str = fn_name.to_string(); let hook_type = &args.hook_type; let matcher = args.matcher.as_deref().unwrap_or(""); // Generate struct name from function name let hook_struct_name = { let fn_name_str = fn_name.to_string(); let camel_case = to_camel_case(&fn_name_str); // 既に "_hook" で終わっている場合は、それを削除してから "Hook" を追加 let cleaned_name = if camel_case.ends_with("Hook") { camel_case.strip_suffix("Hook").unwrap_or(&camel_case) } else { &camel_case }; syn::Ident::new(&format!("{}Hook", cleaned_name), fn_name.span()) }; let expanded = quote! { // Keep the original function #func // Generate Hook struct pub struct #hook_struct_name; impl #hook_struct_name { pub fn new() -> Self { Self } } // Implement WorkerHook trait #[::worker_types::async_trait::async_trait] impl ::worker_types::WorkerHook for #hook_struct_name { fn name(&self) -> &str { #fn_name_str } fn hook_type(&self) -> &str { #hook_type } fn matcher(&self) -> &str { #matcher } async fn execute(&self, context: ::worker_types::HookContext) -> (::worker_types::HookContext, ::worker_types::HookResult) { #fn_name(context).await } } }; TokenStream::from(expanded) } fn validate_hook_function_signature(func: &ItemFn) -> syn::Result<()> { if func.sig.asyncness.is_none() { return Err(syn::Error::new_spanned( &func.sig, "Hook function must be async", )); } if func.sig.inputs.len() != 1 { return Err(syn::Error::new_spanned( &func.sig.inputs, "Hook function must have exactly one argument of type HookContext", )); } let arg = &func.sig.inputs[0]; if !matches!(arg, syn::FnArg::Typed(_)) { return Err(syn::Error::new_spanned( arg, "Argument must be a typed pattern (e.g., `context: HookContext`)", )); } if let syn::ReturnType::Default = func.sig.output { return Err(syn::Error::new_spanned( &func.sig, "Hook function must return (HookContext, HookResult)", )); } Ok(()) }