""" Face Detection Inference Server. This FastAPI application runs in a separate process to handle GPU-accelerated face detection using ONNX Runtime. """ import os import sys import threading import uuid import queue import traceback from typing import Dict, Optional, List from pathlib import Path from fastapi import FastAPI, HTTPException, BackgroundTasks from pydantic import BaseModel import uvicorn import cv2 import numpy as np # Add project root to path for imports if needed sys.path.append(str(Path(__file__).parent.parent)) from server.detector import YOLOFaceDetector, get_detector app = FastAPI(title="Face Mask Inference Server") # Task storage class TaskStatus: PENDING = "pending" PROCESSING = "processing" COMPLETED = "completed" FAILED = "failed" CANCELLED = "cancelled" class Task(BaseModel): id: str status: str progress: int = 0 total: int = 0 message: Optional[str] = None result_path: Optional[str] = None # In-memory storage tasks: Dict[str, Task] = {} cancel_events: Dict[str, threading.Event] = {} class GenerateRequest(BaseModel): video_path: str output_dir: str start_frame: int end_frame: int conf_threshold: float = 0.5 iou_threshold: float = 0.45 mask_scale: float = 1.5 def process_video_task(task_id: str, req: GenerateRequest): """Background task to process video.""" try: tasks[task_id].status = TaskStatus.PROCESSING cancel_event = cancel_events.get(task_id) # Verify video exists if not os.path.exists(req.video_path): tasks[task_id].status = TaskStatus.FAILED tasks[task_id].message = f"Video not found: {req.video_path}" return # Initialize detector (will load model on first run) print(f"Loading detector for task {task_id}...") detector = get_detector( conf_threshold=req.conf_threshold, iou_threshold=req.iou_threshold ) # Ensure model is loaded _ = detector.model # Open video cap = cv2.VideoCapture(req.video_path) if not cap.isOpened(): tasks[task_id].status = TaskStatus.FAILED tasks[task_id].message = "Failed to open video" return # Determine frame range total_video_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) end_frame = min(req.end_frame, total_video_frames - 1) frames_to_process = end_frame - req.start_frame + 1 tasks[task_id].total = frames_to_process # Ensure output directory exists os.makedirs(req.output_dir, exist_ok=True) print(f"Starting processing: {req.video_path} ({frames_to_process} frames)") # Process loop current_count = 0 for frame_idx in range(req.start_frame, end_frame + 1): if cancel_event and cancel_event.is_set(): tasks[task_id].status = TaskStatus.CANCELLED tasks[task_id].message = "Cancelled by user" break # Read frame cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) ret, frame = cap.read() if ret: # Detect detections = detector.detect(frame) # Generate mask mask = detector.generate_mask( frame.shape, detections, mask_scale=req.mask_scale ) # Save mask_filename = f"mask_{current_count:06d}.png" # Note: using relative index for filename mask_path = os.path.join(req.output_dir, mask_filename) cv2.imwrite(mask_path, mask) # Update progress current_count += 1 tasks[task_id].progress = current_count cap.release() if tasks[task_id].status == TaskStatus.PROCESSING: tasks[task_id].status = TaskStatus.COMPLETED tasks[task_id].result_path = req.output_dir tasks[task_id].message = "Processing completed successfully" print(f"Task {task_id} completed.") except Exception as e: tasks[task_id].status = TaskStatus.FAILED tasks[task_id].message = str(e) print(f"Error in task {task_id}: {e}") traceback.print_exc() finally: # Cleanup if task_id in cancel_events: del cancel_events[task_id] @app.get("/status") def get_status(): return {"status": "running", "gpu_available": True} # TODO: check GPU @app.post("/generate", response_model=Task) def generate_mask_endpoint(req: GenerateRequest, background_tasks: BackgroundTasks): task_id = str(uuid.uuid4()) task = Task(id=task_id, status=TaskStatus.PENDING) tasks[task_id] = task cancel_events[task_id] = threading.Event() background_tasks.add_task(process_video_task, task_id, req) return task @app.get("/tasks/{task_id}", response_model=Task) def get_task(task_id: str): if task_id not in tasks: raise HTTPException(status_code=404, detail="Task not found") return tasks[task_id] @app.post("/tasks/{task_id}/cancel") def cancel_task(task_id: str): if task_id not in tasks: raise HTTPException(status_code=404, detail="Task not found") if task_id in cancel_events: cancel_events[task_id].set() return {"message": "Cancellation requested"} if __name__ == "__main__": uvicorn.run(app, host="127.0.0.1", port=8181)