feat: Implement Worker for LLM turn management/tool call/hooks

This commit is contained in:
Keisuke Hirata 2026-01-06 20:38:08 +09:00
parent a4e2795e56
commit e82e0a3ed9
13 changed files with 1333 additions and 74 deletions

140
worker-types/src/hook.rs Normal file
View File

@ -0,0 +1,140 @@
//! Hook関連の型定義
//!
//! Worker層でのターン制御・介入に使用される型
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use thiserror::Error;
// =============================================================================
// Control Flow Types
// =============================================================================
/// Hook処理の制御フロー
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ControlFlow {
/// 処理を続行
Continue,
/// 現在の処理をスキップTool実行など
Skip,
/// 処理を中断
Abort(String),
}
/// ターン終了時の判定結果
#[derive(Debug, Clone)]
pub enum TurnResult {
/// ターンを終了
Finish,
/// メッセージを追加してターン継続(自己修正など)
ContinueWithMessages(Vec<crate::Message>),
}
// =============================================================================
// Tool Call / Result Types
// =============================================================================
/// ツール呼び出し情報
///
/// LLMからのToolUseブロックを表現し、Hook処理で改変可能
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
/// ツール呼び出しIDレスポンスとの紐付けに使用
pub id: String,
/// ツール名
pub name: String,
/// 入力引数JSON
pub input: Value,
}
/// ツール実行結果
///
/// ツール実行後の結果を表現し、Hook処理で改変可能
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
/// 対応するツール呼び出しID
pub tool_use_id: String,
/// 結果コンテンツ
pub content: String,
/// エラーかどうか
#[serde(default)]
pub is_error: bool,
}
impl ToolResult {
/// 成功結果を作成
pub fn success(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
tool_use_id: tool_use_id.into(),
content: content.into(),
is_error: false,
}
}
/// エラー結果を作成
pub fn error(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
tool_use_id: tool_use_id.into(),
content: content.into(),
is_error: true,
}
}
}
// =============================================================================
// Hook Error
// =============================================================================
/// Hookエラー
#[derive(Debug, Error)]
pub enum HookError {
/// 処理が中断された
#[error("Aborted: {0}")]
Aborted(String),
/// 内部エラー
#[error("Hook error: {0}")]
Internal(String),
}
// =============================================================================
// WorkerHook Trait
// =============================================================================
/// Worker Hook trait
///
/// ターンの進行・メッセージ・ツール実行に対して介入するためのトレイト。
/// デフォルト実装では何も行わずContinueを返す。
#[async_trait]
pub trait WorkerHook: Send + Sync {
/// メッセージ送信前
///
/// リクエストに含まれるメッセージリストを改変できる。
async fn on_message_send(
&self,
_context: &mut Vec<crate::Message>,
) -> Result<ControlFlow, HookError> {
Ok(ControlFlow::Continue)
}
/// ツール実行前
///
/// 実行をキャンセルしたり、引数を書き換えることができる。
async fn before_tool_call(&self, _tool_call: &mut ToolCall) -> Result<ControlFlow, HookError> {
Ok(ControlFlow::Continue)
}
/// ツール実行後
///
/// 結果を書き換えたり、隠蔽したりできる。
async fn after_tool_call(&self, _tool_result: &mut ToolResult) -> Result<ControlFlow, HookError> {
Ok(ControlFlow::Continue)
}
/// ターン終了時
///
/// 生成されたメッセージを検査し、必要ならリトライを指示できる。
async fn on_turn_end(&self, _messages: &[crate::Message]) -> Result<TurnResult, HookError> {
Ok(TurnResult::Finish)
}
}

View File

@ -3,12 +3,19 @@
//! このクレートは以下を提供します:
//! - Event: llm_client層からのフラットなイベント列挙
//! - Kind/Handler: タイムライン層でのイベント処理トレイト
//! - Tool: ツール定義トレイト
//! - Hook: Worker層での介入用トレイト
//! - Message: メッセージ型
//! - 各種イベント構造体
mod event;
mod handler;
mod hook;
mod message;
mod tool;
pub use event::*;
pub use handler::*;
pub use hook::*;
pub use message::*;
pub use tool::*;

View File

@ -0,0 +1,87 @@
//! メッセージ型定義
//!
//! LLM会話で使用されるメッセージ構造
use serde::{Deserialize, Serialize};
/// メッセージのロール
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
/// ユーザー
User,
/// アシスタント
Assistant,
}
/// メッセージ
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
/// ロール
pub role: Role,
/// コンテンツ
pub content: MessageContent,
}
/// メッセージコンテンツ
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageContent {
/// テキストコンテンツ
Text(String),
/// ツール結果
ToolResult {
tool_use_id: String,
content: String,
},
/// 複合コンテンツ (テキスト + ツール使用等)
Parts(Vec<ContentPart>),
}
/// コンテンツパーツ
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ContentPart {
/// テキスト
#[serde(rename = "text")]
Text { text: String },
/// ツール使用
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
/// ツール結果
#[serde(rename = "tool_result")]
ToolResult { tool_use_id: String, content: String },
}
impl Message {
/// ユーザーメッセージを作成
pub fn user(content: impl Into<String>) -> Self {
Self {
role: Role::User,
content: MessageContent::Text(content.into()),
}
}
/// アシスタントメッセージを作成
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: Role::Assistant,
content: MessageContent::Text(content.into()),
}
}
/// ツール結果メッセージを作成
pub fn tool_result(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: Role::User,
content: MessageContent::ToolResult {
tool_use_id: tool_use_id.into(),
content: content.into(),
},
}
}
}

