姿勢推定と頭部検知を合成する手段を追加
This commit is contained in:
parent
de99aef9ad
commit
eb028ed278
|
|
@ -141,6 +141,90 @@ class AsyncMaskGenerator:
|
||||||
self.worker_thread.start()
|
self.worker_thread.start()
|
||||||
bpy.app.timers.register(self._check_progress, first_interval=0.1)
|
bpy.app.timers.register(self._check_progress, first_interval=0.1)
|
||||||
|
|
||||||
|
def start_augment_pose(
|
||||||
|
self,
|
||||||
|
detections_path: str,
|
||||||
|
total_frames: int,
|
||||||
|
conf_threshold: float = 0.5,
|
||||||
|
iou_threshold: float = 0.45,
|
||||||
|
on_complete=None,
|
||||||
|
on_progress=None,
|
||||||
|
):
|
||||||
|
"""既存キャッシュへの pose 補完を非同期で開始する。"""
|
||||||
|
global bpy
|
||||||
|
import bpy as _bpy
|
||||||
|
bpy = _bpy
|
||||||
|
|
||||||
|
if self.is_running:
|
||||||
|
raise RuntimeError("Mask generation already in progress")
|
||||||
|
|
||||||
|
self.is_running = True
|
||||||
|
self.total_frames = total_frames
|
||||||
|
self.current_frame = 0
|
||||||
|
self._on_complete = on_complete
|
||||||
|
self._on_progress = on_progress
|
||||||
|
|
||||||
|
self.worker_thread = threading.Thread(
|
||||||
|
target=self._worker_augment_pose,
|
||||||
|
args=(detections_path, conf_threshold, iou_threshold),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
self.worker_thread.start()
|
||||||
|
bpy.app.timers.register(self._check_progress, first_interval=0.1)
|
||||||
|
|
||||||
|
def _worker_augment_pose(
|
||||||
|
self,
|
||||||
|
detections_path: str,
|
||||||
|
conf_threshold: float,
|
||||||
|
iou_threshold: float,
|
||||||
|
):
|
||||||
|
"""client.augment_pose() を呼んで task_id でポーリング。"""
|
||||||
|
import time
|
||||||
|
from .inference_client import get_client
|
||||||
|
|
||||||
|
try:
|
||||||
|
client = get_client()
|
||||||
|
task_id = client.augment_pose(
|
||||||
|
detections_path=detections_path,
|
||||||
|
conf_threshold=conf_threshold,
|
||||||
|
iou_threshold=iou_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
while self.is_running:
|
||||||
|
status = client.get_task_status(task_id)
|
||||||
|
state = status.get("status")
|
||||||
|
|
||||||
|
total = status.get("total", 0)
|
||||||
|
if total > 0:
|
||||||
|
self.total_frames = total
|
||||||
|
|
||||||
|
if state == "completed":
|
||||||
|
progress = status.get("progress", self.total_frames)
|
||||||
|
if progress >= 0:
|
||||||
|
self.progress_queue.put(("progress", progress))
|
||||||
|
result_path = status.get("result_path", detections_path)
|
||||||
|
self.result_queue.put(("done", result_path))
|
||||||
|
return
|
||||||
|
elif state == "failed":
|
||||||
|
self.result_queue.put(("error", status.get("message", "Unknown error")))
|
||||||
|
return
|
||||||
|
elif state == "cancelled":
|
||||||
|
self.result_queue.put(("cancelled", None))
|
||||||
|
return
|
||||||
|
|
||||||
|
progress = status.get("progress", 0)
|
||||||
|
if progress >= 0:
|
||||||
|
self.progress_queue.put(("progress", progress))
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
client.cancel_task(task_id)
|
||||||
|
self.result_queue.put(("cancelled", None))
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
print(f"[FaceMask] Error: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
self.result_queue.put(("error", str(e)))
|
||||||
|
|
||||||
def cancel(self):
|
def cancel(self):
|
||||||
"""Cancel the current processing."""
|
"""Cancel the current processing."""
|
||||||
self.is_running = False
|
self.is_running = False
|
||||||
|
|
|
||||||
|
|
@ -237,6 +237,36 @@ class InferenceClient:
|
||||||
except urllib.error.HTTPError as e:
|
except urllib.error.HTTPError as e:
|
||||||
raise RuntimeError(f"Server error: {e.read().decode('utf-8')}")
|
raise RuntimeError(f"Server error: {e.read().decode('utf-8')}")
|
||||||
|
|
||||||
|
def augment_pose(
|
||||||
|
self,
|
||||||
|
detections_path: str,
|
||||||
|
conf_threshold: float,
|
||||||
|
iou_threshold: float,
|
||||||
|
) -> str:
|
||||||
|
"""既存キャッシュに pose 推定結果を追加合成する。task_id を返す。"""
|
||||||
|
if not self.is_server_running():
|
||||||
|
self.start_server()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"detections_path": detections_path,
|
||||||
|
"conf_threshold": conf_threshold,
|
||||||
|
"iou_threshold": iou_threshold,
|
||||||
|
}
|
||||||
|
|
||||||
|
req = urllib.request.Request(
|
||||||
|
f"{self.SERVER_URL}/augment_pose",
|
||||||
|
data=json.dumps(data).encode("utf-8"),
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
method="POST",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(req) as response:
|
||||||
|
result = json.loads(response.read().decode("utf-8"))
|
||||||
|
return result["id"]
|
||||||
|
except urllib.error.HTTPError as e:
|
||||||
|
raise RuntimeError(f"Server error: {e.read().decode('utf-8')}")
|
||||||
|
|
||||||
def get_task_status(self, task_id: str) -> Dict[str, Any]:
|
def get_task_status(self, task_id: str) -> Dict[str, Any]:
|
||||||
"""Get status of a task."""
|
"""Get status of a task."""
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -205,10 +205,88 @@ class SEQUENCER_OT_cancel_mask_generation(Operator):
|
||||||
return {'FINISHED'}
|
return {'FINISHED'}
|
||||||
|
|
||||||
|
|
||||||
|
class SEQUENCER_OT_augment_pose_mask(Operator):
|
||||||
|
"""Add pose-based head detections to existing detection cache."""
|
||||||
|
|
||||||
|
bl_idname = "sequencer.augment_pose_mask"
|
||||||
|
bl_label = "Augment with Pose"
|
||||||
|
bl_description = "Run pose estimation and merge results into existing detection cache"
|
||||||
|
bl_options = {'REGISTER'}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def poll(cls, context):
|
||||||
|
if not context.scene.sequence_editor:
|
||||||
|
return False
|
||||||
|
strip = context.scene.sequence_editor.active_strip
|
||||||
|
if not strip or strip.type != 'MOVIE':
|
||||||
|
return False
|
||||||
|
return check_detection_cache(strip.name)
|
||||||
|
|
||||||
|
def execute(self, context):
|
||||||
|
strip = context.scene.sequence_editor.active_strip
|
||||||
|
output_dir = get_cache_dir_for_strip(strip.name)
|
||||||
|
detections_path = os.path.join(output_dir, "detections.msgpack")
|
||||||
|
|
||||||
|
if not os.path.exists(detections_path):
|
||||||
|
self.report({'ERROR'}, f"Detection cache not found: {detections_path}")
|
||||||
|
return {'CANCELLED'}
|
||||||
|
|
||||||
|
generator = get_generator()
|
||||||
|
scene = context.scene
|
||||||
|
wm = context.window_manager
|
||||||
|
wm.mask_progress = 0
|
||||||
|
wm.mask_total = 0 # サーバー側から実際の値に更新される
|
||||||
|
|
||||||
|
def on_complete(status, data):
|
||||||
|
wm.mask_total = max(wm.mask_total, generator.total_frames)
|
||||||
|
if status == "done":
|
||||||
|
wm.mask_progress = wm.mask_total
|
||||||
|
elif status in {"error", "cancelled"}:
|
||||||
|
wm.mask_progress = min(wm.mask_progress, wm.mask_total)
|
||||||
|
|
||||||
|
if status == "done":
|
||||||
|
print(f"[FaceMask] Pose augmentation completed: {data}")
|
||||||
|
elif status == "error":
|
||||||
|
print(f"[FaceMask] Error: {data}")
|
||||||
|
elif status == "cancelled":
|
||||||
|
print("[FaceMask] Pose augmentation cancelled")
|
||||||
|
|
||||||
|
for area in context.screen.areas:
|
||||||
|
if area.type == 'SEQUENCE_EDITOR':
|
||||||
|
area.tag_redraw()
|
||||||
|
|
||||||
|
def on_progress(current, total_f):
|
||||||
|
wm.mask_progress = current
|
||||||
|
wm.mask_total = total_f
|
||||||
|
for area in context.screen.areas:
|
||||||
|
if area.type == 'SEQUENCE_EDITOR':
|
||||||
|
area.tag_redraw()
|
||||||
|
|
||||||
|
try:
|
||||||
|
generator.start_augment_pose(
|
||||||
|
detections_path=detections_path,
|
||||||
|
total_frames=0,
|
||||||
|
conf_threshold=scene.facemask_conf_threshold,
|
||||||
|
iou_threshold=scene.facemask_iou_threshold,
|
||||||
|
on_complete=on_complete,
|
||||||
|
on_progress=on_progress,
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
self.report({'WARNING'}, str(e))
|
||||||
|
return {'CANCELLED'}
|
||||||
|
except Exception as e:
|
||||||
|
self.report({'ERROR'}, f"Failed to start pose augmentation: {e}")
|
||||||
|
return {'CANCELLED'}
|
||||||
|
|
||||||
|
self.report({'INFO'}, f"Started pose augmentation for {strip.name}")
|
||||||
|
return {'FINISHED'}
|
||||||
|
|
||||||
|
|
||||||
# Registration
|
# Registration
|
||||||
classes = [
|
classes = [
|
||||||
SEQUENCER_OT_generate_face_mask,
|
SEQUENCER_OT_generate_face_mask,
|
||||||
SEQUENCER_OT_cancel_mask_generation,
|
SEQUENCER_OT_cancel_mask_generation,
|
||||||
|
SEQUENCER_OT_augment_pose_mask,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -289,6 +289,12 @@ class SEQUENCER_PT_face_mask(Panel):
|
||||||
icon='FILE_REFRESH',
|
icon='FILE_REFRESH',
|
||||||
)
|
)
|
||||||
op.force = True
|
op.force = True
|
||||||
|
if strip.type == 'MOVIE':
|
||||||
|
box.operator(
|
||||||
|
"sequencer.augment_pose_mask",
|
||||||
|
text="Augment with Pose",
|
||||||
|
icon='MOD_ARMATURE',
|
||||||
|
)
|
||||||
|
|
||||||
def _draw_blur_controls(self, layout, context, strip):
|
def _draw_blur_controls(self, layout, context, strip):
|
||||||
"""Draw blur application controls."""
|
"""Draw blur application controls."""
|
||||||
|
|
|
||||||
|
|
@ -228,3 +228,205 @@ def get_detector(**kwargs) -> YOLOHeadDetector:
|
||||||
if _detector is None:
|
if _detector is None:
|
||||||
_detector = YOLOHeadDetector(**kwargs)
|
_detector = YOLOHeadDetector(**kwargs)
|
||||||
return _detector
|
return _detector
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pose-based head detector (YOLOv8 pose estimation)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# COCO pose keypoint indices
|
||||||
|
_HEAD_KP = [0, 1, 2, 3, 4] # nose, left_eye, right_eye, left_ear, right_ear
|
||||||
|
_SHOULDER_KP = [5, 6] # left_shoulder, right_shoulder
|
||||||
|
_KP_CONF_THRESH = 0.3
|
||||||
|
|
||||||
|
|
||||||
|
def _head_bbox_from_pose(
|
||||||
|
kp_xy: np.ndarray,
|
||||||
|
kp_conf: np.ndarray,
|
||||||
|
person_x1: float,
|
||||||
|
person_y1: float,
|
||||||
|
person_x2: float,
|
||||||
|
person_y2: float,
|
||||||
|
) -> Tuple[int, int, int, int]:
|
||||||
|
"""
|
||||||
|
Estimate head bounding box (x, y, w, h) from COCO pose keypoints.
|
||||||
|
|
||||||
|
Strategy:
|
||||||
|
1. Use head keypoints (0-4: nose, eyes, ears) if visible.
|
||||||
|
2. Fall back to shoulder keypoints (5-6) to infer head position.
|
||||||
|
3. Last resort: use top of the person bounding box.
|
||||||
|
"""
|
||||||
|
person_w = max(person_x2 - person_x1, 1.0)
|
||||||
|
|
||||||
|
# --- Step 1: head keypoints ---
|
||||||
|
visible_head = [
|
||||||
|
(float(kp_xy[i][0]), float(kp_xy[i][1]))
|
||||||
|
for i in _HEAD_KP
|
||||||
|
if float(kp_conf[i]) > _KP_CONF_THRESH
|
||||||
|
]
|
||||||
|
if visible_head:
|
||||||
|
xs = [p[0] for p in visible_head]
|
||||||
|
ys = [p[1] for p in visible_head]
|
||||||
|
kp_x1, kp_y1 = min(xs), min(ys)
|
||||||
|
kp_x2, kp_y2 = max(xs), max(ys)
|
||||||
|
span = max(kp_x2 - kp_x1, kp_y2 - kp_y1, 1.0)
|
||||||
|
cx = (kp_x1 + kp_x2) / 2.0
|
||||||
|
cy = (kp_y1 + kp_y2) / 2.0
|
||||||
|
r = max(span * 0.5, person_w * 0.10)
|
||||||
|
x1 = int(cx - r)
|
||||||
|
y1 = int(cy - r)
|
||||||
|
x2 = int(cx + r)
|
||||||
|
y2 = int(cy + r)
|
||||||
|
return x1, y1, x2 - x1, y2 - y1
|
||||||
|
|
||||||
|
# --- Step 2: shoulder keypoints ---
|
||||||
|
visible_shoulder = [
|
||||||
|
(float(kp_xy[i][0]), float(kp_xy[i][1]))
|
||||||
|
for i in _SHOULDER_KP
|
||||||
|
if float(kp_conf[i]) > _KP_CONF_THRESH
|
||||||
|
]
|
||||||
|
if visible_shoulder:
|
||||||
|
cx = sum(p[0] for p in visible_shoulder) / len(visible_shoulder)
|
||||||
|
cy_sh = sum(p[1] for p in visible_shoulder) / len(visible_shoulder)
|
||||||
|
if len(visible_shoulder) == 2:
|
||||||
|
sh_width = abs(visible_shoulder[1][0] - visible_shoulder[0][0])
|
||||||
|
else:
|
||||||
|
sh_width = person_w * 0.5
|
||||||
|
r = max(sh_width * 0.3, person_w * 0.12)
|
||||||
|
cy = cy_sh - r * 1.3
|
||||||
|
x1 = int(cx - r)
|
||||||
|
y1 = int(cy - r)
|
||||||
|
x2 = int(cx + r)
|
||||||
|
y2 = int(cy + r)
|
||||||
|
return x1, y1, x2 - x1, y2 - y1
|
||||||
|
|
||||||
|
# --- Step 3: person bbox top ---
|
||||||
|
r = max(person_w * 0.15, 20.0)
|
||||||
|
cx = (person_x1 + person_x2) / 2.0
|
||||||
|
x1 = int(cx - r)
|
||||||
|
y1 = int(person_y1)
|
||||||
|
x2 = int(cx + r)
|
||||||
|
y2 = int(person_y1 + r * 2.0)
|
||||||
|
return x1, y1, x2 - x1, y2 - y1
|
||||||
|
|
||||||
|
|
||||||
|
class YOLOPoseHeadDetector:
|
||||||
|
"""
|
||||||
|
Head detector using YOLOv8 pose estimation with PyTorch ROCm support.
|
||||||
|
|
||||||
|
Extracts head bounding boxes from COCO pose keypoints (nose, eyes, ears).
|
||||||
|
yolov8l-pose.pt is auto-downloaded by Ultralytics on first use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_MODEL = os.path.join("models", "yolov8l-pose.pt")
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path: Optional[str] = None,
|
||||||
|
conf_threshold: float = 0.25,
|
||||||
|
iou_threshold: float = 0.45,
|
||||||
|
input_size: Tuple[int, int] = (640, 640),
|
||||||
|
):
|
||||||
|
self.conf_threshold = conf_threshold
|
||||||
|
self.iou_threshold = iou_threshold
|
||||||
|
self.input_size = input_size
|
||||||
|
self._model = None
|
||||||
|
self._model_path = model_path
|
||||||
|
self._device = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self):
|
||||||
|
"""Lazy-load YOLO pose model."""
|
||||||
|
if self._model is None:
|
||||||
|
from ultralytics import YOLO
|
||||||
|
import torch
|
||||||
|
|
||||||
|
model_path = self._model_path if self._model_path is not None else self.DEFAULT_MODEL
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self._device = 'cuda'
|
||||||
|
device_name = torch.cuda.get_device_name(0)
|
||||||
|
print(f"[FaceMask] Using ROCm GPU for pose inference: {device_name}")
|
||||||
|
else:
|
||||||
|
self._device = 'cpu'
|
||||||
|
print("[FaceMask] Using CPU for pose inference (ROCm GPU not available)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._model = YOLO(model_path)
|
||||||
|
print(f"[FaceMask] Pose model loaded: {model_path}")
|
||||||
|
print(f"[FaceMask] Device: {self._device}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[FaceMask] Error loading pose model: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
raise
|
||||||
|
|
||||||
|
return self._model
|
||||||
|
|
||||||
|
def _results_to_detections(self, result) -> List[Tuple[int, int, int, int, float]]:
|
||||||
|
"""Convert a single YOLO pose result to (x, y, w, h, conf) tuples."""
|
||||||
|
detections = []
|
||||||
|
if result.boxes is None or result.keypoints is None:
|
||||||
|
return detections
|
||||||
|
|
||||||
|
boxes = result.boxes
|
||||||
|
keypoints = result.keypoints
|
||||||
|
|
||||||
|
for i, box in enumerate(boxes):
|
||||||
|
conf = float(box.conf[0].cpu().numpy())
|
||||||
|
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
||||||
|
|
||||||
|
kp_data = keypoints.data[i].cpu().numpy() # shape (17, 3): x, y, conf
|
||||||
|
kp_xy = kp_data[:, :2]
|
||||||
|
kp_conf = kp_data[:, 2]
|
||||||
|
|
||||||
|
hx, hy, hw, hh = _head_bbox_from_pose(
|
||||||
|
kp_xy, kp_conf,
|
||||||
|
float(x1), float(y1), float(x2), float(y2),
|
||||||
|
)
|
||||||
|
detections.append((hx, hy, hw, hh, conf))
|
||||||
|
|
||||||
|
return detections
|
||||||
|
|
||||||
|
def detect_batch(self, frames: List[np.ndarray]) -> List[List[Tuple[int, int, int, int, float]]]:
|
||||||
|
"""Detect heads in multiple frames at once (batch processing)."""
|
||||||
|
if not frames:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = self.model.predict(
|
||||||
|
frames,
|
||||||
|
conf=self.conf_threshold,
|
||||||
|
iou=self.iou_threshold,
|
||||||
|
imgsz=self.input_size[0],
|
||||||
|
verbose=False,
|
||||||
|
device=self._device,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[FaceMask] ERROR during pose batch inference: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
print("[FaceMask] Falling back to CPU inference...")
|
||||||
|
self._device = 'cpu'
|
||||||
|
results = self.model.predict(
|
||||||
|
frames,
|
||||||
|
conf=self.conf_threshold,
|
||||||
|
iou=self.iou_threshold,
|
||||||
|
imgsz=self.input_size[0],
|
||||||
|
verbose=False,
|
||||||
|
device='cpu',
|
||||||
|
)
|
||||||
|
|
||||||
|
return [self._results_to_detections(r) for r in results]
|
||||||
|
|
||||||
|
|
||||||
|
# Pose detector singleton
|
||||||
|
_pose_detector: Optional[YOLOPoseHeadDetector] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_pose_detector(**kwargs) -> YOLOPoseHeadDetector:
|
||||||
|
"""Get or create the global YOLO pose head detector instance."""
|
||||||
|
global _pose_detector
|
||||||
|
if _pose_detector is None:
|
||||||
|
_pose_detector = YOLOPoseHeadDetector(**kwargs)
|
||||||
|
return _pose_detector
|
||||||
|
|
|
||||||
141
server/main.py
141
server/main.py
|
|
@ -49,7 +49,7 @@ import msgpack # noqa: E402
|
||||||
# Add project root to path for imports if needed
|
# Add project root to path for imports if needed
|
||||||
sys.path.append(str(Path(__file__).parent.parent))
|
sys.path.append(str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
from server.detector import get_detector # noqa: E402
|
from server.detector import get_detector, get_pose_detector # noqa: E402
|
||||||
|
|
||||||
app = FastAPI(title="Face Mask Inference Server")
|
app = FastAPI(title="Face Mask Inference Server")
|
||||||
|
|
||||||
|
|
@ -142,6 +142,12 @@ class GenerateImagesRequest(BaseModel):
|
||||||
iou_threshold: float = 0.45
|
iou_threshold: float = 0.45
|
||||||
|
|
||||||
|
|
||||||
|
class AugmentPoseRequest(BaseModel):
|
||||||
|
detections_path: str
|
||||||
|
conf_threshold: float = 0.5
|
||||||
|
iou_threshold: float = 0.45
|
||||||
|
|
||||||
|
|
||||||
class BakeImagesRequest(BaseModel):
|
class BakeImagesRequest(BaseModel):
|
||||||
image_dir: str
|
image_dir: str
|
||||||
filenames: List[str]
|
filenames: List[str]
|
||||||
|
|
@ -582,6 +588,129 @@ def process_bake_images_task(task_id: str, req: BakeImagesRequest):
|
||||||
del cancel_events[task_id]
|
del cancel_events[task_id]
|
||||||
|
|
||||||
|
|
||||||
|
def augment_pose_task(task_id: str, req: AugmentPoseRequest):
|
||||||
|
"""Background task: run pose estimation and merge results into existing cache."""
|
||||||
|
cap = None
|
||||||
|
try:
|
||||||
|
tasks[task_id].status = TaskStatus.PROCESSING
|
||||||
|
cancel_event = cancel_events.get(task_id)
|
||||||
|
|
||||||
|
if not os.path.exists(req.detections_path):
|
||||||
|
tasks[task_id].status = TaskStatus.FAILED
|
||||||
|
tasks[task_id].message = f"Detections file not found: {req.detections_path}"
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(req.detections_path, "rb") as f:
|
||||||
|
payload = msgpack.unpackb(f.read(), raw=False)
|
||||||
|
|
||||||
|
existing_frames: List[List[List[float]]] = payload.get("frames", [])
|
||||||
|
video_path = payload.get("video_path")
|
||||||
|
start_frame = int(payload.get("start_frame", 0))
|
||||||
|
total = len(existing_frames)
|
||||||
|
|
||||||
|
if not video_path:
|
||||||
|
tasks[task_id].status = TaskStatus.FAILED
|
||||||
|
tasks[task_id].message = "Cache does not contain video_path (image caches not supported)"
|
||||||
|
return
|
||||||
|
if not os.path.exists(video_path):
|
||||||
|
tasks[task_id].status = TaskStatus.FAILED
|
||||||
|
tasks[task_id].message = f"Video not found: {video_path}"
|
||||||
|
return
|
||||||
|
if total == 0:
|
||||||
|
tasks[task_id].status = TaskStatus.FAILED
|
||||||
|
tasks[task_id].message = "Cache has no frames"
|
||||||
|
return
|
||||||
|
|
||||||
|
tasks[task_id].total = total
|
||||||
|
|
||||||
|
detector = get_pose_detector(
|
||||||
|
conf_threshold=req.conf_threshold,
|
||||||
|
iou_threshold=req.iou_threshold,
|
||||||
|
)
|
||||||
|
_ = detector.model
|
||||||
|
|
||||||
|
cap = cv2.VideoCapture(video_path)
|
||||||
|
if not cap.isOpened():
|
||||||
|
tasks[task_id].status = TaskStatus.FAILED
|
||||||
|
tasks[task_id].message = "Failed to open video"
|
||||||
|
return
|
||||||
|
|
||||||
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
|
||||||
|
if start_frame > 0:
|
||||||
|
seek_ok = cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
||||||
|
if not seek_ok:
|
||||||
|
for _ in range(start_frame):
|
||||||
|
ret, _ = cap.read()
|
||||||
|
if not ret:
|
||||||
|
tasks[task_id].status = TaskStatus.FAILED
|
||||||
|
tasks[task_id].message = f"Failed to seek to start frame: {start_frame}"
|
||||||
|
return
|
||||||
|
|
||||||
|
frame_buffer: List[np.ndarray] = []
|
||||||
|
buffer_indices: List[int] = [] # existing_frames インデックス対応
|
||||||
|
current_count = 0
|
||||||
|
batch_size = 5
|
||||||
|
|
||||||
|
def process_pose_batch():
|
||||||
|
nonlocal current_count
|
||||||
|
if not frame_buffer:
|
||||||
|
return
|
||||||
|
batch_detections = detector.detect_batch(frame_buffer)
|
||||||
|
for idx, detections in zip(buffer_indices, batch_detections):
|
||||||
|
for x, y, w, h, conf in detections:
|
||||||
|
bx, by, bw, bh = int(x), int(y), int(w), int(h)
|
||||||
|
bx = max(0, bx)
|
||||||
|
by = max(0, by)
|
||||||
|
bw = min(width - bx, bw)
|
||||||
|
bh = min(height - by, bh)
|
||||||
|
if bw > 0 and bh > 0:
|
||||||
|
existing_frames[idx].append([bx, by, bw, bh, float(conf)])
|
||||||
|
current_count += 1
|
||||||
|
tasks[task_id].progress = current_count
|
||||||
|
frame_buffer.clear()
|
||||||
|
buffer_indices.clear()
|
||||||
|
|
||||||
|
for i in range(total):
|
||||||
|
if cancel_event and cancel_event.is_set():
|
||||||
|
tasks[task_id].status = TaskStatus.CANCELLED
|
||||||
|
tasks[task_id].message = "Cancelled by user"
|
||||||
|
break
|
||||||
|
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
|
||||||
|
frame_buffer.append(frame)
|
||||||
|
buffer_indices.append(i)
|
||||||
|
if len(frame_buffer) >= batch_size:
|
||||||
|
process_pose_batch()
|
||||||
|
|
||||||
|
if frame_buffer:
|
||||||
|
process_pose_batch()
|
||||||
|
|
||||||
|
if tasks[task_id].status == TaskStatus.PROCESSING:
|
||||||
|
payload["frames"] = existing_frames
|
||||||
|
with open(req.detections_path, "wb") as f:
|
||||||
|
f.write(msgpack.packb(payload, use_bin_type=True))
|
||||||
|
|
||||||
|
tasks[task_id].status = TaskStatus.COMPLETED
|
||||||
|
tasks[task_id].result_path = req.detections_path
|
||||||
|
tasks[task_id].message = "Pose augmentation completed"
|
||||||
|
print(f"[FaceMask] Pose augmentation completed: {req.detections_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
tasks[task_id].status = TaskStatus.FAILED
|
||||||
|
tasks[task_id].message = str(e)
|
||||||
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
|
if cap:
|
||||||
|
cap.release()
|
||||||
|
if task_id in cancel_events:
|
||||||
|
del cancel_events[task_id]
|
||||||
|
|
||||||
|
|
||||||
def process_video_task(task_id: str, req: GenerateRequest):
|
def process_video_task(task_id: str, req: GenerateRequest):
|
||||||
"""Background task to detect faces and save bbox cache as msgpack."""
|
"""Background task to detect faces and save bbox cache as msgpack."""
|
||||||
cap = None
|
cap = None
|
||||||
|
|
@ -1261,6 +1390,16 @@ def generate_images_endpoint(req: GenerateImagesRequest, background_tasks: Backg
|
||||||
return task
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/augment_pose", response_model=Task)
|
||||||
|
def augment_pose_endpoint(req: AugmentPoseRequest, background_tasks: BackgroundTasks):
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
task = Task(id=task_id, status=TaskStatus.PENDING)
|
||||||
|
tasks[task_id] = task
|
||||||
|
cancel_events[task_id] = threading.Event()
|
||||||
|
background_tasks.add_task(augment_pose_task, task_id, req)
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
@app.post("/bake_image_blur", response_model=Task)
|
@app.post("/bake_image_blur", response_model=Task)
|
||||||
def bake_image_blur_endpoint(req: BakeImagesRequest, background_tasks: BackgroundTasks):
|
def bake_image_blur_endpoint(req: BakeImagesRequest, background_tasks: BackgroundTasks):
|
||||||
task_id = str(uuid.uuid4())
|
task_id = str(uuid.uuid4())
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user