160 lines
5.1 KiB
Python
160 lines
5.1 KiB
Python
"""
|
|
Client for interacting with the external inference server.
|
|
|
|
Manages the server process and handles HTTP communication
|
|
using standard library (avoiding requests dependency).
|
|
"""
|
|
|
|
import subprocess
|
|
import time
|
|
import json
|
|
import urllib.request
|
|
import urllib.error
|
|
import threading
|
|
import os
|
|
import signal
|
|
from typing import Optional, Dict, Any, Tuple
|
|
|
|
|
|
class InferenceClient:
|
|
"""Client for the YOLO inference server."""
|
|
|
|
SERVER_URL = "http://127.0.0.1:8181"
|
|
|
|
def __init__(self):
|
|
self.server_process: Optional[subprocess.Popen] = None
|
|
self._server_lock = threading.Lock()
|
|
|
|
def start_server(self):
|
|
"""Start the inference server process."""
|
|
with self._server_lock:
|
|
if self.is_server_running():
|
|
return
|
|
|
|
print("[FaceMask] Starting inference server...")
|
|
|
|
# Find project root
|
|
# Assuming this file is in core/inference_client.py
|
|
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
server_script = os.path.join(root_dir, "server", "main.py")
|
|
|
|
# Use system python (assumed to have dependencies via Nix/venv)
|
|
# In user's environment, 'python' should refer to the environment python
|
|
python_cmd = "python"
|
|
|
|
# Start process
|
|
self.server_process = subprocess.Popen(
|
|
[python_cmd, server_script],
|
|
cwd=root_dir,
|
|
text=True,
|
|
preexec_fn=os.setsid, # Create new process group
|
|
)
|
|
|
|
# Wait for startup
|
|
for _ in range(20): # Wait up to 10 seconds
|
|
if self.is_server_running():
|
|
print("[FaceMask] Server started successfully")
|
|
return
|
|
|
|
# Check if process died
|
|
if self.server_process.poll() is not None:
|
|
raise RuntimeError(f"Server failed to start (rc={self.server_process.returncode})")
|
|
|
|
time.sleep(0.5)
|
|
|
|
raise RuntimeError("Server startup timed out")
|
|
|
|
def stop_server(self):
|
|
"""Stop the inference server."""
|
|
with self._server_lock:
|
|
if self.server_process:
|
|
print("[FaceMask] Stopping server...")
|
|
try:
|
|
os.killpg(os.getpgid(self.server_process.pid), signal.SIGTERM)
|
|
self.server_process.wait(timeout=3)
|
|
except (ProcessLookupError, subprocess.TimeoutExpired):
|
|
pass
|
|
finally:
|
|
self.server_process = None
|
|
|
|
def is_server_running(self) -> bool:
|
|
"""Check if server is responding."""
|
|
try:
|
|
with urllib.request.urlopen(f"{self.SERVER_URL}/status", timeout=1) as response:
|
|
return response.status == 200
|
|
except (urllib.error.URLError, ConnectionRefusedError, TimeoutError):
|
|
return False
|
|
|
|
def generate_mask(
|
|
self,
|
|
video_path: str,
|
|
output_dir: str,
|
|
start_frame: int,
|
|
end_frame: int,
|
|
conf_threshold: float,
|
|
iou_threshold: float,
|
|
mask_scale: float,
|
|
) -> str:
|
|
"""
|
|
Request mask generation.
|
|
|
|
Returns:
|
|
task_id (str)
|
|
"""
|
|
if not self.is_server_running():
|
|
self.start_server()
|
|
|
|
data = {
|
|
"video_path": video_path,
|
|
"output_dir": output_dir,
|
|
"start_frame": start_frame,
|
|
"end_frame": end_frame,
|
|
"conf_threshold": conf_threshold,
|
|
"iou_threshold": iou_threshold,
|
|
"mask_scale": mask_scale,
|
|
}
|
|
|
|
req = urllib.request.Request(
|
|
f"{self.SERVER_URL}/generate",
|
|
data=json.dumps(data).encode('utf-8'),
|
|
headers={'Content-Type': 'application/json'},
|
|
method='POST'
|
|
)
|
|
|
|
try:
|
|
with urllib.request.urlopen(req) as response:
|
|
result = json.loads(response.read().decode('utf-8'))
|
|
return result['id']
|
|
except urllib.error.HTTPError as e:
|
|
raise RuntimeError(f"Server error: {e.read().decode('utf-8')}")
|
|
|
|
def get_task_status(self, task_id: str) -> Dict[str, Any]:
|
|
"""Get status of a task."""
|
|
try:
|
|
with urllib.request.urlopen(f"{self.SERVER_URL}/tasks/{task_id}") as response:
|
|
return json.loads(response.read().decode('utf-8'))
|
|
except urllib.error.HTTPError:
|
|
return {"status": "unknown"}
|
|
|
|
def cancel_task(self, task_id: str):
|
|
"""Cancel a task."""
|
|
try:
|
|
req = urllib.request.Request(
|
|
f"{self.SERVER_URL}/tasks/{task_id}/cancel",
|
|
method='POST'
|
|
)
|
|
with urllib.request.urlopen(req):
|
|
pass
|
|
except urllib.error.HTTPError:
|
|
pass
|
|
|
|
|
|
# Singleton
|
|
_client: Optional[InferenceClient] = None
|
|
|
|
def get_client() -> InferenceClient:
|
|
global _client
|
|
if _client is None:
|
|
_client = InferenceClient()
|
|
return _client
|