yoi/crates/pod/tests/controller_test.rs

427 lines
13 KiB
Rust

use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use async_trait::async_trait;
use futures::Stream;
use llm_worker::llm_client::event::{Event as LlmEvent, ResponseStatus, StatusEvent};
use llm_worker::llm_client::{ClientError, LlmClient, Request};
use llm_worker::Worker;
use llm_worker_persistence::FsStore;
use pod::{
Event, Method, Pod, PodController, PodManifest, PodStatus,
};
// ---------------------------------------------------------------------------
// Mock LLM Client
// ---------------------------------------------------------------------------
#[derive(Clone)]
struct MockClient {
responses: Arc<Vec<Vec<LlmEvent>>>,
call_count: Arc<AtomicUsize>,
}
impl MockClient {
fn new(events: Vec<LlmEvent>) -> Self {
Self {
responses: Arc::new(vec![events]),
call_count: Arc::new(AtomicUsize::new(0)),
}
}
}
#[async_trait]
impl LlmClient for MockClient {
async fn stream(
&self,
_request: Request,
) -> Result<Pin<Box<dyn Stream<Item = Result<LlmEvent, ClientError>> + Send>>, ClientError>
{
let count = self.call_count.fetch_add(1, Ordering::SeqCst);
if count >= self.responses.len() {
return Err(ClientError::Api {
status: Some(500),
code: Some("mock".into()),
message: "No more responses".into(),
});
}
let events = self.responses[count].clone();
let stream = futures::stream::iter(events.into_iter().map(Ok));
Ok(Box::pin(stream))
}
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
fn simple_text_events() -> Vec<LlmEvent> {
vec![
LlmEvent::text_block_start(0),
LlmEvent::text_delta(0, "Hello"),
LlmEvent::text_delta(0, " World"),
LlmEvent::text_block_stop(0, None),
LlmEvent::Status(StatusEvent {
status: ResponseStatus::Completed,
}),
]
}
const MANIFEST_TOML: &str = r#"
[pod]
name = "test-pod"
[provider]
kind = "anthropic"
model = "test-model"
[worker]
max_tokens = 100
"#;
async fn make_pod(client: MockClient) -> Pod<MockClient, FsStore> {
let manifest = PodManifest::from_toml(MANIFEST_TOML).unwrap();
let tmp = tempfile::tempdir().unwrap();
let store = FsStore::new(tmp.path()).await.unwrap();
// Leak tempdir to keep it alive
std::mem::forget(tmp);
let worker = Worker::new(client);
Pod::new(manifest, worker, store, None).await.unwrap()
}
use pod::PodHandle;
async fn spawn_controller(pod: Pod<MockClient, FsStore>) -> PodHandle {
let tmp = tempfile::tempdir().unwrap();
let runtime_base = tmp.path().to_owned();
// Leak tempdir so it survives the test
std::mem::forget(tmp);
PodController::spawn(pod, &runtime_base).await.unwrap()
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[tokio::test]
async fn shared_state_starts_idle() {
let client = MockClient::new(simple_text_events());
let pod = make_pod(client).await;
let handle = spawn_controller(pod).await;
assert_eq!(handle.shared_state.get_status(), PodStatus::Idle);
}
#[tokio::test]
async fn run_updates_shared_state_to_idle_after_completion() {
let client = MockClient::new(simple_text_events());
let pod = make_pod(client).await;
let handle = spawn_controller(pod).await;
handle
.send(Method::Run {
input: "Hello".into(),
})
.await
.unwrap();
// Wait for the run to complete
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert_eq!(handle.shared_state.get_status(), PodStatus::Idle);
}
#[tokio::test]
async fn run_populates_history() {
let client = MockClient::new(simple_text_events());
let pod = make_pod(client).await;
let handle = spawn_controller(pod).await;
handle
.send(Method::Run {
input: "Hello".into(),
})
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let history = handle.shared_state.history_json();
assert_ne!(history, "[]");
let parsed: serde_json::Value = serde_json::from_str(&history).unwrap();
assert!(parsed.is_array());
assert!(parsed.as_array().unwrap().len() >= 2); // user + assistant
}
#[tokio::test]
async fn events_are_broadcast() {
let client = MockClient::new(simple_text_events());
let pod = make_pod(client).await;
let handle = spawn_controller(pod).await;
let mut rx = handle.subscribe();
handle
.send(Method::Run {
input: "Hello".into(),
})
.await
.unwrap();
let mut saw_turn_start = false;
let mut saw_text_delta = false;
let mut saw_text_done = false;
let mut saw_turn_end = false;
// Collect events with a timeout
let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(2);
loop {
tokio::select! {
event = rx.recv() => {
match event {
Ok(Event::TurnStart { .. }) => saw_turn_start = true,
Ok(Event::TextDelta { .. }) => saw_text_delta = true,
Ok(Event::TextDone { .. }) => saw_text_done = true,
Ok(Event::TurnEnd { .. }) => {
saw_turn_end = true;
break;
}
Err(_) => break,
_ => {}
}
}
_ = tokio::time::sleep_until(deadline) => break,
}
}
assert!(saw_turn_start, "should see turn_start");
assert!(saw_text_delta, "should see text_delta");
assert!(saw_text_done, "should see text_done");
assert!(saw_turn_end, "should see turn_end");
}
#[tokio::test]
async fn double_run_returns_error() {
// Create a client that streams slowly
let events = vec![
LlmEvent::text_block_start(0),
LlmEvent::text_delta(0, "slow..."),
// No stop/completed — the stream will end but without proper completion
];
let client = MockClient::new(events);
let pod = make_pod(client).await;
let handle = spawn_controller(pod).await;
let mut rx = handle.subscribe();
// Send first run
handle
.send(Method::Run {
input: "first".into(),
})
.await
.unwrap();
// Immediately send second run (should get error)
handle
.send(Method::Run {
input: "second".into(),
})
.await
.unwrap();
// Look for the error event
let mut saw_already_running = false;
let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(2);
loop {
tokio::select! {
event = rx.recv() => {
match event {
Ok(Event::Error { code, .. }) => {
if code == pod::ErrorCode::AlreadyRunning {
saw_already_running = true;
break;
}
}
Err(_) => break,
_ => {}
}
}
_ = tokio::time::sleep_until(deadline) => break,
}
}
assert!(saw_already_running, "should see already_running error");
}
#[tokio::test]
async fn resume_without_pause_returns_error() {
let client = MockClient::new(simple_text_events());
let pod = make_pod(client).await;
let handle = spawn_controller(pod).await;
let mut rx = handle.subscribe();
handle.send(Method::Resume).await.unwrap();
let mut saw_not_paused = false;
let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(1);
loop {
tokio::select! {
event = rx.recv() => {
match event {
Ok(Event::Error { code, .. }) if code == pod::ErrorCode::NotPaused => {
saw_not_paused = true;
break;
}
Err(_) => break,
_ => {}
}
}
_ = tokio::time::sleep_until(deadline) => break,
}
}
assert!(saw_not_paused, "should see not_paused error");
}
#[tokio::test]
async fn cancel_without_run_returns_error() {
let client = MockClient::new(simple_text_events());
let pod = make_pod(client).await;
let handle = spawn_controller(pod).await;
let mut rx = handle.subscribe();
handle.send(Method::Cancel).await.unwrap();
let mut saw_not_running = false;
let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(1);
loop {
tokio::select! {
event = rx.recv() => {
match event {
Ok(Event::Error { code, .. }) if code == pod::ErrorCode::NotRunning => {
saw_not_running = true;
break;
}
Err(_) => break,
_ => {}
}
}
_ = tokio::time::sleep_until(deadline) => break,
}
}
assert!(saw_not_running, "should see not_running error");
}
#[tokio::test]
async fn status_json_reflects_pod_name() {
let client = MockClient::new(simple_text_events());
let pod = make_pod(client).await;
let handle = spawn_controller(pod).await;
let json = handle.shared_state.status_json();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["pod_name"], "test-pod");
}
// ---------------------------------------------------------------------------
// Socket transport tests
// ---------------------------------------------------------------------------
#[tokio::test]
async fn socket_run_receives_events() {
use protocol::stream::{JsonLineReader, JsonLineWriter};
use tokio::net::UnixStream;
let client = MockClient::new(simple_text_events());
let pod = make_pod(client).await;
let handle = spawn_controller(pod).await;
// Give the socket server a moment to bind
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let sock_path = handle.runtime_dir.socket_path();
let stream = UnixStream::connect(&sock_path).await.unwrap();
let (reader, writer) = stream.into_split();
let mut reader = JsonLineReader::new(reader);
let mut writer = JsonLineWriter::new(writer);
// Send run method via socket
writer
.write(&Method::Run {
input: "Hello".into(),
})
.await
.unwrap();
// Collect events
let mut saw_turn_start = false;
let mut saw_text_delta = false;
let mut saw_turn_end = false;
let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(2);
loop {
tokio::select! {
event = reader.next::<Event>() => {
match event {
Ok(Some(Event::TurnStart { .. })) => saw_turn_start = true,
Ok(Some(Event::TextDelta { .. })) => saw_text_delta = true,
Ok(Some(Event::TurnEnd { .. })) => {
saw_turn_end = true;
break;
}
Ok(None) | Err(_) => break,
_ => {}
}
}
_ = tokio::time::sleep_until(deadline) => break,
}
}
assert!(saw_turn_start, "should see turn_start via socket");
assert!(saw_text_delta, "should see text_delta via socket");
assert!(saw_turn_end, "should see turn_end via socket");
}
#[tokio::test]
async fn socket_invalid_method_returns_error() {
use protocol::stream::JsonLineReader;
use tokio::io::AsyncWriteExt;
use tokio::net::UnixStream;
let client = MockClient::new(simple_text_events());
let pod = make_pod(client).await;
let handle = spawn_controller(pod).await;
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let sock_path = handle.runtime_dir.socket_path();
let stream = UnixStream::connect(&sock_path).await.unwrap();
let (reader, mut writer) = stream.into_split();
let mut reader = JsonLineReader::new(reader);
// Send garbage
writer.write_all(b"{\"bad\":\"json\"}\n").await.unwrap();
let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(1);
let mut saw_error = false;
loop {
tokio::select! {
event = reader.next::<Event>() => {
match event {
Ok(Some(Event::Error { .. })) => {
saw_error = true;
break;
}
Ok(None) | Err(_) => break,
_ => {}
}
}
_ = tokio::time::sleep_until(deadline) => break,
}
}
assert!(saw_error, "should see error for invalid method");
}