refactor worker tool execution through tool server

This commit is contained in:
Keisuke Hirata 2026-02-19 18:19:16 +09:00
parent 3c62970967
commit b12785ed93
3 changed files with 214 additions and 39 deletions

View File

@ -46,6 +46,7 @@ pub mod state;
pub mod subscriber; pub mod subscriber;
pub mod timeline; pub mod timeline;
pub mod tool; pub mod tool;
pub mod tool_server;
pub use message::{ContentPart, Item, Message, Role}; pub use message::{ContentPart, Item, Message, Role};
pub use worker::{ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult}; pub use worker::{ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult};

View File

@ -0,0 +1,179 @@
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use thiserror::Error;
use crate::llm_client::ToolDefinition as LlmToolDefinition;
use crate::tool::{Tool, ToolDefinition as WorkerToolDefinition, ToolMeta};
type ToolMap = HashMap<String, (ToolMeta, Arc<dyn Tool>)>;
/// Errors produced by ToolServer operations.
#[derive(Debug, Error, PartialEq, Eq)]
pub enum ToolServerError {
/// A tool with the same name already exists.
#[error("Tool with name '{0}' already registered")]
DuplicateName(String),
/// Requested tool was not found.
#[error("Tool '{0}' not found")]
ToolNotFound(String),
/// Tool execution failed.
#[error("Tool execution failed: {0}")]
ToolExecution(String),
}
/// In-memory tool server.
#[derive(Clone, Default)]
pub struct ToolServer {
tools: Arc<Mutex<ToolMap>>,
}
impl ToolServer {
/// Create a new empty tool server.
pub fn new() -> Self {
Self::default()
}
/// Create a handle for shared access.
pub fn handle(&self) -> ToolServerHandle {
ToolServerHandle {
tools: Arc::clone(&self.tools),
}
}
}
/// Shareable handle to a tool server.
#[derive(Clone, Default)]
pub struct ToolServerHandle {
tools: Arc<Mutex<ToolMap>>,
}
impl ToolServerHandle {
/// Register one tool.
pub fn register_tool(&self, factory: WorkerToolDefinition) -> Result<(), ToolServerError> {
let (meta, instance) = factory();
let mut guard = self.tools.lock().unwrap_or_else(|e| e.into_inner());
if guard.contains_key(&meta.name) {
return Err(ToolServerError::DuplicateName(meta.name));
}
guard.insert(meta.name.clone(), (meta, instance));
Ok(())
}
/// Register many tools.
pub fn register_tools(
&self,
factories: impl IntoIterator<Item = WorkerToolDefinition>,
) -> Result<(), ToolServerError> {
for factory in factories {
self.register_tool(factory)?;
}
Ok(())
}
/// Get a tool by name for hook contexts.
pub fn get_tool(&self, name: &str) -> Option<(ToolMeta, Arc<dyn Tool>)> {
let guard = self.tools.lock().unwrap_or_else(|e| e.into_inner());
guard.get(name).map(|(meta, tool)| (meta.clone(), Arc::clone(tool)))
}
/// Execute a tool by name.
pub async fn call_tool(&self, name: &str, input_json: &str) -> Result<String, ToolServerError> {
let tool = {
let guard = self.tools.lock().unwrap_or_else(|e| e.into_inner());
let (_, tool) = guard
.get(name)
.ok_or_else(|| ToolServerError::ToolNotFound(name.to_string()))?;
Arc::clone(tool)
};
tool.execute(input_json)
.await
.map_err(|e| ToolServerError::ToolExecution(e.to_string()))
}
/// Build deterministic tool definitions sorted by tool name.
pub fn tool_definitions_sorted(&self) -> Vec<LlmToolDefinition> {
let guard = self.tools.lock().unwrap_or_else(|e| e.into_inner());
let mut defs: Vec<_> = guard
.values()
.map(|(meta, _)| {
LlmToolDefinition::new(&meta.name)
.description(&meta.description)
.input_schema(meta.input_schema.clone())
})
.collect();
defs.sort_by(|a, b| a.name.cmp(&b.name));
defs
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::json;
use super::*;
use crate::tool::{Tool, ToolDefinition, ToolError, ToolMeta};
struct EchoTool;
#[async_trait]
impl Tool for EchoTool {
async fn execute(&self, input_json: &str) -> Result<String, ToolError> {
Ok(input_json.to_string())
}
}
fn def(name: &'static str) -> ToolDefinition {
Arc::new(move || {
(
ToolMeta::new(name)
.description(format!("desc-{name}"))
.input_schema(json!({"type":"object"})),
Arc::new(EchoTool) as Arc<dyn Tool>,
)
})
}
#[test]
fn register_duplicate_name_fails() {
let handle = ToolServer::new().handle();
handle.register_tool(def("alpha")).expect("first register");
let err = handle
.register_tool(def("alpha"))
.expect_err("duplicate should fail");
assert_eq!(err, ToolServerError::DuplicateName("alpha".to_string()));
}
#[tokio::test]
async fn call_tool_success_and_not_found() {
let handle = ToolServer::new().handle();
handle.register_tool(def("echo")).expect("register");
let out = handle.call_tool("echo", r#"{"x":1}"#).await.expect("call");
assert_eq!(out, r#"{"x":1}"#);
let err = handle
.call_tool("missing", "{}")
.await
.expect_err("missing tool");
assert_eq!(err, ToolServerError::ToolNotFound("missing".to_string()));
}
#[test]
fn tool_definitions_are_sorted() {
let handle = ToolServer::new().handle();
handle.register_tool(def("zeta")).expect("register zeta");
handle.register_tool(def("alpha")).expect("register alpha");
handle.register_tool(def("beta")).expect("register beta");
let names: Vec<_> = handle
.tool_definitions_sorted()
.into_iter()
.map(|d| d.name)
.collect();
assert_eq!(names, vec!["alpha", "beta", "zeta"]);
}
}

View File

@ -19,8 +19,9 @@ use crate::{
ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter, ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter,
ToolUseBlockSubscriberAdapter, UsageSubscriberAdapter, WorkerSubscriber, ToolUseBlockSubscriberAdapter, UsageSubscriberAdapter, WorkerSubscriber,
}, },
tool_server::{ToolServer, ToolServerError, ToolServerHandle},
timeline::{TextBlockCollector, Timeline, ToolCallCollector}, timeline::{TextBlockCollector, Timeline, ToolCallCollector},
tool::{Tool, ToolDefinition as WorkerToolDefinition, ToolError, ToolMeta}, tool::{ToolDefinition as WorkerToolDefinition, ToolError},
}; };
// ============================================================================= // =============================================================================
@ -163,8 +164,8 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
text_block_collector: TextBlockCollector, text_block_collector: TextBlockCollector,
/// Tool call collector (Timeline handler) /// Tool call collector (Timeline handler)
tool_call_collector: ToolCallCollector, tool_call_collector: ToolCallCollector,
/// Registered tools (meta, instance) /// Tool server handle
tools: HashMap<String, (ToolMeta, Arc<dyn Tool>)>, tool_server: ToolServerHandle,
/// Hook registry /// Hook registry
hooks: HookRegistry, hooks: HookRegistry,
/// System prompt /// System prompt
@ -316,12 +317,13 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
/// worker.register_tool(def)?; /// worker.register_tool(def)?;
/// ``` /// ```
pub fn register_tool(&mut self, factory: WorkerToolDefinition) -> Result<(), ToolRegistryError> { pub fn register_tool(&mut self, factory: WorkerToolDefinition) -> Result<(), ToolRegistryError> {
let (meta, instance) = factory(); match self.tool_server.register_tool(factory) {
if self.tools.contains_key(&meta.name) { Ok(()) => Ok(()),
return Err(ToolRegistryError::DuplicateName(meta.name.clone())); Err(ToolServerError::DuplicateName(name)) => Err(ToolRegistryError::DuplicateName(name)),
Err(ToolServerError::ToolNotFound(_) | ToolServerError::ToolExecution(_)) => {
unreachable!("register_tool should only fail with DuplicateName")
}
} }
self.tools.insert(meta.name.clone(), (meta, instance));
Ok(())
} }
/// Register multiple tools /// Register multiple tools
@ -329,10 +331,18 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
&mut self, &mut self,
factories: impl IntoIterator<Item = WorkerToolDefinition>, factories: impl IntoIterator<Item = WorkerToolDefinition>,
) -> Result<(), ToolRegistryError> { ) -> Result<(), ToolRegistryError> {
for factory in factories { match self.tool_server.register_tools(factories) {
self.register_tool(factory)?; Ok(()) => Ok(()),
Err(ToolServerError::DuplicateName(name)) => Err(ToolRegistryError::DuplicateName(name)),
Err(ToolServerError::ToolNotFound(_) | ToolServerError::ToolExecution(_)) => {
unreachable!("register_tools should only fail with DuplicateName")
} }
Ok(()) }
}
/// Get a shared tool server handle.
pub fn tool_server_handle(&self) -> ToolServerHandle {
self.tool_server.clone()
} }
/// Add an on_prompt_submit Hook /// Add an on_prompt_submit Hook
@ -508,14 +518,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
/// Generate list of ToolDefinitions for LLM from registered tools /// Generate list of ToolDefinitions for LLM from registered tools
fn build_tool_definitions(&self) -> Vec<ToolDefinition> { fn build_tool_definitions(&self) -> Vec<ToolDefinition> {
self.tools self.tool_server.tool_definitions_sorted()
.values()
.map(|(meta, _)| {
ToolDefinition::new(&meta.name)
.description(&meta.description)
.input_schema(meta.input_schema.clone())
})
.collect()
} }
/// Build assistant response items from text blocks and tool calls /// Build assistant response items from text blocks and tool calls
@ -715,12 +718,12 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
let mut approved_calls = Vec::new(); let mut approved_calls = Vec::new();
for mut tool_call in tool_calls { for mut tool_call in tool_calls {
// Get tool definition // Get tool definition
if let Some((meta, tool)) = self.tools.get(&tool_call.name) { if let Some((meta, tool)) = self.tool_server.get_tool(&tool_call.name) {
// Create context // Create context
let mut context = ToolCallContext { let mut context = ToolCallContext {
call: tool_call.clone(), call: tool_call.clone(),
meta: meta.clone(), meta,
tool: tool.clone(), tool,
}; };
let mut skip = false; let mut skip = false;
@ -753,7 +756,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
if !skip { if !skip {
call_info_map.insert( call_info_map.insert(
tool_call.id.clone(), tool_call.id.clone(),
(tool_call.clone(), meta.clone(), tool.clone()), (tool_call.clone(), context.meta.clone(), context.tool.clone()),
); );
approved_calls.push(tool_call); approved_calls.push(tool_call);
} }
@ -768,21 +771,13 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
let futures: Vec<_> = approved_calls let futures: Vec<_> = approved_calls
.into_iter() .into_iter()
.map(|tool_call| { .map(|tool_call| {
let tools = &self.tools; let tool_server = self.tool_server.clone();
async move { async move {
if let Some((_, tool)) = tools.get(&tool_call.name) { let input_json = serde_json::to_string(&tool_call.input).unwrap_or_default();
let input_json = match tool_server.call_tool(&tool_call.name, &input_json).await {
serde_json::to_string(&tool_call.input).unwrap_or_default();
match tool.execute(&input_json).await {
Ok(content) => ToolResult::success(&tool_call.id, content), Ok(content) => ToolResult::success(&tool_call.id, content),
Err(e) => ToolResult::error(&tool_call.id, e.to_string()), Err(e) => ToolResult::error(&tool_call.id, e.to_string()),
} }
} else {
ToolResult::error(
&tool_call.id,
format!("Tool '{}' not found", tool_call.name),
)
}
} }
}) })
.collect(); .collect();
@ -1049,7 +1044,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
timeline, timeline,
text_block_collector, text_block_collector,
tool_call_collector, tool_call_collector,
tools: HashMap::new(), tool_server: ToolServer::new().handle(),
hooks: HookRegistry::new(), hooks: HookRegistry::new(),
system_prompt: None, system_prompt: None,
history: Vec::new(), history: Vec::new(),
@ -1220,7 +1215,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
timeline: self.timeline, timeline: self.timeline,
text_block_collector: self.text_block_collector, text_block_collector: self.text_block_collector,
tool_call_collector: self.tool_call_collector, tool_call_collector: self.tool_call_collector,
tools: self.tools, tool_server: self.tool_server,
hooks: self.hooks, hooks: self.hooks,
system_prompt: self.system_prompt, system_prompt: self.system_prompt,
history: self.history, history: self.history,
@ -1256,7 +1251,7 @@ impl<C: LlmClient> Worker<C, CacheLocked> {
timeline: self.timeline, timeline: self.timeline,
text_block_collector: self.text_block_collector, text_block_collector: self.text_block_collector,
tool_call_collector: self.tool_call_collector, tool_call_collector: self.tool_call_collector,
tools: self.tools, tool_server: self.tool_server,
hooks: self.hooks, hooks: self.hooks,
system_prompt: self.system_prompt, system_prompt: self.system_prompt,
history: self.history, history: self.history,