fix: bound plugin websocket open

This commit is contained in:
Keisuke Hirata 2026-06-21 22:14:58 +09:00
parent ce62d23502
commit a766048f29
No known key found for this signature in database
4 changed files with 352 additions and 53 deletions

16
Cargo.lock generated
View File

@ -2885,6 +2885,7 @@ dependencies = [
"dotenv", "dotenv",
"fs4", "fs4",
"futures", "futures",
"futures-util",
"include_dir", "include_dir",
"libc", "libc",
"llm-worker", "llm-worker",
@ -2906,6 +2907,7 @@ dependencies = [
"thiserror 2.0.18", "thiserror 2.0.18",
"ticket", "ticket",
"tokio", "tokio",
"tokio-tungstenite",
"toml", "toml",
"tools", "tools",
"tracing", "tracing",
@ -4471,6 +4473,20 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "tokio-tungstenite"
version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857"
dependencies = [
"futures-util",
"log",
"native-tls",
"tokio",
"tokio-native-tls",
"tungstenite",
]
[[package]] [[package]]
name = "tokio-util" name = "tokio-util"
version = "0.7.18" version = "0.7.18"

View File

@ -40,6 +40,8 @@ arc-swap = "1.9.1"
wasmi = { version = "0.51.1", default-features = false, features = ["std", "extra-checks"] } wasmi = { version = "0.51.1", default-features = false, features = ["std", "extra-checks"] }
wasmtime = { version = "45.0.2", default-features = false, features = ["std", "runtime", "cranelift", "component-model"] } wasmtime = { version = "45.0.2", default-features = false, features = ["std", "runtime", "cranelift", "component-model"] }
tungstenite = { version = "0.28.0", default-features = false, features = ["handshake", "native-tls", "url"] } tungstenite = { version = "0.28.0", default-features = false, features = ["handshake", "native-tls", "url"] }
tokio-tungstenite = { version = "0.28.0", default-features = false, features = ["native-tls", "connect"] }
futures-util = { version = "0.3", features = ["sink"] }
[dev-dependencies] [dev-dependencies]
dotenv = "0.15.0" dotenv = "0.15.0"

View File

@ -17,6 +17,7 @@ use std::sync::{Arc, Mutex, OnceLock};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use async_trait::async_trait; use async_trait::async_trait;
use futures_util::{SinkExt, StreamExt};
use llm_worker::tool::{ use llm_worker::tool::{
Tool, ToolDefinition, ToolError, ToolExecutionContext, ToolMeta, ToolOrigin, ToolOutput, Tool, ToolDefinition, ToolError, ToolExecutionContext, ToolMeta, ToolOrigin, ToolOutput,
}; };
@ -29,9 +30,11 @@ use manifest::plugin::{
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use tungstenite::client::IntoClientRequest; use tokio::runtime::{
use tungstenite::protocol::{Message, WebSocketConfig}; Builder as TokioRuntimeBuilder, Handle as TokioHandle, Runtime as TokioRuntime,
use tungstenite::stream::MaybeTlsStream; };
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::protocol::{Message, WebSocketConfig};
use super::{ use super::{
FeatureDescriptor, FeatureId, FeatureInstallContext, FeatureInstallError, FeatureModule, FeatureDescriptor, FeatureId, FeatureInstallContext, FeatureInstallError, FeatureModule,
@ -927,8 +930,14 @@ fn execute_plugin_websocket_open(
) -> Result<Vec<u8>, PluginWebSocketError> { ) -> Result<Vec<u8>, PluginWebSocketError> {
let (request, url) = validate_plugin_websocket_open_request(record, bytes)?; let (request, url) = validate_plugin_websocket_open_request(record, bytes)?;
let limits = PluginWebSocketLimits::default(); let limits = PluginWebSocketLimits::default();
if !client.supports_bounded_open() {
return Err(PluginWebSocketError::new(
"host_api.websocket client cannot guarantee bounded/cancellable open; refusing to dial",
));
}
let reservation = handles.reserve_open()?;
let connection = client.open(&request, &url, limits)?; let connection = client.open(&request, &url, limits)?;
let handle = handles.insert(connection)?; let handle = reservation.commit(connection)?;
serde_json::to_vec(&PluginWebSocketOpenResponse { serde_json::to_vec(&PluginWebSocketOpenResponse {
handle, handle,
url: safe_url(&url), url: safe_url(&url),
@ -2511,6 +2520,8 @@ trait PluginWebSocketConnection: Send {
} }
trait PluginWebSocketClient: Send + Sync { trait PluginWebSocketClient: Send + Sync {
fn supports_bounded_open(&self) -> bool;
fn open( fn open(
&self, &self,
request: &PluginWebSocketOpenRequest, request: &PluginWebSocketOpenRequest,
@ -2521,13 +2532,19 @@ trait PluginWebSocketClient: Send + Sync {
struct TungstenitePluginWebSocketClient; struct TungstenitePluginWebSocketClient;
type SystemWebSocket = tungstenite::WebSocket<MaybeTlsStream<std::net::TcpStream>>; type AsyncSystemWebSocket =
tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
struct TungstenitePluginWebSocketConnection { struct TungstenitePluginWebSocketConnection {
socket: SystemWebSocket, runtime: TokioRuntime,
socket: AsyncSystemWebSocket,
} }
impl PluginWebSocketClient for TungstenitePluginWebSocketClient { impl PluginWebSocketClient for TungstenitePluginWebSocketClient {
fn supports_bounded_open(&self) -> bool {
true
}
fn open( fn open(
&self, &self,
_request: &PluginWebSocketOpenRequest, _request: &PluginWebSocketOpenRequest,
@ -2538,31 +2555,55 @@ impl PluginWebSocketClient for TungstenitePluginWebSocketClient {
PluginWebSocketError::new(format!("WebSocket request build failed: {error}")) PluginWebSocketError::new(format!("WebSocket request build failed: {error}"))
})?; })?;
request.headers_mut().insert( request.headers_mut().insert(
tungstenite::http::header::USER_AGENT, tokio_tungstenite::tungstenite::http::header::USER_AGENT,
tungstenite::http::HeaderValue::from_static("yoi-plugin-host/1"), tokio_tungstenite::tungstenite::http::HeaderValue::from_static("yoi-plugin-host/1"),
); );
let config = WebSocketConfig::default() let config = WebSocketConfig::default()
.max_message_size(Some(limits.max_message_bytes)) .max_message_size(Some(limits.max_message_bytes))
.max_frame_size(Some(PLUGIN_WEBSOCKET_MAX_FRAME_BYTES)) .max_frame_size(Some(PLUGIN_WEBSOCKET_MAX_FRAME_BYTES))
.accept_unmasked_frames(false); .accept_unmasked_frames(false);
let (mut socket, _response) = let runtime = new_websocket_runtime()?;
tungstenite::client::connect_with_config(request, Some(config), 0).map_err( let open = async {
|error| { tokio::time::timeout(
limits.timeout,
tokio_tungstenite::connect_async_tls_with_config(
request,
Some(config),
false,
None,
),
)
.await
};
let (socket, _response) = block_on_websocket_future(&runtime, open)
.map_err(|error| {
PluginWebSocketError::new(format!(
"WebSocket open timed out after {} ms for {}: {error}",
limits.timeout.as_millis(),
safe_url(url)
))
})?
.map_err(|error| {
PluginWebSocketError::new(format!( PluginWebSocketError::new(format!(
"WebSocket connection failed for {}: {error}", "WebSocket connection failed for {}: {error}",
safe_url(url) safe_url(url)
)) ))
}, })?;
)?; Ok(Box::new(TungstenitePluginWebSocketConnection {
set_system_websocket_timeouts(&mut socket, limits.timeout); runtime,
Ok(Box::new(TungstenitePluginWebSocketConnection { socket })) socket,
}))
} }
} }
impl PluginWebSocketConnection for TungstenitePluginWebSocketConnection { impl PluginWebSocketConnection for TungstenitePluginWebSocketConnection {
fn send_text(&mut self, text: &str) -> Result<(), PluginWebSocketError> { fn send_text(&mut self, text: &str) -> Result<(), PluginWebSocketError> {
self.socket let send = tokio::time::timeout(
.send(Message::Text(text.to_string().into())) PLUGIN_WEBSOCKET_DEFAULT_TIMEOUT,
self.socket.send(Message::Text(text.to_string().into())),
);
block_on_websocket_future(&self.runtime, send)
.map_err(|_| PluginWebSocketError::new("WebSocket send timed out"))?
.map_err(|error| PluginWebSocketError::new(format!("WebSocket send failed: {error}"))) .map_err(|error| PluginWebSocketError::new(format!("WebSocket send failed: {error}")))
} }
@ -2571,9 +2612,12 @@ impl PluginWebSocketConnection for TungstenitePluginWebSocketConnection {
timeout: Duration, timeout: Duration,
max_message_bytes: usize, max_message_bytes: usize,
) -> Result<PluginWebSocketRecvResponse, PluginWebSocketError> { ) -> Result<PluginWebSocketRecvResponse, PluginWebSocketError> {
set_system_websocket_timeouts(&mut self.socket, timeout);
for _ in 0..PLUGIN_WEBSOCKET_MAX_CONTROL_FRAMES { for _ in 0..PLUGIN_WEBSOCKET_MAX_CONTROL_FRAMES {
let message = self.socket.read().map_err(|error| { let next = tokio::time::timeout(timeout, self.socket.next());
let message = block_on_websocket_future(&self.runtime, next)
.map_err(|_| PluginWebSocketError::new("WebSocket receive timed out"))?
.ok_or_else(|| PluginWebSocketError::new("WebSocket stream ended"))?
.map_err(|error| {
PluginWebSocketError::new(format!("WebSocket receive failed: {error}")) PluginWebSocketError::new(format!("WebSocket receive failed: {error}"))
})?; })?;
match message { match message {
@ -2595,9 +2639,15 @@ impl PluginWebSocketConnection for TungstenitePluginWebSocketConnection {
} }
Message::Close(_) => return Ok(PluginWebSocketRecvResponse::Closed), Message::Close(_) => return Ok(PluginWebSocketRecvResponse::Closed),
Message::Ping(payload) => { Message::Ping(payload) => {
self.socket.send(Message::Pong(payload)).map_err(|error| { let send = tokio::time::timeout(
PLUGIN_WEBSOCKET_DEFAULT_TIMEOUT,
self.socket.send(Message::Pong(payload)),
);
block_on_websocket_future(&self.runtime, send)
.map_err(|_| PluginWebSocketError::new("WebSocket pong timed out"))?
.map_err(|error| {
PluginWebSocketError::new(format!("WebSocket pong failed: {error}")) PluginWebSocketError::new(format!("WebSocket pong failed: {error}"))
})? })?;
} }
Message::Pong(_) | Message::Frame(_) => continue, Message::Pong(_) | Message::Frame(_) => continue,
} }
@ -2608,26 +2658,30 @@ impl PluginWebSocketConnection for TungstenitePluginWebSocketConnection {
} }
fn close(&mut self) -> Result<(), PluginWebSocketError> { fn close(&mut self) -> Result<(), PluginWebSocketError> {
self.socket let close = tokio::time::timeout(PLUGIN_WEBSOCKET_DEFAULT_TIMEOUT, self.socket.close(None));
.close(None) block_on_websocket_future(&self.runtime, close)
.map_err(|_| PluginWebSocketError::new("WebSocket close timed out"))?
.map_err(|error| PluginWebSocketError::new(format!("WebSocket close failed: {error}"))) .map_err(|error| PluginWebSocketError::new(format!("WebSocket close failed: {error}")))
} }
} }
fn set_system_websocket_timeouts(socket: &mut SystemWebSocket, timeout: Duration) { fn new_websocket_runtime() -> Result<TokioRuntime, PluginWebSocketError> {
match socket.get_mut() { TokioRuntimeBuilder::new_current_thread()
MaybeTlsStream::Plain(stream) => { .enable_all()
let _ = stream.set_read_timeout(Some(timeout)); .build()
let _ = stream.set_write_timeout(Some(timeout)); .map_err(|error| {
} PluginWebSocketError::new(format!("WebSocket runtime build failed: {error}"))
#[allow(unreachable_patterns)] })
MaybeTlsStream::NativeTls(stream) => { }
let stream = stream.get_ref();
let _ = stream.set_read_timeout(Some(timeout)); fn block_on_websocket_future<F: std::future::Future>(
let _ = stream.set_write_timeout(Some(timeout)); runtime: &TokioRuntime,
} future: F,
#[allow(unreachable_patterns)] ) -> F::Output {
_ => {} if TokioHandle::try_current().is_ok() {
tokio::task::block_in_place(|| runtime.block_on(future))
} else {
runtime.block_on(future)
} }
} }
@ -2637,14 +2691,15 @@ struct PluginWebSocketHandles {
} }
impl PluginWebSocketHandles { impl PluginWebSocketHandles {
fn insert( fn reserve_open(&self) -> Result<PluginWebSocketOpenReservation, PluginWebSocketError> {
&self,
connection: Box<dyn PluginWebSocketConnection>,
) -> Result<u32, PluginWebSocketError> {
self.inner self.inner
.lock() .lock()
.expect("plugin websocket handle table poisoned") .expect("plugin websocket handle table poisoned")
.insert(connection) .reserve_open()?;
Ok(PluginWebSocketOpenReservation {
handles: self.clone(),
active: true,
})
} }
fn with_connection<T>( fn with_connection<T>(
@ -2671,11 +2726,54 @@ impl PluginWebSocketHandles {
.expect("plugin websocket handle table poisoned") .expect("plugin websocket handle table poisoned")
.close_all(); .close_all();
} }
#[cfg(test)]
fn reservation_count(&self) -> usize {
self.inner
.lock()
.expect("plugin websocket handle table poisoned")
.reservations
}
}
struct PluginWebSocketOpenReservation {
handles: PluginWebSocketHandles,
active: bool,
}
impl PluginWebSocketOpenReservation {
fn commit(
mut self,
connection: Box<dyn PluginWebSocketConnection>,
) -> Result<u32, PluginWebSocketError> {
let result = self
.handles
.inner
.lock()
.expect("plugin websocket handle table poisoned")
.insert_reserved(connection);
self.active = false;
result
}
}
impl Drop for PluginWebSocketOpenReservation {
fn drop(&mut self) {
if self.active {
self.handles
.inner
.lock()
.expect("plugin websocket handle table poisoned")
.release_reservation();
self.active = false;
}
}
} }
#[derive(Default)] #[derive(Default)]
struct PluginWebSocketHandleTable { struct PluginWebSocketHandleTable {
next: u32, next: u32,
reservations: usize,
connections: HashMap<u32, PluginWebSocketHandleEntry>, connections: HashMap<u32, PluginWebSocketHandleEntry>,
} }
@ -2685,14 +2783,32 @@ struct PluginWebSocketHandleEntry {
} }
impl PluginWebSocketHandleTable { impl PluginWebSocketHandleTable {
fn insert( fn reserve_open(&mut self) -> Result<(), PluginWebSocketError> {
self.expire_stale();
if self.connections.len() + self.reservations >= PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS {
return Err(PluginWebSocketError::new(format!(
"host_api.websocket open connection limit ({}) exceeded before dialing",
PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS
)));
}
self.reservations += 1;
Ok(())
}
fn release_reservation(&mut self) {
self.reservations = self.reservations.saturating_sub(1);
}
fn insert_reserved(
&mut self, &mut self,
connection: Box<dyn PluginWebSocketConnection>, mut connection: Box<dyn PluginWebSocketConnection>,
) -> Result<u32, PluginWebSocketError> { ) -> Result<u32, PluginWebSocketError> {
self.release_reservation();
self.expire_stale(); self.expire_stale();
if self.connections.len() >= PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS { if self.connections.len() >= PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS {
let _ = connection.close();
return Err(PluginWebSocketError::new(format!( return Err(PluginWebSocketError::new(format!(
"host_api.websocket open connection limit ({}) exceeded", "host_api.websocket open connection limit ({}) exceeded while committing reserved handle",
PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS
))); )));
} }
@ -2741,6 +2857,7 @@ impl PluginWebSocketHandleTable {
} }
fn close_all(&mut self) { fn close_all(&mut self) {
self.reservations = 0;
for (_, mut entry) in self.connections.drain() { for (_, mut entry) in self.connections.drain() {
let _ = entry.connection.close(); let _ = entry.connection.close();
} }
@ -7288,6 +7405,10 @@ input_schema = { type = "object", additionalProperties = true }
} }
impl PluginWebSocketClient for MockWebSocketClient { impl PluginWebSocketClient for MockWebSocketClient {
fn supports_bounded_open(&self) -> bool {
true
}
fn open( fn open(
&self, &self,
_request: &PluginWebSocketOpenRequest, _request: &PluginWebSocketOpenRequest,
@ -7328,10 +7449,52 @@ input_schema = { type = "object", additionalProperties = true }
fn close(&mut self) -> Result<(), PluginWebSocketError> { fn close(&mut self) -> Result<(), PluginWebSocketError> {
self.closed self.closed
.fetch_add(1, std::sync::atomic::Ordering::SeqCst); .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Ok(()) Ok(())
} }
} }
#[derive(Clone, Default)]
struct FailingWebSocketClient {
opens: Arc<std::sync::atomic::AtomicUsize>,
}
impl PluginWebSocketClient for FailingWebSocketClient {
fn supports_bounded_open(&self) -> bool {
true
}
fn open(
&self,
_request: &PluginWebSocketOpenRequest,
_url: &reqwest::Url,
_limits: PluginWebSocketLimits,
) -> Result<Box<dyn PluginWebSocketConnection>, PluginWebSocketError> {
self.opens.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Err(PluginWebSocketError::new("simulated bounded open failure"))
}
}
#[derive(Clone, Default)]
struct UnboundedWebSocketClient {
opens: Arc<std::sync::atomic::AtomicUsize>,
}
impl PluginWebSocketClient for UnboundedWebSocketClient {
fn supports_bounded_open(&self) -> bool {
false
}
fn open(
&self,
_request: &PluginWebSocketOpenRequest,
_url: &reqwest::Url,
_limits: PluginWebSocketLimits,
) -> Result<Box<dyn PluginWebSocketConnection>, PluginWebSocketError> {
self.opens.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Err(PluginWebSocketError::new("should not dial"))
}
}
fn websocket_grant( fn websocket_grant(
scheme: &str, scheme: &str,
host: &str, host: &str,
@ -7357,6 +7520,124 @@ input_schema = { type = "object", additionalProperties = true }
record record
} }
#[test]
fn websocket_max_open_connections_rejects_before_network_open() {
let record = record_with_websocket(
vec![websocket_grant(
"wss",
"gateway.example.com",
None,
&["/gateway"],
)],
vec![websocket_grant(
"wss",
"gateway.example.com",
None,
&["/gateway"],
)],
);
let client = MockWebSocketClient::default();
let handles = PluginWebSocketHandles::default();
for _ in 0..PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS {
execute_plugin_websocket_open(
&record,
&client,
&handles,
br#"{"url":"wss://gateway.example.com/gateway"}"#,
)
.unwrap();
}
assert_eq!(
client.opens.load(std::sync::atomic::Ordering::SeqCst),
PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS
);
let error = execute_plugin_websocket_open(
&record,
&client,
&handles,
br#"{"url":"wss://gateway.example.com/gateway"}"#,
)
.unwrap_err();
assert!(error.0.contains("before dialing"));
assert_eq!(
client.opens.load(std::sync::atomic::Ordering::SeqCst),
PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS
);
}
#[test]
fn websocket_open_failure_releases_capacity_reservation() {
let record = record_with_websocket(
vec![websocket_grant(
"wss",
"gateway.example.com",
None,
&["/gateway"],
)],
vec![websocket_grant(
"wss",
"gateway.example.com",
None,
&["/gateway"],
)],
);
let failing = FailingWebSocketClient::default();
let handles = PluginWebSocketHandles::default();
let error = execute_plugin_websocket_open(
&record,
&failing,
&handles,
br#"{"url":"wss://gateway.example.com/gateway"}"#,
)
.unwrap_err();
assert!(error.0.contains("simulated bounded open failure"));
assert_eq!(handles.reservation_count(), 0);
let client = MockWebSocketClient::default();
let open = execute_plugin_websocket_open(
&record,
&client,
&handles,
br#"{"url":"wss://gateway.example.com/gateway"}"#,
)
.unwrap();
let open: PluginWebSocketOpenResponse = serde_json::from_slice(&open).unwrap();
assert_eq!(open.handle, 1);
}
#[test]
fn websocket_unbounded_open_client_fails_closed_before_dialing() {
let record = record_with_websocket(
vec![websocket_grant(
"wss",
"gateway.example.com",
None,
&["/gateway"],
)],
vec![websocket_grant(
"wss",
"gateway.example.com",
None,
&["/gateway"],
)],
);
let client = UnboundedWebSocketClient::default();
let handles = PluginWebSocketHandles::default();
let error = execute_plugin_websocket_open(
&record,
&client,
&handles,
br#"{"url":"wss://gateway.example.com/gateway"}"#,
)
.unwrap_err();
assert!(
error
.0
.contains("cannot guarantee bounded/cancellable open")
);
assert_eq!(client.opens.load(std::sync::atomic::Ordering::SeqCst), 0);
assert_eq!(handles.reservation_count(), 0);
}
#[test] #[test]
fn websocket_open_send_recv_close_is_bounded_and_explicit() { fn websocket_open_send_recv_close_is_bounded_and_explicit() {
let record = record_with_websocket( let record = record_with_websocket(

View File

@ -43,7 +43,7 @@ rustPlatform.buildRustPackage rec {
filter = sourceFilter; filter = sourceFilter;
}; };
cargoHash = "sha256-TZrw6nJclXVRpFIUlYvimGTDXlxBMaQt6oM5C5DIGIU="; cargoHash = "sha256-cZxkmM42kbDp1Rv9gn4sCD5WIQLc0wCbjj4GbKjuA9Q=";
depsExtraArgs = { depsExtraArgs = {
# Older fetchCargoVendor utilities used crates.io's API download endpoint, # Older fetchCargoVendor utilities used crates.io's API download endpoint,