View File

@ -0,0 +1,99 @@
//! テストフィクスチャ記録ツール
//!
//! 定義されたシナリオのAPIレスポンスを記録する。
//!
//! ## 使用方法
//!
//! ```bash
//! # 利用可能なシナリオを表示
//! cargo run --example record_test_fixtures
//!
//! # 特定のシナリオを記録
//! ANTHROPIC_API_KEY=your-key cargo run --example record_test_fixtures -- simple_text
//! ANTHROPIC_API_KEY=your-key cargo run --example record_test_fixtures -- tool_call
//!
//! # 全シナリオを記録
//! ANTHROPIC_API_KEY=your-key cargo run --example record_test_fixtures -- --all
//! ```
mod recorder;
mod scenarios;
use worker::llm_client::providers::anthropic::AnthropicClient;
fn print_usage() {
println!("Usage: cargo run --example record_test_fixtures -- <scenario_name>");
println!(" cargo run --example record_test_fixtures -- --all");
println!();
println!("Available scenarios:");
for scenario in scenarios::scenarios() {
println!(" {:20} - {}", scenario.output_name, scenario.name);
}
println!();
println!("Options:");
println!(" --all Record all scenarios");
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let args: Vec<String> = std::env::args().collect();
// 引数がなければ使い方を表示
if args.len() < 2 {
print_usage();
return Ok(());
}
let arg = &args[1];
// 全シナリオを取得
let all_scenarios = scenarios::scenarios();
// 実行するシナリオを決定
let scenarios_to_run: Vec<_> = if arg == "--all" {
all_scenarios
} else {
// 指定されたシナリオを検索
let found: Vec<_> = all_scenarios
.into_iter()
.filter(|s| s.output_name == arg)
.collect();
if found.is_empty() {
eprintln!("Error: Unknown scenario '{}'", arg);
println!();
print_usage();
std::process::exit(1);
}
found
};
// APIキーを取得
let api_key = std::env::var("ANTHROPIC_API_KEY")
.expect("ANTHROPIC_API_KEY environment variable must be set");
let model = "claude-sonnet-4-20250514";
println!("=== Test Fixture Generator ===");
println!("Model: {}", model);
println!("Scenarios: {}\n", scenarios_to_run.len());
let client = AnthropicClient::new(&api_key, model);
// シナリオを記録
for scenario in scenarios_to_run {
recorder::record_request(
&client,
scenario.request,
scenario.name,
scenario.output_name,
model,
)
.await?;
}
println!("\n✅ Done!");
println!("Run tests with: cargo test -p worker");
Ok(())
}

View File

@ -0,0 +1,100 @@
//! テストフィクスチャ記録機構
//!
//! イベントをJSONLフォーマットでファイルに保存する
use std::fs::{self, File};
use std::io::{BufWriter, Write};
use std::path::Path;
use std::time::{Instant, SystemTime, UNIX_EPOCH};
use futures::StreamExt;
use worker::llm_client::{LlmClient, Request};
/// 記録されたイベント
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct RecordedEvent {
pub elapsed_ms: u64,
pub event_type: String,
pub data: String,
}
/// セッションメタデータ
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct SessionMetadata {
pub timestamp: u64,
pub model: String,
pub description: String,
}
/// イベントシーケンスをファイルに保存
pub fn save_fixture(
path: impl AsRef<Path>,
metadata: &SessionMetadata,
events: &[RecordedEvent],
) -> std::io::Result<()> {
let file = File::create(path)?;
let mut writer = BufWriter::new(file);
writeln!(writer, "{}", serde_json::to_string(metadata)?)?;
for event in events {
writeln!(writer, "{}", serde_json::to_string(event)?)?;
}
writer.flush()?;
Ok(())
}
/// リクエストを送信してイベントを記録
pub async fn record_request<C: LlmClient>(
client: &C,
request: Request,
description: &str,
output_name: &str,
model: &str,
) -> Result<usize, Box<dyn std::error::Error>> {
println!("\n📝 Recording: {}", description);
let start_time = Instant::now();
let mut events: Vec<RecordedEvent> = Vec::new();
let mut stream = client.stream(request).await?;
while let Some(result) = stream.next().await {
let elapsed = start_time.elapsed().as_millis() as u64;
match result {
Ok(event) => {
let event_json = serde_json::to_string(&event)?;
println!(" [{:>6}ms] {:?}", elapsed, event);
events.push(RecordedEvent {
elapsed_ms: elapsed,
event_type: format!("{:?}", std::mem::discriminant(&event)),
data: event_json,
});
}
Err(e) => {
eprintln!(" Error: {}", e);
break;
}
}
}
// 保存
let fixtures_dir = Path::new("worker/tests/fixtures");
fs::create_dir_all(fixtures_dir)?;
let filepath = fixtures_dir.join(format!("{}.jsonl", output_name));
let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
let metadata = SessionMetadata {
timestamp,
model: model.to_string(),
description: description.to_string(),
};
save_fixture(&filepath, &metadata, &events)?;
let event_count = events.len();
println!(" 💾 Saved: {}", filepath.display());
println!(" 📊 {} events recorded", event_count);
Ok(event_count)
}

View File

