From 16fda38039c2b14f2a6b08147b140e98e798696b Mon Sep 17 00:00:00 2001 From: Hare Date: Sat, 10 Jan 2026 00:31:14 +0900 Subject: [PATCH] feat: Redesign the tool system --- docs/spec/hooks_design.md | 206 +++++++---- docs/spec/tools_design.md | 191 +++++++++++ llm-worker-macros/src/lib.rs | 31 +- llm-worker/examples/worker_cancel_demo.rs | 12 +- llm-worker/examples/worker_cli.rs | 29 +- llm-worker/src/hook.rs | 108 +++++- llm-worker/src/lib.rs | 5 +- llm-worker/src/tool.rs | 128 +++++-- llm-worker/src/worker.rs | 359 +++++++++++++------- llm-worker/tests/parallel_execution_test.rs | 97 +++--- llm-worker/tests/tool_macro_test.rs | 82 +++-- llm-worker/tests/validation_test.rs | 1 - llm-worker/tests/worker_fixtures.rs | 44 ++- 13 files changed, 897 insertions(+), 396 deletions(-) create mode 100644 docs/spec/tools_design.md diff --git a/docs/spec/hooks_design.md b/docs/spec/hooks_design.md index 08354a6..2d8b64b 100644 --- a/docs/spec/hooks_design.md +++ b/docs/spec/hooks_design.md @@ -14,13 +14,14 @@ HookはWorker層でのターン制御に介入するためのメカニズムで ## Hook一覧 -| Hook | タイミング | 主な用途 | 戻り値 | -| ------------------ | -------------------------- | --------------------- | ---------------------- | -| `on_message_send` | LLM送信前 | コンテキスト改変/検証 | `OnMessageSendResult` | -| `before_tool_call` | ツール実行前 | 実行許可/引数改変 | `BeforeToolCallResult` | -| `after_tool_call` | ツール実行後 | 結果加工/マスキング | `AfterToolCallResult` | -| `on_turn_end` | ツールなしでターン終了直前 | 検証/リトライ指示 | `OnTurnEndResult` | -| `on_abort` | 中断時 | クリーンアップ/通知 | `()` | +| Hook | タイミング | 主な用途 | 戻り値 | +| ------------------ | -------------------------- | -------------------------- | ---------------------- | +| `on_prompt_submit` | `run()` 呼び出し時 | ユーザーメッセージの前処理 | `OnPromptSubmitResult` | +| `pre_llm_request` | 各ターンのLLM送信前 | コンテキスト改変/検証 | `PreLlmRequestResult` | +| `pre_tool_call` | ツール実行前 | 実行許可/引数改変 | `PreToolCallResult` | +| `post_tool_call` | ツール実行後 | 結果加工/マスキング | `PostToolCallResult` | +| `on_turn_end` | ツールなしでターン終了直前 | 検証/リトライ指示 | `OnTurnEndResult` | +| `on_abort` | 中断時 | クリーンアップ/通知 | `()` | ## Hook Trait @@ -43,25 +44,31 @@ pub trait HookEventKind { type Output; } -pub struct OnMessageSend; -pub struct BeforeToolCall; -pub struct AfterToolCall; +pub struct OnPromptSubmit; +pub struct PreLlmRequest; +pub struct PreToolCall; +pub struct PostToolCall; pub struct OnTurnEnd; pub struct OnAbort; -pub enum OnMessageSendResult { +pub enum OnPromptSubmitResult { Continue, Cancel(String), } -pub enum BeforeToolCallResult { +pub enum PreLlmRequestResult { + Continue, + Cancel(String), +} + +pub enum PreToolCallResult { Continue, Skip, Abort(String), Pause, } -pub enum AfterToolCallResult { +pub enum PostToolCallResult { Continue, Abort(String), } @@ -75,7 +82,7 @@ pub enum OnTurnEndResult { ### Tool Call Context -`before_tool_call` / `after_tool_call` は、ツール実行の文脈を含む入力を受け取る。 +`pre_tool_call` / `post_tool_call` は、ツール実行の文脈を含む入力を受け取る。 ```rust pub struct ToolCallContext { @@ -84,7 +91,8 @@ pub struct ToolCallContext { pub tool: Arc, // 状態アクセス用 } -pub struct ToolResultContext { +pub struct PostToolCallContext { + pub call: ToolCall, pub result: ToolResult, pub meta: ToolMeta, pub tool: Arc, @@ -94,40 +102,84 @@ pub struct ToolResultContext { ## 呼び出しタイミング ``` -Worker::run() ループ +Worker::run(user_input) │ -├─▶ on_message_send ──────────────────────────────┐ -│ コンテキストの改変、バリデーション、 │ -│ システムプロンプト注入などが可能 │ -│ │ -├─▶ LLMリクエスト送信 & ストリーム処理 │ -│ │ -├─▶ ツール呼び出しがある場合: │ -│ │ │ -│ ├─▶ before_tool_call (各ツールごと・逐次) │ -│ │ 実行可否の判定、引数の改変 │ -│ │ │ -│ ├─▶ ツール並列実行 (join_all) │ -│ │ │ -│ └─▶ after_tool_call (各結果ごと・逐次) │ -│ 結果の確認、加工、ログ出力 │ -│ │ -├─▶ ツール結果をコンテキストに追加 → ループ先頭へ │ -│ │ -└─▶ ツールなしの場合: │ - │ │ - └─▶ on_turn_end ─────────────────────────────┘ - 最終応答のチェック(Lint/Fmt等) - エラーがあればContinueWithMessagesでリトライ +├─▶ on_prompt_submit ───────────────────────────┐ +│ ユーザーメッセージの前処理・検証 │ +│ (最初の1回のみ) │ +│ │ +└─▶ loop { + │ + ├─▶ pre_llm_request ──────────────────────│ + │ コンテキストの改変、バリデーション、 │ + │ システムプロンプト注入などが可能 │ + │ (毎ターン実行) │ + │ │ + ├─▶ LLMリクエスト送信 & ストリーム処理 │ + │ │ + ├─▶ ツール呼び出しがある場合: │ + │ │ │ + │ ├─▶ pre_tool_call (各ツールごと・逐次) │ + │ │ 実行可否の判定、引数の改変 │ + │ │ │ + │ ├─▶ ツール並列実行 (join_all) │ + │ │ │ + │ └─▶ post_tool_call (各結果ごと・逐次) │ + │ 結果の確認、加工、ログ出力 │ + │ │ + ├─▶ ツール結果をコンテキストに追加 │ + │ → ループ先頭へ │ + │ │ + └─▶ ツールなしの場合: │ + │ │ + └─▶ on_turn_end ───────────────────┘ + 最終応答のチェック(Lint/Fmt等) + エラーがあればContinueWithMessagesでリトライ +} ※ 中断時は on_abort が呼ばれる ``` ## 各Hookの詳細 -### on_message_send +### on_prompt_submit -**呼び出しタイミング**: LLMへリクエスト送信前(ターンループの冒頭) +**呼び出しタイミング**: `run()` +でユーザーメッセージを受け取った直後(最初の1回のみ) + +**用途**: + +- ユーザー入力のバリデーション +- 入力のサニタイズ・フィルタリング +- ログ出力 +- `OnPromptSubmitResult::Cancel` による実行キャンセル + +**入力**: `&mut Message` - ユーザーメッセージ(改変可能) + +**例**: 入力のバリデーション + +```rust +struct InputValidator; + +#[async_trait] +impl Hook for InputValidator { + async fn call( + &self, + message: &mut Message, + ) -> Result { + if let MessageContent::Text(text) = &message.content { + if text.trim().is_empty() { + return Ok(OnPromptSubmitResult::Cancel("Empty input".to_string())); + } + } + Ok(OnPromptSubmitResult::Continue) + } +} +``` + +### pre_llm_request + +**呼び出しタイミング**: 各ターンのLLMリクエスト送信前(ループの毎回) **用途**: @@ -135,7 +187,9 @@ Worker::run() ループ - メッセージのバリデーション - 機密情報のフィルタリング - リクエスト内容のログ出力 -- `OnMessageSendResult::Cancel` による送信キャンセル +- `PreLlmRequestResult::Cancel` による送信キャンセル + +**入力**: `&mut Vec` - コンテキスト全体(改変可能) **例**: メッセージにタイムスタンプを追加 @@ -143,19 +197,19 @@ Worker::run() ループ struct TimestampHook; #[async_trait] -impl Hook for TimestampHook { +impl Hook for TimestampHook { async fn call( &self, context: &mut Vec, - ) -> Result { + ) -> Result { let timestamp = chrono::Local::now().to_rfc3339(); context.insert(0, Message::user(format!("[{}]", timestamp))); - Ok(OnMessageSendResult::Continue) + Ok(PreLlmRequestResult::Continue) } } ``` -### before_tool_call +### pre_tool_call **呼び出しタイミング**: 各ツール実行前(並列実行フェーズの前) @@ -165,9 +219,10 @@ impl Hook for TimestampHook { - 引数のサニタイズ - 確認プロンプトの表示(UIとの連携) - 実行ログの記録 -- `BeforeToolCallResult::Pause` による一時停止 +- `PreToolCallResult::Pause` による一時停止 **入力**: + - `ToolCallContext`(`ToolCall` + `ToolMeta` + `Arc`) **例**: 特定ツールをブロック @@ -178,22 +233,22 @@ struct ToolBlocker { } #[async_trait] -impl Hook for ToolBlocker { +impl Hook for ToolBlocker { async fn call( &self, ctx: &mut ToolCallContext, - ) -> Result { + ) -> Result { if self.blocked_tools.contains(&ctx.call.name) { println!("Blocked tool: {}", ctx.call.name); - Ok(BeforeToolCallResult::Skip) + Ok(PreToolCallResult::Skip) } else { - Ok(BeforeToolCallResult::Continue) + Ok(PreToolCallResult::Continue) } } } ``` -### after_tool_call +### post_tool_call **呼び出しタイミング**: 各ツール実行後(並列実行フェーズの後) @@ -203,8 +258,11 @@ impl Hook for ToolBlocker { - 機密情報のマスキング - 結果のキャッシュ - 実行結果のログ出力 + **入力**: -- `ToolResultContext`(`ToolResult` + `ToolMeta` + `Arc`) + +- `PostToolCallContext`(`ToolCall` + `ToolResult` + `ToolMeta` + + `Arc`) **例**: 結果にプレフィックスを追加 @@ -212,15 +270,15 @@ impl Hook for ToolBlocker { struct ResultFormatter; #[async_trait] -impl Hook for ResultFormatter { +impl Hook for ResultFormatter { async fn call( &self, - ctx: &mut ToolResultContext, - ) -> Result { + ctx: &mut PostToolCallContext, + ) -> Result { if !ctx.result.is_error { ctx.result.content = format!("[OK] {}", ctx.result.content); } - Ok(AfterToolCallResult::Continue) + Ok(PostToolCallResult::Continue) } } ``` @@ -283,9 +341,9 @@ impl Hook for JsonValidator { Hookは**イベントごとに登録順**に実行されます。 ```rust -worker.add_before_tool_call_hook(HookA); // 1番目に実行 -worker.add_before_tool_call_hook(HookB); // 2番目に実行 -worker.add_before_tool_call_hook(HookC); // 3番目に実行 +worker.add_pre_tool_call_hook(HookA); // 1番目に実行 +worker.add_pre_tool_call_hook(HookB); // 2番目に実行 +worker.add_pre_tool_call_hook(HookC); // 3番目に実行 ``` ### 制御フローの伝播 @@ -323,15 +381,15 @@ Hook A: Continue → Hook B: Pause async fn call(&self, ctx: &mut ToolCallContext) -> ... { // 引数を直接書き換え ctx.call.input["sanitized"] = json!(true); - Ok(BeforeToolCallResult::Continue) + Ok(PreToolCallResult::Continue) } ``` ### 3. 並列実行との統合 -- `before_tool_call`: 並列実行**前**に逐次実行(許可判定のため) +- `pre_tool_call`: 並列実行**前**に逐次実行(許可判定のため) - ツール実行: `join_all`で**並列**実行 -- `after_tool_call`: 並列実行**後**に逐次実行(結果加工のため) +- `post_tool_call`: 並列実行**後**に逐次実行(結果加工のため) ### 4. Send + Sync 要件 @@ -344,24 +402,24 @@ struct CountingHook { } #[async_trait] -impl Hook for CountingHook { - async fn call(&self, _: &mut ToolCallContext) -> Result { +impl Hook for CountingHook { + async fn call(&self, _: &mut ToolCallContext) -> Result { self.count.fetch_add(1, Ordering::SeqCst); - Ok(BeforeToolCallResult::Continue) + Ok(PreToolCallResult::Continue) } } ``` ## 典型的なユースケース -| ユースケース | 使用Hook | 処理内容 | -| ------------------ | ------------------------ | -------------------------- | -| ツール許可制御 | `before_tool_call` | 危険なツールをSkip | -| 実行ログ | `before/after_tool_call` | 呼び出しと結果を記録 | -| 出力バリデーション | `on_turn_end` | 形式チェック、リトライ指示 | -| コンテキスト注入 | `on_message_send` | システムメッセージ追加 | -| 結果のサニタイズ | `after_tool_call` | 機密情報のマスキング | -| レート制限 | `before_tool_call` | 呼び出し頻度の制御 | +| ユースケース | 使用Hook | 処理内容 | +| ------------------ | -------------------- | -------------------------- | +| ツール許可制御 | `pre_tool_call` | 危険なツールをSkip | +| 実行ログ | `pre/post_tool_call` | 呼び出しと結果を記録 | +| 出力バリデーション | `on_turn_end` | 形式チェック、リトライ指示 | +| コンテキスト注入 | `on_message_send` | システムメッセージ追加 | +| 結果のサニタイズ | `post_tool_call` | 機密情報のマスキング | +| レート制限 | `pre_tool_call` | 呼び出し頻度の制御 | ## TODO diff --git a/docs/spec/tools_design.md b/docs/spec/tools_design.md new file mode 100644 index 0000000..80af202 --- /dev/null +++ b/docs/spec/tools_design.md @@ -0,0 +1,191 @@ +# Tool 設計 + +## 概要 + +`llm-worker`のツールシステムは、LLMが外部リソースにアクセスしたり計算を実行するための仕組みを提供する。 +メタ情報の不変性とセッションスコープの状態管理を両立させる設計となっている。 + +## 主要な型 + +``` +type ToolDefinition + Fn() -> (ToolMeta, Arc) + +worker.register_tool() で呼び出し + + ▼ + +- struct ToolMeta (name, desc, schema) + 不変・登録時固定 +- trait Tool (executer) + 登録時生成・セッション中再利用 +``` + +### ToolMeta + +ツールのメタ情報を保持する不変構造体。登録時に固定され、Worker内で変更されない。 + +```rust +pub struct ToolMeta { + pub name: String, + pub description: String, + pub input_schema: Value, +} +``` + +**目的:** + +- LLM へのツール定義として送信 +- Hook からの参照(読み取り専用) +- 登録後の不変性を保証 + +### Tool trait + +ツールの実行ロジックのみを定義するトレイト。 + +```rust +#[async_trait] +pub trait Tool: Send + Sync { + async fn execute(&self, input_json: &str) -> Result; +} +``` + +**設計方針:** + +- メタ情報(name, description, schema)は含まない +- 状態を持つことが可能(セッション中のカウンターなど) +- `Send + Sync` で並列実行に対応 + +**インスタンスのライフサイクル:** + +1. `register_tool()` 呼び出し時にファクトリが実行され、インスタンスが生成される +2. LLM がツールを呼び出すと、既存インスタンスの `execute()` が実行される +3. 同じセッション中は同一インスタンスが再利用される + +※ 「最初に呼ばれたとき」の遅延初期化ではなく、**登録時の即時初期化**である。 + +### ToolDefinition + +メタ情報とツールインスタンスを生成するファクトリ。 + +```rust +pub type ToolDefinition = Arc (ToolMeta, Arc) + Send + Sync>; +``` + +**なぜファクトリか:** + +- Worker への登録時に一度だけ呼び出される +- メタ情報とインスタンスを同時に生成し、整合性を保証 +- クロージャでコンテキスト(`self.clone()`)をキャプチャ可能 + +## Worker でのツール管理 + +```rust +// Worker 内部 +tools: HashMap)> + +// 登録 API +pub fn register_tool(&mut self, factory: ToolDefinition) -> Result<(), ToolRegistryError> +``` + +登録時の処理: + +1. ファクトリを呼び出し `(meta, instance)` を取得 +2. 同名ツールが既に登録されていればエラー +3. HashMap に `(meta, instance)` を保存 + +## マクロによる自動生成 + +`#[tool_registry]` マクロは `{method}_definition()` メソッドを生成する。 + +```rust +#[tool_registry] +impl MyApp { + /// 検索を実行する + #[tool] + async fn search(&self, query: String) -> String { + // 実装 + } +} + +// 生成されるコード: +impl MyApp { + pub fn search_definition(&self) -> ToolDefinition { + let ctx = self.clone(); + Arc::new(move || { + let meta = ToolMeta::new("search") + .description("検索を実行する") + .input_schema(/* schemars で生成 */); + let tool = Arc::new(ToolSearch { ctx: ctx.clone() }); + (meta, tool) + }) + } +} +``` + +## Hook との連携 + +Hook は `ToolCallContext` / `AfterToolCallContext` +を通じてメタ情報とインスタンスにアクセスできる。 + +```rust +pub struct ToolCallContext { + pub call: ToolCall, // 呼び出し情報(改変可能) + pub meta: ToolMeta, // メタ情報(読み取り専用) + pub tool: Arc, // インスタンス(状態アクセス用) +} +``` + +**用途:** + +- `meta` で名前やスキーマを確認 +- `tool` でツールの内部状態を読み取り(ダウンキャスト必要) +- `call` の引数を改変してツールに渡す + +## 使用例 + +### 手動実装 + +```rust +struct Counter { count: AtomicUsize } + +impl Tool for Counter { + async fn execute(&self, _: &str) -> Result { + let n = self.count.fetch_add(1, Ordering::SeqCst); + Ok(format!("count: {}", n)) + } +} + +let def: ToolDefinition = Arc::new(|| { + let meta = ToolMeta::new("counter") + .description("カウンターを増加") + .input_schema(json!({"type": "object"})); + (meta, Arc::new(Counter { count: AtomicUsize::new(0) })) +}); + +worker.register_tool(def)?; +``` + +### マクロ使用(推奨) + +```rust +#[tool_registry] +impl App { + #[tool] + async fn greet(&self, name: String) -> String { + format!("Hello, {}!", name) + } +} + +let app = App; +worker.register_tool(app.greet_definition())?; +``` + +## 設計上の決定 + +| 問題 | 決定 | 理由 | +| -------------------- | ------------------------------ | ---------------------------------------------- | +| メタ情報の変更可能性 | ToolMeta を分離・不変化 | 登録後の整合性を保証 | +| 状態管理 | 登録時にインスタンス生成 | セッション中の状態保持、同一インスタンス再利用 | +| Factory vs Instance | Factory + 登録時即時呼び出し | コンテキストキャプチャと登録時検証 | +| Hook からのアクセス | Context に meta と tool を含む | 柔軟な介入を可能に | diff --git a/llm-worker-macros/src/lib.rs b/llm-worker-macros/src/lib.rs index 5c9b73f..b1a2e45 100644 --- a/llm-worker-macros/src/lib.rs +++ b/llm-worker-macros/src/lib.rs @@ -113,7 +113,7 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2:: let pascal_name = to_pascal_case(&method_name.to_string()); let tool_struct_name = format_ident!("Tool{}", pascal_name); let args_struct_name = format_ident!("{}Args", pascal_name); - let factory_name = format_ident!("{}_tool", method_name); + let definition_name = format_ident!("{}_definition", method_name); // ドキュメントコメントから説明を取得 let description = extract_doc_comment(&method.attrs); @@ -247,29 +247,24 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2:: #[async_trait::async_trait] impl ::llm_worker::tool::Tool for #tool_struct_name { - fn name(&self) -> &str { - #tool_name - } - - fn description(&self) -> &str { - #description - } - - fn input_schema(&self) -> serde_json::Value { - let schema = schemars::schema_for!(#args_struct_name); - serde_json::to_value(schema).unwrap_or(serde_json::json!({})) - } - async fn execute(&self, input_json: &str) -> Result { #execute_body } } impl #self_ty { - pub fn #factory_name(&self) -> #tool_struct_name { - #tool_struct_name { - ctx: self.clone() - } + /// ToolDefinition を取得(Worker への登録用) + pub fn #definition_name(&self) -> ::llm_worker::tool::ToolDefinition { + let ctx = self.clone(); + ::std::sync::Arc::new(move || { + let schema = schemars::schema_for!(#args_struct_name); + let meta = ::llm_worker::tool::ToolMeta::new(#tool_name) + .description(#description) + .input_schema(serde_json::to_value(schema).unwrap_or(serde_json::json!({}))); + let tool: ::std::sync::Arc = + ::std::sync::Arc::new(#tool_struct_name { ctx: ctx.clone() }); + (meta, tool) + }) } } } diff --git a/llm-worker/examples/worker_cancel_demo.rs b/llm-worker/examples/worker_cancel_demo.rs index ad9173c..b4b7114 100644 --- a/llm-worker/examples/worker_cancel_demo.rs +++ b/llm-worker/examples/worker_cancel_demo.rs @@ -2,11 +2,11 @@ //! //! ストリーミング受信中に別スレッドからキャンセルする例 +use llm_worker::llm_client::providers::anthropic::AnthropicClient; +use llm_worker::{Worker, WorkerResult}; use std::sync::Arc; use std::time::Duration; use tokio::sync::Mutex; -use llm_worker::{Worker, WorkerResult}; -use llm_worker::llm_client::providers::anthropic::AnthropicClient; #[tokio::main] async fn main() -> Result<(), Box> { @@ -21,8 +21,8 @@ async fn main() -> Result<(), Box> { ) .init(); - let api_key = std::env::var("ANTHROPIC_API_KEY") - .expect("ANTHROPIC_API_KEY environment variable not set"); + let api_key = + std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY environment variable not set"); let client = AnthropicClient::new(&api_key, "claude-sonnet-4-20250514"); let worker = Arc::new(Mutex::new(Worker::new(client))); @@ -41,7 +41,7 @@ async fn main() -> Result<(), Box> { let task = tokio::spawn(async move { let mut w = worker_clone.lock().await; println!("📡 Sending request to LLM..."); - + match w.run("Tell me a very long story about a brave knight. Make it as detailed as possible with many paragraphs.").await { Ok(WorkerResult::Finished(_)) => { println!("✅ Task completed normally"); @@ -66,6 +66,6 @@ async fn main() -> Result<(), Box> { task.await?; println!("\n✨ Demo complete!"); - + Ok(()) } diff --git a/llm-worker/examples/worker_cli.rs b/llm-worker/examples/worker_cli.rs index a576871..fbc03e4 100644 --- a/llm-worker/examples/worker_cli.rs +++ b/llm-worker/examples/worker_cli.rs @@ -41,7 +41,7 @@ use tracing_subscriber::EnvFilter; use clap::{Parser, ValueEnum}; use llm_worker::{ Worker, - hook::{AfterToolCall, AfterToolCallResult, Hook, HookError, ToolResult}, + hook::{Hook, HookError, PostToolCall, PostToolCallContext, PostToolCallResult}, llm_client::{ LlmClient, providers::{ @@ -282,25 +282,22 @@ impl ToolResultPrinterHook { } #[async_trait] -impl Hook for ToolResultPrinterHook { - async fn call( - &self, - tool_result: &mut ToolResult, - ) -> Result { +impl Hook for ToolResultPrinterHook { + async fn call(&self, ctx: &mut PostToolCallContext) -> Result { let name = self .call_names .lock() .unwrap() - .remove(&tool_result.tool_use_id) - .unwrap_or_else(|| tool_result.tool_use_id.clone()); + .remove(&ctx.result.tool_use_id) + .unwrap_or_else(|| ctx.result.tool_use_id.clone()); - if tool_result.is_error { - println!(" Result ({}): ❌ {}", name, tool_result.content); + if ctx.result.is_error { + println!(" Result ({}): ❌ {}", name, ctx.result.content); } else { - println!(" Result ({}): ✅ {}", name, tool_result.content); + println!(" Result ({}): ✅ {}", name, ctx.result.content); } - Ok(AfterToolCallResult::Continue) + Ok(PostToolCallResult::Continue) } } @@ -441,8 +438,10 @@ async fn main() -> Result<(), Box> { // ツール登録(--no-tools でなければ) if !args.no_tools { let app = AppContext; - worker.register_tool(app.get_current_time_tool()); - worker.register_tool(app.calculate_tool()); + worker + .register_tool(app.get_current_time_definition()) + .unwrap(); + worker.register_tool(app.calculate_definition()).unwrap(); } // ストリーミング表示用ハンドラーを登録 @@ -451,7 +450,7 @@ async fn main() -> Result<(), Box> { .on_text_block(StreamingPrinter::new()) .on_tool_use_block(ToolCallPrinter::new(tool_call_names.clone())); - worker.add_after_tool_call_hook(ToolResultPrinterHook::new(tool_call_names)); + worker.add_post_tool_call_hook(ToolResultPrinterHook::new(tool_call_names)); // ワンショットモード if let Some(prompt) = args.prompt { diff --git a/llm-worker/src/hook.rs b/llm-worker/src/hook.rs index 1007dec..9fed3af 100644 --- a/llm-worker/src/hook.rs +++ b/llm-worker/src/hook.rs @@ -16,20 +16,27 @@ pub trait HookEventKind: Send + Sync + 'static { type Output; } -pub struct OnMessageSend; -pub struct BeforeToolCall; -pub struct AfterToolCall; +pub struct OnPromptSubmit; +pub struct PreLlmRequest; +pub struct PreToolCall; +pub struct PostToolCall; pub struct OnTurnEnd; pub struct OnAbort; #[derive(Debug, Clone, PartialEq, Eq)] -pub enum OnMessageSendResult { +pub enum OnPromptSubmitResult { Continue, Cancel(String), } #[derive(Debug, Clone, PartialEq, Eq)] -pub enum BeforeToolCallResult { +pub enum PreLlmRequestResult { + Continue, + Cancel(String), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PreToolCallResult { Continue, Skip, Abort(String), @@ -37,7 +44,7 @@ pub enum BeforeToolCallResult { } #[derive(Debug, Clone, PartialEq, Eq)] -pub enum AfterToolCallResult { +pub enum PostToolCallResult { Continue, Abort(String), } @@ -49,19 +56,50 @@ pub enum OnTurnEndResult { Paused, } -impl HookEventKind for OnMessageSend { +use std::sync::Arc; + +use crate::tool::{Tool, ToolMeta}; + +/// PreToolCall の入力コンテキスト +pub struct ToolCallContext { + /// ツール呼び出し情報(改変可能) + pub call: ToolCall, + /// ツールメタ情報(不変) + pub meta: ToolMeta, + /// ツールインスタンス(状態アクセス用) + pub tool: Arc, +} + +/// PostToolCall の入力コンテキスト +pub struct PostToolCallContext { + /// ツール呼び出し情報 + pub call: ToolCall, + /// ツール実行結果(改変可能) + pub result: ToolResult, + /// ツールメタ情報(不変) + pub meta: ToolMeta, + /// ツールインスタンス(状態アクセス用) + pub tool: Arc, +} + +impl HookEventKind for OnPromptSubmit { + type Input = crate::Message; + type Output = OnPromptSubmitResult; +} + +impl HookEventKind for PreLlmRequest { type Input = Vec; - type Output = OnMessageSendResult; + type Output = PreLlmRequestResult; } -impl HookEventKind for BeforeToolCall { - type Input = ToolCall; - type Output = BeforeToolCallResult; +impl HookEventKind for PreToolCall { + type Input = ToolCallContext; + type Output = PreToolCallResult; } -impl HookEventKind for AfterToolCall { - type Input = ToolResult; - type Output = AfterToolCallResult; +impl HookEventKind for PostToolCall { + type Input = PostToolCallContext; + type Output = PostToolCallResult; } impl HookEventKind for OnTurnEnd { @@ -151,3 +189,45 @@ pub enum HookError { pub trait Hook: Send + Sync { async fn call(&self, input: &mut E::Input) -> Result; } + +// ============================================================================= +// Hook Registry +// ============================================================================= + +/// 全 Hook を保持するレジストリ +/// +/// Worker 内部で使用され、各種 Hook を一括管理する。 +pub struct HookRegistry { + /// on_prompt_submit Hook + pub(crate) on_prompt_submit: Vec>>, + /// pre_llm_request Hook + pub(crate) pre_llm_request: Vec>>, + /// pre_tool_call Hook + pub(crate) pre_tool_call: Vec>>, + /// post_tool_call Hook + pub(crate) post_tool_call: Vec>>, + /// on_turn_end Hook + pub(crate) on_turn_end: Vec>>, + /// on_abort Hook + pub(crate) on_abort: Vec>>, +} + +impl Default for HookRegistry { + fn default() -> Self { + Self::new() + } +} + +impl HookRegistry { + /// 空の HookRegistry を作成 + pub fn new() -> Self { + Self { + on_prompt_submit: Vec::new(), + pre_llm_request: Vec::new(), + pre_tool_call: Vec::new(), + post_tool_call: Vec::new(), + on_turn_end: Vec::new(), + on_abort: Vec::new(), + } + } +} diff --git a/llm-worker/src/lib.rs b/llm-worker/src/lib.rs index ef743b4..2a311ab 100644 --- a/llm-worker/src/lib.rs +++ b/llm-worker/src/lib.rs @@ -19,8 +19,7 @@ //! .system_prompt("You are a helpful assistant."); //! //! // ツールを登録(オプション) -//! use llm_worker::tool::Tool; -//! worker.register_tool(my_tool); +//! // worker.register_tool(my_tool_definition)?; //! //! // 対話を実行 //! let history = worker.run("Hello!").await?; @@ -49,4 +48,4 @@ pub mod timeline; pub mod tool; pub use message::{ContentPart, Message, MessageContent, Role}; -pub use worker::{Worker, WorkerConfig, WorkerError, WorkerResult}; +pub use worker::{ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult}; diff --git a/llm-worker/src/tool.rs b/llm-worker/src/tool.rs index 9585eca..eabe57f 100644 --- a/llm-worker/src/tool.rs +++ b/llm-worker/src/tool.rs @@ -3,6 +3,8 @@ //! LLMから呼び出し可能なツールを定義するためのトレイト。 //! 通常は`#[tool]`マクロを使用して自動実装します。 +use std::sync::Arc; + use async_trait::async_trait; use serde_json::Value; use thiserror::Error; @@ -21,64 +23,126 @@ pub enum ToolError { Internal(String), } +// ============================================================================= +// ToolMeta - 不変のメタ情報 +// ============================================================================= + +/// ツールのメタ情報(登録時に固定、不変) +/// +/// `ToolDefinition` ファクトリから生成され、Worker に登録後は変更されません。 +/// LLM へのツール定義送信に使用されます。 +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ToolMeta { + /// ツール名(LLMが識別に使用) + pub name: String, + /// ツールの説明(LLMへのプロンプトに含まれる) + pub description: String, + /// 引数のJSON Schema + pub input_schema: Value, +} + +impl ToolMeta { + /// 新しい ToolMeta を作成 + pub fn new(name: impl Into) -> Self { + Self { + name: name.into(), + description: String::new(), + input_schema: Value::Object(Default::default()), + } + } + + /// 説明を設定 + pub fn description(mut self, desc: impl Into) -> Self { + self.description = desc.into(); + self + } + + /// 引数スキーマを設定 + pub fn input_schema(mut self, schema: Value) -> Self { + self.input_schema = schema; + self + } +} + +// ============================================================================= +// ToolDefinition - ファクトリ型 +// ============================================================================= + +/// ツール定義ファクトリ +/// +/// 呼び出すと `(ToolMeta, Arc)` を返します。 +/// Worker への登録時に一度だけ呼び出され、メタ情報とインスタンスが +/// セッションスコープでキャッシュされます。 +/// +/// # Examples +/// +/// ```ignore +/// let def: ToolDefinition = Arc::new(|| { +/// ( +/// ToolMeta::new("my_tool") +/// .description("My tool description") +/// .input_schema(json!({"type": "object"})), +/// Arc::new(MyToolImpl { state: 0 }) as Arc, +/// ) +/// }); +/// worker.register_tool(def)?; +/// ``` +pub type ToolDefinition = Arc (ToolMeta, Arc) + Send + Sync>; + +// ============================================================================= +// Tool trait +// ============================================================================= + /// LLMから呼び出し可能なツールを定義するトレイト /// /// ツールはLLMが外部リソースにアクセスしたり、 /// 計算を実行したりするために使用します。 +/// セッション中の状態を保持できます。 /// /// # 実装方法 /// -/// 通常は`#[tool]`マクロを使用して自動実装します: +/// 通常は`#[tool_registry]`マクロを使用して自動実装します: /// /// ```ignore -/// use llm_worker::tool; -/// -/// #[tool(description = "Search the web for information")] -/// async fn search(query: String) -> String { -/// // 検索処理 -/// format!("Results for: {}", query) +/// #[tool_registry] +/// impl MyApp { +/// #[tool] +/// async fn search(&self, query: String) -> String { +/// format!("Results for: {}", query) +/// } /// } +/// +/// // 登録 +/// worker.register_tool(app.search_definition())?; /// ``` /// /// # 手動実装 /// /// ```ignore -/// use llm_worker::tool::{Tool, ToolError}; -/// use serde_json::{json, Value}; +/// use llm_worker::tool::{Tool, ToolError, ToolMeta, ToolDefinition}; +/// use std::sync::Arc; /// -/// struct MyTool; +/// struct MyTool { counter: std::sync::atomic::AtomicUsize } /// /// #[async_trait::async_trait] /// impl Tool for MyTool { -/// fn name(&self) -> &str { "my_tool" } -/// fn description(&self) -> &str { "My custom tool" } -/// fn input_schema(&self) -> Value { -/// json!({ -/// "type": "object", -/// "properties": { -/// "query": { "type": "string" } -/// }, -/// "required": ["query"] -/// }) -/// } /// async fn execute(&self, input: &str) -> Result { +/// self.counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst); /// Ok("result".to_string()) /// } /// } +/// +/// let def: ToolDefinition = Arc::new(|| { +/// ( +/// ToolMeta::new("my_tool") +/// .description("My custom tool") +/// .input_schema(serde_json::json!({"type": "object"})), +/// Arc::new(MyTool { counter: Default::default() }) as Arc, +/// ) +/// }); /// ``` #[async_trait] pub trait Tool: Send + Sync { - /// ツール名(LLMが識別に使用) - fn name(&self) -> &str; - - /// ツールの説明(LLMへのプロンプトに含まれる) - fn description(&self) -> &str; - - /// 引数のJSON Schema - /// - /// LLMはこのスキーマに従って引数を生成します。 - fn input_schema(&self) -> Value; - /// ツールを実行する /// /// # Arguments diff --git a/llm-worker/src/worker.rs b/llm-worker/src/worker.rs index 0b413c0..d8c32a0 100644 --- a/llm-worker/src/worker.rs +++ b/llm-worker/src/worker.rs @@ -9,18 +9,21 @@ use tracing::{debug, info, trace, warn}; use crate::{ ContentPart, Message, MessageContent, Role, hook::{ - AfterToolCall, AfterToolCallResult, BeforeToolCall, BeforeToolCallResult, Hook, HookError, - OnAbort, OnMessageSend, OnMessageSendResult, OnTurnEnd, OnTurnEndResult, ToolCall, - ToolResult, + Hook, HookError, HookRegistry, OnAbort, OnPromptSubmit, OnPromptSubmitResult, OnTurnEnd, + OnTurnEndResult, PostToolCall, PostToolCallContext, PostToolCallResult, PreLlmRequest, + PreLlmRequestResult, PreToolCall, PreToolCallResult, ToolCall, ToolCallContext, ToolResult, + }, + llm_client::{ + ClientError, ConfigWarning, LlmClient, Request, RequestConfig, + ToolDefinition as LlmToolDefinition, }, - llm_client::{ClientError, ConfigWarning, LlmClient, Request, RequestConfig, ToolDefinition}, state::{Locked, Mutable, WorkerState}, subscriber::{ ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter, ToolUseBlockSubscriberAdapter, UsageSubscriberAdapter, WorkerSubscriber, }, timeline::{TextBlockCollector, Timeline, ToolCallCollector}, - tool::{Tool, ToolError}, + tool::{Tool, ToolDefinition, ToolError, ToolMeta}, }; // ============================================================================= @@ -50,6 +53,14 @@ pub enum WorkerError { ConfigWarnings(Vec), } +/// ツール登録エラー +#[derive(Debug, thiserror::Error)] +pub enum ToolRegistryError { + /// 同名のツールが既に登録されている + #[error("Tool with name '{0}' already registered")] + DuplicateName(String), +} + // ============================================================================= // Worker Config // ============================================================================= @@ -155,18 +166,10 @@ pub struct Worker { text_block_collector: TextBlockCollector, /// ツールコールコレクター(Timeline用ハンドラ) tool_call_collector: ToolCallCollector, - /// 登録されたツール - tools: HashMap>, - /// on_message_send Hook - hooks_on_message_send: Vec>>, - /// before_tool_call Hook - hooks_before_tool_call: Vec>>, - /// after_tool_call Hook - hooks_after_tool_call: Vec>>, - /// on_turn_end Hook - hooks_on_turn_end: Vec>>, - /// on_abort Hook - hooks_on_abort: Vec>>, + /// 登録されたツール (meta, instance) + tools: HashMap)>, + /// Hook レジストリ + hooks: HookRegistry, /// システムプロンプト system_prompt: Option, /// メッセージ履歴(Workerが所有) @@ -248,51 +251,71 @@ impl Worker { /// ツールを登録する /// /// 登録されたツールはLLMからの呼び出しで自動的に実行されます。 - /// 同名のツールを登録した場合、後から登録したものが優先されます。 + /// 同名のツールを登録するとエラーになります。 /// /// # Examples /// /// ```ignore - /// use llm_worker::Worker; - /// use my_tools::SearchTool; + /// use llm_worker::tool::{ToolMeta, ToolDefinition, Tool}; + /// use std::sync::Arc; /// - /// worker.register_tool(SearchTool::new()); + /// let def: ToolDefinition = Arc::new(|| { + /// (ToolMeta::new("search").description("..."), Arc::new(MyTool) as Arc) + /// }); + /// worker.register_tool(def)?; /// ``` - pub fn register_tool(&mut self, tool: impl Tool + 'static) { - let name = tool.name().to_string(); - self.tools.insert(name, Arc::new(tool)); + pub fn register_tool(&mut self, factory: ToolDefinition) -> Result<(), ToolRegistryError> { + let (meta, instance) = factory(); + if self.tools.contains_key(&meta.name) { + return Err(ToolRegistryError::DuplicateName(meta.name.clone())); + } + self.tools.insert(meta.name.clone(), (meta, instance)); + Ok(()) } /// 複数のツールを登録 - pub fn register_tools(&mut self, tools: impl IntoIterator) { - for tool in tools { - self.register_tool(tool); + pub fn register_tools( + &mut self, + factories: impl IntoIterator, + ) -> Result<(), ToolRegistryError> { + for factory in factories { + self.register_tool(factory)?; } + Ok(()) } - /// on_message_send Hookを追加する - pub fn add_on_message_send_hook(&mut self, hook: impl Hook + 'static) { - self.hooks_on_message_send.push(Box::new(hook)); + /// on_prompt_submit Hookを追加する + /// + /// `run()` でユーザーメッセージを受け取った直後に呼び出される。 + pub fn add_on_prompt_submit_hook(&mut self, hook: impl Hook + 'static) { + self.hooks.on_prompt_submit.push(Box::new(hook)); } - /// before_tool_call Hookを追加する - pub fn add_before_tool_call_hook(&mut self, hook: impl Hook + 'static) { - self.hooks_before_tool_call.push(Box::new(hook)); + /// pre_llm_request Hookを追加する + /// + /// 各ターンのLLMリクエスト送信前に呼び出される。 + pub fn add_pre_llm_request_hook(&mut self, hook: impl Hook + 'static) { + self.hooks.pre_llm_request.push(Box::new(hook)); } - /// after_tool_call Hookを追加する - pub fn add_after_tool_call_hook(&mut self, hook: impl Hook + 'static) { - self.hooks_after_tool_call.push(Box::new(hook)); + /// pre_tool_call Hookを追加する + pub fn add_pre_tool_call_hook(&mut self, hook: impl Hook + 'static) { + self.hooks.pre_tool_call.push(Box::new(hook)); + } + + /// post_tool_call Hookを追加する + pub fn add_post_tool_call_hook(&mut self, hook: impl Hook + 'static) { + self.hooks.post_tool_call.push(Box::new(hook)); } /// on_turn_end Hookを追加する pub fn add_on_turn_end_hook(&mut self, hook: impl Hook + 'static) { - self.hooks_on_turn_end.push(Box::new(hook)); + self.hooks.on_turn_end.push(Box::new(hook)); } /// on_abort Hookを追加する pub fn add_on_abort_hook(&mut self, hook: impl Hook + 'static) { - self.hooks_on_abort.push(Box::new(hook)); + self.hooks.on_abort.push(Box::new(hook)); } /// タイムラインへの可変参照を取得(追加ハンドラ登録用) @@ -427,14 +450,14 @@ impl Worker { &self.cancellation_token } - /// 登録されたツールからToolDefinitionのリストを生成 - fn build_tool_definitions(&self) -> Vec { + /// 登録されたツールからLLM用ToolDefinitionのリストを生成 + fn build_tool_definitions(&self) -> Vec { self.tools .values() - .map(|tool| { - ToolDefinition::new(tool.name()) - .description(tool.description()) - .input_schema(tool.input_schema()) + .map(|(meta, _)| { + LlmToolDefinition::new(&meta.name) + .description(&meta.description) + .input_schema(meta.input_schema.clone()) }) .collect() } @@ -482,7 +505,11 @@ impl Worker { } /// リクエストを構築 - fn build_request(&self, tool_definitions: &[ToolDefinition], context: &[Message]) -> Request { + fn build_request( + &self, + tool_definitions: &[LlmToolDefinition], + context: &[Message], + ) -> Request { let mut request = Request::new(); // システムプロンプトを設定 @@ -546,27 +573,48 @@ impl Worker { request } - /// Hooks: on_message_send - async fn run_on_message_send_hooks( + /// Hooks: on_prompt_submit + /// + /// `run()` でユーザーメッセージを受け取った直後に呼び出される(最初だけ)。 + async fn run_on_prompt_submit_hooks( &self, - ) -> Result<(OnMessageSendResult, Vec), WorkerError> { - let mut temp_context = self.history.clone(); - for hook in &self.hooks_on_message_send { - let result = hook.call(&mut temp_context).await?; + message: &mut Message, + ) -> Result { + for hook in &self.hooks.on_prompt_submit { + let result = hook.call(message).await?; match result { - OnMessageSendResult::Continue => continue, - OnMessageSendResult::Cancel(reason) => { - return Ok((OnMessageSendResult::Cancel(reason), temp_context)); + OnPromptSubmitResult::Continue => continue, + OnPromptSubmitResult::Cancel(reason) => { + return Ok(OnPromptSubmitResult::Cancel(reason)); } } } - Ok((OnMessageSendResult::Continue, temp_context)) + Ok(OnPromptSubmitResult::Continue) + } + + /// Hooks: pre_llm_request + /// + /// 各ターンのLLMリクエスト送信前に呼び出される(毎ターン)。 + async fn run_pre_llm_request_hooks( + &self, + ) -> Result<(PreLlmRequestResult, Vec), WorkerError> { + let mut temp_context = self.history.clone(); + for hook in &self.hooks.pre_llm_request { + let result = hook.call(&mut temp_context).await?; + match result { + PreLlmRequestResult::Continue => continue, + PreLlmRequestResult::Cancel(reason) => { + return Ok((PreLlmRequestResult::Cancel(reason), temp_context)); + } + } + } + Ok((PreLlmRequestResult::Continue, temp_context)) } /// Hooks: on_turn_end async fn run_on_turn_end_hooks(&self) -> Result { let mut temp_messages = self.history.clone(); - for hook in &self.hooks_on_turn_end { + for hook in &self.hooks.on_turn_end { let result = hook.call(&mut temp_messages).await?; match result { OnTurnEndResult::Finish => continue, @@ -582,7 +630,7 @@ impl Worker { /// Hooks: on_abort async fn run_on_abort_hooks(&self, reason: &str) -> Result<(), WorkerError> { let mut reason = reason.to_string(); - for hook in &self.hooks_on_abort { + for hook in &self.hooks.on_abort { hook.call(&mut reason).await?; } Ok(()) @@ -608,44 +656,67 @@ impl Worker { } } - if calls.is_empty() { - None - } else { - Some(calls) - } + if calls.is_empty() { None } else { Some(calls) } } /// ツールを並列実行 /// - /// 全てのツールに対してbefore_tool_callフックを実行後、 - /// 許可されたツールを並列に実行し、結果にafter_tool_callフックを適用する。 + /// 全てのツールに対してpre_tool_callフックを実行後、 + /// 許可されたツールを並列に実行し、結果にpost_tool_callフックを適用する。 async fn execute_tools( &mut self, tool_calls: Vec, ) -> Result { use futures::future::join_all; - // Phase 1: before_tool_call フックを適用(スキップ/中断を判定) + // ツール呼び出しIDから (ToolCall, Meta, Tool) へのマップ + // PostToolCallフックで必要になるため保持する + let mut call_info_map = HashMap::new(); + + // Phase 1: pre_tool_call フックを適用(スキップ/中断を判定) let mut approved_calls = Vec::new(); for mut tool_call in tool_calls { - let mut skip = false; - for hook in &self.hooks_before_tool_call { - let result = hook.call(&mut tool_call).await?; - match result { - BeforeToolCallResult::Continue => {} - BeforeToolCallResult::Skip => { - skip = true; - break; - } - BeforeToolCallResult::Abort(reason) => { - return Err(WorkerError::Aborted(reason)); - } - BeforeToolCallResult::Pause => { - return Ok(ToolExecutionResult::Paused); + // ツール定義を取得 + if let Some((meta, tool)) = self.tools.get(&tool_call.name) { + // コンテキストを作成 + let mut context = ToolCallContext { + call: tool_call.clone(), + meta: meta.clone(), + tool: tool.clone(), + }; + + let mut skip = false; + for hook in &self.hooks.pre_tool_call { + let result = hook.call(&mut context).await?; + match result { + PreToolCallResult::Continue => {} + PreToolCallResult::Skip => { + skip = true; + break; + } + PreToolCallResult::Abort(reason) => { + return Err(WorkerError::Aborted(reason)); + } + PreToolCallResult::Pause => { + return Ok(ToolExecutionResult::Paused); + } } } - } - if !skip { + + // フックで変更された内容を反映 + tool_call = context.call; + + // マップに保存(実行する場合のみ) + if !skip { + call_info_map.insert( + tool_call.id.clone(), + (tool_call.clone(), meta.clone(), tool.clone()), + ); + approved_calls.push(tool_call); + } + } else { + // 未知のツールはそのまま承認リストに入れる(実行時にエラーになる) + // Hookは適用しない(Metaがないため) approved_calls.push(tool_call); } } @@ -656,7 +727,7 @@ impl Worker { .map(|tool_call| { let tools = &self.tools; async move { - if let Some(tool) = tools.get(&tool_call.name) { + if let Some((_, tool)) = tools.get(&tool_call.name) { let input_json = serde_json::to_string(&tool_call.input).unwrap_or_default(); match tool.execute(&input_json).await { @@ -684,16 +755,28 @@ impl Worker { } }; - // Phase 3: after_tool_call フックを適用 + // Phase 3: post_tool_call フックを適用 for tool_result in &mut results { - for hook in &self.hooks_after_tool_call { - let result = hook.call(tool_result).await?; - match result { - AfterToolCallResult::Continue => {} - AfterToolCallResult::Abort(reason) => { - return Err(WorkerError::Aborted(reason)); + // 保存しておいた情報を取得 + if let Some((tool_call, meta, tool)) = call_info_map.get(&tool_result.tool_use_id) { + let mut context = PostToolCallContext { + call: tool_call.clone(), + result: tool_result.clone(), + meta: meta.clone(), + tool: tool.clone(), + }; + + for hook in &self.hooks.post_tool_call { + let result = hook.call(&mut context).await?; + match result { + PostToolCallResult::Continue => {} + PostToolCallResult::Abort(reason) => { + return Err(WorkerError::Aborted(reason)); + } } } + // フックで変更された結果を反映 + *tool_result = context.result; } } @@ -712,16 +795,17 @@ impl Worker { // Resume check: Pending tool calls if let Some(tool_calls) = self.get_pending_tool_calls() { - info!("Resuming pending tool calls"); - match self.execute_tools(tool_calls).await? { - ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)), - ToolExecutionResult::Completed(results) => { - for result in results { - self.history.push(Message::tool_result(&result.tool_use_id, &result.content)); - } - // Continue to loop - } - } + info!("Resuming pending tool calls"); + match self.execute_tools(tool_calls).await? { + ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)), + ToolExecutionResult::Completed(results) => { + for result in results { + self.history + .push(Message::tool_result(&result.tool_use_id, &result.content)); + } + // Continue to loop + } + } } loop { @@ -740,10 +824,10 @@ impl Worker { notifier.on_turn_start(current_turn); } - // Hook: on_message_send - let (control, request_context) = self.run_on_message_send_hooks().await?; + // Hook: pre_llm_request + let (control, request_context) = self.run_pre_llm_request_hooks().await?; match control { - OnMessageSendResult::Cancel(reason) => { + PreLlmRequestResult::Cancel(reason) => { info!(reason = %reason, "Aborted by hook"); for notifier in &self.turn_notifiers { notifier.on_turn_end(current_turn); @@ -751,7 +835,7 @@ impl Worker { self.run_on_abort_hooks(&reason).await?; return Err(WorkerError::Aborted(reason)); } - OnMessageSendResult::Continue => {} + PreLlmRequestResult::Continue => {} } // リクエスト構築 @@ -766,7 +850,7 @@ impl Worker { // ストリーム処理 debug!("Starting stream..."); let mut event_count = 0; - + // ストリームを取得(キャンセル可能) let mut stream = tokio::select! { stream_result = self.client.stream(request) => stream_result?, @@ -777,7 +861,7 @@ impl Worker { return Err(WorkerError::Cancelled); } }; - + loop { tokio::select! { // ストリームからイベントを受信 @@ -846,12 +930,13 @@ impl Worker { // ツール実行 match self.execute_tools(tool_calls).await? { - ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)), - ToolExecutionResult::Completed(results) => { - for result in results { - self.history.push(Message::tool_result(&result.tool_use_id, &result.content)); - } - } + ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)), + ToolExecutionResult::Completed(results) => { + for result in results { + self.history + .push(Message::tool_result(&result.tool_use_id, &result.content)); + } + } } } } @@ -885,11 +970,7 @@ impl Worker { text_block_collector, tool_call_collector, tools: HashMap::new(), - hooks_on_message_send: Vec::new(), - hooks_before_tool_call: Vec::new(), - hooks_after_tool_call: Vec::new(), - hooks_on_turn_end: Vec::new(), - hooks_on_abort: Vec::new(), + hooks: HookRegistry::new(), system_prompt: None, history: Vec::new(), locked_prefix_len: 0, @@ -1058,11 +1139,7 @@ impl Worker { text_block_collector: self.text_block_collector, tool_call_collector: self.tool_call_collector, tools: self.tools, - hooks_on_message_send: self.hooks_on_message_send, - hooks_before_tool_call: self.hooks_before_tool_call, - hooks_after_tool_call: self.hooks_after_tool_call, - hooks_on_turn_end: self.hooks_on_turn_end, - hooks_on_abort: self.hooks_on_abort, + hooks: self.hooks, system_prompt: self.system_prompt, history: self.history, locked_prefix_len, @@ -1081,8 +1158,21 @@ impl Worker { /// /// 注意: この関数は履歴を変更するため、キャッシュ保護が必要な場合は /// `lock()` を呼んでからLocked状態で `run` を使用すること。 - pub async fn run(&mut self, user_input: impl Into) -> Result, WorkerError> { - self.history.push(Message::user(user_input)); + pub async fn run( + &mut self, + user_input: impl Into, + ) -> Result, WorkerError> { + // Hook: on_prompt_submit + let mut user_message = Message::user(user_input); + let result = self.run_on_prompt_submit_hooks(&mut user_message).await?; + match result { + OnPromptSubmitResult::Cancel(reason) => { + self.run_on_abort_hooks(&reason).await?; + return Err(WorkerError::Aborted(reason)); + } + OnPromptSubmitResult::Continue => {} + } + self.history.push(user_message); self.run_turn_loop().await } @@ -1107,8 +1197,21 @@ impl Worker { /// /// 新しいユーザーメッセージを履歴の末尾に追加し、LLMにリクエストを送信する。 /// ロック時点より前の履歴(プレフィックス)は不変であるため、キャッシュヒットが保証される。 - pub async fn run(&mut self, user_input: impl Into) -> Result, WorkerError> { - self.history.push(Message::user(user_input)); + pub async fn run( + &mut self, + user_input: impl Into, + ) -> Result, WorkerError> { + // Hook: on_prompt_submit + let mut user_message = Message::user(user_input); + let result = self.run_on_prompt_submit_hooks(&mut user_message).await?; + match result { + OnPromptSubmitResult::Cancel(reason) => { + self.run_on_abort_hooks(&reason).await?; + return Err(WorkerError::Aborted(reason)); + } + OnPromptSubmitResult::Continue => {} + } + self.history.push(user_message); self.run_turn_loop().await } @@ -1137,11 +1240,7 @@ impl Worker { text_block_collector: self.text_block_collector, tool_call_collector: self.tool_call_collector, tools: self.tools, - hooks_on_message_send: self.hooks_on_message_send, - hooks_before_tool_call: self.hooks_before_tool_call, - hooks_after_tool_call: self.hooks_after_tool_call, - hooks_on_turn_end: self.hooks_on_turn_end, - hooks_on_abort: self.hooks_on_abort, + hooks: self.hooks, system_prompt: self.system_prompt, history: self.history, locked_prefix_len: 0, diff --git a/llm-worker/tests/parallel_execution_test.rs b/llm-worker/tests/parallel_execution_test.rs index 485df23..f3c2769 100644 --- a/llm-worker/tests/parallel_execution_test.rs +++ b/llm-worker/tests/parallel_execution_test.rs @@ -9,11 +9,11 @@ use std::time::{Duration, Instant}; use async_trait::async_trait; use llm_worker::Worker; use llm_worker::hook::{ - AfterToolCall, AfterToolCallResult, BeforeToolCall, BeforeToolCallResult, Hook, HookError, - ToolCall, ToolResult, + Hook, HookError, PostToolCall, PostToolCallContext, PostToolCallResult, PreToolCall, + PreToolCallResult, ToolCallContext, }; use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent}; -use llm_worker::tool::{Tool, ToolError}; +use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta}; mod common; use common::MockLlmClient; @@ -42,25 +42,24 @@ impl SlowTool { fn call_count(&self) -> usize { self.call_count.load(Ordering::SeqCst) } + + /// ToolDefinition を作成 + fn definition(&self) -> ToolDefinition { + let tool = self.clone(); + Arc::new(move || { + let meta = ToolMeta::new(&tool.name) + .description("A tool that waits before responding") + .input_schema(serde_json::json!({ + "type": "object", + "properties": {} + })); + (meta, Arc::new(tool.clone()) as Arc) + }) + } } #[async_trait] impl Tool for SlowTool { - fn name(&self) -> &str { - &self.name - } - - fn description(&self) -> &str { - "A tool that waits before responding" - } - - fn input_schema(&self) -> serde_json::Value { - serde_json::json!({ - "type": "object", - "properties": {} - }) - } - async fn execute(&self, _input_json: &str) -> Result { self.call_count.fetch_add(1, Ordering::SeqCst); tokio::time::sleep(Duration::from_millis(self.delay_ms)).await; @@ -106,9 +105,9 @@ async fn test_parallel_tool_execution() { let tool2_clone = tool2.clone(); let tool3_clone = tool3.clone(); - worker.register_tool(tool1); - worker.register_tool(tool2); - worker.register_tool(tool3); + worker.register_tool(tool1.definition()).unwrap(); + worker.register_tool(tool2.definition()).unwrap(); + worker.register_tool(tool3.definition()).unwrap(); let start = Instant::now(); let _result = worker.run("Run all tools").await; @@ -130,7 +129,7 @@ async fn test_parallel_tool_execution() { println!("Parallel execution completed in {:?}", elapsed); } -/// Hook: before_tool_call でスキップされたツールは実行されないことを確認 +/// Hook: pre_tool_call でスキップされたツールは実行されないことを確認 #[tokio::test] async fn test_before_tool_call_skip() { let events = vec![ @@ -154,24 +153,24 @@ async fn test_before_tool_call_skip() { let allowed_clone = allowed_tool.clone(); let blocked_clone = blocked_tool.clone(); - worker.register_tool(allowed_tool); - worker.register_tool(blocked_tool); + worker.register_tool(allowed_tool.definition()).unwrap(); + worker.register_tool(blocked_tool.definition()).unwrap(); // "blocked_tool" をスキップするHook struct BlockingHook; #[async_trait] - impl Hook for BlockingHook { - async fn call(&self, tool_call: &mut ToolCall) -> Result { - if tool_call.name == "blocked_tool" { - Ok(BeforeToolCallResult::Skip) + impl Hook for BlockingHook { + async fn call(&self, ctx: &mut ToolCallContext) -> Result { + if ctx.call.name == "blocked_tool" { + Ok(PreToolCallResult::Skip) } else { - Ok(BeforeToolCallResult::Continue) + Ok(PreToolCallResult::Continue) } } } - worker.add_before_tool_call_hook(BlockingHook); + worker.add_pre_tool_call_hook(BlockingHook); let _result = worker.run("Test hook").await; @@ -188,9 +187,9 @@ async fn test_before_tool_call_skip() { ); } -/// Hook: after_tool_call で結果が改変されることを確認 +/// Hook: post_tool_call で結果が改変されることを確認 #[tokio::test] -async fn test_after_tool_call_modification() { +async fn test_post_tool_call_modification() { // 複数リクエストに対応するレスポンスを準備 let client = MockLlmClient::with_responses(vec![ // 1回目のリクエスト: ツール呼び出し @@ -220,21 +219,21 @@ async fn test_after_tool_call_modification() { #[async_trait] impl Tool for SimpleTool { - fn name(&self) -> &str { - "test_tool" - } - fn description(&self) -> &str { - "Test" - } - fn input_schema(&self) -> serde_json::Value { - serde_json::json!({}) - } async fn execute(&self, _: &str) -> Result { Ok("Original Result".to_string()) } } - worker.register_tool(SimpleTool); + fn simple_tool_definition() -> ToolDefinition { + Arc::new(|| { + let meta = ToolMeta::new("test_tool") + .description("Test") + .input_schema(serde_json::json!({})); + (meta, Arc::new(SimpleTool) as Arc) + }) + } + + worker.register_tool(simple_tool_definition()).unwrap(); // 結果を改変するHook struct ModifyingHook { @@ -242,19 +241,19 @@ async fn test_after_tool_call_modification() { } #[async_trait] - impl Hook for ModifyingHook { + impl Hook for ModifyingHook { async fn call( &self, - tool_result: &mut ToolResult, - ) -> Result { - tool_result.content = format!("[Modified] {}", tool_result.content); - *self.modified_content.lock().unwrap() = Some(tool_result.content.clone()); - Ok(AfterToolCallResult::Continue) + ctx: &mut PostToolCallContext, + ) -> Result { + ctx.result.content = format!("[Modified] {}", ctx.result.content); + *self.modified_content.lock().unwrap() = Some(ctx.result.content.clone()); + Ok(PostToolCallResult::Continue) } } let modified_content = Arc::new(std::sync::Mutex::new(None)); - worker.add_after_tool_call_hook(ModifyingHook { + worker.add_post_tool_call_hook(ModifyingHook { modified_content: modified_content.clone(), }); diff --git a/llm-worker/tests/tool_macro_test.rs b/llm-worker/tests/tool_macro_test.rs index 04f8f64..7b04cfe 100644 --- a/llm-worker/tests/tool_macro_test.rs +++ b/llm-worker/tests/tool_macro_test.rs @@ -9,7 +9,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use schemars; use serde; -use llm_worker::tool::Tool; +use llm_worker::tool::{Tool, ToolMeta}; use llm_worker_macros::tool_registry; // ============================================================================= @@ -51,30 +51,31 @@ async fn test_basic_tool_generation() { prefix: "Hello".to_string(), }; - // ファクトリメソッドでツールを取得 - let greet_tool = ctx.greet_tool(); + // ファクトリメソッドでToolDefinitionを取得 + let greet_definition = ctx.greet_definition(); - // 名前の確認 - assert_eq!(greet_tool.name(), "greet"); + // ファクトリを呼び出してMetaとToolを取得 + let (meta, tool) = greet_definition(); - // 説明の確認(docコメントから取得) - let desc = greet_tool.description(); + // メタ情報の確認 + assert_eq!(meta.name, "greet"); assert!( - desc.contains("メッセージに挨拶を追加する"), + meta.description.contains("メッセージに挨拶を追加する"), "Description should contain doc comment: {}", - desc + meta.description ); - - // スキーマの確認 - let schema = greet_tool.input_schema(); - println!("Schema: {}", serde_json::to_string_pretty(&schema).unwrap()); assert!( - schema.get("properties").is_some(), + meta.input_schema.get("properties").is_some(), "Schema should have properties" ); + println!( + "Schema: {}", + serde_json::to_string_pretty(&meta.input_schema).unwrap() + ); + // 実行テスト - let result = greet_tool.execute(r#"{"message": "World"}"#).await; + let result = tool.execute(r#"{"message": "World"}"#).await; assert!(result.is_ok(), "Should execute successfully"); let output = result.unwrap(); assert!(output.contains("Hello"), "Output should contain prefix"); @@ -87,11 +88,11 @@ async fn test_multiple_arguments() { prefix: "".to_string(), }; - let add_tool = ctx.add_tool(); + let (meta, tool) = ctx.add_definition()(); - assert_eq!(add_tool.name(), "add"); + assert_eq!(meta.name, "add"); - let result = add_tool.execute(r#"{"a": 10, "b": 20}"#).await; + let result = tool.execute(r#"{"a": 10, "b": 20}"#).await; assert!(result.is_ok()); let output = result.unwrap(); assert!(output.contains("30"), "Should contain sum: {}", output); @@ -103,12 +104,12 @@ async fn test_no_arguments() { prefix: "TestPrefix".to_string(), }; - let get_prefix_tool = ctx.get_prefix_tool(); + let (meta, tool) = ctx.get_prefix_definition()(); - assert_eq!(get_prefix_tool.name(), "get_prefix"); + assert_eq!(meta.name, "get_prefix"); // 空のJSONオブジェクトで呼び出し - let result = get_prefix_tool.execute(r#"{}"#).await; + let result = tool.execute(r#"{}"#).await; assert!(result.is_ok()); let output = result.unwrap(); assert!( @@ -124,10 +125,10 @@ async fn test_invalid_arguments() { prefix: "".to_string(), }; - let greet_tool = ctx.greet_tool(); + let (_, tool) = ctx.greet_definition()(); // 不正なJSON - let result = greet_tool.execute(r#"{"wrong_field": "value"}"#).await; + let result = tool.execute(r#"{"wrong_field": "value"}"#).await; assert!(result.is_err(), "Should fail with invalid arguments"); } @@ -163,9 +164,9 @@ impl FallibleContext { #[tokio::test] async fn test_result_return_type_success() { let ctx = FallibleContext; - let validate_tool = ctx.validate_tool(); + let (_, tool) = ctx.validate_definition()(); - let result = validate_tool.execute(r#"{"value": 42}"#).await; + let result = tool.execute(r#"{"value": 42}"#).await; assert!(result.is_ok(), "Should succeed for positive value"); let output = result.unwrap(); assert!(output.contains("Valid"), "Should contain Valid: {}", output); @@ -174,9 +175,9 @@ async fn test_result_return_type_success() { #[tokio::test] async fn test_result_return_type_error() { let ctx = FallibleContext; - let validate_tool = ctx.validate_tool(); + let (_, tool) = ctx.validate_definition()(); - let result = validate_tool.execute(r#"{"value": -1}"#).await; + let result = tool.execute(r#"{"value": -1}"#).await; assert!(result.is_err(), "Should fail for negative value"); let err = result.unwrap_err(); @@ -211,12 +212,12 @@ async fn test_sync_method() { counter: Arc::new(AtomicUsize::new(0)), }; - let increment_tool = ctx.increment_tool(); + let (_, tool) = ctx.increment_definition()(); // 3回実行 - let result1 = increment_tool.execute(r#"{}"#).await; - let result2 = increment_tool.execute(r#"{}"#).await; - let result3 = increment_tool.execute(r#"{}"#).await; + let result1 = tool.execute(r#"{}"#).await; + let result2 = tool.execute(r#"{}"#).await; + let result3 = tool.execute(r#"{}"#).await; assert!(result1.is_ok()); assert!(result2.is_ok()); @@ -225,3 +226,22 @@ async fn test_sync_method() { // カウンターは3になっているはず assert_eq!(ctx.counter.load(Ordering::SeqCst), 3); } + +// ============================================================================= +// Test: ToolMeta Immutability +// ============================================================================= + +#[tokio::test] +async fn test_tool_meta_immutability() { + let ctx = SimpleContext { + prefix: "Test".to_string(), + }; + + // 2回取得しても同じメタ情報が得られることを確認 + let (meta1, _) = ctx.greet_definition()(); + let (meta2, _) = ctx.greet_definition()(); + + assert_eq!(meta1.name, meta2.name); + assert_eq!(meta1.description, meta2.description); + assert_eq!(meta1.input_schema, meta2.input_schema); +} diff --git a/llm-worker/tests/validation_test.rs b/llm-worker/tests/validation_test.rs index 9ba3017..71e08e7 100644 --- a/llm-worker/tests/validation_test.rs +++ b/llm-worker/tests/validation_test.rs @@ -1,4 +1,3 @@ -use llm_worker::llm_client::LlmClient; use llm_worker::llm_client::providers::openai::OpenAIClient; use llm_worker::{Worker, WorkerError}; diff --git a/llm-worker/tests/worker_fixtures.rs b/llm-worker/tests/worker_fixtures.rs index edb26b5..ce777fa 100644 --- a/llm-worker/tests/worker_fixtures.rs +++ b/llm-worker/tests/worker_fixtures.rs @@ -12,7 +12,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use async_trait::async_trait; use common::MockLlmClient; use llm_worker::Worker; -use llm_worker::tool::{Tool, ToolError}; +use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta}; /// フィクスチャディレクトリのパス fn fixtures_dir() -> std::path::PathBuf { @@ -35,31 +35,29 @@ impl MockWeatherTool { fn get_call_count(&self) -> usize { self.call_count.load(Ordering::SeqCst) } + + fn definition(&self) -> ToolDefinition { + let tool = self.clone(); + Arc::new(move || { + let meta = ToolMeta::new("get_weather") + .description("Get the current weather for a city") + .input_schema(serde_json::json!({ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name" + } + }, + "required": ["city"] + })); + (meta, Arc::new(tool.clone()) as Arc) + }) + } } #[async_trait] impl Tool for MockWeatherTool { - fn name(&self) -> &str { - "get_weather" - } - - fn description(&self) -> &str { - "Get the current weather for a city" - } - - fn input_schema(&self) -> serde_json::Value { - serde_json::json!({ - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "The city name" - } - }, - "required": ["city"] - }) - } - async fn execute(&self, input_json: &str) -> Result { self.call_count.fetch_add(1, Ordering::SeqCst); @@ -158,7 +156,7 @@ async fn test_worker_tool_call() { // ツールを登録 let weather_tool = MockWeatherTool::new(); let tool_for_check = weather_tool.clone(); - worker.register_tool(weather_tool); + worker.register_tool(weather_tool.definition()).unwrap(); // メッセージを送信 let _result = worker.run("What's the weather in Tokyo?").await;