HookのPod側への移動・Interceptorの実装
This commit is contained in:
parent
496038307f
commit
fc8ff9362e
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -1,4 +1,4 @@
|
|||
/target
|
||||
.direnv
|
||||
*.local
|
||||
*.local*
|
||||
.env
|
||||
|
|
|
|||
2
TODO.md
2
TODO.md
|
|
@ -10,3 +10,5 @@
|
|||
- [x] セッションエントリのハッシュチェーン
|
||||
- [x] Subscriber → クロージャ API 移行
|
||||
- [x] JSONL ストリーム変換ユーティリティ (protocol::stream)
|
||||
- [x] Hook モジュールの llm-worker からの除去 → [tickets/remove-hook-module.md](tickets/remove-hook-module.md)
|
||||
- [ ] api_key_file: ファイルパスによるAPIキー解決 → [tickets/api-key-file.md](tickets/api-key-file.md)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use std::sync::Arc;
|
|||
|
||||
use async_trait::async_trait;
|
||||
use common::MockLlmClient;
|
||||
use llm_worker::hook::{Hook, HookError, OnTurnEnd, OnTurnEndResult};
|
||||
use llm_worker::interceptor::{Interceptor, TurnEndAction};
|
||||
use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent};
|
||||
use llm_worker::llm_client::types::{Item, RequestConfig};
|
||||
use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta};
|
||||
|
|
@ -76,13 +76,13 @@ fn weather_tool_definition() -> ToolDefinition {
|
|||
})
|
||||
}
|
||||
|
||||
/// Hook that forces Pause on the first turn end.
|
||||
struct PauseOnFirstTurnEnd;
|
||||
/// Policy that forces Pause on every turn end.
|
||||
struct PausePolicy;
|
||||
|
||||
#[async_trait]
|
||||
impl Hook<OnTurnEnd> for PauseOnFirstTurnEnd {
|
||||
async fn call(&self, _input: &mut Vec<Item>) -> Result<OnTurnEndResult, HookError> {
|
||||
Ok(OnTurnEndResult::Paused)
|
||||
impl Interceptor for PausePolicy {
|
||||
async fn on_turn_end(&self, _history: &[Item]) -> TurnEndAction {
|
||||
TurnEndAction::Pause
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -214,7 +214,7 @@ async fn session_resume_after_pause() {
|
|||
let client = MockLlmClient::with_responses(tool_call_events());
|
||||
let mut worker = Worker::new(client);
|
||||
worker.register_tool(weather_tool_definition()).unwrap();
|
||||
worker.add_on_turn_end_hook(PauseOnFirstTurnEnd);
|
||||
worker.set_interceptor(PausePolicy);
|
||||
|
||||
let mut session = Session::new(worker, store.clone(), SessionConfig::default())
|
||||
.await
|
||||
|
|
|
|||
|
|
@ -41,7 +41,6 @@ use tracing_subscriber::EnvFilter;
|
|||
use clap::{Parser, ValueEnum};
|
||||
use llm_worker::{
|
||||
Worker,
|
||||
hook::{Hook, HookError, PostToolCall, PostToolCallContext, PostToolCallResult},
|
||||
llm_client::{
|
||||
LlmClient,
|
||||
providers::{
|
||||
|
|
@ -49,6 +48,7 @@ use llm_worker::{
|
|||
openai::OpenAIClient,
|
||||
},
|
||||
},
|
||||
interceptor::{Interceptor, PostToolAction, ToolResultInfo},
|
||||
timeline::{Handler, TextBlockEvent, TextBlockKind, ToolUseBlockEvent, ToolUseBlockKind},
|
||||
};
|
||||
use llm_worker_macros::tool_registry;
|
||||
|
|
@ -270,34 +270,34 @@ impl Handler<ToolUseBlockKind> for ToolCallPrinter {
|
|||
}
|
||||
}
|
||||
|
||||
/// Hook that displays tool execution results
|
||||
struct ToolResultPrinterHook {
|
||||
/// Policy that displays tool execution results.
|
||||
struct ToolResultPrinterPolicy {
|
||||
call_names: Arc<Mutex<HashMap<String, String>>>,
|
||||
}
|
||||
|
||||
impl ToolResultPrinterHook {
|
||||
impl ToolResultPrinterPolicy {
|
||||
fn new(call_names: Arc<Mutex<HashMap<String, String>>>) -> Self {
|
||||
Self { call_names }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Hook<PostToolCall> for ToolResultPrinterHook {
|
||||
async fn call(&self, ctx: &mut PostToolCallContext) -> Result<PostToolCallResult, HookError> {
|
||||
impl Interceptor for ToolResultPrinterPolicy {
|
||||
async fn post_tool_call(&self, info: &mut ToolResultInfo) -> PostToolAction {
|
||||
let name = self
|
||||
.call_names
|
||||
.lock()
|
||||
.unwrap()
|
||||
.remove(&ctx.result.tool_use_id)
|
||||
.unwrap_or_else(|| ctx.result.tool_use_id.clone());
|
||||
.remove(&info.result.tool_use_id)
|
||||
.unwrap_or_else(|| info.result.tool_use_id.clone());
|
||||
|
||||
if ctx.result.is_error {
|
||||
println!(" Result ({}): ❌ {}", name, ctx.result.content);
|
||||
if info.result.is_error {
|
||||
println!(" Result ({}): ❌ {}", name, info.result.content);
|
||||
} else {
|
||||
println!(" Result ({}): ✅ {}", name, ctx.result.content);
|
||||
println!(" Result ({}): ✅ {}", name, info.result.content);
|
||||
}
|
||||
|
||||
Ok(PostToolCallResult::Continue)
|
||||
PostToolAction::Continue
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -450,7 +450,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
.on_text_block(StreamingPrinter::new())
|
||||
.on_tool_use_block(ToolCallPrinter::new(tool_call_names.clone()));
|
||||
|
||||
worker.add_post_tool_call_hook(ToolResultPrinterHook::new(tool_call_names));
|
||||
worker.set_interceptor(ToolResultPrinterPolicy::new(tool_call_names));
|
||||
|
||||
// One-shot mode
|
||||
if let Some(prompt) = args.prompt {
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ use crate::handler::{
|
|||
Handler, Kind, TextBlockEvent, TextBlockKind, ToolUseBlockEvent, ToolUseBlockKind,
|
||||
ToolUseBlockStart,
|
||||
};
|
||||
use crate::hook::ToolCall;
|
||||
use crate::tool::ToolCall;
|
||||
|
||||
// =============================================================================
|
||||
// TextBlock Closure Handler
|
||||
|
|
|
|||
|
|
@ -1,310 +0,0 @@
|
|||
//! Hook-related type definitions
|
||||
//!
|
||||
//! Types used for turn control and intervention in the Worker layer
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use thiserror::Error;
|
||||
|
||||
// =============================================================================
|
||||
// Hook Event Kinds
|
||||
// =============================================================================
|
||||
|
||||
pub trait HookEventKind: Send + Sync + 'static {
|
||||
type Input;
|
||||
type Output;
|
||||
}
|
||||
|
||||
pub struct OnPromptSubmit;
|
||||
pub struct PreLlmRequest;
|
||||
pub struct PreToolCall;
|
||||
pub struct PostToolCall;
|
||||
pub struct OnTurnEnd;
|
||||
pub struct OnAbort;
|
||||
pub struct OnTextDelta;
|
||||
pub struct OnToolCallDelta;
|
||||
pub struct OnStreamChunk;
|
||||
pub struct OnStreamComplete;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum OnPromptSubmitResult {
|
||||
Continue,
|
||||
Cancel(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum PreLlmRequestResult {
|
||||
Continue,
|
||||
Cancel(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum PreToolCallResult {
|
||||
Continue,
|
||||
Skip,
|
||||
Abort(String),
|
||||
Pause,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum PostToolCallResult {
|
||||
Continue,
|
||||
Abort(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum OnTurnEndResult {
|
||||
Finish,
|
||||
ContinueWithMessages(Vec<crate::Item>),
|
||||
Paused,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum StreamHookResult {
|
||||
Continue,
|
||||
Abort(String),
|
||||
Pause,
|
||||
}
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::tool::{Tool, ToolMeta};
|
||||
|
||||
/// Input context for PreToolCall
|
||||
pub struct ToolCallContext {
|
||||
/// Tool call information (modifiable)
|
||||
pub call: ToolCall,
|
||||
/// Tool meta information (immutable)
|
||||
pub meta: ToolMeta,
|
||||
/// Tool instance (for state access)
|
||||
pub tool: Arc<dyn Tool>,
|
||||
}
|
||||
|
||||
/// Input context for PostToolCall
|
||||
pub struct PostToolCallContext {
|
||||
/// Tool call information
|
||||
pub call: ToolCall,
|
||||
/// Tool execution result (modifiable)
|
||||
pub result: ToolResult,
|
||||
/// Tool meta information (immutable)
|
||||
pub meta: ToolMeta,
|
||||
/// Tool instance (for state access)
|
||||
pub tool: Arc<dyn Tool>,
|
||||
}
|
||||
|
||||
/// Input context for OnTextDelta
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TextDeltaContext {
|
||||
/// Block index
|
||||
pub index: usize,
|
||||
/// Text delta content
|
||||
pub delta: String,
|
||||
}
|
||||
|
||||
/// Input context for OnToolCallDelta
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolCallDeltaContext {
|
||||
/// Block index
|
||||
pub index: usize,
|
||||
/// Partial JSON fragment
|
||||
pub delta_json_fragment: String,
|
||||
}
|
||||
|
||||
/// Input context for OnStreamChunk
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamChunkContext {
|
||||
/// Public worker-level event
|
||||
pub event: crate::event::Event,
|
||||
}
|
||||
|
||||
/// Input context for OnStreamComplete
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamCompleteContext {
|
||||
/// Current turn number
|
||||
pub turn: usize,
|
||||
/// Number of streamed events in this request
|
||||
pub event_count: usize,
|
||||
}
|
||||
|
||||
impl HookEventKind for OnPromptSubmit {
|
||||
type Input = crate::Item;
|
||||
type Output = OnPromptSubmitResult;
|
||||
}
|
||||
|
||||
impl HookEventKind for PreLlmRequest {
|
||||
type Input = Vec<crate::Item>;
|
||||
type Output = PreLlmRequestResult;
|
||||
}
|
||||
|
||||
impl HookEventKind for PreToolCall {
|
||||
type Input = ToolCallContext;
|
||||
type Output = PreToolCallResult;
|
||||
}
|
||||
|
||||
impl HookEventKind for PostToolCall {
|
||||
type Input = PostToolCallContext;
|
||||
type Output = PostToolCallResult;
|
||||
}
|
||||
|
||||
impl HookEventKind for OnTurnEnd {
|
||||
type Input = Vec<crate::Item>;
|
||||
type Output = OnTurnEndResult;
|
||||
}
|
||||
|
||||
impl HookEventKind for OnAbort {
|
||||
type Input = String;
|
||||
type Output = ();
|
||||
}
|
||||
|
||||
impl HookEventKind for OnTextDelta {
|
||||
type Input = TextDeltaContext;
|
||||
type Output = StreamHookResult;
|
||||
}
|
||||
|
||||
impl HookEventKind for OnToolCallDelta {
|
||||
type Input = ToolCallDeltaContext;
|
||||
type Output = StreamHookResult;
|
||||
}
|
||||
|
||||
impl HookEventKind for OnStreamChunk {
|
||||
type Input = StreamChunkContext;
|
||||
type Output = StreamHookResult;
|
||||
}
|
||||
|
||||
impl HookEventKind for OnStreamComplete {
|
||||
type Input = StreamCompleteContext;
|
||||
type Output = StreamHookResult;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Tool Call / Result Types
|
||||
// =============================================================================
|
||||
|
||||
/// Tool call information
|
||||
///
|
||||
/// Represents a ToolUse block from LLM, modifiable in Hook processing
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
/// Tool call ID (used for linking with response)
|
||||
pub id: String,
|
||||
/// Tool name
|
||||
pub name: String,
|
||||
/// Input arguments (JSON)
|
||||
pub input: Value,
|
||||
}
|
||||
|
||||
/// Tool execution result
|
||||
///
|
||||
/// Represents the result after tool execution, modifiable in Hook processing
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolResult {
|
||||
/// Corresponding tool call ID
|
||||
pub tool_use_id: String,
|
||||
/// Result content
|
||||
pub content: String,
|
||||
/// Whether this is an error
|
||||
#[serde(default)]
|
||||
pub is_error: bool,
|
||||
}
|
||||
|
||||
impl ToolResult {
|
||||
/// Create a success result
|
||||
pub fn success(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
tool_use_id: tool_use_id.into(),
|
||||
content: content.into(),
|
||||
is_error: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an error result
|
||||
pub fn error(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
tool_use_id: tool_use_id.into(),
|
||||
content: content.into(),
|
||||
is_error: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Hook Error
|
||||
// =============================================================================
|
||||
|
||||
/// Hook error
|
||||
#[derive(Debug, Error)]
|
||||
pub enum HookError {
|
||||
/// Processing was aborted
|
||||
#[error("Aborted: {0}")]
|
||||
Aborted(String),
|
||||
/// Internal error
|
||||
#[error("Hook error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Hook Trait
|
||||
// =============================================================================
|
||||
|
||||
/// Trait for handling Hook events
|
||||
///
|
||||
/// Each event type has a different return type, constrained via `HookEventKind`.
|
||||
#[async_trait]
|
||||
pub trait Hook<E: HookEventKind>: Send + Sync {
|
||||
async fn call(&self, input: &mut E::Input) -> Result<E::Output, HookError>;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Hook Registry
|
||||
// =============================================================================
|
||||
|
||||
/// Registry holding all Hooks
|
||||
///
|
||||
/// Used internally by Worker to manage all Hook types.
|
||||
pub struct HookRegistry {
|
||||
/// on_prompt_submit Hook
|
||||
pub(crate) on_prompt_submit: Vec<Box<dyn Hook<OnPromptSubmit>>>,
|
||||
/// pre_llm_request Hook
|
||||
pub(crate) pre_llm_request: Vec<Box<dyn Hook<PreLlmRequest>>>,
|
||||
/// pre_tool_call Hook
|
||||
pub(crate) pre_tool_call: Vec<Box<dyn Hook<PreToolCall>>>,
|
||||
/// post_tool_call Hook
|
||||
pub(crate) post_tool_call: Vec<Box<dyn Hook<PostToolCall>>>,
|
||||
/// on_turn_end Hook
|
||||
pub(crate) on_turn_end: Vec<Box<dyn Hook<OnTurnEnd>>>,
|
||||
/// on_abort Hook
|
||||
pub(crate) on_abort: Vec<Box<dyn Hook<OnAbort>>>,
|
||||
/// on_text_delta Hook
|
||||
pub(crate) on_text_delta: Vec<Box<dyn Hook<OnTextDelta>>>,
|
||||
/// on_tool_call_delta Hook
|
||||
pub(crate) on_tool_call_delta: Vec<Box<dyn Hook<OnToolCallDelta>>>,
|
||||
/// on_stream_chunk Hook
|
||||
pub(crate) on_stream_chunk: Vec<Box<dyn Hook<OnStreamChunk>>>,
|
||||
/// on_stream_complete Hook
|
||||
pub(crate) on_stream_complete: Vec<Box<dyn Hook<OnStreamComplete>>>,
|
||||
}
|
||||
|
||||
impl Default for HookRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl HookRegistry {
|
||||
/// Create an empty HookRegistry
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
on_prompt_submit: Vec::new(),
|
||||
pre_llm_request: Vec::new(),
|
||||
pre_tool_call: Vec::new(),
|
||||
post_tool_call: Vec::new(),
|
||||
on_turn_end: Vec::new(),
|
||||
on_abort: Vec::new(),
|
||||
on_text_delta: Vec::new(),
|
||||
on_tool_call_delta: Vec::new(),
|
||||
on_stream_chunk: Vec::new(),
|
||||
on_stream_complete: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
142
crates/llm-worker/src/interceptor.rs
Normal file
142
crates/llm-worker/src/interceptor.rs
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
//! Interceptor - control flow delegation for the Worker execution loop
|
||||
//!
|
||||
//! Defines the [`Interceptor`] trait that upper layers (e.g. Pod) implement
|
||||
//! to inject orchestration decisions (approval, skip, pause, abort)
|
||||
//! into the Worker's turn loop without the Worker knowing about
|
||||
//! higher-level concepts.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::tool::{Tool, ToolCall, ToolMeta, ToolResult};
|
||||
use crate::Item;
|
||||
|
||||
// =============================================================================
|
||||
// Action Enums
|
||||
// =============================================================================
|
||||
|
||||
/// Action after prompt submission.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum PromptAction {
|
||||
/// Proceed normally.
|
||||
Continue,
|
||||
/// Cancel with a reason.
|
||||
Cancel(String),
|
||||
}
|
||||
|
||||
/// Action before an LLM request.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum PreRequestAction {
|
||||
/// Proceed normally.
|
||||
Continue,
|
||||
/// Cancel with a reason.
|
||||
Cancel(String),
|
||||
}
|
||||
|
||||
/// Action before a tool call.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum PreToolAction {
|
||||
/// Proceed with execution.
|
||||
Continue,
|
||||
/// Skip this tool call (do not execute).
|
||||
Skip,
|
||||
/// Abort the entire run.
|
||||
Abort(String),
|
||||
/// Pause execution (can be resumed later).
|
||||
Pause,
|
||||
}
|
||||
|
||||
/// Action after a tool call.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum PostToolAction {
|
||||
/// Proceed normally.
|
||||
Continue,
|
||||
/// Abort the entire run.
|
||||
Abort(String),
|
||||
}
|
||||
|
||||
/// Action at the end of a turn (when LLM produces no tool calls).
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TurnEndAction {
|
||||
/// Turn is finished, return to caller.
|
||||
Finish,
|
||||
/// Continue with additional messages injected into history.
|
||||
ContinueWithMessages(Vec<Item>),
|
||||
/// Pause execution (can be resumed later).
|
||||
Pause,
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Context Types
|
||||
// =============================================================================
|
||||
|
||||
/// Context for pre-tool-call decisions.
|
||||
pub struct ToolCallInfo {
|
||||
/// Tool call information (modifiable).
|
||||
pub call: ToolCall,
|
||||
/// Tool meta information.
|
||||
pub meta: ToolMeta,
|
||||
/// Tool instance (for state access).
|
||||
pub tool: Arc<dyn Tool>,
|
||||
}
|
||||
|
||||
/// Context for post-tool-call decisions.
|
||||
pub struct ToolResultInfo {
|
||||
/// Original tool call.
|
||||
pub call: ToolCall,
|
||||
/// Tool execution result (modifiable).
|
||||
pub result: ToolResult,
|
||||
/// Tool meta information.
|
||||
pub meta: ToolMeta,
|
||||
/// Tool instance (for state access).
|
||||
pub tool: Arc<dyn Tool>,
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Interceptor Trait
|
||||
// =============================================================================
|
||||
|
||||
/// Intercepts the Worker execution loop at key decision points.
|
||||
///
|
||||
/// All methods have default implementations that let the Worker
|
||||
/// proceed without intervention. Upper layers (e.g. Pod) provide
|
||||
/// richer implementations for approval flows, permission checks, etc.
|
||||
#[async_trait]
|
||||
pub trait Interceptor: Send + Sync {
|
||||
/// Called after receiving user input, before adding to history.
|
||||
async fn on_prompt_submit(&self, _item: &mut Item) -> PromptAction {
|
||||
PromptAction::Continue
|
||||
}
|
||||
|
||||
/// Called before each LLM request. The context can be modified
|
||||
/// (e.g. for context compaction).
|
||||
async fn pre_llm_request(&self, _context: &mut Vec<Item>) -> PreRequestAction {
|
||||
PreRequestAction::Continue
|
||||
}
|
||||
|
||||
/// Called before each tool is executed.
|
||||
async fn pre_tool_call(&self, _info: &mut ToolCallInfo) -> PreToolAction {
|
||||
PreToolAction::Continue
|
||||
}
|
||||
|
||||
/// Called after each tool completes.
|
||||
async fn post_tool_call(&self, _info: &mut ToolResultInfo) -> PostToolAction {
|
||||
PostToolAction::Continue
|
||||
}
|
||||
|
||||
/// Called when a turn ends with no tool calls.
|
||||
async fn on_turn_end(&self, _history: &[Item]) -> TurnEndAction {
|
||||
TurnEndAction::Finish
|
||||
}
|
||||
|
||||
/// Called when execution is interrupted (abort or cancel).
|
||||
async fn on_abort(&self, _reason: &str) {}
|
||||
}
|
||||
|
||||
/// Default interceptor: no intervention. Worker proceeds through the loop
|
||||
/// without any external control flow decisions.
|
||||
pub(crate) struct DefaultInterceptor;
|
||||
|
||||
#[async_trait]
|
||||
impl Interceptor for DefaultInterceptor {}
|
||||
|
|
@ -6,7 +6,7 @@
|
|||
//!
|
||||
//! - [`Worker`] - Central component for managing LLM interactions
|
||||
//! - [`tool::Tool`] - Tools that can be invoked by the LLM
|
||||
//! - [`hook::Hook`] - Hooks for intercepting turn progression
|
||||
//! - [`interceptor::Interceptor`] - Control-flow delegation for the execution loop
|
||||
//! - Closure-based event callbacks via `Worker::on_text_block()`, `on_tool_use_block()`, etc.
|
||||
//!
|
||||
//! # Quick Start
|
||||
|
|
@ -41,8 +41,8 @@ mod worker;
|
|||
|
||||
pub(crate) mod callback;
|
||||
pub mod event;
|
||||
pub mod hook;
|
||||
pub mod llm_client;
|
||||
pub mod interceptor;
|
||||
pub mod state;
|
||||
pub mod timeline;
|
||||
pub mod tool;
|
||||
|
|
@ -51,4 +51,6 @@ pub mod tool_server;
|
|||
pub use callback::{TextBlockScope, ToolUseBlockScope};
|
||||
pub use handler::ToolUseBlockStart;
|
||||
pub use message::{ContentPart, Item, Message, Role};
|
||||
pub use interceptor::Interceptor;
|
||||
pub use tool::{ToolCall, ToolResult};
|
||||
pub use worker::{ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult};
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
use crate::{
|
||||
handler::{Handler, ToolUseBlockEvent, ToolUseBlockKind},
|
||||
hook::ToolCall,
|
||||
tool::ToolCall,
|
||||
};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
|
|
|
|||
|
|
@ -370,3 +370,54 @@ pub trait ToolOutputProcessor: Send + Sync {
|
|||
/// For large outputs, this should be a summary with a blob reference.
|
||||
async fn process(&self, output: String) -> Result<String, ToolError>;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Tool Call / Result Types
|
||||
// =============================================================================
|
||||
|
||||
/// Tool call information
|
||||
///
|
||||
/// Represents a ToolUse block from LLM.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
/// Tool call ID (used for linking with response)
|
||||
pub id: String,
|
||||
/// Tool name
|
||||
pub name: String,
|
||||
/// Input arguments (JSON)
|
||||
pub input: Value,
|
||||
}
|
||||
|
||||
/// Tool execution result
|
||||
///
|
||||
/// Represents the result after tool execution.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolResult {
|
||||
/// Corresponding tool call ID
|
||||
pub tool_use_id: String,
|
||||
/// Result content
|
||||
pub content: String,
|
||||
/// Whether this is an error
|
||||
#[serde(default)]
|
||||
pub is_error: bool,
|
||||
}
|
||||
|
||||
impl ToolResult {
|
||||
/// Create a success result
|
||||
pub fn success(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
tool_use_id: tool_use_id.into(),
|
||||
content: content.into(),
|
||||
is_error: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an error result
|
||||
pub fn error(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
tool_use_id: tool_use_id.into(),
|
||||
content: content.into(),
|
||||
is_error: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,15 +8,11 @@ use tracing::{debug, info, trace, warn};
|
|||
|
||||
use crate::{
|
||||
Item,
|
||||
hook::{
|
||||
Hook, HookError, HookRegistry, OnAbort, OnPromptSubmit, OnPromptSubmitResult,
|
||||
OnStreamChunk, OnStreamComplete, OnTextDelta, OnToolCallDelta, OnTurnEnd, OnTurnEndResult,
|
||||
PostToolCall, PostToolCallContext, PostToolCallResult, PreLlmRequest, PreLlmRequestResult,
|
||||
PreToolCall, PreToolCallResult, StreamChunkContext, StreamCompleteContext,
|
||||
StreamHookResult, TextDeltaContext, ToolCall, ToolCallContext, ToolCallDeltaContext,
|
||||
ToolResult,
|
||||
},
|
||||
llm_client::{ClientError, ConfigWarning, LlmClient, Request, RequestConfig, ToolDefinition},
|
||||
interceptor::{
|
||||
DefaultInterceptor, Interceptor, PostToolAction, PreRequestAction, PreToolAction,
|
||||
PromptAction, ToolCallInfo, ToolResultInfo, TurnEndAction,
|
||||
},
|
||||
state::{CacheLocked, Mutable, WorkerState},
|
||||
callback::{
|
||||
ClosureMetaHandler, ClosureTextBlockHandler, ClosureToolUseBlockHandler, TextBlockScope,
|
||||
|
|
@ -25,7 +21,7 @@ use crate::{
|
|||
handler::{ErrorKind, StatusKind, ToolUseBlockStart, UsageKind},
|
||||
timeline::{TextBlockCollector, Timeline, ToolCallCollector},
|
||||
timeline::event::{ErrorEvent, StatusEvent, UsageEvent},
|
||||
tool::{ToolDefinition as WorkerToolDefinition, ToolError, ToolOutputProcessor},
|
||||
tool::{ToolCall, ToolDefinition as WorkerToolDefinition, ToolError, ToolOutputProcessor, ToolResult},
|
||||
tool_server::{ToolServer, ToolServerError, ToolServerHandle},
|
||||
};
|
||||
|
||||
|
|
@ -42,9 +38,6 @@ pub enum WorkerError {
|
|||
/// Tool error
|
||||
#[error("Tool error: {0}")]
|
||||
Tool(#[from] ToolError),
|
||||
/// Hook error
|
||||
#[error("Hook error: {0}")]
|
||||
Hook(#[from] HookError),
|
||||
/// Execution was aborted
|
||||
#[error("Aborted: {0}")]
|
||||
Aborted(String),
|
||||
|
|
@ -145,8 +138,8 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
|||
tool_call_collector: ToolCallCollector,
|
||||
/// Tool server handle
|
||||
tool_server: ToolServerHandle,
|
||||
/// Hook registry
|
||||
hooks: HookRegistry,
|
||||
/// Interceptor for control-flow decisions
|
||||
interceptor: Box<dyn Interceptor>,
|
||||
/// System prompt
|
||||
system_prompt: Option<String>,
|
||||
/// Item history (owned by Worker)
|
||||
|
|
@ -192,21 +185,16 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
user_input: impl Into<String>,
|
||||
) -> Result<WorkerResult, WorkerError> {
|
||||
self.reset_interruption_state();
|
||||
// Hook: on_prompt_submit
|
||||
// Interceptor: on_prompt_submit
|
||||
let mut user_item = Item::user_message(user_input);
|
||||
let result = self.run_on_prompt_submit_hooks(&mut user_item).await;
|
||||
let result = match result {
|
||||
Ok(value) => value,
|
||||
Err(err) => return self.finalize_interruption(Err(err)).await,
|
||||
};
|
||||
match result {
|
||||
OnPromptSubmitResult::Cancel(reason) => {
|
||||
match self.interceptor.on_prompt_submit(&mut user_item).await {
|
||||
PromptAction::Cancel(reason) => {
|
||||
self.last_run_interrupted = true;
|
||||
return self
|
||||
.finalize_interruption(Err(WorkerError::Aborted(reason)))
|
||||
.await;
|
||||
}
|
||||
OnPromptSubmitResult::Continue => {}
|
||||
PromptAction::Continue => {}
|
||||
}
|
||||
self.history.push(user_item);
|
||||
let result = self.run_turn_loop().await;
|
||||
|
|
@ -335,58 +323,13 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
self.tool_server.clone()
|
||||
}
|
||||
|
||||
/// Add an on_prompt_submit Hook
|
||||
/// Set the interceptor for control-flow decisions.
|
||||
///
|
||||
/// Called immediately after receiving a user message in `run()`.
|
||||
pub fn add_on_prompt_submit_hook(&mut self, hook: impl Hook<OnPromptSubmit> + 'static) {
|
||||
self.hooks.on_prompt_submit.push(Box::new(hook));
|
||||
}
|
||||
|
||||
/// Add a pre_llm_request Hook
|
||||
///
|
||||
/// Called before sending an LLM request for each turn.
|
||||
pub fn add_pre_llm_request_hook(&mut self, hook: impl Hook<PreLlmRequest> + 'static) {
|
||||
self.hooks.pre_llm_request.push(Box::new(hook));
|
||||
}
|
||||
|
||||
/// Add a pre_tool_call Hook
|
||||
pub fn add_pre_tool_call_hook(&mut self, hook: impl Hook<PreToolCall> + 'static) {
|
||||
self.hooks.pre_tool_call.push(Box::new(hook));
|
||||
}
|
||||
|
||||
/// Add a post_tool_call Hook
|
||||
pub fn add_post_tool_call_hook(&mut self, hook: impl Hook<PostToolCall> + 'static) {
|
||||
self.hooks.post_tool_call.push(Box::new(hook));
|
||||
}
|
||||
|
||||
/// Add an on_turn_end Hook
|
||||
pub fn add_on_turn_end_hook(&mut self, hook: impl Hook<OnTurnEnd> + 'static) {
|
||||
self.hooks.on_turn_end.push(Box::new(hook));
|
||||
}
|
||||
|
||||
/// Add an on_abort Hook
|
||||
pub fn add_on_abort_hook(&mut self, hook: impl Hook<OnAbort> + 'static) {
|
||||
self.hooks.on_abort.push(Box::new(hook));
|
||||
}
|
||||
|
||||
/// Add an on_text_delta Hook
|
||||
pub fn add_on_text_delta_hook(&mut self, hook: impl Hook<OnTextDelta> + 'static) {
|
||||
self.hooks.on_text_delta.push(Box::new(hook));
|
||||
}
|
||||
|
||||
/// Add an on_tool_call_delta Hook
|
||||
pub fn add_on_tool_call_delta_hook(&mut self, hook: impl Hook<OnToolCallDelta> + 'static) {
|
||||
self.hooks.on_tool_call_delta.push(Box::new(hook));
|
||||
}
|
||||
|
||||
/// Add an on_stream_chunk Hook
|
||||
pub fn add_on_stream_chunk_hook(&mut self, hook: impl Hook<OnStreamChunk> + 'static) {
|
||||
self.hooks.on_stream_chunk.push(Box::new(hook));
|
||||
}
|
||||
|
||||
/// Add an on_stream_complete Hook
|
||||
pub fn add_on_stream_complete_hook(&mut self, hook: impl Hook<OnStreamComplete> + 'static) {
|
||||
self.hooks.on_stream_complete.push(Box::new(hook));
|
||||
/// The interceptor governs approval, skip, pause, and abort decisions
|
||||
/// at key points in the execution loop. If not set, the default
|
||||
/// interceptor is used (all Continue / Finish).
|
||||
pub fn set_interceptor(&mut self, interceptor: impl Interceptor + 'static) {
|
||||
self.interceptor = Box::new(interceptor);
|
||||
}
|
||||
|
||||
/// Get a mutable reference to the timeline (for additional handler registration)
|
||||
|
|
@ -578,131 +521,6 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
|
||||
/// Hooks: on_prompt_submit
|
||||
///
|
||||
/// Called immediately after receiving a user message in `run()` (first time only).
|
||||
async fn run_on_prompt_submit_hooks(
|
||||
&self,
|
||||
item: &mut Item,
|
||||
) -> Result<OnPromptSubmitResult, WorkerError> {
|
||||
for hook in &self.hooks.on_prompt_submit {
|
||||
let result = hook.call(item).await?;
|
||||
match result {
|
||||
OnPromptSubmitResult::Continue => continue,
|
||||
OnPromptSubmitResult::Cancel(reason) => {
|
||||
return Ok(OnPromptSubmitResult::Cancel(reason));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(OnPromptSubmitResult::Continue)
|
||||
}
|
||||
|
||||
/// Hooks: pre_llm_request
|
||||
///
|
||||
/// Called before sending an LLM request for each turn.
|
||||
async fn run_pre_llm_request_hooks(
|
||||
&self,
|
||||
) -> Result<(PreLlmRequestResult, Vec<Item>), WorkerError> {
|
||||
let mut temp_context = self.history.clone();
|
||||
for hook in &self.hooks.pre_llm_request {
|
||||
let result = hook.call(&mut temp_context).await?;
|
||||
match result {
|
||||
PreLlmRequestResult::Continue => continue,
|
||||
PreLlmRequestResult::Cancel(reason) => {
|
||||
return Ok((PreLlmRequestResult::Cancel(reason), temp_context));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok((PreLlmRequestResult::Continue, temp_context))
|
||||
}
|
||||
|
||||
/// Hooks: on_turn_end
|
||||
async fn run_on_turn_end_hooks(&self) -> Result<OnTurnEndResult, WorkerError> {
|
||||
let mut temp_items = self.history.clone();
|
||||
for hook in &self.hooks.on_turn_end {
|
||||
let result = hook.call(&mut temp_items).await?;
|
||||
match result {
|
||||
OnTurnEndResult::Finish => continue,
|
||||
OnTurnEndResult::ContinueWithMessages(items) => {
|
||||
return Ok(OnTurnEndResult::ContinueWithMessages(items));
|
||||
}
|
||||
OnTurnEndResult::Paused => return Ok(OnTurnEndResult::Paused),
|
||||
}
|
||||
}
|
||||
Ok(OnTurnEndResult::Finish)
|
||||
}
|
||||
|
||||
/// Hooks: on_abort
|
||||
async fn run_on_abort_hooks(&self, reason: &str) -> Result<(), WorkerError> {
|
||||
let mut reason = reason.to_string();
|
||||
for hook in &self.hooks.on_abort {
|
||||
hook.call(&mut reason).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn apply_stream_hook_result(result: StreamHookResult) -> Result<(), WorkerError> {
|
||||
match result {
|
||||
StreamHookResult::Continue => Ok(()),
|
||||
StreamHookResult::Abort(reason) => Err(WorkerError::Aborted(reason)),
|
||||
StreamHookResult::Pause => {
|
||||
Err(WorkerError::Aborted("Paused by stream hook".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_on_stream_chunk_hooks(
|
||||
&self,
|
||||
event: crate::event::Event,
|
||||
) -> Result<(), WorkerError> {
|
||||
let mut context = StreamChunkContext { event };
|
||||
for hook in &self.hooks.on_stream_chunk {
|
||||
let result = hook.call(&mut context).await?;
|
||||
Self::apply_stream_hook_result(result)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_on_text_delta_hooks(
|
||||
&self,
|
||||
index: usize,
|
||||
delta: String,
|
||||
) -> Result<(), WorkerError> {
|
||||
let mut context = TextDeltaContext { index, delta };
|
||||
for hook in &self.hooks.on_text_delta {
|
||||
let result = hook.call(&mut context).await?;
|
||||
Self::apply_stream_hook_result(result)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_on_tool_call_delta_hooks(
|
||||
&self,
|
||||
index: usize,
|
||||
delta_json_fragment: String,
|
||||
) -> Result<(), WorkerError> {
|
||||
let mut context = ToolCallDeltaContext {
|
||||
index,
|
||||
delta_json_fragment,
|
||||
};
|
||||
for hook in &self.hooks.on_tool_call_delta {
|
||||
let result = hook.call(&mut context).await?;
|
||||
Self::apply_stream_hook_result(result)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_on_stream_complete_hooks(
|
||||
&self,
|
||||
turn: usize,
|
||||
event_count: usize,
|
||||
) -> Result<(), WorkerError> {
|
||||
let mut context = StreamCompleteContext { turn, event_count };
|
||||
for hook in &self.hooks.on_stream_complete {
|
||||
let result = hook.call(&mut context).await?;
|
||||
Self::apply_stream_hook_result(result)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn finalize_interruption<T>(
|
||||
&mut self,
|
||||
result: Result<T, WorkerError>,
|
||||
|
|
@ -716,10 +534,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
WorkerError::Cancelled => "Cancelled".to_string(),
|
||||
_ => err.to_string(),
|
||||
};
|
||||
if let Err(hook_err) = self.run_on_abort_hooks(&reason).await {
|
||||
self.last_run_interrupted = true;
|
||||
return Err(hook_err);
|
||||
}
|
||||
self.interceptor.on_abort(&reason).await;
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
|
|
@ -780,59 +595,41 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
// Retained because it's needed for PostToolCall hooks
|
||||
let mut call_info_map = HashMap::new();
|
||||
|
||||
// Phase 1: Apply pre_tool_call hooks (determine skip/abort)
|
||||
// Phase 1: Apply pre_tool_call interceptor (determine skip/abort)
|
||||
let mut approved_calls = Vec::new();
|
||||
for mut tool_call in tool_calls {
|
||||
// Get tool definition
|
||||
if let Some((meta, tool)) = self.tool_server.get_tool(&tool_call.name) {
|
||||
// Create context
|
||||
let mut context = ToolCallContext {
|
||||
let mut info = ToolCallInfo {
|
||||
call: tool_call.clone(),
|
||||
meta,
|
||||
tool,
|
||||
};
|
||||
|
||||
let mut skip = false;
|
||||
for hook in &self.hooks.pre_tool_call {
|
||||
let result = hook
|
||||
.call(&mut context)
|
||||
.await
|
||||
.inspect_err(|_| self.last_run_interrupted = true)?;
|
||||
match result {
|
||||
PreToolCallResult::Continue => {}
|
||||
PreToolCallResult::Skip => {
|
||||
skip = true;
|
||||
break;
|
||||
}
|
||||
PreToolCallResult::Abort(reason) => {
|
||||
self.last_run_interrupted = true;
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
PreToolCallResult::Pause => {
|
||||
self.last_run_interrupted = true;
|
||||
return Ok(ToolExecutionResult::Paused);
|
||||
}
|
||||
match self.interceptor.pre_tool_call(&mut info).await {
|
||||
PreToolAction::Continue => {}
|
||||
PreToolAction::Skip => {
|
||||
continue;
|
||||
}
|
||||
PreToolAction::Abort(reason) => {
|
||||
self.last_run_interrupted = true;
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
PreToolAction::Pause => {
|
||||
self.last_run_interrupted = true;
|
||||
return Ok(ToolExecutionResult::Paused);
|
||||
}
|
||||
}
|
||||
|
||||
// Reflect changes made by hooks
|
||||
tool_call = context.call;
|
||||
// Reflect changes made by interceptor
|
||||
tool_call = info.call;
|
||||
|
||||
// Save to map (only if executing)
|
||||
if !skip {
|
||||
call_info_map.insert(
|
||||
tool_call.id.clone(),
|
||||
(
|
||||
tool_call.clone(),
|
||||
context.meta.clone(),
|
||||
context.tool.clone(),
|
||||
),
|
||||
);
|
||||
approved_calls.push(tool_call);
|
||||
}
|
||||
call_info_map.insert(
|
||||
tool_call.id.clone(),
|
||||
(tool_call.clone(), info.meta.clone(), info.tool.clone()),
|
||||
);
|
||||
approved_calls.push(tool_call);
|
||||
} else {
|
||||
// Unknown tools go into approved list as-is (will error at execution)
|
||||
// Hooks are not applied (no Meta available)
|
||||
approved_calls.push(tool_call);
|
||||
}
|
||||
}
|
||||
|
|
@ -879,32 +676,25 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
}
|
||||
}
|
||||
|
||||
// Phase 3: Apply post_tool_call hooks
|
||||
// Phase 3: Apply post_tool_call interceptor
|
||||
for tool_result in &mut results {
|
||||
// Get saved information
|
||||
if let Some((tool_call, meta, tool)) = call_info_map.get(&tool_result.tool_use_id) {
|
||||
let mut context = PostToolCallContext {
|
||||
let mut info = ToolResultInfo {
|
||||
call: tool_call.clone(),
|
||||
result: tool_result.clone(),
|
||||
meta: meta.clone(),
|
||||
tool: tool.clone(),
|
||||
};
|
||||
|
||||
for hook in &self.hooks.post_tool_call {
|
||||
let result = hook
|
||||
.call(&mut context)
|
||||
.await
|
||||
.inspect_err(|_| self.last_run_interrupted = true)?;
|
||||
match result {
|
||||
PostToolCallResult::Continue => {}
|
||||
PostToolCallResult::Abort(reason) => {
|
||||
self.last_run_interrupted = true;
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
match self.interceptor.post_tool_call(&mut info).await {
|
||||
PostToolAction::Continue => {}
|
||||
PostToolAction::Abort(reason) => {
|
||||
self.last_run_interrupted = true;
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
}
|
||||
// Reflect hook-modified results
|
||||
*tool_result = context.result;
|
||||
// Reflect interceptor-modified results
|
||||
*tool_result = info.result;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -963,21 +753,18 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
cb(current_turn);
|
||||
}
|
||||
|
||||
// Hook: pre_llm_request
|
||||
let (control, request_context) = self
|
||||
.run_pre_llm_request_hooks()
|
||||
.await
|
||||
.inspect_err(|_| self.last_run_interrupted = true)?;
|
||||
match control {
|
||||
PreLlmRequestResult::Cancel(reason) => {
|
||||
info!(reason = %reason, "Aborted by hook");
|
||||
// Interceptor: pre_llm_request
|
||||
let mut request_context = self.history.clone();
|
||||
match self.interceptor.pre_llm_request(&mut request_context).await {
|
||||
PreRequestAction::Cancel(reason) => {
|
||||
info!(reason = %reason, "Aborted by interceptor");
|
||||
for cb in &self.turn_end_cbs {
|
||||
cb(current_turn);
|
||||
}
|
||||
self.last_run_interrupted = true;
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
PreLlmRequestResult::Continue => {}
|
||||
PreRequestAction::Continue => {}
|
||||
}
|
||||
|
||||
// Build request
|
||||
|
|
@ -1025,26 +812,6 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
let event = result
|
||||
.inspect_err(|_| self.last_run_interrupted = true)?;
|
||||
self.timeline.dispatch(&event);
|
||||
|
||||
self.run_on_stream_chunk_hooks(event.clone())
|
||||
.await
|
||||
.inspect_err(|_| self.last_run_interrupted = true)?;
|
||||
|
||||
if let crate::llm_client::event::Event::BlockDelta(delta) = &event {
|
||||
match &delta.delta {
|
||||
crate::llm_client::event::DeltaContent::Text(text) => {
|
||||
self.run_on_text_delta_hooks(delta.index, text.clone())
|
||||
.await
|
||||
.inspect_err(|_| self.last_run_interrupted = true)?;
|
||||
}
|
||||
crate::llm_client::event::DeltaContent::InputJson(json_fragment) => {
|
||||
self.run_on_tool_call_delta_hooks(delta.index, json_fragment.clone())
|
||||
.await
|
||||
.inspect_err(|_| self.last_run_interrupted = true)?;
|
||||
}
|
||||
crate::llm_client::event::DeltaContent::Thinking(_) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
None => break, // Stream ended
|
||||
}
|
||||
|
|
@ -1060,9 +827,6 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
}
|
||||
}
|
||||
}
|
||||
self.run_on_stream_complete_hooks(current_turn, event_count)
|
||||
.await
|
||||
.inspect_err(|_| self.last_run_interrupted = true)?;
|
||||
debug!(event_count = event_count, "Stream completed");
|
||||
|
||||
// Notify turn end
|
||||
|
|
@ -1080,21 +844,17 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
self.history.extend(assistant_items);
|
||||
|
||||
if tool_calls.is_empty() {
|
||||
// No tool calls → determine turn end
|
||||
let turn_result = self
|
||||
.run_on_turn_end_hooks()
|
||||
.await
|
||||
.inspect_err(|_| self.last_run_interrupted = true)?;
|
||||
match turn_result {
|
||||
OnTurnEndResult::Finish => {
|
||||
// No tool calls → determine turn end via interceptor
|
||||
match self.interceptor.on_turn_end(&self.history).await {
|
||||
TurnEndAction::Finish => {
|
||||
self.last_run_interrupted = false;
|
||||
return Ok(WorkerResult::Finished);
|
||||
}
|
||||
OnTurnEndResult::ContinueWithMessages(additional) => {
|
||||
TurnEndAction::ContinueWithMessages(additional) => {
|
||||
self.history.extend(additional);
|
||||
continue;
|
||||
}
|
||||
OnTurnEndResult::Paused => {
|
||||
TurnEndAction::Pause => {
|
||||
self.last_run_interrupted = true;
|
||||
return Ok(WorkerResult::Paused);
|
||||
}
|
||||
|
|
@ -1164,7 +924,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
|||
text_block_collector,
|
||||
tool_call_collector,
|
||||
tool_server: ToolServer::new().handle(),
|
||||
hooks: HookRegistry::new(),
|
||||
interceptor: Box::new(DefaultInterceptor),
|
||||
system_prompt: None,
|
||||
history: Vec::new(),
|
||||
locked_prefix_len: 0,
|
||||
|
|
@ -1400,7 +1160,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
|||
text_block_collector: self.text_block_collector,
|
||||
tool_call_collector: self.tool_call_collector,
|
||||
tool_server: self.tool_server,
|
||||
hooks: self.hooks,
|
||||
interceptor: self.interceptor,
|
||||
system_prompt: self.system_prompt,
|
||||
history: self.history,
|
||||
locked_prefix_len,
|
||||
|
|
@ -1439,7 +1199,7 @@ impl<C: LlmClient> Worker<C, CacheLocked> {
|
|||
text_block_collector: self.text_block_collector,
|
||||
tool_call_collector: self.tool_call_collector,
|
||||
tool_server: self.tool_server,
|
||||
hooks: self.hooks,
|
||||
interceptor: self.interceptor,
|
||||
system_prompt: self.system_prompt,
|
||||
history: self.history,
|
||||
locked_prefix_len: 0,
|
||||
|
|
|
|||
|
|
@ -8,11 +8,8 @@ use std::time::{Duration, Instant};
|
|||
|
||||
use async_trait::async_trait;
|
||||
use llm_worker::Worker;
|
||||
use llm_worker::hook::{
|
||||
Hook, HookError, PostToolCall, PostToolCallContext, PostToolCallResult, PreToolCall,
|
||||
PreToolCallResult, ToolCallContext,
|
||||
};
|
||||
use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent};
|
||||
use llm_worker::interceptor::{Interceptor, PostToolAction, PreToolAction, ToolCallInfo, ToolResultInfo};
|
||||
use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta};
|
||||
|
||||
mod common;
|
||||
|
|
@ -156,21 +153,21 @@ async fn test_before_tool_call_skip() {
|
|||
worker.register_tool(allowed_tool.definition()).unwrap();
|
||||
worker.register_tool(blocked_tool.definition()).unwrap();
|
||||
|
||||
// Hook to skip "blocked_tool"
|
||||
struct BlockingHook;
|
||||
// Policy to skip "blocked_tool"
|
||||
struct BlockingPolicy;
|
||||
|
||||
#[async_trait]
|
||||
impl Hook<PreToolCall> for BlockingHook {
|
||||
async fn call(&self, ctx: &mut ToolCallContext) -> Result<PreToolCallResult, HookError> {
|
||||
if ctx.call.name == "blocked_tool" {
|
||||
Ok(PreToolCallResult::Skip)
|
||||
impl Interceptor for BlockingPolicy {
|
||||
async fn pre_tool_call(&self, info: &mut ToolCallInfo) -> PreToolAction {
|
||||
if info.call.name == "blocked_tool" {
|
||||
PreToolAction::Skip
|
||||
} else {
|
||||
Ok(PreToolCallResult::Continue)
|
||||
PreToolAction::Continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
worker.add_pre_tool_call_hook(BlockingHook);
|
||||
worker.set_interceptor(BlockingPolicy);
|
||||
|
||||
let _result = worker.run("Test hook").await;
|
||||
|
||||
|
|
@ -235,25 +232,22 @@ async fn test_post_tool_call_modification() {
|
|||
|
||||
worker.register_tool(simple_tool_definition()).unwrap();
|
||||
|
||||
// Hook to modify results
|
||||
struct ModifyingHook {
|
||||
// Policy to modify results
|
||||
struct ModifyingPolicy {
|
||||
modified_content: Arc<std::sync::Mutex<Option<String>>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Hook<PostToolCall> for ModifyingHook {
|
||||
async fn call(
|
||||
&self,
|
||||
ctx: &mut PostToolCallContext,
|
||||
) -> Result<PostToolCallResult, HookError> {
|
||||
ctx.result.content = format!("[Modified] {}", ctx.result.content);
|
||||
*self.modified_content.lock().unwrap() = Some(ctx.result.content.clone());
|
||||
Ok(PostToolCallResult::Continue)
|
||||
impl Interceptor for ModifyingPolicy {
|
||||
async fn post_tool_call(&self, info: &mut ToolResultInfo) -> PostToolAction {
|
||||
info.result.content = format!("[Modified] {}", info.result.content);
|
||||
*self.modified_content.lock().unwrap() = Some(info.result.content.clone());
|
||||
PostToolAction::Continue
|
||||
}
|
||||
}
|
||||
|
||||
let modified_content = Arc::new(std::sync::Mutex::new(None));
|
||||
worker.add_post_tool_call_hook(ModifyingHook {
|
||||
worker.set_interceptor(ModifyingPolicy {
|
||||
modified_content: modified_content.clone(),
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -1,194 +0,0 @@
|
|||
//! Streaming hook tests
|
||||
|
||||
mod common;
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common::MockLlmClient;
|
||||
use llm_worker::hook::{
|
||||
Hook, HookError, OnStreamChunk, OnStreamComplete, OnTextDelta, OnToolCallDelta,
|
||||
StreamChunkContext, StreamCompleteContext, StreamHookResult, TextDeltaContext,
|
||||
ToolCallDeltaContext,
|
||||
};
|
||||
use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent};
|
||||
use llm_worker::{Worker, WorkerError};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_delta_hooks_run_in_registration_order() {
|
||||
let events = vec![
|
||||
Event::text_block_start(0),
|
||||
Event::text_delta(0, "A"),
|
||||
Event::text_delta(0, "B"),
|
||||
Event::text_block_stop(0, None),
|
||||
Event::Status(StatusEvent {
|
||||
status: ResponseStatus::Completed,
|
||||
}),
|
||||
];
|
||||
|
||||
let client = MockLlmClient::new(events);
|
||||
let mut worker = Worker::new(client);
|
||||
|
||||
struct RecorderHook {
|
||||
label: &'static str,
|
||||
records: Arc<Mutex<Vec<String>>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Hook<OnTextDelta> for RecorderHook {
|
||||
async fn call(&self, input: &mut TextDeltaContext) -> Result<StreamHookResult, HookError> {
|
||||
self.records
|
||||
.lock()
|
||||
.unwrap()
|
||||
.push(format!("{}:{}", self.label, input.delta));
|
||||
Ok(StreamHookResult::Continue)
|
||||
}
|
||||
}
|
||||
|
||||
let records = Arc::new(Mutex::new(Vec::new()));
|
||||
worker.add_on_text_delta_hook(RecorderHook {
|
||||
label: "first",
|
||||
records: records.clone(),
|
||||
});
|
||||
worker.add_on_text_delta_hook(RecorderHook {
|
||||
label: "second",
|
||||
records: records.clone(),
|
||||
});
|
||||
|
||||
let result = worker.run("hello").await;
|
||||
assert!(result.is_ok(), "run should succeed: {result:?}");
|
||||
|
||||
let got = records.lock().unwrap().clone();
|
||||
assert_eq!(
|
||||
got,
|
||||
vec![
|
||||
"first:A".to_string(),
|
||||
"second:A".to_string(),
|
||||
"first:B".to_string(),
|
||||
"second:B".to_string(),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_stream_chunk_and_stream_complete_hooks_are_called() {
|
||||
let events = vec![
|
||||
Event::ping(),
|
||||
Event::text_block_start(0),
|
||||
Event::text_delta(0, "hi"),
|
||||
Event::text_block_stop(0, None),
|
||||
Event::usage(10, 5),
|
||||
Event::Status(StatusEvent {
|
||||
status: ResponseStatus::Completed,
|
||||
}),
|
||||
];
|
||||
|
||||
let client = MockLlmClient::new(events);
|
||||
let mut worker = Worker::new(client);
|
||||
|
||||
struct ChunkCounter(Arc<Mutex<usize>>);
|
||||
|
||||
#[async_trait]
|
||||
impl Hook<OnStreamChunk> for ChunkCounter {
|
||||
async fn call(
|
||||
&self,
|
||||
_input: &mut StreamChunkContext,
|
||||
) -> Result<StreamHookResult, HookError> {
|
||||
let mut guard = self.0.lock().unwrap();
|
||||
*guard += 1;
|
||||
Ok(StreamHookResult::Continue)
|
||||
}
|
||||
}
|
||||
|
||||
struct CompleteRecorder(Arc<Mutex<Vec<(usize, usize)>>>);
|
||||
|
||||
#[async_trait]
|
||||
impl Hook<OnStreamComplete> for CompleteRecorder {
|
||||
async fn call(
|
||||
&self,
|
||||
input: &mut StreamCompleteContext,
|
||||
) -> Result<StreamHookResult, HookError> {
|
||||
self.0.lock().unwrap().push((input.turn, input.event_count));
|
||||
Ok(StreamHookResult::Continue)
|
||||
}
|
||||
}
|
||||
|
||||
let chunk_count = Arc::new(Mutex::new(0usize));
|
||||
let completes = Arc::new(Mutex::new(Vec::new()));
|
||||
|
||||
worker.add_on_stream_chunk_hook(ChunkCounter(chunk_count.clone()));
|
||||
worker.add_on_stream_complete_hook(CompleteRecorder(completes.clone()));
|
||||
|
||||
let result = worker.run("hello").await;
|
||||
assert!(result.is_ok(), "run should succeed: {result:?}");
|
||||
|
||||
assert_eq!(*chunk_count.lock().unwrap(), 6);
|
||||
assert_eq!(completes.lock().unwrap().as_slice(), &[(0usize, 6usize)]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tool_call_delta_hook_can_abort_run() {
|
||||
let events = vec![
|
||||
Event::tool_use_start(0, "call_1", "unknown_tool"),
|
||||
Event::tool_input_delta(0, r#"{"x":1}"#),
|
||||
Event::tool_use_stop(0),
|
||||
Event::Status(StatusEvent {
|
||||
status: ResponseStatus::Completed,
|
||||
}),
|
||||
];
|
||||
|
||||
let client = MockLlmClient::new(events);
|
||||
let mut worker = Worker::new(client);
|
||||
|
||||
struct AbortToolDelta;
|
||||
|
||||
#[async_trait]
|
||||
impl Hook<OnToolCallDelta> for AbortToolDelta {
|
||||
async fn call(
|
||||
&self,
|
||||
_input: &mut ToolCallDeltaContext,
|
||||
) -> Result<StreamHookResult, HookError> {
|
||||
Ok(StreamHookResult::Abort("blocked by tool delta".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
worker.add_on_tool_call_delta_hook(AbortToolDelta);
|
||||
|
||||
let result = worker.run("hello").await;
|
||||
match result {
|
||||
Err(WorkerError::Aborted(reason)) => assert_eq!(reason, "blocked by tool delta"),
|
||||
other => panic!("expected aborted result, got: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_stream_hook_pause_is_mapped_to_aborted() {
|
||||
let events = vec![
|
||||
Event::text_block_start(0),
|
||||
Event::text_delta(0, "pause me"),
|
||||
Event::text_block_stop(0, None),
|
||||
Event::Status(StatusEvent {
|
||||
status: ResponseStatus::Completed,
|
||||
}),
|
||||
];
|
||||
|
||||
let client = MockLlmClient::new(events);
|
||||
let mut worker = Worker::new(client);
|
||||
|
||||
struct PauseHook;
|
||||
|
||||
#[async_trait]
|
||||
impl Hook<OnTextDelta> for PauseHook {
|
||||
async fn call(&self, _input: &mut TextDeltaContext) -> Result<StreamHookResult, HookError> {
|
||||
Ok(StreamHookResult::Pause)
|
||||
}
|
||||
}
|
||||
|
||||
worker.add_on_text_delta_hook(PauseHook);
|
||||
|
||||
let result = worker.run("hello").await;
|
||||
match result {
|
||||
Err(WorkerError::Aborted(reason)) => assert_eq!(reason, "Paused by stream hook"),
|
||||
other => panic!("expected aborted result, got: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
|
@ -5,6 +5,7 @@ edition.workspace = true
|
|||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.89"
|
||||
clap = { version = "4.6.0", features = ["derive"] }
|
||||
llm-worker = { version = "0.2.1", path = "../llm-worker" }
|
||||
llm-worker-persistence = { version = "0.1.0", path = "../llm-worker-persistence" }
|
||||
|
|
|
|||
160
crates/pod/src/hook.rs
Normal file
160
crates/pod/src/hook.rs
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
//! Pod-layer hook infrastructure
|
||||
//!
|
||||
//! Provides the `Hook<E>` trait and `HookRegistry` for orchestration hooks
|
||||
//! that govern control-flow decisions in the Worker execution loop.
|
||||
//!
|
||||
//! The type system (`HookEventKind` / `Hook<E>`) mirrors the pattern
|
||||
//! originally in llm-worker, now at the insomnia layer where orchestration
|
||||
//! concerns belong.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use llm_worker::interceptor::{
|
||||
PostToolAction, PreRequestAction, PreToolAction, PromptAction, ToolCallInfo, ToolResultInfo,
|
||||
TurnEndAction,
|
||||
};
|
||||
use llm_worker::Item;
|
||||
|
||||
// =============================================================================
|
||||
// Hook Event Kinds
|
||||
// =============================================================================
|
||||
|
||||
/// Marker trait for hook event kinds.
|
||||
///
|
||||
/// Each event kind specifies its input (passed mutably to hooks) and
|
||||
/// output (the control-flow action returned by hooks).
|
||||
pub trait HookEventKind: Send + Sync + 'static {
|
||||
/// Mutable input passed to the hook.
|
||||
type Input;
|
||||
/// Control-flow action returned by the hook.
|
||||
type Output;
|
||||
}
|
||||
|
||||
// --- Event kind markers ---
|
||||
|
||||
/// After receiving user input, before adding to history.
|
||||
pub struct OnPromptSubmit;
|
||||
/// Before each LLM request.
|
||||
pub struct PreLlmRequest;
|
||||
/// Before each tool is executed.
|
||||
pub struct PreToolCall;
|
||||
/// After each tool completes.
|
||||
pub struct PostToolCall;
|
||||
/// When a turn ends with no tool calls.
|
||||
pub struct OnTurnEnd;
|
||||
/// When execution is interrupted.
|
||||
pub struct OnAbort;
|
||||
|
||||
impl HookEventKind for OnPromptSubmit {
|
||||
type Input = Item;
|
||||
type Output = PromptAction;
|
||||
}
|
||||
|
||||
impl HookEventKind for PreLlmRequest {
|
||||
type Input = Vec<Item>;
|
||||
type Output = PreRequestAction;
|
||||
}
|
||||
|
||||
impl HookEventKind for PreToolCall {
|
||||
type Input = ToolCallInfo;
|
||||
type Output = PreToolAction;
|
||||
}
|
||||
|
||||
impl HookEventKind for PostToolCall {
|
||||
type Input = ToolResultInfo;
|
||||
type Output = PostToolAction;
|
||||
}
|
||||
|
||||
impl HookEventKind for OnTurnEnd {
|
||||
type Input = Vec<Item>;
|
||||
type Output = TurnEndAction;
|
||||
}
|
||||
|
||||
impl HookEventKind for OnAbort {
|
||||
type Input = String;
|
||||
type Output = ();
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Hook Trait
|
||||
// =============================================================================
|
||||
|
||||
/// Async hook for a specific event kind.
|
||||
///
|
||||
/// Hooks receive mutable access to the event's input and return a
|
||||
/// control-flow action. Multiple hooks can be registered per event;
|
||||
/// they are evaluated in registration order and short-circuit on the
|
||||
/// first non-Continue result.
|
||||
#[async_trait]
|
||||
pub trait Hook<E: HookEventKind>: Send + Sync {
|
||||
async fn call(&self, input: &mut E::Input) -> E::Output;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Hook Registry
|
||||
// =============================================================================
|
||||
|
||||
/// Builder for constructing a frozen `HookRegistry`.
|
||||
///
|
||||
/// Hooks are added during setup, then `build()` produces an immutable
|
||||
/// registry that can be shared via `Arc`.
|
||||
#[derive(Default)]
|
||||
pub struct HookRegistryBuilder {
|
||||
on_prompt_submit: Vec<Box<dyn Hook<OnPromptSubmit>>>,
|
||||
pre_llm_request: Vec<Box<dyn Hook<PreLlmRequest>>>,
|
||||
pre_tool_call: Vec<Box<dyn Hook<PreToolCall>>>,
|
||||
post_tool_call: Vec<Box<dyn Hook<PostToolCall>>>,
|
||||
on_turn_end: Vec<Box<dyn Hook<OnTurnEnd>>>,
|
||||
on_abort: Vec<Box<dyn Hook<OnAbort>>>,
|
||||
}
|
||||
|
||||
impl HookRegistryBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn add_on_prompt_submit(&mut self, hook: impl Hook<OnPromptSubmit> + 'static) {
|
||||
self.on_prompt_submit.push(Box::new(hook));
|
||||
}
|
||||
|
||||
pub fn add_pre_llm_request(&mut self, hook: impl Hook<PreLlmRequest> + 'static) {
|
||||
self.pre_llm_request.push(Box::new(hook));
|
||||
}
|
||||
|
||||
pub fn add_pre_tool_call(&mut self, hook: impl Hook<PreToolCall> + 'static) {
|
||||
self.pre_tool_call.push(Box::new(hook));
|
||||
}
|
||||
|
||||
pub fn add_post_tool_call(&mut self, hook: impl Hook<PostToolCall> + 'static) {
|
||||
self.post_tool_call.push(Box::new(hook));
|
||||
}
|
||||
|
||||
pub fn add_on_turn_end(&mut self, hook: impl Hook<OnTurnEnd> + 'static) {
|
||||
self.on_turn_end.push(Box::new(hook));
|
||||
}
|
||||
|
||||
pub fn add_on_abort(&mut self, hook: impl Hook<OnAbort> + 'static) {
|
||||
self.on_abort.push(Box::new(hook));
|
||||
}
|
||||
|
||||
/// Freeze the builder into an immutable registry.
|
||||
pub fn build(self) -> HookRegistry {
|
||||
HookRegistry {
|
||||
on_prompt_submit: self.on_prompt_submit,
|
||||
pre_llm_request: self.pre_llm_request,
|
||||
pre_tool_call: self.pre_tool_call,
|
||||
post_tool_call: self.post_tool_call,
|
||||
on_turn_end: self.on_turn_end,
|
||||
on_abort: self.on_abort,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Frozen registry of hooks. Constructed via [`HookRegistryBuilder::build()`].
|
||||
pub struct HookRegistry {
|
||||
pub(crate) on_prompt_submit: Vec<Box<dyn Hook<OnPromptSubmit>>>,
|
||||
pub(crate) pre_llm_request: Vec<Box<dyn Hook<PreLlmRequest>>>,
|
||||
pub(crate) pre_tool_call: Vec<Box<dyn Hook<PreToolCall>>>,
|
||||
pub(crate) post_tool_call: Vec<Box<dyn Hook<PostToolCall>>>,
|
||||
pub(crate) on_turn_end: Vec<Box<dyn Hook<OnTurnEnd>>>,
|
||||
pub(crate) on_abort: Vec<Box<dyn Hook<OnAbort>>>,
|
||||
}
|
||||
87
crates/pod/src/hook_interceptor.rs
Normal file
87
crates/pod/src/hook_interceptor.rs
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
//! HookInterceptor — bridges Pod-layer hooks to Worker's Interceptor trait.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use llm_worker::interceptor::{
|
||||
Interceptor, PostToolAction, PreRequestAction, PreToolAction, PromptAction, ToolCallInfo,
|
||||
ToolResultInfo, TurnEndAction,
|
||||
};
|
||||
use llm_worker::Item;
|
||||
|
||||
use crate::hook::HookRegistry;
|
||||
|
||||
/// An `Interceptor` implementation that delegates to a `HookRegistry`.
|
||||
///
|
||||
/// Each method iterates the registered hooks in order and short-circuits
|
||||
/// on the first non-Continue (or non-Finish) result.
|
||||
pub(crate) struct HookInterceptor {
|
||||
registry: Arc<HookRegistry>,
|
||||
}
|
||||
|
||||
impl HookInterceptor {
|
||||
pub(crate) fn new(registry: Arc<HookRegistry>) -> Self {
|
||||
Self { registry }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Interceptor for HookInterceptor {
|
||||
async fn on_prompt_submit(&self, item: &mut Item) -> PromptAction {
|
||||
for hook in &self.registry.on_prompt_submit {
|
||||
let action = hook.call(item).await;
|
||||
if !matches!(action, PromptAction::Continue) {
|
||||
return action;
|
||||
}
|
||||
}
|
||||
PromptAction::Continue
|
||||
}
|
||||
|
||||
async fn pre_llm_request(&self, context: &mut Vec<Item>) -> PreRequestAction {
|
||||
for hook in &self.registry.pre_llm_request {
|
||||
let action = hook.call(context).await;
|
||||
if !matches!(action, PreRequestAction::Continue) {
|
||||
return action;
|
||||
}
|
||||
}
|
||||
PreRequestAction::Continue
|
||||
}
|
||||
|
||||
async fn pre_tool_call(&self, info: &mut ToolCallInfo) -> PreToolAction {
|
||||
for hook in &self.registry.pre_tool_call {
|
||||
let action = hook.call(info).await;
|
||||
if !matches!(action, PreToolAction::Continue) {
|
||||
return action;
|
||||
}
|
||||
}
|
||||
PreToolAction::Continue
|
||||
}
|
||||
|
||||
async fn post_tool_call(&self, info: &mut ToolResultInfo) -> PostToolAction {
|
||||
for hook in &self.registry.post_tool_call {
|
||||
let action = hook.call(info).await;
|
||||
if !matches!(action, PostToolAction::Continue) {
|
||||
return action;
|
||||
}
|
||||
}
|
||||
PostToolAction::Continue
|
||||
}
|
||||
|
||||
async fn on_turn_end(&self, history: &[Item]) -> TurnEndAction {
|
||||
let mut history_vec = history.to_vec();
|
||||
for hook in &self.registry.on_turn_end {
|
||||
let action = hook.call(&mut history_vec).await;
|
||||
if !matches!(action, TurnEndAction::Finish) {
|
||||
return action;
|
||||
}
|
||||
}
|
||||
TurnEndAction::Finish
|
||||
}
|
||||
|
||||
async fn on_abort(&self, reason: &str) {
|
||||
let mut reason_string = reason.to_string();
|
||||
for hook in &self.registry.on_abort {
|
||||
hook.call(&mut reason_string).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,12 +1,15 @@
|
|||
pub mod controller;
|
||||
pub mod hook;
|
||||
pub mod runtime_dir;
|
||||
pub mod shared_state;
|
||||
pub mod socket_server;
|
||||
|
||||
mod hook_interceptor;
|
||||
mod pod;
|
||||
|
||||
pub use controller::{PodController, PodHandle};
|
||||
pub use manifest::{PodManifest, ProviderConfig, ProviderKind, Scope};
|
||||
pub use hook::{Hook, HookEventKind, HookRegistryBuilder};
|
||||
pub use pod::{Pod, PodError, PodRunResult, apply_worker_manifest};
|
||||
pub use protocol::{ErrorCode, Event, Method, TurnResult};
|
||||
pub use provider::{ProviderError, build_client};
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use llm_worker::llm_client::client::LlmClient;
|
||||
use llm_worker::llm_client::RequestConfig;
|
||||
use llm_worker::Worker;
|
||||
|
|
@ -7,6 +9,12 @@ use llm_worker_persistence::{
|
|||
|
||||
use manifest::{PodManifest, Scope, WorkerManifest};
|
||||
|
||||
use crate::hook::{
|
||||
Hook, HookRegistryBuilder, OnAbort, OnPromptSubmit, OnTurnEnd, PostToolCall, PreLlmRequest,
|
||||
PreToolCall,
|
||||
};
|
||||
use crate::hook_interceptor::HookInterceptor;
|
||||
|
||||
/// An independent agent execution unit.
|
||||
///
|
||||
/// Wraps a persistent [`Session`] with manifest metadata and an optional
|
||||
|
|
@ -15,6 +23,8 @@ pub struct Pod<C: LlmClient, St: Store> {
|
|||
manifest: PodManifest,
|
||||
session: Session<C, St>,
|
||||
scope: Option<Scope>,
|
||||
hook_builder: HookRegistryBuilder,
|
||||
interceptor_installed: bool,
|
||||
}
|
||||
|
||||
impl<C: LlmClient, St: Store> Pod<C, St> {
|
||||
|
|
@ -34,6 +44,8 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
|
|||
manifest,
|
||||
session,
|
||||
scope,
|
||||
hook_builder: HookRegistryBuilder::new(),
|
||||
interceptor_installed: false,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -50,6 +62,8 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
|
|||
manifest,
|
||||
session,
|
||||
scope,
|
||||
hook_builder: HookRegistryBuilder::new(),
|
||||
interceptor_installed: false,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -76,14 +90,75 @@ impl<C: LlmClient, St: Store> Pod<C, St> {
|
|||
&mut self.session
|
||||
}
|
||||
|
||||
// --- Hook registration ---
|
||||
//
|
||||
// Hooks must be registered before the first call to `run()` or `resume()`.
|
||||
// Attempting to add a hook after execution has started will panic.
|
||||
|
||||
fn assert_hooks_open(&self) {
|
||||
assert!(
|
||||
!self.interceptor_installed,
|
||||
"cannot add hooks after run() or resume() has been called"
|
||||
);
|
||||
}
|
||||
|
||||
/// Register a hook that runs after receiving user input.
|
||||
pub fn add_on_prompt_submit_hook(&mut self, hook: impl Hook<OnPromptSubmit> + 'static) {
|
||||
self.assert_hooks_open();
|
||||
self.hook_builder.add_on_prompt_submit(hook);
|
||||
}
|
||||
|
||||
/// Register a hook that runs before each LLM request.
|
||||
pub fn add_pre_llm_request_hook(&mut self, hook: impl Hook<PreLlmRequest> + 'static) {
|
||||
self.assert_hooks_open();
|
||||
self.hook_builder.add_pre_llm_request(hook);
|
||||
}
|
||||
|
||||
/// Register a hook that runs before each tool call.
|
||||
pub fn add_pre_tool_call_hook(&mut self, hook: impl Hook<PreToolCall> + 'static) {
|
||||
self.assert_hooks_open();
|
||||
self.hook_builder.add_pre_tool_call(hook);
|
||||
}
|
||||
|
||||
/// Register a hook that runs after each tool call.
|
||||
pub fn add_post_tool_call_hook(&mut self, hook: impl Hook<PostToolCall> + 'static) {
|
||||
self.assert_hooks_open();
|
||||
self.hook_builder.add_post_tool_call(hook);
|
||||
}
|
||||
|
||||
/// Register a hook that runs at the end of a turn.
|
||||
pub fn add_on_turn_end_hook(&mut self, hook: impl Hook<OnTurnEnd> + 'static) {
|
||||
self.assert_hooks_open();
|
||||
self.hook_builder.add_on_turn_end(hook);
|
||||
}
|
||||
|
||||
/// Register a hook that runs when execution is aborted.
|
||||
pub fn add_on_abort_hook(&mut self, hook: impl Hook<OnAbort> + 'static) {
|
||||
self.assert_hooks_open();
|
||||
self.hook_builder.add_on_abort(hook);
|
||||
}
|
||||
|
||||
/// Install the hook-based interceptor on the Worker if not already done.
|
||||
fn ensure_interceptor_installed(&mut self) {
|
||||
if !self.interceptor_installed {
|
||||
let builder = std::mem::take(&mut self.hook_builder);
|
||||
let registry = Arc::new(builder.build());
|
||||
let interceptor = HookInterceptor::new(registry);
|
||||
self.session.worker.set_interceptor(interceptor);
|
||||
self.interceptor_installed = true;
|
||||
}
|
||||
}
|
||||
|
||||
/// Send user input and run until the LLM turn completes.
|
||||
pub async fn run(&mut self, input: impl Into<String>) -> Result<PodRunResult, PodError> {
|
||||
self.ensure_interceptor_installed();
|
||||
let result = self.session.run(input).await?;
|
||||
Ok(result.into())
|
||||
}
|
||||
|
||||
/// Resume from a paused state.
|
||||
pub async fn resume(&mut self) -> Result<PodRunResult, PodError> {
|
||||
self.ensure_interceptor_installed();
|
||||
let result = self.session.resume().await?;
|
||||
Ok(result.into())
|
||||
}
|
||||
|
|
@ -107,6 +182,8 @@ impl<St: Store> Pod<Box<dyn LlmClient>, St> {
|
|||
manifest,
|
||||
session,
|
||||
scope,
|
||||
hook_builder: HookRegistryBuilder::new(),
|
||||
interceptor_installed: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
44
tickets/api-key-file.md
Normal file
44
tickets/api-key-file.md
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
# api_key_file: ファイルパスによるAPIキー解決
|
||||
|
||||
## 背景
|
||||
|
||||
現状、APIキーの取得手段は `api_key_env`(環境変数名の指定)のみ。
|
||||
永続化やインタラクティブ入力の仕組みがなく、キー管理をユーザーのシェル環境に完全依存している。
|
||||
|
||||
## やること
|
||||
|
||||
マニフェストの `ProviderConfig` に `api_key_file: Option<PathBuf>` を追加し、ファイルからAPIキーを読み取れるようにする。
|
||||
|
||||
### マニフェスト
|
||||
|
||||
```toml
|
||||
[provider]
|
||||
kind = "anthropic"
|
||||
model = "claude-sonnet-4-20250514"
|
||||
api_key_file = "~/.config/insomnia/keys/anthropic"
|
||||
```
|
||||
|
||||
- ファイルにはキーのみを記載(読み込み時に trim)
|
||||
- `~` 展開が必要
|
||||
- 相対パスはマニフェストファイルの位置基準
|
||||
|
||||
### api_key_env との関係
|
||||
|
||||
- 排他。両方指定されたらエラー
|
||||
- Ollama は両方不要のまま
|
||||
|
||||
### 変更箇所
|
||||
|
||||
1. **manifest**: `ProviderConfig` に `api_key_file: Option<PathBuf>` を追加
|
||||
2. **provider**: `build_client()` でファイル読み取りロジックを追加。排他バリデーション
|
||||
3. **provider**: `ProviderError` にキー不在を明示するバリアント追加(将来の TUI フォールバック用)
|
||||
|
||||
### 暗号化について
|
||||
|
||||
現段階では扱わない。ファイルパーミッション(0600)で十分。
|
||||
将来エンドユーザー向けに暗号化が必要になった場合、provider の手前に復号レイヤーを挟む形で対応できる。`api_key_file` の設計自体は変更不要。
|
||||
|
||||
### 将来の拡張
|
||||
|
||||
- TUI サブコマンド(`insomnia key set anthropic` 等)がこのファイルに書き込むラッパーになる
|
||||
- `api_key_cmd`(コマンド実行によるキー取得)は `api_key_file` で不足が生じた時点で検討
|
||||
|
|
@ -64,4 +64,4 @@ history 内のツール出力を走査:
|
|||
|
||||
## 依存チケット
|
||||
|
||||
- [remove-hook-module.md](remove-hook-module.md) — Hook が insomnia 層に移動した後、PreLlmRequest で Prune を差し込む
|
||||
- ~~[remove-hook-module.md](remove-hook-module.md)~~ — 完了。PreLlmRequest は Pod 層の `hook::Hook<PreLlmRequest>` として利用可能
|
||||
|
|
|
|||
|
|
@ -51,4 +51,4 @@ action = "allow"
|
|||
|
||||
## 依存チケット
|
||||
|
||||
- [remove-hook-module.md](remove-hook-module.md) — PreToolCall が insomnia 層に移動してから実装
|
||||
- ~~[remove-hook-module.md](remove-hook-module.md)~~ — 完了。PreToolCall は Pod 層の `hook::Hook<PreToolCall>` として利用可能
|
||||
|
|
|
|||
|
|
@ -1,42 +0,0 @@
|
|||
# Hook モジュールの llm-worker からの除去
|
||||
|
||||
## 背景
|
||||
|
||||
llm-worker は低レベル基盤に徹するべきだが、現行の `hook` モジュールは
|
||||
高レベルのオーケストレーション関心(承認フロー、ターン制御等)を含んでいる。
|
||||
Claude Code の Hooks のような機能は insomnia 層の責務。
|
||||
|
||||
低レベルのストリーム介入は、クロージャベースの Subscriber API で十分カバーできる。
|
||||
|
||||
## 方針
|
||||
|
||||
`hook` モジュールを llm-worker から除去し、責務を分離する。
|
||||
|
||||
### Subscriber で代替(削除)
|
||||
|
||||
ストリーム観測・介入はクロージャ Subscriber で対応:
|
||||
|
||||
- `OnTextDelta`
|
||||
- `OnToolCallDelta`
|
||||
- `OnStreamChunk`
|
||||
- `OnStreamComplete`
|
||||
|
||||
### insomnia 層に移動
|
||||
|
||||
高レベルオーケストレーションは上位層が担う:
|
||||
|
||||
- `OnPromptSubmit`
|
||||
- `PreLlmRequest`
|
||||
- `PreToolCall` / `PostToolCall`
|
||||
- `OnTurnEnd`
|
||||
- `OnAbort`
|
||||
|
||||
## 設計ポイント
|
||||
|
||||
- Worker の実行ループは「ストリーム受信 → ツール実行 → 次ターン」に集中させる
|
||||
- 介入ポイント(承認、中断、ターン継続判断)は insomnia 層が提供する
|
||||
- `HookEventKind` / `Hook<E>` の型設計自体は良いので、insomnia 層で再利用可能
|
||||
|
||||
## 依存チケット
|
||||
|
||||
- [subscriber-closure-api.md](subscriber-closure-api.md) — ストリーム系 Hook の代替先
|
||||
Loading…
Reference in New Issue
Block a user