363 lines
12 KiB
Python
363 lines
12 KiB
Python
"""InferenceManager — orchestrates the YOLO worker process from the GUI thread.
|
|
|
|
Responsibilities:
|
|
- Start / stop the worker process
|
|
- Submit frames (with drop-if-busy logic)
|
|
- Poll result queue via QTimer (never blocks the GUI thread)
|
|
- Watch process health via QTimer (auto-restart on crash)
|
|
- Emit Qt signals with results for BboxOverlay
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import multiprocessing
|
|
import time
|
|
from pathlib import Path
|
|
|
|
from PySide6.QtCore import QObject, QTimer, Signal, Slot
|
|
from PySide6.QtMultimedia import QVideoFrame
|
|
|
|
from app.config import (
|
|
INFERENCE_MAX_RESTARTS,
|
|
INFERENCE_POLL_INTERVAL_MS,
|
|
INFERENCE_WATCHDOG_INTERVAL_MS,
|
|
INFERENCE_WORKER_TIMEOUT_S,
|
|
)
|
|
from app.inference.bbox_overlay import Detection
|
|
from app.inference.worker import FramePacket, ResultPacket, run_worker
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class InferenceManager(QObject):
|
|
"""
|
|
Manages the YOLO worker subprocess.
|
|
|
|
Signals:
|
|
detections_ready(detections, source_size)
|
|
Emitted in the GUI thread when a result arrives.
|
|
detections : list[Detection]
|
|
source_size : tuple[int, int] — (width, height) of inferred frame
|
|
|
|
inference_started() — worker is up and model is loaded
|
|
inference_stopped() — worker has exited cleanly
|
|
inference_error(str) — fatal error (max restarts exceeded)
|
|
|
|
Usage:
|
|
mgr = InferenceManager(parent=self)
|
|
mgr.detections_ready.connect(bbox_overlay.on_detections)
|
|
mgr.start("path/to/model.pt")
|
|
# ...
|
|
mgr.submit_frame(video_frame) # called by FrameDispatcher subscriber
|
|
# ...
|
|
mgr.stop()
|
|
"""
|
|
|
|
detections_ready = Signal(object, object) # list[Detection], tuple[int,int]
|
|
detection_count_updated = Signal(int) # total frames with detections so far
|
|
inference_started = Signal()
|
|
inference_stopped = Signal()
|
|
inference_error = Signal(str)
|
|
|
|
def __init__(self, parent: QObject | None = None) -> None:
|
|
super().__init__(parent)
|
|
|
|
self._model_path: str | None = None
|
|
self._process: multiprocessing.Process | None = None
|
|
self._input_queue: multiprocessing.Queue | None = None
|
|
self._output_queue: multiprocessing.Queue | None = None
|
|
self._stop_event: multiprocessing.Event | None = None
|
|
|
|
# Drop-if-busy flag — True while worker is processing a frame
|
|
self._busy: bool = False
|
|
self._frame_id: int = 0
|
|
|
|
# Restart tracking
|
|
self._restart_count: int = 0
|
|
self._last_result_time: float = 0.0
|
|
|
|
# Paused flag — inference can be suspended without stopping the process
|
|
self._paused: bool = False
|
|
|
|
# Detection counter — frames on which at least one detection occurred
|
|
self._detection_frame_count: int = 0
|
|
|
|
# QTimers (GUI thread)
|
|
self._poll_timer = QTimer(self)
|
|
self._poll_timer.setInterval(INFERENCE_POLL_INTERVAL_MS)
|
|
self._poll_timer.timeout.connect(self._poll_output)
|
|
|
|
self._watchdog_timer = QTimer(self)
|
|
self._watchdog_timer.setInterval(INFERENCE_WATCHDOG_INTERVAL_MS)
|
|
self._watchdog_timer.timeout.connect(self._watchdog_check)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Public API
|
|
# ------------------------------------------------------------------
|
|
|
|
def start(self, model_path: str) -> None:
|
|
"""Load model and start the worker process."""
|
|
if not Path(model_path).exists():
|
|
msg = f"Model file not found: {model_path}"
|
|
logger.error(msg)
|
|
self.inference_error.emit(msg)
|
|
return
|
|
|
|
self._stop_worker()
|
|
self._model_path = model_path
|
|
self._restart_count = 0
|
|
self._paused = False
|
|
self._detection_frame_count = 0
|
|
self._start_worker()
|
|
|
|
def stop(self) -> None:
|
|
"""Stop the worker process and reset state."""
|
|
self._stop_worker()
|
|
self._model_path = None
|
|
self._restart_count = 0
|
|
self._paused = False
|
|
|
|
def pause(self) -> None:
|
|
"""Suspend frame submission without stopping the process."""
|
|
self._paused = True
|
|
logger.debug("InferenceManager: paused")
|
|
|
|
def resume(self) -> None:
|
|
"""Resume frame submission."""
|
|
self._paused = False
|
|
logger.debug("InferenceManager: resumed")
|
|
|
|
@property
|
|
def is_running(self) -> bool:
|
|
return self._process is not None and self._process.is_alive()
|
|
|
|
@property
|
|
def is_paused(self) -> bool:
|
|
return self._paused
|
|
|
|
@property
|
|
def model_path(self) -> str | None:
|
|
return self._model_path
|
|
|
|
@Slot(QVideoFrame)
|
|
def submit_frame(self, frame: QVideoFrame) -> None:
|
|
"""
|
|
Attempt to submit a frame for inference.
|
|
|
|
Drops the frame silently if:
|
|
- manager is not running
|
|
- manager is paused
|
|
- worker is still busy with previous frame (drop_if_busy)
|
|
|
|
Frame conversion strategy:
|
|
Use QVideoFrame.toImage() → QImage.Format_RGB32 → bits().
|
|
This handles all pixel formats (NV12, YUV420P, BGRA, MJPG, etc.)
|
|
because Qt decodes them internally. The cost is a CPU colour-space
|
|
conversion, but it only happens when the worker is idle (drop_if_busy).
|
|
"""
|
|
if not self.is_running or self._paused or self._busy:
|
|
return
|
|
|
|
if not frame.isValid():
|
|
return
|
|
|
|
# Convert frame to RGB via Qt's built-in decoder.
|
|
# toImage() handles NV12, YUV420P, MJPG, BGRA — any pixel format.
|
|
image = frame.toImage()
|
|
if image.isNull():
|
|
logger.warning("InferenceManager: toImage() returned null")
|
|
return
|
|
|
|
width = image.width()
|
|
height = image.height()
|
|
|
|
# Ensure we have packed RGB32 (BGRX on little-endian, 4 bytes/pixel)
|
|
from PySide6.QtGui import QImage # noqa: PLC0415
|
|
if image.format() != QImage.Format.Format_RGB32:
|
|
image = image.convertToFormat(QImage.Format.Format_RGB32)
|
|
|
|
# Extract RGB bytes (drop alpha/padding channel)
|
|
try:
|
|
import numpy as np # noqa: PLC0415
|
|
# bits() returns BGRX (B G R 0xFF) for Format_RGB32
|
|
ptr = image.bits()
|
|
arr = np.frombuffer(ptr, dtype=np.uint8).reshape((height, width, 4))
|
|
# Swap B↔R and drop X → RGB
|
|
rgb = arr[:, :, [2, 1, 0]].copy()
|
|
raw = rgb.tobytes()
|
|
except Exception as exc:
|
|
logger.warning("InferenceManager: frame conversion failed: %s", exc)
|
|
return
|
|
|
|
channels = 3
|
|
|
|
self._frame_id += 1
|
|
packet = FramePacket(
|
|
frame_id=self._frame_id,
|
|
raw_bytes=raw,
|
|
width=width,
|
|
height=height,
|
|
channels=channels,
|
|
)
|
|
|
|
try:
|
|
self._input_queue.put_nowait(packet)
|
|
self._busy = True
|
|
# logger.debug("InferenceManager: submitted frame %d", self._frame_id)
|
|
except Exception as exc:
|
|
logger.warning("InferenceManager: could not enqueue frame: %s", exc)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Private — worker lifecycle
|
|
# ------------------------------------------------------------------
|
|
|
|
def _start_worker(self) -> None:
|
|
ctx = multiprocessing.get_context("spawn")
|
|
self._input_queue = ctx.Queue(maxsize=1)
|
|
self._output_queue = ctx.Queue(maxsize=4)
|
|
self._stop_event = ctx.Event()
|
|
|
|
self._process = ctx.Process(
|
|
target=run_worker,
|
|
args=(
|
|
self._model_path,
|
|
self._input_queue,
|
|
self._output_queue,
|
|
self._stop_event,
|
|
logging.WARNING,
|
|
),
|
|
daemon=True,
|
|
name="inference-worker",
|
|
)
|
|
self._process.start()
|
|
self._busy = False
|
|
self._last_result_time = time.monotonic()
|
|
|
|
self._poll_timer.start()
|
|
self._watchdog_timer.start()
|
|
logger.info(
|
|
"Inference worker started (pid=%d, model=%s)",
|
|
self._process.pid, self._model_path,
|
|
)
|
|
self.inference_started.emit()
|
|
|
|
def _stop_worker(self) -> None:
|
|
self._poll_timer.stop()
|
|
self._watchdog_timer.stop()
|
|
|
|
if self._stop_event is not None:
|
|
self._stop_event.set()
|
|
|
|
if self._process is not None:
|
|
self._process.join(timeout=3.0)
|
|
if self._process.is_alive():
|
|
logger.warning("Worker did not stop cleanly — terminating")
|
|
self._process.terminate()
|
|
self._process.join(timeout=2.0)
|
|
self._process = None
|
|
|
|
self._input_queue = None
|
|
self._output_queue = None
|
|
self._stop_event = None
|
|
self._busy = False
|
|
|
|
logger.info("Inference worker stopped")
|
|
self.inference_stopped.emit()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Private — timers
|
|
# ------------------------------------------------------------------
|
|
|
|
@Slot()
|
|
def _poll_output(self) -> None:
|
|
"""Drain the output queue (called every INFERENCE_POLL_INTERVAL_MS ms)."""
|
|
if self._output_queue is None:
|
|
return
|
|
|
|
try:
|
|
while True:
|
|
item = self._output_queue.get_nowait()
|
|
if item is None:
|
|
# Worker signalled a fatal load error
|
|
logger.error("Worker reported model load failure")
|
|
self._handle_crash("Model failed to load in worker process")
|
|
return
|
|
|
|
packet: ResultPacket = item
|
|
self._busy = False
|
|
self._last_result_time = time.monotonic()
|
|
|
|
detections = [
|
|
Detection(x1, y1, x2, y2, conf, label)
|
|
for x1, y1, x2, y2, conf, label in packet.detections
|
|
]
|
|
source_size = (packet.width, packet.height)
|
|
|
|
if detections:
|
|
self._detection_frame_count += 1
|
|
conf_summary = ", ".join(
|
|
f"{d.label} {d.conf:.2f}" for d in detections
|
|
)
|
|
logger.info(
|
|
"frame %d: %d detection(s) in %.1f ms — %s",
|
|
packet.frame_id,
|
|
len(detections),
|
|
packet.elapsed_ms,
|
|
conf_summary,
|
|
)
|
|
self.detection_count_updated.emit(self._detection_frame_count)
|
|
|
|
self.detections_ready.emit(detections, source_size)
|
|
|
|
except Exception:
|
|
# Empty queue — normal
|
|
pass
|
|
|
|
@Slot()
|
|
def _watchdog_check(self) -> None:
|
|
"""Detect crashed or hung worker process."""
|
|
if self._process is None:
|
|
return
|
|
|
|
# Process died unexpectedly
|
|
if not self._process.is_alive():
|
|
exit_code = self._process.exitcode
|
|
logger.error("Worker process died (exitcode=%s)", exit_code)
|
|
self._handle_crash(f"Worker process exited with code {exit_code}")
|
|
return
|
|
|
|
# Worker alive but hasn't responded for too long (hung during inference)
|
|
if self._busy:
|
|
elapsed = time.monotonic() - self._last_result_time
|
|
if elapsed > INFERENCE_WORKER_TIMEOUT_S:
|
|
logger.error(
|
|
"Worker timeout: no response for %.1f s — restarting", elapsed
|
|
)
|
|
self._process.terminate()
|
|
self._process.join(timeout=2.0)
|
|
self._handle_crash("Worker timed out (hung during inference)")
|
|
|
|
def _handle_crash(self, reason: str) -> None:
|
|
"""Decide whether to auto-restart or give up."""
|
|
# Clean up process handles (already dead)
|
|
self._poll_timer.stop()
|
|
self._watchdog_timer.stop()
|
|
self._process = None
|
|
self._busy = False
|
|
|
|
if self._restart_count < INFERENCE_MAX_RESTARTS:
|
|
self._restart_count += 1
|
|
logger.warning(
|
|
"Auto-restarting worker (attempt %d/%d): %s",
|
|
self._restart_count, INFERENCE_MAX_RESTARTS, reason,
|
|
)
|
|
self._start_worker()
|
|
else:
|
|
msg = (
|
|
f"Inference worker failed after {INFERENCE_MAX_RESTARTS} restarts. "
|
|
f"Last error: {reason}"
|
|
)
|
|
logger.error(msg)
|
|
self.inference_error.emit(msg)
|