diff --git a/Cargo.lock b/Cargo.lock index ef3de4e9..0274ebac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2078,6 +2078,19 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "mcp" +version = "0.1.0" +dependencies = [ + "libc", + "manifest", + "secrets", + "serde", + "serde_json", + "thiserror 2.0.18", + "tokio", +] + [[package]] name = "memchr" version = "2.8.0" diff --git a/Cargo.toml b/Cargo.toml index 2dadc297..8ce48d23 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "crates/session-store", "crates/secrets", "crates/manifest", + "crates/mcp", "crates/pod", "crates/plugin-pdk", "crates/yoi", @@ -34,6 +35,7 @@ default-members = [ "crates/session-store", "crates/secrets", "crates/manifest", + "crates/mcp", "crates/pod", "crates/plugin-pdk", "crates/yoi", @@ -62,6 +64,7 @@ client = { path = "crates/client" } llm-worker = { path = "crates/llm-worker", version = "0.2" } llm-worker-macros = { path = "crates/llm-worker-macros", version = "0.2" } manifest = { path = "crates/manifest" } +mcp = { path = "crates/mcp" } lint-common = { path = "crates/lint-common" } memory = { path = "crates/memory" } ticket = { path = "crates/ticket" } diff --git a/crates/mcp/Cargo.toml b/crates/mcp/Cargo.toml new file mode 100644 index 00000000..6e82fd6b --- /dev/null +++ b/crates/mcp/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "mcp" +version = "0.1.0" +edition.workspace = true + +[dependencies] +libc = "0.2" +manifest = { workspace = true } +secrets = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true, features = ["io-util", "process", "sync", "time"] } + +[dev-dependencies] +tokio = { workspace = true, features = ["io-util", "macros", "process", "rt-multi-thread", "sync", "time"] } + +[[bin]] +name = "mcp-stdio-mock-server" +path = "tests/fixtures/mock_server.rs" +test = false +bench = false +doc = false diff --git a/crates/mcp/src/lib.rs b/crates/mcp/src/lib.rs new file mode 100644 index 00000000..f382de19 --- /dev/null +++ b/crates/mcp/src/lib.rs @@ -0,0 +1,7 @@ +//! Model Context Protocol client foundations. +//! +//! This crate intentionally only owns protocol/lifecycle plumbing. It does not +//! register MCP tools, resources, or prompts into Yoi's model-visible tool +//! surface. + +pub mod stdio; diff --git a/crates/mcp/src/stdio.rs b/crates/mcp/src/stdio.rs new file mode 100644 index 00000000..26171ed0 --- /dev/null +++ b/crates/mcp/src/stdio.rs @@ -0,0 +1,1112 @@ +use std::collections::{BTreeMap, VecDeque}; +use std::env; +use std::fmt; +use std::path::PathBuf; +use std::process::ExitStatus; +use std::sync::Arc; +use std::time::Duration; + +use manifest::{McpConfig, McpEnvValue, McpStdioCwdPolicy, McpStdioServerConfig}; +use secrets::SecretStore; +use serde::{Deserialize, Serialize}; +use serde_json::{Value, json}; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader}; +use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command}; +use tokio::sync::{Mutex, mpsc}; +use tokio::task::JoinHandle; +use tokio::time::{Instant, timeout}; + +const MCP_PROTOCOL_VERSION: &str = "2025-11-25"; +const CLIENT_NAME: &str = "yoi"; +const CLIENT_VERSION: &str = env!("CARGO_PKG_VERSION"); +const JSONRPC_VERSION: &str = "2.0"; +const ERR_METHOD_NOT_FOUND: i64 = -32601; + +/// Resource limits for a local stdio MCP server lifecycle. +#[derive(Debug, Clone)] +pub struct McpStdioLimits { + pub max_stdout_line_bytes: usize, + pub max_stderr_line_bytes: usize, + pub max_diagnostic_lines: usize, + pub max_protocol_bytes: usize, + pub startup_timeout: Duration, + pub request_timeout: Duration, + pub shutdown_timeout: Duration, + pub kill_timeout: Duration, +} + +impl Default for McpStdioLimits { + fn default() -> Self { + Self { + max_stdout_line_bytes: 1024 * 1024, + max_stderr_line_bytes: 16 * 1024, + max_diagnostic_lines: 32, + max_protocol_bytes: 1024 * 1024, + startup_timeout: Duration::from_secs(10), + request_timeout: Duration::from_secs(10), + shutdown_timeout: Duration::from_secs(2), + kill_timeout: Duration::from_secs(2), + } + } +} + +/// A resolved, explicit local stdio MCP server process specification. +#[derive(Debug, Clone)] +pub struct McpStdioServerSpec { + pub name: String, + pub command: String, + pub args: Vec, + pub cwd: Option, + pub env: BTreeMap, + redactions: Vec, +} + +impl McpStdioServerSpec { + pub fn new(name: impl Into, command: impl Into) -> Self { + Self { + name: name.into(), + command: command.into(), + args: Vec::new(), + cwd: None, + env: BTreeMap::new(), + redactions: Vec::new(), + } + } + + pub fn arg(mut self, arg: impl Into) -> Self { + self.args.push(arg.into()); + self + } + + pub fn args(mut self, args: impl IntoIterator>) -> Self { + self.args.extend(args.into_iter().map(Into::into)); + self + } + + pub fn cwd(mut self, cwd: impl Into) -> Self { + self.cwd = Some(cwd.into()); + self + } + + pub fn env(mut self, name: impl Into, value: impl Into) -> Self { + let value = value.into(); + self.redact_value(&value); + self.env.insert(name.into(), value); + self + } + + pub fn redact_value(&mut self, value: &str) { + if !value.is_empty() && !self.redactions.iter().any(|existing| existing == value) { + self.redactions.push(value.to_owned()); + } + } + + fn redactor(&self) -> Redactor { + Redactor::new(self.redactions.clone()) + } +} + +/// Resolve one explicitly named stdio server from typed MCP config. +pub fn resolve_named_stdio_server( + config: &McpConfig, + name: &str, + workspace_root: impl Into, + secret_store: Option<&SecretStore>, +) -> Result { + let server = config + .stdio_servers + .iter() + .find(|server| server.name == name) + .ok_or_else(|| { + McpClientError::new( + name, + McpPhase::Spawn, + McpErrorKind::Config(format!("stdio server `{name}` is not configured")), + ) + })?; + resolve_stdio_server(server, workspace_root, secret_store) +} + +/// Resolve one typed stdio server into process IO settings without starting it. +pub fn resolve_stdio_server( + server: &McpStdioServerConfig, + workspace_root: impl Into, + secret_store: Option<&SecretStore>, +) -> Result { + let mut spec = McpStdioServerSpec::new(server.name.clone(), server.command.clone()) + .args(server.args.clone()); + let _workspace_root = workspace_root.into(); + spec.cwd = match &server.cwd { + Some(McpStdioCwdPolicy::Path { path }) => Some(path.clone()), + Some(McpStdioCwdPolicy::Inherit) | None => None, + }; + + for name in &server.env.inherit { + if let Ok(value) = env::var(name) { + spec.redact_value(&value); + spec.env.insert(name.clone(), value); + } + } + + for (name, value) in &server.env.set { + let resolved = match value { + McpEnvValue::Literal { value } => value.clone(), + McpEnvValue::EnvRef { name } => env::var(name).map_err(|err| { + McpClientError::new( + &server.name, + McpPhase::Spawn, + McpErrorKind::Config(format!( + "environment variable `{name}` is unavailable: {err}" + )), + ) + })?, + McpEnvValue::SecretRef { ref_ } => { + let store = secret_store.ok_or_else(|| { + McpClientError::new( + &server.name, + McpPhase::Spawn, + McpErrorKind::Config(format!( + "secret `{ref_}` requires a configured secret store" + )), + ) + })?; + store + .get(ref_) + .map_err(|err| { + McpClientError::new( + &server.name, + McpPhase::Spawn, + McpErrorKind::Config(format!( + "failed to resolve secret `{ref_}`: {err}" + )), + ) + })? + .into_string() + } + }; + spec.redact_value(&resolved); + spec.env.insert(name.clone(), resolved); + } + + Ok(spec) +} + +/// A running initialized stdio MCP client. +pub struct McpStdioClient { + server_name: String, + limits: McpStdioLimits, + redactor: Redactor, + diagnostics: Arc>, + stdin: Arc>>, + child: Option, + responses: mpsc::Receiver, + reader_task: JoinHandle<()>, + stderr_task: JoinHandle<()>, + next_id: u64, + initialized: Option, + shutdown_started: bool, +} + +impl McpStdioClient { + /// Spawn, initialize, negotiate capabilities, and send notifications/initialized. + pub async fn connect( + spec: McpStdioServerSpec, + limits: McpStdioLimits, + ) -> Result { + let started = Instant::now(); + let mut client = Self::spawn(spec, limits).await?; + match timeout(client.limits.startup_timeout, client.initialize()).await { + Ok(Ok(())) => Ok(client), + Ok(Err(err)) => { + let _ = client.shutdown().await; + Err(err.with_diagnostics(client.snapshot_diagnostics().await)) + } + Err(_) => { + let err = McpClientError::new( + &client.server_name, + McpPhase::Initialize, + McpErrorKind::Timeout { + operation: "startup".to_string(), + elapsed: started.elapsed(), + }, + ) + .with_diagnostics(client.snapshot_diagnostics().await); + let _ = client.shutdown().await; + Err(err) + } + } + } + + async fn spawn( + spec: McpStdioServerSpec, + limits: McpStdioLimits, + ) -> Result { + let redactor = spec.redactor(); + let mut command = Command::new(&spec.command); + command.args(&spec.args); + if let Some(cwd) = &spec.cwd { + command.current_dir(cwd); + } + command.env_clear(); + command.envs(&spec.env); + command.stdin(std::process::Stdio::piped()); + command.stdout(std::process::Stdio::piped()); + command.stderr(std::process::Stdio::piped()); + command.kill_on_drop(true); + + let mut child = command.spawn().map_err(|err| { + McpClientError::new( + &spec.name, + McpPhase::Spawn, + McpErrorKind::Io(redactor.redact(&err.to_string())), + ) + })?; + let stdin = child.stdin.take().ok_or_else(|| { + McpClientError::new( + &spec.name, + McpPhase::Spawn, + McpErrorKind::Protocol("child stdin was not piped".into()), + ) + })?; + let stdout = child.stdout.take().ok_or_else(|| { + McpClientError::new( + &spec.name, + McpPhase::Spawn, + McpErrorKind::Protocol("child stdout was not piped".into()), + ) + })?; + let stderr = child.stderr.take().ok_or_else(|| { + McpClientError::new( + &spec.name, + McpPhase::Spawn, + McpErrorKind::Protocol("child stderr was not piped".into()), + ) + })?; + + let stdin = Arc::new(Mutex::new(Some(stdin))); + let diagnostics = Arc::new(Mutex::new(BoundedDiagnostics::new( + spec.name.clone(), + limits.max_diagnostic_lines, + redactor.clone(), + ))); + let (tx, rx) = mpsc::channel(16); + let reader_task = spawn_stdout_reader( + spec.name.clone(), + stdout, + stdin.clone(), + tx, + limits.clone(), + redactor.clone(), + ); + let stderr_task = spawn_stderr_reader(stderr, diagnostics.clone(), limits.clone()); + + Ok(Self { + server_name: spec.name, + limits, + redactor, + diagnostics, + stdin, + child: Some(child), + responses: rx, + reader_task, + stderr_task, + next_id: 1, + initialized: None, + shutdown_started: false, + }) + } + + async fn initialize(&mut self) -> Result<(), McpClientError> { + let result: InitializeResult = self + .request( + McpPhase::Initialize, + "initialize", + json!({ + "protocolVersion": MCP_PROTOCOL_VERSION, + "capabilities": {}, + "clientInfo": { + "name": CLIENT_NAME, + "version": CLIENT_VERSION, + } + }), + ) + .await?; + self.initialized = Some(result); + self.write_notification( + McpPhase::Initialized, + "notifications/initialized", + json!({}), + ) + .await?; + Ok(()) + } + + pub fn initialize_result(&self) -> Option<&InitializeResult> { + self.initialized.as_ref() + } + + pub async fn snapshot_diagnostics(&self) -> McpDiagnostics { + self.diagnostics.lock().await.snapshot() + } + + pub async fn request Deserialize<'de>>( + &mut self, + phase: McpPhase, + method: &str, + params: Value, + ) -> Result { + let id = self.next_id; + self.next_id += 1; + let request = ClientRequest { + jsonrpc: JSONRPC_VERSION, + id, + method, + params, + }; + self.write_protocol(phase, &request).await?; + let response = match timeout( + self.limits.request_timeout, + self.wait_for_response(id, phase), + ) + .await + { + Ok(result) => result?, + Err(_) => { + return Err(McpClientError::new( + &self.server_name, + phase, + McpErrorKind::Timeout { + operation: method.to_owned(), + elapsed: self.limits.request_timeout, + }, + ) + .with_diagnostics(self.snapshot_diagnostics().await)); + } + }; + if let Some(error) = response.error { + return Err(McpClientError::new( + &self.server_name, + phase, + McpErrorKind::JsonRpcError { + code: error.code, + message: self.redactor.redact(&error.message), + }, + ) + .with_diagnostics(self.snapshot_diagnostics().await)); + } + let result = response.result.ok_or_else(|| { + McpClientError::new( + &self.server_name, + phase, + McpErrorKind::Protocol(format!("response to `{method}` did not contain result")), + ) + })?; + serde_json::from_value(result).map_err(|err| { + McpClientError::new( + &self.server_name, + phase, + McpErrorKind::Protocol(format!("invalid `{method}` result: {err}")), + ) + }) + } + + async fn wait_for_response( + &mut self, + id: u64, + phase: McpPhase, + ) -> Result { + while let Some(event) = self.responses.recv().await { + match event { + ReaderEvent::Response(response) if response.id == id => return Ok(response), + ReaderEvent::Response(_) | ReaderEvent::Notification => continue, + ReaderEvent::Error(err) => return Err(err.with_phase(phase)), + ReaderEvent::Eof => { + return Err(McpClientError::new( + &self.server_name, + phase, + McpErrorKind::Protocol("server stdout closed before response".into()), + ) + .with_diagnostics(self.snapshot_diagnostics().await)); + } + } + } + Err(McpClientError::new( + &self.server_name, + phase, + McpErrorKind::Protocol("stdout reader stopped before response".into()), + ) + .with_diagnostics(self.snapshot_diagnostics().await)) + } + + async fn write_notification( + &mut self, + phase: McpPhase, + method: &str, + params: T, + ) -> Result<(), McpClientError> { + self.write_protocol( + phase, + &ClientNotification { + jsonrpc: JSONRPC_VERSION, + method, + params, + }, + ) + .await + } + + async fn write_protocol( + &mut self, + phase: McpPhase, + value: &T, + ) -> Result<(), McpClientError> { + write_json_line( + &self.server_name, + phase, + &self.stdin, + value, + self.limits.max_protocol_bytes, + &self.redactor, + ) + .await + } + + /// Close stdin and wait for process exit, falling back to terminate and kill. + pub async fn shutdown(&mut self) -> Result { + self.shutdown_started = true; + { + let mut stdin = self.stdin.lock().await; + stdin.take(); + } + + let mut child = match self.child.take() { + Some(child) => child, + None => return Ok(ShutdownReport::already_finished()), + }; + + if let Ok(Some(status)) = child.try_wait() { + self.reader_task.abort(); + self.stderr_task.abort(); + return Ok(ShutdownReport { + exit_status: Some(status), + terminated: false, + killed: false, + }); + } + + match timeout(self.limits.shutdown_timeout, child.wait()).await { + Ok(Ok(status)) => { + self.reader_task.abort(); + self.stderr_task.abort(); + Ok(ShutdownReport { + exit_status: Some(status), + terminated: false, + killed: false, + }) + } + Ok(Err(err)) => Err(McpClientError::new( + &self.server_name, + McpPhase::Shutdown, + McpErrorKind::Io(self.redactor.redact(&err.to_string())), + ) + .with_diagnostics(self.snapshot_diagnostics().await)), + Err(_) => self.terminate_then_kill(child).await, + } + } + + async fn terminate_then_kill( + &mut self, + mut child: Child, + ) -> Result { + let mut terminated = false; + let mut killed = false; + if send_terminate(&mut child).is_ok() { + terminated = true; + } + match timeout(self.limits.kill_timeout, child.wait()).await { + Ok(Ok(status)) => { + self.reader_task.abort(); + self.stderr_task.abort(); + Ok(ShutdownReport { + exit_status: Some(status), + terminated, + killed, + }) + } + Ok(Err(err)) => Err(McpClientError::new( + &self.server_name, + McpPhase::Shutdown, + McpErrorKind::Io(self.redactor.redact(&err.to_string())), + ) + .with_diagnostics(self.snapshot_diagnostics().await)), + Err(_) => { + child.start_kill().map_err(|err| { + McpClientError::new( + &self.server_name, + McpPhase::Shutdown, + McpErrorKind::Io(self.redactor.redact(&err.to_string())), + ) + })?; + killed = true; + let status = timeout(self.limits.kill_timeout, child.wait()) + .await + .map_err(|_| { + McpClientError::new( + &self.server_name, + McpPhase::Shutdown, + McpErrorKind::Timeout { + operation: "kill".to_string(), + elapsed: self.limits.kill_timeout, + }, + ) + })? + .map_err(|err| { + McpClientError::new( + &self.server_name, + McpPhase::Shutdown, + McpErrorKind::Io(self.redactor.redact(&err.to_string())), + ) + })?; + self.reader_task.abort(); + self.stderr_task.abort(); + Ok(ShutdownReport { + exit_status: Some(status), + terminated, + killed, + }) + } + } + } +} + +impl Drop for McpStdioClient { + fn drop(&mut self) { + if !self.shutdown_started { + if let Some(child) = &mut self.child { + let _ = child.start_kill(); + } + } + self.reader_task.abort(); + self.stderr_task.abort(); + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum McpPhase { + Spawn, + Initialize, + Initialized, + Running, + Shutdown, +} + +impl fmt::Display for McpPhase { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Spawn => f.write_str("spawn"), + Self::Initialize => f.write_str("initialize"), + Self::Initialized => f.write_str("initialized"), + Self::Running => f.write_str("running"), + Self::Shutdown => f.write_str("shutdown"), + } + } +} + +#[derive(Debug, Clone)] +pub struct ShutdownReport { + pub exit_status: Option, + pub terminated: bool, + pub killed: bool, +} + +impl ShutdownReport { + fn already_finished() -> Self { + Self { + exit_status: None, + terminated: false, + killed: false, + } + } +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InitializeResult { + pub protocol_version: String, + #[serde(default)] + pub capabilities: Value, + pub server_info: ImplementationInfo, + #[serde(default)] + pub instructions: Option, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ImplementationInfo { + pub name: String, + pub version: String, +} + +#[derive(Debug, Clone)] +pub struct McpDiagnostics { + pub server_name: String, + pub stderr: Vec, + pub dropped_stderr_lines: usize, + pub truncated_stderr_lines: usize, +} + +#[derive(Debug, Error, Clone)] +#[error("MCP stdio server `{server_name}` failed during phase `{phase}`: {kind}")] +pub struct McpClientError { + pub server_name: String, + pub phase: McpPhase, + pub kind: McpErrorKind, + diagnostics: Option, +} + +impl McpClientError { + fn new(server_name: impl Into, phase: McpPhase, kind: McpErrorKind) -> Self { + Self { + server_name: server_name.into(), + phase, + kind, + diagnostics: None, + } + } + + fn with_phase(mut self, phase: McpPhase) -> Self { + self.phase = phase; + self + } + + pub fn with_diagnostics(mut self, diagnostics: McpDiagnostics) -> Self { + self.diagnostics = Some(diagnostics); + self + } + + pub fn diagnostics(&self) -> Option<&McpDiagnostics> { + self.diagnostics.as_ref() + } +} + +#[derive(Debug, Error, Clone)] +pub enum McpErrorKind { + #[error("configuration error: {0}")] + Config(String), + #[error("I/O error: {0}")] + Io(String), + #[error("protocol error: {0}")] + Protocol(String), + #[error("JSON-RPC error {code}: {message}")] + JsonRpcError { code: i64, message: String }, + #[error("timed out during {operation} after {elapsed:?}")] + Timeout { + operation: String, + elapsed: Duration, + }, +} + +#[derive(Debug)] +enum ReaderEvent { + Response(ServerResponse), + Notification, + Error(McpClientError), + Eof, +} + +#[derive(Debug, Serialize)] +struct ClientRequest<'a> { + jsonrpc: &'static str, + id: u64, + method: &'a str, + params: Value, +} + +#[derive(Debug, Serialize)] +struct ClientNotification<'a, T> { + jsonrpc: &'static str, + method: &'a str, + params: T, +} + +#[derive(Debug, Deserialize)] +struct IncomingMessage { + #[allow(dead_code)] + jsonrpc: Option, + id: Option, + method: Option, + result: Option, + error: Option, + #[allow(dead_code)] + params: Option, +} + +#[derive(Debug, Deserialize)] +struct ServerResponse { + id: u64, + result: Option, + error: Option, +} + +#[derive(Debug, Deserialize)] +struct RpcError { + code: i64, + message: String, + #[allow(dead_code)] + data: Option, +} + +#[derive(Debug, Serialize)] +struct ErrorResponse<'a> { + jsonrpc: &'static str, + id: Value, + error: ErrorObject<'a>, +} + +#[derive(Debug, Serialize)] +struct ErrorObject<'a> { + code: i64, + message: &'a str, +} + +fn spawn_stdout_reader( + server_name: String, + stdout: ChildStdout, + stdin: Arc>>, + tx: mpsc::Sender, + limits: McpStdioLimits, + redactor: Redactor, +) -> JoinHandle<()> { + tokio::spawn(async move { + let mut stdout = BufReader::new(stdout); + loop { + match read_protocol_line(&mut stdout, limits.max_stdout_line_bytes).await { + Ok(Some(line)) => match serde_json::from_slice::(&line) { + Ok(message) => { + handle_incoming_message( + &server_name, + &stdin, + &tx, + &limits, + &redactor, + message, + ) + .await + } + Err(err) => { + let _ = tx + .send(ReaderEvent::Error(McpClientError::new( + &server_name, + McpPhase::Running, + McpErrorKind::Protocol(format!( + "invalid stdout JSON-RPC message: {err}" + )), + ))) + .await; + break; + } + }, + Ok(None) => { + let _ = tx.send(ReaderEvent::Eof).await; + break; + } + Err(err) => { + let _ = tx + .send(ReaderEvent::Error(McpClientError::new( + &server_name, + McpPhase::Running, + McpErrorKind::Protocol(err), + ))) + .await; + break; + } + } + } + }) +} + +async fn handle_incoming_message( + server_name: &str, + stdin: &Arc>>, + tx: &mpsc::Sender, + limits: &McpStdioLimits, + redactor: &Redactor, + message: IncomingMessage, +) { + if message.method.is_some() && message.id.is_some() { + if let Some(id) = message.id { + let response = ErrorResponse { + jsonrpc: JSONRPC_VERSION, + id, + error: ErrorObject { + code: ERR_METHOD_NOT_FOUND, + message: "server-to-client requests are not supported by this client", + }, + }; + let _ = write_json_line( + server_name, + McpPhase::Running, + stdin, + &response, + limits.max_protocol_bytes, + redactor, + ) + .await; + } + return; + } + + if message.method.is_some() { + let _ = tx.send(ReaderEvent::Notification).await; + return; + } + + if let Some(id) = message.id.as_ref().and_then(Value::as_u64) { + let _ = tx + .send(ReaderEvent::Response(ServerResponse { + id, + result: message.result, + error: message.error, + })) + .await; + return; + } + + let _ = tx + .send(ReaderEvent::Error(McpClientError::new( + server_name, + McpPhase::Running, + McpErrorKind::Protocol( + "JSON-RPC response id was missing or not an unsigned integer".into(), + ), + ))) + .await; +} + +fn spawn_stderr_reader( + stderr: ChildStderr, + diagnostics: Arc>, + limits: McpStdioLimits, +) -> JoinHandle<()> { + tokio::spawn(async move { + let mut stderr = BufReader::new(stderr); + loop { + match read_diagnostic_line(&mut stderr, limits.max_stderr_line_bytes).await { + Ok(Some((line, truncated))) => diagnostics.lock().await.push(line, truncated), + Ok(None) => break, + Err(err) => { + diagnostics + .lock() + .await + .push(format!("stderr read error: {err}"), false); + break; + } + } + } + }) +} + +async fn write_json_line( + server_name: &str, + phase: McpPhase, + stdin: &Arc>>, + value: &T, + max_protocol_bytes: usize, + redactor: &Redactor, +) -> Result<(), McpClientError> { + let mut bytes = serde_json::to_vec(value).map_err(|err| { + McpClientError::new( + server_name, + phase, + McpErrorKind::Protocol(format!("failed to encode JSON-RPC message: {err}")), + ) + })?; + if bytes.len() > max_protocol_bytes { + return Err(McpClientError::new( + server_name, + phase, + McpErrorKind::Protocol(format!( + "JSON-RPC payload exceeded {max_protocol_bytes} bytes" + )), + )); + } + bytes.push(b'\n'); + let mut guard = stdin.lock().await; + let Some(stdin) = guard.as_mut() else { + return Err(McpClientError::new( + server_name, + phase, + McpErrorKind::Io("child stdin is closed".into()), + )); + }; + stdin.write_all(&bytes).await.map_err(|err| { + McpClientError::new( + server_name, + phase, + McpErrorKind::Io(redactor.redact(&err.to_string())), + ) + })?; + stdin.flush().await.map_err(|err| { + McpClientError::new( + server_name, + phase, + McpErrorKind::Io(redactor.redact(&err.to_string())), + ) + }) +} + +async fn read_protocol_line( + reader: &mut R, + max_bytes: usize, +) -> Result>, String> { + let mut buf = Vec::new(); + let mut byte = [0u8; 1]; + loop { + let read = reader + .read(&mut byte) + .await + .map_err(|err| err.to_string())?; + if read == 0 { + return if buf.is_empty() { + Ok(None) + } else { + Ok(Some(trim_newline(buf))) + }; + } + if byte[0] == b'\n' { + return Ok(Some(trim_newline(buf))); + } + if buf.len() >= max_bytes { + return Err(format!("stdout line exceeded {max_bytes} bytes")); + } + buf.push(byte[0]); + } +} + +async fn read_diagnostic_line( + reader: &mut R, + max_bytes: usize, +) -> Result, String> { + let mut buf = Vec::new(); + let mut truncated = false; + let mut byte = [0u8; 1]; + loop { + let read = reader + .read(&mut byte) + .await + .map_err(|err| err.to_string())?; + if read == 0 { + if buf.is_empty() && !truncated { + return Ok(None); + } + return Ok(Some(( + String::from_utf8_lossy(&trim_newline(buf)).into_owned(), + truncated, + ))); + } + if byte[0] == b'\n' { + return Ok(Some(( + String::from_utf8_lossy(&trim_newline(buf)).into_owned(), + truncated, + ))); + } + if buf.len() < max_bytes { + buf.push(byte[0]); + } else { + truncated = true; + } + } +} + +fn trim_newline(mut buf: Vec) -> Vec { + if buf.last() == Some(&b'\r') { + buf.pop(); + } + buf +} + +#[derive(Debug)] +struct BoundedDiagnostics { + server_name: String, + max_lines: usize, + lines: VecDeque, + dropped_lines: usize, + truncated_lines: usize, + redactor: Redactor, +} + +impl BoundedDiagnostics { + fn new(server_name: String, max_lines: usize, redactor: Redactor) -> Self { + Self { + server_name, + max_lines, + lines: VecDeque::new(), + dropped_lines: 0, + truncated_lines: 0, + redactor, + } + } + + fn push(&mut self, line: String, truncated: bool) { + if truncated { + self.truncated_lines += 1; + } + if self.max_lines == 0 { + self.dropped_lines += 1; + return; + } + if self.lines.len() == self.max_lines { + self.lines.pop_front(); + self.dropped_lines += 1; + } + let suffix = if truncated { "… [truncated]" } else { "" }; + self.lines + .push_back(format!("{}{suffix}", self.redactor.redact(&line))); + } + + fn snapshot(&self) -> McpDiagnostics { + McpDiagnostics { + server_name: self.server_name.clone(), + stderr: self.lines.iter().cloned().collect(), + dropped_stderr_lines: self.dropped_lines, + truncated_stderr_lines: self.truncated_lines, + } + } +} + +#[derive(Debug, Clone)] +struct Redactor { + values: Vec, +} + +impl Redactor { + fn new(mut values: Vec) -> Self { + values.retain(|value| !value.is_empty()); + values.sort_by_key(|value| std::cmp::Reverse(value.len())); + values.dedup(); + Self { values } + } + + fn redact(&self, input: &str) -> String { + let mut output = input.to_owned(); + for value in &self.values { + output = output.replace(value, "[redacted]"); + } + output + } +} + +#[cfg(unix)] +fn send_terminate(child: &mut Child) -> Result<(), ()> { + let Some(pid) = child.id() else { + return Err(()); + }; + let result = unsafe { libc::kill(pid as libc::pid_t, libc::SIGTERM) }; + if result == 0 { Ok(()) } else { Err(()) } +} + +#[cfg(not(unix))] +fn send_terminate(child: &mut Child) -> Result<(), ()> { + child.start_kill().map_err(|_| ()) +} diff --git a/crates/mcp/tests/fixtures/mock_server.rs b/crates/mcp/tests/fixtures/mock_server.rs new file mode 100644 index 00000000..74dcea70 --- /dev/null +++ b/crates/mcp/tests/fixtures/mock_server.rs @@ -0,0 +1,116 @@ +use std::env; +use std::io::{self, BufRead, Write}; +use std::thread; +use std::time::Duration; + +use serde_json::{Value, json}; + +fn main() { + let mode = env::var("YOI_MCP_MOCK_MODE").unwrap_or_else(|_| "success".to_string()); + match mode.as_str() { + "success" => success(), + "fail-init" => fail_init(), + "sampling" => sampling_request(), + "shutdown-hang" => shutdown_hang(), + other => panic!("unknown mock mode: {other}"), + } +} + +fn success() { + let init = read_json(); + assert_eq!(init["method"], "initialize"); + assert!(init["params"]["capabilities"].get("sampling").is_none()); + assert!(init["params"]["capabilities"].get("elicitation").is_none()); + write_json(json!({ + "jsonrpc": "2.0", + "id": init["id"], + "result": initialize_result(), + })); + let initialized = read_json(); + assert_eq!(initialized["method"], "notifications/initialized"); + drain_stdin(); +} + +fn fail_init() { + let secret = env::var("MCP_TEST_SECRET").unwrap_or_default(); + for idx in 0..5 { + eprintln!("diagnostic {idx}: secret={secret}"); + } + let init = read_json(); + write_json(json!({ + "jsonrpc": "2.0", + "id": init["id"], + "error": { + "code": -32000, + "message": format!("init rejected with {secret}"), + } + })); +} + +fn sampling_request() { + let init = read_json(); + write_json(json!({ + "jsonrpc": "2.0", + "id": init["id"], + "result": initialize_result(), + })); + let initialized = read_json(); + assert_eq!(initialized["method"], "notifications/initialized"); + write_json(json!({ + "jsonrpc": "2.0", + "id": 99, + "method": "sampling/createMessage", + "params": {}, + })); + let response = read_json(); + assert_eq!(response["id"], 99); + assert_eq!(response["error"]["code"], -32601); +} + +fn shutdown_hang() { + let init = read_json(); + write_json(json!({ + "jsonrpc": "2.0", + "id": init["id"], + "result": initialize_result(), + })); + let initialized = read_json(); + assert_eq!(initialized["method"], "notifications/initialized"); + loop { + thread::sleep(Duration::from_secs(60)); + } +} + +fn initialize_result() -> Value { + json!({ + "protocolVersion": "2025-11-25", + "capabilities": { + "tools": { "listChanged": true } + }, + "serverInfo": { + "name": "mock-mcp", + "version": "0.1.0" + } + }) +} + +fn read_json() -> Value { + let mut line = String::new(); + let read = io::stdin().lock().read_line(&mut line).expect("read stdin"); + assert_ne!(read, 0, "stdin closed before JSON-RPC message"); + serde_json::from_str(&line).expect("valid JSON-RPC line") +} + +fn write_json(value: Value) { + let mut stdout = io::stdout().lock(); + serde_json::to_writer(&mut stdout, &value).expect("write JSON"); + stdout.write_all(b"\n").expect("write newline"); + stdout.flush().expect("flush stdout"); +} + +fn drain_stdin() { + let mut line = String::new(); + while io::stdin().lock().read_line(&mut line).unwrap_or(0) != 0 { + line.clear(); + } +} diff --git a/crates/mcp/tests/stdio_lifecycle.rs b/crates/mcp/tests/stdio_lifecycle.rs new file mode 100644 index 00000000..d805c57e --- /dev/null +++ b/crates/mcp/tests/stdio_lifecycle.rs @@ -0,0 +1,94 @@ +use std::time::Duration; + +use mcp::stdio::{McpErrorKind, McpPhase, McpStdioClient, McpStdioLimits, McpStdioServerSpec}; + +fn mock_server(mode: &str) -> McpStdioServerSpec { + McpStdioServerSpec::new("mock", env!("CARGO_BIN_EXE_mcp-stdio-mock-server")) + .env("YOI_MCP_MOCK_MODE", mode) +} + +fn tight_limits() -> McpStdioLimits { + McpStdioLimits { + startup_timeout: Duration::from_secs(2), + request_timeout: Duration::from_secs(2), + shutdown_timeout: Duration::from_millis(100), + kill_timeout: Duration::from_millis(100), + max_diagnostic_lines: 2, + max_stderr_line_bytes: 256, + ..Default::default() + } +} + +#[tokio::test] +async fn initializes_mock_stdio_server() { + let mut client = McpStdioClient::connect(mock_server("success"), tight_limits()) + .await + .expect("initialize succeeds"); + let result = client.initialize_result().expect("initialize result"); + assert_eq!(result.protocol_version, "2025-11-25"); + assert_eq!(result.server_info.name, "mock-mcp"); + let shutdown = client.shutdown().await.expect("shutdown succeeds"); + assert!(!shutdown.terminated); + assert!(!shutdown.killed); + assert!(shutdown.exit_status.is_some_and(|status| status.success())); +} + +#[tokio::test] +async fn initialize_failure_reports_server_phase_and_redacted_bounded_stderr() { + let spec = mock_server("fail-init").env("MCP_TEST_SECRET", "super-secret-token"); + let err = match McpStdioClient::connect(spec, tight_limits()).await { + Ok(mut client) => { + let _ = client.shutdown().await; + panic!("initialize unexpectedly succeeded"); + } + Err(err) => err, + }; + assert_eq!(err.server_name, "mock"); + assert_eq!(err.phase, McpPhase::Initialize); + match &err.kind { + McpErrorKind::JsonRpcError { code, message } => { + assert_eq!(*code, -32000); + assert!(!message.contains("super-secret-token")); + assert!(message.contains("[redacted]")); + } + other => panic!("unexpected error kind: {other:?}"), + } + let rendered = err.to_string(); + assert!(rendered.contains("mock")); + assert!(rendered.contains("initialize")); + let diagnostics = err.diagnostics().expect("diagnostics"); + assert_eq!(diagnostics.server_name, "mock"); + assert_eq!(diagnostics.stderr.len(), 2); + assert!(diagnostics.dropped_stderr_lines >= 3); + assert!( + diagnostics + .stderr + .iter() + .all(|line| !line.contains("super-secret-token")) + ); + assert!( + diagnostics + .stderr + .iter() + .any(|line| line.contains("[redacted]")) + ); +} + +#[tokio::test] +async fn shutdown_terminates_or_kills_uncooperative_server() { + let mut client = McpStdioClient::connect(mock_server("shutdown-hang"), tight_limits()) + .await + .expect("initialize succeeds"); + let shutdown = client.shutdown().await.expect("shutdown succeeds"); + assert!(shutdown.terminated || shutdown.killed); +} + +#[tokio::test] +async fn sampling_requests_fail_closed_and_are_not_advertised() { + let mut client = McpStdioClient::connect(mock_server("sampling"), tight_limits()) + .await + .expect("initialize succeeds"); + tokio::time::sleep(Duration::from_millis(50)).await; + let shutdown = client.shutdown().await.expect("shutdown succeeds"); + assert!(shutdown.exit_status.is_some_and(|status| status.success())); +} diff --git a/package.nix b/package.nix index d027b05a..ec9c2f6d 100644 --- a/package.nix +++ b/package.nix @@ -40,7 +40,7 @@ rustPlatform.buildRustPackage rec { filter = sourceFilter; }; - cargoHash = "sha256-Q+z7HDTkLtflth79ptEFy1lkDR9Y5VRrmX0m9NtLVqM="; + cargoHash = "sha256-EH4zdakrFxqVrgaNBx3dICN6KoLqskTEGYnU73XMVsU="; depsExtraArgs = { # Older fetchCargoVendor utilities used crates.io's API download endpoint,