llm_worker_rs/llm-worker/tests/parallel_execution_test.rs

273 lines
8.4 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! 並列ツール実行のテスト
//!
//! Workerが複数のツールを並列に実行することを確認する。
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use async_trait::async_trait;
use llm_worker::Worker;
use llm_worker::hook::{
AfterToolCall, AfterToolCallResult, BeforeToolCall, BeforeToolCallResult, Hook, HookError,
ToolCall, ToolResult,
};
use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent};
use llm_worker::tool::{Tool, ToolError};
mod common;
use common::MockLlmClient;
// =============================================================================
// Parallel Execution Test Tools
// =============================================================================
/// 一定時間待機してから応答するツール
#[derive(Clone)]
struct SlowTool {
name: String,
delay_ms: u64,
call_count: Arc<AtomicUsize>,
}
impl SlowTool {
fn new(name: impl Into<String>, delay_ms: u64) -> Self {
Self {
name: name.into(),
delay_ms,
call_count: Arc::new(AtomicUsize::new(0)),
}
}
fn call_count(&self) -> usize {
self.call_count.load(Ordering::SeqCst)
}
}
#[async_trait]
impl Tool for SlowTool {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
"A tool that waits before responding"
}
fn input_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {}
})
}
async fn execute(&self, _input_json: &str) -> Result<String, ToolError> {
self.call_count.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(self.delay_ms)).await;
Ok(format!("Completed after {}ms", self.delay_ms))
}
}
// =============================================================================
// Tests
// =============================================================================
/// 複数のツールが並列に実行されることを確認
///
/// 各ツールが100msかかる場合、逐次実行なら300ms以上かかるが、
/// 並列実行なら100ms程度で完了するはず。
#[tokio::test]
async fn test_parallel_tool_execution() {
// 3つのツール呼び出しを含むイベントシーケンス
let events = vec![
Event::tool_use_start(0, "call_1", "slow_tool_1"),
Event::tool_input_delta(0, r#"{}"#),
Event::tool_use_stop(0),
Event::tool_use_start(1, "call_2", "slow_tool_2"),
Event::tool_input_delta(1, r#"{}"#),
Event::tool_use_stop(1),
Event::tool_use_start(2, "call_3", "slow_tool_3"),
Event::tool_input_delta(2, r#"{}"#),
Event::tool_use_stop(2),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
];
let client = MockLlmClient::new(events);
let mut worker = Worker::new(client);
// 各ツールは100ms待機
let tool1 = SlowTool::new("slow_tool_1", 100);
let tool2 = SlowTool::new("slow_tool_2", 100);
let tool3 = SlowTool::new("slow_tool_3", 100);
let tool1_clone = tool1.clone();
let tool2_clone = tool2.clone();
let tool3_clone = tool3.clone();
worker.register_tool(tool1);
worker.register_tool(tool2);
worker.register_tool(tool3);
let start = Instant::now();
let _result = worker.run("Run all tools").await;
let elapsed = start.elapsed();
// 全ツールが呼び出されたことを確認
assert_eq!(tool1_clone.call_count(), 1, "Tool 1 should be called once");
assert_eq!(tool2_clone.call_count(), 1, "Tool 2 should be called once");
assert_eq!(tool3_clone.call_count(), 1, "Tool 3 should be called once");
// 並列実行なら200ms以下で完了するはず逐次なら300ms以上
// マージン込みで250msをしきい値とする
assert!(
elapsed < Duration::from_millis(250),
"Parallel execution should complete in ~100ms, but took {:?}",
elapsed
);
println!("Parallel execution completed in {:?}", elapsed);
}
/// Hook: before_tool_call でスキップされたツールは実行されないことを確認
#[tokio::test]
async fn test_before_tool_call_skip() {
let events = vec![
Event::tool_use_start(0, "call_1", "allowed_tool"),
Event::tool_input_delta(0, r#"{}"#),
Event::tool_use_stop(0),
Event::tool_use_start(1, "call_2", "blocked_tool"),
Event::tool_input_delta(1, r#"{}"#),
Event::tool_use_stop(1),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
];
let client = MockLlmClient::new(events);
let mut worker = Worker::new(client);
let allowed_tool = SlowTool::new("allowed_tool", 10);
let blocked_tool = SlowTool::new("blocked_tool", 10);
let allowed_clone = allowed_tool.clone();
let blocked_clone = blocked_tool.clone();
worker.register_tool(allowed_tool);
worker.register_tool(blocked_tool);
// "blocked_tool" をスキップするHook
struct BlockingHook;
#[async_trait]
impl Hook<BeforeToolCall> for BlockingHook {
async fn call(&self, tool_call: &mut ToolCall) -> Result<BeforeToolCallResult, HookError> {
if tool_call.name == "blocked_tool" {
Ok(BeforeToolCallResult::Skip)
} else {
Ok(BeforeToolCallResult::Continue)
}
}
}
worker.add_before_tool_call_hook(BlockingHook);
let _result = worker.run("Test hook").await;
// allowed_tool は呼び出されるが、blocked_tool は呼び出されない
assert_eq!(
allowed_clone.call_count(),
1,
"Allowed tool should be called"
);
assert_eq!(
blocked_clone.call_count(),
0,
"Blocked tool should not be called"
);
}
/// Hook: after_tool_call で結果が改変されることを確認
#[tokio::test]
async fn test_after_tool_call_modification() {
// 複数リクエストに対応するレスポンスを準備
let client = MockLlmClient::with_responses(vec![
// 1回目のリクエスト: ツール呼び出し
vec![
Event::tool_use_start(0, "call_1", "test_tool"),
Event::tool_input_delta(0, r#"{}"#),
Event::tool_use_stop(0),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
],
// 2回目のリクエスト: ツール結果を受けてテキストレスポンス
vec![
Event::text_block_start(0),
Event::text_delta(0, "Done!"),
Event::text_block_stop(0, None),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
],
]);
let mut worker = Worker::new(client);
#[derive(Clone)]
struct SimpleTool;
#[async_trait]
impl Tool for SimpleTool {
fn name(&self) -> &str {
"test_tool"
}
fn description(&self) -> &str {
"Test"
}
fn input_schema(&self) -> serde_json::Value {
serde_json::json!({})
}
async fn execute(&self, _: &str) -> Result<String, ToolError> {
Ok("Original Result".to_string())
}
}
worker.register_tool(SimpleTool);
// 結果を改変するHook
struct ModifyingHook {
modified_content: Arc<std::sync::Mutex<Option<String>>>,
}
#[async_trait]
impl Hook<AfterToolCall> for ModifyingHook {
async fn call(
&self,
tool_result: &mut ToolResult,
) -> Result<AfterToolCallResult, HookError> {
tool_result.content = format!("[Modified] {}", tool_result.content);
*self.modified_content.lock().unwrap() = Some(tool_result.content.clone());
Ok(AfterToolCallResult::Continue)
}
}
let modified_content = Arc::new(std::sync::Mutex::new(None));
worker.add_after_tool_call_hook(ModifyingHook {
modified_content: modified_content.clone(),
});
let result = worker.run("Test modification").await;
assert!(result.is_ok(), "Worker should complete: {:?}", result);
// Hookが呼ばれて内容が改変されたことを確認
let content = modified_content.lock().unwrap().clone();
assert!(content.is_some(), "Hook should have been called");
assert!(
content.unwrap().contains("[Modified]"),
"Result should be modified"
);
}