feat: Enhance inference management with device tracking and telemetry updates

This commit is contained in:
2026-05-13 22:39:08 +02:00
parent 83346dc985
commit 6c401b62bb
7 changed files with 630 additions and 27 deletions

View File

@@ -5,8 +5,10 @@ safely be imported and executed in a child process via multiprocessing.
IPC protocol
------------
input_queue receives : FramePacket (frame_id, raw_bytes, width, height, channels)
output_queue sends : ResultPacket (frame_id, detections, width, height)
input_queue receives : FramePacket (frame_id, raw_bytes, width, height, channels)
output_queue sends : WorkerReadyPacket (device) — once after model load
: ResultPacket (frame_id, detections, width, height, elapsed_ms)
: None — on fatal load failure
stop_event : multiprocessing.Event — set by parent to request clean exit
Detection format (namedtuple-compatible plain tuple):
@@ -37,6 +39,14 @@ class FramePacket(NamedTuple):
channels: int # always 3 (RGB)
class WorkerReadyPacket(NamedTuple):
"""
Sent once by the worker right after the model is loaded.
Carries the device string so the GUI can display it.
"""
device: str # e.g. "cpu", "mps"
class ResultPacket(NamedTuple):
frame_id: int
detections: list # list of (x1, y1, x2, y2, conf, label) tuples
@@ -59,8 +69,9 @@ def run_worker(
"""
Main loop of the inference worker process.
Loads the YOLO model once, then processes frames from input_queue
until stop_event is set. Results are posted to output_queue.
Loads the YOLO model once, sends WorkerReadyPacket, then processes
frames from input_queue until stop_event is set.
Results are posted to output_queue.
This function is designed to be the target of multiprocessing.Process.
It must NOT import PySide6 or any Qt module.
@@ -68,15 +79,21 @@ def run_worker(
_configure_worker_logging(log_level)
logger.info("Inference worker starting (pid=%d)", _getpid())
# Select device once — never changes during the lifetime of this process
device = _select_device()
try:
model = _load_model(model_path)
model = _load_model(model_path, device)
except Exception as exc:
logger.error("Failed to load model '%s': %s", model_path, exc)
# Signal failure by putting None — manager treats it as error
output_queue.put(None)
return
logger.info("Model loaded: %s", model_path)
logger.info("Model loaded: %s device=%s", model_path, device)
# Notify GUI thread of the device being used
output_queue.put(WorkerReadyPacket(device=device))
while not stop_event.is_set():
try:
@@ -88,7 +105,7 @@ def run_worker(
break
try:
result = _infer(model, packet)
result = _infer(model, packet, device)
output_queue.put(result)
except Exception as exc:
logger.error("Inference error (frame %d): %s", packet.frame_id, exc)
@@ -107,11 +124,10 @@ def run_worker(
# Helpers
# ---------------------------------------------------------------------------
def _load_model(model_path: str):
"""Load YOLO model with best available device."""
def _load_model(model_path: str, device: str):
"""Load YOLO model and warm up on the selected device."""
from ultralytics import YOLO # noqa: PLC0415
device = _select_device()
logger.info("Loading YOLO model on device='%s'", device)
model = YOLO(model_path)
# Warm up — run on a tiny dummy to JIT-compile kernels
@@ -126,11 +142,13 @@ def _load_model(model_path: str):
def _select_device() -> str:
"""
Choose inference device.
Choose the best available inference device.
Priority:
- macOS → "mps" if available (Metal GPU), else "cpu"
- macOS → "mps" if torch.backends.mps.is_available(), else "cpu"
- others → "cpu"
Called once at worker startup — not per frame.
"""
system = platform.system()
if system == "Darwin":
@@ -145,7 +163,7 @@ def _select_device() -> str:
return "cpu"
def _infer(model, packet: FramePacket) -> ResultPacket:
def _infer(model, packet: FramePacket, device: str) -> ResultPacket:
"""Run model on one frame, return ResultPacket with elapsed_ms."""
import time # noqa: PLC0415
@@ -155,7 +173,6 @@ def _infer(model, packet: FramePacket) -> ResultPacket:
(packet.height, packet.width, packet.channels)
)
device = _select_device()
t0 = time.perf_counter()
results = model(frame_np, device=device, verbose=False)
elapsed_ms = (time.perf_counter() - t0) * 1000.0

View File

@@ -5,11 +5,12 @@ Responsibilities:
- Submit frames (with drop-if-busy logic)
- Poll result queue via QTimer (never blocks the GUI thread)
- Watch process health via QTimer (auto-restart on crash)
- Emit Qt signals with results for BboxOverlay
- Emit Qt signals with results for BboxOverlay and TelemetryCollector
"""
from __future__ import annotations
import collections
import logging
import multiprocessing
import time
@@ -25,10 +26,13 @@ from app.config import (
INFERENCE_WORKER_TIMEOUT_S,
)
from app.inference.bbox_overlay import Detection
from app.inference.worker import FramePacket, ResultPacket, run_worker
from app.inference.worker import FramePacket, ResultPacket, WorkerReadyPacket, run_worker
logger = logging.getLogger(__name__)
# Number of recent inference times to average for the overlay display
_ELAPSED_WINDOW = 10
class InferenceManager(QObject):
"""
@@ -40,22 +44,26 @@ class InferenceManager(QObject):
detections : list[Detection]
source_size : tuple[int, int] — (width, height) of inferred frame
detection_count_updated(int)
Total number of frames on which at least one detection occurred.
inference_stats_updated(device, avg_ms)
Emitted after every result packet.
device : str — e.g. "cpu", "mps"
avg_ms : float — rolling average of inference time (last 10 frames)
inference_device_changed(str)
Emitted once when the worker reports its device after model load.
inference_started() — worker is up and model is loaded
inference_stopped() — worker has exited cleanly
inference_error(str) — fatal error (max restarts exceeded)
Usage:
mgr = InferenceManager(parent=self)
mgr.detections_ready.connect(bbox_overlay.on_detections)
mgr.start("path/to/model.pt")
# ...
mgr.submit_frame(video_frame) # called by FrameDispatcher subscriber
# ...
mgr.stop()
"""
detections_ready = Signal(object, object) # list[Detection], tuple[int,int]
detection_count_updated = Signal(int) # total frames with detections so far
inference_stats_updated = Signal(str, float) # device, avg_elapsed_ms
inference_device_changed = Signal(str) # emitted once on WorkerReadyPacket
inference_started = Signal()
inference_stopped = Signal()
inference_error = Signal(str)
@@ -83,6 +91,14 @@ class InferenceManager(QObject):
# Detection counter — frames on which at least one detection occurred
self._detection_frame_count: int = 0
# Device reported by the worker after model load
self._current_device: str = "cpu"
# Rolling window of recent elapsed_ms values for averaging
self._elapsed_window: collections.deque[float] = collections.deque(
maxlen=_ELAPSED_WINDOW
)
# QTimers (GUI thread)
self._poll_timer = QTimer(self)
self._poll_timer.setInterval(INFERENCE_POLL_INTERVAL_MS)
@@ -109,6 +125,8 @@ class InferenceManager(QObject):
self._restart_count = 0
self._paused = False
self._detection_frame_count = 0
self._elapsed_window.clear()
self._current_device = "cpu"
self._start_worker()
def stop(self) -> None:
@@ -140,6 +158,10 @@ class InferenceManager(QObject):
def model_path(self) -> str | None:
return self._model_path
@property
def current_device(self) -> str:
return self._current_device
@Slot(QVideoFrame)
def submit_frame(self, frame: QVideoFrame) -> None:
"""
@@ -204,7 +226,6 @@ class InferenceManager(QObject):
try:
self._input_queue.put_nowait(packet)
self._busy = True
# logger.debug("InferenceManager: submitted frame %d", self._frame_id)
except Exception as exc:
logger.warning("InferenceManager: could not enqueue frame: %s", exc)
@@ -278,16 +299,33 @@ class InferenceManager(QObject):
try:
while True:
item = self._output_queue.get_nowait()
if item is None:
# Worker signalled a fatal load error
logger.error("Worker reported model load failure")
self._handle_crash("Model failed to load in worker process")
return
# ----------------------------------------------------------
# WorkerReadyPacket — sent once after model load
# ----------------------------------------------------------
if isinstance(item, WorkerReadyPacket):
self._current_device = item.device
logger.info("Inference device: %s", item.device)
self.inference_device_changed.emit(item.device)
continue
# ----------------------------------------------------------
# ResultPacket — regular inference result
# ----------------------------------------------------------
packet: ResultPacket = item
self._busy = False
self._last_result_time = time.monotonic()
# Update rolling average of elapsed time
self._elapsed_window.append(packet.elapsed_ms)
avg_ms = sum(self._elapsed_window) / len(self._elapsed_window)
detections = [
Detection(x1, y1, x2, y2, conf, label)
for x1, y1, x2, y2, conf, label in packet.detections
@@ -308,6 +346,8 @@ class InferenceManager(QObject):
)
self.detection_count_updated.emit(self._detection_frame_count)
# Always emit stats so overlay stays current
self.inference_stats_updated.emit(self._current_device, avg_ms)
self.detections_ready.emit(detections, source_size)
except Exception:
@@ -340,7 +380,6 @@ class InferenceManager(QObject):
def _handle_crash(self, reason: str) -> None:
"""Decide whether to auto-restart or give up."""
# Clean up process handles (already dead)
self._poll_timer.stop()
self._watchdog_timer.stop()
self._process = None

View File

@@ -33,6 +33,8 @@ class TelemetryOverlay(IOverlayLayer):
CPU sys 14.8 % ← normalised by cpu_count (matches Task Manager)
CPU core 118.4 % ← per single core (can exceed 100%)
Mem 68 MB
Inf.dev mps ← inference device (only when model loaded)
Inf.time 87 ms ← rolling average of model() call time
"""
def __init__(self) -> None:
@@ -106,4 +108,10 @@ class TelemetryOverlay(IOverlayLayer):
if snap.memory_mb is not None:
lines.append(f"Mem {snap.memory_mb:>5.0f} MB")
if snap.inference_device is not None:
lines.append(f"Inf.dev {snap.inference_device:>6s}")
if snap.inference_time_ms is not None:
lines.append(f"Inf.time {snap.inference_time_ms:>5.0f} ms")
return lines

View File

@@ -26,6 +26,9 @@ class TelemetrySnapshot:
cpu_percent_core: float # process CPU per single core — can exceed 100%
memory_mb: float | None # process private working set in MB
timestamp: float # time.perf_counter() when snapshot was taken
# Inference fields — None when inference is disabled / model not loaded
inference_device: str | None = None # e.g. "cpu", "mps"
inference_time_ms: float | None = None # rolling average of model() call time
class TelemetryCollector(QObject):
@@ -69,6 +72,10 @@ class TelemetryCollector(QObject):
self._process.cpu_percent() # first call always returns 0.0; discard
self._cpu_count: int = max(psutil.cpu_count(logical=True) or 1, 1)
# Inference stats (updated externally via set_inference_stats)
self._inference_device: str | None = None
self._inference_time_ms: float | None = None
# periodic snapshot timer
self._timer = QTimer(self)
self._timer.setInterval(update_interval_ms)
@@ -85,6 +92,16 @@ class TelemetryCollector(QObject):
"""Record the FPS that was requested from the camera."""
self._target_fps = fps
def set_inference_stats(self, device: str, avg_ms: float) -> None:
"""Update inference device and average inference time (called from MainWindow)."""
self._inference_device: str | None = device
self._inference_time_ms: float | None = avg_ms
def clear_inference_stats(self) -> None:
"""Clear inference stats when inference is disabled."""
self._inference_device = None
self._inference_time_ms = None
# ------------------------------------------------------------------
# Frame subscriber callback
# ------------------------------------------------------------------
@@ -175,6 +192,12 @@ class TelemetryCollector(QObject):
cpu_percent_core=round(cpu_core, 1),
memory_mb=round(mem_mb, 1) if mem_mb is not None else None,
timestamp=now,
inference_device=self._inference_device,
inference_time_ms=(
round(self._inference_time_ms, 1)
if self._inference_time_ms is not None
else None
),
)
def _make_empty_snapshot(self) -> TelemetrySnapshot:
@@ -187,4 +210,6 @@ class TelemetryCollector(QObject):
cpu_percent_core=0.0,
memory_mb=None,
timestamp=time.perf_counter(),
inference_device=None,
inference_time_ms=None,
)

