328 lines
9.6 KiB
Rust
328 lines
9.6 KiB
Rust
use proc_macro::TokenStream;
|
|
use quote::quote;
|
|
use syn::{
|
|
Attribute, ItemFn, LitStr,
|
|
parse::{Parse, ParseStream},
|
|
parse_macro_input,
|
|
};
|
|
|
|
struct ToolAttributeArgs {
|
|
name: Option<String>,
|
|
}
|
|
|
|
impl Parse for ToolAttributeArgs {
|
|
fn parse(input: ParseStream) -> syn::Result<Self> {
|
|
let mut name = None;
|
|
|
|
if !input.is_empty() {
|
|
let name_ident: syn::Ident = input.parse()?;
|
|
if name_ident != "name" {
|
|
return Err(syn::Error::new_spanned(
|
|
name_ident,
|
|
"Only 'name' attribute is supported",
|
|
));
|
|
}
|
|
input.parse::<syn::Token![=]>()?;
|
|
let name_str: LitStr = input.parse()?;
|
|
name = Some(name_str.value());
|
|
}
|
|
|
|
Ok(ToolAttributeArgs { name })
|
|
}
|
|
}
|
|
|
|
#[proc_macro_attribute]
|
|
pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
|
|
let args = parse_macro_input!(attr as ToolAttributeArgs);
|
|
let func = parse_macro_input!(item as ItemFn);
|
|
|
|
let description = {
|
|
let doc_comments = extract_doc_comments(&func.attrs);
|
|
if doc_comments.is_empty() {
|
|
format!("Tool function: {}", func.sig.ident)
|
|
} else {
|
|
doc_comments
|
|
}
|
|
};
|
|
|
|
// Validate function signature
|
|
if let Err(e) = validate_function_signature(&func) {
|
|
return e.to_compile_error().into();
|
|
}
|
|
|
|
let fn_name = &func.sig.ident;
|
|
let fn_name_str = fn_name.to_string();
|
|
|
|
// Use provided name or fall back to CamelCase function name
|
|
let tool_name_str = args.name.unwrap_or_else(|| to_camel_case(&fn_name_str));
|
|
|
|
// Extract arg_type and output_type safely after validation
|
|
let arg_type = if let syn::FnArg::Typed(pat_type) = &func.sig.inputs[0] {
|
|
&pat_type.ty
|
|
} else {
|
|
// This case should be caught by validate_function_signature
|
|
return syn::Error::new_spanned(&func.sig.inputs[0], "Expected typed argument")
|
|
.to_compile_error()
|
|
.into();
|
|
};
|
|
|
|
if let syn::ReturnType::Type(_, _) = &func.sig.output {
|
|
} else {
|
|
// This case should be caught by validate_function_signature
|
|
return syn::Error::new_spanned(&func.sig.output, "Expected return type")
|
|
.to_compile_error()
|
|
.into();
|
|
};
|
|
|
|
// Generate struct name from function name (e.g., read_file -> ReadFileTool)
|
|
let tool_struct_name = {
|
|
let fn_name_str = fn_name.to_string();
|
|
let camel_case = to_camel_case(&fn_name_str);
|
|
syn::Ident::new(&format!("{}Tool", camel_case), fn_name.span())
|
|
};
|
|
|
|
let expanded = quote! {
|
|
// Keep the original function
|
|
#func
|
|
|
|
// Generate Tool struct
|
|
pub struct #tool_struct_name;
|
|
|
|
impl #tool_struct_name {
|
|
pub fn new() -> Self {
|
|
Self
|
|
}
|
|
}
|
|
|
|
// Implement Tool trait
|
|
#[::worker::types::async_trait::async_trait]
|
|
impl ::worker::types::Tool for #tool_struct_name {
|
|
fn name(&self) -> &str {
|
|
#tool_name_str
|
|
}
|
|
|
|
fn description(&self) -> &str {
|
|
#description
|
|
}
|
|
|
|
fn parameters_schema(&self) -> ::worker::types::serde_json::Value {
|
|
::worker::types::serde_json::to_value(::worker::types::schemars::schema_for!(#arg_type)).unwrap()
|
|
}
|
|
|
|
async fn execute(&self, args: ::worker::types::serde_json::Value) -> ::worker::types::ToolResult<::worker::types::serde_json::Value> {
|
|
let typed_args: #arg_type = ::worker::types::serde_json::from_value(args)?;
|
|
let result = #fn_name(typed_args).await?;
|
|
// Use Display formatting instead of JSON serialization
|
|
let formatted_result = format!("{}", result);
|
|
Ok(::worker::types::serde_json::Value::String(formatted_result))
|
|
}
|
|
}
|
|
|
|
};
|
|
|
|
TokenStream::from(expanded)
|
|
}
|
|
|
|
fn validate_function_signature(func: &ItemFn) -> syn::Result<()> {
|
|
if func.sig.asyncness.is_none() {
|
|
return Err(syn::Error::new_spanned(
|
|
&func.sig,
|
|
"Tool function must be async",
|
|
));
|
|
}
|
|
|
|
if func.sig.inputs.len() != 1 {
|
|
return Err(syn::Error::new_spanned(
|
|
&func.sig.inputs,
|
|
"Tool function must have exactly one argument",
|
|
));
|
|
}
|
|
|
|
let arg = &func.sig.inputs[0];
|
|
if !matches!(arg, syn::FnArg::Typed(_)) {
|
|
return Err(syn::Error::new_spanned(
|
|
arg,
|
|
"Argument must be a typed pattern (e.g., `args: MyArgs`)",
|
|
));
|
|
}
|
|
|
|
if let syn::ReturnType::Default = func.sig.output {
|
|
return Err(syn::Error::new_spanned(
|
|
&func.sig,
|
|
"Tool function must have a return type, typically Result<T, E>",
|
|
));
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn extract_doc_comments(attrs: &[Attribute]) -> String {
|
|
let mut doc_lines = Vec::new();
|
|
|
|
for attr in attrs {
|
|
if attr.path().is_ident("doc") {
|
|
if let syn::Meta::NameValue(meta) = &attr.meta {
|
|
if let syn::Expr::Lit(syn::ExprLit {
|
|
lit: syn::Lit::Str(lit_str),
|
|
..
|
|
}) = &meta.value
|
|
{
|
|
let content = lit_str.value();
|
|
let trimmed = content.trim_start();
|
|
doc_lines.push(trimmed.to_string());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if doc_lines.is_empty() {
|
|
return String::new();
|
|
}
|
|
|
|
doc_lines.join("\n").trim().to_string()
|
|
}
|
|
|
|
fn to_camel_case(snake_case: &str) -> String {
|
|
snake_case
|
|
.split('_')
|
|
.map(|word| {
|
|
let mut chars = word.chars();
|
|
match chars.next() {
|
|
None => String::new(),
|
|
Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
|
|
}
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
// Hook attribute arguments parser
|
|
struct HookAttributeArgs {
|
|
hook_type: String,
|
|
matcher: Option<String>,
|
|
}
|
|
|
|
impl Parse for HookAttributeArgs {
|
|
fn parse(input: ParseStream) -> syn::Result<Self> {
|
|
let mut hook_type = None;
|
|
let mut matcher = None;
|
|
|
|
while !input.is_empty() {
|
|
let name: syn::Ident = input.parse()?;
|
|
input.parse::<syn::Token![=]>()?;
|
|
let value: LitStr = input.parse()?;
|
|
|
|
match name.to_string().as_str() {
|
|
"hook_type" => hook_type = Some(value.value()),
|
|
"matcher" => matcher = Some(value.value()),
|
|
_ => return Err(syn::Error::new_spanned(name, "Unknown hook attribute")),
|
|
}
|
|
|
|
if input.peek(syn::Token![,]) {
|
|
input.parse::<syn::Token![,]>()?;
|
|
}
|
|
}
|
|
|
|
let hook_type = hook_type.ok_or_else(|| input.error("Hook type is required"))?;
|
|
|
|
Ok(HookAttributeArgs { hook_type, matcher })
|
|
}
|
|
}
|
|
|
|
#[proc_macro_attribute]
|
|
pub fn hook(attr: TokenStream, item: TokenStream) -> TokenStream {
|
|
let args = parse_macro_input!(attr as HookAttributeArgs);
|
|
let func = parse_macro_input!(item as ItemFn);
|
|
|
|
// Validate function signature for hooks
|
|
if let Err(e) = validate_hook_function_signature(&func) {
|
|
return e.to_compile_error().into();
|
|
}
|
|
|
|
let fn_name = &func.sig.ident;
|
|
let fn_name_str = fn_name.to_string();
|
|
let hook_type = &args.hook_type;
|
|
let matcher = args.matcher.as_deref().unwrap_or("");
|
|
|
|
// Generate struct name from function name
|
|
let hook_struct_name = {
|
|
let fn_name_str = fn_name.to_string();
|
|
let camel_case = to_camel_case(&fn_name_str);
|
|
// 既に "_hook" で終わっている場合は、それを削除してから "Hook" を追加
|
|
let cleaned_name = if camel_case.ends_with("Hook") {
|
|
camel_case.strip_suffix("Hook").unwrap_or(&camel_case)
|
|
} else {
|
|
&camel_case
|
|
};
|
|
syn::Ident::new(&format!("{}Hook", cleaned_name), fn_name.span())
|
|
};
|
|
|
|
let expanded = quote! {
|
|
// Keep the original function
|
|
#func
|
|
|
|
// Generate Hook struct
|
|
pub struct #hook_struct_name;
|
|
|
|
impl #hook_struct_name {
|
|
pub fn new() -> Self {
|
|
Self
|
|
}
|
|
}
|
|
|
|
// Implement WorkerHook trait
|
|
#[::worker::types::async_trait::async_trait]
|
|
impl ::worker::types::WorkerHook for #hook_struct_name {
|
|
fn name(&self) -> &str {
|
|
#fn_name_str
|
|
}
|
|
|
|
fn hook_type(&self) -> &str {
|
|
#hook_type
|
|
}
|
|
|
|
fn matcher(&self) -> &str {
|
|
#matcher
|
|
}
|
|
|
|
async fn execute(&self, context: ::worker::types::HookContext) -> (::worker::types::HookContext, ::worker::types::HookResult) {
|
|
#fn_name(context).await
|
|
}
|
|
}
|
|
};
|
|
|
|
TokenStream::from(expanded)
|
|
}
|
|
|
|
fn validate_hook_function_signature(func: &ItemFn) -> syn::Result<()> {
|
|
if func.sig.asyncness.is_none() {
|
|
return Err(syn::Error::new_spanned(
|
|
&func.sig,
|
|
"Hook function must be async",
|
|
));
|
|
}
|
|
|
|
if func.sig.inputs.len() != 1 {
|
|
return Err(syn::Error::new_spanned(
|
|
&func.sig.inputs,
|
|
"Hook function must have exactly one argument of type HookContext",
|
|
));
|
|
}
|
|
|
|
let arg = &func.sig.inputs[0];
|
|
if !matches!(arg, syn::FnArg::Typed(_)) {
|
|
return Err(syn::Error::new_spanned(
|
|
arg,
|
|
"Argument must be a typed pattern (e.g., `context: HookContext`)",
|
|
));
|
|
}
|
|
|
|
if let syn::ReturnType::Default = func.sig.output {
|
|
return Err(syn::Error::new_spanned(
|
|
&func.sig,
|
|
"Hook function must return (HookContext, HookResult)",
|
|
));
|
|
}
|
|
|
|
Ok(())
|
|
}
|