姿勢推定と頭部検知を合成する手段を追加

This commit is contained in:
Keisuke Hirata 2026-02-23 04:21:53 +09:00
parent de99aef9ad
commit eb028ed278
6 changed files with 540 additions and 1 deletions

View File

@ -141,6 +141,90 @@ class AsyncMaskGenerator:
self.worker_thread.start()
bpy.app.timers.register(self._check_progress, first_interval=0.1)
def start_augment_pose(
self,
detections_path: str,
total_frames: int,
conf_threshold: float = 0.5,
iou_threshold: float = 0.45,
on_complete=None,
on_progress=None,
):
"""既存キャッシュへの pose 補完を非同期で開始する。"""
global bpy
import bpy as _bpy
bpy = _bpy
if self.is_running:
raise RuntimeError("Mask generation already in progress")
self.is_running = True
self.total_frames = total_frames
self.current_frame = 0
self._on_complete = on_complete
self._on_progress = on_progress
self.worker_thread = threading.Thread(
target=self._worker_augment_pose,
args=(detections_path, conf_threshold, iou_threshold),
daemon=True,
)
self.worker_thread.start()
bpy.app.timers.register(self._check_progress, first_interval=0.1)
def _worker_augment_pose(
self,
detections_path: str,
conf_threshold: float,
iou_threshold: float,
):
"""client.augment_pose() を呼んで task_id でポーリング。"""
import time
from .inference_client import get_client
try:
client = get_client()
task_id = client.augment_pose(
detections_path=detections_path,
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
)
while self.is_running:
status = client.get_task_status(task_id)
state = status.get("status")
total = status.get("total", 0)
if total > 0:
self.total_frames = total
if state == "completed":
progress = status.get("progress", self.total_frames)
if progress >= 0:
self.progress_queue.put(("progress", progress))
result_path = status.get("result_path", detections_path)
self.result_queue.put(("done", result_path))
return
elif state == "failed":
self.result_queue.put(("error", status.get("message", "Unknown error")))
return
elif state == "cancelled":
self.result_queue.put(("cancelled", None))
return
progress = status.get("progress", 0)
if progress >= 0:
self.progress_queue.put(("progress", progress))
time.sleep(0.5)
client.cancel_task(task_id)
self.result_queue.put(("cancelled", None))
except Exception as e:
import traceback
print(f"[FaceMask] Error: {e}")
traceback.print_exc()
self.result_queue.put(("error", str(e)))
def cancel(self):
"""Cancel the current processing."""
self.is_running = False

View File

@ -237,6 +237,36 @@ class InferenceClient:
except urllib.error.HTTPError as e:
raise RuntimeError(f"Server error: {e.read().decode('utf-8')}")
def augment_pose(
self,
detections_path: str,
conf_threshold: float,
iou_threshold: float,
) -> str:
"""既存キャッシュに pose 推定結果を追加合成する。task_id を返す。"""
if not self.is_server_running():
self.start_server()
data = {
"detections_path": detections_path,
"conf_threshold": conf_threshold,
"iou_threshold": iou_threshold,
}
req = urllib.request.Request(
f"{self.SERVER_URL}/augment_pose",
data=json.dumps(data).encode("utf-8"),
headers={"Content-Type": "application/json"},
method="POST",
)
try:
with urllib.request.urlopen(req) as response:
result = json.loads(response.read().decode("utf-8"))
return result["id"]
except urllib.error.HTTPError as e:
raise RuntimeError(f"Server error: {e.read().decode('utf-8')}")
def get_task_status(self, task_id: str) -> Dict[str, Any]:
"""Get status of a task."""
try:

View File

