feat: Implement worker context management and cache protection mechanisms using type-state

This commit is contained in:
Keisuke Hirata 2026-01-08 17:57:03 +09:00
parent 45c8457b71
commit 2487d1ece7
9 changed files with 831 additions and 195 deletions

68
docs/spec/cache_lock.md Normal file
View File

@ -0,0 +1,68 @@
# KVキャッシュを中心とした設計
LLMのKVキャッシュのヒット率を重要なメトリクスであるとし、APIレベルでキャッシュ操作を中心とした設計を行う。
## 前提
リクエスト間キャッシュ(Context Caching)は、複数のリクエストで同じ入力トークン列が繰り返された際、プロバイダ側が計算済みの状態を再利用することでレイテンシと入力コストを下げる仕組みである。
キャッシュは主に**先頭一致 (Common Prefix)** によってHitするため、前提となるシステムプロンプトや、会話ログの過去部分前方を変化させると、以降のキャッシュは無効となる。
## 要件
1. **前方不変性の保証 (Prefix Immutability)**
* 後方に会話が追加されても、前方のデータシステムプロンプトや確定済みのメッセージ履歴が変化しないことをAPIレベルで保証する。
* これにより、意図しないキャッシュミスCache Missを防ぐ。
2. **データ上の再現性**
* コンテキストのデータ構造が同一であれば、生成されるリクエスト構造も同一であることを保証する。
* シリアライズ結果のバイト単位の完全一致までは求めないが、論理的なリクエスト構造は保たれる必要がある。
## アプローチ: Type-state Pattern
RustのType-stateパターンを利用し、Workerの状態によって利用可能な操作をコンパイル時に制限する。
### 1. 状態定義
* **`Mutable` (初期状態)**
* 自由な編集が可能な状態。
* システムプロンプトの設定・変更が可能。
* メッセージ履歴の初期構築(ロード、編集)が可能。
* **`Locked` (キャッシュ保護状態)**
* キャッシュの有効活用を目的とした、前方不変状態。
* **システムプロンプトの変更不可**。
* **既存メッセージ履歴の変更不可**(追記のみ許可)。
* 実行(`run`)はこの状態で行うことを推奨する。
### 2. 状態遷移とAPIイメージ
`Worker` 自身がコンテキスト(履歴)のオーナーとなり、状態によってアクセサを制限する。
```rust
// 1. Mutable状態で初期化
let mut worker: Worker<Mutable> = Worker::new(client);
// 2. コンテキストの構築 (Mutableなので自由に変更可)
worker.set_system_prompt("You are a helpful assistant.");
worker.history_mut().push(initial_message);
// 3. ロックしてLocked状態へ遷移
// これにより、ここまでのコンテキストが "Fixed Prefix" として扱われる
let mut locked_worker: Worker<Locked> = worker.lock();
// 4. 利用 (Locked状態)
// 実行は可能。新しいメッセージは履歴の末尾に追記される。
// 前方の履歴やシステムプロンプトは変更できないため、キャッシュヒットが保証される。
locked_worker.run(new_user_input).await?;
// NG操作 (コンパイルエラー)
// locked_worker.set_system_prompt("New prompt");
// locked_worker.history_mut().clear();
```
### 3. 実装への影響
現在の `Worker` 実装に対し、以下の変更が必要となる。
* **状態パラメータの導入**: `Worker<S: WorkerState>` の導入。
* **コンテキスト所有権の委譲**: `run` メソッドの引数でコンテキストを受け取るのではなく、`Worker` 内部に `history: Vec<Message>` を保持し管理する形へ移行する。
* **APIの分離**: `Mutable` 特有のメソッドsetter等と、`Locked` でも使えるメソッド(実行、参照等)をトレイト境界で分離する。

View File

@ -12,6 +12,7 @@ mod event;
mod handler; mod handler;
mod hook; mod hook;
mod message; mod message;
mod state;
mod subscriber; mod subscriber;
mod tool; mod tool;
@ -19,5 +20,6 @@ pub use event::*;
pub use handler::*; pub use handler::*;
pub use hook::*; pub use hook::*;
pub use message::*; pub use message::*;
pub use state::*;
pub use subscriber::*; pub use subscriber::*;
pub use tool::*; pub use tool::*;

40
worker-types/src/state.rs Normal file
View File

@ -0,0 +1,40 @@
//! Worker状態マーカー型
//!
//! Type-stateパターンによるキャッシュ保護のための状態定義
/// Worker状態を表すマーカートレイト
///
/// このトレイトはシールされており、外部から実装することはできない。
pub trait WorkerState: private::Sealed + Send + Sync + 'static {}
mod private {
pub trait Sealed {}
}
/// 変更可能状態
///
/// この状態では以下の操作が可能:
/// - システムプロンプトの設定・変更
/// - メッセージ履歴の編集(追加、削除、クリア)
/// - ツール・Hookの登録
///
/// `lock()` によって `Locked` 状態へ遷移できる。
#[derive(Debug, Clone, Copy, Default)]
pub struct Mutable;
impl private::Sealed for Mutable {}
impl WorkerState for Mutable {}
/// ロック状態(キャッシュ保護)
///
/// この状態では以下の制限がある:
/// - システムプロンプトの変更不可
/// - 既存メッセージ履歴の変更不可(末尾への追記のみ)
///
/// 実行(`run`)はこの状態で行うことが推奨される。
/// キャッシュヒットを保証するため、前方のコンテキストは不変となる。
#[derive(Debug, Clone, Copy, Default)]
pub struct Locked;
impl private::Sealed for Locked {}
impl WorkerState for Locked {}

