onnx2tf, 4d transpose

This commit is contained in:
mertalev
2024-07-06 21:17:55 -04:00
parent 956480ab2c
commit 5dae920ac6
+140 -151
View File
@@ -1,52 +1,63 @@
import os import os
import platform import platform
import subprocess import subprocess
from tempfile import TemporaryDirectory
from typing import Callable, ClassVar from typing import Callable, ClassVar
import onnx import onnx
import torch from onnx_graphsurgeon import import_onnx, export_onnx
from onnx2torch import convert
from onnx2torch.node_converters.registry import add_converter
from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed
from tinynn.converter import TFLiteConverter
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from onnx2torch.onnx_graph import OnnxGraph
from onnx2torch.onnx_node import OnnxNode
from onnx2torch.utils.common import OperationConverterResult, onnx_mapping_from_node
from onnx.shape_inference import infer_shapes_path from onnx.shape_inference import infer_shapes_path
from huggingface_hub import login, upload_file from huggingface_hub import login, upload_file
import onnx2tf
# egregious hacks: # i can explain
# changed `Clip`'s min/max logic to skip empty strings # armnn only supports up to 4d tranposes, but the model has a 5d transpose due to a redundant unsqueeze
# changed OnnxSqueezeDynamicAxes to use `sorted` instead of `torch.sort`` # this function folds the unsqueeze+transpose+squeeze into a single 4d transpose
# commented out shape inference in `fix_output_shapes`` def onnx_transpose_4d(model_path: str):
proto = onnx.load(model_path)
graph = import_onnx(proto)
for node in graph.nodes:
for i, link1 in enumerate(node.outputs):
if "Unsqueeze" in link1.name:
for node1 in link1.outputs:
for link2 in node1.outputs:
if "Transpose" in link2.name:
for node2 in link2.outputs:
if node2.attrs.get("perm") == [3, 1, 2, 0, 4]:
node2.attrs["perm"] = [2, 0, 1, 3]
link2.shape = link1.shape
for link3 in node2.outputs:
if "Squeeze" in link3.name:
for node3 in link3.outputs:
for link4 in node3.outputs:
link4.shape = [link3.shape[x] for x in [0, 1, 2, 4]]
for inputs in link4.inputs:
if inputs.name == node3.name:
i = link2.inputs.index(node1)
if i >= 0:
link2.inputs[i] = node
i = link4.inputs.index(node3)
if i >= 0:
link4.inputs[i] = node2
node.outputs = [link2]
node1.inputs = []
node1.outputs = []
node3.inputs = []
node3.outputs = []
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
graph.toposort()
graph.fold_constants()
updated = export_onnx(graph)
onnx.save(updated, model_path, save_as_external_data=True, all_tensors_to_one_file=False)
infer_shapes_path(model_path, check_type=True, strict_mode=True, data_prop=True)
class ArgMax(torch.nn.Module): class ExportBase:
def __init__(self, dim: int = -1, keepdim: bool = False):
super().__init__()
self.dim = dim
self.keepdim = keepdim
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.argmax(input, dim=self.dim, keepdim=self.keepdim)
class Erf(torch.nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.erf(input)
@add_converter(operation_type="ArgMax", version=13)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
return OperationConverterResult(
torch_module=ArgMax(),
onnx_mapping=onnx_mapping_from_node(node=node),
)
class ExportBase(torch.nn.Module):
task: ClassVar[str] task: ClassVar[str]
def __init__( def __init__(
@@ -62,84 +73,94 @@ class ExportBase(torch.nn.Module):
self.nchw_transpose = False self.nchw_transpose = False
self.input_shape = input_shape self.input_shape = input_shape
self.pretrained = pretrained self.pretrained = pretrained
self.dummy_param = torch.nn.Parameter(torch.empty(0))
self.model = self.load().eval()
for param in self.parameters():
param.requires_grad_(False)
self.eval()
def load(self) -> torch.nn.Module: def to_onnx_static(self) -> str:
cache_dir = os.path.join(os.environ["CACHE_DIR"], self.model_name) cache_dir = os.path.join(os.environ["CACHE_DIR"], self.model_name)
task_path = os.path.join(cache_dir, self.task) task_path = os.path.join(cache_dir, self.task)
model_path = os.path.join(task_path, "model.onnx") model_path = os.path.join(task_path, "model.onnx")
if not os.path.isfile(model_path): if not os.path.isfile(model_path):
print(f"Downloading {self.model_name}...")
snapshot_download(self.repo_name, cache_dir=cache_dir, local_dir=cache_dir) snapshot_download(self.repo_name, cache_dir=cache_dir, local_dir=cache_dir)
infer_shapes_path(model_path, check_type=True, strict_mode=True, data_prop=True)
onnx_model = onnx.load_model(model_path)
make_input_shape_fixed(onnx_model.graph, onnx_model.graph.input[0].name, self.input_shape)
fix_output_shapes(onnx_model)
# try:
# onnx.save(onnx_model, model_path)
# except:
# onnx.save(onnx_model, model_path, save_as_external_data=True, all_tensors_to_one_file=False)
# infer_shapes_path(model_path, check_type=True, strict_mode=True, data_prop=True)
# onnx_model = onnx.load_model(model_path)
# onnx_model = infer_shapes(onnx_model, check_type=True, strict_mode=True, data_prop=True)
return convert(onnx_model)
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor]: static_dir = os.path.join(task_path, "static")
if self.precision == "fp16": static_path = os.path.join(static_dir, "model.onnx")
inputs = tuple(i.half() for i in inputs) os.makedirs(static_dir, exist_ok=True)
out = self._forward(*inputs) if not os.path.isfile(static_path):
if self.precision == "fp16": print(f"Making {self.model_name} ({self.task}) static")
if isinstance(out, tuple): infer_shapes_path(onnx_path_original, check_type=True, strict_mode=True, data_prop=True)
return tuple(o.float() for o in out) onnx_path_original = os.path.join(cache_dir, "model.onnx")
return out.float() static_model = onnx.load_model(onnx_path_original)
return out make_input_shape_fixed(static_model.graph, static_model.graph.input[0].name, (1, 3, 224, 224))
fix_output_shapes(static_model)
onnx.save(static_model, static_path, save_as_external_data=True, all_tensors_to_one_file=False)
infer_shapes_path(static_path, check_type=True, strict_mode=True, data_prop=True)
onnx_transpose_4d(static_path)
return static_path
def _forward(self, *inputs: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor]: def to_tflite(self, output_dir: str) -> tuple[str, str]:
return self.model(*inputs) input_path = self.to_onnx_static()
def to_armnn(self, output_path: str) -> None:
output_dir = os.path.dirname(output_path)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
self(*self.dummy_inputs) tflite_fp32 = os.path.join(output_dir, "model_float32.tflite")
print(f"Exporting {self.model_name} ({self.task}) with {self.precision} precision") tflite_fp16 = os.path.join(output_dir, "model_float16.tflite")
jit = torch.jit.trace(self, self.dummy_inputs).eval() if not os.path.isfile(tflite_fp32) or not os.path.isfile(tflite_fp16):
with TemporaryDirectory() as tmpdir: print(f"Exporting {self.model_name} ({self.task}) to TFLite")
tflite_model_path = os.path.join(tmpdir, "model.tflite") onnx2tf.convert(
converter = TFLiteConverter( input_onnx_file_path=input_path,
jit, output_folder_path=output_dir,
self.dummy_inputs, copy_onnx_input_output_names_to_tflite=True,
tflite_model_path,
optimize=self.optimize,
nchw_transpose=self.nchw_transpose,
) )
# segfaults on ARM, must run on x86_64 / AMD64
converter.convert()
return tflite_fp32, tflite_fp16
def to_armnn(self, output_dir: str) -> tuple[str, str]:
tflite_model_dir = os.path.join(output_dir, "tflite")
tflite_fp32, tflite_fp16 = self.to_tflite(tflite_model_dir)
fp16_dir = os.path.join(output_dir, "fp16")
os.makedirs(fp16_dir, exist_ok=True)
armnn_fp32 = os.path.join(output_dir, "model.armnn")
armnn_fp16 = os.path.join(fp16_dir, "model.armnn")
print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp32 precision")
subprocess.run( subprocess.run(
[ [
"./armnnconverter", "./armnnconverter",
"-f", "-f",
"tflite-binary", "tflite-binary",
"-m", "-m",
tflite_model_path, tflite_fp32,
"-i", "-i",
"input_tensor", "input_tensor",
"-o", "-o",
"output_tensor", "output_tensor",
"-p", "-p",
output_path, armnn_fp32,
], ],
capture_output=True, capture_output=True,
) )
print(f"Finished exporting {self.name} ({self.task}) with {self.precision} precision") print(f"Finished exporting {self.name} ({self.task}) with fp32 precision")
@property print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp16 precision")
def dummy_inputs(self) -> tuple[torch.FloatTensor]: subprocess.run(
return (torch.rand(self.input_shape, device=self.device, dtype=self.dtype),) [
"./armnnconverter",
"-f",
"tflite-binary",
"-m",
tflite_fp16,
"-i",
"input_tensor",
"-o",
"output_tensor",
"-p",
armnn_fp16,
],
capture_output=True,
)
print(f"Finished exporting {self.name} ({self.task}) with fp16 precision")
return armnn_fp32, armnn_fp16
@property @property
def model_name(self) -> str: def model_name(self) -> str:
@@ -149,25 +170,6 @@ class ExportBase(torch.nn.Module):
def repo_name(self) -> str: def repo_name(self) -> str:
return f"immich-app/{self.model_name}" return f"immich-app/{self.model_name}"
@property
def device(self) -> torch.device:
return self.dummy_param.device
@property
def dtype(self) -> torch.dtype:
return self.dummy_param.dtype
@property
def precision(self) -> str:
match self.dtype:
case torch.float32:
return "fp32"
case torch.float16:
return "fp16"
case _:
raise ValueError(f"Unsupported dtype {self.dtype}")
class ArcFace(ExportBase): class ArcFace(ExportBase):
task = "recognition" task = "recognition"
@@ -183,28 +185,18 @@ class OpenClipVisual(ExportBase):
class OpenClipTextual(ExportBase): class OpenClipTextual(ExportBase):
task = "textual" task = "textual"
@property
def dummy_inputs(self) -> tuple[torch.LongTensor]:
return (torch.randint(0, 5000, self.input_shape, device=self.device, dtype=torch.int32),)
class MClipTextual(ExportBase): class MClipTextual(ExportBase):
task = "textual" task = "textual"
@property
def dummy_inputs(self) -> tuple[torch.LongTensor]:
return (
torch.randint(0, 5000, self.input_shape, device=self.device, dtype=torch.int32),
torch.randint(0, 1, self.input_shape, device=self.device, dtype=torch.int32),
)
def main() -> None: def main() -> None:
if platform.machine() not in ("x86_64", "AMD64"): if platform.machine() not in ("x86_64", "AMD64"):
raise RuntimeError(f"Can only run on x86_64 / AMD64, not {platform.machine()}") raise RuntimeError(f"Can only run on x86_64 / AMD64, not {platform.machine()}")
upload_to_hf = "HF_AUTH_TOKEN" in os.environ
if upload_to_hf:
login(token=os.environ["HF_AUTH_TOKEN"]) login(token=os.environ["HF_AUTH_TOKEN"])
os.environ["LD_LIBRARY_PATH"] = "armnn" os.environ["LD_LIBRARY_PATH"] = "armnn"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
failed: list[Callable[[], ExportBase]] = [ failed: list[Callable[[], ExportBase]] = [
lambda: OpenClipVisual("ViT-H-14-378-quickgelu", (1, 3, 378, 378), pretrained="dfn5b"), # flatbuffers: cannot grow buffer beyond 2 gigabytes (will probably work with fp16) lambda: OpenClipVisual("ViT-H-14-378-quickgelu", (1, 3, 378, 378), pretrained="dfn5b"), # flatbuffers: cannot grow buffer beyond 2 gigabytes (will probably work with fp16)
lambda: OpenClipVisual("ViT-H-14-quickgelu", (1, 3, 224, 224), pretrained="dfn5b"), # flatbuffers: cannot grow buffer beyond 2 gigabytes (will probably work with fp16) lambda: OpenClipVisual("ViT-H-14-quickgelu", (1, 3, 224, 224), pretrained="dfn5b"), # flatbuffers: cannot grow buffer beyond 2 gigabytes (will probably work with fp16)
@@ -249,49 +241,46 @@ def main() -> None:
] ]
succeeded: list[Callable[[], ExportBase]] = [ succeeded: list[Callable[[], ExportBase]] = [
lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="openai"), # lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="openai"),
lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="openai"), # lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="openai"),
lambda: OpenClipVisual("ViT-B-16", (1, 3, 224, 224), pretrained="openai"), lambda: OpenClipVisual("ViT-B-16", (1, 3, 224, 224), pretrained="openai"),
lambda: OpenClipTextual("ViT-B-16", (1, 77), pretrained="openai"), lambda: OpenClipTextual("ViT-B-16", (1, 77), pretrained="openai"),
lambda: OpenClipVisual("ViT-L-14", (1, 3, 224, 224), pretrained="openai"), # lambda: OpenClipVisual("ViT-L-14", (1, 3, 224, 224), pretrained="openai"),
lambda: OpenClipTextual("ViT-L-14", (1, 77), pretrained="openai"), # lambda: OpenClipTextual("ViT-L-14", (1, 77), pretrained="openai"),
lambda: OpenClipVisual("ViT-L-14-336", (1, 3, 336, 336), pretrained="openai"), # lambda: OpenClipVisual("ViT-L-14-336", (1, 3, 336, 336), pretrained="openai"),
lambda: OpenClipTextual("ViT-L-14-336", (1, 77), pretrained="openai"), # lambda: OpenClipTextual("ViT-L-14-336", (1, 77), pretrained="openai"),
lambda: OpenClipVisual("RN50", (1, 3, 224, 224), pretrained="openai"), # lambda: OpenClipVisual("RN50", (1, 3, 224, 224), pretrained="openai"),
lambda: OpenClipTextual("RN50", (1, 77), pretrained="openai"), # lambda: OpenClipTextual("RN50", (1, 77), pretrained="openai"),
lambda: OpenClipTextual("ViT-H-14-quickgelu", (1, 77), pretrained="dfn5b"), # lambda: OpenClipTextual("ViT-H-14-quickgelu", (1, 77), pretrained="dfn5b"),
lambda: OpenClipTextual("ViT-H-14-378-quickgelu", (1, 77), pretrained="dfn5b"), # lambda: OpenClipTextual("ViT-H-14-378-quickgelu", (1, 77), pretrained="dfn5b"),
lambda: OpenClipVisual("XLM-Roberta-Large-Vit-L-14", (1, 3, 224, 224)), # lambda: OpenClipVisual("XLM-Roberta-Large-Vit-L-14", (1, 3, 224, 224)),
lambda: OpenClipVisual("XLM-Roberta-Large-Vit-B-32", (1, 3, 224, 224)), # lambda: OpenClipVisual("XLM-Roberta-Large-Vit-B-32", (1, 3, 224, 224)),
lambda: ArcFace("buffalo_s", (1, 3, 112, 112), optimization_level=3), # lambda: ArcFace("buffalo_s", (1, 3, 112, 112), optimization_level=3),
lambda: RetinaFace("buffalo_s", (1, 3, 640, 640), optimization_level=3), # lambda: RetinaFace("buffalo_s", (1, 3, 640, 640), optimization_level=3),
lambda: ArcFace("buffalo_m", (1, 3, 112, 112), optimization_level=3), # lambda: ArcFace("buffalo_m", (1, 3, 112, 112), optimization_level=3),
lambda: RetinaFace("buffalo_m", (1, 3, 640, 640), optimization_level=3), # lambda: RetinaFace("buffalo_m", (1, 3, 640, 640), optimization_level=3),
lambda: ArcFace("buffalo_l", (1, 3, 112, 112), optimization_level=3), # lambda: ArcFace("buffalo_l", (1, 3, 112, 112), optimization_level=3),
lambda: RetinaFace("buffalo_l", (1, 3, 640, 640), optimization_level=3), # lambda: RetinaFace("buffalo_l", (1, 3, 640, 640), optimization_level=3),
lambda: ArcFace("antelopev2", (1, 3, 112, 112), optimization_level=3), # lambda: ArcFace("antelopev2", (1, 3, 112, 112), optimization_level=3),
lambda: RetinaFace("antelopev2", (1, 3, 640, 640), optimization_level=3), # lambda: RetinaFace("antelopev2", (1, 3, 640, 640), optimization_level=3),
] ]
models: list[Callable[[], ExportBase]] = [*failed, *succeeded] models: list[Callable[[], ExportBase]] = [*failed, *succeeded]
for _model in succeeded: for _model in succeeded:
model = _model().to(device) model = _model()
try: try:
relative_path = os.path.join(model.task, "model.armnn") model_dir = os.path.join("output", model.model_name)
output_path = os.path.join("output", model.model_name, relative_path) output_dir = os.path.join(model_dir, model.task)
model.to_armnn(output_path) armnn_fp32, armnn_fp16 = model.to_armnn(output_dir)
upload_file(path_or_fileobj=output_path, path_in_repo=relative_path, repo_id=model.repo_name) relative_fp32 = os.path.relpath(armnn_fp32, start=model_dir)
if device == torch.device("cuda"): relative_fp16 = os.path.relpath(armnn_fp16, start=model_dir)
model.half() if upload_to_hf and os.path.isfile(armnn_fp32):
relative_path = os.path.join(model.task, "fp16", "model.armnn") upload_file(path_or_fileobj=armnn_fp32, path_in_repo=relative_fp32, repo_id=model.repo_name)
output_path = os.path.join("output", model.model_name, relative_path) if upload_to_hf and os.path.isfile(armnn_fp16):
model.to_armnn(output_path) upload_file(path_or_fileobj=armnn_fp16, path_in_repo=relative_fp16, repo_id=model.repo_name)
upload_file(path_or_fileobj=output_path, path_in_repo=relative_path, repo_id=model.repo_name)
except Exception as exc: except Exception as exc:
print(f"Failed to export {model.model_name} ({model.task}): {exc}") print(f"Failed to export {model.model_name} ({model.task}): {exc}")
if __name__ == "__main__": if __name__ == "__main__":
with torch.no_grad():
main() main()