@ -205,10 +205,88 @@ class SEQUENCER_OT_cancel_mask_generation(Operator):
return {'FINISHED'}
class SEQUENCER_OT_augment_pose_mask(Operator):
"""Add pose-based head detections to existing detection cache."""
bl_idname = "sequencer.augment_pose_mask"
bl_label = "Augment with Pose"
bl_description = "Run pose estimation and merge results into existing detection cache"
bl_options = {'REGISTER'}
@classmethod
def poll(cls, context):
if not context.scene.sequence_editor:
return False
strip = context.scene.sequence_editor.active_strip
if not strip or strip.type != 'MOVIE':
return False
return check_detection_cache(strip.name)
def execute(self, context):
strip = context.scene.sequence_editor.active_strip
output_dir = get_cache_dir_for_strip(strip.name)
detections_path = os.path.join(output_dir, "detections.msgpack")
if not os.path.exists(detections_path):
self.report({'ERROR'}, f"Detection cache not found: {detections_path}")
return {'CANCELLED'}
generator = get_generator()
scene = context.scene
wm = context.window_manager
wm.mask_progress = 0
wm.mask_total = 0 # サーバー側から実際の値に更新される
def on_complete(status, data):
wm.mask_total = max(wm.mask_total, generator.total_frames)
if status == "done":
wm.mask_progress = wm.mask_total
elif status in {"error", "cancelled"}:
wm.mask_progress = min(wm.mask_progress, wm.mask_total)
if status == "done":
print(f"[FaceMask] Pose augmentation completed: {data}")
elif status == "error":
print(f"[FaceMask] Error: {data}")
elif status == "cancelled":
print("[FaceMask] Pose augmentation cancelled")
for area in context.screen.areas:
if area.type == 'SEQUENCE_EDITOR':
area.tag_redraw()
def on_progress(current, total_f):
wm.mask_progress = current
wm.mask_total = total_f
for area in context.screen.areas:
if area.type == 'SEQUENCE_EDITOR':
area.tag_redraw()
try:
generator.start_augment_pose(
detections_path=detections_path,
total_frames=0,
conf_threshold=scene.facemask_conf_threshold,
iou_threshold=scene.facemask_iou_threshold,
on_complete=on_complete,
on_progress=on_progress,
)
except RuntimeError as e:
self.report({'WARNING'}, str(e))
return {'CANCELLED'}
except Exception as e:
self.report({'ERROR'}, f"Failed to start pose augmentation: {e}")
return {'CANCELLED'}
self.report({'INFO'}, f"Started pose augmentation for {strip.name}")
return {'FINISHED'}
# Registration
classes = [
SEQUENCER_OT_generate_face_mask,
SEQUENCER_OT_cancel_mask_generation,
SEQUENCER_OT_augment_pose_mask,
]

View File

@ -289,6 +289,12 @@ class SEQUENCER_PT_face_mask(Panel):
icon='FILE_REFRESH',
)
op.force = True
if strip.type == 'MOVIE':
box.operator(
"sequencer.augment_pose_mask",
text="Augment with Pose",
icon='MOD_ARMATURE',
)
def _draw_blur_controls(self, layout, context, strip):
"""Draw blur application controls."""

View File

