develop #3
|
|
@ -4,7 +4,7 @@ Workerの非同期キャンセル機構についての設計ドキュメント
|
|||
|
||||
## 概要
|
||||
|
||||
`tokio_util::sync::CancellationToken`を用いて、別タスクからWorkerの実行を安全にキャンセルできる。
|
||||
`tokio::sync::mpsc`の通知チャネルを用いて、別タスクからWorkerの実行を安全にキャンセルできる。
|
||||
|
||||
```rust
|
||||
let worker = Arc::new(Mutex::new(Worker::new(client)));
|
||||
|
|
@ -19,15 +19,6 @@ let handle = tokio::spawn(async move {
|
|||
worker.lock().await.cancel();
|
||||
```
|
||||
|
||||
## キャンセルポイント
|
||||
|
||||
キャンセルは以下のタイミングでチェックされる:
|
||||
|
||||
1. **ターンループ先頭** — `is_cancelled()`で即座にチェック
|
||||
2. **ストリーム開始前** — `client.stream()`呼び出し時
|
||||
3. **ストリーム受信中** — `tokio::select!`で各イベント受信と並行監視
|
||||
4. **ツール実行中** — `join_all()`と並行監視
|
||||
|
||||
## キャンセル時の処理フロー
|
||||
|
||||
```
|
||||
|
|
@ -42,11 +33,10 @@ Err(WorkerError::Cancelled) // エラー返却
|
|||
|
||||
## API
|
||||
|
||||
| メソッド | 説明 |
|
||||
| ---------------------- | --------------------------------------------------------- |
|
||||
| `cancel()` | キャンセルをトリガー |
|
||||
| `is_cancelled()` | キャンセル状態を確認 |
|
||||
| `cancellation_token()` | トークンへの参照を取得(`clone()`してタスク間で共有可能) |
|
||||
| メソッド | 説明 |
|
||||
| ----------------- | ------------------------------ |
|
||||
| `cancel()` | キャンセルをトリガー |
|
||||
| `cancel_sender()` | キャンセル通知用のSenderを取得 |
|
||||
|
||||
## on_abort フック
|
||||
|
||||
|
|
@ -69,22 +59,12 @@ async fn on_abort(&self, reason: &str) -> Result<(), HookError> {
|
|||
|
||||
## 既知の問題
|
||||
|
||||
### 1. キャンセルトークンの再利用不可
|
||||
### on_abort の発火基準
|
||||
|
||||
`CancellationToken`は一度キャンセルされると永続的にキャンセル状態になる。
|
||||
同じWorkerインスタンスで再度`run()`を呼ぶと即座に`Cancelled`エラーになる。
|
||||
`on_abort` は **interrupt(中断)** された場合に必ず発火する。
|
||||
|
||||
**対応案:**
|
||||
interrupt の例:
|
||||
|
||||
- `run()`開始時に新しいトークンを生成する
|
||||
- `reset_cancellation()`メソッドを提供する
|
||||
|
||||
### 2. Sync バウンドの追加(破壊的変更)
|
||||
|
||||
`tokio::select!`使用のため、Handler/Scope型に`Sync`バウンドを追加した。
|
||||
既存のユーザーコードで`Sync`未実装の型を使用している場合、コンパイルエラーになる。
|
||||
|
||||
### 3. エラー時のon_abort呼び出し
|
||||
|
||||
現在、`on_abort`はキャンセルとフックAbort時のみ呼ばれる。
|
||||
ストリームエラー等のその他エラー時には呼ばれないため、一貫性に欠ける可能性がある。
|
||||
- `WorkerError::Cancelled`(キャンセル)
|
||||
- `WorkerError::Aborted`(フックによるAbort)
|
||||
- ストリーム/ツール/クライアント/Hook の各種エラーで処理が中断された場合
|
||||
|
|
|
|||
|
|
@ -30,10 +30,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
println!("🚀 Starting Worker...");
|
||||
println!("💡 Will cancel after 2 seconds\n");
|
||||
|
||||
// キャンセルトークンを先に取得(ロックを保持しない)
|
||||
let cancel_token = {
|
||||
// キャンセルSenderを先に取得(ロックを保持しない)
|
||||
let cancel_tx = {
|
||||
let w = worker.lock().await;
|
||||
w.cancellation_token().clone()
|
||||
w.cancel_sender()
|
||||
};
|
||||
|
||||
// タスク1: Workerを実行
|
||||
|
|
@ -43,10 +43,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
println!("📡 Sending request to LLM...");
|
||||
|
||||
match w.run("Tell me a very long story about a brave knight. Make it as detailed as possible with many paragraphs.").await {
|
||||
Ok(WorkerResult::Finished(_)) => {
|
||||
Ok(WorkerResult::Finished) => {
|
||||
println!("✅ Task completed normally");
|
||||
}
|
||||
Ok(WorkerResult::Paused(_)) => {
|
||||
Ok(WorkerResult::Paused) => {
|
||||
println!("⏸️ Task paused");
|
||||
}
|
||||
Err(e) => {
|
||||
|
|
@ -59,7 +59,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
tokio::spawn(async move {
|
||||
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||
println!("\n🛑 Cancelling worker...");
|
||||
cancel_token.cancel();
|
||||
let _ = cancel_tx.send(()).await;
|
||||
});
|
||||
|
||||
// タスク完了を待つ
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ use std::marker::PhantomData;
|
|||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use futures::StreamExt;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, info, trace, warn};
|
||||
|
||||
use crate::{
|
||||
|
|
@ -78,11 +78,11 @@ pub struct WorkerConfig {
|
|||
|
||||
/// Workerの実行結果(ステータス)
|
||||
#[derive(Debug)]
|
||||
pub enum WorkerResult<'a> {
|
||||
pub enum WorkerResult {
|
||||
/// 完了(ユーザー入力待ち状態)
|
||||
Finished(&'a [Message]),
|
||||
Finished,
|
||||
/// 一時停止(再開可能)
|
||||
Paused(&'a [Message]),
|
||||
Paused,
|
||||
}
|
||||
|
||||
/// 内部用: ツール実行結果
|
||||
|
|
@ -182,8 +182,11 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
|||
turn_notifiers: Vec<Box<dyn TurnNotifier>>,
|
||||
/// リクエスト設定(max_tokens, temperature等)
|
||||
request_config: RequestConfig,
|
||||
/// キャンセレーショントークン(実行中断用)
|
||||
cancellation_token: CancellationToken,
|
||||
/// 前回の実行が中断されたかどうか
|
||||
last_run_interrupted: bool,
|
||||
/// キャンセル通知用チャネル(実行中断用)
|
||||
cancel_tx: mpsc::Sender<()>,
|
||||
cancel_rx: mpsc::Receiver<()>,
|
||||
/// 状態マーカー
|
||||
_state: PhantomData<S>,
|
||||
}
|
||||
|
|
@ -193,6 +196,57 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
|||
// =============================================================================
|
||||
|
||||
impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||
fn reset_interruption_state(&mut self) {
|
||||
self.last_run_interrupted = false;
|
||||
}
|
||||
|
||||
/// ターンを実行
|
||||
///
|
||||
/// 新しいユーザーメッセージを履歴に追加し、LLMにリクエストを送信する。
|
||||
/// ツール呼び出しがある場合は自動的にループする。
|
||||
pub async fn run(
|
||||
&mut self,
|
||||
user_input: impl Into<String>,
|
||||
) -> Result<WorkerResult, WorkerError> {
|
||||
self.reset_interruption_state();
|
||||
// Hook: on_prompt_submit
|
||||
let mut user_message = Message::user(user_input);
|
||||
let result = self.run_on_prompt_submit_hooks(&mut user_message).await;
|
||||
let result = match result {
|
||||
Ok(value) => value,
|
||||
Err(err) => return self.finalize_interruption(Err(err)).await,
|
||||
};
|
||||
match result {
|
||||
OnPromptSubmitResult::Cancel(reason) => {
|
||||
self.last_run_interrupted = true;
|
||||
return self.finalize_interruption(Err(WorkerError::Aborted(reason))).await;
|
||||
}
|
||||
OnPromptSubmitResult::Continue => {}
|
||||
}
|
||||
self.history.push(user_message);
|
||||
let result = self.run_turn_loop().await;
|
||||
self.finalize_interruption(result).await
|
||||
}
|
||||
|
||||
fn drain_cancel_queue(&mut self) {
|
||||
use tokio::sync::mpsc::error::TryRecvError;
|
||||
loop {
|
||||
match self.cancel_rx.try_recv() {
|
||||
Ok(()) => continue,
|
||||
Err(TryRecvError::Empty) | Err(TryRecvError::Disconnected) => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn try_cancelled(&mut self) -> bool {
|
||||
use tokio::sync::mpsc::error::TryRecvError;
|
||||
match self.cancel_rx.try_recv() {
|
||||
Ok(()) => true,
|
||||
Err(TryRecvError::Empty) => false,
|
||||
Err(TryRecvError::Disconnected) => true,
|
||||
}
|
||||
}
|
||||
|
||||
/// イベント購読者を登録する
|
||||
///
|
||||
/// 登録したSubscriberは、LLMからのストリーミングイベントを
|
||||
|
|
@ -410,6 +464,11 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
self.request_config.stop_sequences.clear();
|
||||
}
|
||||
|
||||
/// キャンセル通知用Senderを取得する
|
||||
pub fn cancel_sender(&self) -> mpsc::Sender<()> {
|
||||
self.cancel_tx.clone()
|
||||
}
|
||||
|
||||
/// リクエスト設定を一括で設定
|
||||
pub fn set_request_config(&mut self, config: RequestConfig) {
|
||||
self.request_config = config;
|
||||
|
|
@ -437,17 +496,17 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
/// worker.lock().unwrap().cancel();
|
||||
/// ```
|
||||
pub fn cancel(&self) {
|
||||
self.cancellation_token.cancel();
|
||||
let _ = self.cancel_tx.try_send(());
|
||||
}
|
||||
|
||||
/// キャンセルされているかチェック
|
||||
pub fn is_cancelled(&self) -> bool {
|
||||
self.cancellation_token.is_cancelled()
|
||||
pub fn is_cancelled(&mut self) -> bool {
|
||||
self.try_cancelled()
|
||||
}
|
||||
|
||||
/// キャンセレーショントークンへの参照を取得
|
||||
pub fn cancellation_token(&self) -> &CancellationToken {
|
||||
&self.cancellation_token
|
||||
/// 前回の実行が中断されたかどうか
|
||||
pub fn last_run_interrupted(&self) -> bool {
|
||||
self.last_run_interrupted
|
||||
}
|
||||
|
||||
/// 登録されたツールからLLM用ToolDefinitionのリストを生成
|
||||
|
|
@ -636,6 +695,28 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn finalize_interruption<T>(
|
||||
&mut self,
|
||||
result: Result<T, WorkerError>,
|
||||
) -> Result<T, WorkerError> {
|
||||
match result {
|
||||
Ok(value) => Ok(value),
|
||||
Err(err) => {
|
||||
self.last_run_interrupted = true;
|
||||
let reason = match &err {
|
||||
WorkerError::Aborted(reason) => reason.clone(),
|
||||
WorkerError::Cancelled => "Cancelled".to_string(),
|
||||
_ => err.to_string(),
|
||||
};
|
||||
if let Err(hook_err) = self.run_on_abort_hooks(&reason).await {
|
||||
self.last_run_interrupted = true;
|
||||
return Err(hook_err);
|
||||
}
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 未実行のツール呼び出しがあるかチェック(Pauseからの復帰用)
|
||||
fn get_pending_tool_calls(&self) -> Option<Vec<ToolCall>> {
|
||||
let last_msg = self.history.last()?;
|
||||
|
|
@ -687,7 +768,10 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
|
||||
let mut skip = false;
|
||||
for hook in &self.hooks.pre_tool_call {
|
||||
let result = hook.call(&mut context).await?;
|
||||
let result = hook
|
||||
.call(&mut context)
|
||||
.await
|
||||
.inspect_err(|_| self.last_run_interrupted = true)?;
|
||||
match result {
|
||||
PreToolCallResult::Continue => {}
|
||||
PreToolCallResult::Skip => {
|
||||
|
|
@ -695,9 +779,11 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
break;
|
||||
}
|
||||
PreToolCallResult::Abort(reason) => {
|
||||
self.last_run_interrupted = true;
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
PreToolCallResult::Pause => {
|
||||
self.last_run_interrupted = true;
|
||||
return Ok(ToolExecutionResult::Paused);
|
||||
}
|
||||
}
|
||||
|
|
@ -747,10 +833,12 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
// ツール実行をキャンセル可能にする
|
||||
let mut results = tokio::select! {
|
||||
results = join_all(futures) => results,
|
||||
_ = self.cancellation_token.cancelled() => {
|
||||
info!("Tool execution cancelled");
|
||||
cancel = self.cancel_rx.recv() => {
|
||||
if cancel.is_some() {
|
||||
info!("Tool execution cancelled");
|
||||
}
|
||||
self.timeline.abort_current_block();
|
||||
self.run_on_abort_hooks("Cancelled").await?;
|
||||
self.last_run_interrupted = true;
|
||||
return Err(WorkerError::Cancelled);
|
||||
}
|
||||
};
|
||||
|
|
@ -767,10 +855,14 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
};
|
||||
|
||||
for hook in &self.hooks.post_tool_call {
|
||||
let result = hook.call(&mut context).await?;
|
||||
let result = hook
|
||||
.call(&mut context)
|
||||
.await
|
||||
.inspect_err(|_| self.last_run_interrupted = true)?;
|
||||
match result {
|
||||
PostToolCallResult::Continue => {}
|
||||
PostToolCallResult::Abort(reason) => {
|
||||
self.last_run_interrupted = true;
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
}
|
||||
|
|
@ -784,7 +876,9 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
}
|
||||
|
||||
/// 内部で使用するターン実行ロジック
|
||||
async fn run_turn_loop(&mut self) -> Result<WorkerResult<'_>, WorkerError> {
|
||||
async fn run_turn_loop(&mut self) -> Result<WorkerResult, WorkerError> {
|
||||
self.reset_interruption_state();
|
||||
self.drain_cancel_queue();
|
||||
let tool_definitions = self.build_tool_definitions();
|
||||
|
||||
info!(
|
||||
|
|
@ -796,24 +890,31 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
// Resume check: Pending tool calls
|
||||
if let Some(tool_calls) = self.get_pending_tool_calls() {
|
||||
info!("Resuming pending tool calls");
|
||||
match self.execute_tools(tool_calls).await? {
|
||||
ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)),
|
||||
ToolExecutionResult::Completed(results) => {
|
||||
match self.execute_tools(tool_calls).await {
|
||||
Ok(ToolExecutionResult::Paused) => {
|
||||
self.last_run_interrupted = true;
|
||||
return Ok(WorkerResult::Paused);
|
||||
}
|
||||
Ok(ToolExecutionResult::Completed(results)) => {
|
||||
for result in results {
|
||||
self.history
|
||||
.push(Message::tool_result(&result.tool_use_id, &result.content));
|
||||
}
|
||||
// Continue to loop
|
||||
}
|
||||
Err(err) => {
|
||||
self.last_run_interrupted = true;
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
// キャンセルチェック
|
||||
if self.cancellation_token.is_cancelled() {
|
||||
if self.try_cancelled() {
|
||||
info!("Execution cancelled");
|
||||
self.timeline.abort_current_block();
|
||||
self.run_on_abort_hooks("Cancelled").await?;
|
||||
self.last_run_interrupted = true;
|
||||
return Err(WorkerError::Cancelled);
|
||||
}
|
||||
|
||||
|
|
@ -825,14 +926,17 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
}
|
||||
|
||||
// Hook: pre_llm_request
|
||||
let (control, request_context) = self.run_pre_llm_request_hooks().await?;
|
||||
let (control, request_context) = self
|
||||
.run_pre_llm_request_hooks()
|
||||
.await
|
||||
.inspect_err(|_| self.last_run_interrupted = true)?;
|
||||
match control {
|
||||
PreLlmRequestResult::Cancel(reason) => {
|
||||
info!(reason = %reason, "Aborted by hook");
|
||||
for notifier in &self.turn_notifiers {
|
||||
notifier.on_turn_end(current_turn);
|
||||
}
|
||||
self.run_on_abort_hooks(&reason).await?;
|
||||
self.last_run_interrupted = true;
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
PreLlmRequestResult::Continue => {}
|
||||
|
|
@ -853,11 +957,14 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
|
||||
// ストリームを取得(キャンセル可能)
|
||||
let mut stream = tokio::select! {
|
||||
stream_result = self.client.stream(request) => stream_result?,
|
||||
_ = self.cancellation_token.cancelled() => {
|
||||
info!("Cancelled before stream started");
|
||||
stream_result = self.client.stream(request) => stream_result
|
||||
.inspect_err(|_| self.last_run_interrupted = true)?,
|
||||
cancel = self.cancel_rx.recv() => {
|
||||
if cancel.is_some() {
|
||||
info!("Cancelled before stream started");
|
||||
}
|
||||
self.timeline.abort_current_block();
|
||||
self.run_on_abort_hooks("Cancelled").await?;
|
||||
self.last_run_interrupted = true;
|
||||
return Err(WorkerError::Cancelled);
|
||||
}
|
||||
};
|
||||
|
|
@ -877,7 +984,8 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
warn!(error = %e, "Stream error");
|
||||
}
|
||||
}
|
||||
let event = result?;
|
||||
let event = result
|
||||
.inspect_err(|_| self.last_run_interrupted = true)?;
|
||||
let timeline_event: crate::timeline::event::Event = event.into();
|
||||
self.timeline.dispatch(&timeline_event);
|
||||
}
|
||||
|
|
@ -885,10 +993,12 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
}
|
||||
}
|
||||
// キャンセル待機
|
||||
_ = self.cancellation_token.cancelled() => {
|
||||
info!("Stream cancelled");
|
||||
cancel = self.cancel_rx.recv() => {
|
||||
if cancel.is_some() {
|
||||
info!("Stream cancelled");
|
||||
}
|
||||
self.timeline.abort_current_block();
|
||||
self.run_on_abort_hooks("Cancelled").await?;
|
||||
self.last_run_interrupted = true;
|
||||
return Err(WorkerError::Cancelled);
|
||||
}
|
||||
}
|
||||
|
|
@ -913,30 +1023,42 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
|
||||
if tool_calls.is_empty() {
|
||||
// ツール呼び出しなし → ターン終了判定
|
||||
let turn_result = self.run_on_turn_end_hooks().await?;
|
||||
let turn_result = self
|
||||
.run_on_turn_end_hooks()
|
||||
.await
|
||||
.inspect_err(|_| self.last_run_interrupted = true)?;
|
||||
match turn_result {
|
||||
OnTurnEndResult::Finish => {
|
||||
return Ok(WorkerResult::Finished(&self.history));
|
||||
self.last_run_interrupted = false;
|
||||
return Ok(WorkerResult::Finished);
|
||||
}
|
||||
OnTurnEndResult::ContinueWithMessages(additional) => {
|
||||
self.history.extend(additional);
|
||||
continue;
|
||||
}
|
||||
OnTurnEndResult::Paused => {
|
||||
return Ok(WorkerResult::Paused(&self.history));
|
||||
self.last_run_interrupted = true;
|
||||
return Ok(WorkerResult::Paused);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ツール実行
|
||||
match self.execute_tools(tool_calls).await? {
|
||||
ToolExecutionResult::Paused => return Ok(WorkerResult::Paused(&self.history)),
|
||||
ToolExecutionResult::Completed(results) => {
|
||||
match self.execute_tools(tool_calls).await {
|
||||
Ok(ToolExecutionResult::Paused) => {
|
||||
self.last_run_interrupted = true;
|
||||
return Ok(WorkerResult::Paused);
|
||||
}
|
||||
Ok(ToolExecutionResult::Completed(results)) => {
|
||||
for result in results {
|
||||
self.history
|
||||
.push(Message::tool_result(&result.tool_use_id, &result.content));
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
self.last_run_interrupted = true;
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -944,8 +1066,10 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
/// 実行を再開(Pause状態からの復帰)
|
||||
///
|
||||
/// 新しいユーザーメッセージを履歴に追加せず、現在の状態からターン処理を再開する。
|
||||
pub async fn resume(&mut self) -> Result<WorkerResult<'_>, WorkerError> {
|
||||
self.run_turn_loop().await
|
||||
pub async fn resume(&mut self) -> Result<WorkerResult, WorkerError> {
|
||||
self.reset_interruption_state();
|
||||
let result = self.run_turn_loop().await;
|
||||
self.finalize_interruption(result).await
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -959,6 +1083,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
|||
let text_block_collector = TextBlockCollector::new();
|
||||
let tool_call_collector = ToolCallCollector::new();
|
||||
let mut timeline = Timeline::new();
|
||||
let (cancel_tx, cancel_rx) = mpsc::channel(1);
|
||||
|
||||
// コレクターをTimelineに登録
|
||||
timeline.on_text_block(text_block_collector.clone());
|
||||
|
|
@ -977,7 +1102,9 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
|||
turn_count: 0,
|
||||
turn_notifiers: Vec::new(),
|
||||
request_config: RequestConfig::default(),
|
||||
cancellation_token: CancellationToken::new(),
|
||||
last_run_interrupted: false,
|
||||
cancel_tx,
|
||||
cancel_rx,
|
||||
_state: PhantomData,
|
||||
}
|
||||
}
|
||||
|
|
@ -1146,46 +1273,13 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
|||
turn_count: self.turn_count,
|
||||
turn_notifiers: self.turn_notifiers,
|
||||
request_config: self.request_config,
|
||||
cancellation_token: self.cancellation_token,
|
||||
last_run_interrupted: self.last_run_interrupted,
|
||||
cancel_tx: self.cancel_tx,
|
||||
cancel_rx: self.cancel_rx,
|
||||
_state: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// ターンを実行(Mutable状態)
|
||||
///
|
||||
/// 新しいユーザーメッセージを履歴に追加し、LLMにリクエストを送信する。
|
||||
/// ツール呼び出しがある場合は自動的にループする。
|
||||
///
|
||||
/// 注意: この関数は履歴を変更するため、キャッシュ保護が必要な場合は
|
||||
/// `lock()` を呼んでからLocked状態で `run` を使用すること。
|
||||
pub async fn run(
|
||||
&mut self,
|
||||
user_input: impl Into<String>,
|
||||
) -> Result<WorkerResult<'_>, WorkerError> {
|
||||
// Hook: on_prompt_submit
|
||||
let mut user_message = Message::user(user_input);
|
||||
let result = self.run_on_prompt_submit_hooks(&mut user_message).await?;
|
||||
match result {
|
||||
OnPromptSubmitResult::Cancel(reason) => {
|
||||
self.run_on_abort_hooks(&reason).await?;
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
OnPromptSubmitResult::Continue => {}
|
||||
}
|
||||
self.history.push(user_message);
|
||||
self.run_turn_loop().await
|
||||
}
|
||||
|
||||
/// 複数メッセージでターンを実行(Mutable状態)
|
||||
///
|
||||
/// 指定されたメッセージを履歴に追加してから実行する。
|
||||
pub async fn run_with_messages(
|
||||
&mut self,
|
||||
messages: Vec<Message>,
|
||||
) -> Result<WorkerResult<'_>, WorkerError> {
|
||||
self.history.extend(messages);
|
||||
self.run_turn_loop().await
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
|
@ -1193,37 +1287,6 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
|||
// =============================================================================
|
||||
|
||||
impl<C: LlmClient> Worker<C, Locked> {
|
||||
/// ターンを実行(Locked状態)
|
||||
///
|
||||
/// 新しいユーザーメッセージを履歴の末尾に追加し、LLMにリクエストを送信する。
|
||||
/// ロック時点より前の履歴(プレフィックス)は不変であるため、キャッシュヒットが保証される。
|
||||
pub async fn run(
|
||||
&mut self,
|
||||
user_input: impl Into<String>,
|
||||
) -> Result<WorkerResult<'_>, WorkerError> {
|
||||
// Hook: on_prompt_submit
|
||||
let mut user_message = Message::user(user_input);
|
||||
let result = self.run_on_prompt_submit_hooks(&mut user_message).await?;
|
||||
match result {
|
||||
OnPromptSubmitResult::Cancel(reason) => {
|
||||
self.run_on_abort_hooks(&reason).await?;
|
||||
return Err(WorkerError::Aborted(reason));
|
||||
}
|
||||
OnPromptSubmitResult::Continue => {}
|
||||
}
|
||||
self.history.push(user_message);
|
||||
self.run_turn_loop().await
|
||||
}
|
||||
|
||||
/// 複数メッセージでターンを実行(Locked状態)
|
||||
pub async fn run_with_messages(
|
||||
&mut self,
|
||||
messages: Vec<Message>,
|
||||
) -> Result<WorkerResult<'_>, WorkerError> {
|
||||
self.history.extend(messages);
|
||||
self.run_turn_loop().await
|
||||
}
|
||||
|
||||
/// ロック時点のプレフィックス長を取得
|
||||
pub fn locked_prefix_len(&self) -> usize {
|
||||
self.locked_prefix_len
|
||||
|
|
@ -1247,7 +1310,9 @@ impl<C: LlmClient> Worker<C, Locked> {
|
|||
turn_count: self.turn_count,
|
||||
turn_notifiers: self.turn_notifiers,
|
||||
request_config: self.request_config,
|
||||
cancellation_token: self.cancellation_token,
|
||||
last_run_interrupted: self.last_run_interrupted,
|
||||
cancel_tx: self.cancel_tx,
|
||||
cancel_rx: self.cancel_rx,
|
||||
_state: PhantomData,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ use std::sync::atomic::{AtomicUsize, Ordering};
|
|||
use schemars;
|
||||
use serde;
|
||||
|
||||
use llm_worker::tool::{Tool, ToolMeta};
|
||||
use llm_worker_macros::tool_registry;
|
||||
|
||||
// =============================================================================
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user