develop #3

Merged
Hare merged 13 commits from develop into master 2026-01-11 00:00:22 +09:00
8 changed files with 355 additions and 9 deletions
Showing only changes of commit 81107c6f5c - Show all commits

View File

@ -36,17 +36,17 @@
//! locked.run("user input").await?;
//! ```
pub mod llm_client;
pub mod timeline;
mod handler;
mod message;
mod worker;
pub mod event;
mod handler;
pub mod hook;
mod message;
pub mod llm_client;
pub mod state;
pub mod subscriber;
pub mod timeline;
pub mod tool;
mod worker;
// =============================================================================
// トップレベル公開(最も頻繁に使う型)

View File

@ -2,10 +2,40 @@
use std::pin::Pin;
use crate::llm_client::{ClientError, Request, event::Event};
use crate::llm_client::{ClientError, Request, RequestConfig, event::Event};
use async_trait::async_trait;
use futures::Stream;
/// 設定に関する警告
///
/// プロバイダがサポートしていない設定を使用した場合に返される。
#[derive(Debug, Clone)]
pub struct ConfigWarning {
/// 設定オプション名
pub option_name: &'static str,
/// 警告メッセージ
pub message: String,
}
impl ConfigWarning {
/// 新しい警告を作成
pub fn unsupported(option_name: &'static str, provider_name: &str) -> Self {
Self {
option_name,
message: format!(
"'{}' is not supported by {} and will be ignored",
option_name, provider_name
),
}
}
}
impl std::fmt::Display for ConfigWarning {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {}", self.option_name, self.message)
}
}
/// LLMクライアントのtrait
///
/// 各プロバイダはこのtraitを実装し、統一されたインターフェースを提供する。
@ -23,6 +53,19 @@ pub trait LlmClient: Send + Sync {
&self,
request: Request,
) -> Result<Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>>, ClientError>;
/// 設定をバリデーションし、未サポートの設定があれば警告を返す
///
/// # Arguments
/// * `config` - バリデーション対象の設定
///
/// # Returns
/// サポートされていない設定に対する警告のリスト
fn validate_config(&self, config: &RequestConfig) -> Vec<ConfigWarning> {
// デフォルト実装: 全ての設定をサポート
let _ = config;
Vec::new()
}
}
/// `Box<dyn LlmClient>` に対する `LlmClient` の実装
@ -36,4 +79,8 @@ impl LlmClient for Box<dyn LlmClient> {
) -> Result<Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>>, ClientError> {
(**self).stream(request).await
}
fn validate_config(&self, config: &RequestConfig) -> Vec<ConfigWarning> {
(**self).validate_config(config)
}
}

View File

@ -5,7 +5,8 @@
use std::pin::Pin;
use crate::llm_client::{
ClientError, LlmClient, Request, event::Event, scheme::openai::OpenAIScheme,
ClientError, ConfigWarning, LlmClient, Request, RequestConfig, event::Event,
scheme::openai::OpenAIScheme,
};
use async_trait::async_trait;
use eventsource_stream::Eventsource;
@ -197,4 +198,15 @@ impl LlmClient for OpenAIClient {
Ok(Box::pin(stream))
}
fn validate_config(&self, config: &RequestConfig) -> Vec<ConfigWarning> {
let mut warnings = Vec::new();
// OpenAI does not support top_k
if config.top_k.is_some() {
warnings.push(ConfigWarning::unsupported("top_k", "OpenAI"));
}
warnings
}
}

View File

@ -23,6 +23,8 @@ pub(crate) struct AnthropicRequest {
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub stop_sequences: Vec<String>,
pub stream: bool,
@ -90,6 +92,7 @@ impl AnthropicScheme {
tools,
temperature: request.config.temperature,
top_p: request.config.top_p,
top_k: request.config.top_k,
stop_sequences: request.config.stop_sequences.clone(),
stream: true,
}

View File

@ -133,6 +133,9 @@ pub(crate) struct GeminiGenerationConfig {
/// Top P
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
/// Top K
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
/// ストップシーケンス
#[serde(skip_serializing_if = "Vec::is_empty")]
pub stop_sequences: Vec<String>,
@ -183,6 +186,7 @@ impl GeminiScheme {
max_output_tokens: request.config.max_tokens,
temperature: request.config.temperature,
top_p: request.config.top_p,
top_k: request.config.top_k,
stop_sequences: request.config.stop_sequences.clone(),
});

View File

@ -62,6 +62,30 @@ impl Request {
self.config.max_tokens = Some(max_tokens);
self
}
/// temperatureを設定
pub fn temperature(mut self, temperature: f32) -> Self {
self.config.temperature = Some(temperature);
self
}
/// top_pを設定
pub fn top_p(mut self, top_p: f32) -> Self {
self.config.top_p = Some(top_p);
self
}
/// top_kを設定
pub fn top_k(mut self, top_k: u32) -> Self {
self.config.top_k = Some(top_k);
self
}
/// ストップシーケンスを追加
pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
self.config.stop_sequences.push(sequence.into());
self
}
}
/// メッセージ
@ -191,8 +215,47 @@ pub struct RequestConfig {
pub max_tokens: Option<u32>,
/// Temperature
pub temperature: Option<f32>,
/// Top P
/// Top P (nucleus sampling)
pub top_p: Option<f32>,
/// Top K
pub top_k: Option<u32>,
/// ストップシーケンス
pub stop_sequences: Vec<String>,
}
impl RequestConfig {
/// 新しいデフォルト設定を作成
pub fn new() -> Self {
Self::default()
}
/// 最大トークン数を設定
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
/// temperatureを設定
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
/// top_pを設定
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
/// top_kを設定
pub fn with_top_k(mut self, top_k: u32) -> Self {
self.top_k = Some(top_k);
self
}
/// ストップシーケンスを追加
pub fn with_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
self.stop_sequences.push(sequence.into());
self
}
}