@ -228,3 +228,205 @@ def get_detector(**kwargs) -> YOLOHeadDetector:
if _detector is None:
_detector = YOLOHeadDetector(**kwargs)
return _detector
# ---------------------------------------------------------------------------
# Pose-based head detector (YOLOv8 pose estimation)
# ---------------------------------------------------------------------------
# COCO pose keypoint indices
_HEAD_KP = [0, 1, 2, 3, 4] # nose, left_eye, right_eye, left_ear, right_ear
_SHOULDER_KP = [5, 6] # left_shoulder, right_shoulder
_KP_CONF_THRESH = 0.3
def _head_bbox_from_pose(
kp_xy: np.ndarray,
kp_conf: np.ndarray,
person_x1: float,
person_y1: float,
person_x2: float,
person_y2: float,
) -> Tuple[int, int, int, int]:
"""
Estimate head bounding box (x, y, w, h) from COCO pose keypoints.
Strategy:
1. Use head keypoints (0-4: nose, eyes, ears) if visible.
2. Fall back to shoulder keypoints (5-6) to infer head position.
3. Last resort: use top of the person bounding box.
"""
person_w = max(person_x2 - person_x1, 1.0)
# --- Step 1: head keypoints ---
visible_head = [
(float(kp_xy[i][0]), float(kp_xy[i][1]))
for i in _HEAD_KP
if float(kp_conf[i]) > _KP_CONF_THRESH
]
if visible_head:
xs = [p[0] for p in visible_head]
ys = [p[1] for p in visible_head]
kp_x1, kp_y1 = min(xs), min(ys)
kp_x2, kp_y2 = max(xs), max(ys)
span = max(kp_x2 - kp_x1, kp_y2 - kp_y1, 1.0)
cx = (kp_x1 + kp_x2) / 2.0
cy = (kp_y1 + kp_y2) / 2.0
r = max(span * 0.5, person_w * 0.10)
x1 = int(cx - r)
y1 = int(cy - r)
x2 = int(cx + r)
y2 = int(cy + r)
return x1, y1, x2 - x1, y2 - y1
# --- Step 2: shoulder keypoints ---
visible_shoulder = [
(float(kp_xy[i][0]), float(kp_xy[i][1]))
for i in _SHOULDER_KP
if float(kp_conf[i]) > _KP_CONF_THRESH
]
if visible_shoulder:
cx = sum(p[0] for p in visible_shoulder) / len(visible_shoulder)
cy_sh = sum(p[1] for p in visible_shoulder) / len(visible_shoulder)
if len(visible_shoulder) == 2:
sh_width = abs(visible_shoulder[1][0] - visible_shoulder[0][0])
else:
sh_width = person_w * 0.5
r = max(sh_width * 0.3, person_w * 0.12)
cy = cy_sh - r * 1.3
x1 = int(cx - r)
y1 = int(cy - r)
x2 = int(cx + r)
y2 = int(cy + r)
return x1, y1, x2 - x1, y2 - y1
# --- Step 3: person bbox top ---
r = max(person_w * 0.15, 20.0)
cx = (person_x1 + person_x2) / 2.0
x1 = int(cx - r)
y1 = int(person_y1)
x2 = int(cx + r)
y2 = int(person_y1 + r * 2.0)
return x1, y1, x2 - x1, y2 - y1
class YOLOPoseHeadDetector:
"""
Head detector using YOLOv8 pose estimation with PyTorch ROCm support.
Extracts head bounding boxes from COCO pose keypoints (nose, eyes, ears).
yolov8l-pose.pt is auto-downloaded by Ultralytics on first use.
"""
DEFAULT_MODEL = os.path.join("models", "yolov8l-pose.pt")
def __init__(
self,
model_path: Optional[str] = None,
conf_threshold: float = 0.25,
iou_threshold: float = 0.45,
input_size: Tuple[int, int] = (640, 640),
):
self.conf_threshold = conf_threshold
self.iou_threshold = iou_threshold
self.input_size = input_size
self._model = None
self._model_path = model_path
self._device = None
@property
def model(self):
"""Lazy-load YOLO pose model."""
if self._model is None:
from ultralytics import YOLO
import torch
model_path = self._model_path if self._model_path is not None else self.DEFAULT_MODEL
if torch.cuda.is_available():
self._device = 'cuda'
device_name = torch.cuda.get_device_name(0)
print(f"[FaceMask] Using ROCm GPU for pose inference: {device_name}")
else:
self._device = 'cpu'
print("[FaceMask] Using CPU for pose inference (ROCm GPU not available)")
try:
self._model = YOLO(model_path)
print(f"[FaceMask] Pose model loaded: {model_path}")
print(f"[FaceMask] Device: {self._device}")
except Exception as e:
print(f"[FaceMask] Error loading pose model: {e}")
import traceback
traceback.print_exc()
raise
return self._model
def _results_to_detections(self, result) -> List[Tuple[int, int, int, int, float]]:
"""Convert a single YOLO pose result to (x, y, w, h, conf) tuples."""
detections = []
if result.boxes is None or result.keypoints is None:
return detections
boxes = result.boxes
keypoints = result.keypoints
for i, box in enumerate(boxes):
conf = float(box.conf[0].cpu().numpy())
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
kp_data = keypoints.data[i].cpu().numpy() # shape (17, 3): x, y, conf
kp_xy = kp_data[:, :2]
kp_conf = kp_data[:, 2]
hx, hy, hw, hh = _head_bbox_from_pose(
kp_xy, kp_conf,
float(x1), float(y1), float(x2), float(y2),
)
detections.append((hx, hy, hw, hh, conf))
return detections
def detect_batch(self, frames: List[np.ndarray]) -> List[List[Tuple[int, int, int, int, float]]]:
"""Detect heads in multiple frames at once (batch processing)."""
if not frames:
return []
try:
results = self.model.predict(
frames,
conf=self.conf_threshold,
iou=self.iou_threshold,
imgsz=self.input_size[0],
verbose=False,
device=self._device,
)
except Exception as e:
print(f"[FaceMask] ERROR during pose batch inference: {e}")
import traceback
traceback.print_exc()
print("[FaceMask] Falling back to CPU inference...")
self._device = 'cpu'
results = self.model.predict(
frames,
conf=self.conf_threshold,
iou=self.iou_threshold,
imgsz=self.input_size[0],
verbose=False,
device='cpu',
)
return [self._results_to_detections(r) for r in results]
# Pose detector singleton
_pose_detector: Optional[YOLOPoseHeadDetector] = None
def get_pose_detector(**kwargs) -> YOLOPoseHeadDetector:
"""Get or create the global YOLO pose head detector instance."""
global _pose_detector
if _pose_detector is None:
_pose_detector = YOLOPoseHeadDetector(**kwargs)
return _pose_detector

View File

