From c15cd659e35e84d01e3159d60652e9ec65905e1e Mon Sep 17 00:00:00 2001 From: Hare Date: Thu, 12 Feb 2026 22:52:00 +0900 Subject: [PATCH] =?UTF-8?q?mp4=E4=BF=9D=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- core/inference_client.py | 102 ++++++++++++++-------------- flake.nix | 9 ++- operators/generate_mask.py | 107 ++++++++++++++++++++++-------- server/main.py | 132 +++++++++++++++++++++++++++++-------- 4 files changed, 242 insertions(+), 108 deletions(-) diff --git a/core/inference_client.py b/core/inference_client.py index 6ecb09c..6dcfb35 100644 --- a/core/inference_client.py +++ b/core/inference_client.py @@ -5,36 +5,36 @@ Manages the server process and handles HTTP communication using standard library (avoiding requests dependency). """ -import subprocess -import time import json -import urllib.request -import urllib.error -import threading import os import signal -from typing import Optional, Dict, Any, Tuple +import subprocess +import threading +import time +import urllib.error +import urllib.request +from typing import Any, Dict, Optional, Tuple class InferenceClient: """Client for the YOLO inference server.""" - + SERVER_URL = "http://127.0.0.1:8181" - + def __init__(self): self.server_process: Optional[subprocess.Popen] = None self._server_lock = threading.Lock() self.log_file = None self.log_file_path = None - + def start_server(self): """Start the inference server process.""" with self._server_lock: if self.is_server_running(): return - + print("[FaceMask] Starting inference server...") - + # Find project root # Assuming this file is in core/inference_client.py root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -46,24 +46,24 @@ class InferenceClient: # Load environment variables from .env file if it exists env_file = os.path.join(root_dir, ".env") if os.path.exists(env_file): - with open(env_file, 'r') as f: + with open(env_file, "r") as f: for line in f: line = line.strip() - if line and not line.startswith('#') and '=' in line: - key, value = line.split('=', 1) + if line and not line.startswith("#") and "=" in line: + key, value = line.split("=", 1) server_env[key] = value print(f"[FaceMask] Loaded environment from: {env_file}") # Clean PYTHONPATH to avoid conflicts with Nix Python packages # Only include project root to allow local imports - server_env['PYTHONPATH'] = root_dir + server_env["PYTHONPATH"] = root_dir # Remove Python-related environment variables that might cause conflicts # These can cause venv to import packages from Nix instead of venv env_vars_to_remove = [ - 'PYTHONUNBUFFERED', - '__PYVENV_LAUNCHER__', # macOS venv variable - 'VIRTUAL_ENV', # Will be set by venv's Python automatically + "PYTHONUNBUFFERED", + "__PYVENV_LAUNCHER__", # macOS venv variable + "VIRTUAL_ENV", # Will be set by venv's Python automatically ] for var in env_vars_to_remove: server_env.pop(var, None) @@ -73,25 +73,27 @@ class InferenceClient: if os.path.isdir(venv_bin): # Build a clean PATH with venv first, then essential system paths # Filter out any Nix Python-specific paths to avoid version conflicts - current_path = server_env.get('PATH', '') - path_entries = current_path.split(':') + current_path = server_env.get("PATH", "") + path_entries = current_path.split(":") # Filter out Nix Python 3.11 paths filtered_paths = [ - p for p in path_entries - if not ('/python3.11/' in p.lower() or '/python3-3.11' in p.lower()) + p + for p in path_entries + if not ("/python3.11/" in p.lower() or "/python3-3.11" in p.lower()) ] # Reconstruct PATH with venv first - clean_path = ':'.join([venv_bin] + filtered_paths) - server_env['PATH'] = clean_path + clean_path = ":".join([venv_bin] + filtered_paths) + server_env["PATH"] = clean_path print(f"[FaceMask] Using venv from: {venv_bin}") # Prepare log file for server output import tempfile + log_dir = tempfile.gettempdir() self.log_file_path = os.path.join(log_dir, "facemask_server.log") - self.log_file = open(self.log_file_path, 'w', buffering=1) # Line buffered + self.log_file = open(self.log_file_path, "w", buffering=1) # Line buffered print(f"[FaceMask] Server log: {self.log_file_path}") # Start process with 'python' command (will use venv if PATH is set correctly) @@ -120,12 +122,12 @@ class InferenceClient: try: if self.log_file: self.log_file.close() - with open(self.log_file_path, 'r') as f: + with open(self.log_file_path, "r") as f: log_content = f.read() if log_content.strip(): print("[FaceMask] Server log:") # Show last 50 lines - lines = log_content.strip().split('\n') + lines = log_content.strip().split("\n") for line in lines[-50:]: print(line) except Exception as e: @@ -143,18 +145,18 @@ class InferenceClient: try: if self.log_file: self.log_file.close() - with open(self.log_file_path, 'r') as f: + with open(self.log_file_path, "r") as f: log_content = f.read() if log_content.strip(): print("[FaceMask] Server log (partial):") - lines = log_content.strip().split('\n') + lines = log_content.strip().split("\n") for line in lines[-30:]: print(line) except Exception: pass raise RuntimeError("Server startup timed out") - + def stop_server(self): """Stop the inference server.""" with self._server_lock: @@ -175,15 +177,17 @@ class InferenceClient: except Exception: pass self.log_file = None - + def is_server_running(self) -> bool: """Check if server is responding.""" try: - with urllib.request.urlopen(f"{self.SERVER_URL}/status", timeout=1) as response: + with urllib.request.urlopen( + f"{self.SERVER_URL}/status", timeout=1 + ) as response: return response.status == 200 except (urllib.error.URLError, ConnectionRefusedError, TimeoutError): return False - + def generate_mask( self, video_path: str, @@ -196,13 +200,13 @@ class InferenceClient: ) -> str: """ Request mask generation. - + Returns: task_id (str) """ if not self.is_server_running(): self.start_server() - + data = { "video_path": video_path, "output_dir": output_dir, @@ -212,35 +216,36 @@ class InferenceClient: "iou_threshold": iou_threshold, "mask_scale": mask_scale, } - + req = urllib.request.Request( f"{self.SERVER_URL}/generate", - data=json.dumps(data).encode('utf-8'), - headers={'Content-Type': 'application/json'}, - method='POST' + data=json.dumps(data).encode("utf-8"), + headers={"Content-Type": "application/json"}, + method="POST", ) - + try: with urllib.request.urlopen(req) as response: - result = json.loads(response.read().decode('utf-8')) - return result['id'] + result = json.loads(response.read().decode("utf-8")) + return result["id"] except urllib.error.HTTPError as e: raise RuntimeError(f"Server error: {e.read().decode('utf-8')}") - + def get_task_status(self, task_id: str) -> Dict[str, Any]: """Get status of a task.""" try: - with urllib.request.urlopen(f"{self.SERVER_URL}/tasks/{task_id}") as response: - return json.loads(response.read().decode('utf-8')) + with urllib.request.urlopen( + f"{self.SERVER_URL}/tasks/{task_id}" + ) as response: + return json.loads(response.read().decode("utf-8")) except urllib.error.HTTPError: return {"status": "unknown"} - + def cancel_task(self, task_id: str): """Cancel a task.""" try: req = urllib.request.Request( - f"{self.SERVER_URL}/tasks/{task_id}/cancel", - method='POST' + f"{self.SERVER_URL}/tasks/{task_id}/cancel", method="POST" ) with urllib.request.urlopen(req): pass @@ -251,6 +256,7 @@ class InferenceClient: # Singleton _client: Optional[InferenceClient] = None + def get_client() -> InferenceClient: global _client if _client is None: diff --git a/flake.nix b/flake.nix index 35faf44..7b81f1b 100644 --- a/flake.nix +++ b/flake.nix @@ -48,8 +48,8 @@ export ROCM_PATH="${pkgs.rocmPackages.clr}" export HSA_OVERRIDE_GFX_VERSION="11.0.0" # RX 7900 (RDNA 3 / gfx1100) - # LD_LIBRARY_PATH: ROCm、C++標準ライブラリ、その他必要なライブラリ - export LD_LIBRARY_PATH="${pkgs.stdenv.cc.cc.lib}/lib:${pkgs.zlib}/lib:${pkgs.zstd.out}/lib:${pkgs.rocmPackages.clr}/lib:${pkgs.rocmPackages.rocm-runtime}/lib:$LD_LIBRARY_PATH" + # LD_LIBRARY_PATH: ROCm libraries FIRST (critical for GPU inference) + export LD_LIBRARY_PATH="${pkgs.rocmPackages.clr}/lib:${pkgs.rocmPackages.rocm-runtime}/lib:${pkgs.stdenv.cc.cc.lib}/lib:${pkgs.zlib}/lib:${pkgs.zstd.out}/lib:$LD_LIBRARY_PATH" # venvのセットアップ VENV_DIR="$PWD/.venv" @@ -89,12 +89,11 @@ export BLENDER_USER_ADDONS="$BLENDER_USER_SCRIPTS/addons" # 環境変数をファイルに保存(サーバープロセス用) + # CRITICAL: ROCm library paths MUST come first for GPU inference cat > "$PWD/.env" << EOF -LD_LIBRARY_PATH=${pkgs.stdenv.cc.cc.lib}/lib:${pkgs.zlib}/lib:${pkgs.zstd.out}/lib:${pkgs.rocmPackages.clr}/lib:${pkgs.rocmPackages.rocm-runtime}/lib +LD_LIBRARY_PATH=${pkgs.rocmPackages.clr}/lib:${pkgs.rocmPackages.rocm-runtime}/lib:${pkgs.stdenv.cc.cc.lib}/lib:${pkgs.zlib}/lib:${pkgs.zstd.out}/lib ROCM_PATH=${pkgs.rocmPackages.clr} HSA_OVERRIDE_GFX_VERSION=11.0.0 -PYTORCH_ROCM_ARCH=gfx1100 -ROCBLAS_TENSILE_LIBPATH=${pkgs.rocmPackages.clr}/lib/rocblas/library EOF echo "[Setup] Environment ready with GPU support" diff --git a/operators/generate_mask.py b/operators/generate_mask.py index ec868fc..7c034e5 100644 --- a/operators/generate_mask.py +++ b/operators/generate_mask.py @@ -143,54 +143,109 @@ class SEQUENCER_OT_generate_face_mask(Operator): def _check_cache(self, cache_dir: str, expected_frames: int) -> bool: """Check if cached masks exist and are complete. - + Args: cache_dir: Path to cache directory expected_frames: Number of frames expected - + Returns: - True if cache exists and has at least 90% of expected frames + True if cache exists and is valid """ if not os.path.exists(cache_dir): return False - - mask_files = [f for f in os.listdir(cache_dir) + + # Check for MP4 video (new format) + mask_video = os.path.join(cache_dir, "mask.mp4") + if os.path.exists(mask_video): + # Verify video has expected number of frames + import cv2 + cap = cv2.VideoCapture(mask_video) + if cap.isOpened(): + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + cap.release() + # Accept cache if at least 90% of frames exist + return frame_count >= expected_frames * 0.9 + cap.release() + return False + + # Fallback: check for PNG sequence (backward compatibility) + mask_files = [f for f in os.listdir(cache_dir) if f.startswith("mask_") and f.endswith(".png")] - + # Accept cache if at least 90% of frames exist - # (some frames may have been skipped due to read errors) return len(mask_files) >= expected_frames * 0.9 - def _add_mask_strip(self, context, source_strip_name: str, mask_dir: str): - """Add mask image sequence as a new strip.""" + def _add_mask_strip(self, context, source_strip_name: str, mask_path: str): + """Add mask video as a new strip. + + Args: + context: Blender context + source_strip_name: Name of the source video strip + mask_path: Path to mask video file or directory (for backward compatibility) + """ scene = context.scene seq_editor = scene.sequence_editor - + if not seq_editor: return - + # Find source strip (Blender 5.0 uses 'strips' instead of 'sequences') source_strip = seq_editor.strips.get(source_strip_name) if not source_strip: return - - # Get first mask image - mask_files = sorted([ - f for f in os.listdir(mask_dir) - if f.startswith("mask_") and f.endswith(".png") - ]) - - if not mask_files: - return - - first_mask = os.path.join(mask_dir, mask_files[0]) - + + # Check if mask_path is a video file or directory (backward compatibility) + if os.path.isfile(mask_path): + # New format: single MP4 file + mask_video = mask_path + else: + # Old format: directory with PNG sequence (backward compatibility) + mask_video = os.path.join(mask_path, "mask.mp4") + if not os.path.exists(mask_video): + # Fallback to PNG sequence + mask_files = sorted([ + f for f in os.listdir(mask_path) + if f.startswith("mask_") and f.endswith(".png") + ]) + if not mask_files: + return + first_mask = os.path.join(mask_path, mask_files[0]) + self._add_mask_strip_png_sequence(context, source_strip_name, mask_path, mask_files, first_mask) + return + # Find an empty channel used_channels = {s.channel for s in seq_editor.strips} new_channel = source_strip.channel + 1 while new_channel in used_channels: new_channel += 1 - + + # Add movie strip (Blender 5.0 API) + mask_strip = seq_editor.strips.new_movie( + name=f"{source_strip_name}_mask", + filepath=mask_video, + channel=new_channel, + frame_start=source_strip.frame_final_start, + ) + + # Set blend mode for mask + mask_strip.blend_type = 'ALPHA_OVER' + mask_strip.blend_alpha = 0.5 + + def _add_mask_strip_png_sequence(self, context, source_strip_name, mask_dir, mask_files, first_mask): + """Backward compatibility: Add PNG sequence as mask strip.""" + scene = context.scene + seq_editor = scene.sequence_editor + source_strip = seq_editor.strips.get(source_strip_name) + + if not source_strip: + return + + # Find an empty channel + used_channels = {s.channel for s in seq_editor.strips} + new_channel = source_strip.channel + 1 + while new_channel in used_channels: + new_channel += 1 + # Add image sequence (Blender 5.0 API) mask_strip = seq_editor.strips.new_image( name=f"{source_strip_name}_mask", @@ -198,11 +253,11 @@ class SEQUENCER_OT_generate_face_mask(Operator): channel=new_channel, frame_start=source_strip.frame_final_start, ) - + # Add remaining frames for mask_file in mask_files[1:]: mask_strip.elements.append(mask_file) - + # Set blend mode for mask mask_strip.blend_type = 'ALPHA_OVER' mask_strip.blend_alpha = 0.5 diff --git a/server/main.py b/server/main.py index 78f6a2f..0f9a42f 100644 --- a/server/main.py +++ b/server/main.py @@ -8,6 +8,29 @@ GPU-accelerated face detection using ONNX Runtime. import os import sys import platform + +# CRITICAL: Fix LD_LIBRARY_PATH before importing cv2 or torch +# cv2 adds its own lib path to the front, which can override ROCm libraries +def fix_library_path(): + """Ensure ROCm libraries are loaded before cv2's bundled libraries.""" + ld_path = os.environ.get('LD_LIBRARY_PATH', '') + + # Split and filter paths + paths = [p for p in ld_path.split(':') if p] + + # Separate ROCm/GPU paths from other paths + rocm_paths = [p for p in paths if 'rocm' in p.lower() or 'clr-' in p or 'hip' in p.lower()] + other_paths = [p for p in paths if p not in rocm_paths] + + # Rebuild with ROCm paths first + if rocm_paths: + new_ld_path = ':'.join(rocm_paths + other_paths) + os.environ['LD_LIBRARY_PATH'] = new_ld_path + print(f"[FaceMask] Fixed LD_LIBRARY_PATH to prioritize ROCm libraries") + +# Fix library path BEFORE any other imports +fix_library_path() + import threading import uuid import queue @@ -61,11 +84,15 @@ class GenerateRequest(BaseModel): mask_scale: float = 1.5 def process_video_task(task_id: str, req: GenerateRequest): - """Background task to process video.""" + """Background task to process video with async MP4 output.""" + writer = None + write_queue = None + writer_thread = None + try: tasks[task_id].status = TaskStatus.PROCESSING cancel_event = cancel_events.get(task_id) - + # Verify video exists if not os.path.exists(req.video_path): tasks[task_id].status = TaskStatus.FAILED @@ -78,28 +105,60 @@ def process_video_task(task_id: str, req: GenerateRequest): conf_threshold=req.conf_threshold, iou_threshold=req.iou_threshold ) - # Ensure model is loaded _ = detector.model - + # Open video cap = cv2.VideoCapture(req.video_path) if not cap.isOpened(): tasks[task_id].status = TaskStatus.FAILED tasks[task_id].message = "Failed to open video" return - - # Determine frame range + + # Get video properties + fps = cap.get(cv2.CAP_PROP_FPS) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_video_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) end_frame = min(req.end_frame, total_video_frames - 1) frames_to_process = end_frame - req.start_frame + 1 - + tasks[task_id].total = frames_to_process - + # Ensure output directory exists os.makedirs(req.output_dir, exist_ok=True) - - print(f"Starting processing: {req.video_path} ({frames_to_process} frames)") - + + # Setup MP4 writer (grayscale) + output_video_path = os.path.join(req.output_dir, "mask.mp4") + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + writer = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height), isColor=False) + + if not writer.isOpened(): + tasks[task_id].status = TaskStatus.FAILED + tasks[task_id].message = "Failed to create video writer" + cap.release() + return + + # Async writer setup + write_queue = queue.Queue(maxsize=30) # Buffer up to 30 frames + writer_running = threading.Event() + writer_running.set() + + def async_writer(): + """Background thread for writing frames to video.""" + while writer_running.is_set() or not write_queue.empty(): + try: + mask = write_queue.get(timeout=0.1) + if mask is not None: + writer.write(mask) + write_queue.task_done() + except queue.Empty: + continue + + writer_thread = threading.Thread(target=async_writer, daemon=True) + writer_thread.start() + + print(f"Starting processing: {req.video_path} ({frames_to_process} frames) -> {output_video_path}") + # Process loop current_count = 0 for frame_idx in range(req.start_frame, end_frame + 1): @@ -107,38 +166,44 @@ def process_video_task(task_id: str, req: GenerateRequest): tasks[task_id].status = TaskStatus.CANCELLED tasks[task_id].message = "Cancelled by user" break - + # Read frame cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) ret, frame = cap.read() - + if ret: # Detect detections = detector.detect(frame) - + # Generate mask mask = detector.generate_mask( frame.shape, detections, mask_scale=req.mask_scale ) - - # Save - mask_filename = f"mask_{current_count:06d}.png" # Note: using relative index for filename - mask_path = os.path.join(req.output_dir, mask_filename) - cv2.imwrite(mask_path, mask) - + + # Async write to queue + write_queue.put(mask) + # Update progress current_count += 1 tasks[task_id].progress = current_count - + + # Cleanup + writer_running.clear() + write_queue.join() # Wait for all frames to be written + if writer_thread: + writer_thread.join(timeout=5) + cap.release() - + if writer: + writer.release() + if tasks[task_id].status == TaskStatus.PROCESSING: tasks[task_id].status = TaskStatus.COMPLETED - tasks[task_id].result_path = req.output_dir + tasks[task_id].result_path = output_video_path # Return video path tasks[task_id].message = "Processing completed successfully" - print(f"Task {task_id} completed.") + print(f"Task {task_id} completed: {output_video_path}") except Exception as e: tasks[task_id].status = TaskStatus.FAILED @@ -230,12 +295,21 @@ def log_startup_diagnostics(): for var in rocm_vars: value = os.environ.get(var) if value: - # Truncate very long values - if len(value) > 200: - display_value = value[:200] + "... (truncated)" + # For LD_LIBRARY_PATH, show if ROCm paths are included + if var == 'LD_LIBRARY_PATH': + has_rocm = 'rocm' in value.lower() or 'clr-' in value + has_hip = 'hip' in value.lower() + print(f" {var}: {value[:100]}...") + print(f" Contains ROCm paths: {has_rocm}") + print(f" Contains HIP paths: {has_hip}") + if not has_rocm: + print(f" ⚠️ WARNING: ROCm library paths not found!") else: - display_value = value - print(f" {var}: {display_value}") + if len(value) > 200: + display_value = value[:200] + "... (truncated)" + else: + display_value = value + print(f" {var}: {display_value}") else: print(f" {var}: (not set)")