tool: add execution context

This commit is contained in:
Keisuke Hirata 2026-06-09 19:31:11 +09:00
parent b21fab82fc
commit d8aed7befe
No known key found for this signature in database
39 changed files with 1212 additions and 259 deletions

View File

@ -90,19 +90,27 @@ fn extract_doc_comment(attrs: &[Attribute]) -> String {
/// Extract description from #[description = "..."] attribute /// Extract description from #[description = "..."] attribute
fn extract_description_attr(attrs: &[syn::Attribute]) -> Option<String> { fn extract_description_attr(attrs: &[syn::Attribute]) -> Option<String> {
for attr in attrs { for attr in attrs {
if attr.path().is_ident("description") { if attr.path().is_ident("description")
if let Meta::NameValue(meta) = &attr.meta { && let Meta::NameValue(meta) = &attr.meta
if let syn::Expr::Lit(expr_lit) = &meta.value { && let syn::Expr::Lit(expr_lit) = &meta.value
if let Lit::Str(lit_str) = &expr_lit.lit { && let Lit::Str(lit_str) = &expr_lit.lit
{
return Some(lit_str.value()); return Some(lit_str.value());
} }
} }
}
}
}
None None
} }
fn is_tool_execution_context_type(ty: &Type) -> bool {
let Type::Path(path) = ty else {
return false;
};
path.path
.segments
.last()
.is_some_and(|segment| segment.ident == "ToolExecutionContext")
}
/// Generate Tool implementation from a method /// Generate Tool implementation from a method
fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::TokenStream { fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::TokenStream {
let sig = &method.sig; let sig = &method.sig;
@ -123,8 +131,10 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
description description
}; };
// Parse arguments (excluding self) // Parse method arguments (excluding self). A parameter typed as
let args: Vec<_> = sig // ToolExecutionContext is supplied from the execution context and is not
// exposed in the JSON input schema.
let method_args: Vec<_> = sig
.inputs .inputs
.iter() .iter()
.filter_map(|arg| { .filter_map(|arg| {
@ -135,9 +145,14 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
} }
}) })
.collect(); .collect();
let json_args: Vec<_> = method_args
.iter()
.copied()
.filter(|pat_type| !is_tool_execution_context_type(pat_type.ty.as_ref()))
.collect();
// Generate argument struct fields // Generate argument struct fields
let arg_fields: Vec<_> = args let arg_fields: Vec<_> = json_args
.iter() .iter()
.map(|pat_type| { .map(|pat_type| {
let pat = &pat_type.pat; let pat = &pat_type.pat;
@ -165,11 +180,13 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
}) })
.collect(); .collect();
// Code to expand arguments in execute // Code to expand method arguments in execute
let arg_names: Vec<_> = args let call_args: Vec<_> = method_args
.iter() .iter()
.map(|pat_type| { .map(|pat_type| {
if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() { if is_tool_execution_context_type(pat_type.ty.as_ref()) {
quote! { ctx.clone() }
} else if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
let ident = &pat_ident.ident; let ident = &pat_ident.ident;
quote! { args.#ident } quote! { args.#ident }
} else { } else {
@ -177,6 +194,11 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
} }
}) })
.collect(); .collect();
let method_call = if call_args.is_empty() {
quote! { self.ctx.#method_name() }
} else {
quote! { self.ctx.#method_name(#(#call_args),*) }
};
// Check if method is async // Check if method is async
let is_async = sig.asyncness.is_some(); let is_async = sig.asyncness.is_some();
@ -218,13 +240,13 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
}; };
// Execute body handling for no arguments case // Execute body handling for no arguments case
let execute_body = if args.is_empty() { let execute_body = if json_args.is_empty() {
quote! { quote! {
// Allow empty JSON object even with no arguments // Allow empty JSON object even with no JSON arguments
let _: #args_struct_name = serde_json::from_str(input_json) let _: #args_struct_name = serde_json::from_str(input_json)
.unwrap_or(#args_struct_name {}); .unwrap_or(#args_struct_name {});
let result = self.ctx.#method_name()#awaiter; let result = #method_call #awaiter;
#result_handling #result_handling
} }
} else { } else {
@ -232,7 +254,7 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
let args: #args_struct_name = serde_json::from_str(input_json) let args: #args_struct_name = serde_json::from_str(input_json)
.map_err(|e| ::llm_worker::tool::ToolError::InvalidArgument(e.to_string()))?; .map_err(|e| ::llm_worker::tool::ToolError::InvalidArgument(e.to_string()))?;
let result = self.ctx.#method_name(#(#arg_names),*)#awaiter; let result = #method_call #awaiter;
#result_handling #result_handling
} }
}; };
@ -247,7 +269,8 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
#[async_trait::async_trait] #[async_trait::async_trait]
impl ::llm_worker::tool::Tool for #tool_struct_name { impl ::llm_worker::tool::Tool for #tool_struct_name {
async fn execute(&self, input_json: &str) -> Result<::llm_worker::tool::ToolOutput, ::llm_worker::tool::ToolError> { async fn execute(&self, input_json: &str, ctx: ::llm_worker::tool::ToolExecutionContext) -> Result<::llm_worker::tool::ToolOutput, ::llm_worker::tool::ToolError> {
let _ = &ctx;
#execute_body #execute_body
} }
} }

View File

@ -10,7 +10,7 @@ use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use crate::Item; use crate::Item;
use crate::tool::{Tool, ToolCall, ToolMeta, ToolResult}; use crate::tool::{Tool, ToolCall, ToolExecutionContext, ToolMeta, ToolResult};
// ============================================================================= // =============================================================================
// Action Enums // Action Enums
@ -107,6 +107,8 @@ pub struct ToolCallInfo {
pub meta: ToolMeta, pub meta: ToolMeta,
/// Tool instance (for state access). /// Tool instance (for state access).
pub tool: Arc<dyn Tool>, pub tool: Arc<dyn Tool>,
/// Response-local execution context for this call.
pub context: ToolExecutionContext,
} }
/// Context for post-tool-call decisions. /// Context for post-tool-call decisions.
@ -119,6 +121,8 @@ pub struct ToolResultInfo {
pub meta: ToolMeta, pub meta: ToolMeta,
/// Tool instance (for state access). /// Tool instance (for state access).
pub tool: Arc<dyn Tool>, pub tool: Arc<dyn Tool>,
/// Response-local execution context for this call.
pub context: ToolExecutionContext,
} }
// ============================================================================= // =============================================================================

View File

