diff --git a/Cargo.lock b/Cargo.lock index bc5c1dd..7686203 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1908,6 +1908,7 @@ dependencies = [ "eventsource-stream", "futures", "reqwest", + "schemars", "serde", "serde_json", "tempfile", diff --git a/docs/spec/worker_design.md b/docs/spec/worker_design.md index 59ac333..300f7ce 100644 --- a/docs/spec/worker_design.md +++ b/docs/spec/worker_design.md @@ -228,4 +228,209 @@ pub enum TurnResult { * ただし、リアルタイム性を重視する場合(ストリーミング中にToolを実行開始等)は将来的な拡張とするが、現状は「結果が揃うのを待って」という要件に従い、収集フェーズと実行フェーズを分ける。 3. **worker-macros**: - * `syn`, `quote` を用いて、関数定義から `Tool` トレイト実装と `InputInputSchema` (schemars利用) を生成。 + * `syn`, `quote` を用いて、関数定義から `Tool` トレイト実装と `InputSchema` (schemars利用) を生成。 + +## Worker Event API 設計 + +### 背景と目的 + +Workerは内部でイベントを処理し結果を返しますが、UIへのストリーミング表示やリアルタイムフィードバックには、イベントを外部に公開する仕組みが必要です。 + +**要件**: +1. テキストデルタをリアルタイムでUIに表示 +2. ツール呼び出しの進行状況を表示 +3. ブロック完了時に累積結果を受け取る + +### 設計思想 + +Worker APIは **Timeline層のHandler機構の薄いラッパー** として設計します。 + +| 層 | 目的 | 提供するもの | +|---|------|-------------| +| **Handler (Timeline層)** | 内部実装、役割分離 | スコープ管理 + Deltaイベント | +| **Worker Event API** | ユーザー向け利便性 | Handler露出 + Completeイベント追加 | + +Handlerのスコープ管理パターン(Start→Delta→End)をそのまま活かしつつ、累積済みのCompleteイベントも追加提供します。 + +### APIパターン + +#### 1. 個別登録: `worker.on_*(handler)` + +Timelineの`on_*`メソッドを直接露出。必要なイベントだけを個別に登録可能にする。 + +```rust +// ブロックイベント(スコープ管理あり) +worker.on_text_block(my_text_handler); // Handler +worker.on_tool_use_block(my_tool_handler); // Handler + +// 単発イベント(スコープ = ()) +worker.on_usage(my_usage_handler); // Handler +worker.on_status(my_status_handler); // Handler + +// 累積イベント(Worker層で追加、スコープ = ()) +worker.on_text_complete(my_complete_handler); // Handler +worker.on_tool_call_complete(my_tool_complete); // Handler +``` + +#### 2. 一括登録: `worker.subscribe(subscriber)` + +`WorkerSubscriber`トレイトを実装し、全ハンドラをまとめて登録。 + +```rust +/// 統合Subscriberトレイト +pub trait WorkerSubscriber: Send { + // スコープ型(ブロックイベント用) + type TextBlockScope: Default + Send; + type ToolUseBlockScope: Default + Send; + + // === ブロックイベント(スコープ管理あり)=== + fn on_text_block( + &mut self, + _scope: &mut Self::TextBlockScope, + _event: &TextBlockEvent, + ) {} + + fn on_tool_use_block( + &mut self, + _scope: &mut Self::ToolUseBlockScope, + _event: &ToolUseBlockEvent, + ) {} + + // === 単発イベント === + fn on_usage(&mut self, _event: &UsageEvent) {} + fn on_status(&mut self, _event: &StatusEvent) {} + fn on_error(&mut self, _event: &ErrorEvent) {} + + // === 累積イベント(Worker層で追加)=== + fn on_text_complete(&mut self, _text: &str) {} + fn on_tool_call_complete(&mut self, _call: &ToolCall) {} + + // === ターン制御 === + fn on_turn_start(&mut self, _turn: usize) {} + fn on_turn_end(&mut self, _turn: usize) {} +} +``` + +### 使用例: WorkerSubscriber + +```rust +struct MyUI { + chat_view: ChatView, +} + +impl WorkerSubscriber for MyUI { + type TextBlockScope = TextComponent; + type ToolUseBlockScope = ToolComponent; + + fn on_text_block(&mut self, comp: &mut TextComponent, event: &TextBlockEvent) { + match event { + TextBlockEvent::Start(_) => { + // スコープ開始時にコンポーネント初期化(Defaultで自動生成) + } + TextBlockEvent::Delta(text) => { + comp.append(text); + self.chat_view.update(comp); + } + TextBlockEvent::Stop(_) => { + comp.set_immutable(); + // スコープ終了後に自動破棄 + } + } + } + + fn on_text_complete(&mut self, text: &str) { + // 累積済みテキストを履歴に保存 + self.chat_view.add_to_history(text); + } + + fn on_tool_use_block(&mut self, comp: &mut ToolComponent, event: &ToolUseBlockEvent) { + match event { + ToolUseBlockEvent::Start(start) => { + comp.set_name(&start.name); + self.chat_view.show_tool_indicator(comp); + } + ToolUseBlockEvent::InputJsonDelta(delta) => { + comp.append_input(delta); + } + ToolUseBlockEvent::Stop(_) => { + comp.finalize(); + } + } + } + + fn on_tool_call_complete(&mut self, call: &ToolCall) { + self.chat_view.update_tool_result(&call.name, &call.input); + } +} + +// Worker に登録 +let mut worker = Worker::new(client); +worker.subscribe(MyUI::new()); + +let result = worker.run(messages).await?; +``` + +### 使用例: 個別登録 + +```rust +// シンプルなクロージャベース(将来的な糖衣構文として検討) +worker.on_text_complete(|text: &str| { + println!("Complete: {}", text); +}); + +// または Handler実装 +struct TextLogger; +impl Handler for TextLogger { + type Scope = (); + fn on_event(&mut self, _: &mut (), text: &String) { + println!("Complete: {}", text); + } +} +worker.on_text_complete(TextLogger); +``` + +### 累積イベント用Kind定義 + +```rust +/// テキスト完了イベント用Kind +pub struct TextCompleteKind; +impl Kind for TextCompleteKind { + type Event = String; // 累積済みテキスト +} + +/// ツール呼び出し完了イベント用Kind +pub struct ToolCallCompleteKind; +impl Kind for ToolCallCompleteKind { + type Event = ToolCall; // 完全なToolCall +} +``` + +### 内部実装 + +WorkerはSubscriberを内部で分解し、各Kindに対応するHandlerとしてTimelineに登録します。 +累積イベント(TextComplete等)はWorker層で処理し、ブロック終了時に累積結果を渡します。 + +```rust +impl Worker { + pub fn subscribe(&mut self, subscriber: S) { + let subscriber = Arc::new(Mutex::new(subscriber)); + + // TextBlock用ハンドラを登録 + self.timeline.on_text_block(TextBlockAdapter { + subscriber: subscriber.clone(), + }); + + // 累積イベント用の内部ハンドラも登録 + // (TextBlockCollectorのStop時にon_text_completeを呼ぶ) + } +} +``` + +### 設計上のポイント + +1. **Handlerの再利用**: 既存のHandler traitをそのまま活用 +2. **スコープ管理の維持**: ブロックイベントはStart→Delta→Endのライフサイクルを保持 +3. **選択的購読**: on_*で必要なイベントだけ、またはSubscriberで一括 +4. **累積イベントの追加**: Worker層でComplete系イベントを追加提供 +5. **後方互換性**: 従来の`run()`も引き続き使用可能 + diff --git a/worker-macros/src/lib.rs b/worker-macros/src/lib.rs index 921d201..c580437 100644 --- a/worker-macros/src/lib.rs +++ b/worker-macros/src/lib.rs @@ -1,6 +1,13 @@ +//! worker-macros - Tool生成用手続きマクロ +//! +//! `#[tool_registry]` と `#[tool]` マクロを提供し、 +//! ユーザー定義のメソッドから `Tool` トレイト実装を自動生成する。 + use proc_macro::TokenStream; -use quote::quote; -use syn::{parse_macro_input, ImplItem, ItemImpl}; +use quote::{format_ident, quote}; +use syn::{ + parse_macro_input, Attribute, FnArg, ImplItem, ItemImpl, Lit, Meta, Pat, ReturnType, Type, +}; /// `impl` ブロックに付与し、内部の `#[tool]` 属性がついたメソッドからツールを生成するマクロ。 /// @@ -8,10 +15,18 @@ use syn::{parse_macro_input, ImplItem, ItemImpl}; /// ```ignore /// #[tool_registry] /// impl MyApp { +/// /// ユーザー情報を取得する +/// /// 指定されたIDのユーザーをDBから検索します。 /// #[tool] -/// async fn my_function(&self, arg: String) -> Result { ... } +/// async fn get_user(&self, user_id: String) -> Result { ... } /// } /// ``` +/// +/// これにより以下が生成されます: +/// - `GetUserArgs` 構造体(引数用) +/// - `Tool_get_user` 構造体(Toolラッパー) +/// - `impl Tool for Tool_get_user` +/// - `impl MyApp { fn get_user_tool(&self) -> Tool_get_user }` #[proc_macro_attribute] pub fn tool_registry(_attr: TokenStream, item: TokenStream) -> TokenStream { let mut impl_block = parse_macro_input!(item as ItemImpl); @@ -23,76 +38,19 @@ pub fn tool_registry(_attr: TokenStream, item: TokenStream) -> TokenStream { if let ImplItem::Fn(method) = item { // #[tool] 属性を探す let mut is_tool = false; - let mut _description = String::new(); - + // 属性を走査してtoolがあるか確認し、削除する - // 同時にドキュメントコメントから説明を取得 method.attrs.retain(|attr| { if attr.path().is_ident("tool") { is_tool = true; false // 属性を削除 - } else if attr.path().is_ident("doc") { - // TODO: docコメントのパース - true } else { true } }); if is_tool { - let sig = &method.sig; - let method_name = &sig.ident; - let tool_name = method_name.to_string(); - let tool_struct_name = syn::Ident::new( - &format!("Tool_{}", method_name), - method_name.span(), - ); - - let factory_name = syn::Ident::new( - &format!("{}_tool", method_name), - method_name.span(), - ); - - // TODO: 引数の解析とArgs構造体の生成 - // TODO: descriptionの取得 - - // 仮の実装: Contextを抱えるTool構造体を作成 - let tool_impl = quote! { - #[derive(Clone)] - pub struct #tool_struct_name { - ctx: #self_ty, - } - - #[async_trait::async_trait] - impl worker_types::Tool for #tool_struct_name { - fn name(&self) -> &str { - #tool_name - } - - fn description(&self) -> &str { - "TODO: description from doc comments" - } - - fn input_schema(&self) -> serde_json::Value { - serde_json::json!({}) // TODO: schemars - } - - async fn execute(&self, input_json: &str) -> Result { - // TODO: Deserialize args and call check - // self.ctx.#method_name(...) - Ok("Not implemented yet".to_string()) - } - } - - impl #self_ty { - pub fn #factory_name(&self) -> #tool_struct_name { - #tool_struct_name { - ctx: self.clone() - } - } - } - }; - + let tool_impl = generate_tool_impl(self_ty, method); generated_items.push(tool_impl); } } @@ -100,15 +58,269 @@ pub fn tool_registry(_attr: TokenStream, item: TokenStream) -> TokenStream { let expanded = quote! { #impl_block - + #(#generated_items)* }; TokenStream::from(expanded) } +/// ドキュメントコメントから説明文を抽出 +fn extract_doc_comment(attrs: &[Attribute]) -> String { + let mut lines = Vec::new(); + + for attr in attrs { + if attr.path().is_ident("doc") { + if let Meta::NameValue(meta) = &attr.meta { + if let syn::Expr::Lit(expr_lit) = &meta.value { + if let Lit::Str(lit_str) = &expr_lit.lit { + let line = lit_str.value(); + // 先頭の空白を1つだけ除去(/// の後のスペース) + let trimmed = line.strip_prefix(' ').unwrap_or(&line); + lines.push(trimmed.to_string()); + } + } + } + } + } + + lines.join("\n") +} + +/// #[description = "..."] 属性から説明を抽出 +fn extract_description_attr(attrs: &[syn::Attribute]) -> Option { + for attr in attrs { + if attr.path().is_ident("description") { + if let Meta::NameValue(meta) = &attr.meta { + if let syn::Expr::Lit(expr_lit) = &meta.value { + if let Lit::Str(lit_str) = &expr_lit.lit { + return Some(lit_str.value()); + } + } + } + } + } + None +} + +/// メソッドからTool実装を生成 +fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::TokenStream { + let sig = &method.sig; + let method_name = &sig.ident; + let tool_name = method_name.to_string(); + + // 構造体名を生成(PascalCase変換) + let pascal_name = to_pascal_case(&method_name.to_string()); + let tool_struct_name = format_ident!("Tool{}", pascal_name); + let args_struct_name = format_ident!("{}Args", pascal_name); + let factory_name = format_ident!("{}_tool", method_name); + + // ドキュメントコメントから説明を取得 + let description = extract_doc_comment(&method.attrs); + let description = if description.is_empty() { + format!("Tool: {}", tool_name) + } else { + description + }; + + // 引数を解析(selfを除く) + let args: Vec<_> = sig + .inputs + .iter() + .filter_map(|arg| { + if let FnArg::Typed(pat_type) = arg { + Some(pat_type) + } else { + None // selfを除外 + } + }) + .collect(); + + // 引数構造体のフィールドを生成 + let arg_fields: Vec<_> = args + .iter() + .map(|pat_type| { + let pat = &pat_type.pat; + let ty = &pat_type.ty; + let desc = extract_description_attr(&pat_type.attrs); + + // パターンから識別子を抽出 + let field_name = if let Pat::Ident(pat_ident) = pat.as_ref() { + &pat_ident.ident + } else { + panic!("Only simple identifiers are supported for tool arguments"); + }; + + // #[description] があればschemarsのdocに変換 + if let Some(desc_str) = desc { + quote! { + #[schemars(description = #desc_str)] + pub #field_name: #ty + } + } else { + quote! { + pub #field_name: #ty + } + } + }) + .collect(); + + // execute内で引数を展開するコード + let arg_names: Vec<_> = args + .iter() + .map(|pat_type| { + if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() { + let ident = &pat_ident.ident; + quote! { args.#ident } + } else { + panic!("Only simple identifiers are supported"); + } + }) + .collect(); + + // メソッドが非同期かどうか + let is_async = sig.asyncness.is_some(); + + // 戻り値の型を解析してResult判定 + let awaiter = if is_async { + quote! { .await } + } else { + quote! {} + }; + + // 戻り値がResultかどうかを判定 + let result_handling = if is_result_type(&sig.output) { + quote! { + match result { + Ok(val) => Ok(format!("{:?}", val)), + Err(e) => Err(worker_types::ToolError::ExecutionFailed(format!("{}", e))), + } + } + } else { + quote! { + Ok(format!("{:?}", result)) + } + }; + + // 引数がない場合は空のArgs構造体を作成 + let args_struct_def = if arg_fields.is_empty() { + quote! { + #[derive(serde::Deserialize, schemars::JsonSchema)] + struct #args_struct_name {} + } + } else { + quote! { + #[derive(serde::Deserialize, schemars::JsonSchema)] + struct #args_struct_name { + #(#arg_fields),* + } + } + }; + + // 引数がない場合のexecute処理 + let execute_body = if args.is_empty() { + quote! { + // 引数なしでも空のJSONオブジェクトを許容 + let _: #args_struct_name = serde_json::from_str(input_json) + .unwrap_or(#args_struct_name {}); + + let result = self.ctx.#method_name()#awaiter; + #result_handling + } + } else { + quote! { + let args: #args_struct_name = serde_json::from_str(input_json) + .map_err(|e| worker_types::ToolError::InvalidArgument(e.to_string()))?; + + let result = self.ctx.#method_name(#(#arg_names),*)#awaiter; + #result_handling + } + }; + + quote! { + #args_struct_def + + #[derive(Clone)] + pub struct #tool_struct_name { + ctx: #self_ty, + } + + #[async_trait::async_trait] + impl worker_types::Tool for #tool_struct_name { + fn name(&self) -> &str { + #tool_name + } + + fn description(&self) -> &str { + #description + } + + fn input_schema(&self) -> serde_json::Value { + let schema = schemars::schema_for!(#args_struct_name); + serde_json::to_value(schema).unwrap_or(serde_json::json!({})) + } + + async fn execute(&self, input_json: &str) -> Result { + #execute_body + } + } + + impl #self_ty { + pub fn #factory_name(&self) -> #tool_struct_name { + #tool_struct_name { + ctx: self.clone() + } + } + } + } +} + +/// 戻り値の型がResultかどうかを判定 +fn is_result_type(return_type: &ReturnType) -> bool { + match return_type { + ReturnType::Default => false, + ReturnType::Type(_, ty) => { + // Type::Pathの場合、最後のセグメントが"Result"かチェック + if let Type::Path(type_path) = ty.as_ref() { + if let Some(segment) = type_path.path.segments.last() { + return segment.ident == "Result"; + } + } + false + } + } +} + +/// snake_case を PascalCase に変換 +fn to_pascal_case(s: &str) -> String { + s.split('_') + .map(|part| { + let mut chars = part.chars(); + match chars.next() { + None => String::new(), + Some(first) => first.to_uppercase().chain(chars).collect(), + } + }) + .collect() +} + /// マーカー属性。`tool_registry` によって処理されるため、ここでは何もしない。 #[proc_macro_attribute] pub fn tool(_attr: TokenStream, item: TokenStream) -> TokenStream { item } + +/// 引数属性用のマーカー。パース時に`tool_registry`で解釈される。 +/// +/// # Example +/// ```ignore +/// #[tool] +/// async fn get_user( +/// &self, +/// #[description = "取得したいユーザーのID"] user_id: String +/// ) -> Result { ... } +/// ``` +#[proc_macro_attribute] +pub fn description(_attr: TokenStream, item: TokenStream) -> TokenStream { + item +} diff --git a/worker/Cargo.toml b/worker/Cargo.toml index d70c9e0..1cecdaf 100644 --- a/worker/Cargo.toml +++ b/worker/Cargo.toml @@ -16,4 +16,5 @@ worker-macros = { path = "../worker-macros" } worker-types = { path = "../worker-types" } [dev-dependencies] +schemars = "1.2.0" tempfile = "3.24.0" diff --git a/worker/src/lib.rs b/worker/src/lib.rs index 994ea8b..933c58e 100644 --- a/worker/src/lib.rs +++ b/worker/src/lib.rs @@ -7,10 +7,12 @@ //! - 型消去されたHandler実装 pub mod llm_client; +mod text_block_collector; mod timeline; mod tool_call_collector; mod worker; +pub use text_block_collector::TextBlockCollector; pub use timeline::*; pub use tool_call_collector::ToolCallCollector; pub use worker::*; diff --git a/worker/src/text_block_collector.rs b/worker/src/text_block_collector.rs new file mode 100644 index 0000000..7a28663 --- /dev/null +++ b/worker/src/text_block_collector.rs @@ -0,0 +1,131 @@ +//! TextBlockCollector - テキストブロック収集用ハンドラ +//! +//! TimelineのTextBlockHandler として登録され、 +//! ストリーム中のテキストブロックを収集する。 + +use std::sync::{Arc, Mutex}; +use worker_types::{Handler, TextBlockEvent, TextBlockKind}; + +/// TextBlockから収集したテキスト情報を保持 +#[derive(Debug, Default)] +pub struct TextCollectorState { + /// 蓄積中のテキスト + buffer: String, +} + +/// TextBlockCollector - テキストブロックハンドラ +/// +/// Timelineに登録してTextBlockイベントを受信し、 +/// 完了したテキストブロックを収集する。 +#[derive(Clone)] +pub struct TextBlockCollector { + /// 収集されたテキストブロック + collected: Arc>>, +} + +impl TextBlockCollector { + /// 新しいTextBlockCollectorを作成 + pub fn new() -> Self { + Self { + collected: Arc::new(Mutex::new(Vec::new())), + } + } + + /// 収集されたテキストを取得してクリア + pub fn take_collected(&self) -> Vec { + let mut guard = self.collected.lock().unwrap(); + std::mem::take(&mut *guard) + } + + /// 収集されたテキストの参照を取得 + pub fn collected(&self) -> Vec { + self.collected.lock().unwrap().clone() + } + + /// 収集されたテキストがあるかどうか + pub fn has_content(&self) -> bool { + !self.collected.lock().unwrap().is_empty() + } + + /// 収集をクリア + pub fn clear(&self) { + self.collected.lock().unwrap().clear(); + } +} + +impl Default for TextBlockCollector { + fn default() -> Self { + Self::new() + } +} + +impl Handler for TextBlockCollector { + type Scope = TextCollectorState; + + fn on_event(&mut self, scope: &mut Self::Scope, event: &TextBlockEvent) { + match event { + TextBlockEvent::Start(_) => { + scope.buffer.clear(); + } + TextBlockEvent::Delta(text) => { + scope.buffer.push_str(text); + } + TextBlockEvent::Stop(_) => { + // ブロック完了時にテキストを確定 + if !scope.buffer.is_empty() { + let text = std::mem::take(&mut scope.buffer); + self.collected.lock().unwrap().push(text); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Timeline; + use worker_types::Event; + + /// TextBlockCollectorが単一のテキストブロックを正しく収集することを確認 + #[test] + fn test_collect_single_text_block() { + let collector = TextBlockCollector::new(); + let mut timeline = Timeline::new(); + timeline.on_text_block(collector.clone()); + + // テキストブロックのイベントシーケンスをディスパッチ + timeline.dispatch(&Event::text_block_start(0)); + timeline.dispatch(&Event::text_delta(0, "Hello, ")); + timeline.dispatch(&Event::text_delta(0, "World!")); + timeline.dispatch(&Event::text_block_stop(0, None)); + + // 収集されたテキストを確認 + let texts = collector.take_collected(); + assert_eq!(texts.len(), 1); + assert_eq!(texts[0], "Hello, World!"); + } + + /// TextBlockCollectorが複数のテキストブロックを正しく収集することを確認 + #[test] + fn test_collect_multiple_text_blocks() { + let collector = TextBlockCollector::new(); + let mut timeline = Timeline::new(); + timeline.on_text_block(collector.clone()); + + // 1つ目のテキストブロック + timeline.dispatch(&Event::text_block_start(0)); + timeline.dispatch(&Event::text_delta(0, "First")); + timeline.dispatch(&Event::text_block_stop(0, None)); + + // 2つ目のテキストブロック + timeline.dispatch(&Event::text_block_start(1)); + timeline.dispatch(&Event::text_delta(1, "Second")); + timeline.dispatch(&Event::text_block_stop(1, None)); + + let texts = collector.take_collected(); + assert_eq!(texts.len(), 2); + assert_eq!(texts[0], "First"); + assert_eq!(texts[1], "Second"); + } +} diff --git a/worker/src/worker.rs b/worker/src/worker.rs index 9498055..daeea6d 100644 --- a/worker/src/worker.rs +++ b/worker/src/worker.rs @@ -8,10 +8,12 @@ use std::sync::Arc; use futures::StreamExt; use crate::llm_client::{ClientError, LlmClient, Request, ToolDefinition}; +use crate::text_block_collector::TextBlockCollector; use crate::tool_call_collector::ToolCallCollector; use crate::Timeline; use worker_types::{ - ControlFlow, HookError, Message, Tool, ToolCall, ToolError, ToolResult, TurnResult, WorkerHook, + ContentPart, ControlFlow, HookError, Message, MessageContent, Tool, ToolCall, ToolError, + ToolResult, TurnResult, WorkerHook, }; // ============================================================================= @@ -40,16 +42,10 @@ pub enum WorkerError { // ============================================================================= /// Worker設定 -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct WorkerConfig { - /// 最大ターン数(無限ループ防止) - pub max_turns: usize, -} - -impl Default for WorkerConfig { - fn default() -> Self { - Self { max_turns: 10 } - } + // 将来の拡張用(現在は空) + _private: (), } // ============================================================================= @@ -68,38 +64,40 @@ pub struct Worker { client: C, /// イベントタイムライン timeline: Timeline, - /// ツールコレクター(Timeline用ハンドラ) + /// テキストブロックコレクター(Timeline用ハンドラ) + text_block_collector: TextBlockCollector, + /// ツールコールコレクター(Timeline用ハンドラ) tool_call_collector: ToolCallCollector, /// 登録されたツール tools: HashMap>, /// 登録されたHook hooks: Vec>, - /// 設定 - config: WorkerConfig, } impl Worker { /// 新しいWorkerを作成 pub fn new(client: C) -> Self { + let text_block_collector = TextBlockCollector::new(); let tool_call_collector = ToolCallCollector::new(); let mut timeline = Timeline::new(); - // ToolCallCollectorをTimelineに登録 + // コレクターをTimelineに登録 + timeline.on_text_block(text_block_collector.clone()); timeline.on_tool_use_block(tool_call_collector.clone()); Self { client, timeline, + text_block_collector, tool_call_collector, tools: HashMap::new(), hooks: Vec::new(), - config: WorkerConfig::default(), } } - /// 設定を適用 - pub fn config(mut self, config: WorkerConfig) -> Self { - self.config = config; + /// 設定を適用(将来の拡張用) + #[allow(dead_code)] + pub fn config(self, _config: WorkerConfig) -> Self { self } @@ -146,7 +144,7 @@ impl Worker { let mut context = messages; let tool_definitions = self.build_tool_definitions(); - for _turn in 0..self.config.max_turns { + loop { // Hook: on_message_send let control = self.run_on_message_send_hooks(&mut context).await?; if let ControlFlow::Abort(reason) = control { @@ -163,9 +161,16 @@ impl Worker { self.timeline.dispatch(&event); } - // ツール呼び出しの収集結果を取得 + // 収集結果を取得 + let text_blocks = self.text_block_collector.take_collected(); let tool_calls = self.tool_call_collector.take_collected(); + // アシスタントメッセージをコンテキストに追加 + let assistant_message = self.build_assistant_message(&text_blocks, &tool_calls); + if let Some(msg) = assistant_message { + context.push(msg); + } + if tool_calls.is_empty() { // ツール呼び出しなし → ターン終了判定 let turn_result = self.run_on_turn_end_hooks(&context).await?; @@ -188,12 +193,48 @@ impl Worker { context.push(Message::tool_result(&result.tool_use_id, &result.content)); } } + } - // 最大ターン数到達 - Err(WorkerError::Aborted(format!( - "Maximum turns ({}) reached", - self.config.max_turns - ))) + /// テキストブロックとツール呼び出しからアシスタントメッセージを構築 + fn build_assistant_message( + &self, + text_blocks: &[String], + tool_calls: &[ToolCall], + ) -> Option { + // テキストもツール呼び出しもない場合はNone + if text_blocks.is_empty() && tool_calls.is_empty() { + return None; + } + + // テキストのみの場合はシンプルなテキストメッセージ + if tool_calls.is_empty() { + let text = text_blocks.join(""); + return Some(Message::assistant(text)); + } + + // ツール呼び出しがある場合は Parts として構築 + let mut parts = Vec::new(); + + // テキストパーツを追加 + for text in text_blocks { + if !text.is_empty() { + parts.push(ContentPart::Text { text: text.clone() }); + } + } + + // ツール呼び出しパーツを追加 + for call in tool_calls { + parts.push(ContentPart::ToolUse { + id: call.id.clone(), + name: call.name.clone(), + input: call.input.clone(), + }); + } + + Some(Message { + role: worker_types::Role::Assistant, + content: MessageContent::Parts(parts), + }) } /// リクエストを構築 @@ -291,16 +332,18 @@ impl Worker { } /// ツールを並列実行 + /// + /// 全てのツールに対してbefore_tool_callフックを実行後、 + /// 許可されたツールを並列に実行し、結果にafter_tool_callフックを適用する。 async fn execute_tools( &self, - mut tool_calls: Vec, + tool_calls: Vec, ) -> Result, WorkerError> { - let mut results = Vec::new(); + use futures::future::join_all; - // TODO: 将来的には join_all で並列実行 - // 現在は逐次実行 - for mut tool_call in tool_calls.drain(..) { - // Hook: before_tool_call + // Phase 1: before_tool_call フックを適用(スキップ/中断を判定) + let mut approved_calls = Vec::new(); + for mut tool_call in tool_calls { let mut skip = false; for hook in &self.hooks { let result = hook.before_tool_call(&mut tool_call).await?; @@ -315,28 +358,40 @@ impl Worker { } } } - - if skip { - continue; + if !skip { + approved_calls.push(tool_call); } + } - // ツール実行 - let mut tool_result = if let Some(tool) = self.tools.get(&tool_call.name) { - let input_json = serde_json::to_string(&tool_call.input).unwrap_or_default(); - match tool.execute(&input_json).await { - Ok(content) => ToolResult::success(&tool_call.id, content), - Err(e) => ToolResult::error(&tool_call.id, e.to_string()), + // Phase 2: 許可されたツールを並列実行 + let futures: Vec<_> = approved_calls + .into_iter() + .map(|tool_call| { + let tools = &self.tools; + async move { + if let Some(tool) = tools.get(&tool_call.name) { + let input_json = + serde_json::to_string(&tool_call.input).unwrap_or_default(); + match tool.execute(&input_json).await { + Ok(content) => ToolResult::success(&tool_call.id, content), + Err(e) => ToolResult::error(&tool_call.id, e.to_string()), + } + } else { + ToolResult::error( + &tool_call.id, + format!("Tool '{}' not found", tool_call.name), + ) + } } - } else { - ToolResult::error( - &tool_call.id, - format!("Tool '{}' not found", tool_call.name), - ) - }; + }) + .collect(); - // Hook: after_tool_call + let mut results = join_all(futures).await; + + // Phase 3: after_tool_call フックを適用 + for tool_result in &mut results { for hook in &self.hooks { - let result = hook.after_tool_call(&mut tool_result).await?; + let result = hook.after_tool_call(tool_result).await?; match result { ControlFlow::Continue => {} ControlFlow::Skip => break, @@ -345,8 +400,6 @@ impl Worker { } } } - - results.push(tool_result); } Ok(results) diff --git a/worker/tests/common/mod.rs b/worker/tests/common/mod.rs index 1dc18c6..430f4ac 100644 --- a/worker/tests/common/mod.rs +++ b/worker/tests/common/mod.rs @@ -204,26 +204,81 @@ impl EventPlayer { /// /// 事前に定義されたイベントシーケンスをストリームとして返す。 /// fixtureファイルからロードすることも、直接イベントを渡すこともできる。 +/// +/// # 複数リクエスト対応 +/// +/// `with_responses()`を使用して、複数回のリクエストに対して異なるレスポンスを設定できる。 +/// リクエスト回数が設定されたレスポンス数を超えた場合は空のストリームを返す。 pub struct MockLlmClient { - events: Vec, + /// 各リクエストに対するレスポンス(イベントシーケンス) + responses: std::sync::Arc>>>, + /// 現在のリクエストインデックス + request_index: std::sync::Arc, } +#[allow(dead_code)] impl MockLlmClient { - /// イベントリストから直接作成 + /// イベントリストから直接作成(単一レスポンス) + /// + /// すべてのリクエストに対して同じイベントシーケンスを返す(従来の動作) pub fn new(events: Vec) -> Self { - Self { events } + Self { + responses: std::sync::Arc::new(std::sync::Mutex::new(vec![events])), + request_index: std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)), + } } - /// fixtureファイルからロード + /// 複数のレスポンスを設定 + /// + /// 各リクエストに対して順番にイベントシーケンスを返す。 + /// N回目のリクエストにはN番目のレスポンスが使用される。 + /// + /// # Example + /// ```ignore + /// let client = MockLlmClient::with_responses(vec![ + /// // 1回目のリクエスト: ツール呼び出し + /// vec![Event::tool_use_start(0, "call_1", "my_tool"), ...], + /// // 2回目のリクエスト: テキストレスポンス + /// vec![Event::text_block_start(0), ...], + /// ]); + /// ``` + pub fn with_responses(responses: Vec>) -> Self { + Self { + responses: std::sync::Arc::new(std::sync::Mutex::new(responses)), + request_index: std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)), + } + } + + /// fixtureファイルからロード(単一レスポンス) pub fn from_fixture(path: impl AsRef) -> std::io::Result { let player = EventPlayer::load(path)?; let events = player.parse_events(); - Ok(Self { events }) + Ok(Self::new(events)) } - /// 保持しているイベント数を取得 + /// 保持しているレスポンス数を取得 + pub fn response_count(&self) -> usize { + self.responses.lock().unwrap().len() + } + + /// 最初のレスポンスのイベント数を取得(後方互換性) pub fn event_count(&self) -> usize { - self.events.len() + self.responses + .lock() + .unwrap() + .first() + .map(|v| v.len()) + .unwrap_or(0) + } + + /// 現在のリクエストインデックスを取得 + pub fn current_request_index(&self) -> usize { + self.request_index.load(std::sync::atomic::Ordering::SeqCst) + } + + /// リクエストインデックスをリセット + pub fn reset(&self) { + self.request_index.store(0, std::sync::atomic::Ordering::SeqCst); } } @@ -233,8 +288,20 @@ impl LlmClient for MockLlmClient { &self, _request: Request, ) -> Result> + Send>>, ClientError> { - let events = self.events.clone(); + let index = self.request_index.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + + let events = { + let responses = self.responses.lock().unwrap(); + if index < responses.len() { + responses[index].clone() + } else { + // レスポンスが尽きた場合は空のストリーム + Vec::new() + } + }; + let stream = futures::stream::iter(events.into_iter().map(Ok)); Ok(Box::pin(stream)) } } + diff --git a/worker/tests/parallel_execution_test.rs b/worker/tests/parallel_execution_test.rs new file mode 100644 index 0000000..deb0715 --- /dev/null +++ b/worker/tests/parallel_execution_test.rs @@ -0,0 +1,254 @@ +//! 並列ツール実行のテスト +//! +//! Workerが複数のツールを並列に実行することを確認する。 + +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use async_trait::async_trait; +use worker::Worker; +use worker_types::{Event, Message, ResponseStatus, StatusEvent, Tool, ToolError, ToolResult, ToolCall, ControlFlow, HookError, WorkerHook}; + +mod common; +use common::MockLlmClient; + +// ============================================================================= +// Parallel Execution Test Tools +// ============================================================================= + +/// 一定時間待機してから応答するツール +#[derive(Clone)] +struct SlowTool { + name: String, + delay_ms: u64, + call_count: Arc, +} + +impl SlowTool { + fn new(name: impl Into, delay_ms: u64) -> Self { + Self { + name: name.into(), + delay_ms, + call_count: Arc::new(AtomicUsize::new(0)), + } + } + + fn call_count(&self) -> usize { + self.call_count.load(Ordering::SeqCst) + } +} + +#[async_trait] +impl Tool for SlowTool { + fn name(&self) -> &str { + &self.name + } + + fn description(&self) -> &str { + "A tool that waits before responding" + } + + fn input_schema(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": {} + }) + } + + async fn execute(&self, _input_json: &str) -> Result { + self.call_count.fetch_add(1, Ordering::SeqCst); + tokio::time::sleep(Duration::from_millis(self.delay_ms)).await; + Ok(format!("Completed after {}ms", self.delay_ms)) + } +} + +// ============================================================================= +// Tests +// ============================================================================= + +/// 複数のツールが並列に実行されることを確認 +/// +/// 各ツールが100msかかる場合、逐次実行なら300ms以上かかるが、 +/// 並列実行なら100ms程度で完了するはず。 +#[tokio::test] +async fn test_parallel_tool_execution() { + // 3つのツール呼び出しを含むイベントシーケンス + let events = vec![ + Event::tool_use_start(0, "call_1", "slow_tool_1"), + Event::tool_input_delta(0, r#"{}"#), + Event::tool_use_stop(0), + Event::tool_use_start(1, "call_2", "slow_tool_2"), + Event::tool_input_delta(1, r#"{}"#), + Event::tool_use_stop(1), + Event::tool_use_start(2, "call_3", "slow_tool_3"), + Event::tool_input_delta(2, r#"{}"#), + Event::tool_use_stop(2), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ]; + + let client = MockLlmClient::new(events); + let mut worker = Worker::new(client); + + // 各ツールは100ms待機 + let tool1 = SlowTool::new("slow_tool_1", 100); + let tool2 = SlowTool::new("slow_tool_2", 100); + let tool3 = SlowTool::new("slow_tool_3", 100); + + let tool1_clone = tool1.clone(); + let tool2_clone = tool2.clone(); + let tool3_clone = tool3.clone(); + + worker.register_tool(tool1); + worker.register_tool(tool2); + worker.register_tool(tool3); + + + + let messages = vec![Message::user("Run all tools")]; + + let start = Instant::now(); + let _result = worker.run(messages).await; + let elapsed = start.elapsed(); + + // 全ツールが呼び出されたことを確認 + assert_eq!(tool1_clone.call_count(), 1, "Tool 1 should be called once"); + assert_eq!(tool2_clone.call_count(), 1, "Tool 2 should be called once"); + assert_eq!(tool3_clone.call_count(), 1, "Tool 3 should be called once"); + + // 並列実行なら200ms以下で完了するはず(逐次なら300ms以上) + // マージン込みで250msをしきい値とする + assert!( + elapsed < Duration::from_millis(250), + "Parallel execution should complete in ~100ms, but took {:?}", + elapsed + ); + + println!("Parallel execution completed in {:?}", elapsed); +} + +/// Hook: before_tool_call でスキップされたツールは実行されないことを確認 +#[tokio::test] +async fn test_before_tool_call_skip() { + let events = vec![ + Event::tool_use_start(0, "call_1", "allowed_tool"), + Event::tool_input_delta(0, r#"{}"#), + Event::tool_use_stop(0), + Event::tool_use_start(1, "call_2", "blocked_tool"), + Event::tool_input_delta(1, r#"{}"#), + Event::tool_use_stop(1), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ]; + + let client = MockLlmClient::new(events); + let mut worker = Worker::new(client); + + let allowed_tool = SlowTool::new("allowed_tool", 10); + let blocked_tool = SlowTool::new("blocked_tool", 10); + + let allowed_clone = allowed_tool.clone(); + let blocked_clone = blocked_tool.clone(); + + worker.register_tool(allowed_tool); + worker.register_tool(blocked_tool); + + // "blocked_tool" をスキップするHook + struct BlockingHook; + + #[async_trait] + impl WorkerHook for BlockingHook { + async fn before_tool_call(&self, tool_call: &mut ToolCall) -> Result { + if tool_call.name == "blocked_tool" { + Ok(ControlFlow::Skip) + } else { + Ok(ControlFlow::Continue) + } + } + } + + worker.add_hook(BlockingHook); + + let messages = vec![Message::user("Test hook")]; + let _result = worker.run(messages).await; + + // allowed_tool は呼び出されるが、blocked_tool は呼び出されない + assert_eq!(allowed_clone.call_count(), 1, "Allowed tool should be called"); + assert_eq!(blocked_clone.call_count(), 0, "Blocked tool should not be called"); +} + +/// Hook: after_tool_call で結果が改変されることを確認 +#[tokio::test] +async fn test_after_tool_call_modification() { + // 複数リクエストに対応するレスポンスを準備 + let client = MockLlmClient::with_responses(vec![ + // 1回目のリクエスト: ツール呼び出し + vec![ + Event::tool_use_start(0, "call_1", "test_tool"), + Event::tool_input_delta(0, r#"{}"#), + Event::tool_use_stop(0), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ], + // 2回目のリクエスト: ツール結果を受けてテキストレスポンス + vec![ + Event::text_block_start(0), + Event::text_delta(0, "Done!"), + Event::text_block_stop(0, None), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ], + ]); + + let mut worker = Worker::new(client); + + #[derive(Clone)] + struct SimpleTool; + + #[async_trait] + impl Tool for SimpleTool { + fn name(&self) -> &str { "test_tool" } + fn description(&self) -> &str { "Test" } + fn input_schema(&self) -> serde_json::Value { serde_json::json!({}) } + async fn execute(&self, _: &str) -> Result { + Ok("Original Result".to_string()) + } + } + + worker.register_tool(SimpleTool); + + // 結果を改変するHook + struct ModifyingHook { + modified_content: Arc>>, + } + + #[async_trait] + impl WorkerHook for ModifyingHook { + async fn after_tool_call(&self, tool_result: &mut ToolResult) -> Result { + tool_result.content = format!("[Modified] {}", tool_result.content); + *self.modified_content.lock().unwrap() = Some(tool_result.content.clone()); + Ok(ControlFlow::Continue) + } + } + + let modified_content = Arc::new(std::sync::Mutex::new(None)); + worker.add_hook(ModifyingHook { modified_content: modified_content.clone() }); + + let messages = vec![Message::user("Test modification")]; + let result = worker.run(messages).await; + + assert!(result.is_ok(), "Worker should complete: {:?}", result); + + // Hookが呼ばれて内容が改変されたことを確認 + let content = modified_content.lock().unwrap().clone(); + assert!(content.is_some(), "Hook should have been called"); + assert!( + content.unwrap().contains("[Modified]"), + "Result should be modified" + ); +} diff --git a/worker/tests/tool_macro_test.rs b/worker/tests/tool_macro_test.rs new file mode 100644 index 0000000..98e1d00 --- /dev/null +++ b/worker/tests/tool_macro_test.rs @@ -0,0 +1,212 @@ +//! ツールマクロのテスト +//! +//! `#[tool_registry]` と `#[tool]` マクロの動作を確認する。 + +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +// マクロ展開に必要なインポート +use schemars; +use serde; + +use worker_macros::tool_registry; +use worker_types::Tool; + +// ============================================================================= +// Test: Basic Tool Generation +// ============================================================================= + +/// シンプルなコンテキスト構造体 +#[derive(Clone)] +struct SimpleContext { + prefix: String, +} + +#[tool_registry] +impl SimpleContext { + /// メッセージに挨拶を追加する + /// + /// 指定されたメッセージにプレフィックスを付けて返します。 + #[tool] + async fn greet(&self, message: String) -> String { + format!("{}: {}", self.prefix, message) + } + + /// 二つの数を足す + #[tool] + async fn add(&self, a: i32, b: i32) -> i32 { + a + b + } + + /// 引数なしのツール + #[tool] + async fn get_prefix(&self) -> String { + self.prefix.clone() + } +} + +#[tokio::test] +async fn test_basic_tool_generation() { + let ctx = SimpleContext { + prefix: "Hello".to_string(), + }; + + // ファクトリメソッドでツールを取得 + let greet_tool = ctx.greet_tool(); + + // 名前の確認 + assert_eq!(greet_tool.name(), "greet"); + + // 説明の確認(docコメントから取得) + let desc = greet_tool.description(); + assert!(desc.contains("メッセージに挨拶を追加する"), "Description should contain doc comment: {}", desc); + + // スキーマの確認 + let schema = greet_tool.input_schema(); + println!("Schema: {}", serde_json::to_string_pretty(&schema).unwrap()); + assert!(schema.get("properties").is_some(), "Schema should have properties"); + + // 実行テスト + let result = greet_tool.execute(r#"{"message": "World"}"#).await; + assert!(result.is_ok(), "Should execute successfully"); + let output = result.unwrap(); + assert!(output.contains("Hello"), "Output should contain prefix"); + assert!(output.contains("World"), "Output should contain message"); +} + +#[tokio::test] +async fn test_multiple_arguments() { + let ctx = SimpleContext { + prefix: "".to_string(), + }; + + let add_tool = ctx.add_tool(); + + assert_eq!(add_tool.name(), "add"); + + let result = add_tool.execute(r#"{"a": 10, "b": 20}"#).await; + assert!(result.is_ok()); + let output = result.unwrap(); + assert!(output.contains("30"), "Should contain sum: {}", output); +} + +#[tokio::test] +async fn test_no_arguments() { + let ctx = SimpleContext { + prefix: "TestPrefix".to_string(), + }; + + let get_prefix_tool = ctx.get_prefix_tool(); + + assert_eq!(get_prefix_tool.name(), "get_prefix"); + + // 空のJSONオブジェクトで呼び出し + let result = get_prefix_tool.execute(r#"{}"#).await; + assert!(result.is_ok()); + let output = result.unwrap(); + assert!(output.contains("TestPrefix"), "Should contain prefix: {}", output); +} + +#[tokio::test] +async fn test_invalid_arguments() { + let ctx = SimpleContext { + prefix: "".to_string(), + }; + + let greet_tool = ctx.greet_tool(); + + // 不正なJSON + let result = greet_tool.execute(r#"{"wrong_field": "value"}"#).await; + assert!(result.is_err(), "Should fail with invalid arguments"); +} + +// ============================================================================= +// Test: Result Return Type +// ============================================================================= + +#[derive(Clone)] +struct FallibleContext; + +#[derive(Debug)] +struct MyError(String); + +impl std::fmt::Display for MyError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +#[tool_registry] +impl FallibleContext { + /// 与えられた値を検証する + #[tool] + async fn validate(&self, value: i32) -> Result { + if value > 0 { + Ok(format!("Valid: {}", value)) + } else { + Err(MyError("Value must be positive".to_string())) + } + } +} + +#[tokio::test] +async fn test_result_return_type_success() { + let ctx = FallibleContext; + let validate_tool = ctx.validate_tool(); + + let result = validate_tool.execute(r#"{"value": 42}"#).await; + assert!(result.is_ok(), "Should succeed for positive value"); + let output = result.unwrap(); + assert!(output.contains("Valid"), "Should contain Valid: {}", output); +} + +#[tokio::test] +async fn test_result_return_type_error() { + let ctx = FallibleContext; + let validate_tool = ctx.validate_tool(); + + let result = validate_tool.execute(r#"{"value": -1}"#).await; + assert!(result.is_err(), "Should fail for negative value"); + + let err = result.unwrap_err(); + assert!(err.to_string().contains("positive"), "Error should mention positive: {}", err); +} + +// ============================================================================= +// Test: Synchronous Methods +// ============================================================================= + +#[derive(Clone)] +struct SyncContext { + counter: Arc, +} + +#[tool_registry] +impl SyncContext { + /// カウンターをインクリメントして返す (非async) + #[tool] + fn increment(&self) -> usize { + self.counter.fetch_add(1, Ordering::SeqCst) + 1 + } +} + +#[tokio::test] +async fn test_sync_method() { + let ctx = SyncContext { + counter: Arc::new(AtomicUsize::new(0)), + }; + + let increment_tool = ctx.increment_tool(); + + // 3回実行 + let result1 = increment_tool.execute(r#"{}"#).await; + let result2 = increment_tool.execute(r#"{}"#).await; + let result3 = increment_tool.execute(r#"{}"#).await; + + assert!(result1.is_ok()); + assert!(result2.is_ok()); + assert!(result3.is_ok()); + + // カウンターは3になっているはず + assert_eq!(ctx.counter.load(Ordering::SeqCst), 3); +} diff --git a/worker/tests/worker_fixtures.rs b/worker/tests/worker_fixtures.rs index 7d79eb8..32a7682 100644 --- a/worker/tests/worker_fixtures.rs +++ b/worker/tests/worker_fixtures.rs @@ -11,7 +11,7 @@ use std::sync::Arc; use async_trait::async_trait; use common::MockLlmClient; -use worker::{Worker, WorkerConfig}; +use worker::Worker; use worker_types::{Tool, ToolError}; /// フィクスチャディレクトリのパス @@ -163,8 +163,7 @@ async fn test_worker_tool_call() { let tool_for_check = weather_tool.clone(); worker.register_tool(weather_tool); - // 設定: ツール実行後はターン終了(ループしない) - worker = worker.config(WorkerConfig { max_turns: 1 }); + // メッセージを送信 let messages = vec![worker_types::Message::user("What's the weather in Tokyo?")];