fix: serialize same-file mutations

This commit is contained in:
Keisuke Hirata 2026-06-10 18:24:53 +09:00
parent 536ff4dd57
commit 401301438d
No known key found for this signature in database
4 changed files with 216 additions and 4 deletions

View File

@ -23,7 +23,7 @@ serde_json = { workspace = true }
sha2 = { workspace = true }
tempfile = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["process", "rt", "time"] }
tokio = { workspace = true, features = ["process", "rt", "sync", "time"] }
tracing = { workspace = true }
[dev-dependencies]

View File

@ -39,7 +39,7 @@ impl Tool for EditTool {
async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: EditParams = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid Edit input: {e}")))?;
@ -61,6 +61,8 @@ impl Tool for EditTool {
));
}
let _mutation_permit = self.tracker.acquire_mutation(&params.file_path, &ctx).await;
// Load current content and verify it matches the recorded hash.
let current_bytes = self.fs.read_bytes(&params.file_path)?;
self.tracker.verify(&params.file_path, &current_bytes)?;

View File

@ -38,7 +38,9 @@
//! ```
use std::collections::{HashMap, VecDeque};
use std::path::{Path, PathBuf};
use std::path::{Component, Path, PathBuf};
use llm_worker::tool::ToolExecutionContext;
use std::sync::{Arc, Mutex};
use sha2::{Digest, Sha256};
@ -58,6 +60,60 @@ fn hash_bytes(bytes: &[u8]) -> ContentHash {
hasher.finalize().into()
}
#[derive(Debug, Clone, Default)]
struct FileMutationCoordinator {
locks: Arc<tokio::sync::Mutex<HashMap<PathBuf, Arc<tokio::sync::Mutex<()>>>>>,
}
pub(crate) struct FileMutationPermit {
_guard: tokio::sync::OwnedMutexGuard<()>,
}
impl FileMutationCoordinator {
async fn acquire(&self, path: &Path) -> FileMutationPermit {
let key = file_mutation_key(path);
let lock = {
let mut locks = self.locks.lock().await;
locks
.entry(key)
.or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
.clone()
};
FileMutationPermit {
_guard: lock.lock_owned().await,
}
}
}
fn file_mutation_key(path: &Path) -> PathBuf {
if let Ok(canonical) = path.canonicalize() {
return canonical;
}
if let (Some(parent), Some(file_name)) = (path.parent(), path.file_name())
&& let Ok(canonical_parent) = parent.canonicalize()
{
return canonical_parent.join(file_name);
}
normalize_path_lexically(path)
}
fn normalize_path_lexically(path: &Path) -> PathBuf {
let mut normalized = PathBuf::new();
for component in path.components() {
match component {
Component::CurDir => {}
Component::ParentDir => {
if !normalized.pop() {
normalized.push(component.as_os_str());
}
}
Component::Normal(part) => normalized.push(part),
Component::RootDir | Component::Prefix(_) => normalized.push(component.as_os_str()),
}
}
normalized
}
#[derive(Debug, Default)]
struct Inner {
/// Hash of each file's last observed contents, keyed by canonical path.
@ -74,6 +130,7 @@ struct Inner {
#[derive(Debug, Clone, Default)]
pub struct Tracker {
inner: Arc<Mutex<Inner>>,
mutations: FileMutationCoordinator,
}
impl Tracker {
@ -82,6 +139,25 @@ impl Tracker {
Self::default()
}
/// Acquire the per-target-file mutation guard shared by `Write` and `Edit`.
///
/// The guard is keyed by canonical target path where possible so equivalent
/// paths serialize through the same lock. Worker still executes tool calls in
/// parallel; this only gates the critical filesystem mutation section for
/// builtin file mutation tools.
pub(crate) async fn acquire_mutation(
&self,
path: &Path,
ctx: &ToolExecutionContext,
) -> FileMutationPermit {
tracing::debug!(
batch_id = %ctx.batch_id,
call_index = ctx.call_index,
"acquire file mutation guard"
);
self.mutations.acquire(path).await
}
/// Record that `path` has been observed with the given content bytes.
///
/// Called by the `Read` tool after a successful read, and by the
@ -347,4 +423,50 @@ mod tests {
assert!(recent.iter().all(|p| !p.ends_with(&name)));
}
}
#[tokio::test]
async fn mutation_guard_blocks_equivalent_paths_until_drop() {
let dir = tempfile::tempdir().unwrap();
let file = dir.path().join("target.txt");
fs::write(&file, "x").unwrap();
let equivalent = dir.path().join("sub").join("..").join("target.txt");
fs::create_dir(dir.path().join("sub")).unwrap();
let tracker = Tracker::new();
let first = tracker
.acquire_mutation(&file, &ToolExecutionContext::new("a", "batch", 0))
.await;
let second_ctx = ToolExecutionContext::new("b", "batch", 1);
let second = tracker.acquire_mutation(&equivalent, &second_ctx);
assert!(
tokio::time::timeout(std::time::Duration::from_millis(10), second)
.await
.is_err()
);
drop(first);
tracker
.acquire_mutation(&equivalent, &ToolExecutionContext::new("b", "batch", 1))
.await;
}
#[tokio::test]
async fn mutation_guard_does_not_block_different_files() {
let dir = tempfile::tempdir().unwrap();
let first_file = dir.path().join("a.txt");
let second_file = dir.path().join("b.txt");
fs::write(&first_file, "a").unwrap();
fs::write(&second_file, "b").unwrap();
let tracker = Tracker::new();
let _first = tracker
.acquire_mutation(&first_file, &ToolExecutionContext::new("a", "batch", 0))
.await;
tokio::time::timeout(
std::time::Duration::from_millis(100),
tracker.acquire_mutation(&second_file, &ToolExecutionContext::new("b", "batch", 1)),
)
.await
.expect("different files should not share a mutation guard");
}
}

View File

@ -33,7 +33,7 @@ impl Tool for WriteTool {
async fn execute(
&self,
input_json: &str,
_ctx: llm_worker::tool::ToolExecutionContext,
ctx: llm_worker::tool::ToolExecutionContext,
) -> Result<ToolOutput, ToolError> {
let params: WriteParams = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(format!("invalid Write input: {e}")))?;
@ -44,6 +44,8 @@ impl Tool for WriteTool {
"Write"
);
let _mutation_permit = self.tracker.acquire_mutation(&params.file_path, &ctx).await;
// Policy check: if the target already exists, it must have been
// observed by the Read tool (via the tracker) and its current
// contents must match the recorded hash.
@ -231,4 +233,90 @@ mod tests {
.unwrap_err();
assert!(matches!(err, ToolError::InvalidArgument(_)));
}
#[tokio::test]
async fn write_then_edit_same_file_same_batch_uses_call_order() {
use crate::edit::edit_tool;
use llm_worker::tool::ToolExecutionContext;
let (dir, fs, tracker) = setup();
let file = dir.path().join("ordered.txt");
let write_def = write_tool(fs.clone(), tracker.clone());
let (_, writer) = write_def();
let edit_def = edit_tool(fs, tracker);
let (_, editor) = edit_def();
let write_in = serde_json::json!({
"file_path": file.to_str().unwrap(),
"content": "hello",
});
let edit_in = serde_json::json!({
"file_path": file.to_str().unwrap(),
"old_string": "hello",
"new_string": "goodbye",
});
let write_json = write_in.to_string();
let edit_json = edit_in.to_string();
let (write_out, edit_out) = tokio::join!(
writer.execute(&write_json, ToolExecutionContext::new("write", "batch", 0),),
editor.execute(&edit_json, ToolExecutionContext::new("edit", "batch", 1)),
);
write_out.unwrap();
edit_out.unwrap();
assert_eq!(std::fs::read_to_string(&file).unwrap(), "goodbye");
}
#[tokio::test]
async fn failed_same_file_mutation_releases_guard_for_followup() {
use crate::edit::edit_tool;
use llm_worker::tool::ToolExecutionContext;
let (dir, fs, tracker) = setup();
let file = dir.path().join("release.txt");
std::fs::write(&file, "alpha").unwrap();
let read_def = read_tool(fs.clone(), tracker.clone());
let (_, reader) = read_def();
reader
.execute(
&serde_json::json!({ "file_path": file.to_str().unwrap() }).to_string(),
ToolExecutionContext::new("read", "pre", 0),
)
.await
.unwrap();
let edit_def = edit_tool(fs, tracker);
let (_, editor) = edit_def();
let bad_edit = serde_json::json!({
"file_path": file.to_str().unwrap(),
"old_string": "missing",
"new_string": "beta",
});
let good_edit = serde_json::json!({
"file_path": file.to_str().unwrap(),
"old_string": "alpha",
"new_string": "beta",
});
assert!(
editor
.execute(
&bad_edit.to_string(),
ToolExecutionContext::new("bad", "batch", 0),
)
.await
.is_err()
);
editor
.execute(
&good_edit.to_string(),
ToolExecutionContext::new("good", "batch", 1),
)
.await
.unwrap();
assert_eq!(std::fs::read_to_string(&file).unwrap(), "beta");
}
}