diff --git a/app/config.py b/app/config.py index 12869a2..fcad544 100644 --- a/app/config.py +++ b/app/config.py @@ -27,3 +27,20 @@ DISPATCHER_MAX_QUEUE_SIZE = 2 # max pending frames per slow subscriber before d LOG_DIR = Path("logs") # relative to CWD (project root) MAX_LOG_FILES = 20 # oldest sessions are deleted when exceeded TELEMETRY_CSV_INTERVAL_S = 5.0 # how often a CSV row is written (seconds) + +# Inference worker +INFERENCE_WORKER_TIMEOUT_S = 10.0 # seconds without response before watchdog fires +INFERENCE_MAX_RESTARTS = 3 # max auto-restart attempts before giving up +INFERENCE_POLL_INTERVAL_MS = 50 # how often GUI thread polls output queue (ms) +INFERENCE_WATCHDOG_INTERVAL_MS = 2000 # how often watchdog checks process health (ms) + +# BBox overlay +BBOX_COLOR = (0, 220, 60, 255) # RGBA — vivid green +BBOX_LABEL_BG_COLOR = (0, 220, 60, 200) # RGBA — semi-transparent green for label bg +BBOX_LABEL_TEXT_COLOR = (0, 0, 0, 255) # RGBA — black text on green bg +BBOX_LINE_WIDTH = 2 +BBOX_FONT_SIZE = 11 + +# Video file source +VIDEO_FILE_EXTENSIONS = "Video Files (*.mp4 *.avi *.mov *.mkv *.m4v *.webm)" +MODEL_FILE_EXTENSIONS = "YOLO Model (*.pt *.pth)" diff --git a/app/inference/__init__.py b/app/inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/inference/bbox_overlay.py b/app/inference/bbox_overlay.py new file mode 100644 index 0000000..dd46cdd --- /dev/null +++ b/app/inference/bbox_overlay.py @@ -0,0 +1,154 @@ +"""BboxOverlay — draws YOLO detection bounding boxes on the camera view.""" + +from __future__ import annotations + +import logging +from typing import NamedTuple + +from PySide6.QtCore import QRect, QSize, Qt, Slot +from PySide6.QtGui import QColor, QFont, QPainter, QPen + +from app.config import ( + BBOX_COLOR, + BBOX_FONT_SIZE, + BBOX_LABEL_BG_COLOR, + BBOX_LABEL_TEXT_COLOR, + BBOX_LINE_WIDTH, +) +from app.overlay.overlay_layer import IOverlayLayer + +logger = logging.getLogger(__name__) + + +class Detection(NamedTuple): + """ + A single object detection result. + + Coordinates (x1, y1, x2, y2) are in pixels of the *source frame* + (i.e. the frame that was submitted to inference). BboxOverlay maps + them to the letterboxed video_rect before drawing. + """ + + x1: float + y1: float + x2: float + y2: float + conf: float + label: str + + +class BboxOverlay(IOverlayLayer): + """ + Overlay layer that renders detection bounding boxes. + + Usage: + overlay = BboxOverlay() + camera_view.add_overlay_layer(overlay) + inference_manager.detections_ready.connect(overlay.on_detections) + + Thread safety: + on_detections() is called from the GUI thread (via Qt signal). + paint() is also called from the GUI thread (paintEvent). + No locks required. + """ + + def __init__(self) -> None: + super().__init__() + self._detections: list[Detection] = [] + self._source_size: QSize = QSize(0, 0) + + self._pen = QPen(QColor(*BBOX_COLOR)) + self._pen.setWidth(BBOX_LINE_WIDTH) + self._pen.setJoinStyle(Qt.PenJoinStyle.MiterJoin) + + self._font = QFont("Monospace") + self._font.setStyleHint(QFont.StyleHint.TypeWriter) + self._font.setPointSize(BBOX_FONT_SIZE) + self._font.setBold(True) + + self._box_color = QColor(*BBOX_COLOR) + self._bg_color = QColor(*BBOX_LABEL_BG_COLOR) + self._text_color = QColor(*BBOX_LABEL_TEXT_COLOR) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + @Slot(object, object) + def on_detections( + self, + detections: list[Detection], + source_size: tuple[int, int], + ) -> None: + """ + Receive detection results from InferenceManager. + + Args: + detections: List of Detection namedtuples (pixel coords). + source_size: (width, height) of the frame that was inferred. + """ + self._detections = detections + self._source_size = QSize(*source_size) + + def clear(self) -> None: + """Remove all currently displayed detections.""" + self._detections = [] + + # ------------------------------------------------------------------ + # IOverlayLayer implementation + # ------------------------------------------------------------------ + + def paint(self, painter: QPainter, video_rect: QRect) -> None: + if not self._detections: + return + if self._source_size.isEmpty(): + return + + src_w = self._source_size.width() + src_h = self._source_size.height() + vr = video_rect + + # Scale factors: source-pixel → video_rect-pixel + scale_x = vr.width() / src_w + scale_y = vr.height() / src_h + + painter.setFont(self._font) + fm = painter.fontMetrics() + + for det in self._detections: + # Map to widget coordinates + wx1 = vr.x() + int(det.x1 * scale_x) + wy1 = vr.y() + int(det.y1 * scale_y) + wx2 = vr.x() + int(det.x2 * scale_x) + wy2 = vr.y() + int(det.y2 * scale_y) + + box_rect = QRect(wx1, wy1, wx2 - wx1, wy2 - wy1) + + # Draw bounding box + painter.setPen(self._pen) + painter.setBrush(Qt.BrushStyle.NoBrush) + painter.drawRect(box_rect) + + # Label text: "label 0.87" + label_text = f"{det.label} {det.conf:.2f}" + text_w = fm.horizontalAdvance(label_text) + 6 + text_h = fm.height() + 2 + + # Position label above box, clamped to video_rect + lx = wx1 + ly = wy1 - text_h + if ly < vr.top(): + ly = wy1 # draw inside box if no room above + + label_bg = QRect(lx, ly, text_w, text_h) + + painter.setPen(Qt.PenStyle.NoPen) + painter.setBrush(self._bg_color) + painter.drawRect(label_bg) + + painter.setPen(QPen(self._text_color)) + painter.drawText( + lx + 3, + ly + fm.ascent() + 1, + label_text, + ) diff --git a/app/inference/worker.py b/app/inference/worker.py new file mode 100644 index 0000000..e46b502 --- /dev/null +++ b/app/inference/worker.py @@ -0,0 +1,196 @@ +"""YOLO inference worker — runs in a separate process. + +This module contains only plain functions (no Qt, no PySide6) so it can +safely be imported and executed in a child process via multiprocessing. + +IPC protocol +------------ +input_queue receives : FramePacket (frame_id, raw_bytes, width, height, channels) +output_queue sends : ResultPacket (frame_id, detections, width, height) +stop_event : multiprocessing.Event — set by parent to request clean exit + +Detection format (namedtuple-compatible plain tuple): + (x1, y1, x2, y2, conf, label) — all floats/str, x/y in source-frame pixels +""" + +from __future__ import annotations + +import logging +import platform +import sys +from multiprocessing import Event, Queue +from queue import Empty +from typing import NamedTuple + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Data structures shared between worker and manager +# --------------------------------------------------------------------------- + +class FramePacket(NamedTuple): + frame_id: int + raw_bytes: bytes # RGB bytes, row-major, shape = (height, width, channels) + width: int + height: int + channels: int # always 3 (RGB) + + +class ResultPacket(NamedTuple): + frame_id: int + detections: list # list of (x1, y1, x2, y2, conf, label) tuples + width: int # source frame width (for overlay scaling) + height: int # source frame height + + +# --------------------------------------------------------------------------- +# Worker entry point +# --------------------------------------------------------------------------- + +def run_worker( + model_path: str, + input_queue: Queue, + output_queue: Queue, + stop_event: Event, + log_level: int = logging.WARNING, +) -> None: + """ + Main loop of the inference worker process. + + Loads the YOLO model once, then processes frames from input_queue + until stop_event is set. Results are posted to output_queue. + + This function is designed to be the target of multiprocessing.Process. + It must NOT import PySide6 or any Qt module. + """ + _configure_worker_logging(log_level) + logger.info("Inference worker starting (pid=%d)", _getpid()) + + try: + model = _load_model(model_path) + except Exception as exc: + logger.error("Failed to load model '%s': %s", model_path, exc) + # Signal failure by putting None — manager treats it as error + output_queue.put(None) + return + + logger.info("Model loaded: %s", model_path) + + while not stop_event.is_set(): + try: + packet: FramePacket = input_queue.get(timeout=0.1) + except Empty: + continue + except Exception as exc: + logger.error("Error reading input queue: %s", exc) + break + + try: + result = _infer(model, packet) + output_queue.put(result) + except Exception as exc: + logger.error("Inference error (frame %d): %s", packet.frame_id, exc) + # Put empty result so manager knows we're still alive + output_queue.put(ResultPacket( + frame_id=packet.frame_id, + detections=[], + width=packet.width, + height=packet.height, + )) + + logger.info("Inference worker stopping") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _load_model(model_path: str): + """Load YOLO model with best available device.""" + from ultralytics import YOLO # noqa: PLC0415 + + device = _select_device() + logger.info("Loading YOLO model on device='%s'", device) + model = YOLO(model_path) + # Warm up — run on a tiny dummy to JIT-compile kernels + try: + import numpy as np # noqa: PLC0415 + dummy = np.zeros((64, 64, 3), dtype=np.uint8) + model(dummy, device=device, verbose=False) + except Exception as exc: + logger.warning("Warm-up failed (non-fatal): %s", exc) + return model + + +def _select_device() -> str: + """ + Choose inference device. + + Priority: + - macOS → "mps" if available (Metal GPU), else "cpu" + - others → "cpu" + """ + system = platform.system() + if system == "Darwin": + try: + import torch # noqa: PLC0415 + if torch.backends.mps.is_available(): + logger.info("MPS (Metal) available — using GPU") + return "mps" + except Exception: + pass + logger.info("MPS not available — using CPU") + return "cpu" + + +def _infer(model, packet: FramePacket) -> ResultPacket: + """Run model on one frame, return ResultPacket.""" + import numpy as np # noqa: PLC0415 + + frame_np = np.frombuffer(packet.raw_bytes, dtype=np.uint8).reshape( + (packet.height, packet.width, packet.channels) + ) + + device = _select_device() + results = model(frame_np, device=device, verbose=False) + + detections = [] + for r in results: + if r.boxes is None: + continue + boxes = r.boxes + for i in range(len(boxes)): + xyxy = boxes.xyxy[i].tolist() # [x1, y1, x2, y2] in source pixels + conf = float(boxes.conf[i]) + cls_idx = int(boxes.cls[i]) + label = ( + r.names[cls_idx] + if r.names and cls_idx in r.names + else str(cls_idx) + ) + detections.append(( + float(xyxy[0]), float(xyxy[1]), + float(xyxy[2]), float(xyxy[3]), + conf, label, + )) + + return ResultPacket( + frame_id=packet.frame_id, + detections=detections, + width=packet.width, + height=packet.height, + ) + + +def _configure_worker_logging(level: int) -> None: + logging.basicConfig( + level=level, + format="[worker %(process)d] %(levelname)s %(name)s: %(message)s", + stream=sys.stderr, + ) + + +def _getpid() -> int: + import os # noqa: PLC0415 + return os.getpid() diff --git a/app/inference/worker_manager.py b/app/inference/worker_manager.py new file mode 100644 index 0000000..7598cb7 --- /dev/null +++ b/app/inference/worker_manager.py @@ -0,0 +1,350 @@ +"""InferenceManager — orchestrates the YOLO worker process from the GUI thread. + +Responsibilities: + - Start / stop the worker process + - Submit frames (with drop-if-busy logic) + - Poll result queue via QTimer (never blocks the GUI thread) + - Watch process health via QTimer (auto-restart on crash) + - Emit Qt signals with results for BboxOverlay +""" + +from __future__ import annotations + +import logging +import multiprocessing +import time +from pathlib import Path + +from PySide6.QtCore import QObject, QTimer, Signal, Slot +from PySide6.QtMultimedia import QVideoFrame + +from app.config import ( + INFERENCE_MAX_RESTARTS, + INFERENCE_POLL_INTERVAL_MS, + INFERENCE_WATCHDOG_INTERVAL_MS, + INFERENCE_WORKER_TIMEOUT_S, +) +from app.inference.bbox_overlay import Detection +from app.inference.worker import FramePacket, ResultPacket, run_worker + +logger = logging.getLogger(__name__) + + +class InferenceManager(QObject): + """ + Manages the YOLO worker subprocess. + + Signals: + detections_ready(detections, source_size) + Emitted in the GUI thread when a result arrives. + detections : list[Detection] + source_size : tuple[int, int] — (width, height) of inferred frame + + inference_started() — worker is up and model is loaded + inference_stopped() — worker has exited cleanly + inference_error(str) — fatal error (max restarts exceeded) + + Usage: + mgr = InferenceManager(parent=self) + mgr.detections_ready.connect(bbox_overlay.on_detections) + mgr.start("path/to/model.pt") + # ... + mgr.submit_frame(video_frame) # called by FrameDispatcher subscriber + # ... + mgr.stop() + """ + + detections_ready = Signal(object, object) # list[Detection], tuple[int,int] + inference_started = Signal() + inference_stopped = Signal() + inference_error = Signal(str) + + def __init__(self, parent: QObject | None = None) -> None: + super().__init__(parent) + + self._model_path: str | None = None + self._process: multiprocessing.Process | None = None + self._input_queue: multiprocessing.Queue | None = None + self._output_queue: multiprocessing.Queue | None = None + self._stop_event: multiprocessing.Event | None = None + + # Drop-if-busy flag — True while worker is processing a frame + self._busy: bool = False + self._frame_id: int = 0 + + # Restart tracking + self._restart_count: int = 0 + self._last_result_time: float = 0.0 + + # Paused flag — inference can be suspended without stopping the process + self._paused: bool = False + + # QTimers (GUI thread) + self._poll_timer = QTimer(self) + self._poll_timer.setInterval(INFERENCE_POLL_INTERVAL_MS) + self._poll_timer.timeout.connect(self._poll_output) + + self._watchdog_timer = QTimer(self) + self._watchdog_timer.setInterval(INFERENCE_WATCHDOG_INTERVAL_MS) + self._watchdog_timer.timeout.connect(self._watchdog_check) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def start(self, model_path: str) -> None: + """Load model and start the worker process.""" + if not Path(model_path).exists(): + msg = f"Model file not found: {model_path}" + logger.error(msg) + self.inference_error.emit(msg) + return + + self._stop_worker() + self._model_path = model_path + self._restart_count = 0 + self._paused = False + self._start_worker() + + def stop(self) -> None: + """Stop the worker process and reset state.""" + self._stop_worker() + self._model_path = None + self._restart_count = 0 + self._paused = False + + def pause(self) -> None: + """Suspend frame submission without stopping the process.""" + self._paused = True + logger.debug("InferenceManager: paused") + + def resume(self) -> None: + """Resume frame submission.""" + self._paused = False + logger.debug("InferenceManager: resumed") + + @property + def is_running(self) -> bool: + return self._process is not None and self._process.is_alive() + + @property + def is_paused(self) -> bool: + return self._paused + + @property + def model_path(self) -> str | None: + return self._model_path + + @Slot(QVideoFrame) + def submit_frame(self, frame: QVideoFrame) -> None: + """ + Attempt to submit a frame for inference. + + Drops the frame silently if: + - manager is not running + - manager is paused + - worker is still busy with previous frame (drop_if_busy) + """ + if not self.is_running or self._paused or self._busy: + return + + if not frame.isValid(): + return + + # Map frame to read-only memory, copy raw bytes, unmap + if not frame.map(QVideoFrame.MapMode.ReadOnly): + logger.warning("InferenceManager: failed to map QVideoFrame") + return + + try: + width = frame.width() + height = frame.height() + raw = bytes(frame.bits(0)) # plane 0 — copies data + finally: + frame.unmap() + + if not raw: + return + + # Detect number of channels from byte count + expected_rgb = width * height * 3 + expected_rgba = width * height * 4 + if len(raw) >= expected_rgba: + # BGRA / RGBA — convert to RGB by stripping alpha and swapping B/R + try: + import numpy as np # noqa: PLC0415 + arr = np.frombuffer(raw, dtype=np.uint8).reshape((height, width, 4)) + # Qt delivers BGRA → swap to RGB + rgb = arr[:, :, [2, 1, 0]].copy() + raw = rgb.tobytes() + channels = 3 + except Exception as exc: + logger.warning("Frame colour conversion failed: %s", exc) + return + elif len(raw) >= expected_rgb: + channels = 3 + else: + logger.warning( + "Unexpected frame size: %d bytes for %dx%d", + len(raw), width, height, + ) + return + + self._frame_id += 1 + packet = FramePacket( + frame_id=self._frame_id, + raw_bytes=raw, + width=width, + height=height, + channels=channels, + ) + + try: + self._input_queue.put_nowait(packet) + self._busy = True + logger.debug("InferenceManager: submitted frame %d", self._frame_id) + except Exception as exc: + logger.warning("InferenceManager: could not enqueue frame: %s", exc) + + # ------------------------------------------------------------------ + # Private — worker lifecycle + # ------------------------------------------------------------------ + + def _start_worker(self) -> None: + ctx = multiprocessing.get_context("spawn") + self._input_queue = ctx.Queue(maxsize=1) + self._output_queue = ctx.Queue(maxsize=4) + self._stop_event = ctx.Event() + + self._process = ctx.Process( + target=run_worker, + args=( + self._model_path, + self._input_queue, + self._output_queue, + self._stop_event, + logging.WARNING, + ), + daemon=True, + name="inference-worker", + ) + self._process.start() + self._busy = False + self._last_result_time = time.monotonic() + + self._poll_timer.start() + self._watchdog_timer.start() + logger.info( + "Inference worker started (pid=%d, model=%s)", + self._process.pid, self._model_path, + ) + self.inference_started.emit() + + def _stop_worker(self) -> None: + self._poll_timer.stop() + self._watchdog_timer.stop() + + if self._stop_event is not None: + self._stop_event.set() + + if self._process is not None: + self._process.join(timeout=3.0) + if self._process.is_alive(): + logger.warning("Worker did not stop cleanly — terminating") + self._process.terminate() + self._process.join(timeout=2.0) + self._process = None + + self._input_queue = None + self._output_queue = None + self._stop_event = None + self._busy = False + + logger.info("Inference worker stopped") + self.inference_stopped.emit() + + # ------------------------------------------------------------------ + # Private — timers + # ------------------------------------------------------------------ + + @Slot() + def _poll_output(self) -> None: + """Drain the output queue (called every INFERENCE_POLL_INTERVAL_MS ms).""" + if self._output_queue is None: + return + + try: + while True: + item = self._output_queue.get_nowait() + if item is None: + # Worker signalled a fatal load error + logger.error("Worker reported model load failure") + self._handle_crash("Model failed to load in worker process") + return + + packet: ResultPacket = item + self._busy = False + self._last_result_time = time.monotonic() + + detections = [ + Detection(x1, y1, x2, y2, conf, label) + for x1, y1, x2, y2, conf, label in packet.detections + ] + source_size = (packet.width, packet.height) + + logger.debug( + "InferenceManager: frame %d → %d detections", + packet.frame_id, len(detections), + ) + self.detections_ready.emit(detections, source_size) + + except Exception: + # Empty queue — normal + pass + + @Slot() + def _watchdog_check(self) -> None: + """Detect crashed or hung worker process.""" + if self._process is None: + return + + # Process died unexpectedly + if not self._process.is_alive(): + exit_code = self._process.exitcode + logger.error("Worker process died (exitcode=%s)", exit_code) + self._handle_crash(f"Worker process exited with code {exit_code}") + return + + # Worker alive but hasn't responded for too long (hung during inference) + if self._busy: + elapsed = time.monotonic() - self._last_result_time + if elapsed > INFERENCE_WORKER_TIMEOUT_S: + logger.error( + "Worker timeout: no response for %.1f s — restarting", elapsed + ) + self._process.terminate() + self._process.join(timeout=2.0) + self._handle_crash("Worker timed out (hung during inference)") + + def _handle_crash(self, reason: str) -> None: + """Decide whether to auto-restart or give up.""" + # Clean up process handles (already dead) + self._poll_timer.stop() + self._watchdog_timer.stop() + self._process = None + self._busy = False + + if self._restart_count < INFERENCE_MAX_RESTARTS: + self._restart_count += 1 + logger.warning( + "Auto-restarting worker (attempt %d/%d): %s", + self._restart_count, INFERENCE_MAX_RESTARTS, reason, + ) + self._start_worker() + else: + msg = ( + f"Inference worker failed after {INFERENCE_MAX_RESTARTS} restarts. " + f"Last error: {reason}" + ) + logger.error(msg) + self.inference_error.emit(msg) diff --git a/app/ui/main_window.py b/app/ui/main_window.py index 6c6cbb8..d667b0a 100644 --- a/app/ui/main_window.py +++ b/app/ui/main_window.py @@ -6,7 +6,7 @@ import logging from pathlib import Path from PySide6.QtCore import QTimer -from PySide6.QtWidgets import QLabel, QMainWindow, QSizePolicy, QStatusBar +from PySide6.QtWidgets import QLabel, QMainWindow, QMessageBox, QSizePolicy, QStatusBar from app.camera.camera_enumerator import CameraEnumerator, CameraFormat, CameraInfo from app.camera.camera_service import CameraService @@ -14,6 +14,8 @@ from app.camera.uvc import make_uvc_controller from app.camera.uvc.base import UvcControllerBase from app.camera.uvc.stub import NullUvcController from app.config import APP_NAME, APP_VERSION +from app.inference.bbox_overlay import BboxOverlay +from app.inference.worker_manager import InferenceManager from app.overlay.telemetry_overlay import TelemetryOverlay from app.pipeline.frame_dispatcher import FrameDispatcher from app.telemetry.csv_logger import CsvTelemetryLogger @@ -21,6 +23,7 @@ from app.telemetry.telemetry_collector import TelemetryCollector from app.ui.camera_settings_dialog import CameraSettingsDialog from app.ui.camera_view import CameraView from app.ui.menu_bar import AppMenuBar +from app.video.video_player import VideoPlayer logger = logging.getLogger(__name__) @@ -29,19 +32,25 @@ class MainWindow(QMainWindow): """ Top-level application window. - Rendering architecture: - QVideoWidget is intentionally NOT used — on Windows its native HWND - surface occludes all sibling/child QWidgets regardless of z-order. - CameraView is a plain QWidget that renders frames and overlay layers - in a single paintEvent pass. + Frame source (exclusive): + • CameraService — live camera (default) + • VideoPlayer — local video file + + Inference pipeline (optional): + InferenceManager runs YOLO in a separate process. + Frames submitted via FrameDispatcher subscriber (drop_if_busy). + Results displayed by BboxOverlay. Signal flow: - CameraService.frame_ready + [CameraService | VideoPlayer].frame_ready(QVideoFrame) → FrameDispatcher.dispatch - → CameraView.on_frame (render frame) - → TelemetryCollector.on_frame (measure metrics) - → TelemetryOverlay.on_metrics_updated (overlay data) - → CsvTelemetryLogger.on_metrics_updated (CSV file) + → CameraView.on_frame (render) + → TelemetryCollector.on_frame (metrics) + → TelemetryOverlay (HUD) + → CsvTelemetryLogger (CSV) + → InferenceManager.submit_frame (drop_if_busy, optional) + → [worker process] YOLO + → BboxOverlay.on_detections (draw boxes) """ def __init__(self, log_path: Path | None = None) -> None: @@ -51,22 +60,28 @@ class MainWindow(QMainWindow): self.setMinimumSize(640, 480) self.resize(1280, 720) - # --- Core pipeline components --- + # --- Core pipeline --- self._camera_service = CameraService(self) + self._video_player = VideoPlayer(self) self._dispatcher = FrameDispatcher(self) self._telemetry = TelemetryCollector(parent=self) + self._inference = InferenceManager(self) - # --- UVC controller (platform-specific, lazy-opened per camera) --- + # Track which source is active + self._video_source_active: bool = False + self._current_camera: CameraInfo | None = None + + # --- UVC --- self._uvc: UvcControllerBase = NullUvcController() - # --- CSV telemetry logger --- + # --- CSV logger --- self._csv_logger: CsvTelemetryLogger | None = None if log_path is not None: csv_path = log_path.with_suffix(".csv") self._csv_logger = CsvTelemetryLogger(csv_path) logger.info("Telemetry CSV: %s", csv_path.resolve()) - # --- Camera view (central widget) --- + # --- Camera view --- self._camera_view = CameraView(self) self._camera_view.setSizePolicy( QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding @@ -75,7 +90,10 @@ class MainWindow(QMainWindow): # --- Overlay layers --- self._telemetry_overlay = TelemetryOverlay() + self._bbox_overlay = BboxOverlay() self._camera_view.add_overlay_layer(self._telemetry_overlay) + self._camera_view.add_overlay_layer(self._bbox_overlay) + self._bbox_overlay.visible = False # hidden until inference enabled # --- Menu bar --- self._menu = AppMenuBar(self) @@ -92,7 +110,6 @@ class MainWindow(QMainWindow): # --- Wire signals --- self._wire_signals() - # --- Enumerate cameras and start --- QTimer.singleShot(0, self._initialise_cameras) # ------------------------------------------------------------------ @@ -101,21 +118,19 @@ class MainWindow(QMainWindow): def _initialise_cameras(self) -> None: cameras = CameraEnumerator.list_cameras() - if not cameras: self._status_label.setText("No cameras found") logger.warning("No cameras detected") return self._menu.populate_cameras(cameras) - default = CameraEnumerator.default_camera() start_cam = default if default is not None else cameras[0] - self._menu.populate_formats(start_cam) self._start_camera(start_cam) def _start_camera(self, cam: CameraInfo) -> None: + self._current_camera = cam self._telemetry.reset_counters() self._camera_service.start(cam) self._menu.set_active_camera(cam) @@ -123,12 +138,10 @@ class MainWindow(QMainWindow): self._open_uvc(cam) def _open_uvc(self, cam: CameraInfo) -> None: - """Open or reopen the UVC controller for the given camera.""" if self._uvc.is_open(): self._uvc.close() ctrl = make_uvc_controller(cam.name) if not ctrl.is_open(): - # factory may return a pre-opened controller or a NullUvcController ctrl.open(cam.name) self._uvc = ctrl @@ -137,38 +150,73 @@ class MainWindow(QMainWindow): # ------------------------------------------------------------------ def _wire_signals(self) -> None: - # CameraService → FrameDispatcher + # ---- Active source → dispatcher ---- + # (connected dynamically in _switch_to_camera / _switch_to_video) self._camera_service.frame_ready.connect(self._dispatcher.dispatch) - # FrameDispatcher → CameraView (render) — drop if busy + # ---- Dispatcher fans out to all consumers ---- self._dispatcher.subscribe(self._camera_view.on_frame, drop_if_busy=True) - - # FrameDispatcher → TelemetryCollector — never drop self._dispatcher.subscribe(self._telemetry.on_frame, drop_if_busy=False) + # InferenceManager subscriber added/removed dynamically on toggle - # TelemetryCollector → overlay + # ---- Telemetry ---- self._telemetry.metrics_updated.connect( self._telemetry_overlay.on_metrics_updated ) - - # TelemetryCollector → CSV logger (throttled internally) if self._csv_logger is not None: self._telemetry.metrics_updated.connect(self._csv_logger.on_metrics_updated) - - # CameraService → TelemetryCollector: keep target FPS in sync self._camera_service.format_changed.connect(self._telemetry.set_target_fps) - # CameraService status + # ---- Camera service status ---- self._camera_service.camera_started.connect(self._on_camera_started) self._camera_service.camera_stopped.connect(self._on_camera_stopped) self._camera_service.camera_error.connect(self._on_camera_error) - # Menu signals + # ---- Video player status ---- + self._video_player.playback_started.connect(self._on_playback_started) + self._video_player.playback_stopped.connect(self._on_playback_stopped) + self._video_player.playback_error.connect(self._on_playback_error) + + # ---- InferenceManager ---- + self._inference.detections_ready.connect(self._bbox_overlay.on_detections) + self._inference.inference_started.connect(self._on_inference_started) + self._inference.inference_stopped.connect(self._on_inference_stopped) + self._inference.inference_error.connect(self._on_inference_error) + + # ---- Menu ---- self._menu.camera_selected.connect(self._on_camera_selected) self._menu.format_selected.connect(self._on_format_selected) self._menu.reconnect_requested.connect(self._camera_service.reconnect) self._menu.overlay_toggled.connect(self._camera_view.set_all_overlays_visible) self._menu.camera_settings_requested.connect(self._on_settings_requested) + self._menu.video_file_selected.connect(self._on_video_selected) + self._menu.video_closed.connect(self._on_video_closed) + self._menu.model_file_selected.connect(self._on_model_selected) + self._menu.inference_toggled.connect(self._on_inference_toggled) + + # ------------------------------------------------------------------ + # Source switching + # ------------------------------------------------------------------ + + def _switch_to_camera(self) -> None: + """Disconnect VideoPlayer, connect CameraService to dispatcher.""" + try: + self._video_player.frame_ready.disconnect(self._dispatcher.dispatch) + except RuntimeError: + pass + self._camera_service.frame_ready.connect(self._dispatcher.dispatch) + self._video_source_active = False + self._menu.set_video_source_active(False) + + def _switch_to_video(self) -> None: + """Disconnect CameraService, connect VideoPlayer to dispatcher.""" + try: + self._camera_service.frame_ready.disconnect(self._dispatcher.dispatch) + except RuntimeError: + pass + self._video_player.frame_ready.connect(self._dispatcher.dispatch) + self._video_source_active = True + self._menu.set_video_source_active(True) # ------------------------------------------------------------------ # Camera status slots @@ -187,11 +235,48 @@ class MainWindow(QMainWindow): self._status_label.setText(f"Error: {message}") logger.error("Camera error: %s", message) + # ------------------------------------------------------------------ + # Video player slots + # ------------------------------------------------------------------ + + def _on_playback_started(self) -> None: + path = self._video_player.current_path or "" + name = Path(path).name if path else "video" + self._status_label.setText(f"Playing: {name}") + + def _on_playback_stopped(self) -> None: + self._status_label.setText("Playback finished") + + def _on_playback_error(self, message: str) -> None: + self._status_label.setText(f"Video error: {message}") + logger.error(message) + + # ------------------------------------------------------------------ + # Inference slots + # ------------------------------------------------------------------ + + def _on_inference_started(self) -> None: + self._status_label.setText("Inference running") + self._menu.set_inference_checked(True) + + def _on_inference_stopped(self) -> None: + self._bbox_overlay.clear() + + def _on_inference_error(self, message: str) -> None: + logger.error("Inference: %s", message) + self._menu.set_inference_available(False) + self._menu.set_inference_checked(False) + self._bbox_overlay.visible = False + QMessageBox.critical(self, "Inference Error", message) + # ------------------------------------------------------------------ # Menu action slots # ------------------------------------------------------------------ def _on_camera_selected(self, cam: CameraInfo) -> None: + if self._video_source_active: + self._video_player.stop() + self._switch_to_camera() self._start_camera(cam) def _on_format_selected(self, fmt: CameraFormat) -> None: @@ -209,12 +294,61 @@ class MainWindow(QMainWindow): dlg = CameraSettingsDialog(qt_cam, self._uvc, parent=self) dlg.exec() + def _on_video_selected(self, path: str) -> None: + """Switch source to video file.""" + self._camera_service.stop() + self._switch_to_video() + self._video_player.play(path) + logger.info("Video source: %s", path) + + def _on_video_closed(self) -> None: + """Return to camera source.""" + self._video_player.stop() + self._switch_to_camera() + if self._current_camera is not None: + self._start_camera(self._current_camera) + logger.info("Returned to camera source") + + def _on_model_selected(self, path: str) -> None: + """Load YOLO model into inference manager.""" + name = Path(path).name + logger.info("Loading model: %s", path) + self._status_label.setText(f"Loading model: {name}\u2026") + self._inference.start(path) + self._menu.set_model_label(name) + self._menu.set_inference_available(True) + self._menu.set_inference_checked(False) # user must explicitly enable + + def _on_inference_toggled(self, enabled: bool) -> None: + if enabled: + if not self._inference.is_running: + # shouldn't happen but be safe + logger.warning("Inference toggle on but manager not running") + self._menu.set_inference_checked(False) + return + self._inference.resume() + self._dispatcher.subscribe( + self._inference.submit_frame, drop_if_busy=True + ) + self._bbox_overlay.visible = True + self._status_label.setText("Inference enabled") + logger.info("Inference enabled") + else: + self._inference.pause() + self._dispatcher.unsubscribe(self._inference.submit_frame) + self._bbox_overlay.clear() + self._bbox_overlay.visible = False + self._status_label.setText("Inference disabled") + logger.info("Inference disabled") + # ------------------------------------------------------------------ # Qt overrides # ------------------------------------------------------------------ def closeEvent(self, event) -> None: # noqa: N802 + self._inference.stop() self._camera_service.stop() + self._video_player.stop() if self._uvc.is_open(): self._uvc.close() if self._csv_logger is not None: diff --git a/app/ui/menu_bar.py b/app/ui/menu_bar.py index fa65051..22bd635 100644 --- a/app/ui/menu_bar.py +++ b/app/ui/menu_bar.py @@ -1,4 +1,4 @@ -"""Menu bar — camera, video format and debug controls.""" +"""Menu bar — File, Camera, Video format, Image, Model and Debug controls.""" from __future__ import annotations @@ -6,9 +6,10 @@ import logging from PySide6.QtCore import Signal from PySide6.QtGui import QAction, QActionGroup -from PySide6.QtWidgets import QMenuBar, QWidget +from PySide6.QtWidgets import QFileDialog, QMenuBar, QWidget from app.camera.camera_enumerator import CameraFormat, CameraInfo +from app.config import MODEL_FILE_EXTENSIONS, VIDEO_FILE_EXTENSIONS from app.logging_setup import set_console_level logger = logging.getLogger(__name__) @@ -19,17 +20,32 @@ class AppMenuBar(QMenuBar): Application menu bar. Signals: - camera_selected(CameraInfo) — user picked a camera - format_selected(CameraFormat) — user picked a full format (res+fps+pixel) - reconnect_requested() — user hit Reconnect - overlay_toggled(bool) — overlay show/hide - log_toggled(bool) — console logging on/off - camera_settings_requested() — user opened Image Settings dialog + video_file_selected(str) — user picked a video file path + video_closed() — user chose to close video and return to camera + model_file_selected(str) — user picked a .pt model file path + inference_toggled(bool) — user toggled inference on/off + camera_selected(CameraInfo) + format_selected(CameraFormat) + reconnect_requested() + overlay_toggled(bool) + log_toggled(bool) + camera_settings_requested() """ + # File / video + video_file_selected = Signal(str) + video_closed = Signal() + + # Model / inference + model_file_selected = Signal(str) + inference_toggled = Signal(bool) + + # Camera camera_selected = Signal(object) # CameraInfo format_selected = Signal(object) # CameraFormat reconnect_requested = Signal() + + # View / debug overlay_toggled = Signal(bool) log_toggled = Signal(bool) camera_settings_requested = Signal() @@ -48,7 +64,6 @@ class AppMenuBar(QMenuBar): # ------------------------------------------------------------------ def populate_cameras(self, cameras: list[CameraInfo]) -> None: - """Populate the Camera menu with discovered devices.""" self._cameras = cameras menu = self._camera_menu @@ -71,7 +86,6 @@ class AppMenuBar(QMenuBar): self._camera_group.actions()[0].setChecked(True) def populate_formats(self, camera_info: CameraInfo) -> None: - """Populate the Resolution submenu with full format entries.""" self._populate_format_menu(camera_info) def set_active_camera(self, camera_info: CameraInfo) -> None: @@ -83,7 +97,6 @@ class AppMenuBar(QMenuBar): return def set_active_format(self, fmt: CameraFormat) -> None: - """Mark the given format as checked in the Resolution menu.""" if self._format_group is None: return for action in self._format_group.actions(): @@ -98,34 +111,80 @@ class AppMenuBar(QMenuBar): return def set_log_file_path(self, path: str) -> None: - """Display the log file path as a disabled menu item in Debug menu.""" display = path if len(path) <= 60 else "\u2026" + path[-57:] self._log_file_action.setText(f"Log: {display}") self._log_file_action.setToolTip(path) + def set_video_source_active(self, is_video: bool) -> None: + """Update File menu state when source switches between camera and video.""" + self._close_video_action.setEnabled(is_video) + + def set_inference_available(self, available: bool) -> None: + """Enable/disable the inference toggle (requires model to be loaded).""" + self._inference_toggle_action.setEnabled(available) + + def set_inference_checked(self, checked: bool) -> None: + self._inference_toggle_action.setChecked(checked) + + def set_model_label(self, name: str) -> None: + """Show loaded model name as disabled info item.""" + self._model_info_action.setText(f"Model: {name}") + # ------------------------------------------------------------------ # Menu construction # ------------------------------------------------------------------ def _build_menus(self) -> None: - # Camera menu + # --- File menu --- + file_menu = self.addMenu("File") + + open_video_action = QAction("Open Video\u2026", self) + open_video_action.triggered.connect(self._on_open_video) + file_menu.addAction(open_video_action) + + self._close_video_action = QAction("Close Video", self) + self._close_video_action.setEnabled(False) + self._close_video_action.triggered.connect(self.video_closed) + file_menu.addAction(self._close_video_action) + + # --- Camera menu --- self._camera_menu = self.addMenu("Camera") self._cam_separator = self._camera_menu.addSeparator() self._reconnect_action = QAction("Reconnect", self) self._reconnect_action.triggered.connect(self.reconnect_requested) self._camera_menu.addAction(self._reconnect_action) - # Video menu + # --- Video menu --- self._video_menu = self.addMenu("Video") self._res_menu = self._video_menu.addMenu("Resolution") - # Image menu (camera controls) + # --- Image menu --- self._image_menu = self.addMenu("Image") self._settings_action = QAction("Camera Settings\u2026", self) self._settings_action.triggered.connect(self.camera_settings_requested) self._image_menu.addAction(self._settings_action) - # Debug menu + # --- Model menu --- + model_menu = self.addMenu("Model") + + load_model_action = QAction("Load Model\u2026", self) + load_model_action.triggered.connect(self._on_load_model) + model_menu.addAction(load_model_action) + + self._inference_toggle_action = QAction("Enable Inference", self) + self._inference_toggle_action.setCheckable(True) + self._inference_toggle_action.setChecked(False) + self._inference_toggle_action.setEnabled(False) # enabled after model loaded + self._inference_toggle_action.toggled.connect(self.inference_toggled) + model_menu.addAction(self._inference_toggle_action) + + model_menu.addSeparator() + + self._model_info_action = QAction("Model: (none)", self) + self._model_info_action.setEnabled(False) + model_menu.addAction(self._model_info_action) + + # --- Debug menu --- debug_menu = self.addMenu("Debug") self._overlay_action = QAction("Show Overlay", self) @@ -147,7 +206,6 @@ class AppMenuBar(QMenuBar): debug_menu.addAction(self._log_file_action) def _populate_format_menu(self, camera_info: CameraInfo) -> None: - """Build Resolution submenu: one action per unique (W, H, FPS, pixel_format).""" self._res_menu.clear() self._format_group = QActionGroup(self) self._format_group.setExclusive(True) @@ -173,6 +231,28 @@ class AppMenuBar(QMenuBar): # Slots # ------------------------------------------------------------------ + def _on_open_video(self) -> None: + path, _ = QFileDialog.getOpenFileName( + self.parentWidget(), + "Open Video File", + "", + VIDEO_FILE_EXTENSIONS, + ) + if path: + logger.debug("Video file selected: %s", path) + self.video_file_selected.emit(path) + + def _on_load_model(self) -> None: + path, _ = QFileDialog.getOpenFileName( + self.parentWidget(), + "Load YOLO Model", + "", + MODEL_FILE_EXTENSIONS, + ) + if path: + logger.debug("Model file selected: %s", path) + self.model_file_selected.emit(path) + def _on_camera_action(self) -> None: action = self.sender() if action is None: diff --git a/app/video/__init__.py b/app/video/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/video/video_player.py b/app/video/video_player.py new file mode 100644 index 0000000..507a360 --- /dev/null +++ b/app/video/video_player.py @@ -0,0 +1,117 @@ +"""VideoPlayer — plays a local video file and delivers frames via frame_ready signal. + +The public interface mirrors CameraService so MainWindow can treat both +interchangeably: both emit frame_ready(QVideoFrame). +""" + +from __future__ import annotations + +import logging +from pathlib import Path + +from PySide6.QtCore import QObject, QUrl, Signal, Slot +from PySide6.QtMultimedia import ( + QMediaPlayer, + QVideoFrame, + QVideoSink, +) + +logger = logging.getLogger(__name__) + + +class VideoPlayer(QObject): + """ + Wraps QMediaPlayer + QVideoSink to replay a local video file. + + Signal flow (identical interface to CameraService): + VideoPlayer.frame_ready(QVideoFrame) → FrameDispatcher + + Notes: + - Playback is real-time (1×) — no seek/pause in this version. + - At end-of-file: emits playback_stopped() and stops. + - On any error: emits playback_error(str) then playback_stopped(). + """ + + frame_ready = Signal(QVideoFrame) + playback_started = Signal() + playback_stopped = Signal() + playback_error = Signal(str) + + def __init__(self, parent: QObject | None = None) -> None: + super().__init__(parent) + + self._player = QMediaPlayer(self) + self._sink = QVideoSink(self) + + self._player.setVideoSink(self._sink) + + self._sink.videoFrameChanged.connect(self._on_frame) + self._player.playbackStateChanged.connect(self._on_playback_state_changed) + self._player.errorOccurred.connect(self._on_error) + + self._current_path: str | None = None + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def play(self, path: str) -> None: + """Open and start playing a video file.""" + self.stop() + p = Path(path) + if not p.exists(): + msg = f"Video file not found: {path}" + logger.error(msg) + self.playback_error.emit(msg) + return + + self._current_path = path + url = QUrl.fromLocalFile(str(p.resolve())) + self._player.setSource(url) + self._player.play() + logger.info("VideoPlayer: starting playback of '%s'", p.name) + + def stop(self) -> None: + """Stop playback and clear source.""" + if self._player.playbackState() != QMediaPlayer.PlaybackState.StoppedState: + self._player.stop() + self._player.setSource(QUrl()) + self._current_path = None + + @property + def is_playing(self) -> bool: + return ( + self._player.playbackState() + == QMediaPlayer.PlaybackState.PlayingState + ) + + @property + def current_path(self) -> str | None: + return self._current_path + + # ------------------------------------------------------------------ + # Private slots + # ------------------------------------------------------------------ + + @Slot(QVideoFrame) + def _on_frame(self, frame: QVideoFrame) -> None: + if frame.isValid(): + self.frame_ready.emit(frame) + + @Slot(QMediaPlayer.PlaybackState) + def _on_playback_state_changed(self, state: QMediaPlayer.PlaybackState) -> None: + if state == QMediaPlayer.PlaybackState.PlayingState: + logger.info("VideoPlayer: playing") + self.playback_started.emit() + elif state == QMediaPlayer.PlaybackState.StoppedState: + logger.info("VideoPlayer: stopped") + self.playback_stopped.emit() + + @Slot(QMediaPlayer.Error, str) + def _on_error(self, error: QMediaPlayer.Error, error_string: str) -> None: + if error == QMediaPlayer.Error.NoError: + return + msg = f"VideoPlayer error: {error_string}" + logger.error(msg) + self.playback_error.emit(msg) + self.playback_stopped.emit() diff --git a/models/.gitkeep b/models/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/models/.gitkeep @@ -0,0 +1 @@ + diff --git a/models/best_v1.pt b/models/best_v1.pt new file mode 100644 index 0000000..6b6603f Binary files /dev/null and b/models/best_v1.pt differ diff --git a/pyproject.toml b/pyproject.toml index 30de526..a0368d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,14 @@ dependencies = [ "psutil>=6.0", ] +[project.optional-dependencies] +# Install inference support: pip install -e ".[inference]" +# or: pip install ultralytics numpy +inference = [ + "ultralytics>=8.0", + "numpy>=1.24", +] + [project.scripts] duck-preview = "app.main:main" diff --git a/tests/test_bbox_overlay.py b/tests/test_bbox_overlay.py new file mode 100644 index 0000000..f6c6ff3 --- /dev/null +++ b/tests/test_bbox_overlay.py @@ -0,0 +1,180 @@ +"""Tests for BboxOverlay — coordinate mapping and state management.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +from PySide6.QtCore import QRect, QSize + +from app.inference.bbox_overlay import BboxOverlay, Detection + + +class TestDetection: + def test_namedtuple_fields(self) -> None: + d = Detection(x1=10.0, y1=20.0, x2=100.0, y2=200.0, conf=0.87, label="label") + assert d.x1 == 10.0 + assert d.label == "label" + assert d.conf == pytest.approx(0.87) + + def test_immutable(self) -> None: + d = Detection(0, 0, 1, 1, 0.5, "x") + with pytest.raises(AttributeError): + d.conf = 0.9 # type: ignore[misc] + + +class TestBboxOverlayState: + def setup_method(self) -> None: + self.overlay = BboxOverlay() + + def test_initially_no_detections(self) -> None: + assert self.overlay._detections == [] + + def test_initially_source_size_empty(self) -> None: + assert self.overlay._source_size.isEmpty() + + def test_on_detections_stores_data(self) -> None: + dets = [Detection(0, 0, 100, 100, 0.9, "label")] + self.overlay.on_detections(dets, (640, 480)) + assert self.overlay._detections == dets + assert self.overlay._source_size == QSize(640, 480) + + def test_clear_removes_detections(self) -> None: + self.overlay.on_detections([Detection(0, 0, 10, 10, 0.5, "x")], (100, 100)) + self.overlay.clear() + assert self.overlay._detections == [] + + def test_visible_by_default(self) -> None: + assert self.overlay.visible is True + + def test_multiple_detections_stored(self) -> None: + dets = [ + Detection(0, 0, 50, 50, 0.9, "label"), + Detection(100, 100, 200, 200, 0.75, "label"), + ] + self.overlay.on_detections(dets, (640, 480)) + assert len(self.overlay._detections) == 2 + + def test_replace_detections_on_new_call(self) -> None: + self.overlay.on_detections([Detection(0, 0, 10, 10, 0.5, "x")], (100, 100)) + self.overlay.on_detections([], (100, 100)) + assert self.overlay._detections == [] + + +class TestBboxOverlayCoordinateMapping: + """ + Verify that BboxOverlay correctly maps source-frame pixel coordinates + onto the letterboxed video_rect when painting. + + We don't test actual QPainter output — instead we verify that the + QRect values passed to painter.drawRect() correspond to the expected + scaled coordinates. + """ + + def setup_method(self) -> None: + self.overlay = BboxOverlay() + + def _make_painter_mock(self): + painter = MagicMock() + fm = MagicMock() + fm.height.return_value = 14 + fm.ascent.return_value = 11 + fm.horizontalAdvance.return_value = 60 + painter.fontMetrics.return_value = fm + return painter + + def test_paint_skips_when_no_detections(self) -> None: + painter = self._make_painter_mock() + self.overlay.paint(painter, QRect(0, 0, 640, 480)) + painter.drawRect.assert_not_called() + + def test_paint_skips_when_source_size_empty(self) -> None: + # detections present but source_size not set + self.overlay._detections = [Detection(0, 0, 100, 100, 0.9, "label")] + painter = self._make_painter_mock() + self.overlay.paint(painter, QRect(0, 0, 640, 480)) + painter.drawRect.assert_not_called() + + def test_bbox_scaled_to_full_video_rect(self) -> None: + """ + Source: 640×480, covers full frame. + video_rect: 640×480 at origin. + Detection: full-frame box → should map 1:1. + """ + self.overlay.on_detections( + [Detection(0.0, 0.0, 640.0, 480.0, 0.99, "label")], + (640, 480), + ) + painter = self._make_painter_mock() + video_rect = QRect(0, 0, 640, 480) + self.overlay.paint(painter, video_rect) + + # First drawRect call = the bounding box + first_call_rect: QRect = painter.drawRect.call_args_list[0][0][0] + assert first_call_rect.x() == 0 + assert first_call_rect.y() == 0 + assert first_call_rect.width() == 640 + assert first_call_rect.height() == 480 + + def test_bbox_scaled_with_half_size_video_rect(self) -> None: + """ + Source: 640×480, video_rect: 320×240 at origin (0.5× scale). + Detection at (64, 48)→(128, 96) should map to (32, 24)→(64, 48). + """ + self.overlay.on_detections( + [Detection(64.0, 48.0, 128.0, 96.0, 0.8, "label")], + (640, 480), + ) + painter = self._make_painter_mock() + video_rect = QRect(0, 0, 320, 240) + self.overlay.paint(painter, video_rect) + + first_call_rect: QRect = painter.drawRect.call_args_list[0][0][0] + assert first_call_rect.x() == 32 + assert first_call_rect.y() == 24 + assert first_call_rect.width() == 32 # (128-64) * 0.5 + assert first_call_rect.height() == 24 # (96-48) * 0.5 + + def test_bbox_offset_by_video_rect_origin(self) -> None: + """ + video_rect at (100, 50) — letterboxed with margins. + Detection at origin of source should map to (100, 50). + """ + self.overlay.on_detections( + [Detection(0.0, 0.0, 100.0, 100.0, 0.9, "label")], + (640, 480), + ) + painter = self._make_painter_mock() + # video_rect 320×240 starting at (100, 50) + video_rect = QRect(100, 50, 320, 240) + self.overlay.paint(painter, video_rect) + + first_call_rect: QRect = painter.drawRect.call_args_list[0][0][0] + # x: 100 + int(0 * 320/640) = 100 + # y: 50 + int(0 * 240/480) = 50 + assert first_call_rect.x() == 100 + assert first_call_rect.y() == 50 + + +class TestBboxOverlayWorkerPacket: + """Test FramePacket and ResultPacket data structures.""" + + def test_frame_packet_fields(self) -> None: + from app.inference.worker import FramePacket + pkt = FramePacket( + frame_id=1, + raw_bytes=b"\x00" * 12, + width=2, + height=2, + channels=3, + ) + assert pkt.frame_id == 1 + assert pkt.width == 2 + assert pkt.channels == 3 + + def test_result_packet_fields(self) -> None: + from app.inference.worker import ResultPacket + pkt = ResultPacket(frame_id=5, detections=[], width=640, height=480) + assert pkt.frame_id == 5 + assert pkt.detections == [] + assert pkt.width == 640 diff --git a/tests/test_inference_manager.py b/tests/test_inference_manager.py new file mode 100644 index 0000000..c9f6a4a --- /dev/null +++ b/tests/test_inference_manager.py @@ -0,0 +1,238 @@ +"""Tests for InferenceManager — drop-if-busy, restart counter, model validation.""" + +from __future__ import annotations + +import sys +from unittest.mock import MagicMock, patch + +import pytest +from PySide6.QtWidgets import QApplication + +from app.inference.worker_manager import InferenceManager + +# Ensure a QApplication exists for tests that create Qt objects +_app = QApplication.instance() or QApplication(sys.argv) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_manager() -> InferenceManager: + """Return an InferenceManager without starting any process.""" + mgr = InferenceManager.__new__(InferenceManager) + mgr._model_path = None + mgr._process = None + mgr._input_queue = None + mgr._output_queue = None + mgr._stop_event = None + mgr._busy = False + mgr._frame_id = 0 + mgr._restart_count = 0 + mgr._last_result_time = 0.0 + mgr._paused = False + return mgr + + +# --------------------------------------------------------------------------- +# Model path validation +# --------------------------------------------------------------------------- + +class TestModelPathValidation: + def test_start_emits_error_for_missing_file(self, tmp_path) -> None: + """start() with non-existent path must NOT spawn a process.""" + mgr = InferenceManager() + errors: list[str] = [] + mgr.inference_error.connect(errors.append) + + mgr.start(str(tmp_path / "nonexistent.pt")) + + assert errors, "Expected inference_error signal" + assert mgr._process is None + + def test_start_does_not_raise_for_existing_file(self, tmp_path) -> None: + """start() with existing file should attempt to start (we mock _start_worker).""" + model_file = tmp_path / "model.pt" + model_file.write_bytes(b"fake") + + mgr = InferenceManager() + with patch.object(mgr, "_start_worker") as mock_start: + mgr.start(str(model_file)) + mock_start.assert_called_once() + + +# --------------------------------------------------------------------------- +# Drop-if-busy logic +# --------------------------------------------------------------------------- + +class TestDropIfBusy: + def test_submit_frame_drops_when_busy(self) -> None: + """submit_frame must not enqueue when _busy is True.""" + mgr = _make_manager() + mgr._busy = True + mgr._process = MagicMock() + mgr._process.is_alive.return_value = True + mgr._input_queue = MagicMock() + + frame = MagicMock() + frame.isValid.return_value = True + mgr.submit_frame(frame) + + mgr._input_queue.put_nowait.assert_not_called() + + def test_submit_frame_drops_when_paused(self) -> None: + mgr = _make_manager() + mgr._paused = True + mgr._process = MagicMock() + mgr._process.is_alive.return_value = True + mgr._input_queue = MagicMock() + + frame = MagicMock() + frame.isValid.return_value = True + mgr.submit_frame(frame) + + mgr._input_queue.put_nowait.assert_not_called() + + def test_submit_frame_drops_when_not_running(self) -> None: + mgr = _make_manager() + mgr._process = None + mgr._input_queue = MagicMock() + + frame = MagicMock() + frame.isValid.return_value = True + mgr.submit_frame(frame) + + mgr._input_queue.put_nowait.assert_not_called() + + def test_submit_frame_drops_invalid_frame(self) -> None: + mgr = _make_manager() + mgr._process = MagicMock() + mgr._process.is_alive.return_value = True + mgr._input_queue = MagicMock() + + frame = MagicMock() + frame.isValid.return_value = False + mgr.submit_frame(frame) + + mgr._input_queue.put_nowait.assert_not_called() + + +# --------------------------------------------------------------------------- +# Pause / resume +# --------------------------------------------------------------------------- + +class TestPauseResume: + def test_pause_sets_flag(self) -> None: + mgr = _make_manager() + assert mgr._paused is False + mgr.pause() + assert mgr._paused is True + + def test_resume_clears_flag(self) -> None: + mgr = _make_manager() + mgr.pause() + mgr.resume() + assert mgr._paused is False + + def test_is_paused_property(self) -> None: + mgr = _make_manager() + assert mgr.is_paused is False + mgr.pause() + assert mgr.is_paused is True + + +# --------------------------------------------------------------------------- +# Restart counter +# --------------------------------------------------------------------------- + +class TestRestartCounter: + def test_handle_crash_increments_counter(self) -> None: + mgr = InferenceManager() + mgr._model_path = "fake.pt" + mgr._restart_count = 0 + + with ( + patch.object(mgr, "_start_worker"), + patch.object(mgr._poll_timer, "stop"), + patch.object(mgr._watchdog_timer, "stop"), + ): + mgr._handle_crash("test crash") + + assert mgr._restart_count == 1 + + def test_handle_crash_emits_error_after_max_restarts(self) -> None: + from app.config import INFERENCE_MAX_RESTARTS + + mgr = InferenceManager() + mgr._model_path = "fake.pt" + mgr._restart_count = INFERENCE_MAX_RESTARTS + + errors: list[str] = [] + mgr.inference_error.connect(errors.append) + + with ( + patch.object(mgr, "_start_worker") as mock_start, + patch.object(mgr._poll_timer, "stop"), + patch.object(mgr._watchdog_timer, "stop"), + ): + mgr._handle_crash("final crash") + + assert errors, "Expected inference_error signal after max restarts" + mock_start.assert_not_called() + + def test_stop_resets_restart_count(self) -> None: + mgr = InferenceManager() + mgr._restart_count = 2 + + with patch.object(mgr, "_stop_worker"): + mgr.stop() + + assert mgr._restart_count == 0 + + +# --------------------------------------------------------------------------- +# is_running property +# --------------------------------------------------------------------------- + +class TestIsRunning: + def test_not_running_when_process_is_none(self) -> None: + mgr = _make_manager() + assert mgr.is_running is False + + def test_not_running_when_process_dead(self) -> None: + mgr = _make_manager() + proc = MagicMock() + proc.is_alive.return_value = False + mgr._process = proc + assert mgr.is_running is False + + def test_running_when_process_alive(self) -> None: + mgr = _make_manager() + proc = MagicMock() + proc.is_alive.return_value = True + mgr._process = proc + assert mgr.is_running is True + + +# --------------------------------------------------------------------------- +# Worker data structures +# --------------------------------------------------------------------------- + +class TestWorkerDataStructures: + def test_frame_packet_is_immutable(self) -> None: + from app.inference.worker import FramePacket + pkt = FramePacket(1, b"", 640, 480, 3) + with pytest.raises(AttributeError): + pkt.frame_id = 2 # type: ignore[misc] + + def test_result_packet_is_immutable(self) -> None: + from app.inference.worker import ResultPacket + pkt = ResultPacket(1, [], 640, 480) + with pytest.raises(AttributeError): + pkt.frame_id = 2 # type: ignore[misc] + + def test_select_device_returns_string(self) -> None: + from app.inference.worker import _select_device + device = _select_device() + assert isinstance(device, str) + assert device in ("cpu", "mps", "cuda")