llm-worker-rs/worker/src/tests/mock_llm_integration.rs
2025-11-01 05:27:47 +09:00

497 lines
15 KiB
Rust

use std::sync::{Arc, Mutex, OnceLock};
use async_trait::async_trait;
use futures::StreamExt;
use serde_json::{json, Value};
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockGuard, MockServer, ResponseTemplate};
use worker::{
HookContext, HookResult, LlmProvider, PromptError, StreamEvent, Tool, ToolResult, Worker,
WorkerBlueprint, WorkerHook,
};
use worker_types::Role;
const SAMPLE_TOOL_NAME: &str = "sample_tool";
static ENV_MUTEX: OnceLock<Mutex<()>> = OnceLock::new();
struct ProviderCase {
name: &'static str,
provider: LlmProvider,
model: &'static str,
env_var: &'static str,
completion_path: &'static str,
}
struct SampleTool {
name: String,
description: String,
calls: Arc<Mutex<Vec<Value>>>,
response: Value,
}
impl SampleTool {
fn new(provider_label: &str, calls: Arc<Mutex<Vec<Value>>>) -> Self {
Self {
name: SAMPLE_TOOL_NAME.to_string(),
description: format!("Records invocations for {}", provider_label),
calls,
response: json!({
"status": "ok",
"provider": provider_label,
}),
}
}
}
#[async_trait]
impl Tool for SampleTool {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn parameters_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"provider": {"type": "string"},
"request_id": {"type": "integer"}
},
"required": ["provider", "request_id"]
})
}
async fn execute(&self, args: Value) -> ToolResult<Value> {
self.calls.lock().unwrap().push(args.clone());
Ok(self.response.clone())
}
}
struct RecordingHook {
tool_name: String,
provider_label: String,
events: Arc<Mutex<Vec<String>>>,
}
impl RecordingHook {
fn new(provider_label: &str, events: Arc<Mutex<Vec<String>>>) -> Self {
Self {
tool_name: SAMPLE_TOOL_NAME.to_string(),
provider_label: provider_label.to_string(),
events,
}
}
}
#[async_trait]
impl WorkerHook for RecordingHook {
fn name(&self) -> &str {
"recording_hook"
}
fn hook_type(&self) -> &str {
"PostToolUse"
}
fn matcher(&self) -> &str {
&self.tool_name
}
async fn execute(&self, context: HookContext) -> (HookContext, HookResult) {
let tool = context
.get_variable("current_tool")
.cloned()
.unwrap_or_else(|| self.tool_name.clone());
let entry = format!(
"{}::{}::{}",
self.provider_label, tool, context.content
);
self.events.lock().unwrap().push(entry);
let message = format!("{} hook observed {}", self.provider_label, tool);
(
context,
HookResult::AddMessage(message, Role::Assistant),
)
}
}
struct EnvOverride {
key: String,
previous: Option<String>,
}
impl EnvOverride {
fn set(key: &str, value: String) -> Self {
let previous = std::env::var(key).ok();
std::env::set_var(key, &value);
Self {
key: key.to_string(),
previous,
}
}
}
impl Drop for EnvOverride {
fn drop(&mut self) {
if let Some(prev) = self.previous.take() {
std::env::set_var(&self.key, prev);
} else {
std::env::remove_var(&self.key);
}
}
}
fn build_blueprint(
provider: LlmProvider,
model: &str,
provider_label: &str,
tool_calls: Arc<Mutex<Vec<Value>>>,
hook_events: Arc<Mutex<Vec<String>>>,
) -> WorkerBlueprint {
let mut blueprint = Worker::blueprint();
blueprint
.provider(provider)
.model(model)
.api_key(provider.as_str(), "test-key")
.system_prompt_fn(|_, _| Ok::<_, PromptError>("Integration test system prompt.".into()))
.add_tool(SampleTool::new(provider_label, Arc::clone(&tool_calls)))
.attach_hook(RecordingHook::new(provider_label, Arc::clone(&hook_events)));
blueprint
}
async fn setup_mock_response(
case: &ProviderCase,
server: &MockServer,
expected_args: &Value,
) -> MockGuard {
match case.provider {
LlmProvider::OpenAI => {
let arguments = expected_args.to_string();
let event_body = json!({
"choices": [{
"delta": {
"tool_calls": [{
"function": {
"name": SAMPLE_TOOL_NAME,
"arguments": arguments
}
}]
}
}]
});
let sse = format!(
"data: {}\n\ndata: [DONE]\n\n",
event_body
);
Mock::given(method("POST"))
.and(path(case.completion_path))
.respond_with(
ResponseTemplate::new(200)
.set_body_raw(sse, "text/event-stream"),
)
.mount(server)
.await
}
LlmProvider::Claude => {
let event_tool = json!({
"type": "content_block_start",
"data": {
"content_block": {
"type": "tool_use",
"name": SAMPLE_TOOL_NAME,
"input": expected_args
}
}
});
let event_stop = json!({
"type": "message_stop",
"data": {}
});
let sse = format!(
"data: {}\n\ndata: {}\n\ndata: [DONE]\n\n",
event_tool, event_stop
);
Mock::given(method("POST"))
.and(path(case.completion_path))
.respond_with(
ResponseTemplate::new(200)
.set_body_raw(sse, "text/event-stream"),
)
.mount(server)
.await
}
LlmProvider::Gemini => {
let body = json!({
"candidates": [{
"content": {
"role": "model",
"parts": [{
"functionCall": {
"name": SAMPLE_TOOL_NAME,
"args": expected_args
}
}]
},
"finishReason": "STOP"
}]
});
Mock::given(method("POST"))
.and(path(case.completion_path))
.respond_with(ResponseTemplate::new(200).set_body_json(body))
.mount(server)
.await
}
LlmProvider::Ollama => {
let first = json!({
"message": {
"role": "assistant",
"content": "",
"tool_calls": [{
"function": {
"name": SAMPLE_TOOL_NAME,
"arguments": expected_args
}
}]
},
"done": false
});
let second = json!({
"message": {
"role": "assistant",
"content": "finished"
},
"done": true
});
let body = format!("{}\n{}\n", first, second);
Mock::given(method("POST"))
.and(path(case.completion_path))
.respond_with(
ResponseTemplate::new(200)
.set_body_raw(body, "application/x-ndjson"),
)
.mount(server)
.await
}
other => panic!("Unsupported provider in test: {:?}", other),
}
}
fn provider_cases() -> Vec<ProviderCase> {
vec![
ProviderCase {
name: "openai",
provider: LlmProvider::OpenAI,
model: "gpt-4o-mini",
env_var: "OPENAI_BASE_URL",
completion_path: "/v1/chat/completions",
},
ProviderCase {
name: "gemini",
provider: LlmProvider::Gemini,
model: "gemini-1.5-flash",
env_var: "GEMINI_BASE_URL",
completion_path: "/v1beta/models/gemini-1.5-flash:streamGenerateContent",
},
ProviderCase {
name: "anthropic",
provider: LlmProvider::Claude,
model: "claude-3-opus-20240229",
env_var: "ANTHROPIC_BASE_URL",
completion_path: "/v1/messages",
},
ProviderCase {
name: "ollama",
provider: LlmProvider::Ollama,
model: "llama3",
env_var: "OLLAMA_BASE_URL",
completion_path: "/api/chat",
},
]
}
#[tokio::test]
async fn worker_executes_tools_and_hooks_across_mocked_providers() {
let env_lock = ENV_MUTEX.get_or_init(|| Mutex::new(()));
for case in provider_cases() {
let tool_calls = Arc::new(Mutex::new(Vec::<Value>::new()));
let hook_events = Arc::new(Mutex::new(Vec::<String>::new()));
let _env_guard = env_lock.lock().unwrap();
let server = MockServer::start().await;
let _env_override = EnvOverride::set(case.env_var, server.uri());
let expected_args = json!({
"provider": case.name,
"request_id": 1
});
let _mock = setup_mock_response(&case, &server, &expected_args).await;
let mut blueprint = build_blueprint(
case.provider,
case.model,
case.name,
Arc::clone(&tool_calls),
Arc::clone(&hook_events),
);
let mut worker = blueprint.instantiate().expect("worker to instantiate");
let mut stream = worker
.process_task_stream(
"Trigger the sample tool".to_string(),
None,
)
.await;
let mut events = Vec::new();
while let Some(event) = stream.next().await {
events.push(event.expect("stream event"));
}
let requests = server
.received_requests()
.await
.expect("to inspect received requests");
assert_eq!(
requests.len(),
1,
"expected exactly one request for provider {}",
case.name
);
let body: Value =
serde_json::from_slice(&requests[0].body).expect("request body to be JSON");
match case.provider {
LlmProvider::OpenAI => {
assert_eq!(body["model"], case.model);
assert_eq!(body["stream"], true);
assert_eq!(
body["tools"][0]["function"]["name"],
SAMPLE_TOOL_NAME
);
}
LlmProvider::Claude => {
assert_eq!(body["model"], case.model);
assert_eq!(body["stream"], true);
assert_eq!(body["tools"][0]["name"], SAMPLE_TOOL_NAME);
}
LlmProvider::Gemini => {
assert_eq!(
body["contents"]
.as_array()
.expect("contents to be array")
.len(),
2,
"system + user messages should be present"
);
let tools = body["tools"][0]["functionDeclarations"]
.as_array()
.expect("function declarations to exist");
assert_eq!(tools[0]["name"], SAMPLE_TOOL_NAME);
}
LlmProvider::Ollama => {
assert_eq!(body["model"], case.model);
assert_eq!(body["stream"], true);
assert_eq!(
body["tools"][0]["function"]["name"],
SAMPLE_TOOL_NAME
);
}
_ => unreachable!(),
}
let recorded_calls = tool_calls.lock().unwrap().clone();
assert_eq!(
recorded_calls.len(),
1,
"tool should execute exactly once for {}",
case.name
);
assert_eq!(
recorded_calls[0], expected_args,
"tool arguments should match for {}",
case.name
);
let recorded_hooks = hook_events.lock().unwrap().clone();
assert!(
recorded_hooks
.iter()
.any(|entry| entry.contains(case.name) && entry.contains(SAMPLE_TOOL_NAME)),
"hook should capture tool usage for {}: {:?}",
case.name,
recorded_hooks
);
let mut saw_tool_call = false;
let mut saw_tool_result = false;
let mut saw_hook_message = false;
let mut saw_completion = false;
for event in &events {
match event {
StreamEvent::ToolCall(call) => {
saw_tool_call = true;
assert_eq!(call.name, SAMPLE_TOOL_NAME);
assert_eq!(
serde_json::from_str::<Value>(&call.arguments)
.expect("tool arguments to be JSON"),
expected_args
);
}
StreamEvent::ToolResult { tool_name, result } => {
if tool_name == SAMPLE_TOOL_NAME {
saw_tool_result = true;
let value = result
.as_ref()
.expect("tool execution should succeed");
assert_eq!(value["status"], "ok");
assert_eq!(value["provider"], case.name);
}
}
StreamEvent::HookMessage { hook_name, content, .. } => {
if hook_name == "recording_hook" {
saw_hook_message = true;
assert!(
content.contains(case.name),
"hook content should mention provider"
);
}
}
StreamEvent::Completion(_) => {
saw_completion = true;
}
_ => {}
}
}
assert!(saw_tool_call, "missing tool call event for {}", case.name);
assert!(
saw_tool_result,
"missing tool result event for {}",
case.name
);
assert!(
saw_hook_message,
"missing hook message for {}",
case.name
);
assert!(
saw_completion,
"missing completion event for {}",
case.name
);
std::env::remove_var(case.env_var);
}
}