323 lines
9.6 KiB
Python
323 lines
9.6 KiB
Python
"""
|
|
Face Detection Inference Server.
|
|
|
|
This FastAPI application runs in a separate process to handle
|
|
GPU-accelerated face detection using ONNX Runtime.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import platform
|
|
import threading
|
|
import uuid
|
|
import queue
|
|
import traceback
|
|
from typing import Dict, Optional, List
|
|
from pathlib import Path
|
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
|
from pydantic import BaseModel
|
|
import uvicorn
|
|
import cv2
|
|
import numpy as np
|
|
|
|
# Add project root to path for imports if needed
|
|
sys.path.append(str(Path(__file__).parent.parent))
|
|
|
|
from server.detector import YOLOFaceDetector, get_detector
|
|
|
|
app = FastAPI(title="Face Mask Inference Server")
|
|
|
|
# GPU status cache
|
|
_gpu_status_cache = None
|
|
|
|
# Task storage
|
|
class TaskStatus:
|
|
PENDING = "pending"
|
|
PROCESSING = "processing"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
CANCELLED = "cancelled"
|
|
|
|
class Task(BaseModel):
|
|
id: str
|
|
status: str
|
|
progress: int = 0
|
|
total: int = 0
|
|
message: Optional[str] = None
|
|
result_path: Optional[str] = None
|
|
|
|
# In-memory storage
|
|
tasks: Dict[str, Task] = {}
|
|
cancel_events: Dict[str, threading.Event] = {}
|
|
|
|
class GenerateRequest(BaseModel):
|
|
video_path: str
|
|
output_dir: str
|
|
start_frame: int
|
|
end_frame: int
|
|
conf_threshold: float = 0.5
|
|
iou_threshold: float = 0.45
|
|
mask_scale: float = 1.5
|
|
|
|
def process_video_task(task_id: str, req: GenerateRequest):
|
|
"""Background task to process video."""
|
|
try:
|
|
tasks[task_id].status = TaskStatus.PROCESSING
|
|
cancel_event = cancel_events.get(task_id)
|
|
|
|
# Verify video exists
|
|
if not os.path.exists(req.video_path):
|
|
tasks[task_id].status = TaskStatus.FAILED
|
|
tasks[task_id].message = f"Video not found: {req.video_path}"
|
|
return
|
|
|
|
# Initialize detector (will load model on first run)
|
|
print(f"Loading detector for task {task_id}...")
|
|
detector = get_detector(
|
|
conf_threshold=req.conf_threshold,
|
|
iou_threshold=req.iou_threshold
|
|
)
|
|
# Ensure model is loaded
|
|
_ = detector.model
|
|
|
|
# Open video
|
|
cap = cv2.VideoCapture(req.video_path)
|
|
if not cap.isOpened():
|
|
tasks[task_id].status = TaskStatus.FAILED
|
|
tasks[task_id].message = "Failed to open video"
|
|
return
|
|
|
|
# Determine frame range
|
|
total_video_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
end_frame = min(req.end_frame, total_video_frames - 1)
|
|
frames_to_process = end_frame - req.start_frame + 1
|
|
|
|
tasks[task_id].total = frames_to_process
|
|
|
|
# Ensure output directory exists
|
|
os.makedirs(req.output_dir, exist_ok=True)
|
|
|
|
print(f"Starting processing: {req.video_path} ({frames_to_process} frames)")
|
|
|
|
# Process loop
|
|
current_count = 0
|
|
for frame_idx in range(req.start_frame, end_frame + 1):
|
|
if cancel_event and cancel_event.is_set():
|
|
tasks[task_id].status = TaskStatus.CANCELLED
|
|
tasks[task_id].message = "Cancelled by user"
|
|
break
|
|
|
|
# Read frame
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
|
|
ret, frame = cap.read()
|
|
|
|
if ret:
|
|
# Detect
|
|
detections = detector.detect(frame)
|
|
|
|
# Generate mask
|
|
mask = detector.generate_mask(
|
|
frame.shape,
|
|
detections,
|
|
mask_scale=req.mask_scale
|
|
)
|
|
|
|
# Save
|
|
mask_filename = f"mask_{current_count:06d}.png" # Note: using relative index for filename
|
|
mask_path = os.path.join(req.output_dir, mask_filename)
|
|
cv2.imwrite(mask_path, mask)
|
|
|
|
# Update progress
|
|
current_count += 1
|
|
tasks[task_id].progress = current_count
|
|
|
|
cap.release()
|
|
|
|
if tasks[task_id].status == TaskStatus.PROCESSING:
|
|
tasks[task_id].status = TaskStatus.COMPLETED
|
|
tasks[task_id].result_path = req.output_dir
|
|
tasks[task_id].message = "Processing completed successfully"
|
|
print(f"Task {task_id} completed.")
|
|
|
|
except Exception as e:
|
|
tasks[task_id].status = TaskStatus.FAILED
|
|
tasks[task_id].message = str(e)
|
|
print(f"Error in task {task_id}: {e}")
|
|
traceback.print_exc()
|
|
finally:
|
|
# Cleanup
|
|
if task_id in cancel_events:
|
|
del cancel_events[task_id]
|
|
|
|
def check_gpu_available() -> dict:
|
|
"""
|
|
Check if GPU is available for inference.
|
|
|
|
Returns a dict with GPU information:
|
|
{
|
|
"available": bool,
|
|
"device_name": str or None,
|
|
"device_count": int,
|
|
"rocm_version": str or None
|
|
}
|
|
"""
|
|
global _gpu_status_cache
|
|
|
|
# Return cached result if available
|
|
if _gpu_status_cache is not None:
|
|
return _gpu_status_cache
|
|
|
|
result = {
|
|
"available": False,
|
|
"device_name": None,
|
|
"device_count": 0,
|
|
"rocm_version": None
|
|
}
|
|
|
|
try:
|
|
import torch
|
|
|
|
result["available"] = torch.cuda.is_available()
|
|
|
|
if result["available"]:
|
|
result["device_count"] = torch.cuda.device_count()
|
|
if result["device_count"] > 0:
|
|
result["device_name"] = torch.cuda.get_device_name(0)
|
|
|
|
if hasattr(torch.version, 'hip'):
|
|
result["rocm_version"] = torch.version.hip
|
|
|
|
except Exception as e:
|
|
print(f"[FaceMask] Warning: GPU detection failed: {e}")
|
|
result["available"] = False
|
|
|
|
# Cache the result
|
|
_gpu_status_cache = result
|
|
|
|
return result
|
|
|
|
|
|
def log_startup_diagnostics():
|
|
"""Log diagnostic information about the environment and GPU."""
|
|
print("=" * 70)
|
|
print("[FaceMask Server] Startup Diagnostics")
|
|
print("=" * 70)
|
|
|
|
# Python Environment
|
|
print("\n[Python Environment]")
|
|
print(f" Python Version: {sys.version.split()[0]}")
|
|
print(f" Python Executable: {sys.executable}")
|
|
print(f" Platform: {platform.platform()}")
|
|
print(f" Working Directory: {os.getcwd()}")
|
|
|
|
# Check if in venv
|
|
in_venv = sys.prefix != sys.base_prefix
|
|
print(f" Virtual Environment: {'Yes' if in_venv else 'No'}")
|
|
if in_venv:
|
|
print(f" venv path: {sys.prefix}")
|
|
|
|
# ROCm Environment Variables
|
|
print("\n[ROCm Environment Variables]")
|
|
rocm_vars = [
|
|
'ROCM_PATH',
|
|
'HSA_OVERRIDE_GFX_VERSION',
|
|
'PYTORCH_ROCM_ARCH',
|
|
'ROCBLAS_TENSILE_LIBPATH',
|
|
'LD_LIBRARY_PATH'
|
|
]
|
|
|
|
for var in rocm_vars:
|
|
value = os.environ.get(var)
|
|
if value:
|
|
# Truncate very long values
|
|
if len(value) > 200:
|
|
display_value = value[:200] + "... (truncated)"
|
|
else:
|
|
display_value = value
|
|
print(f" {var}: {display_value}")
|
|
else:
|
|
print(f" {var}: (not set)")
|
|
|
|
# GPU Detection
|
|
print("\n[GPU Detection]")
|
|
try:
|
|
import torch
|
|
|
|
cuda_available = torch.cuda.is_available()
|
|
print(f" torch.cuda.is_available(): {cuda_available}")
|
|
|
|
if cuda_available:
|
|
device_count = torch.cuda.device_count()
|
|
print(f" GPU Device Count: {device_count}")
|
|
|
|
if device_count > 0:
|
|
device_name = torch.cuda.get_device_name(0)
|
|
print(f" GPU Device 0: {device_name}")
|
|
|
|
# ROCm version
|
|
if hasattr(torch.version, 'hip'):
|
|
print(f" ROCm Version (HIP): {torch.version.hip}")
|
|
|
|
# CUDA version (might be emulated by ROCm)
|
|
if torch.version.cuda:
|
|
print(f" CUDA Version: {torch.version.cuda}")
|
|
else:
|
|
print(" WARNING: GPU not detected!")
|
|
print(" Server will use CPU for inference (slower)")
|
|
print(" Troubleshooting:")
|
|
print(" - Check ROCm environment variables above")
|
|
print(" - Run: python -c 'import torch; print(torch.cuda.is_available())'")
|
|
|
|
except ImportError as e:
|
|
print(f" ERROR: Cannot import torch: {e}")
|
|
print(" PyTorch must be installed for inference")
|
|
except Exception as e:
|
|
print(f" ERROR during GPU detection: {e}")
|
|
|
|
print("=" * 70)
|
|
print()
|
|
|
|
|
|
@app.get("/status")
|
|
def get_status():
|
|
gpu_info = check_gpu_available()
|
|
|
|
return {
|
|
"status": "running",
|
|
"gpu_available": gpu_info["available"],
|
|
"gpu_device": gpu_info["device_name"],
|
|
"gpu_count": gpu_info["device_count"],
|
|
"rocm_version": gpu_info["rocm_version"]
|
|
}
|
|
|
|
@app.post("/generate", response_model=Task)
|
|
def generate_mask_endpoint(req: GenerateRequest, background_tasks: BackgroundTasks):
|
|
task_id = str(uuid.uuid4())
|
|
task = Task(id=task_id, status=TaskStatus.PENDING)
|
|
tasks[task_id] = task
|
|
cancel_events[task_id] = threading.Event()
|
|
|
|
background_tasks.add_task(process_video_task, task_id, req)
|
|
return task
|
|
|
|
@app.get("/tasks/{task_id}", response_model=Task)
|
|
def get_task(task_id: str):
|
|
if task_id not in tasks:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
return tasks[task_id]
|
|
|
|
@app.post("/tasks/{task_id}/cancel")
|
|
def cancel_task(task_id: str):
|
|
if task_id not in tasks:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
|
if task_id in cancel_events:
|
|
cancel_events[task_id].set()
|
|
|
|
return {"message": "Cancellation requested"}
|
|
|
|
if __name__ == "__main__":
|
|
log_startup_diagnostics()
|
|
uvicorn.run(app, host="127.0.0.1", port=8181)
|