679 lines
22 KiB
Python
679 lines
22 KiB
Python
"""
|
|
Face Detection Inference Server.
|
|
|
|
This FastAPI application runs in a separate process to handle
|
|
GPU-accelerated face detection using ONNX Runtime.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import platform
|
|
|
|
# CRITICAL: Fix LD_LIBRARY_PATH before importing cv2 or torch
|
|
# cv2 adds its own lib path to the front, which can override ROCm libraries
|
|
def fix_library_path():
|
|
"""Ensure ROCm libraries are loaded before cv2's bundled libraries."""
|
|
ld_path = os.environ.get('LD_LIBRARY_PATH', '')
|
|
|
|
# Split and filter paths
|
|
paths = [p for p in ld_path.split(':') if p]
|
|
|
|
# Separate ROCm/GPU paths from other paths
|
|
rocm_paths = [p for p in paths if 'rocm' in p.lower() or 'clr-' in p or 'hip' in p.lower()]
|
|
other_paths = [p for p in paths if p not in rocm_paths]
|
|
|
|
# Rebuild with ROCm paths first
|
|
if rocm_paths:
|
|
new_ld_path = ':'.join(rocm_paths + other_paths)
|
|
os.environ['LD_LIBRARY_PATH'] = new_ld_path
|
|
print(f"[FaceMask] Fixed LD_LIBRARY_PATH to prioritize ROCm libraries")
|
|
|
|
# Fix library path BEFORE any other imports
|
|
fix_library_path()
|
|
|
|
import threading
|
|
import uuid
|
|
import queue
|
|
import traceback
|
|
from typing import Dict, Optional, List
|
|
from pathlib import Path
|
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
|
from pydantic import BaseModel
|
|
import uvicorn
|
|
import cv2
|
|
import numpy as np
|
|
|
|
# Add project root to path for imports if needed
|
|
sys.path.append(str(Path(__file__).parent.parent))
|
|
|
|
from server.detector import YOLOFaceDetector, get_detector
|
|
|
|
app = FastAPI(title="Face Mask Inference Server")
|
|
|
|
# GPU status cache
|
|
_gpu_status_cache = None
|
|
|
|
# Task storage
|
|
class TaskStatus:
|
|
PENDING = "pending"
|
|
PROCESSING = "processing"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
CANCELLED = "cancelled"
|
|
|
|
class Task(BaseModel):
|
|
id: str
|
|
status: str
|
|
progress: int = 0
|
|
total: int = 0
|
|
message: Optional[str] = None
|
|
result_path: Optional[str] = None
|
|
|
|
# In-memory storage
|
|
tasks: Dict[str, Task] = {}
|
|
cancel_events: Dict[str, threading.Event] = {}
|
|
|
|
class GenerateRequest(BaseModel):
|
|
video_path: str
|
|
output_dir: str
|
|
start_frame: int
|
|
end_frame: int
|
|
conf_threshold: float = 0.5
|
|
iou_threshold: float = 0.45
|
|
mask_scale: float = 1.5
|
|
|
|
|
|
class BakeRequest(BaseModel):
|
|
video_path: str
|
|
mask_path: str
|
|
output_path: str
|
|
blur_size: int = 50
|
|
format: str = "mp4"
|
|
|
|
|
|
def _build_video_writer(
|
|
output_path: str,
|
|
fmt: str,
|
|
fps: float,
|
|
width: int,
|
|
height: int,
|
|
) -> cv2.VideoWriter:
|
|
"""Create VideoWriter with codec fallback per format."""
|
|
format_key = fmt.lower()
|
|
codec_candidates = {
|
|
"mp4": ["avc1", "mp4v"],
|
|
"mov": ["avc1", "mp4v"],
|
|
"avi": ["MJPG", "XVID"],
|
|
}.get(format_key, ["mp4v"])
|
|
|
|
for codec in codec_candidates:
|
|
writer = cv2.VideoWriter(
|
|
output_path,
|
|
cv2.VideoWriter_fourcc(*codec),
|
|
fps,
|
|
(width, height),
|
|
isColor=True,
|
|
)
|
|
if writer.isOpened():
|
|
print(f"[FaceMask] Using output codec: {codec}")
|
|
return writer
|
|
writer.release()
|
|
|
|
raise RuntimeError(f"Failed to create video writer for format='{fmt}'")
|
|
|
|
def process_video_task(task_id: str, req: GenerateRequest):
|
|
"""Background task to process video with async MP4 output."""
|
|
writer = None
|
|
write_queue = None
|
|
writer_thread = None
|
|
|
|
try:
|
|
tasks[task_id].status = TaskStatus.PROCESSING
|
|
cancel_event = cancel_events.get(task_id)
|
|
|
|
# Verify video exists
|
|
if not os.path.exists(req.video_path):
|
|
tasks[task_id].status = TaskStatus.FAILED
|
|
tasks[task_id].message = f"Video not found: {req.video_path}"
|
|
return
|
|
|
|
# Initialize detector (will load model on first run)
|
|
print(f"Loading detector for task {task_id}...")
|
|
detector = get_detector(
|
|
conf_threshold=req.conf_threshold,
|
|
iou_threshold=req.iou_threshold
|
|
)
|
|
_ = detector.model
|
|
|
|
# Open video
|
|
cap = cv2.VideoCapture(req.video_path)
|
|
if not cap.isOpened():
|
|
tasks[task_id].status = TaskStatus.FAILED
|
|
tasks[task_id].message = "Failed to open video"
|
|
return
|
|
|
|
# Get video properties
|
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
total_video_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
end_frame = min(req.end_frame, total_video_frames - 1)
|
|
frames_to_process = end_frame - req.start_frame + 1
|
|
|
|
tasks[task_id].total = frames_to_process
|
|
|
|
# Ensure output directory exists
|
|
os.makedirs(req.output_dir, exist_ok=True)
|
|
|
|
# Setup MP4 writer (grayscale)
|
|
output_video_path = os.path.join(req.output_dir, "mask.mp4")
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
writer = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height), isColor=False)
|
|
|
|
if not writer.isOpened():
|
|
tasks[task_id].status = TaskStatus.FAILED
|
|
tasks[task_id].message = "Failed to create video writer"
|
|
cap.release()
|
|
return
|
|
|
|
# Async writer setup
|
|
write_queue = queue.Queue(maxsize=30) # Buffer up to 30 frames
|
|
writer_running = threading.Event()
|
|
writer_running.set()
|
|
|
|
def async_writer():
|
|
"""Background thread for writing frames to video."""
|
|
while writer_running.is_set() or not write_queue.empty():
|
|
try:
|
|
mask = write_queue.get(timeout=0.1)
|
|
if mask is not None:
|
|
writer.write(mask)
|
|
write_queue.task_done()
|
|
except queue.Empty:
|
|
continue
|
|
|
|
writer_thread = threading.Thread(target=async_writer, daemon=True)
|
|
writer_thread.start()
|
|
|
|
print(f"Starting processing: {req.video_path} ({frames_to_process} frames) -> {output_video_path}")
|
|
|
|
# Batch processing configuration
|
|
BATCH_SIZE = 5 # Optimal batch size for 4K video (72.9% improvement)
|
|
frame_buffer = []
|
|
TEMPORAL_SIDE_WEIGHT = 0.7
|
|
TEMPORAL_CENTER_WEIGHT = 1.0
|
|
|
|
# Temporal blending state (streaming, low-memory)
|
|
prev_mask = None
|
|
curr_mask = None
|
|
wrote_first_frame = False
|
|
|
|
def _scale_mask(mask: np.ndarray, weight: float) -> np.ndarray:
|
|
"""Scale mask intensity for temporal blending."""
|
|
if weight == 1.0:
|
|
return mask
|
|
return cv2.convertScaleAbs(mask, alpha=weight, beta=0)
|
|
|
|
def _blend_edge(base: np.ndarray, neighbor: np.ndarray) -> np.ndarray:
|
|
"""Blend for first/last frame (one-sided temporal context)."""
|
|
base_w = _scale_mask(base, TEMPORAL_CENTER_WEIGHT)
|
|
neighbor_w = _scale_mask(neighbor, TEMPORAL_SIDE_WEIGHT)
|
|
return cv2.max(base_w, neighbor_w)
|
|
|
|
def _blend_middle(prev: np.ndarray, cur: np.ndarray, nxt: np.ndarray) -> np.ndarray:
|
|
"""Blend for middle frames (previous/current/next temporal context)."""
|
|
prev_w = _scale_mask(prev, TEMPORAL_SIDE_WEIGHT)
|
|
cur_w = _scale_mask(cur, TEMPORAL_CENTER_WEIGHT)
|
|
nxt_w = _scale_mask(nxt, TEMPORAL_SIDE_WEIGHT)
|
|
return cv2.max(cur_w, cv2.max(prev_w, nxt_w))
|
|
|
|
def push_mask_temporal(raw_mask: np.ndarray):
|
|
"""Push mask and emit blended output in frame order."""
|
|
nonlocal prev_mask, curr_mask, wrote_first_frame
|
|
|
|
if prev_mask is None:
|
|
prev_mask = raw_mask
|
|
return
|
|
|
|
if curr_mask is None:
|
|
curr_mask = raw_mask
|
|
return
|
|
|
|
if not wrote_first_frame:
|
|
write_queue.put(_blend_edge(prev_mask, curr_mask))
|
|
wrote_first_frame = True
|
|
|
|
# Emit blended current frame using prev/current/next
|
|
write_queue.put(_blend_middle(prev_mask, curr_mask, raw_mask))
|
|
|
|
# Slide temporal window
|
|
prev_mask = curr_mask
|
|
curr_mask = raw_mask
|
|
|
|
def flush_temporal_tail():
|
|
"""Flush remaining masks after all frames are processed."""
|
|
if prev_mask is None:
|
|
return
|
|
|
|
# Single-frame case
|
|
if curr_mask is None:
|
|
write_queue.put(_scale_mask(prev_mask, TEMPORAL_CENTER_WEIGHT))
|
|
return
|
|
|
|
# Two-frame case
|
|
if not wrote_first_frame:
|
|
write_queue.put(_blend_edge(prev_mask, curr_mask))
|
|
|
|
# Always emit last frame with one-sided blend
|
|
write_queue.put(_blend_edge(curr_mask, prev_mask))
|
|
|
|
def process_batch():
|
|
"""Process accumulated batch of frames."""
|
|
if not frame_buffer:
|
|
return
|
|
|
|
# Batch inference at full resolution
|
|
batch_detections = detector.detect_batch(frame_buffer)
|
|
|
|
# Generate masks for each frame
|
|
for i, detections in enumerate(batch_detections):
|
|
frame = frame_buffer[i]
|
|
|
|
# Generate mask at original resolution
|
|
mask = detector.generate_mask(
|
|
frame.shape,
|
|
detections,
|
|
mask_scale=req.mask_scale
|
|
)
|
|
|
|
# Temporal blend before async write
|
|
push_mask_temporal(mask)
|
|
|
|
# Clear buffer
|
|
frame_buffer.clear()
|
|
|
|
# Seek once to the starting frame. Avoid random-access seek on every frame.
|
|
if req.start_frame > 0:
|
|
seek_ok = cap.set(cv2.CAP_PROP_POS_FRAMES, req.start_frame)
|
|
if not seek_ok:
|
|
print(
|
|
f"[FaceMask] Warning: CAP_PROP_POS_FRAMES seek failed, "
|
|
f"fallback to sequential skip ({req.start_frame} frames)"
|
|
)
|
|
for _ in range(req.start_frame):
|
|
ret, _ = cap.read()
|
|
if not ret:
|
|
tasks[task_id].status = TaskStatus.FAILED
|
|
tasks[task_id].message = (
|
|
f"Failed to seek to start frame: {req.start_frame}"
|
|
)
|
|
return
|
|
|
|
# Process loop with batching
|
|
current_count = 0
|
|
for frame_idx in range(req.start_frame, end_frame + 1):
|
|
if cancel_event and cancel_event.is_set():
|
|
tasks[task_id].status = TaskStatus.CANCELLED
|
|
tasks[task_id].message = "Cancelled by user"
|
|
break
|
|
|
|
# Read next frame sequentially (after one-time initial seek)
|
|
ret, frame = cap.read()
|
|
|
|
if ret:
|
|
# Store frame for batch processing
|
|
frame_buffer.append(frame)
|
|
|
|
# Process batch when full
|
|
if len(frame_buffer) >= BATCH_SIZE:
|
|
process_batch()
|
|
|
|
# Update progress
|
|
current_count += 1
|
|
tasks[task_id].progress = current_count
|
|
|
|
# Process remaining frames in buffer
|
|
if frame_buffer:
|
|
process_batch()
|
|
flush_temporal_tail()
|
|
|
|
# Cleanup
|
|
writer_running.clear()
|
|
write_queue.join() # Wait for all frames to be written
|
|
if writer_thread:
|
|
writer_thread.join(timeout=5)
|
|
|
|
cap.release()
|
|
if writer:
|
|
writer.release()
|
|
|
|
if tasks[task_id].status == TaskStatus.PROCESSING:
|
|
tasks[task_id].status = TaskStatus.COMPLETED
|
|
tasks[task_id].result_path = output_video_path # Return video path
|
|
tasks[task_id].message = "Processing completed successfully"
|
|
print(f"Task {task_id} completed: {output_video_path}")
|
|
|
|
except Exception as e:
|
|
tasks[task_id].status = TaskStatus.FAILED
|
|
tasks[task_id].message = str(e)
|
|
print(f"Error in task {task_id}: {e}")
|
|
traceback.print_exc()
|
|
finally:
|
|
# Cleanup
|
|
if task_id in cancel_events:
|
|
del cancel_events[task_id]
|
|
|
|
|
|
def process_bake_task(task_id: str, req: BakeRequest):
|
|
"""Background task to bake blur into a regular video file."""
|
|
src_cap = None
|
|
mask_cap = None
|
|
writer = None
|
|
|
|
try:
|
|
tasks[task_id].status = TaskStatus.PROCESSING
|
|
cancel_event = cancel_events.get(task_id)
|
|
|
|
if not os.path.exists(req.video_path):
|
|
tasks[task_id].status = TaskStatus.FAILED
|
|
tasks[task_id].message = f"Video not found: {req.video_path}"
|
|
return
|
|
|
|
if not os.path.exists(req.mask_path):
|
|
tasks[task_id].status = TaskStatus.FAILED
|
|
tasks[task_id].message = f"Mask video not found: {req.mask_path}"
|
|
return
|
|
|
|
src_cap = cv2.VideoCapture(req.video_path)
|
|
mask_cap = cv2.VideoCapture(req.mask_path)
|
|
|
|
if not src_cap.isOpened():
|
|
tasks[task_id].status = TaskStatus.FAILED
|
|
tasks[task_id].message = "Failed to open source video"
|
|
return
|
|
if not mask_cap.isOpened():
|
|
tasks[task_id].status = TaskStatus.FAILED
|
|
tasks[task_id].message = "Failed to open mask video"
|
|
return
|
|
|
|
src_fps = src_cap.get(cv2.CAP_PROP_FPS) or 30.0
|
|
src_width = int(src_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
src_height = int(src_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
src_frames = int(src_cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
mask_frames = int(mask_cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
|
if src_width <= 0 or src_height <= 0:
|
|
tasks[task_id].status = TaskStatus.FAILED
|
|
tasks[task_id].message = "Invalid source video dimensions"
|
|
return
|
|
|
|
total = min(src_frames, mask_frames) if src_frames > 0 and mask_frames > 0 else 0
|
|
if total <= 0:
|
|
tasks[task_id].status = TaskStatus.FAILED
|
|
tasks[task_id].message = "Source/mask frame count is zero"
|
|
return
|
|
tasks[task_id].total = total
|
|
|
|
output_dir = os.path.dirname(req.output_path)
|
|
if output_dir:
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
writer = _build_video_writer(req.output_path, req.format, src_fps, src_width, src_height)
|
|
|
|
# Kernel size must be odd and >= 1
|
|
blur_size = max(1, int(req.blur_size))
|
|
if blur_size % 2 == 0:
|
|
blur_size += 1
|
|
|
|
print(f"[FaceMask] Starting blur bake: {req.video_path} + {req.mask_path} -> {req.output_path}")
|
|
if src_frames != mask_frames:
|
|
print(
|
|
f"[FaceMask] Warning: frame count mismatch "
|
|
f"(src={src_frames}, mask={mask_frames}), processing {total} frames"
|
|
)
|
|
|
|
for idx in range(total):
|
|
if cancel_event and cancel_event.is_set():
|
|
tasks[task_id].status = TaskStatus.CANCELLED
|
|
tasks[task_id].message = "Cancelled by user"
|
|
break
|
|
|
|
src_ok, src_frame = src_cap.read()
|
|
mask_ok, mask_frame = mask_cap.read()
|
|
if not src_ok or not mask_ok:
|
|
break
|
|
|
|
if mask_frame.ndim == 3:
|
|
mask_gray = cv2.cvtColor(mask_frame, cv2.COLOR_BGR2GRAY)
|
|
else:
|
|
mask_gray = mask_frame
|
|
|
|
if mask_gray.shape[0] != src_height or mask_gray.shape[1] != src_width:
|
|
mask_gray = cv2.resize(
|
|
mask_gray,
|
|
(src_width, src_height),
|
|
interpolation=cv2.INTER_LINEAR,
|
|
)
|
|
|
|
blurred = cv2.GaussianBlur(src_frame, (blur_size, blur_size), 0)
|
|
alpha = (mask_gray.astype(np.float32) / 255.0)[..., np.newaxis]
|
|
composed = (src_frame.astype(np.float32) * (1.0 - alpha)) + (
|
|
blurred.astype(np.float32) * alpha
|
|
)
|
|
writer.write(np.clip(composed, 0, 255).astype(np.uint8))
|
|
|
|
tasks[task_id].progress = idx + 1
|
|
|
|
if tasks[task_id].status == TaskStatus.PROCESSING:
|
|
tasks[task_id].status = TaskStatus.COMPLETED
|
|
tasks[task_id].result_path = req.output_path
|
|
tasks[task_id].message = "Blur bake completed"
|
|
print(f"[FaceMask] Bake completed: {req.output_path}")
|
|
|
|
except Exception as e:
|
|
tasks[task_id].status = TaskStatus.FAILED
|
|
tasks[task_id].message = str(e)
|
|
print(f"Error in bake task {task_id}: {e}")
|
|
traceback.print_exc()
|
|
finally:
|
|
if src_cap:
|
|
src_cap.release()
|
|
if mask_cap:
|
|
mask_cap.release()
|
|
if writer:
|
|
writer.release()
|
|
if task_id in cancel_events:
|
|
del cancel_events[task_id]
|
|
|
|
def check_gpu_available() -> dict:
|
|
"""
|
|
Check if GPU is available for inference.
|
|
|
|
Returns a dict with GPU information:
|
|
{
|
|
"available": bool,
|
|
"device_name": str or None,
|
|
"device_count": int,
|
|
"rocm_version": str or None
|
|
}
|
|
"""
|
|
global _gpu_status_cache
|
|
|
|
# Return cached result if available
|
|
if _gpu_status_cache is not None:
|
|
return _gpu_status_cache
|
|
|
|
result = {
|
|
"available": False,
|
|
"device_name": None,
|
|
"device_count": 0,
|
|
"rocm_version": None
|
|
}
|
|
|
|
try:
|
|
import torch
|
|
|
|
result["available"] = torch.cuda.is_available()
|
|
|
|
if result["available"]:
|
|
result["device_count"] = torch.cuda.device_count()
|
|
if result["device_count"] > 0:
|
|
result["device_name"] = torch.cuda.get_device_name(0)
|
|
|
|
if hasattr(torch.version, 'hip'):
|
|
result["rocm_version"] = torch.version.hip
|
|
|
|
except Exception as e:
|
|
print(f"[FaceMask] Warning: GPU detection failed: {e}")
|
|
result["available"] = False
|
|
|
|
# Cache the result
|
|
_gpu_status_cache = result
|
|
|
|
return result
|
|
|
|
|
|
def log_startup_diagnostics():
|
|
"""Log diagnostic information about the environment and GPU."""
|
|
print("=" * 70)
|
|
print("[FaceMask Server] Startup Diagnostics")
|
|
print("=" * 70)
|
|
|
|
# Python Environment
|
|
print("\n[Python Environment]")
|
|
print(f" Python Version: {sys.version.split()[0]}")
|
|
print(f" Python Executable: {sys.executable}")
|
|
print(f" Platform: {platform.platform()}")
|
|
print(f" Working Directory: {os.getcwd()}")
|
|
|
|
# Check if in venv
|
|
in_venv = sys.prefix != sys.base_prefix
|
|
print(f" Virtual Environment: {'Yes' if in_venv else 'No'}")
|
|
if in_venv:
|
|
print(f" venv path: {sys.prefix}")
|
|
|
|
# ROCm Environment Variables
|
|
print("\n[ROCm Environment Variables]")
|
|
rocm_vars = [
|
|
'ROCM_PATH',
|
|
'HSA_OVERRIDE_GFX_VERSION',
|
|
'PYTORCH_ROCM_ARCH',
|
|
'ROCBLAS_TENSILE_LIBPATH',
|
|
'LD_LIBRARY_PATH'
|
|
]
|
|
|
|
for var in rocm_vars:
|
|
value = os.environ.get(var)
|
|
if value:
|
|
# For LD_LIBRARY_PATH, show if ROCm paths are included
|
|
if var == 'LD_LIBRARY_PATH':
|
|
has_rocm = 'rocm' in value.lower() or 'clr-' in value
|
|
has_hip = 'hip' in value.lower()
|
|
print(f" {var}: {value[:100]}...")
|
|
print(f" Contains ROCm paths: {has_rocm}")
|
|
print(f" Contains HIP paths: {has_hip}")
|
|
if not has_rocm:
|
|
print(f" ⚠️ WARNING: ROCm library paths not found!")
|
|
else:
|
|
if len(value) > 200:
|
|
display_value = value[:200] + "... (truncated)"
|
|
else:
|
|
display_value = value
|
|
print(f" {var}: {display_value}")
|
|
else:
|
|
print(f" {var}: (not set)")
|
|
|
|
# GPU Detection
|
|
print("\n[GPU Detection]")
|
|
try:
|
|
import torch
|
|
|
|
cuda_available = torch.cuda.is_available()
|
|
print(f" torch.cuda.is_available(): {cuda_available}")
|
|
|
|
if cuda_available:
|
|
device_count = torch.cuda.device_count()
|
|
print(f" GPU Device Count: {device_count}")
|
|
|
|
if device_count > 0:
|
|
device_name = torch.cuda.get_device_name(0)
|
|
print(f" GPU Device 0: {device_name}")
|
|
|
|
# ROCm version
|
|
if hasattr(torch.version, 'hip'):
|
|
print(f" ROCm Version (HIP): {torch.version.hip}")
|
|
|
|
# CUDA version (might be emulated by ROCm)
|
|
if torch.version.cuda:
|
|
print(f" CUDA Version: {torch.version.cuda}")
|
|
else:
|
|
print(" WARNING: GPU not detected!")
|
|
print(" Server will use CPU for inference (slower)")
|
|
print(" Troubleshooting:")
|
|
print(" - Check ROCm environment variables above")
|
|
print(" - Run: python -c 'import torch; print(torch.cuda.is_available())'")
|
|
|
|
except ImportError as e:
|
|
print(f" ERROR: Cannot import torch: {e}")
|
|
print(" PyTorch must be installed for inference")
|
|
except Exception as e:
|
|
print(f" ERROR during GPU detection: {e}")
|
|
|
|
print("=" * 70)
|
|
print()
|
|
|
|
|
|
@app.get("/status")
|
|
def get_status():
|
|
gpu_info = check_gpu_available()
|
|
|
|
return {
|
|
"status": "running",
|
|
"gpu_available": gpu_info["available"],
|
|
"gpu_device": gpu_info["device_name"],
|
|
"gpu_count": gpu_info["device_count"],
|
|
"rocm_version": gpu_info["rocm_version"]
|
|
}
|
|
|
|
@app.post("/generate", response_model=Task)
|
|
def generate_mask_endpoint(req: GenerateRequest, background_tasks: BackgroundTasks):
|
|
task_id = str(uuid.uuid4())
|
|
task = Task(id=task_id, status=TaskStatus.PENDING)
|
|
tasks[task_id] = task
|
|
cancel_events[task_id] = threading.Event()
|
|
|
|
background_tasks.add_task(process_video_task, task_id, req)
|
|
return task
|
|
|
|
|
|
@app.post("/bake_blur", response_model=Task)
|
|
def bake_blur_endpoint(req: BakeRequest, background_tasks: BackgroundTasks):
|
|
task_id = str(uuid.uuid4())
|
|
task = Task(id=task_id, status=TaskStatus.PENDING)
|
|
tasks[task_id] = task
|
|
cancel_events[task_id] = threading.Event()
|
|
|
|
background_tasks.add_task(process_bake_task, task_id, req)
|
|
return task
|
|
|
|
@app.get("/tasks/{task_id}", response_model=Task)
|
|
def get_task(task_id: str):
|
|
if task_id not in tasks:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
return tasks[task_id]
|
|
|
|
@app.post("/tasks/{task_id}/cancel")
|
|
def cancel_task(task_id: str):
|
|
if task_id not in tasks:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
|
if task_id in cancel_events:
|
|
cancel_events[task_id].set()
|
|
|
|
return {"message": "Cancellation requested"}
|
|
|
|
if __name__ == "__main__":
|
|
log_startup_diagnostics()
|
|
uvicorn.run(app, host="127.0.0.1", port=8181)
|