View File

@ -51,7 +51,6 @@ use worker::{
}, },
}; };
use worker_macros::tool_registry; use worker_macros::tool_registry;
use worker_types::Message;
// 必要なマクロ展開用インポート // 必要なマクロ展開用インポート
use schemars; use schemars;
@ -453,14 +452,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
worker.add_hook(ToolResultPrinterHook::new(tool_call_names)); worker.add_hook(ToolResultPrinterHook::new(tool_call_names));
// 会話履歴
let mut history: Vec<Message> = Vec::new();
// ワンショットモード // ワンショットモード
if let Some(prompt) = args.prompt { if let Some(prompt) = args.prompt {
history.push(Message::user(&prompt)); match worker.run(&prompt).await {
match worker.run(history).await {
Ok(_) => {} Ok(_) => {}
Err(e) => { Err(e) => {
eprintln!("\n❌ Error: {}", e); eprintln!("\n❌ Error: {}", e);
@ -489,18 +483,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
break; break;
} }
// ユーザーメッセージを履歴に追加 // Workerを実行Workerが履歴を管理
history.push(Message::user(input)); match worker.run(input).await {
Ok(_) => {}
// Workerを実行
match worker.run(history.clone()).await {
Ok(new_history) => {
history = new_history;
}
Err(e) => { Err(e) => {
eprintln!("\n❌ Error: {}", e); eprintln!("\n❌ Error: {}", e);
// エラー時は最後のユーザーメッセージを削除
history.pop();
} }
} }
} }

View File

@ -1,4 +1,5 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use futures::StreamExt; use futures::StreamExt;
@ -13,8 +14,8 @@ use crate::subscriber_adapter::{
use crate::text_block_collector::TextBlockCollector; use crate::text_block_collector::TextBlockCollector;
use crate::tool_call_collector::ToolCallCollector; use crate::tool_call_collector::ToolCallCollector;
use worker_types::{ use worker_types::{
ContentPart, ControlFlow, HookError, Message, MessageContent, Tool, ToolCall, ToolError, ContentPart, ControlFlow, HookError, Locked, Message, MessageContent, Mutable, Tool, ToolCall,
ToolResult, TurnResult, WorkerHook, WorkerSubscriber, ToolError, ToolResult, TurnResult, WorkerHook, WorkerState, WorkerSubscriber,
}; };
// ============================================================================= // =============================================================================
@ -83,12 +84,19 @@ impl<S: WorkerSubscriber + 'static> TurnNotifier for SubscriberTurnNotifier<S> {
/// Worker - ターン制御コンポーネント /// Worker - ターン制御コンポーネント
/// ///
/// Type-stateパターンによりキャッシュ保護を実現する。
///
/// # 状態
/// - `Mutable`: 初期状態。システムプロンプトや履歴を自由に編集可能。
/// - `Locked`: キャッシュ保護状態。前方コンテキストは不変となり、追記のみ可能。
///
/// # 責務 /// # 責務
/// - LLMへのリクエスト送信とレスポンス処理 /// - LLMへのリクエスト送信とレスポンス処理
/// - ツール呼び出しの収集と実行 /// - ツール呼び出しの収集と実行
/// - Hookによる介入の提供 /// - Hookによる介入の提供
/// - ターンループの制御 /// - ターンループの制御
pub struct Worker<C: LlmClient> { /// - 履歴の所有と管理
pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
/// LLMクライアント /// LLMクライアント
client: C, client: C,
/// イベントタイムライン /// イベントタイムライン
@ -103,36 +111,23 @@ pub struct Worker<C: LlmClient> {
hooks: Vec<Box<dyn WorkerHook>>, hooks: Vec<Box<dyn WorkerHook>>,
/// システムプロンプト /// システムプロンプト
system_prompt: Option<String>, system_prompt: Option<String>,
/// メッセージ履歴Workerが所有
history: Vec<Message>,
/// ロック時点での履歴長Locked状態でのみ意味を持つ
locked_prefix_len: usize,
/// ターンカウント /// ターンカウント
turn_count: usize, turn_count: usize,
/// ターン通知用のコールバック /// ターン通知用のコールバック
turn_notifiers: Vec<Box<dyn TurnNotifier>>, turn_notifiers: Vec<Box<dyn TurnNotifier>>,
/// 状態マーカー
_state: PhantomData<S>,
} }
impl<C: LlmClient> Worker<C> { // =============================================================================
/// 新しいWorkerを作成 // 共通実装(全状態で利用可能)
pub fn new(client: C) -> Self { // =============================================================================
let text_block_collector = TextBlockCollector::new();
let tool_call_collector = ToolCallCollector::new();
let mut timeline = Timeline::new();
// コレクターをTimelineに登録
timeline.on_text_block(text_block_collector.clone());
timeline.on_tool_use_block(tool_call_collector.clone());
Self {
client,
timeline,
text_block_collector,
tool_call_collector,
tools: HashMap::new(),
hooks: Vec::new(),
system_prompt: None,
turn_count: 0,
turn_notifiers: Vec::new(),
}
}
impl<C: LlmClient, S: WorkerState> Worker<C, S> {
/// WorkerSubscriberを登録 /// WorkerSubscriberを登録
/// ///
/// Subscriberは以下のイベントを受け取ることができる: /// Subscriberは以下のイベントを受け取ることができる:
@ -140,7 +135,7 @@ impl<C: LlmClient> Worker<C> {
/// - 単発イベント: on_usage, on_status, on_error /// - 単発イベント: on_usage, on_status, on_error
/// - 累積イベント: on_text_complete, on_tool_call_complete /// - 累積イベント: on_text_complete, on_tool_call_complete
/// - ターン制御: on_turn_start, on_turn_end /// - ターン制御: on_turn_start, on_turn_end
pub fn subscribe<S: WorkerSubscriber + 'static>(&mut self, subscriber: S) { pub fn subscribe<Sub: WorkerSubscriber + 'static>(&mut self, subscriber: Sub) {
let subscriber = Arc::new(Mutex::new(subscriber)); let subscriber = Arc::new(Mutex::new(subscriber));
// TextBlock用ハンドラを登録 // TextBlock用ハンドラを登録
@ -164,23 +159,6 @@ impl<C: LlmClient> Worker<C> {
.push(Box::new(SubscriberTurnNotifier { subscriber })); .push(Box::new(SubscriberTurnNotifier { subscriber }));
} }
/// システムプロンプトを設定
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
/// システムプロンプトを設定(可変参照版)
pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
self.system_prompt = Some(prompt.into());
}
/// 設定を適用(将来の拡張用)
#[allow(dead_code)]
pub fn config(self, _config: WorkerConfig) -> Self {
self
}
/// ツールを登録 /// ツールを登録
pub fn register_tool(&mut self, tool: impl Tool + 'static) { pub fn register_tool(&mut self, tool: impl Tool + 'static) {
let name = tool.name().to_string(); let name = tool.name().to_string();
@ -204,6 +182,21 @@ impl<C: LlmClient> Worker<C> {
&mut self.timeline &mut self.timeline
} }
/// 履歴への参照を取得
pub fn history(&self) -> &[Message] {
&self.history
}
/// システムプロンプトへの参照を取得
pub fn get_system_prompt(&self) -> Option<&str> {
self.system_prompt.as_deref()
}
/// 現在のターンカウントを取得
pub fn turn_count(&self) -> usize {
self.turn_count
}
/// 登録されたツールからToolDefinitionのリストを生成 /// 登録されたツールからToolDefinitionのリストを生成
fn build_tool_definitions(&self) -> Vec<ToolDefinition> { fn build_tool_definitions(&self) -> Vec<ToolDefinition> {
self.tools self.tools
@ -216,107 +209,6 @@ impl<C: LlmClient> Worker<C> {
.collect() .collect()
} }
/// ターンを実行
///
/// メッセージを送信し、レスポンスを処理する。
/// ツール呼び出しがある場合は自動的にループする。
pub async fn run(&mut self, messages: Vec<Message>) -> Result<Vec<Message>, WorkerError> {
let mut context = messages;
let tool_definitions = self.build_tool_definitions();
info!(
message_count = context.len(),
tool_count = tool_definitions.len(),
"Starting worker run"
);
loop {
// ターン開始を通知
let current_turn = self.turn_count;
debug!(turn = current_turn, "Turn start");
for notifier in &self.turn_notifiers {
notifier.on_turn_start(current_turn);
}
// Hook: on_message_send
let control = self.run_on_message_send_hooks(&mut context).await?;
if let ControlFlow::Abort(reason) = control {
warn!(reason = %reason, "Aborted by hook");
// ターン終了を通知(異常終了)
for notifier in &self.turn_notifiers {
notifier.on_turn_end(current_turn);
}
return Err(WorkerError::Aborted(reason));
}
// リクエスト構築
let request = self.build_request(&context, &tool_definitions);
debug!(
message_count = request.messages.len(),
tool_count = request.tools.len(),
has_system = request.system_prompt.is_some(),
"Sending request to LLM"
);
// ストリーム処理
debug!("Starting stream...");
let mut stream = self.client.stream(request).await?;
let mut event_count = 0;
while let Some(event_result) = stream.next().await {
match &event_result {
Ok(event) => {
trace!(event = ?event, "Received event");
event_count += 1;
}
Err(e) => {
warn!(error = %e, "Stream error");
}
}
let event = event_result?;
self.timeline.dispatch(&event);
}
debug!(event_count = event_count, "Stream completed");
// ターン終了を通知
for notifier in &self.turn_notifiers {
notifier.on_turn_end(current_turn);
}
self.turn_count += 1;
// 収集結果を取得
let text_blocks = self.text_block_collector.take_collected();
let tool_calls = self.tool_call_collector.take_collected();
// アシスタントメッセージをコンテキストに追加
let assistant_message = self.build_assistant_message(&text_blocks, &tool_calls);
if let Some(msg) = assistant_message {
context.push(msg);
}
if tool_calls.is_empty() {
// ツール呼び出しなし → ターン終了判定
let turn_result = self.run_on_turn_end_hooks(&context).await?;
match turn_result {
TurnResult::Finish => {
return Ok(context);
}
TurnResult::ContinueWithMessages(additional) => {
context.extend(additional);
continue;
}
}
}
// ツール実行
let tool_results = self.execute_tools(tool_calls).await?;
// ツール結果をコンテキストに追加
for result in tool_results {
context.push(Message::tool_result(&result.tool_use_id, &result.content));
}
}
}
/// テキストブロックとツール呼び出しからアシスタントメッセージを構築 /// テキストブロックとツール呼び出しからアシスタントメッセージを構築
fn build_assistant_message( fn build_assistant_message(
&self, &self,
@ -360,7 +252,7 @@ impl<C: LlmClient> Worker<C> {
} }
/// リクエストを構築 /// リクエストを構築
fn build_request(&self, context: &[Message], tool_definitions: &[ToolDefinition]) -> Request { fn build_request(&self, tool_definitions: &[ToolDefinition]) -> Request {
let mut request = Request::new(); let mut request = Request::new();
// システムプロンプトを設定 // システムプロンプトを設定
@ -369,7 +261,7 @@ impl<C: LlmClient> Worker<C> {
} }
// メッセージを追加 // メッセージを追加
for msg in context { for msg in &self.history {
// worker-types::Message から llm_client::Message への変換 // worker-types::Message から llm_client::Message への変換
request = request.message(crate::llm_client::Message { request = request.message(crate::llm_client::Message {
role: match msg.role { role: match msg.role {
@ -426,12 +318,13 @@ impl<C: LlmClient> Worker<C> {
} }
/// Hooks: on_message_send /// Hooks: on_message_send
async fn run_on_message_send_hooks( async fn run_on_message_send_hooks(&self) -> Result<ControlFlow, WorkerError> {
&self,
context: &mut Vec<Message>,
) -> Result<ControlFlow, WorkerError> {
for hook in &self.hooks { for hook in &self.hooks {
let result = hook.on_message_send(context).await?; // Note: Locked状態でも履歴全体を参照として渡す変更は不可
// HookのAPIを変更し、immutable参照のみを渡すようにする必要があるかもしれない
// 現在は空のVecを渡して回避要検討
let mut temp_context = self.history.clone();
let result = hook.on_message_send(&mut temp_context).await?;
match result { match result {
ControlFlow::Continue => continue, ControlFlow::Continue => continue,
ControlFlow::Skip => return Ok(ControlFlow::Skip), ControlFlow::Skip => return Ok(ControlFlow::Skip),
@ -442,9 +335,9 @@ impl<C: LlmClient> Worker<C> {
} }
/// Hooks: on_turn_end /// Hooks: on_turn_end
async fn run_on_turn_end_hooks(&self, messages: &[Message]) -> Result<TurnResult, WorkerError> { async fn run_on_turn_end_hooks(&self) -> Result<TurnResult, WorkerError> {
for hook in &self.hooks { for hook in &self.hooks {
let result = hook.on_turn_end(messages).await?; let result = hook.on_turn_end(&self.history).await?;
match result { match result {
TurnResult::Finish => continue, TurnResult::Finish => continue,
TurnResult::ContinueWithMessages(msgs) => { TurnResult::ContinueWithMessages(msgs) => {
@ -528,6 +421,291 @@ impl<C: LlmClient> Worker<C> {
Ok(results) Ok(results)
} }
/// 内部で使用するターン実行ロジック
async fn run_turn_loop(&mut self) -> Result<(), WorkerError> {
let tool_definitions = self.build_tool_definitions();
info!(
message_count = self.history.len(),
tool_count = tool_definitions.len(),
"Starting worker run"
);
loop {
// ターン開始を通知
let current_turn = self.turn_count;
debug!(turn = current_turn, "Turn start");
for notifier in &self.turn_notifiers {
notifier.on_turn_start(current_turn);
}
// Hook: on_message_send
let control = self.run_on_message_send_hooks().await?;
if let ControlFlow::Abort(reason) = control {
warn!(reason = %reason, "Aborted by hook");
// ターン終了を通知(異常終了)
for notifier in &self.turn_notifiers {
notifier.on_turn_end(current_turn);
}
return Err(WorkerError::Aborted(reason));
}
// リクエスト構築
let request = self.build_request(&tool_definitions);
debug!(
message_count = request.messages.len(),
tool_count = request.tools.len(),
has_system = request.system_prompt.is_some(),
"Sending request to LLM"
);
// ストリーム処理
debug!("Starting stream...");
let mut stream = self.client.stream(request).await?;
let mut event_count = 0;
while let Some(event_result) = stream.next().await {
match &event_result {
Ok(event) => {
trace!(event = ?event, "Received event");
event_count += 1;
}
Err(e) => {
warn!(error = %e, "Stream error");
}
}
let event = event_result?;
self.timeline.dispatch(&event);
}
debug!(event_count = event_count, "Stream completed");
// ターン終了を通知
for notifier in &self.turn_notifiers {
notifier.on_turn_end(current_turn);
}
self.turn_count += 1;
// 収集結果を取得
let text_blocks = self.text_block_collector.take_collected();
let tool_calls = self.tool_call_collector.take_collected();
// アシスタントメッセージを履歴に追加
let assistant_message = self.build_assistant_message(&text_blocks, &tool_calls);
if let Some(msg) = assistant_message {
self.history.push(msg);
}
if tool_calls.is_empty() {
// ツール呼び出しなし → ターン終了判定
let turn_result = self.run_on_turn_end_hooks().await?;
match turn_result {
TurnResult::Finish => {
return Ok(());
}
TurnResult::ContinueWithMessages(additional) => {
self.history.extend(additional);
continue;
}
}
}
// ツール実行
let tool_results = self.execute_tools(tool_calls).await?;
// ツール結果を履歴に追加
for result in tool_results {
self.history
.push(Message::tool_result(&result.tool_use_id, &result.content));
}
}
}
}
// =============================================================================
// Mutable状態専用の実装
// =============================================================================
impl<C: LlmClient> Worker<C, Mutable> {
/// 新しいWorkerを作成Mutable状態
pub fn new(client: C) -> Self {
let text_block_collector = TextBlockCollector::new();
let tool_call_collector = ToolCallCollector::new();
let mut timeline = Timeline::new();
// コレクターをTimelineに登録
timeline.on_text_block(text_block_collector.clone());
timeline.on_tool_use_block(tool_call_collector.clone());
Self {
client,
timeline,
text_block_collector,
tool_call_collector,
tools: HashMap::new(),
hooks: Vec::new(),
system_prompt: None,
history: Vec::new(),
locked_prefix_len: 0,
turn_count: 0,
turn_notifiers: Vec::new(),
_state: PhantomData,
}
}
/// システムプロンプトを設定(ビルダーパターン)
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
/// システムプロンプトを設定(可変参照版)
pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
self.system_prompt = Some(prompt.into());
}
/// 履歴への可変参照を取得
///
/// Mutable状態でのみ利用可能。履歴を自由に編集できる。
pub fn history_mut(&mut self) -> &mut Vec<Message> {
&mut self.history
}
/// 履歴を設定
pub fn set_history(&mut self, messages: Vec<Message>) {
self.history = messages;
}
/// 履歴にメッセージを追加(ビルダーパターン)
pub fn with_message(mut self, message: Message) -> Self {
self.history.push(message);
self
}
/// 履歴にメッセージを追加
pub fn push_message(&mut self, message: Message) {
self.history.push(message);
}
/// 複数のメッセージを履歴に追加(ビルダーパターン)
pub fn with_messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
self.history.extend(messages);
self
}
/// 複数のメッセージを履歴に追加
pub fn extend_history(&mut self, messages: impl IntoIterator<Item = Message>) {
self.history.extend(messages);
}
/// 履歴をクリア
pub fn clear_history(&mut self) {
self.history.clear();
}
/// 設定を適用(将来の拡張用)
#[allow(dead_code)]
pub fn config(self, _config: WorkerConfig) -> Self {
self
}
/// ロックしてLocked状態へ遷移
///
/// この操作により、現在のシステムプロンプトと履歴が「確定済みプレフィックス」として
/// 固定される。以降は履歴への追記のみが可能となり、キャッシュヒットが保証される。
pub fn lock(self) -> Worker<C, Locked> {
let locked_prefix_len = self.history.len();
Worker {
client: self.client,
timeline: self.timeline,
text_block_collector: self.text_block_collector,
tool_call_collector: self.tool_call_collector,
tools: self.tools,
hooks: self.hooks,
system_prompt: self.system_prompt,
history: self.history,
locked_prefix_len,
turn_count: self.turn_count,
turn_notifiers: self.turn_notifiers,
_state: PhantomData,
}
}
/// ターンを実行Mutable状態
///
/// 新しいユーザーメッセージを履歴に追加し、LLMにリクエストを送信する。
/// ツール呼び出しがある場合は自動的にループする。
///
/// 注意: この関数は履歴を変更するため、キャッシュ保護が必要な場合は
/// `lock()` を呼んでからLocked状態で `run` を使用すること。
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<&[Message], WorkerError> {
self.history.push(Message::user(user_input));
self.run_turn_loop().await?;
Ok(&self.history)
}
/// 複数メッセージでターンを実行Mutable状態
///
/// 指定されたメッセージを履歴に追加してから実行する。
pub async fn run_with_messages(
&mut self,
messages: Vec<Message>,
) -> Result<&[Message], WorkerError> {
self.history.extend(messages);
self.run_turn_loop().await?;
Ok(&self.history)
}
}
// =============================================================================
// Locked状態専用の実装
// =============================================================================
impl<C: LlmClient> Worker<C, Locked> {
/// ターンを実行Locked状態
///
/// 新しいユーザーメッセージを履歴の末尾に追加し、LLMにリクエストを送信する。
/// ロック時点より前の履歴(プレフィックス)は不変であるため、キャッシュヒットが保証される。
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<&[Message], WorkerError> {
self.history.push(Message::user(user_input));
self.run_turn_loop().await?;
Ok(&self.history)
}
/// 複数メッセージでターンを実行Locked状態
pub async fn run_with_messages(
&mut self,
messages: Vec<Message>,
) -> Result<&[Message], WorkerError> {
self.history.extend(messages);
self.run_turn_loop().await?;
Ok(&self.history)
}
/// ロック時点のプレフィックス長を取得
pub fn locked_prefix_len(&self) -> usize {
self.locked_prefix_len
}
/// ロックを解除してMutable状態へ戻す
///
/// 注意: この操作を行うと、以降のリクエストでキャッシュがヒットしなくなる可能性がある。
/// 履歴を編集する必要がある場合にのみ使用すること。
pub fn unlock(self) -> Worker<C, Mutable> {
Worker {
client: self.client,
timeline: self.timeline,
text_block_collector: self.text_block_collector,
tool_call_collector: self.tool_call_collector,
tools: self.tools,
hooks: self.hooks,
system_prompt: self.system_prompt,
history: self.history,
locked_prefix_len: 0,
turn_count: self.turn_count,
turn_notifiers: self.turn_notifiers,
_state: PhantomData,
}
}
} }
#[cfg(test)] #[cfg(test)]

View File

@ -9,7 +9,7 @@ use std::time::{Duration, Instant};
use async_trait::async_trait; use async_trait::async_trait;
use worker::Worker; use worker::Worker;
use worker_types::{ use worker_types::{
ControlFlow, Event, HookError, Message, ResponseStatus, StatusEvent, Tool, ToolCall, ToolError, ControlFlow, Event, HookError, ResponseStatus, StatusEvent, Tool, ToolCall, ToolError,
ToolResult, WorkerHook, ToolResult, WorkerHook,
}; };
@ -108,10 +108,8 @@ async fn test_parallel_tool_execution() {
worker.register_tool(tool2); worker.register_tool(tool2);
worker.register_tool(tool3); worker.register_tool(tool3);
let messages = vec![Message::user("Run all tools")];
let start = Instant::now(); let start = Instant::now();
let _result = worker.run(messages).await; let _result = worker.run("Run all tools").await;
let elapsed = start.elapsed(); let elapsed = start.elapsed();
// 全ツールが呼び出されたことを確認 // 全ツールが呼び出されたことを確認
@ -176,8 +174,7 @@ async fn test_before_tool_call_skip() {
worker.add_hook(BlockingHook); worker.add_hook(BlockingHook);
let messages = vec![Message::user("Test hook")]; let _result = worker.run("Test hook").await;
let _result = worker.run(messages).await;
// allowed_tool は呼び出されるが、blocked_tool は呼び出されない // allowed_tool は呼び出されるが、blocked_tool は呼び出されない
assert_eq!( assert_eq!(
@ -262,8 +259,7 @@ async fn test_after_tool_call_modification() {
modified_content: modified_content.clone(), modified_content: modified_content.clone(),
}); });
let messages = vec![Message::user("Test modification")]; let result = worker.run("Test modification").await;
let result = worker.run(messages).await;
assert!(result.is_ok(), "Worker should complete: {:?}", result); assert!(result.is_ok(), "Worker should complete: {:?}", result);

View File

@ -9,8 +9,8 @@ use std::sync::{Arc, Mutex};
use common::MockLlmClient; use common::MockLlmClient;
use worker::{Worker, WorkerSubscriber}; use worker::{Worker, WorkerSubscriber};
use worker_types::{ use worker_types::{
ErrorEvent, Event, Message, ResponseStatus, StatusEvent, TextBlockEvent, ToolCall, ErrorEvent, Event, ResponseStatus, StatusEvent, TextBlockEvent, ToolCall, ToolUseBlockEvent,
ToolUseBlockEvent, UsageEvent, UsageEvent,
}; };
// ============================================================================= // =============================================================================
@ -115,8 +115,7 @@ async fn test_subscriber_text_block_events() {
worker.subscribe(subscriber); worker.subscribe(subscriber);
// 実行 // 実行
let messages = vec![Message::user("Greet me")]; let result = worker.run("Greet me").await;
let result = worker.run(messages).await;
assert!(result.is_ok(), "Worker should complete: {:?}", result); assert!(result.is_ok(), "Worker should complete: {:?}", result);
@ -155,8 +154,7 @@ async fn test_subscriber_tool_call_complete() {
worker.subscribe(subscriber); worker.subscribe(subscriber);
// 実行 // 実行
let messages = vec![Message::user("Weather please")]; let _ = worker.run("Weather please").await;
let _ = worker.run(messages).await;
// ツール呼び出し完了が収集されていることを確認 // ツール呼び出し完了が収集されていることを確認
let completes = tool_call_completes.lock().unwrap(); let completes = tool_call_completes.lock().unwrap();
@ -188,8 +186,7 @@ async fn test_subscriber_turn_events() {
worker.subscribe(subscriber); worker.subscribe(subscriber);
// 実行 // 実行
let messages = vec![Message::user("Do something")]; let result = worker.run("Do something").await;
let result = worker.run(messages).await;
assert!(result.is_ok()); assert!(result.is_ok());
@ -226,8 +223,7 @@ async fn test_subscriber_usage_events() {
worker.subscribe(subscriber); worker.subscribe(subscriber);
// 実行 // 実行
let messages = vec![Message::user("Hello")]; let _ = worker.run("Hello").await;
let _ = worker.run(messages).await;
// Usageイベントが収集されていることを確認 // Usageイベントが収集されていることを確認
let usages = usage_events.lock().unwrap(); let usages = usage_events.lock().unwrap();

View File

@ -134,8 +134,7 @@ async fn test_worker_simple_text_response() {
let mut worker = Worker::new(client); let mut worker = Worker::new(client);
// シンプルなメッセージを送信 // シンプルなメッセージを送信
let messages = vec![worker_types::Message::user("Hello")]; let result = worker.run("Hello").await;
let result = worker.run(messages).await;
assert!(result.is_ok(), "Worker should complete successfully"); assert!(result.is_ok(), "Worker should complete successfully");
} }
@ -162,8 +161,7 @@ async fn test_worker_tool_call() {
worker.register_tool(weather_tool); worker.register_tool(weather_tool);
// メッセージを送信 // メッセージを送信
let messages = vec![worker_types::Message::user("What's the weather in Tokyo?")]; let _result = worker.run("What's the weather in Tokyo?").await;
let _result = worker.run(messages).await;
// ツールが呼び出されたことを確認 // ツールが呼び出されたことを確認
// Note: max_turns=1なのでツール結果後のリクエストは送信されない // Note: max_turns=1なのでツール結果後のリクエストは送信されない
@ -196,8 +194,7 @@ async fn test_worker_with_programmatic_events() {
let client = MockLlmClient::new(events); let client = MockLlmClient::new(events);
let mut worker = Worker::new(client); let mut worker = Worker::new(client);
let messages = vec![worker_types::Message::user("Greet me")]; let result = worker.run("Greet me").await;
let result = worker.run(messages).await;
assert!(result.is_ok(), "Worker should complete successfully"); assert!(result.is_ok(), "Worker should complete successfully");
} }

View File

@ -0,0 +1,372 @@
//! Worker状態管理のテスト
//!
//! Type-stateパターンMutable/Lockedによる状態遷移と
//! ターン間の状態保持をテストする。
mod common;
use common::MockLlmClient;
use worker::Worker;
use worker_types::{Event, Message, MessageContent, ResponseStatus, StatusEvent};
// =============================================================================
// Mutable状態のテスト
// =============================================================================
/// Mutable状態でシステムプロンプトを設定できることを確認
#[test]
fn test_mutable_set_system_prompt() {
let client = MockLlmClient::new(vec![]);
let mut worker = Worker::new(client);
assert!(worker.get_system_prompt().is_none());
worker.set_system_prompt("You are a helpful assistant.");
assert_eq!(
worker.get_system_prompt(),
Some("You are a helpful assistant.")
);
}
/// Mutable状態で履歴を自由に編集できることを確認
#[test]
fn test_mutable_history_manipulation() {
let client = MockLlmClient::new(vec![]);
let mut worker = Worker::new(client);
// 初期状態は空
assert!(worker.history().is_empty());
// 履歴を追加
worker.push_message(Message::user("Hello"));
worker.push_message(Message::assistant("Hi there!"));
assert_eq!(worker.history().len(), 2);
// 履歴への可変アクセス
worker.history_mut().push(Message::user("How are you?"));
assert_eq!(worker.history().len(), 3);
// 履歴をクリア
worker.clear_history();
assert!(worker.history().is_empty());
// 履歴を設定
let messages = vec![Message::user("Test"), Message::assistant("Response")];
worker.set_history(messages);
assert_eq!(worker.history().len(), 2);
}
/// ビルダーパターンでWorkerを構築できることを確認
#[test]
fn test_mutable_builder_pattern() {
let client = MockLlmClient::new(vec![]);
let worker = Worker::new(client)
.system_prompt("System prompt")
.with_message(Message::user("Hello"))
.with_message(Message::assistant("Hi!"))
.with_messages(vec![
Message::user("How are you?"),
Message::assistant("I'm fine!"),
]);
assert_eq!(worker.get_system_prompt(), Some("System prompt"));
assert_eq!(worker.history().len(), 4);
}
/// extend_historyで複数メッセージを追加できることを確認
#[test]
fn test_mutable_extend_history() {
let client = MockLlmClient::new(vec![]);
let mut worker = Worker::new(client);
worker.push_message(Message::user("First"));
worker.extend_history(vec![
Message::assistant("Response 1"),
Message::user("Second"),
Message::assistant("Response 2"),
]);
assert_eq!(worker.history().len(), 4);
}
// =============================================================================
// 状態遷移テスト
// =============================================================================
/// lock()でMutable -> Locked状態に遷移することを確認
#[test]
fn test_lock_transition() {
let client = MockLlmClient::new(vec![]);
let mut worker = Worker::new(client);
worker.set_system_prompt("System");
worker.push_message(Message::user("Hello"));
worker.push_message(Message::assistant("Hi"));
// ロック
let locked_worker = worker.lock();
// Locked状態でも履歴とシステムプロンプトにアクセス可能
assert_eq!(locked_worker.get_system_prompt(), Some("System"));
assert_eq!(locked_worker.history().len(), 2);
assert_eq!(locked_worker.locked_prefix_len(), 2);
}
/// unlock()でLocked -> Mutable状態に遷移することを確認
#[test]
fn test_unlock_transition() {
let client = MockLlmClient::new(vec![]);
let mut worker = Worker::new(client);
worker.push_message(Message::user("Hello"));
let locked_worker = worker.lock();
// アンロック
let mut worker = locked_worker.unlock();
// Mutable状態に戻ったので履歴操作が可能
worker.push_message(Message::assistant("Hi"));
worker.clear_history();
assert!(worker.history().is_empty());
}
// =============================================================================
// ターン実行と状態保持のテスト
// =============================================================================
/// Mutable状態でターンを実行し、履歴が正しく更新されることを確認
#[tokio::test]
async fn test_mutable_run_updates_history() {
let events = vec![
Event::text_block_start(0),
Event::text_delta(0, "Hello, I'm an assistant!"),
Event::text_block_stop(0, None),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
];
let client = MockLlmClient::new(events);
let mut worker = Worker::new(client);
// 実行
let result = worker.run("Hi there").await;
assert!(result.is_ok());
// 履歴が更新されている
let history = worker.history();
assert_eq!(history.len(), 2); // user + assistant
// ユーザーメッセージ
assert!(matches!(
&history[0].content,
MessageContent::Text(t) if t == "Hi there"
));
// アシスタントメッセージ
assert!(matches!(
&history[1].content,
MessageContent::Text(t) if t == "Hello, I'm an assistant!"
));
}
/// Locked状態で複数ターンを実行し、履歴が正しく累積することを確認
#[tokio::test]
async fn test_locked_multi_turn_history_accumulation() {
// 2回のリクエストに対応するレスポンスを準備
let client = MockLlmClient::with_responses(vec![
// 1回目のレスポンス
vec![
Event::text_block_start(0),
Event::text_delta(0, "Nice to meet you!"),
Event::text_block_stop(0, None),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
],
// 2回目のレスポンス
vec![
Event::text_block_start(0),
Event::text_delta(0, "I can help with that."),
Event::text_block_stop(0, None),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
],
]);
let worker = Worker::new(client).system_prompt("You are helpful.");
// ロック(システムプロンプト設定後)
let mut locked_worker = worker.lock();
assert_eq!(locked_worker.locked_prefix_len(), 0); // メッセージはまだない
// 1ターン目
let result1 = locked_worker.run("Hello!").await;
assert!(result1.is_ok());
assert_eq!(locked_worker.history().len(), 2); // user + assistant
// 2ターン目
let result2 = locked_worker.run("Can you help me?").await;
assert!(result2.is_ok());
assert_eq!(locked_worker.history().len(), 4); // 2 * (user + assistant)
// 履歴の内容を確認
let history = locked_worker.history();
// 1ターン目のユーザーメッセージ
assert!(matches!(&history[0].content, MessageContent::Text(t) if t == "Hello!"));
// 1ターン目のアシスタントメッセージ
assert!(matches!(&history[1].content, MessageContent::Text(t) if t == "Nice to meet you!"));
// 2ターン目のユーザーメッセージ
assert!(matches!(&history[2].content, MessageContent::Text(t) if t == "Can you help me?"));
// 2ターン目のアシスタントメッセージ
assert!(matches!(&history[3].content, MessageContent::Text(t) if t == "I can help with that."));
}
/// locked_prefix_lenがロック時点の履歴長を正しく記録することを確認
#[tokio::test]
async fn test_locked_prefix_len_tracking() {
let client = MockLlmClient::with_responses(vec![
vec![
Event::text_block_start(0),
Event::text_delta(0, "Response 1"),
Event::text_block_stop(0, None),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
],
vec![
Event::text_block_start(0),
Event::text_delta(0, "Response 2"),
Event::text_block_stop(0, None),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
],
]);
let mut worker = Worker::new(client);
// 事前にメッセージを追加
worker.push_message(Message::user("Pre-existing message 1"));
worker.push_message(Message::assistant("Pre-existing response 1"));
assert_eq!(worker.history().len(), 2);
// ロック
let mut locked_worker = worker.lock();
assert_eq!(locked_worker.locked_prefix_len(), 2); // ロック時点で2メッセージ
// ターン実行
locked_worker.run("New message").await.unwrap();
// 履歴は増えるが、locked_prefix_lenは変わらない
assert_eq!(locked_worker.history().len(), 4); // 2 + 2
assert_eq!(locked_worker.locked_prefix_len(), 2); // 変わらない
}
/// ターンカウントが正しくインクリメントされることを確認
#[tokio::test]
async fn test_turn_count_increment() {
let client = MockLlmClient::with_responses(vec![
vec![
Event::text_block_start(0),
Event::text_delta(0, "Turn 1"),
Event::text_block_stop(0, None),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
],
vec![
Event::text_block_start(0),
Event::text_delta(0, "Turn 2"),
Event::text_block_stop(0, None),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
],
]);
let mut worker = Worker::new(client);
assert_eq!(worker.turn_count(), 0);
worker.run("First").await.unwrap();
assert_eq!(worker.turn_count(), 1);
worker.run("Second").await.unwrap();
assert_eq!(worker.turn_count(), 2);
}
/// unlock後に履歴を編集し、再度lockできることを確認
#[tokio::test]
async fn test_unlock_edit_relock() {
let client = MockLlmClient::with_responses(vec![vec![
Event::text_block_start(0),
Event::text_delta(0, "Response"),
Event::text_block_stop(0, None),
Event::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
]]);
let worker = Worker::new(client)
.with_message(Message::user("Hello"))
.with_message(Message::assistant("Hi"));
// ロック -> アンロック
let locked = worker.lock();
assert_eq!(locked.locked_prefix_len(), 2);
let mut unlocked = locked.unlock();
// 履歴を編集
unlocked.clear_history();
unlocked.push_message(Message::user("Fresh start"));
// 再ロック
let relocked = unlocked.lock();
assert_eq!(relocked.history().len(), 1);
assert_eq!(relocked.locked_prefix_len(), 1);
}
// =============================================================================
// システムプロンプト保持のテスト
// =============================================================================
/// Locked状態でもシステムプロンプトが保持されることを確認
#[test]
fn test_system_prompt_preserved_in_locked_state() {
let client = MockLlmClient::new(vec![]);
let worker = Worker::new(client).system_prompt("Important system prompt");
let locked = worker.lock();
assert_eq!(locked.get_system_prompt(), Some("Important system prompt"));
let unlocked = locked.unlock();
assert_eq!(
unlocked.get_system_prompt(),
Some("Important system prompt")
);
}
/// unlock -> 再lock でシステムプロンプトを変更できることを確認
#[test]
fn test_system_prompt_change_after_unlock() {
let client = MockLlmClient::new(vec![]);
let worker = Worker::new(client).system_prompt("Original prompt");
let locked = worker.lock();
let mut unlocked = locked.unlock();
unlocked.set_system_prompt("New prompt");
assert_eq!(unlocked.get_system_prompt(), Some("New prompt"));
let relocked = unlocked.lock();
assert_eq!(relocked.get_system_prompt(), Some("New prompt"));
}