Max Turnの実装

This commit is contained in:
Keisuke Hirata 2026-04-11 03:16:36 +09:00
parent 60505f206b
commit 0fe05e502e
11 changed files with 119 additions and 5 deletions

View File

@ -320,6 +320,7 @@ impl<C: LlmClient, St: Store> Session<C, St> {
let outcome = match result {
Ok(WorkerResult::Finished) => Outcome::Finished,
Ok(WorkerResult::Paused) => Outcome::Paused,
Ok(WorkerResult::LimitReached) => Outcome::LimitReached,
Err(e) => Outcome::Error {
message: e.to_string(),
},

View File

@ -67,6 +67,7 @@ pub enum LogEntry {
pub enum Outcome {
Finished,
Paused,
LimitReached,
Error { message: String },
}

View File

@ -20,3 +20,4 @@ LLM との対話を管理する低レベル基盤クレート。会話履歴、
- `timeline` — イベントストリームのディスパッチ(`Handler` トレイト、各ブロックコレクター)
- `event` — ストリーミングイベント型(`Event`, `BlockStart`, `BlockDelta` など)
- `state` — 型状態パターンによるキャッシュ保護(`Mutable` / `CacheLocked`
cratesの整理Add READMEsRE to all crates@@

View File

@ -49,6 +49,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(WorkerResult::Paused) => {
println!("⏸️ Task paused");
}
Ok(WorkerResult::LimitReached) => {
println!("🔒 Turn limit reached");
}
Err(e) => {
println!("❌ Task error: {}", e);
}

View File

@ -84,6 +84,8 @@ pub enum WorkerResult {
Finished,
/// Paused (can be resumed)
Paused,
/// Turn limit reached (max_turns exceeded)
LimitReached,
}
/// Internal: tool execution result
@ -179,6 +181,8 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
locked_prefix_len: usize,
/// Turn count
turn_count: usize,
/// Maximum number of turns (None = unlimited)
max_turns: Option<u32>,
/// Turn notification callbacks
turn_notifiers: Vec<Box<dyn TurnNotifier>>,
/// Request configuration (max_tokens, temperature, etc.)
@ -1097,6 +1101,15 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
return Err(err);
}
}
// Check turn limit (after assistant items and tool results are in history)
if let Some(max) = self.max_turns {
if self.turn_count >= max as usize {
info!(turn_count = self.turn_count, max_turns = max, "Turn limit reached");
self.last_run_interrupted = false;
return Ok(WorkerResult::LimitReached);
}
}
}
}
@ -1137,6 +1150,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
history: Vec::new(),
locked_prefix_len: 0,
turn_count: 0,
max_turns: None,
turn_notifiers: Vec::new(),
request_config: RequestConfig::default(),
last_run_interrupted: false,
@ -1330,6 +1344,11 @@ impl<C: LlmClient> Worker<C, Mutable> {
self.turn_count = count;
}
/// Set the maximum number of turns. None means unlimited.
pub fn set_max_turns(&mut self, max_turns: Option<u32>) {
self.max_turns = max_turns;
}
/// Set the last_run_interrupted flag (for session restoration)
pub fn set_last_run_interrupted(&mut self, interrupted: bool) {
self.last_run_interrupted = interrupted;
@ -1366,6 +1385,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
history: self.history,
locked_prefix_len,
turn_count: self.turn_count,
max_turns: self.max_turns,
turn_notifiers: self.turn_notifiers,
request_config: self.request_config,
last_run_interrupted: self.last_run_interrupted,
@ -1403,6 +1423,7 @@ impl<C: LlmClient> Worker<C, CacheLocked> {
history: self.history,
locked_prefix_len: 0,
turn_count: self.turn_count,
max_turns: self.max_turns,
turn_notifiers: self.turn_notifiers,
request_config: self.request_config,
last_run_interrupted: self.last_run_interrupted,

View File

@ -2,6 +2,7 @@ mod scope;
pub use scope::Scope;
use std::num::NonZeroU32;
use std::path::PathBuf;
use serde::{Deserialize, Serialize};
@ -56,6 +57,8 @@ pub struct WorkerManifest {
#[serde(default)]
pub max_tokens: Option<u32>,
#[serde(default)]
pub max_turns: Option<NonZeroU32>,
#[serde(default)]
pub temperature: Option<f32>,
}
@ -151,6 +154,55 @@ model = "llama3"
assert!(manifest.provider.api_key_env.is_none());
}
#[test]
fn parse_max_turns() {
let toml = r#"
[pod]
name = "test"
[provider]
kind = "anthropic"
model = "claude-sonnet-4-20250514"
[worker]
max_turns = 50
"#;
let manifest = PodManifest::from_toml(toml).unwrap();
assert_eq!(manifest.worker.max_turns.unwrap().get(), 50);
}
#[test]
fn omitted_max_turns_is_none() {
let toml = r#"
[pod]
name = "test"
[provider]
kind = "anthropic"
model = "claude-sonnet-4-20250514"
[worker]
"#;
let manifest = PodManifest::from_toml(toml).unwrap();
assert!(manifest.worker.max_turns.is_none());
}
#[test]
fn reject_max_turns_zero() {
let toml = r#"
[pod]
name = "test"
[provider]
kind = "anthropic"
model = "claude-sonnet-4-20250514"
[worker]
max_turns = 0
"#;
assert!(PodManifest::from_toml(toml).is_err());
}
#[test]
fn reject_unknown_provider() {
let toml = r#"

View File

@ -49,6 +49,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
match result {
PodRunResult::Finished => println!("(finished)"),
PodRunResult::Paused => println!("(paused)"),
PodRunResult::LimitReached => println!("(turn limit reached)"),
}
// 5. Extract the assistant's reply from history

View File

@ -11,7 +11,7 @@ use llm_worker_persistence::Store;
use tokio::sync::{broadcast, mpsc};
use crate::pod::{Pod, PodRunResult, PodError};
use protocol::{ErrorCode, Event, Method, TurnResult};
use protocol::{ErrorCode, Event, Method, RunResult, TurnResult};
use crate::runtime_dir::RuntimeDir;
use crate::shared_state::{PodSharedState, PodStatus};
use crate::socket_server::SocketServer;
@ -193,10 +193,15 @@ where
tokio::select! {
result = &mut pod_future => {
return match result {
Ok(r) => match r {
PodRunResult::Finished => PodStatus::Idle,
PodRunResult::Paused => PodStatus::Paused,
},
Ok(r) => {
let (status, run_result) = match r {
PodRunResult::Finished => (PodStatus::Idle, RunResult::Finished),
PodRunResult::Paused => (PodStatus::Paused, RunResult::Paused),
PodRunResult::LimitReached => (PodStatus::Idle, RunResult::LimitReached),
};
let _ = event_tx.send(Event::RunEnd { result: run_result });
status
}
Err(e) => {
let code = worker_error_code(&e);
let _ = event_tx.send(Event::Error {

View File

@ -142,6 +142,7 @@ pub fn apply_worker_manifest<C: LlmClient>(worker: &mut Worker<C>, wm: &WorkerMa
config.temperature = Some(temperature);
}
worker.set_request_config(config);
worker.set_max_turns(wm.max_turns.map(|n| n.get()));
}
/// Result of a Pod run.
@ -151,6 +152,8 @@ pub enum PodRunResult {
Finished,
/// The LLM paused (e.g. awaiting user confirmation via a hook).
Paused,
/// The worker reached its configured max_turns limit.
LimitReached,
}
impl From<llm_worker::WorkerResult> for PodRunResult {
@ -158,6 +161,7 @@ impl From<llm_worker::WorkerResult> for PodRunResult {
match r {
llm_worker::WorkerResult::Finished => PodRunResult::Finished,
llm_worker::WorkerResult::Paused => PodRunResult::Paused,
llm_worker::WorkerResult::LimitReached => PodRunResult::LimitReached,
}
}
}

View File

@ -60,6 +60,9 @@ pub enum Event {
input_tokens: Option<u64>,
output_tokens: Option<u64>,
},
RunEnd {
result: RunResult,
},
Error {
code: ErrorCode,
message: String,
@ -83,6 +86,14 @@ pub enum TurnResult {
Paused,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RunResult {
Finished,
Paused,
LimitReached,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ErrorCode {
@ -126,6 +137,17 @@ mod tests {
assert_eq!(parsed["data"]["text"], "Hello");
}
#[test]
fn event_run_end_format() {
let event = Event::RunEnd {
result: RunResult::LimitReached,
};
let json = event.to_json_line().unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["event"], "run_end");
assert_eq!(parsed["data"]["result"], "limit_reached");
}
#[test]
fn event_error_format() {
let event = Event::Error {

View File

@ -131,6 +131,9 @@ impl App {
});
self.scroll_to_bottom();
}
Event::RunEnd { result } => {
self.push_status(format!("[run end] {result:?}"));
}
Event::ToolCallArgsDelta { .. } => {}
}
}