add cli
This commit is contained in:
@@ -0,0 +1,99 @@
|
||||
import os
|
||||
import platform
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import typer
|
||||
|
||||
from onnx2ann.export import Exporter, ModelType, Precision
|
||||
|
||||
app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)
|
||||
|
||||
|
||||
@app.command()
|
||||
def export(
|
||||
model_name: Annotated[
|
||||
str, typer.Argument(..., help="The name of the model to be exported as it exists in Hugging Face.")
|
||||
],
|
||||
model_type: Annotated[ModelType, typer.Option(..., "--type", "-t", help="The type of model to be exported.")],
|
||||
input_shapes: Annotated[
|
||||
list[str],
|
||||
typer.Option(
|
||||
...,
|
||||
"--input-shape",
|
||||
"-s",
|
||||
help="The shape of an input tensor to the model, each dimension separated by commas. "
|
||||
"Multiple shapes can be provided for multiple inputs.",
|
||||
),
|
||||
],
|
||||
precision: Annotated[
|
||||
Precision,
|
||||
typer.Option(
|
||||
...,
|
||||
"--precision",
|
||||
"-p",
|
||||
help="The precision of the exported model. `float16` requires a GPU.",
|
||||
),
|
||||
] = Precision.FLOAT32,
|
||||
cache_dir: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
...,
|
||||
"--cache-dir",
|
||||
"-c",
|
||||
help="Directory where pre-export models will be stored.",
|
||||
envvar="CACHE_DIR",
|
||||
show_envvar=True,
|
||||
),
|
||||
] = "~/.cache/huggingface",
|
||||
output_dir: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
...,
|
||||
"--output-dir",
|
||||
"-o",
|
||||
help="Directory where exported models will be stored.",
|
||||
),
|
||||
] = "output",
|
||||
auth_token: Annotated[
|
||||
Optional[str],
|
||||
typer.Option(
|
||||
...,
|
||||
"--auth-token",
|
||||
"-t",
|
||||
help="If uploading models to Hugging Face, the auth token of the user or organisation.",
|
||||
envvar="HF_AUTH_TOKEN",
|
||||
show_envvar=True,
|
||||
),
|
||||
] = None,
|
||||
force_export: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
...,
|
||||
"--force-export",
|
||||
"-f",
|
||||
help="Export the model even if an exported model already exists in the output directory.",
|
||||
),
|
||||
] = False,
|
||||
) -> None:
|
||||
if platform.machine() not in ("x86_64", "AMD64"):
|
||||
msg = f"Can only run on x86_64 / AMD64, not {platform.machine()}"
|
||||
raise RuntimeError(msg)
|
||||
os.environ.setdefault("LD_LIBRARY_PATH", "armnn")
|
||||
parsed_input_shapes = [tuple(map(int, shape.split(","))) for shape in input_shapes]
|
||||
model = Exporter(
|
||||
model_name, model_type, input_shapes=parsed_input_shapes, cache_dir=cache_dir, force_export=force_export
|
||||
)
|
||||
model_dir = os.path.join("output", model_name)
|
||||
output_dir = os.path.join(model_dir, model_type)
|
||||
armnn_model = model.to_armnn(output_dir, precision)
|
||||
|
||||
if not auth_token:
|
||||
return
|
||||
|
||||
from huggingface_hub import upload_file
|
||||
|
||||
relative_path = os.path.relpath(armnn_model, start=model_dir)
|
||||
upload_file(path_or_fileobj=armnn_model, path_in_repo=relative_path, repo_id=model.repo_name, token=auth_token)
|
||||
|
||||
|
||||
app()
|
||||
@@ -0,0 +1,129 @@
|
||||
import os
|
||||
import subprocess
|
||||
from enum import StrEnum
|
||||
|
||||
from onnx2ann.helpers import onnx_make_armnn_compatible, onnx_make_inputs_fixed
|
||||
|
||||
|
||||
class ModelType(StrEnum):
|
||||
VISUAL = "visual"
|
||||
TEXTUAL = "textual"
|
||||
RECOGNITION = "recognition"
|
||||
DETECTION = "detection"
|
||||
|
||||
|
||||
class Precision(StrEnum):
|
||||
FLOAT16 = "float16"
|
||||
FLOAT32 = "float32"
|
||||
|
||||
|
||||
class Exporter:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: str,
|
||||
input_shapes: list[tuple[int, ...]],
|
||||
optimization_level: int = 5,
|
||||
cache_dir: str = os.environ.get("CACHE_DIR", "~/.cache/huggingface"),
|
||||
force_export: bool = False,
|
||||
):
|
||||
self.model_name = model_name.split("/")[-1]
|
||||
self.model_type = model_type
|
||||
self.optimize = optimization_level
|
||||
self.input_shapes = input_shapes
|
||||
self.cache_dir = os.path.join(cache_dir, self.repo_name)
|
||||
self.force_export = force_export
|
||||
|
||||
def download(self) -> str:
|
||||
model_path = os.path.join(self.cache_dir, self.model_type, "model.onnx")
|
||||
if os.path.isfile(model_path):
|
||||
print(f"Model is already downloaded at {model_path}")
|
||||
return model_path
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
snapshot_download(
|
||||
self.repo_name, cache_dir=self.cache_dir, local_dir=self.cache_dir, local_dir_use_symlinks=False
|
||||
)
|
||||
return model_path
|
||||
|
||||
def to_onnx_static(self, precision: Precision) -> str:
|
||||
import onnx
|
||||
from onnxconverter_common import float16
|
||||
onnx_path_original = self.download()
|
||||
static_dir = os.path.join(self.cache_dir, self.model_type, "static")
|
||||
|
||||
static_path = os.path.join(static_dir, f"model.onnx")
|
||||
if self.force_export and not os.path.isfile(static_path):
|
||||
print(f"Making {self} static")
|
||||
os.makedirs(static_dir, exist_ok=True)
|
||||
onnx_make_inputs_fixed(onnx_path_original, static_path, self.input_shapes)
|
||||
onnx_make_armnn_compatible(static_path)
|
||||
print(f"Finished making {self} static")
|
||||
|
||||
model = onnx.load(static_path)
|
||||
self.inputs = [input_.name for input_ in model.graph.input]
|
||||
self.outputs = [output_.name for output_ in model.graph.output]
|
||||
if precision == Precision.FLOAT16:
|
||||
static_path = os.path.join(static_dir, f"model_{precision}.onnx")
|
||||
print(f"Converting {self} to {precision} precision")
|
||||
model = float16.convert_float_to_float16(model, keep_io_types=True, disable_shape_infer=True)
|
||||
onnx.save(model, static_path)
|
||||
print(f"Finished converting {self} to {precision} precision")
|
||||
# self.inputs, self.outputs = onnx_get_inputs_outputs(static_path)
|
||||
return static_path
|
||||
|
||||
def to_tflite(self, output_dir: str, precision: Precision) -> str:
|
||||
onnx_model = self.to_onnx_static(precision)
|
||||
tflite_dir = os.path.join(output_dir, precision)
|
||||
tflite_model = os.path.join(tflite_dir, f"model_{precision}.tflite")
|
||||
if self.force_export or not os.path.isfile(tflite_model):
|
||||
import onnx2tf
|
||||
|
||||
print(f"Exporting {self} to TFLite with {precision} precision (this might take a few minutes)")
|
||||
onnx2tf.convert(
|
||||
input_onnx_file_path=onnx_model,
|
||||
output_folder_path=tflite_dir,
|
||||
keep_shape_absolutely_input_names=self.inputs,
|
||||
# verbosity="warn",
|
||||
copy_onnx_input_output_names_to_tflite=True,
|
||||
output_signaturedefs=True,
|
||||
not_use_onnxsim=True,
|
||||
)
|
||||
print(f"Finished exporting {self} to TFLite with {precision} precision")
|
||||
|
||||
return tflite_model
|
||||
|
||||
def to_armnn(self, output_dir: str, precision: Precision) -> tuple[str, str]:
|
||||
armnn_model = os.path.join(output_dir, "model.armnn")
|
||||
if not self.force_export and os.path.isfile(armnn_model):
|
||||
return armnn_model
|
||||
|
||||
tflite_model_dir = os.path.join(output_dir, "tflite")
|
||||
tflite_model = self.to_tflite(tflite_model_dir, precision)
|
||||
|
||||
args = ["./armnnconverter", "-f", "tflite-binary", "-m", tflite_model, "-p", armnn_model]
|
||||
args.append("-i")
|
||||
args.extend(self.inputs)
|
||||
args.append("-o")
|
||||
args.extend(self.outputs)
|
||||
|
||||
print(f"Exporting {self} to ARM NN with {precision} precision")
|
||||
try:
|
||||
if (stdout := subprocess.check_output(args, stderr=subprocess.STDOUT).decode()):
|
||||
print(stdout)
|
||||
print(f"Finished exporting {self} to ARM NN with {precision} precision")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e.output.decode())
|
||||
try:
|
||||
from shutil import rmtree
|
||||
|
||||
rmtree(tflite_model_dir, ignore_errors=True)
|
||||
finally:
|
||||
raise e
|
||||
|
||||
@property
|
||||
def repo_name(self) -> str:
|
||||
return f"immich-app/{self.model_name}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.model_name} ({self.model_type})"
|
||||
@@ -0,0 +1,260 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
def onnx_make_armnn_compatible(model_path: str) -> None:
|
||||
"""
|
||||
i can explain
|
||||
armnn only supports up to 4d tranposes, but the model has a 5d transpose due to a redundant unsqueeze
|
||||
this function folds the unsqueeze+transpose+squeeze into a single 4d transpose
|
||||
it also switches from gather ops to slices since armnn has different dimension semantics for gathers
|
||||
also fixes batch normalization being in training mode
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
from onnx_graphsurgeon import Constant, Node, Variable, export_onnx, import_onnx
|
||||
|
||||
proto = onnx.load(model_path)
|
||||
graph = import_onnx(proto)
|
||||
|
||||
gather_idx = 1
|
||||
squeeze_idx = 1
|
||||
for node in graph.nodes:
|
||||
for link1 in node.outputs:
|
||||
if "Unsqueeze" in link1.name:
|
||||
for node1 in link1.outputs:
|
||||
for link2 in node1.outputs:
|
||||
if "Transpose" in link2.name:
|
||||
for node2 in link2.outputs:
|
||||
if node2.attrs.get("perm") == [3, 1, 2, 0, 4]:
|
||||
node2.attrs["perm"] = [2, 0, 1, 3]
|
||||
link2.shape = link1.shape
|
||||
for link3 in node2.outputs:
|
||||
if "Squeeze" in link3.name:
|
||||
link3.shape = [link3.shape[x] for x in [0, 1, 2, 4]]
|
||||
for node3 in link3.outputs:
|
||||
for link4 in node3.outputs:
|
||||
link4.shape = link3.shape
|
||||
try:
|
||||
idx = link2.inputs.index(node1)
|
||||
link2.inputs[idx] = node
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
node.outputs = [link2]
|
||||
if "Gather" in link4.name:
|
||||
for node4 in link4.outputs:
|
||||
axis = node1.attrs.get("axis", 0)
|
||||
index = node4.inputs[1].values
|
||||
slice_link = Variable(
|
||||
f"onnx::Slice_123{gather_idx}",
|
||||
dtype=link4.dtype,
|
||||
shape=[1] + link3.shape[1:],
|
||||
)
|
||||
slice_node = Node(
|
||||
op="Slice",
|
||||
inputs=[
|
||||
link3,
|
||||
Constant(
|
||||
f"SliceStart_123{gather_idx}",
|
||||
np.array([index]),
|
||||
),
|
||||
Constant(
|
||||
f"SliceEnd_123{gather_idx}",
|
||||
np.array([index + 1]),
|
||||
),
|
||||
Constant(
|
||||
f"SliceAxis_123{gather_idx}",
|
||||
np.array([axis]),
|
||||
),
|
||||
],
|
||||
outputs=[slice_link],
|
||||
name=f"Slice_123{gather_idx}",
|
||||
)
|
||||
graph.nodes.append(slice_node)
|
||||
gather_idx += 1
|
||||
|
||||
for link5 in node4.outputs:
|
||||
for node5 in link5.outputs:
|
||||
try:
|
||||
idx = node5.inputs.index(link5)
|
||||
node5.inputs[idx] = slice_link
|
||||
except ValueError:
|
||||
pass
|
||||
elif node.op == "LayerNormalization":
|
||||
for node1 in link1.outputs:
|
||||
if node1.op == "Gather":
|
||||
for link2 in node1.outputs:
|
||||
for node2 in link2.outputs:
|
||||
axis = node1.attrs.get("axis", 0)
|
||||
index = node1.inputs[1].values
|
||||
slice_link = Variable(
|
||||
f"onnx::Slice_123{gather_idx}",
|
||||
dtype=link2.dtype,
|
||||
shape=[1, *link2.shape],
|
||||
)
|
||||
slice_node = Node(
|
||||
op="Slice",
|
||||
inputs=[
|
||||
node1.inputs[0],
|
||||
Constant(
|
||||
f"SliceStart_123{gather_idx}",
|
||||
np.array([index]),
|
||||
),
|
||||
Constant(
|
||||
f"SliceEnd_123{gather_idx}",
|
||||
np.array([index + 1]),
|
||||
),
|
||||
Constant(
|
||||
f"SliceAxis_123{gather_idx}",
|
||||
np.array([axis]),
|
||||
),
|
||||
],
|
||||
outputs=[slice_link],
|
||||
name=f"Slice_123{gather_idx}",
|
||||
)
|
||||
graph.nodes.append(slice_node)
|
||||
gather_idx += 1
|
||||
|
||||
squeeze_link = Variable(
|
||||
f"onnx::Squeeze_123{squeeze_idx}",
|
||||
dtype=link2.dtype,
|
||||
shape=link2.shape,
|
||||
)
|
||||
squeeze_node = Node(
|
||||
op="Squeeze",
|
||||
inputs=[
|
||||
slice_link,
|
||||
Constant(
|
||||
f"SqueezeAxis_123{squeeze_idx}",
|
||||
np.array([0]),
|
||||
),
|
||||
],
|
||||
outputs=[squeeze_link],
|
||||
name=f"Squeeze_123{squeeze_idx}",
|
||||
)
|
||||
graph.nodes.append(squeeze_node)
|
||||
squeeze_idx += 1
|
||||
try:
|
||||
idx = node2.inputs.index(link2)
|
||||
node2.inputs[idx] = squeeze_link
|
||||
except ValueError:
|
||||
pass
|
||||
elif node.op == "Reshape":
|
||||
for node1 in link1.outputs:
|
||||
if node1.op == "Gather":
|
||||
node2s = [n for link in node1.outputs for n in link.outputs]
|
||||
if any(n.op == "Abs" for n in node2s):
|
||||
axis = node1.attrs.get("axis", 0)
|
||||
index = node1.inputs[1].values
|
||||
slice_link = Variable(
|
||||
f"onnx::Slice_123{gather_idx}",
|
||||
dtype=node1.outputs[0].dtype,
|
||||
shape=[1, *node1.outputs[0].shape],
|
||||
)
|
||||
slice_node = Node(
|
||||
op="Slice",
|
||||
inputs=[
|
||||
node1.inputs[0],
|
||||
Constant(
|
||||
f"SliceStart_123{gather_idx}",
|
||||
np.array([index]),
|
||||
),
|
||||
Constant(
|
||||
f"SliceEnd_123{gather_idx}",
|
||||
np.array([index + 1]),
|
||||
),
|
||||
Constant(
|
||||
f"SliceAxis_123{gather_idx}",
|
||||
np.array([axis]),
|
||||
),
|
||||
],
|
||||
outputs=[slice_link],
|
||||
name=f"Slice_123{gather_idx}",
|
||||
)
|
||||
graph.nodes.append(slice_node)
|
||||
gather_idx += 1
|
||||
|
||||
squeeze_link = Variable(
|
||||
f"onnx::Squeeze_123{squeeze_idx}",
|
||||
dtype=node1.outputs[0].dtype,
|
||||
shape=node1.outputs[0].shape,
|
||||
)
|
||||
squeeze_node = Node(
|
||||
op="Squeeze",
|
||||
inputs=[
|
||||
slice_link,
|
||||
Constant(
|
||||
f"SqueezeAxis_123{squeeze_idx}",
|
||||
np.array([0]),
|
||||
),
|
||||
],
|
||||
outputs=[squeeze_link],
|
||||
name=f"Squeeze_123{squeeze_idx}",
|
||||
)
|
||||
graph.nodes.append(squeeze_node)
|
||||
squeeze_idx += 1
|
||||
for node2 in node2s:
|
||||
node2.inputs[0] = squeeze_link
|
||||
elif node.op == "BatchNormalization" and node.attrs.get("training_mode") == 1:
|
||||
node.attrs["training_mode"] = 0
|
||||
node.outputs = node.outputs[:1]
|
||||
|
||||
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
|
||||
graph.toposort()
|
||||
graph.fold_constants()
|
||||
updated = export_onnx(graph)
|
||||
onnx_save(updated, model_path)
|
||||
|
||||
# for some reason, reloading the model is necessary to apply the correct shape
|
||||
proto = onnx.load(model_path)
|
||||
graph = import_onnx(proto)
|
||||
for node in graph.nodes:
|
||||
if node.op == "Slice":
|
||||
for link in node.outputs:
|
||||
if "Slice_123" in link.name and link.shape[0] == 3: # noqa: PLR2004
|
||||
link.shape[0] = 1
|
||||
|
||||
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
|
||||
graph.toposort()
|
||||
graph.fold_constants()
|
||||
updated = export_onnx(graph)
|
||||
onnx_save(updated, model_path)
|
||||
onnx.shape_inference.infer_shapes_path(model_path, check_type=True, strict_mode=True, data_prop=True)
|
||||
|
||||
|
||||
def onnx_make_inputs_fixed(input_path: str, output_path: str, input_shapes: list[tuple[int, ...]]) -> None:
|
||||
import onnx
|
||||
import onnxsim
|
||||
from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed
|
||||
|
||||
model, success = onnxsim.simplify(input_path)
|
||||
if not success:
|
||||
msg = f"Failed to simplify {input_path}"
|
||||
raise RuntimeError(msg)
|
||||
onnx_save(model, output_path)
|
||||
onnx.shape_inference.infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True)
|
||||
model = onnx.load_model(output_path)
|
||||
for input_node, shape in zip(model.graph.input, input_shapes, strict=False):
|
||||
make_input_shape_fixed(model.graph, input_node.name, shape)
|
||||
fix_output_shapes(model)
|
||||
onnx_save(model, output_path)
|
||||
onnx.shape_inference.infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True)
|
||||
|
||||
|
||||
def onnx_get_inputs_outputs(model_path: str) -> tuple[list[str], list[str]]:
|
||||
import onnx
|
||||
|
||||
model = onnx.load(model_path)
|
||||
inputs = [input_.name for input_ in model.graph.input]
|
||||
outputs = [output_.name for output_ in model.graph.output]
|
||||
return inputs, outputs
|
||||
|
||||
|
||||
def onnx_save(model: Any, output_path: str) -> None:
|
||||
import onnx
|
||||
|
||||
try:
|
||||
onnx.save(model, output_path)
|
||||
except:
|
||||
onnx.save(model, output_path, save_as_external_data=True, all_tensors_to_one_file=False, size_threshold=1_000_000)
|
||||
Reference in New Issue
Block a user