102 lines
3.0 KiB
Rust
102 lines
3.0 KiB
Rust
use std::collections::HashMap;
|
|
use std::sync::{Arc, Mutex};
|
|
use worker::{
|
|
PromptError, ResourceLoader, Role, Worker,
|
|
plugin::{PluginRegistry, ProviderPlugin, example_provider::CustomProviderPlugin},
|
|
};
|
|
|
|
struct FsLoader;
|
|
|
|
impl ResourceLoader for FsLoader {
|
|
fn load(&self, identifier: &str) -> Result<String, PromptError> {
|
|
std::fs::read_to_string(identifier)
|
|
.map_err(|e| PromptError::FileNotFound(format!("{}: {}", identifier, e)))
|
|
}
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
// Initialize tracing for debugging
|
|
tracing_subscriber::fmt::init();
|
|
|
|
// Create a plugin registry
|
|
let plugin_registry = Arc::new(Mutex::new(PluginRegistry::new()));
|
|
|
|
// Create and initialize a custom provider plugin
|
|
let mut custom_plugin = CustomProviderPlugin::new();
|
|
|
|
let mut config = HashMap::new();
|
|
config.insert(
|
|
"base_url".to_string(),
|
|
serde_json::Value::String("https://api.custom-provider.com".to_string()),
|
|
);
|
|
config.insert(
|
|
"timeout".to_string(),
|
|
serde_json::Value::Number(serde_json::Number::from(60)),
|
|
);
|
|
|
|
custom_plugin.initialize(config).await?;
|
|
|
|
// Register the plugin
|
|
{
|
|
let mut registry = plugin_registry.lock().unwrap();
|
|
registry.register(Arc::new(custom_plugin))?;
|
|
}
|
|
|
|
// List available plugins
|
|
{
|
|
let registry = plugin_registry.lock().unwrap();
|
|
let plugins = registry.list();
|
|
println!("Available plugins:");
|
|
for plugin in plugins {
|
|
println!(
|
|
" - {} ({}): {}",
|
|
plugin.name, plugin.id, plugin.description
|
|
);
|
|
println!(" Supported models: {:?}", plugin.supported_models);
|
|
}
|
|
}
|
|
|
|
// Create a Worker instance using the plugin
|
|
let role = Role::new(
|
|
"assistant",
|
|
"A helpful AI assistant",
|
|
"You are a helpful, harmless, and honest AI assistant powered by a custom LLM provider.",
|
|
);
|
|
|
|
let worker = Worker::builder()
|
|
.plugin("custom-provider", plugin_registry.clone())
|
|
.model("custom-turbo")
|
|
.api_key("__plugin__", "custom-1234567890abcdefghijklmnop")
|
|
.resource_loader(FsLoader)
|
|
.role(role)
|
|
.build()?;
|
|
|
|
println!("\nWorker created successfully with custom provider plugin!");
|
|
|
|
// Example: List plugins from the worker
|
|
let plugin_list = worker.list_plugins()?;
|
|
println!("\nPlugins registered in worker:");
|
|
for metadata in plugin_list {
|
|
println!(
|
|
" - {}: v{} by {}",
|
|
metadata.name, metadata.version, metadata.author
|
|
);
|
|
}
|
|
|
|
// Load plugins from directory (if dynamic loading is enabled)
|
|
#[cfg(feature = "dynamic-loading")]
|
|
{
|
|
use std::path::Path;
|
|
|
|
let plugin_dir = Path::new("./plugins");
|
|
if plugin_dir.exists() {
|
|
let mut worker = worker;
|
|
worker.load_plugins_from_directory(plugin_dir).await?;
|
|
println!("\nLoaded plugins from directory: {:?}", plugin_dir);
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|