497 lines
15 KiB
Rust
497 lines
15 KiB
Rust
use std::sync::{Arc, Mutex, OnceLock};
|
|
|
|
use async_trait::async_trait;
|
|
use futures::StreamExt;
|
|
use serde_json::{json, Value};
|
|
use wiremock::matchers::{method, path};
|
|
use wiremock::{Mock, MockGuard, MockServer, ResponseTemplate};
|
|
|
|
use worker::{
|
|
HookContext, HookResult, LlmProvider, PromptError, StreamEvent, Tool, ToolResult, Worker,
|
|
WorkerBlueprint, WorkerHook,
|
|
};
|
|
use worker_types::Role;
|
|
|
|
const SAMPLE_TOOL_NAME: &str = "sample_tool";
|
|
static ENV_MUTEX: OnceLock<Mutex<()>> = OnceLock::new();
|
|
|
|
struct ProviderCase {
|
|
name: &'static str,
|
|
provider: LlmProvider,
|
|
model: &'static str,
|
|
env_var: &'static str,
|
|
completion_path: &'static str,
|
|
}
|
|
|
|
struct SampleTool {
|
|
name: String,
|
|
description: String,
|
|
calls: Arc<Mutex<Vec<Value>>>,
|
|
response: Value,
|
|
}
|
|
|
|
impl SampleTool {
|
|
fn new(provider_label: &str, calls: Arc<Mutex<Vec<Value>>>) -> Self {
|
|
Self {
|
|
name: SAMPLE_TOOL_NAME.to_string(),
|
|
description: format!("Records invocations for {}", provider_label),
|
|
calls,
|
|
response: json!({
|
|
"status": "ok",
|
|
"provider": provider_label,
|
|
}),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Tool for SampleTool {
|
|
fn name(&self) -> &str {
|
|
&self.name
|
|
}
|
|
|
|
fn description(&self) -> &str {
|
|
&self.description
|
|
}
|
|
|
|
fn parameters_schema(&self) -> Value {
|
|
json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"provider": {"type": "string"},
|
|
"request_id": {"type": "integer"}
|
|
},
|
|
"required": ["provider", "request_id"]
|
|
})
|
|
}
|
|
|
|
async fn execute(&self, args: Value) -> ToolResult<Value> {
|
|
self.calls.lock().unwrap().push(args.clone());
|
|
Ok(self.response.clone())
|
|
}
|
|
}
|
|
|
|
struct RecordingHook {
|
|
tool_name: String,
|
|
provider_label: String,
|
|
events: Arc<Mutex<Vec<String>>>,
|
|
}
|
|
|
|
impl RecordingHook {
|
|
fn new(provider_label: &str, events: Arc<Mutex<Vec<String>>>) -> Self {
|
|
Self {
|
|
tool_name: SAMPLE_TOOL_NAME.to_string(),
|
|
provider_label: provider_label.to_string(),
|
|
events,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl WorkerHook for RecordingHook {
|
|
fn name(&self) -> &str {
|
|
"recording_hook"
|
|
}
|
|
|
|
fn hook_type(&self) -> &str {
|
|
"PostToolUse"
|
|
}
|
|
|
|
fn matcher(&self) -> &str {
|
|
&self.tool_name
|
|
}
|
|
|
|
async fn execute(&self, context: HookContext) -> (HookContext, HookResult) {
|
|
let tool = context
|
|
.get_variable("current_tool")
|
|
.cloned()
|
|
.unwrap_or_else(|| self.tool_name.clone());
|
|
|
|
let entry = format!(
|
|
"{}::{}::{}",
|
|
self.provider_label, tool, context.content
|
|
);
|
|
self.events.lock().unwrap().push(entry);
|
|
|
|
let message = format!("{} hook observed {}", self.provider_label, tool);
|
|
(
|
|
context,
|
|
HookResult::AddMessage(message, Role::Assistant),
|
|
)
|
|
}
|
|
}
|
|
|
|
struct EnvOverride {
|
|
key: String,
|
|
previous: Option<String>,
|
|
}
|
|
|
|
impl EnvOverride {
|
|
fn set(key: &str, value: String) -> Self {
|
|
let previous = std::env::var(key).ok();
|
|
std::env::set_var(key, &value);
|
|
Self {
|
|
key: key.to_string(),
|
|
previous,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Drop for EnvOverride {
|
|
fn drop(&mut self) {
|
|
if let Some(prev) = self.previous.take() {
|
|
std::env::set_var(&self.key, prev);
|
|
} else {
|
|
std::env::remove_var(&self.key);
|
|
}
|
|
}
|
|
}
|
|
|
|
fn build_blueprint(
|
|
provider: LlmProvider,
|
|
model: &str,
|
|
provider_label: &str,
|
|
tool_calls: Arc<Mutex<Vec<Value>>>,
|
|
hook_events: Arc<Mutex<Vec<String>>>,
|
|
) -> WorkerBlueprint {
|
|
let mut blueprint = Worker::blueprint();
|
|
blueprint
|
|
.provider(provider)
|
|
.model(model)
|
|
.api_key(provider.as_str(), "test-key")
|
|
.system_prompt_fn(|_, _| Ok::<_, PromptError>("Integration test system prompt.".into()))
|
|
.add_tool(SampleTool::new(provider_label, Arc::clone(&tool_calls)))
|
|
.attach_hook(RecordingHook::new(provider_label, Arc::clone(&hook_events)));
|
|
blueprint
|
|
}
|
|
|
|
async fn setup_mock_response(
|
|
case: &ProviderCase,
|
|
server: &MockServer,
|
|
expected_args: &Value,
|
|
) -> MockGuard {
|
|
match case.provider {
|
|
LlmProvider::OpenAI => {
|
|
let arguments = expected_args.to_string();
|
|
let event_body = json!({
|
|
"choices": [{
|
|
"delta": {
|
|
"tool_calls": [{
|
|
"function": {
|
|
"name": SAMPLE_TOOL_NAME,
|
|
"arguments": arguments
|
|
}
|
|
}]
|
|
}
|
|
}]
|
|
});
|
|
let sse = format!(
|
|
"data: {}\n\ndata: [DONE]\n\n",
|
|
event_body
|
|
);
|
|
Mock::given(method("POST"))
|
|
.and(path(case.completion_path))
|
|
.respond_with(
|
|
ResponseTemplate::new(200)
|
|
.set_body_raw(sse, "text/event-stream"),
|
|
)
|
|
.mount(server)
|
|
.await
|
|
}
|
|
LlmProvider::Claude => {
|
|
let event_tool = json!({
|
|
"type": "content_block_start",
|
|
"data": {
|
|
"content_block": {
|
|
"type": "tool_use",
|
|
"name": SAMPLE_TOOL_NAME,
|
|
"input": expected_args
|
|
}
|
|
}
|
|
});
|
|
let event_stop = json!({
|
|
"type": "message_stop",
|
|
"data": {}
|
|
});
|
|
let sse = format!(
|
|
"data: {}\n\ndata: {}\n\ndata: [DONE]\n\n",
|
|
event_tool, event_stop
|
|
);
|
|
Mock::given(method("POST"))
|
|
.and(path(case.completion_path))
|
|
.respond_with(
|
|
ResponseTemplate::new(200)
|
|
.set_body_raw(sse, "text/event-stream"),
|
|
)
|
|
.mount(server)
|
|
.await
|
|
}
|
|
LlmProvider::Gemini => {
|
|
let body = json!({
|
|
"candidates": [{
|
|
"content": {
|
|
"role": "model",
|
|
"parts": [{
|
|
"functionCall": {
|
|
"name": SAMPLE_TOOL_NAME,
|
|
"args": expected_args
|
|
}
|
|
}]
|
|
},
|
|
"finishReason": "STOP"
|
|
}]
|
|
});
|
|
Mock::given(method("POST"))
|
|
.and(path(case.completion_path))
|
|
.respond_with(ResponseTemplate::new(200).set_body_json(body))
|
|
.mount(server)
|
|
.await
|
|
}
|
|
LlmProvider::Ollama => {
|
|
let first = json!({
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "",
|
|
"tool_calls": [{
|
|
"function": {
|
|
"name": SAMPLE_TOOL_NAME,
|
|
"arguments": expected_args
|
|
}
|
|
}]
|
|
},
|
|
"done": false
|
|
});
|
|
let second = json!({
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "finished"
|
|
},
|
|
"done": true
|
|
});
|
|
let body = format!("{}\n{}\n", first, second);
|
|
Mock::given(method("POST"))
|
|
.and(path(case.completion_path))
|
|
.respond_with(
|
|
ResponseTemplate::new(200)
|
|
.set_body_raw(body, "application/x-ndjson"),
|
|
)
|
|
.mount(server)
|
|
.await
|
|
}
|
|
other => panic!("Unsupported provider in test: {:?}", other),
|
|
}
|
|
}
|
|
|
|
fn provider_cases() -> Vec<ProviderCase> {
|
|
vec![
|
|
ProviderCase {
|
|
name: "openai",
|
|
provider: LlmProvider::OpenAI,
|
|
model: "gpt-4o-mini",
|
|
env_var: "OPENAI_BASE_URL",
|
|
completion_path: "/v1/chat/completions",
|
|
},
|
|
ProviderCase {
|
|
name: "gemini",
|
|
provider: LlmProvider::Gemini,
|
|
model: "gemini-1.5-flash",
|
|
env_var: "GEMINI_BASE_URL",
|
|
completion_path: "/v1beta/models/gemini-1.5-flash:streamGenerateContent",
|
|
},
|
|
ProviderCase {
|
|
name: "anthropic",
|
|
provider: LlmProvider::Claude,
|
|
model: "claude-3-opus-20240229",
|
|
env_var: "ANTHROPIC_BASE_URL",
|
|
completion_path: "/v1/messages",
|
|
},
|
|
ProviderCase {
|
|
name: "ollama",
|
|
provider: LlmProvider::Ollama,
|
|
model: "llama3",
|
|
env_var: "OLLAMA_BASE_URL",
|
|
completion_path: "/api/chat",
|
|
},
|
|
]
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn worker_executes_tools_and_hooks_across_mocked_providers() {
|
|
let env_lock = ENV_MUTEX.get_or_init(|| Mutex::new(()));
|
|
|
|
for case in provider_cases() {
|
|
let tool_calls = Arc::new(Mutex::new(Vec::<Value>::new()));
|
|
let hook_events = Arc::new(Mutex::new(Vec::<String>::new()));
|
|
|
|
let _env_guard = env_lock.lock().unwrap();
|
|
|
|
let server = MockServer::start().await;
|
|
let _env_override = EnvOverride::set(case.env_var, server.uri());
|
|
|
|
let expected_args = json!({
|
|
"provider": case.name,
|
|
"request_id": 1
|
|
});
|
|
|
|
let _mock = setup_mock_response(&case, &server, &expected_args).await;
|
|
|
|
let mut blueprint = build_blueprint(
|
|
case.provider,
|
|
case.model,
|
|
case.name,
|
|
Arc::clone(&tool_calls),
|
|
Arc::clone(&hook_events),
|
|
);
|
|
|
|
let mut worker = blueprint.instantiate().expect("worker to instantiate");
|
|
|
|
let mut stream = worker
|
|
.process_task_stream(
|
|
"Trigger the sample tool".to_string(),
|
|
None,
|
|
)
|
|
.await;
|
|
|
|
let mut events = Vec::new();
|
|
while let Some(event) = stream.next().await {
|
|
events.push(event.expect("stream event"));
|
|
}
|
|
|
|
let requests = server
|
|
.received_requests()
|
|
.await
|
|
.expect("to inspect received requests");
|
|
assert_eq!(
|
|
requests.len(),
|
|
1,
|
|
"expected exactly one request for provider {}",
|
|
case.name
|
|
);
|
|
let body: Value =
|
|
serde_json::from_slice(&requests[0].body).expect("request body to be JSON");
|
|
|
|
match case.provider {
|
|
LlmProvider::OpenAI => {
|
|
assert_eq!(body["model"], case.model);
|
|
assert_eq!(body["stream"], true);
|
|
assert_eq!(
|
|
body["tools"][0]["function"]["name"],
|
|
SAMPLE_TOOL_NAME
|
|
);
|
|
}
|
|
LlmProvider::Claude => {
|
|
assert_eq!(body["model"], case.model);
|
|
assert_eq!(body["stream"], true);
|
|
assert_eq!(body["tools"][0]["name"], SAMPLE_TOOL_NAME);
|
|
}
|
|
LlmProvider::Gemini => {
|
|
assert_eq!(
|
|
body["contents"]
|
|
.as_array()
|
|
.expect("contents to be array")
|
|
.len(),
|
|
2,
|
|
"system + user messages should be present"
|
|
);
|
|
let tools = body["tools"][0]["functionDeclarations"]
|
|
.as_array()
|
|
.expect("function declarations to exist");
|
|
assert_eq!(tools[0]["name"], SAMPLE_TOOL_NAME);
|
|
}
|
|
LlmProvider::Ollama => {
|
|
assert_eq!(body["model"], case.model);
|
|
assert_eq!(body["stream"], true);
|
|
assert_eq!(
|
|
body["tools"][0]["function"]["name"],
|
|
SAMPLE_TOOL_NAME
|
|
);
|
|
}
|
|
_ => unreachable!(),
|
|
}
|
|
|
|
let recorded_calls = tool_calls.lock().unwrap().clone();
|
|
assert_eq!(
|
|
recorded_calls.len(),
|
|
1,
|
|
"tool should execute exactly once for {}",
|
|
case.name
|
|
);
|
|
assert_eq!(
|
|
recorded_calls[0], expected_args,
|
|
"tool arguments should match for {}",
|
|
case.name
|
|
);
|
|
|
|
let recorded_hooks = hook_events.lock().unwrap().clone();
|
|
assert!(
|
|
recorded_hooks
|
|
.iter()
|
|
.any(|entry| entry.contains(case.name) && entry.contains(SAMPLE_TOOL_NAME)),
|
|
"hook should capture tool usage for {}: {:?}",
|
|
case.name,
|
|
recorded_hooks
|
|
);
|
|
|
|
let mut saw_tool_call = false;
|
|
let mut saw_tool_result = false;
|
|
let mut saw_hook_message = false;
|
|
let mut saw_completion = false;
|
|
|
|
for event in &events {
|
|
match event {
|
|
StreamEvent::ToolCall(call) => {
|
|
saw_tool_call = true;
|
|
assert_eq!(call.name, SAMPLE_TOOL_NAME);
|
|
assert_eq!(
|
|
serde_json::from_str::<Value>(&call.arguments)
|
|
.expect("tool arguments to be JSON"),
|
|
expected_args
|
|
);
|
|
}
|
|
StreamEvent::ToolResult { tool_name, result } => {
|
|
if tool_name == SAMPLE_TOOL_NAME {
|
|
saw_tool_result = true;
|
|
let value = result
|
|
.as_ref()
|
|
.expect("tool execution should succeed");
|
|
assert_eq!(value["status"], "ok");
|
|
assert_eq!(value["provider"], case.name);
|
|
}
|
|
}
|
|
StreamEvent::HookMessage { hook_name, content, .. } => {
|
|
if hook_name == "recording_hook" {
|
|
saw_hook_message = true;
|
|
assert!(
|
|
content.contains(case.name),
|
|
"hook content should mention provider"
|
|
);
|
|
}
|
|
}
|
|
StreamEvent::Completion(_) => {
|
|
saw_completion = true;
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
assert!(saw_tool_call, "missing tool call event for {}", case.name);
|
|
assert!(
|
|
saw_tool_result,
|
|
"missing tool result event for {}",
|
|
case.name
|
|
);
|
|
assert!(
|
|
saw_hook_message,
|
|
"missing hook message for {}",
|
|
case.name
|
|
);
|
|
assert!(
|
|
saw_completion,
|
|
"missing completion event for {}",
|
|
case.name
|
|
);
|
|
|
|
std::env::remove_var(case.env_var);
|
|
}
|
|
}
|