diff --git a/Cargo.lock b/Cargo.lock index 7ccf0182..de6d211b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -681,6 +681,16 @@ dependencies = [ [[package]] name = "insomnia" version = "0.1.0" +dependencies = [ + "llm-worker", + "llm-worker-persistence", + "serde", + "tempfile", + "thiserror", + "tokio", + "toml 0.8.23", + "uuid", +] [[package]] name = "ipnet" @@ -1250,6 +1260,15 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_spanned" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf41e0cfaf7226dca15e8197172c295a782857fcb97fad1808a166870dee75a3" +dependencies = [ + "serde", +] + [[package]] name = "serde_spanned" version = "1.1.1" @@ -1471,6 +1490,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml" +version = "0.8.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" +dependencies = [ + "serde", + "serde_spanned 0.6.9", + "toml_datetime 0.6.11", + "toml_edit", +] + [[package]] name = "toml" version = "1.1.2+spec-1.1.0" @@ -1479,11 +1510,20 @@ checksum = "81f3d15e84cbcd896376e6730314d59fb5a87f31e4b038454184435cd57defee" dependencies = [ "indexmap", "serde_core", - "serde_spanned", - "toml_datetime", + "serde_spanned 1.1.1", + "toml_datetime 1.1.1+spec-1.1.0", "toml_parser", "toml_writer", - "winnow", + "winnow 1.0.1", +] + +[[package]] +name = "toml_datetime" +version = "0.6.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c" +dependencies = [ + "serde", ] [[package]] @@ -1495,15 +1535,35 @@ dependencies = [ "serde_core", ] +[[package]] +name = "toml_edit" +version = "0.22.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" +dependencies = [ + "indexmap", + "serde", + "serde_spanned 0.6.9", + "toml_datetime 0.6.11", + "toml_write", + "winnow 0.7.15", +] + [[package]] name = "toml_parser" version = "1.1.2+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2abe9b86193656635d2411dc43050282ca48aa31c2451210f4202550afb7526" dependencies = [ - "winnow", + "winnow 1.0.1", ] +[[package]] +name = "toml_write" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" + [[package]] name = "toml_writer" version = "1.1.1+spec-1.1.0" @@ -1634,7 +1694,7 @@ dependencies = [ "serde_json", "target-triple", "termcolor", - "toml", + "toml 1.1.2+spec-1.1.0", ] [[package]] @@ -1945,6 +2005,15 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "winnow" +version = "0.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df79d97927682d2fd8adb29682d1140b343be4ac0f08fd68b7765d9c059d3945" +dependencies = [ + "memchr", +] + [[package]] name = "winnow" version = "1.0.1" diff --git a/TODO.md b/TODO.md index 3ebb47d9..dcb777d6 100644 --- a/TODO.md +++ b/TODO.md @@ -2,4 +2,6 @@ - [ ] テスト設計 - [x] ツール出力の遅延読み込み設計 (ToolOutput / BlobStore / auto_summarize) - [ ] ツール設計 -- [ ] inspect ツール実装 + - [ ] ツールの動的追加/削除 (unregister, replace) + - [ ] ToolDefinition ファクトリの遅延初期化修正 (現状 register 時に即時呼び出しされている。セッション開始=初回メッセージ送信時まで遅延させる) +- [x] inspect ツール実装 diff --git a/crates/insomnia/Cargo.toml b/crates/insomnia/Cargo.toml index ba926c09..f5ae8066 100644 --- a/crates/insomnia/Cargo.toml +++ b/crates/insomnia/Cargo.toml @@ -1,6 +1,18 @@ [package] name = "insomnia" version = "0.1.0" -edition = "2024" +edition.workspace = true +license.workspace = true [dependencies] +llm-worker = { path = "../llm-worker" } +llm-worker-persistence = { path = "../llm-worker-persistence" } +serde = { version = "1.0", features = ["derive"] } +toml = "0.8" +uuid = { version = "1", features = ["v7", "serde"] } +thiserror = "2.0" +tokio = { version = "1.49", features = ["fs"] } + +[dev-dependencies] +tokio = { version = "1.49", features = ["macros", "rt-multi-thread"] } +tempfile = "3.24" diff --git a/crates/insomnia/src/lib.rs b/crates/insomnia/src/lib.rs index e69de29b..52aa4488 100644 --- a/crates/insomnia/src/lib.rs +++ b/crates/insomnia/src/lib.rs @@ -0,0 +1,9 @@ +pub mod manifest; +pub mod pod; +pub mod provider; +pub mod scope; + +pub use manifest::{PodManifest, ProviderConfig, ProviderKind}; +pub use pod::{Pod, PodError, PodId, PodRunResult, apply_worker_manifest, new_pod_id}; +pub use provider::build_client; +pub use scope::Scope; diff --git a/crates/insomnia/src/manifest.rs b/crates/insomnia/src/manifest.rs new file mode 100644 index 00000000..7d414ee8 --- /dev/null +++ b/crates/insomnia/src/manifest.rs @@ -0,0 +1,164 @@ +use std::path::PathBuf; + +use serde::Deserialize; + +/// Declarative configuration for a Pod. +/// +/// Parsed from a TOML manifest file. Describes the provider, model, +/// system prompt, and optional directory scope. +#[derive(Debug, Clone, Deserialize)] +pub struct PodManifest { + pub pod: PodMeta, + pub provider: ProviderConfig, + pub worker: WorkerManifest, + #[serde(default)] + pub scope: Option, +} + +/// Pod metadata. +#[derive(Debug, Clone, Deserialize)] +pub struct PodMeta { + pub name: String, +} + +/// LLM provider configuration. +#[derive(Debug, Clone, Deserialize)] +pub struct ProviderConfig { + pub kind: ProviderKind, + pub model: String, + /// Environment variable name holding the API key. + #[serde(default)] + pub api_key_env: Option, + /// Custom base URL for the provider API. + #[serde(default)] + pub base_url: Option, +} + +/// Supported LLM providers. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ProviderKind { + Anthropic, + Openai, + Gemini, + Ollama, +} + +/// Worker-level configuration embedded in the manifest. +#[derive(Debug, Clone, Deserialize)] +pub struct WorkerManifest { + #[serde(default)] + pub system_prompt: Option, + #[serde(default)] + pub max_tokens: Option, + #[serde(default)] + pub temperature: Option, +} + +/// Directory scope configuration. +#[derive(Debug, Clone, Deserialize)] +pub struct ScopeConfig { + pub root: PathBuf, +} + +impl PodManifest { + /// Parse a manifest from a TOML string. + pub fn from_toml(s: &str) -> Result { + toml::from_str(s) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_minimal_manifest() { + let toml = r#" +[pod] +name = "test-agent" + +[provider] +kind = "anthropic" +model = "claude-sonnet-4-20250514" + +[worker] +"#; + let manifest = PodManifest::from_toml(toml).unwrap(); + assert_eq!(manifest.pod.name, "test-agent"); + assert_eq!(manifest.provider.kind, ProviderKind::Anthropic); + assert_eq!(manifest.provider.model, "claude-sonnet-4-20250514"); + assert!(manifest.provider.api_key_env.is_none()); + assert!(manifest.scope.is_none()); + assert!(manifest.worker.system_prompt.is_none()); + } + + #[test] + fn parse_full_manifest() { + let toml = r#" +[pod] +name = "code-reviewer" + +[provider] +kind = "anthropic" +model = "claude-sonnet-4-20250514" +api_key_env = "ANTHROPIC_API_KEY" + +[worker] +system_prompt = "You are a code reviewer." +max_tokens = 4096 +temperature = 0.3 + +[scope] +root = "./src" +"#; + let manifest = PodManifest::from_toml(toml).unwrap(); + assert_eq!(manifest.pod.name, "code-reviewer"); + assert_eq!( + manifest.provider.api_key_env.as_deref(), + Some("ANTHROPIC_API_KEY") + ); + assert_eq!( + manifest.worker.system_prompt.as_deref(), + Some("You are a code reviewer.") + ); + assert_eq!(manifest.worker.max_tokens, Some(4096)); + assert_eq!(manifest.worker.temperature, Some(0.3)); + assert_eq!( + manifest.scope.as_ref().unwrap().root, + PathBuf::from("./src") + ); + } + + #[test] + fn parse_ollama_no_api_key() { + let toml = r#" +[pod] +name = "local-agent" + +[provider] +kind = "ollama" +model = "llama3" + +[worker] +"#; + let manifest = PodManifest::from_toml(toml).unwrap(); + assert_eq!(manifest.provider.kind, ProviderKind::Ollama); + assert!(manifest.provider.api_key_env.is_none()); + } + + #[test] + fn reject_unknown_provider() { + let toml = r#" +[pod] +name = "test" + +[provider] +kind = "unknown_provider" +model = "x" + +[worker] +"#; + assert!(PodManifest::from_toml(toml).is_err()); + } +} diff --git a/crates/insomnia/src/pod.rs b/crates/insomnia/src/pod.rs new file mode 100644 index 00000000..0c3a100e --- /dev/null +++ b/crates/insomnia/src/pod.rs @@ -0,0 +1,180 @@ +use llm_worker::llm_client::client::LlmClient; +use llm_worker::llm_client::RequestConfig; +use llm_worker::Worker; +use llm_worker_persistence::{ + Session, SessionConfig, SessionError, SessionId, Store, StoreError, +}; + +use crate::manifest::{PodManifest, WorkerManifest}; +use crate::scope::Scope; + +/// Pod identifier. UUID v7 (time-ordered). +pub type PodId = uuid::Uuid; + +/// Generate a new Pod ID. +pub fn new_pod_id() -> PodId { + uuid::Uuid::now_v7() +} + +/// An independent agent execution unit. +/// +/// Wraps a persistent [`Session`] with manifest metadata and an optional +/// directory scope. This is the primary abstraction in insomnia. +pub struct Pod { + id: PodId, + manifest: PodManifest, + session: Session, + scope: Option, +} + +impl Pod { + /// Create a new Pod from a pre-built Worker and store. + /// + /// The caller is responsible for constructing the `LlmClient` from the + /// manifest's provider config. This keeps Pod free of provider-specific + /// dependencies. + pub async fn new( + manifest: PodManifest, + worker: Worker, + store: St, + scope: Option, + ) -> Result { + let session = Session::new(worker, store, SessionConfig::default()).await?; + Ok(Self { + id: new_pod_id(), + manifest, + session, + scope, + }) + } + + /// Restore a Pod from a persisted session. + pub async fn restore( + id: PodId, + session_id: SessionId, + manifest: PodManifest, + client: C, + store: St, + scope: Option, + ) -> Result { + let session = Session::restore(client, store, session_id, SessionConfig::default()).await?; + Ok(Self { + id, + manifest, + session, + scope, + }) + } + + /// The Pod's unique identifier. + pub fn id(&self) -> PodId { + self.id + } + + /// The session ID used for persistence. + pub fn session_id(&self) -> SessionId { + self.session.session_id() + } + + /// The Pod's manifest. + pub fn manifest(&self) -> &PodManifest { + &self.manifest + } + + /// The Pod's directory scope, if any. + pub fn scope(&self) -> Option<&Scope> { + self.scope.as_ref() + } + + /// Direct access to the underlying session. + /// + /// Use this to register tools, hooks, or subscribers on the worker + /// before calling [`run`](Self::run). + pub fn session_mut(&mut self) -> &mut Session { + &mut self.session + } + + /// Send user input and run until the LLM turn completes. + pub async fn run(&mut self, input: impl Into) -> Result { + let result = self.session.run(input).await?; + Ok(result.into()) + } + + /// Resume from a paused state. + pub async fn resume(&mut self) -> Result { + let result = self.session.resume().await?; + Ok(result.into()) + } +} + +impl Pod, St> { + /// Create a Pod entirely from a manifest. + /// + /// Builds the LLM client from the provider config, applies worker + /// settings, and creates a new persistent session. + pub async fn from_manifest( + manifest: PodManifest, + store: St, + scope: Option, + ) -> Result { + let client = crate::provider::build_client(&manifest.provider)?; + let mut worker = Worker::new(client); + apply_worker_manifest(&mut worker, &manifest.worker); + let session = Session::new(worker, store, SessionConfig::default()).await?; + Ok(Self { + id: new_pod_id(), + manifest, + session, + scope, + }) + } +} + +/// Apply worker-level manifest settings to a Worker. +pub fn apply_worker_manifest(worker: &mut Worker, wm: &WorkerManifest) { + if let Some(ref prompt) = wm.system_prompt { + worker.set_system_prompt(prompt); + } + let mut config = RequestConfig::new(); + if let Some(max_tokens) = wm.max_tokens { + config.max_tokens = Some(max_tokens); + } + if let Some(temperature) = wm.temperature { + config.temperature = Some(temperature); + } + worker.set_request_config(config); +} + +/// Result of a Pod run. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PodRunResult { + /// The LLM finished its turn normally. + Finished, + /// The LLM paused (e.g. awaiting user confirmation via a hook). + Paused, +} + +impl From for PodRunResult { + fn from(r: llm_worker::WorkerResult) -> Self { + match r { + llm_worker::WorkerResult::Finished => PodRunResult::Finished, + llm_worker::WorkerResult::Paused => PodRunResult::Paused, + } + } +} + +/// Pod errors. +#[derive(Debug, thiserror::Error)] +pub enum PodError { + #[error(transparent)] + Session(#[from] SessionError), + + #[error(transparent)] + Store(#[from] StoreError), + + #[error("scope violation: {path} is outside the allowed directory")] + ScopeViolation { path: String }, + + #[error("provider configuration error: {0}")] + ProviderConfig(String), +} diff --git a/crates/insomnia/src/provider.rs b/crates/insomnia/src/provider.rs new file mode 100644 index 00000000..5ecf8958 --- /dev/null +++ b/crates/insomnia/src/provider.rs @@ -0,0 +1,60 @@ +use llm_worker::llm_client::client::LlmClient; +use llm_worker::llm_client::providers::anthropic::AnthropicClient; +use llm_worker::llm_client::providers::gemini::GeminiClient; +use llm_worker::llm_client::providers::ollama::OllamaClient; +use llm_worker::llm_client::providers::openai::OpenAIClient; + +use crate::manifest::{ProviderConfig, ProviderKind}; +use crate::pod::PodError; + +/// Build an [`LlmClient`] from a [`ProviderConfig`]. +/// +/// Resolves the API key from the environment variable specified in the config. +pub fn build_client(config: &ProviderConfig) -> Result, PodError> { + let api_key = config + .api_key_env + .as_deref() + .map(std::env::var) + .transpose() + .map_err(|e| PodError::ProviderConfig(format!("env var: {e}")))?; + + match config.kind { + ProviderKind::Anthropic => { + let key = api_key.ok_or_else(|| { + PodError::ProviderConfig("anthropic requires api_key_env".into()) + })?; + let mut client = AnthropicClient::new(key, &config.model); + if let Some(ref url) = config.base_url { + client = client.with_base_url(url); + } + Ok(Box::new(client)) + } + ProviderKind::Openai => { + let key = api_key.ok_or_else(|| { + PodError::ProviderConfig("openai requires api_key_env".into()) + })?; + let mut client = OpenAIClient::new(key, &config.model); + if let Some(ref url) = config.base_url { + client = client.with_base_url(url); + } + Ok(Box::new(client)) + } + ProviderKind::Gemini => { + let key = api_key.ok_or_else(|| { + PodError::ProviderConfig("gemini requires api_key_env".into()) + })?; + let mut client = GeminiClient::new(key, &config.model); + if let Some(ref url) = config.base_url { + client = client.with_base_url(url); + } + Ok(Box::new(client)) + } + ProviderKind::Ollama => { + let mut client = OllamaClient::new(&config.model); + if let Some(ref url) = config.base_url { + client = client.with_base_url(url); + } + Ok(Box::new(client)) + } + } +} diff --git a/crates/insomnia/src/scope.rs b/crates/insomnia/src/scope.rs new file mode 100644 index 00000000..6a979ae4 --- /dev/null +++ b/crates/insomnia/src/scope.rs @@ -0,0 +1,101 @@ +use std::path::{Path, PathBuf}; + +/// Directory scope constraining a Pod's write access. +/// +/// Read access is unrestricted — only write operations are checked against the scope. +#[derive(Debug, Clone)] +pub struct Scope { + root: PathBuf, +} + +impl Scope { + /// Create a new scope rooted at the given directory. + /// + /// The path is canonicalized to resolve symlinks and relative components. + pub fn new(root: impl Into) -> std::io::Result { + let root = root.into().canonicalize()?; + Ok(Self { root }) + } + + /// The root directory of this scope. + pub fn root(&self) -> &Path { + &self.root + } + + /// Check whether `path` falls within this scope. + /// + /// The path is canonicalized before comparison. + pub fn contains(&self, path: &Path) -> bool { + match path.canonicalize() { + Ok(canonical) => canonical.starts_with(&self.root), + Err(_) => { + // Path doesn't exist yet — check the parent directory instead. + // This handles write_file to a new file inside the scope. + match path.parent().and_then(|p| p.canonicalize().ok()) { + Some(parent) => parent.starts_with(&self.root), + None => false, + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use tempfile::TempDir; + + #[test] + fn contains_file_inside_scope() { + let dir = TempDir::new().unwrap(); + let scope = Scope::new(dir.path()).unwrap(); + + let file = dir.path().join("test.txt"); + fs::write(&file, "hello").unwrap(); + + assert!(scope.contains(&file)); + } + + #[test] + fn rejects_file_outside_scope() { + let dir = TempDir::new().unwrap(); + let outside = TempDir::new().unwrap(); + let scope = Scope::new(dir.path()).unwrap(); + + let file = outside.path().join("test.txt"); + fs::write(&file, "hello").unwrap(); + + assert!(!scope.contains(&file)); + } + + #[test] + fn contains_new_file_in_existing_parent() { + let dir = TempDir::new().unwrap(); + let scope = Scope::new(dir.path()).unwrap(); + + // File doesn't exist yet, but parent dir is inside scope + let new_file = dir.path().join("new.txt"); + assert!(scope.contains(&new_file)); + } + + #[test] + fn contains_nested_directory() { + let dir = TempDir::new().unwrap(); + let nested = dir.path().join("a/b/c"); + fs::create_dir_all(&nested).unwrap(); + let scope = Scope::new(dir.path()).unwrap(); + + let file = nested.join("test.txt"); + assert!(scope.contains(&file)); + } + + #[test] + fn rejects_traversal_attack() { + let dir = TempDir::new().unwrap(); + let scope = Scope::new(dir.path()).unwrap(); + + let traversal = dir.path().join("../../../etc/passwd"); + assert!(!scope.contains(&traversal)); + } +} diff --git a/crates/llm-worker-persistence/src/inspect_tool.rs b/crates/llm-worker-persistence/src/inspect_tool.rs new file mode 100644 index 00000000..91d01083 --- /dev/null +++ b/crates/llm-worker-persistence/src/inspect_tool.rs @@ -0,0 +1,668 @@ +//! Built-in `inspect` tool for retrieving stored blob content. +//! +//! When large tool outputs are stored in a [`BlobStore`], only a summary +//! with a `[blob:]` reference is placed in conversation history. +//! This tool lets the LLM retrieve details on demand, with optional +//! selectors for partial access. + +use std::sync::Arc; + +use async_trait::async_trait; +use serde::Deserialize; +use serde_json::json; + +use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta}; +use llm_worker::state::Mutable; +use llm_worker::ToolRegistryError; +use llm_worker::Worker; +use llm_worker::llm_client::LlmClient; + +use crate::blob_store::{BlobId, BlobStore}; + +// ─── Constants ─────────────────────────────────────────────────────────────── + +/// Maximum lines shown in the default text preview. +const DEFAULT_PREVIEW_LINES: usize = 50; +/// Maximum array elements shown in the default preview. +const DEFAULT_PREVIEW_ELEMENTS: usize = 5; +/// Maximum object keys whose values are shown in the default preview. +const DEFAULT_PREVIEW_KEYS: usize = 3; + +// ─── Selector ──────────────────────────────────────────────────────────────── + +/// Parsed selector for partial blob content retrieval. +#[derive(Debug, Clone, PartialEq, Eq)] +enum Selector { + /// Extract a range of lines (1-based, inclusive). + Lines { start: usize, end: usize }, + /// Extract a range of array elements (0-based, exclusive end). + Slice { start: usize, end: usize }, + /// Extract a specific key from a JSON object. + Key(String), +} + +fn parse_selector(s: &str) -> Result { + if let Some(rest) = s.strip_prefix("lines:") { + let (a, b) = rest + .split_once('-') + .ok_or_else(|| ToolError::InvalidArgument(format!( + "invalid lines selector '{s}': expected format lines:N-M" + )))?; + let start: usize = a.parse().map_err(|_| { + ToolError::InvalidArgument(format!("invalid start line number: '{a}'")) + })?; + let end: usize = b.parse().map_err(|_| { + ToolError::InvalidArgument(format!("invalid end line number: '{b}'")) + })?; + if start == 0 { + return Err(ToolError::InvalidArgument( + "line numbers are 1-based, got 0".into(), + )); + } + if start > end { + return Err(ToolError::InvalidArgument(format!( + "start line ({start}) must be <= end line ({end})" + ))); + } + Ok(Selector::Lines { start, end }) + } else if let Some(rest) = s.strip_prefix("slice:") { + let (a, b) = rest + .split_once("..") + .ok_or_else(|| ToolError::InvalidArgument(format!( + "invalid slice selector '{s}': expected format slice:N..M" + )))?; + let start: usize = a.parse().map_err(|_| { + ToolError::InvalidArgument(format!("invalid start index: '{a}'")) + })?; + let end: usize = b.parse().map_err(|_| { + ToolError::InvalidArgument(format!("invalid end index: '{b}'")) + })?; + if start > end { + return Err(ToolError::InvalidArgument(format!( + "start index ({start}) must be <= end index ({end})" + ))); + } + Ok(Selector::Slice { start, end }) + } else if let Some(rest) = s.strip_prefix("key:") { + if rest.is_empty() { + return Err(ToolError::InvalidArgument("key name must not be empty".into())); + } + Ok(Selector::Key(rest.to_string())) + } else { + Err(ToolError::InvalidArgument(format!( + "unrecognized selector format: '{s}'. Expected: lines:N-M, slice:N..M, or key:NAME" + ))) + } +} + +// ─── InspectTool ───────────────────────────────────────────────────────────── + +#[derive(Deserialize)] +struct InspectArgs { + blob_id: String, + selector: Option, +} + +/// Built-in tool that retrieves stored blob content. +pub struct InspectTool { + blob_store: Arc, +} + +impl InspectTool { + pub fn new(blob_store: Arc) -> Self { + Self { blob_store } + } +} + +impl InspectTool { + /// Create a [`ToolDefinition`] factory for this tool. + pub fn tool_definition(blob_store: Arc) -> ToolDefinition { + Arc::new(move || { + let meta = ToolMeta::new("inspect") + .description( + "Retrieve content from a stored blob referenced by [blob:] in conversation history. \ + Supports selectors for partial access: \ + 'lines:N-M' (text line range, 1-based inclusive), \ + 'slice:N..M' (array element range, 0-based exclusive end), \ + 'key:NAME' (object key lookup). \ + Without a selector, returns metadata and a preview.", + ) + .input_schema(json!({ + "type": "object", + "properties": { + "blob_id": { + "type": "string", + "description": "The blob UUID from a [blob:] reference" + }, + "selector": { + "type": "string", + "description": "Optional: 'lines:N-M', 'slice:N..M', or 'key:NAME'" + } + }, + "required": ["blob_id"] + })); + let tool = Arc::new(InspectTool::new(Arc::clone(&blob_store))) as Arc; + (meta, tool) + }) + } +} + +#[async_trait] +impl Tool for InspectTool { + async fn execute(&self, input_json: &str) -> Result { + let args: InspectArgs = serde_json::from_str(input_json) + .map_err(|e| ToolError::InvalidArgument(format!("invalid arguments: {e}")))?; + + let blob_id: BlobId = args + .blob_id + .parse() + .map_err(|_| ToolError::InvalidArgument(format!( + "invalid blob_id: '{}' is not a valid UUID", args.blob_id + )))?; + + let content = self + .blob_store + .load(blob_id) + .await + .map_err(|e| ToolError::ExecutionFailed(format!("{e}")))?; + + match args.selector { + None => Ok(default_view(&content)), + Some(sel) => { + let selector = parse_selector(&sel)?; + apply_selector(&content, &selector) + } + } + } +} + +// ─── Default view ──────────────────────────────────────────────────────────── + +use llm_worker::tool::Content; + +fn default_view(content: &Content) -> String { + match content { + Content::Text(text) => default_view_text(text), + Content::Structured(value) => default_view_structured(value), + } +} + +fn default_view_text(text: &str) -> String { + let lines: Vec<&str> = text.lines().collect(); + let total = lines.len(); + let size = text.len(); + let preview_end = total.min(DEFAULT_PREVIEW_LINES); + + let mut out = format!("type: text\nlines: {total}\nsize: {size} bytes\n\n"); + out.push_str(&format!("── preview (lines 1-{preview_end}) ──\n")); + for line in &lines[..preview_end] { + out.push_str(line); + out.push('\n'); + } + if total > DEFAULT_PREVIEW_LINES { + out.push_str(&format!("... ({} more lines)\n", total - DEFAULT_PREVIEW_LINES)); + } + out +} + +fn default_view_structured(value: &serde_json::Value) -> String { + use serde_json::Value; + match value { + Value::Array(arr) => { + let total = arr.len(); + let preview_end = total.min(DEFAULT_PREVIEW_ELEMENTS); + let mut out = format!("type: json_array\nentries: {total}\n\n"); + out.push_str(&format!("── preview (0..{preview_end}) ──\n")); + for item in &arr[..preview_end] { + if let Ok(json) = serde_json::to_string_pretty(item) { + out.push_str(&json); + out.push('\n'); + } + } + if total > DEFAULT_PREVIEW_ELEMENTS { + out.push_str(&format!("... ({} more entries)\n", total - DEFAULT_PREVIEW_ELEMENTS)); + } + out + } + Value::Object(map) => { + let total = map.len(); + let mut out = format!("type: json_object\nkeys: {total}\n\n── keys ──\n"); + for (key, val) in map.iter() { + out.push_str(&format!("{key}: {}\n", value_type_label(val))); + } + // Preview first N key-value pairs + let preview_keys: Vec<_> = map.iter().take(DEFAULT_PREVIEW_KEYS).collect(); + if !preview_keys.is_empty() { + out.push_str("\n── preview ──\n"); + for (key, val) in preview_keys { + if let Ok(json) = serde_json::to_string_pretty(val) { + out.push_str(&format!("{key}: {json}\n")); + } + } + } + out + } + other => { + // Scalar — just show it + serde_json::to_string_pretty(other).unwrap_or_default() + } + } +} + +fn value_type_label(value: &serde_json::Value) -> &'static str { + match value { + serde_json::Value::Null => "null", + serde_json::Value::Bool(_) => "bool", + serde_json::Value::Number(_) => "number", + serde_json::Value::String(_) => "string", + serde_json::Value::Array(_) => "array", + serde_json::Value::Object(_) => "object", + } +} + +// ─── Selector application ──────────────────────────────────────────────────── + +fn apply_selector(content: &Content, selector: &Selector) -> Result { + match (content, selector) { + (Content::Text(text), Selector::Lines { start, end }) => { + let lines: Vec<&str> = text.lines().collect(); + let total = lines.len(); + // Convert 1-based inclusive to 0-based + let from = (*start - 1).min(total); + let to = (*end).min(total); + if from >= total { + return Ok(format!("(no lines — content has {total} lines)")); + } + Ok(lines[from..to].join("\n")) + } + + (Content::Structured(serde_json::Value::Array(arr)), Selector::Slice { start, end }) => { + let total = arr.len(); + let from = (*start).min(total); + let to = (*end).min(total); + let slice = &arr[from..to]; + serde_json::to_string_pretty(slice) + .map_err(|e| ToolError::Internal(format!("JSON serialization error: {e}"))) + } + + (Content::Structured(serde_json::Value::Object(map)), Selector::Key(key)) => { + match map.get(key.as_str()) { + Some(val) => serde_json::to_string_pretty(val) + .map_err(|e| ToolError::Internal(format!("JSON serialization error: {e}"))), + None => { + let available: Vec<_> = map.keys().collect(); + Err(ToolError::InvalidArgument(format!( + "key '{key}' not found. Available keys: {available:?}" + ))) + } + } + } + + // Type mismatches + (Content::Text(_), Selector::Slice { .. }) => Err(ToolError::InvalidArgument( + "slice selector only applies to JSON arrays, but this blob contains text. Use 'lines:N-M' instead.".into(), + )), + (Content::Text(_), Selector::Key(_)) => Err(ToolError::InvalidArgument( + "key selector only applies to JSON objects, but this blob contains text. Use 'lines:N-M' instead.".into(), + )), + (Content::Structured(_), Selector::Lines { .. }) => Err(ToolError::InvalidArgument( + "lines selector only applies to text content, but this blob contains JSON. Use 'slice:N..M' or 'key:NAME' instead.".into(), + )), + (Content::Structured(serde_json::Value::Object(_)), Selector::Slice { .. }) => Err(ToolError::InvalidArgument( + "slice selector only applies to JSON arrays, but this blob is a JSON object. Use 'key:NAME' instead.".into(), + )), + (Content::Structured(serde_json::Value::Array(_)), Selector::Key(_)) => Err(ToolError::InvalidArgument( + "key selector only applies to JSON objects, but this blob is a JSON array. Use 'slice:N..M' instead.".into(), + )), + (Content::Structured(_), Selector::Slice { .. }) => Err(ToolError::InvalidArgument( + "slice selector only applies to JSON arrays.".into(), + )), + (Content::Structured(_), Selector::Key(_)) => Err(ToolError::InvalidArgument( + "key selector only applies to JSON objects.".into(), + )), + } +} + +// ─── Registration helper ───────────────────────────────────────────────────── + +/// Register the `inspect` tool on a [`Worker`]. +/// +/// Call this alongside [`BlobOutputProcessor`](crate::BlobOutputProcessor) +/// setup so the LLM can retrieve stored blob content. +pub fn register_inspect_tool( + worker: &mut Worker, + blob_store: Arc, +) -> Result<(), ToolRegistryError> +where + C: LlmClient, + B: BlobStore + 'static, +{ + worker.register_tool(InspectTool::::tool_definition(blob_store)) +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::blob_store::{new_blob_id, BlobStoreError}; + use llm_worker::tool::Content; + use std::collections::HashMap; + use tokio::sync::Mutex; + + // ── In-memory BlobStore for tests ──────────────────────────────────── + + struct MemBlobStore { + blobs: Mutex>, + } + + impl MemBlobStore { + fn new() -> Self { + Self { + blobs: Mutex::new(HashMap::new()), + } + } + } + + impl BlobStore for MemBlobStore { + async fn store(&self, content: &Content) -> Result { + let id = new_blob_id(); + self.blobs.lock().await.insert(id, content.clone()); + Ok(id) + } + + async fn load(&self, id: BlobId) -> Result { + self.blobs + .lock() + .await + .get(&id) + .cloned() + .ok_or(BlobStoreError::NotFound(id)) + } + + async fn exists(&self, id: BlobId) -> Result { + Ok(self.blobs.lock().await.contains_key(&id)) + } + } + + // ── Selector parsing ───────────────────────────────────────────────── + + #[test] + fn parse_lines_valid() { + assert_eq!( + parse_selector("lines:1-50").unwrap(), + Selector::Lines { start: 1, end: 50 } + ); + assert_eq!( + parse_selector("lines:5-5").unwrap(), + Selector::Lines { start: 5, end: 5 } + ); + } + + #[test] + fn parse_lines_zero_start() { + let err = parse_selector("lines:0-5").unwrap_err(); + assert!(matches!(err, ToolError::InvalidArgument(_))); + } + + #[test] + fn parse_lines_inverted() { + let err = parse_selector("lines:50-20").unwrap_err(); + assert!(matches!(err, ToolError::InvalidArgument(_))); + } + + #[test] + fn parse_lines_missing_dash() { + let err = parse_selector("lines:20").unwrap_err(); + assert!(matches!(err, ToolError::InvalidArgument(_))); + } + + #[test] + fn parse_slice_valid() { + assert_eq!( + parse_selector("slice:0..10").unwrap(), + Selector::Slice { start: 0, end: 10 } + ); + assert_eq!( + parse_selector("slice:3..8").unwrap(), + Selector::Slice { start: 3, end: 8 } + ); + } + + #[test] + fn parse_slice_inverted() { + let err = parse_selector("slice:10..3").unwrap_err(); + assert!(matches!(err, ToolError::InvalidArgument(_))); + } + + #[test] + fn parse_key_valid() { + assert_eq!( + parse_selector("key:results").unwrap(), + Selector::Key("results".into()) + ); + // Key name with colon + assert_eq!( + parse_selector("key:nested:key").unwrap(), + Selector::Key("nested:key".into()) + ); + } + + #[test] + fn parse_key_empty() { + let err = parse_selector("key:").unwrap_err(); + assert!(matches!(err, ToolError::InvalidArgument(_))); + } + + #[test] + fn parse_unknown_prefix() { + let err = parse_selector("unknown:foo").unwrap_err(); + assert!(matches!(err, ToolError::InvalidArgument(_))); + } + + // ── Default view ───────────────────────────────────────────────────── + + #[test] + fn default_view_text_short() { + let text = "line1\nline2\nline3\n"; + let content = Content::Text(text.into()); + let view = default_view(&content); + assert!(view.contains("type: text")); + assert!(view.contains("lines: 3")); + assert!(view.contains("line1")); + assert!(!view.contains("more lines")); + } + + #[test] + fn default_view_text_long() { + let text: String = (1..=100).map(|i| format!("line {i}\n")).collect(); + let content = Content::Text(text); + let view = default_view(&content); + assert!(view.contains("type: text")); + assert!(view.contains("lines: 100")); + assert!(view.contains("line 1")); + assert!(view.contains("line 50")); + assert!(!view.contains("line 51\n")); + assert!(view.contains("50 more lines")); + } + + #[test] + fn default_view_array() { + let arr: Vec = (0..20).map(|i| json!({"id": i})).collect(); + let content = Content::Structured(json!(arr)); + let view = default_view(&content); + assert!(view.contains("type: json_array")); + assert!(view.contains("entries: 20")); + assert!(view.contains("15 more entries")); + } + + #[test] + fn default_view_object() { + let content = Content::Structured(json!({ + "name": "test", + "count": 42, + "items": [1, 2, 3], + "nested": {"a": 1} + })); + let view = default_view(&content); + assert!(view.contains("type: json_object")); + assert!(view.contains("keys: 4")); + assert!(view.contains("── keys ──")); + assert!(view.contains("── preview ──")); + } + + // ── Selector application ───────────────────────────────────────────── + + #[test] + fn apply_lines_on_text() { + let text = "a\nb\nc\nd\ne\nf\n"; + let content = Content::Text(text.into()); + let result = apply_selector(&content, &Selector::Lines { start: 2, end: 4 }).unwrap(); + assert_eq!(result, "b\nc\nd"); + } + + #[test] + fn apply_lines_clamp() { + let text = "a\nb\nc\n"; + let content = Content::Text(text.into()); + let result = apply_selector(&content, &Selector::Lines { start: 2, end: 100 }).unwrap(); + assert_eq!(result, "b\nc"); + } + + #[test] + fn apply_lines_beyond_content() { + let text = "a\nb\n"; + let content = Content::Text(text.into()); + let result = apply_selector(&content, &Selector::Lines { start: 10, end: 20 }).unwrap(); + assert!(result.contains("no lines")); + } + + #[test] + fn apply_slice_on_array() { + let content = Content::Structured(json!([10, 20, 30, 40, 50])); + let result = apply_selector(&content, &Selector::Slice { start: 1, end: 3 }).unwrap(); + let parsed: Vec = serde_json::from_str(&result).unwrap(); + assert_eq!(parsed, vec![20, 30]); + } + + #[test] + fn apply_slice_clamp() { + let content = Content::Structured(json!([10, 20, 30])); + let result = apply_selector(&content, &Selector::Slice { start: 1, end: 100 }).unwrap(); + let parsed: Vec = serde_json::from_str(&result).unwrap(); + assert_eq!(parsed, vec![20, 30]); + } + + #[test] + fn apply_key_on_object() { + let content = Content::Structured(json!({"name": "test", "count": 42})); + let result = apply_selector(&content, &Selector::Key("name".into())).unwrap(); + assert_eq!(result.trim(), "\"test\""); + } + + #[test] + fn apply_key_not_found() { + let content = Content::Structured(json!({"name": "test"})); + let err = apply_selector(&content, &Selector::Key("missing".into())).unwrap_err(); + match err { + ToolError::InvalidArgument(msg) => { + assert!(msg.contains("missing")); + assert!(msg.contains("name")); + } + _ => panic!("expected InvalidArgument"), + } + } + + // ── Type mismatch errors ───────────────────────────────────────────── + + #[test] + fn lines_on_json_error() { + let content = Content::Structured(json!([1, 2, 3])); + let err = apply_selector(&content, &Selector::Lines { start: 1, end: 3 }).unwrap_err(); + assert!(matches!(err, ToolError::InvalidArgument(_))); + } + + #[test] + fn slice_on_text_error() { + let content = Content::Text("hello".into()); + let err = apply_selector(&content, &Selector::Slice { start: 0, end: 3 }).unwrap_err(); + assert!(matches!(err, ToolError::InvalidArgument(_))); + } + + #[test] + fn key_on_text_error() { + let content = Content::Text("hello".into()); + let err = apply_selector(&content, &Selector::Key("foo".into())).unwrap_err(); + assert!(matches!(err, ToolError::InvalidArgument(_))); + } + + #[test] + fn slice_on_object_error() { + let content = Content::Structured(json!({"a": 1})); + let err = apply_selector(&content, &Selector::Slice { start: 0, end: 3 }).unwrap_err(); + assert!(matches!(err, ToolError::InvalidArgument(_))); + } + + #[test] + fn key_on_array_error() { + let content = Content::Structured(json!([1, 2, 3])); + let err = apply_selector(&content, &Selector::Key("foo".into())).unwrap_err(); + assert!(matches!(err, ToolError::InvalidArgument(_))); + } + + // ── Integration via execute() ──────────────────────────────────────── + + #[tokio::test] + async fn execute_default_view() { + let store = Arc::new(MemBlobStore::new()); + let text = (1..=100).map(|i| format!("line {i}")).collect::>().join("\n"); + let blob_id = store.store(&Content::Text(text)).await.unwrap(); + + let tool = InspectTool::new(store); + let result = tool + .execute(&json!({"blob_id": blob_id.to_string()}).to_string()) + .await + .unwrap(); + assert!(result.contains("type: text")); + assert!(result.contains("lines: 100")); + } + + #[tokio::test] + async fn execute_with_selector() { + let store = Arc::new(MemBlobStore::new()); + let blob_id = store + .store(&Content::Structured(json!({"name": "test", "value": 42}))) + .await + .unwrap(); + + let tool = InspectTool::new(store); + let result = tool + .execute(&json!({"blob_id": blob_id.to_string(), "selector": "key:name"}).to_string()) + .await + .unwrap(); + assert_eq!(result.trim(), "\"test\""); + } + + #[tokio::test] + async fn execute_invalid_blob_id() { + let store = Arc::new(MemBlobStore::new()); + let tool = InspectTool::new(store); + let err = tool + .execute(&json!({"blob_id": "not-a-uuid"}).to_string()) + .await + .unwrap_err(); + assert!(matches!(err, ToolError::InvalidArgument(_))); + } + + #[tokio::test] + async fn execute_blob_not_found() { + let store = Arc::new(MemBlobStore::new()); + let tool = InspectTool::new(store); + let fake_id = new_blob_id(); + let err = tool + .execute(&json!({"blob_id": fake_id.to_string()}).to_string()) + .await + .unwrap_err(); + assert!(matches!(err, ToolError::ExecutionFailed(_))); + } +} diff --git a/crates/llm-worker-persistence/src/lib.rs b/crates/llm-worker-persistence/src/lib.rs index b5a2e67a..c329738c 100644 --- a/crates/llm-worker-persistence/src/lib.rs +++ b/crates/llm-worker-persistence/src/lib.rs @@ -25,12 +25,14 @@ pub mod blob_store; pub mod event_trace; pub mod fs_blob_store; pub mod fs_store; +pub mod inspect_tool; pub mod session; pub mod session_log; pub mod store; pub use blob_output_processor::BlobOutputProcessor; pub use blob_store::{BlobId, BlobStore, BlobStoreError}; +pub use inspect_tool::{InspectTool, register_inspect_tool}; pub use event_trace::TraceEntry; pub use fs_blob_store::FsBlobStore; pub use fs_store::FsStore;