From 32be6075c12aa44b4a4594b1f6bfa66af327addc Mon Sep 17 00:00:00 2001 From: Hare Date: Wed, 10 Jun 2026 18:16:23 +0900 Subject: [PATCH] feat: add setup model command --- Cargo.lock | 1 + crates/tui/Cargo.toml | 1 + crates/tui/src/lib.rs | 1 + crates/tui/src/setup_model.rs | 330 ++++++++++++++++++++++++++++++++++ crates/yoi/src/main.rs | 24 +++ 5 files changed, 357 insertions(+) create mode 100644 crates/tui/src/setup_model.rs diff --git a/Cargo.lock b/Cargo.lock index 7b091d61..6db38a35 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3965,6 +3965,7 @@ dependencies = [ "pod-registry", "pod-store", "protocol", + "provider", "pulldown-cmark", "ratatui", "secrets", diff --git a/crates/tui/Cargo.toml b/crates/tui/Cargo.toml index dade66c8..0fc0b5de 100644 --- a/crates/tui/Cargo.toml +++ b/crates/tui/Cargo.toml @@ -19,6 +19,7 @@ secrets = { workspace = true } session-store = { workspace = true } pod-store = { workspace = true } pod-registry = { workspace = true } +provider = { workspace = true } ticket = { workspace = true } serde = { workspace = true, features = ["derive"] } pulldown-cmark = { version = "0.13.3", default-features = false } diff --git a/crates/tui/src/lib.rs b/crates/tui/src/lib.rs index 1232759d..6f71e229 100644 --- a/crates/tui/src/lib.rs +++ b/crates/tui/src/lib.rs @@ -12,6 +12,7 @@ mod picker; mod pod_list; mod role_session_registry; mod scroll; +pub mod setup_model; mod single_pod; mod spawn; mod task; diff --git a/crates/tui/src/setup_model.rs b/crates/tui/src/setup_model.rs new file mode 100644 index 00000000..33c43d07 --- /dev/null +++ b/crates/tui/src/setup_model.rs @@ -0,0 +1,330 @@ +use std::collections::BTreeMap; +use std::io::{self, Write}; +use std::path::{Path, PathBuf}; +use std::process::ExitCode; + +use provider::catalog::{self, AuthHint, ModelEntry, ProviderEntry}; +use toml::Value; +use toml::map::Map; + +const GENERATED_PROFILE_NAME: &str = "default"; +const GENERATED_PROFILE_PATH: &str = "profiles/default.lua"; +const GENERATED_PROFILE_DESCRIPTION: &str = "Generated by yoi setup-model"; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ModelChoice { + pub model_ref: String, + pub provider_id: String, + pub provider_display_name: String, + pub model_id: String, + pub auth_hint: AuthHint, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct WrittenSetupConfig { + pub registry_path: PathBuf, + pub profile_path: PathBuf, + pub model_ref: String, +} + +pub async fn launch() -> ExitCode { + match run_interactive_setup() { + Ok(written) => { + println!( + "Saved default Profile `{}` for model `{}`", + GENERATED_PROFILE_NAME, written.model_ref + ); + println!("registry: {}", written.registry_path.display()); + println!("profile: {}", written.profile_path.display()); + ExitCode::SUCCESS + } + Err(err) => { + eprintln!("yoi setup-model: {err}"); + ExitCode::FAILURE + } + } +} + +fn run_interactive_setup() -> Result> { + let config_dir = manifest::paths::config_dir().ok_or_else(|| { + io::Error::new( + io::ErrorKind::NotFound, + "could not determine yoi config directory", + ) + })?; + let choices = load_model_choices()?; + let choice = prompt_model_choice(&choices)?; + write_default_profile_config(&config_dir, &choice.model_ref) +} + +pub fn load_model_choices() -> Result, catalog::CatalogError> { + let providers = catalog::load_providers()? + .into_iter() + .map(|provider| (provider.id.clone(), provider)) + .collect::>(); + let mut choices = catalog::load_models()? + .into_iter() + .filter_map(|model| choice_from_model(model, &providers)) + .collect::>(); + choices.sort_by(|a, b| { + a.provider_display_name + .cmp(&b.provider_display_name) + .then_with(|| a.model_id.cmp(&b.model_id)) + .then_with(|| a.model_ref.cmp(&b.model_ref)) + }); + Ok(choices) +} + +fn choice_from_model( + model: ModelEntry, + providers: &BTreeMap, +) -> Option { + let provider = providers.get(&model.provider)?; + Some(ModelChoice { + model_ref: format!("{}/{}", model.provider, model.id), + provider_id: provider.id.clone(), + provider_display_name: provider.display_name.clone(), + model_id: model.id, + auth_hint: provider.auth_hint.clone(), + }) +} + +fn prompt_model_choice( + choices: &[ModelChoice], +) -> Result<&ModelChoice, Box> { + if choices.is_empty() { + return Err("no models are configured in the model catalog".into()); + } + + println!("yoi setup-model"); + println!(); + println!("Choose the default model Profile to write under the user config directory."); + println!("This command only writes Profile config; it does not start or attach a Pod."); + println!(); + for (idx, choice) in choices.iter().enumerate() { + println!( + "{:>2}. {:<42} {} ({})", + idx + 1, + choice.model_ref, + choice.provider_display_name, + auth_hint_label(&choice.auth_hint), + ); + } + println!(); + print!("Select model [1]: "); + io::stdout().flush()?; + + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + let selection = input.trim(); + let index = if selection.is_empty() { + 0 + } else { + selection + .parse::() + .map_err(|_| format!("invalid model selection `{selection}`"))? + .checked_sub(1) + .ok_or("model selection starts at 1")? + }; + choices + .get(index) + .ok_or_else(|| format!("model selection {} is out of range", index + 1).into()) +} + +fn auth_hint_label(hint: &AuthHint) -> String { + match hint { + AuthHint::None => "no auth".to_string(), + AuthHint::ApiKey => "API key file".to_string(), + AuthHint::SecretRef { ref_ } => format!("secret `{ref_}`"), + AuthHint::CodexOAuth => "Codex OAuth".to_string(), + } +} + +pub fn write_default_profile_config( + config_dir: &Path, + model_ref: &str, +) -> Result> { + std::fs::create_dir_all(config_dir)?; + let profiles_dir = config_dir.join("profiles"); + std::fs::create_dir_all(&profiles_dir)?; + + let profile_path = config_dir.join(GENERATED_PROFILE_PATH); + std::fs::write(&profile_path, generated_profile_lua(model_ref))?; + + let registry_path = config_dir.join("profiles.toml"); + let mut document = read_registry_document(®istry_path)?; + set_default_profile_entry(&mut document)?; + std::fs::write(®istry_path, toml::to_string_pretty(&document)?)?; + + Ok(WrittenSetupConfig { + registry_path, + profile_path, + model_ref: model_ref.to_string(), + }) +} + +fn read_registry_document(path: &Path) -> Result> { + if !path.exists() { + return Ok(Value::Table(Map::new())); + } + let text = std::fs::read_to_string(path)?; + let value: Value = toml::from_str(&text)?; + if !value.is_table() { + return Err(format!("{} must contain a TOML table", path.display()).into()); + } + Ok(value) +} + +fn set_default_profile_entry(document: &mut Value) -> Result<(), Box> { + let table = document + .as_table_mut() + .ok_or("profiles.toml root must be a TOML table")?; + table.insert( + "default".to_string(), + Value::String(format!("user:{GENERATED_PROFILE_NAME}")), + ); + + let profile_value = table + .entry("profile".to_string()) + .or_insert_with(|| Value::Table(Map::new())); + let profile_table = profile_value + .as_table_mut() + .ok_or("profiles.toml `profile` must be a TOML table")?; + let mut entry = Map::new(); + entry.insert( + "path".to_string(), + Value::String(GENERATED_PROFILE_PATH.to_string()), + ); + entry.insert( + "description".to_string(), + Value::String(GENERATED_PROFILE_DESCRIPTION.to_string()), + ); + profile_table.insert(GENERATED_PROFILE_NAME.to_string(), Value::Table(entry)); + Ok(()) +} + +fn generated_profile_lua(model_ref: &str) -> String { + format!( + r#"local profile = require("yoi.profile") +local scope = require("yoi.scope") +local compact = require("yoi.compact") + +return profile {{ + slug = "default", + description = "Generated by yoi setup-model", + + scope = scope.workspace_write(), + + session = {{ + record_event_trace = true, + }}, + + worker = {{ + reasoning = "high", + }}, + + model = {{ + ref = "{}", + }}, + + compaction = compact.tokens {{ + threshold = 240000, + request_threshold = 270000, + worker_context_max_tokens = 100000, + }}, + + feature = {{ + task = {{ enabled = true }}, + memory = {{ enabled = true }}, + web = {{ enabled = true }}, + pods = {{ enabled = false }}, + ticket = {{ enabled = false, access = "lifecycle" }}, + ticket_orchestration = {{ enabled = false }}, + }}, + + memory = {{ + extract_threshold = 50000, + consolidation_threshold_files = 5, + consolidation_threshold_bytes = 50000, + }}, + + web = {{ + enabled = true, + search = {{ + provider = "brave", + api_key_secret = "web/brave/default", + }}, + }}, +}} +"#, + escape_lua_string(model_ref) + ) +} + +fn escape_lua_string(value: &str) -> String { + value + .chars() + .flat_map(|c| match c { + '\\' => "\\\\".chars().collect::>(), + '"' => "\\\"".chars().collect::>(), + '\n' => "\\n".chars().collect::>(), + '\r' => "\\r".chars().collect::>(), + '\t' => "\\t".chars().collect::>(), + other => vec![other], + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn write_default_profile_config_creates_registry_and_profile() { + let dir = tempfile::tempdir().unwrap(); + + let written = write_default_profile_config(dir.path(), "codex-oauth/gpt-5.5").unwrap(); + + assert_eq!(written.registry_path, dir.path().join("profiles.toml")); + assert_eq!( + written.profile_path, + dir.path().join("profiles/default.lua") + ); + + let registry = std::fs::read_to_string(&written.registry_path).unwrap(); + assert!(registry.contains("default = \"user:default\"")); + assert!(registry.contains("[profile.default]")); + assert!(registry.contains("path = \"profiles/default.lua\"")); + + let profile = std::fs::read_to_string(&written.profile_path).unwrap(); + assert!(profile.contains("slug = \"default\"")); + assert!(profile.contains("ref = \"codex-oauth/gpt-5.5\"")); + assert!(profile.contains("scope = scope.workspace_write()")); + } + + #[test] + fn write_default_profile_config_preserves_other_profile_entries() { + let dir = tempfile::tempdir().unwrap(); + std::fs::write( + dir.path().join("profiles.toml"), + r#"[profile.other] +path = "profiles/other.lua" +description = "keep me" +"#, + ) + .unwrap(); + + write_default_profile_config(dir.path(), "anthropic/claude-sonnet-4-6").unwrap(); + + let registry = std::fs::read_to_string(dir.path().join("profiles.toml")).unwrap(); + assert!(registry.contains("[profile.other]")); + assert!(registry.contains("path = \"profiles/other.lua\"")); + assert!(registry.contains("[profile.default]")); + assert!(registry.contains("default = \"user:default\"")); + } + + #[test] + fn escape_lua_string_escapes_quotes_and_slashes() { + assert_eq!(escape_lua_string("a\\b\"c"), "a\\\\b\\\"c"); + } +} diff --git a/crates/yoi/src/main.rs b/crates/yoi/src/main.rs index a86cc9cc..86eab03e 100644 --- a/crates/yoi/src/main.rs +++ b/crates/yoi/src/main.rs @@ -22,6 +22,7 @@ enum Mode { Ticket(ticket_cli::TicketCli), PodRuntime(Vec), Keys, + SetupModel, Tui { mode: LaunchMode, workspace_root: PathBuf, @@ -107,6 +108,7 @@ async fn main() -> ExitCode { }, Mode::PodRuntime(args) => pod::entrypoint::run_cli_from("yoi pod", args).await, Mode::Keys => tui::keys::launch().await, + Mode::SetupModel => tui::setup_model::launch().await, Mode::Tui { mode, workspace_root, @@ -183,6 +185,14 @@ fn parse_args_slice(args: &[String]) -> Result { } return Ok(Mode::Keys); } + "setup-model" => { + if args.len() != 1 { + return Err(ParseError( + "yoi setup-model does not accept arguments".into(), + )); + } + return Ok(Mode::SetupModel); + } "memory" if args.get(1).map(String::as_str) == Some("lint") => { let lint_args = &args[2..]; if lint_args.iter().any(|arg| arg == "--help" || arg == "-h") { @@ -544,6 +554,20 @@ mod tests { } } + #[test] + fn parse_setup_model_subcommand() { + match parse_args_from(["setup-model"]).unwrap() { + Mode::SetupModel => {} + _ => panic!("expected SetupModel mode"), + } + } + + #[test] + fn parse_setup_model_rejects_arguments() { + let err = parse_args_from(["setup-model", "extra"]).unwrap_err(); + assert_eq!(err.to_string(), "yoi setup-model does not accept arguments"); + } + #[test] fn parse_literal_pod_name_still_available_with_flag() { match parse_args_from(["--pod", "pod"]).unwrap() {