feat(ml): ARM NN acceleration

This commit is contained in:
Fynn Petersen-Frey
2023-11-04 09:34:19 +01:00
committed by Fynn Petersen-Frey
parent 767fe87b2e
commit 5f6ad9e239
4 changed files with 86 additions and 6 deletions

View File

@@ -1,12 +1,18 @@
import json
from enum import Enum
from pathlib import Path
from typing import Any
def get_model_path(output_dir: Path | str) -> Path:
class ModelType(Enum):
ONNX = "onnx"
TFLITE = "tflite"
def get_model_path(output_dir: Path | str, model_type: ModelType = ModelType.ONNX) -> Path:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
return output_dir / "model.onnx"
return output_dir / f"model.{model_type.value}"
def save_config(config: Any, output_path: Path | str) -> None: