refactor worker tool execution through tool server
This commit is contained in:
parent
3c62970967
commit
b12785ed93
|
|
@ -46,6 +46,7 @@ pub mod state;
|
||||||
pub mod subscriber;
|
pub mod subscriber;
|
||||||
pub mod timeline;
|
pub mod timeline;
|
||||||
pub mod tool;
|
pub mod tool;
|
||||||
|
pub mod tool_server;
|
||||||
|
|
||||||
pub use message::{ContentPart, Item, Message, Role};
|
pub use message::{ContentPart, Item, Message, Role};
|
||||||
pub use worker::{ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult};
|
pub use worker::{ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult};
|
||||||
|
|
|
||||||
179
llm-worker/src/tool_server.rs
Normal file
179
llm-worker/src/tool_server.rs
Normal file
|
|
@ -0,0 +1,179 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
use crate::llm_client::ToolDefinition as LlmToolDefinition;
|
||||||
|
use crate::tool::{Tool, ToolDefinition as WorkerToolDefinition, ToolMeta};
|
||||||
|
|
||||||
|
type ToolMap = HashMap<String, (ToolMeta, Arc<dyn Tool>)>;
|
||||||
|
|
||||||
|
/// Errors produced by ToolServer operations.
|
||||||
|
#[derive(Debug, Error, PartialEq, Eq)]
|
||||||
|
pub enum ToolServerError {
|
||||||
|
/// A tool with the same name already exists.
|
||||||
|
#[error("Tool with name '{0}' already registered")]
|
||||||
|
DuplicateName(String),
|
||||||
|
/// Requested tool was not found.
|
||||||
|
#[error("Tool '{0}' not found")]
|
||||||
|
ToolNotFound(String),
|
||||||
|
/// Tool execution failed.
|
||||||
|
#[error("Tool execution failed: {0}")]
|
||||||
|
ToolExecution(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// In-memory tool server.
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
pub struct ToolServer {
|
||||||
|
tools: Arc<Mutex<ToolMap>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolServer {
|
||||||
|
/// Create a new empty tool server.
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a handle for shared access.
|
||||||
|
pub fn handle(&self) -> ToolServerHandle {
|
||||||
|
ToolServerHandle {
|
||||||
|
tools: Arc::clone(&self.tools),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shareable handle to a tool server.
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
pub struct ToolServerHandle {
|
||||||
|
tools: Arc<Mutex<ToolMap>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolServerHandle {
|
||||||
|
/// Register one tool.
|
||||||
|
pub fn register_tool(&self, factory: WorkerToolDefinition) -> Result<(), ToolServerError> {
|
||||||
|
let (meta, instance) = factory();
|
||||||
|
let mut guard = self.tools.lock().unwrap_or_else(|e| e.into_inner());
|
||||||
|
if guard.contains_key(&meta.name) {
|
||||||
|
return Err(ToolServerError::DuplicateName(meta.name));
|
||||||
|
}
|
||||||
|
guard.insert(meta.name.clone(), (meta, instance));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register many tools.
|
||||||
|
pub fn register_tools(
|
||||||
|
&self,
|
||||||
|
factories: impl IntoIterator<Item = WorkerToolDefinition>,
|
||||||
|
) -> Result<(), ToolServerError> {
|
||||||
|
for factory in factories {
|
||||||
|
self.register_tool(factory)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a tool by name for hook contexts.
|
||||||
|
pub fn get_tool(&self, name: &str) -> Option<(ToolMeta, Arc<dyn Tool>)> {
|
||||||
|
let guard = self.tools.lock().unwrap_or_else(|e| e.into_inner());
|
||||||
|
guard.get(name).map(|(meta, tool)| (meta.clone(), Arc::clone(tool)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute a tool by name.
|
||||||
|
pub async fn call_tool(&self, name: &str, input_json: &str) -> Result<String, ToolServerError> {
|
||||||
|
let tool = {
|
||||||
|
let guard = self.tools.lock().unwrap_or_else(|e| e.into_inner());
|
||||||
|
let (_, tool) = guard
|
||||||
|
.get(name)
|
||||||
|
.ok_or_else(|| ToolServerError::ToolNotFound(name.to_string()))?;
|
||||||
|
Arc::clone(tool)
|
||||||
|
};
|
||||||
|
tool.execute(input_json)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ToolServerError::ToolExecution(e.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build deterministic tool definitions sorted by tool name.
|
||||||
|
pub fn tool_definitions_sorted(&self) -> Vec<LlmToolDefinition> {
|
||||||
|
let guard = self.tools.lock().unwrap_or_else(|e| e.into_inner());
|
||||||
|
let mut defs: Vec<_> = guard
|
||||||
|
.values()
|
||||||
|
.map(|(meta, _)| {
|
||||||
|
LlmToolDefinition::new(&meta.name)
|
||||||
|
.description(&meta.description)
|
||||||
|
.input_schema(meta.input_schema.clone())
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
defs.sort_by(|a, b| a.name.cmp(&b.name));
|
||||||
|
defs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use crate::tool::{Tool, ToolDefinition, ToolError, ToolMeta};
|
||||||
|
|
||||||
|
struct EchoTool;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for EchoTool {
|
||||||
|
async fn execute(&self, input_json: &str) -> Result<String, ToolError> {
|
||||||
|
Ok(input_json.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn def(name: &'static str) -> ToolDefinition {
|
||||||
|
Arc::new(move || {
|
||||||
|
(
|
||||||
|
ToolMeta::new(name)
|
||||||
|
.description(format!("desc-{name}"))
|
||||||
|
.input_schema(json!({"type":"object"})),
|
||||||
|
Arc::new(EchoTool) as Arc<dyn Tool>,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn register_duplicate_name_fails() {
|
||||||
|
let handle = ToolServer::new().handle();
|
||||||
|
handle.register_tool(def("alpha")).expect("first register");
|
||||||
|
let err = handle
|
||||||
|
.register_tool(def("alpha"))
|
||||||
|
.expect_err("duplicate should fail");
|
||||||
|
assert_eq!(err, ToolServerError::DuplicateName("alpha".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn call_tool_success_and_not_found() {
|
||||||
|
let handle = ToolServer::new().handle();
|
||||||
|
handle.register_tool(def("echo")).expect("register");
|
||||||
|
|
||||||
|
let out = handle.call_tool("echo", r#"{"x":1}"#).await.expect("call");
|
||||||
|
assert_eq!(out, r#"{"x":1}"#);
|
||||||
|
|
||||||
|
let err = handle
|
||||||
|
.call_tool("missing", "{}")
|
||||||
|
.await
|
||||||
|
.expect_err("missing tool");
|
||||||
|
assert_eq!(err, ToolServerError::ToolNotFound("missing".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tool_definitions_are_sorted() {
|
||||||
|
let handle = ToolServer::new().handle();
|
||||||
|
handle.register_tool(def("zeta")).expect("register zeta");
|
||||||
|
handle.register_tool(def("alpha")).expect("register alpha");
|
||||||
|
handle.register_tool(def("beta")).expect("register beta");
|
||||||
|
|
||||||
|
let names: Vec<_> = handle
|
||||||
|
.tool_definitions_sorted()
|
||||||
|
.into_iter()
|
||||||
|
.map(|d| d.name)
|
||||||
|
.collect();
|
||||||
|
assert_eq!(names, vec!["alpha", "beta", "zeta"]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -19,8 +19,9 @@ use crate::{
|
||||||
ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter,
|
ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter,
|
||||||
ToolUseBlockSubscriberAdapter, UsageSubscriberAdapter, WorkerSubscriber,
|
ToolUseBlockSubscriberAdapter, UsageSubscriberAdapter, WorkerSubscriber,
|
||||||
},
|
},
|
||||||
|
tool_server::{ToolServer, ToolServerError, ToolServerHandle},
|
||||||
timeline::{TextBlockCollector, Timeline, ToolCallCollector},
|
timeline::{TextBlockCollector, Timeline, ToolCallCollector},
|
||||||
tool::{Tool, ToolDefinition as WorkerToolDefinition, ToolError, ToolMeta},
|
tool::{ToolDefinition as WorkerToolDefinition, ToolError},
|
||||||
};
|
};
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
@ -163,8 +164,8 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
||||||
text_block_collector: TextBlockCollector,
|
text_block_collector: TextBlockCollector,
|
||||||
/// Tool call collector (Timeline handler)
|
/// Tool call collector (Timeline handler)
|
||||||
tool_call_collector: ToolCallCollector,
|
tool_call_collector: ToolCallCollector,
|
||||||
/// Registered tools (meta, instance)
|
/// Tool server handle
|
||||||
tools: HashMap<String, (ToolMeta, Arc<dyn Tool>)>,
|
tool_server: ToolServerHandle,
|
||||||
/// Hook registry
|
/// Hook registry
|
||||||
hooks: HookRegistry,
|
hooks: HookRegistry,
|
||||||
/// System prompt
|
/// System prompt
|
||||||
|
|
@ -316,12 +317,13 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
/// worker.register_tool(def)?;
|
/// worker.register_tool(def)?;
|
||||||
/// ```
|
/// ```
|
||||||
pub fn register_tool(&mut self, factory: WorkerToolDefinition) -> Result<(), ToolRegistryError> {
|
pub fn register_tool(&mut self, factory: WorkerToolDefinition) -> Result<(), ToolRegistryError> {
|
||||||
let (meta, instance) = factory();
|
match self.tool_server.register_tool(factory) {
|
||||||
if self.tools.contains_key(&meta.name) {
|
Ok(()) => Ok(()),
|
||||||
return Err(ToolRegistryError::DuplicateName(meta.name.clone()));
|
Err(ToolServerError::DuplicateName(name)) => Err(ToolRegistryError::DuplicateName(name)),
|
||||||
|
Err(ToolServerError::ToolNotFound(_) | ToolServerError::ToolExecution(_)) => {
|
||||||
|
unreachable!("register_tool should only fail with DuplicateName")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
self.tools.insert(meta.name.clone(), (meta, instance));
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Register multiple tools
|
/// Register multiple tools
|
||||||
|
|
@ -329,10 +331,18 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
&mut self,
|
&mut self,
|
||||||
factories: impl IntoIterator<Item = WorkerToolDefinition>,
|
factories: impl IntoIterator<Item = WorkerToolDefinition>,
|
||||||
) -> Result<(), ToolRegistryError> {
|
) -> Result<(), ToolRegistryError> {
|
||||||
for factory in factories {
|
match self.tool_server.register_tools(factories) {
|
||||||
self.register_tool(factory)?;
|
Ok(()) => Ok(()),
|
||||||
|
Err(ToolServerError::DuplicateName(name)) => Err(ToolRegistryError::DuplicateName(name)),
|
||||||
|
Err(ToolServerError::ToolNotFound(_) | ToolServerError::ToolExecution(_)) => {
|
||||||
|
unreachable!("register_tools should only fail with DuplicateName")
|
||||||
}
|
}
|
||||||
Ok(())
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a shared tool server handle.
|
||||||
|
pub fn tool_server_handle(&self) -> ToolServerHandle {
|
||||||
|
self.tool_server.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add an on_prompt_submit Hook
|
/// Add an on_prompt_submit Hook
|
||||||
|
|
@ -508,14 +518,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
|
|
||||||
/// Generate list of ToolDefinitions for LLM from registered tools
|
/// Generate list of ToolDefinitions for LLM from registered tools
|
||||||
fn build_tool_definitions(&self) -> Vec<ToolDefinition> {
|
fn build_tool_definitions(&self) -> Vec<ToolDefinition> {
|
||||||
self.tools
|
self.tool_server.tool_definitions_sorted()
|
||||||
.values()
|
|
||||||
.map(|(meta, _)| {
|
|
||||||
ToolDefinition::new(&meta.name)
|
|
||||||
.description(&meta.description)
|
|
||||||
.input_schema(meta.input_schema.clone())
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build assistant response items from text blocks and tool calls
|
/// Build assistant response items from text blocks and tool calls
|
||||||
|
|
@ -715,12 +718,12 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
let mut approved_calls = Vec::new();
|
let mut approved_calls = Vec::new();
|
||||||
for mut tool_call in tool_calls {
|
for mut tool_call in tool_calls {
|
||||||
// Get tool definition
|
// Get tool definition
|
||||||
if let Some((meta, tool)) = self.tools.get(&tool_call.name) {
|
if let Some((meta, tool)) = self.tool_server.get_tool(&tool_call.name) {
|
||||||
// Create context
|
// Create context
|
||||||
let mut context = ToolCallContext {
|
let mut context = ToolCallContext {
|
||||||
call: tool_call.clone(),
|
call: tool_call.clone(),
|
||||||
meta: meta.clone(),
|
meta,
|
||||||
tool: tool.clone(),
|
tool,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut skip = false;
|
let mut skip = false;
|
||||||
|
|
@ -753,7 +756,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
if !skip {
|
if !skip {
|
||||||
call_info_map.insert(
|
call_info_map.insert(
|
||||||
tool_call.id.clone(),
|
tool_call.id.clone(),
|
||||||
(tool_call.clone(), meta.clone(), tool.clone()),
|
(tool_call.clone(), context.meta.clone(), context.tool.clone()),
|
||||||
);
|
);
|
||||||
approved_calls.push(tool_call);
|
approved_calls.push(tool_call);
|
||||||
}
|
}
|
||||||
|
|
@ -768,21 +771,13 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
let futures: Vec<_> = approved_calls
|
let futures: Vec<_> = approved_calls
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|tool_call| {
|
.map(|tool_call| {
|
||||||
let tools = &self.tools;
|
let tool_server = self.tool_server.clone();
|
||||||
async move {
|
async move {
|
||||||
if let Some((_, tool)) = tools.get(&tool_call.name) {
|
let input_json = serde_json::to_string(&tool_call.input).unwrap_or_default();
|
||||||
let input_json =
|
match tool_server.call_tool(&tool_call.name, &input_json).await {
|
||||||
serde_json::to_string(&tool_call.input).unwrap_or_default();
|
|
||||||
match tool.execute(&input_json).await {
|
|
||||||
Ok(content) => ToolResult::success(&tool_call.id, content),
|
Ok(content) => ToolResult::success(&tool_call.id, content),
|
||||||
Err(e) => ToolResult::error(&tool_call.id, e.to_string()),
|
Err(e) => ToolResult::error(&tool_call.id, e.to_string()),
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
ToolResult::error(
|
|
||||||
&tool_call.id,
|
|
||||||
format!("Tool '{}' not found", tool_call.name),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
@ -1049,7 +1044,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
||||||
timeline,
|
timeline,
|
||||||
text_block_collector,
|
text_block_collector,
|
||||||
tool_call_collector,
|
tool_call_collector,
|
||||||
tools: HashMap::new(),
|
tool_server: ToolServer::new().handle(),
|
||||||
hooks: HookRegistry::new(),
|
hooks: HookRegistry::new(),
|
||||||
system_prompt: None,
|
system_prompt: None,
|
||||||
history: Vec::new(),
|
history: Vec::new(),
|
||||||
|
|
@ -1220,7 +1215,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
||||||
timeline: self.timeline,
|
timeline: self.timeline,
|
||||||
text_block_collector: self.text_block_collector,
|
text_block_collector: self.text_block_collector,
|
||||||
tool_call_collector: self.tool_call_collector,
|
tool_call_collector: self.tool_call_collector,
|
||||||
tools: self.tools,
|
tool_server: self.tool_server,
|
||||||
hooks: self.hooks,
|
hooks: self.hooks,
|
||||||
system_prompt: self.system_prompt,
|
system_prompt: self.system_prompt,
|
||||||
history: self.history,
|
history: self.history,
|
||||||
|
|
@ -1256,7 +1251,7 @@ impl<C: LlmClient> Worker<C, CacheLocked> {
|
||||||
timeline: self.timeline,
|
timeline: self.timeline,
|
||||||
text_block_collector: self.text_block_collector,
|
text_block_collector: self.text_block_collector,
|
||||||
tool_call_collector: self.tool_call_collector,
|
tool_call_collector: self.tool_call_collector,
|
||||||
tools: self.tools,
|
tool_server: self.tool_server,
|
||||||
hooks: self.hooks,
|
hooks: self.hooks,
|
||||||
system_prompt: self.system_prompt,
|
system_prompt: self.system_prompt,
|
||||||
history: self.history,
|
history: self.history,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user