llm_worker_rs/llm-worker/src/hook.rs

183 lines
5.9 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! Hook関連の型定義
//!
//! Worker層でのターン制御・介入に使用される型
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use thiserror::Error;
// =============================================================================
// Control Flow Types
// =============================================================================
/// Hook処理の制御フロー
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ControlFlow {
/// 処理を続行
Continue,
/// 現在の処理をスキップTool実行など
Skip,
/// 処理を中断
Abort(String),
}
/// ターン終了時の判定結果
#[derive(Debug, Clone)]
pub enum TurnResult {
/// ターンを終了
Finish,
/// メッセージを追加してターン継続(自己修正など)
ContinueWithMessages(Vec<crate::Message>),
}
// =============================================================================
// Tool Call / Result Types
// =============================================================================
/// ツール呼び出し情報
///
/// LLMからのToolUseブロックを表現し、Hook処理で改変可能
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
/// ツール呼び出しIDレスポンスとの紐付けに使用
pub id: String,
/// ツール名
pub name: String,
/// 入力引数JSON
pub input: Value,
}
/// ツール実行結果
///
/// ツール実行後の結果を表現し、Hook処理で改変可能
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
/// 対応するツール呼び出しID
pub tool_use_id: String,
/// 結果コンテンツ
pub content: String,
/// エラーかどうか
#[serde(default)]
pub is_error: bool,
}
impl ToolResult {
/// 成功結果を作成
pub fn success(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
tool_use_id: tool_use_id.into(),
content: content.into(),
is_error: false,
}
}
/// エラー結果を作成
pub fn error(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
tool_use_id: tool_use_id.into(),
content: content.into(),
is_error: true,
}
}
}
// =============================================================================
// Hook Error
// =============================================================================
/// Hookエラー
#[derive(Debug, Error)]
pub enum HookError {
/// 処理が中断された
#[error("Aborted: {0}")]
Aborted(String),
/// 内部エラー
#[error("Hook error: {0}")]
Internal(String),
}
// =============================================================================
// WorkerHook Trait
// =============================================================================
/// ターンの進行・ツール実行に介入するためのトレイト
///
/// Hookを使うと、メッセージ送信前、ツール実行前後、ターン終了時に
/// 処理を挟んだり、実行をキャンセルしたりできます。
///
/// # Examples
///
/// ```ignore
/// use llm_worker::hook::{ControlFlow, HookError, ToolCall, TurnResult, WorkerHook};
/// use llm_worker::Message;
///
/// struct ValidationHook;
///
/// #[async_trait::async_trait]
/// impl WorkerHook for ValidationHook {
/// async fn before_tool_call(&self, call: &mut ToolCall) -> Result<ControlFlow, HookError> {
/// // 危険なツールをブロック
/// if call.name == "delete_all" {
/// return Ok(ControlFlow::Skip);
/// }
/// Ok(ControlFlow::Continue)
/// }
///
/// async fn on_turn_end(&self, messages: &[Message]) -> Result<TurnResult, HookError> {
/// // 条件を満たさなければ追加メッセージで継続
/// if messages.len() < 3 {
/// return Ok(TurnResult::ContinueWithMessages(vec![
/// Message::user("Please elaborate.")
/// ]));
/// }
/// Ok(TurnResult::Finish)
/// }
/// }
/// ```
///
/// # デフォルト実装
///
/// すべてのメソッドにはデフォルト実装があり、何も行わず`Continue`を返します。
/// 必要なメソッドのみオーバーライドしてください。
#[async_trait]
pub trait WorkerHook: Send + Sync {
/// メッセージ送信前に呼ばれる
///
/// リクエストに含まれるメッセージリストを参照・改変できます。
/// `ControlFlow::Abort`を返すとターンが中断されます。
async fn on_message_send(
&self,
_context: &mut Vec<crate::Message>,
) -> Result<ControlFlow, HookError> {
Ok(ControlFlow::Continue)
}
/// ツール実行前に呼ばれる
///
/// ツール呼び出しの引数を書き換えたり、実行をスキップしたりできます。
/// `ControlFlow::Skip`を返すとこのツールの実行がスキップされます。
async fn before_tool_call(&self, _tool_call: &mut ToolCall) -> Result<ControlFlow, HookError> {
Ok(ControlFlow::Continue)
}
/// ツール実行後に呼ばれる
///
/// ツールの実行結果を書き換えたり、隠蔽したりできます。
async fn after_tool_call(
&self,
_tool_result: &mut ToolResult,
) -> Result<ControlFlow, HookError> {
Ok(ControlFlow::Continue)
}
/// ターン終了時に呼ばれる
///
/// 生成されたメッセージを検査し、必要なら追加メッセージで継続を指示できます。
/// `TurnResult::ContinueWithMessages`を返すと、指定したメッセージを追加して
/// 次のターンに進みます。
async fn on_turn_end(&self, _messages: &[crate::Message]) -> Result<TurnResult, HookError> {
Ok(TurnResult::Finish)
}
}