mirror of
https://github.com/immich-app/immich.git
synced 2025-01-07 20:36:48 +01:00
158 lines
4.9 KiB
Python
158 lines
4.9 KiB
Python
|
import logging
|
||
|
import os
|
||
|
import platform
|
||
|
import subprocess
|
||
|
from abc import abstractmethod
|
||
|
|
||
|
import onnx
|
||
|
import open_clip
|
||
|
import torch
|
||
|
from onnx2torch import convert
|
||
|
from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed
|
||
|
from tinynn.converter import TFLiteConverter
|
||
|
|
||
|
|
||
|
class ExportBase(torch.nn.Module):
|
||
|
input_shape: tuple[int, ...]
|
||
|
|
||
|
def __init__(self, device: torch.device, name: str):
|
||
|
super().__init__()
|
||
|
self.device = device
|
||
|
self.name = name
|
||
|
self.optimize = 5
|
||
|
self.nchw_transpose = False
|
||
|
|
||
|
@abstractmethod
|
||
|
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor]:
|
||
|
pass
|
||
|
|
||
|
def dummy_input(self) -> torch.FloatTensor:
|
||
|
return torch.rand((1, 3, 224, 224), device=self.device)
|
||
|
|
||
|
|
||
|
class ArcFace(ExportBase):
|
||
|
input_shape = (1, 3, 112, 112)
|
||
|
|
||
|
def __init__(self, onnx_model_path: str, device: torch.device):
|
||
|
name, _ = os.path.splitext(os.path.basename(onnx_model_path))
|
||
|
super().__init__(device, name)
|
||
|
onnx_model = onnx.load_model(onnx_model_path)
|
||
|
make_input_shape_fixed(onnx_model.graph, onnx_model.graph.input[0].name, self.input_shape)
|
||
|
fix_output_shapes(onnx_model)
|
||
|
self.model = convert(onnx_model).to(device)
|
||
|
if self.device.type == "cuda":
|
||
|
self.model = self.model.half()
|
||
|
|
||
|
def forward(self, input_tensor: torch.Tensor) -> torch.FloatTensor:
|
||
|
embedding: torch.FloatTensor = self.model(
|
||
|
input_tensor.half() if self.device.type == "cuda" else input_tensor
|
||
|
).float()
|
||
|
assert isinstance(embedding, torch.FloatTensor)
|
||
|
return embedding
|
||
|
|
||
|
def dummy_input(self) -> torch.FloatTensor:
|
||
|
return torch.rand(self.input_shape, device=self.device)
|
||
|
|
||
|
|
||
|
class RetinaFace(ExportBase):
|
||
|
input_shape = (1, 3, 640, 640)
|
||
|
|
||
|
def __init__(self, onnx_model_path: str, device: torch.device):
|
||
|
name, _ = os.path.splitext(os.path.basename(onnx_model_path))
|
||
|
super().__init__(device, name)
|
||
|
self.optimize = 3
|
||
|
self.model = convert(onnx_model_path).eval().to(device)
|
||
|
if self.device.type == "cuda":
|
||
|
self.model = self.model.half()
|
||
|
|
||
|
def forward(self, input_tensor: torch.Tensor) -> tuple[torch.FloatTensor]:
|
||
|
out: torch.Tensor = self.model(input_tensor.half() if self.device.type == "cuda" else input_tensor)
|
||
|
return tuple(o.float() for o in out)
|
||
|
|
||
|
def dummy_input(self) -> torch.FloatTensor:
|
||
|
return torch.rand(self.input_shape, device=self.device)
|
||
|
|
||
|
|
||
|
class ClipVision(ExportBase):
|
||
|
input_shape = (1, 3, 224, 224)
|
||
|
|
||
|
def __init__(self, model_name: str, weights: str, device: torch.device):
|
||
|
super().__init__(device, model_name + "__" + weights)
|
||
|
self.model = open_clip.create_model(
|
||
|
model_name,
|
||
|
weights,
|
||
|
precision="fp16" if device.type == "cuda" else "fp32",
|
||
|
jit=False,
|
||
|
require_pretrained=True,
|
||
|
device=device,
|
||
|
)
|
||
|
|
||
|
def forward(self, input_tensor: torch.Tensor) -> torch.FloatTensor:
|
||
|
embedding: torch.Tensor = self.model.encode_image(
|
||
|
input_tensor.half() if self.device.type == "cuda" else input_tensor,
|
||
|
normalize=True,
|
||
|
).float()
|
||
|
return embedding
|
||
|
|
||
|
|
||
|
def export(model: ExportBase) -> None:
|
||
|
model.eval()
|
||
|
for param in model.parameters():
|
||
|
param.requires_grad = False
|
||
|
dummy_input = model.dummy_input()
|
||
|
model(dummy_input)
|
||
|
jit = torch.jit.trace(model, dummy_input) # type: ignore[no-untyped-call,attr-defined]
|
||
|
tflite_model_path = f"output/{model.name}.tflite"
|
||
|
os.makedirs("output", exist_ok=True)
|
||
|
|
||
|
converter = TFLiteConverter(
|
||
|
jit,
|
||
|
dummy_input,
|
||
|
tflite_model_path,
|
||
|
optimize=model.optimize,
|
||
|
nchw_transpose=model.nchw_transpose,
|
||
|
)
|
||
|
# segfaults on ARM, must run on x86_64 / AMD64
|
||
|
converter.convert()
|
||
|
|
||
|
armnn_model_path = f"output/{model.name}.armnn"
|
||
|
os.environ["LD_LIBRARY_PATH"] = "armnn"
|
||
|
subprocess.run(
|
||
|
[
|
||
|
"./armnnconverter",
|
||
|
"-f",
|
||
|
"tflite-binary",
|
||
|
"-m",
|
||
|
tflite_model_path,
|
||
|
"-i",
|
||
|
"input_tensor",
|
||
|
"-o",
|
||
|
"output_tensor",
|
||
|
"-p",
|
||
|
armnn_model_path,
|
||
|
]
|
||
|
)
|
||
|
|
||
|
|
||
|
def main() -> None:
|
||
|
if platform.machine() not in ("x86_64", "AMD64"):
|
||
|
raise RuntimeError(f"Can only run on x86_64 / AMD64, not {platform.machine()}")
|
||
|
|
||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
if device.type != "cuda":
|
||
|
logging.warning(
|
||
|
"No CUDA available, cannot create fp16 model! proceeding to create a fp32 model (use only for testing)"
|
||
|
)
|
||
|
models = [
|
||
|
ClipVision("ViT-B-32", "openai", device),
|
||
|
ArcFace("buffalo_l_rec.onnx", device),
|
||
|
RetinaFace("buffalo_l_det.onnx", device),
|
||
|
]
|
||
|
for model in models:
|
||
|
export(model)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
with torch.no_grad():
|
||
|
main()
|