refactor
This commit is contained in:
@@ -7,6 +7,7 @@ from insightface.model_zoo import RetinaFace
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from app.models.base import InferenceModel
|
||||
from app.models.session import ort_has_batch_dim, ort_squeeze_outputs
|
||||
from app.models.transforms import decode_cv2
|
||||
from app.schemas import FaceDetectionOutput, ModelSession, ModelTask, ModelType
|
||||
|
||||
@@ -27,7 +28,8 @@ class FaceDetector(InferenceModel):
|
||||
|
||||
def _load(self) -> ModelSession:
|
||||
session = self._make_session(self.model_path)
|
||||
self._squeeze_outputs(session)
|
||||
if isinstance(session, ort.InferenceSession) and ort_has_batch_dim(session):
|
||||
ort_squeeze_outputs(session)
|
||||
self.model = RetinaFace(session=session)
|
||||
self.model.prepare(ctx_id=0, det_thresh=self.min_score, input_size=(640, 640))
|
||||
|
||||
@@ -46,15 +48,5 @@ class FaceDetector(InferenceModel):
|
||||
def _detect(self, inputs: NDArray[np.uint8] | bytes) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
|
||||
return self.model.detect(inputs) # type: ignore
|
||||
|
||||
def _squeeze_outputs(self, session: ort.InferenceSession) -> None:
|
||||
original_run = session.run
|
||||
|
||||
def run(output_names: list[str], input_feed: dict[str, NDArray[np.float32]]) -> list[NDArray[np.float32]]:
|
||||
out: list[NDArray[np.float32]] = original_run(output_names, input_feed)
|
||||
out = [o.squeeze(axis=0) for o in out]
|
||||
return out
|
||||
|
||||
session.run = run
|
||||
|
||||
def configure(self, **kwargs: Any) -> None:
|
||||
self.model.det_thresh = kwargs.pop("minScore", self.model.det_thresh)
|
||||
|
||||
Reference in New Issue
Block a user