0.5.2: テストの更新

This commit is contained in:
Keisuke Hirata 2025-11-01 05:27:47 +09:00
parent 90edd3828b
commit b206acc3d3
8 changed files with 687 additions and 249 deletions

197
Cargo.lock generated
View File

@ -26,6 +26,16 @@ version = "1.0.100"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61"
[[package]]
name = "assert-json-diff"
version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12"
dependencies = [
"serde",
"serde_json",
]
[[package]] [[package]]
name = "async-stream" name = "async-stream"
version = "0.3.6" version = "0.3.6"
@ -59,6 +69,12 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "atomic-waker"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
[[package]] [[package]]
name = "autocfg" name = "autocfg"
version = "1.5.0" version = "1.5.0"
@ -71,6 +87,12 @@ version = "0.21.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
[[package]]
name = "base64"
version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "1.3.2" version = "1.3.2"
@ -141,6 +163,24 @@ version = "0.8.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
[[package]]
name = "deadpool"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0be2b1d1d6ec8d846f05e137292d0b89133caf95ef33695424c09568bdd39b1b"
dependencies = [
"deadpool-runtime",
"lazy_static",
"num_cpus",
"tokio",
]
[[package]]
name = "deadpool-runtime"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b"
[[package]] [[package]]
name = "dirs" name = "dirs"
version = "6.0.0" version = "6.0.0"
@ -365,7 +405,26 @@ dependencies = [
"futures-core", "futures-core",
"futures-sink", "futures-sink",
"futures-util", "futures-util",
"http", "http 0.2.12",
"indexmap",
"slab",
"tokio",
"tokio-util",
"tracing",
]
[[package]]
name = "h2"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3c0b69cfcb4e1b9f1bf2f53f95f766e4661169728ec61cd3fe5a0166f2d1386"
dependencies = [
"atomic-waker",
"bytes",
"fnv",
"futures-core",
"futures-sink",
"http 1.3.1",
"indexmap", "indexmap",
"slab", "slab",
"tokio", "tokio",
@ -385,6 +444,12 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hermit-abi"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
[[package]] [[package]]
name = "http" name = "http"
version = "0.2.12" version = "0.2.12"
@ -396,6 +461,17 @@ dependencies = [
"itoa", "itoa",
] ]
[[package]]
name = "http"
version = "1.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565"
dependencies = [
"bytes",
"fnv",
"itoa",
]
[[package]] [[package]]
name = "http-body" name = "http-body"
version = "0.4.6" version = "0.4.6"
@ -403,7 +479,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2"
dependencies = [ dependencies = [
"bytes", "bytes",
"http", "http 0.2.12",
"pin-project-lite",
]
[[package]]
name = "http-body"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184"
dependencies = [
"bytes",
"http 1.3.1",
]
[[package]]
name = "http-body-util"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a"
dependencies = [
"bytes",
"futures-core",
"http 1.3.1",
"http-body 1.0.1",
"pin-project-lite", "pin-project-lite",
] ]
@ -429,9 +528,9 @@ dependencies = [
"futures-channel", "futures-channel",
"futures-core", "futures-core",
"futures-util", "futures-util",
"h2", "h2 0.3.27",
"http", "http 0.2.12",
"http-body", "http-body 0.4.6",
"httparse", "httparse",
"httpdate", "httpdate",
"itoa", "itoa",
@ -443,6 +542,29 @@ dependencies = [
"want", "want",
] ]
[[package]]
name = "hyper"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb3aa54a13a0dfe7fbe3a59e0c76093041720fdc77b110cc0fc260fafb4dc51e"
dependencies = [
"atomic-waker",
"bytes",
"futures-channel",
"futures-core",
"h2 0.4.12",
"http 1.3.1",
"http-body 1.0.1",
"httparse",
"httpdate",
"itoa",
"pin-project-lite",
"pin-utils",
"smallvec",
"tokio",
"want",
]
[[package]] [[package]]
name = "hyper-rustls" name = "hyper-rustls"
version = "0.24.2" version = "0.24.2"
@ -450,13 +572,28 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590"
dependencies = [ dependencies = [
"futures-util", "futures-util",
"http", "http 0.2.12",
"hyper", "hyper 0.14.32",
"rustls", "rustls",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",
] ]
[[package]]
name = "hyper-util"
version = "0.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c6995591a8f1380fcb4ba966a252a4b29188d51d2b89e3a252f5305be65aea8"
dependencies = [
"bytes",
"futures-core",
"http 1.3.1",
"http-body 1.0.1",
"hyper 1.7.0",
"pin-project-lite",
"tokio",
]
[[package]] [[package]]
name = "iana-time-zone" name = "iana-time-zone"
version = "0.1.64" version = "0.1.64"
@ -736,6 +873,16 @@ dependencies = [
"autocfg", "autocfg",
] ]
[[package]]
name = "num_cpus"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b"
dependencies = [
"hermit-abi",
"libc",
]
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.21.3" version = "1.21.3"
@ -897,15 +1044,15 @@ version = "0.11.27"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62" checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62"
dependencies = [ dependencies = [
"base64", "base64 0.21.7",
"bytes", "bytes",
"encoding_rs", "encoding_rs",
"futures-core", "futures-core",
"futures-util", "futures-util",
"h2", "h2 0.3.27",
"http", "http 0.2.12",
"http-body", "http-body 0.4.6",
"hyper", "hyper 0.14.32",
"hyper-rustls", "hyper-rustls",
"ipnet", "ipnet",
"js-sys", "js-sys",
@ -979,7 +1126,7 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c"
dependencies = [ dependencies = [
"base64", "base64 0.21.7",
] ]
[[package]] [[package]]
@ -1957,6 +2104,29 @@ dependencies = [
"windows-sys 0.48.0", "windows-sys 0.48.0",
] ]
[[package]]
name = "wiremock"
version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08db1edfb05d9b3c1542e521aea074442088292f00b5f28e435c714a98f85031"
dependencies = [
"assert-json-diff",
"base64 0.22.1",
"deadpool",
"futures",
"http 1.3.1",
"http-body-util",
"hyper 1.7.0",
"hyper-util",
"log",
"once_cell",
"regex",
"serde",
"serde_json",
"tokio",
"url",
]
[[package]] [[package]]
name = "wit-bindgen" name = "wit-bindgen"
version = "0.46.0" version = "0.46.0"
@ -1994,6 +2164,7 @@ dependencies = [
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"uuid", "uuid",
"wiremock",
"worker-macros", "worker-macros",
"worker-types", "worker-types",
"xdg", "xdg",

View File

@ -2,11 +2,11 @@
"nodes": { "nodes": {
"flake-compat": { "flake-compat": {
"locked": { "locked": {
"lastModified": 1747046372, "lastModified": 1761588595,
"narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=",
"owner": "edolstra", "owner": "edolstra",
"repo": "flake-compat", "repo": "flake-compat",
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -35,11 +35,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1751011381, "lastModified": 1761373498,
"narHash": "sha256-krGXKxvkBhnrSC/kGBmg5MyupUUT5R6IBCLEzx9jhMM=", "narHash": "sha256-Q/uhWNvd7V7k1H1ZPMy/vkx3F8C13ZcdrKjO7Jv7v0c=",
"owner": "nixos", "owner": "nixos",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "30e2e2857ba47844aa71991daa6ed1fc678bcbb7", "rev": "6a08e6bb4e46ff7fcbb53d409b253f6bad8a28ce",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -45,3 +45,4 @@ dynamic-loading = ["libloading"]
[dev-dependencies] [dev-dependencies]
tempfile = "3.10.1" tempfile = "3.10.1"
tracing-subscriber = "0.3" tracing-subscriber = "0.3"
wiremock = "0.6"

View File

@ -126,131 +126,3 @@ fn expand_environment_variables(input: &str) -> Result<String, WorkerError> {
Ok(result) Ok(result)
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config_creation() {
let mut config = McpConfig::default();
config.servers.insert(
"brave_search".to_string(),
McpServerDefinition {
command: "npx".to_string(),
args: vec![
"-y".to_string(),
"@brave/brave-search-mcp-server".to_string(),
],
env: HashMap::new(),
description: None,
enabled: true,
integration_mode: IntegrationMode::Individual,
},
);
assert!(!config.servers.is_empty());
assert!(config.servers.contains_key("brave_search"));
}
#[test]
fn test_config_serialization() {
let mut servers = HashMap::new();
servers.insert(
"filesystem".to_string(),
McpServerDefinition {
command: "npx".to_string(),
args: vec![
"-y".to_string(),
"@modelcontextprotocol/server-filesystem".to_string(),
],
env: HashMap::new(),
description: Some("Filesystem operations".to_string()),
enabled: true,
integration_mode: IntegrationMode::Proxy,
},
);
let config = McpConfig { servers };
let yaml = serde_yaml::to_string(&config).unwrap();
// YAML形式で正しくシリアライズされることを確認
assert!(yaml.contains("servers:"));
assert!(yaml.contains("filesystem:"));
assert!(yaml.contains("command:"));
}
#[test]
fn test_config_deserialization() {
let yaml_content = r#"
servers:
test_server:
command: "python3"
args: ["test.py"]
env:
TEST_VAR: "test_value"
description: "Test server"
enabled: true
integration_mode: "individual"
"#;
let config: McpConfig = serde_yaml::from_str(yaml_content).unwrap();
assert_eq!(config.servers.len(), 1);
let server = config.servers.get("test_server").unwrap();
assert_eq!(server.command, "python3");
assert_eq!(server.args, vec!["test.py"]);
assert_eq!(server.env.get("TEST_VAR").unwrap(), "test_value");
assert!(server.enabled);
}
#[test]
fn test_environment_variable_expansion() {
// SAFETY: Setting test environment variables in a single-threaded test context
unsafe {
std::env::set_var("TEST_VAR", "test_value");
}
let result = expand_environment_variables("prefix_${TEST_VAR}_suffix").unwrap();
assert_eq!(result, "prefix_test_value_suffix");
// 存在しない環境変数の場合はそのまま残る
let result = expand_environment_variables("${NON_EXISTENT_VAR}").unwrap();
assert_eq!(result, "${NON_EXISTENT_VAR}");
}
#[test]
fn test_enabled_servers_filter() {
let mut config = McpConfig::default();
// 有効なサーバーを追加
config.servers.insert(
"enabled_server".to_string(),
McpServerDefinition {
command: "test".to_string(),
args: vec![],
env: HashMap::new(),
description: None,
enabled: true,
integration_mode: IntegrationMode::Individual,
},
);
// 無効なサーバーを追加
config.servers.insert(
"disabled_server".to_string(),
McpServerDefinition {
command: "test".to_string(),
args: vec![],
env: HashMap::new(),
description: None,
enabled: false,
integration_mode: IntegrationMode::Individual,
},
);
let enabled_servers = config.get_enabled_servers();
assert_eq!(enabled_servers.len(), 1);
assert_eq!(enabled_servers[0].0, "enabled_server");
}
}

View File

@ -417,33 +417,3 @@ impl Drop for McpClient {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_json_rpc_serialization() {
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Value::Number(serde_json::Number::from(1)),
method: "tools/list".to_string(),
params: None,
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"jsonrpc\":\"2.0\""));
assert!(json.contains("\"id\":1"));
assert!(json.contains("\"method\":\"tools/list\""));
}
#[test]
fn test_error_response() {
let response_json =
r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32601,"message":"Method not found"}}"#;
let response: JsonRpcResponse = serde_json::from_str(response_json).unwrap();
assert!(response.error.is_some());
assert_eq!(response.error.unwrap().code, -32601);
}
}

View File

@ -451,36 +451,3 @@ pub async fn test_mcp_connection(
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mcp_server_config() {
let config =
McpServerConfig::new("npx", vec!["-y", "@modelcontextprotocol/server-everything"]);
assert_eq!(config.command, "npx");
assert_eq!(
config.args,
vec!["-y", "@modelcontextprotocol/server-everything"]
);
assert_eq!(
config.name,
"npx(-y @modelcontextprotocol/server-everything)"
);
}
#[tokio::test]
async fn test_mcp_tool_creation() {
let config = McpServerConfig::new("echo", vec!["test"]);
let tool = McpDynamicTool::new(config);
assert_eq!(tool.name(), "mcp_proxy");
assert!(!tool.description().is_empty());
let schema = tool.parameters_schema();
assert!(schema.is_object());
assert!(schema.get("properties").is_some());
}
}

View File

@ -206,42 +206,3 @@ impl LlmClientTrait for CustomLlmClient {
pub extern "C" fn create_plugin() -> Box<dyn ProviderPlugin> { pub extern "C" fn create_plugin() -> Box<dyn ProviderPlugin> {
Box::new(CustomProviderPlugin::new()) Box::new(CustomProviderPlugin::new())
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_plugin_metadata() {
let plugin = CustomProviderPlugin::new();
let metadata = plugin.metadata();
assert_eq!(metadata.id, "custom-provider");
assert_eq!(metadata.name, "Custom LLM Provider");
assert!(metadata.requires_api_key);
assert_eq!(metadata.supported_models.len(), 3);
}
#[test]
fn test_api_key_validation() {
let plugin = CustomProviderPlugin::new();
assert!(plugin.validate_api_key("custom-1234567890abcdefghij"));
assert!(!plugin.validate_api_key("invalid-key"));
assert!(!plugin.validate_api_key("custom-short"));
assert!(!plugin.validate_api_key(""));
}
#[tokio::test]
async fn test_plugin_initialization() {
let mut plugin = CustomProviderPlugin::new();
let mut config = HashMap::new();
config.insert(
"base_url".to_string(),
Value::String("https://api.example.com".to_string()),
);
let result = plugin.initialize(config).await;
assert!(result.is_ok());
}
}

View File

@ -0,0 +1,496 @@
use std::sync::{Arc, Mutex, OnceLock};
use async_trait::async_trait;
use futures::StreamExt;
use serde_json::{json, Value};
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockGuard, MockServer, ResponseTemplate};
use worker::{
HookContext, HookResult, LlmProvider, PromptError, StreamEvent, Tool, ToolResult, Worker,
WorkerBlueprint, WorkerHook,
};
use worker_types::Role;
const SAMPLE_TOOL_NAME: &str = "sample_tool";
static ENV_MUTEX: OnceLock<Mutex<()>> = OnceLock::new();
struct ProviderCase {
name: &'static str,
provider: LlmProvider,
model: &'static str,
env_var: &'static str,
completion_path: &'static str,
}
struct SampleTool {
name: String,
description: String,
calls: Arc<Mutex<Vec<Value>>>,
response: Value,
}
impl SampleTool {
fn new(provider_label: &str, calls: Arc<Mutex<Vec<Value>>>) -> Self {
Self {
name: SAMPLE_TOOL_NAME.to_string(),
description: format!("Records invocations for {}", provider_label),
calls,
response: json!({
"status": "ok",
"provider": provider_label,
}),
}
}
}
#[async_trait]
impl Tool for SampleTool {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn parameters_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"provider": {"type": "string"},
"request_id": {"type": "integer"}
},
"required": ["provider", "request_id"]
})
}
async fn execute(&self, args: Value) -> ToolResult<Value> {
self.calls.lock().unwrap().push(args.clone());
Ok(self.response.clone())
}
}
struct RecordingHook {
tool_name: String,
provider_label: String,
events: Arc<Mutex<Vec<String>>>,
}
impl RecordingHook {
fn new(provider_label: &str, events: Arc<Mutex<Vec<String>>>) -> Self {
Self {
tool_name: SAMPLE_TOOL_NAME.to_string(),
provider_label: provider_label.to_string(),
events,
}
}
}
#[async_trait]
impl WorkerHook for RecordingHook {
fn name(&self) -> &str {
"recording_hook"
}
fn hook_type(&self) -> &str {
"PostToolUse"
}
fn matcher(&self) -> &str {
&self.tool_name
}
async fn execute(&self, context: HookContext) -> (HookContext, HookResult) {
let tool = context
.get_variable("current_tool")
.cloned()
.unwrap_or_else(|| self.tool_name.clone());
let entry = format!(
"{}::{}::{}",
self.provider_label, tool, context.content
);
self.events.lock().unwrap().push(entry);
let message = format!("{} hook observed {}", self.provider_label, tool);
(
context,
HookResult::AddMessage(message, Role::Assistant),
)
}
}
struct EnvOverride {
key: String,
previous: Option<String>,
}
impl EnvOverride {
fn set(key: &str, value: String) -> Self {
let previous = std::env::var(key).ok();
std::env::set_var(key, &value);
Self {
key: key.to_string(),
previous,
}
}
}
impl Drop for EnvOverride {
fn drop(&mut self) {
if let Some(prev) = self.previous.take() {
std::env::set_var(&self.key, prev);
} else {
std::env::remove_var(&self.key);
}
}
}
fn build_blueprint(
provider: LlmProvider,
model: &str,
provider_label: &str,
tool_calls: Arc<Mutex<Vec<Value>>>,
hook_events: Arc<Mutex<Vec<String>>>,
) -> WorkerBlueprint {
let mut blueprint = Worker::blueprint();
blueprint
.provider(provider)
.model(model)
.api_key(provider.as_str(), "test-key")
.system_prompt_fn(|_, _| Ok::<_, PromptError>("Integration test system prompt.".into()))
.add_tool(SampleTool::new(provider_label, Arc::clone(&tool_calls)))
.attach_hook(RecordingHook::new(provider_label, Arc::clone(&hook_events)));
blueprint
}
async fn setup_mock_response(
case: &ProviderCase,
server: &MockServer,
expected_args: &Value,
) -> MockGuard {
match case.provider {
LlmProvider::OpenAI => {
let arguments = expected_args.to_string();
let event_body = json!({
"choices": [{
"delta": {
"tool_calls": [{
"function": {
"name": SAMPLE_TOOL_NAME,
"arguments": arguments
}
}]
}
}]
});
let sse = format!(
"data: {}\n\ndata: [DONE]\n\n",
event_body
);
Mock::given(method("POST"))
.and(path(case.completion_path))
.respond_with(
ResponseTemplate::new(200)
.set_body_raw(sse, "text/event-stream"),
)
.mount(server)
.await
}
LlmProvider::Claude => {
let event_tool = json!({
"type": "content_block_start",
"data": {
"content_block": {
"type": "tool_use",
"name": SAMPLE_TOOL_NAME,
"input": expected_args
}
}
});
let event_stop = json!({
"type": "message_stop",
"data": {}
});
let sse = format!(
"data: {}\n\ndata: {}\n\ndata: [DONE]\n\n",
event_tool, event_stop
);
Mock::given(method("POST"))
.and(path(case.completion_path))
.respond_with(
ResponseTemplate::new(200)
.set_body_raw(sse, "text/event-stream"),
)
.mount(server)
.await
}
LlmProvider::Gemini => {
let body = json!({
"candidates": [{
"content": {
"role": "model",
"parts": [{
"functionCall": {
"name": SAMPLE_TOOL_NAME,
"args": expected_args
}
}]
},
"finishReason": "STOP"
}]
});
Mock::given(method("POST"))
.and(path(case.completion_path))
.respond_with(ResponseTemplate::new(200).set_body_json(body))
.mount(server)
.await
}
LlmProvider::Ollama => {
let first = json!({
"message": {
"role": "assistant",
"content": "",
"tool_calls": [{
"function": {
"name": SAMPLE_TOOL_NAME,
"arguments": expected_args
}
}]
},
"done": false
});
let second = json!({
"message": {
"role": "assistant",
"content": "finished"
},
"done": true
});
let body = format!("{}\n{}\n", first, second);
Mock::given(method("POST"))
.and(path(case.completion_path))
.respond_with(
ResponseTemplate::new(200)
.set_body_raw(body, "application/x-ndjson"),
)
.mount(server)
.await
}
other => panic!("Unsupported provider in test: {:?}", other),
}
}
fn provider_cases() -> Vec<ProviderCase> {
vec![
ProviderCase {
name: "openai",
provider: LlmProvider::OpenAI,
model: "gpt-4o-mini",
env_var: "OPENAI_BASE_URL",
completion_path: "/v1/chat/completions",
},
ProviderCase {
name: "gemini",
provider: LlmProvider::Gemini,
model: "gemini-1.5-flash",
env_var: "GEMINI_BASE_URL",
completion_path: "/v1beta/models/gemini-1.5-flash:streamGenerateContent",
},
ProviderCase {
name: "anthropic",
provider: LlmProvider::Claude,
model: "claude-3-opus-20240229",
env_var: "ANTHROPIC_BASE_URL",
completion_path: "/v1/messages",
},
ProviderCase {
name: "ollama",
provider: LlmProvider::Ollama,
model: "llama3",
env_var: "OLLAMA_BASE_URL",
completion_path: "/api/chat",
},
]
}
#[tokio::test]
async fn worker_executes_tools_and_hooks_across_mocked_providers() {
let env_lock = ENV_MUTEX.get_or_init(|| Mutex::new(()));
for case in provider_cases() {
let tool_calls = Arc::new(Mutex::new(Vec::<Value>::new()));
let hook_events = Arc::new(Mutex::new(Vec::<String>::new()));
let _env_guard = env_lock.lock().unwrap();
let server = MockServer::start().await;
let _env_override = EnvOverride::set(case.env_var, server.uri());
let expected_args = json!({
"provider": case.name,
"request_id": 1
});
let _mock = setup_mock_response(&case, &server, &expected_args).await;
let mut blueprint = build_blueprint(
case.provider,
case.model,
case.name,
Arc::clone(&tool_calls),
Arc::clone(&hook_events),
);
let mut worker = blueprint.instantiate().expect("worker to instantiate");
let mut stream = worker
.process_task_stream(
"Trigger the sample tool".to_string(),
None,
)
.await;
let mut events = Vec::new();
while let Some(event) = stream.next().await {
events.push(event.expect("stream event"));
}
let requests = server
.received_requests()
.await
.expect("to inspect received requests");
assert_eq!(
requests.len(),
1,
"expected exactly one request for provider {}",
case.name
);
let body: Value =
serde_json::from_slice(&requests[0].body).expect("request body to be JSON");
match case.provider {
LlmProvider::OpenAI => {
assert_eq!(body["model"], case.model);
assert_eq!(body["stream"], true);
assert_eq!(
body["tools"][0]["function"]["name"],
SAMPLE_TOOL_NAME
);
}
LlmProvider::Claude => {
assert_eq!(body["model"], case.model);
assert_eq!(body["stream"], true);
assert_eq!(body["tools"][0]["name"], SAMPLE_TOOL_NAME);
}
LlmProvider::Gemini => {
assert_eq!(
body["contents"]
.as_array()
.expect("contents to be array")
.len(),
2,
"system + user messages should be present"
);
let tools = body["tools"][0]["functionDeclarations"]
.as_array()
.expect("function declarations to exist");
assert_eq!(tools[0]["name"], SAMPLE_TOOL_NAME);
}
LlmProvider::Ollama => {
assert_eq!(body["model"], case.model);
assert_eq!(body["stream"], true);
assert_eq!(
body["tools"][0]["function"]["name"],
SAMPLE_TOOL_NAME
);
}
_ => unreachable!(),
}
let recorded_calls = tool_calls.lock().unwrap().clone();
assert_eq!(
recorded_calls.len(),
1,
"tool should execute exactly once for {}",
case.name
);
assert_eq!(
recorded_calls[0], expected_args,
"tool arguments should match for {}",
case.name
);
let recorded_hooks = hook_events.lock().unwrap().clone();
assert!(
recorded_hooks
.iter()
.any(|entry| entry.contains(case.name) && entry.contains(SAMPLE_TOOL_NAME)),
"hook should capture tool usage for {}: {:?}",
case.name,
recorded_hooks
);
let mut saw_tool_call = false;
let mut saw_tool_result = false;
let mut saw_hook_message = false;
let mut saw_completion = false;
for event in &events {
match event {
StreamEvent::ToolCall(call) => {
saw_tool_call = true;
assert_eq!(call.name, SAMPLE_TOOL_NAME);
assert_eq!(
serde_json::from_str::<Value>(&call.arguments)
.expect("tool arguments to be JSON"),
expected_args
);
}
StreamEvent::ToolResult { tool_name, result } => {
if tool_name == SAMPLE_TOOL_NAME {
saw_tool_result = true;
let value = result
.as_ref()
.expect("tool execution should succeed");
assert_eq!(value["status"], "ok");
assert_eq!(value["provider"], case.name);
}
}
StreamEvent::HookMessage { hook_name, content, .. } => {
if hook_name == "recording_hook" {
saw_hook_message = true;
assert!(
content.contains(case.name),
"hook content should mention provider"
);
}
}
StreamEvent::Completion(_) => {
saw_completion = true;
}
_ => {}
}
}
assert!(saw_tool_call, "missing tool call event for {}", case.name);
assert!(
saw_tool_result,
"missing tool result event for {}",
case.name
);
assert!(
saw_hook_message,
"missing hook message for {}",
case.name
);
assert!(
saw_completion,
"missing completion event for {}",
case.name
);
std::env::remove_var(case.env_var);
}
}