blender-mask-peoples/server/main.py
2026-02-07 07:47:23 +09:00

181 lines
5.5 KiB
Python

"""
Face Detection Inference Server.
This FastAPI application runs in a separate process to handle
GPU-accelerated face detection using ONNX Runtime.
"""
import os
import sys
import threading
import uuid
import queue
import traceback
from typing import Dict, Optional, List
from pathlib import Path
from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel
import uvicorn
import cv2
import numpy as np
# Add project root to path for imports if needed
sys.path.append(str(Path(__file__).parent.parent))
from server.detector import YOLOFaceDetector, get_detector
app = FastAPI(title="Face Mask Inference Server")
# Task storage
class TaskStatus:
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class Task(BaseModel):
id: str
status: str
progress: int = 0
total: int = 0
message: Optional[str] = None
result_path: Optional[str] = None
# In-memory storage
tasks: Dict[str, Task] = {}
cancel_events: Dict[str, threading.Event] = {}
class GenerateRequest(BaseModel):
video_path: str
output_dir: str
start_frame: int
end_frame: int
conf_threshold: float = 0.5
iou_threshold: float = 0.45
mask_scale: float = 1.5
def process_video_task(task_id: str, req: GenerateRequest):
"""Background task to process video."""
try:
tasks[task_id].status = TaskStatus.PROCESSING
cancel_event = cancel_events.get(task_id)
# Verify video exists
if not os.path.exists(req.video_path):
tasks[task_id].status = TaskStatus.FAILED
tasks[task_id].message = f"Video not found: {req.video_path}"
return
# Initialize detector (will load model on first run)
print(f"Loading detector for task {task_id}...")
detector = get_detector(
conf_threshold=req.conf_threshold,
iou_threshold=req.iou_threshold
)
# Ensure model is loaded
_ = detector.model
# Open video
cap = cv2.VideoCapture(req.video_path)
if not cap.isOpened():
tasks[task_id].status = TaskStatus.FAILED
tasks[task_id].message = "Failed to open video"
return
# Determine frame range
total_video_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
end_frame = min(req.end_frame, total_video_frames - 1)
frames_to_process = end_frame - req.start_frame + 1
tasks[task_id].total = frames_to_process
# Ensure output directory exists
os.makedirs(req.output_dir, exist_ok=True)
print(f"Starting processing: {req.video_path} ({frames_to_process} frames)")
# Process loop
current_count = 0
for frame_idx in range(req.start_frame, end_frame + 1):
if cancel_event and cancel_event.is_set():
tasks[task_id].status = TaskStatus.CANCELLED
tasks[task_id].message = "Cancelled by user"
break
# Read frame
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, frame = cap.read()
if ret:
# Detect
detections = detector.detect(frame)
# Generate mask
mask = detector.generate_mask(
frame.shape,
detections,
mask_scale=req.mask_scale
)
# Save
mask_filename = f"mask_{current_count:06d}.png" # Note: using relative index for filename
mask_path = os.path.join(req.output_dir, mask_filename)
cv2.imwrite(mask_path, mask)
# Update progress
current_count += 1
tasks[task_id].progress = current_count
cap.release()
if tasks[task_id].status == TaskStatus.PROCESSING:
tasks[task_id].status = TaskStatus.COMPLETED
tasks[task_id].result_path = req.output_dir
tasks[task_id].message = "Processing completed successfully"
print(f"Task {task_id} completed.")
except Exception as e:
tasks[task_id].status = TaskStatus.FAILED
tasks[task_id].message = str(e)
print(f"Error in task {task_id}: {e}")
traceback.print_exc()
finally:
# Cleanup
if task_id in cancel_events:
del cancel_events[task_id]
@app.get("/status")
def get_status():
return {"status": "running", "gpu_available": True} # TODO: check GPU
@app.post("/generate", response_model=Task)
def generate_mask_endpoint(req: GenerateRequest, background_tasks: BackgroundTasks):
task_id = str(uuid.uuid4())
task = Task(id=task_id, status=TaskStatus.PENDING)
tasks[task_id] = task
cancel_events[task_id] = threading.Event()
background_tasks.add_task(process_video_task, task_id, req)
return task
@app.get("/tasks/{task_id}", response_model=Task)
def get_task(task_id: str):
if task_id not in tasks:
raise HTTPException(status_code=404, detail="Task not found")
return tasks[task_id]
@app.post("/tasks/{task_id}/cancel")
def cancel_task(task_id: str):
if task_id not in tasks:
raise HTTPException(status_code=404, detail="Task not found")
if task_id in cancel_events:
cancel_events[task_id].set()
return {"message": "Cancellation requested"}
if __name__ == "__main__":
uvicorn.run(app, host="127.0.0.1", port=8181)