diff --git a/crates/llm-worker-macros/src/lib.rs b/crates/llm-worker-macros/src/lib.rs index 9597bec4..2c8b8812 100644 --- a/crates/llm-worker-macros/src/lib.rs +++ b/crates/llm-worker-macros/src/lib.rs @@ -90,19 +90,27 @@ fn extract_doc_comment(attrs: &[Attribute]) -> String { /// Extract description from #[description = "..."] attribute fn extract_description_attr(attrs: &[syn::Attribute]) -> Option { for attr in attrs { - if attr.path().is_ident("description") { - if let Meta::NameValue(meta) = &attr.meta { - if let syn::Expr::Lit(expr_lit) = &meta.value { - if let Lit::Str(lit_str) = &expr_lit.lit { - return Some(lit_str.value()); - } - } - } + if attr.path().is_ident("description") + && let Meta::NameValue(meta) = &attr.meta + && let syn::Expr::Lit(expr_lit) = &meta.value + && let Lit::Str(lit_str) = &expr_lit.lit + { + return Some(lit_str.value()); } } None } +fn is_tool_execution_context_type(ty: &Type) -> bool { + let Type::Path(path) = ty else { + return false; + }; + path.path + .segments + .last() + .is_some_and(|segment| segment.ident == "ToolExecutionContext") +} + /// Generate Tool implementation from a method fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::TokenStream { let sig = &method.sig; @@ -123,8 +131,10 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2:: description }; - // Parse arguments (excluding self) - let args: Vec<_> = sig + // Parse method arguments (excluding self). A parameter typed as + // ToolExecutionContext is supplied from the execution context and is not + // exposed in the JSON input schema. + let method_args: Vec<_> = sig .inputs .iter() .filter_map(|arg| { @@ -135,9 +145,14 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2:: } }) .collect(); + let json_args: Vec<_> = method_args + .iter() + .copied() + .filter(|pat_type| !is_tool_execution_context_type(pat_type.ty.as_ref())) + .collect(); // Generate argument struct fields - let arg_fields: Vec<_> = args + let arg_fields: Vec<_> = json_args .iter() .map(|pat_type| { let pat = &pat_type.pat; @@ -165,11 +180,13 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2:: }) .collect(); - // Code to expand arguments in execute - let arg_names: Vec<_> = args + // Code to expand method arguments in execute + let call_args: Vec<_> = method_args .iter() .map(|pat_type| { - if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() { + if is_tool_execution_context_type(pat_type.ty.as_ref()) { + quote! { ctx.clone() } + } else if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() { let ident = &pat_ident.ident; quote! { args.#ident } } else { @@ -177,6 +194,11 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2:: } }) .collect(); + let method_call = if call_args.is_empty() { + quote! { self.ctx.#method_name() } + } else { + quote! { self.ctx.#method_name(#(#call_args),*) } + }; // Check if method is async let is_async = sig.asyncness.is_some(); @@ -218,13 +240,13 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2:: }; // Execute body handling for no arguments case - let execute_body = if args.is_empty() { + let execute_body = if json_args.is_empty() { quote! { - // Allow empty JSON object even with no arguments + // Allow empty JSON object even with no JSON arguments let _: #args_struct_name = serde_json::from_str(input_json) .unwrap_or(#args_struct_name {}); - let result = self.ctx.#method_name()#awaiter; + let result = #method_call #awaiter; #result_handling } } else { @@ -232,7 +254,7 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2:: let args: #args_struct_name = serde_json::from_str(input_json) .map_err(|e| ::llm_worker::tool::ToolError::InvalidArgument(e.to_string()))?; - let result = self.ctx.#method_name(#(#arg_names),*)#awaiter; + let result = #method_call #awaiter; #result_handling } }; @@ -247,7 +269,8 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2:: #[async_trait::async_trait] impl ::llm_worker::tool::Tool for #tool_struct_name { - async fn execute(&self, input_json: &str) -> Result<::llm_worker::tool::ToolOutput, ::llm_worker::tool::ToolError> { + async fn execute(&self, input_json: &str, ctx: ::llm_worker::tool::ToolExecutionContext) -> Result<::llm_worker::tool::ToolOutput, ::llm_worker::tool::ToolError> { + let _ = &ctx; #execute_body } } diff --git a/crates/llm-worker/src/interceptor.rs b/crates/llm-worker/src/interceptor.rs index 5b4b2e25..0cb926fc 100644 --- a/crates/llm-worker/src/interceptor.rs +++ b/crates/llm-worker/src/interceptor.rs @@ -10,7 +10,7 @@ use std::sync::Arc; use async_trait::async_trait; use crate::Item; -use crate::tool::{Tool, ToolCall, ToolMeta, ToolResult}; +use crate::tool::{Tool, ToolCall, ToolExecutionContext, ToolMeta, ToolResult}; // ============================================================================= // Action Enums @@ -107,6 +107,8 @@ pub struct ToolCallInfo { pub meta: ToolMeta, /// Tool instance (for state access). pub tool: Arc, + /// Response-local execution context for this call. + pub context: ToolExecutionContext, } /// Context for post-tool-call decisions. @@ -119,6 +121,8 @@ pub struct ToolResultInfo { pub meta: ToolMeta, /// Tool instance (for state access). pub tool: Arc, + /// Response-local execution context for this call. + pub context: ToolExecutionContext, } // ============================================================================= diff --git a/crates/llm-worker/src/lib.rs b/crates/llm-worker/src/lib.rs index cf4b861b..535cfc7b 100644 --- a/crates/llm-worker/src/lib.rs +++ b/crates/llm-worker/src/lib.rs @@ -57,7 +57,7 @@ pub use callback::{TextBlockScope, ThinkingBlockScope, ToolUseBlockScope}; pub use handler::ToolUseBlockStart; pub use interceptor::Interceptor; pub use message::{ContentPart, Item, Message, Role}; -pub use tool::{ToolCall, ToolOutputLimits, ToolResult}; +pub use tool::{ToolCall, ToolExecutionContext, ToolOutputLimits, ToolResult}; pub use usage_record::UsageRecord; pub use worker::{ LlmRetryNotice, RunOutput, ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult, diff --git a/crates/llm-worker/src/tool.rs b/crates/llm-worker/src/tool.rs index 0185954f..231c645d 100644 --- a/crates/llm-worker/src/tool.rs +++ b/crates/llm-worker/src/tool.rs @@ -189,6 +189,44 @@ impl ToolMeta { /// ``` pub type ToolDefinition = Arc (ToolMeta, Arc) + Send + Sync>; +/// Per-call context supplied by the worker when executing a tool call. +/// +/// The context identifies a tool call within one assistant response's tool-call +/// batch without imposing any scheduling policy on the worker. Tool +/// implementations may use it for response-local ordering, diagnostics, or +/// correlation, but it is intentionally not a handle to worker state, history, +/// or session mutation. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ToolExecutionContext { + /// Provider/tool-call id for the call being executed. + pub call_id: String, + /// Worker-local identity shared by all tool calls from one execution batch. + pub batch_id: String, + /// Zero-based order of this call in the model-returned tool-call list. + pub call_index: usize, +} + +impl ToolExecutionContext { + pub fn new(call_id: impl Into, batch_id: impl Into, call_index: usize) -> Self { + Self { + call_id: call_id.into(), + batch_id: batch_id.into(), + call_index, + } + } + + /// Context for direct, non-worker calls in unit tests and low-level callers. + pub fn direct() -> Self { + Self::new("direct", "direct", 0) + } +} + +impl Default for ToolExecutionContext { + fn default() -> Self { + Self::direct() + } +} + // ============================================================================= // Tool trait // ============================================================================= @@ -219,16 +257,16 @@ pub type ToolDefinition = Arc (ToolMeta, Arc) + Send + Syn /// # Manual Implementation /// /// ```ignore -/// use llm_worker::tool::{Tool, ToolError, ToolMeta, ToolDefinition}; +/// use llm_worker::tool::{Tool, ToolError, ToolExecutionContext, ToolMeta, ToolDefinition, ToolOutput}; /// use std::sync::Arc; /// /// struct MyTool { counter: std::sync::atomic::AtomicUsize } /// /// #[async_trait::async_trait] /// impl Tool for MyTool { -/// async fn execute(&self, input: &str) -> Result { +/// async fn execute(&self, input: &str, ctx: ToolExecutionContext) -> Result { /// self.counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst); -/// Ok("result".to_string()) +/// Ok(format!("call {}: {}", ctx.call_index, input).into()) /// } /// } /// @@ -247,11 +285,16 @@ pub trait Tool: Send + Sync { /// /// # Arguments /// * `input_json` - JSON-formatted arguments generated by LLM + /// * `ctx` - response-local call identity and ordering context /// /// # Returns /// A [`ToolOutput`] with summary and optional detailed content. /// For simple cases, use `From`: `Ok("done".to_string().into())` - async fn execute(&self, input_json: &str) -> Result; + async fn execute( + &self, + input_json: &str, + ctx: ToolExecutionContext, + ) -> Result; } // ============================================================================= diff --git a/crates/llm-worker/src/tool_server.rs b/crates/llm-worker/src/tool_server.rs index a7e73356..fb884df4 100644 --- a/crates/llm-worker/src/tool_server.rs +++ b/crates/llm-worker/src/tool_server.rs @@ -4,7 +4,9 @@ use std::sync::{Arc, Mutex}; use thiserror::Error; use crate::llm_client::ToolDefinition as LlmToolDefinition; -use crate::tool::{Tool, ToolDefinition as WorkerToolDefinition, ToolMeta, ToolOutput}; +use crate::tool::{ + Tool, ToolDefinition as WorkerToolDefinition, ToolExecutionContext, ToolMeta, ToolOutput, +}; type ToolMap = HashMap)>; @@ -117,6 +119,7 @@ impl ToolServerHandle { &self, name: &str, input_json: &str, + ctx: ToolExecutionContext, ) -> Result { let tool = { let guard = self.tools.lock().unwrap_or_else(|e| e.into_inner()); @@ -125,7 +128,7 @@ impl ToolServerHandle { .ok_or_else(|| ToolServerError::ToolNotFound(name.to_string()))?; Arc::clone(tool) }; - tool.execute(input_json) + tool.execute(input_json, ctx) .await .map_err(|e| ToolServerError::ToolExecution(e.to_string())) } @@ -187,7 +190,11 @@ mod tests { #[async_trait] impl Tool for EchoTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: crate::tool::ToolExecutionContext, + ) -> Result { Ok(input_json.to_string().into()) } } @@ -236,12 +243,15 @@ mod tests { handle.register_tool(def("echo")); handle.flush_pending(); - let out = handle.call_tool("echo", r#"{"x":1}"#).await.expect("call"); + let out = handle + .call_tool("echo", r#"{"x":1}"#, Default::default()) + .await + .expect("call"); assert_eq!(out.summary, r#"{"x":1}"#); assert!(out.content.is_none()); let err = handle - .call_tool("missing", "{}") + .call_tool("missing", "{}", Default::default()) .await .expect_err("missing tool"); assert_eq!(err, ToolServerError::ToolNotFound("missing".to_string())); @@ -298,7 +308,11 @@ mod tests { #[async_trait] impl Tool for FixedTool { - async fn execute(&self, _input_json: &str) -> Result { + async fn execute( + &self, + _input_json: &str, + _ctx: crate::tool::ToolExecutionContext, + ) -> Result { Ok("replaced".to_string().into()) } } @@ -327,7 +341,11 @@ mod tests { #[async_trait] impl Tool for ConstTool { - async fn execute(&self, _input_json: &str) -> Result { + async fn execute( + &self, + _input_json: &str, + _ctx: crate::tool::ToolExecutionContext, + ) -> Result { Ok("const".to_string().into()) } } @@ -342,7 +360,10 @@ mod tests { }); handle.replace(replacement).expect("replace"); - let out = handle.call_tool("echo", "{}").await.expect("call"); + let out = handle + .call_tool("echo", "{}", Default::default()) + .await + .expect("call"); assert_eq!(out.summary, "const"); } @@ -360,7 +381,11 @@ mod tests { #[async_trait] impl Tool for GatedTool { - async fn execute(&self, _input_json: &str) -> Result { + async fn execute( + &self, + _input_json: &str, + _ctx: crate::tool::ToolExecutionContext, + ) -> Result { self.started.notify_one(); self.finish.notified().await; Ok("done".to_string().into()) @@ -384,7 +409,7 @@ mod tests { handle.flush_pending(); let h = handle.clone(); - let call = tokio::spawn(async move { h.call_tool("slow", "{}").await }); + let call = tokio::spawn(async move { h.call_tool("slow", "{}", Default::default()).await }); // Wait until the tool is actually executing. started.notified().await; @@ -413,7 +438,11 @@ mod tests { #[async_trait] impl Tool for OldTool { - async fn execute(&self, _input_json: &str) -> Result { + async fn execute( + &self, + _input_json: &str, + _ctx: crate::tool::ToolExecutionContext, + ) -> Result { self.started.notify_one(); self.finish.notified().await; Ok("old".to_string().into()) @@ -437,7 +466,7 @@ mod tests { handle.flush_pending(); let h = handle.clone(); - let call = tokio::spawn(async move { h.call_tool("t", "{}").await }); + let call = tokio::spawn(async move { h.call_tool("t", "{}", Default::default()).await }); // Wait until the old tool is mid-execution. started.notified().await; @@ -447,7 +476,11 @@ mod tests { #[async_trait] impl Tool for NewTool { - async fn execute(&self, _input_json: &str) -> Result { + async fn execute( + &self, + _input_json: &str, + _ctx: crate::tool::ToolExecutionContext, + ) -> Result { Ok("new".to_string().into()) } } @@ -469,7 +502,10 @@ mod tests { assert_eq!(result.expect("call").summary, "old"); // New calls use the replacement. - let out = handle.call_tool("t", "{}").await.expect("call"); + let out = handle + .call_tool("t", "{}", Default::default()) + .await + .expect("call"); assert_eq!(out.summary, "new"); } diff --git a/crates/llm-worker/src/worker.rs b/crates/llm-worker/src/worker.rs index 7c0ccc80..8310bdba 100644 --- a/crates/llm-worker/src/worker.rs +++ b/crates/llm-worker/src/worker.rs @@ -26,8 +26,8 @@ use crate::{ timeline::event::{ErrorEvent, StatusEvent, UsageEvent}, timeline::{TextBlockCollector, ThinkingBlockCollector, Timeline, ToolCallCollector}, tool::{ - ToolCall, ToolDefinition as WorkerToolDefinition, ToolError, ToolOutputLimits, ToolResult, - truncate_content, + ToolCall, ToolDefinition as WorkerToolDefinition, ToolError, ToolExecutionContext, + ToolOutputLimits, ToolResult, truncate_content, }, tool_server::{ToolServer, ToolServerHandle}, }; @@ -187,6 +187,10 @@ pub struct Worker { /// LlmCall count (per-Worker running counter, monotonic). Unlike /// `turn_count` this never collapses retries. llm_call_count: usize, + /// Tool execution batch count (per-Worker running counter, monotonic). + /// Each batch corresponds to one collected assistant tool-call set or one + /// resumed pending tool-call set. + tool_execution_batch_count: usize, /// Maximum number of AgentTurns (None = unlimited) max_turns: Option, /// AgentTurn-start callbacks (1:1 with LlmCall today) @@ -912,19 +916,23 @@ impl Worker { ) -> Result { use futures::future::join_all; - // Map from tool call ID to (ToolCall, Meta, Tool) + // Map from tool call ID to (ToolCall, Meta, Tool, Context) // Retained because it's needed for PostToolCall hooks let mut call_info_map = HashMap::new(); let mut synthetic_results = Vec::new(); + let batch_id = format!("tool-batch-{}", self.tool_execution_batch_count); + self.tool_execution_batch_count += 1; // Phase 1: Apply pre_tool_call interceptor (determine skip/abort/synthetic result) let mut approved_calls = Vec::new(); - for mut tool_call in tool_calls { + for (call_index, mut tool_call) in tool_calls.into_iter().enumerate() { + let context = ToolExecutionContext::new(&tool_call.id, &batch_id, call_index); if let Some((meta, tool)) = self.tool_server.get_tool(&tool_call.name) { let mut info = ToolCallInfo { call: tool_call.clone(), meta, tool, + context, }; match self.interceptor.pre_tool_call(&mut info).await { @@ -934,9 +942,11 @@ impl Worker { } PreToolAction::SyntheticResult(result) => { let tool_call = info.call; + let mut context = info.context; + context.call_id = tool_call.id.clone(); call_info_map.insert( tool_call.id.clone(), - (tool_call, info.meta.clone(), info.tool.clone()), + (tool_call, info.meta.clone(), info.tool.clone(), context), ); synthetic_results.push(result); continue; @@ -953,26 +963,37 @@ impl Worker { // Reflect changes made by interceptor tool_call = info.call; + let mut context = info.context; + context.call_id = tool_call.id.clone(); call_info_map.insert( tool_call.id.clone(), - (tool_call.clone(), info.meta.clone(), info.tool.clone()), + ( + tool_call.clone(), + info.meta.clone(), + info.tool.clone(), + context.clone(), + ), ); - approved_calls.push(tool_call); + approved_calls.push((tool_call, context)); } else { // Unknown tools go into approved list as-is (will error at execution) - approved_calls.push(tool_call); + let context = ToolExecutionContext::new(&tool_call.id, &batch_id, call_index); + approved_calls.push((tool_call, context)); } } // Phase 2: Execute approved tools in parallel (cancellable) let futures: Vec<_> = approved_calls .into_iter() - .map(|tool_call| { + .map(|(tool_call, context)| { let tool_server = self.tool_server.clone(); async move { let input_json = serde_json::to_string(&tool_call.input).unwrap_or_default(); - match tool_server.call_tool(&tool_call.name, &input_json).await { + match tool_server + .call_tool(&tool_call.name, &input_json, context) + .await + { Ok(output) => ToolResult::from_output(&tool_call.id, output), Err(e) => ToolResult::error(&tool_call.id, e.to_string()), } @@ -996,12 +1017,15 @@ impl Worker { // Phase 3: Apply post_tool_call interceptor for tool_result in &mut results { - if let Some((tool_call, meta, tool)) = call_info_map.get(&tool_result.tool_use_id) { + if let Some((tool_call, meta, tool, context)) = + call_info_map.get(&tool_result.tool_use_id) + { let mut info = ToolResultInfo { call: tool_call.clone(), result: tool_result.clone(), meta: meta.clone(), tool: tool.clone(), + context: context.clone(), }; match self.interceptor.post_tool_call(&mut info).await { @@ -1026,7 +1050,7 @@ impl Worker { let Some(content) = tool_result.content.as_mut() else { continue; }; - let Some((tool_call, _, _)) = call_info_map.get(&tool_result.tool_use_id) else { + let Some((tool_call, _, _, _)) = call_info_map.get(&tool_result.tool_use_id) else { continue; }; let limit = limits.limit_for(&tool_call.name); @@ -1628,6 +1652,7 @@ impl Worker { locked_prefix_len: 0, turn_count: 0, llm_call_count: 0, + tool_execution_batch_count: 0, max_turns: None, turn_start_cbs: Vec::new(), turn_end_cbs: Vec::new(), @@ -1892,6 +1917,7 @@ impl Worker { locked_prefix_len, turn_count: self.turn_count, llm_call_count: self.llm_call_count, + tool_execution_batch_count: self.tool_execution_batch_count, max_turns: self.max_turns, turn_start_cbs: self.turn_start_cbs, turn_end_cbs: self.turn_end_cbs, @@ -1984,6 +2010,7 @@ impl Worker { locked_prefix_len: 0, turn_count: self.turn_count, llm_call_count: self.llm_call_count, + tool_execution_batch_count: self.tool_execution_batch_count, max_turns: self.max_turns, turn_start_cbs: self.turn_start_cbs, turn_end_cbs: self.turn_end_cbs, diff --git a/crates/llm-worker/tests/callback_test.rs b/crates/llm-worker/tests/callback_test.rs index 107c558f..20aa3bb8 100644 --- a/crates/llm-worker/tests/callback_test.rs +++ b/crates/llm-worker/tests/callback_test.rs @@ -218,7 +218,11 @@ struct FixedOutputTool { #[async_trait] impl Tool for FixedOutputTool { - async fn execute(&self, _input_json: &str) -> Result { + async fn execute( + &self, + _input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { Ok(self.output.clone()) } } @@ -289,7 +293,11 @@ struct ErroringTool { #[async_trait] impl Tool for ErroringTool { - async fn execute(&self, _input_json: &str) -> Result { + async fn execute( + &self, + _input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { Err(ToolError::ExecutionFailed(self.message.clone())) } } diff --git a/crates/llm-worker/tests/parallel_execution_test.rs b/crates/llm-worker/tests/parallel_execution_test.rs index 2e36c991..2000bfa3 100644 --- a/crates/llm-worker/tests/parallel_execution_test.rs +++ b/crates/llm-worker/tests/parallel_execution_test.rs @@ -2,8 +2,8 @@ //! //! Verify that Worker executes multiple tools in parallel. -use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use async_trait::async_trait; @@ -12,7 +12,9 @@ use llm_worker::interceptor::{ Interceptor, PostToolAction, PreToolAction, ToolCallInfo, ToolResultInfo, }; use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent}; -use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta, ToolOutput, ToolResult}; +use llm_worker::tool::{ + Tool, ToolDefinition, ToolError, ToolExecutionContext, ToolMeta, ToolOutput, ToolResult, +}; mod common; use common::MockLlmClient; @@ -59,13 +61,54 @@ impl SlowTool { #[async_trait] impl Tool for SlowTool { - async fn execute(&self, _input_json: &str) -> Result { + async fn execute( + &self, + _input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { self.call_count.fetch_add(1, Ordering::SeqCst); tokio::time::sleep(Duration::from_millis(self.delay_ms)).await; Ok(format!("Completed after {}ms", self.delay_ms).into()) } } +#[derive(Clone)] +struct ContextRecordingTool { + name: String, + contexts: Arc>>, +} + +impl ContextRecordingTool { + fn new(name: impl Into, contexts: Arc>>) -> Self { + Self { + name: name.into(), + contexts, + } + } + + fn definition(&self) -> ToolDefinition { + let tool = self.clone(); + Arc::new(move || { + let meta = ToolMeta::new(&tool.name) + .description("Records tool execution context") + .input_schema(serde_json::json!({"type": "object"})); + (meta, Arc::new(tool.clone()) as Arc) + }) + } +} + +#[async_trait] +impl Tool for ContextRecordingTool { + async fn execute( + &self, + _input_json: &str, + ctx: ToolExecutionContext, + ) -> Result { + self.contexts.lock().unwrap().push(ctx); + Ok("recorded".to_string().into()) + } +} + // ============================================================================= // Tests // ============================================================================= @@ -92,10 +135,18 @@ async fn test_parallel_tool_execution() { }), ]; - let client = MockLlmClient::new(events); + let client = MockLlmClient::with_responses(vec![ + events, + vec![ + Event::text_block_start(0), + Event::text_delta(0, "Done"), + Event::text_block_stop(0, None), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ], + ]); let mut worker = Worker::new(client); - - // Each tool waits 100ms let tool1 = SlowTool::new("slow_tool_1", 100); let tool2 = SlowTool::new("slow_tool_2", 100); let tool3 = SlowTool::new("slow_tool_3", 100); @@ -129,7 +180,201 @@ async fn test_parallel_tool_execution() { println!("Parallel execution completed in {:?}", elapsed); } -/// Hook: pre_tool_call - verify that skipped tools are not executed +#[tokio::test] +async fn test_tool_execution_context_order_and_batch_id() { + let client = MockLlmClient::with_responses(vec![ + vec![ + Event::tool_use_start(0, "call_a", "record_a"), + Event::tool_input_delta(0, r#"{}"#), + Event::tool_use_stop(0), + Event::tool_use_start(1, "call_b", "record_b"), + Event::tool_input_delta(1, r#"{}"#), + Event::tool_use_stop(1), + Event::tool_use_start(2, "call_c", "record_c"), + Event::tool_input_delta(2, r#"{}"#), + Event::tool_use_stop(2), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ], + vec![ + Event::text_block_start(0), + Event::text_delta(0, "Done"), + Event::text_block_stop(0, None), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ], + ]); + let mut worker = Worker::new(client); + let contexts = Arc::new(Mutex::new(Vec::new())); + + worker.register_tool(ContextRecordingTool::new("record_a", contexts.clone()).definition()); + worker.register_tool(ContextRecordingTool::new("record_b", contexts.clone()).definition()); + worker.register_tool(ContextRecordingTool::new("record_c", contexts.clone()).definition()); + + let _ = worker.run("record contexts").await; + + let mut contexts = contexts.lock().unwrap().clone(); + contexts.sort_by_key(|ctx| ctx.call_index); + + assert_eq!(contexts.len(), 3); + assert_eq!(contexts[0].call_id, "call_a"); + assert_eq!(contexts[0].call_index, 0); + assert_eq!(contexts[1].call_id, "call_b"); + assert_eq!(contexts[1].call_index, 1); + assert_eq!(contexts[2].call_id, "call_c"); + assert_eq!(contexts[2].call_index, 2); + assert_eq!(contexts[0].batch_id, contexts[1].batch_id); + assert_eq!(contexts[1].batch_id, contexts[2].batch_id); +} + +#[tokio::test] +async fn test_tool_execution_context_batch_id_changes_between_batches() { + let client = MockLlmClient::with_responses(vec![ + vec![ + Event::tool_use_start(0, "call_first", "record"), + Event::tool_input_delta(0, r#"{}"#), + Event::tool_use_stop(0), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ], + vec![ + Event::tool_use_start(0, "call_second", "record"), + Event::tool_input_delta(0, r#"{}"#), + Event::tool_use_stop(0), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ], + vec![ + Event::text_block_start(0), + Event::text_delta(0, "Done"), + Event::text_block_stop(0, None), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ], + ]); + let mut worker = Worker::new(client); + let contexts = Arc::new(Mutex::new(Vec::new())); + + worker.register_tool(ContextRecordingTool::new("record", contexts.clone()).definition()); + + let _ = worker.run("record batches").await; + + let contexts = contexts.lock().unwrap().clone(); + assert_eq!(contexts.len(), 2); + assert_eq!(contexts[0].call_id, "call_first"); + assert_eq!(contexts[0].call_index, 0); + assert_eq!(contexts[1].call_id, "call_second"); + assert_eq!(contexts[1].call_index, 0); + assert_ne!(contexts[0].batch_id, contexts[1].batch_id); +} + +#[tokio::test] +async fn test_tool_execution_context_for_skipped_and_synthetic_paths() { + let client = MockLlmClient::with_responses(vec![ + vec![ + Event::tool_use_start(0, "call_run", "record"), + Event::tool_input_delta(0, r#"{}"#), + Event::tool_use_stop(0), + Event::tool_use_start(1, "call_skip", "skip_tool"), + Event::tool_input_delta(1, r#"{}"#), + Event::tool_use_stop(1), + Event::tool_use_start(2, "call_synth", "synthetic_tool"), + Event::tool_input_delta(2, r#"{}"#), + Event::tool_use_stop(2), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ], + vec![ + Event::text_block_start(0), + Event::text_delta(0, "Done"), + Event::text_block_stop(0, None), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ], + ]); + let mut worker = Worker::new(client); + let executed_contexts = Arc::new(Mutex::new(Vec::new())); + let pre_contexts = Arc::new(Mutex::new(Vec::new())); + let post_contexts = Arc::new(Mutex::new(Vec::new())); + + worker + .register_tool(ContextRecordingTool::new("record", executed_contexts.clone()).definition()); + worker.register_tool( + ContextRecordingTool::new("skip_tool", executed_contexts.clone()).definition(), + ); + worker.register_tool( + ContextRecordingTool::new("synthetic_tool", executed_contexts.clone()).definition(), + ); + + struct ContextPolicy { + pre_contexts: Arc>>, + post_contexts: Arc>>, + } + + #[async_trait] + impl Interceptor for ContextPolicy { + async fn pre_tool_call(&self, info: &mut ToolCallInfo) -> PreToolAction { + self.pre_contexts.lock().unwrap().push(info.context.clone()); + match info.call.name.as_str() { + "skip_tool" => PreToolAction::Skip, + "synthetic_tool" => PreToolAction::SyntheticResult(ToolResult::from_output( + &info.call.id, + ToolOutput::from("synthetic result".to_string()), + )), + _ => PreToolAction::Continue, + } + } + + async fn post_tool_call(&self, info: &mut ToolResultInfo) -> PostToolAction { + self.post_contexts + .lock() + .unwrap() + .push(info.context.clone()); + PostToolAction::Continue + } + } + + worker.set_interceptor(ContextPolicy { + pre_contexts: pre_contexts.clone(), + post_contexts: post_contexts.clone(), + }); + + let _ = worker.run("record skipped and synthetic contexts").await; + + let mut pre_contexts = pre_contexts.lock().unwrap().clone(); + pre_contexts.sort_by_key(|ctx| ctx.call_index); + assert_eq!(pre_contexts.len(), 3); + assert_eq!(pre_contexts[0].call_id, "call_run"); + assert_eq!(pre_contexts[0].call_index, 0); + assert_eq!(pre_contexts[1].call_id, "call_skip"); + assert_eq!(pre_contexts[1].call_index, 1); + assert_eq!(pre_contexts[2].call_id, "call_synth"); + assert_eq!(pre_contexts[2].call_index, 2); + assert_eq!(pre_contexts[0].batch_id, pre_contexts[1].batch_id); + assert_eq!(pre_contexts[1].batch_id, pre_contexts[2].batch_id); + + let executed_contexts = executed_contexts.lock().unwrap().clone(); + assert_eq!(executed_contexts.len(), 1); + assert_eq!(executed_contexts[0].call_id, "call_run"); + assert_eq!(executed_contexts[0].call_index, 0); + + let mut post_contexts = post_contexts.lock().unwrap().clone(); + post_contexts.sort_by_key(|ctx| ctx.call_index); + assert_eq!(post_contexts.len(), 2); + assert_eq!(post_contexts[0].call_id, "call_run"); + assert_eq!(post_contexts[0].call_index, 0); + assert_eq!(post_contexts[1].call_id, "call_synth"); + assert_eq!(post_contexts[1].call_index, 2); + assert_eq!(post_contexts[0].batch_id, post_contexts[1].batch_id); +} + #[tokio::test] async fn test_before_tool_call_skip() { let events = vec![ @@ -220,7 +465,11 @@ async fn test_post_tool_call_modification() { #[async_trait] impl Tool for SimpleTool { - async fn execute(&self, _: &str) -> Result { + async fn execute( + &self, + _: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { Ok("Original Result".to_string().into()) } } diff --git a/crates/llm-worker/tests/tool_macro_test.rs b/crates/llm-worker/tests/tool_macro_test.rs index 3d326a98..6faab77a 100644 --- a/crates/llm-worker/tests/tool_macro_test.rs +++ b/crates/llm-worker/tests/tool_macro_test.rs @@ -9,6 +9,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use schemars; use serde; +use llm_worker::ToolExecutionContext; use llm_worker_macros::tool_registry; // ============================================================================= @@ -42,6 +43,15 @@ impl SimpleContext { async fn get_prefix(&self) -> String { self.prefix.clone() } + + /// Tool that observes execution context + #[tool] + async fn context_echo(&self, ctx: ToolExecutionContext, message: String) -> String { + format!( + "{}:{}:{}:{}", + ctx.batch_id, ctx.call_index, ctx.call_id, message + ) + } } #[tokio::test] @@ -74,7 +84,9 @@ async fn test_basic_tool_generation() { ); // Execution test - let result = tool.execute(r#"{"message": "World"}"#).await; + let result = tool + .execute(r#"{"message": "World"}"#, Default::default()) + .await; assert!(result.is_ok(), "Should execute successfully"); let output = result.unwrap(); assert!( @@ -97,7 +109,9 @@ async fn test_multiple_arguments() { assert_eq!(meta.name, "add"); - let result = tool.execute(r#"{"a": 10, "b": 20}"#).await; + let result = tool + .execute(r#"{"a": 10, "b": 20}"#, Default::default()) + .await; assert!(result.is_ok()); let output = result.unwrap(); assert!( @@ -118,7 +132,7 @@ async fn test_no_arguments() { assert_eq!(meta.name, "get_prefix"); // Call with empty JSON object - let result = tool.execute(r#"{}"#).await; + let result = tool.execute(r#"{}"#, Default::default()).await; assert!(result.is_ok()); let output = result.unwrap(); assert!( @@ -137,7 +151,9 @@ async fn test_invalid_arguments() { let (_, tool) = ctx.greet_definition()(); // Invalid JSON - let result = tool.execute(r#"{"wrong_field": "value"}"#).await; + let result = tool + .execute(r#"{"wrong_field": "value"}"#, Default::default()) + .await; assert!(result.is_err(), "Should fail with invalid arguments"); } @@ -175,7 +191,7 @@ async fn test_result_return_type_success() { let ctx = FallibleContext; let (_, tool) = ctx.validate_definition()(); - let result = tool.execute(r#"{"value": 42}"#).await; + let result = tool.execute(r#"{"value": 42}"#, Default::default()).await; assert!(result.is_ok(), "Should succeed for positive value"); let output = result.unwrap(); assert!( @@ -190,7 +206,7 @@ async fn test_result_return_type_error() { let ctx = FallibleContext; let (_, tool) = ctx.validate_definition()(); - let result = tool.execute(r#"{"value": -1}"#).await; + let result = tool.execute(r#"{"value": -1}"#, Default::default()).await; assert!(result.is_err(), "Should fail for negative value"); let err = result.unwrap_err(); @@ -228,9 +244,9 @@ async fn test_sync_method() { let (_, tool) = ctx.increment_definition()(); // Execute 3 times - let result1 = tool.execute(r#"{}"#).await; - let result2 = tool.execute(r#"{}"#).await; - let result3 = tool.execute(r#"{}"#).await; + let result1 = tool.execute(r#"{}"#, Default::default()).await; + let result2 = tool.execute(r#"{}"#, Default::default()).await; + let result3 = tool.execute(r#"{}"#, Default::default()).await; assert!(result1.is_ok()); assert!(result2.is_ok()); @@ -240,6 +256,24 @@ async fn test_sync_method() { assert_eq!(ctx.counter.load(Ordering::SeqCst), 3); } +#[tokio::test] +async fn test_tool_macro_passes_execution_context() { + let ctx = SimpleContext { + prefix: "Test".to_string(), + }; + let (_, tool) = ctx.context_echo_definition()(); + + let output = tool + .execute( + r#"{"message":"hello"}"#, + ToolExecutionContext::new("call-ctx", "batch-ctx", 7), + ) + .await + .unwrap(); + + assert_eq!(output.summary, "\"batch-ctx:7:call-ctx:hello\""); +} + // ============================================================================= // Test: ToolMeta Immutability // ============================================================================= diff --git a/crates/llm-worker/tests/worker_fixtures.rs b/crates/llm-worker/tests/worker_fixtures.rs index cef140c8..8023e20b 100644 --- a/crates/llm-worker/tests/worker_fixtures.rs +++ b/crates/llm-worker/tests/worker_fixtures.rs @@ -58,7 +58,11 @@ impl MockWeatherTool { #[async_trait] impl Tool for MockWeatherTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { self.call_count.fetch_add(1, Ordering::SeqCst); // Parse input diff --git a/crates/llm-worker/tests/worker_state_test.rs b/crates/llm-worker/tests/worker_state_test.rs index c8a513ad..58c7cf33 100644 --- a/crates/llm-worker/tests/worker_state_test.rs +++ b/crates/llm-worker/tests/worker_state_test.rs @@ -136,7 +136,11 @@ impl CountingTool { #[async_trait] impl Tool for CountingTool { - async fn execute(&self, _input_json: &str) -> Result { + async fn execute( + &self, + _input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { self.calls.fetch_add(1, Ordering::SeqCst); Ok(format!("{}-ok", self.name).into()) } diff --git a/crates/memory/src/extract/tool.rs b/crates/memory/src/extract/tool.rs index 087b2a20..ef0f8c1e 100644 --- a/crates/memory/src/extract/tool.rs +++ b/crates/memory/src/extract/tool.rs @@ -54,7 +54,11 @@ struct WriteExtractedTool { #[async_trait] impl Tool for WriteExtractedTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let payload: ExtractedPayload = serde_json::from_str(input_json).map_err(|e| { ToolError::InvalidArgument(format!("invalid write_extracted input: {e}")) })?; @@ -122,7 +126,7 @@ mod tests { "requests": [] }) .to_string(); - let out = tool.execute(&input).await.unwrap(); + let out = tool.execute(&input, Default::default()).await.unwrap(); assert!(out.summary.contains("decisions=1")); let payload = ctx.take_payload().unwrap(); assert_eq!(payload.decisions.len(), 1); @@ -137,7 +141,7 @@ mod tests { let first = serde_json::json!({"decisions": [], "discussions": [], "attempts": [], "requests": []}) .to_string(); - tool.execute(&first).await.unwrap(); + tool.execute(&first, Default::default()).await.unwrap(); let second = serde_json::json!({ "decisions": [], @@ -146,7 +150,7 @@ mod tests { "requests": [] }) .to_string(); - tool.execute(&second).await.unwrap(); + tool.execute(&second, Default::default()).await.unwrap(); let payload = ctx.take_payload().unwrap(); assert_eq!(payload.attempts.len(), 1); @@ -157,7 +161,7 @@ mod tests { async fn invalid_json_returns_invalid_argument() { let ctx = Arc::new(ExtractWorkerContext::new()); let tool: Arc = Arc::new(WriteExtractedTool { ctx: ctx.clone() }); - let res = tool.execute("not json").await; + let res = tool.execute("not json", Default::default()).await; assert!(matches!(res, Err(ToolError::InvalidArgument(_)))); assert!(ctx.take_payload().is_none()); } diff --git a/crates/memory/src/tool/delete.rs b/crates/memory/src/tool/delete.rs index 16ffc939..1e36aa85 100644 --- a/crates/memory/src/tool/delete.rs +++ b/crates/memory/src/tool/delete.rs @@ -29,7 +29,11 @@ struct MemoryDeleteTool { #[async_trait] impl Tool for MemoryDeleteTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: DeleteParams = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid MemoryDelete input: {e}")))?; let path = params @@ -139,7 +143,10 @@ mod tests { let (_, tool) = delete_tool(layout.clone())(); let out = tool - .execute(r#"{"kind":"decision","slug":"obsolete"}"#) + .execute( + r#"{"kind":"decision","slug":"obsolete"}"#, + Default::default(), + ) .await .unwrap(); assert!(out.summary.contains("Deleted")); diff --git a/crates/memory/src/tool/edit.rs b/crates/memory/src/tool/edit.rs index 9fa896d2..dc495e3c 100644 --- a/crates/memory/src/tool/edit.rs +++ b/crates/memory/src/tool/edit.rs @@ -47,7 +47,11 @@ struct EditTool { #[async_trait] impl Tool for EditTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: EditParams = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid MemoryEdit input: {e}")))?; @@ -316,7 +320,10 @@ mod tests { "old_string": "body body", "new_string": "edited", }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); assert!(out.summary.contains("1 replacement")); let after = std::fs::read_to_string(&path).unwrap(); assert!(after.contains("edited")); @@ -335,7 +342,10 @@ mod tests { "old_string": "status: open\n", "new_string": "", }); - let err = tool.execute(&inp.to_string()).await.unwrap_err(); + let err = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap_err(); let msg = format!("{err}"); assert!(msg.contains("status") || msg.contains("missing")); @@ -354,7 +364,10 @@ mod tests { "old_string": "x", "new_string": "y", }); - let err = tool.execute(&inp.to_string()).await.unwrap_err(); + let err = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap_err(); assert!(matches!(err, ToolError::ExecutionFailed(_))); } @@ -369,7 +382,10 @@ mod tests { "old_string": "x", "new_string": "y", }); - let err = tool.execute(&inp.to_string()).await.unwrap_err(); + let err = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap_err(); assert!(matches!(err, ToolError::InvalidArgument(_))); } } diff --git a/crates/memory/src/tool/query.rs b/crates/memory/src/tool/query.rs index 2c51e5f5..041882c8 100644 --- a/crates/memory/src/tool/query.rs +++ b/crates/memory/src/tool/query.rs @@ -126,7 +126,11 @@ struct KnowledgeQueryTool { #[async_trait] impl Tool for MemoryQueryTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: MemoryQueryParams = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid MemoryQuery input: {e}")))?; let needle = match params.query.as_deref() { @@ -240,7 +244,11 @@ impl Tool for MemoryQueryTool { #[async_trait] impl Tool for KnowledgeQueryTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: KnowledgeQueryParams = serde_json::from_str(input_json).map_err(|e| { ToolError::InvalidArgument(format!("invalid KnowledgeQuery input: {e}")) })?; @@ -568,7 +576,10 @@ mod tests { write_decision(dir.path(), "beta", "no match here\n"); let (_, tool) = memory_query_tool(layout, QueryConfig::default())(); let inp = serde_json::json!({ "query": "ollama" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let records: Vec = parse_records(&out); assert_eq!(records.len(), 1); assert_eq!(records[0].slug, "alpha"); @@ -596,7 +607,7 @@ mod tests { .unwrap(); let (_, tool) = memory_query_tool(layout, QueryConfig::default())(); - let out = tool.execute("{}").await.unwrap(); + let out = tool.execute("{}", Default::default()).await.unwrap(); let records: Vec = parse_records(&out); let mut slugs: Vec<&str> = records.iter().map(|r| r.slug.as_str()).collect(); slugs.sort(); @@ -616,7 +627,10 @@ mod tests { .unwrap(); let (_, tool) = memory_query_tool(layout, QueryConfig::default())(); let inp = serde_json::json!({ "query": "needle" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let records: Vec = parse_records(&out); assert_eq!(records.len(), 1); assert_eq!(records[0].slug, "summary"); @@ -633,7 +647,10 @@ mod tests { let (_, tool) = memory_query_tool(layout, QueryConfig::default())(); let inp = serde_json::json!({ "query": "needle" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let records: Vec = parse_records(&out); assert!(records.is_empty(), "got records: {:?}", out.content); } @@ -653,8 +670,14 @@ mod tests { let (_, memory_tool) = memory_query_tool(layout.clone(), QueryConfig::default())(); let (_, knowledge_tool) = knowledge_query_tool(layout.clone(), QueryConfig::default())(); let inp = serde_json::json!({ "query": "needle" }); - memory_tool.execute(&inp.to_string()).await.unwrap(); - knowledge_tool.execute(&inp.to_string()).await.unwrap(); + memory_tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); + knowledge_tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let report = crate::usage::build_usage_report(&layout).unwrap(); assert!(report.records.is_empty()); @@ -673,7 +696,10 @@ mod tests { }; let (_, tool) = memory_query_tool(layout, cfg)(); let inp = serde_json::json!({ "query": "needle" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let records: Vec = parse_records(&out); assert_eq!(records.len(), 3); } @@ -692,7 +718,10 @@ mod tests { }; let (_, tool) = memory_query_tool(layout, cfg)(); let inp = serde_json::json!({ "query": "needle" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let records: Vec = parse_records(&out); assert_eq!(records.len(), 1); let e = records[0].excerpt.as_deref().unwrap(); @@ -708,7 +737,10 @@ mod tests { let (_dir, layout) = setup(); let (_, tool) = memory_query_tool(layout, QueryConfig::default())(); let inp = serde_json::json!({ "query": " " }); - let err = tool.execute(&inp.to_string()).await.unwrap_err(); + let err = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap_err(); assert!(matches!(err, ToolError::InvalidArgument(_))); } @@ -724,7 +756,10 @@ mod tests { ); let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())(); let inp = serde_json::json!({ "query": "ollama" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let records: Vec = parse_records(&out); assert_eq!(records.len(), 1); assert_eq!(records[0].slug, "policy"); @@ -748,7 +783,7 @@ mod tests { write_knowledge(dir.path(), "h1", "howto", "d2", "body\n"); let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())(); - let out = tool.execute("{}").await.unwrap(); + let out = tool.execute("{}", Default::default()).await.unwrap(); let records: Vec = parse_records(&out); let mut slugs: Vec<&str> = records.iter().map(|r| r.slug.as_str()).collect(); slugs.sort(); @@ -764,7 +799,10 @@ mod tests { let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())(); let inp = serde_json::json!({ "query": "needle", "kind": "howto" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let records: Vec = parse_records(&out); assert_eq!(records.len(), 1); assert_eq!(records[0].slug, "h1"); @@ -778,7 +816,10 @@ mod tests { let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())(); let inp = serde_json::json!({ "kind": "howto" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let records: Vec = parse_records(&out); assert_eq!(records.len(), 1); assert_eq!(records[0].slug, "h1"); @@ -792,7 +833,10 @@ mod tests { let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())(); let inp = serde_json::json!({ "query": "xyzzy" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let records: Vec = parse_records(&out); assert_eq!(records.len(), 1); assert_eq!(records[0].slug, "p"); @@ -804,7 +848,10 @@ mod tests { write_knowledge(dir.path(), "p", "policy", "d", "no match\n"); let (_, tool) = knowledge_query_tool(layout, QueryConfig::default())(); let inp = serde_json::json!({ "query": "absent" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let records: Vec = parse_records(&out); assert!(records.is_empty()); } diff --git a/crates/memory/src/tool/read.rs b/crates/memory/src/tool/read.rs index 78033323..9a65f5fe 100644 --- a/crates/memory/src/tool/read.rs +++ b/crates/memory/src/tool/read.rs @@ -45,7 +45,11 @@ struct ReadTool { #[async_trait] impl Tool for ReadTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: ReadParams = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid MemoryRead input: {e}")))?; @@ -225,7 +229,10 @@ mod tests { let (_meta, tool) = read_tool(layout)(); let inp = serde_json::json!({ "kind": "decision", "slug": "foo" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let body = out.content.unwrap(); assert!(body.contains(" 1\talpha")); assert!(body.contains(" 2\tbeta")); @@ -240,7 +247,10 @@ mod tests { let (_, tool) = read_tool(layout)(); let inp = serde_json::json!({ "kind": "summary" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); assert!(out.content.unwrap().contains("summary body")); } @@ -249,7 +259,10 @@ mod tests { let (_dir, layout) = setup(); let (_, tool) = read_tool(layout)(); let inp = serde_json::json!({ "kind": "summary", "slug": "x" }); - let err = tool.execute(&inp.to_string()).await.unwrap_err(); + let err = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap_err(); assert!(matches!(err, ToolError::InvalidArgument(_))); } @@ -258,7 +271,10 @@ mod tests { let (_dir, layout) = setup(); let (_, tool) = read_tool(layout)(); let inp = serde_json::json!({ "kind": "decision" }); - let err = tool.execute(&inp.to_string()).await.unwrap_err(); + let err = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap_err(); assert!(matches!(err, ToolError::InvalidArgument(_))); } @@ -267,7 +283,10 @@ mod tests { let (_dir, layout) = setup(); let (_, tool) = read_tool(layout)(); let inp = serde_json::json!({ "kind": "decision", "slug": "Bad-Slug" }); - let err = tool.execute(&inp.to_string()).await.unwrap_err(); + let err = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap_err(); assert!(matches!(err, ToolError::InvalidArgument(_))); } @@ -280,7 +299,10 @@ mod tests { let (_, tool) = read_tool(layout)(); let inp = serde_json::json!({ "kind": "knowledge", "slug": "policy" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); assert!(out.content.unwrap().contains("k")); } @@ -293,7 +315,9 @@ mod tests { let (_, tool) = read_tool_with_usage(layout.clone(), "session-1")(); let inp = serde_json::json!({ "kind": "decision", "slug": "foo" }); - tool.execute(&inp.to_string()).await.unwrap(); + tool.execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let report = usage::build_usage_report(&layout).unwrap(); assert_eq!(report.records.len(), 1); @@ -310,7 +334,10 @@ mod tests { let (_dir, layout) = setup(); let (_, tool) = read_tool(layout)(); let inp = serde_json::json!({ "kind": "decision", "slug": "missing" }); - let err = tool.execute(&inp.to_string()).await.unwrap_err(); + let err = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap_err(); assert!(matches!(err, ToolError::ExecutionFailed(_))); } } diff --git a/crates/memory/src/tool/write.rs b/crates/memory/src/tool/write.rs index 7307c1f1..8635f4ad 100644 --- a/crates/memory/src/tool/write.rs +++ b/crates/memory/src/tool/write.rs @@ -42,7 +42,11 @@ struct WriteTool { #[async_trait] impl Tool for WriteTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: WriteParams = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid MemoryWrite input: {e}")))?; @@ -229,7 +233,10 @@ mod tests { "kind": "summary", "content": content, }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); assert!(out.summary.contains("Created")); assert!(path.exists()); } @@ -249,7 +256,10 @@ mod tests { "slug": "foo", "content": content, }); - let err = tool.execute(&inp.to_string()).await.unwrap_err(); + let err = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap_err(); let msg = format!("{err}"); assert!(msg.contains("status") || msg.contains("missing"), "{msg}"); } @@ -271,7 +281,10 @@ mod tests { "slug": "foo", "content": initial, }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); assert!(out.summary.contains("Overwrote")); } @@ -283,7 +296,10 @@ mod tests { "kind": "decision", "content": "ignored", }); - let err = tool.execute(&inp.to_string()).await.unwrap_err(); + let err = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap_err(); assert!(matches!(err, ToolError::InvalidArgument(_))); } @@ -298,7 +314,11 @@ mod tests { "slug": "foo", "content": bad, }); - assert!(tool.execute(&inp.to_string()).await.is_err()); + assert!( + tool.execute(&inp.to_string(), Default::default()) + .await + .is_err() + ); assert!(!path.exists()); } @@ -312,7 +332,10 @@ mod tests { "slug": "wf", "content": "---\n---\n", }); - let err = tool.execute(&inp.to_string()).await.unwrap_err(); + let err = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap_err(); assert!(matches!(err, ToolError::InvalidArgument(_))); } } diff --git a/crates/pod/src/compact/worker.rs b/crates/pod/src/compact/worker.rs index 118efabe..b74d6f59 100644 --- a/crates/pod/src/compact/worker.rs +++ b/crates/pod/src/compact/worker.rs @@ -151,7 +151,11 @@ struct SearchSessionLogTool { #[async_trait] impl Tool for SearchSessionLogTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: SearchSessionParams = serde_json::from_str(input_json).map_err(|e| { ToolError::InvalidArgument(format!("invalid search_session_log input: {e}")) })?; @@ -206,7 +210,11 @@ struct ReadSessionItemsTool { #[async_trait] impl Tool for ReadSessionItemsTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: ReadSessionParams = serde_json::from_str(input_json).map_err(|e| { ToolError::InvalidArgument(format!("invalid read_session_items input: {e}")) })?; @@ -368,7 +376,11 @@ struct MarkReadRequiredTool { #[async_trait] impl Tool for MarkReadRequiredTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: MarkParams = serde_json::from_str(input_json).map_err(|e| { ToolError::InvalidArgument(format!("invalid mark_read_required input: {e}")) })?; @@ -425,7 +437,11 @@ struct AddReferenceTool { #[async_trait] impl Tool for AddReferenceTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: ReferenceParams = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid add_reference input: {e}")))?; let mut guard = self.ctx.lock().expect("compact worker context poisoned"); @@ -449,7 +465,11 @@ struct WriteSummaryTool { #[async_trait] impl Tool for WriteSummaryTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: SummaryParams = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid write_summary input: {e}")))?; let mut guard = self.ctx.lock().expect("compact worker context poisoned"); @@ -749,7 +769,7 @@ mod tests { ctx: ctx.clone(), }); let input = serde_json::json!({ "file_path": path.to_str().unwrap() }).to_string(); - let out = tool.execute(&input).await.unwrap(); + let out = tool.execute(&input, Default::default()).await.unwrap(); assert!(out.summary.starts_with("Marked")); let guard = ctx.lock().unwrap(); @@ -770,7 +790,7 @@ mod tests { ctx: ctx.clone(), }); let input = serde_json::json!({ "file_path": path.to_str().unwrap() }).to_string(); - let res = tool.execute(&input).await; + let res = tool.execute(&input, Default::default()).await; assert!(matches!(res, Err(ToolError::ExecutionFailed(_)))); let guard = ctx.lock().unwrap(); @@ -784,11 +804,11 @@ mod tests { let tool: Arc = Arc::new(WriteSummaryTool { ctx: ctx.clone() }); let first = serde_json::json!({ "text": "first" }).to_string(); - let out1 = tool.execute(&first).await.unwrap(); + let out1 = tool.execute(&first, Default::default()).await.unwrap(); assert!(out1.summary.contains("recorded")); let second = serde_json::json!({ "text": "second" }).to_string(); - let out2 = tool.execute(&second).await.unwrap(); + let out2 = tool.execute(&second, Default::default()).await.unwrap(); assert!(out2.summary.contains("replaced")); assert_eq!(ctx.lock().unwrap().summary.as_deref(), Some("second")); @@ -801,8 +821,8 @@ mod tests { let p = "/abs/path.rs"; let input = serde_json::json!({ "file_path": p }).to_string(); - tool.execute(&input).await.unwrap(); - tool.execute(&input).await.unwrap(); + tool.execute(&input, Default::default()).await.unwrap(); + tool.execute(&input, Default::default()).await.unwrap(); let guard = ctx.lock().unwrap(); assert_eq!(guard.references.len(), 1); @@ -823,7 +843,7 @@ mod tests { state: Arc::new(SessionLogToolState { items }), }); let input = serde_json::json!({ "query": "compact", "limit": 10 }).to_string(); - let out = tool.execute(&input).await.unwrap(); + let out = tool.execute(&input, Default::default()).await.unwrap(); let content = out.content.unwrap(); assert!(content.contains("investigate compact failure")); @@ -842,7 +862,7 @@ mod tests { state: Arc::new(SessionLogToolState { items }), }); let input = serde_json::json!({ "offset": 0, "limit": 1, "mode": "full" }).to_string(); - let out = tool.execute(&input).await.unwrap(); + let out = tool.execute(&input, Default::default()).await.unwrap(); let content = out.content.unwrap(); assert!(content.contains("raw trace detail")); diff --git a/crates/pod/src/discovery.rs b/crates/pod/src/discovery.rs index dfdefb1e..8c87a3cb 100644 --- a/crates/pod/src/discovery.rs +++ b/crates/pod/src/discovery.rs @@ -752,7 +752,11 @@ impl Tool for ListPodsTool where St: PodMetadataStore + Clone + Send + Sync + 'static, { - async fn execute(&self, _input_json: &str) -> Result { + async fn execute( + &self, + _input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let items = self .discovery .list_visible() @@ -775,7 +779,11 @@ impl Tool for RestorePodTool where St: PodMetadataStore + Clone + Send + Sync + 'static, { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let input: PodNameInput = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid RestorePod input: {e}")))?; let result = self @@ -847,7 +855,11 @@ impl Tool for SendToPeerPodTool where St: PodMetadataStore + Clone + Send + Sync + 'static, { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let input: SendToPeerPodInput = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid SendToPeerPod input: {e}")))?; let detail = self @@ -1392,7 +1404,7 @@ mod tests { let (_, tool) = send_to_peer_pod_tool(discovery)(); let output = tool - .execute(r#"{"name":"target","message":"hello"}"#) + .execute(r#"{"name":"target","message":"hello"}"#, Default::default()) .await .unwrap(); assert_eq!(output.summary, "sent peer message to `target`"); diff --git a/crates/pod/src/feature.rs b/crates/pod/src/feature.rs index 6b9d05e1..5c903eb7 100644 --- a/crates/pod/src/feature.rs +++ b/crates/pod/src/feature.rs @@ -1292,7 +1292,11 @@ mod tests { #[async_trait] impl Tool for DummyTool { - async fn execute(&self, _input_json: &str) -> Result { + async fn execute( + &self, + _input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { Ok(ToolOutput::from("ok".to_string())) } } diff --git a/crates/pod/src/feature/builtin/task/tool_impl.rs b/crates/pod/src/feature/builtin/task/tool_impl.rs index 9ef9e221..40d1adf4 100644 --- a/crates/pod/src/feature/builtin/task/tool_impl.rs +++ b/crates/pod/src/feature/builtin/task/tool_impl.rs @@ -73,7 +73,11 @@ step: leave the task as-is, summarize the problem to the user, and end the turn. #[async_trait] impl Tool for TaskCreateTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: TaskCreateParams = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid TaskCreate input: {e}")))?; let created = self.store.create(params.subject, params.description); @@ -93,7 +97,11 @@ impl Tool for TaskCreateTool { #[async_trait] impl Tool for TaskListTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let _: TaskListParams = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid TaskList input: {e}")))?; let tasks = self.store.list(); @@ -106,7 +114,11 @@ impl Tool for TaskListTool { #[async_trait] impl Tool for TaskGetTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: TaskGetParams = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid TaskGet input: {e}")))?; let task = self.store.get(params.taskid).ok_or_else(|| { @@ -122,7 +134,11 @@ impl Tool for TaskGetTool { #[async_trait] impl Tool for TaskUpdateTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: TaskUpdateParams = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid TaskUpdate input: {e}")))?; let updated = self @@ -241,14 +257,20 @@ mod tests { let update = tool(task_update_tool(store.clone())); let out = create - .execute(r#"{"subject":"implement","description":"write code"}"#) + .execute( + r#"{"subject":"implement","description":"write code"}"#, + Default::default(), + ) .await .unwrap(); assert!(out.summary.contains("Created task 1")); assert_eq!(store.get(1).unwrap().status, TaskStatus::Pending); let out = update - .execute(r#"{"taskid":1,"status":"inprogress","subject":"implement tasks"}"#) + .execute( + r#"{"taskid":1,"status":"inprogress","subject":"implement tasks"}"#, + Default::default(), + ) .await .unwrap(); assert!(out.summary.contains("Updated task 1")); @@ -256,11 +278,14 @@ mod tests { assert_eq!(task.status, TaskStatus::Inprogress); assert_eq!(task.subject, "implement tasks"); - let out = get.execute(r#"{"taskid":1}"#).await.unwrap(); + let out = get + .execute(r#"{"taskid":1}"#, Default::default()) + .await + .unwrap(); assert!(out.summary.contains("Task 1 (inprogress)")); assert!(out.content.unwrap().contains("implement tasks")); - let out = list.execute("{}").await.unwrap(); + let out = list.execute("{}", Default::default()).await.unwrap(); assert!(out.summary.contains("1 task(s)")); let content = out.content.unwrap(); assert!(content.contains("\"taskid\": 1")); @@ -273,11 +298,14 @@ mod tests { store.create("s".into(), "d".into()); let update = tool(task_update_tool(store)); - let err = update.execute(r#"{"taskid":1}"#).await.unwrap_err(); + let err = update + .execute(r#"{"taskid":1}"#, Default::default()) + .await + .unwrap_err(); assert!(err.to_string().contains("at least one")); let err = update - .execute(r#"{"taskid":99,"status":"deleted"}"#) + .execute(r#"{"taskid":99,"status":"deleted"}"#, Default::default()) .await .unwrap_err(); assert!(err.to_string().contains("taskid 99 not found")); diff --git a/crates/pod/src/ipc/interceptor.rs b/crates/pod/src/ipc/interceptor.rs index a73aeafb..41f36e61 100644 --- a/crates/pod/src/ipc/interceptor.rs +++ b/crates/pod/src/ipc/interceptor.rs @@ -491,6 +491,7 @@ mod tests { }, meta, tool, + context: llm_worker::tool::ToolExecutionContext::new("call-id", "test-batch", 0), } } @@ -898,6 +899,7 @@ mod tests { ), meta: info.meta, tool: info.tool, + context: info.context, }; let action = interceptor.post_tool_call(&mut result_info).await; diff --git a/crates/pod/src/spawn/comm_tools.rs b/crates/pod/src/spawn/comm_tools.rs index 44fadb5b..27eed70d 100644 --- a/crates/pod/src/spawn/comm_tools.rs +++ b/crates/pod/src/spawn/comm_tools.rs @@ -62,7 +62,11 @@ struct SendToPodTool { #[async_trait] impl Tool for SendToPodTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let input: SendToPodInput = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid SendToPod input: {e}")))?; let record = self @@ -123,7 +127,11 @@ struct ReadPodOutputTool { #[async_trait] impl Tool for ReadPodOutputTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let input: NameInput = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid ReadPodOutput input: {e}")))?; let record = self @@ -197,7 +205,11 @@ struct StopPodTool { #[async_trait] impl Tool for StopPodTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let input: NameInput = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid StopPod input: {e}")))?; let record = self diff --git a/crates/pod/src/spawn/tool.rs b/crates/pod/src/spawn/tool.rs index 642c78ba..5fc4d13e 100644 --- a/crates/pod/src/spawn/tool.rs +++ b/crates/pod/src/spawn/tool.rs @@ -298,7 +298,11 @@ impl SpawnPodTool { #[async_trait] impl Tool for SpawnPodTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let input: SpawnPodInput = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid SpawnPod input: {e}")))?; diff --git a/crates/pod/tests/controller_test.rs b/crates/pod/tests/controller_test.rs index 54b91334..5d9c88a5 100644 --- a/crates/pod/tests/controller_test.rs +++ b/crates/pod/tests/controller_test.rs @@ -1351,7 +1351,11 @@ struct HangingTool; #[async_trait] impl Tool for HangingTool { - async fn execute(&self, _input: &str) -> Result { + async fn execute( + &self, + _input: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { std::future::pending::<()>().await; unreachable!() } diff --git a/crates/pod/tests/pod_comm_tools_test.rs b/crates/pod/tests/pod_comm_tools_test.rs index 9786017f..fa3206ea 100644 --- a/crates/pod/tests/pod_comm_tools_test.rs +++ b/crates/pod/tests/pod_comm_tools_test.rs @@ -262,7 +262,7 @@ async fn send_to_pod_delivers_run_method() { let def = send_to_pod_tool(registry); let (_meta, tool) = def(); let input = json!({ "name": "child", "message": "hello there" }).to_string(); - let output: ToolOutput = tool.execute(&input).await.unwrap(); + let output: ToolOutput = tool.execute(&input, Default::default()).await.unwrap(); assert!( output.summary.contains("child"), "summary: {}", @@ -285,7 +285,7 @@ async fn send_to_pod_errors_on_unknown_pod() { let def = send_to_pod_tool(registry); let (_meta, tool) = def(); let input = json!({ "name": "nope", "message": "hi" }).to_string(); - let err = tool.execute(&input).await.unwrap_err(); + let err = tool.execute(&input, Default::default()).await.unwrap_err(); assert!(err.to_string().contains("no spawned pod"), "{err}"); } @@ -307,7 +307,7 @@ async fn send_to_pod_errors_when_pod_already_running() { let def = send_to_pod_tool(registry); let (_meta, tool) = def(); let input = json!({ "name": "child", "message": "hi" }).to_string(); - let err = tool.execute(&input).await.unwrap_err(); + let err = tool.execute(&input, Default::default()).await.unwrap_err(); assert!( err.to_string().contains("already running"), "expected AlreadyRunning wording: {err}" @@ -341,13 +341,13 @@ async fn read_pod_output_returns_new_assistant_text_then_empty_on_second_call() let (_meta, tool) = def(); let input = json!({ "name": "child" }).to_string(); - let first: ToolOutput = tool.execute(&input).await.unwrap(); + let first: ToolOutput = tool.execute(&input, Default::default()).await.unwrap(); let body = first.content.expect("first read should have content"); assert!(body.contains("hi back"), "body: {body}"); assert!(body.contains("still working"), "body: {body}"); // Cursor now points past all items โ€” second call returns no new text. - let second: ToolOutput = tool.execute(&input).await.unwrap(); + let second: ToolOutput = tool.execute(&input, Default::default()).await.unwrap(); assert!( second.content.is_none(), "unexpected content: {:?}", @@ -371,7 +371,7 @@ async fn read_pod_output_reports_stopped_on_dead_socket() { let def = read_pod_output_tool(registry); let (_meta, tool) = def(); let input = json!({ "name": "child" }).to_string(); - let output: ToolOutput = tool.execute(&input).await.unwrap(); + let output: ToolOutput = tool.execute(&input, Default::default()).await.unwrap(); assert!(output.summary.contains("stopped"), "{}", output.summary); } @@ -452,7 +452,7 @@ async fn stop_pod_sends_shutdown_and_releases_scope() { let def = stop_pod_tool(registry.clone()); let (_meta, tool) = def(); let input = json!({ "name": "child" }).to_string(); - let output: ToolOutput = tool.execute(&input).await.unwrap(); + let output: ToolOutput = tool.execute(&input, Default::default()).await.unwrap(); assert!(output.summary.contains("stopped"), "{}", output.summary); // The child got a Shutdown. @@ -497,7 +497,7 @@ async fn stop_pod_succeeds_even_when_child_unreachable() { let def = stop_pod_tool(registry.clone()); let (_meta, tool) = def(); let input = json!({ "name": "child" }).to_string(); - let output: ToolOutput = tool.execute(&input).await.unwrap(); + let output: ToolOutput = tool.execute(&input, Default::default()).await.unwrap(); assert!(output.summary.contains("stopped"), "{}", output.summary); // Registry no longer knows about the child. @@ -545,7 +545,7 @@ async fn restored_registry_uses_pod_state_without_runtime_file() { let def = send_to_pod_tool(restored.clone()); let (_meta, tool) = def(); let input = json!({ "name": "child", "message": "after restart" }).to_string(); - tool.execute(&input).await.unwrap(); + tool.execute(&input, Default::default()).await.unwrap(); match received.recv().await.expect("expected Run") { Method::Run { input } => match input.as_slice() { [protocol::Segment::Text { content }] => assert_eq!(content, "after restart"), @@ -556,7 +556,7 @@ async fn restored_registry_uses_pod_state_without_runtime_file() { let def = stop_pod_tool(restored.clone()); let (_meta, tool) = def(); - tool.execute(&json!({ "name": "child" }).to_string()) + tool.execute(&json!({ "name": "child" }).to_string(), Default::default()) .await .unwrap(); assert!(matches!( diff --git a/crates/pod/tests/session_metrics_test.rs b/crates/pod/tests/session_metrics_test.rs index 08bb2eb6..36312215 100644 --- a/crates/pod/tests/session_metrics_test.rs +++ b/crates/pod/tests/session_metrics_test.rs @@ -79,7 +79,11 @@ struct BigContentTool { #[async_trait] impl Tool for BigContentTool { - async fn execute(&self, _input: &str) -> Result { + async fn execute( + &self, + _input: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { Ok(ToolOutput { summary: self.summary.into(), content: Some(self.content.clone()), diff --git a/crates/pod/tests/spawn_pod_test.rs b/crates/pod/tests/spawn_pod_test.rs index 82f7c10a..56a260cb 100644 --- a/crates/pod/tests/spawn_pod_test.rs +++ b/crates/pod/tests/spawn_pod_test.rs @@ -312,7 +312,7 @@ async fn spawn_pod_launches_runtime_in_workspace_and_passes_tool_cwd() { }) .to_string(); - tool.execute(&input).await.unwrap(); + tool.execute(&input, Default::default()).await.unwrap(); assert!(matches!(received.await.unwrap(), Some(Method::Run { .. }))); let invocation = read_recorded_runtime_invocation(&output_path).await; assert_eq!(invocation[0], allow_root.path().to_str().unwrap()); @@ -373,7 +373,7 @@ async fn spawn_pod_omitted_cwd_preserves_spawner_pwd() { }) .to_string(); - tool.execute(&input).await.unwrap(); + tool.execute(&input, Default::default()).await.unwrap(); assert!(matches!(received.await.unwrap(), Some(Method::Run { .. }))); let invocation = read_recorded_runtime_invocation(&output_path).await; assert_eq!(invocation[0], allow_root.path().to_str().unwrap()); @@ -433,7 +433,7 @@ async fn spawn_pod_delegates_scope_and_sends_run() { .is_writable(&allow_root.path().join("a.txt")) ); - let output: ToolOutput = tool.execute(&input).await.unwrap(); + let output: ToolOutput = tool.execute(&input, Default::default()).await.unwrap(); assert!( output.summary.contains("child"), "summary: {}", @@ -519,7 +519,7 @@ async fn spawn_pod_requires_explicit_delegation_even_with_direct_scope() { }) .to_string(); - let err = tool.execute(&input).await.unwrap_err(); + let err = tool.execute(&input, Default::default()).await.unwrap_err(); match err { ToolError::InvalidArgument(message) => { assert!(message.contains("no delegation scope grant"), "{message}"); @@ -587,7 +587,7 @@ async fn spawn_pod_rejects_child_non_recursive_scope_under_parent_non_recursive_ }) .to_string(); - let err = tool.execute(&input).await.unwrap_err(); + let err = tool.execute(&input, Default::default()).await.unwrap_err(); match err { ToolError::InvalidArgument(message) => { assert!( @@ -639,7 +639,7 @@ async fn spawn_pod_rejects_scope_outside_spawner() { }) .to_string(); - let err = tool.execute(&input).await.unwrap_err(); + let err = tool.execute(&input, Default::default()).await.unwrap_err(); match err { ToolError::InvalidArgument(msg) => { assert!( @@ -712,7 +712,7 @@ async fn spawn_pod_rolls_back_reservation_when_socket_never_appears() { }) .to_string(); - let err = tool.execute(&input).await.unwrap_err(); + let err = tool.execute(&input, Default::default()).await.unwrap_err(); match err { ToolError::ExecutionFailed(msg) => { assert!( diff --git a/crates/session-store/tests/session_test.rs b/crates/session-store/tests/session_test.rs index 0c8e50b4..dc8632c9 100644 --- a/crates/session-store/tests/session_test.rs +++ b/crates/session-store/tests/session_test.rs @@ -54,7 +54,11 @@ struct MockWeatherTool; #[async_trait] impl Tool for MockWeatherTool { - async fn execute(&self, _input_json: &str) -> Result { + async fn execute( + &self, + _input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { Ok("Sunny, 25C".to_string().into()) } } diff --git a/crates/ticket/src/tool.rs b/crates/ticket/src/tool.rs index 363cfa31..94825f8f 100644 --- a/crates/ticket/src/tool.rs +++ b/crates/ticket/src/tool.rs @@ -562,7 +562,11 @@ struct TicketDoctorTool { #[async_trait] impl Tool for TicketCreateTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: TicketCreateParams = parse_input("TicketCreate", input_json)?; let mut input = NewTicket::new(params.title); if let Some(body) = params.body { @@ -594,7 +598,11 @@ impl Tool for TicketCreateTool { #[async_trait] impl Tool for TicketListTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: TicketListParams = parse_input("TicketList", input_json)?; let state = params.state.unwrap_or(TicketListStateParam::All); let (filter, state_filter) = state.as_filter(); @@ -629,7 +637,11 @@ impl Tool for TicketListTool { #[async_trait] impl Tool for TicketShowTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: TicketShowParams = parse_input("TicketShow", input_json)?; let query = id_or_query(params.id, params.query)?; let event_limit = bounded(params.event_limit, DEFAULT_EVENT_LIMIT, MAX_EVENT_LIMIT); @@ -661,7 +673,11 @@ impl Tool for TicketShowTool { #[async_trait] impl Tool for TicketCommentTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: TicketCommentParams = parse_input("TicketComment", input_json)?; let kind = match params.role { TicketCommentRoleParam::Comment => TicketEventKind::Comment, @@ -684,7 +700,11 @@ impl Tool for TicketCommentTool { #[async_trait] impl Tool for TicketReviewTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: TicketReviewParams = parse_input("TicketReview", input_json)?; let result = match params.result { TicketReviewResultParam::Approve => TicketReviewResult::Approve, @@ -708,7 +728,11 @@ impl Tool for TicketReviewTool { #[async_trait] impl Tool for TicketIntakeReadyTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: TicketIntakeReadyParams = parse_input("TicketIntakeReady", input_json)?; let from = TicketWorkflowState::Planning; let reason = params @@ -743,7 +767,11 @@ impl Tool for TicketIntakeReadyTool { #[async_trait] impl Tool for TicketWorkflowStateTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: TicketWorkflowStateParams = parse_input("TicketWorkflowState", input_json)?; let from = params.from.into_state(); let to = params.to.into_state(); @@ -778,7 +806,11 @@ impl Tool for TicketWorkflowStateTool { #[async_trait] impl Tool for TicketCloseTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: TicketCloseParams = parse_input("TicketClose", input_json)?; self.backend .close( @@ -795,7 +827,11 @@ impl Tool for TicketCloseTool { #[async_trait] impl Tool for TicketRelationRecordTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: TicketRelationRecordParams = parse_input("TicketRelationRecord", input_json)?; let relation = NewTicketRelation { kind: params.kind.into_kind(), @@ -819,7 +855,11 @@ impl Tool for TicketRelationRecordTool { #[async_trait] impl Tool for TicketRelationQueryTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: TicketRelationQueryParams = parse_input("TicketRelationQuery", input_json)?; let limit = bounded(params.limit, DEFAULT_LIST_LIMIT, MAX_LIST_LIMIT); let ticket = params.ticket.clone().map(TicketIdOrSlug::Id); @@ -853,7 +893,11 @@ impl Tool for TicketRelationQueryTool { #[async_trait] impl Tool for TicketOrchestrationPlanRecordTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: TicketOrchestrationPlanRecordParams = parse_input("TicketOrchestrationPlanRecord", input_json)?; let accepted_plan = params.accepted_plan.map(|plan| AcceptedOrchestrationPlan { @@ -885,7 +929,11 @@ impl Tool for TicketOrchestrationPlanRecordTool { #[async_trait] impl Tool for TicketOrchestrationPlanQueryTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: TicketOrchestrationPlanQueryParams = parse_input("TicketOrchestrationPlanQuery", input_json)?; let limit = bounded(params.limit, DEFAULT_LIST_LIMIT, MAX_LIST_LIMIT); @@ -922,7 +970,11 @@ impl Tool for TicketOrchestrationPlanQueryTool { #[async_trait] impl Tool for TicketDoctorTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: TicketDoctorParams = parse_input("TicketDoctor", input_json)?; let limit = bounded(params.limit, DEFAULT_DIAGNOSTIC_LIMIT, MAX_DIAGNOSTIC_LIMIT); let report = self @@ -1377,6 +1429,7 @@ mod tests { "body": "## Background\n\nCreated by tool.\n" }) .to_string(), + Default::default(), ) .await .unwrap(); @@ -1388,7 +1441,10 @@ mod tests { assert!(!created_text.contains("needs_preflight")); let listed = list - .execute(&json!({ "state": "planning" }).to_string()) + .execute( + &json!({ "state": "planning" }).to_string(), + Default::default(), + ) .await .unwrap(); assert!(listed.summary.contains("Listed 1 ticket")); @@ -1398,7 +1454,10 @@ mod tests { assert!(!listed_content.contains("needs_preflight")); let shown = show - .execute(&json!({ "id": id, "event_limit": 10 }).to_string()) + .execute( + &json!({ "id": id, "event_limit": 10 }).to_string(), + Default::default(), + ) .await .unwrap(); assert!(shown.summary.contains(&id)); @@ -1407,7 +1466,10 @@ mod tests { assert!(!shown_content.contains("legacy_ticket")); assert!(!shown_content.contains("needs_preflight")); - let report = doctor.execute(&json!({}).to_string()).await.unwrap(); + let report = doctor + .execute(&json!({}).to_string(), Default::default()) + .await + .unwrap(); assert!(report.summary.contains("0 error(s)")); } @@ -1431,6 +1493,7 @@ mod tests { "author": "test" }) .to_string(), + Default::default(), ) .await .unwrap(); @@ -1440,7 +1503,10 @@ mod tests { assert_eq!(recorded_json["target"], target.id); let queried = query - .execute(&json!({ "ticket": target.id.clone() }).to_string()) + .execute( + &json!({ "ticket": target.id.clone() }).to_string(), + Default::default(), + ) .await .unwrap(); let queried_json: Value = serde_json::from_str(&queried.content.unwrap()).unwrap(); @@ -1448,7 +1514,10 @@ mod tests { assert_eq!(queried_json["relations"][0]["ticket_id"], source.id); let shown = show - .execute(&json!({ "id": target.id.clone() }).to_string()) + .execute( + &json!({ "id": target.id.clone() }).to_string(), + Default::default(), + ) .await .unwrap(); let shown_json: Value = serde_json::from_str(&shown.content.unwrap()).unwrap(); @@ -1476,6 +1545,7 @@ mod tests { "body": "Implemented." }) .to_string(), + Default::default(), ) .await .unwrap(); @@ -1487,6 +1557,7 @@ mod tests { "body": "Looks good." }) .to_string(), + Default::default(), ) .await .unwrap(); @@ -1494,11 +1565,15 @@ mod tests { .execute( &json!({ "ticket": created.id, "resolution": "Done via TicketClose.\n" }) .to_string(), + Default::default(), ) .await .unwrap(); - let report = doctor.execute(&json!({}).to_string()).await.unwrap(); + let report = doctor + .execute(&json!({}).to_string(), Default::default()) + .await + .unwrap(); assert!(report.summary.contains("0 error(s)")); let closed = backend.show(TicketIdOrSlug::Id(created.id)).unwrap(); assert!(closed.resolution.is_some()); @@ -1538,6 +1613,7 @@ mod tests { "author": "intake-pod" }) .to_string(), + Default::default(), ) .await .unwrap(); @@ -1555,6 +1631,7 @@ mod tests { "author": "orchestrator" }) .to_string(), + Default::default(), ) .await .unwrap(); @@ -1569,6 +1646,7 @@ mod tests { "author": "orchestrator" }) .to_string(), + Default::default(), ) .await .unwrap(); @@ -1621,6 +1699,7 @@ mod tests { "author": "orchestrator" }) .to_string(), + Default::default(), ) .await .unwrap(); @@ -1650,6 +1729,7 @@ mod tests { "author": "orchestrator" }) .to_string(), + Default::default(), ) .await .unwrap(); @@ -1685,6 +1765,7 @@ mod tests { "body": "Should not apply.\n" }) .to_string(), + Default::default(), ) .await .unwrap_err(); @@ -1717,6 +1798,7 @@ mod tests { "body": "Should not bypass Queue.\n" }) .to_string(), + Default::default(), ) .await .unwrap_err(); @@ -1735,6 +1817,7 @@ mod tests { "body": "Should not move backwards.\n" }) .to_string(), + Default::default(), ) .await .unwrap_err(); @@ -1753,6 +1836,7 @@ mod tests { "body": "Should not skip inprogress.\n" }) .to_string(), + Default::default(), ) .await .unwrap_err(); @@ -1775,6 +1859,7 @@ mod tests { "intake_summary": "Should not rewrite ready ticket." }) .to_string(), + Default::default(), ) .await .unwrap_err(); @@ -1807,6 +1892,7 @@ mod tests { "author": "orchestrator" }) .to_string(), + Default::default(), ) .await .unwrap(); @@ -1823,6 +1909,7 @@ mod tests { "relation_kind": "blocked_by" }) .to_string(), + Default::default(), ) .await .unwrap(); @@ -1840,7 +1927,10 @@ mod tests { let temp = TempDir::new().unwrap(); let show = tool_by_name(backend(&temp), "TicketShow"); let error = show - .execute(&json!({ "id": "a", "query": "b" }).to_string()) + .execute( + &json!({ "id": "a", "query": "b" }).to_string(), + Default::default(), + ) .await .unwrap_err(); assert!(matches!(error, ToolError::InvalidArgument(_))); @@ -1852,7 +1942,10 @@ mod tests { let backend = backend(&temp); let create = tool_by_name(backend.clone(), "TicketCreate"); let output = create - .execute(&json!({ "title": "Escape" }).to_string()) + .execute( + &json!({ "title": "Escape" }).to_string(), + Default::default(), + ) .await .unwrap(); let value: Value = serde_json::from_str(&output.content.unwrap()).unwrap(); diff --git a/crates/tools/src/bash.rs b/crates/tools/src/bash.rs index 9a2976ff..352d5d7b 100644 --- a/crates/tools/src/bash.rs +++ b/crates/tools/src/bash.rs @@ -101,7 +101,11 @@ impl Drop for BashTool { #[async_trait] impl Tool for BashTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: BashParams = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid Bash input: {e}")))?; let timeout_secs = params @@ -394,7 +398,10 @@ mod tests { assert_eq!(meta.name, "Bash"); let inp = serde_json::json!({ "command": "echo hello" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); assert_eq!(out.summary, "$ echo hello"); assert_eq!(out.content.as_deref().map(str::trim), Some("hello")); } @@ -407,7 +414,10 @@ mod tests { let inp = serde_json::json!({ "command": "echo out; echo err 1>&2", }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let body = out.content.unwrap(); assert!(body.contains("out")); assert!(body.contains("err")); @@ -419,7 +429,10 @@ mod tests { let tool = make_tool(&h); let inp = serde_json::json!({ "command": "exit 7" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); assert!(out.summary.contains("exit 7"), "summary: {}", out.summary); assert!( out.content.is_none(), @@ -441,12 +454,16 @@ mod tests { "command": format!("cd {}", sub.to_str().unwrap()), }) .to_string(), + Default::default(), ) .await .unwrap(); let pwd_out = tool - .execute(&serde_json::json!({ "command": "pwd" }).to_string()) + .execute( + &serde_json::json!({ "command": "pwd" }).to_string(), + Default::default(), + ) .await .unwrap(); let body = pwd_out.content.unwrap(); @@ -467,7 +484,10 @@ mod tests { "command": "sleep 30", "timeout": 1, }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); assert!( out.summary.contains("timed out"), "summary: {}", @@ -480,7 +500,10 @@ mod tests { let h = setup(); let tool = make_tool(&h); - let err = tool.execute("not json").await.unwrap_err(); + let err = tool + .execute("not json", Default::default()) + .await + .unwrap_err(); assert!(matches!(err, ToolError::InvalidArgument(_))); } @@ -494,7 +517,10 @@ mod tests { let inp = serde_json::json!({ "command": "for i in $(seq 1 200); do echo line $i; done", }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let body = out.content.expect("expected content"); assert!( @@ -523,7 +549,10 @@ mod tests { let inp = serde_json::json!({ "command": "printf 'x%.0s' {1..20480}", }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let body = out.content.unwrap(); assert!( body.contains(spill_dir.to_str().unwrap()), @@ -542,7 +571,10 @@ mod tests { "command": "(sleep 0.05; echo bg) &", "timeout": 5, }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); assert!( !out.summary.contains("timed out"), "summary: {}", @@ -559,7 +591,9 @@ mod tests { let inp = serde_json::json!({ "command": "for i in $(seq 1 200); do echo $i; done", }); - tool.execute(&inp.to_string()).await.unwrap(); + tool.execute(&inp.to_string(), Default::default()) + .await + .unwrap(); // The spill dir should now contain exactly one bash-*.log file. let files_before: Vec<_> = std::fs::read_dir(&spill_dir) diff --git a/crates/tools/src/edit.rs b/crates/tools/src/edit.rs index 8ad6acba..32d905d6 100644 --- a/crates/tools/src/edit.rs +++ b/crates/tools/src/edit.rs @@ -36,7 +36,11 @@ pub(crate) struct EditTool { #[async_trait] impl Tool for EditTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: EditParams = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid Edit input: {e}")))?; @@ -169,7 +173,10 @@ mod tests { let def = read_tool(fs.clone(), tracker.clone()); let (_, reader) = def(); let inp = serde_json::json!({ "file_path": file.to_str().unwrap() }); - reader.execute(&inp.to_string()).await.unwrap(); + reader + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); } #[tokio::test] @@ -188,7 +195,10 @@ mod tests { "old_string": "foo bar", "new_string": "foo baz", }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); assert!(out.summary.contains("1 replacement")); assert_eq!( std::fs::read_to_string(&file).unwrap(), @@ -212,7 +222,10 @@ mod tests { "new_string": "y", "replace_all": true, }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); assert!(out.summary.contains("3 replacements")); assert_eq!(std::fs::read_to_string(&file).unwrap(), "y y y\n"); } @@ -231,7 +244,10 @@ mod tests { "old_string": "a", "new_string": "b", }); - let err = tool.execute(&inp.to_string()).await.unwrap_err(); + let err = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap_err(); assert!(matches!(err, ToolError::InvalidArgument(_))); } @@ -249,7 +265,10 @@ mod tests { "old_string": "world", "new_string": "x", }); - let err = tool.execute(&inp.to_string()).await.unwrap_err(); + let err = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap_err(); assert!(matches!(err, ToolError::InvalidArgument(_))); } @@ -266,7 +285,10 @@ mod tests { "old_string": "foo", "new_string": "bar", }); - let err = tool.execute(&inp.to_string()).await.unwrap_err(); + let err = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap_err(); assert!(matches!(err, ToolError::InvalidArgument(_))); } @@ -287,7 +309,10 @@ mod tests { "old_string": "foo", "new_string": "bar", }); - let err = tool.execute(&inp.to_string()).await.unwrap_err(); + let err = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap_err(); let msg = format!("{err}"); assert!(msg.contains("modified externally"), "{msg}"); } diff --git a/crates/tools/src/glob.rs b/crates/tools/src/glob.rs index 97bb75cf..d9f9efea 100644 --- a/crates/tools/src/glob.rs +++ b/crates/tools/src/glob.rs @@ -35,7 +35,11 @@ pub(crate) struct GlobTool { #[async_trait] impl Tool for GlobTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: GlobParams = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid Glob input: {e}")))?; @@ -239,7 +243,10 @@ mod tests { assert_eq!(meta.name, "Glob"); let inp = serde_json::json!({ "pattern": "**/*.rs" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); assert!(out.summary.contains("2 file(s)")); let body = out.content.unwrap(); assert!(body.contains("a.rs")); @@ -261,7 +268,10 @@ mod tests { let def = glob_tool(fs); let (_, tool) = def(); let inp = serde_json::json!({ "pattern": "*.rs" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let body = out.content.unwrap(); let new_pos = body.find("new.rs").unwrap(); let old_pos = body.find("old.rs").unwrap(); @@ -274,7 +284,10 @@ mod tests { let def = glob_tool(fs); let (_, tool) = def(); let inp = serde_json::json!({ "pattern": "**/*.nonexistent" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); assert!(out.summary.contains("No files")); assert!(out.content.is_none()); } @@ -285,7 +298,10 @@ mod tests { let def = glob_tool(fs); let (_, tool) = def(); let inp = serde_json::json!({ "pattern": "[unterminated" }); - let err = tool.execute(&inp.to_string()).await.unwrap_err(); + let err = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap_err(); assert!(matches!(err, ToolError::InvalidArgument(_))); } @@ -317,7 +333,10 @@ mod tests { let def = glob_tool(fs); let (_, tool) = def(); let inp = serde_json::json!({ "pattern": "**/*.rs" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let body = out.content.unwrap_or_default(); assert!(body.contains("visible.rs")); assert!( @@ -335,7 +354,10 @@ mod tests { let def = glob_tool(fs); let (_, tool) = def(); let inp = serde_json::json!({ "pattern": "*.rs" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let body = out.content.unwrap(); assert!(body.contains(".hidden.rs")); assert!(body.contains("visible.rs")); @@ -358,7 +380,10 @@ mod tests { "path": link.to_str().unwrap(), "pattern": "**/*.rs", }); - let err = tool.execute(&inp.to_string()).await.unwrap_err(); + let err = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap_err(); let msg = format!("{err}"); assert!( msg.contains("Glob does not follow symlink directories"), diff --git a/crates/tools/src/grep.rs b/crates/tools/src/grep.rs index c8f7fbcc..8d4a63e2 100644 --- a/crates/tools/src/grep.rs +++ b/crates/tools/src/grep.rs @@ -82,7 +82,11 @@ pub(crate) struct GrepTool { #[async_trait] impl Tool for GrepTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: GrepParams = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid Grep input: {e}")))?; @@ -563,7 +567,10 @@ mod tests { let def = grep_tool(scoped); let (_, tool) = def(); let inp = serde_json::json!({ "pattern": "needle" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let body = out.content.unwrap_or_default(); assert!(body.contains("visible.txt")); assert!( @@ -583,7 +590,10 @@ mod tests { assert_eq!(meta.name, "Grep"); let inp = serde_json::json!({ "pattern": "bravo" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); assert!(out.summary.contains("1 file")); assert!(out.content.unwrap().contains("a.txt")); } @@ -599,7 +609,10 @@ mod tests { "pattern": "two", "output_mode": "content", }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let body = out.content.unwrap(); assert!(body.contains(":2:two")); } @@ -616,7 +629,10 @@ mod tests { "pattern": "x", "output_mode": "count", }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let body = out.content.unwrap(); assert!(body.contains("a.txt:3")); assert!(body.contains("b.txt:1")); @@ -635,7 +651,10 @@ mod tests { "-i": true, "output_mode": "content", }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); assert!(out.content.unwrap().contains("HELLO")); } @@ -654,7 +673,10 @@ mod tests { "output_mode": "content", "-C": 1, }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let body = out.content.unwrap(); // should contain: line2 (before context), MATCH, line4 (after context) assert!(body.contains("line2")); @@ -677,7 +699,10 @@ mod tests { "multiline": true, "output_mode": "content", }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let body = out.content.unwrap(); assert!(body.contains("foo")); } @@ -694,7 +719,10 @@ mod tests { "pattern": "target", "glob": "*.rs", }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let body = out.content.unwrap(); assert!(body.contains("a.rs")); assert!(!body.contains("b.txt")); @@ -712,7 +740,10 @@ mod tests { "pattern": "target", "type": "rust", }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let body = out.content.unwrap(); assert!(body.contains("a.rs")); assert!(!body.contains("b.py")); @@ -731,7 +762,10 @@ mod tests { "pattern": "x", "head_limit": 2, }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let body = out.content.unwrap(); assert_eq!(body.lines().count(), 2); assert!(out.summary.contains("truncated at 2")); @@ -752,7 +786,10 @@ mod tests { "offset": 3, "head_limit": 10, }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let body = out.content.unwrap(); // We skipped 3, so only 2 should remain. assert_eq!(body.lines().count(), 2); @@ -769,7 +806,10 @@ mod tests { let def = grep_tool(fs); let (_, tool) = def(); let inp = serde_json::json!({ "pattern": "needle" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); let body = out.content.unwrap(); assert!(body.contains("b.txt")); assert!(!body.contains("a.bin")); @@ -781,7 +821,10 @@ mod tests { let def = grep_tool(fs); let (_, tool) = def(); let inp = serde_json::json!({ "pattern": "(" }); - let err = tool.execute(&inp.to_string()).await.unwrap_err(); + let err = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap_err(); assert!(matches!(err, ToolError::InvalidArgument(_))); } @@ -794,7 +837,10 @@ mod tests { "pattern": "x", "type": "nonexistent", }); - let err = tool.execute(&inp.to_string()).await.unwrap_err(); + let err = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap_err(); assert!(matches!(err, ToolError::InvalidArgument(_))); } @@ -805,7 +851,10 @@ mod tests { let def = grep_tool(fs); let (_, tool) = def(); let inp = serde_json::json!({ "pattern": "zzz" }); - let out = tool.execute(&inp.to_string()).await.unwrap(); + let out = tool + .execute(&inp.to_string(), Default::default()) + .await + .unwrap(); assert_eq!(out.summary, "No files matched"); assert!(out.content.is_none()); } diff --git a/crates/tools/src/read.rs b/crates/tools/src/read.rs index 2bebbc18..34b8fbdf 100644 --- a/crates/tools/src/read.rs +++ b/crates/tools/src/read.rs @@ -36,7 +36,11 @@ pub(crate) struct ReadTool { #[async_trait] impl Tool for ReadTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: ReadParams = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid Read input: {e}")))?; let offset = params.offset.unwrap_or(0); @@ -155,7 +159,10 @@ mod tests { assert_eq!(meta.name, "Read"); let input = serde_json::json!({ "file_path": file.to_str().unwrap() }); - let out = tool.execute(&input.to_string()).await.unwrap(); + let out = tool + .execute(&input.to_string(), Default::default()) + .await + .unwrap(); assert!(out.summary.contains("Read 3 line(s)")); let body = out.content.unwrap(); assert!(body.contains(" 1\talpha")); @@ -178,7 +185,10 @@ mod tests { "offset": 1, "limit": 2, }); - let out = tool.execute(&input.to_string()).await.unwrap(); + let out = tool + .execute(&input.to_string(), Default::default()) + .await + .unwrap(); assert!(out.summary.contains("[2..3] of 5")); let body = out.content.unwrap(); assert!(body.contains(" 2\t2")); @@ -193,7 +203,10 @@ mod tests { let input = serde_json::json!({ "file_path": dir.path().join("nope.txt").to_str().unwrap() }); - let err = tool.execute(&input.to_string()).await.unwrap_err(); + let err = tool + .execute(&input.to_string(), Default::default()) + .await + .unwrap_err(); assert!(matches!(err, ToolError::ExecutionFailed(_))); } @@ -202,7 +215,10 @@ mod tests { let (_dir, fs, tracker) = setup(); let def = read_tool(fs, tracker); let (_, tool) = def(); - let err = tool.execute("not json").await.unwrap_err(); + let err = tool + .execute("not json", Default::default()) + .await + .unwrap_err(); assert!(matches!(err, ToolError::InvalidArgument(_))); } } diff --git a/crates/tools/src/web.rs b/crates/tools/src/web.rs index 9f3a6aea..2f8c4453 100644 --- a/crates/tools/src/web.rs +++ b/crates/tools/src/web.rs @@ -146,7 +146,11 @@ struct WebFetchTool { #[async_trait] impl Tool for WebSearchTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let input: WebSearchInput = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid WebSearch input: {e}")))?; self.web.run_search(input).await @@ -193,7 +197,11 @@ impl WebTools { #[async_trait] impl Tool for WebFetchTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let input: WebFetchInput = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid WebFetch input: {e}")))?; self.web.run_fetch(input).await diff --git a/crates/tools/src/write.rs b/crates/tools/src/write.rs index 762387d3..8902ec8d 100644 --- a/crates/tools/src/write.rs +++ b/crates/tools/src/write.rs @@ -30,7 +30,11 @@ pub(crate) struct WriteTool { #[async_trait] impl Tool for WriteTool { - async fn execute(&self, input_json: &str) -> Result { + async fn execute( + &self, + input_json: &str, + _ctx: llm_worker::tool::ToolExecutionContext, + ) -> Result { let params: WriteParams = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid Write input: {e}")))?; @@ -118,7 +122,10 @@ mod tests { "file_path": file.to_str().unwrap(), "content": "hello\n", }); - let out = tool.execute(&input.to_string()).await.unwrap(); + let out = tool + .execute(&input.to_string(), Default::default()) + .await + .unwrap(); assert!(out.summary.contains("Created")); assert_eq!(std::fs::read_to_string(&file).unwrap(), "hello\n"); } @@ -135,7 +142,10 @@ mod tests { "file_path": file.to_str().unwrap(), "content": "new", }); - let err = tool.execute(&input.to_string()).await.unwrap_err(); + let err = tool + .execute(&input.to_string(), Default::default()) + .await + .unwrap_err(); assert!(matches!(err, ToolError::InvalidArgument(_))); } @@ -148,7 +158,10 @@ mod tests { let read_def = read_tool(fs.clone(), tracker.clone()); let (_, reader) = read_def(); let read_in = serde_json::json!({ "file_path": file.to_str().unwrap() }); - reader.execute(&read_in.to_string()).await.unwrap(); + reader + .execute(&read_in.to_string(), Default::default()) + .await + .unwrap(); let write_def = write_tool(fs, tracker); let (_, writer) = write_def(); @@ -156,7 +169,10 @@ mod tests { "file_path": file.to_str().unwrap(), "content": "new\n", }); - let out = writer.execute(&write_in.to_string()).await.unwrap(); + let out = writer + .execute(&write_in.to_string(), Default::default()) + .await + .unwrap(); assert!(out.summary.contains("Overwrote")); assert_eq!(std::fs::read_to_string(&file).unwrap(), "new\n"); } @@ -171,7 +187,10 @@ mod tests { let read_def = read_tool(fs.clone(), tracker.clone()); let (_, reader) = read_def(); reader - .execute(&serde_json::json!({ "file_path": file.to_str().unwrap() }).to_string()) + .execute( + &serde_json::json!({ "file_path": file.to_str().unwrap() }).to_string(), + Default::default(), + ) .await .unwrap(); @@ -187,6 +206,7 @@ mod tests { "content": "new", }) .to_string(), + Default::default(), ) .await .unwrap_err(); @@ -205,7 +225,10 @@ mod tests { "file_path": outside.path().join("x.txt").to_str().unwrap(), "content": "x", }); - let err = tool.execute(&input.to_string()).await.unwrap_err(); + let err = tool + .execute(&input.to_string(), Default::default()) + .await + .unwrap_err(); assert!(matches!(err, ToolError::InvalidArgument(_))); } } diff --git a/crates/tools/tests/edge_cases.rs b/crates/tools/tests/edge_cases.rs index 181e2f83..7926ac10 100644 --- a/crates/tools/tests/edge_cases.rs +++ b/crates/tools/tests/edge_cases.rs @@ -66,13 +66,17 @@ async fn unicode_path_and_content() { "content": content, }) .to_string(), + Default::default(), ) .await .unwrap(); let read = reg.get("Read"); let out = read - .execute(&json!({ "file_path": file.to_str().unwrap() }).to_string()) + .execute( + &json!({ "file_path": file.to_str().unwrap() }).to_string(), + Default::default(), + ) .await .unwrap(); let body = out.content.unwrap(); @@ -98,7 +102,10 @@ async fn symlink_to_outside_scope_is_rejected_for_write() { // target sits outside the scope. let read = reg.get("Read"); let read_err = read - .execute(&json!({ "file_path": link.to_str().unwrap() }).to_string()) + .execute( + &json!({ "file_path": link.to_str().unwrap() }).to_string(), + Default::default(), + ) .await .unwrap_err(); assert!( @@ -119,6 +126,7 @@ async fn symlink_to_outside_scope_is_rejected_for_write() { "content": "overwritten", }) .to_string(), + Default::default(), ) .await .unwrap_err(); @@ -147,7 +155,10 @@ async fn broken_symlink_reports_target_and_repair_hint() { let read = reg.get("Read"); let err = read - .execute(&json!({ "file_path": link.to_str().unwrap() }).to_string()) + .execute( + &json!({ "file_path": link.to_str().unwrap() }).to_string(), + Default::default(), + ) .await .unwrap_err(); let msg = format!("{err}"); @@ -165,7 +176,10 @@ async fn empty_file_read_and_edit() { let read = reg.get("Read"); let out = read - .execute(&json!({ "file_path": file.to_str().unwrap() }).to_string()) + .execute( + &json!({ "file_path": file.to_str().unwrap() }).to_string(), + Default::default(), + ) .await .unwrap(); assert!(out.summary.contains("0 line")); @@ -180,6 +194,7 @@ async fn empty_file_read_and_edit() { "new_string": "bar", }) .to_string(), + Default::default(), ) .await .unwrap_err(); @@ -196,7 +211,10 @@ async fn very_long_single_line() { let read = reg.get("Read"); let out = read - .execute(&json!({ "file_path": file.to_str().unwrap() }).to_string()) + .execute( + &json!({ "file_path": file.to_str().unwrap() }).to_string(), + Default::default(), + ) .await .unwrap(); // Should return exactly 1 line @@ -208,7 +226,10 @@ async fn relative_path_is_rejected() { let (_dir, _spill, reg) = setup(); let read = reg.get("Read"); let err = read - .execute(&json!({ "file_path": "relative.txt" }).to_string()) + .execute( + &json!({ "file_path": "relative.txt" }).to_string(), + Default::default(), + ) .await .unwrap_err(); assert!(format!("{err}").contains("absolute")); @@ -219,7 +240,10 @@ async fn directory_target_is_rejected_for_read() { let (dir, _spill, reg) = setup(); let read = reg.get("Read"); let err = read - .execute(&json!({ "file_path": dir.path().to_str().unwrap() }).to_string()) + .execute( + &json!({ "file_path": dir.path().to_str().unwrap() }).to_string(), + Default::default(), + ) .await .unwrap_err(); assert!(format!("{err}").contains("directory")); @@ -237,6 +261,7 @@ async fn deeply_nested_new_file_is_created() { "content": "deep\n", }) .to_string(), + Default::default(), ) .await .unwrap(); @@ -250,9 +275,12 @@ async fn replace_preserves_unicode() { std::fs::write(&file, "๐Ÿฆ€ rust ๐Ÿฆ€\n").unwrap(); let read = reg.get("Read"); - read.execute(&json!({ "file_path": file.to_str().unwrap() }).to_string()) - .await - .unwrap(); + read.execute( + &json!({ "file_path": file.to_str().unwrap() }).to_string(), + Default::default(), + ) + .await + .unwrap(); let edit = reg.get("Edit"); edit.execute( @@ -262,6 +290,7 @@ async fn replace_preserves_unicode() { "new_string": "ใƒฉใ‚นใƒˆ", }) .to_string(), + Default::default(), ) .await .unwrap(); @@ -282,6 +311,7 @@ async fn grep_handles_unicode_pattern() { "output_mode": "content", }) .to_string(), + Default::default(), ) .await .unwrap(); diff --git a/crates/tools/tests/integration.rs b/crates/tools/tests/integration.rs index 13cef5df..ce605fe0 100644 --- a/crates/tools/tests/integration.rs +++ b/crates/tools/tests/integration.rs @@ -66,13 +66,13 @@ fn setup() -> (TempDir, TempDir, Registry) { } async fn call(tool: &Arc, input: serde_json::Value) -> llm_worker::tool::ToolOutput { - tool.execute(&input.to_string()) + tool.execute(&input.to_string(), Default::default()) .await .expect("tool execution failed") } async fn call_err(tool: &Arc, input: serde_json::Value) -> llm_worker::tool::ToolError { - tool.execute(&input.to_string()) + tool.execute(&input.to_string(), Default::default()) .await .expect_err("expected error") }