more linting
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
# This code is from leafqycc/rknn-multi-threaded
|
||||
# Following Apache License 2.0
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
@@ -21,7 +22,7 @@ def get_soc(device_tree_path: Path | str) -> str | None:
|
||||
return soc
|
||||
log.warning("Device is not supported for RKNN")
|
||||
except OSError as e:
|
||||
log.warning("Could not read /proc/device-tree/compatible. Reason: %s", e.msg)
|
||||
log.warning("Could not read /proc/device-tree/compatible. Reason: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
@@ -56,19 +57,24 @@ def init_rknn(model_path: str) -> "RKNNLite":
|
||||
|
||||
|
||||
class RknnPoolExecutor:
|
||||
def __init__(self, model_path: str, tpes: int, func):
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
tpes: int,
|
||||
func: Callable[["RKNNLite", list[NDArray[np.float32]]], list[NDArray[np.float32]]],
|
||||
) -> None:
|
||||
self.tpes = tpes
|
||||
self.queue = Queue()
|
||||
self.queue: Queue[Future[list[NDArray[np.float32]]]] = Queue()
|
||||
self.rknn_pool = [init_rknn(model_path) for _ in range(tpes)]
|
||||
self.pool = ThreadPoolExecutor(max_workers=tpes)
|
||||
self.func = func
|
||||
self.num = 0
|
||||
|
||||
def put(self, inputs) -> None:
|
||||
def put(self, inputs: list[NDArray[np.float32]]) -> None:
|
||||
self.queue.put(self.pool.submit(self.func, self.rknn_pool[self.num % self.tpes], inputs))
|
||||
self.num += 1
|
||||
|
||||
def get(self) -> list[list[NDArray[np.float32]], bool]:
|
||||
def get(self) -> list[NDArray[np.float32]] | None:
|
||||
if self.queue.empty():
|
||||
return None
|
||||
fut = self.queue.get()
|
||||
|
||||
Reference in New Issue
Block a user