mirror of
https://github.com/immich-app/immich.git
synced 2024-12-29 15:11:58 +00:00
fixes
This commit is contained in:
parent
3d62011ae3
commit
b39cca1b43
1 changed files with 28 additions and 14 deletions
|
@ -12,6 +12,7 @@ from huggingface_hub import login, upload_file
|
|||
import onnx2tf
|
||||
import numpy as np
|
||||
import onnxsim
|
||||
from shutil import rmtree
|
||||
|
||||
# i can explain
|
||||
# armnn only supports up to 4d tranposes, but the model has a 5d transpose due to a redundant unsqueeze
|
||||
|
@ -167,9 +168,9 @@ def onnx_make_fixed(input_path: str, output_path: str, input_shape: tuple[int, .
|
|||
simplified, success = onnxsim.simplify(input_path)
|
||||
if not success:
|
||||
raise RuntimeError(f"Failed to simplify {input_path}")
|
||||
onnx.save(simplified, input_path)
|
||||
infer_shapes_path(input_path, check_type=True, strict_mode=True, data_prop=True)
|
||||
model = onnx.load_model(input_path)
|
||||
onnx.save(simplified, output_path, save_as_external_data=True, all_tensors_to_one_file=False)
|
||||
infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True)
|
||||
model = onnx.load_model(output_path)
|
||||
make_input_shape_fixed(model.graph, model.graph.input[0].name, input_shape)
|
||||
fix_output_shapes(model)
|
||||
onnx.save(model, output_path, save_as_external_data=True, all_tensors_to_one_file=False)
|
||||
|
@ -218,20 +219,23 @@ class ExportBase:
|
|||
|
||||
def to_tflite(self, output_dir: str) -> tuple[str, str]:
|
||||
input_path = self.to_onnx_static()
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
tflite_fp32 = os.path.join(output_dir, "model_float32.tflite")
|
||||
tflite_fp16 = os.path.join(output_dir, "model_float16.tflite")
|
||||
if not os.path.isfile(tflite_fp32) or not os.path.isfile(tflite_fp16):
|
||||
print(f"Exporting {self.model_name} ({self.task}) to TFLite")
|
||||
print(f"Exporting {self.model_name} ({self.task}) to TFLite (this might take a few minutes)")
|
||||
onnx2tf.convert(
|
||||
input_onnx_file_path=input_path,
|
||||
output_folder_path=output_dir,
|
||||
keep_shape_absolutely_input_names=self.inputs,
|
||||
verbosity="warn",
|
||||
copy_onnx_input_output_names_to_tflite=True,
|
||||
output_signaturedefs=True,
|
||||
)
|
||||
|
||||
return tflite_fp32, tflite_fp16
|
||||
|
||||
def to_armnn(self, output_dir: str) -> tuple[str, str]:
|
||||
output_dir = os.path.abspath(output_dir)
|
||||
tflite_model_dir = os.path.join(output_dir, "tflite")
|
||||
tflite_fp32, tflite_fp16 = self.to_tflite(tflite_model_dir)
|
||||
|
||||
|
@ -240,28 +244,38 @@ class ExportBase:
|
|||
armnn_fp32 = os.path.join(output_dir, "model.armnn")
|
||||
armnn_fp16 = os.path.join(fp16_dir, "model.armnn")
|
||||
|
||||
args = [
|
||||
"./armnnconverter",
|
||||
"-f",
|
||||
"tflite-binary",
|
||||
]
|
||||
args = ["./armnnconverter", "-f", "tflite-binary"]
|
||||
for input_ in self.inputs:
|
||||
args.extend(["-i", input_])
|
||||
for output_ in self.outputs:
|
||||
args.extend(["-o", output_])
|
||||
|
||||
fp32_args = args.copy()
|
||||
fp32_args.extend(["-m", tflite_fp32, "-p", tflite_fp32])
|
||||
fp32_args.extend(["-m", tflite_fp32, "-p", armnn_fp32])
|
||||
|
||||
print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp32 precision")
|
||||
subprocess.run(fp32_args, capture_output=True)
|
||||
try:
|
||||
print(subprocess.check_output(fp32_args, stderr=subprocess.STDOUT).decode())
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e.output.decode())
|
||||
try:
|
||||
rmtree(tflite_model_dir, ignore_errors=True)
|
||||
finally:
|
||||
raise e
|
||||
print(f"Finished exporting {self.name} ({self.task}) with fp32 precision")
|
||||
|
||||
fp16_args = args.copy()
|
||||
fp32_args.extend(["-m", tflite_fp16, "-p", tflite_fp16])
|
||||
fp16_args.extend(["-m", tflite_fp16, "-p", armnn_fp16])
|
||||
|
||||
print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp16 precision")
|
||||
subprocess.run(fp16_args, capture_output=True)
|
||||
try:
|
||||
print(subprocess.check_output(fp16_args, stderr=subprocess.STDOUT).decode())
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e.output.decode())
|
||||
try:
|
||||
rmtree(tflite_model_dir, ignore_errors=True)
|
||||
finally:
|
||||
raise e
|
||||
print(f"Finished exporting {self.name} ({self.task}) with fp16 precision")
|
||||
|
||||
return armnn_fp32, armnn_fp16
|
||||
|
|
Loading…
Reference in a new issue