446 lines
13 KiB
Rust
446 lines
13 KiB
Rust
//! Worker を用いた対話型 CLI クライアント
|
||
//!
|
||
//! 複数のLLMプロバイダ(Anthropic, Gemini, OpenAI, Ollama)と対話するCLIアプリケーション。
|
||
//! ツールの登録と実行、ストリーミングレスポンスの表示をデモする。
|
||
//!
|
||
//! ## 使用方法
|
||
//!
|
||
//! ```bash
|
||
//! # .envファイルにAPIキーを設定
|
||
//! echo "ANTHROPIC_API_KEY=your-api-key" > .env
|
||
//! echo "GEMINI_API_KEY=your-api-key" >> .env
|
||
//! echo "OPENAI_API_KEY=your-api-key" >> .env
|
||
//!
|
||
//! # Anthropic (デフォルト)
|
||
//! cargo run --example worker_cli
|
||
//!
|
||
//! # Gemini
|
||
//! cargo run --example worker_cli -- --provider gemini
|
||
//!
|
||
//! # OpenAI
|
||
//! cargo run --example worker_cli -- --provider openai --model gpt-4o
|
||
//!
|
||
//! # Ollama (ローカル)
|
||
//! cargo run --example worker_cli -- --provider ollama --model llama3.2
|
||
//!
|
||
//! # オプション指定
|
||
//! cargo run --example worker_cli -- --provider anthropic --model claude-3-haiku-20240307 --system "You are a helpful assistant."
|
||
//!
|
||
//! # ヘルプ表示
|
||
//! cargo run --example worker_cli -- --help
|
||
//! ```
|
||
|
||
use std::io::{self, Write};
|
||
use std::sync::{Arc, Mutex};
|
||
|
||
use tracing::info;
|
||
use tracing_subscriber::EnvFilter;
|
||
|
||
use clap::{Parser, ValueEnum};
|
||
use worker::{
|
||
Handler, TextBlockEvent, TextBlockKind, ToolUseBlockEvent, ToolUseBlockKind, Worker,
|
||
llm_client::{
|
||
LlmClient,
|
||
providers::{
|
||
anthropic::AnthropicClient, gemini::GeminiClient, ollama::OllamaClient,
|
||
openai::OpenAIClient,
|
||
},
|
||
},
|
||
};
|
||
use worker_macros::tool_registry;
|
||
use worker_types::Message;
|
||
|
||
// 必要なマクロ展開用インポート
|
||
use schemars;
|
||
use serde;
|
||
|
||
// =============================================================================
|
||
// プロバイダ定義
|
||
// =============================================================================
|
||
|
||
/// 利用可能なLLMプロバイダ
|
||
#[derive(Debug, Clone, Copy, ValueEnum, Default)]
|
||
enum Provider {
|
||
/// Anthropic Claude
|
||
#[default]
|
||
Anthropic,
|
||
/// Google Gemini
|
||
Gemini,
|
||
/// OpenAI GPT
|
||
Openai,
|
||
/// Ollama (ローカル)
|
||
Ollama,
|
||
}
|
||
|
||
impl Provider {
|
||
/// プロバイダのデフォルトモデル
|
||
fn default_model(&self) -> &'static str {
|
||
match self {
|
||
Provider::Anthropic => "claude-sonnet-4-20250514",
|
||
Provider::Gemini => "gemini-2.0-flash",
|
||
Provider::Openai => "gpt-4o",
|
||
Provider::Ollama => "llama3.2",
|
||
}
|
||
}
|
||
|
||
/// プロバイダの表示名
|
||
fn display_name(&self) -> &'static str {
|
||
match self {
|
||
Provider::Anthropic => "Anthropic Claude",
|
||
Provider::Gemini => "Google Gemini",
|
||
Provider::Openai => "OpenAI GPT",
|
||
Provider::Ollama => "Ollama (Local)",
|
||
}
|
||
}
|
||
|
||
/// APIキーの環境変数名
|
||
fn env_var_name(&self) -> Option<&'static str> {
|
||
match self {
|
||
Provider::Anthropic => Some("ANTHROPIC_API_KEY"),
|
||
Provider::Gemini => Some("GEMINI_API_KEY"),
|
||
Provider::Openai => Some("OPENAI_API_KEY"),
|
||
Provider::Ollama => None, // Ollamaはローカルなので不要
|
||
}
|
||
}
|
||
}
|
||
|
||
// =============================================================================
|
||
// CLI引数定義
|
||
// =============================================================================
|
||
|
||
/// 複数のLLMプロバイダに対応した対話型CLIクライアント
|
||
#[derive(Parser, Debug)]
|
||
#[command(name = "worker-cli")]
|
||
#[command(about = "Interactive CLI client for multiple LLM providers using Worker")]
|
||
#[command(version)]
|
||
struct Args {
|
||
/// 使用するプロバイダ
|
||
#[arg(long, value_enum, default_value_t = Provider::Anthropic)]
|
||
provider: Provider,
|
||
|
||
/// 使用するモデル名(未指定時はプロバイダのデフォルト)
|
||
#[arg(short, long)]
|
||
model: Option<String>,
|
||
|
||
/// システムプロンプト
|
||
#[arg(short, long)]
|
||
system: Option<String>,
|
||
|
||
/// ツールを無効化
|
||
#[arg(long, default_value = "false")]
|
||
no_tools: bool,
|
||
|
||
/// 最初のメッセージ(指定するとそれを送信して終了)
|
||
#[arg(short = 'p', long)]
|
||
prompt: Option<String>,
|
||
|
||
/// APIキー(環境変数より優先)
|
||
#[arg(long)]
|
||
api_key: Option<String>,
|
||
}
|
||
|
||
// =============================================================================
|
||
// ツール定義
|
||
// =============================================================================
|
||
|
||
/// アプリケーションコンテキスト
|
||
#[derive(Clone)]
|
||
struct AppContext;
|
||
|
||
#[tool_registry]
|
||
impl AppContext {
|
||
/// 現在の日時を取得する
|
||
///
|
||
/// システムの現在の日付と時刻を返します。
|
||
#[tool]
|
||
fn get_current_time(&self) -> String {
|
||
let now = std::time::SystemTime::now()
|
||
.duration_since(std::time::UNIX_EPOCH)
|
||
.unwrap()
|
||
.as_secs();
|
||
// シンプルなUnixタイムスタンプからの変換
|
||
format!("Current Unix timestamp: {}", now)
|
||
}
|
||
|
||
/// 簡単な計算を行う
|
||
///
|
||
/// 2つの数値の四則演算を実行します。
|
||
#[tool]
|
||
fn calculate(&self, a: f64, b: f64, operation: String) -> Result<String, String> {
|
||
let result = match operation.as_str() {
|
||
"add" | "+" => a + b,
|
||
"subtract" | "-" => a - b,
|
||
"multiply" | "*" => a * b,
|
||
"divide" | "/" => {
|
||
if b == 0.0 {
|
||
return Err("Cannot divide by zero".to_string());
|
||
}
|
||
a / b
|
||
}
|
||
_ => return Err(format!("Unknown operation: {}", operation)),
|
||
};
|
||
Ok(format!("{} {} {} = {}", a, operation, b, result))
|
||
}
|
||
}
|
||
|
||
// =============================================================================
|
||
// ストリーミング表示用ハンドラー
|
||
// =============================================================================
|
||
|
||
/// テキストをリアルタイムで出力するハンドラー
|
||
struct StreamingPrinter {
|
||
is_first_delta: Arc<Mutex<bool>>,
|
||
}
|
||
|
||
impl StreamingPrinter {
|
||
fn new() -> Self {
|
||
Self {
|
||
is_first_delta: Arc::new(Mutex::new(true)),
|
||
}
|
||
}
|
||
}
|
||
|
||
impl Handler<TextBlockKind> for StreamingPrinter {
|
||
type Scope = ();
|
||
|
||
fn on_event(&mut self, _scope: &mut (), event: &TextBlockEvent) {
|
||
match event {
|
||
TextBlockEvent::Start(_) => {
|
||
let mut first = self.is_first_delta.lock().unwrap();
|
||
if *first {
|
||
print!("\n🤖 ");
|
||
*first = false;
|
||
}
|
||
}
|
||
TextBlockEvent::Delta(text) => {
|
||
print!("{}", text);
|
||
io::stdout().flush().ok();
|
||
}
|
||
TextBlockEvent::Stop(_) => {
|
||
println!();
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
/// ツール呼び出しを表示するハンドラー
|
||
struct ToolCallPrinter;
|
||
|
||
impl Handler<ToolUseBlockKind> for ToolCallPrinter {
|
||
type Scope = String;
|
||
|
||
fn on_event(&mut self, json_buffer: &mut String, event: &ToolUseBlockEvent) {
|
||
match event {
|
||
ToolUseBlockEvent::Start(start) => {
|
||
println!("\n🔧 Calling tool: {}", start.name);
|
||
}
|
||
ToolUseBlockEvent::InputJsonDelta(json) => {
|
||
json_buffer.push_str(json);
|
||
}
|
||
ToolUseBlockEvent::Stop(_) => {
|
||
println!(" Args: {}", json_buffer);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// =============================================================================
|
||
// クライアント作成
|
||
// =============================================================================
|
||
|
||
/// プロバイダに応じたAPIキーを取得
|
||
fn get_api_key(args: &Args) -> Result<String, String> {
|
||
// CLI引数のAPIキーが優先
|
||
if let Some(ref key) = args.api_key {
|
||
return Ok(key.clone());
|
||
}
|
||
|
||
// プロバイダに応じた環境変数を確認
|
||
if let Some(env_var) = args.provider.env_var_name() {
|
||
std::env::var(env_var).map_err(|_| {
|
||
format!(
|
||
"API key required. Set {} environment variable or use --api-key",
|
||
env_var
|
||
)
|
||
})
|
||
} else {
|
||
// Ollamaなどはキー不要
|
||
Ok(String::new())
|
||
}
|
||
}
|
||
|
||
/// プロバイダに応じたクライアントを作成
|
||
fn create_client(args: &Args) -> Result<Box<dyn LlmClient>, String> {
|
||
let model = args
|
||
.model
|
||
.clone()
|
||
.unwrap_or_else(|| args.provider.default_model().to_string());
|
||
|
||
let api_key = get_api_key(args)?;
|
||
|
||
match args.provider {
|
||
Provider::Anthropic => {
|
||
let client = AnthropicClient::new(&api_key, &model);
|
||
Ok(Box::new(client))
|
||
}
|
||
Provider::Gemini => {
|
||
let client = GeminiClient::new(&api_key, &model);
|
||
Ok(Box::new(client))
|
||
}
|
||
Provider::Openai => {
|
||
let client = OpenAIClient::new(&api_key, &model);
|
||
Ok(Box::new(client))
|
||
}
|
||
Provider::Ollama => {
|
||
let client = OllamaClient::new(&model);
|
||
Ok(Box::new(client))
|
||
}
|
||
}
|
||
}
|
||
|
||
// =============================================================================
|
||
// メイン
|
||
// =============================================================================
|
||
|
||
#[tokio::main]
|
||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||
// .envファイルを読み込む
|
||
dotenv::dotenv().ok();
|
||
|
||
// ロギング初期化
|
||
// RUST_LOG=debug cargo run --example worker_cli ... で詳細ログ表示
|
||
// デフォルトは warn レベル、RUST_LOG 環境変数で上書き可能
|
||
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("warn"));
|
||
|
||
tracing_subscriber::fmt()
|
||
.with_env_filter(filter)
|
||
.with_target(true)
|
||
.init();
|
||
|
||
// CLI引数をパース
|
||
let args = Args::parse();
|
||
|
||
info!(
|
||
provider = ?args.provider,
|
||
model = ?args.model,
|
||
"Starting worker CLI"
|
||
);
|
||
|
||
// 対話モードかワンショットモードか
|
||
let is_interactive = args.prompt.is_none();
|
||
|
||
// モデル名(表示用)
|
||
let model_name = args
|
||
.model
|
||
.clone()
|
||
.unwrap_or_else(|| args.provider.default_model().to_string());
|
||
|
||
if is_interactive {
|
||
let title = format!("Worker CLI - {}", args.provider.display_name());
|
||
let border_len = title.len() + 6;
|
||
println!("╔{}╗", "═".repeat(border_len));
|
||
println!("║ {} ║", title);
|
||
println!("╚{}╝", "═".repeat(border_len));
|
||
println!();
|
||
println!("Provider: {}", args.provider.display_name());
|
||
println!("Model: {}", model_name);
|
||
if let Some(ref system) = args.system {
|
||
println!("System: {}", system);
|
||
}
|
||
if args.no_tools {
|
||
println!("Tools: disabled");
|
||
} else {
|
||
println!("Tools:");
|
||
println!(" • get_current_time - Get the current timestamp");
|
||
println!(" • calculate - Perform arithmetic (add, subtract, multiply, divide)");
|
||
}
|
||
println!();
|
||
println!("Type 'quit' or 'exit' to end the session.");
|
||
println!("─────────────────────────────────────────────────");
|
||
}
|
||
|
||
// クライアント作成
|
||
let client = match create_client(&args) {
|
||
Ok(c) => c,
|
||
Err(e) => {
|
||
eprintln!("❌ Error: {}", e);
|
||
std::process::exit(1);
|
||
}
|
||
};
|
||
|
||
// Worker作成
|
||
let mut worker = Worker::new(client);
|
||
|
||
// システムプロンプトを設定
|
||
if let Some(ref system_prompt) = args.system {
|
||
worker.set_system_prompt(system_prompt);
|
||
}
|
||
|
||
// ツール登録(--no-tools でなければ)
|
||
if !args.no_tools {
|
||
let app = AppContext;
|
||
worker.register_tool(app.get_current_time_tool());
|
||
worker.register_tool(app.calculate_tool());
|
||
}
|
||
|
||
// ストリーミング表示用ハンドラーを登録
|
||
worker
|
||
.timeline_mut()
|
||
.on_text_block(StreamingPrinter::new())
|
||
.on_tool_use_block(ToolCallPrinter);
|
||
|
||
// 会話履歴
|
||
let mut history: Vec<Message> = Vec::new();
|
||
|
||
// ワンショットモード
|
||
if let Some(prompt) = args.prompt {
|
||
history.push(Message::user(&prompt));
|
||
|
||
match worker.run(history).await {
|
||
Ok(_) => {}
|
||
Err(e) => {
|
||
eprintln!("\n❌ Error: {}", e);
|
||
std::process::exit(1);
|
||
}
|
||
}
|
||
|
||
return Ok(());
|
||
}
|
||
|
||
// 対話ループ
|
||
loop {
|
||
print!("\n👤 You: ");
|
||
io::stdout().flush()?;
|
||
|
||
let mut input = String::new();
|
||
io::stdin().read_line(&mut input)?;
|
||
let input = input.trim();
|
||
|
||
if input.is_empty() {
|
||
continue;
|
||
}
|
||
|
||
if input == "quit" || input == "exit" {
|
||
println!("\n👋 Goodbye!");
|
||
break;
|
||
}
|
||
|
||
// ユーザーメッセージを履歴に追加
|
||
history.push(Message::user(input));
|
||
|
||
// Workerを実行
|
||
match worker.run(history.clone()).await {
|
||
Ok(new_history) => {
|
||
history = new_history;
|
||
}
|
||
Err(e) => {
|
||
eprintln!("\n❌ Error: {}", e);
|
||
// エラー時は最後のユーザーメッセージを削除
|
||
history.pop();
|
||
}
|
||
}
|
||
}
|
||
|
||
Ok(())
|
||
}
|