@ -49,7 +49,7 @@ import msgpack # noqa: E402
# Add project root to path for imports if needed
sys.path.append(str(Path(__file__).parent.parent))
from server.detector import get_detector # noqa: E402
from server.detector import get_detector, get_pose_detector # noqa: E402
app = FastAPI(title="Face Mask Inference Server")
@ -142,6 +142,12 @@ class GenerateImagesRequest(BaseModel):
iou_threshold: float = 0.45
class AugmentPoseRequest(BaseModel):
detections_path: str
conf_threshold: float = 0.5
iou_threshold: float = 0.45
class BakeImagesRequest(BaseModel):
image_dir: str
filenames: List[str]
@ -582,6 +588,129 @@ def process_bake_images_task(task_id: str, req: BakeImagesRequest):
del cancel_events[task_id]
def augment_pose_task(task_id: str, req: AugmentPoseRequest):
"""Background task: run pose estimation and merge results into existing cache."""
cap = None
try:
tasks[task_id].status = TaskStatus.PROCESSING
cancel_event = cancel_events.get(task_id)
if not os.path.exists(req.detections_path):
tasks[task_id].status = TaskStatus.FAILED
tasks[task_id].message = f"Detections file not found: {req.detections_path}"
return
with open(req.detections_path, "rb") as f:
payload = msgpack.unpackb(f.read(), raw=False)
existing_frames: List[List[List[float]]] = payload.get("frames", [])
video_path = payload.get("video_path")
start_frame = int(payload.get("start_frame", 0))
total = len(existing_frames)
if not video_path:
tasks[task_id].status = TaskStatus.FAILED
tasks[task_id].message = "Cache does not contain video_path (image caches not supported)"
return
if not os.path.exists(video_path):
tasks[task_id].status = TaskStatus.FAILED
tasks[task_id].message = f"Video not found: {video_path}"
return
if total == 0:
tasks[task_id].status = TaskStatus.FAILED
tasks[task_id].message = "Cache has no frames"
return
tasks[task_id].total = total
detector = get_pose_detector(
conf_threshold=req.conf_threshold,
iou_threshold=req.iou_threshold,
)
_ = detector.model
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
tasks[task_id].status = TaskStatus.FAILED
tasks[task_id].message = "Failed to open video"
return
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
if start_frame > 0:
seek_ok = cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
if not seek_ok:
for _ in range(start_frame):
ret, _ = cap.read()
if not ret:
tasks[task_id].status = TaskStatus.FAILED
tasks[task_id].message = f"Failed to seek to start frame: {start_frame}"
return
frame_buffer: List[np.ndarray] = []
buffer_indices: List[int] = [] # existing_frames インデックス対応
current_count = 0
batch_size = 5
def process_pose_batch():
nonlocal current_count
if not frame_buffer:
return
batch_detections = detector.detect_batch(frame_buffer)
for idx, detections in zip(buffer_indices, batch_detections):
for x, y, w, h, conf in detections:
bx, by, bw, bh = int(x), int(y), int(w), int(h)
bx = max(0, bx)
by = max(0, by)
bw = min(width - bx, bw)
bh = min(height - by, bh)
if bw > 0 and bh > 0:
existing_frames[idx].append([bx, by, bw, bh, float(conf)])
current_count += 1
tasks[task_id].progress = current_count
frame_buffer.clear()
buffer_indices.clear()
for i in range(total):
if cancel_event and cancel_event.is_set():
tasks[task_id].status = TaskStatus.CANCELLED
tasks[task_id].message = "Cancelled by user"
break
ret, frame = cap.read()
if not ret:
break
frame_buffer.append(frame)
buffer_indices.append(i)
if len(frame_buffer) >= batch_size:
process_pose_batch()
if frame_buffer:
process_pose_batch()
if tasks[task_id].status == TaskStatus.PROCESSING:
payload["frames"] = existing_frames
with open(req.detections_path, "wb") as f:
f.write(msgpack.packb(payload, use_bin_type=True))
tasks[task_id].status = TaskStatus.COMPLETED
tasks[task_id].result_path = req.detections_path
tasks[task_id].message = "Pose augmentation completed"
print(f"[FaceMask] Pose augmentation completed: {req.detections_path}")
except Exception as e:
tasks[task_id].status = TaskStatus.FAILED
tasks[task_id].message = str(e)
traceback.print_exc()
finally:
if cap:
cap.release()
if task_id in cancel_events:
del cancel_events[task_id]
def process_video_task(task_id: str, req: GenerateRequest):
"""Background task to detect faces and save bbox cache as msgpack."""
cap = None
@ -1261,6 +1390,16 @@ def generate_images_endpoint(req: GenerateImagesRequest, background_tasks: Backg
return task
@app.post("/augment_pose", response_model=Task)
def augment_pose_endpoint(req: AugmentPoseRequest, background_tasks: BackgroundTasks):
task_id = str(uuid.uuid4())
task = Task(id=task_id, status=TaskStatus.PENDING)
tasks[task_id] = task
cancel_events[task_id] = threading.Event()
background_tasks.add_task(augment_pose_task, task_id, req)
return task
@app.post("/bake_image_blur", response_model=Task)
def bake_image_blur_endpoint(req: BakeImagesRequest, background_tasks: BackgroundTasks):
task_id = str(uuid.uuid4())