1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-01 08:31:59 +00:00
This commit is contained in:
mertalev 2024-02-02 23:34:32 -05:00
parent 2c2cf59f09
commit bb56bd3297
No known key found for this signature in database
GPG key ID: 9181CD92C0A1C5E3

View file

@ -9,7 +9,7 @@ from typing import Any
import onnx import onnx
import onnxruntime as ort import onnxruntime as ort
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from onnx.shape_inference import infer_shapes from onnx.shape_inference import infer_shapes_path
from onnx.tools.update_model_dims import update_inputs_outputs_dims from onnx.tools.update_model_dims import update_inputs_outputs_dims
from typing_extensions import Buffer from typing_extensions import Buffer
import ann.ann import ann.ann
@ -117,8 +117,7 @@ class InferenceModel(ABC):
model_path = onnx_path model_path = onnx_path
if any(provider in STATIC_INPUT_PROVIDERS for provider in self.providers): if any(provider in STATIC_INPUT_PROVIDERS for provider in self.providers):
static_path = model_path.parent / "static_1" / "model.onnx" static_path = model_path.parent / "model_static_1.onnx"
static_path.parent.mkdir(parents=True, exist_ok=True)
if not static_path.is_file(): if not static_path.is_file():
self._convert_to_static(model_path, static_path) self._convert_to_static(model_path, static_path)
model_path = static_path model_path = static_path
@ -138,11 +137,12 @@ class InferenceModel(ABC):
return session return session
def _convert_to_static(self, source_path: Path, target_path: Path) -> None: def _convert_to_static(self, source_path: Path, target_path: Path) -> None:
inferred = infer_shapes(onnx.load(source_path)) infer_shapes_path(source_path, strict_mode=True)
inputs = self._get_static_dims(inferred.graph.input) proto = onnx.load(source_path, load_external_data=False)
outputs = self._get_static_dims(inferred.graph.output) inputs = self._get_static_dims(proto.graph.input)
outputs = self._get_static_dims(proto.graph.output)
# check_model gets called in update_inputs_outputs_dims and doesn't work for large models # check_model gets called in update_inputs_outputs_dims
check_model = onnx.checker.check_model check_model = onnx.checker.check_model
try: try:
@ -150,17 +150,11 @@ class InferenceModel(ABC):
pass pass
onnx.checker.check_model = check_model_stub onnx.checker.check_model = check_model_stub
updated_model = update_inputs_outputs_dims(inferred, inputs, outputs) updated_model = update_inputs_outputs_dims(proto, inputs, outputs)
finally: finally:
onnx.checker.check_model = check_model onnx.checker.check_model = check_model
onnx.save( onnx.save(updated_model, target_path)
updated_model,
target_path,
save_as_external_data=True,
all_tensors_to_one_file=False,
size_threshold=1048576,
)
def _get_static_dims(self, graph_io: Any, dim_size: int = 1) -> dict[str, list[int]]: def _get_static_dims(self, graph_io: Any, dim_size: int = 1) -> dict[str, list[int]]:
return { return {