1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2024-12-29 15:11:58 +00:00
This commit is contained in:
mertalev 2024-07-08 18:19:35 -04:00
parent 3d62011ae3
commit b39cca1b43
No known key found for this signature in database
GPG key ID: 9181CD92C0A1C5E3

View file

@ -12,6 +12,7 @@ from huggingface_hub import login, upload_file
import onnx2tf
import numpy as np
import onnxsim
from shutil import rmtree
# i can explain
# armnn only supports up to 4d tranposes, but the model has a 5d transpose due to a redundant unsqueeze
@ -167,9 +168,9 @@ def onnx_make_fixed(input_path: str, output_path: str, input_shape: tuple[int, .
simplified, success = onnxsim.simplify(input_path)
if not success:
raise RuntimeError(f"Failed to simplify {input_path}")
onnx.save(simplified, input_path)
infer_shapes_path(input_path, check_type=True, strict_mode=True, data_prop=True)
model = onnx.load_model(input_path)
onnx.save(simplified, output_path, save_as_external_data=True, all_tensors_to_one_file=False)
infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True)
model = onnx.load_model(output_path)
make_input_shape_fixed(model.graph, model.graph.input[0].name, input_shape)
fix_output_shapes(model)
onnx.save(model, output_path, save_as_external_data=True, all_tensors_to_one_file=False)
@ -218,20 +219,23 @@ class ExportBase:
def to_tflite(self, output_dir: str) -> tuple[str, str]:
input_path = self.to_onnx_static()
os.makedirs(output_dir, exist_ok=True)
tflite_fp32 = os.path.join(output_dir, "model_float32.tflite")
tflite_fp16 = os.path.join(output_dir, "model_float16.tflite")
if not os.path.isfile(tflite_fp32) or not os.path.isfile(tflite_fp16):
print(f"Exporting {self.model_name} ({self.task}) to TFLite")
print(f"Exporting {self.model_name} ({self.task}) to TFLite (this might take a few minutes)")
onnx2tf.convert(
input_onnx_file_path=input_path,
output_folder_path=output_dir,
keep_shape_absolutely_input_names=self.inputs,
verbosity="warn",
copy_onnx_input_output_names_to_tflite=True,
output_signaturedefs=True,
)
return tflite_fp32, tflite_fp16
def to_armnn(self, output_dir: str) -> tuple[str, str]:
output_dir = os.path.abspath(output_dir)
tflite_model_dir = os.path.join(output_dir, "tflite")
tflite_fp32, tflite_fp16 = self.to_tflite(tflite_model_dir)
@ -240,28 +244,38 @@ class ExportBase:
armnn_fp32 = os.path.join(output_dir, "model.armnn")
armnn_fp16 = os.path.join(fp16_dir, "model.armnn")
args = [
"./armnnconverter",
"-f",
"tflite-binary",
]
args = ["./armnnconverter", "-f", "tflite-binary"]
for input_ in self.inputs:
args.extend(["-i", input_])
for output_ in self.outputs:
args.extend(["-o", output_])
fp32_args = args.copy()
fp32_args.extend(["-m", tflite_fp32, "-p", tflite_fp32])
fp32_args.extend(["-m", tflite_fp32, "-p", armnn_fp32])
print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp32 precision")
subprocess.run(fp32_args, capture_output=True)
try:
print(subprocess.check_output(fp32_args, stderr=subprocess.STDOUT).decode())
except subprocess.CalledProcessError as e:
print(e.output.decode())
try:
rmtree(tflite_model_dir, ignore_errors=True)
finally:
raise e
print(f"Finished exporting {self.name} ({self.task}) with fp32 precision")
fp16_args = args.copy()
fp32_args.extend(["-m", tflite_fp16, "-p", tflite_fp16])
fp16_args.extend(["-m", tflite_fp16, "-p", armnn_fp16])
print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp16 precision")
subprocess.run(fp16_args, capture_output=True)
try:
print(subprocess.check_output(fp16_args, stderr=subprocess.STDOUT).decode())
except subprocess.CalledProcessError as e:
print(e.output.decode())
try:
rmtree(tflite_model_dir, ignore_errors=True)
finally:
raise e
print(f"Finished exporting {self.name} ({self.task}) with fp16 precision")
return armnn_fp32, armnn_fp16