yoi/crates/mcp/src/stdio.rs

1687 lines
50 KiB
Rust

use std::collections::{BTreeMap, BTreeSet, VecDeque};
use std::env;
use std::fmt;
use std::path::PathBuf;
use std::process::ExitStatus;
use std::sync::Arc;
use std::time::Duration;
use manifest::{McpConfig, McpEnvValue, McpStdioCwdPolicy, McpStdioServerConfig};
use secrets::SecretStore;
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command};
use tokio::sync::{Mutex, mpsc};
use tokio::task::JoinHandle;
use tokio::time::{Instant, timeout};
const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
const CLIENT_NAME: &str = "yoi";
const CLIENT_VERSION: &str = env!("CARGO_PKG_VERSION");
const JSONRPC_VERSION: &str = "2.0";
const ERR_METHOD_NOT_FOUND: i64 = -32601;
/// Resource limits for a local stdio MCP server lifecycle.
#[derive(Debug, Clone)]
pub struct McpStdioLimits {
pub max_stdout_line_bytes: usize,
pub max_stderr_line_bytes: usize,
pub max_diagnostic_lines: usize,
pub max_protocol_bytes: usize,
pub startup_timeout: Duration,
pub request_timeout: Duration,
pub shutdown_timeout: Duration,
pub kill_timeout: Duration,
}
impl Default for McpStdioLimits {
fn default() -> Self {
Self {
max_stdout_line_bytes: 1024 * 1024,
max_stderr_line_bytes: 16 * 1024,
max_diagnostic_lines: 32,
max_protocol_bytes: 1024 * 1024,
startup_timeout: Duration::from_secs(10),
request_timeout: Duration::from_secs(10),
shutdown_timeout: Duration::from_secs(2),
kill_timeout: Duration::from_secs(2),
}
}
}
/// Host bounds for MCP `tools/list` pagination during discovery.
#[derive(Debug, Clone, Copy)]
pub struct McpToolListLimits {
pub max_pages: usize,
pub max_tools: usize,
}
impl Default for McpToolListLimits {
fn default() -> Self {
Self {
max_pages: 8,
max_tools: 128,
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct McpToolDefinition {
pub name: String,
#[serde(default)]
pub title: Option<String>,
#[serde(default)]
pub description: Option<String>,
pub input_schema: Value,
#[serde(default)]
pub output_schema: Option<Value>,
#[serde(default)]
pub annotations: Option<Value>,
#[serde(default, rename = "_meta")]
pub meta: Option<Value>,
#[serde(flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ListToolsResult {
#[serde(default)]
pub tools: Vec<McpToolDefinition>,
#[serde(default)]
pub next_cursor: Option<String>,
#[serde(default, rename = "_meta")]
pub meta: Option<Value>,
#[serde(flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Copy)]
pub struct McpResourceListLimits {
pub max_pages: usize,
pub max_resources: usize,
pub max_resource_templates: usize,
}
impl Default for McpResourceListLimits {
fn default() -> Self {
Self {
max_pages: 8,
max_resources: 128,
max_resource_templates: 128,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct McpPromptListLimits {
pub max_pages: usize,
pub max_prompts: usize,
}
impl Default for McpPromptListLimits {
fn default() -> Self {
Self {
max_pages: 8,
max_prompts: 128,
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct McpResourceDefinition {
pub uri: String,
#[serde(default)]
pub name: Option<String>,
#[serde(default)]
pub title: Option<String>,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub mime_type: Option<String>,
#[serde(default)]
pub annotations: Option<Value>,
#[serde(default, rename = "_meta")]
pub meta: Option<Value>,
#[serde(flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct McpResourceTemplateDefinition {
pub uri_template: String,
#[serde(default)]
pub name: Option<String>,
#[serde(default)]
pub title: Option<String>,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub mime_type: Option<String>,
#[serde(default)]
pub annotations: Option<Value>,
#[serde(default, rename = "_meta")]
pub meta: Option<Value>,
#[serde(flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ListResourcesResult {
#[serde(default)]
pub resources: Vec<McpResourceDefinition>,
#[serde(default)]
pub resource_templates: Vec<McpResourceTemplateDefinition>,
#[serde(default)]
pub next_cursor: Option<String>,
#[serde(default, rename = "_meta")]
pub meta: Option<Value>,
#[serde(flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ReadResourceRequest {
pub uri: String,
}
impl ReadResourceRequest {
pub fn new(uri: impl Into<String>) -> Self {
Self { uri: uri.into() }
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct McpResourceContent {
pub uri: String,
#[serde(default)]
pub mime_type: Option<String>,
#[serde(default, rename = "_meta")]
pub meta: Option<Value>,
#[serde(flatten)]
pub fields: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ReadResourceResult {
#[serde(default)]
pub contents: Vec<McpResourceContent>,
#[serde(default, rename = "_meta")]
pub meta: Option<Value>,
#[serde(flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct McpPromptArgumentDefinition {
pub name: String,
#[serde(default)]
pub title: Option<String>,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub required: Option<bool>,
#[serde(default, rename = "_meta")]
pub meta: Option<Value>,
#[serde(flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct McpPromptDefinition {
pub name: String,
#[serde(default)]
pub title: Option<String>,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub arguments: Vec<McpPromptArgumentDefinition>,
#[serde(default, rename = "_meta")]
pub meta: Option<Value>,
#[serde(flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ListPromptsResult {
#[serde(default)]
pub prompts: Vec<McpPromptDefinition>,
#[serde(default)]
pub next_cursor: Option<String>,
#[serde(default, rename = "_meta")]
pub meta: Option<Value>,
#[serde(flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GetPromptRequest {
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub arguments: Option<Value>,
}
impl GetPromptRequest {
pub fn new(name: impl Into<String>, arguments: Option<Value>) -> Self {
Self {
name: name.into(),
arguments,
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct McpPromptMessage {
pub role: String,
pub content: McpContentBlock,
#[serde(flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GetPromptResult {
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub messages: Vec<McpPromptMessage>,
#[serde(default, rename = "_meta")]
pub meta: Option<Value>,
#[serde(flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CallToolRequest {
pub name: String,
#[serde(default, skip_serializing_if = "Value::is_null")]
pub arguments: Value,
}
impl CallToolRequest {
pub fn new(name: impl Into<String>, arguments: Value) -> Self {
Self {
name: name.into(),
arguments,
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CallToolResult {
#[serde(default)]
pub content: Vec<McpContentBlock>,
#[serde(default)]
pub structured_content: Option<Value>,
#[serde(default)]
pub is_error: bool,
#[serde(default, rename = "_meta")]
pub meta: Option<Value>,
#[serde(flatten)]
pub extra: BTreeMap<String, Value>,
}
/// One untrusted MCP `tools/call` content block.
///
/// The `type` discriminator is kept explicit and all server-owned fields stay
/// data in `fields`; this crate does not turn rich MCP content into hidden host
/// context.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpContentBlock {
#[serde(rename = "type")]
pub kind: String,
#[serde(flatten)]
pub fields: BTreeMap<String, Value>,
}
/// MCP list surface whose `notifications/*/list_changed` signal was observed.
///
/// The notification is only a freshness signal. The stdio client records this
/// bounded enum state and deliberately ignores notification params so a server
/// cannot inject resource/prompt content or alter model-visible tool schemas
/// through an out-of-band notification.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum McpListChangedKind {
Tools,
Resources,
Prompts,
}
impl McpListChangedKind {
fn from_notification_method(method: &str) -> Option<Self> {
match method {
"notifications/tools/list_changed" => Some(Self::Tools),
"notifications/resources/list_changed" => Some(Self::Resources),
"notifications/prompts/list_changed" => Some(Self::Prompts),
_ => None,
}
}
pub fn notification_method(self) -> &'static str {
match self {
Self::Tools => "notifications/tools/list_changed",
Self::Resources => "notifications/resources/list_changed",
Self::Prompts => "notifications/prompts/list_changed",
}
}
pub fn list_method(self) -> &'static str {
match self {
Self::Tools => "tools/list",
Self::Resources => "resources/list",
Self::Prompts => "prompts/list",
}
}
}
/// Bounded snapshot of list-change signals observed from one stdio server.
#[derive(Debug, Clone)]
pub struct McpListChangedSnapshot {
pub server_name: String,
kinds: BTreeSet<McpListChangedKind>,
}
impl McpListChangedSnapshot {
pub fn is_empty(&self) -> bool {
self.kinds.is_empty()
}
pub fn contains(&self, kind: McpListChangedKind) -> bool {
self.kinds.contains(&kind)
}
pub fn kinds(&self) -> impl Iterator<Item = McpListChangedKind> + '_ {
self.kinds.iter().copied()
}
}
/// A resolved, explicit local stdio MCP server process specification.
#[derive(Clone)]
pub struct McpStdioServerSpec {
pub name: String,
pub command: String,
pub args: Vec<String>,
pub cwd: Option<PathBuf>,
pub env: BTreeMap<String, String>,
redactions: Vec<String>,
}
impl McpStdioServerSpec {
pub fn new(name: impl Into<String>, command: impl Into<String>) -> Self {
Self {
name: name.into(),
command: command.into(),
args: Vec::new(),
cwd: None,
env: BTreeMap::new(),
redactions: Vec::new(),
}
}
pub fn arg(mut self, arg: impl Into<String>) -> Self {
self.args.push(arg.into());
self
}
pub fn args(mut self, args: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.args.extend(args.into_iter().map(Into::into));
self
}
pub fn cwd(mut self, cwd: impl Into<PathBuf>) -> Self {
self.cwd = Some(cwd.into());
self
}
pub fn env(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
let value = value.into();
self.redact_value(&value);
self.env.insert(name.into(), value);
self
}
pub fn redact_value(&mut self, value: &str) {
if !value.is_empty() && !self.redactions.iter().any(|existing| existing == value) {
self.redactions.push(value.to_owned());
}
}
fn redactor(&self) -> Redactor {
Redactor::new(self.redactions.clone())
}
}
impl fmt::Debug for McpStdioServerSpec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let env: BTreeMap<&str, &str> = self
.env
.keys()
.map(|name| (name.as_str(), "[redacted]"))
.collect();
f.debug_struct("McpStdioServerSpec")
.field("name", &self.name)
.field("command", &self.command)
.field("args", &self.args)
.field("cwd", &self.cwd)
.field("env", &env)
.field("redaction_count", &self.redactions.len())
.finish()
}
}
/// Resolve one explicitly named stdio server from typed MCP config.
pub fn resolve_named_stdio_server(
config: &McpConfig,
name: &str,
workspace_root: impl Into<PathBuf>,
secret_store: Option<&SecretStore>,
) -> Result<McpStdioServerSpec, McpClientError> {
let server = config
.stdio_servers
.iter()
.find(|server| server.name == name)
.ok_or_else(|| {
McpClientError::new(
name,
McpPhase::Spawn,
McpErrorKind::Config(format!("stdio server `{name}` is not configured")),
)
})?;
resolve_stdio_server(server, workspace_root, secret_store)
}
/// Resolve one typed stdio server into process IO settings without starting it.
pub fn resolve_stdio_server(
server: &McpStdioServerConfig,
workspace_root: impl Into<PathBuf>,
secret_store: Option<&SecretStore>,
) -> Result<McpStdioServerSpec, McpClientError> {
let mut spec = McpStdioServerSpec::new(server.name.clone(), server.command.clone())
.args(server.args.clone());
let _workspace_root = workspace_root.into();
spec.cwd = match &server.cwd {
Some(McpStdioCwdPolicy::Path { path }) => Some(path.clone()),
Some(McpStdioCwdPolicy::Inherit) | None => None,
};
for name in &server.env.inherit {
if let Ok(value) = env::var(name) {
spec.redact_value(&value);
spec.env.insert(name.clone(), value);
}
}
for (name, value) in &server.env.set {
let resolved = match value {
McpEnvValue::Literal { value } => value.clone(),
McpEnvValue::EnvRef { name } => env::var(name).map_err(|err| {
McpClientError::new(
&server.name,
McpPhase::Spawn,
McpErrorKind::Config(format!(
"environment variable `{name}` is unavailable: {err}"
)),
)
})?,
McpEnvValue::SecretRef { ref_ } => {
let store = secret_store.ok_or_else(|| {
McpClientError::new(
&server.name,
McpPhase::Spawn,
McpErrorKind::Config(format!(
"secret `{ref_}` requires a configured secret store"
)),
)
})?;
store
.get(ref_)
.map_err(|err| {
McpClientError::new(
&server.name,
McpPhase::Spawn,
McpErrorKind::Config(format!(
"failed to resolve secret `{ref_}`: {err}"
)),
)
})?
.into_string()
}
};
spec.redact_value(&resolved);
spec.env.insert(name.clone(), resolved);
}
Ok(spec)
}
/// A running initialized stdio MCP client.
pub struct McpStdioClient {
server_name: String,
limits: McpStdioLimits,
redactor: Redactor,
diagnostics: Arc<Mutex<BoundedDiagnostics>>,
list_changes: Arc<Mutex<BoundedListChanged>>,
stdin: Arc<Mutex<Option<ChildStdin>>>,
child: Option<Child>,
responses: mpsc::Receiver<ReaderEvent>,
reader_task: JoinHandle<()>,
stderr_task: JoinHandle<()>,
next_id: u64,
initialized: Option<InitializeResult>,
shutdown_started: bool,
}
impl McpStdioClient {
/// Spawn, initialize, negotiate capabilities, and send notifications/initialized.
pub async fn connect(
spec: McpStdioServerSpec,
limits: McpStdioLimits,
) -> Result<Self, McpClientError> {
let started = Instant::now();
let mut client = Self::spawn(spec, limits).await?;
match timeout(client.limits.startup_timeout, client.initialize()).await {
Ok(Ok(())) => Ok(client),
Ok(Err(err)) => {
let _ = client.shutdown().await;
Err(err.with_diagnostics(client.snapshot_diagnostics().await))
}
Err(_) => {
let err = McpClientError::new(
&client.server_name,
McpPhase::Initialize,
McpErrorKind::Timeout {
operation: "startup".to_string(),
elapsed: started.elapsed(),
},
)
.with_diagnostics(client.snapshot_diagnostics().await);
let _ = client.shutdown().await;
Err(err)
}
}
}
async fn spawn(
spec: McpStdioServerSpec,
limits: McpStdioLimits,
) -> Result<Self, McpClientError> {
let redactor = spec.redactor();
let mut command = Command::new(&spec.command);
command.args(&spec.args);
if let Some(cwd) = &spec.cwd {
command.current_dir(cwd);
}
command.env_clear();
command.envs(&spec.env);
command.stdin(std::process::Stdio::piped());
command.stdout(std::process::Stdio::piped());
command.stderr(std::process::Stdio::piped());
command.kill_on_drop(true);
let mut child = command.spawn().map_err(|err| {
McpClientError::new(
&spec.name,
McpPhase::Spawn,
McpErrorKind::Io(redactor.redact(&err.to_string())),
)
})?;
let stdin = child.stdin.take().ok_or_else(|| {
McpClientError::new(
&spec.name,
McpPhase::Spawn,
McpErrorKind::Protocol("child stdin was not piped".into()),
)
})?;
let stdout = child.stdout.take().ok_or_else(|| {
McpClientError::new(
&spec.name,
McpPhase::Spawn,
McpErrorKind::Protocol("child stdout was not piped".into()),
)
})?;
let stderr = child.stderr.take().ok_or_else(|| {
McpClientError::new(
&spec.name,
McpPhase::Spawn,
McpErrorKind::Protocol("child stderr was not piped".into()),
)
})?;
let stdin = Arc::new(Mutex::new(Some(stdin)));
let diagnostics = Arc::new(Mutex::new(BoundedDiagnostics::new(
spec.name.clone(),
limits.max_diagnostic_lines,
redactor.clone(),
)));
let list_changes = Arc::new(Mutex::new(BoundedListChanged::new(spec.name.clone())));
let (tx, rx) = mpsc::channel(16);
let reader_task = spawn_stdout_reader(
spec.name.clone(),
stdout,
stdin.clone(),
tx,
limits.clone(),
redactor.clone(),
list_changes.clone(),
);
let stderr_task = spawn_stderr_reader(stderr, diagnostics.clone(), limits.clone());
Ok(Self {
server_name: spec.name,
limits,
redactor,
diagnostics,
list_changes,
stdin,
child: Some(child),
responses: rx,
reader_task,
stderr_task,
next_id: 1,
initialized: None,
shutdown_started: false,
})
}
async fn initialize(&mut self) -> Result<(), McpClientError> {
let result: InitializeResult = self
.request(
McpPhase::Initialize,
"initialize",
json!({
"protocolVersion": MCP_PROTOCOL_VERSION,
"capabilities": {},
"clientInfo": {
"name": CLIENT_NAME,
"version": CLIENT_VERSION,
}
}),
)
.await?;
self.initialized = Some(result);
self.write_notification(
McpPhase::Initialized,
"notifications/initialized",
json!({}),
)
.await?;
Ok(())
}
pub fn initialize_result(&self) -> Option<&InitializeResult> {
self.initialized.as_ref()
}
/// Request one page of the MCP `tools/list` surface after initialization.
///
/// This performs discovery only. It never sends `tools/call` and does not
/// expose resources or prompts.
pub async fn list_tools_page(
&mut self,
cursor: Option<String>,
) -> Result<ListToolsResult, McpClientError> {
let params = cursor
.map(|cursor| json!({ "cursor": cursor }))
.unwrap_or_else(|| json!({}));
self.request(McpPhase::Running, "tools/list", params).await
}
/// Execute an initialized MCP `tools/call` request.
///
/// The caller is responsible for applying Yoi tool permissions before this
/// method is reached and for bounding/serializing the untrusted result before
/// it is exposed to model-visible tool history.
pub async fn call_tool(
&mut self,
request: CallToolRequest,
) -> Result<CallToolResult, McpClientError> {
let params = serde_json::to_value(request).map_err(|err| {
McpClientError::new(
&self.server_name,
McpPhase::Running,
McpErrorKind::Protocol(format!("failed to serialize tools/call request: {err}")),
)
})?;
self.request(McpPhase::Running, "tools/call", params).await
}
/// Request one page of the MCP `resources/list` surface after initialization.
pub async fn list_resources_page(
&mut self,
cursor: Option<String>,
) -> Result<ListResourcesResult, McpClientError> {
let params = cursor
.map(|cursor| json!({ "cursor": cursor }))
.unwrap_or_else(|| json!({}));
self.request(McpPhase::Running, "resources/list", params)
.await
}
/// Read one MCP resource by URI after initialization.
pub async fn read_resource(
&mut self,
request: ReadResourceRequest,
) -> Result<ReadResourceResult, McpClientError> {
let params = serde_json::to_value(request).map_err(|err| {
McpClientError::new(
&self.server_name,
McpPhase::Running,
McpErrorKind::Protocol(format!(
"failed to serialize resources/read request: {err}"
)),
)
})?;
self.request(McpPhase::Running, "resources/read", params)
.await
}
/// Request one page of the MCP `prompts/list` surface after initialization.
pub async fn list_prompts_page(
&mut self,
cursor: Option<String>,
) -> Result<ListPromptsResult, McpClientError> {
let params = cursor
.map(|cursor| json!({ "cursor": cursor }))
.unwrap_or_else(|| json!({}));
self.request(McpPhase::Running, "prompts/list", params)
.await
}
/// Get one MCP prompt template by name after initialization.
pub async fn get_prompt(
&mut self,
request: GetPromptRequest,
) -> Result<GetPromptResult, McpClientError> {
let params = serde_json::to_value(request).map_err(|err| {
McpClientError::new(
&self.server_name,
McpPhase::Running,
McpErrorKind::Protocol(format!("failed to serialize prompts/get request: {err}")),
)
})?;
self.request(McpPhase::Running, "prompts/get", params).await
}
/// Request pages from `tools/list` up to a host-supplied page/tool bound.
///
/// Bounds are enforced by the host so a server cannot make startup discovery
/// unbounded through pagination.
pub async fn list_tools_bounded(
&mut self,
limits: McpToolListLimits,
) -> Result<ListToolsResult, McpClientError> {
let mut tools = Vec::new();
let mut cursor = None;
let mut pages = 0usize;
loop {
if pages >= limits.max_pages {
return Err(McpClientError::new(
&self.server_name,
McpPhase::Running,
McpErrorKind::Protocol(format!(
"tools/list exceeded {} page(s)",
limits.max_pages
)),
)
.with_diagnostics(self.snapshot_diagnostics().await));
}
pages += 1;
let result = self.list_tools_page(cursor.take()).await?;
for tool in result.tools {
if tools.len() >= limits.max_tools {
return Err(McpClientError::new(
&self.server_name,
McpPhase::Running,
McpErrorKind::Protocol(format!(
"tools/list exceeded {} tool(s)",
limits.max_tools
)),
)
.with_diagnostics(self.snapshot_diagnostics().await));
}
tools.push(tool);
}
cursor = result.next_cursor;
if cursor.is_none() {
return Ok(ListToolsResult {
tools,
next_cursor: None,
meta: result.meta,
extra: BTreeMap::new(),
});
}
}
}
pub async fn snapshot_diagnostics(&self) -> McpDiagnostics {
self.diagnostics.lock().await.snapshot()
}
/// Return bounded list-change signals observed so far for this connection.
///
/// This is diagnostic/freshness state only. It never contains notification
/// params and must not be used to mutate an active run's model-visible tool
/// schema outside an explicit safe boundary.
pub async fn snapshot_list_changes(&self) -> McpListChangedSnapshot {
self.list_changes.lock().await.snapshot()
}
/// Clear observed list-change signals before an explicit safe-boundary
/// refresh. New notifications received after this call will be recorded.
pub async fn clear_list_changes(&self) {
self.list_changes.lock().await.clear();
}
pub async fn request<T: for<'de> Deserialize<'de>>(
&mut self,
phase: McpPhase,
method: &str,
params: Value,
) -> Result<T, McpClientError> {
let id = self.next_id;
self.next_id += 1;
let request = ClientRequest {
jsonrpc: JSONRPC_VERSION,
id,
method,
params,
};
self.write_protocol(phase, &request).await?;
let response = match timeout(
self.limits.request_timeout,
self.wait_for_response(id, phase),
)
.await
{
Ok(result) => result?,
Err(_) => {
return Err(McpClientError::new(
&self.server_name,
phase,
McpErrorKind::Timeout {
operation: method.to_owned(),
elapsed: self.limits.request_timeout,
},
)
.with_diagnostics(self.snapshot_diagnostics().await));
}
};
if let Some(error) = response.error {
return Err(McpClientError::new(
&self.server_name,
phase,
McpErrorKind::JsonRpcError {
code: error.code,
message: self.redactor.redact(&error.message),
},
)
.with_diagnostics(self.snapshot_diagnostics().await));
}
let result = response.result.ok_or_else(|| {
McpClientError::new(
&self.server_name,
phase,
McpErrorKind::Protocol(format!("response to `{method}` did not contain result")),
)
})?;
serde_json::from_value(result).map_err(|err| {
McpClientError::new(
&self.server_name,
phase,
McpErrorKind::Protocol(format!("invalid `{method}` result: {err}")),
)
})
}
async fn wait_for_response(
&mut self,
id: u64,
phase: McpPhase,
) -> Result<ServerResponse, McpClientError> {
while let Some(event) = self.responses.recv().await {
match event {
ReaderEvent::Response(response) if response.id == id => return Ok(response),
ReaderEvent::Response(_) | ReaderEvent::Notification => continue,
ReaderEvent::Error(err) => return Err(err.with_phase(phase)),
ReaderEvent::Eof => {
return Err(McpClientError::new(
&self.server_name,
phase,
McpErrorKind::Protocol("server stdout closed before response".into()),
)
.with_diagnostics(self.snapshot_diagnostics().await));
}
}
}
Err(McpClientError::new(
&self.server_name,
phase,
McpErrorKind::Protocol("stdout reader stopped before response".into()),
)
.with_diagnostics(self.snapshot_diagnostics().await))
}
async fn write_notification<T: Serialize>(
&mut self,
phase: McpPhase,
method: &str,
params: T,
) -> Result<(), McpClientError> {
self.write_protocol(
phase,
&ClientNotification {
jsonrpc: JSONRPC_VERSION,
method,
params,
},
)
.await
}
async fn write_protocol<T: Serialize>(
&mut self,
phase: McpPhase,
value: &T,
) -> Result<(), McpClientError> {
write_json_line(
&self.server_name,
phase,
&self.stdin,
value,
self.limits.max_protocol_bytes,
&self.redactor,
)
.await
}
/// Close stdin and wait for process exit, falling back to terminate and kill.
pub async fn shutdown(&mut self) -> Result<ShutdownReport, McpClientError> {
self.shutdown_started = true;
{
let mut stdin = self.stdin.lock().await;
stdin.take();
}
let mut child = match self.child.take() {
Some(child) => child,
None => return Ok(ShutdownReport::already_finished()),
};
if let Ok(Some(status)) = child.try_wait() {
self.reader_task.abort();
self.stderr_task.abort();
return Ok(ShutdownReport {
exit_status: Some(status),
terminated: false,
killed: false,
});
}
match timeout(self.limits.shutdown_timeout, child.wait()).await {
Ok(Ok(status)) => {
self.reader_task.abort();
self.stderr_task.abort();
Ok(ShutdownReport {
exit_status: Some(status),
terminated: false,
killed: false,
})
}
Ok(Err(err)) => Err(McpClientError::new(
&self.server_name,
McpPhase::Shutdown,
McpErrorKind::Io(self.redactor.redact(&err.to_string())),
)
.with_diagnostics(self.snapshot_diagnostics().await)),
Err(_) => self.terminate_then_kill(child).await,
}
}
async fn terminate_then_kill(
&mut self,
mut child: Child,
) -> Result<ShutdownReport, McpClientError> {
let mut terminated = false;
let mut killed = false;
if send_terminate(&mut child).is_ok() {
terminated = true;
}
match timeout(self.limits.kill_timeout, child.wait()).await {
Ok(Ok(status)) => {
self.reader_task.abort();
self.stderr_task.abort();
Ok(ShutdownReport {
exit_status: Some(status),
terminated,
killed,
})
}
Ok(Err(err)) => Err(McpClientError::new(
&self.server_name,
McpPhase::Shutdown,
McpErrorKind::Io(self.redactor.redact(&err.to_string())),
)
.with_diagnostics(self.snapshot_diagnostics().await)),
Err(_) => {
child.start_kill().map_err(|err| {
McpClientError::new(
&self.server_name,
McpPhase::Shutdown,
McpErrorKind::Io(self.redactor.redact(&err.to_string())),
)
})?;
killed = true;
let status = timeout(self.limits.kill_timeout, child.wait())
.await
.map_err(|_| {
McpClientError::new(
&self.server_name,
McpPhase::Shutdown,
McpErrorKind::Timeout {
operation: "kill".to_string(),
elapsed: self.limits.kill_timeout,
},
)
})?
.map_err(|err| {
McpClientError::new(
&self.server_name,
McpPhase::Shutdown,
McpErrorKind::Io(self.redactor.redact(&err.to_string())),
)
})?;
self.reader_task.abort();
self.stderr_task.abort();
Ok(ShutdownReport {
exit_status: Some(status),
terminated,
killed,
})
}
}
}
}
impl Drop for McpStdioClient {
fn drop(&mut self) {
if !self.shutdown_started {
if let Some(child) = &mut self.child {
let _ = child.start_kill();
}
}
self.reader_task.abort();
self.stderr_task.abort();
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum McpPhase {
Spawn,
Initialize,
Initialized,
Running,
Shutdown,
}
impl fmt::Display for McpPhase {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Spawn => f.write_str("spawn"),
Self::Initialize => f.write_str("initialize"),
Self::Initialized => f.write_str("initialized"),
Self::Running => f.write_str("running"),
Self::Shutdown => f.write_str("shutdown"),
}
}
}
#[derive(Debug, Clone)]
pub struct ShutdownReport {
pub exit_status: Option<ExitStatus>,
pub terminated: bool,
pub killed: bool,
}
impl ShutdownReport {
fn already_finished() -> Self {
Self {
exit_status: None,
terminated: false,
killed: false,
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializeResult {
pub protocol_version: String,
#[serde(default)]
pub capabilities: Value,
pub server_info: ImplementationInfo,
#[serde(default)]
pub instructions: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ImplementationInfo {
pub name: String,
pub version: String,
}
#[derive(Debug, Clone)]
pub struct McpDiagnostics {
pub server_name: String,
pub stderr: Vec<String>,
pub dropped_stderr_lines: usize,
pub truncated_stderr_lines: usize,
}
#[derive(Debug, Error, Clone)]
#[error("MCP stdio server `{server_name}` failed during phase `{phase}`: {kind}")]
pub struct McpClientError {
pub server_name: String,
pub phase: McpPhase,
pub kind: McpErrorKind,
diagnostics: Option<McpDiagnostics>,
}
impl McpClientError {
fn new(server_name: impl Into<String>, phase: McpPhase, kind: McpErrorKind) -> Self {
Self {
server_name: server_name.into(),
phase,
kind,
diagnostics: None,
}
}
fn with_phase(mut self, phase: McpPhase) -> Self {
self.phase = phase;
self
}
pub fn with_diagnostics(mut self, diagnostics: McpDiagnostics) -> Self {
self.diagnostics = Some(diagnostics);
self
}
pub fn diagnostics(&self) -> Option<&McpDiagnostics> {
self.diagnostics.as_ref()
}
}
#[derive(Debug, Error, Clone)]
pub enum McpErrorKind {
#[error("configuration error: {0}")]
Config(String),
#[error("I/O error: {0}")]
Io(String),
#[error("protocol error: {0}")]
Protocol(String),
#[error("JSON-RPC error {code}: {message}")]
JsonRpcError { code: i64, message: String },
#[error("timed out during {operation} after {elapsed:?}")]
Timeout {
operation: String,
elapsed: Duration,
},
}
#[derive(Debug)]
enum ReaderEvent {
Response(ServerResponse),
Notification,
Error(McpClientError),
Eof,
}
#[derive(Debug, Serialize)]
struct ClientRequest<'a> {
jsonrpc: &'static str,
id: u64,
method: &'a str,
params: Value,
}
#[derive(Debug, Serialize)]
struct ClientNotification<'a, T> {
jsonrpc: &'static str,
method: &'a str,
params: T,
}
#[derive(Debug, Deserialize)]
struct IncomingMessage {
#[allow(dead_code)]
jsonrpc: Option<String>,
id: Option<Value>,
method: Option<String>,
result: Option<Value>,
error: Option<RpcError>,
#[allow(dead_code)]
params: Option<Value>,
}
#[derive(Debug, Deserialize)]
struct ServerResponse {
id: u64,
result: Option<Value>,
error: Option<RpcError>,
}
#[derive(Debug, Deserialize)]
struct RpcError {
code: i64,
message: String,
#[allow(dead_code)]
data: Option<Value>,
}
#[derive(Debug, Serialize)]
struct ErrorResponse<'a> {
jsonrpc: &'static str,
id: Value,
error: ErrorObject<'a>,
}
#[derive(Debug, Serialize)]
struct ErrorObject<'a> {
code: i64,
message: &'a str,
}
fn spawn_stdout_reader(
server_name: String,
stdout: ChildStdout,
stdin: Arc<Mutex<Option<ChildStdin>>>,
tx: mpsc::Sender<ReaderEvent>,
limits: McpStdioLimits,
redactor: Redactor,
list_changes: Arc<Mutex<BoundedListChanged>>,
) -> JoinHandle<()> {
tokio::spawn(async move {
let mut stdout = BufReader::new(stdout);
loop {
match read_protocol_line(&mut stdout, limits.max_stdout_line_bytes).await {
Ok(Some(line)) => match serde_json::from_slice::<IncomingMessage>(&line) {
Ok(message) => {
handle_incoming_message(
&server_name,
&stdin,
&tx,
&limits,
&redactor,
&list_changes,
message,
)
.await
}
Err(err) => {
let _ = tx
.send(ReaderEvent::Error(McpClientError::new(
&server_name,
McpPhase::Running,
McpErrorKind::Protocol(format!(
"invalid stdout JSON-RPC message: {err}"
)),
)))
.await;
break;
}
},
Ok(None) => {
let _ = tx.send(ReaderEvent::Eof).await;
break;
}
Err(err) => {
let _ = tx
.send(ReaderEvent::Error(McpClientError::new(
&server_name,
McpPhase::Running,
McpErrorKind::Protocol(err),
)))
.await;
break;
}
}
}
})
}
async fn handle_incoming_message(
server_name: &str,
stdin: &Arc<Mutex<Option<ChildStdin>>>,
tx: &mpsc::Sender<ReaderEvent>,
limits: &McpStdioLimits,
redactor: &Redactor,
list_changes: &Arc<Mutex<BoundedListChanged>>,
message: IncomingMessage,
) {
if message.method.is_some() && message.id.is_some() {
if let Some(id) = message.id {
let response = ErrorResponse {
jsonrpc: JSONRPC_VERSION,
id,
error: ErrorObject {
code: ERR_METHOD_NOT_FOUND,
message: "server-to-client requests are not supported by this client",
},
};
let _ = write_json_line(
server_name,
McpPhase::Running,
stdin,
&response,
limits.max_protocol_bytes,
redactor,
)
.await;
}
return;
}
if let Some(method) = message.method.as_deref() {
if let Some(kind) = McpListChangedKind::from_notification_method(method) {
list_changes.lock().await.mark(kind);
}
let _ = tx.send(ReaderEvent::Notification).await;
return;
}
if let Some(id) = message.id.as_ref().and_then(Value::as_u64) {
let _ = tx
.send(ReaderEvent::Response(ServerResponse {
id,
result: message.result,
error: message.error,
}))
.await;
return;
}
let _ = tx
.send(ReaderEvent::Error(McpClientError::new(
server_name,
McpPhase::Running,
McpErrorKind::Protocol(
"JSON-RPC response id was missing or not an unsigned integer".into(),
),
)))
.await;
}
#[derive(Debug)]
struct BoundedListChanged {
server_name: String,
kinds: BTreeSet<McpListChangedKind>,
}
impl BoundedListChanged {
fn new(server_name: String) -> Self {
Self {
server_name,
kinds: BTreeSet::new(),
}
}
fn mark(&mut self, kind: McpListChangedKind) {
self.kinds.insert(kind);
}
fn clear(&mut self) {
self.kinds.clear();
}
fn snapshot(&self) -> McpListChangedSnapshot {
McpListChangedSnapshot {
server_name: self.server_name.clone(),
kinds: self.kinds.clone(),
}
}
}
fn spawn_stderr_reader(
stderr: ChildStderr,
diagnostics: Arc<Mutex<BoundedDiagnostics>>,
limits: McpStdioLimits,
) -> JoinHandle<()> {
tokio::spawn(async move {
let mut stderr = BufReader::new(stderr);
loop {
match read_diagnostic_line(&mut stderr, limits.max_stderr_line_bytes).await {
Ok(Some((line, truncated))) => diagnostics.lock().await.push(line, truncated),
Ok(None) => break,
Err(err) => {
diagnostics
.lock()
.await
.push(format!("stderr read error: {err}"), false);
break;
}
}
}
})
}
async fn write_json_line<T: Serialize>(
server_name: &str,
phase: McpPhase,
stdin: &Arc<Mutex<Option<ChildStdin>>>,
value: &T,
max_protocol_bytes: usize,
redactor: &Redactor,
) -> Result<(), McpClientError> {
let mut bytes = serde_json::to_vec(value).map_err(|err| {
McpClientError::new(
server_name,
phase,
McpErrorKind::Protocol(format!("failed to encode JSON-RPC message: {err}")),
)
})?;
if bytes.len() > max_protocol_bytes {
return Err(McpClientError::new(
server_name,
phase,
McpErrorKind::Protocol(format!(
"JSON-RPC payload exceeded {max_protocol_bytes} bytes"
)),
));
}
bytes.push(b'\n');
let mut guard = stdin.lock().await;
let Some(stdin) = guard.as_mut() else {
return Err(McpClientError::new(
server_name,
phase,
McpErrorKind::Io("child stdin is closed".into()),
));
};
stdin.write_all(&bytes).await.map_err(|err| {
McpClientError::new(
server_name,
phase,
McpErrorKind::Io(redactor.redact(&err.to_string())),
)
})?;
stdin.flush().await.map_err(|err| {
McpClientError::new(
server_name,
phase,
McpErrorKind::Io(redactor.redact(&err.to_string())),
)
})
}
async fn read_protocol_line<R: AsyncRead + Unpin>(
reader: &mut R,
max_bytes: usize,
) -> Result<Option<Vec<u8>>, String> {
let mut buf = Vec::new();
let mut byte = [0u8; 1];
loop {
let read = reader
.read(&mut byte)
.await
.map_err(|err| err.to_string())?;
if read == 0 {
return if buf.is_empty() {
Ok(None)
} else {
Ok(Some(trim_newline(buf)))
};
}
if byte[0] == b'\n' {
return Ok(Some(trim_newline(buf)));
}
if buf.len() >= max_bytes {
return Err(format!("stdout line exceeded {max_bytes} bytes"));
}
buf.push(byte[0]);
}
}
async fn read_diagnostic_line<R: AsyncRead + Unpin>(
reader: &mut R,
max_bytes: usize,
) -> Result<Option<(String, bool)>, String> {
let mut buf = Vec::new();
let mut truncated = false;
let mut byte = [0u8; 1];
loop {
let read = reader
.read(&mut byte)
.await
.map_err(|err| err.to_string())?;
if read == 0 {
if buf.is_empty() && !truncated {
return Ok(None);
}
return Ok(Some((
String::from_utf8_lossy(&trim_newline(buf)).into_owned(),
truncated,
)));
}
if byte[0] == b'\n' {
return Ok(Some((
String::from_utf8_lossy(&trim_newline(buf)).into_owned(),
truncated,
)));
}
if buf.len() < max_bytes {
buf.push(byte[0]);
} else {
truncated = true;
}
}
}
fn trim_newline(mut buf: Vec<u8>) -> Vec<u8> {
if buf.last() == Some(&b'\r') {
buf.pop();
}
buf
}
#[derive(Debug)]
struct BoundedDiagnostics {
server_name: String,
max_lines: usize,
lines: VecDeque<String>,
dropped_lines: usize,
truncated_lines: usize,
redactor: Redactor,
}
impl BoundedDiagnostics {
fn new(server_name: String, max_lines: usize, redactor: Redactor) -> Self {
Self {
server_name,
max_lines,
lines: VecDeque::new(),
dropped_lines: 0,
truncated_lines: 0,
redactor,
}
}
fn push(&mut self, line: String, truncated: bool) {
if truncated {
self.truncated_lines += 1;
}
if self.max_lines == 0 {
self.dropped_lines += 1;
return;
}
if self.lines.len() == self.max_lines {
self.lines.pop_front();
self.dropped_lines += 1;
}
let suffix = if truncated { "… [truncated]" } else { "" };
self.lines
.push_back(format!("{}{suffix}", self.redactor.redact(&line)));
}
fn snapshot(&self) -> McpDiagnostics {
McpDiagnostics {
server_name: self.server_name.clone(),
stderr: self.lines.iter().cloned().collect(),
dropped_stderr_lines: self.dropped_lines,
truncated_stderr_lines: self.truncated_lines,
}
}
}
#[derive(Debug, Clone)]
struct Redactor {
values: Vec<String>,
}
impl Redactor {
fn new(mut values: Vec<String>) -> Self {
values.retain(|value| !value.is_empty());
values.sort_by_key(|value| std::cmp::Reverse(value.len()));
values.dedup();
Self { values }
}
fn redact(&self, input: &str) -> String {
let mut output = input.to_owned();
for value in &self.values {
output = output.replace(value, "[redacted]");
}
output
}
}
#[cfg(unix)]
fn send_terminate(child: &mut Child) -> Result<(), ()> {
let Some(pid) = child.id() else {
return Err(());
};
let result = unsafe { libc::kill(pid as libc::pid_t, libc::SIGTERM) };
if result == 0 { Ok(()) } else { Err(()) }
}
#[cfg(not(unix))]
fn send_terminate(child: &mut Child) -> Result<(), ()> {
child.start_kill().map_err(|_| ())
}