llm_worker_rs/llm-worker/tests/parallel_execution_test.rs
2026-01-16 16:58:03 +09:00

272 lines
8.5 KiB
Rust

//! Parallel tool execution tests
//!
//! Verify that Worker executes multiple tools in parallel.
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use async_trait::async_trait;
use llm_worker::Worker;
use llm_worker::hook::{
Hook, HookError, PostToolCall, PostToolCallContext, PostToolCallResult, PreToolCall,
PreToolCallResult, ToolCallContext,
};
use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent};
use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta};
mod common;
use common::MockLlmClient;
// =============================================================================
// Parallel Execution Test Tools
// =============================================================================
/// Tool that waits for a specified time before responding
#[derive(Clone)]
struct SlowTool {
name: String,
delay_ms: u64,
call_count: Arc<AtomicUsize>,
}
impl SlowTool {
fn new(name: impl Into<String>, delay_ms: u64) -> Self {
Self {
name: name.into(),
delay_ms,
call_count: Arc::new(AtomicUsize::new(0)),
}
}
fn call_count(&self) -> usize {
self.call_count.load(Ordering::SeqCst)
}
/// Create ToolDefinition
fn definition(&self) -> ToolDefinition {
let tool = self.clone();
Arc::new(move || {
let meta = ToolMeta::new(&tool.name)
.description("A tool that waits before responding")
.input_schema(serde_json::json!({
"type": "object",
"properties": {}
}));
(meta, Arc::new(tool.clone()) as Arc<dyn Tool>)
})
}
}
#[async_trait]
impl Tool for SlowTool {
async fn execute(&self, _input_json: &str) -> Result<String, ToolError> {
self.call_count.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(self.delay_ms)).await;
Ok(format!("Completed after {}ms", self.delay_ms))
}
}
// =============================================================================
// Tests
// =============================================================================
/// Verify that multiple tools are executed in parallel
///
/// If each tool takes 100ms, sequential execution would take 300ms+,
/// but parallel execution should complete in about 100ms.
#[tokio::test]
async fn test_parallel_tool_execution() {
// Event sequence containing 3 tool calls
let events = vec![
Event::tool_use_start(0, "call_1", "slow_tool_1"),
Event::tool_input_delta(0, r#"{}"#),
Event::tool_use_stop(0),
Event::tool_use_start(1, "call_2", "slow_tool_2"),
Event::tool_input_delta(1, r#"{}"#),
Event::tool_use_stop(1),
Event::tool_use_start(2, "call_3", "slow_tool_3"),
Event::tool_input_delta(2, r#"{}"#),
Event::tool_use_stop(2),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
];
let client = MockLlmClient::new(events);
let mut worker = Worker::new(client);
// Each tool waits 100ms
let tool1 = SlowTool::new("slow_tool_1", 100);
let tool2 = SlowTool::new("slow_tool_2", 100);
let tool3 = SlowTool::new("slow_tool_3", 100);
let tool1_clone = tool1.clone();
let tool2_clone = tool2.clone();
let tool3_clone = tool3.clone();
worker.register_tool(tool1.definition()).unwrap();
worker.register_tool(tool2.definition()).unwrap();
worker.register_tool(tool3.definition()).unwrap();
let start = Instant::now();
let _result = worker.run("Run all tools").await;
let elapsed = start.elapsed();
// Verify all tools were called
assert_eq!(tool1_clone.call_count(), 1, "Tool 1 should be called once");
assert_eq!(tool2_clone.call_count(), 1, "Tool 2 should be called once");
assert_eq!(tool3_clone.call_count(), 1, "Tool 3 should be called once");
// Parallel execution should complete in under 200ms (sequential would be 300ms+)
// Using 250ms as threshold with margin
assert!(
elapsed < Duration::from_millis(250),
"Parallel execution should complete in ~100ms, but took {:?}",
elapsed
);
println!("Parallel execution completed in {:?}", elapsed);
}
/// Hook: pre_tool_call - verify that skipped tools are not executed
#[tokio::test]
async fn test_before_tool_call_skip() {
let events = vec![
Event::tool_use_start(0, "call_1", "allowed_tool"),
Event::tool_input_delta(0, r#"{}"#),
Event::tool_use_stop(0),
Event::tool_use_start(1, "call_2", "blocked_tool"),
Event::tool_input_delta(1, r#"{}"#),
Event::tool_use_stop(1),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
];
let client = MockLlmClient::new(events);
let mut worker = Worker::new(client);
let allowed_tool = SlowTool::new("allowed_tool", 10);
let blocked_tool = SlowTool::new("blocked_tool", 10);
let allowed_clone = allowed_tool.clone();
let blocked_clone = blocked_tool.clone();
worker.register_tool(allowed_tool.definition()).unwrap();
worker.register_tool(blocked_tool.definition()).unwrap();
// Hook to skip "blocked_tool"
struct BlockingHook;
#[async_trait]
impl Hook<PreToolCall> for BlockingHook {
async fn call(&self, ctx: &mut ToolCallContext) -> Result<PreToolCallResult, HookError> {
if ctx.call.name == "blocked_tool" {
Ok(PreToolCallResult::Skip)
} else {
Ok(PreToolCallResult::Continue)
}
}
}
worker.add_pre_tool_call_hook(BlockingHook);
let _result = worker.run("Test hook").await;
// allowed_tool is called, but blocked_tool is not
assert_eq!(
allowed_clone.call_count(),
1,
"Allowed tool should be called"
);
assert_eq!(
blocked_clone.call_count(),
0,
"Blocked tool should not be called"
);
}
/// Hook: post_tool_call - verify that results can be modified
#[tokio::test]
async fn test_post_tool_call_modification() {
// Prepare responses for multiple requests
let client = MockLlmClient::with_responses(vec![
// First request: tool call
vec![
Event::tool_use_start(0, "call_1", "test_tool"),
Event::tool_input_delta(0, r#"{}"#),
Event::tool_use_stop(0),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
],
// Second request: text response after receiving tool result
vec![
Event::text_block_start(0),
Event::text_delta(0, "Done!"),
Event::text_block_stop(0, None),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
],
]);
let mut worker = Worker::new(client);
#[derive(Clone)]
struct SimpleTool;
#[async_trait]
impl Tool for SimpleTool {
async fn execute(&self, _: &str) -> Result<String, ToolError> {
Ok("Original Result".to_string())
}
}
fn simple_tool_definition() -> ToolDefinition {
Arc::new(|| {
let meta = ToolMeta::new("test_tool")
.description("Test")
.input_schema(serde_json::json!({}));
(meta, Arc::new(SimpleTool) as Arc<dyn Tool>)
})
}
worker.register_tool(simple_tool_definition()).unwrap();
// Hook to modify results
struct ModifyingHook {
modified_content: Arc<std::sync::Mutex<Option<String>>>,
}
#[async_trait]
impl Hook<PostToolCall> for ModifyingHook {
async fn call(
&self,
ctx: &mut PostToolCallContext,
) -> Result<PostToolCallResult, HookError> {
ctx.result.content = format!("[Modified] {}", ctx.result.content);
*self.modified_content.lock().unwrap() = Some(ctx.result.content.clone());
Ok(PostToolCallResult::Continue)
}
}
let modified_content = Arc::new(std::sync::Mutex::new(None));
worker.add_post_tool_call_hook(ModifyingHook {
modified_content: modified_content.clone(),
});
let result = worker.run("Test modification").await;
assert!(result.is_ok(), "Worker should complete: {:?}", result);
// Verify hook was called and content was modified
let content = modified_content.lock().unwrap().clone();
assert!(content.is_some(), "Hook should have been called");
assert!(
content.unwrap().contains("[Modified]"),
"Result should be modified"
);
}