@ -0,0 +1,61 @@
//! テストフィクスチャ用リクエスト定義
//!
//! 各シナリオのリクエストと出力ファイル名を定義
use worker::llm_client::{Request, ToolDefinition};
/// テストシナリオ
pub struct TestScenario {
/// シナリオ名(説明)
pub name: &'static str,
/// 出力ファイル名(拡張子なし)
pub output_name: &'static str,
/// リクエスト
pub request: Request,
}
/// 全てのテストシナリオを取得
pub fn scenarios() -> Vec<TestScenario> {
vec![
simple_text_scenario(),
tool_call_scenario(),
]
}
/// シンプルなテキストレスポンス
fn simple_text_scenario() -> TestScenario {
TestScenario {
name: "Simple text response",
output_name: "simple_text",
request: Request::new()
.system("You are a helpful assistant. Be very concise.")
.user("Say hello in one word.")
.max_tokens(50),
}
}
/// ツール呼び出しを含むレスポンス
fn tool_call_scenario() -> TestScenario {
let get_weather_tool = ToolDefinition::new("get_weather")
.description("Get the current weather for a city")
.input_schema(serde_json::json!({
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name"
}
},
"required": ["city"]
}));
TestScenario {
name: "Tool call response",
output_name: "tool_call",
request: Request::new()
.system("You are a helpful assistant. Use tools when appropriate.")
.user("What's the weather in Tokyo? Use the get_weather tool.")
.tool(get_weather_tool)
.max_tokens(200),
}
}

View File

@ -1,12 +1,17 @@
//! worker - LLMワーカーのメイン実装
//!
//! このクレートは以下を提供します:
//! - Worker: ターン制御を行う高レベルコンポーネント
//! - Timeline: イベントストリームの状態管理とハンドラーへのディスパッチ
//! - LlmClient: LLMプロバイダとの通信
//! - 型消去されたHandler実装
pub mod llm_client;
mod timeline;
mod tool_call_collector;
mod worker;
pub use timeline::*;
pub use tool_call_collector::ToolCallCollector;
pub use worker::*;
pub use worker_types::*;

View File

@ -7,7 +7,6 @@
//! - **client**: `LlmClient` trait定義
//! - **scheme**: APIスキーマリクエスト/レスポンス変換)
//! - **providers**: プロバイダ固有のHTTPクライアント実装
//! - **testing**: テスト用のAPIレスポンス記録・再生機能
pub mod client;
pub mod error;
@ -16,9 +15,6 @@ pub mod types;
pub mod providers;
pub(crate) mod scheme;
#[cfg(test)]
pub mod testing;
pub use client::*;
pub use error::*;
pub use types::*;

View File

