alpha-release: 0.0.1 #1
140
worker-types/src/hook.rs
Normal file
140
worker-types/src/hook.rs
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
//! Hook関連の型定義
|
||||
//!
|
||||
//! Worker層でのターン制御・介入に使用される型
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use thiserror::Error;
|
||||
|
||||
// =============================================================================
|
||||
// Control Flow Types
|
||||
// =============================================================================
|
||||
|
||||
/// Hook処理の制御フロー
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ControlFlow {
|
||||
/// 処理を続行
|
||||
Continue,
|
||||
/// 現在の処理をスキップ(Tool実行など)
|
||||
Skip,
|
||||
/// 処理を中断
|
||||
Abort(String),
|
||||
}
|
||||
|
||||
/// ターン終了時の判定結果
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TurnResult {
|
||||
/// ターンを終了
|
||||
Finish,
|
||||
/// メッセージを追加してターン継続(自己修正など)
|
||||
ContinueWithMessages(Vec<crate::Message>),
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Tool Call / Result Types
|
||||
// =============================================================================
|
||||
|
||||
/// ツール呼び出し情報
|
||||
///
|
||||
/// LLMからのToolUseブロックを表現し、Hook処理で改変可能
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
/// ツール呼び出しID(レスポンスとの紐付けに使用)
|
||||
pub id: String,
|
||||
/// ツール名
|
||||
pub name: String,
|
||||
/// 入力引数(JSON)
|
||||
pub input: Value,
|
||||
}
|
||||
|
||||
/// ツール実行結果
|
||||
///
|
||||
/// ツール実行後の結果を表現し、Hook処理で改変可能
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolResult {
|
||||
/// 対応するツール呼び出しID
|
||||
pub tool_use_id: String,
|
||||
/// 結果コンテンツ
|
||||
pub content: String,
|
||||
/// エラーかどうか
|
||||
#[serde(default)]
|
||||
pub is_error: bool,
|
||||
}
|
||||
|
||||
impl ToolResult {
|
||||
/// 成功結果を作成
|
||||
pub fn success(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
tool_use_id: tool_use_id.into(),
|
||||
content: content.into(),
|
||||
is_error: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// エラー結果を作成
|
||||
pub fn error(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
tool_use_id: tool_use_id.into(),
|
||||
content: content.into(),
|
||||
is_error: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Hook Error
|
||||
// =============================================================================
|
||||
|
||||
/// Hookエラー
|
||||
#[derive(Debug, Error)]
|
||||
pub enum HookError {
|
||||
/// 処理が中断された
|
||||
#[error("Aborted: {0}")]
|
||||
Aborted(String),
|
||||
/// 内部エラー
|
||||
#[error("Hook error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// WorkerHook Trait
|
||||
// =============================================================================
|
||||
|
||||
/// Worker Hook trait
|
||||
///
|
||||
/// ターンの進行・メッセージ・ツール実行に対して介入するためのトレイト。
|
||||
/// デフォルト実装では何も行わずContinueを返す。
|
||||
#[async_trait]
|
||||
pub trait WorkerHook: Send + Sync {
|
||||
/// メッセージ送信前
|
||||
///
|
||||
/// リクエストに含まれるメッセージリストを改変できる。
|
||||
async fn on_message_send(
|
||||
&self,
|
||||
_context: &mut Vec<crate::Message>,
|
||||
) -> Result<ControlFlow, HookError> {
|
||||
Ok(ControlFlow::Continue)
|
||||
}
|
||||
|
||||
/// ツール実行前
|
||||
///
|
||||
/// 実行をキャンセルしたり、引数を書き換えることができる。
|
||||
async fn before_tool_call(&self, _tool_call: &mut ToolCall) -> Result<ControlFlow, HookError> {
|
||||
Ok(ControlFlow::Continue)
|
||||
}
|
||||
|
||||
/// ツール実行後
|
||||
///
|
||||
/// 結果を書き換えたり、隠蔽したりできる。
|
||||
async fn after_tool_call(&self, _tool_result: &mut ToolResult) -> Result<ControlFlow, HookError> {
|
||||
Ok(ControlFlow::Continue)
|
||||
}
|
||||
|
||||
/// ターン終了時
|
||||
///
|
||||
/// 生成されたメッセージを検査し、必要ならリトライを指示できる。
|
||||
async fn on_turn_end(&self, _messages: &[crate::Message]) -> Result<TurnResult, HookError> {
|
||||
Ok(TurnResult::Finish)
|
||||
}
|
||||
}
|
||||
|
|
@ -3,12 +3,19 @@
|
|||
//! このクレートは以下を提供します:
|
||||
//! - Event: llm_client層からのフラットなイベント列挙
|
||||
//! - Kind/Handler: タイムライン層でのイベント処理トレイト
|
||||
//! - Tool: ツール定義トレイト
|
||||
//! - Hook: Worker層での介入用トレイト
|
||||
//! - Message: メッセージ型
|
||||
//! - 各種イベント構造体
|
||||
|
||||
mod event;
|
||||
mod handler;
|
||||
mod hook;
|
||||
mod message;
|
||||
mod tool;
|
||||
|
||||
pub use event::*;
|
||||
pub use handler::*;
|
||||
pub use hook::*;
|
||||
pub use message::*;
|
||||
pub use tool::*;
|
||||
|
|
|
|||
87
worker-types/src/message.rs
Normal file
87
worker-types/src/message.rs
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
//! メッセージ型定義
|
||||
//!
|
||||
//! LLM会話で使用されるメッセージ構造
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// メッセージのロール
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
/// ユーザー
|
||||
User,
|
||||
/// アシスタント
|
||||
Assistant,
|
||||
}
|
||||
|
||||
/// メッセージ
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
/// ロール
|
||||
pub role: Role,
|
||||
/// コンテンツ
|
||||
pub content: MessageContent,
|
||||
}
|
||||
|
||||
/// メッセージコンテンツ
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum MessageContent {
|
||||
/// テキストコンテンツ
|
||||
Text(String),
|
||||
/// ツール結果
|
||||
ToolResult {
|
||||
tool_use_id: String,
|
||||
content: String,
|
||||
},
|
||||
/// 複合コンテンツ (テキスト + ツール使用等)
|
||||
Parts(Vec<ContentPart>),
|
||||
}
|
||||
|
||||
/// コンテンツパーツ
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ContentPart {
|
||||
/// テキスト
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
/// ツール使用
|
||||
#[serde(rename = "tool_use")]
|
||||
ToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input: serde_json::Value,
|
||||
},
|
||||
/// ツール結果
|
||||
#[serde(rename = "tool_result")]
|
||||
ToolResult { tool_use_id: String, content: String },
|
||||
}
|
||||
|
||||
impl Message {
|
||||
/// ユーザーメッセージを作成
|
||||
pub fn user(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text(content.into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// アシスタントメッセージを作成
|
||||
pub fn assistant(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: Role::Assistant,
|
||||
content: MessageContent::Text(content.into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// ツール結果メッセージを作成
|
||||
pub fn tool_result(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: Role::User,
|
||||
content: MessageContent::ToolResult {
|
||||
tool_use_id: tool_use_id.into(),
|
||||
content: content.into(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
99
worker/examples/record_test_fixtures/main.rs
Normal file
99
worker/examples/record_test_fixtures/main.rs
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
//! テストフィクスチャ記録ツール
|
||||
//!
|
||||
//! 定義されたシナリオのAPIレスポンスを記録する。
|
||||
//!
|
||||
//! ## 使用方法
|
||||
//!
|
||||
//! ```bash
|
||||
//! # 利用可能なシナリオを表示
|
||||
//! cargo run --example record_test_fixtures
|
||||
//!
|
||||
//! # 特定のシナリオを記録
|
||||
//! ANTHROPIC_API_KEY=your-key cargo run --example record_test_fixtures -- simple_text
|
||||
//! ANTHROPIC_API_KEY=your-key cargo run --example record_test_fixtures -- tool_call
|
||||
//!
|
||||
//! # 全シナリオを記録
|
||||
//! ANTHROPIC_API_KEY=your-key cargo run --example record_test_fixtures -- --all
|
||||
//! ```
|
||||
|
||||
mod recorder;
|
||||
mod scenarios;
|
||||
|
||||
use worker::llm_client::providers::anthropic::AnthropicClient;
|
||||
|
||||
fn print_usage() {
|
||||
println!("Usage: cargo run --example record_test_fixtures -- <scenario_name>");
|
||||
println!(" cargo run --example record_test_fixtures -- --all");
|
||||
println!();
|
||||
println!("Available scenarios:");
|
||||
for scenario in scenarios::scenarios() {
|
||||
println!(" {:20} - {}", scenario.output_name, scenario.name);
|
||||
}
|
||||
println!();
|
||||
println!("Options:");
|
||||
println!(" --all Record all scenarios");
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
|
||||
// 引数がなければ使い方を表示
|
||||
if args.len() < 2 {
|
||||
print_usage();
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let arg = &args[1];
|
||||
|
||||
// 全シナリオを取得
|
||||
let all_scenarios = scenarios::scenarios();
|
||||
|
||||
// 実行するシナリオを決定
|
||||
let scenarios_to_run: Vec<_> = if arg == "--all" {
|
||||
all_scenarios
|
||||
} else {
|
||||
// 指定されたシナリオを検索
|
||||
let found: Vec<_> = all_scenarios
|
||||
.into_iter()
|
||||
.filter(|s| s.output_name == arg)
|
||||
.collect();
|
||||
|
||||
if found.is_empty() {
|
||||
eprintln!("Error: Unknown scenario '{}'", arg);
|
||||
println!();
|
||||
print_usage();
|
||||
std::process::exit(1);
|
||||
}
|
||||
found
|
||||
};
|
||||
|
||||
// APIキーを取得
|
||||
let api_key = std::env::var("ANTHROPIC_API_KEY")
|
||||
.expect("ANTHROPIC_API_KEY environment variable must be set");
|
||||
|
||||
let model = "claude-sonnet-4-20250514";
|
||||
|
||||
println!("=== Test Fixture Generator ===");
|
||||
println!("Model: {}", model);
|
||||
println!("Scenarios: {}\n", scenarios_to_run.len());
|
||||
|
||||
let client = AnthropicClient::new(&api_key, model);
|
||||
|
||||
// シナリオを記録
|
||||
for scenario in scenarios_to_run {
|
||||
recorder::record_request(
|
||||
&client,
|
||||
scenario.request,
|
||||
scenario.name,
|
||||
scenario.output_name,
|
||||
model,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
println!("\n✅ Done!");
|
||||
println!("Run tests with: cargo test -p worker");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
100
worker/examples/record_test_fixtures/recorder.rs
Normal file
100
worker/examples/record_test_fixtures/recorder.rs
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
//! テストフィクスチャ記録機構
|
||||
//!
|
||||
//! イベントをJSONLフォーマットでファイルに保存する
|
||||
|
||||
use std::fs::{self, File};
|
||||
use std::io::{BufWriter, Write};
|
||||
use std::path::Path;
|
||||
use std::time::{Instant, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use futures::StreamExt;
|
||||
use worker::llm_client::{LlmClient, Request};
|
||||
|
||||
/// 記録されたイベント
|
||||
#[derive(Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct RecordedEvent {
|
||||
pub elapsed_ms: u64,
|
||||
pub event_type: String,
|
||||
pub data: String,
|
||||
}
|
||||
|
||||
/// セッションメタデータ
|
||||
#[derive(Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct SessionMetadata {
|
||||
pub timestamp: u64,
|
||||
pub model: String,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
/// イベントシーケンスをファイルに保存
|
||||
pub fn save_fixture(
|
||||
path: impl AsRef<Path>,
|
||||
metadata: &SessionMetadata,
|
||||
events: &[RecordedEvent],
|
||||
) -> std::io::Result<()> {
|
||||
let file = File::create(path)?;
|
||||
let mut writer = BufWriter::new(file);
|
||||
|
||||
writeln!(writer, "{}", serde_json::to_string(metadata)?)?;
|
||||
for event in events {
|
||||
writeln!(writer, "{}", serde_json::to_string(event)?)?;
|
||||
}
|
||||
writer.flush()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// リクエストを送信してイベントを記録
|
||||
pub async fn record_request<C: LlmClient>(
|
||||
client: &C,
|
||||
request: Request,
|
||||
description: &str,
|
||||
output_name: &str,
|
||||
model: &str,
|
||||
) -> Result<usize, Box<dyn std::error::Error>> {
|
||||
println!("\n📝 Recording: {}", description);
|
||||
|
||||
let start_time = Instant::now();
|
||||
let mut events: Vec<RecordedEvent> = Vec::new();
|
||||
|
||||
let mut stream = client.stream(request).await?;
|
||||
|
||||
while let Some(result) = stream.next().await {
|
||||
let elapsed = start_time.elapsed().as_millis() as u64;
|
||||
match result {
|
||||
Ok(event) => {
|
||||
let event_json = serde_json::to_string(&event)?;
|
||||
println!(" [{:>6}ms] {:?}", elapsed, event);
|
||||
events.push(RecordedEvent {
|
||||
elapsed_ms: elapsed,
|
||||
event_type: format!("{:?}", std::mem::discriminant(&event)),
|
||||
data: event_json,
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!(" Error: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 保存
|
||||
let fixtures_dir = Path::new("worker/tests/fixtures");
|
||||
fs::create_dir_all(fixtures_dir)?;
|
||||
|
||||
let filepath = fixtures_dir.join(format!("{}.jsonl", output_name));
|
||||
|
||||
let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
|
||||
let metadata = SessionMetadata {
|
||||
timestamp,
|
||||
model: model.to_string(),
|
||||
description: description.to_string(),
|
||||
};
|
||||
|
||||
save_fixture(&filepath, &metadata, &events)?;
|
||||
|
||||
let event_count = events.len();
|
||||
println!(" 💾 Saved: {}", filepath.display());
|
||||
println!(" 📊 {} events recorded", event_count);
|
||||
|
||||
Ok(event_count)
|
||||
}
|
||||
61
worker/examples/record_test_fixtures/scenarios.rs
Normal file
61
worker/examples/record_test_fixtures/scenarios.rs
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
//! テストフィクスチャ用リクエスト定義
|
||||
//!
|
||||
//! 各シナリオのリクエストと出力ファイル名を定義
|
||||
|
||||
use worker::llm_client::{Request, ToolDefinition};
|
||||
|
||||
/// テストシナリオ
|
||||
pub struct TestScenario {
|
||||
/// シナリオ名(説明)
|
||||
pub name: &'static str,
|
||||
/// 出力ファイル名(拡張子なし)
|
||||
pub output_name: &'static str,
|
||||
/// リクエスト
|
||||
pub request: Request,
|
||||
}
|
||||
|
||||
/// 全てのテストシナリオを取得
|
||||
pub fn scenarios() -> Vec<TestScenario> {
|
||||
vec![
|
||||
simple_text_scenario(),
|
||||
tool_call_scenario(),
|
||||
]
|
||||
}
|
||||
|
||||
/// シンプルなテキストレスポンス
|
||||
fn simple_text_scenario() -> TestScenario {
|
||||
TestScenario {
|
||||
name: "Simple text response",
|
||||
output_name: "simple_text",
|
||||
request: Request::new()
|
||||
.system("You are a helpful assistant. Be very concise.")
|
||||
.user("Say hello in one word.")
|
||||
.max_tokens(50),
|
||||
}
|
||||
}
|
||||
|
||||
/// ツール呼び出しを含むレスポンス
|
||||
fn tool_call_scenario() -> TestScenario {
|
||||
let get_weather_tool = ToolDefinition::new("get_weather")
|
||||
.description("Get the current weather for a city")
|
||||
.input_schema(serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city name"
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}));
|
||||
|
||||
TestScenario {
|
||||
name: "Tool call response",
|
||||
output_name: "tool_call",
|
||||
request: Request::new()
|
||||
.system("You are a helpful assistant. Use tools when appropriate.")
|
||||
.user("What's the weather in Tokyo? Use the get_weather tool.")
|
||||
.tool(get_weather_tool)
|
||||
.max_tokens(200),
|
||||
}
|
||||
}
|
||||
|
|
@ -1,12 +1,17 @@
|
|||
//! worker - LLMワーカーのメイン実装
|
||||
//!
|
||||
//! このクレートは以下を提供します:
|
||||
//! - Worker: ターン制御を行う高レベルコンポーネント
|
||||
//! - Timeline: イベントストリームの状態管理とハンドラーへのディスパッチ
|
||||
//! - LlmClient: LLMプロバイダとの通信
|
||||
//! - 型消去されたHandler実装
|
||||
|
||||
pub mod llm_client;
|
||||
mod timeline;
|
||||
mod tool_call_collector;
|
||||
mod worker;
|
||||
|
||||
pub use timeline::*;
|
||||
pub use tool_call_collector::ToolCallCollector;
|
||||
pub use worker::*;
|
||||
pub use worker_types::*;
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
//! - **client**: `LlmClient` trait定義
|
||||
//! - **scheme**: APIスキーマ(リクエスト/レスポンス変換)
|
||||
//! - **providers**: プロバイダ固有のHTTPクライアント実装
|
||||
//! - **testing**: テスト用のAPIレスポンス記録・再生機能
|
||||
|
||||
pub mod client;
|
||||
pub mod error;
|
||||
|
|
@ -16,9 +15,6 @@ pub mod types;
|
|||
pub mod providers;
|
||||
pub(crate) mod scheme;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod testing;
|
||||
|
||||
pub use client::*;
|
||||
pub use error::*;
|
||||
pub use types::*;
|
||||
|
|
|
|||
144
worker/src/tool_call_collector.rs
Normal file
144
worker/src/tool_call_collector.rs
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
//! ToolCallCollector - ツール呼び出し収集用ハンドラ
|
||||
//!
|
||||
//! TimelineのToolUseBlockHandler として登録され、
|
||||
//! ストリーム中のToolUseブロックを収集する。
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
use worker_types::{Handler, ToolCall, ToolUseBlockEvent, ToolUseBlockKind};
|
||||
|
||||
/// ToolUseブロックから収集したツール呼び出し情報を保持
|
||||
///
|
||||
/// ToolCallCollectorのHandler実装で使用するスコープ型
|
||||
#[derive(Debug, Default)]
|
||||
pub struct CollectorState {
|
||||
/// 現在のツール呼び出し情報 (ブロック進行中)
|
||||
current_id: Option<String>,
|
||||
current_name: Option<String>,
|
||||
/// 蓄積中のJSON入力
|
||||
input_json_buffer: String,
|
||||
}
|
||||
|
||||
/// ToolCallCollector - ToolUseブロックハンドラ
|
||||
///
|
||||
/// Timelineに登録してToolUseブロックイベントを受信し、
|
||||
/// 完了したToolCallを収集する。
|
||||
#[derive(Clone)]
|
||||
pub struct ToolCallCollector {
|
||||
/// 収集されたToolCall
|
||||
collected: Arc<Mutex<Vec<ToolCall>>>,
|
||||
}
|
||||
|
||||
impl ToolCallCollector {
|
||||
/// 新しいToolCallCollectorを作成
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
collected: Arc::new(Mutex::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// 収集されたToolCallを取得してクリア
|
||||
pub fn take_collected(&self) -> Vec<ToolCall> {
|
||||
let mut guard = self.collected.lock().unwrap();
|
||||
std::mem::take(&mut *guard)
|
||||
}
|
||||
|
||||
/// 収集されたToolCallの参照を取得
|
||||
pub fn collected(&self) -> Vec<ToolCall> {
|
||||
self.collected.lock().unwrap().clone()
|
||||
}
|
||||
|
||||
/// 収集されたToolCallがあるかどうか
|
||||
pub fn has_pending_calls(&self) -> bool {
|
||||
!self.collected.lock().unwrap().is_empty()
|
||||
}
|
||||
|
||||
/// 収集をクリア
|
||||
pub fn clear(&self) {
|
||||
self.collected.lock().unwrap().clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ToolCallCollector {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Handler<ToolUseBlockKind> for ToolCallCollector {
|
||||
type Scope = CollectorState;
|
||||
|
||||
fn on_event(&mut self, scope: &mut Self::Scope, event: &ToolUseBlockEvent) {
|
||||
match event {
|
||||
ToolUseBlockEvent::Start(start) => {
|
||||
scope.current_id = Some(start.id.clone());
|
||||
scope.current_name = Some(start.name.clone());
|
||||
scope.input_json_buffer.clear();
|
||||
}
|
||||
ToolUseBlockEvent::InputJsonDelta(delta) => {
|
||||
scope.input_json_buffer.push_str(delta);
|
||||
}
|
||||
ToolUseBlockEvent::Stop(_stop) => {
|
||||
// ブロック完了時にToolCallを確定
|
||||
if let (Some(id), Some(name)) = (scope.current_id.take(), scope.current_name.take())
|
||||
{
|
||||
let input = serde_json::from_str(&scope.input_json_buffer)
|
||||
.unwrap_or(serde_json::Value::Null);
|
||||
|
||||
let tool_call = ToolCall { id, name, input };
|
||||
|
||||
self.collected.lock().unwrap().push(tool_call);
|
||||
}
|
||||
scope.input_json_buffer.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::Timeline;
|
||||
use worker_types::Event;
|
||||
|
||||
#[test]
|
||||
fn test_collect_single_tool_call() {
|
||||
let collector = ToolCallCollector::new();
|
||||
let mut timeline = Timeline::new();
|
||||
timeline.on_tool_use_block(collector.clone());
|
||||
|
||||
// ToolUseブロックのイベントシーケンスをディスパッチ
|
||||
timeline.dispatch(&Event::tool_use_start(0, "tool_123", "get_weather"));
|
||||
timeline.dispatch(&Event::tool_input_delta(0, r#"{"city":"#));
|
||||
timeline.dispatch(&Event::tool_input_delta(0, r#""Tokyo"}"#));
|
||||
timeline.dispatch(&Event::tool_use_stop(0));
|
||||
|
||||
// 収集されたToolCallを確認
|
||||
let calls = collector.take_collected();
|
||||
assert_eq!(calls.len(), 1);
|
||||
assert_eq!(calls[0].id, "tool_123");
|
||||
assert_eq!(calls[0].name, "get_weather");
|
||||
assert_eq!(calls[0].input["city"], "Tokyo");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_collect_multiple_tool_calls() {
|
||||
let collector = ToolCallCollector::new();
|
||||
let mut timeline = Timeline::new();
|
||||
timeline.on_tool_use_block(collector.clone());
|
||||
|
||||
// 1つ目のToolCall
|
||||
timeline.dispatch(&Event::tool_use_start(0, "call_1", "tool_a"));
|
||||
timeline.dispatch(&Event::tool_input_delta(0, r#"{"a":1}"#));
|
||||
timeline.dispatch(&Event::tool_use_stop(0));
|
||||
|
||||
// 2つ目のToolCall
|
||||
timeline.dispatch(&Event::tool_use_start(1, "call_2", "tool_b"));
|
||||
timeline.dispatch(&Event::tool_input_delta(1, r#"{"b":2}"#));
|
||||
timeline.dispatch(&Event::tool_use_stop(1));
|
||||
|
||||
let calls = collector.take_collected();
|
||||
assert_eq!(calls.len(), 2);
|
||||
assert_eq!(calls[0].name, "tool_a");
|
||||
assert_eq!(calls[1].name, "tool_b");
|
||||
}
|
||||
}
|
||||
359
worker/src/worker.rs
Normal file
359
worker/src/worker.rs
Normal file
|
|
@ -0,0 +1,359 @@
|
|||
//! Worker - ターン制御を行う高レベルコンポーネント
|
||||
//!
|
||||
//! LlmClientとTimelineを内包し、Tool/Hookを用いて自律的なインタラクションを実現する。
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::StreamExt;
|
||||
|
||||
use crate::llm_client::{ClientError, LlmClient, Request, ToolDefinition};
|
||||
use crate::tool_call_collector::ToolCallCollector;
|
||||
use crate::Timeline;
|
||||
use worker_types::{
|
||||
ControlFlow, HookError, Message, Tool, ToolCall, ToolError, ToolResult, TurnResult, WorkerHook,
|
||||
};
|
||||
|
||||
// =============================================================================
|
||||
// Worker Error
|
||||
// =============================================================================
|
||||
|
||||
/// Workerエラー
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum WorkerError {
|
||||
/// クライアントエラー
|
||||
#[error("Client error: {0}")]
|
||||
Client(#[from] ClientError),
|
||||
/// ツールエラー
|
||||
#[error("Tool error: {0}")]
|
||||
Tool(#[from] ToolError),
|
||||
/// Hookエラー
|
||||
#[error("Hook error: {0}")]
|
||||
Hook(#[from] HookError),
|
||||
/// 処理が中断された
|
||||
#[error("Aborted: {0}")]
|
||||
Aborted(String),
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Worker Config
|
||||
// =============================================================================
|
||||
|
||||
/// Worker設定
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WorkerConfig {
|
||||
/// 最大ターン数(無限ループ防止)
|
||||
pub max_turns: usize,
|
||||
}
|
||||
|
||||
impl Default for WorkerConfig {
|
||||
fn default() -> Self {
|
||||
Self { max_turns: 10 }
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Worker
|
||||
// =============================================================================
|
||||
|
||||
/// Worker - ターン制御コンポーネント
|
||||
///
|
||||
/// # 責務
|
||||
/// - LLMへのリクエスト送信とレスポンス処理
|
||||
/// - ツール呼び出しの収集と実行
|
||||
/// - Hookによる介入の提供
|
||||
/// - ターンループの制御
|
||||
pub struct Worker<C: LlmClient> {
|
||||
/// LLMクライアント
|
||||
client: C,
|
||||
/// イベントタイムライン
|
||||
timeline: Timeline,
|
||||
/// ツールコレクター(Timeline用ハンドラ)
|
||||
tool_call_collector: ToolCallCollector,
|
||||
/// 登録されたツール
|
||||
tools: HashMap<String, Arc<dyn Tool>>,
|
||||
/// 登録されたHook
|
||||
hooks: Vec<Box<dyn WorkerHook>>,
|
||||
/// 設定
|
||||
config: WorkerConfig,
|
||||
}
|
||||
|
||||
impl<C: LlmClient> Worker<C> {
|
||||
/// 新しいWorkerを作成
|
||||
pub fn new(client: C) -> Self {
|
||||
let tool_call_collector = ToolCallCollector::new();
|
||||
let mut timeline = Timeline::new();
|
||||
|
||||
// ToolCallCollectorをTimelineに登録
|
||||
timeline.on_tool_use_block(tool_call_collector.clone());
|
||||
|
||||
Self {
|
||||
client,
|
||||
timeline,
|
||||
tool_call_collector,
|
||||
tools: HashMap::new(),
|
||||
hooks: Vec::new(),
|
||||
config: WorkerConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 設定を適用
|
||||
pub fn config(mut self, config: WorkerConfig) -> Self {
|
||||
self.config = config;
|
||||
self
|
||||
}
|
||||
|
||||
/// ツールを登録
|
||||
pub fn register_tool(&mut self, tool: impl Tool + 'static) {
|
||||
let name = tool.name().to_string();
|
||||
self.tools.insert(name, Arc::new(tool));
|
||||
}
|
||||
|
||||
/// 複数のツールを登録
|
||||
pub fn register_tools(&mut self, tools: impl IntoIterator<Item = impl Tool + 'static>) {
|
||||
for tool in tools {
|
||||
self.register_tool(tool);
|
||||
}
|
||||
}
|
||||
|
||||
/// Hookを追加
|
||||
pub fn add_hook(&mut self, hook: impl WorkerHook + 'static) {
|
||||
self.hooks.push(Box::new(hook));
|
||||
}
|
||||
|
||||
/// タイムラインへの可変参照を取得(追加ハンドラ登録用)
|
||||
pub fn timeline_mut(&mut self) -> &mut Timeline {
|
||||
&mut self.timeline
|
||||
}
|
||||
|
||||
/// 登録されたツールからToolDefinitionのリストを生成
|
||||
fn build_tool_definitions(&self) -> Vec<ToolDefinition> {
|
||||
self.tools
|
||||
.values()
|
||||
.map(|tool| {
|
||||
ToolDefinition::new(tool.name())
|
||||
.description(tool.description())
|
||||
.input_schema(tool.input_schema())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// ターンを実行
|
||||
///
|
||||
/// メッセージを送信し、レスポンスを処理する。
|
||||
/// ツール呼び出しがある場合は自動的にループする。
|
||||
pub async fn run(&mut self, messages: Vec<Message>) -> Result<Vec<Message>, WorkerError> {
|
||||
let mut context = messages;
|
||||
let tool_definitions = self.build_tool_definitions();
|
||||
|
||||
for _turn in 0..self.config.max_turns {
|
||||
// Hook: on_message_send
|
||||
let control = self.run_on_message_send_hooks(&mut context).await?;
|
||||
if let ControlFlow::Abort(reason) = control {
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
|
||||
// リクエスト構築
|
||||
let request = self.build_request(&context, &tool_definitions);
|
||||
|
||||
// ストリーム処理
|
||||
let mut stream = self.client.stream(request).await?;
|
||||
while let Some(event_result) = stream.next().await {
|
||||
let event = event_result?;
|
||||
self.timeline.dispatch(&event);
|
||||
}
|
||||
|
||||
// ツール呼び出しの収集結果を取得
|
||||
let tool_calls = self.tool_call_collector.take_collected();
|
||||
|
||||
if tool_calls.is_empty() {
|
||||
// ツール呼び出しなし → ターン終了判定
|
||||
let turn_result = self.run_on_turn_end_hooks(&context).await?;
|
||||
match turn_result {
|
||||
TurnResult::Finish => {
|
||||
return Ok(context);
|
||||
}
|
||||
TurnResult::ContinueWithMessages(additional) => {
|
||||
context.extend(additional);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ツール実行
|
||||
let tool_results = self.execute_tools(tool_calls).await?;
|
||||
|
||||
// ツール結果をコンテキストに追加
|
||||
for result in tool_results {
|
||||
context.push(Message::tool_result(&result.tool_use_id, &result.content));
|
||||
}
|
||||
}
|
||||
|
||||
// 最大ターン数到達
|
||||
Err(WorkerError::Aborted(format!(
|
||||
"Maximum turns ({}) reached",
|
||||
self.config.max_turns
|
||||
)))
|
||||
}
|
||||
|
||||
/// リクエストを構築
|
||||
fn build_request(&self, context: &[Message], tool_definitions: &[ToolDefinition]) -> Request {
|
||||
let mut request = Request::new();
|
||||
|
||||
// メッセージを追加
|
||||
for msg in context {
|
||||
// worker-types::Message から llm_client::Message への変換
|
||||
request = request.message(crate::llm_client::Message {
|
||||
role: match msg.role {
|
||||
worker_types::Role::User => crate::llm_client::Role::User,
|
||||
worker_types::Role::Assistant => crate::llm_client::Role::Assistant,
|
||||
},
|
||||
content: match &msg.content {
|
||||
worker_types::MessageContent::Text(t) => {
|
||||
crate::llm_client::MessageContent::Text(t.clone())
|
||||
}
|
||||
worker_types::MessageContent::ToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
} => crate::llm_client::MessageContent::ToolResult {
|
||||
tool_use_id: tool_use_id.clone(),
|
||||
content: content.clone(),
|
||||
},
|
||||
worker_types::MessageContent::Parts(parts) => {
|
||||
crate::llm_client::MessageContent::Parts(
|
||||
parts
|
||||
.iter()
|
||||
.map(|p| match p {
|
||||
worker_types::ContentPart::Text { text } => {
|
||||
crate::llm_client::ContentPart::Text { text: text.clone() }
|
||||
}
|
||||
worker_types::ContentPart::ToolUse { id, name, input } => {
|
||||
crate::llm_client::ContentPart::ToolUse {
|
||||
id: id.clone(),
|
||||
name: name.clone(),
|
||||
input: input.clone(),
|
||||
}
|
||||
}
|
||||
worker_types::ContentPart::ToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
} => crate::llm_client::ContentPart::ToolResult {
|
||||
tool_use_id: tool_use_id.clone(),
|
||||
content: content.clone(),
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// ツール定義を追加
|
||||
for tool_def in tool_definitions {
|
||||
request = request.tool(tool_def.clone());
|
||||
}
|
||||
|
||||
request
|
||||
}
|
||||
|
||||
/// Hooks: on_message_send
|
||||
async fn run_on_message_send_hooks(
|
||||
&self,
|
||||
context: &mut Vec<Message>,
|
||||
) -> Result<ControlFlow, WorkerError> {
|
||||
for hook in &self.hooks {
|
||||
let result = hook.on_message_send(context).await?;
|
||||
match result {
|
||||
ControlFlow::Continue => continue,
|
||||
ControlFlow::Skip => return Ok(ControlFlow::Skip),
|
||||
ControlFlow::Abort(reason) => return Ok(ControlFlow::Abort(reason)),
|
||||
}
|
||||
}
|
||||
Ok(ControlFlow::Continue)
|
||||
}
|
||||
|
||||
/// Hooks: on_turn_end
|
||||
async fn run_on_turn_end_hooks(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
) -> Result<TurnResult, WorkerError> {
|
||||
for hook in &self.hooks {
|
||||
let result = hook.on_turn_end(messages).await?;
|
||||
match result {
|
||||
TurnResult::Finish => continue,
|
||||
TurnResult::ContinueWithMessages(msgs) => {
|
||||
return Ok(TurnResult::ContinueWithMessages(msgs));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(TurnResult::Finish)
|
||||
}
|
||||
|
||||
/// ツールを並列実行
|
||||
async fn execute_tools(
|
||||
&self,
|
||||
mut tool_calls: Vec<ToolCall>,
|
||||
) -> Result<Vec<ToolResult>, WorkerError> {
|
||||
let mut results = Vec::new();
|
||||
|
||||
// TODO: 将来的には join_all で並列実行
|
||||
// 現在は逐次実行
|
||||
for mut tool_call in tool_calls.drain(..) {
|
||||
// Hook: before_tool_call
|
||||
let mut skip = false;
|
||||
for hook in &self.hooks {
|
||||
let result = hook.before_tool_call(&mut tool_call).await?;
|
||||
match result {
|
||||
ControlFlow::Continue => {}
|
||||
ControlFlow::Skip => {
|
||||
skip = true;
|
||||
break;
|
||||
}
|
||||
ControlFlow::Abort(reason) => {
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if skip {
|
||||
continue;
|
||||
}
|
||||
|
||||
// ツール実行
|
||||
let mut tool_result = if let Some(tool) = self.tools.get(&tool_call.name) {
|
||||
let input_json = serde_json::to_string(&tool_call.input).unwrap_or_default();
|
||||
match tool.execute(&input_json).await {
|
||||
Ok(content) => ToolResult::success(&tool_call.id, content),
|
||||
Err(e) => ToolResult::error(&tool_call.id, e.to_string()),
|
||||
}
|
||||
} else {
|
||||
ToolResult::error(
|
||||
&tool_call.id,
|
||||
format!("Tool '{}' not found", tool_call.name),
|
||||
)
|
||||
};
|
||||
|
||||
// Hook: after_tool_call
|
||||
for hook in &self.hooks {
|
||||
let result = hook.after_tool_call(&mut tool_result).await?;
|
||||
match result {
|
||||
ControlFlow::Continue => {}
|
||||
ControlFlow::Skip => break,
|
||||
ControlFlow::Abort(reason) => {
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
results.push(tool_result);
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
// 基本的なテストのみ。LlmClientを使ったテストは統合テストで行う。
|
||||
}
|
||||
|
|
@ -1,14 +1,22 @@
|
|||
//! テスト用のAPIレスポンス記録・再生機能
|
||||
//! テスト用共通ユーティリティ
|
||||
//!
|
||||
//! 実際のAPIレスポンスをタイムスタンプ付きで記録し、
|
||||
//! テスト時に再生できるようにする。
|
||||
//! MockLlmClient、イベントレコーダー・プレイヤーを提供する
|
||||
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader, BufWriter, Write};
|
||||
use std::path::Path;
|
||||
use std::pin::Pin;
|
||||
use std::time::{Instant, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::Stream;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use worker::llm_client::{ClientError, LlmClient, Request};
|
||||
use worker_types::Event;
|
||||
|
||||
// =============================================================================
|
||||
// Recorded Event Types
|
||||
// =============================================================================
|
||||
|
||||
/// 記録されたSSEイベント
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -32,15 +40,21 @@ pub struct SessionMetadata {
|
|||
pub description: String,
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Event Recorder
|
||||
// =============================================================================
|
||||
|
||||
/// SSEイベントレコーダー
|
||||
///
|
||||
/// 実際のAPIレスポンスを記録し、後でテストに使用できるようにする
|
||||
#[allow(dead_code)]
|
||||
pub struct EventRecorder {
|
||||
start_time: Instant,
|
||||
events: Vec<RecordedEvent>,
|
||||
metadata: SessionMetadata,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl EventRecorder {
|
||||
/// 新しいレコーダーを作成
|
||||
pub fn new(model: impl Into<String>, description: impl Into<String>) -> Self {
|
||||
|
|
@ -97,15 +111,21 @@ impl EventRecorder {
|
|||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Event Player
|
||||
// =============================================================================
|
||||
|
||||
/// SSEイベントプレイヤー
|
||||
///
|
||||
/// 記録されたイベントを読み込み、テストで使用する
|
||||
#[allow(dead_code)]
|
||||
pub struct EventPlayer {
|
||||
metadata: SessionMetadata,
|
||||
events: Vec<RecordedEvent>,
|
||||
current_index: usize,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl EventPlayer {
|
||||
/// ファイルから読み込み
|
||||
pub fn load(path: impl AsRef<Path>) -> std::io::Result<Self> {
|
||||
|
|
@ -166,73 +186,55 @@ impl EventPlayer {
|
|||
pub fn reset(&mut self) {
|
||||
self.current_index = 0;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Write;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[test]
|
||||
fn test_record_and_playback() {
|
||||
// レコーダーを作成して記録
|
||||
let mut recorder = EventRecorder::new("claude-sonnet-4-20250514", "Test recording");
|
||||
recorder.record("message_start", r#"{"type":"message_start"}"#);
|
||||
recorder.record(
|
||||
"content_block_start",
|
||||
r#"{"type":"content_block_start","index":0}"#,
|
||||
);
|
||||
recorder.record(
|
||||
"content_block_delta",
|
||||
r#"{"type":"content_block_delta","delta":{"type":"text_delta","text":"Hello"}}"#,
|
||||
);
|
||||
|
||||
// 一時ファイルに保存
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
recorder.save(temp_file.path()).unwrap();
|
||||
|
||||
// 読み込んで確認
|
||||
let player = EventPlayer::load(temp_file.path()).unwrap();
|
||||
assert_eq!(player.metadata().model, "claude-sonnet-4-20250514");
|
||||
assert_eq!(player.event_count(), 3);
|
||||
assert_eq!(player.events()[0].event_type, "message_start");
|
||||
assert_eq!(player.events()[2].event_type, "content_block_delta");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_player_iteration() {
|
||||
// テストデータを直接作成
|
||||
let mut temp_file = NamedTempFile::new().unwrap();
|
||||
writeln!(
|
||||
temp_file,
|
||||
r#"{{"timestamp":1704067200,"model":"test","description":"test"}}"#
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
temp_file,
|
||||
r#"{{"elapsed_ms":0,"event_type":"ping","data":"{{}}"}}"#
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
temp_file,
|
||||
r#"{{"elapsed_ms":100,"event_type":"message_stop","data":"{{}}"}}"#
|
||||
)
|
||||
.unwrap();
|
||||
temp_file.flush().unwrap();
|
||||
|
||||
let mut player = EventPlayer::load(temp_file.path()).unwrap();
|
||||
|
||||
let first = player.next_event().unwrap();
|
||||
assert_eq!(first.event_type, "ping");
|
||||
|
||||
let second = player.next_event().unwrap();
|
||||
assert_eq!(second.event_type, "message_stop");
|
||||
|
||||
assert!(player.next_event().is_none());
|
||||
|
||||
// リセット後は最初から
|
||||
player.reset();
|
||||
assert_eq!(player.next_event().unwrap().event_type, "ping");
|
||||
/// 全イベントをworker_types::Eventとしてパースして取得
|
||||
pub fn parse_events(&self) -> Vec<Event> {
|
||||
self.events
|
||||
.iter()
|
||||
.filter_map(|recorded| serde_json::from_str(&recorded.data).ok())
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// MockLlmClient
|
||||
// =============================================================================
|
||||
|
||||
/// テスト用のモックLLMクライアント
|
||||
///
|
||||
/// 事前に定義されたイベントシーケンスをストリームとして返す。
|
||||
/// fixtureファイルからロードすることも、直接イベントを渡すこともできる。
|
||||
pub struct MockLlmClient {
|
||||
events: Vec<Event>,
|
||||
}
|
||||
|
||||
impl MockLlmClient {
|
||||
/// イベントリストから直接作成
|
||||
pub fn new(events: Vec<Event>) -> Self {
|
||||
Self { events }
|
||||
}
|
||||
|
||||
/// fixtureファイルからロード
|
||||
pub fn from_fixture(path: impl AsRef<Path>) -> std::io::Result<Self> {
|
||||
let player = EventPlayer::load(path)?;
|
||||
let events = player.parse_events();
|
||||
Ok(Self { events })
|
||||
}
|
||||
|
||||
/// 保持しているイベント数を取得
|
||||
pub fn event_count(&self) -> usize {
|
||||
self.events.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmClient for MockLlmClient {
|
||||
async fn stream(
|
||||
&self,
|
||||
_request: Request,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>>, ClientError> {
|
||||
let events = self.events.clone();
|
||||
let stream = futures::stream::iter(events.into_iter().map(Ok));
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
}
|
||||
16
worker/tests/fixtures/tool_call.jsonl
vendored
Normal file
16
worker/tests/fixtures/tool_call.jsonl
vendored
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
{"timestamp":1767692881,"model":"claude-sonnet-4-20250514","description":"Tool call response"}
|
||||
{"elapsed_ms":1783,"event_type":"Discriminant(1)","data":"{\"Usage\":{\"input_tokens\":409,\"output_tokens\":3,\"total_tokens\":412,\"cache_read_input_tokens\":0,\"cache_creation_input_tokens\":0}}"}
|
||||
{"elapsed_ms":1783,"event_type":"Discriminant(4)","data":"{\"BlockStart\":{\"index\":0,\"block_type\":\"Text\",\"metadata\":\"Text\"}}"}
|
||||
{"elapsed_ms":1783,"event_type":"Discriminant(5)","data":"{\"BlockDelta\":{\"index\":0,\"delta\":{\"Text\":\"I'll check\"}}}"}
|
||||
{"elapsed_ms":1883,"event_type":"Discriminant(5)","data":"{\"BlockDelta\":{\"index\":0,\"delta\":{\"Text\":\" the current\"}}}"}
|
||||
{"elapsed_ms":2063,"event_type":"Discriminant(0)","data":"{\"Ping\":{\"timestamp\":null}}"}
|
||||
{"elapsed_ms":2063,"event_type":"Discriminant(5)","data":"{\"BlockDelta\":{\"index\":0,\"delta\":{\"Text\":\" weather in Tokyo for you using\"}}}"}
|
||||
{"elapsed_ms":2124,"event_type":"Discriminant(5)","data":"{\"BlockDelta\":{\"index\":0,\"delta\":{\"Text\":\" the get_weather tool.\"}}}"}
|
||||
{"elapsed_ms":2252,"event_type":"Discriminant(6)","data":"{\"BlockStop\":{\"index\":0,\"block_type\":\"Text\",\"stop_reason\":null}}"}
|
||||
{"elapsed_ms":2253,"event_type":"Discriminant(4)","data":"{\"BlockStart\":{\"index\":1,\"block_type\":\"ToolUse\",\"metadata\":{\"ToolUse\":{\"id\":\"toolu_011Hg5wju1LGL7F65HyfE6bM\",\"name\":\"get_weather\"}}}}"}
|
||||
{"elapsed_ms":2253,"event_type":"Discriminant(5)","data":"{\"BlockDelta\":{\"index\":1,\"delta\":{\"InputJson\":\"\"}}}"}
|
||||
{"elapsed_ms":2306,"event_type":"Discriminant(5)","data":"{\"BlockDelta\":{\"index\":1,\"delta\":{\"InputJson\":\"{\\\"city\\\": \\\"Tokyo\"}}}"}
|
||||
{"elapsed_ms":2451,"event_type":"Discriminant(5)","data":"{\"BlockDelta\":{\"index\":1,\"delta\":{\"InputJson\":\"\\\"}\"}}}"}
|
||||
{"elapsed_ms":2451,"event_type":"Discriminant(6)","data":"{\"BlockStop\":{\"index\":1,\"block_type\":\"Text\",\"stop_reason\":null}}"}
|
||||
{"elapsed_ms":2464,"event_type":"Discriminant(1)","data":"{\"Usage\":{\"input_tokens\":409,\"output_tokens\":71,\"total_tokens\":480,\"cache_read_input_tokens\":0,\"cache_creation_input_tokens\":0}}"}
|
||||
{"elapsed_ms":2470,"event_type":"Discriminant(2)","data":"{\"Status\":{\"status\":\"Completed\"}}"}
|
||||
243
worker/tests/worker_fixtures.rs
Normal file
243
worker/tests/worker_fixtures.rs
Normal file
|
|
@ -0,0 +1,243 @@
|
|||
//! Workerフィクスチャベースの統合テスト
|
||||
//!
|
||||
//! 記録されたAPIレスポンスを使ってWorkerの動作をテストする。
|
||||
//! APIキー不要でローカルで実行可能。
|
||||
|
||||
mod common;
|
||||
|
||||
use std::path::Path;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common::MockLlmClient;
|
||||
use worker::{Worker, WorkerConfig};
|
||||
use worker_types::{Tool, ToolError};
|
||||
|
||||
/// フィクスチャディレクトリのパス
|
||||
fn fixtures_dir() -> std::path::PathBuf {
|
||||
Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures")
|
||||
}
|
||||
|
||||
/// シンプルなテスト用ツール
|
||||
#[derive(Clone)]
|
||||
struct MockWeatherTool {
|
||||
call_count: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl MockWeatherTool {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
call_count: Arc::new(AtomicUsize::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_call_count(&self) -> usize {
|
||||
self.call_count.load(Ordering::SeqCst)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for MockWeatherTool {
|
||||
fn name(&self) -> &str {
|
||||
"get_weather"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Get the current weather for a city"
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city name"
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input_json: &str) -> Result<String, ToolError> {
|
||||
self.call_count.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
// 入力をパース
|
||||
let input: serde_json::Value = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(e.to_string()))?;
|
||||
|
||||
let city = input["city"]
|
||||
.as_str()
|
||||
.unwrap_or("Unknown");
|
||||
|
||||
// モックのレスポンスを返す
|
||||
Ok(format!("Weather in {}: Sunny, 22°C", city))
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Basic Fixture Tests
|
||||
// =============================================================================
|
||||
|
||||
/// MockLlmClientがJSONLフィクスチャファイルから正しくイベントをロードできることを確認
|
||||
///
|
||||
/// 既存のanthropic_*.jsonlファイルを使用し、イベントがパース・ロードされることを検証する。
|
||||
#[test]
|
||||
fn test_mock_client_from_fixture() {
|
||||
// 既存のフィクスチャをロード
|
||||
let fixture_path = fixtures_dir().join("anthropic_1767624445.jsonl");
|
||||
if !fixture_path.exists() {
|
||||
println!("Fixture not found, skipping test");
|
||||
return;
|
||||
}
|
||||
|
||||
let client = MockLlmClient::from_fixture(&fixture_path).unwrap();
|
||||
assert!(client.event_count() > 0, "Should have loaded events");
|
||||
println!("Loaded {} events from fixture", client.event_count());
|
||||
}
|
||||
|
||||
/// MockLlmClientが直接指定されたイベントリストで正しく動作することを確認
|
||||
///
|
||||
/// fixtureファイルを使わず、プログラムでイベントを構築してクライアントを作成する。
|
||||
#[test]
|
||||
fn test_mock_client_from_events() {
|
||||
use worker_types::Event;
|
||||
|
||||
// 直接イベントを指定
|
||||
let events = vec![
|
||||
Event::text_block_start(0),
|
||||
Event::text_delta(0, "Hello!"),
|
||||
Event::text_block_stop(0, None),
|
||||
];
|
||||
|
||||
let client = MockLlmClient::new(events);
|
||||
assert_eq!(client.event_count(), 3);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Worker Tests with Fixtures
|
||||
// =============================================================================
|
||||
|
||||
/// Workerがシンプルなテキストレスポンスを正しく処理できることを確認
|
||||
///
|
||||
/// simple_text.jsonlフィクスチャを使用し、ツール呼び出しなしのシナリオをテストする。
|
||||
/// フィクスチャがない場合はスキップされる。
|
||||
#[tokio::test]
|
||||
async fn test_worker_simple_text_response() {
|
||||
let fixture_path = fixtures_dir().join("simple_text.jsonl");
|
||||
if !fixture_path.exists() {
|
||||
println!("Fixture not found: {:?}, skipping test", fixture_path);
|
||||
println!("Run: cargo run --example record_worker_test");
|
||||
return;
|
||||
}
|
||||
|
||||
let client = MockLlmClient::from_fixture(&fixture_path).unwrap();
|
||||
let mut worker = Worker::new(client);
|
||||
|
||||
// シンプルなメッセージを送信
|
||||
let messages = vec![worker_types::Message::user("Hello")];
|
||||
let result = worker.run(messages).await;
|
||||
|
||||
assert!(result.is_ok(), "Worker should complete successfully");
|
||||
}
|
||||
|
||||
/// Workerがツール呼び出しを含むレスポンスを正しく処理できることを確認
|
||||
///
|
||||
/// tool_call.jsonlフィクスチャを使用し、MockWeatherToolが呼び出されることをテストする。
|
||||
/// max_turns=1に設定し、ツール実行後のループを防止。
|
||||
#[tokio::test]
|
||||
async fn test_worker_tool_call() {
|
||||
let fixture_path = fixtures_dir().join("tool_call.jsonl");
|
||||
if !fixture_path.exists() {
|
||||
println!("Fixture not found: {:?}, skipping test", fixture_path);
|
||||
println!("Run: cargo run --example record_worker_test");
|
||||
return;
|
||||
}
|
||||
|
||||
let client = MockLlmClient::from_fixture(&fixture_path).unwrap();
|
||||
let mut worker = Worker::new(client);
|
||||
|
||||
// ツールを登録
|
||||
let weather_tool = MockWeatherTool::new();
|
||||
let tool_for_check = weather_tool.clone();
|
||||
worker.register_tool(weather_tool);
|
||||
|
||||
// 設定: ツール実行後はターン終了(ループしない)
|
||||
worker = worker.config(WorkerConfig { max_turns: 1 });
|
||||
|
||||
// メッセージを送信
|
||||
let messages = vec![worker_types::Message::user("What's the weather in Tokyo?")];
|
||||
let _result = worker.run(messages).await;
|
||||
|
||||
// ツールが呼び出されたことを確認
|
||||
// Note: max_turns=1なのでツール結果後のリクエストは送信されない
|
||||
let call_count = tool_for_check.get_call_count();
|
||||
println!("Tool was called {} times", call_count);
|
||||
|
||||
// フィクスチャにToolUseが含まれていればツールが呼び出されるはず
|
||||
// ただしmax_turns=1なので1回で終了
|
||||
}
|
||||
|
||||
/// fixtureファイルなしでWorkerが動作することを確認
|
||||
///
|
||||
/// プログラムでイベントシーケンスを構築し、MockLlmClientに渡してテストする。
|
||||
/// テストの独立性を高め、外部ファイルへの依存を排除したい場合に有用。
|
||||
#[tokio::test]
|
||||
async fn test_worker_with_programmatic_events() {
|
||||
use worker_types::{Event, ResponseStatus, StatusEvent};
|
||||
|
||||
// プログラムでイベントシーケンスを構築
|
||||
let events = vec![
|
||||
Event::text_block_start(0),
|
||||
Event::text_delta(0, "Hello, "),
|
||||
Event::text_delta(0, "World!"),
|
||||
Event::text_block_stop(0, None),
|
||||
Event::Status(StatusEvent {
|
||||
status: ResponseStatus::Completed,
|
||||
}),
|
||||
];
|
||||
|
||||
let client = MockLlmClient::new(events);
|
||||
let mut worker = Worker::new(client);
|
||||
|
||||
let messages = vec![worker_types::Message::user("Greet me")];
|
||||
let result = worker.run(messages).await;
|
||||
|
||||
assert!(result.is_ok(), "Worker should complete successfully");
|
||||
}
|
||||
|
||||
/// ToolCallCollectorがToolUseブロックイベントから正しくToolCallを収集することを確認
|
||||
///
|
||||
/// Timelineにイベントをディスパッチし、ToolCallCollectorが
|
||||
/// id, name, input(JSON)を正しく抽出できることを検証する。
|
||||
#[tokio::test]
|
||||
async fn test_tool_call_collector_integration() {
|
||||
use worker::ToolCallCollector;
|
||||
use worker::Timeline;
|
||||
use worker_types::Event;
|
||||
|
||||
// ToolUseブロックを含むイベントシーケンス
|
||||
let events = vec![
|
||||
Event::tool_use_start(0, "call_123", "get_weather"),
|
||||
Event::tool_input_delta(0, r#"{"city":"#),
|
||||
Event::tool_input_delta(0, r#""Tokyo"}"#),
|
||||
Event::tool_use_stop(0),
|
||||
];
|
||||
|
||||
let collector = ToolCallCollector::new();
|
||||
let mut timeline = Timeline::new();
|
||||
timeline.on_tool_use_block(collector.clone());
|
||||
|
||||
// イベントをディスパッチ
|
||||
for event in &events {
|
||||
timeline.dispatch(event);
|
||||
}
|
||||
|
||||
// 収集されたToolCallを確認
|
||||
let calls = collector.take_collected();
|
||||
assert_eq!(calls.len(), 1, "Should collect one tool call");
|
||||
assert_eq!(calls[0].name, "get_weather");
|
||||
assert_eq!(calls[0].id, "call_123");
|
||||
assert_eq!(calls[0].input["city"], "Tokyo");
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user