feat: Enhance inference management with device tracking and telemetry updates
This commit is contained in:
@@ -5,8 +5,10 @@ safely be imported and executed in a child process via multiprocessing.
|
||||
|
||||
IPC protocol
|
||||
------------
|
||||
input_queue receives : FramePacket (frame_id, raw_bytes, width, height, channels)
|
||||
output_queue sends : ResultPacket (frame_id, detections, width, height)
|
||||
input_queue receives : FramePacket (frame_id, raw_bytes, width, height, channels)
|
||||
output_queue sends : WorkerReadyPacket (device) — once after model load
|
||||
: ResultPacket (frame_id, detections, width, height, elapsed_ms)
|
||||
: None — on fatal load failure
|
||||
stop_event : multiprocessing.Event — set by parent to request clean exit
|
||||
|
||||
Detection format (namedtuple-compatible plain tuple):
|
||||
@@ -37,6 +39,14 @@ class FramePacket(NamedTuple):
|
||||
channels: int # always 3 (RGB)
|
||||
|
||||
|
||||
class WorkerReadyPacket(NamedTuple):
|
||||
"""
|
||||
Sent once by the worker right after the model is loaded.
|
||||
Carries the device string so the GUI can display it.
|
||||
"""
|
||||
device: str # e.g. "cpu", "mps"
|
||||
|
||||
|
||||
class ResultPacket(NamedTuple):
|
||||
frame_id: int
|
||||
detections: list # list of (x1, y1, x2, y2, conf, label) tuples
|
||||
@@ -59,8 +69,9 @@ def run_worker(
|
||||
"""
|
||||
Main loop of the inference worker process.
|
||||
|
||||
Loads the YOLO model once, then processes frames from input_queue
|
||||
until stop_event is set. Results are posted to output_queue.
|
||||
Loads the YOLO model once, sends WorkerReadyPacket, then processes
|
||||
frames from input_queue until stop_event is set.
|
||||
Results are posted to output_queue.
|
||||
|
||||
This function is designed to be the target of multiprocessing.Process.
|
||||
It must NOT import PySide6 or any Qt module.
|
||||
@@ -68,15 +79,21 @@ def run_worker(
|
||||
_configure_worker_logging(log_level)
|
||||
logger.info("Inference worker starting (pid=%d)", _getpid())
|
||||
|
||||
# Select device once — never changes during the lifetime of this process
|
||||
device = _select_device()
|
||||
|
||||
try:
|
||||
model = _load_model(model_path)
|
||||
model = _load_model(model_path, device)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to load model '%s': %s", model_path, exc)
|
||||
# Signal failure by putting None — manager treats it as error
|
||||
output_queue.put(None)
|
||||
return
|
||||
|
||||
logger.info("Model loaded: %s", model_path)
|
||||
logger.info("Model loaded: %s device=%s", model_path, device)
|
||||
|
||||
# Notify GUI thread of the device being used
|
||||
output_queue.put(WorkerReadyPacket(device=device))
|
||||
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
@@ -88,7 +105,7 @@ def run_worker(
|
||||
break
|
||||
|
||||
try:
|
||||
result = _infer(model, packet)
|
||||
result = _infer(model, packet, device)
|
||||
output_queue.put(result)
|
||||
except Exception as exc:
|
||||
logger.error("Inference error (frame %d): %s", packet.frame_id, exc)
|
||||
@@ -107,11 +124,10 @@ def run_worker(
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _load_model(model_path: str):
|
||||
"""Load YOLO model with best available device."""
|
||||
def _load_model(model_path: str, device: str):
|
||||
"""Load YOLO model and warm up on the selected device."""
|
||||
from ultralytics import YOLO # noqa: PLC0415
|
||||
|
||||
device = _select_device()
|
||||
logger.info("Loading YOLO model on device='%s'", device)
|
||||
model = YOLO(model_path)
|
||||
# Warm up — run on a tiny dummy to JIT-compile kernels
|
||||
@@ -126,11 +142,13 @@ def _load_model(model_path: str):
|
||||
|
||||
def _select_device() -> str:
|
||||
"""
|
||||
Choose inference device.
|
||||
Choose the best available inference device.
|
||||
|
||||
Priority:
|
||||
- macOS → "mps" if available (Metal GPU), else "cpu"
|
||||
- macOS → "mps" if torch.backends.mps.is_available(), else "cpu"
|
||||
- others → "cpu"
|
||||
|
||||
Called once at worker startup — not per frame.
|
||||
"""
|
||||
system = platform.system()
|
||||
if system == "Darwin":
|
||||
@@ -145,7 +163,7 @@ def _select_device() -> str:
|
||||
return "cpu"
|
||||
|
||||
|
||||
def _infer(model, packet: FramePacket) -> ResultPacket:
|
||||
def _infer(model, packet: FramePacket, device: str) -> ResultPacket:
|
||||
"""Run model on one frame, return ResultPacket with elapsed_ms."""
|
||||
import time # noqa: PLC0415
|
||||
|
||||
@@ -155,7 +173,6 @@ def _infer(model, packet: FramePacket) -> ResultPacket:
|
||||
(packet.height, packet.width, packet.channels)
|
||||
)
|
||||
|
||||
device = _select_device()
|
||||
t0 = time.perf_counter()
|
||||
results = model(frame_np, device=device, verbose=False)
|
||||
elapsed_ms = (time.perf_counter() - t0) * 1000.0
|
||||
|
||||
@@ -5,11 +5,12 @@ Responsibilities:
|
||||
- Submit frames (with drop-if-busy logic)
|
||||
- Poll result queue via QTimer (never blocks the GUI thread)
|
||||
- Watch process health via QTimer (auto-restart on crash)
|
||||
- Emit Qt signals with results for BboxOverlay
|
||||
- Emit Qt signals with results for BboxOverlay and TelemetryCollector
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
@@ -25,10 +26,13 @@ from app.config import (
|
||||
INFERENCE_WORKER_TIMEOUT_S,
|
||||
)
|
||||
from app.inference.bbox_overlay import Detection
|
||||
from app.inference.worker import FramePacket, ResultPacket, run_worker
|
||||
from app.inference.worker import FramePacket, ResultPacket, WorkerReadyPacket, run_worker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Number of recent inference times to average for the overlay display
|
||||
_ELAPSED_WINDOW = 10
|
||||
|
||||
|
||||
class InferenceManager(QObject):
|
||||
"""
|
||||
@@ -40,22 +44,26 @@ class InferenceManager(QObject):
|
||||
detections : list[Detection]
|
||||
source_size : tuple[int, int] — (width, height) of inferred frame
|
||||
|
||||
detection_count_updated(int)
|
||||
Total number of frames on which at least one detection occurred.
|
||||
|
||||
inference_stats_updated(device, avg_ms)
|
||||
Emitted after every result packet.
|
||||
device : str — e.g. "cpu", "mps"
|
||||
avg_ms : float — rolling average of inference time (last 10 frames)
|
||||
|
||||
inference_device_changed(str)
|
||||
Emitted once when the worker reports its device after model load.
|
||||
|
||||
inference_started() — worker is up and model is loaded
|
||||
inference_stopped() — worker has exited cleanly
|
||||
inference_error(str) — fatal error (max restarts exceeded)
|
||||
|
||||
Usage:
|
||||
mgr = InferenceManager(parent=self)
|
||||
mgr.detections_ready.connect(bbox_overlay.on_detections)
|
||||
mgr.start("path/to/model.pt")
|
||||
# ...
|
||||
mgr.submit_frame(video_frame) # called by FrameDispatcher subscriber
|
||||
# ...
|
||||
mgr.stop()
|
||||
"""
|
||||
|
||||
detections_ready = Signal(object, object) # list[Detection], tuple[int,int]
|
||||
detection_count_updated = Signal(int) # total frames with detections so far
|
||||
inference_stats_updated = Signal(str, float) # device, avg_elapsed_ms
|
||||
inference_device_changed = Signal(str) # emitted once on WorkerReadyPacket
|
||||
inference_started = Signal()
|
||||
inference_stopped = Signal()
|
||||
inference_error = Signal(str)
|
||||
@@ -83,6 +91,14 @@ class InferenceManager(QObject):
|
||||
# Detection counter — frames on which at least one detection occurred
|
||||
self._detection_frame_count: int = 0
|
||||
|
||||
# Device reported by the worker after model load
|
||||
self._current_device: str = "cpu"
|
||||
|
||||
# Rolling window of recent elapsed_ms values for averaging
|
||||
self._elapsed_window: collections.deque[float] = collections.deque(
|
||||
maxlen=_ELAPSED_WINDOW
|
||||
)
|
||||
|
||||
# QTimers (GUI thread)
|
||||
self._poll_timer = QTimer(self)
|
||||
self._poll_timer.setInterval(INFERENCE_POLL_INTERVAL_MS)
|
||||
@@ -109,6 +125,8 @@ class InferenceManager(QObject):
|
||||
self._restart_count = 0
|
||||
self._paused = False
|
||||
self._detection_frame_count = 0
|
||||
self._elapsed_window.clear()
|
||||
self._current_device = "cpu"
|
||||
self._start_worker()
|
||||
|
||||
def stop(self) -> None:
|
||||
@@ -140,6 +158,10 @@ class InferenceManager(QObject):
|
||||
def model_path(self) -> str | None:
|
||||
return self._model_path
|
||||
|
||||
@property
|
||||
def current_device(self) -> str:
|
||||
return self._current_device
|
||||
|
||||
@Slot(QVideoFrame)
|
||||
def submit_frame(self, frame: QVideoFrame) -> None:
|
||||
"""
|
||||
@@ -204,7 +226,6 @@ class InferenceManager(QObject):
|
||||
try:
|
||||
self._input_queue.put_nowait(packet)
|
||||
self._busy = True
|
||||
# logger.debug("InferenceManager: submitted frame %d", self._frame_id)
|
||||
except Exception as exc:
|
||||
logger.warning("InferenceManager: could not enqueue frame: %s", exc)
|
||||
|
||||
@@ -278,16 +299,33 @@ class InferenceManager(QObject):
|
||||
try:
|
||||
while True:
|
||||
item = self._output_queue.get_nowait()
|
||||
|
||||
if item is None:
|
||||
# Worker signalled a fatal load error
|
||||
logger.error("Worker reported model load failure")
|
||||
self._handle_crash("Model failed to load in worker process")
|
||||
return
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# WorkerReadyPacket — sent once after model load
|
||||
# ----------------------------------------------------------
|
||||
if isinstance(item, WorkerReadyPacket):
|
||||
self._current_device = item.device
|
||||
logger.info("Inference device: %s", item.device)
|
||||
self.inference_device_changed.emit(item.device)
|
||||
continue
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# ResultPacket — regular inference result
|
||||
# ----------------------------------------------------------
|
||||
packet: ResultPacket = item
|
||||
self._busy = False
|
||||
self._last_result_time = time.monotonic()
|
||||
|
||||
# Update rolling average of elapsed time
|
||||
self._elapsed_window.append(packet.elapsed_ms)
|
||||
avg_ms = sum(self._elapsed_window) / len(self._elapsed_window)
|
||||
|
||||
detections = [
|
||||
Detection(x1, y1, x2, y2, conf, label)
|
||||
for x1, y1, x2, y2, conf, label in packet.detections
|
||||
@@ -308,6 +346,8 @@ class InferenceManager(QObject):
|
||||
)
|
||||
self.detection_count_updated.emit(self._detection_frame_count)
|
||||
|
||||
# Always emit stats so overlay stays current
|
||||
self.inference_stats_updated.emit(self._current_device, avg_ms)
|
||||
self.detections_ready.emit(detections, source_size)
|
||||
|
||||
except Exception:
|
||||
@@ -340,7 +380,6 @@ class InferenceManager(QObject):
|
||||
|
||||
def _handle_crash(self, reason: str) -> None:
|
||||
"""Decide whether to auto-restart or give up."""
|
||||
# Clean up process handles (already dead)
|
||||
self._poll_timer.stop()
|
||||
self._watchdog_timer.stop()
|
||||
self._process = None
|
||||
|
||||
Reference in New Issue
Block a user