From 401301438d3012e00106f286d2de1f498f65d714 Mon Sep 17 00:00:00 2001 From: Hare Date: Wed, 10 Jun 2026 18:24:53 +0900 Subject: [PATCH] fix: serialize same-file mutations --- crates/tools/Cargo.toml | 2 +- crates/tools/src/edit.rs | 4 +- crates/tools/src/tracker.rs | 124 +++++++++++++++++++++++++++++++++++- crates/tools/src/write.rs | 90 +++++++++++++++++++++++++- 4 files changed, 216 insertions(+), 4 deletions(-) diff --git a/crates/tools/Cargo.toml b/crates/tools/Cargo.toml index 6331af67..bee31826 100644 --- a/crates/tools/Cargo.toml +++ b/crates/tools/Cargo.toml @@ -23,7 +23,7 @@ serde_json = { workspace = true } sha2 = { workspace = true } tempfile = { workspace = true } thiserror = { workspace = true } -tokio = { workspace = true, features = ["process", "rt", "time"] } +tokio = { workspace = true, features = ["process", "rt", "sync", "time"] } tracing = { workspace = true } [dev-dependencies] diff --git a/crates/tools/src/edit.rs b/crates/tools/src/edit.rs index 32d905d6..13ede1b3 100644 --- a/crates/tools/src/edit.rs +++ b/crates/tools/src/edit.rs @@ -39,7 +39,7 @@ impl Tool for EditTool { async fn execute( &self, input_json: &str, - _ctx: llm_worker::tool::ToolExecutionContext, + ctx: llm_worker::tool::ToolExecutionContext, ) -> Result { let params: EditParams = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid Edit input: {e}")))?; @@ -61,6 +61,8 @@ impl Tool for EditTool { )); } + let _mutation_permit = self.tracker.acquire_mutation(¶ms.file_path, &ctx).await; + // Load current content and verify it matches the recorded hash. let current_bytes = self.fs.read_bytes(¶ms.file_path)?; self.tracker.verify(¶ms.file_path, ¤t_bytes)?; diff --git a/crates/tools/src/tracker.rs b/crates/tools/src/tracker.rs index fcfdbe2e..867b0423 100644 --- a/crates/tools/src/tracker.rs +++ b/crates/tools/src/tracker.rs @@ -38,7 +38,9 @@ //! ``` use std::collections::{HashMap, VecDeque}; -use std::path::{Path, PathBuf}; +use std::path::{Component, Path, PathBuf}; + +use llm_worker::tool::ToolExecutionContext; use std::sync::{Arc, Mutex}; use sha2::{Digest, Sha256}; @@ -58,6 +60,60 @@ fn hash_bytes(bytes: &[u8]) -> ContentHash { hasher.finalize().into() } +#[derive(Debug, Clone, Default)] +struct FileMutationCoordinator { + locks: Arc>>>>, +} + +pub(crate) struct FileMutationPermit { + _guard: tokio::sync::OwnedMutexGuard<()>, +} + +impl FileMutationCoordinator { + async fn acquire(&self, path: &Path) -> FileMutationPermit { + let key = file_mutation_key(path); + let lock = { + let mut locks = self.locks.lock().await; + locks + .entry(key) + .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(()))) + .clone() + }; + FileMutationPermit { + _guard: lock.lock_owned().await, + } + } +} + +fn file_mutation_key(path: &Path) -> PathBuf { + if let Ok(canonical) = path.canonicalize() { + return canonical; + } + if let (Some(parent), Some(file_name)) = (path.parent(), path.file_name()) + && let Ok(canonical_parent) = parent.canonicalize() + { + return canonical_parent.join(file_name); + } + normalize_path_lexically(path) +} + +fn normalize_path_lexically(path: &Path) -> PathBuf { + let mut normalized = PathBuf::new(); + for component in path.components() { + match component { + Component::CurDir => {} + Component::ParentDir => { + if !normalized.pop() { + normalized.push(component.as_os_str()); + } + } + Component::Normal(part) => normalized.push(part), + Component::RootDir | Component::Prefix(_) => normalized.push(component.as_os_str()), + } + } + normalized +} + #[derive(Debug, Default)] struct Inner { /// Hash of each file's last observed contents, keyed by canonical path. @@ -74,6 +130,7 @@ struct Inner { #[derive(Debug, Clone, Default)] pub struct Tracker { inner: Arc>, + mutations: FileMutationCoordinator, } impl Tracker { @@ -82,6 +139,25 @@ impl Tracker { Self::default() } + /// Acquire the per-target-file mutation guard shared by `Write` and `Edit`. + /// + /// The guard is keyed by canonical target path where possible so equivalent + /// paths serialize through the same lock. Worker still executes tool calls in + /// parallel; this only gates the critical filesystem mutation section for + /// builtin file mutation tools. + pub(crate) async fn acquire_mutation( + &self, + path: &Path, + ctx: &ToolExecutionContext, + ) -> FileMutationPermit { + tracing::debug!( + batch_id = %ctx.batch_id, + call_index = ctx.call_index, + "acquire file mutation guard" + ); + self.mutations.acquire(path).await + } + /// Record that `path` has been observed with the given content bytes. /// /// Called by the `Read` tool after a successful read, and by the @@ -347,4 +423,50 @@ mod tests { assert!(recent.iter().all(|p| !p.ends_with(&name))); } } + + #[tokio::test] + async fn mutation_guard_blocks_equivalent_paths_until_drop() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("target.txt"); + fs::write(&file, "x").unwrap(); + let equivalent = dir.path().join("sub").join("..").join("target.txt"); + fs::create_dir(dir.path().join("sub")).unwrap(); + let tracker = Tracker::new(); + let first = tracker + .acquire_mutation(&file, &ToolExecutionContext::new("a", "batch", 0)) + .await; + + let second_ctx = ToolExecutionContext::new("b", "batch", 1); + let second = tracker.acquire_mutation(&equivalent, &second_ctx); + assert!( + tokio::time::timeout(std::time::Duration::from_millis(10), second) + .await + .is_err() + ); + + drop(first); + tracker + .acquire_mutation(&equivalent, &ToolExecutionContext::new("b", "batch", 1)) + .await; + } + + #[tokio::test] + async fn mutation_guard_does_not_block_different_files() { + let dir = tempfile::tempdir().unwrap(); + let first_file = dir.path().join("a.txt"); + let second_file = dir.path().join("b.txt"); + fs::write(&first_file, "a").unwrap(); + fs::write(&second_file, "b").unwrap(); + let tracker = Tracker::new(); + let _first = tracker + .acquire_mutation(&first_file, &ToolExecutionContext::new("a", "batch", 0)) + .await; + + tokio::time::timeout( + std::time::Duration::from_millis(100), + tracker.acquire_mutation(&second_file, &ToolExecutionContext::new("b", "batch", 1)), + ) + .await + .expect("different files should not share a mutation guard"); + } } diff --git a/crates/tools/src/write.rs b/crates/tools/src/write.rs index 8902ec8d..5df7d335 100644 --- a/crates/tools/src/write.rs +++ b/crates/tools/src/write.rs @@ -33,7 +33,7 @@ impl Tool for WriteTool { async fn execute( &self, input_json: &str, - _ctx: llm_worker::tool::ToolExecutionContext, + ctx: llm_worker::tool::ToolExecutionContext, ) -> Result { let params: WriteParams = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(format!("invalid Write input: {e}")))?; @@ -44,6 +44,8 @@ impl Tool for WriteTool { "Write" ); + let _mutation_permit = self.tracker.acquire_mutation(¶ms.file_path, &ctx).await; + // Policy check: if the target already exists, it must have been // observed by the Read tool (via the tracker) and its current // contents must match the recorded hash. @@ -231,4 +233,90 @@ mod tests { .unwrap_err(); assert!(matches!(err, ToolError::InvalidArgument(_))); } + + #[tokio::test] + async fn write_then_edit_same_file_same_batch_uses_call_order() { + use crate::edit::edit_tool; + use llm_worker::tool::ToolExecutionContext; + + let (dir, fs, tracker) = setup(); + let file = dir.path().join("ordered.txt"); + + let write_def = write_tool(fs.clone(), tracker.clone()); + let (_, writer) = write_def(); + let edit_def = edit_tool(fs, tracker); + let (_, editor) = edit_def(); + + let write_in = serde_json::json!({ + "file_path": file.to_str().unwrap(), + "content": "hello", + }); + let edit_in = serde_json::json!({ + "file_path": file.to_str().unwrap(), + "old_string": "hello", + "new_string": "goodbye", + }); + + let write_json = write_in.to_string(); + let edit_json = edit_in.to_string(); + let (write_out, edit_out) = tokio::join!( + writer.execute(&write_json, ToolExecutionContext::new("write", "batch", 0),), + editor.execute(&edit_json, ToolExecutionContext::new("edit", "batch", 1)), + ); + + write_out.unwrap(); + edit_out.unwrap(); + assert_eq!(std::fs::read_to_string(&file).unwrap(), "goodbye"); + } + + #[tokio::test] + async fn failed_same_file_mutation_releases_guard_for_followup() { + use crate::edit::edit_tool; + use llm_worker::tool::ToolExecutionContext; + + let (dir, fs, tracker) = setup(); + let file = dir.path().join("release.txt"); + std::fs::write(&file, "alpha").unwrap(); + + let read_def = read_tool(fs.clone(), tracker.clone()); + let (_, reader) = read_def(); + reader + .execute( + &serde_json::json!({ "file_path": file.to_str().unwrap() }).to_string(), + ToolExecutionContext::new("read", "pre", 0), + ) + .await + .unwrap(); + + let edit_def = edit_tool(fs, tracker); + let (_, editor) = edit_def(); + let bad_edit = serde_json::json!({ + "file_path": file.to_str().unwrap(), + "old_string": "missing", + "new_string": "beta", + }); + let good_edit = serde_json::json!({ + "file_path": file.to_str().unwrap(), + "old_string": "alpha", + "new_string": "beta", + }); + + assert!( + editor + .execute( + &bad_edit.to_string(), + ToolExecutionContext::new("bad", "batch", 0), + ) + .await + .is_err() + ); + editor + .execute( + &good_edit.to_string(), + ToolExecutionContext::new("good", "batch", 1), + ) + .await + .unwrap(); + assert_eq!(std::fs::read_to_string(&file).unwrap(), "beta"); + } }