"""InferenceManager — orchestrates the YOLO worker process from the GUI thread. Responsibilities: - Start / stop the worker process - Submit frames (with drop-if-busy logic) - Poll result queue via QTimer (never blocks the GUI thread) - Watch process health via QTimer (auto-restart on crash) - Emit Qt signals with results for BboxOverlay """ from __future__ import annotations import logging import multiprocessing import time from pathlib import Path from PySide6.QtCore import QObject, QTimer, Signal, Slot from PySide6.QtMultimedia import QVideoFrame from app.config import ( INFERENCE_MAX_RESTARTS, INFERENCE_POLL_INTERVAL_MS, INFERENCE_WATCHDOG_INTERVAL_MS, INFERENCE_WORKER_TIMEOUT_S, ) from app.inference.bbox_overlay import Detection from app.inference.worker import FramePacket, ResultPacket, run_worker logger = logging.getLogger(__name__) class InferenceManager(QObject): """ Manages the YOLO worker subprocess. Signals: detections_ready(detections, source_size) Emitted in the GUI thread when a result arrives. detections : list[Detection] source_size : tuple[int, int] — (width, height) of inferred frame inference_started() — worker is up and model is loaded inference_stopped() — worker has exited cleanly inference_error(str) — fatal error (max restarts exceeded) Usage: mgr = InferenceManager(parent=self) mgr.detections_ready.connect(bbox_overlay.on_detections) mgr.start("path/to/model.pt") # ... mgr.submit_frame(video_frame) # called by FrameDispatcher subscriber # ... mgr.stop() """ detections_ready = Signal(object, object) # list[Detection], tuple[int,int] detection_count_updated = Signal(int) # total frames with detections so far inference_started = Signal() inference_stopped = Signal() inference_error = Signal(str) def __init__(self, parent: QObject | None = None) -> None: super().__init__(parent) self._model_path: str | None = None self._process: multiprocessing.Process | None = None self._input_queue: multiprocessing.Queue | None = None self._output_queue: multiprocessing.Queue | None = None self._stop_event: multiprocessing.Event | None = None # Drop-if-busy flag — True while worker is processing a frame self._busy: bool = False self._frame_id: int = 0 # Restart tracking self._restart_count: int = 0 self._last_result_time: float = 0.0 # Paused flag — inference can be suspended without stopping the process self._paused: bool = False # Detection counter — frames on which at least one detection occurred self._detection_frame_count: int = 0 # QTimers (GUI thread) self._poll_timer = QTimer(self) self._poll_timer.setInterval(INFERENCE_POLL_INTERVAL_MS) self._poll_timer.timeout.connect(self._poll_output) self._watchdog_timer = QTimer(self) self._watchdog_timer.setInterval(INFERENCE_WATCHDOG_INTERVAL_MS) self._watchdog_timer.timeout.connect(self._watchdog_check) # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def start(self, model_path: str) -> None: """Load model and start the worker process.""" if not Path(model_path).exists(): msg = f"Model file not found: {model_path}" logger.error(msg) self.inference_error.emit(msg) return self._stop_worker() self._model_path = model_path self._restart_count = 0 self._paused = False self._detection_frame_count = 0 self._start_worker() def stop(self) -> None: """Stop the worker process and reset state.""" self._stop_worker() self._model_path = None self._restart_count = 0 self._paused = False def pause(self) -> None: """Suspend frame submission without stopping the process.""" self._paused = True logger.debug("InferenceManager: paused") def resume(self) -> None: """Resume frame submission.""" self._paused = False logger.debug("InferenceManager: resumed") @property def is_running(self) -> bool: return self._process is not None and self._process.is_alive() @property def is_paused(self) -> bool: return self._paused @property def model_path(self) -> str | None: return self._model_path @Slot(QVideoFrame) def submit_frame(self, frame: QVideoFrame) -> None: """ Attempt to submit a frame for inference. Drops the frame silently if: - manager is not running - manager is paused - worker is still busy with previous frame (drop_if_busy) Frame conversion strategy: Use QVideoFrame.toImage() → QImage.Format_RGB32 → bits(). This handles all pixel formats (NV12, YUV420P, BGRA, MJPG, etc.) because Qt decodes them internally. The cost is a CPU colour-space conversion, but it only happens when the worker is idle (drop_if_busy). """ if not self.is_running or self._paused or self._busy: return if not frame.isValid(): return # Convert frame to RGB via Qt's built-in decoder. # toImage() handles NV12, YUV420P, MJPG, BGRA — any pixel format. image = frame.toImage() if image.isNull(): logger.warning("InferenceManager: toImage() returned null") return width = image.width() height = image.height() # Ensure we have packed RGB32 (BGRX on little-endian, 4 bytes/pixel) from PySide6.QtGui import QImage # noqa: PLC0415 if image.format() != QImage.Format.Format_RGB32: image = image.convertToFormat(QImage.Format.Format_RGB32) # Extract RGB bytes (drop alpha/padding channel) try: import numpy as np # noqa: PLC0415 # bits() returns BGRX (B G R 0xFF) for Format_RGB32 ptr = image.bits() arr = np.frombuffer(ptr, dtype=np.uint8).reshape((height, width, 4)) # Swap B↔R and drop X → RGB rgb = arr[:, :, [2, 1, 0]].copy() raw = rgb.tobytes() except Exception as exc: logger.warning("InferenceManager: frame conversion failed: %s", exc) return channels = 3 self._frame_id += 1 packet = FramePacket( frame_id=self._frame_id, raw_bytes=raw, width=width, height=height, channels=channels, ) try: self._input_queue.put_nowait(packet) self._busy = True # logger.debug("InferenceManager: submitted frame %d", self._frame_id) except Exception as exc: logger.warning("InferenceManager: could not enqueue frame: %s", exc) # ------------------------------------------------------------------ # Private — worker lifecycle # ------------------------------------------------------------------ def _start_worker(self) -> None: ctx = multiprocessing.get_context("spawn") self._input_queue = ctx.Queue(maxsize=1) self._output_queue = ctx.Queue(maxsize=4) self._stop_event = ctx.Event() self._process = ctx.Process( target=run_worker, args=( self._model_path, self._input_queue, self._output_queue, self._stop_event, logging.WARNING, ), daemon=True, name="inference-worker", ) self._process.start() self._busy = False self._last_result_time = time.monotonic() self._poll_timer.start() self._watchdog_timer.start() logger.info( "Inference worker started (pid=%d, model=%s)", self._process.pid, self._model_path, ) self.inference_started.emit() def _stop_worker(self) -> None: self._poll_timer.stop() self._watchdog_timer.stop() if self._stop_event is not None: self._stop_event.set() if self._process is not None: self._process.join(timeout=3.0) if self._process.is_alive(): logger.warning("Worker did not stop cleanly — terminating") self._process.terminate() self._process.join(timeout=2.0) self._process = None self._input_queue = None self._output_queue = None self._stop_event = None self._busy = False logger.info("Inference worker stopped") self.inference_stopped.emit() # ------------------------------------------------------------------ # Private — timers # ------------------------------------------------------------------ @Slot() def _poll_output(self) -> None: """Drain the output queue (called every INFERENCE_POLL_INTERVAL_MS ms).""" if self._output_queue is None: return try: while True: item = self._output_queue.get_nowait() if item is None: # Worker signalled a fatal load error logger.error("Worker reported model load failure") self._handle_crash("Model failed to load in worker process") return packet: ResultPacket = item self._busy = False self._last_result_time = time.monotonic() detections = [ Detection(x1, y1, x2, y2, conf, label) for x1, y1, x2, y2, conf, label in packet.detections ] source_size = (packet.width, packet.height) if detections: self._detection_frame_count += 1 conf_summary = ", ".join( f"{d.label} {d.conf:.2f}" for d in detections ) logger.info( "frame %d: %d detection(s) in %.1f ms — %s", packet.frame_id, len(detections), packet.elapsed_ms, conf_summary, ) self.detection_count_updated.emit(self._detection_frame_count) self.detections_ready.emit(detections, source_size) except Exception: # Empty queue — normal pass @Slot() def _watchdog_check(self) -> None: """Detect crashed or hung worker process.""" if self._process is None: return # Process died unexpectedly if not self._process.is_alive(): exit_code = self._process.exitcode logger.error("Worker process died (exitcode=%s)", exit_code) self._handle_crash(f"Worker process exited with code {exit_code}") return # Worker alive but hasn't responded for too long (hung during inference) if self._busy: elapsed = time.monotonic() - self._last_result_time if elapsed > INFERENCE_WORKER_TIMEOUT_S: logger.error( "Worker timeout: no response for %.1f s — restarting", elapsed ) self._process.terminate() self._process.join(timeout=2.0) self._handle_crash("Worker timed out (hung during inference)") def _handle_crash(self, reason: str) -> None: """Decide whether to auto-restart or give up.""" # Clean up process handles (already dead) self._poll_timer.stop() self._watchdog_timer.stop() self._process = None self._busy = False if self._restart_count < INFERENCE_MAX_RESTARTS: self._restart_count += 1 logger.warning( "Auto-restarting worker (attempt %d/%d): %s", self._restart_count, INFERENCE_MAX_RESTARTS, reason, ) self._start_worker() else: msg = ( f"Inference worker failed after {INFERENCE_MAX_RESTARTS} restarts. " f"Last error: {reason}" ) logger.error(msg) self.inference_error.emit(msg)