183 lines
5.9 KiB
Rust
183 lines
5.9 KiB
Rust
//! Hook関連の型定義
|
||
//!
|
||
//! Worker層でのターン制御・介入に使用される型
|
||
|
||
use async_trait::async_trait;
|
||
use serde::{Deserialize, Serialize};
|
||
use serde_json::Value;
|
||
use thiserror::Error;
|
||
|
||
// =============================================================================
|
||
// Control Flow Types
|
||
// =============================================================================
|
||
|
||
/// Hook処理の制御フロー
|
||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||
pub enum ControlFlow {
|
||
/// 処理を続行
|
||
Continue,
|
||
/// 現在の処理をスキップ(Tool実行など)
|
||
Skip,
|
||
/// 処理を中断
|
||
Abort(String),
|
||
}
|
||
|
||
/// ターン終了時の判定結果
|
||
#[derive(Debug, Clone)]
|
||
pub enum TurnResult {
|
||
/// ターンを終了
|
||
Finish,
|
||
/// メッセージを追加してターン継続(自己修正など)
|
||
ContinueWithMessages(Vec<crate::Message>),
|
||
}
|
||
|
||
// =============================================================================
|
||
// Tool Call / Result Types
|
||
// =============================================================================
|
||
|
||
/// ツール呼び出し情報
|
||
///
|
||
/// LLMからのToolUseブロックを表現し、Hook処理で改変可能
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct ToolCall {
|
||
/// ツール呼び出しID(レスポンスとの紐付けに使用)
|
||
pub id: String,
|
||
/// ツール名
|
||
pub name: String,
|
||
/// 入力引数(JSON)
|
||
pub input: Value,
|
||
}
|
||
|
||
/// ツール実行結果
|
||
///
|
||
/// ツール実行後の結果を表現し、Hook処理で改変可能
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct ToolResult {
|
||
/// 対応するツール呼び出しID
|
||
pub tool_use_id: String,
|
||
/// 結果コンテンツ
|
||
pub content: String,
|
||
/// エラーかどうか
|
||
#[serde(default)]
|
||
pub is_error: bool,
|
||
}
|
||
|
||
impl ToolResult {
|
||
/// 成功結果を作成
|
||
pub fn success(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
|
||
Self {
|
||
tool_use_id: tool_use_id.into(),
|
||
content: content.into(),
|
||
is_error: false,
|
||
}
|
||
}
|
||
|
||
/// エラー結果を作成
|
||
pub fn error(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
|
||
Self {
|
||
tool_use_id: tool_use_id.into(),
|
||
content: content.into(),
|
||
is_error: true,
|
||
}
|
||
}
|
||
}
|
||
|
||
// =============================================================================
|
||
// Hook Error
|
||
// =============================================================================
|
||
|
||
/// Hookエラー
|
||
#[derive(Debug, Error)]
|
||
pub enum HookError {
|
||
/// 処理が中断された
|
||
#[error("Aborted: {0}")]
|
||
Aborted(String),
|
||
/// 内部エラー
|
||
#[error("Hook error: {0}")]
|
||
Internal(String),
|
||
}
|
||
|
||
// =============================================================================
|
||
// WorkerHook Trait
|
||
// =============================================================================
|
||
|
||
/// ターンの進行・ツール実行に介入するためのトレイト
|
||
///
|
||
/// Hookを使うと、メッセージ送信前、ツール実行前後、ターン終了時に
|
||
/// 処理を挟んだり、実行をキャンセルしたりできます。
|
||
///
|
||
/// # Examples
|
||
///
|
||
/// ```ignore
|
||
/// use llm_worker::hook::{ControlFlow, HookError, ToolCall, TurnResult, WorkerHook};
|
||
/// use llm_worker::Message;
|
||
///
|
||
/// struct ValidationHook;
|
||
///
|
||
/// #[async_trait::async_trait]
|
||
/// impl WorkerHook for ValidationHook {
|
||
/// async fn before_tool_call(&self, call: &mut ToolCall) -> Result<ControlFlow, HookError> {
|
||
/// // 危険なツールをブロック
|
||
/// if call.name == "delete_all" {
|
||
/// return Ok(ControlFlow::Skip);
|
||
/// }
|
||
/// Ok(ControlFlow::Continue)
|
||
/// }
|
||
///
|
||
/// async fn on_turn_end(&self, messages: &[Message]) -> Result<TurnResult, HookError> {
|
||
/// // 条件を満たさなければ追加メッセージで継続
|
||
/// if messages.len() < 3 {
|
||
/// return Ok(TurnResult::ContinueWithMessages(vec![
|
||
/// Message::user("Please elaborate.")
|
||
/// ]));
|
||
/// }
|
||
/// Ok(TurnResult::Finish)
|
||
/// }
|
||
/// }
|
||
/// ```
|
||
///
|
||
/// # デフォルト実装
|
||
///
|
||
/// すべてのメソッドにはデフォルト実装があり、何も行わず`Continue`を返します。
|
||
/// 必要なメソッドのみオーバーライドしてください。
|
||
#[async_trait]
|
||
pub trait WorkerHook: Send + Sync {
|
||
/// メッセージ送信前に呼ばれる
|
||
///
|
||
/// リクエストに含まれるメッセージリストを参照・改変できます。
|
||
/// `ControlFlow::Abort`を返すとターンが中断されます。
|
||
async fn on_message_send(
|
||
&self,
|
||
_context: &mut Vec<crate::Message>,
|
||
) -> Result<ControlFlow, HookError> {
|
||
Ok(ControlFlow::Continue)
|
||
}
|
||
|
||
/// ツール実行前に呼ばれる
|
||
///
|
||
/// ツール呼び出しの引数を書き換えたり、実行をスキップしたりできます。
|
||
/// `ControlFlow::Skip`を返すとこのツールの実行がスキップされます。
|
||
async fn before_tool_call(&self, _tool_call: &mut ToolCall) -> Result<ControlFlow, HookError> {
|
||
Ok(ControlFlow::Continue)
|
||
}
|
||
|
||
/// ツール実行後に呼ばれる
|
||
///
|
||
/// ツールの実行結果を書き換えたり、隠蔽したりできます。
|
||
async fn after_tool_call(
|
||
&self,
|
||
_tool_result: &mut ToolResult,
|
||
) -> Result<ControlFlow, HookError> {
|
||
Ok(ControlFlow::Continue)
|
||
}
|
||
|
||
/// ターン終了時に呼ばれる
|
||
///
|
||
/// 生成されたメッセージを検査し、必要なら追加メッセージで継続を指示できます。
|
||
/// `TurnResult::ContinueWithMessages`を返すと、指定したメッセージを追加して
|
||
/// 次のターンに進みます。
|
||
async fn on_turn_end(&self, _messages: &[crate::Message]) -> Result<TurnResult, HookError> {
|
||
Ok(TurnResult::Finish)
|
||
}
|
||
}
|