- Introduced VideoPlayer class to handle local video playback, emitting frames via frame_ready signal. - Updated MainWindow to switch between camera and video sources, integrating video playback controls. - Enhanced AppMenuBar with options to open video files and manage inference models. - Implemented BboxOverlay for displaying detection results on video frames. - Added InferenceManager to manage YOLO inference in a separate process, with error handling and restart logic. - Created tests for BboxOverlay and InferenceManager to ensure functionality and robustness. - Updated pyproject.toml to include optional dependencies for inference support.
197 lines
6.0 KiB
Python
197 lines
6.0 KiB
Python
"""YOLO inference worker — runs in a separate process.
|
|
|
|
This module contains only plain functions (no Qt, no PySide6) so it can
|
|
safely be imported and executed in a child process via multiprocessing.
|
|
|
|
IPC protocol
|
|
------------
|
|
input_queue receives : FramePacket (frame_id, raw_bytes, width, height, channels)
|
|
output_queue sends : ResultPacket (frame_id, detections, width, height)
|
|
stop_event : multiprocessing.Event — set by parent to request clean exit
|
|
|
|
Detection format (namedtuple-compatible plain tuple):
|
|
(x1, y1, x2, y2, conf, label) — all floats/str, x/y in source-frame pixels
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import platform
|
|
import sys
|
|
from multiprocessing import Event, Queue
|
|
from queue import Empty
|
|
from typing import NamedTuple
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Data structures shared between worker and manager
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class FramePacket(NamedTuple):
|
|
frame_id: int
|
|
raw_bytes: bytes # RGB bytes, row-major, shape = (height, width, channels)
|
|
width: int
|
|
height: int
|
|
channels: int # always 3 (RGB)
|
|
|
|
|
|
class ResultPacket(NamedTuple):
|
|
frame_id: int
|
|
detections: list # list of (x1, y1, x2, y2, conf, label) tuples
|
|
width: int # source frame width (for overlay scaling)
|
|
height: int # source frame height
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Worker entry point
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def run_worker(
|
|
model_path: str,
|
|
input_queue: Queue,
|
|
output_queue: Queue,
|
|
stop_event: Event,
|
|
log_level: int = logging.WARNING,
|
|
) -> None:
|
|
"""
|
|
Main loop of the inference worker process.
|
|
|
|
Loads the YOLO model once, then processes frames from input_queue
|
|
until stop_event is set. Results are posted to output_queue.
|
|
|
|
This function is designed to be the target of multiprocessing.Process.
|
|
It must NOT import PySide6 or any Qt module.
|
|
"""
|
|
_configure_worker_logging(log_level)
|
|
logger.info("Inference worker starting (pid=%d)", _getpid())
|
|
|
|
try:
|
|
model = _load_model(model_path)
|
|
except Exception as exc:
|
|
logger.error("Failed to load model '%s': %s", model_path, exc)
|
|
# Signal failure by putting None — manager treats it as error
|
|
output_queue.put(None)
|
|
return
|
|
|
|
logger.info("Model loaded: %s", model_path)
|
|
|
|
while not stop_event.is_set():
|
|
try:
|
|
packet: FramePacket = input_queue.get(timeout=0.1)
|
|
except Empty:
|
|
continue
|
|
except Exception as exc:
|
|
logger.error("Error reading input queue: %s", exc)
|
|
break
|
|
|
|
try:
|
|
result = _infer(model, packet)
|
|
output_queue.put(result)
|
|
except Exception as exc:
|
|
logger.error("Inference error (frame %d): %s", packet.frame_id, exc)
|
|
# Put empty result so manager knows we're still alive
|
|
output_queue.put(ResultPacket(
|
|
frame_id=packet.frame_id,
|
|
detections=[],
|
|
width=packet.width,
|
|
height=packet.height,
|
|
))
|
|
|
|
logger.info("Inference worker stopping")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _load_model(model_path: str):
|
|
"""Load YOLO model with best available device."""
|
|
from ultralytics import YOLO # noqa: PLC0415
|
|
|
|
device = _select_device()
|
|
logger.info("Loading YOLO model on device='%s'", device)
|
|
model = YOLO(model_path)
|
|
# Warm up — run on a tiny dummy to JIT-compile kernels
|
|
try:
|
|
import numpy as np # noqa: PLC0415
|
|
dummy = np.zeros((64, 64, 3), dtype=np.uint8)
|
|
model(dummy, device=device, verbose=False)
|
|
except Exception as exc:
|
|
logger.warning("Warm-up failed (non-fatal): %s", exc)
|
|
return model
|
|
|
|
|
|
def _select_device() -> str:
|
|
"""
|
|
Choose inference device.
|
|
|
|
Priority:
|
|
- macOS → "mps" if available (Metal GPU), else "cpu"
|
|
- others → "cpu"
|
|
"""
|
|
system = platform.system()
|
|
if system == "Darwin":
|
|
try:
|
|
import torch # noqa: PLC0415
|
|
if torch.backends.mps.is_available():
|
|
logger.info("MPS (Metal) available — using GPU")
|
|
return "mps"
|
|
except Exception:
|
|
pass
|
|
logger.info("MPS not available — using CPU")
|
|
return "cpu"
|
|
|
|
|
|
def _infer(model, packet: FramePacket) -> ResultPacket:
|
|
"""Run model on one frame, return ResultPacket."""
|
|
import numpy as np # noqa: PLC0415
|
|
|
|
frame_np = np.frombuffer(packet.raw_bytes, dtype=np.uint8).reshape(
|
|
(packet.height, packet.width, packet.channels)
|
|
)
|
|
|
|
device = _select_device()
|
|
results = model(frame_np, device=device, verbose=False)
|
|
|
|
detections = []
|
|
for r in results:
|
|
if r.boxes is None:
|
|
continue
|
|
boxes = r.boxes
|
|
for i in range(len(boxes)):
|
|
xyxy = boxes.xyxy[i].tolist() # [x1, y1, x2, y2] in source pixels
|
|
conf = float(boxes.conf[i])
|
|
cls_idx = int(boxes.cls[i])
|
|
label = (
|
|
r.names[cls_idx]
|
|
if r.names and cls_idx in r.names
|
|
else str(cls_idx)
|
|
)
|
|
detections.append((
|
|
float(xyxy[0]), float(xyxy[1]),
|
|
float(xyxy[2]), float(xyxy[3]),
|
|
conf, label,
|
|
))
|
|
|
|
return ResultPacket(
|
|
frame_id=packet.frame_id,
|
|
detections=detections,
|
|
width=packet.width,
|
|
height=packet.height,
|
|
)
|
|
|
|
|
|
def _configure_worker_logging(level: int) -> None:
|
|
logging.basicConfig(
|
|
level=level,
|
|
format="[worker %(process)d] %(levelname)s %(name)s: %(message)s",
|
|
stream=sys.stderr,
|
|
)
|
|
|
|
|
|
def _getpid() -> int:
|
|
import os # noqa: PLC0415
|
|
return os.getpid()
|