308 lines
9.8 KiB
Rust
308 lines
9.8 KiB
Rust
//! テスト用共通ユーティリティ
|
||
//!
|
||
//! MockLlmClient、イベントレコーダー・プレイヤーを提供する
|
||
|
||
use std::fs::File;
|
||
use std::io::{BufRead, BufReader, BufWriter, Write};
|
||
use std::path::Path;
|
||
use std::pin::Pin;
|
||
use std::time::{Instant, SystemTime, UNIX_EPOCH};
|
||
|
||
use async_trait::async_trait;
|
||
use futures::Stream;
|
||
use serde::{Deserialize, Serialize};
|
||
use worker::llm_client::{ClientError, LlmClient, Request};
|
||
use worker_types::Event;
|
||
|
||
// =============================================================================
|
||
// Recorded Event Types
|
||
// =============================================================================
|
||
|
||
/// 記録されたSSEイベント
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct RecordedEvent {
|
||
/// イベント受信からの経過時間 (ミリ秒)
|
||
pub elapsed_ms: u64,
|
||
/// SSEイベントタイプ
|
||
pub event_type: String,
|
||
/// SSEイベントデータ
|
||
pub data: String,
|
||
}
|
||
|
||
/// セッションメタデータ
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct SessionMetadata {
|
||
/// 記録開始タイムスタンプ (Unix epoch秒)
|
||
pub timestamp: u64,
|
||
/// モデル名
|
||
pub model: String,
|
||
/// リクエストの説明
|
||
pub description: String,
|
||
}
|
||
|
||
// =============================================================================
|
||
// Event Recorder
|
||
// =============================================================================
|
||
|
||
/// SSEイベントレコーダー
|
||
///
|
||
/// 実際のAPIレスポンスを記録し、後でテストに使用できるようにする
|
||
#[allow(dead_code)]
|
||
pub struct EventRecorder {
|
||
start_time: Instant,
|
||
events: Vec<RecordedEvent>,
|
||
metadata: SessionMetadata,
|
||
}
|
||
|
||
#[allow(dead_code)]
|
||
impl EventRecorder {
|
||
/// 新しいレコーダーを作成
|
||
pub fn new(model: impl Into<String>, description: impl Into<String>) -> Self {
|
||
let timestamp = SystemTime::now()
|
||
.duration_since(UNIX_EPOCH)
|
||
.unwrap()
|
||
.as_secs();
|
||
|
||
Self {
|
||
start_time: Instant::now(),
|
||
events: Vec::new(),
|
||
metadata: SessionMetadata {
|
||
timestamp,
|
||
model: model.into(),
|
||
description: description.into(),
|
||
},
|
||
}
|
||
}
|
||
|
||
/// イベントを記録
|
||
pub fn record(&mut self, event_type: &str, data: &str) {
|
||
let elapsed = self.start_time.elapsed();
|
||
self.events.push(RecordedEvent {
|
||
elapsed_ms: elapsed.as_millis() as u64,
|
||
event_type: event_type.to_string(),
|
||
data: data.to_string(),
|
||
});
|
||
}
|
||
|
||
/// 記録をファイルに保存
|
||
///
|
||
/// フォーマット: JSONL (1行目: metadata, 2行目以降: events)
|
||
pub fn save(&self, path: impl AsRef<Path>) -> std::io::Result<()> {
|
||
let file = File::create(path)?;
|
||
let mut writer = BufWriter::new(file);
|
||
|
||
// メタデータを書き込み
|
||
let metadata_json = serde_json::to_string(&self.metadata)?;
|
||
writeln!(writer, "{}", metadata_json)?;
|
||
|
||
// イベントを書き込み
|
||
for event in &self.events {
|
||
let event_json = serde_json::to_string(event)?;
|
||
writeln!(writer, "{}", event_json)?;
|
||
}
|
||
|
||
writer.flush()?;
|
||
Ok(())
|
||
}
|
||
|
||
/// 記録されたイベント数を取得
|
||
pub fn event_count(&self) -> usize {
|
||
self.events.len()
|
||
}
|
||
}
|
||
|
||
// =============================================================================
|
||
// Event Player
|
||
// =============================================================================
|
||
|
||
/// SSEイベントプレイヤー
|
||
///
|
||
/// 記録されたイベントを読み込み、テストで使用する
|
||
#[allow(dead_code)]
|
||
pub struct EventPlayer {
|
||
metadata: SessionMetadata,
|
||
events: Vec<RecordedEvent>,
|
||
current_index: usize,
|
||
}
|
||
|
||
#[allow(dead_code)]
|
||
impl EventPlayer {
|
||
/// ファイルから読み込み
|
||
pub fn load(path: impl AsRef<Path>) -> std::io::Result<Self> {
|
||
let file = File::open(path)?;
|
||
let reader = BufReader::new(file);
|
||
let mut lines = reader.lines();
|
||
|
||
// メタデータを読み込み
|
||
let metadata_line = lines
|
||
.next()
|
||
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidData, "Empty file"))??;
|
||
let metadata: SessionMetadata = serde_json::from_str(&metadata_line)?;
|
||
|
||
// イベントを読み込み
|
||
let mut events = Vec::new();
|
||
for line in lines {
|
||
let line = line?;
|
||
if !line.is_empty() {
|
||
let event: RecordedEvent = serde_json::from_str(&line)?;
|
||
events.push(event);
|
||
}
|
||
}
|
||
|
||
Ok(Self {
|
||
metadata,
|
||
events,
|
||
current_index: 0,
|
||
})
|
||
}
|
||
|
||
/// メタデータを取得
|
||
pub fn metadata(&self) -> &SessionMetadata {
|
||
&self.metadata
|
||
}
|
||
|
||
/// 全イベントを取得
|
||
pub fn events(&self) -> &[RecordedEvent] {
|
||
&self.events
|
||
}
|
||
|
||
/// イベント数を取得
|
||
pub fn event_count(&self) -> usize {
|
||
self.events.len()
|
||
}
|
||
|
||
/// 次のイベントを取得(Iterator的に使用)
|
||
pub fn next_event(&mut self) -> Option<&RecordedEvent> {
|
||
if self.current_index < self.events.len() {
|
||
let event = &self.events[self.current_index];
|
||
self.current_index += 1;
|
||
Some(event)
|
||
} else {
|
||
None
|
||
}
|
||
}
|
||
|
||
/// インデックスをリセット
|
||
pub fn reset(&mut self) {
|
||
self.current_index = 0;
|
||
}
|
||
|
||
/// 全イベントをworker_types::Eventとしてパースして取得
|
||
pub fn parse_events(&self) -> Vec<Event> {
|
||
self.events
|
||
.iter()
|
||
.filter_map(|recorded| serde_json::from_str(&recorded.data).ok())
|
||
.collect()
|
||
}
|
||
}
|
||
|
||
// =============================================================================
|
||
// MockLlmClient
|
||
// =============================================================================
|
||
|
||
/// テスト用のモックLLMクライアント
|
||
///
|
||
/// 事前に定義されたイベントシーケンスをストリームとして返す。
|
||
/// fixtureファイルからロードすることも、直接イベントを渡すこともできる。
|
||
///
|
||
/// # 複数リクエスト対応
|
||
///
|
||
/// `with_responses()`を使用して、複数回のリクエストに対して異なるレスポンスを設定できる。
|
||
/// リクエスト回数が設定されたレスポンス数を超えた場合は空のストリームを返す。
|
||
pub struct MockLlmClient {
|
||
/// 各リクエストに対するレスポンス(イベントシーケンス)
|
||
responses: std::sync::Arc<std::sync::Mutex<Vec<Vec<Event>>>>,
|
||
/// 現在のリクエストインデックス
|
||
request_index: std::sync::Arc<std::sync::atomic::AtomicUsize>,
|
||
}
|
||
|
||
#[allow(dead_code)]
|
||
impl MockLlmClient {
|
||
/// イベントリストから直接作成(単一レスポンス)
|
||
///
|
||
/// すべてのリクエストに対して同じイベントシーケンスを返す(従来の動作)
|
||
pub fn new(events: Vec<Event>) -> Self {
|
||
Self {
|
||
responses: std::sync::Arc::new(std::sync::Mutex::new(vec![events])),
|
||
request_index: std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)),
|
||
}
|
||
}
|
||
|
||
/// 複数のレスポンスを設定
|
||
///
|
||
/// 各リクエストに対して順番にイベントシーケンスを返す。
|
||
/// N回目のリクエストにはN番目のレスポンスが使用される。
|
||
///
|
||
/// # Example
|
||
/// ```ignore
|
||
/// let client = MockLlmClient::with_responses(vec![
|
||
/// // 1回目のリクエスト: ツール呼び出し
|
||
/// vec![Event::tool_use_start(0, "call_1", "my_tool"), ...],
|
||
/// // 2回目のリクエスト: テキストレスポンス
|
||
/// vec![Event::text_block_start(0), ...],
|
||
/// ]);
|
||
/// ```
|
||
pub fn with_responses(responses: Vec<Vec<Event>>) -> Self {
|
||
Self {
|
||
responses: std::sync::Arc::new(std::sync::Mutex::new(responses)),
|
||
request_index: std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)),
|
||
}
|
||
}
|
||
|
||
/// fixtureファイルからロード(単一レスポンス)
|
||
pub fn from_fixture(path: impl AsRef<Path>) -> std::io::Result<Self> {
|
||
let player = EventPlayer::load(path)?;
|
||
let events = player.parse_events();
|
||
Ok(Self::new(events))
|
||
}
|
||
|
||
/// 保持しているレスポンス数を取得
|
||
pub fn response_count(&self) -> usize {
|
||
self.responses.lock().unwrap().len()
|
||
}
|
||
|
||
/// 最初のレスポンスのイベント数を取得(後方互換性)
|
||
pub fn event_count(&self) -> usize {
|
||
self.responses
|
||
.lock()
|
||
.unwrap()
|
||
.first()
|
||
.map(|v| v.len())
|
||
.unwrap_or(0)
|
||
}
|
||
|
||
/// 現在のリクエストインデックスを取得
|
||
pub fn current_request_index(&self) -> usize {
|
||
self.request_index.load(std::sync::atomic::Ordering::SeqCst)
|
||
}
|
||
|
||
/// リクエストインデックスをリセット
|
||
pub fn reset(&self) {
|
||
self.request_index.store(0, std::sync::atomic::Ordering::SeqCst);
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl LlmClient for MockLlmClient {
|
||
async fn stream(
|
||
&self,
|
||
_request: Request,
|
||
) -> Result<Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>>, ClientError> {
|
||
let index = self.request_index.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||
|
||
let events = {
|
||
let responses = self.responses.lock().unwrap();
|
||
if index < responses.len() {
|
||
responses[index].clone()
|
||
} else {
|
||
// レスポンスが尽きた場合は空のストリーム
|
||
Vec::new()
|
||
}
|
||
};
|
||
|
||
let stream = futures::stream::iter(events.into_iter().map(Ok));
|
||
Ok(Box::pin(stream))
|
||
}
|
||
}
|
||
|