""" Client for interacting with the external inference server. Manages the server process and handles HTTP communication using standard library (avoiding requests dependency). """ import json import os import signal import subprocess import sys import threading import time import urllib.error import urllib.request from typing import Any, Dict, Optional class InferenceClient: """Client for the YOLO inference server.""" SERVER_URL = "http://127.0.0.1:8181" def __init__(self): self.server_process: Optional[subprocess.Popen] = None self._server_lock = threading.Lock() self.log_file = None self.log_file_path = None def start_server(self): """Start the inference server process.""" with self._server_lock: if self.is_server_running(): return print("[FaceMask] Starting inference server...") # Find project root # Assuming this file is in core/inference_client.py root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) server_script = os.path.join(root_dir, "server", "main.py") # Prepare environment variables for server process server_env = os.environ.copy() # Load environment variables from .env file if it exists env_file = os.path.join(root_dir, ".env") if os.path.exists(env_file): with open(env_file, "r") as f: for line in f: line = line.strip() if line and not line.startswith("#") and "=" in line: key, value = line.split("=", 1) server_env[key] = value print(f"[FaceMask] Loaded environment from: {env_file}") # Clean PYTHONPATH to avoid conflicts with Nix Python packages # Only include project root to allow local imports server_env["PYTHONPATH"] = root_dir # Remove Python-related environment variables that might cause conflicts # These can cause venv to import packages from Nix instead of venv env_vars_to_remove = [ "PYTHONUNBUFFERED", "__PYVENV_LAUNCHER__", # macOS venv variable "VIRTUAL_ENV", # Will be set by venv's Python automatically ] for var in env_vars_to_remove: server_env.pop(var, None) # If there's a venv in the project, add it to PATH venv_bin = os.path.join(root_dir, ".venv", "bin") if os.path.isdir(venv_bin): # Build a clean PATH with venv first, then essential system paths # Filter out any Nix Python-specific paths to avoid version conflicts current_path = server_env.get("PATH", "") path_entries = current_path.split(":") # Filter out Nix Python 3.11 paths filtered_paths = [ p for p in path_entries if not ("/python3.11/" in p.lower() or "/python3-3.11" in p.lower()) ] # Reconstruct PATH with venv first clean_path = ":".join([venv_bin] + filtered_paths) server_env["PATH"] = clean_path print(f"[FaceMask] Using venv from: {venv_bin}") # Prepare log file for server output import tempfile log_dir = tempfile.gettempdir() self.log_file_path = os.path.join(log_dir, "facemask_server.log") self.log_file = open(self.log_file_path, "w", buffering=1) # Line buffered print(f"[FaceMask] Server log: {self.log_file_path}") # Start server with explicit Python executable when available. python_executable = "python" venv_python = os.path.join(venv_bin, "python") if os.path.isfile(venv_python): python_executable = venv_python else: python_executable = sys.executable self.server_process = subprocess.Popen( [python_executable, "-u", server_script], # -u for unbuffered output cwd=root_dir, text=True, env=server_env, stdout=self.log_file, # Write to log file stderr=subprocess.STDOUT, # Merge stderr into stdout preexec_fn=os.setsid, # Create new process group ) # Wait for startup for _ in range(20): # Wait up to 10 seconds if self.is_server_running(): print("[FaceMask] Server started successfully") return # Check if process died if self.server_process.poll() is not None: # Read error output from log file error_msg = f"Server failed to start (exit code: {self.server_process.returncode})" print(f"[FaceMask] ERROR: {error_msg}") try: if self.log_file: self.log_file.close() with open(self.log_file_path, "r") as f: log_content = f.read() if log_content.strip(): print("[FaceMask] Server log:") # Show last 50 lines lines = log_content.strip().split("\n") for line in lines[-50:]: print(line) except Exception as e: print(f"[FaceMask] Could not read log file: {e}") self.server_process = None raise RuntimeError(error_msg) time.sleep(0.5) # If we get here, startup timed out print("[FaceMask] Server startup timed out") # Try to read partial log try: if self.log_file: self.log_file.close() with open(self.log_file_path, "r") as f: log_content = f.read() if log_content.strip(): print("[FaceMask] Server log (partial):") lines = log_content.strip().split("\n") for line in lines[-30:]: print(line) except Exception: pass raise RuntimeError("Server startup timed out") def stop_server(self): """Stop the inference server.""" with self._server_lock: if self.server_process: print("[FaceMask] Stopping server...") try: os.killpg(os.getpgid(self.server_process.pid), signal.SIGTERM) self.server_process.wait(timeout=3) except (ProcessLookupError, subprocess.TimeoutExpired): pass finally: self.server_process = None # Close log file if self.log_file: try: self.log_file.close() except Exception: pass self.log_file = None def is_server_running(self) -> bool: """Check if server is responding.""" try: with urllib.request.urlopen( f"{self.SERVER_URL}/status", timeout=1 ) as response: return response.status == 200 except (urllib.error.URLError, ConnectionRefusedError, TimeoutError): return False def generate_mask( self, video_path: str, output_dir: str, start_frame: int, end_frame: int, conf_threshold: float, iou_threshold: float, ) -> str: """ Request mask generation. Returns: task_id (str) """ if not self.is_server_running(): self.start_server() data = { "video_path": video_path, "output_dir": output_dir, "start_frame": start_frame, "end_frame": end_frame, "conf_threshold": conf_threshold, "iou_threshold": iou_threshold, } req = urllib.request.Request( f"{self.SERVER_URL}/generate", data=json.dumps(data).encode("utf-8"), headers={"Content-Type": "application/json"}, method="POST", ) try: with urllib.request.urlopen(req) as response: result = json.loads(response.read().decode("utf-8")) return result["id"] except urllib.error.HTTPError as e: raise RuntimeError(f"Server error: {e.read().decode('utf-8')}") def get_task_status(self, task_id: str) -> Dict[str, Any]: """Get status of a task.""" try: with urllib.request.urlopen( f"{self.SERVER_URL}/tasks/{task_id}" ) as response: return json.loads(response.read().decode("utf-8")) except urllib.error.HTTPError: return {"status": "unknown"} def get_video_info(self, video_path: str) -> Dict[str, Any]: """Get video metadata from the inference server.""" if not self.is_server_running(): self.start_server() req = urllib.request.Request( f"{self.SERVER_URL}/video_info", data=json.dumps({"video_path": video_path}).encode("utf-8"), headers={"Content-Type": "application/json"}, method="POST", ) try: with urllib.request.urlopen(req) as response: return json.loads(response.read().decode("utf-8")) except urllib.error.HTTPError as e: raise RuntimeError(f"Server error: {e.read().decode('utf-8')}") def bake_blur( self, video_path: str, detections_path: str, output_path: str, blur_size: int, display_scale: float, fmt: str, ) -> str: """ Request blur bake for a source video + mask video. Returns: task_id (str) """ if not self.is_server_running(): self.start_server() data = { "video_path": video_path, "detections_path": detections_path, "output_path": output_path, "blur_size": blur_size, "display_scale": display_scale, "format": fmt, } req = urllib.request.Request( f"{self.SERVER_URL}/bake_blur", data=json.dumps(data).encode("utf-8"), headers={"Content-Type": "application/json"}, method="POST", ) try: with urllib.request.urlopen(req) as response: result = json.loads(response.read().decode("utf-8")) return result["id"] except urllib.error.HTTPError as e: raise RuntimeError(f"Server error: {e.read().decode('utf-8')}") def generate_mask_images( self, image_dir: str, filenames: list, output_dir: str, start_index: int, end_index: int, conf_threshold: float, iou_threshold: float, ) -> str: """画像シーケンスの顔検出タスクを開始して task_id を返す。""" if not self.is_server_running(): self.start_server() data = { "image_dir": image_dir, "filenames": filenames, "output_dir": output_dir, "start_index": start_index, "end_index": end_index, "conf_threshold": conf_threshold, "iou_threshold": iou_threshold, } req = urllib.request.Request( f"{self.SERVER_URL}/generate_images", data=json.dumps(data).encode("utf-8"), headers={"Content-Type": "application/json"}, method="POST", ) try: with urllib.request.urlopen(req) as response: result = json.loads(response.read().decode("utf-8")) return result["id"] except urllib.error.HTTPError as e: raise RuntimeError(f"Server error: {e.read().decode('utf-8')}") def bake_image_blur( self, image_dir: str, filenames: list, output_dir: str, detections_path: str, blur_size: int, display_scale: float, ) -> str: """画像シーケンスのぼかしBakeタスクを開始して task_id を返す。""" if not self.is_server_running(): self.start_server() data = { "image_dir": image_dir, "filenames": filenames, "output_dir": output_dir, "detections_path": detections_path, "blur_size": blur_size, "display_scale": display_scale, } req = urllib.request.Request( f"{self.SERVER_URL}/bake_image_blur", data=json.dumps(data).encode("utf-8"), headers={"Content-Type": "application/json"}, method="POST", ) try: with urllib.request.urlopen(req) as response: result = json.loads(response.read().decode("utf-8")) return result["id"] except urllib.error.HTTPError as e: raise RuntimeError(f"Server error: {e.read().decode('utf-8')}") def cancel_task(self, task_id: str): """Cancel a task.""" try: req = urllib.request.Request( f"{self.SERVER_URL}/tasks/{task_id}/cancel", method="POST" ) with urllib.request.urlopen(req): pass except urllib.error.HTTPError: pass # Singleton _client: Optional[InferenceClient] = None def get_client() -> InferenceClient: global _client if _client is None: _client = InferenceClient() return _client