diff --git a/.gitignore b/.gitignore index 5382a4e6..0de651ce 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ /target .direnv -*.local +*.local* .env diff --git a/TODO.md b/TODO.md index 5e0f1b66..22147d6f 100644 --- a/TODO.md +++ b/TODO.md @@ -10,3 +10,5 @@ - [x] セッションエントリのハッシュチェーン - [x] Subscriber → クロージャ API 移行 - [x] JSONL ストリーム変換ユーティリティ (protocol::stream) +- [x] Hook モジュールの llm-worker からの除去 → [tickets/remove-hook-module.md](tickets/remove-hook-module.md) +- [ ] api_key_file: ファイルパスによるAPIキー解決 → [tickets/api-key-file.md](tickets/api-key-file.md) diff --git a/crates/llm-worker-persistence/tests/session_test.rs b/crates/llm-worker-persistence/tests/session_test.rs index 99149a17..25aece74 100644 --- a/crates/llm-worker-persistence/tests/session_test.rs +++ b/crates/llm-worker-persistence/tests/session_test.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use async_trait::async_trait; use common::MockLlmClient; -use llm_worker::hook::{Hook, HookError, OnTurnEnd, OnTurnEndResult}; +use llm_worker::interceptor::{Interceptor, TurnEndAction}; use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent}; use llm_worker::llm_client::types::{Item, RequestConfig}; use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta}; @@ -76,13 +76,13 @@ fn weather_tool_definition() -> ToolDefinition { }) } -/// Hook that forces Pause on the first turn end. -struct PauseOnFirstTurnEnd; +/// Policy that forces Pause on every turn end. +struct PausePolicy; #[async_trait] -impl Hook for PauseOnFirstTurnEnd { - async fn call(&self, _input: &mut Vec) -> Result { - Ok(OnTurnEndResult::Paused) +impl Interceptor for PausePolicy { + async fn on_turn_end(&self, _history: &[Item]) -> TurnEndAction { + TurnEndAction::Pause } } @@ -214,7 +214,7 @@ async fn session_resume_after_pause() { let client = MockLlmClient::with_responses(tool_call_events()); let mut worker = Worker::new(client); worker.register_tool(weather_tool_definition()).unwrap(); - worker.add_on_turn_end_hook(PauseOnFirstTurnEnd); + worker.set_interceptor(PausePolicy); let mut session = Session::new(worker, store.clone(), SessionConfig::default()) .await diff --git a/crates/llm-worker/examples/worker_cli.rs b/crates/llm-worker/examples/worker_cli.rs index 1b3b6ea4..ed6c606d 100644 --- a/crates/llm-worker/examples/worker_cli.rs +++ b/crates/llm-worker/examples/worker_cli.rs @@ -41,7 +41,6 @@ use tracing_subscriber::EnvFilter; use clap::{Parser, ValueEnum}; use llm_worker::{ Worker, - hook::{Hook, HookError, PostToolCall, PostToolCallContext, PostToolCallResult}, llm_client::{ LlmClient, providers::{ @@ -49,6 +48,7 @@ use llm_worker::{ openai::OpenAIClient, }, }, + interceptor::{Interceptor, PostToolAction, ToolResultInfo}, timeline::{Handler, TextBlockEvent, TextBlockKind, ToolUseBlockEvent, ToolUseBlockKind}, }; use llm_worker_macros::tool_registry; @@ -270,34 +270,34 @@ impl Handler for ToolCallPrinter { } } -/// Hook that displays tool execution results -struct ToolResultPrinterHook { +/// Policy that displays tool execution results. +struct ToolResultPrinterPolicy { call_names: Arc>>, } -impl ToolResultPrinterHook { +impl ToolResultPrinterPolicy { fn new(call_names: Arc>>) -> Self { Self { call_names } } } #[async_trait] -impl Hook for ToolResultPrinterHook { - async fn call(&self, ctx: &mut PostToolCallContext) -> Result { +impl Interceptor for ToolResultPrinterPolicy { + async fn post_tool_call(&self, info: &mut ToolResultInfo) -> PostToolAction { let name = self .call_names .lock() .unwrap() - .remove(&ctx.result.tool_use_id) - .unwrap_or_else(|| ctx.result.tool_use_id.clone()); + .remove(&info.result.tool_use_id) + .unwrap_or_else(|| info.result.tool_use_id.clone()); - if ctx.result.is_error { - println!(" Result ({}): ❌ {}", name, ctx.result.content); + if info.result.is_error { + println!(" Result ({}): ❌ {}", name, info.result.content); } else { - println!(" Result ({}): ✅ {}", name, ctx.result.content); + println!(" Result ({}): ✅ {}", name, info.result.content); } - Ok(PostToolCallResult::Continue) + PostToolAction::Continue } } @@ -450,7 +450,7 @@ async fn main() -> Result<(), Box> { .on_text_block(StreamingPrinter::new()) .on_tool_use_block(ToolCallPrinter::new(tool_call_names.clone())); - worker.add_post_tool_call_hook(ToolResultPrinterHook::new(tool_call_names)); + worker.set_interceptor(ToolResultPrinterPolicy::new(tool_call_names)); // One-shot mode if let Some(prompt) = args.prompt { diff --git a/crates/llm-worker/src/callback.rs b/crates/llm-worker/src/callback.rs index a66de95f..7d431bbd 100644 --- a/crates/llm-worker/src/callback.rs +++ b/crates/llm-worker/src/callback.rs @@ -10,7 +10,7 @@ use crate::handler::{ Handler, Kind, TextBlockEvent, TextBlockKind, ToolUseBlockEvent, ToolUseBlockKind, ToolUseBlockStart, }; -use crate::hook::ToolCall; +use crate::tool::ToolCall; // ============================================================================= // TextBlock Closure Handler diff --git a/crates/llm-worker/src/hook.rs b/crates/llm-worker/src/hook.rs deleted file mode 100644 index 1a8dc05f..00000000 --- a/crates/llm-worker/src/hook.rs +++ /dev/null @@ -1,310 +0,0 @@ -//! Hook-related type definitions -//! -//! Types used for turn control and intervention in the Worker layer - -use async_trait::async_trait; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use thiserror::Error; - -// ============================================================================= -// Hook Event Kinds -// ============================================================================= - -pub trait HookEventKind: Send + Sync + 'static { - type Input; - type Output; -} - -pub struct OnPromptSubmit; -pub struct PreLlmRequest; -pub struct PreToolCall; -pub struct PostToolCall; -pub struct OnTurnEnd; -pub struct OnAbort; -pub struct OnTextDelta; -pub struct OnToolCallDelta; -pub struct OnStreamChunk; -pub struct OnStreamComplete; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum OnPromptSubmitResult { - Continue, - Cancel(String), -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum PreLlmRequestResult { - Continue, - Cancel(String), -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum PreToolCallResult { - Continue, - Skip, - Abort(String), - Pause, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum PostToolCallResult { - Continue, - Abort(String), -} - -#[derive(Debug, Clone)] -pub enum OnTurnEndResult { - Finish, - ContinueWithMessages(Vec), - Paused, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum StreamHookResult { - Continue, - Abort(String), - Pause, -} - -use std::sync::Arc; - -use crate::tool::{Tool, ToolMeta}; - -/// Input context for PreToolCall -pub struct ToolCallContext { - /// Tool call information (modifiable) - pub call: ToolCall, - /// Tool meta information (immutable) - pub meta: ToolMeta, - /// Tool instance (for state access) - pub tool: Arc, -} - -/// Input context for PostToolCall -pub struct PostToolCallContext { - /// Tool call information - pub call: ToolCall, - /// Tool execution result (modifiable) - pub result: ToolResult, - /// Tool meta information (immutable) - pub meta: ToolMeta, - /// Tool instance (for state access) - pub tool: Arc, -} - -/// Input context for OnTextDelta -#[derive(Debug, Clone)] -pub struct TextDeltaContext { - /// Block index - pub index: usize, - /// Text delta content - pub delta: String, -} - -/// Input context for OnToolCallDelta -#[derive(Debug, Clone)] -pub struct ToolCallDeltaContext { - /// Block index - pub index: usize, - /// Partial JSON fragment - pub delta_json_fragment: String, -} - -/// Input context for OnStreamChunk -#[derive(Debug, Clone)] -pub struct StreamChunkContext { - /// Public worker-level event - pub event: crate::event::Event, -} - -/// Input context for OnStreamComplete -#[derive(Debug, Clone)] -pub struct StreamCompleteContext { - /// Current turn number - pub turn: usize, - /// Number of streamed events in this request - pub event_count: usize, -} - -impl HookEventKind for OnPromptSubmit { - type Input = crate::Item; - type Output = OnPromptSubmitResult; -} - -impl HookEventKind for PreLlmRequest { - type Input = Vec; - type Output = PreLlmRequestResult; -} - -impl HookEventKind for PreToolCall { - type Input = ToolCallContext; - type Output = PreToolCallResult; -} - -impl HookEventKind for PostToolCall { - type Input = PostToolCallContext; - type Output = PostToolCallResult; -} - -impl HookEventKind for OnTurnEnd { - type Input = Vec; - type Output = OnTurnEndResult; -} - -impl HookEventKind for OnAbort { - type Input = String; - type Output = (); -} - -impl HookEventKind for OnTextDelta { - type Input = TextDeltaContext; - type Output = StreamHookResult; -} - -impl HookEventKind for OnToolCallDelta { - type Input = ToolCallDeltaContext; - type Output = StreamHookResult; -} - -impl HookEventKind for OnStreamChunk { - type Input = StreamChunkContext; - type Output = StreamHookResult; -} - -impl HookEventKind for OnStreamComplete { - type Input = StreamCompleteContext; - type Output = StreamHookResult; -} - -// ============================================================================= -// Tool Call / Result Types -// ============================================================================= - -/// Tool call information -/// -/// Represents a ToolUse block from LLM, modifiable in Hook processing -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolCall { - /// Tool call ID (used for linking with response) - pub id: String, - /// Tool name - pub name: String, - /// Input arguments (JSON) - pub input: Value, -} - -/// Tool execution result -/// -/// Represents the result after tool execution, modifiable in Hook processing -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolResult { - /// Corresponding tool call ID - pub tool_use_id: String, - /// Result content - pub content: String, - /// Whether this is an error - #[serde(default)] - pub is_error: bool, -} - -impl ToolResult { - /// Create a success result - pub fn success(tool_use_id: impl Into, content: impl Into) -> Self { - Self { - tool_use_id: tool_use_id.into(), - content: content.into(), - is_error: false, - } - } - - /// Create an error result - pub fn error(tool_use_id: impl Into, content: impl Into) -> Self { - Self { - tool_use_id: tool_use_id.into(), - content: content.into(), - is_error: true, - } - } -} - -// ============================================================================= -// Hook Error -// ============================================================================= - -/// Hook error -#[derive(Debug, Error)] -pub enum HookError { - /// Processing was aborted - #[error("Aborted: {0}")] - Aborted(String), - /// Internal error - #[error("Hook error: {0}")] - Internal(String), -} - -// ============================================================================= -// Hook Trait -// ============================================================================= - -/// Trait for handling Hook events -/// -/// Each event type has a different return type, constrained via `HookEventKind`. -#[async_trait] -pub trait Hook: Send + Sync { - async fn call(&self, input: &mut E::Input) -> Result; -} - -// ============================================================================= -// Hook Registry -// ============================================================================= - -/// Registry holding all Hooks -/// -/// Used internally by Worker to manage all Hook types. -pub struct HookRegistry { - /// on_prompt_submit Hook - pub(crate) on_prompt_submit: Vec>>, - /// pre_llm_request Hook - pub(crate) pre_llm_request: Vec>>, - /// pre_tool_call Hook - pub(crate) pre_tool_call: Vec>>, - /// post_tool_call Hook - pub(crate) post_tool_call: Vec>>, - /// on_turn_end Hook - pub(crate) on_turn_end: Vec>>, - /// on_abort Hook - pub(crate) on_abort: Vec>>, - /// on_text_delta Hook - pub(crate) on_text_delta: Vec>>, - /// on_tool_call_delta Hook - pub(crate) on_tool_call_delta: Vec>>, - /// on_stream_chunk Hook - pub(crate) on_stream_chunk: Vec>>, - /// on_stream_complete Hook - pub(crate) on_stream_complete: Vec>>, -} - -impl Default for HookRegistry { - fn default() -> Self { - Self::new() - } -} - -impl HookRegistry { - /// Create an empty HookRegistry - pub fn new() -> Self { - Self { - on_prompt_submit: Vec::new(), - pre_llm_request: Vec::new(), - pre_tool_call: Vec::new(), - post_tool_call: Vec::new(), - on_turn_end: Vec::new(), - on_abort: Vec::new(), - on_text_delta: Vec::new(), - on_tool_call_delta: Vec::new(), - on_stream_chunk: Vec::new(), - on_stream_complete: Vec::new(), - } - } -} diff --git a/crates/llm-worker/src/interceptor.rs b/crates/llm-worker/src/interceptor.rs new file mode 100644 index 00000000..725fd93d --- /dev/null +++ b/crates/llm-worker/src/interceptor.rs @@ -0,0 +1,142 @@ +//! Interceptor - control flow delegation for the Worker execution loop +//! +//! Defines the [`Interceptor`] trait that upper layers (e.g. Pod) implement +//! to inject orchestration decisions (approval, skip, pause, abort) +//! into the Worker's turn loop without the Worker knowing about +//! higher-level concepts. + +use std::sync::Arc; + +use async_trait::async_trait; + +use crate::tool::{Tool, ToolCall, ToolMeta, ToolResult}; +use crate::Item; + +// ============================================================================= +// Action Enums +// ============================================================================= + +/// Action after prompt submission. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PromptAction { + /// Proceed normally. + Continue, + /// Cancel with a reason. + Cancel(String), +} + +/// Action before an LLM request. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PreRequestAction { + /// Proceed normally. + Continue, + /// Cancel with a reason. + Cancel(String), +} + +/// Action before a tool call. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PreToolAction { + /// Proceed with execution. + Continue, + /// Skip this tool call (do not execute). + Skip, + /// Abort the entire run. + Abort(String), + /// Pause execution (can be resumed later). + Pause, +} + +/// Action after a tool call. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PostToolAction { + /// Proceed normally. + Continue, + /// Abort the entire run. + Abort(String), +} + +/// Action at the end of a turn (when LLM produces no tool calls). +#[derive(Debug, Clone)] +pub enum TurnEndAction { + /// Turn is finished, return to caller. + Finish, + /// Continue with additional messages injected into history. + ContinueWithMessages(Vec), + /// Pause execution (can be resumed later). + Pause, +} + +// ============================================================================= +// Context Types +// ============================================================================= + +/// Context for pre-tool-call decisions. +pub struct ToolCallInfo { + /// Tool call information (modifiable). + pub call: ToolCall, + /// Tool meta information. + pub meta: ToolMeta, + /// Tool instance (for state access). + pub tool: Arc, +} + +/// Context for post-tool-call decisions. +pub struct ToolResultInfo { + /// Original tool call. + pub call: ToolCall, + /// Tool execution result (modifiable). + pub result: ToolResult, + /// Tool meta information. + pub meta: ToolMeta, + /// Tool instance (for state access). + pub tool: Arc, +} + +// ============================================================================= +// Interceptor Trait +// ============================================================================= + +/// Intercepts the Worker execution loop at key decision points. +/// +/// All methods have default implementations that let the Worker +/// proceed without intervention. Upper layers (e.g. Pod) provide +/// richer implementations for approval flows, permission checks, etc. +#[async_trait] +pub trait Interceptor: Send + Sync { + /// Called after receiving user input, before adding to history. + async fn on_prompt_submit(&self, _item: &mut Item) -> PromptAction { + PromptAction::Continue + } + + /// Called before each LLM request. The context can be modified + /// (e.g. for context compaction). + async fn pre_llm_request(&self, _context: &mut Vec) -> PreRequestAction { + PreRequestAction::Continue + } + + /// Called before each tool is executed. + async fn pre_tool_call(&self, _info: &mut ToolCallInfo) -> PreToolAction { + PreToolAction::Continue + } + + /// Called after each tool completes. + async fn post_tool_call(&self, _info: &mut ToolResultInfo) -> PostToolAction { + PostToolAction::Continue + } + + /// Called when a turn ends with no tool calls. + async fn on_turn_end(&self, _history: &[Item]) -> TurnEndAction { + TurnEndAction::Finish + } + + /// Called when execution is interrupted (abort or cancel). + async fn on_abort(&self, _reason: &str) {} +} + +/// Default interceptor: no intervention. Worker proceeds through the loop +/// without any external control flow decisions. +pub(crate) struct DefaultInterceptor; + +#[async_trait] +impl Interceptor for DefaultInterceptor {} diff --git a/crates/llm-worker/src/lib.rs b/crates/llm-worker/src/lib.rs index 0f0fb052..771da62b 100644 --- a/crates/llm-worker/src/lib.rs +++ b/crates/llm-worker/src/lib.rs @@ -6,7 +6,7 @@ //! //! - [`Worker`] - Central component for managing LLM interactions //! - [`tool::Tool`] - Tools that can be invoked by the LLM -//! - [`hook::Hook`] - Hooks for intercepting turn progression +//! - [`interceptor::Interceptor`] - Control-flow delegation for the execution loop //! - Closure-based event callbacks via `Worker::on_text_block()`, `on_tool_use_block()`, etc. //! //! # Quick Start @@ -41,8 +41,8 @@ mod worker; pub(crate) mod callback; pub mod event; -pub mod hook; pub mod llm_client; +pub mod interceptor; pub mod state; pub mod timeline; pub mod tool; @@ -51,4 +51,6 @@ pub mod tool_server; pub use callback::{TextBlockScope, ToolUseBlockScope}; pub use handler::ToolUseBlockStart; pub use message::{ContentPart, Item, Message, Role}; +pub use interceptor::Interceptor; +pub use tool::{ToolCall, ToolResult}; pub use worker::{ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult}; diff --git a/crates/llm-worker/src/timeline/tool_call_collector.rs b/crates/llm-worker/src/timeline/tool_call_collector.rs index 853ec5e7..e81a6b6f 100644 --- a/crates/llm-worker/src/timeline/tool_call_collector.rs +++ b/crates/llm-worker/src/timeline/tool_call_collector.rs @@ -5,7 +5,7 @@ use crate::{ handler::{Handler, ToolUseBlockEvent, ToolUseBlockKind}, - hook::ToolCall, + tool::ToolCall, }; use std::sync::{Arc, Mutex}; diff --git a/crates/llm-worker/src/tool.rs b/crates/llm-worker/src/tool.rs index ec1c5d44..622470f4 100644 --- a/crates/llm-worker/src/tool.rs +++ b/crates/llm-worker/src/tool.rs @@ -370,3 +370,54 @@ pub trait ToolOutputProcessor: Send + Sync { /// For large outputs, this should be a summary with a blob reference. async fn process(&self, output: String) -> Result; } + +// ============================================================================= +// Tool Call / Result Types +// ============================================================================= + +/// Tool call information +/// +/// Represents a ToolUse block from LLM. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + /// Tool call ID (used for linking with response) + pub id: String, + /// Tool name + pub name: String, + /// Input arguments (JSON) + pub input: Value, +} + +/// Tool execution result +/// +/// Represents the result after tool execution. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolResult { + /// Corresponding tool call ID + pub tool_use_id: String, + /// Result content + pub content: String, + /// Whether this is an error + #[serde(default)] + pub is_error: bool, +} + +impl ToolResult { + /// Create a success result + pub fn success(tool_use_id: impl Into, content: impl Into) -> Self { + Self { + tool_use_id: tool_use_id.into(), + content: content.into(), + is_error: false, + } + } + + /// Create an error result + pub fn error(tool_use_id: impl Into, content: impl Into) -> Self { + Self { + tool_use_id: tool_use_id.into(), + content: content.into(), + is_error: true, + } + } +} diff --git a/crates/llm-worker/src/worker.rs b/crates/llm-worker/src/worker.rs index c895e918..79f232cc 100644 --- a/crates/llm-worker/src/worker.rs +++ b/crates/llm-worker/src/worker.rs @@ -8,15 +8,11 @@ use tracing::{debug, info, trace, warn}; use crate::{ Item, - hook::{ - Hook, HookError, HookRegistry, OnAbort, OnPromptSubmit, OnPromptSubmitResult, - OnStreamChunk, OnStreamComplete, OnTextDelta, OnToolCallDelta, OnTurnEnd, OnTurnEndResult, - PostToolCall, PostToolCallContext, PostToolCallResult, PreLlmRequest, PreLlmRequestResult, - PreToolCall, PreToolCallResult, StreamChunkContext, StreamCompleteContext, - StreamHookResult, TextDeltaContext, ToolCall, ToolCallContext, ToolCallDeltaContext, - ToolResult, - }, llm_client::{ClientError, ConfigWarning, LlmClient, Request, RequestConfig, ToolDefinition}, + interceptor::{ + DefaultInterceptor, Interceptor, PostToolAction, PreRequestAction, PreToolAction, + PromptAction, ToolCallInfo, ToolResultInfo, TurnEndAction, + }, state::{CacheLocked, Mutable, WorkerState}, callback::{ ClosureMetaHandler, ClosureTextBlockHandler, ClosureToolUseBlockHandler, TextBlockScope, @@ -25,7 +21,7 @@ use crate::{ handler::{ErrorKind, StatusKind, ToolUseBlockStart, UsageKind}, timeline::{TextBlockCollector, Timeline, ToolCallCollector}, timeline::event::{ErrorEvent, StatusEvent, UsageEvent}, - tool::{ToolDefinition as WorkerToolDefinition, ToolError, ToolOutputProcessor}, + tool::{ToolCall, ToolDefinition as WorkerToolDefinition, ToolError, ToolOutputProcessor, ToolResult}, tool_server::{ToolServer, ToolServerError, ToolServerHandle}, }; @@ -42,9 +38,6 @@ pub enum WorkerError { /// Tool error #[error("Tool error: {0}")] Tool(#[from] ToolError), - /// Hook error - #[error("Hook error: {0}")] - Hook(#[from] HookError), /// Execution was aborted #[error("Aborted: {0}")] Aborted(String), @@ -145,8 +138,8 @@ pub struct Worker { tool_call_collector: ToolCallCollector, /// Tool server handle tool_server: ToolServerHandle, - /// Hook registry - hooks: HookRegistry, + /// Interceptor for control-flow decisions + interceptor: Box, /// System prompt system_prompt: Option, /// Item history (owned by Worker) @@ -192,21 +185,16 @@ impl Worker { user_input: impl Into, ) -> Result { self.reset_interruption_state(); - // Hook: on_prompt_submit + // Interceptor: on_prompt_submit let mut user_item = Item::user_message(user_input); - let result = self.run_on_prompt_submit_hooks(&mut user_item).await; - let result = match result { - Ok(value) => value, - Err(err) => return self.finalize_interruption(Err(err)).await, - }; - match result { - OnPromptSubmitResult::Cancel(reason) => { + match self.interceptor.on_prompt_submit(&mut user_item).await { + PromptAction::Cancel(reason) => { self.last_run_interrupted = true; return self .finalize_interruption(Err(WorkerError::Aborted(reason))) .await; } - OnPromptSubmitResult::Continue => {} + PromptAction::Continue => {} } self.history.push(user_item); let result = self.run_turn_loop().await; @@ -335,58 +323,13 @@ impl Worker { self.tool_server.clone() } - /// Add an on_prompt_submit Hook + /// Set the interceptor for control-flow decisions. /// - /// Called immediately after receiving a user message in `run()`. - pub fn add_on_prompt_submit_hook(&mut self, hook: impl Hook + 'static) { - self.hooks.on_prompt_submit.push(Box::new(hook)); - } - - /// Add a pre_llm_request Hook - /// - /// Called before sending an LLM request for each turn. - pub fn add_pre_llm_request_hook(&mut self, hook: impl Hook + 'static) { - self.hooks.pre_llm_request.push(Box::new(hook)); - } - - /// Add a pre_tool_call Hook - pub fn add_pre_tool_call_hook(&mut self, hook: impl Hook + 'static) { - self.hooks.pre_tool_call.push(Box::new(hook)); - } - - /// Add a post_tool_call Hook - pub fn add_post_tool_call_hook(&mut self, hook: impl Hook + 'static) { - self.hooks.post_tool_call.push(Box::new(hook)); - } - - /// Add an on_turn_end Hook - pub fn add_on_turn_end_hook(&mut self, hook: impl Hook + 'static) { - self.hooks.on_turn_end.push(Box::new(hook)); - } - - /// Add an on_abort Hook - pub fn add_on_abort_hook(&mut self, hook: impl Hook + 'static) { - self.hooks.on_abort.push(Box::new(hook)); - } - - /// Add an on_text_delta Hook - pub fn add_on_text_delta_hook(&mut self, hook: impl Hook + 'static) { - self.hooks.on_text_delta.push(Box::new(hook)); - } - - /// Add an on_tool_call_delta Hook - pub fn add_on_tool_call_delta_hook(&mut self, hook: impl Hook + 'static) { - self.hooks.on_tool_call_delta.push(Box::new(hook)); - } - - /// Add an on_stream_chunk Hook - pub fn add_on_stream_chunk_hook(&mut self, hook: impl Hook + 'static) { - self.hooks.on_stream_chunk.push(Box::new(hook)); - } - - /// Add an on_stream_complete Hook - pub fn add_on_stream_complete_hook(&mut self, hook: impl Hook + 'static) { - self.hooks.on_stream_complete.push(Box::new(hook)); + /// The interceptor governs approval, skip, pause, and abort decisions + /// at key points in the execution loop. If not set, the default + /// interceptor is used (all Continue / Finish). + pub fn set_interceptor(&mut self, interceptor: impl Interceptor + 'static) { + self.interceptor = Box::new(interceptor); } /// Get a mutable reference to the timeline (for additional handler registration) @@ -578,131 +521,6 @@ impl Worker { /// Hooks: on_prompt_submit /// - /// Called immediately after receiving a user message in `run()` (first time only). - async fn run_on_prompt_submit_hooks( - &self, - item: &mut Item, - ) -> Result { - for hook in &self.hooks.on_prompt_submit { - let result = hook.call(item).await?; - match result { - OnPromptSubmitResult::Continue => continue, - OnPromptSubmitResult::Cancel(reason) => { - return Ok(OnPromptSubmitResult::Cancel(reason)); - } - } - } - Ok(OnPromptSubmitResult::Continue) - } - - /// Hooks: pre_llm_request - /// - /// Called before sending an LLM request for each turn. - async fn run_pre_llm_request_hooks( - &self, - ) -> Result<(PreLlmRequestResult, Vec), WorkerError> { - let mut temp_context = self.history.clone(); - for hook in &self.hooks.pre_llm_request { - let result = hook.call(&mut temp_context).await?; - match result { - PreLlmRequestResult::Continue => continue, - PreLlmRequestResult::Cancel(reason) => { - return Ok((PreLlmRequestResult::Cancel(reason), temp_context)); - } - } - } - Ok((PreLlmRequestResult::Continue, temp_context)) - } - - /// Hooks: on_turn_end - async fn run_on_turn_end_hooks(&self) -> Result { - let mut temp_items = self.history.clone(); - for hook in &self.hooks.on_turn_end { - let result = hook.call(&mut temp_items).await?; - match result { - OnTurnEndResult::Finish => continue, - OnTurnEndResult::ContinueWithMessages(items) => { - return Ok(OnTurnEndResult::ContinueWithMessages(items)); - } - OnTurnEndResult::Paused => return Ok(OnTurnEndResult::Paused), - } - } - Ok(OnTurnEndResult::Finish) - } - - /// Hooks: on_abort - async fn run_on_abort_hooks(&self, reason: &str) -> Result<(), WorkerError> { - let mut reason = reason.to_string(); - for hook in &self.hooks.on_abort { - hook.call(&mut reason).await?; - } - Ok(()) - } - - fn apply_stream_hook_result(result: StreamHookResult) -> Result<(), WorkerError> { - match result { - StreamHookResult::Continue => Ok(()), - StreamHookResult::Abort(reason) => Err(WorkerError::Aborted(reason)), - StreamHookResult::Pause => { - Err(WorkerError::Aborted("Paused by stream hook".to_string())) - } - } - } - - async fn run_on_stream_chunk_hooks( - &self, - event: crate::event::Event, - ) -> Result<(), WorkerError> { - let mut context = StreamChunkContext { event }; - for hook in &self.hooks.on_stream_chunk { - let result = hook.call(&mut context).await?; - Self::apply_stream_hook_result(result)?; - } - Ok(()) - } - - async fn run_on_text_delta_hooks( - &self, - index: usize, - delta: String, - ) -> Result<(), WorkerError> { - let mut context = TextDeltaContext { index, delta }; - for hook in &self.hooks.on_text_delta { - let result = hook.call(&mut context).await?; - Self::apply_stream_hook_result(result)?; - } - Ok(()) - } - - async fn run_on_tool_call_delta_hooks( - &self, - index: usize, - delta_json_fragment: String, - ) -> Result<(), WorkerError> { - let mut context = ToolCallDeltaContext { - index, - delta_json_fragment, - }; - for hook in &self.hooks.on_tool_call_delta { - let result = hook.call(&mut context).await?; - Self::apply_stream_hook_result(result)?; - } - Ok(()) - } - - async fn run_on_stream_complete_hooks( - &self, - turn: usize, - event_count: usize, - ) -> Result<(), WorkerError> { - let mut context = StreamCompleteContext { turn, event_count }; - for hook in &self.hooks.on_stream_complete { - let result = hook.call(&mut context).await?; - Self::apply_stream_hook_result(result)?; - } - Ok(()) - } - async fn finalize_interruption( &mut self, result: Result, @@ -716,10 +534,7 @@ impl Worker { WorkerError::Cancelled => "Cancelled".to_string(), _ => err.to_string(), }; - if let Err(hook_err) = self.run_on_abort_hooks(&reason).await { - self.last_run_interrupted = true; - return Err(hook_err); - } + self.interceptor.on_abort(&reason).await; Err(err) } } @@ -780,59 +595,41 @@ impl Worker { // Retained because it's needed for PostToolCall hooks let mut call_info_map = HashMap::new(); - // Phase 1: Apply pre_tool_call hooks (determine skip/abort) + // Phase 1: Apply pre_tool_call interceptor (determine skip/abort) let mut approved_calls = Vec::new(); for mut tool_call in tool_calls { - // Get tool definition if let Some((meta, tool)) = self.tool_server.get_tool(&tool_call.name) { - // Create context - let mut context = ToolCallContext { + let mut info = ToolCallInfo { call: tool_call.clone(), meta, tool, }; - let mut skip = false; - for hook in &self.hooks.pre_tool_call { - let result = hook - .call(&mut context) - .await - .inspect_err(|_| self.last_run_interrupted = true)?; - match result { - PreToolCallResult::Continue => {} - PreToolCallResult::Skip => { - skip = true; - break; - } - PreToolCallResult::Abort(reason) => { - self.last_run_interrupted = true; - return Err(WorkerError::Aborted(reason)); - } - PreToolCallResult::Pause => { - self.last_run_interrupted = true; - return Ok(ToolExecutionResult::Paused); - } + match self.interceptor.pre_tool_call(&mut info).await { + PreToolAction::Continue => {} + PreToolAction::Skip => { + continue; + } + PreToolAction::Abort(reason) => { + self.last_run_interrupted = true; + return Err(WorkerError::Aborted(reason)); + } + PreToolAction::Pause => { + self.last_run_interrupted = true; + return Ok(ToolExecutionResult::Paused); } } - // Reflect changes made by hooks - tool_call = context.call; + // Reflect changes made by interceptor + tool_call = info.call; - // Save to map (only if executing) - if !skip { - call_info_map.insert( - tool_call.id.clone(), - ( - tool_call.clone(), - context.meta.clone(), - context.tool.clone(), - ), - ); - approved_calls.push(tool_call); - } + call_info_map.insert( + tool_call.id.clone(), + (tool_call.clone(), info.meta.clone(), info.tool.clone()), + ); + approved_calls.push(tool_call); } else { // Unknown tools go into approved list as-is (will error at execution) - // Hooks are not applied (no Meta available) approved_calls.push(tool_call); } } @@ -879,32 +676,25 @@ impl Worker { } } - // Phase 3: Apply post_tool_call hooks + // Phase 3: Apply post_tool_call interceptor for tool_result in &mut results { - // Get saved information if let Some((tool_call, meta, tool)) = call_info_map.get(&tool_result.tool_use_id) { - let mut context = PostToolCallContext { + let mut info = ToolResultInfo { call: tool_call.clone(), result: tool_result.clone(), meta: meta.clone(), tool: tool.clone(), }; - for hook in &self.hooks.post_tool_call { - let result = hook - .call(&mut context) - .await - .inspect_err(|_| self.last_run_interrupted = true)?; - match result { - PostToolCallResult::Continue => {} - PostToolCallResult::Abort(reason) => { - self.last_run_interrupted = true; - return Err(WorkerError::Aborted(reason)); - } + match self.interceptor.post_tool_call(&mut info).await { + PostToolAction::Continue => {} + PostToolAction::Abort(reason) => { + self.last_run_interrupted = true; + return Err(WorkerError::Aborted(reason)); } } - // Reflect hook-modified results - *tool_result = context.result; + // Reflect interceptor-modified results + *tool_result = info.result; } } @@ -963,21 +753,18 @@ impl Worker { cb(current_turn); } - // Hook: pre_llm_request - let (control, request_context) = self - .run_pre_llm_request_hooks() - .await - .inspect_err(|_| self.last_run_interrupted = true)?; - match control { - PreLlmRequestResult::Cancel(reason) => { - info!(reason = %reason, "Aborted by hook"); + // Interceptor: pre_llm_request + let mut request_context = self.history.clone(); + match self.interceptor.pre_llm_request(&mut request_context).await { + PreRequestAction::Cancel(reason) => { + info!(reason = %reason, "Aborted by interceptor"); for cb in &self.turn_end_cbs { cb(current_turn); } self.last_run_interrupted = true; return Err(WorkerError::Aborted(reason)); } - PreLlmRequestResult::Continue => {} + PreRequestAction::Continue => {} } // Build request @@ -1025,26 +812,6 @@ impl Worker { let event = result .inspect_err(|_| self.last_run_interrupted = true)?; self.timeline.dispatch(&event); - - self.run_on_stream_chunk_hooks(event.clone()) - .await - .inspect_err(|_| self.last_run_interrupted = true)?; - - if let crate::llm_client::event::Event::BlockDelta(delta) = &event { - match &delta.delta { - crate::llm_client::event::DeltaContent::Text(text) => { - self.run_on_text_delta_hooks(delta.index, text.clone()) - .await - .inspect_err(|_| self.last_run_interrupted = true)?; - } - crate::llm_client::event::DeltaContent::InputJson(json_fragment) => { - self.run_on_tool_call_delta_hooks(delta.index, json_fragment.clone()) - .await - .inspect_err(|_| self.last_run_interrupted = true)?; - } - crate::llm_client::event::DeltaContent::Thinking(_) => {} - } - } } None => break, // Stream ended } @@ -1060,9 +827,6 @@ impl Worker { } } } - self.run_on_stream_complete_hooks(current_turn, event_count) - .await - .inspect_err(|_| self.last_run_interrupted = true)?; debug!(event_count = event_count, "Stream completed"); // Notify turn end @@ -1080,21 +844,17 @@ impl Worker { self.history.extend(assistant_items); if tool_calls.is_empty() { - // No tool calls → determine turn end - let turn_result = self - .run_on_turn_end_hooks() - .await - .inspect_err(|_| self.last_run_interrupted = true)?; - match turn_result { - OnTurnEndResult::Finish => { + // No tool calls → determine turn end via interceptor + match self.interceptor.on_turn_end(&self.history).await { + TurnEndAction::Finish => { self.last_run_interrupted = false; return Ok(WorkerResult::Finished); } - OnTurnEndResult::ContinueWithMessages(additional) => { + TurnEndAction::ContinueWithMessages(additional) => { self.history.extend(additional); continue; } - OnTurnEndResult::Paused => { + TurnEndAction::Pause => { self.last_run_interrupted = true; return Ok(WorkerResult::Paused); } @@ -1164,7 +924,7 @@ impl Worker { text_block_collector, tool_call_collector, tool_server: ToolServer::new().handle(), - hooks: HookRegistry::new(), + interceptor: Box::new(DefaultInterceptor), system_prompt: None, history: Vec::new(), locked_prefix_len: 0, @@ -1400,7 +1160,7 @@ impl Worker { text_block_collector: self.text_block_collector, tool_call_collector: self.tool_call_collector, tool_server: self.tool_server, - hooks: self.hooks, + interceptor: self.interceptor, system_prompt: self.system_prompt, history: self.history, locked_prefix_len, @@ -1439,7 +1199,7 @@ impl Worker { text_block_collector: self.text_block_collector, tool_call_collector: self.tool_call_collector, tool_server: self.tool_server, - hooks: self.hooks, + interceptor: self.interceptor, system_prompt: self.system_prompt, history: self.history, locked_prefix_len: 0, diff --git a/crates/llm-worker/tests/parallel_execution_test.rs b/crates/llm-worker/tests/parallel_execution_test.rs index ccb9872c..65bff220 100644 --- a/crates/llm-worker/tests/parallel_execution_test.rs +++ b/crates/llm-worker/tests/parallel_execution_test.rs @@ -8,11 +8,8 @@ use std::time::{Duration, Instant}; use async_trait::async_trait; use llm_worker::Worker; -use llm_worker::hook::{ - Hook, HookError, PostToolCall, PostToolCallContext, PostToolCallResult, PreToolCall, - PreToolCallResult, ToolCallContext, -}; use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent}; +use llm_worker::interceptor::{Interceptor, PostToolAction, PreToolAction, ToolCallInfo, ToolResultInfo}; use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta}; mod common; @@ -156,21 +153,21 @@ async fn test_before_tool_call_skip() { worker.register_tool(allowed_tool.definition()).unwrap(); worker.register_tool(blocked_tool.definition()).unwrap(); - // Hook to skip "blocked_tool" - struct BlockingHook; + // Policy to skip "blocked_tool" + struct BlockingPolicy; #[async_trait] - impl Hook for BlockingHook { - async fn call(&self, ctx: &mut ToolCallContext) -> Result { - if ctx.call.name == "blocked_tool" { - Ok(PreToolCallResult::Skip) + impl Interceptor for BlockingPolicy { + async fn pre_tool_call(&self, info: &mut ToolCallInfo) -> PreToolAction { + if info.call.name == "blocked_tool" { + PreToolAction::Skip } else { - Ok(PreToolCallResult::Continue) + PreToolAction::Continue } } } - worker.add_pre_tool_call_hook(BlockingHook); + worker.set_interceptor(BlockingPolicy); let _result = worker.run("Test hook").await; @@ -235,25 +232,22 @@ async fn test_post_tool_call_modification() { worker.register_tool(simple_tool_definition()).unwrap(); - // Hook to modify results - struct ModifyingHook { + // Policy to modify results + struct ModifyingPolicy { modified_content: Arc>>, } #[async_trait] - impl Hook for ModifyingHook { - async fn call( - &self, - ctx: &mut PostToolCallContext, - ) -> Result { - ctx.result.content = format!("[Modified] {}", ctx.result.content); - *self.modified_content.lock().unwrap() = Some(ctx.result.content.clone()); - Ok(PostToolCallResult::Continue) + impl Interceptor for ModifyingPolicy { + async fn post_tool_call(&self, info: &mut ToolResultInfo) -> PostToolAction { + info.result.content = format!("[Modified] {}", info.result.content); + *self.modified_content.lock().unwrap() = Some(info.result.content.clone()); + PostToolAction::Continue } } let modified_content = Arc::new(std::sync::Mutex::new(None)); - worker.add_post_tool_call_hook(ModifyingHook { + worker.set_interceptor(ModifyingPolicy { modified_content: modified_content.clone(), }); diff --git a/crates/llm-worker/tests/streaming_hook_test.rs b/crates/llm-worker/tests/streaming_hook_test.rs deleted file mode 100644 index dad4632a..00000000 --- a/crates/llm-worker/tests/streaming_hook_test.rs +++ /dev/null @@ -1,194 +0,0 @@ -//! Streaming hook tests - -mod common; - -use std::sync::{Arc, Mutex}; - -use async_trait::async_trait; -use common::MockLlmClient; -use llm_worker::hook::{ - Hook, HookError, OnStreamChunk, OnStreamComplete, OnTextDelta, OnToolCallDelta, - StreamChunkContext, StreamCompleteContext, StreamHookResult, TextDeltaContext, - ToolCallDeltaContext, -}; -use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent}; -use llm_worker::{Worker, WorkerError}; - -#[tokio::test] -async fn test_text_delta_hooks_run_in_registration_order() { - let events = vec![ - Event::text_block_start(0), - Event::text_delta(0, "A"), - Event::text_delta(0, "B"), - Event::text_block_stop(0, None), - Event::Status(StatusEvent { - status: ResponseStatus::Completed, - }), - ]; - - let client = MockLlmClient::new(events); - let mut worker = Worker::new(client); - - struct RecorderHook { - label: &'static str, - records: Arc>>, - } - - #[async_trait] - impl Hook for RecorderHook { - async fn call(&self, input: &mut TextDeltaContext) -> Result { - self.records - .lock() - .unwrap() - .push(format!("{}:{}", self.label, input.delta)); - Ok(StreamHookResult::Continue) - } - } - - let records = Arc::new(Mutex::new(Vec::new())); - worker.add_on_text_delta_hook(RecorderHook { - label: "first", - records: records.clone(), - }); - worker.add_on_text_delta_hook(RecorderHook { - label: "second", - records: records.clone(), - }); - - let result = worker.run("hello").await; - assert!(result.is_ok(), "run should succeed: {result:?}"); - - let got = records.lock().unwrap().clone(); - assert_eq!( - got, - vec![ - "first:A".to_string(), - "second:A".to_string(), - "first:B".to_string(), - "second:B".to_string(), - ] - ); -} - -#[tokio::test] -async fn test_stream_chunk_and_stream_complete_hooks_are_called() { - let events = vec![ - Event::ping(), - Event::text_block_start(0), - Event::text_delta(0, "hi"), - Event::text_block_stop(0, None), - Event::usage(10, 5), - Event::Status(StatusEvent { - status: ResponseStatus::Completed, - }), - ]; - - let client = MockLlmClient::new(events); - let mut worker = Worker::new(client); - - struct ChunkCounter(Arc>); - - #[async_trait] - impl Hook for ChunkCounter { - async fn call( - &self, - _input: &mut StreamChunkContext, - ) -> Result { - let mut guard = self.0.lock().unwrap(); - *guard += 1; - Ok(StreamHookResult::Continue) - } - } - - struct CompleteRecorder(Arc>>); - - #[async_trait] - impl Hook for CompleteRecorder { - async fn call( - &self, - input: &mut StreamCompleteContext, - ) -> Result { - self.0.lock().unwrap().push((input.turn, input.event_count)); - Ok(StreamHookResult::Continue) - } - } - - let chunk_count = Arc::new(Mutex::new(0usize)); - let completes = Arc::new(Mutex::new(Vec::new())); - - worker.add_on_stream_chunk_hook(ChunkCounter(chunk_count.clone())); - worker.add_on_stream_complete_hook(CompleteRecorder(completes.clone())); - - let result = worker.run("hello").await; - assert!(result.is_ok(), "run should succeed: {result:?}"); - - assert_eq!(*chunk_count.lock().unwrap(), 6); - assert_eq!(completes.lock().unwrap().as_slice(), &[(0usize, 6usize)]); -} - -#[tokio::test] -async fn test_tool_call_delta_hook_can_abort_run() { - let events = vec![ - Event::tool_use_start(0, "call_1", "unknown_tool"), - Event::tool_input_delta(0, r#"{"x":1}"#), - Event::tool_use_stop(0), - Event::Status(StatusEvent { - status: ResponseStatus::Completed, - }), - ]; - - let client = MockLlmClient::new(events); - let mut worker = Worker::new(client); - - struct AbortToolDelta; - - #[async_trait] - impl Hook for AbortToolDelta { - async fn call( - &self, - _input: &mut ToolCallDeltaContext, - ) -> Result { - Ok(StreamHookResult::Abort("blocked by tool delta".to_string())) - } - } - - worker.add_on_tool_call_delta_hook(AbortToolDelta); - - let result = worker.run("hello").await; - match result { - Err(WorkerError::Aborted(reason)) => assert_eq!(reason, "blocked by tool delta"), - other => panic!("expected aborted result, got: {other:?}"), - } -} - -#[tokio::test] -async fn test_stream_hook_pause_is_mapped_to_aborted() { - let events = vec![ - Event::text_block_start(0), - Event::text_delta(0, "pause me"), - Event::text_block_stop(0, None), - Event::Status(StatusEvent { - status: ResponseStatus::Completed, - }), - ]; - - let client = MockLlmClient::new(events); - let mut worker = Worker::new(client); - - struct PauseHook; - - #[async_trait] - impl Hook for PauseHook { - async fn call(&self, _input: &mut TextDeltaContext) -> Result { - Ok(StreamHookResult::Pause) - } - } - - worker.add_on_text_delta_hook(PauseHook); - - let result = worker.run("hello").await; - match result { - Err(WorkerError::Aborted(reason)) => assert_eq!(reason, "Paused by stream hook"), - other => panic!("expected aborted result, got: {other:?}"), - } -} diff --git a/crates/pod/Cargo.toml b/crates/pod/Cargo.toml index 598552e6..15c0e38a 100644 --- a/crates/pod/Cargo.toml +++ b/crates/pod/Cargo.toml @@ -5,6 +5,7 @@ edition.workspace = true license.workspace = true [dependencies] +async-trait = "0.1.89" clap = { version = "4.6.0", features = ["derive"] } llm-worker = { version = "0.2.1", path = "../llm-worker" } llm-worker-persistence = { version = "0.1.0", path = "../llm-worker-persistence" } diff --git a/crates/pod/src/hook.rs b/crates/pod/src/hook.rs new file mode 100644 index 00000000..e3163a6a --- /dev/null +++ b/crates/pod/src/hook.rs @@ -0,0 +1,160 @@ +//! Pod-layer hook infrastructure +//! +//! Provides the `Hook` trait and `HookRegistry` for orchestration hooks +//! that govern control-flow decisions in the Worker execution loop. +//! +//! The type system (`HookEventKind` / `Hook`) mirrors the pattern +//! originally in llm-worker, now at the insomnia layer where orchestration +//! concerns belong. + +use async_trait::async_trait; +use llm_worker::interceptor::{ + PostToolAction, PreRequestAction, PreToolAction, PromptAction, ToolCallInfo, ToolResultInfo, + TurnEndAction, +}; +use llm_worker::Item; + +// ============================================================================= +// Hook Event Kinds +// ============================================================================= + +/// Marker trait for hook event kinds. +/// +/// Each event kind specifies its input (passed mutably to hooks) and +/// output (the control-flow action returned by hooks). +pub trait HookEventKind: Send + Sync + 'static { + /// Mutable input passed to the hook. + type Input; + /// Control-flow action returned by the hook. + type Output; +} + +// --- Event kind markers --- + +/// After receiving user input, before adding to history. +pub struct OnPromptSubmit; +/// Before each LLM request. +pub struct PreLlmRequest; +/// Before each tool is executed. +pub struct PreToolCall; +/// After each tool completes. +pub struct PostToolCall; +/// When a turn ends with no tool calls. +pub struct OnTurnEnd; +/// When execution is interrupted. +pub struct OnAbort; + +impl HookEventKind for OnPromptSubmit { + type Input = Item; + type Output = PromptAction; +} + +impl HookEventKind for PreLlmRequest { + type Input = Vec; + type Output = PreRequestAction; +} + +impl HookEventKind for PreToolCall { + type Input = ToolCallInfo; + type Output = PreToolAction; +} + +impl HookEventKind for PostToolCall { + type Input = ToolResultInfo; + type Output = PostToolAction; +} + +impl HookEventKind for OnTurnEnd { + type Input = Vec; + type Output = TurnEndAction; +} + +impl HookEventKind for OnAbort { + type Input = String; + type Output = (); +} + +// ============================================================================= +// Hook Trait +// ============================================================================= + +/// Async hook for a specific event kind. +/// +/// Hooks receive mutable access to the event's input and return a +/// control-flow action. Multiple hooks can be registered per event; +/// they are evaluated in registration order and short-circuit on the +/// first non-Continue result. +#[async_trait] +pub trait Hook: Send + Sync { + async fn call(&self, input: &mut E::Input) -> E::Output; +} + +// ============================================================================= +// Hook Registry +// ============================================================================= + +/// Builder for constructing a frozen `HookRegistry`. +/// +/// Hooks are added during setup, then `build()` produces an immutable +/// registry that can be shared via `Arc`. +#[derive(Default)] +pub struct HookRegistryBuilder { + on_prompt_submit: Vec>>, + pre_llm_request: Vec>>, + pre_tool_call: Vec>>, + post_tool_call: Vec>>, + on_turn_end: Vec>>, + on_abort: Vec>>, +} + +impl HookRegistryBuilder { + pub fn new() -> Self { + Self::default() + } + + pub fn add_on_prompt_submit(&mut self, hook: impl Hook + 'static) { + self.on_prompt_submit.push(Box::new(hook)); + } + + pub fn add_pre_llm_request(&mut self, hook: impl Hook + 'static) { + self.pre_llm_request.push(Box::new(hook)); + } + + pub fn add_pre_tool_call(&mut self, hook: impl Hook + 'static) { + self.pre_tool_call.push(Box::new(hook)); + } + + pub fn add_post_tool_call(&mut self, hook: impl Hook + 'static) { + self.post_tool_call.push(Box::new(hook)); + } + + pub fn add_on_turn_end(&mut self, hook: impl Hook + 'static) { + self.on_turn_end.push(Box::new(hook)); + } + + pub fn add_on_abort(&mut self, hook: impl Hook + 'static) { + self.on_abort.push(Box::new(hook)); + } + + /// Freeze the builder into an immutable registry. + pub fn build(self) -> HookRegistry { + HookRegistry { + on_prompt_submit: self.on_prompt_submit, + pre_llm_request: self.pre_llm_request, + pre_tool_call: self.pre_tool_call, + post_tool_call: self.post_tool_call, + on_turn_end: self.on_turn_end, + on_abort: self.on_abort, + } + } +} + +/// Frozen registry of hooks. Constructed via [`HookRegistryBuilder::build()`]. +pub struct HookRegistry { + pub(crate) on_prompt_submit: Vec>>, + pub(crate) pre_llm_request: Vec>>, + pub(crate) pre_tool_call: Vec>>, + pub(crate) post_tool_call: Vec>>, + pub(crate) on_turn_end: Vec>>, + pub(crate) on_abort: Vec>>, +} diff --git a/crates/pod/src/hook_interceptor.rs b/crates/pod/src/hook_interceptor.rs new file mode 100644 index 00000000..4d15550d --- /dev/null +++ b/crates/pod/src/hook_interceptor.rs @@ -0,0 +1,87 @@ +//! HookInterceptor — bridges Pod-layer hooks to Worker's Interceptor trait. + +use std::sync::Arc; + +use async_trait::async_trait; +use llm_worker::interceptor::{ + Interceptor, PostToolAction, PreRequestAction, PreToolAction, PromptAction, ToolCallInfo, + ToolResultInfo, TurnEndAction, +}; +use llm_worker::Item; + +use crate::hook::HookRegistry; + +/// An `Interceptor` implementation that delegates to a `HookRegistry`. +/// +/// Each method iterates the registered hooks in order and short-circuits +/// on the first non-Continue (or non-Finish) result. +pub(crate) struct HookInterceptor { + registry: Arc, +} + +impl HookInterceptor { + pub(crate) fn new(registry: Arc) -> Self { + Self { registry } + } +} + +#[async_trait] +impl Interceptor for HookInterceptor { + async fn on_prompt_submit(&self, item: &mut Item) -> PromptAction { + for hook in &self.registry.on_prompt_submit { + let action = hook.call(item).await; + if !matches!(action, PromptAction::Continue) { + return action; + } + } + PromptAction::Continue + } + + async fn pre_llm_request(&self, context: &mut Vec) -> PreRequestAction { + for hook in &self.registry.pre_llm_request { + let action = hook.call(context).await; + if !matches!(action, PreRequestAction::Continue) { + return action; + } + } + PreRequestAction::Continue + } + + async fn pre_tool_call(&self, info: &mut ToolCallInfo) -> PreToolAction { + for hook in &self.registry.pre_tool_call { + let action = hook.call(info).await; + if !matches!(action, PreToolAction::Continue) { + return action; + } + } + PreToolAction::Continue + } + + async fn post_tool_call(&self, info: &mut ToolResultInfo) -> PostToolAction { + for hook in &self.registry.post_tool_call { + let action = hook.call(info).await; + if !matches!(action, PostToolAction::Continue) { + return action; + } + } + PostToolAction::Continue + } + + async fn on_turn_end(&self, history: &[Item]) -> TurnEndAction { + let mut history_vec = history.to_vec(); + for hook in &self.registry.on_turn_end { + let action = hook.call(&mut history_vec).await; + if !matches!(action, TurnEndAction::Finish) { + return action; + } + } + TurnEndAction::Finish + } + + async fn on_abort(&self, reason: &str) { + let mut reason_string = reason.to_string(); + for hook in &self.registry.on_abort { + hook.call(&mut reason_string).await; + } + } +} diff --git a/crates/pod/src/lib.rs b/crates/pod/src/lib.rs index 3aec7b49..1da428b8 100644 --- a/crates/pod/src/lib.rs +++ b/crates/pod/src/lib.rs @@ -1,12 +1,15 @@ pub mod controller; +pub mod hook; pub mod runtime_dir; pub mod shared_state; pub mod socket_server; +mod hook_interceptor; mod pod; pub use controller::{PodController, PodHandle}; pub use manifest::{PodManifest, ProviderConfig, ProviderKind, Scope}; +pub use hook::{Hook, HookEventKind, HookRegistryBuilder}; pub use pod::{Pod, PodError, PodRunResult, apply_worker_manifest}; pub use protocol::{ErrorCode, Event, Method, TurnResult}; pub use provider::{ProviderError, build_client}; diff --git a/crates/pod/src/pod.rs b/crates/pod/src/pod.rs index 3db409a6..f593416e 100644 --- a/crates/pod/src/pod.rs +++ b/crates/pod/src/pod.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use llm_worker::llm_client::client::LlmClient; use llm_worker::llm_client::RequestConfig; use llm_worker::Worker; @@ -7,6 +9,12 @@ use llm_worker_persistence::{ use manifest::{PodManifest, Scope, WorkerManifest}; +use crate::hook::{ + Hook, HookRegistryBuilder, OnAbort, OnPromptSubmit, OnTurnEnd, PostToolCall, PreLlmRequest, + PreToolCall, +}; +use crate::hook_interceptor::HookInterceptor; + /// An independent agent execution unit. /// /// Wraps a persistent [`Session`] with manifest metadata and an optional @@ -15,6 +23,8 @@ pub struct Pod { manifest: PodManifest, session: Session, scope: Option, + hook_builder: HookRegistryBuilder, + interceptor_installed: bool, } impl Pod { @@ -34,6 +44,8 @@ impl Pod { manifest, session, scope, + hook_builder: HookRegistryBuilder::new(), + interceptor_installed: false, }) } @@ -50,6 +62,8 @@ impl Pod { manifest, session, scope, + hook_builder: HookRegistryBuilder::new(), + interceptor_installed: false, }) } @@ -76,14 +90,75 @@ impl Pod { &mut self.session } + // --- Hook registration --- + // + // Hooks must be registered before the first call to `run()` or `resume()`. + // Attempting to add a hook after execution has started will panic. + + fn assert_hooks_open(&self) { + assert!( + !self.interceptor_installed, + "cannot add hooks after run() or resume() has been called" + ); + } + + /// Register a hook that runs after receiving user input. + pub fn add_on_prompt_submit_hook(&mut self, hook: impl Hook + 'static) { + self.assert_hooks_open(); + self.hook_builder.add_on_prompt_submit(hook); + } + + /// Register a hook that runs before each LLM request. + pub fn add_pre_llm_request_hook(&mut self, hook: impl Hook + 'static) { + self.assert_hooks_open(); + self.hook_builder.add_pre_llm_request(hook); + } + + /// Register a hook that runs before each tool call. + pub fn add_pre_tool_call_hook(&mut self, hook: impl Hook + 'static) { + self.assert_hooks_open(); + self.hook_builder.add_pre_tool_call(hook); + } + + /// Register a hook that runs after each tool call. + pub fn add_post_tool_call_hook(&mut self, hook: impl Hook + 'static) { + self.assert_hooks_open(); + self.hook_builder.add_post_tool_call(hook); + } + + /// Register a hook that runs at the end of a turn. + pub fn add_on_turn_end_hook(&mut self, hook: impl Hook + 'static) { + self.assert_hooks_open(); + self.hook_builder.add_on_turn_end(hook); + } + + /// Register a hook that runs when execution is aborted. + pub fn add_on_abort_hook(&mut self, hook: impl Hook + 'static) { + self.assert_hooks_open(); + self.hook_builder.add_on_abort(hook); + } + + /// Install the hook-based interceptor on the Worker if not already done. + fn ensure_interceptor_installed(&mut self) { + if !self.interceptor_installed { + let builder = std::mem::take(&mut self.hook_builder); + let registry = Arc::new(builder.build()); + let interceptor = HookInterceptor::new(registry); + self.session.worker.set_interceptor(interceptor); + self.interceptor_installed = true; + } + } + /// Send user input and run until the LLM turn completes. pub async fn run(&mut self, input: impl Into) -> Result { + self.ensure_interceptor_installed(); let result = self.session.run(input).await?; Ok(result.into()) } /// Resume from a paused state. pub async fn resume(&mut self) -> Result { + self.ensure_interceptor_installed(); let result = self.session.resume().await?; Ok(result.into()) } @@ -107,6 +182,8 @@ impl Pod, St> { manifest, session, scope, + hook_builder: HookRegistryBuilder::new(), + interceptor_installed: false, }) } } diff --git a/tickets/api-key-file.md b/tickets/api-key-file.md new file mode 100644 index 00000000..1f20db4d --- /dev/null +++ b/tickets/api-key-file.md @@ -0,0 +1,44 @@ +# api_key_file: ファイルパスによるAPIキー解決 + +## 背景 + +現状、APIキーの取得手段は `api_key_env`(環境変数名の指定)のみ。 +永続化やインタラクティブ入力の仕組みがなく、キー管理をユーザーのシェル環境に完全依存している。 + +## やること + +マニフェストの `ProviderConfig` に `api_key_file: Option` を追加し、ファイルからAPIキーを読み取れるようにする。 + +### マニフェスト + +```toml +[provider] +kind = "anthropic" +model = "claude-sonnet-4-20250514" +api_key_file = "~/.config/insomnia/keys/anthropic" +``` + +- ファイルにはキーのみを記載(読み込み時に trim) +- `~` 展開が必要 +- 相対パスはマニフェストファイルの位置基準 + +### api_key_env との関係 + +- 排他。両方指定されたらエラー +- Ollama は両方不要のまま + +### 変更箇所 + +1. **manifest**: `ProviderConfig` に `api_key_file: Option` を追加 +2. **provider**: `build_client()` でファイル読み取りロジックを追加。排他バリデーション +3. **provider**: `ProviderError` にキー不在を明示するバリアント追加(将来の TUI フォールバック用) + +### 暗号化について + +現段階では扱わない。ファイルパーミッション(0600)で十分。 +将来エンドユーザー向けに暗号化が必要になった場合、provider の手前に復号レイヤーを挟む形で対応できる。`api_key_file` の設計自体は変更不要。 + +### 将来の拡張 + +- TUI サブコマンド(`insomnia key set anthropic` 等)がこのファイルに書き込むラッパーになる +- `api_key_cmd`(コマンド実行によるキー取得)は `api_key_file` で不足が生じた時点で検討 diff --git a/tickets/context-compaction.md b/tickets/context-compaction.md index 509f0177..65b82034 100644 --- a/tickets/context-compaction.md +++ b/tickets/context-compaction.md @@ -64,4 +64,4 @@ history 内のツール出力を走査: ## 依存チケット -- [remove-hook-module.md](remove-hook-module.md) — Hook が insomnia 層に移動した後、PreLlmRequest で Prune を差し込む +- ~~[remove-hook-module.md](remove-hook-module.md)~~ — 完了。PreLlmRequest は Pod 層の `hook::Hook` として利用可能 diff --git a/tickets/permission-extension-point.md b/tickets/permission-extension-point.md index 70e2e934..d8b4ab8e 100644 --- a/tickets/permission-extension-point.md +++ b/tickets/permission-extension-point.md @@ -51,4 +51,4 @@ action = "allow" ## 依存チケット -- [remove-hook-module.md](remove-hook-module.md) — PreToolCall が insomnia 層に移動してから実装 +- ~~[remove-hook-module.md](remove-hook-module.md)~~ — 完了。PreToolCall は Pod 層の `hook::Hook` として利用可能 diff --git a/tickets/remove-hook-module.md b/tickets/remove-hook-module.md deleted file mode 100644 index 8f3faa2a..00000000 --- a/tickets/remove-hook-module.md +++ /dev/null @@ -1,42 +0,0 @@ -# Hook モジュールの llm-worker からの除去 - -## 背景 - -llm-worker は低レベル基盤に徹するべきだが、現行の `hook` モジュールは -高レベルのオーケストレーション関心(承認フロー、ターン制御等)を含んでいる。 -Claude Code の Hooks のような機能は insomnia 層の責務。 - -低レベルのストリーム介入は、クロージャベースの Subscriber API で十分カバーできる。 - -## 方針 - -`hook` モジュールを llm-worker から除去し、責務を分離する。 - -### Subscriber で代替(削除) - -ストリーム観測・介入はクロージャ Subscriber で対応: - -- `OnTextDelta` -- `OnToolCallDelta` -- `OnStreamChunk` -- `OnStreamComplete` - -### insomnia 層に移動 - -高レベルオーケストレーションは上位層が担う: - -- `OnPromptSubmit` -- `PreLlmRequest` -- `PreToolCall` / `PostToolCall` -- `OnTurnEnd` -- `OnAbort` - -## 設計ポイント - -- Worker の実行ループは「ストリーム受信 → ツール実行 → 次ターン」に集中させる -- 介入ポイント(承認、中断、ターン継続判断)は insomnia 層が提供する -- `HookEventKind` / `Hook` の型設計自体は良いので、insomnia 層で再利用可能 - -## 依存チケット - -- [subscriber-closure-api.md](subscriber-closure-api.md) — ストリーム系 Hook の代替先