From 9a2454037fe2e30d8ac3dade5e5072e2e9e7f9e9 Mon Sep 17 00:00:00 2001 From: Hare Date: Sat, 20 Jun 2026 18:07:21 +0900 Subject: [PATCH] mcp: execute stdio tool calls --- crates/mcp/src/stdio.rs | 64 +++ crates/mcp/tests/fixtures/mock_server.rs | 93 ++++- crates/mcp/tests/stdio_lifecycle.rs | 71 +++- crates/pod/src/feature/mcp.rs | 473 ++++++++++++++++++++++- 4 files changed, 689 insertions(+), 12 deletions(-) diff --git a/crates/mcp/src/stdio.rs b/crates/mcp/src/stdio.rs index 9ed55560..9dfe0979 100644 --- a/crates/mcp/src/stdio.rs +++ b/crates/mcp/src/stdio.rs @@ -99,6 +99,51 @@ pub struct ListToolsResult { pub extra: BTreeMap, } +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct CallToolRequest { + pub name: String, + #[serde(default, skip_serializing_if = "Value::is_null")] + pub arguments: Value, +} + +impl CallToolRequest { + pub fn new(name: impl Into, arguments: Value) -> Self { + Self { + name: name.into(), + arguments, + } + } +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct CallToolResult { + #[serde(default)] + pub content: Vec, + #[serde(default)] + pub structured_content: Option, + #[serde(default)] + pub is_error: bool, + #[serde(default, rename = "_meta")] + pub meta: Option, + #[serde(flatten)] + pub extra: BTreeMap, +} + +/// One untrusted MCP `tools/call` content block. +/// +/// The `type` discriminator is kept explicit and all server-owned fields stay +/// data in `fields`; this crate does not turn rich MCP content into hidden host +/// context. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct McpContentBlock { + #[serde(rename = "type")] + pub kind: String, + #[serde(flatten)] + pub fields: BTreeMap, +} + /// A resolved, explicit local stdio MCP server process specification. #[derive(Clone)] pub struct McpStdioServerSpec { @@ -426,6 +471,25 @@ impl McpStdioClient { self.request(McpPhase::Running, "tools/list", params).await } + /// Execute an initialized MCP `tools/call` request. + /// + /// The caller is responsible for applying Yoi tool permissions before this + /// method is reached and for bounding/serializing the untrusted result before + /// it is exposed to model-visible tool history. + pub async fn call_tool( + &mut self, + request: CallToolRequest, + ) -> Result { + let params = serde_json::to_value(request).map_err(|err| { + McpClientError::new( + &self.server_name, + McpPhase::Running, + McpErrorKind::Protocol(format!("failed to serialize tools/call request: {err}")), + ) + })?; + self.request(McpPhase::Running, "tools/call", params).await + } + /// Request pages from `tools/list` up to a host-supplied page/tool bound. /// /// Bounds are enforced by the host so a server cannot make startup discovery diff --git a/crates/mcp/tests/fixtures/mock_server.rs b/crates/mcp/tests/fixtures/mock_server.rs index de986b8c..fcb5c8dd 100644 --- a/crates/mcp/tests/fixtures/mock_server.rs +++ b/crates/mcp/tests/fixtures/mock_server.rs @@ -10,7 +10,10 @@ fn main() { match mode.as_str() { "success" => success(), "tools" => tools_list(), - "tools-call-forbidden" => tools_list(), + "tools-call-normal" => tools_call_normal(), + "tools-call-is-error" => tools_call_is_error(), + "tools-call-protocol-error" => tools_call_protocol_error(), + "tools-call-forbidden" => tools_call_forbidden(), "fail-init" => fail_init(), "sampling" => sampling_request(), "shutdown-hang" => shutdown_hang(), @@ -96,6 +99,94 @@ fn tools_list() { } } +fn tools_call_normal() { + tools_call(|request| { + assert_eq!(request["params"]["name"], "search-files"); + assert_eq!(request["params"]["arguments"]["query"], "needle"); + json!({ + "jsonrpc": "2.0", + "id": request["id"], + "result": { + "content": [{"type": "text", "text": "found needle"}], + "structuredContent": {"matches": ["needle.rs"]}, + "_meta": {"server": "mock"} + } + }) + }); +} + +fn tools_call_is_error() { + tools_call(|request| { + assert_eq!(request["params"]["name"], "search-files"); + json!({ + "jsonrpc": "2.0", + "id": request["id"], + "result": { + "isError": true, + "content": [{"type": "text", "text": "tool-level failure"}] + } + }) + }); +} + +fn tools_call_protocol_error() { + tools_call(|request| { + json!({ + "jsonrpc": "2.0", + "id": request["id"], + "error": {"code": -32010, "message": "server refused tools/call"} + }) + }); +} + +fn tools_call_forbidden() { + let init = read_json(); + assert_eq!(init["method"], "initialize"); + write_json(json!({ + "jsonrpc": "2.0", + "id": init["id"], + "result": initialize_result(), + })); + let initialized = read_json(); + assert_eq!(initialized["method"], "notifications/initialized"); + + loop { + let request = read_json(); + assert_ne!( + request["method"], "tools/call", + "permission denial path must not send MCP tools/call" + ); + if request["method"] == "shutdown" { + write_json(json!({"jsonrpc":"2.0", "id": request["id"], "result": {}})); + let notification = read_json(); + assert_eq!(notification["method"], "exit"); + break; + } + } +} + +fn tools_call(response: impl FnOnce(&Value) -> Value) { + let init = read_json(); + assert_eq!(init["method"], "initialize"); + write_json(json!({ + "jsonrpc": "2.0", + "id": init["id"], + "result": initialize_result(), + })); + let initialized = read_json(); + assert_eq!(initialized["method"], "notifications/initialized"); + + let call = read_json(); + assert_eq!(call["method"], "tools/call"); + write_json(response(&call)); + + let shutdown = read_json(); + assert_eq!(shutdown["method"], "shutdown"); + write_json(json!({"jsonrpc":"2.0", "id": shutdown["id"], "result": {}})); + let notification = read_json(); + assert_eq!(notification["method"], "exit"); +} + fn fail_init() { let secret = env::var("MCP_TEST_SECRET").unwrap_or_default(); for idx in 0..5 { diff --git a/crates/mcp/tests/stdio_lifecycle.rs b/crates/mcp/tests/stdio_lifecycle.rs index 44c9d3ef..d7220a16 100644 --- a/crates/mcp/tests/stdio_lifecycle.rs +++ b/crates/mcp/tests/stdio_lifecycle.rs @@ -1,7 +1,8 @@ use std::time::Duration; use mcp::stdio::{ - McpErrorKind, McpPhase, McpStdioClient, McpStdioLimits, McpStdioServerSpec, McpToolListLimits, + CallToolRequest, McpErrorKind, McpPhase, McpStdioClient, McpStdioLimits, McpStdioServerSpec, + McpToolListLimits, }; fn mock_server(mode: &str) -> McpStdioServerSpec { @@ -161,6 +162,74 @@ async fn initialize_failure_reports_server_phase_and_redacted_bounded_stderr() { ); } +#[tokio::test] +async fn call_tool_returns_normal_result() { + let mut client = McpStdioClient::connect(mock_server("tools-call-normal"), tight_limits()) + .await + .expect("connect"); + let result = client + .call_tool(CallToolRequest::new( + "search-files", + serde_json::json!({"query": "needle"}), + )) + .await + .expect("call tool"); + assert!(!result.is_error); + assert_eq!(result.content.len(), 1); + assert_eq!(result.content[0].kind, "text"); + assert_eq!(result.content[0].fields["text"], "found needle"); + assert_eq!( + result.structured_content.as_ref().unwrap()["matches"][0], + "needle.rs" + ); + assert_eq!(result.meta.as_ref().unwrap()["server"], "mock"); + client.shutdown().await.expect("shutdown"); +} + +#[tokio::test] +async fn call_tool_preserves_mcp_is_error_result() { + let mut client = McpStdioClient::connect(mock_server("tools-call-is-error"), tight_limits()) + .await + .expect("connect"); + let result = client + .call_tool(CallToolRequest::new( + "search-files", + serde_json::json!({"query": "needle"}), + )) + .await + .expect("call tool"); + assert!(result.is_error); + assert_eq!(result.content[0].fields["text"], "tool-level failure"); + client.shutdown().await.expect("shutdown"); +} + +#[tokio::test] +async fn call_tool_reports_json_rpc_protocol_error_distinctly() { + let mut client = + McpStdioClient::connect(mock_server("tools-call-protocol-error"), tight_limits()) + .await + .expect("connect"); + let err = client + .call_tool(CallToolRequest::new( + "search-files", + serde_json::json!({"query": "needle"}), + )) + .await + .expect_err("protocol error"); + assert!(matches!(err.kind, McpErrorKind::JsonRpcError { .. })); + client.shutdown().await.expect("shutdown"); +} + +#[tokio::test] +async fn permission_denial_style_shutdown_sends_no_tools_call() { + let mut client = McpStdioClient::connect(mock_server("tools-call-forbidden"), tight_limits()) + .await + .expect("connect"); + // This mirrors Worker pre-tool-call denial: the ordinary Tool execution body + // is never entered, so the MCP server sees lifecycle shutdown but no call. + client.shutdown().await.expect("shutdown"); +} + #[tokio::test] async fn shutdown_terminates_or_kills_uncooperative_server() { let mut client = McpStdioClient::connect(mock_server("shutdown-hang"), tight_limits()) diff --git a/crates/pod/src/feature/mcp.rs b/crates/pod/src/feature/mcp.rs index 0a927902..15580be5 100644 --- a/crates/pod/src/feature/mcp.rs +++ b/crates/pod/src/feature/mcp.rs @@ -8,7 +8,8 @@ use llm_worker::tool::{ }; use manifest::McpConfig; use mcp::stdio::{ - ListToolsResult, McpStdioClient, McpStdioLimits, McpStdioServerSpec, McpToolDefinition, + CallToolRequest, CallToolResult, ListToolsResult, McpClientError, McpContentBlock, + McpErrorKind, McpStdioClient, McpStdioLimits, McpStdioServerSpec, McpToolDefinition, McpToolListLimits, resolve_stdio_server, }; use serde_json::{Map, Value}; @@ -29,6 +30,12 @@ const MAX_SCHEMA_STRING_CHARS: usize = 4096; const MAX_DIAGNOSTIC_CHARS: usize = 512; const MAX_TOOL_PAGES: usize = 8; const MAX_TOOLS_PER_SERVER: usize = 128; +const MAX_RESULT_CONTENT_BLOCKS: usize = 16; +const MAX_RESULT_TEXT_CHARS: usize = 8192; +const MAX_RESULT_JSON_DEPTH: usize = 12; +const MAX_RESULT_JSON_NODES: usize = 512; +const MAX_RESULT_STRING_CHARS: usize = 4096; +const MAX_RESULT_OUTPUT_BYTES: usize = 64 * 1024; /// Discover enabled MCP stdio server tools and return a single feature module /// containing startup contributions for normal ToolRegistry installation. @@ -62,6 +69,7 @@ async fn discover_server_tools(spec: McpStdioServerSpec) -> ProtocolProviderCont let declaration = provider_declaration(&spec.name, None); let mut contribution = ProtocolProviderContribution::ready(declaration.clone()); let server_namespace = sanitize_segment(&spec.name); + let execution_spec = spec.clone(); let mut client = match McpStdioClient::connect(spec, McpStdioLimits::default()).await { Ok(client) => client, @@ -122,6 +130,7 @@ async fn discover_server_tools(spec: McpStdioServerSpec) -> ProtocolProviderCont } contribution = normalize_listed_tools( + execution_spec, contribution, declaration, server_namespace, @@ -132,6 +141,7 @@ async fn discover_server_tools(spec: McpStdioServerSpec) -> ProtocolProviderCont } fn normalize_listed_tools( + execution_spec: McpStdioServerSpec, mut contribution: ProtocolProviderContribution, declaration: ProtocolProviderDeclaration, server_namespace: String, @@ -143,6 +153,7 @@ fn normalize_listed_tools( for tool in list.tools { match mcp_tool_contribution( + execution_spec.clone(), &declaration, &server_namespace, server_version.as_deref(), @@ -177,11 +188,13 @@ fn normalize_listed_tools( } fn mcp_tool_contribution( + execution_spec: McpStdioServerSpec, declaration: &ProtocolProviderDeclaration, server_namespace: &str, server_version: Option<&str>, tool: McpToolDefinition, ) -> Result<(String, ToolContribution), String> { + let mcp_tool_name = tool.name.clone(); let tool_segment = sanitize_segment(&tool.name); if tool_segment == "unnamed" { return Err(bounded_diagnostic( @@ -215,13 +228,18 @@ fn mcp_tool_contribution( let description = description.clone(); let schema = schema.clone(); let origin = origin.clone(); + let execution_spec = execution_spec.clone(); + let mcp_tool_name = mcp_tool_name.clone(); move || { ( ToolMeta::new(name.clone()) .description(description.clone()) .input_schema(schema.clone()) .origin(origin.clone()), - Arc::new(McpDiscoveryOnlyTool) as Arc, + Arc::new(McpStdioTool { + server_spec: execution_spec.clone(), + mcp_tool_name: mcp_tool_name.clone(), + }) as Arc, ) } }); @@ -232,22 +250,336 @@ fn mcp_tool_contribution( } #[derive(Debug)] -struct McpDiscoveryOnlyTool; +struct McpStdioTool { + server_spec: McpStdioServerSpec, + mcp_tool_name: String, +} #[async_trait] -impl Tool for McpDiscoveryOnlyTool { +impl Tool for McpStdioTool { async fn execute( &self, - _input_json: &str, + input_json: &str, _ctx: ToolExecutionContext, ) -> Result { - Err(ToolError::ExecutionFailed( - "MCP tool execution is not implemented in this release; registration is discovery-only" - .to_string(), - )) + let arguments = parse_tool_arguments(input_json)?; + let mut client = + McpStdioClient::connect(self.server_spec.clone(), McpStdioLimits::default()) + .await + .map_err(|err| ToolError::ExecutionFailed(mcp_call_error_message(&err)))?; + + let call_result = client + .call_tool(CallToolRequest::new(self.mcp_tool_name.clone(), arguments)) + .await; + let shutdown_result = client.shutdown().await; + + match call_result { + Ok(result) => { + let mut output = render_call_tool_result(&self.mcp_tool_name, result)?; + if let Err(err) = shutdown_result { + let warning = bounded_diagnostic(format!( + "MCP server shutdown after tools/call failed: {err}" + )); + output.summary.push_str("; shutdown warning recorded"); + output.content = Some(match output.content.take() { + Some(content) => format!("{content}\n\nShutdown warning: {warning}"), + None => format!("Shutdown warning: {warning}"), + }); + } + Ok(output) + } + Err(err) => { + let mut message = mcp_call_error_message(&err); + if let Err(shutdown_err) = shutdown_result { + message.push_str("; shutdown after failure also failed: "); + message.push_str(&bounded_diagnostic(shutdown_err.to_string())); + } + Err(ToolError::ExecutionFailed(message)) + } + } } } +fn parse_tool_arguments(input_json: &str) -> Result { + let input = input_json.trim(); + if input.is_empty() { + return Ok(Value::Object(Map::new())); + } + let value: Value = serde_json::from_str(input).map_err(|err| { + ToolError::InvalidArgument(format!("invalid MCP tool arguments JSON: {err}")) + })?; + Ok(match value { + Value::Null => Value::Object(Map::new()), + other => other, + }) +} + +fn mcp_call_error_message(err: &McpClientError) -> String { + match &err.kind { + McpErrorKind::JsonRpcError { .. } => { + format!("MCP tools/call JSON-RPC protocol error: {err}") + } + _ => format!("MCP tools/call transport/protocol failure: {err}"), + } +} + +fn render_call_tool_result( + mcp_tool_name: &str, + result: CallToolResult, +) -> Result { + let mut truncated = false; + let omitted_blocks = result + .content + .len() + .saturating_sub(MAX_RESULT_CONTENT_BLOCKS); + let blocks: Vec = result + .content + .iter() + .take(MAX_RESULT_CONTENT_BLOCKS) + .map(|block| serialize_content_block(block, &mut truncated)) + .collect(); + + let mut budget = ResultJsonBudget { + nodes: 0, + truncated: false, + }; + let structured_content = result + .structured_content + .as_ref() + .map(|value| bound_result_json(value, 0, &mut budget)); + let meta = result + .meta + .as_ref() + .map(|value| bound_result_json(value, 0, &mut budget)); + let extra = if result.extra.is_empty() { + None + } else { + Some(bound_result_json( + &Value::Object(result.extra.into_iter().collect()), + 0, + &mut budget, + )) + }; + truncated |= budget.truncated || omitted_blocks > 0; + + let status = if result.is_error { + "mcp_is_error" + } else { + "ok" + }; + let mut root = Map::new(); + root.insert("untrusted_mcp_tools_call_result".into(), Value::Bool(true)); + root.insert( + "tool".into(), + Value::String(bounded_plain_text(mcp_tool_name, 256)), + ); + root.insert("status".into(), Value::String(status.to_string())); + root.insert("isError".into(), Value::Bool(result.is_error)); + root.insert("content".into(), Value::Array(blocks)); + if omitted_blocks > 0 { + root.insert("omittedContentBlocks".into(), Value::from(omitted_blocks)); + } + if let Some(value) = structured_content { + root.insert("structuredContent".into(), value); + } + if let Some(value) = meta { + root.insert("_meta".into(), value); + } + if let Some(value) = extra { + root.insert("extra".into(), value); + } + if truncated { + root.insert("truncated".into(), Value::Bool(true)); + } + + let mut content = serde_json::to_string_pretty(&Value::Object(root)).map_err(|err| { + ToolError::ExecutionFailed(format!("failed to serialize MCP tools/call result: {err}")) + })?; + if content.len() > MAX_RESULT_OUTPUT_BYTES { + truncate_utf8(&mut content, MAX_RESULT_OUTPUT_BYTES); + truncated = true; + } + + let status_label = if result.is_error { + "MCP isError=true" + } else { + "success" + }; + let mut summary = format!( + "MCP tool `{}` returned {status_label} ({} content block(s)", + bounded_plain_text(mcp_tool_name, 96), + result.content.len() + ); + if result.structured_content.is_some() { + summary.push_str(", structuredContent"); + } + if result.meta.is_some() { + summary.push_str(", _meta"); + } + if truncated { + summary.push_str(", truncated"); + } + summary.push(')'); + + Ok(ToolOutput { + summary, + content: Some(content), + }) +} + +fn serialize_content_block(block: &McpContentBlock, truncated: &mut bool) -> Value { + let mut out = Map::new(); + out.insert( + "type".to_string(), + Value::String(bounded_plain_text(&block.kind, 64)), + ); + match block.kind.as_str() { + "text" => { + if let Some(text) = block.fields.get("text").and_then(Value::as_str) { + out.insert( + "text".to_string(), + Value::String(bounded_text_field(text, MAX_RESULT_TEXT_CHARS, truncated)), + ); + } + } + "image" | "audio" => { + copy_bounded_field(&mut out, block, "mimeType", truncated); + if let Some(data) = block.fields.get("data").and_then(Value::as_str) { + out.insert("dataBytes".to_string(), Value::from(data.len())); + out.insert("dataOmitted".to_string(), Value::Bool(true)); + *truncated = true; + } + } + "resource_link" => { + for key in ["uri", "name", "title", "description", "mimeType"] { + copy_bounded_field(&mut out, block, key, truncated); + } + } + "resource" => { + if let Some(resource) = block.fields.get("resource") { + let mut budget = ResultJsonBudget { + nodes: 0, + truncated: false, + }; + out.insert( + "resource".to_string(), + bound_result_json(resource, 0, &mut budget), + ); + *truncated |= budget.truncated; + } + } + _ => { + let mut budget = ResultJsonBudget { + nodes: 0, + truncated: false, + }; + out.insert( + "fields".to_string(), + bound_result_json( + &Value::Object(block.fields.clone().into_iter().collect()), + 0, + &mut budget, + ), + ); + *truncated |= budget.truncated; + } + } + Value::Object(out) +} + +fn copy_bounded_field( + out: &mut Map, + block: &McpContentBlock, + key: &str, + truncated: &mut bool, +) { + if let Some(value) = block.fields.get(key) { + let mut budget = ResultJsonBudget { + nodes: 0, + truncated: false, + }; + out.insert(key.to_string(), bound_result_json(value, 0, &mut budget)); + *truncated |= budget.truncated; + } +} + +struct ResultJsonBudget { + nodes: usize, + truncated: bool, +} + +fn bound_result_json(value: &Value, depth: usize, budget: &mut ResultJsonBudget) -> Value { + budget.nodes += 1; + if depth > MAX_RESULT_JSON_DEPTH || budget.nodes > MAX_RESULT_JSON_NODES { + budget.truncated = true; + return Value::String("[truncated: MCP JSON result bounds exceeded]".to_string()); + } + match value { + Value::Null | Value::Bool(_) | Value::Number(_) => value.clone(), + Value::String(text) => Value::String(bounded_text_field( + text, + MAX_RESULT_STRING_CHARS, + &mut budget.truncated, + )), + Value::Array(values) => { + let remaining = MAX_RESULT_JSON_NODES.saturating_sub(budget.nodes).max(1); + if values.len() > remaining { + budget.truncated = true; + } + Value::Array( + values + .iter() + .take(remaining) + .map(|item| bound_result_json(item, depth + 1, budget)) + .collect(), + ) + } + Value::Object(map) => { + let mut out = Map::new(); + for (key, value) in map { + if budget.nodes > MAX_RESULT_JSON_NODES { + budget.truncated = true; + break; + } + out.insert( + bounded_plain_text(key, 128), + bound_result_json(value, depth + 1, budget), + ); + } + Value::Object(out) + } + } +} + +fn bounded_text_field(input: &str, max_chars: usize, truncated: &mut bool) -> String { + let total = input.chars().count(); + if total <= max_chars { + input.to_string() + } else { + *truncated = true; + let mut output: String = input.chars().take(max_chars).collect(); + output.push_str(&format!( + "\n[truncated: {} chars omitted]", + total - max_chars + )); + output + } +} + +fn truncate_utf8(input: &mut String, max_bytes: usize) { + if input.len() <= max_bytes { + return; + } + let marker = format!("\n[truncated: {} bytes omitted]", input.len() - max_bytes); + let keep = max_bytes.saturating_sub(marker.len()); + let mut boundary = keep; + while !input.is_char_boundary(boundary) { + boundary -= 1; + } + input.truncate(boundary); + input.push_str(&marker); +} + fn provider_declaration(name: &str, version: Option<&str>) -> ProtocolProviderDeclaration { ProtocolProviderDeclaration::new( ProviderId::new(format!("mcp:stdio:{}", sanitize_segment(name))) @@ -282,7 +614,7 @@ impl McpStdioToolFeature { impl FeatureModule for McpStdioToolFeature { fn descriptor(&self) -> FeatureDescriptor { let mut descriptor = FeatureDescriptor::builtin(FEATURE_ID, "MCP stdio tools") - .with_description("Discovery-only MCP stdio tool registration"); + .with_description("MCP stdio tool discovery and ordinary tool execution"); descriptor.runtime = FeatureRuntimeKind::ProtocolProvider; for contribution in &self.contributions { descriptor = descriptor.with_protocol_provider(contribution.declaration.clone()); @@ -477,6 +809,10 @@ mod tests { use crate::feature::{FeatureDiagnosticSeverity, FeatureRegistryBuilder}; use crate::hook::HookRegistryBuilder; + fn server_spec() -> McpStdioServerSpec { + McpStdioServerSpec::new("demo", "mock-mcp-server") + } + fn mcp_tool(name: &str, description: &str, schema: Value) -> McpToolDefinition { McpToolDefinition { name: name.to_string(), @@ -494,6 +830,7 @@ mod tests { fn valid_mcp_tool_normalizes_to_model_visible_definition() { let declaration = provider_declaration("demo server", Some("1.2.3")); let (name, contribution) = mcp_tool_contribution( + server_spec(), &declaration, "demo_server", Some("1.2.3"), @@ -523,6 +860,7 @@ mod tests { fn valid_mcp_tool_installs_as_pending_model_visible_tool() { let declaration = provider_declaration("demo", Some("1.0.0")); let (_, tool) = mcp_tool_contribution( + server_spec(), &declaration, "demo", Some("1.0.0"), @@ -554,6 +892,7 @@ mod tests { fn invalid_schema_is_rejected_with_bounded_diagnostic() { let declaration = provider_declaration("demo", None); let error = match mcp_tool_contribution( + server_spec(), &declaration, "demo", None, @@ -580,6 +919,7 @@ mod tests { extra: BTreeMap::new(), }; let contribution = normalize_listed_tools( + server_spec(), ProtocolProviderContribution::ready(declaration.clone()), declaration, "demo".to_string(), @@ -613,6 +953,119 @@ mod tests { assert!(names.iter().any(|name| name == "Mcp_demo_unique")); } + fn shell_tool_server(response: &str) -> McpStdioServerSpec { + let script = format!( + r#"read init || exit 1 +printf '%s\n' '{{"jsonrpc":"2.0","id":1,"result":{{"protocolVersion":"2025-06-18","capabilities":{{"tools":{{}}}},"serverInfo":{{"name":"shell-mock","version":"1"}}}}}}' +read initialized || exit 1 +read call || exit 1 +case "$call" in *'"method":"tools/call"'*|*'"method": "tools/call"'*) ;; *) echo "expected tools/call, got $call" >&2; exit 2;; esac +printf '%s\n' '{}' +read shutdown || exit 1 +printf '%s\n' '{{"jsonrpc":"2.0","id":3,"result":{{}}}}' +read exit_notification || true +"#, + response.replace('\\', "\\\\").replace('\'', "'\\''") + ); + McpStdioServerSpec::new("shell-mock", "/bin/sh").args(["-c".to_string(), script]) + } + + #[tokio::test] + async fn stdio_tool_execute_returns_normal_result_through_tool_output() { + let response = r#"{"jsonrpc":"2.0","id":2,"result":{"content":[{"type":"text","text":"ordinary result"}],"structuredContent":{"ok":true}}}"#; + let tool = McpStdioTool { + server_spec: shell_tool_server(response), + mcp_tool_name: "demo-tool".to_string(), + }; + let output = tool + .execute(r#"{"query":"needle"}"#, ToolExecutionContext::direct()) + .await + .expect("execute"); + assert!(output.summary.contains("returned success")); + let content = output.content.unwrap(); + assert!(content.contains("untrusted_mcp_tools_call_result")); + assert!(content.contains("ordinary result")); + assert!(content.contains("structuredContent")); + } + + #[tokio::test] + async fn stdio_tool_execute_reports_protocol_failure_distinctly() { + let response = r#"{"jsonrpc":"2.0","id":2,"error":{"code":-32000,"message":"boom"}}"#; + let tool = McpStdioTool { + server_spec: shell_tool_server(response), + mcp_tool_name: "demo-tool".to_string(), + }; + let err = tool + .execute(r#"{}"#, ToolExecutionContext::direct()) + .await + .expect_err("protocol error"); + assert!( + err.to_string() + .contains("MCP tools/call JSON-RPC protocol error") + ); + } + + #[test] + fn call_tool_result_renderer_marks_untrusted_mcp_error() { + let result = CallToolResult { + content: vec![McpContentBlock { + kind: "text".to_string(), + fields: BTreeMap::from([( + "text".to_string(), + Value::String("tool-level failure".to_string()), + )]), + }], + structured_content: Some(json!({"diagnostic": "visible"})), + is_error: true, + meta: Some(json!({"trace": "metadata"})), + extra: BTreeMap::new(), + }; + let output = render_call_tool_result("search-files", result).expect("render"); + assert!(output.summary.contains("MCP isError=true")); + let content = output.content.unwrap(); + assert!(content.contains("untrusted_mcp_tools_call_result")); + assert!(content.contains("\"status\": \"mcp_is_error\"")); + assert!(content.contains("tool-level failure")); + assert!(content.contains("structuredContent")); + assert!(content.contains("_meta")); + } + + #[test] + fn call_tool_result_renderer_bounds_rich_outputs() { + let result = CallToolResult { + content: vec![ + McpContentBlock { + kind: "text".to_string(), + fields: BTreeMap::from([( + "text".to_string(), + Value::String("x".repeat(MAX_RESULT_TEXT_CHARS + 128)), + )]), + }, + McpContentBlock { + kind: "image".to_string(), + fields: BTreeMap::from([ + ( + "mimeType".to_string(), + Value::String("image/png".to_string()), + ), + ("data".to_string(), Value::String("A".repeat(1024))), + ]), + }, + ], + structured_content: Some(json!({"long": "y".repeat(MAX_RESULT_STRING_CHARS + 64)})), + is_error: false, + meta: None, + extra: BTreeMap::new(), + }; + let output = render_call_tool_result("rich", result).expect("render"); + assert!(output.summary.contains("truncated")); + let content = output.content.unwrap(); + assert!(content.len() <= MAX_RESULT_OUTPUT_BYTES + 128); + assert!(content.contains("dataOmitted")); + assert!(content.contains("truncated")); + assert!(!content.contains(&"A".repeat(512))); + } + #[test] fn schema_references_are_rejected() { let error = normalize_input_schema(json!({