fix: preserve runtime websocket diagnostics

This commit is contained in:
Keisuke Hirata 2026-06-26 14:10:08 +09:00
parent 9807accaf0
commit 8cc9a594f7
No known key found for this signature in database
2 changed files with 184 additions and 10 deletions

View File

@ -1,11 +1,12 @@
use std::collections::{BTreeMap, VecDeque}; use std::collections::{BTreeMap, VecDeque};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use axum::http::StatusCode;
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio_tungstenite::connect_async; use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::Message as TungsteniteMessage;
use tokio_tungstenite::tungstenite::client::IntoClientRequest; use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::{Error as TungsteniteError, Message as TungsteniteMessage};
use worker_runtime::http_server::{RuntimeWorkerEventWsEnvelope, RuntimeWorkerEventWsFrame}; use worker_runtime::http_server::{RuntimeWorkerEventWsEnvelope, RuntimeWorkerEventWsFrame};
/// Backend-private source for a runtime worker observation stream. /// Backend-private source for a runtime worker observation stream.
@ -320,6 +321,47 @@ impl BackendObservationProxy {
} }
} }
fn map_runtime_connect_error(error: TungsteniteError) -> ObservationProxyError {
match error {
TungsteniteError::Http(response) if response.status() == StatusCode::NOT_FOUND => {
ObservationProxyError::WorkerNotFound(
"runtime worker observation endpoint returned 404 not found".into(),
)
}
TungsteniteError::Http(response) if response.status() == StatusCode::BAD_REQUEST => {
ObservationProxyError::CursorMalformed(
"runtime worker observation endpoint rejected the request as malformed".into(),
)
}
TungsteniteError::Http(response) => ObservationProxyError::RuntimeUnavailable(format!(
"runtime worker observation endpoint rejected WebSocket upgrade with status {}",
response.status()
)),
error => ObservationProxyError::RuntimeUnavailable(format!(
"failed to connect runtime WebSocket: {error}"
)),
}
}
fn map_runtime_diagnostic(code: String, message: String) -> ObservationProxyError {
match code.as_str() {
"runtime.worker_not_found" => ObservationProxyError::WorkerNotFound(message),
"runtime.cursor_malformed" => ObservationProxyError::CursorMalformed(message),
"runtime.cursor_unknown_or_expired" | "runtime.cursor_expired" => {
ObservationProxyError::CursorUnknownOrExpired(message)
}
"runtime.unavailable" => ObservationProxyError::RuntimeUnavailable(message),
"runtime.upstream_closed" | "runtime.websocket_error" => {
ObservationProxyError::UpstreamDisconnect(message)
}
"runtime.serialize_failed" => ObservationProxyError::MalformedFrame(message),
"runtime.observation_only" => ObservationProxyError::ObservationOnly,
_ => ObservationProxyError::RuntimeUnavailable(format!(
"runtime diagnostic {code}: {message}"
)),
}
}
pub struct RuntimeWsObservationClient { pub struct RuntimeWsObservationClient {
runtime_id: String, runtime_id: String,
worker_id: String, worker_id: String,
@ -355,11 +397,9 @@ impl RuntimeWsObservationClient {
})?, })?,
); );
} }
let (stream, _) = connect_async(request).await.map_err(|error| { let (stream, _) = connect_async(request)
ObservationProxyError::RuntimeUnavailable(format!( .await
"failed to connect runtime WebSocket: {error}" .map_err(map_runtime_connect_error)?;
))
})?;
Ok(Self { Ok(Self {
runtime_id: source.runtime_id.clone(), runtime_id: source.runtime_id.clone(),
worker_id: source.worker_id.clone(), worker_id: source.worker_id.clone(),
@ -417,10 +457,7 @@ impl RuntimeWsObservationClient {
return Ok(self.map_envelope(envelope)); return Ok(self.map_envelope(envelope));
} }
RuntimeWorkerEventWsFrame::Diagnostic { diagnostic } => { RuntimeWorkerEventWsFrame::Diagnostic { diagnostic } => {
return Err(ObservationProxyError::UpstreamDisconnect(format!( return Err(map_runtime_diagnostic(diagnostic.code, diagnostic.message));
"runtime diagnostic {}: {}",
diagnostic.code, diagnostic.message
)));
} }
} }
} }

View File

