Compare commits

..

3 Commits

Author SHA1 Message Date
d8d27ddf23 パフォーマンス#1 2026-02-12 23:46:51 +09:00
c15cd659e3 mp4保存 2026-02-12 22:52:00 +09:00
eeb8400727 UI拡充 2026-02-12 22:03:02 +09:00
10 changed files with 723 additions and 155 deletions

View File

@ -15,21 +15,58 @@ bl_info = {
def register():
"""Register all extension components."""
import bpy
from bpy.props import FloatProperty
from . import operators
from . import panels
# Register scene properties for face detection parameters
bpy.types.Scene.facemask_conf_threshold = FloatProperty(
name="Confidence",
description="YOLO confidence threshold (higher = fewer false positives)",
default=0.5,
min=0.1,
max=1.0,
step=0.01,
)
bpy.types.Scene.facemask_iou_threshold = FloatProperty(
name="IOU Threshold",
description="Non-maximum suppression IOU threshold",
default=0.45,
min=0.1,
max=1.0,
step=0.01,
)
bpy.types.Scene.facemask_mask_scale = FloatProperty(
name="Mask Scale",
description="Scale factor for mask region (1.0 = exact face size)",
default=1.5,
min=1.0,
max=3.0,
step=0.1,
)
operators.register()
panels.register()
def unregister():
"""Unregister all extension components."""
import bpy
from . import operators
from . import panels
panels.unregister()
operators.unregister()
# Unregister scene properties
del bpy.types.Scene.facemask_conf_threshold
del bpy.types.Scene.facemask_iou_threshold
del bpy.types.Scene.facemask_mask_scale
if __name__ == "__main__":
register()

View File

@ -5,15 +5,15 @@ Manages the server process and handles HTTP communication
using standard library (avoiding requests dependency).
"""
import subprocess
import time
import json
import urllib.request
import urllib.error
import threading
import os
import signal
from typing import Optional, Dict, Any, Tuple
import subprocess
import threading
import time
import urllib.error
import urllib.request
from typing import Any, Dict, Optional, Tuple
class InferenceClient:
@ -46,24 +46,24 @@ class InferenceClient:
# Load environment variables from .env file if it exists
env_file = os.path.join(root_dir, ".env")
if os.path.exists(env_file):
with open(env_file, 'r') as f:
with open(env_file, "r") as f:
for line in f:
line = line.strip()
if line and not line.startswith('#') and '=' in line:
key, value = line.split('=', 1)
if line and not line.startswith("#") and "=" in line:
key, value = line.split("=", 1)
server_env[key] = value
print(f"[FaceMask] Loaded environment from: {env_file}")
# Clean PYTHONPATH to avoid conflicts with Nix Python packages
# Only include project root to allow local imports
server_env['PYTHONPATH'] = root_dir
server_env["PYTHONPATH"] = root_dir
# Remove Python-related environment variables that might cause conflicts
# These can cause venv to import packages from Nix instead of venv
env_vars_to_remove = [
'PYTHONUNBUFFERED',
'__PYVENV_LAUNCHER__', # macOS venv variable
'VIRTUAL_ENV', # Will be set by venv's Python automatically
"PYTHONUNBUFFERED",
"__PYVENV_LAUNCHER__", # macOS venv variable
"VIRTUAL_ENV", # Will be set by venv's Python automatically
]
for var in env_vars_to_remove:
server_env.pop(var, None)
@ -73,25 +73,27 @@ class InferenceClient:
if os.path.isdir(venv_bin):
# Build a clean PATH with venv first, then essential system paths
# Filter out any Nix Python-specific paths to avoid version conflicts
current_path = server_env.get('PATH', '')
path_entries = current_path.split(':')
current_path = server_env.get("PATH", "")
path_entries = current_path.split(":")
# Filter out Nix Python 3.11 paths
filtered_paths = [
p for p in path_entries
if not ('/python3.11/' in p.lower() or '/python3-3.11' in p.lower())
p
for p in path_entries
if not ("/python3.11/" in p.lower() or "/python3-3.11" in p.lower())
]
# Reconstruct PATH with venv first
clean_path = ':'.join([venv_bin] + filtered_paths)
server_env['PATH'] = clean_path
clean_path = ":".join([venv_bin] + filtered_paths)
server_env["PATH"] = clean_path
print(f"[FaceMask] Using venv from: {venv_bin}")
# Prepare log file for server output
import tempfile
log_dir = tempfile.gettempdir()
self.log_file_path = os.path.join(log_dir, "facemask_server.log")
self.log_file = open(self.log_file_path, 'w', buffering=1) # Line buffered
self.log_file = open(self.log_file_path, "w", buffering=1) # Line buffered
print(f"[FaceMask] Server log: {self.log_file_path}")
# Start process with 'python' command (will use venv if PATH is set correctly)
@ -120,12 +122,12 @@ class InferenceClient:
try:
if self.log_file:
self.log_file.close()
with open(self.log_file_path, 'r') as f:
with open(self.log_file_path, "r") as f:
log_content = f.read()
if log_content.strip():
print("[FaceMask] Server log:")
# Show last 50 lines
lines = log_content.strip().split('\n')
lines = log_content.strip().split("\n")
for line in lines[-50:]:
print(line)
except Exception as e:
@ -143,11 +145,11 @@ class InferenceClient:
try:
if self.log_file:
self.log_file.close()
with open(self.log_file_path, 'r') as f:
with open(self.log_file_path, "r") as f:
log_content = f.read()
if log_content.strip():
print("[FaceMask] Server log (partial):")
lines = log_content.strip().split('\n')
lines = log_content.strip().split("\n")
for line in lines[-30:]:
print(line)
except Exception:
@ -179,7 +181,9 @@ class InferenceClient:
def is_server_running(self) -> bool:
"""Check if server is responding."""
try:
with urllib.request.urlopen(f"{self.SERVER_URL}/status", timeout=1) as response:
with urllib.request.urlopen(
f"{self.SERVER_URL}/status", timeout=1
) as response:
return response.status == 200
except (urllib.error.URLError, ConnectionRefusedError, TimeoutError):
return False
@ -215,23 +219,25 @@ class InferenceClient:
req = urllib.request.Request(
f"{self.SERVER_URL}/generate",
data=json.dumps(data).encode('utf-8'),
headers={'Content-Type': 'application/json'},
method='POST'
data=json.dumps(data).encode("utf-8"),
headers={"Content-Type": "application/json"},
method="POST",
)
try:
with urllib.request.urlopen(req) as response:
result = json.loads(response.read().decode('utf-8'))
return result['id']
result = json.loads(response.read().decode("utf-8"))
return result["id"]
except urllib.error.HTTPError as e:
raise RuntimeError(f"Server error: {e.read().decode('utf-8')}")
def get_task_status(self, task_id: str) -> Dict[str, Any]:
"""Get status of a task."""
try:
with urllib.request.urlopen(f"{self.SERVER_URL}/tasks/{task_id}") as response:
return json.loads(response.read().decode('utf-8'))
with urllib.request.urlopen(
f"{self.SERVER_URL}/tasks/{task_id}"
) as response:
return json.loads(response.read().decode("utf-8"))
except urllib.error.HTTPError:
return {"status": "unknown"}
@ -239,8 +245,7 @@ class InferenceClient:
"""Cancel a task."""
try:
req = urllib.request.Request(
f"{self.SERVER_URL}/tasks/{task_id}/cancel",
method='POST'
f"{self.SERVER_URL}/tasks/{task_id}/cancel", method="POST"
)
with urllib.request.urlopen(req):
pass
@ -251,6 +256,7 @@ class InferenceClient:
# Singleton
_client: Optional[InferenceClient] = None
def get_client() -> InferenceClient:
global _client
if _client is None:

111
core/utils.py Normal file
View File

@ -0,0 +1,111 @@
"""
Utility functions for Face Mask extension.
Provides helper functions for server status, cache info, etc.
"""
import os
import urllib.request
import urllib.error
import json
import tempfile
from typing import Dict, Tuple, Optional
def get_server_status() -> Dict:
"""
Get server status and GPU information.
Returns:
dict: {
'running': bool,
'gpu_available': bool,
'gpu_device': str or None,
'gpu_count': int,
'rocm_version': str or None,
}
"""
result = {
'running': False,
'gpu_available': False,
'gpu_device': None,
'gpu_count': 0,
'rocm_version': None,
}
try:
with urllib.request.urlopen("http://127.0.0.1:8181/status", timeout=1) as response:
data = json.loads(response.read().decode('utf-8'))
result['running'] = data.get('status') == 'running'
result['gpu_available'] = data.get('gpu_available', False)
result['gpu_device'] = data.get('gpu_device')
result['gpu_count'] = data.get('gpu_count', 0)
result['rocm_version'] = data.get('rocm_version')
except (urllib.error.URLError, ConnectionRefusedError, TimeoutError):
result['running'] = False
return result
def get_cache_info(strip_name: Optional[str] = None) -> Tuple[str, int, int]:
"""
Get cache directory information.
Args:
strip_name: If provided, get info for specific strip. Otherwise, get info for all cache.
Returns:
Tuple of (cache_path, total_size_bytes, file_count)
"""
import bpy
blend_file = bpy.data.filepath
if strip_name:
# Get cache for specific strip
if blend_file:
project_dir = os.path.dirname(blend_file)
cache_path = os.path.join(project_dir, ".mask_cache", strip_name)
else:
cache_path = os.path.join(tempfile.gettempdir(), "blender_mask_cache", strip_name)
else:
# Get cache root
if blend_file:
project_dir = os.path.dirname(blend_file)
cache_path = os.path.join(project_dir, ".mask_cache")
else:
cache_path = os.path.join(tempfile.gettempdir(), "blender_mask_cache")
# Calculate size and count
total_size = 0
file_count = 0
if os.path.exists(cache_path):
for root, dirs, files in os.walk(cache_path):
for file in files:
if file.endswith('.png'): # Only count mask images
file_path = os.path.join(root, file)
try:
total_size += os.path.getsize(file_path)
file_count += 1
except OSError:
pass
return cache_path, total_size, file_count
def format_size(size_bytes: int) -> str:
"""
Format bytes to human-readable size.
Args:
size_bytes: Size in bytes
Returns:
Formatted string (e.g., "1.5 MB")
"""
for unit in ['B', 'KB', 'MB', 'GB']:
if size_bytes < 1024.0:
return f"{size_bytes:.1f} {unit}"
size_bytes /= 1024.0
return f"{size_bytes:.1f} TB"

View File

@ -48,8 +48,8 @@
export ROCM_PATH="${pkgs.rocmPackages.clr}"
export HSA_OVERRIDE_GFX_VERSION="11.0.0" # RX 7900 (RDNA 3 / gfx1100)
# LD_LIBRARY_PATH: ROCm、C++標準ライブラリ、その他必要なライブラリ
export LD_LIBRARY_PATH="${pkgs.stdenv.cc.cc.lib}/lib:${pkgs.zlib}/lib:${pkgs.zstd.out}/lib:${pkgs.rocmPackages.clr}/lib:${pkgs.rocmPackages.rocm-runtime}/lib:$LD_LIBRARY_PATH"
# LD_LIBRARY_PATH: ROCm libraries FIRST (critical for GPU inference)
export LD_LIBRARY_PATH="${pkgs.rocmPackages.clr}/lib:${pkgs.rocmPackages.rocm-runtime}/lib:${pkgs.stdenv.cc.cc.lib}/lib:${pkgs.zlib}/lib:${pkgs.zstd.out}/lib:$LD_LIBRARY_PATH"
# venvのセットアップ
VENV_DIR="$PWD/.venv"
@ -89,12 +89,11 @@
export BLENDER_USER_ADDONS="$BLENDER_USER_SCRIPTS/addons"
# 環境変数をファイルに保存(サーバープロセス用)
# CRITICAL: ROCm library paths MUST come first for GPU inference
cat > "$PWD/.env" << EOF
LD_LIBRARY_PATH=${pkgs.stdenv.cc.cc.lib}/lib:${pkgs.zlib}/lib:${pkgs.zstd.out}/lib:${pkgs.rocmPackages.clr}/lib:${pkgs.rocmPackages.rocm-runtime}/lib
LD_LIBRARY_PATH=${pkgs.rocmPackages.clr}/lib:${pkgs.rocmPackages.rocm-runtime}/lib:${pkgs.stdenv.cc.cc.lib}/lib:${pkgs.zlib}/lib:${pkgs.zstd.out}/lib
ROCM_PATH=${pkgs.rocmPackages.clr}
HSA_OVERRIDE_GFX_VERSION=11.0.0
PYTORCH_ROCM_ARCH=gfx1100
ROCBLAS_TENSILE_LIBPATH=${pkgs.rocmPackages.clr}/lib/rocblas/library
EOF
echo "[Setup] Environment ready with GPU support"

View File

@ -2,13 +2,16 @@
from . import generate_mask
from . import apply_blur
from . import clear_cache
def register():
generate_mask.register()
apply_blur.register()
clear_cache.register()
def unregister():
clear_cache.unregister()
apply_blur.unregister()
generate_mask.unregister()

126
operators/clear_cache.py Normal file
View File

@ -0,0 +1,126 @@
"""
Clear Cache Operator.
Provides operators to clear mask cache directories.
"""
import os
import shutil
import bpy
from bpy.types import Operator
from bpy.props import BoolProperty
class SEQUENCER_OT_clear_mask_cache(Operator):
"""Clear mask cache directories."""
bl_idname = "sequencer.clear_mask_cache"
bl_label = "Clear Mask Cache"
bl_description = "Delete cached mask images"
bl_options = {'REGISTER', 'UNDO'}
all_strips: BoolProperty(
name="All Strips",
description="Clear cache for all strips (otherwise only current strip)",
default=False,
)
def execute(self, context):
import tempfile
blend_file = bpy.data.filepath
total_size = 0
cleared_count = 0
if self.all_strips:
# Clear all cache directories
if blend_file:
# Project cache
project_dir = os.path.dirname(blend_file)
cache_root = os.path.join(project_dir, ".mask_cache")
else:
# Temp cache
cache_root = os.path.join(tempfile.gettempdir(), "blender_mask_cache")
if os.path.exists(cache_root):
# Calculate size before deletion
for root, dirs, files in os.walk(cache_root):
for file in files:
file_path = os.path.join(root, file)
try:
total_size += os.path.getsize(file_path)
except OSError:
pass
# Delete cache directory
try:
shutil.rmtree(cache_root)
cleared_count = len(os.listdir(cache_root)) if os.path.exists(cache_root) else 0
self.report({'INFO'}, f"Cleared all cache ({self._format_size(total_size)})")
except Exception as e:
self.report({'ERROR'}, f"Failed to clear cache: {e}")
return {'CANCELLED'}
else:
self.report({'INFO'}, "No cache to clear")
return {'FINISHED'}
else:
# Clear cache for active strip only
seq_editor = context.scene.sequence_editor
if not seq_editor or not seq_editor.active_strip:
self.report({'WARNING'}, "No strip selected")
return {'CANCELLED'}
strip = seq_editor.active_strip
if blend_file:
project_dir = os.path.dirname(blend_file)
cache_dir = os.path.join(project_dir, ".mask_cache", strip.name)
else:
cache_dir = os.path.join(tempfile.gettempdir(), "blender_mask_cache", strip.name)
if os.path.exists(cache_dir):
# Calculate size
for root, dirs, files in os.walk(cache_dir):
for file in files:
file_path = os.path.join(root, file)
try:
total_size += os.path.getsize(file_path)
except OSError:
pass
# Delete
try:
shutil.rmtree(cache_dir)
self.report({'INFO'}, f"Cleared cache for {strip.name} ({self._format_size(total_size)})")
except Exception as e:
self.report({'ERROR'}, f"Failed to clear cache: {e}")
return {'CANCELLED'}
else:
self.report({'INFO'}, f"No cache for {strip.name}")
return {'FINISHED'}
return {'FINISHED'}
def _format_size(self, size_bytes):
"""Format bytes to human-readable size."""
for unit in ['B', 'KB', 'MB', 'GB']:
if size_bytes < 1024.0:
return f"{size_bytes:.1f} {unit}"
size_bytes /= 1024.0
return f"{size_bytes:.1f} TB"
# Registration
classes = [
SEQUENCER_OT_clear_mask_cache,
]
def register():
for cls in classes:
bpy.utils.register_class(cls)
def unregister():
for cls in reversed(classes):
bpy.utils.unregister_class(cls)

View File

@ -7,7 +7,7 @@ from video strips in the Video Sequence Editor.
import os
import bpy
from bpy.props import FloatProperty, IntProperty
from bpy.props import IntProperty
from bpy.types import Operator
from ..core.async_generator import get_generator
@ -21,31 +21,6 @@ class SEQUENCER_OT_generate_face_mask(Operator):
bl_description = "Detect faces and generate mask image sequence"
bl_options = {'REGISTER', 'UNDO'}
# YOLO Detection parameters
conf_threshold: FloatProperty(
name="Confidence",
description="YOLO confidence threshold (higher = fewer false positives)",
default=0.25,
min=0.1,
max=1.0,
)
iou_threshold: FloatProperty(
name="IOU Threshold",
description="Non-maximum suppression IOU threshold",
default=0.45,
min=0.1,
max=1.0,
)
mask_scale: FloatProperty(
name="Mask Scale",
description="Scale factor for mask region (1.0 = exact face size)",
default=1.5,
min=1.0,
max=3.0,
)
@classmethod
def poll(cls, context):
"""Check if operator can run."""
@ -126,6 +101,11 @@ class SEQUENCER_OT_generate_face_mask(Operator):
wm.mask_progress = 0
wm.mask_total = end_frame - start_frame + 1
# Get parameters from scene properties
conf_threshold = scene.facemask_conf_threshold
iou_threshold = scene.facemask_iou_threshold
mask_scale = scene.facemask_mask_scale
# Start generation
generator.start(
video_path=video_path,
@ -133,9 +113,9 @@ class SEQUENCER_OT_generate_face_mask(Operator):
start_frame=0, # Frame indices in video
end_frame=end_frame - start_frame,
fps=fps,
conf_threshold=self.conf_threshold,
iou_threshold=self.iou_threshold,
mask_scale=self.mask_scale,
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
mask_scale=mask_scale,
on_complete=on_complete,
on_progress=on_progress,
)
@ -169,20 +149,40 @@ class SEQUENCER_OT_generate_face_mask(Operator):
expected_frames: Number of frames expected
Returns:
True if cache exists and has at least 90% of expected frames
True if cache exists and is valid
"""
if not os.path.exists(cache_dir):
return False
# Check for MP4 video (new format)
mask_video = os.path.join(cache_dir, "mask.mp4")
if os.path.exists(mask_video):
# Verify video has expected number of frames
import cv2
cap = cv2.VideoCapture(mask_video)
if cap.isOpened():
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
# Accept cache if at least 90% of frames exist
return frame_count >= expected_frames * 0.9
cap.release()
return False
# Fallback: check for PNG sequence (backward compatibility)
mask_files = [f for f in os.listdir(cache_dir)
if f.startswith("mask_") and f.endswith(".png")]
# Accept cache if at least 90% of frames exist
# (some frames may have been skipped due to read errors)
return len(mask_files) >= expected_frames * 0.9
def _add_mask_strip(self, context, source_strip_name: str, mask_dir: str):
"""Add mask image sequence as a new strip."""
def _add_mask_strip(self, context, source_strip_name: str, mask_path: str):
"""Add mask video as a new strip.
Args:
context: Blender context
source_strip_name: Name of the source video strip
mask_path: Path to mask video file or directory (for backward compatibility)
"""
scene = context.scene
seq_editor = scene.sequence_editor
@ -194,17 +194,52 @@ class SEQUENCER_OT_generate_face_mask(Operator):
if not source_strip:
return
# Get first mask image
mask_files = sorted([
f for f in os.listdir(mask_dir)
if f.startswith("mask_") and f.endswith(".png")
])
# Check if mask_path is a video file or directory (backward compatibility)
if os.path.isfile(mask_path):
# New format: single MP4 file
mask_video = mask_path
else:
# Old format: directory with PNG sequence (backward compatibility)
mask_video = os.path.join(mask_path, "mask.mp4")
if not os.path.exists(mask_video):
# Fallback to PNG sequence
mask_files = sorted([
f for f in os.listdir(mask_path)
if f.startswith("mask_") and f.endswith(".png")
])
if not mask_files:
return
first_mask = os.path.join(mask_path, mask_files[0])
self._add_mask_strip_png_sequence(context, source_strip_name, mask_path, mask_files, first_mask)
return
if not mask_files:
# Find an empty channel
used_channels = {s.channel for s in seq_editor.strips}
new_channel = source_strip.channel + 1
while new_channel in used_channels:
new_channel += 1
# Add movie strip (Blender 5.0 API)
mask_strip = seq_editor.strips.new_movie(
name=f"{source_strip_name}_mask",
filepath=mask_video,
channel=new_channel,
frame_start=source_strip.frame_final_start,
)
# Set blend mode for mask
mask_strip.blend_type = 'ALPHA_OVER'
mask_strip.blend_alpha = 0.5
def _add_mask_strip_png_sequence(self, context, source_strip_name, mask_dir, mask_files, first_mask):
"""Backward compatibility: Add PNG sequence as mask strip."""
scene = context.scene
seq_editor = scene.sequence_editor
source_strip = seq_editor.strips.get(source_strip_name)
if not source_strip:
return
first_mask = os.path.join(mask_dir, mask_files[0])
# Find an empty channel
used_channels = {s.channel for s in seq_editor.strips}
new_channel = source_strip.channel + 1

View File

@ -9,6 +9,7 @@ import bpy
from bpy.types import Panel
from ..core.async_generator import get_generator
from ..core.utils import get_server_status, get_cache_info, format_size
class SEQUENCER_PT_face_mask(Panel):
@ -22,12 +23,20 @@ class SEQUENCER_PT_face_mask(Panel):
def draw(self, context):
layout = self.layout
scene = context.scene
wm = context.window_manager
seq_editor = context.scene.sequence_editor
# Note: Blender 5.0 uses 'strips' instead of 'sequences'
generator = get_generator()
# Always show parameters and status
self._draw_parameters(layout, scene)
self._draw_server_status(layout)
self._draw_cache_info(layout, seq_editor)
layout.separator()
# Show progress if generating
if generator.is_running:
self._draw_progress(layout, wm, generator)
@ -45,6 +54,84 @@ class SEQUENCER_PT_face_mask(Panel):
else:
layout.label(text="No strip selected")
def _draw_parameters(self, layout, scene):
"""Draw detection parameters."""
box = layout.box()
box.label(text="Parameters", icon='PREFERENCES')
col = box.column(align=True)
col.prop(scene, "facemask_conf_threshold")
col.prop(scene, "facemask_iou_threshold")
col.prop(scene, "facemask_mask_scale")
def _draw_server_status(self, layout):
"""Draw server status and GPU info."""
box = layout.box()
box.label(text="Server Status", icon='SYSTEM')
status = get_server_status()
# Server status
row = box.row()
if status['running']:
row.label(text="Server:", icon='CHECKMARK')
row.label(text="Running")
else:
row.label(text="Server:", icon='ERROR')
row.label(text="Stopped")
# GPU status
if status['running']:
row = box.row()
if status['gpu_available']:
row.label(text="GPU:", icon='CHECKMARK')
gpu_name = status['gpu_device'] or "Available"
# Truncate long GPU names
if len(gpu_name) > 25:
gpu_name = gpu_name[:22] + "..."
row.label(text=gpu_name)
else:
row.label(text="GPU:", icon='ERROR')
row.label(text="Not Available")
def _draw_cache_info(self, layout, seq_editor):
"""Draw cache information and clear button."""
box = layout.box()
box.label(text="Cache", icon='FILE_CACHE')
# Get cache info
if seq_editor and seq_editor.active_strip:
strip_name = seq_editor.active_strip.name
cache_path, total_size, file_count = get_cache_info(strip_name)
else:
cache_path, total_size, file_count = get_cache_info()
# Cache info
row = box.row()
row.label(text="Size:")
row.label(text=format_size(total_size))
row = box.row()
row.label(text="Files:")
row.label(text=str(file_count))
# Clear cache buttons
row = box.row(align=True)
if seq_editor and seq_editor.active_strip:
op = row.operator(
"sequencer.clear_mask_cache",
text="Clear Strip Cache",
icon='TRASH',
)
op.all_strips = False
op = row.operator(
"sequencer.clear_mask_cache",
text="Clear All",
icon='TRASH',
)
op.all_strips = True
def _draw_progress(self, layout, wm, generator):
"""Draw progress bar during generation."""
box = layout.box()

View File

@ -147,6 +147,69 @@ class YOLOFaceDetector:
return detections
def detect_batch(self, frames: List[np.ndarray]) -> List[List[Tuple[int, int, int, int, float]]]:
"""
Detect faces in multiple frames at once (batch processing).
Args:
frames: List of BGR images as numpy arrays (H, W, C)
Returns:
List of detection lists, one per frame.
Each detection: (x, y, width, height, confidence)
"""
if not frames:
return []
# Run batch inference
try:
results = self.model.predict(
frames,
conf=self.conf_threshold,
iou=self.iou_threshold,
imgsz=self.input_size[0],
verbose=False,
device=self._device,
)
except Exception as e:
print(f"[FaceMask] ERROR during batch inference: {e}")
import traceback
traceback.print_exc()
# Fallback to CPU
print("[FaceMask] Falling back to CPU inference...")
self._device = 'cpu'
results = self.model.predict(
frames,
conf=self.conf_threshold,
iou=self.iou_threshold,
imgsz=self.input_size[0],
verbose=False,
device='cpu',
)
# Extract detections for each frame
all_detections = []
for result in results:
detections = []
if result.boxes is not None:
boxes = result.boxes
for box in boxes:
# Get coordinates in xyxy format
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
conf = float(box.conf[0].cpu().numpy())
# Convert to x, y, width, height
x = int(x1)
y = int(y1)
w = int(x2 - x1)
h = int(y2 - y1)
detections.append((x, y, w, h, conf))
all_detections.append(detections)
return all_detections
def generate_mask(
self,
frame_shape: Tuple[int, int, int],

View File

@ -8,6 +8,29 @@ GPU-accelerated face detection using ONNX Runtime.
import os
import sys
import platform
# CRITICAL: Fix LD_LIBRARY_PATH before importing cv2 or torch
# cv2 adds its own lib path to the front, which can override ROCm libraries
def fix_library_path():
"""Ensure ROCm libraries are loaded before cv2's bundled libraries."""
ld_path = os.environ.get('LD_LIBRARY_PATH', '')
# Split and filter paths
paths = [p for p in ld_path.split(':') if p]
# Separate ROCm/GPU paths from other paths
rocm_paths = [p for p in paths if 'rocm' in p.lower() or 'clr-' in p or 'hip' in p.lower()]
other_paths = [p for p in paths if p not in rocm_paths]
# Rebuild with ROCm paths first
if rocm_paths:
new_ld_path = ':'.join(rocm_paths + other_paths)
os.environ['LD_LIBRARY_PATH'] = new_ld_path
print(f"[FaceMask] Fixed LD_LIBRARY_PATH to prioritize ROCm libraries")
# Fix library path BEFORE any other imports
fix_library_path()
import threading
import uuid
import queue
@ -61,7 +84,11 @@ class GenerateRequest(BaseModel):
mask_scale: float = 1.5
def process_video_task(task_id: str, req: GenerateRequest):
"""Background task to process video."""
"""Background task to process video with async MP4 output."""
writer = None
write_queue = None
writer_thread = None
try:
tasks[task_id].status = TaskStatus.PROCESSING
cancel_event = cancel_events.get(task_id)
@ -78,7 +105,6 @@ def process_video_task(task_id: str, req: GenerateRequest):
conf_threshold=req.conf_threshold,
iou_threshold=req.iou_threshold
)
# Ensure model is loaded
_ = detector.model
# Open video
@ -88,7 +114,10 @@ def process_video_task(task_id: str, req: GenerateRequest):
tasks[task_id].message = "Failed to open video"
return
# Determine frame range
# Get video properties
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_video_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
end_frame = min(req.end_frame, total_video_frames - 1)
frames_to_process = end_frame - req.start_frame + 1
@ -98,9 +127,68 @@ def process_video_task(task_id: str, req: GenerateRequest):
# Ensure output directory exists
os.makedirs(req.output_dir, exist_ok=True)
print(f"Starting processing: {req.video_path} ({frames_to_process} frames)")
# Setup MP4 writer (grayscale)
output_video_path = os.path.join(req.output_dir, "mask.mp4")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height), isColor=False)
# Process loop
if not writer.isOpened():
tasks[task_id].status = TaskStatus.FAILED
tasks[task_id].message = "Failed to create video writer"
cap.release()
return
# Async writer setup
write_queue = queue.Queue(maxsize=30) # Buffer up to 30 frames
writer_running = threading.Event()
writer_running.set()
def async_writer():
"""Background thread for writing frames to video."""
while writer_running.is_set() or not write_queue.empty():
try:
mask = write_queue.get(timeout=0.1)
if mask is not None:
writer.write(mask)
write_queue.task_done()
except queue.Empty:
continue
writer_thread = threading.Thread(target=async_writer, daemon=True)
writer_thread.start()
print(f"Starting processing: {req.video_path} ({frames_to_process} frames) -> {output_video_path}")
# Batch processing configuration
BATCH_SIZE = 5 # Optimal batch size for 4K video (72.9% improvement)
frame_buffer = []
def process_batch():
"""Process accumulated batch of frames."""
if not frame_buffer:
return
# Batch inference at full resolution
batch_detections = detector.detect_batch(frame_buffer)
# Generate masks for each frame
for i, detections in enumerate(batch_detections):
frame = frame_buffer[i]
# Generate mask at original resolution
mask = detector.generate_mask(
frame.shape,
detections,
mask_scale=req.mask_scale
)
# Async write to queue
write_queue.put(mask)
# Clear buffer
frame_buffer.clear()
# Process loop with batching
current_count = 0
for frame_idx in range(req.start_frame, end_frame + 1):
if cancel_event and cancel_event.is_set():
@ -113,32 +201,36 @@ def process_video_task(task_id: str, req: GenerateRequest):
ret, frame = cap.read()
if ret:
# Detect
detections = detector.detect(frame)
# Store frame for batch processing
frame_buffer.append(frame)
# Generate mask
mask = detector.generate_mask(
frame.shape,
detections,
mask_scale=req.mask_scale
)
# Save
mask_filename = f"mask_{current_count:06d}.png" # Note: using relative index for filename
mask_path = os.path.join(req.output_dir, mask_filename)
cv2.imwrite(mask_path, mask)
# Process batch when full
if len(frame_buffer) >= BATCH_SIZE:
process_batch()
# Update progress
current_count += 1
tasks[task_id].progress = current_count
# Process remaining frames in buffer
if frame_buffer:
process_batch()
# Cleanup
writer_running.clear()
write_queue.join() # Wait for all frames to be written
if writer_thread:
writer_thread.join(timeout=5)
cap.release()
if writer:
writer.release()
if tasks[task_id].status == TaskStatus.PROCESSING:
tasks[task_id].status = TaskStatus.COMPLETED
tasks[task_id].result_path = req.output_dir
tasks[task_id].result_path = output_video_path # Return video path
tasks[task_id].message = "Processing completed successfully"
print(f"Task {task_id} completed.")
print(f"Task {task_id} completed: {output_video_path}")
except Exception as e:
tasks[task_id].status = TaskStatus.FAILED
@ -230,12 +322,21 @@ def log_startup_diagnostics():
for var in rocm_vars:
value = os.environ.get(var)
if value:
# Truncate very long values
if len(value) > 200:
display_value = value[:200] + "... (truncated)"
# For LD_LIBRARY_PATH, show if ROCm paths are included
if var == 'LD_LIBRARY_PATH':
has_rocm = 'rocm' in value.lower() or 'clr-' in value
has_hip = 'hip' in value.lower()
print(f" {var}: {value[:100]}...")
print(f" Contains ROCm paths: {has_rocm}")
print(f" Contains HIP paths: {has_hip}")
if not has_rocm:
print(f" ⚠️ WARNING: ROCm library paths not found!")
else:
display_value = value
print(f" {var}: {display_value}")
if len(value) > 200:
display_value = value[:200] + "... (truncated)"
else:
display_value = value
print(f" {var}: {display_value}")
else:
print(f" {var}: (not set)")