feat: Implement RequestConfig validation
This commit is contained in:
parent
3b5c7e2d46
commit
81107c6f5c
|
|
@ -36,17 +36,17 @@
|
||||||
//! locked.run("user input").await?;
|
//! locked.run("user input").await?;
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
pub mod llm_client;
|
mod handler;
|
||||||
pub mod timeline;
|
mod message;
|
||||||
|
mod worker;
|
||||||
|
|
||||||
pub mod event;
|
pub mod event;
|
||||||
mod handler;
|
|
||||||
pub mod hook;
|
pub mod hook;
|
||||||
mod message;
|
pub mod llm_client;
|
||||||
pub mod state;
|
pub mod state;
|
||||||
pub mod subscriber;
|
pub mod subscriber;
|
||||||
|
pub mod timeline;
|
||||||
pub mod tool;
|
pub mod tool;
|
||||||
mod worker;
|
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// トップレベル公開(最も頻繁に使う型)
|
// トップレベル公開(最も頻繁に使う型)
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,40 @@
|
||||||
|
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
|
|
||||||
use crate::llm_client::{ClientError, Request, event::Event};
|
use crate::llm_client::{ClientError, Request, RequestConfig, event::Event};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
|
|
||||||
|
/// 設定に関する警告
|
||||||
|
///
|
||||||
|
/// プロバイダがサポートしていない設定を使用した場合に返される。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ConfigWarning {
|
||||||
|
/// 設定オプション名
|
||||||
|
pub option_name: &'static str,
|
||||||
|
/// 警告メッセージ
|
||||||
|
pub message: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConfigWarning {
|
||||||
|
/// 新しい警告を作成
|
||||||
|
pub fn unsupported(option_name: &'static str, provider_name: &str) -> Self {
|
||||||
|
Self {
|
||||||
|
option_name,
|
||||||
|
message: format!(
|
||||||
|
"'{}' is not supported by {} and will be ignored",
|
||||||
|
option_name, provider_name
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for ConfigWarning {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "{}: {}", self.option_name, self.message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// LLMクライアントのtrait
|
/// LLMクライアントのtrait
|
||||||
///
|
///
|
||||||
/// 各プロバイダはこのtraitを実装し、統一されたインターフェースを提供する。
|
/// 各プロバイダはこのtraitを実装し、統一されたインターフェースを提供する。
|
||||||
|
|
@ -23,6 +53,19 @@ pub trait LlmClient: Send + Sync {
|
||||||
&self,
|
&self,
|
||||||
request: Request,
|
request: Request,
|
||||||
) -> Result<Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>>, ClientError>;
|
) -> Result<Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>>, ClientError>;
|
||||||
|
|
||||||
|
/// 設定をバリデーションし、未サポートの設定があれば警告を返す
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `config` - バリデーション対象の設定
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// サポートされていない設定に対する警告のリスト
|
||||||
|
fn validate_config(&self, config: &RequestConfig) -> Vec<ConfigWarning> {
|
||||||
|
// デフォルト実装: 全ての設定をサポート
|
||||||
|
let _ = config;
|
||||||
|
Vec::new()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// `Box<dyn LlmClient>` に対する `LlmClient` の実装
|
/// `Box<dyn LlmClient>` に対する `LlmClient` の実装
|
||||||
|
|
@ -36,4 +79,8 @@ impl LlmClient for Box<dyn LlmClient> {
|
||||||
) -> Result<Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>>, ClientError> {
|
) -> Result<Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>>, ClientError> {
|
||||||
(**self).stream(request).await
|
(**self).stream(request).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn validate_config(&self, config: &RequestConfig) -> Vec<ConfigWarning> {
|
||||||
|
(**self).validate_config(config)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,8 @@
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
|
|
||||||
use crate::llm_client::{
|
use crate::llm_client::{
|
||||||
ClientError, LlmClient, Request, event::Event, scheme::openai::OpenAIScheme,
|
ClientError, ConfigWarning, LlmClient, Request, RequestConfig, event::Event,
|
||||||
|
scheme::openai::OpenAIScheme,
|
||||||
};
|
};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use eventsource_stream::Eventsource;
|
use eventsource_stream::Eventsource;
|
||||||
|
|
@ -197,4 +198,15 @@ impl LlmClient for OpenAIClient {
|
||||||
|
|
||||||
Ok(Box::pin(stream))
|
Ok(Box::pin(stream))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn validate_config(&self, config: &RequestConfig) -> Vec<ConfigWarning> {
|
||||||
|
let mut warnings = Vec::new();
|
||||||
|
|
||||||
|
// OpenAI does not support top_k
|
||||||
|
if config.top_k.is_some() {
|
||||||
|
warnings.push(ConfigWarning::unsupported("top_k", "OpenAI"));
|
||||||
|
}
|
||||||
|
|
||||||
|
warnings
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,8 @@ pub(crate) struct AnthropicRequest {
|
||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub top_p: Option<f32>,
|
pub top_p: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_k: Option<u32>,
|
||||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||||
pub stop_sequences: Vec<String>,
|
pub stop_sequences: Vec<String>,
|
||||||
pub stream: bool,
|
pub stream: bool,
|
||||||
|
|
@ -90,6 +92,7 @@ impl AnthropicScheme {
|
||||||
tools,
|
tools,
|
||||||
temperature: request.config.temperature,
|
temperature: request.config.temperature,
|
||||||
top_p: request.config.top_p,
|
top_p: request.config.top_p,
|
||||||
|
top_k: request.config.top_k,
|
||||||
stop_sequences: request.config.stop_sequences.clone(),
|
stop_sequences: request.config.stop_sequences.clone(),
|
||||||
stream: true,
|
stream: true,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -133,6 +133,9 @@ pub(crate) struct GeminiGenerationConfig {
|
||||||
/// Top P
|
/// Top P
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub top_p: Option<f32>,
|
pub top_p: Option<f32>,
|
||||||
|
/// Top K
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_k: Option<u32>,
|
||||||
/// ストップシーケンス
|
/// ストップシーケンス
|
||||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||||
pub stop_sequences: Vec<String>,
|
pub stop_sequences: Vec<String>,
|
||||||
|
|
@ -183,6 +186,7 @@ impl GeminiScheme {
|
||||||
max_output_tokens: request.config.max_tokens,
|
max_output_tokens: request.config.max_tokens,
|
||||||
temperature: request.config.temperature,
|
temperature: request.config.temperature,
|
||||||
top_p: request.config.top_p,
|
top_p: request.config.top_p,
|
||||||
|
top_k: request.config.top_k,
|
||||||
stop_sequences: request.config.stop_sequences.clone(),
|
stop_sequences: request.config.stop_sequences.clone(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,30 @@ impl Request {
|
||||||
self.config.max_tokens = Some(max_tokens);
|
self.config.max_tokens = Some(max_tokens);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// temperatureを設定
|
||||||
|
pub fn temperature(mut self, temperature: f32) -> Self {
|
||||||
|
self.config.temperature = Some(temperature);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// top_pを設定
|
||||||
|
pub fn top_p(mut self, top_p: f32) -> Self {
|
||||||
|
self.config.top_p = Some(top_p);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// top_kを設定
|
||||||
|
pub fn top_k(mut self, top_k: u32) -> Self {
|
||||||
|
self.config.top_k = Some(top_k);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// ストップシーケンスを追加
|
||||||
|
pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||||||
|
self.config.stop_sequences.push(sequence.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// メッセージ
|
/// メッセージ
|
||||||
|
|
@ -191,8 +215,47 @@ pub struct RequestConfig {
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
/// Temperature
|
/// Temperature
|
||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
/// Top P
|
/// Top P (nucleus sampling)
|
||||||
pub top_p: Option<f32>,
|
pub top_p: Option<f32>,
|
||||||
|
/// Top K
|
||||||
|
pub top_k: Option<u32>,
|
||||||
/// ストップシーケンス
|
/// ストップシーケンス
|
||||||
pub stop_sequences: Vec<String>,
|
pub stop_sequences: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl RequestConfig {
|
||||||
|
/// 新しいデフォルト設定を作成
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 最大トークン数を設定
|
||||||
|
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
|
||||||
|
self.max_tokens = Some(max_tokens);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// temperatureを設定
|
||||||
|
pub fn with_temperature(mut self, temperature: f32) -> Self {
|
||||||
|
self.temperature = Some(temperature);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// top_pを設定
|
||||||
|
pub fn with_top_p(mut self, top_p: f32) -> Self {
|
||||||
|
self.top_p = Some(top_p);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// top_kを設定
|
||||||
|
pub fn with_top_k(mut self, top_k: u32) -> Self {
|
||||||
|
self.top_k = Some(top_k);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// ストップシーケンスを追加
|
||||||
|
pub fn with_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||||||
|
self.stop_sequences.push(sequence.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ use tracing::{debug, info, trace, warn};
|
||||||
use crate::{
|
use crate::{
|
||||||
ContentPart, Message, MessageContent, Role,
|
ContentPart, Message, MessageContent, Role,
|
||||||
hook::{ControlFlow, HookError, ToolCall, ToolResult, TurnResult, WorkerHook},
|
hook::{ControlFlow, HookError, ToolCall, ToolResult, TurnResult, WorkerHook},
|
||||||
llm_client::{ClientError, LlmClient, Request, ToolDefinition},
|
llm_client::{ClientError, ConfigWarning, LlmClient, Request, RequestConfig, ToolDefinition},
|
||||||
state::{Locked, Mutable, WorkerState},
|
state::{Locked, Mutable, WorkerState},
|
||||||
subscriber::{
|
subscriber::{
|
||||||
ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter,
|
ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter,
|
||||||
|
|
@ -37,6 +37,9 @@ pub enum WorkerError {
|
||||||
/// 処理が中断された
|
/// 処理が中断された
|
||||||
#[error("Aborted: {0}")]
|
#[error("Aborted: {0}")]
|
||||||
Aborted(String),
|
Aborted(String),
|
||||||
|
/// 設定に関する警告(未サポートのオプション)
|
||||||
|
#[error("Config warnings: {}", .0.iter().map(|w| w.to_string()).collect::<Vec<_>>().join(", "))]
|
||||||
|
ConfigWarnings(Vec<ConfigWarning>),
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
@ -139,6 +142,8 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
|
||||||
turn_count: usize,
|
turn_count: usize,
|
||||||
/// ターン通知用のコールバック
|
/// ターン通知用のコールバック
|
||||||
turn_notifiers: Vec<Box<dyn TurnNotifier>>,
|
turn_notifiers: Vec<Box<dyn TurnNotifier>>,
|
||||||
|
/// リクエスト設定(max_tokens, temperature等)
|
||||||
|
request_config: RequestConfig,
|
||||||
/// 状態マーカー
|
/// 状態マーカー
|
||||||
_state: PhantomData<S>,
|
_state: PhantomData<S>,
|
||||||
}
|
}
|
||||||
|
|
@ -274,6 +279,83 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
self.turn_count
|
self.turn_count
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 現在のリクエスト設定への参照を取得
|
||||||
|
pub fn request_config(&self) -> &RequestConfig {
|
||||||
|
&self.request_config
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 最大トークン数を設定
|
||||||
|
///
|
||||||
|
/// この設定はキャッシュロックとは独立しており、各リクエストに適用されます。
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```ignore
|
||||||
|
/// worker.set_max_tokens(4096);
|
||||||
|
/// ```
|
||||||
|
pub fn set_max_tokens(&mut self, max_tokens: u32) {
|
||||||
|
self.request_config.max_tokens = Some(max_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// temperatureを設定
|
||||||
|
///
|
||||||
|
/// 0.0から1.0(または2.0)の範囲で設定します。
|
||||||
|
/// 低い値はより決定的な出力を、高い値はより多様な出力を生成します。
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```ignore
|
||||||
|
/// worker.set_temperature(0.7);
|
||||||
|
/// ```
|
||||||
|
pub fn set_temperature(&mut self, temperature: f32) {
|
||||||
|
self.request_config.temperature = Some(temperature);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// top_pを設定(nucleus sampling)
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```ignore
|
||||||
|
/// worker.set_top_p(0.9);
|
||||||
|
/// ```
|
||||||
|
pub fn set_top_p(&mut self, top_p: f32) {
|
||||||
|
self.request_config.top_p = Some(top_p);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// top_kを設定
|
||||||
|
///
|
||||||
|
/// トークン選択時に考慮する上位k個のトークンを指定します。
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```ignore
|
||||||
|
/// worker.set_top_k(40);
|
||||||
|
/// ```
|
||||||
|
pub fn set_top_k(&mut self, top_k: u32) {
|
||||||
|
self.request_config.top_k = Some(top_k);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// ストップシーケンスを追加
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```ignore
|
||||||
|
/// worker.add_stop_sequence("\n\n");
|
||||||
|
/// ```
|
||||||
|
pub fn add_stop_sequence(&mut self, sequence: impl Into<String>) {
|
||||||
|
self.request_config.stop_sequences.push(sequence.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// ストップシーケンスをクリア
|
||||||
|
pub fn clear_stop_sequences(&mut self) {
|
||||||
|
self.request_config.stop_sequences.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// リクエスト設定を一括で設定
|
||||||
|
pub fn set_request_config(&mut self, config: RequestConfig) {
|
||||||
|
self.request_config = config;
|
||||||
|
}
|
||||||
|
|
||||||
/// 登録されたツールからToolDefinitionのリストを生成
|
/// 登録されたツールからToolDefinitionのリストを生成
|
||||||
fn build_tool_definitions(&self) -> Vec<ToolDefinition> {
|
fn build_tool_definitions(&self) -> Vec<ToolDefinition> {
|
||||||
self.tools
|
self.tools
|
||||||
|
|
@ -387,6 +469,9 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
|
||||||
request = request.tool(tool_def.clone());
|
request = request.tool(tool_def.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// リクエスト設定を適用
|
||||||
|
request = request.config(self.request_config.clone());
|
||||||
|
|
||||||
request
|
request
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -622,6 +707,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
||||||
locked_prefix_len: 0,
|
locked_prefix_len: 0,
|
||||||
turn_count: 0,
|
turn_count: 0,
|
||||||
turn_notifiers: Vec::new(),
|
turn_notifiers: Vec::new(),
|
||||||
|
request_config: RequestConfig::default(),
|
||||||
_state: PhantomData,
|
_state: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -637,6 +723,95 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
||||||
self.system_prompt = Some(prompt.into());
|
self.system_prompt = Some(prompt.into());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 最大トークン数を設定(ビルダーパターン)
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```ignore
|
||||||
|
/// let worker = Worker::new(client)
|
||||||
|
/// .system_prompt("You are a helpful assistant.")
|
||||||
|
/// .max_tokens(4096);
|
||||||
|
/// ```
|
||||||
|
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
|
||||||
|
self.request_config.max_tokens = Some(max_tokens);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// temperatureを設定(ビルダーパターン)
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```ignore
|
||||||
|
/// let worker = Worker::new(client)
|
||||||
|
/// .temperature(0.7);
|
||||||
|
/// ```
|
||||||
|
pub fn temperature(mut self, temperature: f32) -> Self {
|
||||||
|
self.request_config.temperature = Some(temperature);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// top_pを設定(ビルダーパターン)
|
||||||
|
pub fn top_p(mut self, top_p: f32) -> Self {
|
||||||
|
self.request_config.top_p = Some(top_p);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// top_kを設定(ビルダーパターン)
|
||||||
|
pub fn top_k(mut self, top_k: u32) -> Self {
|
||||||
|
self.request_config.top_k = Some(top_k);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// ストップシーケンスを追加(ビルダーパターン)
|
||||||
|
pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||||||
|
self.request_config.stop_sequences.push(sequence.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// リクエスト設定をまとめて設定(ビルダーパターン)
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```ignore
|
||||||
|
/// let config = RequestConfig::new()
|
||||||
|
/// .with_max_tokens(4096)
|
||||||
|
/// .with_temperature(0.7);
|
||||||
|
///
|
||||||
|
/// let worker = Worker::new(client)
|
||||||
|
/// .system_prompt("...")
|
||||||
|
/// .with_config(config);
|
||||||
|
/// ```
|
||||||
|
pub fn with_config(mut self, config: RequestConfig) -> Self {
|
||||||
|
self.request_config = config;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 現在の設定をプロバイダに対してバリデーションする
|
||||||
|
///
|
||||||
|
/// 未サポートの設定があればエラーを返す。
|
||||||
|
/// チェーンの最後で呼び出すことで、設定の問題を早期に検出できる。
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```ignore
|
||||||
|
/// let worker = Worker::new(client)
|
||||||
|
/// .temperature(0.7)
|
||||||
|
/// .top_k(40)
|
||||||
|
/// .validate()?; // OpenAIならtop_kがサポートされないためエラー
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// * `Ok(Self)` - バリデーション成功
|
||||||
|
/// * `Err(WorkerError::ConfigWarnings)` - 未サポートの設定がある
|
||||||
|
pub fn validate(self) -> Result<Self, WorkerError> {
|
||||||
|
let warnings = self.client.validate_config(&self.request_config);
|
||||||
|
if warnings.is_empty() {
|
||||||
|
Ok(self)
|
||||||
|
} else {
|
||||||
|
Err(WorkerError::ConfigWarnings(warnings))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// 履歴への可変参照を取得
|
/// 履歴への可変参照を取得
|
||||||
///
|
///
|
||||||
/// Mutable状態でのみ利用可能。履歴を自由に編集できる。
|
/// Mutable状態でのみ利用可能。履歴を自由に編集できる。
|
||||||
|
|
@ -700,6 +875,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
|
||||||
locked_prefix_len,
|
locked_prefix_len,
|
||||||
turn_count: self.turn_count,
|
turn_count: self.turn_count,
|
||||||
turn_notifiers: self.turn_notifiers,
|
turn_notifiers: self.turn_notifiers,
|
||||||
|
request_config: self.request_config,
|
||||||
_state: PhantomData,
|
_state: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -777,6 +953,7 @@ impl<C: LlmClient> Worker<C, Locked> {
|
||||||
locked_prefix_len: 0,
|
locked_prefix_len: 0,
|
||||||
turn_count: self.turn_count,
|
turn_count: self.turn_count,
|
||||||
turn_notifiers: self.turn_notifiers,
|
turn_notifiers: self.turn_notifiers,
|
||||||
|
request_config: self.request_config,
|
||||||
_state: PhantomData,
|
_state: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
40
llm-worker/tests/validation_test.rs
Normal file
40
llm-worker/tests/validation_test.rs
Normal file
|
|
@ -0,0 +1,40 @@
|
||||||
|
use llm_worker::llm_client::LlmClient;
|
||||||
|
use llm_worker::llm_client::providers::openai::OpenAIClient;
|
||||||
|
use llm_worker::{Worker, WorkerError};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_openai_top_k_warning() {
|
||||||
|
// ダミーキーでクライアント作成(validate_configは通信しないため安全)
|
||||||
|
let client = OpenAIClient::new("dummy-key", "gpt-4o");
|
||||||
|
|
||||||
|
// top_kを設定したWorkerを作成
|
||||||
|
let worker = Worker::new(client).top_k(50); // OpenAIはtop_k非対応
|
||||||
|
|
||||||
|
// validate()を実行
|
||||||
|
let result = worker.validate();
|
||||||
|
|
||||||
|
// エラーが返り、ConfigWarningsが含まれていることを確認
|
||||||
|
match result {
|
||||||
|
Err(WorkerError::ConfigWarnings(warnings)) => {
|
||||||
|
assert_eq!(warnings.len(), 1);
|
||||||
|
assert_eq!(warnings[0].option_name, "top_k");
|
||||||
|
println!("Got expected warning: {}", warnings[0]);
|
||||||
|
}
|
||||||
|
Ok(_) => panic!("Should have returned validation error"),
|
||||||
|
Err(e) => panic!("Unexpected error type: {:?}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_openai_valid_config() {
|
||||||
|
let client = OpenAIClient::new("dummy-key", "gpt-4o");
|
||||||
|
|
||||||
|
// validな設定(temperatureのみ)
|
||||||
|
let worker = Worker::new(client).temperature(0.7);
|
||||||
|
|
||||||
|
// validate()を実行
|
||||||
|
let result = worker.validate();
|
||||||
|
|
||||||
|
// 成功を確認
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user