255 lines
8.2 KiB
Rust
255 lines
8.2 KiB
Rust
//! 並列ツール実行のテスト
|
||
//!
|
||
//! Workerが複数のツールを並列に実行することを確認する。
|
||
|
||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||
use std::sync::Arc;
|
||
use std::time::{Duration, Instant};
|
||
|
||
use async_trait::async_trait;
|
||
use worker::Worker;
|
||
use worker_types::{Event, Message, ResponseStatus, StatusEvent, Tool, ToolError, ToolResult, ToolCall, ControlFlow, HookError, WorkerHook};
|
||
|
||
mod common;
|
||
use common::MockLlmClient;
|
||
|
||
// =============================================================================
|
||
// Parallel Execution Test Tools
|
||
// =============================================================================
|
||
|
||
/// 一定時間待機してから応答するツール
|
||
#[derive(Clone)]
|
||
struct SlowTool {
|
||
name: String,
|
||
delay_ms: u64,
|
||
call_count: Arc<AtomicUsize>,
|
||
}
|
||
|
||
impl SlowTool {
|
||
fn new(name: impl Into<String>, delay_ms: u64) -> Self {
|
||
Self {
|
||
name: name.into(),
|
||
delay_ms,
|
||
call_count: Arc::new(AtomicUsize::new(0)),
|
||
}
|
||
}
|
||
|
||
fn call_count(&self) -> usize {
|
||
self.call_count.load(Ordering::SeqCst)
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl Tool for SlowTool {
|
||
fn name(&self) -> &str {
|
||
&self.name
|
||
}
|
||
|
||
fn description(&self) -> &str {
|
||
"A tool that waits before responding"
|
||
}
|
||
|
||
fn input_schema(&self) -> serde_json::Value {
|
||
serde_json::json!({
|
||
"type": "object",
|
||
"properties": {}
|
||
})
|
||
}
|
||
|
||
async fn execute(&self, _input_json: &str) -> Result<String, ToolError> {
|
||
self.call_count.fetch_add(1, Ordering::SeqCst);
|
||
tokio::time::sleep(Duration::from_millis(self.delay_ms)).await;
|
||
Ok(format!("Completed after {}ms", self.delay_ms))
|
||
}
|
||
}
|
||
|
||
// =============================================================================
|
||
// Tests
|
||
// =============================================================================
|
||
|
||
/// 複数のツールが並列に実行されることを確認
|
||
///
|
||
/// 各ツールが100msかかる場合、逐次実行なら300ms以上かかるが、
|
||
/// 並列実行なら100ms程度で完了するはず。
|
||
#[tokio::test]
|
||
async fn test_parallel_tool_execution() {
|
||
// 3つのツール呼び出しを含むイベントシーケンス
|
||
let events = vec![
|
||
Event::tool_use_start(0, "call_1", "slow_tool_1"),
|
||
Event::tool_input_delta(0, r#"{}"#),
|
||
Event::tool_use_stop(0),
|
||
Event::tool_use_start(1, "call_2", "slow_tool_2"),
|
||
Event::tool_input_delta(1, r#"{}"#),
|
||
Event::tool_use_stop(1),
|
||
Event::tool_use_start(2, "call_3", "slow_tool_3"),
|
||
Event::tool_input_delta(2, r#"{}"#),
|
||
Event::tool_use_stop(2),
|
||
Event::Status(StatusEvent {
|
||
status: ResponseStatus::Completed,
|
||
}),
|
||
];
|
||
|
||
let client = MockLlmClient::new(events);
|
||
let mut worker = Worker::new(client);
|
||
|
||
// 各ツールは100ms待機
|
||
let tool1 = SlowTool::new("slow_tool_1", 100);
|
||
let tool2 = SlowTool::new("slow_tool_2", 100);
|
||
let tool3 = SlowTool::new("slow_tool_3", 100);
|
||
|
||
let tool1_clone = tool1.clone();
|
||
let tool2_clone = tool2.clone();
|
||
let tool3_clone = tool3.clone();
|
||
|
||
worker.register_tool(tool1);
|
||
worker.register_tool(tool2);
|
||
worker.register_tool(tool3);
|
||
|
||
|
||
|
||
let messages = vec![Message::user("Run all tools")];
|
||
|
||
let start = Instant::now();
|
||
let _result = worker.run(messages).await;
|
||
let elapsed = start.elapsed();
|
||
|
||
// 全ツールが呼び出されたことを確認
|
||
assert_eq!(tool1_clone.call_count(), 1, "Tool 1 should be called once");
|
||
assert_eq!(tool2_clone.call_count(), 1, "Tool 2 should be called once");
|
||
assert_eq!(tool3_clone.call_count(), 1, "Tool 3 should be called once");
|
||
|
||
// 並列実行なら200ms以下で完了するはず(逐次なら300ms以上)
|
||
// マージン込みで250msをしきい値とする
|
||
assert!(
|
||
elapsed < Duration::from_millis(250),
|
||
"Parallel execution should complete in ~100ms, but took {:?}",
|
||
elapsed
|
||
);
|
||
|
||
println!("Parallel execution completed in {:?}", elapsed);
|
||
}
|
||
|
||
/// Hook: before_tool_call でスキップされたツールは実行されないことを確認
|
||
#[tokio::test]
|
||
async fn test_before_tool_call_skip() {
|
||
let events = vec![
|
||
Event::tool_use_start(0, "call_1", "allowed_tool"),
|
||
Event::tool_input_delta(0, r#"{}"#),
|
||
Event::tool_use_stop(0),
|
||
Event::tool_use_start(1, "call_2", "blocked_tool"),
|
||
Event::tool_input_delta(1, r#"{}"#),
|
||
Event::tool_use_stop(1),
|
||
Event::Status(StatusEvent {
|
||
status: ResponseStatus::Completed,
|
||
}),
|
||
];
|
||
|
||
let client = MockLlmClient::new(events);
|
||
let mut worker = Worker::new(client);
|
||
|
||
let allowed_tool = SlowTool::new("allowed_tool", 10);
|
||
let blocked_tool = SlowTool::new("blocked_tool", 10);
|
||
|
||
let allowed_clone = allowed_tool.clone();
|
||
let blocked_clone = blocked_tool.clone();
|
||
|
||
worker.register_tool(allowed_tool);
|
||
worker.register_tool(blocked_tool);
|
||
|
||
// "blocked_tool" をスキップするHook
|
||
struct BlockingHook;
|
||
|
||
#[async_trait]
|
||
impl WorkerHook for BlockingHook {
|
||
async fn before_tool_call(&self, tool_call: &mut ToolCall) -> Result<ControlFlow, HookError> {
|
||
if tool_call.name == "blocked_tool" {
|
||
Ok(ControlFlow::Skip)
|
||
} else {
|
||
Ok(ControlFlow::Continue)
|
||
}
|
||
}
|
||
}
|
||
|
||
worker.add_hook(BlockingHook);
|
||
|
||
let messages = vec![Message::user("Test hook")];
|
||
let _result = worker.run(messages).await;
|
||
|
||
// allowed_tool は呼び出されるが、blocked_tool は呼び出されない
|
||
assert_eq!(allowed_clone.call_count(), 1, "Allowed tool should be called");
|
||
assert_eq!(blocked_clone.call_count(), 0, "Blocked tool should not be called");
|
||
}
|
||
|
||
/// Hook: after_tool_call で結果が改変されることを確認
|
||
#[tokio::test]
|
||
async fn test_after_tool_call_modification() {
|
||
// 複数リクエストに対応するレスポンスを準備
|
||
let client = MockLlmClient::with_responses(vec![
|
||
// 1回目のリクエスト: ツール呼び出し
|
||
vec![
|
||
Event::tool_use_start(0, "call_1", "test_tool"),
|
||
Event::tool_input_delta(0, r#"{}"#),
|
||
Event::tool_use_stop(0),
|
||
Event::Status(StatusEvent {
|
||
status: ResponseStatus::Completed,
|
||
}),
|
||
],
|
||
// 2回目のリクエスト: ツール結果を受けてテキストレスポンス
|
||
vec![
|
||
Event::text_block_start(0),
|
||
Event::text_delta(0, "Done!"),
|
||
Event::text_block_stop(0, None),
|
||
Event::Status(StatusEvent {
|
||
status: ResponseStatus::Completed,
|
||
}),
|
||
],
|
||
]);
|
||
|
||
let mut worker = Worker::new(client);
|
||
|
||
#[derive(Clone)]
|
||
struct SimpleTool;
|
||
|
||
#[async_trait]
|
||
impl Tool for SimpleTool {
|
||
fn name(&self) -> &str { "test_tool" }
|
||
fn description(&self) -> &str { "Test" }
|
||
fn input_schema(&self) -> serde_json::Value { serde_json::json!({}) }
|
||
async fn execute(&self, _: &str) -> Result<String, ToolError> {
|
||
Ok("Original Result".to_string())
|
||
}
|
||
}
|
||
|
||
worker.register_tool(SimpleTool);
|
||
|
||
// 結果を改変するHook
|
||
struct ModifyingHook {
|
||
modified_content: Arc<std::sync::Mutex<Option<String>>>,
|
||
}
|
||
|
||
#[async_trait]
|
||
impl WorkerHook for ModifyingHook {
|
||
async fn after_tool_call(&self, tool_result: &mut ToolResult) -> Result<ControlFlow, HookError> {
|
||
tool_result.content = format!("[Modified] {}", tool_result.content);
|
||
*self.modified_content.lock().unwrap() = Some(tool_result.content.clone());
|
||
Ok(ControlFlow::Continue)
|
||
}
|
||
}
|
||
|
||
let modified_content = Arc::new(std::sync::Mutex::new(None));
|
||
worker.add_hook(ModifyingHook { modified_content: modified_content.clone() });
|
||
|
||
let messages = vec![Message::user("Test modification")];
|
||
let result = worker.run(messages).await;
|
||
|
||
assert!(result.is_ok(), "Worker should complete: {:?}", result);
|
||
|
||
// Hookが呼ばれて内容が改変されたことを確認
|
||
let content = modified_content.lock().unwrap().clone();
|
||
assert!(content.is_some(), "Hook should have been called");
|
||
assert!(
|
||
content.unwrap().contains("[Modified]"),
|
||
"Result should be modified"
|
||
);
|
||
}
|