329 lines
12 KiB
Python
329 lines
12 KiB
Python
"""
|
|
Client for interacting with the external inference server.
|
|
|
|
Manages the server process and handles HTTP communication
|
|
using standard library (avoiding requests dependency).
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import signal
|
|
import subprocess
|
|
import sys
|
|
import threading
|
|
import time
|
|
import urllib.error
|
|
import urllib.request
|
|
from typing import Any, Dict, Optional
|
|
|
|
|
|
class InferenceClient:
|
|
"""Client for the YOLO inference server."""
|
|
|
|
SERVER_URL = "http://127.0.0.1:8181"
|
|
|
|
def __init__(self):
|
|
self.server_process: Optional[subprocess.Popen] = None
|
|
self._server_lock = threading.Lock()
|
|
self.log_file = None
|
|
self.log_file_path = None
|
|
|
|
def start_server(self):
|
|
"""Start the inference server process."""
|
|
with self._server_lock:
|
|
if self.is_server_running():
|
|
return
|
|
|
|
print("[FaceMask] Starting inference server...")
|
|
|
|
# Find project root
|
|
# Assuming this file is in core/inference_client.py
|
|
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
server_script = os.path.join(root_dir, "server", "main.py")
|
|
|
|
# Prepare environment variables for server process
|
|
server_env = os.environ.copy()
|
|
|
|
# Load environment variables from .env file if it exists
|
|
env_file = os.path.join(root_dir, ".env")
|
|
if os.path.exists(env_file):
|
|
with open(env_file, "r") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line and not line.startswith("#") and "=" in line:
|
|
key, value = line.split("=", 1)
|
|
server_env[key] = value
|
|
print(f"[FaceMask] Loaded environment from: {env_file}")
|
|
|
|
# Clean PYTHONPATH to avoid conflicts with Nix Python packages
|
|
# Only include project root to allow local imports
|
|
server_env["PYTHONPATH"] = root_dir
|
|
|
|
# Remove Python-related environment variables that might cause conflicts
|
|
# These can cause venv to import packages from Nix instead of venv
|
|
env_vars_to_remove = [
|
|
"PYTHONUNBUFFERED",
|
|
"__PYVENV_LAUNCHER__", # macOS venv variable
|
|
"VIRTUAL_ENV", # Will be set by venv's Python automatically
|
|
]
|
|
for var in env_vars_to_remove:
|
|
server_env.pop(var, None)
|
|
|
|
# If there's a venv in the project, add it to PATH
|
|
venv_bin = os.path.join(root_dir, ".venv", "bin")
|
|
if os.path.isdir(venv_bin):
|
|
# Build a clean PATH with venv first, then essential system paths
|
|
# Filter out any Nix Python-specific paths to avoid version conflicts
|
|
current_path = server_env.get("PATH", "")
|
|
path_entries = current_path.split(":")
|
|
|
|
# Filter out Nix Python 3.11 paths
|
|
filtered_paths = [
|
|
p
|
|
for p in path_entries
|
|
if not ("/python3.11/" in p.lower() or "/python3-3.11" in p.lower())
|
|
]
|
|
|
|
# Reconstruct PATH with venv first
|
|
clean_path = ":".join([venv_bin] + filtered_paths)
|
|
server_env["PATH"] = clean_path
|
|
print(f"[FaceMask] Using venv from: {venv_bin}")
|
|
|
|
# Prepare log file for server output
|
|
import tempfile
|
|
|
|
log_dir = tempfile.gettempdir()
|
|
self.log_file_path = os.path.join(log_dir, "facemask_server.log")
|
|
self.log_file = open(self.log_file_path, "w", buffering=1) # Line buffered
|
|
print(f"[FaceMask] Server log: {self.log_file_path}")
|
|
|
|
# Start server with explicit Python executable when available.
|
|
python_executable = "python"
|
|
venv_python = os.path.join(venv_bin, "python")
|
|
if os.path.isfile(venv_python):
|
|
python_executable = venv_python
|
|
else:
|
|
python_executable = sys.executable
|
|
|
|
self.server_process = subprocess.Popen(
|
|
[python_executable, "-u", server_script], # -u for unbuffered output
|
|
cwd=root_dir,
|
|
text=True,
|
|
env=server_env,
|
|
stdout=self.log_file, # Write to log file
|
|
stderr=subprocess.STDOUT, # Merge stderr into stdout
|
|
preexec_fn=os.setsid, # Create new process group
|
|
)
|
|
|
|
# Wait for startup
|
|
for _ in range(20): # Wait up to 10 seconds
|
|
if self.is_server_running():
|
|
print("[FaceMask] Server started successfully")
|
|
return
|
|
|
|
# Check if process died
|
|
if self.server_process.poll() is not None:
|
|
# Read error output from log file
|
|
error_msg = f"Server failed to start (exit code: {self.server_process.returncode})"
|
|
print(f"[FaceMask] ERROR: {error_msg}")
|
|
|
|
try:
|
|
if self.log_file:
|
|
self.log_file.close()
|
|
with open(self.log_file_path, "r") as f:
|
|
log_content = f.read()
|
|
if log_content.strip():
|
|
print("[FaceMask] Server log:")
|
|
# Show last 50 lines
|
|
lines = log_content.strip().split("\n")
|
|
for line in lines[-50:]:
|
|
print(line)
|
|
except Exception as e:
|
|
print(f"[FaceMask] Could not read log file: {e}")
|
|
|
|
self.server_process = None
|
|
raise RuntimeError(error_msg)
|
|
|
|
time.sleep(0.5)
|
|
|
|
# If we get here, startup timed out
|
|
print("[FaceMask] Server startup timed out")
|
|
|
|
# Try to read partial log
|
|
try:
|
|
if self.log_file:
|
|
self.log_file.close()
|
|
with open(self.log_file_path, "r") as f:
|
|
log_content = f.read()
|
|
if log_content.strip():
|
|
print("[FaceMask] Server log (partial):")
|
|
lines = log_content.strip().split("\n")
|
|
for line in lines[-30:]:
|
|
print(line)
|
|
except Exception:
|
|
pass
|
|
|
|
raise RuntimeError("Server startup timed out")
|
|
|
|
def stop_server(self):
|
|
"""Stop the inference server."""
|
|
with self._server_lock:
|
|
if self.server_process:
|
|
print("[FaceMask] Stopping server...")
|
|
try:
|
|
os.killpg(os.getpgid(self.server_process.pid), signal.SIGTERM)
|
|
self.server_process.wait(timeout=3)
|
|
except (ProcessLookupError, subprocess.TimeoutExpired):
|
|
pass
|
|
finally:
|
|
self.server_process = None
|
|
|
|
# Close log file
|
|
if self.log_file:
|
|
try:
|
|
self.log_file.close()
|
|
except Exception:
|
|
pass
|
|
self.log_file = None
|
|
|
|
def is_server_running(self) -> bool:
|
|
"""Check if server is responding."""
|
|
try:
|
|
with urllib.request.urlopen(
|
|
f"{self.SERVER_URL}/status", timeout=1
|
|
) as response:
|
|
return response.status == 200
|
|
except (urllib.error.URLError, ConnectionRefusedError, TimeoutError):
|
|
return False
|
|
|
|
def generate_mask(
|
|
self,
|
|
video_path: str,
|
|
output_dir: str,
|
|
start_frame: int,
|
|
end_frame: int,
|
|
conf_threshold: float,
|
|
iou_threshold: float,
|
|
) -> str:
|
|
"""
|
|
Request mask generation.
|
|
|
|
Returns:
|
|
task_id (str)
|
|
"""
|
|
if not self.is_server_running():
|
|
self.start_server()
|
|
|
|
data = {
|
|
"video_path": video_path,
|
|
"output_dir": output_dir,
|
|
"start_frame": start_frame,
|
|
"end_frame": end_frame,
|
|
"conf_threshold": conf_threshold,
|
|
"iou_threshold": iou_threshold,
|
|
}
|
|
|
|
req = urllib.request.Request(
|
|
f"{self.SERVER_URL}/generate",
|
|
data=json.dumps(data).encode("utf-8"),
|
|
headers={"Content-Type": "application/json"},
|
|
method="POST",
|
|
)
|
|
|
|
try:
|
|
with urllib.request.urlopen(req) as response:
|
|
result = json.loads(response.read().decode("utf-8"))
|
|
return result["id"]
|
|
except urllib.error.HTTPError as e:
|
|
raise RuntimeError(f"Server error: {e.read().decode('utf-8')}")
|
|
|
|
def get_task_status(self, task_id: str) -> Dict[str, Any]:
|
|
"""Get status of a task."""
|
|
try:
|
|
with urllib.request.urlopen(
|
|
f"{self.SERVER_URL}/tasks/{task_id}"
|
|
) as response:
|
|
return json.loads(response.read().decode("utf-8"))
|
|
except urllib.error.HTTPError:
|
|
return {"status": "unknown"}
|
|
|
|
def get_video_info(self, video_path: str) -> Dict[str, Any]:
|
|
"""Get video metadata from the inference server."""
|
|
if not self.is_server_running():
|
|
self.start_server()
|
|
|
|
req = urllib.request.Request(
|
|
f"{self.SERVER_URL}/video_info",
|
|
data=json.dumps({"video_path": video_path}).encode("utf-8"),
|
|
headers={"Content-Type": "application/json"},
|
|
method="POST",
|
|
)
|
|
try:
|
|
with urllib.request.urlopen(req) as response:
|
|
return json.loads(response.read().decode("utf-8"))
|
|
except urllib.error.HTTPError as e:
|
|
raise RuntimeError(f"Server error: {e.read().decode('utf-8')}")
|
|
|
|
def bake_blur(
|
|
self,
|
|
video_path: str,
|
|
detections_path: str,
|
|
output_path: str,
|
|
blur_size: int,
|
|
display_scale: float,
|
|
fmt: str,
|
|
) -> str:
|
|
"""
|
|
Request blur bake for a source video + mask video.
|
|
|
|
Returns:
|
|
task_id (str)
|
|
"""
|
|
if not self.is_server_running():
|
|
self.start_server()
|
|
|
|
data = {
|
|
"video_path": video_path,
|
|
"detections_path": detections_path,
|
|
"output_path": output_path,
|
|
"blur_size": blur_size,
|
|
"display_scale": display_scale,
|
|
"format": fmt,
|
|
}
|
|
|
|
req = urllib.request.Request(
|
|
f"{self.SERVER_URL}/bake_blur",
|
|
data=json.dumps(data).encode("utf-8"),
|
|
headers={"Content-Type": "application/json"},
|
|
method="POST",
|
|
)
|
|
|
|
try:
|
|
with urllib.request.urlopen(req) as response:
|
|
result = json.loads(response.read().decode("utf-8"))
|
|
return result["id"]
|
|
except urllib.error.HTTPError as e:
|
|
raise RuntimeError(f"Server error: {e.read().decode('utf-8')}")
|
|
|
|
def cancel_task(self, task_id: str):
|
|
"""Cancel a task."""
|
|
try:
|
|
req = urllib.request.Request(
|
|
f"{self.SERVER_URL}/tasks/{task_id}/cancel", method="POST"
|
|
)
|
|
with urllib.request.urlopen(req):
|
|
pass
|
|
except urllib.error.HTTPError:
|
|
pass
|
|
|
|
|
|
# Singleton
|
|
_client: Optional[InferenceClient] = None
|
|
|
|
|
|
def get_client() -> InferenceClient:
|
|
global _client
|
|
if _client is None:
|
|
_client = InferenceClient()
|
|
return _client
|