255 lines
8.5 KiB
Python
255 lines
8.5 KiB
Python
"""
|
|
Async Mask Generator using Thread + Queue + Timer pattern.
|
|
|
|
This module provides non-blocking face mask generation for Blender.
|
|
Heavy processing (face detection) runs in a worker thread while
|
|
Blender's UI remains responsive via bpy.app.timers.
|
|
"""
|
|
|
|
import os
|
|
import threading
|
|
import queue
|
|
from typing import Optional, Callable
|
|
from pathlib import Path
|
|
|
|
# Will be imported when running inside Blender
|
|
bpy = None
|
|
|
|
|
|
class AsyncMaskGenerator:
|
|
"""
|
|
Asynchronous mask generator that doesn't block Blender's UI.
|
|
|
|
Uses Thread + Queue + Timer pattern:
|
|
- Worker thread: Face detection (can use bpy-unsafe operations)
|
|
- Main thread timer: UI updates and bpy operations
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.result_queue: queue.Queue = queue.Queue()
|
|
self.progress_queue: queue.Queue = queue.Queue()
|
|
self.worker_thread: Optional[threading.Thread] = None
|
|
self.is_running: bool = False
|
|
self.total_frames: int = 0
|
|
self.current_frame: int = 0
|
|
self._on_complete: Optional[Callable] = None
|
|
self._on_progress: Optional[Callable] = None
|
|
|
|
def start(
|
|
self,
|
|
video_path: str,
|
|
output_dir: str,
|
|
start_frame: int,
|
|
end_frame: int,
|
|
fps: float,
|
|
conf_threshold: float = 0.5,
|
|
iou_threshold: float = 0.45,
|
|
mask_scale: float = 1.5,
|
|
on_complete: Optional[Callable] = None,
|
|
on_progress: Optional[Callable] = None,
|
|
):
|
|
"""
|
|
Start asynchronous mask generation with YOLO GPU acceleration.
|
|
|
|
Args:
|
|
video_path: Path to source video file
|
|
output_dir: Directory to save mask images
|
|
start_frame: First frame to process
|
|
end_frame: Last frame to process
|
|
fps: Video frame rate (for seeking)
|
|
conf_threshold: YOLO confidence threshold
|
|
iou_threshold: YOLO NMS IoU threshold
|
|
mask_scale: Mask region scale factor
|
|
on_complete: Callback when processing completes (called from main thread)
|
|
on_progress: Callback for progress updates (called from main thread)
|
|
"""
|
|
global bpy
|
|
import bpy as _bpy
|
|
bpy = _bpy
|
|
|
|
if self.is_running:
|
|
raise RuntimeError("Mask generation already in progress")
|
|
|
|
print(f"[FaceMask] Starting mask generation: {video_path}")
|
|
print(f"[FaceMask] Output directory: {output_dir}")
|
|
print(f"[FaceMask] Frame range: {start_frame} - {end_frame}")
|
|
|
|
self.is_running = True
|
|
self.total_frames = end_frame - start_frame + 1
|
|
self.current_frame = 0
|
|
self._on_complete = on_complete
|
|
self._on_progress = on_progress
|
|
|
|
# Ensure output directory exists
|
|
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
# Start worker thread
|
|
self.worker_thread = threading.Thread(
|
|
target=self._worker,
|
|
args=(
|
|
video_path,
|
|
output_dir,
|
|
start_frame,
|
|
end_frame,
|
|
fps,
|
|
conf_threshold,
|
|
iou_threshold,
|
|
mask_scale,
|
|
),
|
|
daemon=True,
|
|
)
|
|
self.worker_thread.start()
|
|
|
|
# Register timer for main thread callbacks
|
|
bpy.app.timers.register(
|
|
self._check_progress,
|
|
first_interval=0.1,
|
|
)
|
|
|
|
def cancel(self):
|
|
"""Cancel the current processing."""
|
|
self.is_running = False
|
|
if self.worker_thread and self.worker_thread.is_alive():
|
|
self.worker_thread.join(timeout=2.0)
|
|
|
|
def _worker(
|
|
self,
|
|
video_path: str,
|
|
output_dir: str,
|
|
start_frame: int,
|
|
end_frame: int,
|
|
fps: float,
|
|
conf_threshold: float,
|
|
iou_threshold: float,
|
|
mask_scale: float,
|
|
):
|
|
"""
|
|
Worker thread function. Delegates to inference server and polls status.
|
|
"""
|
|
import time
|
|
from .inference_client import get_client
|
|
|
|
try:
|
|
client = get_client()
|
|
|
|
# Start task on server
|
|
print("[FaceMask] Requesting generation on server...")
|
|
task_id = client.generate_mask(
|
|
video_path=video_path,
|
|
output_dir=output_dir,
|
|
start_frame=start_frame,
|
|
end_frame=end_frame,
|
|
conf_threshold=conf_threshold,
|
|
iou_threshold=iou_threshold,
|
|
mask_scale=mask_scale,
|
|
)
|
|
print(f"[FaceMask] Task started: {task_id}")
|
|
|
|
# Poll loop
|
|
while self.is_running:
|
|
status = client.get_task_status(task_id)
|
|
state = status.get("status")
|
|
|
|
total = status.get("total", 0)
|
|
if total > 0:
|
|
self.total_frames = total
|
|
|
|
if state == "completed":
|
|
final_progress = status.get("progress", self.total_frames)
|
|
if final_progress >= 0:
|
|
self.progress_queue.put(("progress", final_progress))
|
|
result_path = status.get(
|
|
"result_path",
|
|
os.path.join(output_dir, "detections.msgpack"),
|
|
)
|
|
self.result_queue.put(("done", result_path))
|
|
return
|
|
|
|
elif state == "failed":
|
|
error_msg = status.get("message", "Unknown server error")
|
|
print(f"[FaceMask] Server task failed: {error_msg}")
|
|
self.result_queue.put(("error", error_msg))
|
|
return
|
|
|
|
elif state == "cancelled":
|
|
self.result_queue.put(("cancelled", None))
|
|
return
|
|
|
|
# Report progress
|
|
progress = status.get("progress", 0)
|
|
if progress >= 0:
|
|
self.progress_queue.put(("progress", progress))
|
|
|
|
time.sleep(0.5)
|
|
|
|
# If loop exited but task not done, cancel server task
|
|
print("[FaceMask] Cancelling server task...")
|
|
client.cancel_task(task_id)
|
|
self.result_queue.put(("cancelled", None))
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
print(f"[FaceMask] Error: {e}")
|
|
traceback.print_exc()
|
|
self.result_queue.put(("error", str(e)))
|
|
|
|
def _check_progress(self) -> Optional[float]:
|
|
"""
|
|
Timer callback for checking progress from main thread.
|
|
|
|
Returns:
|
|
Time until next call, or None to unregister.
|
|
"""
|
|
# Process all pending progress updates
|
|
while not self.progress_queue.empty():
|
|
try:
|
|
msg_type, data = self.progress_queue.get_nowait()
|
|
if msg_type == "progress":
|
|
self.current_frame = data
|
|
if self._on_progress:
|
|
self._on_progress(self.current_frame, self.total_frames)
|
|
except queue.Empty:
|
|
break
|
|
|
|
# Check for completion
|
|
if not self.result_queue.empty():
|
|
try:
|
|
msg_type, data = self.result_queue.get_nowait()
|
|
self.is_running = False
|
|
|
|
# Ensure UI receives a final progress update before completion.
|
|
if (
|
|
msg_type == "done"
|
|
and self.total_frames > 0
|
|
and self.current_frame < self.total_frames
|
|
and self._on_progress
|
|
):
|
|
self.current_frame = self.total_frames
|
|
self._on_progress(self.current_frame, self.total_frames)
|
|
|
|
if self._on_complete:
|
|
self._on_complete(msg_type, data)
|
|
|
|
return None # Unregister timer
|
|
|
|
except queue.Empty:
|
|
pass
|
|
|
|
# Continue checking if still running
|
|
if self.is_running:
|
|
return 0.1 # Check again in 100ms
|
|
|
|
return None # Unregister timer
|
|
|
|
|
|
# Global instance for easy access from operators
|
|
_generator: Optional[AsyncMaskGenerator] = None
|
|
|
|
|
|
def get_generator() -> AsyncMaskGenerator:
|
|
"""Get or create the global mask generator instance."""
|
|
global _generator
|
|
if _generator is None:
|
|
_generator = AsyncMaskGenerator()
|
|
return _generator
|