View File

@ -8,7 +8,7 @@ use tracing::{debug, info, trace, warn};
use crate::{
ContentPart, Message, MessageContent, Role,
hook::{ControlFlow, HookError, ToolCall, ToolResult, TurnResult, WorkerHook},
llm_client::{ClientError, LlmClient, Request, ToolDefinition},
llm_client::{ClientError, ConfigWarning, LlmClient, Request, RequestConfig, ToolDefinition},
state::{Locked, Mutable, WorkerState},
subscriber::{
ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter,
@ -37,6 +37,9 @@ pub enum WorkerError {
/// 処理が中断された
#[error("Aborted: {0}")]
Aborted(String),
/// 設定に関する警告(未サポートのオプション)
#[error("Config warnings: {}", .0.iter().map(|w| w.to_string()).collect::<Vec<_>>().join(", "))]
ConfigWarnings(Vec<ConfigWarning>),
}
// =============================================================================
@ -139,6 +142,8 @@ pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
turn_count: usize,
/// ターン通知用のコールバック
turn_notifiers: Vec<Box<dyn TurnNotifier>>,
/// リクエスト設定max_tokens, temperature等
request_config: RequestConfig,
/// 状態マーカー
_state: PhantomData<S>,
}
@ -274,6 +279,83 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
self.turn_count
}
/// 現在のリクエスト設定への参照を取得
pub fn request_config(&self) -> &RequestConfig {
&self.request_config
}
/// 最大トークン数を設定
///
/// この設定はキャッシュロックとは独立しており、各リクエストに適用されます。
///
/// # Examples
///
/// ```ignore
/// worker.set_max_tokens(4096);
/// ```
pub fn set_max_tokens(&mut self, max_tokens: u32) {
self.request_config.max_tokens = Some(max_tokens);
}
/// temperatureを設定
///
/// 0.0から1.0または2.0)の範囲で設定します。
/// 低い値はより決定的な出力を、高い値はより多様な出力を生成します。
///
/// # Examples
///
/// ```ignore
/// worker.set_temperature(0.7);
/// ```
pub fn set_temperature(&mut self, temperature: f32) {
self.request_config.temperature = Some(temperature);
}
/// top_pを設定nucleus sampling
///
/// # Examples
///
/// ```ignore
/// worker.set_top_p(0.9);
/// ```
pub fn set_top_p(&mut self, top_p: f32) {
self.request_config.top_p = Some(top_p);
}
/// top_kを設定
///
/// トークン選択時に考慮する上位k個のトークンを指定します。
///
/// # Examples
///
/// ```ignore
/// worker.set_top_k(40);
/// ```
pub fn set_top_k(&mut self, top_k: u32) {
self.request_config.top_k = Some(top_k);
}
/// ストップシーケンスを追加
///
/// # Examples
///
/// ```ignore
/// worker.add_stop_sequence("\n\n");
/// ```
pub fn add_stop_sequence(&mut self, sequence: impl Into<String>) {
self.request_config.stop_sequences.push(sequence.into());
}
/// ストップシーケンスをクリア
pub fn clear_stop_sequences(&mut self) {
self.request_config.stop_sequences.clear();
}
/// リクエスト設定を一括で設定
pub fn set_request_config(&mut self, config: RequestConfig) {
self.request_config = config;
}
/// 登録されたツールからToolDefinitionのリストを生成
fn build_tool_definitions(&self) -> Vec<ToolDefinition> {
self.tools
@ -387,6 +469,9 @@ impl<C: LlmClient, S: WorkerState> Worker<C, S> {
request = request.tool(tool_def.clone());
}
// リクエスト設定を適用
request = request.config(self.request_config.clone());
request
}
@ -622,6 +707,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
locked_prefix_len: 0,
turn_count: 0,
turn_notifiers: Vec::new(),
request_config: RequestConfig::default(),
_state: PhantomData,
}
}
@ -637,6 +723,95 @@ impl<C: LlmClient> Worker<C, Mutable> {
self.system_prompt = Some(prompt.into());
}
/// 最大トークン数を設定(ビルダーパターン)
///
/// # Examples
///
/// ```ignore
/// let worker = Worker::new(client)
/// .system_prompt("You are a helpful assistant.")
/// .max_tokens(4096);
/// ```
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.request_config.max_tokens = Some(max_tokens);
self
}
/// temperatureを設定ビルダーパターン
///
/// # Examples
///
/// ```ignore
/// let worker = Worker::new(client)
/// .temperature(0.7);
/// ```
pub fn temperature(mut self, temperature: f32) -> Self {
self.request_config.temperature = Some(temperature);
self
}
/// top_pを設定ビルダーパターン
pub fn top_p(mut self, top_p: f32) -> Self {
self.request_config.top_p = Some(top_p);
self
}
/// top_kを設定ビルダーパターン
pub fn top_k(mut self, top_k: u32) -> Self {
self.request_config.top_k = Some(top_k);
self
}
/// ストップシーケンスを追加(ビルダーパターン)
pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
self.request_config.stop_sequences.push(sequence.into());
self
}
/// リクエスト設定をまとめて設定(ビルダーパターン)
///
/// # Examples
///
/// ```ignore
/// let config = RequestConfig::new()
/// .with_max_tokens(4096)
/// .with_temperature(0.7);
///
/// let worker = Worker::new(client)
/// .system_prompt("...")
/// .with_config(config);
/// ```
pub fn with_config(mut self, config: RequestConfig) -> Self {
self.request_config = config;
self
}
/// 現在の設定をプロバイダに対してバリデーションする
///
/// 未サポートの設定があればエラーを返す。
/// チェーンの最後で呼び出すことで、設定の問題を早期に検出できる。
///
/// # Examples
///
/// ```ignore
/// let worker = Worker::new(client)
/// .temperature(0.7)
/// .top_k(40)
/// .validate()?; // OpenAIならtop_kがサポートされないためエラー
/// ```
///
/// # Returns
/// * `Ok(Self)` - バリデーション成功
/// * `Err(WorkerError::ConfigWarnings)` - 未サポートの設定がある
pub fn validate(self) -> Result<Self, WorkerError> {
let warnings = self.client.validate_config(&self.request_config);
if warnings.is_empty() {
Ok(self)
} else {
Err(WorkerError::ConfigWarnings(warnings))
}
}
/// 履歴への可変参照を取得
///
/// Mutable状態でのみ利用可能。履歴を自由に編集できる。
@ -700,6 +875,7 @@ impl<C: LlmClient> Worker<C, Mutable> {
locked_prefix_len,
turn_count: self.turn_count,
turn_notifiers: self.turn_notifiers,
request_config: self.request_config,
_state: PhantomData,
}
}
@ -777,6 +953,7 @@ impl<C: LlmClient> Worker<C, Locked> {
locked_prefix_len: 0,
turn_count: self.turn_count,
turn_notifiers: self.turn_notifiers,
request_config: self.request_config,
_state: PhantomData,
}
}

