Files
duck-preview/app/inference/worker_manager.py
bartool e9b474b1ed feat: Add video playback functionality and inference support
- Introduced VideoPlayer class to handle local video playback, emitting frames via frame_ready signal.
- Updated MainWindow to switch between camera and video sources, integrating video playback controls.
- Enhanced AppMenuBar with options to open video files and manage inference models.
- Implemented BboxOverlay for displaying detection results on video frames.
- Added InferenceManager to manage YOLO inference in a separate process, with error handling and restart logic.
- Created tests for BboxOverlay and InferenceManager to ensure functionality and robustness.
- Updated pyproject.toml to include optional dependencies for inference support.
2026-05-13 21:30:13 +02:00

351 lines
12 KiB
Python

"""InferenceManager — orchestrates the YOLO worker process from the GUI thread.
Responsibilities:
- Start / stop the worker process
- Submit frames (with drop-if-busy logic)
- Poll result queue via QTimer (never blocks the GUI thread)
- Watch process health via QTimer (auto-restart on crash)
- Emit Qt signals with results for BboxOverlay
"""
from __future__ import annotations
import logging
import multiprocessing
import time
from pathlib import Path
from PySide6.QtCore import QObject, QTimer, Signal, Slot
from PySide6.QtMultimedia import QVideoFrame
from app.config import (
INFERENCE_MAX_RESTARTS,
INFERENCE_POLL_INTERVAL_MS,
INFERENCE_WATCHDOG_INTERVAL_MS,
INFERENCE_WORKER_TIMEOUT_S,
)
from app.inference.bbox_overlay import Detection
from app.inference.worker import FramePacket, ResultPacket, run_worker
logger = logging.getLogger(__name__)
class InferenceManager(QObject):
"""
Manages the YOLO worker subprocess.
Signals:
detections_ready(detections, source_size)
Emitted in the GUI thread when a result arrives.
detections : list[Detection]
source_size : tuple[int, int] — (width, height) of inferred frame
inference_started() — worker is up and model is loaded
inference_stopped() — worker has exited cleanly
inference_error(str) — fatal error (max restarts exceeded)
Usage:
mgr = InferenceManager(parent=self)
mgr.detections_ready.connect(bbox_overlay.on_detections)
mgr.start("path/to/model.pt")
# ...
mgr.submit_frame(video_frame) # called by FrameDispatcher subscriber
# ...
mgr.stop()
"""
detections_ready = Signal(object, object) # list[Detection], tuple[int,int]
inference_started = Signal()
inference_stopped = Signal()
inference_error = Signal(str)
def __init__(self, parent: QObject | None = None) -> None:
super().__init__(parent)
self._model_path: str | None = None
self._process: multiprocessing.Process | None = None
self._input_queue: multiprocessing.Queue | None = None
self._output_queue: multiprocessing.Queue | None = None
self._stop_event: multiprocessing.Event | None = None
# Drop-if-busy flag — True while worker is processing a frame
self._busy: bool = False
self._frame_id: int = 0
# Restart tracking
self._restart_count: int = 0
self._last_result_time: float = 0.0
# Paused flag — inference can be suspended without stopping the process
self._paused: bool = False
# QTimers (GUI thread)
self._poll_timer = QTimer(self)
self._poll_timer.setInterval(INFERENCE_POLL_INTERVAL_MS)
self._poll_timer.timeout.connect(self._poll_output)
self._watchdog_timer = QTimer(self)
self._watchdog_timer.setInterval(INFERENCE_WATCHDOG_INTERVAL_MS)
self._watchdog_timer.timeout.connect(self._watchdog_check)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def start(self, model_path: str) -> None:
"""Load model and start the worker process."""
if not Path(model_path).exists():
msg = f"Model file not found: {model_path}"
logger.error(msg)
self.inference_error.emit(msg)
return
self._stop_worker()
self._model_path = model_path
self._restart_count = 0
self._paused = False
self._start_worker()
def stop(self) -> None:
"""Stop the worker process and reset state."""
self._stop_worker()
self._model_path = None
self._restart_count = 0
self._paused = False
def pause(self) -> None:
"""Suspend frame submission without stopping the process."""
self._paused = True
logger.debug("InferenceManager: paused")
def resume(self) -> None:
"""Resume frame submission."""
self._paused = False
logger.debug("InferenceManager: resumed")
@property
def is_running(self) -> bool:
return self._process is not None and self._process.is_alive()
@property
def is_paused(self) -> bool:
return self._paused
@property
def model_path(self) -> str | None:
return self._model_path
@Slot(QVideoFrame)
def submit_frame(self, frame: QVideoFrame) -> None:
"""
Attempt to submit a frame for inference.
Drops the frame silently if:
- manager is not running
- manager is paused
- worker is still busy with previous frame (drop_if_busy)
"""
if not self.is_running or self._paused or self._busy:
return
if not frame.isValid():
return
# Map frame to read-only memory, copy raw bytes, unmap
if not frame.map(QVideoFrame.MapMode.ReadOnly):
logger.warning("InferenceManager: failed to map QVideoFrame")
return
try:
width = frame.width()
height = frame.height()
raw = bytes(frame.bits(0)) # plane 0 — copies data
finally:
frame.unmap()
if not raw:
return
# Detect number of channels from byte count
expected_rgb = width * height * 3
expected_rgba = width * height * 4
if len(raw) >= expected_rgba:
# BGRA / RGBA — convert to RGB by stripping alpha and swapping B/R
try:
import numpy as np # noqa: PLC0415
arr = np.frombuffer(raw, dtype=np.uint8).reshape((height, width, 4))
# Qt delivers BGRA → swap to RGB
rgb = arr[:, :, [2, 1, 0]].copy()
raw = rgb.tobytes()
channels = 3
except Exception as exc:
logger.warning("Frame colour conversion failed: %s", exc)
return
elif len(raw) >= expected_rgb:
channels = 3
else:
logger.warning(
"Unexpected frame size: %d bytes for %dx%d",
len(raw), width, height,
)
return
self._frame_id += 1
packet = FramePacket(
frame_id=self._frame_id,
raw_bytes=raw,
width=width,
height=height,
channels=channels,
)
try:
self._input_queue.put_nowait(packet)
self._busy = True
logger.debug("InferenceManager: submitted frame %d", self._frame_id)
except Exception as exc:
logger.warning("InferenceManager: could not enqueue frame: %s", exc)
# ------------------------------------------------------------------
# Private — worker lifecycle
# ------------------------------------------------------------------
def _start_worker(self) -> None:
ctx = multiprocessing.get_context("spawn")
self._input_queue = ctx.Queue(maxsize=1)
self._output_queue = ctx.Queue(maxsize=4)
self._stop_event = ctx.Event()
self._process = ctx.Process(
target=run_worker,
args=(
self._model_path,
self._input_queue,
self._output_queue,
self._stop_event,
logging.WARNING,
),
daemon=True,
name="inference-worker",
)
self._process.start()
self._busy = False
self._last_result_time = time.monotonic()
self._poll_timer.start()
self._watchdog_timer.start()
logger.info(
"Inference worker started (pid=%d, model=%s)",
self._process.pid, self._model_path,
)
self.inference_started.emit()
def _stop_worker(self) -> None:
self._poll_timer.stop()
self._watchdog_timer.stop()
if self._stop_event is not None:
self._stop_event.set()
if self._process is not None:
self._process.join(timeout=3.0)
if self._process.is_alive():
logger.warning("Worker did not stop cleanly — terminating")
self._process.terminate()
self._process.join(timeout=2.0)
self._process = None
self._input_queue = None
self._output_queue = None
self._stop_event = None
self._busy = False
logger.info("Inference worker stopped")
self.inference_stopped.emit()
# ------------------------------------------------------------------
# Private — timers
# ------------------------------------------------------------------
@Slot()
def _poll_output(self) -> None:
"""Drain the output queue (called every INFERENCE_POLL_INTERVAL_MS ms)."""
if self._output_queue is None:
return
try:
while True:
item = self._output_queue.get_nowait()
if item is None:
# Worker signalled a fatal load error
logger.error("Worker reported model load failure")
self._handle_crash("Model failed to load in worker process")
return
packet: ResultPacket = item
self._busy = False
self._last_result_time = time.monotonic()
detections = [
Detection(x1, y1, x2, y2, conf, label)
for x1, y1, x2, y2, conf, label in packet.detections
]
source_size = (packet.width, packet.height)
logger.debug(
"InferenceManager: frame %d%d detections",
packet.frame_id, len(detections),
)
self.detections_ready.emit(detections, source_size)
except Exception:
# Empty queue — normal
pass
@Slot()
def _watchdog_check(self) -> None:
"""Detect crashed or hung worker process."""
if self._process is None:
return
# Process died unexpectedly
if not self._process.is_alive():
exit_code = self._process.exitcode
logger.error("Worker process died (exitcode=%s)", exit_code)
self._handle_crash(f"Worker process exited with code {exit_code}")
return
# Worker alive but hasn't responded for too long (hung during inference)
if self._busy:
elapsed = time.monotonic() - self._last_result_time
if elapsed > INFERENCE_WORKER_TIMEOUT_S:
logger.error(
"Worker timeout: no response for %.1f s — restarting", elapsed
)
self._process.terminate()
self._process.join(timeout=2.0)
self._handle_crash("Worker timed out (hung during inference)")
def _handle_crash(self, reason: str) -> None:
"""Decide whether to auto-restart or give up."""
# Clean up process handles (already dead)
self._poll_timer.stop()
self._watchdog_timer.stop()
self._process = None
self._busy = False
if self._restart_count < INFERENCE_MAX_RESTARTS:
self._restart_count += 1
logger.warning(
"Auto-restarting worker (attempt %d/%d): %s",
self._restart_count, INFERENCE_MAX_RESTARTS, reason,
)
self._start_worker()
else:
msg = (
f"Inference worker failed after {INFERENCE_MAX_RESTARTS} restarts. "
f"Last error: {reason}"
)
logger.error(msg)
self.inference_error.emit(msg)