Files
immich/machine-learning/export/ann/onnx2ann/__main__.py
T
mertalev 72269ab58c add cli
2024-07-12 16:50:48 -04:00

100 lines
3.0 KiB
Python

import os
import platform
from typing import Annotated, Optional
import typer
from onnx2ann.export import Exporter, ModelType, Precision
app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)
@app.command()
def export(
model_name: Annotated[
str, typer.Argument(..., help="The name of the model to be exported as it exists in Hugging Face.")
],
model_type: Annotated[ModelType, typer.Option(..., "--type", "-t", help="The type of model to be exported.")],
input_shapes: Annotated[
list[str],
typer.Option(
...,
"--input-shape",
"-s",
help="The shape of an input tensor to the model, each dimension separated by commas. "
"Multiple shapes can be provided for multiple inputs.",
),
],
precision: Annotated[
Precision,
typer.Option(
...,
"--precision",
"-p",
help="The precision of the exported model. `float16` requires a GPU.",
),
] = Precision.FLOAT32,
cache_dir: Annotated[
str,
typer.Option(
...,
"--cache-dir",
"-c",
help="Directory where pre-export models will be stored.",
envvar="CACHE_DIR",
show_envvar=True,
),
] = "~/.cache/huggingface",
output_dir: Annotated[
str,
typer.Option(
...,
"--output-dir",
"-o",
help="Directory where exported models will be stored.",
),
] = "output",
auth_token: Annotated[
Optional[str],
typer.Option(
...,
"--auth-token",
"-t",
help="If uploading models to Hugging Face, the auth token of the user or organisation.",
envvar="HF_AUTH_TOKEN",
show_envvar=True,
),
] = None,
force_export: Annotated[
bool,
typer.Option(
...,
"--force-export",
"-f",
help="Export the model even if an exported model already exists in the output directory.",
),
] = False,
) -> None:
if platform.machine() not in ("x86_64", "AMD64"):
msg = f"Can only run on x86_64 / AMD64, not {platform.machine()}"
raise RuntimeError(msg)
os.environ.setdefault("LD_LIBRARY_PATH", "armnn")
parsed_input_shapes = [tuple(map(int, shape.split(","))) for shape in input_shapes]
model = Exporter(
model_name, model_type, input_shapes=parsed_input_shapes, cache_dir=cache_dir, force_export=force_export
)
model_dir = os.path.join("output", model_name)
output_dir = os.path.join(model_dir, model_type)
armnn_model = model.to_armnn(output_dir, precision)
if not auth_token:
return
from huggingface_hub import upload_file
relative_path = os.path.relpath(armnn_model, start=model_dir)
upload_file(path_or_fileobj=armnn_model, path_in_repo=relative_path, repo_id=model.repo_name, token=auth_token)
app()