diff --git a/core/async_generator.py b/core/async_generator.py index c1e8a7a..70ffc46 100644 --- a/core/async_generator.py +++ b/core/async_generator.py @@ -141,6 +141,90 @@ class AsyncMaskGenerator: self.worker_thread.start() bpy.app.timers.register(self._check_progress, first_interval=0.1) + def start_augment_pose( + self, + detections_path: str, + total_frames: int, + conf_threshold: float = 0.5, + iou_threshold: float = 0.45, + on_complete=None, + on_progress=None, + ): + """既存キャッシュへの pose 補完を非同期で開始する。""" + global bpy + import bpy as _bpy + bpy = _bpy + + if self.is_running: + raise RuntimeError("Mask generation already in progress") + + self.is_running = True + self.total_frames = total_frames + self.current_frame = 0 + self._on_complete = on_complete + self._on_progress = on_progress + + self.worker_thread = threading.Thread( + target=self._worker_augment_pose, + args=(detections_path, conf_threshold, iou_threshold), + daemon=True, + ) + self.worker_thread.start() + bpy.app.timers.register(self._check_progress, first_interval=0.1) + + def _worker_augment_pose( + self, + detections_path: str, + conf_threshold: float, + iou_threshold: float, + ): + """client.augment_pose() を呼んで task_id でポーリング。""" + import time + from .inference_client import get_client + + try: + client = get_client() + task_id = client.augment_pose( + detections_path=detections_path, + conf_threshold=conf_threshold, + iou_threshold=iou_threshold, + ) + + while self.is_running: + status = client.get_task_status(task_id) + state = status.get("status") + + total = status.get("total", 0) + if total > 0: + self.total_frames = total + + if state == "completed": + progress = status.get("progress", self.total_frames) + if progress >= 0: + self.progress_queue.put(("progress", progress)) + result_path = status.get("result_path", detections_path) + self.result_queue.put(("done", result_path)) + return + elif state == "failed": + self.result_queue.put(("error", status.get("message", "Unknown error"))) + return + elif state == "cancelled": + self.result_queue.put(("cancelled", None)) + return + + progress = status.get("progress", 0) + if progress >= 0: + self.progress_queue.put(("progress", progress)) + time.sleep(0.5) + + client.cancel_task(task_id) + self.result_queue.put(("cancelled", None)) + except Exception as e: + import traceback + print(f"[FaceMask] Error: {e}") + traceback.print_exc() + self.result_queue.put(("error", str(e))) + def cancel(self): """Cancel the current processing.""" self.is_running = False diff --git a/core/inference_client.py b/core/inference_client.py index acc8ea2..2213fe7 100644 --- a/core/inference_client.py +++ b/core/inference_client.py @@ -237,6 +237,36 @@ class InferenceClient: except urllib.error.HTTPError as e: raise RuntimeError(f"Server error: {e.read().decode('utf-8')}") + def augment_pose( + self, + detections_path: str, + conf_threshold: float, + iou_threshold: float, + ) -> str: + """既存キャッシュに pose 推定結果を追加合成する。task_id を返す。""" + if not self.is_server_running(): + self.start_server() + + data = { + "detections_path": detections_path, + "conf_threshold": conf_threshold, + "iou_threshold": iou_threshold, + } + + req = urllib.request.Request( + f"{self.SERVER_URL}/augment_pose", + data=json.dumps(data).encode("utf-8"), + headers={"Content-Type": "application/json"}, + method="POST", + ) + + try: + with urllib.request.urlopen(req) as response: + result = json.loads(response.read().decode("utf-8")) + return result["id"] + except urllib.error.HTTPError as e: + raise RuntimeError(f"Server error: {e.read().decode('utf-8')}") + def get_task_status(self, task_id: str) -> Dict[str, Any]: """Get status of a task.""" try: diff --git a/operators/generate_mask.py b/operators/generate_mask.py index 5649820..3843445 100644 --- a/operators/generate_mask.py +++ b/operators/generate_mask.py @@ -205,10 +205,88 @@ class SEQUENCER_OT_cancel_mask_generation(Operator): return {'FINISHED'} +class SEQUENCER_OT_augment_pose_mask(Operator): + """Add pose-based head detections to existing detection cache.""" + + bl_idname = "sequencer.augment_pose_mask" + bl_label = "Augment with Pose" + bl_description = "Run pose estimation and merge results into existing detection cache" + bl_options = {'REGISTER'} + + @classmethod + def poll(cls, context): + if not context.scene.sequence_editor: + return False + strip = context.scene.sequence_editor.active_strip + if not strip or strip.type != 'MOVIE': + return False + return check_detection_cache(strip.name) + + def execute(self, context): + strip = context.scene.sequence_editor.active_strip + output_dir = get_cache_dir_for_strip(strip.name) + detections_path = os.path.join(output_dir, "detections.msgpack") + + if not os.path.exists(detections_path): + self.report({'ERROR'}, f"Detection cache not found: {detections_path}") + return {'CANCELLED'} + + generator = get_generator() + scene = context.scene + wm = context.window_manager + wm.mask_progress = 0 + wm.mask_total = 0 # サーバー側から実際の値に更新される + + def on_complete(status, data): + wm.mask_total = max(wm.mask_total, generator.total_frames) + if status == "done": + wm.mask_progress = wm.mask_total + elif status in {"error", "cancelled"}: + wm.mask_progress = min(wm.mask_progress, wm.mask_total) + + if status == "done": + print(f"[FaceMask] Pose augmentation completed: {data}") + elif status == "error": + print(f"[FaceMask] Error: {data}") + elif status == "cancelled": + print("[FaceMask] Pose augmentation cancelled") + + for area in context.screen.areas: + if area.type == 'SEQUENCE_EDITOR': + area.tag_redraw() + + def on_progress(current, total_f): + wm.mask_progress = current + wm.mask_total = total_f + for area in context.screen.areas: + if area.type == 'SEQUENCE_EDITOR': + area.tag_redraw() + + try: + generator.start_augment_pose( + detections_path=detections_path, + total_frames=0, + conf_threshold=scene.facemask_conf_threshold, + iou_threshold=scene.facemask_iou_threshold, + on_complete=on_complete, + on_progress=on_progress, + ) + except RuntimeError as e: + self.report({'WARNING'}, str(e)) + return {'CANCELLED'} + except Exception as e: + self.report({'ERROR'}, f"Failed to start pose augmentation: {e}") + return {'CANCELLED'} + + self.report({'INFO'}, f"Started pose augmentation for {strip.name}") + return {'FINISHED'} + + # Registration classes = [ SEQUENCER_OT_generate_face_mask, SEQUENCER_OT_cancel_mask_generation, + SEQUENCER_OT_augment_pose_mask, ] diff --git a/panels/vse_panel.py b/panels/vse_panel.py index af03f39..0cf1d78 100644 --- a/panels/vse_panel.py +++ b/panels/vse_panel.py @@ -289,6 +289,12 @@ class SEQUENCER_PT_face_mask(Panel): icon='FILE_REFRESH', ) op.force = True + if strip.type == 'MOVIE': + box.operator( + "sequencer.augment_pose_mask", + text="Augment with Pose", + icon='MOD_ARMATURE', + ) def _draw_blur_controls(self, layout, context, strip): """Draw blur application controls.""" diff --git a/server/detector.py b/server/detector.py index 1454ca7..33f6efe 100644 --- a/server/detector.py +++ b/server/detector.py @@ -228,3 +228,205 @@ def get_detector(**kwargs) -> YOLOHeadDetector: if _detector is None: _detector = YOLOHeadDetector(**kwargs) return _detector + + +# --------------------------------------------------------------------------- +# Pose-based head detector (YOLOv8 pose estimation) +# --------------------------------------------------------------------------- + +# COCO pose keypoint indices +_HEAD_KP = [0, 1, 2, 3, 4] # nose, left_eye, right_eye, left_ear, right_ear +_SHOULDER_KP = [5, 6] # left_shoulder, right_shoulder +_KP_CONF_THRESH = 0.3 + + +def _head_bbox_from_pose( + kp_xy: np.ndarray, + kp_conf: np.ndarray, + person_x1: float, + person_y1: float, + person_x2: float, + person_y2: float, +) -> Tuple[int, int, int, int]: + """ + Estimate head bounding box (x, y, w, h) from COCO pose keypoints. + + Strategy: + 1. Use head keypoints (0-4: nose, eyes, ears) if visible. + 2. Fall back to shoulder keypoints (5-6) to infer head position. + 3. Last resort: use top of the person bounding box. + """ + person_w = max(person_x2 - person_x1, 1.0) + + # --- Step 1: head keypoints --- + visible_head = [ + (float(kp_xy[i][0]), float(kp_xy[i][1])) + for i in _HEAD_KP + if float(kp_conf[i]) > _KP_CONF_THRESH + ] + if visible_head: + xs = [p[0] for p in visible_head] + ys = [p[1] for p in visible_head] + kp_x1, kp_y1 = min(xs), min(ys) + kp_x2, kp_y2 = max(xs), max(ys) + span = max(kp_x2 - kp_x1, kp_y2 - kp_y1, 1.0) + cx = (kp_x1 + kp_x2) / 2.0 + cy = (kp_y1 + kp_y2) / 2.0 + r = max(span * 0.5, person_w * 0.10) + x1 = int(cx - r) + y1 = int(cy - r) + x2 = int(cx + r) + y2 = int(cy + r) + return x1, y1, x2 - x1, y2 - y1 + + # --- Step 2: shoulder keypoints --- + visible_shoulder = [ + (float(kp_xy[i][0]), float(kp_xy[i][1])) + for i in _SHOULDER_KP + if float(kp_conf[i]) > _KP_CONF_THRESH + ] + if visible_shoulder: + cx = sum(p[0] for p in visible_shoulder) / len(visible_shoulder) + cy_sh = sum(p[1] for p in visible_shoulder) / len(visible_shoulder) + if len(visible_shoulder) == 2: + sh_width = abs(visible_shoulder[1][0] - visible_shoulder[0][0]) + else: + sh_width = person_w * 0.5 + r = max(sh_width * 0.3, person_w * 0.12) + cy = cy_sh - r * 1.3 + x1 = int(cx - r) + y1 = int(cy - r) + x2 = int(cx + r) + y2 = int(cy + r) + return x1, y1, x2 - x1, y2 - y1 + + # --- Step 3: person bbox top --- + r = max(person_w * 0.15, 20.0) + cx = (person_x1 + person_x2) / 2.0 + x1 = int(cx - r) + y1 = int(person_y1) + x2 = int(cx + r) + y2 = int(person_y1 + r * 2.0) + return x1, y1, x2 - x1, y2 - y1 + + +class YOLOPoseHeadDetector: + """ + Head detector using YOLOv8 pose estimation with PyTorch ROCm support. + + Extracts head bounding boxes from COCO pose keypoints (nose, eyes, ears). + yolov8l-pose.pt is auto-downloaded by Ultralytics on first use. + """ + + DEFAULT_MODEL = os.path.join("models", "yolov8l-pose.pt") + + def __init__( + self, + model_path: Optional[str] = None, + conf_threshold: float = 0.25, + iou_threshold: float = 0.45, + input_size: Tuple[int, int] = (640, 640), + ): + self.conf_threshold = conf_threshold + self.iou_threshold = iou_threshold + self.input_size = input_size + self._model = None + self._model_path = model_path + self._device = None + + @property + def model(self): + """Lazy-load YOLO pose model.""" + if self._model is None: + from ultralytics import YOLO + import torch + + model_path = self._model_path if self._model_path is not None else self.DEFAULT_MODEL + + if torch.cuda.is_available(): + self._device = 'cuda' + device_name = torch.cuda.get_device_name(0) + print(f"[FaceMask] Using ROCm GPU for pose inference: {device_name}") + else: + self._device = 'cpu' + print("[FaceMask] Using CPU for pose inference (ROCm GPU not available)") + + try: + self._model = YOLO(model_path) + print(f"[FaceMask] Pose model loaded: {model_path}") + print(f"[FaceMask] Device: {self._device}") + except Exception as e: + print(f"[FaceMask] Error loading pose model: {e}") + import traceback + traceback.print_exc() + raise + + return self._model + + def _results_to_detections(self, result) -> List[Tuple[int, int, int, int, float]]: + """Convert a single YOLO pose result to (x, y, w, h, conf) tuples.""" + detections = [] + if result.boxes is None or result.keypoints is None: + return detections + + boxes = result.boxes + keypoints = result.keypoints + + for i, box in enumerate(boxes): + conf = float(box.conf[0].cpu().numpy()) + x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() + + kp_data = keypoints.data[i].cpu().numpy() # shape (17, 3): x, y, conf + kp_xy = kp_data[:, :2] + kp_conf = kp_data[:, 2] + + hx, hy, hw, hh = _head_bbox_from_pose( + kp_xy, kp_conf, + float(x1), float(y1), float(x2), float(y2), + ) + detections.append((hx, hy, hw, hh, conf)) + + return detections + + def detect_batch(self, frames: List[np.ndarray]) -> List[List[Tuple[int, int, int, int, float]]]: + """Detect heads in multiple frames at once (batch processing).""" + if not frames: + return [] + + try: + results = self.model.predict( + frames, + conf=self.conf_threshold, + iou=self.iou_threshold, + imgsz=self.input_size[0], + verbose=False, + device=self._device, + ) + except Exception as e: + print(f"[FaceMask] ERROR during pose batch inference: {e}") + import traceback + traceback.print_exc() + print("[FaceMask] Falling back to CPU inference...") + self._device = 'cpu' + results = self.model.predict( + frames, + conf=self.conf_threshold, + iou=self.iou_threshold, + imgsz=self.input_size[0], + verbose=False, + device='cpu', + ) + + return [self._results_to_detections(r) for r in results] + + +# Pose detector singleton +_pose_detector: Optional[YOLOPoseHeadDetector] = None + + +def get_pose_detector(**kwargs) -> YOLOPoseHeadDetector: + """Get or create the global YOLO pose head detector instance.""" + global _pose_detector + if _pose_detector is None: + _pose_detector = YOLOPoseHeadDetector(**kwargs) + return _pose_detector diff --git a/server/main.py b/server/main.py index 1932576..78249ad 100644 --- a/server/main.py +++ b/server/main.py @@ -49,7 +49,7 @@ import msgpack # noqa: E402 # Add project root to path for imports if needed sys.path.append(str(Path(__file__).parent.parent)) -from server.detector import get_detector # noqa: E402 +from server.detector import get_detector, get_pose_detector # noqa: E402 app = FastAPI(title="Face Mask Inference Server") @@ -142,6 +142,12 @@ class GenerateImagesRequest(BaseModel): iou_threshold: float = 0.45 +class AugmentPoseRequest(BaseModel): + detections_path: str + conf_threshold: float = 0.5 + iou_threshold: float = 0.45 + + class BakeImagesRequest(BaseModel): image_dir: str filenames: List[str] @@ -582,6 +588,129 @@ def process_bake_images_task(task_id: str, req: BakeImagesRequest): del cancel_events[task_id] +def augment_pose_task(task_id: str, req: AugmentPoseRequest): + """Background task: run pose estimation and merge results into existing cache.""" + cap = None + try: + tasks[task_id].status = TaskStatus.PROCESSING + cancel_event = cancel_events.get(task_id) + + if not os.path.exists(req.detections_path): + tasks[task_id].status = TaskStatus.FAILED + tasks[task_id].message = f"Detections file not found: {req.detections_path}" + return + + with open(req.detections_path, "rb") as f: + payload = msgpack.unpackb(f.read(), raw=False) + + existing_frames: List[List[List[float]]] = payload.get("frames", []) + video_path = payload.get("video_path") + start_frame = int(payload.get("start_frame", 0)) + total = len(existing_frames) + + if not video_path: + tasks[task_id].status = TaskStatus.FAILED + tasks[task_id].message = "Cache does not contain video_path (image caches not supported)" + return + if not os.path.exists(video_path): + tasks[task_id].status = TaskStatus.FAILED + tasks[task_id].message = f"Video not found: {video_path}" + return + if total == 0: + tasks[task_id].status = TaskStatus.FAILED + tasks[task_id].message = "Cache has no frames" + return + + tasks[task_id].total = total + + detector = get_pose_detector( + conf_threshold=req.conf_threshold, + iou_threshold=req.iou_threshold, + ) + _ = detector.model + + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + tasks[task_id].status = TaskStatus.FAILED + tasks[task_id].message = "Failed to open video" + return + + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + if start_frame > 0: + seek_ok = cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) + if not seek_ok: + for _ in range(start_frame): + ret, _ = cap.read() + if not ret: + tasks[task_id].status = TaskStatus.FAILED + tasks[task_id].message = f"Failed to seek to start frame: {start_frame}" + return + + frame_buffer: List[np.ndarray] = [] + buffer_indices: List[int] = [] # existing_frames インデックス対応 + current_count = 0 + batch_size = 5 + + def process_pose_batch(): + nonlocal current_count + if not frame_buffer: + return + batch_detections = detector.detect_batch(frame_buffer) + for idx, detections in zip(buffer_indices, batch_detections): + for x, y, w, h, conf in detections: + bx, by, bw, bh = int(x), int(y), int(w), int(h) + bx = max(0, bx) + by = max(0, by) + bw = min(width - bx, bw) + bh = min(height - by, bh) + if bw > 0 and bh > 0: + existing_frames[idx].append([bx, by, bw, bh, float(conf)]) + current_count += 1 + tasks[task_id].progress = current_count + frame_buffer.clear() + buffer_indices.clear() + + for i in range(total): + if cancel_event and cancel_event.is_set(): + tasks[task_id].status = TaskStatus.CANCELLED + tasks[task_id].message = "Cancelled by user" + break + + ret, frame = cap.read() + if not ret: + break + + frame_buffer.append(frame) + buffer_indices.append(i) + if len(frame_buffer) >= batch_size: + process_pose_batch() + + if frame_buffer: + process_pose_batch() + + if tasks[task_id].status == TaskStatus.PROCESSING: + payload["frames"] = existing_frames + with open(req.detections_path, "wb") as f: + f.write(msgpack.packb(payload, use_bin_type=True)) + + tasks[task_id].status = TaskStatus.COMPLETED + tasks[task_id].result_path = req.detections_path + tasks[task_id].message = "Pose augmentation completed" + print(f"[FaceMask] Pose augmentation completed: {req.detections_path}") + + except Exception as e: + tasks[task_id].status = TaskStatus.FAILED + tasks[task_id].message = str(e) + traceback.print_exc() + finally: + if cap: + cap.release() + if task_id in cancel_events: + del cancel_events[task_id] + + def process_video_task(task_id: str, req: GenerateRequest): """Background task to detect faces and save bbox cache as msgpack.""" cap = None @@ -1261,6 +1390,16 @@ def generate_images_endpoint(req: GenerateImagesRequest, background_tasks: Backg return task +@app.post("/augment_pose", response_model=Task) +def augment_pose_endpoint(req: AugmentPoseRequest, background_tasks: BackgroundTasks): + task_id = str(uuid.uuid4()) + task = Task(id=task_id, status=TaskStatus.PENDING) + tasks[task_id] = task + cancel_events[task_id] = threading.Event() + background_tasks.add_task(augment_pose_task, task_id, req) + return task + + @app.post("/bake_image_blur", response_model=Task) def bake_image_blur_endpoint(req: BakeImagesRequest, background_tasks: BackgroundTasks): task_id = str(uuid.uuid4())