mirror of
https://github.com/immich-app/immich.git
synced 2025-01-04 02:46:47 +01:00
gather -> slice
This commit is contained in:
parent
5dae920ac6
commit
1ad348c407
1 changed files with 133 additions and 65 deletions
|
@ -4,22 +4,27 @@ import subprocess
|
||||||
from typing import Callable, ClassVar
|
from typing import Callable, ClassVar
|
||||||
|
|
||||||
import onnx
|
import onnx
|
||||||
from onnx_graphsurgeon import import_onnx, export_onnx
|
from onnx_graphsurgeon import Constant, Node, Variable, import_onnx, export_onnx
|
||||||
from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed
|
from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from onnx.shape_inference import infer_shapes_path
|
from onnx.shape_inference import infer_shapes_path
|
||||||
from huggingface_hub import login, upload_file
|
from huggingface_hub import login, upload_file
|
||||||
import onnx2tf
|
import onnx2tf
|
||||||
|
from itertools import chain
|
||||||
|
import numpy as np
|
||||||
|
import onnxsim
|
||||||
|
|
||||||
# i can explain
|
# i can explain
|
||||||
# armnn only supports up to 4d tranposes, but the model has a 5d transpose due to a redundant unsqueeze
|
# armnn only supports up to 4d tranposes, but the model has a 5d transpose due to a redundant unsqueeze
|
||||||
# this function folds the unsqueeze+transpose+squeeze into a single 4d transpose
|
# this function folds the unsqueeze+transpose+squeeze into a single 4d transpose
|
||||||
|
# it also switches from gather ops to slices since armnn doesn't support 3d gather
|
||||||
def onnx_transpose_4d(model_path: str):
|
def onnx_transpose_4d(model_path: str):
|
||||||
proto = onnx.load(model_path)
|
proto = onnx.load(model_path)
|
||||||
graph = import_onnx(proto)
|
graph = import_onnx(proto)
|
||||||
|
|
||||||
|
gather_idx = 1
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
for i, link1 in enumerate(node.outputs):
|
for link1 in node.outputs:
|
||||||
if "Unsqueeze" in link1.name:
|
if "Unsqueeze" in link1.name:
|
||||||
for node1 in link1.outputs:
|
for node1 in link1.outputs:
|
||||||
for link2 in node1.outputs:
|
for link2 in node1.outputs:
|
||||||
|
@ -30,33 +35,89 @@ def onnx_transpose_4d(model_path: str):
|
||||||
link2.shape = link1.shape
|
link2.shape = link1.shape
|
||||||
for link3 in node2.outputs:
|
for link3 in node2.outputs:
|
||||||
if "Squeeze" in link3.name:
|
if "Squeeze" in link3.name:
|
||||||
|
link3.shape = [link3.shape[x] for x in [0, 1, 2, 4]]
|
||||||
for node3 in link3.outputs:
|
for node3 in link3.outputs:
|
||||||
for link4 in node3.outputs:
|
for link4 in node3.outputs:
|
||||||
link4.shape = [link3.shape[x] for x in [0, 1, 2, 4]]
|
link4.shape = link3.shape
|
||||||
for inputs in link4.inputs:
|
try:
|
||||||
if inputs.name == node3.name:
|
idx = link2.inputs.index(node1)
|
||||||
i = link2.inputs.index(node1)
|
link2.inputs[idx] = node
|
||||||
if i >= 0:
|
except ValueError:
|
||||||
link2.inputs[i] = node
|
pass
|
||||||
|
|
||||||
i = link4.inputs.index(node3)
|
node.outputs = [link2]
|
||||||
if i >= 0:
|
if "Gather" in link4.name:
|
||||||
link4.inputs[i] = node2
|
for node4 in link4.outputs:
|
||||||
|
index = node4.inputs[1].values
|
||||||
|
slice_link = Variable(
|
||||||
|
f"onnx::Slice_123{gather_idx}",
|
||||||
|
dtype=link4.dtype,
|
||||||
|
shape=[1] + link3.shape[1:],
|
||||||
|
)
|
||||||
|
slice_node = Node(
|
||||||
|
op="Slice",
|
||||||
|
inputs=[
|
||||||
|
link3,
|
||||||
|
Constant(
|
||||||
|
f"SliceStart_123{gather_idx}",
|
||||||
|
np.array([index, 0, 0, 0]),
|
||||||
|
),
|
||||||
|
Constant(
|
||||||
|
f"SliceEnd_123{gather_idx}",
|
||||||
|
np.array([index + 1] + link3.shape[1:]),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[slice_link],
|
||||||
|
name=f"Slice_123{gather_idx}",
|
||||||
|
)
|
||||||
|
graph.nodes.append(slice_node)
|
||||||
|
gather_idx += 1
|
||||||
|
|
||||||
node.outputs = [link2]
|
for link5 in node4.outputs:
|
||||||
node1.inputs = []
|
for node5 in link5.outputs:
|
||||||
node1.outputs = []
|
try:
|
||||||
node3.inputs = []
|
idx = node5.inputs.index(link5)
|
||||||
node3.outputs = []
|
node5.inputs[idx] = slice_link
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
|
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
|
||||||
graph.toposort()
|
graph.toposort()
|
||||||
graph.fold_constants()
|
graph.fold_constants()
|
||||||
updated = export_onnx(graph)
|
updated = export_onnx(graph)
|
||||||
onnx.save(updated, model_path, save_as_external_data=True, all_tensors_to_one_file=False)
|
onnx.save(updated, model_path)
|
||||||
|
# infer_shapes_path(updated, check_type=True, strict_mode=False, data_prop=True)
|
||||||
|
|
||||||
|
# for some reason, reloading the model is necessary to apply the correct shape
|
||||||
|
proto = onnx.load(model_path)
|
||||||
|
graph = import_onnx(proto)
|
||||||
|
for node in graph.nodes:
|
||||||
|
if node.op == "Slice":
|
||||||
|
for link in node.outputs:
|
||||||
|
if "Slice_123" in link.name and link.shape[0] == 3:
|
||||||
|
link.shape[0] = 1
|
||||||
|
|
||||||
|
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
|
||||||
|
graph.toposort()
|
||||||
|
graph.fold_constants()
|
||||||
|
updated = export_onnx(graph)
|
||||||
|
onnx.save(updated, model_path)
|
||||||
infer_shapes_path(model_path, check_type=True, strict_mode=True, data_prop=True)
|
infer_shapes_path(model_path, check_type=True, strict_mode=True, data_prop=True)
|
||||||
|
|
||||||
|
|
||||||
|
def onnx_make_fixed(input_path: str, output_path: str, input_shape: tuple[int, ...]):
|
||||||
|
simplified, success = onnxsim.simplify(input_path)
|
||||||
|
if not success:
|
||||||
|
raise RuntimeError(f"Failed to simplify {input_path}")
|
||||||
|
onnx.save(simplified, input_path)
|
||||||
|
infer_shapes_path(input_path, check_type=True, strict_mode=True, data_prop=True)
|
||||||
|
model = onnx.load_model(input_path)
|
||||||
|
make_input_shape_fixed(model.graph, model.graph.input[0].name, input_shape)
|
||||||
|
fix_output_shapes(model)
|
||||||
|
onnx.save(model, output_path, save_as_external_data=True, all_tensors_to_one_file=False)
|
||||||
|
infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True)
|
||||||
|
|
||||||
|
|
||||||
class ExportBase:
|
class ExportBase:
|
||||||
task: ClassVar[str]
|
task: ClassVar[str]
|
||||||
|
|
||||||
|
@ -73,29 +134,27 @@ class ExportBase:
|
||||||
self.nchw_transpose = False
|
self.nchw_transpose = False
|
||||||
self.input_shape = input_shape
|
self.input_shape = input_shape
|
||||||
self.pretrained = pretrained
|
self.pretrained = pretrained
|
||||||
|
self.cache_dir = os.path.join(os.environ["CACHE_DIR"], self.model_name)
|
||||||
|
|
||||||
def to_onnx_static(self) -> str:
|
def download(self) -> str:
|
||||||
cache_dir = os.path.join(os.environ["CACHE_DIR"], self.model_name)
|
model_path = os.path.join(self.cache_dir, self.task, "model.onnx")
|
||||||
task_path = os.path.join(cache_dir, self.task)
|
|
||||||
model_path = os.path.join(task_path, "model.onnx")
|
|
||||||
if not os.path.isfile(model_path):
|
if not os.path.isfile(model_path):
|
||||||
print(f"Downloading {self.model_name}...")
|
print(f"Downloading {self.model_name}...")
|
||||||
snapshot_download(self.repo_name, cache_dir=cache_dir, local_dir=cache_dir)
|
snapshot_download(self.repo_name, cache_dir=self.cache_dir, local_dir=self.cache_dir, local_dir_use_symlinks=False)
|
||||||
|
return model_path
|
||||||
|
|
||||||
static_dir = os.path.join(task_path, "static")
|
def to_onnx_static(self) -> str:
|
||||||
static_path = os.path.join(static_dir, "model.onnx")
|
onnx_path_original = self.download()
|
||||||
|
static_dir = os.path.join(self.cache_dir, self.task, "static")
|
||||||
os.makedirs(static_dir, exist_ok=True)
|
os.makedirs(static_dir, exist_ok=True)
|
||||||
|
|
||||||
if not os.path.isfile(static_path):
|
static_path = os.path.join(static_dir, "model.onnx")
|
||||||
print(f"Making {self.model_name} ({self.task}) static")
|
print(f"Making {self.model_name} ({self.task}) static")
|
||||||
infer_shapes_path(onnx_path_original, check_type=True, strict_mode=True, data_prop=True)
|
onnx_make_fixed(onnx_path_original, static_path, self.input_shape)
|
||||||
onnx_path_original = os.path.join(cache_dir, "model.onnx")
|
onnx_transpose_4d(static_path)
|
||||||
static_model = onnx.load_model(onnx_path_original)
|
static_model = onnx.load_model(static_path)
|
||||||
make_input_shape_fixed(static_model.graph, static_model.graph.input[0].name, (1, 3, 224, 224))
|
self.inputs = [input_.name for input_ in static_model.graph.input]
|
||||||
fix_output_shapes(static_model)
|
self.outputs = [output_.name for output_ in static_model.graph.output]
|
||||||
onnx.save(static_model, static_path, save_as_external_data=True, all_tensors_to_one_file=False)
|
|
||||||
infer_shapes_path(static_path, check_type=True, strict_mode=True, data_prop=True)
|
|
||||||
onnx_transpose_4d(static_path)
|
|
||||||
return static_path
|
return static_path
|
||||||
|
|
||||||
def to_tflite(self, output_dir: str) -> tuple[str, str]:
|
def to_tflite(self, output_dir: str) -> tuple[str, str]:
|
||||||
|
@ -122,40 +181,48 @@ class ExportBase:
|
||||||
armnn_fp32 = os.path.join(output_dir, "model.armnn")
|
armnn_fp32 = os.path.join(output_dir, "model.armnn")
|
||||||
armnn_fp16 = os.path.join(fp16_dir, "model.armnn")
|
armnn_fp16 = os.path.join(fp16_dir, "model.armnn")
|
||||||
|
|
||||||
|
input_tensors = list(chain.from_iterable(("-i", input_) for input_ in self.inputs)),
|
||||||
|
output_tensors = list(chain.from_iterable(("-o", output_) for output_ in self.outputs)),
|
||||||
|
print(f"{input_tensors=}")
|
||||||
|
print(f"{output_tensors=}")
|
||||||
|
args = [
|
||||||
|
"./armnnconverter",
|
||||||
|
"-f",
|
||||||
|
"tflite-binary",
|
||||||
|
"-m",
|
||||||
|
tflite_fp32,
|
||||||
|
"-p",
|
||||||
|
armnn_fp32,
|
||||||
|
]
|
||||||
|
for input_ in self.inputs:
|
||||||
|
args.extend(["-i", input_])
|
||||||
|
for output_ in self.outputs:
|
||||||
|
args.extend(["-o", output_])
|
||||||
|
|
||||||
print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp32 precision")
|
print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp32 precision")
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
[
|
args,
|
||||||
"./armnnconverter",
|
|
||||||
"-f",
|
|
||||||
"tflite-binary",
|
|
||||||
"-m",
|
|
||||||
tflite_fp32,
|
|
||||||
"-i",
|
|
||||||
"input_tensor",
|
|
||||||
"-o",
|
|
||||||
"output_tensor",
|
|
||||||
"-p",
|
|
||||||
armnn_fp32,
|
|
||||||
],
|
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
)
|
)
|
||||||
print(f"Finished exporting {self.name} ({self.task}) with fp32 precision")
|
print(f"Finished exporting {self.name} ({self.task}) with fp32 precision")
|
||||||
|
|
||||||
|
args = [
|
||||||
|
"./armnnconverter",
|
||||||
|
"-f",
|
||||||
|
"tflite-binary",
|
||||||
|
"-m",
|
||||||
|
tflite_fp16,
|
||||||
|
"-p",
|
||||||
|
armnn_fp16,
|
||||||
|
]
|
||||||
|
for input_ in self.inputs:
|
||||||
|
args.extend(["-i", input_])
|
||||||
|
for output_ in self.outputs:
|
||||||
|
args.extend(["-o", output_])
|
||||||
|
|
||||||
print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp16 precision")
|
print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp16 precision")
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
[
|
args,
|
||||||
"./armnnconverter",
|
|
||||||
"-f",
|
|
||||||
"tflite-binary",
|
|
||||||
"-m",
|
|
||||||
tflite_fp16,
|
|
||||||
"-i",
|
|
||||||
"input_tensor",
|
|
||||||
"-o",
|
|
||||||
"output_tensor",
|
|
||||||
"-p",
|
|
||||||
armnn_fp16,
|
|
||||||
],
|
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
)
|
)
|
||||||
print(f"Finished exporting {self.name} ({self.task}) with fp16 precision")
|
print(f"Finished exporting {self.name} ({self.task}) with fp16 precision")
|
||||||
|
@ -280,6 +347,7 @@ def main() -> None:
|
||||||
upload_file(path_or_fileobj=armnn_fp16, path_in_repo=relative_fp16, repo_id=model.repo_name)
|
upload_file(path_or_fileobj=armnn_fp16, path_in_repo=relative_fp16, repo_id=model.repo_name)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
print(f"Failed to export {model.model_name} ({model.task}): {exc}")
|
print(f"Failed to export {model.model_name} ({model.task}): {exc}")
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in a new issue