blender-mask-peoples/core/inference_client.py
2026-02-06 10:13:26 +09:00

160 lines
5.1 KiB
Python

"""
Client for interacting with the external inference server.
Manages the server process and handles HTTP communication
using standard library (avoiding requests dependency).
"""
import subprocess
import time
import json
import urllib.request
import urllib.error
import threading
import os
import signal
from typing import Optional, Dict, Any, Tuple
class InferenceClient:
"""Client for the YOLO inference server."""
SERVER_URL = "http://127.0.0.1:8181"
def __init__(self):
self.server_process: Optional[subprocess.Popen] = None
self._server_lock = threading.Lock()
def start_server(self):
"""Start the inference server process."""
with self._server_lock:
if self.is_server_running():
return
print("[FaceMask] Starting inference server...")
# Find project root
# Assuming this file is in core/inference_client.py
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
server_script = os.path.join(root_dir, "server", "main.py")
# Use system python (assumed to have dependencies via Nix/venv)
# In user's environment, 'python' should refer to the environment python
python_cmd = "python"
# Start process
self.server_process = subprocess.Popen(
[python_cmd, server_script],
cwd=root_dir,
text=True,
preexec_fn=os.setsid, # Create new process group
)
# Wait for startup
for _ in range(20): # Wait up to 10 seconds
if self.is_server_running():
print("[FaceMask] Server started successfully")
return
# Check if process died
if self.server_process.poll() is not None:
raise RuntimeError(f"Server failed to start (rc={self.server_process.returncode})")
time.sleep(0.5)
raise RuntimeError("Server startup timed out")
def stop_server(self):
"""Stop the inference server."""
with self._server_lock:
if self.server_process:
print("[FaceMask] Stopping server...")
try:
os.killpg(os.getpgid(self.server_process.pid), signal.SIGTERM)
self.server_process.wait(timeout=3)
except (ProcessLookupError, subprocess.TimeoutExpired):
pass
finally:
self.server_process = None
def is_server_running(self) -> bool:
"""Check if server is responding."""
try:
with urllib.request.urlopen(f"{self.SERVER_URL}/status", timeout=1) as response:
return response.status == 200
except (urllib.error.URLError, ConnectionRefusedError, TimeoutError):
return False
def generate_mask(
self,
video_path: str,
output_dir: str,
start_frame: int,
end_frame: int,
conf_threshold: float,
iou_threshold: float,
mask_scale: float,
) -> str:
"""
Request mask generation.
Returns:
task_id (str)
"""
if not self.is_server_running():
self.start_server()
data = {
"video_path": video_path,
"output_dir": output_dir,
"start_frame": start_frame,
"end_frame": end_frame,
"conf_threshold": conf_threshold,
"iou_threshold": iou_threshold,
"mask_scale": mask_scale,
}
req = urllib.request.Request(
f"{self.SERVER_URL}/generate",
data=json.dumps(data).encode('utf-8'),
headers={'Content-Type': 'application/json'},
method='POST'
)
try:
with urllib.request.urlopen(req) as response:
result = json.loads(response.read().decode('utf-8'))
return result['id']
except urllib.error.HTTPError as e:
raise RuntimeError(f"Server error: {e.read().decode('utf-8')}")
def get_task_status(self, task_id: str) -> Dict[str, Any]:
"""Get status of a task."""
try:
with urllib.request.urlopen(f"{self.SERVER_URL}/tasks/{task_id}") as response:
return json.loads(response.read().decode('utf-8'))
except urllib.error.HTTPError:
return {"status": "unknown"}
def cancel_task(self, task_id: str):
"""Cancel a task."""
try:
req = urllib.request.Request(
f"{self.SERVER_URL}/tasks/{task_id}/cancel",
method='POST'
)
with urllib.request.urlopen(req):
pass
except urllib.error.HTTPError:
pass
# Singleton
_client: Optional[InferenceClient] = None
def get_client() -> InferenceClient:
global _client
if _client is None:
_client = InferenceClient()
return _client