@ -57,7 +57,7 @@ pub use callback::{TextBlockScope, ThinkingBlockScope, ToolUseBlockScope};
pub use handler::ToolUseBlockStart; pub use handler::ToolUseBlockStart;
pub use interceptor::Interceptor; pub use interceptor::Interceptor;
pub use message::{ContentPart, Item, Message, Role}; pub use message::{ContentPart, Item, Message, Role};
pub use tool::{ToolCall, ToolOutputLimits, ToolResult}; pub use tool::{ToolCall, ToolExecutionContext, ToolOutputLimits, ToolResult};
pub use usage_record::UsageRecord; pub use usage_record::UsageRecord;
pub use worker::{ pub use worker::{
LlmRetryNotice, RunOutput, ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult, LlmRetryNotice, RunOutput, ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult,

View File

@ -189,6 +189,44 @@ impl ToolMeta {
/// ``` /// ```
pub type ToolDefinition = Arc<dyn Fn() -> (ToolMeta, Arc<dyn Tool>) + Send + Sync>; pub type ToolDefinition = Arc<dyn Fn() -> (ToolMeta, Arc<dyn Tool>) + Send + Sync>;
/// Per-call context supplied by the worker when executing a tool call.
///
/// The context identifies a tool call within one assistant response's tool-call
/// batch without imposing any scheduling policy on the worker. Tool
/// implementations may use it for response-local ordering, diagnostics, or
/// correlation, but it is intentionally not a handle to worker state, history,
/// or session mutation.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ToolExecutionContext {
/// Provider/tool-call id for the call being executed.
pub call_id: String,
/// Worker-local identity shared by all tool calls from one execution batch.
pub batch_id: String,
/// Zero-based order of this call in the model-returned tool-call list.
pub call_index: usize,
}
impl ToolExecutionContext {
pub fn new(call_id: impl Into<String>, batch_id: impl Into<String>, call_index: usize) -> Self {
Self {
call_id: call_id.into(),
batch_id: batch_id.into(),
call_index,
}
}
/// Context for direct, non-worker calls in unit tests and low-level callers.
pub fn direct() -> Self {
Self::new("direct", "direct", 0)
}
}
impl Default for ToolExecutionContext {
fn default() -> Self {
Self::direct()
}
}
// ============================================================================= // =============================================================================
// Tool trait // Tool trait
// ============================================================================= // =============================================================================
@ -219,16 +257,16 @@ pub type ToolDefinition = Arc<dyn Fn() -> (ToolMeta, Arc<dyn Tool>) + Send + Syn
/// # Manual Implementation /// # Manual Implementation
/// ///
/// ```ignore /// ```ignore
/// use llm_worker::tool::{Tool, ToolError, ToolMeta, ToolDefinition}; /// use llm_worker::tool::{Tool, ToolError, ToolExecutionContext, ToolMeta, ToolDefinition, ToolOutput};
/// use std::sync::Arc; /// use std::sync::Arc;
/// ///
/// struct MyTool { counter: std::sync::atomic::AtomicUsize } /// struct MyTool { counter: std::sync::atomic::AtomicUsize }
/// ///
/// #[async_trait::async_trait] /// #[async_trait::async_trait]
/// impl Tool for MyTool { /// impl Tool for MyTool {
/// async fn execute(&self, input: &str) -> Result<String, ToolError> { /// async fn execute(&self, input: &str, ctx: ToolExecutionContext) -> Result<ToolOutput, ToolError> {
/// self.counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst); /// self.counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
/// Ok("result".to_string()) /// Ok(format!("call {}: {}", ctx.call_index, input).into())
/// } /// }
/// } /// }
/// ///
@ -247,11 +285,16 @@ pub trait Tool: Send + Sync {
/// ///
/// # Arguments /// # Arguments
/// * `input_json` - JSON-formatted arguments generated by LLM /// * `input_json` - JSON-formatted arguments generated by LLM
/// * `ctx` - response-local call identity and ordering context
/// ///
/// # Returns /// # Returns
/// A [`ToolOutput`] with summary and optional detailed content. /// A [`ToolOutput`] with summary and optional detailed content.
/// For simple cases, use `From<String>`: `Ok("done".to_string().into())` /// For simple cases, use `From<String>`: `Ok("done".to_string().into())`
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError>; async fn execute(
&self,
input_json: &str,
ctx: ToolExecutionContext,
) -> Result<ToolOutput, ToolError>;
} }
// ============================================================================= // =============================================================================

View File

@ -4,7 +4,9 @@ use std::sync::{Arc, Mutex};
use thiserror::Error; use thiserror::Error;
use crate::llm_client::ToolDefinition as LlmToolDefinition; use crate::llm_client::ToolDefinition as LlmToolDefinition;
use crate::tool::{Tool, ToolDefinition as WorkerToolDefinition, ToolMeta, ToolOutput}; use crate::tool::{
Tool, ToolDefinition as WorkerToolDefinition, ToolExecutionContext, ToolMeta, ToolOutput,
};
type ToolMap = HashMap<String, (ToolMeta, Arc<dyn Tool>)>; type ToolMap = HashMap<String, (ToolMeta, Arc<dyn Tool>)>;
@ -117,6 +119,7 @@ impl ToolServerHandle {
&self, &self,
name: &str, name: &str,
input_json: &str, input_json: &str,
ctx: ToolExecutionContext,
) -> Result<ToolOutput, ToolServerError> { ) -> Result<ToolOutput, ToolServerError> {
let tool = { let tool = {
let guard = self.tools.lock().unwrap_or_else(|e| e.into_inner()); let guard = self.tools.lock().unwrap_or_else(|e| e.into_inner());
@ -125,7 +128,7 @@ impl ToolServerHandle {
.ok_or_else(|| ToolServerError::ToolNotFound(name.to_string()))?; .ok_or_else(|| ToolServerError::ToolNotFound(name.to_string()))?;
Arc::clone(tool) Arc::clone(tool)
}; };
tool.execute(input_json) tool.execute(input_json, ctx)
.await .await
.map_err(|e| ToolServerError::ToolExecution(e.to_string())) .map_err(|e| ToolServerError::ToolExecution(e.to_string()))
} }
@ -187,7 +190,11 @@ mod tests {
#[async_trait] #[async_trait]
impl Tool for EchoTool { impl Tool for EchoTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: crate::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
Ok(input_json.to_string().into()) Ok(input_json.to_string().into())
} }
} }
@ -236,12 +243,15 @@ mod tests {
handle.register_tool(def("echo")); handle.register_tool(def("echo"));
handle.flush_pending(); handle.flush_pending();
let out = handle.call_tool("echo", r#"{"x":1}"#).await.expect("call"); let out = handle
.call_tool("echo", r#"{"x":1}"#, Default::default())
.await
.expect("call");
assert_eq!(out.summary, r#"{"x":1}"#); assert_eq!(out.summary, r#"{"x":1}"#);
assert!(out.content.is_none()); assert!(out.content.is_none());
let err = handle let err = handle
.call_tool("missing", "{}") .call_tool("missing", "{}", Default::default())
.await .await
.expect_err("missing tool"); .expect_err("missing tool");
assert_eq!(err, ToolServerError::ToolNotFound("missing".to_string())); assert_eq!(err, ToolServerError::ToolNotFound("missing".to_string()));
@ -298,7 +308,11 @@ mod tests {
#[async_trait] #[async_trait]
impl Tool for FixedTool { impl Tool for FixedTool {
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
_input_json: &str,
_ctx: crate::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
Ok("replaced".to_string().into()) Ok("replaced".to_string().into())
} }
} }
@ -327,7 +341,11 @@ mod tests {
#[async_trait] #[async_trait]
impl Tool for ConstTool { impl Tool for ConstTool {
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
_input_json: &str,
_ctx: crate::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
Ok("const".to_string().into()) Ok("const".to_string().into())
} }
} }
@ -342,7 +360,10 @@ mod tests {
}); });
handle.replace(replacement).expect("replace"); handle.replace(replacement).expect("replace");
let out = handle.call_tool("echo", "{}").await.expect("call"); let out = handle
.call_tool("echo", "{}", Default::default())
.await
.expect("call");
assert_eq!(out.summary, "const"); assert_eq!(out.summary, "const");
} }
@ -360,7 +381,11 @@ mod tests {
#[async_trait] #[async_trait]
impl Tool for GatedTool { impl Tool for GatedTool {
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
_input_json: &str,
_ctx: crate::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
self.started.notify_one(); self.started.notify_one();
self.finish.notified().await; self.finish.notified().await;
Ok("done".to_string().into()) Ok("done".to_string().into())
@ -384,7 +409,7 @@ mod tests {
handle.flush_pending(); handle.flush_pending();
let h = handle.clone(); let h = handle.clone();
let call = tokio::spawn(async move { h.call_tool("slow", "{}").await }); let call = tokio::spawn(async move { h.call_tool("slow", "{}", Default::default()).await });
// Wait until the tool is actually executing. // Wait until the tool is actually executing.
started.notified().await; started.notified().await;
@ -413,7 +438,11 @@ mod tests {
#[async_trait] #[async_trait]
impl Tool for OldTool { impl Tool for OldTool {
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
_input_json: &str,
_ctx: crate::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
self.started.notify_one(); self.started.notify_one();
self.finish.notified().await; self.finish.notified().await;
Ok("old".to_string().into()) Ok("old".to_string().into())
@ -437,7 +466,7 @@ mod tests {
handle.flush_pending(); handle.flush_pending();
let h = handle.clone(); let h = handle.clone();
let call = tokio::spawn(async move { h.call_tool("t", "{}").await }); let call = tokio::spawn(async move { h.call_tool("t", "{}", Default::default()).await });
// Wait until the old tool is mid-execution. // Wait until the old tool is mid-execution.
started.notified().await; started.notified().await;
@ -447,7 +476,11 @@ mod tests {
#[async_trait] #[async_trait]
impl Tool for NewTool { impl Tool for NewTool {
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
_input_json: &str,
_ctx: crate::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
Ok("new".to_string().into()) Ok("new".to_string().into())
} }
} }
@ -469,7 +502,10 @@ mod tests {
assert_eq!(result.expect("call").summary, "old"); assert_eq!(result.expect("call").summary, "old");
// New calls use the replacement. // New calls use the replacement.
let out = handle.call_tool("t", "{}").await.expect("call"); let out = handle
.call_tool("t", "{}", Default::default())
.await
.expect("call");
assert_eq!(out.summary, "new"); assert_eq!(out.summary, "new");
} }

View File

@ -26,8 +26,8 @@ use crate::{
timeline::event::{ErrorEvent, StatusEvent, UsageEvent}, timeline::event::{ErrorEvent, StatusEvent, UsageEvent},
timeline::{TextBlockCollector, ThinkingBlockCollector, Timeline, ToolCallCollector}, timeline::{TextBlockCollector, ThinkingBlockCollector, Timeline, ToolCallCollector},
tool::{ tool::{
ToolCall, ToolDefinition as WorkerToolDefinition, ToolError, ToolOutputLimits, ToolResult, ToolCall, ToolDefinition as WorkerToolDefinition, ToolError, ToolExecutionContext,
truncate_content, ToolOutputLimits, ToolResult, truncate_content,
}, },
tool_server::{ToolServer, ToolServerHandle}, tool_server::{ToolServer, ToolServerHandle},
}; };
@ -187,6 +187,10 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
/// LlmCall count (per-Worker running counter, monotonic). Unlike /// LlmCall count (per-Worker running counter, monotonic). Unlike
/// `turn_count` this never collapses retries. /// `turn_count` this never collapses retries.
llm_call_count: usize, llm_call_count: usize,
/// Tool execution batch count (per-Worker running counter, monotonic).
/// Each batch corresponds to one collected assistant tool-call set or one
/// resumed pending tool-call set.
tool_execution_batch_count: usize,
/// Maximum number of AgentTurns (None = unlimited) /// Maximum number of AgentTurns (None = unlimited)
max_turns: Option<u32>, max_turns: Option<u32>,
/// AgentTurn-start callbacks (1:1 with LlmCall today) /// AgentTurn-start callbacks (1:1 with LlmCall today)
@ -912,19 +916,23 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
) -> Result<ToolExecutionResult, WorkerError> { ) -> Result<ToolExecutionResult, WorkerError> {
use futures::future::join_all; use futures::future::join_all;
// Map from tool call ID to (ToolCall, Meta, Tool) // Map from tool call ID to (ToolCall, Meta, Tool, Context)
// Retained because it's needed for PostToolCall hooks // Retained because it's needed for PostToolCall hooks
let mut call_info_map = HashMap::new(); let mut call_info_map = HashMap::new();
let mut synthetic_results = Vec::new(); let mut synthetic_results = Vec::new();
let batch_id = format!("tool-batch-{}", self.tool_execution_batch_count);
self.tool_execution_batch_count += 1;
// Phase 1: Apply pre_tool_call interceptor (determine skip/abort/synthetic result) // Phase 1: Apply pre_tool_call interceptor (determine skip/abort/synthetic result)
let mut approved_calls = Vec::new(); let mut approved_calls = Vec::new();
for mut tool_call in tool_calls { for (call_index, mut tool_call) in tool_calls.into_iter().enumerate() {
let context = ToolExecutionContext::new(&tool_call.id, &batch_id, call_index);
if let Some((meta, tool)) = self.tool_server.get_tool(&tool_call.name) { if let Some((meta, tool)) = self.tool_server.get_tool(&tool_call.name) {
let mut info = ToolCallInfo { let mut info = ToolCallInfo {
call: tool_call.clone(), call: tool_call.clone(),
meta, meta,
tool, tool,
context,
}; };
match self.interceptor.pre_tool_call(&mut info).await { match self.interceptor.pre_tool_call(&mut info).await {
@ -934,9 +942,11 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
} }
PreToolAction::SyntheticResult(result) => { PreToolAction::SyntheticResult(result) => {
let tool_call = info.call; let tool_call = info.call;
let mut context = info.context;
context.call_id = tool_call.id.clone();
call_info_map.insert( call_info_map.insert(
tool_call.id.clone(), tool_call.id.clone(),
(tool_call, info.meta.clone(), info.tool.clone()), (tool_call, info.meta.clone(), info.tool.clone(), context),
); );
synthetic_results.push(result); synthetic_results.push(result);
continue; continue;
@ -953,26 +963,37 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
// Reflect changes made by interceptor // Reflect changes made by interceptor
tool_call = info.call; tool_call = info.call;
let mut context = info.context;
context.call_id = tool_call.id.clone();
call_info_map.insert( call_info_map.insert(
tool_call.id.clone(), tool_call.id.clone(),
(tool_call.clone(), info.meta.clone(), info.tool.clone()), (
tool_call.clone(),
info.meta.clone(),
info.tool.clone(),
context.clone(),
),
); );
approved_calls.push(tool_call); approved_calls.push((tool_call, context));
} else { } else {
// Unknown tools go into approved list as-is (will error at execution) // Unknown tools go into approved list as-is (will error at execution)
approved_calls.push(tool_call); let context = ToolExecutionContext::new(&tool_call.id, &batch_id, call_index);
approved_calls.push((tool_call, context));
} }
} }
// Phase 2: Execute approved tools in parallel (cancellable) // Phase 2: Execute approved tools in parallel (cancellable)
let futures: Vec<_> = approved_calls let futures: Vec<_> = approved_calls
.into_iter() .into_iter()
.map(|tool_call| { .map(|(tool_call, context)| {
let tool_server = self.tool_server.clone(); let tool_server = self.tool_server.clone();
async move { async move {
let input_json = serde_json::to_string(&tool_call.input).unwrap_or_default(); let input_json = serde_json::to_string(&tool_call.input).unwrap_or_default();
match tool_server.call_tool(&tool_call.name, &input_json).await { match tool_server
.call_tool(&tool_call.name, &input_json, context)
.await
{
Ok(output) => ToolResult::from_output(&tool_call.id, output), Ok(output) => ToolResult::from_output(&tool_call.id, output),
Err(e) => ToolResult::error(&tool_call.id, e.to_string()), Err(e) => ToolResult::error(&tool_call.id, e.to_string()),
} }
@ -996,12 +1017,15 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
// Phase 3: Apply post_tool_call interceptor // Phase 3: Apply post_tool_call interceptor
for tool_result in &mut results { for tool_result in &mut results {
if let Some((tool_call, meta, tool)) = call_info_map.get(&tool_result.tool_use_id) { if let Some((tool_call, meta, tool, context)) =
call_info_map.get(&tool_result.tool_use_id)
{
let mut info = ToolResultInfo { let mut info = ToolResultInfo {
call: tool_call.clone(), call: tool_call.clone(),
result: tool_result.clone(), result: tool_result.clone(),
meta: meta.clone(), meta: meta.clone(),
tool: tool.clone(), tool: tool.clone(),
context: context.clone(),
}; };
match self.interceptor.post_tool_call(&mut info).await { match self.interceptor.post_tool_call(&mut info).await {
@ -1026,7 +1050,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
let Some(content) = tool_result.content.as_mut() else { let Some(content) = tool_result.content.as_mut() else {
continue; continue;
}; };
let Some((tool_call, _, _)) = call_info_map.get(&tool_result.tool_use_id) else { let Some((tool_call, _, _, _)) = call_info_map.get(&tool_result.tool_use_id) else {
continue; continue;
}; };
let limit = limits.limit_for(&tool_call.name); let limit = limits.limit_for(&tool_call.name);
@ -1628,6 +1652,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
locked_prefix_len: 0, locked_prefix_len: 0,
turn_count: 0, turn_count: 0,
llm_call_count: 0, llm_call_count: 0,
tool_execution_batch_count: 0,
max_turns: None, max_turns: None,
turn_start_cbs: Vec::new(), turn_start_cbs: Vec::new(),
turn_end_cbs: Vec::new(), turn_end_cbs: Vec::new(),
@ -1892,6 +1917,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
locked_prefix_len, locked_prefix_len,
turn_count: self.turn_count, turn_count: self.turn_count,
llm_call_count: self.llm_call_count, llm_call_count: self.llm_call_count,
tool_execution_batch_count: self.tool_execution_batch_count,
max_turns: self.max_turns, max_turns: self.max_turns,
turn_start_cbs: self.turn_start_cbs, turn_start_cbs: self.turn_start_cbs,
turn_end_cbs: self.turn_end_cbs, turn_end_cbs: self.turn_end_cbs,
@ -1984,6 +2010,7 @@ impl<C: LlmClient> Worker<C, Locked> {
locked_prefix_len: 0, locked_prefix_len: 0,
turn_count: self.turn_count, turn_count: self.turn_count,
llm_call_count: self.llm_call_count, llm_call_count: self.llm_call_count,
tool_execution_batch_count: self.tool_execution_batch_count,
max_turns: self.max_turns, max_turns: self.max_turns,
turn_start_cbs: self.turn_start_cbs, turn_start_cbs: self.turn_start_cbs,
turn_end_cbs: self.turn_end_cbs, turn_end_cbs: self.turn_end_cbs,

View File

@ -218,7 +218,11 @@ struct FixedOutputTool {
#[async_trait] #[async_trait]
impl Tool for FixedOutputTool { impl Tool for FixedOutputTool {
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
_input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
Ok(self.output.clone()) Ok(self.output.clone())
} }
} }
@ -289,7 +293,11 @@ struct ErroringTool {
#[async_trait] #[async_trait]
impl Tool for ErroringTool { impl Tool for ErroringTool {
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
_input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
Err(ToolError::ExecutionFailed(self.message.clone())) Err(ToolError::ExecutionFailed(self.message.clone()))
} }
} }

View File

@ -2,8 +2,8 @@
//! //!
//! Verify that Worker executes multiple tools in parallel. //! Verify that Worker executes multiple tools in parallel.
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use async_trait::async_trait; use async_trait::async_trait;
@ -12,7 +12,9 @@ use llm_worker::interceptor::{
Interceptor, PostToolAction, PreToolAction, ToolCallInfo, ToolResultInfo, Interceptor, PostToolAction, PreToolAction, ToolCallInfo, ToolResultInfo,
}; };
use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent}; use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent};
use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta, ToolOutput, ToolResult}; use llm_worker::tool::{
Tool, ToolDefinition, ToolError, ToolExecutionContext, ToolMeta, ToolOutput, ToolResult,
};
mod common; mod common;
use common::MockLlmClient; use common::MockLlmClient;
@ -59,13 +61,54 @@ impl SlowTool {
#[async_trait] #[async_trait]
impl Tool for SlowTool { impl Tool for SlowTool {
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
_input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
self.call_count.fetch_add(1, Ordering::SeqCst); self.call_count.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(self.delay_ms)).await; tokio::time::sleep(Duration::from_millis(self.delay_ms)).await;
Ok(format!("Completed after {}ms", self.delay_ms).into()) Ok(format!("Completed after {}ms", self.delay_ms).into())
} }
} }
#[derive(Clone)]
struct ContextRecordingTool {
name: String,
contexts: Arc<Mutex<Vec<ToolExecutionContext>>>,
}
impl ContextRecordingTool {
fn new(name: impl Into<String>, contexts: Arc<Mutex<Vec<ToolExecutionContext>>>) -> Self {
Self {
name: name.into(),
contexts,
}
}
fn definition(&self) -> ToolDefinition {
let tool = self.clone();
Arc::new(move || {
let meta = ToolMeta::new(&tool.name)
.description("Records tool execution context")
.input_schema(serde_json::json!({"type": "object"}));
(meta, Arc::new(tool.clone()) as Arc<dyn Tool>)
})
}
}
#[async_trait]
impl Tool for ContextRecordingTool {
async fn execute(
&self,
_input_json: &str,
ctx: ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
self.contexts.lock().unwrap().push(ctx);
Ok("recorded".to_string().into())
}
}
// ============================================================================= // =============================================================================
// Tests // Tests
// ============================================================================= // =============================================================================
@ -92,10 +135,18 @@ async fn test_parallel_tool_execution() {
}), }),
]; ];
let client = MockLlmClient::new(events); let client = MockLlmClient::with_responses(vec![
events,
vec![
Event::text_block_start(0),
Event::text_delta(0, "Done"),
Event::text_block_stop(0, None),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
],
]);
let mut worker = Worker::new(client); let mut worker = Worker::new(client);
// Each tool waits 100ms
let tool1 = SlowTool::new("slow_tool_1", 100); let tool1 = SlowTool::new("slow_tool_1", 100);
let tool2 = SlowTool::new("slow_tool_2", 100); let tool2 = SlowTool::new("slow_tool_2", 100);
let tool3 = SlowTool::new("slow_tool_3", 100); let tool3 = SlowTool::new("slow_tool_3", 100);
@ -129,7 +180,201 @@ async fn test_parallel_tool_execution() {
println!("Parallel execution completed in {:?}", elapsed); println!("Parallel execution completed in {:?}", elapsed);
} }
/// Hook: pre_tool_call - verify that skipped tools are not executed #[tokio::test]
async fn test_tool_execution_context_order_and_batch_id() {
let client = MockLlmClient::with_responses(vec![
vec![
Event::tool_use_start(0, "call_a", "record_a"),
Event::tool_input_delta(0, r#"{}"#),
Event::tool_use_stop(0),
Event::tool_use_start(1, "call_b", "record_b"),
Event::tool_input_delta(1, r#"{}"#),
Event::tool_use_stop(1),
Event::tool_use_start(2, "call_c", "record_c"),
Event::tool_input_delta(2, r#"{}"#),
Event::tool_use_stop(2),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
],
vec![
Event::text_block_start(0),
Event::text_delta(0, "Done"),
Event::text_block_stop(0, None),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
],
]);
let mut worker = Worker::new(client);
let contexts = Arc::new(Mutex::new(Vec::new()));
worker.register_tool(ContextRecordingTool::new("record_a", contexts.clone()).definition());
worker.register_tool(ContextRecordingTool::new("record_b", contexts.clone()).definition());
worker.register_tool(ContextRecordingTool::new("record_c", contexts.clone()).definition());
let _ = worker.run("record contexts").await;
let mut contexts = contexts.lock().unwrap().clone();
contexts.sort_by_key(|ctx| ctx.call_index);
assert_eq!(contexts.len(), 3);
assert_eq!(contexts[0].call_id, "call_a");
assert_eq!(contexts[0].call_index, 0);
assert_eq!(contexts[1].call_id, "call_b");
assert_eq!(contexts[1].call_index, 1);
assert_eq!(contexts[2].call_id, "call_c");
assert_eq!(contexts[2].call_index, 2);
assert_eq!(contexts[0].batch_id, contexts[1].batch_id);
assert_eq!(contexts[1].batch_id, contexts[2].batch_id);
}
#[tokio::test]
async fn test_tool_execution_context_batch_id_changes_between_batches() {
let client = MockLlmClient::with_responses(vec![
vec![
Event::tool_use_start(0, "call_first", "record"),
Event::tool_input_delta(0, r#"{}"#),
Event::tool_use_stop(0),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
],
vec![
Event::tool_use_start(0, "call_second", "record"),
Event::tool_input_delta(0, r#"{}"#),
Event::tool_use_stop(0),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
],
vec![
Event::text_block_start(0),
Event::text_delta(0, "Done"),
Event::text_block_stop(0, None),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
],
]);
let mut worker = Worker::new(client);
let contexts = Arc::new(Mutex::new(Vec::new()));
worker.register_tool(ContextRecordingTool::new("record", contexts.clone()).definition());
let _ = worker.run("record batches").await;
let contexts = contexts.lock().unwrap().clone();
assert_eq!(contexts.len(), 2);
assert_eq!(contexts[0].call_id, "call_first");
assert_eq!(contexts[0].call_index, 0);
assert_eq!(contexts[1].call_id, "call_second");
assert_eq!(contexts[1].call_index, 0);
assert_ne!(contexts[0].batch_id, contexts[1].batch_id);
}
#[tokio::test]
async fn test_tool_execution_context_for_skipped_and_synthetic_paths() {
let client = MockLlmClient::with_responses(vec![
vec![
Event::tool_use_start(0, "call_run", "record"),
Event::tool_input_delta(0, r#"{}"#),
Event::tool_use_stop(0),
Event::tool_use_start(1, "call_skip", "skip_tool"),
Event::tool_input_delta(1, r#"{}"#),
Event::tool_use_stop(1),
Event::tool_use_start(2, "call_synth", "synthetic_tool"),
Event::tool_input_delta(2, r#"{}"#),
Event::tool_use_stop(2),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
],
vec![
Event::text_block_start(0),
Event::text_delta(0, "Done"),
Event::text_block_stop(0, None),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
],
]);
let mut worker = Worker::new(client);
let executed_contexts = Arc::new(Mutex::new(Vec::new()));
let pre_contexts = Arc::new(Mutex::new(Vec::new()));
let post_contexts = Arc::new(Mutex::new(Vec::new()));
worker
.register_tool(ContextRecordingTool::new("record", executed_contexts.clone()).definition());
worker.register_tool(
ContextRecordingTool::new("skip_tool", executed_contexts.clone()).definition(),
);
worker.register_tool(
ContextRecordingTool::new("synthetic_tool", executed_contexts.clone()).definition(),
);
struct ContextPolicy {
pre_contexts: Arc<Mutex<Vec<ToolExecutionContext>>>,
post_contexts: Arc<Mutex<Vec<ToolExecutionContext>>>,
}
#[async_trait]
impl Interceptor for ContextPolicy {
async fn pre_tool_call(&self, info: &mut ToolCallInfo) -> PreToolAction {
self.pre_contexts.lock().unwrap().push(info.context.clone());
match info.call.name.as_str() {
"skip_tool" => PreToolAction::Skip,
"synthetic_tool" => PreToolAction::SyntheticResult(ToolResult::from_output(
&info.call.id,
ToolOutput::from("synthetic result".to_string()),
)),
_ => PreToolAction::Continue,
}
}
async fn post_tool_call(&self, info: &mut ToolResultInfo) -> PostToolAction {
self.post_contexts
.lock()
.unwrap()
.push(info.context.clone());
PostToolAction::Continue
}
}
worker.set_interceptor(ContextPolicy {
pre_contexts: pre_contexts.clone(),
post_contexts: post_contexts.clone(),
});
let _ = worker.run("record skipped and synthetic contexts").await;
let mut pre_contexts = pre_contexts.lock().unwrap().clone();
pre_contexts.sort_by_key(|ctx| ctx.call_index);
assert_eq!(pre_contexts.len(), 3);
assert_eq!(pre_contexts[0].call_id, "call_run");
assert_eq!(pre_contexts[0].call_index, 0);
assert_eq!(pre_contexts[1].call_id, "call_skip");
assert_eq!(pre_contexts[1].call_index, 1);
assert_eq!(pre_contexts[2].call_id, "call_synth");
assert_eq!(pre_contexts[2].call_index, 2);
assert_eq!(pre_contexts[0].batch_id, pre_contexts[1].batch_id);
assert_eq!(pre_contexts[1].batch_id, pre_contexts[2].batch_id);
let executed_contexts = executed_contexts.lock().unwrap().clone();
assert_eq!(executed_contexts.len(), 1);
assert_eq!(executed_contexts[0].call_id, "call_run");
assert_eq!(executed_contexts[0].call_index, 0);
let mut post_contexts = post_contexts.lock().unwrap().clone();
post_contexts.sort_by_key(|ctx| ctx.call_index);
assert_eq!(post_contexts.len(), 2);
assert_eq!(post_contexts[0].call_id, "call_run");
assert_eq!(post_contexts[0].call_index, 0);
assert_eq!(post_contexts[1].call_id, "call_synth");
assert_eq!(post_contexts[1].call_index, 2);
assert_eq!(post_contexts[0].batch_id, post_contexts[1].batch_id);
}
#[tokio::test] #[tokio::test]
async fn test_before_tool_call_skip() { async fn test_before_tool_call_skip() {
let events = vec![ let events = vec![
@ -220,7 +465,11 @@ async fn test_post_tool_call_modification() {
#[async_trait] #[async_trait]
impl Tool for SimpleTool { impl Tool for SimpleTool {
async fn execute(&self, _: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
_: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
Ok("Original Result".to_string().into()) Ok("Original Result".to_string().into())
} }
} }

View File

@ -9,6 +9,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};
use schemars; use schemars;
use serde; use serde;
use llm_worker::ToolExecutionContext;
use llm_worker_macros::tool_registry; use llm_worker_macros::tool_registry;
// ============================================================================= // =============================================================================
@ -42,6 +43,15 @@ impl SimpleContext {
async fn get_prefix(&self) -> String { async fn get_prefix(&self) -> String {
self.prefix.clone() self.prefix.clone()
} }
/// Tool that observes execution context
#[tool]
async fn context_echo(&self, ctx: ToolExecutionContext, message: String) -> String {
format!(
"{}:{}:{}:{}",
ctx.batch_id, ctx.call_index, ctx.call_id, message
)
}
} }
#[tokio::test] #[tokio::test]
@ -74,7 +84,9 @@ async fn test_basic_tool_generation() {
); );
// Execution test // Execution test
let result = tool.execute(r#"{"message": "World"}"#).await; let result = tool
.execute(r#"{"message": "World"}"#, Default::default())
.await;
assert!(result.is_ok(), "Should execute successfully"); assert!(result.is_ok(), "Should execute successfully");
let output = result.unwrap(); let output = result.unwrap();
assert!( assert!(
@ -97,7 +109,9 @@ async fn test_multiple_arguments() {
assert_eq!(meta.name, "add"); assert_eq!(meta.name, "add");
let result = tool.execute(r#"{"a": 10, "b": 20}"#).await; let result = tool
.execute(r#"{"a": 10, "b": 20}"#, Default::default())
.await;
assert!(result.is_ok()); assert!(result.is_ok());
let output = result.unwrap(); let output = result.unwrap();
assert!( assert!(
@ -118,7 +132,7 @@ async fn test_no_arguments() {
assert_eq!(meta.name, "get_prefix"); assert_eq!(meta.name, "get_prefix");
// Call with empty JSON object // Call with empty JSON object
let result = tool.execute(r#"{}"#).await; let result = tool.execute(r#"{}"#, Default::default()).await;
assert!(result.is_ok()); assert!(result.is_ok());
let output = result.unwrap(); let output = result.unwrap();
assert!( assert!(
@ -137,7 +151,9 @@ async fn test_invalid_arguments() {
let (_, tool) = ctx.greet_definition()(); let (_, tool) = ctx.greet_definition()();
// Invalid JSON // Invalid JSON
let result = tool.execute(r#"{"wrong_field": "value"}"#).await; let result = tool
.execute(r#"{"wrong_field": "value"}"#, Default::default())
.await;
assert!(result.is_err(), "Should fail with invalid arguments"); assert!(result.is_err(), "Should fail with invalid arguments");
} }
@ -175,7 +191,7 @@ async fn test_result_return_type_success() {
let ctx = FallibleContext; let ctx = FallibleContext;
let (_, tool) = ctx.validate_definition()(); let (_, tool) = ctx.validate_definition()();
let result = tool.execute(r#"{"value": 42}"#).await; let result = tool.execute(r#"{"value": 42}"#, Default::default()).await;
assert!(result.is_ok(), "Should succeed for positive value"); assert!(result.is_ok(), "Should succeed for positive value");
let output = result.unwrap(); let output = result.unwrap();
assert!( assert!(
@ -190,7 +206,7 @@ async fn test_result_return_type_error() {
let ctx = FallibleContext; let ctx = FallibleContext;
let (_, tool) = ctx.validate_definition()(); let (_, tool) = ctx.validate_definition()();
let result = tool.execute(r#"{"value": -1}"#).await; let result = tool.execute(r#"{"value": -1}"#, Default::default()).await;
assert!(result.is_err(), "Should fail for negative value"); assert!(result.is_err(), "Should fail for negative value");
let err = result.unwrap_err(); let err = result.unwrap_err();
@ -228,9 +244,9 @@ async fn test_sync_method() {
let (_, tool) = ctx.increment_definition()(); let (_, tool) = ctx.increment_definition()();
// Execute 3 times // Execute 3 times
let result1 = tool.execute(r#"{}"#).await; let result1 = tool.execute(r#"{}"#, Default::default()).await;
let result2 = tool.execute(r#"{}"#).await; let result2 = tool.execute(r#"{}"#, Default::default()).await;
let result3 = tool.execute(r#"{}"#).await; let result3 = tool.execute(r#"{}"#, Default::default()).await;
assert!(result1.is_ok()); assert!(result1.is_ok());
assert!(result2.is_ok()); assert!(result2.is_ok());
@ -240,6 +256,24 @@ async fn test_sync_method() {
assert_eq!(ctx.counter.load(Ordering::SeqCst), 3); assert_eq!(ctx.counter.load(Ordering::SeqCst), 3);
} }
#[tokio::test]
async fn test_tool_macro_passes_execution_context() {
let ctx = SimpleContext {
prefix: "Test".to_string(),
};
let (_, tool) = ctx.context_echo_definition()();
let output = tool
.execute(
r#"{"message":"hello"}"#,
ToolExecutionContext::new("call-ctx", "batch-ctx", 7),
)
.await
.unwrap();
assert_eq!(output.summary, "\"batch-ctx:7:call-ctx:hello\"");
}
// ============================================================================= // =============================================================================
// Test: ToolMeta Immutability // Test: ToolMeta Immutability
// ============================================================================= // =============================================================================

View File

@ -58,7 +58,11 @@ impl MockWeatherTool {
#[async_trait] #[async_trait]
impl Tool for MockWeatherTool { impl Tool for MockWeatherTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
self.call_count.fetch_add(1, Ordering::SeqCst); self.call_count.fetch_add(1, Ordering::SeqCst);
// Parse input // Parse input

View File

@ -136,7 +136,11 @@ impl CountingTool {
#[async_trait] #[async_trait]
impl Tool for CountingTool { impl Tool for CountingTool {
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
_input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
self.calls.fetch_add(1, Ordering::SeqCst); self.calls.fetch_add(1, Ordering::SeqCst);
Ok(format!("{}-ok", self.name).into()) Ok(format!("{}-ok", self.name).into())
} }

View File

@ -54,7 +54,11 @@ struct WriteExtractedTool {
#[async_trait] #[async_trait]
impl Tool for WriteExtractedTool { impl Tool for WriteExtractedTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let payload: ExtractedPayload = serde_json::from_str(input_json).map_err(|e| { let payload: ExtractedPayload = serde_json::from_str(input_json).map_err(|e| {
ToolError::InvalidArgument(format!("invalid write_extracted input: {e}")) ToolError::InvalidArgument(format!("invalid write_extracted input: {e}"))
})?; })?;
@ -122,7 +126,7 @@ mod tests {
"requests": [] "requests": []
}) })
.to_string(); .to_string();
let out = tool.execute(&input).await.unwrap(); let out = tool.execute(&input, Default::default()).await.unwrap();
assert!(out.summary.contains("decisions=1")); assert!(out.summary.contains("decisions=1"));
let payload = ctx.take_payload().unwrap(); let payload = ctx.take_payload().unwrap();
assert_eq!(payload.decisions.len(), 1); assert_eq!(payload.decisions.len(), 1);
@ -137,7 +141,7 @@ mod tests {
let first = let first =
serde_json::json!({"decisions": [], "discussions": [], "attempts": [], "requests": []}) serde_json::json!({"decisions": [], "discussions": [], "attempts": [], "requests": []})
.to_string(); .to_string();
tool.execute(&first).await.unwrap(); tool.execute(&first, Default::default()).await.unwrap();
let second = serde_json::json!({ let second = serde_json::json!({
"decisions": [], "decisions": [],
@ -146,7 +150,7 @@ mod tests {
"requests": [] "requests": []
}) })
.to_string(); .to_string();
tool.execute(&second).await.unwrap(); tool.execute(&second, Default::default()).await.unwrap();
let payload = ctx.take_payload().unwrap(); let payload = ctx.take_payload().unwrap();
assert_eq!(payload.attempts.len(), 1); assert_eq!(payload.attempts.len(), 1);
@ -157,7 +161,7 @@ mod tests {
async fn invalid_json_returns_invalid_argument() { async fn invalid_json_returns_invalid_argument() {
let ctx = Arc::new(ExtractWorkerContext::new()); let ctx = Arc::new(ExtractWorkerContext::new());
let tool: Arc<dyn Tool> = Arc::new(WriteExtractedTool { ctx: ctx.clone() }); let tool: Arc<dyn Tool> = Arc::new(WriteExtractedTool { ctx: ctx.clone() });
let res = tool.execute("not json").await; let res = tool.execute("not json", Default::default()).await;
assert!(matches!(res, Err(ToolError::InvalidArgument(_)))); assert!(matches!(res, Err(ToolError::InvalidArgument(_))));
assert!(ctx.take_payload().is_none()); assert!(ctx.take_payload().is_none());
} }

View File

@ -29,7 +29,11 @@ struct MemoryDeleteTool {
#[async_trait] #[async_trait]
impl Tool for MemoryDeleteTool { impl Tool for MemoryDeleteTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: DeleteParams = serde_json::from_str(input_json) let params: DeleteParams = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid MemoryDelete input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid MemoryDelete input: {e}")))?;
let path = params let path = params
@ -139,7 +143,10 @@ mod tests {
let (_, tool) = delete_tool(layout.clone())(); let (_, tool) = delete_tool(layout.clone())();
let out = tool let out = tool
.execute(r#"{"kind":"decision","slug":"obsolete"}"#) .execute(
r#"{"kind":"decision","slug":"obsolete"}"#,
Default::default(),
)
.await .await
.unwrap(); .unwrap();
assert!(out.summary.contains("Deleted")); assert!(out.summary.contains("Deleted"));

View File

@ -47,7 +47,11 @@ struct EditTool {
#[async_trait] #[async_trait]
impl Tool for EditTool { impl Tool for EditTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: EditParams = serde_json::from_str(input_json) let params: EditParams = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid MemoryEdit input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid MemoryEdit input: {e}")))?;
@ -316,7 +320,10 @@ mod tests {
"old_string": "body body", "old_string": "body body",
"new_string": "edited", "new_string": "edited",
}); });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
assert!(out.summary.contains("1 replacement")); assert!(out.summary.contains("1 replacement"));
let after = std::fs::read_to_string(&path).unwrap(); let after = std::fs::read_to_string(&path).unwrap();
assert!(after.contains("edited")); assert!(after.contains("edited"));
@ -335,7 +342,10 @@ mod tests {
"old_string": "status: open\n", "old_string": "status: open\n",
"new_string": "", "new_string": "",
}); });
let err = tool.execute(&inp.to_string()).await.unwrap_err(); let err = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap_err();
let msg = format!("{err}"); let msg = format!("{err}");
assert!(msg.contains("status") || msg.contains("missing")); assert!(msg.contains("status") || msg.contains("missing"));
@ -354,7 +364,10 @@ mod tests {
"old_string": "x", "old_string": "x",
"new_string": "y", "new_string": "y",
}); });
let err = tool.execute(&inp.to_string()).await.unwrap_err(); let err = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap_err();
assert!(matches!(err, ToolError::ExecutionFailed(_))); assert!(matches!(err, ToolError::ExecutionFailed(_)));
} }
@ -369,7 +382,10 @@ mod tests {
"old_string": "x", "old_string": "x",
"new_string": "y", "new_string": "y",
}); });
let err = tool.execute(&inp.to_string()).await.unwrap_err(); let err = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap_err();
assert!(matches!(err, ToolError::InvalidArgument(_))); assert!(matches!(err, ToolError::InvalidArgument(_)));
} }
} }

View File

@ -126,7 +126,11 @@ struct KnowledgeQueryTool {
#[async_trait] #[async_trait]
impl Tool for MemoryQueryTool { impl Tool for MemoryQueryTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: MemoryQueryParams = serde_json::from_str(input_json) let params: MemoryQueryParams = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid MemoryQuery input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid MemoryQuery input: {e}")))?;
let needle = match params.query.as_deref() { let needle = match params.query.as_deref() {
@ -240,7 +244,11 @@ impl Tool for MemoryQueryTool {
#[async_trait] #[async_trait]
impl Tool for KnowledgeQueryTool { impl Tool for KnowledgeQueryTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: KnowledgeQueryParams = serde_json::from_str(input_json).map_err(|e| { let params: KnowledgeQueryParams = serde_json::from_str(input_json).map_err(|e| {
ToolError::InvalidArgument(format!("invalid KnowledgeQuery input: {e}")) ToolError::InvalidArgument(format!("invalid KnowledgeQuery input: {e}"))
})?; })?;
@ -568,7 +576,10 @@ mod tests {
write_decision(dir.path(), "beta", "no match here\n"); write_decision(dir.path(), "beta", "no match here\n");
let (_, tool) = memory_query_tool(layout, QueryConfig::default())(); let (_, tool) = memory_query_tool(layout, QueryConfig::default())();
let inp = serde_json::json!({ "query": "ollama" }); let inp = serde_json::json!({ "query": "ollama" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let records: Vec<OwnedMemoryRecord> = parse_records(&out); let records: Vec<OwnedMemoryRecord> = parse_records(&out);
assert_eq!(records.len(), 1); assert_eq!(records.len(), 1);
assert_eq!(records[0].slug, "alpha"); assert_eq!(records[0].slug, "alpha");
@ -596,7 +607,7 @@ mod tests {
.unwrap(); .unwrap();
let (_, tool) = memory_query_tool(layout, QueryConfig::default())(); let (_, tool) = memory_query_tool(layout, QueryConfig::default())();
let out = tool.execute("{}").await.unwrap(); let out = tool.execute("{}", Default::default()).await.unwrap();
let records: Vec<OwnedMemoryRecord> = parse_records(&out); let records: Vec<OwnedMemoryRecord> = parse_records(&out);
let mut slugs: Vec<&str> = records.iter().map(|r| r.slug.as_str()).collect(); let mut slugs: Vec<&str> = records.iter().map(|r| r.slug.as_str()).collect();
slugs.sort(); slugs.sort();
@ -616,7 +627,10 @@ mod tests {
.unwrap(); .unwrap();
let (_, tool) = memory_query_tool(layout, QueryConfig::default())(); let (_, tool) = memory_query_tool(layout, QueryConfig::default())();
let inp = serde_json::json!({ "query": "needle" }); let inp = serde_json::json!({ "query": "needle" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let records: Vec<OwnedMemoryRecord> = parse_records(&out); let records: Vec<OwnedMemoryRecord> = parse_records(&out);
assert_eq!(records.len(), 1); assert_eq!(records.len(), 1);
assert_eq!(records[0].slug, "summary"); assert_eq!(records[0].slug, "summary");
@ -633,7 +647,10 @@ mod tests {
let (_, tool) = memory_query_tool(layout, QueryConfig::default())(); let (_, tool) = memory_query_tool(layout, QueryConfig::default())();
let inp = serde_json::json!({ "query": "needle" }); let inp = serde_json::json!({ "query": "needle" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let records: Vec<OwnedMemoryRecord> = parse_records(&out); let records: Vec<OwnedMemoryRecord> = parse_records(&out);
assert!(records.is_empty(), "got records: {:?}", out.content); assert!(records.is_empty(), "got records: {:?}", out.content);
} }
@ -653,8 +670,14 @@ mod tests {
let (_, memory_tool) = memory_query_tool(layout.clone(), QueryConfig::default())(); let (_, memory_tool) = memory_query_tool(layout.clone(), QueryConfig::default())();
let (_, knowledge_tool) = knowledge_query_tool(layout.clone(), QueryConfig::default())(); let (_, knowledge_tool) = knowledge_query_tool(layout.clone(), QueryConfig::default())();
let inp = serde_json::json!({ "query": "needle" }); let inp = serde_json::json!({ "query": "needle" });
memory_tool.execute(&inp.to_string()).await.unwrap(); memory_tool
knowledge_tool.execute(&inp.to_string()).await.unwrap(); .execute(&inp.to_string(), Default::default())
.await
.unwrap();
knowledge_tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let report = crate::usage::build_usage_report(&layout).unwrap(); let report = crate::usage::build_usage_report(&layout).unwrap();
assert!(report.records.is_empty()); assert!(report.records.is_empty());
@ -673,7 +696,10 @@ mod tests {
}; };
let (_, tool) = memory_query_tool(layout, cfg)(); let (_, tool) = memory_query_tool(layout, cfg)();
let inp = serde_json::json!({ "query": "needle" }); let inp = serde_json::json!({ "query": "needle" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let records: Vec<OwnedMemoryRecord> = parse_records(&out); let records: Vec<OwnedMemoryRecord> = parse_records(&out);
assert_eq!(records.len(), 3); assert_eq!(records.len(), 3);
} }
@ -692,7 +718,10 @@ mod tests {
}; };
let (_, tool) = memory_query_tool(layout, cfg)(); let (_, tool) = memory_query_tool(layout, cfg)();
let inp = serde_json::json!({ "query": "needle" }); let inp = serde_json::json!({ "query": "needle" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let records: Vec<OwnedMemoryRecord> = parse_records(&out); let records: Vec<OwnedMemoryRecord> = parse_records(&out);
assert_eq!(records.len(), 1); assert_eq!(records.len(), 1);
let e = records[0].excerpt.as_deref().unwrap(); let e = records[0].excerpt.as_deref().unwrap();
@ -708,7 +737,10 @@ mod tests {
let (_dir, layout) = setup(); let (_dir, layout) = setup();
let (_, tool) = memory_query_tool(layout, QueryConfig::default())(); let (_, tool) = memory_query_tool(layout, QueryConfig::default())();
let inp = serde_json::json!({ "query": " " }); let inp = serde_json::json!({ "query": " " });
let err = tool.execute(&inp.to_string()).await.unwrap_err(); let err = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap_err();
assert!(matches!(err, ToolError::InvalidArgument(_))); assert!(matches!(err, ToolError::InvalidArgument(_)));
} }
@ -724,7 +756,10 @@ mod tests {
); );
let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())(); let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())();
let inp = serde_json::json!({ "query": "ollama" }); let inp = serde_json::json!({ "query": "ollama" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let records: Vec<OwnedKnowledgeRecord> = parse_records(&out); let records: Vec<OwnedKnowledgeRecord> = parse_records(&out);
assert_eq!(records.len(), 1); assert_eq!(records.len(), 1);
assert_eq!(records[0].slug, "policy"); assert_eq!(records[0].slug, "policy");
@ -748,7 +783,7 @@ mod tests {
write_knowledge(dir.path(), "h1", "howto", "d2", "body\n"); write_knowledge(dir.path(), "h1", "howto", "d2", "body\n");
let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())(); let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())();
let out = tool.execute("{}").await.unwrap(); let out = tool.execute("{}", Default::default()).await.unwrap();
let records: Vec<OwnedKnowledgeRecord> = parse_records(&out); let records: Vec<OwnedKnowledgeRecord> = parse_records(&out);
let mut slugs: Vec<&str> = records.iter().map(|r| r.slug.as_str()).collect(); let mut slugs: Vec<&str> = records.iter().map(|r| r.slug.as_str()).collect();
slugs.sort(); slugs.sort();
@ -764,7 +799,10 @@ mod tests {
let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())(); let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())();
let inp = serde_json::json!({ "query": "needle", "kind": "howto" }); let inp = serde_json::json!({ "query": "needle", "kind": "howto" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let records: Vec<OwnedKnowledgeRecord> = parse_records(&out); let records: Vec<OwnedKnowledgeRecord> = parse_records(&out);
assert_eq!(records.len(), 1); assert_eq!(records.len(), 1);
assert_eq!(records[0].slug, "h1"); assert_eq!(records[0].slug, "h1");
@ -778,7 +816,10 @@ mod tests {
let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())(); let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())();
let inp = serde_json::json!({ "kind": "howto" }); let inp = serde_json::json!({ "kind": "howto" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let records: Vec<OwnedKnowledgeRecord> = parse_records(&out); let records: Vec<OwnedKnowledgeRecord> = parse_records(&out);
assert_eq!(records.len(), 1); assert_eq!(records.len(), 1);
assert_eq!(records[0].slug, "h1"); assert_eq!(records[0].slug, "h1");
@ -792,7 +833,10 @@ mod tests {
let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())(); let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())();
let inp = serde_json::json!({ "query": "xyzzy" }); let inp = serde_json::json!({ "query": "xyzzy" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let records: Vec<OwnedKnowledgeRecord> = parse_records(&out); let records: Vec<OwnedKnowledgeRecord> = parse_records(&out);
assert_eq!(records.len(), 1); assert_eq!(records.len(), 1);
assert_eq!(records[0].slug, "p"); assert_eq!(records[0].slug, "p");
@ -804,7 +848,10 @@ mod tests {
write_knowledge(dir.path(), "p", "policy", "d", "no match\n"); write_knowledge(dir.path(), "p", "policy", "d", "no match\n");
let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())(); let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())();
let inp = serde_json::json!({ "query": "absent" }); let inp = serde_json::json!({ "query": "absent" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let records: Vec<OwnedKnowledgeRecord> = parse_records(&out); let records: Vec<OwnedKnowledgeRecord> = parse_records(&out);
assert!(records.is_empty()); assert!(records.is_empty());
} }

View File

@ -45,7 +45,11 @@ struct ReadTool {
#[async_trait] #[async_trait]
impl Tool for ReadTool { impl Tool for ReadTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: ReadParams = serde_json::from_str(input_json) let params: ReadParams = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid MemoryRead input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid MemoryRead input: {e}")))?;
@ -225,7 +229,10 @@ mod tests {
let (_meta, tool) = read_tool(layout)(); let (_meta, tool) = read_tool(layout)();
let inp = serde_json::json!({ "kind": "decision", "slug": "foo" }); let inp = serde_json::json!({ "kind": "decision", "slug": "foo" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let body = out.content.unwrap(); let body = out.content.unwrap();
assert!(body.contains(" 1\talpha")); assert!(body.contains(" 1\talpha"));
assert!(body.contains(" 2\tbeta")); assert!(body.contains(" 2\tbeta"));
@ -240,7 +247,10 @@ mod tests {
let (_, tool) = read_tool(layout)(); let (_, tool) = read_tool(layout)();
let inp = serde_json::json!({ "kind": "summary" }); let inp = serde_json::json!({ "kind": "summary" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
assert!(out.content.unwrap().contains("summary body")); assert!(out.content.unwrap().contains("summary body"));
} }
@ -249,7 +259,10 @@ mod tests {
let (_dir, layout) = setup(); let (_dir, layout) = setup();
let (_, tool) = read_tool(layout)(); let (_, tool) = read_tool(layout)();
let inp = serde_json::json!({ "kind": "summary", "slug": "x" }); let inp = serde_json::json!({ "kind": "summary", "slug": "x" });
let err = tool.execute(&inp.to_string()).await.unwrap_err(); let err = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap_err();
assert!(matches!(err, ToolError::InvalidArgument(_))); assert!(matches!(err, ToolError::InvalidArgument(_)));
} }
@ -258,7 +271,10 @@ mod tests {
let (_dir, layout) = setup(); let (_dir, layout) = setup();
let (_, tool) = read_tool(layout)(); let (_, tool) = read_tool(layout)();
let inp = serde_json::json!({ "kind": "decision" }); let inp = serde_json::json!({ "kind": "decision" });
let err = tool.execute(&inp.to_string()).await.unwrap_err(); let err = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap_err();
assert!(matches!(err, ToolError::InvalidArgument(_))); assert!(matches!(err, ToolError::InvalidArgument(_)));
} }
@ -267,7 +283,10 @@ mod tests {
let (_dir, layout) = setup(); let (_dir, layout) = setup();
let (_, tool) = read_tool(layout)(); let (_, tool) = read_tool(layout)();
let inp = serde_json::json!({ "kind": "decision", "slug": "Bad-Slug" }); let inp = serde_json::json!({ "kind": "decision", "slug": "Bad-Slug" });
let err = tool.execute(&inp.to_string()).await.unwrap_err(); let err = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap_err();
assert!(matches!(err, ToolError::InvalidArgument(_))); assert!(matches!(err, ToolError::InvalidArgument(_)));
} }
@ -280,7 +299,10 @@ mod tests {
let (_, tool) = read_tool(layout)(); let (_, tool) = read_tool(layout)();
let inp = serde_json::json!({ "kind": "knowledge", "slug": "policy" }); let inp = serde_json::json!({ "kind": "knowledge", "slug": "policy" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
assert!(out.content.unwrap().contains("k")); assert!(out.content.unwrap().contains("k"));
} }
@ -293,7 +315,9 @@ mod tests {
let (_, tool) = read_tool_with_usage(layout.clone(), "session-1")(); let (_, tool) = read_tool_with_usage(layout.clone(), "session-1")();
let inp = serde_json::json!({ "kind": "decision", "slug": "foo" }); let inp = serde_json::json!({ "kind": "decision", "slug": "foo" });
tool.execute(&inp.to_string()).await.unwrap(); tool.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let report = usage::build_usage_report(&layout).unwrap(); let report = usage::build_usage_report(&layout).unwrap();
assert_eq!(report.records.len(), 1); assert_eq!(report.records.len(), 1);
@ -310,7 +334,10 @@ mod tests {
let (_dir, layout) = setup(); let (_dir, layout) = setup();
let (_, tool) = read_tool(layout)(); let (_, tool) = read_tool(layout)();
let inp = serde_json::json!({ "kind": "decision", "slug": "missing" }); let inp = serde_json::json!({ "kind": "decision", "slug": "missing" });
let err = tool.execute(&inp.to_string()).await.unwrap_err(); let err = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap_err();
assert!(matches!(err, ToolError::ExecutionFailed(_))); assert!(matches!(err, ToolError::ExecutionFailed(_)));
} }
} }

View File

@ -42,7 +42,11 @@ struct WriteTool {
#[async_trait] #[async_trait]
impl Tool for WriteTool { impl Tool for WriteTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: WriteParams = serde_json::from_str(input_json) let params: WriteParams = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid MemoryWrite input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid MemoryWrite input: {e}")))?;
@ -229,7 +233,10 @@ mod tests {
"kind": "summary", "kind": "summary",
"content": content, "content": content,
}); });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
assert!(out.summary.contains("Created")); assert!(out.summary.contains("Created"));
assert!(path.exists()); assert!(path.exists());
} }
@ -249,7 +256,10 @@ mod tests {
"slug": "foo", "slug": "foo",
"content": content, "content": content,
}); });
let err = tool.execute(&inp.to_string()).await.unwrap_err(); let err = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap_err();
let msg = format!("{err}"); let msg = format!("{err}");
assert!(msg.contains("status") || msg.contains("missing"), "{msg}"); assert!(msg.contains("status") || msg.contains("missing"), "{msg}");
} }
@ -271,7 +281,10 @@ mod tests {
"slug": "foo", "slug": "foo",
"content": initial, "content": initial,
}); });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
assert!(out.summary.contains("Overwrote")); assert!(out.summary.contains("Overwrote"));
} }
@ -283,7 +296,10 @@ mod tests {
"kind": "decision", "kind": "decision",
"content": "ignored", "content": "ignored",
}); });
let err = tool.execute(&inp.to_string()).await.unwrap_err(); let err = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap_err();
assert!(matches!(err, ToolError::InvalidArgument(_))); assert!(matches!(err, ToolError::InvalidArgument(_)));
} }
@ -298,7 +314,11 @@ mod tests {
"slug": "foo", "slug": "foo",
"content": bad, "content": bad,
}); });
assert!(tool.execute(&inp.to_string()).await.is_err()); assert!(
tool.execute(&inp.to_string(), Default::default())
.await
.is_err()
);
assert!(!path.exists()); assert!(!path.exists());
} }
@ -312,7 +332,10 @@ mod tests {
"slug": "wf", "slug": "wf",
"content": "---\n---\n", "content": "---\n---\n",
}); });
let err = tool.execute(&inp.to_string()).await.unwrap_err(); let err = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap_err();
assert!(matches!(err, ToolError::InvalidArgument(_))); assert!(matches!(err, ToolError::InvalidArgument(_)));
} }
} }

View File

@ -151,7 +151,11 @@ struct SearchSessionLogTool {
#[async_trait] #[async_trait]
impl Tool for SearchSessionLogTool { impl Tool for SearchSessionLogTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: SearchSessionParams = serde_json::from_str(input_json).map_err(|e| { let params: SearchSessionParams = serde_json::from_str(input_json).map_err(|e| {
ToolError::InvalidArgument(format!("invalid search_session_log input: {e}")) ToolError::InvalidArgument(format!("invalid search_session_log input: {e}"))
})?; })?;
@ -206,7 +210,11 @@ struct ReadSessionItemsTool {
#[async_trait] #[async_trait]
impl Tool for ReadSessionItemsTool { impl Tool for ReadSessionItemsTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: ReadSessionParams = serde_json::from_str(input_json).map_err(|e| { let params: ReadSessionParams = serde_json::from_str(input_json).map_err(|e| {
ToolError::InvalidArgument(format!("invalid read_session_items input: {e}")) ToolError::InvalidArgument(format!("invalid read_session_items input: {e}"))
})?; })?;
@ -368,7 +376,11 @@ struct MarkReadRequiredTool {
#[async_trait] #[async_trait]
impl Tool for MarkReadRequiredTool { impl Tool for MarkReadRequiredTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: MarkParams = serde_json::from_str(input_json).map_err(|e| { let params: MarkParams = serde_json::from_str(input_json).map_err(|e| {
ToolError::InvalidArgument(format!("invalid mark_read_required input: {e}")) ToolError::InvalidArgument(format!("invalid mark_read_required input: {e}"))
})?; })?;
@ -425,7 +437,11 @@ struct AddReferenceTool {
#[async_trait] #[async_trait]
impl Tool for AddReferenceTool { impl Tool for AddReferenceTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: ReferenceParams = serde_json::from_str(input_json) let params: ReferenceParams = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid add_reference input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid add_reference input: {e}")))?;
let mut guard = self.ctx.lock().expect("compact worker context poisoned"); let mut guard = self.ctx.lock().expect("compact worker context poisoned");
@ -449,7 +465,11 @@ struct WriteSummaryTool {
#[async_trait] #[async_trait]
impl Tool for WriteSummaryTool { impl Tool for WriteSummaryTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: SummaryParams = serde_json::from_str(input_json) let params: SummaryParams = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid write_summary input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid write_summary input: {e}")))?;
let mut guard = self.ctx.lock().expect("compact worker context poisoned"); let mut guard = self.ctx.lock().expect("compact worker context poisoned");
@ -749,7 +769,7 @@ mod tests {
ctx: ctx.clone(), ctx: ctx.clone(),
}); });
let input = serde_json::json!({ "file_path": path.to_str().unwrap() }).to_string(); let input = serde_json::json!({ "file_path": path.to_str().unwrap() }).to_string();
let out = tool.execute(&input).await.unwrap(); let out = tool.execute(&input, Default::default()).await.unwrap();
assert!(out.summary.starts_with("Marked")); assert!(out.summary.starts_with("Marked"));
let guard = ctx.lock().unwrap(); let guard = ctx.lock().unwrap();
@ -770,7 +790,7 @@ mod tests {
ctx: ctx.clone(), ctx: ctx.clone(),
}); });
let input = serde_json::json!({ "file_path": path.to_str().unwrap() }).to_string(); let input = serde_json::json!({ "file_path": path.to_str().unwrap() }).to_string();
let res = tool.execute(&input).await; let res = tool.execute(&input, Default::default()).await;
assert!(matches!(res, Err(ToolError::ExecutionFailed(_)))); assert!(matches!(res, Err(ToolError::ExecutionFailed(_))));
let guard = ctx.lock().unwrap(); let guard = ctx.lock().unwrap();
@ -784,11 +804,11 @@ mod tests {
let tool: Arc<dyn Tool> = Arc::new(WriteSummaryTool { ctx: ctx.clone() }); let tool: Arc<dyn Tool> = Arc::new(WriteSummaryTool { ctx: ctx.clone() });
let first = serde_json::json!({ "text": "first" }).to_string(); let first = serde_json::json!({ "text": "first" }).to_string();
let out1 = tool.execute(&first).await.unwrap(); let out1 = tool.execute(&first, Default::default()).await.unwrap();
assert!(out1.summary.contains("recorded")); assert!(out1.summary.contains("recorded"));
let second = serde_json::json!({ "text": "second" }).to_string(); let second = serde_json::json!({ "text": "second" }).to_string();
let out2 = tool.execute(&second).await.unwrap(); let out2 = tool.execute(&second, Default::default()).await.unwrap();
assert!(out2.summary.contains("replaced")); assert!(out2.summary.contains("replaced"));
assert_eq!(ctx.lock().unwrap().summary.as_deref(), Some("second")); assert_eq!(ctx.lock().unwrap().summary.as_deref(), Some("second"));
@ -801,8 +821,8 @@ mod tests {
let p = "/abs/path.rs"; let p = "/abs/path.rs";
let input = serde_json::json!({ "file_path": p }).to_string(); let input = serde_json::json!({ "file_path": p }).to_string();
tool.execute(&input).await.unwrap(); tool.execute(&input, Default::default()).await.unwrap();
tool.execute(&input).await.unwrap(); tool.execute(&input, Default::default()).await.unwrap();
let guard = ctx.lock().unwrap(); let guard = ctx.lock().unwrap();
assert_eq!(guard.references.len(), 1); assert_eq!(guard.references.len(), 1);
@ -823,7 +843,7 @@ mod tests {
state: Arc::new(SessionLogToolState { items }), state: Arc::new(SessionLogToolState { items }),
}); });
let input = serde_json::json!({ "query": "compact", "limit": 10 }).to_string(); let input = serde_json::json!({ "query": "compact", "limit": 10 }).to_string();
let out = tool.execute(&input).await.unwrap(); let out = tool.execute(&input, Default::default()).await.unwrap();
let content = out.content.unwrap(); let content = out.content.unwrap();
assert!(content.contains("investigate compact failure")); assert!(content.contains("investigate compact failure"));
@ -842,7 +862,7 @@ mod tests {
state: Arc::new(SessionLogToolState { items }), state: Arc::new(SessionLogToolState { items }),
}); });
let input = serde_json::json!({ "offset": 0, "limit": 1, "mode": "full" }).to_string(); let input = serde_json::json!({ "offset": 0, "limit": 1, "mode": "full" }).to_string();
let out = tool.execute(&input).await.unwrap(); let out = tool.execute(&input, Default::default()).await.unwrap();
let content = out.content.unwrap(); let content = out.content.unwrap();
assert!(content.contains("raw trace detail")); assert!(content.contains("raw trace detail"));

View File

@ -752,7 +752,11 @@ impl<St> Tool for ListPodsTool<St>
where where
St: PodMetadataStore + Clone + Send + Sync + 'static, St: PodMetadataStore + Clone + Send + Sync + 'static,
{ {
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
_input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let items = self let items = self
.discovery .discovery
.list_visible() .list_visible()
@ -775,7 +779,11 @@ impl<St> Tool for RestorePodTool<St>
where where
St: PodMetadataStore + Clone + Send + Sync + 'static, St: PodMetadataStore + Clone + Send + Sync + 'static,
{ {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let input: PodNameInput = serde_json::from_str(input_json) let input: PodNameInput = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid RestorePod input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid RestorePod input: {e}")))?;
let result = self let result = self
@ -847,7 +855,11 @@ impl<St> Tool for SendToPeerPodTool<St>
where where
St: PodMetadataStore + Clone + Send + Sync + 'static, St: PodMetadataStore + Clone + Send + Sync + 'static,
{ {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let input: SendToPeerPodInput = serde_json::from_str(input_json) let input: SendToPeerPodInput = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid SendToPeerPod input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid SendToPeerPod input: {e}")))?;
let detail = self let detail = self
@ -1392,7 +1404,7 @@ mod tests {
let (_, tool) = send_to_peer_pod_tool(discovery)(); let (_, tool) = send_to_peer_pod_tool(discovery)();
let output = tool let output = tool
.execute(r#"{"name":"target","message":"hello"}"#) .execute(r#"{"name":"target","message":"hello"}"#, Default::default())
.await .await
.unwrap(); .unwrap();
assert_eq!(output.summary, "sent peer message to `target`"); assert_eq!(output.summary, "sent peer message to `target`");

View File

@ -1292,7 +1292,11 @@ mod tests {
#[async_trait] #[async_trait]
impl Tool for DummyTool { impl Tool for DummyTool {
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
_input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
Ok(ToolOutput::from("ok".to_string())) Ok(ToolOutput::from("ok".to_string()))
} }
} }

View File

@ -73,7 +73,11 @@ step: leave the task as-is, summarize the problem to the user, and end the turn.
#[async_trait] #[async_trait]
impl Tool for TaskCreateTool { impl Tool for TaskCreateTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: TaskCreateParams = serde_json::from_str(input_json) let params: TaskCreateParams = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid TaskCreate input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid TaskCreate input: {e}")))?;
let created = self.store.create(params.subject, params.description); let created = self.store.create(params.subject, params.description);
@ -93,7 +97,11 @@ impl Tool for TaskCreateTool {
#[async_trait] #[async_trait]
impl Tool for TaskListTool { impl Tool for TaskListTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let _: TaskListParams = serde_json::from_str(input_json) let _: TaskListParams = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid TaskList input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid TaskList input: {e}")))?;
let tasks = self.store.list(); let tasks = self.store.list();
@ -106,7 +114,11 @@ impl Tool for TaskListTool {
#[async_trait] #[async_trait]
impl Tool for TaskGetTool { impl Tool for TaskGetTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: TaskGetParams = serde_json::from_str(input_json) let params: TaskGetParams = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid TaskGet input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid TaskGet input: {e}")))?;
let task = self.store.get(params.taskid).ok_or_else(|| { let task = self.store.get(params.taskid).ok_or_else(|| {
@ -122,7 +134,11 @@ impl Tool for TaskGetTool {
#[async_trait] #[async_trait]
impl Tool for TaskUpdateTool { impl Tool for TaskUpdateTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: TaskUpdateParams = serde_json::from_str(input_json) let params: TaskUpdateParams = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid TaskUpdate input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid TaskUpdate input: {e}")))?;
let updated = self let updated = self
@ -241,14 +257,20 @@ mod tests {
let update = tool(task_update_tool(store.clone())); let update = tool(task_update_tool(store.clone()));
let out = create let out = create
.execute(r#"{"subject":"implement","description":"write code"}"#) .execute(
r#"{"subject":"implement","description":"write code"}"#,
Default::default(),
)
.await .await
.unwrap(); .unwrap();
assert!(out.summary.contains("Created task 1")); assert!(out.summary.contains("Created task 1"));
assert_eq!(store.get(1).unwrap().status, TaskStatus::Pending); assert_eq!(store.get(1).unwrap().status, TaskStatus::Pending);
let out = update let out = update
.execute(r#"{"taskid":1,"status":"inprogress","subject":"implement tasks"}"#) .execute(
r#"{"taskid":1,"status":"inprogress","subject":"implement tasks"}"#,
Default::default(),
)
.await .await
.unwrap(); .unwrap();
assert!(out.summary.contains("Updated task 1")); assert!(out.summary.contains("Updated task 1"));
@ -256,11 +278,14 @@ mod tests {
assert_eq!(task.status, TaskStatus::Inprogress); assert_eq!(task.status, TaskStatus::Inprogress);
assert_eq!(task.subject, "implement tasks"); assert_eq!(task.subject, "implement tasks");
let out = get.execute(r#"{"taskid":1}"#).await.unwrap(); let out = get
.execute(r#"{"taskid":1}"#, Default::default())
.await
.unwrap();
assert!(out.summary.contains("Task 1 (inprogress)")); assert!(out.summary.contains("Task 1 (inprogress)"));
assert!(out.content.unwrap().contains("implement tasks")); assert!(out.content.unwrap().contains("implement tasks"));
let out = list.execute("{}").await.unwrap(); let out = list.execute("{}", Default::default()).await.unwrap();
assert!(out.summary.contains("1 task(s)")); assert!(out.summary.contains("1 task(s)"));
let content = out.content.unwrap(); let content = out.content.unwrap();
assert!(content.contains("\"taskid\": 1")); assert!(content.contains("\"taskid\": 1"));
@ -273,11 +298,14 @@ mod tests {
store.create("s".into(), "d".into()); store.create("s".into(), "d".into());
let update = tool(task_update_tool(store)); let update = tool(task_update_tool(store));
let err = update.execute(r#"{"taskid":1}"#).await.unwrap_err(); let err = update
.execute(r#"{"taskid":1}"#, Default::default())
.await
.unwrap_err();
assert!(err.to_string().contains("at least one")); assert!(err.to_string().contains("at least one"));
let err = update let err = update
.execute(r#"{"taskid":99,"status":"deleted"}"#) .execute(r#"{"taskid":99,"status":"deleted"}"#, Default::default())
.await .await
.unwrap_err(); .unwrap_err();
assert!(err.to_string().contains("taskid 99 not found")); assert!(err.to_string().contains("taskid 99 not found"));

View File

@ -491,6 +491,7 @@ mod tests {
}, },
meta, meta,
tool, tool,
context: llm_worker::tool::ToolExecutionContext::new("call-id", "test-batch", 0),
} }
} }
@ -898,6 +899,7 @@ mod tests {
), ),
meta: info.meta, meta: info.meta,
tool: info.tool, tool: info.tool,
context: info.context,
}; };
let action = interceptor.post_tool_call(&mut result_info).await; let action = interceptor.post_tool_call(&mut result_info).await;

View File

@ -62,7 +62,11 @@ struct SendToPodTool {
#[async_trait] #[async_trait]
impl Tool for SendToPodTool { impl Tool for SendToPodTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let input: SendToPodInput = serde_json::from_str(input_json) let input: SendToPodInput = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid SendToPod input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid SendToPod input: {e}")))?;
let record = self let record = self
@ -123,7 +127,11 @@ struct ReadPodOutputTool {
#[async_trait] #[async_trait]
impl Tool for ReadPodOutputTool { impl Tool for ReadPodOutputTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let input: NameInput = serde_json::from_str(input_json) let input: NameInput = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid ReadPodOutput input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid ReadPodOutput input: {e}")))?;
let record = self let record = self
@ -197,7 +205,11 @@ struct StopPodTool {
#[async_trait] #[async_trait]
impl Tool for StopPodTool { impl Tool for StopPodTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let input: NameInput = serde_json::from_str(input_json) let input: NameInput = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid StopPod input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid StopPod input: {e}")))?;
let record = self let record = self

View File

@ -298,7 +298,11 @@ impl SpawnPodTool {
#[async_trait] #[async_trait]
impl Tool for SpawnPodTool { impl Tool for SpawnPodTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let input: SpawnPodInput = serde_json::from_str(input_json) let input: SpawnPodInput = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid SpawnPod input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid SpawnPod input: {e}")))?;

View File

@ -1351,7 +1351,11 @@ struct HangingTool;
#[async_trait] #[async_trait]
impl Tool for HangingTool { impl Tool for HangingTool {
async fn execute(&self, _input: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
_input: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
std::future::pending::<()>().await; std::future::pending::<()>().await;
unreachable!() unreachable!()
} }

View File

@ -262,7 +262,7 @@ async fn send_to_pod_delivers_run_method() {
let def = send_to_pod_tool(registry); let def = send_to_pod_tool(registry);
let (_meta, tool) = def(); let (_meta, tool) = def();
let input = json!({ "name": "child", "message": "hello there" }).to_string(); let input = json!({ "name": "child", "message": "hello there" }).to_string();
let output: ToolOutput = tool.execute(&input).await.unwrap(); let output: ToolOutput = tool.execute(&input, Default::default()).await.unwrap();
assert!( assert!(
output.summary.contains("child"), output.summary.contains("child"),
"summary: {}", "summary: {}",
@ -285,7 +285,7 @@ async fn send_to_pod_errors_on_unknown_pod() {
let def = send_to_pod_tool(registry); let def = send_to_pod_tool(registry);
let (_meta, tool) = def(); let (_meta, tool) = def();
let input = json!({ "name": "nope", "message": "hi" }).to_string(); let input = json!({ "name": "nope", "message": "hi" }).to_string();
let err = tool.execute(&input).await.unwrap_err(); let err = tool.execute(&input, Default::default()).await.unwrap_err();
assert!(err.to_string().contains("no spawned pod"), "{err}"); assert!(err.to_string().contains("no spawned pod"), "{err}");
} }
@ -307,7 +307,7 @@ async fn send_to_pod_errors_when_pod_already_running() {
let def = send_to_pod_tool(registry); let def = send_to_pod_tool(registry);
let (_meta, tool) = def(); let (_meta, tool) = def();
let input = json!({ "name": "child", "message": "hi" }).to_string(); let input = json!({ "name": "child", "message": "hi" }).to_string();
let err = tool.execute(&input).await.unwrap_err(); let err = tool.execute(&input, Default::default()).await.unwrap_err();
assert!( assert!(
err.to_string().contains("already running"), err.to_string().contains("already running"),
"expected AlreadyRunning wording: {err}" "expected AlreadyRunning wording: {err}"
@ -341,13 +341,13 @@ async fn read_pod_output_returns_new_assistant_text_then_empty_on_second_call()
let (_meta, tool) = def(); let (_meta, tool) = def();
let input = json!({ "name": "child" }).to_string(); let input = json!({ "name": "child" }).to_string();
let first: ToolOutput = tool.execute(&input).await.unwrap(); let first: ToolOutput = tool.execute(&input, Default::default()).await.unwrap();
let body = first.content.expect("first read should have content"); let body = first.content.expect("first read should have content");
assert!(body.contains("hi back"), "body: {body}"); assert!(body.contains("hi back"), "body: {body}");
assert!(body.contains("still working"), "body: {body}"); assert!(body.contains("still working"), "body: {body}");
// Cursor now points past all items — second call returns no new text. // Cursor now points past all items — second call returns no new text.
let second: ToolOutput = tool.execute(&input).await.unwrap(); let second: ToolOutput = tool.execute(&input, Default::default()).await.unwrap();
assert!( assert!(
second.content.is_none(), second.content.is_none(),
"unexpected content: {:?}", "unexpected content: {:?}",
@ -371,7 +371,7 @@ async fn read_pod_output_reports_stopped_on_dead_socket() {
let def = read_pod_output_tool(registry); let def = read_pod_output_tool(registry);
let (_meta, tool) = def(); let (_meta, tool) = def();
let input = json!({ "name": "child" }).to_string(); let input = json!({ "name": "child" }).to_string();
let output: ToolOutput = tool.execute(&input).await.unwrap(); let output: ToolOutput = tool.execute(&input, Default::default()).await.unwrap();
assert!(output.summary.contains("stopped"), "{}", output.summary); assert!(output.summary.contains("stopped"), "{}", output.summary);
} }
@ -452,7 +452,7 @@ async fn stop_pod_sends_shutdown_and_releases_scope() {
let def = stop_pod_tool(registry.clone()); let def = stop_pod_tool(registry.clone());
let (_meta, tool) = def(); let (_meta, tool) = def();
let input = json!({ "name": "child" }).to_string(); let input = json!({ "name": "child" }).to_string();
let output: ToolOutput = tool.execute(&input).await.unwrap(); let output: ToolOutput = tool.execute(&input, Default::default()).await.unwrap();
assert!(output.summary.contains("stopped"), "{}", output.summary); assert!(output.summary.contains("stopped"), "{}", output.summary);
// The child got a Shutdown. // The child got a Shutdown.
@ -497,7 +497,7 @@ async fn stop_pod_succeeds_even_when_child_unreachable() {
let def = stop_pod_tool(registry.clone()); let def = stop_pod_tool(registry.clone());
let (_meta, tool) = def(); let (_meta, tool) = def();
let input = json!({ "name": "child" }).to_string(); let input = json!({ "name": "child" }).to_string();
let output: ToolOutput = tool.execute(&input).await.unwrap(); let output: ToolOutput = tool.execute(&input, Default::default()).await.unwrap();
assert!(output.summary.contains("stopped"), "{}", output.summary); assert!(output.summary.contains("stopped"), "{}", output.summary);
// Registry no longer knows about the child. // Registry no longer knows about the child.
@ -545,7 +545,7 @@ async fn restored_registry_uses_pod_state_without_runtime_file() {
let def = send_to_pod_tool(restored.clone()); let def = send_to_pod_tool(restored.clone());
let (_meta, tool) = def(); let (_meta, tool) = def();
let input = json!({ "name": "child", "message": "after restart" }).to_string(); let input = json!({ "name": "child", "message": "after restart" }).to_string();
tool.execute(&input).await.unwrap(); tool.execute(&input, Default::default()).await.unwrap();
match received.recv().await.expect("expected Run") { match received.recv().await.expect("expected Run") {
Method::Run { input } => match input.as_slice() { Method::Run { input } => match input.as_slice() {
[protocol::Segment::Text { content }] => assert_eq!(content, "after restart"), [protocol::Segment::Text { content }] => assert_eq!(content, "after restart"),
@ -556,7 +556,7 @@ async fn restored_registry_uses_pod_state_without_runtime_file() {
let def = stop_pod_tool(restored.clone()); let def = stop_pod_tool(restored.clone());
let (_meta, tool) = def(); let (_meta, tool) = def();
tool.execute(&json!({ "name": "child" }).to_string()) tool.execute(&json!({ "name": "child" }).to_string(), Default::default())
.await .await
.unwrap(); .unwrap();
assert!(matches!( assert!(matches!(

View File

@ -79,7 +79,11 @@ struct BigContentTool {
#[async_trait] #[async_trait]
impl Tool for BigContentTool { impl Tool for BigContentTool {
async fn execute(&self, _input: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
_input: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
Ok(ToolOutput { Ok(ToolOutput {
summary: self.summary.into(), summary: self.summary.into(),
content: Some(self.content.clone()), content: Some(self.content.clone()),

View File

@ -312,7 +312,7 @@ async fn spawn_pod_launches_runtime_in_workspace_and_passes_tool_cwd() {
}) })
.to_string(); .to_string();
tool.execute(&input).await.unwrap(); tool.execute(&input, Default::default()).await.unwrap();
assert!(matches!(received.await.unwrap(), Some(Method::Run { .. }))); assert!(matches!(received.await.unwrap(), Some(Method::Run { .. })));
let invocation = read_recorded_runtime_invocation(&output_path).await; let invocation = read_recorded_runtime_invocation(&output_path).await;
assert_eq!(invocation[0], allow_root.path().to_str().unwrap()); assert_eq!(invocation[0], allow_root.path().to_str().unwrap());
@ -373,7 +373,7 @@ async fn spawn_pod_omitted_cwd_preserves_spawner_pwd() {
}) })
.to_string(); .to_string();
tool.execute(&input).await.unwrap(); tool.execute(&input, Default::default()).await.unwrap();
assert!(matches!(received.await.unwrap(), Some(Method::Run { .. }))); assert!(matches!(received.await.unwrap(), Some(Method::Run { .. })));
let invocation = read_recorded_runtime_invocation(&output_path).await; let invocation = read_recorded_runtime_invocation(&output_path).await;
assert_eq!(invocation[0], allow_root.path().to_str().unwrap()); assert_eq!(invocation[0], allow_root.path().to_str().unwrap());
@ -433,7 +433,7 @@ async fn spawn_pod_delegates_scope_and_sends_run() {
.is_writable(&allow_root.path().join("a.txt")) .is_writable(&allow_root.path().join("a.txt"))
); );
let output: ToolOutput = tool.execute(&input).await.unwrap(); let output: ToolOutput = tool.execute(&input, Default::default()).await.unwrap();
assert!( assert!(
output.summary.contains("child"), output.summary.contains("child"),
"summary: {}", "summary: {}",
@ -519,7 +519,7 @@ async fn spawn_pod_requires_explicit_delegation_even_with_direct_scope() {
}) })
.to_string(); .to_string();
let err = tool.execute(&input).await.unwrap_err(); let err = tool.execute(&input, Default::default()).await.unwrap_err();
match err { match err {
ToolError::InvalidArgument(message) => { ToolError::InvalidArgument(message) => {
assert!(message.contains("no delegation scope grant"), "{message}"); assert!(message.contains("no delegation scope grant"), "{message}");
@ -587,7 +587,7 @@ async fn spawn_pod_rejects_child_non_recursive_scope_under_parent_non_recursive_
}) })
.to_string(); .to_string();
let err = tool.execute(&input).await.unwrap_err(); let err = tool.execute(&input, Default::default()).await.unwrap_err();
match err { match err {
ToolError::InvalidArgument(message) => { ToolError::InvalidArgument(message) => {
assert!( assert!(
@ -639,7 +639,7 @@ async fn spawn_pod_rejects_scope_outside_spawner() {
}) })
.to_string(); .to_string();
let err = tool.execute(&input).await.unwrap_err(); let err = tool.execute(&input, Default::default()).await.unwrap_err();
match err { match err {
ToolError::InvalidArgument(msg) => { ToolError::InvalidArgument(msg) => {
assert!( assert!(
@ -712,7 +712,7 @@ async fn spawn_pod_rolls_back_reservation_when_socket_never_appears() {
}) })
.to_string(); .to_string();
let err = tool.execute(&input).await.unwrap_err(); let err = tool.execute(&input, Default::default()).await.unwrap_err();
match err { match err {
ToolError::ExecutionFailed(msg) => { ToolError::ExecutionFailed(msg) => {
assert!( assert!(

View File

@ -54,7 +54,11 @@ struct MockWeatherTool;
#[async_trait] #[async_trait]
impl Tool for MockWeatherTool { impl Tool for MockWeatherTool {
async fn execute(&self, _input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
_input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
Ok("Sunny, 25C".to_string().into()) Ok("Sunny, 25C".to_string().into())
} }
} }

View File

@ -562,7 +562,11 @@ struct TicketDoctorTool {
#[async_trait] #[async_trait]
impl Tool for TicketCreateTool { impl Tool for TicketCreateTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: TicketCreateParams = parse_input("TicketCreate", input_json)?; let params: TicketCreateParams = parse_input("TicketCreate", input_json)?;
let mut input = NewTicket::new(params.title); let mut input = NewTicket::new(params.title);
if let Some(body) = params.body { if let Some(body) = params.body {
@ -594,7 +598,11 @@ impl Tool for TicketCreateTool {
#[async_trait] #[async_trait]
impl Tool for TicketListTool { impl Tool for TicketListTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: TicketListParams = parse_input("TicketList", input_json)?; let params: TicketListParams = parse_input("TicketList", input_json)?;
let state = params.state.unwrap_or(TicketListStateParam::All); let state = params.state.unwrap_or(TicketListStateParam::All);
let (filter, state_filter) = state.as_filter(); let (filter, state_filter) = state.as_filter();
@ -629,7 +637,11 @@ impl Tool for TicketListTool {
#[async_trait] #[async_trait]
impl Tool for TicketShowTool { impl Tool for TicketShowTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: TicketShowParams = parse_input("TicketShow", input_json)?; let params: TicketShowParams = parse_input("TicketShow", input_json)?;
let query = id_or_query(params.id, params.query)?; let query = id_or_query(params.id, params.query)?;
let event_limit = bounded(params.event_limit, DEFAULT_EVENT_LIMIT, MAX_EVENT_LIMIT); let event_limit = bounded(params.event_limit, DEFAULT_EVENT_LIMIT, MAX_EVENT_LIMIT);
@ -661,7 +673,11 @@ impl Tool for TicketShowTool {
#[async_trait] #[async_trait]
impl Tool for TicketCommentTool { impl Tool for TicketCommentTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: TicketCommentParams = parse_input("TicketComment", input_json)?; let params: TicketCommentParams = parse_input("TicketComment", input_json)?;
let kind = match params.role { let kind = match params.role {
TicketCommentRoleParam::Comment => TicketEventKind::Comment, TicketCommentRoleParam::Comment => TicketEventKind::Comment,
@ -684,7 +700,11 @@ impl Tool for TicketCommentTool {
#[async_trait] #[async_trait]
impl Tool for TicketReviewTool { impl Tool for TicketReviewTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: TicketReviewParams = parse_input("TicketReview", input_json)?; let params: TicketReviewParams = parse_input("TicketReview", input_json)?;
let result = match params.result { let result = match params.result {
TicketReviewResultParam::Approve => TicketReviewResult::Approve, TicketReviewResultParam::Approve => TicketReviewResult::Approve,
@ -708,7 +728,11 @@ impl Tool for TicketReviewTool {
#[async_trait] #[async_trait]
impl Tool for TicketIntakeReadyTool { impl Tool for TicketIntakeReadyTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: TicketIntakeReadyParams = parse_input("TicketIntakeReady", input_json)?; let params: TicketIntakeReadyParams = parse_input("TicketIntakeReady", input_json)?;
let from = TicketWorkflowState::Planning; let from = TicketWorkflowState::Planning;
let reason = params let reason = params
@ -743,7 +767,11 @@ impl Tool for TicketIntakeReadyTool {
#[async_trait] #[async_trait]
impl Tool for TicketWorkflowStateTool { impl Tool for TicketWorkflowStateTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: TicketWorkflowStateParams = parse_input("TicketWorkflowState", input_json)?; let params: TicketWorkflowStateParams = parse_input("TicketWorkflowState", input_json)?;
let from = params.from.into_state(); let from = params.from.into_state();
let to = params.to.into_state(); let to = params.to.into_state();
@ -778,7 +806,11 @@ impl Tool for TicketWorkflowStateTool {
#[async_trait] #[async_trait]
impl Tool for TicketCloseTool { impl Tool for TicketCloseTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: TicketCloseParams = parse_input("TicketClose", input_json)?; let params: TicketCloseParams = parse_input("TicketClose", input_json)?;
self.backend self.backend
.close( .close(
@ -795,7 +827,11 @@ impl Tool for TicketCloseTool {
#[async_trait] #[async_trait]
impl Tool for TicketRelationRecordTool { impl Tool for TicketRelationRecordTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: TicketRelationRecordParams = parse_input("TicketRelationRecord", input_json)?; let params: TicketRelationRecordParams = parse_input("TicketRelationRecord", input_json)?;
let relation = NewTicketRelation { let relation = NewTicketRelation {
kind: params.kind.into_kind(), kind: params.kind.into_kind(),
@ -819,7 +855,11 @@ impl Tool for TicketRelationRecordTool {
#[async_trait] #[async_trait]
impl Tool for TicketRelationQueryTool { impl Tool for TicketRelationQueryTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: TicketRelationQueryParams = parse_input("TicketRelationQuery", input_json)?; let params: TicketRelationQueryParams = parse_input("TicketRelationQuery", input_json)?;
let limit = bounded(params.limit, DEFAULT_LIST_LIMIT, MAX_LIST_LIMIT); let limit = bounded(params.limit, DEFAULT_LIST_LIMIT, MAX_LIST_LIMIT);
let ticket = params.ticket.clone().map(TicketIdOrSlug::Id); let ticket = params.ticket.clone().map(TicketIdOrSlug::Id);
@ -853,7 +893,11 @@ impl Tool for TicketRelationQueryTool {
#[async_trait] #[async_trait]
impl Tool for TicketOrchestrationPlanRecordTool { impl Tool for TicketOrchestrationPlanRecordTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: TicketOrchestrationPlanRecordParams = let params: TicketOrchestrationPlanRecordParams =
parse_input("TicketOrchestrationPlanRecord", input_json)?; parse_input("TicketOrchestrationPlanRecord", input_json)?;
let accepted_plan = params.accepted_plan.map(|plan| AcceptedOrchestrationPlan { let accepted_plan = params.accepted_plan.map(|plan| AcceptedOrchestrationPlan {
@ -885,7 +929,11 @@ impl Tool for TicketOrchestrationPlanRecordTool {
#[async_trait] #[async_trait]
impl Tool for TicketOrchestrationPlanQueryTool { impl Tool for TicketOrchestrationPlanQueryTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: TicketOrchestrationPlanQueryParams = let params: TicketOrchestrationPlanQueryParams =
parse_input("TicketOrchestrationPlanQuery", input_json)?; parse_input("TicketOrchestrationPlanQuery", input_json)?;
let limit = bounded(params.limit, DEFAULT_LIST_LIMIT, MAX_LIST_LIMIT); let limit = bounded(params.limit, DEFAULT_LIST_LIMIT, MAX_LIST_LIMIT);
@ -922,7 +970,11 @@ impl Tool for TicketOrchestrationPlanQueryTool {
#[async_trait] #[async_trait]
impl Tool for TicketDoctorTool { impl Tool for TicketDoctorTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: TicketDoctorParams = parse_input("TicketDoctor", input_json)?; let params: TicketDoctorParams = parse_input("TicketDoctor", input_json)?;
let limit = bounded(params.limit, DEFAULT_DIAGNOSTIC_LIMIT, MAX_DIAGNOSTIC_LIMIT); let limit = bounded(params.limit, DEFAULT_DIAGNOSTIC_LIMIT, MAX_DIAGNOSTIC_LIMIT);
let report = self let report = self
@ -1377,6 +1429,7 @@ mod tests {
"body": "## Background\n\nCreated by tool.\n" "body": "## Background\n\nCreated by tool.\n"
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap(); .unwrap();
@ -1388,7 +1441,10 @@ mod tests {
assert!(!created_text.contains("needs_preflight")); assert!(!created_text.contains("needs_preflight"));
let listed = list let listed = list
.execute(&json!({ "state": "planning" }).to_string()) .execute(
&json!({ "state": "planning" }).to_string(),
Default::default(),
)
.await .await
.unwrap(); .unwrap();
assert!(listed.summary.contains("Listed 1 ticket")); assert!(listed.summary.contains("Listed 1 ticket"));
@ -1398,7 +1454,10 @@ mod tests {
assert!(!listed_content.contains("needs_preflight")); assert!(!listed_content.contains("needs_preflight"));
let shown = show let shown = show
.execute(&json!({ "id": id, "event_limit": 10 }).to_string()) .execute(
&json!({ "id": id, "event_limit": 10 }).to_string(),
Default::default(),
)
.await .await
.unwrap(); .unwrap();
assert!(shown.summary.contains(&id)); assert!(shown.summary.contains(&id));
@ -1407,7 +1466,10 @@ mod tests {
assert!(!shown_content.contains("legacy_ticket")); assert!(!shown_content.contains("legacy_ticket"));
assert!(!shown_content.contains("needs_preflight")); assert!(!shown_content.contains("needs_preflight"));
let report = doctor.execute(&json!({}).to_string()).await.unwrap(); let report = doctor
.execute(&json!({}).to_string(), Default::default())
.await
.unwrap();
assert!(report.summary.contains("0 error(s)")); assert!(report.summary.contains("0 error(s)"));
} }
@ -1431,6 +1493,7 @@ mod tests {
"author": "test" "author": "test"
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap(); .unwrap();
@ -1440,7 +1503,10 @@ mod tests {
assert_eq!(recorded_json["target"], target.id); assert_eq!(recorded_json["target"], target.id);
let queried = query let queried = query
.execute(&json!({ "ticket": target.id.clone() }).to_string()) .execute(
&json!({ "ticket": target.id.clone() }).to_string(),
Default::default(),
)
.await .await
.unwrap(); .unwrap();
let queried_json: Value = serde_json::from_str(&queried.content.unwrap()).unwrap(); let queried_json: Value = serde_json::from_str(&queried.content.unwrap()).unwrap();
@ -1448,7 +1514,10 @@ mod tests {
assert_eq!(queried_json["relations"][0]["ticket_id"], source.id); assert_eq!(queried_json["relations"][0]["ticket_id"], source.id);
let shown = show let shown = show
.execute(&json!({ "id": target.id.clone() }).to_string()) .execute(
&json!({ "id": target.id.clone() }).to_string(),
Default::default(),
)
.await .await
.unwrap(); .unwrap();
let shown_json: Value = serde_json::from_str(&shown.content.unwrap()).unwrap(); let shown_json: Value = serde_json::from_str(&shown.content.unwrap()).unwrap();
@ -1476,6 +1545,7 @@ mod tests {
"body": "Implemented." "body": "Implemented."
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap(); .unwrap();
@ -1487,6 +1557,7 @@ mod tests {
"body": "Looks good." "body": "Looks good."
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap(); .unwrap();
@ -1494,11 +1565,15 @@ mod tests {
.execute( .execute(
&json!({ "ticket": created.id, "resolution": "Done via TicketClose.\n" }) &json!({ "ticket": created.id, "resolution": "Done via TicketClose.\n" })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap(); .unwrap();
let report = doctor.execute(&json!({}).to_string()).await.unwrap(); let report = doctor
.execute(&json!({}).to_string(), Default::default())
.await
.unwrap();
assert!(report.summary.contains("0 error(s)")); assert!(report.summary.contains("0 error(s)"));
let closed = backend.show(TicketIdOrSlug::Id(created.id)).unwrap(); let closed = backend.show(TicketIdOrSlug::Id(created.id)).unwrap();
assert!(closed.resolution.is_some()); assert!(closed.resolution.is_some());
@ -1538,6 +1613,7 @@ mod tests {
"author": "intake-pod" "author": "intake-pod"
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap(); .unwrap();
@ -1555,6 +1631,7 @@ mod tests {
"author": "orchestrator" "author": "orchestrator"
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap(); .unwrap();
@ -1569,6 +1646,7 @@ mod tests {
"author": "orchestrator" "author": "orchestrator"
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap(); .unwrap();
@ -1621,6 +1699,7 @@ mod tests {
"author": "orchestrator" "author": "orchestrator"
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap(); .unwrap();
@ -1650,6 +1729,7 @@ mod tests {
"author": "orchestrator" "author": "orchestrator"
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap(); .unwrap();
@ -1685,6 +1765,7 @@ mod tests {
"body": "Should not apply.\n" "body": "Should not apply.\n"
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap_err(); .unwrap_err();
@ -1717,6 +1798,7 @@ mod tests {
"body": "Should not bypass Queue.\n" "body": "Should not bypass Queue.\n"
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap_err(); .unwrap_err();
@ -1735,6 +1817,7 @@ mod tests {
"body": "Should not move backwards.\n" "body": "Should not move backwards.\n"
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap_err(); .unwrap_err();
@ -1753,6 +1836,7 @@ mod tests {
"body": "Should not skip inprogress.\n" "body": "Should not skip inprogress.\n"
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap_err(); .unwrap_err();
@ -1775,6 +1859,7 @@ mod tests {
"intake_summary": "Should not rewrite ready ticket." "intake_summary": "Should not rewrite ready ticket."
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap_err(); .unwrap_err();
@ -1807,6 +1892,7 @@ mod tests {
"author": "orchestrator" "author": "orchestrator"
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap(); .unwrap();
@ -1823,6 +1909,7 @@ mod tests {
"relation_kind": "blocked_by" "relation_kind": "blocked_by"
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap(); .unwrap();
@ -1840,7 +1927,10 @@ mod tests {
let temp = TempDir::new().unwrap(); let temp = TempDir::new().unwrap();
let show = tool_by_name(backend(&temp), "TicketShow"); let show = tool_by_name(backend(&temp), "TicketShow");
let error = show let error = show
.execute(&json!({ "id": "a", "query": "b" }).to_string()) .execute(
&json!({ "id": "a", "query": "b" }).to_string(),
Default::default(),
)
.await .await
.unwrap_err(); .unwrap_err();
assert!(matches!(error, ToolError::InvalidArgument(_))); assert!(matches!(error, ToolError::InvalidArgument(_)));
@ -1852,7 +1942,10 @@ mod tests {
let backend = backend(&temp); let backend = backend(&temp);
let create = tool_by_name(backend.clone(), "TicketCreate"); let create = tool_by_name(backend.clone(), "TicketCreate");
let output = create let output = create
.execute(&json!({ "title": "Escape" }).to_string()) .execute(
&json!({ "title": "Escape" }).to_string(),
Default::default(),
)
.await .await
.unwrap(); .unwrap();
let value: Value = serde_json::from_str(&output.content.unwrap()).unwrap(); let value: Value = serde_json::from_str(&output.content.unwrap()).unwrap();

View File

@ -101,7 +101,11 @@ impl Drop for BashTool {
#[async_trait] #[async_trait]
impl Tool for BashTool { impl Tool for BashTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: BashParams = serde_json::from_str(input_json) let params: BashParams = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid Bash input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid Bash input: {e}")))?;
let timeout_secs = params let timeout_secs = params
@ -394,7 +398,10 @@ mod tests {
assert_eq!(meta.name, "Bash"); assert_eq!(meta.name, "Bash");
let inp = serde_json::json!({ "command": "echo hello" }); let inp = serde_json::json!({ "command": "echo hello" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
assert_eq!(out.summary, "$ echo hello"); assert_eq!(out.summary, "$ echo hello");
assert_eq!(out.content.as_deref().map(str::trim), Some("hello")); assert_eq!(out.content.as_deref().map(str::trim), Some("hello"));
} }
@ -407,7 +414,10 @@ mod tests {
let inp = serde_json::json!({ let inp = serde_json::json!({
"command": "echo out; echo err 1>&2", "command": "echo out; echo err 1>&2",
}); });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let body = out.content.unwrap(); let body = out.content.unwrap();
assert!(body.contains("out")); assert!(body.contains("out"));
assert!(body.contains("err")); assert!(body.contains("err"));
@ -419,7 +429,10 @@ mod tests {
let tool = make_tool(&h); let tool = make_tool(&h);
let inp = serde_json::json!({ "command": "exit 7" }); let inp = serde_json::json!({ "command": "exit 7" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
assert!(out.summary.contains("exit 7"), "summary: {}", out.summary); assert!(out.summary.contains("exit 7"), "summary: {}", out.summary);
assert!( assert!(
out.content.is_none(), out.content.is_none(),
@ -441,12 +454,16 @@ mod tests {
"command": format!("cd {}", sub.to_str().unwrap()), "command": format!("cd {}", sub.to_str().unwrap()),
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap(); .unwrap();
let pwd_out = tool let pwd_out = tool
.execute(&serde_json::json!({ "command": "pwd" }).to_string()) .execute(
&serde_json::json!({ "command": "pwd" }).to_string(),
Default::default(),
)
.await .await
.unwrap(); .unwrap();
let body = pwd_out.content.unwrap(); let body = pwd_out.content.unwrap();
@ -467,7 +484,10 @@ mod tests {
"command": "sleep 30", "command": "sleep 30",
"timeout": 1, "timeout": 1,
}); });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
assert!( assert!(
out.summary.contains("timed out"), out.summary.contains("timed out"),
"summary: {}", "summary: {}",
@ -480,7 +500,10 @@ mod tests {
let h = setup(); let h = setup();
let tool = make_tool(&h); let tool = make_tool(&h);
let err = tool.execute("not json").await.unwrap_err(); let err = tool
.execute("not json", Default::default())
.await
.unwrap_err();
assert!(matches!(err, ToolError::InvalidArgument(_))); assert!(matches!(err, ToolError::InvalidArgument(_)));
} }
@ -494,7 +517,10 @@ mod tests {
let inp = serde_json::json!({ let inp = serde_json::json!({
"command": "for i in $(seq 1 200); do echo line $i; done", "command": "for i in $(seq 1 200); do echo line $i; done",
}); });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let body = out.content.expect("expected content"); let body = out.content.expect("expected content");
assert!( assert!(
@ -523,7 +549,10 @@ mod tests {
let inp = serde_json::json!({ let inp = serde_json::json!({
"command": "printf 'x%.0s' {1..20480}", "command": "printf 'x%.0s' {1..20480}",
}); });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let body = out.content.unwrap(); let body = out.content.unwrap();
assert!( assert!(
body.contains(spill_dir.to_str().unwrap()), body.contains(spill_dir.to_str().unwrap()),
@ -542,7 +571,10 @@ mod tests {
"command": "(sleep 0.05; echo bg) &", "command": "(sleep 0.05; echo bg) &",
"timeout": 5, "timeout": 5,
}); });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
assert!( assert!(
!out.summary.contains("timed out"), !out.summary.contains("timed out"),
"summary: {}", "summary: {}",
@ -559,7 +591,9 @@ mod tests {
let inp = serde_json::json!({ let inp = serde_json::json!({
"command": "for i in $(seq 1 200); do echo $i; done", "command": "for i in $(seq 1 200); do echo $i; done",
}); });
tool.execute(&inp.to_string()).await.unwrap(); tool.execute(&inp.to_string(), Default::default())
.await
.unwrap();
// The spill dir should now contain exactly one bash-*.log file. // The spill dir should now contain exactly one bash-*.log file.
let files_before: Vec<_> = std::fs::read_dir(&spill_dir) let files_before: Vec<_> = std::fs::read_dir(&spill_dir)

View File

@ -36,7 +36,11 @@ pub(crate) struct EditTool {
#[async_trait] #[async_trait]
impl Tool for EditTool { impl Tool for EditTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: EditParams = serde_json::from_str(input_json) let params: EditParams = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid Edit input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid Edit input: {e}")))?;
@ -169,7 +173,10 @@ mod tests {
let def = read_tool(fs.clone(), tracker.clone()); let def = read_tool(fs.clone(), tracker.clone());
let (_, reader) = def(); let (_, reader) = def();
let inp = serde_json::json!({ "file_path": file.to_str().unwrap() }); let inp = serde_json::json!({ "file_path": file.to_str().unwrap() });
reader.execute(&inp.to_string()).await.unwrap(); reader
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
} }
#[tokio::test] #[tokio::test]
@ -188,7 +195,10 @@ mod tests {
"old_string": "foo bar", "old_string": "foo bar",
"new_string": "foo baz", "new_string": "foo baz",
}); });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
assert!(out.summary.contains("1 replacement")); assert!(out.summary.contains("1 replacement"));
assert_eq!( assert_eq!(
std::fs::read_to_string(&file).unwrap(), std::fs::read_to_string(&file).unwrap(),
@ -212,7 +222,10 @@ mod tests {
"new_string": "y", "new_string": "y",
"replace_all": true, "replace_all": true,
}); });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
assert!(out.summary.contains("3 replacements")); assert!(out.summary.contains("3 replacements"));
assert_eq!(std::fs::read_to_string(&file).unwrap(), "y y y\n"); assert_eq!(std::fs::read_to_string(&file).unwrap(), "y y y\n");
} }
@ -231,7 +244,10 @@ mod tests {
"old_string": "a", "old_string": "a",
"new_string": "b", "new_string": "b",
}); });
let err = tool.execute(&inp.to_string()).await.unwrap_err(); let err = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap_err();
assert!(matches!(err, ToolError::InvalidArgument(_))); assert!(matches!(err, ToolError::InvalidArgument(_)));
} }
@ -249,7 +265,10 @@ mod tests {
"old_string": "world", "old_string": "world",
"new_string": "x", "new_string": "x",
}); });
let err = tool.execute(&inp.to_string()).await.unwrap_err(); let err = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap_err();
assert!(matches!(err, ToolError::InvalidArgument(_))); assert!(matches!(err, ToolError::InvalidArgument(_)));
} }
@ -266,7 +285,10 @@ mod tests {
"old_string": "foo", "old_string": "foo",
"new_string": "bar", "new_string": "bar",
}); });
let err = tool.execute(&inp.to_string()).await.unwrap_err(); let err = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap_err();
assert!(matches!(err, ToolError::InvalidArgument(_))); assert!(matches!(err, ToolError::InvalidArgument(_)));
} }
@ -287,7 +309,10 @@ mod tests {
"old_string": "foo", "old_string": "foo",
"new_string": "bar", "new_string": "bar",
}); });
let err = tool.execute(&inp.to_string()).await.unwrap_err(); let err = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap_err();
let msg = format!("{err}"); let msg = format!("{err}");
assert!(msg.contains("modified externally"), "{msg}"); assert!(msg.contains("modified externally"), "{msg}");
} }

View File

@ -35,7 +35,11 @@ pub(crate) struct GlobTool {
#[async_trait] #[async_trait]
impl Tool for GlobTool { impl Tool for GlobTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: GlobParams = serde_json::from_str(input_json) let params: GlobParams = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid Glob input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid Glob input: {e}")))?;
@ -239,7 +243,10 @@ mod tests {
assert_eq!(meta.name, "Glob"); assert_eq!(meta.name, "Glob");
let inp = serde_json::json!({ "pattern": "**/*.rs" }); let inp = serde_json::json!({ "pattern": "**/*.rs" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
assert!(out.summary.contains("2 file(s)")); assert!(out.summary.contains("2 file(s)"));
let body = out.content.unwrap(); let body = out.content.unwrap();
assert!(body.contains("a.rs")); assert!(body.contains("a.rs"));
@ -261,7 +268,10 @@ mod tests {
let def = glob_tool(fs); let def = glob_tool(fs);
let (_, tool) = def(); let (_, tool) = def();
let inp = serde_json::json!({ "pattern": "*.rs" }); let inp = serde_json::json!({ "pattern": "*.rs" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let body = out.content.unwrap(); let body = out.content.unwrap();
let new_pos = body.find("new.rs").unwrap(); let new_pos = body.find("new.rs").unwrap();
let old_pos = body.find("old.rs").unwrap(); let old_pos = body.find("old.rs").unwrap();
@ -274,7 +284,10 @@ mod tests {
let def = glob_tool(fs); let def = glob_tool(fs);
let (_, tool) = def(); let (_, tool) = def();
let inp = serde_json::json!({ "pattern": "**/*.nonexistent" }); let inp = serde_json::json!({ "pattern": "**/*.nonexistent" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
assert!(out.summary.contains("No files")); assert!(out.summary.contains("No files"));
assert!(out.content.is_none()); assert!(out.content.is_none());
} }
@ -285,7 +298,10 @@ mod tests {
let def = glob_tool(fs); let def = glob_tool(fs);
let (_, tool) = def(); let (_, tool) = def();
let inp = serde_json::json!({ "pattern": "[unterminated" }); let inp = serde_json::json!({ "pattern": "[unterminated" });
let err = tool.execute(&inp.to_string()).await.unwrap_err(); let err = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap_err();
assert!(matches!(err, ToolError::InvalidArgument(_))); assert!(matches!(err, ToolError::InvalidArgument(_)));
} }
@ -317,7 +333,10 @@ mod tests {
let def = glob_tool(fs); let def = glob_tool(fs);
let (_, tool) = def(); let (_, tool) = def();
let inp = serde_json::json!({ "pattern": "**/*.rs" }); let inp = serde_json::json!({ "pattern": "**/*.rs" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let body = out.content.unwrap_or_default(); let body = out.content.unwrap_or_default();
assert!(body.contains("visible.rs")); assert!(body.contains("visible.rs"));
assert!( assert!(
@ -335,7 +354,10 @@ mod tests {
let def = glob_tool(fs); let def = glob_tool(fs);
let (_, tool) = def(); let (_, tool) = def();
let inp = serde_json::json!({ "pattern": "*.rs" }); let inp = serde_json::json!({ "pattern": "*.rs" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let body = out.content.unwrap(); let body = out.content.unwrap();
assert!(body.contains(".hidden.rs")); assert!(body.contains(".hidden.rs"));
assert!(body.contains("visible.rs")); assert!(body.contains("visible.rs"));
@ -358,7 +380,10 @@ mod tests {
"path": link.to_str().unwrap(), "path": link.to_str().unwrap(),
"pattern": "**/*.rs", "pattern": "**/*.rs",
}); });
let err = tool.execute(&inp.to_string()).await.unwrap_err(); let err = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap_err();
let msg = format!("{err}"); let msg = format!("{err}");
assert!( assert!(
msg.contains("Glob does not follow symlink directories"), msg.contains("Glob does not follow symlink directories"),

View File

@ -82,7 +82,11 @@ pub(crate) struct GrepTool {
#[async_trait] #[async_trait]
impl Tool for GrepTool { impl Tool for GrepTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: GrepParams = serde_json::from_str(input_json) let params: GrepParams = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid Grep input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid Grep input: {e}")))?;
@ -563,7 +567,10 @@ mod tests {
let def = grep_tool(scoped); let def = grep_tool(scoped);
let (_, tool) = def(); let (_, tool) = def();
let inp = serde_json::json!({ "pattern": "needle" }); let inp = serde_json::json!({ "pattern": "needle" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let body = out.content.unwrap_or_default(); let body = out.content.unwrap_or_default();
assert!(body.contains("visible.txt")); assert!(body.contains("visible.txt"));
assert!( assert!(
@ -583,7 +590,10 @@ mod tests {
assert_eq!(meta.name, "Grep"); assert_eq!(meta.name, "Grep");
let inp = serde_json::json!({ "pattern": "bravo" }); let inp = serde_json::json!({ "pattern": "bravo" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
assert!(out.summary.contains("1 file")); assert!(out.summary.contains("1 file"));
assert!(out.content.unwrap().contains("a.txt")); assert!(out.content.unwrap().contains("a.txt"));
} }
@ -599,7 +609,10 @@ mod tests {
"pattern": "two", "pattern": "two",
"output_mode": "content", "output_mode": "content",
}); });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let body = out.content.unwrap(); let body = out.content.unwrap();
assert!(body.contains(":2:two")); assert!(body.contains(":2:two"));
} }
@ -616,7 +629,10 @@ mod tests {
"pattern": "x", "pattern": "x",
"output_mode": "count", "output_mode": "count",
}); });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let body = out.content.unwrap(); let body = out.content.unwrap();
assert!(body.contains("a.txt:3")); assert!(body.contains("a.txt:3"));
assert!(body.contains("b.txt:1")); assert!(body.contains("b.txt:1"));
@ -635,7 +651,10 @@ mod tests {
"-i": true, "-i": true,
"output_mode": "content", "output_mode": "content",
}); });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
assert!(out.content.unwrap().contains("HELLO")); assert!(out.content.unwrap().contains("HELLO"));
} }
@ -654,7 +673,10 @@ mod tests {
"output_mode": "content", "output_mode": "content",
"-C": 1, "-C": 1,
}); });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let body = out.content.unwrap(); let body = out.content.unwrap();
// should contain: line2 (before context), MATCH, line4 (after context) // should contain: line2 (before context), MATCH, line4 (after context)
assert!(body.contains("line2")); assert!(body.contains("line2"));
@ -677,7 +699,10 @@ mod tests {
"multiline": true, "multiline": true,
"output_mode": "content", "output_mode": "content",
}); });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let body = out.content.unwrap(); let body = out.content.unwrap();
assert!(body.contains("foo")); assert!(body.contains("foo"));
} }
@ -694,7 +719,10 @@ mod tests {
"pattern": "target", "pattern": "target",
"glob": "*.rs", "glob": "*.rs",
}); });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let body = out.content.unwrap(); let body = out.content.unwrap();
assert!(body.contains("a.rs")); assert!(body.contains("a.rs"));
assert!(!body.contains("b.txt")); assert!(!body.contains("b.txt"));
@ -712,7 +740,10 @@ mod tests {
"pattern": "target", "pattern": "target",
"type": "rust", "type": "rust",
}); });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let body = out.content.unwrap(); let body = out.content.unwrap();
assert!(body.contains("a.rs")); assert!(body.contains("a.rs"));
assert!(!body.contains("b.py")); assert!(!body.contains("b.py"));
@ -731,7 +762,10 @@ mod tests {
"pattern": "x", "pattern": "x",
"head_limit": 2, "head_limit": 2,
}); });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let body = out.content.unwrap(); let body = out.content.unwrap();
assert_eq!(body.lines().count(), 2); assert_eq!(body.lines().count(), 2);
assert!(out.summary.contains("truncated at 2")); assert!(out.summary.contains("truncated at 2"));
@ -752,7 +786,10 @@ mod tests {
"offset": 3, "offset": 3,
"head_limit": 10, "head_limit": 10,
}); });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let body = out.content.unwrap(); let body = out.content.unwrap();
// We skipped 3, so only 2 should remain. // We skipped 3, so only 2 should remain.
assert_eq!(body.lines().count(), 2); assert_eq!(body.lines().count(), 2);
@ -769,7 +806,10 @@ mod tests {
let def = grep_tool(fs); let def = grep_tool(fs);
let (_, tool) = def(); let (_, tool) = def();
let inp = serde_json::json!({ "pattern": "needle" }); let inp = serde_json::json!({ "pattern": "needle" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
let body = out.content.unwrap(); let body = out.content.unwrap();
assert!(body.contains("b.txt")); assert!(body.contains("b.txt"));
assert!(!body.contains("a.bin")); assert!(!body.contains("a.bin"));
@ -781,7 +821,10 @@ mod tests {
let def = grep_tool(fs); let def = grep_tool(fs);
let (_, tool) = def(); let (_, tool) = def();
let inp = serde_json::json!({ "pattern": "(" }); let inp = serde_json::json!({ "pattern": "(" });
let err = tool.execute(&inp.to_string()).await.unwrap_err(); let err = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap_err();
assert!(matches!(err, ToolError::InvalidArgument(_))); assert!(matches!(err, ToolError::InvalidArgument(_)));
} }
@ -794,7 +837,10 @@ mod tests {
"pattern": "x", "pattern": "x",
"type": "nonexistent", "type": "nonexistent",
}); });
let err = tool.execute(&inp.to_string()).await.unwrap_err(); let err = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap_err();
assert!(matches!(err, ToolError::InvalidArgument(_))); assert!(matches!(err, ToolError::InvalidArgument(_)));
} }
@ -805,7 +851,10 @@ mod tests {
let def = grep_tool(fs); let def = grep_tool(fs);
let (_, tool) = def(); let (_, tool) = def();
let inp = serde_json::json!({ "pattern": "zzz" }); let inp = serde_json::json!({ "pattern": "zzz" });
let out = tool.execute(&inp.to_string()).await.unwrap(); let out = tool
.execute(&inp.to_string(), Default::default())
.await
.unwrap();
assert_eq!(out.summary, "No files matched"); assert_eq!(out.summary, "No files matched");
assert!(out.content.is_none()); assert!(out.content.is_none());
} }

View File

@ -36,7 +36,11 @@ pub(crate) struct ReadTool {
#[async_trait] #[async_trait]
impl Tool for ReadTool { impl Tool for ReadTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: ReadParams = serde_json::from_str(input_json) let params: ReadParams = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid Read input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid Read input: {e}")))?;
let offset = params.offset.unwrap_or(0); let offset = params.offset.unwrap_or(0);
@ -155,7 +159,10 @@ mod tests {
assert_eq!(meta.name, "Read"); assert_eq!(meta.name, "Read");
let input = serde_json::json!({ "file_path": file.to_str().unwrap() }); let input = serde_json::json!({ "file_path": file.to_str().unwrap() });
let out = tool.execute(&input.to_string()).await.unwrap(); let out = tool
.execute(&input.to_string(), Default::default())
.await
.unwrap();
assert!(out.summary.contains("Read 3 line(s)")); assert!(out.summary.contains("Read 3 line(s)"));
let body = out.content.unwrap(); let body = out.content.unwrap();
assert!(body.contains(" 1\talpha")); assert!(body.contains(" 1\talpha"));
@ -178,7 +185,10 @@ mod tests {
"offset": 1, "offset": 1,
"limit": 2, "limit": 2,
}); });
let out = tool.execute(&input.to_string()).await.unwrap(); let out = tool
.execute(&input.to_string(), Default::default())
.await
.unwrap();
assert!(out.summary.contains("[2..3] of 5")); assert!(out.summary.contains("[2..3] of 5"));
let body = out.content.unwrap(); let body = out.content.unwrap();
assert!(body.contains(" 2\t2")); assert!(body.contains(" 2\t2"));
@ -193,7 +203,10 @@ mod tests {
let input = serde_json::json!({ let input = serde_json::json!({
"file_path": dir.path().join("nope.txt").to_str().unwrap() "file_path": dir.path().join("nope.txt").to_str().unwrap()
}); });
let err = tool.execute(&input.to_string()).await.unwrap_err(); let err = tool
.execute(&input.to_string(), Default::default())
.await
.unwrap_err();
assert!(matches!(err, ToolError::ExecutionFailed(_))); assert!(matches!(err, ToolError::ExecutionFailed(_)));
} }
@ -202,7 +215,10 @@ mod tests {
let (_dir, fs, tracker) = setup(); let (_dir, fs, tracker) = setup();
let def = read_tool(fs, tracker); let def = read_tool(fs, tracker);
let (_, tool) = def(); let (_, tool) = def();
let err = tool.execute("not json").await.unwrap_err(); let err = tool
.execute("not json", Default::default())
.await
.unwrap_err();
assert!(matches!(err, ToolError::InvalidArgument(_))); assert!(matches!(err, ToolError::InvalidArgument(_)));
} }
} }

View File

@ -146,7 +146,11 @@ struct WebFetchTool {
#[async_trait] #[async_trait]
impl Tool for WebSearchTool { impl Tool for WebSearchTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let input: WebSearchInput = serde_json::from_str(input_json) let input: WebSearchInput = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid WebSearch input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid WebSearch input: {e}")))?;
self.web.run_search(input).await self.web.run_search(input).await
@ -193,7 +197,11 @@ impl WebTools {
#[async_trait] #[async_trait]
impl Tool for WebFetchTool { impl Tool for WebFetchTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let input: WebFetchInput = serde_json::from_str(input_json) let input: WebFetchInput = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid WebFetch input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid WebFetch input: {e}")))?;
self.web.run_fetch(input).await self.web.run_fetch(input).await

View File

@ -30,7 +30,11 @@ pub(crate) struct WriteTool {
#[async_trait] #[async_trait]
impl Tool for WriteTool { impl Tool for WriteTool {
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError> { async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: WriteParams = serde_json::from_str(input_json) let params: WriteParams = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid Write input: {e}")))?; .map_err(|e| ToolError::InvalidArgument(format!("invalid Write input: {e}")))?;
@ -118,7 +122,10 @@ mod tests {
"file_path": file.to_str().unwrap(), "file_path": file.to_str().unwrap(),
"content": "hello\n", "content": "hello\n",
}); });
let out = tool.execute(&input.to_string()).await.unwrap(); let out = tool
.execute(&input.to_string(), Default::default())
.await
.unwrap();
assert!(out.summary.contains("Created")); assert!(out.summary.contains("Created"));
assert_eq!(std::fs::read_to_string(&file).unwrap(), "hello\n"); assert_eq!(std::fs::read_to_string(&file).unwrap(), "hello\n");
} }
@ -135,7 +142,10 @@ mod tests {
"file_path": file.to_str().unwrap(), "file_path": file.to_str().unwrap(),
"content": "new", "content": "new",
}); });
let err = tool.execute(&input.to_string()).await.unwrap_err(); let err = tool
.execute(&input.to_string(), Default::default())
.await
.unwrap_err();
assert!(matches!(err, ToolError::InvalidArgument(_))); assert!(matches!(err, ToolError::InvalidArgument(_)));
} }
@ -148,7 +158,10 @@ mod tests {
let read_def = read_tool(fs.clone(), tracker.clone()); let read_def = read_tool(fs.clone(), tracker.clone());
let (_, reader) = read_def(); let (_, reader) = read_def();
let read_in = serde_json::json!({ "file_path": file.to_str().unwrap() }); let read_in = serde_json::json!({ "file_path": file.to_str().unwrap() });
reader.execute(&read_in.to_string()).await.unwrap(); reader
.execute(&read_in.to_string(), Default::default())
.await
.unwrap();
let write_def = write_tool(fs, tracker); let write_def = write_tool(fs, tracker);
let (_, writer) = write_def(); let (_, writer) = write_def();
@ -156,7 +169,10 @@ mod tests {
"file_path": file.to_str().unwrap(), "file_path": file.to_str().unwrap(),
"content": "new\n", "content": "new\n",
}); });
let out = writer.execute(&write_in.to_string()).await.unwrap(); let out = writer
.execute(&write_in.to_string(), Default::default())
.await
.unwrap();
assert!(out.summary.contains("Overwrote")); assert!(out.summary.contains("Overwrote"));
assert_eq!(std::fs::read_to_string(&file).unwrap(), "new\n"); assert_eq!(std::fs::read_to_string(&file).unwrap(), "new\n");
} }
@ -171,7 +187,10 @@ mod tests {
let read_def = read_tool(fs.clone(), tracker.clone()); let read_def = read_tool(fs.clone(), tracker.clone());
let (_, reader) = read_def(); let (_, reader) = read_def();
reader reader
.execute(&serde_json::json!({ "file_path": file.to_str().unwrap() }).to_string()) .execute(
&serde_json::json!({ "file_path": file.to_str().unwrap() }).to_string(),
Default::default(),
)
.await .await
.unwrap(); .unwrap();
@ -187,6 +206,7 @@ mod tests {
"content": "new", "content": "new",
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap_err(); .unwrap_err();
@ -205,7 +225,10 @@ mod tests {
"file_path": outside.path().join("x.txt").to_str().unwrap(), "file_path": outside.path().join("x.txt").to_str().unwrap(),
"content": "x", "content": "x",
}); });
let err = tool.execute(&input.to_string()).await.unwrap_err(); let err = tool
.execute(&input.to_string(), Default::default())
.await
.unwrap_err();
assert!(matches!(err, ToolError::InvalidArgument(_))); assert!(matches!(err, ToolError::InvalidArgument(_)));
} }
} }

View File

@ -66,13 +66,17 @@ async fn unicode_path_and_content() {
"content": content, "content": content,
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap(); .unwrap();
let read = reg.get("Read"); let read = reg.get("Read");
let out = read let out = read
.execute(&json!({ "file_path": file.to_str().unwrap() }).to_string()) .execute(
&json!({ "file_path": file.to_str().unwrap() }).to_string(),
Default::default(),
)
.await .await
.unwrap(); .unwrap();
let body = out.content.unwrap(); let body = out.content.unwrap();
@ -98,7 +102,10 @@ async fn symlink_to_outside_scope_is_rejected_for_write() {
// target sits outside the scope. // target sits outside the scope.
let read = reg.get("Read"); let read = reg.get("Read");
let read_err = read let read_err = read
.execute(&json!({ "file_path": link.to_str().unwrap() }).to_string()) .execute(
&json!({ "file_path": link.to_str().unwrap() }).to_string(),
Default::default(),
)
.await .await
.unwrap_err(); .unwrap_err();
assert!( assert!(
@ -119,6 +126,7 @@ async fn symlink_to_outside_scope_is_rejected_for_write() {
"content": "overwritten", "content": "overwritten",
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap_err(); .unwrap_err();
@ -147,7 +155,10 @@ async fn broken_symlink_reports_target_and_repair_hint() {
let read = reg.get("Read"); let read = reg.get("Read");
let err = read let err = read
.execute(&json!({ "file_path": link.to_str().unwrap() }).to_string()) .execute(
&json!({ "file_path": link.to_str().unwrap() }).to_string(),
Default::default(),
)
.await .await
.unwrap_err(); .unwrap_err();
let msg = format!("{err}"); let msg = format!("{err}");
@ -165,7 +176,10 @@ async fn empty_file_read_and_edit() {
let read = reg.get("Read"); let read = reg.get("Read");
let out = read let out = read
.execute(&json!({ "file_path": file.to_str().unwrap() }).to_string()) .execute(
&json!({ "file_path": file.to_str().unwrap() }).to_string(),
Default::default(),
)
.await .await
.unwrap(); .unwrap();
assert!(out.summary.contains("0 line")); assert!(out.summary.contains("0 line"));
@ -180,6 +194,7 @@ async fn empty_file_read_and_edit() {
"new_string": "bar", "new_string": "bar",
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap_err(); .unwrap_err();
@ -196,7 +211,10 @@ async fn very_long_single_line() {
let read = reg.get("Read"); let read = reg.get("Read");
let out = read let out = read
.execute(&json!({ "file_path": file.to_str().unwrap() }).to_string()) .execute(
&json!({ "file_path": file.to_str().unwrap() }).to_string(),
Default::default(),
)
.await .await
.unwrap(); .unwrap();
// Should return exactly 1 line // Should return exactly 1 line
@ -208,7 +226,10 @@ async fn relative_path_is_rejected() {
let (_dir, _spill, reg) = setup(); let (_dir, _spill, reg) = setup();
let read = reg.get("Read"); let read = reg.get("Read");
let err = read let err = read
.execute(&json!({ "file_path": "relative.txt" }).to_string()) .execute(
&json!({ "file_path": "relative.txt" }).to_string(),
Default::default(),
)
.await .await
.unwrap_err(); .unwrap_err();
assert!(format!("{err}").contains("absolute")); assert!(format!("{err}").contains("absolute"));
@ -219,7 +240,10 @@ async fn directory_target_is_rejected_for_read() {
let (dir, _spill, reg) = setup(); let (dir, _spill, reg) = setup();
let read = reg.get("Read"); let read = reg.get("Read");
let err = read let err = read
.execute(&json!({ "file_path": dir.path().to_str().unwrap() }).to_string()) .execute(
&json!({ "file_path": dir.path().to_str().unwrap() }).to_string(),
Default::default(),
)
.await .await
.unwrap_err(); .unwrap_err();
assert!(format!("{err}").contains("directory")); assert!(format!("{err}").contains("directory"));
@ -237,6 +261,7 @@ async fn deeply_nested_new_file_is_created() {
"content": "deep\n", "content": "deep\n",
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap(); .unwrap();
@ -250,7 +275,10 @@ async fn replace_preserves_unicode() {
std::fs::write(&file, "🦀 rust 🦀\n").unwrap(); std::fs::write(&file, "🦀 rust 🦀\n").unwrap();
let read = reg.get("Read"); let read = reg.get("Read");
read.execute(&json!({ "file_path": file.to_str().unwrap() }).to_string()) read.execute(
&json!({ "file_path": file.to_str().unwrap() }).to_string(),
Default::default(),
)
.await .await
.unwrap(); .unwrap();
@ -262,6 +290,7 @@ async fn replace_preserves_unicode() {
"new_string": "ラスト", "new_string": "ラスト",
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap(); .unwrap();
@ -282,6 +311,7 @@ async fn grep_handles_unicode_pattern() {
"output_mode": "content", "output_mode": "content",
}) })
.to_string(), .to_string(),
Default::default(),
) )
.await .await
.unwrap(); .unwrap();

View File

@ -66,13 +66,13 @@ fn setup() -> (TempDir, TempDir, Registry) {
} }
async fn call(tool: &Arc<dyn Tool>, input: serde_json::Value) -> llm_worker::tool::ToolOutput { async fn call(tool: &Arc<dyn Tool>, input: serde_json::Value) -> llm_worker::tool::ToolOutput {
tool.execute(&input.to_string()) tool.execute(&input.to_string(), Default::default())
.await .await
.expect("tool execution failed") .expect("tool execution failed")
} }
async fn call_err(tool: &Arc<dyn Tool>, input: serde_json::Value) -> llm_worker::tool::ToolError { async fn call_err(tool: &Arc<dyn Tool>, input: serde_json::Value) -> llm_worker::tool::ToolError {
tool.execute(&input.to_string()) tool.execute(&input.to_string(), Default::default())
.await .await
.expect_err("expected error") .expect_err("expected error")
} }