diff --git a/docs/spec/cache_lock.md b/docs/spec/cache_lock.md new file mode 100644 index 0000000..f2e8a00 --- /dev/null +++ b/docs/spec/cache_lock.md @@ -0,0 +1,68 @@ +# KVキャッシュを中心とした設計 + +LLMのKVキャッシュのヒット率を重要なメトリクスであるとし、APIレベルでキャッシュ操作を中心とした設計を行う。 + +## 前提 + +リクエスト間キャッシュ(Context Caching)は、複数のリクエストで同じ入力トークン列が繰り返された際、プロバイダ側が計算済みの状態を再利用することでレイテンシと入力コストを下げる仕組みである。 +キャッシュは主に**先頭一致 (Common Prefix)** によってHitするため、前提となるシステムプロンプトや、会話ログの過去部分(前方)を変化させると、以降のキャッシュは無効となる。 + +## 要件 + +1. **前方不変性の保証 (Prefix Immutability)** + * 後方に会話が追加されても、前方のデータ(システムプロンプトや確定済みのメッセージ履歴)が変化しないことをAPIレベルで保証する。 + * これにより、意図しないキャッシュミス(Cache Miss)を防ぐ。 + +2. **データ上の再現性** + * コンテキストのデータ構造が同一であれば、生成されるリクエスト構造も同一であることを保証する。 + * シリアライズ結果のバイト単位の完全一致までは求めないが、論理的なリクエスト構造は保たれる必要がある。 + +## アプローチ: Type-state Pattern + +RustのType-stateパターンを利用し、Workerの状態によって利用可能な操作をコンパイル時に制限する。 + +### 1. 状態定義 + +* **`Mutable` (初期状態)** + * 自由な編集が可能な状態。 + * システムプロンプトの設定・変更が可能。 + * メッセージ履歴の初期構築(ロード、編集)が可能。 +* **`Locked` (キャッシュ保護状態)** + * キャッシュの有効活用を目的とした、前方不変状態。 + * **システムプロンプトの変更不可**。 + * **既存メッセージ履歴の変更不可**(追記のみ許可)。 + * 実行(`run`)はこの状態で行うことを推奨する。 + +### 2. 状態遷移とAPIイメージ + +`Worker` 自身がコンテキスト(履歴)のオーナーとなり、状態によってアクセサを制限する。 + +```rust +// 1. Mutable状態で初期化 +let mut worker: Worker = Worker::new(client); + +// 2. コンテキストの構築 (Mutableなので自由に変更可) +worker.set_system_prompt("You are a helpful assistant."); +worker.history_mut().push(initial_message); + +// 3. ロックしてLocked状態へ遷移 +// これにより、ここまでのコンテキストが "Fixed Prefix" として扱われる +let mut locked_worker: Worker = worker.lock(); + +// 4. 利用 (Locked状態) +// 実行は可能。新しいメッセージは履歴の末尾に追記される。 +// 前方の履歴やシステムプロンプトは変更できないため、キャッシュヒットが保証される。 +locked_worker.run(new_user_input).await?; + +// NG操作 (コンパイルエラー) +// locked_worker.set_system_prompt("New prompt"); +// locked_worker.history_mut().clear(); +``` + +### 3. 実装への影響 + +現在の `Worker` 実装に対し、以下の変更が必要となる。 + +* **状態パラメータの導入**: `Worker` の導入。 +* **コンテキスト所有権の委譲**: `run` メソッドの引数でコンテキストを受け取るのではなく、`Worker` 内部に `history: Vec` を保持し管理する形へ移行する。 +* **APIの分離**: `Mutable` 特有のメソッド(setter等)と、`Locked` でも使えるメソッド(実行、参照等)をトレイト境界で分離する。 diff --git a/worker-types/src/lib.rs b/worker-types/src/lib.rs index 7497ced..2a69777 100644 --- a/worker-types/src/lib.rs +++ b/worker-types/src/lib.rs @@ -12,6 +12,7 @@ mod event; mod handler; mod hook; mod message; +mod state; mod subscriber; mod tool; @@ -19,5 +20,6 @@ pub use event::*; pub use handler::*; pub use hook::*; pub use message::*; +pub use state::*; pub use subscriber::*; pub use tool::*; diff --git a/worker-types/src/state.rs b/worker-types/src/state.rs new file mode 100644 index 0000000..eecdd73 --- /dev/null +++ b/worker-types/src/state.rs @@ -0,0 +1,40 @@ +//! Worker状態マーカー型 +//! +//! Type-stateパターンによるキャッシュ保護のための状態定義 + +/// Worker状態を表すマーカートレイト +/// +/// このトレイトはシールされており、外部から実装することはできない。 +pub trait WorkerState: private::Sealed + Send + Sync + 'static {} + +mod private { + pub trait Sealed {} +} + +/// 変更可能状態 +/// +/// この状態では以下の操作が可能: +/// - システムプロンプトの設定・変更 +/// - メッセージ履歴の編集(追加、削除、クリア) +/// - ツール・Hookの登録 +/// +/// `lock()` によって `Locked` 状態へ遷移できる。 +#[derive(Debug, Clone, Copy, Default)] +pub struct Mutable; + +impl private::Sealed for Mutable {} +impl WorkerState for Mutable {} + +/// ロック状態(キャッシュ保護) +/// +/// この状態では以下の制限がある: +/// - システムプロンプトの変更不可 +/// - 既存メッセージ履歴の変更不可(末尾への追記のみ) +/// +/// 実行(`run`)はこの状態で行うことが推奨される。 +/// キャッシュヒットを保証するため、前方のコンテキストは不変となる。 +#[derive(Debug, Clone, Copy, Default)] +pub struct Locked; + +impl private::Sealed for Locked {} +impl WorkerState for Locked {} diff --git a/worker/examples/worker_cli.rs b/worker/examples/worker_cli.rs index 5866e7b..be046b2 100644 --- a/worker/examples/worker_cli.rs +++ b/worker/examples/worker_cli.rs @@ -51,7 +51,6 @@ use worker::{ }, }; use worker_macros::tool_registry; -use worker_types::Message; // 必要なマクロ展開用インポート use schemars; @@ -453,14 +452,9 @@ async fn main() -> Result<(), Box> { worker.add_hook(ToolResultPrinterHook::new(tool_call_names)); - // 会話履歴 - let mut history: Vec = Vec::new(); - // ワンショットモード if let Some(prompt) = args.prompt { - history.push(Message::user(&prompt)); - - match worker.run(history).await { + match worker.run(&prompt).await { Ok(_) => {} Err(e) => { eprintln!("\n❌ Error: {}", e); @@ -489,18 +483,11 @@ async fn main() -> Result<(), Box> { break; } - // ユーザーメッセージを履歴に追加 - history.push(Message::user(input)); - - // Workerを実行 - match worker.run(history.clone()).await { - Ok(new_history) => { - history = new_history; - } + // Workerを実行(Workerが履歴を管理) + match worker.run(input).await { + Ok(_) => {} Err(e) => { eprintln!("\n❌ Error: {}", e); - // エラー時は最後のユーザーメッセージを削除 - history.pop(); } } } diff --git a/worker/src/worker.rs b/worker/src/worker.rs index 38ba751..c10bc3e 100644 --- a/worker/src/worker.rs +++ b/worker/src/worker.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::marker::PhantomData; use std::sync::{Arc, Mutex}; use futures::StreamExt; @@ -13,8 +14,8 @@ use crate::subscriber_adapter::{ use crate::text_block_collector::TextBlockCollector; use crate::tool_call_collector::ToolCallCollector; use worker_types::{ - ContentPart, ControlFlow, HookError, Message, MessageContent, Tool, ToolCall, ToolError, - ToolResult, TurnResult, WorkerHook, WorkerSubscriber, + ContentPart, ControlFlow, HookError, Locked, Message, MessageContent, Mutable, Tool, ToolCall, + ToolError, ToolResult, TurnResult, WorkerHook, WorkerState, WorkerSubscriber, }; // ============================================================================= @@ -83,12 +84,19 @@ impl TurnNotifier for SubscriberTurnNotifier { /// Worker - ターン制御コンポーネント /// +/// Type-stateパターンによりキャッシュ保護を実現する。 +/// +/// # 状態 +/// - `Mutable`: 初期状態。システムプロンプトや履歴を自由に編集可能。 +/// - `Locked`: キャッシュ保護状態。前方コンテキストは不変となり、追記のみ可能。 +/// /// # 責務 /// - LLMへのリクエスト送信とレスポンス処理 /// - ツール呼び出しの収集と実行 /// - Hookによる介入の提供 /// - ターンループの制御 -pub struct Worker { +/// - 履歴の所有と管理 +pub struct Worker { /// LLMクライアント client: C, /// イベントタイムライン @@ -103,36 +111,23 @@ pub struct Worker { hooks: Vec>, /// システムプロンプト system_prompt: Option, + /// メッセージ履歴(Workerが所有) + history: Vec, + /// ロック時点での履歴長(Locked状態でのみ意味を持つ) + locked_prefix_len: usize, /// ターンカウント turn_count: usize, /// ターン通知用のコールバック turn_notifiers: Vec>, + /// 状態マーカー + _state: PhantomData, } -impl Worker { - /// 新しいWorkerを作成 - pub fn new(client: C) -> Self { - let text_block_collector = TextBlockCollector::new(); - let tool_call_collector = ToolCallCollector::new(); - let mut timeline = Timeline::new(); - - // コレクターをTimelineに登録 - timeline.on_text_block(text_block_collector.clone()); - timeline.on_tool_use_block(tool_call_collector.clone()); - - Self { - client, - timeline, - text_block_collector, - tool_call_collector, - tools: HashMap::new(), - hooks: Vec::new(), - system_prompt: None, - turn_count: 0, - turn_notifiers: Vec::new(), - } - } +// ============================================================================= +// 共通実装(全状態で利用可能) +// ============================================================================= +impl Worker { /// WorkerSubscriberを登録 /// /// Subscriberは以下のイベントを受け取ることができる: @@ -140,7 +135,7 @@ impl Worker { /// - 単発イベント: on_usage, on_status, on_error /// - 累積イベント: on_text_complete, on_tool_call_complete /// - ターン制御: on_turn_start, on_turn_end - pub fn subscribe(&mut self, subscriber: S) { + pub fn subscribe(&mut self, subscriber: Sub) { let subscriber = Arc::new(Mutex::new(subscriber)); // TextBlock用ハンドラを登録 @@ -164,23 +159,6 @@ impl Worker { .push(Box::new(SubscriberTurnNotifier { subscriber })); } - /// システムプロンプトを設定 - pub fn system_prompt(mut self, prompt: impl Into) -> Self { - self.system_prompt = Some(prompt.into()); - self - } - - /// システムプロンプトを設定(可変参照版) - pub fn set_system_prompt(&mut self, prompt: impl Into) { - self.system_prompt = Some(prompt.into()); - } - - /// 設定を適用(将来の拡張用) - #[allow(dead_code)] - pub fn config(self, _config: WorkerConfig) -> Self { - self - } - /// ツールを登録 pub fn register_tool(&mut self, tool: impl Tool + 'static) { let name = tool.name().to_string(); @@ -204,6 +182,21 @@ impl Worker { &mut self.timeline } + /// 履歴への参照を取得 + pub fn history(&self) -> &[Message] { + &self.history + } + + /// システムプロンプトへの参照を取得 + pub fn get_system_prompt(&self) -> Option<&str> { + self.system_prompt.as_deref() + } + + /// 現在のターンカウントを取得 + pub fn turn_count(&self) -> usize { + self.turn_count + } + /// 登録されたツールからToolDefinitionのリストを生成 fn build_tool_definitions(&self) -> Vec { self.tools @@ -216,107 +209,6 @@ impl Worker { .collect() } - /// ターンを実行 - /// - /// メッセージを送信し、レスポンスを処理する。 - /// ツール呼び出しがある場合は自動的にループする。 - pub async fn run(&mut self, messages: Vec) -> Result, WorkerError> { - let mut context = messages; - let tool_definitions = self.build_tool_definitions(); - - info!( - message_count = context.len(), - tool_count = tool_definitions.len(), - "Starting worker run" - ); - - loop { - // ターン開始を通知 - let current_turn = self.turn_count; - debug!(turn = current_turn, "Turn start"); - for notifier in &self.turn_notifiers { - notifier.on_turn_start(current_turn); - } - - // Hook: on_message_send - let control = self.run_on_message_send_hooks(&mut context).await?; - if let ControlFlow::Abort(reason) = control { - warn!(reason = %reason, "Aborted by hook"); - // ターン終了を通知(異常終了) - for notifier in &self.turn_notifiers { - notifier.on_turn_end(current_turn); - } - return Err(WorkerError::Aborted(reason)); - } - - // リクエスト構築 - let request = self.build_request(&context, &tool_definitions); - debug!( - message_count = request.messages.len(), - tool_count = request.tools.len(), - has_system = request.system_prompt.is_some(), - "Sending request to LLM" - ); - - // ストリーム処理 - debug!("Starting stream..."); - let mut stream = self.client.stream(request).await?; - let mut event_count = 0; - while let Some(event_result) = stream.next().await { - match &event_result { - Ok(event) => { - trace!(event = ?event, "Received event"); - event_count += 1; - } - Err(e) => { - warn!(error = %e, "Stream error"); - } - } - let event = event_result?; - self.timeline.dispatch(&event); - } - debug!(event_count = event_count, "Stream completed"); - - // ターン終了を通知 - for notifier in &self.turn_notifiers { - notifier.on_turn_end(current_turn); - } - self.turn_count += 1; - - // 収集結果を取得 - let text_blocks = self.text_block_collector.take_collected(); - let tool_calls = self.tool_call_collector.take_collected(); - - // アシスタントメッセージをコンテキストに追加 - let assistant_message = self.build_assistant_message(&text_blocks, &tool_calls); - if let Some(msg) = assistant_message { - context.push(msg); - } - - if tool_calls.is_empty() { - // ツール呼び出しなし → ターン終了判定 - let turn_result = self.run_on_turn_end_hooks(&context).await?; - match turn_result { - TurnResult::Finish => { - return Ok(context); - } - TurnResult::ContinueWithMessages(additional) => { - context.extend(additional); - continue; - } - } - } - - // ツール実行 - let tool_results = self.execute_tools(tool_calls).await?; - - // ツール結果をコンテキストに追加 - for result in tool_results { - context.push(Message::tool_result(&result.tool_use_id, &result.content)); - } - } - } - /// テキストブロックとツール呼び出しからアシスタントメッセージを構築 fn build_assistant_message( &self, @@ -360,7 +252,7 @@ impl Worker { } /// リクエストを構築 - fn build_request(&self, context: &[Message], tool_definitions: &[ToolDefinition]) -> Request { + fn build_request(&self, tool_definitions: &[ToolDefinition]) -> Request { let mut request = Request::new(); // システムプロンプトを設定 @@ -369,7 +261,7 @@ impl Worker { } // メッセージを追加 - for msg in context { + for msg in &self.history { // worker-types::Message から llm_client::Message への変換 request = request.message(crate::llm_client::Message { role: match msg.role { @@ -426,12 +318,13 @@ impl Worker { } /// Hooks: on_message_send - async fn run_on_message_send_hooks( - &self, - context: &mut Vec, - ) -> Result { + async fn run_on_message_send_hooks(&self) -> Result { for hook in &self.hooks { - let result = hook.on_message_send(context).await?; + // Note: Locked状態でも履歴全体を参照として渡す(変更は不可) + // HookのAPIを変更し、immutable参照のみを渡すようにする必要があるかもしれない + // 現在は空のVecを渡して回避(要検討) + let mut temp_context = self.history.clone(); + let result = hook.on_message_send(&mut temp_context).await?; match result { ControlFlow::Continue => continue, ControlFlow::Skip => return Ok(ControlFlow::Skip), @@ -442,9 +335,9 @@ impl Worker { } /// Hooks: on_turn_end - async fn run_on_turn_end_hooks(&self, messages: &[Message]) -> Result { + async fn run_on_turn_end_hooks(&self) -> Result { for hook in &self.hooks { - let result = hook.on_turn_end(messages).await?; + let result = hook.on_turn_end(&self.history).await?; match result { TurnResult::Finish => continue, TurnResult::ContinueWithMessages(msgs) => { @@ -528,6 +421,291 @@ impl Worker { Ok(results) } + + /// 内部で使用するターン実行ロジック + async fn run_turn_loop(&mut self) -> Result<(), WorkerError> { + let tool_definitions = self.build_tool_definitions(); + + info!( + message_count = self.history.len(), + tool_count = tool_definitions.len(), + "Starting worker run" + ); + + loop { + // ターン開始を通知 + let current_turn = self.turn_count; + debug!(turn = current_turn, "Turn start"); + for notifier in &self.turn_notifiers { + notifier.on_turn_start(current_turn); + } + + // Hook: on_message_send + let control = self.run_on_message_send_hooks().await?; + if let ControlFlow::Abort(reason) = control { + warn!(reason = %reason, "Aborted by hook"); + // ターン終了を通知(異常終了) + for notifier in &self.turn_notifiers { + notifier.on_turn_end(current_turn); + } + return Err(WorkerError::Aborted(reason)); + } + + // リクエスト構築 + let request = self.build_request(&tool_definitions); + debug!( + message_count = request.messages.len(), + tool_count = request.tools.len(), + has_system = request.system_prompt.is_some(), + "Sending request to LLM" + ); + + // ストリーム処理 + debug!("Starting stream..."); + let mut stream = self.client.stream(request).await?; + let mut event_count = 0; + while let Some(event_result) = stream.next().await { + match &event_result { + Ok(event) => { + trace!(event = ?event, "Received event"); + event_count += 1; + } + Err(e) => { + warn!(error = %e, "Stream error"); + } + } + let event = event_result?; + self.timeline.dispatch(&event); + } + debug!(event_count = event_count, "Stream completed"); + + // ターン終了を通知 + for notifier in &self.turn_notifiers { + notifier.on_turn_end(current_turn); + } + self.turn_count += 1; + + // 収集結果を取得 + let text_blocks = self.text_block_collector.take_collected(); + let tool_calls = self.tool_call_collector.take_collected(); + + // アシスタントメッセージを履歴に追加 + let assistant_message = self.build_assistant_message(&text_blocks, &tool_calls); + if let Some(msg) = assistant_message { + self.history.push(msg); + } + + if tool_calls.is_empty() { + // ツール呼び出しなし → ターン終了判定 + let turn_result = self.run_on_turn_end_hooks().await?; + match turn_result { + TurnResult::Finish => { + return Ok(()); + } + TurnResult::ContinueWithMessages(additional) => { + self.history.extend(additional); + continue; + } + } + } + + // ツール実行 + let tool_results = self.execute_tools(tool_calls).await?; + + // ツール結果を履歴に追加 + for result in tool_results { + self.history + .push(Message::tool_result(&result.tool_use_id, &result.content)); + } + } + } +} + +// ============================================================================= +// Mutable状態専用の実装 +// ============================================================================= + +impl Worker { + /// 新しいWorkerを作成(Mutable状態) + pub fn new(client: C) -> Self { + let text_block_collector = TextBlockCollector::new(); + let tool_call_collector = ToolCallCollector::new(); + let mut timeline = Timeline::new(); + + // コレクターをTimelineに登録 + timeline.on_text_block(text_block_collector.clone()); + timeline.on_tool_use_block(tool_call_collector.clone()); + + Self { + client, + timeline, + text_block_collector, + tool_call_collector, + tools: HashMap::new(), + hooks: Vec::new(), + system_prompt: None, + history: Vec::new(), + locked_prefix_len: 0, + turn_count: 0, + turn_notifiers: Vec::new(), + _state: PhantomData, + } + } + + /// システムプロンプトを設定(ビルダーパターン) + pub fn system_prompt(mut self, prompt: impl Into) -> Self { + self.system_prompt = Some(prompt.into()); + self + } + + /// システムプロンプトを設定(可変参照版) + pub fn set_system_prompt(&mut self, prompt: impl Into) { + self.system_prompt = Some(prompt.into()); + } + + /// 履歴への可変参照を取得 + /// + /// Mutable状態でのみ利用可能。履歴を自由に編集できる。 + pub fn history_mut(&mut self) -> &mut Vec { + &mut self.history + } + + /// 履歴を設定 + pub fn set_history(&mut self, messages: Vec) { + self.history = messages; + } + + /// 履歴にメッセージを追加(ビルダーパターン) + pub fn with_message(mut self, message: Message) -> Self { + self.history.push(message); + self + } + + /// 履歴にメッセージを追加 + pub fn push_message(&mut self, message: Message) { + self.history.push(message); + } + + /// 複数のメッセージを履歴に追加(ビルダーパターン) + pub fn with_messages(mut self, messages: impl IntoIterator) -> Self { + self.history.extend(messages); + self + } + + /// 複数のメッセージを履歴に追加 + pub fn extend_history(&mut self, messages: impl IntoIterator) { + self.history.extend(messages); + } + + /// 履歴をクリア + pub fn clear_history(&mut self) { + self.history.clear(); + } + + /// 設定を適用(将来の拡張用) + #[allow(dead_code)] + pub fn config(self, _config: WorkerConfig) -> Self { + self + } + + /// ロックしてLocked状態へ遷移 + /// + /// この操作により、現在のシステムプロンプトと履歴が「確定済みプレフィックス」として + /// 固定される。以降は履歴への追記のみが可能となり、キャッシュヒットが保証される。 + pub fn lock(self) -> Worker { + let locked_prefix_len = self.history.len(); + Worker { + client: self.client, + timeline: self.timeline, + text_block_collector: self.text_block_collector, + tool_call_collector: self.tool_call_collector, + tools: self.tools, + hooks: self.hooks, + system_prompt: self.system_prompt, + history: self.history, + locked_prefix_len, + turn_count: self.turn_count, + turn_notifiers: self.turn_notifiers, + _state: PhantomData, + } + } + + /// ターンを実行(Mutable状態) + /// + /// 新しいユーザーメッセージを履歴に追加し、LLMにリクエストを送信する。 + /// ツール呼び出しがある場合は自動的にループする。 + /// + /// 注意: この関数は履歴を変更するため、キャッシュ保護が必要な場合は + /// `lock()` を呼んでからLocked状態で `run` を使用すること。 + pub async fn run(&mut self, user_input: impl Into) -> Result<&[Message], WorkerError> { + self.history.push(Message::user(user_input)); + self.run_turn_loop().await?; + Ok(&self.history) + } + + /// 複数メッセージでターンを実行(Mutable状態) + /// + /// 指定されたメッセージを履歴に追加してから実行する。 + pub async fn run_with_messages( + &mut self, + messages: Vec, + ) -> Result<&[Message], WorkerError> { + self.history.extend(messages); + self.run_turn_loop().await?; + Ok(&self.history) + } +} + +// ============================================================================= +// Locked状態専用の実装 +// ============================================================================= + +impl Worker { + /// ターンを実行(Locked状態) + /// + /// 新しいユーザーメッセージを履歴の末尾に追加し、LLMにリクエストを送信する。 + /// ロック時点より前の履歴(プレフィックス)は不変であるため、キャッシュヒットが保証される。 + pub async fn run(&mut self, user_input: impl Into) -> Result<&[Message], WorkerError> { + self.history.push(Message::user(user_input)); + self.run_turn_loop().await?; + Ok(&self.history) + } + + /// 複数メッセージでターンを実行(Locked状態) + pub async fn run_with_messages( + &mut self, + messages: Vec, + ) -> Result<&[Message], WorkerError> { + self.history.extend(messages); + self.run_turn_loop().await?; + Ok(&self.history) + } + + /// ロック時点のプレフィックス長を取得 + pub fn locked_prefix_len(&self) -> usize { + self.locked_prefix_len + } + + /// ロックを解除してMutable状態へ戻す + /// + /// 注意: この操作を行うと、以降のリクエストでキャッシュがヒットしなくなる可能性がある。 + /// 履歴を編集する必要がある場合にのみ使用すること。 + pub fn unlock(self) -> Worker { + Worker { + client: self.client, + timeline: self.timeline, + text_block_collector: self.text_block_collector, + tool_call_collector: self.tool_call_collector, + tools: self.tools, + hooks: self.hooks, + system_prompt: self.system_prompt, + history: self.history, + locked_prefix_len: 0, + turn_count: self.turn_count, + turn_notifiers: self.turn_notifiers, + _state: PhantomData, + } + } } #[cfg(test)] diff --git a/worker/tests/parallel_execution_test.rs b/worker/tests/parallel_execution_test.rs index 49888f6..f01b6e9 100644 --- a/worker/tests/parallel_execution_test.rs +++ b/worker/tests/parallel_execution_test.rs @@ -9,7 +9,7 @@ use std::time::{Duration, Instant}; use async_trait::async_trait; use worker::Worker; use worker_types::{ - ControlFlow, Event, HookError, Message, ResponseStatus, StatusEvent, Tool, ToolCall, ToolError, + ControlFlow, Event, HookError, ResponseStatus, StatusEvent, Tool, ToolCall, ToolError, ToolResult, WorkerHook, }; @@ -108,10 +108,8 @@ async fn test_parallel_tool_execution() { worker.register_tool(tool2); worker.register_tool(tool3); - let messages = vec![Message::user("Run all tools")]; - let start = Instant::now(); - let _result = worker.run(messages).await; + let _result = worker.run("Run all tools").await; let elapsed = start.elapsed(); // 全ツールが呼び出されたことを確認 @@ -176,8 +174,7 @@ async fn test_before_tool_call_skip() { worker.add_hook(BlockingHook); - let messages = vec![Message::user("Test hook")]; - let _result = worker.run(messages).await; + let _result = worker.run("Test hook").await; // allowed_tool は呼び出されるが、blocked_tool は呼び出されない assert_eq!( @@ -262,8 +259,7 @@ async fn test_after_tool_call_modification() { modified_content: modified_content.clone(), }); - let messages = vec![Message::user("Test modification")]; - let result = worker.run(messages).await; + let result = worker.run("Test modification").await; assert!(result.is_ok(), "Worker should complete: {:?}", result); diff --git a/worker/tests/subscriber_test.rs b/worker/tests/subscriber_test.rs index 6cea3f4..34cf4c0 100644 --- a/worker/tests/subscriber_test.rs +++ b/worker/tests/subscriber_test.rs @@ -9,8 +9,8 @@ use std::sync::{Arc, Mutex}; use common::MockLlmClient; use worker::{Worker, WorkerSubscriber}; use worker_types::{ - ErrorEvent, Event, Message, ResponseStatus, StatusEvent, TextBlockEvent, ToolCall, - ToolUseBlockEvent, UsageEvent, + ErrorEvent, Event, ResponseStatus, StatusEvent, TextBlockEvent, ToolCall, ToolUseBlockEvent, + UsageEvent, }; // ============================================================================= @@ -115,8 +115,7 @@ async fn test_subscriber_text_block_events() { worker.subscribe(subscriber); // 実行 - let messages = vec![Message::user("Greet me")]; - let result = worker.run(messages).await; + let result = worker.run("Greet me").await; assert!(result.is_ok(), "Worker should complete: {:?}", result); @@ -155,8 +154,7 @@ async fn test_subscriber_tool_call_complete() { worker.subscribe(subscriber); // 実行 - let messages = vec![Message::user("Weather please")]; - let _ = worker.run(messages).await; + let _ = worker.run("Weather please").await; // ツール呼び出し完了が収集されていることを確認 let completes = tool_call_completes.lock().unwrap(); @@ -188,8 +186,7 @@ async fn test_subscriber_turn_events() { worker.subscribe(subscriber); // 実行 - let messages = vec![Message::user("Do something")]; - let result = worker.run(messages).await; + let result = worker.run("Do something").await; assert!(result.is_ok()); @@ -226,8 +223,7 @@ async fn test_subscriber_usage_events() { worker.subscribe(subscriber); // 実行 - let messages = vec![Message::user("Hello")]; - let _ = worker.run(messages).await; + let _ = worker.run("Hello").await; // Usageイベントが収集されていることを確認 let usages = usage_events.lock().unwrap(); diff --git a/worker/tests/worker_fixtures.rs b/worker/tests/worker_fixtures.rs index b8d7a1a..3e3aee3 100644 --- a/worker/tests/worker_fixtures.rs +++ b/worker/tests/worker_fixtures.rs @@ -134,8 +134,7 @@ async fn test_worker_simple_text_response() { let mut worker = Worker::new(client); // シンプルなメッセージを送信 - let messages = vec![worker_types::Message::user("Hello")]; - let result = worker.run(messages).await; + let result = worker.run("Hello").await; assert!(result.is_ok(), "Worker should complete successfully"); } @@ -162,8 +161,7 @@ async fn test_worker_tool_call() { worker.register_tool(weather_tool); // メッセージを送信 - let messages = vec![worker_types::Message::user("What's the weather in Tokyo?")]; - let _result = worker.run(messages).await; + let _result = worker.run("What's the weather in Tokyo?").await; // ツールが呼び出されたことを確認 // Note: max_turns=1なのでツール結果後のリクエストは送信されない @@ -196,8 +194,7 @@ async fn test_worker_with_programmatic_events() { let client = MockLlmClient::new(events); let mut worker = Worker::new(client); - let messages = vec![worker_types::Message::user("Greet me")]; - let result = worker.run(messages).await; + let result = worker.run("Greet me").await; assert!(result.is_ok(), "Worker should complete successfully"); } diff --git a/worker/tests/worker_state_test.rs b/worker/tests/worker_state_test.rs new file mode 100644 index 0000000..7b7e6ed --- /dev/null +++ b/worker/tests/worker_state_test.rs @@ -0,0 +1,372 @@ +//! Worker状態管理のテスト +//! +//! Type-stateパターン(Mutable/Locked)による状態遷移と +//! ターン間の状態保持をテストする。 + +mod common; + +use common::MockLlmClient; +use worker::Worker; +use worker_types::{Event, Message, MessageContent, ResponseStatus, StatusEvent}; + +// ============================================================================= +// Mutable状態のテスト +// ============================================================================= + +/// Mutable状態でシステムプロンプトを設定できることを確認 +#[test] +fn test_mutable_set_system_prompt() { + let client = MockLlmClient::new(vec![]); + let mut worker = Worker::new(client); + + assert!(worker.get_system_prompt().is_none()); + + worker.set_system_prompt("You are a helpful assistant."); + assert_eq!( + worker.get_system_prompt(), + Some("You are a helpful assistant.") + ); +} + +/// Mutable状態で履歴を自由に編集できることを確認 +#[test] +fn test_mutable_history_manipulation() { + let client = MockLlmClient::new(vec![]); + let mut worker = Worker::new(client); + + // 初期状態は空 + assert!(worker.history().is_empty()); + + // 履歴を追加 + worker.push_message(Message::user("Hello")); + worker.push_message(Message::assistant("Hi there!")); + assert_eq!(worker.history().len(), 2); + + // 履歴への可変アクセス + worker.history_mut().push(Message::user("How are you?")); + assert_eq!(worker.history().len(), 3); + + // 履歴をクリア + worker.clear_history(); + assert!(worker.history().is_empty()); + + // 履歴を設定 + let messages = vec![Message::user("Test"), Message::assistant("Response")]; + worker.set_history(messages); + assert_eq!(worker.history().len(), 2); +} + +/// ビルダーパターンでWorkerを構築できることを確認 +#[test] +fn test_mutable_builder_pattern() { + let client = MockLlmClient::new(vec![]); + let worker = Worker::new(client) + .system_prompt("System prompt") + .with_message(Message::user("Hello")) + .with_message(Message::assistant("Hi!")) + .with_messages(vec![ + Message::user("How are you?"), + Message::assistant("I'm fine!"), + ]); + + assert_eq!(worker.get_system_prompt(), Some("System prompt")); + assert_eq!(worker.history().len(), 4); +} + +/// extend_historyで複数メッセージを追加できることを確認 +#[test] +fn test_mutable_extend_history() { + let client = MockLlmClient::new(vec![]); + let mut worker = Worker::new(client); + + worker.push_message(Message::user("First")); + + worker.extend_history(vec![ + Message::assistant("Response 1"), + Message::user("Second"), + Message::assistant("Response 2"), + ]); + + assert_eq!(worker.history().len(), 4); +} + +// ============================================================================= +// 状態遷移テスト +// ============================================================================= + +/// lock()でMutable -> Locked状態に遷移することを確認 +#[test] +fn test_lock_transition() { + let client = MockLlmClient::new(vec![]); + let mut worker = Worker::new(client); + + worker.set_system_prompt("System"); + worker.push_message(Message::user("Hello")); + worker.push_message(Message::assistant("Hi")); + + // ロック + let locked_worker = worker.lock(); + + // Locked状態でも履歴とシステムプロンプトにアクセス可能 + assert_eq!(locked_worker.get_system_prompt(), Some("System")); + assert_eq!(locked_worker.history().len(), 2); + assert_eq!(locked_worker.locked_prefix_len(), 2); +} + +/// unlock()でLocked -> Mutable状態に遷移することを確認 +#[test] +fn test_unlock_transition() { + let client = MockLlmClient::new(vec![]); + let mut worker = Worker::new(client); + + worker.push_message(Message::user("Hello")); + let locked_worker = worker.lock(); + + // アンロック + let mut worker = locked_worker.unlock(); + + // Mutable状態に戻ったので履歴操作が可能 + worker.push_message(Message::assistant("Hi")); + worker.clear_history(); + assert!(worker.history().is_empty()); +} + +// ============================================================================= +// ターン実行と状態保持のテスト +// ============================================================================= + +/// Mutable状態でターンを実行し、履歴が正しく更新されることを確認 +#[tokio::test] +async fn test_mutable_run_updates_history() { + let events = vec![ + Event::text_block_start(0), + Event::text_delta(0, "Hello, I'm an assistant!"), + Event::text_block_stop(0, None), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ]; + + let client = MockLlmClient::new(events); + let mut worker = Worker::new(client); + + // 実行 + let result = worker.run("Hi there").await; + assert!(result.is_ok()); + + // 履歴が更新されている + let history = worker.history(); + assert_eq!(history.len(), 2); // user + assistant + + // ユーザーメッセージ + assert!(matches!( + &history[0].content, + MessageContent::Text(t) if t == "Hi there" + )); + + // アシスタントメッセージ + assert!(matches!( + &history[1].content, + MessageContent::Text(t) if t == "Hello, I'm an assistant!" + )); +} + +/// Locked状態で複数ターンを実行し、履歴が正しく累積することを確認 +#[tokio::test] +async fn test_locked_multi_turn_history_accumulation() { + // 2回のリクエストに対応するレスポンスを準備 + let client = MockLlmClient::with_responses(vec![ + // 1回目のレスポンス + vec![ + Event::text_block_start(0), + Event::text_delta(0, "Nice to meet you!"), + Event::text_block_stop(0, None), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ], + // 2回目のレスポンス + vec![ + Event::text_block_start(0), + Event::text_delta(0, "I can help with that."), + Event::text_block_stop(0, None), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ], + ]); + + let worker = Worker::new(client).system_prompt("You are helpful."); + + // ロック(システムプロンプト設定後) + let mut locked_worker = worker.lock(); + assert_eq!(locked_worker.locked_prefix_len(), 0); // メッセージはまだない + + // 1ターン目 + let result1 = locked_worker.run("Hello!").await; + assert!(result1.is_ok()); + assert_eq!(locked_worker.history().len(), 2); // user + assistant + + // 2ターン目 + let result2 = locked_worker.run("Can you help me?").await; + assert!(result2.is_ok()); + assert_eq!(locked_worker.history().len(), 4); // 2 * (user + assistant) + + // 履歴の内容を確認 + let history = locked_worker.history(); + + // 1ターン目のユーザーメッセージ + assert!(matches!(&history[0].content, MessageContent::Text(t) if t == "Hello!")); + + // 1ターン目のアシスタントメッセージ + assert!(matches!(&history[1].content, MessageContent::Text(t) if t == "Nice to meet you!")); + + // 2ターン目のユーザーメッセージ + assert!(matches!(&history[2].content, MessageContent::Text(t) if t == "Can you help me?")); + + // 2ターン目のアシスタントメッセージ + assert!(matches!(&history[3].content, MessageContent::Text(t) if t == "I can help with that.")); +} + +/// locked_prefix_lenがロック時点の履歴長を正しく記録することを確認 +#[tokio::test] +async fn test_locked_prefix_len_tracking() { + let client = MockLlmClient::with_responses(vec![ + vec![ + Event::text_block_start(0), + Event::text_delta(0, "Response 1"), + Event::text_block_stop(0, None), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ], + vec![ + Event::text_block_start(0), + Event::text_delta(0, "Response 2"), + Event::text_block_stop(0, None), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ], + ]); + + let mut worker = Worker::new(client); + + // 事前にメッセージを追加 + worker.push_message(Message::user("Pre-existing message 1")); + worker.push_message(Message::assistant("Pre-existing response 1")); + + assert_eq!(worker.history().len(), 2); + + // ロック + let mut locked_worker = worker.lock(); + assert_eq!(locked_worker.locked_prefix_len(), 2); // ロック時点で2メッセージ + + // ターン実行 + locked_worker.run("New message").await.unwrap(); + + // 履歴は増えるが、locked_prefix_lenは変わらない + assert_eq!(locked_worker.history().len(), 4); // 2 + 2 + assert_eq!(locked_worker.locked_prefix_len(), 2); // 変わらない +} + +/// ターンカウントが正しくインクリメントされることを確認 +#[tokio::test] +async fn test_turn_count_increment() { + let client = MockLlmClient::with_responses(vec![ + vec![ + Event::text_block_start(0), + Event::text_delta(0, "Turn 1"), + Event::text_block_stop(0, None), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ], + vec![ + Event::text_block_start(0), + Event::text_delta(0, "Turn 2"), + Event::text_block_stop(0, None), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ], + ]); + + let mut worker = Worker::new(client); + + assert_eq!(worker.turn_count(), 0); + + worker.run("First").await.unwrap(); + assert_eq!(worker.turn_count(), 1); + + worker.run("Second").await.unwrap(); + assert_eq!(worker.turn_count(), 2); +} + +/// unlock後に履歴を編集し、再度lockできることを確認 +#[tokio::test] +async fn test_unlock_edit_relock() { + let client = MockLlmClient::with_responses(vec![vec![ + Event::text_block_start(0), + Event::text_delta(0, "Response"), + Event::text_block_stop(0, None), + Event::Status(StatusEvent { + status: ResponseStatus::Completed, + }), + ]]); + + let worker = Worker::new(client) + .with_message(Message::user("Hello")) + .with_message(Message::assistant("Hi")); + + // ロック -> アンロック + let locked = worker.lock(); + assert_eq!(locked.locked_prefix_len(), 2); + + let mut unlocked = locked.unlock(); + + // 履歴を編集 + unlocked.clear_history(); + unlocked.push_message(Message::user("Fresh start")); + + // 再ロック + let relocked = unlocked.lock(); + assert_eq!(relocked.history().len(), 1); + assert_eq!(relocked.locked_prefix_len(), 1); +} + +// ============================================================================= +// システムプロンプト保持のテスト +// ============================================================================= + +/// Locked状態でもシステムプロンプトが保持されることを確認 +#[test] +fn test_system_prompt_preserved_in_locked_state() { + let client = MockLlmClient::new(vec![]); + let worker = Worker::new(client).system_prompt("Important system prompt"); + + let locked = worker.lock(); + assert_eq!(locked.get_system_prompt(), Some("Important system prompt")); + + let unlocked = locked.unlock(); + assert_eq!( + unlocked.get_system_prompt(), + Some("Important system prompt") + ); +} + +/// unlock -> 再lock でシステムプロンプトを変更できることを確認 +#[test] +fn test_system_prompt_change_after_unlock() { + let client = MockLlmClient::new(vec![]); + let worker = Worker::new(client).system_prompt("Original prompt"); + + let locked = worker.lock(); + let mut unlocked = locked.unlock(); + + unlocked.set_system_prompt("New prompt"); + assert_eq!(unlocked.get_system_prompt(), Some("New prompt")); + + let relocked = unlocked.lock(); + assert_eq!(relocked.get_system_prompt(), Some("New prompt")); +}