diff --git a/Cargo.lock b/Cargo.lock index 908b007..523b78a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,16 @@ version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" +[[package]] +name = "assert-json-diff" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "async-stream" version = "0.3.6" @@ -59,6 +69,12 @@ dependencies = [ "syn", ] +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "autocfg" version = "1.5.0" @@ -71,6 +87,12 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "bitflags" version = "1.3.2" @@ -141,6 +163,24 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "deadpool" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0be2b1d1d6ec8d846f05e137292d0b89133caf95ef33695424c09568bdd39b1b" +dependencies = [ + "deadpool-runtime", + "lazy_static", + "num_cpus", + "tokio", +] + +[[package]] +name = "deadpool-runtime" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" + [[package]] name = "dirs" version = "6.0.0" @@ -365,7 +405,26 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", + "http 0.2.12", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "h2" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3c0b69cfcb4e1b9f1bf2f53f95f766e4661169728ec61cd3fe5a0166f2d1386" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http 1.3.1", "indexmap", "slab", "tokio", @@ -385,6 +444,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + [[package]] name = "http" version = "0.2.12" @@ -396,6 +461,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http-body" version = "0.4.6" @@ -403,7 +479,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", - "http", + "http 0.2.12", + "pin-project-lite", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http 1.3.1", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http 1.3.1", + "http-body 1.0.1", "pin-project-lite", ] @@ -429,9 +528,9 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2", - "http", - "http-body", + "h2 0.3.27", + "http 0.2.12", + "http-body 0.4.6", "httparse", "httpdate", "itoa", @@ -443,6 +542,29 @@ dependencies = [ "want", ] +[[package]] +name = "hyper" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb3aa54a13a0dfe7fbe3a59e0c76093041720fdc77b110cc0fc260fafb4dc51e" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "h2 0.4.12", + "http 1.3.1", + "http-body 1.0.1", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "pin-utils", + "smallvec", + "tokio", + "want", +] + [[package]] name = "hyper-rustls" version = "0.24.2" @@ -450,13 +572,28 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" dependencies = [ "futures-util", - "http", - "hyper", + "http 0.2.12", + "hyper 0.14.32", "rustls", "tokio", "tokio-rustls", ] +[[package]] +name = "hyper-util" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c6995591a8f1380fcb4ba966a252a4b29188d51d2b89e3a252f5305be65aea8" +dependencies = [ + "bytes", + "futures-core", + "http 1.3.1", + "http-body 1.0.1", + "hyper 1.7.0", + "pin-project-lite", + "tokio", +] + [[package]] name = "iana-time-zone" version = "0.1.64" @@ -736,6 +873,16 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_cpus" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -897,15 +1044,15 @@ version = "0.11.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62" dependencies = [ - "base64", + "base64 0.21.7", "bytes", "encoding_rs", "futures-core", "futures-util", - "h2", - "http", - "http-body", - "hyper", + "h2 0.3.27", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.32", "hyper-rustls", "ipnet", "js-sys", @@ -979,7 +1126,7 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" dependencies = [ - "base64", + "base64 0.21.7", ] [[package]] @@ -1957,6 +2104,29 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "wiremock" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08db1edfb05d9b3c1542e521aea074442088292f00b5f28e435c714a98f85031" +dependencies = [ + "assert-json-diff", + "base64 0.22.1", + "deadpool", + "futures", + "http 1.3.1", + "http-body-util", + "hyper 1.7.0", + "hyper-util", + "log", + "once_cell", + "regex", + "serde", + "serde_json", + "tokio", + "url", +] + [[package]] name = "wit-bindgen" version = "0.46.0" @@ -1994,6 +2164,7 @@ dependencies = [ "tracing", "tracing-subscriber", "uuid", + "wiremock", "worker-macros", "worker-types", "xdg", diff --git a/flake.lock b/flake.lock index 3cebd2b..e8ba29b 100644 --- a/flake.lock +++ b/flake.lock @@ -2,11 +2,11 @@ "nodes": { "flake-compat": { "locked": { - "lastModified": 1747046372, - "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "lastModified": 1761588595, + "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=", "owner": "edolstra", "repo": "flake-compat", - "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", "type": "github" }, "original": { @@ -35,11 +35,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1751011381, - "narHash": "sha256-krGXKxvkBhnrSC/kGBmg5MyupUUT5R6IBCLEzx9jhMM=", + "lastModified": 1761373498, + "narHash": "sha256-Q/uhWNvd7V7k1H1ZPMy/vkx3F8C13ZcdrKjO7Jv7v0c=", "owner": "nixos", "repo": "nixpkgs", - "rev": "30e2e2857ba47844aa71991daa6ed1fc678bcbb7", + "rev": "6a08e6bb4e46ff7fcbb53d409b253f6bad8a28ce", "type": "github" }, "original": { diff --git a/worker/Cargo.toml b/worker/Cargo.toml index 17ad650..3c4e4d3 100644 --- a/worker/Cargo.toml +++ b/worker/Cargo.toml @@ -45,3 +45,4 @@ dynamic-loading = ["libloading"] [dev-dependencies] tempfile = "3.10.1" tracing-subscriber = "0.3" +wiremock = "0.6" diff --git a/worker/src/mcp/config.rs b/worker/src/mcp/config.rs index a81892f..b408a64 100644 --- a/worker/src/mcp/config.rs +++ b/worker/src/mcp/config.rs @@ -126,131 +126,3 @@ fn expand_environment_variables(input: &str) -> Result { Ok(result) } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_default_config_creation() { - let mut config = McpConfig::default(); - config.servers.insert( - "brave_search".to_string(), - McpServerDefinition { - command: "npx".to_string(), - args: vec![ - "-y".to_string(), - "@brave/brave-search-mcp-server".to_string(), - ], - env: HashMap::new(), - description: None, - enabled: true, - integration_mode: IntegrationMode::Individual, - }, - ); - - assert!(!config.servers.is_empty()); - assert!(config.servers.contains_key("brave_search")); - } - - #[test] - fn test_config_serialization() { - let mut servers = HashMap::new(); - servers.insert( - "filesystem".to_string(), - McpServerDefinition { - command: "npx".to_string(), - args: vec![ - "-y".to_string(), - "@modelcontextprotocol/server-filesystem".to_string(), - ], - env: HashMap::new(), - description: Some("Filesystem operations".to_string()), - enabled: true, - integration_mode: IntegrationMode::Proxy, - }, - ); - - let config = McpConfig { servers }; - let yaml = serde_yaml::to_string(&config).unwrap(); - - // YAML形式で正しくシリアライズされることを確認 - assert!(yaml.contains("servers:")); - assert!(yaml.contains("filesystem:")); - assert!(yaml.contains("command:")); - } - - #[test] - fn test_config_deserialization() { - let yaml_content = r#" -servers: - test_server: - command: "python3" - args: ["test.py"] - env: - TEST_VAR: "test_value" - description: "Test server" - enabled: true - integration_mode: "individual" -"#; - - let config: McpConfig = serde_yaml::from_str(yaml_content).unwrap(); - assert_eq!(config.servers.len(), 1); - - let server = config.servers.get("test_server").unwrap(); - assert_eq!(server.command, "python3"); - assert_eq!(server.args, vec!["test.py"]); - assert_eq!(server.env.get("TEST_VAR").unwrap(), "test_value"); - assert!(server.enabled); - } - - #[test] - fn test_environment_variable_expansion() { - // SAFETY: Setting test environment variables in a single-threaded test context - unsafe { - std::env::set_var("TEST_VAR", "test_value"); - } - - let result = expand_environment_variables("prefix_${TEST_VAR}_suffix").unwrap(); - assert_eq!(result, "prefix_test_value_suffix"); - - // 存在しない環境変数の場合はそのまま残る - let result = expand_environment_variables("${NON_EXISTENT_VAR}").unwrap(); - assert_eq!(result, "${NON_EXISTENT_VAR}"); - } - - #[test] - fn test_enabled_servers_filter() { - let mut config = McpConfig::default(); - - // 有効なサーバーを追加 - config.servers.insert( - "enabled_server".to_string(), - McpServerDefinition { - command: "test".to_string(), - args: vec![], - env: HashMap::new(), - description: None, - enabled: true, - integration_mode: IntegrationMode::Individual, - }, - ); - - // 無効なサーバーを追加 - config.servers.insert( - "disabled_server".to_string(), - McpServerDefinition { - command: "test".to_string(), - args: vec![], - env: HashMap::new(), - description: None, - enabled: false, - integration_mode: IntegrationMode::Individual, - }, - ); - - let enabled_servers = config.get_enabled_servers(); - assert_eq!(enabled_servers.len(), 1); - assert_eq!(enabled_servers[0].0, "enabled_server"); - } -} diff --git a/worker/src/mcp/protocol.rs b/worker/src/mcp/protocol.rs index e5271ad..6f64488 100644 --- a/worker/src/mcp/protocol.rs +++ b/worker/src/mcp/protocol.rs @@ -417,33 +417,3 @@ impl Drop for McpClient { } } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_json_rpc_serialization() { - let request = JsonRpcRequest { - jsonrpc: "2.0".to_string(), - id: Value::Number(serde_json::Number::from(1)), - method: "tools/list".to_string(), - params: None, - }; - - let json = serde_json::to_string(&request).unwrap(); - assert!(json.contains("\"jsonrpc\":\"2.0\"")); - assert!(json.contains("\"id\":1")); - assert!(json.contains("\"method\":\"tools/list\"")); - } - - #[test] - fn test_error_response() { - let response_json = - r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32601,"message":"Method not found"}}"#; - let response: JsonRpcResponse = serde_json::from_str(response_json).unwrap(); - - assert!(response.error.is_some()); - assert_eq!(response.error.unwrap().code, -32601); - } -} diff --git a/worker/src/mcp/tool.rs b/worker/src/mcp/tool.rs index 62d25d3..5aaf7e4 100644 --- a/worker/src/mcp/tool.rs +++ b/worker/src/mcp/tool.rs @@ -451,36 +451,3 @@ pub async fn test_mcp_connection( } } } - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_mcp_server_config() { - let config = - McpServerConfig::new("npx", vec!["-y", "@modelcontextprotocol/server-everything"]); - assert_eq!(config.command, "npx"); - assert_eq!( - config.args, - vec!["-y", "@modelcontextprotocol/server-everything"] - ); - assert_eq!( - config.name, - "npx(-y @modelcontextprotocol/server-everything)" - ); - } - - #[tokio::test] - async fn test_mcp_tool_creation() { - let config = McpServerConfig::new("echo", vec!["test"]); - let tool = McpDynamicTool::new(config); - - assert_eq!(tool.name(), "mcp_proxy"); - assert!(!tool.description().is_empty()); - - let schema = tool.parameters_schema(); - assert!(schema.is_object()); - assert!(schema.get("properties").is_some()); - } -} diff --git a/worker/src/plugin/example_provider.rs b/worker/src/plugin/example_provider.rs index 296e601..cb5df32 100644 --- a/worker/src/plugin/example_provider.rs +++ b/worker/src/plugin/example_provider.rs @@ -206,42 +206,3 @@ impl LlmClientTrait for CustomLlmClient { pub extern "C" fn create_plugin() -> Box { Box::new(CustomProviderPlugin::new()) } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_plugin_metadata() { - let plugin = CustomProviderPlugin::new(); - let metadata = plugin.metadata(); - - assert_eq!(metadata.id, "custom-provider"); - assert_eq!(metadata.name, "Custom LLM Provider"); - assert!(metadata.requires_api_key); - assert_eq!(metadata.supported_models.len(), 3); - } - - #[test] - fn test_api_key_validation() { - let plugin = CustomProviderPlugin::new(); - - assert!(plugin.validate_api_key("custom-1234567890abcdefghij")); - assert!(!plugin.validate_api_key("invalid-key")); - assert!(!plugin.validate_api_key("custom-short")); - assert!(!plugin.validate_api_key("")); - } - - #[tokio::test] - async fn test_plugin_initialization() { - let mut plugin = CustomProviderPlugin::new(); - let mut config = HashMap::new(); - config.insert( - "base_url".to_string(), - Value::String("https://api.example.com".to_string()), - ); - - let result = plugin.initialize(config).await; - assert!(result.is_ok()); - } -} diff --git a/worker/src/tests/mock_llm_integration.rs b/worker/src/tests/mock_llm_integration.rs new file mode 100644 index 0000000..6554737 --- /dev/null +++ b/worker/src/tests/mock_llm_integration.rs @@ -0,0 +1,496 @@ +use std::sync::{Arc, Mutex, OnceLock}; + +use async_trait::async_trait; +use futures::StreamExt; +use serde_json::{json, Value}; +use wiremock::matchers::{method, path}; +use wiremock::{Mock, MockGuard, MockServer, ResponseTemplate}; + +use worker::{ + HookContext, HookResult, LlmProvider, PromptError, StreamEvent, Tool, ToolResult, Worker, + WorkerBlueprint, WorkerHook, +}; +use worker_types::Role; + +const SAMPLE_TOOL_NAME: &str = "sample_tool"; +static ENV_MUTEX: OnceLock> = OnceLock::new(); + +struct ProviderCase { + name: &'static str, + provider: LlmProvider, + model: &'static str, + env_var: &'static str, + completion_path: &'static str, +} + +struct SampleTool { + name: String, + description: String, + calls: Arc>>, + response: Value, +} + +impl SampleTool { + fn new(provider_label: &str, calls: Arc>>) -> Self { + Self { + name: SAMPLE_TOOL_NAME.to_string(), + description: format!("Records invocations for {}", provider_label), + calls, + response: json!({ + "status": "ok", + "provider": provider_label, + }), + } + } +} + +#[async_trait] +impl Tool for SampleTool { + fn name(&self) -> &str { + &self.name + } + + fn description(&self) -> &str { + &self.description + } + + fn parameters_schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "provider": {"type": "string"}, + "request_id": {"type": "integer"} + }, + "required": ["provider", "request_id"] + }) + } + + async fn execute(&self, args: Value) -> ToolResult { + self.calls.lock().unwrap().push(args.clone()); + Ok(self.response.clone()) + } +} + +struct RecordingHook { + tool_name: String, + provider_label: String, + events: Arc>>, +} + +impl RecordingHook { + fn new(provider_label: &str, events: Arc>>) -> Self { + Self { + tool_name: SAMPLE_TOOL_NAME.to_string(), + provider_label: provider_label.to_string(), + events, + } + } +} + +#[async_trait] +impl WorkerHook for RecordingHook { + fn name(&self) -> &str { + "recording_hook" + } + + fn hook_type(&self) -> &str { + "PostToolUse" + } + + fn matcher(&self) -> &str { + &self.tool_name + } + + async fn execute(&self, context: HookContext) -> (HookContext, HookResult) { + let tool = context + .get_variable("current_tool") + .cloned() + .unwrap_or_else(|| self.tool_name.clone()); + + let entry = format!( + "{}::{}::{}", + self.provider_label, tool, context.content + ); + self.events.lock().unwrap().push(entry); + + let message = format!("{} hook observed {}", self.provider_label, tool); + ( + context, + HookResult::AddMessage(message, Role::Assistant), + ) + } +} + +struct EnvOverride { + key: String, + previous: Option, +} + +impl EnvOverride { + fn set(key: &str, value: String) -> Self { + let previous = std::env::var(key).ok(); + std::env::set_var(key, &value); + Self { + key: key.to_string(), + previous, + } + } +} + +impl Drop for EnvOverride { + fn drop(&mut self) { + if let Some(prev) = self.previous.take() { + std::env::set_var(&self.key, prev); + } else { + std::env::remove_var(&self.key); + } + } +} + +fn build_blueprint( + provider: LlmProvider, + model: &str, + provider_label: &str, + tool_calls: Arc>>, + hook_events: Arc>>, +) -> WorkerBlueprint { + let mut blueprint = Worker::blueprint(); + blueprint + .provider(provider) + .model(model) + .api_key(provider.as_str(), "test-key") + .system_prompt_fn(|_, _| Ok::<_, PromptError>("Integration test system prompt.".into())) + .add_tool(SampleTool::new(provider_label, Arc::clone(&tool_calls))) + .attach_hook(RecordingHook::new(provider_label, Arc::clone(&hook_events))); + blueprint +} + +async fn setup_mock_response( + case: &ProviderCase, + server: &MockServer, + expected_args: &Value, +) -> MockGuard { + match case.provider { + LlmProvider::OpenAI => { + let arguments = expected_args.to_string(); + let event_body = json!({ + "choices": [{ + "delta": { + "tool_calls": [{ + "function": { + "name": SAMPLE_TOOL_NAME, + "arguments": arguments + } + }] + } + }] + }); + let sse = format!( + "data: {}\n\ndata: [DONE]\n\n", + event_body + ); + Mock::given(method("POST")) + .and(path(case.completion_path)) + .respond_with( + ResponseTemplate::new(200) + .set_body_raw(sse, "text/event-stream"), + ) + .mount(server) + .await + } + LlmProvider::Claude => { + let event_tool = json!({ + "type": "content_block_start", + "data": { + "content_block": { + "type": "tool_use", + "name": SAMPLE_TOOL_NAME, + "input": expected_args + } + } + }); + let event_stop = json!({ + "type": "message_stop", + "data": {} + }); + let sse = format!( + "data: {}\n\ndata: {}\n\ndata: [DONE]\n\n", + event_tool, event_stop + ); + Mock::given(method("POST")) + .and(path(case.completion_path)) + .respond_with( + ResponseTemplate::new(200) + .set_body_raw(sse, "text/event-stream"), + ) + .mount(server) + .await + } + LlmProvider::Gemini => { + let body = json!({ + "candidates": [{ + "content": { + "role": "model", + "parts": [{ + "functionCall": { + "name": SAMPLE_TOOL_NAME, + "args": expected_args + } + }] + }, + "finishReason": "STOP" + }] + }); + Mock::given(method("POST")) + .and(path(case.completion_path)) + .respond_with(ResponseTemplate::new(200).set_body_json(body)) + .mount(server) + .await + } + LlmProvider::Ollama => { + let first = json!({ + "message": { + "role": "assistant", + "content": "", + "tool_calls": [{ + "function": { + "name": SAMPLE_TOOL_NAME, + "arguments": expected_args + } + }] + }, + "done": false + }); + let second = json!({ + "message": { + "role": "assistant", + "content": "finished" + }, + "done": true + }); + let body = format!("{}\n{}\n", first, second); + Mock::given(method("POST")) + .and(path(case.completion_path)) + .respond_with( + ResponseTemplate::new(200) + .set_body_raw(body, "application/x-ndjson"), + ) + .mount(server) + .await + } + other => panic!("Unsupported provider in test: {:?}", other), + } +} + +fn provider_cases() -> Vec { + vec![ + ProviderCase { + name: "openai", + provider: LlmProvider::OpenAI, + model: "gpt-4o-mini", + env_var: "OPENAI_BASE_URL", + completion_path: "/v1/chat/completions", + }, + ProviderCase { + name: "gemini", + provider: LlmProvider::Gemini, + model: "gemini-1.5-flash", + env_var: "GEMINI_BASE_URL", + completion_path: "/v1beta/models/gemini-1.5-flash:streamGenerateContent", + }, + ProviderCase { + name: "anthropic", + provider: LlmProvider::Claude, + model: "claude-3-opus-20240229", + env_var: "ANTHROPIC_BASE_URL", + completion_path: "/v1/messages", + }, + ProviderCase { + name: "ollama", + provider: LlmProvider::Ollama, + model: "llama3", + env_var: "OLLAMA_BASE_URL", + completion_path: "/api/chat", + }, + ] +} + +#[tokio::test] +async fn worker_executes_tools_and_hooks_across_mocked_providers() { + let env_lock = ENV_MUTEX.get_or_init(|| Mutex::new(())); + + for case in provider_cases() { + let tool_calls = Arc::new(Mutex::new(Vec::::new())); + let hook_events = Arc::new(Mutex::new(Vec::::new())); + + let _env_guard = env_lock.lock().unwrap(); + + let server = MockServer::start().await; + let _env_override = EnvOverride::set(case.env_var, server.uri()); + + let expected_args = json!({ + "provider": case.name, + "request_id": 1 + }); + + let _mock = setup_mock_response(&case, &server, &expected_args).await; + + let mut blueprint = build_blueprint( + case.provider, + case.model, + case.name, + Arc::clone(&tool_calls), + Arc::clone(&hook_events), + ); + + let mut worker = blueprint.instantiate().expect("worker to instantiate"); + + let mut stream = worker + .process_task_stream( + "Trigger the sample tool".to_string(), + None, + ) + .await; + + let mut events = Vec::new(); + while let Some(event) = stream.next().await { + events.push(event.expect("stream event")); + } + + let requests = server + .received_requests() + .await + .expect("to inspect received requests"); + assert_eq!( + requests.len(), + 1, + "expected exactly one request for provider {}", + case.name + ); + let body: Value = + serde_json::from_slice(&requests[0].body).expect("request body to be JSON"); + + match case.provider { + LlmProvider::OpenAI => { + assert_eq!(body["model"], case.model); + assert_eq!(body["stream"], true); + assert_eq!( + body["tools"][0]["function"]["name"], + SAMPLE_TOOL_NAME + ); + } + LlmProvider::Claude => { + assert_eq!(body["model"], case.model); + assert_eq!(body["stream"], true); + assert_eq!(body["tools"][0]["name"], SAMPLE_TOOL_NAME); + } + LlmProvider::Gemini => { + assert_eq!( + body["contents"] + .as_array() + .expect("contents to be array") + .len(), + 2, + "system + user messages should be present" + ); + let tools = body["tools"][0]["functionDeclarations"] + .as_array() + .expect("function declarations to exist"); + assert_eq!(tools[0]["name"], SAMPLE_TOOL_NAME); + } + LlmProvider::Ollama => { + assert_eq!(body["model"], case.model); + assert_eq!(body["stream"], true); + assert_eq!( + body["tools"][0]["function"]["name"], + SAMPLE_TOOL_NAME + ); + } + _ => unreachable!(), + } + + let recorded_calls = tool_calls.lock().unwrap().clone(); + assert_eq!( + recorded_calls.len(), + 1, + "tool should execute exactly once for {}", + case.name + ); + assert_eq!( + recorded_calls[0], expected_args, + "tool arguments should match for {}", + case.name + ); + + let recorded_hooks = hook_events.lock().unwrap().clone(); + assert!( + recorded_hooks + .iter() + .any(|entry| entry.contains(case.name) && entry.contains(SAMPLE_TOOL_NAME)), + "hook should capture tool usage for {}: {:?}", + case.name, + recorded_hooks + ); + + let mut saw_tool_call = false; + let mut saw_tool_result = false; + let mut saw_hook_message = false; + let mut saw_completion = false; + + for event in &events { + match event { + StreamEvent::ToolCall(call) => { + saw_tool_call = true; + assert_eq!(call.name, SAMPLE_TOOL_NAME); + assert_eq!( + serde_json::from_str::(&call.arguments) + .expect("tool arguments to be JSON"), + expected_args + ); + } + StreamEvent::ToolResult { tool_name, result } => { + if tool_name == SAMPLE_TOOL_NAME { + saw_tool_result = true; + let value = result + .as_ref() + .expect("tool execution should succeed"); + assert_eq!(value["status"], "ok"); + assert_eq!(value["provider"], case.name); + } + } + StreamEvent::HookMessage { hook_name, content, .. } => { + if hook_name == "recording_hook" { + saw_hook_message = true; + assert!( + content.contains(case.name), + "hook content should mention provider" + ); + } + } + StreamEvent::Completion(_) => { + saw_completion = true; + } + _ => {} + } + } + + assert!(saw_tool_call, "missing tool call event for {}", case.name); + assert!( + saw_tool_result, + "missing tool result event for {}", + case.name + ); + assert!( + saw_hook_message, + "missing hook message for {}", + case.name + ); + assert!( + saw_completion, + "missing completion event for {}", + case.name + ); + + std::env::remove_var(case.env_var); + } +}