feat: Implement WorkerCLI to use multiple providers
This commit is contained in:
parent
a26d43c52d
commit
1fbd4c8380
118
Cargo.lock
generated
118
Cargo.lock
generated
|
|
@ -2,6 +2,15 @@
|
|||
# It is not intended for manual editing.
|
||||
version = 4
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "1.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstream"
|
||||
version = "0.6.21"
|
||||
|
|
@ -766,6 +775,12 @@ dependencies = [
|
|||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.179"
|
||||
|
|
@ -796,6 +811,15 @@ version = "0.1.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154"
|
||||
|
||||
[[package]]
|
||||
name = "matchers"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9"
|
||||
dependencies = [
|
||||
"regex-automata",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.7.6"
|
||||
|
|
@ -835,6 +859,15 @@ dependencies = [
|
|||
"minimal-lexical",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nu-ansi-term"
|
||||
version = "0.50.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
|
||||
dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.21.3"
|
||||
|
|
@ -1018,6 +1051,23 @@ dependencies = [
|
|||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-automata"
|
||||
version = "0.4.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5276caf25ac86c8d810222b3dbb938e512c55c6831a10f3e6ed1c93b84041f1c"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.8.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58"
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.13.1"
|
||||
|
|
@ -1295,6 +1345,15 @@ dependencies = [
|
|||
"zmij",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sharded-slab"
|
||||
version = "0.1.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6"
|
||||
dependencies = [
|
||||
"lazy_static",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "shlex"
|
||||
version = "1.3.0"
|
||||
|
|
@ -1446,6 +1505,15 @@ dependencies = [
|
|||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thread_local"
|
||||
version = "1.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinystr"
|
||||
version = "0.8.2"
|
||||
|
|
@ -1572,9 +1640,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100"
|
||||
dependencies = [
|
||||
"pin-project-lite",
|
||||
"tracing-attributes",
|
||||
"tracing-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-attributes"
|
||||
version = "0.1.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-core"
|
||||
version = "0.1.36"
|
||||
|
|
@ -1582,6 +1662,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"valuable",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-log"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
|
||||
dependencies = [
|
||||
"log",
|
||||
"once_cell",
|
||||
"tracing-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-subscriber"
|
||||
version = "0.3.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e"
|
||||
dependencies = [
|
||||
"matchers",
|
||||
"nu-ansi-term",
|
||||
"once_cell",
|
||||
"regex-automata",
|
||||
"sharded-slab",
|
||||
"smallvec",
|
||||
"thread_local",
|
||||
"tracing",
|
||||
"tracing-core",
|
||||
"tracing-log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1626,6 +1736,12 @@ version = "0.2.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
||||
|
||||
[[package]]
|
||||
name = "valuable"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
|
||||
|
||||
[[package]]
|
||||
name = "walkdir"
|
||||
version = "2.5.0"
|
||||
|
|
@ -2048,6 +2164,8 @@ dependencies = [
|
|||
"tempfile",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"worker-macros",
|
||||
"worker-types",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ serde = { version = "1.0.228", features = ["derive"] }
|
|||
serde_json = "1.0"
|
||||
thiserror = "1.0"
|
||||
tokio = { version = "1.49.0", features = ["macros", "rt-multi-thread"] }
|
||||
tracing = "0.1"
|
||||
worker-macros = { path = "../worker-macros" }
|
||||
worker-types = { path = "../worker-types" }
|
||||
|
||||
|
|
@ -20,3 +21,4 @@ clap = { version = "4.5.54", features = ["derive", "env"] }
|
|||
schemars = "1.2.0"
|
||||
tempfile = "3.24.0"
|
||||
dotenv = "0.15.0"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
//! Worker を用いた対話型 CLI クライアント
|
||||
//!
|
||||
//! Anthropic Claude API と対話するシンプルなCLIアプリケーション。
|
||||
//! 複数のLLMプロバイダ(Anthropic, Gemini, OpenAI, Ollama)と対話するCLIアプリケーション。
|
||||
//! ツールの登録と実行、ストリーミングレスポンスの表示をデモする。
|
||||
//!
|
||||
//! ## 使用方法
|
||||
|
|
@ -8,12 +8,23 @@
|
|||
//! ```bash
|
||||
//! # .envファイルにAPIキーを設定
|
||||
//! echo "ANTHROPIC_API_KEY=your-api-key" > .env
|
||||
//! echo "GEMINI_API_KEY=your-api-key" >> .env
|
||||
//! echo "OPENAI_API_KEY=your-api-key" >> .env
|
||||
//!
|
||||
//! # 基本的な実行
|
||||
//! # Anthropic (デフォルト)
|
||||
//! cargo run --example worker_cli
|
||||
//!
|
||||
//! # Gemini
|
||||
//! cargo run --example worker_cli -- --provider gemini
|
||||
//!
|
||||
//! # OpenAI
|
||||
//! cargo run --example worker_cli -- --provider openai --model gpt-4o
|
||||
//!
|
||||
//! # Ollama (ローカル)
|
||||
//! cargo run --example worker_cli -- --provider ollama --model llama3.2
|
||||
//!
|
||||
//! # オプション指定
|
||||
//! cargo run --example worker_cli -- --model claude-3-haiku-20240307 --system "You are a helpful assistant."
|
||||
//! cargo run --example worker_cli -- --provider anthropic --model claude-3-haiku-20240307 --system "You are a helpful assistant."
|
||||
//!
|
||||
//! # ヘルプ表示
|
||||
//! cargo run --example worker_cli -- --help
|
||||
|
|
@ -22,10 +33,19 @@
|
|||
use std::io::{self, Write};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use clap::Parser;
|
||||
use tracing::info;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
use worker::{
|
||||
llm_client::providers::anthropic::AnthropicClient, Handler, TextBlockEvent, TextBlockKind,
|
||||
ToolUseBlockEvent, ToolUseBlockKind, Worker,
|
||||
llm_client::{
|
||||
providers::{
|
||||
anthropic::AnthropicClient, gemini::GeminiClient, ollama::OllamaClient,
|
||||
openai::OpenAIClient,
|
||||
},
|
||||
LlmClient,
|
||||
},
|
||||
Handler, TextBlockEvent, TextBlockKind, ToolUseBlockEvent, ToolUseBlockKind, Worker,
|
||||
};
|
||||
use worker_macros::tool_registry;
|
||||
use worker_types::Message;
|
||||
|
|
@ -34,19 +54,73 @@ use worker_types::Message;
|
|||
use schemars;
|
||||
use serde;
|
||||
|
||||
// =============================================================================
|
||||
// プロバイダ定義
|
||||
// =============================================================================
|
||||
|
||||
/// 利用可能なLLMプロバイダ
|
||||
#[derive(Debug, Clone, Copy, ValueEnum, Default)]
|
||||
enum Provider {
|
||||
/// Anthropic Claude
|
||||
#[default]
|
||||
Anthropic,
|
||||
/// Google Gemini
|
||||
Gemini,
|
||||
/// OpenAI GPT
|
||||
Openai,
|
||||
/// Ollama (ローカル)
|
||||
Ollama,
|
||||
}
|
||||
|
||||
impl Provider {
|
||||
/// プロバイダのデフォルトモデル
|
||||
fn default_model(&self) -> &'static str {
|
||||
match self {
|
||||
Provider::Anthropic => "claude-sonnet-4-20250514",
|
||||
Provider::Gemini => "gemini-2.0-flash",
|
||||
Provider::Openai => "gpt-4o",
|
||||
Provider::Ollama => "llama3.2",
|
||||
}
|
||||
}
|
||||
|
||||
/// プロバイダの表示名
|
||||
fn display_name(&self) -> &'static str {
|
||||
match self {
|
||||
Provider::Anthropic => "Anthropic Claude",
|
||||
Provider::Gemini => "Google Gemini",
|
||||
Provider::Openai => "OpenAI GPT",
|
||||
Provider::Ollama => "Ollama (Local)",
|
||||
}
|
||||
}
|
||||
|
||||
/// APIキーの環境変数名
|
||||
fn env_var_name(&self) -> Option<&'static str> {
|
||||
match self {
|
||||
Provider::Anthropic => Some("ANTHROPIC_API_KEY"),
|
||||
Provider::Gemini => Some("GEMINI_API_KEY"),
|
||||
Provider::Openai => Some("OPENAI_API_KEY"),
|
||||
Provider::Ollama => None, // Ollamaはローカルなので不要
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// CLI引数定義
|
||||
// =============================================================================
|
||||
|
||||
/// Anthropic Claude API を使った対話型CLIクライアント
|
||||
/// 複数のLLMプロバイダに対応した対話型CLIクライアント
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "worker-cli")]
|
||||
#[command(about = "Interactive CLI client for Anthropic Claude API using Worker")]
|
||||
#[command(about = "Interactive CLI client for multiple LLM providers using Worker")]
|
||||
#[command(version)]
|
||||
struct Args {
|
||||
/// 使用するモデル名
|
||||
#[arg(short, long, default_value = "claude-sonnet-4-20250514")]
|
||||
model: String,
|
||||
/// 使用するプロバイダ
|
||||
#[arg(long, value_enum, default_value_t = Provider::Anthropic)]
|
||||
provider: Provider,
|
||||
|
||||
/// 使用するモデル名(未指定時はプロバイダのデフォルト)
|
||||
#[arg(short, long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// システムプロンプト
|
||||
#[arg(short, long)]
|
||||
|
|
@ -60,9 +134,9 @@ struct Args {
|
|||
#[arg(short = 'p', long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// APIキー(環境変数 ANTHROPIC_API_KEY より優先)
|
||||
#[arg(long, env = "ANTHROPIC_API_KEY")]
|
||||
api_key: String,
|
||||
/// APIキー(環境変数より優先)
|
||||
#[arg(long)]
|
||||
api_key: Option<String>,
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
|
@ -170,24 +244,107 @@ impl Handler<ToolUseBlockKind> for ToolCallPrinter {
|
|||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// クライアント作成
|
||||
// =============================================================================
|
||||
|
||||
/// プロバイダに応じたAPIキーを取得
|
||||
fn get_api_key(args: &Args) -> Result<String, String> {
|
||||
// CLI引数のAPIキーが優先
|
||||
if let Some(ref key) = args.api_key {
|
||||
return Ok(key.clone());
|
||||
}
|
||||
|
||||
// プロバイダに応じた環境変数を確認
|
||||
if let Some(env_var) = args.provider.env_var_name() {
|
||||
std::env::var(env_var).map_err(|_| {
|
||||
format!(
|
||||
"API key required. Set {} environment variable or use --api-key",
|
||||
env_var
|
||||
)
|
||||
})
|
||||
} else {
|
||||
// Ollamaなどはキー不要
|
||||
Ok(String::new())
|
||||
}
|
||||
}
|
||||
|
||||
/// プロバイダに応じたクライアントを作成
|
||||
fn create_client(args: &Args) -> Result<Box<dyn LlmClient>, String> {
|
||||
let model = args
|
||||
.model
|
||||
.clone()
|
||||
.unwrap_or_else(|| args.provider.default_model().to_string());
|
||||
|
||||
let api_key = get_api_key(args)?;
|
||||
|
||||
match args.provider {
|
||||
Provider::Anthropic => {
|
||||
let client = AnthropicClient::new(&api_key, &model);
|
||||
Ok(Box::new(client))
|
||||
}
|
||||
Provider::Gemini => {
|
||||
let client = GeminiClient::new(&api_key, &model);
|
||||
Ok(Box::new(client))
|
||||
}
|
||||
Provider::Openai => {
|
||||
let client = OpenAIClient::new(&api_key, &model);
|
||||
Ok(Box::new(client))
|
||||
}
|
||||
Provider::Ollama => {
|
||||
let client = OllamaClient::new(&model);
|
||||
Ok(Box::new(client))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// メイン
|
||||
// =============================================================================
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// .envファイルを読み込む
|
||||
dotenv::dotenv().ok();
|
||||
|
||||
// ロギング初期化
|
||||
// RUST_LOG=debug cargo run --example worker_cli ... で詳細ログ表示
|
||||
// デフォルトは warn レベル、RUST_LOG 環境変数で上書き可能
|
||||
let filter = EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new("warn"));
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(filter)
|
||||
.with_target(true)
|
||||
.init();
|
||||
|
||||
// CLI引数をパース
|
||||
let args = Args::parse();
|
||||
|
||||
info!(
|
||||
provider = ?args.provider,
|
||||
model = ?args.model,
|
||||
"Starting worker CLI"
|
||||
);
|
||||
|
||||
// 対話モードかワンショットモードか
|
||||
let is_interactive = args.prompt.is_none();
|
||||
|
||||
// モデル名(表示用)
|
||||
let model_name = args
|
||||
.model
|
||||
.clone()
|
||||
.unwrap_or_else(|| args.provider.default_model().to_string());
|
||||
|
||||
if is_interactive {
|
||||
println!("╔════════════════════════════════════════════════╗");
|
||||
println!("║ Worker CLI - Anthropic Claude Client ║");
|
||||
println!("╚════════════════════════════════════════════════╝");
|
||||
let title = format!("Worker CLI - {}", args.provider.display_name());
|
||||
let border_len = title.len() + 6;
|
||||
println!("╔{}╗", "═".repeat(border_len));
|
||||
println!("║ {} ║", title);
|
||||
println!("╚{}╝", "═".repeat(border_len));
|
||||
println!();
|
||||
println!("Model: {}", args.model);
|
||||
println!("Provider: {}", args.provider.display_name());
|
||||
println!("Model: {}", model_name);
|
||||
if let Some(ref system) = args.system {
|
||||
println!("System: {}", system);
|
||||
}
|
||||
|
|
@ -204,7 +361,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
}
|
||||
|
||||
// クライアント作成
|
||||
let client = AnthropicClient::new(&args.api_key, &args.model);
|
||||
let client = match create_client(&args) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
eprintln!("❌ Error: {}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
|
||||
// Worker作成
|
||||
let mut worker = Worker::new(client);
|
||||
|
|
|
|||
|
|
@ -26,3 +26,16 @@ pub trait LlmClient: Send + Sync {
|
|||
request: Request,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>>, ClientError>;
|
||||
}
|
||||
|
||||
/// `Box<dyn LlmClient>` に対する `LlmClient` の実装
|
||||
///
|
||||
/// これにより、動的ディスパッチを使用するクライアントも `Worker` で利用可能になる。
|
||||
#[async_trait]
|
||||
impl LlmClient for Box<dyn LlmClient> {
|
||||
async fn stream(
|
||||
&self,
|
||||
request: Request,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>>, ClientError> {
|
||||
(**self).stream(request).await
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -137,9 +137,6 @@ impl LlmClient for AnthropicClient {
|
|||
.map_err(|e| std::io::Error::other(e));
|
||||
let event_stream = byte_stream.eventsource();
|
||||
|
||||
// 現在のブロックタイプを追跡するための状態
|
||||
// Note: Streamではmutableな状態を直接保持できないため、
|
||||
// BlockStopイベントでblock_typeを正しく設定するには追加の処理が必要
|
||||
let stream = event_stream.map(move |result| {
|
||||
match result {
|
||||
Ok(event) => {
|
||||
|
|
@ -162,14 +159,6 @@ impl LlmClient for AnthropicClient {
|
|||
}
|
||||
}
|
||||
|
||||
impl Clone for AnthropicScheme {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
api_version: self.api_version.clone(),
|
||||
fine_grained_tool_streaming: self.fine_grained_tool_streaming,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
|
|
|||
|
|
@ -194,11 +194,11 @@ impl AnthropicScheme {
|
|||
}
|
||||
AnthropicEventType::ContentBlockStop => {
|
||||
let event: ContentBlockStopEvent = serde_json::from_str(data)?;
|
||||
// Note: BlockStopにはblock_typeが必要だが、ここでは追跡していない
|
||||
// プロバイダ層で状態を追跡する必要がある
|
||||
// Note: BlockStopにはblock_typeが必要だが、AnthropicはStopイベントに含めない
|
||||
// Timeline層がBlockStartを追跡して正しいblock_typeを知る
|
||||
Ok(Some(Event::BlockStop(BlockStop {
|
||||
index: event.index,
|
||||
block_type: BlockType::Text, // プロバイダ層で上書きされる
|
||||
block_type: BlockType::Text, // Timeline層で上書きされる
|
||||
stop_reason: None,
|
||||
})))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ mod request;
|
|||
/// Anthropicスキーマ
|
||||
///
|
||||
/// Anthropic Messages APIのリクエスト/レスポンス変換を担当
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AnthropicScheme {
|
||||
/// APIバージョン
|
||||
pub api_version: String,
|
||||
|
|
|
|||
|
|
@ -9,20 +9,12 @@ mod request;
|
|||
/// Geminiスキーマ
|
||||
///
|
||||
/// Google Gemini APIのリクエスト/レスポンス変換を担当
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct GeminiScheme {
|
||||
/// ストリーミング関数呼び出し引数を有効にするか
|
||||
pub stream_function_call_arguments: bool,
|
||||
}
|
||||
|
||||
impl Default for GeminiScheme {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
stream_function_call_arguments: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GeminiScheme {
|
||||
/// 新しいスキーマを作成
|
||||
pub fn new() -> Self {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
//! OpenAI SSEイベントパース
|
||||
|
||||
use serde::Deserialize;
|
||||
use worker_types::{BlockType, Event, StopReason, UsageEvent};
|
||||
use worker_types::{Event, StopReason, UsageEvent};
|
||||
|
||||
use crate::llm_client::ClientError;
|
||||
|
||||
|
|
@ -12,46 +12,48 @@ use super::OpenAIScheme;
|
|||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct ChatCompletionChunk {
|
||||
pub id: String,
|
||||
pub choices: Vec<ChatCompletionChoice>,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub system_fingerprint: Option<String>,
|
||||
pub usage: Option<Usage>, // present if stream_options: { include_usage: true }
|
||||
pub choices: Vec<ChunkChoice>,
|
||||
pub usage: Option<ChunkUsage>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct ChatCompletionChoice {
|
||||
pub(crate) struct ChunkChoice {
|
||||
pub index: usize,
|
||||
pub delta: ChatCompletionDelta,
|
||||
pub delta: ChunkDelta,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct ChatCompletionDelta {
|
||||
pub(crate) struct ChunkDelta {
|
||||
pub role: Option<String>,
|
||||
pub content: Option<String>,
|
||||
pub tool_calls: Option<Vec<ChatCompletionToolCallDelta>>,
|
||||
pub refusal: Option<String>,
|
||||
pub tool_calls: Option<Vec<ChunkToolCall>>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct ChatCompletionToolCallDelta {
|
||||
pub(crate) struct ChunkToolCall {
|
||||
pub index: usize,
|
||||
pub id: Option<String>,
|
||||
pub r#type: Option<String>, // "function"
|
||||
pub function: Option<ChatCompletionFunctionDelta>,
|
||||
#[serde(rename = "type")]
|
||||
pub call_type: Option<String>,
|
||||
pub function: Option<ChunkFunction>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct ChatCompletionFunctionDelta {
|
||||
pub(crate) struct ChunkFunction {
|
||||
pub name: Option<String>,
|
||||
pub arguments: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct Usage {
|
||||
pub(crate) struct ChunkUsage {
|
||||
pub prompt_tokens: u64,
|
||||
pub completion_tokens: u64,
|
||||
pub total_tokens: u64,
|
||||
|
|
@ -59,6 +61,9 @@ pub(crate) struct Usage {
|
|||
|
||||
impl OpenAIScheme {
|
||||
/// SSEデータのパースとEventへの変換
|
||||
///
|
||||
/// OpenAI APIはBlockStartイベントを明示的に送信しない。
|
||||
/// Timeline層が暗黙的なBlockStartを処理する。
|
||||
pub fn parse_event(&self, data: &str) -> Result<Option<Vec<Event>>, ClientError> {
|
||||
if data == "[DONE]" {
|
||||
return Ok(None);
|
||||
|
|
@ -87,26 +92,8 @@ impl OpenAIScheme {
|
|||
for choice in chunk.choices {
|
||||
// Text Content Delta
|
||||
if let Some(content) = choice.delta.content {
|
||||
// OpenAI splits "start" and "delta", but for text it usually just streams content.
|
||||
// We don't distinctly get "BlockStart" from OpenAI for text usually, unless we track it manually.
|
||||
// We'll optimistically emit BlockDelta(Text). The consumer (Timeline) should handle implicit starts if needed,
|
||||
// OR we need to maintain state in the Scheme struct to know if we sent start.
|
||||
// However, LlmClient usually just emits generic events.
|
||||
// Let's assume index 0 for text if implicit.
|
||||
// Actually, choice.index could be the block index? No, choice index is candidate index.
|
||||
// OpenAI only generates 1 candidate usually in streaming unless n > 1.
|
||||
// We map choice.index to Event index, hoping consumer handles it.
|
||||
|
||||
// NOTE: We might need to emit BlockStart if this is the first chunk for this choice index.
|
||||
// But Scheme is stateless per event parse call usually.
|
||||
// Timeline handles accumulating text. We can just emit Delta.
|
||||
// BUT wait, `worker_types::Event` expects explicit `BlockStart` before `BlockDelta`?
|
||||
// Let's check `events.rs` in anthropic. It seems to rely on explicit events from API.
|
||||
// OpenAI API key diff: No explicit "start_block" event.
|
||||
// So we might need to emit TextDelta, and if the consumer sees it without start, it handles it?
|
||||
// Re-checking `worker_types::Event`: `BlockDelta` exists.
|
||||
|
||||
// For now, let's map content to `BlockDelta(Text)`.
|
||||
// OpenAI APIはBlockStartを送らないため、デルタのみを発行
|
||||
// Timeline層が暗黙的なBlockStartを処理する
|
||||
events.push(Event::text_delta(choice.index, content));
|
||||
}
|
||||
|
||||
|
|
@ -115,20 +102,16 @@ impl OpenAIScheme {
|
|||
for tool_call in tool_calls {
|
||||
// Start of tool call (has ID)
|
||||
if let Some(id) = tool_call.id {
|
||||
let name = tool_call.function.as_ref().and_then(|f| f.name.clone()).unwrap_or_default();
|
||||
// Assuming tool_call.index is sequential for the choice.
|
||||
// We might want to map (choice.index, tool_call.index) to a flat block index?
|
||||
// OpenAI's tool_call.index is 0, 1, 2... within the message.
|
||||
// Timeline expects usize index. We can use tool_call.index.
|
||||
events.push(Event::tool_use_start(tool_call.index, id, name));
|
||||
let name = tool_call.function.as_ref().and_then(|f| f.name.clone()).unwrap_or_default();
|
||||
events.push(Event::tool_use_start(tool_call.index, id, name));
|
||||
}
|
||||
|
||||
// Arguments delta
|
||||
if let Some(function) = tool_call.function {
|
||||
if let Some(args) = function.arguments {
|
||||
if !args.is_empty() {
|
||||
events.push(Event::tool_input_delta(tool_call.index, args));
|
||||
}
|
||||
if !args.is_empty() {
|
||||
events.push(Event::tool_input_delta(tool_call.index, args));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -140,84 +123,26 @@ impl OpenAIScheme {
|
|||
"stop" => Some(StopReason::EndTurn),
|
||||
"length" => Some(StopReason::MaxTokens),
|
||||
"tool_calls" | "function_call" => Some(StopReason::ToolUse),
|
||||
// "content_filter" => ...
|
||||
_ => Some(StopReason::EndTurn),
|
||||
};
|
||||
|
||||
// We need to know WHAT block stopped.
|
||||
// OpenAI doesn't tell us "Text block stopped" vs "Tool block stopped" easily in the finish_reason event alone without context.
|
||||
// But usually finish_reason comes at the end.
|
||||
// If `stop` or `length`, it's likely the Text block (index 0) or the last active block.
|
||||
// If `tool_calls`, it means the ToolUse blocks are done.
|
||||
let is_tool_finish = finish_reason == "tool_calls" || finish_reason == "function_call";
|
||||
|
||||
// We'll emit BlockStop for the choice index.
|
||||
// For tool calls, we might have emitted ToolUseStart for explicit indices.
|
||||
// If finish_reason is tool_calls, we might need to close all open tool blocks?
|
||||
// The generic BlockStop event takes an index and type.
|
||||
|
||||
// Simplified strategy:
|
||||
// If tool_calls, we assume the last tool call index we saw?
|
||||
// Or better, we emit a generic BlockStop logic in Timeline?
|
||||
// Provide a "generic" stop for now?
|
||||
// Event::BlockStop requires type.
|
||||
|
||||
let block_type = if finish_reason == "tool_calls" || finish_reason == "function_call" {
|
||||
BlockType::ToolUse
|
||||
if is_tool_finish {
|
||||
// ツール呼び出し終了
|
||||
// Note: OpenAIはどのツールが終了したか明示しないため、
|
||||
// Timeline層で適切に処理する必要がある
|
||||
} else {
|
||||
BlockType::Text
|
||||
};
|
||||
|
||||
// We use choice.index as the block index for Text, but Tool Calls have their own indices.
|
||||
// This mismatch is tricky without state.
|
||||
// However, for Text (standard), choice.index usually 0.
|
||||
// For Tool calls, they have indices 0, 1, 2...
|
||||
// If we finish with tool_calls, strictly speaking we should close the tool blocks.
|
||||
// But we don't know WHICH ones are open without state.
|
||||
|
||||
// Let's defer to emitting a Stop for choice.index (Text) or 0 (Text) if text,
|
||||
// But for ToolUse, we might not emit BlockStop here if we rely on the consumer to close based on ToolUseStart/Delta flow completion?
|
||||
// OpenAI doesn't stream "Tool call 0 finished", it just starts "Tool call 1" or ends message.
|
||||
|
||||
// Actually, we can check if `tool_calls` field was present in ANY chunk to know if we are in tool mode? No.
|
||||
|
||||
// Tentative: Emit BlockStop for Text if NOT tool_calls.
|
||||
if block_type == BlockType::Text {
|
||||
// テキスト終了
|
||||
events.push(Event::text_block_stop(choice.index, stop_reason));
|
||||
} else {
|
||||
// For tool calls, we don't emit a stop here?
|
||||
// Or we emit `Event::tool_use_stop` for the *last* known index? impossible to know.
|
||||
// IMPORTANT: The `worker-types::Event::tool_use_stop` requires an index.
|
||||
// We might need to assume the `Timeline` layer handles implicit stops for tools when the turn ends?
|
||||
// OR we modify this parser to specific logic later.
|
||||
|
||||
// Let's assume mostly 1 tool call for now or that we don't explicitly close them here
|
||||
// and rely on `BlockStop` with `StopReason::ToolUse` at index 0 to signal "Message finished due to tool use"?
|
||||
// No, that confuses Block/Message levels.
|
||||
|
||||
// Re-read `worker_types`: `BlockStop` is per block.
|
||||
// If we have multiple tools, we need multiple stops.
|
||||
// But we only get one `finish_reason`.
|
||||
|
||||
// Ideally, we'd emit stops for all tools.
|
||||
// Without state, we can't.
|
||||
// We will emit NOTHING for tool stops here and hope Timeline handles it via `finish_reason` on the message?
|
||||
// Events are flat.
|
||||
|
||||
// Workaround: Emit a generic status event or specific stop if we can.
|
||||
// Anthropic emits `content_block_stop`. OpenAI doesn't.
|
||||
// We might need a stateful parser for OpenAI to be perfect.
|
||||
// But `OpenAIScheme` is methods-only.
|
||||
|
||||
// We will skip emitting specific BlockStop for tools for now,
|
||||
// but we will emit Status(Completed) if finish_reason is stop/length.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if events.is_empty() {
|
||||
Ok(None)
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(events))
|
||||
Ok(Some(events))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -233,14 +158,16 @@ mod tests {
|
|||
let data = r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}"#;
|
||||
|
||||
let events = scheme.parse_event(data).unwrap().unwrap();
|
||||
// OpenAIはBlockStartを発行しないため、デルタのみ
|
||||
assert_eq!(events.len(), 1);
|
||||
|
||||
if let Event::BlockDelta(delta) = &events[0] {
|
||||
assert_eq!(delta.index, 0);
|
||||
if let DeltaContent::Text(text) = &delta.delta {
|
||||
assert_eq!(text, "Hello");
|
||||
} else {
|
||||
panic!("Expected text delta");
|
||||
}
|
||||
if let DeltaContent::Text(text) = &delta.delta {
|
||||
assert_eq!(text, "Hello");
|
||||
} else {
|
||||
panic!("Expected text delta");
|
||||
}
|
||||
} else {
|
||||
panic!("Expected BlockDelta");
|
||||
}
|
||||
|
|
@ -253,28 +180,27 @@ mod tests {
|
|||
let data_start = r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_abc","type":"function","function":{"name":"get_weather","arguments":""}}]},"finish_reason":null}]}"#;
|
||||
|
||||
let events = scheme.parse_event(data_start).unwrap().unwrap();
|
||||
// Should have tool_use_start
|
||||
assert_eq!(events.len(), 1);
|
||||
if let Event::BlockStart(start) = &events[0] {
|
||||
assert_eq!(start.index, 0); // tool_call index is 0
|
||||
assert_eq!(start.index, 0);
|
||||
if let worker_types::BlockMetadata::ToolUse { id, name } = &start.metadata {
|
||||
assert_eq!(id, "call_abc");
|
||||
assert_eq!(name, "get_weather");
|
||||
} else {
|
||||
panic!("Expected ToolUse metadata");
|
||||
panic!("Expected ToolUse metadata");
|
||||
}
|
||||
}
|
||||
|
||||
// Tool arguments delta
|
||||
let data_arg = r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{}}"}}]},"finish_reason":null}]}"#;
|
||||
let events = scheme.parse_event(data_arg).unwrap().unwrap();
|
||||
assert_eq!(events.len(), 1);
|
||||
if let Event::BlockDelta(delta) = &events[0] {
|
||||
if let DeltaContent::InputJson(json) = &delta.delta {
|
||||
assert_eq!(json, "{}}");
|
||||
} else {
|
||||
panic!("Expected input json delta");
|
||||
}
|
||||
}
|
||||
assert_eq!(events.len(), 1);
|
||||
if let Event::BlockDelta(delta) = &events[0] {
|
||||
if let DeltaContent::InputJson(json) = &delta.delta {
|
||||
assert_eq!(json, "{}}");
|
||||
} else {
|
||||
panic!("Expected input json delta");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ mod request;
|
|||
/// OpenAIスキーマ
|
||||
///
|
||||
/// OpenAI Chat Completions API (および互換API) のリクエスト/レスポンス変換を担当
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct OpenAIScheme {
|
||||
/// モデル名 (リクエスト時に指定されるが、デフォルト値として保持も可能)
|
||||
pub model: Option<String>,
|
||||
|
|
@ -17,15 +17,6 @@ pub struct OpenAIScheme {
|
|||
pub use_legacy_max_tokens: bool,
|
||||
}
|
||||
|
||||
impl Default for OpenAIScheme {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
model: None,
|
||||
use_legacy_max_tokens: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAIScheme {
|
||||
/// 新しいスキーマを作成
|
||||
pub fn new() -> Self {
|
||||
|
|
|
|||
|
|
@ -81,6 +81,8 @@ trait ErasedBlockHandler: Send {
|
|||
fn dispatch_abort(&mut self, abort: &BlockAbort);
|
||||
fn start_scope(&mut self);
|
||||
fn end_scope(&mut self);
|
||||
/// スコープがアクティブかどうか
|
||||
fn has_scope(&self) -> bool;
|
||||
}
|
||||
|
||||
/// TextBlockKind用のラッパー
|
||||
|
|
@ -150,6 +152,10 @@ where
|
|||
fn end_scope(&mut self) {
|
||||
self.scope = None;
|
||||
}
|
||||
|
||||
fn has_scope(&self) -> bool {
|
||||
self.scope.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
/// ThinkingBlockKind用のラッパー
|
||||
|
|
@ -214,6 +220,10 @@ where
|
|||
fn end_scope(&mut self) {
|
||||
self.scope = None;
|
||||
}
|
||||
|
||||
fn has_scope(&self) -> bool {
|
||||
self.scope.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
/// ToolUseBlockKind用のラッパー
|
||||
|
|
@ -296,6 +306,10 @@ where
|
|||
self.scope = None;
|
||||
self.current_tool = None;
|
||||
}
|
||||
|
||||
fn has_scope(&self) -> bool {
|
||||
self.scope.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
|
@ -488,8 +502,19 @@ impl Timeline {
|
|||
|
||||
fn handle_block_delta(&mut self, delta: &BlockDelta) {
|
||||
let block_type = delta.delta.block_type();
|
||||
|
||||
// OpenAIなどのプロバイダはBlockStartを送らない場合があるため、
|
||||
// Deltaが来たときにスコープがなければ暗黙的に開始する
|
||||
if self.current_block.is_none() {
|
||||
self.current_block = Some(block_type);
|
||||
}
|
||||
|
||||
let handlers = self.get_block_handlers_mut(block_type);
|
||||
for handler in handlers {
|
||||
// スコープがなければ暗黙的に開始
|
||||
if !handler.has_scope() {
|
||||
handler.start_scope();
|
||||
}
|
||||
handler.dispatch_delta(delta);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ use std::collections::HashMap;
|
|||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use futures::StreamExt;
|
||||
use tracing::{debug, info, trace, warn};
|
||||
|
||||
use crate::llm_client::{ClientError, LlmClient, Request, ToolDefinition};
|
||||
use crate::subscriber_adapter::{
|
||||
|
|
@ -223,9 +224,16 @@ impl<C: LlmClient> Worker<C> {
|
|||
let mut context = messages;
|
||||
let tool_definitions = self.build_tool_definitions();
|
||||
|
||||
info!(
|
||||
message_count = context.len(),
|
||||
tool_count = tool_definitions.len(),
|
||||
"Starting worker run"
|
||||
);
|
||||
|
||||
loop {
|
||||
// ターン開始を通知
|
||||
let current_turn = self.turn_count;
|
||||
debug!(turn = current_turn, "Turn start");
|
||||
for notifier in &self.turn_notifiers {
|
||||
notifier.on_turn_start(current_turn);
|
||||
}
|
||||
|
|
@ -233,6 +241,7 @@ impl<C: LlmClient> Worker<C> {
|
|||
// Hook: on_message_send
|
||||
let control = self.run_on_message_send_hooks(&mut context).await?;
|
||||
if let ControlFlow::Abort(reason) = control {
|
||||
warn!(reason = %reason, "Aborted by hook");
|
||||
// ターン終了を通知(異常終了)
|
||||
for notifier in &self.turn_notifiers {
|
||||
notifier.on_turn_end(current_turn);
|
||||
|
|
@ -242,13 +251,31 @@ impl<C: LlmClient> Worker<C> {
|
|||
|
||||
// リクエスト構築
|
||||
let request = self.build_request(&context, &tool_definitions);
|
||||
debug!(
|
||||
message_count = request.messages.len(),
|
||||
tool_count = request.tools.len(),
|
||||
has_system = request.system_prompt.is_some(),
|
||||
"Sending request to LLM"
|
||||
);
|
||||
|
||||
// ストリーム処理
|
||||
debug!("Starting stream...");
|
||||
let mut stream = self.client.stream(request).await?;
|
||||
let mut event_count = 0;
|
||||
while let Some(event_result) = stream.next().await {
|
||||
match &event_result {
|
||||
Ok(event) => {
|
||||
trace!(event = ?event, "Received event");
|
||||
event_count += 1;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "Stream error");
|
||||
}
|
||||
}
|
||||
let event = event_result?;
|
||||
self.timeline.dispatch(&event);
|
||||
}
|
||||
debug!(event_count = event_count, "Stream completed");
|
||||
|
||||
// ターン終了を通知
|
||||
for notifier in &self.turn_notifiers {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user