View File

@ -0,0 +1,40 @@
use llm_worker::llm_client::LlmClient;
use llm_worker::llm_client::providers::openai::OpenAIClient;
use llm_worker::{Worker, WorkerError};
#[test]
fn test_openai_top_k_warning() {
// ダミーキーでクライアント作成validate_configは通信しないため安全
let client = OpenAIClient::new("dummy-key", "gpt-4o");
// top_kを設定したWorkerを作成
let worker = Worker::new(client).top_k(50); // OpenAIはtop_k非対応
// validate()を実行
let result = worker.validate();
// エラーが返り、ConfigWarningsが含まれていることを確認
match result {
Err(WorkerError::ConfigWarnings(warnings)) => {
assert_eq!(warnings.len(), 1);
assert_eq!(warnings[0].option_name, "top_k");
println!("Got expected warning: {}", warnings[0]);
}
Ok(_) => panic!("Should have returned validation error"),
Err(e) => panic!("Unexpected error type: {:?}", e),
}
}
#[test]
fn test_openai_valid_config() {
let client = OpenAIClient::new("dummy-key", "gpt-4o");
// validな設定temperatureのみ
let worker = Worker::new(client).temperature(0.7);
// validate()を実行
let result = worker.validate();
// 成功を確認
assert!(result.is_ok());
}