llm_worker_rs/worker/examples/worker_cli.rs

497 lines
15 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! Worker を用いた対話型 CLI クライアント
//!
//! 複数のLLMプロバイダAnthropic, Gemini, OpenAI, Ollamaと対話するCLIアプリケーション。
//! ツールの登録と実行、ストリーミングレスポンスの表示をデモする。
//!
//! ## 使用方法
//!
//! ```bash
//! # .envファイルにAPIキーを設定
//! echo "ANTHROPIC_API_KEY=your-api-key" > .env
//! echo "GEMINI_API_KEY=your-api-key" >> .env
//! echo "OPENAI_API_KEY=your-api-key" >> .env
//!
//! # Anthropic (デフォルト)
//! cargo run --example worker_cli
//!
//! # Gemini
//! cargo run --example worker_cli -- --provider gemini
//!
//! # OpenAI
//! cargo run --example worker_cli -- --provider openai --model gpt-4o
//!
//! # Ollama (ローカル)
//! cargo run --example worker_cli -- --provider ollama --model llama3.2
//!
//! # オプション指定
//! cargo run --example worker_cli -- --provider anthropic --model claude-3-haiku-20240307 --system "You are a helpful assistant."
//!
//! # ヘルプ表示
//! cargo run --example worker_cli -- --help
//! ```
use std::collections::HashMap;
use std::io::{self, Write};
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use tracing::info;
use tracing_subscriber::EnvFilter;
use clap::{Parser, ValueEnum};
use worker::{
ControlFlow, Handler, HookError, TextBlockEvent, TextBlockKind, ToolResult, ToolUseBlockEvent,
ToolUseBlockKind, Worker, WorkerHook,
llm_client::{
LlmClient,
providers::{
anthropic::AnthropicClient, gemini::GeminiClient, ollama::OllamaClient,
openai::OpenAIClient,
},
},
};
use worker_macros::tool_registry;
// 必要なマクロ展開用インポート
use schemars;
use serde;
// =============================================================================
// プロバイダ定義
// =============================================================================
/// 利用可能なLLMプロバイダ
#[derive(Debug, Clone, Copy, ValueEnum, Default)]
enum Provider {
/// Anthropic Claude
#[default]
Anthropic,
/// Google Gemini
Gemini,
/// OpenAI GPT
Openai,
/// Ollama (ローカル)
Ollama,
}
impl Provider {
/// プロバイダのデフォルトモデル
fn default_model(&self) -> &'static str {
match self {
Provider::Anthropic => "claude-sonnet-4-20250514",
Provider::Gemini => "gemini-2.0-flash",
Provider::Openai => "gpt-4o",
Provider::Ollama => "llama3.2",
}
}
/// プロバイダの表示名
fn display_name(&self) -> &'static str {
match self {
Provider::Anthropic => "Anthropic Claude",
Provider::Gemini => "Google Gemini",
Provider::Openai => "OpenAI GPT",
Provider::Ollama => "Ollama (Local)",
}
}
/// APIキーの環境変数名
fn env_var_name(&self) -> Option<&'static str> {
match self {
Provider::Anthropic => Some("ANTHROPIC_API_KEY"),
Provider::Gemini => Some("GEMINI_API_KEY"),
Provider::Openai => Some("OPENAI_API_KEY"),
Provider::Ollama => None, // Ollamaはローカルなので不要
}
}
}
// =============================================================================
// CLI引数定義
// =============================================================================
/// 複数のLLMプロバイダに対応した対話型CLIクライアント
#[derive(Parser, Debug)]
#[command(name = "worker-cli")]
#[command(about = "Interactive CLI client for multiple LLM providers using Worker")]
#[command(version)]
struct Args {
/// 使用するプロバイダ
#[arg(long, value_enum, default_value_t = Provider::Anthropic)]
provider: Provider,
/// 使用するモデル名(未指定時はプロバイダのデフォルト)
#[arg(short, long)]
model: Option<String>,
/// システムプロンプト
#[arg(short, long)]
system: Option<String>,
/// ツールを無効化
#[arg(long, default_value = "false")]
no_tools: bool,
/// 最初のメッセージ(指定するとそれを送信して終了)
#[arg(short = 'p', long)]
prompt: Option<String>,
/// APIキー環境変数より優先
#[arg(long)]
api_key: Option<String>,
}
// =============================================================================
// ツール定義
// =============================================================================
/// アプリケーションコンテキスト
#[derive(Clone)]
struct AppContext;
#[tool_registry]
impl AppContext {
/// 現在の日時を取得する
///
/// システムの現在の日付と時刻を返します。
#[tool]
fn get_current_time(&self) -> String {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
// シンプルなUnixタイムスタンプからの変換
format!("Current Unix timestamp: {}", now)
}
/// 簡単な計算を行う
///
/// 2つの数値の四則演算を実行します。
#[tool]
fn calculate(&self, a: f64, b: f64, operation: String) -> Result<String, String> {
let result = match operation.as_str() {
"add" | "+" => a + b,
"subtract" | "-" => a - b,
"multiply" | "*" => a * b,
"divide" | "/" => {
if b == 0.0 {
return Err("Cannot divide by zero".to_string());
}
a / b
}
_ => return Err(format!("Unknown operation: {}", operation)),
};
Ok(format!("{} {} {} = {}", a, operation, b, result))
}
}
// =============================================================================
// ストリーミング表示用ハンドラー
// =============================================================================
/// テキストをリアルタイムで出力するハンドラー
struct StreamingPrinter {
is_first_delta: Arc<Mutex<bool>>,
}
impl StreamingPrinter {
fn new() -> Self {
Self {
is_first_delta: Arc::new(Mutex::new(true)),
}
}
}
impl Handler<TextBlockKind> for StreamingPrinter {
type Scope = ();
fn on_event(&mut self, _scope: &mut (), event: &TextBlockEvent) {
match event {
TextBlockEvent::Start(_) => {
let mut first = self.is_first_delta.lock().unwrap();
if *first {
print!("\n🤖 ");
*first = false;
}
}
TextBlockEvent::Delta(text) => {
print!("{}", text);
io::stdout().flush().ok();
}
TextBlockEvent::Stop(_) => {
println!();
}
}
}
}
/// ツール呼び出しを表示するハンドラー
struct ToolCallPrinter {
call_names: Arc<Mutex<HashMap<String, String>>>,
}
impl ToolCallPrinter {
fn new(call_names: Arc<Mutex<HashMap<String, String>>>) -> Self {
Self { call_names }
}
}
#[derive(Default)]
struct ToolCallPrinterScope {
input_json: String,
}
impl Handler<ToolUseBlockKind> for ToolCallPrinter {
type Scope = ToolCallPrinterScope;
fn on_event(&mut self, scope: &mut Self::Scope, event: &ToolUseBlockEvent) {
match event {
ToolUseBlockEvent::Start(start) => {
scope.input_json.clear();
self.call_names
.lock()
.unwrap()
.insert(start.id.clone(), start.name.clone());
println!("\n🔧 Calling tool: {}", start.name);
}
ToolUseBlockEvent::InputJsonDelta(json) => {
scope.input_json.push_str(json);
}
ToolUseBlockEvent::Stop(_) => {
if scope.input_json.is_empty() {
println!(" Args: {{}}");
} else {
println!(" Args: {}", scope.input_json);
}
scope.input_json.clear();
}
}
}
}
/// ツール実行結果を表示するHook
struct ToolResultPrinterHook {
call_names: Arc<Mutex<HashMap<String, String>>>,
}
impl ToolResultPrinterHook {
fn new(call_names: Arc<Mutex<HashMap<String, String>>>) -> Self {
Self { call_names }
}
}
#[async_trait]
impl WorkerHook for ToolResultPrinterHook {
async fn after_tool_call(
&self,
tool_result: &mut ToolResult,
) -> Result<ControlFlow, HookError> {
let name = self
.call_names
.lock()
.unwrap()
.remove(&tool_result.tool_use_id)
.unwrap_or_else(|| tool_result.tool_use_id.clone());
if tool_result.is_error {
println!(" Result ({}): ❌ {}", name, tool_result.content);
} else {
println!(" Result ({}): ✅ {}", name, tool_result.content);
}
Ok(ControlFlow::Continue)
}
}
// =============================================================================
// クライアント作成
// =============================================================================
/// プロバイダに応じたAPIキーを取得
fn get_api_key(args: &Args) -> Result<String, String> {
// CLI引数のAPIキーが優先
if let Some(ref key) = args.api_key {
return Ok(key.clone());
}
// プロバイダに応じた環境変数を確認
if let Some(env_var) = args.provider.env_var_name() {
std::env::var(env_var).map_err(|_| {
format!(
"API key required. Set {} environment variable or use --api-key",
env_var
)
})
} else {
// Ollamaなどはキー不要
Ok(String::new())
}
}
/// プロバイダに応じたクライアントを作成
fn create_client(args: &Args) -> Result<Box<dyn LlmClient>, String> {
let model = args
.model
.clone()
.unwrap_or_else(|| args.provider.default_model().to_string());
let api_key = get_api_key(args)?;
match args.provider {
Provider::Anthropic => {
let client = AnthropicClient::new(&api_key, &model);
Ok(Box::new(client))
}
Provider::Gemini => {
let client = GeminiClient::new(&api_key, &model);
Ok(Box::new(client))
}
Provider::Openai => {
let client = OpenAIClient::new(&api_key, &model);
Ok(Box::new(client))
}
Provider::Ollama => {
let client = OllamaClient::new(&model);
Ok(Box::new(client))
}
}
}
// =============================================================================
// メイン
// =============================================================================
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// .envファイルを読み込む
dotenv::dotenv().ok();
// ロギング初期化
// RUST_LOG=debug cargo run --example worker_cli ... で詳細ログ表示
// デフォルトは warn レベル、RUST_LOG 環境変数で上書き可能
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("warn"));
tracing_subscriber::fmt()
.with_env_filter(filter)
.with_target(true)
.init();
// CLI引数をパース
let args = Args::parse();
info!(
provider = ?args.provider,
model = ?args.model,
"Starting worker CLI"
);
// 対話モードかワンショットモードか
let is_interactive = args.prompt.is_none();
// モデル名(表示用)
let model_name = args
.model
.clone()
.unwrap_or_else(|| args.provider.default_model().to_string());
if is_interactive {
let title = format!("Worker CLI - {}", args.provider.display_name());
let border_len = title.len() + 6;
println!("{}", "".repeat(border_len));
println!("{}", title);
println!("{}", "".repeat(border_len));
println!();
println!("Provider: {}", args.provider.display_name());
println!("Model: {}", model_name);
if let Some(ref system) = args.system {
println!("System: {}", system);
}
if args.no_tools {
println!("Tools: disabled");
} else {
println!("Tools:");
println!(" • get_current_time - Get the current timestamp");
println!(" • calculate - Perform arithmetic (add, subtract, multiply, divide)");
}
println!();
println!("Type 'quit' or 'exit' to end the session.");
println!("─────────────────────────────────────────────────");
}
// クライアント作成
let client = match create_client(&args) {
Ok(c) => c,
Err(e) => {
eprintln!("❌ Error: {}", e);
std::process::exit(1);
}
};
// Worker作成
let mut worker = Worker::new(client);
let tool_call_names = Arc::new(Mutex::new(HashMap::new()));
// システムプロンプトを設定
if let Some(ref system_prompt) = args.system {
worker.set_system_prompt(system_prompt);
}
// ツール登録(--no-tools でなければ)
if !args.no_tools {
let app = AppContext;
worker.register_tool(app.get_current_time_tool());
worker.register_tool(app.calculate_tool());
}
// ストリーミング表示用ハンドラーを登録
worker
.timeline_mut()
.on_text_block(StreamingPrinter::new())
.on_tool_use_block(ToolCallPrinter::new(tool_call_names.clone()));
worker.add_hook(ToolResultPrinterHook::new(tool_call_names));
// ワンショットモード
if let Some(prompt) = args.prompt {
match worker.run(&prompt).await {
Ok(_) => {}
Err(e) => {
eprintln!("\n❌ Error: {}", e);
std::process::exit(1);
}
}
return Ok(());
}
// 対話ループ
loop {
print!("\n👤 You: ");
io::stdout().flush()?;
let mut input = String::new();
io::stdin().read_line(&mut input)?;
let input = input.trim();
if input.is_empty() {
continue;
}
if input == "quit" || input == "exit" {
println!("\n👋 Goodbye!");
break;
}
// Workerを実行Workerが履歴を管理
match worker.run(input).await {
Ok(_) => {}
Err(e) => {
eprintln!("\n❌ Error: {}", e);
}
}
}
Ok(())
}