use std::sync::{Arc, Mutex, OnceLock}; use async_trait::async_trait; use futures::StreamExt; use serde_json::{json, Value}; use wiremock::matchers::{method, path}; use wiremock::{Mock, MockGuard, MockServer, ResponseTemplate}; use worker::{ HookContext, HookResult, LlmProvider, PromptError, StreamEvent, Tool, ToolResult, Worker, WorkerBlueprint, WorkerHook, }; use worker_types::Role; const SAMPLE_TOOL_NAME: &str = "sample_tool"; static ENV_MUTEX: OnceLock> = OnceLock::new(); struct ProviderCase { name: &'static str, provider: LlmProvider, model: &'static str, env_var: &'static str, completion_path: &'static str, } struct SampleTool { name: String, description: String, calls: Arc>>, response: Value, } impl SampleTool { fn new(provider_label: &str, calls: Arc>>) -> Self { Self { name: SAMPLE_TOOL_NAME.to_string(), description: format!("Records invocations for {}", provider_label), calls, response: json!({ "status": "ok", "provider": provider_label, }), } } } #[async_trait] impl Tool for SampleTool { fn name(&self) -> &str { &self.name } fn description(&self) -> &str { &self.description } fn parameters_schema(&self) -> Value { json!({ "type": "object", "properties": { "provider": {"type": "string"}, "request_id": {"type": "integer"} }, "required": ["provider", "request_id"] }) } async fn execute(&self, args: Value) -> ToolResult { self.calls.lock().unwrap().push(args.clone()); Ok(self.response.clone()) } } struct RecordingHook { tool_name: String, provider_label: String, events: Arc>>, } impl RecordingHook { fn new(provider_label: &str, events: Arc>>) -> Self { Self { tool_name: SAMPLE_TOOL_NAME.to_string(), provider_label: provider_label.to_string(), events, } } } #[async_trait] impl WorkerHook for RecordingHook { fn name(&self) -> &str { "recording_hook" } fn hook_type(&self) -> &str { "PostToolUse" } fn matcher(&self) -> &str { &self.tool_name } async fn execute(&self, context: HookContext) -> (HookContext, HookResult) { let tool = context .get_variable("current_tool") .cloned() .unwrap_or_else(|| self.tool_name.clone()); let entry = format!( "{}::{}::{}", self.provider_label, tool, context.content ); self.events.lock().unwrap().push(entry); let message = format!("{} hook observed {}", self.provider_label, tool); ( context, HookResult::AddMessage(message, Role::Assistant), ) } } struct EnvOverride { key: String, previous: Option, } impl EnvOverride { fn set(key: &str, value: String) -> Self { let previous = std::env::var(key).ok(); std::env::set_var(key, &value); Self { key: key.to_string(), previous, } } } impl Drop for EnvOverride { fn drop(&mut self) { if let Some(prev) = self.previous.take() { std::env::set_var(&self.key, prev); } else { std::env::remove_var(&self.key); } } } fn build_blueprint( provider: LlmProvider, model: &str, provider_label: &str, tool_calls: Arc>>, hook_events: Arc>>, ) -> WorkerBlueprint { let mut blueprint = Worker::blueprint(); blueprint .provider(provider) .model(model) .api_key(provider.as_str(), "test-key") .system_prompt_fn(|_, _| Ok::<_, PromptError>("Integration test system prompt.".into())) .add_tool(SampleTool::new(provider_label, Arc::clone(&tool_calls))) .attach_hook(RecordingHook::new(provider_label, Arc::clone(&hook_events))); blueprint } async fn setup_mock_response( case: &ProviderCase, server: &MockServer, expected_args: &Value, ) -> MockGuard { match case.provider { LlmProvider::OpenAI => { let arguments = expected_args.to_string(); let event_body = json!({ "choices": [{ "delta": { "tool_calls": [{ "function": { "name": SAMPLE_TOOL_NAME, "arguments": arguments } }] } }] }); let sse = format!( "data: {}\n\ndata: [DONE]\n\n", event_body ); Mock::given(method("POST")) .and(path(case.completion_path)) .respond_with( ResponseTemplate::new(200) .set_body_raw(sse, "text/event-stream"), ) .mount(server) .await } LlmProvider::Claude => { let event_tool = json!({ "type": "content_block_start", "data": { "content_block": { "type": "tool_use", "name": SAMPLE_TOOL_NAME, "input": expected_args } } }); let event_stop = json!({ "type": "message_stop", "data": {} }); let sse = format!( "data: {}\n\ndata: {}\n\ndata: [DONE]\n\n", event_tool, event_stop ); Mock::given(method("POST")) .and(path(case.completion_path)) .respond_with( ResponseTemplate::new(200) .set_body_raw(sse, "text/event-stream"), ) .mount(server) .await } LlmProvider::Gemini => { let body = json!({ "candidates": [{ "content": { "role": "model", "parts": [{ "functionCall": { "name": SAMPLE_TOOL_NAME, "args": expected_args } }] }, "finishReason": "STOP" }] }); Mock::given(method("POST")) .and(path(case.completion_path)) .respond_with(ResponseTemplate::new(200).set_body_json(body)) .mount(server) .await } LlmProvider::Ollama => { let first = json!({ "message": { "role": "assistant", "content": "", "tool_calls": [{ "function": { "name": SAMPLE_TOOL_NAME, "arguments": expected_args } }] }, "done": false }); let second = json!({ "message": { "role": "assistant", "content": "finished" }, "done": true }); let body = format!("{}\n{}\n", first, second); Mock::given(method("POST")) .and(path(case.completion_path)) .respond_with( ResponseTemplate::new(200) .set_body_raw(body, "application/x-ndjson"), ) .mount(server) .await } other => panic!("Unsupported provider in test: {:?}", other), } } fn provider_cases() -> Vec { vec![ ProviderCase { name: "openai", provider: LlmProvider::OpenAI, model: "gpt-4o-mini", env_var: "OPENAI_BASE_URL", completion_path: "/v1/chat/completions", }, ProviderCase { name: "gemini", provider: LlmProvider::Gemini, model: "gemini-1.5-flash", env_var: "GEMINI_BASE_URL", completion_path: "/v1beta/models/gemini-1.5-flash:streamGenerateContent", }, ProviderCase { name: "anthropic", provider: LlmProvider::Claude, model: "claude-3-opus-20240229", env_var: "ANTHROPIC_BASE_URL", completion_path: "/v1/messages", }, ProviderCase { name: "ollama", provider: LlmProvider::Ollama, model: "llama3", env_var: "OLLAMA_BASE_URL", completion_path: "/api/chat", }, ] } #[tokio::test] async fn worker_executes_tools_and_hooks_across_mocked_providers() { let env_lock = ENV_MUTEX.get_or_init(|| Mutex::new(())); for case in provider_cases() { let tool_calls = Arc::new(Mutex::new(Vec::::new())); let hook_events = Arc::new(Mutex::new(Vec::::new())); let _env_guard = env_lock.lock().unwrap(); let server = MockServer::start().await; let _env_override = EnvOverride::set(case.env_var, server.uri()); let expected_args = json!({ "provider": case.name, "request_id": 1 }); let _mock = setup_mock_response(&case, &server, &expected_args).await; let mut blueprint = build_blueprint( case.provider, case.model, case.name, Arc::clone(&tool_calls), Arc::clone(&hook_events), ); let mut worker = blueprint.instantiate().expect("worker to instantiate"); let mut stream = worker .process_task_stream( "Trigger the sample tool".to_string(), None, ) .await; let mut events = Vec::new(); while let Some(event) = stream.next().await { events.push(event.expect("stream event")); } let requests = server .received_requests() .await .expect("to inspect received requests"); assert_eq!( requests.len(), 1, "expected exactly one request for provider {}", case.name ); let body: Value = serde_json::from_slice(&requests[0].body).expect("request body to be JSON"); match case.provider { LlmProvider::OpenAI => { assert_eq!(body["model"], case.model); assert_eq!(body["stream"], true); assert_eq!( body["tools"][0]["function"]["name"], SAMPLE_TOOL_NAME ); } LlmProvider::Claude => { assert_eq!(body["model"], case.model); assert_eq!(body["stream"], true); assert_eq!(body["tools"][0]["name"], SAMPLE_TOOL_NAME); } LlmProvider::Gemini => { assert_eq!( body["contents"] .as_array() .expect("contents to be array") .len(), 2, "system + user messages should be present" ); let tools = body["tools"][0]["functionDeclarations"] .as_array() .expect("function declarations to exist"); assert_eq!(tools[0]["name"], SAMPLE_TOOL_NAME); } LlmProvider::Ollama => { assert_eq!(body["model"], case.model); assert_eq!(body["stream"], true); assert_eq!( body["tools"][0]["function"]["name"], SAMPLE_TOOL_NAME ); } _ => unreachable!(), } let recorded_calls = tool_calls.lock().unwrap().clone(); assert_eq!( recorded_calls.len(), 1, "tool should execute exactly once for {}", case.name ); assert_eq!( recorded_calls[0], expected_args, "tool arguments should match for {}", case.name ); let recorded_hooks = hook_events.lock().unwrap().clone(); assert!( recorded_hooks .iter() .any(|entry| entry.contains(case.name) && entry.contains(SAMPLE_TOOL_NAME)), "hook should capture tool usage for {}: {:?}", case.name, recorded_hooks ); let mut saw_tool_call = false; let mut saw_tool_result = false; let mut saw_hook_message = false; let mut saw_completion = false; for event in &events { match event { StreamEvent::ToolCall(call) => { saw_tool_call = true; assert_eq!(call.name, SAMPLE_TOOL_NAME); assert_eq!( serde_json::from_str::(&call.arguments) .expect("tool arguments to be JSON"), expected_args ); } StreamEvent::ToolResult { tool_name, result } => { if tool_name == SAMPLE_TOOL_NAME { saw_tool_result = true; let value = result .as_ref() .expect("tool execution should succeed"); assert_eq!(value["status"], "ok"); assert_eq!(value["provider"], case.name); } } StreamEvent::HookMessage { hook_name, content, .. } => { if hook_name == "recording_hook" { saw_hook_message = true; assert!( content.contains(case.name), "hook content should mention provider" ); } } StreamEvent::Completion(_) => { saw_completion = true; } _ => {} } } assert!(saw_tool_call, "missing tool call event for {}", case.name); assert!( saw_tool_result, "missing tool result event for {}", case.name ); assert!( saw_hook_message, "missing hook message for {}", case.name ); assert!( saw_completion, "missing completion event for {}", case.name ); std::env::remove_var(case.env_var); } }