feat(ml): improved ARM-NN support (#11233)

This commit is contained in:
Fynn Petersen-Frey
2024-07-20 21:59:27 +02:00
committed by GitHub
parent 7c3326b662
commit 54488b1016
8 changed files with 70 additions and 32 deletions
+2
View File
@@ -30,6 +30,8 @@ class Settings(BaseSettings):
model_inter_op_threads: int = 0
model_intra_op_threads: int = 0
ann: bool = True
ann_fp16_turbo: bool = False
ann_tuning_level: int = 2
preload: PreloadModelData | None = None
class Config:
+2 -1
View File
@@ -20,12 +20,13 @@ class AnnSession:
def __init__(self, model_path: Path, cache_dir: Path = settings.cache_folder) -> None:
self.model_path = model_path
self.cache_dir = cache_dir
self.ann = Ann(tuning_level=3, tuning_file=(cache_dir / "gpu-tuning.ann").as_posix())
self.ann = Ann(tuning_level=settings.ann_tuning_level, tuning_file=(cache_dir / "gpu-tuning.ann").as_posix())
log.info("Loading ANN model %s ...", model_path)
self.model = self.ann.load(
model_path.as_posix(),
cached_network_path=model_path.with_suffix(".anncache").as_posix(),
fp16=settings.ann_fp16_turbo,
)
log.info("Loaded ANN model with ID %d", self.model)
+2 -2
View File
@@ -268,9 +268,9 @@ class TestAnnSession:
AnnSession(model_path, cache_dir)
ann_session.assert_called_once_with(tuning_level=3, tuning_file=(cache_dir / "gpu-tuning.ann").as_posix())
ann_session.assert_called_once_with(tuning_level=2, tuning_file=(cache_dir / "gpu-tuning.ann").as_posix())
ann_session.return_value.load.assert_called_once_with(
model_path.as_posix(), cached_network_path=model_path.with_suffix(".anncache").as_posix()
model_path.as_posix(), cached_network_path=model_path.with_suffix(".anncache").as_posix(), fp16=False
)
info.assert_has_calls(
[