feat: Enhance inference management with device tracking and telemetry updates
This commit is contained in:
@@ -5,8 +5,10 @@ safely be imported and executed in a child process via multiprocessing.
|
||||
|
||||
IPC protocol
|
||||
------------
|
||||
input_queue receives : FramePacket (frame_id, raw_bytes, width, height, channels)
|
||||
output_queue sends : ResultPacket (frame_id, detections, width, height)
|
||||
input_queue receives : FramePacket (frame_id, raw_bytes, width, height, channels)
|
||||
output_queue sends : WorkerReadyPacket (device) — once after model load
|
||||
: ResultPacket (frame_id, detections, width, height, elapsed_ms)
|
||||
: None — on fatal load failure
|
||||
stop_event : multiprocessing.Event — set by parent to request clean exit
|
||||
|
||||
Detection format (namedtuple-compatible plain tuple):
|
||||
@@ -37,6 +39,14 @@ class FramePacket(NamedTuple):
|
||||
channels: int # always 3 (RGB)
|
||||
|
||||
|
||||
class WorkerReadyPacket(NamedTuple):
|
||||
"""
|
||||
Sent once by the worker right after the model is loaded.
|
||||
Carries the device string so the GUI can display it.
|
||||
"""
|
||||
device: str # e.g. "cpu", "mps"
|
||||
|
||||
|
||||
class ResultPacket(NamedTuple):
|
||||
frame_id: int
|
||||
detections: list # list of (x1, y1, x2, y2, conf, label) tuples
|
||||
@@ -59,8 +69,9 @@ def run_worker(
|
||||
"""
|
||||
Main loop of the inference worker process.
|
||||
|
||||
Loads the YOLO model once, then processes frames from input_queue
|
||||
until stop_event is set. Results are posted to output_queue.
|
||||
Loads the YOLO model once, sends WorkerReadyPacket, then processes
|
||||
frames from input_queue until stop_event is set.
|
||||
Results are posted to output_queue.
|
||||
|
||||
This function is designed to be the target of multiprocessing.Process.
|
||||
It must NOT import PySide6 or any Qt module.
|
||||
@@ -68,15 +79,21 @@ def run_worker(
|
||||
_configure_worker_logging(log_level)
|
||||
logger.info("Inference worker starting (pid=%d)", _getpid())
|
||||
|
||||
# Select device once — never changes during the lifetime of this process
|
||||
device = _select_device()
|
||||
|
||||
try:
|
||||
model = _load_model(model_path)
|
||||
model = _load_model(model_path, device)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to load model '%s': %s", model_path, exc)
|
||||
# Signal failure by putting None — manager treats it as error
|
||||
output_queue.put(None)
|
||||
return
|
||||
|
||||
logger.info("Model loaded: %s", model_path)
|
||||
logger.info("Model loaded: %s device=%s", model_path, device)
|
||||
|
||||
# Notify GUI thread of the device being used
|
||||
output_queue.put(WorkerReadyPacket(device=device))
|
||||
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
@@ -88,7 +105,7 @@ def run_worker(
|
||||
break
|
||||
|
||||
try:
|
||||
result = _infer(model, packet)
|
||||
result = _infer(model, packet, device)
|
||||
output_queue.put(result)
|
||||
except Exception as exc:
|
||||
logger.error("Inference error (frame %d): %s", packet.frame_id, exc)
|
||||
@@ -107,11 +124,10 @@ def run_worker(
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _load_model(model_path: str):
|
||||
"""Load YOLO model with best available device."""
|
||||
def _load_model(model_path: str, device: str):
|
||||
"""Load YOLO model and warm up on the selected device."""
|
||||
from ultralytics import YOLO # noqa: PLC0415
|
||||
|
||||
device = _select_device()
|
||||
logger.info("Loading YOLO model on device='%s'", device)
|
||||
model = YOLO(model_path)
|
||||
# Warm up — run on a tiny dummy to JIT-compile kernels
|
||||
@@ -126,11 +142,13 @@ def _load_model(model_path: str):
|
||||
|
||||
def _select_device() -> str:
|
||||
"""
|
||||
Choose inference device.
|
||||
Choose the best available inference device.
|
||||
|
||||
Priority:
|
||||
- macOS → "mps" if available (Metal GPU), else "cpu"
|
||||
- macOS → "mps" if torch.backends.mps.is_available(), else "cpu"
|
||||
- others → "cpu"
|
||||
|
||||
Called once at worker startup — not per frame.
|
||||
"""
|
||||
system = platform.system()
|
||||
if system == "Darwin":
|
||||
@@ -145,7 +163,7 @@ def _select_device() -> str:
|
||||
return "cpu"
|
||||
|
||||
|
||||
def _infer(model, packet: FramePacket) -> ResultPacket:
|
||||
def _infer(model, packet: FramePacket, device: str) -> ResultPacket:
|
||||
"""Run model on one frame, return ResultPacket with elapsed_ms."""
|
||||
import time # noqa: PLC0415
|
||||
|
||||
@@ -155,7 +173,6 @@ def _infer(model, packet: FramePacket) -> ResultPacket:
|
||||
(packet.height, packet.width, packet.channels)
|
||||
)
|
||||
|
||||
device = _select_device()
|
||||
t0 = time.perf_counter()
|
||||
results = model(frame_np, device=device, verbose=False)
|
||||
elapsed_ms = (time.perf_counter() - t0) * 1000.0
|
||||
|
||||
@@ -5,11 +5,12 @@ Responsibilities:
|
||||
- Submit frames (with drop-if-busy logic)
|
||||
- Poll result queue via QTimer (never blocks the GUI thread)
|
||||
- Watch process health via QTimer (auto-restart on crash)
|
||||
- Emit Qt signals with results for BboxOverlay
|
||||
- Emit Qt signals with results for BboxOverlay and TelemetryCollector
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
@@ -25,10 +26,13 @@ from app.config import (
|
||||
INFERENCE_WORKER_TIMEOUT_S,
|
||||
)
|
||||
from app.inference.bbox_overlay import Detection
|
||||
from app.inference.worker import FramePacket, ResultPacket, run_worker
|
||||
from app.inference.worker import FramePacket, ResultPacket, WorkerReadyPacket, run_worker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Number of recent inference times to average for the overlay display
|
||||
_ELAPSED_WINDOW = 10
|
||||
|
||||
|
||||
class InferenceManager(QObject):
|
||||
"""
|
||||
@@ -40,22 +44,26 @@ class InferenceManager(QObject):
|
||||
detections : list[Detection]
|
||||
source_size : tuple[int, int] — (width, height) of inferred frame
|
||||
|
||||
detection_count_updated(int)
|
||||
Total number of frames on which at least one detection occurred.
|
||||
|
||||
inference_stats_updated(device, avg_ms)
|
||||
Emitted after every result packet.
|
||||
device : str — e.g. "cpu", "mps"
|
||||
avg_ms : float — rolling average of inference time (last 10 frames)
|
||||
|
||||
inference_device_changed(str)
|
||||
Emitted once when the worker reports its device after model load.
|
||||
|
||||
inference_started() — worker is up and model is loaded
|
||||
inference_stopped() — worker has exited cleanly
|
||||
inference_error(str) — fatal error (max restarts exceeded)
|
||||
|
||||
Usage:
|
||||
mgr = InferenceManager(parent=self)
|
||||
mgr.detections_ready.connect(bbox_overlay.on_detections)
|
||||
mgr.start("path/to/model.pt")
|
||||
# ...
|
||||
mgr.submit_frame(video_frame) # called by FrameDispatcher subscriber
|
||||
# ...
|
||||
mgr.stop()
|
||||
"""
|
||||
|
||||
detections_ready = Signal(object, object) # list[Detection], tuple[int,int]
|
||||
detection_count_updated = Signal(int) # total frames with detections so far
|
||||
inference_stats_updated = Signal(str, float) # device, avg_elapsed_ms
|
||||
inference_device_changed = Signal(str) # emitted once on WorkerReadyPacket
|
||||
inference_started = Signal()
|
||||
inference_stopped = Signal()
|
||||
inference_error = Signal(str)
|
||||
@@ -83,6 +91,14 @@ class InferenceManager(QObject):
|
||||
# Detection counter — frames on which at least one detection occurred
|
||||
self._detection_frame_count: int = 0
|
||||
|
||||
# Device reported by the worker after model load
|
||||
self._current_device: str = "cpu"
|
||||
|
||||
# Rolling window of recent elapsed_ms values for averaging
|
||||
self._elapsed_window: collections.deque[float] = collections.deque(
|
||||
maxlen=_ELAPSED_WINDOW
|
||||
)
|
||||
|
||||
# QTimers (GUI thread)
|
||||
self._poll_timer = QTimer(self)
|
||||
self._poll_timer.setInterval(INFERENCE_POLL_INTERVAL_MS)
|
||||
@@ -109,6 +125,8 @@ class InferenceManager(QObject):
|
||||
self._restart_count = 0
|
||||
self._paused = False
|
||||
self._detection_frame_count = 0
|
||||
self._elapsed_window.clear()
|
||||
self._current_device = "cpu"
|
||||
self._start_worker()
|
||||
|
||||
def stop(self) -> None:
|
||||
@@ -140,6 +158,10 @@ class InferenceManager(QObject):
|
||||
def model_path(self) -> str | None:
|
||||
return self._model_path
|
||||
|
||||
@property
|
||||
def current_device(self) -> str:
|
||||
return self._current_device
|
||||
|
||||
@Slot(QVideoFrame)
|
||||
def submit_frame(self, frame: QVideoFrame) -> None:
|
||||
"""
|
||||
@@ -204,7 +226,6 @@ class InferenceManager(QObject):
|
||||
try:
|
||||
self._input_queue.put_nowait(packet)
|
||||
self._busy = True
|
||||
# logger.debug("InferenceManager: submitted frame %d", self._frame_id)
|
||||
except Exception as exc:
|
||||
logger.warning("InferenceManager: could not enqueue frame: %s", exc)
|
||||
|
||||
@@ -278,16 +299,33 @@ class InferenceManager(QObject):
|
||||
try:
|
||||
while True:
|
||||
item = self._output_queue.get_nowait()
|
||||
|
||||
if item is None:
|
||||
# Worker signalled a fatal load error
|
||||
logger.error("Worker reported model load failure")
|
||||
self._handle_crash("Model failed to load in worker process")
|
||||
return
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# WorkerReadyPacket — sent once after model load
|
||||
# ----------------------------------------------------------
|
||||
if isinstance(item, WorkerReadyPacket):
|
||||
self._current_device = item.device
|
||||
logger.info("Inference device: %s", item.device)
|
||||
self.inference_device_changed.emit(item.device)
|
||||
continue
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# ResultPacket — regular inference result
|
||||
# ----------------------------------------------------------
|
||||
packet: ResultPacket = item
|
||||
self._busy = False
|
||||
self._last_result_time = time.monotonic()
|
||||
|
||||
# Update rolling average of elapsed time
|
||||
self._elapsed_window.append(packet.elapsed_ms)
|
||||
avg_ms = sum(self._elapsed_window) / len(self._elapsed_window)
|
||||
|
||||
detections = [
|
||||
Detection(x1, y1, x2, y2, conf, label)
|
||||
for x1, y1, x2, y2, conf, label in packet.detections
|
||||
@@ -308,6 +346,8 @@ class InferenceManager(QObject):
|
||||
)
|
||||
self.detection_count_updated.emit(self._detection_frame_count)
|
||||
|
||||
# Always emit stats so overlay stays current
|
||||
self.inference_stats_updated.emit(self._current_device, avg_ms)
|
||||
self.detections_ready.emit(detections, source_size)
|
||||
|
||||
except Exception:
|
||||
@@ -340,7 +380,6 @@ class InferenceManager(QObject):
|
||||
|
||||
def _handle_crash(self, reason: str) -> None:
|
||||
"""Decide whether to auto-restart or give up."""
|
||||
# Clean up process handles (already dead)
|
||||
self._poll_timer.stop()
|
||||
self._watchdog_timer.stop()
|
||||
self._process = None
|
||||
|
||||
@@ -33,6 +33,8 @@ class TelemetryOverlay(IOverlayLayer):
|
||||
CPU sys 14.8 % ← normalised by cpu_count (matches Task Manager)
|
||||
CPU core 118.4 % ← per single core (can exceed 100%)
|
||||
Mem 68 MB
|
||||
Inf.dev mps ← inference device (only when model loaded)
|
||||
Inf.time 87 ms ← rolling average of model() call time
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
@@ -106,4 +108,10 @@ class TelemetryOverlay(IOverlayLayer):
|
||||
if snap.memory_mb is not None:
|
||||
lines.append(f"Mem {snap.memory_mb:>5.0f} MB")
|
||||
|
||||
if snap.inference_device is not None:
|
||||
lines.append(f"Inf.dev {snap.inference_device:>6s}")
|
||||
|
||||
if snap.inference_time_ms is not None:
|
||||
lines.append(f"Inf.time {snap.inference_time_ms:>5.0f} ms")
|
||||
|
||||
return lines
|
||||
|
||||
@@ -26,6 +26,9 @@ class TelemetrySnapshot:
|
||||
cpu_percent_core: float # process CPU per single core — can exceed 100%
|
||||
memory_mb: float | None # process private working set in MB
|
||||
timestamp: float # time.perf_counter() when snapshot was taken
|
||||
# Inference fields — None when inference is disabled / model not loaded
|
||||
inference_device: str | None = None # e.g. "cpu", "mps"
|
||||
inference_time_ms: float | None = None # rolling average of model() call time
|
||||
|
||||
|
||||
class TelemetryCollector(QObject):
|
||||
@@ -69,6 +72,10 @@ class TelemetryCollector(QObject):
|
||||
self._process.cpu_percent() # first call always returns 0.0; discard
|
||||
self._cpu_count: int = max(psutil.cpu_count(logical=True) or 1, 1)
|
||||
|
||||
# Inference stats (updated externally via set_inference_stats)
|
||||
self._inference_device: str | None = None
|
||||
self._inference_time_ms: float | None = None
|
||||
|
||||
# periodic snapshot timer
|
||||
self._timer = QTimer(self)
|
||||
self._timer.setInterval(update_interval_ms)
|
||||
@@ -85,6 +92,16 @@ class TelemetryCollector(QObject):
|
||||
"""Record the FPS that was requested from the camera."""
|
||||
self._target_fps = fps
|
||||
|
||||
def set_inference_stats(self, device: str, avg_ms: float) -> None:
|
||||
"""Update inference device and average inference time (called from MainWindow)."""
|
||||
self._inference_device: str | None = device
|
||||
self._inference_time_ms: float | None = avg_ms
|
||||
|
||||
def clear_inference_stats(self) -> None:
|
||||
"""Clear inference stats when inference is disabled."""
|
||||
self._inference_device = None
|
||||
self._inference_time_ms = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Frame subscriber callback
|
||||
# ------------------------------------------------------------------
|
||||
@@ -175,6 +192,12 @@ class TelemetryCollector(QObject):
|
||||
cpu_percent_core=round(cpu_core, 1),
|
||||
memory_mb=round(mem_mb, 1) if mem_mb is not None else None,
|
||||
timestamp=now,
|
||||
inference_device=self._inference_device,
|
||||
inference_time_ms=(
|
||||
round(self._inference_time_ms, 1)
|
||||
if self._inference_time_ms is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
def _make_empty_snapshot(self) -> TelemetrySnapshot:
|
||||
@@ -187,4 +210,6 @@ class TelemetryCollector(QObject):
|
||||
cpu_percent_core=0.0,
|
||||
memory_mb=None,
|
||||
timestamp=time.perf_counter(),
|
||||
inference_device=None,
|
||||
inference_time_ms=None,
|
||||
)
|
||||
|
||||
@@ -184,6 +184,7 @@ class MainWindow(QMainWindow):
|
||||
# ---- InferenceManager ----
|
||||
self._inference.detections_ready.connect(self._bbox_overlay.on_detections)
|
||||
self._inference.detection_count_updated.connect(self._on_detection_count_updated)
|
||||
self._inference.inference_stats_updated.connect(self._on_inference_stats_updated)
|
||||
self._inference.inference_started.connect(self._on_inference_started)
|
||||
self._inference.inference_stopped.connect(self._on_inference_stopped)
|
||||
self._inference.inference_error.connect(self._on_inference_error)
|
||||
@@ -267,6 +268,9 @@ class MainWindow(QMainWindow):
|
||||
def _on_detection_count_updated(self, count: int) -> None:
|
||||
self._detection_label.setText(f"Detections: {count} frames")
|
||||
|
||||
def _on_inference_stats_updated(self, device: str, avg_ms: float) -> None:
|
||||
self._telemetry.set_inference_stats(device, avg_ms)
|
||||
|
||||
def _on_inference_stopped(self) -> None:
|
||||
self._bbox_overlay.clear()
|
||||
|
||||
@@ -276,6 +280,7 @@ class MainWindow(QMainWindow):
|
||||
self._menu.set_inference_checked(False)
|
||||
self._bbox_overlay.visible = False
|
||||
self._detection_label.setVisible(False)
|
||||
self._telemetry.clear_inference_stats()
|
||||
QMessageBox.critical(self, "Inference Error", message)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -350,6 +355,7 @@ class MainWindow(QMainWindow):
|
||||
self._bbox_overlay.clear()
|
||||
self._bbox_overlay.visible = False
|
||||
self._detection_label.setVisible(False)
|
||||
self._telemetry.clear_inference_stats()
|
||||
self._status_label.setText("Inference disabled")
|
||||
logger.info("Inference disabled")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user