merge: serialize file mutations
This commit is contained in:
commit
29960c1589
|
|
@ -23,7 +23,7 @@ serde_json = { workspace = true }
|
|||
sha2 = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio = { workspace = true, features = ["process", "rt", "time"] }
|
||||
tokio = { workspace = true, features = ["process", "rt", "sync", "time"] }
|
||||
tracing = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ impl Tool for EditTool {
|
|||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: EditParams = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid Edit input: {e}")))?;
|
||||
|
|
@ -61,6 +61,8 @@ impl Tool for EditTool {
|
|||
));
|
||||
}
|
||||
|
||||
let _mutation_permit = self.tracker.acquire_mutation(¶ms.file_path, &ctx).await;
|
||||
|
||||
// Load current content and verify it matches the recorded hash.
|
||||
let current_bytes = self.fs.read_bytes(¶ms.file_path)?;
|
||||
self.tracker.verify(¶ms.file_path, ¤t_bytes)?;
|
||||
|
|
|
|||
|
|
@ -38,7 +38,9 @@
|
|||
//! ```
|
||||
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::path::{Component, Path, PathBuf};
|
||||
|
||||
use llm_worker::tool::ToolExecutionContext;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use sha2::{Digest, Sha256};
|
||||
|
|
@ -58,6 +60,60 @@ fn hash_bytes(bytes: &[u8]) -> ContentHash {
|
|||
hasher.finalize().into()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct FileMutationCoordinator {
|
||||
locks: Arc<tokio::sync::Mutex<HashMap<PathBuf, Arc<tokio::sync::Mutex<()>>>>>,
|
||||
}
|
||||
|
||||
pub(crate) struct FileMutationPermit {
|
||||
_guard: tokio::sync::OwnedMutexGuard<()>,
|
||||
}
|
||||
|
||||
impl FileMutationCoordinator {
|
||||
async fn acquire(&self, path: &Path) -> FileMutationPermit {
|
||||
let key = file_mutation_key(path);
|
||||
let lock = {
|
||||
let mut locks = self.locks.lock().await;
|
||||
locks
|
||||
.entry(key)
|
||||
.or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
|
||||
.clone()
|
||||
};
|
||||
FileMutationPermit {
|
||||
_guard: lock.lock_owned().await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn file_mutation_key(path: &Path) -> PathBuf {
|
||||
if let Ok(canonical) = path.canonicalize() {
|
||||
return canonical;
|
||||
}
|
||||
if let (Some(parent), Some(file_name)) = (path.parent(), path.file_name())
|
||||
&& let Ok(canonical_parent) = parent.canonicalize()
|
||||
{
|
||||
return canonical_parent.join(file_name);
|
||||
}
|
||||
normalize_path_lexically(path)
|
||||
}
|
||||
|
||||
fn normalize_path_lexically(path: &Path) -> PathBuf {
|
||||
let mut normalized = PathBuf::new();
|
||||
for component in path.components() {
|
||||
match component {
|
||||
Component::CurDir => {}
|
||||
Component::ParentDir => {
|
||||
if !normalized.pop() {
|
||||
normalized.push(component.as_os_str());
|
||||
}
|
||||
}
|
||||
Component::Normal(part) => normalized.push(part),
|
||||
Component::RootDir | Component::Prefix(_) => normalized.push(component.as_os_str()),
|
||||
}
|
||||
}
|
||||
normalized
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct Inner {
|
||||
/// Hash of each file's last observed contents, keyed by canonical path.
|
||||
|
|
@ -74,6 +130,7 @@ struct Inner {
|
|||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Tracker {
|
||||
inner: Arc<Mutex<Inner>>,
|
||||
mutations: FileMutationCoordinator,
|
||||
}
|
||||
|
||||
impl Tracker {
|
||||
|
|
@ -82,6 +139,25 @@ impl Tracker {
|
|||
Self::default()
|
||||
}
|
||||
|
||||
/// Acquire the per-target-file mutation guard shared by `Write` and `Edit`.
|
||||
///
|
||||
/// The guard is keyed by canonical target path where possible so equivalent
|
||||
/// paths serialize through the same lock. Worker still executes tool calls in
|
||||
/// parallel; this only gates the critical filesystem mutation section for
|
||||
/// builtin file mutation tools.
|
||||
pub(crate) async fn acquire_mutation(
|
||||
&self,
|
||||
path: &Path,
|
||||
ctx: &ToolExecutionContext,
|
||||
) -> FileMutationPermit {
|
||||
tracing::debug!(
|
||||
batch_id = %ctx.batch_id,
|
||||
call_index = ctx.call_index,
|
||||
"acquire file mutation guard"
|
||||
);
|
||||
self.mutations.acquire(path).await
|
||||
}
|
||||
|
||||
/// Record that `path` has been observed with the given content bytes.
|
||||
///
|
||||
/// Called by the `Read` tool after a successful read, and by the
|
||||
|
|
@ -347,4 +423,50 @@ mod tests {
|
|||
assert!(recent.iter().all(|p| !p.ends_with(&name)));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mutation_guard_blocks_equivalent_paths_until_drop() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let file = dir.path().join("target.txt");
|
||||
fs::write(&file, "x").unwrap();
|
||||
let equivalent = dir.path().join("sub").join("..").join("target.txt");
|
||||
fs::create_dir(dir.path().join("sub")).unwrap();
|
||||
let tracker = Tracker::new();
|
||||
let first = tracker
|
||||
.acquire_mutation(&file, &ToolExecutionContext::new("a", "batch", 0))
|
||||
.await;
|
||||
|
||||
let second_ctx = ToolExecutionContext::new("b", "batch", 1);
|
||||
let second = tracker.acquire_mutation(&equivalent, &second_ctx);
|
||||
assert!(
|
||||
tokio::time::timeout(std::time::Duration::from_millis(10), second)
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
|
||||
drop(first);
|
||||
tracker
|
||||
.acquire_mutation(&equivalent, &ToolExecutionContext::new("b", "batch", 1))
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mutation_guard_does_not_block_different_files() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let first_file = dir.path().join("a.txt");
|
||||
let second_file = dir.path().join("b.txt");
|
||||
fs::write(&first_file, "a").unwrap();
|
||||
fs::write(&second_file, "b").unwrap();
|
||||
let tracker = Tracker::new();
|
||||
let _first = tracker
|
||||
.acquire_mutation(&first_file, &ToolExecutionContext::new("a", "batch", 0))
|
||||
.await;
|
||||
|
||||
tokio::time::timeout(
|
||||
std::time::Duration::from_millis(100),
|
||||
tracker.acquire_mutation(&second_file, &ToolExecutionContext::new("b", "batch", 1)),
|
||||
)
|
||||
.await
|
||||
.expect("different files should not share a mutation guard");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ impl Tool for WriteTool {
|
|||
async fn execute(
|
||||
&self,
|
||||
input_json: &str,
|
||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
||||
ctx: llm_worker::tool::ToolExecutionContext,
|
||||
) -> Result<ToolOutput, ToolError> {
|
||||
let params: WriteParams = serde_json::from_str(input_json)
|
||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid Write input: {e}")))?;
|
||||
|
|
@ -44,6 +44,8 @@ impl Tool for WriteTool {
|
|||
"Write"
|
||||
);
|
||||
|
||||
let _mutation_permit = self.tracker.acquire_mutation(¶ms.file_path, &ctx).await;
|
||||
|
||||
// Policy check: if the target already exists, it must have been
|
||||
// observed by the Read tool (via the tracker) and its current
|
||||
// contents must match the recorded hash.
|
||||
|
|
@ -231,4 +233,90 @@ mod tests {
|
|||
.unwrap_err();
|
||||
assert!(matches!(err, ToolError::InvalidArgument(_)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_then_edit_same_file_same_batch_uses_call_order() {
|
||||
use crate::edit::edit_tool;
|
||||
use llm_worker::tool::ToolExecutionContext;
|
||||
|
||||
let (dir, fs, tracker) = setup();
|
||||
let file = dir.path().join("ordered.txt");
|
||||
|
||||
let write_def = write_tool(fs.clone(), tracker.clone());
|
||||
let (_, writer) = write_def();
|
||||
let edit_def = edit_tool(fs, tracker);
|
||||
let (_, editor) = edit_def();
|
||||
|
||||
let write_in = serde_json::json!({
|
||||
"file_path": file.to_str().unwrap(),
|
||||
"content": "hello",
|
||||
});
|
||||
let edit_in = serde_json::json!({
|
||||
"file_path": file.to_str().unwrap(),
|
||||
"old_string": "hello",
|
||||
"new_string": "goodbye",
|
||||
});
|
||||
|
||||
let write_json = write_in.to_string();
|
||||
let edit_json = edit_in.to_string();
|
||||
let (write_out, edit_out) = tokio::join!(
|
||||
writer.execute(&write_json, ToolExecutionContext::new("write", "batch", 0),),
|
||||
editor.execute(&edit_json, ToolExecutionContext::new("edit", "batch", 1)),
|
||||
);
|
||||
|
||||
write_out.unwrap();
|
||||
edit_out.unwrap();
|
||||
assert_eq!(std::fs::read_to_string(&file).unwrap(), "goodbye");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn failed_same_file_mutation_releases_guard_for_followup() {
|
||||
use crate::edit::edit_tool;
|
||||
use llm_worker::tool::ToolExecutionContext;
|
||||
|
||||
let (dir, fs, tracker) = setup();
|
||||
let file = dir.path().join("release.txt");
|
||||
std::fs::write(&file, "alpha").unwrap();
|
||||
|
||||
let read_def = read_tool(fs.clone(), tracker.clone());
|
||||
let (_, reader) = read_def();
|
||||
reader
|
||||
.execute(
|
||||
&serde_json::json!({ "file_path": file.to_str().unwrap() }).to_string(),
|
||||
ToolExecutionContext::new("read", "pre", 0),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let edit_def = edit_tool(fs, tracker);
|
||||
let (_, editor) = edit_def();
|
||||
let bad_edit = serde_json::json!({
|
||||
"file_path": file.to_str().unwrap(),
|
||||
"old_string": "missing",
|
||||
"new_string": "beta",
|
||||
});
|
||||
let good_edit = serde_json::json!({
|
||||
"file_path": file.to_str().unwrap(),
|
||||
"old_string": "alpha",
|
||||
"new_string": "beta",
|
||||
});
|
||||
|
||||
assert!(
|
||||
editor
|
||||
.execute(
|
||||
&bad_edit.to_string(),
|
||||
ToolExecutionContext::new("bad", "batch", 0),
|
||||
)
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
editor
|
||||
.execute(
|
||||
&good_edit.to_string(),
|
||||
ToolExecutionContext::new("good", "batch", 1),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(std::fs::read_to_string(&file).unwrap(), "beta");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user