145 lines
5.1 KiB
Python
145 lines
5.1 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
|
|
from app.label_parser import ParsedLabel, parse_label_text
|
|
from app.ocr import create_ocr_engine
|
|
|
|
|
|
@dataclass
|
|
class DetectionResult:
|
|
xyxy: tuple[int, int, int, int] | None = None
|
|
confidence: float | None = None
|
|
class_name: str | None = None
|
|
raw_text: str = ""
|
|
parsed: ParsedLabel | None = None
|
|
error: str | None = None
|
|
ocr_engine: str | None = None
|
|
ocr_confidence: float | None = None
|
|
ocr_elapsed_ms: float | None = None
|
|
all_boxes: list[dict[str, Any]] = field(default_factory=list)
|
|
|
|
def to_metadata(self) -> dict[str, Any]:
|
|
return {
|
|
"bbox_xyxy": list(self.xyxy) if self.xyxy else None,
|
|
"confidence": self.confidence,
|
|
"class_name": self.class_name,
|
|
"raw_text": self.raw_text,
|
|
"parsed": self.parsed.to_dict() if self.parsed else None,
|
|
"error": self.error,
|
|
"ocr_engine": self.ocr_engine,
|
|
"ocr_confidence": self.ocr_confidence,
|
|
"ocr_elapsed_ms": self.ocr_elapsed_ms,
|
|
"all_boxes": self.all_boxes,
|
|
}
|
|
|
|
|
|
class YoloLabelDetector:
|
|
def __init__(self, config: dict[str, Any], app_config: Any) -> None:
|
|
self.config = config
|
|
self.app_config = app_config
|
|
self.model = None
|
|
self.load_error: str | None = None
|
|
self._load_model()
|
|
|
|
def _load_model(self) -> None:
|
|
model_path = self.app_config.resolve_path(self.config["detection"]["model_path"])
|
|
if not model_path.exists():
|
|
self.load_error = f"Brak modelu: {model_path}"
|
|
return
|
|
|
|
try:
|
|
from ultralytics import YOLO
|
|
|
|
self.model = YOLO(str(model_path))
|
|
except Exception as exc: # pragma: no cover - depends on optional runtime deps
|
|
self.load_error = f"Nie mozna zaladowac YOLO: {exc}"
|
|
|
|
def detect(self, frame_bgr: np.ndarray) -> DetectionResult:
|
|
if self.model is None:
|
|
return DetectionResult(error=self.load_error or "Model YOLO nie jest zaladowany")
|
|
|
|
detection_cfg = self.config["detection"]
|
|
try:
|
|
results = self.model.predict(
|
|
source=frame_bgr,
|
|
conf=float(detection_cfg["confidence_threshold"]),
|
|
imgsz=int(detection_cfg["image_size"]),
|
|
device=detection_cfg.get("device", "cpu"),
|
|
verbose=False,
|
|
)
|
|
except Exception as exc: # pragma: no cover - depends on model runtime
|
|
return DetectionResult(error=f"Blad YOLO: {exc}")
|
|
|
|
boxes = []
|
|
names = getattr(self.model, "names", {})
|
|
for result in results:
|
|
if result.boxes is None:
|
|
continue
|
|
|
|
for box in result.boxes:
|
|
x1, y1, x2, y2 = [int(v) for v in box.xyxy[0].tolist()]
|
|
confidence = float(box.conf[0])
|
|
class_id = int(box.cls[0]) if box.cls is not None else -1
|
|
class_name = names.get(class_id, str(class_id)) if isinstance(names, dict) else str(class_id)
|
|
boxes.append(
|
|
{
|
|
"xyxy": (x1, y1, x2, y2),
|
|
"confidence": confidence,
|
|
"class_name": class_name,
|
|
}
|
|
)
|
|
|
|
if not boxes:
|
|
return DetectionResult(error="Nie wykryto etykiety")
|
|
|
|
boxes.sort(key=lambda item: item["confidence"], reverse=True)
|
|
selected = boxes[0]
|
|
result = DetectionResult(
|
|
xyxy=selected["xyxy"],
|
|
confidence=selected["confidence"],
|
|
class_name=selected["class_name"],
|
|
all_boxes=[
|
|
{
|
|
"xyxy": list(item["xyxy"]),
|
|
"confidence": item["confidence"],
|
|
"class_name": item["class_name"],
|
|
}
|
|
for item in boxes
|
|
],
|
|
)
|
|
return result
|
|
|
|
|
|
class DetectionPipeline:
|
|
def __init__(self, config: dict[str, Any], app_config: Any) -> None:
|
|
self.config = config
|
|
self.detector = YoloLabelDetector(config, app_config)
|
|
self.ocr = create_ocr_engine(config)
|
|
|
|
def process(self, frame_bgr: np.ndarray) -> DetectionResult:
|
|
result = self.detector.detect(frame_bgr)
|
|
if result.xyxy is None:
|
|
return result
|
|
|
|
ocr_result = self.ocr.read_label(frame_bgr, result.xyxy)
|
|
result.raw_text = ocr_result.text
|
|
result.ocr_engine = ocr_result.engine
|
|
result.ocr_confidence = ocr_result.confidence
|
|
result.ocr_elapsed_ms = ocr_result.elapsed_ms
|
|
label_cfg = self.config["label_data"]
|
|
result.parsed = parse_label_text(
|
|
ocr_result.text,
|
|
label_cfg.get("colors", []),
|
|
label_cfg.get("models", []),
|
|
model_min_score=float(label_cfg.get("model_min_score", 0.72)),
|
|
color_min_score=float(label_cfg.get("color_min_score", 0.72)),
|
|
)
|
|
if ocr_result.error:
|
|
result.error = ocr_result.error
|
|
return result
|