optimized scrfd code
This commit is contained in:
@@ -12,12 +12,12 @@ def ort_has_batch_dim(session: ort.InferenceSession) -> bool:
|
||||
return session.get_inputs()[0].shape[0] == "batch"
|
||||
|
||||
|
||||
def ort_squeeze_outputs(session: ort.InferenceSession) -> None:
|
||||
def ort_expand_outputs(session: ort.InferenceSession) -> None:
|
||||
original_run = session.run
|
||||
|
||||
def run(output_names: list[str], input_feed: dict[str, NDArray[np.float32]]) -> list[NDArray[np.float32]]:
|
||||
out: list[NDArray[np.float32]] = original_run(output_names, input_feed)
|
||||
out = [o.squeeze(axis=0) for o in out]
|
||||
out = [np.expand_dims(o, axis=0) for o in out]
|
||||
return out
|
||||
|
||||
session.run = run
|
||||
|
||||
Reference in New Issue
Block a user