feat: Implement RequestConfig validation
This commit is contained in:
parent
3b5c7e2d46
commit
81107c6f5c
|
|
@ -36,17 +36,17 @@
|
|||
//! locked.run("user input").await?;
|
||||
//! ```
|
||||
|
||||
pub mod llm_client;
|
||||
pub mod timeline;
|
||||
mod handler;
|
||||
mod message;
|
||||
mod worker;
|
||||
|
||||
pub mod event;
|
||||
mod handler;
|
||||
pub mod hook;
|
||||
mod message;
|
||||
pub mod llm_client;
|
||||
pub mod state;
|
||||
pub mod subscriber;
|
||||
pub mod timeline;
|
||||
pub mod tool;
|
||||
mod worker;
|
||||
|
||||
// =============================================================================
|
||||
// トップレベル公開(最も頻繁に使う型)
|
||||
|
|
|
|||
|
|
@ -2,10 +2,40 @@
|
|||
|
||||
use std::pin::Pin;
|
||||
|
||||
use crate::llm_client::{ClientError, Request, event::Event};
|
||||
use crate::llm_client::{ClientError, Request, RequestConfig, event::Event};
|
||||
use async_trait::async_trait;
|
||||
use futures::Stream;
|
||||
|
||||
/// 設定に関する警告
|
||||
///
|
||||
/// プロバイダがサポートしていない設定を使用した場合に返される。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConfigWarning {
|
||||
/// 設定オプション名
|
||||
pub option_name: &'static str,
|
||||
/// 警告メッセージ
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl ConfigWarning {
|
||||
/// 新しい警告を作成
|
||||
pub fn unsupported(option_name: &'static str, provider_name: &str) -> Self {
|
||||
Self {
|
||||
option_name,
|
||||
message: format!(
|
||||
"'{}' is not supported by {} and will be ignored",
|
||||
option_name, provider_name
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ConfigWarning {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}: {}", self.option_name, self.message)
|
||||
}
|
||||
}
|
||||
|
||||
/// LLMクライアントのtrait
|
||||
///
|
||||
/// 各プロバイダはこのtraitを実装し、統一されたインターフェースを提供する。
|
||||
|
|
@ -23,6 +53,19 @@ pub trait LlmClient: Send + Sync {
|
|||
&self,
|
||||
request: Request,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>>, ClientError>;
|
||||
|
||||
/// 設定をバリデーションし、未サポートの設定があれば警告を返す
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `config` - バリデーション対象の設定
|
||||
///
|
||||
/// # Returns
|
||||
/// サポートされていない設定に対する警告のリスト
|
||||
fn validate_config(&self, config: &RequestConfig) -> Vec<ConfigWarning> {
|
||||
// デフォルト実装: 全ての設定をサポート
|
||||
let _ = config;
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// `Box<dyn LlmClient>` に対する `LlmClient` の実装
|
||||
|
|
@ -36,4 +79,8 @@ impl LlmClient for Box<dyn LlmClient> {
|
|||
) -> Result<Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>>, ClientError> {
|
||||
(**self).stream(request).await
|
||||
}
|
||||
|
||||
fn validate_config(&self, config: &RequestConfig) -> Vec<ConfigWarning> {
|
||||
(**self).validate_config(config)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
use std::pin::Pin;
|
||||
|
||||
use crate::llm_client::{
|
||||
ClientError, LlmClient, Request, event::Event, scheme::openai::OpenAIScheme,
|
||||
ClientError, ConfigWarning, LlmClient, Request, RequestConfig, event::Event,
|
||||
scheme::openai::OpenAIScheme,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use eventsource_stream::Eventsource;
|
||||
|
|
@ -197,4 +198,15 @@ impl LlmClient for OpenAIClient {
|
|||
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
|
||||
fn validate_config(&self, config: &RequestConfig) -> Vec<ConfigWarning> {
|
||||
let mut warnings = Vec::new();
|
||||
|
||||
// OpenAI does not support top_k
|
||||
if config.top_k.is_some() {
|
||||
warnings.push(ConfigWarning::unsupported("top_k", "OpenAI"));
|
||||
}
|
||||
|
||||
warnings
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,6 +23,8 @@ pub(crate) struct AnthropicRequest {
|
|||
pub temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_k: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub stop_sequences: Vec<String>,
|
||||
pub stream: bool,
|
||||
|
|
@ -90,6 +92,7 @@ impl AnthropicScheme {
|
|||
tools,
|
||||
temperature: request.config.temperature,
|
||||
top_p: request.config.top_p,
|
||||
top_k: request.config.top_k,
|
||||
stop_sequences: request.config.stop_sequences.clone(),
|
||||
stream: true,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -133,6 +133,9 @@ pub(crate) struct GeminiGenerationConfig {
|
|||
/// Top P
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
/// Top K
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_k: Option<u32>,
|
||||
/// ストップシーケンス
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub stop_sequences: Vec<String>,
|
||||
|
|
@ -183,6 +186,7 @@ impl GeminiScheme {
|
|||
max_output_tokens: request.config.max_tokens,
|
||||
temperature: request.config.temperature,
|
||||
top_p: request.config.top_p,
|
||||
top_k: request.config.top_k,
|
||||
stop_sequences: request.config.stop_sequences.clone(),
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -62,6 +62,30 @@ impl Request {
|
|||
self.config.max_tokens = Some(max_tokens);
|
||||
self
|
||||
}
|
||||
|
||||
/// temperatureを設定
|
||||
pub fn temperature(mut self, temperature: f32) -> Self {
|
||||
self.config.temperature = Some(temperature);
|
||||
self
|
||||
}
|
||||
|
||||
/// top_pを設定
|
||||
pub fn top_p(mut self, top_p: f32) -> Self {
|
||||
self.config.top_p = Some(top_p);
|
||||
self
|
||||
}
|
||||
|
||||
/// top_kを設定
|
||||
pub fn top_k(mut self, top_k: u32) -> Self {
|
||||
self.config.top_k = Some(top_k);
|
||||
self
|
||||
}
|
||||
|
||||
/// ストップシーケンスを追加
|
||||
pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||||
self.config.stop_sequences.push(sequence.into());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// メッセージ
|
||||
|
|
@ -191,8 +215,47 @@ pub struct RequestConfig {
|
|||
pub max_tokens: Option<u32>,
|
||||
/// Temperature
|
||||
pub temperature: Option<f32>,
|
||||
/// Top P
|
||||
/// Top P (nucleus sampling)
|
||||
pub top_p: Option<f32>,
|
||||
/// Top K
|
||||
pub top_k: Option<u32>,
|
||||
/// ストップシーケンス
|
||||
pub stop_sequences: Vec<String>,
|
||||
}
|
||||
|
||||
impl RequestConfig {
|
||||
/// 新しいデフォルト設定を作成
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// 最大トークン数を設定
|
||||
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
|
||||
self.max_tokens = Some(max_tokens);
|
||||
self
|
||||
}
|
||||
|
||||
/// temperatureを設定
|
||||
pub fn with_temperature(mut self, temperature: f32) -> Self {
|
||||
self.temperature = Some(temperature);
|
||||
self
|
||||
}
|
||||
|
||||
/// top_pを設定
|
||||
pub fn with_top_p(mut self, top_p: f32) -> Self {
|
||||
self.top_p = Some(top_p);
|
||||
self
|
||||
}
|
||||
|
||||
/// top_kを設定
|
||||
pub fn with_top_k(mut self, top_k: u32) -> Self {
|
||||
self.top_k = Some(top_k);
|
||||
self
|
||||
}
|
||||
|
||||
/// ストップシーケンスを追加
|
||||
pub fn with_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||||
self.stop_sequences.push(sequence.into());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ use tracing::{debug, info, trace, warn};
|
|||
use crate::{
|
||||
ContentPart, Message, MessageContent, Role,
|
||||
hook::{ControlFlow, HookError, ToolCall, ToolResult, TurnResult, WorkerHook},
|
||||
llm_client::{ClientError, LlmClient, Request, ToolDefinition},
|
||||
llm_client::{ClientError, ConfigWarning, LlmClient, Request, RequestConfig, ToolDefinition},
|
||||
state::{Locked, Mutable, WorkerState},
|
||||
subscriber::{
|
||||
ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter,
|
||||
|
|
@ -37,6 +37,9 @@ pub enum WorkerError {
|
|||
/// 処理が中断された
|
||||
#[error("Aborted: {0}")]
|
||||
Aborted(String),
|
||||
/// 設定に関する警告(未サポートのオプション)
|
||||
#[error("Config warnings: {}", .0.iter().map(|w| w.to_string()).collect::<Vec<_>>().join(", "))]
|
||||
ConfigWarnings(Vec<ConfigWarning>),
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
|
@ -139,6 +142,8 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
|||
turn_count: usize,
|
||||
/// ターン通知用のコールバック
|
||||
turn_notifiers: Vec<Box<dyn TurnNotifier>>,
|
||||
/// リクエスト設定(max_tokens, temperature等)
|
||||
request_config: RequestConfig,
|
||||
/// 状態マーカー
|
||||
_state: PhantomData<S>,
|
||||
}
|
||||
|
|
@ -274,6 +279,83 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
self.turn_count
|
||||
}
|
||||
|
||||
/// 現在のリクエスト設定への参照を取得
|
||||
pub fn request_config(&self) -> &RequestConfig {
|
||||
&self.request_config
|
||||
}
|
||||
|
||||
/// 最大トークン数を設定
|
||||
///
|
||||
/// この設定はキャッシュロックとは独立しており、各リクエストに適用されます。
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```ignore
|
||||
/// worker.set_max_tokens(4096);
|
||||
/// ```
|
||||
pub fn set_max_tokens(&mut self, max_tokens: u32) {
|
||||
self.request_config.max_tokens = Some(max_tokens);
|
||||
}
|
||||
|
||||
/// temperatureを設定
|
||||
///
|
||||
/// 0.0から1.0(または2.0)の範囲で設定します。
|
||||
/// 低い値はより決定的な出力を、高い値はより多様な出力を生成します。
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```ignore
|
||||
/// worker.set_temperature(0.7);
|
||||
/// ```
|
||||
pub fn set_temperature(&mut self, temperature: f32) {
|
||||
self.request_config.temperature = Some(temperature);
|
||||
}
|
||||
|
||||
/// top_pを設定(nucleus sampling)
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```ignore
|
||||
/// worker.set_top_p(0.9);
|
||||
/// ```
|
||||
pub fn set_top_p(&mut self, top_p: f32) {
|
||||
self.request_config.top_p = Some(top_p);
|
||||
}
|
||||
|
||||
/// top_kを設定
|
||||
///
|
||||
/// トークン選択時に考慮する上位k個のトークンを指定します。
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```ignore
|
||||
/// worker.set_top_k(40);
|
||||
/// ```
|
||||
pub fn set_top_k(&mut self, top_k: u32) {
|
||||
self.request_config.top_k = Some(top_k);
|
||||
}
|
||||
|
||||
/// ストップシーケンスを追加
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```ignore
|
||||
/// worker.add_stop_sequence("\n\n");
|
||||
/// ```
|
||||
pub fn add_stop_sequence(&mut self, sequence: impl Into<String>) {
|
||||
self.request_config.stop_sequences.push(sequence.into());
|
||||
}
|
||||
|
||||
/// ストップシーケンスをクリア
|
||||
pub fn clear_stop_sequences(&mut self) {
|
||||
self.request_config.stop_sequences.clear();
|
||||
}
|
||||
|
||||
/// リクエスト設定を一括で設定
|
||||
pub fn set_request_config(&mut self, config: RequestConfig) {
|
||||
self.request_config = config;
|
||||
}
|
||||
|
||||
/// 登録されたツールからToolDefinitionのリストを生成
|
||||
fn build_tool_definitions(&self) -> Vec<ToolDefinition> {
|
||||
self.tools
|
||||
|
|
@ -387,6 +469,9 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
|||
request = request.tool(tool_def.clone());
|
||||
}
|
||||
|
||||
// リクエスト設定を適用
|
||||
request = request.config(self.request_config.clone());
|
||||
|
||||
request
|
||||
}
|
||||
|
||||
|
|
@ -622,6 +707,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
|||
locked_prefix_len: 0,
|
||||
turn_count: 0,
|
||||
turn_notifiers: Vec::new(),
|
||||
request_config: RequestConfig::default(),
|
||||
_state: PhantomData,
|
||||
}
|
||||
}
|
||||
|
|
@ -637,6 +723,95 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
|||
self.system_prompt = Some(prompt.into());
|
||||
}
|
||||
|
||||
/// 最大トークン数を設定(ビルダーパターン)
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```ignore
|
||||
/// let worker = Worker::new(client)
|
||||
/// .system_prompt("You are a helpful assistant.")
|
||||
/// .max_tokens(4096);
|
||||
/// ```
|
||||
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
|
||||
self.request_config.max_tokens = Some(max_tokens);
|
||||
self
|
||||
}
|
||||
|
||||
/// temperatureを設定(ビルダーパターン)
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```ignore
|
||||
/// let worker = Worker::new(client)
|
||||
/// .temperature(0.7);
|
||||
/// ```
|
||||
pub fn temperature(mut self, temperature: f32) -> Self {
|
||||
self.request_config.temperature = Some(temperature);
|
||||
self
|
||||
}
|
||||
|
||||
/// top_pを設定(ビルダーパターン)
|
||||
pub fn top_p(mut self, top_p: f32) -> Self {
|
||||
self.request_config.top_p = Some(top_p);
|
||||
self
|
||||
}
|
||||
|
||||
/// top_kを設定(ビルダーパターン)
|
||||
pub fn top_k(mut self, top_k: u32) -> Self {
|
||||
self.request_config.top_k = Some(top_k);
|
||||
self
|
||||
}
|
||||
|
||||
/// ストップシーケンスを追加(ビルダーパターン)
|
||||
pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||||
self.request_config.stop_sequences.push(sequence.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// リクエスト設定をまとめて設定(ビルダーパターン)
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```ignore
|
||||
/// let config = RequestConfig::new()
|
||||
/// .with_max_tokens(4096)
|
||||
/// .with_temperature(0.7);
|
||||
///
|
||||
/// let worker = Worker::new(client)
|
||||
/// .system_prompt("...")
|
||||
/// .with_config(config);
|
||||
/// ```
|
||||
pub fn with_config(mut self, config: RequestConfig) -> Self {
|
||||
self.request_config = config;
|
||||
self
|
||||
}
|
||||
|
||||
/// 現在の設定をプロバイダに対してバリデーションする
|
||||
///
|
||||
/// 未サポートの設定があればエラーを返す。
|
||||
/// チェーンの最後で呼び出すことで、設定の問題を早期に検出できる。
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```ignore
|
||||
/// let worker = Worker::new(client)
|
||||
/// .temperature(0.7)
|
||||
/// .top_k(40)
|
||||
/// .validate()?; // OpenAIならtop_kがサポートされないためエラー
|
||||
/// ```
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(Self)` - バリデーション成功
|
||||
/// * `Err(WorkerError::ConfigWarnings)` - 未サポートの設定がある
|
||||
pub fn validate(self) -> Result<Self, WorkerError> {
|
||||
let warnings = self.client.validate_config(&self.request_config);
|
||||
if warnings.is_empty() {
|
||||
Ok(self)
|
||||
} else {
|
||||
Err(WorkerError::ConfigWarnings(warnings))
|
||||
}
|
||||
}
|
||||
|
||||
/// 履歴への可変参照を取得
|
||||
///
|
||||
/// Mutable状態でのみ利用可能。履歴を自由に編集できる。
|
||||
|
|
@ -700,6 +875,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
|||
locked_prefix_len,
|
||||
turn_count: self.turn_count,
|
||||
turn_notifiers: self.turn_notifiers,
|
||||
request_config: self.request_config,
|
||||
_state: PhantomData,
|
||||
}
|
||||
}
|
||||
|
|
@ -777,6 +953,7 @@ impl<C: LlmClient> Worker<C, Locked> {
|
|||
locked_prefix_len: 0,
|
||||
turn_count: self.turn_count,
|
||||
turn_notifiers: self.turn_notifiers,
|
||||
request_config: self.request_config,
|
||||
_state: PhantomData,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
40
llm-worker/tests/validation_test.rs
Normal file
40
llm-worker/tests/validation_test.rs
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
use llm_worker::llm_client::LlmClient;
|
||||
use llm_worker::llm_client::providers::openai::OpenAIClient;
|
||||
use llm_worker::{Worker, WorkerError};
|
||||
|
||||
#[test]
|
||||
fn test_openai_top_k_warning() {
|
||||
// ダミーキーでクライアント作成(validate_configは通信しないため安全)
|
||||
let client = OpenAIClient::new("dummy-key", "gpt-4o");
|
||||
|
||||
// top_kを設定したWorkerを作成
|
||||
let worker = Worker::new(client).top_k(50); // OpenAIはtop_k非対応
|
||||
|
||||
// validate()を実行
|
||||
let result = worker.validate();
|
||||
|
||||
// エラーが返り、ConfigWarningsが含まれていることを確認
|
||||
match result {
|
||||
Err(WorkerError::ConfigWarnings(warnings)) => {
|
||||
assert_eq!(warnings.len(), 1);
|
||||
assert_eq!(warnings[0].option_name, "top_k");
|
||||
println!("Got expected warning: {}", warnings[0]);
|
||||
}
|
||||
Ok(_) => panic!("Should have returned validation error"),
|
||||
Err(e) => panic!("Unexpected error type: {:?}", e),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_valid_config() {
|
||||
let client = OpenAIClient::new("dummy-key", "gpt-4o");
|
||||
|
||||
// validな設定(temperatureのみ)
|
||||
let worker = Worker::new(client).temperature(0.7);
|
||||
|
||||
// validate()を実行
|
||||
let result = worker.validate();
|
||||
|
||||
// 成功を確認
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user