llm_worker_rs/worker-macros/src/lib.rs
2026-01-07 22:04:44 +09:00

327 lines
9.9 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! worker-macros - Tool生成用手続きマクロ
//!
//! `#[tool_registry]` と `#[tool]` マクロを提供し、
//! ユーザー定義のメソッドから `Tool` トレイト実装を自動生成する。
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{
Attribute, FnArg, ImplItem, ItemImpl, Lit, Meta, Pat, ReturnType, Type, parse_macro_input,
};
/// `impl` ブロックに付与し、内部の `#[tool]` 属性がついたメソッドからツールを生成するマクロ。
///
/// # Example
/// ```ignore
/// #[tool_registry]
/// impl MyApp {
/// /// ユーザー情報を取得する
/// /// 指定されたIDのユーザーをDBから検索します。
/// #[tool]
/// async fn get_user(&self, user_id: String) -> Result<User, Error> { ... }
/// }
/// ```
///
/// これにより以下が生成されます:
/// - `GetUserArgs` 構造体(引数用)
/// - `Tool_get_user` 構造体Toolラッパー
/// - `impl Tool for Tool_get_user`
/// - `impl MyApp { fn get_user_tool(&self) -> Tool_get_user }`
#[proc_macro_attribute]
pub fn tool_registry(_attr: TokenStream, item: TokenStream) -> TokenStream {
let mut impl_block = parse_macro_input!(item as ItemImpl);
let self_ty = &impl_block.self_ty;
let mut generated_items = Vec::new();
for item in &mut impl_block.items {
if let ImplItem::Fn(method) = item {
// #[tool] 属性を探す
let mut is_tool = false;
// 属性を走査してtoolがあるか確認し、削除する
method.attrs.retain(|attr| {
if attr.path().is_ident("tool") {
is_tool = true;
false // 属性を削除
} else {
true
}
});
if is_tool {
let tool_impl = generate_tool_impl(self_ty, method);
generated_items.push(tool_impl);
}
}
}
let expanded = quote! {
#impl_block
#(#generated_items)*
};
TokenStream::from(expanded)
}
/// ドキュメントコメントから説明文を抽出
fn extract_doc_comment(attrs: &[Attribute]) -> String {
let mut lines = Vec::new();
for attr in attrs {
if attr.path().is_ident("doc") {
if let Meta::NameValue(meta) = &attr.meta {
if let syn::Expr::Lit(expr_lit) = &meta.value {
if let Lit::Str(lit_str) = &expr_lit.lit {
let line = lit_str.value();
// 先頭の空白を1つだけ除去/// の後のスペース)
let trimmed = line.strip_prefix(' ').unwrap_or(&line);
lines.push(trimmed.to_string());
}
}
}
}
}
lines.join("\n")
}
/// #[description = "..."] 属性から説明を抽出
fn extract_description_attr(attrs: &[syn::Attribute]) -> Option<String> {
for attr in attrs {
if attr.path().is_ident("description") {
if let Meta::NameValue(meta) = &attr.meta {
if let syn::Expr::Lit(expr_lit) = &meta.value {
if let Lit::Str(lit_str) = &expr_lit.lit {
return Some(lit_str.value());
}
}
}
}
}
None
}
/// メソッドからTool実装を生成
fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::TokenStream {
let sig = &method.sig;
let method_name = &sig.ident;
let tool_name = method_name.to_string();
// 構造体名を生成PascalCase変換
let pascal_name = to_pascal_case(&method_name.to_string());
let tool_struct_name = format_ident!("Tool{}", pascal_name);
let args_struct_name = format_ident!("{}Args", pascal_name);
let factory_name = format_ident!("{}_tool", method_name);
// ドキュメントコメントから説明を取得
let description = extract_doc_comment(&method.attrs);
let description = if description.is_empty() {
format!("Tool: {}", tool_name)
} else {
description
};
// 引数を解析selfを除く
let args: Vec<_> = sig
.inputs
.iter()
.filter_map(|arg| {
if let FnArg::Typed(pat_type) = arg {
Some(pat_type)
} else {
None // selfを除外
}
})
.collect();
// 引数構造体のフィールドを生成
let arg_fields: Vec<_> = args
.iter()
.map(|pat_type| {
let pat = &pat_type.pat;
let ty = &pat_type.ty;
let desc = extract_description_attr(&pat_type.attrs);
// パターンから識別子を抽出
let field_name = if let Pat::Ident(pat_ident) = pat.as_ref() {
&pat_ident.ident
} else {
panic!("Only simple identifiers are supported for tool arguments");
};
// #[description] があればschemarsのdocに変換
if let Some(desc_str) = desc {
quote! {
#[schemars(description = #desc_str)]
pub #field_name: #ty
}
} else {
quote! {
pub #field_name: #ty
}
}
})
.collect();
// execute内で引数を展開するコード
let arg_names: Vec<_> = args
.iter()
.map(|pat_type| {
if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
let ident = &pat_ident.ident;
quote! { args.#ident }
} else {
panic!("Only simple identifiers are supported");
}
})
.collect();
// メソッドが非同期かどうか
let is_async = sig.asyncness.is_some();
// 戻り値の型を解析してResult判定
let awaiter = if is_async {
quote! { .await }
} else {
quote! {}
};
// 戻り値がResultかどうかを判定
let result_handling = if is_result_type(&sig.output) {
quote! {
match result {
Ok(val) => Ok(format!("{:?}", val)),
Err(e) => Err(worker_types::ToolError::ExecutionFailed(format!("{}", e))),
}
}
} else {
quote! {
Ok(format!("{:?}", result))
}
};
// 引数がない場合は空のArgs構造体を作成
let args_struct_def = if arg_fields.is_empty() {
quote! {
#[derive(serde::Deserialize, schemars::JsonSchema)]
struct #args_struct_name {}
}
} else {
quote! {
#[derive(serde::Deserialize, schemars::JsonSchema)]
struct #args_struct_name {
#(#arg_fields),*
}
}
};
// 引数がない場合のexecute処理
let execute_body = if args.is_empty() {
quote! {
// 引数なしでも空のJSONオブジェクトを許容
let _: #args_struct_name = serde_json::from_str(input_json)
.unwrap_or(#args_struct_name {});
let result = self.ctx.#method_name()#awaiter;
#result_handling
}
} else {
quote! {
let args: #args_struct_name = serde_json::from_str(input_json)
.map_err(|e| worker_types::ToolError::InvalidArgument(e.to_string()))?;
let result = self.ctx.#method_name(#(#arg_names),*)#awaiter;
#result_handling
}
};
quote! {
#args_struct_def
#[derive(Clone)]
pub struct #tool_struct_name {
ctx: #self_ty,
}
#[async_trait::async_trait]
impl worker_types::Tool for #tool_struct_name {
fn name(&self) -> &str {
#tool_name
}
fn description(&self) -> &str {
#description
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(#args_struct_name);
serde_json::to_value(schema).unwrap_or(serde_json::json!({}))
}
async fn execute(&self, input_json: &str) -> Result<String, worker_types::ToolError> {
#execute_body
}
}
impl #self_ty {
pub fn #factory_name(&self) -> #tool_struct_name {
#tool_struct_name {
ctx: self.clone()
}
}
}
}
}
/// 戻り値の型がResultかどうかを判定
fn is_result_type(return_type: &ReturnType) -> bool {
match return_type {
ReturnType::Default => false,
ReturnType::Type(_, ty) => {
// Type::Pathの場合、最後のセグメントが"Result"かチェック
if let Type::Path(type_path) = ty.as_ref() {
if let Some(segment) = type_path.path.segments.last() {
return segment.ident == "Result";
}
}
false
}
}
}
/// snake_case を PascalCase に変換
fn to_pascal_case(s: &str) -> String {
s.split('_')
.map(|part| {
let mut chars = part.chars();
match chars.next() {
None => String::new(),
Some(first) => first.to_uppercase().chain(chars).collect(),
}
})
.collect()
}
/// マーカー属性。`tool_registry` によって処理されるため、ここでは何もしない。
#[proc_macro_attribute]
pub fn tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
/// 引数属性用のマーカー。パース時に`tool_registry`で解釈される。
///
/// # Example
/// ```ignore
/// #[tool]
/// async fn get_user(
/// &self,
/// #[description = "取得したいユーザーのID"] user_id: String
/// ) -> Result<User, Error> { ... }
/// ```
#[proc_macro_attribute]
pub fn description(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}