llm_worker_rs/llm-worker/examples/worker_cli.rs
2026-01-16 16:58:03 +09:00

497 lines
14 KiB
Rust

//! Interactive CLI client using Worker
//!
//! A CLI application for interacting with multiple LLM providers (Anthropic, Gemini, OpenAI, Ollama).
//! Demonstrates tool registration and execution, and streaming response display.
//!
//! ## Usage
//!
//! ```bash
//! # Set API keys in .env file
//! echo "ANTHROPIC_API_KEY=your-api-key" > .env
//! echo "GEMINI_API_KEY=your-api-key" >> .env
//! echo "OPENAI_API_KEY=your-api-key" >> .env
//!
//! # Anthropic (default)
//! cargo run --example worker_cli
//!
//! # Gemini
//! cargo run --example worker_cli -- --provider gemini
//!
//! # OpenAI
//! cargo run --example worker_cli -- --provider openai --model gpt-4o
//!
//! # Ollama (local)
//! cargo run --example worker_cli -- --provider ollama --model llama3.2
//!
//! # With options
//! cargo run --example worker_cli -- --provider anthropic --model claude-3-haiku-20240307 --system "You are a helpful assistant."
//!
//! # Show help
//! cargo run --example worker_cli -- --help
//! ```
use std::collections::HashMap;
use std::io::{self, Write};
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use tracing::info;
use tracing_subscriber::EnvFilter;
use clap::{Parser, ValueEnum};
use llm_worker::{
Worker,
hook::{Hook, HookError, PostToolCall, PostToolCallContext, PostToolCallResult},
llm_client::{
LlmClient,
providers::{
anthropic::AnthropicClient, gemini::GeminiClient, ollama::OllamaClient,
openai::OpenAIClient,
},
},
timeline::{Handler, TextBlockEvent, TextBlockKind, ToolUseBlockEvent, ToolUseBlockKind},
};
use llm_worker_macros::tool_registry;
// Required imports for macro expansion
use schemars;
use serde;
// =============================================================================
// Provider Definition
// =============================================================================
/// Available LLM providers
#[derive(Debug, Clone, Copy, ValueEnum, Default)]
enum Provider {
/// Anthropic Claude
#[default]
Anthropic,
/// Google Gemini
Gemini,
/// OpenAI GPT
Openai,
/// Ollama (local)
Ollama,
}
impl Provider {
/// Default model for the provider
fn default_model(&self) -> &'static str {
match self {
Provider::Anthropic => "claude-sonnet-4-20250514",
Provider::Gemini => "gemini-2.0-flash",
Provider::Openai => "gpt-4o",
Provider::Ollama => "llama3.2",
}
}
/// Display name for the provider
fn display_name(&self) -> &'static str {
match self {
Provider::Anthropic => "Anthropic Claude",
Provider::Gemini => "Google Gemini",
Provider::Openai => "OpenAI GPT",
Provider::Ollama => "Ollama (Local)",
}
}
/// Environment variable name for API key
fn env_var_name(&self) -> Option<&'static str> {
match self {
Provider::Anthropic => Some("ANTHROPIC_API_KEY"),
Provider::Gemini => Some("GEMINI_API_KEY"),
Provider::Openai => Some("OPENAI_API_KEY"),
Provider::Ollama => None, // Ollama is local, no key needed
}
}
}
// =============================================================================
// CLI Argument Definition
// =============================================================================
/// Interactive CLI client supporting multiple LLM providers
#[derive(Parser, Debug)]
#[command(name = "worker-cli")]
#[command(about = "Interactive CLI client for multiple LLM providers using Worker")]
#[command(version)]
struct Args {
/// Provider to use
#[arg(long, value_enum, default_value_t = Provider::Anthropic)]
provider: Provider,
/// Model name to use (defaults to provider's default if not specified)
#[arg(short, long)]
model: Option<String>,
/// System prompt
#[arg(short, long)]
system: Option<String>,
/// Disable tools
#[arg(long, default_value = "false")]
no_tools: bool,
/// Initial message (if specified, sends it and exits)
#[arg(short = 'p', long)]
prompt: Option<String>,
/// API key (takes precedence over environment variable)
#[arg(long)]
api_key: Option<String>,
}
// =============================================================================
// Tool Definition
// =============================================================================
/// Application context
#[derive(Clone)]
struct AppContext;
#[tool_registry]
impl AppContext {
/// Get the current date and time
///
/// Returns the system's current date and time.
#[tool]
fn get_current_time(&self) -> String {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
// Simple conversion from Unix timestamp
format!("Current Unix timestamp: {}", now)
}
/// Perform a simple calculation
///
/// Executes arithmetic operations on two numbers.
#[tool]
fn calculate(&self, a: f64, b: f64, operation: String) -> Result<String, String> {
let result = match operation.as_str() {
"add" | "+" => a + b,
"subtract" | "-" => a - b,
"multiply" | "*" => a * b,
"divide" | "/" => {
if b == 0.0 {
return Err("Cannot divide by zero".to_string());
}
a / b
}
_ => return Err(format!("Unknown operation: {}", operation)),
};
Ok(format!("{} {} {} = {}", a, operation, b, result))
}
}
// =============================================================================
// Streaming Display Handlers
// =============================================================================
/// Handler that outputs text in real-time
struct StreamingPrinter {
is_first_delta: Arc<Mutex<bool>>,
}
impl StreamingPrinter {
fn new() -> Self {
Self {
is_first_delta: Arc::new(Mutex::new(true)),
}
}
}
impl Handler<TextBlockKind> for StreamingPrinter {
type Scope = ();
fn on_event(&mut self, _scope: &mut (), event: &TextBlockEvent) {
match event {
TextBlockEvent::Start(_) => {
let mut first = self.is_first_delta.lock().unwrap();
if *first {
print!("\n🤖 ");
*first = false;
}
}
TextBlockEvent::Delta(text) => {
print!("{}", text);
io::stdout().flush().ok();
}
TextBlockEvent::Stop(_) => {
println!();
}
}
}
}
/// Handler that displays tool calls
struct ToolCallPrinter {
call_names: Arc<Mutex<HashMap<String, String>>>,
}
impl ToolCallPrinter {
fn new(call_names: Arc<Mutex<HashMap<String, String>>>) -> Self {
Self { call_names }
}
}
#[derive(Default)]
struct ToolCallPrinterScope {
input_json: String,
}
impl Handler<ToolUseBlockKind> for ToolCallPrinter {
type Scope = ToolCallPrinterScope;
fn on_event(&mut self, scope: &mut Self::Scope, event: &ToolUseBlockEvent) {
match event {
ToolUseBlockEvent::Start(start) => {
scope.input_json.clear();
self.call_names
.lock()
.unwrap()
.insert(start.id.clone(), start.name.clone());
println!("\n🔧 Calling tool: {}", start.name);
}
ToolUseBlockEvent::InputJsonDelta(json) => {
scope.input_json.push_str(json);
}
ToolUseBlockEvent::Stop(_) => {
if scope.input_json.is_empty() {
println!(" Args: {{}}");
} else {
println!(" Args: {}", scope.input_json);
}
scope.input_json.clear();
}
}
}
}
/// Hook that displays tool execution results
struct ToolResultPrinterHook {
call_names: Arc<Mutex<HashMap<String, String>>>,
}
impl ToolResultPrinterHook {
fn new(call_names: Arc<Mutex<HashMap<String, String>>>) -> Self {
Self { call_names }
}
}
#[async_trait]
impl Hook<PostToolCall> for ToolResultPrinterHook {
async fn call(&self, ctx: &mut PostToolCallContext) -> Result<PostToolCallResult, HookError> {
let name = self
.call_names
.lock()
.unwrap()
.remove(&ctx.result.tool_use_id)
.unwrap_or_else(|| ctx.result.tool_use_id.clone());
if ctx.result.is_error {
println!(" Result ({}): ❌ {}", name, ctx.result.content);
} else {
println!(" Result ({}): ✅ {}", name, ctx.result.content);
}
Ok(PostToolCallResult::Continue)
}
}
// =============================================================================
// Client Creation
// =============================================================================
/// Get API key based on provider
fn get_api_key(args: &Args) -> Result<String, String> {
// CLI argument API key takes precedence
if let Some(ref key) = args.api_key {
return Ok(key.clone());
}
// Check environment variable based on provider
if let Some(env_var) = args.provider.env_var_name() {
std::env::var(env_var).map_err(|_| {
format!(
"API key required. Set {} environment variable or use --api-key",
env_var
)
})
} else {
// Ollama etc. don't need a key
Ok(String::new())
}
}
/// Create client based on provider
fn create_client(args: &Args) -> Result<Box<dyn LlmClient>, String> {
let model = args
.model
.clone()
.unwrap_or_else(|| args.provider.default_model().to_string());
let api_key = get_api_key(args)?;
match args.provider {
Provider::Anthropic => {
let client = AnthropicClient::new(&api_key, &model);
Ok(Box::new(client))
}
Provider::Gemini => {
let client = GeminiClient::new(&api_key, &model);
Ok(Box::new(client))
}
Provider::Openai => {
let client = OpenAIClient::new(&api_key, &model);
Ok(Box::new(client))
}
Provider::Ollama => {
let client = OllamaClient::new(&model);
Ok(Box::new(client))
}
}
}
// =============================================================================
// Main
// =============================================================================
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Load .env file
dotenv::dotenv().ok();
// Initialize logging
// Use RUST_LOG=debug cargo run --example worker_cli ... for detailed logs
// Default is warn level, can be overridden with RUST_LOG environment variable
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("warn"));
tracing_subscriber::fmt()
.with_env_filter(filter)
.with_target(true)
.init();
// Parse CLI arguments
let args = Args::parse();
info!(
provider = ?args.provider,
model = ?args.model,
"Starting worker CLI"
);
// Interactive mode or one-shot mode
let is_interactive = args.prompt.is_none();
// Model name (for display)
let model_name = args
.model
.clone()
.unwrap_or_else(|| args.provider.default_model().to_string());
if is_interactive {
let title = format!("Worker CLI - {}", args.provider.display_name());
let border_len = title.len() + 6;
println!("{}", "".repeat(border_len));
println!("{}", title);
println!("{}", "".repeat(border_len));
println!();
println!("Provider: {}", args.provider.display_name());
println!("Model: {}", model_name);
if let Some(ref system) = args.system {
println!("System: {}", system);
}
if args.no_tools {
println!("Tools: disabled");
} else {
println!("Tools:");
println!(" • get_current_time - Get the current timestamp");
println!(" • calculate - Perform arithmetic (add, subtract, multiply, divide)");
}
println!();
println!("Type 'quit' or 'exit' to end the session.");
println!("─────────────────────────────────────────────────");
}
// Create client
let client = match create_client(&args) {
Ok(c) => c,
Err(e) => {
eprintln!("❌ Error: {}", e);
std::process::exit(1);
}
};
// Create Worker
let mut worker = Worker::new(client);
let tool_call_names = Arc::new(Mutex::new(HashMap::new()));
// Set system prompt
if let Some(ref system_prompt) = args.system {
worker.set_system_prompt(system_prompt);
}
// Register tools (unless --no-tools)
if !args.no_tools {
let app = AppContext;
worker
.register_tool(app.get_current_time_definition())
.unwrap();
worker.register_tool(app.calculate_definition()).unwrap();
}
// Register streaming display handlers
worker
.timeline_mut()
.on_text_block(StreamingPrinter::new())
.on_tool_use_block(ToolCallPrinter::new(tool_call_names.clone()));
worker.add_post_tool_call_hook(ToolResultPrinterHook::new(tool_call_names));
// One-shot mode
if let Some(prompt) = args.prompt {
match worker.run(&prompt).await {
Ok(_) => {}
Err(e) => {
eprintln!("\n❌ Error: {}", e);
std::process::exit(1);
}
}
return Ok(());
}
// Interactive loop
loop {
print!("\n👤 You: ");
io::stdout().flush()?;
let mut input = String::new();
io::stdin().read_line(&mut input)?;
let input = input.trim();
if input.is_empty() {
continue;
}
if input == "quit" || input == "exit" {
println!("\n👋 Goodbye!");
break;
}
// Run Worker (Worker manages history)
match worker.run(input).await {
Ok(_) => {}
Err(e) => {
eprintln!("\n❌ Error: {}", e);
}
}
}
Ok(())
}