From 1ad348c40710a53c0c576caf95f7fcd923fceeed Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Sun, 7 Jul 2024 18:29:18 -0400 Subject: [PATCH] gather -> slice --- machine-learning/export/ann/run.py | 198 +++++++++++++++++++---------- 1 file changed, 133 insertions(+), 65 deletions(-) diff --git a/machine-learning/export/ann/run.py b/machine-learning/export/ann/run.py index 1475f1d670..e7a15c10fe 100644 --- a/machine-learning/export/ann/run.py +++ b/machine-learning/export/ann/run.py @@ -4,22 +4,27 @@ import subprocess from typing import Callable, ClassVar import onnx -from onnx_graphsurgeon import import_onnx, export_onnx +from onnx_graphsurgeon import Constant, Node, Variable, import_onnx, export_onnx from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed from huggingface_hub import snapshot_download from onnx.shape_inference import infer_shapes_path from huggingface_hub import login, upload_file import onnx2tf +from itertools import chain +import numpy as np +import onnxsim # i can explain # armnn only supports up to 4d tranposes, but the model has a 5d transpose due to a redundant unsqueeze # this function folds the unsqueeze+transpose+squeeze into a single 4d transpose +# it also switches from gather ops to slices since armnn doesn't support 3d gather def onnx_transpose_4d(model_path: str): proto = onnx.load(model_path) graph = import_onnx(proto) + gather_idx = 1 for node in graph.nodes: - for i, link1 in enumerate(node.outputs): + for link1 in node.outputs: if "Unsqueeze" in link1.name: for node1 in link1.outputs: for link2 in node1.outputs: @@ -30,31 +35,87 @@ def onnx_transpose_4d(model_path: str): link2.shape = link1.shape for link3 in node2.outputs: if "Squeeze" in link3.name: + link3.shape = [link3.shape[x] for x in [0, 1, 2, 4]] for node3 in link3.outputs: for link4 in node3.outputs: - link4.shape = [link3.shape[x] for x in [0, 1, 2, 4]] - for inputs in link4.inputs: - if inputs.name == node3.name: - i = link2.inputs.index(node1) - if i >= 0: - link2.inputs[i] = node - - i = link4.inputs.index(node3) - if i >= 0: - link4.inputs[i] = node2 - - node.outputs = [link2] - node1.inputs = [] - node1.outputs = [] - node3.inputs = [] - node3.outputs = [] - + link4.shape = link3.shape + try: + idx = link2.inputs.index(node1) + link2.inputs[idx] = node + except ValueError: + pass + + node.outputs = [link2] + if "Gather" in link4.name: + for node4 in link4.outputs: + index = node4.inputs[1].values + slice_link = Variable( + f"onnx::Slice_123{gather_idx}", + dtype=link4.dtype, + shape=[1] + link3.shape[1:], + ) + slice_node = Node( + op="Slice", + inputs=[ + link3, + Constant( + f"SliceStart_123{gather_idx}", + np.array([index, 0, 0, 0]), + ), + Constant( + f"SliceEnd_123{gather_idx}", + np.array([index + 1] + link3.shape[1:]), + ), + ], + outputs=[slice_link], + name=f"Slice_123{gather_idx}", + ) + graph.nodes.append(slice_node) + gather_idx += 1 + + for link5 in node4.outputs: + for node5 in link5.outputs: + try: + idx = node5.inputs.index(link5) + node5.inputs[idx] = slice_link + except ValueError: + pass + graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True) graph.toposort() graph.fold_constants() updated = export_onnx(graph) - onnx.save(updated, model_path, save_as_external_data=True, all_tensors_to_one_file=False) + onnx.save(updated, model_path) + # infer_shapes_path(updated, check_type=True, strict_mode=False, data_prop=True) + + # for some reason, reloading the model is necessary to apply the correct shape + proto = onnx.load(model_path) + graph = import_onnx(proto) + for node in graph.nodes: + if node.op == "Slice": + for link in node.outputs: + if "Slice_123" in link.name and link.shape[0] == 3: + link.shape[0] = 1 + + graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True) + graph.toposort() + graph.fold_constants() + updated = export_onnx(graph) + onnx.save(updated, model_path) infer_shapes_path(model_path, check_type=True, strict_mode=True, data_prop=True) + + +def onnx_make_fixed(input_path: str, output_path: str, input_shape: tuple[int, ...]): + simplified, success = onnxsim.simplify(input_path) + if not success: + raise RuntimeError(f"Failed to simplify {input_path}") + onnx.save(simplified, input_path) + infer_shapes_path(input_path, check_type=True, strict_mode=True, data_prop=True) + model = onnx.load_model(input_path) + make_input_shape_fixed(model.graph, model.graph.input[0].name, input_shape) + fix_output_shapes(model) + onnx.save(model, output_path, save_as_external_data=True, all_tensors_to_one_file=False) + infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True) class ExportBase: @@ -73,29 +134,27 @@ class ExportBase: self.nchw_transpose = False self.input_shape = input_shape self.pretrained = pretrained - - def to_onnx_static(self) -> str: - cache_dir = os.path.join(os.environ["CACHE_DIR"], self.model_name) - task_path = os.path.join(cache_dir, self.task) - model_path = os.path.join(task_path, "model.onnx") + self.cache_dir = os.path.join(os.environ["CACHE_DIR"], self.model_name) + + def download(self) -> str: + model_path = os.path.join(self.cache_dir, self.task, "model.onnx") if not os.path.isfile(model_path): print(f"Downloading {self.model_name}...") - snapshot_download(self.repo_name, cache_dir=cache_dir, local_dir=cache_dir) - - static_dir = os.path.join(task_path, "static") - static_path = os.path.join(static_dir, "model.onnx") + snapshot_download(self.repo_name, cache_dir=self.cache_dir, local_dir=self.cache_dir, local_dir_use_symlinks=False) + return model_path + + def to_onnx_static(self) -> str: + onnx_path_original = self.download() + static_dir = os.path.join(self.cache_dir, self.task, "static") os.makedirs(static_dir, exist_ok=True) - if not os.path.isfile(static_path): - print(f"Making {self.model_name} ({self.task}) static") - infer_shapes_path(onnx_path_original, check_type=True, strict_mode=True, data_prop=True) - onnx_path_original = os.path.join(cache_dir, "model.onnx") - static_model = onnx.load_model(onnx_path_original) - make_input_shape_fixed(static_model.graph, static_model.graph.input[0].name, (1, 3, 224, 224)) - fix_output_shapes(static_model) - onnx.save(static_model, static_path, save_as_external_data=True, all_tensors_to_one_file=False) - infer_shapes_path(static_path, check_type=True, strict_mode=True, data_prop=True) - onnx_transpose_4d(static_path) + static_path = os.path.join(static_dir, "model.onnx") + print(f"Making {self.model_name} ({self.task}) static") + onnx_make_fixed(onnx_path_original, static_path, self.input_shape) + onnx_transpose_4d(static_path) + static_model = onnx.load_model(static_path) + self.inputs = [input_.name for input_ in static_model.graph.input] + self.outputs = [output_.name for output_ in static_model.graph.output] return static_path def to_tflite(self, output_dir: str) -> tuple[str, str]: @@ -122,40 +181,48 @@ class ExportBase: armnn_fp32 = os.path.join(output_dir, "model.armnn") armnn_fp16 = os.path.join(fp16_dir, "model.armnn") + input_tensors = list(chain.from_iterable(("-i", input_) for input_ in self.inputs)), + output_tensors = list(chain.from_iterable(("-o", output_) for output_ in self.outputs)), + print(f"{input_tensors=}") + print(f"{output_tensors=}") + args = [ + "./armnnconverter", + "-f", + "tflite-binary", + "-m", + tflite_fp32, + "-p", + armnn_fp32, + ] + for input_ in self.inputs: + args.extend(["-i", input_]) + for output_ in self.outputs: + args.extend(["-o", output_]) + print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp32 precision") subprocess.run( - [ - "./armnnconverter", - "-f", - "tflite-binary", - "-m", - tflite_fp32, - "-i", - "input_tensor", - "-o", - "output_tensor", - "-p", - armnn_fp32, - ], + args, capture_output=True, ) print(f"Finished exporting {self.name} ({self.task}) with fp32 precision") + args = [ + "./armnnconverter", + "-f", + "tflite-binary", + "-m", + tflite_fp16, + "-p", + armnn_fp16, + ] + for input_ in self.inputs: + args.extend(["-i", input_]) + for output_ in self.outputs: + args.extend(["-o", output_]) + print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp16 precision") subprocess.run( - [ - "./armnnconverter", - "-f", - "tflite-binary", - "-m", - tflite_fp16, - "-i", - "input_tensor", - "-o", - "output_tensor", - "-p", - armnn_fp16, - ], + args, capture_output=True, ) print(f"Finished exporting {self.name} ({self.task}) with fp16 precision") @@ -280,6 +347,7 @@ def main() -> None: upload_file(path_or_fileobj=armnn_fp16, path_in_repo=relative_fp16, repo_id=model.repo_name) except Exception as exc: print(f"Failed to export {model.model_name} ({model.task}): {exc}") + raise exc if __name__ == "__main__":