@ -0,0 +1,144 @@
//! ToolCallCollector - ツール呼び出し収集用ハンドラ
//!
//! TimelineのToolUseBlockHandler として登録され、
//! ストリーム中のToolUseブロックを収集する。
use std::sync::{Arc, Mutex};
use worker_types::{Handler, ToolCall, ToolUseBlockEvent, ToolUseBlockKind};
/// ToolUseブロックから収集したツール呼び出し情報を保持
///
/// ToolCallCollectorのHandler実装で使用するスコープ型
#[derive(Debug, Default)]
pub struct CollectorState {
/// 現在のツール呼び出し情報 (ブロック進行中)
current_id: Option<String>,
current_name: Option<String>,
/// 蓄積中のJSON入力
input_json_buffer: String,
}
/// ToolCallCollector - ToolUseブロックハンドラ
///
/// Timelineに登録してToolUseブロックイベントを受信し、
/// 完了したToolCallを収集する。
#[derive(Clone)]
pub struct ToolCallCollector {
/// 収集されたToolCall
collected: Arc<Mutex<Vec<ToolCall>>>,
}
impl ToolCallCollector {
/// 新しいToolCallCollectorを作成
pub fn new() -> Self {
Self {
collected: Arc::new(Mutex::new(Vec::new())),
}
}
/// 収集されたToolCallを取得してクリア
pub fn take_collected(&self) -> Vec<ToolCall> {
let mut guard = self.collected.lock().unwrap();
std::mem::take(&mut *guard)
}
/// 収集されたToolCallの参照を取得
pub fn collected(&self) -> Vec<ToolCall> {
self.collected.lock().unwrap().clone()
}
/// 収集されたToolCallがあるかどうか
pub fn has_pending_calls(&self) -> bool {
!self.collected.lock().unwrap().is_empty()
}
/// 収集をクリア
pub fn clear(&self) {
self.collected.lock().unwrap().clear();
}
}
impl Default for ToolCallCollector {
fn default() -> Self {
Self::new()
}
}
impl Handler<ToolUseBlockKind> for ToolCallCollector {
type Scope = CollectorState;
fn on_event(&mut self, scope: &mut Self::Scope, event: &ToolUseBlockEvent) {
match event {
ToolUseBlockEvent::Start(start) => {
scope.current_id = Some(start.id.clone());
scope.current_name = Some(start.name.clone());
scope.input_json_buffer.clear();
}
ToolUseBlockEvent::InputJsonDelta(delta) => {
scope.input_json_buffer.push_str(delta);
}
ToolUseBlockEvent::Stop(_stop) => {
// ブロック完了時にToolCallを確定
if let (Some(id), Some(name)) = (scope.current_id.take(), scope.current_name.take())
{
let input = serde_json::from_str(&scope.input_json_buffer)
.unwrap_or(serde_json::Value::Null);
let tool_call = ToolCall { id, name, input };
self.collected.lock().unwrap().push(tool_call);
}
scope.input_json_buffer.clear();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Timeline;
use worker_types::Event;
#[test]
fn test_collect_single_tool_call() {
let collector = ToolCallCollector::new();
let mut timeline = Timeline::new();
timeline.on_tool_use_block(collector.clone());
// ToolUseブロックのイベントシーケンスをディスパッチ
timeline.dispatch(&Event::tool_use_start(0, "tool_123", "get_weather"));
timeline.dispatch(&Event::tool_input_delta(0, r#"{"city":"#));
timeline.dispatch(&Event::tool_input_delta(0, r#""Tokyo"}"#));
timeline.dispatch(&Event::tool_use_stop(0));
// 収集されたToolCallを確認
let calls = collector.take_collected();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].id, "tool_123");
assert_eq!(calls[0].name, "get_weather");
assert_eq!(calls[0].input["city"], "Tokyo");
}
#[test]
fn test_collect_multiple_tool_calls() {
let collector = ToolCallCollector::new();
let mut timeline = Timeline::new();
timeline.on_tool_use_block(collector.clone());
// 1つ目のToolCall
timeline.dispatch(&Event::tool_use_start(0, "call_1", "tool_a"));
timeline.dispatch(&Event::tool_input_delta(0, r#"{"a":1}"#));
timeline.dispatch(&Event::tool_use_stop(0));
// 2つ目のToolCall
timeline.dispatch(&Event::tool_use_start(1, "call_2", "tool_b"));
timeline.dispatch(&Event::tool_input_delta(1, r#"{"b":2}"#));
timeline.dispatch(&Event::tool_use_stop(1));
let calls = collector.take_collected();
assert_eq!(calls.len(), 2);
assert_eq!(calls[0].name, "tool_a");
assert_eq!(calls[1].name, "tool_b");
}
}

359
worker/src/worker.rs Normal file
View File

@ -0,0 +1,359 @@
//! Worker - ターン制御を行う高レベルコンポーネント
//!
//! LlmClientとTimelineを内包し、Tool/Hookを用いて自律的なインタラクションを実現する。
use std::collections::HashMap;
use std::sync::Arc;
use futures::StreamExt;
use crate::llm_client::{ClientError, LlmClient, Request, ToolDefinition};
use crate::tool_call_collector::ToolCallCollector;
use crate::Timeline;
use worker_types::{
ControlFlow, HookError, Message, Tool, ToolCall, ToolError, ToolResult, TurnResult, WorkerHook,
};
// =============================================================================
// Worker Error
// =============================================================================
/// Workerエラー
#[derive(Debug, thiserror::Error)]
pub enum WorkerError {
/// クライアントエラー
#[error("Client error: {0}")]
Client(#[from] ClientError),
/// ツールエラー
#[error("Tool error: {0}")]
Tool(#[from] ToolError),
/// Hookエラー
#[error("Hook error: {0}")]
Hook(#[from] HookError),
/// 処理が中断された
#[error("Aborted: {0}")]
Aborted(String),
}
// =============================================================================
// Worker Config
// =============================================================================
/// Worker設定
#[derive(Debug, Clone)]
pub struct WorkerConfig {
/// 最大ターン数(無限ループ防止)
pub max_turns: usize,
}
impl Default for WorkerConfig {
fn default() -> Self {
Self { max_turns: 10 }
}
}
// =============================================================================
// Worker
// =============================================================================
/// Worker - ターン制御コンポーネント
///
/// # 責務
/// - LLMへのリクエスト送信とレスポンス処理
/// - ツール呼び出しの収集と実行
/// - Hookによる介入の提供
/// - ターンループの制御
pub struct Worker<C: LlmClient> {
/// LLMクライアント
client: C,
/// イベントタイムライン
timeline: Timeline,
/// ツールコレクターTimeline用ハンドラ
tool_call_collector: ToolCallCollector,
/// 登録されたツール
tools: HashMap<String, Arc<dyn Tool>>,
/// 登録されたHook
hooks: Vec<Box<dyn WorkerHook>>,
/// 設定
config: WorkerConfig,
}
impl<C: LlmClient> Worker<C> {
/// 新しいWorkerを作成
pub fn new(client: C) -> Self {
let tool_call_collector = ToolCallCollector::new();
let mut timeline = Timeline::new();
// ToolCallCollectorをTimelineに登録
timeline.on_tool_use_block(tool_call_collector.clone());
Self {
client,
timeline,
tool_call_collector,
tools: HashMap::new(),
hooks: Vec::new(),
config: WorkerConfig::default(),
}
}
/// 設定を適用
pub fn config(mut self, config: WorkerConfig) -> Self {
self.config = config;
self
}
/// ツールを登録
pub fn register_tool(&mut self, tool: impl Tool + 'static) {
let name = tool.name().to_string();
self.tools.insert(name, Arc::new(tool));
}
/// 複数のツールを登録
pub fn register_tools(&mut self, tools: impl IntoIterator<Item = impl Tool + 'static>) {
for tool in tools {
self.register_tool(tool);
}
}
/// Hookを追加
pub fn add_hook(&mut self, hook: impl WorkerHook + 'static) {
self.hooks.push(Box::new(hook));
}
/// タイムラインへの可変参照を取得(追加ハンドラ登録用)
pub fn timeline_mut(&mut self) -> &mut Timeline {
&mut self.timeline
}
/// 登録されたツールからToolDefinitionのリストを生成
fn build_tool_definitions(&self) -> Vec<ToolDefinition> {
self.tools
.values()
.map(|tool| {
ToolDefinition::new(tool.name())
.description(tool.description())
.input_schema(tool.input_schema())
})
.collect()
}
/// ターンを実行
///
/// メッセージを送信し、レスポンスを処理する。
/// ツール呼び出しがある場合は自動的にループする。
pub async fn run(&mut self, messages: Vec<Message>) -> Result<Vec<Message>, WorkerError> {
let mut context = messages;
let tool_definitions = self.build_tool_definitions();
for _turn in 0..self.config.max_turns {
// Hook: on_message_send
let control = self.run_on_message_send_hooks(&mut context).await?;
if let ControlFlow::Abort(reason) = control {
return Err(WorkerError::Aborted(reason));
}
// リクエスト構築
let request = self.build_request(&context, &tool_definitions);
// ストリーム処理
let mut stream = self.client.stream(request).await?;
while let Some(event_result) = stream.next().await {
let event = event_result?;
self.timeline.dispatch(&event);
}
// ツール呼び出しの収集結果を取得
let tool_calls = self.tool_call_collector.take_collected();
if tool_calls.is_empty() {
// ツール呼び出しなし → ターン終了判定
let turn_result = self.run_on_turn_end_hooks(&context).await?;
match turn_result {
TurnResult::Finish => {
return Ok(context);
}
TurnResult::ContinueWithMessages(additional) => {
context.extend(additional);
continue;
}
}
}
// ツール実行
let tool_results = self.execute_tools(tool_calls).await?;
// ツール結果をコンテキストに追加
for result in tool_results {
context.push(Message::tool_result(&result.tool_use_id, &result.content));
}
}
// 最大ターン数到達
Err(WorkerError::Aborted(format!(
"Maximum turns ({}) reached",
self.config.max_turns
)))
}
/// リクエストを構築
fn build_request(&self, context: &[Message], tool_definitions: &[ToolDefinition]) -> Request {
let mut request = Request::new();
// メッセージを追加
for msg in context {
// worker-types::Message から llm_client::Message への変換
request = request.message(crate::llm_client::Message {
role: match msg.role {
worker_types::Role::User => crate::llm_client::Role::User,
worker_types::Role::Assistant => crate::llm_client::Role::Assistant,
},
content: match &msg.content {
worker_types::MessageContent::Text(t) => {
crate::llm_client::MessageContent::Text(t.clone())
}
worker_types::MessageContent::ToolResult {
tool_use_id,
content,
} => crate::llm_client::MessageContent::ToolResult {
tool_use_id: tool_use_id.clone(),
content: content.clone(),
},
worker_types::MessageContent::Parts(parts) => {
crate::llm_client::MessageContent::Parts(
parts
.iter()
.map(|p| match p {
worker_types::ContentPart::Text { text } => {
crate::llm_client::ContentPart::Text { text: text.clone() }
}
worker_types::ContentPart::ToolUse { id, name, input } => {
crate::llm_client::ContentPart::ToolUse {
id: id.clone(),
name: name.clone(),
input: input.clone(),
}
}
worker_types::ContentPart::ToolResult {
tool_use_id,
content,
} => crate::llm_client::ContentPart::ToolResult {
tool_use_id: tool_use_id.clone(),
content: content.clone(),
},
})
.collect(),
)
}
},
});
}
// ツール定義を追加
for tool_def in tool_definitions {
request = request.tool(tool_def.clone());
}
request
}
/// Hooks: on_message_send
async fn run_on_message_send_hooks(
&self,
context: &mut Vec<Message>,
) -> Result<ControlFlow, WorkerError> {
for hook in &self.hooks {
let result = hook.on_message_send(context).await?;
match result {
ControlFlow::Continue => continue,
ControlFlow::Skip => return Ok(ControlFlow::Skip),
ControlFlow::Abort(reason) => return Ok(ControlFlow::Abort(reason)),
}
}
Ok(ControlFlow::Continue)
}
/// Hooks: on_turn_end
async fn run_on_turn_end_hooks(
&self,
messages: &[Message],
) -> Result<TurnResult, WorkerError> {
for hook in &self.hooks {
let result = hook.on_turn_end(messages).await?;
match result {
TurnResult::Finish => continue,
TurnResult::ContinueWithMessages(msgs) => {
return Ok(TurnResult::ContinueWithMessages(msgs));
}
}
}
Ok(TurnResult::Finish)
}
/// ツールを並列実行
async fn execute_tools(
&self,
mut tool_calls: Vec<ToolCall>,
) -> Result<Vec<ToolResult>, WorkerError> {
let mut results = Vec::new();
// TODO: 将来的には join_all で並列実行
// 現在は逐次実行
for mut tool_call in tool_calls.drain(..) {
// Hook: before_tool_call
let mut skip = false;
for hook in &self.hooks {
let result = hook.before_tool_call(&mut tool_call).await?;
match result {
ControlFlow::Continue => {}
ControlFlow::Skip => {
skip = true;
break;
}
ControlFlow::Abort(reason) => {
return Err(WorkerError::Aborted(reason));
}
}
}
if skip {
continue;
}
// ツール実行
let mut tool_result = if let Some(tool) = self.tools.get(&tool_call.name) {
let input_json = serde_json::to_string(&tool_call.input).unwrap_or_default();
match tool.execute(&input_json).await {
Ok(content) => ToolResult::success(&tool_call.id, content),
Err(e) => ToolResult::error(&tool_call.id, e.to_string()),
}
} else {
ToolResult::error(
&tool_call.id,
format!("Tool '{}' not found", tool_call.name),
)
};
// Hook: after_tool_call
for hook in &self.hooks {
let result = hook.after_tool_call(&mut tool_result).await?;
match result {
ControlFlow::Continue => {}
ControlFlow::Skip => break,
ControlFlow::Abort(reason) => {
return Err(WorkerError::Aborted(reason));
}
}
}
results.push(tool_result);
}
Ok(results)
}
}
#[cfg(test)]
mod tests {
// 基本的なテストのみ。LlmClientを使ったテストは統合テストで行う。
}

View File

@ -1,14 +1,22 @@
//! テスト用のAPIレスポンス記録・再生機能
//! テスト用共通ユーティリティ
//!
//! 実際のAPIレスポンスをタイムスタンプ付きで記録し、
//! テスト時に再生できるようにする。
//! MockLlmClient、イベントレコーダー・プレイヤーを提供する
use std::fs::File;
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::path::Path;
use std::pin::Pin;
use std::time::{Instant, SystemTime, UNIX_EPOCH};
use async_trait::async_trait;
use futures::Stream;
use serde::{Deserialize, Serialize};
use worker::llm_client::{ClientError, LlmClient, Request};
use worker_types::Event;
// =============================================================================
// Recorded Event Types
// =============================================================================
/// 記録されたSSEイベント
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -32,15 +40,21 @@ pub struct SessionMetadata {
pub description: String,
}
// =============================================================================
// Event Recorder
// =============================================================================
/// SSEイベントレコーダー
///
/// 実際のAPIレスポンスを記録し、後でテストに使用できるようにする
#[allow(dead_code)]
pub struct EventRecorder {
start_time: Instant,
events: Vec<RecordedEvent>,
metadata: SessionMetadata,
}
#[allow(dead_code)]
impl EventRecorder {
/// 新しいレコーダーを作成
pub fn new(model: impl Into<String>, description: impl Into<String>) -> Self {
@ -97,15 +111,21 @@ impl EventRecorder {
}
}
// =============================================================================
// Event Player
// =============================================================================
/// SSEイベントプレイヤー
///
/// 記録されたイベントを読み込み、テストで使用する
#[allow(dead_code)]
pub struct EventPlayer {
metadata: SessionMetadata,
events: Vec<RecordedEvent>,
current_index: usize,
}
#[allow(dead_code)]
impl EventPlayer {
/// ファイルから読み込み
pub fn load(path: impl AsRef<Path>) -> std::io::Result<Self> {
@ -166,73 +186,55 @@ impl EventPlayer {
pub fn reset(&mut self) {
self.current_index = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_record_and_playback() {
// レコーダーを作成して記録
let mut recorder = EventRecorder::new("claude-sonnet-4-20250514", "Test recording");
recorder.record("message_start", r#"{"type":"message_start"}"#);
recorder.record(
"content_block_start",
r#"{"type":"content_block_start","index":0}"#,
);
recorder.record(
"content_block_delta",
r#"{"type":"content_block_delta","delta":{"type":"text_delta","text":"Hello"}}"#,
);
// 一時ファイルに保存
let temp_file = NamedTempFile::new().unwrap();
recorder.save(temp_file.path()).unwrap();
// 読み込んで確認
let player = EventPlayer::load(temp_file.path()).unwrap();
assert_eq!(player.metadata().model, "claude-sonnet-4-20250514");
assert_eq!(player.event_count(), 3);
assert_eq!(player.events()[0].event_type, "message_start");
assert_eq!(player.events()[2].event_type, "content_block_delta");
}
#[test]
fn test_player_iteration() {
// テストデータを直接作成
let mut temp_file = NamedTempFile::new().unwrap();
writeln!(
temp_file,
r#"{{"timestamp":1704067200,"model":"test","description":"test"}}"#
)
.unwrap();
writeln!(
temp_file,
r#"{{"elapsed_ms":0,"event_type":"ping","data":"{{}}"}}"#
)
.unwrap();
writeln!(
temp_file,
r#"{{"elapsed_ms":100,"event_type":"message_stop","data":"{{}}"}}"#
)
.unwrap();
temp_file.flush().unwrap();
let mut player = EventPlayer::load(temp_file.path()).unwrap();
let first = player.next_event().unwrap();
assert_eq!(first.event_type, "ping");
let second = player.next_event().unwrap();
assert_eq!(second.event_type, "message_stop");
assert!(player.next_event().is_none());
// リセット後は最初から
player.reset();
assert_eq!(player.next_event().unwrap().event_type, "ping");
/// 全イベントをworker_types::Eventとしてパースして取得
pub fn parse_events(&self) -> Vec<Event> {
self.events
.iter()
.filter_map(|recorded| serde_json::from_str(&recorded.data).ok())
.collect()
}
}
// =============================================================================
// MockLlmClient
// =============================================================================
/// テスト用のモックLLMクライアント
///
/// 事前に定義されたイベントシーケンスをストリームとして返す。
/// fixtureファイルからロードすることも、直接イベントを渡すこともできる。
pub struct MockLlmClient {
events: Vec<Event>,
}
impl MockLlmClient {
/// イベントリストから直接作成
pub fn new(events: Vec<Event>) -> Self {
Self { events }
}
/// fixtureファイルからロード
pub fn from_fixture(path: impl AsRef<Path>) -> std::io::Result<Self> {
let player = EventPlayer::load(path)?;
let events = player.parse_events();
Ok(Self { events })
}
/// 保持しているイベント数を取得
pub fn event_count(&self) -> usize {
self.events.len()
}
}
#[async_trait]
impl LlmClient for MockLlmClient {
async fn stream(
&self,
_request: Request,
) -> Result<Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>>, ClientError> {
let events = self.events.clone();
let stream = futures::stream::iter(events.into_iter().map(Ok));
Ok(Box::pin(stream))
}
}

16
worker/tests/fixtures/tool_call.jsonl vendored Normal file
View File

@ -0,0 +1,16 @@
{"timestamp":1767692881,"model":"claude-sonnet-4-20250514","description":"Tool call response"}
{"elapsed_ms":1783,"event_type":"Discriminant(1)","data":"{\"Usage\":{\"input_tokens\":409,\"output_tokens\":3,\"total_tokens\":412,\"cache_read_input_tokens\":0,\"cache_creation_input_tokens\":0}}"}
{"elapsed_ms":1783,"event_type":"Discriminant(4)","data":"{\"BlockStart\":{\"index\":0,\"block_type\":\"Text\",\"metadata\":\"Text\"}}"}
{"elapsed_ms":1783,"event_type":"Discriminant(5)","data":"{\"BlockDelta\":{\"index\":0,\"delta\":{\"Text\":\"I'll check\"}}}"}
{"elapsed_ms":1883,"event_type":"Discriminant(5)","data":"{\"BlockDelta\":{\"index\":0,\"delta\":{\"Text\":\" the current\"}}}"}
{"elapsed_ms":2063,"event_type":"Discriminant(0)","data":"{\"Ping\":{\"timestamp\":null}}"}
{"elapsed_ms":2063,"event_type":"Discriminant(5)","data":"{\"BlockDelta\":{\"index\":0,\"delta\":{\"Text\":\" weather in Tokyo for you using\"}}}"}
{"elapsed_ms":2124,"event_type":"Discriminant(5)","data":"{\"BlockDelta\":{\"index\":0,\"delta\":{\"Text\":\" the get_weather tool.\"}}}"}
{"elapsed_ms":2252,"event_type":"Discriminant(6)","data":"{\"BlockStop\":{\"index\":0,\"block_type\":\"Text\",\"stop_reason\":null}}"}
{"elapsed_ms":2253,"event_type":"Discriminant(4)","data":"{\"BlockStart\":{\"index\":1,\"block_type\":\"ToolUse\",\"metadata\":{\"ToolUse\":{\"id\":\"toolu_011Hg5wju1LGL7F65HyfE6bM\",\"name\":\"get_weather\"}}}}"}
{"elapsed_ms":2253,"event_type":"Discriminant(5)","data":"{\"BlockDelta\":{\"index\":1,\"delta\":{\"InputJson\":\"\"}}}"}
{"elapsed_ms":2306,"event_type":"Discriminant(5)","data":"{\"BlockDelta\":{\"index\":1,\"delta\":{\"InputJson\":\"{\\\"city\\\": \\\"Tokyo\"}}}"}
{"elapsed_ms":2451,"event_type":"Discriminant(5)","data":"{\"BlockDelta\":{\"index\":1,\"delta\":{\"InputJson\":\"\\\"}\"}}}"}
{"elapsed_ms":2451,"event_type":"Discriminant(6)","data":"{\"BlockStop\":{\"index\":1,\"block_type\":\"Text\",\"stop_reason\":null}}"}
{"elapsed_ms":2464,"event_type":"Discriminant(1)","data":"{\"Usage\":{\"input_tokens\":409,\"output_tokens\":71,\"total_tokens\":480,\"cache_read_input_tokens\":0,\"cache_creation_input_tokens\":0}}"}
{"elapsed_ms":2470,"event_type":"Discriminant(2)","data":"{\"Status\":{\"status\":\"Completed\"}}"}

View File

@ -0,0 +1,243 @@
//! Workerフィクスチャベースの統合テスト
//!
//! 記録されたAPIレスポンスを使ってWorkerの動作をテストする。
//! APIキー不要でローカルで実行可能。
mod common;
use std::path::Path;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use async_trait::async_trait;
use common::MockLlmClient;
use worker::{Worker, WorkerConfig};
use worker_types::{Tool, ToolError};
/// フィクスチャディレクトリのパス
fn fixtures_dir() -> std::path::PathBuf {
Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures")
}
/// シンプルなテスト用ツール
#[derive(Clone)]
struct MockWeatherTool {
call_count: Arc<AtomicUsize>,
}
impl MockWeatherTool {
fn new() -> Self {
Self {
call_count: Arc::new(AtomicUsize::new(0)),
}
}
fn get_call_count(&self) -> usize {
self.call_count.load(Ordering::SeqCst)
}
}
#[async_trait]
impl Tool for MockWeatherTool {
fn name(&self) -> &str {
"get_weather"
}
fn description(&self) -> &str {
"Get the current weather for a city"
}
fn input_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name"
}
},
"required": ["city"]
})
}
async fn execute(&self, input_json: &str) -> Result<String, ToolError> {
self.call_count.fetch_add(1, Ordering::SeqCst);
// 入力をパース
let input: serde_json::Value = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(e.to_string()))?;
let city = input["city"]
.as_str()
.unwrap_or("Unknown");
// モックのレスポンスを返す
Ok(format!("Weather in {}: Sunny, 22°C", city))
}
}
// =============================================================================
// Basic Fixture Tests
// =============================================================================
/// MockLlmClientがJSONLフィクスチャファイルから正しくイベントをロードできることを確認
///
/// 既存のanthropic_*.jsonlファイルを使用し、イベントがパース・ロードされることを検証する。
#[test]
fn test_mock_client_from_fixture() {
// 既存のフィクスチャをロード
let fixture_path = fixtures_dir().join("anthropic_1767624445.jsonl");
if !fixture_path.exists() {
println!("Fixture not found, skipping test");
return;
}
let client = MockLlmClient::from_fixture(&fixture_path).unwrap();
assert!(client.event_count() > 0, "Should have loaded events");
println!("Loaded {} events from fixture", client.event_count());
}
/// MockLlmClientが直接指定されたイベントリストで正しく動作することを確認
///
/// fixtureファイルを使わず、プログラムでイベントを構築してクライアントを作成する。
#[test]
fn test_mock_client_from_events() {
use worker_types::Event;
// 直接イベントを指定
let events = vec![
Event::text_block_start(0),
Event::text_delta(0, "Hello!"),
Event::text_block_stop(0, None),
];
let client = MockLlmClient::new(events);
assert_eq!(client.event_count(), 3);
}
// =============================================================================
// Worker Tests with Fixtures
// =============================================================================
/// Workerがシンプルなテキストレスポンスを正しく処理できることを確認
///
/// simple_text.jsonlフィクスチャを使用し、ツール呼び出しなしのシナリオをテストする。
/// フィクスチャがない場合はスキップされる。
#[tokio::test]
async fn test_worker_simple_text_response() {
let fixture_path = fixtures_dir().join("simple_text.jsonl");
if !fixture_path.exists() {
println!("Fixture not found: {:?}, skipping test", fixture_path);
println!("Run: cargo run --example record_worker_test");
return;
}
let client = MockLlmClient::from_fixture(&fixture_path).unwrap();
let mut worker = Worker::new(client);
// シンプルなメッセージを送信
let messages = vec![worker_types::Message::user("Hello")];
let result = worker.run(messages).await;
assert!(result.is_ok(), "Worker should complete successfully");
}
/// Workerがツール呼び出しを含むレスポンスを正しく処理できることを確認
///
/// tool_call.jsonlフィクスチャを使用し、MockWeatherToolが呼び出されることをテストする。
/// max_turns=1に設定し、ツール実行後のループを防止。
#[tokio::test]
async fn test_worker_tool_call() {
let fixture_path = fixtures_dir().join("tool_call.jsonl");
if !fixture_path.exists() {
println!("Fixture not found: {:?}, skipping test", fixture_path);
println!("Run: cargo run --example record_worker_test");
return;
}
let client = MockLlmClient::from_fixture(&fixture_path).unwrap();
let mut worker = Worker::new(client);
// ツールを登録
let weather_tool = MockWeatherTool::new();
let tool_for_check = weather_tool.clone();
worker.register_tool(weather_tool);
// 設定: ツール実行後はターン終了(ループしない)
worker = worker.config(WorkerConfig { max_turns: 1 });
// メッセージを送信
let messages = vec![worker_types::Message::user("What's the weather in Tokyo?")];
let _result = worker.run(messages).await;
// ツールが呼び出されたことを確認
// Note: max_turns=1なのでツール結果後のリクエストは送信されない
let call_count = tool_for_check.get_call_count();
println!("Tool was called {} times", call_count);
// フィクスチャにToolUseが含まれていればツールが呼び出されるはず
// ただしmax_turns=1なので1回で終了
}
/// fixtureファイルなしでWorkerが動作することを確認
///
/// プログラムでイベントシーケンスを構築し、MockLlmClientに渡してテストする。
/// テストの独立性を高め、外部ファイルへの依存を排除したい場合に有用。
#[tokio::test]
async fn test_worker_with_programmatic_events() {
use worker_types::{Event, ResponseStatus, StatusEvent};
// プログラムでイベントシーケンスを構築
let events = vec![
Event::text_block_start(0),
Event::text_delta(0, "Hello, "),
Event::text_delta(0, "World!"),
Event::text_block_stop(0, None),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
];
let client = MockLlmClient::new(events);
let mut worker = Worker::new(client);
let messages = vec![worker_types::Message::user("Greet me")];
let result = worker.run(messages).await;
assert!(result.is_ok(), "Worker should complete successfully");
}
/// ToolCallCollectorがToolUseブロックイベントから正しくToolCallを収集することを確認
///
/// Timelineにイベントをディスパッチし、ToolCallCollectorが
/// id, name, inputJSONを正しく抽出できることを検証する。
#[tokio::test]
async fn test_tool_call_collector_integration() {
use worker::ToolCallCollector;
use worker::Timeline;
use worker_types::Event;
// ToolUseブロックを含むイベントシーケンス
let events = vec![
Event::tool_use_start(0, "call_123", "get_weather"),
Event::tool_input_delta(0, r#"{"city":"#),
Event::tool_input_delta(0, r#""Tokyo"}"#),
Event::tool_use_stop(0),
];
let collector = ToolCallCollector::new();
let mut timeline = Timeline::new();
timeline.on_tool_use_block(collector.clone());
// イベントをディスパッチ
for event in &events {
timeline.dispatch(event);
}
// 収集されたToolCallを確認
let calls = collector.take_collected();
assert_eq!(calls.len(), 1, "Should collect one tool call");
assert_eq!(calls[0].name, "get_weather");
assert_eq!(calls[0].id, "call_123");
assert_eq!(calls[0].input["city"], "Tokyo");
}