blender-mask-peoples/core/inference_client.py

312 lines
11 KiB
Python

"""
Client for interacting with the external inference server.
Manages the server process and handles HTTP communication
using standard library (avoiding requests dependency).
"""
import json
import os
import signal
import subprocess
import sys
import threading
import time
import urllib.error
import urllib.request
from typing import Any, Dict, Optional
class InferenceClient:
"""Client for the YOLO inference server."""
SERVER_URL = "http://127.0.0.1:8181"
def __init__(self):
self.server_process: Optional[subprocess.Popen] = None
self._server_lock = threading.Lock()
self.log_file = None
self.log_file_path = None
def start_server(self):
"""Start the inference server process."""
with self._server_lock:
if self.is_server_running():
return
print("[FaceMask] Starting inference server...")
# Find project root
# Assuming this file is in core/inference_client.py
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
server_script = os.path.join(root_dir, "server", "main.py")
# Prepare environment variables for server process
server_env = os.environ.copy()
# Load environment variables from .env file if it exists
env_file = os.path.join(root_dir, ".env")
if os.path.exists(env_file):
with open(env_file, "r") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, value = line.split("=", 1)
server_env[key] = value
print(f"[FaceMask] Loaded environment from: {env_file}")
# Clean PYTHONPATH to avoid conflicts with Nix Python packages
# Only include project root to allow local imports
server_env["PYTHONPATH"] = root_dir
# Remove Python-related environment variables that might cause conflicts
# These can cause venv to import packages from Nix instead of venv
env_vars_to_remove = [
"PYTHONUNBUFFERED",
"__PYVENV_LAUNCHER__", # macOS venv variable
"VIRTUAL_ENV", # Will be set by venv's Python automatically
]
for var in env_vars_to_remove:
server_env.pop(var, None)
# If there's a venv in the project, add it to PATH
venv_bin = os.path.join(root_dir, ".venv", "bin")
if os.path.isdir(venv_bin):
# Build a clean PATH with venv first, then essential system paths
# Filter out any Nix Python-specific paths to avoid version conflicts
current_path = server_env.get("PATH", "")
path_entries = current_path.split(":")
# Filter out Nix Python 3.11 paths
filtered_paths = [
p
for p in path_entries
if not ("/python3.11/" in p.lower() or "/python3-3.11" in p.lower())
]
# Reconstruct PATH with venv first
clean_path = ":".join([venv_bin] + filtered_paths)
server_env["PATH"] = clean_path
print(f"[FaceMask] Using venv from: {venv_bin}")
# Prepare log file for server output
import tempfile
log_dir = tempfile.gettempdir()
self.log_file_path = os.path.join(log_dir, "facemask_server.log")
self.log_file = open(self.log_file_path, "w", buffering=1) # Line buffered
print(f"[FaceMask] Server log: {self.log_file_path}")
# Start server with explicit Python executable when available.
python_executable = "python"
venv_python = os.path.join(venv_bin, "python")
if os.path.isfile(venv_python):
python_executable = venv_python
else:
python_executable = sys.executable
self.server_process = subprocess.Popen(
[python_executable, "-u", server_script], # -u for unbuffered output
cwd=root_dir,
text=True,
env=server_env,
stdout=self.log_file, # Write to log file
stderr=subprocess.STDOUT, # Merge stderr into stdout
preexec_fn=os.setsid, # Create new process group
)
# Wait for startup
for _ in range(20): # Wait up to 10 seconds
if self.is_server_running():
print("[FaceMask] Server started successfully")
return
# Check if process died
if self.server_process.poll() is not None:
# Read error output from log file
error_msg = f"Server failed to start (exit code: {self.server_process.returncode})"
print(f"[FaceMask] ERROR: {error_msg}")
try:
if self.log_file:
self.log_file.close()
with open(self.log_file_path, "r") as f:
log_content = f.read()
if log_content.strip():
print("[FaceMask] Server log:")
# Show last 50 lines
lines = log_content.strip().split("\n")
for line in lines[-50:]:
print(line)
except Exception as e:
print(f"[FaceMask] Could not read log file: {e}")
self.server_process = None
raise RuntimeError(error_msg)
time.sleep(0.5)
# If we get here, startup timed out
print("[FaceMask] Server startup timed out")
# Try to read partial log
try:
if self.log_file:
self.log_file.close()
with open(self.log_file_path, "r") as f:
log_content = f.read()
if log_content.strip():
print("[FaceMask] Server log (partial):")
lines = log_content.strip().split("\n")
for line in lines[-30:]:
print(line)
except Exception:
pass
raise RuntimeError("Server startup timed out")
def stop_server(self):
"""Stop the inference server."""
with self._server_lock:
if self.server_process:
print("[FaceMask] Stopping server...")
try:
os.killpg(os.getpgid(self.server_process.pid), signal.SIGTERM)
self.server_process.wait(timeout=3)
except (ProcessLookupError, subprocess.TimeoutExpired):
pass
finally:
self.server_process = None
# Close log file
if self.log_file:
try:
self.log_file.close()
except Exception:
pass
self.log_file = None
def is_server_running(self) -> bool:
"""Check if server is responding."""
try:
with urllib.request.urlopen(
f"{self.SERVER_URL}/status", timeout=1
) as response:
return response.status == 200
except (urllib.error.URLError, ConnectionRefusedError, TimeoutError):
return False
def generate_mask(
self,
video_path: str,
output_dir: str,
start_frame: int,
end_frame: int,
conf_threshold: float,
iou_threshold: float,
) -> str:
"""
Request mask generation.
Returns:
task_id (str)
"""
if not self.is_server_running():
self.start_server()
data = {
"video_path": video_path,
"output_dir": output_dir,
"start_frame": start_frame,
"end_frame": end_frame,
"conf_threshold": conf_threshold,
"iou_threshold": iou_threshold,
}
req = urllib.request.Request(
f"{self.SERVER_URL}/generate",
data=json.dumps(data).encode("utf-8"),
headers={"Content-Type": "application/json"},
method="POST",
)
try:
with urllib.request.urlopen(req) as response:
result = json.loads(response.read().decode("utf-8"))
return result["id"]
except urllib.error.HTTPError as e:
raise RuntimeError(f"Server error: {e.read().decode('utf-8')}")
def get_task_status(self, task_id: str) -> Dict[str, Any]:
"""Get status of a task."""
try:
with urllib.request.urlopen(
f"{self.SERVER_URL}/tasks/{task_id}"
) as response:
return json.loads(response.read().decode("utf-8"))
except urllib.error.HTTPError:
return {"status": "unknown"}
def bake_blur(
self,
video_path: str,
detections_path: str,
output_path: str,
blur_size: int,
display_scale: float,
fmt: str,
) -> str:
"""
Request blur bake for a source video + mask video.
Returns:
task_id (str)
"""
if not self.is_server_running():
self.start_server()
data = {
"video_path": video_path,
"detections_path": detections_path,
"output_path": output_path,
"blur_size": blur_size,
"display_scale": display_scale,
"format": fmt,
}
req = urllib.request.Request(
f"{self.SERVER_URL}/bake_blur",
data=json.dumps(data).encode("utf-8"),
headers={"Content-Type": "application/json"},
method="POST",
)
try:
with urllib.request.urlopen(req) as response:
result = json.loads(response.read().decode("utf-8"))
return result["id"]
except urllib.error.HTTPError as e:
raise RuntimeError(f"Server error: {e.read().decode('utf-8')}")
def cancel_task(self, task_id: str):
"""Cancel a task."""
try:
req = urllib.request.Request(
f"{self.SERVER_URL}/tasks/{task_id}/cancel", method="POST"
)
with urllib.request.urlopen(req):
pass
except urllib.error.HTTPError:
pass
# Singleton
_client: Optional[InferenceClient] = None
def get_client() -> InferenceClient:
global _client
if _client is None:
_client = InferenceClient()
return _client