use std::collections::HashMap; use std::marker::PhantomData; use std::sync::{Arc, Mutex}; use futures::StreamExt; use tracing::{debug, info, trace, warn}; use crate::Timeline; use crate::llm_client::{ClientError, LlmClient, Request, ToolDefinition}; use crate::subscriber_adapter::{ ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter, ToolUseBlockSubscriberAdapter, UsageSubscriberAdapter, }; use crate::text_block_collector::TextBlockCollector; use crate::tool_call_collector::ToolCallCollector; use worker_types::{ ContentPart, ControlFlow, HookError, Locked, Message, MessageContent, Mutable, Tool, ToolCall, ToolError, ToolResult, TurnResult, WorkerHook, WorkerState, WorkerSubscriber, }; // ============================================================================= // Worker Error // ============================================================================= /// Workerエラー #[derive(Debug, thiserror::Error)] pub enum WorkerError { /// クライアントエラー #[error("Client error: {0}")] Client(#[from] ClientError), /// ツールエラー #[error("Tool error: {0}")] Tool(#[from] ToolError), /// Hookエラー #[error("Hook error: {0}")] Hook(#[from] HookError), /// 処理が中断された #[error("Aborted: {0}")] Aborted(String), } // ============================================================================= // Worker Config // ============================================================================= /// Worker設定 #[derive(Debug, Clone, Default)] pub struct WorkerConfig { // 将来の拡張用(現在は空) _private: (), } // ============================================================================= // ターン制御用コールバック保持 // ============================================================================= /// ターンイベントを通知するためのコールバック (型消去) trait TurnNotifier: Send { fn on_turn_start(&self, turn: usize); fn on_turn_end(&self, turn: usize); } struct SubscriberTurnNotifier { subscriber: Arc>, } impl TurnNotifier for SubscriberTurnNotifier { fn on_turn_start(&self, turn: usize) { if let Ok(mut s) = self.subscriber.lock() { s.on_turn_start(turn); } } fn on_turn_end(&self, turn: usize) { if let Ok(mut s) = self.subscriber.lock() { s.on_turn_end(turn); } } } // ============================================================================= // Worker // ============================================================================= /// LLMとの対話を管理する中心コンポーネント /// /// ユーザーからの入力を受け取り、LLMにリクエストを送信し、 /// ツール呼び出しがあれば自動的に実行してターンを進行させます。 /// /// # 状態遷移(Type-state) /// /// - [`Mutable`]: 初期状態。システムプロンプトや履歴を自由に編集可能。 /// - [`Locked`]: キャッシュ保護状態。`lock()`で遷移。前方コンテキストは不変。 /// /// # Examples /// /// ```ignore /// use worker::{Worker, Message}; /// /// // Workerを作成してツールを登録 /// let mut worker = Worker::new(client) /// .system_prompt("You are a helpful assistant."); /// worker.register_tool(my_tool); /// /// // 対話を実行 /// let history = worker.run("Hello!").await?; /// ``` /// /// # キャッシュ保護が必要な場合 /// /// ```ignore /// let mut worker = Worker::new(client) /// .system_prompt("..."); /// /// // 履歴を設定後、ロックしてキャッシュを保護 /// let mut locked = worker.lock(); /// locked.run("user input").await?; /// ``` pub struct Worker { /// LLMクライアント client: C, /// イベントタイムライン timeline: Timeline, /// テキストブロックコレクター(Timeline用ハンドラ) text_block_collector: TextBlockCollector, /// ツールコールコレクター(Timeline用ハンドラ) tool_call_collector: ToolCallCollector, /// 登録されたツール tools: HashMap>, /// 登録されたHook hooks: Vec>, /// システムプロンプト system_prompt: Option, /// メッセージ履歴(Workerが所有) history: Vec, /// ロック時点での履歴長(Locked状態でのみ意味を持つ) locked_prefix_len: usize, /// ターンカウント turn_count: usize, /// ターン通知用のコールバック turn_notifiers: Vec>, /// 状態マーカー _state: PhantomData, } // ============================================================================= // 共通実装(全状態で利用可能) // ============================================================================= impl Worker { /// イベント購読者を登録する /// /// 登録したSubscriberは、LLMからのストリーミングイベントを /// リアルタイムで受信できます。UIへのストリーム表示などに利用します。 /// /// # 受信できるイベント /// /// - **ブロックイベント**: `on_text_block`, `on_tool_use_block` /// - **メタイベント**: `on_usage`, `on_status`, `on_error` /// - **完了イベント**: `on_text_complete`, `on_tool_call_complete` /// - **ターン制御**: `on_turn_start`, `on_turn_end` /// /// # Examples /// /// ```ignore /// use worker::{Worker, WorkerSubscriber, TextBlockEvent}; /// /// struct MyPrinter; /// impl WorkerSubscriber for MyPrinter { /// type TextBlockScope = (); /// type ToolUseBlockScope = (); /// /// fn on_text_block(&mut self, _: &mut (), event: &TextBlockEvent) { /// if let TextBlockEvent::Delta(text) = event { /// print!("{}", text); /// } /// } /// } /// /// worker.subscribe(MyPrinter); /// ``` pub fn subscribe(&mut self, subscriber: Sub) { let subscriber = Arc::new(Mutex::new(subscriber)); // TextBlock用ハンドラを登録 self.timeline .on_text_block(TextBlockSubscriberAdapter::new(subscriber.clone())); // ToolUseBlock用ハンドラを登録 self.timeline .on_tool_use_block(ToolUseBlockSubscriberAdapter::new(subscriber.clone())); // Meta系ハンドラを登録 self.timeline .on_usage(UsageSubscriberAdapter::new(subscriber.clone())); self.timeline .on_status(StatusSubscriberAdapter::new(subscriber.clone())); self.timeline .on_error(ErrorSubscriberAdapter::new(subscriber.clone())); // ターン制御用コールバックを登録 self.turn_notifiers .push(Box::new(SubscriberTurnNotifier { subscriber })); } /// ツールを登録する /// /// 登録されたツールはLLMからの呼び出しで自動的に実行されます。 /// 同名のツールを登録した場合、後から登録したものが優先されます。 /// /// # Examples /// /// ```ignore /// use worker::Worker; /// use my_tools::SearchTool; /// /// worker.register_tool(SearchTool::new()); /// ``` pub fn register_tool(&mut self, tool: impl Tool + 'static) { let name = tool.name().to_string(); self.tools.insert(name, Arc::new(tool)); } /// 複数のツールを登録 pub fn register_tools(&mut self, tools: impl IntoIterator) { for tool in tools { self.register_tool(tool); } } /// Hookを追加する /// /// Hookはターンの進行・ツール実行に介入できます。 /// 複数のHookを登録した場合、登録順に実行されます。 /// /// # Examples /// /// ```ignore /// use worker::{Worker, WorkerHook, ControlFlow, ToolCall}; /// /// struct LoggingHook; /// /// #[async_trait::async_trait] /// impl WorkerHook for LoggingHook { /// async fn before_tool_call(&self, call: &mut ToolCall) -> Result { /// println!("Calling tool: {}", call.name); /// Ok(ControlFlow::Continue) /// } /// } /// /// worker.add_hook(LoggingHook); /// ``` pub fn add_hook(&mut self, hook: impl WorkerHook + 'static) { self.hooks.push(Box::new(hook)); } /// タイムラインへの可変参照を取得(追加ハンドラ登録用) pub fn timeline_mut(&mut self) -> &mut Timeline { &mut self.timeline } /// 履歴への参照を取得 pub fn history(&self) -> &[Message] { &self.history } /// システムプロンプトへの参照を取得 pub fn get_system_prompt(&self) -> Option<&str> { self.system_prompt.as_deref() } /// 現在のターンカウントを取得 pub fn turn_count(&self) -> usize { self.turn_count } /// 登録されたツールからToolDefinitionのリストを生成 fn build_tool_definitions(&self) -> Vec { self.tools .values() .map(|tool| { ToolDefinition::new(tool.name()) .description(tool.description()) .input_schema(tool.input_schema()) }) .collect() } /// テキストブロックとツール呼び出しからアシスタントメッセージを構築 fn build_assistant_message( &self, text_blocks: &[String], tool_calls: &[ToolCall], ) -> Option { // テキストもツール呼び出しもない場合はNone if text_blocks.is_empty() && tool_calls.is_empty() { return None; } // テキストのみの場合はシンプルなテキストメッセージ if tool_calls.is_empty() { let text = text_blocks.join(""); return Some(Message::assistant(text)); } // ツール呼び出しがある場合は Parts として構築 let mut parts = Vec::new(); // テキストパーツを追加 for text in text_blocks { if !text.is_empty() { parts.push(ContentPart::Text { text: text.clone() }); } } // ツール呼び出しパーツを追加 for call in tool_calls { parts.push(ContentPart::ToolUse { id: call.id.clone(), name: call.name.clone(), input: call.input.clone(), }); } Some(Message { role: worker_types::Role::Assistant, content: MessageContent::Parts(parts), }) } /// リクエストを構築 fn build_request(&self, tool_definitions: &[ToolDefinition]) -> Request { let mut request = Request::new(); // システムプロンプトを設定 if let Some(ref system) = self.system_prompt { request = request.system(system); } // メッセージを追加 for msg in &self.history { // worker-types::Message から llm_client::Message への変換 request = request.message(crate::llm_client::Message { role: match msg.role { worker_types::Role::User => crate::llm_client::Role::User, worker_types::Role::Assistant => crate::llm_client::Role::Assistant, }, content: match &msg.content { worker_types::MessageContent::Text(t) => { crate::llm_client::MessageContent::Text(t.clone()) } worker_types::MessageContent::ToolResult { tool_use_id, content, } => crate::llm_client::MessageContent::ToolResult { tool_use_id: tool_use_id.clone(), content: content.clone(), }, worker_types::MessageContent::Parts(parts) => { crate::llm_client::MessageContent::Parts( parts .iter() .map(|p| match p { worker_types::ContentPart::Text { text } => { crate::llm_client::ContentPart::Text { text: text.clone() } } worker_types::ContentPart::ToolUse { id, name, input } => { crate::llm_client::ContentPart::ToolUse { id: id.clone(), name: name.clone(), input: input.clone(), } } worker_types::ContentPart::ToolResult { tool_use_id, content, } => crate::llm_client::ContentPart::ToolResult { tool_use_id: tool_use_id.clone(), content: content.clone(), }, }) .collect(), ) } }, }); } // ツール定義を追加 for tool_def in tool_definitions { request = request.tool(tool_def.clone()); } request } /// Hooks: on_message_send async fn run_on_message_send_hooks(&self) -> Result { for hook in &self.hooks { // Note: Locked状態でも履歴全体を参照として渡す(変更は不可) // HookのAPIを変更し、immutable参照のみを渡すようにする必要があるかもしれない // 現在は空のVecを渡して回避(要検討) let mut temp_context = self.history.clone(); let result = hook.on_message_send(&mut temp_context).await?; match result { ControlFlow::Continue => continue, ControlFlow::Skip => return Ok(ControlFlow::Skip), ControlFlow::Abort(reason) => return Ok(ControlFlow::Abort(reason)), } } Ok(ControlFlow::Continue) } /// Hooks: on_turn_end async fn run_on_turn_end_hooks(&self) -> Result { for hook in &self.hooks { let result = hook.on_turn_end(&self.history).await?; match result { TurnResult::Finish => continue, TurnResult::ContinueWithMessages(msgs) => { return Ok(TurnResult::ContinueWithMessages(msgs)); } } } Ok(TurnResult::Finish) } /// ツールを並列実行 /// /// 全てのツールに対してbefore_tool_callフックを実行後、 /// 許可されたツールを並列に実行し、結果にafter_tool_callフックを適用する。 async fn execute_tools( &self, tool_calls: Vec, ) -> Result, WorkerError> { use futures::future::join_all; // Phase 1: before_tool_call フックを適用(スキップ/中断を判定) let mut approved_calls = Vec::new(); for mut tool_call in tool_calls { let mut skip = false; for hook in &self.hooks { let result = hook.before_tool_call(&mut tool_call).await?; match result { ControlFlow::Continue => {} ControlFlow::Skip => { skip = true; break; } ControlFlow::Abort(reason) => { return Err(WorkerError::Aborted(reason)); } } } if !skip { approved_calls.push(tool_call); } } // Phase 2: 許可されたツールを並列実行 let futures: Vec<_> = approved_calls .into_iter() .map(|tool_call| { let tools = &self.tools; async move { if let Some(tool) = tools.get(&tool_call.name) { let input_json = serde_json::to_string(&tool_call.input).unwrap_or_default(); match tool.execute(&input_json).await { Ok(content) => ToolResult::success(&tool_call.id, content), Err(e) => ToolResult::error(&tool_call.id, e.to_string()), } } else { ToolResult::error( &tool_call.id, format!("Tool '{}' not found", tool_call.name), ) } } }) .collect(); let mut results = join_all(futures).await; // Phase 3: after_tool_call フックを適用 for tool_result in &mut results { for hook in &self.hooks { let result = hook.after_tool_call(tool_result).await?; match result { ControlFlow::Continue => {} ControlFlow::Skip => break, ControlFlow::Abort(reason) => { return Err(WorkerError::Aborted(reason)); } } } } Ok(results) } /// 内部で使用するターン実行ロジック async fn run_turn_loop(&mut self) -> Result<(), WorkerError> { let tool_definitions = self.build_tool_definitions(); info!( message_count = self.history.len(), tool_count = tool_definitions.len(), "Starting worker run" ); loop { // ターン開始を通知 let current_turn = self.turn_count; debug!(turn = current_turn, "Turn start"); for notifier in &self.turn_notifiers { notifier.on_turn_start(current_turn); } // Hook: on_message_send let control = self.run_on_message_send_hooks().await?; if let ControlFlow::Abort(reason) = control { warn!(reason = %reason, "Aborted by hook"); // ターン終了を通知(異常終了) for notifier in &self.turn_notifiers { notifier.on_turn_end(current_turn); } return Err(WorkerError::Aborted(reason)); } // リクエスト構築 let request = self.build_request(&tool_definitions); debug!( message_count = request.messages.len(), tool_count = request.tools.len(), has_system = request.system_prompt.is_some(), "Sending request to LLM" ); // ストリーム処理 debug!("Starting stream..."); let mut stream = self.client.stream(request).await?; let mut event_count = 0; while let Some(event_result) = stream.next().await { match &event_result { Ok(event) => { trace!(event = ?event, "Received event"); event_count += 1; } Err(e) => { warn!(error = %e, "Stream error"); } } let event = event_result?; self.timeline.dispatch(&event); } debug!(event_count = event_count, "Stream completed"); // ターン終了を通知 for notifier in &self.turn_notifiers { notifier.on_turn_end(current_turn); } self.turn_count += 1; // 収集結果を取得 let text_blocks = self.text_block_collector.take_collected(); let tool_calls = self.tool_call_collector.take_collected(); // アシスタントメッセージを履歴に追加 let assistant_message = self.build_assistant_message(&text_blocks, &tool_calls); if let Some(msg) = assistant_message { self.history.push(msg); } if tool_calls.is_empty() { // ツール呼び出しなし → ターン終了判定 let turn_result = self.run_on_turn_end_hooks().await?; match turn_result { TurnResult::Finish => { return Ok(()); } TurnResult::ContinueWithMessages(additional) => { self.history.extend(additional); continue; } } } // ツール実行 let tool_results = self.execute_tools(tool_calls).await?; // ツール結果を履歴に追加 for result in tool_results { self.history .push(Message::tool_result(&result.tool_use_id, &result.content)); } } } } // ============================================================================= // Mutable状態専用の実装 // ============================================================================= impl Worker { /// 新しいWorkerを作成(Mutable状態) pub fn new(client: C) -> Self { let text_block_collector = TextBlockCollector::new(); let tool_call_collector = ToolCallCollector::new(); let mut timeline = Timeline::new(); // コレクターをTimelineに登録 timeline.on_text_block(text_block_collector.clone()); timeline.on_tool_use_block(tool_call_collector.clone()); Self { client, timeline, text_block_collector, tool_call_collector, tools: HashMap::new(), hooks: Vec::new(), system_prompt: None, history: Vec::new(), locked_prefix_len: 0, turn_count: 0, turn_notifiers: Vec::new(), _state: PhantomData, } } /// システムプロンプトを設定(ビルダーパターン) pub fn system_prompt(mut self, prompt: impl Into) -> Self { self.system_prompt = Some(prompt.into()); self } /// システムプロンプトを設定(可変参照版) pub fn set_system_prompt(&mut self, prompt: impl Into) { self.system_prompt = Some(prompt.into()); } /// 履歴への可変参照を取得 /// /// Mutable状態でのみ利用可能。履歴を自由に編集できる。 pub fn history_mut(&mut self) -> &mut Vec { &mut self.history } /// 履歴を設定 pub fn set_history(&mut self, messages: Vec) { self.history = messages; } /// 履歴にメッセージを追加(ビルダーパターン) pub fn with_message(mut self, message: Message) -> Self { self.history.push(message); self } /// 履歴にメッセージを追加 pub fn push_message(&mut self, message: Message) { self.history.push(message); } /// 複数のメッセージを履歴に追加(ビルダーパターン) pub fn with_messages(mut self, messages: impl IntoIterator) -> Self { self.history.extend(messages); self } /// 複数のメッセージを履歴に追加 pub fn extend_history(&mut self, messages: impl IntoIterator) { self.history.extend(messages); } /// 履歴をクリア pub fn clear_history(&mut self) { self.history.clear(); } /// 設定を適用(将来の拡張用) #[allow(dead_code)] pub fn config(self, _config: WorkerConfig) -> Self { self } /// ロックしてLocked状態へ遷移 /// /// この操作により、現在のシステムプロンプトと履歴が「確定済みプレフィックス」として /// 固定される。以降は履歴への追記のみが可能となり、キャッシュヒットが保証される。 pub fn lock(self) -> Worker { let locked_prefix_len = self.history.len(); Worker { client: self.client, timeline: self.timeline, text_block_collector: self.text_block_collector, tool_call_collector: self.tool_call_collector, tools: self.tools, hooks: self.hooks, system_prompt: self.system_prompt, history: self.history, locked_prefix_len, turn_count: self.turn_count, turn_notifiers: self.turn_notifiers, _state: PhantomData, } } /// ターンを実行(Mutable状態) /// /// 新しいユーザーメッセージを履歴に追加し、LLMにリクエストを送信する。 /// ツール呼び出しがある場合は自動的にループする。 /// /// 注意: この関数は履歴を変更するため、キャッシュ保護が必要な場合は /// `lock()` を呼んでからLocked状態で `run` を使用すること。 pub async fn run(&mut self, user_input: impl Into) -> Result<&[Message], WorkerError> { self.history.push(Message::user(user_input)); self.run_turn_loop().await?; Ok(&self.history) } /// 複数メッセージでターンを実行(Mutable状態) /// /// 指定されたメッセージを履歴に追加してから実行する。 pub async fn run_with_messages( &mut self, messages: Vec, ) -> Result<&[Message], WorkerError> { self.history.extend(messages); self.run_turn_loop().await?; Ok(&self.history) } } // ============================================================================= // Locked状態専用の実装 // ============================================================================= impl Worker { /// ターンを実行(Locked状態) /// /// 新しいユーザーメッセージを履歴の末尾に追加し、LLMにリクエストを送信する。 /// ロック時点より前の履歴(プレフィックス)は不変であるため、キャッシュヒットが保証される。 pub async fn run(&mut self, user_input: impl Into) -> Result<&[Message], WorkerError> { self.history.push(Message::user(user_input)); self.run_turn_loop().await?; Ok(&self.history) } /// 複数メッセージでターンを実行(Locked状態) pub async fn run_with_messages( &mut self, messages: Vec, ) -> Result<&[Message], WorkerError> { self.history.extend(messages); self.run_turn_loop().await?; Ok(&self.history) } /// ロック時点のプレフィックス長を取得 pub fn locked_prefix_len(&self) -> usize { self.locked_prefix_len } /// ロックを解除してMutable状態へ戻す /// /// 注意: この操作を行うと、以降のリクエストでキャッシュがヒットしなくなる可能性がある。 /// 履歴を編集する必要がある場合にのみ使用すること。 pub fn unlock(self) -> Worker { Worker { client: self.client, timeline: self.timeline, text_block_collector: self.text_block_collector, tool_call_collector: self.tool_call_collector, tools: self.tools, hooks: self.hooks, system_prompt: self.system_prompt, history: self.history, locked_prefix_len: 0, turn_count: self.turn_count, turn_notifiers: self.turn_notifiers, _state: PhantomData, } } } #[cfg(test)] mod tests { // 基本的なテストのみ。LlmClientを使ったテストは統合テストで行う。 }