View File

@@ -184,6 +184,7 @@ class MainWindow(QMainWindow):
# ---- InferenceManager ----
self._inference.detections_ready.connect(self._bbox_overlay.on_detections)
self._inference.detection_count_updated.connect(self._on_detection_count_updated)
self._inference.inference_stats_updated.connect(self._on_inference_stats_updated)
self._inference.inference_started.connect(self._on_inference_started)
self._inference.inference_stopped.connect(self._on_inference_stopped)
self._inference.inference_error.connect(self._on_inference_error)
@@ -267,6 +268,9 @@ class MainWindow(QMainWindow):
def _on_detection_count_updated(self, count: int) -> None:
self._detection_label.setText(f"Detections: {count} frames")
def _on_inference_stats_updated(self, device: str, avg_ms: float) -> None:
self._telemetry.set_inference_stats(device, avg_ms)
def _on_inference_stopped(self) -> None:
self._bbox_overlay.clear()
@@ -276,6 +280,7 @@ class MainWindow(QMainWindow):
self._menu.set_inference_checked(False)
self._bbox_overlay.visible = False
self._detection_label.setVisible(False)
self._telemetry.clear_inference_stats()
QMessageBox.critical(self, "Inference Error", message)
# ------------------------------------------------------------------
@@ -350,6 +355,7 @@ class MainWindow(QMainWindow):
self._bbox_overlay.clear()
self._bbox_overlay.visible = False
self._detection_label.setVisible(False)
self._telemetry.clear_inference_stats()
self._status_label.setText("Inference disabled")
logger.info("Inference disabled")