feat: Redesign the tool system
This commit is contained in:
parent
5691b09fc8
commit
16fda38039
|
|
@ -15,10 +15,11 @@ HookはWorker層でのターン制御に介入するためのメカニズムで
|
|||
## Hook一覧
|
||||
|
||||
| Hook | タイミング | 主な用途 | 戻り値 |
|
||||
| ------------------ | -------------------------- | --------------------- | ---------------------- |
|
||||
| `on_message_send` | LLM送信前 | コンテキスト改変/検証 | `OnMessageSendResult` |
|
||||
| `before_tool_call` | ツール実行前 | 実行許可/引数改変 | `BeforeToolCallResult` |
|
||||
| `after_tool_call` | ツール実行後 | 結果加工/マスキング | `AfterToolCallResult` |
|
||||
| ------------------ | -------------------------- | -------------------------- | ---------------------- |
|
||||
| `on_prompt_submit` | `run()` 呼び出し時 | ユーザーメッセージの前処理 | `OnPromptSubmitResult` |
|
||||
| `pre_llm_request` | 各ターンのLLM送信前 | コンテキスト改変/検証 | `PreLlmRequestResult` |
|
||||
| `pre_tool_call` | ツール実行前 | 実行許可/引数改変 | `PreToolCallResult` |
|
||||
| `post_tool_call` | ツール実行後 | 結果加工/マスキング | `PostToolCallResult` |
|
||||
| `on_turn_end` | ツールなしでターン終了直前 | 検証/リトライ指示 | `OnTurnEndResult` |
|
||||
| `on_abort` | 中断時 | クリーンアップ/通知 | `()` |
|
||||
|
||||
|
|
@ -43,25 +44,31 @@ pub trait HookEventKind {
|
|||
type Output;
|
||||
}
|
||||
|
||||
pub struct OnMessageSend;
|
||||
pub struct BeforeToolCall;
|
||||
pub struct AfterToolCall;
|
||||
pub struct OnPromptSubmit;
|
||||
pub struct PreLlmRequest;
|
||||
pub struct PreToolCall;
|
||||
pub struct PostToolCall;
|
||||
pub struct OnTurnEnd;
|
||||
pub struct OnAbort;
|
||||
|
||||
pub enum OnMessageSendResult {
|
||||
pub enum OnPromptSubmitResult {
|
||||
Continue,
|
||||
Cancel(String),
|
||||
}
|
||||
|
||||
pub enum BeforeToolCallResult {
|
||||
pub enum PreLlmRequestResult {
|
||||
Continue,
|
||||
Cancel(String),
|
||||
}
|
||||
|
||||
pub enum PreToolCallResult {
|
||||
Continue,
|
||||
Skip,
|
||||
Abort(String),
|
||||
Pause,
|
||||
}
|
||||
|
||||
pub enum AfterToolCallResult {
|
||||
pub enum PostToolCallResult {
|
||||
Continue,
|
||||
Abort(String),
|
||||
}
|
||||
|
|
@ -75,7 +82,7 @@ pub enum OnTurnEndResult {
|
|||
|
||||
### Tool Call Context
|
||||
|
||||
`before_tool_call` / `after_tool_call` は、ツール実行の文脈を含む入力を受け取る。
|
||||
`pre_tool_call` / `post_tool_call` は、ツール実行の文脈を含む入力を受け取る。
|
||||
|
||||
```rust
|
||||
pub struct ToolCallContext {
|
||||
|
|
@ -84,7 +91,8 @@ pub struct ToolCallContext {
|
|||
pub tool: Arc<dyn Tool>, // 状態アクセス用
|
||||
}
|
||||
|
||||
pub struct ToolResultContext {
|
||||
pub struct PostToolCallContext {
|
||||
pub call: ToolCall,
|
||||
pub result: ToolResult,
|
||||
pub meta: ToolMeta,
|
||||
pub tool: Arc<dyn Tool>,
|
||||
|
|
@ -94,40 +102,84 @@ pub struct ToolResultContext {
|
|||
## 呼び出しタイミング
|
||||
|
||||
```
|
||||
Worker::run() ループ
|
||||
Worker::run(user_input)
|
||||
│
|
||||
├─▶ on_message_send ──────────────────────────────┐
|
||||
│ コンテキストの改変、バリデーション、 │
|
||||
│ システムプロンプト注入などが可能 │
|
||||
├─▶ on_prompt_submit ───────────────────────────┐
|
||||
│ ユーザーメッセージの前処理・検証 │
|
||||
│ (最初の1回のみ) │
|
||||
│ │
|
||||
├─▶ LLMリクエスト送信 & ストリーム処理 │
|
||||
│ │
|
||||
├─▶ ツール呼び出しがある場合: │
|
||||
│ │ │
|
||||
│ ├─▶ before_tool_call (各ツールごと・逐次) │
|
||||
│ │ 実行可否の判定、引数の改変 │
|
||||
│ │ │
|
||||
│ ├─▶ ツール並列実行 (join_all) │
|
||||
│ │ │
|
||||
│ └─▶ after_tool_call (各結果ごと・逐次) │
|
||||
│ 結果の確認、加工、ログ出力 │
|
||||
│ │
|
||||
├─▶ ツール結果をコンテキストに追加 → ループ先頭へ │
|
||||
│ │
|
||||
└─▶ ツールなしの場合: │
|
||||
└─▶ loop {
|
||||
│
|
||||
├─▶ pre_llm_request ──────────────────────│
|
||||
│ コンテキストの改変、バリデーション、 │
|
||||
│ システムプロンプト注入などが可能 │
|
||||
│ (毎ターン実行) │
|
||||
│ │
|
||||
└─▶ on_turn_end ─────────────────────────────┘
|
||||
├─▶ LLMリクエスト送信 & ストリーム処理 │
|
||||
│ │
|
||||
├─▶ ツール呼び出しがある場合: │
|
||||
│ │ │
|
||||
│ ├─▶ pre_tool_call (各ツールごと・逐次) │
|
||||
│ │ 実行可否の判定、引数の改変 │
|
||||
│ │ │
|
||||
│ ├─▶ ツール並列実行 (join_all) │
|
||||
│ │ │
|
||||
│ └─▶ post_tool_call (各結果ごと・逐次) │
|
||||
│ 結果の確認、加工、ログ出力 │
|
||||
│ │
|
||||
├─▶ ツール結果をコンテキストに追加 │
|
||||
│ → ループ先頭へ │
|
||||
│ │
|
||||
└─▶ ツールなしの場合: │
|
||||
│ │
|
||||
└─▶ on_turn_end ───────────────────┘
|
||||
最終応答のチェック(Lint/Fmt等)
|
||||
エラーがあればContinueWithMessagesでリトライ
|
||||
}
|
||||
|
||||
※ 中断時は on_abort が呼ばれる
|
||||
```
|
||||
|
||||
## 各Hookの詳細
|
||||
|
||||
### on_message_send
|
||||
### on_prompt_submit
|
||||
|
||||
**呼び出しタイミング**: LLMへリクエスト送信前(ターンループの冒頭)
|
||||
**呼び出しタイミング**: `run()`
|
||||
でユーザーメッセージを受け取った直後(最初の1回のみ)
|
||||
|
||||
**用途**:
|
||||
|
||||
- ユーザー入力のバリデーション
|
||||
- 入力のサニタイズ・フィルタリング
|
||||
- ログ出力
|
||||
- `OnPromptSubmitResult::Cancel` による実行キャンセル
|
||||
|
||||
**入力**: `&mut Message` - ユーザーメッセージ(改変可能)
|
||||
|
||||
**例**: 入力のバリデーション
|
||||
|
||||
```rust
|
||||
struct InputValidator;
|
||||
|
||||
#[async_trait]
|
||||
impl Hook<OnPromptSubmit> for InputValidator {
|
||||
async fn call(
|
||||
&self,
|
||||
message: &mut Message,
|
||||
) -> Result<OnPromptSubmitResult, HookError> {
|
||||
if let MessageContent::Text(text) = &message.content {
|
||||
if text.trim().is_empty() {
|
||||
return Ok(OnPromptSubmitResult::Cancel("Empty input".to_string()));
|
||||
}
|
||||
}
|
||||
Ok(OnPromptSubmitResult::Continue)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### pre_llm_request
|
||||
|
||||
**呼び出しタイミング**: 各ターンのLLMリクエスト送信前(ループの毎回)
|
||||
|
||||
**用途**:
|
||||
|
||||
|
|
@ -135,7 +187,9 @@ Worker::run() ループ
|
|||
- メッセージのバリデーション
|
||||
- 機密情報のフィルタリング
|
||||
- リクエスト内容のログ出力
|
||||
- `OnMessageSendResult::Cancel` による送信キャンセル
|
||||
- `PreLlmRequestResult::Cancel` による送信キャンセル
|
||||
|
||||
**入力**: `&mut Vec<Message>` - コンテキスト全体(改変可能)
|
||||
|
||||
**例**: メッセージにタイムスタンプを追加
|
||||
|
||||
|
|
@ -143,19 +197,19 @@ Worker::run() ループ
|
|||
struct TimestampHook;
|
||||
|
||||
#[async_trait]
|
||||
impl Hook<OnMessageSend> for TimestampHook {
|
||||
impl Hook<PreLlmRequest> for TimestampHook {
|
||||
async fn call(
|
||||
&self,
|
||||
context: &mut Vec<Message>,
|
||||
) -> Result<OnMessageSendResult, HookError> {
|
||||
) -> Result<PreLlmRequestResult, HookError> {
|
||||
let timestamp = chrono::Local::now().to_rfc3339();
|
||||
context.insert(0, Message::user(format!("[{}]", timestamp)));
|
||||
Ok(OnMessageSendResult::Continue)
|
||||
Ok(PreLlmRequestResult::Continue)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### before_tool_call
|
||||
### pre_tool_call
|
||||
|
||||
**呼び出しタイミング**: 各ツール実行前(並列実行フェーズの前)
|
||||
|
||||
|
|
@ -165,9 +219,10 @@ impl Hook<OnMessageSend> for TimestampHook {
|
|||
- 引数のサニタイズ
|
||||
- 確認プロンプトの表示(UIとの連携)
|
||||
- 実行ログの記録
|
||||
- `BeforeToolCallResult::Pause` による一時停止
|
||||
- `PreToolCallResult::Pause` による一時停止
|
||||
|
||||
**入力**:
|
||||
|
||||
- `ToolCallContext`(`ToolCall` + `ToolMeta` + `Arc<dyn Tool>`)
|
||||
|
||||
**例**: 特定ツールをブロック
|
||||
|
|
@ -178,22 +233,22 @@ struct ToolBlocker {
|
|||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Hook<BeforeToolCall> for ToolBlocker {
|
||||
impl Hook<PreToolCall> for ToolBlocker {
|
||||
async fn call(
|
||||
&self,
|
||||
ctx: &mut ToolCallContext,
|
||||
) -> Result<BeforeToolCallResult, HookError> {
|
||||
) -> Result<PreToolCallResult, HookError> {
|
||||
if self.blocked_tools.contains(&ctx.call.name) {
|
||||
println!("Blocked tool: {}", ctx.call.name);
|
||||
Ok(BeforeToolCallResult::Skip)
|
||||
Ok(PreToolCallResult::Skip)
|
||||
} else {
|
||||
Ok(BeforeToolCallResult::Continue)
|
||||
Ok(PreToolCallResult::Continue)
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### after_tool_call
|
||||
### post_tool_call
|
||||
|
||||
**呼び出しタイミング**: 各ツール実行後(並列実行フェーズの後)
|
||||
|
||||
|
|
@ -203,8 +258,11 @@ impl Hook<BeforeToolCall> for ToolBlocker {
|
|||
- 機密情報のマスキング
|
||||
- 結果のキャッシュ
|
||||
- 実行結果のログ出力
|
||||
|
||||
**入力**:
|
||||
- `ToolResultContext`(`ToolResult` + `ToolMeta` + `Arc<dyn Tool>`)
|
||||
|
||||
- `PostToolCallContext`(`ToolCall` + `ToolResult` + `ToolMeta` +
|
||||
`Arc<dyn Tool>`)
|
||||
|
||||
**例**: 結果にプレフィックスを追加
|
||||
|
||||
|
|
@ -212,15 +270,15 @@ impl Hook<BeforeToolCall> for ToolBlocker {
|
|||
struct ResultFormatter;
|
||||
|
||||
#[async_trait]
|
||||
impl Hook<AfterToolCall> for ResultFormatter {
|
||||
impl Hook<PostToolCall> for ResultFormatter {
|
||||
async fn call(
|
||||
&self,
|
||||
ctx: &mut ToolResultContext,
|
||||
) -> Result<AfterToolCallResult, HookError> {
|
||||
ctx: &mut PostToolCallContext,
|
||||
) -> Result<PostToolCallResult, HookError> {
|
||||
if !ctx.result.is_error {
|
||||
ctx.result.content = format!("[OK] {}", ctx.result.content);
|
||||
}
|
||||
Ok(AfterToolCallResult::Continue)
|
||||
Ok(PostToolCallResult::Continue)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
|
@ -283,9 +341,9 @@ impl Hook<OnTurnEnd> for JsonValidator {
|
|||
Hookは**イベントごとに登録順**に実行されます。
|
||||
|
||||
```rust
|
||||
worker.add_before_tool_call_hook(HookA); // 1番目に実行
|
||||
worker.add_before_tool_call_hook(HookB); // 2番目に実行
|
||||
worker.add_before_tool_call_hook(HookC); // 3番目に実行
|
||||
worker.add_pre_tool_call_hook(HookA); // 1番目に実行
|
||||
worker.add_pre_tool_call_hook(HookB); // 2番目に実行
|
||||
worker.add_pre_tool_call_hook(HookC); // 3番目に実行
|
||||
```
|
||||
|
||||
### 制御フローの伝播
|
||||
|
|
@ -323,15 +381,15 @@ Hook A: Continue → Hook B: Pause
|
|||
async fn call(&self, ctx: &mut ToolCallContext) -> ... {
|
||||
// 引数を直接書き換え
|
||||
ctx.call.input["sanitized"] = json!(true);
|
||||
Ok(BeforeToolCallResult::Continue)
|
||||
Ok(PreToolCallResult::Continue)
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 並列実行との統合
|
||||
|
||||
- `before_tool_call`: 並列実行**前**に逐次実行(許可判定のため)
|
||||
- `pre_tool_call`: 並列実行**前**に逐次実行(許可判定のため)
|
||||
- ツール実行: `join_all`で**並列**実行
|
||||
- `after_tool_call`: 並列実行**後**に逐次実行(結果加工のため)
|
||||
- `post_tool_call`: 並列実行**後**に逐次実行(結果加工のため)
|
||||
|
||||
### 4. Send + Sync 要件
|
||||
|
||||
|
|
@ -344,10 +402,10 @@ struct CountingHook {
|
|||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Hook<BeforeToolCall> for CountingHook {
|
||||
async fn call(&self, _: &mut ToolCallContext) -> Result<BeforeToolCallResult, HookError> {
|
||||
impl Hook<PreToolCall> for CountingHook {
|
||||
async fn call(&self, _: &mut ToolCallContext) -> Result<PreToolCallResult, HookError> {
|
||||
self.count.fetch_add(1, Ordering::SeqCst);
|
||||
Ok(BeforeToolCallResult::Continue)
|
||||
Ok(PreToolCallResult::Continue)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
|
@ -355,13 +413,13 @@ impl Hook<BeforeToolCall> for CountingHook {
|
|||
## 典型的なユースケース
|
||||
|
||||
| ユースケース | 使用Hook | 処理内容 |
|
||||
| ------------------ | ------------------------ | -------------------------- |
|
||||
| ツール許可制御 | `before_tool_call` | 危険なツールをSkip |
|
||||
| 実行ログ | `before/after_tool_call` | 呼び出しと結果を記録 |
|
||||
| ------------------ | -------------------- | -------------------------- |
|
||||
| ツール許可制御 | `pre_tool_call` | 危険なツールをSkip |
|
||||
| 実行ログ | `pre/post_tool_call` | 呼び出しと結果を記録 |
|
||||
| 出力バリデーション | `on_turn_end` | 形式チェック、リトライ指示 |
|
||||
| コンテキスト注入 | `on_message_send` | システムメッセージ追加 |
|
||||
| 結果のサニタイズ | `after_tool_call` | 機密情報のマスキング |
|
||||
| レート制限 | `before_tool_call` | 呼び出し頻度の制御 |
|
||||
| 結果のサニタイズ | `post_tool_call` | 機密情報のマスキング |
|
||||
| レート制限 | `pre_tool_call` | 呼び出し頻度の制御 |
|
||||
|
||||
## TODO
|
||||
|
||||
|
|
|
|||
191
docs/spec/tools_design.md
Normal file
191
docs/spec/tools_design.md
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
# Tool 設計
|
||||
|
||||
## 概要
|
||||
|
||||
`llm-worker`のツールシステムは、LLMが外部リソースにアクセスしたり計算を実行するための仕組みを提供する。
|
||||
メタ情報の不変性とセッションスコープの状態管理を両立させる設計となっている。
|
||||
|
||||
## 主要な型
|
||||
|
||||
```
|
||||
type ToolDefinition
|
||||
Fn() -> (ToolMeta, Arc<dyn Tool>)
|
||||
|
||||
worker.register_tool() で呼び出し
|
||||
|
||||
▼
|
||||
|
||||
- struct ToolMeta (name, desc, schema)
|
||||
不変・登録時固定
|
||||
- trait Tool (executer)
|
||||
登録時生成・セッション中再利用
|
||||
```
|
||||
|
||||
### ToolMeta
|
||||
|
||||
ツールのメタ情報を保持する不変構造体。登録時に固定され、Worker内で変更されない。
|
||||
|
||||
```rust
|
||||
pub struct ToolMeta {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub input_schema: Value,
|
||||
}
|
||||
```
|
||||
|
||||
**目的:**
|
||||
|
||||
- LLM へのツール定義として送信
|
||||
- Hook からの参照(読み取り専用)
|
||||
- 登録後の不変性を保証
|
||||
|
||||
### Tool trait
|
||||
|
||||
ツールの実行ロジックのみを定義するトレイト。
|
||||
|
||||
```rust
|
||||
#[async_trait]
|
||||
pub trait Tool: Send + Sync {
|
||||
async fn execute(&self, input_json: &str) -> Result<String, ToolError>;
|
||||
}
|
||||
```
|
||||
|
||||
**設計方針:**
|
||||
|
||||
- メタ情報(name, description, schema)は含まない
|
||||
- 状態を持つことが可能(セッション中のカウンターなど)
|
||||
- `Send + Sync` で並列実行に対応
|
||||
|
||||
**インスタンスのライフサイクル:**
|
||||
|
||||
1. `register_tool()` 呼び出し時にファクトリが実行され、インスタンスが生成される
|
||||
2. LLM がツールを呼び出すと、既存インスタンスの `execute()` が実行される
|
||||
3. 同じセッション中は同一インスタンスが再利用される
|
||||
|
||||
※ 「最初に呼ばれたとき」の遅延初期化ではなく、**登録時の即時初期化**である。
|
||||
|
||||
### ToolDefinition
|
||||
|
||||
メタ情報とツールインスタンスを生成するファクトリ。
|
||||
|
||||
```rust
|
||||
pub type ToolDefinition = Arc<dyn Fn() -> (ToolMeta, Arc<dyn Tool>) + Send + Sync>;
|
||||
```
|
||||
|
||||
**なぜファクトリか:**
|
||||
|
||||
- Worker への登録時に一度だけ呼び出される
|
||||
- メタ情報とインスタンスを同時に生成し、整合性を保証
|
||||
- クロージャでコンテキスト(`self.clone()`)をキャプチャ可能
|
||||
|
||||
## Worker でのツール管理
|
||||
|
||||
```rust
|
||||
// Worker 内部
|
||||
tools: HashMap<String, (ToolMeta, Arc<dyn Tool>)>
|
||||
|
||||
// 登録 API
|
||||
pub fn register_tool(&mut self, factory: ToolDefinition) -> Result<(), ToolRegistryError>
|
||||
```
|
||||
|
||||
登録時の処理:
|
||||
|
||||
1. ファクトリを呼び出し `(meta, instance)` を取得
|
||||
2. 同名ツールが既に登録されていればエラー
|
||||
3. HashMap に `(meta, instance)` を保存
|
||||
|
||||
## マクロによる自動生成
|
||||
|
||||
`#[tool_registry]` マクロは `{method}_definition()` メソッドを生成する。
|
||||
|
||||
```rust
|
||||
#[tool_registry]
|
||||
impl MyApp {
|
||||
/// 検索を実行する
|
||||
#[tool]
|
||||
async fn search(&self, query: String) -> String {
|
||||
// 実装
|
||||
}
|
||||
}
|
||||
|
||||
// 生成されるコード:
|
||||
impl MyApp {
|
||||
pub fn search_definition(&self) -> ToolDefinition {
|
||||
let ctx = self.clone();
|
||||
Arc::new(move || {
|
||||
let meta = ToolMeta::new("search")
|
||||
.description("検索を実行する")
|
||||
.input_schema(/* schemars で生成 */);
|
||||
let tool = Arc::new(ToolSearch { ctx: ctx.clone() });
|
||||
(meta, tool)
|
||||
})
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Hook との連携
|
||||
|
||||
Hook は `ToolCallContext` / `AfterToolCallContext`
|
||||
を通じてメタ情報とインスタンスにアクセスできる。
|
||||
|
||||
```rust
|
||||
pub struct ToolCallContext {
|
||||
pub call: ToolCall, // 呼び出し情報(改変可能)
|
||||
pub meta: ToolMeta, // メタ情報(読み取り専用)
|
||||
pub tool: Arc<dyn Tool>, // インスタンス(状態アクセス用)
|
||||
}
|
||||
```
|
||||
|
||||
**用途:**
|
||||
|
||||
- `meta` で名前やスキーマを確認
|
||||
- `tool` でツールの内部状態を読み取り(ダウンキャスト必要)
|
||||
- `call` の引数を改変してツールに渡す
|
||||
|
||||
## 使用例
|
||||
|
||||
### 手動実装
|
||||
|
||||
```rust
|
||||
struct Counter { count: AtomicUsize }
|
||||
|
||||
impl Tool for Counter {
|
||||
async fn execute(&self, _: &str) -> Result<String, ToolError> {
|
||||
let n = self.count.fetch_add(1, Ordering::SeqCst);
|
||||
Ok(format!("count: {}", n))
|
||||
}
|
||||
}
|
||||
|
||||
let def: ToolDefinition = Arc::new(|| {
|
||||
let meta = ToolMeta::new("counter")
|
||||
.description("カウンターを増加")
|
||||
.input_schema(json!({"type": "object"}));
|
||||
(meta, Arc::new(Counter { count: AtomicUsize::new(0) }))
|
||||
});
|
||||
|
||||
worker.register_tool(def)?;
|
||||
```
|
||||
|
||||
### マクロ使用(推奨)
|
||||
|
||||
```rust
|
||||
#[tool_registry]
|
||||
impl App {
|
||||
#[tool]
|
||||
async fn greet(&self, name: String) -> String {
|
||||
format!("Hello, {}!", name)
|
||||
}
|
||||
}
|
||||
|
||||
let app = App;
|
||||
worker.register_tool(app.greet_definition())?;
|
||||
```
|
||||
|
||||
## 設計上の決定
|
||||
|
||||
| 問題 | 決定 | 理由 |
|
||||
| -------------------- | ------------------------------ | ---------------------------------------------- |
|
||||
| メタ情報の変更可能性 | ToolMeta を分離・不変化 | 登録後の整合性を保証 |
|
||||
| 状態管理 | 登録時にインスタンス生成 | セッション中の状態保持、同一インスタンス再利用 |
|
||||
| Factory vs Instance | Factory + 登録時即時呼び出し | コンテキストキャプチャと登録時検証 |
|
||||
| Hook からのアクセス | Context に meta と tool を含む | 柔軟な介入を可能に |
|
||||
|
|
@ -113,7 +113,7 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
|
|||
let pascal_name = to_pascal_case(&method_name.to_string());
|
||||
let tool_struct_name = format_ident!("Tool{}", pascal_name);
|
||||
let args_struct_name = format_ident!("{}Args", pascal_name);
|
||||
let factory_name = format_ident!("{}_tool", method_name);
|
||||
let definition_name = format_ident!("{}_definition", method_name);
|
||||
|
||||
// ドキュメントコメントから説明を取得
|
||||
let description = extract_doc_comment(&method.attrs);
|
||||
|
|
@ -247,29 +247,24 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
|
|||
|
||||
#[async_trait::async_trait]
|
||||
impl ::llm_worker::tool::Tool for #tool_struct_name {
|
||||
fn name(&self) -> &str {
|
||||
#tool_name
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
#description
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
let schema = schemars::schema_for!(#args_struct_name);
|
||||
serde_json::to_value(schema).unwrap_or(serde_json::json!({}))
|
||||
}
|
||||
|
||||
async fn execute(&self, input_json: &str) -> Result<String, ::llm_worker::tool::ToolError> {
|
||||
#execute_body
|
||||
}
|
||||
}
|
||||
|
||||
impl #self_ty {
|
||||
pub fn #factory_name(&self) -> #tool_struct_name {
|
||||
#tool_struct_name {
|
||||
ctx: self.clone()
|
||||
}
|
||||
/// ToolDefinition を取得(Worker への登録用)
|
||||
pub fn #definition_name(&self) -> ::llm_worker::tool::ToolDefinition {
|
||||
let ctx = self.clone();
|
||||
::std::sync::Arc::new(move || {
|
||||
let schema = schemars::schema_for!(#args_struct_name);
|
||||
let meta = ::llm_worker::tool::ToolMeta::new(#tool_name)
|
||||
.description(#description)
|
||||
.input_schema(serde_json::to_value(schema).unwrap_or(serde_json::json!({})));
|
||||
let tool: ::std::sync::Arc<dyn ::llm_worker::tool::Tool> =
|
||||
::std::sync::Arc::new(#tool_struct_name { ctx: ctx.clone() });
|
||||
(meta, tool)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@
|
|||
//!
|
||||
//! ストリーミング受信中に別スレッドからキャンセルする例
|
||||
|
||||
use llm_worker::llm_client::providers::anthropic::AnthropicClient;
|
||||
use llm_worker::{Worker, WorkerResult};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use llm_worker::{Worker, WorkerResult};
|
||||
use llm_worker::llm_client::providers::anthropic::AnthropicClient;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
|
|
@ -21,8 +21,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
)
|
||||
.init();
|
||||
|
||||
let api_key = std::env::var("ANTHROPIC_API_KEY")
|
||||
.expect("ANTHROPIC_API_KEY environment variable not set");
|
||||
let api_key =
|
||||
std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY environment variable not set");
|
||||
|
||||
let client = AnthropicClient::new(&api_key, "claude-sonnet-4-20250514");
|
||||
let worker = Arc::new(Mutex::new(Worker::new(client)));
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ use tracing_subscriber::EnvFilter;
|
|||
use clap::{Parser, ValueEnum};
|
||||
use llm_worker::{
|
||||
Worker,
|
||||
hook::{AfterToolCall, AfterToolCallResult, Hook, HookError, ToolResult},
|
||||
hook::{Hook, HookError, PostToolCall, PostToolCallContext, PostToolCallResult},
|
||||
llm_client::{
|
||||
LlmClient,
|
||||
providers::{
|
||||
|
|
@ -282,25 +282,22 @@ impl ToolResultPrinterHook {
|
|||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Hook<AfterToolCall> for ToolResultPrinterHook {
|
||||
async fn call(
|
||||
&self,
|
||||
tool_result: &mut ToolResult,
|
||||
) -> Result<AfterToolCallResult, HookError> {
|
||||
impl Hook<PostToolCall> for ToolResultPrinterHook {
|
||||
async fn call(&self, ctx: &mut PostToolCallContext) -> Result<PostToolCallResult, HookError> {
|
||||
let name = self
|
||||
.call_names
|
||||
.lock()
|
||||
.unwrap()
|
||||
.remove(&tool_result.tool_use_id)
|
||||
.unwrap_or_else(|| tool_result.tool_use_id.clone());
|
||||
.remove(&ctx.result.tool_use_id)
|
||||
.unwrap_or_else(|| ctx.result.tool_use_id.clone());
|
||||
|
||||
if tool_result.is_error {
|
||||
println!(" Result ({}): ❌ {}", name, tool_result.content);
|
||||
if ctx.result.is_error {
|
||||
println!(" Result ({}): ❌ {}", name, ctx.result.content);
|
||||
} else {
|
||||
println!(" Result ({}): ✅ {}", name, tool_result.content);
|
||||
println!(" Result ({}): ✅ {}", name, ctx.result.content);
|
||||
}
|
||||
|
||||
Ok(AfterToolCallResult::Continue)
|
||||
Ok(PostToolCallResult::Continue)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -441,8 +438,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
// ツール登録(--no-tools でなければ)
|
||||
if !args.no_tools {
|
||||
let app = AppContext;
|
||||
worker.register_tool(app.get_current_time_tool());
|
||||
worker.register_tool(app.calculate_tool());
|
||||
worker
|
||||
.register_tool(app.get_current_time_definition())
|
||||
.unwrap();
|
||||
worker.register_tool(app.calculate_definition()).unwrap();
|
||||
}
|
||||
|
||||
// ストリーミング表示用ハンドラーを登録
|
||||
|
|
@ -451,7 +450,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
.on_text_block(StreamingPrinter::new())
|
||||
.on_tool_use_block(ToolCallPrinter::new(tool_call_names.clone()));
|
||||
|
||||
worker.add_after_tool_call_hook(ToolResultPrinterHook::new(tool_call_names));
|
||||
worker.add_post_tool_call_hook(ToolResultPrinterHook::new(tool_call_names));
|
||||
|
||||
// ワンショットモード
|
||||
if let Some(prompt) = args.prompt {
|
||||
|
|
|
|||
|
|
@ -16,20 +16,27 @@ pub trait HookEventKind: Send + Sync + 'static {
|
|||
type Output;
|
||||
}
|
||||
|
||||
pub struct OnMessageSend;
|
||||
pub struct BeforeToolCall;
|
||||
pub struct AfterToolCall;
|
||||
pub struct OnPromptSubmit;
|
||||
pub struct PreLlmRequest;
|
||||
pub struct PreToolCall;
|
||||
pub struct PostToolCall;
|
||||
pub struct OnTurnEnd;
|
||||
pub struct OnAbort;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum OnMessageSendResult {
|
||||
pub enum OnPromptSubmitResult {
|
||||
Continue,
|
||||
Cancel(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum BeforeToolCallResult {
|
||||
pub enum PreLlmRequestResult {
|
||||
Continue,
|
||||
Cancel(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum PreToolCallResult {
|
||||
Continue,
|
||||
Skip,
|
||||
Abort(String),
|
||||
|
|
@ -37,7 +44,7 @@ pub enum BeforeToolCallResult {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum AfterToolCallResult {
|
||||
pub enum PostToolCallResult {
|
||||
Continue,
|
||||
Abort(String),
|
||||
}
|
||||
|
|
@ -49,19 +56,50 @@ pub enum OnTurnEndResult {
|
|||
Paused,
|
||||
}
|
||||
|
||||
impl HookEventKind for OnMessageSend {
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::tool::{Tool, ToolMeta};
|
||||
|
||||
/// PreToolCall の入力コンテキスト
|
||||
pub struct ToolCallContext {
|
||||
/// ツール呼び出し情報(改変可能)
|
||||
pub call: ToolCall,
|
||||
/// ツールメタ情報(不変)
|
||||
pub meta: ToolMeta,
|
||||
/// ツールインスタンス(状態アクセス用)
|
||||
pub tool: Arc<dyn Tool>,
|
||||
}
|
||||
|
||||
/// PostToolCall の入力コンテキスト
|
||||
pub struct PostToolCallContext {
|
||||
/// ツール呼び出し情報
|
||||
pub call: ToolCall,
|
||||
/// ツール実行結果(改変可能)
|
||||
pub result: ToolResult,
|
||||
/// ツールメタ情報(不変)
|
||||
pub meta: ToolMeta,
|
||||
/// ツールインスタンス(状態アクセス用)
|
||||
pub tool: Arc<dyn Tool>,
|
||||
}
|
||||
|
||||
impl HookEventKind for OnPromptSubmit {
|
||||
type Input = crate::Message;
|
||||
type Output = OnPromptSubmitResult;
|
||||
}
|
||||
|
||||
impl HookEventKind for PreLlmRequest {
|
||||
type Input = Vec<crate::Message>;
|
||||
type Output = OnMessageSendResult;
|
||||
type Output = PreLlmRequestResult;
|
||||
}
|
||||
|
||||
impl HookEventKind for BeforeToolCall {
|
||||
type Input = ToolCall;
|
||||
type Output = BeforeToolCallResult;
|
||||
impl HookEventKind for PreToolCall {
|
||||
type Input = ToolCallContext;
|
||||
type Output = PreToolCallResult;
|
||||
}
|
||||
|
||||
impl HookEventKind for AfterToolCall {
|
||||
type Input = ToolResult;
|
||||
type Output = AfterToolCallResult;
|
||||
impl HookEventKind for PostToolCall {
|
||||
type Input = PostToolCallContext;
|
||||
type Output = PostToolCallResult;
|
||||
}
|
||||
|
||||
impl HookEventKind for OnTurnEnd {
|
||||
|
|
@ -151,3 +189,45 @@ pub enum HookError {
|
|||
pub trait Hook<E: HookEventKind>: Send + Sync {
|
||||
async fn call(&self, input: &mut E::Input) -> Result<E::Output, HookError>;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Hook Registry
|
||||
// =============================================================================
|
||||
|
||||
/// 全 Hook を保持するレジストリ
|
||||
///
|
||||
/// Worker 内部で使用され、各種 Hook を一括管理する。
|
||||
pub struct HookRegistry {
|
||||
/// on_prompt_submit Hook
|
||||
pub(crate) on_prompt_submit: Vec<Box<dyn Hook<OnPromptSubmit>>>,
|
||||
/// pre_llm_request Hook
|
||||
pub(crate) pre_llm_request: Vec<Box<dyn Hook<PreLlmRequest>>>,
|
||||
/// pre_tool_call Hook
|
||||
pub(crate) pre_tool_call: Vec<Box<dyn Hook<PreToolCall>>>,
|
||||
/// post_tool_call Hook
|
||||
pub(crate) post_tool_call: Vec<Box<dyn Hook<PostToolCall>>>,
|
||||
/// on_turn_end Hook
|
||||
pub(crate) on_turn_end: Vec<Box<dyn Hook<OnTurnEnd>>>,
|
||||
/// on_abort Hook
|
||||
pub(crate) on_abort: Vec<Box<dyn Hook<OnAbort>>>,
|
||||
}
|
||||
|
||||
impl Default for HookRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl HookRegistry {
|
||||
/// 空の HookRegistry を作成
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
on_prompt_submit: Vec::new(),
|
||||
pre_llm_request: Vec::new(),
|
||||
pre_tool_call: Vec::new(),
|
||||
post_tool_call: Vec::new(),
|
||||
on_turn_end: Vec::new(),
|
||||
on_abort: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,8 +19,7 @@
|
|||
//! .system_prompt("You are a helpful assistant.");
|
||||
//!
|
||||
//! // ツールを登録(オプション)
|
||||
//! use llm_worker::tool::Tool;
|
||||
//! worker.register_tool(my_tool);
|
||||
//! // worker.register_tool(my_tool_definition)?;
|
||||
//!
|
||||
//! // 対話を実行
|
||||
//! let history = worker.run("Hello!").await?;
|
||||
|
|
@ -49,4 +48,4 @@ pub mod timeline;
|
|||
pub mod tool;
|
||||
|
||||
pub use message::{ContentPart, Message, MessageContent, Role};
|
||||
pub use worker::{Worker, WorkerConfig, WorkerError, WorkerResult};
|
||||
pub use worker::{ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult};
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
//! LLMから呼び出し可能なツールを定義するためのトレイト。
|
||||
//! 通常は`#[tool]`マクロを使用して自動実装します。
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
use thiserror::Error;
|
||||
|
|
@ -21,64 +23,126 @@ pub enum ToolError {
|
|||
Internal(String),
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// ToolMeta - 不変のメタ情報
|
||||
// =============================================================================
|
||||
|
||||
/// ツールのメタ情報(登録時に固定、不変)
|
||||
///
|
||||
/// `ToolDefinition` ファクトリから生成され、Worker に登録後は変更されません。
|
||||
/// LLM へのツール定義送信に使用されます。
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ToolMeta {
|
||||
/// ツール名(LLMが識別に使用)
|
||||
pub name: String,
|
||||
/// ツールの説明(LLMへのプロンプトに含まれる)
|
||||
pub description: String,
|
||||
/// 引数のJSON Schema
|
||||
pub input_schema: Value,
|
||||
}
|
||||
|
||||
impl ToolMeta {
|
||||
/// 新しい ToolMeta を作成
|
||||
pub fn new(name: impl Into<String>) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
description: String::new(),
|
||||
input_schema: Value::Object(Default::default()),
|
||||
}
|
||||
}
|
||||
|
||||
/// 説明を設定
|
||||
pub fn description(mut self, desc: impl Into<String>) -> Self {
|
||||
self.description = desc.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// 引数スキーマを設定
|
||||
pub fn input_schema(mut self, schema: Value) -> Self {
|
||||
self.input_schema = schema;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// ToolDefinition - ファクトリ型
|
||||
// =============================================================================
|
||||
|
||||
/// ツール定義ファクトリ
|
||||
///
|
||||
/// 呼び出すと `(ToolMeta, Arc<dyn Tool>)` を返します。
|
||||
/// Worker への登録時に一度だけ呼び出され、メタ情報とインスタンスが
|
||||
/// セッションスコープでキャッシュされます。
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```ignore
|
||||
/// let def: ToolDefinition = Arc::new(|| {
|
||||
/// (
|
||||
/// ToolMeta::new("my_tool")
|
||||
/// .description("My tool description")
|
||||
/// .input_schema(json!({"type": "object"})),
|
||||
/// Arc::new(MyToolImpl { state: 0 }) as Arc<dyn Tool>,
|
||||
/// )
|
||||
/// });
|
||||
/// worker.register_tool(def)?;
|
||||
/// ```
|
||||
pub type ToolDefinition = Arc<dyn Fn() -> (ToolMeta, Arc<dyn Tool>) + Send + Sync>;
|
||||
|
||||
// =============================================================================
|
||||
// Tool trait
|
||||
// =============================================================================
|
||||
|
||||
/// LLMから呼び出し可能なツールを定義するトレイト
|
||||
///
|
||||
/// ツールはLLMが外部リソースにアクセスしたり、
|
||||
/// 計算を実行したりするために使用します。
|
||||
/// セッション中の状態を保持できます。
|
||||
///
|
||||
/// # 実装方法
|
||||
///
|
||||
/// 通常は`#[tool]`マクロを使用して自動実装します:
|
||||
/// 通常は`#[tool_registry]`マクロを使用して自動実装します:
|
||||
///
|
||||
/// ```ignore
|
||||
/// use llm_worker::tool;
|
||||
///
|
||||
/// #[tool(description = "Search the web for information")]
|
||||
/// async fn search(query: String) -> String {
|
||||
/// // 検索処理
|
||||
/// #[tool_registry]
|
||||
/// impl MyApp {
|
||||
/// #[tool]
|
||||
/// async fn search(&self, query: String) -> String {
|
||||
/// format!("Results for: {}", query)
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// // 登録
|
||||
/// worker.register_tool(app.search_definition())?;
|
||||
/// ```
|
||||
///
|
||||
/// # 手動実装
|
||||
///
|
||||
/// ```ignore
|
||||
/// use llm_worker::tool::{Tool, ToolError};
|
||||
/// use serde_json::{json, Value};
|
||||
/// use llm_worker::tool::{Tool, ToolError, ToolMeta, ToolDefinition};
|
||||
/// use std::sync::Arc;
|
||||
///
|
||||
/// struct MyTool;
|
||||
/// struct MyTool { counter: std::sync::atomic::AtomicUsize }
|
||||
///
|
||||
/// #[async_trait::async_trait]
|
||||
/// impl Tool for MyTool {
|
||||
/// fn name(&self) -> &str { "my_tool" }
|
||||
/// fn description(&self) -> &str { "My custom tool" }
|
||||
/// fn input_schema(&self) -> Value {
|
||||
/// json!({
|
||||
/// "type": "object",
|
||||
/// "properties": {
|
||||
/// "query": { "type": "string" }
|
||||
/// },
|
||||
/// "required": ["query"]
|
||||
/// })
|
||||
/// }
|
||||
/// async fn execute(&self, input: &str) -> Result<String, ToolError> {
|
||||
/// self.counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||||
/// Ok("result".to_string())
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// let def: ToolDefinition = Arc::new(|| {
|
||||
/// (
|
||||
/// ToolMeta::new("my_tool")
|
||||
/// .description("My custom tool")
|
||||
/// .input_schema(serde_json::json!({"type": "object"})),
|
||||
/// Arc::new(MyTool { counter: Default::default() }) as Arc<dyn Tool>,
|
||||
/// )
|
||||
/// });
|
||||
/// ```
|
||||
#[async_trait]
|
||||
pub trait Tool: Send + Sync {
|
||||
/// ツール名(LLMが識別に使用)
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// ツールの説明(LLMへのプロンプトに含まれる)
|
||||
fn description(&self) -> &str;
|
||||
|
||||
/// 引数のJSON Schema
|
||||
///
|
||||
/// LLMはこのスキーマに従って引数を生成します。
|
||||
fn input_schema(&self) -> Value;
|
||||
|
||||
/// ツールを実行する
|
||||
///
|
||||
/// # Arguments
|
||||
|
|
|
|||
|
|
@ -9,18 +9,21 @@ use tracing::{debug, info, trace, warn};
|
|||
use crate::{
|
||||
ContentPart, Message, MessageContent, Role,
|
||||
hook::{
|
||||
AfterToolCall, AfterToolCallResult, BeforeToolCall, BeforeToolCallResult, Hook, HookError,
|
||||
OnAbort, OnMessageSend, OnMessageSendResult, OnTurnEnd, OnTurnEndResult, ToolCall,
|
||||
ToolResult,
|
||||
Hook, HookError, HookRegistry, OnAbort, OnPromptSubmit, OnPromptSubmitResult, OnTurnEnd,
|
||||
OnTurnEndResult, PostToolCall, PostToolCallContext, PostToolCallResult, PreLlmRequest,
|
||||
PreLlmRequestResult, PreToolCall, PreToolCallResult, ToolCall, ToolCallContext, ToolResult,
|
||||
},
|
||||
llm_client::{
|
||||
ClientError, ConfigWarning, LlmClient, Request, RequestConfig,
|
||||
ToolDefinition as LlmToolDefinition,
|
||||
},
|
||||
llm_client::{ClientError, ConfigWarning, LlmClient, Request, RequestConfig, ToolDefinition},
|
||||
state::{Locked, Mutable, WorkerState},
|
||||
subscriber::{
|
||||
ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter,
|
||||
ToolUseBlockSubscriberAdapter, UsageSubscriberAdapter, WorkerSubscriber,
|
||||
},
|
||||
timeline::{TextBlockCollector, Timeline, ToolCallCollector},
|
||||
tool::{Tool, ToolError},
|
||||
tool::{Tool, ToolDefinition, ToolError, ToolMeta},
|
||||
};
|
||||
|
||||
// =============================================================================
|
||||
|
|
@ -50,6 +53,14 @@ pub enum WorkerError {
|
|||
ConfigWarnings(Vec<ConfigWarning>),
|
||||
}
|
||||
|
||||
/// ツール登録エラー
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ToolRegistryError {
|
||||
/// 同名のツールが既に登録されている
|
||||
#[error("Tool with name '{0}' already registered")]
|
||||
DuplicateName(String),
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Worker Config
|
||||
// =============================================================================
|
||||
|
|
@ -155,18 +166,10 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
|||
text_block_collector: TextBlockCollector,
|
||||
/// ツールコールコレクター(Timeline用ハンドラ)
|
||||
tool_call_collector: ToolCallCollector,
|
||||
/// 登録されたツール
|
||||
tools: HashMap<String, Arc<dyn Tool>>,
|
||||
/// on_message_send Hook
|
||||
hooks_on_message_send: Vec<Box<dyn Hook<OnMessageSend>>>,
|
||||
/// before_tool_call Hook
|
||||
hooks_before_tool_call: Vec<Box<dyn Hook<BeforeToolCall>>>,
|
||||
/// after_tool_call Hook
|
||||
hooks_after_tool_call: Vec<Box<dyn Hook<AfterToolCall>>>,
|
||||
/// on_turn_end Hook
|
||||
hooks_on_turn_end: Vec<Box<dyn Hook<OnTurnEnd>>>,
|
||||
/// on_abort Hook
|
||||
hooks_on_abort: Vec<Box<dyn Hook<OnAbort>>>,
|
||||
/// 登録されたツール (meta, instance)
|
||||
tools: HashMap<String, (ToolMeta, Arc<dyn Tool>)>,
|
||||
/// Hook レジストリ
|
||||
hooks: HookRegistry,
|
||||
/// システムプロンプト
|
||||
system_prompt: Option<String>,
|
||||
/// メッセージ履歴(Workerが所有)
|
||||
|
|
@ -248,51 +251,71 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
/// ツールを登録する
|
||||
///
|
||||
/// 登録されたツールはLLMからの呼び出しで自動的に実行されます。
|
||||
/// 同名のツールを登録した場合、後から登録したものが優先されます。
|
||||
/// 同名のツールを登録するとエラーになります。
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```ignore
|
||||
/// use llm_worker::Worker;
|
||||
/// use my_tools::SearchTool;
|
||||
/// use llm_worker::tool::{ToolMeta, ToolDefinition, Tool};
|
||||
/// use std::sync::Arc;
|
||||
///
|
||||
/// worker.register_tool(SearchTool::new());
|
||||
/// let def: ToolDefinition = Arc::new(|| {
|
||||
/// (ToolMeta::new("search").description("..."), Arc::new(MyTool) as Arc<dyn Tool>)
|
||||
/// });
|
||||
/// worker.register_tool(def)?;
|
||||
/// ```
|
||||
pub fn register_tool(&mut self, tool: impl Tool + 'static) {
|
||||
let name = tool.name().to_string();
|
||||
self.tools.insert(name, Arc::new(tool));
|
||||
pub fn register_tool(&mut self, factory: ToolDefinition) -> Result<(), ToolRegistryError> {
|
||||
let (meta, instance) = factory();
|
||||
if self.tools.contains_key(&meta.name) {
|
||||
return Err(ToolRegistryError::DuplicateName(meta.name.clone()));
|
||||
}
|
||||
self.tools.insert(meta.name.clone(), (meta, instance));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 複数のツールを登録
|
||||
pub fn register_tools(&mut self, tools: impl IntoIterator<Item = impl Tool + 'static>) {
|
||||
for tool in tools {
|
||||
self.register_tool(tool);
|
||||
pub fn register_tools(
|
||||
&mut self,
|
||||
factories: impl IntoIterator<Item = ToolDefinition>,
|
||||
) -> Result<(), ToolRegistryError> {
|
||||
for factory in factories {
|
||||
self.register_tool(factory)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// on_message_send Hookを追加する
|
||||
pub fn add_on_message_send_hook(&mut self, hook: impl Hook<OnMessageSend> + 'static) {
|
||||
self.hooks_on_message_send.push(Box::new(hook));
|
||||
/// on_prompt_submit Hookを追加する
|
||||
///
|
||||
/// `run()` でユーザーメッセージを受け取った直後に呼び出される。
|
||||
pub fn add_on_prompt_submit_hook(&mut self, hook: impl Hook<OnPromptSubmit> + 'static) {
|
||||
self.hooks.on_prompt_submit.push(Box::new(hook));
|
||||
}
|
||||
|
||||
/// before_tool_call Hookを追加する
|
||||
pub fn add_before_tool_call_hook(&mut self, hook: impl Hook<BeforeToolCall> + 'static) {
|
||||
self.hooks_before_tool_call.push(Box::new(hook));
|
||||
/// pre_llm_request Hookを追加する
|
||||
///
|
||||
/// 各ターンのLLMリクエスト送信前に呼び出される。
|
||||
pub fn add_pre_llm_request_hook(&mut self, hook: impl Hook<PreLlmRequest> + 'static) {
|
||||
self.hooks.pre_llm_request.push(Box::new(hook));
|
||||
}
|
||||
|
||||
/// after_tool_call Hookを追加する
|
||||
pub fn add_after_tool_call_hook(&mut self, hook: impl Hook<AfterToolCall> + 'static) {
|
||||
self.hooks_after_tool_call.push(Box::new(hook));
|
||||
/// pre_tool_call Hookを追加する
|
||||
pub fn add_pre_tool_call_hook(&mut self, hook: impl Hook<PreToolCall> + 'static) {
|
||||
self.hooks.pre_tool_call.push(Box::new(hook));
|
||||
}
|
||||
|
||||
/// post_tool_call Hookを追加する
|
||||
pub fn add_post_tool_call_hook(&mut self, hook: impl Hook<PostToolCall> + 'static) {
|
||||
self.hooks.post_tool_call.push(Box::new(hook));
|
||||
}
|
||||
|
||||
/// on_turn_end Hookを追加する
|
||||
pub fn add_on_turn_end_hook(&mut self, hook: impl Hook<OnTurnEnd> + 'static) {
|
||||
self.hooks_on_turn_end.push(Box::new(hook));
|
||||
self.hooks.on_turn_end.push(Box::new(hook));
|
||||
}
|
||||
|
||||
/// on_abort Hookを追加する
|
||||
pub fn add_on_abort_hook(&mut self, hook: impl Hook<OnAbort> + 'static) {
|
||||
self.hooks_on_abort.push(Box::new(hook));
|
||||
self.hooks.on_abort.push(Box::new(hook));
|
||||
}
|
||||
|
||||
/// タイムラインへの可変参照を取得(追加ハンドラ登録用)
|
||||
|
|
@ -427,14 +450,14 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
&self.cancellation_token
|
||||
}
|
||||
|
||||
/// 登録されたツールからToolDefinitionのリストを生成
|
||||
fn build_tool_definitions(&self) -> Vec<ToolDefinition> {
|
||||
/// 登録されたツールからLLM用ToolDefinitionのリストを生成
|
||||
fn build_tool_definitions(&self) -> Vec<LlmToolDefinition> {
|
||||
self.tools
|
||||
.values()
|
||||
.map(|tool| {
|
||||
ToolDefinition::new(tool.name())
|
||||
.description(tool.description())
|
||||
.input_schema(tool.input_schema())
|
||||
.map(|(meta, _)| {
|
||||
LlmToolDefinition::new(&meta.name)
|
||||
.description(&meta.description)
|
||||
.input_schema(meta.input_schema.clone())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
|
@ -482,7 +505,11 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
}
|
||||
|
||||
/// リクエストを構築
|
||||
fn build_request(&self, tool_definitions: &[ToolDefinition], context: &[Message]) -> Request {
|
||||
fn build_request(
|
||||
&self,
|
||||
tool_definitions: &[LlmToolDefinition],
|
||||
context: &[Message],
|
||||
) -> Request {
|
||||
let mut request = Request::new();
|
||||
|
||||
// システムプロンプトを設定
|
||||
|
|
@ -546,27 +573,48 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
request
|
||||
}
|
||||
|
||||
/// Hooks: on_message_send
|
||||
async fn run_on_message_send_hooks(
|
||||
/// Hooks: on_prompt_submit
|
||||
///
|
||||
/// `run()` でユーザーメッセージを受け取った直後に呼び出される(最初だけ)。
|
||||
async fn run_on_prompt_submit_hooks(
|
||||
&self,
|
||||
) -> Result<(OnMessageSendResult, Vec<Message>), WorkerError> {
|
||||
message: &mut Message,
|
||||
) -> Result<OnPromptSubmitResult, WorkerError> {
|
||||
for hook in &self.hooks.on_prompt_submit {
|
||||
let result = hook.call(message).await?;
|
||||
match result {
|
||||
OnPromptSubmitResult::Continue => continue,
|
||||
OnPromptSubmitResult::Cancel(reason) => {
|
||||
return Ok(OnPromptSubmitResult::Cancel(reason));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(OnPromptSubmitResult::Continue)
|
||||
}
|
||||
|
||||
/// Hooks: pre_llm_request
|
||||
///
|
||||
/// 各ターンのLLMリクエスト送信前に呼び出される(毎ターン)。
|
||||
async fn run_pre_llm_request_hooks(
|
||||
&self,
|
||||
) -> Result<(PreLlmRequestResult, Vec<Message>), WorkerError> {
|
||||
let mut temp_context = self.history.clone();
|
||||
for hook in &self.hooks_on_message_send {
|
||||
for hook in &self.hooks.pre_llm_request {
|
||||
let result = hook.call(&mut temp_context).await?;
|
||||
match result {
|
||||
OnMessageSendResult::Continue => continue,
|
||||
OnMessageSendResult::Cancel(reason) => {
|
||||
return Ok((OnMessageSendResult::Cancel(reason), temp_context));
|
||||
PreLlmRequestResult::Continue => continue,
|
||||
PreLlmRequestResult::Cancel(reason) => {
|
||||
return Ok((PreLlmRequestResult::Cancel(reason), temp_context));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok((OnMessageSendResult::Continue, temp_context))
|
||||
Ok((PreLlmRequestResult::Continue, temp_context))
|
||||
}
|
||||
|
||||
/// Hooks: on_turn_end
|
||||
async fn run_on_turn_end_hooks(&self) -> Result<OnTurnEndResult, WorkerError> {
|
||||
let mut temp_messages = self.history.clone();
|
||||
for hook in &self.hooks_on_turn_end {
|
||||
for hook in &self.hooks.on_turn_end {
|
||||
let result = hook.call(&mut temp_messages).await?;
|
||||
match result {
|
||||
OnTurnEndResult::Finish => continue,
|
||||
|
|
@ -582,7 +630,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
/// Hooks: on_abort
|
||||
async fn run_on_abort_hooks(&self, reason: &str) -> Result<(), WorkerError> {
|
||||
let mut reason = reason.to_string();
|
||||
for hook in &self.hooks_on_abort {
|
||||
for hook in &self.hooks.on_abort {
|
||||
hook.call(&mut reason).await?;
|
||||
}
|
||||
Ok(())
|
||||
|
|
@ -608,44 +656,67 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
}
|
||||
}
|
||||
|
||||
if calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(calls)
|
||||
}
|
||||
if calls.is_empty() { None } else { Some(calls) }
|
||||
}
|
||||
|
||||
/// ツールを並列実行
|
||||
///
|
||||
/// 全てのツールに対してbefore_tool_callフックを実行後、
|
||||
/// 許可されたツールを並列に実行し、結果にafter_tool_callフックを適用する。
|
||||
/// 全てのツールに対してpre_tool_callフックを実行後、
|
||||
/// 許可されたツールを並列に実行し、結果にpost_tool_callフックを適用する。
|
||||
async fn execute_tools(
|
||||
&mut self,
|
||||
tool_calls: Vec<ToolCall>,
|
||||
) -> Result<ToolExecutionResult, WorkerError> {
|
||||
use futures::future::join_all;
|
||||
|
||||
// Phase 1: before_tool_call フックを適用(スキップ/中断を判定)
|
||||
// ツール呼び出しIDから (ToolCall, Meta, Tool) へのマップ
|
||||
// PostToolCallフックで必要になるため保持する
|
||||
let mut call_info_map = HashMap::new();
|
||||
|
||||
// Phase 1: pre_tool_call フックを適用(スキップ/中断を判定)
|
||||
let mut approved_calls = Vec::new();
|
||||
for mut tool_call in tool_calls {
|
||||
// ツール定義を取得
|
||||
if let Some((meta, tool)) = self.tools.get(&tool_call.name) {
|
||||
// コンテキストを作成
|
||||
let mut context = ToolCallContext {
|
||||
call: tool_call.clone(),
|
||||
meta: meta.clone(),
|
||||
tool: tool.clone(),
|
||||
};
|
||||
|
||||
let mut skip = false;
|
||||
for hook in &self.hooks_before_tool_call {
|
||||
let result = hook.call(&mut tool_call).await?;
|
||||
for hook in &self.hooks.pre_tool_call {
|
||||
let result = hook.call(&mut context).await?;
|
||||
match result {
|
||||
BeforeToolCallResult::Continue => {}
|
||||
BeforeToolCallResult::Skip => {
|
||||
PreToolCallResult::Continue => {}
|
||||
PreToolCallResult::Skip => {
|
||||
skip = true;
|
||||
break;
|
||||
}
|
||||
BeforeToolCallResult::Abort(reason) => {
|
||||
PreToolCallResult::Abort(reason) => {
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
BeforeToolCallResult::Pause => {
|
||||
PreToolCallResult::Pause => {
|
||||
return Ok(ToolExecutionResult::Paused);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// フックで変更された内容を反映
|
||||
tool_call = context.call;
|
||||
|
||||
// マップに保存(実行する場合のみ)
|
||||
if !skip {
|
||||
call_info_map.insert(
|
||||
tool_call.id.clone(),
|
||||
(tool_call.clone(), meta.clone(), tool.clone()),
|
||||
);
|
||||
approved_calls.push(tool_call);
|
||||
}
|
||||
} else {
|
||||
// 未知のツールはそのまま承認リストに入れる(実行時にエラーになる)
|
||||
// Hookは適用しない(Metaがないため)
|
||||
approved_calls.push(tool_call);
|
||||
}
|
||||
}
|
||||
|
|
@ -656,7 +727,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
.map(|tool_call| {
|
||||
let tools = &self.tools;
|
||||
async move {
|
||||
if let Some(tool) = tools.get(&tool_call.name) {
|
||||
if let Some((_, tool)) = tools.get(&tool_call.name) {
|
||||
let input_json =
|
||||
serde_json::to_string(&tool_call.input).unwrap_or_default();
|
||||
match tool.execute(&input_json).await {
|
||||
|
|
@ -684,17 +755,29 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
}
|
||||
};
|
||||
|
||||
// Phase 3: after_tool_call フックを適用
|
||||
// Phase 3: post_tool_call フックを適用
|
||||
for tool_result in &mut results {
|
||||
for hook in &self.hooks_after_tool_call {
|
||||
let result = hook.call(tool_result).await?;
|
||||
// 保存しておいた情報を取得
|
||||
if let Some((tool_call, meta, tool)) = call_info_map.get(&tool_result.tool_use_id) {
|
||||
let mut context = PostToolCallContext {
|
||||
call: tool_call.clone(),
|
||||
result: tool_result.clone(),
|
||||
meta: meta.clone(),
|
||||
tool: tool.clone(),
|
||||
};
|
||||
|
||||
for hook in &self.hooks.post_tool_call {
|
||||
let result = hook.call(&mut context).await?;
|
||||
match result {
|
||||
AfterToolCallResult::Continue => {}
|
||||
AfterToolCallResult::Abort(reason) => {
|
||||
PostToolCallResult::Continue => {}
|
||||
PostToolCallResult::Abort(reason) => {
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
}
|
||||
}
|
||||
// フックで変更された結果を反映
|
||||
*tool_result = context.result;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ToolExecutionResult::Completed(results))
|
||||
|
|
@ -717,7 +800,8 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)),
|
||||
ToolExecutionResult::Completed(results) => {
|
||||
for result in results {
|
||||
self.history.push(Message::tool_result(&result.tool_use_id, &result.content));
|
||||
self.history
|
||||
.push(Message::tool_result(&result.tool_use_id, &result.content));
|
||||
}
|
||||
// Continue to loop
|
||||
}
|
||||
|
|
@ -740,10 +824,10 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
notifier.on_turn_start(current_turn);
|
||||
}
|
||||
|
||||
// Hook: on_message_send
|
||||
let (control, request_context) = self.run_on_message_send_hooks().await?;
|
||||
// Hook: pre_llm_request
|
||||
let (control, request_context) = self.run_pre_llm_request_hooks().await?;
|
||||
match control {
|
||||
OnMessageSendResult::Cancel(reason) => {
|
||||
PreLlmRequestResult::Cancel(reason) => {
|
||||
info!(reason = %reason, "Aborted by hook");
|
||||
for notifier in &self.turn_notifiers {
|
||||
notifier.on_turn_end(current_turn);
|
||||
|
|
@ -751,7 +835,7 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
self.run_on_abort_hooks(&reason).await?;
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
OnMessageSendResult::Continue => {}
|
||||
PreLlmRequestResult::Continue => {}
|
||||
}
|
||||
|
||||
// リクエスト構築
|
||||
|
|
@ -849,7 +933,8 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)),
|
||||
ToolExecutionResult::Completed(results) => {
|
||||
for result in results {
|
||||
self.history.push(Message::tool_result(&result.tool_use_id, &result.content));
|
||||
self.history
|
||||
.push(Message::tool_result(&result.tool_use_id, &result.content));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -885,11 +970,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
|||
text_block_collector,
|
||||
tool_call_collector,
|
||||
tools: HashMap::new(),
|
||||
hooks_on_message_send: Vec::new(),
|
||||
hooks_before_tool_call: Vec::new(),
|
||||
hooks_after_tool_call: Vec::new(),
|
||||
hooks_on_turn_end: Vec::new(),
|
||||
hooks_on_abort: Vec::new(),
|
||||
hooks: HookRegistry::new(),
|
||||
system_prompt: None,
|
||||
history: Vec::new(),
|
||||
locked_prefix_len: 0,
|
||||
|
|
@ -1058,11 +1139,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
|||
text_block_collector: self.text_block_collector,
|
||||
tool_call_collector: self.tool_call_collector,
|
||||
tools: self.tools,
|
||||
hooks_on_message_send: self.hooks_on_message_send,
|
||||
hooks_before_tool_call: self.hooks_before_tool_call,
|
||||
hooks_after_tool_call: self.hooks_after_tool_call,
|
||||
hooks_on_turn_end: self.hooks_on_turn_end,
|
||||
hooks_on_abort: self.hooks_on_abort,
|
||||
hooks: self.hooks,
|
||||
system_prompt: self.system_prompt,
|
||||
history: self.history,
|
||||
locked_prefix_len,
|
||||
|
|
@ -1081,8 +1158,21 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
|||
///
|
||||
/// 注意: この関数は履歴を変更するため、キャッシュ保護が必要な場合は
|
||||
/// `lock()` を呼んでからLocked状態で `run` を使用すること。
|
||||
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<WorkerResult<'_>, WorkerError> {
|
||||
self.history.push(Message::user(user_input));
|
||||
pub async fn run(
|
||||
&mut self,
|
||||
user_input: impl Into<String>,
|
||||
) -> Result<WorkerResult<'_>, WorkerError> {
|
||||
// Hook: on_prompt_submit
|
||||
let mut user_message = Message::user(user_input);
|
||||
let result = self.run_on_prompt_submit_hooks(&mut user_message).await?;
|
||||
match result {
|
||||
OnPromptSubmitResult::Cancel(reason) => {
|
||||
self.run_on_abort_hooks(&reason).await?;
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
OnPromptSubmitResult::Continue => {}
|
||||
}
|
||||
self.history.push(user_message);
|
||||
self.run_turn_loop().await
|
||||
}
|
||||
|
||||
|
|
@ -1107,8 +1197,21 @@ impl<C: LlmClient> Worker<C, Locked> {
|
|||
///
|
||||
/// 新しいユーザーメッセージを履歴の末尾に追加し、LLMにリクエストを送信する。
|
||||
/// ロック時点より前の履歴(プレフィックス)は不変であるため、キャッシュヒットが保証される。
|
||||
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<WorkerResult<'_>, WorkerError> {
|
||||
self.history.push(Message::user(user_input));
|
||||
pub async fn run(
|
||||
&mut self,
|
||||
user_input: impl Into<String>,
|
||||
) -> Result<WorkerResult<'_>, WorkerError> {
|
||||
// Hook: on_prompt_submit
|
||||
let mut user_message = Message::user(user_input);
|
||||
let result = self.run_on_prompt_submit_hooks(&mut user_message).await?;
|
||||
match result {
|
||||
OnPromptSubmitResult::Cancel(reason) => {
|
||||
self.run_on_abort_hooks(&reason).await?;
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
OnPromptSubmitResult::Continue => {}
|
||||
}
|
||||
self.history.push(user_message);
|
||||
self.run_turn_loop().await
|
||||
}
|
||||
|
||||
|
|
@ -1137,11 +1240,7 @@ impl<C: LlmClient> Worker<C, Locked> {
|
|||
text_block_collector: self.text_block_collector,
|
||||
tool_call_collector: self.tool_call_collector,
|
||||
tools: self.tools,
|
||||
hooks_on_message_send: self.hooks_on_message_send,
|
||||
hooks_before_tool_call: self.hooks_before_tool_call,
|
||||
hooks_after_tool_call: self.hooks_after_tool_call,
|
||||
hooks_on_turn_end: self.hooks_on_turn_end,
|
||||
hooks_on_abort: self.hooks_on_abort,
|
||||
hooks: self.hooks,
|
||||
system_prompt: self.system_prompt,
|
||||
history: self.history,
|
||||
locked_prefix_len: 0,
|
||||
|
|
|
|||
|
|
@ -9,11 +9,11 @@ use std::time::{Duration, Instant};
|
|||
use async_trait::async_trait;
|
||||
use llm_worker::Worker;
|
||||
use llm_worker::hook::{
|
||||
AfterToolCall, AfterToolCallResult, BeforeToolCall, BeforeToolCallResult, Hook, HookError,
|
||||
ToolCall, ToolResult,
|
||||
Hook, HookError, PostToolCall, PostToolCallContext, PostToolCallResult, PreToolCall,
|
||||
PreToolCallResult, ToolCallContext,
|
||||
};
|
||||
use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent};
|
||||
use llm_worker::tool::{Tool, ToolError};
|
||||
use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta};
|
||||
|
||||
mod common;
|
||||
use common::MockLlmClient;
|
||||
|
|
@ -42,25 +42,24 @@ impl SlowTool {
|
|||
fn call_count(&self) -> usize {
|
||||
self.call_count.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// ToolDefinition を作成
|
||||
fn definition(&self) -> ToolDefinition {
|
||||
let tool = self.clone();
|
||||
Arc::new(move || {
|
||||
let meta = ToolMeta::new(&tool.name)
|
||||
.description("A tool that waits before responding")
|
||||
.input_schema(serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}));
|
||||
(meta, Arc::new(tool.clone()) as Arc<dyn Tool>)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SlowTool {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"A tool that waits before responding"
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, _input_json: &str) -> Result<String, ToolError> {
|
||||
self.call_count.fetch_add(1, Ordering::SeqCst);
|
||||
tokio::time::sleep(Duration::from_millis(self.delay_ms)).await;
|
||||
|
|
@ -106,9 +105,9 @@ async fn test_parallel_tool_execution() {
|
|||
let tool2_clone = tool2.clone();
|
||||
let tool3_clone = tool3.clone();
|
||||
|
||||
worker.register_tool(tool1);
|
||||
worker.register_tool(tool2);
|
||||
worker.register_tool(tool3);
|
||||
worker.register_tool(tool1.definition()).unwrap();
|
||||
worker.register_tool(tool2.definition()).unwrap();
|
||||
worker.register_tool(tool3.definition()).unwrap();
|
||||
|
||||
let start = Instant::now();
|
||||
let _result = worker.run("Run all tools").await;
|
||||
|
|
@ -130,7 +129,7 @@ async fn test_parallel_tool_execution() {
|
|||
println!("Parallel execution completed in {:?}", elapsed);
|
||||
}
|
||||
|
||||
/// Hook: before_tool_call でスキップされたツールは実行されないことを確認
|
||||
/// Hook: pre_tool_call でスキップされたツールは実行されないことを確認
|
||||
#[tokio::test]
|
||||
async fn test_before_tool_call_skip() {
|
||||
let events = vec![
|
||||
|
|
@ -154,24 +153,24 @@ async fn test_before_tool_call_skip() {
|
|||
let allowed_clone = allowed_tool.clone();
|
||||
let blocked_clone = blocked_tool.clone();
|
||||
|
||||
worker.register_tool(allowed_tool);
|
||||
worker.register_tool(blocked_tool);
|
||||
worker.register_tool(allowed_tool.definition()).unwrap();
|
||||
worker.register_tool(blocked_tool.definition()).unwrap();
|
||||
|
||||
// "blocked_tool" をスキップするHook
|
||||
struct BlockingHook;
|
||||
|
||||
#[async_trait]
|
||||
impl Hook<BeforeToolCall> for BlockingHook {
|
||||
async fn call(&self, tool_call: &mut ToolCall) -> Result<BeforeToolCallResult, HookError> {
|
||||
if tool_call.name == "blocked_tool" {
|
||||
Ok(BeforeToolCallResult::Skip)
|
||||
impl Hook<PreToolCall> for BlockingHook {
|
||||
async fn call(&self, ctx: &mut ToolCallContext) -> Result<PreToolCallResult, HookError> {
|
||||
if ctx.call.name == "blocked_tool" {
|
||||
Ok(PreToolCallResult::Skip)
|
||||
} else {
|
||||
Ok(BeforeToolCallResult::Continue)
|
||||
Ok(PreToolCallResult::Continue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
worker.add_before_tool_call_hook(BlockingHook);
|
||||
worker.add_pre_tool_call_hook(BlockingHook);
|
||||
|
||||
let _result = worker.run("Test hook").await;
|
||||
|
||||
|
|
@ -188,9 +187,9 @@ async fn test_before_tool_call_skip() {
|
|||
);
|
||||
}
|
||||
|
||||
/// Hook: after_tool_call で結果が改変されることを確認
|
||||
/// Hook: post_tool_call で結果が改変されることを確認
|
||||
#[tokio::test]
|
||||
async fn test_after_tool_call_modification() {
|
||||
async fn test_post_tool_call_modification() {
|
||||
// 複数リクエストに対応するレスポンスを準備
|
||||
let client = MockLlmClient::with_responses(vec![
|
||||
// 1回目のリクエスト: ツール呼び出し
|
||||
|
|
@ -220,21 +219,21 @@ async fn test_after_tool_call_modification() {
|
|||
|
||||
#[async_trait]
|
||||
impl Tool for SimpleTool {
|
||||
fn name(&self) -> &str {
|
||||
"test_tool"
|
||||
}
|
||||
fn description(&self) -> &str {
|
||||
"Test"
|
||||
}
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({})
|
||||
}
|
||||
async fn execute(&self, _: &str) -> Result<String, ToolError> {
|
||||
Ok("Original Result".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
worker.register_tool(SimpleTool);
|
||||
fn simple_tool_definition() -> ToolDefinition {
|
||||
Arc::new(|| {
|
||||
let meta = ToolMeta::new("test_tool")
|
||||
.description("Test")
|
||||
.input_schema(serde_json::json!({}));
|
||||
(meta, Arc::new(SimpleTool) as Arc<dyn Tool>)
|
||||
})
|
||||
}
|
||||
|
||||
worker.register_tool(simple_tool_definition()).unwrap();
|
||||
|
||||
// 結果を改変するHook
|
||||
struct ModifyingHook {
|
||||
|
|
@ -242,19 +241,19 @@ async fn test_after_tool_call_modification() {
|
|||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Hook<AfterToolCall> for ModifyingHook {
|
||||
impl Hook<PostToolCall> for ModifyingHook {
|
||||
async fn call(
|
||||
&self,
|
||||
tool_result: &mut ToolResult,
|
||||
) -> Result<AfterToolCallResult, HookError> {
|
||||
tool_result.content = format!("[Modified] {}", tool_result.content);
|
||||
*self.modified_content.lock().unwrap() = Some(tool_result.content.clone());
|
||||
Ok(AfterToolCallResult::Continue)
|
||||
ctx: &mut PostToolCallContext,
|
||||
) -> Result<PostToolCallResult, HookError> {
|
||||
ctx.result.content = format!("[Modified] {}", ctx.result.content);
|
||||
*self.modified_content.lock().unwrap() = Some(ctx.result.content.clone());
|
||||
Ok(PostToolCallResult::Continue)
|
||||
}
|
||||
}
|
||||
|
||||
let modified_content = Arc::new(std::sync::Mutex::new(None));
|
||||
worker.add_after_tool_call_hook(ModifyingHook {
|
||||
worker.add_post_tool_call_hook(ModifyingHook {
|
||||
modified_content: modified_content.clone(),
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};
|
|||
use schemars;
|
||||
use serde;
|
||||
|
||||
use llm_worker::tool::Tool;
|
||||
use llm_worker::tool::{Tool, ToolMeta};
|
||||
use llm_worker_macros::tool_registry;
|
||||
|
||||
// =============================================================================
|
||||
|
|
@ -51,30 +51,31 @@ async fn test_basic_tool_generation() {
|
|||
prefix: "Hello".to_string(),
|
||||
};
|
||||
|
||||
// ファクトリメソッドでツールを取得
|
||||
let greet_tool = ctx.greet_tool();
|
||||
// ファクトリメソッドでToolDefinitionを取得
|
||||
let greet_definition = ctx.greet_definition();
|
||||
|
||||
// 名前の確認
|
||||
assert_eq!(greet_tool.name(), "greet");
|
||||
// ファクトリを呼び出してMetaとToolを取得
|
||||
let (meta, tool) = greet_definition();
|
||||
|
||||
// 説明の確認(docコメントから取得)
|
||||
let desc = greet_tool.description();
|
||||
// メタ情報の確認
|
||||
assert_eq!(meta.name, "greet");
|
||||
assert!(
|
||||
desc.contains("メッセージに挨拶を追加する"),
|
||||
meta.description.contains("メッセージに挨拶を追加する"),
|
||||
"Description should contain doc comment: {}",
|
||||
desc
|
||||
meta.description
|
||||
);
|
||||
|
||||
// スキーマの確認
|
||||
let schema = greet_tool.input_schema();
|
||||
println!("Schema: {}", serde_json::to_string_pretty(&schema).unwrap());
|
||||
assert!(
|
||||
schema.get("properties").is_some(),
|
||||
meta.input_schema.get("properties").is_some(),
|
||||
"Schema should have properties"
|
||||
);
|
||||
|
||||
println!(
|
||||
"Schema: {}",
|
||||
serde_json::to_string_pretty(&meta.input_schema).unwrap()
|
||||
);
|
||||
|
||||
// 実行テスト
|
||||
let result = greet_tool.execute(r#"{"message": "World"}"#).await;
|
||||
let result = tool.execute(r#"{"message": "World"}"#).await;
|
||||
assert!(result.is_ok(), "Should execute successfully");
|
||||
let output = result.unwrap();
|
||||
assert!(output.contains("Hello"), "Output should contain prefix");
|
||||
|
|
@ -87,11 +88,11 @@ async fn test_multiple_arguments() {
|
|||
prefix: "".to_string(),
|
||||
};
|
||||
|
||||
let add_tool = ctx.add_tool();
|
||||
let (meta, tool) = ctx.add_definition()();
|
||||
|
||||
assert_eq!(add_tool.name(), "add");
|
||||
assert_eq!(meta.name, "add");
|
||||
|
||||
let result = add_tool.execute(r#"{"a": 10, "b": 20}"#).await;
|
||||
let result = tool.execute(r#"{"a": 10, "b": 20}"#).await;
|
||||
assert!(result.is_ok());
|
||||
let output = result.unwrap();
|
||||
assert!(output.contains("30"), "Should contain sum: {}", output);
|
||||
|
|
@ -103,12 +104,12 @@ async fn test_no_arguments() {
|
|||
prefix: "TestPrefix".to_string(),
|
||||
};
|
||||
|
||||
let get_prefix_tool = ctx.get_prefix_tool();
|
||||
let (meta, tool) = ctx.get_prefix_definition()();
|
||||
|
||||
assert_eq!(get_prefix_tool.name(), "get_prefix");
|
||||
assert_eq!(meta.name, "get_prefix");
|
||||
|
||||
// 空のJSONオブジェクトで呼び出し
|
||||
let result = get_prefix_tool.execute(r#"{}"#).await;
|
||||
let result = tool.execute(r#"{}"#).await;
|
||||
assert!(result.is_ok());
|
||||
let output = result.unwrap();
|
||||
assert!(
|
||||
|
|
@ -124,10 +125,10 @@ async fn test_invalid_arguments() {
|
|||
prefix: "".to_string(),
|
||||
};
|
||||
|
||||
let greet_tool = ctx.greet_tool();
|
||||
let (_, tool) = ctx.greet_definition()();
|
||||
|
||||
// 不正なJSON
|
||||
let result = greet_tool.execute(r#"{"wrong_field": "value"}"#).await;
|
||||
let result = tool.execute(r#"{"wrong_field": "value"}"#).await;
|
||||
assert!(result.is_err(), "Should fail with invalid arguments");
|
||||
}
|
||||
|
||||
|
|
@ -163,9 +164,9 @@ impl FallibleContext {
|
|||
#[tokio::test]
|
||||
async fn test_result_return_type_success() {
|
||||
let ctx = FallibleContext;
|
||||
let validate_tool = ctx.validate_tool();
|
||||
let (_, tool) = ctx.validate_definition()();
|
||||
|
||||
let result = validate_tool.execute(r#"{"value": 42}"#).await;
|
||||
let result = tool.execute(r#"{"value": 42}"#).await;
|
||||
assert!(result.is_ok(), "Should succeed for positive value");
|
||||
let output = result.unwrap();
|
||||
assert!(output.contains("Valid"), "Should contain Valid: {}", output);
|
||||
|
|
@ -174,9 +175,9 @@ async fn test_result_return_type_success() {
|
|||
#[tokio::test]
|
||||
async fn test_result_return_type_error() {
|
||||
let ctx = FallibleContext;
|
||||
let validate_tool = ctx.validate_tool();
|
||||
let (_, tool) = ctx.validate_definition()();
|
||||
|
||||
let result = validate_tool.execute(r#"{"value": -1}"#).await;
|
||||
let result = tool.execute(r#"{"value": -1}"#).await;
|
||||
assert!(result.is_err(), "Should fail for negative value");
|
||||
|
||||
let err = result.unwrap_err();
|
||||
|
|
@ -211,12 +212,12 @@ async fn test_sync_method() {
|
|||
counter: Arc::new(AtomicUsize::new(0)),
|
||||
};
|
||||
|
||||
let increment_tool = ctx.increment_tool();
|
||||
let (_, tool) = ctx.increment_definition()();
|
||||
|
||||
// 3回実行
|
||||
let result1 = increment_tool.execute(r#"{}"#).await;
|
||||
let result2 = increment_tool.execute(r#"{}"#).await;
|
||||
let result3 = increment_tool.execute(r#"{}"#).await;
|
||||
let result1 = tool.execute(r#"{}"#).await;
|
||||
let result2 = tool.execute(r#"{}"#).await;
|
||||
let result3 = tool.execute(r#"{}"#).await;
|
||||
|
||||
assert!(result1.is_ok());
|
||||
assert!(result2.is_ok());
|
||||
|
|
@ -225,3 +226,22 @@ async fn test_sync_method() {
|
|||
// カウンターは3になっているはず
|
||||
assert_eq!(ctx.counter.load(Ordering::SeqCst), 3);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Test: ToolMeta Immutability
|
||||
// =============================================================================
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tool_meta_immutability() {
|
||||
let ctx = SimpleContext {
|
||||
prefix: "Test".to_string(),
|
||||
};
|
||||
|
||||
// 2回取得しても同じメタ情報が得られることを確認
|
||||
let (meta1, _) = ctx.greet_definition()();
|
||||
let (meta2, _) = ctx.greet_definition()();
|
||||
|
||||
assert_eq!(meta1.name, meta2.name);
|
||||
assert_eq!(meta1.description, meta2.description);
|
||||
assert_eq!(meta1.input_schema, meta2.input_schema);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
use llm_worker::llm_client::LlmClient;
|
||||
use llm_worker::llm_client::providers::openai::OpenAIClient;
|
||||
use llm_worker::{Worker, WorkerError};
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};
|
|||
use async_trait::async_trait;
|
||||
use common::MockLlmClient;
|
||||
use llm_worker::Worker;
|
||||
use llm_worker::tool::{Tool, ToolError};
|
||||
use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta};
|
||||
|
||||
/// フィクスチャディレクトリのパス
|
||||
fn fixtures_dir() -> std::path::PathBuf {
|
||||
|
|
@ -35,20 +35,13 @@ impl MockWeatherTool {
|
|||
fn get_call_count(&self) -> usize {
|
||||
self.call_count.load(Ordering::SeqCst)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for MockWeatherTool {
|
||||
fn name(&self) -> &str {
|
||||
"get_weather"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Get the current weather for a city"
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
fn definition(&self) -> ToolDefinition {
|
||||
let tool = self.clone();
|
||||
Arc::new(move || {
|
||||
let meta = ToolMeta::new("get_weather")
|
||||
.description("Get the current weather for a city")
|
||||
.input_schema(serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
|
|
@ -57,9 +50,14 @@ impl Tool for MockWeatherTool {
|
|||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}));
|
||||
(meta, Arc::new(tool.clone()) as Arc<dyn Tool>)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for MockWeatherTool {
|
||||
async fn execute(&self, input_json: &str) -> Result<String, ToolError> {
|
||||
self.call_count.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
|
|
@ -158,7 +156,7 @@ async fn test_worker_tool_call() {
|
|||
// ツールを登録
|
||||
let weather_tool = MockWeatherTool::new();
|
||||
let tool_for_check = weather_tool.clone();
|
||||
worker.register_tool(weather_tool);
|
||||
worker.register_tool(weather_tool.definition()).unwrap();
|
||||
|
||||
// メッセージを送信
|
||||
let _result = worker.run("What's the weather in Tokyo?").await;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user