Max Turnの実装
This commit is contained in:
parent
60505f206b
commit
0fe05e502e
|
|
@ -320,6 +320,7 @@ impl<C: LlmClient, St: Store> Session<C, St> {
|
||||||
let outcome = match result {
|
let outcome = match result {
|
||||||
Ok(WorkerResult::Finished) => Outcome::Finished,
|
Ok(WorkerResult::Finished) => Outcome::Finished,
|
||||||
Ok(WorkerResult::Paused) => Outcome::Paused,
|
Ok(WorkerResult::Paused) => Outcome::Paused,
|
||||||
|
Ok(WorkerResult::LimitReached) => Outcome::LimitReached,
|
||||||
Err(e) => Outcome::Error {
|
Err(e) => Outcome::Error {
|
||||||
message: e.to_string(),
|
message: e.to_string(),
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -67,6 +67,7 @@ pub enum LogEntry {
|
||||||
pub enum Outcome {
|
pub enum Outcome {
|
||||||
Finished,
|
Finished,
|
||||||
Paused,
|
Paused,
|
||||||
|
LimitReached,
|
||||||
Error { message: String },
|
Error { message: String },
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,3 +20,4 @@ LLM との対話を管理する低レベル基盤クレート。会話履歴、
|
||||||
- `timeline` — イベントストリームのディスパッチ(`Handler` トレイト、各ブロックコレクター)
|
- `timeline` — イベントストリームのディスパッチ(`Handler` トレイト、各ブロックコレクター)
|
||||||
- `event` — ストリーミングイベント型(`Event`, `BlockStart`, `BlockDelta` など)
|
- `event` — ストリーミングイベント型(`Event`, `BlockStart`, `BlockDelta` など)
|
||||||
- `state` — 型状態パターンによるキャッシュ保護(`Mutable` / `CacheLocked`)
|
- `state` — 型状態パターンによるキャッシュ保護(`Mutable` / `CacheLocked`)
|
||||||
|
cratesの整理Add READMEsRE to all crates@@
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
Ok(WorkerResult::Paused) => {
|
Ok(WorkerResult::Paused) => {
|
||||||
println!("⏸️ Task paused");
|
println!("⏸️ Task paused");
|
||||||
}
|
}
|
||||||
|
Ok(WorkerResult::LimitReached) => {
|
||||||
|
println!("🔒 Turn limit reached");
|
||||||
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
println!("❌ Task error: {}", e);
|
println!("❌ Task error: {}", e);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -84,6 +84,8 @@ pub enum WorkerResult {
|
||||||
Finished,
|
Finished,
|
||||||
/// Paused (can be resumed)
|
/// Paused (can be resumed)
|
||||||
Paused,
|
Paused,
|
||||||
|
/// Turn limit reached (max_turns exceeded)
|
||||||
|
LimitReached,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Internal: tool execution result
|
/// Internal: tool execution result
|
||||||
|
|
@ -179,6 +181,8 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
||||||
locked_prefix_len: usize,
|
locked_prefix_len: usize,
|
||||||
/// Turn count
|
/// Turn count
|
||||||
turn_count: usize,
|
turn_count: usize,
|
||||||
|
/// Maximum number of turns (None = unlimited)
|
||||||
|
max_turns: Option<u32>,
|
||||||
/// Turn notification callbacks
|
/// Turn notification callbacks
|
||||||
turn_notifiers: Vec<Box<dyn TurnNotifier>>,
|
turn_notifiers: Vec<Box<dyn TurnNotifier>>,
|
||||||
/// Request configuration (max_tokens, temperature, etc.)
|
/// Request configuration (max_tokens, temperature, etc.)
|
||||||
|
|
@ -1097,6 +1101,15 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
return Err(err);
|
return Err(err);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check turn limit (after assistant items and tool results are in history)
|
||||||
|
if let Some(max) = self.max_turns {
|
||||||
|
if self.turn_count >= max as usize {
|
||||||
|
info!(turn_count = self.turn_count, max_turns = max, "Turn limit reached");
|
||||||
|
self.last_run_interrupted = false;
|
||||||
|
return Ok(WorkerResult::LimitReached);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1137,6 +1150,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
||||||
history: Vec::new(),
|
history: Vec::new(),
|
||||||
locked_prefix_len: 0,
|
locked_prefix_len: 0,
|
||||||
turn_count: 0,
|
turn_count: 0,
|
||||||
|
max_turns: None,
|
||||||
turn_notifiers: Vec::new(),
|
turn_notifiers: Vec::new(),
|
||||||
request_config: RequestConfig::default(),
|
request_config: RequestConfig::default(),
|
||||||
last_run_interrupted: false,
|
last_run_interrupted: false,
|
||||||
|
|
@ -1330,6 +1344,11 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
||||||
self.turn_count = count;
|
self.turn_count = count;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the maximum number of turns. None means unlimited.
|
||||||
|
pub fn set_max_turns(&mut self, max_turns: Option<u32>) {
|
||||||
|
self.max_turns = max_turns;
|
||||||
|
}
|
||||||
|
|
||||||
/// Set the last_run_interrupted flag (for session restoration)
|
/// Set the last_run_interrupted flag (for session restoration)
|
||||||
pub fn set_last_run_interrupted(&mut self, interrupted: bool) {
|
pub fn set_last_run_interrupted(&mut self, interrupted: bool) {
|
||||||
self.last_run_interrupted = interrupted;
|
self.last_run_interrupted = interrupted;
|
||||||
|
|
@ -1366,6 +1385,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
||||||
history: self.history,
|
history: self.history,
|
||||||
locked_prefix_len,
|
locked_prefix_len,
|
||||||
turn_count: self.turn_count,
|
turn_count: self.turn_count,
|
||||||
|
max_turns: self.max_turns,
|
||||||
turn_notifiers: self.turn_notifiers,
|
turn_notifiers: self.turn_notifiers,
|
||||||
request_config: self.request_config,
|
request_config: self.request_config,
|
||||||
last_run_interrupted: self.last_run_interrupted,
|
last_run_interrupted: self.last_run_interrupted,
|
||||||
|
|
@ -1403,6 +1423,7 @@ impl<C: LlmClient> Worker<C, CacheLocked> {
|
||||||
history: self.history,
|
history: self.history,
|
||||||
locked_prefix_len: 0,
|
locked_prefix_len: 0,
|
||||||
turn_count: self.turn_count,
|
turn_count: self.turn_count,
|
||||||
|
max_turns: self.max_turns,
|
||||||
turn_notifiers: self.turn_notifiers,
|
turn_notifiers: self.turn_notifiers,
|
||||||
request_config: self.request_config,
|
request_config: self.request_config,
|
||||||
last_run_interrupted: self.last_run_interrupted,
|
last_run_interrupted: self.last_run_interrupted,
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ mod scope;
|
||||||
|
|
||||||
pub use scope::Scope;
|
pub use scope::Scope;
|
||||||
|
|
||||||
|
use std::num::NonZeroU32;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
@ -56,6 +57,8 @@ pub struct WorkerManifest {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
pub max_turns: Option<NonZeroU32>,
|
||||||
|
#[serde(default)]
|
||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -151,6 +154,55 @@ model = "llama3"
|
||||||
assert!(manifest.provider.api_key_env.is_none());
|
assert!(manifest.provider.api_key_env.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_max_turns() {
|
||||||
|
let toml = r#"
|
||||||
|
[pod]
|
||||||
|
name = "test"
|
||||||
|
|
||||||
|
[provider]
|
||||||
|
kind = "anthropic"
|
||||||
|
model = "claude-sonnet-4-20250514"
|
||||||
|
|
||||||
|
[worker]
|
||||||
|
max_turns = 50
|
||||||
|
"#;
|
||||||
|
let manifest = PodManifest::from_toml(toml).unwrap();
|
||||||
|
assert_eq!(manifest.worker.max_turns.unwrap().get(), 50);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn omitted_max_turns_is_none() {
|
||||||
|
let toml = r#"
|
||||||
|
[pod]
|
||||||
|
name = "test"
|
||||||
|
|
||||||
|
[provider]
|
||||||
|
kind = "anthropic"
|
||||||
|
model = "claude-sonnet-4-20250514"
|
||||||
|
|
||||||
|
[worker]
|
||||||
|
"#;
|
||||||
|
let manifest = PodManifest::from_toml(toml).unwrap();
|
||||||
|
assert!(manifest.worker.max_turns.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn reject_max_turns_zero() {
|
||||||
|
let toml = r#"
|
||||||
|
[pod]
|
||||||
|
name = "test"
|
||||||
|
|
||||||
|
[provider]
|
||||||
|
kind = "anthropic"
|
||||||
|
model = "claude-sonnet-4-20250514"
|
||||||
|
|
||||||
|
[worker]
|
||||||
|
max_turns = 0
|
||||||
|
"#;
|
||||||
|
assert!(PodManifest::from_toml(toml).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn reject_unknown_provider() {
|
fn reject_unknown_provider() {
|
||||||
let toml = r#"
|
let toml = r#"
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
match result {
|
match result {
|
||||||
PodRunResult::Finished => println!("(finished)"),
|
PodRunResult::Finished => println!("(finished)"),
|
||||||
PodRunResult::Paused => println!("(paused)"),
|
PodRunResult::Paused => println!("(paused)"),
|
||||||
|
PodRunResult::LimitReached => println!("(turn limit reached)"),
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. Extract the assistant's reply from history
|
// 5. Extract the assistant's reply from history
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ use llm_worker_persistence::Store;
|
||||||
use tokio::sync::{broadcast, mpsc};
|
use tokio::sync::{broadcast, mpsc};
|
||||||
|
|
||||||
use crate::pod::{Pod, PodRunResult, PodError};
|
use crate::pod::{Pod, PodRunResult, PodError};
|
||||||
use protocol::{ErrorCode, Event, Method, TurnResult};
|
use protocol::{ErrorCode, Event, Method, RunResult, TurnResult};
|
||||||
use crate::runtime_dir::RuntimeDir;
|
use crate::runtime_dir::RuntimeDir;
|
||||||
use crate::shared_state::{PodSharedState, PodStatus};
|
use crate::shared_state::{PodSharedState, PodStatus};
|
||||||
use crate::socket_server::SocketServer;
|
use crate::socket_server::SocketServer;
|
||||||
|
|
@ -193,10 +193,15 @@ where
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
result = &mut pod_future => {
|
result = &mut pod_future => {
|
||||||
return match result {
|
return match result {
|
||||||
Ok(r) => match r {
|
Ok(r) => {
|
||||||
PodRunResult::Finished => PodStatus::Idle,
|
let (status, run_result) = match r {
|
||||||
PodRunResult::Paused => PodStatus::Paused,
|
PodRunResult::Finished => (PodStatus::Idle, RunResult::Finished),
|
||||||
},
|
PodRunResult::Paused => (PodStatus::Paused, RunResult::Paused),
|
||||||
|
PodRunResult::LimitReached => (PodStatus::Idle, RunResult::LimitReached),
|
||||||
|
};
|
||||||
|
let _ = event_tx.send(Event::RunEnd { result: run_result });
|
||||||
|
status
|
||||||
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let code = worker_error_code(&e);
|
let code = worker_error_code(&e);
|
||||||
let _ = event_tx.send(Event::Error {
|
let _ = event_tx.send(Event::Error {
|
||||||
|
|
|
||||||
|
|
@ -142,6 +142,7 @@ pub fn apply_worker_manifest<C: LlmClient>(worker: &mut Worker<C>, wm: &WorkerMa
|
||||||
config.temperature = Some(temperature);
|
config.temperature = Some(temperature);
|
||||||
}
|
}
|
||||||
worker.set_request_config(config);
|
worker.set_request_config(config);
|
||||||
|
worker.set_max_turns(wm.max_turns.map(|n| n.get()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Result of a Pod run.
|
/// Result of a Pod run.
|
||||||
|
|
@ -151,6 +152,8 @@ pub enum PodRunResult {
|
||||||
Finished,
|
Finished,
|
||||||
/// The LLM paused (e.g. awaiting user confirmation via a hook).
|
/// The LLM paused (e.g. awaiting user confirmation via a hook).
|
||||||
Paused,
|
Paused,
|
||||||
|
/// The worker reached its configured max_turns limit.
|
||||||
|
LimitReached,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<llm_worker::WorkerResult> for PodRunResult {
|
impl From<llm_worker::WorkerResult> for PodRunResult {
|
||||||
|
|
@ -158,6 +161,7 @@ impl From<llm_worker::WorkerResult> for PodRunResult {
|
||||||
match r {
|
match r {
|
||||||
llm_worker::WorkerResult::Finished => PodRunResult::Finished,
|
llm_worker::WorkerResult::Finished => PodRunResult::Finished,
|
||||||
llm_worker::WorkerResult::Paused => PodRunResult::Paused,
|
llm_worker::WorkerResult::Paused => PodRunResult::Paused,
|
||||||
|
llm_worker::WorkerResult::LimitReached => PodRunResult::LimitReached,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -60,6 +60,9 @@ pub enum Event {
|
||||||
input_tokens: Option<u64>,
|
input_tokens: Option<u64>,
|
||||||
output_tokens: Option<u64>,
|
output_tokens: Option<u64>,
|
||||||
},
|
},
|
||||||
|
RunEnd {
|
||||||
|
result: RunResult,
|
||||||
|
},
|
||||||
Error {
|
Error {
|
||||||
code: ErrorCode,
|
code: ErrorCode,
|
||||||
message: String,
|
message: String,
|
||||||
|
|
@ -83,6 +86,14 @@ pub enum TurnResult {
|
||||||
Paused,
|
Paused,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum RunResult {
|
||||||
|
Finished,
|
||||||
|
Paused,
|
||||||
|
LimitReached,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub enum ErrorCode {
|
pub enum ErrorCode {
|
||||||
|
|
@ -126,6 +137,17 @@ mod tests {
|
||||||
assert_eq!(parsed["data"]["text"], "Hello");
|
assert_eq!(parsed["data"]["text"], "Hello");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn event_run_end_format() {
|
||||||
|
let event = Event::RunEnd {
|
||||||
|
result: RunResult::LimitReached,
|
||||||
|
};
|
||||||
|
let json = event.to_json_line().unwrap();
|
||||||
|
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
|
||||||
|
assert_eq!(parsed["event"], "run_end");
|
||||||
|
assert_eq!(parsed["data"]["result"], "limit_reached");
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn event_error_format() {
|
fn event_error_format() {
|
||||||
let event = Event::Error {
|
let event = Event::Error {
|
||||||
|
|
|
||||||
|
|
@ -131,6 +131,9 @@ impl App {
|
||||||
});
|
});
|
||||||
self.scroll_to_bottom();
|
self.scroll_to_bottom();
|
||||||
}
|
}
|
||||||
|
Event::RunEnd { result } => {
|
||||||
|
self.push_status(format!("[run end] {result:?}"));
|
||||||
|
}
|
||||||
Event::ToolCallArgsDelta { .. } => {}
|
Event::ToolCallArgsDelta { .. } => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user