fix: serialize same-file mutations
This commit is contained in:
parent
536ff4dd57
commit
401301438d
|
|
@ -23,7 +23,7 @@ serde_json = { workspace = true }
|
||||||
sha2 = { workspace = true }
|
sha2 = { workspace = true }
|
||||||
tempfile = { workspace = true }
|
tempfile = { workspace = true }
|
||||||
thiserror = { workspace = true }
|
thiserror = { workspace = true }
|
||||||
tokio = { workspace = true, features = ["process", "rt", "time"] }
|
tokio = { workspace = true, features = ["process", "rt", "sync", "time"] }
|
||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ impl Tool for EditTool {
|
||||||
async fn execute(
|
async fn execute(
|
||||||
&self,
|
&self,
|
||||||
input_json: &str,
|
input_json: &str,
|
||||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
ctx: llm_worker::tool::ToolExecutionContext,
|
||||||
) -> Result<ToolOutput, ToolError> {
|
) -> Result<ToolOutput, ToolError> {
|
||||||
let params: EditParams = serde_json::from_str(input_json)
|
let params: EditParams = serde_json::from_str(input_json)
|
||||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid Edit input: {e}")))?;
|
.map_err(|e| ToolError::InvalidArgument(format!("invalid Edit input: {e}")))?;
|
||||||
|
|
@ -61,6 +61,8 @@ impl Tool for EditTool {
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let _mutation_permit = self.tracker.acquire_mutation(¶ms.file_path, &ctx).await;
|
||||||
|
|
||||||
// Load current content and verify it matches the recorded hash.
|
// Load current content and verify it matches the recorded hash.
|
||||||
let current_bytes = self.fs.read_bytes(¶ms.file_path)?;
|
let current_bytes = self.fs.read_bytes(¶ms.file_path)?;
|
||||||
self.tracker.verify(¶ms.file_path, ¤t_bytes)?;
|
self.tracker.verify(¶ms.file_path, ¤t_bytes)?;
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,9 @@
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
use std::collections::{HashMap, VecDeque};
|
use std::collections::{HashMap, VecDeque};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Component, Path, PathBuf};
|
||||||
|
|
||||||
|
use llm_worker::tool::ToolExecutionContext;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
use sha2::{Digest, Sha256};
|
use sha2::{Digest, Sha256};
|
||||||
|
|
@ -58,6 +60,60 @@ fn hash_bytes(bytes: &[u8]) -> ContentHash {
|
||||||
hasher.finalize().into()
|
hasher.finalize().into()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default)]
|
||||||
|
struct FileMutationCoordinator {
|
||||||
|
locks: Arc<tokio::sync::Mutex<HashMap<PathBuf, Arc<tokio::sync::Mutex<()>>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct FileMutationPermit {
|
||||||
|
_guard: tokio::sync::OwnedMutexGuard<()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FileMutationCoordinator {
|
||||||
|
async fn acquire(&self, path: &Path) -> FileMutationPermit {
|
||||||
|
let key = file_mutation_key(path);
|
||||||
|
let lock = {
|
||||||
|
let mut locks = self.locks.lock().await;
|
||||||
|
locks
|
||||||
|
.entry(key)
|
||||||
|
.or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
|
||||||
|
.clone()
|
||||||
|
};
|
||||||
|
FileMutationPermit {
|
||||||
|
_guard: lock.lock_owned().await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn file_mutation_key(path: &Path) -> PathBuf {
|
||||||
|
if let Ok(canonical) = path.canonicalize() {
|
||||||
|
return canonical;
|
||||||
|
}
|
||||||
|
if let (Some(parent), Some(file_name)) = (path.parent(), path.file_name())
|
||||||
|
&& let Ok(canonical_parent) = parent.canonicalize()
|
||||||
|
{
|
||||||
|
return canonical_parent.join(file_name);
|
||||||
|
}
|
||||||
|
normalize_path_lexically(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_path_lexically(path: &Path) -> PathBuf {
|
||||||
|
let mut normalized = PathBuf::new();
|
||||||
|
for component in path.components() {
|
||||||
|
match component {
|
||||||
|
Component::CurDir => {}
|
||||||
|
Component::ParentDir => {
|
||||||
|
if !normalized.pop() {
|
||||||
|
normalized.push(component.as_os_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Component::Normal(part) => normalized.push(part),
|
||||||
|
Component::RootDir | Component::Prefix(_) => normalized.push(component.as_os_str()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
normalized
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
struct Inner {
|
struct Inner {
|
||||||
/// Hash of each file's last observed contents, keyed by canonical path.
|
/// Hash of each file's last observed contents, keyed by canonical path.
|
||||||
|
|
@ -74,6 +130,7 @@ struct Inner {
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default)]
|
||||||
pub struct Tracker {
|
pub struct Tracker {
|
||||||
inner: Arc<Mutex<Inner>>,
|
inner: Arc<Mutex<Inner>>,
|
||||||
|
mutations: FileMutationCoordinator,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Tracker {
|
impl Tracker {
|
||||||
|
|
@ -82,6 +139,25 @@ impl Tracker {
|
||||||
Self::default()
|
Self::default()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Acquire the per-target-file mutation guard shared by `Write` and `Edit`.
|
||||||
|
///
|
||||||
|
/// The guard is keyed by canonical target path where possible so equivalent
|
||||||
|
/// paths serialize through the same lock. Worker still executes tool calls in
|
||||||
|
/// parallel; this only gates the critical filesystem mutation section for
|
||||||
|
/// builtin file mutation tools.
|
||||||
|
pub(crate) async fn acquire_mutation(
|
||||||
|
&self,
|
||||||
|
path: &Path,
|
||||||
|
ctx: &ToolExecutionContext,
|
||||||
|
) -> FileMutationPermit {
|
||||||
|
tracing::debug!(
|
||||||
|
batch_id = %ctx.batch_id,
|
||||||
|
call_index = ctx.call_index,
|
||||||
|
"acquire file mutation guard"
|
||||||
|
);
|
||||||
|
self.mutations.acquire(path).await
|
||||||
|
}
|
||||||
|
|
||||||
/// Record that `path` has been observed with the given content bytes.
|
/// Record that `path` has been observed with the given content bytes.
|
||||||
///
|
///
|
||||||
/// Called by the `Read` tool after a successful read, and by the
|
/// Called by the `Read` tool after a successful read, and by the
|
||||||
|
|
@ -347,4 +423,50 @@ mod tests {
|
||||||
assert!(recent.iter().all(|p| !p.ends_with(&name)));
|
assert!(recent.iter().all(|p| !p.ends_with(&name)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn mutation_guard_blocks_equivalent_paths_until_drop() {
|
||||||
|
let dir = tempfile::tempdir().unwrap();
|
||||||
|
let file = dir.path().join("target.txt");
|
||||||
|
fs::write(&file, "x").unwrap();
|
||||||
|
let equivalent = dir.path().join("sub").join("..").join("target.txt");
|
||||||
|
fs::create_dir(dir.path().join("sub")).unwrap();
|
||||||
|
let tracker = Tracker::new();
|
||||||
|
let first = tracker
|
||||||
|
.acquire_mutation(&file, &ToolExecutionContext::new("a", "batch", 0))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let second_ctx = ToolExecutionContext::new("b", "batch", 1);
|
||||||
|
let second = tracker.acquire_mutation(&equivalent, &second_ctx);
|
||||||
|
assert!(
|
||||||
|
tokio::time::timeout(std::time::Duration::from_millis(10), second)
|
||||||
|
.await
|
||||||
|
.is_err()
|
||||||
|
);
|
||||||
|
|
||||||
|
drop(first);
|
||||||
|
tracker
|
||||||
|
.acquire_mutation(&equivalent, &ToolExecutionContext::new("b", "batch", 1))
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn mutation_guard_does_not_block_different_files() {
|
||||||
|
let dir = tempfile::tempdir().unwrap();
|
||||||
|
let first_file = dir.path().join("a.txt");
|
||||||
|
let second_file = dir.path().join("b.txt");
|
||||||
|
fs::write(&first_file, "a").unwrap();
|
||||||
|
fs::write(&second_file, "b").unwrap();
|
||||||
|
let tracker = Tracker::new();
|
||||||
|
let _first = tracker
|
||||||
|
.acquire_mutation(&first_file, &ToolExecutionContext::new("a", "batch", 0))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
tokio::time::timeout(
|
||||||
|
std::time::Duration::from_millis(100),
|
||||||
|
tracker.acquire_mutation(&second_file, &ToolExecutionContext::new("b", "batch", 1)),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("different files should not share a mutation guard");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ impl Tool for WriteTool {
|
||||||
async fn execute(
|
async fn execute(
|
||||||
&self,
|
&self,
|
||||||
input_json: &str,
|
input_json: &str,
|
||||||
_ctx: llm_worker::tool::ToolExecutionContext,
|
ctx: llm_worker::tool::ToolExecutionContext,
|
||||||
) -> Result<ToolOutput, ToolError> {
|
) -> Result<ToolOutput, ToolError> {
|
||||||
let params: WriteParams = serde_json::from_str(input_json)
|
let params: WriteParams = serde_json::from_str(input_json)
|
||||||
.map_err(|e| ToolError::InvalidArgument(format!("invalid Write input: {e}")))?;
|
.map_err(|e| ToolError::InvalidArgument(format!("invalid Write input: {e}")))?;
|
||||||
|
|
@ -44,6 +44,8 @@ impl Tool for WriteTool {
|
||||||
"Write"
|
"Write"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let _mutation_permit = self.tracker.acquire_mutation(¶ms.file_path, &ctx).await;
|
||||||
|
|
||||||
// Policy check: if the target already exists, it must have been
|
// Policy check: if the target already exists, it must have been
|
||||||
// observed by the Read tool (via the tracker) and its current
|
// observed by the Read tool (via the tracker) and its current
|
||||||
// contents must match the recorded hash.
|
// contents must match the recorded hash.
|
||||||
|
|
@ -231,4 +233,90 @@ mod tests {
|
||||||
.unwrap_err();
|
.unwrap_err();
|
||||||
assert!(matches!(err, ToolError::InvalidArgument(_)));
|
assert!(matches!(err, ToolError::InvalidArgument(_)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn write_then_edit_same_file_same_batch_uses_call_order() {
|
||||||
|
use crate::edit::edit_tool;
|
||||||
|
use llm_worker::tool::ToolExecutionContext;
|
||||||
|
|
||||||
|
let (dir, fs, tracker) = setup();
|
||||||
|
let file = dir.path().join("ordered.txt");
|
||||||
|
|
||||||
|
let write_def = write_tool(fs.clone(), tracker.clone());
|
||||||
|
let (_, writer) = write_def();
|
||||||
|
let edit_def = edit_tool(fs, tracker);
|
||||||
|
let (_, editor) = edit_def();
|
||||||
|
|
||||||
|
let write_in = serde_json::json!({
|
||||||
|
"file_path": file.to_str().unwrap(),
|
||||||
|
"content": "hello",
|
||||||
|
});
|
||||||
|
let edit_in = serde_json::json!({
|
||||||
|
"file_path": file.to_str().unwrap(),
|
||||||
|
"old_string": "hello",
|
||||||
|
"new_string": "goodbye",
|
||||||
|
});
|
||||||
|
|
||||||
|
let write_json = write_in.to_string();
|
||||||
|
let edit_json = edit_in.to_string();
|
||||||
|
let (write_out, edit_out) = tokio::join!(
|
||||||
|
writer.execute(&write_json, ToolExecutionContext::new("write", "batch", 0),),
|
||||||
|
editor.execute(&edit_json, ToolExecutionContext::new("edit", "batch", 1)),
|
||||||
|
);
|
||||||
|
|
||||||
|
write_out.unwrap();
|
||||||
|
edit_out.unwrap();
|
||||||
|
assert_eq!(std::fs::read_to_string(&file).unwrap(), "goodbye");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn failed_same_file_mutation_releases_guard_for_followup() {
|
||||||
|
use crate::edit::edit_tool;
|
||||||
|
use llm_worker::tool::ToolExecutionContext;
|
||||||
|
|
||||||
|
let (dir, fs, tracker) = setup();
|
||||||
|
let file = dir.path().join("release.txt");
|
||||||
|
std::fs::write(&file, "alpha").unwrap();
|
||||||
|
|
||||||
|
let read_def = read_tool(fs.clone(), tracker.clone());
|
||||||
|
let (_, reader) = read_def();
|
||||||
|
reader
|
||||||
|
.execute(
|
||||||
|
&serde_json::json!({ "file_path": file.to_str().unwrap() }).to_string(),
|
||||||
|
ToolExecutionContext::new("read", "pre", 0),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let edit_def = edit_tool(fs, tracker);
|
||||||
|
let (_, editor) = edit_def();
|
||||||
|
let bad_edit = serde_json::json!({
|
||||||
|
"file_path": file.to_str().unwrap(),
|
||||||
|
"old_string": "missing",
|
||||||
|
"new_string": "beta",
|
||||||
|
});
|
||||||
|
let good_edit = serde_json::json!({
|
||||||
|
"file_path": file.to_str().unwrap(),
|
||||||
|
"old_string": "alpha",
|
||||||
|
"new_string": "beta",
|
||||||
|
});
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
editor
|
||||||
|
.execute(
|
||||||
|
&bad_edit.to_string(),
|
||||||
|
ToolExecutionContext::new("bad", "batch", 0),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.is_err()
|
||||||
|
);
|
||||||
|
editor
|
||||||
|
.execute(
|
||||||
|
&good_edit.to_string(),
|
||||||
|
ToolExecutionContext::new("good", "batch", 1),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(std::fs::read_to_string(&file).unwrap(), "beta");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user