792 lines
28 KiB
Rust
792 lines
28 KiB
Rust
use std::collections::HashMap;
|
||
use std::marker::PhantomData;
|
||
use std::sync::{Arc, Mutex};
|
||
|
||
use futures::StreamExt;
|
||
use tracing::{debug, info, trace, warn};
|
||
|
||
use crate::Timeline;
|
||
use crate::llm_client::{ClientError, LlmClient, Request, ToolDefinition};
|
||
use crate::subscriber_adapter::{
|
||
ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter,
|
||
ToolUseBlockSubscriberAdapter, UsageSubscriberAdapter,
|
||
};
|
||
use crate::text_block_collector::TextBlockCollector;
|
||
use crate::tool_call_collector::ToolCallCollector;
|
||
use worker_types::{
|
||
ContentPart, ControlFlow, HookError, Locked, Message, MessageContent, Mutable, Tool, ToolCall,
|
||
ToolError, ToolResult, TurnResult, WorkerHook, WorkerState, WorkerSubscriber,
|
||
};
|
||
|
||
// =============================================================================
|
||
// Worker Error
|
||
// =============================================================================
|
||
|
||
/// Workerエラー
|
||
#[derive(Debug, thiserror::Error)]
|
||
pub enum WorkerError {
|
||
/// クライアントエラー
|
||
#[error("Client error: {0}")]
|
||
Client(#[from] ClientError),
|
||
/// ツールエラー
|
||
#[error("Tool error: {0}")]
|
||
Tool(#[from] ToolError),
|
||
/// Hookエラー
|
||
#[error("Hook error: {0}")]
|
||
Hook(#[from] HookError),
|
||
/// 処理が中断された
|
||
#[error("Aborted: {0}")]
|
||
Aborted(String),
|
||
}
|
||
|
||
// =============================================================================
|
||
// Worker Config
|
||
// =============================================================================
|
||
|
||
/// Worker設定
|
||
#[derive(Debug, Clone, Default)]
|
||
pub struct WorkerConfig {
|
||
// 将来の拡張用(現在は空)
|
||
_private: (),
|
||
}
|
||
|
||
// =============================================================================
|
||
// ターン制御用コールバック保持
|
||
// =============================================================================
|
||
|
||
/// ターンイベントを通知するためのコールバック (型消去)
|
||
trait TurnNotifier: Send {
|
||
fn on_turn_start(&self, turn: usize);
|
||
fn on_turn_end(&self, turn: usize);
|
||
}
|
||
|
||
struct SubscriberTurnNotifier<S: WorkerSubscriber + 'static> {
|
||
subscriber: Arc<Mutex<S>>,
|
||
}
|
||
|
||
impl<S: WorkerSubscriber + 'static> TurnNotifier for SubscriberTurnNotifier<S> {
|
||
fn on_turn_start(&self, turn: usize) {
|
||
if let Ok(mut s) = self.subscriber.lock() {
|
||
s.on_turn_start(turn);
|
||
}
|
||
}
|
||
|
||
fn on_turn_end(&self, turn: usize) {
|
||
if let Ok(mut s) = self.subscriber.lock() {
|
||
s.on_turn_end(turn);
|
||
}
|
||
}
|
||
}
|
||
|
||
// =============================================================================
|
||
// Worker
|
||
// =============================================================================
|
||
|
||
/// LLMとの対話を管理する中心コンポーネント
|
||
///
|
||
/// ユーザーからの入力を受け取り、LLMにリクエストを送信し、
|
||
/// ツール呼び出しがあれば自動的に実行してターンを進行させます。
|
||
///
|
||
/// # 状態遷移(Type-state)
|
||
///
|
||
/// - [`Mutable`]: 初期状態。システムプロンプトや履歴を自由に編集可能。
|
||
/// - [`Locked`]: キャッシュ保護状態。`lock()`で遷移。前方コンテキストは不変。
|
||
///
|
||
/// # Examples
|
||
///
|
||
/// ```ignore
|
||
/// use worker::{Worker, Message};
|
||
///
|
||
/// // Workerを作成してツールを登録
|
||
/// let mut worker = Worker::new(client)
|
||
/// .system_prompt("You are a helpful assistant.");
|
||
/// worker.register_tool(my_tool);
|
||
///
|
||
/// // 対話を実行
|
||
/// let history = worker.run("Hello!").await?;
|
||
/// ```
|
||
///
|
||
/// # キャッシュ保護が必要な場合
|
||
///
|
||
/// ```ignore
|
||
/// let mut worker = Worker::new(client)
|
||
/// .system_prompt("...");
|
||
///
|
||
/// // 履歴を設定後、ロックしてキャッシュを保護
|
||
/// let mut locked = worker.lock();
|
||
/// locked.run("user input").await?;
|
||
/// ```
|
||
pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
||
/// LLMクライアント
|
||
client: C,
|
||
/// イベントタイムライン
|
||
timeline: Timeline,
|
||
/// テキストブロックコレクター(Timeline用ハンドラ)
|
||
text_block_collector: TextBlockCollector,
|
||
/// ツールコールコレクター(Timeline用ハンドラ)
|
||
tool_call_collector: ToolCallCollector,
|
||
/// 登録されたツール
|
||
tools: HashMap<String, Arc<dyn Tool>>,
|
||
/// 登録されたHook
|
||
hooks: Vec<Box<dyn WorkerHook>>,
|
||
/// システムプロンプト
|
||
system_prompt: Option<String>,
|
||
/// メッセージ履歴(Workerが所有)
|
||
history: Vec<Message>,
|
||
/// ロック時点での履歴長(Locked状態でのみ意味を持つ)
|
||
locked_prefix_len: usize,
|
||
/// ターンカウント
|
||
turn_count: usize,
|
||
/// ターン通知用のコールバック
|
||
turn_notifiers: Vec<Box<dyn TurnNotifier>>,
|
||
/// 状態マーカー
|
||
_state: PhantomData<S>,
|
||
}
|
||
|
||
// =============================================================================
|
||
// 共通実装(全状態で利用可能)
|
||
// =============================================================================
|
||
|
||
impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||
/// イベント購読者を登録する
|
||
///
|
||
/// 登録したSubscriberは、LLMからのストリーミングイベントを
|
||
/// リアルタイムで受信できます。UIへのストリーム表示などに利用します。
|
||
///
|
||
/// # 受信できるイベント
|
||
///
|
||
/// - **ブロックイベント**: `on_text_block`, `on_tool_use_block`
|
||
/// - **メタイベント**: `on_usage`, `on_status`, `on_error`
|
||
/// - **完了イベント**: `on_text_complete`, `on_tool_call_complete`
|
||
/// - **ターン制御**: `on_turn_start`, `on_turn_end`
|
||
///
|
||
/// # Examples
|
||
///
|
||
/// ```ignore
|
||
/// use worker::{Worker, WorkerSubscriber, TextBlockEvent};
|
||
///
|
||
/// struct MyPrinter;
|
||
/// impl WorkerSubscriber for MyPrinter {
|
||
/// type TextBlockScope = ();
|
||
/// type ToolUseBlockScope = ();
|
||
///
|
||
/// fn on_text_block(&mut self, _: &mut (), event: &TextBlockEvent) {
|
||
/// if let TextBlockEvent::Delta(text) = event {
|
||
/// print!("{}", text);
|
||
/// }
|
||
/// }
|
||
/// }
|
||
///
|
||
/// worker.subscribe(MyPrinter);
|
||
/// ```
|
||
pub fn subscribe<Sub: WorkerSubscriber + 'static>(&mut self, subscriber: Sub) {
|
||
let subscriber = Arc::new(Mutex::new(subscriber));
|
||
|
||
// TextBlock用ハンドラを登録
|
||
self.timeline
|
||
.on_text_block(TextBlockSubscriberAdapter::new(subscriber.clone()));
|
||
|
||
// ToolUseBlock用ハンドラを登録
|
||
self.timeline
|
||
.on_tool_use_block(ToolUseBlockSubscriberAdapter::new(subscriber.clone()));
|
||
|
||
// Meta系ハンドラを登録
|
||
self.timeline
|
||
.on_usage(UsageSubscriberAdapter::new(subscriber.clone()));
|
||
self.timeline
|
||
.on_status(StatusSubscriberAdapter::new(subscriber.clone()));
|
||
self.timeline
|
||
.on_error(ErrorSubscriberAdapter::new(subscriber.clone()));
|
||
|
||
// ターン制御用コールバックを登録
|
||
self.turn_notifiers
|
||
.push(Box::new(SubscriberTurnNotifier { subscriber }));
|
||
}
|
||
|
||
/// ツールを登録する
|
||
///
|
||
/// 登録されたツールはLLMからの呼び出しで自動的に実行されます。
|
||
/// 同名のツールを登録した場合、後から登録したものが優先されます。
|
||
///
|
||
/// # Examples
|
||
///
|
||
/// ```ignore
|
||
/// use worker::Worker;
|
||
/// use my_tools::SearchTool;
|
||
///
|
||
/// worker.register_tool(SearchTool::new());
|
||
/// ```
|
||
pub fn register_tool(&mut self, tool: impl Tool + 'static) {
|
||
let name = tool.name().to_string();
|
||
self.tools.insert(name, Arc::new(tool));
|
||
}
|
||
|
||
/// 複数のツールを登録
|
||
pub fn register_tools(&mut self, tools: impl IntoIterator<Item = impl Tool + 'static>) {
|
||
for tool in tools {
|
||
self.register_tool(tool);
|
||
}
|
||
}
|
||
|
||
/// Hookを追加する
|
||
///
|
||
/// Hookはターンの進行・ツール実行に介入できます。
|
||
/// 複数のHookを登録した場合、登録順に実行されます。
|
||
///
|
||
/// # Examples
|
||
///
|
||
/// ```ignore
|
||
/// use worker::{Worker, WorkerHook, ControlFlow, ToolCall};
|
||
///
|
||
/// struct LoggingHook;
|
||
///
|
||
/// #[async_trait::async_trait]
|
||
/// impl WorkerHook for LoggingHook {
|
||
/// async fn before_tool_call(&self, call: &mut ToolCall) -> Result<ControlFlow, HookError> {
|
||
/// println!("Calling tool: {}", call.name);
|
||
/// Ok(ControlFlow::Continue)
|
||
/// }
|
||
/// }
|
||
///
|
||
/// worker.add_hook(LoggingHook);
|
||
/// ```
|
||
pub fn add_hook(&mut self, hook: impl WorkerHook + 'static) {
|
||
self.hooks.push(Box::new(hook));
|
||
}
|
||
|
||
/// タイムラインへの可変参照を取得(追加ハンドラ登録用)
|
||
pub fn timeline_mut(&mut self) -> &mut Timeline {
|
||
&mut self.timeline
|
||
}
|
||
|
||
/// 履歴への参照を取得
|
||
pub fn history(&self) -> &[Message] {
|
||
&self.history
|
||
}
|
||
|
||
/// システムプロンプトへの参照を取得
|
||
pub fn get_system_prompt(&self) -> Option<&str> {
|
||
self.system_prompt.as_deref()
|
||
}
|
||
|
||
/// 現在のターンカウントを取得
|
||
pub fn turn_count(&self) -> usize {
|
||
self.turn_count
|
||
}
|
||
|
||
/// 登録されたツールからToolDefinitionのリストを生成
|
||
fn build_tool_definitions(&self) -> Vec<ToolDefinition> {
|
||
self.tools
|
||
.values()
|
||
.map(|tool| {
|
||
ToolDefinition::new(tool.name())
|
||
.description(tool.description())
|
||
.input_schema(tool.input_schema())
|
||
})
|
||
.collect()
|
||
}
|
||
|
||
/// テキストブロックとツール呼び出しからアシスタントメッセージを構築
|
||
fn build_assistant_message(
|
||
&self,
|
||
text_blocks: &[String],
|
||
tool_calls: &[ToolCall],
|
||
) -> Option<Message> {
|
||
// テキストもツール呼び出しもない場合はNone
|
||
if text_blocks.is_empty() && tool_calls.is_empty() {
|
||
return None;
|
||
}
|
||
|
||
// テキストのみの場合はシンプルなテキストメッセージ
|
||
if tool_calls.is_empty() {
|
||
let text = text_blocks.join("");
|
||
return Some(Message::assistant(text));
|
||
}
|
||
|
||
// ツール呼び出しがある場合は Parts として構築
|
||
let mut parts = Vec::new();
|
||
|
||
// テキストパーツを追加
|
||
for text in text_blocks {
|
||
if !text.is_empty() {
|
||
parts.push(ContentPart::Text { text: text.clone() });
|
||
}
|
||
}
|
||
|
||
// ツール呼び出しパーツを追加
|
||
for call in tool_calls {
|
||
parts.push(ContentPart::ToolUse {
|
||
id: call.id.clone(),
|
||
name: call.name.clone(),
|
||
input: call.input.clone(),
|
||
});
|
||
}
|
||
|
||
Some(Message {
|
||
role: worker_types::Role::Assistant,
|
||
content: MessageContent::Parts(parts),
|
||
})
|
||
}
|
||
|
||
/// リクエストを構築
|
||
fn build_request(&self, tool_definitions: &[ToolDefinition]) -> Request {
|
||
let mut request = Request::new();
|
||
|
||
// システムプロンプトを設定
|
||
if let Some(ref system) = self.system_prompt {
|
||
request = request.system(system);
|
||
}
|
||
|
||
// メッセージを追加
|
||
for msg in &self.history {
|
||
// worker-types::Message から llm_client::Message への変換
|
||
request = request.message(crate::llm_client::Message {
|
||
role: match msg.role {
|
||
worker_types::Role::User => crate::llm_client::Role::User,
|
||
worker_types::Role::Assistant => crate::llm_client::Role::Assistant,
|
||
},
|
||
content: match &msg.content {
|
||
worker_types::MessageContent::Text(t) => {
|
||
crate::llm_client::MessageContent::Text(t.clone())
|
||
}
|
||
worker_types::MessageContent::ToolResult {
|
||
tool_use_id,
|
||
content,
|
||
} => crate::llm_client::MessageContent::ToolResult {
|
||
tool_use_id: tool_use_id.clone(),
|
||
content: content.clone(),
|
||
},
|
||
worker_types::MessageContent::Parts(parts) => {
|
||
crate::llm_client::MessageContent::Parts(
|
||
parts
|
||
.iter()
|
||
.map(|p| match p {
|
||
worker_types::ContentPart::Text { text } => {
|
||
crate::llm_client::ContentPart::Text { text: text.clone() }
|
||
}
|
||
worker_types::ContentPart::ToolUse { id, name, input } => {
|
||
crate::llm_client::ContentPart::ToolUse {
|
||
id: id.clone(),
|
||
name: name.clone(),
|
||
input: input.clone(),
|
||
}
|
||
}
|
||
worker_types::ContentPart::ToolResult {
|
||
tool_use_id,
|
||
content,
|
||
} => crate::llm_client::ContentPart::ToolResult {
|
||
tool_use_id: tool_use_id.clone(),
|
||
content: content.clone(),
|
||
},
|
||
})
|
||
.collect(),
|
||
)
|
||
}
|
||
},
|
||
});
|
||
}
|
||
|
||
// ツール定義を追加
|
||
for tool_def in tool_definitions {
|
||
request = request.tool(tool_def.clone());
|
||
}
|
||
|
||
request
|
||
}
|
||
|
||
/// Hooks: on_message_send
|
||
async fn run_on_message_send_hooks(&self) -> Result<ControlFlow, WorkerError> {
|
||
for hook in &self.hooks {
|
||
// Note: Locked状態でも履歴全体を参照として渡す(変更は不可)
|
||
// HookのAPIを変更し、immutable参照のみを渡すようにする必要があるかもしれない
|
||
// 現在は空のVecを渡して回避(要検討)
|
||
let mut temp_context = self.history.clone();
|
||
let result = hook.on_message_send(&mut temp_context).await?;
|
||
match result {
|
||
ControlFlow::Continue => continue,
|
||
ControlFlow::Skip => return Ok(ControlFlow::Skip),
|
||
ControlFlow::Abort(reason) => return Ok(ControlFlow::Abort(reason)),
|
||
}
|
||
}
|
||
Ok(ControlFlow::Continue)
|
||
}
|
||
|
||
/// Hooks: on_turn_end
|
||
async fn run_on_turn_end_hooks(&self) -> Result<TurnResult, WorkerError> {
|
||
for hook in &self.hooks {
|
||
let result = hook.on_turn_end(&self.history).await?;
|
||
match result {
|
||
TurnResult::Finish => continue,
|
||
TurnResult::ContinueWithMessages(msgs) => {
|
||
return Ok(TurnResult::ContinueWithMessages(msgs));
|
||
}
|
||
}
|
||
}
|
||
Ok(TurnResult::Finish)
|
||
}
|
||
|
||
/// ツールを並列実行
|
||
///
|
||
/// 全てのツールに対してbefore_tool_callフックを実行後、
|
||
/// 許可されたツールを並列に実行し、結果にafter_tool_callフックを適用する。
|
||
async fn execute_tools(
|
||
&self,
|
||
tool_calls: Vec<ToolCall>,
|
||
) -> Result<Vec<ToolResult>, WorkerError> {
|
||
use futures::future::join_all;
|
||
|
||
// Phase 1: before_tool_call フックを適用(スキップ/中断を判定)
|
||
let mut approved_calls = Vec::new();
|
||
for mut tool_call in tool_calls {
|
||
let mut skip = false;
|
||
for hook in &self.hooks {
|
||
let result = hook.before_tool_call(&mut tool_call).await?;
|
||
match result {
|
||
ControlFlow::Continue => {}
|
||
ControlFlow::Skip => {
|
||
skip = true;
|
||
break;
|
||
}
|
||
ControlFlow::Abort(reason) => {
|
||
return Err(WorkerError::Aborted(reason));
|
||
}
|
||
}
|
||
}
|
||
if !skip {
|
||
approved_calls.push(tool_call);
|
||
}
|
||
}
|
||
|
||
// Phase 2: 許可されたツールを並列実行
|
||
let futures: Vec<_> = approved_calls
|
||
.into_iter()
|
||
.map(|tool_call| {
|
||
let tools = &self.tools;
|
||
async move {
|
||
if let Some(tool) = tools.get(&tool_call.name) {
|
||
let input_json =
|
||
serde_json::to_string(&tool_call.input).unwrap_or_default();
|
||
match tool.execute(&input_json).await {
|
||
Ok(content) => ToolResult::success(&tool_call.id, content),
|
||
Err(e) => ToolResult::error(&tool_call.id, e.to_string()),
|
||
}
|
||
} else {
|
||
ToolResult::error(
|
||
&tool_call.id,
|
||
format!("Tool '{}' not found", tool_call.name),
|
||
)
|
||
}
|
||
}
|
||
})
|
||
.collect();
|
||
|
||
let mut results = join_all(futures).await;
|
||
|
||
// Phase 3: after_tool_call フックを適用
|
||
for tool_result in &mut results {
|
||
for hook in &self.hooks {
|
||
let result = hook.after_tool_call(tool_result).await?;
|
||
match result {
|
||
ControlFlow::Continue => {}
|
||
ControlFlow::Skip => break,
|
||
ControlFlow::Abort(reason) => {
|
||
return Err(WorkerError::Aborted(reason));
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
Ok(results)
|
||
}
|
||
|
||
/// 内部で使用するターン実行ロジック
|
||
async fn run_turn_loop(&mut self) -> Result<(), WorkerError> {
|
||
let tool_definitions = self.build_tool_definitions();
|
||
|
||
info!(
|
||
message_count = self.history.len(),
|
||
tool_count = tool_definitions.len(),
|
||
"Starting worker run"
|
||
);
|
||
|
||
loop {
|
||
// ターン開始を通知
|
||
let current_turn = self.turn_count;
|
||
debug!(turn = current_turn, "Turn start");
|
||
for notifier in &self.turn_notifiers {
|
||
notifier.on_turn_start(current_turn);
|
||
}
|
||
|
||
// Hook: on_message_send
|
||
let control = self.run_on_message_send_hooks().await?;
|
||
if let ControlFlow::Abort(reason) = control {
|
||
warn!(reason = %reason, "Aborted by hook");
|
||
// ターン終了を通知(異常終了)
|
||
for notifier in &self.turn_notifiers {
|
||
notifier.on_turn_end(current_turn);
|
||
}
|
||
return Err(WorkerError::Aborted(reason));
|
||
}
|
||
|
||
// リクエスト構築
|
||
let request = self.build_request(&tool_definitions);
|
||
debug!(
|
||
message_count = request.messages.len(),
|
||
tool_count = request.tools.len(),
|
||
has_system = request.system_prompt.is_some(),
|
||
"Sending request to LLM"
|
||
);
|
||
|
||
// ストリーム処理
|
||
debug!("Starting stream...");
|
||
let mut stream = self.client.stream(request).await?;
|
||
let mut event_count = 0;
|
||
while let Some(event_result) = stream.next().await {
|
||
match &event_result {
|
||
Ok(event) => {
|
||
trace!(event = ?event, "Received event");
|
||
event_count += 1;
|
||
}
|
||
Err(e) => {
|
||
warn!(error = %e, "Stream error");
|
||
}
|
||
}
|
||
let event = event_result?;
|
||
self.timeline.dispatch(&event);
|
||
}
|
||
debug!(event_count = event_count, "Stream completed");
|
||
|
||
// ターン終了を通知
|
||
for notifier in &self.turn_notifiers {
|
||
notifier.on_turn_end(current_turn);
|
||
}
|
||
self.turn_count += 1;
|
||
|
||
// 収集結果を取得
|
||
let text_blocks = self.text_block_collector.take_collected();
|
||
let tool_calls = self.tool_call_collector.take_collected();
|
||
|
||
// アシスタントメッセージを履歴に追加
|
||
let assistant_message = self.build_assistant_message(&text_blocks, &tool_calls);
|
||
if let Some(msg) = assistant_message {
|
||
self.history.push(msg);
|
||
}
|
||
|
||
if tool_calls.is_empty() {
|
||
// ツール呼び出しなし → ターン終了判定
|
||
let turn_result = self.run_on_turn_end_hooks().await?;
|
||
match turn_result {
|
||
TurnResult::Finish => {
|
||
return Ok(());
|
||
}
|
||
TurnResult::ContinueWithMessages(additional) => {
|
||
self.history.extend(additional);
|
||
continue;
|
||
}
|
||
}
|
||
}
|
||
|
||
// ツール実行
|
||
let tool_results = self.execute_tools(tool_calls).await?;
|
||
|
||
// ツール結果を履歴に追加
|
||
for result in tool_results {
|
||
self.history
|
||
.push(Message::tool_result(&result.tool_use_id, &result.content));
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// =============================================================================
|
||
// Mutable状態専用の実装
|
||
// =============================================================================
|
||
|
||
impl<C: LlmClient> Worker<C, Mutable> {
|
||
/// 新しいWorkerを作成(Mutable状態)
|
||
pub fn new(client: C) -> Self {
|
||
let text_block_collector = TextBlockCollector::new();
|
||
let tool_call_collector = ToolCallCollector::new();
|
||
let mut timeline = Timeline::new();
|
||
|
||
// コレクターをTimelineに登録
|
||
timeline.on_text_block(text_block_collector.clone());
|
||
timeline.on_tool_use_block(tool_call_collector.clone());
|
||
|
||
Self {
|
||
client,
|
||
timeline,
|
||
text_block_collector,
|
||
tool_call_collector,
|
||
tools: HashMap::new(),
|
||
hooks: Vec::new(),
|
||
system_prompt: None,
|
||
history: Vec::new(),
|
||
locked_prefix_len: 0,
|
||
turn_count: 0,
|
||
turn_notifiers: Vec::new(),
|
||
_state: PhantomData,
|
||
}
|
||
}
|
||
|
||
/// システムプロンプトを設定(ビルダーパターン)
|
||
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
|
||
self.system_prompt = Some(prompt.into());
|
||
self
|
||
}
|
||
|
||
/// システムプロンプトを設定(可変参照版)
|
||
pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
|
||
self.system_prompt = Some(prompt.into());
|
||
}
|
||
|
||
/// 履歴への可変参照を取得
|
||
///
|
||
/// Mutable状態でのみ利用可能。履歴を自由に編集できる。
|
||
pub fn history_mut(&mut self) -> &mut Vec<Message> {
|
||
&mut self.history
|
||
}
|
||
|
||
/// 履歴を設定
|
||
pub fn set_history(&mut self, messages: Vec<Message>) {
|
||
self.history = messages;
|
||
}
|
||
|
||
/// 履歴にメッセージを追加(ビルダーパターン)
|
||
pub fn with_message(mut self, message: Message) -> Self {
|
||
self.history.push(message);
|
||
self
|
||
}
|
||
|
||
/// 履歴にメッセージを追加
|
||
pub fn push_message(&mut self, message: Message) {
|
||
self.history.push(message);
|
||
}
|
||
|
||
/// 複数のメッセージを履歴に追加(ビルダーパターン)
|
||
pub fn with_messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
|
||
self.history.extend(messages);
|
||
self
|
||
}
|
||
|
||
/// 複数のメッセージを履歴に追加
|
||
pub fn extend_history(&mut self, messages: impl IntoIterator<Item = Message>) {
|
||
self.history.extend(messages);
|
||
}
|
||
|
||
/// 履歴をクリア
|
||
pub fn clear_history(&mut self) {
|
||
self.history.clear();
|
||
}
|
||
|
||
/// 設定を適用(将来の拡張用)
|
||
#[allow(dead_code)]
|
||
pub fn config(self, _config: WorkerConfig) -> Self {
|
||
self
|
||
}
|
||
|
||
/// ロックしてLocked状態へ遷移
|
||
///
|
||
/// この操作により、現在のシステムプロンプトと履歴が「確定済みプレフィックス」として
|
||
/// 固定される。以降は履歴への追記のみが可能となり、キャッシュヒットが保証される。
|
||
pub fn lock(self) -> Worker<C, Locked> {
|
||
let locked_prefix_len = self.history.len();
|
||
Worker {
|
||
client: self.client,
|
||
timeline: self.timeline,
|
||
text_block_collector: self.text_block_collector,
|
||
tool_call_collector: self.tool_call_collector,
|
||
tools: self.tools,
|
||
hooks: self.hooks,
|
||
system_prompt: self.system_prompt,
|
||
history: self.history,
|
||
locked_prefix_len,
|
||
turn_count: self.turn_count,
|
||
turn_notifiers: self.turn_notifiers,
|
||
_state: PhantomData,
|
||
}
|
||
}
|
||
|
||
/// ターンを実行(Mutable状態)
|
||
///
|
||
/// 新しいユーザーメッセージを履歴に追加し、LLMにリクエストを送信する。
|
||
/// ツール呼び出しがある場合は自動的にループする。
|
||
///
|
||
/// 注意: この関数は履歴を変更するため、キャッシュ保護が必要な場合は
|
||
/// `lock()` を呼んでからLocked状態で `run` を使用すること。
|
||
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<&[Message], WorkerError> {
|
||
self.history.push(Message::user(user_input));
|
||
self.run_turn_loop().await?;
|
||
Ok(&self.history)
|
||
}
|
||
|
||
/// 複数メッセージでターンを実行(Mutable状態)
|
||
///
|
||
/// 指定されたメッセージを履歴に追加してから実行する。
|
||
pub async fn run_with_messages(
|
||
&mut self,
|
||
messages: Vec<Message>,
|
||
) -> Result<&[Message], WorkerError> {
|
||
self.history.extend(messages);
|
||
self.run_turn_loop().await?;
|
||
Ok(&self.history)
|
||
}
|
||
}
|
||
|
||
// =============================================================================
|
||
// Locked状態専用の実装
|
||
// =============================================================================
|
||
|
||
impl<C: LlmClient> Worker<C, Locked> {
|
||
/// ターンを実行(Locked状態)
|
||
///
|
||
/// 新しいユーザーメッセージを履歴の末尾に追加し、LLMにリクエストを送信する。
|
||
/// ロック時点より前の履歴(プレフィックス)は不変であるため、キャッシュヒットが保証される。
|
||
pub async fn run(&mut self, user_input: impl Into<String>) -> Result<&[Message], WorkerError> {
|
||
self.history.push(Message::user(user_input));
|
||
self.run_turn_loop().await?;
|
||
Ok(&self.history)
|
||
}
|
||
|
||
/// 複数メッセージでターンを実行(Locked状態)
|
||
pub async fn run_with_messages(
|
||
&mut self,
|
||
messages: Vec<Message>,
|
||
) -> Result<&[Message], WorkerError> {
|
||
self.history.extend(messages);
|
||
self.run_turn_loop().await?;
|
||
Ok(&self.history)
|
||
}
|
||
|
||
/// ロック時点のプレフィックス長を取得
|
||
pub fn locked_prefix_len(&self) -> usize {
|
||
self.locked_prefix_len
|
||
}
|
||
|
||
/// ロックを解除してMutable状態へ戻す
|
||
///
|
||
/// 注意: この操作を行うと、以降のリクエストでキャッシュがヒットしなくなる可能性がある。
|
||
/// 履歴を編集する必要がある場合にのみ使用すること。
|
||
pub fn unlock(self) -> Worker<C, Mutable> {
|
||
Worker {
|
||
client: self.client,
|
||
timeline: self.timeline,
|
||
text_block_collector: self.text_block_collector,
|
||
tool_call_collector: self.tool_call_collector,
|
||
tools: self.tools,
|
||
hooks: self.hooks,
|
||
system_prompt: self.system_prompt,
|
||
history: self.history,
|
||
locked_prefix_len: 0,
|
||
turn_count: self.turn_count,
|
||
turn_notifiers: self.turn_notifiers,
|
||
_state: PhantomData,
|
||
}
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
// 基本的なテストのみ。LlmClientを使ったテストは統合テストで行う。
|
||
}
|