181 lines
5.5 KiB
Python
181 lines
5.5 KiB
Python
"""
|
|
Face Detection Inference Server.
|
|
|
|
This FastAPI application runs in a separate process to handle
|
|
GPU-accelerated face detection using ONNX Runtime.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import threading
|
|
import uuid
|
|
import queue
|
|
import traceback
|
|
from typing import Dict, Optional, List
|
|
from pathlib import Path
|
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
|
from pydantic import BaseModel
|
|
import uvicorn
|
|
import cv2
|
|
import numpy as np
|
|
|
|
# Add project root to path for imports if needed
|
|
sys.path.append(str(Path(__file__).parent.parent))
|
|
|
|
from server.detector import YOLOFaceDetector, get_detector
|
|
|
|
app = FastAPI(title="Face Mask Inference Server")
|
|
|
|
# Task storage
|
|
class TaskStatus:
|
|
PENDING = "pending"
|
|
PROCESSING = "processing"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
CANCELLED = "cancelled"
|
|
|
|
class Task(BaseModel):
|
|
id: str
|
|
status: str
|
|
progress: int = 0
|
|
total: int = 0
|
|
message: Optional[str] = None
|
|
result_path: Optional[str] = None
|
|
|
|
# In-memory storage
|
|
tasks: Dict[str, Task] = {}
|
|
cancel_events: Dict[str, threading.Event] = {}
|
|
|
|
class GenerateRequest(BaseModel):
|
|
video_path: str
|
|
output_dir: str
|
|
start_frame: int
|
|
end_frame: int
|
|
conf_threshold: float = 0.5
|
|
iou_threshold: float = 0.45
|
|
mask_scale: float = 1.5
|
|
|
|
def process_video_task(task_id: str, req: GenerateRequest):
|
|
"""Background task to process video."""
|
|
try:
|
|
tasks[task_id].status = TaskStatus.PROCESSING
|
|
cancel_event = cancel_events.get(task_id)
|
|
|
|
# Verify video exists
|
|
if not os.path.exists(req.video_path):
|
|
tasks[task_id].status = TaskStatus.FAILED
|
|
tasks[task_id].message = f"Video not found: {req.video_path}"
|
|
return
|
|
|
|
# Initialize detector (will load model on first run)
|
|
print(f"Loading detector for task {task_id}...")
|
|
detector = get_detector(
|
|
conf_threshold=req.conf_threshold,
|
|
iou_threshold=req.iou_threshold
|
|
)
|
|
# Ensure session is loaded
|
|
_ = detector.session
|
|
|
|
# Open video
|
|
cap = cv2.VideoCapture(req.video_path)
|
|
if not cap.isOpened():
|
|
tasks[task_id].status = TaskStatus.FAILED
|
|
tasks[task_id].message = "Failed to open video"
|
|
return
|
|
|
|
# Determine frame range
|
|
total_video_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
end_frame = min(req.end_frame, total_video_frames - 1)
|
|
frames_to_process = end_frame - req.start_frame + 1
|
|
|
|
tasks[task_id].total = frames_to_process
|
|
|
|
# Ensure output directory exists
|
|
os.makedirs(req.output_dir, exist_ok=True)
|
|
|
|
print(f"Starting processing: {req.video_path} ({frames_to_process} frames)")
|
|
|
|
# Process loop
|
|
current_count = 0
|
|
for frame_idx in range(req.start_frame, end_frame + 1):
|
|
if cancel_event and cancel_event.is_set():
|
|
tasks[task_id].status = TaskStatus.CANCELLED
|
|
tasks[task_id].message = "Cancelled by user"
|
|
break
|
|
|
|
# Read frame
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
|
|
ret, frame = cap.read()
|
|
|
|
if ret:
|
|
# Detect
|
|
detections = detector.detect(frame)
|
|
|
|
# Generate mask
|
|
mask = detector.generate_mask(
|
|
frame.shape,
|
|
detections,
|
|
mask_scale=req.mask_scale
|
|
)
|
|
|
|
# Save
|
|
mask_filename = f"mask_{current_count:06d}.png" # Note: using relative index for filename
|
|
mask_path = os.path.join(req.output_dir, mask_filename)
|
|
cv2.imwrite(mask_path, mask)
|
|
|
|
# Update progress
|
|
current_count += 1
|
|
tasks[task_id].progress = current_count
|
|
|
|
cap.release()
|
|
|
|
if tasks[task_id].status == TaskStatus.PROCESSING:
|
|
tasks[task_id].status = TaskStatus.COMPLETED
|
|
tasks[task_id].result_path = req.output_dir
|
|
tasks[task_id].message = "Processing completed successfully"
|
|
print(f"Task {task_id} completed.")
|
|
|
|
except Exception as e:
|
|
tasks[task_id].status = TaskStatus.FAILED
|
|
tasks[task_id].message = str(e)
|
|
print(f"Error in task {task_id}: {e}")
|
|
traceback.print_exc()
|
|
finally:
|
|
# Cleanup
|
|
if task_id in cancel_events:
|
|
del cancel_events[task_id]
|
|
|
|
@app.get("/status")
|
|
def get_status():
|
|
return {"status": "running", "gpu_available": True} # TODO: check GPU
|
|
|
|
@app.post("/generate", response_model=Task)
|
|
def generate_mask_endpoint(req: GenerateRequest, background_tasks: BackgroundTasks):
|
|
task_id = str(uuid.uuid4())
|
|
task = Task(id=task_id, status=TaskStatus.PENDING)
|
|
tasks[task_id] = task
|
|
cancel_events[task_id] = threading.Event()
|
|
|
|
background_tasks.add_task(process_video_task, task_id, req)
|
|
return task
|
|
|
|
@app.get("/tasks/{task_id}", response_model=Task)
|
|
def get_task(task_id: str):
|
|
if task_id not in tasks:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
return tasks[task_id]
|
|
|
|
@app.post("/tasks/{task_id}/cancel")
|
|
def cancel_task(task_id: str):
|
|
if task_id not in tasks:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
|
if task_id in cancel_events:
|
|
cancel_events[task_id].set()
|
|
|
|
return {"message": "Cancellation requested"}
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run(app, host="127.0.0.1", port=8181)
|