yoi/crates/llm-worker/src/tool.rs
2026-04-15 04:08:56 +09:00

374 lines
12 KiB
Rust

//! Tool Definition
//!
//! Traits for defining tools callable by LLM.
//! Usually auto-implemented using the `#[tool]` macro.
use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use thiserror::Error;
/// Error during tool execution
#[derive(Debug, Error)]
pub enum ToolError {
/// Invalid argument
#[error("Invalid argument: {0}")]
InvalidArgument(String),
/// Execution failed
#[error("Execution failed: {0}")]
ExecutionFailed(String),
/// Internal error
#[error("Internal error: {0}")]
Internal(String),
}
// =============================================================================
// ToolOutput - Tool execution result with summary + content
// =============================================================================
/// Threshold below which tool output is treated as summary-only (no content).
/// Outputs this small don't benefit from pruning.
pub const SUMMARY_THRESHOLD: usize = 200;
/// Byte-size caps applied to tool execution `content` at the Worker's
/// tool-execution boundary, before results enter conversation history.
///
/// Exists so a single oversized tool result (e.g. a wide `Glob` scan)
/// cannot blow past the provider's per-minute input-token rate limit.
/// Individual tools are not trusted to self-limit — this is the single
/// chokepoint.
///
/// The unit is bytes rather than tokens because accurate pre-send token
/// estimation is not available. The limits can be migrated to token
/// units later without changing callers.
#[derive(Debug, Clone)]
pub struct ToolOutputLimits {
/// Cap applied to any tool not listed in `per_tool`.
pub default_max_bytes: usize,
/// Per-tool overrides, keyed by tool registration name.
pub per_tool: HashMap<String, usize>,
}
impl ToolOutputLimits {
/// Resolve the cap for a given tool name.
pub fn limit_for(&self, tool_name: &str) -> usize {
self.per_tool
.get(tool_name)
.copied()
.unwrap_or(self.default_max_bytes)
}
}
/// Truncate `content` in-place if it exceeds `limit` bytes, replacing
/// the dropped tail with a short human- and LLM-readable marker so the
/// model can self-correct by narrowing its query.
///
/// The cut point is walked back to the nearest UTF-8 char boundary so
/// multibyte characters are never split.
pub(crate) fn truncate_content(content: &mut String, limit: usize) {
let original_len = content.len();
if original_len <= limit {
return;
}
let suffix_template = "\n\n[truncated: %BYTES% bytes dropped, refine your query]";
// Reserve enough headroom for the suffix (upper bound on the byte length
// of the number substitution). usize::MAX fits in 20 digits.
let reserved = suffix_template.len() + 20 - "%BYTES%".len();
let body_budget = limit.saturating_sub(reserved);
let mut cut = body_budget.min(original_len);
while cut > 0 && !content.is_char_boundary(cut) {
cut -= 1;
}
content.truncate(cut);
let dropped = original_len - cut;
content.push_str(&suffix_template.replace("%BYTES%", &dropped.to_string()));
}
/// Tool execution result.
///
/// Every output has a mandatory `summary` (1-2 lines) that persists in
/// conversation history even after pruning. The optional `content` carries
/// full details and is removed by the Prune mechanism when the context
/// grows too large.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolOutput {
/// Short summary (1-2 lines). Always remains in history.
pub summary: String,
/// Detailed output. Removed by Prune when old enough.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
}
impl From<String> for ToolOutput {
fn from(s: String) -> Self {
if s.len() <= SUMMARY_THRESHOLD {
ToolOutput {
summary: s,
content: None,
}
} else {
let lines = s.lines().count();
let first_line: String = s.lines().next().unwrap_or("").chars().take(80).collect();
let summary = format!("{lines} lines | {first_line}");
ToolOutput {
summary,
content: Some(s),
}
}
}
}
// =============================================================================
// ToolMeta - Immutable Meta Information
// =============================================================================
/// Tool meta information (fixed at registration, immutable)
///
/// Generated from `ToolDefinition` factory and does not change after registration with Worker.
/// Used for sending tool definitions to LLM.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ToolMeta {
/// Tool name (used by LLM for identification)
pub name: String,
/// Tool description (included in prompt to LLM)
pub description: String,
/// JSON Schema for arguments
pub input_schema: Value,
}
impl ToolMeta {
/// Create a new ToolMeta
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
description: String::new(),
input_schema: Value::Object(Default::default()),
}
}
/// Set the description
pub fn description(mut self, desc: impl Into<String>) -> Self {
self.description = desc.into();
self
}
/// Set the argument schema
pub fn input_schema(mut self, schema: Value) -> Self {
self.input_schema = schema;
self
}
}
// =============================================================================
// ToolDefinition - Factory Type
// =============================================================================
/// Tool definition factory
///
/// When called, returns `(ToolMeta, Arc<dyn Tool>)`.
/// Called once during Worker registration, and the meta information and instance
/// are cached at session scope.
///
/// # Examples
///
/// ```ignore
/// let def: ToolDefinition = Arc::new(|| {
/// (
/// ToolMeta::new("my_tool")
/// .description("My tool description")
/// .input_schema(json!({"type": "object"})),
/// Arc::new(MyToolImpl { state: 0 }) as Arc<dyn Tool>,
/// )
/// });
/// worker.register_tool(def)?;
/// ```
pub type ToolDefinition = Arc<dyn Fn() -> (ToolMeta, Arc<dyn Tool>) + Send + Sync>;
// =============================================================================
// Tool trait
// =============================================================================
/// Trait for defining tools callable by LLM
///
/// Tools are used by LLM to access external resources
/// or execute computations.
/// Can maintain state during the session.
///
/// # How to Implement
///
/// Usually auto-implemented using the `#[tool_registry]` macro:
///
/// ```ignore
/// #[tool_registry]
/// impl MyApp {
/// #[tool]
/// async fn search(&self, query: String) -> String {
/// format!("Results for: {}", query)
/// }
/// }
///
/// // Register
/// worker.register_tool(app.search_definition())?;
/// ```
///
/// # Manual Implementation
///
/// ```ignore
/// use llm_worker::tool::{Tool, ToolError, ToolMeta, ToolDefinition};
/// use std::sync::Arc;
///
/// struct MyTool { counter: std::sync::atomic::AtomicUsize }
///
/// #[async_trait::async_trait]
/// impl Tool for MyTool {
/// async fn execute(&self, input: &str) -> Result<String, ToolError> {
/// self.counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
/// Ok("result".to_string())
/// }
/// }
///
/// let def: ToolDefinition = Arc::new(|| {
/// (
/// ToolMeta::new("my_tool")
/// .description("My custom tool")
/// .input_schema(serde_json::json!({"type": "object"})),
/// Arc::new(MyTool { counter: Default::default() }) as Arc<dyn Tool>,
/// )
/// });
/// ```
#[async_trait]
pub trait Tool: Send + Sync {
/// Execute the tool.
///
/// # Arguments
/// * `input_json` - JSON-formatted arguments generated by LLM
///
/// # Returns
/// A [`ToolOutput`] with summary and optional detailed content.
/// For simple cases, use `From<String>`: `Ok("done".to_string().into())`
async fn execute(&self, input_json: &str) -> Result<ToolOutput, ToolError>;
}
// =============================================================================
// Tool Call / Result Types
// =============================================================================
/// Tool call information
///
/// Represents a ToolUse block from LLM.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
/// Tool call ID (used for linking with response)
pub id: String,
/// Tool name
pub name: String,
/// Input arguments (JSON)
pub input: Value,
}
/// Tool execution result
///
/// Intermediate representation between tool execution and history.
/// Carries `summary` + optional `content` from [`ToolOutput`].
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
/// Corresponding tool call ID
pub tool_use_id: String,
/// Short summary (always kept in history)
pub summary: String,
/// Detailed output (prunable)
#[serde(default, skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
/// Whether this is an error
#[serde(default)]
pub is_error: bool,
}
impl ToolResult {
/// Create a success result from a [`ToolOutput`].
pub fn from_output(tool_use_id: impl Into<String>, output: ToolOutput) -> Self {
Self {
tool_use_id: tool_use_id.into(),
summary: output.summary,
content: output.content,
is_error: false,
}
}
/// Create an error result.
pub fn error(tool_use_id: impl Into<String>, message: impl Into<String>) -> Self {
Self {
tool_use_id: tool_use_id.into(),
summary: message.into(),
content: None,
is_error: true,
}
}
}
#[cfg(test)]
mod truncate_tests {
use super::*;
#[test]
fn noop_when_within_limit() {
let mut s = "hello world".to_string();
truncate_content(&mut s, 1024);
assert_eq!(s, "hello world");
}
#[test]
fn noop_at_exact_limit() {
let mut s = "a".repeat(100);
truncate_content(&mut s, 100);
assert_eq!(s.len(), 100);
}
#[test]
fn truncates_oversized_ascii_with_marker() {
let mut s = "a".repeat(1000);
truncate_content(&mut s, 200);
assert!(s.contains("[truncated:"));
assert!(s.contains("refine your query"));
assert!(s.len() <= 200, "result was {} bytes", s.len());
let dropped: usize = s
.split("[truncated: ")
.nth(1)
.unwrap()
.split(' ')
.next()
.unwrap()
.parse()
.unwrap();
let body_len = s.find("\n\n[truncated:").unwrap();
assert_eq!(body_len + dropped, 1000);
}
#[test]
fn respects_utf8_char_boundaries() {
// 100 copies of "あ" (3 bytes each) = 300 bytes.
let mut s = "".repeat(100);
truncate_content(&mut s, 120);
// Truncation must not split a multibyte character.
assert!(s.is_char_boundary(s.find("\n\n[truncated:").unwrap_or(s.len())));
// And the result must still be valid UTF-8 (implicitly true for String).
assert!(s.contains("[truncated:"));
}
#[test]
fn limits_per_tool_override() {
let mut limits = ToolOutputLimits {
default_max_bytes: 1024,
per_tool: HashMap::new(),
};
limits.per_tool.insert("Read".to_string(), 4096);
assert_eq!(limits.limit_for("Read"), 4096);
assert_eq!(limits.limit_for("Grep"), 1024);
}
}