From 81107c6f5ce68a1bbe24a7ab8651f90e826bfe3c Mon Sep 17 00:00:00 2001 From: Hare Date: Fri, 9 Jan 2026 01:19:16 +0900 Subject: [PATCH] feat: Implement RequestConfig validation --- llm-worker/src/lib.rs | 10 +- llm-worker/src/llm_client/client.rs | 49 ++++- llm-worker/src/llm_client/providers/openai.rs | 14 +- .../llm_client/scheme/anthropic/request.rs | 3 + .../src/llm_client/scheme/gemini/request.rs | 4 + llm-worker/src/llm_client/types.rs | 65 ++++++- llm-worker/src/worker.rs | 179 +++++++++++++++++- llm-worker/tests/validation_test.rs | 40 ++++ 8 files changed, 355 insertions(+), 9 deletions(-) create mode 100644 llm-worker/tests/validation_test.rs diff --git a/llm-worker/src/lib.rs b/llm-worker/src/lib.rs index 550eb13..5668503 100644 --- a/llm-worker/src/lib.rs +++ b/llm-worker/src/lib.rs @@ -36,17 +36,17 @@ //! locked.run("user input").await?; //! ``` -pub mod llm_client; -pub mod timeline; +mod handler; +mod message; +mod worker; pub mod event; -mod handler; pub mod hook; -mod message; +pub mod llm_client; pub mod state; pub mod subscriber; +pub mod timeline; pub mod tool; -mod worker; // ============================================================================= // トップレベル公開(最も頻繁に使う型) diff --git a/llm-worker/src/llm_client/client.rs b/llm-worker/src/llm_client/client.rs index 77d9ea3..a148da9 100644 --- a/llm-worker/src/llm_client/client.rs +++ b/llm-worker/src/llm_client/client.rs @@ -2,10 +2,40 @@ use std::pin::Pin; -use crate::llm_client::{ClientError, Request, event::Event}; +use crate::llm_client::{ClientError, Request, RequestConfig, event::Event}; use async_trait::async_trait; use futures::Stream; +/// 設定に関する警告 +/// +/// プロバイダがサポートしていない設定を使用した場合に返される。 +#[derive(Debug, Clone)] +pub struct ConfigWarning { + /// 設定オプション名 + pub option_name: &'static str, + /// 警告メッセージ + pub message: String, +} + +impl ConfigWarning { + /// 新しい警告を作成 + pub fn unsupported(option_name: &'static str, provider_name: &str) -> Self { + Self { + option_name, + message: format!( + "'{}' is not supported by {} and will be ignored", + option_name, provider_name + ), + } + } +} + +impl std::fmt::Display for ConfigWarning { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}: {}", self.option_name, self.message) + } +} + /// LLMクライアントのtrait /// /// 各プロバイダはこのtraitを実装し、統一されたインターフェースを提供する。 @@ -23,6 +53,19 @@ pub trait LlmClient: Send + Sync { &self, request: Request, ) -> Result> + Send>>, ClientError>; + + /// 設定をバリデーションし、未サポートの設定があれば警告を返す + /// + /// # Arguments + /// * `config` - バリデーション対象の設定 + /// + /// # Returns + /// サポートされていない設定に対する警告のリスト + fn validate_config(&self, config: &RequestConfig) -> Vec { + // デフォルト実装: 全ての設定をサポート + let _ = config; + Vec::new() + } } /// `Box` に対する `LlmClient` の実装 @@ -36,4 +79,8 @@ impl LlmClient for Box { ) -> Result> + Send>>, ClientError> { (**self).stream(request).await } + + fn validate_config(&self, config: &RequestConfig) -> Vec { + (**self).validate_config(config) + } } diff --git a/llm-worker/src/llm_client/providers/openai.rs b/llm-worker/src/llm_client/providers/openai.rs index eec1f19..bf27d0c 100644 --- a/llm-worker/src/llm_client/providers/openai.rs +++ b/llm-worker/src/llm_client/providers/openai.rs @@ -5,7 +5,8 @@ use std::pin::Pin; use crate::llm_client::{ - ClientError, LlmClient, Request, event::Event, scheme::openai::OpenAIScheme, + ClientError, ConfigWarning, LlmClient, Request, RequestConfig, event::Event, + scheme::openai::OpenAIScheme, }; use async_trait::async_trait; use eventsource_stream::Eventsource; @@ -197,4 +198,15 @@ impl LlmClient for OpenAIClient { Ok(Box::pin(stream)) } + + fn validate_config(&self, config: &RequestConfig) -> Vec { + let mut warnings = Vec::new(); + + // OpenAI does not support top_k + if config.top_k.is_some() { + warnings.push(ConfigWarning::unsupported("top_k", "OpenAI")); + } + + warnings + } } diff --git a/llm-worker/src/llm_client/scheme/anthropic/request.rs b/llm-worker/src/llm_client/scheme/anthropic/request.rs index c554257..39e48f5 100644 --- a/llm-worker/src/llm_client/scheme/anthropic/request.rs +++ b/llm-worker/src/llm_client/scheme/anthropic/request.rs @@ -23,6 +23,8 @@ pub(crate) struct AnthropicRequest { pub temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, #[serde(skip_serializing_if = "Vec::is_empty")] pub stop_sequences: Vec, pub stream: bool, @@ -90,6 +92,7 @@ impl AnthropicScheme { tools, temperature: request.config.temperature, top_p: request.config.top_p, + top_k: request.config.top_k, stop_sequences: request.config.stop_sequences.clone(), stream: true, } diff --git a/llm-worker/src/llm_client/scheme/gemini/request.rs b/llm-worker/src/llm_client/scheme/gemini/request.rs index 6785ea8..6c2febb 100644 --- a/llm-worker/src/llm_client/scheme/gemini/request.rs +++ b/llm-worker/src/llm_client/scheme/gemini/request.rs @@ -133,6 +133,9 @@ pub(crate) struct GeminiGenerationConfig { /// Top P #[serde(skip_serializing_if = "Option::is_none")] pub top_p: Option, + /// Top K + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, /// ストップシーケンス #[serde(skip_serializing_if = "Vec::is_empty")] pub stop_sequences: Vec, @@ -183,6 +186,7 @@ impl GeminiScheme { max_output_tokens: request.config.max_tokens, temperature: request.config.temperature, top_p: request.config.top_p, + top_k: request.config.top_k, stop_sequences: request.config.stop_sequences.clone(), }); diff --git a/llm-worker/src/llm_client/types.rs b/llm-worker/src/llm_client/types.rs index ae71c57..749c5d2 100644 --- a/llm-worker/src/llm_client/types.rs +++ b/llm-worker/src/llm_client/types.rs @@ -62,6 +62,30 @@ impl Request { self.config.max_tokens = Some(max_tokens); self } + + /// temperatureを設定 + pub fn temperature(mut self, temperature: f32) -> Self { + self.config.temperature = Some(temperature); + self + } + + /// top_pを設定 + pub fn top_p(mut self, top_p: f32) -> Self { + self.config.top_p = Some(top_p); + self + } + + /// top_kを設定 + pub fn top_k(mut self, top_k: u32) -> Self { + self.config.top_k = Some(top_k); + self + } + + /// ストップシーケンスを追加 + pub fn stop_sequence(mut self, sequence: impl Into) -> Self { + self.config.stop_sequences.push(sequence.into()); + self + } } /// メッセージ @@ -191,8 +215,47 @@ pub struct RequestConfig { pub max_tokens: Option, /// Temperature pub temperature: Option, - /// Top P + /// Top P (nucleus sampling) pub top_p: Option, + /// Top K + pub top_k: Option, /// ストップシーケンス pub stop_sequences: Vec, } + +impl RequestConfig { + /// 新しいデフォルト設定を作成 + pub fn new() -> Self { + Self::default() + } + + /// 最大トークン数を設定 + pub fn with_max_tokens(mut self, max_tokens: u32) -> Self { + self.max_tokens = Some(max_tokens); + self + } + + /// temperatureを設定 + pub fn with_temperature(mut self, temperature: f32) -> Self { + self.temperature = Some(temperature); + self + } + + /// top_pを設定 + pub fn with_top_p(mut self, top_p: f32) -> Self { + self.top_p = Some(top_p); + self + } + + /// top_kを設定 + pub fn with_top_k(mut self, top_k: u32) -> Self { + self.top_k = Some(top_k); + self + } + + /// ストップシーケンスを追加 + pub fn with_stop_sequence(mut self, sequence: impl Into) -> Self { + self.stop_sequences.push(sequence.into()); + self + } +} diff --git a/llm-worker/src/worker.rs b/llm-worker/src/worker.rs index 8d898fb..3b87429 100644 --- a/llm-worker/src/worker.rs +++ b/llm-worker/src/worker.rs @@ -8,7 +8,7 @@ use tracing::{debug, info, trace, warn}; use crate::{ ContentPart, Message, MessageContent, Role, hook::{ControlFlow, HookError, ToolCall, ToolResult, TurnResult, WorkerHook}, - llm_client::{ClientError, LlmClient, Request, ToolDefinition}, + llm_client::{ClientError, ConfigWarning, LlmClient, Request, RequestConfig, ToolDefinition}, state::{Locked, Mutable, WorkerState}, subscriber::{ ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter, @@ -37,6 +37,9 @@ pub enum WorkerError { /// 処理が中断された #[error("Aborted: {0}")] Aborted(String), + /// 設定に関する警告(未サポートのオプション) + #[error("Config warnings: {}", .0.iter().map(|w| w.to_string()).collect::>().join(", "))] + ConfigWarnings(Vec), } // ============================================================================= @@ -139,6 +142,8 @@ pub struct Worker { turn_count: usize, /// ターン通知用のコールバック turn_notifiers: Vec>, + /// リクエスト設定(max_tokens, temperature等) + request_config: RequestConfig, /// 状態マーカー _state: PhantomData, } @@ -274,6 +279,83 @@ impl Worker { self.turn_count } + /// 現在のリクエスト設定への参照を取得 + pub fn request_config(&self) -> &RequestConfig { + &self.request_config + } + + /// 最大トークン数を設定 + /// + /// この設定はキャッシュロックとは独立しており、各リクエストに適用されます。 + /// + /// # Examples + /// + /// ```ignore + /// worker.set_max_tokens(4096); + /// ``` + pub fn set_max_tokens(&mut self, max_tokens: u32) { + self.request_config.max_tokens = Some(max_tokens); + } + + /// temperatureを設定 + /// + /// 0.0から1.0(または2.0)の範囲で設定します。 + /// 低い値はより決定的な出力を、高い値はより多様な出力を生成します。 + /// + /// # Examples + /// + /// ```ignore + /// worker.set_temperature(0.7); + /// ``` + pub fn set_temperature(&mut self, temperature: f32) { + self.request_config.temperature = Some(temperature); + } + + /// top_pを設定(nucleus sampling) + /// + /// # Examples + /// + /// ```ignore + /// worker.set_top_p(0.9); + /// ``` + pub fn set_top_p(&mut self, top_p: f32) { + self.request_config.top_p = Some(top_p); + } + + /// top_kを設定 + /// + /// トークン選択時に考慮する上位k個のトークンを指定します。 + /// + /// # Examples + /// + /// ```ignore + /// worker.set_top_k(40); + /// ``` + pub fn set_top_k(&mut self, top_k: u32) { + self.request_config.top_k = Some(top_k); + } + + /// ストップシーケンスを追加 + /// + /// # Examples + /// + /// ```ignore + /// worker.add_stop_sequence("\n\n"); + /// ``` + pub fn add_stop_sequence(&mut self, sequence: impl Into) { + self.request_config.stop_sequences.push(sequence.into()); + } + + /// ストップシーケンスをクリア + pub fn clear_stop_sequences(&mut self) { + self.request_config.stop_sequences.clear(); + } + + /// リクエスト設定を一括で設定 + pub fn set_request_config(&mut self, config: RequestConfig) { + self.request_config = config; + } + /// 登録されたツールからToolDefinitionのリストを生成 fn build_tool_definitions(&self) -> Vec { self.tools @@ -387,6 +469,9 @@ impl Worker { request = request.tool(tool_def.clone()); } + // リクエスト設定を適用 + request = request.config(self.request_config.clone()); + request } @@ -622,6 +707,7 @@ impl Worker { locked_prefix_len: 0, turn_count: 0, turn_notifiers: Vec::new(), + request_config: RequestConfig::default(), _state: PhantomData, } } @@ -637,6 +723,95 @@ impl Worker { self.system_prompt = Some(prompt.into()); } + /// 最大トークン数を設定(ビルダーパターン) + /// + /// # Examples + /// + /// ```ignore + /// let worker = Worker::new(client) + /// .system_prompt("You are a helpful assistant.") + /// .max_tokens(4096); + /// ``` + pub fn max_tokens(mut self, max_tokens: u32) -> Self { + self.request_config.max_tokens = Some(max_tokens); + self + } + + /// temperatureを設定(ビルダーパターン) + /// + /// # Examples + /// + /// ```ignore + /// let worker = Worker::new(client) + /// .temperature(0.7); + /// ``` + pub fn temperature(mut self, temperature: f32) -> Self { + self.request_config.temperature = Some(temperature); + self + } + + /// top_pを設定(ビルダーパターン) + pub fn top_p(mut self, top_p: f32) -> Self { + self.request_config.top_p = Some(top_p); + self + } + + /// top_kを設定(ビルダーパターン) + pub fn top_k(mut self, top_k: u32) -> Self { + self.request_config.top_k = Some(top_k); + self + } + + /// ストップシーケンスを追加(ビルダーパターン) + pub fn stop_sequence(mut self, sequence: impl Into) -> Self { + self.request_config.stop_sequences.push(sequence.into()); + self + } + + /// リクエスト設定をまとめて設定(ビルダーパターン) + /// + /// # Examples + /// + /// ```ignore + /// let config = RequestConfig::new() + /// .with_max_tokens(4096) + /// .with_temperature(0.7); + /// + /// let worker = Worker::new(client) + /// .system_prompt("...") + /// .with_config(config); + /// ``` + pub fn with_config(mut self, config: RequestConfig) -> Self { + self.request_config = config; + self + } + + /// 現在の設定をプロバイダに対してバリデーションする + /// + /// 未サポートの設定があればエラーを返す。 + /// チェーンの最後で呼び出すことで、設定の問題を早期に検出できる。 + /// + /// # Examples + /// + /// ```ignore + /// let worker = Worker::new(client) + /// .temperature(0.7) + /// .top_k(40) + /// .validate()?; // OpenAIならtop_kがサポートされないためエラー + /// ``` + /// + /// # Returns + /// * `Ok(Self)` - バリデーション成功 + /// * `Err(WorkerError::ConfigWarnings)` - 未サポートの設定がある + pub fn validate(self) -> Result { + let warnings = self.client.validate_config(&self.request_config); + if warnings.is_empty() { + Ok(self) + } else { + Err(WorkerError::ConfigWarnings(warnings)) + } + } + /// 履歴への可変参照を取得 /// /// Mutable状態でのみ利用可能。履歴を自由に編集できる。 @@ -700,6 +875,7 @@ impl Worker { locked_prefix_len, turn_count: self.turn_count, turn_notifiers: self.turn_notifiers, + request_config: self.request_config, _state: PhantomData, } } @@ -777,6 +953,7 @@ impl Worker { locked_prefix_len: 0, turn_count: self.turn_count, turn_notifiers: self.turn_notifiers, + request_config: self.request_config, _state: PhantomData, } } diff --git a/llm-worker/tests/validation_test.rs b/llm-worker/tests/validation_test.rs new file mode 100644 index 0000000..9ba3017 --- /dev/null +++ b/llm-worker/tests/validation_test.rs @@ -0,0 +1,40 @@ +use llm_worker::llm_client::LlmClient; +use llm_worker::llm_client::providers::openai::OpenAIClient; +use llm_worker::{Worker, WorkerError}; + +#[test] +fn test_openai_top_k_warning() { + // ダミーキーでクライアント作成(validate_configは通信しないため安全) + let client = OpenAIClient::new("dummy-key", "gpt-4o"); + + // top_kを設定したWorkerを作成 + let worker = Worker::new(client).top_k(50); // OpenAIはtop_k非対応 + + // validate()を実行 + let result = worker.validate(); + + // エラーが返り、ConfigWarningsが含まれていることを確認 + match result { + Err(WorkerError::ConfigWarnings(warnings)) => { + assert_eq!(warnings.len(), 1); + assert_eq!(warnings[0].option_name, "top_k"); + println!("Got expected warning: {}", warnings[0]); + } + Ok(_) => panic!("Should have returned validation error"), + Err(e) => panic!("Unexpected error type: {:?}", e), + } +} + +#[test] +fn test_openai_valid_config() { + let client = OpenAIClient::new("dummy-key", "gpt-4o"); + + // validな設定(temperatureのみ) + let worker = Worker::new(client).temperature(0.7); + + // validate()を実行 + let result = worker.validate(); + + // 成功を確認 + assert!(result.is_ok()); +}