Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| cba96e4f46 | |||
| d63b4ea470 | |||
| b12785ed93 | |||
| 3c62970967 | |||
| 6c43ac9969 | |||
| 6d87da90d1 |
|
|
@ -1,6 +1,7 @@
|
||||||
# llm-worker-rs 開発instruction
|
# llm-worker-rs Development Instructions
|
||||||
|
|
||||||
## パッケージ管理ルール
|
## Package Management Rules
|
||||||
|
|
||||||
- クレートに依存関係を追加・更新する際は必ず
|
- When adding or updating crate dependencies, always use the `cargo` command. Do
|
||||||
`cargo`コマンドを使い、`Cargo.toml`を直接手で書き換えず、必ずコマンド経由で管理すること。
|
not manually edit `Cargo.toml` directly; always manage dependencies via
|
||||||
|
commands.
|
||||||
|
|
|
||||||
102
Cargo.lock
generated
102
Cargo.lock
generated
|
|
@ -384,6 +384,12 @@ dependencies = [
|
||||||
"wasip2",
|
"wasip2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "glob"
|
||||||
|
version = "0.3.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "h2"
|
name = "h2"
|
||||||
version = "0.4.13"
|
version = "0.4.13"
|
||||||
|
|
@ -708,7 +714,7 @@ checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "llm-worker"
|
name = "llm-worker"
|
||||||
version = "0.2.0"
|
version = "0.2.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"clap",
|
"clap",
|
||||||
|
|
@ -726,6 +732,7 @@ dependencies = [
|
||||||
"tokio-util",
|
"tokio-util",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
|
"trybuild",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -1178,6 +1185,15 @@ dependencies = [
|
||||||
"zmij",
|
"zmij",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_spanned"
|
||||||
|
version = "1.0.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f8bbf91e5a4d6315eee45e704372590b30e260ee83af6639d64557f51b067776"
|
||||||
|
dependencies = [
|
||||||
|
"serde_core",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sharded-slab"
|
name = "sharded-slab"
|
||||||
version = "0.1.7"
|
version = "0.1.7"
|
||||||
|
|
@ -1264,6 +1280,12 @@ dependencies = [
|
||||||
"syn",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "target-triple"
|
||||||
|
version = "1.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "591ef38edfb78ca4771ee32cf494cb8771944bee237a9b91fc9c1424ac4b777b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tempfile"
|
name = "tempfile"
|
||||||
version = "3.24.0"
|
version = "3.24.0"
|
||||||
|
|
@ -1277,6 +1299,15 @@ dependencies = [
|
||||||
"windows-sys 0.61.2",
|
"windows-sys 0.61.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "termcolor"
|
||||||
|
version = "1.4.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755"
|
||||||
|
dependencies = [
|
||||||
|
"winapi-util",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror"
|
name = "thiserror"
|
||||||
version = "2.0.17"
|
version = "2.0.17"
|
||||||
|
|
@ -1375,6 +1406,45 @@ dependencies = [
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "toml"
|
||||||
|
version = "1.0.3+spec-1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c7614eaf19ad818347db24addfa201729cf2a9b6fdfd9eb0ab870fcacc606c0c"
|
||||||
|
dependencies = [
|
||||||
|
"indexmap",
|
||||||
|
"serde_core",
|
||||||
|
"serde_spanned",
|
||||||
|
"toml_datetime",
|
||||||
|
"toml_parser",
|
||||||
|
"toml_writer",
|
||||||
|
"winnow",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "toml_datetime"
|
||||||
|
version = "1.0.0+spec-1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "32c2555c699578a4f59f0cc68e5116c8d7cabbd45e1409b989d4be085b53f13e"
|
||||||
|
dependencies = [
|
||||||
|
"serde_core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "toml_parser"
|
||||||
|
version = "1.0.9+spec-1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4"
|
||||||
|
dependencies = [
|
||||||
|
"winnow",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "toml_writer"
|
||||||
|
version = "1.0.6+spec-1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ab16f14aed21ee8bfd8ec22513f7287cd4a91aa92e44edfe2c17ddd004e92607"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tower"
|
name = "tower"
|
||||||
version = "0.5.2"
|
version = "0.5.2"
|
||||||
|
|
@ -1487,6 +1557,21 @@ version = "0.2.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "trybuild"
|
||||||
|
version = "1.0.116"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "47c635f0191bd3a2941013e5062667100969f8c4e9cd787c14f977265d73616e"
|
||||||
|
dependencies = [
|
||||||
|
"glob",
|
||||||
|
"serde",
|
||||||
|
"serde_derive",
|
||||||
|
"serde_json",
|
||||||
|
"target-triple",
|
||||||
|
"termcolor",
|
||||||
|
"toml",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode-ident"
|
name = "unicode-ident"
|
||||||
version = "1.0.22"
|
version = "1.0.22"
|
||||||
|
|
@ -1640,6 +1725,15 @@ dependencies = [
|
||||||
"wasm-bindgen",
|
"wasm-bindgen",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winapi-util"
|
||||||
|
version = "0.1.11"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
|
||||||
|
dependencies = [
|
||||||
|
"windows-sys 0.61.2",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows-link"
|
name = "windows-link"
|
||||||
version = "0.2.1"
|
version = "0.2.1"
|
||||||
|
|
@ -1802,6 +1896,12 @@ version = "0.53.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650"
|
checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winnow"
|
||||||
|
version = "0.7.14"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5a5364e9d77fcdeeaa6062ced926ee3381faa2ee02d3eb83a5c27a8825540829"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wit-bindgen"
|
name = "wit-bindgen"
|
||||||
version = "0.46.0"
|
version = "0.46.0"
|
||||||
|
|
|
||||||
83
docs/plan/worker_api_plan.md
Normal file
83
docs/plan/worker_api_plan.md
Normal file
|
|
@ -0,0 +1,83 @@
|
||||||
|
# Worker API/DSL 実装計画
|
||||||
|
|
||||||
|
## 目的
|
||||||
|
|
||||||
|
- [Open Responses](https://www.openresponses.org)(以後"OR")に準拠した正規化を前提に、
|
||||||
|
Item/Part の2段スコープを扱える Worker API を設計する。
|
||||||
|
- APIの煩雑化を防ぐため、worker.on_xxx として公開するのを避けつつ、
|
||||||
|
Text/Thinking/Tool など型の違いを静的に扱える DSL を提供する。
|
||||||
|
|
||||||
|
## 方針
|
||||||
|
|
||||||
|
- 内部は Timeline が Event を正規化し、Item/Part/Meta
|
||||||
|
を単一ストリームとして扱う。
|
||||||
|
- API では Item/Part 型ごとに ctx を持てるようにし、DSL
|
||||||
|
で記述の冗長さを削減する。
|
||||||
|
- まず macro_rules! 版を作り、必要なら proc-macro に拡張する。
|
||||||
|
- Item/Part の型パラメータはクレートが公開する Kind 型を使う。
|
||||||
|
|
||||||
|
## 仕様の前提
|
||||||
|
|
||||||
|
- Item は OR の item (message, function_call, reasoning など) に対応する。
|
||||||
|
- Part は OR の content part (output_text, reasoning_text など) に対応する。
|
||||||
|
- Item は必ず start/stop を持つ。Part は Item 内で複数発生し得る。
|
||||||
|
- Item/Part の型指定は `Item<Message>` / `Part<ReasoningText>` のように書く。
|
||||||
|
|
||||||
|
## 設計ステップ
|
||||||
|
|
||||||
|
### 1. 内部イベントモデルの整理
|
||||||
|
|
||||||
|
- Event を Item/Part/Meta の3層に整理する。
|
||||||
|
- ItemEvent / PartEvent は型パラメータで区別する。
|
||||||
|
- 例: ItemEvent<Message>, PartEvent<Message, OutputText>
|
||||||
|
|
||||||
|
### 2. スコープの二段化
|
||||||
|
|
||||||
|
- Item ctx: Item 型ごとに1つ
|
||||||
|
- Part ctx: Part 型ごとに1つ
|
||||||
|
- Part のイベントでは常に item ctx と part ctx の両方を渡す。
|
||||||
|
|
||||||
|
### 3. Handler trait の再定義
|
||||||
|
|
||||||
|
- Item/Part を型で指定できる trait を導入する。
|
||||||
|
- 例:
|
||||||
|
- trait ItemHandler<I>
|
||||||
|
- trait PartHandler<I, P>
|
||||||
|
- PartHandler には ItemHandler の ItemCtx を必須で渡す。
|
||||||
|
- Part の ctx 型は `PartKind::Ctx` 方式 or enum 方式で切り替える。
|
||||||
|
|
||||||
|
### 4. Timeline との結合
|
||||||
|
|
||||||
|
- Timeline は ItemStart で ItemCtx を生成
|
||||||
|
- PartStart で PartCtx を生成
|
||||||
|
- Delta/Stop は対応 ctx に流す
|
||||||
|
- ItemStop で ItemCtx を破棄
|
||||||
|
|
||||||
|
### 5. DSL (macro_rules!) の導入
|
||||||
|
|
||||||
|
- まず宣言的 DSL を提供する。
|
||||||
|
- 例:
|
||||||
|
- handler! { Item<Message> { type ItemCtx = ...; Part<OutputText> { type
|
||||||
|
PartCtx = ...; } } }
|
||||||
|
- DSL は ItemHandler / PartHandler 実装を生成する。
|
||||||
|
- Item/Part の Kind 型はクレートが公開する型を参照する。
|
||||||
|
|
||||||
|
### 6. 拡張ポイント
|
||||||
|
|
||||||
|
- 追加 Part (output_image など) を DSL に追加しやすい形にする。
|
||||||
|
- 必要なら proc-macro に移行して構文自由度を上げる。
|
||||||
|
|
||||||
|
## 実装順序
|
||||||
|
|
||||||
|
1. Event/Item/Part の型定義の整理
|
||||||
|
2. Item/Part ctx を持つ Timeline 実装
|
||||||
|
3. Handler trait の定義・既存コードの移行
|
||||||
|
4. macro_rules! DSL の実装
|
||||||
|
5. 既存ユースケースの移植
|
||||||
|
|
||||||
|
## TODO
|
||||||
|
|
||||||
|
- Item と Part の型対応表を整理する
|
||||||
|
- OR と既存 llm_client の差分を再確認する
|
||||||
|
- Tool args の delta を OR 拡張として扱うか検討する
|
||||||
|
- macro_rules! で表現可能な DSL の最小文法を確定する
|
||||||
80
docs/research/openresponses_mapping.md
Normal file
80
docs/research/openresponses_mapping.md
Normal file
|
|
@ -0,0 +1,80 @@
|
||||||
|
# Open Responses mapping (llm_client -> Open Responses)
|
||||||
|
|
||||||
|
This document maps the current `llm_client` event model to Open Responses items
|
||||||
|
and streaming events. It focuses on output streaming; input items are noted
|
||||||
|
where they are the closest semantic match.
|
||||||
|
|
||||||
|
## Legend
|
||||||
|
|
||||||
|
- **OR item**: Open Responses item types used in `response.output`.
|
||||||
|
- **OR event**: Open Responses streaming events (`response.*`).
|
||||||
|
- **Note**: Gaps or required adaptation decisions.
|
||||||
|
|
||||||
|
## Response lifecycle / meta events
|
||||||
|
|
||||||
|
| llm_client | Open Responses | Note |
|
||||||
|
| ------------------------ | ------------------------------------------------------------- | ---------------------------------------------------------------------------------------------- |
|
||||||
|
| `StatusEvent::Started` | `response.created`, `response.queued`, `response.in_progress` | OR has finer-grained lifecycle states; pick a subset or map Started -> `response.in_progress`. |
|
||||||
|
| `StatusEvent::Completed` | `response.completed` | |
|
||||||
|
| `StatusEvent::Failed` | `response.failed` | |
|
||||||
|
| `StatusEvent::Cancelled` | (no direct event) | Could map to `response.incomplete` or `response.failed` depending on semantics. |
|
||||||
|
| `UsageEvent` | `response.completed` payload usage | OR reports usage on the response object, not as a dedicated streaming event. |
|
||||||
|
| `ErrorEvent` | `error` event | OR has a dedicated error streaming event. |
|
||||||
|
| `PingEvent` | (no direct event) | OR does not define a heartbeat event. |
|
||||||
|
|
||||||
|
## Output block lifecycle
|
||||||
|
|
||||||
|
### Text block
|
||||||
|
|
||||||
|
| llm_client | Open Responses | Note |
|
||||||
|
| ------------------------------------------------- | ---------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------- |
|
||||||
|
| `BlockStart { block_type: Text, metadata: Text }` | `response.output_item.added` with item type `message` (assistant) | OR output items are message/function_call/reasoning. This creates the message item. |
|
||||||
|
| `BlockDelta { delta: Text(..) }` | `response.output_text.delta` | Text deltas map 1:1 to output text deltas. |
|
||||||
|
| `BlockStop { block_type: Text }` | `response.output_text.done` + `response.content_part.done` + `response.output_item.done` | OR emits separate done events for content parts and items. |
|
||||||
|
|
||||||
|
### Tool use (function call)
|
||||||
|
|
||||||
|
| llm_client | Open Responses | Note |
|
||||||
|
| -------------------------------------------------------------------- | --------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------- |
|
||||||
|
| `BlockStart { block_type: ToolUse, metadata: ToolUse { id, name } }` | `response.output_item.added` with item type `function_call` | OR uses `call_id` + `name` + `arguments` string. Map `id` -> `call_id`. |
|
||||||
|
| `BlockDelta { delta: InputJson(..) }` | `response.function_call_arguments.delta` | OR spec does not explicitly require argument deltas; treat as OpenAI-compatible extension if adopted. |
|
||||||
|
| `BlockStop { block_type: ToolUse }` | `response.function_call_arguments.done` + `response.output_item.done` | Item status can be set to `completed` or `incomplete`. |
|
||||||
|
|
||||||
|
### Tool result (function call output)
|
||||||
|
|
||||||
|
| llm_client | Open Responses | Note |
|
||||||
|
| ----------------------------------------------------------------------------- | ------------------------------------- | ---------------------------------------------------------------------------------------- |
|
||||||
|
| `BlockStart { block_type: ToolResult, metadata: ToolResult { tool_use_id } }` | **Input item** `function_call_output` | OR treats tool results as input items, not output items. This is a request-side mapping. |
|
||||||
|
| `BlockDelta` | (no direct output event) | OR does not stream tool output deltas as response events. |
|
||||||
|
| `BlockStop` | (no direct output event) | Tool output lives on the next request as an input item. |
|
||||||
|
|
||||||
|
### Thinking / reasoning
|
||||||
|
|
||||||
|
| llm_client | Open Responses | Note |
|
||||||
|
| --------------------------------------------------------- | ------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `BlockStart { block_type: Thinking, metadata: Thinking }` | `response.output_item.added` with item type `reasoning` | OR models reasoning as a separate item type. |
|
||||||
|
| `BlockDelta { delta: Thinking(..) }` | `response.reasoning.delta` | OR has dedicated reasoning delta events. |
|
||||||
|
| `BlockStop { block_type: Thinking }` | `response.reasoning.done` | OR separates reasoning summary events (`response.reasoning_summary_*`) from reasoning deltas. Decide whether Thinking maps to full reasoning or summary only. |
|
||||||
|
|
||||||
|
## Stop reasons
|
||||||
|
|
||||||
|
| llm_client `StopReason` | Open Responses | Note |
|
||||||
|
| ----------------------- | ------------------------------------------------------------------------------ | ---------------------------------------------- |
|
||||||
|
| `EndTurn` | `response.completed` + item status `completed` | |
|
||||||
|
| `MaxTokens` | `response.incomplete` + item status `incomplete` | |
|
||||||
|
| `StopSequence` | `response.completed` | |
|
||||||
|
| `ToolUse` | `response.completed` for message item, followed by `function_call` output item | OR models tool call as a separate output item. |
|
||||||
|
|
||||||
|
## Gaps / open decisions
|
||||||
|
|
||||||
|
- `PingEvent` has no OR equivalent. If needed, keep as internal only.
|
||||||
|
- `Cancelled` status needs a policy: map to `response.incomplete` or
|
||||||
|
`response.failed`.
|
||||||
|
- OR has `response.refusal.delta` / `response.refusal.done`. `llm_client` has no
|
||||||
|
refusal delta type; consider adding a new block or delta variant if needed.
|
||||||
|
- OR splits _item_ and _content part_ lifecycles. `llm_client` currently has a
|
||||||
|
single block lifecycle, so mapping should decide whether to synthesize
|
||||||
|
`content_part.*` events or ignore them.
|
||||||
|
- The OR specification does not state how `function_call.arguments` stream
|
||||||
|
deltas; `response.function_call_arguments.*` should be treated as a compatible
|
||||||
|
extension if required.
|
||||||
|
|
@ -33,39 +33,46 @@ Workerは以下のループ(ターン)を実行します。
|
||||||
|
|
||||||
1. **Start Turn**: `Worker::run(messages)` 呼び出し
|
1. **Start Turn**: `Worker::run(messages)` 呼び出し
|
||||||
2. **Hook: OnMessageSend**:
|
2. **Hook: OnMessageSend**:
|
||||||
* ユーザーメッセージの改変、バリデーション、キャンセルが可能。
|
- ユーザーメッセージの改変、バリデーション、キャンセルが可能。
|
||||||
* コンテキストへのシステムプロンプト注入などもここで行う想定。
|
- コンテキストへのシステムプロンプト注入などもここで行う想定。
|
||||||
3. **Request & Stream**:
|
3. **Request & Stream**:
|
||||||
* LLMへリクエスト送信。イベントストリーム開始。
|
- LLMへリクエスト送信。イベントストリーム開始。
|
||||||
* `Timeline`によるイベント処理。
|
- `Timeline`によるイベント処理。
|
||||||
4. **Tool Handling (Parallel)**:
|
4. **Tool Handling (Parallel)**:
|
||||||
* レスポンス内に含まれる全てのTool Callを収集。
|
- レスポンス内に含まれる全てのTool Callを収集。
|
||||||
* 各Toolに対して **Hook: BeforeToolCall** を実行(実行可否、引数改変)。
|
- 各Toolに対して **Hook: BeforeToolCall** を実行(実行可否、引数改変)。
|
||||||
* 許可されたToolを**並列実行 (`join_all`)**。
|
- 許可されたToolを**並列実行 (`join_all`)**。
|
||||||
* 各Tool実行後に **Hook: AfterToolCall** を実行(結果の確認、加工)。
|
- 各Tool実行後に **Hook: AfterToolCall** を実行(結果の確認、加工)。
|
||||||
5. **Next Request Decision**:
|
5. **Next Request Decision**:
|
||||||
* Tool実行結果がある場合 -> 結果をMessageとしてContextに追加し、**Step 3へ戻る** (自動ループ)。
|
- Tool実行結果がある場合 -> 結果をMessageとしてContextに追加し、**Step
|
||||||
* Tool実行がない場合 -> Step 6へ。
|
3へ戻る** (自動ループ)。
|
||||||
|
- Tool実行がない場合 -> Step 6へ。
|
||||||
6. **Hook: OnTurnEnd**:
|
6. **Hook: OnTurnEnd**:
|
||||||
* 最終的な応答に対するチェック(Lint/Fmt)。
|
- 最終的な応答に対するチェック(Lint/Fmt)。
|
||||||
* エラーがある場合、エラーメッセージをContextに追加して **Step 3へ戻る** ことで自己修正を促せる。
|
- エラーがある場合、エラーメッセージをContextに追加して **Step 3へ戻る**
|
||||||
* 問題なければターン終了。
|
ことで自己修正を促せる。
|
||||||
|
- 問題なければターン終了。
|
||||||
|
|
||||||
## Tool 設計
|
## Tool 設計
|
||||||
|
|
||||||
### アーキテクチャ概要
|
### アーキテクチャ概要
|
||||||
|
|
||||||
Rustの静的型付けシステムとLLMの動的なツール呼び出し(文字列による指定)を、**Trait Object** と **動的ディスパッチ** を用いて接続します。
|
Rustの静的型付けシステムとLLMの動的なツール呼び出し(文字列による指定)を、**Trait
|
||||||
|
Object** と **動的ディスパッチ** を用いて接続します。
|
||||||
|
|
||||||
1. **共通インターフェース (`Tool` Trait)**: 全てのツールが実装すべき共通の振る舞い(メタデータ取得と実行)を定義します。
|
1. **共通インターフェース (`Tool` Trait)**:
|
||||||
2. **ラッパー生成 (`#[tool]` Macro)**: ユーザー定義のメソッドをラップし、`Tool` Traitを実装した構造体を自動生成します。
|
全てのツールが実装すべき共通の振る舞い(メタデータ取得と実行)を定義します。
|
||||||
3. **レジストリ (`HashMap`)**: Workerは動的ディスパッチ用に `HashMap<String, Box<dyn Tool>>` でツールを管理します。
|
2. **ラッパー生成 (`#[tool]` Macro)**: ユーザー定義のメソッドをラップし、`Tool`
|
||||||
|
Traitを実装した構造体を自動生成します。
|
||||||
|
3. **レジストリ (`HashMap`)**: Workerは動的ディスパッチ用に
|
||||||
|
`HashMap<String, Box<dyn Tool>>` でツールを管理します。
|
||||||
|
|
||||||
この仕組みにより、「名前からツールを探し、JSON引数を型変換して関数を実行する」フローを安全に実現します。
|
この仕組みにより、「名前からツールを探し、JSON引数を型変換して関数を実行する」フローを安全に実現します。
|
||||||
|
|
||||||
### 1. Tool Trait 定義
|
### 1. Tool Trait 定義
|
||||||
|
|
||||||
ツールが最低限持つべきインターフェースです。`Send + Sync` を必須とし、マルチスレッド(並列実行)に対応します。
|
ツールが最低限持つべきインターフェースです。`Send + Sync`
|
||||||
|
を必須とし、マルチスレッド(並列実行)に対応します。
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
|
@ -113,7 +120,8 @@ impl MyApp {
|
||||||
|
|
||||||
**マクロ展開後のイメージ (擬似コード):**
|
**マクロ展開後のイメージ (擬似コード):**
|
||||||
|
|
||||||
マクロは、元のメソッドに対応する**ラッパー構造体**を生成します。このラッパーが `Tool` Trait を実装します。
|
マクロは、元のメソッドに対応する**ラッパー構造体**を生成します。このラッパーが
|
||||||
|
`Tool` Trait を実装します。
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
// 1. 引数をデシリアライズ用の中間構造体に変換
|
// 1. 引数をデシリアライズ用の中間構造体に変換
|
||||||
|
|
@ -155,15 +163,18 @@ impl Tool for GetUserTool {
|
||||||
|
|
||||||
### 3. Workerによる実行フロー
|
### 3. Workerによる実行フロー
|
||||||
|
|
||||||
Workerは生成されたラッパー構造体を `Box<dyn Tool>` として保持し、以下のフローで実行します。
|
Workerは生成されたラッパー構造体を `Box<dyn Tool>`
|
||||||
|
として保持し、以下のフローで実行します。
|
||||||
|
|
||||||
1. **登録**: アプリケーション開始時、コンテキスト(`MyApp`)から各ツールのラッパー(`GetUserTool`)を生成し、WorkerのMapに登録。
|
1. **登録**:
|
||||||
2. **解決**: LLMからのレスポンスに含まれる `ToolUse { name: "get_user", ... }` を受け取る。
|
アプリケーション開始時、コンテキスト(`MyApp`)から各ツールのラッパー(`GetUserTool`)を生成し、WorkerのMapに登録。
|
||||||
|
2. **解決**: LLMからのレスポンスに含まれる `ToolUse { name: "get_user", ... }`
|
||||||
|
を受け取る。
|
||||||
3. **検索**: `name` をキーに Map から `Box<dyn Tool>` を取得。
|
3. **検索**: `name` をキーに Map から `Box<dyn Tool>` を取得。
|
||||||
4. **実行**:
|
4. **実行**:
|
||||||
* `tool.execute(json)` を呼び出す。
|
- `tool.execute(json)` を呼び出す。
|
||||||
* 内部で `serde_json` による型変換とメソッド実行が行われる。
|
- 内部で `serde_json` による型変換とメソッド実行が行われる。
|
||||||
* 結果が返る。
|
- 結果が返る。
|
||||||
|
|
||||||
これにより、型安全性を保ちつつ、動的なツール実行が可能になります。
|
これにより、型安全性を保ちつつ、動的なツール実行が可能になります。
|
||||||
|
|
||||||
|
|
@ -171,8 +182,9 @@ Workerは生成されたラッパー構造体を `Box<dyn Tool>` として保持
|
||||||
|
|
||||||
### コンセプト
|
### コンセプト
|
||||||
|
|
||||||
* **制御の介入**: ターンの進行、メッセージの内容、ツールの実行に対して介入します。
|
- **制御の介入**:
|
||||||
* **Contextへのアクセス**: メッセージ履歴(Context)を読み書きできます。
|
ターンの進行、メッセージの内容、ツールの実行に対して介入します。
|
||||||
|
- **Contextへのアクセス**: メッセージ履歴(Context)を読み書きできます。
|
||||||
|
|
||||||
### Hook Trait
|
### Hook Trait
|
||||||
|
|
||||||
|
|
@ -219,7 +231,8 @@ pub enum OnTurnEndResult {
|
||||||
|
|
||||||
### Tool Call Context
|
### Tool Call Context
|
||||||
|
|
||||||
`before_tool_call` / `after_tool_call` は、ツール実行の文脈を含む入力を受け取る。
|
`before_tool_call` / `after_tool_call`
|
||||||
|
は、ツール実行の文脈を含む入力を受け取る。
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
pub struct ToolCallContext {
|
pub struct ToolCallContext {
|
||||||
|
|
@ -238,16 +251,19 @@ pub struct ToolResultContext {
|
||||||
## 実装方針
|
## 実装方針
|
||||||
|
|
||||||
1. **Worker Struct**:
|
1. **Worker Struct**:
|
||||||
* `Timeline`を所有。
|
- `Timeline`を所有。
|
||||||
* `Handler`として「ToolCallCollector」をTimelineに登録。
|
- `Handler`として「ToolCallCollector」をTimelineに登録。
|
||||||
* `stream`終了後に収集したToolCallを処理するロジックを持つ。
|
- `stream`終了後に収集したToolCallを処理するロジックを持つ。
|
||||||
|
- **履歴管理**: `set_history`, `with_messages`, `history_mut`
|
||||||
|
等を通じて、会話履歴の注入や編集を可能にする。
|
||||||
|
|
||||||
2. **Tool Executor Handler**:
|
2. **Tool Executor Handler**:
|
||||||
* Timeline上ではツール実行を行わず、あくまで「ToolCallブロックの収集」に徹する(Toolの実行は非同期かつ並列で、ストリーム終了後あるいはブロック確定後に行うため)。
|
- Timeline上ではツール実行を行わず、あくまで「ToolCallブロックの収集」に徹する(Toolの実行は非同期かつ並列で、ストリーム終了後あるいはブロック確定後に行うため)。
|
||||||
* ただし、リアルタイム性を重視する場合(ストリーミング中にToolを実行開始等)は将来的な拡張とするが、現状は「結果が揃うのを待って」という要件に従い、収集フェーズと実行フェーズを分ける。
|
- ただし、リアルタイム性を重視する場合(ストリーミング中にToolを実行開始等)は将来的な拡張とするが、現状は「結果が揃うのを待って」という要件に従い、収集フェーズと実行フェーズを分ける。
|
||||||
|
|
||||||
3. **worker-macros**:
|
3. **worker-macros**:
|
||||||
* `syn`, `quote` を用いて、関数定義から `Tool` トレイト実装と `InputSchema` (schemars利用) を生成。
|
- `syn`, `quote` を用いて、関数定義から `Tool` トレイト実装と `InputSchema`
|
||||||
|
(schemars利用) を生成。
|
||||||
|
|
||||||
## Worker Event API 設計
|
## Worker Event API 設計
|
||||||
|
|
||||||
|
|
@ -256,6 +272,7 @@ pub struct ToolResultContext {
|
||||||
Workerは内部でイベントを処理し結果を返しますが、UIへのストリーミング表示やリアルタイムフィードバックには、イベントを外部に公開する仕組みが必要です。
|
Workerは内部でイベントを処理し結果を返しますが、UIへのストリーミング表示やリアルタイムフィードバックには、イベントを外部に公開する仕組みが必要です。
|
||||||
|
|
||||||
**要件**:
|
**要件**:
|
||||||
|
|
||||||
1. テキストデルタをリアルタイムでUIに表示
|
1. テキストデルタをリアルタイムでUIに表示
|
||||||
2. ツール呼び出しの進行状況を表示
|
2. ツール呼び出しの進行状況を表示
|
||||||
3. ブロック完了時に累積結果を受け取る
|
3. ブロック完了時に累積結果を受け取る
|
||||||
|
|
@ -265,7 +282,7 @@ Workerは内部でイベントを処理し結果を返しますが、UIへのス
|
||||||
Worker APIは **Timeline層のHandler機構の薄いラッパー** として設計します。
|
Worker APIは **Timeline層のHandler機構の薄いラッパー** として設計します。
|
||||||
|
|
||||||
| 層 | 目的 | 提供するもの |
|
| 層 | 目的 | 提供するもの |
|
||||||
|---|------|-------------|
|
| ------------------------ | ------------------ | ---------------------------------- |
|
||||||
| **Handler (Timeline層)** | 内部実装、役割分離 | スコープ管理 + Deltaイベント |
|
| **Handler (Timeline層)** | 内部実装、役割分離 | スコープ管理 + Deltaイベント |
|
||||||
| **Worker Event API** | ユーザー向け利便性 | Handler露出 + Completeイベント追加 |
|
| **Worker Event API** | ユーザー向け利便性 | Handler露出 + Completeイベント追加 |
|
||||||
|
|
||||||
|
|
@ -448,7 +465,8 @@ impl<C: LlmClient> Worker<C> {
|
||||||
### 設計上のポイント
|
### 設計上のポイント
|
||||||
|
|
||||||
1. **Handlerの再利用**: 既存のHandler traitをそのまま活用
|
1. **Handlerの再利用**: 既存のHandler traitをそのまま活用
|
||||||
2. **スコープ管理の維持**: ブロックイベントはStart→Delta→Endのライフサイクルを保持
|
2. **スコープ管理の維持**:
|
||||||
|
ブロックイベントはStart→Delta→Endのライフサイクルを保持
|
||||||
3. **選択的購読**: on_*で必要なイベントだけ、またはSubscriberで一括
|
3. **選択的購読**: on_*で必要なイベントだけ、またはSubscriberで一括
|
||||||
4. **累積イベントの追加**: Worker層でComplete系イベントを追加提供
|
4. **累積イベントの追加**: Worker層でComplete系イベントを追加提供
|
||||||
5. **後方互換性**: 従来の`run()`も引き続き使用可能
|
5. **後方互換性**: 従来の`run()`も引き続き使用可能
|
||||||
|
|
|
||||||
|
|
@ -35,11 +35,11 @@
|
||||||
},
|
},
|
||||||
"nixpkgs": {
|
"nixpkgs": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1767116409,
|
"lastModified": 1771369470,
|
||||||
"narHash": "sha256-5vKw92l1GyTnjoLzEagJy5V5mDFck72LiQWZSOnSicw=",
|
"narHash": "sha256-0NBlEBKkN3lufyvFegY4TYv5mCNHbi5OmBDrzihbBMQ=",
|
||||||
"owner": "nixos",
|
"owner": "nixos",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "cad22e7d996aea55ecab064e84834289143e44a0",
|
"rev": "0182a361324364ae3f436a63005877674cf45efb",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
//! llm-worker-macros - Tool生成用手続きマクロ
|
//! llm-worker-macros - Procedural macros for Tool generation
|
||||||
//!
|
//!
|
||||||
//! `#[tool_registry]` と `#[tool]` マクロを提供し、
|
//! Provides `#[tool_registry]` and `#[tool]` macros to
|
||||||
//! ユーザー定義のメソッドから `Tool` トレイト実装を自動生成する。
|
//! automatically generate `Tool` trait implementations from user-defined methods.
|
||||||
|
|
||||||
use proc_macro::TokenStream;
|
use proc_macro::TokenStream;
|
||||||
use quote::{format_ident, quote};
|
use quote::{format_ident, quote};
|
||||||
|
|
@ -9,22 +9,22 @@ use syn::{
|
||||||
Attribute, FnArg, ImplItem, ItemImpl, Lit, Meta, Pat, ReturnType, Type, parse_macro_input,
|
Attribute, FnArg, ImplItem, ItemImpl, Lit, Meta, Pat, ReturnType, Type, parse_macro_input,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// `impl` ブロックに付与し、内部の `#[tool]` 属性がついたメソッドからツールを生成するマクロ。
|
/// Macro applied to an `impl` block that generates tools from methods marked with `#[tool]`.
|
||||||
///
|
///
|
||||||
/// # Example
|
/// # Example
|
||||||
/// ```ignore
|
/// ```ignore
|
||||||
/// #[tool_registry]
|
/// #[tool_registry]
|
||||||
/// impl MyApp {
|
/// impl MyApp {
|
||||||
/// /// ユーザー情報を取得する
|
/// /// Get user information
|
||||||
/// /// 指定されたIDのユーザーをDBから検索します。
|
/// /// Retrieves a user from the database by their ID.
|
||||||
/// #[tool]
|
/// #[tool]
|
||||||
/// async fn get_user(&self, user_id: String) -> Result<User, Error> { ... }
|
/// async fn get_user(&self, user_id: String) -> Result<User, Error> { ... }
|
||||||
/// }
|
/// }
|
||||||
/// ```
|
/// ```
|
||||||
///
|
///
|
||||||
/// これにより以下が生成されます:
|
/// This generates:
|
||||||
/// - `GetUserArgs` 構造体(引数用)
|
/// - `GetUserArgs` struct (for arguments)
|
||||||
/// - `Tool_get_user` 構造体(Toolラッパー)
|
/// - `Tool_get_user` struct (Tool wrapper)
|
||||||
/// - `impl Tool for Tool_get_user`
|
/// - `impl Tool for Tool_get_user`
|
||||||
/// - `impl MyApp { fn get_user_tool(&self) -> Tool_get_user }`
|
/// - `impl MyApp { fn get_user_tool(&self) -> Tool_get_user }`
|
||||||
#[proc_macro_attribute]
|
#[proc_macro_attribute]
|
||||||
|
|
@ -36,14 +36,14 @@ pub fn tool_registry(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||||
|
|
||||||
for item in &mut impl_block.items {
|
for item in &mut impl_block.items {
|
||||||
if let ImplItem::Fn(method) = item {
|
if let ImplItem::Fn(method) = item {
|
||||||
// #[tool] 属性を探す
|
// Look for #[tool] attribute
|
||||||
let mut is_tool = false;
|
let mut is_tool = false;
|
||||||
|
|
||||||
// 属性を走査してtoolがあるか確認し、削除する
|
// Iterate through attributes to check for tool and remove it
|
||||||
method.attrs.retain(|attr| {
|
method.attrs.retain(|attr| {
|
||||||
if attr.path().is_ident("tool") {
|
if attr.path().is_ident("tool") {
|
||||||
is_tool = true;
|
is_tool = true;
|
||||||
false // 属性を削除
|
false // Remove the attribute
|
||||||
} else {
|
} else {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
@ -65,7 +65,7 @@ pub fn tool_registry(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||||
TokenStream::from(expanded)
|
TokenStream::from(expanded)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ドキュメントコメントから説明文を抽出
|
/// Extract description from doc comments
|
||||||
fn extract_doc_comment(attrs: &[Attribute]) -> String {
|
fn extract_doc_comment(attrs: &[Attribute]) -> String {
|
||||||
let mut lines = Vec::new();
|
let mut lines = Vec::new();
|
||||||
|
|
||||||
|
|
@ -75,7 +75,7 @@ fn extract_doc_comment(attrs: &[Attribute]) -> String {
|
||||||
if let syn::Expr::Lit(expr_lit) = &meta.value {
|
if let syn::Expr::Lit(expr_lit) = &meta.value {
|
||||||
if let Lit::Str(lit_str) = &expr_lit.lit {
|
if let Lit::Str(lit_str) = &expr_lit.lit {
|
||||||
let line = lit_str.value();
|
let line = lit_str.value();
|
||||||
// 先頭の空白を1つだけ除去(/// の後のスペース)
|
// Remove only the leading space (after ///)
|
||||||
let trimmed = line.strip_prefix(' ').unwrap_or(&line);
|
let trimmed = line.strip_prefix(' ').unwrap_or(&line);
|
||||||
lines.push(trimmed.to_string());
|
lines.push(trimmed.to_string());
|
||||||
}
|
}
|
||||||
|
|
@ -87,7 +87,7 @@ fn extract_doc_comment(attrs: &[Attribute]) -> String {
|
||||||
lines.join("\n")
|
lines.join("\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
/// #[description = "..."] 属性から説明を抽出
|
/// Extract description from #[description = "..."] attribute
|
||||||
fn extract_description_attr(attrs: &[syn::Attribute]) -> Option<String> {
|
fn extract_description_attr(attrs: &[syn::Attribute]) -> Option<String> {
|
||||||
for attr in attrs {
|
for attr in attrs {
|
||||||
if attr.path().is_ident("description") {
|
if attr.path().is_ident("description") {
|
||||||
|
|
@ -103,19 +103,19 @@ fn extract_description_attr(attrs: &[syn::Attribute]) -> Option<String> {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
/// メソッドからTool実装を生成
|
/// Generate Tool implementation from a method
|
||||||
fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::TokenStream {
|
fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::TokenStream {
|
||||||
let sig = &method.sig;
|
let sig = &method.sig;
|
||||||
let method_name = &sig.ident;
|
let method_name = &sig.ident;
|
||||||
let tool_name = method_name.to_string();
|
let tool_name = method_name.to_string();
|
||||||
|
|
||||||
// 構造体名を生成(PascalCase変換)
|
// Generate struct names (convert to PascalCase)
|
||||||
let pascal_name = to_pascal_case(&method_name.to_string());
|
let pascal_name = to_pascal_case(&method_name.to_string());
|
||||||
let tool_struct_name = format_ident!("Tool{}", pascal_name);
|
let tool_struct_name = format_ident!("Tool{}", pascal_name);
|
||||||
let args_struct_name = format_ident!("{}Args", pascal_name);
|
let args_struct_name = format_ident!("{}Args", pascal_name);
|
||||||
let definition_name = format_ident!("{}_definition", method_name);
|
let definition_name = format_ident!("{}_definition", method_name);
|
||||||
|
|
||||||
// ドキュメントコメントから説明を取得
|
// Get description from doc comments
|
||||||
let description = extract_doc_comment(&method.attrs);
|
let description = extract_doc_comment(&method.attrs);
|
||||||
let description = if description.is_empty() {
|
let description = if description.is_empty() {
|
||||||
format!("Tool: {}", tool_name)
|
format!("Tool: {}", tool_name)
|
||||||
|
|
@ -123,7 +123,7 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
|
||||||
description
|
description
|
||||||
};
|
};
|
||||||
|
|
||||||
// 引数を解析(selfを除く)
|
// Parse arguments (excluding self)
|
||||||
let args: Vec<_> = sig
|
let args: Vec<_> = sig
|
||||||
.inputs
|
.inputs
|
||||||
.iter()
|
.iter()
|
||||||
|
|
@ -131,12 +131,12 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
|
||||||
if let FnArg::Typed(pat_type) = arg {
|
if let FnArg::Typed(pat_type) = arg {
|
||||||
Some(pat_type)
|
Some(pat_type)
|
||||||
} else {
|
} else {
|
||||||
None // selfを除外
|
None // Exclude self
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// 引数構造体のフィールドを生成
|
// Generate argument struct fields
|
||||||
let arg_fields: Vec<_> = args
|
let arg_fields: Vec<_> = args
|
||||||
.iter()
|
.iter()
|
||||||
.map(|pat_type| {
|
.map(|pat_type| {
|
||||||
|
|
@ -144,14 +144,14 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
|
||||||
let ty = &pat_type.ty;
|
let ty = &pat_type.ty;
|
||||||
let desc = extract_description_attr(&pat_type.attrs);
|
let desc = extract_description_attr(&pat_type.attrs);
|
||||||
|
|
||||||
// パターンから識別子を抽出
|
// Extract identifier from pattern
|
||||||
let field_name = if let Pat::Ident(pat_ident) = pat.as_ref() {
|
let field_name = if let Pat::Ident(pat_ident) = pat.as_ref() {
|
||||||
&pat_ident.ident
|
&pat_ident.ident
|
||||||
} else {
|
} else {
|
||||||
panic!("Only simple identifiers are supported for tool arguments");
|
panic!("Only simple identifiers are supported for tool arguments");
|
||||||
};
|
};
|
||||||
|
|
||||||
// #[description] があればschemarsのdocに変換
|
// Convert #[description] to schemars doc if present
|
||||||
if let Some(desc_str) = desc {
|
if let Some(desc_str) = desc {
|
||||||
quote! {
|
quote! {
|
||||||
#[schemars(description = #desc_str)]
|
#[schemars(description = #desc_str)]
|
||||||
|
|
@ -165,7 +165,7 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// execute内で引数を展開するコード
|
// Code to expand arguments in execute
|
||||||
let arg_names: Vec<_> = args
|
let arg_names: Vec<_> = args
|
||||||
.iter()
|
.iter()
|
||||||
.map(|pat_type| {
|
.map(|pat_type| {
|
||||||
|
|
@ -178,17 +178,17 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// メソッドが非同期かどうか
|
// Check if method is async
|
||||||
let is_async = sig.asyncness.is_some();
|
let is_async = sig.asyncness.is_some();
|
||||||
|
|
||||||
// 戻り値の型を解析してResult判定
|
// Parse return type and determine if Result
|
||||||
let awaiter = if is_async {
|
let awaiter = if is_async {
|
||||||
quote! { .await }
|
quote! { .await }
|
||||||
} else {
|
} else {
|
||||||
quote! {}
|
quote! {}
|
||||||
};
|
};
|
||||||
|
|
||||||
// 戻り値がResultかどうかを判定
|
// Determine if return type is Result
|
||||||
let result_handling = if is_result_type(&sig.output) {
|
let result_handling = if is_result_type(&sig.output) {
|
||||||
quote! {
|
quote! {
|
||||||
match result {
|
match result {
|
||||||
|
|
@ -202,7 +202,7 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// 引数がない場合は空のArgs構造体を作成
|
// Create empty Args struct if no arguments
|
||||||
let args_struct_def = if arg_fields.is_empty() {
|
let args_struct_def = if arg_fields.is_empty() {
|
||||||
quote! {
|
quote! {
|
||||||
#[derive(serde::Deserialize, schemars::JsonSchema)]
|
#[derive(serde::Deserialize, schemars::JsonSchema)]
|
||||||
|
|
@ -217,10 +217,10 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// 引数がない場合のexecute処理
|
// Execute body handling for no arguments case
|
||||||
let execute_body = if args.is_empty() {
|
let execute_body = if args.is_empty() {
|
||||||
quote! {
|
quote! {
|
||||||
// 引数なしでも空のJSONオブジェクトを許容
|
// Allow empty JSON object even with no arguments
|
||||||
let _: #args_struct_name = serde_json::from_str(input_json)
|
let _: #args_struct_name = serde_json::from_str(input_json)
|
||||||
.unwrap_or(#args_struct_name {});
|
.unwrap_or(#args_struct_name {});
|
||||||
|
|
||||||
|
|
@ -253,7 +253,7 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
|
||||||
}
|
}
|
||||||
|
|
||||||
impl #self_ty {
|
impl #self_ty {
|
||||||
/// ToolDefinition を取得(Worker への登録用)
|
/// Get ToolDefinition (for registering with Worker)
|
||||||
pub fn #definition_name(&self) -> ::llm_worker::tool::ToolDefinition {
|
pub fn #definition_name(&self) -> ::llm_worker::tool::ToolDefinition {
|
||||||
let ctx = self.clone();
|
let ctx = self.clone();
|
||||||
::std::sync::Arc::new(move || {
|
::std::sync::Arc::new(move || {
|
||||||
|
|
@ -270,12 +270,12 @@ fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 戻り値の型がResultかどうかを判定
|
/// Determine if return type is Result
|
||||||
fn is_result_type(return_type: &ReturnType) -> bool {
|
fn is_result_type(return_type: &ReturnType) -> bool {
|
||||||
match return_type {
|
match return_type {
|
||||||
ReturnType::Default => false,
|
ReturnType::Default => false,
|
||||||
ReturnType::Type(_, ty) => {
|
ReturnType::Type(_, ty) => {
|
||||||
// Type::Pathの場合、最後のセグメントが"Result"かチェック
|
// For Type::Path, check if last segment is "Result"
|
||||||
if let Type::Path(type_path) = ty.as_ref() {
|
if let Type::Path(type_path) = ty.as_ref() {
|
||||||
if let Some(segment) = type_path.path.segments.last() {
|
if let Some(segment) = type_path.path.segments.last() {
|
||||||
return segment.ident == "Result";
|
return segment.ident == "Result";
|
||||||
|
|
@ -286,7 +286,7 @@ fn is_result_type(return_type: &ReturnType) -> bool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// snake_case を PascalCase に変換
|
/// Convert snake_case to PascalCase
|
||||||
fn to_pascal_case(s: &str) -> String {
|
fn to_pascal_case(s: &str) -> String {
|
||||||
s.split('_')
|
s.split('_')
|
||||||
.map(|part| {
|
.map(|part| {
|
||||||
|
|
@ -299,20 +299,20 @@ fn to_pascal_case(s: &str) -> String {
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// マーカー属性。`tool_registry` によって処理されるため、ここでは何もしない。
|
/// Marker attribute. Does nothing here as it's processed by `tool_registry`.
|
||||||
#[proc_macro_attribute]
|
#[proc_macro_attribute]
|
||||||
pub fn tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
pub fn tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||||
item
|
item
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 引数属性用のマーカー。パース時に`tool_registry`で解釈される。
|
/// Marker for argument attributes. Interpreted by `tool_registry` during parsing.
|
||||||
///
|
///
|
||||||
/// # Example
|
/// # Example
|
||||||
/// ```ignore
|
/// ```ignore
|
||||||
/// #[tool]
|
/// #[tool]
|
||||||
/// async fn get_user(
|
/// async fn get_user(
|
||||||
/// &self,
|
/// &self,
|
||||||
/// #[description = "取得したいユーザーのID"] user_id: String
|
/// #[description = "The ID of the user to retrieve"] user_id: String
|
||||||
/// ) -> Result<User, Error> { ... }
|
/// ) -> Result<User, Error> { ... }
|
||||||
/// ```
|
/// ```
|
||||||
#[proc_macro_attribute]
|
#[proc_macro_attribute]
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
[package]
|
[package]
|
||||||
name = "llm-worker"
|
name = "llm-worker"
|
||||||
description = "A library for building autonomous LLM-powered systems"
|
description = "A library for building autonomous LLM-powered systems"
|
||||||
version = "0.2.0"
|
version = "0.2.1"
|
||||||
publish.workspace = true
|
publish.workspace = true
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|
@ -26,3 +26,4 @@ schemars = "1.2"
|
||||||
tempfile = "3.24"
|
tempfile = "3.24"
|
||||||
dotenv = "0.15"
|
dotenv = "0.15"
|
||||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
|
trybuild = "1.0.116"
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,18 @@
|
||||||
//! テストフィクスチャ記録ツール
|
//! Test fixture recording tool
|
||||||
//!
|
//!
|
||||||
//! 定義されたシナリオのAPIレスポンスを記録する。
|
//! Records API responses for defined scenarios.
|
||||||
//!
|
//!
|
||||||
//! ## 使用方法
|
//! ## Usage
|
||||||
//!
|
//!
|
||||||
//! ```bash
|
//! ```bash
|
||||||
//! # 利用可能なシナリオを表示
|
//! # Show available scenarios
|
||||||
//! cargo run --example record_test_fixtures
|
//! cargo run --example record_test_fixtures
|
||||||
//!
|
//!
|
||||||
//! # 特定のシナリオを記録
|
//! # Record specific scenario
|
||||||
//! ANTHROPIC_API_KEY=your-key cargo run --example record_test_fixtures -- simple_text
|
//! ANTHROPIC_API_KEY=your-key cargo run --example record_test_fixtures -- simple_text
|
||||||
//! ANTHROPIC_API_KEY=your-key cargo run --example record_test_fixtures -- tool_call
|
//! ANTHROPIC_API_KEY=your-key cargo run --example record_test_fixtures -- tool_call
|
||||||
//!
|
//!
|
||||||
//! # 全シナリオを記録
|
//! # Record all scenarios
|
||||||
//! ANTHROPIC_API_KEY=your-key cargo run --example record_test_fixtures -- --all
|
//! ANTHROPIC_API_KEY=your-key cargo run --example record_test_fixtures -- --all
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
|
|
@ -193,8 +193,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
ClientType::Ollama => "ollama",
|
ClientType::Ollama => "ollama",
|
||||||
};
|
};
|
||||||
|
|
||||||
// シナリオのフィルタリングは main.rs のロジックで実行済み
|
// Scenario filtering is already done in main.rs logic
|
||||||
// ここでは単純なループで実行
|
// Here we just execute in a simple loop
|
||||||
for scenario in scenarios_to_run {
|
for scenario in scenarios_to_run {
|
||||||
match args.client {
|
match args.client {
|
||||||
ClientType::Anthropic => {
|
ClientType::Anthropic => {
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
//! テストフィクスチャ記録機構
|
//! Test fixture recording mechanism
|
||||||
//!
|
//!
|
||||||
//! イベントをJSONLフォーマットでファイルに保存する
|
//! Saves events to files in JSONL format
|
||||||
|
|
||||||
use std::fs::{self, File};
|
use std::fs::{self, File};
|
||||||
use std::io::{BufWriter, Write};
|
use std::io::{BufWriter, Write};
|
||||||
|
|
@ -10,7 +10,7 @@ use std::time::{Instant, SystemTime, UNIX_EPOCH};
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use llm_worker::llm_client::{LlmClient, Request};
|
use llm_worker::llm_client::{LlmClient, Request};
|
||||||
|
|
||||||
/// 記録されたイベント
|
/// Recorded event
|
||||||
#[derive(Debug, serde::Serialize, serde::Deserialize)]
|
#[derive(Debug, serde::Serialize, serde::Deserialize)]
|
||||||
pub struct RecordedEvent {
|
pub struct RecordedEvent {
|
||||||
pub elapsed_ms: u64,
|
pub elapsed_ms: u64,
|
||||||
|
|
@ -18,7 +18,7 @@ pub struct RecordedEvent {
|
||||||
pub data: String,
|
pub data: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// セッションメタデータ
|
/// Session metadata
|
||||||
#[derive(Debug, serde::Serialize, serde::Deserialize)]
|
#[derive(Debug, serde::Serialize, serde::Deserialize)]
|
||||||
pub struct SessionMetadata {
|
pub struct SessionMetadata {
|
||||||
pub timestamp: u64,
|
pub timestamp: u64,
|
||||||
|
|
@ -26,7 +26,7 @@ pub struct SessionMetadata {
|
||||||
pub description: String,
|
pub description: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// イベントシーケンスをファイルに保存
|
/// Save event sequence to file
|
||||||
pub fn save_fixture(
|
pub fn save_fixture(
|
||||||
path: impl AsRef<Path>,
|
path: impl AsRef<Path>,
|
||||||
metadata: &SessionMetadata,
|
metadata: &SessionMetadata,
|
||||||
|
|
@ -43,7 +43,7 @@ pub fn save_fixture(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// リクエストを送信してイベントを記録
|
/// Send request and record events
|
||||||
pub async fn record_request<C: LlmClient>(
|
pub async fn record_request<C: LlmClient>(
|
||||||
client: &C,
|
client: &C,
|
||||||
request: Request,
|
request: Request,
|
||||||
|
|
@ -78,7 +78,7 @@ pub async fn record_request<C: LlmClient>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 保存
|
// Save
|
||||||
let fixtures_dir = Path::new("worker/tests/fixtures").join(subdir);
|
let fixtures_dir = Path::new("worker/tests/fixtures").join(subdir);
|
||||||
fs::create_dir_all(&fixtures_dir)?;
|
fs::create_dir_all(&fixtures_dir)?;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,20 @@
|
||||||
//! テストフィクスチャ用リクエスト定義
|
//! Test fixture request definitions
|
||||||
//!
|
//!
|
||||||
//! 各シナリオのリクエストと出力ファイル名を定義
|
//! Defines requests and output file names for each scenario
|
||||||
|
|
||||||
use llm_worker::llm_client::{Request, ToolDefinition};
|
use llm_worker::llm_client::{Request, ToolDefinition};
|
||||||
|
|
||||||
/// テストシナリオ
|
/// Test scenario
|
||||||
pub struct TestScenario {
|
pub struct TestScenario {
|
||||||
/// シナリオ名(説明)
|
/// Scenario name (description)
|
||||||
pub name: &'static str,
|
pub name: &'static str,
|
||||||
/// 出力ファイル名(拡張子なし)
|
/// Output file name (without extension)
|
||||||
pub output_name: &'static str,
|
pub output_name: &'static str,
|
||||||
/// リクエスト
|
/// Request
|
||||||
pub request: Request,
|
pub request: Request,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 全てのテストシナリオを取得
|
/// Get all test scenarios
|
||||||
pub fn scenarios() -> Vec<TestScenario> {
|
pub fn scenarios() -> Vec<TestScenario> {
|
||||||
vec![
|
vec![
|
||||||
simple_text_scenario(),
|
simple_text_scenario(),
|
||||||
|
|
@ -23,7 +23,7 @@ pub fn scenarios() -> Vec<TestScenario> {
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
/// シンプルなテキストレスポンス
|
/// Simple text response
|
||||||
fn simple_text_scenario() -> TestScenario {
|
fn simple_text_scenario() -> TestScenario {
|
||||||
TestScenario {
|
TestScenario {
|
||||||
name: "Simple text response",
|
name: "Simple text response",
|
||||||
|
|
@ -35,7 +35,7 @@ fn simple_text_scenario() -> TestScenario {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ツール呼び出しを含むレスポンス
|
/// Response with tool call
|
||||||
fn tool_call_scenario() -> TestScenario {
|
fn tool_call_scenario() -> TestScenario {
|
||||||
let get_weather_tool = ToolDefinition::new("get_weather")
|
let get_weather_tool = ToolDefinition::new("get_weather")
|
||||||
.description("Get the current weather for a city")
|
.description("Get the current weather for a city")
|
||||||
|
|
@ -61,7 +61,7 @@ fn tool_call_scenario() -> TestScenario {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 長文生成シナリオ
|
/// Long text generation scenario
|
||||||
fn long_text_scenario() -> TestScenario {
|
fn long_text_scenario() -> TestScenario {
|
||||||
TestScenario {
|
TestScenario {
|
||||||
name: "Long text response",
|
name: "Long text response",
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
//! Worker のキャンセル機能のデモンストレーション
|
//! Worker cancellation demo
|
||||||
//!
|
//!
|
||||||
//! ストリーミング受信中に別スレッドからキャンセルする例
|
//! Example of cancelling from another thread during streaming
|
||||||
|
|
||||||
use llm_worker::llm_client::providers::anthropic::AnthropicClient;
|
use llm_worker::llm_client::providers::anthropic::AnthropicClient;
|
||||||
use llm_worker::{Worker, WorkerResult};
|
use llm_worker::{Worker, WorkerResult};
|
||||||
|
|
@ -10,10 +10,10 @@ use tokio::sync::Mutex;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
// .envファイルを読み込む
|
// Load .env file
|
||||||
dotenv::dotenv().ok();
|
dotenv::dotenv().ok();
|
||||||
|
|
||||||
// ロギング初期化
|
// Initialize logging
|
||||||
tracing_subscriber::fmt()
|
tracing_subscriber::fmt()
|
||||||
.with_env_filter(
|
.with_env_filter(
|
||||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||||
|
|
@ -30,13 +30,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
println!("🚀 Starting Worker...");
|
println!("🚀 Starting Worker...");
|
||||||
println!("💡 Will cancel after 2 seconds\n");
|
println!("💡 Will cancel after 2 seconds\n");
|
||||||
|
|
||||||
// キャンセルSenderを先に取得(ロックを保持しない)
|
// Get cancel sender first (without holding lock)
|
||||||
let cancel_tx = {
|
let cancel_tx = {
|
||||||
let w = worker.lock().await;
|
let w = worker.lock().await;
|
||||||
w.cancel_sender()
|
w.cancel_sender()
|
||||||
};
|
};
|
||||||
|
|
||||||
// タスク1: Workerを実行
|
// Task 1: Run Worker
|
||||||
let worker_clone = worker.clone();
|
let worker_clone = worker.clone();
|
||||||
let task = tokio::spawn(async move {
|
let task = tokio::spawn(async move {
|
||||||
let mut w = worker_clone.lock().await;
|
let mut w = worker_clone.lock().await;
|
||||||
|
|
@ -55,14 +55,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// タスク2: 2秒後にキャンセル
|
// Task 2: Cancel after 2 seconds
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
tokio::time::sleep(Duration::from_secs(2)).await;
|
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||||
println!("\n🛑 Cancelling worker...");
|
println!("\n🛑 Cancelling worker...");
|
||||||
let _ = cancel_tx.send(()).await;
|
let _ = cancel_tx.send(()).await;
|
||||||
});
|
});
|
||||||
|
|
||||||
// タスク完了を待つ
|
// Wait for task completion
|
||||||
task.await?;
|
task.await?;
|
||||||
|
|
||||||
println!("\n✨ Demo complete!");
|
println!("\n✨ Demo complete!");
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,17 @@
|
||||||
//! Worker を用いた対話型 CLI クライアント
|
//! Interactive CLI client using Worker
|
||||||
//!
|
//!
|
||||||
//! 複数のLLMプロバイダ(Anthropic, Gemini, OpenAI, Ollama)と対話するCLIアプリケーション。
|
//! A CLI application for interacting with multiple LLM providers (Anthropic, Gemini, OpenAI, Ollama).
|
||||||
//! ツールの登録と実行、ストリーミングレスポンスの表示をデモする。
|
//! Demonstrates tool registration and execution, and streaming response display.
|
||||||
//!
|
//!
|
||||||
//! ## 使用方法
|
//! ## Usage
|
||||||
//!
|
//!
|
||||||
//! ```bash
|
//! ```bash
|
||||||
//! # .envファイルにAPIキーを設定
|
//! # Set API keys in .env file
|
||||||
//! echo "ANTHROPIC_API_KEY=your-api-key" > .env
|
//! echo "ANTHROPIC_API_KEY=your-api-key" > .env
|
||||||
//! echo "GEMINI_API_KEY=your-api-key" >> .env
|
//! echo "GEMINI_API_KEY=your-api-key" >> .env
|
||||||
//! echo "OPENAI_API_KEY=your-api-key" >> .env
|
//! echo "OPENAI_API_KEY=your-api-key" >> .env
|
||||||
//!
|
//!
|
||||||
//! # Anthropic (デフォルト)
|
//! # Anthropic (default)
|
||||||
//! cargo run --example worker_cli
|
//! cargo run --example worker_cli
|
||||||
//!
|
//!
|
||||||
//! # Gemini
|
//! # Gemini
|
||||||
|
|
@ -20,13 +20,13 @@
|
||||||
//! # OpenAI
|
//! # OpenAI
|
||||||
//! cargo run --example worker_cli -- --provider openai --model gpt-4o
|
//! cargo run --example worker_cli -- --provider openai --model gpt-4o
|
||||||
//!
|
//!
|
||||||
//! # Ollama (ローカル)
|
//! # Ollama (local)
|
||||||
//! cargo run --example worker_cli -- --provider ollama --model llama3.2
|
//! cargo run --example worker_cli -- --provider ollama --model llama3.2
|
||||||
//!
|
//!
|
||||||
//! # オプション指定
|
//! # With options
|
||||||
//! cargo run --example worker_cli -- --provider anthropic --model claude-3-haiku-20240307 --system "You are a helpful assistant."
|
//! cargo run --example worker_cli -- --provider anthropic --model claude-3-haiku-20240307 --system "You are a helpful assistant."
|
||||||
//!
|
//!
|
||||||
//! # ヘルプ表示
|
//! # Show help
|
||||||
//! cargo run --example worker_cli -- --help
|
//! cargo run --example worker_cli -- --help
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
|
|
@ -53,15 +53,15 @@ use llm_worker::{
|
||||||
};
|
};
|
||||||
use llm_worker_macros::tool_registry;
|
use llm_worker_macros::tool_registry;
|
||||||
|
|
||||||
// 必要なマクロ展開用インポート
|
// Required imports for macro expansion
|
||||||
use schemars;
|
use schemars;
|
||||||
use serde;
|
use serde;
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// プロバイダ定義
|
// Provider Definition
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// 利用可能なLLMプロバイダ
|
/// Available LLM providers
|
||||||
#[derive(Debug, Clone, Copy, ValueEnum, Default)]
|
#[derive(Debug, Clone, Copy, ValueEnum, Default)]
|
||||||
enum Provider {
|
enum Provider {
|
||||||
/// Anthropic Claude
|
/// Anthropic Claude
|
||||||
|
|
@ -71,12 +71,12 @@ enum Provider {
|
||||||
Gemini,
|
Gemini,
|
||||||
/// OpenAI GPT
|
/// OpenAI GPT
|
||||||
Openai,
|
Openai,
|
||||||
/// Ollama (ローカル)
|
/// Ollama (local)
|
||||||
Ollama,
|
Ollama,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Provider {
|
impl Provider {
|
||||||
/// プロバイダのデフォルトモデル
|
/// Default model for the provider
|
||||||
fn default_model(&self) -> &'static str {
|
fn default_model(&self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Provider::Anthropic => "claude-sonnet-4-20250514",
|
Provider::Anthropic => "claude-sonnet-4-20250514",
|
||||||
|
|
@ -86,7 +86,7 @@ impl Provider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// プロバイダの表示名
|
/// Display name for the provider
|
||||||
fn display_name(&self) -> &'static str {
|
fn display_name(&self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Provider::Anthropic => "Anthropic Claude",
|
Provider::Anthropic => "Anthropic Claude",
|
||||||
|
|
@ -96,78 +96,78 @@ impl Provider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// APIキーの環境変数名
|
/// Environment variable name for API key
|
||||||
fn env_var_name(&self) -> Option<&'static str> {
|
fn env_var_name(&self) -> Option<&'static str> {
|
||||||
match self {
|
match self {
|
||||||
Provider::Anthropic => Some("ANTHROPIC_API_KEY"),
|
Provider::Anthropic => Some("ANTHROPIC_API_KEY"),
|
||||||
Provider::Gemini => Some("GEMINI_API_KEY"),
|
Provider::Gemini => Some("GEMINI_API_KEY"),
|
||||||
Provider::Openai => Some("OPENAI_API_KEY"),
|
Provider::Openai => Some("OPENAI_API_KEY"),
|
||||||
Provider::Ollama => None, // Ollamaはローカルなので不要
|
Provider::Ollama => None, // Ollama is local, no key needed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// CLI引数定義
|
// CLI Argument Definition
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// 複数のLLMプロバイダに対応した対話型CLIクライアント
|
/// Interactive CLI client supporting multiple LLM providers
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(name = "worker-cli")]
|
#[command(name = "worker-cli")]
|
||||||
#[command(about = "Interactive CLI client for multiple LLM providers using Worker")]
|
#[command(about = "Interactive CLI client for multiple LLM providers using Worker")]
|
||||||
#[command(version)]
|
#[command(version)]
|
||||||
struct Args {
|
struct Args {
|
||||||
/// 使用するプロバイダ
|
/// Provider to use
|
||||||
#[arg(long, value_enum, default_value_t = Provider::Anthropic)]
|
#[arg(long, value_enum, default_value_t = Provider::Anthropic)]
|
||||||
provider: Provider,
|
provider: Provider,
|
||||||
|
|
||||||
/// 使用するモデル名(未指定時はプロバイダのデフォルト)
|
/// Model name to use (defaults to provider's default if not specified)
|
||||||
#[arg(short, long)]
|
#[arg(short, long)]
|
||||||
model: Option<String>,
|
model: Option<String>,
|
||||||
|
|
||||||
/// システムプロンプト
|
/// System prompt
|
||||||
#[arg(short, long)]
|
#[arg(short, long)]
|
||||||
system: Option<String>,
|
system: Option<String>,
|
||||||
|
|
||||||
/// ツールを無効化
|
/// Disable tools
|
||||||
#[arg(long, default_value = "false")]
|
#[arg(long, default_value = "false")]
|
||||||
no_tools: bool,
|
no_tools: bool,
|
||||||
|
|
||||||
/// 最初のメッセージ(指定するとそれを送信して終了)
|
/// Initial message (if specified, sends it and exits)
|
||||||
#[arg(short = 'p', long)]
|
#[arg(short = 'p', long)]
|
||||||
prompt: Option<String>,
|
prompt: Option<String>,
|
||||||
|
|
||||||
/// APIキー(環境変数より優先)
|
/// API key (takes precedence over environment variable)
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
api_key: Option<String>,
|
api_key: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// ツール定義
|
// Tool Definition
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// アプリケーションコンテキスト
|
/// Application context
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct AppContext;
|
struct AppContext;
|
||||||
|
|
||||||
#[tool_registry]
|
#[tool_registry]
|
||||||
impl AppContext {
|
impl AppContext {
|
||||||
/// 現在の日時を取得する
|
/// Get the current date and time
|
||||||
///
|
///
|
||||||
/// システムの現在の日付と時刻を返します。
|
/// Returns the system's current date and time.
|
||||||
#[tool]
|
#[tool]
|
||||||
fn get_current_time(&self) -> String {
|
fn get_current_time(&self) -> String {
|
||||||
let now = std::time::SystemTime::now()
|
let now = std::time::SystemTime::now()
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.as_secs();
|
.as_secs();
|
||||||
// シンプルなUnixタイムスタンプからの変換
|
// Simple conversion from Unix timestamp
|
||||||
format!("Current Unix timestamp: {}", now)
|
format!("Current Unix timestamp: {}", now)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 簡単な計算を行う
|
/// Perform a simple calculation
|
||||||
///
|
///
|
||||||
/// 2つの数値の四則演算を実行します。
|
/// Executes arithmetic operations on two numbers.
|
||||||
#[tool]
|
#[tool]
|
||||||
fn calculate(&self, a: f64, b: f64, operation: String) -> Result<String, String> {
|
fn calculate(&self, a: f64, b: f64, operation: String) -> Result<String, String> {
|
||||||
let result = match operation.as_str() {
|
let result = match operation.as_str() {
|
||||||
|
|
@ -187,10 +187,10 @@ impl AppContext {
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// ストリーミング表示用ハンドラー
|
// Streaming Display Handlers
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// テキストをリアルタイムで出力するハンドラー
|
/// Handler that outputs text in real-time
|
||||||
struct StreamingPrinter {
|
struct StreamingPrinter {
|
||||||
is_first_delta: Arc<Mutex<bool>>,
|
is_first_delta: Arc<Mutex<bool>>,
|
||||||
}
|
}
|
||||||
|
|
@ -226,7 +226,7 @@ impl Handler<TextBlockKind> for StreamingPrinter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ツール呼び出しを表示するハンドラー
|
/// Handler that displays tool calls
|
||||||
struct ToolCallPrinter {
|
struct ToolCallPrinter {
|
||||||
call_names: Arc<Mutex<HashMap<String, String>>>,
|
call_names: Arc<Mutex<HashMap<String, String>>>,
|
||||||
}
|
}
|
||||||
|
|
@ -270,7 +270,7 @@ impl Handler<ToolUseBlockKind> for ToolCallPrinter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ツール実行結果を表示するHook
|
/// Hook that displays tool execution results
|
||||||
struct ToolResultPrinterHook {
|
struct ToolResultPrinterHook {
|
||||||
call_names: Arc<Mutex<HashMap<String, String>>>,
|
call_names: Arc<Mutex<HashMap<String, String>>>,
|
||||||
}
|
}
|
||||||
|
|
@ -302,17 +302,17 @@ impl Hook<PostToolCall> for ToolResultPrinterHook {
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// クライアント作成
|
// Client Creation
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// プロバイダに応じたAPIキーを取得
|
/// Get API key based on provider
|
||||||
fn get_api_key(args: &Args) -> Result<String, String> {
|
fn get_api_key(args: &Args) -> Result<String, String> {
|
||||||
// CLI引数のAPIキーが優先
|
// CLI argument API key takes precedence
|
||||||
if let Some(ref key) = args.api_key {
|
if let Some(ref key) = args.api_key {
|
||||||
return Ok(key.clone());
|
return Ok(key.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
// プロバイダに応じた環境変数を確認
|
// Check environment variable based on provider
|
||||||
if let Some(env_var) = args.provider.env_var_name() {
|
if let Some(env_var) = args.provider.env_var_name() {
|
||||||
std::env::var(env_var).map_err(|_| {
|
std::env::var(env_var).map_err(|_| {
|
||||||
format!(
|
format!(
|
||||||
|
|
@ -321,12 +321,12 @@ fn get_api_key(args: &Args) -> Result<String, String> {
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
// Ollamaなどはキー不要
|
// Ollama etc. don't need a key
|
||||||
Ok(String::new())
|
Ok(String::new())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// プロバイダに応じたクライアントを作成
|
/// Create client based on provider
|
||||||
fn create_client(args: &Args) -> Result<Box<dyn LlmClient>, String> {
|
fn create_client(args: &Args) -> Result<Box<dyn LlmClient>, String> {
|
||||||
let model = args
|
let model = args
|
||||||
.model
|
.model
|
||||||
|
|
@ -356,17 +356,17 @@ fn create_client(args: &Args) -> Result<Box<dyn LlmClient>, String> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// メイン
|
// Main
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
// .envファイルを読み込む
|
// Load .env file
|
||||||
dotenv::dotenv().ok();
|
dotenv::dotenv().ok();
|
||||||
|
|
||||||
// ロギング初期化
|
// Initialize logging
|
||||||
// RUST_LOG=debug cargo run --example worker_cli ... で詳細ログ表示
|
// Use RUST_LOG=debug cargo run --example worker_cli ... for detailed logs
|
||||||
// デフォルトは warn レベル、RUST_LOG 環境変数で上書き可能
|
// Default is warn level, can be overridden with RUST_LOG environment variable
|
||||||
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("warn"));
|
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("warn"));
|
||||||
|
|
||||||
tracing_subscriber::fmt()
|
tracing_subscriber::fmt()
|
||||||
|
|
@ -374,7 +374,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
.with_target(true)
|
.with_target(true)
|
||||||
.init();
|
.init();
|
||||||
|
|
||||||
// CLI引数をパース
|
// Parse CLI arguments
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
|
|
@ -383,10 +383,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
"Starting worker CLI"
|
"Starting worker CLI"
|
||||||
);
|
);
|
||||||
|
|
||||||
// 対話モードかワンショットモードか
|
// Interactive mode or one-shot mode
|
||||||
let is_interactive = args.prompt.is_none();
|
let is_interactive = args.prompt.is_none();
|
||||||
|
|
||||||
// モデル名(表示用)
|
// Model name (for display)
|
||||||
let model_name = args
|
let model_name = args
|
||||||
.model
|
.model
|
||||||
.clone()
|
.clone()
|
||||||
|
|
@ -416,7 +416,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
println!("─────────────────────────────────────────────────");
|
println!("─────────────────────────────────────────────────");
|
||||||
}
|
}
|
||||||
|
|
||||||
// クライアント作成
|
// Create client
|
||||||
let client = match create_client(&args) {
|
let client = match create_client(&args) {
|
||||||
Ok(c) => c,
|
Ok(c) => c,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
|
@ -425,17 +425,17 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Worker作成
|
// Create Worker
|
||||||
let mut worker = Worker::new(client);
|
let mut worker = Worker::new(client);
|
||||||
|
|
||||||
let tool_call_names = Arc::new(Mutex::new(HashMap::new()));
|
let tool_call_names = Arc::new(Mutex::new(HashMap::new()));
|
||||||
|
|
||||||
// システムプロンプトを設定
|
// Set system prompt
|
||||||
if let Some(ref system_prompt) = args.system {
|
if let Some(ref system_prompt) = args.system {
|
||||||
worker.set_system_prompt(system_prompt);
|
worker.set_system_prompt(system_prompt);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ツール登録(--no-tools でなければ)
|
// Register tools (unless --no-tools)
|
||||||
if !args.no_tools {
|
if !args.no_tools {
|
||||||
let app = AppContext;
|
let app = AppContext;
|
||||||
worker
|
worker
|
||||||
|
|
@ -444,7 +444,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
worker.register_tool(app.calculate_definition()).unwrap();
|
worker.register_tool(app.calculate_definition()).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
// ストリーミング表示用ハンドラーを登録
|
// Register streaming display handlers
|
||||||
worker
|
worker
|
||||||
.timeline_mut()
|
.timeline_mut()
|
||||||
.on_text_block(StreamingPrinter::new())
|
.on_text_block(StreamingPrinter::new())
|
||||||
|
|
@ -452,7 +452,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
|
||||||
worker.add_post_tool_call_hook(ToolResultPrinterHook::new(tool_call_names));
|
worker.add_post_tool_call_hook(ToolResultPrinterHook::new(tool_call_names));
|
||||||
|
|
||||||
// ワンショットモード
|
// One-shot mode
|
||||||
if let Some(prompt) = args.prompt {
|
if let Some(prompt) = args.prompt {
|
||||||
match worker.run(&prompt).await {
|
match worker.run(&prompt).await {
|
||||||
Ok(_) => {}
|
Ok(_) => {}
|
||||||
|
|
@ -465,7 +465,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
// 対話ループ
|
// Interactive loop
|
||||||
loop {
|
loop {
|
||||||
print!("\n👤 You: ");
|
print!("\n👤 You: ");
|
||||||
io::stdout().flush()?;
|
io::stdout().flush()?;
|
||||||
|
|
@ -483,7 +483,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Workerを実行(Workerが履歴を管理)
|
// Run Worker (Worker manages history)
|
||||||
match worker.run(input).await {
|
match worker.run(input).await {
|
||||||
Ok(_) => {}
|
Ok(_) => {}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
//! Worker層の公開イベント型
|
//! Public event types for Worker layer
|
||||||
//!
|
//!
|
||||||
//! 外部利用者に公開するためのイベント表現。
|
//! Event representation exposed to external users.
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
|
@ -8,38 +8,38 @@ use serde::{Deserialize, Serialize};
|
||||||
// Core Event Types (from llm_client layer)
|
// Core Event Types (from llm_client layer)
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// LLMからのストリーミングイベント
|
/// Streaming events from LLM
|
||||||
///
|
///
|
||||||
/// 各LLMプロバイダからのレスポンスは、この`Event`のストリームとして
|
/// Responses from each LLM provider are processed uniformly
|
||||||
/// 統一的に処理されます。
|
/// as a stream of `Event`.
|
||||||
///
|
///
|
||||||
/// # イベントの種類
|
/// # Event Types
|
||||||
///
|
///
|
||||||
/// - **メタイベント**: `Ping`, `Usage`, `Status`, `Error`
|
/// - **Meta events**: `Ping`, `Usage`, `Status`, `Error`
|
||||||
/// - **ブロックイベント**: `BlockStart`, `BlockDelta`, `BlockStop`, `BlockAbort`
|
/// - **Block events**: `BlockStart`, `BlockDelta`, `BlockStop`, `BlockAbort`
|
||||||
///
|
///
|
||||||
/// # ブロックのライフサイクル
|
/// # Block Lifecycle
|
||||||
///
|
///
|
||||||
/// テキストやツール呼び出しは、`BlockStart` → `BlockDelta`(複数) → `BlockStop`
|
/// Text and tool calls have events in the order of
|
||||||
/// の順序でイベントが発生します。
|
/// `BlockStart` → `BlockDelta`(multiple) → `BlockStop`.
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
pub enum Event {
|
pub enum Event {
|
||||||
/// ハートビート
|
/// Heartbeat
|
||||||
Ping(PingEvent),
|
Ping(PingEvent),
|
||||||
/// トークン使用量
|
/// Token usage
|
||||||
Usage(UsageEvent),
|
Usage(UsageEvent),
|
||||||
/// ストリームのステータス変化
|
/// Stream status change
|
||||||
Status(StatusEvent),
|
Status(StatusEvent),
|
||||||
/// エラー発生
|
/// Error occurred
|
||||||
Error(ErrorEvent),
|
Error(ErrorEvent),
|
||||||
|
|
||||||
/// ブロック開始(テキスト、ツール使用等)
|
/// Block start (text, tool use, etc.)
|
||||||
BlockStart(BlockStart),
|
BlockStart(BlockStart),
|
||||||
/// ブロックの差分データ
|
/// Block delta data
|
||||||
BlockDelta(BlockDelta),
|
BlockDelta(BlockDelta),
|
||||||
/// ブロック正常終了
|
/// Block normal end
|
||||||
BlockStop(BlockStop),
|
BlockStop(BlockStop),
|
||||||
/// ブロック中断
|
/// Block abort
|
||||||
BlockAbort(BlockAbort),
|
BlockAbort(BlockAbort),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -47,47 +47,47 @@ pub enum Event {
|
||||||
// Meta Events
|
// Meta Events
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// Pingイベント(ハートビート)
|
/// Ping event (heartbeat)
|
||||||
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
|
||||||
pub struct PingEvent {
|
pub struct PingEvent {
|
||||||
pub timestamp: Option<u64>,
|
pub timestamp: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 使用量イベント
|
/// Usage event
|
||||||
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
|
||||||
pub struct UsageEvent {
|
pub struct UsageEvent {
|
||||||
/// 入力トークン数
|
/// Input token count
|
||||||
pub input_tokens: Option<u64>,
|
pub input_tokens: Option<u64>,
|
||||||
/// 出力トークン数
|
/// Output token count
|
||||||
pub output_tokens: Option<u64>,
|
pub output_tokens: Option<u64>,
|
||||||
/// 合計トークン数
|
/// Total token count
|
||||||
pub total_tokens: Option<u64>,
|
pub total_tokens: Option<u64>,
|
||||||
/// キャッシュ読み込みトークン数
|
/// Cache read token count
|
||||||
pub cache_read_input_tokens: Option<u64>,
|
pub cache_read_input_tokens: Option<u64>,
|
||||||
/// キャッシュ作成トークン数
|
/// Cache creation token count
|
||||||
pub cache_creation_input_tokens: Option<u64>,
|
pub cache_creation_input_tokens: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ステータスイベント
|
/// Status event
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
pub struct StatusEvent {
|
pub struct StatusEvent {
|
||||||
pub status: ResponseStatus,
|
pub status: ResponseStatus,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// レスポンスステータス
|
/// Response status
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
pub enum ResponseStatus {
|
pub enum ResponseStatus {
|
||||||
/// ストリーム開始
|
/// Stream started
|
||||||
Started,
|
Started,
|
||||||
/// 正常完了
|
/// Completed normally
|
||||||
Completed,
|
Completed,
|
||||||
/// キャンセルされた
|
/// Cancelled
|
||||||
Cancelled,
|
Cancelled,
|
||||||
/// エラー発生
|
/// Error occurred
|
||||||
Failed,
|
Failed,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// エラーイベント
|
/// Error event
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
pub struct ErrorEvent {
|
pub struct ErrorEvent {
|
||||||
pub code: Option<String>,
|
pub code: Option<String>,
|
||||||
|
|
@ -98,27 +98,27 @@ pub struct ErrorEvent {
|
||||||
// Block Types
|
// Block Types
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// ブロックの種別
|
/// Block type
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||||
pub enum BlockType {
|
pub enum BlockType {
|
||||||
/// テキスト生成
|
/// Text generation
|
||||||
Text,
|
Text,
|
||||||
/// 思考 (Claude Extended Thinking等)
|
/// Thinking (Claude Extended Thinking, etc.)
|
||||||
Thinking,
|
Thinking,
|
||||||
/// ツール呼び出し
|
/// Tool call
|
||||||
ToolUse,
|
ToolUse,
|
||||||
/// ツール結果
|
/// Tool result
|
||||||
ToolResult,
|
ToolResult,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ブロック開始イベント
|
/// Block start event
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
pub struct BlockStart {
|
pub struct BlockStart {
|
||||||
/// ブロックのインデックス
|
/// Block index
|
||||||
pub index: usize,
|
pub index: usize,
|
||||||
/// ブロックの種別
|
/// Block type
|
||||||
pub block_type: BlockType,
|
pub block_type: BlockType,
|
||||||
/// ブロック固有のメタデータ
|
/// Block-specific metadata
|
||||||
pub metadata: BlockMetadata,
|
pub metadata: BlockMetadata,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -128,7 +128,7 @@ impl BlockStart {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ブロックのメタデータ
|
/// Block metadata
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
pub enum BlockMetadata {
|
pub enum BlockMetadata {
|
||||||
Text,
|
Text,
|
||||||
|
|
@ -137,28 +137,28 @@ pub enum BlockMetadata {
|
||||||
ToolResult { tool_use_id: String },
|
ToolResult { tool_use_id: String },
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ブロックデルタイベント
|
/// Block delta event
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
pub struct BlockDelta {
|
pub struct BlockDelta {
|
||||||
/// ブロックのインデックス
|
/// Block index
|
||||||
pub index: usize,
|
pub index: usize,
|
||||||
/// デルタの内容
|
/// Delta content
|
||||||
pub delta: DeltaContent,
|
pub delta: DeltaContent,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// デルタの内容
|
/// Delta content
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
pub enum DeltaContent {
|
pub enum DeltaContent {
|
||||||
/// テキストデルタ
|
/// Text delta
|
||||||
Text(String),
|
Text(String),
|
||||||
/// 思考デルタ
|
/// Thinking delta
|
||||||
Thinking(String),
|
Thinking(String),
|
||||||
/// ツール引数のJSON部分文字列
|
/// JSON substring of tool arguments
|
||||||
InputJson(String),
|
InputJson(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DeltaContent {
|
impl DeltaContent {
|
||||||
/// デルタのブロック種別を取得
|
/// Get block type of the delta
|
||||||
pub fn block_type(&self) -> BlockType {
|
pub fn block_type(&self) -> BlockType {
|
||||||
match self {
|
match self {
|
||||||
DeltaContent::Text(_) => BlockType::Text,
|
DeltaContent::Text(_) => BlockType::Text,
|
||||||
|
|
@ -168,14 +168,14 @@ impl DeltaContent {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ブロック停止イベント
|
/// Block stop event
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
pub struct BlockStop {
|
pub struct BlockStop {
|
||||||
/// ブロックのインデックス
|
/// Block index
|
||||||
pub index: usize,
|
pub index: usize,
|
||||||
/// ブロックの種別
|
/// Block type
|
||||||
pub block_type: BlockType,
|
pub block_type: BlockType,
|
||||||
/// 停止理由
|
/// Stop reason
|
||||||
pub stop_reason: Option<StopReason>,
|
pub stop_reason: Option<StopReason>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -185,14 +185,14 @@ impl BlockStop {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ブロック中断イベント
|
/// Block abort event
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
pub struct BlockAbort {
|
pub struct BlockAbort {
|
||||||
/// ブロックのインデックス
|
/// Block index
|
||||||
pub index: usize,
|
pub index: usize,
|
||||||
/// ブロックの種別
|
/// Block type
|
||||||
pub block_type: BlockType,
|
pub block_type: BlockType,
|
||||||
/// 中断理由
|
/// Abort reason
|
||||||
pub reason: String,
|
pub reason: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -202,16 +202,16 @@ impl BlockAbort {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 停止理由
|
/// Stop reason
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
pub enum StopReason {
|
pub enum StopReason {
|
||||||
/// 自然終了
|
/// Natural end
|
||||||
EndTurn,
|
EndTurn,
|
||||||
/// 最大トークン数到達
|
/// Max tokens reached
|
||||||
MaxTokens,
|
MaxTokens,
|
||||||
/// ストップシーケンス到達
|
/// Stop sequence reached
|
||||||
StopSequence,
|
StopSequence,
|
||||||
/// ツール使用
|
/// Tool use
|
||||||
ToolUse,
|
ToolUse,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -220,7 +220,7 @@ pub enum StopReason {
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
impl Event {
|
impl Event {
|
||||||
/// テキストブロック開始イベントを作成
|
/// Create text block start event
|
||||||
pub fn text_block_start(index: usize) -> Self {
|
pub fn text_block_start(index: usize) -> Self {
|
||||||
Event::BlockStart(BlockStart {
|
Event::BlockStart(BlockStart {
|
||||||
index,
|
index,
|
||||||
|
|
@ -229,7 +229,7 @@ impl Event {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// テキストデルタイベントを作成
|
/// Create text delta event
|
||||||
pub fn text_delta(index: usize, text: impl Into<String>) -> Self {
|
pub fn text_delta(index: usize, text: impl Into<String>) -> Self {
|
||||||
Event::BlockDelta(BlockDelta {
|
Event::BlockDelta(BlockDelta {
|
||||||
index,
|
index,
|
||||||
|
|
@ -237,7 +237,7 @@ impl Event {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// テキストブロック停止イベントを作成
|
/// Create text block stop event
|
||||||
pub fn text_block_stop(index: usize, stop_reason: Option<StopReason>) -> Self {
|
pub fn text_block_stop(index: usize, stop_reason: Option<StopReason>) -> Self {
|
||||||
Event::BlockStop(BlockStop {
|
Event::BlockStop(BlockStop {
|
||||||
index,
|
index,
|
||||||
|
|
@ -246,7 +246,7 @@ impl Event {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ツール使用ブロック開始イベントを作成
|
/// Create tool use block start event
|
||||||
pub fn tool_use_start(index: usize, id: impl Into<String>, name: impl Into<String>) -> Self {
|
pub fn tool_use_start(index: usize, id: impl Into<String>, name: impl Into<String>) -> Self {
|
||||||
Event::BlockStart(BlockStart {
|
Event::BlockStart(BlockStart {
|
||||||
index,
|
index,
|
||||||
|
|
@ -258,7 +258,7 @@ impl Event {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ツール引数デルタイベントを作成
|
/// Create tool input delta event
|
||||||
pub fn tool_input_delta(index: usize, json: impl Into<String>) -> Self {
|
pub fn tool_input_delta(index: usize, json: impl Into<String>) -> Self {
|
||||||
Event::BlockDelta(BlockDelta {
|
Event::BlockDelta(BlockDelta {
|
||||||
index,
|
index,
|
||||||
|
|
@ -266,7 +266,7 @@ impl Event {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ツール使用ブロック停止イベントを作成
|
/// Create tool use block stop event
|
||||||
pub fn tool_use_stop(index: usize) -> Self {
|
pub fn tool_use_stop(index: usize) -> Self {
|
||||||
Event::BlockStop(BlockStop {
|
Event::BlockStop(BlockStop {
|
||||||
index,
|
index,
|
||||||
|
|
@ -275,7 +275,7 @@ impl Event {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 使用量イベントを作成
|
/// Create usage event
|
||||||
pub fn usage(input_tokens: u64, output_tokens: u64) -> Self {
|
pub fn usage(input_tokens: u64, output_tokens: u64) -> Self {
|
||||||
Event::Usage(UsageEvent {
|
Event::Usage(UsageEvent {
|
||||||
input_tokens: Some(input_tokens),
|
input_tokens: Some(input_tokens),
|
||||||
|
|
@ -286,7 +286,7 @@ impl Event {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Pingイベントを作成
|
/// Create ping event
|
||||||
pub fn ping() -> Self {
|
pub fn ping() -> Self {
|
||||||
Event::Ping(PingEvent { timestamp: None })
|
Event::Ping(PingEvent { timestamp: None })
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
//! Handler/Kind型
|
//! Handler/Kind Types
|
||||||
//!
|
//!
|
||||||
//! Timeline層でイベントを処理するためのトレイト。
|
//! Traits for processing events in the Timeline layer.
|
||||||
//! カスタムハンドラを実装してTimelineに登録することで、
|
//! By implementing custom handlers and registering them with Timeline,
|
||||||
//! ストリームイベントを受信できます。
|
//! you can receive stream events.
|
||||||
|
|
||||||
use crate::timeline::event::*;
|
use crate::timeline::event::*;
|
||||||
|
|
||||||
|
|
@ -10,13 +10,13 @@ use crate::timeline::event::*;
|
||||||
// Kind Trait
|
// Kind Trait
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// イベント種別を定義するマーカートレイト
|
/// Marker trait defining event types
|
||||||
///
|
///
|
||||||
/// 各Kindは対応するイベント型を指定します。
|
/// Each Kind specifies its corresponding event type.
|
||||||
/// HandlerはこのKindに対して実装され、同じKindに対して
|
/// Handlers are implemented for this Kind, and multiple Handlers
|
||||||
/// 異なるScope型を持つ複数のHandlerを登録できます。
|
/// with different Scope types can be registered for the same Kind.
|
||||||
pub trait Kind {
|
pub trait Kind {
|
||||||
/// このKindに対応するイベント型
|
/// Event type corresponding to this Kind
|
||||||
type Event;
|
type Event;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -24,10 +24,10 @@ pub trait Kind {
|
||||||
// Handler Trait
|
// Handler Trait
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// イベントを処理するハンドラトレイト
|
/// Handler trait for processing events
|
||||||
///
|
///
|
||||||
/// 特定の`Kind`に対するイベント処理を定義します。
|
/// Defines event processing for a specific `Kind`.
|
||||||
/// `Scope`はブロックのライフサイクル中に保持される状態です。
|
/// `Scope` is state held during the block's lifecycle.
|
||||||
///
|
///
|
||||||
/// # Examples
|
/// # Examples
|
||||||
///
|
///
|
||||||
|
|
@ -39,7 +39,7 @@ pub trait Kind {
|
||||||
/// }
|
/// }
|
||||||
///
|
///
|
||||||
/// impl Handler<TextBlockKind> for TextCollector {
|
/// impl Handler<TextBlockKind> for TextCollector {
|
||||||
/// type Scope = String; // ブロックごとのバッファ
|
/// type Scope = String; // Buffer per block
|
||||||
///
|
///
|
||||||
/// fn on_event(&mut self, buffer: &mut String, event: &TextBlockEvent) {
|
/// fn on_event(&mut self, buffer: &mut String, event: &TextBlockEvent) {
|
||||||
/// match event {
|
/// match event {
|
||||||
|
|
@ -53,13 +53,13 @@ pub trait Kind {
|
||||||
/// }
|
/// }
|
||||||
/// ```
|
/// ```
|
||||||
pub trait Handler<K: Kind> {
|
pub trait Handler<K: Kind> {
|
||||||
/// Handler固有のスコープ型
|
/// Handler-specific scope type
|
||||||
///
|
///
|
||||||
/// ブロック開始時に`Default::default()`で生成され、
|
/// Generated with `Default::default()` at block start,
|
||||||
/// ブロック終了時に破棄されます。
|
/// and destroyed at block end.
|
||||||
type Scope: Default;
|
type Scope: Default;
|
||||||
|
|
||||||
/// イベントを処理する
|
/// Process the event
|
||||||
fn on_event(&mut self, scope: &mut Self::Scope, event: &K::Event);
|
fn on_event(&mut self, scope: &mut Self::Scope, event: &K::Event);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -67,25 +67,25 @@ pub trait Handler<K: Kind> {
|
||||||
// Meta Kind Definitions
|
// Meta Kind Definitions
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// Usage Kind - 使用量イベント用
|
/// Usage Kind - for usage events
|
||||||
pub struct UsageKind;
|
pub struct UsageKind;
|
||||||
impl Kind for UsageKind {
|
impl Kind for UsageKind {
|
||||||
type Event = UsageEvent;
|
type Event = UsageEvent;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Ping Kind - Pingイベント用
|
/// Ping Kind - for ping events
|
||||||
pub struct PingKind;
|
pub struct PingKind;
|
||||||
impl Kind for PingKind {
|
impl Kind for PingKind {
|
||||||
type Event = PingEvent;
|
type Event = PingEvent;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Status Kind - ステータスイベント用
|
/// Status Kind - for status events
|
||||||
pub struct StatusKind;
|
pub struct StatusKind;
|
||||||
impl Kind for StatusKind {
|
impl Kind for StatusKind {
|
||||||
type Event = StatusEvent;
|
type Event = StatusEvent;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Error Kind - エラーイベント用
|
/// Error Kind - for error events
|
||||||
pub struct ErrorKind;
|
pub struct ErrorKind;
|
||||||
impl Kind for ErrorKind {
|
impl Kind for ErrorKind {
|
||||||
type Event = ErrorEvent;
|
type Event = ErrorEvent;
|
||||||
|
|
@ -95,13 +95,13 @@ impl Kind for ErrorKind {
|
||||||
// Block Kind Definitions
|
// Block Kind Definitions
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// TextBlock Kind - テキストブロック用
|
/// TextBlock Kind - for text blocks
|
||||||
pub struct TextBlockKind;
|
pub struct TextBlockKind;
|
||||||
impl Kind for TextBlockKind {
|
impl Kind for TextBlockKind {
|
||||||
type Event = TextBlockEvent;
|
type Event = TextBlockEvent;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// テキストブロックのイベント
|
/// Text block events
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub enum TextBlockEvent {
|
pub enum TextBlockEvent {
|
||||||
Start(TextBlockStart),
|
Start(TextBlockStart),
|
||||||
|
|
@ -120,13 +120,13 @@ pub struct TextBlockStop {
|
||||||
pub stop_reason: Option<StopReason>,
|
pub stop_reason: Option<StopReason>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ThinkingBlock Kind - 思考ブロック用
|
/// ThinkingBlock Kind - for thinking blocks
|
||||||
pub struct ThinkingBlockKind;
|
pub struct ThinkingBlockKind;
|
||||||
impl Kind for ThinkingBlockKind {
|
impl Kind for ThinkingBlockKind {
|
||||||
type Event = ThinkingBlockEvent;
|
type Event = ThinkingBlockEvent;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 思考ブロックのイベント
|
/// Thinking block events
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub enum ThinkingBlockEvent {
|
pub enum ThinkingBlockEvent {
|
||||||
Start(ThinkingBlockStart),
|
Start(ThinkingBlockStart),
|
||||||
|
|
@ -144,17 +144,17 @@ pub struct ThinkingBlockStop {
|
||||||
pub index: usize,
|
pub index: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ToolUseBlock Kind - ツール使用ブロック用
|
/// ToolUseBlock Kind - for tool use blocks
|
||||||
pub struct ToolUseBlockKind;
|
pub struct ToolUseBlockKind;
|
||||||
impl Kind for ToolUseBlockKind {
|
impl Kind for ToolUseBlockKind {
|
||||||
type Event = ToolUseBlockEvent;
|
type Event = ToolUseBlockEvent;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ツール使用ブロックのイベント
|
/// Tool use block events
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub enum ToolUseBlockEvent {
|
pub enum ToolUseBlockEvent {
|
||||||
Start(ToolUseBlockStart),
|
Start(ToolUseBlockStart),
|
||||||
/// ツール引数のJSON部分文字列
|
/// JSON substring of tool arguments
|
||||||
InputJsonDelta(String),
|
InputJsonDelta(String),
|
||||||
Stop(ToolUseBlockStop),
|
Stop(ToolUseBlockStop),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
//! Hook関連の型定義
|
//! Hook-related type definitions
|
||||||
//!
|
//!
|
||||||
//! Worker層でのターン制御・介入に使用される型
|
//! Types used for turn control and intervention in the Worker layer
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
@ -52,7 +52,7 @@ pub enum PostToolCallResult {
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum OnTurnEndResult {
|
pub enum OnTurnEndResult {
|
||||||
Finish,
|
Finish,
|
||||||
ContinueWithMessages(Vec<crate::Message>),
|
ContinueWithMessages(Vec<crate::Item>),
|
||||||
Paused,
|
Paused,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -60,35 +60,35 @@ use std::sync::Arc;
|
||||||
|
|
||||||
use crate::tool::{Tool, ToolMeta};
|
use crate::tool::{Tool, ToolMeta};
|
||||||
|
|
||||||
/// PreToolCall の入力コンテキスト
|
/// Input context for PreToolCall
|
||||||
pub struct ToolCallContext {
|
pub struct ToolCallContext {
|
||||||
/// ツール呼び出し情報(改変可能)
|
/// Tool call information (modifiable)
|
||||||
pub call: ToolCall,
|
pub call: ToolCall,
|
||||||
/// ツールメタ情報(不変)
|
/// Tool meta information (immutable)
|
||||||
pub meta: ToolMeta,
|
pub meta: ToolMeta,
|
||||||
/// ツールインスタンス(状態アクセス用)
|
/// Tool instance (for state access)
|
||||||
pub tool: Arc<dyn Tool>,
|
pub tool: Arc<dyn Tool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// PostToolCall の入力コンテキスト
|
/// Input context for PostToolCall
|
||||||
pub struct PostToolCallContext {
|
pub struct PostToolCallContext {
|
||||||
/// ツール呼び出し情報
|
/// Tool call information
|
||||||
pub call: ToolCall,
|
pub call: ToolCall,
|
||||||
/// ツール実行結果(改変可能)
|
/// Tool execution result (modifiable)
|
||||||
pub result: ToolResult,
|
pub result: ToolResult,
|
||||||
/// ツールメタ情報(不変)
|
/// Tool meta information (immutable)
|
||||||
pub meta: ToolMeta,
|
pub meta: ToolMeta,
|
||||||
/// ツールインスタンス(状態アクセス用)
|
/// Tool instance (for state access)
|
||||||
pub tool: Arc<dyn Tool>,
|
pub tool: Arc<dyn Tool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HookEventKind for OnPromptSubmit {
|
impl HookEventKind for OnPromptSubmit {
|
||||||
type Input = crate::Message;
|
type Input = crate::Item;
|
||||||
type Output = OnPromptSubmitResult;
|
type Output = OnPromptSubmitResult;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HookEventKind for PreLlmRequest {
|
impl HookEventKind for PreLlmRequest {
|
||||||
type Input = Vec<crate::Message>;
|
type Input = Vec<crate::Item>;
|
||||||
type Output = PreLlmRequestResult;
|
type Output = PreLlmRequestResult;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -103,7 +103,7 @@ impl HookEventKind for PostToolCall {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HookEventKind for OnTurnEnd {
|
impl HookEventKind for OnTurnEnd {
|
||||||
type Input = Vec<crate::Message>;
|
type Input = Vec<crate::Item>;
|
||||||
type Output = OnTurnEndResult;
|
type Output = OnTurnEndResult;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -116,35 +116,35 @@ impl HookEventKind for OnAbort {
|
||||||
// Tool Call / Result Types
|
// Tool Call / Result Types
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// ツール呼び出し情報
|
/// Tool call information
|
||||||
///
|
///
|
||||||
/// LLMからのToolUseブロックを表現し、Hook処理で改変可能
|
/// Represents a ToolUse block from LLM, modifiable in Hook processing
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct ToolCall {
|
pub struct ToolCall {
|
||||||
/// ツール呼び出しID(レスポンスとの紐付けに使用)
|
/// Tool call ID (used for linking with response)
|
||||||
pub id: String,
|
pub id: String,
|
||||||
/// ツール名
|
/// Tool name
|
||||||
pub name: String,
|
pub name: String,
|
||||||
/// 入力引数(JSON)
|
/// Input arguments (JSON)
|
||||||
pub input: Value,
|
pub input: Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ツール実行結果
|
/// Tool execution result
|
||||||
///
|
///
|
||||||
/// ツール実行後の結果を表現し、Hook処理で改変可能
|
/// Represents the result after tool execution, modifiable in Hook processing
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct ToolResult {
|
pub struct ToolResult {
|
||||||
/// 対応するツール呼び出しID
|
/// Corresponding tool call ID
|
||||||
pub tool_use_id: String,
|
pub tool_use_id: String,
|
||||||
/// 結果コンテンツ
|
/// Result content
|
||||||
pub content: String,
|
pub content: String,
|
||||||
/// エラーかどうか
|
/// Whether this is an error
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub is_error: bool,
|
pub is_error: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToolResult {
|
impl ToolResult {
|
||||||
/// 成功結果を作成
|
/// Create a success result
|
||||||
pub fn success(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
|
pub fn success(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
tool_use_id: tool_use_id.into(),
|
tool_use_id: tool_use_id.into(),
|
||||||
|
|
@ -153,7 +153,7 @@ impl ToolResult {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// エラー結果を作成
|
/// Create an error result
|
||||||
pub fn error(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
|
pub fn error(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
tool_use_id: tool_use_id.into(),
|
tool_use_id: tool_use_id.into(),
|
||||||
|
|
@ -167,13 +167,13 @@ impl ToolResult {
|
||||||
// Hook Error
|
// Hook Error
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// Hookエラー
|
/// Hook error
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum HookError {
|
pub enum HookError {
|
||||||
/// 処理が中断された
|
/// Processing was aborted
|
||||||
#[error("Aborted: {0}")]
|
#[error("Aborted: {0}")]
|
||||||
Aborted(String),
|
Aborted(String),
|
||||||
/// 内部エラー
|
/// Internal error
|
||||||
#[error("Hook error: {0}")]
|
#[error("Hook error: {0}")]
|
||||||
Internal(String),
|
Internal(String),
|
||||||
}
|
}
|
||||||
|
|
@ -182,9 +182,9 @@ pub enum HookError {
|
||||||
// Hook Trait
|
// Hook Trait
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// Hookイベントの処理を行うトレイト
|
/// Trait for handling Hook events
|
||||||
///
|
///
|
||||||
/// 各イベント種別は戻り値型が異なるため、`HookEventKind`を介して型を制約する。
|
/// Each event type has a different return type, constrained via `HookEventKind`.
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Hook<E: HookEventKind>: Send + Sync {
|
pub trait Hook<E: HookEventKind>: Send + Sync {
|
||||||
async fn call(&self, input: &mut E::Input) -> Result<E::Output, HookError>;
|
async fn call(&self, input: &mut E::Input) -> Result<E::Output, HookError>;
|
||||||
|
|
@ -194,9 +194,9 @@ pub trait Hook<E: HookEventKind>: Send + Sync {
|
||||||
// Hook Registry
|
// Hook Registry
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// 全 Hook を保持するレジストリ
|
/// Registry holding all Hooks
|
||||||
///
|
///
|
||||||
/// Worker 内部で使用され、各種 Hook を一括管理する。
|
/// Used internally by Worker to manage all Hook types.
|
||||||
pub struct HookRegistry {
|
pub struct HookRegistry {
|
||||||
/// on_prompt_submit Hook
|
/// on_prompt_submit Hook
|
||||||
pub(crate) on_prompt_submit: Vec<Box<dyn Hook<OnPromptSubmit>>>,
|
pub(crate) on_prompt_submit: Vec<Box<dyn Hook<OnPromptSubmit>>>,
|
||||||
|
|
@ -219,7 +219,7 @@ impl Default for HookRegistry {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HookRegistry {
|
impl HookRegistry {
|
||||||
/// 空の HookRegistry を作成
|
/// Create an empty HookRegistry
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
on_prompt_submit: Vec::new(),
|
on_prompt_submit: Vec::new(),
|
||||||
|
|
|
||||||
|
|
@ -1,34 +1,34 @@
|
||||||
//! llm-worker - LLMワーカーライブラリ
|
//! llm-worker - LLM Worker Library
|
||||||
//!
|
//!
|
||||||
//! LLMとの対話を管理するコンポーネントを提供します。
|
//! Provides components for managing interactions with LLMs.
|
||||||
//!
|
//!
|
||||||
//! # 主要なコンポーネント
|
//! # Main Components
|
||||||
//!
|
//!
|
||||||
//! - [`Worker`] - LLMとの対話を管理する中心コンポーネント
|
//! - [`Worker`] - Central component for managing LLM interactions
|
||||||
//! - [`tool::Tool`] - LLMから呼び出し可能なツール
|
//! - [`tool::Tool`] - Tools that can be invoked by the LLM
|
||||||
//! - [`hook::Hook`] - ターン進行への介入
|
//! - [`hook::Hook`] - Hooks for intercepting turn progression
|
||||||
//! - [`subscriber::WorkerSubscriber`] - ストリーミングイベントの購読
|
//! - [`subscriber::WorkerSubscriber`] - Subscribing to streaming events
|
||||||
//!
|
//!
|
||||||
//! # Quick Start
|
//! # Quick Start
|
||||||
//!
|
//!
|
||||||
//! ```ignore
|
//! ```ignore
|
||||||
//! use llm_worker::{Worker, Message};
|
//! use llm_worker::{Worker, Item};
|
||||||
//!
|
//!
|
||||||
//! // Workerを作成
|
//! // Create a Worker
|
||||||
//! let mut worker = Worker::new(client)
|
//! let mut worker = Worker::new(client)
|
||||||
//! .system_prompt("You are a helpful assistant.");
|
//! .system_prompt("You are a helpful assistant.");
|
||||||
//!
|
//!
|
||||||
//! // ツールを登録(オプション)
|
//! // Register tools (optional)
|
||||||
//! // worker.register_tool(my_tool_definition)?;
|
//! // worker.register_tool(my_tool_definition)?;
|
||||||
//!
|
//!
|
||||||
//! // 対話を実行
|
//! // Run the interaction
|
||||||
//! let history = worker.run("Hello!").await?;
|
//! let history = worker.run("Hello!").await?;
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! # キャッシュ保護
|
//! # Cache Protection
|
||||||
//!
|
//!
|
||||||
//! KVキャッシュのヒット率を最大化するには、[`Worker::lock()`]で
|
//! To maximize KV cache hit rate, transition to the locked state
|
||||||
//! ロック状態に遷移してから実行してください。
|
//! with [`Worker::lock()`] before execution.
|
||||||
//!
|
//!
|
||||||
//! ```ignore
|
//! ```ignore
|
||||||
//! let mut locked = worker.lock();
|
//! let mut locked = worker.lock();
|
||||||
|
|
@ -46,6 +46,7 @@ pub mod state;
|
||||||
pub mod subscriber;
|
pub mod subscriber;
|
||||||
pub mod timeline;
|
pub mod timeline;
|
||||||
pub mod tool;
|
pub mod tool;
|
||||||
|
pub mod tool_server;
|
||||||
|
|
||||||
pub use message::{ContentPart, Message, MessageContent, Role};
|
pub use message::{ContentPart, Item, Message, Role};
|
||||||
pub use worker::{ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult};
|
pub use worker::{ToolRegistryError, Worker, WorkerConfig, WorkerError, WorkerResult};
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,17 @@
|
||||||
//! Anthropic リクエスト生成
|
//! Anthropic Request Builder
|
||||||
|
//!
|
||||||
|
//! Converts Open Responses native Item model to Anthropic Messages API format.
|
||||||
|
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
|
||||||
use crate::llm_client::{
|
use crate::llm_client::{
|
||||||
|
types::{ContentPart, Item, Role, ToolDefinition},
|
||||||
Request,
|
Request,
|
||||||
types::{ContentPart, Message, MessageContent, Role, ToolDefinition},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::AnthropicScheme;
|
use super::AnthropicScheme;
|
||||||
|
|
||||||
/// Anthropic APIへのリクエストボディ
|
/// Anthropic API request body
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub(crate) struct AnthropicRequest {
|
pub(crate) struct AnthropicRequest {
|
||||||
pub model: String,
|
pub model: String,
|
||||||
|
|
@ -30,14 +32,14 @@ pub(crate) struct AnthropicRequest {
|
||||||
pub stream: bool,
|
pub stream: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Anthropic メッセージ
|
/// Anthropic message
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub(crate) struct AnthropicMessage {
|
pub(crate) struct AnthropicMessage {
|
||||||
pub role: String,
|
pub role: String,
|
||||||
pub content: AnthropicContent,
|
pub content: AnthropicContent,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Anthropic コンテンツ
|
/// Anthropic content
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub(crate) enum AnthropicContent {
|
pub(crate) enum AnthropicContent {
|
||||||
|
|
@ -45,7 +47,7 @@ pub(crate) enum AnthropicContent {
|
||||||
Parts(Vec<AnthropicContentPart>),
|
Parts(Vec<AnthropicContentPart>),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Anthropic コンテンツパーツ
|
/// Anthropic content part
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
pub(crate) enum AnthropicContentPart {
|
pub(crate) enum AnthropicContentPart {
|
||||||
|
|
@ -58,13 +60,10 @@ pub(crate) enum AnthropicContentPart {
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
},
|
},
|
||||||
#[serde(rename = "tool_result")]
|
#[serde(rename = "tool_result")]
|
||||||
ToolResult {
|
ToolResult { tool_use_id: String, content: String },
|
||||||
tool_use_id: String,
|
|
||||||
content: String,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Anthropic ツール定義
|
/// Anthropic tool definition
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub(crate) struct AnthropicTool {
|
pub(crate) struct AnthropicTool {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
|
|
@ -74,14 +73,9 @@ pub(crate) struct AnthropicTool {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AnthropicScheme {
|
impl AnthropicScheme {
|
||||||
/// RequestからAnthropicのリクエストボディを構築
|
/// Build Anthropic request from Request
|
||||||
pub(crate) fn build_request(&self, model: &str, request: &Request) -> AnthropicRequest {
|
pub(crate) fn build_request(&self, model: &str, request: &Request) -> AnthropicRequest {
|
||||||
let messages = request
|
let messages = self.convert_items_to_messages(&request.items);
|
||||||
.messages
|
|
||||||
.iter()
|
|
||||||
.map(|m| self.convert_message(m))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let tools = request.tools.iter().map(|t| self.convert_tool(t)).collect();
|
let tools = request.tools.iter().map(|t| self.convert_tool(t)).collect();
|
||||||
|
|
||||||
AnthropicRequest {
|
AnthropicRequest {
|
||||||
|
|
@ -98,49 +92,160 @@ impl AnthropicScheme {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn convert_message(&self, message: &Message) -> AnthropicMessage {
|
/// Convert Open Responses Items to Anthropic Messages
|
||||||
let role = match message.role {
|
///
|
||||||
|
/// Anthropic uses a message-based model where:
|
||||||
|
/// - User messages have role "user"
|
||||||
|
/// - Assistant messages have role "assistant"
|
||||||
|
/// - Tool calls are content parts within assistant messages
|
||||||
|
/// - Tool results are content parts within user messages
|
||||||
|
fn convert_items_to_messages(&self, items: &[Item]) -> Vec<AnthropicMessage> {
|
||||||
|
let mut messages = Vec::new();
|
||||||
|
let mut pending_assistant_parts: Vec<AnthropicContentPart> = Vec::new();
|
||||||
|
let mut pending_user_parts: Vec<AnthropicContentPart> = Vec::new();
|
||||||
|
|
||||||
|
for item in items {
|
||||||
|
match item {
|
||||||
|
Item::Message { role, content, .. } => {
|
||||||
|
// Flush pending parts before a new message
|
||||||
|
self.flush_pending_parts(
|
||||||
|
&mut messages,
|
||||||
|
&mut pending_assistant_parts,
|
||||||
|
&mut pending_user_parts,
|
||||||
|
);
|
||||||
|
|
||||||
|
let anthropic_role = match role {
|
||||||
Role::User => "user",
|
Role::User => "user",
|
||||||
Role::Assistant => "assistant",
|
Role::Assistant => "assistant",
|
||||||
|
Role::System => continue, // Skip system role items
|
||||||
};
|
};
|
||||||
|
|
||||||
let content = match &message.content {
|
let parts: Vec<AnthropicContentPart> = content
|
||||||
MessageContent::Text(text) => AnthropicContent::Text(text.clone()),
|
|
||||||
MessageContent::ToolResult {
|
|
||||||
tool_use_id,
|
|
||||||
content,
|
|
||||||
} => AnthropicContent::Parts(vec![AnthropicContentPart::ToolResult {
|
|
||||||
tool_use_id: tool_use_id.clone(),
|
|
||||||
content: content.clone(),
|
|
||||||
}]),
|
|
||||||
MessageContent::Parts(parts) => {
|
|
||||||
let converted: Vec<_> = parts
|
|
||||||
.iter()
|
.iter()
|
||||||
.map(|p| match p {
|
.map(|p| match p {
|
||||||
ContentPart::Text { text } => {
|
ContentPart::InputText { text } => {
|
||||||
AnthropicContentPart::Text { text: text.clone() }
|
AnthropicContentPart::Text { text: text.clone() }
|
||||||
}
|
}
|
||||||
ContentPart::ToolUse { id, name, input } => AnthropicContentPart::ToolUse {
|
ContentPart::OutputText { text } => {
|
||||||
id: id.clone(),
|
AnthropicContentPart::Text { text: text.clone() }
|
||||||
name: name.clone(),
|
}
|
||||||
input: input.clone(),
|
ContentPart::Refusal { refusal } => {
|
||||||
},
|
AnthropicContentPart::Text {
|
||||||
ContentPart::ToolResult {
|
text: refusal.clone(),
|
||||||
tool_use_id,
|
}
|
||||||
content,
|
}
|
||||||
} => AnthropicContentPart::ToolResult {
|
|
||||||
tool_use_id: tool_use_id.clone(),
|
|
||||||
content: content.clone(),
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
AnthropicContent::Parts(converted)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
AnthropicMessage {
|
if parts.len() == 1 {
|
||||||
role: role.to_string(),
|
if let AnthropicContentPart::Text { text } = &parts[0] {
|
||||||
content,
|
messages.push(AnthropicMessage {
|
||||||
|
role: anthropic_role.to_string(),
|
||||||
|
content: AnthropicContent::Text(text.clone()),
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
messages.push(AnthropicMessage {
|
||||||
|
role: anthropic_role.to_string(),
|
||||||
|
content: AnthropicContent::Parts(parts),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
messages.push(AnthropicMessage {
|
||||||
|
role: anthropic_role.to_string(),
|
||||||
|
content: AnthropicContent::Parts(parts),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Item::FunctionCall {
|
||||||
|
call_id,
|
||||||
|
name,
|
||||||
|
arguments,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
// Flush pending user parts first
|
||||||
|
if !pending_user_parts.is_empty() {
|
||||||
|
messages.push(AnthropicMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: AnthropicContent::Parts(std::mem::take(
|
||||||
|
&mut pending_user_parts,
|
||||||
|
)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse arguments JSON string to Value
|
||||||
|
let input = serde_json::from_str(arguments)
|
||||||
|
.unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new()));
|
||||||
|
|
||||||
|
pending_assistant_parts.push(AnthropicContentPart::ToolUse {
|
||||||
|
id: call_id.clone(),
|
||||||
|
name: name.clone(),
|
||||||
|
input,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Item::FunctionCallOutput { call_id, output, .. } => {
|
||||||
|
// Flush pending assistant parts first
|
||||||
|
if !pending_assistant_parts.is_empty() {
|
||||||
|
messages.push(AnthropicMessage {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: AnthropicContent::Parts(std::mem::take(
|
||||||
|
&mut pending_assistant_parts,
|
||||||
|
)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
pending_user_parts.push(AnthropicContentPart::ToolResult {
|
||||||
|
tool_use_id: call_id.clone(),
|
||||||
|
content: output.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Item::Reasoning { text, .. } => {
|
||||||
|
// Flush pending user parts first
|
||||||
|
if !pending_user_parts.is_empty() {
|
||||||
|
messages.push(AnthropicMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: AnthropicContent::Parts(std::mem::take(
|
||||||
|
&mut pending_user_parts,
|
||||||
|
)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reasoning is treated as assistant text in Anthropic
|
||||||
|
// (actual thinking blocks are handled differently in streaming)
|
||||||
|
pending_assistant_parts.push(AnthropicContentPart::Text { text: text.clone() });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush remaining pending parts
|
||||||
|
self.flush_pending_parts(
|
||||||
|
&mut messages,
|
||||||
|
&mut pending_assistant_parts,
|
||||||
|
&mut pending_user_parts,
|
||||||
|
);
|
||||||
|
|
||||||
|
messages
|
||||||
|
}
|
||||||
|
|
||||||
|
fn flush_pending_parts(
|
||||||
|
&self,
|
||||||
|
messages: &mut Vec<AnthropicMessage>,
|
||||||
|
pending_assistant_parts: &mut Vec<AnthropicContentPart>,
|
||||||
|
pending_user_parts: &mut Vec<AnthropicContentPart>,
|
||||||
|
) {
|
||||||
|
if !pending_assistant_parts.is_empty() {
|
||||||
|
messages.push(AnthropicMessage {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: AnthropicContent::Parts(std::mem::take(pending_assistant_parts)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if !pending_user_parts.is_empty() {
|
||||||
|
messages.push(AnthropicMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: AnthropicContent::Parts(std::mem::take(pending_user_parts)),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -195,4 +300,24 @@ mod tests {
|
||||||
assert_eq!(anthropic_req.tools.len(), 1);
|
assert_eq!(anthropic_req.tools.len(), 1);
|
||||||
assert_eq!(anthropic_req.tools[0].name, "get_weather");
|
assert_eq!(anthropic_req.tools[0].name, "get_weather");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_function_call_and_output() {
|
||||||
|
let scheme = AnthropicScheme::new();
|
||||||
|
let request = Request::new()
|
||||||
|
.user("What's the weather?")
|
||||||
|
.item(Item::function_call(
|
||||||
|
"call_123",
|
||||||
|
"get_weather",
|
||||||
|
r#"{"city":"Tokyo"}"#,
|
||||||
|
))
|
||||||
|
.item(Item::function_call_output("call_123", "Sunny, 25°C"));
|
||||||
|
|
||||||
|
let anthropic_req = scheme.build_request("claude-sonnet-4-20250514", &request);
|
||||||
|
|
||||||
|
assert_eq!(anthropic_req.messages.len(), 3);
|
||||||
|
assert_eq!(anthropic_req.messages[0].role, "user");
|
||||||
|
assert_eq!(anthropic_req.messages[1].role, "assistant");
|
||||||
|
assert_eq!(anthropic_req.messages[2].role, "user");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,130 +1,130 @@
|
||||||
//! Gemini リクエスト生成
|
//! Gemini Request Builder
|
||||||
//!
|
//!
|
||||||
//! Google Gemini APIへのリクエストボディを構築
|
//! Converts Open Responses native Item model to Google Gemini API format.
|
||||||
|
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::llm_client::{
|
use crate::llm_client::{
|
||||||
|
types::{Item, Role, ToolDefinition},
|
||||||
Request,
|
Request,
|
||||||
types::{ContentPart, Message, MessageContent, Role, ToolDefinition},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::GeminiScheme;
|
use super::GeminiScheme;
|
||||||
|
|
||||||
/// Gemini APIへのリクエストボディ
|
/// Gemini API request body
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub(crate) struct GeminiRequest {
|
pub(crate) struct GeminiRequest {
|
||||||
/// コンテンツ(会話履歴)
|
/// Contents (conversation history)
|
||||||
pub contents: Vec<GeminiContent>,
|
pub contents: Vec<GeminiContent>,
|
||||||
/// システム指示
|
/// System instruction
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub system_instruction: Option<GeminiContent>,
|
pub system_instruction: Option<GeminiContent>,
|
||||||
/// ツール定義
|
/// Tool definitions
|
||||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||||
pub tools: Vec<GeminiTool>,
|
pub tools: Vec<GeminiTool>,
|
||||||
/// ツール設定
|
/// Tool config
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tool_config: Option<GeminiToolConfig>,
|
pub tool_config: Option<GeminiToolConfig>,
|
||||||
/// 生成設定
|
/// Generation config
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub generation_config: Option<GeminiGenerationConfig>,
|
pub generation_config: Option<GeminiGenerationConfig>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gemini コンテンツ
|
/// Gemini content
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub(crate) struct GeminiContent {
|
pub(crate) struct GeminiContent {
|
||||||
/// ロール
|
/// Role
|
||||||
pub role: String,
|
pub role: String,
|
||||||
/// パーツ
|
/// Parts
|
||||||
pub parts: Vec<GeminiPart>,
|
pub parts: Vec<GeminiPart>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gemini パーツ
|
/// Gemini part
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub(crate) enum GeminiPart {
|
pub(crate) enum GeminiPart {
|
||||||
/// テキストパーツ
|
/// Text part
|
||||||
Text { text: String },
|
Text { text: String },
|
||||||
/// 関数呼び出しパーツ
|
/// Function call part
|
||||||
FunctionCall {
|
FunctionCall {
|
||||||
#[serde(rename = "functionCall")]
|
#[serde(rename = "functionCall")]
|
||||||
function_call: GeminiFunctionCall,
|
function_call: GeminiFunctionCall,
|
||||||
},
|
},
|
||||||
/// 関数レスポンスパーツ
|
/// Function response part
|
||||||
FunctionResponse {
|
FunctionResponse {
|
||||||
#[serde(rename = "functionResponse")]
|
#[serde(rename = "functionResponse")]
|
||||||
function_response: GeminiFunctionResponse,
|
function_response: GeminiFunctionResponse,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gemini 関数呼び出し
|
/// Gemini function call
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub(crate) struct GeminiFunctionCall {
|
pub(crate) struct GeminiFunctionCall {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub args: Value,
|
pub args: Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gemini 関数レスポンス
|
/// Gemini function response
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub(crate) struct GeminiFunctionResponse {
|
pub(crate) struct GeminiFunctionResponse {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub response: GeminiFunctionResponseContent,
|
pub response: GeminiFunctionResponseContent,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gemini 関数レスポンス内容
|
/// Gemini function response content
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub(crate) struct GeminiFunctionResponseContent {
|
pub(crate) struct GeminiFunctionResponseContent {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub content: Value,
|
pub content: Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gemini ツール定義
|
/// Gemini tool definition
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub(crate) struct GeminiTool {
|
pub(crate) struct GeminiTool {
|
||||||
/// 関数宣言
|
/// Function declarations
|
||||||
pub function_declarations: Vec<GeminiFunctionDeclaration>,
|
pub function_declarations: Vec<GeminiFunctionDeclaration>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gemini 関数宣言
|
/// Gemini function declaration
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub(crate) struct GeminiFunctionDeclaration {
|
pub(crate) struct GeminiFunctionDeclaration {
|
||||||
/// 関数名
|
/// Function name
|
||||||
pub name: String,
|
pub name: String,
|
||||||
/// 説明
|
/// Description
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub description: Option<String>,
|
pub description: Option<String>,
|
||||||
/// パラメータスキーマ
|
/// Parameter schema
|
||||||
pub parameters: Value,
|
pub parameters: Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gemini ツール設定
|
/// Gemini tool config
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub(crate) struct GeminiToolConfig {
|
pub(crate) struct GeminiToolConfig {
|
||||||
/// 関数呼び出し設定
|
/// Function calling config
|
||||||
pub function_calling_config: GeminiFunctionCallingConfig,
|
pub function_calling_config: GeminiFunctionCallingConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gemini 関数呼び出し設定
|
/// Gemini function calling config
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub(crate) struct GeminiFunctionCallingConfig {
|
pub(crate) struct GeminiFunctionCallingConfig {
|
||||||
/// モード: AUTO, ANY, NONE
|
/// Mode: AUTO, ANY, NONE
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub mode: Option<String>,
|
pub mode: Option<String>,
|
||||||
/// ストリーミング関数呼び出し引数を有効にするか
|
/// Enable streaming function call arguments
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub stream_function_call_arguments: Option<bool>,
|
pub stream_function_call_arguments: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gemini 生成設定
|
/// Gemini generation config
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub(crate) struct GeminiGenerationConfig {
|
pub(crate) struct GeminiGenerationConfig {
|
||||||
/// 最大出力トークン数
|
/// Max output tokens
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub max_output_tokens: Option<u32>,
|
pub max_output_tokens: Option<u32>,
|
||||||
/// Temperature
|
/// Temperature
|
||||||
|
|
@ -136,27 +136,23 @@ pub(crate) struct GeminiGenerationConfig {
|
||||||
/// Top K
|
/// Top K
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub top_k: Option<u32>,
|
pub top_k: Option<u32>,
|
||||||
/// ストップシーケンス
|
/// Stop sequences
|
||||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||||
pub stop_sequences: Vec<String>,
|
pub stop_sequences: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GeminiScheme {
|
impl GeminiScheme {
|
||||||
/// RequestからGeminiのリクエストボディを構築
|
/// Build Gemini request from Request
|
||||||
pub(crate) fn build_request(&self, request: &Request) -> GeminiRequest {
|
pub(crate) fn build_request(&self, request: &Request) -> GeminiRequest {
|
||||||
let mut contents = Vec::new();
|
let contents = self.convert_items_to_contents(&request.items);
|
||||||
|
|
||||||
for message in &request.messages {
|
// System prompt
|
||||||
contents.push(self.convert_message(message));
|
|
||||||
}
|
|
||||||
|
|
||||||
// システムプロンプト
|
|
||||||
let system_instruction = request.system_prompt.as_ref().map(|s| GeminiContent {
|
let system_instruction = request.system_prompt.as_ref().map(|s| GeminiContent {
|
||||||
role: "user".to_string(), // system_instructionではroleは"user"か省略
|
role: "user".to_string(),
|
||||||
parts: vec![GeminiPart::Text { text: s.clone() }],
|
parts: vec![GeminiPart::Text { text: s.clone() }],
|
||||||
});
|
});
|
||||||
|
|
||||||
// ツール
|
// Tools
|
||||||
let tools = if request.tools.is_empty() {
|
let tools = if request.tools.is_empty() {
|
||||||
vec![]
|
vec![]
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -165,7 +161,7 @@ impl GeminiScheme {
|
||||||
}]
|
}]
|
||||||
};
|
};
|
||||||
|
|
||||||
// ツール設定
|
// Tool config
|
||||||
let tool_config = if !request.tools.is_empty() {
|
let tool_config = if !request.tools.is_empty() {
|
||||||
Some(GeminiToolConfig {
|
Some(GeminiToolConfig {
|
||||||
function_calling_config: GeminiFunctionCallingConfig {
|
function_calling_config: GeminiFunctionCallingConfig {
|
||||||
|
|
@ -181,7 +177,7 @@ impl GeminiScheme {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
// 生成設定
|
// Generation config
|
||||||
let generation_config = Some(GeminiGenerationConfig {
|
let generation_config = Some(GeminiGenerationConfig {
|
||||||
max_output_tokens: request.config.max_tokens,
|
max_output_tokens: request.config.max_tokens,
|
||||||
temperature: request.config.temperature,
|
temperature: request.config.temperature,
|
||||||
|
|
@ -199,58 +195,126 @@ impl GeminiScheme {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn convert_message(&self, message: &Message) -> GeminiContent {
|
/// Convert Open Responses Items to Gemini Contents
|
||||||
let role = match message.role {
|
///
|
||||||
|
/// Gemini uses:
|
||||||
|
/// - role "user" for user messages and function responses
|
||||||
|
/// - role "model" for assistant messages and function calls
|
||||||
|
fn convert_items_to_contents(&self, items: &[Item]) -> Vec<GeminiContent> {
|
||||||
|
let mut contents = Vec::new();
|
||||||
|
let mut pending_model_parts: Vec<GeminiPart> = Vec::new();
|
||||||
|
let mut pending_user_parts: Vec<GeminiPart> = Vec::new();
|
||||||
|
|
||||||
|
for item in items {
|
||||||
|
match item {
|
||||||
|
Item::Message { role, content, .. } => {
|
||||||
|
// Flush pending parts
|
||||||
|
self.flush_pending_parts(
|
||||||
|
&mut contents,
|
||||||
|
&mut pending_model_parts,
|
||||||
|
&mut pending_user_parts,
|
||||||
|
);
|
||||||
|
|
||||||
|
let gemini_role = match role {
|
||||||
Role::User => "user",
|
Role::User => "user",
|
||||||
Role::Assistant => "model",
|
Role::Assistant => "model",
|
||||||
|
Role::System => continue, // Skip system role items
|
||||||
};
|
};
|
||||||
|
|
||||||
let parts = match &message.content {
|
let parts: Vec<GeminiPart> = content
|
||||||
MessageContent::Text(text) => vec![GeminiPart::Text { text: text.clone() }],
|
|
||||||
MessageContent::ToolResult {
|
|
||||||
tool_use_id,
|
|
||||||
content,
|
|
||||||
} => {
|
|
||||||
// Geminiでは関数レスポンスとしてマップ
|
|
||||||
vec![GeminiPart::FunctionResponse {
|
|
||||||
function_response: GeminiFunctionResponse {
|
|
||||||
name: tool_use_id.clone(),
|
|
||||||
response: GeminiFunctionResponseContent {
|
|
||||||
name: tool_use_id.clone(),
|
|
||||||
content: serde_json::Value::String(content.clone()),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}]
|
|
||||||
}
|
|
||||||
MessageContent::Parts(parts) => parts
|
|
||||||
.iter()
|
.iter()
|
||||||
.map(|p| match p {
|
.map(|p| GeminiPart::Text {
|
||||||
ContentPart::Text { text } => GeminiPart::Text { text: text.clone() },
|
text: p.as_text().to_string(),
|
||||||
ContentPart::ToolUse { id: _, name, input } => GeminiPart::FunctionCall {
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
contents.push(GeminiContent {
|
||||||
|
role: gemini_role.to_string(),
|
||||||
|
parts,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Item::FunctionCall {
|
||||||
|
name, arguments, ..
|
||||||
|
} => {
|
||||||
|
// Flush pending user parts first
|
||||||
|
if !pending_user_parts.is_empty() {
|
||||||
|
contents.push(GeminiContent {
|
||||||
|
role: "user".to_string(),
|
||||||
|
parts: std::mem::take(&mut pending_user_parts),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse arguments
|
||||||
|
let args = serde_json::from_str(arguments)
|
||||||
|
.unwrap_or_else(|_| Value::Object(serde_json::Map::new()));
|
||||||
|
|
||||||
|
pending_model_parts.push(GeminiPart::FunctionCall {
|
||||||
function_call: GeminiFunctionCall {
|
function_call: GeminiFunctionCall {
|
||||||
name: name.clone(),
|
name: name.clone(),
|
||||||
args: input.clone(),
|
args,
|
||||||
},
|
},
|
||||||
},
|
});
|
||||||
ContentPart::ToolResult {
|
}
|
||||||
tool_use_id,
|
|
||||||
content,
|
|
||||||
} => GeminiPart::FunctionResponse {
|
|
||||||
function_response: GeminiFunctionResponse {
|
|
||||||
name: tool_use_id.clone(),
|
|
||||||
response: GeminiFunctionResponseContent {
|
|
||||||
name: tool_use_id.clone(),
|
|
||||||
content: serde_json::Value::String(content.clone()),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
};
|
|
||||||
|
|
||||||
GeminiContent {
|
Item::FunctionCallOutput { call_id, output, .. } => {
|
||||||
role: role.to_string(),
|
// Flush pending model parts first
|
||||||
parts,
|
if !pending_model_parts.is_empty() {
|
||||||
|
contents.push(GeminiContent {
|
||||||
|
role: "model".to_string(),
|
||||||
|
parts: std::mem::take(&mut pending_model_parts),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
pending_user_parts.push(GeminiPart::FunctionResponse {
|
||||||
|
function_response: GeminiFunctionResponse {
|
||||||
|
name: call_id.clone(),
|
||||||
|
response: GeminiFunctionResponseContent {
|
||||||
|
name: call_id.clone(),
|
||||||
|
content: Value::String(output.clone()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Item::Reasoning { text, .. } => {
|
||||||
|
// Flush pending user parts first
|
||||||
|
if !pending_user_parts.is_empty() {
|
||||||
|
contents.push(GeminiContent {
|
||||||
|
role: "user".to_string(),
|
||||||
|
parts: std::mem::take(&mut pending_user_parts),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reasoning is treated as model text in Gemini
|
||||||
|
pending_model_parts.push(GeminiPart::Text { text: text.clone() });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush remaining pending parts
|
||||||
|
self.flush_pending_parts(&mut contents, &mut pending_model_parts, &mut pending_user_parts);
|
||||||
|
|
||||||
|
contents
|
||||||
|
}
|
||||||
|
|
||||||
|
fn flush_pending_parts(
|
||||||
|
&self,
|
||||||
|
contents: &mut Vec<GeminiContent>,
|
||||||
|
pending_model_parts: &mut Vec<GeminiPart>,
|
||||||
|
pending_user_parts: &mut Vec<GeminiPart>,
|
||||||
|
) {
|
||||||
|
if !pending_model_parts.is_empty() {
|
||||||
|
contents.push(GeminiContent {
|
||||||
|
role: "model".to_string(),
|
||||||
|
parts: std::mem::take(pending_model_parts),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if !pending_user_parts.is_empty() {
|
||||||
|
contents.push(GeminiContent {
|
||||||
|
role: "user".to_string(),
|
||||||
|
parts: std::mem::take(pending_user_parts),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -318,4 +382,24 @@ mod tests {
|
||||||
assert_eq!(gemini_req.contents[0].role, "user");
|
assert_eq!(gemini_req.contents[0].role, "user");
|
||||||
assert_eq!(gemini_req.contents[1].role, "model");
|
assert_eq!(gemini_req.contents[1].role, "model");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_function_call_and_output() {
|
||||||
|
let scheme = GeminiScheme::new();
|
||||||
|
let request = Request::new()
|
||||||
|
.user("What's the weather?")
|
||||||
|
.item(Item::function_call(
|
||||||
|
"call_123",
|
||||||
|
"get_weather",
|
||||||
|
r#"{"city":"Tokyo"}"#,
|
||||||
|
))
|
||||||
|
.item(Item::function_call_output("call_123", "Sunny, 25°C"));
|
||||||
|
|
||||||
|
let gemini_req = scheme.build_request(&request);
|
||||||
|
|
||||||
|
assert_eq!(gemini_req.contents.len(), 3);
|
||||||
|
assert_eq!(gemini_req.contents[0].role, "user");
|
||||||
|
assert_eq!(gemini_req.contents[1].role, "model");
|
||||||
|
assert_eq!(gemini_req.contents[2].role, "user");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,23 @@
|
||||||
//! OpenAI リクエスト生成
|
//! OpenAI Request Builder
|
||||||
|
//!
|
||||||
|
//! Converts Open Responses native Item model to OpenAI Chat Completions API format.
|
||||||
|
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::llm_client::{
|
use crate::llm_client::{
|
||||||
|
types::{Item, Role, ToolDefinition},
|
||||||
Request,
|
Request,
|
||||||
types::{ContentPart, Message, MessageContent, Role, ToolDefinition},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::OpenAIScheme;
|
use super::OpenAIScheme;
|
||||||
|
|
||||||
/// OpenAI APIへのリクエストボディ
|
/// OpenAI API request body
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub(crate) struct OpenAIRequest {
|
pub(crate) struct OpenAIRequest {
|
||||||
pub model: String,
|
pub model: String,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub max_completion_tokens: Option<u32>, // max_tokens is deprecated for newer models, generally max_completion_tokens is preferred
|
pub max_completion_tokens: Option<u32>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub max_tokens: Option<u32>, // Legacy field for compatibility (e.g. Ollama)
|
pub max_tokens: Option<u32>, // Legacy field for compatibility (e.g. Ollama)
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
|
@ -31,7 +33,7 @@ pub(crate) struct OpenAIRequest {
|
||||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||||
pub tools: Vec<OpenAITool>,
|
pub tools: Vec<OpenAITool>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tool_choice: Option<String>, // "auto", "none", or specific
|
pub tool_choice: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
|
|
@ -39,20 +41,21 @@ pub(crate) struct StreamOptions {
|
||||||
pub include_usage: bool,
|
pub include_usage: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// OpenAI メッセージ
|
/// OpenAI message
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub(crate) struct OpenAIMessage {
|
pub(crate) struct OpenAIMessage {
|
||||||
pub role: String,
|
pub role: String,
|
||||||
pub content: Option<OpenAIContent>, // Optional for assistant tool calls
|
pub content: Option<OpenAIContent>,
|
||||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||||
pub tool_calls: Vec<OpenAIToolCall>,
|
pub tool_calls: Vec<OpenAIToolCall>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tool_call_id: Option<String>, // For tool_result (role: tool)
|
pub tool_call_id: Option<String>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub name: Option<String>, // Optional name
|
pub name: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// OpenAI コンテンツ
|
/// OpenAI content
|
||||||
|
#[allow(dead_code)]
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub(crate) enum OpenAIContent {
|
pub(crate) enum OpenAIContent {
|
||||||
|
|
@ -60,7 +63,7 @@ pub(crate) enum OpenAIContent {
|
||||||
Parts(Vec<OpenAIContentPart>),
|
Parts(Vec<OpenAIContentPart>),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// OpenAI コンテンツパーツ
|
/// OpenAI content part
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
|
|
@ -76,7 +79,7 @@ pub(crate) struct ImageUrl {
|
||||||
pub url: String,
|
pub url: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// OpenAI ツール定義
|
/// OpenAI tool definition
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub(crate) struct OpenAITool {
|
pub(crate) struct OpenAITool {
|
||||||
pub r#type: String,
|
pub r#type: String,
|
||||||
|
|
@ -91,7 +94,7 @@ pub(crate) struct OpenAIToolFunction {
|
||||||
pub parameters: Value,
|
pub parameters: Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// OpenAI ツール呼び出し(メッセージ内)
|
/// OpenAI tool call in message
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub(crate) struct OpenAIToolCall {
|
pub(crate) struct OpenAIToolCall {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
|
|
@ -106,10 +109,11 @@ pub(crate) struct OpenAIToolCallFunction {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAIScheme {
|
impl OpenAIScheme {
|
||||||
/// RequestからOpenAIのリクエストボディを構築
|
/// Build OpenAI request from Request
|
||||||
pub(crate) fn build_request(&self, model: &str, request: &Request) -> OpenAIRequest {
|
pub(crate) fn build_request(&self, model: &str, request: &Request) -> OpenAIRequest {
|
||||||
let mut messages = Vec::new();
|
let mut messages = Vec::new();
|
||||||
|
|
||||||
|
// Add system message if present
|
||||||
if let Some(system) = &request.system_prompt {
|
if let Some(system) = &request.system_prompt {
|
||||||
messages.push(OpenAIMessage {
|
messages.push(OpenAIMessage {
|
||||||
role: "system".to_string(),
|
role: "system".to_string(),
|
||||||
|
|
@ -120,7 +124,8 @@ impl OpenAIScheme {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
messages.extend(request.messages.iter().map(|m| self.convert_message(m)));
|
// Convert items to messages
|
||||||
|
messages.extend(self.convert_items_to_messages(&request.items));
|
||||||
|
|
||||||
let tools = request.tools.iter().map(|t| self.convert_tool(t)).collect();
|
let tools = request.tools.iter().map(|t| self.convert_tool(t)).collect();
|
||||||
|
|
||||||
|
|
@ -143,106 +148,122 @@ impl OpenAIScheme {
|
||||||
}),
|
}),
|
||||||
messages,
|
messages,
|
||||||
tools,
|
tools,
|
||||||
tool_choice: None, // Default to auto if tools are present? Or let API decide (which is auto)
|
tool_choice: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn convert_message(&self, message: &Message) -> OpenAIMessage {
|
/// Convert Open Responses Items to OpenAI Messages
|
||||||
match &message.content {
|
///
|
||||||
MessageContent::ToolResult {
|
/// OpenAI uses a message-based model where:
|
||||||
tool_use_id,
|
/// - User messages have role "user"
|
||||||
content,
|
/// - Assistant messages have role "assistant"
|
||||||
} => OpenAIMessage {
|
/// - Tool calls are within assistant messages as tool_calls array
|
||||||
role: "tool".to_string(),
|
/// - Tool results have role "tool" with tool_call_id
|
||||||
content: Some(OpenAIContent::Text(content.clone())),
|
fn convert_items_to_messages(&self, items: &[Item]) -> Vec<OpenAIMessage> {
|
||||||
tool_calls: vec![],
|
let mut messages = Vec::new();
|
||||||
tool_call_id: Some(tool_use_id.clone()),
|
let mut pending_tool_calls: Vec<OpenAIToolCall> = Vec::new();
|
||||||
name: None,
|
let mut pending_assistant_text: Option<String> = None;
|
||||||
},
|
|
||||||
MessageContent::Text(text) => {
|
for item in items {
|
||||||
let role = match message.role {
|
match item {
|
||||||
|
Item::Message { role, content, .. } => {
|
||||||
|
// Flush pending tool calls
|
||||||
|
self.flush_pending_assistant(
|
||||||
|
&mut messages,
|
||||||
|
&mut pending_tool_calls,
|
||||||
|
&mut pending_assistant_text,
|
||||||
|
);
|
||||||
|
|
||||||
|
let openai_role = match role {
|
||||||
Role::User => "user",
|
Role::User => "user",
|
||||||
Role::Assistant => "assistant",
|
Role::Assistant => "assistant",
|
||||||
|
Role::System => "system",
|
||||||
};
|
};
|
||||||
OpenAIMessage {
|
|
||||||
role: role.to_string(),
|
let text_content: String = content
|
||||||
content: Some(OpenAIContent::Text(text.clone())),
|
.iter()
|
||||||
|
.map(|p| p.as_text())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("");
|
||||||
|
|
||||||
|
messages.push(OpenAIMessage {
|
||||||
|
role: openai_role.to_string(),
|
||||||
|
content: Some(OpenAIContent::Text(text_content)),
|
||||||
tool_calls: vec![],
|
tool_calls: vec![],
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
name: None,
|
name: None,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
|
||||||
MessageContent::Parts(parts) => {
|
|
||||||
let role = match message.role {
|
|
||||||
Role::User => "user",
|
|
||||||
Role::Assistant => "assistant",
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut content_parts = Vec::new();
|
Item::FunctionCall {
|
||||||
let mut tool_calls = Vec::new();
|
call_id,
|
||||||
let mut is_tool_result = false;
|
name,
|
||||||
let mut tool_result_id = None;
|
arguments,
|
||||||
let mut tool_result_content = String::new();
|
..
|
||||||
|
} => {
|
||||||
for part in parts {
|
pending_tool_calls.push(OpenAIToolCall {
|
||||||
match part {
|
id: call_id.clone(),
|
||||||
ContentPart::Text { text } => {
|
|
||||||
content_parts.push(OpenAIContentPart::Text { text: text.clone() });
|
|
||||||
}
|
|
||||||
ContentPart::ToolUse { id, name, input } => {
|
|
||||||
tool_calls.push(OpenAIToolCall {
|
|
||||||
id: id.clone(),
|
|
||||||
r#type: "function".to_string(),
|
r#type: "function".to_string(),
|
||||||
function: OpenAIToolCallFunction {
|
function: OpenAIToolCallFunction {
|
||||||
name: name.clone(),
|
name: name.clone(),
|
||||||
arguments: input.to_string(),
|
arguments: arguments.clone(),
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
ContentPart::ToolResult {
|
|
||||||
tool_use_id,
|
|
||||||
content,
|
|
||||||
} => {
|
|
||||||
// OpenAI doesn't support mixed content with ToolResult in the same message easily if not careful
|
|
||||||
// But strictly speaking, a Message with ToolResult should be its own message with role "tool"
|
|
||||||
is_tool_result = true;
|
|
||||||
tool_result_id = Some(tool_use_id.clone());
|
|
||||||
tool_result_content = content.clone();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if is_tool_result {
|
Item::FunctionCallOutput { call_id, output, .. } => {
|
||||||
OpenAIMessage {
|
// Flush pending tool calls before tool result
|
||||||
|
self.flush_pending_assistant(
|
||||||
|
&mut messages,
|
||||||
|
&mut pending_tool_calls,
|
||||||
|
&mut pending_assistant_text,
|
||||||
|
);
|
||||||
|
|
||||||
|
messages.push(OpenAIMessage {
|
||||||
role: "tool".to_string(),
|
role: "tool".to_string(),
|
||||||
content: Some(OpenAIContent::Text(tool_result_content)),
|
content: Some(OpenAIContent::Text(output.clone())),
|
||||||
tool_calls: vec![],
|
tool_calls: vec![],
|
||||||
tool_call_id: tool_result_id,
|
tool_call_id: Some(call_id.clone()),
|
||||||
name: None,
|
name: None,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
let content = if content_parts.is_empty() {
|
|
||||||
None
|
|
||||||
} else if content_parts.len() == 1 {
|
|
||||||
// Simplify single text part to just Text content if preferred, or keep as Parts
|
|
||||||
if let OpenAIContentPart::Text { text } = &content_parts[0] {
|
|
||||||
Some(OpenAIContent::Text(text.clone()))
|
|
||||||
} else {
|
|
||||||
Some(OpenAIContent::Parts(content_parts))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Some(OpenAIContent::Parts(content_parts))
|
|
||||||
};
|
|
||||||
|
|
||||||
OpenAIMessage {
|
Item::Reasoning { text, .. } => {
|
||||||
role: role.to_string(),
|
// Reasoning is treated as assistant text in OpenAI
|
||||||
content,
|
// (OpenAI doesn't have native reasoning support like Claude)
|
||||||
tool_calls,
|
if let Some(ref mut existing) = pending_assistant_text {
|
||||||
|
existing.push_str(text);
|
||||||
|
} else {
|
||||||
|
pending_assistant_text = Some(text.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush remaining pending items
|
||||||
|
self.flush_pending_assistant(
|
||||||
|
&mut messages,
|
||||||
|
&mut pending_tool_calls,
|
||||||
|
&mut pending_assistant_text,
|
||||||
|
);
|
||||||
|
|
||||||
|
messages
|
||||||
|
}
|
||||||
|
|
||||||
|
fn flush_pending_assistant(
|
||||||
|
&self,
|
||||||
|
messages: &mut Vec<OpenAIMessage>,
|
||||||
|
pending_tool_calls: &mut Vec<OpenAIToolCall>,
|
||||||
|
pending_assistant_text: &mut Option<String>,
|
||||||
|
) {
|
||||||
|
if !pending_tool_calls.is_empty() || pending_assistant_text.is_some() {
|
||||||
|
messages.push(OpenAIMessage {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: pending_assistant_text.take().map(OpenAIContent::Text),
|
||||||
|
tool_calls: std::mem::take(pending_tool_calls),
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
name: None,
|
name: None,
|
||||||
}
|
});
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -274,7 +295,6 @@ mod tests {
|
||||||
assert_eq!(body.messages[0].role, "system");
|
assert_eq!(body.messages[0].role, "system");
|
||||||
assert_eq!(body.messages[1].role, "user");
|
assert_eq!(body.messages[1].role, "user");
|
||||||
|
|
||||||
// Check system content
|
|
||||||
if let Some(OpenAIContent::Text(text)) = &body.messages[0].content {
|
if let Some(OpenAIContent::Text(text)) = &body.messages[0].content {
|
||||||
assert_eq!(text, "System prompt");
|
assert_eq!(text, "System prompt");
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -301,20 +321,39 @@ mod tests {
|
||||||
|
|
||||||
let body = scheme.build_request("llama3", &request);
|
let body = scheme.build_request("llama3", &request);
|
||||||
|
|
||||||
// max_tokens should be set, max_completion_tokens should be None
|
|
||||||
assert_eq!(body.max_tokens, Some(100));
|
assert_eq!(body.max_tokens, Some(100));
|
||||||
assert!(body.max_completion_tokens.is_none());
|
assert!(body.max_completion_tokens.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_build_request_modern_max_tokens() {
|
fn test_build_request_modern_max_tokens() {
|
||||||
let scheme = OpenAIScheme::new(); // Default matches modern (legacy=false)
|
let scheme = OpenAIScheme::new();
|
||||||
let request = Request::new().user("Hello").max_tokens(100);
|
let request = Request::new().user("Hello").max_tokens(100);
|
||||||
|
|
||||||
let body = scheme.build_request("gpt-4o", &request);
|
let body = scheme.build_request("gpt-4o", &request);
|
||||||
|
|
||||||
// max_completion_tokens should be set, max_tokens should be None
|
|
||||||
assert_eq!(body.max_completion_tokens, Some(100));
|
assert_eq!(body.max_completion_tokens, Some(100));
|
||||||
assert!(body.max_tokens.is_none());
|
assert!(body.max_tokens.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_function_call_and_output() {
|
||||||
|
let scheme = OpenAIScheme::new();
|
||||||
|
let request = Request::new()
|
||||||
|
.user("Check weather")
|
||||||
|
.item(Item::function_call(
|
||||||
|
"call_123",
|
||||||
|
"get_weather",
|
||||||
|
r#"{"city":"Tokyo"}"#,
|
||||||
|
))
|
||||||
|
.item(Item::function_call_output("call_123", "Sunny, 25°C"));
|
||||||
|
|
||||||
|
let body = scheme.build_request("gpt-4o", &request);
|
||||||
|
|
||||||
|
assert_eq!(body.messages.len(), 3);
|
||||||
|
assert_eq!(body.messages[0].role, "user");
|
||||||
|
assert_eq!(body.messages[1].role, "assistant");
|
||||||
|
assert_eq!(body.messages[1].tool_calls.len(), 1);
|
||||||
|
assert_eq!(body.messages[2].role, "tool");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
494
llm-worker/src/llm_client/scheme/openresponses/events.rs
Normal file
494
llm-worker/src/llm_client/scheme/openresponses/events.rs
Normal file
|
|
@ -0,0 +1,494 @@
|
||||||
|
//! Open Responses Event Parser
|
||||||
|
//!
|
||||||
|
//! Parses SSE events from the Open Responses API into internal Event types.
|
||||||
|
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
use crate::llm_client::{
|
||||||
|
event::{
|
||||||
|
BlockMetadata, BlockStart, BlockStop, DeltaContent, ErrorEvent, Event, ResponseStatus,
|
||||||
|
StatusEvent, StopReason, UsageEvent,
|
||||||
|
},
|
||||||
|
ClientError,
|
||||||
|
};
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Open Responses SSE Event Types
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
/// Response created event
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ResponseCreatedEvent {
|
||||||
|
pub response: ResponseObject,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Response object
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ResponseObject {
|
||||||
|
pub id: String,
|
||||||
|
pub status: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub output: Vec<OutputItem>,
|
||||||
|
pub usage: Option<UsageObject>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Output item in response
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum OutputItem {
|
||||||
|
Message {
|
||||||
|
id: String,
|
||||||
|
role: String,
|
||||||
|
#[serde(default)]
|
||||||
|
content: Vec<ContentPartObject>,
|
||||||
|
},
|
||||||
|
FunctionCall {
|
||||||
|
id: String,
|
||||||
|
call_id: String,
|
||||||
|
name: String,
|
||||||
|
arguments: String,
|
||||||
|
},
|
||||||
|
Reasoning {
|
||||||
|
id: String,
|
||||||
|
#[serde(default)]
|
||||||
|
text: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Content part object
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum ContentPartObject {
|
||||||
|
OutputText { text: String },
|
||||||
|
InputText { text: String },
|
||||||
|
Refusal { refusal: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Usage object
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct UsageObject {
|
||||||
|
pub input_tokens: Option<u64>,
|
||||||
|
pub output_tokens: Option<u64>,
|
||||||
|
pub total_tokens: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Output item added event
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct OutputItemAddedEvent {
|
||||||
|
pub output_index: usize,
|
||||||
|
pub item: OutputItem,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Text delta event
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct TextDeltaEvent {
|
||||||
|
pub output_index: usize,
|
||||||
|
pub content_index: usize,
|
||||||
|
pub delta: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Text done event
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct TextDoneEvent {
|
||||||
|
pub output_index: usize,
|
||||||
|
pub content_index: usize,
|
||||||
|
pub text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Function call arguments delta event
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct FunctionCallArgumentsDeltaEvent {
|
||||||
|
pub output_index: usize,
|
||||||
|
pub call_id: String,
|
||||||
|
pub delta: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Function call arguments done event
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct FunctionCallArgumentsDoneEvent {
|
||||||
|
pub output_index: usize,
|
||||||
|
pub call_id: String,
|
||||||
|
pub arguments: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reasoning delta event
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ReasoningDeltaEvent {
|
||||||
|
pub output_index: usize,
|
||||||
|
pub delta: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reasoning done event
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ReasoningDoneEvent {
|
||||||
|
pub output_index: usize,
|
||||||
|
pub text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Content part done event
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ContentPartDoneEvent {
|
||||||
|
pub output_index: usize,
|
||||||
|
pub content_index: usize,
|
||||||
|
pub part: ContentPartObject,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Output item done event
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct OutputItemDoneEvent {
|
||||||
|
pub output_index: usize,
|
||||||
|
pub item: OutputItem,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Response done event
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ResponseDoneEvent {
|
||||||
|
pub response: ResponseObject,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Error event from API
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ApiErrorEvent {
|
||||||
|
pub error: ApiError,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// API error details
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ApiError {
|
||||||
|
pub code: Option<String>,
|
||||||
|
pub message: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Event Parsing
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
/// Parse SSE event into internal Event(s)
|
||||||
|
///
|
||||||
|
/// Returns `Ok(None)` for events that should be ignored (e.g., heartbeats)
|
||||||
|
/// Returns `Ok(Some(vec))` for events that produce one or more internal Events
|
||||||
|
pub fn parse_event(event_type: &str, data: &str) -> Result<Option<Vec<Event>>, ClientError> {
|
||||||
|
// Skip empty data
|
||||||
|
if data.is_empty() || data == "[DONE]" {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
let events = match event_type {
|
||||||
|
// Response lifecycle
|
||||||
|
"response.created" => {
|
||||||
|
let _event: ResponseCreatedEvent = parse_json(data)?;
|
||||||
|
Some(vec![Event::Status(StatusEvent {
|
||||||
|
status: ResponseStatus::Started,
|
||||||
|
})])
|
||||||
|
}
|
||||||
|
|
||||||
|
"response.in_progress" => {
|
||||||
|
// Just a status update, no action needed
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
"response.completed" | "response.done" => {
|
||||||
|
let event: ResponseDoneEvent = parse_json(data)?;
|
||||||
|
let mut events = Vec::new();
|
||||||
|
|
||||||
|
// Emit usage if present
|
||||||
|
if let Some(usage) = event.response.usage {
|
||||||
|
events.push(Event::Usage(UsageEvent {
|
||||||
|
input_tokens: usage.input_tokens,
|
||||||
|
output_tokens: usage.output_tokens,
|
||||||
|
total_tokens: usage.total_tokens,
|
||||||
|
cache_read_input_tokens: None,
|
||||||
|
cache_creation_input_tokens: None,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
events.push(Event::Status(StatusEvent {
|
||||||
|
status: ResponseStatus::Completed,
|
||||||
|
}));
|
||||||
|
Some(events)
|
||||||
|
}
|
||||||
|
|
||||||
|
"response.failed" => {
|
||||||
|
// Try to parse error
|
||||||
|
if let Ok(error_event) = parse_json::<ApiErrorEvent>(data) {
|
||||||
|
Some(vec![
|
||||||
|
Event::Error(ErrorEvent {
|
||||||
|
code: error_event.error.code,
|
||||||
|
message: error_event.error.message,
|
||||||
|
}),
|
||||||
|
Event::Status(StatusEvent {
|
||||||
|
status: ResponseStatus::Failed,
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
} else {
|
||||||
|
Some(vec![Event::Status(StatusEvent {
|
||||||
|
status: ResponseStatus::Failed,
|
||||||
|
})])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output item events
|
||||||
|
"response.output_item.added" => {
|
||||||
|
let event: OutputItemAddedEvent = parse_json(data)?;
|
||||||
|
Some(vec![convert_item_added(&event)])
|
||||||
|
}
|
||||||
|
|
||||||
|
"response.output_item.done" => {
|
||||||
|
let event: OutputItemDoneEvent = parse_json(data)?;
|
||||||
|
Some(vec![convert_item_done(&event)])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Text content events
|
||||||
|
"response.output_text.delta" => {
|
||||||
|
let event: TextDeltaEvent = parse_json(data)?;
|
||||||
|
Some(vec![Event::text_delta(event.output_index, &event.delta)])
|
||||||
|
}
|
||||||
|
|
||||||
|
"response.output_text.done" => {
|
||||||
|
// Text done - we'll handle stop in output_item.done
|
||||||
|
let _event: TextDoneEvent = parse_json(data)?;
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
// Content part events
|
||||||
|
"response.content_part.added" => {
|
||||||
|
// Content part added - we handle this via output_item.added
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
"response.content_part.done" => {
|
||||||
|
// Content part done - we handle stop in output_item.done
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function call events
|
||||||
|
"response.function_call_arguments.delta" => {
|
||||||
|
let event: FunctionCallArgumentsDeltaEvent = parse_json(data)?;
|
||||||
|
Some(vec![Event::BlockDelta(crate::llm_client::event::BlockDelta {
|
||||||
|
index: event.output_index,
|
||||||
|
delta: DeltaContent::InputJson(event.delta),
|
||||||
|
})])
|
||||||
|
}
|
||||||
|
|
||||||
|
"response.function_call_arguments.done" => {
|
||||||
|
// Arguments done - we handle stop in output_item.done
|
||||||
|
let _event: FunctionCallArgumentsDoneEvent = parse_json(data)?;
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reasoning events
|
||||||
|
"response.reasoning.delta" | "response.reasoning_summary_text.delta" => {
|
||||||
|
let event: ReasoningDeltaEvent = parse_json(data)?;
|
||||||
|
Some(vec![Event::BlockDelta(crate::llm_client::event::BlockDelta {
|
||||||
|
index: event.output_index,
|
||||||
|
delta: DeltaContent::Thinking(event.delta),
|
||||||
|
})])
|
||||||
|
}
|
||||||
|
|
||||||
|
"response.reasoning.done" | "response.reasoning_summary_text.done" => {
|
||||||
|
// Reasoning done - we handle stop in output_item.done
|
||||||
|
let _event: ReasoningDoneEvent = parse_json(data)?;
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error event
|
||||||
|
"error" => {
|
||||||
|
let event: ApiErrorEvent = parse_json(data)?;
|
||||||
|
Some(vec![Event::Error(ErrorEvent {
|
||||||
|
code: event.error.code,
|
||||||
|
message: event.error.message,
|
||||||
|
})])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unknown event type - ignore
|
||||||
|
_ => {
|
||||||
|
tracing::debug!(event_type = event_type, "Unknown Open Responses event type");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(events)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_json<T: serde::de::DeserializeOwned>(data: &str) -> Result<T, ClientError> {
|
||||||
|
serde_json::from_str(data).map_err(|e| ClientError::Parse(e.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_item_added(event: &OutputItemAddedEvent) -> Event {
|
||||||
|
match &event.item {
|
||||||
|
OutputItem::Message { id, role: _, content: _ } => Event::BlockStart(BlockStart {
|
||||||
|
index: event.output_index,
|
||||||
|
block_type: crate::llm_client::event::BlockType::Text,
|
||||||
|
metadata: BlockMetadata::Text,
|
||||||
|
}),
|
||||||
|
|
||||||
|
OutputItem::FunctionCall {
|
||||||
|
id,
|
||||||
|
call_id,
|
||||||
|
name,
|
||||||
|
arguments: _,
|
||||||
|
} => Event::BlockStart(BlockStart {
|
||||||
|
index: event.output_index,
|
||||||
|
block_type: crate::llm_client::event::BlockType::ToolUse,
|
||||||
|
metadata: BlockMetadata::ToolUse {
|
||||||
|
id: call_id.clone(),
|
||||||
|
name: name.clone(),
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
|
||||||
|
OutputItem::Reasoning { id, text: _ } => Event::BlockStart(BlockStart {
|
||||||
|
index: event.output_index,
|
||||||
|
block_type: crate::llm_client::event::BlockType::Thinking,
|
||||||
|
metadata: BlockMetadata::Thinking,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_item_done(event: &OutputItemDoneEvent) -> Event {
|
||||||
|
let stop_reason = match &event.item {
|
||||||
|
OutputItem::FunctionCall { .. } => Some(StopReason::ToolUse),
|
||||||
|
_ => Some(StopReason::EndTurn),
|
||||||
|
};
|
||||||
|
|
||||||
|
Event::BlockStop(BlockStop {
|
||||||
|
index: event.output_index,
|
||||||
|
stop_reason,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_response_created() {
|
||||||
|
let data = r#"{"response":{"id":"resp_123","status":"in_progress","output":[]}}"#;
|
||||||
|
let events = parse_event("response.created", data).unwrap().unwrap();
|
||||||
|
assert_eq!(events.len(), 1);
|
||||||
|
assert!(matches!(
|
||||||
|
events[0],
|
||||||
|
Event::Status(StatusEvent {
|
||||||
|
status: ResponseStatus::Started
|
||||||
|
})
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_text_delta() {
|
||||||
|
let data = r#"{"output_index":0,"content_index":0,"delta":"Hello"}"#;
|
||||||
|
let events = parse_event("response.output_text.delta", data)
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(events.len(), 1);
|
||||||
|
if let Event::BlockDelta(delta) = &events[0] {
|
||||||
|
assert_eq!(delta.index, 0);
|
||||||
|
assert!(matches!(&delta.delta, DeltaContent::Text(t) if t == "Hello"));
|
||||||
|
} else {
|
||||||
|
panic!("Expected BlockDelta");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_output_item_added_message() {
|
||||||
|
let data = r#"{"output_index":0,"item":{"type":"message","id":"msg_123","role":"assistant","content":[]}}"#;
|
||||||
|
let events = parse_event("response.output_item.added", data)
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(events.len(), 1);
|
||||||
|
if let Event::BlockStart(start) = &events[0] {
|
||||||
|
assert_eq!(start.index, 0);
|
||||||
|
assert!(matches!(
|
||||||
|
start.block_type,
|
||||||
|
crate::llm_client::event::BlockType::Text
|
||||||
|
));
|
||||||
|
} else {
|
||||||
|
panic!("Expected BlockStart");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_output_item_added_function_call() {
|
||||||
|
let data = r#"{"output_index":1,"item":{"type":"function_call","id":"fc_123","call_id":"call_456","name":"get_weather","arguments":""}}"#;
|
||||||
|
let events = parse_event("response.output_item.added", data)
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(events.len(), 1);
|
||||||
|
if let Event::BlockStart(start) = &events[0] {
|
||||||
|
assert_eq!(start.index, 1);
|
||||||
|
assert!(matches!(
|
||||||
|
start.block_type,
|
||||||
|
crate::llm_client::event::BlockType::ToolUse
|
||||||
|
));
|
||||||
|
if let BlockMetadata::ToolUse { id, name } = &start.metadata {
|
||||||
|
assert_eq!(id, "call_456");
|
||||||
|
assert_eq!(name, "get_weather");
|
||||||
|
} else {
|
||||||
|
panic!("Expected ToolUse metadata");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
panic!("Expected BlockStart");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_function_call_arguments_delta() {
|
||||||
|
let data = r#"{"output_index":1,"call_id":"call_456","delta":"{\"city\":"}"#;
|
||||||
|
let events = parse_event("response.function_call_arguments.delta", data)
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(events.len(), 1);
|
||||||
|
if let Event::BlockDelta(delta) = &events[0] {
|
||||||
|
assert_eq!(delta.index, 1);
|
||||||
|
assert!(matches!(
|
||||||
|
&delta.delta,
|
||||||
|
DeltaContent::InputJson(s) if s == "{\"city\":"
|
||||||
|
));
|
||||||
|
} else {
|
||||||
|
panic!("Expected BlockDelta");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_response_completed() {
|
||||||
|
let data = r#"{"response":{"id":"resp_123","status":"completed","output":[],"usage":{"input_tokens":10,"output_tokens":20,"total_tokens":30}}}"#;
|
||||||
|
let events = parse_event("response.completed", data).unwrap().unwrap();
|
||||||
|
assert_eq!(events.len(), 2);
|
||||||
|
|
||||||
|
// First event should be usage
|
||||||
|
if let Event::Usage(usage) = &events[0] {
|
||||||
|
assert_eq!(usage.input_tokens, Some(10));
|
||||||
|
assert_eq!(usage.output_tokens, Some(20));
|
||||||
|
assert_eq!(usage.total_tokens, Some(30));
|
||||||
|
} else {
|
||||||
|
panic!("Expected Usage event");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second event should be status
|
||||||
|
assert!(matches!(
|
||||||
|
events[1],
|
||||||
|
Event::Status(StatusEvent {
|
||||||
|
status: ResponseStatus::Completed
|
||||||
|
})
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_error() {
|
||||||
|
let data = r#"{"error":{"code":"rate_limit","message":"Too many requests"}}"#;
|
||||||
|
let events = parse_event("error", data).unwrap().unwrap();
|
||||||
|
assert_eq!(events.len(), 1);
|
||||||
|
if let Event::Error(err) = &events[0] {
|
||||||
|
assert_eq!(err.code, Some("rate_limit".to_string()));
|
||||||
|
assert_eq!(err.message, "Too many requests");
|
||||||
|
} else {
|
||||||
|
panic!("Expected Error event");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_unknown_event() {
|
||||||
|
let data = r#"{}"#;
|
||||||
|
let events = parse_event("some.unknown.event", data).unwrap();
|
||||||
|
assert!(events.is_none());
|
||||||
|
}
|
||||||
|
}
|
||||||
49
llm-worker/src/llm_client/scheme/openresponses/mod.rs
Normal file
49
llm-worker/src/llm_client/scheme/openresponses/mod.rs
Normal file
|
|
@ -0,0 +1,49 @@
|
||||||
|
//! Open Responses Scheme
|
||||||
|
//!
|
||||||
|
//! Handles request/response conversion for the Open Responses API.
|
||||||
|
//! Since our internal types are already Open Responses native, this scheme
|
||||||
|
//! primarily passes through data with minimal transformation.
|
||||||
|
|
||||||
|
mod events;
|
||||||
|
mod request;
|
||||||
|
|
||||||
|
use crate::llm_client::{ClientError, Request};
|
||||||
|
|
||||||
|
pub use events::*;
|
||||||
|
pub use request::*;
|
||||||
|
|
||||||
|
/// Open Responses Scheme
|
||||||
|
///
|
||||||
|
/// Handles conversion between internal types and the Open Responses wire format.
|
||||||
|
#[derive(Debug, Clone, Default)]
|
||||||
|
pub struct OpenResponsesScheme {
|
||||||
|
/// Optional model override
|
||||||
|
pub model: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenResponsesScheme {
|
||||||
|
/// Create a new OpenResponsesScheme
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the model
|
||||||
|
pub fn with_model(mut self, model: impl Into<String>) -> Self {
|
||||||
|
self.model = Some(model.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build Open Responses request from internal Request
|
||||||
|
pub fn build_request(&self, model: &str, request: &Request) -> OpenResponsesRequest {
|
||||||
|
build_request(model, request)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse SSE event data into internal Event(s)
|
||||||
|
pub fn parse_event(
|
||||||
|
&self,
|
||||||
|
event_type: &str,
|
||||||
|
data: &str,
|
||||||
|
) -> Result<Option<Vec<crate::llm_client::Event>>, ClientError> {
|
||||||
|
parse_event(event_type, data)
|
||||||
|
}
|
||||||
|
}
|
||||||
285
llm-worker/src/llm_client/scheme/openresponses/request.rs
Normal file
285
llm-worker/src/llm_client/scheme/openresponses/request.rs
Normal file
|
|
@ -0,0 +1,285 @@
|
||||||
|
//! Open Responses Request Builder
|
||||||
|
//!
|
||||||
|
//! Converts internal Request/Item types to Open Responses API format.
|
||||||
|
//! Since our internal types are already Open Responses native, this is
|
||||||
|
//! mostly a direct serialization with some field renaming.
|
||||||
|
|
||||||
|
use serde::Serialize;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
use crate::llm_client::{types::Item, Request, ToolDefinition};
|
||||||
|
|
||||||
|
/// Open Responses API request body
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct OpenResponsesRequest {
|
||||||
|
/// Model identifier
|
||||||
|
pub model: String,
|
||||||
|
|
||||||
|
/// Input items (conversation history)
|
||||||
|
pub input: Vec<OpenResponsesItem>,
|
||||||
|
|
||||||
|
/// System instructions
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub instructions: Option<String>,
|
||||||
|
|
||||||
|
/// Tool definitions
|
||||||
|
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||||
|
pub tools: Vec<OpenResponsesTool>,
|
||||||
|
|
||||||
|
/// Enable streaming
|
||||||
|
pub stream: bool,
|
||||||
|
|
||||||
|
/// Maximum output tokens
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_output_tokens: Option<u32>,
|
||||||
|
|
||||||
|
/// Temperature
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
|
||||||
|
/// Top P (nucleus sampling)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Open Responses input item
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum OpenResponsesItem {
|
||||||
|
/// Message item
|
||||||
|
Message {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
id: Option<String>,
|
||||||
|
role: String,
|
||||||
|
content: Vec<OpenResponsesContentPart>,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Function call item
|
||||||
|
FunctionCall {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
id: Option<String>,
|
||||||
|
call_id: String,
|
||||||
|
name: String,
|
||||||
|
arguments: String,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Function call output item
|
||||||
|
FunctionCallOutput {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
id: Option<String>,
|
||||||
|
call_id: String,
|
||||||
|
output: String,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Reasoning item
|
||||||
|
Reasoning {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
id: Option<String>,
|
||||||
|
text: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Open Responses content part
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum OpenResponsesContentPart {
|
||||||
|
/// Input text (for user messages)
|
||||||
|
InputText { text: String },
|
||||||
|
|
||||||
|
/// Output text (for assistant messages)
|
||||||
|
OutputText { text: String },
|
||||||
|
|
||||||
|
/// Refusal
|
||||||
|
Refusal { refusal: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Open Responses tool definition
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct OpenResponsesTool {
|
||||||
|
/// Tool type (always "function")
|
||||||
|
pub r#type: String,
|
||||||
|
|
||||||
|
/// Function definition
|
||||||
|
pub name: String,
|
||||||
|
|
||||||
|
/// Description
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub description: Option<String>,
|
||||||
|
|
||||||
|
/// Parameters schema
|
||||||
|
pub parameters: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build Open Responses request from internal Request
|
||||||
|
pub fn build_request(model: &str, request: &Request) -> OpenResponsesRequest {
|
||||||
|
let input = request.items.iter().map(convert_item).collect();
|
||||||
|
let tools = request.tools.iter().map(convert_tool).collect();
|
||||||
|
|
||||||
|
OpenResponsesRequest {
|
||||||
|
model: model.to_string(),
|
||||||
|
input,
|
||||||
|
instructions: request.system_prompt.clone(),
|
||||||
|
tools,
|
||||||
|
stream: true,
|
||||||
|
max_output_tokens: request.config.max_tokens,
|
||||||
|
temperature: request.config.temperature,
|
||||||
|
top_p: request.config.top_p,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_item(item: &Item) -> OpenResponsesItem {
|
||||||
|
match item {
|
||||||
|
Item::Message {
|
||||||
|
id,
|
||||||
|
role,
|
||||||
|
content,
|
||||||
|
status: _,
|
||||||
|
} => {
|
||||||
|
let role_str = match role {
|
||||||
|
crate::llm_client::types::Role::User => "user",
|
||||||
|
crate::llm_client::types::Role::Assistant => "assistant",
|
||||||
|
crate::llm_client::types::Role::System => "system",
|
||||||
|
};
|
||||||
|
|
||||||
|
let parts = content
|
||||||
|
.iter()
|
||||||
|
.map(|p| match p {
|
||||||
|
crate::llm_client::types::ContentPart::InputText { text } => {
|
||||||
|
OpenResponsesContentPart::InputText { text: text.clone() }
|
||||||
|
}
|
||||||
|
crate::llm_client::types::ContentPart::OutputText { text } => {
|
||||||
|
OpenResponsesContentPart::OutputText { text: text.clone() }
|
||||||
|
}
|
||||||
|
crate::llm_client::types::ContentPart::Refusal { refusal } => {
|
||||||
|
OpenResponsesContentPart::Refusal {
|
||||||
|
refusal: refusal.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
OpenResponsesItem::Message {
|
||||||
|
id: id.clone(),
|
||||||
|
role: role_str.to_string(),
|
||||||
|
content: parts,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Item::FunctionCall {
|
||||||
|
id,
|
||||||
|
call_id,
|
||||||
|
name,
|
||||||
|
arguments,
|
||||||
|
status: _,
|
||||||
|
} => OpenResponsesItem::FunctionCall {
|
||||||
|
id: id.clone(),
|
||||||
|
call_id: call_id.clone(),
|
||||||
|
name: name.clone(),
|
||||||
|
arguments: arguments.clone(),
|
||||||
|
},
|
||||||
|
|
||||||
|
Item::FunctionCallOutput {
|
||||||
|
id,
|
||||||
|
call_id,
|
||||||
|
output,
|
||||||
|
} => OpenResponsesItem::FunctionCallOutput {
|
||||||
|
id: id.clone(),
|
||||||
|
call_id: call_id.clone(),
|
||||||
|
output: output.clone(),
|
||||||
|
},
|
||||||
|
|
||||||
|
Item::Reasoning {
|
||||||
|
id,
|
||||||
|
text,
|
||||||
|
status: _,
|
||||||
|
} => OpenResponsesItem::Reasoning {
|
||||||
|
id: id.clone(),
|
||||||
|
text: text.clone(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_tool(tool: &ToolDefinition) -> OpenResponsesTool {
|
||||||
|
OpenResponsesTool {
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
name: tool.name.clone(),
|
||||||
|
description: tool.description.clone(),
|
||||||
|
parameters: tool.input_schema.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::llm_client::types::Item;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_simple_request() {
|
||||||
|
let request = Request::new()
|
||||||
|
.system("You are a helpful assistant.")
|
||||||
|
.user("Hello!");
|
||||||
|
|
||||||
|
let or_req = build_request("gpt-4o", &request);
|
||||||
|
|
||||||
|
assert_eq!(or_req.model, "gpt-4o");
|
||||||
|
assert_eq!(
|
||||||
|
or_req.instructions,
|
||||||
|
Some("You are a helpful assistant.".to_string())
|
||||||
|
);
|
||||||
|
assert_eq!(or_req.input.len(), 1);
|
||||||
|
assert!(or_req.stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_request_with_tool() {
|
||||||
|
let request = Request::new().user("What's the weather?").tool(
|
||||||
|
ToolDefinition::new("get_weather")
|
||||||
|
.description("Get current weather")
|
||||||
|
.input_schema(serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": { "type": "string" }
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
})),
|
||||||
|
);
|
||||||
|
|
||||||
|
let or_req = build_request("gpt-4o", &request);
|
||||||
|
|
||||||
|
assert_eq!(or_req.tools.len(), 1);
|
||||||
|
assert_eq!(or_req.tools[0].name, "get_weather");
|
||||||
|
assert_eq!(or_req.tools[0].r#type, "function");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_function_call_and_output() {
|
||||||
|
let request = Request::new()
|
||||||
|
.user("What's the weather?")
|
||||||
|
.item(Item::function_call(
|
||||||
|
"call_123",
|
||||||
|
"get_weather",
|
||||||
|
r#"{"city":"Tokyo"}"#,
|
||||||
|
))
|
||||||
|
.item(Item::function_call_output("call_123", "Sunny, 25°C"));
|
||||||
|
|
||||||
|
let or_req = build_request("gpt-4o", &request);
|
||||||
|
|
||||||
|
assert_eq!(or_req.input.len(), 3);
|
||||||
|
|
||||||
|
// Check function call
|
||||||
|
if let OpenResponsesItem::FunctionCall { call_id, name, .. } = &or_req.input[1] {
|
||||||
|
assert_eq!(call_id, "call_123");
|
||||||
|
assert_eq!(name, "get_weather");
|
||||||
|
} else {
|
||||||
|
panic!("Expected FunctionCall");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check function call output
|
||||||
|
if let OpenResponsesItem::FunctionCallOutput { call_id, output, .. } = &or_req.input[2] {
|
||||||
|
assert_eq!(call_id, "call_123");
|
||||||
|
assert_eq!(output, "Sunny, 25°C");
|
||||||
|
} else {
|
||||||
|
panic!("Expected FunctionCallOutput");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,189 +1,491 @@
|
||||||
//! LLMクライアント共通型定義
|
//! LLM Client Common Types - Open Responses Native
|
||||||
|
//!
|
||||||
|
//! This module defines types that are natively aligned with the Open Responses specification.
|
||||||
|
//! The core abstraction is `Item` which represents different types of conversation elements:
|
||||||
|
//! - Message items (user/assistant messages with content parts)
|
||||||
|
//! - FunctionCall items (tool invocations)
|
||||||
|
//! - FunctionCallOutput items (tool results)
|
||||||
|
//! - Reasoning items (extended thinking)
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// リクエスト構造体
|
// ============================================================================
|
||||||
|
// Item - The core unit of conversation
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
/// Item ID type for tracking items in a conversation
|
||||||
|
pub type ItemId = String;
|
||||||
|
|
||||||
|
/// Call ID type for linking function calls to their outputs
|
||||||
|
pub type CallId = String;
|
||||||
|
|
||||||
|
/// Conversation item - the primary unit in Open Responses
|
||||||
|
///
|
||||||
|
/// Items represent discrete elements in a conversation. Unlike traditional
|
||||||
|
/// message-based APIs, Open Responses treats tool calls and reasoning as
|
||||||
|
/// first-class items rather than parts of messages.
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```ignore
|
||||||
|
/// use llm_worker::Item;
|
||||||
|
///
|
||||||
|
/// // User message
|
||||||
|
/// let user_item = Item::user_message("Hello!");
|
||||||
|
///
|
||||||
|
/// // Assistant message
|
||||||
|
/// let assistant_item = Item::assistant_message("Hi there!");
|
||||||
|
///
|
||||||
|
/// // Function call
|
||||||
|
/// let call = Item::function_call("call_123", "get_weather", json!({"city": "Tokyo"}));
|
||||||
|
///
|
||||||
|
/// // Function call output
|
||||||
|
/// let result = Item::function_call_output("call_123", "Sunny, 25°C");
|
||||||
|
/// ```
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum Item {
|
||||||
|
/// User or assistant message with content parts
|
||||||
|
Message {
|
||||||
|
/// Optional item ID
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
id: Option<ItemId>,
|
||||||
|
/// Message role
|
||||||
|
role: Role,
|
||||||
|
/// Content parts
|
||||||
|
content: Vec<ContentPart>,
|
||||||
|
/// Item status
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
status: Option<ItemStatus>,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Function (tool) call from the assistant
|
||||||
|
FunctionCall {
|
||||||
|
/// Optional item ID
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
id: Option<ItemId>,
|
||||||
|
/// Call ID for linking to output
|
||||||
|
call_id: CallId,
|
||||||
|
/// Function name
|
||||||
|
name: String,
|
||||||
|
/// Function arguments as JSON string
|
||||||
|
arguments: String,
|
||||||
|
/// Item status
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
status: Option<ItemStatus>,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Function (tool) call output/result
|
||||||
|
FunctionCallOutput {
|
||||||
|
/// Optional item ID
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
id: Option<ItemId>,
|
||||||
|
/// Call ID linking to the function call
|
||||||
|
call_id: CallId,
|
||||||
|
/// Output content
|
||||||
|
output: String,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Reasoning/thinking item
|
||||||
|
Reasoning {
|
||||||
|
/// Optional item ID
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
id: Option<ItemId>,
|
||||||
|
/// Reasoning text
|
||||||
|
text: String,
|
||||||
|
/// Item status
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
status: Option<ItemStatus>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Item {
|
||||||
|
// ========================================================================
|
||||||
|
// Message constructors
|
||||||
|
// ========================================================================
|
||||||
|
|
||||||
|
/// Create a user message item with text content
|
||||||
|
pub fn user_message(text: impl Into<String>) -> Self {
|
||||||
|
Self::Message {
|
||||||
|
id: None,
|
||||||
|
role: Role::User,
|
||||||
|
content: vec![ContentPart::InputText {
|
||||||
|
text: text.into(),
|
||||||
|
}],
|
||||||
|
status: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a user message item with multiple content parts
|
||||||
|
pub fn user_message_parts(parts: Vec<ContentPart>) -> Self {
|
||||||
|
Self::Message {
|
||||||
|
id: None,
|
||||||
|
role: Role::User,
|
||||||
|
content: parts,
|
||||||
|
status: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an assistant message item with text content
|
||||||
|
pub fn assistant_message(text: impl Into<String>) -> Self {
|
||||||
|
Self::Message {
|
||||||
|
id: None,
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: vec![ContentPart::OutputText {
|
||||||
|
text: text.into(),
|
||||||
|
}],
|
||||||
|
status: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an assistant message item with multiple content parts
|
||||||
|
pub fn assistant_message_parts(parts: Vec<ContentPart>) -> Self {
|
||||||
|
Self::Message {
|
||||||
|
id: None,
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: parts,
|
||||||
|
status: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========================================================================
|
||||||
|
// Function call constructors
|
||||||
|
// ========================================================================
|
||||||
|
|
||||||
|
/// Create a function call item
|
||||||
|
pub fn function_call(
|
||||||
|
call_id: impl Into<String>,
|
||||||
|
name: impl Into<String>,
|
||||||
|
arguments: impl Into<String>,
|
||||||
|
) -> Self {
|
||||||
|
Self::FunctionCall {
|
||||||
|
id: None,
|
||||||
|
call_id: call_id.into(),
|
||||||
|
name: name.into(),
|
||||||
|
arguments: arguments.into(),
|
||||||
|
status: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a function call item from a JSON value
|
||||||
|
pub fn function_call_json(
|
||||||
|
call_id: impl Into<String>,
|
||||||
|
name: impl Into<String>,
|
||||||
|
arguments: serde_json::Value,
|
||||||
|
) -> Self {
|
||||||
|
Self::function_call(call_id, name, arguments.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a function call output item
|
||||||
|
pub fn function_call_output(call_id: impl Into<String>, output: impl Into<String>) -> Self {
|
||||||
|
Self::FunctionCallOutput {
|
||||||
|
id: None,
|
||||||
|
call_id: call_id.into(),
|
||||||
|
output: output.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========================================================================
|
||||||
|
// Reasoning constructors
|
||||||
|
// ========================================================================
|
||||||
|
|
||||||
|
/// Create a reasoning item
|
||||||
|
pub fn reasoning(text: impl Into<String>) -> Self {
|
||||||
|
Self::Reasoning {
|
||||||
|
id: None,
|
||||||
|
text: text.into(),
|
||||||
|
status: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========================================================================
|
||||||
|
// Builder methods
|
||||||
|
// ========================================================================
|
||||||
|
|
||||||
|
/// Set the item ID
|
||||||
|
pub fn with_id(mut self, id: impl Into<String>) -> Self {
|
||||||
|
match &mut self {
|
||||||
|
Self::Message { id: item_id, .. } => *item_id = Some(id.into()),
|
||||||
|
Self::FunctionCall { id: item_id, .. } => *item_id = Some(id.into()),
|
||||||
|
Self::FunctionCallOutput { id: item_id, .. } => *item_id = Some(id.into()),
|
||||||
|
Self::Reasoning { id: item_id, .. } => *item_id = Some(id.into()),
|
||||||
|
}
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the item status
|
||||||
|
pub fn with_status(mut self, new_status: ItemStatus) -> Self {
|
||||||
|
match &mut self {
|
||||||
|
Self::Message { status, .. } => *status = Some(new_status),
|
||||||
|
Self::FunctionCall { status, .. } => *status = Some(new_status),
|
||||||
|
Self::FunctionCallOutput { .. } => {} // Output items don't have status
|
||||||
|
Self::Reasoning { status, .. } => *status = Some(new_status),
|
||||||
|
}
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========================================================================
|
||||||
|
// Accessors
|
||||||
|
// ========================================================================
|
||||||
|
|
||||||
|
/// Get the item ID if set
|
||||||
|
pub fn id(&self) -> Option<&str> {
|
||||||
|
match self {
|
||||||
|
Self::Message { id, .. } => id.as_deref(),
|
||||||
|
Self::FunctionCall { id, .. } => id.as_deref(),
|
||||||
|
Self::FunctionCallOutput { id, .. } => id.as_deref(),
|
||||||
|
Self::Reasoning { id, .. } => id.as_deref(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the item type as a string
|
||||||
|
pub fn item_type(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::Message { .. } => "message",
|
||||||
|
Self::FunctionCall { .. } => "function_call",
|
||||||
|
Self::FunctionCallOutput { .. } => "function_call_output",
|
||||||
|
Self::Reasoning { .. } => "reasoning",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if this is a user message
|
||||||
|
pub fn is_user_message(&self) -> bool {
|
||||||
|
matches!(self, Self::Message { role: Role::User, .. })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if this is an assistant message
|
||||||
|
pub fn is_assistant_message(&self) -> bool {
|
||||||
|
matches!(self, Self::Message { role: Role::Assistant, .. })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if this is a function call
|
||||||
|
pub fn is_function_call(&self) -> bool {
|
||||||
|
matches!(self, Self::FunctionCall { .. })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if this is a function call output
|
||||||
|
pub fn is_function_call_output(&self) -> bool {
|
||||||
|
matches!(self, Self::FunctionCallOutput { .. })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if this is a reasoning item
|
||||||
|
pub fn is_reasoning(&self) -> bool {
|
||||||
|
matches!(self, Self::Reasoning { .. })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get text content if this is a simple text message
|
||||||
|
pub fn as_text(&self) -> Option<&str> {
|
||||||
|
match self {
|
||||||
|
Self::Message { content, .. } if content.len() == 1 => match &content[0] {
|
||||||
|
ContentPart::InputText { text } => Some(text),
|
||||||
|
ContentPart::OutputText { text } => Some(text),
|
||||||
|
_ => None,
|
||||||
|
},
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Content Parts - Components within message items
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
/// Content part within a message item
|
||||||
|
///
|
||||||
|
/// Open Responses distinguishes between input and output content types.
|
||||||
|
/// Input types are used in user messages, output types in assistant messages.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum ContentPart {
|
||||||
|
/// Input text (for user messages)
|
||||||
|
InputText {
|
||||||
|
/// The text content
|
||||||
|
text: String,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Output text (for assistant messages)
|
||||||
|
OutputText {
|
||||||
|
/// The text content
|
||||||
|
text: String,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Refusal content (for assistant messages)
|
||||||
|
Refusal {
|
||||||
|
/// The refusal message
|
||||||
|
refusal: String,
|
||||||
|
},
|
||||||
|
// Future: InputAudio, OutputAudio, etc.
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ContentPart {
|
||||||
|
/// Create an input text part
|
||||||
|
pub fn input_text(text: impl Into<String>) -> Self {
|
||||||
|
Self::InputText { text: text.into() }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an output text part
|
||||||
|
pub fn output_text(text: impl Into<String>) -> Self {
|
||||||
|
Self::OutputText { text: text.into() }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a refusal part
|
||||||
|
pub fn refusal(refusal: impl Into<String>) -> Self {
|
||||||
|
Self::Refusal {
|
||||||
|
refusal: refusal.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the text content regardless of type
|
||||||
|
pub fn as_text(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
Self::InputText { text } => text,
|
||||||
|
Self::OutputText { text } => text,
|
||||||
|
Self::Refusal { refusal } => refusal,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Role and Status
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
/// Message role
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum Role {
|
||||||
|
/// User
|
||||||
|
User,
|
||||||
|
/// Assistant
|
||||||
|
Assistant,
|
||||||
|
/// System (for system prompts, not typically used in items)
|
||||||
|
System,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Item status
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum ItemStatus {
|
||||||
|
/// Item is being generated
|
||||||
|
InProgress,
|
||||||
|
/// Item completed successfully
|
||||||
|
Completed,
|
||||||
|
/// Item was truncated (e.g., max tokens)
|
||||||
|
Incomplete,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Request Types
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
/// LLM Request
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default)]
|
||||||
pub struct Request {
|
pub struct Request {
|
||||||
/// システムプロンプト
|
/// System prompt (instructions)
|
||||||
pub system_prompt: Option<String>,
|
pub system_prompt: Option<String>,
|
||||||
/// メッセージ履歴
|
/// Input items (conversation history)
|
||||||
pub messages: Vec<Message>,
|
pub items: Vec<Item>,
|
||||||
/// ツール定義
|
/// Tool definitions
|
||||||
pub tools: Vec<ToolDefinition>,
|
pub tools: Vec<ToolDefinition>,
|
||||||
/// リクエスト設定
|
/// Request configuration
|
||||||
pub config: RequestConfig,
|
pub config: RequestConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Request {
|
impl Request {
|
||||||
/// 新しいリクエストを作成
|
/// Create a new empty request
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self::default()
|
Self::default()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// システムプロンプトを設定
|
/// Set the system prompt
|
||||||
pub fn system(mut self, prompt: impl Into<String>) -> Self {
|
pub fn system(mut self, prompt: impl Into<String>) -> Self {
|
||||||
self.system_prompt = Some(prompt.into());
|
self.system_prompt = Some(prompt.into());
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ユーザーメッセージを追加
|
/// Add a user message
|
||||||
pub fn user(mut self, content: impl Into<String>) -> Self {
|
pub fn user(mut self, content: impl Into<String>) -> Self {
|
||||||
self.messages.push(Message::user(content));
|
self.items.push(Item::user_message(content));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// アシスタントメッセージを追加
|
/// Add an assistant message
|
||||||
pub fn assistant(mut self, content: impl Into<String>) -> Self {
|
pub fn assistant(mut self, content: impl Into<String>) -> Self {
|
||||||
self.messages.push(Message::assistant(content));
|
self.items.push(Item::assistant_message(content));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// メッセージを追加
|
/// Add an item
|
||||||
pub fn message(mut self, message: Message) -> Self {
|
pub fn item(mut self, item: Item) -> Self {
|
||||||
self.messages.push(message);
|
self.items.push(item);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ツールを追加
|
/// Add multiple items
|
||||||
|
pub fn items(mut self, items: impl IntoIterator<Item = Item>) -> Self {
|
||||||
|
self.items.extend(items);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a tool definition
|
||||||
pub fn tool(mut self, tool: ToolDefinition) -> Self {
|
pub fn tool(mut self, tool: ToolDefinition) -> Self {
|
||||||
self.tools.push(tool);
|
self.tools.push(tool);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 設定を適用
|
/// Set the request config
|
||||||
pub fn config(mut self, config: RequestConfig) -> Self {
|
pub fn config(mut self, config: RequestConfig) -> Self {
|
||||||
self.config = config;
|
self.config = config;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// max_tokensを設定
|
/// Set max tokens
|
||||||
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
|
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
|
||||||
self.config.max_tokens = Some(max_tokens);
|
self.config.max_tokens = Some(max_tokens);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// temperatureを設定
|
/// Set temperature
|
||||||
pub fn temperature(mut self, temperature: f32) -> Self {
|
pub fn temperature(mut self, temperature: f32) -> Self {
|
||||||
self.config.temperature = Some(temperature);
|
self.config.temperature = Some(temperature);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// top_pを設定
|
/// Set top_p
|
||||||
pub fn top_p(mut self, top_p: f32) -> Self {
|
pub fn top_p(mut self, top_p: f32) -> Self {
|
||||||
self.config.top_p = Some(top_p);
|
self.config.top_p = Some(top_p);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// top_kを設定
|
/// Set top_k
|
||||||
pub fn top_k(mut self, top_k: u32) -> Self {
|
pub fn top_k(mut self, top_k: u32) -> Self {
|
||||||
self.config.top_k = Some(top_k);
|
self.config.top_k = Some(top_k);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ストップシーケンスを追加
|
/// Add a stop sequence
|
||||||
pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||||||
self.config.stop_sequences.push(sequence.into());
|
self.config.stop_sequences.push(sequence.into());
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// メッセージ
|
// ============================================================================
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
// Tool Definition
|
||||||
pub struct Message {
|
// ============================================================================
|
||||||
/// ロール
|
|
||||||
pub role: Role,
|
|
||||||
/// コンテンツ
|
|
||||||
pub content: MessageContent,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Message {
|
/// Tool (function) definition
|
||||||
/// ユーザーメッセージを作成
|
|
||||||
pub fn user(content: impl Into<String>) -> Self {
|
|
||||||
Self {
|
|
||||||
role: Role::User,
|
|
||||||
content: MessageContent::Text(content.into()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// アシスタントメッセージを作成
|
|
||||||
pub fn assistant(content: impl Into<String>) -> Self {
|
|
||||||
Self {
|
|
||||||
role: Role::Assistant,
|
|
||||||
content: MessageContent::Text(content.into()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// ツール結果メッセージを作成
|
|
||||||
pub fn tool_result(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
|
|
||||||
Self {
|
|
||||||
role: Role::User,
|
|
||||||
content: MessageContent::ToolResult {
|
|
||||||
tool_use_id: tool_use_id.into(),
|
|
||||||
content: content.into(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// ロール
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "lowercase")]
|
|
||||||
pub enum Role {
|
|
||||||
User,
|
|
||||||
Assistant,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// メッセージコンテンツ
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum MessageContent {
|
|
||||||
/// テキストコンテンツ
|
|
||||||
Text(String),
|
|
||||||
/// ツール結果
|
|
||||||
ToolResult {
|
|
||||||
tool_use_id: String,
|
|
||||||
content: String,
|
|
||||||
},
|
|
||||||
/// 複合コンテンツ (テキスト + ツール使用等)
|
|
||||||
Parts(Vec<ContentPart>),
|
|
||||||
}
|
|
||||||
|
|
||||||
/// コンテンツパーツ
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
#[serde(tag = "type")]
|
|
||||||
pub enum ContentPart {
|
|
||||||
/// テキスト
|
|
||||||
#[serde(rename = "text")]
|
|
||||||
Text { text: String },
|
|
||||||
/// ツール使用
|
|
||||||
#[serde(rename = "tool_use")]
|
|
||||||
ToolUse {
|
|
||||||
id: String,
|
|
||||||
name: String,
|
|
||||||
input: serde_json::Value,
|
|
||||||
},
|
|
||||||
/// ツール結果
|
|
||||||
#[serde(rename = "tool_result")]
|
|
||||||
ToolResult {
|
|
||||||
tool_use_id: String,
|
|
||||||
content: String,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
/// ツール定義
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct ToolDefinition {
|
pub struct ToolDefinition {
|
||||||
/// ツール名
|
/// Tool name
|
||||||
pub name: String,
|
pub name: String,
|
||||||
/// 説明
|
/// Tool description
|
||||||
pub description: Option<String>,
|
pub description: Option<String>,
|
||||||
/// 入力スキーマ (JSON Schema)
|
/// Input schema (JSON Schema)
|
||||||
pub input_schema: serde_json::Value,
|
pub input_schema: serde_json::Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToolDefinition {
|
impl ToolDefinition {
|
||||||
/// 新しいツール定義を作成
|
/// Create a new tool definition
|
||||||
pub fn new(name: impl Into<String>) -> Self {
|
pub fn new(name: impl Into<String>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
name: name.into(),
|
name: name.into(),
|
||||||
|
|
@ -195,65 +497,69 @@ impl ToolDefinition {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 説明を設定
|
/// Set the description
|
||||||
pub fn description(mut self, desc: impl Into<String>) -> Self {
|
pub fn description(mut self, desc: impl Into<String>) -> Self {
|
||||||
self.description = Some(desc.into());
|
self.description = Some(desc.into());
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 入力スキーマを設定
|
/// Set the input schema
|
||||||
pub fn input_schema(mut self, schema: serde_json::Value) -> Self {
|
pub fn input_schema(mut self, schema: serde_json::Value) -> Self {
|
||||||
self.input_schema = schema;
|
self.input_schema = schema;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// リクエスト設定
|
// ============================================================================
|
||||||
|
// Request Config
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
/// Request configuration
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default)]
|
||||||
pub struct RequestConfig {
|
pub struct RequestConfig {
|
||||||
/// 最大トークン数
|
/// Maximum tokens to generate
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
/// Temperature
|
/// Temperature (randomness)
|
||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
/// Top P (nucleus sampling)
|
/// Top P (nucleus sampling)
|
||||||
pub top_p: Option<f32>,
|
pub top_p: Option<f32>,
|
||||||
/// Top K
|
/// Top K
|
||||||
pub top_k: Option<u32>,
|
pub top_k: Option<u32>,
|
||||||
/// ストップシーケンス
|
/// Stop sequences
|
||||||
pub stop_sequences: Vec<String>,
|
pub stop_sequences: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RequestConfig {
|
impl RequestConfig {
|
||||||
/// 新しいデフォルト設定を作成
|
/// Create a new default config
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self::default()
|
Self::default()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 最大トークン数を設定
|
/// Set max tokens
|
||||||
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
|
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
|
||||||
self.max_tokens = Some(max_tokens);
|
self.max_tokens = Some(max_tokens);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// temperatureを設定
|
/// Set temperature
|
||||||
pub fn with_temperature(mut self, temperature: f32) -> Self {
|
pub fn with_temperature(mut self, temperature: f32) -> Self {
|
||||||
self.temperature = Some(temperature);
|
self.temperature = Some(temperature);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// top_pを設定
|
/// Set top_p
|
||||||
pub fn with_top_p(mut self, top_p: f32) -> Self {
|
pub fn with_top_p(mut self, top_p: f32) -> Self {
|
||||||
self.top_p = Some(top_p);
|
self.top_p = Some(top_p);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// top_kを設定
|
/// Set top_k
|
||||||
pub fn with_top_k(mut self, top_k: u32) -> Self {
|
pub fn with_top_k(mut self, top_k: u32) -> Self {
|
||||||
self.top_k = Some(top_k);
|
self.top_k = Some(top_k);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ストップシーケンスを追加
|
/// Add a stop sequence
|
||||||
pub fn with_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
pub fn with_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||||||
self.stop_sequences.push(sequence.into());
|
self.stop_sequences.push(sequence.into());
|
||||||
self
|
self
|
||||||
|
|
|
||||||
|
|
@ -1,116 +1,16 @@
|
||||||
//! メッセージ型
|
//! Message and Item Types
|
||||||
//!
|
//!
|
||||||
//! LLMとの会話で使用されるメッセージ構造。
|
//! This module provides the core types for representing conversation items
|
||||||
//! [`Message::user`]や[`Message::assistant`]で簡単に作成できます。
|
//! in the Open Responses format.
|
||||||
|
//!
|
||||||
|
//! The primary type is [`Item`], which represents different kinds of conversation
|
||||||
|
//! elements: messages, function calls, function call outputs, and reasoning.
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
// Re-export all types from llm_client::types
|
||||||
|
pub use crate::llm_client::types::{ContentPart, Item, Role};
|
||||||
|
|
||||||
/// メッセージのロール
|
/// Convenience alias for backward compatibility
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "lowercase")]
|
|
||||||
pub enum Role {
|
|
||||||
/// ユーザー
|
|
||||||
User,
|
|
||||||
/// アシスタント
|
|
||||||
Assistant,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 会話のメッセージ
|
|
||||||
///
|
///
|
||||||
/// # Examples
|
/// In the Open Responses model, messages are just one type of Item.
|
||||||
///
|
/// This alias allows code that expects a "Message" type to continue working.
|
||||||
/// ```ignore
|
pub type Message = Item;
|
||||||
/// use llm_worker::Message;
|
|
||||||
///
|
|
||||||
/// // ユーザーメッセージ
|
|
||||||
/// let user_msg = Message::user("Hello!");
|
|
||||||
///
|
|
||||||
/// // アシスタントメッセージ
|
|
||||||
/// let assistant_msg = Message::assistant("Hi there!");
|
|
||||||
/// ```
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct Message {
|
|
||||||
/// ロール
|
|
||||||
pub role: Role,
|
|
||||||
/// コンテンツ
|
|
||||||
pub content: MessageContent,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// メッセージコンテンツ
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum MessageContent {
|
|
||||||
/// テキストコンテンツ
|
|
||||||
Text(String),
|
|
||||||
/// ツール結果
|
|
||||||
ToolResult {
|
|
||||||
tool_use_id: String,
|
|
||||||
content: String,
|
|
||||||
},
|
|
||||||
/// 複合コンテンツ (テキスト + ツール使用等)
|
|
||||||
Parts(Vec<ContentPart>),
|
|
||||||
}
|
|
||||||
|
|
||||||
/// コンテンツパーツ
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
#[serde(tag = "type")]
|
|
||||||
pub enum ContentPart {
|
|
||||||
/// テキスト
|
|
||||||
#[serde(rename = "text")]
|
|
||||||
Text { text: String },
|
|
||||||
/// ツール使用
|
|
||||||
#[serde(rename = "tool_use")]
|
|
||||||
ToolUse {
|
|
||||||
id: String,
|
|
||||||
name: String,
|
|
||||||
input: serde_json::Value,
|
|
||||||
},
|
|
||||||
/// ツール結果
|
|
||||||
#[serde(rename = "tool_result")]
|
|
||||||
ToolResult {
|
|
||||||
tool_use_id: String,
|
|
||||||
content: String,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Message {
|
|
||||||
/// ユーザーメッセージを作成
|
|
||||||
///
|
|
||||||
/// # Examples
|
|
||||||
///
|
|
||||||
/// ```ignore
|
|
||||||
/// use llm_worker::Message;
|
|
||||||
/// let msg = Message::user("こんにちは");
|
|
||||||
/// ```
|
|
||||||
pub fn user(content: impl Into<String>) -> Self {
|
|
||||||
Self {
|
|
||||||
role: Role::User,
|
|
||||||
content: MessageContent::Text(content.into()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// アシスタントメッセージを作成
|
|
||||||
///
|
|
||||||
/// 通常はWorker内部で自動生成されますが、
|
|
||||||
/// 履歴の初期化などで手動作成も可能です。
|
|
||||||
pub fn assistant(content: impl Into<String>) -> Self {
|
|
||||||
Self {
|
|
||||||
role: Role::Assistant,
|
|
||||||
content: MessageContent::Text(content.into()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// ツール結果メッセージを作成
|
|
||||||
///
|
|
||||||
/// Worker内部でツール実行後に自動生成されます。
|
|
||||||
/// 通常は直接作成する必要はありません。
|
|
||||||
pub fn tool_result(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
|
|
||||||
Self {
|
|
||||||
role: Role::User,
|
|
||||||
content: MessageContent::ToolResult {
|
|
||||||
tool_use_id: tool_use_id.into(),
|
|
||||||
content: content.into(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -1,25 +1,25 @@
|
||||||
//! Worker状態
|
//! Worker State
|
||||||
//!
|
//!
|
||||||
//! Type-stateパターンによるキャッシュ保護のための状態マーカー型。
|
//! State marker types for cache protection using the Type-state pattern.
|
||||||
//! Workerは`Mutable` → `CacheLocked`の状態遷移を持ちます。
|
//! Worker has state transitions from `Mutable` → `CacheLocked`.
|
||||||
|
|
||||||
/// Worker状態を表すマーカートレイト
|
/// Marker trait representing Worker state
|
||||||
///
|
///
|
||||||
/// このトレイトはシールされており、外部から実装することはできません。
|
/// This trait is sealed and cannot be implemented externally.
|
||||||
pub trait WorkerState: private::Sealed + Send + Sync + 'static {}
|
pub trait WorkerState: private::Sealed + Send + Sync + 'static {}
|
||||||
|
|
||||||
mod private {
|
mod private {
|
||||||
pub trait Sealed {}
|
pub trait Sealed {}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 編集可能状態
|
/// Mutable state (editable)
|
||||||
///
|
///
|
||||||
/// この状態では以下の操作が可能です:
|
/// In this state, the following operations are available:
|
||||||
/// - システムプロンプトの設定・変更
|
/// - Setting/changing system prompt
|
||||||
/// - メッセージ履歴の編集(追加、削除、クリア)
|
/// - Editing message history (add, delete, clear)
|
||||||
/// - ツール・Hookの登録
|
/// - Registering tools and hooks
|
||||||
///
|
///
|
||||||
/// `Worker::lock()`により[`CacheLocked`]状態へ遷移できます。
|
/// Can transition to [`CacheLocked`] state via `Worker::lock()`.
|
||||||
///
|
///
|
||||||
/// # Examples
|
/// # Examples
|
||||||
///
|
///
|
||||||
|
|
@ -29,11 +29,11 @@ mod private {
|
||||||
/// let mut worker = Worker::new(client)
|
/// let mut worker = Worker::new(client)
|
||||||
/// .system_prompt("You are helpful.");
|
/// .system_prompt("You are helpful.");
|
||||||
///
|
///
|
||||||
/// // 履歴を編集可能
|
/// // History can be edited
|
||||||
/// worker.push_message(Message::user("Hello"));
|
/// worker.push_message(Message::user("Hello"));
|
||||||
/// worker.clear_history();
|
/// worker.clear_history();
|
||||||
///
|
///
|
||||||
/// // ロックして保護状態へ
|
/// // Lock to protected state
|
||||||
/// let locked = worker.lock();
|
/// let locked = worker.lock();
|
||||||
/// ```
|
/// ```
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
|
|
@ -42,17 +42,17 @@ pub struct Mutable;
|
||||||
impl private::Sealed for Mutable {}
|
impl private::Sealed for Mutable {}
|
||||||
impl WorkerState for Mutable {}
|
impl WorkerState for Mutable {}
|
||||||
|
|
||||||
/// キャッシュロック状態(キャッシュ保護)
|
/// Cache locked state (cache protected)
|
||||||
///
|
///
|
||||||
/// この状態では以下の制限があります:
|
/// In this state, the following restrictions apply:
|
||||||
/// - システムプロンプトの変更不可
|
/// - System prompt cannot be changed
|
||||||
/// - 既存メッセージ履歴の変更不可(末尾への追記のみ)
|
/// - Existing message history cannot be modified (only appending to the end)
|
||||||
///
|
///
|
||||||
/// LLM APIのKVキャッシュヒットを保証するため、
|
/// To ensure LLM API KV cache hits,
|
||||||
/// 実行時にはこの状態の使用が推奨されます。
|
/// using this state during execution is recommended.
|
||||||
///
|
///
|
||||||
/// `Worker::unlock()`により[`Mutable`]状態へ戻せますが、
|
/// Can return to [`Mutable`] state via `Worker::unlock()`,
|
||||||
/// キャッシュ保護が解除されることに注意してください。
|
/// but note that cache protection will be released.
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
pub struct CacheLocked;
|
pub struct CacheLocked;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
//! イベント購読
|
//! Event Subscription
|
||||||
//!
|
//!
|
||||||
//! LLMからのストリーミングイベントをリアルタイムで受信するためのトレイト。
|
//! Trait for receiving streaming events from LLM in real-time.
|
||||||
//! UIへのストリーム表示やプログレス表示に使用します。
|
//! Used for stream display to UI and progress display.
|
||||||
|
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
|
@ -18,17 +18,17 @@ use crate::{
|
||||||
// WorkerSubscriber Trait
|
// WorkerSubscriber Trait
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// LLMからのストリーミングイベントを購読するトレイト
|
/// Trait for subscribing to streaming events from LLM
|
||||||
///
|
///
|
||||||
/// Workerに登録すると、テキスト生成やツール呼び出しのイベントを
|
/// When registered with Worker, you can receive events from text generation
|
||||||
/// リアルタイムで受信できます。UIへのストリーム表示に最適です。
|
/// and tool calls in real-time. Ideal for stream display to UI.
|
||||||
///
|
///
|
||||||
/// # 受信できるイベント
|
/// # Available Events
|
||||||
///
|
///
|
||||||
/// - **ブロックイベント**: テキスト、ツール使用(スコープ付き)
|
/// - **Block events**: Text, tool use (with scope)
|
||||||
/// - **メタイベント**: 使用量、ステータス、エラー
|
/// - **Meta events**: Usage, status, error
|
||||||
/// - **完了イベント**: テキスト完了、ツール呼び出し完了
|
/// - **Completion events**: Text complete, tool call complete
|
||||||
/// - **ターン制御**: ターン開始、ターン終了
|
/// - **Turn control**: Turn start, turn end
|
||||||
///
|
///
|
||||||
/// # Examples
|
/// # Examples
|
||||||
///
|
///
|
||||||
|
|
@ -44,7 +44,7 @@ use crate::{
|
||||||
///
|
///
|
||||||
/// fn on_text_block(&mut self, _: &mut (), event: &TextBlockEvent) {
|
/// fn on_text_block(&mut self, _: &mut (), event: &TextBlockEvent) {
|
||||||
/// if let TextBlockEvent::Delta(text) = event {
|
/// if let TextBlockEvent::Delta(text) = event {
|
||||||
/// print!("{}", text); // リアルタイム出力
|
/// print!("{}", text); // Real-time output
|
||||||
/// }
|
/// }
|
||||||
/// }
|
/// }
|
||||||
///
|
///
|
||||||
|
|
@ -53,37 +53,37 @@ use crate::{
|
||||||
/// }
|
/// }
|
||||||
/// }
|
/// }
|
||||||
///
|
///
|
||||||
/// // Workerに登録
|
/// // Register with Worker
|
||||||
/// worker.subscribe(StreamPrinter);
|
/// worker.subscribe(StreamPrinter);
|
||||||
/// ```
|
/// ```
|
||||||
pub trait WorkerSubscriber: Send {
|
pub trait WorkerSubscriber: Send {
|
||||||
// =========================================================================
|
// =========================================================================
|
||||||
// スコープ型(ブロックイベント用)
|
// Scope Types (for block events)
|
||||||
// =========================================================================
|
// =========================================================================
|
||||||
|
|
||||||
/// テキストブロック処理用のスコープ型
|
/// Scope type for text block processing
|
||||||
///
|
///
|
||||||
/// ブロック開始時にDefault::default()で生成され、
|
/// Generated with Default::default() at block start,
|
||||||
/// ブロック終了時に破棄される。
|
/// destroyed at block end.
|
||||||
type TextBlockScope: Default + Send + Sync;
|
type TextBlockScope: Default + Send + Sync;
|
||||||
|
|
||||||
/// ツール使用ブロック処理用のスコープ型
|
/// Scope type for tool use block processing
|
||||||
type ToolUseBlockScope: Default + Send + Sync;
|
type ToolUseBlockScope: Default + Send + Sync;
|
||||||
|
|
||||||
// =========================================================================
|
// =========================================================================
|
||||||
// ブロックイベント(スコープ管理あり)
|
// Block Events (with scope management)
|
||||||
// =========================================================================
|
// =========================================================================
|
||||||
|
|
||||||
/// テキストブロックイベント
|
/// Text block event
|
||||||
///
|
///
|
||||||
/// Start/Delta/Stopのライフサイクルを持つ。
|
/// Has Start/Delta/Stop lifecycle.
|
||||||
/// scopeはブロック開始時に生成され、終了時に破棄される。
|
/// Scope is generated at block start and destroyed at end.
|
||||||
#[allow(unused_variables)]
|
#[allow(unused_variables)]
|
||||||
fn on_text_block(&mut self, scope: &mut Self::TextBlockScope, event: &TextBlockEvent) {}
|
fn on_text_block(&mut self, scope: &mut Self::TextBlockScope, event: &TextBlockEvent) {}
|
||||||
|
|
||||||
/// ツール使用ブロックイベント
|
/// Tool use block event
|
||||||
///
|
///
|
||||||
/// Start/InputJsonDelta/Stopのライフサイクルを持つ。
|
/// Has Start/InputJsonDelta/Stop lifecycle.
|
||||||
#[allow(unused_variables)]
|
#[allow(unused_variables)]
|
||||||
fn on_tool_use_block(
|
fn on_tool_use_block(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
|
@ -93,62 +93,62 @@ pub trait WorkerSubscriber: Send {
|
||||||
}
|
}
|
||||||
|
|
||||||
// =========================================================================
|
// =========================================================================
|
||||||
// 単発イベント(スコープ不要)
|
// Single Events (no scope needed)
|
||||||
// =========================================================================
|
// =========================================================================
|
||||||
|
|
||||||
/// 使用量イベント
|
/// Usage event
|
||||||
#[allow(unused_variables)]
|
#[allow(unused_variables)]
|
||||||
fn on_usage(&mut self, event: &UsageEvent) {}
|
fn on_usage(&mut self, event: &UsageEvent) {}
|
||||||
|
|
||||||
/// ステータスイベント
|
/// Status event
|
||||||
#[allow(unused_variables)]
|
#[allow(unused_variables)]
|
||||||
fn on_status(&mut self, event: &StatusEvent) {}
|
fn on_status(&mut self, event: &StatusEvent) {}
|
||||||
|
|
||||||
/// エラーイベント
|
/// Error event
|
||||||
#[allow(unused_variables)]
|
#[allow(unused_variables)]
|
||||||
fn on_error(&mut self, event: &ErrorEvent) {}
|
fn on_error(&mut self, event: &ErrorEvent) {}
|
||||||
|
|
||||||
// =========================================================================
|
// =========================================================================
|
||||||
// 累積イベント(Worker層で追加)
|
// Accumulated Events (added in Worker layer)
|
||||||
// =========================================================================
|
// =========================================================================
|
||||||
|
|
||||||
/// テキスト完了イベント
|
/// Text complete event
|
||||||
///
|
///
|
||||||
/// テキストブロックが完了した時点で、累積されたテキスト全体が渡される。
|
/// When a text block completes, the entire accumulated text is passed.
|
||||||
/// ブロック処理後の最終結果を受け取るのに便利。
|
/// Convenient for receiving the final result after block processing.
|
||||||
#[allow(unused_variables)]
|
#[allow(unused_variables)]
|
||||||
fn on_text_complete(&mut self, text: &str) {}
|
fn on_text_complete(&mut self, text: &str) {}
|
||||||
|
|
||||||
/// ツール呼び出し完了イベント
|
/// Tool call complete event
|
||||||
///
|
///
|
||||||
/// ツール使用ブロックが完了した時点で、完全なToolCallが渡される。
|
/// When a tool use block completes, the complete ToolCall is passed.
|
||||||
#[allow(unused_variables)]
|
#[allow(unused_variables)]
|
||||||
fn on_tool_call_complete(&mut self, call: &ToolCall) {}
|
fn on_tool_call_complete(&mut self, call: &ToolCall) {}
|
||||||
|
|
||||||
// =========================================================================
|
// =========================================================================
|
||||||
// ターン制御
|
// Turn Control
|
||||||
// =========================================================================
|
// =========================================================================
|
||||||
|
|
||||||
/// ターン開始時
|
/// On turn start
|
||||||
///
|
///
|
||||||
/// `turn`は0から始まるターン番号。
|
/// `turn` is a 0-based turn number.
|
||||||
#[allow(unused_variables)]
|
#[allow(unused_variables)]
|
||||||
fn on_turn_start(&mut self, turn: usize) {}
|
fn on_turn_start(&mut self, turn: usize) {}
|
||||||
|
|
||||||
/// ターン終了時
|
/// On turn end
|
||||||
#[allow(unused_variables)]
|
#[allow(unused_variables)]
|
||||||
fn on_turn_end(&mut self, turn: usize) {}
|
fn on_turn_end(&mut self, turn: usize) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// SubscriberAdapter - WorkerSubscriberをTimelineハンドラにブリッジ
|
// SubscriberAdapter - Bridge WorkerSubscriber to Timeline handlers
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// TextBlock Handler Adapter
|
// TextBlock Handler Adapter
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// TextBlockKind用のSubscriberアダプター
|
/// Subscriber adapter for TextBlockKind
|
||||||
pub(crate) struct TextBlockSubscriberAdapter<S: WorkerSubscriber> {
|
pub(crate) struct TextBlockSubscriberAdapter<S: WorkerSubscriber> {
|
||||||
subscriber: Arc<Mutex<S>>,
|
subscriber: Arc<Mutex<S>>,
|
||||||
}
|
}
|
||||||
|
|
@ -167,10 +167,10 @@ impl<S: WorkerSubscriber> Clone for TextBlockSubscriberAdapter<S> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// TextBlockのスコープをラップ
|
/// Wrapper for TextBlock scope
|
||||||
pub struct TextBlockScopeWrapper<S: WorkerSubscriber> {
|
pub struct TextBlockScopeWrapper<S: WorkerSubscriber> {
|
||||||
inner: S::TextBlockScope,
|
inner: S::TextBlockScope,
|
||||||
buffer: String, // on_text_complete用のバッファ
|
buffer: String, // Buffer for on_text_complete
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: WorkerSubscriber> Default for TextBlockScopeWrapper<S> {
|
impl<S: WorkerSubscriber> Default for TextBlockScopeWrapper<S> {
|
||||||
|
|
@ -186,16 +186,16 @@ impl<S: WorkerSubscriber + 'static> Handler<TextBlockKind> for TextBlockSubscrib
|
||||||
type Scope = TextBlockScopeWrapper<S>;
|
type Scope = TextBlockScopeWrapper<S>;
|
||||||
|
|
||||||
fn on_event(&mut self, scope: &mut Self::Scope, event: &TextBlockEvent) {
|
fn on_event(&mut self, scope: &mut Self::Scope, event: &TextBlockEvent) {
|
||||||
// Deltaの場合はバッファに蓄積
|
// Accumulate deltas into buffer
|
||||||
if let TextBlockEvent::Delta(text) = event {
|
if let TextBlockEvent::Delta(text) = event {
|
||||||
scope.buffer.push_str(text);
|
scope.buffer.push_str(text);
|
||||||
}
|
}
|
||||||
|
|
||||||
// SubscriberのTextBlockイベントハンドラを呼び出し
|
// Call Subscriber's TextBlock event handler
|
||||||
if let Ok(mut subscriber) = self.subscriber.lock() {
|
if let Ok(mut subscriber) = self.subscriber.lock() {
|
||||||
subscriber.on_text_block(&mut scope.inner, event);
|
subscriber.on_text_block(&mut scope.inner, event);
|
||||||
|
|
||||||
// Stopの場合はon_text_completeも呼び出し
|
// Also call on_text_complete on Stop
|
||||||
if matches!(event, TextBlockEvent::Stop(_)) {
|
if matches!(event, TextBlockEvent::Stop(_)) {
|
||||||
subscriber.on_text_complete(&scope.buffer);
|
subscriber.on_text_complete(&scope.buffer);
|
||||||
}
|
}
|
||||||
|
|
@ -207,7 +207,7 @@ impl<S: WorkerSubscriber + 'static> Handler<TextBlockKind> for TextBlockSubscrib
|
||||||
// ToolUseBlock Handler Adapter
|
// ToolUseBlock Handler Adapter
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// ToolUseBlockKind用のSubscriberアダプター
|
/// Subscriber adapter for ToolUseBlockKind
|
||||||
pub(crate) struct ToolUseBlockSubscriberAdapter<S: WorkerSubscriber> {
|
pub(crate) struct ToolUseBlockSubscriberAdapter<S: WorkerSubscriber> {
|
||||||
subscriber: Arc<Mutex<S>>,
|
subscriber: Arc<Mutex<S>>,
|
||||||
}
|
}
|
||||||
|
|
@ -226,12 +226,12 @@ impl<S: WorkerSubscriber> Clone for ToolUseBlockSubscriberAdapter<S> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ToolUseBlockのスコープをラップ
|
/// Wrapper for ToolUseBlock scope
|
||||||
pub struct ToolUseBlockScopeWrapper<S: WorkerSubscriber> {
|
pub struct ToolUseBlockScopeWrapper<S: WorkerSubscriber> {
|
||||||
inner: S::ToolUseBlockScope,
|
inner: S::ToolUseBlockScope,
|
||||||
id: String,
|
id: String,
|
||||||
name: String,
|
name: String,
|
||||||
input_json: String, // JSON蓄積用
|
input_json: String, // JSON accumulation
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: WorkerSubscriber> Default for ToolUseBlockScopeWrapper<S> {
|
impl<S: WorkerSubscriber> Default for ToolUseBlockScopeWrapper<S> {
|
||||||
|
|
@ -249,22 +249,22 @@ impl<S: WorkerSubscriber + 'static> Handler<ToolUseBlockKind> for ToolUseBlockSu
|
||||||
type Scope = ToolUseBlockScopeWrapper<S>;
|
type Scope = ToolUseBlockScopeWrapper<S>;
|
||||||
|
|
||||||
fn on_event(&mut self, scope: &mut Self::Scope, event: &ToolUseBlockEvent) {
|
fn on_event(&mut self, scope: &mut Self::Scope, event: &ToolUseBlockEvent) {
|
||||||
// Start時にメタデータを保存
|
// Save metadata on Start
|
||||||
if let ToolUseBlockEvent::Start(start) = event {
|
if let ToolUseBlockEvent::Start(start) = event {
|
||||||
scope.id = start.id.clone();
|
scope.id = start.id.clone();
|
||||||
scope.name = start.name.clone();
|
scope.name = start.name.clone();
|
||||||
}
|
}
|
||||||
|
|
||||||
// InputJsonDeltaの場合はバッファに蓄積
|
// Accumulate InputJsonDelta into buffer
|
||||||
if let ToolUseBlockEvent::InputJsonDelta(json) = event {
|
if let ToolUseBlockEvent::InputJsonDelta(json) = event {
|
||||||
scope.input_json.push_str(json);
|
scope.input_json.push_str(json);
|
||||||
}
|
}
|
||||||
|
|
||||||
// SubscriberのToolUseBlockイベントハンドラを呼び出し
|
// Call Subscriber's ToolUseBlock event handler
|
||||||
if let Ok(mut subscriber) = self.subscriber.lock() {
|
if let Ok(mut subscriber) = self.subscriber.lock() {
|
||||||
subscriber.on_tool_use_block(&mut scope.inner, event);
|
subscriber.on_tool_use_block(&mut scope.inner, event);
|
||||||
|
|
||||||
// Stopの場合はon_tool_call_completeも呼び出し
|
// Also call on_tool_call_complete on Stop
|
||||||
if matches!(event, ToolUseBlockEvent::Stop(_)) {
|
if matches!(event, ToolUseBlockEvent::Stop(_)) {
|
||||||
let input: serde_json::Value =
|
let input: serde_json::Value =
|
||||||
serde_json::from_str(&scope.input_json).unwrap_or_default();
|
serde_json::from_str(&scope.input_json).unwrap_or_default();
|
||||||
|
|
@ -283,7 +283,7 @@ impl<S: WorkerSubscriber + 'static> Handler<ToolUseBlockKind> for ToolUseBlockSu
|
||||||
// Meta Event Handler Adapters
|
// Meta Event Handler Adapters
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// UsageKind用のSubscriberアダプター
|
/// Subscriber adapter for UsageKind
|
||||||
pub(crate) struct UsageSubscriberAdapter<S: WorkerSubscriber> {
|
pub(crate) struct UsageSubscriberAdapter<S: WorkerSubscriber> {
|
||||||
subscriber: Arc<Mutex<S>>,
|
subscriber: Arc<Mutex<S>>,
|
||||||
}
|
}
|
||||||
|
|
@ -312,7 +312,7 @@ impl<S: WorkerSubscriber + 'static> Handler<UsageKind> for UsageSubscriberAdapte
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// StatusKind用のSubscriberアダプター
|
/// Subscriber adapter for StatusKind
|
||||||
pub(crate) struct StatusSubscriberAdapter<S: WorkerSubscriber> {
|
pub(crate) struct StatusSubscriberAdapter<S: WorkerSubscriber> {
|
||||||
subscriber: Arc<Mutex<S>>,
|
subscriber: Arc<Mutex<S>>,
|
||||||
}
|
}
|
||||||
|
|
@ -341,7 +341,7 @@ impl<S: WorkerSubscriber + 'static> Handler<StatusKind> for StatusSubscriberAdap
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ErrorKind用のSubscriberアダプター
|
/// Subscriber adapter for ErrorKind
|
||||||
pub(crate) struct ErrorSubscriberAdapter<S: WorkerSubscriber> {
|
pub(crate) struct ErrorSubscriberAdapter<S: WorkerSubscriber> {
|
||||||
subscriber: Arc<Mutex<S>>,
|
subscriber: Arc<Mutex<S>>,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
//! ツール定義
|
//! Tool Definition
|
||||||
//!
|
//!
|
||||||
//! LLMから呼び出し可能なツールを定義するためのトレイト。
|
//! Traits for defining tools callable by LLM.
|
||||||
//! 通常は`#[tool]`マクロを使用して自動実装します。
|
//! Usually auto-implemented using the `#[tool]` macro.
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
|
@ -9,40 +9,40 @@ use async_trait::async_trait;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
/// ツール実行時のエラー
|
/// Error during tool execution
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum ToolError {
|
pub enum ToolError {
|
||||||
/// 引数が不正
|
/// Invalid argument
|
||||||
#[error("Invalid argument: {0}")]
|
#[error("Invalid argument: {0}")]
|
||||||
InvalidArgument(String),
|
InvalidArgument(String),
|
||||||
/// 実行に失敗
|
/// Execution failed
|
||||||
#[error("Execution failed: {0}")]
|
#[error("Execution failed: {0}")]
|
||||||
ExecutionFailed(String),
|
ExecutionFailed(String),
|
||||||
/// 内部エラー
|
/// Internal error
|
||||||
#[error("Internal error: {0}")]
|
#[error("Internal error: {0}")]
|
||||||
Internal(String),
|
Internal(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// ToolMeta - 不変のメタ情報
|
// ToolMeta - Immutable Meta Information
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// ツールのメタ情報(登録時に固定、不変)
|
/// Tool meta information (fixed at registration, immutable)
|
||||||
///
|
///
|
||||||
/// `ToolDefinition` ファクトリから生成され、Worker に登録後は変更されません。
|
/// Generated from `ToolDefinition` factory and does not change after registration with Worker.
|
||||||
/// LLM へのツール定義送信に使用されます。
|
/// Used for sending tool definitions to LLM.
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct ToolMeta {
|
pub struct ToolMeta {
|
||||||
/// ツール名(LLMが識別に使用)
|
/// Tool name (used by LLM for identification)
|
||||||
pub name: String,
|
pub name: String,
|
||||||
/// ツールの説明(LLMへのプロンプトに含まれる)
|
/// Tool description (included in prompt to LLM)
|
||||||
pub description: String,
|
pub description: String,
|
||||||
/// 引数のJSON Schema
|
/// JSON Schema for arguments
|
||||||
pub input_schema: Value,
|
pub input_schema: Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToolMeta {
|
impl ToolMeta {
|
||||||
/// 新しい ToolMeta を作成
|
/// Create a new ToolMeta
|
||||||
pub fn new(name: impl Into<String>) -> Self {
|
pub fn new(name: impl Into<String>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
name: name.into(),
|
name: name.into(),
|
||||||
|
|
@ -51,13 +51,13 @@ impl ToolMeta {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 説明を設定
|
/// Set the description
|
||||||
pub fn description(mut self, desc: impl Into<String>) -> Self {
|
pub fn description(mut self, desc: impl Into<String>) -> Self {
|
||||||
self.description = desc.into();
|
self.description = desc.into();
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 引数スキーマを設定
|
/// Set the argument schema
|
||||||
pub fn input_schema(mut self, schema: Value) -> Self {
|
pub fn input_schema(mut self, schema: Value) -> Self {
|
||||||
self.input_schema = schema;
|
self.input_schema = schema;
|
||||||
self
|
self
|
||||||
|
|
@ -65,14 +65,14 @@ impl ToolMeta {
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// ToolDefinition - ファクトリ型
|
// ToolDefinition - Factory Type
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// ツール定義ファクトリ
|
/// Tool definition factory
|
||||||
///
|
///
|
||||||
/// 呼び出すと `(ToolMeta, Arc<dyn Tool>)` を返します。
|
/// When called, returns `(ToolMeta, Arc<dyn Tool>)`.
|
||||||
/// Worker への登録時に一度だけ呼び出され、メタ情報とインスタンスが
|
/// Called once during Worker registration, and the meta information and instance
|
||||||
/// セッションスコープでキャッシュされます。
|
/// are cached at session scope.
|
||||||
///
|
///
|
||||||
/// # Examples
|
/// # Examples
|
||||||
///
|
///
|
||||||
|
|
@ -93,15 +93,15 @@ pub type ToolDefinition = Arc<dyn Fn() -> (ToolMeta, Arc<dyn Tool>) + Send + Syn
|
||||||
// Tool trait
|
// Tool trait
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// LLMから呼び出し可能なツールを定義するトレイト
|
/// Trait for defining tools callable by LLM
|
||||||
///
|
///
|
||||||
/// ツールはLLMが外部リソースにアクセスしたり、
|
/// Tools are used by LLM to access external resources
|
||||||
/// 計算を実行したりするために使用します。
|
/// or execute computations.
|
||||||
/// セッション中の状態を保持できます。
|
/// Can maintain state during the session.
|
||||||
///
|
///
|
||||||
/// # 実装方法
|
/// # How to Implement
|
||||||
///
|
///
|
||||||
/// 通常は`#[tool_registry]`マクロを使用して自動実装します:
|
/// Usually auto-implemented using the `#[tool_registry]` macro:
|
||||||
///
|
///
|
||||||
/// ```ignore
|
/// ```ignore
|
||||||
/// #[tool_registry]
|
/// #[tool_registry]
|
||||||
|
|
@ -112,11 +112,11 @@ pub type ToolDefinition = Arc<dyn Fn() -> (ToolMeta, Arc<dyn Tool>) + Send + Syn
|
||||||
/// }
|
/// }
|
||||||
/// }
|
/// }
|
||||||
///
|
///
|
||||||
/// // 登録
|
/// // Register
|
||||||
/// worker.register_tool(app.search_definition())?;
|
/// worker.register_tool(app.search_definition())?;
|
||||||
/// ```
|
/// ```
|
||||||
///
|
///
|
||||||
/// # 手動実装
|
/// # Manual Implementation
|
||||||
///
|
///
|
||||||
/// ```ignore
|
/// ```ignore
|
||||||
/// use llm_worker::tool::{Tool, ToolError, ToolMeta, ToolDefinition};
|
/// use llm_worker::tool::{Tool, ToolError, ToolMeta, ToolDefinition};
|
||||||
|
|
@ -143,12 +143,12 @@ pub type ToolDefinition = Arc<dyn Fn() -> (ToolMeta, Arc<dyn Tool>) + Send + Syn
|
||||||
/// ```
|
/// ```
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Tool: Send + Sync {
|
pub trait Tool: Send + Sync {
|
||||||
/// ツールを実行する
|
/// Execute the tool
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `input_json` - LLMが生成したJSON形式の引数
|
/// * `input_json` - JSON-formatted arguments generated by LLM
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// 実行結果の文字列。この内容がLLMに返されます。
|
/// Result string from execution. This content is returned to LLM.
|
||||||
async fn execute(&self, input_json: &str) -> Result<String, ToolError>;
|
async fn execute(&self, input_json: &str) -> Result<String, ToolError>;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
182
llm-worker/src/tool_server.rs
Normal file
182
llm-worker/src/tool_server.rs
Normal file
|
|
@ -0,0 +1,182 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
use crate::llm_client::ToolDefinition as LlmToolDefinition;
|
||||||
|
use crate::tool::{Tool, ToolDefinition as WorkerToolDefinition, ToolMeta};
|
||||||
|
|
||||||
|
type ToolMap = HashMap<String, (ToolMeta, Arc<dyn Tool>)>;
|
||||||
|
|
||||||
|
/// Errors produced by ToolServer operations.
|
||||||
|
#[derive(Debug, Error, PartialEq, Eq)]
|
||||||
|
pub enum ToolServerError {
|
||||||
|
/// A tool with the same name already exists.
|
||||||
|
#[error("Tool with name '{0}' already registered")]
|
||||||
|
DuplicateName(String),
|
||||||
|
/// Requested tool was not found.
|
||||||
|
#[error("Tool '{0}' not found")]
|
||||||
|
ToolNotFound(String),
|
||||||
|
/// Tool execution failed.
|
||||||
|
#[error("Tool execution failed: {0}")]
|
||||||
|
ToolExecution(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// In-memory tool server.
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
pub struct ToolServer {
|
||||||
|
tools: Arc<Mutex<ToolMap>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolServer {
|
||||||
|
/// Create a new empty tool server.
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a handle for shared access.
|
||||||
|
pub fn handle(&self) -> ToolServerHandle {
|
||||||
|
ToolServerHandle {
|
||||||
|
tools: Arc::clone(&self.tools),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shareable handle to a tool server.
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
pub struct ToolServerHandle {
|
||||||
|
tools: Arc<Mutex<ToolMap>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolServerHandle {
|
||||||
|
/// Register one tool.
|
||||||
|
pub(crate) fn register_tool(
|
||||||
|
&self,
|
||||||
|
factory: WorkerToolDefinition,
|
||||||
|
) -> Result<(), ToolServerError> {
|
||||||
|
let (meta, instance) = factory();
|
||||||
|
let mut guard = self.tools.lock().unwrap_or_else(|e| e.into_inner());
|
||||||
|
if guard.contains_key(&meta.name) {
|
||||||
|
return Err(ToolServerError::DuplicateName(meta.name));
|
||||||
|
}
|
||||||
|
guard.insert(meta.name.clone(), (meta, instance));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register many tools.
|
||||||
|
pub(crate) fn register_tools(
|
||||||
|
&self,
|
||||||
|
factories: impl IntoIterator<Item = WorkerToolDefinition>,
|
||||||
|
) -> Result<(), ToolServerError> {
|
||||||
|
for factory in factories {
|
||||||
|
self.register_tool(factory)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a tool by name for hook contexts.
|
||||||
|
pub fn get_tool(&self, name: &str) -> Option<(ToolMeta, Arc<dyn Tool>)> {
|
||||||
|
let guard = self.tools.lock().unwrap_or_else(|e| e.into_inner());
|
||||||
|
guard.get(name).map(|(meta, tool)| (meta.clone(), Arc::clone(tool)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute a tool by name.
|
||||||
|
pub async fn call_tool(&self, name: &str, input_json: &str) -> Result<String, ToolServerError> {
|
||||||
|
let tool = {
|
||||||
|
let guard = self.tools.lock().unwrap_or_else(|e| e.into_inner());
|
||||||
|
let (_, tool) = guard
|
||||||
|
.get(name)
|
||||||
|
.ok_or_else(|| ToolServerError::ToolNotFound(name.to_string()))?;
|
||||||
|
Arc::clone(tool)
|
||||||
|
};
|
||||||
|
tool.execute(input_json)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ToolServerError::ToolExecution(e.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build deterministic tool definitions sorted by tool name.
|
||||||
|
pub fn tool_definitions_sorted(&self) -> Vec<LlmToolDefinition> {
|
||||||
|
let guard = self.tools.lock().unwrap_or_else(|e| e.into_inner());
|
||||||
|
let mut defs: Vec<_> = guard
|
||||||
|
.values()
|
||||||
|
.map(|(meta, _)| {
|
||||||
|
LlmToolDefinition::new(&meta.name)
|
||||||
|
.description(&meta.description)
|
||||||
|
.input_schema(meta.input_schema.clone())
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
defs.sort_by(|a, b| a.name.cmp(&b.name));
|
||||||
|
defs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use crate::tool::{Tool, ToolDefinition, ToolError, ToolMeta};
|
||||||
|
|
||||||
|
struct EchoTool;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for EchoTool {
|
||||||
|
async fn execute(&self, input_json: &str) -> Result<String, ToolError> {
|
||||||
|
Ok(input_json.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn def(name: &'static str) -> ToolDefinition {
|
||||||
|
Arc::new(move || {
|
||||||
|
(
|
||||||
|
ToolMeta::new(name)
|
||||||
|
.description(format!("desc-{name}"))
|
||||||
|
.input_schema(json!({"type":"object"})),
|
||||||
|
Arc::new(EchoTool) as Arc<dyn Tool>,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn register_duplicate_name_fails() {
|
||||||
|
let handle = ToolServer::new().handle();
|
||||||
|
handle.register_tool(def("alpha")).expect("first register");
|
||||||
|
let err = handle
|
||||||
|
.register_tool(def("alpha"))
|
||||||
|
.expect_err("duplicate should fail");
|
||||||
|
assert_eq!(err, ToolServerError::DuplicateName("alpha".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn call_tool_success_and_not_found() {
|
||||||
|
let handle = ToolServer::new().handle();
|
||||||
|
handle.register_tool(def("echo")).expect("register");
|
||||||
|
|
||||||
|
let out = handle.call_tool("echo", r#"{"x":1}"#).await.expect("call");
|
||||||
|
assert_eq!(out, r#"{"x":1}"#);
|
||||||
|
|
||||||
|
let err = handle
|
||||||
|
.call_tool("missing", "{}")
|
||||||
|
.await
|
||||||
|
.expect_err("missing tool");
|
||||||
|
assert_eq!(err, ToolServerError::ToolNotFound("missing".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tool_definitions_are_sorted() {
|
||||||
|
let handle = ToolServer::new().handle();
|
||||||
|
handle.register_tool(def("zeta")).expect("register zeta");
|
||||||
|
handle.register_tool(def("alpha")).expect("register alpha");
|
||||||
|
handle.register_tool(def("beta")).expect("register beta");
|
||||||
|
|
||||||
|
let names: Vec<_> = handle
|
||||||
|
.tool_definitions_sorted()
|
||||||
|
.into_iter()
|
||||||
|
.map(|d| d.name)
|
||||||
|
.collect();
|
||||||
|
assert_eq!(names, vec!["alpha", "beta", "zeta"]);
|
||||||
|
}
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,4 +1,4 @@
|
||||||
//! Anthropic フィクスチャベースの統合テスト
|
//! Anthropic fixture-based integration tests
|
||||||
|
|
||||||
mod common;
|
mod common;
|
||||||
|
|
||||||
|
|
|
||||||
6
llm-worker/tests/compile_fail.rs
Normal file
6
llm-worker/tests/compile_fail.rs
Normal file
|
|
@ -0,0 +1,6 @@
|
||||||
|
#[test]
|
||||||
|
fn compile_fail_state_constraints() {
|
||||||
|
let t = trybuild::TestCases::new();
|
||||||
|
t.compile_fail("tests/ui/cache_locked_register_tool.rs");
|
||||||
|
t.compile_fail("tests/ui/tool_server_handle_register_tool.rs");
|
||||||
|
}
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
//! Gemini フィクスチャベースの統合テスト
|
//! Gemini fixture-based integration tests
|
||||||
|
|
||||||
mod common;
|
mod common;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
//! Ollama フィクスチャベースの統合テスト
|
//! Ollama fixture-based integration tests
|
||||||
|
|
||||||
mod common;
|
mod common;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
//! OpenAI フィクスチャベースの統合テスト
|
//! OpenAI fixture-based integration tests
|
||||||
|
|
||||||
mod common;
|
mod common;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
//! 並列ツール実行のテスト
|
//! Parallel tool execution tests
|
||||||
//!
|
//!
|
||||||
//! Workerが複数のツールを並列に実行することを確認する。
|
//! Verify that Worker executes multiple tools in parallel.
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
|
@ -22,7 +22,7 @@ use common::MockLlmClient;
|
||||||
// Parallel Execution Test Tools
|
// Parallel Execution Test Tools
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// 一定時間待機してから応答するツール
|
/// Tool that waits for a specified time before responding
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct SlowTool {
|
struct SlowTool {
|
||||||
name: String,
|
name: String,
|
||||||
|
|
@ -43,7 +43,7 @@ impl SlowTool {
|
||||||
self.call_count.load(Ordering::SeqCst)
|
self.call_count.load(Ordering::SeqCst)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ToolDefinition を作成
|
/// Create ToolDefinition
|
||||||
fn definition(&self) -> ToolDefinition {
|
fn definition(&self) -> ToolDefinition {
|
||||||
let tool = self.clone();
|
let tool = self.clone();
|
||||||
Arc::new(move || {
|
Arc::new(move || {
|
||||||
|
|
@ -71,13 +71,13 @@ impl Tool for SlowTool {
|
||||||
// Tests
|
// Tests
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// 複数のツールが並列に実行されることを確認
|
/// Verify that multiple tools are executed in parallel
|
||||||
///
|
///
|
||||||
/// 各ツールが100msかかる場合、逐次実行なら300ms以上かかるが、
|
/// If each tool takes 100ms, sequential execution would take 300ms+,
|
||||||
/// 並列実行なら100ms程度で完了するはず。
|
/// but parallel execution should complete in about 100ms.
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_parallel_tool_execution() {
|
async fn test_parallel_tool_execution() {
|
||||||
// 3つのツール呼び出しを含むイベントシーケンス
|
// Event sequence containing 3 tool calls
|
||||||
let events = vec![
|
let events = vec![
|
||||||
Event::tool_use_start(0, "call_1", "slow_tool_1"),
|
Event::tool_use_start(0, "call_1", "slow_tool_1"),
|
||||||
Event::tool_input_delta(0, r#"{}"#),
|
Event::tool_input_delta(0, r#"{}"#),
|
||||||
|
|
@ -96,7 +96,7 @@ async fn test_parallel_tool_execution() {
|
||||||
let client = MockLlmClient::new(events);
|
let client = MockLlmClient::new(events);
|
||||||
let mut worker = Worker::new(client);
|
let mut worker = Worker::new(client);
|
||||||
|
|
||||||
// 各ツールは100ms待機
|
// Each tool waits 100ms
|
||||||
let tool1 = SlowTool::new("slow_tool_1", 100);
|
let tool1 = SlowTool::new("slow_tool_1", 100);
|
||||||
let tool2 = SlowTool::new("slow_tool_2", 100);
|
let tool2 = SlowTool::new("slow_tool_2", 100);
|
||||||
let tool3 = SlowTool::new("slow_tool_3", 100);
|
let tool3 = SlowTool::new("slow_tool_3", 100);
|
||||||
|
|
@ -113,13 +113,13 @@ async fn test_parallel_tool_execution() {
|
||||||
let _result = worker.run("Run all tools").await;
|
let _result = worker.run("Run all tools").await;
|
||||||
let elapsed = start.elapsed();
|
let elapsed = start.elapsed();
|
||||||
|
|
||||||
// 全ツールが呼び出されたことを確認
|
// Verify all tools were called
|
||||||
assert_eq!(tool1_clone.call_count(), 1, "Tool 1 should be called once");
|
assert_eq!(tool1_clone.call_count(), 1, "Tool 1 should be called once");
|
||||||
assert_eq!(tool2_clone.call_count(), 1, "Tool 2 should be called once");
|
assert_eq!(tool2_clone.call_count(), 1, "Tool 2 should be called once");
|
||||||
assert_eq!(tool3_clone.call_count(), 1, "Tool 3 should be called once");
|
assert_eq!(tool3_clone.call_count(), 1, "Tool 3 should be called once");
|
||||||
|
|
||||||
// 並列実行なら200ms以下で完了するはず(逐次なら300ms以上)
|
// Parallel execution should complete in under 200ms (sequential would be 300ms+)
|
||||||
// マージン込みで250msをしきい値とする
|
// Using 250ms as threshold with margin
|
||||||
assert!(
|
assert!(
|
||||||
elapsed < Duration::from_millis(250),
|
elapsed < Duration::from_millis(250),
|
||||||
"Parallel execution should complete in ~100ms, but took {:?}",
|
"Parallel execution should complete in ~100ms, but took {:?}",
|
||||||
|
|
@ -129,7 +129,7 @@ async fn test_parallel_tool_execution() {
|
||||||
println!("Parallel execution completed in {:?}", elapsed);
|
println!("Parallel execution completed in {:?}", elapsed);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Hook: pre_tool_call でスキップされたツールは実行されないことを確認
|
/// Hook: pre_tool_call - verify that skipped tools are not executed
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_before_tool_call_skip() {
|
async fn test_before_tool_call_skip() {
|
||||||
let events = vec![
|
let events = vec![
|
||||||
|
|
@ -156,7 +156,7 @@ async fn test_before_tool_call_skip() {
|
||||||
worker.register_tool(allowed_tool.definition()).unwrap();
|
worker.register_tool(allowed_tool.definition()).unwrap();
|
||||||
worker.register_tool(blocked_tool.definition()).unwrap();
|
worker.register_tool(blocked_tool.definition()).unwrap();
|
||||||
|
|
||||||
// "blocked_tool" をスキップするHook
|
// Hook to skip "blocked_tool"
|
||||||
struct BlockingHook;
|
struct BlockingHook;
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
|
@ -174,7 +174,7 @@ async fn test_before_tool_call_skip() {
|
||||||
|
|
||||||
let _result = worker.run("Test hook").await;
|
let _result = worker.run("Test hook").await;
|
||||||
|
|
||||||
// allowed_tool は呼び出されるが、blocked_tool は呼び出されない
|
// allowed_tool is called, but blocked_tool is not
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
allowed_clone.call_count(),
|
allowed_clone.call_count(),
|
||||||
1,
|
1,
|
||||||
|
|
@ -187,12 +187,12 @@ async fn test_before_tool_call_skip() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Hook: post_tool_call で結果が改変されることを確認
|
/// Hook: post_tool_call - verify that results can be modified
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_post_tool_call_modification() {
|
async fn test_post_tool_call_modification() {
|
||||||
// 複数リクエストに対応するレスポンスを準備
|
// Prepare responses for multiple requests
|
||||||
let client = MockLlmClient::with_responses(vec![
|
let client = MockLlmClient::with_responses(vec![
|
||||||
// 1回目のリクエスト: ツール呼び出し
|
// First request: tool call
|
||||||
vec![
|
vec![
|
||||||
Event::tool_use_start(0, "call_1", "test_tool"),
|
Event::tool_use_start(0, "call_1", "test_tool"),
|
||||||
Event::tool_input_delta(0, r#"{}"#),
|
Event::tool_input_delta(0, r#"{}"#),
|
||||||
|
|
@ -201,7 +201,7 @@ async fn test_post_tool_call_modification() {
|
||||||
status: ResponseStatus::Completed,
|
status: ResponseStatus::Completed,
|
||||||
}),
|
}),
|
||||||
],
|
],
|
||||||
// 2回目のリクエスト: ツール結果を受けてテキストレスポンス
|
// Second request: text response after receiving tool result
|
||||||
vec![
|
vec![
|
||||||
Event::text_block_start(0),
|
Event::text_block_start(0),
|
||||||
Event::text_delta(0, "Done!"),
|
Event::text_delta(0, "Done!"),
|
||||||
|
|
@ -235,7 +235,7 @@ async fn test_post_tool_call_modification() {
|
||||||
|
|
||||||
worker.register_tool(simple_tool_definition()).unwrap();
|
worker.register_tool(simple_tool_definition()).unwrap();
|
||||||
|
|
||||||
// 結果を改変するHook
|
// Hook to modify results
|
||||||
struct ModifyingHook {
|
struct ModifyingHook {
|
||||||
modified_content: Arc<std::sync::Mutex<Option<String>>>,
|
modified_content: Arc<std::sync::Mutex<Option<String>>>,
|
||||||
}
|
}
|
||||||
|
|
@ -261,7 +261,7 @@ async fn test_post_tool_call_modification() {
|
||||||
|
|
||||||
assert!(result.is_ok(), "Worker should complete: {:?}", result);
|
assert!(result.is_ok(), "Worker should complete: {:?}", result);
|
||||||
|
|
||||||
// Hookが呼ばれて内容が改変されたことを確認
|
// Verify hook was called and content was modified
|
||||||
let content = modified_content.lock().unwrap().clone();
|
let content = modified_content.lock().unwrap().clone();
|
||||||
assert!(content.is_some(), "Hook should have been called");
|
assert!(content.is_some(), "Hook should have been called");
|
||||||
assert!(
|
assert!(
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
//! WorkerSubscriberのテスト
|
//! WorkerSubscriber tests
|
||||||
//!
|
//!
|
||||||
//! WorkerSubscriberを使ってイベントを購読するテスト
|
//! Tests for subscribing to events using WorkerSubscriber
|
||||||
|
|
||||||
mod common;
|
mod common;
|
||||||
|
|
||||||
|
|
@ -18,9 +18,9 @@ use llm_worker::timeline::{TextBlockEvent, ToolUseBlockEvent};
|
||||||
// Test Subscriber
|
// Test Subscriber
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// テスト用のシンプルなSubscriber実装
|
/// Simple Subscriber implementation for testing
|
||||||
struct TestSubscriber {
|
struct TestSubscriber {
|
||||||
// 記録用のバッファ
|
// Recording buffers
|
||||||
text_deltas: Arc<Mutex<Vec<String>>>,
|
text_deltas: Arc<Mutex<Vec<String>>>,
|
||||||
text_completes: Arc<Mutex<Vec<String>>>,
|
text_completes: Arc<Mutex<Vec<String>>>,
|
||||||
tool_call_completes: Arc<Mutex<Vec<ToolCall>>>,
|
tool_call_completes: Arc<Mutex<Vec<ToolCall>>>,
|
||||||
|
|
@ -60,7 +60,7 @@ impl WorkerSubscriber for TestSubscriber {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_tool_use_block(&mut self, _scope: &mut (), _event: &ToolUseBlockEvent) {
|
fn on_tool_use_block(&mut self, _scope: &mut (), _event: &ToolUseBlockEvent) {
|
||||||
// 必要に応じて処理
|
// Process as needed
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_tool_call_complete(&mut self, call: &ToolCall) {
|
fn on_tool_call_complete(&mut self, call: &ToolCall) {
|
||||||
|
|
@ -76,7 +76,7 @@ impl WorkerSubscriber for TestSubscriber {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_error(&mut self, _event: &ErrorEvent) {
|
fn on_error(&mut self, _event: &ErrorEvent) {
|
||||||
// 必要に応じて処理
|
// Process as needed
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_turn_start(&mut self, turn: usize) {
|
fn on_turn_start(&mut self, turn: usize) {
|
||||||
|
|
@ -92,10 +92,10 @@ impl WorkerSubscriber for TestSubscriber {
|
||||||
// Tests
|
// Tests
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// WorkerSubscriberがテキストブロックイベントを正しく受け取ることを確認
|
/// Verify that WorkerSubscriber correctly receives text block events
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_subscriber_text_block_events() {
|
async fn test_subscriber_text_block_events() {
|
||||||
// テキストレスポンスを含むイベントシーケンス
|
// Event sequence containing text response
|
||||||
let events = vec![
|
let events = vec![
|
||||||
Event::text_block_start(0),
|
Event::text_block_start(0),
|
||||||
Event::text_delta(0, "Hello, "),
|
Event::text_delta(0, "Hello, "),
|
||||||
|
|
@ -109,33 +109,33 @@ async fn test_subscriber_text_block_events() {
|
||||||
let client = MockLlmClient::new(events);
|
let client = MockLlmClient::new(events);
|
||||||
let mut worker = Worker::new(client);
|
let mut worker = Worker::new(client);
|
||||||
|
|
||||||
// Subscriberを登録
|
// Register Subscriber
|
||||||
let subscriber = TestSubscriber::new();
|
let subscriber = TestSubscriber::new();
|
||||||
let text_deltas = subscriber.text_deltas.clone();
|
let text_deltas = subscriber.text_deltas.clone();
|
||||||
let text_completes = subscriber.text_completes.clone();
|
let text_completes = subscriber.text_completes.clone();
|
||||||
worker.subscribe(subscriber);
|
worker.subscribe(subscriber);
|
||||||
|
|
||||||
// 実行
|
// Execute
|
||||||
let result = worker.run("Greet me").await;
|
let result = worker.run("Greet me").await;
|
||||||
|
|
||||||
assert!(result.is_ok(), "Worker should complete: {:?}", result);
|
assert!(result.is_ok(), "Worker should complete: {:?}", result);
|
||||||
|
|
||||||
// デルタが収集されていることを確認
|
// Verify deltas were collected
|
||||||
let deltas = text_deltas.lock().unwrap();
|
let deltas = text_deltas.lock().unwrap();
|
||||||
assert_eq!(deltas.len(), 2);
|
assert_eq!(deltas.len(), 2);
|
||||||
assert_eq!(deltas[0], "Hello, ");
|
assert_eq!(deltas[0], "Hello, ");
|
||||||
assert_eq!(deltas[1], "World!");
|
assert_eq!(deltas[1], "World!");
|
||||||
|
|
||||||
// 完了テキストが収集されていることを確認
|
// Verify complete text was collected
|
||||||
let completes = text_completes.lock().unwrap();
|
let completes = text_completes.lock().unwrap();
|
||||||
assert_eq!(completes.len(), 1);
|
assert_eq!(completes.len(), 1);
|
||||||
assert_eq!(completes[0], "Hello, World!");
|
assert_eq!(completes[0], "Hello, World!");
|
||||||
}
|
}
|
||||||
|
|
||||||
/// WorkerSubscriberがツール呼び出し完了イベントを正しく受け取ることを確認
|
/// Verify that WorkerSubscriber correctly receives tool call complete events
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_subscriber_tool_call_complete() {
|
async fn test_subscriber_tool_call_complete() {
|
||||||
// ツール呼び出しを含むイベントシーケンス
|
// Event sequence containing tool call
|
||||||
let events = vec![
|
let events = vec![
|
||||||
Event::tool_use_start(0, "call_123", "get_weather"),
|
Event::tool_use_start(0, "call_123", "get_weather"),
|
||||||
Event::tool_input_delta(0, r#"{"city":"#),
|
Event::tool_input_delta(0, r#"{"city":"#),
|
||||||
|
|
@ -149,15 +149,15 @@ async fn test_subscriber_tool_call_complete() {
|
||||||
let client = MockLlmClient::new(events);
|
let client = MockLlmClient::new(events);
|
||||||
let mut worker = Worker::new(client);
|
let mut worker = Worker::new(client);
|
||||||
|
|
||||||
// Subscriberを登録
|
// Register Subscriber
|
||||||
let subscriber = TestSubscriber::new();
|
let subscriber = TestSubscriber::new();
|
||||||
let tool_call_completes = subscriber.tool_call_completes.clone();
|
let tool_call_completes = subscriber.tool_call_completes.clone();
|
||||||
worker.subscribe(subscriber);
|
worker.subscribe(subscriber);
|
||||||
|
|
||||||
// 実行
|
// Execute
|
||||||
let _ = worker.run("Weather please").await;
|
let _ = worker.run("Weather please").await;
|
||||||
|
|
||||||
// ツール呼び出し完了が収集されていることを確認
|
// Verify tool call complete was collected
|
||||||
let completes = tool_call_completes.lock().unwrap();
|
let completes = tool_call_completes.lock().unwrap();
|
||||||
assert_eq!(completes.len(), 1);
|
assert_eq!(completes.len(), 1);
|
||||||
assert_eq!(completes[0].name, "get_weather");
|
assert_eq!(completes[0].name, "get_weather");
|
||||||
|
|
@ -165,7 +165,7 @@ async fn test_subscriber_tool_call_complete() {
|
||||||
assert_eq!(completes[0].input["city"], "Tokyo");
|
assert_eq!(completes[0].input["city"], "Tokyo");
|
||||||
}
|
}
|
||||||
|
|
||||||
/// WorkerSubscriberがターンイベントを正しく受け取ることを確認
|
/// Verify that WorkerSubscriber correctly receives turn events
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_subscriber_turn_events() {
|
async fn test_subscriber_turn_events() {
|
||||||
let events = vec![
|
let events = vec![
|
||||||
|
|
@ -180,29 +180,29 @@ async fn test_subscriber_turn_events() {
|
||||||
let client = MockLlmClient::new(events);
|
let client = MockLlmClient::new(events);
|
||||||
let mut worker = Worker::new(client);
|
let mut worker = Worker::new(client);
|
||||||
|
|
||||||
// Subscriberを登録
|
// Register Subscriber
|
||||||
let subscriber = TestSubscriber::new();
|
let subscriber = TestSubscriber::new();
|
||||||
let turn_starts = subscriber.turn_starts.clone();
|
let turn_starts = subscriber.turn_starts.clone();
|
||||||
let turn_ends = subscriber.turn_ends.clone();
|
let turn_ends = subscriber.turn_ends.clone();
|
||||||
worker.subscribe(subscriber);
|
worker.subscribe(subscriber);
|
||||||
|
|
||||||
// 実行
|
// Execute
|
||||||
let result = worker.run("Do something").await;
|
let result = worker.run("Do something").await;
|
||||||
|
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
// ターンイベントが収集されていることを確認
|
// Verify turn events were collected
|
||||||
let starts = turn_starts.lock().unwrap();
|
let starts = turn_starts.lock().unwrap();
|
||||||
let ends = turn_ends.lock().unwrap();
|
let ends = turn_ends.lock().unwrap();
|
||||||
|
|
||||||
assert_eq!(starts.len(), 1);
|
assert_eq!(starts.len(), 1);
|
||||||
assert_eq!(starts[0], 0); // 最初のターン
|
assert_eq!(starts[0], 0); // First turn
|
||||||
|
|
||||||
assert_eq!(ends.len(), 1);
|
assert_eq!(ends.len(), 1);
|
||||||
assert_eq!(ends[0], 0);
|
assert_eq!(ends[0], 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// WorkerSubscriberがUsageイベントを正しく受け取ることを確認
|
/// Verify that WorkerSubscriber correctly receives Usage events
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_subscriber_usage_events() {
|
async fn test_subscriber_usage_events() {
|
||||||
let events = vec![
|
let events = vec![
|
||||||
|
|
@ -218,15 +218,15 @@ async fn test_subscriber_usage_events() {
|
||||||
let client = MockLlmClient::new(events);
|
let client = MockLlmClient::new(events);
|
||||||
let mut worker = Worker::new(client);
|
let mut worker = Worker::new(client);
|
||||||
|
|
||||||
// Subscriberを登録
|
// Register Subscriber
|
||||||
let subscriber = TestSubscriber::new();
|
let subscriber = TestSubscriber::new();
|
||||||
let usage_events = subscriber.usage_events.clone();
|
let usage_events = subscriber.usage_events.clone();
|
||||||
worker.subscribe(subscriber);
|
worker.subscribe(subscriber);
|
||||||
|
|
||||||
// 実行
|
// Execute
|
||||||
let _ = worker.run("Hello").await;
|
let _ = worker.run("Hello").await;
|
||||||
|
|
||||||
// Usageイベントが収集されていることを確認
|
// Verify Usage events were collected
|
||||||
let usages = usage_events.lock().unwrap();
|
let usages = usage_events.lock().unwrap();
|
||||||
assert_eq!(usages.len(), 1);
|
assert_eq!(usages.len(), 1);
|
||||||
assert_eq!(usages[0].input_tokens, Some(100));
|
assert_eq!(usages[0].input_tokens, Some(100));
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
//! ツールマクロのテスト
|
//! Tool macro tests
|
||||||
//!
|
//!
|
||||||
//! `#[tool_registry]` と `#[tool]` マクロの動作を確認する。
|
//! Verify the behavior of `#[tool_registry]` and `#[tool]` macros.
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
|
||||||
// マクロ展開に必要なインポート
|
// Imports needed for macro expansion
|
||||||
use schemars;
|
use schemars;
|
||||||
use serde;
|
use serde;
|
||||||
|
|
||||||
|
|
@ -15,7 +15,7 @@ use llm_worker_macros::tool_registry;
|
||||||
// Test: Basic Tool Generation
|
// Test: Basic Tool Generation
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// シンプルなコンテキスト構造体
|
/// Simple context struct
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct SimpleContext {
|
struct SimpleContext {
|
||||||
prefix: String,
|
prefix: String,
|
||||||
|
|
@ -23,21 +23,21 @@ struct SimpleContext {
|
||||||
|
|
||||||
#[tool_registry]
|
#[tool_registry]
|
||||||
impl SimpleContext {
|
impl SimpleContext {
|
||||||
/// メッセージに挨拶を追加する
|
/// Add greeting to message
|
||||||
///
|
///
|
||||||
/// 指定されたメッセージにプレフィックスを付けて返します。
|
/// Returns the message with a prefix added.
|
||||||
#[tool]
|
#[tool]
|
||||||
async fn greet(&self, message: String) -> String {
|
async fn greet(&self, message: String) -> String {
|
||||||
format!("{}: {}", self.prefix, message)
|
format!("{}: {}", self.prefix, message)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 二つの数を足す
|
/// Add two numbers
|
||||||
#[tool]
|
#[tool]
|
||||||
async fn add(&self, a: i32, b: i32) -> i32 {
|
async fn add(&self, a: i32, b: i32) -> i32 {
|
||||||
a + b
|
a + b
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 引数なしのツール
|
/// Tool with no arguments
|
||||||
#[tool]
|
#[tool]
|
||||||
async fn get_prefix(&self) -> String {
|
async fn get_prefix(&self) -> String {
|
||||||
self.prefix.clone()
|
self.prefix.clone()
|
||||||
|
|
@ -50,16 +50,16 @@ async fn test_basic_tool_generation() {
|
||||||
prefix: "Hello".to_string(),
|
prefix: "Hello".to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// ファクトリメソッドでToolDefinitionを取得
|
// Get ToolDefinition from factory method
|
||||||
let greet_definition = ctx.greet_definition();
|
let greet_definition = ctx.greet_definition();
|
||||||
|
|
||||||
// ファクトリを呼び出してMetaとToolを取得
|
// Call factory to get Meta and Tool
|
||||||
let (meta, tool) = greet_definition();
|
let (meta, tool) = greet_definition();
|
||||||
|
|
||||||
// メタ情報の確認
|
// Verify meta information
|
||||||
assert_eq!(meta.name, "greet");
|
assert_eq!(meta.name, "greet");
|
||||||
assert!(
|
assert!(
|
||||||
meta.description.contains("メッセージに挨拶を追加する"),
|
meta.description.contains("Add greeting to message"),
|
||||||
"Description should contain doc comment: {}",
|
"Description should contain doc comment: {}",
|
||||||
meta.description
|
meta.description
|
||||||
);
|
);
|
||||||
|
|
@ -73,7 +73,7 @@ async fn test_basic_tool_generation() {
|
||||||
serde_json::to_string_pretty(&meta.input_schema).unwrap()
|
serde_json::to_string_pretty(&meta.input_schema).unwrap()
|
||||||
);
|
);
|
||||||
|
|
||||||
// 実行テスト
|
// Execution test
|
||||||
let result = tool.execute(r#"{"message": "World"}"#).await;
|
let result = tool.execute(r#"{"message": "World"}"#).await;
|
||||||
assert!(result.is_ok(), "Should execute successfully");
|
assert!(result.is_ok(), "Should execute successfully");
|
||||||
let output = result.unwrap();
|
let output = result.unwrap();
|
||||||
|
|
@ -107,7 +107,7 @@ async fn test_no_arguments() {
|
||||||
|
|
||||||
assert_eq!(meta.name, "get_prefix");
|
assert_eq!(meta.name, "get_prefix");
|
||||||
|
|
||||||
// 空のJSONオブジェクトで呼び出し
|
// Call with empty JSON object
|
||||||
let result = tool.execute(r#"{}"#).await;
|
let result = tool.execute(r#"{}"#).await;
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
let output = result.unwrap();
|
let output = result.unwrap();
|
||||||
|
|
@ -126,7 +126,7 @@ async fn test_invalid_arguments() {
|
||||||
|
|
||||||
let (_, tool) = ctx.greet_definition()();
|
let (_, tool) = ctx.greet_definition()();
|
||||||
|
|
||||||
// 不正なJSON
|
// Invalid JSON
|
||||||
let result = tool.execute(r#"{"wrong_field": "value"}"#).await;
|
let result = tool.execute(r#"{"wrong_field": "value"}"#).await;
|
||||||
assert!(result.is_err(), "Should fail with invalid arguments");
|
assert!(result.is_err(), "Should fail with invalid arguments");
|
||||||
}
|
}
|
||||||
|
|
@ -149,7 +149,7 @@ impl std::fmt::Display for MyError {
|
||||||
|
|
||||||
#[tool_registry]
|
#[tool_registry]
|
||||||
impl FallibleContext {
|
impl FallibleContext {
|
||||||
/// 与えられた値を検証する
|
/// Validate the given value
|
||||||
#[tool]
|
#[tool]
|
||||||
async fn validate(&self, value: i32) -> Result<String, MyError> {
|
async fn validate(&self, value: i32) -> Result<String, MyError> {
|
||||||
if value > 0 {
|
if value > 0 {
|
||||||
|
|
@ -198,7 +198,7 @@ struct SyncContext {
|
||||||
|
|
||||||
#[tool_registry]
|
#[tool_registry]
|
||||||
impl SyncContext {
|
impl SyncContext {
|
||||||
/// カウンターをインクリメントして返す (非async)
|
/// Increment counter and return (non-async)
|
||||||
#[tool]
|
#[tool]
|
||||||
fn increment(&self) -> usize {
|
fn increment(&self) -> usize {
|
||||||
self.counter.fetch_add(1, Ordering::SeqCst) + 1
|
self.counter.fetch_add(1, Ordering::SeqCst) + 1
|
||||||
|
|
@ -213,7 +213,7 @@ async fn test_sync_method() {
|
||||||
|
|
||||||
let (_, tool) = ctx.increment_definition()();
|
let (_, tool) = ctx.increment_definition()();
|
||||||
|
|
||||||
// 3回実行
|
// Execute 3 times
|
||||||
let result1 = tool.execute(r#"{}"#).await;
|
let result1 = tool.execute(r#"{}"#).await;
|
||||||
let result2 = tool.execute(r#"{}"#).await;
|
let result2 = tool.execute(r#"{}"#).await;
|
||||||
let result3 = tool.execute(r#"{}"#).await;
|
let result3 = tool.execute(r#"{}"#).await;
|
||||||
|
|
@ -222,7 +222,7 @@ async fn test_sync_method() {
|
||||||
assert!(result2.is_ok());
|
assert!(result2.is_ok());
|
||||||
assert!(result3.is_ok());
|
assert!(result3.is_ok());
|
||||||
|
|
||||||
// カウンターは3になっているはず
|
// Counter should be 3
|
||||||
assert_eq!(ctx.counter.load(Ordering::SeqCst), 3);
|
assert_eq!(ctx.counter.load(Ordering::SeqCst), 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -236,7 +236,7 @@ async fn test_tool_meta_immutability() {
|
||||||
prefix: "Test".to_string(),
|
prefix: "Test".to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// 2回取得しても同じメタ情報が得られることを確認
|
// Verify same meta info is returned on multiple calls
|
||||||
let (meta1, _) = ctx.greet_definition()();
|
let (meta1, _) = ctx.greet_definition()();
|
||||||
let (meta2, _) = ctx.greet_definition()();
|
let (meta2, _) = ctx.greet_definition()();
|
||||||
|
|
||||||
|
|
|
||||||
11
llm-worker/tests/ui/cache_locked_register_tool.rs
Normal file
11
llm-worker/tests/ui/cache_locked_register_tool.rs
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
use llm_worker::Worker;
|
||||||
|
use llm_worker::llm_client::providers::ollama::OllamaClient;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let client = OllamaClient::new("dummy-model");
|
||||||
|
let worker = Worker::new(client);
|
||||||
|
let mut locked = worker.lock();
|
||||||
|
let def: llm_worker::tool::ToolDefinition = Arc::new(|| panic!("unused"));
|
||||||
|
let _ = locked.register_tool(def);
|
||||||
|
}
|
||||||
8
llm-worker/tests/ui/cache_locked_register_tool.stderr
Normal file
8
llm-worker/tests/ui/cache_locked_register_tool.stderr
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
error[E0599]: no method named `register_tool` found for struct `Worker<OllamaClient, CacheLocked>` in the current scope
|
||||||
|
--> tests/ui/cache_locked_register_tool.rs:10:20
|
||||||
|
|
|
||||||
|
10 | let _ = locked.register_tool(def);
|
||||||
|
| ^^^^^^^^^^^^^ method not found in `Worker<OllamaClient, CacheLocked>`
|
||||||
|
|
|
||||||
|
= note: the method was found for
|
||||||
|
- `Worker<C>`
|
||||||
11
llm-worker/tests/ui/tool_server_handle_register_tool.rs
Normal file
11
llm-worker/tests/ui/tool_server_handle_register_tool.rs
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
use llm_worker::Worker;
|
||||||
|
use llm_worker::llm_client::providers::ollama::OllamaClient;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let client = OllamaClient::new("dummy-model");
|
||||||
|
let worker = Worker::new(client);
|
||||||
|
let handle = worker.tool_server_handle();
|
||||||
|
let def: llm_worker::tool::ToolDefinition = Arc::new(|| panic!("unused"));
|
||||||
|
let _ = handle.register_tool(def);
|
||||||
|
}
|
||||||
13
llm-worker/tests/ui/tool_server_handle_register_tool.stderr
Normal file
13
llm-worker/tests/ui/tool_server_handle_register_tool.stderr
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
error[E0624]: method `register_tool` is private
|
||||||
|
--> tests/ui/tool_server_handle_register_tool.rs:10:20
|
||||||
|
|
|
||||||
|
10 | let _ = handle.register_tool(def);
|
||||||
|
| ^^^^^^^^^^^^^ private method
|
||||||
|
|
|
||||||
|
::: src/tool_server.rs
|
||||||
|
|
|
||||||
|
| / pub(crate) fn register_tool(
|
||||||
|
| | &self,
|
||||||
|
| | factory: WorkerToolDefinition,
|
||||||
|
| | ) -> Result<(), ToolServerError> {
|
||||||
|
| |____________________________________- private method defined here
|
||||||
|
|
@ -3,16 +3,16 @@ use llm_worker::{Worker, WorkerError};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_openai_top_k_warning() {
|
fn test_openai_top_k_warning() {
|
||||||
// ダミーキーでクライアント作成(validate_configは通信しないため安全)
|
// Create client with dummy key (validate_config doesn't make network calls, so safe)
|
||||||
let client = OpenAIClient::new("dummy-key", "gpt-4o");
|
let client = OpenAIClient::new("dummy-key", "gpt-4o");
|
||||||
|
|
||||||
// top_kを設定したWorkerを作成
|
// Create Worker with top_k set (OpenAI doesn't support top_k)
|
||||||
let worker = Worker::new(client).top_k(50); // OpenAIはtop_k非対応
|
let worker = Worker::new(client).top_k(50);
|
||||||
|
|
||||||
// validate()を実行
|
// Run validate()
|
||||||
let result = worker.validate();
|
let result = worker.validate();
|
||||||
|
|
||||||
// エラーが返り、ConfigWarningsが含まれていることを確認
|
// Verify error is returned and ConfigWarnings is included
|
||||||
match result {
|
match result {
|
||||||
Err(WorkerError::ConfigWarnings(warnings)) => {
|
Err(WorkerError::ConfigWarnings(warnings)) => {
|
||||||
assert_eq!(warnings.len(), 1);
|
assert_eq!(warnings.len(), 1);
|
||||||
|
|
@ -28,12 +28,12 @@ fn test_openai_top_k_warning() {
|
||||||
fn test_openai_valid_config() {
|
fn test_openai_valid_config() {
|
||||||
let client = OpenAIClient::new("dummy-key", "gpt-4o");
|
let client = OpenAIClient::new("dummy-key", "gpt-4o");
|
||||||
|
|
||||||
// validな設定(temperatureのみ)
|
// Valid configuration (temperature only)
|
||||||
let worker = Worker::new(client).temperature(0.7);
|
let worker = Worker::new(client).temperature(0.7);
|
||||||
|
|
||||||
// validate()を実行
|
// Run validate()
|
||||||
let result = worker.validate();
|
let result = worker.validate();
|
||||||
|
|
||||||
// 成功を確認
|
// Verify success
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
//! Workerフィクスチャベースの統合テスト
|
//! Worker fixture-based integration tests
|
||||||
//!
|
//!
|
||||||
//! 記録されたAPIレスポンスを使ってWorkerの動作をテストする。
|
//! Tests Worker behavior using recorded API responses.
|
||||||
//! APIキー不要でローカルで実行可能。
|
//! Can run locally without API keys.
|
||||||
|
|
||||||
mod common;
|
mod common;
|
||||||
|
|
||||||
|
|
@ -14,12 +14,12 @@ use common::MockLlmClient;
|
||||||
use llm_worker::Worker;
|
use llm_worker::Worker;
|
||||||
use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta};
|
use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta};
|
||||||
|
|
||||||
/// フィクスチャディレクトリのパス
|
/// Fixture directory path
|
||||||
fn fixtures_dir() -> std::path::PathBuf {
|
fn fixtures_dir() -> std::path::PathBuf {
|
||||||
Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/anthropic")
|
Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/anthropic")
|
||||||
}
|
}
|
||||||
|
|
||||||
/// シンプルなテスト用ツール
|
/// Simple test tool
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct MockWeatherTool {
|
struct MockWeatherTool {
|
||||||
call_count: Arc<AtomicUsize>,
|
call_count: Arc<AtomicUsize>,
|
||||||
|
|
@ -61,13 +61,13 @@ impl Tool for MockWeatherTool {
|
||||||
async fn execute(&self, input_json: &str) -> Result<String, ToolError> {
|
async fn execute(&self, input_json: &str) -> Result<String, ToolError> {
|
||||||
self.call_count.fetch_add(1, Ordering::SeqCst);
|
self.call_count.fetch_add(1, Ordering::SeqCst);
|
||||||
|
|
||||||
// 入力をパース
|
// Parse input
|
||||||
let input: serde_json::Value = serde_json::from_str(input_json)
|
let input: serde_json::Value = serde_json::from_str(input_json)
|
||||||
.map_err(|e| ToolError::InvalidArgument(e.to_string()))?;
|
.map_err(|e| ToolError::InvalidArgument(e.to_string()))?;
|
||||||
|
|
||||||
let city = input["city"].as_str().unwrap_or("Unknown");
|
let city = input["city"].as_str().unwrap_or("Unknown");
|
||||||
|
|
||||||
// モックのレスポンスを返す
|
// Return mock response
|
||||||
Ok(format!("Weather in {}: Sunny, 22°C", city))
|
Ok(format!("Weather in {}: Sunny, 22°C", city))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -76,12 +76,12 @@ impl Tool for MockWeatherTool {
|
||||||
// Basic Fixture Tests
|
// Basic Fixture Tests
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// MockLlmClientがJSONLフィクスチャファイルから正しくイベントをロードできることを確認
|
/// Verify that MockLlmClient can correctly load events from JSONL fixture files
|
||||||
///
|
///
|
||||||
/// 既存のanthropic_*.jsonlファイルを使用し、イベントがパース・ロードされることを検証する。
|
/// Uses existing anthropic_*.jsonl files to verify events are parsed and loaded.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_mock_client_from_fixture() {
|
fn test_mock_client_from_fixture() {
|
||||||
// 既存のフィクスチャをロード
|
// Load existing fixture
|
||||||
let fixture_path = fixtures_dir().join("anthropic_1767624445.jsonl");
|
let fixture_path = fixtures_dir().join("anthropic_1767624445.jsonl");
|
||||||
if !fixture_path.exists() {
|
if !fixture_path.exists() {
|
||||||
println!("Fixture not found, skipping test");
|
println!("Fixture not found, skipping test");
|
||||||
|
|
@ -93,14 +93,14 @@ fn test_mock_client_from_fixture() {
|
||||||
println!("Loaded {} events from fixture", client.event_count());
|
println!("Loaded {} events from fixture", client.event_count());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// MockLlmClientが直接指定されたイベントリストで正しく動作することを確認
|
/// Verify that MockLlmClient works correctly with directly specified event lists
|
||||||
///
|
///
|
||||||
/// fixtureファイルを使わず、プログラムでイベントを構築してクライアントを作成する。
|
/// Creates a client with programmatically constructed events instead of using fixture files.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_mock_client_from_events() {
|
fn test_mock_client_from_events() {
|
||||||
use llm_worker::llm_client::event::Event;
|
use llm_worker::llm_client::event::Event;
|
||||||
|
|
||||||
// 直接イベントを指定
|
// Specify events directly
|
||||||
let events = vec![
|
let events = vec![
|
||||||
Event::text_block_start(0),
|
Event::text_block_start(0),
|
||||||
Event::text_delta(0, "Hello!"),
|
Event::text_delta(0, "Hello!"),
|
||||||
|
|
@ -115,10 +115,10 @@ fn test_mock_client_from_events() {
|
||||||
// Worker Tests with Fixtures
|
// Worker Tests with Fixtures
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// Workerがシンプルなテキストレスポンスを正しく処理できることを確認
|
/// Verify that Worker can correctly process simple text responses
|
||||||
///
|
///
|
||||||
/// simple_text.jsonlフィクスチャを使用し、ツール呼び出しなしのシナリオをテストする。
|
/// Uses simple_text.jsonl fixture to test scenarios without tool calls.
|
||||||
/// フィクスチャがない場合はスキップされる。
|
/// Skipped if fixture is not present.
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_worker_simple_text_response() {
|
async fn test_worker_simple_text_response() {
|
||||||
let fixture_path = fixtures_dir().join("simple_text.jsonl");
|
let fixture_path = fixtures_dir().join("simple_text.jsonl");
|
||||||
|
|
@ -131,16 +131,16 @@ async fn test_worker_simple_text_response() {
|
||||||
let client = MockLlmClient::from_fixture(&fixture_path).unwrap();
|
let client = MockLlmClient::from_fixture(&fixture_path).unwrap();
|
||||||
let mut worker = Worker::new(client);
|
let mut worker = Worker::new(client);
|
||||||
|
|
||||||
// シンプルなメッセージを送信
|
// Send a simple message
|
||||||
let result = worker.run("Hello").await;
|
let result = worker.run("Hello").await;
|
||||||
|
|
||||||
assert!(result.is_ok(), "Worker should complete successfully");
|
assert!(result.is_ok(), "Worker should complete successfully");
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Workerがツール呼び出しを含むレスポンスを正しく処理できることを確認
|
/// Verify that Worker can correctly process responses containing tool calls
|
||||||
///
|
///
|
||||||
/// tool_call.jsonlフィクスチャを使用し、MockWeatherToolが呼び出されることをテストする。
|
/// Uses tool_call.jsonl fixture to test that MockWeatherTool is called.
|
||||||
/// max_turns=1に設定し、ツール実行後のループを防止。
|
/// Sets max_turns=1 to prevent loop after tool execution.
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_worker_tool_call() {
|
async fn test_worker_tool_call() {
|
||||||
let fixture_path = fixtures_dir().join("tool_call.jsonl");
|
let fixture_path = fixtures_dir().join("tool_call.jsonl");
|
||||||
|
|
@ -153,32 +153,32 @@ async fn test_worker_tool_call() {
|
||||||
let client = MockLlmClient::from_fixture(&fixture_path).unwrap();
|
let client = MockLlmClient::from_fixture(&fixture_path).unwrap();
|
||||||
let mut worker = Worker::new(client);
|
let mut worker = Worker::new(client);
|
||||||
|
|
||||||
// ツールを登録
|
// Register tool
|
||||||
let weather_tool = MockWeatherTool::new();
|
let weather_tool = MockWeatherTool::new();
|
||||||
let tool_for_check = weather_tool.clone();
|
let tool_for_check = weather_tool.clone();
|
||||||
worker.register_tool(weather_tool.definition()).unwrap();
|
worker.register_tool(weather_tool.definition()).unwrap();
|
||||||
|
|
||||||
// メッセージを送信
|
// Send message
|
||||||
let _result = worker.run("What's the weather in Tokyo?").await;
|
let _result = worker.run("What's the weather in Tokyo?").await;
|
||||||
|
|
||||||
// ツールが呼び出されたことを確認
|
// Verify tool was called
|
||||||
// Note: max_turns=1なのでツール結果後のリクエストは送信されない
|
// Note: max_turns=1 so no request is sent after tool result
|
||||||
let call_count = tool_for_check.get_call_count();
|
let call_count = tool_for_check.get_call_count();
|
||||||
println!("Tool was called {} times", call_count);
|
println!("Tool was called {} times", call_count);
|
||||||
|
|
||||||
// フィクスチャにToolUseが含まれていればツールが呼び出されるはず
|
// Tool should be called if fixture contains ToolUse
|
||||||
// ただしmax_turns=1なので1回で終了
|
// But ends after 1 turn due to max_turns=1
|
||||||
}
|
}
|
||||||
|
|
||||||
/// fixtureファイルなしでWorkerが動作することを確認
|
/// Verify that Worker works without fixture files
|
||||||
///
|
///
|
||||||
/// プログラムでイベントシーケンスを構築し、MockLlmClientに渡してテストする。
|
/// Constructs event sequence programmatically and passes to MockLlmClient.
|
||||||
/// テストの独立性を高め、外部ファイルへの依存を排除したい場合に有用。
|
/// Useful when test independence is needed and external file dependency should be eliminated.
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_worker_with_programmatic_events() {
|
async fn test_worker_with_programmatic_events() {
|
||||||
use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent};
|
use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent};
|
||||||
|
|
||||||
// プログラムでイベントシーケンスを構築
|
// Construct event sequence programmatically
|
||||||
let events = vec![
|
let events = vec![
|
||||||
Event::text_block_start(0),
|
Event::text_block_start(0),
|
||||||
Event::text_delta(0, "Hello, "),
|
Event::text_delta(0, "Hello, "),
|
||||||
|
|
@ -197,16 +197,16 @@ async fn test_worker_with_programmatic_events() {
|
||||||
assert!(result.is_ok(), "Worker should complete successfully");
|
assert!(result.is_ok(), "Worker should complete successfully");
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ToolCallCollectorがToolUseブロックイベントから正しくToolCallを収集することを確認
|
/// Verify that ToolCallCollector correctly collects ToolCall from ToolUse block events
|
||||||
///
|
///
|
||||||
/// Timelineにイベントをディスパッチし、ToolCallCollectorが
|
/// Dispatches events to Timeline and verifies ToolCallCollector
|
||||||
/// id, name, input(JSON)を正しく抽出できることを検証する。
|
/// correctly extracts id, name, and input (JSON).
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_tool_call_collector_integration() {
|
async fn test_tool_call_collector_integration() {
|
||||||
use llm_worker::llm_client::event::Event;
|
use llm_worker::llm_client::event::Event;
|
||||||
use llm_worker::timeline::{Timeline, ToolCallCollector};
|
use llm_worker::timeline::{Timeline, ToolCallCollector};
|
||||||
|
|
||||||
// ToolUseブロックを含むイベントシーケンス
|
// Event sequence containing ToolUse block
|
||||||
let events = vec![
|
let events = vec![
|
||||||
Event::tool_use_start(0, "call_123", "get_weather"),
|
Event::tool_use_start(0, "call_123", "get_weather"),
|
||||||
Event::tool_input_delta(0, r#"{"city":"#),
|
Event::tool_input_delta(0, r#"{"city":"#),
|
||||||
|
|
@ -218,13 +218,13 @@ async fn test_tool_call_collector_integration() {
|
||||||
let mut timeline = Timeline::new();
|
let mut timeline = Timeline::new();
|
||||||
timeline.on_tool_use_block(collector.clone());
|
timeline.on_tool_use_block(collector.clone());
|
||||||
|
|
||||||
// イベントをディスパッチ
|
// Dispatch events
|
||||||
for event in &events {
|
for event in &events {
|
||||||
let timeline_event: llm_worker::timeline::event::Event = event.clone().into();
|
let timeline_event: llm_worker::timeline::event::Event = event.clone().into();
|
||||||
timeline.dispatch(&timeline_event);
|
timeline.dispatch(&timeline_event);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 収集されたToolCallを確認
|
// Verify collected ToolCall
|
||||||
let calls = collector.take_collected();
|
let calls = collector.take_collected();
|
||||||
assert_eq!(calls.len(), 1, "Should collect one tool call");
|
assert_eq!(calls.len(), 1, "Should collect one tool call");
|
||||||
assert_eq!(calls[0].name, "get_weather");
|
assert_eq!(calls[0].name, "get_weather");
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,25 @@
|
||||||
//! Worker状態管理のテスト
|
//! Worker state management tests
|
||||||
//!
|
//!
|
||||||
//! Type-stateパターン(Mutable/CacheLocked)による状態遷移と
|
//! Tests for state transitions using the Type-state pattern (Mutable/CacheLocked)
|
||||||
//! ターン間の状態保持をテストする。
|
//! and state preservation between turns.
|
||||||
|
|
||||||
mod common;
|
mod common;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
use common::MockLlmClient;
|
use common::MockLlmClient;
|
||||||
use llm_worker::Worker;
|
use llm_worker::Worker;
|
||||||
use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent};
|
use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent};
|
||||||
use llm_worker::{Message, MessageContent};
|
use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta};
|
||||||
|
use llm_worker::Item;
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// Mutable状態のテスト
|
// Mutable State Tests
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// Mutable状態でシステムプロンプトを設定できることを確認
|
/// Verify that system prompt can be set in Mutable state
|
||||||
#[test]
|
#[test]
|
||||||
fn test_mutable_set_system_prompt() {
|
fn test_mutable_set_system_prompt() {
|
||||||
let client = MockLlmClient::new(vec![]);
|
let client = MockLlmClient::new(vec![]);
|
||||||
|
|
@ -29,114 +34,164 @@ fn test_mutable_set_system_prompt() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mutable状態で履歴を自由に編集できることを確認
|
/// Verify that history can be freely edited in Mutable state
|
||||||
#[test]
|
#[test]
|
||||||
fn test_mutable_history_manipulation() {
|
fn test_mutable_history_manipulation() {
|
||||||
let client = MockLlmClient::new(vec![]);
|
let client = MockLlmClient::new(vec![]);
|
||||||
let mut worker = Worker::new(client);
|
let mut worker = Worker::new(client);
|
||||||
|
|
||||||
// 初期状態は空
|
// Initial state is empty
|
||||||
assert!(worker.history().is_empty());
|
assert!(worker.history().is_empty());
|
||||||
|
|
||||||
// 履歴を追加
|
// Add to history
|
||||||
worker.push_message(Message::user("Hello"));
|
worker.push_item(Item::user_message("Hello"));
|
||||||
worker.push_message(Message::assistant("Hi there!"));
|
worker.push_item(Item::assistant_message("Hi there!"));
|
||||||
assert_eq!(worker.history().len(), 2);
|
assert_eq!(worker.history().len(), 2);
|
||||||
|
|
||||||
// 履歴への可変アクセス
|
// Mutable access to history
|
||||||
worker.history_mut().push(Message::user("How are you?"));
|
worker.history_mut().push(Item::user_message("How are you?"));
|
||||||
assert_eq!(worker.history().len(), 3);
|
assert_eq!(worker.history().len(), 3);
|
||||||
|
|
||||||
// 履歴をクリア
|
// Clear history
|
||||||
worker.clear_history();
|
worker.clear_history();
|
||||||
assert!(worker.history().is_empty());
|
assert!(worker.history().is_empty());
|
||||||
|
|
||||||
// 履歴を設定
|
// Set history
|
||||||
let messages = vec![Message::user("Test"), Message::assistant("Response")];
|
let items = vec![Item::user_message("Test"), Item::assistant_message("Response")];
|
||||||
worker.set_history(messages);
|
worker.set_history(items);
|
||||||
assert_eq!(worker.history().len(), 2);
|
assert_eq!(worker.history().len(), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ビルダーパターンでWorkerを構築できることを確認
|
/// Verify that Worker can be constructed using builder pattern
|
||||||
#[test]
|
#[test]
|
||||||
fn test_mutable_builder_pattern() {
|
fn test_mutable_builder_pattern() {
|
||||||
let client = MockLlmClient::new(vec![]);
|
let client = MockLlmClient::new(vec![]);
|
||||||
let worker = Worker::new(client)
|
let worker = Worker::new(client)
|
||||||
.system_prompt("System prompt")
|
.system_prompt("System prompt")
|
||||||
.with_message(Message::user("Hello"))
|
.with_item(Item::user_message("Hello"))
|
||||||
.with_message(Message::assistant("Hi!"))
|
.with_item(Item::assistant_message("Hi!"))
|
||||||
.with_messages(vec![
|
.with_items(vec![
|
||||||
Message::user("How are you?"),
|
Item::user_message("How are you?"),
|
||||||
Message::assistant("I'm fine!"),
|
Item::assistant_message("I'm fine!"),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
assert_eq!(worker.get_system_prompt(), Some("System prompt"));
|
assert_eq!(worker.get_system_prompt(), Some("System prompt"));
|
||||||
assert_eq!(worker.history().len(), 4);
|
assert_eq!(worker.history().len(), 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// extend_historyで複数メッセージを追加できることを確認
|
/// Verify that multiple items can be added with extend_history
|
||||||
#[test]
|
#[test]
|
||||||
fn test_mutable_extend_history() {
|
fn test_mutable_extend_history() {
|
||||||
let client = MockLlmClient::new(vec![]);
|
let client = MockLlmClient::new(vec![]);
|
||||||
let mut worker = Worker::new(client);
|
let mut worker = Worker::new(client);
|
||||||
|
|
||||||
worker.push_message(Message::user("First"));
|
worker.push_item(Item::user_message("First"));
|
||||||
|
|
||||||
worker.extend_history(vec![
|
worker.extend_history(vec![
|
||||||
Message::assistant("Response 1"),
|
Item::assistant_message("Response 1"),
|
||||||
Message::user("Second"),
|
Item::user_message("Second"),
|
||||||
Message::assistant("Response 2"),
|
Item::assistant_message("Response 2"),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
assert_eq!(worker.history().len(), 4);
|
assert_eq!(worker.history().len(), 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct CountingTool {
|
||||||
|
name: String,
|
||||||
|
calls: Arc<AtomicUsize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CountingTool {
|
||||||
|
fn new(name: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
name: name.into(),
|
||||||
|
calls: Arc::new(AtomicUsize::new(0)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn definition(&self) -> ToolDefinition {
|
||||||
|
let tool = self.clone();
|
||||||
|
Arc::new(move || {
|
||||||
|
(
|
||||||
|
ToolMeta::new(&tool.name)
|
||||||
|
.description("Counting tool")
|
||||||
|
.input_schema(serde_json::json!({"type":"object","properties":{}})),
|
||||||
|
Arc::new(tool.clone()) as Arc<dyn Tool>,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call_count(&self) -> usize {
|
||||||
|
self.calls.load(Ordering::SeqCst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for CountingTool {
|
||||||
|
async fn execute(&self, _input_json: &str) -> Result<String, ToolError> {
|
||||||
|
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||||
|
Ok(format!("{}-ok", self.name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Verify that tools can be registered in Mutable state.
|
||||||
|
#[test]
|
||||||
|
fn test_mutable_can_register_tool() {
|
||||||
|
let client = MockLlmClient::new(vec![]);
|
||||||
|
let mut worker = Worker::new(client);
|
||||||
|
let tool = CountingTool::new("count_tool");
|
||||||
|
|
||||||
|
let result = worker.register_tool(tool.definition());
|
||||||
|
assert!(result.is_ok(), "Mutable should allow tool registration");
|
||||||
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// 状態遷移テスト
|
// State Transition Tests
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// lock()でMutable -> CacheLocked状態に遷移することを確認
|
/// Verify that lock() transitions from Mutable -> CacheLocked state
|
||||||
#[test]
|
#[test]
|
||||||
fn test_lock_transition() {
|
fn test_lock_transition() {
|
||||||
let client = MockLlmClient::new(vec![]);
|
let client = MockLlmClient::new(vec![]);
|
||||||
let mut worker = Worker::new(client);
|
let mut worker = Worker::new(client);
|
||||||
|
|
||||||
worker.set_system_prompt("System");
|
worker.set_system_prompt("System");
|
||||||
worker.push_message(Message::user("Hello"));
|
worker.push_item(Item::user_message("Hello"));
|
||||||
worker.push_message(Message::assistant("Hi"));
|
worker.push_item(Item::assistant_message("Hi"));
|
||||||
|
|
||||||
// ロック
|
// Lock
|
||||||
let locked_worker = worker.lock();
|
let locked_worker = worker.lock();
|
||||||
|
|
||||||
// CacheLocked状態でも履歴とシステムプロンプトにアクセス可能
|
// History and system prompt are still accessible in CacheLocked state
|
||||||
assert_eq!(locked_worker.get_system_prompt(), Some("System"));
|
assert_eq!(locked_worker.get_system_prompt(), Some("System"));
|
||||||
assert_eq!(locked_worker.history().len(), 2);
|
assert_eq!(locked_worker.history().len(), 2);
|
||||||
assert_eq!(locked_worker.locked_prefix_len(), 2);
|
assert_eq!(locked_worker.locked_prefix_len(), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// unlock()でCacheLocked -> Mutable状態に遷移することを確認
|
/// Verify that unlock() transitions from CacheLocked -> Mutable state
|
||||||
#[test]
|
#[test]
|
||||||
fn test_unlock_transition() {
|
fn test_unlock_transition() {
|
||||||
let client = MockLlmClient::new(vec![]);
|
let client = MockLlmClient::new(vec![]);
|
||||||
let mut worker = Worker::new(client);
|
let mut worker = Worker::new(client);
|
||||||
|
|
||||||
worker.push_message(Message::user("Hello"));
|
worker.push_item(Item::user_message("Hello"));
|
||||||
let locked_worker = worker.lock();
|
let locked_worker = worker.lock();
|
||||||
|
|
||||||
// アンロック
|
// Unlock
|
||||||
let mut worker = locked_worker.unlock();
|
let mut worker = locked_worker.unlock();
|
||||||
|
|
||||||
// Mutable状態に戻ったので履歴操作が可能
|
// History operations are available again in Mutable state
|
||||||
worker.push_message(Message::assistant("Hi"));
|
worker.push_item(Item::assistant_message("Hi"));
|
||||||
worker.clear_history();
|
worker.clear_history();
|
||||||
assert!(worker.history().is_empty());
|
assert!(worker.history().is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// ターン実行と状態保持のテスト
|
// Turn Execution and State Preservation Tests
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// Mutable状態でターンを実行し、履歴が正しく更新されることを確認
|
/// Verify that history is correctly updated after running a turn in Mutable state
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_mutable_run_updates_history() {
|
async fn test_mutable_run_updates_history() {
|
||||||
let events = vec![
|
let events = vec![
|
||||||
|
|
@ -151,33 +206,27 @@ async fn test_mutable_run_updates_history() {
|
||||||
let client = MockLlmClient::new(events);
|
let client = MockLlmClient::new(events);
|
||||||
let mut worker = Worker::new(client);
|
let mut worker = Worker::new(client);
|
||||||
|
|
||||||
// 実行
|
// Execute
|
||||||
let result = worker.run("Hi there").await;
|
let result = worker.run("Hi there").await;
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
// 履歴が更新されている
|
// History is updated
|
||||||
let history = worker.history();
|
let history = worker.history();
|
||||||
assert_eq!(history.len(), 2); // user + assistant
|
assert_eq!(history.len(), 2); // user + assistant
|
||||||
|
|
||||||
// ユーザーメッセージ
|
// User message
|
||||||
assert!(matches!(
|
assert_eq!(history[0].as_text(), Some("Hi there"));
|
||||||
&history[0].content,
|
|
||||||
MessageContent::Text(t) if t == "Hi there"
|
|
||||||
));
|
|
||||||
|
|
||||||
// アシスタントメッセージ
|
// Assistant message
|
||||||
assert!(matches!(
|
assert_eq!(history[1].as_text(), Some("Hello, I'm an assistant!"));
|
||||||
&history[1].content,
|
|
||||||
MessageContent::Text(t) if t == "Hello, I'm an assistant!"
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// CacheLocked状態で複数ターンを実行し、履歴が正しく累積することを確認
|
/// Verify that history accumulates correctly over multiple turns in CacheLocked state
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_locked_multi_turn_history_accumulation() {
|
async fn test_locked_multi_turn_history_accumulation() {
|
||||||
// 2回のリクエストに対応するレスポンスを準備
|
// Prepare responses for 2 requests
|
||||||
let client = MockLlmClient::with_responses(vec![
|
let client = MockLlmClient::with_responses(vec![
|
||||||
// 1回目のレスポンス
|
// First response
|
||||||
vec![
|
vec![
|
||||||
Event::text_block_start(0),
|
Event::text_block_start(0),
|
||||||
Event::text_delta(0, "Nice to meet you!"),
|
Event::text_delta(0, "Nice to meet you!"),
|
||||||
|
|
@ -186,7 +235,7 @@ async fn test_locked_multi_turn_history_accumulation() {
|
||||||
status: ResponseStatus::Completed,
|
status: ResponseStatus::Completed,
|
||||||
}),
|
}),
|
||||||
],
|
],
|
||||||
// 2回目のレスポンス
|
// Second response
|
||||||
vec![
|
vec![
|
||||||
Event::text_block_start(0),
|
Event::text_block_start(0),
|
||||||
Event::text_delta(0, "I can help with that."),
|
Event::text_delta(0, "I can help with that."),
|
||||||
|
|
@ -199,37 +248,37 @@ async fn test_locked_multi_turn_history_accumulation() {
|
||||||
|
|
||||||
let worker = Worker::new(client).system_prompt("You are helpful.");
|
let worker = Worker::new(client).system_prompt("You are helpful.");
|
||||||
|
|
||||||
// ロック(システムプロンプト設定後)
|
// Lock (after setting system prompt)
|
||||||
let mut locked_worker = worker.lock();
|
let mut locked_worker = worker.lock();
|
||||||
assert_eq!(locked_worker.locked_prefix_len(), 0); // メッセージはまだない
|
assert_eq!(locked_worker.locked_prefix_len(), 0); // No items yet
|
||||||
|
|
||||||
// 1ターン目
|
// Turn 1
|
||||||
let result1 = locked_worker.run("Hello!").await;
|
let result1 = locked_worker.run("Hello!").await;
|
||||||
assert!(result1.is_ok());
|
assert!(result1.is_ok());
|
||||||
assert_eq!(locked_worker.history().len(), 2); // user + assistant
|
assert_eq!(locked_worker.history().len(), 2); // user + assistant
|
||||||
|
|
||||||
// 2ターン目
|
// Turn 2
|
||||||
let result2 = locked_worker.run("Can you help me?").await;
|
let result2 = locked_worker.run("Can you help me?").await;
|
||||||
assert!(result2.is_ok());
|
assert!(result2.is_ok());
|
||||||
assert_eq!(locked_worker.history().len(), 4); // 2 * (user + assistant)
|
assert_eq!(locked_worker.history().len(), 4); // 2 * (user + assistant)
|
||||||
|
|
||||||
// 履歴の内容を確認
|
// Verify history contents
|
||||||
let history = locked_worker.history();
|
let history = locked_worker.history();
|
||||||
|
|
||||||
// 1ターン目のユーザーメッセージ
|
// Turn 1 user message
|
||||||
assert!(matches!(&history[0].content, MessageContent::Text(t) if t == "Hello!"));
|
assert_eq!(history[0].as_text(), Some("Hello!"));
|
||||||
|
|
||||||
// 1ターン目のアシスタントメッセージ
|
// Turn 1 assistant message
|
||||||
assert!(matches!(&history[1].content, MessageContent::Text(t) if t == "Nice to meet you!"));
|
assert_eq!(history[1].as_text(), Some("Nice to meet you!"));
|
||||||
|
|
||||||
// 2ターン目のユーザーメッセージ
|
// Turn 2 user message
|
||||||
assert!(matches!(&history[2].content, MessageContent::Text(t) if t == "Can you help me?"));
|
assert_eq!(history[2].as_text(), Some("Can you help me?"));
|
||||||
|
|
||||||
// 2ターン目のアシスタントメッセージ
|
// Turn 2 assistant message
|
||||||
assert!(matches!(&history[3].content, MessageContent::Text(t) if t == "I can help with that."));
|
assert_eq!(history[3].as_text(), Some("I can help with that."));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// locked_prefix_lenがロック時点の履歴長を正しく記録することを確認
|
/// Verify that locked_prefix_len correctly records history length at lock time
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_locked_prefix_len_tracking() {
|
async fn test_locked_prefix_len_tracking() {
|
||||||
let client = MockLlmClient::with_responses(vec![
|
let client = MockLlmClient::with_responses(vec![
|
||||||
|
|
@ -253,25 +302,25 @@ async fn test_locked_prefix_len_tracking() {
|
||||||
|
|
||||||
let mut worker = Worker::new(client);
|
let mut worker = Worker::new(client);
|
||||||
|
|
||||||
// 事前にメッセージを追加
|
// Add items beforehand
|
||||||
worker.push_message(Message::user("Pre-existing message 1"));
|
worker.push_item(Item::user_message("Pre-existing message 1"));
|
||||||
worker.push_message(Message::assistant("Pre-existing response 1"));
|
worker.push_item(Item::assistant_message("Pre-existing response 1"));
|
||||||
|
|
||||||
assert_eq!(worker.history().len(), 2);
|
assert_eq!(worker.history().len(), 2);
|
||||||
|
|
||||||
// ロック
|
// Lock
|
||||||
let mut locked_worker = worker.lock();
|
let mut locked_worker = worker.lock();
|
||||||
assert_eq!(locked_worker.locked_prefix_len(), 2); // ロック時点で2メッセージ
|
assert_eq!(locked_worker.locked_prefix_len(), 2); // 2 items at lock time
|
||||||
|
|
||||||
// ターン実行
|
// Execute turn
|
||||||
locked_worker.run("New message").await.unwrap();
|
locked_worker.run("New message").await.unwrap();
|
||||||
|
|
||||||
// 履歴は増えるが、locked_prefix_lenは変わらない
|
// History grows but locked_prefix_len remains unchanged
|
||||||
assert_eq!(locked_worker.history().len(), 4); // 2 + 2
|
assert_eq!(locked_worker.history().len(), 4); // 2 + 2
|
||||||
assert_eq!(locked_worker.locked_prefix_len(), 2); // 変わらない
|
assert_eq!(locked_worker.locked_prefix_len(), 2); // Unchanged
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ターンカウントが正しくインクリメントされることを確認
|
/// Verify that turn count is correctly incremented
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_turn_count_increment() {
|
async fn test_turn_count_increment() {
|
||||||
let client = MockLlmClient::with_responses(vec![
|
let client = MockLlmClient::with_responses(vec![
|
||||||
|
|
@ -304,7 +353,7 @@ async fn test_turn_count_increment() {
|
||||||
assert_eq!(worker.turn_count(), 2);
|
assert_eq!(worker.turn_count(), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// unlock後に履歴を編集し、再度lockできることを確認
|
/// Verify that history can be edited after unlock and re-locked
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_unlock_edit_relock() {
|
async fn test_unlock_edit_relock() {
|
||||||
let client = MockLlmClient::with_responses(vec![vec![
|
let client = MockLlmClient::with_responses(vec![vec![
|
||||||
|
|
@ -317,30 +366,91 @@ async fn test_unlock_edit_relock() {
|
||||||
]]);
|
]]);
|
||||||
|
|
||||||
let worker = Worker::new(client)
|
let worker = Worker::new(client)
|
||||||
.with_message(Message::user("Hello"))
|
.with_item(Item::user_message("Hello"))
|
||||||
.with_message(Message::assistant("Hi"));
|
.with_item(Item::assistant_message("Hi"));
|
||||||
|
|
||||||
// ロック -> アンロック
|
// Lock -> Unlock
|
||||||
let locked = worker.lock();
|
let locked = worker.lock();
|
||||||
assert_eq!(locked.locked_prefix_len(), 2);
|
assert_eq!(locked.locked_prefix_len(), 2);
|
||||||
|
|
||||||
let mut unlocked = locked.unlock();
|
let mut unlocked = locked.unlock();
|
||||||
|
|
||||||
// 履歴を編集
|
// Edit history
|
||||||
unlocked.clear_history();
|
unlocked.clear_history();
|
||||||
unlocked.push_message(Message::user("Fresh start"));
|
unlocked.push_item(Item::user_message("Fresh start"));
|
||||||
|
|
||||||
// 再ロック
|
// Re-lock
|
||||||
let relocked = unlocked.lock();
|
let relocked = unlocked.lock();
|
||||||
assert_eq!(relocked.history().len(), 1);
|
assert_eq!(relocked.history().len(), 1);
|
||||||
assert_eq!(relocked.locked_prefix_len(), 1);
|
assert_eq!(relocked.locked_prefix_len(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Verify that tools registered before lock and after unlock remain effective.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_lock_unlock_relock_tools_remain_effective() {
|
||||||
|
let client = MockLlmClient::with_responses(vec![
|
||||||
|
vec![
|
||||||
|
Event::tool_use_start(0, "call_1", "tool_a"),
|
||||||
|
Event::tool_input_delta(0, r#"{}"#),
|
||||||
|
Event::tool_use_stop(0),
|
||||||
|
Event::Status(StatusEvent {
|
||||||
|
status: ResponseStatus::Completed,
|
||||||
|
}),
|
||||||
|
],
|
||||||
|
vec![
|
||||||
|
Event::text_block_start(0),
|
||||||
|
Event::text_delta(0, "done-a"),
|
||||||
|
Event::text_block_stop(0, None),
|
||||||
|
Event::Status(StatusEvent {
|
||||||
|
status: ResponseStatus::Completed,
|
||||||
|
}),
|
||||||
|
],
|
||||||
|
vec![
|
||||||
|
Event::tool_use_start(0, "call_2", "tool_b"),
|
||||||
|
Event::tool_input_delta(0, r#"{}"#),
|
||||||
|
Event::tool_use_stop(0),
|
||||||
|
Event::Status(StatusEvent {
|
||||||
|
status: ResponseStatus::Completed,
|
||||||
|
}),
|
||||||
|
],
|
||||||
|
vec![
|
||||||
|
Event::text_block_start(0),
|
||||||
|
Event::text_delta(0, "done-b"),
|
||||||
|
Event::text_block_stop(0, None),
|
||||||
|
Event::Status(StatusEvent {
|
||||||
|
status: ResponseStatus::Completed,
|
||||||
|
}),
|
||||||
|
],
|
||||||
|
]);
|
||||||
|
|
||||||
|
let mut worker = Worker::new(client);
|
||||||
|
let tool_a = CountingTool::new("tool_a");
|
||||||
|
worker
|
||||||
|
.register_tool(tool_a.definition())
|
||||||
|
.expect("register tool_a should succeed");
|
||||||
|
|
||||||
|
let mut locked = worker.lock();
|
||||||
|
locked.run("first").await.expect("first run");
|
||||||
|
assert_eq!(tool_a.call_count(), 1, "tool_a should be called once");
|
||||||
|
|
||||||
|
let mut unlocked = locked.unlock();
|
||||||
|
let tool_b = CountingTool::new("tool_b");
|
||||||
|
unlocked
|
||||||
|
.register_tool(tool_b.definition())
|
||||||
|
.expect("register tool_b after unlock should succeed");
|
||||||
|
|
||||||
|
let mut relocked = unlocked.lock();
|
||||||
|
relocked.run("second").await.expect("second run");
|
||||||
|
|
||||||
|
assert_eq!(tool_a.call_count(), 1, "tool_a should not be called again");
|
||||||
|
assert_eq!(tool_b.call_count(), 1, "tool_b should be called once");
|
||||||
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// システムプロンプト保持のテスト
|
// System Prompt Preservation Tests
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
/// CacheLocked状態でもシステムプロンプトが保持されることを確認
|
/// Verify that system prompt is preserved in CacheLocked state
|
||||||
#[test]
|
#[test]
|
||||||
fn test_system_prompt_preserved_in_locked_state() {
|
fn test_system_prompt_preserved_in_locked_state() {
|
||||||
let client = MockLlmClient::new(vec![]);
|
let client = MockLlmClient::new(vec![]);
|
||||||
|
|
@ -356,7 +466,7 @@ fn test_system_prompt_preserved_in_locked_state() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// unlock -> 再lock でシステムプロンプトを変更できることを確認
|
/// Verify that system prompt can be changed after unlock -> re-lock
|
||||||
#[test]
|
#[test]
|
||||||
fn test_system_prompt_change_after_unlock() {
|
fn test_system_prompt_change_after_unlock() {
|
||||||
let client = MockLlmClient::new(vec![]);
|
let client = MockLlmClient::new(vec![]);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user