fix: bound plugin websocket open
This commit is contained in:
parent
ce62d23502
commit
a766048f29
16
Cargo.lock
generated
16
Cargo.lock
generated
|
|
@ -2885,6 +2885,7 @@ dependencies = [
|
||||||
"dotenv",
|
"dotenv",
|
||||||
"fs4",
|
"fs4",
|
||||||
"futures",
|
"futures",
|
||||||
|
"futures-util",
|
||||||
"include_dir",
|
"include_dir",
|
||||||
"libc",
|
"libc",
|
||||||
"llm-worker",
|
"llm-worker",
|
||||||
|
|
@ -2906,6 +2907,7 @@ dependencies = [
|
||||||
"thiserror 2.0.18",
|
"thiserror 2.0.18",
|
||||||
"ticket",
|
"ticket",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tokio-tungstenite",
|
||||||
"toml",
|
"toml",
|
||||||
"tools",
|
"tools",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
|
@ -4471,6 +4473,20 @@ dependencies = [
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokio-tungstenite"
|
||||||
|
version = "0.28.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857"
|
||||||
|
dependencies = [
|
||||||
|
"futures-util",
|
||||||
|
"log",
|
||||||
|
"native-tls",
|
||||||
|
"tokio",
|
||||||
|
"tokio-native-tls",
|
||||||
|
"tungstenite",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio-util"
|
name = "tokio-util"
|
||||||
version = "0.7.18"
|
version = "0.7.18"
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,8 @@ arc-swap = "1.9.1"
|
||||||
wasmi = { version = "0.51.1", default-features = false, features = ["std", "extra-checks"] }
|
wasmi = { version = "0.51.1", default-features = false, features = ["std", "extra-checks"] }
|
||||||
wasmtime = { version = "45.0.2", default-features = false, features = ["std", "runtime", "cranelift", "component-model"] }
|
wasmtime = { version = "45.0.2", default-features = false, features = ["std", "runtime", "cranelift", "component-model"] }
|
||||||
tungstenite = { version = "0.28.0", default-features = false, features = ["handshake", "native-tls", "url"] }
|
tungstenite = { version = "0.28.0", default-features = false, features = ["handshake", "native-tls", "url"] }
|
||||||
|
tokio-tungstenite = { version = "0.28.0", default-features = false, features = ["native-tls", "connect"] }
|
||||||
|
futures-util = { version = "0.3", features = ["sink"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
dotenv = "0.15.0"
|
dotenv = "0.15.0"
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ use std::sync::{Arc, Mutex, OnceLock};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use futures_util::{SinkExt, StreamExt};
|
||||||
use llm_worker::tool::{
|
use llm_worker::tool::{
|
||||||
Tool, ToolDefinition, ToolError, ToolExecutionContext, ToolMeta, ToolOrigin, ToolOutput,
|
Tool, ToolDefinition, ToolError, ToolExecutionContext, ToolMeta, ToolOrigin, ToolOutput,
|
||||||
};
|
};
|
||||||
|
|
@ -29,9 +30,11 @@ use manifest::plugin::{
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use tungstenite::client::IntoClientRequest;
|
use tokio::runtime::{
|
||||||
use tungstenite::protocol::{Message, WebSocketConfig};
|
Builder as TokioRuntimeBuilder, Handle as TokioHandle, Runtime as TokioRuntime,
|
||||||
use tungstenite::stream::MaybeTlsStream;
|
};
|
||||||
|
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
||||||
|
use tokio_tungstenite::tungstenite::protocol::{Message, WebSocketConfig};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
FeatureDescriptor, FeatureId, FeatureInstallContext, FeatureInstallError, FeatureModule,
|
FeatureDescriptor, FeatureId, FeatureInstallContext, FeatureInstallError, FeatureModule,
|
||||||
|
|
@ -927,8 +930,14 @@ fn execute_plugin_websocket_open(
|
||||||
) -> Result<Vec<u8>, PluginWebSocketError> {
|
) -> Result<Vec<u8>, PluginWebSocketError> {
|
||||||
let (request, url) = validate_plugin_websocket_open_request(record, bytes)?;
|
let (request, url) = validate_plugin_websocket_open_request(record, bytes)?;
|
||||||
let limits = PluginWebSocketLimits::default();
|
let limits = PluginWebSocketLimits::default();
|
||||||
|
if !client.supports_bounded_open() {
|
||||||
|
return Err(PluginWebSocketError::new(
|
||||||
|
"host_api.websocket client cannot guarantee bounded/cancellable open; refusing to dial",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let reservation = handles.reserve_open()?;
|
||||||
let connection = client.open(&request, &url, limits)?;
|
let connection = client.open(&request, &url, limits)?;
|
||||||
let handle = handles.insert(connection)?;
|
let handle = reservation.commit(connection)?;
|
||||||
serde_json::to_vec(&PluginWebSocketOpenResponse {
|
serde_json::to_vec(&PluginWebSocketOpenResponse {
|
||||||
handle,
|
handle,
|
||||||
url: safe_url(&url),
|
url: safe_url(&url),
|
||||||
|
|
@ -2511,6 +2520,8 @@ trait PluginWebSocketConnection: Send {
|
||||||
}
|
}
|
||||||
|
|
||||||
trait PluginWebSocketClient: Send + Sync {
|
trait PluginWebSocketClient: Send + Sync {
|
||||||
|
fn supports_bounded_open(&self) -> bool;
|
||||||
|
|
||||||
fn open(
|
fn open(
|
||||||
&self,
|
&self,
|
||||||
request: &PluginWebSocketOpenRequest,
|
request: &PluginWebSocketOpenRequest,
|
||||||
|
|
@ -2521,13 +2532,19 @@ trait PluginWebSocketClient: Send + Sync {
|
||||||
|
|
||||||
struct TungstenitePluginWebSocketClient;
|
struct TungstenitePluginWebSocketClient;
|
||||||
|
|
||||||
type SystemWebSocket = tungstenite::WebSocket<MaybeTlsStream<std::net::TcpStream>>;
|
type AsyncSystemWebSocket =
|
||||||
|
tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
|
||||||
|
|
||||||
struct TungstenitePluginWebSocketConnection {
|
struct TungstenitePluginWebSocketConnection {
|
||||||
socket: SystemWebSocket,
|
runtime: TokioRuntime,
|
||||||
|
socket: AsyncSystemWebSocket,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PluginWebSocketClient for TungstenitePluginWebSocketClient {
|
impl PluginWebSocketClient for TungstenitePluginWebSocketClient {
|
||||||
|
fn supports_bounded_open(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
fn open(
|
fn open(
|
||||||
&self,
|
&self,
|
||||||
_request: &PluginWebSocketOpenRequest,
|
_request: &PluginWebSocketOpenRequest,
|
||||||
|
|
@ -2538,31 +2555,55 @@ impl PluginWebSocketClient for TungstenitePluginWebSocketClient {
|
||||||
PluginWebSocketError::new(format!("WebSocket request build failed: {error}"))
|
PluginWebSocketError::new(format!("WebSocket request build failed: {error}"))
|
||||||
})?;
|
})?;
|
||||||
request.headers_mut().insert(
|
request.headers_mut().insert(
|
||||||
tungstenite::http::header::USER_AGENT,
|
tokio_tungstenite::tungstenite::http::header::USER_AGENT,
|
||||||
tungstenite::http::HeaderValue::from_static("yoi-plugin-host/1"),
|
tokio_tungstenite::tungstenite::http::HeaderValue::from_static("yoi-plugin-host/1"),
|
||||||
);
|
);
|
||||||
let config = WebSocketConfig::default()
|
let config = WebSocketConfig::default()
|
||||||
.max_message_size(Some(limits.max_message_bytes))
|
.max_message_size(Some(limits.max_message_bytes))
|
||||||
.max_frame_size(Some(PLUGIN_WEBSOCKET_MAX_FRAME_BYTES))
|
.max_frame_size(Some(PLUGIN_WEBSOCKET_MAX_FRAME_BYTES))
|
||||||
.accept_unmasked_frames(false);
|
.accept_unmasked_frames(false);
|
||||||
let (mut socket, _response) =
|
let runtime = new_websocket_runtime()?;
|
||||||
tungstenite::client::connect_with_config(request, Some(config), 0).map_err(
|
let open = async {
|
||||||
|error| {
|
tokio::time::timeout(
|
||||||
|
limits.timeout,
|
||||||
|
tokio_tungstenite::connect_async_tls_with_config(
|
||||||
|
request,
|
||||||
|
Some(config),
|
||||||
|
false,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
};
|
||||||
|
let (socket, _response) = block_on_websocket_future(&runtime, open)
|
||||||
|
.map_err(|error| {
|
||||||
|
PluginWebSocketError::new(format!(
|
||||||
|
"WebSocket open timed out after {} ms for {}: {error}",
|
||||||
|
limits.timeout.as_millis(),
|
||||||
|
safe_url(url)
|
||||||
|
))
|
||||||
|
})?
|
||||||
|
.map_err(|error| {
|
||||||
PluginWebSocketError::new(format!(
|
PluginWebSocketError::new(format!(
|
||||||
"WebSocket connection failed for {}: {error}",
|
"WebSocket connection failed for {}: {error}",
|
||||||
safe_url(url)
|
safe_url(url)
|
||||||
))
|
))
|
||||||
},
|
})?;
|
||||||
)?;
|
Ok(Box::new(TungstenitePluginWebSocketConnection {
|
||||||
set_system_websocket_timeouts(&mut socket, limits.timeout);
|
runtime,
|
||||||
Ok(Box::new(TungstenitePluginWebSocketConnection { socket }))
|
socket,
|
||||||
|
}))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PluginWebSocketConnection for TungstenitePluginWebSocketConnection {
|
impl PluginWebSocketConnection for TungstenitePluginWebSocketConnection {
|
||||||
fn send_text(&mut self, text: &str) -> Result<(), PluginWebSocketError> {
|
fn send_text(&mut self, text: &str) -> Result<(), PluginWebSocketError> {
|
||||||
self.socket
|
let send = tokio::time::timeout(
|
||||||
.send(Message::Text(text.to_string().into()))
|
PLUGIN_WEBSOCKET_DEFAULT_TIMEOUT,
|
||||||
|
self.socket.send(Message::Text(text.to_string().into())),
|
||||||
|
);
|
||||||
|
block_on_websocket_future(&self.runtime, send)
|
||||||
|
.map_err(|_| PluginWebSocketError::new("WebSocket send timed out"))?
|
||||||
.map_err(|error| PluginWebSocketError::new(format!("WebSocket send failed: {error}")))
|
.map_err(|error| PluginWebSocketError::new(format!("WebSocket send failed: {error}")))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2571,9 +2612,12 @@ impl PluginWebSocketConnection for TungstenitePluginWebSocketConnection {
|
||||||
timeout: Duration,
|
timeout: Duration,
|
||||||
max_message_bytes: usize,
|
max_message_bytes: usize,
|
||||||
) -> Result<PluginWebSocketRecvResponse, PluginWebSocketError> {
|
) -> Result<PluginWebSocketRecvResponse, PluginWebSocketError> {
|
||||||
set_system_websocket_timeouts(&mut self.socket, timeout);
|
|
||||||
for _ in 0..PLUGIN_WEBSOCKET_MAX_CONTROL_FRAMES {
|
for _ in 0..PLUGIN_WEBSOCKET_MAX_CONTROL_FRAMES {
|
||||||
let message = self.socket.read().map_err(|error| {
|
let next = tokio::time::timeout(timeout, self.socket.next());
|
||||||
|
let message = block_on_websocket_future(&self.runtime, next)
|
||||||
|
.map_err(|_| PluginWebSocketError::new("WebSocket receive timed out"))?
|
||||||
|
.ok_or_else(|| PluginWebSocketError::new("WebSocket stream ended"))?
|
||||||
|
.map_err(|error| {
|
||||||
PluginWebSocketError::new(format!("WebSocket receive failed: {error}"))
|
PluginWebSocketError::new(format!("WebSocket receive failed: {error}"))
|
||||||
})?;
|
})?;
|
||||||
match message {
|
match message {
|
||||||
|
|
@ -2595,9 +2639,15 @@ impl PluginWebSocketConnection for TungstenitePluginWebSocketConnection {
|
||||||
}
|
}
|
||||||
Message::Close(_) => return Ok(PluginWebSocketRecvResponse::Closed),
|
Message::Close(_) => return Ok(PluginWebSocketRecvResponse::Closed),
|
||||||
Message::Ping(payload) => {
|
Message::Ping(payload) => {
|
||||||
self.socket.send(Message::Pong(payload)).map_err(|error| {
|
let send = tokio::time::timeout(
|
||||||
|
PLUGIN_WEBSOCKET_DEFAULT_TIMEOUT,
|
||||||
|
self.socket.send(Message::Pong(payload)),
|
||||||
|
);
|
||||||
|
block_on_websocket_future(&self.runtime, send)
|
||||||
|
.map_err(|_| PluginWebSocketError::new("WebSocket pong timed out"))?
|
||||||
|
.map_err(|error| {
|
||||||
PluginWebSocketError::new(format!("WebSocket pong failed: {error}"))
|
PluginWebSocketError::new(format!("WebSocket pong failed: {error}"))
|
||||||
})?
|
})?;
|
||||||
}
|
}
|
||||||
Message::Pong(_) | Message::Frame(_) => continue,
|
Message::Pong(_) | Message::Frame(_) => continue,
|
||||||
}
|
}
|
||||||
|
|
@ -2608,26 +2658,30 @@ impl PluginWebSocketConnection for TungstenitePluginWebSocketConnection {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn close(&mut self) -> Result<(), PluginWebSocketError> {
|
fn close(&mut self) -> Result<(), PluginWebSocketError> {
|
||||||
self.socket
|
let close = tokio::time::timeout(PLUGIN_WEBSOCKET_DEFAULT_TIMEOUT, self.socket.close(None));
|
||||||
.close(None)
|
block_on_websocket_future(&self.runtime, close)
|
||||||
|
.map_err(|_| PluginWebSocketError::new("WebSocket close timed out"))?
|
||||||
.map_err(|error| PluginWebSocketError::new(format!("WebSocket close failed: {error}")))
|
.map_err(|error| PluginWebSocketError::new(format!("WebSocket close failed: {error}")))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_system_websocket_timeouts(socket: &mut SystemWebSocket, timeout: Duration) {
|
fn new_websocket_runtime() -> Result<TokioRuntime, PluginWebSocketError> {
|
||||||
match socket.get_mut() {
|
TokioRuntimeBuilder::new_current_thread()
|
||||||
MaybeTlsStream::Plain(stream) => {
|
.enable_all()
|
||||||
let _ = stream.set_read_timeout(Some(timeout));
|
.build()
|
||||||
let _ = stream.set_write_timeout(Some(timeout));
|
.map_err(|error| {
|
||||||
}
|
PluginWebSocketError::new(format!("WebSocket runtime build failed: {error}"))
|
||||||
#[allow(unreachable_patterns)]
|
})
|
||||||
MaybeTlsStream::NativeTls(stream) => {
|
}
|
||||||
let stream = stream.get_ref();
|
|
||||||
let _ = stream.set_read_timeout(Some(timeout));
|
fn block_on_websocket_future<F: std::future::Future>(
|
||||||
let _ = stream.set_write_timeout(Some(timeout));
|
runtime: &TokioRuntime,
|
||||||
}
|
future: F,
|
||||||
#[allow(unreachable_patterns)]
|
) -> F::Output {
|
||||||
_ => {}
|
if TokioHandle::try_current().is_ok() {
|
||||||
|
tokio::task::block_in_place(|| runtime.block_on(future))
|
||||||
|
} else {
|
||||||
|
runtime.block_on(future)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2637,14 +2691,15 @@ struct PluginWebSocketHandles {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PluginWebSocketHandles {
|
impl PluginWebSocketHandles {
|
||||||
fn insert(
|
fn reserve_open(&self) -> Result<PluginWebSocketOpenReservation, PluginWebSocketError> {
|
||||||
&self,
|
|
||||||
connection: Box<dyn PluginWebSocketConnection>,
|
|
||||||
) -> Result<u32, PluginWebSocketError> {
|
|
||||||
self.inner
|
self.inner
|
||||||
.lock()
|
.lock()
|
||||||
.expect("plugin websocket handle table poisoned")
|
.expect("plugin websocket handle table poisoned")
|
||||||
.insert(connection)
|
.reserve_open()?;
|
||||||
|
Ok(PluginWebSocketOpenReservation {
|
||||||
|
handles: self.clone(),
|
||||||
|
active: true,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn with_connection<T>(
|
fn with_connection<T>(
|
||||||
|
|
@ -2671,11 +2726,54 @@ impl PluginWebSocketHandles {
|
||||||
.expect("plugin websocket handle table poisoned")
|
.expect("plugin websocket handle table poisoned")
|
||||||
.close_all();
|
.close_all();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn reservation_count(&self) -> usize {
|
||||||
|
self.inner
|
||||||
|
.lock()
|
||||||
|
.expect("plugin websocket handle table poisoned")
|
||||||
|
.reservations
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct PluginWebSocketOpenReservation {
|
||||||
|
handles: PluginWebSocketHandles,
|
||||||
|
active: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PluginWebSocketOpenReservation {
|
||||||
|
fn commit(
|
||||||
|
mut self,
|
||||||
|
connection: Box<dyn PluginWebSocketConnection>,
|
||||||
|
) -> Result<u32, PluginWebSocketError> {
|
||||||
|
let result = self
|
||||||
|
.handles
|
||||||
|
.inner
|
||||||
|
.lock()
|
||||||
|
.expect("plugin websocket handle table poisoned")
|
||||||
|
.insert_reserved(connection);
|
||||||
|
self.active = false;
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for PluginWebSocketOpenReservation {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if self.active {
|
||||||
|
self.handles
|
||||||
|
.inner
|
||||||
|
.lock()
|
||||||
|
.expect("plugin websocket handle table poisoned")
|
||||||
|
.release_reservation();
|
||||||
|
self.active = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
struct PluginWebSocketHandleTable {
|
struct PluginWebSocketHandleTable {
|
||||||
next: u32,
|
next: u32,
|
||||||
|
reservations: usize,
|
||||||
connections: HashMap<u32, PluginWebSocketHandleEntry>,
|
connections: HashMap<u32, PluginWebSocketHandleEntry>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2685,14 +2783,32 @@ struct PluginWebSocketHandleEntry {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PluginWebSocketHandleTable {
|
impl PluginWebSocketHandleTable {
|
||||||
fn insert(
|
fn reserve_open(&mut self) -> Result<(), PluginWebSocketError> {
|
||||||
|
self.expire_stale();
|
||||||
|
if self.connections.len() + self.reservations >= PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS {
|
||||||
|
return Err(PluginWebSocketError::new(format!(
|
||||||
|
"host_api.websocket open connection limit ({}) exceeded before dialing",
|
||||||
|
PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
self.reservations += 1;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn release_reservation(&mut self) {
|
||||||
|
self.reservations = self.reservations.saturating_sub(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn insert_reserved(
|
||||||
&mut self,
|
&mut self,
|
||||||
connection: Box<dyn PluginWebSocketConnection>,
|
mut connection: Box<dyn PluginWebSocketConnection>,
|
||||||
) -> Result<u32, PluginWebSocketError> {
|
) -> Result<u32, PluginWebSocketError> {
|
||||||
|
self.release_reservation();
|
||||||
self.expire_stale();
|
self.expire_stale();
|
||||||
if self.connections.len() >= PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS {
|
if self.connections.len() >= PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS {
|
||||||
|
let _ = connection.close();
|
||||||
return Err(PluginWebSocketError::new(format!(
|
return Err(PluginWebSocketError::new(format!(
|
||||||
"host_api.websocket open connection limit ({}) exceeded",
|
"host_api.websocket open connection limit ({}) exceeded while committing reserved handle",
|
||||||
PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS
|
PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
@ -2741,6 +2857,7 @@ impl PluginWebSocketHandleTable {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn close_all(&mut self) {
|
fn close_all(&mut self) {
|
||||||
|
self.reservations = 0;
|
||||||
for (_, mut entry) in self.connections.drain() {
|
for (_, mut entry) in self.connections.drain() {
|
||||||
let _ = entry.connection.close();
|
let _ = entry.connection.close();
|
||||||
}
|
}
|
||||||
|
|
@ -7288,6 +7405,10 @@ input_schema = { type = "object", additionalProperties = true }
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PluginWebSocketClient for MockWebSocketClient {
|
impl PluginWebSocketClient for MockWebSocketClient {
|
||||||
|
fn supports_bounded_open(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
fn open(
|
fn open(
|
||||||
&self,
|
&self,
|
||||||
_request: &PluginWebSocketOpenRequest,
|
_request: &PluginWebSocketOpenRequest,
|
||||||
|
|
@ -7328,10 +7449,52 @@ input_schema = { type = "object", additionalProperties = true }
|
||||||
fn close(&mut self) -> Result<(), PluginWebSocketError> {
|
fn close(&mut self) -> Result<(), PluginWebSocketError> {
|
||||||
self.closed
|
self.closed
|
||||||
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
struct FailingWebSocketClient {
|
||||||
|
opens: Arc<std::sync::atomic::AtomicUsize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PluginWebSocketClient for FailingWebSocketClient {
|
||||||
|
fn supports_bounded_open(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn open(
|
||||||
|
&self,
|
||||||
|
_request: &PluginWebSocketOpenRequest,
|
||||||
|
_url: &reqwest::Url,
|
||||||
|
_limits: PluginWebSocketLimits,
|
||||||
|
) -> Result<Box<dyn PluginWebSocketConnection>, PluginWebSocketError> {
|
||||||
|
self.opens.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||||||
|
Err(PluginWebSocketError::new("simulated bounded open failure"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
struct UnboundedWebSocketClient {
|
||||||
|
opens: Arc<std::sync::atomic::AtomicUsize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PluginWebSocketClient for UnboundedWebSocketClient {
|
||||||
|
fn supports_bounded_open(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn open(
|
||||||
|
&self,
|
||||||
|
_request: &PluginWebSocketOpenRequest,
|
||||||
|
_url: &reqwest::Url,
|
||||||
|
_limits: PluginWebSocketLimits,
|
||||||
|
) -> Result<Box<dyn PluginWebSocketConnection>, PluginWebSocketError> {
|
||||||
|
self.opens.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||||||
|
Err(PluginWebSocketError::new("should not dial"))
|
||||||
|
}
|
||||||
|
}
|
||||||
fn websocket_grant(
|
fn websocket_grant(
|
||||||
scheme: &str,
|
scheme: &str,
|
||||||
host: &str,
|
host: &str,
|
||||||
|
|
@ -7357,6 +7520,124 @@ input_schema = { type = "object", additionalProperties = true }
|
||||||
record
|
record
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn websocket_max_open_connections_rejects_before_network_open() {
|
||||||
|
let record = record_with_websocket(
|
||||||
|
vec![websocket_grant(
|
||||||
|
"wss",
|
||||||
|
"gateway.example.com",
|
||||||
|
None,
|
||||||
|
&["/gateway"],
|
||||||
|
)],
|
||||||
|
vec![websocket_grant(
|
||||||
|
"wss",
|
||||||
|
"gateway.example.com",
|
||||||
|
None,
|
||||||
|
&["/gateway"],
|
||||||
|
)],
|
||||||
|
);
|
||||||
|
let client = MockWebSocketClient::default();
|
||||||
|
let handles = PluginWebSocketHandles::default();
|
||||||
|
for _ in 0..PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS {
|
||||||
|
execute_plugin_websocket_open(
|
||||||
|
&record,
|
||||||
|
&client,
|
||||||
|
&handles,
|
||||||
|
br#"{"url":"wss://gateway.example.com/gateway"}"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
assert_eq!(
|
||||||
|
client.opens.load(std::sync::atomic::Ordering::SeqCst),
|
||||||
|
PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS
|
||||||
|
);
|
||||||
|
let error = execute_plugin_websocket_open(
|
||||||
|
&record,
|
||||||
|
&client,
|
||||||
|
&handles,
|
||||||
|
br#"{"url":"wss://gateway.example.com/gateway"}"#,
|
||||||
|
)
|
||||||
|
.unwrap_err();
|
||||||
|
assert!(error.0.contains("before dialing"));
|
||||||
|
assert_eq!(
|
||||||
|
client.opens.load(std::sync::atomic::Ordering::SeqCst),
|
||||||
|
PLUGIN_WEBSOCKET_MAX_OPEN_CONNECTIONS
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn websocket_open_failure_releases_capacity_reservation() {
|
||||||
|
let record = record_with_websocket(
|
||||||
|
vec![websocket_grant(
|
||||||
|
"wss",
|
||||||
|
"gateway.example.com",
|
||||||
|
None,
|
||||||
|
&["/gateway"],
|
||||||
|
)],
|
||||||
|
vec![websocket_grant(
|
||||||
|
"wss",
|
||||||
|
"gateway.example.com",
|
||||||
|
None,
|
||||||
|
&["/gateway"],
|
||||||
|
)],
|
||||||
|
);
|
||||||
|
let failing = FailingWebSocketClient::default();
|
||||||
|
let handles = PluginWebSocketHandles::default();
|
||||||
|
let error = execute_plugin_websocket_open(
|
||||||
|
&record,
|
||||||
|
&failing,
|
||||||
|
&handles,
|
||||||
|
br#"{"url":"wss://gateway.example.com/gateway"}"#,
|
||||||
|
)
|
||||||
|
.unwrap_err();
|
||||||
|
assert!(error.0.contains("simulated bounded open failure"));
|
||||||
|
assert_eq!(handles.reservation_count(), 0);
|
||||||
|
|
||||||
|
let client = MockWebSocketClient::default();
|
||||||
|
let open = execute_plugin_websocket_open(
|
||||||
|
&record,
|
||||||
|
&client,
|
||||||
|
&handles,
|
||||||
|
br#"{"url":"wss://gateway.example.com/gateway"}"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let open: PluginWebSocketOpenResponse = serde_json::from_slice(&open).unwrap();
|
||||||
|
assert_eq!(open.handle, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn websocket_unbounded_open_client_fails_closed_before_dialing() {
|
||||||
|
let record = record_with_websocket(
|
||||||
|
vec![websocket_grant(
|
||||||
|
"wss",
|
||||||
|
"gateway.example.com",
|
||||||
|
None,
|
||||||
|
&["/gateway"],
|
||||||
|
)],
|
||||||
|
vec![websocket_grant(
|
||||||
|
"wss",
|
||||||
|
"gateway.example.com",
|
||||||
|
None,
|
||||||
|
&["/gateway"],
|
||||||
|
)],
|
||||||
|
);
|
||||||
|
let client = UnboundedWebSocketClient::default();
|
||||||
|
let handles = PluginWebSocketHandles::default();
|
||||||
|
let error = execute_plugin_websocket_open(
|
||||||
|
&record,
|
||||||
|
&client,
|
||||||
|
&handles,
|
||||||
|
br#"{"url":"wss://gateway.example.com/gateway"}"#,
|
||||||
|
)
|
||||||
|
.unwrap_err();
|
||||||
|
assert!(
|
||||||
|
error
|
||||||
|
.0
|
||||||
|
.contains("cannot guarantee bounded/cancellable open")
|
||||||
|
);
|
||||||
|
assert_eq!(client.opens.load(std::sync::atomic::Ordering::SeqCst), 0);
|
||||||
|
assert_eq!(handles.reservation_count(), 0);
|
||||||
|
}
|
||||||
#[test]
|
#[test]
|
||||||
fn websocket_open_send_recv_close_is_bounded_and_explicit() {
|
fn websocket_open_send_recv_close_is_bounded_and_explicit() {
|
||||||
let record = record_with_websocket(
|
let record = record_with_websocket(
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ rustPlatform.buildRustPackage rec {
|
||||||
filter = sourceFilter;
|
filter = sourceFilter;
|
||||||
};
|
};
|
||||||
|
|
||||||
cargoHash = "sha256-TZrw6nJclXVRpFIUlYvimGTDXlxBMaQt6oM5C5DIGIU=";
|
cargoHash = "sha256-cZxkmM42kbDp1Rv9gn4sCD5WIQLc0wCbjj4GbKjuA9Q=";
|
||||||
|
|
||||||
depsExtraArgs = {
|
depsExtraArgs = {
|
||||||
# Older fetchCargoVendor utilities used crates.io's API download endpoint,
|
# Older fetchCargoVendor utilities used crates.io's API download endpoint,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user