//! 並列ツール実行のテスト //! //! Workerが複数のツールを並列に実行することを確認する。 use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::{Duration, Instant}; use async_trait::async_trait; use worker::Worker; use worker_types::{ ControlFlow, Event, HookError, ResponseStatus, StatusEvent, Tool, ToolCall, ToolError, ToolResult, WorkerHook, }; mod common; use common::MockLlmClient; // ============================================================================= // Parallel Execution Test Tools // ============================================================================= /// 一定時間待機してから応答するツール #[derive(Clone)] struct SlowTool { name: String, delay_ms: u64, call_count: Arc, } impl SlowTool { fn new(name: impl Into, delay_ms: u64) -> Self { Self { name: name.into(), delay_ms, call_count: Arc::new(AtomicUsize::new(0)), } } fn call_count(&self) -> usize { self.call_count.load(Ordering::SeqCst) } } #[async_trait] impl Tool for SlowTool { fn name(&self) -> &str { &self.name } fn description(&self) -> &str { "A tool that waits before responding" } fn input_schema(&self) -> serde_json::Value { serde_json::json!({ "type": "object", "properties": {} }) } async fn execute(&self, _input_json: &str) -> Result { self.call_count.fetch_add(1, Ordering::SeqCst); tokio::time::sleep(Duration::from_millis(self.delay_ms)).await; Ok(format!("Completed after {}ms", self.delay_ms)) } } // ============================================================================= // Tests // ============================================================================= /// 複数のツールが並列に実行されることを確認 /// /// 各ツールが100msかかる場合、逐次実行なら300ms以上かかるが、 /// 並列実行なら100ms程度で完了するはず。 #[tokio::test] async fn test_parallel_tool_execution() { // 3つのツール呼び出しを含むイベントシーケンス let events = vec![ Event::tool_use_start(0, "call_1", "slow_tool_1"), Event::tool_input_delta(0, r#"{}"#), Event::tool_use_stop(0), Event::tool_use_start(1, "call_2", "slow_tool_2"), Event::tool_input_delta(1, r#"{}"#), Event::tool_use_stop(1), Event::tool_use_start(2, "call_3", "slow_tool_3"), Event::tool_input_delta(2, r#"{}"#), Event::tool_use_stop(2), Event::Status(StatusEvent { status: ResponseStatus::Completed, }), ]; let client = MockLlmClient::new(events); let mut worker = Worker::new(client); // 各ツールは100ms待機 let tool1 = SlowTool::new("slow_tool_1", 100); let tool2 = SlowTool::new("slow_tool_2", 100); let tool3 = SlowTool::new("slow_tool_3", 100); let tool1_clone = tool1.clone(); let tool2_clone = tool2.clone(); let tool3_clone = tool3.clone(); worker.register_tool(tool1); worker.register_tool(tool2); worker.register_tool(tool3); let start = Instant::now(); let _result = worker.run("Run all tools").await; let elapsed = start.elapsed(); // 全ツールが呼び出されたことを確認 assert_eq!(tool1_clone.call_count(), 1, "Tool 1 should be called once"); assert_eq!(tool2_clone.call_count(), 1, "Tool 2 should be called once"); assert_eq!(tool3_clone.call_count(), 1, "Tool 3 should be called once"); // 並列実行なら200ms以下で完了するはず(逐次なら300ms以上) // マージン込みで250msをしきい値とする assert!( elapsed < Duration::from_millis(250), "Parallel execution should complete in ~100ms, but took {:?}", elapsed ); println!("Parallel execution completed in {:?}", elapsed); } /// Hook: before_tool_call でスキップされたツールは実行されないことを確認 #[tokio::test] async fn test_before_tool_call_skip() { let events = vec![ Event::tool_use_start(0, "call_1", "allowed_tool"), Event::tool_input_delta(0, r#"{}"#), Event::tool_use_stop(0), Event::tool_use_start(1, "call_2", "blocked_tool"), Event::tool_input_delta(1, r#"{}"#), Event::tool_use_stop(1), Event::Status(StatusEvent { status: ResponseStatus::Completed, }), ]; let client = MockLlmClient::new(events); let mut worker = Worker::new(client); let allowed_tool = SlowTool::new("allowed_tool", 10); let blocked_tool = SlowTool::new("blocked_tool", 10); let allowed_clone = allowed_tool.clone(); let blocked_clone = blocked_tool.clone(); worker.register_tool(allowed_tool); worker.register_tool(blocked_tool); // "blocked_tool" をスキップするHook struct BlockingHook; #[async_trait] impl WorkerHook for BlockingHook { async fn before_tool_call( &self, tool_call: &mut ToolCall, ) -> Result { if tool_call.name == "blocked_tool" { Ok(ControlFlow::Skip) } else { Ok(ControlFlow::Continue) } } } worker.add_hook(BlockingHook); let _result = worker.run("Test hook").await; // allowed_tool は呼び出されるが、blocked_tool は呼び出されない assert_eq!( allowed_clone.call_count(), 1, "Allowed tool should be called" ); assert_eq!( blocked_clone.call_count(), 0, "Blocked tool should not be called" ); } /// Hook: after_tool_call で結果が改変されることを確認 #[tokio::test] async fn test_after_tool_call_modification() { // 複数リクエストに対応するレスポンスを準備 let client = MockLlmClient::with_responses(vec![ // 1回目のリクエスト: ツール呼び出し vec![ Event::tool_use_start(0, "call_1", "test_tool"), Event::tool_input_delta(0, r#"{}"#), Event::tool_use_stop(0), Event::Status(StatusEvent { status: ResponseStatus::Completed, }), ], // 2回目のリクエスト: ツール結果を受けてテキストレスポンス vec![ Event::text_block_start(0), Event::text_delta(0, "Done!"), Event::text_block_stop(0, None), Event::Status(StatusEvent { status: ResponseStatus::Completed, }), ], ]); let mut worker = Worker::new(client); #[derive(Clone)] struct SimpleTool; #[async_trait] impl Tool for SimpleTool { fn name(&self) -> &str { "test_tool" } fn description(&self) -> &str { "Test" } fn input_schema(&self) -> serde_json::Value { serde_json::json!({}) } async fn execute(&self, _: &str) -> Result { Ok("Original Result".to_string()) } } worker.register_tool(SimpleTool); // 結果を改変するHook struct ModifyingHook { modified_content: Arc>>, } #[async_trait] impl WorkerHook for ModifyingHook { async fn after_tool_call( &self, tool_result: &mut ToolResult, ) -> Result { tool_result.content = format!("[Modified] {}", tool_result.content); *self.modified_content.lock().unwrap() = Some(tool_result.content.clone()); Ok(ControlFlow::Continue) } } let modified_content = Arc::new(std::sync::Mutex::new(None)); worker.add_hook(ModifyingHook { modified_content: modified_content.clone(), }); let result = worker.run("Test modification").await; assert!(result.is_ok(), "Worker should complete: {:?}", result); // Hookが呼ばれて内容が改変されたことを確認 let content = modified_content.lock().unwrap().clone(); assert!(content.is_some(), "Hook should have been called"); assert!( content.unwrap().contains("[Modified]"), "Result should be modified" ); }