424 lines
13 KiB
Python
424 lines
13 KiB
Python
"""
|
|
Face Detection Inference Server.
|
|
|
|
This FastAPI application runs in a separate process to handle
|
|
GPU-accelerated face detection using ONNX Runtime.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import platform
|
|
|
|
# CRITICAL: Fix LD_LIBRARY_PATH before importing cv2 or torch
|
|
# cv2 adds its own lib path to the front, which can override ROCm libraries
|
|
def fix_library_path():
|
|
"""Ensure ROCm libraries are loaded before cv2's bundled libraries."""
|
|
ld_path = os.environ.get('LD_LIBRARY_PATH', '')
|
|
|
|
# Split and filter paths
|
|
paths = [p for p in ld_path.split(':') if p]
|
|
|
|
# Separate ROCm/GPU paths from other paths
|
|
rocm_paths = [p for p in paths if 'rocm' in p.lower() or 'clr-' in p or 'hip' in p.lower()]
|
|
other_paths = [p for p in paths if p not in rocm_paths]
|
|
|
|
# Rebuild with ROCm paths first
|
|
if rocm_paths:
|
|
new_ld_path = ':'.join(rocm_paths + other_paths)
|
|
os.environ['LD_LIBRARY_PATH'] = new_ld_path
|
|
print(f"[FaceMask] Fixed LD_LIBRARY_PATH to prioritize ROCm libraries")
|
|
|
|
# Fix library path BEFORE any other imports
|
|
fix_library_path()
|
|
|
|
import threading
|
|
import uuid
|
|
import queue
|
|
import traceback
|
|
from typing import Dict, Optional, List
|
|
from pathlib import Path
|
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
|
from pydantic import BaseModel
|
|
import uvicorn
|
|
import cv2
|
|
import numpy as np
|
|
|
|
# Add project root to path for imports if needed
|
|
sys.path.append(str(Path(__file__).parent.parent))
|
|
|
|
from server.detector import YOLOFaceDetector, get_detector
|
|
|
|
app = FastAPI(title="Face Mask Inference Server")
|
|
|
|
# GPU status cache
|
|
_gpu_status_cache = None
|
|
|
|
# Task storage
|
|
class TaskStatus:
|
|
PENDING = "pending"
|
|
PROCESSING = "processing"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
CANCELLED = "cancelled"
|
|
|
|
class Task(BaseModel):
|
|
id: str
|
|
status: str
|
|
progress: int = 0
|
|
total: int = 0
|
|
message: Optional[str] = None
|
|
result_path: Optional[str] = None
|
|
|
|
# In-memory storage
|
|
tasks: Dict[str, Task] = {}
|
|
cancel_events: Dict[str, threading.Event] = {}
|
|
|
|
class GenerateRequest(BaseModel):
|
|
video_path: str
|
|
output_dir: str
|
|
start_frame: int
|
|
end_frame: int
|
|
conf_threshold: float = 0.5
|
|
iou_threshold: float = 0.45
|
|
mask_scale: float = 1.5
|
|
|
|
def process_video_task(task_id: str, req: GenerateRequest):
|
|
"""Background task to process video with async MP4 output."""
|
|
writer = None
|
|
write_queue = None
|
|
writer_thread = None
|
|
|
|
try:
|
|
tasks[task_id].status = TaskStatus.PROCESSING
|
|
cancel_event = cancel_events.get(task_id)
|
|
|
|
# Verify video exists
|
|
if not os.path.exists(req.video_path):
|
|
tasks[task_id].status = TaskStatus.FAILED
|
|
tasks[task_id].message = f"Video not found: {req.video_path}"
|
|
return
|
|
|
|
# Initialize detector (will load model on first run)
|
|
print(f"Loading detector for task {task_id}...")
|
|
detector = get_detector(
|
|
conf_threshold=req.conf_threshold,
|
|
iou_threshold=req.iou_threshold
|
|
)
|
|
_ = detector.model
|
|
|
|
# Open video
|
|
cap = cv2.VideoCapture(req.video_path)
|
|
if not cap.isOpened():
|
|
tasks[task_id].status = TaskStatus.FAILED
|
|
tasks[task_id].message = "Failed to open video"
|
|
return
|
|
|
|
# Get video properties
|
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
total_video_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
end_frame = min(req.end_frame, total_video_frames - 1)
|
|
frames_to_process = end_frame - req.start_frame + 1
|
|
|
|
tasks[task_id].total = frames_to_process
|
|
|
|
# Ensure output directory exists
|
|
os.makedirs(req.output_dir, exist_ok=True)
|
|
|
|
# Setup MP4 writer (grayscale)
|
|
output_video_path = os.path.join(req.output_dir, "mask.mp4")
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
writer = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height), isColor=False)
|
|
|
|
if not writer.isOpened():
|
|
tasks[task_id].status = TaskStatus.FAILED
|
|
tasks[task_id].message = "Failed to create video writer"
|
|
cap.release()
|
|
return
|
|
|
|
# Async writer setup
|
|
write_queue = queue.Queue(maxsize=30) # Buffer up to 30 frames
|
|
writer_running = threading.Event()
|
|
writer_running.set()
|
|
|
|
def async_writer():
|
|
"""Background thread for writing frames to video."""
|
|
while writer_running.is_set() or not write_queue.empty():
|
|
try:
|
|
mask = write_queue.get(timeout=0.1)
|
|
if mask is not None:
|
|
writer.write(mask)
|
|
write_queue.task_done()
|
|
except queue.Empty:
|
|
continue
|
|
|
|
writer_thread = threading.Thread(target=async_writer, daemon=True)
|
|
writer_thread.start()
|
|
|
|
print(f"Starting processing: {req.video_path} ({frames_to_process} frames) -> {output_video_path}")
|
|
|
|
# Batch processing configuration
|
|
BATCH_SIZE = 5 # Optimal batch size for 4K video (72.9% improvement)
|
|
frame_buffer = []
|
|
|
|
def process_batch():
|
|
"""Process accumulated batch of frames."""
|
|
if not frame_buffer:
|
|
return
|
|
|
|
# Batch inference at full resolution
|
|
batch_detections = detector.detect_batch(frame_buffer)
|
|
|
|
# Generate masks for each frame
|
|
for i, detections in enumerate(batch_detections):
|
|
frame = frame_buffer[i]
|
|
|
|
# Generate mask at original resolution
|
|
mask = detector.generate_mask(
|
|
frame.shape,
|
|
detections,
|
|
mask_scale=req.mask_scale
|
|
)
|
|
|
|
# Async write to queue
|
|
write_queue.put(mask)
|
|
|
|
# Clear buffer
|
|
frame_buffer.clear()
|
|
|
|
# Process loop with batching
|
|
current_count = 0
|
|
for frame_idx in range(req.start_frame, end_frame + 1):
|
|
if cancel_event and cancel_event.is_set():
|
|
tasks[task_id].status = TaskStatus.CANCELLED
|
|
tasks[task_id].message = "Cancelled by user"
|
|
break
|
|
|
|
# Read frame
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
|
|
ret, frame = cap.read()
|
|
|
|
if ret:
|
|
# Store frame for batch processing
|
|
frame_buffer.append(frame)
|
|
|
|
# Process batch when full
|
|
if len(frame_buffer) >= BATCH_SIZE:
|
|
process_batch()
|
|
|
|
# Update progress
|
|
current_count += 1
|
|
tasks[task_id].progress = current_count
|
|
|
|
# Process remaining frames in buffer
|
|
if frame_buffer:
|
|
process_batch()
|
|
|
|
# Cleanup
|
|
writer_running.clear()
|
|
write_queue.join() # Wait for all frames to be written
|
|
if writer_thread:
|
|
writer_thread.join(timeout=5)
|
|
|
|
cap.release()
|
|
if writer:
|
|
writer.release()
|
|
|
|
if tasks[task_id].status == TaskStatus.PROCESSING:
|
|
tasks[task_id].status = TaskStatus.COMPLETED
|
|
tasks[task_id].result_path = output_video_path # Return video path
|
|
tasks[task_id].message = "Processing completed successfully"
|
|
print(f"Task {task_id} completed: {output_video_path}")
|
|
|
|
except Exception as e:
|
|
tasks[task_id].status = TaskStatus.FAILED
|
|
tasks[task_id].message = str(e)
|
|
print(f"Error in task {task_id}: {e}")
|
|
traceback.print_exc()
|
|
finally:
|
|
# Cleanup
|
|
if task_id in cancel_events:
|
|
del cancel_events[task_id]
|
|
|
|
def check_gpu_available() -> dict:
|
|
"""
|
|
Check if GPU is available for inference.
|
|
|
|
Returns a dict with GPU information:
|
|
{
|
|
"available": bool,
|
|
"device_name": str or None,
|
|
"device_count": int,
|
|
"rocm_version": str or None
|
|
}
|
|
"""
|
|
global _gpu_status_cache
|
|
|
|
# Return cached result if available
|
|
if _gpu_status_cache is not None:
|
|
return _gpu_status_cache
|
|
|
|
result = {
|
|
"available": False,
|
|
"device_name": None,
|
|
"device_count": 0,
|
|
"rocm_version": None
|
|
}
|
|
|
|
try:
|
|
import torch
|
|
|
|
result["available"] = torch.cuda.is_available()
|
|
|
|
if result["available"]:
|
|
result["device_count"] = torch.cuda.device_count()
|
|
if result["device_count"] > 0:
|
|
result["device_name"] = torch.cuda.get_device_name(0)
|
|
|
|
if hasattr(torch.version, 'hip'):
|
|
result["rocm_version"] = torch.version.hip
|
|
|
|
except Exception as e:
|
|
print(f"[FaceMask] Warning: GPU detection failed: {e}")
|
|
result["available"] = False
|
|
|
|
# Cache the result
|
|
_gpu_status_cache = result
|
|
|
|
return result
|
|
|
|
|
|
def log_startup_diagnostics():
|
|
"""Log diagnostic information about the environment and GPU."""
|
|
print("=" * 70)
|
|
print("[FaceMask Server] Startup Diagnostics")
|
|
print("=" * 70)
|
|
|
|
# Python Environment
|
|
print("\n[Python Environment]")
|
|
print(f" Python Version: {sys.version.split()[0]}")
|
|
print(f" Python Executable: {sys.executable}")
|
|
print(f" Platform: {platform.platform()}")
|
|
print(f" Working Directory: {os.getcwd()}")
|
|
|
|
# Check if in venv
|
|
in_venv = sys.prefix != sys.base_prefix
|
|
print(f" Virtual Environment: {'Yes' if in_venv else 'No'}")
|
|
if in_venv:
|
|
print(f" venv path: {sys.prefix}")
|
|
|
|
# ROCm Environment Variables
|
|
print("\n[ROCm Environment Variables]")
|
|
rocm_vars = [
|
|
'ROCM_PATH',
|
|
'HSA_OVERRIDE_GFX_VERSION',
|
|
'PYTORCH_ROCM_ARCH',
|
|
'ROCBLAS_TENSILE_LIBPATH',
|
|
'LD_LIBRARY_PATH'
|
|
]
|
|
|
|
for var in rocm_vars:
|
|
value = os.environ.get(var)
|
|
if value:
|
|
# For LD_LIBRARY_PATH, show if ROCm paths are included
|
|
if var == 'LD_LIBRARY_PATH':
|
|
has_rocm = 'rocm' in value.lower() or 'clr-' in value
|
|
has_hip = 'hip' in value.lower()
|
|
print(f" {var}: {value[:100]}...")
|
|
print(f" Contains ROCm paths: {has_rocm}")
|
|
print(f" Contains HIP paths: {has_hip}")
|
|
if not has_rocm:
|
|
print(f" ⚠️ WARNING: ROCm library paths not found!")
|
|
else:
|
|
if len(value) > 200:
|
|
display_value = value[:200] + "... (truncated)"
|
|
else:
|
|
display_value = value
|
|
print(f" {var}: {display_value}")
|
|
else:
|
|
print(f" {var}: (not set)")
|
|
|
|
# GPU Detection
|
|
print("\n[GPU Detection]")
|
|
try:
|
|
import torch
|
|
|
|
cuda_available = torch.cuda.is_available()
|
|
print(f" torch.cuda.is_available(): {cuda_available}")
|
|
|
|
if cuda_available:
|
|
device_count = torch.cuda.device_count()
|
|
print(f" GPU Device Count: {device_count}")
|
|
|
|
if device_count > 0:
|
|
device_name = torch.cuda.get_device_name(0)
|
|
print(f" GPU Device 0: {device_name}")
|
|
|
|
# ROCm version
|
|
if hasattr(torch.version, 'hip'):
|
|
print(f" ROCm Version (HIP): {torch.version.hip}")
|
|
|
|
# CUDA version (might be emulated by ROCm)
|
|
if torch.version.cuda:
|
|
print(f" CUDA Version: {torch.version.cuda}")
|
|
else:
|
|
print(" WARNING: GPU not detected!")
|
|
print(" Server will use CPU for inference (slower)")
|
|
print(" Troubleshooting:")
|
|
print(" - Check ROCm environment variables above")
|
|
print(" - Run: python -c 'import torch; print(torch.cuda.is_available())'")
|
|
|
|
except ImportError as e:
|
|
print(f" ERROR: Cannot import torch: {e}")
|
|
print(" PyTorch must be installed for inference")
|
|
except Exception as e:
|
|
print(f" ERROR during GPU detection: {e}")
|
|
|
|
print("=" * 70)
|
|
print()
|
|
|
|
|
|
@app.get("/status")
|
|
def get_status():
|
|
gpu_info = check_gpu_available()
|
|
|
|
return {
|
|
"status": "running",
|
|
"gpu_available": gpu_info["available"],
|
|
"gpu_device": gpu_info["device_name"],
|
|
"gpu_count": gpu_info["device_count"],
|
|
"rocm_version": gpu_info["rocm_version"]
|
|
}
|
|
|
|
@app.post("/generate", response_model=Task)
|
|
def generate_mask_endpoint(req: GenerateRequest, background_tasks: BackgroundTasks):
|
|
task_id = str(uuid.uuid4())
|
|
task = Task(id=task_id, status=TaskStatus.PENDING)
|
|
tasks[task_id] = task
|
|
cancel_events[task_id] = threading.Event()
|
|
|
|
background_tasks.add_task(process_video_task, task_id, req)
|
|
return task
|
|
|
|
@app.get("/tasks/{task_id}", response_model=Task)
|
|
def get_task(task_id: str):
|
|
if task_id not in tasks:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
return tasks[task_id]
|
|
|
|
@app.post("/tasks/{task_id}/cancel")
|
|
def cancel_task(task_id: str):
|
|
if task_id not in tasks:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
|
if task_id in cancel_events:
|
|
cancel_events[task_id].set()
|
|
|
|
return {"message": "Cancellation requested"}
|
|
|
|
if __name__ == "__main__":
|
|
log_startup_diagnostics()
|
|
uvicorn.run(app, host="127.0.0.1", port=8181)
|