diff --git a/crates/client/src/pod_client.rs b/crates/client/src/pod_client.rs index 0988727f..e56324ef 100644 --- a/crates/client/src/pod_client.rs +++ b/crates/client/src/pod_client.rs @@ -5,10 +5,12 @@ use protocol::stream::{JsonLineReader, JsonLineWriter}; use protocol::{Event, Method}; use tokio::net::UnixStream; use tokio::sync::mpsc; +use tokio::task::JoinHandle; pub struct PodClient { writer: JsonLineWriter>, event_rx: mpsc::Receiver, + reader_task: JoinHandle<()>, } impl PodClient { @@ -19,7 +21,7 @@ impl PodClient { let (event_tx, event_rx) = mpsc::channel::(256); - tokio::spawn(async move { + let reader_task = tokio::spawn(async move { let mut reader = JsonLineReader::new(reader); while let Ok(Some(event)) = reader.next::().await { if event_tx.send(event).await.is_err() { @@ -28,7 +30,11 @@ impl PodClient { } }); - Ok(Self { writer, event_rx }) + Ok(Self { + writer, + event_rx, + reader_task, + }) } pub async fn send(&mut self, method: &Method) -> Result<(), io::Error> { @@ -43,3 +49,138 @@ impl PodClient { self.event_rx.recv().await } } + +impl Drop for PodClient { + fn drop(&mut self) { + self.reader_task.abort(); + } +} + +#[cfg(test)] +mod tests { + use std::io::ErrorKind; + use std::time::Duration; + + use protocol::{PodStatus, Segment}; + use tempfile::tempdir; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::UnixListener; + + use super::*; + + async fn assert_peer_closed(stream: &mut UnixStream, reason: &str) { + let mut buf = [0_u8; 1]; + match tokio::time::timeout(Duration::from_secs(1), stream.read(&mut buf)) + .await + .expect(reason) + { + Ok(0) => {} + Err(error) if error.kind() == ErrorKind::ConnectionReset => {} + Ok(n) => panic!("server should observe peer close, read {n} byte(s)"), + Err(error) => panic!("server read failed unexpectedly: {error}"), + } + } + + #[tokio::test] + async fn receives_events_while_client_is_alive() { + let socket_dir = tempdir().unwrap(); + let socket_path = socket_dir.path().join("events.sock"); + let listener = UnixListener::bind(&socket_path).unwrap(); + let server = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let mut writer = JsonLineWriter::new(stream); + writer + .write(&Event::Status { + status: PodStatus::Idle, + }) + .await + .unwrap(); + }); + + let mut client = PodClient::connect(&socket_path).await.unwrap(); + + let event = tokio::time::timeout(Duration::from_secs(1), client.next_event()) + .await + .expect("client should receive event while alive"); + assert!(matches!( + event, + Some(Event::Status { + status: PodStatus::Idle + }) + )); + server.await.unwrap(); + } + + #[tokio::test] + async fn send_writes_methods_while_client_is_alive() { + let socket_dir = tempdir().unwrap(); + let socket_path = socket_dir.path().join("send.sock"); + let listener = UnixListener::bind(&socket_path).unwrap(); + let server = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let mut reader = JsonLineReader::new(stream); + reader.next::().await.unwrap() + }); + + let mut client = PodClient::connect(&socket_path).await.unwrap(); + let method = Method::Run { + input: vec![Segment::text("hello")], + }; + client.send(&method).await.unwrap(); + + let received = tokio::time::timeout(Duration::from_secs(1), server) + .await + .expect("server should receive method while client is alive") + .unwrap(); + match received { + Some(Method::Run { input }) => assert_eq!(input, vec![Segment::text("hello")]), + other => panic!("expected Run method, got {other:?}"), + } + } + + #[tokio::test] + async fn dropping_repeated_clients_closes_server_connections() { + let socket_dir = tempdir().unwrap(); + let socket_path = socket_dir.path().join("drop.sock"); + let listener = UnixListener::bind(&socket_path).unwrap(); + let server = tokio::spawn(async move { + for _ in 0..16 { + let (mut stream, _) = listener.accept().await.unwrap(); + assert_peer_closed( + &mut stream, + "dropped client should close its socket promptly", + ) + .await; + } + }); + + for _ in 0..16 { + let client = PodClient::connect(&socket_path).await.unwrap(); + drop(client); + } + + server.await.unwrap(); + } + + #[tokio::test] + async fn dropping_client_aborts_blocked_reader_task() { + let socket_dir = tempdir().unwrap(); + let socket_path = socket_dir.path().join("blocked-reader.sock"); + let listener = UnixListener::bind(&socket_path).unwrap(); + let server = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + stream.write_all(b"{\"event\"").await.unwrap(); + assert_peer_closed( + &mut stream, + "aborting the blocked client reader should close the socket", + ) + .await; + }); + + let client = PodClient::connect(&socket_path).await.unwrap(); + tokio::task::yield_now().await; + drop(client); + + server.await.unwrap(); + } +}