327 lines
9.9 KiB
Rust
327 lines
9.9 KiB
Rust
//! worker-macros - Tool生成用手続きマクロ
|
||
//!
|
||
//! `#[tool_registry]` と `#[tool]` マクロを提供し、
|
||
//! ユーザー定義のメソッドから `Tool` トレイト実装を自動生成する。
|
||
|
||
use proc_macro::TokenStream;
|
||
use quote::{format_ident, quote};
|
||
use syn::{
|
||
Attribute, FnArg, ImplItem, ItemImpl, Lit, Meta, Pat, ReturnType, Type, parse_macro_input,
|
||
};
|
||
|
||
/// `impl` ブロックに付与し、内部の `#[tool]` 属性がついたメソッドからツールを生成するマクロ。
|
||
///
|
||
/// # Example
|
||
/// ```ignore
|
||
/// #[tool_registry]
|
||
/// impl MyApp {
|
||
/// /// ユーザー情報を取得する
|
||
/// /// 指定されたIDのユーザーをDBから検索します。
|
||
/// #[tool]
|
||
/// async fn get_user(&self, user_id: String) -> Result<User, Error> { ... }
|
||
/// }
|
||
/// ```
|
||
///
|
||
/// これにより以下が生成されます:
|
||
/// - `GetUserArgs` 構造体(引数用)
|
||
/// - `Tool_get_user` 構造体(Toolラッパー)
|
||
/// - `impl Tool for Tool_get_user`
|
||
/// - `impl MyApp { fn get_user_tool(&self) -> Tool_get_user }`
|
||
#[proc_macro_attribute]
|
||
pub fn tool_registry(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
||
let mut impl_block = parse_macro_input!(item as ItemImpl);
|
||
let self_ty = &impl_block.self_ty;
|
||
|
||
let mut generated_items = Vec::new();
|
||
|
||
for item in &mut impl_block.items {
|
||
if let ImplItem::Fn(method) = item {
|
||
// #[tool] 属性を探す
|
||
let mut is_tool = false;
|
||
|
||
// 属性を走査してtoolがあるか確認し、削除する
|
||
method.attrs.retain(|attr| {
|
||
if attr.path().is_ident("tool") {
|
||
is_tool = true;
|
||
false // 属性を削除
|
||
} else {
|
||
true
|
||
}
|
||
});
|
||
|
||
if is_tool {
|
||
let tool_impl = generate_tool_impl(self_ty, method);
|
||
generated_items.push(tool_impl);
|
||
}
|
||
}
|
||
}
|
||
|
||
let expanded = quote! {
|
||
#impl_block
|
||
|
||
#(#generated_items)*
|
||
};
|
||
|
||
TokenStream::from(expanded)
|
||
}
|
||
|
||
/// ドキュメントコメントから説明文を抽出
|
||
fn extract_doc_comment(attrs: &[Attribute]) -> String {
|
||
let mut lines = Vec::new();
|
||
|
||
for attr in attrs {
|
||
if attr.path().is_ident("doc") {
|
||
if let Meta::NameValue(meta) = &attr.meta {
|
||
if let syn::Expr::Lit(expr_lit) = &meta.value {
|
||
if let Lit::Str(lit_str) = &expr_lit.lit {
|
||
let line = lit_str.value();
|
||
// 先頭の空白を1つだけ除去(/// の後のスペース)
|
||
let trimmed = line.strip_prefix(' ').unwrap_or(&line);
|
||
lines.push(trimmed.to_string());
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
lines.join("\n")
|
||
}
|
||
|
||
/// #[description = "..."] 属性から説明を抽出
|
||
fn extract_description_attr(attrs: &[syn::Attribute]) -> Option<String> {
|
||
for attr in attrs {
|
||
if attr.path().is_ident("description") {
|
||
if let Meta::NameValue(meta) = &attr.meta {
|
||
if let syn::Expr::Lit(expr_lit) = &meta.value {
|
||
if let Lit::Str(lit_str) = &expr_lit.lit {
|
||
return Some(lit_str.value());
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
None
|
||
}
|
||
|
||
/// メソッドからTool実装を生成
|
||
fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::TokenStream {
|
||
let sig = &method.sig;
|
||
let method_name = &sig.ident;
|
||
let tool_name = method_name.to_string();
|
||
|
||
// 構造体名を生成(PascalCase変換)
|
||
let pascal_name = to_pascal_case(&method_name.to_string());
|
||
let tool_struct_name = format_ident!("Tool{}", pascal_name);
|
||
let args_struct_name = format_ident!("{}Args", pascal_name);
|
||
let factory_name = format_ident!("{}_tool", method_name);
|
||
|
||
// ドキュメントコメントから説明を取得
|
||
let description = extract_doc_comment(&method.attrs);
|
||
let description = if description.is_empty() {
|
||
format!("Tool: {}", tool_name)
|
||
} else {
|
||
description
|
||
};
|
||
|
||
// 引数を解析(selfを除く)
|
||
let args: Vec<_> = sig
|
||
.inputs
|
||
.iter()
|
||
.filter_map(|arg| {
|
||
if let FnArg::Typed(pat_type) = arg {
|
||
Some(pat_type)
|
||
} else {
|
||
None // selfを除外
|
||
}
|
||
})
|
||
.collect();
|
||
|
||
// 引数構造体のフィールドを生成
|
||
let arg_fields: Vec<_> = args
|
||
.iter()
|
||
.map(|pat_type| {
|
||
let pat = &pat_type.pat;
|
||
let ty = &pat_type.ty;
|
||
let desc = extract_description_attr(&pat_type.attrs);
|
||
|
||
// パターンから識別子を抽出
|
||
let field_name = if let Pat::Ident(pat_ident) = pat.as_ref() {
|
||
&pat_ident.ident
|
||
} else {
|
||
panic!("Only simple identifiers are supported for tool arguments");
|
||
};
|
||
|
||
// #[description] があればschemarsのdocに変換
|
||
if let Some(desc_str) = desc {
|
||
quote! {
|
||
#[schemars(description = #desc_str)]
|
||
pub #field_name: #ty
|
||
}
|
||
} else {
|
||
quote! {
|
||
pub #field_name: #ty
|
||
}
|
||
}
|
||
})
|
||
.collect();
|
||
|
||
// execute内で引数を展開するコード
|
||
let arg_names: Vec<_> = args
|
||
.iter()
|
||
.map(|pat_type| {
|
||
if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
|
||
let ident = &pat_ident.ident;
|
||
quote! { args.#ident }
|
||
} else {
|
||
panic!("Only simple identifiers are supported");
|
||
}
|
||
})
|
||
.collect();
|
||
|
||
// メソッドが非同期かどうか
|
||
let is_async = sig.asyncness.is_some();
|
||
|
||
// 戻り値の型を解析してResult判定
|
||
let awaiter = if is_async {
|
||
quote! { .await }
|
||
} else {
|
||
quote! {}
|
||
};
|
||
|
||
// 戻り値がResultかどうかを判定
|
||
let result_handling = if is_result_type(&sig.output) {
|
||
quote! {
|
||
match result {
|
||
Ok(val) => Ok(format!("{:?}", val)),
|
||
Err(e) => Err(::worker::tool::ToolError::ExecutionFailed(format!("{}", e))),
|
||
}
|
||
}
|
||
} else {
|
||
quote! {
|
||
Ok(format!("{:?}", result))
|
||
}
|
||
};
|
||
|
||
// 引数がない場合は空のArgs構造体を作成
|
||
let args_struct_def = if arg_fields.is_empty() {
|
||
quote! {
|
||
#[derive(serde::Deserialize, schemars::JsonSchema)]
|
||
struct #args_struct_name {}
|
||
}
|
||
} else {
|
||
quote! {
|
||
#[derive(serde::Deserialize, schemars::JsonSchema)]
|
||
struct #args_struct_name {
|
||
#(#arg_fields),*
|
||
}
|
||
}
|
||
};
|
||
|
||
// 引数がない場合のexecute処理
|
||
let execute_body = if args.is_empty() {
|
||
quote! {
|
||
// 引数なしでも空のJSONオブジェクトを許容
|
||
let _: #args_struct_name = serde_json::from_str(input_json)
|
||
.unwrap_or(#args_struct_name {});
|
||
|
||
let result = self.ctx.#method_name()#awaiter;
|
||
#result_handling
|
||
}
|
||
} else {
|
||
quote! {
|
||
let args: #args_struct_name = serde_json::from_str(input_json)
|
||
.map_err(|e| ::worker::tool::ToolError::InvalidArgument(e.to_string()))?;
|
||
|
||
let result = self.ctx.#method_name(#(#arg_names),*)#awaiter;
|
||
#result_handling
|
||
}
|
||
};
|
||
|
||
quote! {
|
||
#args_struct_def
|
||
|
||
#[derive(Clone)]
|
||
pub struct #tool_struct_name {
|
||
ctx: #self_ty,
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl ::worker::tool::Tool for #tool_struct_name {
|
||
fn name(&self) -> &str {
|
||
#tool_name
|
||
}
|
||
|
||
fn description(&self) -> &str {
|
||
#description
|
||
}
|
||
|
||
fn input_schema(&self) -> serde_json::Value {
|
||
let schema = schemars::schema_for!(#args_struct_name);
|
||
serde_json::to_value(schema).unwrap_or(serde_json::json!({}))
|
||
}
|
||
|
||
async fn execute(&self, input_json: &str) -> Result<String, ::worker::tool::ToolError> {
|
||
#execute_body
|
||
}
|
||
}
|
||
|
||
impl #self_ty {
|
||
pub fn #factory_name(&self) -> #tool_struct_name {
|
||
#tool_struct_name {
|
||
ctx: self.clone()
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 戻り値の型がResultかどうかを判定
|
||
fn is_result_type(return_type: &ReturnType) -> bool {
|
||
match return_type {
|
||
ReturnType::Default => false,
|
||
ReturnType::Type(_, ty) => {
|
||
// Type::Pathの場合、最後のセグメントが"Result"かチェック
|
||
if let Type::Path(type_path) = ty.as_ref() {
|
||
if let Some(segment) = type_path.path.segments.last() {
|
||
return segment.ident == "Result";
|
||
}
|
||
}
|
||
false
|
||
}
|
||
}
|
||
}
|
||
|
||
/// snake_case を PascalCase に変換
|
||
fn to_pascal_case(s: &str) -> String {
|
||
s.split('_')
|
||
.map(|part| {
|
||
let mut chars = part.chars();
|
||
match chars.next() {
|
||
None => String::new(),
|
||
Some(first) => first.to_uppercase().chain(chars).collect(),
|
||
}
|
||
})
|
||
.collect()
|
||
}
|
||
|
||
/// マーカー属性。`tool_registry` によって処理されるため、ここでは何もしない。
|
||
#[proc_macro_attribute]
|
||
pub fn tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
||
item
|
||
}
|
||
|
||
/// 引数属性用のマーカー。パース時に`tool_registry`で解釈される。
|
||
///
|
||
/// # Example
|
||
/// ```ignore
|
||
/// #[tool]
|
||
/// async fn get_user(
|
||
/// &self,
|
||
/// #[description = "取得したいユーザーのID"] user_id: String
|
||
/// ) -> Result<User, Error> { ... }
|
||
/// ```
|
||
#[proc_macro_attribute]
|
||
pub fn description(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
||
item
|
||
}
|