diff --git a/Cargo.lock b/Cargo.lock index 443dbf8..d641a42 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -384,6 +384,12 @@ dependencies = [ "wasip2", ] +[[package]] +name = "glob" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" + [[package]] name = "h2" version = "0.4.13" @@ -726,6 +732,7 @@ dependencies = [ "tokio-util", "tracing", "tracing-subscriber", + "trybuild", ] [[package]] @@ -1178,6 +1185,15 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_spanned" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8bbf91e5a4d6315eee45e704372590b30e260ee83af6639d64557f51b067776" +dependencies = [ + "serde_core", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -1264,6 +1280,12 @@ dependencies = [ "syn", ] +[[package]] +name = "target-triple" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "591ef38edfb78ca4771ee32cf494cb8771944bee237a9b91fc9c1424ac4b777b" + [[package]] name = "tempfile" version = "3.24.0" @@ -1277,6 +1299,15 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + [[package]] name = "thiserror" version = "2.0.17" @@ -1375,6 +1406,45 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml" +version = "1.0.3+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7614eaf19ad818347db24addfa201729cf2a9b6fdfd9eb0ab870fcacc606c0c" +dependencies = [ + "indexmap", + "serde_core", + "serde_spanned", + "toml_datetime", + "toml_parser", + "toml_writer", + "winnow", +] + +[[package]] +name = "toml_datetime" +version = "1.0.0+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32c2555c699578a4f59f0cc68e5116c8d7cabbd45e1409b989d4be085b53f13e" +dependencies = [ + "serde_core", +] + +[[package]] +name = "toml_parser" +version = "1.0.9+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4" +dependencies = [ + "winnow", +] + +[[package]] +name = "toml_writer" +version = "1.0.6+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab16f14aed21ee8bfd8ec22513f7287cd4a91aa92e44edfe2c17ddd004e92607" + [[package]] name = "tower" version = "0.5.2" @@ -1487,6 +1557,21 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "trybuild" +version = "1.0.116" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47c635f0191bd3a2941013e5062667100969f8c4e9cd787c14f977265d73616e" +dependencies = [ + "glob", + "serde", + "serde_derive", + "serde_json", + "target-triple", + "termcolor", + "toml", +] + [[package]] name = "unicode-ident" version = "1.0.22" @@ -1640,6 +1725,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "windows-link" version = "0.2.1" @@ -1802,6 +1896,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" +[[package]] +name = "winnow" +version = "0.7.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a5364e9d77fcdeeaa6062ced926ee3381faa2ee02d3eb83a5c27a8825540829" + [[package]] name = "wit-bindgen" version = "0.46.0" diff --git a/llm-worker/Cargo.toml b/llm-worker/Cargo.toml index 727ed92..e268175 100644 --- a/llm-worker/Cargo.toml +++ b/llm-worker/Cargo.toml @@ -26,3 +26,4 @@ schemars = "1.2" tempfile = "3.24" dotenv = "0.15" tracing-subscriber = { version = "0.3", features = ["env-filter"] } +trybuild = "1.0.116" diff --git a/llm-worker/tests/compile_fail.rs b/llm-worker/tests/compile_fail.rs new file mode 100644 index 0000000..fa876c2 --- /dev/null +++ b/llm-worker/tests/compile_fail.rs @@ -0,0 +1,6 @@ +#[test] +fn compile_fail_state_constraints() { + let t = trybuild::TestCases::new(); + t.compile_fail("tests/ui/cache_locked_register_tool.rs"); + t.compile_fail("tests/ui/tool_server_handle_register_tool.rs"); +} diff --git a/llm-worker/tests/ui/cache_locked_register_tool.rs b/llm-worker/tests/ui/cache_locked_register_tool.rs new file mode 100644 index 0000000..4f8be01 --- /dev/null +++ b/llm-worker/tests/ui/cache_locked_register_tool.rs @@ -0,0 +1,11 @@ +use llm_worker::Worker; +use llm_worker::llm_client::providers::ollama::OllamaClient; +use std::sync::Arc; + +fn main() { + let client = OllamaClient::new("dummy-model"); + let worker = Worker::new(client); + let mut locked = worker.lock(); + let def: llm_worker::tool::ToolDefinition = Arc::new(|| panic!("unused")); + let _ = locked.register_tool(def); +} diff --git a/llm-worker/tests/ui/cache_locked_register_tool.stderr b/llm-worker/tests/ui/cache_locked_register_tool.stderr new file mode 100644 index 0000000..0c3b097 --- /dev/null +++ b/llm-worker/tests/ui/cache_locked_register_tool.stderr @@ -0,0 +1,8 @@ +error[E0599]: no method named `register_tool` found for struct `Worker` in the current scope + --> tests/ui/cache_locked_register_tool.rs:10:20 + | +10 | let _ = locked.register_tool(def); + | ^^^^^^^^^^^^^ method not found in `Worker` + | + = note: the method was found for + - `Worker` diff --git a/llm-worker/tests/ui/tool_server_handle_register_tool.rs b/llm-worker/tests/ui/tool_server_handle_register_tool.rs new file mode 100644 index 0000000..ff75d0c --- /dev/null +++ b/llm-worker/tests/ui/tool_server_handle_register_tool.rs @@ -0,0 +1,11 @@ +use llm_worker::Worker; +use llm_worker::llm_client::providers::ollama::OllamaClient; +use std::sync::Arc; + +fn main() { + let client = OllamaClient::new("dummy-model"); + let worker = Worker::new(client); + let handle = worker.tool_server_handle(); + let def: llm_worker::tool::ToolDefinition = Arc::new(|| panic!("unused")); + let _ = handle.register_tool(def); +} diff --git a/llm-worker/tests/ui/tool_server_handle_register_tool.stderr b/llm-worker/tests/ui/tool_server_handle_register_tool.stderr new file mode 100644 index 0000000..57d7139 --- /dev/null +++ b/llm-worker/tests/ui/tool_server_handle_register_tool.stderr @@ -0,0 +1,13 @@ +error[E0624]: method `register_tool` is private + --> tests/ui/tool_server_handle_register_tool.rs:10:20 + | +10 | let _ = handle.register_tool(def); + | ^^^^^^^^^^^^^ private method + | + ::: src/tool_server.rs + | + | / pub(crate) fn register_tool( + | | &self, + | | factory: WorkerToolDefinition, + | | ) -> Result<(), ToolServerError> { + | |____________________________________- private method defined here diff --git a/llm-worker/tests/worker_state_test.rs b/llm-worker/tests/worker_state_test.rs index 496f6da..c4f92a0 100644 --- a/llm-worker/tests/worker_state_test.rs +++ b/llm-worker/tests/worker_state_test.rs @@ -5,9 +5,14 @@ mod common; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use async_trait::async_trait; use common::MockLlmClient; use llm_worker::Worker; use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent}; +use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta}; use llm_worker::Item; // ============================================================================= @@ -91,6 +96,56 @@ fn test_mutable_extend_history() { assert_eq!(worker.history().len(), 4); } +#[derive(Clone)] +struct CountingTool { + name: String, + calls: Arc, +} + +impl CountingTool { + fn new(name: impl Into) -> Self { + Self { + name: name.into(), + calls: Arc::new(AtomicUsize::new(0)), + } + } + + fn definition(&self) -> ToolDefinition { + let tool = self.clone(); + Arc::new(move || { + ( + ToolMeta::new(&tool.name) + .description("Counting tool") + .input_schema(serde_json::json!({"type":"object","properties":{}})), + Arc::new(tool.clone()) as Arc, + ) + }) + } + + fn call_count(&self) -> usize { + self.calls.load(Ordering::SeqCst) + } +} + +#[async_trait] +impl Tool for CountingTool { + async fn execute(&self, _input_json: &str) -> Result { + self.calls.fetch_add(1, Ordering::SeqCst); + Ok(format!("{}-ok", self.name)) + } +} + +/// Verify that tools can be registered in Mutable state. +#[test] +fn test_mutable_can_register_tool() { + let client = MockLlmClient::new(vec![]); + let mut worker = Worker::new(client); + let tool = CountingTool::new("count_tool"); + + let result = worker.register_tool(tool.definition()); + assert!(result.is_ok(), "Mutable should allow tool registration"); +} + // ============================================================================= // State Transition Tests // ============================================================================= @@ -330,6 +385,67 @@ async fn test_unlock_edit_relock() { assert_eq!(relocked.locked_prefix_len(), 1); } +/// Verify that tools registered before lock and after unlock remain effective. +#[tokio::test] +async fn test_lock_unlock_relock_tools_remain_effective() { + let client = MockLlmClient::with_responses(vec![ + vec![ + Event::tool_use_start(0, "call_1", "tool_a"), + Event::tool_input_delta(0, r#"{}"#), + Event::tool_use_stop(0), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ], + vec![ + Event::text_block_start(0), + Event::text_delta(0, "done-a"), + Event::text_block_stop(0, None), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ], + vec![ + Event::tool_use_start(0, "call_2", "tool_b"), + Event::tool_input_delta(0, r#"{}"#), + Event::tool_use_stop(0), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ], + vec![ + Event::text_block_start(0), + Event::text_delta(0, "done-b"), + Event::text_block_stop(0, None), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ], + ]); + + let mut worker = Worker::new(client); + let tool_a = CountingTool::new("tool_a"); + worker + .register_tool(tool_a.definition()) + .expect("register tool_a should succeed"); + + let mut locked = worker.lock(); + locked.run("first").await.expect("first run"); + assert_eq!(tool_a.call_count(), 1, "tool_a should be called once"); + + let mut unlocked = locked.unlock(); + let tool_b = CountingTool::new("tool_b"); + unlocked + .register_tool(tool_b.definition()) + .expect("register tool_b after unlock should succeed"); + + let mut relocked = unlocked.lock(); + relocked.run("second").await.expect("second run"); + + assert_eq!(tool_a.call_count(), 1, "tool_a should not be called again"); + assert_eq!(tool_b.call_count(), 1, "tool_b should be called once"); +} + // ============================================================================= // System Prompt Preservation Tests // =============================================================================