diff --git a/worker-types/src/hook.rs b/worker-types/src/hook.rs new file mode 100644 index 0000000..7c3dd35 --- /dev/null +++ b/worker-types/src/hook.rs @@ -0,0 +1,140 @@ +//! Hook関連の型定義 +//! +//! Worker層でのターン制御・介入に使用される型 + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use thiserror::Error; + +// ============================================================================= +// Control Flow Types +// ============================================================================= + +/// Hook処理の制御フロー +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ControlFlow { + /// 処理を続行 + Continue, + /// 現在の処理をスキップ(Tool実行など) + Skip, + /// 処理を中断 + Abort(String), +} + +/// ターン終了時の判定結果 +#[derive(Debug, Clone)] +pub enum TurnResult { + /// ターンを終了 + Finish, + /// メッセージを追加してターン継続(自己修正など) + ContinueWithMessages(Vec), +} + +// ============================================================================= +// Tool Call / Result Types +// ============================================================================= + +/// ツール呼び出し情報 +/// +/// LLMからのToolUseブロックを表現し、Hook処理で改変可能 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + /// ツール呼び出しID(レスポンスとの紐付けに使用) + pub id: String, + /// ツール名 + pub name: String, + /// 入力引数(JSON) + pub input: Value, +} + +/// ツール実行結果 +/// +/// ツール実行後の結果を表現し、Hook処理で改変可能 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolResult { + /// 対応するツール呼び出しID + pub tool_use_id: String, + /// 結果コンテンツ + pub content: String, + /// エラーかどうか + #[serde(default)] + pub is_error: bool, +} + +impl ToolResult { + /// 成功結果を作成 + pub fn success(tool_use_id: impl Into, content: impl Into) -> Self { + Self { + tool_use_id: tool_use_id.into(), + content: content.into(), + is_error: false, + } + } + + /// エラー結果を作成 + pub fn error(tool_use_id: impl Into, content: impl Into) -> Self { + Self { + tool_use_id: tool_use_id.into(), + content: content.into(), + is_error: true, + } + } +} + +// ============================================================================= +// Hook Error +// ============================================================================= + +/// Hookエラー +#[derive(Debug, Error)] +pub enum HookError { + /// 処理が中断された + #[error("Aborted: {0}")] + Aborted(String), + /// 内部エラー + #[error("Hook error: {0}")] + Internal(String), +} + +// ============================================================================= +// WorkerHook Trait +// ============================================================================= + +/// Worker Hook trait +/// +/// ターンの進行・メッセージ・ツール実行に対して介入するためのトレイト。 +/// デフォルト実装では何も行わずContinueを返す。 +#[async_trait] +pub trait WorkerHook: Send + Sync { + /// メッセージ送信前 + /// + /// リクエストに含まれるメッセージリストを改変できる。 + async fn on_message_send( + &self, + _context: &mut Vec, + ) -> Result { + Ok(ControlFlow::Continue) + } + + /// ツール実行前 + /// + /// 実行をキャンセルしたり、引数を書き換えることができる。 + async fn before_tool_call(&self, _tool_call: &mut ToolCall) -> Result { + Ok(ControlFlow::Continue) + } + + /// ツール実行後 + /// + /// 結果を書き換えたり、隠蔽したりできる。 + async fn after_tool_call(&self, _tool_result: &mut ToolResult) -> Result { + Ok(ControlFlow::Continue) + } + + /// ターン終了時 + /// + /// 生成されたメッセージを検査し、必要ならリトライを指示できる。 + async fn on_turn_end(&self, _messages: &[crate::Message]) -> Result { + Ok(TurnResult::Finish) + } +} diff --git a/worker-types/src/lib.rs b/worker-types/src/lib.rs index 72b3291..5626913 100644 --- a/worker-types/src/lib.rs +++ b/worker-types/src/lib.rs @@ -3,12 +3,19 @@ //! このクレートは以下を提供します: //! - Event: llm_client層からのフラットなイベント列挙 //! - Kind/Handler: タイムライン層でのイベント処理トレイト +//! - Tool: ツール定義トレイト +//! - Hook: Worker層での介入用トレイト +//! - Message: メッセージ型 //! - 各種イベント構造体 mod event; mod handler; +mod hook; +mod message; mod tool; pub use event::*; pub use handler::*; +pub use hook::*; +pub use message::*; pub use tool::*; diff --git a/worker-types/src/message.rs b/worker-types/src/message.rs new file mode 100644 index 0000000..6981842 --- /dev/null +++ b/worker-types/src/message.rs @@ -0,0 +1,87 @@ +//! メッセージ型定義 +//! +//! LLM会話で使用されるメッセージ構造 + +use serde::{Deserialize, Serialize}; + +/// メッセージのロール +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Role { + /// ユーザー + User, + /// アシスタント + Assistant, +} + +/// メッセージ +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + /// ロール + pub role: Role, + /// コンテンツ + pub content: MessageContent, +} + +/// メッセージコンテンツ +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum MessageContent { + /// テキストコンテンツ + Text(String), + /// ツール結果 + ToolResult { + tool_use_id: String, + content: String, + }, + /// 複合コンテンツ (テキスト + ツール使用等) + Parts(Vec), +} + +/// コンテンツパーツ +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ContentPart { + /// テキスト + #[serde(rename = "text")] + Text { text: String }, + /// ツール使用 + #[serde(rename = "tool_use")] + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, + /// ツール結果 + #[serde(rename = "tool_result")] + ToolResult { tool_use_id: String, content: String }, +} + +impl Message { + /// ユーザーメッセージを作成 + pub fn user(content: impl Into) -> Self { + Self { + role: Role::User, + content: MessageContent::Text(content.into()), + } + } + + /// アシスタントメッセージを作成 + pub fn assistant(content: impl Into) -> Self { + Self { + role: Role::Assistant, + content: MessageContent::Text(content.into()), + } + } + + /// ツール結果メッセージを作成 + pub fn tool_result(tool_use_id: impl Into, content: impl Into) -> Self { + Self { + role: Role::User, + content: MessageContent::ToolResult { + tool_use_id: tool_use_id.into(), + content: content.into(), + }, + } + } +} diff --git a/worker/examples/record_test_fixtures/main.rs b/worker/examples/record_test_fixtures/main.rs new file mode 100644 index 0000000..7b4dcfd --- /dev/null +++ b/worker/examples/record_test_fixtures/main.rs @@ -0,0 +1,99 @@ +//! テストフィクスチャ記録ツール +//! +//! 定義されたシナリオのAPIレスポンスを記録する。 +//! +//! ## 使用方法 +//! +//! ```bash +//! # 利用可能なシナリオを表示 +//! cargo run --example record_test_fixtures +//! +//! # 特定のシナリオを記録 +//! ANTHROPIC_API_KEY=your-key cargo run --example record_test_fixtures -- simple_text +//! ANTHROPIC_API_KEY=your-key cargo run --example record_test_fixtures -- tool_call +//! +//! # 全シナリオを記録 +//! ANTHROPIC_API_KEY=your-key cargo run --example record_test_fixtures -- --all +//! ``` + +mod recorder; +mod scenarios; + +use worker::llm_client::providers::anthropic::AnthropicClient; + +fn print_usage() { + println!("Usage: cargo run --example record_test_fixtures -- "); + println!(" cargo run --example record_test_fixtures -- --all"); + println!(); + println!("Available scenarios:"); + for scenario in scenarios::scenarios() { + println!(" {:20} - {}", scenario.output_name, scenario.name); + } + println!(); + println!("Options:"); + println!(" --all Record all scenarios"); +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let args: Vec = std::env::args().collect(); + + // 引数がなければ使い方を表示 + if args.len() < 2 { + print_usage(); + return Ok(()); + } + + let arg = &args[1]; + + // 全シナリオを取得 + let all_scenarios = scenarios::scenarios(); + + // 実行するシナリオを決定 + let scenarios_to_run: Vec<_> = if arg == "--all" { + all_scenarios + } else { + // 指定されたシナリオを検索 + let found: Vec<_> = all_scenarios + .into_iter() + .filter(|s| s.output_name == arg) + .collect(); + + if found.is_empty() { + eprintln!("Error: Unknown scenario '{}'", arg); + println!(); + print_usage(); + std::process::exit(1); + } + found + }; + + // APIキーを取得 + let api_key = std::env::var("ANTHROPIC_API_KEY") + .expect("ANTHROPIC_API_KEY environment variable must be set"); + + let model = "claude-sonnet-4-20250514"; + + println!("=== Test Fixture Generator ==="); + println!("Model: {}", model); + println!("Scenarios: {}\n", scenarios_to_run.len()); + + let client = AnthropicClient::new(&api_key, model); + + // シナリオを記録 + for scenario in scenarios_to_run { + recorder::record_request( + &client, + scenario.request, + scenario.name, + scenario.output_name, + model, + ) + .await?; + } + + println!("\n✅ Done!"); + println!("Run tests with: cargo test -p worker"); + + Ok(()) +} diff --git a/worker/examples/record_test_fixtures/recorder.rs b/worker/examples/record_test_fixtures/recorder.rs new file mode 100644 index 0000000..7a159bc --- /dev/null +++ b/worker/examples/record_test_fixtures/recorder.rs @@ -0,0 +1,100 @@ +//! テストフィクスチャ記録機構 +//! +//! イベントをJSONLフォーマットでファイルに保存する + +use std::fs::{self, File}; +use std::io::{BufWriter, Write}; +use std::path::Path; +use std::time::{Instant, SystemTime, UNIX_EPOCH}; + +use futures::StreamExt; +use worker::llm_client::{LlmClient, Request}; + +/// 記録されたイベント +#[derive(Debug, serde::Serialize, serde::Deserialize)] +pub struct RecordedEvent { + pub elapsed_ms: u64, + pub event_type: String, + pub data: String, +} + +/// セッションメタデータ +#[derive(Debug, serde::Serialize, serde::Deserialize)] +pub struct SessionMetadata { + pub timestamp: u64, + pub model: String, + pub description: String, +} + +/// イベントシーケンスをファイルに保存 +pub fn save_fixture( + path: impl AsRef, + metadata: &SessionMetadata, + events: &[RecordedEvent], +) -> std::io::Result<()> { + let file = File::create(path)?; + let mut writer = BufWriter::new(file); + + writeln!(writer, "{}", serde_json::to_string(metadata)?)?; + for event in events { + writeln!(writer, "{}", serde_json::to_string(event)?)?; + } + writer.flush()?; + Ok(()) +} + +/// リクエストを送信してイベントを記録 +pub async fn record_request( + client: &C, + request: Request, + description: &str, + output_name: &str, + model: &str, +) -> Result> { + println!("\n📝 Recording: {}", description); + + let start_time = Instant::now(); + let mut events: Vec = Vec::new(); + + let mut stream = client.stream(request).await?; + + while let Some(result) = stream.next().await { + let elapsed = start_time.elapsed().as_millis() as u64; + match result { + Ok(event) => { + let event_json = serde_json::to_string(&event)?; + println!(" [{:>6}ms] {:?}", elapsed, event); + events.push(RecordedEvent { + elapsed_ms: elapsed, + event_type: format!("{:?}", std::mem::discriminant(&event)), + data: event_json, + }); + } + Err(e) => { + eprintln!(" Error: {}", e); + break; + } + } + } + + // 保存 + let fixtures_dir = Path::new("worker/tests/fixtures"); + fs::create_dir_all(fixtures_dir)?; + + let filepath = fixtures_dir.join(format!("{}.jsonl", output_name)); + + let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(); + let metadata = SessionMetadata { + timestamp, + model: model.to_string(), + description: description.to_string(), + }; + + save_fixture(&filepath, &metadata, &events)?; + + let event_count = events.len(); + println!(" 💾 Saved: {}", filepath.display()); + println!(" 📊 {} events recorded", event_count); + + Ok(event_count) +} diff --git a/worker/examples/record_test_fixtures/scenarios.rs b/worker/examples/record_test_fixtures/scenarios.rs new file mode 100644 index 0000000..c7be964 --- /dev/null +++ b/worker/examples/record_test_fixtures/scenarios.rs @@ -0,0 +1,61 @@ +//! テストフィクスチャ用リクエスト定義 +//! +//! 各シナリオのリクエストと出力ファイル名を定義 + +use worker::llm_client::{Request, ToolDefinition}; + +/// テストシナリオ +pub struct TestScenario { + /// シナリオ名(説明) + pub name: &'static str, + /// 出力ファイル名(拡張子なし) + pub output_name: &'static str, + /// リクエスト + pub request: Request, +} + +/// 全てのテストシナリオを取得 +pub fn scenarios() -> Vec { + vec![ + simple_text_scenario(), + tool_call_scenario(), + ] +} + +/// シンプルなテキストレスポンス +fn simple_text_scenario() -> TestScenario { + TestScenario { + name: "Simple text response", + output_name: "simple_text", + request: Request::new() + .system("You are a helpful assistant. Be very concise.") + .user("Say hello in one word.") + .max_tokens(50), + } +} + +/// ツール呼び出しを含むレスポンス +fn tool_call_scenario() -> TestScenario { + let get_weather_tool = ToolDefinition::new("get_weather") + .description("Get the current weather for a city") + .input_schema(serde_json::json!({ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name" + } + }, + "required": ["city"] + })); + + TestScenario { + name: "Tool call response", + output_name: "tool_call", + request: Request::new() + .system("You are a helpful assistant. Use tools when appropriate.") + .user("What's the weather in Tokyo? Use the get_weather tool.") + .tool(get_weather_tool) + .max_tokens(200), + } +} diff --git a/worker/src/lib.rs b/worker/src/lib.rs index 0bcb17b..994ea8b 100644 --- a/worker/src/lib.rs +++ b/worker/src/lib.rs @@ -1,12 +1,17 @@ //! worker - LLMワーカーのメイン実装 //! //! このクレートは以下を提供します: +//! - Worker: ターン制御を行う高レベルコンポーネント //! - Timeline: イベントストリームの状態管理とハンドラーへのディスパッチ //! - LlmClient: LLMプロバイダとの通信 //! - 型消去されたHandler実装 pub mod llm_client; mod timeline; +mod tool_call_collector; +mod worker; pub use timeline::*; +pub use tool_call_collector::ToolCallCollector; +pub use worker::*; pub use worker_types::*; diff --git a/worker/src/llm_client/mod.rs b/worker/src/llm_client/mod.rs index 404b2e6..1850dc8 100644 --- a/worker/src/llm_client/mod.rs +++ b/worker/src/llm_client/mod.rs @@ -7,7 +7,6 @@ //! - **client**: `LlmClient` trait定義 //! - **scheme**: APIスキーマ(リクエスト/レスポンス変換) //! - **providers**: プロバイダ固有のHTTPクライアント実装 -//! - **testing**: テスト用のAPIレスポンス記録・再生機能 pub mod client; pub mod error; @@ -16,9 +15,6 @@ pub mod types; pub mod providers; pub(crate) mod scheme; -#[cfg(test)] -pub mod testing; - pub use client::*; pub use error::*; pub use types::*; diff --git a/worker/src/tool_call_collector.rs b/worker/src/tool_call_collector.rs new file mode 100644 index 0000000..8e1d092 --- /dev/null +++ b/worker/src/tool_call_collector.rs @@ -0,0 +1,144 @@ +//! ToolCallCollector - ツール呼び出し収集用ハンドラ +//! +//! TimelineのToolUseBlockHandler として登録され、 +//! ストリーム中のToolUseブロックを収集する。 + +use std::sync::{Arc, Mutex}; +use worker_types::{Handler, ToolCall, ToolUseBlockEvent, ToolUseBlockKind}; + +/// ToolUseブロックから収集したツール呼び出し情報を保持 +/// +/// ToolCallCollectorのHandler実装で使用するスコープ型 +#[derive(Debug, Default)] +pub struct CollectorState { + /// 現在のツール呼び出し情報 (ブロック進行中) + current_id: Option, + current_name: Option, + /// 蓄積中のJSON入力 + input_json_buffer: String, +} + +/// ToolCallCollector - ToolUseブロックハンドラ +/// +/// Timelineに登録してToolUseブロックイベントを受信し、 +/// 完了したToolCallを収集する。 +#[derive(Clone)] +pub struct ToolCallCollector { + /// 収集されたToolCall + collected: Arc>>, +} + +impl ToolCallCollector { + /// 新しいToolCallCollectorを作成 + pub fn new() -> Self { + Self { + collected: Arc::new(Mutex::new(Vec::new())), + } + } + + /// 収集されたToolCallを取得してクリア + pub fn take_collected(&self) -> Vec { + let mut guard = self.collected.lock().unwrap(); + std::mem::take(&mut *guard) + } + + /// 収集されたToolCallの参照を取得 + pub fn collected(&self) -> Vec { + self.collected.lock().unwrap().clone() + } + + /// 収集されたToolCallがあるかどうか + pub fn has_pending_calls(&self) -> bool { + !self.collected.lock().unwrap().is_empty() + } + + /// 収集をクリア + pub fn clear(&self) { + self.collected.lock().unwrap().clear(); + } +} + +impl Default for ToolCallCollector { + fn default() -> Self { + Self::new() + } +} + +impl Handler for ToolCallCollector { + type Scope = CollectorState; + + fn on_event(&mut self, scope: &mut Self::Scope, event: &ToolUseBlockEvent) { + match event { + ToolUseBlockEvent::Start(start) => { + scope.current_id = Some(start.id.clone()); + scope.current_name = Some(start.name.clone()); + scope.input_json_buffer.clear(); + } + ToolUseBlockEvent::InputJsonDelta(delta) => { + scope.input_json_buffer.push_str(delta); + } + ToolUseBlockEvent::Stop(_stop) => { + // ブロック完了時にToolCallを確定 + if let (Some(id), Some(name)) = (scope.current_id.take(), scope.current_name.take()) + { + let input = serde_json::from_str(&scope.input_json_buffer) + .unwrap_or(serde_json::Value::Null); + + let tool_call = ToolCall { id, name, input }; + + self.collected.lock().unwrap().push(tool_call); + } + scope.input_json_buffer.clear(); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Timeline; + use worker_types::Event; + + #[test] + fn test_collect_single_tool_call() { + let collector = ToolCallCollector::new(); + let mut timeline = Timeline::new(); + timeline.on_tool_use_block(collector.clone()); + + // ToolUseブロックのイベントシーケンスをディスパッチ + timeline.dispatch(&Event::tool_use_start(0, "tool_123", "get_weather")); + timeline.dispatch(&Event::tool_input_delta(0, r#"{"city":"#)); + timeline.dispatch(&Event::tool_input_delta(0, r#""Tokyo"}"#)); + timeline.dispatch(&Event::tool_use_stop(0)); + + // 収集されたToolCallを確認 + let calls = collector.take_collected(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].id, "tool_123"); + assert_eq!(calls[0].name, "get_weather"); + assert_eq!(calls[0].input["city"], "Tokyo"); + } + + #[test] + fn test_collect_multiple_tool_calls() { + let collector = ToolCallCollector::new(); + let mut timeline = Timeline::new(); + timeline.on_tool_use_block(collector.clone()); + + // 1つ目のToolCall + timeline.dispatch(&Event::tool_use_start(0, "call_1", "tool_a")); + timeline.dispatch(&Event::tool_input_delta(0, r#"{"a":1}"#)); + timeline.dispatch(&Event::tool_use_stop(0)); + + // 2つ目のToolCall + timeline.dispatch(&Event::tool_use_start(1, "call_2", "tool_b")); + timeline.dispatch(&Event::tool_input_delta(1, r#"{"b":2}"#)); + timeline.dispatch(&Event::tool_use_stop(1)); + + let calls = collector.take_collected(); + assert_eq!(calls.len(), 2); + assert_eq!(calls[0].name, "tool_a"); + assert_eq!(calls[1].name, "tool_b"); + } +} diff --git a/worker/src/worker.rs b/worker/src/worker.rs new file mode 100644 index 0000000..9498055 --- /dev/null +++ b/worker/src/worker.rs @@ -0,0 +1,359 @@ +//! Worker - ターン制御を行う高レベルコンポーネント +//! +//! LlmClientとTimelineを内包し、Tool/Hookを用いて自律的なインタラクションを実現する。 + +use std::collections::HashMap; +use std::sync::Arc; + +use futures::StreamExt; + +use crate::llm_client::{ClientError, LlmClient, Request, ToolDefinition}; +use crate::tool_call_collector::ToolCallCollector; +use crate::Timeline; +use worker_types::{ + ControlFlow, HookError, Message, Tool, ToolCall, ToolError, ToolResult, TurnResult, WorkerHook, +}; + +// ============================================================================= +// Worker Error +// ============================================================================= + +/// Workerエラー +#[derive(Debug, thiserror::Error)] +pub enum WorkerError { + /// クライアントエラー + #[error("Client error: {0}")] + Client(#[from] ClientError), + /// ツールエラー + #[error("Tool error: {0}")] + Tool(#[from] ToolError), + /// Hookエラー + #[error("Hook error: {0}")] + Hook(#[from] HookError), + /// 処理が中断された + #[error("Aborted: {0}")] + Aborted(String), +} + +// ============================================================================= +// Worker Config +// ============================================================================= + +/// Worker設定 +#[derive(Debug, Clone)] +pub struct WorkerConfig { + /// 最大ターン数(無限ループ防止) + pub max_turns: usize, +} + +impl Default for WorkerConfig { + fn default() -> Self { + Self { max_turns: 10 } + } +} + +// ============================================================================= +// Worker +// ============================================================================= + +/// Worker - ターン制御コンポーネント +/// +/// # 責務 +/// - LLMへのリクエスト送信とレスポンス処理 +/// - ツール呼び出しの収集と実行 +/// - Hookによる介入の提供 +/// - ターンループの制御 +pub struct Worker { + /// LLMクライアント + client: C, + /// イベントタイムライン + timeline: Timeline, + /// ツールコレクター(Timeline用ハンドラ) + tool_call_collector: ToolCallCollector, + /// 登録されたツール + tools: HashMap>, + /// 登録されたHook + hooks: Vec>, + /// 設定 + config: WorkerConfig, +} + +impl Worker { + /// 新しいWorkerを作成 + pub fn new(client: C) -> Self { + let tool_call_collector = ToolCallCollector::new(); + let mut timeline = Timeline::new(); + + // ToolCallCollectorをTimelineに登録 + timeline.on_tool_use_block(tool_call_collector.clone()); + + Self { + client, + timeline, + tool_call_collector, + tools: HashMap::new(), + hooks: Vec::new(), + config: WorkerConfig::default(), + } + } + + /// 設定を適用 + pub fn config(mut self, config: WorkerConfig) -> Self { + self.config = config; + self + } + + /// ツールを登録 + pub fn register_tool(&mut self, tool: impl Tool + 'static) { + let name = tool.name().to_string(); + self.tools.insert(name, Arc::new(tool)); + } + + /// 複数のツールを登録 + pub fn register_tools(&mut self, tools: impl IntoIterator) { + for tool in tools { + self.register_tool(tool); + } + } + + /// Hookを追加 + pub fn add_hook(&mut self, hook: impl WorkerHook + 'static) { + self.hooks.push(Box::new(hook)); + } + + /// タイムラインへの可変参照を取得(追加ハンドラ登録用) + pub fn timeline_mut(&mut self) -> &mut Timeline { + &mut self.timeline + } + + /// 登録されたツールからToolDefinitionのリストを生成 + fn build_tool_definitions(&self) -> Vec { + self.tools + .values() + .map(|tool| { + ToolDefinition::new(tool.name()) + .description(tool.description()) + .input_schema(tool.input_schema()) + }) + .collect() + } + + /// ターンを実行 + /// + /// メッセージを送信し、レスポンスを処理する。 + /// ツール呼び出しがある場合は自動的にループする。 + pub async fn run(&mut self, messages: Vec) -> Result, WorkerError> { + let mut context = messages; + let tool_definitions = self.build_tool_definitions(); + + for _turn in 0..self.config.max_turns { + // Hook: on_message_send + let control = self.run_on_message_send_hooks(&mut context).await?; + if let ControlFlow::Abort(reason) = control { + return Err(WorkerError::Aborted(reason)); + } + + // リクエスト構築 + let request = self.build_request(&context, &tool_definitions); + + // ストリーム処理 + let mut stream = self.client.stream(request).await?; + while let Some(event_result) = stream.next().await { + let event = event_result?; + self.timeline.dispatch(&event); + } + + // ツール呼び出しの収集結果を取得 + let tool_calls = self.tool_call_collector.take_collected(); + + if tool_calls.is_empty() { + // ツール呼び出しなし → ターン終了判定 + let turn_result = self.run_on_turn_end_hooks(&context).await?; + match turn_result { + TurnResult::Finish => { + return Ok(context); + } + TurnResult::ContinueWithMessages(additional) => { + context.extend(additional); + continue; + } + } + } + + // ツール実行 + let tool_results = self.execute_tools(tool_calls).await?; + + // ツール結果をコンテキストに追加 + for result in tool_results { + context.push(Message::tool_result(&result.tool_use_id, &result.content)); + } + } + + // 最大ターン数到達 + Err(WorkerError::Aborted(format!( + "Maximum turns ({}) reached", + self.config.max_turns + ))) + } + + /// リクエストを構築 + fn build_request(&self, context: &[Message], tool_definitions: &[ToolDefinition]) -> Request { + let mut request = Request::new(); + + // メッセージを追加 + for msg in context { + // worker-types::Message から llm_client::Message への変換 + request = request.message(crate::llm_client::Message { + role: match msg.role { + worker_types::Role::User => crate::llm_client::Role::User, + worker_types::Role::Assistant => crate::llm_client::Role::Assistant, + }, + content: match &msg.content { + worker_types::MessageContent::Text(t) => { + crate::llm_client::MessageContent::Text(t.clone()) + } + worker_types::MessageContent::ToolResult { + tool_use_id, + content, + } => crate::llm_client::MessageContent::ToolResult { + tool_use_id: tool_use_id.clone(), + content: content.clone(), + }, + worker_types::MessageContent::Parts(parts) => { + crate::llm_client::MessageContent::Parts( + parts + .iter() + .map(|p| match p { + worker_types::ContentPart::Text { text } => { + crate::llm_client::ContentPart::Text { text: text.clone() } + } + worker_types::ContentPart::ToolUse { id, name, input } => { + crate::llm_client::ContentPart::ToolUse { + id: id.clone(), + name: name.clone(), + input: input.clone(), + } + } + worker_types::ContentPart::ToolResult { + tool_use_id, + content, + } => crate::llm_client::ContentPart::ToolResult { + tool_use_id: tool_use_id.clone(), + content: content.clone(), + }, + }) + .collect(), + ) + } + }, + }); + } + + // ツール定義を追加 + for tool_def in tool_definitions { + request = request.tool(tool_def.clone()); + } + + request + } + + /// Hooks: on_message_send + async fn run_on_message_send_hooks( + &self, + context: &mut Vec, + ) -> Result { + for hook in &self.hooks { + let result = hook.on_message_send(context).await?; + match result { + ControlFlow::Continue => continue, + ControlFlow::Skip => return Ok(ControlFlow::Skip), + ControlFlow::Abort(reason) => return Ok(ControlFlow::Abort(reason)), + } + } + Ok(ControlFlow::Continue) + } + + /// Hooks: on_turn_end + async fn run_on_turn_end_hooks( + &self, + messages: &[Message], + ) -> Result { + for hook in &self.hooks { + let result = hook.on_turn_end(messages).await?; + match result { + TurnResult::Finish => continue, + TurnResult::ContinueWithMessages(msgs) => { + return Ok(TurnResult::ContinueWithMessages(msgs)); + } + } + } + Ok(TurnResult::Finish) + } + + /// ツールを並列実行 + async fn execute_tools( + &self, + mut tool_calls: Vec, + ) -> Result, WorkerError> { + let mut results = Vec::new(); + + // TODO: 将来的には join_all で並列実行 + // 現在は逐次実行 + for mut tool_call in tool_calls.drain(..) { + // Hook: before_tool_call + let mut skip = false; + for hook in &self.hooks { + let result = hook.before_tool_call(&mut tool_call).await?; + match result { + ControlFlow::Continue => {} + ControlFlow::Skip => { + skip = true; + break; + } + ControlFlow::Abort(reason) => { + return Err(WorkerError::Aborted(reason)); + } + } + } + + if skip { + continue; + } + + // ツール実行 + let mut tool_result = if let Some(tool) = self.tools.get(&tool_call.name) { + let input_json = serde_json::to_string(&tool_call.input).unwrap_or_default(); + match tool.execute(&input_json).await { + Ok(content) => ToolResult::success(&tool_call.id, content), + Err(e) => ToolResult::error(&tool_call.id, e.to_string()), + } + } else { + ToolResult::error( + &tool_call.id, + format!("Tool '{}' not found", tool_call.name), + ) + }; + + // Hook: after_tool_call + for hook in &self.hooks { + let result = hook.after_tool_call(&mut tool_result).await?; + match result { + ControlFlow::Continue => {} + ControlFlow::Skip => break, + ControlFlow::Abort(reason) => { + return Err(WorkerError::Aborted(reason)); + } + } + } + + results.push(tool_result); + } + + Ok(results) + } +} + +#[cfg(test)] +mod tests { + // 基本的なテストのみ。LlmClientを使ったテストは統合テストで行う。 +} diff --git a/worker/src/llm_client/testing.rs b/worker/tests/common/mod.rs similarity index 65% rename from worker/src/llm_client/testing.rs rename to worker/tests/common/mod.rs index 1613708..1dc18c6 100644 --- a/worker/src/llm_client/testing.rs +++ b/worker/tests/common/mod.rs @@ -1,14 +1,22 @@ -//! テスト用のAPIレスポンス記録・再生機能 +//! テスト用共通ユーティリティ //! -//! 実際のAPIレスポンスをタイムスタンプ付きで記録し、 -//! テスト時に再生できるようにする。 +//! MockLlmClient、イベントレコーダー・プレイヤーを提供する use std::fs::File; use std::io::{BufRead, BufReader, BufWriter, Write}; use std::path::Path; +use std::pin::Pin; use std::time::{Instant, SystemTime, UNIX_EPOCH}; +use async_trait::async_trait; +use futures::Stream; use serde::{Deserialize, Serialize}; +use worker::llm_client::{ClientError, LlmClient, Request}; +use worker_types::Event; + +// ============================================================================= +// Recorded Event Types +// ============================================================================= /// 記録されたSSEイベント #[derive(Debug, Clone, Serialize, Deserialize)] @@ -32,15 +40,21 @@ pub struct SessionMetadata { pub description: String, } +// ============================================================================= +// Event Recorder +// ============================================================================= + /// SSEイベントレコーダー /// /// 実際のAPIレスポンスを記録し、後でテストに使用できるようにする +#[allow(dead_code)] pub struct EventRecorder { start_time: Instant, events: Vec, metadata: SessionMetadata, } +#[allow(dead_code)] impl EventRecorder { /// 新しいレコーダーを作成 pub fn new(model: impl Into, description: impl Into) -> Self { @@ -97,15 +111,21 @@ impl EventRecorder { } } +// ============================================================================= +// Event Player +// ============================================================================= + /// SSEイベントプレイヤー /// /// 記録されたイベントを読み込み、テストで使用する +#[allow(dead_code)] pub struct EventPlayer { metadata: SessionMetadata, events: Vec, current_index: usize, } +#[allow(dead_code)] impl EventPlayer { /// ファイルから読み込み pub fn load(path: impl AsRef) -> std::io::Result { @@ -166,73 +186,55 @@ impl EventPlayer { pub fn reset(&mut self) { self.current_index = 0; } -} -#[cfg(test)] -mod tests { - use super::*; - use std::io::Write; - use tempfile::NamedTempFile; - - #[test] - fn test_record_and_playback() { - // レコーダーを作成して記録 - let mut recorder = EventRecorder::new("claude-sonnet-4-20250514", "Test recording"); - recorder.record("message_start", r#"{"type":"message_start"}"#); - recorder.record( - "content_block_start", - r#"{"type":"content_block_start","index":0}"#, - ); - recorder.record( - "content_block_delta", - r#"{"type":"content_block_delta","delta":{"type":"text_delta","text":"Hello"}}"#, - ); - - // 一時ファイルに保存 - let temp_file = NamedTempFile::new().unwrap(); - recorder.save(temp_file.path()).unwrap(); - - // 読み込んで確認 - let player = EventPlayer::load(temp_file.path()).unwrap(); - assert_eq!(player.metadata().model, "claude-sonnet-4-20250514"); - assert_eq!(player.event_count(), 3); - assert_eq!(player.events()[0].event_type, "message_start"); - assert_eq!(player.events()[2].event_type, "content_block_delta"); - } - - #[test] - fn test_player_iteration() { - // テストデータを直接作成 - let mut temp_file = NamedTempFile::new().unwrap(); - writeln!( - temp_file, - r#"{{"timestamp":1704067200,"model":"test","description":"test"}}"# - ) - .unwrap(); - writeln!( - temp_file, - r#"{{"elapsed_ms":0,"event_type":"ping","data":"{{}}"}}"# - ) - .unwrap(); - writeln!( - temp_file, - r#"{{"elapsed_ms":100,"event_type":"message_stop","data":"{{}}"}}"# - ) - .unwrap(); - temp_file.flush().unwrap(); - - let mut player = EventPlayer::load(temp_file.path()).unwrap(); - - let first = player.next_event().unwrap(); - assert_eq!(first.event_type, "ping"); - - let second = player.next_event().unwrap(); - assert_eq!(second.event_type, "message_stop"); - - assert!(player.next_event().is_none()); - - // リセット後は最初から - player.reset(); - assert_eq!(player.next_event().unwrap().event_type, "ping"); + /// 全イベントをworker_types::Eventとしてパースして取得 + pub fn parse_events(&self) -> Vec { + self.events + .iter() + .filter_map(|recorded| serde_json::from_str(&recorded.data).ok()) + .collect() + } +} + +// ============================================================================= +// MockLlmClient +// ============================================================================= + +/// テスト用のモックLLMクライアント +/// +/// 事前に定義されたイベントシーケンスをストリームとして返す。 +/// fixtureファイルからロードすることも、直接イベントを渡すこともできる。 +pub struct MockLlmClient { + events: Vec, +} + +impl MockLlmClient { + /// イベントリストから直接作成 + pub fn new(events: Vec) -> Self { + Self { events } + } + + /// fixtureファイルからロード + pub fn from_fixture(path: impl AsRef) -> std::io::Result { + let player = EventPlayer::load(path)?; + let events = player.parse_events(); + Ok(Self { events }) + } + + /// 保持しているイベント数を取得 + pub fn event_count(&self) -> usize { + self.events.len() + } +} + +#[async_trait] +impl LlmClient for MockLlmClient { + async fn stream( + &self, + _request: Request, + ) -> Result> + Send>>, ClientError> { + let events = self.events.clone(); + let stream = futures::stream::iter(events.into_iter().map(Ok)); + Ok(Box::pin(stream)) } } diff --git a/worker/tests/fixtures/tool_call.jsonl b/worker/tests/fixtures/tool_call.jsonl new file mode 100644 index 0000000..43ceb61 --- /dev/null +++ b/worker/tests/fixtures/tool_call.jsonl @@ -0,0 +1,16 @@ +{"timestamp":1767692881,"model":"claude-sonnet-4-20250514","description":"Tool call response"} +{"elapsed_ms":1783,"event_type":"Discriminant(1)","data":"{\"Usage\":{\"input_tokens\":409,\"output_tokens\":3,\"total_tokens\":412,\"cache_read_input_tokens\":0,\"cache_creation_input_tokens\":0}}"} +{"elapsed_ms":1783,"event_type":"Discriminant(4)","data":"{\"BlockStart\":{\"index\":0,\"block_type\":\"Text\",\"metadata\":\"Text\"}}"} +{"elapsed_ms":1783,"event_type":"Discriminant(5)","data":"{\"BlockDelta\":{\"index\":0,\"delta\":{\"Text\":\"I'll check\"}}}"} +{"elapsed_ms":1883,"event_type":"Discriminant(5)","data":"{\"BlockDelta\":{\"index\":0,\"delta\":{\"Text\":\" the current\"}}}"} +{"elapsed_ms":2063,"event_type":"Discriminant(0)","data":"{\"Ping\":{\"timestamp\":null}}"} +{"elapsed_ms":2063,"event_type":"Discriminant(5)","data":"{\"BlockDelta\":{\"index\":0,\"delta\":{\"Text\":\" weather in Tokyo for you using\"}}}"} +{"elapsed_ms":2124,"event_type":"Discriminant(5)","data":"{\"BlockDelta\":{\"index\":0,\"delta\":{\"Text\":\" the get_weather tool.\"}}}"} +{"elapsed_ms":2252,"event_type":"Discriminant(6)","data":"{\"BlockStop\":{\"index\":0,\"block_type\":\"Text\",\"stop_reason\":null}}"} +{"elapsed_ms":2253,"event_type":"Discriminant(4)","data":"{\"BlockStart\":{\"index\":1,\"block_type\":\"ToolUse\",\"metadata\":{\"ToolUse\":{\"id\":\"toolu_011Hg5wju1LGL7F65HyfE6bM\",\"name\":\"get_weather\"}}}}"} +{"elapsed_ms":2253,"event_type":"Discriminant(5)","data":"{\"BlockDelta\":{\"index\":1,\"delta\":{\"InputJson\":\"\"}}}"} +{"elapsed_ms":2306,"event_type":"Discriminant(5)","data":"{\"BlockDelta\":{\"index\":1,\"delta\":{\"InputJson\":\"{\\\"city\\\": \\\"Tokyo\"}}}"} +{"elapsed_ms":2451,"event_type":"Discriminant(5)","data":"{\"BlockDelta\":{\"index\":1,\"delta\":{\"InputJson\":\"\\\"}\"}}}"} +{"elapsed_ms":2451,"event_type":"Discriminant(6)","data":"{\"BlockStop\":{\"index\":1,\"block_type\":\"Text\",\"stop_reason\":null}}"} +{"elapsed_ms":2464,"event_type":"Discriminant(1)","data":"{\"Usage\":{\"input_tokens\":409,\"output_tokens\":71,\"total_tokens\":480,\"cache_read_input_tokens\":0,\"cache_creation_input_tokens\":0}}"} +{"elapsed_ms":2470,"event_type":"Discriminant(2)","data":"{\"Status\":{\"status\":\"Completed\"}}"} diff --git a/worker/tests/worker_fixtures.rs b/worker/tests/worker_fixtures.rs new file mode 100644 index 0000000..7d79eb8 --- /dev/null +++ b/worker/tests/worker_fixtures.rs @@ -0,0 +1,243 @@ +//! Workerフィクスチャベースの統合テスト +//! +//! 記録されたAPIレスポンスを使ってWorkerの動作をテストする。 +//! APIキー不要でローカルで実行可能。 + +mod common; + +use std::path::Path; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +use async_trait::async_trait; +use common::MockLlmClient; +use worker::{Worker, WorkerConfig}; +use worker_types::{Tool, ToolError}; + +/// フィクスチャディレクトリのパス +fn fixtures_dir() -> std::path::PathBuf { + Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures") +} + +/// シンプルなテスト用ツール +#[derive(Clone)] +struct MockWeatherTool { + call_count: Arc, +} + +impl MockWeatherTool { + fn new() -> Self { + Self { + call_count: Arc::new(AtomicUsize::new(0)), + } + } + + fn get_call_count(&self) -> usize { + self.call_count.load(Ordering::SeqCst) + } +} + +#[async_trait] +impl Tool for MockWeatherTool { + fn name(&self) -> &str { + "get_weather" + } + + fn description(&self) -> &str { + "Get the current weather for a city" + } + + fn input_schema(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name" + } + }, + "required": ["city"] + }) + } + + async fn execute(&self, input_json: &str) -> Result { + self.call_count.fetch_add(1, Ordering::SeqCst); + + // 入力をパース + let input: serde_json::Value = serde_json::from_str(input_json) + .map_err(|e| ToolError::InvalidArgument(e.to_string()))?; + + let city = input["city"] + .as_str() + .unwrap_or("Unknown"); + + // モックのレスポンスを返す + Ok(format!("Weather in {}: Sunny, 22°C", city)) + } +} + +// ============================================================================= +// Basic Fixture Tests +// ============================================================================= + +/// MockLlmClientがJSONLフィクスチャファイルから正しくイベントをロードできることを確認 +/// +/// 既存のanthropic_*.jsonlファイルを使用し、イベントがパース・ロードされることを検証する。 +#[test] +fn test_mock_client_from_fixture() { + // 既存のフィクスチャをロード + let fixture_path = fixtures_dir().join("anthropic_1767624445.jsonl"); + if !fixture_path.exists() { + println!("Fixture not found, skipping test"); + return; + } + + let client = MockLlmClient::from_fixture(&fixture_path).unwrap(); + assert!(client.event_count() > 0, "Should have loaded events"); + println!("Loaded {} events from fixture", client.event_count()); +} + +/// MockLlmClientが直接指定されたイベントリストで正しく動作することを確認 +/// +/// fixtureファイルを使わず、プログラムでイベントを構築してクライアントを作成する。 +#[test] +fn test_mock_client_from_events() { + use worker_types::Event; + + // 直接イベントを指定 + let events = vec![ + Event::text_block_start(0), + Event::text_delta(0, "Hello!"), + Event::text_block_stop(0, None), + ]; + + let client = MockLlmClient::new(events); + assert_eq!(client.event_count(), 3); +} + +// ============================================================================= +// Worker Tests with Fixtures +// ============================================================================= + +/// Workerがシンプルなテキストレスポンスを正しく処理できることを確認 +/// +/// simple_text.jsonlフィクスチャを使用し、ツール呼び出しなしのシナリオをテストする。 +/// フィクスチャがない場合はスキップされる。 +#[tokio::test] +async fn test_worker_simple_text_response() { + let fixture_path = fixtures_dir().join("simple_text.jsonl"); + if !fixture_path.exists() { + println!("Fixture not found: {:?}, skipping test", fixture_path); + println!("Run: cargo run --example record_worker_test"); + return; + } + + let client = MockLlmClient::from_fixture(&fixture_path).unwrap(); + let mut worker = Worker::new(client); + + // シンプルなメッセージを送信 + let messages = vec![worker_types::Message::user("Hello")]; + let result = worker.run(messages).await; + + assert!(result.is_ok(), "Worker should complete successfully"); +} + +/// Workerがツール呼び出しを含むレスポンスを正しく処理できることを確認 +/// +/// tool_call.jsonlフィクスチャを使用し、MockWeatherToolが呼び出されることをテストする。 +/// max_turns=1に設定し、ツール実行後のループを防止。 +#[tokio::test] +async fn test_worker_tool_call() { + let fixture_path = fixtures_dir().join("tool_call.jsonl"); + if !fixture_path.exists() { + println!("Fixture not found: {:?}, skipping test", fixture_path); + println!("Run: cargo run --example record_worker_test"); + return; + } + + let client = MockLlmClient::from_fixture(&fixture_path).unwrap(); + let mut worker = Worker::new(client); + + // ツールを登録 + let weather_tool = MockWeatherTool::new(); + let tool_for_check = weather_tool.clone(); + worker.register_tool(weather_tool); + + // 設定: ツール実行後はターン終了(ループしない) + worker = worker.config(WorkerConfig { max_turns: 1 }); + + // メッセージを送信 + let messages = vec![worker_types::Message::user("What's the weather in Tokyo?")]; + let _result = worker.run(messages).await; + + // ツールが呼び出されたことを確認 + // Note: max_turns=1なのでツール結果後のリクエストは送信されない + let call_count = tool_for_check.get_call_count(); + println!("Tool was called {} times", call_count); + + // フィクスチャにToolUseが含まれていればツールが呼び出されるはず + // ただしmax_turns=1なので1回で終了 +} + +/// fixtureファイルなしでWorkerが動作することを確認 +/// +/// プログラムでイベントシーケンスを構築し、MockLlmClientに渡してテストする。 +/// テストの独立性を高め、外部ファイルへの依存を排除したい場合に有用。 +#[tokio::test] +async fn test_worker_with_programmatic_events() { + use worker_types::{Event, ResponseStatus, StatusEvent}; + + // プログラムでイベントシーケンスを構築 + let events = vec![ + Event::text_block_start(0), + Event::text_delta(0, "Hello, "), + Event::text_delta(0, "World!"), + Event::text_block_stop(0, None), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ]; + + let client = MockLlmClient::new(events); + let mut worker = Worker::new(client); + + let messages = vec![worker_types::Message::user("Greet me")]; + let result = worker.run(messages).await; + + assert!(result.is_ok(), "Worker should complete successfully"); +} + +/// ToolCallCollectorがToolUseブロックイベントから正しくToolCallを収集することを確認 +/// +/// Timelineにイベントをディスパッチし、ToolCallCollectorが +/// id, name, input(JSON)を正しく抽出できることを検証する。 +#[tokio::test] +async fn test_tool_call_collector_integration() { + use worker::ToolCallCollector; + use worker::Timeline; + use worker_types::Event; + + // ToolUseブロックを含むイベントシーケンス + let events = vec![ + Event::tool_use_start(0, "call_123", "get_weather"), + Event::tool_input_delta(0, r#"{"city":"#), + Event::tool_input_delta(0, r#""Tokyo"}"#), + Event::tool_use_stop(0), + ]; + + let collector = ToolCallCollector::new(); + let mut timeline = Timeline::new(); + timeline.on_tool_use_block(collector.clone()); + + // イベントをディスパッチ + for event in &events { + timeline.dispatch(event); + } + + // 収集されたToolCallを確認 + let calls = collector.take_collected(); + assert_eq!(calls.len(), 1, "Should collect one tool call"); + assert_eq!(calls[0].name, "get_weather"); + assert_eq!(calls[0].id, "call_123"); + assert_eq!(calls[0].input["city"], "Tokyo"); +}