401 lines
13 KiB
Rust
401 lines
13 KiB
Rust
//! Persistent session wrapper around [`Worker`].
|
|
//!
|
|
//! [`Session`] intercepts `Worker` operations and appends [`HashedEntry`] records
|
|
//! to a [`Store`]. It does not modify `Worker` internals — all persistence
|
|
//! happens by observing state before and after each operation.
|
|
//!
|
|
//! Each appended entry carries a hash that chains to the previous entry.
|
|
//! On append, the session checks whether the store's head still matches its
|
|
//! own `head_hash`; if not, it auto-forks into a new session.
|
|
|
|
use crate::session_log::{self, EntryHash, HashedEntry, LogEntry, Outcome};
|
|
use crate::store::{Store, StoreError};
|
|
use crate::SessionId;
|
|
use llm_worker::llm_client::client::LlmClient;
|
|
use llm_worker::state::Mutable;
|
|
use llm_worker::{Worker, WorkerError, WorkerResult};
|
|
|
|
/// Configuration for session persistence.
|
|
#[derive(Debug, Clone)]
|
|
pub struct SessionConfig {
|
|
/// Record raw stream events to a separate trace file.
|
|
/// Default: `false`.
|
|
pub record_event_trace: bool,
|
|
}
|
|
|
|
impl Default for SessionConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
record_event_trace: false,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Errors from session operations.
|
|
#[derive(Debug, thiserror::Error)]
|
|
pub enum SessionError {
|
|
#[error(transparent)]
|
|
Worker(#[from] WorkerError),
|
|
|
|
#[error(transparent)]
|
|
Store(#[from] StoreError),
|
|
}
|
|
|
|
/// Persistent session wrapping a [`Worker`].
|
|
///
|
|
/// Use [`worker()`](Self::worker) / [`worker_mut()`](Self::worker_mut) to
|
|
/// access the underlying Worker for configuration (tool registration, etc.).
|
|
/// State-mutating operations (`run`, `resume`) should go through Session
|
|
/// methods to ensure proper logging.
|
|
pub struct Session<C: LlmClient, St: Store> {
|
|
/// Always `Some` outside of `run()` / `resume()`.
|
|
worker: Option<Worker<C, Mutable>>,
|
|
store: St,
|
|
session_id: SessionId,
|
|
head_hash: Option<EntryHash>,
|
|
_config: SessionConfig,
|
|
}
|
|
|
|
impl<C: LlmClient, St: Store> Session<C, St> {
|
|
/// Create a new session, writing the initial `SessionStart` entry.
|
|
pub async fn new(
|
|
worker: Worker<C, Mutable>,
|
|
store: St,
|
|
config: SessionConfig,
|
|
) -> Result<Self, StoreError> {
|
|
let session_id = crate::new_session_id();
|
|
let entry = LogEntry::SessionStart {
|
|
ts: session_log::now_millis(),
|
|
system_prompt: worker.get_system_prompt().map(String::from),
|
|
config: worker.request_config().clone(),
|
|
history: worker.history().to_vec(),
|
|
};
|
|
let hashed = session_log::compute_hash(None, &entry);
|
|
let hashed_entry = HashedEntry {
|
|
hash: hashed.clone(),
|
|
prev_hash: None,
|
|
entry,
|
|
};
|
|
store.append(session_id, &hashed_entry).await?;
|
|
|
|
Ok(Self {
|
|
worker: Some(worker),
|
|
store,
|
|
session_id,
|
|
head_hash: Some(hashed),
|
|
_config: config,
|
|
})
|
|
}
|
|
|
|
/// Restore a session from a stored log.
|
|
pub async fn restore(
|
|
client: C,
|
|
store: St,
|
|
session_id: SessionId,
|
|
config: SessionConfig,
|
|
) -> Result<Self, SessionError> {
|
|
let entries = store.read_all(session_id).await?;
|
|
let state = session_log::collect_state(&entries);
|
|
|
|
let mut worker = Worker::new(client);
|
|
if let Some(ref prompt) = state.system_prompt {
|
|
worker.set_system_prompt(prompt);
|
|
}
|
|
worker.set_history(state.history);
|
|
worker.set_request_config(state.config);
|
|
worker.set_turn_count(state.turn_count);
|
|
worker.set_last_run_interrupted(state.last_run_interrupted);
|
|
|
|
Ok(Self {
|
|
worker: Some(worker),
|
|
store,
|
|
session_id,
|
|
head_hash: state.head_hash,
|
|
_config: config,
|
|
})
|
|
}
|
|
|
|
fn w(&self) -> &Worker<C, Mutable> {
|
|
self.worker.as_ref().expect("worker taken during run")
|
|
}
|
|
|
|
/// Reference to the underlying Worker.
|
|
pub fn worker(&self) -> &Worker<C, Mutable> {
|
|
self.w()
|
|
}
|
|
|
|
/// Mutable reference to the underlying Worker.
|
|
pub fn worker_mut(&mut self) -> &mut Worker<C, Mutable> {
|
|
self.worker.as_mut().expect("worker taken during run")
|
|
}
|
|
|
|
/// The session ID.
|
|
pub fn session_id(&self) -> SessionId {
|
|
self.session_id
|
|
}
|
|
|
|
/// The current head hash of the session log chain.
|
|
pub fn head_hash(&self) -> Option<&EntryHash> {
|
|
self.head_hash.as_ref()
|
|
}
|
|
|
|
/// Reference to the underlying store.
|
|
pub fn store(&self) -> &St {
|
|
&self.store
|
|
}
|
|
|
|
/// Run a user turn, logging all state changes.
|
|
///
|
|
/// Internally locks the Worker (flushing pending tools), runs the turn,
|
|
/// then unlocks back to Mutable state.
|
|
pub async fn run(
|
|
&mut self,
|
|
user_input: impl Into<String>,
|
|
) -> Result<WorkerResult, SessionError> {
|
|
let input = user_input.into();
|
|
self.ensure_head_or_fork().await?;
|
|
|
|
let history_before = self.w().history().len();
|
|
|
|
// lock → run → unlock (use lock() directly to keep worker on error)
|
|
let worker = self.worker.take().expect("worker taken during run");
|
|
let mut locked = worker.lock();
|
|
let result = locked.run(input).await;
|
|
self.worker = Some(locked.unlock());
|
|
|
|
self.log_history_delta(history_before).await?;
|
|
self.log_turn_end().await?;
|
|
self.log_outcome(&result).await?;
|
|
|
|
result.map_err(SessionError::Worker)
|
|
}
|
|
|
|
/// Resume from a paused state, logging all state changes.
|
|
pub async fn resume(&mut self) -> Result<WorkerResult, SessionError> {
|
|
self.ensure_head_or_fork().await?;
|
|
|
|
let history_before = self.w().history().len();
|
|
|
|
// lock → resume → unlock
|
|
let worker = self.worker.take().expect("worker taken during run");
|
|
let mut locked = worker.lock();
|
|
let result = locked.resume().await;
|
|
self.worker = Some(locked.unlock());
|
|
|
|
self.log_history_delta(history_before).await?;
|
|
self.log_turn_end().await?;
|
|
self.log_outcome(&result).await?;
|
|
|
|
result.map_err(SessionError::Worker)
|
|
}
|
|
|
|
/// Fork this session at its current state.
|
|
pub async fn fork(&self) -> Result<SessionId, StoreError> {
|
|
let fork_id = crate::new_session_id();
|
|
let entry = LogEntry::SessionStart {
|
|
ts: session_log::now_millis(),
|
|
system_prompt: self.w().get_system_prompt().map(String::from),
|
|
config: self.w().request_config().clone(),
|
|
history: self.w().history().to_vec(),
|
|
};
|
|
let hashed = session_log::compute_hash(None, &entry);
|
|
let hashed_entry = HashedEntry {
|
|
hash: hashed,
|
|
prev_hash: None,
|
|
entry,
|
|
};
|
|
self.store
|
|
.create_session(fork_id, &[hashed_entry])
|
|
.await?;
|
|
Ok(fork_id)
|
|
}
|
|
|
|
/// Fork from an arbitrary point in a stored session's log.
|
|
pub async fn fork_at(
|
|
store: &St,
|
|
source_id: SessionId,
|
|
at_hash: &EntryHash,
|
|
) -> Result<SessionId, StoreError> {
|
|
let entries = store.read_all(source_id).await?;
|
|
let cut = entries
|
|
.iter()
|
|
.position(|e| &e.hash == at_hash)
|
|
.map(|i| i + 1)
|
|
.unwrap_or(entries.len());
|
|
let state = session_log::collect_state(&entries[..cut]);
|
|
|
|
let fork_id = crate::new_session_id();
|
|
let entry = LogEntry::SessionStart {
|
|
ts: session_log::now_millis(),
|
|
system_prompt: state.system_prompt,
|
|
config: state.config,
|
|
history: state.history,
|
|
};
|
|
let hashed = session_log::compute_hash(None, &entry);
|
|
let hashed_entry = HashedEntry {
|
|
hash: hashed,
|
|
prev_hash: None,
|
|
entry,
|
|
};
|
|
store.create_session(fork_id, &[hashed_entry]).await?;
|
|
Ok(fork_id)
|
|
}
|
|
|
|
/// Log a `Locked` entry.
|
|
pub async fn log_cache_locked(
|
|
&mut self,
|
|
locked_prefix_len: usize,
|
|
) -> Result<(), StoreError> {
|
|
let entry = LogEntry::Locked {
|
|
ts: session_log::now_millis(),
|
|
locked_prefix_len,
|
|
};
|
|
self.append_entry(entry).await
|
|
}
|
|
|
|
/// Log a `CacheUnlocked` entry.
|
|
pub async fn log_cache_unlocked(&mut self) -> Result<(), StoreError> {
|
|
let entry = LogEntry::CacheUnlocked {
|
|
ts: session_log::now_millis(),
|
|
};
|
|
self.append_entry(entry).await
|
|
}
|
|
|
|
/// Log a `ConfigChanged` entry.
|
|
pub async fn log_config_changed(&mut self) -> Result<(), StoreError> {
|
|
let entry = LogEntry::ConfigChanged {
|
|
ts: session_log::now_millis(),
|
|
config: self.w().request_config().clone(),
|
|
};
|
|
self.append_entry(entry).await
|
|
}
|
|
|
|
// ── Private helpers ──────────────────────────────────────────────────
|
|
|
|
async fn append_entry(&mut self, entry: LogEntry) -> Result<(), StoreError> {
|
|
let hash = session_log::compute_hash(self.head_hash.as_ref(), &entry);
|
|
let hashed_entry = HashedEntry {
|
|
hash: hash.clone(),
|
|
prev_hash: self.head_hash.clone(),
|
|
entry,
|
|
};
|
|
self.store
|
|
.append(self.session_id, &hashed_entry)
|
|
.await?;
|
|
self.head_hash = Some(hash);
|
|
Ok(())
|
|
}
|
|
|
|
async fn ensure_head_or_fork(&mut self) -> Result<(), StoreError> {
|
|
let store_head = self.store.read_head_hash(self.session_id).await?;
|
|
if store_head == self.head_hash {
|
|
return Ok(());
|
|
}
|
|
let fork_id = crate::new_session_id();
|
|
let entry = LogEntry::SessionStart {
|
|
ts: session_log::now_millis(),
|
|
system_prompt: self.w().get_system_prompt().map(String::from),
|
|
config: self.w().request_config().clone(),
|
|
history: self.w().history().to_vec(),
|
|
};
|
|
let hash = session_log::compute_hash(None, &entry);
|
|
let hashed_entry = HashedEntry {
|
|
hash: hash.clone(),
|
|
prev_hash: None,
|
|
entry,
|
|
};
|
|
self.store
|
|
.create_session(fork_id, &[hashed_entry])
|
|
.await?;
|
|
self.session_id = fork_id;
|
|
self.head_hash = Some(hash);
|
|
Ok(())
|
|
}
|
|
|
|
async fn log_history_delta(&mut self, before_len: usize) -> Result<(), StoreError> {
|
|
let history = self.w().history();
|
|
if history.len() <= before_len {
|
|
return Ok(());
|
|
}
|
|
|
|
let ts = session_log::now_millis();
|
|
let new_items = history[before_len..].to_vec();
|
|
let mut i = 0;
|
|
|
|
while i < new_items.len() {
|
|
let item = &new_items[i];
|
|
if item.is_user_message() {
|
|
self.append_entry(LogEntry::UserInput {
|
|
ts,
|
|
item: new_items[i].clone(),
|
|
})
|
|
.await?;
|
|
i += 1;
|
|
} else if item.is_tool_result() {
|
|
let start = i;
|
|
while i < new_items.len() && new_items[i].is_tool_result() {
|
|
i += 1;
|
|
}
|
|
self.append_entry(LogEntry::ToolResults {
|
|
ts,
|
|
items: new_items[start..i].to_vec(),
|
|
})
|
|
.await?;
|
|
} else if item.is_assistant_message()
|
|
|| item.is_tool_call()
|
|
|| item.is_reasoning()
|
|
{
|
|
let start = i;
|
|
while i < new_items.len()
|
|
&& (new_items[i].is_assistant_message()
|
|
|| new_items[i].is_tool_call()
|
|
|| new_items[i].is_reasoning())
|
|
{
|
|
i += 1;
|
|
}
|
|
self.append_entry(LogEntry::AssistantItems {
|
|
ts,
|
|
items: new_items[start..i].to_vec(),
|
|
})
|
|
.await?;
|
|
} else {
|
|
self.append_entry(LogEntry::HookInjectedItems {
|
|
ts,
|
|
items: vec![new_items[i].clone()],
|
|
})
|
|
.await?;
|
|
i += 1;
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn log_turn_end(&mut self) -> Result<(), StoreError> {
|
|
self.append_entry(LogEntry::TurnEnd {
|
|
ts: session_log::now_millis(),
|
|
turn_count: self.w().turn_count(),
|
|
})
|
|
.await
|
|
}
|
|
|
|
async fn log_outcome(
|
|
&mut self,
|
|
result: &Result<WorkerResult, WorkerError>,
|
|
) -> Result<(), StoreError> {
|
|
let outcome = match result {
|
|
Ok(WorkerResult::Finished) => Outcome::Finished,
|
|
Ok(WorkerResult::Paused) => Outcome::Paused,
|
|
Ok(WorkerResult::LimitReached) => Outcome::LimitReached,
|
|
Err(e) => Outcome::Error {
|
|
message: e.to_string(),
|
|
},
|
|
};
|
|
self.append_entry(LogEntry::RunOutcome {
|
|
ts: session_log::now_millis(),
|
|
outcome,
|
|
interrupted: self.w().last_run_interrupted(),
|
|
})
|
|
.await
|
|
}
|
|
}
|