add cli
This commit is contained in:
@@ -1,28 +1,35 @@
|
|||||||
FROM mambaorg/micromamba:bookworm-slim@sha256:333f7598ff2c2400fb10bfe057709c68b7daab5d847143af85abcf224a07271a as builder
|
FROM mambaorg/micromamba:bookworm-slim@sha256:333f7598ff2c2400fb10bfe057709c68b7daab5d847143af85abcf224a07271a as builder
|
||||||
|
|
||||||
ENV TRANSFORMERS_CACHE=/cache \
|
|
||||||
PYTHONDONTWRITEBYTECODE=1 \
|
|
||||||
PYTHONUNBUFFERED=1 \
|
|
||||||
PATH="/opt/venv/bin:$PATH"
|
|
||||||
|
|
||||||
WORKDIR /export/ann
|
|
||||||
|
|
||||||
USER root
|
USER root
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
build-essential \
|
build-essential \
|
||||||
|
cmake \
|
||||||
curl \
|
curl \
|
||||||
git
|
git
|
||||||
|
|
||||||
USER $MAMBA_USER
|
USER $MAMBA_USER
|
||||||
COPY --chown=$MAMBA_USER:$MAMBA_USER env.yaml ./
|
|
||||||
RUN micromamba install -y -f env.yaml
|
|
||||||
COPY --chown=$MAMBA_USER:$MAMBA_USER *.sh *.cpp ./
|
|
||||||
|
|
||||||
ENV ARMNN_PATH=/export/ann/armnn
|
WORKDIR /home/mambauser
|
||||||
|
ENV ARMNN_PATH=armnn
|
||||||
|
COPY --chown=$MAMBA_USER:$MAMBA_USER scripts/* .
|
||||||
RUN ./download-armnn.sh && \
|
RUN ./download-armnn.sh && \
|
||||||
./build-converter.sh && \
|
./build-converter.sh && \
|
||||||
./build.sh
|
./build.sh
|
||||||
COPY --chown=$MAMBA_USER:$MAMBA_USER run.py ./
|
|
||||||
|
|
||||||
ENTRYPOINT ["/usr/local/bin/_entrypoint.sh"]
|
COPY --chown=$MAMBA_USER:$MAMBA_USER conda-lock.yml .
|
||||||
CMD ["python", "run.py"]
|
RUN micromamba create -y -p /home/mambauser/venv -f conda-lock.yml && \
|
||||||
|
micromamba clean --all --yes
|
||||||
|
ENV PATH="/home/mambauser/venv/bin:${PATH}"
|
||||||
|
|
||||||
|
FROM gcr.io/distroless/base-debian12
|
||||||
|
# FROM mambaorg/micromamba:bookworm-slim@sha256:333f7598ff2c2400fb10bfe057709c68b7daab5d847143af85abcf224a07271a
|
||||||
|
|
||||||
|
WORKDIR /export/ann
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||||
|
LD_LIBRARY_PATH=/export/ann/armnn \
|
||||||
|
PATH="/opt/venv/bin:${PATH}"
|
||||||
|
|
||||||
|
COPY --from=builder /home/mambauser/armnnconverter /home/mambauser/armnn ./
|
||||||
|
COPY --from=builder /home/mambauser/venv /opt/venv
|
||||||
|
COPY --chown=$MAMBA_USER:$MAMBA_USER onnx2ann onnx2ann
|
||||||
|
|
||||||
|
ENTRYPOINT ["python", "-m", "onnx2ann"]
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,201 +1,21 @@
|
|||||||
name: annexport
|
name: onnx2ann
|
||||||
channels:
|
channels:
|
||||||
- pytorch
|
|
||||||
- nvidia
|
|
||||||
- conda-forge
|
- conda-forge
|
||||||
dependencies:
|
dependencies:
|
||||||
- _libgcc_mutex=0.1=conda_forge
|
- python>=3.11,<4.0
|
||||||
- _openmp_mutex=4.5=2_kmp_llvm
|
- onnx>=1.16.1
|
||||||
- aiohttp=3.9.1=py310h2372a71_0
|
# - onnxruntime>=1.18.1 # conda only has gpu version
|
||||||
- aiosignal=1.3.1=pyhd8ed1ab_0
|
- psutil>=6.0.0
|
||||||
- arpack=3.8.0=nompi_h0baa96a_101
|
- flatbuffers>=24.3.25
|
||||||
- async-timeout=4.0.3=pyhd8ed1ab_0
|
- ml_dtypes>=0.3.1
|
||||||
- attrs=23.1.0=pyh71513ae_1
|
- typer-slim>=0.12.3
|
||||||
- aws-c-auth=0.7.3=h28f7589_1
|
- huggingface_hub>=0.23.4
|
||||||
- aws-c-cal=0.6.1=hc309b26_1
|
- pip
|
||||||
- aws-c-common=0.9.0=hd590300_0
|
|
||||||
- aws-c-compression=0.2.17=h4d4d85c_2
|
|
||||||
- aws-c-event-stream=0.3.1=h2e3709c_4
|
|
||||||
- aws-c-http=0.7.11=h00aa349_4
|
|
||||||
- aws-c-io=0.13.32=he9a53bd_1
|
|
||||||
- aws-c-mqtt=0.9.3=hb447be9_1
|
|
||||||
- aws-c-s3=0.3.14=hf3aad02_1
|
|
||||||
- aws-c-sdkutils=0.1.12=h4d4d85c_1
|
|
||||||
- aws-checksums=0.1.17=h4d4d85c_1
|
|
||||||
- aws-crt-cpp=0.21.0=hb942446_5
|
|
||||||
- aws-sdk-cpp=1.10.57=h85b1a90_19
|
|
||||||
- blas=2.120=openblas
|
|
||||||
- blas-devel=3.9.0=20_linux64_openblas
|
|
||||||
- brotli-python=1.0.9=py310hd8f1fbe_9
|
|
||||||
- bzip2=1.0.8=hd590300_5
|
|
||||||
- c-ares=1.23.0=hd590300_0
|
|
||||||
- ca-certificates=2023.11.17=hbcca054_0
|
|
||||||
- certifi=2023.11.17=pyhd8ed1ab_0
|
|
||||||
- charset-normalizer=3.3.2=pyhd8ed1ab_0
|
|
||||||
- click=8.1.7=unix_pyh707e725_0
|
|
||||||
- colorama=0.4.6=pyhd8ed1ab_0
|
|
||||||
- coloredlogs=15.0.1=pyhd8ed1ab_3
|
|
||||||
- cuda-cudart=11.7.99=0
|
|
||||||
- cuda-cupti=11.7.101=0
|
|
||||||
- cuda-libraries=11.7.1=0
|
|
||||||
- cuda-nvrtc=11.7.99=0
|
|
||||||
- cuda-nvtx=11.7.91=0
|
|
||||||
- cuda-runtime=11.7.1=0
|
|
||||||
- dataclasses=0.8=pyhc8e2a94_3
|
|
||||||
- datasets=2.14.7=pyhd8ed1ab_0
|
|
||||||
- dill=0.3.7=pyhd8ed1ab_0
|
|
||||||
- filelock=3.13.1=pyhd8ed1ab_0
|
|
||||||
- flatbuffers=23.5.26=h59595ed_1
|
|
||||||
- freetype=2.12.1=h267a509_2
|
|
||||||
- frozenlist=1.4.0=py310h2372a71_1
|
|
||||||
- fsspec=2023.10.0=pyhca7485f_0
|
|
||||||
- ftfy=6.1.3=pyhd8ed1ab_0
|
|
||||||
- gflags=2.2.2=he1b5a44_1004
|
|
||||||
- glog=0.6.0=h6f12383_0
|
|
||||||
- glpk=5.0=h445213a_0
|
|
||||||
- gmp=6.3.0=h59595ed_0
|
|
||||||
- gmpy2=2.1.2=py310h3ec546c_1
|
|
||||||
- huggingface_hub=0.17.3=pyhd8ed1ab_0
|
|
||||||
- humanfriendly=10.0=pyhd8ed1ab_6
|
|
||||||
- icu=73.2=h59595ed_0
|
|
||||||
- idna=3.6=pyhd8ed1ab_0
|
|
||||||
- importlib-metadata=7.0.0=pyha770c72_0
|
|
||||||
- importlib_metadata=7.0.0=hd8ed1ab_0
|
|
||||||
- joblib=1.3.2=pyhd8ed1ab_0
|
|
||||||
- keyutils=1.6.1=h166bdaf_0
|
|
||||||
- krb5=1.21.2=h659d440_0
|
|
||||||
- lcms2=2.15=h7f713cb_2
|
|
||||||
- ld_impl_linux-64=2.40=h41732ed_0
|
|
||||||
- lerc=4.0.0=h27087fc_0
|
|
||||||
- libabseil=20230125.3=cxx17_h59595ed_0
|
|
||||||
- libarrow=12.0.1=hb87d912_8_cpu
|
|
||||||
- libblas=3.9.0=20_linux64_openblas
|
|
||||||
- libbrotlicommon=1.0.9=h166bdaf_9
|
|
||||||
- libbrotlidec=1.0.9=h166bdaf_9
|
|
||||||
- libbrotlienc=1.0.9=h166bdaf_9
|
|
||||||
- libcblas=3.9.0=20_linux64_openblas
|
|
||||||
- libcrc32c=1.1.2=h9c3ff4c_0
|
|
||||||
- libcublas=11.10.3.66=0
|
|
||||||
- libcufft=10.7.2.124=h4fbf590_0
|
|
||||||
- libcufile=1.8.1.2=0
|
|
||||||
- libcurand=10.3.4.101=0
|
|
||||||
- libcurl=8.5.0=hca28451_0
|
|
||||||
- libcusolver=11.4.0.1=0
|
|
||||||
- libcusparse=11.7.4.91=0
|
|
||||||
- libdeflate=1.19=hd590300_0
|
|
||||||
- libedit=3.1.20191231=he28a2e2_2
|
|
||||||
- libev=4.33=hd590300_2
|
|
||||||
- libevent=2.1.12=hf998b51_1
|
|
||||||
- libffi=3.4.2=h7f98852_5
|
|
||||||
- libgcc-ng=13.2.0=h807b86a_3
|
|
||||||
- libgfortran-ng=13.2.0=h69a702a_3
|
|
||||||
- libgfortran5=13.2.0=ha4646dd_3
|
|
||||||
- libgoogle-cloud=2.12.0=hac9eb74_1
|
|
||||||
- libgrpc=1.54.3=hb20ce57_0
|
|
||||||
- libhwloc=2.9.3=default_h554bfaf_1009
|
|
||||||
- libiconv=1.17=hd590300_1
|
|
||||||
- libjpeg-turbo=2.1.5.1=hd590300_1
|
|
||||||
- liblapack=3.9.0=20_linux64_openblas
|
|
||||||
- liblapacke=3.9.0=20_linux64_openblas
|
|
||||||
- libnghttp2=1.58.0=h47da74e_1
|
|
||||||
- libnpp=11.7.4.75=0
|
|
||||||
- libnsl=2.0.1=hd590300_0
|
|
||||||
- libnuma=2.0.16=h0b41bf4_1
|
|
||||||
- libnvjpeg=11.8.0.2=0
|
|
||||||
- libopenblas=0.3.25=pthreads_h413a1c8_0
|
|
||||||
- libpng=1.6.39=h753d276_0
|
|
||||||
- libprotobuf=3.21.12=hfc55251_2
|
|
||||||
- libsentencepiece=0.1.99=h180e1df_0
|
|
||||||
- libsqlite=3.44.2=h2797004_0
|
|
||||||
- libssh2=1.11.0=h0841786_0
|
|
||||||
- libstdcxx-ng=13.2.0=h7e041cc_3
|
|
||||||
- libthrift=0.18.1=h8fd135c_2
|
|
||||||
- libtiff=4.6.0=h29866fb_1
|
|
||||||
- libutf8proc=2.8.0=h166bdaf_0
|
|
||||||
- libuuid=2.38.1=h0b41bf4_0
|
|
||||||
- libwebp-base=1.3.2=hd590300_0
|
|
||||||
- libxcb=1.15=h0b41bf4_0
|
|
||||||
- libxml2=2.11.6=h232c23b_0
|
|
||||||
- libzlib=1.2.13=hd590300_5
|
|
||||||
- llvm-openmp=17.0.6=h4dfa4b3_0
|
|
||||||
- lz4-c=1.9.4=hcb278e6_0
|
|
||||||
- mkl=2022.2.1=h84fe81f_16997
|
|
||||||
- mkl-devel=2022.2.1=ha770c72_16998
|
|
||||||
- mkl-include=2022.2.1=h84fe81f_16997
|
|
||||||
- mpc=1.3.1=hfe3b2da_0
|
|
||||||
- mpfr=4.2.1=h9458935_0
|
|
||||||
- mpmath=1.3.0=pyhd8ed1ab_0
|
|
||||||
- multidict=6.0.4=py310h2372a71_1
|
|
||||||
- multiprocess=0.70.15=py310h2372a71_1
|
|
||||||
- ncurses=6.4=h59595ed_2
|
|
||||||
- numpy=1.26.2=py310hb13e2d6_0
|
|
||||||
- onnx=1.14.0=py310ha3deec4_1
|
|
||||||
- onnx2torch=1.5.13=pyhd8ed1ab_0
|
|
||||||
- onnxruntime=1.16.3=py310hd4b7fbc_1_cpu
|
|
||||||
- open-clip-torch=2.23.0=pyhd8ed1ab_1
|
|
||||||
- openblas=0.3.25=pthreads_h7a3da1a_0
|
|
||||||
- openjpeg=2.5.0=h488ebb8_3
|
|
||||||
- openssl=3.2.0=hd590300_1
|
|
||||||
- orc=1.9.0=h2f23424_1
|
|
||||||
- packaging=23.2=pyhd8ed1ab_0
|
|
||||||
- pandas=2.1.4=py310hcc13569_0
|
|
||||||
- pillow=10.0.1=py310h29da1c1_1
|
|
||||||
- pip=23.3.1=pyhd8ed1ab_0
|
|
||||||
- protobuf=4.21.12=py310heca2aa9_0
|
|
||||||
- pthread-stubs=0.4=h36c2ea0_1001
|
|
||||||
- pyarrow=12.0.1=py310h0576679_8_cpu
|
|
||||||
- pyarrow-hotfix=0.6=pyhd8ed1ab_0
|
|
||||||
- pysocks=1.7.1=pyha2e5f31_6
|
|
||||||
- python=3.10.13=hd12c33a_0_cpython
|
|
||||||
- python-dateutil=2.8.2=pyhd8ed1ab_0
|
|
||||||
- python-flatbuffers=23.5.26=pyhd8ed1ab_0
|
|
||||||
- python-tzdata=2023.3=pyhd8ed1ab_0
|
|
||||||
- python-xxhash=3.4.1=py310h2372a71_0
|
|
||||||
- python_abi=3.10=4_cp310
|
|
||||||
- pytorch=1.13.1=cpu_py310hd11e9c7_1
|
|
||||||
- pytorch-cuda=11.7=h778d358_5
|
|
||||||
- pytorch-mutex=1.0=cuda
|
|
||||||
- pytz=2023.3.post1=pyhd8ed1ab_0
|
|
||||||
- pyyaml=6.0.1=py310h2372a71_1
|
|
||||||
- rdma-core=28.9=h59595ed_1
|
|
||||||
- re2=2023.03.02=h8c504da_0
|
|
||||||
- readline=8.2=h8228510_1
|
|
||||||
- regex=2023.10.3=py310h2372a71_0
|
|
||||||
- requests=2.31.0=pyhd8ed1ab_0
|
|
||||||
- s2n=1.3.49=h06160fa_0
|
|
||||||
- sacremoses=0.0.53=pyhd8ed1ab_0
|
|
||||||
- safetensors=0.3.3=py310hcb5633a_1
|
|
||||||
- sentencepiece=0.1.99=hff52083_0
|
|
||||||
- sentencepiece-python=0.1.99=py310hebdb9f0_0
|
|
||||||
- sentencepiece-spm=0.1.99=h180e1df_0
|
|
||||||
- setuptools=68.2.2=pyhd8ed1ab_0
|
|
||||||
- six=1.16.0=pyh6c4a22f_0
|
|
||||||
- sleef=3.5.1=h9b69904_2
|
|
||||||
- snappy=1.1.10=h9fff704_0
|
|
||||||
- sympy=1.12=pypyh9d50eac_103
|
|
||||||
- tbb=2021.11.0=h00ab1b0_0
|
|
||||||
- texttable=1.7.0=pyhd8ed1ab_0
|
|
||||||
- timm=0.9.12=pyhd8ed1ab_0
|
|
||||||
- tk=8.6.13=noxft_h4845f30_101
|
|
||||||
- tokenizers=0.14.1=py310h320607d_2
|
|
||||||
- torchvision=0.14.1=cpu_py310hd3d2ac3_1
|
|
||||||
- tqdm=4.66.1=pyhd8ed1ab_0
|
|
||||||
- transformers=4.35.2=pyhd8ed1ab_0
|
|
||||||
- typing-extensions=4.9.0=hd8ed1ab_0
|
|
||||||
- typing_extensions=4.9.0=pyha770c72_0
|
|
||||||
- tzdata=2023c=h71feb2d_0
|
|
||||||
- ucx=1.14.1=h64cca9d_5
|
|
||||||
- urllib3=2.1.0=pyhd8ed1ab_0
|
|
||||||
- wcwidth=0.2.12=pyhd8ed1ab_0
|
|
||||||
- wheel=0.42.0=pyhd8ed1ab_0
|
|
||||||
- xorg-libxau=1.0.11=hd590300_0
|
|
||||||
- xorg-libxdmcp=1.1.3=h7f98852_0
|
|
||||||
- xxhash=0.8.2=hd590300_0
|
|
||||||
- xz=5.2.6=h166bdaf_0
|
|
||||||
- yaml=0.2.5=h7f98852_2
|
|
||||||
- yarl=1.9.3=py310h2372a71_0
|
|
||||||
- zipp=3.17.0=pyhd8ed1ab_0
|
|
||||||
- zlib=1.2.13=hd590300_5
|
|
||||||
- zstd=1.5.5=hfc55251_0
|
|
||||||
- pip:
|
- pip:
|
||||||
- git+https://github.com/fyfrey/TinyNeuralNetwork.git
|
- onnxruntime>=1.18.1 # conda only has gpu version
|
||||||
|
- onnxsim>=0.4.36
|
||||||
|
- onnx2tf>=1.24.1
|
||||||
|
- onnx_graphsurgeon>=0.5.2
|
||||||
|
- simple_onnx_processing_tools>=1.1.32
|
||||||
|
- tf_keras>=2.16.0
|
||||||
|
- git+https://github.com/microsoft/onnxconverter-common.git
|
||||||
|
|||||||
@@ -0,0 +1,99 @@
|
|||||||
|
import os
|
||||||
|
import platform
|
||||||
|
from typing import Annotated, Optional
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from onnx2ann.export import Exporter, ModelType, Precision
|
||||||
|
|
||||||
|
app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def export(
|
||||||
|
model_name: Annotated[
|
||||||
|
str, typer.Argument(..., help="The name of the model to be exported as it exists in Hugging Face.")
|
||||||
|
],
|
||||||
|
model_type: Annotated[ModelType, typer.Option(..., "--type", "-t", help="The type of model to be exported.")],
|
||||||
|
input_shapes: Annotated[
|
||||||
|
list[str],
|
||||||
|
typer.Option(
|
||||||
|
...,
|
||||||
|
"--input-shape",
|
||||||
|
"-s",
|
||||||
|
help="The shape of an input tensor to the model, each dimension separated by commas. "
|
||||||
|
"Multiple shapes can be provided for multiple inputs.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
precision: Annotated[
|
||||||
|
Precision,
|
||||||
|
typer.Option(
|
||||||
|
...,
|
||||||
|
"--precision",
|
||||||
|
"-p",
|
||||||
|
help="The precision of the exported model. `float16` requires a GPU.",
|
||||||
|
),
|
||||||
|
] = Precision.FLOAT32,
|
||||||
|
cache_dir: Annotated[
|
||||||
|
str,
|
||||||
|
typer.Option(
|
||||||
|
...,
|
||||||
|
"--cache-dir",
|
||||||
|
"-c",
|
||||||
|
help="Directory where pre-export models will be stored.",
|
||||||
|
envvar="CACHE_DIR",
|
||||||
|
show_envvar=True,
|
||||||
|
),
|
||||||
|
] = "~/.cache/huggingface",
|
||||||
|
output_dir: Annotated[
|
||||||
|
str,
|
||||||
|
typer.Option(
|
||||||
|
...,
|
||||||
|
"--output-dir",
|
||||||
|
"-o",
|
||||||
|
help="Directory where exported models will be stored.",
|
||||||
|
),
|
||||||
|
] = "output",
|
||||||
|
auth_token: Annotated[
|
||||||
|
Optional[str],
|
||||||
|
typer.Option(
|
||||||
|
...,
|
||||||
|
"--auth-token",
|
||||||
|
"-t",
|
||||||
|
help="If uploading models to Hugging Face, the auth token of the user or organisation.",
|
||||||
|
envvar="HF_AUTH_TOKEN",
|
||||||
|
show_envvar=True,
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
force_export: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(
|
||||||
|
...,
|
||||||
|
"--force-export",
|
||||||
|
"-f",
|
||||||
|
help="Export the model even if an exported model already exists in the output directory.",
|
||||||
|
),
|
||||||
|
] = False,
|
||||||
|
) -> None:
|
||||||
|
if platform.machine() not in ("x86_64", "AMD64"):
|
||||||
|
msg = f"Can only run on x86_64 / AMD64, not {platform.machine()}"
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
os.environ.setdefault("LD_LIBRARY_PATH", "armnn")
|
||||||
|
parsed_input_shapes = [tuple(map(int, shape.split(","))) for shape in input_shapes]
|
||||||
|
model = Exporter(
|
||||||
|
model_name, model_type, input_shapes=parsed_input_shapes, cache_dir=cache_dir, force_export=force_export
|
||||||
|
)
|
||||||
|
model_dir = os.path.join("output", model_name)
|
||||||
|
output_dir = os.path.join(model_dir, model_type)
|
||||||
|
armnn_model = model.to_armnn(output_dir, precision)
|
||||||
|
|
||||||
|
if not auth_token:
|
||||||
|
return
|
||||||
|
|
||||||
|
from huggingface_hub import upload_file
|
||||||
|
|
||||||
|
relative_path = os.path.relpath(armnn_model, start=model_dir)
|
||||||
|
upload_file(path_or_fileobj=armnn_model, path_in_repo=relative_path, repo_id=model.repo_name, token=auth_token)
|
||||||
|
|
||||||
|
|
||||||
|
app()
|
||||||
@@ -0,0 +1,129 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
from onnx2ann.helpers import onnx_make_armnn_compatible, onnx_make_inputs_fixed
|
||||||
|
|
||||||
|
|
||||||
|
class ModelType(StrEnum):
|
||||||
|
VISUAL = "visual"
|
||||||
|
TEXTUAL = "textual"
|
||||||
|
RECOGNITION = "recognition"
|
||||||
|
DETECTION = "detection"
|
||||||
|
|
||||||
|
|
||||||
|
class Precision(StrEnum):
|
||||||
|
FLOAT16 = "float16"
|
||||||
|
FLOAT32 = "float32"
|
||||||
|
|
||||||
|
|
||||||
|
class Exporter:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
model_type: str,
|
||||||
|
input_shapes: list[tuple[int, ...]],
|
||||||
|
optimization_level: int = 5,
|
||||||
|
cache_dir: str = os.environ.get("CACHE_DIR", "~/.cache/huggingface"),
|
||||||
|
force_export: bool = False,
|
||||||
|
):
|
||||||
|
self.model_name = model_name.split("/")[-1]
|
||||||
|
self.model_type = model_type
|
||||||
|
self.optimize = optimization_level
|
||||||
|
self.input_shapes = input_shapes
|
||||||
|
self.cache_dir = os.path.join(cache_dir, self.repo_name)
|
||||||
|
self.force_export = force_export
|
||||||
|
|
||||||
|
def download(self) -> str:
|
||||||
|
model_path = os.path.join(self.cache_dir, self.model_type, "model.onnx")
|
||||||
|
if os.path.isfile(model_path):
|
||||||
|
print(f"Model is already downloaded at {model_path}")
|
||||||
|
return model_path
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
snapshot_download(
|
||||||
|
self.repo_name, cache_dir=self.cache_dir, local_dir=self.cache_dir, local_dir_use_symlinks=False
|
||||||
|
)
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
def to_onnx_static(self, precision: Precision) -> str:
|
||||||
|
import onnx
|
||||||
|
from onnxconverter_common import float16
|
||||||
|
onnx_path_original = self.download()
|
||||||
|
static_dir = os.path.join(self.cache_dir, self.model_type, "static")
|
||||||
|
|
||||||
|
static_path = os.path.join(static_dir, f"model.onnx")
|
||||||
|
if self.force_export and not os.path.isfile(static_path):
|
||||||
|
print(f"Making {self} static")
|
||||||
|
os.makedirs(static_dir, exist_ok=True)
|
||||||
|
onnx_make_inputs_fixed(onnx_path_original, static_path, self.input_shapes)
|
||||||
|
onnx_make_armnn_compatible(static_path)
|
||||||
|
print(f"Finished making {self} static")
|
||||||
|
|
||||||
|
model = onnx.load(static_path)
|
||||||
|
self.inputs = [input_.name for input_ in model.graph.input]
|
||||||
|
self.outputs = [output_.name for output_ in model.graph.output]
|
||||||
|
if precision == Precision.FLOAT16:
|
||||||
|
static_path = os.path.join(static_dir, f"model_{precision}.onnx")
|
||||||
|
print(f"Converting {self} to {precision} precision")
|
||||||
|
model = float16.convert_float_to_float16(model, keep_io_types=True, disable_shape_infer=True)
|
||||||
|
onnx.save(model, static_path)
|
||||||
|
print(f"Finished converting {self} to {precision} precision")
|
||||||
|
# self.inputs, self.outputs = onnx_get_inputs_outputs(static_path)
|
||||||
|
return static_path
|
||||||
|
|
||||||
|
def to_tflite(self, output_dir: str, precision: Precision) -> str:
|
||||||
|
onnx_model = self.to_onnx_static(precision)
|
||||||
|
tflite_dir = os.path.join(output_dir, precision)
|
||||||
|
tflite_model = os.path.join(tflite_dir, f"model_{precision}.tflite")
|
||||||
|
if self.force_export or not os.path.isfile(tflite_model):
|
||||||
|
import onnx2tf
|
||||||
|
|
||||||
|
print(f"Exporting {self} to TFLite with {precision} precision (this might take a few minutes)")
|
||||||
|
onnx2tf.convert(
|
||||||
|
input_onnx_file_path=onnx_model,
|
||||||
|
output_folder_path=tflite_dir,
|
||||||
|
keep_shape_absolutely_input_names=self.inputs,
|
||||||
|
# verbosity="warn",
|
||||||
|
copy_onnx_input_output_names_to_tflite=True,
|
||||||
|
output_signaturedefs=True,
|
||||||
|
not_use_onnxsim=True,
|
||||||
|
)
|
||||||
|
print(f"Finished exporting {self} to TFLite with {precision} precision")
|
||||||
|
|
||||||
|
return tflite_model
|
||||||
|
|
||||||
|
def to_armnn(self, output_dir: str, precision: Precision) -> tuple[str, str]:
|
||||||
|
armnn_model = os.path.join(output_dir, "model.armnn")
|
||||||
|
if not self.force_export and os.path.isfile(armnn_model):
|
||||||
|
return armnn_model
|
||||||
|
|
||||||
|
tflite_model_dir = os.path.join(output_dir, "tflite")
|
||||||
|
tflite_model = self.to_tflite(tflite_model_dir, precision)
|
||||||
|
|
||||||
|
args = ["./armnnconverter", "-f", "tflite-binary", "-m", tflite_model, "-p", armnn_model]
|
||||||
|
args.append("-i")
|
||||||
|
args.extend(self.inputs)
|
||||||
|
args.append("-o")
|
||||||
|
args.extend(self.outputs)
|
||||||
|
|
||||||
|
print(f"Exporting {self} to ARM NN with {precision} precision")
|
||||||
|
try:
|
||||||
|
if (stdout := subprocess.check_output(args, stderr=subprocess.STDOUT).decode()):
|
||||||
|
print(stdout)
|
||||||
|
print(f"Finished exporting {self} to ARM NN with {precision} precision")
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(e.output.decode())
|
||||||
|
try:
|
||||||
|
from shutil import rmtree
|
||||||
|
|
||||||
|
rmtree(tflite_model_dir, ignore_errors=True)
|
||||||
|
finally:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
@property
|
||||||
|
def repo_name(self) -> str:
|
||||||
|
return f"immich-app/{self.model_name}"
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.model_name} ({self.model_type})"
|
||||||
@@ -0,0 +1,260 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def onnx_make_armnn_compatible(model_path: str) -> None:
|
||||||
|
"""
|
||||||
|
i can explain
|
||||||
|
armnn only supports up to 4d tranposes, but the model has a 5d transpose due to a redundant unsqueeze
|
||||||
|
this function folds the unsqueeze+transpose+squeeze into a single 4d transpose
|
||||||
|
it also switches from gather ops to slices since armnn has different dimension semantics for gathers
|
||||||
|
also fixes batch normalization being in training mode
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import onnx
|
||||||
|
from onnx_graphsurgeon import Constant, Node, Variable, export_onnx, import_onnx
|
||||||
|
|
||||||
|
proto = onnx.load(model_path)
|
||||||
|
graph = import_onnx(proto)
|
||||||
|
|
||||||
|
gather_idx = 1
|
||||||
|
squeeze_idx = 1
|
||||||
|
for node in graph.nodes:
|
||||||
|
for link1 in node.outputs:
|
||||||
|
if "Unsqueeze" in link1.name:
|
||||||
|
for node1 in link1.outputs:
|
||||||
|
for link2 in node1.outputs:
|
||||||
|
if "Transpose" in link2.name:
|
||||||
|
for node2 in link2.outputs:
|
||||||
|
if node2.attrs.get("perm") == [3, 1, 2, 0, 4]:
|
||||||
|
node2.attrs["perm"] = [2, 0, 1, 3]
|
||||||
|
link2.shape = link1.shape
|
||||||
|
for link3 in node2.outputs:
|
||||||
|
if "Squeeze" in link3.name:
|
||||||
|
link3.shape = [link3.shape[x] for x in [0, 1, 2, 4]]
|
||||||
|
for node3 in link3.outputs:
|
||||||
|
for link4 in node3.outputs:
|
||||||
|
link4.shape = link3.shape
|
||||||
|
try:
|
||||||
|
idx = link2.inputs.index(node1)
|
||||||
|
link2.inputs[idx] = node
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
node.outputs = [link2]
|
||||||
|
if "Gather" in link4.name:
|
||||||
|
for node4 in link4.outputs:
|
||||||
|
axis = node1.attrs.get("axis", 0)
|
||||||
|
index = node4.inputs[1].values
|
||||||
|
slice_link = Variable(
|
||||||
|
f"onnx::Slice_123{gather_idx}",
|
||||||
|
dtype=link4.dtype,
|
||||||
|
shape=[1] + link3.shape[1:],
|
||||||
|
)
|
||||||
|
slice_node = Node(
|
||||||
|
op="Slice",
|
||||||
|
inputs=[
|
||||||
|
link3,
|
||||||
|
Constant(
|
||||||
|
f"SliceStart_123{gather_idx}",
|
||||||
|
np.array([index]),
|
||||||
|
),
|
||||||
|
Constant(
|
||||||
|
f"SliceEnd_123{gather_idx}",
|
||||||
|
np.array([index + 1]),
|
||||||
|
),
|
||||||
|
Constant(
|
||||||
|
f"SliceAxis_123{gather_idx}",
|
||||||
|
np.array([axis]),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[slice_link],
|
||||||
|
name=f"Slice_123{gather_idx}",
|
||||||
|
)
|
||||||
|
graph.nodes.append(slice_node)
|
||||||
|
gather_idx += 1
|
||||||
|
|
||||||
|
for link5 in node4.outputs:
|
||||||
|
for node5 in link5.outputs:
|
||||||
|
try:
|
||||||
|
idx = node5.inputs.index(link5)
|
||||||
|
node5.inputs[idx] = slice_link
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
elif node.op == "LayerNormalization":
|
||||||
|
for node1 in link1.outputs:
|
||||||
|
if node1.op == "Gather":
|
||||||
|
for link2 in node1.outputs:
|
||||||
|
for node2 in link2.outputs:
|
||||||
|
axis = node1.attrs.get("axis", 0)
|
||||||
|
index = node1.inputs[1].values
|
||||||
|
slice_link = Variable(
|
||||||
|
f"onnx::Slice_123{gather_idx}",
|
||||||
|
dtype=link2.dtype,
|
||||||
|
shape=[1, *link2.shape],
|
||||||
|
)
|
||||||
|
slice_node = Node(
|
||||||
|
op="Slice",
|
||||||
|
inputs=[
|
||||||
|
node1.inputs[0],
|
||||||
|
Constant(
|
||||||
|
f"SliceStart_123{gather_idx}",
|
||||||
|
np.array([index]),
|
||||||
|
),
|
||||||
|
Constant(
|
||||||
|
f"SliceEnd_123{gather_idx}",
|
||||||
|
np.array([index + 1]),
|
||||||
|
),
|
||||||
|
Constant(
|
||||||
|
f"SliceAxis_123{gather_idx}",
|
||||||
|
np.array([axis]),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[slice_link],
|
||||||
|
name=f"Slice_123{gather_idx}",
|
||||||
|
)
|
||||||
|
graph.nodes.append(slice_node)
|
||||||
|
gather_idx += 1
|
||||||
|
|
||||||
|
squeeze_link = Variable(
|
||||||
|
f"onnx::Squeeze_123{squeeze_idx}",
|
||||||
|
dtype=link2.dtype,
|
||||||
|
shape=link2.shape,
|
||||||
|
)
|
||||||
|
squeeze_node = Node(
|
||||||
|
op="Squeeze",
|
||||||
|
inputs=[
|
||||||
|
slice_link,
|
||||||
|
Constant(
|
||||||
|
f"SqueezeAxis_123{squeeze_idx}",
|
||||||
|
np.array([0]),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[squeeze_link],
|
||||||
|
name=f"Squeeze_123{squeeze_idx}",
|
||||||
|
)
|
||||||
|
graph.nodes.append(squeeze_node)
|
||||||
|
squeeze_idx += 1
|
||||||
|
try:
|
||||||
|
idx = node2.inputs.index(link2)
|
||||||
|
node2.inputs[idx] = squeeze_link
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
elif node.op == "Reshape":
|
||||||
|
for node1 in link1.outputs:
|
||||||
|
if node1.op == "Gather":
|
||||||
|
node2s = [n for link in node1.outputs for n in link.outputs]
|
||||||
|
if any(n.op == "Abs" for n in node2s):
|
||||||
|
axis = node1.attrs.get("axis", 0)
|
||||||
|
index = node1.inputs[1].values
|
||||||
|
slice_link = Variable(
|
||||||
|
f"onnx::Slice_123{gather_idx}",
|
||||||
|
dtype=node1.outputs[0].dtype,
|
||||||
|
shape=[1, *node1.outputs[0].shape],
|
||||||
|
)
|
||||||
|
slice_node = Node(
|
||||||
|
op="Slice",
|
||||||
|
inputs=[
|
||||||
|
node1.inputs[0],
|
||||||
|
Constant(
|
||||||
|
f"SliceStart_123{gather_idx}",
|
||||||
|
np.array([index]),
|
||||||
|
),
|
||||||
|
Constant(
|
||||||
|
f"SliceEnd_123{gather_idx}",
|
||||||
|
np.array([index + 1]),
|
||||||
|
),
|
||||||
|
Constant(
|
||||||
|
f"SliceAxis_123{gather_idx}",
|
||||||
|
np.array([axis]),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[slice_link],
|
||||||
|
name=f"Slice_123{gather_idx}",
|
||||||
|
)
|
||||||
|
graph.nodes.append(slice_node)
|
||||||
|
gather_idx += 1
|
||||||
|
|
||||||
|
squeeze_link = Variable(
|
||||||
|
f"onnx::Squeeze_123{squeeze_idx}",
|
||||||
|
dtype=node1.outputs[0].dtype,
|
||||||
|
shape=node1.outputs[0].shape,
|
||||||
|
)
|
||||||
|
squeeze_node = Node(
|
||||||
|
op="Squeeze",
|
||||||
|
inputs=[
|
||||||
|
slice_link,
|
||||||
|
Constant(
|
||||||
|
f"SqueezeAxis_123{squeeze_idx}",
|
||||||
|
np.array([0]),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[squeeze_link],
|
||||||
|
name=f"Squeeze_123{squeeze_idx}",
|
||||||
|
)
|
||||||
|
graph.nodes.append(squeeze_node)
|
||||||
|
squeeze_idx += 1
|
||||||
|
for node2 in node2s:
|
||||||
|
node2.inputs[0] = squeeze_link
|
||||||
|
elif node.op == "BatchNormalization" and node.attrs.get("training_mode") == 1:
|
||||||
|
node.attrs["training_mode"] = 0
|
||||||
|
node.outputs = node.outputs[:1]
|
||||||
|
|
||||||
|
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
|
||||||
|
graph.toposort()
|
||||||
|
graph.fold_constants()
|
||||||
|
updated = export_onnx(graph)
|
||||||
|
onnx_save(updated, model_path)
|
||||||
|
|
||||||
|
# for some reason, reloading the model is necessary to apply the correct shape
|
||||||
|
proto = onnx.load(model_path)
|
||||||
|
graph = import_onnx(proto)
|
||||||
|
for node in graph.nodes:
|
||||||
|
if node.op == "Slice":
|
||||||
|
for link in node.outputs:
|
||||||
|
if "Slice_123" in link.name and link.shape[0] == 3: # noqa: PLR2004
|
||||||
|
link.shape[0] = 1
|
||||||
|
|
||||||
|
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
|
||||||
|
graph.toposort()
|
||||||
|
graph.fold_constants()
|
||||||
|
updated = export_onnx(graph)
|
||||||
|
onnx_save(updated, model_path)
|
||||||
|
onnx.shape_inference.infer_shapes_path(model_path, check_type=True, strict_mode=True, data_prop=True)
|
||||||
|
|
||||||
|
|
||||||
|
def onnx_make_inputs_fixed(input_path: str, output_path: str, input_shapes: list[tuple[int, ...]]) -> None:
|
||||||
|
import onnx
|
||||||
|
import onnxsim
|
||||||
|
from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed
|
||||||
|
|
||||||
|
model, success = onnxsim.simplify(input_path)
|
||||||
|
if not success:
|
||||||
|
msg = f"Failed to simplify {input_path}"
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
onnx_save(model, output_path)
|
||||||
|
onnx.shape_inference.infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True)
|
||||||
|
model = onnx.load_model(output_path)
|
||||||
|
for input_node, shape in zip(model.graph.input, input_shapes, strict=False):
|
||||||
|
make_input_shape_fixed(model.graph, input_node.name, shape)
|
||||||
|
fix_output_shapes(model)
|
||||||
|
onnx_save(model, output_path)
|
||||||
|
onnx.shape_inference.infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True)
|
||||||
|
|
||||||
|
|
||||||
|
def onnx_get_inputs_outputs(model_path: str) -> tuple[list[str], list[str]]:
|
||||||
|
import onnx
|
||||||
|
|
||||||
|
model = onnx.load(model_path)
|
||||||
|
inputs = [input_.name for input_ in model.graph.input]
|
||||||
|
outputs = [output_.name for output_ in model.graph.output]
|
||||||
|
return inputs, outputs
|
||||||
|
|
||||||
|
|
||||||
|
def onnx_save(model: Any, output_path: str) -> None:
|
||||||
|
import onnx
|
||||||
|
|
||||||
|
try:
|
||||||
|
onnx.save(model, output_path)
|
||||||
|
except:
|
||||||
|
onnx.save(model, output_path, save_as_external_data=True, all_tensors_to_one_file=False, size_threshold=1_000_000)
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
[project]
|
||||||
|
name = "onnx2ann"
|
||||||
|
version = "1.107.2"
|
||||||
|
dependencies = [
|
||||||
|
"onnx>=1.16.1",
|
||||||
|
"psutil>=6.0.0",
|
||||||
|
"flatbuffers>=24.3.25",
|
||||||
|
"ml_dtypes>=0.3.1,<1.0.0",
|
||||||
|
"typer-slim>=0.12.3,<1.0.0",
|
||||||
|
"huggingface_hub>=0.23.4,<1.0.0",
|
||||||
|
"onnxruntime>=1.18.1",
|
||||||
|
"onnxsim>=0.4.36,<1.0.0",
|
||||||
|
"onnx2tf>=1.24.0",
|
||||||
|
"onnx_graphsurgeon>=0.5.2,<1.0.0",
|
||||||
|
"simple_onnx_processing_tools>=1.1.32",
|
||||||
|
"tf_keras>=2.16.0",
|
||||||
|
"onnxconverter-common @ git+https://github.com/microsoft/onnxconverter-common"
|
||||||
|
]
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.sdist]
|
||||||
|
only-include = ["onnx2ann"]
|
||||||
|
|
||||||
|
[tool.hatch.metadata]
|
||||||
|
allow-direct-references = true
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
python_version = "3.12"
|
||||||
|
follow_imports = "silent"
|
||||||
|
warn_redundant_casts = true
|
||||||
|
disallow_any_generics = true
|
||||||
|
check_untyped_defs = true
|
||||||
|
disallow_untyped_defs = true
|
||||||
|
ignore_missing_imports = true
|
||||||
|
|
||||||
|
[tool.pydantic-mypy]
|
||||||
|
init_forbid_extra = true
|
||||||
|
init_typed = true
|
||||||
|
warn_required_dynamic_aliases = true
|
||||||
|
warn_untyped_fields = true
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 120
|
||||||
|
target-version = "py312"
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
extend-select = ["E", "F", "I"]
|
||||||
|
extend-ignore = ["FBT001", "FBT002"]
|
||||||
|
|
||||||
|
[tool.black]
|
||||||
|
line-length = 120
|
||||||
|
target-version = ['py312']
|
||||||
@@ -1,475 +0,0 @@
|
|||||||
import os
|
|
||||||
import platform
|
|
||||||
import subprocess
|
|
||||||
from typing import Callable, ClassVar
|
|
||||||
|
|
||||||
import onnx
|
|
||||||
from onnx_graphsurgeon import Constant, Node, Variable, import_onnx, export_onnx
|
|
||||||
from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from onnx.shape_inference import infer_shapes_path
|
|
||||||
from huggingface_hub import login, upload_file
|
|
||||||
import onnx2tf
|
|
||||||
import numpy as np
|
|
||||||
import onnxsim
|
|
||||||
from shutil import rmtree
|
|
||||||
|
|
||||||
# hack: changed Mul op in onnx2tf to skip broadcast if graph_node.o().op == 'Sigmoid'
|
|
||||||
|
|
||||||
# i can explain
|
|
||||||
# armnn only supports up to 4d tranposes, but the model has a 5d transpose due to a redundant unsqueeze
|
|
||||||
# this function folds the unsqueeze+transpose+squeeze into a single 4d transpose
|
|
||||||
# it also switches from gather ops to slices since armnn has different dimension semantics for gathers
|
|
||||||
# also fixes batch normalization being in training mode
|
|
||||||
def make_onnx_armnn_compatible(model_path: str):
|
|
||||||
proto = onnx.load(model_path)
|
|
||||||
graph = import_onnx(proto)
|
|
||||||
|
|
||||||
gather_idx = 1
|
|
||||||
squeeze_idx = 1
|
|
||||||
for node in graph.nodes:
|
|
||||||
for link1 in node.outputs:
|
|
||||||
if "Unsqueeze" in link1.name:
|
|
||||||
for node1 in link1.outputs:
|
|
||||||
for link2 in node1.outputs:
|
|
||||||
if "Transpose" in link2.name:
|
|
||||||
for node2 in link2.outputs:
|
|
||||||
if node2.attrs.get("perm") == [3, 1, 2, 0, 4]:
|
|
||||||
node2.attrs["perm"] = [2, 0, 1, 3]
|
|
||||||
link2.shape = link1.shape
|
|
||||||
for link3 in node2.outputs:
|
|
||||||
if "Squeeze" in link3.name:
|
|
||||||
link3.shape = [link3.shape[x] for x in [0, 1, 2, 4]]
|
|
||||||
for node3 in link3.outputs:
|
|
||||||
for link4 in node3.outputs:
|
|
||||||
link4.shape = link3.shape
|
|
||||||
try:
|
|
||||||
idx = link2.inputs.index(node1)
|
|
||||||
link2.inputs[idx] = node
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
node.outputs = [link2]
|
|
||||||
if "Gather" in link4.name:
|
|
||||||
for node4 in link4.outputs:
|
|
||||||
axis = node1.attrs.get("axis", 0)
|
|
||||||
index = node4.inputs[1].values
|
|
||||||
slice_link = Variable(
|
|
||||||
f"onnx::Slice_123{gather_idx}",
|
|
||||||
dtype=link4.dtype,
|
|
||||||
shape=[1] + link3.shape[1:],
|
|
||||||
)
|
|
||||||
slice_node = Node(
|
|
||||||
op="Slice",
|
|
||||||
inputs=[
|
|
||||||
link3,
|
|
||||||
Constant(
|
|
||||||
f"SliceStart_123{gather_idx}",
|
|
||||||
np.array([index]),
|
|
||||||
),
|
|
||||||
Constant(
|
|
||||||
f"SliceEnd_123{gather_idx}",
|
|
||||||
np.array([index + 1]),
|
|
||||||
),
|
|
||||||
Constant(
|
|
||||||
f"SliceAxis_123{gather_idx}",
|
|
||||||
np.array([axis]),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
outputs=[slice_link],
|
|
||||||
name=f"Slice_123{gather_idx}",
|
|
||||||
)
|
|
||||||
graph.nodes.append(slice_node)
|
|
||||||
gather_idx += 1
|
|
||||||
|
|
||||||
for link5 in node4.outputs:
|
|
||||||
for node5 in link5.outputs:
|
|
||||||
try:
|
|
||||||
idx = node5.inputs.index(link5)
|
|
||||||
node5.inputs[idx] = slice_link
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
elif node.op == "LayerNormalization":
|
|
||||||
for node1 in link1.outputs:
|
|
||||||
if node1.op == "Gather":
|
|
||||||
for link2 in node1.outputs:
|
|
||||||
for node2 in link2.outputs:
|
|
||||||
axis = node1.attrs.get("axis", 0)
|
|
||||||
index = node1.inputs[1].values
|
|
||||||
slice_link = Variable(
|
|
||||||
f"onnx::Slice_123{gather_idx}",
|
|
||||||
dtype=link2.dtype,
|
|
||||||
shape=[1] + link2.shape,
|
|
||||||
)
|
|
||||||
slice_node = Node(
|
|
||||||
op="Slice",
|
|
||||||
inputs=[
|
|
||||||
node1.inputs[0],
|
|
||||||
Constant(
|
|
||||||
f"SliceStart_123{gather_idx}",
|
|
||||||
np.array([index]),
|
|
||||||
),
|
|
||||||
Constant(
|
|
||||||
f"SliceEnd_123{gather_idx}",
|
|
||||||
np.array([index + 1]),
|
|
||||||
),
|
|
||||||
Constant(
|
|
||||||
f"SliceAxis_123{gather_idx}",
|
|
||||||
np.array([axis]),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
outputs=[slice_link],
|
|
||||||
name=f"Slice_123{gather_idx}",
|
|
||||||
)
|
|
||||||
graph.nodes.append(slice_node)
|
|
||||||
gather_idx += 1
|
|
||||||
|
|
||||||
squeeze_link = Variable(
|
|
||||||
f"onnx::Squeeze_123{squeeze_idx}",
|
|
||||||
dtype=link2.dtype,
|
|
||||||
shape=link2.shape,
|
|
||||||
)
|
|
||||||
squeeze_node = Node(
|
|
||||||
op="Squeeze",
|
|
||||||
inputs=[slice_link, Constant(f"SqueezeAxis_123{squeeze_idx}",np.array([0]),)],
|
|
||||||
outputs=[squeeze_link],
|
|
||||||
name=f"Squeeze_123{squeeze_idx}",
|
|
||||||
)
|
|
||||||
graph.nodes.append(squeeze_node)
|
|
||||||
squeeze_idx += 1
|
|
||||||
try:
|
|
||||||
idx = node2.inputs.index(link2)
|
|
||||||
node2.inputs[idx] = squeeze_link
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
elif node.op == "Reshape":
|
|
||||||
for node1 in link1.outputs:
|
|
||||||
if node1.op == "Gather":
|
|
||||||
node2s = [n for l in node1.outputs for n in l.outputs]
|
|
||||||
if any(n.op == "Abs" for n in node2s):
|
|
||||||
axis = node1.attrs.get("axis", 0)
|
|
||||||
index = node1.inputs[1].values
|
|
||||||
slice_link = Variable(
|
|
||||||
f"onnx::Slice_123{gather_idx}",
|
|
||||||
dtype=node1.outputs[0].dtype,
|
|
||||||
shape=[1] + node1.outputs[0].shape,
|
|
||||||
)
|
|
||||||
slice_node = Node(
|
|
||||||
op="Slice",
|
|
||||||
inputs=[
|
|
||||||
node1.inputs[0],
|
|
||||||
Constant(
|
|
||||||
f"SliceStart_123{gather_idx}",
|
|
||||||
np.array([index]),
|
|
||||||
),
|
|
||||||
Constant(
|
|
||||||
f"SliceEnd_123{gather_idx}",
|
|
||||||
np.array([index + 1]),
|
|
||||||
),
|
|
||||||
Constant(
|
|
||||||
f"SliceAxis_123{gather_idx}",
|
|
||||||
np.array([axis]),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
outputs=[slice_link],
|
|
||||||
name=f"Slice_123{gather_idx}",
|
|
||||||
)
|
|
||||||
graph.nodes.append(slice_node)
|
|
||||||
gather_idx += 1
|
|
||||||
|
|
||||||
squeeze_link = Variable(
|
|
||||||
f"onnx::Squeeze_123{squeeze_idx}",
|
|
||||||
dtype=node1.outputs[0].dtype,
|
|
||||||
shape=node1.outputs[0].shape,
|
|
||||||
)
|
|
||||||
squeeze_node = Node(
|
|
||||||
op="Squeeze",
|
|
||||||
inputs=[slice_link, Constant(f"SqueezeAxis_123{squeeze_idx}",np.array([0]),)],
|
|
||||||
outputs=[squeeze_link],
|
|
||||||
name=f"Squeeze_123{squeeze_idx}",
|
|
||||||
)
|
|
||||||
graph.nodes.append(squeeze_node)
|
|
||||||
squeeze_idx += 1
|
|
||||||
for node2 in node2s:
|
|
||||||
node2.inputs[0] = squeeze_link
|
|
||||||
elif node.op == "BatchNormalization":
|
|
||||||
if node.attrs.get("training_mode") == 1:
|
|
||||||
node.attrs["training_mode"] = 0
|
|
||||||
node.outputs = node.outputs[:1]
|
|
||||||
|
|
||||||
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
|
|
||||||
graph.toposort()
|
|
||||||
graph.fold_constants()
|
|
||||||
updated = export_onnx(graph)
|
|
||||||
onnx.save(updated, model_path)
|
|
||||||
# infer_shapes_path(updated, check_type=True, strict_mode=False, data_prop=True)
|
|
||||||
|
|
||||||
# for some reason, reloading the model is necessary to apply the correct shape
|
|
||||||
proto = onnx.load(model_path)
|
|
||||||
graph = import_onnx(proto)
|
|
||||||
for node in graph.nodes:
|
|
||||||
if node.op == "Slice":
|
|
||||||
for link in node.outputs:
|
|
||||||
if "Slice_123" in link.name and link.shape[0] == 3:
|
|
||||||
link.shape[0] = 1
|
|
||||||
|
|
||||||
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
|
|
||||||
graph.toposort()
|
|
||||||
graph.fold_constants()
|
|
||||||
updated = export_onnx(graph)
|
|
||||||
onnx.save(updated, model_path)
|
|
||||||
infer_shapes_path(model_path, check_type=True, strict_mode=True, data_prop=True)
|
|
||||||
|
|
||||||
|
|
||||||
def onnx_make_fixed(input_path: str, output_path: str, input_shape: tuple[int, ...]):
|
|
||||||
simplified, success = onnxsim.simplify(input_path)
|
|
||||||
if not success:
|
|
||||||
raise RuntimeError(f"Failed to simplify {input_path}")
|
|
||||||
try:
|
|
||||||
onnx.save(simplified, output_path)
|
|
||||||
except:
|
|
||||||
onnx.save(simplified, output_path, save_as_external_data=True, all_tensors_to_one_file=False)
|
|
||||||
infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True)
|
|
||||||
model = onnx.load_model(output_path)
|
|
||||||
make_input_shape_fixed(model.graph, model.graph.input[0].name, input_shape)
|
|
||||||
fix_output_shapes(model)
|
|
||||||
try:
|
|
||||||
onnx.save(model, output_path)
|
|
||||||
except:
|
|
||||||
onnx.save(model, output_path, save_as_external_data=True, all_tensors_to_one_file=False)
|
|
||||||
onnx.save(model, output_path)
|
|
||||||
infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True)
|
|
||||||
|
|
||||||
|
|
||||||
class ExportBase:
|
|
||||||
task: ClassVar[str]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
input_shape: tuple[int, ...],
|
|
||||||
pretrained: str | None = None,
|
|
||||||
optimization_level: int = 5,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.name = name
|
|
||||||
self.optimize = optimization_level
|
|
||||||
self.input_shape = input_shape
|
|
||||||
self.pretrained = pretrained
|
|
||||||
self.cache_dir = os.path.join(os.environ["CACHE_DIR"], self.model_name)
|
|
||||||
|
|
||||||
def download(self) -> str:
|
|
||||||
model_path = os.path.join(self.cache_dir, self.task, "model.onnx")
|
|
||||||
if not os.path.isfile(model_path):
|
|
||||||
print(f"Downloading {self.model_name}...")
|
|
||||||
snapshot_download(self.repo_name, cache_dir=self.cache_dir, local_dir=self.cache_dir, local_dir_use_symlinks=False)
|
|
||||||
return model_path
|
|
||||||
|
|
||||||
def to_onnx_static(self) -> str:
|
|
||||||
onnx_path_original = self.download()
|
|
||||||
static_dir = os.path.join(self.cache_dir, self.task, "static")
|
|
||||||
os.makedirs(static_dir, exist_ok=True)
|
|
||||||
|
|
||||||
static_path = os.path.join(static_dir, "model.onnx")
|
|
||||||
if not os.path.isfile(static_path):
|
|
||||||
print(f"Making {self.model_name} ({self.task}) static")
|
|
||||||
onnx_make_fixed(onnx_path_original, static_path, self.input_shape)
|
|
||||||
make_onnx_armnn_compatible(static_path)
|
|
||||||
static_model = onnx.load_model(static_path)
|
|
||||||
self.inputs = [input_.name for input_ in static_model.graph.input]
|
|
||||||
self.outputs = [output_.name for output_ in static_model.graph.output]
|
|
||||||
return static_path
|
|
||||||
|
|
||||||
def to_tflite(self, output_dir: str) -> tuple[str, str]:
|
|
||||||
input_path = self.to_onnx_static()
|
|
||||||
tflite_fp32 = os.path.join(output_dir, "model_float32.tflite")
|
|
||||||
tflite_fp16 = os.path.join(output_dir, "model_float16.tflite")
|
|
||||||
if not os.path.isfile(tflite_fp32) or not os.path.isfile(tflite_fp16):
|
|
||||||
print(f"Exporting {self.model_name} ({self.task}) to TFLite (this might take a few minutes)")
|
|
||||||
onnx2tf.convert(
|
|
||||||
input_onnx_file_path=input_path,
|
|
||||||
output_folder_path=output_dir,
|
|
||||||
keep_shape_absolutely_input_names=self.inputs,
|
|
||||||
verbosity="warn",
|
|
||||||
copy_onnx_input_output_names_to_tflite=True,
|
|
||||||
output_signaturedefs=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return tflite_fp32, tflite_fp16
|
|
||||||
|
|
||||||
def to_armnn(self, output_dir: str) -> tuple[str, str]:
|
|
||||||
output_dir = os.path.abspath(output_dir)
|
|
||||||
tflite_model_dir = os.path.join(output_dir, "tflite")
|
|
||||||
tflite_fp32, tflite_fp16 = self.to_tflite(tflite_model_dir)
|
|
||||||
|
|
||||||
fp16_dir = os.path.join(output_dir, "fp16")
|
|
||||||
os.makedirs(fp16_dir, exist_ok=True)
|
|
||||||
armnn_fp32 = os.path.join(output_dir, "model.armnn")
|
|
||||||
armnn_fp16 = os.path.join(fp16_dir, "model.armnn")
|
|
||||||
|
|
||||||
args = ["./armnnconverter", "-f", "tflite-binary"]
|
|
||||||
args.append("-i")
|
|
||||||
args.extend(self.inputs)
|
|
||||||
args.append("-o")
|
|
||||||
args.extend(self.outputs)
|
|
||||||
|
|
||||||
fp32_args = args.copy()
|
|
||||||
fp32_args.extend(["-m", tflite_fp32, "-p", armnn_fp32])
|
|
||||||
|
|
||||||
print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp32 precision")
|
|
||||||
try:
|
|
||||||
print(subprocess.check_output(fp32_args, stderr=subprocess.STDOUT).decode())
|
|
||||||
except subprocess.CalledProcessError as e:
|
|
||||||
print(e.output.decode())
|
|
||||||
try:
|
|
||||||
rmtree(tflite_model_dir, ignore_errors=True)
|
|
||||||
finally:
|
|
||||||
raise e
|
|
||||||
print(f"Finished exporting {self.model_name} ({self.task}) with fp32 precision")
|
|
||||||
|
|
||||||
fp16_args = args.copy()
|
|
||||||
fp16_args.extend(["-m", tflite_fp16, "-p", armnn_fp16])
|
|
||||||
|
|
||||||
print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp16 precision")
|
|
||||||
try:
|
|
||||||
print(subprocess.check_output(fp16_args, stderr=subprocess.STDOUT).decode())
|
|
||||||
except subprocess.CalledProcessError as e:
|
|
||||||
print(e.output.decode())
|
|
||||||
try:
|
|
||||||
rmtree(tflite_model_dir, ignore_errors=True)
|
|
||||||
finally:
|
|
||||||
raise e
|
|
||||||
print(f"Finished exporting {self.model_name} ({self.task}) with fp16 precision")
|
|
||||||
|
|
||||||
return armnn_fp32, armnn_fp16
|
|
||||||
|
|
||||||
@property
|
|
||||||
def model_name(self) -> str:
|
|
||||||
return f"{self.name}__{self.pretrained}" if self.pretrained else self.name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def repo_name(self) -> str:
|
|
||||||
return f"immich-app/{self.model_name}"
|
|
||||||
|
|
||||||
class ArcFace(ExportBase):
|
|
||||||
task = "recognition"
|
|
||||||
|
|
||||||
|
|
||||||
class RetinaFace(ExportBase):
|
|
||||||
task = "detection"
|
|
||||||
|
|
||||||
|
|
||||||
class OpenClipVisual(ExportBase):
|
|
||||||
task = "visual"
|
|
||||||
|
|
||||||
|
|
||||||
class OpenClipTextual(ExportBase):
|
|
||||||
task = "textual"
|
|
||||||
|
|
||||||
|
|
||||||
class MClipTextual(ExportBase):
|
|
||||||
task = "textual"
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
if platform.machine() not in ("x86_64", "AMD64"):
|
|
||||||
raise RuntimeError(f"Can only run on x86_64 / AMD64, not {platform.machine()}")
|
|
||||||
hf_token = os.environ.get("HF_AUTH_TOKEN")
|
|
||||||
if hf_token:
|
|
||||||
login(token=hf_token)
|
|
||||||
os.environ["LD_LIBRARY_PATH"] = "armnn"
|
|
||||||
failed: list[Callable[[], ExportBase]] = [
|
|
||||||
lambda: OpenClipVisual("ViT-H-14-378-quickgelu", (1, 3, 378, 378), pretrained="dfn5b"), # flatbuffers: cannot grow buffer beyond 2 gigabytes (will probably work with fp16)
|
|
||||||
lambda: OpenClipVisual("ViT-H-14-quickgelu", (1, 3, 224, 224), pretrained="dfn5b"), # flatbuffers: cannot grow buffer beyond 2 gigabytes (will probably work with fp16)
|
|
||||||
lambda: OpenClipVisual("ViT-H-14", (1, 3, 224, 224), pretrained="laion2b-s32b-b79k"),
|
|
||||||
lambda: OpenClipTextual("ViT-H-14", (1, 77), pretrained="laion2b-s32b-b79k"),
|
|
||||||
lambda: OpenClipVisual("ViT-g-14", (1, 3, 224, 224), pretrained="laion2b-s12b-b42k"),
|
|
||||||
lambda: OpenClipTextual("ViT-g-14", (1, 77), pretrained="laion2b-s12b-b42k"),
|
|
||||||
lambda: OpenClipVisual("XLM-Roberta-Large-Vit-B-16Plus", (1, 3, 240, 240)),
|
|
||||||
lambda: OpenClipVisual("XLM-Roberta-Large-ViT-H-14", (1, 3, 224, 224), pretrained="frozen_laion5b_s13b_b90k"),
|
|
||||||
lambda: MClipTextual("XLM-Roberta-Large-Vit-L-14", (1, 77)), # Expected normalized_shape to be at least 1-dimensional, i.e., containing at least one element, but got normalized_shape = []
|
|
||||||
lambda: MClipTextual("XLM-Roberta-Large-Vit-B-16Plus", (1, 77)), # Expected normalized_shape to be at least 1-dimensional, i.e., containing at least one element, but got normalized_shape = []
|
|
||||||
lambda: MClipTextual("LABSE-Vit-L-14", (1, 77)), # Expected normalized_shape to be at least 1-dimensional, i.e., containing at least one element, but got normalized_shape = []
|
|
||||||
lambda: OpenClipTextual("XLM-Roberta-Large-ViT-H-14", (1, 77), pretrained="frozen_laion5b_s13b_b90k"), # Expected normalized_shape to be at least 1-dimensional, i.e., containing at least one element, but got normalized_shape = []
|
|
||||||
]
|
|
||||||
|
|
||||||
oom = [
|
|
||||||
lambda: OpenClipVisual("nllb-clip-base-siglip", (1, 3, 384, 384), pretrained="v1"),
|
|
||||||
lambda: OpenClipTextual("nllb-clip-base-siglip", (1, 77), pretrained="v1"),
|
|
||||||
lambda: OpenClipVisual("nllb-clip-large-siglip", (1, 3, 384, 384), pretrained="v1"),
|
|
||||||
lambda: OpenClipTextual("nllb-clip-large-siglip", (1, 77), pretrained="v1"), # ERROR (tinynn.converter.base) Unsupported ops: aten::logical_not
|
|
||||||
# lambda: OpenClipTextual("ViT-H-14-quickgelu", (1, 77), pretrained="dfn5b"),
|
|
||||||
# lambda: OpenClipTextual("ViT-H-14-378-quickgelu", (1, 77), pretrained="dfn5b"),
|
|
||||||
# lambda: OpenClipVisual("XLM-Roberta-Large-Vit-L-14", (1, 3, 224, 224)),
|
|
||||||
]
|
|
||||||
|
|
||||||
succeeded: list[Callable[[], ExportBase]] = [
|
|
||||||
# lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="laion2b_e16"),
|
|
||||||
# lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="laion2b_e16"),
|
|
||||||
# lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="laion400m_e31"),
|
|
||||||
# lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="laion400m_e31"),
|
|
||||||
# lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="laion400m_e32"),
|
|
||||||
# lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="laion400m_e32"),
|
|
||||||
# lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="laion2b-s34b-b79k"),
|
|
||||||
# lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="laion2b-s34b-b79k"),
|
|
||||||
# lambda: OpenClipVisual("ViT-B-16", (1, 3, 224, 224), pretrained="laion400m_e31"),
|
|
||||||
# lambda: OpenClipTextual("ViT-B-16", (1, 77), pretrained="laion400m_e31"),
|
|
||||||
# lambda: OpenClipVisual("ViT-B-16", (1, 3, 224, 224), pretrained="laion400m_e32"),
|
|
||||||
# lambda: OpenClipTextual("ViT-B-16", (1, 77), pretrained="laion400m_e32"),
|
|
||||||
# lambda: OpenClipVisual("ViT-B-16-plus-240", (1, 3, 240, 240), pretrained="laion400m_e31"),
|
|
||||||
# lambda: OpenClipTextual("ViT-B-16-plus-240", (1, 77), pretrained="laion400m_e31"),
|
|
||||||
# lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="openai"),
|
|
||||||
# lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="openai"),
|
|
||||||
# lambda: OpenClipVisual("ViT-B-16", (1, 3, 224, 224), pretrained="openai"),
|
|
||||||
# lambda: OpenClipTextual("ViT-B-16", (1, 77), pretrained="openai"),
|
|
||||||
# lambda: OpenClipVisual("RN50", (1, 3, 224, 224), pretrained="openai"),
|
|
||||||
# lambda: OpenClipTextual("RN50", (1, 77), pretrained="openai"),
|
|
||||||
# lambda: OpenClipVisual("RN50", (1, 3, 224, 224), pretrained="yfcc15m"),
|
|
||||||
# lambda: OpenClipTextual("RN50", (1, 77), pretrained="yfcc15m"),
|
|
||||||
# lambda: OpenClipVisual("RN50", (1, 3, 224, 224), pretrained="cc12m"),
|
|
||||||
# lambda: OpenClipTextual("RN50", (1, 77), pretrained="cc12m"),
|
|
||||||
# lambda: OpenClipVisual("XLM-Roberta-Large-Vit-B-32", (1, 3, 224, 224)),
|
|
||||||
# lambda: OpenClipVisual("ViT-L-14", (1, 3, 224, 224), pretrained="openai"),
|
|
||||||
# lambda: OpenClipTextual("ViT-L-14", (1, 77), pretrained="openai"),
|
|
||||||
lambda: OpenClipVisual("ViT-L-14", (1, 3, 224, 224), pretrained="laion400m_e31"),
|
|
||||||
lambda: OpenClipTextual("ViT-L-14", (1, 77), pretrained="laion400m_e31"),
|
|
||||||
lambda: OpenClipVisual("ViT-L-14", (1, 3, 224, 224), pretrained="laion400m_e32"),
|
|
||||||
lambda: OpenClipTextual("ViT-L-14", (1, 77), pretrained="laion400m_e32"),
|
|
||||||
lambda: OpenClipVisual("ViT-L-14", (1, 3, 224, 224), pretrained="laion2b-s32b-b82k"),
|
|
||||||
lambda: OpenClipTextual("ViT-L-14", (1, 77), pretrained="laion2b-s32b-b82k"),
|
|
||||||
# lambda: OpenClipVisual("ViT-L-14-336", (1, 3, 336, 336), pretrained="openai"),
|
|
||||||
# lambda: OpenClipTextual("ViT-L-14-336", (1, 77), pretrained="openai"),
|
|
||||||
# lambda: ArcFace("buffalo_s", (1, 3, 112, 112), optimization_level=3),
|
|
||||||
# lambda: RetinaFace("buffalo_s", (1, 3, 640, 640), optimization_level=3),
|
|
||||||
# lambda: ArcFace("buffalo_m", (1, 3, 112, 112), optimization_level=3),
|
|
||||||
# lambda: RetinaFace("buffalo_m", (1, 3, 640, 640), optimization_level=3),
|
|
||||||
# lambda: ArcFace("buffalo_l", (1, 3, 112, 112), optimization_level=3),
|
|
||||||
# lambda: RetinaFace("buffalo_l", (1, 3, 640, 640), optimization_level=3),
|
|
||||||
# lambda: ArcFace("antelopev2", (1, 3, 112, 112), optimization_level=3),
|
|
||||||
# lambda: RetinaFace("antelopev2", (1, 3, 640, 640), optimization_level=3),
|
|
||||||
]
|
|
||||||
|
|
||||||
models: list[Callable[[], ExportBase]] = [*failed, *succeeded]
|
|
||||||
for _model in succeeded:
|
|
||||||
model = _model()
|
|
||||||
try:
|
|
||||||
model_dir = os.path.join("output", model.model_name)
|
|
||||||
output_dir = os.path.join(model_dir, model.task)
|
|
||||||
armnn_fp32, armnn_fp16 = model.to_armnn(output_dir)
|
|
||||||
relative_fp32 = os.path.relpath(armnn_fp32, start=model_dir)
|
|
||||||
relative_fp16 = os.path.relpath(armnn_fp16, start=model_dir)
|
|
||||||
if hf_token and os.path.isfile(armnn_fp32):
|
|
||||||
print(f"Uploading {model.model_name} ({model.task}) ARM NN model with fp32 precision")
|
|
||||||
upload_file(path_or_fileobj=armnn_fp32, path_in_repo=relative_fp32, repo_id=model.repo_name)
|
|
||||||
print(f"Finished uploading {model.model_name} ({model.task}) ARM NN model with fp32 precision")
|
|
||||||
if hf_token and os.path.isfile(armnn_fp16):
|
|
||||||
print(f"Uploading {model.model_name} ({model.task}) ARM NN model with fp16 precision")
|
|
||||||
upload_file(path_or_fileobj=armnn_fp16, path_in_repo=relative_fp16, repo_id=model.repo_name)
|
|
||||||
print(f"Finished uploading {model.model_name} ({model.task}) ARM NN model with fp16 precision")
|
|
||||||
except Exception as exc:
|
|
||||||
print(f"Failed to export {model.model_name} ({model.task}): {exc}")
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
Reference in New Issue
Block a user