blender-mask-peoples/core/inference_client.py
2026-02-12 18:52:55 +09:00

220 lines
7.8 KiB
Python

"""
Client for interacting with the external inference server.
Manages the server process and handles HTTP communication
using standard library (avoiding requests dependency).
"""
import subprocess
import time
import json
import urllib.request
import urllib.error
import threading
import os
import signal
from typing import Optional, Dict, Any, Tuple
class InferenceClient:
"""Client for the YOLO inference server."""
SERVER_URL = "http://127.0.0.1:8181"
def __init__(self):
self.server_process: Optional[subprocess.Popen] = None
self._server_lock = threading.Lock()
def start_server(self):
"""Start the inference server process."""
with self._server_lock:
if self.is_server_running():
return
print("[FaceMask] Starting inference server...")
# Find project root
# Assuming this file is in core/inference_client.py
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
server_script = os.path.join(root_dir, "server", "main.py")
# Prepare environment variables for server process
server_env = os.environ.copy()
# Load environment variables from .env file if it exists
env_file = os.path.join(root_dir, ".env")
if os.path.exists(env_file):
with open(env_file, 'r') as f:
for line in f:
line = line.strip()
if line and not line.startswith('#') and '=' in line:
key, value = line.split('=', 1)
server_env[key] = value
print(f"[FaceMask] Loaded environment from: {env_file}")
# Ensure PYTHONPATH includes project root
pythonpath = server_env.get('PYTHONPATH', '')
if pythonpath:
server_env['PYTHONPATH'] = f"{root_dir}:{pythonpath}"
else:
server_env['PYTHONPATH'] = root_dir
# If there's a venv in the project, add it to PATH
venv_bin = os.path.join(root_dir, ".venv", "bin")
if os.path.isdir(venv_bin):
current_path = server_env.get('PATH', '')
server_env['PATH'] = f"{venv_bin}:{current_path}"
print(f"[FaceMask] Using venv from: {venv_bin}")
# Start process with 'python' command (will use venv if PATH is set correctly)
self.server_process = subprocess.Popen(
["python", server_script],
cwd=root_dir,
text=True,
env=server_env,
stdout=subprocess.PIPE, # Capture stdout
stderr=subprocess.PIPE, # Capture stderr
preexec_fn=os.setsid, # Create new process group
)
# Wait for startup
for _ in range(20): # Wait up to 10 seconds
if self.is_server_running():
print("[FaceMask] Server started successfully")
return
# Check if process died
if self.server_process.poll() is not None:
# Capture and display error output
try:
stdout, stderr = self.server_process.communicate(timeout=1)
except subprocess.TimeoutExpired:
stdout, stderr = "", ""
error_msg = f"Server failed to start (exit code: {self.server_process.returncode})"
print(f"[FaceMask] ERROR: {error_msg}")
if stdout and stdout.strip():
print("[FaceMask] Server stdout:")
print(stdout.strip())
if stderr and stderr.strip():
print("[FaceMask] Server stderr:")
print(stderr.strip())
self.server_process = None
raise RuntimeError(error_msg)
time.sleep(0.5)
# If we get here, startup timed out
# Try to capture any partial output
if self.server_process:
try:
# Non-blocking read with short timeout
stdout, stderr = self.server_process.communicate(timeout=0.1)
if stdout and stdout.strip():
print("[FaceMask] Server stdout (partial):")
print(stdout.strip())
if stderr and stderr.strip():
print("[FaceMask] Server stderr (partial):")
print(stderr.strip())
except subprocess.TimeoutExpired:
pass # Process still running but not responding
raise RuntimeError("Server startup timed out")
def stop_server(self):
"""Stop the inference server."""
with self._server_lock:
if self.server_process:
print("[FaceMask] Stopping server...")
try:
os.killpg(os.getpgid(self.server_process.pid), signal.SIGTERM)
self.server_process.wait(timeout=3)
except (ProcessLookupError, subprocess.TimeoutExpired):
pass
finally:
self.server_process = None
def is_server_running(self) -> bool:
"""Check if server is responding."""
try:
with urllib.request.urlopen(f"{self.SERVER_URL}/status", timeout=1) as response:
return response.status == 200
except (urllib.error.URLError, ConnectionRefusedError, TimeoutError):
return False
def generate_mask(
self,
video_path: str,
output_dir: str,
start_frame: int,
end_frame: int,
conf_threshold: float,
iou_threshold: float,
mask_scale: float,
) -> str:
"""
Request mask generation.
Returns:
task_id (str)
"""
if not self.is_server_running():
self.start_server()
data = {
"video_path": video_path,
"output_dir": output_dir,
"start_frame": start_frame,
"end_frame": end_frame,
"conf_threshold": conf_threshold,
"iou_threshold": iou_threshold,
"mask_scale": mask_scale,
}
req = urllib.request.Request(
f"{self.SERVER_URL}/generate",
data=json.dumps(data).encode('utf-8'),
headers={'Content-Type': 'application/json'},
method='POST'
)
try:
with urllib.request.urlopen(req) as response:
result = json.loads(response.read().decode('utf-8'))
return result['id']
except urllib.error.HTTPError as e:
raise RuntimeError(f"Server error: {e.read().decode('utf-8')}")
def get_task_status(self, task_id: str) -> Dict[str, Any]:
"""Get status of a task."""
try:
with urllib.request.urlopen(f"{self.SERVER_URL}/tasks/{task_id}") as response:
return json.loads(response.read().decode('utf-8'))
except urllib.error.HTTPError:
return {"status": "unknown"}
def cancel_task(self, task_id: str):
"""Cancel a task."""
try:
req = urllib.request.Request(
f"{self.SERVER_URL}/tasks/{task_id}/cancel",
method='POST'
)
with urllib.request.urlopen(req):
pass
except urllib.error.HTTPError:
pass
# Singleton
_client: Optional[InferenceClient] = None
def get_client() -> InferenceClient:
global _client
if _client is None:
_client = InferenceClient()
return _client