mp4保存
This commit is contained in:
parent
eeb8400727
commit
c15cd659e3
|
|
@ -5,15 +5,15 @@ Manages the server process and handles HTTP communication
|
|||
using standard library (avoiding requests dependency).
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import time
|
||||
import json
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
import threading
|
||||
import os
|
||||
import signal
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
|
||||
class InferenceClient:
|
||||
|
|
@ -46,24 +46,24 @@ class InferenceClient:
|
|||
# Load environment variables from .env file if it exists
|
||||
env_file = os.path.join(root_dir, ".env")
|
||||
if os.path.exists(env_file):
|
||||
with open(env_file, 'r') as f:
|
||||
with open(env_file, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#') and '=' in line:
|
||||
key, value = line.split('=', 1)
|
||||
if line and not line.startswith("#") and "=" in line:
|
||||
key, value = line.split("=", 1)
|
||||
server_env[key] = value
|
||||
print(f"[FaceMask] Loaded environment from: {env_file}")
|
||||
|
||||
# Clean PYTHONPATH to avoid conflicts with Nix Python packages
|
||||
# Only include project root to allow local imports
|
||||
server_env['PYTHONPATH'] = root_dir
|
||||
server_env["PYTHONPATH"] = root_dir
|
||||
|
||||
# Remove Python-related environment variables that might cause conflicts
|
||||
# These can cause venv to import packages from Nix instead of venv
|
||||
env_vars_to_remove = [
|
||||
'PYTHONUNBUFFERED',
|
||||
'__PYVENV_LAUNCHER__', # macOS venv variable
|
||||
'VIRTUAL_ENV', # Will be set by venv's Python automatically
|
||||
"PYTHONUNBUFFERED",
|
||||
"__PYVENV_LAUNCHER__", # macOS venv variable
|
||||
"VIRTUAL_ENV", # Will be set by venv's Python automatically
|
||||
]
|
||||
for var in env_vars_to_remove:
|
||||
server_env.pop(var, None)
|
||||
|
|
@ -73,25 +73,27 @@ class InferenceClient:
|
|||
if os.path.isdir(venv_bin):
|
||||
# Build a clean PATH with venv first, then essential system paths
|
||||
# Filter out any Nix Python-specific paths to avoid version conflicts
|
||||
current_path = server_env.get('PATH', '')
|
||||
path_entries = current_path.split(':')
|
||||
current_path = server_env.get("PATH", "")
|
||||
path_entries = current_path.split(":")
|
||||
|
||||
# Filter out Nix Python 3.11 paths
|
||||
filtered_paths = [
|
||||
p for p in path_entries
|
||||
if not ('/python3.11/' in p.lower() or '/python3-3.11' in p.lower())
|
||||
p
|
||||
for p in path_entries
|
||||
if not ("/python3.11/" in p.lower() or "/python3-3.11" in p.lower())
|
||||
]
|
||||
|
||||
# Reconstruct PATH with venv first
|
||||
clean_path = ':'.join([venv_bin] + filtered_paths)
|
||||
server_env['PATH'] = clean_path
|
||||
clean_path = ":".join([venv_bin] + filtered_paths)
|
||||
server_env["PATH"] = clean_path
|
||||
print(f"[FaceMask] Using venv from: {venv_bin}")
|
||||
|
||||
# Prepare log file for server output
|
||||
import tempfile
|
||||
|
||||
log_dir = tempfile.gettempdir()
|
||||
self.log_file_path = os.path.join(log_dir, "facemask_server.log")
|
||||
self.log_file = open(self.log_file_path, 'w', buffering=1) # Line buffered
|
||||
self.log_file = open(self.log_file_path, "w", buffering=1) # Line buffered
|
||||
print(f"[FaceMask] Server log: {self.log_file_path}")
|
||||
|
||||
# Start process with 'python' command (will use venv if PATH is set correctly)
|
||||
|
|
@ -120,12 +122,12 @@ class InferenceClient:
|
|||
try:
|
||||
if self.log_file:
|
||||
self.log_file.close()
|
||||
with open(self.log_file_path, 'r') as f:
|
||||
with open(self.log_file_path, "r") as f:
|
||||
log_content = f.read()
|
||||
if log_content.strip():
|
||||
print("[FaceMask] Server log:")
|
||||
# Show last 50 lines
|
||||
lines = log_content.strip().split('\n')
|
||||
lines = log_content.strip().split("\n")
|
||||
for line in lines[-50:]:
|
||||
print(line)
|
||||
except Exception as e:
|
||||
|
|
@ -143,11 +145,11 @@ class InferenceClient:
|
|||
try:
|
||||
if self.log_file:
|
||||
self.log_file.close()
|
||||
with open(self.log_file_path, 'r') as f:
|
||||
with open(self.log_file_path, "r") as f:
|
||||
log_content = f.read()
|
||||
if log_content.strip():
|
||||
print("[FaceMask] Server log (partial):")
|
||||
lines = log_content.strip().split('\n')
|
||||
lines = log_content.strip().split("\n")
|
||||
for line in lines[-30:]:
|
||||
print(line)
|
||||
except Exception:
|
||||
|
|
@ -179,7 +181,9 @@ class InferenceClient:
|
|||
def is_server_running(self) -> bool:
|
||||
"""Check if server is responding."""
|
||||
try:
|
||||
with urllib.request.urlopen(f"{self.SERVER_URL}/status", timeout=1) as response:
|
||||
with urllib.request.urlopen(
|
||||
f"{self.SERVER_URL}/status", timeout=1
|
||||
) as response:
|
||||
return response.status == 200
|
||||
except (urllib.error.URLError, ConnectionRefusedError, TimeoutError):
|
||||
return False
|
||||
|
|
@ -215,23 +219,25 @@ class InferenceClient:
|
|||
|
||||
req = urllib.request.Request(
|
||||
f"{self.SERVER_URL}/generate",
|
||||
data=json.dumps(data).encode('utf-8'),
|
||||
headers={'Content-Type': 'application/json'},
|
||||
method='POST'
|
||||
data=json.dumps(data).encode("utf-8"),
|
||||
headers={"Content-Type": "application/json"},
|
||||
method="POST",
|
||||
)
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req) as response:
|
||||
result = json.loads(response.read().decode('utf-8'))
|
||||
return result['id']
|
||||
result = json.loads(response.read().decode("utf-8"))
|
||||
return result["id"]
|
||||
except urllib.error.HTTPError as e:
|
||||
raise RuntimeError(f"Server error: {e.read().decode('utf-8')}")
|
||||
|
||||
def get_task_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""Get status of a task."""
|
||||
try:
|
||||
with urllib.request.urlopen(f"{self.SERVER_URL}/tasks/{task_id}") as response:
|
||||
return json.loads(response.read().decode('utf-8'))
|
||||
with urllib.request.urlopen(
|
||||
f"{self.SERVER_URL}/tasks/{task_id}"
|
||||
) as response:
|
||||
return json.loads(response.read().decode("utf-8"))
|
||||
except urllib.error.HTTPError:
|
||||
return {"status": "unknown"}
|
||||
|
||||
|
|
@ -239,8 +245,7 @@ class InferenceClient:
|
|||
"""Cancel a task."""
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
f"{self.SERVER_URL}/tasks/{task_id}/cancel",
|
||||
method='POST'
|
||||
f"{self.SERVER_URL}/tasks/{task_id}/cancel", method="POST"
|
||||
)
|
||||
with urllib.request.urlopen(req):
|
||||
pass
|
||||
|
|
@ -251,6 +256,7 @@ class InferenceClient:
|
|||
# Singleton
|
||||
_client: Optional[InferenceClient] = None
|
||||
|
||||
|
||||
def get_client() -> InferenceClient:
|
||||
global _client
|
||||
if _client is None:
|
||||
|
|
|
|||
|
|
@ -48,8 +48,8 @@
|
|||
export ROCM_PATH="${pkgs.rocmPackages.clr}"
|
||||
export HSA_OVERRIDE_GFX_VERSION="11.0.0" # RX 7900 (RDNA 3 / gfx1100)
|
||||
|
||||
# LD_LIBRARY_PATH: ROCm、C++標準ライブラリ、その他必要なライブラリ
|
||||
export LD_LIBRARY_PATH="${pkgs.stdenv.cc.cc.lib}/lib:${pkgs.zlib}/lib:${pkgs.zstd.out}/lib:${pkgs.rocmPackages.clr}/lib:${pkgs.rocmPackages.rocm-runtime}/lib:$LD_LIBRARY_PATH"
|
||||
# LD_LIBRARY_PATH: ROCm libraries FIRST (critical for GPU inference)
|
||||
export LD_LIBRARY_PATH="${pkgs.rocmPackages.clr}/lib:${pkgs.rocmPackages.rocm-runtime}/lib:${pkgs.stdenv.cc.cc.lib}/lib:${pkgs.zlib}/lib:${pkgs.zstd.out}/lib:$LD_LIBRARY_PATH"
|
||||
|
||||
# venvのセットアップ
|
||||
VENV_DIR="$PWD/.venv"
|
||||
|
|
@ -89,12 +89,11 @@
|
|||
export BLENDER_USER_ADDONS="$BLENDER_USER_SCRIPTS/addons"
|
||||
|
||||
# 環境変数をファイルに保存(サーバープロセス用)
|
||||
# CRITICAL: ROCm library paths MUST come first for GPU inference
|
||||
cat > "$PWD/.env" << EOF
|
||||
LD_LIBRARY_PATH=${pkgs.stdenv.cc.cc.lib}/lib:${pkgs.zlib}/lib:${pkgs.zstd.out}/lib:${pkgs.rocmPackages.clr}/lib:${pkgs.rocmPackages.rocm-runtime}/lib
|
||||
LD_LIBRARY_PATH=${pkgs.rocmPackages.clr}/lib:${pkgs.rocmPackages.rocm-runtime}/lib:${pkgs.stdenv.cc.cc.lib}/lib:${pkgs.zlib}/lib:${pkgs.zstd.out}/lib
|
||||
ROCM_PATH=${pkgs.rocmPackages.clr}
|
||||
HSA_OVERRIDE_GFX_VERSION=11.0.0
|
||||
PYTORCH_ROCM_ARCH=gfx1100
|
||||
ROCBLAS_TENSILE_LIBPATH=${pkgs.rocmPackages.clr}/lib/rocblas/library
|
||||
EOF
|
||||
|
||||
echo "[Setup] Environment ready with GPU support"
|
||||
|
|
|
|||
|
|
@ -149,20 +149,40 @@ class SEQUENCER_OT_generate_face_mask(Operator):
|
|||
expected_frames: Number of frames expected
|
||||
|
||||
Returns:
|
||||
True if cache exists and has at least 90% of expected frames
|
||||
True if cache exists and is valid
|
||||
"""
|
||||
if not os.path.exists(cache_dir):
|
||||
return False
|
||||
|
||||
# Check for MP4 video (new format)
|
||||
mask_video = os.path.join(cache_dir, "mask.mp4")
|
||||
if os.path.exists(mask_video):
|
||||
# Verify video has expected number of frames
|
||||
import cv2
|
||||
cap = cv2.VideoCapture(mask_video)
|
||||
if cap.isOpened():
|
||||
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
cap.release()
|
||||
# Accept cache if at least 90% of frames exist
|
||||
return frame_count >= expected_frames * 0.9
|
||||
cap.release()
|
||||
return False
|
||||
|
||||
# Fallback: check for PNG sequence (backward compatibility)
|
||||
mask_files = [f for f in os.listdir(cache_dir)
|
||||
if f.startswith("mask_") and f.endswith(".png")]
|
||||
|
||||
# Accept cache if at least 90% of frames exist
|
||||
# (some frames may have been skipped due to read errors)
|
||||
return len(mask_files) >= expected_frames * 0.9
|
||||
|
||||
def _add_mask_strip(self, context, source_strip_name: str, mask_dir: str):
|
||||
"""Add mask image sequence as a new strip."""
|
||||
def _add_mask_strip(self, context, source_strip_name: str, mask_path: str):
|
||||
"""Add mask video as a new strip.
|
||||
|
||||
Args:
|
||||
context: Blender context
|
||||
source_strip_name: Name of the source video strip
|
||||
mask_path: Path to mask video file or directory (for backward compatibility)
|
||||
"""
|
||||
scene = context.scene
|
||||
seq_editor = scene.sequence_editor
|
||||
|
||||
|
|
@ -174,17 +194,52 @@ class SEQUENCER_OT_generate_face_mask(Operator):
|
|||
if not source_strip:
|
||||
return
|
||||
|
||||
# Get first mask image
|
||||
mask_files = sorted([
|
||||
f for f in os.listdir(mask_dir)
|
||||
if f.startswith("mask_") and f.endswith(".png")
|
||||
])
|
||||
# Check if mask_path is a video file or directory (backward compatibility)
|
||||
if os.path.isfile(mask_path):
|
||||
# New format: single MP4 file
|
||||
mask_video = mask_path
|
||||
else:
|
||||
# Old format: directory with PNG sequence (backward compatibility)
|
||||
mask_video = os.path.join(mask_path, "mask.mp4")
|
||||
if not os.path.exists(mask_video):
|
||||
# Fallback to PNG sequence
|
||||
mask_files = sorted([
|
||||
f for f in os.listdir(mask_path)
|
||||
if f.startswith("mask_") and f.endswith(".png")
|
||||
])
|
||||
if not mask_files:
|
||||
return
|
||||
first_mask = os.path.join(mask_path, mask_files[0])
|
||||
self._add_mask_strip_png_sequence(context, source_strip_name, mask_path, mask_files, first_mask)
|
||||
return
|
||||
|
||||
if not mask_files:
|
||||
# Find an empty channel
|
||||
used_channels = {s.channel for s in seq_editor.strips}
|
||||
new_channel = source_strip.channel + 1
|
||||
while new_channel in used_channels:
|
||||
new_channel += 1
|
||||
|
||||
# Add movie strip (Blender 5.0 API)
|
||||
mask_strip = seq_editor.strips.new_movie(
|
||||
name=f"{source_strip_name}_mask",
|
||||
filepath=mask_video,
|
||||
channel=new_channel,
|
||||
frame_start=source_strip.frame_final_start,
|
||||
)
|
||||
|
||||
# Set blend mode for mask
|
||||
mask_strip.blend_type = 'ALPHA_OVER'
|
||||
mask_strip.blend_alpha = 0.5
|
||||
|
||||
def _add_mask_strip_png_sequence(self, context, source_strip_name, mask_dir, mask_files, first_mask):
|
||||
"""Backward compatibility: Add PNG sequence as mask strip."""
|
||||
scene = context.scene
|
||||
seq_editor = scene.sequence_editor
|
||||
source_strip = seq_editor.strips.get(source_strip_name)
|
||||
|
||||
if not source_strip:
|
||||
return
|
||||
|
||||
first_mask = os.path.join(mask_dir, mask_files[0])
|
||||
|
||||
# Find an empty channel
|
||||
used_channels = {s.channel for s in seq_editor.strips}
|
||||
new_channel = source_strip.channel + 1
|
||||
|
|
|
|||
104
server/main.py
104
server/main.py
|
|
@ -8,6 +8,29 @@ GPU-accelerated face detection using ONNX Runtime.
|
|||
import os
|
||||
import sys
|
||||
import platform
|
||||
|
||||
# CRITICAL: Fix LD_LIBRARY_PATH before importing cv2 or torch
|
||||
# cv2 adds its own lib path to the front, which can override ROCm libraries
|
||||
def fix_library_path():
|
||||
"""Ensure ROCm libraries are loaded before cv2's bundled libraries."""
|
||||
ld_path = os.environ.get('LD_LIBRARY_PATH', '')
|
||||
|
||||
# Split and filter paths
|
||||
paths = [p for p in ld_path.split(':') if p]
|
||||
|
||||
# Separate ROCm/GPU paths from other paths
|
||||
rocm_paths = [p for p in paths if 'rocm' in p.lower() or 'clr-' in p or 'hip' in p.lower()]
|
||||
other_paths = [p for p in paths if p not in rocm_paths]
|
||||
|
||||
# Rebuild with ROCm paths first
|
||||
if rocm_paths:
|
||||
new_ld_path = ':'.join(rocm_paths + other_paths)
|
||||
os.environ['LD_LIBRARY_PATH'] = new_ld_path
|
||||
print(f"[FaceMask] Fixed LD_LIBRARY_PATH to prioritize ROCm libraries")
|
||||
|
||||
# Fix library path BEFORE any other imports
|
||||
fix_library_path()
|
||||
|
||||
import threading
|
||||
import uuid
|
||||
import queue
|
||||
|
|
@ -61,7 +84,11 @@ class GenerateRequest(BaseModel):
|
|||
mask_scale: float = 1.5
|
||||
|
||||
def process_video_task(task_id: str, req: GenerateRequest):
|
||||
"""Background task to process video."""
|
||||
"""Background task to process video with async MP4 output."""
|
||||
writer = None
|
||||
write_queue = None
|
||||
writer_thread = None
|
||||
|
||||
try:
|
||||
tasks[task_id].status = TaskStatus.PROCESSING
|
||||
cancel_event = cancel_events.get(task_id)
|
||||
|
|
@ -78,7 +105,6 @@ def process_video_task(task_id: str, req: GenerateRequest):
|
|||
conf_threshold=req.conf_threshold,
|
||||
iou_threshold=req.iou_threshold
|
||||
)
|
||||
# Ensure model is loaded
|
||||
_ = detector.model
|
||||
|
||||
# Open video
|
||||
|
|
@ -88,7 +114,10 @@ def process_video_task(task_id: str, req: GenerateRequest):
|
|||
tasks[task_id].message = "Failed to open video"
|
||||
return
|
||||
|
||||
# Determine frame range
|
||||
# Get video properties
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
total_video_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
end_frame = min(req.end_frame, total_video_frames - 1)
|
||||
frames_to_process = end_frame - req.start_frame + 1
|
||||
|
|
@ -98,7 +127,37 @@ def process_video_task(task_id: str, req: GenerateRequest):
|
|||
# Ensure output directory exists
|
||||
os.makedirs(req.output_dir, exist_ok=True)
|
||||
|
||||
print(f"Starting processing: {req.video_path} ({frames_to_process} frames)")
|
||||
# Setup MP4 writer (grayscale)
|
||||
output_video_path = os.path.join(req.output_dir, "mask.mp4")
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
writer = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height), isColor=False)
|
||||
|
||||
if not writer.isOpened():
|
||||
tasks[task_id].status = TaskStatus.FAILED
|
||||
tasks[task_id].message = "Failed to create video writer"
|
||||
cap.release()
|
||||
return
|
||||
|
||||
# Async writer setup
|
||||
write_queue = queue.Queue(maxsize=30) # Buffer up to 30 frames
|
||||
writer_running = threading.Event()
|
||||
writer_running.set()
|
||||
|
||||
def async_writer():
|
||||
"""Background thread for writing frames to video."""
|
||||
while writer_running.is_set() or not write_queue.empty():
|
||||
try:
|
||||
mask = write_queue.get(timeout=0.1)
|
||||
if mask is not None:
|
||||
writer.write(mask)
|
||||
write_queue.task_done()
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
writer_thread = threading.Thread(target=async_writer, daemon=True)
|
||||
writer_thread.start()
|
||||
|
||||
print(f"Starting processing: {req.video_path} ({frames_to_process} frames) -> {output_video_path}")
|
||||
|
||||
# Process loop
|
||||
current_count = 0
|
||||
|
|
@ -123,22 +182,28 @@ def process_video_task(task_id: str, req: GenerateRequest):
|
|||
mask_scale=req.mask_scale
|
||||
)
|
||||
|
||||
# Save
|
||||
mask_filename = f"mask_{current_count:06d}.png" # Note: using relative index for filename
|
||||
mask_path = os.path.join(req.output_dir, mask_filename)
|
||||
cv2.imwrite(mask_path, mask)
|
||||
# Async write to queue
|
||||
write_queue.put(mask)
|
||||
|
||||
# Update progress
|
||||
current_count += 1
|
||||
tasks[task_id].progress = current_count
|
||||
|
||||
# Cleanup
|
||||
writer_running.clear()
|
||||
write_queue.join() # Wait for all frames to be written
|
||||
if writer_thread:
|
||||
writer_thread.join(timeout=5)
|
||||
|
||||
cap.release()
|
||||
if writer:
|
||||
writer.release()
|
||||
|
||||
if tasks[task_id].status == TaskStatus.PROCESSING:
|
||||
tasks[task_id].status = TaskStatus.COMPLETED
|
||||
tasks[task_id].result_path = req.output_dir
|
||||
tasks[task_id].result_path = output_video_path # Return video path
|
||||
tasks[task_id].message = "Processing completed successfully"
|
||||
print(f"Task {task_id} completed.")
|
||||
print(f"Task {task_id} completed: {output_video_path}")
|
||||
|
||||
except Exception as e:
|
||||
tasks[task_id].status = TaskStatus.FAILED
|
||||
|
|
@ -230,12 +295,21 @@ def log_startup_diagnostics():
|
|||
for var in rocm_vars:
|
||||
value = os.environ.get(var)
|
||||
if value:
|
||||
# Truncate very long values
|
||||
if len(value) > 200:
|
||||
display_value = value[:200] + "... (truncated)"
|
||||
# For LD_LIBRARY_PATH, show if ROCm paths are included
|
||||
if var == 'LD_LIBRARY_PATH':
|
||||
has_rocm = 'rocm' in value.lower() or 'clr-' in value
|
||||
has_hip = 'hip' in value.lower()
|
||||
print(f" {var}: {value[:100]}...")
|
||||
print(f" Contains ROCm paths: {has_rocm}")
|
||||
print(f" Contains HIP paths: {has_hip}")
|
||||
if not has_rocm:
|
||||
print(f" ⚠️ WARNING: ROCm library paths not found!")
|
||||
else:
|
||||
display_value = value
|
||||
print(f" {var}: {display_value}")
|
||||
if len(value) > 200:
|
||||
display_value = value[:200] + "... (truncated)"
|
||||
else:
|
||||
display_value = value
|
||||
print(f" {var}: {display_value}")
|
||||
else:
|
||||
print(f" {var}: (not set)")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user