tool: add execution context
This commit is contained in:
parent
b21fab82fc
commit
d8aed7befe
|
|
@ -90,19 +90,27 @@ fn extract_doc_comment(attrs: &[Attribute]) -> String {
|
|||
/// Extract description from #[description = "..."] attribute
|
||||
fn extract_description_attr(attrs: &[syn::Attribute]) -> Option<String> {
|
||||
for attr in attrs {
|
||||
if attr.path().is_ident("description") {
|
||||
if let Meta::NameValue(meta) = &attr.meta {
|
||||
if let syn::Expr::Lit(expr_lit) = &meta.value {
|
||||
if let Lit::Str(lit_str) = &expr_lit.lit {
|
||||
return Some(lit_str.value());
|
||||
}
|
||||
}
|
||||
}
|
||||
if attr.path().is_ident("description")
|
||||
&& let Meta::NameValue(meta) = &attr.meta
|
||||
&& let syn::Expr::Lit(expr_lit) = &meta.value
|
||||
&& let Lit::Str(lit_str) = &expr_lit.lit
|
||||
{
|
||||
return Some(lit_str.value());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn is_tool_execution_context_type(ty: &Type) -> bool {
|
||||
let Type::Path(path) = ty else {
|
||||
return false;
|
||||
};
|
||||
path.path
|
||||
.segments
|
||||
.last()
|
||||
.is_some_and(|segment| segment.ident == "ToolExecutionContext")
|
||||
}
|
||||
|
||||
/// Generate Tool implementation from a method
|
||||
fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::TokenStream {
|
||||
let sig = &method.sig;
|
||||
|
|
@ -123,8 +131,10 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
|
|||
description
|
||||
};
|
||||
|
||||
// Parse arguments (excluding self)
|
||||
let args: Vec<_> = sig
|
||||
// Parse method arguments (excluding self). A parameter typed as
|
||||
// ToolExecutionContext is supplied from the execution context and is not
|
||||
// exposed in the JSON input schema.
|
||||
let method_args: Vec<_> = sig
|
||||
.inputs
|
||||
.iter()
|
||||
.filter_map(|arg| {
|
||||
|
|
@ -135,9 +145,14 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
|
|||
}
|
||||
})
|
||||
.collect();
|
||||
let json_args: Vec<_> = method_args
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|pat_type| !is_tool_execution_context_type(pat_type.ty.as_ref()))
|
||||
.collect();
|
||||
|
||||
// Generate argument struct fields
|
||||
let arg_fields: Vec<_> = args
|
||||
let arg_fields: Vec<_> = json_args
|
||||
.iter()
|
||||
.map(|pat_type| {
|
||||
let pat = &pat_type.pat;
|
||||
|
|
@ -165,11 +180,13 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
|
|||
})
|
||||
.collect();
|
||||
|
||||
// Code to expand arguments in execute
|
||||
let arg_names: Vec<_> = args
|
||||
// Code to expand method arguments in execute
|
||||
let call_args: Vec<_> = method_args
|
||||
.iter()
|
||||
.map(|pat_type| {
|
||||
if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
|
||||
if is_tool_execution_context_type(pat_type.ty.as_ref()) {
|
||||
quote! { ctx.clone() }
|
||||
} else if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
|
||||
let ident = &pat_ident.ident;
|
||||
quote! { args.#ident }
|
||||
} else {
|
||||
|
|
@ -177,6 +194,11 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
|
|||
}
|
||||
})
|
||||
.collect();
|
||||
let method_call = if call_args.is_empty() {
|
||||
quote! { self.ctx.#method_name() }
|
||||
} else {
|
||||
quote! { self.ctx.#method_name(#(#call_args),*) }
|
||||
};
|
||||
|
||||
// Check if method is async
|
||||
let is_async = sig.asyncness.is_some();
|
||||
|
|
@ -218,13 +240,13 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
|
|||
};
|
||||
|
||||
// Execute body handling for no arguments case
|
||||
let execute_body = if args.is_empty() {
|
||||
let execute_body = if json_args.is_empty() {
|
||||
quote! {
|
||||
// Allow empty JSON object even with no arguments
|
||||
// Allow empty JSON object even with no JSON arguments
|
||||
let _: #args_struct_name = serde_json::from_str(input_json)
|
||||
.unwrap_or(#args_struct_name {});
|
||||
|
||||
let result = self.ctx.#method_name()#awaiter;
|
||||
let result = #method_call #awaiter;
|
||||
#result_handling
|
||||
}
|
||||
} else {
|
||||
|
|
@ -232,7 +254,7 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
|
|||
let args: #args_struct_name = serde_json::from_str(input_json)
|
||||
.map_err(|e| ::llm_worker::tool::ToolError::InvalidArgument(e.to_string()))?;
|
||||
|
||||
let result = self.ctx.#method_name(#(#arg_names),*)#awaiter;
|
||||
let result = #method_call #awaiter;
|
||||
#result_handling
|
||||
}
|
||||
};
|
||||
|
|
@ -247,7 +269,8 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
|
|||
|
||||
#[async_trait::async_trait]
|
||||
impl ::llm_worker::tool::Tool for #tool_struct_name {
|
||||
async fn execute(&self, input_json: &str) -> Result<::llm_worker::tool::ToolOutput, ::llm_worker::tool::ToolError> {
|
||||
async fn execute(&self, input_json: &str, ctx: ::llm_worker::tool::ToolExecutionContext) -> Result<::llm_worker::tool::ToolOutput, ::llm_worker::tool::ToolError> {
|
||||
let _ = &ctx;
|
||||
#execute_body
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ use std::sync::Arc;
|
|||
use async_trait::async_trait;
|
||||
|
||||
use crate::Item;
|
||||
use crate::tool::{Tool, ToolCall, ToolMeta, ToolResult};
|
||||
use crate::tool::{Tool, ToolCall, ToolExecutionContext, ToolMeta, ToolResult};
|
||||
|
||||
// =============================================================================
|
||||
// Action Enums
|
||||
|
|
@ -107,6 +107,8 @@ pub struct ToolCallInfo {
|
|||
pub meta: ToolMeta,
|
||||
/// Tool instance (for state access).
|
||||
pub tool: Arc<dyn Tool>,
|
||||
/// Response-local execution context for this call.
|
||||
pub context: ToolExecutionContext,
|
||||
}
|
||||
|
||||
/// Context for post-tool-call decisions.
|
||||
|
|
@ -119,6 +121,8 @@ pub struct ToolResultInfo {
|
|||
pub meta: ToolMeta,
|
||||
/// Tool instance (for state access).
|
||||
pub tool: Arc<dyn Tool>,
|
||||
/// Response-local execution context for this call.
|
||||
pub context: ToolExecutionContext,
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ pub use callback::{TextBlockScope, ThinkingBlockScope, ToolUseBlockScope};
|
|||
pub use handler::ToolUseBlockStart;
|
||||
pub use interceptor::Interceptor;
|
||||
pub use message::{ContentPart, Item, Message, Role};
|
||||
pub use tool::{ToolCall, ToolOutputLimits, ToolResult};
|
||||
pub use tool::{ToolCall, ToolExecutionContext, ToolOutputLimits, ToolResult};
|
||||
pub use usage_record::UsageRecord;
|
||||
pub use worker::{
|
||||
LlmRetryNotice, RunOutput, ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult,
|
||||
|
|
|
|||
|
|
@ -189,6 +189,44 @@ impl ToolMeta {
|
|||
/// ```
|
||||
pub type ToolDefinition = Arc<dyn Fn() -> (ToolMeta, Arc<dyn Tool>) + Send + Sync>;
|
||||
|
||||
/// Per-call context supplied by the worker when executing a tool call.
|
||||
///
|
||||
/// The context identifies a tool call within one assistant response's tool-call
|
||||
/// batch without imposing any scheduling policy on the worker. Tool
|
||||
/// implementations may use it for response-local ordering, diagnostics, or
|
||||
/// correlation, but it is intentionally not a handle to worker state, history,
|
||||
/// or session mutation.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ToolExecutionContext {
|
||||
/// Provider/tool-call id for the call being executed.
|
||||
pub call_id: String,
|
||||
/// Worker-local identity shared by all tool calls from one execution batch.
|
||||
pub batch_id: String,
|
||||
/// Zero-based order of this call in the model-returned tool-call list.
|
||||
pub call_index: usize,
|
||||
}
|
||||
|
||||
impl ToolExecutionContext {
|
||||
pub fn new(call_id: impl Into<String>, batch_id: impl Into<String>, call_index: usize) -> Self {
|
||||
Self {
|
||||
call_id: call_id.into(),
|
||||
batch_id: batch_id.into(),
|
||||
call_index,
|
||||
}
|
||||
}
|
||||
|
||||
/// Context for direct, non-worker calls in unit tests and low-level callers.
|
||||
pub fn direct() -> Self {
|
||||
Self::new("direct", "direct", 0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ToolExecutionContext {
|
||||
fn default() -> Self {
|
||||
Self::direct()
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Tool trait
|
||||
// =============================================================================
|
||||
|
|
@ -219,16 +257,16 @@ pub type ToolDefinition = Arc<dyn Fn() -> (ToolMeta, Arc<dyn Tool>) + Send + Syn
|
|||
/// # Manual Implementation
|
||||
///
|
||||
/// ```ignore
|
||||
/// use llm_worker::tool::{Tool, ToolError, ToolMeta, ToolDefinition};
|
||||
/// use llm_worker::tool::{Tool, ToolError, ToolExecutionContext, ToolMeta, ToolDefinition, ToolOutput};
|
||||
/// use std::sync::Arc;
|
||||
///
|
||||
/// struct MyTool { counter: std::sync::atomic::AtomicUsize }
|
||||
///
|
||||
/// #[async_trait::async_trait]
|
||||
/// impl Tool for MyTool {
|
||||
/// async fn execute(&self, input: &str) -> Result<String, ToolError> {
|
||||
/// async fn execute(&self, input: &str, ctx: ToolExecutionContext) -> Result<ToolOutput, ToolError> {
|
||||
/// self.counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||||
/// Ok("result".to_string())
|
||||
/// Ok(format!("call {}: {}", ctx.call_index, input).into())
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
|
|
@ -247,11 +285,16 @@ pub trait Tool: Send + Sync {
|
|||
///
|
||||
/// # Arguments
|
||||
/// * `input_json` - JSON-formatted arguments generated by LLM
|
||||
/// * `ctx` - response-local call identity and ordering context
|
||||
///
|
||||
/// # Returns
|
||||
/// A [`ToolOutput`] with summary and optional detailed content.
|
||||
/// For simple cases, use `From<String>`: `Ok("done".to_string().into())`
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError>;
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
ctx: ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError>;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
|
|
|||
|
|
@ -4,7 +4,9 @@ use std::sync::{Arc, Mutex};
|
|||
use thiserror::Error;
|
||||
|
||||
use crate::llm_client::ToolDefinition as LlmToolDefinition;
|
||||
use crate::tool::{Tool, ToolDefinition as WorkerToolDefinition, ToolMeta, ToolOutput};
|
||||
use crate::tool::{
|
||||
Tool, ToolDefinition as WorkerToolDefinition, ToolExecutionContext, ToolMeta, ToolOutput,
|
||||
};
|
||||
|
||||
type ToolMap = HashMap<String, (ToolMeta, Arc<dyn Tool>)>;
|
||||
|
||||
|
|
@ -117,6 +119,7 @@ impl ToolServerHandle {
|
|||
&self,
|
||||
name: &str,
|
||||
input_json: &str,
|
||||
ctx: ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolServerError> {
|
||||
let tool = {
|
||||
let guard = self.tools.lock().unwrap_or_else(|e| e.into_inner());
|
||||
|
|
@ -125,7 +128,7 @@ impl ToolServerHandle {
|
|||
.ok_or_else(|| ToolServerError::ToolNotFound(name.to_string()))?;
|
||||
Arc::clone(tool)
|
||||
};
|
||||
tool.execute(input_json)
|
||||
tool.execute(input_json, ctx)
|
||||
.await
|
||||
.map_err(|e| ToolServerError::ToolExecution(e.to_string()))
|
||||
}
|
||||
|
|
@ -187,7 +190,11 @@ mod tests {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for EchoTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: crate::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
Ok(input_json.to_string().into())
|
||||
}
|
||||
}
|
||||
|
|
@ -236,12 +243,15 @@ mod tests {
|
|||
handle.register_tool(def("echo"));
|
||||
handle.flush_pending();
|
||||
|
||||
let out = handle.call_tool("echo", r#"{"x":1}"#).await.expect("call");
|
||||
let out = handle
|
||||
.call_tool("echo", r#"{"x":1}"#, Default::default())
|
||||
.await
|
||||
.expect("call");
|
||||
assert_eq!(out.summary, r#"{"x":1}"#);
|
||||
assert!(out.content.is_none());
|
||||
|
||||
let err = handle
|
||||
.call_tool("missing", "{}")
|
||||
.call_tool("missing", "{}", Default::default())
|
||||
.await
|
||||
.expect_err("missing tool");
|
||||
assert_eq!(err, ToolServerError::ToolNotFound("missing".to_string()));
|
||||
|
|
@ -298,7 +308,11 @@ mod tests {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for FixedTool {
|
||||
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
_input_json: &str,
|
||||
_ctx: crate::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
Ok("replaced".to_string().into())
|
||||
}
|
||||
}
|
||||
|
|
@ -327,7 +341,11 @@ mod tests {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for ConstTool {
|
||||
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
_input_json: &str,
|
||||
_ctx: crate::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
Ok("const".to_string().into())
|
||||
}
|
||||
}
|
||||
|
|
@ -342,7 +360,10 @@ mod tests {
|
|||
});
|
||||
handle.replace(replacement).expect("replace");
|
||||
|
||||
let out = handle.call_tool("echo", "{}").await.expect("call");
|
||||
let out = handle
|
||||
.call_tool("echo", "{}", Default::default())
|
||||
.await
|
||||
.expect("call");
|
||||
assert_eq!(out.summary, "const");
|
||||
}
|
||||
|
||||
|
|
@ -360,7 +381,11 @@ mod tests {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for GatedTool {
|
||||
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
_input_json: &str,
|
||||
_ctx: crate::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
self.started.notify_one();
|
||||
self.finish.notified().await;
|
||||
Ok("done".to_string().into())
|
||||
|
|
@ -384,7 +409,7 @@ mod tests {
|
|||
handle.flush_pending();
|
||||
|
||||
let h = handle.clone();
|
||||
let call = tokio::spawn(async move { h.call_tool("slow", "{}").await });
|
||||
let call = tokio::spawn(async move { h.call_tool("slow", "{}", Default::default()).await });
|
||||
|
||||
// Wait until the tool is actually executing.
|
||||
started.notified().await;
|
||||
|
|
@ -413,7 +438,11 @@ mod tests {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for OldTool {
|
||||
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
_input_json: &str,
|
||||
_ctx: crate::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
self.started.notify_one();
|
||||
self.finish.notified().await;
|
||||
Ok("old".to_string().into())
|
||||
|
|
@ -437,7 +466,7 @@ mod tests {
|
|||
handle.flush_pending();
|
||||
|
||||
let h = handle.clone();
|
||||
let call = tokio::spawn(async move { h.call_tool("t", "{}").await });
|
||||
let call = tokio::spawn(async move { h.call_tool("t", "{}", Default::default()).await });
|
||||
|
||||
// Wait until the old tool is mid-execution.
|
||||
started.notified().await;
|
||||
|
|
@ -447,7 +476,11 @@ mod tests {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for NewTool {
|
||||
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
_input_json: &str,
|
||||
_ctx: crate::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
Ok("new".to_string().into())
|
||||
}
|
||||
}
|
||||
|
|
@ -469,7 +502,10 @@ mod tests {
|
|||
assert_eq!(result.expect("call").summary, "old");
|
||||
|
||||
// New calls use the replacement.
|
||||
let out = handle.call_tool("t", "{}").await.expect("call");
|
||||
let out = handle
|
||||
.call_tool("t", "{}", Default::default())
|
||||
.await
|
||||
.expect("call");
|
||||
assert_eq!(out.summary, "new");
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -26,8 +26,8 @@ use crate::{
|
|||
timeline::event::{ErrorEvent, StatusEvent, UsageEvent},
|
||||
timeline::{TextBlockCollector, ThinkingBlockCollector, Timeline, ToolCallCollector},
|
||||
tool::{
|
||||
ToolCall, ToolDefinition as WorkerToolDefinition, ToolError, ToolOutputLimits, ToolResult,
|
||||
truncate_content,
|
||||
ToolCall, ToolDefinition as WorkerToolDefinition, ToolError, ToolExecutionContext,
|
||||
ToolOutputLimits, ToolResult, truncate_content,
|
||||
},
|
||||
tool_server::{ToolServer, ToolServerHandle},
|
||||
};
|
||||
|
|
@ -187,6 +187,10 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
|||
/// LlmCall count (per-Worker running counter, monotonic). Unlike
|
||||
/// `turn_count` this never collapses retries.
|
||||
llm_call_count: usize,
|
||||
/// Tool execution batch count (per-Worker running counter, monotonic).
|
||||
/// Each batch corresponds to one collected assistant tool-call set or one
|
||||
/// resumed pending tool-call set.
|
||||
tool_execution_batch_count: usize,
|
||||
/// Maximum number of AgentTurns (None = unlimited)
|
||||
max_turns: Option<u32>,
|
||||
/// AgentTurn-start callbacks (1:1 with LlmCall today)
|
||||
|
|
@ -912,19 +916,23 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
) -> Result<ToolExecutionResult, WorkerError> {
|
||||
use futures::future::join_all;
|
||||
|
||||
// Map from tool call ID to (ToolCall, Meta, Tool)
|
||||
// Map from tool call ID to (ToolCall, Meta, Tool, Context)
|
||||
// Retained because it's needed for PostToolCall hooks
|
||||
let mut call_info_map = HashMap::new();
|
||||
let mut synthetic_results = Vec::new();
|
||||
let batch_id = format!("tool-batch-{}", self.tool_execution_batch_count);
|
||||
self.tool_execution_batch_count += 1;
|
||||
|
||||
// Phase 1: Apply pre_tool_call interceptor (determine skip/abort/synthetic result)
|
||||
let mut approved_calls = Vec::new();
|
||||
for mut tool_call in tool_calls {
|
||||
for (call_index, mut tool_call) in tool_calls.into_iter().enumerate() {
|
||||
let context = ToolExecutionContext::new(&tool_call.id, &batch_id, call_index);
|
||||
if let Some((meta, tool)) = self.tool_server.get_tool(&tool_call.name) {
|
||||
let mut info = ToolCallInfo {
|
||||
call: tool_call.clone(),
|
||||
meta,
|
||||
tool,
|
||||
context,
|
||||
};
|
||||
|
||||
match self.interceptor.pre_tool_call(&mut info).await {
|
||||
|
|
@ -934,9 +942,11 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
}
|
||||
PreToolAction::SyntheticResult(result) => {
|
||||
let tool_call = info.call;
|
||||
let mut context = info.context;
|
||||
context.call_id = tool_call.id.clone();
|
||||
call_info_map.insert(
|
||||
tool_call.id.clone(),
|
||||
(tool_call, info.meta.clone(), info.tool.clone()),
|
||||
(tool_call, info.meta.clone(), info.tool.clone(), context),
|
||||
);
|
||||
synthetic_results.push(result);
|
||||
continue;
|
||||
|
|
@ -953,26 +963,37 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
|
||||
// Reflect changes made by interceptor
|
||||
tool_call = info.call;
|
||||
let mut context = info.context;
|
||||
context.call_id = tool_call.id.clone();
|
||||
|
||||
call_info_map.insert(
|
||||
tool_call.id.clone(),
|
||||
(tool_call.clone(), info.meta.clone(), info.tool.clone()),
|
||||
(
|
||||
tool_call.clone(),
|
||||
info.meta.clone(),
|
||||
info.tool.clone(),
|
||||
context.clone(),
|
||||
),
|
||||
);
|
||||
approved_calls.push(tool_call);
|
||||
approved_calls.push((tool_call, context));
|
||||
} else {
|
||||
// Unknown tools go into approved list as-is (will error at execution)
|
||||
approved_calls.push(tool_call);
|
||||
let context = ToolExecutionContext::new(&tool_call.id, &batch_id, call_index);
|
||||
approved_calls.push((tool_call, context));
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: Execute approved tools in parallel (cancellable)
|
||||
let futures: Vec<_> = approved_calls
|
||||
.into_iter()
|
||||
.map(|tool_call| {
|
||||
.map(|(tool_call, context)| {
|
||||
let tool_server = self.tool_server.clone();
|
||||
async move {
|
||||
let input_json = serde_json::to_string(&tool_call.input).unwrap_or_default();
|
||||
match tool_server.call_tool(&tool_call.name, &input_json).await {
|
||||
match tool_server
|
||||
.call_tool(&tool_call.name, &input_json, context)
|
||||
.await
|
||||
{
|
||||
Ok(output) => ToolResult::from_output(&tool_call.id, output),
|
||||
Err(e) => ToolResult::error(&tool_call.id, e.to_string()),
|
||||
}
|
||||
|
|
@ -996,12 +1017,15 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
|
||||
// Phase 3: Apply post_tool_call interceptor
|
||||
for tool_result in &mut results {
|
||||
if let Some((tool_call, meta, tool)) = call_info_map.get(&tool_result.tool_use_id) {
|
||||
if let Some((tool_call, meta, tool, context)) =
|
||||
call_info_map.get(&tool_result.tool_use_id)
|
||||
{
|
||||
let mut info = ToolResultInfo {
|
||||
call: tool_call.clone(),
|
||||
result: tool_result.clone(),
|
||||
meta: meta.clone(),
|
||||
tool: tool.clone(),
|
||||
context: context.clone(),
|
||||
};
|
||||
|
||||
match self.interceptor.post_tool_call(&mut info).await {
|
||||
|
|
@ -1026,7 +1050,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
let Some(content) = tool_result.content.as_mut() else {
|
||||
continue;
|
||||
};
|
||||
let Some((tool_call, _, _)) = call_info_map.get(&tool_result.tool_use_id) else {
|
||||
let Some((tool_call, _, _, _)) = call_info_map.get(&tool_result.tool_use_id) else {
|
||||
continue;
|
||||
};
|
||||
let limit = limits.limit_for(&tool_call.name);
|
||||
|
|
@ -1628,6 +1652,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
|||
locked_prefix_len: 0,
|
||||
turn_count: 0,
|
||||
llm_call_count: 0,
|
||||
tool_execution_batch_count: 0,
|
||||
max_turns: None,
|
||||
turn_start_cbs: Vec::new(),
|
||||
turn_end_cbs: Vec::new(),
|
||||
|
|
@ -1892,6 +1917,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
|||
locked_prefix_len,
|
||||
turn_count: self.turn_count,
|
||||
llm_call_count: self.llm_call_count,
|
||||
tool_execution_batch_count: self.tool_execution_batch_count,
|
||||
max_turns: self.max_turns,
|
||||
turn_start_cbs: self.turn_start_cbs,
|
||||
turn_end_cbs: self.turn_end_cbs,
|
||||
|
|
@ -1984,6 +2010,7 @@ impl<C: LlmClient> Worker<C, Locked> {
|
|||
locked_prefix_len: 0,
|
||||
turn_count: self.turn_count,
|
||||
llm_call_count: self.llm_call_count,
|
||||
tool_execution_batch_count: self.tool_execution_batch_count,
|
||||
max_turns: self.max_turns,
|
||||
turn_start_cbs: self.turn_start_cbs,
|
||||
turn_end_cbs: self.turn_end_cbs,
|
||||
|
|
|
|||
|
|
@ -218,7 +218,11 @@ struct FixedOutputTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for FixedOutputTool {
|
||||
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
_input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
Ok(self.output.clone())
|
||||
}
|
||||
}
|
||||
|
|
@ -289,7 +293,11 @@ struct ErroringTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for ErroringTool {
|
||||
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
_input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
Err(ToolError::ExecutionFailed(self.message.clone()))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@
|
|||
//!
|
||||
//! Verify that Worker executes multiple tools in parallel.
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
|
@ -12,7 +12,9 @@ use llm_worker::interceptor::{
|
|||
Interceptor, PostToolAction, PreToolAction, ToolCallInfo, ToolResultInfo,
|
||||
};
|
||||
use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent};
|
||||
use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta, ToolOutput, ToolResult};
|
||||
use llm_worker::tool::{
|
||||
Tool, ToolDefinition, ToolError, ToolExecutionContext, ToolMeta, ToolOutput, ToolResult,
|
||||
};
|
||||
|
||||
mod common;
|
||||
use common::MockLlmClient;
|
||||
|
|
@ -59,13 +61,54 @@ impl SlowTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for SlowTool {
|
||||
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
_input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
self.call_count.fetch_add(1, Ordering::SeqCst);
|
||||
tokio::time::sleep(Duration::from_millis(self.delay_ms)).await;
|
||||
Ok(format!("Completed after {}ms", self.delay_ms).into())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ContextRecordingTool {
|
||||
name: String,
|
||||
contexts: Arc<Mutex<Vec<ToolExecutionContext>>>,
|
||||
}
|
||||
|
||||
impl ContextRecordingTool {
|
||||
fn new(name: impl Into<String>, contexts: Arc<Mutex<Vec<ToolExecutionContext>>>) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
contexts,
|
||||
}
|
||||
}
|
||||
|
||||
fn definition(&self) -> ToolDefinition {
|
||||
let tool = self.clone();
|
||||
Arc::new(move || {
|
||||
let meta = ToolMeta::new(&tool.name)
|
||||
.description("Records tool execution context")
|
||||
.input_schema(serde_json::json!({"type": "object"}));
|
||||
(meta, Arc::new(tool.clone()) as Arc<dyn Tool>)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ContextRecordingTool {
|
||||
async fn execute(
|
||||
&self,
|
||||
_input_json: &str,
|
||||
ctx: ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
self.contexts.lock().unwrap().push(ctx);
|
||||
Ok("recorded".to_string().into())
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Tests
|
||||
// =============================================================================
|
||||
|
|
@ -92,10 +135,18 @@ async fn test_parallel_tool_execution() {
|
|||
}),
|
||||
];
|
||||
|
||||
let client = MockLlmClient::new(events);
|
||||
let client = MockLlmClient::with_responses(vec![
|
||||
events,
|
||||
vec![
|
||||
Event::text_block_start(0),
|
||||
Event::text_delta(0, "Done"),
|
||||
Event::text_block_stop(0, None),
|
||||
Event::Status(StatusEvent {
|
||||
status: ResponseStatus::Completed,
|
||||
}),
|
||||
],
|
||||
]);
|
||||
let mut worker = Worker::new(client);
|
||||
|
||||
// Each tool waits 100ms
|
||||
let tool1 = SlowTool::new("slow_tool_1", 100);
|
||||
let tool2 = SlowTool::new("slow_tool_2", 100);
|
||||
let tool3 = SlowTool::new("slow_tool_3", 100);
|
||||
|
|
@ -129,7 +180,201 @@ async fn test_parallel_tool_execution() {
|
|||
println!("Parallel execution completed in {:?}", elapsed);
|
||||
}
|
||||
|
||||
/// Hook: pre_tool_call - verify that skipped tools are not executed
|
||||
#[tokio::test]
|
||||
async fn test_tool_execution_context_order_and_batch_id() {
|
||||
let client = MockLlmClient::with_responses(vec![
|
||||
vec![
|
||||
Event::tool_use_start(0, "call_a", "record_a"),
|
||||
Event::tool_input_delta(0, r#"{}"#),
|
||||
Event::tool_use_stop(0),
|
||||
Event::tool_use_start(1, "call_b", "record_b"),
|
||||
Event::tool_input_delta(1, r#"{}"#),
|
||||
Event::tool_use_stop(1),
|
||||
Event::tool_use_start(2, "call_c", "record_c"),
|
||||
Event::tool_input_delta(2, r#"{}"#),
|
||||
Event::tool_use_stop(2),
|
||||
Event::Status(StatusEvent {
|
||||
status: ResponseStatus::Completed,
|
||||
}),
|
||||
],
|
||||
vec![
|
||||
Event::text_block_start(0),
|
||||
Event::text_delta(0, "Done"),
|
||||
Event::text_block_stop(0, None),
|
||||
Event::Status(StatusEvent {
|
||||
status: ResponseStatus::Completed,
|
||||
}),
|
||||
],
|
||||
]);
|
||||
let mut worker = Worker::new(client);
|
||||
let contexts = Arc::new(Mutex::new(Vec::new()));
|
||||
|
||||
worker.register_tool(ContextRecordingTool::new("record_a", contexts.clone()).definition());
|
||||
worker.register_tool(ContextRecordingTool::new("record_b", contexts.clone()).definition());
|
||||
worker.register_tool(ContextRecordingTool::new("record_c", contexts.clone()).definition());
|
||||
|
||||
let _ = worker.run("record contexts").await;
|
||||
|
||||
let mut contexts = contexts.lock().unwrap().clone();
|
||||
contexts.sort_by_key(|ctx| ctx.call_index);
|
||||
|
||||
assert_eq!(contexts.len(), 3);
|
||||
assert_eq!(contexts[0].call_id, "call_a");
|
||||
assert_eq!(contexts[0].call_index, 0);
|
||||
assert_eq!(contexts[1].call_id, "call_b");
|
||||
assert_eq!(contexts[1].call_index, 1);
|
||||
assert_eq!(contexts[2].call_id, "call_c");
|
||||
assert_eq!(contexts[2].call_index, 2);
|
||||
assert_eq!(contexts[0].batch_id, contexts[1].batch_id);
|
||||
assert_eq!(contexts[1].batch_id, contexts[2].batch_id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tool_execution_context_batch_id_changes_between_batches() {
|
||||
let client = MockLlmClient::with_responses(vec![
|
||||
vec![
|
||||
Event::tool_use_start(0, "call_first", "record"),
|
||||
Event::tool_input_delta(0, r#"{}"#),
|
||||
Event::tool_use_stop(0),
|
||||
Event::Status(StatusEvent {
|
||||
status: ResponseStatus::Completed,
|
||||
}),
|
||||
],
|
||||
vec![
|
||||
Event::tool_use_start(0, "call_second", "record"),
|
||||
Event::tool_input_delta(0, r#"{}"#),
|
||||
Event::tool_use_stop(0),
|
||||
Event::Status(StatusEvent {
|
||||
status: ResponseStatus::Completed,
|
||||
}),
|
||||
],
|
||||
vec![
|
||||
Event::text_block_start(0),
|
||||
Event::text_delta(0, "Done"),
|
||||
Event::text_block_stop(0, None),
|
||||
Event::Status(StatusEvent {
|
||||
status: ResponseStatus::Completed,
|
||||
}),
|
||||
],
|
||||
]);
|
||||
let mut worker = Worker::new(client);
|
||||
let contexts = Arc::new(Mutex::new(Vec::new()));
|
||||
|
||||
worker.register_tool(ContextRecordingTool::new("record", contexts.clone()).definition());
|
||||
|
||||
let _ = worker.run("record batches").await;
|
||||
|
||||
let contexts = contexts.lock().unwrap().clone();
|
||||
assert_eq!(contexts.len(), 2);
|
||||
assert_eq!(contexts[0].call_id, "call_first");
|
||||
assert_eq!(contexts[0].call_index, 0);
|
||||
assert_eq!(contexts[1].call_id, "call_second");
|
||||
assert_eq!(contexts[1].call_index, 0);
|
||||
assert_ne!(contexts[0].batch_id, contexts[1].batch_id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tool_execution_context_for_skipped_and_synthetic_paths() {
|
||||
let client = MockLlmClient::with_responses(vec![
|
||||
vec![
|
||||
Event::tool_use_start(0, "call_run", "record"),
|
||||
Event::tool_input_delta(0, r#"{}"#),
|
||||
Event::tool_use_stop(0),
|
||||
Event::tool_use_start(1, "call_skip", "skip_tool"),
|
||||
Event::tool_input_delta(1, r#"{}"#),
|
||||
Event::tool_use_stop(1),
|
||||
Event::tool_use_start(2, "call_synth", "synthetic_tool"),
|
||||
Event::tool_input_delta(2, r#"{}"#),
|
||||
Event::tool_use_stop(2),
|
||||
Event::Status(StatusEvent {
|
||||
status: ResponseStatus::Completed,
|
||||
}),
|
||||
],
|
||||
vec![
|
||||
Event::text_block_start(0),
|
||||
Event::text_delta(0, "Done"),
|
||||
Event::text_block_stop(0, None),
|
||||
Event::Status(StatusEvent {
|
||||
status: ResponseStatus::Completed,
|
||||
}),
|
||||
],
|
||||
]);
|
||||
let mut worker = Worker::new(client);
|
||||
let executed_contexts = Arc::new(Mutex::new(Vec::new()));
|
||||
let pre_contexts = Arc::new(Mutex::new(Vec::new()));
|
||||
let post_contexts = Arc::new(Mutex::new(Vec::new()));
|
||||
|
||||
worker
|
||||
.register_tool(ContextRecordingTool::new("record", executed_contexts.clone()).definition());
|
||||
worker.register_tool(
|
||||
ContextRecordingTool::new("skip_tool", executed_contexts.clone()).definition(),
|
||||
);
|
||||
worker.register_tool(
|
||||
ContextRecordingTool::new("synthetic_tool", executed_contexts.clone()).definition(),
|
||||
);
|
||||
|
||||
struct ContextPolicy {
|
||||
pre_contexts: Arc<Mutex<Vec<ToolExecutionContext>>>,
|
||||
post_contexts: Arc<Mutex<Vec<ToolExecutionContext>>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Interceptor for ContextPolicy {
|
||||
async fn pre_tool_call(&self, info: &mut ToolCallInfo) -> PreToolAction {
|
||||
self.pre_contexts.lock().unwrap().push(info.context.clone());
|
||||
match info.call.name.as_str() {
|
||||
"skip_tool" => PreToolAction::Skip,
|
||||
"synthetic_tool" => PreToolAction::SyntheticResult(ToolResult::from_output(
|
||||
&info.call.id,
|
||||
ToolOutput::from("synthetic result".to_string()),
|
||||
)),
|
||||
_ => PreToolAction::Continue,
|
||||
}
|
||||
}
|
||||
|
||||
async fn post_tool_call(&self, info: &mut ToolResultInfo) -> PostToolAction {
|
||||
self.post_contexts
|
||||
.lock()
|
||||
.unwrap()
|
||||
.push(info.context.clone());
|
||||
PostToolAction::Continue
|
||||
}
|
||||
}
|
||||
|
||||
worker.set_interceptor(ContextPolicy {
|
||||
pre_contexts: pre_contexts.clone(),
|
||||
post_contexts: post_contexts.clone(),
|
||||
});
|
||||
|
||||
let _ = worker.run("record skipped and synthetic contexts").await;
|
||||
|
||||
let mut pre_contexts = pre_contexts.lock().unwrap().clone();
|
||||
pre_contexts.sort_by_key(|ctx| ctx.call_index);
|
||||
assert_eq!(pre_contexts.len(), 3);
|
||||
assert_eq!(pre_contexts[0].call_id, "call_run");
|
||||
assert_eq!(pre_contexts[0].call_index, 0);
|
||||
assert_eq!(pre_contexts[1].call_id, "call_skip");
|
||||
assert_eq!(pre_contexts[1].call_index, 1);
|
||||
assert_eq!(pre_contexts[2].call_id, "call_synth");
|
||||
assert_eq!(pre_contexts[2].call_index, 2);
|
||||
assert_eq!(pre_contexts[0].batch_id, pre_contexts[1].batch_id);
|
||||
assert_eq!(pre_contexts[1].batch_id, pre_contexts[2].batch_id);
|
||||
|
||||
let executed_contexts = executed_contexts.lock().unwrap().clone();
|
||||
assert_eq!(executed_contexts.len(), 1);
|
||||
assert_eq!(executed_contexts[0].call_id, "call_run");
|
||||
assert_eq!(executed_contexts[0].call_index, 0);
|
||||
|
||||
let mut post_contexts = post_contexts.lock().unwrap().clone();
|
||||
post_contexts.sort_by_key(|ctx| ctx.call_index);
|
||||
assert_eq!(post_contexts.len(), 2);
|
||||
assert_eq!(post_contexts[0].call_id, "call_run");
|
||||
assert_eq!(post_contexts[0].call_index, 0);
|
||||
assert_eq!(post_contexts[1].call_id, "call_synth");
|
||||
assert_eq!(post_contexts[1].call_index, 2);
|
||||
assert_eq!(post_contexts[0].batch_id, post_contexts[1].batch_id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_before_tool_call_skip() {
|
||||
let events = vec![
|
||||
|
|
@ -220,7 +465,11 @@ async fn test_post_tool_call_modification() {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for SimpleTool {
|
||||
async fn execute(&self, _: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
_: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
Ok("Original Result".to_string().into())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};
|
|||
use schemars;
|
||||
use serde;
|
||||
|
||||
use llm_worker::ToolExecutionContext;
|
||||
use llm_worker_macros::tool_registry;
|
||||
|
||||
// =============================================================================
|
||||
|
|
@ -42,6 +43,15 @@ impl SimpleContext {
|
|||
async fn get_prefix(&self) -> String {
|
||||
self.prefix.clone()
|
||||
}
|
||||
|
||||
/// Tool that observes execution context
|
||||
#[tool]
|
||||
async fn context_echo(&self, ctx: ToolExecutionContext, message: String) -> String {
|
||||
format!(
|
||||
"{}:{}:{}:{}",
|
||||
ctx.batch_id, ctx.call_index, ctx.call_id, message
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -74,7 +84,9 @@ async fn test_basic_tool_generation() {
|
|||
);
|
||||
|
||||
// Execution test
|
||||
let result = tool.execute(r#"{"message": "World"}"#).await;
|
||||
let result = tool
|
||||
.execute(r#"{"message": "World"}"#, Default::default())
|
||||
.await;
|
||||
assert!(result.is_ok(), "Should execute successfully");
|
||||
let output = result.unwrap();
|
||||
assert!(
|
||||
|
|
@ -97,7 +109,9 @@ async fn test_multiple_arguments() {
|
|||
|
||||
assert_eq!(meta.name, "add");
|
||||
|
||||
let result = tool.execute(r#"{"a": 10, "b": 20}"#).await;
|
||||
let result = tool
|
||||
.execute(r#"{"a": 10, "b": 20}"#, Default::default())
|
||||
.await;
|
||||
assert!(result.is_ok());
|
||||
let output = result.unwrap();
|
||||
assert!(
|
||||
|
|
@ -118,7 +132,7 @@ async fn test_no_arguments() {
|
|||
assert_eq!(meta.name, "get_prefix");
|
||||
|
||||
// Call with empty JSON object
|
||||
let result = tool.execute(r#"{}"#).await;
|
||||
let result = tool.execute(r#"{}"#, Default::default()).await;
|
||||
assert!(result.is_ok());
|
||||
let output = result.unwrap();
|
||||
assert!(
|
||||
|
|
@ -137,7 +151,9 @@ async fn test_invalid_arguments() {
|
|||
let (_, tool) = ctx.greet_definition()();
|
||||
|
||||
// Invalid JSON
|
||||
let result = tool.execute(r#"{"wrong_field": "value"}"#).await;
|
||||
let result = tool
|
||||
.execute(r#"{"wrong_field": "value"}"#, Default::default())
|
||||
.await;
|
||||
assert!(result.is_err(), "Should fail with invalid arguments");
|
||||
}
|
||||
|
||||
|
|
@ -175,7 +191,7 @@ async fn test_result_return_type_success() {
|
|||
let ctx = FallibleContext;
|
||||
let (_, tool) = ctx.validate_definition()();
|
||||
|
||||
let result = tool.execute(r#"{"value": 42}"#).await;
|
||||
let result = tool.execute(r#"{"value": 42}"#, Default::default()).await;
|
||||
assert!(result.is_ok(), "Should succeed for positive value");
|
||||
let output = result.unwrap();
|
||||
assert!(
|
||||
|
|
@ -190,7 +206,7 @@ async fn test_result_return_type_error() {
|
|||
let ctx = FallibleContext;
|
||||
let (_, tool) = ctx.validate_definition()();
|
||||
|
||||
let result = tool.execute(r#"{"value": -1}"#).await;
|
||||
let result = tool.execute(r#"{"value": -1}"#, Default::default()).await;
|
||||
assert!(result.is_err(), "Should fail for negative value");
|
||||
|
||||
let err = result.unwrap_err();
|
||||
|
|
@ -228,9 +244,9 @@ async fn test_sync_method() {
|
|||
let (_, tool) = ctx.increment_definition()();
|
||||
|
||||
// Execute 3 times
|
||||
let result1 = tool.execute(r#"{}"#).await;
|
||||
let result2 = tool.execute(r#"{}"#).await;
|
||||
let result3 = tool.execute(r#"{}"#).await;
|
||||
let result1 = tool.execute(r#"{}"#, Default::default()).await;
|
||||
let result2 = tool.execute(r#"{}"#, Default::default()).await;
|
||||
let result3 = tool.execute(r#"{}"#, Default::default()).await;
|
||||
|
||||
assert!(result1.is_ok());
|
||||
assert!(result2.is_ok());
|
||||
|
|
@ -240,6 +256,24 @@ async fn test_sync_method() {
|
|||
assert_eq!(ctx.counter.load(Ordering::SeqCst), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tool_macro_passes_execution_context() {
|
||||
let ctx = SimpleContext {
|
||||
prefix: "Test".to_string(),
|
||||
};
|
||||
let (_, tool) = ctx.context_echo_definition()();
|
||||
|
||||
let output = tool
|
||||
.execute(
|
||||
r#"{"message":"hello"}"#,
|
||||
ToolExecutionContext::new("call-ctx", "batch-ctx", 7),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(output.summary, "\"batch-ctx:7:call-ctx:hello\"");
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Test: ToolMeta Immutability
|
||||
// =============================================================================
|
||||
|
|
|
|||
|
|
@ -58,7 +58,11 @@ impl MockWeatherTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for MockWeatherTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
self.call_count.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
// Parse input
|
||||
|
|
|
|||
|
|
@ -136,7 +136,11 @@ impl CountingTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for CountingTool {
|
||||
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
_input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||
Ok(format!("{}-ok", self.name).into())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -54,7 +54,11 @@ struct WriteExtractedTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for WriteExtractedTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let payload: ExtractedPayload = serde_json::from_str(input_json).map_err(|e| {
|
||||
ToolError::InvalidArgument(format!("invalid write_extracted input: {e}"))
|
||||
})?;
|
||||
|
|
@ -122,7 +126,7 @@ mod tests {
|
|||
"requests": []
|
||||
})
|
||||
.to_string();
|
||||
let out = tool.execute(&input).await.unwrap();
|
||||
let out = tool.execute(&input, Default::default()).await.unwrap();
|
||||
assert!(out.summary.contains("decisions=1"));
|
||||
let payload = ctx.take_payload().unwrap();
|
||||
assert_eq!(payload.decisions.len(), 1);
|
||||
|
|
@ -137,7 +141,7 @@ mod tests {
|
|||
let first =
|
||||
serde_json::json!({"decisions": [], "discussions": [], "attempts": [], "requests": []})
|
||||
.to_string();
|
||||
tool.execute(&first).await.unwrap();
|
||||
tool.execute(&first, Default::default()).await.unwrap();
|
||||
|
||||
let second = serde_json::json!({
|
||||
"decisions": [],
|
||||
|
|
@ -146,7 +150,7 @@ mod tests {
|
|||
"requests": []
|
||||
})
|
||||
.to_string();
|
||||
tool.execute(&second).await.unwrap();
|
||||
tool.execute(&second, Default::default()).await.unwrap();
|
||||
|
||||
let payload = ctx.take_payload().unwrap();
|
||||
assert_eq!(payload.attempts.len(), 1);
|
||||
|
|
@ -157,7 +161,7 @@ mod tests {
|
|||
async fn invalid_json_returns_invalid_argument() {
|
||||
let ctx = Arc::new(ExtractWorkerContext::new());
|
||||
let tool: Arc<dyn Tool> = Arc::new(WriteExtractedTool { ctx: ctx.clone() });
|
||||
let res = tool.execute("not json").await;
|
||||
let res = tool.execute("not json", Default::default()).await;
|
||||
assert!(matches!(res, Err(ToolError::InvalidArgument(_))));
|
||||
assert!(ctx.take_payload().is_none());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,7 +29,11 @@ struct MemoryDeleteTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for MemoryDeleteTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: DeleteParams = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid MemoryDelete input: {e}")))?;
|
||||
let path = params
|
||||
|
|
@ -139,7 +143,10 @@ mod tests {
|
|||
|
||||
let (_, tool) = delete_tool(layout.clone())();
|
||||
let out = tool
|
||||
.execute(r#"{"kind":"decision","slug":"obsolete"}"#)
|
||||
.execute(
|
||||
r#"{"kind":"decision","slug":"obsolete"}"#,
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(out.summary.contains("Deleted"));
|
||||
|
|
|
|||
|
|
@ -47,7 +47,11 @@ struct EditTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for EditTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: EditParams = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid MemoryEdit input: {e}")))?;
|
||||
|
||||
|
|
@ -316,7 +320,10 @@ mod tests {
|
|||
"old_string": "body body",
|
||||
"new_string": "edited",
|
||||
});
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(out.summary.contains("1 replacement"));
|
||||
let after = std::fs::read_to_string(&path).unwrap();
|
||||
assert!(after.contains("edited"));
|
||||
|
|
@ -335,7 +342,10 @@ mod tests {
|
|||
"old_string": "status: open\n",
|
||||
"new_string": "",
|
||||
});
|
||||
let err = tool.execute(&inp.to_string()).await.unwrap_err();
|
||||
let err = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
let msg = format!("{err}");
|
||||
assert!(msg.contains("status") || msg.contains("missing"));
|
||||
|
||||
|
|
@ -354,7 +364,10 @@ mod tests {
|
|||
"old_string": "x",
|
||||
"new_string": "y",
|
||||
});
|
||||
let err = tool.execute(&inp.to_string()).await.unwrap_err();
|
||||
let err = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, ToolError::ExecutionFailed(_)));
|
||||
}
|
||||
|
||||
|
|
@ -369,7 +382,10 @@ mod tests {
|
|||
"old_string": "x",
|
||||
"new_string": "y",
|
||||
});
|
||||
let err = tool.execute(&inp.to_string()).await.unwrap_err();
|
||||
let err = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, ToolError::InvalidArgument(_)));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -126,7 +126,11 @@ struct KnowledgeQueryTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for MemoryQueryTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: MemoryQueryParams = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid MemoryQuery input: {e}")))?;
|
||||
let needle = match params.query.as_deref() {
|
||||
|
|
@ -240,7 +244,11 @@ impl Tool for MemoryQueryTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for KnowledgeQueryTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: KnowledgeQueryParams = serde_json::from_str(input_json).map_err(|e| {
|
||||
ToolError::InvalidArgument(format!("invalid KnowledgeQuery input: {e}"))
|
||||
})?;
|
||||
|
|
@ -568,7 +576,10 @@ mod tests {
|
|||
write_decision(dir.path(), "beta", "no match here\n");
|
||||
let (_, tool) = memory_query_tool(layout, QueryConfig::default())();
|
||||
let inp = serde_json::json!({ "query": "ollama" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let records: Vec<OwnedMemoryRecord> = parse_records(&out);
|
||||
assert_eq!(records.len(), 1);
|
||||
assert_eq!(records[0].slug, "alpha");
|
||||
|
|
@ -596,7 +607,7 @@ mod tests {
|
|||
.unwrap();
|
||||
|
||||
let (_, tool) = memory_query_tool(layout, QueryConfig::default())();
|
||||
let out = tool.execute("{}").await.unwrap();
|
||||
let out = tool.execute("{}", Default::default()).await.unwrap();
|
||||
let records: Vec<OwnedMemoryRecord> = parse_records(&out);
|
||||
let mut slugs: Vec<&str> = records.iter().map(|r| r.slug.as_str()).collect();
|
||||
slugs.sort();
|
||||
|
|
@ -616,7 +627,10 @@ mod tests {
|
|||
.unwrap();
|
||||
let (_, tool) = memory_query_tool(layout, QueryConfig::default())();
|
||||
let inp = serde_json::json!({ "query": "needle" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let records: Vec<OwnedMemoryRecord> = parse_records(&out);
|
||||
assert_eq!(records.len(), 1);
|
||||
assert_eq!(records[0].slug, "summary");
|
||||
|
|
@ -633,7 +647,10 @@ mod tests {
|
|||
|
||||
let (_, tool) = memory_query_tool(layout, QueryConfig::default())();
|
||||
let inp = serde_json::json!({ "query": "needle" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let records: Vec<OwnedMemoryRecord> = parse_records(&out);
|
||||
assert!(records.is_empty(), "got records: {:?}", out.content);
|
||||
}
|
||||
|
|
@ -653,8 +670,14 @@ mod tests {
|
|||
let (_, memory_tool) = memory_query_tool(layout.clone(), QueryConfig::default())();
|
||||
let (_, knowledge_tool) = knowledge_query_tool(layout.clone(), QueryConfig::default())();
|
||||
let inp = serde_json::json!({ "query": "needle" });
|
||||
memory_tool.execute(&inp.to_string()).await.unwrap();
|
||||
knowledge_tool.execute(&inp.to_string()).await.unwrap();
|
||||
memory_tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
knowledge_tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let report = crate::usage::build_usage_report(&layout).unwrap();
|
||||
assert!(report.records.is_empty());
|
||||
|
|
@ -673,7 +696,10 @@ mod tests {
|
|||
};
|
||||
let (_, tool) = memory_query_tool(layout, cfg)();
|
||||
let inp = serde_json::json!({ "query": "needle" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let records: Vec<OwnedMemoryRecord> = parse_records(&out);
|
||||
assert_eq!(records.len(), 3);
|
||||
}
|
||||
|
|
@ -692,7 +718,10 @@ mod tests {
|
|||
};
|
||||
let (_, tool) = memory_query_tool(layout, cfg)();
|
||||
let inp = serde_json::json!({ "query": "needle" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let records: Vec<OwnedMemoryRecord> = parse_records(&out);
|
||||
assert_eq!(records.len(), 1);
|
||||
let e = records[0].excerpt.as_deref().unwrap();
|
||||
|
|
@ -708,7 +737,10 @@ mod tests {
|
|||
let (_dir, layout) = setup();
|
||||
let (_, tool) = memory_query_tool(layout, QueryConfig::default())();
|
||||
let inp = serde_json::json!({ "query": " " });
|
||||
let err = tool.execute(&inp.to_string()).await.unwrap_err();
|
||||
let err = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, ToolError::InvalidArgument(_)));
|
||||
}
|
||||
|
||||
|
|
@ -724,7 +756,10 @@ mod tests {
|
|||
);
|
||||
let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())();
|
||||
let inp = serde_json::json!({ "query": "ollama" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let records: Vec<OwnedKnowledgeRecord> = parse_records(&out);
|
||||
assert_eq!(records.len(), 1);
|
||||
assert_eq!(records[0].slug, "policy");
|
||||
|
|
@ -748,7 +783,7 @@ mod tests {
|
|||
write_knowledge(dir.path(), "h1", "howto", "d2", "body\n");
|
||||
|
||||
let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())();
|
||||
let out = tool.execute("{}").await.unwrap();
|
||||
let out = tool.execute("{}", Default::default()).await.unwrap();
|
||||
let records: Vec<OwnedKnowledgeRecord> = parse_records(&out);
|
||||
let mut slugs: Vec<&str> = records.iter().map(|r| r.slug.as_str()).collect();
|
||||
slugs.sort();
|
||||
|
|
@ -764,7 +799,10 @@ mod tests {
|
|||
|
||||
let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())();
|
||||
let inp = serde_json::json!({ "query": "needle", "kind": "howto" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let records: Vec<OwnedKnowledgeRecord> = parse_records(&out);
|
||||
assert_eq!(records.len(), 1);
|
||||
assert_eq!(records[0].slug, "h1");
|
||||
|
|
@ -778,7 +816,10 @@ mod tests {
|
|||
|
||||
let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())();
|
||||
let inp = serde_json::json!({ "kind": "howto" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let records: Vec<OwnedKnowledgeRecord> = parse_records(&out);
|
||||
assert_eq!(records.len(), 1);
|
||||
assert_eq!(records[0].slug, "h1");
|
||||
|
|
@ -792,7 +833,10 @@ mod tests {
|
|||
|
||||
let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())();
|
||||
let inp = serde_json::json!({ "query": "xyzzy" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let records: Vec<OwnedKnowledgeRecord> = parse_records(&out);
|
||||
assert_eq!(records.len(), 1);
|
||||
assert_eq!(records[0].slug, "p");
|
||||
|
|
@ -804,7 +848,10 @@ mod tests {
|
|||
write_knowledge(dir.path(), "p", "policy", "d", "no match\n");
|
||||
let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())();
|
||||
let inp = serde_json::json!({ "query": "absent" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let records: Vec<OwnedKnowledgeRecord> = parse_records(&out);
|
||||
assert!(records.is_empty());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -45,7 +45,11 @@ struct ReadTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for ReadTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: ReadParams = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid MemoryRead input: {e}")))?;
|
||||
|
||||
|
|
@ -225,7 +229,10 @@ mod tests {
|
|||
|
||||
let (_meta, tool) = read_tool(layout)();
|
||||
let inp = serde_json::json!({ "kind": "decision", "slug": "foo" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let body = out.content.unwrap();
|
||||
assert!(body.contains(" 1\talpha"));
|
||||
assert!(body.contains(" 2\tbeta"));
|
||||
|
|
@ -240,7 +247,10 @@ mod tests {
|
|||
|
||||
let (_, tool) = read_tool(layout)();
|
||||
let inp = serde_json::json!({ "kind": "summary" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(out.content.unwrap().contains("summary body"));
|
||||
}
|
||||
|
||||
|
|
@ -249,7 +259,10 @@ mod tests {
|
|||
let (_dir, layout) = setup();
|
||||
let (_, tool) = read_tool(layout)();
|
||||
let inp = serde_json::json!({ "kind": "summary", "slug": "x" });
|
||||
let err = tool.execute(&inp.to_string()).await.unwrap_err();
|
||||
let err = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, ToolError::InvalidArgument(_)));
|
||||
}
|
||||
|
||||
|
|
@ -258,7 +271,10 @@ mod tests {
|
|||
let (_dir, layout) = setup();
|
||||
let (_, tool) = read_tool(layout)();
|
||||
let inp = serde_json::json!({ "kind": "decision" });
|
||||
let err = tool.execute(&inp.to_string()).await.unwrap_err();
|
||||
let err = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, ToolError::InvalidArgument(_)));
|
||||
}
|
||||
|
||||
|
|
@ -267,7 +283,10 @@ mod tests {
|
|||
let (_dir, layout) = setup();
|
||||
let (_, tool) = read_tool(layout)();
|
||||
let inp = serde_json::json!({ "kind": "decision", "slug": "Bad-Slug" });
|
||||
let err = tool.execute(&inp.to_string()).await.unwrap_err();
|
||||
let err = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, ToolError::InvalidArgument(_)));
|
||||
}
|
||||
|
||||
|
|
@ -280,7 +299,10 @@ mod tests {
|
|||
|
||||
let (_, tool) = read_tool(layout)();
|
||||
let inp = serde_json::json!({ "kind": "knowledge", "slug": "policy" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(out.content.unwrap().contains("k"));
|
||||
}
|
||||
|
||||
|
|
@ -293,7 +315,9 @@ mod tests {
|
|||
|
||||
let (_, tool) = read_tool_with_usage(layout.clone(), "session-1")();
|
||||
let inp = serde_json::json!({ "kind": "decision", "slug": "foo" });
|
||||
tool.execute(&inp.to_string()).await.unwrap();
|
||||
tool.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let report = usage::build_usage_report(&layout).unwrap();
|
||||
assert_eq!(report.records.len(), 1);
|
||||
|
|
@ -310,7 +334,10 @@ mod tests {
|
|||
let (_dir, layout) = setup();
|
||||
let (_, tool) = read_tool(layout)();
|
||||
let inp = serde_json::json!({ "kind": "decision", "slug": "missing" });
|
||||
let err = tool.execute(&inp.to_string()).await.unwrap_err();
|
||||
let err = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, ToolError::ExecutionFailed(_)));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,7 +42,11 @@ struct WriteTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for WriteTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: WriteParams = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid MemoryWrite input: {e}")))?;
|
||||
|
||||
|
|
@ -229,7 +233,10 @@ mod tests {
|
|||
"kind": "summary",
|
||||
"content": content,
|
||||
});
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(out.summary.contains("Created"));
|
||||
assert!(path.exists());
|
||||
}
|
||||
|
|
@ -249,7 +256,10 @@ mod tests {
|
|||
"slug": "foo",
|
||||
"content": content,
|
||||
});
|
||||
let err = tool.execute(&inp.to_string()).await.unwrap_err();
|
||||
let err = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
let msg = format!("{err}");
|
||||
assert!(msg.contains("status") || msg.contains("missing"), "{msg}");
|
||||
}
|
||||
|
|
@ -271,7 +281,10 @@ mod tests {
|
|||
"slug": "foo",
|
||||
"content": initial,
|
||||
});
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(out.summary.contains("Overwrote"));
|
||||
}
|
||||
|
||||
|
|
@ -283,7 +296,10 @@ mod tests {
|
|||
"kind": "decision",
|
||||
"content": "ignored",
|
||||
});
|
||||
let err = tool.execute(&inp.to_string()).await.unwrap_err();
|
||||
let err = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, ToolError::InvalidArgument(_)));
|
||||
}
|
||||
|
||||
|
|
@ -298,7 +314,11 @@ mod tests {
|
|||
"slug": "foo",
|
||||
"content": bad,
|
||||
});
|
||||
assert!(tool.execute(&inp.to_string()).await.is_err());
|
||||
assert!(
|
||||
tool.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
assert!(!path.exists());
|
||||
}
|
||||
|
||||
|
|
@ -312,7 +332,10 @@ mod tests {
|
|||
"slug": "wf",
|
||||
"content": "---\n---\n",
|
||||
});
|
||||
let err = tool.execute(&inp.to_string()).await.unwrap_err();
|
||||
let err = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, ToolError::InvalidArgument(_)));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -151,7 +151,11 @@ struct SearchSessionLogTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for SearchSessionLogTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: SearchSessionParams = serde_json::from_str(input_json).map_err(|e| {
|
||||
ToolError::InvalidArgument(format!("invalid search_session_log input: {e}"))
|
||||
})?;
|
||||
|
|
@ -206,7 +210,11 @@ struct ReadSessionItemsTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for ReadSessionItemsTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: ReadSessionParams = serde_json::from_str(input_json).map_err(|e| {
|
||||
ToolError::InvalidArgument(format!("invalid read_session_items input: {e}"))
|
||||
})?;
|
||||
|
|
@ -368,7 +376,11 @@ struct MarkReadRequiredTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for MarkReadRequiredTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: MarkParams = serde_json::from_str(input_json).map_err(|e| {
|
||||
ToolError::InvalidArgument(format!("invalid mark_read_required input: {e}"))
|
||||
})?;
|
||||
|
|
@ -425,7 +437,11 @@ struct AddReferenceTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for AddReferenceTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: ReferenceParams = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid add_reference input: {e}")))?;
|
||||
let mut guard = self.ctx.lock().expect("compact worker context poisoned");
|
||||
|
|
@ -449,7 +465,11 @@ struct WriteSummaryTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for WriteSummaryTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: SummaryParams = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid write_summary input: {e}")))?;
|
||||
let mut guard = self.ctx.lock().expect("compact worker context poisoned");
|
||||
|
|
@ -749,7 +769,7 @@ mod tests {
|
|||
ctx: ctx.clone(),
|
||||
});
|
||||
let input = serde_json::json!({ "file_path": path.to_str().unwrap() }).to_string();
|
||||
let out = tool.execute(&input).await.unwrap();
|
||||
let out = tool.execute(&input, Default::default()).await.unwrap();
|
||||
|
||||
assert!(out.summary.starts_with("Marked"));
|
||||
let guard = ctx.lock().unwrap();
|
||||
|
|
@ -770,7 +790,7 @@ mod tests {
|
|||
ctx: ctx.clone(),
|
||||
});
|
||||
let input = serde_json::json!({ "file_path": path.to_str().unwrap() }).to_string();
|
||||
let res = tool.execute(&input).await;
|
||||
let res = tool.execute(&input, Default::default()).await;
|
||||
|
||||
assert!(matches!(res, Err(ToolError::ExecutionFailed(_))));
|
||||
let guard = ctx.lock().unwrap();
|
||||
|
|
@ -784,11 +804,11 @@ mod tests {
|
|||
let tool: Arc<dyn Tool> = Arc::new(WriteSummaryTool { ctx: ctx.clone() });
|
||||
|
||||
let first = serde_json::json!({ "text": "first" }).to_string();
|
||||
let out1 = tool.execute(&first).await.unwrap();
|
||||
let out1 = tool.execute(&first, Default::default()).await.unwrap();
|
||||
assert!(out1.summary.contains("recorded"));
|
||||
|
||||
let second = serde_json::json!({ "text": "second" }).to_string();
|
||||
let out2 = tool.execute(&second).await.unwrap();
|
||||
let out2 = tool.execute(&second, Default::default()).await.unwrap();
|
||||
assert!(out2.summary.contains("replaced"));
|
||||
|
||||
assert_eq!(ctx.lock().unwrap().summary.as_deref(), Some("second"));
|
||||
|
|
@ -801,8 +821,8 @@ mod tests {
|
|||
|
||||
let p = "/abs/path.rs";
|
||||
let input = serde_json::json!({ "file_path": p }).to_string();
|
||||
tool.execute(&input).await.unwrap();
|
||||
tool.execute(&input).await.unwrap();
|
||||
tool.execute(&input, Default::default()).await.unwrap();
|
||||
tool.execute(&input, Default::default()).await.unwrap();
|
||||
|
||||
let guard = ctx.lock().unwrap();
|
||||
assert_eq!(guard.references.len(), 1);
|
||||
|
|
@ -823,7 +843,7 @@ mod tests {
|
|||
state: Arc::new(SessionLogToolState { items }),
|
||||
});
|
||||
let input = serde_json::json!({ "query": "compact", "limit": 10 }).to_string();
|
||||
let out = tool.execute(&input).await.unwrap();
|
||||
let out = tool.execute(&input, Default::default()).await.unwrap();
|
||||
let content = out.content.unwrap();
|
||||
|
||||
assert!(content.contains("investigate compact failure"));
|
||||
|
|
@ -842,7 +862,7 @@ mod tests {
|
|||
state: Arc::new(SessionLogToolState { items }),
|
||||
});
|
||||
let input = serde_json::json!({ "offset": 0, "limit": 1, "mode": "full" }).to_string();
|
||||
let out = tool.execute(&input).await.unwrap();
|
||||
let out = tool.execute(&input, Default::default()).await.unwrap();
|
||||
let content = out.content.unwrap();
|
||||
|
||||
assert!(content.contains("raw trace detail"));
|
||||
|
|
|
|||
|
|
@ -752,7 +752,11 @@ impl<St> Tool for ListPodsTool<St>
|
|||
where
|
||||
St: PodMetadataStore + Clone + Send + Sync + 'static,
|
||||
{
|
||||
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
_input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let items = self
|
||||
.discovery
|
||||
.list_visible()
|
||||
|
|
@ -775,7 +779,11 @@ impl<St> Tool for RestorePodTool<St>
|
|||
where
|
||||
St: PodMetadataStore + Clone + Send + Sync + 'static,
|
||||
{
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let input: PodNameInput = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid RestorePod input: {e}")))?;
|
||||
let result = self
|
||||
|
|
@ -847,7 +855,11 @@ impl<St> Tool for SendToPeerPodTool<St>
|
|||
where
|
||||
St: PodMetadataStore + Clone + Send + Sync + 'static,
|
||||
{
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let input: SendToPeerPodInput = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid SendToPeerPod input: {e}")))?;
|
||||
let detail = self
|
||||
|
|
@ -1392,7 +1404,7 @@ mod tests {
|
|||
|
||||
let (_, tool) = send_to_peer_pod_tool(discovery)();
|
||||
let output = tool
|
||||
.execute(r#"{"name":"target","message":"hello"}"#)
|
||||
.execute(r#"{"name":"target","message":"hello"}"#, Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(output.summary, "sent peer message to `target`");
|
||||
|
|
|
|||
|
|
@ -1292,7 +1292,11 @@ mod tests {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for DummyTool {
|
||||
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
_input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
Ok(ToolOutput::from("ok".to_string()))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -73,7 +73,11 @@ step: leave the task as-is, summarize the problem to the user, and end the turn.
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for TaskCreateTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: TaskCreateParams = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid TaskCreate input: {e}")))?;
|
||||
let created = self.store.create(params.subject, params.description);
|
||||
|
|
@ -93,7 +97,11 @@ impl Tool for TaskCreateTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for TaskListTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let _: TaskListParams = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid TaskList input: {e}")))?;
|
||||
let tasks = self.store.list();
|
||||
|
|
@ -106,7 +114,11 @@ impl Tool for TaskListTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for TaskGetTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: TaskGetParams = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid TaskGet input: {e}")))?;
|
||||
let task = self.store.get(params.taskid).ok_or_else(|| {
|
||||
|
|
@ -122,7 +134,11 @@ impl Tool for TaskGetTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for TaskUpdateTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: TaskUpdateParams = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid TaskUpdate input: {e}")))?;
|
||||
let updated = self
|
||||
|
|
@ -241,14 +257,20 @@ mod tests {
|
|||
let update = tool(task_update_tool(store.clone()));
|
||||
|
||||
let out = create
|
||||
.execute(r#"{"subject":"implement","description":"write code"}"#)
|
||||
.execute(
|
||||
r#"{"subject":"implement","description":"write code"}"#,
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(out.summary.contains("Created task 1"));
|
||||
assert_eq!(store.get(1).unwrap().status, TaskStatus::Pending);
|
||||
|
||||
let out = update
|
||||
.execute(r#"{"taskid":1,"status":"inprogress","subject":"implement tasks"}"#)
|
||||
.execute(
|
||||
r#"{"taskid":1,"status":"inprogress","subject":"implement tasks"}"#,
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(out.summary.contains("Updated task 1"));
|
||||
|
|
@ -256,11 +278,14 @@ mod tests {
|
|||
assert_eq!(task.status, TaskStatus::Inprogress);
|
||||
assert_eq!(task.subject, "implement tasks");
|
||||
|
||||
let out = get.execute(r#"{"taskid":1}"#).await.unwrap();
|
||||
let out = get
|
||||
.execute(r#"{"taskid":1}"#, Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(out.summary.contains("Task 1 (inprogress)"));
|
||||
assert!(out.content.unwrap().contains("implement tasks"));
|
||||
|
||||
let out = list.execute("{}").await.unwrap();
|
||||
let out = list.execute("{}", Default::default()).await.unwrap();
|
||||
assert!(out.summary.contains("1 task(s)"));
|
||||
let content = out.content.unwrap();
|
||||
assert!(content.contains("\"taskid\": 1"));
|
||||
|
|
@ -273,11 +298,14 @@ mod tests {
|
|||
store.create("s".into(), "d".into());
|
||||
let update = tool(task_update_tool(store));
|
||||
|
||||
let err = update.execute(r#"{"taskid":1}"#).await.unwrap_err();
|
||||
let err = update
|
||||
.execute(r#"{"taskid":1}"#, Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(err.to_string().contains("at least one"));
|
||||
|
||||
let err = update
|
||||
.execute(r#"{"taskid":99,"status":"deleted"}"#)
|
||||
.execute(r#"{"taskid":99,"status":"deleted"}"#, Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(err.to_string().contains("taskid 99 not found"));
|
||||
|
|
|
|||
|
|
@ -491,6 +491,7 @@ mod tests {
|
|||
},
|
||||
meta,
|
||||
tool,
|
||||
context: llm_worker::tool::ToolExecutionContext::new("call-id", "test-batch", 0),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -898,6 +899,7 @@ mod tests {
|
|||
),
|
||||
meta: info.meta,
|
||||
tool: info.tool,
|
||||
context: info.context,
|
||||
};
|
||||
|
||||
let action = interceptor.post_tool_call(&mut result_info).await;
|
||||
|
|
|
|||
|
|
@ -62,7 +62,11 @@ struct SendToPodTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for SendToPodTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let input: SendToPodInput = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid SendToPod input: {e}")))?;
|
||||
let record = self
|
||||
|
|
@ -123,7 +127,11 @@ struct ReadPodOutputTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for ReadPodOutputTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let input: NameInput = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid ReadPodOutput input: {e}")))?;
|
||||
let record = self
|
||||
|
|
@ -197,7 +205,11 @@ struct StopPodTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for StopPodTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let input: NameInput = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid StopPod input: {e}")))?;
|
||||
let record = self
|
||||
|
|
|
|||
|
|
@ -298,7 +298,11 @@ impl SpawnPodTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for SpawnPodTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let input: SpawnPodInput = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid SpawnPod input: {e}")))?;
|
||||
|
||||
|
|
|
|||
|
|
@ -1351,7 +1351,11 @@ struct HangingTool;
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for HangingTool {
|
||||
async fn execute(&self, _input: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
_input: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
std::future::pending::<()>().await;
|
||||
unreachable!()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -262,7 +262,7 @@ async fn send_to_pod_delivers_run_method() {
|
|||
let def = send_to_pod_tool(registry);
|
||||
let (_meta, tool) = def();
|
||||
let input = json!({ "name": "child", "message": "hello there" }).to_string();
|
||||
let output: ToolOutput = tool.execute(&input).await.unwrap();
|
||||
let output: ToolOutput = tool.execute(&input, Default::default()).await.unwrap();
|
||||
assert!(
|
||||
output.summary.contains("child"),
|
||||
"summary: {}",
|
||||
|
|
@ -285,7 +285,7 @@ async fn send_to_pod_errors_on_unknown_pod() {
|
|||
let def = send_to_pod_tool(registry);
|
||||
let (_meta, tool) = def();
|
||||
let input = json!({ "name": "nope", "message": "hi" }).to_string();
|
||||
let err = tool.execute(&input).await.unwrap_err();
|
||||
let err = tool.execute(&input, Default::default()).await.unwrap_err();
|
||||
assert!(err.to_string().contains("no spawned pod"), "{err}");
|
||||
}
|
||||
|
||||
|
|
@ -307,7 +307,7 @@ async fn send_to_pod_errors_when_pod_already_running() {
|
|||
let def = send_to_pod_tool(registry);
|
||||
let (_meta, tool) = def();
|
||||
let input = json!({ "name": "child", "message": "hi" }).to_string();
|
||||
let err = tool.execute(&input).await.unwrap_err();
|
||||
let err = tool.execute(&input, Default::default()).await.unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("already running"),
|
||||
"expected AlreadyRunning wording: {err}"
|
||||
|
|
@ -341,13 +341,13 @@ async fn read_pod_output_returns_new_assistant_text_then_empty_on_second_call()
|
|||
let (_meta, tool) = def();
|
||||
let input = json!({ "name": "child" }).to_string();
|
||||
|
||||
let first: ToolOutput = tool.execute(&input).await.unwrap();
|
||||
let first: ToolOutput = tool.execute(&input, Default::default()).await.unwrap();
|
||||
let body = first.content.expect("first read should have content");
|
||||
assert!(body.contains("hi back"), "body: {body}");
|
||||
assert!(body.contains("still working"), "body: {body}");
|
||||
|
||||
// Cursor now points past all items — second call returns no new text.
|
||||
let second: ToolOutput = tool.execute(&input).await.unwrap();
|
||||
let second: ToolOutput = tool.execute(&input, Default::default()).await.unwrap();
|
||||
assert!(
|
||||
second.content.is_none(),
|
||||
"unexpected content: {:?}",
|
||||
|
|
@ -371,7 +371,7 @@ async fn read_pod_output_reports_stopped_on_dead_socket() {
|
|||
let def = read_pod_output_tool(registry);
|
||||
let (_meta, tool) = def();
|
||||
let input = json!({ "name": "child" }).to_string();
|
||||
let output: ToolOutput = tool.execute(&input).await.unwrap();
|
||||
let output: ToolOutput = tool.execute(&input, Default::default()).await.unwrap();
|
||||
assert!(output.summary.contains("stopped"), "{}", output.summary);
|
||||
}
|
||||
|
||||
|
|
@ -452,7 +452,7 @@ async fn stop_pod_sends_shutdown_and_releases_scope() {
|
|||
let def = stop_pod_tool(registry.clone());
|
||||
let (_meta, tool) = def();
|
||||
let input = json!({ "name": "child" }).to_string();
|
||||
let output: ToolOutput = tool.execute(&input).await.unwrap();
|
||||
let output: ToolOutput = tool.execute(&input, Default::default()).await.unwrap();
|
||||
assert!(output.summary.contains("stopped"), "{}", output.summary);
|
||||
|
||||
// The child got a Shutdown.
|
||||
|
|
@ -497,7 +497,7 @@ async fn stop_pod_succeeds_even_when_child_unreachable() {
|
|||
let def = stop_pod_tool(registry.clone());
|
||||
let (_meta, tool) = def();
|
||||
let input = json!({ "name": "child" }).to_string();
|
||||
let output: ToolOutput = tool.execute(&input).await.unwrap();
|
||||
let output: ToolOutput = tool.execute(&input, Default::default()).await.unwrap();
|
||||
assert!(output.summary.contains("stopped"), "{}", output.summary);
|
||||
|
||||
// Registry no longer knows about the child.
|
||||
|
|
@ -545,7 +545,7 @@ async fn restored_registry_uses_pod_state_without_runtime_file() {
|
|||
let def = send_to_pod_tool(restored.clone());
|
||||
let (_meta, tool) = def();
|
||||
let input = json!({ "name": "child", "message": "after restart" }).to_string();
|
||||
tool.execute(&input).await.unwrap();
|
||||
tool.execute(&input, Default::default()).await.unwrap();
|
||||
match received.recv().await.expect("expected Run") {
|
||||
Method::Run { input } => match input.as_slice() {
|
||||
[protocol::Segment::Text { content }] => assert_eq!(content, "after restart"),
|
||||
|
|
@ -556,7 +556,7 @@ async fn restored_registry_uses_pod_state_without_runtime_file() {
|
|||
|
||||
let def = stop_pod_tool(restored.clone());
|
||||
let (_meta, tool) = def();
|
||||
tool.execute(&json!({ "name": "child" }).to_string())
|
||||
tool.execute(&json!({ "name": "child" }).to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(matches!(
|
||||
|
|
|
|||
|
|
@ -79,7 +79,11 @@ struct BigContentTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for BigContentTool {
|
||||
async fn execute(&self, _input: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
_input: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
Ok(ToolOutput {
|
||||
summary: self.summary.into(),
|
||||
content: Some(self.content.clone()),
|
||||
|
|
|
|||
|
|
@ -312,7 +312,7 @@ async fn spawn_pod_launches_runtime_in_workspace_and_passes_tool_cwd() {
|
|||
})
|
||||
.to_string();
|
||||
|
||||
tool.execute(&input).await.unwrap();
|
||||
tool.execute(&input, Default::default()).await.unwrap();
|
||||
assert!(matches!(received.await.unwrap(), Some(Method::Run { .. })));
|
||||
let invocation = read_recorded_runtime_invocation(&output_path).await;
|
||||
assert_eq!(invocation[0], allow_root.path().to_str().unwrap());
|
||||
|
|
@ -373,7 +373,7 @@ async fn spawn_pod_omitted_cwd_preserves_spawner_pwd() {
|
|||
})
|
||||
.to_string();
|
||||
|
||||
tool.execute(&input).await.unwrap();
|
||||
tool.execute(&input, Default::default()).await.unwrap();
|
||||
assert!(matches!(received.await.unwrap(), Some(Method::Run { .. })));
|
||||
let invocation = read_recorded_runtime_invocation(&output_path).await;
|
||||
assert_eq!(invocation[0], allow_root.path().to_str().unwrap());
|
||||
|
|
@ -433,7 +433,7 @@ async fn spawn_pod_delegates_scope_and_sends_run() {
|
|||
.is_writable(&allow_root.path().join("a.txt"))
|
||||
);
|
||||
|
||||
let output: ToolOutput = tool.execute(&input).await.unwrap();
|
||||
let output: ToolOutput = tool.execute(&input, Default::default()).await.unwrap();
|
||||
assert!(
|
||||
output.summary.contains("child"),
|
||||
"summary: {}",
|
||||
|
|
@ -519,7 +519,7 @@ async fn spawn_pod_requires_explicit_delegation_even_with_direct_scope() {
|
|||
})
|
||||
.to_string();
|
||||
|
||||
let err = tool.execute(&input).await.unwrap_err();
|
||||
let err = tool.execute(&input, Default::default()).await.unwrap_err();
|
||||
match err {
|
||||
ToolError::InvalidArgument(message) => {
|
||||
assert!(message.contains("no delegation scope grant"), "{message}");
|
||||
|
|
@ -587,7 +587,7 @@ async fn spawn_pod_rejects_child_non_recursive_scope_under_parent_non_recursive_
|
|||
})
|
||||
.to_string();
|
||||
|
||||
let err = tool.execute(&input).await.unwrap_err();
|
||||
let err = tool.execute(&input, Default::default()).await.unwrap_err();
|
||||
match err {
|
||||
ToolError::InvalidArgument(message) => {
|
||||
assert!(
|
||||
|
|
@ -639,7 +639,7 @@ async fn spawn_pod_rejects_scope_outside_spawner() {
|
|||
})
|
||||
.to_string();
|
||||
|
||||
let err = tool.execute(&input).await.unwrap_err();
|
||||
let err = tool.execute(&input, Default::default()).await.unwrap_err();
|
||||
match err {
|
||||
ToolError::InvalidArgument(msg) => {
|
||||
assert!(
|
||||
|
|
@ -712,7 +712,7 @@ async fn spawn_pod_rolls_back_reservation_when_socket_never_appears() {
|
|||
})
|
||||
.to_string();
|
||||
|
||||
let err = tool.execute(&input).await.unwrap_err();
|
||||
let err = tool.execute(&input, Default::default()).await.unwrap_err();
|
||||
match err {
|
||||
ToolError::ExecutionFailed(msg) => {
|
||||
assert!(
|
||||
|
|
|
|||
|
|
@ -54,7 +54,11 @@ struct MockWeatherTool;
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for MockWeatherTool {
|
||||
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
_input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
Ok("Sunny, 25C".to_string().into())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -562,7 +562,11 @@ struct TicketDoctorTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for TicketCreateTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: TicketCreateParams = parse_input("TicketCreate", input_json)?;
|
||||
let mut input = NewTicket::new(params.title);
|
||||
if let Some(body) = params.body {
|
||||
|
|
@ -594,7 +598,11 @@ impl Tool for TicketCreateTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for TicketListTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: TicketListParams = parse_input("TicketList", input_json)?;
|
||||
let state = params.state.unwrap_or(TicketListStateParam::All);
|
||||
let (filter, state_filter) = state.as_filter();
|
||||
|
|
@ -629,7 +637,11 @@ impl Tool for TicketListTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for TicketShowTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: TicketShowParams = parse_input("TicketShow", input_json)?;
|
||||
let query = id_or_query(params.id, params.query)?;
|
||||
let event_limit = bounded(params.event_limit, DEFAULT_EVENT_LIMIT, MAX_EVENT_LIMIT);
|
||||
|
|
@ -661,7 +673,11 @@ impl Tool for TicketShowTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for TicketCommentTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: TicketCommentParams = parse_input("TicketComment", input_json)?;
|
||||
let kind = match params.role {
|
||||
TicketCommentRoleParam::Comment => TicketEventKind::Comment,
|
||||
|
|
@ -684,7 +700,11 @@ impl Tool for TicketCommentTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for TicketReviewTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: TicketReviewParams = parse_input("TicketReview", input_json)?;
|
||||
let result = match params.result {
|
||||
TicketReviewResultParam::Approve => TicketReviewResult::Approve,
|
||||
|
|
@ -708,7 +728,11 @@ impl Tool for TicketReviewTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for TicketIntakeReadyTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: TicketIntakeReadyParams = parse_input("TicketIntakeReady", input_json)?;
|
||||
let from = TicketWorkflowState::Planning;
|
||||
let reason = params
|
||||
|
|
@ -743,7 +767,11 @@ impl Tool for TicketIntakeReadyTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for TicketWorkflowStateTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: TicketWorkflowStateParams = parse_input("TicketWorkflowState", input_json)?;
|
||||
let from = params.from.into_state();
|
||||
let to = params.to.into_state();
|
||||
|
|
@ -778,7 +806,11 @@ impl Tool for TicketWorkflowStateTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for TicketCloseTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: TicketCloseParams = parse_input("TicketClose", input_json)?;
|
||||
self.backend
|
||||
.close(
|
||||
|
|
@ -795,7 +827,11 @@ impl Tool for TicketCloseTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for TicketRelationRecordTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: TicketRelationRecordParams = parse_input("TicketRelationRecord", input_json)?;
|
||||
let relation = NewTicketRelation {
|
||||
kind: params.kind.into_kind(),
|
||||
|
|
@ -819,7 +855,11 @@ impl Tool for TicketRelationRecordTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for TicketRelationQueryTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: TicketRelationQueryParams = parse_input("TicketRelationQuery", input_json)?;
|
||||
let limit = bounded(params.limit, DEFAULT_LIST_LIMIT, MAX_LIST_LIMIT);
|
||||
let ticket = params.ticket.clone().map(TicketIdOrSlug::Id);
|
||||
|
|
@ -853,7 +893,11 @@ impl Tool for TicketRelationQueryTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for TicketOrchestrationPlanRecordTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: TicketOrchestrationPlanRecordParams =
|
||||
parse_input("TicketOrchestrationPlanRecord", input_json)?;
|
||||
let accepted_plan = params.accepted_plan.map(|plan| AcceptedOrchestrationPlan {
|
||||
|
|
@ -885,7 +929,11 @@ impl Tool for TicketOrchestrationPlanRecordTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for TicketOrchestrationPlanQueryTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: TicketOrchestrationPlanQueryParams =
|
||||
parse_input("TicketOrchestrationPlanQuery", input_json)?;
|
||||
let limit = bounded(params.limit, DEFAULT_LIST_LIMIT, MAX_LIST_LIMIT);
|
||||
|
|
@ -922,7 +970,11 @@ impl Tool for TicketOrchestrationPlanQueryTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for TicketDoctorTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: TicketDoctorParams = parse_input("TicketDoctor", input_json)?;
|
||||
let limit = bounded(params.limit, DEFAULT_DIAGNOSTIC_LIMIT, MAX_DIAGNOSTIC_LIMIT);
|
||||
let report = self
|
||||
|
|
@ -1377,6 +1429,7 @@ mod tests {
|
|||
"body": "## Background\n\nCreated by tool.\n"
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
@ -1388,7 +1441,10 @@ mod tests {
|
|||
assert!(!created_text.contains("needs_preflight"));
|
||||
|
||||
let listed = list
|
||||
.execute(&json!({ "state": "planning" }).to_string())
|
||||
.execute(
|
||||
&json!({ "state": "planning" }).to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(listed.summary.contains("Listed 1 ticket"));
|
||||
|
|
@ -1398,7 +1454,10 @@ mod tests {
|
|||
assert!(!listed_content.contains("needs_preflight"));
|
||||
|
||||
let shown = show
|
||||
.execute(&json!({ "id": id, "event_limit": 10 }).to_string())
|
||||
.execute(
|
||||
&json!({ "id": id, "event_limit": 10 }).to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(shown.summary.contains(&id));
|
||||
|
|
@ -1407,7 +1466,10 @@ mod tests {
|
|||
assert!(!shown_content.contains("legacy_ticket"));
|
||||
assert!(!shown_content.contains("needs_preflight"));
|
||||
|
||||
let report = doctor.execute(&json!({}).to_string()).await.unwrap();
|
||||
let report = doctor
|
||||
.execute(&json!({}).to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(report.summary.contains("0 error(s)"));
|
||||
}
|
||||
|
||||
|
|
@ -1431,6 +1493,7 @@ mod tests {
|
|||
"author": "test"
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
@ -1440,7 +1503,10 @@ mod tests {
|
|||
assert_eq!(recorded_json["target"], target.id);
|
||||
|
||||
let queried = query
|
||||
.execute(&json!({ "ticket": target.id.clone() }).to_string())
|
||||
.execute(
|
||||
&json!({ "ticket": target.id.clone() }).to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let queried_json: Value = serde_json::from_str(&queried.content.unwrap()).unwrap();
|
||||
|
|
@ -1448,7 +1514,10 @@ mod tests {
|
|||
assert_eq!(queried_json["relations"][0]["ticket_id"], source.id);
|
||||
|
||||
let shown = show
|
||||
.execute(&json!({ "id": target.id.clone() }).to_string())
|
||||
.execute(
|
||||
&json!({ "id": target.id.clone() }).to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let shown_json: Value = serde_json::from_str(&shown.content.unwrap()).unwrap();
|
||||
|
|
@ -1476,6 +1545,7 @@ mod tests {
|
|||
"body": "Implemented."
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
@ -1487,6 +1557,7 @@ mod tests {
|
|||
"body": "Looks good."
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
@ -1494,11 +1565,15 @@ mod tests {
|
|||
.execute(
|
||||
&json!({ "ticket": created.id, "resolution": "Done via TicketClose.\n" })
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let report = doctor.execute(&json!({}).to_string()).await.unwrap();
|
||||
let report = doctor
|
||||
.execute(&json!({}).to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(report.summary.contains("0 error(s)"));
|
||||
let closed = backend.show(TicketIdOrSlug::Id(created.id)).unwrap();
|
||||
assert!(closed.resolution.is_some());
|
||||
|
|
@ -1538,6 +1613,7 @@ mod tests {
|
|||
"author": "intake-pod"
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
@ -1555,6 +1631,7 @@ mod tests {
|
|||
"author": "orchestrator"
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
@ -1569,6 +1646,7 @@ mod tests {
|
|||
"author": "orchestrator"
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
@ -1621,6 +1699,7 @@ mod tests {
|
|||
"author": "orchestrator"
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
@ -1650,6 +1729,7 @@ mod tests {
|
|||
"author": "orchestrator"
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
@ -1685,6 +1765,7 @@ mod tests {
|
|||
"body": "Should not apply.\n"
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
|
@ -1717,6 +1798,7 @@ mod tests {
|
|||
"body": "Should not bypass Queue.\n"
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
|
@ -1735,6 +1817,7 @@ mod tests {
|
|||
"body": "Should not move backwards.\n"
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
|
@ -1753,6 +1836,7 @@ mod tests {
|
|||
"body": "Should not skip inprogress.\n"
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
|
@ -1775,6 +1859,7 @@ mod tests {
|
|||
"intake_summary": "Should not rewrite ready ticket."
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
|
@ -1807,6 +1892,7 @@ mod tests {
|
|||
"author": "orchestrator"
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
@ -1823,6 +1909,7 @@ mod tests {
|
|||
"relation_kind": "blocked_by"
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
@ -1840,7 +1927,10 @@ mod tests {
|
|||
let temp = TempDir::new().unwrap();
|
||||
let show = tool_by_name(backend(&temp), "TicketShow");
|
||||
let error = show
|
||||
.execute(&json!({ "id": "a", "query": "b" }).to_string())
|
||||
.execute(
|
||||
&json!({ "id": "a", "query": "b" }).to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(matches!(error, ToolError::InvalidArgument(_)));
|
||||
|
|
@ -1852,7 +1942,10 @@ mod tests {
|
|||
let backend = backend(&temp);
|
||||
let create = tool_by_name(backend.clone(), "TicketCreate");
|
||||
let output = create
|
||||
.execute(&json!({ "title": "Escape" }).to_string())
|
||||
.execute(
|
||||
&json!({ "title": "Escape" }).to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let value: Value = serde_json::from_str(&output.content.unwrap()).unwrap();
|
||||
|
|
|
|||
|
|
@ -101,7 +101,11 @@ impl Drop for BashTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for BashTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: BashParams = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid Bash input: {e}")))?;
|
||||
let timeout_secs = params
|
||||
|
|
@ -394,7 +398,10 @@ mod tests {
|
|||
assert_eq!(meta.name, "Bash");
|
||||
|
||||
let inp = serde_json::json!({ "command": "echo hello" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(out.summary, "$ echo hello");
|
||||
assert_eq!(out.content.as_deref().map(str::trim), Some("hello"));
|
||||
}
|
||||
|
|
@ -407,7 +414,10 @@ mod tests {
|
|||
let inp = serde_json::json!({
|
||||
"command": "echo out; echo err 1>&2",
|
||||
});
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let body = out.content.unwrap();
|
||||
assert!(body.contains("out"));
|
||||
assert!(body.contains("err"));
|
||||
|
|
@ -419,7 +429,10 @@ mod tests {
|
|||
let tool = make_tool(&h);
|
||||
|
||||
let inp = serde_json::json!({ "command": "exit 7" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(out.summary.contains("exit 7"), "summary: {}", out.summary);
|
||||
assert!(
|
||||
out.content.is_none(),
|
||||
|
|
@ -441,12 +454,16 @@ mod tests {
|
|||
"command": format!("cd {}", sub.to_str().unwrap()),
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let pwd_out = tool
|
||||
.execute(&serde_json::json!({ "command": "pwd" }).to_string())
|
||||
.execute(
|
||||
&serde_json::json!({ "command": "pwd" }).to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let body = pwd_out.content.unwrap();
|
||||
|
|
@ -467,7 +484,10 @@ mod tests {
|
|||
"command": "sleep 30",
|
||||
"timeout": 1,
|
||||
});
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(
|
||||
out.summary.contains("timed out"),
|
||||
"summary: {}",
|
||||
|
|
@ -480,7 +500,10 @@ mod tests {
|
|||
let h = setup();
|
||||
let tool = make_tool(&h);
|
||||
|
||||
let err = tool.execute("not json").await.unwrap_err();
|
||||
let err = tool
|
||||
.execute("not json", Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, ToolError::InvalidArgument(_)));
|
||||
}
|
||||
|
||||
|
|
@ -494,7 +517,10 @@ mod tests {
|
|||
let inp = serde_json::json!({
|
||||
"command": "for i in $(seq 1 200); do echo line $i; done",
|
||||
});
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let body = out.content.expect("expected content");
|
||||
|
||||
assert!(
|
||||
|
|
@ -523,7 +549,10 @@ mod tests {
|
|||
let inp = serde_json::json!({
|
||||
"command": "printf 'x%.0s' {1..20480}",
|
||||
});
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let body = out.content.unwrap();
|
||||
assert!(
|
||||
body.contains(spill_dir.to_str().unwrap()),
|
||||
|
|
@ -542,7 +571,10 @@ mod tests {
|
|||
"command": "(sleep 0.05; echo bg) &",
|
||||
"timeout": 5,
|
||||
});
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(
|
||||
!out.summary.contains("timed out"),
|
||||
"summary: {}",
|
||||
|
|
@ -559,7 +591,9 @@ mod tests {
|
|||
let inp = serde_json::json!({
|
||||
"command": "for i in $(seq 1 200); do echo $i; done",
|
||||
});
|
||||
tool.execute(&inp.to_string()).await.unwrap();
|
||||
tool.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// The spill dir should now contain exactly one bash-*.log file.
|
||||
let files_before: Vec<_> = std::fs::read_dir(&spill_dir)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,11 @@ pub(crate) struct EditTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for EditTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: EditParams = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid Edit input: {e}")))?;
|
||||
|
||||
|
|
@ -169,7 +173,10 @@ mod tests {
|
|||
let def = read_tool(fs.clone(), tracker.clone());
|
||||
let (_, reader) = def();
|
||||
let inp = serde_json::json!({ "file_path": file.to_str().unwrap() });
|
||||
reader.execute(&inp.to_string()).await.unwrap();
|
||||
reader
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -188,7 +195,10 @@ mod tests {
|
|||
"old_string": "foo bar",
|
||||
"new_string": "foo baz",
|
||||
});
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(out.summary.contains("1 replacement"));
|
||||
assert_eq!(
|
||||
std::fs::read_to_string(&file).unwrap(),
|
||||
|
|
@ -212,7 +222,10 @@ mod tests {
|
|||
"new_string": "y",
|
||||
"replace_all": true,
|
||||
});
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(out.summary.contains("3 replacements"));
|
||||
assert_eq!(std::fs::read_to_string(&file).unwrap(), "y y y\n");
|
||||
}
|
||||
|
|
@ -231,7 +244,10 @@ mod tests {
|
|||
"old_string": "a",
|
||||
"new_string": "b",
|
||||
});
|
||||
let err = tool.execute(&inp.to_string()).await.unwrap_err();
|
||||
let err = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, ToolError::InvalidArgument(_)));
|
||||
}
|
||||
|
||||
|
|
@ -249,7 +265,10 @@ mod tests {
|
|||
"old_string": "world",
|
||||
"new_string": "x",
|
||||
});
|
||||
let err = tool.execute(&inp.to_string()).await.unwrap_err();
|
||||
let err = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, ToolError::InvalidArgument(_)));
|
||||
}
|
||||
|
||||
|
|
@ -266,7 +285,10 @@ mod tests {
|
|||
"old_string": "foo",
|
||||
"new_string": "bar",
|
||||
});
|
||||
let err = tool.execute(&inp.to_string()).await.unwrap_err();
|
||||
let err = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, ToolError::InvalidArgument(_)));
|
||||
}
|
||||
|
||||
|
|
@ -287,7 +309,10 @@ mod tests {
|
|||
"old_string": "foo",
|
||||
"new_string": "bar",
|
||||
});
|
||||
let err = tool.execute(&inp.to_string()).await.unwrap_err();
|
||||
let err = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
let msg = format!("{err}");
|
||||
assert!(msg.contains("modified externally"), "{msg}");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,7 +35,11 @@ pub(crate) struct GlobTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for GlobTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: GlobParams = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid Glob input: {e}")))?;
|
||||
|
||||
|
|
@ -239,7 +243,10 @@ mod tests {
|
|||
assert_eq!(meta.name, "Glob");
|
||||
|
||||
let inp = serde_json::json!({ "pattern": "**/*.rs" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(out.summary.contains("2 file(s)"));
|
||||
let body = out.content.unwrap();
|
||||
assert!(body.contains("a.rs"));
|
||||
|
|
@ -261,7 +268,10 @@ mod tests {
|
|||
let def = glob_tool(fs);
|
||||
let (_, tool) = def();
|
||||
let inp = serde_json::json!({ "pattern": "*.rs" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let body = out.content.unwrap();
|
||||
let new_pos = body.find("new.rs").unwrap();
|
||||
let old_pos = body.find("old.rs").unwrap();
|
||||
|
|
@ -274,7 +284,10 @@ mod tests {
|
|||
let def = glob_tool(fs);
|
||||
let (_, tool) = def();
|
||||
let inp = serde_json::json!({ "pattern": "**/*.nonexistent" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(out.summary.contains("No files"));
|
||||
assert!(out.content.is_none());
|
||||
}
|
||||
|
|
@ -285,7 +298,10 @@ mod tests {
|
|||
let def = glob_tool(fs);
|
||||
let (_, tool) = def();
|
||||
let inp = serde_json::json!({ "pattern": "[unterminated" });
|
||||
let err = tool.execute(&inp.to_string()).await.unwrap_err();
|
||||
let err = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, ToolError::InvalidArgument(_)));
|
||||
}
|
||||
|
||||
|
|
@ -317,7 +333,10 @@ mod tests {
|
|||
let def = glob_tool(fs);
|
||||
let (_, tool) = def();
|
||||
let inp = serde_json::json!({ "pattern": "**/*.rs" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let body = out.content.unwrap_or_default();
|
||||
assert!(body.contains("visible.rs"));
|
||||
assert!(
|
||||
|
|
@ -335,7 +354,10 @@ mod tests {
|
|||
let def = glob_tool(fs);
|
||||
let (_, tool) = def();
|
||||
let inp = serde_json::json!({ "pattern": "*.rs" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let body = out.content.unwrap();
|
||||
assert!(body.contains(".hidden.rs"));
|
||||
assert!(body.contains("visible.rs"));
|
||||
|
|
@ -358,7 +380,10 @@ mod tests {
|
|||
"path": link.to_str().unwrap(),
|
||||
"pattern": "**/*.rs",
|
||||
});
|
||||
let err = tool.execute(&inp.to_string()).await.unwrap_err();
|
||||
let err = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
let msg = format!("{err}");
|
||||
assert!(
|
||||
msg.contains("Glob does not follow symlink directories"),
|
||||
|
|
|
|||
|
|
@ -82,7 +82,11 @@ pub(crate) struct GrepTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for GrepTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: GrepParams = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid Grep input: {e}")))?;
|
||||
|
||||
|
|
@ -563,7 +567,10 @@ mod tests {
|
|||
let def = grep_tool(scoped);
|
||||
let (_, tool) = def();
|
||||
let inp = serde_json::json!({ "pattern": "needle" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let body = out.content.unwrap_or_default();
|
||||
assert!(body.contains("visible.txt"));
|
||||
assert!(
|
||||
|
|
@ -583,7 +590,10 @@ mod tests {
|
|||
assert_eq!(meta.name, "Grep");
|
||||
|
||||
let inp = serde_json::json!({ "pattern": "bravo" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(out.summary.contains("1 file"));
|
||||
assert!(out.content.unwrap().contains("a.txt"));
|
||||
}
|
||||
|
|
@ -599,7 +609,10 @@ mod tests {
|
|||
"pattern": "two",
|
||||
"output_mode": "content",
|
||||
});
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let body = out.content.unwrap();
|
||||
assert!(body.contains(":2:two"));
|
||||
}
|
||||
|
|
@ -616,7 +629,10 @@ mod tests {
|
|||
"pattern": "x",
|
||||
"output_mode": "count",
|
||||
});
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let body = out.content.unwrap();
|
||||
assert!(body.contains("a.txt:3"));
|
||||
assert!(body.contains("b.txt:1"));
|
||||
|
|
@ -635,7 +651,10 @@ mod tests {
|
|||
"-i": true,
|
||||
"output_mode": "content",
|
||||
});
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(out.content.unwrap().contains("HELLO"));
|
||||
}
|
||||
|
||||
|
|
@ -654,7 +673,10 @@ mod tests {
|
|||
"output_mode": "content",
|
||||
"-C": 1,
|
||||
});
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let body = out.content.unwrap();
|
||||
// should contain: line2 (before context), MATCH, line4 (after context)
|
||||
assert!(body.contains("line2"));
|
||||
|
|
@ -677,7 +699,10 @@ mod tests {
|
|||
"multiline": true,
|
||||
"output_mode": "content",
|
||||
});
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let body = out.content.unwrap();
|
||||
assert!(body.contains("foo"));
|
||||
}
|
||||
|
|
@ -694,7 +719,10 @@ mod tests {
|
|||
"pattern": "target",
|
||||
"glob": "*.rs",
|
||||
});
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let body = out.content.unwrap();
|
||||
assert!(body.contains("a.rs"));
|
||||
assert!(!body.contains("b.txt"));
|
||||
|
|
@ -712,7 +740,10 @@ mod tests {
|
|||
"pattern": "target",
|
||||
"type": "rust",
|
||||
});
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let body = out.content.unwrap();
|
||||
assert!(body.contains("a.rs"));
|
||||
assert!(!body.contains("b.py"));
|
||||
|
|
@ -731,7 +762,10 @@ mod tests {
|
|||
"pattern": "x",
|
||||
"head_limit": 2,
|
||||
});
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let body = out.content.unwrap();
|
||||
assert_eq!(body.lines().count(), 2);
|
||||
assert!(out.summary.contains("truncated at 2"));
|
||||
|
|
@ -752,7 +786,10 @@ mod tests {
|
|||
"offset": 3,
|
||||
"head_limit": 10,
|
||||
});
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let body = out.content.unwrap();
|
||||
// We skipped 3, so only 2 should remain.
|
||||
assert_eq!(body.lines().count(), 2);
|
||||
|
|
@ -769,7 +806,10 @@ mod tests {
|
|||
let def = grep_tool(fs);
|
||||
let (_, tool) = def();
|
||||
let inp = serde_json::json!({ "pattern": "needle" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let body = out.content.unwrap();
|
||||
assert!(body.contains("b.txt"));
|
||||
assert!(!body.contains("a.bin"));
|
||||
|
|
@ -781,7 +821,10 @@ mod tests {
|
|||
let def = grep_tool(fs);
|
||||
let (_, tool) = def();
|
||||
let inp = serde_json::json!({ "pattern": "(" });
|
||||
let err = tool.execute(&inp.to_string()).await.unwrap_err();
|
||||
let err = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, ToolError::InvalidArgument(_)));
|
||||
}
|
||||
|
||||
|
|
@ -794,7 +837,10 @@ mod tests {
|
|||
"pattern": "x",
|
||||
"type": "nonexistent",
|
||||
});
|
||||
let err = tool.execute(&inp.to_string()).await.unwrap_err();
|
||||
let err = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, ToolError::InvalidArgument(_)));
|
||||
}
|
||||
|
||||
|
|
@ -805,7 +851,10 @@ mod tests {
|
|||
let def = grep_tool(fs);
|
||||
let (_, tool) = def();
|
||||
let inp = serde_json::json!({ "pattern": "zzz" });
|
||||
let out = tool.execute(&inp.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&inp.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(out.summary, "No files matched");
|
||||
assert!(out.content.is_none());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,7 +36,11 @@ pub(crate) struct ReadTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for ReadTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: ReadParams = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid Read input: {e}")))?;
|
||||
let offset = params.offset.unwrap_or(0);
|
||||
|
|
@ -155,7 +159,10 @@ mod tests {
|
|||
assert_eq!(meta.name, "Read");
|
||||
|
||||
let input = serde_json::json!({ "file_path": file.to_str().unwrap() });
|
||||
let out = tool.execute(&input.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&input.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(out.summary.contains("Read 3 line(s)"));
|
||||
let body = out.content.unwrap();
|
||||
assert!(body.contains(" 1\talpha"));
|
||||
|
|
@ -178,7 +185,10 @@ mod tests {
|
|||
"offset": 1,
|
||||
"limit": 2,
|
||||
});
|
||||
let out = tool.execute(&input.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&input.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(out.summary.contains("[2..3] of 5"));
|
||||
let body = out.content.unwrap();
|
||||
assert!(body.contains(" 2\t2"));
|
||||
|
|
@ -193,7 +203,10 @@ mod tests {
|
|||
let input = serde_json::json!({
|
||||
"file_path": dir.path().join("nope.txt").to_str().unwrap()
|
||||
});
|
||||
let err = tool.execute(&input.to_string()).await.unwrap_err();
|
||||
let err = tool
|
||||
.execute(&input.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, ToolError::ExecutionFailed(_)));
|
||||
}
|
||||
|
||||
|
|
@ -202,7 +215,10 @@ mod tests {
|
|||
let (_dir, fs, tracker) = setup();
|
||||
let def = read_tool(fs, tracker);
|
||||
let (_, tool) = def();
|
||||
let err = tool.execute("not json").await.unwrap_err();
|
||||
let err = tool
|
||||
.execute("not json", Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, ToolError::InvalidArgument(_)));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -146,7 +146,11 @@ struct WebFetchTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for WebSearchTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let input: WebSearchInput = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid WebSearch input: {e}")))?;
|
||||
self.web.run_search(input).await
|
||||
|
|
@ -193,7 +197,11 @@ impl WebTools {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for WebFetchTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let input: WebFetchInput = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid WebFetch input: {e}")))?;
|
||||
self.web.run_fetch(input).await
|
||||
|
|
|
|||
|
|
@ -30,7 +30,11 @@ pub(crate) struct WriteTool {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for WriteTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> {
|
||||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: WriteParams = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid Write input: {e}")))?;
|
||||
|
||||
|
|
@ -118,7 +122,10 @@ mod tests {
|
|||
"file_path": file.to_str().unwrap(),
|
||||
"content": "hello\n",
|
||||
});
|
||||
let out = tool.execute(&input.to_string()).await.unwrap();
|
||||
let out = tool
|
||||
.execute(&input.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(out.summary.contains("Created"));
|
||||
assert_eq!(std::fs::read_to_string(&file).unwrap(), "hello\n");
|
||||
}
|
||||
|
|
@ -135,7 +142,10 @@ mod tests {
|
|||
"file_path": file.to_str().unwrap(),
|
||||
"content": "new",
|
||||
});
|
||||
let err = tool.execute(&input.to_string()).await.unwrap_err();
|
||||
let err = tool
|
||||
.execute(&input.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, ToolError::InvalidArgument(_)));
|
||||
}
|
||||
|
||||
|
|
@ -148,7 +158,10 @@ mod tests {
|
|||
let read_def = read_tool(fs.clone(), tracker.clone());
|
||||
let (_, reader) = read_def();
|
||||
let read_in = serde_json::json!({ "file_path": file.to_str().unwrap() });
|
||||
reader.execute(&read_in.to_string()).await.unwrap();
|
||||
reader
|
||||
.execute(&read_in.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let write_def = write_tool(fs, tracker);
|
||||
let (_, writer) = write_def();
|
||||
|
|
@ -156,7 +169,10 @@ mod tests {
|
|||
"file_path": file.to_str().unwrap(),
|
||||
"content": "new\n",
|
||||
});
|
||||
let out = writer.execute(&write_in.to_string()).await.unwrap();
|
||||
let out = writer
|
||||
.execute(&write_in.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(out.summary.contains("Overwrote"));
|
||||
assert_eq!(std::fs::read_to_string(&file).unwrap(), "new\n");
|
||||
}
|
||||
|
|
@ -171,7 +187,10 @@ mod tests {
|
|||
let read_def = read_tool(fs.clone(), tracker.clone());
|
||||
let (_, reader) = read_def();
|
||||
reader
|
||||
.execute(&serde_json::json!({ "file_path": file.to_str().unwrap() }).to_string())
|
||||
.execute(
|
||||
&serde_json::json!({ "file_path": file.to_str().unwrap() }).to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -187,6 +206,7 @@ mod tests {
|
|||
"content": "new",
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
|
@ -205,7 +225,10 @@ mod tests {
|
|||
"file_path": outside.path().join("x.txt").to_str().unwrap(),
|
||||
"content": "x",
|
||||
});
|
||||
let err = tool.execute(&input.to_string()).await.unwrap_err();
|
||||
let err = tool
|
||||
.execute(&input.to_string(), Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, ToolError::InvalidArgument(_)));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -66,13 +66,17 @@ async fn unicode_path_and_content() {
|
|||
"content": content,
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let read = reg.get("Read");
|
||||
let out = read
|
||||
.execute(&json!({ "file_path": file.to_str().unwrap() }).to_string())
|
||||
.execute(
|
||||
&json!({ "file_path": file.to_str().unwrap() }).to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let body = out.content.unwrap();
|
||||
|
|
@ -98,7 +102,10 @@ async fn symlink_to_outside_scope_is_rejected_for_write() {
|
|||
// target sits outside the scope.
|
||||
let read = reg.get("Read");
|
||||
let read_err = read
|
||||
.execute(&json!({ "file_path": link.to_str().unwrap() }).to_string())
|
||||
.execute(
|
||||
&json!({ "file_path": link.to_str().unwrap() }).to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(
|
||||
|
|
@ -119,6 +126,7 @@ async fn symlink_to_outside_scope_is_rejected_for_write() {
|
|||
"content": "overwritten",
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
|
@ -147,7 +155,10 @@ async fn broken_symlink_reports_target_and_repair_hint() {
|
|||
|
||||
let read = reg.get("Read");
|
||||
let err = read
|
||||
.execute(&json!({ "file_path": link.to_str().unwrap() }).to_string())
|
||||
.execute(
|
||||
&json!({ "file_path": link.to_str().unwrap() }).to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
let msg = format!("{err}");
|
||||
|
|
@ -165,7 +176,10 @@ async fn empty_file_read_and_edit() {
|
|||
|
||||
let read = reg.get("Read");
|
||||
let out = read
|
||||
.execute(&json!({ "file_path": file.to_str().unwrap() }).to_string())
|
||||
.execute(
|
||||
&json!({ "file_path": file.to_str().unwrap() }).to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(out.summary.contains("0 line"));
|
||||
|
|
@ -180,6 +194,7 @@ async fn empty_file_read_and_edit() {
|
|||
"new_string": "bar",
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
|
@ -196,7 +211,10 @@ async fn very_long_single_line() {
|
|||
|
||||
let read = reg.get("Read");
|
||||
let out = read
|
||||
.execute(&json!({ "file_path": file.to_str().unwrap() }).to_string())
|
||||
.execute(
|
||||
&json!({ "file_path": file.to_str().unwrap() }).to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
// Should return exactly 1 line
|
||||
|
|
@ -208,7 +226,10 @@ async fn relative_path_is_rejected() {
|
|||
let (_dir, _spill, reg) = setup();
|
||||
let read = reg.get("Read");
|
||||
let err = read
|
||||
.execute(&json!({ "file_path": "relative.txt" }).to_string())
|
||||
.execute(
|
||||
&json!({ "file_path": "relative.txt" }).to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(format!("{err}").contains("absolute"));
|
||||
|
|
@ -219,7 +240,10 @@ async fn directory_target_is_rejected_for_read() {
|
|||
let (dir, _spill, reg) = setup();
|
||||
let read = reg.get("Read");
|
||||
let err = read
|
||||
.execute(&json!({ "file_path": dir.path().to_str().unwrap() }).to_string())
|
||||
.execute(
|
||||
&json!({ "file_path": dir.path().to_str().unwrap() }).to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(format!("{err}").contains("directory"));
|
||||
|
|
@ -237,6 +261,7 @@ async fn deeply_nested_new_file_is_created() {
|
|||
"content": "deep\n",
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
@ -250,9 +275,12 @@ async fn replace_preserves_unicode() {
|
|||
std::fs::write(&file, "🦀 rust 🦀\n").unwrap();
|
||||
|
||||
let read = reg.get("Read");
|
||||
read.execute(&json!({ "file_path": file.to_str().unwrap() }).to_string())
|
||||
.await
|
||||
.unwrap();
|
||||
read.execute(
|
||||
&json!({ "file_path": file.to_str().unwrap() }).to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let edit = reg.get("Edit");
|
||||
edit.execute(
|
||||
|
|
@ -262,6 +290,7 @@ async fn replace_preserves_unicode() {
|
|||
"new_string": "ラスト",
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
@ -282,6 +311,7 @@ async fn grep_handles_unicode_pattern() {
|
|||
"output_mode": "content",
|
||||
})
|
||||
.to_string(),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
|
|||
|
|
@ -66,13 +66,13 @@ fn setup() -> (TempDir, TempDir, Registry) {
|
|||
}
|
||||
|
||||
async fn call(tool: &Arc<dyn Tool>, input: serde_json::Value) -> llm_worker::tool::ToolOutput {
|
||||
tool.execute(&input.to_string())
|
||||
tool.execute(&input.to_string(), Default::default())
|
||||
.await
|
||||
.expect("tool execution failed")
|
||||
}
|
||||
|
||||
async fn call_err(tool: &Arc<dyn Tool>, input: serde_json::Value) -> llm_worker::tool::ToolError {
|
||||
tool.execute(&input.to_string())
|
||||
tool.execute(&input.to_string(), Default::default())
|
||||
.await
|
||||
.expect_err("expected error")
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user