diff --git a/Cargo.lock b/Cargo.lock index 71f9c698..9a55d81f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -203,6 +203,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "31b698c5f9a010f6573133b09e0de5408834d0c82f8d7475a89fc1867a71cd90" dependencies = [ "axum-core", + "base64", "bytes", "form_urlencoded", "futures-util", @@ -221,8 +222,10 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_urlencoded", + "sha1", "sync_wrapper", "tokio", + "tokio-tungstenite 0.29.0", "tower", "tower-layer", "tower-service", @@ -4422,7 +4425,19 @@ dependencies = [ "native-tls", "tokio", "tokio-native-tls", - "tungstenite", + "tungstenite 0.28.0", +] + +[[package]] +name = "tokio-tungstenite" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f72a05e828585856dacd553fba484c242c46e391fb0e58917c942ee9202915c" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite 0.29.0", ] [[package]] @@ -4709,6 +4724,22 @@ dependencies = [ "utf-8", ] +[[package]] +name = "tungstenite" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c01152af293afb9c7c2a57e4b559c5620b421f6d133261c60dd2d0cdb38e6b8" +dependencies = [ + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand 0.9.4", + "sha1", + "thiserror 2.0.18", +] + [[package]] name = "type1-encoding-parser" version = "0.1.1" @@ -5889,11 +5920,11 @@ dependencies = [ "thiserror 2.0.18", "ticket", "tokio", - "tokio-tungstenite", + "tokio-tungstenite 0.28.0", "toml", "tools", "tracing", - "tungstenite", + "tungstenite 0.28.0", "uuid", "wasmtime", "wat", @@ -5906,10 +5937,13 @@ name = "worker-runtime" version = "0.1.0" dependencies = [ "axum", + "futures", + "protocol", "serde", "serde_json", "thiserror 2.0.18", "tokio", + "tokio-tungstenite 0.29.0", "tower", ] @@ -5997,9 +6031,11 @@ dependencies = [ "async-trait", "axum", "chrono", + "futures", "manifest", "pod-store", "project-record", + "protocol", "rusqlite", "serde", "serde_json", @@ -6009,10 +6045,12 @@ dependencies = [ "thiserror 2.0.18", "ticket", "tokio", + "tokio-tungstenite 0.29.0", "toml", "tower", "tracing", "uuid", + "worker-runtime", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 36662456..51e91e5f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -104,6 +104,7 @@ sha2 = "0.11" tempfile = "3.27" thiserror = "2.0" tokio = "1.52" +tokio-tungstenite = "0.29" tower = "0.5" toml = "1.1" tracing = "0.1" diff --git a/crates/worker-runtime/Cargo.toml b/crates/worker-runtime/Cargo.toml index d8588ab2..cac97606 100644 --- a/crates/worker-runtime/Cargo.toml +++ b/crates/worker-runtime/Cargo.toml @@ -14,11 +14,18 @@ required-features = ["http-server"] default = [] fs-store = ["dep:serde_json"] http-server = ["dep:axum", "dep:serde_json", "dep:tokio", "dep:tower"] +ws-server = ["http-server", "axum/ws", "dep:futures", "dep:protocol", "tokio/sync"] [dependencies] axum = { workspace = true, optional = true } +futures = { workspace = true, optional = true } +protocol = { workspace = true, optional = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true, optional = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["net", "rt"], optional = true } tower = { workspace = true, features = ["util"], optional = true } + +[dev-dependencies] +tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } +tokio-tungstenite.workspace = true diff --git a/crates/worker-runtime/src/http_server.rs b/crates/worker-runtime/src/http_server.rs index 11624fd8..dd9a3357 100644 --- a/crates/worker-runtime/src/http_server.rs +++ b/crates/worker-runtime/src/http_server.rs @@ -14,15 +14,21 @@ use crate::fs_store::FsRuntimeStoreOptions; use crate::identity::{RuntimeId, WorkerId, WorkerRef}; use crate::interaction::{WorkerInput, WorkerInteractionAck}; use crate::management::{RuntimeLimits, RuntimeOptions, RuntimeSummary}; +#[cfg(feature = "ws-server")] +use crate::observation::WorkerObservationCursor; use crate::observation::{TranscriptProjection, TranscriptQuery}; use axum::body::{Body, Bytes}; use axum::extract::rejection::{JsonRejection, QueryRejection}; +#[cfg(feature = "ws-server")] +use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade}; use axum::extract::{Path, Query, State}; use axum::http::{Request, StatusCode, header}; use axum::middleware::{self, Next}; use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; use axum::{Json, Router}; +#[cfg(feature = "ws-server")] +use futures::StreamExt; use serde::{Deserialize, Serialize}; use std::fmt; use std::net::SocketAddr; @@ -157,7 +163,7 @@ pub fn runtime_http_router(runtime: Runtime, local_token: Option) -> Rou local_token: local_token.map(Arc::::from), }; - Router::new() + let router = Router::new() .route("/v1/runtime", get(get_runtime)) .route("/v1/workers", get(list_workers).post(create_worker)) .route("/v1/workers/{worker_id}", get(get_worker)) @@ -167,7 +173,12 @@ pub fn runtime_http_router(runtime: Runtime, local_token: Option) -> Rou .route( "/v1/workers/{worker_id}/transcript", get(get_worker_transcript), - ) + ); + + #[cfg(feature = "ws-server")] + let router = router.route("/v1/workers/{worker_id}/events/ws", get(worker_events_ws)); + + router .with_state(state.clone()) .layer(middleware::from_fn_with_state(state, require_local_token)) } @@ -255,6 +266,43 @@ pub struct RuntimeHttpErrorDetail { pub message: String, } +/// Runtime-owned WebSocket frame for worker-scoped observation. +#[cfg(feature = "ws-server")] +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum RuntimeWorkerEventWsFrame { + Event { + envelope: RuntimeWorkerEventWsEnvelope, + }, + Diagnostic { + diagnostic: RuntimeWorkerEventWsDiagnostic, + }, +} + +/// Runtime-local protocol event envelope. +#[cfg(feature = "ws-server")] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RuntimeWorkerEventWsEnvelope { + pub cursor: String, + pub event_id: String, + pub worker_id: WorkerId, + pub payload: protocol::Event, +} + +/// Runtime-local observation diagnostic. +#[cfg(feature = "ws-server")] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct RuntimeWorkerEventWsDiagnostic { + pub code: String, + pub message: String, +} + +#[cfg(feature = "ws-server")] +#[derive(Clone, Debug, Default, Deserialize)] +struct RuntimeWorkerEventsWsQuery { + cursor: Option, +} + #[derive(Clone, Debug, Deserialize)] struct RuntimeHttpTranscriptQuery { #[serde(default)] @@ -267,6 +315,51 @@ fn default_transcript_limit() -> usize { 256 } +#[cfg(feature = "ws-server")] +impl RuntimeWorkerEventWsFrame { + fn event( + cursor: String, + event_id: String, + worker_id: WorkerId, + payload: protocol::Event, + ) -> Self { + Self::Event { + envelope: RuntimeWorkerEventWsEnvelope { + cursor, + event_id, + worker_id, + payload, + }, + } + } + + fn diagnostic(code: impl Into, message: impl Into) -> Self { + Self::Diagnostic { + diagnostic: RuntimeWorkerEventWsDiagnostic { + code: code.into(), + message: message.into(), + }, + } + } +} + +#[cfg(feature = "ws-server")] +async fn send_ws_frame(socket: &mut WebSocket, frame: &RuntimeWorkerEventWsFrame) -> bool { + match serde_json::to_string(frame) { + Ok(text) => socket.send(WsMessage::Text(text.into())).await.is_ok(), + Err(error) => { + let fallback = RuntimeWorkerEventWsFrame::diagnostic( + "runtime.serialize_failed", + format!("failed to serialize observation frame: {error}"), + ); + let Ok(text) = serde_json::to_string(&fallback) else { + return false; + }; + socket.send(WsMessage::Text(text.into())).await.is_ok() + } + } +} + type RestResult = Result, RuntimeHttpRestError>; async fn get_runtime( @@ -313,6 +406,182 @@ async fn create_worker( Ok(Json(RuntimeHttpWorkerResponse { worker })) } +#[cfg(feature = "ws-server")] +async fn worker_events_ws( + State(state): State, + Path(worker_id): Path, + Query(query): Query, + ws: WebSocketUpgrade, +) -> Result { + let worker_ref = worker_ref_for(&state.runtime, worker_id)?; + state + .runtime + .worker_detail(&worker_ref) + .map_err(RuntimeHttpRestError::runtime)?; + Ok(ws + .on_upgrade(move |socket| { + worker_events_ws_session(state.runtime, worker_ref, query, socket) + }) + .into_response()) +} + +#[cfg(feature = "ws-server")] +async fn worker_events_ws_session( + runtime: Runtime, + worker_ref: WorkerRef, + query: RuntimeWorkerEventsWsQuery, + mut socket: WebSocket, +) { + let mut cursor = match query.cursor.as_deref() { + Some(raw) => match WorkerObservationCursor::decode(raw) { + Some(cursor) => cursor, + None => { + let frame = RuntimeWorkerEventWsFrame::diagnostic( + "runtime.cursor_malformed", + format!("malformed worker observation cursor: {raw}"), + ); + let _ = send_ws_frame(&mut socket, &frame).await; + return; + } + }, + None => match runtime.worker_observation_cursor_now(&worker_ref) { + Ok(cursor) => cursor, + Err(error) => { + let frame = RuntimeWorkerEventWsFrame::diagnostic( + "runtime.worker_not_found", + error.to_string(), + ); + let _ = send_ws_frame(&mut socket, &frame).await; + return; + } + }, + }; + + let mut receiver = match runtime.subscribe_worker_observation() { + Ok(receiver) => receiver, + Err(error) => { + let frame = RuntimeWorkerEventWsFrame::diagnostic( + "runtime.unavailable", + format!("runtime observation bus unavailable: {error}"), + ); + let _ = send_ws_frame(&mut socket, &frame).await; + return; + } + }; + + let snapshot = match runtime.worker_observation_snapshot(&worker_ref) { + Ok(snapshot) => snapshot, + Err(error) => { + let frame = RuntimeWorkerEventWsFrame::diagnostic( + "runtime.worker_not_found", + error.to_string(), + ); + let _ = send_ws_frame(&mut socket, &frame).await; + return; + } + }; + let snapshot_cursor = cursor.encode(); + let snapshot_frame = RuntimeWorkerEventWsFrame::event( + snapshot_cursor.clone(), + format!("snapshot:{snapshot_cursor}"), + worker_ref.worker_id.clone(), + snapshot, + ); + if !send_ws_frame(&mut socket, &snapshot_frame).await { + return; + } + + match runtime.read_worker_observation_events(&worker_ref, cursor) { + Ok(backlog) => { + for event in backlog { + cursor = WorkerObservationCursor::new(event.sequence); + let frame = RuntimeWorkerEventWsFrame::event( + event.cursor, + event.event_id, + event.worker_ref.worker_id, + event.payload, + ); + if !send_ws_frame(&mut socket, &frame).await { + return; + } + } + } + Err(error) => { + let frame = RuntimeWorkerEventWsFrame::diagnostic( + "runtime.cursor_unknown_or_expired", + error.to_string(), + ); + let _ = send_ws_frame(&mut socket, &frame).await; + return; + } + } + + loop { + tokio::select! { + inbound = socket.next() => { + match inbound { + Some(Ok(WsMessage::Close(_))) | None => return, + Some(Ok(WsMessage::Ping(payload))) => { + if socket.send(WsMessage::Pong(payload)).await.is_err() { + return; + } + } + Some(Ok(WsMessage::Pong(_))) => {} + Some(Ok(_)) => { + let frame = RuntimeWorkerEventWsFrame::diagnostic( + "runtime.observation_only", + "runtime worker event WebSocket is observation-only", + ); + let _ = send_ws_frame(&mut socket, &frame).await; + return; + } + Some(Err(error)) => { + let frame = RuntimeWorkerEventWsFrame::diagnostic( + "runtime.websocket_error", + format!("runtime WebSocket receive error: {error}"), + ); + let _ = send_ws_frame(&mut socket, &frame).await; + return; + } + } + } + event = receiver.recv() => { + match event { + Ok(event) if event.worker_ref == worker_ref && event.sequence > cursor.sequence => { + cursor = WorkerObservationCursor::new(event.sequence); + let frame = RuntimeWorkerEventWsFrame::event( + event.cursor, + event.event_id, + event.worker_ref.worker_id, + event.payload, + ); + if !send_ws_frame(&mut socket, &frame).await { + return; + } + } + Ok(_) => {} + Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => { + let frame = RuntimeWorkerEventWsFrame::diagnostic( + "runtime.cursor_expired", + "runtime observation backlog was overrun", + ); + let _ = send_ws_frame(&mut socket, &frame).await; + return; + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + let frame = RuntimeWorkerEventWsFrame::diagnostic( + "runtime.upstream_closed", + "runtime observation bus closed", + ); + let _ = send_ws_frame(&mut socket, &frame).await; + return; + } + } + } + } + } +} + async fn send_worker_input( State(state): State, Path(worker_id): Path, @@ -688,3 +957,159 @@ mod tests { assert!(error.error.message.contains("worker-missing")); } } + +#[cfg(all(test, feature = "ws-server"))] +mod ws_tests { + use super::*; + use futures::{SinkExt, StreamExt}; + use tokio_tungstenite::connect_async; + use tokio_tungstenite::tungstenite::Message; + + async fn spawn_runtime_server() -> (Runtime, WorkerRef, String) { + let runtime = Runtime::new_memory(); + let worker = runtime + .create_worker(CreateWorkerRequest::default()) + .unwrap(); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn({ + let runtime = runtime.clone(); + async move { serve_runtime_http(runtime, listener, None).await.unwrap() } + }); + ( + runtime, + worker.worker_ref.clone(), + format!( + "ws://{addr}/v1/workers/{}/events/ws", + worker.worker_ref.worker_id + ), + ) + } + + async fn next_frame( + stream: &mut tokio_tungstenite::WebSocketStream< + tokio_tungstenite::MaybeTlsStream, + >, + ) -> RuntimeWorkerEventWsFrame { + let message = stream.next().await.unwrap().unwrap(); + let Message::Text(text) = message else { + panic!("expected text frame"); + }; + serde_json::from_str(&text).unwrap() + } + + #[tokio::test] + async fn runtime_ws_connect_sends_snapshot_and_live_worker_events() { + let (runtime, worker_ref, url) = spawn_runtime_server().await; + let (mut stream, _) = connect_async(&url).await.unwrap(); + + match next_frame(&mut stream).await { + RuntimeWorkerEventWsFrame::Event { envelope } => { + assert_eq!(envelope.worker_id, worker_ref.worker_id); + assert!(matches!(envelope.payload, protocol::Event::Snapshot { .. })); + } + RuntimeWorkerEventWsFrame::Diagnostic { diagnostic } => { + panic!("unexpected diagnostic: {diagnostic:?}"); + } + } + + let stored = runtime + .observe_worker_event( + &worker_ref, + protocol::Event::TextDelta { + text: "started".into(), + }, + ) + .unwrap(); + match next_frame(&mut stream).await { + RuntimeWorkerEventWsFrame::Event { envelope } => { + assert_eq!(envelope.worker_id, worker_ref.worker_id); + assert_eq!(envelope.cursor, stored.cursor); + assert!(matches!( + envelope.payload, + protocol::Event::TextDelta { .. } + )); + } + RuntimeWorkerEventWsFrame::Diagnostic { diagnostic } => { + panic!("unexpected diagnostic: {diagnostic:?}"); + } + } + } + + #[tokio::test] + async fn runtime_ws_cursor_resume_is_duplicate_safe_and_filters_workers() { + let (runtime, worker_ref, url) = spawn_runtime_server().await; + let other = runtime + .create_worker(CreateWorkerRequest::default()) + .unwrap(); + let first = runtime + .observe_worker_event( + &worker_ref, + protocol::Event::TextDelta { + text: "started".into(), + }, + ) + .unwrap(); + runtime + .observe_worker_event( + &other.worker_ref, + protocol::Event::TextDelta { + text: "started".into(), + }, + ) + .unwrap(); + + let (mut stream, _) = connect_async(format!("{url}?cursor={}", first.cursor)) + .await + .unwrap(); + assert!(matches!( + next_frame(&mut stream).await, + RuntimeWorkerEventWsFrame::Event { envelope } if matches!(envelope.payload, protocol::Event::Snapshot { .. }) + )); + + let second = runtime + .observe_worker_event( + &worker_ref, + protocol::Event::TextDone { + text: "done".into(), + }, + ) + .unwrap(); + match next_frame(&mut stream).await { + RuntimeWorkerEventWsFrame::Event { envelope } => { + assert_eq!(envelope.cursor, second.cursor); + assert_ne!(envelope.cursor, first.cursor); + assert!(matches!(envelope.payload, protocol::Event::TextDone { .. })); + } + RuntimeWorkerEventWsFrame::Diagnostic { diagnostic } => { + panic!("unexpected diagnostic: {diagnostic:?}"); + } + } + } + + #[tokio::test] + async fn runtime_ws_reports_malformed_cursor_and_observation_only_input() { + let (_runtime, _worker_ref, url) = spawn_runtime_server().await; + let (mut malformed, _) = connect_async(format!("{url}?cursor=bad")).await.unwrap(); + match next_frame(&mut malformed).await { + RuntimeWorkerEventWsFrame::Diagnostic { diagnostic } => { + assert_eq!(diagnostic.code, "runtime.cursor_malformed"); + } + RuntimeWorkerEventWsFrame::Event { envelope } => { + panic!("unexpected event: {envelope:?}"); + } + } + + let (mut stream, _) = connect_async(&url).await.unwrap(); + let _ = next_frame(&mut stream).await; + stream.send(Message::Text("{}".into())).await.unwrap(); + match next_frame(&mut stream).await { + RuntimeWorkerEventWsFrame::Diagnostic { diagnostic } => { + assert_eq!(diagnostic.code, "runtime.observation_only"); + } + RuntimeWorkerEventWsFrame::Event { envelope } => { + panic!("unexpected event: {envelope:?}"); + } + } + } +} diff --git a/crates/worker-runtime/src/observation.rs b/crates/worker-runtime/src/observation.rs index 50ee2318..fc99e918 100644 --- a/crates/worker-runtime/src/observation.rs +++ b/crates/worker-runtime/src/observation.rs @@ -93,3 +93,62 @@ pub struct RuntimeEventBatch { pub events: Vec, pub has_more: bool, } + +/// Runtime-local cursor for worker-scoped WebSocket observation. +#[cfg(feature = "ws-server")] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub struct WorkerObservationCursor { + pub sequence: u64, +} + +#[cfg(feature = "ws-server")] +impl WorkerObservationCursor { + pub const PREFIX: &'static str = "wo"; + + pub fn new(sequence: u64) -> Self { + Self { sequence } + } + + pub fn zero() -> Self { + Self { sequence: 0 } + } + + pub fn encode(self) -> String { + format!("{}_{:016x}", Self::PREFIX, self.sequence) + } + + pub fn decode(value: &str) -> Option { + let encoded = value.strip_prefix("wo_")?; + if encoded.len() != 16 { + return None; + } + u64::from_str_radix(encoded, 16) + .ok() + .map(|sequence| Self { sequence }) + } +} + +/// One protocol event observed from a runtime Worker. +#[cfg(feature = "ws-server")] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct WorkerObservationEvent { + pub cursor: String, + pub event_id: String, + pub sequence: u64, + pub worker_ref: WorkerRef, + pub payload: protocol::Event, +} + +#[cfg(feature = "ws-server")] +impl WorkerObservationEvent { + pub fn new(sequence: u64, worker_ref: WorkerRef, payload: protocol::Event) -> Self { + let cursor = WorkerObservationCursor::new(sequence).encode(); + Self { + event_id: cursor.clone(), + cursor, + sequence, + worker_ref, + payload, + } + } +} diff --git a/crates/worker-runtime/src/runtime.rs b/crates/worker-runtime/src/runtime.rs index 75eccdc5..172b4f65 100644 --- a/crates/worker-runtime/src/runtime.rs +++ b/crates/worker-runtime/src/runtime.rs @@ -16,9 +16,15 @@ use crate::observation::{ EventCursor, EventSubscription, EventSubscriptionMode, RuntimeEvent, RuntimeEventBatch, RuntimeEventKind, TranscriptEntry, TranscriptProjection, TranscriptQuery, TranscriptRole, }; +#[cfg(feature = "ws-server")] +use crate::observation::{WorkerObservationCursor, WorkerObservationEvent}; use std::collections::BTreeMap; +#[cfg(feature = "ws-server")] +use std::collections::VecDeque; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, Mutex, MutexGuard}; +#[cfg(feature = "ws-server")] +use tokio::sync::broadcast; static NEXT_RUNTIME_SEQUENCE: AtomicU64 = AtomicU64::new(1); @@ -395,6 +401,88 @@ impl Runtime { }) } + /// Cursor pointing after the current worker-scoped protocol observation event. + #[cfg(feature = "ws-server")] + pub fn worker_observation_cursor_now( + &self, + worker_ref: &WorkerRef, + ) -> Result { + let state = self.lock()?; + state.ensure_worker_ref(worker_ref)?; + let sequence = state + .observation_events + .iter() + .rev() + .find(|event| &event.worker_ref == worker_ref) + .map(|event| event.sequence) + .unwrap_or(0); + Ok(WorkerObservationCursor::new(sequence)) + } + + /// Build the current Worker Snapshot event used as the first observation frame. + #[cfg(feature = "ws-server")] + pub fn worker_observation_snapshot( + &self, + worker_ref: &WorkerRef, + ) -> Result { + let state = self.lock()?; + let _worker = state.worker(worker_ref)?; + Ok(protocol::Event::Snapshot { + entries: Vec::new(), + greeting: protocol::Greeting { + worker_name: worker_ref.worker_id.to_string(), + cwd: String::new(), + provider: "worker-runtime".to_string(), + model: "worker-runtime".to_string(), + scope_summary: "runtime worker observation".to_string(), + tools: Vec::new(), + context_window: 0, + context_tokens: 0, + }, + status: protocol::WorkerStatus::Idle, + in_flight: protocol::InFlightSnapshot { blocks: Vec::new() }, + }) + } + + /// Replay retained worker-scoped protocol observation events after a cursor. + #[cfg(feature = "ws-server")] + pub fn read_worker_observation_events( + &self, + worker_ref: &WorkerRef, + cursor: WorkerObservationCursor, + ) -> Result, RuntimeError> { + let state = self.lock()?; + state.ensure_worker_ref(worker_ref)?; + state.validate_worker_observation_cursor(worker_ref, cursor)?; + Ok(state + .observation_events + .iter() + .filter(|event| &event.worker_ref == worker_ref && event.sequence > cursor.sequence) + .cloned() + .collect()) + } + + /// Subscribe to live protocol observation events. + #[cfg(feature = "ws-server")] + pub fn subscribe_worker_observation( + &self, + ) -> Result, RuntimeError> { + Ok(self.lock()?.observation_tx.subscribe()) + } + + /// Append a Worker protocol event to the observation bus. + #[cfg(feature = "ws-server")] + pub fn observe_worker_event( + &self, + worker_ref: &WorkerRef, + payload: protocol::Event, + ) -> Result { + let mut state = self.lock()?; + state.ensure_worker_ref(worker_ref)?; + let event = state.push_worker_observation_event(worker_ref.clone(), payload); + Ok(event) + } + /// Snapshot current diagnostics. pub fn diagnostics(&self) -> Result, RuntimeError> { Ok(self.lock()?.diagnostics.clone()) @@ -465,6 +553,12 @@ struct RuntimeState { workers: BTreeMap, events: Vec, diagnostics: Vec, + #[cfg(feature = "ws-server")] + next_observation_sequence: u64, + #[cfg(feature = "ws-server")] + observation_events: VecDeque, + #[cfg(feature = "ws-server")] + observation_tx: broadcast::Sender, } impl RuntimeState { @@ -482,6 +576,12 @@ impl RuntimeState { workers: BTreeMap::new(), events: Vec::new(), diagnostics: Vec::new(), + #[cfg(feature = "ws-server")] + next_observation_sequence: 1, + #[cfg(feature = "ws-server")] + observation_events: VecDeque::new(), + #[cfg(feature = "ws-server")] + observation_tx: broadcast::channel(256).0, } } @@ -505,6 +605,12 @@ impl RuntimeState { workers: BTreeMap::new(), events: Vec::new(), diagnostics: Vec::new(), + #[cfg(feature = "ws-server")] + next_observation_sequence: 1, + #[cfg(feature = "ws-server")] + observation_events: VecDeque::new(), + #[cfg(feature = "ws-server")] + observation_tx: broadcast::channel(256).0, } } @@ -762,6 +868,54 @@ impl RuntimeState { self.next_event_id.saturating_sub(1) } + #[cfg(feature = "ws-server")] + fn validate_worker_observation_cursor( + &self, + worker_ref: &WorkerRef, + cursor: WorkerObservationCursor, + ) -> Result<(), RuntimeError> { + if let Some(first) = self + .observation_events + .iter() + .find(|event| &event.worker_ref == worker_ref) + { + if cursor.sequence != 0 && cursor.sequence < first.sequence { + return Err(RuntimeError::InvalidRequest(format!( + "worker observation cursor {} is expired for worker {}", + cursor.encode(), + worker_ref.worker_id + ))); + } + } + if cursor.sequence >= self.next_observation_sequence { + return Err(RuntimeError::InvalidRequest(format!( + "worker observation cursor {} is unknown for worker {}", + cursor.encode(), + worker_ref.worker_id + ))); + } + Ok(()) + } + + #[cfg(feature = "ws-server")] + fn push_worker_observation_event( + &mut self, + worker_ref: WorkerRef, + payload: protocol::Event, + ) -> WorkerObservationEvent { + const MAX_OBSERVATION_BACKLOG: usize = 1024; + + let sequence = self.next_observation_sequence; + self.next_observation_sequence += 1; + let event = WorkerObservationEvent::new(sequence, worker_ref, payload); + self.observation_events.push_back(event.clone()); + while self.observation_events.len() > MAX_OBSERVATION_BACKLOG { + self.observation_events.pop_front(); + } + let _ = self.observation_tx.send(event.clone()); + event + } + fn push_diagnostic( &mut self, severity: DiagnosticSeverity, diff --git a/crates/workspace-server/Cargo.toml b/crates/workspace-server/Cargo.toml index cab80896..4fa7b487 100644 --- a/crates/workspace-server/Cargo.toml +++ b/crates/workspace-server/Cargo.toml @@ -7,10 +7,12 @@ publish = false [dependencies] async-trait.workspace = true -axum.workspace = true +axum = { workspace = true, features = ["ws"] } chrono = { version = "0.4", default-features = false, features = ["clock"] } manifest = { workspace = true } +futures.workspace = true pod-store = { workspace = true } +protocol = { workspace = true } project-record.workspace = true rusqlite.workspace = true serde = { workspace = true, features = ["derive"] } @@ -20,6 +22,8 @@ sha2.workspace = true thiserror.workspace = true ticket.workspace = true tokio = { workspace = true, features = ["fs", "macros", "net", "rt-multi-thread", "sync"] } +tokio-tungstenite.workspace = true +worker-runtime = { workspace = true, features = ["ws-server"] } toml.workspace = true tracing.workspace = true uuid = { workspace = true, features = ["v7"] } diff --git a/crates/workspace-server/src/lib.rs b/crates/workspace-server/src/lib.rs index 817c6f1e..5928c3b0 100644 --- a/crates/workspace-server/src/lib.rs +++ b/crates/workspace-server/src/lib.rs @@ -6,6 +6,7 @@ pub mod hosts; pub mod identity; +pub mod observation; pub mod records; pub mod repositories; pub mod server; diff --git a/crates/workspace-server/src/observation.rs b/crates/workspace-server/src/observation.rs new file mode 100644 index 00000000..a750877f --- /dev/null +++ b/crates/workspace-server/src/observation.rs @@ -0,0 +1,477 @@ +use std::collections::{BTreeMap, VecDeque}; +use std::sync::{Arc, Mutex}; + +use axum::http::StatusCode; +use futures::{SinkExt, StreamExt}; +use serde::{Deserialize, Serialize}; +use tokio_tungstenite::connect_async; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; +use tokio_tungstenite::tungstenite::{Error as TungsteniteError, Message as TungsteniteMessage}; +use worker_runtime::http_server::{RuntimeWorkerEventWsEnvelope, RuntimeWorkerEventWsFrame}; + +/// Backend-private source for a runtime worker observation stream. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct RuntimeObservationSourceConfig { + pub runtime_id: String, + pub worker_id: String, + pub endpoint: String, + pub bearer_token: Option, +} + +/// Event consumed from a Runtime-owned worker observation WebSocket. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RuntimeObservationUpstreamEvent { + pub runtime_id: String, + pub worker_id: String, + pub runtime_cursor: String, + pub payload: protocol::Event, +} + +/// Backend-local frame exposed to browser/future-TUI clients. +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum ClientWorkerEventWsFrame { + Event { + envelope: ClientWorkerEventWsEnvelope, + }, + Diagnostic { + diagnostic: ClientWorkerEventWsDiagnostic, + }, +} + +/// Backend-owned opaque event envelope. It intentionally omits Runtime endpoints, +/// credentials, sockets and session paths. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ClientWorkerEventWsEnvelope { + pub cursor: String, + pub event_id: String, + pub runtime_id: String, + pub worker_id: String, + pub payload: protocol::Event, +} + +/// Client-facing typed observation diagnostic. +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct ClientWorkerEventWsDiagnostic { + pub code: String, + pub message: String, +} + +#[derive(Clone, Debug, Default, Deserialize)] +pub struct ClientWorkerEventsWsQuery { + pub cursor: Option, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ObservationProxyError { + RuntimeUnavailable(String), + WorkerNotFound(String), + CursorMalformed(String), + CursorUnknownOrExpired(String), + UpstreamDisconnect(String), + MalformedFrame(String), + ObservationOnly, +} + +impl ObservationProxyError { + pub fn code(&self) -> &'static str { + match self { + ObservationProxyError::RuntimeUnavailable(_) => "backend.runtime_unavailable", + ObservationProxyError::WorkerNotFound(_) => "backend.worker_not_found", + ObservationProxyError::CursorMalformed(_) => "backend.cursor_malformed", + ObservationProxyError::CursorUnknownOrExpired(_) => "backend.cursor_unknown_or_expired", + ObservationProxyError::UpstreamDisconnect(_) => "backend.upstream_disconnect", + ObservationProxyError::MalformedFrame(_) => "backend.malformed_frame", + ObservationProxyError::ObservationOnly => "backend.observation_only", + } + } + + pub fn message(&self) -> &str { + match self { + ObservationProxyError::RuntimeUnavailable(message) + | ObservationProxyError::WorkerNotFound(message) + | ObservationProxyError::CursorMalformed(message) + | ObservationProxyError::CursorUnknownOrExpired(message) + | ObservationProxyError::UpstreamDisconnect(message) + | ObservationProxyError::MalformedFrame(message) => message, + ObservationProxyError::ObservationOnly => { + "backend worker event WebSocket is observation-only" + } + } + } +} + +impl ClientWorkerEventWsFrame { + pub fn event(envelope: ClientWorkerEventWsEnvelope) -> Self { + Self::Event { envelope } + } + + pub fn diagnostic(error: ObservationProxyError) -> Self { + Self::Diagnostic { + diagnostic: ClientWorkerEventWsDiagnostic { + code: error.code().to_string(), + message: error.message().to_string(), + }, + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub struct BackendObservationCursor { + pub sequence: u64, +} + +impl BackendObservationCursor { + pub fn new(sequence: u64) -> Self { + Self { sequence } + } + + pub fn zero() -> Self { + Self { sequence: 0 } + } + + pub fn encode(self) -> String { + format!("bo_{:016x}", self.sequence) + } + + pub fn decode(value: &str) -> Option { + let encoded = value.strip_prefix("bo_")?; + if encoded.len() != 16 { + return None; + } + u64::from_str_radix(encoded, 16) + .ok() + .map(|sequence| Self { sequence }) + } +} + +#[derive(Debug, Default)] +struct BackendObservationState { + next_sequence: u64, + history: BTreeMap>, +} + +impl BackendObservationState { + fn new() -> Self { + Self { + next_sequence: 1, + history: BTreeMap::new(), + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +struct ObservationKey { + runtime_id: String, + worker_id: String, +} + +#[derive(Clone, Debug)] +struct StoredBackendEvent { + sequence: u64, + runtime_cursor: String, + envelope: ClientWorkerEventWsEnvelope, +} + +#[derive(Clone, Debug)] +pub struct BackendObservationOpen { + pub replay: Vec, + pub runtime_cursor: Option, + pub backend_cursor: BackendObservationCursor, +} + +/// Backend-owned in-memory v0 observation proxy state. +#[derive(Clone, Debug)] +pub struct BackendObservationProxy { + sources: Arc>, + state: Arc>, +} + +impl BackendObservationProxy { + pub fn new(sources: Vec) -> Self { + let sources = sources + .into_iter() + .map(|source| { + ( + ObservationKey { + runtime_id: source.runtime_id.clone(), + worker_id: source.worker_id.clone(), + }, + source, + ) + }) + .collect(); + Self { + sources: Arc::new(sources), + state: Arc::new(Mutex::new(BackendObservationState::new())), + } + } + + pub fn source( + &self, + runtime_id: &str, + worker_id: &str, + ) -> Result { + self.sources + .get(&ObservationKey { + runtime_id: runtime_id.to_string(), + worker_id: worker_id.to_string(), + }) + .cloned() + .ok_or_else(|| { + ObservationProxyError::WorkerNotFound(format!( + "worker {worker_id} is not registered for runtime {runtime_id}" + )) + }) + } + + pub fn open( + &self, + runtime_id: &str, + worker_id: &str, + cursor: Option<&str>, + ) -> Result { + let key = ObservationKey { + runtime_id: runtime_id.to_string(), + worker_id: worker_id.to_string(), + }; + let cursor = match cursor { + Some(raw) => BackendObservationCursor::decode(raw).ok_or_else(|| { + ObservationProxyError::CursorMalformed(format!( + "malformed backend observation cursor: {raw}" + )) + })?, + None => BackendObservationCursor::zero(), + }; + let state = self.state.lock().map_err(|_| { + ObservationProxyError::RuntimeUnavailable( + "backend observation state lock poisoned".into(), + ) + })?; + let history = state.history.get(&key); + let replay: Vec<_> = history + .into_iter() + .flat_map(|events| events.iter()) + .filter(|event| event.sequence > cursor.sequence) + .cloned() + .collect(); + if cursor.sequence != 0 { + let found = history + .into_iter() + .flat_map(|events| events.iter()) + .any(|event| event.sequence == cursor.sequence); + if !found { + return Err(ObservationProxyError::CursorUnknownOrExpired(format!( + "backend observation cursor {} is unknown or expired for runtime {runtime_id} worker {worker_id}", + cursor.encode() + ))); + } + } + let runtime_cursor = replay + .last() + .map(|event| event.runtime_cursor.clone()) + .or_else(|| { + history.and_then(|events| { + events + .iter() + .find(|event| event.sequence == cursor.sequence) + .map(|event| event.runtime_cursor.clone()) + }) + }); + Ok(BackendObservationOpen { + replay: replay.into_iter().map(|event| event.envelope).collect(), + runtime_cursor, + backend_cursor: cursor, + }) + } + + pub fn store( + &self, + event: RuntimeObservationUpstreamEvent, + ) -> Result { + let mut state = self.state.lock().map_err(|_| { + ObservationProxyError::RuntimeUnavailable( + "backend observation state lock poisoned".into(), + ) + })?; + let sequence = state.next_sequence; + state.next_sequence += 1; + let cursor = BackendObservationCursor::new(sequence).encode(); + let envelope = ClientWorkerEventWsEnvelope { + cursor: cursor.clone(), + event_id: cursor, + runtime_id: event.runtime_id.clone(), + worker_id: event.worker_id.clone(), + payload: event.payload, + }; + let key = ObservationKey { + runtime_id: event.runtime_id, + worker_id: event.worker_id, + }; + let history = state.history.entry(key).or_default(); + history.push_back(StoredBackendEvent { + sequence, + runtime_cursor: event.runtime_cursor, + envelope: envelope.clone(), + }); + while history.len() > 1024 { + history.pop_front(); + } + Ok(envelope) + } +} + +fn map_runtime_connect_error(error: TungsteniteError) -> ObservationProxyError { + match error { + TungsteniteError::Http(response) if response.status() == StatusCode::NOT_FOUND => { + ObservationProxyError::WorkerNotFound( + "runtime worker observation endpoint returned 404 not found".into(), + ) + } + TungsteniteError::Http(response) if response.status() == StatusCode::BAD_REQUEST => { + ObservationProxyError::CursorMalformed( + "runtime worker observation endpoint rejected the request as malformed".into(), + ) + } + TungsteniteError::Http(response) => ObservationProxyError::RuntimeUnavailable(format!( + "runtime worker observation endpoint rejected WebSocket upgrade with status {}", + response.status() + )), + error => ObservationProxyError::RuntimeUnavailable(format!( + "failed to connect runtime WebSocket: {error}" + )), + } +} + +fn map_runtime_diagnostic(code: String, message: String) -> ObservationProxyError { + match code.as_str() { + "runtime.worker_not_found" => ObservationProxyError::WorkerNotFound(message), + "runtime.cursor_malformed" => ObservationProxyError::CursorMalformed(message), + "runtime.cursor_unknown_or_expired" | "runtime.cursor_expired" => { + ObservationProxyError::CursorUnknownOrExpired(message) + } + "runtime.unavailable" => ObservationProxyError::RuntimeUnavailable(message), + "runtime.upstream_closed" | "runtime.websocket_error" => { + ObservationProxyError::UpstreamDisconnect(message) + } + "runtime.serialize_failed" => ObservationProxyError::MalformedFrame(message), + "runtime.observation_only" => ObservationProxyError::ObservationOnly, + _ => ObservationProxyError::RuntimeUnavailable(format!( + "runtime diagnostic {code}: {message}" + )), + } +} + +pub struct RuntimeWsObservationClient { + runtime_id: String, + worker_id: String, + stream: tokio_tungstenite::WebSocketStream< + tokio_tungstenite::MaybeTlsStream, + >, +} + +impl RuntimeWsObservationClient { + pub async fn connect( + source: &RuntimeObservationSourceConfig, + runtime_cursor: Option<&str>, + ) -> Result { + let mut endpoint = source.endpoint.clone(); + if let Some(cursor) = runtime_cursor { + let separator = if endpoint.contains('?') { '&' } else { '?' }; + endpoint.push(separator); + endpoint.push_str("cursor="); + endpoint.push_str(cursor); + } + let mut request = endpoint.into_client_request().map_err(|error| { + ObservationProxyError::RuntimeUnavailable(format!( + "failed to build runtime WebSocket request: {error}" + )) + })?; + if let Some(token) = &source.bearer_token { + request.headers_mut().insert( + "authorization", + format!("Bearer {token}").parse().map_err(|error| { + ObservationProxyError::RuntimeUnavailable(format!( + "failed to build runtime authorization header: {error}" + )) + })?, + ); + } + let (stream, _) = connect_async(request) + .await + .map_err(map_runtime_connect_error)?; + Ok(Self { + runtime_id: source.runtime_id.clone(), + worker_id: source.worker_id.clone(), + stream, + }) + } + + pub async fn next_event( + &mut self, + ) -> Result { + loop { + let Some(message) = self.stream.next().await else { + return Err(ObservationProxyError::UpstreamDisconnect( + "runtime WebSocket closed".into(), + )); + }; + let message = message.map_err(|error| { + ObservationProxyError::UpstreamDisconnect(format!( + "runtime WebSocket receive error: {error}" + )) + })?; + let text = match message { + TungsteniteMessage::Text(text) => text, + TungsteniteMessage::Close(_) => { + return Err(ObservationProxyError::UpstreamDisconnect( + "runtime WebSocket closed".into(), + )); + } + TungsteniteMessage::Ping(payload) => { + self.stream + .send(TungsteniteMessage::Pong(payload)) + .await + .map_err(|error| { + ObservationProxyError::UpstreamDisconnect(format!( + "failed to reply to runtime ping: {error}" + )) + })?; + continue; + } + TungsteniteMessage::Pong(_) => continue, + TungsteniteMessage::Binary(_) | TungsteniteMessage::Frame(_) => { + return Err(ObservationProxyError::MalformedFrame( + "runtime sent a non-text observation frame".into(), + )); + } + }; + let frame: RuntimeWorkerEventWsFrame = + serde_json::from_str(&text).map_err(|error| { + ObservationProxyError::MalformedFrame(format!( + "failed to decode runtime observation frame: {error}" + )) + })?; + match frame { + RuntimeWorkerEventWsFrame::Event { envelope } => { + return Ok(self.map_envelope(envelope)); + } + RuntimeWorkerEventWsFrame::Diagnostic { diagnostic } => { + return Err(map_runtime_diagnostic(diagnostic.code, diagnostic.message)); + } + } + } + } + + fn map_envelope( + &self, + envelope: RuntimeWorkerEventWsEnvelope, + ) -> RuntimeObservationUpstreamEvent { + RuntimeObservationUpstreamEvent { + runtime_id: self.runtime_id.clone(), + worker_id: self.worker_id.clone(), + runtime_cursor: envelope.cursor, + payload: envelope.payload, + } + } +} diff --git a/crates/workspace-server/src/server.rs b/crates/workspace-server/src/server.rs index 983f8f9d..2bdf19cd 100644 --- a/crates/workspace-server/src/server.rs +++ b/crates/workspace-server/src/server.rs @@ -1,12 +1,14 @@ use std::path::{Component, Path, PathBuf}; use std::sync::Arc; +use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade}; use axum::extract::{Path as AxumPath, Query, State}; use axum::http::header::CONTENT_TYPE; use axum::http::{StatusCode, Uri}; use axum::response::{IntoResponse, Response}; use axum::routing::get; use axum::{Json, Router}; +use futures::StreamExt; use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; @@ -15,6 +17,10 @@ use crate::hosts::{ RuntimeSummary, WorkerSummary, }; use crate::identity::WorkspaceIdentity; +use crate::observation::{ + BackendObservationProxy, ClientWorkerEventWsFrame, ClientWorkerEventsWsQuery, + ObservationProxyError, RuntimeObservationSourceConfig, RuntimeWsObservationClient, +}; use crate::records::{ LocalProjectRecordReader, ObjectiveDetail, ProjectRecordList, TicketDetail, TicketSummary, }; @@ -39,6 +45,7 @@ pub struct ServerConfig { pub auth: AuthConfig, pub max_records: usize, pub local_runtime_data_dir: Option, + pub runtime_event_sources: Vec, } impl ServerConfig { @@ -55,6 +62,7 @@ impl ServerConfig { }, max_records: 200, local_runtime_data_dir: manifest::paths::data_dir(), + runtime_event_sources: Vec::new(), } } } @@ -65,6 +73,7 @@ pub struct WorkspaceApi { store: Arc, records: LocalProjectRecordReader, runtime: Arc, + observation_proxy: BackendObservationProxy, } impl WorkspaceApi { @@ -83,11 +92,13 @@ impl WorkspaceApi { config.workspace_root.clone(), config.local_runtime_data_dir.clone(), ))); + let observation_proxy = BackendObservationProxy::new(config.runtime_event_sources.clone()); Ok(Self { records: LocalProjectRecordReader::new(config.workspace_root.clone()), config, store, runtime, + observation_proxy, }) } @@ -128,6 +139,10 @@ pub fn build_router(api: WorkspaceApi) -> Router { .route("/api/hosts", get(list_hosts)) .route("/api/runtimes", get(list_runtimes)) .route("/api/workers", get(list_workers)) + .route( + "/api/runtimes/{runtime_id}/workers/{worker_id}/events/ws", + get(worker_observation_ws), + ) .route("/api/hosts/{host_id}/workers", get(list_host_workers)) .fallback(get(static_or_spa_fallback)) .with_state(api) @@ -423,6 +438,144 @@ async fn list_workers( workers_response(api).map(Json) } +async fn worker_observation_ws( + State(api): State, + AxumPath((runtime_id, worker_id)): AxumPath<(String, String)>, + Query(query): Query, + ws: WebSocketUpgrade, +) -> impl IntoResponse { + match api.observation_proxy.source(&runtime_id, &worker_id) { + Ok(source) => ws.on_upgrade(move |socket| { + worker_observation_ws_session(api.observation_proxy, source, query, socket) + }), + Err(error) => { + let status = match error { + ObservationProxyError::WorkerNotFound(_) => StatusCode::NOT_FOUND, + _ => StatusCode::BAD_REQUEST, + }; + ( + status, + Json(serde_json::json!({ + "error": error.code(), + "message": error.message(), + })), + ) + .into_response() + } + } +} + +async fn worker_observation_ws_session( + proxy: BackendObservationProxy, + source: RuntimeObservationSourceConfig, + query: ClientWorkerEventsWsQuery, + mut socket: WebSocket, +) { + let open = match proxy.open( + &source.runtime_id, + &source.worker_id, + query.cursor.as_deref(), + ) { + Ok(open) => open, + Err(error) => { + let _ = send_client_ws_frame(&mut socket, ClientWorkerEventWsFrame::diagnostic(error)) + .await; + return; + } + }; + + let mut backend_cursor = open.backend_cursor; + for envelope in open.replay { + backend_cursor = crate::observation::BackendObservationCursor::decode(&envelope.cursor) + .unwrap_or(backend_cursor); + if !send_client_ws_frame(&mut socket, ClientWorkerEventWsFrame::event(envelope)).await { + return; + } + } + + let mut upstream = + match RuntimeWsObservationClient::connect(&source, open.runtime_cursor.as_deref()).await { + Ok(client) => client, + Err(error) => { + let _ = + send_client_ws_frame(&mut socket, ClientWorkerEventWsFrame::diagnostic(error)) + .await; + return; + } + }; + + loop { + tokio::select! { + inbound = socket.next() => { + match inbound { + Some(Ok(WsMessage::Close(_))) | None => return, + Some(Ok(WsMessage::Ping(payload))) => { + if socket.send(WsMessage::Pong(payload)).await.is_err() { + return; + } + } + Some(Ok(WsMessage::Pong(_))) => {} + Some(Ok(_)) => { + let _ = send_client_ws_frame( + &mut socket, + ClientWorkerEventWsFrame::diagnostic(ObservationProxyError::ObservationOnly), + ).await; + return; + } + Some(Err(error)) => { + let _ = send_client_ws_frame( + &mut socket, + ClientWorkerEventWsFrame::diagnostic( + ObservationProxyError::MalformedFrame(format!( + "client WebSocket receive error: {error}" + )), + ), + ).await; + return; + } + } + } + upstream_event = upstream.next_event() => { + match upstream_event { + Ok(event) => match proxy.store(event) { + Ok(envelope) => { + backend_cursor = crate::observation::BackendObservationCursor::decode(&envelope.cursor) + .unwrap_or(backend_cursor); + if !send_client_ws_frame(&mut socket, ClientWorkerEventWsFrame::event(envelope)).await { + return; + } + } + Err(error) => { + let _ = send_client_ws_frame(&mut socket, ClientWorkerEventWsFrame::diagnostic(error)).await; + return; + } + }, + Err(error) => { + let _ = send_client_ws_frame(&mut socket, ClientWorkerEventWsFrame::diagnostic(error)).await; + return; + } + } + } + } + } +} + +async fn send_client_ws_frame(socket: &mut WebSocket, frame: ClientWorkerEventWsFrame) -> bool { + match serde_json::to_string(&frame) { + Ok(text) => socket.send(WsMessage::Text(text.into())).await.is_ok(), + Err(error) => { + let fallback = + ClientWorkerEventWsFrame::diagnostic(ObservationProxyError::MalformedFrame( + format!("failed to serialize backend observation frame: {error}"), + )); + let Ok(text) = serde_json::to_string(&fallback) else { + return false; + }; + socket.send(WsMessage::Text(text.into())).await.is_ok() + } + } +} + async fn list_host_workers( State(api): State, AxumPath(host_id): AxumPath, @@ -636,9 +789,13 @@ mod tests { use super::*; use axum::body::{Body, to_bytes}; use axum::http::Request; + use futures::{SinkExt, StreamExt}; use serde_json::Value; + use tokio_tungstenite::connect_async; + use tokio_tungstenite::tungstenite::Message; use tower::ServiceExt; + use crate::observation::ClientWorkerEventWsDiagnostic; use crate::store::SqliteWorkspaceStore; const TEST_WORKSPACE_ID: &str = "0192f0e8-4d84-7d6e-a000-000000000001"; @@ -844,6 +1001,263 @@ mod tests { ); } + #[tokio::test] + async fn proxies_worker_observation_ws_with_backend_cursors_and_diagnostics() { + let runtime = worker_runtime::Runtime::new_memory(); + let worker = runtime + .create_worker(worker_runtime::catalog::CreateWorkerRequest::default()) + .unwrap(); + let runtime_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let runtime_addr = runtime_listener.local_addr().unwrap(); + tokio::spawn({ + let runtime = runtime.clone(); + async move { + worker_runtime::http_server::serve_runtime_http(runtime, runtime_listener, None) + .await + .unwrap() + } + }); + + let dir = tempfile::tempdir().unwrap(); + let store = SqliteWorkspaceStore::in_memory().unwrap(); + let mut config = ServerConfig::local_dev(dir.path(), test_identity()); + config.local_runtime_data_dir = Some(dir.path().join("data")); + config + .runtime_event_sources + .push(RuntimeObservationSourceConfig { + runtime_id: "runtime-a".into(), + worker_id: "worker-a".into(), + endpoint: format!( + "ws://{runtime_addr}/v1/workers/{}/events/ws", + worker.worker_ref.worker_id + ), + bearer_token: None, + }); + let api = WorkspaceApi::new(config, Arc::new(store)).await.unwrap(); + let app_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let app_addr = app_listener.local_addr().unwrap(); + tokio::spawn(async move { axum::serve(app_listener, build_router(api)).await.unwrap() }); + + let url = format!("ws://{app_addr}/api/runtimes/runtime-a/workers/worker-a/events/ws"); + let (mut stream, _) = connect_async(&url).await.unwrap(); + let snapshot = next_client_frame(&mut stream).await; + let ClientWorkerEventWsFrame::Event { envelope: snapshot } = snapshot else { + panic!("expected snapshot event"); + }; + assert_eq!(snapshot.runtime_id, "runtime-a"); + assert_eq!(snapshot.worker_id, "worker-a"); + assert!(matches!(snapshot.payload, protocol::Event::Snapshot { .. })); + + runtime + .observe_worker_event( + &worker.worker_ref, + protocol::Event::TextDelta { + text: "live".into(), + }, + ) + .unwrap(); + let live = next_client_frame(&mut stream).await; + let ClientWorkerEventWsFrame::Event { envelope: live } = live else { + panic!("expected live event"); + }; + assert_eq!(live.runtime_id, "runtime-a"); + assert_eq!(live.worker_id, "worker-a"); + assert!(matches!(live.payload, protocol::Event::TextDelta { .. })); + + let (mut resumed, _) = connect_async(format!("{url}?cursor={}", live.cursor)) + .await + .unwrap(); + let _snapshot = next_client_frame(&mut resumed).await; + runtime + .observe_worker_event( + &worker.worker_ref, + protocol::Event::TextDone { + text: "done".into(), + }, + ) + .unwrap(); + let resumed_event = next_client_frame(&mut resumed).await; + let ClientWorkerEventWsFrame::Event { + envelope: resumed_event, + } = resumed_event + else { + panic!("expected resumed live event"); + }; + assert_ne!(resumed_event.cursor, live.cursor); + assert!(matches!( + resumed_event.payload, + protocol::Event::TextDone { .. } + )); + + let (mut malformed, _) = connect_async(format!("{url}?cursor=bad")).await.unwrap(); + let diagnostic = next_client_frame(&mut malformed).await; + let ClientWorkerEventWsFrame::Diagnostic { diagnostic } = diagnostic else { + panic!("expected malformed cursor diagnostic"); + }; + assert_eq!(diagnostic.code, "backend.cursor_malformed"); + + stream.send(Message::Text("{}".into())).await.unwrap(); + let mut saw_observation_only = false; + for _ in 0..3 { + if let ClientWorkerEventWsFrame::Diagnostic { diagnostic } = + next_client_frame(&mut stream).await + { + assert_eq!(diagnostic.code, "backend.observation_only"); + saw_observation_only = true; + break; + } + } + assert!(saw_observation_only, "expected observation-only diagnostic"); + } + + #[tokio::test] + async fn proxy_reports_unknown_backend_cursor_before_upstream_connect() { + let source = RuntimeObservationSourceConfig { + runtime_id: "runtime-a".into(), + worker_id: "worker-a".into(), + endpoint: "ws://127.0.0.1:9/not-used".into(), + bearer_token: None, + }; + let (url, _dir) = spawn_workspace_proxy(source).await; + let (mut stream, _) = connect_async(format!("{url}?cursor=bo_ffffffffffffffff")) + .await + .unwrap(); + let diagnostic = next_client_diagnostic(&mut stream).await; + assert_eq!(diagnostic.code, "backend.cursor_unknown_or_expired"); + } + + #[tokio::test] + async fn proxy_maps_runtime_cursor_diagnostic_to_typed_backend_diagnostic() { + let (_runtime, _worker_ref, endpoint) = spawn_runtime_worker().await; + let source = RuntimeObservationSourceConfig { + runtime_id: "runtime-a".into(), + worker_id: "worker-a".into(), + endpoint: format!("{endpoint}?cursor=wo_ffffffffffffffff"), + bearer_token: None, + }; + let (url, _dir) = spawn_workspace_proxy(source).await; + let (mut stream, _) = connect_async(&url).await.unwrap(); + assert!(matches!( + next_client_frame(&mut stream).await, + ClientWorkerEventWsFrame::Event { envelope } if matches!(envelope.payload, protocol::Event::Snapshot { .. }) + )); + let diagnostic = next_client_diagnostic(&mut stream).await; + assert_eq!(diagnostic.code, "backend.cursor_unknown_or_expired"); + } + + #[tokio::test] + async fn proxy_maps_runtime_worker_not_found_http_404_to_typed_backend_diagnostic() { + let (_runtime, _worker_ref, endpoint) = spawn_runtime_worker().await; + let endpoint = endpoint.replace("/events/ws", "/missing-worker/events/ws"); + let source = RuntimeObservationSourceConfig { + runtime_id: "runtime-a".into(), + worker_id: "worker-a".into(), + endpoint, + bearer_token: None, + }; + let (url, _dir) = spawn_workspace_proxy(source).await; + let (mut stream, _) = connect_async(&url).await.unwrap(); + let diagnostic = next_client_diagnostic(&mut stream).await; + assert_eq!(diagnostic.code, "backend.worker_not_found"); + } + + #[tokio::test] + async fn proxy_reports_actual_upstream_disconnect_separately() { + let endpoint = spawn_closing_runtime_ws().await; + let source = RuntimeObservationSourceConfig { + runtime_id: "runtime-a".into(), + worker_id: "worker-a".into(), + endpoint, + bearer_token: None, + }; + let (url, _dir) = spawn_workspace_proxy(source).await; + let (mut stream, _) = connect_async(&url).await.unwrap(); + let diagnostic = next_client_diagnostic(&mut stream).await; + assert_eq!(diagnostic.code, "backend.upstream_disconnect"); + } + + async fn next_client_frame( + stream: &mut tokio_tungstenite::WebSocketStream< + tokio_tungstenite::MaybeTlsStream, + >, + ) -> ClientWorkerEventWsFrame { + let message = stream.next().await.unwrap().unwrap(); + let Message::Text(text) = message else { + panic!("expected text frame"); + }; + serde_json::from_str(&text).unwrap() + } + + async fn next_client_diagnostic( + stream: &mut tokio_tungstenite::WebSocketStream< + tokio_tungstenite::MaybeTlsStream, + >, + ) -> ClientWorkerEventWsDiagnostic { + match next_client_frame(stream).await { + ClientWorkerEventWsFrame::Diagnostic { diagnostic } => diagnostic, + ClientWorkerEventWsFrame::Event { envelope } => { + panic!("expected diagnostic, got event: {envelope:?}") + } + } + } + + async fn spawn_runtime_worker() -> ( + worker_runtime::Runtime, + worker_runtime::identity::WorkerRef, + String, + ) { + let runtime = worker_runtime::Runtime::new_memory(); + let worker = runtime + .create_worker(worker_runtime::catalog::CreateWorkerRequest::default()) + .unwrap(); + let runtime_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let runtime_addr = runtime_listener.local_addr().unwrap(); + tokio::spawn({ + let runtime = runtime.clone(); + async move { + worker_runtime::http_server::serve_runtime_http(runtime, runtime_listener, None) + .await + .unwrap() + } + }); + let endpoint = format!( + "ws://{runtime_addr}/v1/workers/{}/events/ws", + worker.worker_ref.worker_id + ); + (runtime, worker.worker_ref, endpoint) + } + + async fn spawn_workspace_proxy( + source: RuntimeObservationSourceConfig, + ) -> (String, tempfile::TempDir) { + let dir = tempfile::tempdir().unwrap(); + let store = SqliteWorkspaceStore::in_memory().unwrap(); + let mut config = ServerConfig::local_dev(dir.path(), test_identity()); + config.local_runtime_data_dir = Some(dir.path().join("data")); + let runtime_id = source.runtime_id.clone(); + let worker_id = source.worker_id.clone(); + config.runtime_event_sources.push(source); + let api = WorkspaceApi::new(config, Arc::new(store)).await.unwrap(); + let app_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let app_addr = app_listener.local_addr().unwrap(); + tokio::spawn(async move { axum::serve(app_listener, build_router(api)).await.unwrap() }); + ( + format!("ws://{app_addr}/api/runtimes/{runtime_id}/workers/{worker_id}/events/ws"), + dir, + ) + } + + async fn spawn_closing_runtime_ws() -> String { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let mut websocket = tokio_tungstenite::accept_async(stream).await.unwrap(); + let _ = websocket.close(None).await; + }); + format!("ws://{addr}/events/ws") + } + async fn get_json(app: Router, uri: &str) -> Value { let response = app .oneshot(Request::builder().uri(uri).body(Body::empty()).unwrap()) diff --git a/package.nix b/package.nix index c7d0e37d..662a3d63 100644 --- a/package.nix +++ b/package.nix @@ -43,7 +43,7 @@ rustPlatform.buildRustPackage rec { filter = sourceFilter; }; - cargoHash = "sha256-dv2MrgL0IB+ZisZQ9QnA0kdvKJtzEm0pKUpvofgqSB8="; + cargoHash = "sha256-5vmZTzO5PSRPHvQfiK0rNiBkHNyc0y3BCeDJNFJaAqA="; depsExtraArgs = { # Older fetchCargoVendor utilities used crates.io's API download endpoint,