@ -795,6 +795,7 @@ mod tests {
use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::tungstenite::Message;
use tower::ServiceExt; use tower::ServiceExt;
use crate::observation::ClientWorkerEventWsDiagnostic;
use crate::store::SqliteWorkspaceStore; use crate::store::SqliteWorkspaceStore;
const TEST_WORKSPACE_ID: &str = "0192f0e8-4d84-7d6e-a000-000000000001"; const TEST_WORKSPACE_ID: &str = "0192f0e8-4d84-7d6e-a000-000000000001";
@ -1109,6 +1110,72 @@ mod tests {
assert!(saw_observation_only, "expected observation-only diagnostic"); assert!(saw_observation_only, "expected observation-only diagnostic");
} }
#[tokio::test]
async fn proxy_reports_unknown_backend_cursor_before_upstream_connect() {
let source = RuntimeObservationSourceConfig {
runtime_id: "runtime-a".into(),
worker_id: "worker-a".into(),
endpoint: "ws://127.0.0.1:9/not-used".into(),
bearer_token: None,
};
let (url, _dir) = spawn_workspace_proxy(source).await;
let (mut stream, _) = connect_async(format!("{url}?cursor=bo_ffffffffffffffff"))
.await
.unwrap();
let diagnostic = next_client_diagnostic(&mut stream).await;
assert_eq!(diagnostic.code, "backend.cursor_unknown_or_expired");
}
#[tokio::test]
async fn proxy_maps_runtime_cursor_diagnostic_to_typed_backend_diagnostic() {
let (_runtime, _worker_ref, endpoint) = spawn_runtime_worker().await;
let source = RuntimeObservationSourceConfig {
runtime_id: "runtime-a".into(),
worker_id: "worker-a".into(),
endpoint: format!("{endpoint}?cursor=wo_ffffffffffffffff"),
bearer_token: None,
};
let (url, _dir) = spawn_workspace_proxy(source).await;
let (mut stream, _) = connect_async(&url).await.unwrap();
assert!(matches!(
next_client_frame(&mut stream).await,
ClientWorkerEventWsFrame::Event { envelope } if matches!(envelope.payload, protocol::Event::Snapshot { .. })
));
let diagnostic = next_client_diagnostic(&mut stream).await;
assert_eq!(diagnostic.code, "backend.cursor_unknown_or_expired");
}
#[tokio::test]
async fn proxy_maps_runtime_worker_not_found_http_404_to_typed_backend_diagnostic() {
let (_runtime, _worker_ref, endpoint) = spawn_runtime_worker().await;
let endpoint = endpoint.replace("/events/ws", "/missing-worker/events/ws");
let source = RuntimeObservationSourceConfig {
runtime_id: "runtime-a".into(),
worker_id: "worker-a".into(),
endpoint,
bearer_token: None,
};
let (url, _dir) = spawn_workspace_proxy(source).await;
let (mut stream, _) = connect_async(&url).await.unwrap();
let diagnostic = next_client_diagnostic(&mut stream).await;
assert_eq!(diagnostic.code, "backend.worker_not_found");
}
#[tokio::test]
async fn proxy_reports_actual_upstream_disconnect_separately() {
let endpoint = spawn_closing_runtime_ws().await;
let source = RuntimeObservationSourceConfig {
runtime_id: "runtime-a".into(),
worker_id: "worker-a".into(),
endpoint,
bearer_token: None,
};
let (url, _dir) = spawn_workspace_proxy(source).await;
let (mut stream, _) = connect_async(&url).await.unwrap();
let diagnostic = next_client_diagnostic(&mut stream).await;
assert_eq!(diagnostic.code, "backend.upstream_disconnect");
}
async fn next_client_frame( async fn next_client_frame(
stream: &mut tokio_tungstenite::WebSocketStream< stream: &mut tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>, tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
@ -1121,6 +1188,76 @@ mod tests {
serde_json::from_str(&text).unwrap() serde_json::from_str(&text).unwrap()
} }
async fn next_client_diagnostic(
stream: &mut tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
) -> ClientWorkerEventWsDiagnostic {
match next_client_frame(stream).await {
ClientWorkerEventWsFrame::Diagnostic { diagnostic } => diagnostic,
ClientWorkerEventWsFrame::Event { envelope } => {
panic!("expected diagnostic, got event: {envelope:?}")
}
}
}
async fn spawn_runtime_worker() -> (
worker_runtime::Runtime,
worker_runtime::identity::WorkerRef,
String,
) {
let runtime = worker_runtime::Runtime::new_memory();
let worker = runtime
.create_worker(worker_runtime::catalog::CreateWorkerRequest::default())
.unwrap();
let runtime_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let runtime_addr = runtime_listener.local_addr().unwrap();
tokio::spawn({
let runtime = runtime.clone();
async move {
worker_runtime::http_server::serve_runtime_http(runtime, runtime_listener, None)
.await
.unwrap()
}
});
let endpoint = format!(
"ws://{runtime_addr}/v1/workers/{}/events/ws",
worker.worker_ref.worker_id
);
(runtime, worker.worker_ref, endpoint)
}
async fn spawn_workspace_proxy(
source: RuntimeObservationSourceConfig,
) -> (String, tempfile::TempDir) {
let dir = tempfile::tempdir().unwrap();
let store = SqliteWorkspaceStore::in_memory().unwrap();
let mut config = ServerConfig::local_dev(dir.path(), test_identity());
config.local_runtime_data_dir = Some(dir.path().join("data"));
let runtime_id = source.runtime_id.clone();
let worker_id = source.worker_id.clone();
config.runtime_event_sources.push(source);
let api = WorkspaceApi::new(config, Arc::new(store)).await.unwrap();
let app_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let app_addr = app_listener.local_addr().unwrap();
tokio::spawn(async move { axum::serve(app_listener, build_router(api)).await.unwrap() });
(
format!("ws://{app_addr}/api/runtimes/{runtime_id}/workers/{worker_id}/events/ws"),
dir,
)
}
async fn spawn_closing_runtime_ws() -> String {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut websocket = tokio_tungstenite::accept_async(stream).await.unwrap();
let _ = websocket.close(None).await;
});
format!("ws://{addr}/events/ws")
}
async fn get_json(app: Router, uri: &str) -> Value { async fn get_json(app: Router, uri: &str) -> Value {
let response = app let response = app
.oneshot(Request::builder().uri(uri).body(Body::empty()).unwrap()) .oneshot(Request::builder().uri(uri).body(Body::empty()).unwrap())