feat: add setup model command

This commit is contained in:
Keisuke Hirata 2026-06-10 18:16:23 +09:00
parent 6bb023e9fe
commit 32be6075c1
No known key found for this signature in database
5 changed files with 357 additions and 0 deletions

1
Cargo.lock generated
View File

@ -3965,6 +3965,7 @@ dependencies = [
"pod-registry",
"pod-store",
"protocol",
"provider",
"pulldown-cmark",
"ratatui",
"secrets",

View File

@ -19,6 +19,7 @@ secrets = { workspace = true }
session-store = { workspace = true }
pod-store = { workspace = true }
pod-registry = { workspace = true }
provider = { workspace = true }
ticket = { workspace = true }
serde = { workspace = true, features = ["derive"] }
pulldown-cmark = { version = "0.13.3", default-features = false }

View File

@ -12,6 +12,7 @@ mod picker;
mod pod_list;
mod role_session_registry;
mod scroll;
pub mod setup_model;
mod single_pod;
mod spawn;
mod task;

View File

@ -0,0 +1,330 @@
use std::collections::BTreeMap;
use std::io::{self, Write};
use std::path::{Path, PathBuf};
use std::process::ExitCode;
use provider::catalog::{self, AuthHint, ModelEntry, ProviderEntry};
use toml::Value;
use toml::map::Map;
const GENERATED_PROFILE_NAME: &str = "default";
const GENERATED_PROFILE_PATH: &str = "profiles/default.lua";
const GENERATED_PROFILE_DESCRIPTION: &str = "Generated by yoi setup-model";
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ModelChoice {
pub model_ref: String,
pub provider_id: String,
pub provider_display_name: String,
pub model_id: String,
pub auth_hint: AuthHint,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WrittenSetupConfig {
pub registry_path: PathBuf,
pub profile_path: PathBuf,
pub model_ref: String,
}
pub async fn launch() -> ExitCode {
match run_interactive_setup() {
Ok(written) => {
println!(
"Saved default Profile `{}` for model `{}`",
GENERATED_PROFILE_NAME, written.model_ref
);
println!("registry: {}", written.registry_path.display());
println!("profile: {}", written.profile_path.display());
ExitCode::SUCCESS
}
Err(err) => {
eprintln!("yoi setup-model: {err}");
ExitCode::FAILURE
}
}
}
fn run_interactive_setup() -> Result<WrittenSetupConfig, Box<dyn std::error::Error>> {
let config_dir = manifest::paths::config_dir().ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
"could not determine yoi config directory",
)
})?;
let choices = load_model_choices()?;
let choice = prompt_model_choice(&choices)?;
write_default_profile_config(&config_dir, &choice.model_ref)
}
pub fn load_model_choices() -> Result<Vec<ModelChoice>, catalog::CatalogError> {
let providers = catalog::load_providers()?
.into_iter()
.map(|provider| (provider.id.clone(), provider))
.collect::<BTreeMap<_, _>>();
let mut choices = catalog::load_models()?
.into_iter()
.filter_map(|model| choice_from_model(model, &providers))
.collect::<Vec<_>>();
choices.sort_by(|a, b| {
a.provider_display_name
.cmp(&b.provider_display_name)
.then_with(|| a.model_id.cmp(&b.model_id))
.then_with(|| a.model_ref.cmp(&b.model_ref))
});
Ok(choices)
}
fn choice_from_model(
model: ModelEntry,
providers: &BTreeMap<String, ProviderEntry>,
) -> Option<ModelChoice> {
let provider = providers.get(&model.provider)?;
Some(ModelChoice {
model_ref: format!("{}/{}", model.provider, model.id),
provider_id: provider.id.clone(),
provider_display_name: provider.display_name.clone(),
model_id: model.id,
auth_hint: provider.auth_hint.clone(),
})
}
fn prompt_model_choice(
choices: &[ModelChoice],
) -> Result<&ModelChoice, Box<dyn std::error::Error>> {
if choices.is_empty() {
return Err("no models are configured in the model catalog".into());
}
println!("yoi setup-model");
println!();
println!("Choose the default model Profile to write under the user config directory.");
println!("This command only writes Profile config; it does not start or attach a Pod.");
println!();
for (idx, choice) in choices.iter().enumerate() {
println!(
"{:>2}. {:<42} {} ({})",
idx + 1,
choice.model_ref,
choice.provider_display_name,
auth_hint_label(&choice.auth_hint),
);
}
println!();
print!("Select model [1]: ");
io::stdout().flush()?;
let mut input = String::new();
io::stdin().read_line(&mut input)?;
let selection = input.trim();
let index = if selection.is_empty() {
0
} else {
selection
.parse::<usize>()
.map_err(|_| format!("invalid model selection `{selection}`"))?
.checked_sub(1)
.ok_or("model selection starts at 1")?
};
choices
.get(index)
.ok_or_else(|| format!("model selection {} is out of range", index + 1).into())
}
fn auth_hint_label(hint: &AuthHint) -> String {
match hint {
AuthHint::None => "no auth".to_string(),
AuthHint::ApiKey => "API key file".to_string(),
AuthHint::SecretRef { ref_ } => format!("secret `{ref_}`"),
AuthHint::CodexOAuth => "Codex OAuth".to_string(),
}
}
pub fn write_default_profile_config(
config_dir: &Path,
model_ref: &str,
) -> Result<WrittenSetupConfig, Box<dyn std::error::Error>> {
std::fs::create_dir_all(config_dir)?;
let profiles_dir = config_dir.join("profiles");
std::fs::create_dir_all(&profiles_dir)?;
let profile_path = config_dir.join(GENERATED_PROFILE_PATH);
std::fs::write(&profile_path, generated_profile_lua(model_ref))?;
let registry_path = config_dir.join("profiles.toml");
let mut document = read_registry_document(&registry_path)?;
set_default_profile_entry(&mut document)?;
std::fs::write(&registry_path, toml::to_string_pretty(&document)?)?;
Ok(WrittenSetupConfig {
registry_path,
profile_path,
model_ref: model_ref.to_string(),
})
}
fn read_registry_document(path: &Path) -> Result<Value, Box<dyn std::error::Error>> {
if !path.exists() {
return Ok(Value::Table(Map::new()));
}
let text = std::fs::read_to_string(path)?;
let value: Value = toml::from_str(&text)?;
if !value.is_table() {
return Err(format!("{} must contain a TOML table", path.display()).into());
}
Ok(value)
}
fn set_default_profile_entry(document: &mut Value) -> Result<(), Box<dyn std::error::Error>> {
let table = document
.as_table_mut()
.ok_or("profiles.toml root must be a TOML table")?;
table.insert(
"default".to_string(),
Value::String(format!("user:{GENERATED_PROFILE_NAME}")),
);
let profile_value = table
.entry("profile".to_string())
.or_insert_with(|| Value::Table(Map::new()));
let profile_table = profile_value
.as_table_mut()
.ok_or("profiles.toml `profile` must be a TOML table")?;
let mut entry = Map::new();
entry.insert(
"path".to_string(),
Value::String(GENERATED_PROFILE_PATH.to_string()),
);
entry.insert(
"description".to_string(),
Value::String(GENERATED_PROFILE_DESCRIPTION.to_string()),
);
profile_table.insert(GENERATED_PROFILE_NAME.to_string(), Value::Table(entry));
Ok(())
}
fn generated_profile_lua(model_ref: &str) -> String {
format!(
r#"local profile = require("yoi.profile")
local scope = require("yoi.scope")
local compact = require("yoi.compact")
return profile {{
slug = "default",
description = "Generated by yoi setup-model",
scope = scope.workspace_write(),
session = {{
record_event_trace = true,
}},
worker = {{
reasoning = "high",
}},
model = {{
ref = "{}",
}},
compaction = compact.tokens {{
threshold = 240000,
request_threshold = 270000,
worker_context_max_tokens = 100000,
}},
feature = {{
task = {{ enabled = true }},
memory = {{ enabled = true }},
web = {{ enabled = true }},
pods = {{ enabled = false }},
ticket = {{ enabled = false, access = "lifecycle" }},
ticket_orchestration = {{ enabled = false }},
}},
memory = {{
extract_threshold = 50000,
consolidation_threshold_files = 5,
consolidation_threshold_bytes = 50000,
}},
web = {{
enabled = true,
search = {{
provider = "brave",
api_key_secret = "web/brave/default",
}},
}},
}}
"#,
escape_lua_string(model_ref)
)
}
fn escape_lua_string(value: &str) -> String {
value
.chars()
.flat_map(|c| match c {
'\\' => "\\\\".chars().collect::<Vec<_>>(),
'"' => "\\\"".chars().collect::<Vec<_>>(),
'\n' => "\\n".chars().collect::<Vec<_>>(),
'\r' => "\\r".chars().collect::<Vec<_>>(),
'\t' => "\\t".chars().collect::<Vec<_>>(),
other => vec![other],
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn write_default_profile_config_creates_registry_and_profile() {
let dir = tempfile::tempdir().unwrap();
let written = write_default_profile_config(dir.path(), "codex-oauth/gpt-5.5").unwrap();
assert_eq!(written.registry_path, dir.path().join("profiles.toml"));
assert_eq!(
written.profile_path,
dir.path().join("profiles/default.lua")
);
let registry = std::fs::read_to_string(&written.registry_path).unwrap();
assert!(registry.contains("default = \"user:default\""));
assert!(registry.contains("[profile.default]"));
assert!(registry.contains("path = \"profiles/default.lua\""));
let profile = std::fs::read_to_string(&written.profile_path).unwrap();
assert!(profile.contains("slug = \"default\""));
assert!(profile.contains("ref = \"codex-oauth/gpt-5.5\""));
assert!(profile.contains("scope = scope.workspace_write()"));
}
#[test]
fn write_default_profile_config_preserves_other_profile_entries() {
let dir = tempfile::tempdir().unwrap();
std::fs::write(
dir.path().join("profiles.toml"),
r#"[profile.other]
path = "profiles/other.lua"
description = "keep me"
"#,
)
.unwrap();
write_default_profile_config(dir.path(), "anthropic/claude-sonnet-4-6").unwrap();
let registry = std::fs::read_to_string(dir.path().join("profiles.toml")).unwrap();
assert!(registry.contains("[profile.other]"));
assert!(registry.contains("path = \"profiles/other.lua\""));
assert!(registry.contains("[profile.default]"));
assert!(registry.contains("default = \"user:default\""));
}
#[test]
fn escape_lua_string_escapes_quotes_and_slashes() {
assert_eq!(escape_lua_string("a\\b\"c"), "a\\\\b\\\"c");
}
}

View File

@ -22,6 +22,7 @@ enum Mode {
Ticket(ticket_cli::TicketCli),
PodRuntime(Vec<String>),
Keys,
SetupModel,
Tui {
mode: LaunchMode,
workspace_root: PathBuf,
@ -107,6 +108,7 @@ async fn main() -> ExitCode {
},
Mode::PodRuntime(args) => pod::entrypoint::run_cli_from("yoi pod", args).await,
Mode::Keys => tui::keys::launch().await,
Mode::SetupModel => tui::setup_model::launch().await,
Mode::Tui {
mode,
workspace_root,
@ -183,6 +185,14 @@ fn parse_args_slice(args: &[String]) -> Result<Mode, ParseError> {
}
return Ok(Mode::Keys);
}
"setup-model" => {
if args.len() != 1 {
return Err(ParseError(
"yoi setup-model does not accept arguments".into(),
));
}
return Ok(Mode::SetupModel);
}
"memory" if args.get(1).map(String::as_str) == Some("lint") => {
let lint_args = &args[2..];
if lint_args.iter().any(|arg| arg == "--help" || arg == "-h") {
@ -544,6 +554,20 @@ mod tests {
}
}
#[test]
fn parse_setup_model_subcommand() {
match parse_args_from(["setup-model"]).unwrap() {
Mode::SetupModel => {}
_ => panic!("expected SetupModel mode"),
}
}
#[test]
fn parse_setup_model_rejects_arguments() {
let err = parse_args_from(["setup-model", "extra"]).unwrap_err();
assert_eq!(err.to_string(), "yoi setup-model does not accept arguments");
}
#[test]
fn parse_literal_pod_name_still_available_with_flag() {
match parse_args_from(["--pod", "pod"]).unwrap() {