姿勢推定と頭部検知を合成する手段を追加
This commit is contained in:
parent
de99aef9ad
commit
eb028ed278
|
|
@ -141,6 +141,90 @@ class AsyncMaskGenerator:
|
|||
self.worker_thread.start()
|
||||
bpy.app.timers.register(self._check_progress, first_interval=0.1)
|
||||
|
||||
def start_augment_pose(
|
||||
self,
|
||||
detections_path: str,
|
||||
total_frames: int,
|
||||
conf_threshold: float = 0.5,
|
||||
iou_threshold: float = 0.45,
|
||||
on_complete=None,
|
||||
on_progress=None,
|
||||
):
|
||||
"""既存キャッシュへの pose 補完を非同期で開始する。"""
|
||||
global bpy
|
||||
import bpy as _bpy
|
||||
bpy = _bpy
|
||||
|
||||
if self.is_running:
|
||||
raise RuntimeError("Mask generation already in progress")
|
||||
|
||||
self.is_running = True
|
||||
self.total_frames = total_frames
|
||||
self.current_frame = 0
|
||||
self._on_complete = on_complete
|
||||
self._on_progress = on_progress
|
||||
|
||||
self.worker_thread = threading.Thread(
|
||||
target=self._worker_augment_pose,
|
||||
args=(detections_path, conf_threshold, iou_threshold),
|
||||
daemon=True,
|
||||
)
|
||||
self.worker_thread.start()
|
||||
bpy.app.timers.register(self._check_progress, first_interval=0.1)
|
||||
|
||||
def _worker_augment_pose(
|
||||
self,
|
||||
detections_path: str,
|
||||
conf_threshold: float,
|
||||
iou_threshold: float,
|
||||
):
|
||||
"""client.augment_pose() を呼んで task_id でポーリング。"""
|
||||
import time
|
||||
from .inference_client import get_client
|
||||
|
||||
try:
|
||||
client = get_client()
|
||||
task_id = client.augment_pose(
|
||||
detections_path=detections_path,
|
||||
conf_threshold=conf_threshold,
|
||||
iou_threshold=iou_threshold,
|
||||
)
|
||||
|
||||
while self.is_running:
|
||||
status = client.get_task_status(task_id)
|
||||
state = status.get("status")
|
||||
|
||||
total = status.get("total", 0)
|
||||
if total > 0:
|
||||
self.total_frames = total
|
||||
|
||||
if state == "completed":
|
||||
progress = status.get("progress", self.total_frames)
|
||||
if progress >= 0:
|
||||
self.progress_queue.put(("progress", progress))
|
||||
result_path = status.get("result_path", detections_path)
|
||||
self.result_queue.put(("done", result_path))
|
||||
return
|
||||
elif state == "failed":
|
||||
self.result_queue.put(("error", status.get("message", "Unknown error")))
|
||||
return
|
||||
elif state == "cancelled":
|
||||
self.result_queue.put(("cancelled", None))
|
||||
return
|
||||
|
||||
progress = status.get("progress", 0)
|
||||
if progress >= 0:
|
||||
self.progress_queue.put(("progress", progress))
|
||||
time.sleep(0.5)
|
||||
|
||||
client.cancel_task(task_id)
|
||||
self.result_queue.put(("cancelled", None))
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"[FaceMask] Error: {e}")
|
||||
traceback.print_exc()
|
||||
self.result_queue.put(("error", str(e)))
|
||||
|
||||
def cancel(self):
|
||||
"""Cancel the current processing."""
|
||||
self.is_running = False
|
||||
|
|
|
|||
|
|
@ -237,6 +237,36 @@ class InferenceClient:
|
|||
except urllib.error.HTTPError as e:
|
||||
raise RuntimeError(f"Server error: {e.read().decode('utf-8')}")
|
||||
|
||||
def augment_pose(
|
||||
self,
|
||||
detections_path: str,
|
||||
conf_threshold: float,
|
||||
iou_threshold: float,
|
||||
) -> str:
|
||||
"""既存キャッシュに pose 推定結果を追加合成する。task_id を返す。"""
|
||||
if not self.is_server_running():
|
||||
self.start_server()
|
||||
|
||||
data = {
|
||||
"detections_path": detections_path,
|
||||
"conf_threshold": conf_threshold,
|
||||
"iou_threshold": iou_threshold,
|
||||
}
|
||||
|
||||
req = urllib.request.Request(
|
||||
f"{self.SERVER_URL}/augment_pose",
|
||||
data=json.dumps(data).encode("utf-8"),
|
||||
headers={"Content-Type": "application/json"},
|
||||
method="POST",
|
||||
)
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req) as response:
|
||||
result = json.loads(response.read().decode("utf-8"))
|
||||
return result["id"]
|
||||
except urllib.error.HTTPError as e:
|
||||
raise RuntimeError(f"Server error: {e.read().decode('utf-8')}")
|
||||
|
||||
def get_task_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""Get status of a task."""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -205,10 +205,88 @@ class SEQUENCER_OT_cancel_mask_generation(Operator):
|
|||
return {'FINISHED'}
|
||||
|
||||
|
||||
class SEQUENCER_OT_augment_pose_mask(Operator):
|
||||
"""Add pose-based head detections to existing detection cache."""
|
||||
|
||||
bl_idname = "sequencer.augment_pose_mask"
|
||||
bl_label = "Augment with Pose"
|
||||
bl_description = "Run pose estimation and merge results into existing detection cache"
|
||||
bl_options = {'REGISTER'}
|
||||
|
||||
@classmethod
|
||||
def poll(cls, context):
|
||||
if not context.scene.sequence_editor:
|
||||
return False
|
||||
strip = context.scene.sequence_editor.active_strip
|
||||
if not strip or strip.type != 'MOVIE':
|
||||
return False
|
||||
return check_detection_cache(strip.name)
|
||||
|
||||
def execute(self, context):
|
||||
strip = context.scene.sequence_editor.active_strip
|
||||
output_dir = get_cache_dir_for_strip(strip.name)
|
||||
detections_path = os.path.join(output_dir, "detections.msgpack")
|
||||
|
||||
if not os.path.exists(detections_path):
|
||||
self.report({'ERROR'}, f"Detection cache not found: {detections_path}")
|
||||
return {'CANCELLED'}
|
||||
|
||||
generator = get_generator()
|
||||
scene = context.scene
|
||||
wm = context.window_manager
|
||||
wm.mask_progress = 0
|
||||
wm.mask_total = 0 # サーバー側から実際の値に更新される
|
||||
|
||||
def on_complete(status, data):
|
||||
wm.mask_total = max(wm.mask_total, generator.total_frames)
|
||||
if status == "done":
|
||||
wm.mask_progress = wm.mask_total
|
||||
elif status in {"error", "cancelled"}:
|
||||
wm.mask_progress = min(wm.mask_progress, wm.mask_total)
|
||||
|
||||
if status == "done":
|
||||
print(f"[FaceMask] Pose augmentation completed: {data}")
|
||||
elif status == "error":
|
||||
print(f"[FaceMask] Error: {data}")
|
||||
elif status == "cancelled":
|
||||
print("[FaceMask] Pose augmentation cancelled")
|
||||
|
||||
for area in context.screen.areas:
|
||||
if area.type == 'SEQUENCE_EDITOR':
|
||||
area.tag_redraw()
|
||||
|
||||
def on_progress(current, total_f):
|
||||
wm.mask_progress = current
|
||||
wm.mask_total = total_f
|
||||
for area in context.screen.areas:
|
||||
if area.type == 'SEQUENCE_EDITOR':
|
||||
area.tag_redraw()
|
||||
|
||||
try:
|
||||
generator.start_augment_pose(
|
||||
detections_path=detections_path,
|
||||
total_frames=0,
|
||||
conf_threshold=scene.facemask_conf_threshold,
|
||||
iou_threshold=scene.facemask_iou_threshold,
|
||||
on_complete=on_complete,
|
||||
on_progress=on_progress,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
self.report({'WARNING'}, str(e))
|
||||
return {'CANCELLED'}
|
||||
except Exception as e:
|
||||
self.report({'ERROR'}, f"Failed to start pose augmentation: {e}")
|
||||
return {'CANCELLED'}
|
||||
|
||||
self.report({'INFO'}, f"Started pose augmentation for {strip.name}")
|
||||
return {'FINISHED'}
|
||||
|
||||
|
||||
# Registration
|
||||
classes = [
|
||||
SEQUENCER_OT_generate_face_mask,
|
||||
SEQUENCER_OT_cancel_mask_generation,
|
||||
SEQUENCER_OT_augment_pose_mask,
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -289,6 +289,12 @@ class SEQUENCER_PT_face_mask(Panel):
|
|||
icon='FILE_REFRESH',
|
||||
)
|
||||
op.force = True
|
||||
if strip.type == 'MOVIE':
|
||||
box.operator(
|
||||
"sequencer.augment_pose_mask",
|
||||
text="Augment with Pose",
|
||||
icon='MOD_ARMATURE',
|
||||
)
|
||||
|
||||
def _draw_blur_controls(self, layout, context, strip):
|
||||
"""Draw blur application controls."""
|
||||
|
|
|
|||
|
|
@ -228,3 +228,205 @@ def get_detector(**kwargs) -> YOLOHeadDetector:
|
|||
if _detector is None:
|
||||
_detector = YOLOHeadDetector(**kwargs)
|
||||
return _detector
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pose-based head detector (YOLOv8 pose estimation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# COCO pose keypoint indices
|
||||
_HEAD_KP = [0, 1, 2, 3, 4] # nose, left_eye, right_eye, left_ear, right_ear
|
||||
_SHOULDER_KP = [5, 6] # left_shoulder, right_shoulder
|
||||
_KP_CONF_THRESH = 0.3
|
||||
|
||||
|
||||
def _head_bbox_from_pose(
|
||||
kp_xy: np.ndarray,
|
||||
kp_conf: np.ndarray,
|
||||
person_x1: float,
|
||||
person_y1: float,
|
||||
person_x2: float,
|
||||
person_y2: float,
|
||||
) -> Tuple[int, int, int, int]:
|
||||
"""
|
||||
Estimate head bounding box (x, y, w, h) from COCO pose keypoints.
|
||||
|
||||
Strategy:
|
||||
1. Use head keypoints (0-4: nose, eyes, ears) if visible.
|
||||
2. Fall back to shoulder keypoints (5-6) to infer head position.
|
||||
3. Last resort: use top of the person bounding box.
|
||||
"""
|
||||
person_w = max(person_x2 - person_x1, 1.0)
|
||||
|
||||
# --- Step 1: head keypoints ---
|
||||
visible_head = [
|
||||
(float(kp_xy[i][0]), float(kp_xy[i][1]))
|
||||
for i in _HEAD_KP
|
||||
if float(kp_conf[i]) > _KP_CONF_THRESH
|
||||
]
|
||||
if visible_head:
|
||||
xs = [p[0] for p in visible_head]
|
||||
ys = [p[1] for p in visible_head]
|
||||
kp_x1, kp_y1 = min(xs), min(ys)
|
||||
kp_x2, kp_y2 = max(xs), max(ys)
|
||||
span = max(kp_x2 - kp_x1, kp_y2 - kp_y1, 1.0)
|
||||
cx = (kp_x1 + kp_x2) / 2.0
|
||||
cy = (kp_y1 + kp_y2) / 2.0
|
||||
r = max(span * 0.5, person_w * 0.10)
|
||||
x1 = int(cx - r)
|
||||
y1 = int(cy - r)
|
||||
x2 = int(cx + r)
|
||||
y2 = int(cy + r)
|
||||
return x1, y1, x2 - x1, y2 - y1
|
||||
|
||||
# --- Step 2: shoulder keypoints ---
|
||||
visible_shoulder = [
|
||||
(float(kp_xy[i][0]), float(kp_xy[i][1]))
|
||||
for i in _SHOULDER_KP
|
||||
if float(kp_conf[i]) > _KP_CONF_THRESH
|
||||
]
|
||||
if visible_shoulder:
|
||||
cx = sum(p[0] for p in visible_shoulder) / len(visible_shoulder)
|
||||
cy_sh = sum(p[1] for p in visible_shoulder) / len(visible_shoulder)
|
||||
if len(visible_shoulder) == 2:
|
||||
sh_width = abs(visible_shoulder[1][0] - visible_shoulder[0][0])
|
||||
else:
|
||||
sh_width = person_w * 0.5
|
||||
r = max(sh_width * 0.3, person_w * 0.12)
|
||||
cy = cy_sh - r * 1.3
|
||||
x1 = int(cx - r)
|
||||
y1 = int(cy - r)
|
||||
x2 = int(cx + r)
|
||||
y2 = int(cy + r)
|
||||
return x1, y1, x2 - x1, y2 - y1
|
||||
|
||||
# --- Step 3: person bbox top ---
|
||||
r = max(person_w * 0.15, 20.0)
|
||||
cx = (person_x1 + person_x2) / 2.0
|
||||
x1 = int(cx - r)
|
||||
y1 = int(person_y1)
|
||||
x2 = int(cx + r)
|
||||
y2 = int(person_y1 + r * 2.0)
|
||||
return x1, y1, x2 - x1, y2 - y1
|
||||
|
||||
|
||||
class YOLOPoseHeadDetector:
|
||||
"""
|
||||
Head detector using YOLOv8 pose estimation with PyTorch ROCm support.
|
||||
|
||||
Extracts head bounding boxes from COCO pose keypoints (nose, eyes, ears).
|
||||
yolov8l-pose.pt is auto-downloaded by Ultralytics on first use.
|
||||
"""
|
||||
|
||||
DEFAULT_MODEL = os.path.join("models", "yolov8l-pose.pt")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: Optional[str] = None,
|
||||
conf_threshold: float = 0.25,
|
||||
iou_threshold: float = 0.45,
|
||||
input_size: Tuple[int, int] = (640, 640),
|
||||
):
|
||||
self.conf_threshold = conf_threshold
|
||||
self.iou_threshold = iou_threshold
|
||||
self.input_size = input_size
|
||||
self._model = None
|
||||
self._model_path = model_path
|
||||
self._device = None
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""Lazy-load YOLO pose model."""
|
||||
if self._model is None:
|
||||
from ultralytics import YOLO
|
||||
import torch
|
||||
|
||||
model_path = self._model_path if self._model_path is not None else self.DEFAULT_MODEL
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self._device = 'cuda'
|
||||
device_name = torch.cuda.get_device_name(0)
|
||||
print(f"[FaceMask] Using ROCm GPU for pose inference: {device_name}")
|
||||
else:
|
||||
self._device = 'cpu'
|
||||
print("[FaceMask] Using CPU for pose inference (ROCm GPU not available)")
|
||||
|
||||
try:
|
||||
self._model = YOLO(model_path)
|
||||
print(f"[FaceMask] Pose model loaded: {model_path}")
|
||||
print(f"[FaceMask] Device: {self._device}")
|
||||
except Exception as e:
|
||||
print(f"[FaceMask] Error loading pose model: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
return self._model
|
||||
|
||||
def _results_to_detections(self, result) -> List[Tuple[int, int, int, int, float]]:
|
||||
"""Convert a single YOLO pose result to (x, y, w, h, conf) tuples."""
|
||||
detections = []
|
||||
if result.boxes is None or result.keypoints is None:
|
||||
return detections
|
||||
|
||||
boxes = result.boxes
|
||||
keypoints = result.keypoints
|
||||
|
||||
for i, box in enumerate(boxes):
|
||||
conf = float(box.conf[0].cpu().numpy())
|
||||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
||||
|
||||
kp_data = keypoints.data[i].cpu().numpy() # shape (17, 3): x, y, conf
|
||||
kp_xy = kp_data[:, :2]
|
||||
kp_conf = kp_data[:, 2]
|
||||
|
||||
hx, hy, hw, hh = _head_bbox_from_pose(
|
||||
kp_xy, kp_conf,
|
||||
float(x1), float(y1), float(x2), float(y2),
|
||||
)
|
||||
detections.append((hx, hy, hw, hh, conf))
|
||||
|
||||
return detections
|
||||
|
||||
def detect_batch(self, frames: List[np.ndarray]) -> List[List[Tuple[int, int, int, int, float]]]:
|
||||
"""Detect heads in multiple frames at once (batch processing)."""
|
||||
if not frames:
|
||||
return []
|
||||
|
||||
try:
|
||||
results = self.model.predict(
|
||||
frames,
|
||||
conf=self.conf_threshold,
|
||||
iou=self.iou_threshold,
|
||||
imgsz=self.input_size[0],
|
||||
verbose=False,
|
||||
device=self._device,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[FaceMask] ERROR during pose batch inference: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
print("[FaceMask] Falling back to CPU inference...")
|
||||
self._device = 'cpu'
|
||||
results = self.model.predict(
|
||||
frames,
|
||||
conf=self.conf_threshold,
|
||||
iou=self.iou_threshold,
|
||||
imgsz=self.input_size[0],
|
||||
verbose=False,
|
||||
device='cpu',
|
||||
)
|
||||
|
||||
return [self._results_to_detections(r) for r in results]
|
||||
|
||||
|
||||
# Pose detector singleton
|
||||
_pose_detector: Optional[YOLOPoseHeadDetector] = None
|
||||
|
||||
|
||||
def get_pose_detector(**kwargs) -> YOLOPoseHeadDetector:
|
||||
"""Get or create the global YOLO pose head detector instance."""
|
||||
global _pose_detector
|
||||
if _pose_detector is None:
|
||||
_pose_detector = YOLOPoseHeadDetector(**kwargs)
|
||||
return _pose_detector
|
||||
|
|
|
|||
141
server/main.py
141
server/main.py
|
|
@ -49,7 +49,7 @@ import msgpack # noqa: E402
|
|||
# Add project root to path for imports if needed
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
from server.detector import get_detector # noqa: E402
|
||||
from server.detector import get_detector, get_pose_detector # noqa: E402
|
||||
|
||||
app = FastAPI(title="Face Mask Inference Server")
|
||||
|
||||
|
|
@ -142,6 +142,12 @@ class GenerateImagesRequest(BaseModel):
|
|||
iou_threshold: float = 0.45
|
||||
|
||||
|
||||
class AugmentPoseRequest(BaseModel):
|
||||
detections_path: str
|
||||
conf_threshold: float = 0.5
|
||||
iou_threshold: float = 0.45
|
||||
|
||||
|
||||
class BakeImagesRequest(BaseModel):
|
||||
image_dir: str
|
||||
filenames: List[str]
|
||||
|
|
@ -582,6 +588,129 @@ def process_bake_images_task(task_id: str, req: BakeImagesRequest):
|
|||
del cancel_events[task_id]
|
||||
|
||||
|
||||
def augment_pose_task(task_id: str, req: AugmentPoseRequest):
|
||||
"""Background task: run pose estimation and merge results into existing cache."""
|
||||
cap = None
|
||||
try:
|
||||
tasks[task_id].status = TaskStatus.PROCESSING
|
||||
cancel_event = cancel_events.get(task_id)
|
||||
|
||||
if not os.path.exists(req.detections_path):
|
||||
tasks[task_id].status = TaskStatus.FAILED
|
||||
tasks[task_id].message = f"Detections file not found: {req.detections_path}"
|
||||
return
|
||||
|
||||
with open(req.detections_path, "rb") as f:
|
||||
payload = msgpack.unpackb(f.read(), raw=False)
|
||||
|
||||
existing_frames: List[List[List[float]]] = payload.get("frames", [])
|
||||
video_path = payload.get("video_path")
|
||||
start_frame = int(payload.get("start_frame", 0))
|
||||
total = len(existing_frames)
|
||||
|
||||
if not video_path:
|
||||
tasks[task_id].status = TaskStatus.FAILED
|
||||
tasks[task_id].message = "Cache does not contain video_path (image caches not supported)"
|
||||
return
|
||||
if not os.path.exists(video_path):
|
||||
tasks[task_id].status = TaskStatus.FAILED
|
||||
tasks[task_id].message = f"Video not found: {video_path}"
|
||||
return
|
||||
if total == 0:
|
||||
tasks[task_id].status = TaskStatus.FAILED
|
||||
tasks[task_id].message = "Cache has no frames"
|
||||
return
|
||||
|
||||
tasks[task_id].total = total
|
||||
|
||||
detector = get_pose_detector(
|
||||
conf_threshold=req.conf_threshold,
|
||||
iou_threshold=req.iou_threshold,
|
||||
)
|
||||
_ = detector.model
|
||||
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
tasks[task_id].status = TaskStatus.FAILED
|
||||
tasks[task_id].message = "Failed to open video"
|
||||
return
|
||||
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
if start_frame > 0:
|
||||
seek_ok = cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
||||
if not seek_ok:
|
||||
for _ in range(start_frame):
|
||||
ret, _ = cap.read()
|
||||
if not ret:
|
||||
tasks[task_id].status = TaskStatus.FAILED
|
||||
tasks[task_id].message = f"Failed to seek to start frame: {start_frame}"
|
||||
return
|
||||
|
||||
frame_buffer: List[np.ndarray] = []
|
||||
buffer_indices: List[int] = [] # existing_frames インデックス対応
|
||||
current_count = 0
|
||||
batch_size = 5
|
||||
|
||||
def process_pose_batch():
|
||||
nonlocal current_count
|
||||
if not frame_buffer:
|
||||
return
|
||||
batch_detections = detector.detect_batch(frame_buffer)
|
||||
for idx, detections in zip(buffer_indices, batch_detections):
|
||||
for x, y, w, h, conf in detections:
|
||||
bx, by, bw, bh = int(x), int(y), int(w), int(h)
|
||||
bx = max(0, bx)
|
||||
by = max(0, by)
|
||||
bw = min(width - bx, bw)
|
||||
bh = min(height - by, bh)
|
||||
if bw > 0 and bh > 0:
|
||||
existing_frames[idx].append([bx, by, bw, bh, float(conf)])
|
||||
current_count += 1
|
||||
tasks[task_id].progress = current_count
|
||||
frame_buffer.clear()
|
||||
buffer_indices.clear()
|
||||
|
||||
for i in range(total):
|
||||
if cancel_event and cancel_event.is_set():
|
||||
tasks[task_id].status = TaskStatus.CANCELLED
|
||||
tasks[task_id].message = "Cancelled by user"
|
||||
break
|
||||
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
frame_buffer.append(frame)
|
||||
buffer_indices.append(i)
|
||||
if len(frame_buffer) >= batch_size:
|
||||
process_pose_batch()
|
||||
|
||||
if frame_buffer:
|
||||
process_pose_batch()
|
||||
|
||||
if tasks[task_id].status == TaskStatus.PROCESSING:
|
||||
payload["frames"] = existing_frames
|
||||
with open(req.detections_path, "wb") as f:
|
||||
f.write(msgpack.packb(payload, use_bin_type=True))
|
||||
|
||||
tasks[task_id].status = TaskStatus.COMPLETED
|
||||
tasks[task_id].result_path = req.detections_path
|
||||
tasks[task_id].message = "Pose augmentation completed"
|
||||
print(f"[FaceMask] Pose augmentation completed: {req.detections_path}")
|
||||
|
||||
except Exception as e:
|
||||
tasks[task_id].status = TaskStatus.FAILED
|
||||
tasks[task_id].message = str(e)
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
if cap:
|
||||
cap.release()
|
||||
if task_id in cancel_events:
|
||||
del cancel_events[task_id]
|
||||
|
||||
|
||||
def process_video_task(task_id: str, req: GenerateRequest):
|
||||
"""Background task to detect faces and save bbox cache as msgpack."""
|
||||
cap = None
|
||||
|
|
@ -1261,6 +1390,16 @@ def generate_images_endpoint(req: GenerateImagesRequest, background_tasks: Backg
|
|||
return task
|
||||
|
||||
|
||||
@app.post("/augment_pose", response_model=Task)
|
||||
def augment_pose_endpoint(req: AugmentPoseRequest, background_tasks: BackgroundTasks):
|
||||
task_id = str(uuid.uuid4())
|
||||
task = Task(id=task_id, status=TaskStatus.PENDING)
|
||||
tasks[task_id] = task
|
||||
cancel_events[task_id] = threading.Event()
|
||||
background_tasks.add_task(augment_pose_task, task_id, req)
|
||||
return task
|
||||
|
||||
|
||||
@app.post("/bake_image_blur", response_model=Task)
|
||||
def bake_image_blur_endpoint(req: BakeImagesRequest, background_tasks: BackgroundTasks):
|
||||
task_id = str(uuid.uuid4())
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user