blender-mask-peoples/core/async_bake_generator.py

162 lines
4.8 KiB
Python

"""
Async blur bake generator using Thread + Queue + Timer pattern.
This module mirrors AsyncMaskGenerator behavior for bake-and-swap workflow,
so Blender UI remains responsive during server-side bake processing.
"""
import threading
import queue
from typing import Optional, Callable
# Will be imported when running inside Blender
bpy = None
class AsyncBakeGenerator:
"""Asynchronous bake generator for non-blocking blur bake tasks."""
def __init__(self):
self.result_queue: queue.Queue = queue.Queue()
self.progress_queue: queue.Queue = queue.Queue()
self.worker_thread: Optional[threading.Thread] = None
self.is_running: bool = False
self.total_frames: int = 0
self.current_frame: int = 0
self._on_complete: Optional[Callable] = None
self._on_progress: Optional[Callable] = None
def start(
self,
video_path: str,
detections_path: str,
output_path: str,
blur_size: int,
fmt: str,
on_complete: Optional[Callable] = None,
on_progress: Optional[Callable] = None,
):
"""Start asynchronous bake request and progress polling."""
global bpy
import bpy as _bpy
bpy = _bpy
if self.is_running:
raise RuntimeError("Blur bake already in progress")
self.is_running = True
self.total_frames = 0
self.current_frame = 0
self._on_complete = on_complete
self._on_progress = on_progress
self.worker_thread = threading.Thread(
target=self._worker,
args=(video_path, detections_path, output_path, blur_size, fmt),
daemon=True,
)
self.worker_thread.start()
bpy.app.timers.register(
self._check_progress,
first_interval=0.1,
)
def cancel(self):
"""Cancel the current bake processing."""
self.is_running = False
if self.worker_thread and self.worker_thread.is_alive():
self.worker_thread.join(timeout=2.0)
def _worker(
self,
video_path: str,
detections_path: str,
output_path: str,
blur_size: int,
fmt: str,
):
import time
from .inference_client import get_client
task_id = None
try:
client = get_client()
task_id = client.bake_blur(
video_path=video_path,
detections_path=detections_path,
output_path=output_path,
blur_size=blur_size,
fmt=fmt,
)
while self.is_running:
status = client.get_task_status(task_id)
state = status.get("status")
total = status.get("total", 0)
if total > 0:
self.total_frames = total
progress = status.get("progress", 0)
if progress >= 0:
self.progress_queue.put(("progress", progress))
if state == "completed":
result_path = status.get("result_path", output_path)
self.result_queue.put(("done", result_path))
return
if state == "failed":
error_msg = status.get("message", "Unknown server error")
self.result_queue.put(("error", error_msg))
return
if state == "cancelled":
self.result_queue.put(("cancelled", None))
return
time.sleep(0.5)
# Local cancel path
if task_id:
client.cancel_task(task_id)
self.result_queue.put(("cancelled", None))
except Exception as e:
self.result_queue.put(("error", str(e)))
def _check_progress(self) -> Optional[float]:
while not self.progress_queue.empty():
try:
msg_type, data = self.progress_queue.get_nowait()
if msg_type == "progress":
self.current_frame = data
if self._on_progress:
self._on_progress(self.current_frame, self.total_frames)
except queue.Empty:
break
if not self.result_queue.empty():
try:
msg_type, data = self.result_queue.get_nowait()
self.is_running = False
if self._on_complete:
self._on_complete(msg_type, data)
return None
except queue.Empty:
pass
if self.is_running:
return 0.1
return None
_bake_generator: Optional[AsyncBakeGenerator] = None
def get_bake_generator() -> AsyncBakeGenerator:
global _bake_generator
if _bake_generator is None:
_bake_generator = AsyncBakeGenerator()
return _bake_generator