Compare commits
No commits in common. "cba96e4f4685908571ff52b34bf0cd7c72b61284" and "3c6297096745eaa68bbd7755a3c1f5a73b0714fc" have entirely different histories.
cba96e4f46
...
3c62970967
100
Cargo.lock
generated
100
Cargo.lock
generated
|
|
@ -384,12 +384,6 @@ dependencies = [
|
|||
"wasip2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "glob"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280"
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.4.13"
|
||||
|
|
@ -732,7 +726,6 @@ dependencies = [
|
|||
"tokio-util",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"trybuild",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1185,15 +1178,6 @@ dependencies = [
|
|||
"zmij",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_spanned"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f8bbf91e5a4d6315eee45e704372590b30e260ee83af6639d64557f51b067776"
|
||||
dependencies = [
|
||||
"serde_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sharded-slab"
|
||||
version = "0.1.7"
|
||||
|
|
@ -1280,12 +1264,6 @@ dependencies = [
|
|||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "target-triple"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "591ef38edfb78ca4771ee32cf494cb8771944bee237a9b91fc9c1424ac4b777b"
|
||||
|
||||
[[package]]
|
||||
name = "tempfile"
|
||||
version = "3.24.0"
|
||||
|
|
@ -1299,15 +1277,6 @@ dependencies = [
|
|||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
version = "1.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755"
|
||||
dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "2.0.17"
|
||||
|
|
@ -1406,45 +1375,6 @@ dependencies = [
|
|||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml"
|
||||
version = "1.0.3+spec-1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7614eaf19ad818347db24addfa201729cf2a9b6fdfd9eb0ab870fcacc606c0c"
|
||||
dependencies = [
|
||||
"indexmap",
|
||||
"serde_core",
|
||||
"serde_spanned",
|
||||
"toml_datetime",
|
||||
"toml_parser",
|
||||
"toml_writer",
|
||||
"winnow",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_datetime"
|
||||
version = "1.0.0+spec-1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "32c2555c699578a4f59f0cc68e5116c8d7cabbd45e1409b989d4be085b53f13e"
|
||||
dependencies = [
|
||||
"serde_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_parser"
|
||||
version = "1.0.9+spec-1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4"
|
||||
dependencies = [
|
||||
"winnow",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_writer"
|
||||
version = "1.0.6+spec-1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ab16f14aed21ee8bfd8ec22513f7287cd4a91aa92e44edfe2c17ddd004e92607"
|
||||
|
||||
[[package]]
|
||||
name = "tower"
|
||||
version = "0.5.2"
|
||||
|
|
@ -1557,21 +1487,6 @@ version = "0.2.5"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||
|
||||
[[package]]
|
||||
name = "trybuild"
|
||||
version = "1.0.116"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "47c635f0191bd3a2941013e5062667100969f8c4e9cd787c14f977265d73616e"
|
||||
dependencies = [
|
||||
"glob",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"target-triple",
|
||||
"termcolor",
|
||||
"toml",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.22"
|
||||
|
|
@ -1725,15 +1640,6 @@ dependencies = [
|
|||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-util"
|
||||
version = "0.1.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
|
||||
dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-link"
|
||||
version = "0.2.1"
|
||||
|
|
@ -1896,12 +1802,6 @@ version = "0.53.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650"
|
||||
|
||||
[[package]]
|
||||
name = "winnow"
|
||||
version = "0.7.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a5364e9d77fcdeeaa6062ced926ee3381faa2ee02d3eb83a5c27a8825540829"
|
||||
|
||||
[[package]]
|
||||
name = "wit-bindgen"
|
||||
version = "0.46.0"
|
||||
|
|
|
|||
|
|
@ -26,4 +26,3 @@ schemars = "1.2"
|
|||
tempfile = "3.24"
|
||||
dotenv = "0.15"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
trybuild = "1.0.116"
|
||||
|
|
|
|||
|
|
@ -46,7 +46,6 @@ pub mod state;
|
|||
pub mod subscriber;
|
||||
pub mod timeline;
|
||||
pub mod tool;
|
||||
pub mod tool_server;
|
||||
|
||||
pub use message::{ContentPart, Item, Message, Role};
|
||||
pub use worker::{ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult};
|
||||
|
|
|
|||
|
|
@ -1,182 +0,0 @@
|
|||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::llm_client::ToolDefinition as LlmToolDefinition;
|
||||
use crate::tool::{Tool, ToolDefinition as WorkerToolDefinition, ToolMeta};
|
||||
|
||||
type ToolMap = HashMap<String, (ToolMeta, Arc<dyn Tool>)>;
|
||||
|
||||
/// Errors produced by ToolServer operations.
|
||||
#[derive(Debug, Error, PartialEq, Eq)]
|
||||
pub enum ToolServerError {
|
||||
/// A tool with the same name already exists.
|
||||
#[error("Tool with name '{0}' already registered")]
|
||||
DuplicateName(String),
|
||||
/// Requested tool was not found.
|
||||
#[error("Tool '{0}' not found")]
|
||||
ToolNotFound(String),
|
||||
/// Tool execution failed.
|
||||
#[error("Tool execution failed: {0}")]
|
||||
ToolExecution(String),
|
||||
}
|
||||
|
||||
/// In-memory tool server.
|
||||
#[derive(Clone, Default)]
|
||||
pub struct ToolServer {
|
||||
tools: Arc<Mutex<ToolMap>>,
|
||||
}
|
||||
|
||||
impl ToolServer {
|
||||
/// Create a new empty tool server.
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Create a handle for shared access.
|
||||
pub fn handle(&self) -> ToolServerHandle {
|
||||
ToolServerHandle {
|
||||
tools: Arc::clone(&self.tools),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Shareable handle to a tool server.
|
||||
#[derive(Clone, Default)]
|
||||
pub struct ToolServerHandle {
|
||||
tools: Arc<Mutex<ToolMap>>,
|
||||
}
|
||||
|
||||
impl ToolServerHandle {
|
||||
/// Register one tool.
|
||||
pub(crate) fn register_tool(
|
||||
&self,
|
||||
factory: WorkerToolDefinition,
|
||||
) -> Result<(), ToolServerError> {
|
||||
let (meta, instance) = factory();
|
||||
let mut guard = self.tools.lock().unwrap_or_else(|e| e.into_inner());
|
||||
if guard.contains_key(&meta.name) {
|
||||
return Err(ToolServerError::DuplicateName(meta.name));
|
||||
}
|
||||
guard.insert(meta.name.clone(), (meta, instance));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Register many tools.
|
||||
pub(crate) fn register_tools(
|
||||
&self,
|
||||
factories: impl IntoIterator<Item = WorkerToolDefinition>,
|
||||
) -> Result<(), ToolServerError> {
|
||||
for factory in factories {
|
||||
self.register_tool(factory)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get a tool by name for hook contexts.
|
||||
pub fn get_tool(&self, name: &str) -> Option<(ToolMeta, Arc<dyn Tool>)> {
|
||||
let guard = self.tools.lock().unwrap_or_else(|e| e.into_inner());
|
||||
guard.get(name).map(|(meta, tool)| (meta.clone(), Arc::clone(tool)))
|
||||
}
|
||||
|
||||
/// Execute a tool by name.
|
||||
pub async fn call_tool(&self, name: &str, input_json: &str) -> Result<String, ToolServerError> {
|
||||
let tool = {
|
||||
let guard = self.tools.lock().unwrap_or_else(|e| e.into_inner());
|
||||
let (_, tool) = guard
|
||||
.get(name)
|
||||
.ok_or_else(|| ToolServerError::ToolNotFound(name.to_string()))?;
|
||||
Arc::clone(tool)
|
||||
};
|
||||
tool.execute(input_json)
|
||||
.await
|
||||
.map_err(|e| ToolServerError::ToolExecution(e.to_string()))
|
||||
}
|
||||
|
||||
/// Build deterministic tool definitions sorted by tool name.
|
||||
pub fn tool_definitions_sorted(&self) -> Vec<LlmToolDefinition> {
|
||||
let guard = self.tools.lock().unwrap_or_else(|e| e.into_inner());
|
||||
let mut defs: Vec<_> = guard
|
||||
.values()
|
||||
.map(|(meta, _)| {
|
||||
LlmToolDefinition::new(&meta.name)
|
||||
.description(&meta.description)
|
||||
.input_schema(meta.input_schema.clone())
|
||||
})
|
||||
.collect();
|
||||
defs.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
defs
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
use super::*;
|
||||
use crate::tool::{Tool, ToolDefinition, ToolError, ToolMeta};
|
||||
|
||||
struct EchoTool;
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for EchoTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<String, ToolError> {
|
||||
Ok(input_json.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn def(name: &'static str) -> ToolDefinition {
|
||||
Arc::new(move || {
|
||||
(
|
||||
ToolMeta::new(name)
|
||||
.description(format!("desc-{name}"))
|
||||
.input_schema(json!({"type":"object"})),
|
||||
Arc::new(EchoTool) as Arc<dyn Tool>,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn register_duplicate_name_fails() {
|
||||
let handle = ToolServer::new().handle();
|
||||
handle.register_tool(def("alpha")).expect("first register");
|
||||
let err = handle
|
||||
.register_tool(def("alpha"))
|
||||
.expect_err("duplicate should fail");
|
||||
assert_eq!(err, ToolServerError::DuplicateName("alpha".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn call_tool_success_and_not_found() {
|
||||
let handle = ToolServer::new().handle();
|
||||
handle.register_tool(def("echo")).expect("register");
|
||||
|
||||
let out = handle.call_tool("echo", r#"{"x":1}"#).await.expect("call");
|
||||
assert_eq!(out, r#"{"x":1}"#);
|
||||
|
||||
let err = handle
|
||||
.call_tool("missing", "{}")
|
||||
.await
|
||||
.expect_err("missing tool");
|
||||
assert_eq!(err, ToolServerError::ToolNotFound("missing".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_definitions_are_sorted() {
|
||||
let handle = ToolServer::new().handle();
|
||||
handle.register_tool(def("zeta")).expect("register zeta");
|
||||
handle.register_tool(def("alpha")).expect("register alpha");
|
||||
handle.register_tool(def("beta")).expect("register beta");
|
||||
|
||||
let names: Vec<_> = handle
|
||||
.tool_definitions_sorted()
|
||||
.into_iter()
|
||||
.map(|d| d.name)
|
||||
.collect();
|
||||
assert_eq!(names, vec!["alpha", "beta", "zeta"]);
|
||||
}
|
||||
}
|
||||
|
|
@ -19,9 +19,8 @@ use crate::{
|
|||
ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter,
|
||||
ToolUseBlockSubscriberAdapter, UsageSubscriberAdapter, WorkerSubscriber,
|
||||
},
|
||||
tool_server::{ToolServer, ToolServerError, ToolServerHandle},
|
||||
timeline::{TextBlockCollector, Timeline, ToolCallCollector},
|
||||
tool::{ToolDefinition as WorkerToolDefinition, ToolError},
|
||||
tool::{Tool, ToolDefinition as WorkerToolDefinition, ToolError, ToolMeta},
|
||||
};
|
||||
|
||||
// =============================================================================
|
||||
|
|
@ -164,8 +163,8 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
|||
text_block_collector: TextBlockCollector,
|
||||
/// Tool call collector (Timeline handler)
|
||||
tool_call_collector: ToolCallCollector,
|
||||
/// Tool server handle
|
||||
tool_server: ToolServerHandle,
|
||||
/// Registered tools (meta, instance)
|
||||
tools: HashMap<String, (ToolMeta, Arc<dyn Tool>)>,
|
||||
/// Hook registry
|
||||
hooks: HookRegistry,
|
||||
/// System prompt
|
||||
|
|
@ -300,9 +299,40 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
.push(Box::new(SubscriberTurnNotifier { subscriber }));
|
||||
}
|
||||
|
||||
/// Get a shared tool server handle.
|
||||
pub fn tool_server_handle(&self) -> ToolServerHandle {
|
||||
self.tool_server.clone()
|
||||
/// Register a tool
|
||||
///
|
||||
/// Registered tools are automatically executed when called by the LLM.
|
||||
/// Registering a tool with the same name will result in an error.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```ignore
|
||||
/// use llm_worker::tool::{ToolMeta, ToolDefinition, Tool};
|
||||
/// use std::sync::Arc;
|
||||
///
|
||||
/// let def: ToolDefinition = Arc::new(|| {
|
||||
/// (ToolMeta::new("search").description("..."), Arc::new(MyTool) as Arc<dyn Tool>)
|
||||
/// });
|
||||
/// worker.register_tool(def)?;
|
||||
/// ```
|
||||
pub fn register_tool(&mut self, factory: WorkerToolDefinition) -> Result<(), ToolRegistryError> {
|
||||
let (meta, instance) = factory();
|
||||
if self.tools.contains_key(&meta.name) {
|
||||
return Err(ToolRegistryError::DuplicateName(meta.name.clone()));
|
||||
}
|
||||
self.tools.insert(meta.name.clone(), (meta, instance));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Register multiple tools
|
||||
pub fn register_tools(
|
||||
&mut self,
|
||||
factories: impl IntoIterator<Item = WorkerToolDefinition>,
|
||||
) -> Result<(), ToolRegistryError> {
|
||||
for factory in factories {
|
||||
self.register_tool(factory)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add an on_prompt_submit Hook
|
||||
|
|
@ -478,7 +508,14 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
|
||||
/// Generate list of ToolDefinitions for LLM from registered tools
|
||||
fn build_tool_definitions(&self) -> Vec<ToolDefinition> {
|
||||
self.tool_server.tool_definitions_sorted()
|
||||
self.tools
|
||||
.values()
|
||||
.map(|(meta, _)| {
|
||||
ToolDefinition::new(&meta.name)
|
||||
.description(&meta.description)
|
||||
.input_schema(meta.input_schema.clone())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Build assistant response items from text blocks and tool calls
|
||||
|
|
@ -678,12 +715,12 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
let mut approved_calls = Vec::new();
|
||||
for mut tool_call in tool_calls {
|
||||
// Get tool definition
|
||||
if let Some((meta, tool)) = self.tool_server.get_tool(&tool_call.name) {
|
||||
if let Some((meta, tool)) = self.tools.get(&tool_call.name) {
|
||||
// Create context
|
||||
let mut context = ToolCallContext {
|
||||
call: tool_call.clone(),
|
||||
meta,
|
||||
tool,
|
||||
meta: meta.clone(),
|
||||
tool: tool.clone(),
|
||||
};
|
||||
|
||||
let mut skip = false;
|
||||
|
|
@ -716,7 +753,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
if !skip {
|
||||
call_info_map.insert(
|
||||
tool_call.id.clone(),
|
||||
(tool_call.clone(), context.meta.clone(), context.tool.clone()),
|
||||
(tool_call.clone(), meta.clone(), tool.clone()),
|
||||
);
|
||||
approved_calls.push(tool_call);
|
||||
}
|
||||
|
|
@ -731,13 +768,21 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
let futures: Vec<_> = approved_calls
|
||||
.into_iter()
|
||||
.map(|tool_call| {
|
||||
let tool_server = self.tool_server.clone();
|
||||
let tools = &self.tools;
|
||||
async move {
|
||||
let input_json = serde_json::to_string(&tool_call.input).unwrap_or_default();
|
||||
match tool_server.call_tool(&tool_call.name, &input_json).await {
|
||||
if let Some((_, tool)) = tools.get(&tool_call.name) {
|
||||
let input_json =
|
||||
serde_json::to_string(&tool_call.input).unwrap_or_default();
|
||||
match tool.execute(&input_json).await {
|
||||
Ok(content) => ToolResult::success(&tool_call.id, content),
|
||||
Err(e) => ToolResult::error(&tool_call.id, e.to_string()),
|
||||
}
|
||||
} else {
|
||||
ToolResult::error(
|
||||
&tool_call.id,
|
||||
format!("Tool '{}' not found", tool_call.name),
|
||||
)
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
|
@ -1004,7 +1049,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
|||
timeline,
|
||||
text_block_collector,
|
||||
tool_call_collector,
|
||||
tool_server: ToolServer::new().handle(),
|
||||
tools: HashMap::new(),
|
||||
hooks: HookRegistry::new(),
|
||||
system_prompt: None,
|
||||
history: Vec::new(),
|
||||
|
|
@ -1019,41 +1064,6 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Register a tool
|
||||
///
|
||||
/// Registered tools are automatically executed when called by the LLM.
|
||||
/// Registering a tool with the same name will result in an error.
|
||||
///
|
||||
/// Available only in Mutable state.
|
||||
pub fn register_tool(
|
||||
&mut self,
|
||||
factory: WorkerToolDefinition,
|
||||
) -> Result<(), ToolRegistryError> {
|
||||
match self.tool_server.register_tool(factory) {
|
||||
Ok(()) => Ok(()),
|
||||
Err(ToolServerError::DuplicateName(name)) => Err(ToolRegistryError::DuplicateName(name)),
|
||||
Err(ToolServerError::ToolNotFound(_) | ToolServerError::ToolExecution(_)) => {
|
||||
unreachable!("register_tool should only fail with DuplicateName")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Register multiple tools
|
||||
///
|
||||
/// Available only in Mutable state.
|
||||
pub fn register_tools(
|
||||
&mut self,
|
||||
factories: impl IntoIterator<Item = WorkerToolDefinition>,
|
||||
) -> Result<(), ToolRegistryError> {
|
||||
match self.tool_server.register_tools(factories) {
|
||||
Ok(()) => Ok(()),
|
||||
Err(ToolServerError::DuplicateName(name)) => Err(ToolRegistryError::DuplicateName(name)),
|
||||
Err(ToolServerError::ToolNotFound(_) | ToolServerError::ToolExecution(_)) => {
|
||||
unreachable!("register_tools should only fail with DuplicateName")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Set system prompt (builder pattern)
|
||||
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
|
||||
self.system_prompt = Some(prompt.into());
|
||||
|
|
@ -1210,7 +1220,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
|||
timeline: self.timeline,
|
||||
text_block_collector: self.text_block_collector,
|
||||
tool_call_collector: self.tool_call_collector,
|
||||
tool_server: self.tool_server,
|
||||
tools: self.tools,
|
||||
hooks: self.hooks,
|
||||
system_prompt: self.system_prompt,
|
||||
history: self.history,
|
||||
|
|
@ -1246,7 +1256,7 @@ impl<C: LlmClient> Worker<C, CacheLocked> {
|
|||
timeline: self.timeline,
|
||||
text_block_collector: self.text_block_collector,
|
||||
tool_call_collector: self.tool_call_collector,
|
||||
tool_server: self.tool_server,
|
||||
tools: self.tools,
|
||||
hooks: self.hooks,
|
||||
system_prompt: self.system_prompt,
|
||||
history: self.history,
|
||||
|
|
|
|||
|
|
@ -1,6 +0,0 @@
|
|||
#[test]
|
||||
fn compile_fail_state_constraints() {
|
||||
let t = trybuild::TestCases::new();
|
||||
t.compile_fail("tests/ui/cache_locked_register_tool.rs");
|
||||
t.compile_fail("tests/ui/tool_server_handle_register_tool.rs");
|
||||
}
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
use llm_worker::Worker;
|
||||
use llm_worker::llm_client::providers::ollama::OllamaClient;
|
||||
use std::sync::Arc;
|
||||
|
||||
fn main() {
|
||||
let client = OllamaClient::new("dummy-model");
|
||||
let worker = Worker::new(client);
|
||||
let mut locked = worker.lock();
|
||||
let def: llm_worker::tool::ToolDefinition = Arc::new(|| panic!("unused"));
|
||||
let _ = locked.register_tool(def);
|
||||
}
|
||||
|
|
@ -1,8 +0,0 @@
|
|||
error[E0599]: no method named `register_tool` found for struct `Worker<OllamaClient, CacheLocked>` in the current scope
|
||||
--> tests/ui/cache_locked_register_tool.rs:10:20
|
||||
|
|
||||
10 | let _ = locked.register_tool(def);
|
||||
| ^^^^^^^^^^^^^ method not found in `Worker<OllamaClient, CacheLocked>`
|
||||
|
|
||||
= note: the method was found for
|
||||
- `Worker<C>`
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
use llm_worker::Worker;
|
||||
use llm_worker::llm_client::providers::ollama::OllamaClient;
|
||||
use std::sync::Arc;
|
||||
|
||||
fn main() {
|
||||
let client = OllamaClient::new("dummy-model");
|
||||
let worker = Worker::new(client);
|
||||
let handle = worker.tool_server_handle();
|
||||
let def: llm_worker::tool::ToolDefinition = Arc::new(|| panic!("unused"));
|
||||
let _ = handle.register_tool(def);
|
||||
}
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
error[E0624]: method `register_tool` is private
|
||||
--> tests/ui/tool_server_handle_register_tool.rs:10:20
|
||||
|
|
||||
10 | let _ = handle.register_tool(def);
|
||||
| ^^^^^^^^^^^^^ private method
|
||||
|
|
||||
::: src/tool_server.rs
|
||||
|
|
||||
| / pub(crate) fn register_tool(
|
||||
| | &self,
|
||||
| | factory: WorkerToolDefinition,
|
||||
| | ) -> Result<(), ToolServerError> {
|
||||
| |____________________________________- private method defined here
|
||||
|
|
@ -5,14 +5,9 @@
|
|||
|
||||
mod common;
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common::MockLlmClient;
|
||||
use llm_worker::Worker;
|
||||
use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent};
|
||||
use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta};
|
||||
use llm_worker::Item;
|
||||
|
||||
// =============================================================================
|
||||
|
|
@ -96,56 +91,6 @@ fn test_mutable_extend_history() {
|
|||
assert_eq!(worker.history().len(), 4);
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct CountingTool {
|
||||
name: String,
|
||||
calls: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl CountingTool {
|
||||
fn new(name: impl Into<String>) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
calls: Arc::new(AtomicUsize::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
fn definition(&self) -> ToolDefinition {
|
||||
let tool = self.clone();
|
||||
Arc::new(move || {
|
||||
(
|
||||
ToolMeta::new(&tool.name)
|
||||
.description("Counting tool")
|
||||
.input_schema(serde_json::json!({"type":"object","properties":{}})),
|
||||
Arc::new(tool.clone()) as Arc<dyn Tool>,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn call_count(&self) -> usize {
|
||||
self.calls.load(Ordering::SeqCst)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for CountingTool {
|
||||
async fn execute(&self, _input_json: &str) -> Result<String, ToolError> {
|
||||
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||
Ok(format!("{}-ok", self.name))
|
||||
}
|
||||
}
|
||||
|
||||
/// Verify that tools can be registered in Mutable state.
|
||||
#[test]
|
||||
fn test_mutable_can_register_tool() {
|
||||
let client = MockLlmClient::new(vec![]);
|
||||
let mut worker = Worker::new(client);
|
||||
let tool = CountingTool::new("count_tool");
|
||||
|
||||
let result = worker.register_tool(tool.definition());
|
||||
assert!(result.is_ok(), "Mutable should allow tool registration");
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// State Transition Tests
|
||||
// =============================================================================
|
||||
|
|
@ -385,67 +330,6 @@ async fn test_unlock_edit_relock() {
|
|||
assert_eq!(relocked.locked_prefix_len(), 1);
|
||||
}
|
||||
|
||||
/// Verify that tools registered before lock and after unlock remain effective.
|
||||
#[tokio::test]
|
||||
async fn test_lock_unlock_relock_tools_remain_effective() {
|
||||
let client = MockLlmClient::with_responses(vec![
|
||||
vec![
|
||||
Event::tool_use_start(0, "call_1", "tool_a"),
|
||||
Event::tool_input_delta(0, r#"{}"#),
|
||||
Event::tool_use_stop(0),
|
||||
Event::Status(StatusEvent {
|
||||
status: ResponseStatus::Completed,
|
||||
}),
|
||||
],
|
||||
vec![
|
||||
Event::text_block_start(0),
|
||||
Event::text_delta(0, "done-a"),
|
||||
Event::text_block_stop(0, None),
|
||||
Event::Status(StatusEvent {
|
||||
status: ResponseStatus::Completed,
|
||||
}),
|
||||
],
|
||||
vec![
|
||||
Event::tool_use_start(0, "call_2", "tool_b"),
|
||||
Event::tool_input_delta(0, r#"{}"#),
|
||||
Event::tool_use_stop(0),
|
||||
Event::Status(StatusEvent {
|
||||
status: ResponseStatus::Completed,
|
||||
}),
|
||||
],
|
||||
vec![
|
||||
Event::text_block_start(0),
|
||||
Event::text_delta(0, "done-b"),
|
||||
Event::text_block_stop(0, None),
|
||||
Event::Status(StatusEvent {
|
||||
status: ResponseStatus::Completed,
|
||||
}),
|
||||
],
|
||||
]);
|
||||
|
||||
let mut worker = Worker::new(client);
|
||||
let tool_a = CountingTool::new("tool_a");
|
||||
worker
|
||||
.register_tool(tool_a.definition())
|
||||
.expect("register tool_a should succeed");
|
||||
|
||||
let mut locked = worker.lock();
|
||||
locked.run("first").await.expect("first run");
|
||||
assert_eq!(tool_a.call_count(), 1, "tool_a should be called once");
|
||||
|
||||
let mut unlocked = locked.unlock();
|
||||
let tool_b = CountingTool::new("tool_b");
|
||||
unlocked
|
||||
.register_tool(tool_b.definition())
|
||||
.expect("register tool_b after unlock should succeed");
|
||||
|
||||
let mut relocked = unlocked.lock();
|
||||
relocked.run("second").await.expect("second run");
|
||||
|
||||
assert_eq!(tool_a.call_count(), 1, "tool_a should not be called again");
|
||||
assert_eq!(tool_b.call_count(), 1, "tool_b should be called once");
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// System Prompt Preservation Tests
|
||||
// =============================================================================
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user