From b39cca1b43cc73fff21d7e26f469923adb95fc28 Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Mon, 8 Jul 2024 18:19:35 -0400 Subject: [PATCH] fixes --- machine-learning/export/ann/run.py | 42 ++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/machine-learning/export/ann/run.py b/machine-learning/export/ann/run.py index 7cbe8de3ae..0912f18338 100644 --- a/machine-learning/export/ann/run.py +++ b/machine-learning/export/ann/run.py @@ -12,6 +12,7 @@ from huggingface_hub import login, upload_file import onnx2tf import numpy as np import onnxsim +from shutil import rmtree # i can explain # armnn only supports up to 4d tranposes, but the model has a 5d transpose due to a redundant unsqueeze @@ -167,9 +168,9 @@ def onnx_make_fixed(input_path: str, output_path: str, input_shape: tuple[int, . simplified, success = onnxsim.simplify(input_path) if not success: raise RuntimeError(f"Failed to simplify {input_path}") - onnx.save(simplified, input_path) - infer_shapes_path(input_path, check_type=True, strict_mode=True, data_prop=True) - model = onnx.load_model(input_path) + onnx.save(simplified, output_path, save_as_external_data=True, all_tensors_to_one_file=False) + infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True) + model = onnx.load_model(output_path) make_input_shape_fixed(model.graph, model.graph.input[0].name, input_shape) fix_output_shapes(model) onnx.save(model, output_path, save_as_external_data=True, all_tensors_to_one_file=False) @@ -218,20 +219,23 @@ class ExportBase: def to_tflite(self, output_dir: str) -> tuple[str, str]: input_path = self.to_onnx_static() - os.makedirs(output_dir, exist_ok=True) tflite_fp32 = os.path.join(output_dir, "model_float32.tflite") tflite_fp16 = os.path.join(output_dir, "model_float16.tflite") if not os.path.isfile(tflite_fp32) or not os.path.isfile(tflite_fp16): - print(f"Exporting {self.model_name} ({self.task}) to TFLite") + print(f"Exporting {self.model_name} ({self.task}) to TFLite (this might take a few minutes)") onnx2tf.convert( input_onnx_file_path=input_path, output_folder_path=output_dir, + keep_shape_absolutely_input_names=self.inputs, + verbosity="warn", copy_onnx_input_output_names_to_tflite=True, + output_signaturedefs=True, ) return tflite_fp32, tflite_fp16 def to_armnn(self, output_dir: str) -> tuple[str, str]: + output_dir = os.path.abspath(output_dir) tflite_model_dir = os.path.join(output_dir, "tflite") tflite_fp32, tflite_fp16 = self.to_tflite(tflite_model_dir) @@ -240,28 +244,38 @@ class ExportBase: armnn_fp32 = os.path.join(output_dir, "model.armnn") armnn_fp16 = os.path.join(fp16_dir, "model.armnn") - args = [ - "./armnnconverter", - "-f", - "tflite-binary", - ] + args = ["./armnnconverter", "-f", "tflite-binary"] for input_ in self.inputs: args.extend(["-i", input_]) for output_ in self.outputs: args.extend(["-o", output_]) fp32_args = args.copy() - fp32_args.extend(["-m", tflite_fp32, "-p", tflite_fp32]) + fp32_args.extend(["-m", tflite_fp32, "-p", armnn_fp32]) print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp32 precision") - subprocess.run(fp32_args, capture_output=True) + try: + print(subprocess.check_output(fp32_args, stderr=subprocess.STDOUT).decode()) + except subprocess.CalledProcessError as e: + print(e.output.decode()) + try: + rmtree(tflite_model_dir, ignore_errors=True) + finally: + raise e print(f"Finished exporting {self.name} ({self.task}) with fp32 precision") fp16_args = args.copy() - fp32_args.extend(["-m", tflite_fp16, "-p", tflite_fp16]) + fp16_args.extend(["-m", tflite_fp16, "-p", armnn_fp16]) print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp16 precision") - subprocess.run(fp16_args, capture_output=True) + try: + print(subprocess.check_output(fp16_args, stderr=subprocess.STDOUT).decode()) + except subprocess.CalledProcessError as e: + print(e.output.decode()) + try: + rmtree(tflite_model_dir, ignore_errors=True) + finally: + raise e print(f"Finished exporting {self.name} ({self.task}) with fp16 precision") return armnn_fp32, armnn_fp16