blender-mask-peoples/core/async_generator.py
2026-02-17 00:20:02 +09:00

255 lines
8.5 KiB
Python

"""
Async Mask Generator using Thread + Queue + Timer pattern.
This module provides non-blocking face mask generation for Blender.
Heavy processing (face detection) runs in a worker thread while
Blender's UI remains responsive via bpy.app.timers.
"""
import os
import threading
import queue
from typing import Optional, Callable
from pathlib import Path
# Will be imported when running inside Blender
bpy = None
class AsyncMaskGenerator:
"""
Asynchronous mask generator that doesn't block Blender's UI.
Uses Thread + Queue + Timer pattern:
- Worker thread: Face detection (can use bpy-unsafe operations)
- Main thread timer: UI updates and bpy operations
"""
def __init__(self):
self.result_queue: queue.Queue = queue.Queue()
self.progress_queue: queue.Queue = queue.Queue()
self.worker_thread: Optional[threading.Thread] = None
self.is_running: bool = False
self.total_frames: int = 0
self.current_frame: int = 0
self._on_complete: Optional[Callable] = None
self._on_progress: Optional[Callable] = None
def start(
self,
video_path: str,
output_dir: str,
start_frame: int,
end_frame: int,
fps: float,
conf_threshold: float = 0.5,
iou_threshold: float = 0.45,
mask_scale: float = 1.5,
on_complete: Optional[Callable] = None,
on_progress: Optional[Callable] = None,
):
"""
Start asynchronous mask generation with YOLO GPU acceleration.
Args:
video_path: Path to source video file
output_dir: Directory to save mask images
start_frame: First frame to process
end_frame: Last frame to process
fps: Video frame rate (for seeking)
conf_threshold: YOLO confidence threshold
iou_threshold: YOLO NMS IoU threshold
mask_scale: Mask region scale factor
on_complete: Callback when processing completes (called from main thread)
on_progress: Callback for progress updates (called from main thread)
"""
global bpy
import bpy as _bpy
bpy = _bpy
if self.is_running:
raise RuntimeError("Mask generation already in progress")
print(f"[FaceMask] Starting mask generation: {video_path}")
print(f"[FaceMask] Output directory: {output_dir}")
print(f"[FaceMask] Frame range: {start_frame} - {end_frame}")
self.is_running = True
self.total_frames = end_frame - start_frame + 1
self.current_frame = 0
self._on_complete = on_complete
self._on_progress = on_progress
# Ensure output directory exists
Path(output_dir).mkdir(parents=True, exist_ok=True)
# Start worker thread
self.worker_thread = threading.Thread(
target=self._worker,
args=(
video_path,
output_dir,
start_frame,
end_frame,
fps,
conf_threshold,
iou_threshold,
mask_scale,
),
daemon=True,
)
self.worker_thread.start()
# Register timer for main thread callbacks
bpy.app.timers.register(
self._check_progress,
first_interval=0.1,
)
def cancel(self):
"""Cancel the current processing."""
self.is_running = False
if self.worker_thread and self.worker_thread.is_alive():
self.worker_thread.join(timeout=2.0)
def _worker(
self,
video_path: str,
output_dir: str,
start_frame: int,
end_frame: int,
fps: float,
conf_threshold: float,
iou_threshold: float,
mask_scale: float,
):
"""
Worker thread function. Delegates to inference server and polls status.
"""
import time
from .inference_client import get_client
try:
client = get_client()
# Start task on server
print("[FaceMask] Requesting generation on server...")
task_id = client.generate_mask(
video_path=video_path,
output_dir=output_dir,
start_frame=start_frame,
end_frame=end_frame,
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
mask_scale=mask_scale,
)
print(f"[FaceMask] Task started: {task_id}")
# Poll loop
while self.is_running:
status = client.get_task_status(task_id)
state = status.get("status")
total = status.get("total", 0)
if total > 0:
self.total_frames = total
if state == "completed":
final_progress = status.get("progress", self.total_frames)
if final_progress >= 0:
self.progress_queue.put(("progress", final_progress))
result_path = status.get(
"result_path",
os.path.join(output_dir, "detections.msgpack"),
)
self.result_queue.put(("done", result_path))
return
elif state == "failed":
error_msg = status.get("message", "Unknown server error")
print(f"[FaceMask] Server task failed: {error_msg}")
self.result_queue.put(("error", error_msg))
return
elif state == "cancelled":
self.result_queue.put(("cancelled", None))
return
# Report progress
progress = status.get("progress", 0)
if progress >= 0:
self.progress_queue.put(("progress", progress))
time.sleep(0.5)
# If loop exited but task not done, cancel server task
print("[FaceMask] Cancelling server task...")
client.cancel_task(task_id)
self.result_queue.put(("cancelled", None))
except Exception as e:
import traceback
print(f"[FaceMask] Error: {e}")
traceback.print_exc()
self.result_queue.put(("error", str(e)))
def _check_progress(self) -> Optional[float]:
"""
Timer callback for checking progress from main thread.
Returns:
Time until next call, or None to unregister.
"""
# Process all pending progress updates
while not self.progress_queue.empty():
try:
msg_type, data = self.progress_queue.get_nowait()
if msg_type == "progress":
self.current_frame = data
if self._on_progress:
self._on_progress(self.current_frame, self.total_frames)
except queue.Empty:
break
# Check for completion
if not self.result_queue.empty():
try:
msg_type, data = self.result_queue.get_nowait()
self.is_running = False
# Ensure UI receives a final progress update before completion.
if (
msg_type == "done"
and self.total_frames > 0
and self.current_frame < self.total_frames
and self._on_progress
):
self.current_frame = self.total_frames
self._on_progress(self.current_frame, self.total_frames)
if self._on_complete:
self._on_complete(msg_type, data)
return None # Unregister timer
except queue.Empty:
pass
# Continue checking if still running
if self.is_running:
return 0.1 # Check again in 100ms
return None # Unregister timer
# Global instance for easy access from operators
_generator: Optional[AsyncMaskGenerator] = None
def get_generator() -> AsyncMaskGenerator:
"""Get or create the global mask generator instance."""
global _generator
if _generator is None:
_generator = AsyncMaskGenerator()
return _generator