diff --git a/llm-worker/src/lib.rs b/llm-worker/src/lib.rs index b26413b..a958bfe 100644 --- a/llm-worker/src/lib.rs +++ b/llm-worker/src/lib.rs @@ -46,6 +46,7 @@ pub mod state; pub mod subscriber; pub mod timeline; pub mod tool; +pub mod tool_server; pub use message::{ContentPart, Item, Message, Role}; pub use worker::{ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult}; diff --git a/llm-worker/src/tool_server.rs b/llm-worker/src/tool_server.rs new file mode 100644 index 0000000..8ebad01 --- /dev/null +++ b/llm-worker/src/tool_server.rs @@ -0,0 +1,179 @@ +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +use thiserror::Error; + +use crate::llm_client::ToolDefinition as LlmToolDefinition; +use crate::tool::{Tool, ToolDefinition as WorkerToolDefinition, ToolMeta}; + +type ToolMap = HashMap)>; + +/// Errors produced by ToolServer operations. +#[derive(Debug, Error, PartialEq, Eq)] +pub enum ToolServerError { + /// A tool with the same name already exists. + #[error("Tool with name '{0}' already registered")] + DuplicateName(String), + /// Requested tool was not found. + #[error("Tool '{0}' not found")] + ToolNotFound(String), + /// Tool execution failed. + #[error("Tool execution failed: {0}")] + ToolExecution(String), +} + +/// In-memory tool server. +#[derive(Clone, Default)] +pub struct ToolServer { + tools: Arc>, +} + +impl ToolServer { + /// Create a new empty tool server. + pub fn new() -> Self { + Self::default() + } + + /// Create a handle for shared access. + pub fn handle(&self) -> ToolServerHandle { + ToolServerHandle { + tools: Arc::clone(&self.tools), + } + } +} + +/// Shareable handle to a tool server. +#[derive(Clone, Default)] +pub struct ToolServerHandle { + tools: Arc>, +} + +impl ToolServerHandle { + /// Register one tool. + pub fn register_tool(&self, factory: WorkerToolDefinition) -> Result<(), ToolServerError> { + let (meta, instance) = factory(); + let mut guard = self.tools.lock().unwrap_or_else(|e| e.into_inner()); + if guard.contains_key(&meta.name) { + return Err(ToolServerError::DuplicateName(meta.name)); + } + guard.insert(meta.name.clone(), (meta, instance)); + Ok(()) + } + + /// Register many tools. + pub fn register_tools( + &self, + factories: impl IntoIterator, + ) -> Result<(), ToolServerError> { + for factory in factories { + self.register_tool(factory)?; + } + Ok(()) + } + + /// Get a tool by name for hook contexts. + pub fn get_tool(&self, name: &str) -> Option<(ToolMeta, Arc)> { + let guard = self.tools.lock().unwrap_or_else(|e| e.into_inner()); + guard.get(name).map(|(meta, tool)| (meta.clone(), Arc::clone(tool))) + } + + /// Execute a tool by name. + pub async fn call_tool(&self, name: &str, input_json: &str) -> Result { + let tool = { + let guard = self.tools.lock().unwrap_or_else(|e| e.into_inner()); + let (_, tool) = guard + .get(name) + .ok_or_else(|| ToolServerError::ToolNotFound(name.to_string()))?; + Arc::clone(tool) + }; + tool.execute(input_json) + .await + .map_err(|e| ToolServerError::ToolExecution(e.to_string())) + } + + /// Build deterministic tool definitions sorted by tool name. + pub fn tool_definitions_sorted(&self) -> Vec { + let guard = self.tools.lock().unwrap_or_else(|e| e.into_inner()); + let mut defs: Vec<_> = guard + .values() + .map(|(meta, _)| { + LlmToolDefinition::new(&meta.name) + .description(&meta.description) + .input_schema(meta.input_schema.clone()) + }) + .collect(); + defs.sort_by(|a, b| a.name.cmp(&b.name)); + defs + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use async_trait::async_trait; + use serde_json::json; + + use super::*; + use crate::tool::{Tool, ToolDefinition, ToolError, ToolMeta}; + + struct EchoTool; + + #[async_trait] + impl Tool for EchoTool { + async fn execute(&self, input_json: &str) -> Result { + Ok(input_json.to_string()) + } + } + + fn def(name: &'static str) -> ToolDefinition { + Arc::new(move || { + ( + ToolMeta::new(name) + .description(format!("desc-{name}")) + .input_schema(json!({"type":"object"})), + Arc::new(EchoTool) as Arc, + ) + }) + } + + #[test] + fn register_duplicate_name_fails() { + let handle = ToolServer::new().handle(); + handle.register_tool(def("alpha")).expect("first register"); + let err = handle + .register_tool(def("alpha")) + .expect_err("duplicate should fail"); + assert_eq!(err, ToolServerError::DuplicateName("alpha".to_string())); + } + + #[tokio::test] + async fn call_tool_success_and_not_found() { + let handle = ToolServer::new().handle(); + handle.register_tool(def("echo")).expect("register"); + + let out = handle.call_tool("echo", r#"{"x":1}"#).await.expect("call"); + assert_eq!(out, r#"{"x":1}"#); + + let err = handle + .call_tool("missing", "{}") + .await + .expect_err("missing tool"); + assert_eq!(err, ToolServerError::ToolNotFound("missing".to_string())); + } + + #[test] + fn tool_definitions_are_sorted() { + let handle = ToolServer::new().handle(); + handle.register_tool(def("zeta")).expect("register zeta"); + handle.register_tool(def("alpha")).expect("register alpha"); + handle.register_tool(def("beta")).expect("register beta"); + + let names: Vec<_> = handle + .tool_definitions_sorted() + .into_iter() + .map(|d| d.name) + .collect(); + assert_eq!(names, vec!["alpha", "beta", "zeta"]); + } +} diff --git a/llm-worker/src/worker.rs b/llm-worker/src/worker.rs index b3bec79..7b42087 100644 --- a/llm-worker/src/worker.rs +++ b/llm-worker/src/worker.rs @@ -19,8 +19,9 @@ use crate::{ ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter, ToolUseBlockSubscriberAdapter, UsageSubscriberAdapter, WorkerSubscriber, }, + tool_server::{ToolServer, ToolServerError, ToolServerHandle}, timeline::{TextBlockCollector, Timeline, ToolCallCollector}, - tool::{Tool, ToolDefinition as WorkerToolDefinition, ToolError, ToolMeta}, + tool::{ToolDefinition as WorkerToolDefinition, ToolError}, }; // ============================================================================= @@ -163,8 +164,8 @@ pub struct Worker { text_block_collector: TextBlockCollector, /// Tool call collector (Timeline handler) tool_call_collector: ToolCallCollector, - /// Registered tools (meta, instance) - tools: HashMap)>, + /// Tool server handle + tool_server: ToolServerHandle, /// Hook registry hooks: HookRegistry, /// System prompt @@ -316,12 +317,13 @@ impl Worker { /// worker.register_tool(def)?; /// ``` pub fn register_tool(&mut self, factory: WorkerToolDefinition) -> Result<(), ToolRegistryError> { - let (meta, instance) = factory(); - if self.tools.contains_key(&meta.name) { - return Err(ToolRegistryError::DuplicateName(meta.name.clone())); + match self.tool_server.register_tool(factory) { + Ok(()) => Ok(()), + Err(ToolServerError::DuplicateName(name)) => Err(ToolRegistryError::DuplicateName(name)), + Err(ToolServerError::ToolNotFound(_) | ToolServerError::ToolExecution(_)) => { + unreachable!("register_tool should only fail with DuplicateName") + } } - self.tools.insert(meta.name.clone(), (meta, instance)); - Ok(()) } /// Register multiple tools @@ -329,10 +331,18 @@ impl Worker { &mut self, factories: impl IntoIterator, ) -> Result<(), ToolRegistryError> { - for factory in factories { - self.register_tool(factory)?; + match self.tool_server.register_tools(factories) { + Ok(()) => Ok(()), + Err(ToolServerError::DuplicateName(name)) => Err(ToolRegistryError::DuplicateName(name)), + Err(ToolServerError::ToolNotFound(_) | ToolServerError::ToolExecution(_)) => { + unreachable!("register_tools should only fail with DuplicateName") + } } - Ok(()) + } + + /// Get a shared tool server handle. + pub fn tool_server_handle(&self) -> ToolServerHandle { + self.tool_server.clone() } /// Add an on_prompt_submit Hook @@ -508,14 +518,7 @@ impl Worker { /// Generate list of ToolDefinitions for LLM from registered tools fn build_tool_definitions(&self) -> Vec { - self.tools - .values() - .map(|(meta, _)| { - ToolDefinition::new(&meta.name) - .description(&meta.description) - .input_schema(meta.input_schema.clone()) - }) - .collect() + self.tool_server.tool_definitions_sorted() } /// Build assistant response items from text blocks and tool calls @@ -715,12 +718,12 @@ impl Worker { let mut approved_calls = Vec::new(); for mut tool_call in tool_calls { // Get tool definition - if let Some((meta, tool)) = self.tools.get(&tool_call.name) { + if let Some((meta, tool)) = self.tool_server.get_tool(&tool_call.name) { // Create context let mut context = ToolCallContext { call: tool_call.clone(), - meta: meta.clone(), - tool: tool.clone(), + meta, + tool, }; let mut skip = false; @@ -753,7 +756,7 @@ impl Worker { if !skip { call_info_map.insert( tool_call.id.clone(), - (tool_call.clone(), meta.clone(), tool.clone()), + (tool_call.clone(), context.meta.clone(), context.tool.clone()), ); approved_calls.push(tool_call); } @@ -768,20 +771,12 @@ impl Worker { let futures: Vec<_> = approved_calls .into_iter() .map(|tool_call| { - let tools = &self.tools; + let tool_server = self.tool_server.clone(); async move { - if let Some((_, tool)) = tools.get(&tool_call.name) { - let input_json = - serde_json::to_string(&tool_call.input).unwrap_or_default(); - match tool.execute(&input_json).await { - Ok(content) => ToolResult::success(&tool_call.id, content), - Err(e) => ToolResult::error(&tool_call.id, e.to_string()), - } - } else { - ToolResult::error( - &tool_call.id, - format!("Tool '{}' not found", tool_call.name), - ) + let input_json = serde_json::to_string(&tool_call.input).unwrap_or_default(); + match tool_server.call_tool(&tool_call.name, &input_json).await { + Ok(content) => ToolResult::success(&tool_call.id, content), + Err(e) => ToolResult::error(&tool_call.id, e.to_string()), } } }) @@ -1049,7 +1044,7 @@ impl Worker { timeline, text_block_collector, tool_call_collector, - tools: HashMap::new(), + tool_server: ToolServer::new().handle(), hooks: HookRegistry::new(), system_prompt: None, history: Vec::new(), @@ -1220,7 +1215,7 @@ impl Worker { timeline: self.timeline, text_block_collector: self.text_block_collector, tool_call_collector: self.tool_call_collector, - tools: self.tools, + tool_server: self.tool_server, hooks: self.hooks, system_prompt: self.system_prompt, history: self.history, @@ -1256,7 +1251,7 @@ impl Worker { timeline: self.timeline, text_block_collector: self.text_block_collector, tool_call_collector: self.tool_call_collector, - tools: self.tools, + tool_server: self.tool_server, hooks: self.hooks, system_prompt: self.system_prompt, history: self.history,