update: Implement cancellation notification using mpsc::channel

This commit is contained in:
Keisuke Hirata 2026-01-10 22:45:01 +09:00
parent 16fda38039
commit a2f53d7879
4 changed files with 192 additions and 148 deletions

View File

@ -4,7 +4,7 @@ Workerの非同期キャンセル機構についての設計ドキュメント
## 概要 ## 概要
`tokio_util::sync::CancellationToken`を用いて、別タスクからWorkerの実行を安全にキャンセルできる。 `tokio::sync::mpsc`の通知チャネルを用いて、別タスクからWorkerの実行を安全にキャンセルできる。
```rust ```rust
let worker = Arc::new(Mutex::new(Worker::new(client))); let worker = Arc::new(Mutex::new(Worker::new(client)));
@ -19,15 +19,6 @@ let handle = tokio::spawn(async move {
worker.lock().await.cancel(); worker.lock().await.cancel();
``` ```
## キャンセルポイント
キャンセルは以下のタイミングでチェックされる:
1. **ターンループ先頭**`is_cancelled()`で即座にチェック
2. **ストリーム開始前**`client.stream()`呼び出し時
3. **ストリーム受信中**`tokio::select!`で各イベント受信と並行監視
4. **ツール実行中**`join_all()`と並行監視
## キャンセル時の処理フロー ## キャンセル時の処理フロー
``` ```
@ -42,11 +33,10 @@ Err(WorkerError::Cancelled) // エラー返却
## API ## API
| メソッド | 説明 | | メソッド | 説明 |
| ---------------------- | --------------------------------------------------------- | | ----------------- | ------------------------------ |
| `cancel()` | キャンセルをトリガー | | `cancel()` | キャンセルをトリガー |
| `is_cancelled()` | キャンセル状態を確認 | | `cancel_sender()` | キャンセル通知用のSenderを取得 |
| `cancellation_token()` | トークンへの参照を取得(`clone()`してタスク間で共有可能) |
## on_abort フック ## on_abort フック
@ -69,22 +59,12 @@ async fn on_abort(&self, reason: &str) -> Result<(), HookError> {
## 既知の問題 ## 既知の問題
### 1. キャンセルトークンの再利用不可 ### on_abort の発火基準
`CancellationToken`は一度キャンセルされると永続的にキャンセル状態になる。 `on_abort`**interrupt中断** された場合に必ず発火する。
同じWorkerインスタンスで再度`run()`を呼ぶと即座に`Cancelled`エラーになる。
**対応案:** interrupt の例:
- `run()`開始時に新しいトークンを生成する - `WorkerError::Cancelled`(キャンセル)
- `reset_cancellation()`メソッドを提供する - `WorkerError::Aborted`フックによるAbort
- ストリーム/ツール/クライアント/Hook の各種エラーで処理が中断された場合
### 2. Sync バウンドの追加(破壊的変更)
`tokio::select!`使用のため、Handler/Scope型に`Sync`バウンドを追加した。
既存のユーザーコードで`Sync`未実装の型を使用している場合、コンパイルエラーになる。
### 3. エラー時のon_abort呼び出し
現在、`on_abort`はキャンセルとフックAbort時のみ呼ばれる。
ストリームエラー等のその他エラー時には呼ばれないため、一貫性に欠ける可能性がある。

View File

@ -30,10 +30,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("🚀 Starting Worker..."); println!("🚀 Starting Worker...");
println!("💡 Will cancel after 2 seconds\n"); println!("💡 Will cancel after 2 seconds\n");
// キャンセルトークンを先に取得(ロックを保持しない) // キャンセルSenderを先に取得(ロックを保持しない)
let cancel_token = { let cancel_tx = {
let w = worker.lock().await; let w = worker.lock().await;
w.cancellation_token().clone() w.cancel_sender()
}; };
// タスク1: Workerを実行 // タスク1: Workerを実行
@ -43,10 +43,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("📡 Sending request to LLM..."); println!("📡 Sending request to LLM...");
match w.run("Tell me a very long story about a brave knight. Make it as detailed as possible with many paragraphs.").await { match w.run("Tell me a very long story about a brave knight. Make it as detailed as possible with many paragraphs.").await {
Ok(WorkerResult::Finished(_)) => { Ok(WorkerResult::Finished) => {
println!("✅ Task completed normally"); println!("✅ Task completed normally");
} }
Ok(WorkerResult::Paused(_)) => { Ok(WorkerResult::Paused) => {
println!("⏸️ Task paused"); println!("⏸️ Task paused");
} }
Err(e) => { Err(e) => {
@ -59,7 +59,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
tokio::spawn(async move { tokio::spawn(async move {
tokio::time::sleep(Duration::from_secs(2)).await; tokio::time::sleep(Duration::from_secs(2)).await;
println!("\n🛑 Cancelling worker..."); println!("\n🛑 Cancelling worker...");
cancel_token.cancel(); let _ = cancel_tx.send(()).await;
}); });
// タスク完了を待つ // タスク完了を待つ

View File

@ -3,7 +3,7 @@ use std::marker::PhantomData;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use futures::StreamExt; use futures::StreamExt;
use tokio_util::sync::CancellationToken; use tokio::sync::mpsc;
use tracing::{debug, info, trace, warn}; use tracing::{debug, info, trace, warn};
use crate::{ use crate::{
@ -78,11 +78,11 @@ pub struct WorkerConfig {
/// Workerの実行結果ステータス /// Workerの実行結果ステータス
#[derive(Debug)] #[derive(Debug)]
pub enum WorkerResult<'a> { pub enum WorkerResult {
/// 完了(ユーザー入力待ち状態) /// 完了(ユーザー入力待ち状態)
Finished(&'a [Message]), Finished,
/// 一時停止(再開可能) /// 一時停止(再開可能)
Paused(&'a [Message]), Paused,
} }
/// 内部用: ツール実行結果 /// 内部用: ツール実行結果
@ -182,8 +182,11 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
turn_notifiers: Vec<Box<dyn TurnNotifier>>, turn_notifiers: Vec<Box<dyn TurnNotifier>>,
/// リクエスト設定max_tokens, temperature等 /// リクエスト設定max_tokens, temperature等
request_config: RequestConfig, request_config: RequestConfig,
/// キャンセレーショントークン(実行中断用) /// 前回の実行が中断されたかどうか
cancellation_token: CancellationToken, last_run_interrupted: bool,
/// キャンセル通知用チャネル(実行中断用)
cancel_tx: mpsc::Sender<()>,
cancel_rx: mpsc::Receiver<()>,
/// 状態マーカー /// 状態マーカー
_state: PhantomData<S>, _state: PhantomData<S>,
} }
@ -193,6 +196,57 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
// ============================================================================= // =============================================================================
impl<C: LlmClient, S: WorkerState> Worker<C, S> { impl<C: LlmClient, S: WorkerState> Worker<C, S> {
fn reset_interruption_state(&mut self) {
self.last_run_interrupted = false;
}
/// ターンを実行
///
/// 新しいユーザーメッセージを履歴に追加し、LLMにリクエストを送信する。
/// ツール呼び出しがある場合は自動的にループする。
pub async fn run(
&mut self,
user_input: impl Into<String>,
) -> Result<WorkerResult, WorkerError> {
self.reset_interruption_state();
// Hook: on_prompt_submit
let mut user_message = Message::user(user_input);
let result = self.run_on_prompt_submit_hooks(&mut user_message).await;
let result = match result {
Ok(value) => value,
Err(err) => return self.finalize_interruption(Err(err)).await,
};
match result {
OnPromptSubmitResult::Cancel(reason) => {
self.last_run_interrupted = true;
return self.finalize_interruption(Err(WorkerError::Aborted(reason))).await;
}
OnPromptSubmitResult::Continue => {}
}
self.history.push(user_message);
let result = self.run_turn_loop().await;
self.finalize_interruption(result).await
}
fn drain_cancel_queue(&mut self) {
use tokio::sync::mpsc::error::TryRecvError;
loop {
match self.cancel_rx.try_recv() {
Ok(()) => continue,
Err(TryRecvError::Empty) | Err(TryRecvError::Disconnected) => break,
}
}
}
fn try_cancelled(&mut self) -> bool {
use tokio::sync::mpsc::error::TryRecvError;
match self.cancel_rx.try_recv() {
Ok(()) => true,
Err(TryRecvError::Empty) => false,
Err(TryRecvError::Disconnected) => true,
}
}
/// イベント購読者を登録する /// イベント購読者を登録する
/// ///
/// 登録したSubscriberは、LLMからのストリーミングイベントを /// 登録したSubscriberは、LLMからのストリーミングイベントを
@ -410,6 +464,11 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
self.request_config.stop_sequences.clear(); self.request_config.stop_sequences.clear();
} }
/// キャンセル通知用Senderを取得する
pub fn cancel_sender(&self) -> mpsc::Sender<()> {
self.cancel_tx.clone()
}
/// リクエスト設定を一括で設定 /// リクエスト設定を一括で設定
pub fn set_request_config(&mut self, config: RequestConfig) { pub fn set_request_config(&mut self, config: RequestConfig) {
self.request_config = config; self.request_config = config;
@ -437,17 +496,17 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
/// worker.lock().unwrap().cancel(); /// worker.lock().unwrap().cancel();
/// ``` /// ```
pub fn cancel(&self) { pub fn cancel(&self) {
self.cancellation_token.cancel(); let _ = self.cancel_tx.try_send(());
} }
/// キャンセルされているかチェック /// キャンセルされているかチェック
pub fn is_cancelled(&self) -> bool { pub fn is_cancelled(&mut self) -> bool {
self.cancellation_token.is_cancelled() self.try_cancelled()
} }
/// キャンセレーショントークンへの参照を取得 /// 前回の実行が中断されたかどうか
pub fn cancellation_token(&self) -> &CancellationToken { pub fn last_run_interrupted(&self) -> bool {
&self.cancellation_token self.last_run_interrupted
} }
/// 登録されたツールからLLM用ToolDefinitionのリストを生成 /// 登録されたツールからLLM用ToolDefinitionのリストを生成
@ -636,6 +695,28 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
Ok(()) Ok(())
} }
async fn finalize_interruption<T>(
&mut self,
result: Result<T, WorkerError>,
) -> Result<T, WorkerError> {
match result {
Ok(value) => Ok(value),
Err(err) => {
self.last_run_interrupted = true;
let reason = match &err {
WorkerError::Aborted(reason) => reason.clone(),
WorkerError::Cancelled => "Cancelled".to_string(),
_ => err.to_string(),
};
if let Err(hook_err) = self.run_on_abort_hooks(&reason).await {
self.last_run_interrupted = true;
return Err(hook_err);
}
Err(err)
}
}
}
/// 未実行のツール呼び出しがあるかチェックPauseからの復帰用 /// 未実行のツール呼び出しがあるかチェックPauseからの復帰用
fn get_pending_tool_calls(&self) -> Option<Vec<ToolCall>> { fn get_pending_tool_calls(&self) -> Option<Vec<ToolCall>> {
let last_msg = self.history.last()?; let last_msg = self.history.last()?;
@ -687,7 +768,10 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
let mut skip = false; let mut skip = false;
for hook in &self.hooks.pre_tool_call { for hook in &self.hooks.pre_tool_call {
let result = hook.call(&mut context).await?; let result = hook
.call(&mut context)
.await
.inspect_err(|_| self.last_run_interrupted = true)?;
match result { match result {
PreToolCallResult::Continue => {} PreToolCallResult::Continue => {}
PreToolCallResult::Skip => { PreToolCallResult::Skip => {
@ -695,9 +779,11 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
break; break;
} }
PreToolCallResult::Abort(reason) => { PreToolCallResult::Abort(reason) => {
self.last_run_interrupted = true;
return Err(WorkerError::Aborted(reason)); return Err(WorkerError::Aborted(reason));
} }
PreToolCallResult::Pause => { PreToolCallResult::Pause => {
self.last_run_interrupted = true;
return Ok(ToolExecutionResult::Paused); return Ok(ToolExecutionResult::Paused);
} }
} }
@ -747,10 +833,12 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
// ツール実行をキャンセル可能にする // ツール実行をキャンセル可能にする
let mut results = tokio::select! { let mut results = tokio::select! {
results = join_all(futures) => results, results = join_all(futures) => results,
_ = self.cancellation_token.cancelled() => { cancel = self.cancel_rx.recv() => {
info!("Tool execution cancelled"); if cancel.is_some() {
info!("Tool execution cancelled");
}
self.timeline.abort_current_block(); self.timeline.abort_current_block();
self.run_on_abort_hooks("Cancelled").await?; self.last_run_interrupted = true;
return Err(WorkerError::Cancelled); return Err(WorkerError::Cancelled);
} }
}; };
@ -767,10 +855,14 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
}; };
for hook in &self.hooks.post_tool_call { for hook in &self.hooks.post_tool_call {
let result = hook.call(&mut context).await?; let result = hook
.call(&mut context)
.await
.inspect_err(|_| self.last_run_interrupted = true)?;
match result { match result {
PostToolCallResult::Continue => {} PostToolCallResult::Continue => {}
PostToolCallResult::Abort(reason) => { PostToolCallResult::Abort(reason) => {
self.last_run_interrupted = true;
return Err(WorkerError::Aborted(reason)); return Err(WorkerError::Aborted(reason));
} }
} }
@ -784,7 +876,9 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
} }
/// 内部で使用するターン実行ロジック /// 内部で使用するターン実行ロジック
async fn run_turn_loop(&mut self) -> Result<WorkerResult<'_>, WorkerError> { async fn run_turn_loop(&mut self) -> Result<WorkerResult, WorkerError> {
self.reset_interruption_state();
self.drain_cancel_queue();
let tool_definitions = self.build_tool_definitions(); let tool_definitions = self.build_tool_definitions();
info!( info!(
@ -796,24 +890,31 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
// Resume check: Pending tool calls // Resume check: Pending tool calls
if let Some(tool_calls) = self.get_pending_tool_calls() { if let Some(tool_calls) = self.get_pending_tool_calls() {
info!("Resuming pending tool calls"); info!("Resuming pending tool calls");
match self.execute_tools(tool_calls).await? { match self.execute_tools(tool_calls).await {
ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)), Ok(ToolExecutionResult::Paused) => {
ToolExecutionResult::Completed(results) => { self.last_run_interrupted = true;
return Ok(WorkerResult::Paused);
}
Ok(ToolExecutionResult::Completed(results)) => {
for result in results { for result in results {
self.history self.history
.push(Message::tool_result(&result.tool_use_id, &result.content)); .push(Message::tool_result(&result.tool_use_id, &result.content));
} }
// Continue to loop // Continue to loop
} }
Err(err) => {
self.last_run_interrupted = true;
return Err(err);
}
} }
} }
loop { loop {
// キャンセルチェック // キャンセルチェック
if self.cancellation_token.is_cancelled() { if self.try_cancelled() {
info!("Execution cancelled"); info!("Execution cancelled");
self.timeline.abort_current_block(); self.timeline.abort_current_block();
self.run_on_abort_hooks("Cancelled").await?; self.last_run_interrupted = true;
return Err(WorkerError::Cancelled); return Err(WorkerError::Cancelled);
} }
@ -825,14 +926,17 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
} }
// Hook: pre_llm_request // Hook: pre_llm_request
let (control, request_context) = self.run_pre_llm_request_hooks().await?; let (control, request_context) = self
.run_pre_llm_request_hooks()
.await
.inspect_err(|_| self.last_run_interrupted = true)?;
match control { match control {
PreLlmRequestResult::Cancel(reason) => { PreLlmRequestResult::Cancel(reason) => {
info!(reason = %reason, "Aborted by hook"); info!(reason = %reason, "Aborted by hook");
for notifier in &self.turn_notifiers { for notifier in &self.turn_notifiers {
notifier.on_turn_end(current_turn); notifier.on_turn_end(current_turn);
} }
self.run_on_abort_hooks(&reason).await?; self.last_run_interrupted = true;
return Err(WorkerError::Aborted(reason)); return Err(WorkerError::Aborted(reason));
} }
PreLlmRequestResult::Continue => {} PreLlmRequestResult::Continue => {}
@ -853,11 +957,14 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
// ストリームを取得(キャンセル可能) // ストリームを取得(キャンセル可能)
let mut stream = tokio::select! { let mut stream = tokio::select! {
stream_result = self.client.stream(request) => stream_result?, stream_result = self.client.stream(request) => stream_result
_ = self.cancellation_token.cancelled() => { .inspect_err(|_| self.last_run_interrupted = true)?,
info!("Cancelled before stream started"); cancel = self.cancel_rx.recv() => {
if cancel.is_some() {
info!("Cancelled before stream started");
}
self.timeline.abort_current_block(); self.timeline.abort_current_block();
self.run_on_abort_hooks("Cancelled").await?; self.last_run_interrupted = true;
return Err(WorkerError::Cancelled); return Err(WorkerError::Cancelled);
} }
}; };
@ -877,7 +984,8 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
warn!(error = %e, "Stream error"); warn!(error = %e, "Stream error");
} }
} }
let event = result?; let event = result
.inspect_err(|_| self.last_run_interrupted = true)?;
let timeline_event: crate::timeline::event::Event = event.into(); let timeline_event: crate::timeline::event::Event = event.into();
self.timeline.dispatch(&timeline_event); self.timeline.dispatch(&timeline_event);
} }
@ -885,10 +993,12 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
} }
} }
// キャンセル待機 // キャンセル待機
_ = self.cancellation_token.cancelled() => { cancel = self.cancel_rx.recv() => {
info!("Stream cancelled"); if cancel.is_some() {
info!("Stream cancelled");
}
self.timeline.abort_current_block(); self.timeline.abort_current_block();
self.run_on_abort_hooks("Cancelled").await?; self.last_run_interrupted = true;
return Err(WorkerError::Cancelled); return Err(WorkerError::Cancelled);
} }
} }
@ -913,30 +1023,42 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
if tool_calls.is_empty() { if tool_calls.is_empty() {
// ツール呼び出しなし → ターン終了判定 // ツール呼び出しなし → ターン終了判定
let turn_result = self.run_on_turn_end_hooks().await?; let turn_result = self
.run_on_turn_end_hooks()
.await
.inspect_err(|_| self.last_run_interrupted = true)?;
match turn_result { match turn_result {
OnTurnEndResult::Finish => { OnTurnEndResult::Finish => {
return Ok(WorkerResult::Finished(&self.history)); self.last_run_interrupted = false;
return Ok(WorkerResult::Finished);
} }
OnTurnEndResult::ContinueWithMessages(additional) => { OnTurnEndResult::ContinueWithMessages(additional) => {
self.history.extend(additional); self.history.extend(additional);
continue; continue;
} }
OnTurnEndResult::Paused => { OnTurnEndResult::Paused => {
return Ok(WorkerResult::Paused(&self.history)); self.last_run_interrupted = true;
return Ok(WorkerResult::Paused);
} }
} }
} }
// ツール実行 // ツール実行
match self.execute_tools(tool_calls).await? { match self.execute_tools(tool_calls).await {
ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)), Ok(ToolExecutionResult::Paused) => {
ToolExecutionResult::Completed(results) => { self.last_run_interrupted = true;
return Ok(WorkerResult::Paused);
}
Ok(ToolExecutionResult::Completed(results)) => {
for result in results { for result in results {
self.history self.history
.push(Message::tool_result(&result.tool_use_id, &result.content)); .push(Message::tool_result(&result.tool_use_id, &result.content));
} }
} }
Err(err) => {
self.last_run_interrupted = true;
return Err(err);
}
} }
} }
} }
@ -944,8 +1066,10 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
/// 実行を再開Pause状態からの復帰 /// 実行を再開Pause状態からの復帰
/// ///
/// 新しいユーザーメッセージを履歴に追加せず、現在の状態からターン処理を再開する。 /// 新しいユーザーメッセージを履歴に追加せず、現在の状態からターン処理を再開する。
pub async fn resume(&mut self) -> Result<WorkerResult<'_>, WorkerError> { pub async fn resume(&mut self) -> Result<WorkerResult, WorkerError> {
self.run_turn_loop().await self.reset_interruption_state();
let result = self.run_turn_loop().await;
self.finalize_interruption(result).await
} }
} }
@ -959,6 +1083,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
let text_block_collector = TextBlockCollector::new(); let text_block_collector = TextBlockCollector::new();
let tool_call_collector = ToolCallCollector::new(); let tool_call_collector = ToolCallCollector::new();
let mut timeline = Timeline::new(); let mut timeline = Timeline::new();
let (cancel_tx, cancel_rx) = mpsc::channel(1);
// コレクターをTimelineに登録 // コレクターをTimelineに登録
timeline.on_text_block(text_block_collector.clone()); timeline.on_text_block(text_block_collector.clone());
@ -977,7 +1102,9 @@ impl<C: LlmClient> Worker<C, Mutable> {
turn_count: 0, turn_count: 0,
turn_notifiers: Vec::new(), turn_notifiers: Vec::new(),
request_config: RequestConfig::default(), request_config: RequestConfig::default(),
cancellation_token: CancellationToken::new(), last_run_interrupted: false,
cancel_tx,
cancel_rx,
_state: PhantomData, _state: PhantomData,
} }
} }
@ -1146,46 +1273,13 @@ impl<C: LlmClient> Worker<C, Mutable> {
turn_count: self.turn_count, turn_count: self.turn_count,
turn_notifiers: self.turn_notifiers, turn_notifiers: self.turn_notifiers,
request_config: self.request_config, request_config: self.request_config,
cancellation_token: self.cancellation_token, last_run_interrupted: self.last_run_interrupted,
cancel_tx: self.cancel_tx,
cancel_rx: self.cancel_rx,
_state: PhantomData, _state: PhantomData,
} }
} }
/// ターンを実行Mutable状態
///
/// 新しいユーザーメッセージを履歴に追加し、LLMにリクエストを送信する。
/// ツール呼び出しがある場合は自動的にループする。
///
/// 注意: この関数は履歴を変更するため、キャッシュ保護が必要な場合は
/// `lock()` を呼んでからLocked状態で `run` を使用すること。
pub async fn run(
&mut self,
user_input: impl Into<String>,
) -> Result<WorkerResult<'_>, WorkerError> {
// Hook: on_prompt_submit
let mut user_message = Message::user(user_input);
let result = self.run_on_prompt_submit_hooks(&mut user_message).await?;
match result {
OnPromptSubmitResult::Cancel(reason) => {
self.run_on_abort_hooks(&reason).await?;
return Err(WorkerError::Aborted(reason));
}
OnPromptSubmitResult::Continue => {}
}
self.history.push(user_message);
self.run_turn_loop().await
}
/// 複数メッセージでターンを実行Mutable状態
///
/// 指定されたメッセージを履歴に追加してから実行する。
pub async fn run_with_messages(
&mut self,
messages: Vec<Message>,
) -> Result<WorkerResult<'_>, WorkerError> {
self.history.extend(messages);
self.run_turn_loop().await
}
} }
// ============================================================================= // =============================================================================
@ -1193,37 +1287,6 @@ impl<C: LlmClient> Worker<C, Mutable> {
// ============================================================================= // =============================================================================
impl<C: LlmClient> Worker<C, Locked> { impl<C: LlmClient> Worker<C, Locked> {
/// ターンを実行Locked状態
///
/// 新しいユーザーメッセージを履歴の末尾に追加し、LLMにリクエストを送信する。
/// ロック時点より前の履歴(プレフィックス)は不変であるため、キャッシュヒットが保証される。
pub async fn run(
&mut self,
user_input: impl Into<String>,
) -> Result<WorkerResult<'_>, WorkerError> {
// Hook: on_prompt_submit
let mut user_message = Message::user(user_input);
let result = self.run_on_prompt_submit_hooks(&mut user_message).await?;
match result {
OnPromptSubmitResult::Cancel(reason) => {
self.run_on_abort_hooks(&reason).await?;
return Err(WorkerError::Aborted(reason));
}
OnPromptSubmitResult::Continue => {}
}
self.history.push(user_message);
self.run_turn_loop().await
}
/// 複数メッセージでターンを実行Locked状態
pub async fn run_with_messages(
&mut self,
messages: Vec<Message>,
) -> Result<WorkerResult<'_>, WorkerError> {
self.history.extend(messages);
self.run_turn_loop().await
}
/// ロック時点のプレフィックス長を取得 /// ロック時点のプレフィックス長を取得
pub fn locked_prefix_len(&self) -> usize { pub fn locked_prefix_len(&self) -> usize {
self.locked_prefix_len self.locked_prefix_len
@ -1247,7 +1310,9 @@ impl<C: LlmClient> Worker<C, Locked> {
turn_count: self.turn_count, turn_count: self.turn_count,
turn_notifiers: self.turn_notifiers, turn_notifiers: self.turn_notifiers,
request_config: self.request_config, request_config: self.request_config,
cancellation_token: self.cancellation_token, last_run_interrupted: self.last_run_interrupted,
cancel_tx: self.cancel_tx,
cancel_rx: self.cancel_rx,
_state: PhantomData, _state: PhantomData,
} }
} }

View File

@ -9,7 +9,6 @@ use std::sync::atomic::{AtomicUsize, Ordering};
use schemars; use schemars;
use serde; use serde;
use llm_worker::tool::{Tool, ToolMeta};
use llm_worker_macros::tool_registry; use llm_worker_macros::tool_registry;
// ============================================================================= // =============================================================================