993 lines
34 KiB
Python
993 lines
34 KiB
Python
"""
|
||
Face Detection Inference Server.
|
||
|
||
This FastAPI application runs in a separate process to handle
|
||
GPU-accelerated face detection using ONNX Runtime.
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import platform
|
||
|
||
# CRITICAL: Fix LD_LIBRARY_PATH before importing cv2 or torch
|
||
# cv2 adds its own lib path to the front, which can override ROCm libraries
|
||
def fix_library_path():
|
||
"""Ensure ROCm libraries are loaded before cv2's bundled libraries."""
|
||
ld_path = os.environ.get('LD_LIBRARY_PATH', '')
|
||
|
||
# Split and filter paths
|
||
paths = [p for p in ld_path.split(':') if p]
|
||
|
||
# Separate ROCm/GPU paths from other paths
|
||
rocm_paths = [p for p in paths if 'rocm' in p.lower() or 'clr-' in p or 'hip' in p.lower()]
|
||
other_paths = [p for p in paths if p not in rocm_paths]
|
||
|
||
# Rebuild with ROCm paths first
|
||
if rocm_paths:
|
||
new_ld_path = ':'.join(rocm_paths + other_paths)
|
||
os.environ['LD_LIBRARY_PATH'] = new_ld_path
|
||
print("[FaceMask] Fixed LD_LIBRARY_PATH to prioritize ROCm libraries")
|
||
|
||
# Fix library path BEFORE any other imports
|
||
fix_library_path()
|
||
|
||
import queue
|
||
import threading
|
||
import uuid
|
||
import traceback
|
||
import subprocess
|
||
from typing import Dict, Optional, List
|
||
from pathlib import Path
|
||
|
||
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
||
from pydantic import BaseModel
|
||
import uvicorn
|
||
import cv2
|
||
import numpy as np
|
||
import msgpack
|
||
|
||
# Add project root to path for imports if needed
|
||
sys.path.append(str(Path(__file__).parent.parent))
|
||
|
||
from server.detector import get_detector
|
||
|
||
app = FastAPI(title="Face Mask Inference Server")
|
||
|
||
|
||
def _get_r_frame_rate(video_path: str) -> tuple:
|
||
"""ffprobe でコンテナ宣言の r_frame_rate を取得する。
|
||
|
||
Returns:
|
||
(fps_float, fps_str): fps_str は "120/1" のような分数文字列。
|
||
取得失敗時は (0.0, "")。
|
||
"""
|
||
try:
|
||
result = subprocess.run(
|
||
[
|
||
"ffprobe", "-v", "error",
|
||
"-select_streams", "v:0",
|
||
"-show_entries", "stream=r_frame_rate",
|
||
"-of", "default=noprint_wrappers=1:nokey=1",
|
||
video_path,
|
||
],
|
||
capture_output=True,
|
||
text=True,
|
||
timeout=10,
|
||
)
|
||
if result.returncode == 0:
|
||
rate_str = result.stdout.strip()
|
||
if "/" in rate_str:
|
||
num, den = rate_str.split("/")
|
||
fps_float = float(num) / float(den)
|
||
else:
|
||
fps_float = float(rate_str)
|
||
rate_str = str(fps_float)
|
||
return fps_float, rate_str
|
||
except Exception:
|
||
pass
|
||
return 0.0, ""
|
||
|
||
# GPU status cache
|
||
_gpu_status_cache = None
|
||
|
||
# Task storage
|
||
class TaskStatus:
|
||
PENDING = "pending"
|
||
PROCESSING = "processing"
|
||
COMPLETED = "completed"
|
||
FAILED = "failed"
|
||
CANCELLED = "cancelled"
|
||
|
||
class Task(BaseModel):
|
||
id: str
|
||
status: str
|
||
progress: int = 0
|
||
total: int = 0
|
||
message: Optional[str] = None
|
||
result_path: Optional[str] = None
|
||
|
||
# In-memory storage
|
||
tasks: Dict[str, Task] = {}
|
||
cancel_events: Dict[str, threading.Event] = {}
|
||
|
||
class GenerateRequest(BaseModel):
|
||
video_path: str
|
||
output_dir: str
|
||
start_frame: int
|
||
end_frame: int
|
||
conf_threshold: float = 0.5
|
||
iou_threshold: float = 0.45
|
||
|
||
|
||
class VideoInfoRequest(BaseModel):
|
||
video_path: str
|
||
|
||
|
||
class BakeRequest(BaseModel):
|
||
video_path: str
|
||
detections_path: str
|
||
output_path: str
|
||
blur_size: int = 50
|
||
display_scale: float = 1.0
|
||
format: str = "mp4"
|
||
|
||
|
||
class _FFmpegPipeWriter:
|
||
"""Write BGR frames to ffmpeg stdin."""
|
||
|
||
def __init__(self, cmd: List[str]):
|
||
self._proc = subprocess.Popen(
|
||
cmd,
|
||
stdin=subprocess.PIPE,
|
||
stdout=subprocess.DEVNULL,
|
||
stderr=subprocess.PIPE,
|
||
)
|
||
if self._proc.stdin is None:
|
||
self._proc.kill()
|
||
raise RuntimeError("Failed to open ffmpeg stdin")
|
||
|
||
def write(self, frame: np.ndarray) -> None:
|
||
if self._proc.stdin is None:
|
||
raise RuntimeError("ffmpeg stdin is not available")
|
||
self._proc.stdin.write(frame.tobytes())
|
||
|
||
def release(self) -> None:
|
||
if self._proc.stdin:
|
||
try:
|
||
self._proc.stdin.close()
|
||
except Exception:
|
||
pass
|
||
|
||
stderr_output = ""
|
||
if self._proc.stderr:
|
||
try:
|
||
stderr_output = self._proc.stderr.read().decode("utf-8", errors="replace")
|
||
except Exception:
|
||
stderr_output = ""
|
||
|
||
rc = self._proc.wait()
|
||
if rc != 0:
|
||
tail = "\n".join(stderr_output.strip().splitlines()[-8:])
|
||
raise RuntimeError(f"ffmpeg writer failed (code={rc}): {tail}")
|
||
|
||
|
||
def _build_ffmpeg_vaapi_writer(
|
||
output_path: str,
|
||
fps: float,
|
||
width: int,
|
||
height: int,
|
||
out_fps_str: str = "",
|
||
) -> _FFmpegPipeWriter:
|
||
"""Create ffmpeg h264_vaapi writer with QP=24 (balanced quality/speed).
|
||
|
||
fps: ソース動画の avg_frame_rate(rawパイプの入力レート)
|
||
out_fps_str: 出力コンテナに宣言する r_frame_rate("120/1" 等)。
|
||
ソースと異なる場合は fps フィルタでフレームを補完する。
|
||
"""
|
||
# ソースの avg_fps と出力の r_fps が有意に異なる場合のみ fps フィルタを挿入
|
||
needs_fps_filter = bool(out_fps_str)
|
||
if needs_fps_filter:
|
||
try:
|
||
if "/" in out_fps_str:
|
||
num, den = out_fps_str.split("/")
|
||
out_fps_float = float(num) / float(den)
|
||
else:
|
||
out_fps_float = float(out_fps_str)
|
||
needs_fps_filter = abs(out_fps_float - fps) > 0.01
|
||
except ValueError:
|
||
needs_fps_filter = False
|
||
|
||
if needs_fps_filter:
|
||
vf = f"format=nv12,fps={out_fps_str},hwupload"
|
||
print(f"[FaceMask] fps filter: {fps:.3f} -> {out_fps_str}")
|
||
else:
|
||
vf = "format=nv12,hwupload"
|
||
|
||
cmd = [
|
||
"ffmpeg",
|
||
"-hide_banner",
|
||
"-loglevel",
|
||
"error",
|
||
"-y",
|
||
"-vaapi_device",
|
||
"/dev/dri/renderD128",
|
||
"-f",
|
||
"rawvideo",
|
||
"-pix_fmt",
|
||
"bgr24",
|
||
"-s",
|
||
f"{width}x{height}",
|
||
"-r",
|
||
f"{fps}",
|
||
"-i",
|
||
"-",
|
||
"-an",
|
||
"-vf",
|
||
vf,
|
||
"-c:v",
|
||
"h264_vaapi",
|
||
"-qp",
|
||
"24",
|
||
output_path,
|
||
]
|
||
return _FFmpegPipeWriter(cmd)
|
||
|
||
|
||
def _build_video_writer(
|
||
output_path: str,
|
||
fmt: str,
|
||
fps: float,
|
||
width: int,
|
||
height: int,
|
||
out_fps_str: str = "",
|
||
) -> object:
|
||
"""Create writer with VAAPI preference and OpenCV fallback."""
|
||
format_key = fmt.lower()
|
||
|
||
if format_key in {"mp4", "mov"}:
|
||
try:
|
||
writer = _build_ffmpeg_vaapi_writer(output_path, fps, width, height, out_fps_str)
|
||
print("[FaceMask] Using output encoder: ffmpeg h264_vaapi (-qp 24)")
|
||
return writer
|
||
except Exception as e:
|
||
print(f"[FaceMask] VAAPI writer unavailable, fallback to OpenCV: {e}")
|
||
|
||
codec_candidates = {
|
||
"mp4": ["avc1", "mp4v"],
|
||
"mov": ["avc1", "mp4v"],
|
||
"avi": ["MJPG", "XVID"],
|
||
}.get(format_key, ["mp4v"])
|
||
|
||
for codec in codec_candidates:
|
||
writer = cv2.VideoWriter(
|
||
output_path,
|
||
cv2.VideoWriter_fourcc(*codec),
|
||
fps,
|
||
(width, height),
|
||
isColor=True,
|
||
)
|
||
if writer.isOpened():
|
||
print(f"[FaceMask] Using output codec: {codec}")
|
||
return writer
|
||
writer.release()
|
||
|
||
raise RuntimeError(f"Failed to create video writer for format='{fmt}'")
|
||
|
||
def _scale_bbox(
|
||
x: int,
|
||
y: int,
|
||
w: int,
|
||
h: int,
|
||
scale: float,
|
||
frame_width: int,
|
||
frame_height: int,
|
||
) -> Optional[List[int]]:
|
||
"""Scale bbox around center and clamp to frame boundaries."""
|
||
if w <= 0 or h <= 0:
|
||
return None
|
||
|
||
center_x = x + (w * 0.5)
|
||
center_y = y + (h * 0.5)
|
||
scaled_w = max(1, int(w * scale))
|
||
scaled_h = max(1, int(h * scale))
|
||
|
||
x1 = max(0, int(center_x - scaled_w * 0.5))
|
||
y1 = max(0, int(center_y - scaled_h * 0.5))
|
||
x2 = min(frame_width, x1 + scaled_w)
|
||
y2 = min(frame_height, y1 + scaled_h)
|
||
out_w = x2 - x1
|
||
out_h = y2 - y1
|
||
if out_w <= 0 or out_h <= 0:
|
||
return None
|
||
return [x1, y1, out_w, out_h]
|
||
|
||
|
||
def process_video_task(task_id: str, req: GenerateRequest):
|
||
"""Background task to detect faces and save bbox cache as msgpack."""
|
||
cap = None
|
||
|
||
try:
|
||
tasks[task_id].status = TaskStatus.PROCESSING
|
||
cancel_event = cancel_events.get(task_id)
|
||
|
||
if not os.path.exists(req.video_path):
|
||
tasks[task_id].status = TaskStatus.FAILED
|
||
tasks[task_id].message = f"Video not found: {req.video_path}"
|
||
return
|
||
|
||
print(f"Loading detector for task {task_id}...")
|
||
detector = get_detector(
|
||
conf_threshold=req.conf_threshold,
|
||
iou_threshold=req.iou_threshold,
|
||
)
|
||
_ = detector.model
|
||
|
||
cap = cv2.VideoCapture(req.video_path)
|
||
if not cap.isOpened():
|
||
tasks[task_id].status = TaskStatus.FAILED
|
||
tasks[task_id].message = "Failed to open video"
|
||
return
|
||
|
||
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
total_video_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
end_frame = min(req.end_frame, total_video_frames - 1)
|
||
frames_to_process = end_frame - req.start_frame + 1
|
||
if frames_to_process <= 0:
|
||
tasks[task_id].status = TaskStatus.FAILED
|
||
tasks[task_id].message = "Invalid frame range"
|
||
return
|
||
|
||
tasks[task_id].total = frames_to_process
|
||
os.makedirs(req.output_dir, exist_ok=True)
|
||
output_msgpack_path = os.path.join(req.output_dir, "detections.msgpack")
|
||
|
||
if req.start_frame > 0:
|
||
seek_ok = cap.set(cv2.CAP_PROP_POS_FRAMES, req.start_frame)
|
||
if not seek_ok:
|
||
for _ in range(req.start_frame):
|
||
ret, _ = cap.read()
|
||
if not ret:
|
||
tasks[task_id].status = TaskStatus.FAILED
|
||
tasks[task_id].message = (
|
||
f"Failed to seek to start frame: {req.start_frame}"
|
||
)
|
||
return
|
||
|
||
frame_buffer: List[np.ndarray] = []
|
||
frame_detections: List[List[List[float]]] = []
|
||
batch_size = 5
|
||
current_count = 0
|
||
|
||
def process_batch():
|
||
nonlocal current_count
|
||
if not frame_buffer:
|
||
return
|
||
|
||
batch_detections = detector.detect_batch(frame_buffer)
|
||
for detections in batch_detections:
|
||
packed_detections: List[List[float]] = []
|
||
for x, y, w, h, conf in detections:
|
||
# bboxをそのまま保存(表示スケールはBake時に適用)
|
||
bx, by, bw, bh = int(x), int(y), int(w), int(h)
|
||
bx = max(0, bx)
|
||
by = max(0, by)
|
||
bw = min(width - bx, bw)
|
||
bh = min(height - by, bh)
|
||
if bw <= 0 or bh <= 0:
|
||
continue
|
||
packed_detections.append([bx, by, bw, bh, float(conf)])
|
||
frame_detections.append(packed_detections)
|
||
current_count += 1
|
||
tasks[task_id].progress = current_count
|
||
|
||
frame_buffer.clear()
|
||
|
||
print(
|
||
f"Starting detection cache generation: {req.video_path} "
|
||
f"({frames_to_process} frames) -> {output_msgpack_path}"
|
||
)
|
||
|
||
for _ in range(req.start_frame, end_frame + 1):
|
||
if cancel_event and cancel_event.is_set():
|
||
tasks[task_id].status = TaskStatus.CANCELLED
|
||
tasks[task_id].message = "Cancelled by user"
|
||
break
|
||
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
break
|
||
|
||
frame_buffer.append(frame)
|
||
if len(frame_buffer) >= batch_size:
|
||
process_batch()
|
||
|
||
if frame_buffer:
|
||
process_batch()
|
||
|
||
if tasks[task_id].status == TaskStatus.PROCESSING:
|
||
payload = {
|
||
"version": 1,
|
||
"video_path": req.video_path,
|
||
"start_frame": req.start_frame,
|
||
"end_frame": req.start_frame + len(frame_detections) - 1,
|
||
"width": width,
|
||
"height": height,
|
||
"fps": fps,
|
||
"mask_scale": 1.0,
|
||
"frames": frame_detections,
|
||
}
|
||
with open(output_msgpack_path, "wb") as f:
|
||
f.write(msgpack.packb(payload, use_bin_type=True))
|
||
|
||
tasks[task_id].status = TaskStatus.COMPLETED
|
||
tasks[task_id].result_path = output_msgpack_path
|
||
tasks[task_id].message = "Detection cache completed"
|
||
print(f"Task {task_id} completed: {output_msgpack_path}")
|
||
|
||
except Exception as e:
|
||
tasks[task_id].status = TaskStatus.FAILED
|
||
tasks[task_id].message = str(e)
|
||
print(f"Error in task {task_id}: {e}")
|
||
traceback.print_exc()
|
||
finally:
|
||
if cap:
|
||
cap.release()
|
||
if task_id in cancel_events:
|
||
del cancel_events[task_id]
|
||
|
||
|
||
def process_bake_task(task_id: str, req: BakeRequest):
|
||
"""Bake blur using async pipeline: read/process/write run in parallel for 1.35x speedup."""
|
||
try:
|
||
tasks[task_id].status = TaskStatus.PROCESSING
|
||
cancel_event = cancel_events.get(task_id)
|
||
|
||
if not os.path.exists(req.video_path):
|
||
tasks[task_id].status = TaskStatus.FAILED
|
||
tasks[task_id].message = f"Video not found: {req.video_path}"
|
||
return
|
||
if not os.path.exists(req.detections_path):
|
||
tasks[task_id].status = TaskStatus.FAILED
|
||
tasks[task_id].message = f"Detections file not found: {req.detections_path}"
|
||
return
|
||
|
||
with open(req.detections_path, "rb") as f:
|
||
payload = msgpack.unpackb(f.read(), raw=False)
|
||
frames_detections = payload.get("frames")
|
||
if not isinstance(frames_detections, list):
|
||
tasks[task_id].status = TaskStatus.FAILED
|
||
tasks[task_id].message = "Invalid detections format: 'frames' is missing"
|
||
return
|
||
|
||
# 検出キャッシュの開始フレーム(ソース動画のフレームインデックス)
|
||
det_start_frame = int(payload.get("start_frame", 0))
|
||
|
||
# Get video info
|
||
temp_cap = cv2.VideoCapture(req.video_path)
|
||
if not temp_cap.isOpened():
|
||
tasks[task_id].status = TaskStatus.FAILED
|
||
tasks[task_id].message = "Failed to open source video"
|
||
return
|
||
|
||
src_fps = temp_cap.get(cv2.CAP_PROP_FPS) or 30.0
|
||
src_width = int(temp_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
src_height = int(temp_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
src_frames = int(temp_cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
temp_cap.release()
|
||
|
||
# ffprobe で r_frame_rate を取得し、出力コンテナの宣言 FPS をソースに合わせる。
|
||
# 例: 120fps タイムベースで記録された 60fps 動画は r_frame_rate=120/1 だが
|
||
# cv2 は avg_frame_rate=60fps を返すため、Bake 後に Blender がFPSを別値で認識してしまう。
|
||
r_fps_float, r_fps_str = _get_r_frame_rate(req.video_path)
|
||
if r_fps_float > 0:
|
||
print(f"[FaceMask] r_frame_rate={r_fps_str}, avg_fps={src_fps:.3f}")
|
||
else:
|
||
r_fps_str = ""
|
||
|
||
if src_width <= 0 or src_height <= 0:
|
||
tasks[task_id].status = TaskStatus.FAILED
|
||
tasks[task_id].message = "Invalid source video dimensions"
|
||
return
|
||
|
||
# ソース動画の全フレームを出力(スワップ後もトリム設定が正しく機能するよう)
|
||
total = src_frames if src_frames > 0 else (det_start_frame + len(frames_detections))
|
||
if total <= 0:
|
||
tasks[task_id].status = TaskStatus.FAILED
|
||
tasks[task_id].message = "Source/detections frame count is zero"
|
||
return
|
||
tasks[task_id].total = total
|
||
|
||
output_dir = os.path.dirname(req.output_path)
|
||
if output_dir:
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# Pipeline setup
|
||
blur_size = max(1, int(req.blur_size))
|
||
if blur_size % 2 == 0:
|
||
blur_size += 1
|
||
display_scale = max(0.1, float(req.display_scale))
|
||
# blur_margin は境界問題回避のための計算用余白のみ(表示には使わない)
|
||
blur_margin = blur_size // 2
|
||
|
||
# Queues
|
||
queue_size = 8
|
||
read_queue: queue.Queue = queue.Queue(maxsize=queue_size)
|
||
process_queue: queue.Queue = queue.Queue(maxsize=queue_size)
|
||
|
||
# Shared state
|
||
error_holder = {"error": None}
|
||
progress_lock = threading.Lock()
|
||
current_progress = [0]
|
||
|
||
def _reader_worker():
|
||
"""Read frames from video."""
|
||
import time as _time
|
||
cap = cv2.VideoCapture(req.video_path)
|
||
if not cap.isOpened():
|
||
error_holder["error"] = "Failed to open video in reader"
|
||
return
|
||
|
||
t_read_total = 0.0
|
||
frame_count = 0
|
||
try:
|
||
for idx in range(total):
|
||
if cancel_event and cancel_event.is_set():
|
||
break
|
||
|
||
t0 = _time.perf_counter()
|
||
ok, frame = cap.read()
|
||
t_read_total += _time.perf_counter() - t0
|
||
|
||
if not ok:
|
||
break
|
||
|
||
read_queue.put((idx, frame))
|
||
frame_count += 1
|
||
except Exception as e:
|
||
error_holder["error"] = f"Reader error: {e}"
|
||
finally:
|
||
cap.release()
|
||
read_queue.put(None) # Sentinel
|
||
if frame_count > 0:
|
||
print(
|
||
f"[Perf/Reader] FINAL frame={frame_count}"
|
||
f" read_avg={t_read_total/frame_count*1000:.1f}ms"
|
||
f" throughput≈{frame_count/max(t_read_total,1e-9):.1f}fps"
|
||
)
|
||
|
||
def _processor_worker():
|
||
"""Process frames with ROI blur."""
|
||
import time as _time
|
||
t_wait_total = 0.0
|
||
t_blur_total = 0.0
|
||
t_blend_total = 0.0
|
||
frame_count = 0
|
||
REPORT_INTERVAL = 50
|
||
|
||
try:
|
||
while True:
|
||
if cancel_event and cancel_event.is_set():
|
||
process_queue.put(None)
|
||
break
|
||
|
||
t0 = _time.perf_counter()
|
||
item = read_queue.get()
|
||
t_wait_total += _time.perf_counter() - t0
|
||
|
||
if item is None:
|
||
process_queue.put(None)
|
||
break
|
||
|
||
idx, frame = item
|
||
det_idx = idx - det_start_frame
|
||
frame_boxes = frames_detections[det_idx] if 0 <= det_idx < len(frames_detections) else []
|
||
|
||
if not frame_boxes:
|
||
process_queue.put((idx, frame))
|
||
frame_count += 1
|
||
continue
|
||
|
||
# 各人物ごとに個別ROIで処理(全員まとめると離れた人物間が巨大ROIになるため)
|
||
valid_boxes = []
|
||
for box in frame_boxes:
|
||
if not isinstance(box, list) or len(box) < 4:
|
||
continue
|
||
x, y, w, h = int(box[0]), int(box[1]), int(box[2]), int(box[3])
|
||
if w <= 0 or h <= 0:
|
||
continue
|
||
valid_boxes.append((x, y, w, h))
|
||
|
||
if not valid_boxes:
|
||
process_queue.put((idx, frame))
|
||
frame_count += 1
|
||
continue
|
||
|
||
for x, y, w, h in valid_boxes:
|
||
# display_scale で表示サイズを決定
|
||
cx = x + w / 2
|
||
cy = y + h / 2
|
||
dw = max(1, int(w * display_scale))
|
||
dh = max(1, int(h * display_scale))
|
||
dx = int(cx - dw / 2)
|
||
dy = int(cy - dh / 2)
|
||
|
||
# ROIは表示サイズ + blur_margin(計算用余白、境界問題回避のみ)
|
||
roi_x1 = max(0, dx - blur_margin)
|
||
roi_y1 = max(0, dy - blur_margin)
|
||
roi_x2 = min(src_width, dx + dw + blur_margin)
|
||
roi_y2 = min(src_height, dy + dh + blur_margin)
|
||
roi_width = roi_x2 - roi_x1
|
||
roi_height = roi_y2 - roi_y1
|
||
|
||
if roi_width <= 0 or roi_height <= 0:
|
||
continue
|
||
|
||
# ブラーはROI全体で計算(余白があるので端の精度が保証される)
|
||
roi_src = frame[roi_y1:roi_y2, roi_x1:roi_x2]
|
||
|
||
# ダウンサンプル→blur→アップサンプル(同等のぼかしを1/4の計算量で実現)
|
||
t1 = _time.perf_counter()
|
||
small_w = max(1, roi_width // 2)
|
||
small_h = max(1, roi_height // 2)
|
||
roi_small = cv2.resize(roi_src, (small_w, small_h), interpolation=cv2.INTER_LINEAR)
|
||
small_blur_size = max(3, (blur_size // 2) | 1)
|
||
roi_small_blurred = cv2.GaussianBlur(roi_small, (small_blur_size, small_blur_size), 0)
|
||
roi_blurred = cv2.resize(roi_small_blurred, (roi_width, roi_height), interpolation=cv2.INTER_LINEAR)
|
||
t_blur_total += _time.perf_counter() - t1
|
||
|
||
# 合成マスクはdisplay_scaleサイズの楕円のみ(featheringなし)
|
||
roi_mask = np.zeros((roi_height, roi_width), dtype=np.uint8)
|
||
center = (int(cx) - roi_x1, int(cy) - roi_y1)
|
||
axes = (max(1, dw // 2), max(1, dh // 2))
|
||
cv2.ellipse(roi_mask, center, axes, 0, 0, 360, 255, -1)
|
||
|
||
# バイナリマスクなのでcopyToで高速合成(float32変換不要)
|
||
t2 = _time.perf_counter()
|
||
result = roi_src.copy()
|
||
cv2.copyTo(roi_blurred, roi_mask, result)
|
||
frame[roi_y1:roi_y2, roi_x1:roi_x2] = result
|
||
t_blend_total += _time.perf_counter() - t2
|
||
|
||
process_queue.put((idx, frame))
|
||
frame_count += 1
|
||
|
||
if frame_count % REPORT_INTERVAL == 0:
|
||
n = max(frame_count, 1)
|
||
fps_proc = frame_count / max(t_wait_total + t_blur_total + t_blend_total, 1e-9)
|
||
print(
|
||
f"[Perf/Processor] frame={frame_count}"
|
||
f" wait={t_wait_total/n*1000:.1f}ms"
|
||
f" blur={t_blur_total/n*1000:.1f}ms"
|
||
f" blend={t_blend_total/n*1000:.1f}ms"
|
||
f" ROI={roi_width}x{roi_height}"
|
||
f" throughput≈{fps_proc:.1f}fps"
|
||
)
|
||
|
||
except Exception as e:
|
||
error_holder["error"] = f"Processor error: {e}"
|
||
process_queue.put(None)
|
||
finally:
|
||
if frame_count > 0:
|
||
n = max(frame_count, 1)
|
||
print(
|
||
f"[Perf/Processor] FINAL frame={frame_count}"
|
||
f" wait_avg={t_wait_total/n*1000:.1f}ms"
|
||
f" blur_avg={t_blur_total/n*1000:.1f}ms"
|
||
f" blend_avg={t_blend_total/n*1000:.1f}ms"
|
||
)
|
||
|
||
def _writer_worker():
|
||
"""Write frames to output."""
|
||
import time as _time
|
||
t_wait_total = 0.0
|
||
t_write_total = 0.0
|
||
frame_count = 0
|
||
writer = None
|
||
try:
|
||
writer = _build_video_writer(req.output_path, req.format, src_fps, src_width, src_height, r_fps_str)
|
||
|
||
while True:
|
||
if cancel_event and cancel_event.is_set():
|
||
break
|
||
|
||
t0 = _time.perf_counter()
|
||
item = process_queue.get()
|
||
t_wait_total += _time.perf_counter() - t0
|
||
|
||
if item is None:
|
||
break
|
||
|
||
idx, frame = item
|
||
t1 = _time.perf_counter()
|
||
writer.write(frame)
|
||
t_write_total += _time.perf_counter() - t1
|
||
frame_count += 1
|
||
|
||
with progress_lock:
|
||
current_progress[0] = idx + 1
|
||
tasks[task_id].progress = current_progress[0]
|
||
|
||
except Exception as e:
|
||
error_holder["error"] = f"Writer error: {e}"
|
||
finally:
|
||
if writer:
|
||
try:
|
||
writer.release()
|
||
except Exception as e:
|
||
print(f"[FaceMask] Writer release error: {e}")
|
||
if frame_count > 0:
|
||
n = max(frame_count, 1)
|
||
print(
|
||
f"[Perf/Writer] FINAL frame={frame_count}"
|
||
f" wait_avg={t_wait_total/n*1000:.1f}ms"
|
||
f" write_avg={t_write_total/n*1000:.1f}ms"
|
||
)
|
||
|
||
print(
|
||
f"[FaceMask] Starting blur bake: {req.video_path} + "
|
||
f"{req.detections_path} -> {req.output_path}"
|
||
)
|
||
|
||
# Start threads
|
||
reader_thread = threading.Thread(target=_reader_worker, daemon=True)
|
||
processor_thread = threading.Thread(target=_processor_worker, daemon=True)
|
||
writer_thread = threading.Thread(target=_writer_worker, daemon=True)
|
||
|
||
reader_thread.start()
|
||
processor_thread.start()
|
||
writer_thread.start()
|
||
|
||
# Wait for completion
|
||
reader_thread.join()
|
||
processor_thread.join()
|
||
writer_thread.join()
|
||
|
||
if error_holder["error"]:
|
||
tasks[task_id].status = TaskStatus.FAILED
|
||
tasks[task_id].message = error_holder["error"]
|
||
print(f"[FaceMask] Bake failed: {error_holder['error']}")
|
||
elif cancel_event and cancel_event.is_set():
|
||
tasks[task_id].status = TaskStatus.CANCELLED
|
||
tasks[task_id].message = "Cancelled by user"
|
||
else:
|
||
tasks[task_id].status = TaskStatus.COMPLETED
|
||
tasks[task_id].result_path = req.output_path
|
||
tasks[task_id].message = "Blur bake completed"
|
||
print(f"[FaceMask] Bake completed: {req.output_path}")
|
||
|
||
except Exception as e:
|
||
tasks[task_id].status = TaskStatus.FAILED
|
||
tasks[task_id].message = str(e)
|
||
print(f"Error in async bake task {task_id}: {e}")
|
||
traceback.print_exc()
|
||
finally:
|
||
if task_id in cancel_events:
|
||
del cancel_events[task_id]
|
||
|
||
|
||
def check_gpu_available() -> dict:
|
||
"""
|
||
Check if GPU is available for inference.
|
||
|
||
Returns a dict with GPU information:
|
||
{
|
||
"available": bool,
|
||
"device_name": str or None,
|
||
"device_count": int,
|
||
"rocm_version": str or None
|
||
}
|
||
"""
|
||
global _gpu_status_cache
|
||
|
||
# Return cached result if available
|
||
if _gpu_status_cache is not None:
|
||
return _gpu_status_cache
|
||
|
||
result = {
|
||
"available": False,
|
||
"device_name": None,
|
||
"device_count": 0,
|
||
"rocm_version": None
|
||
}
|
||
|
||
try:
|
||
import torch
|
||
|
||
result["available"] = torch.cuda.is_available()
|
||
|
||
if result["available"]:
|
||
result["device_count"] = torch.cuda.device_count()
|
||
if result["device_count"] > 0:
|
||
result["device_name"] = torch.cuda.get_device_name(0)
|
||
|
||
if hasattr(torch.version, 'hip'):
|
||
result["rocm_version"] = torch.version.hip
|
||
|
||
except Exception as e:
|
||
print(f"[FaceMask] Warning: GPU detection failed: {e}")
|
||
result["available"] = False
|
||
|
||
# Cache the result
|
||
_gpu_status_cache = result
|
||
|
||
return result
|
||
|
||
|
||
def log_startup_diagnostics():
|
||
"""Log diagnostic information about the environment and GPU."""
|
||
print("=" * 70)
|
||
print("[FaceMask Server] Startup Diagnostics")
|
||
print("=" * 70)
|
||
|
||
# Python Environment
|
||
print("\n[Python Environment]")
|
||
print(f" Python Version: {sys.version.split()[0]}")
|
||
print(f" Python Executable: {sys.executable}")
|
||
print(f" Platform: {platform.platform()}")
|
||
print(f" Working Directory: {os.getcwd()}")
|
||
|
||
# Check if in venv
|
||
in_venv = sys.prefix != sys.base_prefix
|
||
print(f" Virtual Environment: {'Yes' if in_venv else 'No'}")
|
||
if in_venv:
|
||
print(f" venv path: {sys.prefix}")
|
||
|
||
# ROCm Environment Variables
|
||
print("\n[ROCm Environment Variables]")
|
||
rocm_vars = [
|
||
'ROCM_PATH',
|
||
'HSA_OVERRIDE_GFX_VERSION',
|
||
'PYTORCH_ROCM_ARCH',
|
||
'ROCBLAS_TENSILE_LIBPATH',
|
||
'LD_LIBRARY_PATH'
|
||
]
|
||
|
||
for var in rocm_vars:
|
||
value = os.environ.get(var)
|
||
if value:
|
||
# For LD_LIBRARY_PATH, show if ROCm paths are included
|
||
if var == 'LD_LIBRARY_PATH':
|
||
has_rocm = 'rocm' in value.lower() or 'clr-' in value
|
||
has_hip = 'hip' in value.lower()
|
||
print(f" {var}: {value[:100]}...")
|
||
print(f" Contains ROCm paths: {has_rocm}")
|
||
print(f" Contains HIP paths: {has_hip}")
|
||
if not has_rocm:
|
||
print(" ⚠️ WARNING: ROCm library paths not found!")
|
||
else:
|
||
if len(value) > 200:
|
||
display_value = value[:200] + "... (truncated)"
|
||
else:
|
||
display_value = value
|
||
print(f" {var}: {display_value}")
|
||
else:
|
||
print(f" {var}: (not set)")
|
||
|
||
# GPU Detection
|
||
print("\n[GPU Detection]")
|
||
try:
|
||
import torch
|
||
|
||
cuda_available = torch.cuda.is_available()
|
||
print(f" torch.cuda.is_available(): {cuda_available}")
|
||
|
||
if cuda_available:
|
||
device_count = torch.cuda.device_count()
|
||
print(f" GPU Device Count: {device_count}")
|
||
|
||
if device_count > 0:
|
||
device_name = torch.cuda.get_device_name(0)
|
||
print(f" GPU Device 0: {device_name}")
|
||
|
||
# ROCm version
|
||
if hasattr(torch.version, 'hip'):
|
||
print(f" ROCm Version (HIP): {torch.version.hip}")
|
||
|
||
# CUDA version (might be emulated by ROCm)
|
||
if torch.version.cuda:
|
||
print(f" CUDA Version: {torch.version.cuda}")
|
||
else:
|
||
print(" WARNING: GPU not detected!")
|
||
print(" Server will use CPU for inference (slower)")
|
||
print(" Troubleshooting:")
|
||
print(" - Check ROCm environment variables above")
|
||
print(" - Run: python -c 'import torch; print(torch.cuda.is_available())'")
|
||
|
||
except ImportError as e:
|
||
print(f" ERROR: Cannot import torch: {e}")
|
||
print(" PyTorch must be installed for inference")
|
||
except Exception as e:
|
||
print(f" ERROR during GPU detection: {e}")
|
||
|
||
print("=" * 70)
|
||
print()
|
||
|
||
|
||
@app.get("/status")
|
||
def get_status():
|
||
gpu_info = check_gpu_available()
|
||
|
||
return {
|
||
"status": "running",
|
||
"gpu_available": gpu_info["available"],
|
||
"gpu_device": gpu_info["device_name"],
|
||
"gpu_count": gpu_info["device_count"],
|
||
"rocm_version": gpu_info["rocm_version"]
|
||
}
|
||
|
||
|
||
@app.post("/video_info")
|
||
def get_video_info(req: VideoInfoRequest):
|
||
if not os.path.exists(req.video_path):
|
||
raise HTTPException(status_code=404, detail=f"Video not found: {req.video_path}")
|
||
|
||
cap = cv2.VideoCapture(req.video_path)
|
||
if not cap.isOpened():
|
||
raise HTTPException(status_code=400, detail="Failed to open video")
|
||
|
||
try:
|
||
avg_fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0)
|
||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
|
||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
|
||
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
|
||
finally:
|
||
cap.release()
|
||
|
||
# Blender は r_frame_rate でタイムライン配置を計算するため、
|
||
# cv2 の avg_frame_rate ではなく r_frame_rate を fps として返す。
|
||
# 例: 120fps タイムベース記録の 60fps 動画で r_frame_rate=120 を返すことで
|
||
# compute_strip_frame_range の fps_ratio が Blender の解釈と一致する。
|
||
r_fps_float, _ = _get_r_frame_rate(req.video_path)
|
||
fps = r_fps_float if r_fps_float > 0 else avg_fps
|
||
|
||
return {
|
||
"video_path": req.video_path,
|
||
"fps": fps,
|
||
"width": width,
|
||
"height": height,
|
||
"frame_count": frame_count,
|
||
}
|
||
|
||
@app.post("/generate", response_model=Task)
|
||
def generate_mask_endpoint(req: GenerateRequest, background_tasks: BackgroundTasks):
|
||
task_id = str(uuid.uuid4())
|
||
task = Task(id=task_id, status=TaskStatus.PENDING)
|
||
tasks[task_id] = task
|
||
cancel_events[task_id] = threading.Event()
|
||
|
||
background_tasks.add_task(process_video_task, task_id, req)
|
||
return task
|
||
|
||
|
||
@app.post("/bake_blur", response_model=Task)
|
||
def bake_blur_endpoint(req: BakeRequest, background_tasks: BackgroundTasks):
|
||
task_id = str(uuid.uuid4())
|
||
task = Task(id=task_id, status=TaskStatus.PENDING)
|
||
tasks[task_id] = task
|
||
cancel_events[task_id] = threading.Event()
|
||
|
||
background_tasks.add_task(process_bake_task, task_id, req)
|
||
return task
|
||
|
||
@app.get("/tasks/{task_id}", response_model=Task)
|
||
def get_task(task_id: str):
|
||
if task_id not in tasks:
|
||
raise HTTPException(status_code=404, detail="Task not found")
|
||
return tasks[task_id]
|
||
|
||
@app.post("/tasks/{task_id}/cancel")
|
||
def cancel_task(task_id: str):
|
||
if task_id not in tasks:
|
||
raise HTTPException(status_code=404, detail="Task not found")
|
||
|
||
if task_id in cancel_events:
|
||
cancel_events[task_id].set()
|
||
|
||
return {"message": "Cancellation requested"}
|
||
|
||
if __name__ == "__main__":
|
||
log_startup_diagnostics()
|
||
uvicorn.run(app, host="127.0.0.1", port=8181)
|