//! 並列ツール実行のテスト //! //! Workerが複数のツールを並列に実行することを確認する。 use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::{Duration, Instant}; use async_trait::async_trait; use llm_worker::Worker; use llm_worker::hook::{ Hook, HookError, PostToolCall, PostToolCallContext, PostToolCallResult, PreToolCall, PreToolCallResult, ToolCallContext, }; use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent}; use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta}; mod common; use common::MockLlmClient; // ============================================================================= // Parallel Execution Test Tools // ============================================================================= /// 一定時間待機してから応答するツール #[derive(Clone)] struct SlowTool { name: String, delay_ms: u64, call_count: Arc, } impl SlowTool { fn new(name: impl Into, delay_ms: u64) -> Self { Self { name: name.into(), delay_ms, call_count: Arc::new(AtomicUsize::new(0)), } } fn call_count(&self) -> usize { self.call_count.load(Ordering::SeqCst) } /// ToolDefinition を作成 fn definition(&self) -> ToolDefinition { let tool = self.clone(); Arc::new(move || { let meta = ToolMeta::new(&tool.name) .description("A tool that waits before responding") .input_schema(serde_json::json!({ "type": "object", "properties": {} })); (meta, Arc::new(tool.clone()) as Arc) }) } } #[async_trait] impl Tool for SlowTool { async fn execute(&self, _input_json: &str) -> Result { self.call_count.fetch_add(1, Ordering::SeqCst); tokio::time::sleep(Duration::from_millis(self.delay_ms)).await; Ok(format!("Completed after {}ms", self.delay_ms)) } } // ============================================================================= // Tests // ============================================================================= /// 複数のツールが並列に実行されることを確認 /// /// 各ツールが100msかかる場合、逐次実行なら300ms以上かかるが、 /// 並列実行なら100ms程度で完了するはず。 #[tokio::test] async fn test_parallel_tool_execution() { // 3つのツール呼び出しを含むイベントシーケンス let events = vec![ Event::tool_use_start(0, "call_1", "slow_tool_1"), Event::tool_input_delta(0, r#"{}"#), Event::tool_use_stop(0), Event::tool_use_start(1, "call_2", "slow_tool_2"), Event::tool_input_delta(1, r#"{}"#), Event::tool_use_stop(1), Event::tool_use_start(2, "call_3", "slow_tool_3"), Event::tool_input_delta(2, r#"{}"#), Event::tool_use_stop(2), Event::Status(StatusEvent { status: ResponseStatus::Completed, }), ]; let client = MockLlmClient::new(events); let mut worker = Worker::new(client); // 各ツールは100ms待機 let tool1 = SlowTool::new("slow_tool_1", 100); let tool2 = SlowTool::new("slow_tool_2", 100); let tool3 = SlowTool::new("slow_tool_3", 100); let tool1_clone = tool1.clone(); let tool2_clone = tool2.clone(); let tool3_clone = tool3.clone(); worker.register_tool(tool1.definition()).unwrap(); worker.register_tool(tool2.definition()).unwrap(); worker.register_tool(tool3.definition()).unwrap(); let start = Instant::now(); let _result = worker.run("Run all tools").await; let elapsed = start.elapsed(); // 全ツールが呼び出されたことを確認 assert_eq!(tool1_clone.call_count(), 1, "Tool 1 should be called once"); assert_eq!(tool2_clone.call_count(), 1, "Tool 2 should be called once"); assert_eq!(tool3_clone.call_count(), 1, "Tool 3 should be called once"); // 並列実行なら200ms以下で完了するはず(逐次なら300ms以上) // マージン込みで250msをしきい値とする assert!( elapsed < Duration::from_millis(250), "Parallel execution should complete in ~100ms, but took {:?}", elapsed ); println!("Parallel execution completed in {:?}", elapsed); } /// Hook: pre_tool_call でスキップされたツールは実行されないことを確認 #[tokio::test] async fn test_before_tool_call_skip() { let events = vec![ Event::tool_use_start(0, "call_1", "allowed_tool"), Event::tool_input_delta(0, r#"{}"#), Event::tool_use_stop(0), Event::tool_use_start(1, "call_2", "blocked_tool"), Event::tool_input_delta(1, r#"{}"#), Event::tool_use_stop(1), Event::Status(StatusEvent { status: ResponseStatus::Completed, }), ]; let client = MockLlmClient::new(events); let mut worker = Worker::new(client); let allowed_tool = SlowTool::new("allowed_tool", 10); let blocked_tool = SlowTool::new("blocked_tool", 10); let allowed_clone = allowed_tool.clone(); let blocked_clone = blocked_tool.clone(); worker.register_tool(allowed_tool.definition()).unwrap(); worker.register_tool(blocked_tool.definition()).unwrap(); // "blocked_tool" をスキップするHook struct BlockingHook; #[async_trait] impl Hook for BlockingHook { async fn call(&self, ctx: &mut ToolCallContext) -> Result { if ctx.call.name == "blocked_tool" { Ok(PreToolCallResult::Skip) } else { Ok(PreToolCallResult::Continue) } } } worker.add_pre_tool_call_hook(BlockingHook); let _result = worker.run("Test hook").await; // allowed_tool は呼び出されるが、blocked_tool は呼び出されない assert_eq!( allowed_clone.call_count(), 1, "Allowed tool should be called" ); assert_eq!( blocked_clone.call_count(), 0, "Blocked tool should not be called" ); } /// Hook: post_tool_call で結果が改変されることを確認 #[tokio::test] async fn test_post_tool_call_modification() { // 複数リクエストに対応するレスポンスを準備 let client = MockLlmClient::with_responses(vec![ // 1回目のリクエスト: ツール呼び出し vec![ Event::tool_use_start(0, "call_1", "test_tool"), Event::tool_input_delta(0, r#"{}"#), Event::tool_use_stop(0), Event::Status(StatusEvent { status: ResponseStatus::Completed, }), ], // 2回目のリクエスト: ツール結果を受けてテキストレスポンス vec![ Event::text_block_start(0), Event::text_delta(0, "Done!"), Event::text_block_stop(0, None), Event::Status(StatusEvent { status: ResponseStatus::Completed, }), ], ]); let mut worker = Worker::new(client); #[derive(Clone)] struct SimpleTool; #[async_trait] impl Tool for SimpleTool { async fn execute(&self, _: &str) -> Result { Ok("Original Result".to_string()) } } fn simple_tool_definition() -> ToolDefinition { Arc::new(|| { let meta = ToolMeta::new("test_tool") .description("Test") .input_schema(serde_json::json!({})); (meta, Arc::new(SimpleTool) as Arc) }) } worker.register_tool(simple_tool_definition()).unwrap(); // 結果を改変するHook struct ModifyingHook { modified_content: Arc>>, } #[async_trait] impl Hook for ModifyingHook { async fn call( &self, ctx: &mut PostToolCallContext, ) -> Result { ctx.result.content = format!("[Modified] {}", ctx.result.content); *self.modified_content.lock().unwrap() = Some(ctx.result.content.clone()); Ok(PostToolCallResult::Continue) } } let modified_content = Arc::new(std::sync::Mutex::new(None)); worker.add_post_tool_call_hook(ModifyingHook { modified_content: modified_content.clone(), }); let result = worker.run("Test modification").await; assert!(result.is_ok(), "Worker should complete: {:?}", result); // Hookが呼ばれて内容が改変されたことを確認 let content = modified_content.lock().unwrap().clone(); assert!(content.is_some(), "Hook should have been called"); assert!( content.unwrap().contains("[Modified]"), "Result should be modified" ); }