llm_worker_rs/worker/src/worker.rs
2026-01-08 18:23:16 +09:00

792 lines
28 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::{Arc, Mutex};
use futures::StreamExt;
use tracing::{debug, info, trace, warn};
use crate::Timeline;
use crate::llm_client::{ClientError, LlmClient, Request, ToolDefinition};
use crate::subscriber_adapter::{
ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter,
ToolUseBlockSubscriberAdapter, UsageSubscriberAdapter,
};
use crate::text_block_collector::TextBlockCollector;
use crate::tool_call_collector::ToolCallCollector;
use worker_types::{
ContentPart, ControlFlow, HookError, Locked, Message, MessageContent, Mutable, Tool, ToolCall,
ToolError, ToolResult, TurnResult, WorkerHook, WorkerState, WorkerSubscriber,
};
// =============================================================================
// Worker Error
// =============================================================================
/// Workerエラー
#[derive(Debug, thiserror::Error)]
pub enum WorkerError {
/// クライアントエラー
#[error("Client error: {0}")]
Client(#[from] ClientError),
/// ツールエラー
#[error("Tool error: {0}")]
Tool(#[from] ToolError),
/// Hookエラー
#[error("Hook error: {0}")]
Hook(#[from] HookError),
/// 処理が中断された
#[error("Aborted: {0}")]
Aborted(String),
}
// =============================================================================
// Worker Config
// =============================================================================
/// Worker設定
#[derive(Debug, Clone, Default)]
pub struct WorkerConfig {
// 将来の拡張用(現在は空)
_private: (),
}
// =============================================================================
// ターン制御用コールバック保持
// =============================================================================
/// ターンイベントを通知するためのコールバック (型消去)
trait TurnNotifier: Send {
fn on_turn_start(&self, turn: usize);
fn on_turn_end(&self, turn: usize);
}
struct SubscriberTurnNotifier<S: WorkerSubscriber + 'static> {
subscriber: Arc<Mutex<S>>,
}
impl<S: WorkerSubscriber + 'static> TurnNotifier for SubscriberTurnNotifier<S> {
fn on_turn_start(&self, turn: usize) {
if let Ok(mut s) = self.subscriber.lock() {
s.on_turn_start(turn);
}
}
fn on_turn_end(&self, turn: usize) {
if let Ok(mut s) = self.subscriber.lock() {
s.on_turn_end(turn);
}
}
}
// =============================================================================
// Worker
// =============================================================================
/// LLMとの対話を管理する中心コンポーネント
///
/// ユーザーからの入力を受け取り、LLMにリクエストを送信し、
/// ツール呼び出しがあれば自動的に実行してターンを進行させます。
///
/// # 状態遷移Type-state
///
/// - [`Mutable`]: 初期状態。システムプロンプトや履歴を自由に編集可能。
/// - [`Locked`]: キャッシュ保護状態。`lock()`で遷移。前方コンテキストは不変。
///
/// # Examples
///
/// ```ignore
/// use worker::{Worker, Message};
///
/// // Workerを作成してツールを登録
/// let mut worker = Worker::new(client)
/// .system_prompt("You are a helpful assistant.");
/// worker.register_tool(my_tool);
///
/// // 対話を実行
/// let history = worker.run("Hello!").await?;
/// ```
///
/// # キャッシュ保護が必要な場合
///
/// ```ignore
/// let mut worker = Worker::new(client)
/// .system_prompt("...");
///
/// // 履歴を設定後、ロックしてキャッシュを保護
/// let mut locked = worker.lock();
/// locked.run("user input").await?;
/// ```
pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
/// LLMクライアント
client: C,
/// イベントタイムライン
timeline: Timeline,
/// テキストブロックコレクターTimeline用ハンドラ
text_block_collector: TextBlockCollector,
/// ツールコールコレクターTimeline用ハンドラ
tool_call_collector: ToolCallCollector,
/// 登録されたツール
tools: HashMap<String, Arc<dyn Tool>>,
/// 登録されたHook
hooks: Vec<Box<dyn WorkerHook>>,
/// システムプロンプト
system_prompt: Option<String>,
/// メッセージ履歴Workerが所有
history: Vec<Message>,
/// ロック時点での履歴長Locked状態でのみ意味を持つ
locked_prefix_len: usize,
/// ターンカウント
turn_count: usize,
/// ターン通知用のコールバック
turn_notifiers: Vec<Box<dyn TurnNotifier>>,
/// 状態マーカー
_state: PhantomData<S>,
}
// =============================================================================
// 共通実装(全状態で利用可能)
// =============================================================================
impl<C: LlmClient, S: WorkerState> Worker<C, S> {
/// イベント購読者を登録する
///
/// 登録したSubscriberは、LLMからのストリーミングイベントを
/// リアルタイムで受信できます。UIへのストリーム表示などに利用します。
///
/// # 受信できるイベント
///
/// - **ブロックイベント**: `on_text_block`, `on_tool_use_block`
/// - **メタイベント**: `on_usage`, `on_status`, `on_error`
/// - **完了イベント**: `on_text_complete`, `on_tool_call_complete`
/// - **ターン制御**: `on_turn_start`, `on_turn_end`
///
/// # Examples
///
/// ```ignore
/// use worker::{Worker, WorkerSubscriber, TextBlockEvent};
///
/// struct MyPrinter;
/// impl WorkerSubscriber for MyPrinter {
/// type TextBlockScope = ();
/// type ToolUseBlockScope = ();
///
/// fn on_text_block(&mut self, _: &mut (), event: &TextBlockEvent) {
/// if let TextBlockEvent::Delta(text) = event {
/// print!("{}", text);
/// }
/// }
/// }
///
/// worker.subscribe(MyPrinter);
/// ```
pub fn subscribe<Sub: WorkerSubscriber + 'static>(&mut self, subscriber: Sub) {
let subscriber = Arc::new(Mutex::new(subscriber));
// TextBlock用ハンドラを登録
self.timeline
.on_text_block(TextBlockSubscriberAdapter::new(subscriber.clone()));
// ToolUseBlock用ハンドラを登録
self.timeline
.on_tool_use_block(ToolUseBlockSubscriberAdapter::new(subscriber.clone()));
// Meta系ハンドラを登録
self.timeline
.on_usage(UsageSubscriberAdapter::new(subscriber.clone()));
self.timeline
.on_status(StatusSubscriberAdapter::new(subscriber.clone()));
self.timeline
.on_error(ErrorSubscriberAdapter::new(subscriber.clone()));
// ターン制御用コールバックを登録
self.turn_notifiers
.push(Box::new(SubscriberTurnNotifier { subscriber }));
}
/// ツールを登録する
///
/// 登録されたツールはLLMからの呼び出しで自動的に実行されます。
/// 同名のツールを登録した場合、後から登録したものが優先されます。
///
/// # Examples
///
/// ```ignore
/// use worker::Worker;
/// use my_tools::SearchTool;
///
/// worker.register_tool(SearchTool::new());
/// ```
pub fn register_tool(&mut self, tool: impl Tool + 'static) {
let name = tool.name().to_string();
self.tools.insert(name, Arc::new(tool));
}
/// 複数のツールを登録
pub fn register_tools(&mut self, tools: impl IntoIterator<Item = impl Tool + 'static>) {
for tool in tools {
self.register_tool(tool);
}
}
/// Hookを追加する
///
/// Hookはターンの進行・ツール実行に介入できます。
/// 複数のHookを登録した場合、登録順に実行されます。
///
/// # Examples
///
/// ```ignore
/// use worker::{Worker, WorkerHook, ControlFlow, ToolCall};
///
/// struct LoggingHook;
///
/// #[async_trait::async_trait]
/// impl WorkerHook for LoggingHook {
/// async fn before_tool_call(&self, call: &mut ToolCall) -> Result<ControlFlow, HookError> {
/// println!("Calling tool: {}", call.name);
/// Ok(ControlFlow::Continue)
/// }
/// }
///
/// worker.add_hook(LoggingHook);
/// ```
pub fn add_hook(&mut self, hook: impl WorkerHook + 'static) {
self.hooks.push(Box::new(hook));
}
/// タイムラインへの可変参照を取得(追加ハンドラ登録用)
pub fn timeline_mut(&mut self) -> &mut Timeline {
&mut self.timeline
}
/// 履歴への参照を取得
pub fn history(&self) -> &[Message] {
&self.history
}
/// システムプロンプトへの参照を取得
pub fn get_system_prompt(&self) -> Option<&str> {
self.system_prompt.as_deref()
}
/// 現在のターンカウントを取得
pub fn turn_count(&self) -> usize {
self.turn_count
}
/// 登録されたツールからToolDefinitionのリストを生成
fn build_tool_definitions(&self) -> Vec<ToolDefinition> {
self.tools
.values()
.map(|tool| {
ToolDefinition::new(tool.name())
.description(tool.description())
.input_schema(tool.input_schema())
})
.collect()
}
/// テキストブロックとツール呼び出しからアシスタントメッセージを構築
fn build_assistant_message(
&self,
text_blocks: &[String],
tool_calls: &[ToolCall],
) -> Option<Message> {
// テキストもツール呼び出しもない場合はNone
if text_blocks.is_empty() && tool_calls.is_empty() {
return None;
}
// テキストのみの場合はシンプルなテキストメッセージ
if tool_calls.is_empty() {
let text = text_blocks.join("");
return Some(Message::assistant(text));
}
// ツール呼び出しがある場合は Parts として構築
let mut parts = Vec::new();
// テキストパーツを追加
for text in text_blocks {
if !text.is_empty() {
parts.push(ContentPart::Text { text: text.clone() });
}
}
// ツール呼び出しパーツを追加
for call in tool_calls {
parts.push(ContentPart::ToolUse {
id: call.id.clone(),
name: call.name.clone(),
input: call.input.clone(),
});
}
Some(Message {
role: worker_types::Role::Assistant,
content: MessageContent::Parts(parts),
})
}
/// リクエストを構築
fn build_request(&self, tool_definitions: &[ToolDefinition]) -> Request {
let mut request = Request::new();
// システムプロンプトを設定
if let Some(ref system) = self.system_prompt {
request = request.system(system);
}
// メッセージを追加
for msg in &self.history {
// worker-types::Message から llm_client::Message への変換
request = request.message(crate::llm_client::Message {
role: match msg.role {
worker_types::Role::User => crate::llm_client::Role::User,
worker_types::Role::Assistant => crate::llm_client::Role::Assistant,
},
content: match &msg.content {
worker_types::MessageContent::Text(t) => {
crate::llm_client::MessageContent::Text(t.clone())
}
worker_types::MessageContent::ToolResult {
tool_use_id,
content,
} => crate::llm_client::MessageContent::ToolResult {
tool_use_id: tool_use_id.clone(),
content: content.clone(),
},
worker_types::MessageContent::Parts(parts) => {
crate::llm_client::MessageContent::Parts(
parts
.iter()
.map(|p| match p {
worker_types::ContentPart::Text { text } => {
crate::llm_client::ContentPart::Text { text: text.clone() }
}
worker_types::ContentPart::ToolUse { id, name, input } => {
crate::llm_client::ContentPart::ToolUse {
id: id.clone(),
name: name.clone(),
input: input.clone(),
}
}
worker_types::ContentPart::ToolResult {
tool_use_id,
content,
} => crate::llm_client::ContentPart::ToolResult {
tool_use_id: tool_use_id.clone(),
content: content.clone(),
},
})
.collect(),
)
}
},
});
}
// ツール定義を追加
for tool_def in tool_definitions {
request = request.tool(tool_def.clone());
}
request
}
/// Hooks: on_message_send
async fn run_on_message_send_hooks(&self) -> Result<ControlFlow, WorkerError> {
for hook in &self.hooks {
// Note: Locked状態でも履歴全体を参照として渡す変更は不可
// HookのAPIを変更し、immutable参照のみを渡すようにする必要があるかもしれない
// 現在は空のVecを渡して回避要検討
let mut temp_context = self.history.clone();
let result = hook.on_message_send(&mut temp_context).await?;
match result {
ControlFlow::Continue => continue,
ControlFlow::Skip => return Ok(ControlFlow::Skip),
ControlFlow::Abort(reason) => return Ok(ControlFlow::Abort(reason)),
}
}
Ok(ControlFlow::Continue)
}
/// Hooks: on_turn_end
async fn run_on_turn_end_hooks(&self) -> Result<TurnResult, WorkerError> {
for hook in &self.hooks {
let result = hook.on_turn_end(&self.history).await?;
match result {
TurnResult::Finish => continue,
TurnResult::ContinueWithMessages(msgs) => {
return Ok(TurnResult::ContinueWithMessages(msgs));
}
}
}
Ok(TurnResult::Finish)
}
/// ツールを並列実行
///
/// 全てのツールに対してbefore_tool_callフックを実行後、
/// 許可されたツールを並列に実行し、結果にafter_tool_callフックを適用する。
async fn execute_tools(
&self,
tool_calls: Vec<ToolCall>,
) -> Result<Vec<ToolResult>, WorkerError> {
use futures::future::join_all;
// Phase 1: before_tool_call フックを適用(スキップ/中断を判定)
let mut approved_calls = Vec::new();
for mut tool_call in tool_calls {
let mut skip = false;
for hook in &self.hooks {
let result = hook.before_tool_call(&mut tool_call).await?;
match result {
ControlFlow::Continue => {}
ControlFlow::Skip => {
skip = true;
break;
}
ControlFlow::Abort(reason) => {
return Err(WorkerError::Aborted(reason));
}
}
}
if !skip {
approved_calls.push(tool_call);
}
}
// Phase 2: 許可されたツールを並列実行
let futures: Vec<_> = approved_calls
.into_iter()
.map(|tool_call| {
let tools = &self.tools;
async move {
if let Some(tool) = tools.get(&tool_call.name) {
let input_json =
serde_json::to_string(&tool_call.input).unwrap_or_default();
match tool.execute(&input_json).await {
Ok(content) => ToolResult::success(&tool_call.id, content),
Err(e) => ToolResult::error(&tool_call.id, e.to_string()),
}
} else {
ToolResult::error(
&tool_call.id,
format!("Tool '{}' not found", tool_call.name),
)
}
}
})
.collect();
let mut results = join_all(futures).await;
// Phase 3: after_tool_call フックを適用
for tool_result in &mut results {
for hook in &self.hooks {
let result = hook.after_tool_call(tool_result).await?;
match result {
ControlFlow::Continue => {}
ControlFlow::Skip => break,
ControlFlow::Abort(reason) => {
return Err(WorkerError::Aborted(reason));
}
}
}
}
Ok(results)
}
/// 内部で使用するターン実行ロジック
async fn run_turn_loop(&mut self) -> Result<(), WorkerError> {
let tool_definitions = self.build_tool_definitions();
info!(
message_count = self.history.len(),
tool_count = tool_definitions.len(),
"Starting worker run"
);
loop {
// ターン開始を通知
let current_turn = self.turn_count;
debug!(turn = current_turn, "Turn start");
for notifier in &self.turn_notifiers {
notifier.on_turn_start(current_turn);
}
// Hook: on_message_send
let control = self.run_on_message_send_hooks().await?;
if let ControlFlow::Abort(reason) = control {
warn!(reason = %reason, "Aborted by hook");
// ターン終了を通知(異常終了)
for notifier in &self.turn_notifiers {
notifier.on_turn_end(current_turn);
}
return Err(WorkerError::Aborted(reason));
}
// リクエスト構築
let request = self.build_request(&tool_definitions);
debug!(
message_count = request.messages.len(),
tool_count = request.tools.len(),
has_system = request.system_prompt.is_some(),
"Sending request to LLM"
);
// ストリーム処理
debug!("Starting stream...");
let mut stream = self.client.stream(request).await?;
let mut event_count = 0;
while let Some(event_result) = stream.next().await {
match &event_result {
Ok(event) => {
trace!(event = ?event, "Received event");
event_count += 1;
}
Err(e) => {
warn!(error = %e, "Stream error");
}
}
let event = event_result?;
self.timeline.dispatch(&event);
}
debug!(event_count = event_count, "Stream completed");
// ターン終了を通知
for notifier in &self.turn_notifiers {
notifier.on_turn_end(current_turn);
}
self.turn_count += 1;
// 収集結果を取得
let text_blocks = self.text_block_collector.take_collected();
let tool_calls = self.tool_call_collector.take_collected();
// アシスタントメッセージを履歴に追加
let assistant_message = self.build_assistant_message(&text_blocks, &tool_calls);
if let Some(msg) = assistant_message {
self.history.push(msg);
}
if tool_calls.is_empty() {
// ツール呼び出しなし → ターン終了判定
let turn_result = self.run_on_turn_end_hooks().await?;
match turn_result {
TurnResult::Finish => {
return Ok(());
}
TurnResult::ContinueWithMessages(additional) => {
self.history.extend(additional);
continue;
}
}
}
// ツール実行
let tool_results = self.execute_tools(tool_calls).await?;
// ツール結果を履歴に追加
for result in tool_results {
self.history
.push(Message::tool_result(&result.tool_use_id, &result.content));
}
}
}
}
// =============================================================================
// Mutable状態専用の実装
// =============================================================================
impl<C: LlmClient> Worker<C, Mutable> {
/// 新しいWorkerを作成Mutable状態
pub fn new(client: C) -> Self {
let text_block_collector = TextBlockCollector::new();
let tool_call_collector = ToolCallCollector::new();
let mut timeline = Timeline::new();
// コレクターをTimelineに登録
timeline.on_text_block(text_block_collector.clone());
timeline.on_tool_use_block(tool_call_collector.clone());
Self {
client,
timeline,
text_block_collector,
tool_call_collector,
tools: HashMap::new(),
hooks: Vec::new(),
system_prompt: None,
history: Vec::new(),
locked_prefix_len: 0,
turn_count: 0,
turn_notifiers: Vec::new(),
_state: PhantomData,
}
}
/// システムプロンプトを設定(ビルダーパターン)
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
/// システムプロンプトを設定(可変参照版)
pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
self.system_prompt = Some(prompt.into());
}
/// 履歴への可変参照を取得
///
/// Mutable状態でのみ利用可能。履歴を自由に編集できる。
pub fn history_mut(&mut self) -> &mut Vec<Message> {
&mut self.history
}
/// 履歴を設定
pub fn set_history(&mut self, messages: Vec<Message>) {
self.history = messages;
}
/// 履歴にメッセージを追加(ビルダーパターン)
pub fn with_message(mut self, message: Message) -> Self {
self.history.push(message);
self
}
/// 履歴にメッセージを追加
pub fn push_message(&mut self, message: Message) {
self.history.push(message);
}
/// 複数のメッセージを履歴に追加(ビルダーパターン)
pub fn with_messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
self.history.extend(messages);
self
}
/// 複数のメッセージを履歴に追加
pub fn extend_history(&mut self, messages: impl IntoIterator<Item = Message>) {
self.history.extend(messages);
}
/// 履歴をクリア
pub fn clear_history(&mut self) {
self.history.clear();
}
/// 設定を適用(将来の拡張用)
#[allow(dead_code)]
pub fn config(self, _config: WorkerConfig) -> Self {
self
}
/// ロックしてLocked状態へ遷移
///
/// この操作により、現在のシステムプロンプトと履歴が「確定済みプレフィックス」として
/// 固定される。以降は履歴への追記のみが可能となり、キャッシュヒットが保証される。
pub fn lock(self) -> Worker<C, Locked> {
let locked_prefix_len = self.history.len();
Worker {
client: self.client,
timeline: self.timeline,
text_block_collector: self.text_block_collector,
tool_call_collector: self.tool_call_collector,
tools: self.tools,
hooks: self.hooks,
system_prompt: self.system_prompt,
history: self.history,
locked_prefix_len,
turn_count: self.turn_count,
turn_notifiers: self.turn_notifiers,
_state: PhantomData,
}
}
/// ターンを実行Mutable状態
///
/// 新しいユーザーメッセージを履歴に追加し、LLMにリクエストを送信する。
/// ツール呼び出しがある場合は自動的にループする。
///
/// 注意: この関数は履歴を変更するため、キャッシュ保護が必要な場合は
/// `lock()` を呼んでからLocked状態で `run` を使用すること。
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<&[Message], WorkerError> {
self.history.push(Message::user(user_input));
self.run_turn_loop().await?;
Ok(&self.history)
}
/// 複数メッセージでターンを実行Mutable状態
///
/// 指定されたメッセージを履歴に追加してから実行する。
pub async fn run_with_messages(
&mut self,
messages: Vec<Message>,
) -> Result<&[Message], WorkerError> {
self.history.extend(messages);
self.run_turn_loop().await?;
Ok(&self.history)
}
}
// =============================================================================
// Locked状態専用の実装
// =============================================================================
impl<C: LlmClient> Worker<C, Locked> {
/// ターンを実行Locked状態
///
/// 新しいユーザーメッセージを履歴の末尾に追加し、LLMにリクエストを送信する。
/// ロック時点より前の履歴(プレフィックス)は不変であるため、キャッシュヒットが保証される。
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<&[Message], WorkerError> {
self.history.push(Message::user(user_input));
self.run_turn_loop().await?;
Ok(&self.history)
}
/// 複数メッセージでターンを実行Locked状態
pub async fn run_with_messages(
&mut self,
messages: Vec<Message>,
) -> Result<&[Message], WorkerError> {
self.history.extend(messages);
self.run_turn_loop().await?;
Ok(&self.history)
}
/// ロック時点のプレフィックス長を取得
pub fn locked_prefix_len(&self) -> usize {
self.locked_prefix_len
}
/// ロックを解除してMutable状態へ戻す
///
/// 注意: この操作を行うと、以降のリクエストでキャッシュがヒットしなくなる可能性がある。
/// 履歴を編集する必要がある場合にのみ使用すること。
pub fn unlock(self) -> Worker<C, Mutable> {
Worker {
client: self.client,
timeline: self.timeline,
text_block_collector: self.text_block_collector,
tool_call_collector: self.tool_call_collector,
tools: self.tools,
hooks: self.hooks,
system_prompt: self.system_prompt,
history: self.history,
locked_prefix_len: 0,
turn_count: self.turn_count,
turn_notifiers: self.turn_notifiers,
_state: PhantomData,
}
}
}
#[cfg(test)]
mod tests {
// 基本的なテストのみ。LlmClientを使ったテストは統合テストで行う。
}