diff --git a/Cargo.lock b/Cargo.lock index a2d29dc4..7d492b65 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2885,6 +2885,7 @@ dependencies = [ "dotenv", "fs4", "futures", + "futures-util", "include_dir", "libc", "llm-worker", @@ -2906,6 +2907,7 @@ dependencies = [ "thiserror 2.0.18", "ticket", "tokio", + "tokio-tungstenite", "toml", "tools", "tracing", @@ -4471,6 +4473,20 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857" +dependencies = [ + "futures-util", + "log", + "native-tls", + "tokio", + "tokio-native-tls", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.18" diff --git a/crates/pod/Cargo.toml b/crates/pod/Cargo.toml index 36cea441..f5ddde35 100644 --- a/crates/pod/Cargo.toml +++ b/crates/pod/Cargo.toml @@ -40,6 +40,8 @@ arc-swap = "1.9.1" wasmi = { version = "0.51.1", default-features = false, features = ["std", "extra-checks"] } wasmtime = { version = "45.0.2", default-features = false, features = ["std", "runtime", "cranelift", "component-model"] } tungstenite = { version = "0.28.0", default-features = false, features = ["handshake", "native-tls", "url"] } +tokio-tungstenite = { version = "0.28.0", default-features = false, features = ["native-tls", "connect"] } +futures-util = { version = "0.3", features = ["sink"] } [dev-dependencies] dotenv = "0.15.0" diff --git a/crates/pod/src/feature/plugin.rs b/crates/pod/src/feature/plugin.rs index 1584e283..60306f19 100644 --- a/crates/pod/src/feature/plugin.rs +++ b/crates/pod/src/feature/plugin.rs @@ -17,6 +17,7 @@ use std::sync::{Arc, Mutex, OnceLock}; use std::time::{Duration, Instant}; use async_trait::async_trait; +use futures_util::{SinkExt, StreamExt}; use llm_worker::tool::{ Tool, ToolDefinition, ToolError, ToolExecutionContext, ToolMeta, ToolOrigin, ToolOutput, }; @@ -29,9 +30,11 @@ use manifest::plugin::{ }; use serde::{Deserialize, Serialize}; use serde_json::Value; -use tungstenite::client::IntoClientRequest; -use tungstenite::protocol::{Message, WebSocketConfig}; -use tungstenite::stream::MaybeTlsStream; +use tokio::runtime::{ + Builder as TokioRuntimeBuilder, Handle as TokioHandle, Runtime as TokioRuntime, +}; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; +use tokio_tungstenite::tungstenite::protocol::{Message, WebSocketConfig}; use super::{ FeatureDescriptor, FeatureId, FeatureInstallContext, FeatureInstallError, FeatureModule, @@ -927,8 +930,14 @@ fn execute_plugin_websocket_open( ) -> Result, PluginWebSocketError> { let (request, url) = validate_plugin_websocket_open_request(record, bytes)?; let limits = PluginWebSocketLimits::default(); + if !client.supports_bounded_open() { + return Err(PluginWebSocketError::new( + "host_api.websocket client cannot guarantee bounded/cancellable open; refusing to dial", + )); + } + let reservation = handles.reserve_open()?; let connection = client.open(&request, &url, limits)?; - let handle = handles.insert(connection)?; + let handle = reservation.commit(connection)?; serde_json::to_vec(&PluginWebSocketOpenResponse { handle, url: safe_url(&url), @@ -2511,6 +2520,8 @@ trait PluginWebSocketConnection: Send { } trait PluginWebSocketClient: Send + Sync { + fn supports_bounded_open(&self) -> bool; + fn open( &self, request: &PluginWebSocketOpenRequest, @@ -2521,13 +2532,19 @@ trait PluginWebSocketClient: Send + Sync { struct TungstenitePluginWebSocketClient; -type SystemWebSocket = tungstenite::WebSocket>; +type AsyncSystemWebSocket = + tokio_tungstenite::WebSocketStream>; struct TungstenitePluginWebSocketConnection { - socket: SystemWebSocket, + runtime: TokioRuntime, + socket: AsyncSystemWebSocket, } impl PluginWebSocketClient for TungstenitePluginWebSocketClient { + fn supports_bounded_open(&self) -> bool { + true + } + fn open( &self, _request: &PluginWebSocketOpenRequest, @@ -2538,31 +2555,55 @@ impl PluginWebSocketClient for TungstenitePluginWebSocketClient { PluginWebSocketError::new(format!("WebSocket request build failed: {error}")) })?; request.headers_mut().insert( - tungstenite::http::header::USER_AGENT, - tungstenite::http::HeaderValue::from_static("yoi-plugin-host/1"), + tokio_tungstenite::tungstenite::http::header::USER_AGENT, + tokio_tungstenite::tungstenite::http::HeaderValue::from_static("yoi-plugin-host/1"), ); let config = WebSocketConfig::default() .max_message_size(Some(limits.max_message_bytes)) .max_frame_size(Some(PLUGIN_WEBSOCKET_MAX_FRAME_BYTES)) .accept_unmasked_frames(false); - let (mut socket, _response) = - tungstenite::client::connect_with_config(request, Some(config), 0).map_err( - |error| { - PluginWebSocketError::new(format!( - "WebSocket connection failed for {}: {error}", - safe_url(url) - )) - }, - )?; - set_system_websocket_timeouts(&mut socket, limits.timeout); - Ok(Box::new(TungstenitePluginWebSocketConnection { socket })) + let runtime = new_websocket_runtime()?; + let open = async { + tokio::time::timeout( + limits.timeout, + tokio_tungstenite::connect_async_tls_with_config( + request, + Some(config), + false, + None, + ), + ) + .await + }; + let (socket, _response) = block_on_websocket_future(&runtime, open) + .map_err(|error| { + PluginWebSocketError::new(format!( + "WebSocket open timed out after {} ms for {}: {error}", + limits.timeout.as_millis(), + safe_url(url) + )) + })? + .map_err(|error| { + PluginWebSocketError::new(format!( + "WebSocket connection failed for {}: {error}", + safe_url(url) + )) + })?; + Ok(Box::new(TungstenitePluginWebSocketConnection { + runtime, + socket, + })) } } impl PluginWebSocketConnection for TungstenitePluginWebSocketConnection { fn send_text(&mut self, text: &str) -> Result<(), PluginWebSocketError> { - self.socket - .send(Message::Text(text.to_string().into())) + let send = tokio::time::timeout( + PLUGIN_WEBSOCKET_DEFAULT_TIMEOUT, + self.socket.send(Message::Text(text.to_string().into())), + ); + block_on_websocket_future(&self.runtime, send) + .map_err(|_| PluginWebSocketError::new("WebSocket send timed out"))? .map_err(|error| PluginWebSocketError::new(format!("WebSocket send failed: {error}"))) } @@ -2571,11 +2612,14 @@ impl PluginWebSocketConnection for TungstenitePluginWebSocketConnection { timeout: Duration, max_message_bytes: usize, ) -> Result { - set_system_websocket_timeouts(&mut self.socket, timeout); for _ in 0..PLUGIN_WEBSOCKET_MAX_CONTROL_FRAMES { - let message = self.socket.read().map_err(|error| { - PluginWebSocketError::new(format!("WebSocket receive failed: {error}")) - })?; + let next = tokio::time::timeout(timeout, self.socket.next()); + let message = block_on_websocket_future(&self.runtime, next) + .map_err(|_| PluginWebSocketError::new("WebSocket receive timed out"))? + .ok_or_else(|| PluginWebSocketError::new("WebSocket stream ended"))? + .map_err(|error| { + PluginWebSocketError::new(format!("WebSocket receive failed: {error}")) + })?; match message { Message::Text(text) => { if text.len() > max_message_bytes { @@ -2595,9 +2639,15 @@ impl PluginWebSocketConnection for TungstenitePluginWebSocketConnection { } Message::Close(_) => return Ok(PluginWebSocketRecvResponse::Closed), Message::Ping(payload) => { - self.socket.send(Message::Pong(payload)).map_err(|error| { - PluginWebSocketError::new(format!("WebSocket pong failed: {error}")) - })? + let send = tokio::time::timeout( + PLUGIN_WEBSOCKET_DEFAULT_TIMEOUT, + self.socket.send(Message::Pong(payload)), + ); + block_on_websocket_future(&self.runtime, send) + .map_err(|_| PluginWebSocketError::new("WebSocket pong timed out"))? + .map_err(|error| { + PluginWebSocketError::new(format!("WebSocket pong failed: {error}")) + })?; } Message::Pong(_) | Message::Frame(_) => continue, } @@ -2608,26 +2658,30 @@ impl PluginWebSocketConnection for TungstenitePluginWebSocketConnection { } fn close(&mut self) -> Result<(), PluginWebSocketError> { - self.socket - .close(None) + let close = tokio::time::timeout(PLUGIN_WEBSOCKET_DEFAULT_TIMEOUT, self.socket.close(None)); + block_on_websocket_future(&self.runtime, close) + .map_err(|_| PluginWebSocketError::new("WebSocket close timed out"))? .map_err(|error| PluginWebSocketError::new(format!("WebSocket close failed: {error}"))) } } -fn set_system_websocket_timeouts(socket: &mut SystemWebSocket, timeout: Duration) { - match socket.get_mut() { - MaybeTlsStream::Plain(stream) => { - let _ = stream.set_read_timeout(Some(timeout)); - let _ = stream.set_write_timeout(Some(timeout)); - } - #[allow(unreachable_patterns)] - MaybeTlsStream::NativeTls(stream) => { - let stream = stream.get_ref(); - let _ = stream.set_read_timeout(Some(timeout)); - let _ = stream.set_write_timeout(Some(timeout)); - } - #[allow(unreachable_patterns)] - _ => {} +fn new_websocket_runtime() -> Result { + TokioRuntimeBuilder::new_current_thread() + .enable_all() + .build() + .map_err(|error| { + PluginWebSocketError::new(format!("WebSocket runtime build failed: {error}")) + }) +} + +fn block_on_websocket_future( + runtime: &TokioRuntime, + future: F, +) -> F::Output { + if TokioHandle::try_current().is_ok() { + tokio::task::block_in_place(|| runtime.block_on(future)) + } else { + runtime.block_on(future) } } @@ -2637,14 +2691,15 @@ struct PluginWebSocketHandles { } impl PluginWebSocketHandles { - fn insert( - &self, - connection: Box, - ) -> Result { + fn reserve_open(&self) -> Result { self.inner .lock() .expect("plugin websocket handle table poisoned") - .insert(connection) + .reserve_open()?; + Ok(PluginWebSocketOpenReservation { + handles: self.clone(), + active: true, + }) } fn with_connection( @@ -2671,11 +2726,54 @@ impl PluginWebSocketHandles { .expect("plugin websocket handle table poisoned") .close_all(); } + + #[cfg(test)] + fn reservation_count(&self) -> usize { + self.inner + .lock() + .expect("plugin websocket handle table poisoned") + .reservations + } +} + +struct PluginWebSocketOpenReservation { + handles: PluginWebSocketHandles, + active: bool, +} + +impl PluginWebSocketOpenReservation { + fn commit( + mut self, + connection: Box, + ) -> Result { + let result = self + .handles + .inner + .lock() + .expect("plugin websocket handle table poisoned") + .insert_reserved(connection); + self.active = false; + result + } +} + +impl Drop for PluginWebSocketOpenReservation { + fn drop(&mut self) { + if self.active { + self.handles + .inner + .lock() + .expect("plugin websocket handle table poisoned") + .release_reservation(); + self.active = false; + } + } } #[derive(Default)] struct PluginWebSocketHandleTable { next: u32, + reservations: usize, connections: HashMap, } @@ -2685,14 +2783,32 @@ struct PluginWebSocketHandleEntry { } impl PluginWebSocketHandleTable { - fn insert( + fn reserve_open(&mut self) -> Result<(), PluginWebSocketError> { + self.expire_stale(); + if self.connections.len() + self.reservations >= PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS { + return Err(PluginWebSocketError::new(format!( + "host_api.websocket open connection limit ({}) exceeded before dialing", + PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS + ))); + } + self.reservations += 1; + Ok(()) + } + + fn release_reservation(&mut self) { + self.reservations = self.reservations.saturating_sub(1); + } + + fn insert_reserved( &mut self, - connection: Box, + mut connection: Box, ) -> Result { + self.release_reservation(); self.expire_stale(); if self.connections.len() >= PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS { + let _ = connection.close(); return Err(PluginWebSocketError::new(format!( - "host_api.websocket open connection limit ({}) exceeded", + "host_api.websocket open connection limit ({}) exceeded while committing reserved handle", PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS ))); } @@ -2741,6 +2857,7 @@ impl PluginWebSocketHandleTable { } fn close_all(&mut self) { + self.reservations = 0; for (_, mut entry) in self.connections.drain() { let _ = entry.connection.close(); } @@ -7288,6 +7405,10 @@ input_schema = { type = "object", additionalProperties = true } } impl PluginWebSocketClient for MockWebSocketClient { + fn supports_bounded_open(&self) -> bool { + true + } + fn open( &self, _request: &PluginWebSocketOpenRequest, @@ -7328,10 +7449,52 @@ input_schema = { type = "object", additionalProperties = true } fn close(&mut self) -> Result<(), PluginWebSocketError> { self.closed .fetch_add(1, std::sync::atomic::Ordering::SeqCst); + Ok(()) } } + #[derive(Clone, Default)] + struct FailingWebSocketClient { + opens: Arc, + } + + impl PluginWebSocketClient for FailingWebSocketClient { + fn supports_bounded_open(&self) -> bool { + true + } + + fn open( + &self, + _request: &PluginWebSocketOpenRequest, + _url: &reqwest::Url, + _limits: PluginWebSocketLimits, + ) -> Result, PluginWebSocketError> { + self.opens.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + Err(PluginWebSocketError::new("simulated bounded open failure")) + } + } + + #[derive(Clone, Default)] + struct UnboundedWebSocketClient { + opens: Arc, + } + + impl PluginWebSocketClient for UnboundedWebSocketClient { + fn supports_bounded_open(&self) -> bool { + false + } + + fn open( + &self, + _request: &PluginWebSocketOpenRequest, + _url: &reqwest::Url, + _limits: PluginWebSocketLimits, + ) -> Result, PluginWebSocketError> { + self.opens.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + Err(PluginWebSocketError::new("should not dial")) + } + } fn websocket_grant( scheme: &str, host: &str, @@ -7357,6 +7520,124 @@ input_schema = { type = "object", additionalProperties = true } record } + #[test] + fn websocket_max_open_connections_rejects_before_network_open() { + let record = record_with_websocket( + vec![websocket_grant( + "wss", + "gateway.example.com", + None, + &["/gateway"], + )], + vec![websocket_grant( + "wss", + "gateway.example.com", + None, + &["/gateway"], + )], + ); + let client = MockWebSocketClient::default(); + let handles = PluginWebSocketHandles::default(); + for _ in 0..PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS { + execute_plugin_websocket_open( + &record, + &client, + &handles, + br#"{"url":"wss://gateway.example.com/gateway"}"#, + ) + .unwrap(); + } + assert_eq!( + client.opens.load(std::sync::atomic::Ordering::SeqCst), + PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS + ); + let error = execute_plugin_websocket_open( + &record, + &client, + &handles, + br#"{"url":"wss://gateway.example.com/gateway"}"#, + ) + .unwrap_err(); + assert!(error.0.contains("before dialing")); + assert_eq!( + client.opens.load(std::sync::atomic::Ordering::SeqCst), + PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS + ); + } + + #[test] + fn websocket_open_failure_releases_capacity_reservation() { + let record = record_with_websocket( + vec![websocket_grant( + "wss", + "gateway.example.com", + None, + &["/gateway"], + )], + vec![websocket_grant( + "wss", + "gateway.example.com", + None, + &["/gateway"], + )], + ); + let failing = FailingWebSocketClient::default(); + let handles = PluginWebSocketHandles::default(); + let error = execute_plugin_websocket_open( + &record, + &failing, + &handles, + br#"{"url":"wss://gateway.example.com/gateway"}"#, + ) + .unwrap_err(); + assert!(error.0.contains("simulated bounded open failure")); + assert_eq!(handles.reservation_count(), 0); + + let client = MockWebSocketClient::default(); + let open = execute_plugin_websocket_open( + &record, + &client, + &handles, + br#"{"url":"wss://gateway.example.com/gateway"}"#, + ) + .unwrap(); + let open: PluginWebSocketOpenResponse = serde_json::from_slice(&open).unwrap(); + assert_eq!(open.handle, 1); + } + + #[test] + fn websocket_unbounded_open_client_fails_closed_before_dialing() { + let record = record_with_websocket( + vec![websocket_grant( + "wss", + "gateway.example.com", + None, + &["/gateway"], + )], + vec![websocket_grant( + "wss", + "gateway.example.com", + None, + &["/gateway"], + )], + ); + let client = UnboundedWebSocketClient::default(); + let handles = PluginWebSocketHandles::default(); + let error = execute_plugin_websocket_open( + &record, + &client, + &handles, + br#"{"url":"wss://gateway.example.com/gateway"}"#, + ) + .unwrap_err(); + assert!( + error + .0 + .contains("cannot guarantee bounded/cancellable open") + ); + assert_eq!(client.opens.load(std::sync::atomic::Ordering::SeqCst), 0); + assert_eq!(handles.reservation_count(), 0); + } #[test] fn websocket_open_send_recv_close_is_bounded_and_explicit() { let record = record_with_websocket( diff --git a/package.nix b/package.nix index 8a7fe449..a0d9e420 100644 --- a/package.nix +++ b/package.nix @@ -43,7 +43,7 @@ rustPlatform.buildRustPackage rec { filter = sourceFilter; }; - cargoHash = "sha256-TZrw6nJclXVRpFIUlYvimGTDXlxBMaQt6oM5C5DIGIU="; + cargoHash = "sha256-cZxkmM42kbDp1Rv9gn4sCD5WIQLc0wCbjj4GbKjuA9Q="; depsExtraArgs = { # Older fetchCargoVendor utilities used crates.io's API download endpoint,