2024-07-31 06:34:45 +02:00
|
|
|
import os
|
2023-10-31 11:02:04 +01:00
|
|
|
import tempfile
|
|
|
|
import warnings
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from multilingual_clip.pt_multilingual_clip import MultilingualCLIP
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
|
|
from .openclip import OpenCLIPModelConfig
|
|
|
|
from .openclip import to_onnx as openclip_to_onnx
|
|
|
|
from .util import get_model_path
|
|
|
|
|
|
|
|
_MCLIP_TO_OPENCLIP = {
|
|
|
|
"M-CLIP/XLM-Roberta-Large-Vit-B-32": OpenCLIPModelConfig("ViT-B-32", "openai"),
|
|
|
|
"M-CLIP/XLM-Roberta-Large-Vit-B-16Plus": OpenCLIPModelConfig("ViT-B-16-plus-240", "laion400m_e32"),
|
|
|
|
"M-CLIP/LABSE-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"),
|
|
|
|
"M-CLIP/XLM-Roberta-Large-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"),
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def to_onnx(
|
|
|
|
model_name: str,
|
|
|
|
output_dir_visual: Path | str,
|
|
|
|
output_dir_textual: Path | str,
|
2024-07-31 06:34:45 +02:00
|
|
|
) -> tuple[Path, Path]:
|
2023-10-31 11:02:04 +01:00
|
|
|
textual_path = get_model_path(output_dir_textual)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
2024-07-31 06:34:45 +02:00
|
|
|
model = MultilingualCLIP.from_pretrained(model_name, cache_dir=os.environ.get("CACHE_DIR", tmpdir))
|
2023-10-31 11:02:04 +01:00
|
|
|
AutoTokenizer.from_pretrained(model_name).save_pretrained(output_dir_textual)
|
|
|
|
|
2024-07-31 06:34:45 +02:00
|
|
|
model.eval()
|
2023-10-31 11:02:04 +01:00
|
|
|
for param in model.parameters():
|
|
|
|
param.requires_grad_(False)
|
|
|
|
|
|
|
|
export_text_encoder(model, textual_path)
|
2024-07-31 06:34:45 +02:00
|
|
|
visual_path, _ = openclip_to_onnx(_MCLIP_TO_OPENCLIP[model_name], output_dir_visual)
|
|
|
|
assert visual_path is not None, "Visual model export failed"
|
|
|
|
return visual_path, textual_path
|
2023-10-31 11:02:04 +01:00
|
|
|
|
|
|
|
|
|
|
|
def export_text_encoder(model: MultilingualCLIP, output_path: Path | str) -> None:
|
|
|
|
output_path = Path(output_path)
|
|
|
|
|
|
|
|
def forward(self: MultilingualCLIP, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
|
|
|
embs = self.transformer(input_ids, attention_mask)[0]
|
|
|
|
embs = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None]
|
|
|
|
embs = self.LinearTransformation(embs)
|
|
|
|
return torch.nn.functional.normalize(embs, dim=-1)
|
|
|
|
|
|
|
|
# unfortunately need to monkeypatch for tracing to work here
|
|
|
|
# otherwise it hits the 2GiB protobuf serialization limit
|
|
|
|
MultilingualCLIP.forward = forward
|
|
|
|
|
|
|
|
args = (torch.ones(1, 77, dtype=torch.int32), torch.ones(1, 77, dtype=torch.int32))
|
|
|
|
with warnings.catch_warnings():
|
|
|
|
warnings.simplefilter("ignore", UserWarning)
|
|
|
|
torch.onnx.export(
|
|
|
|
model,
|
|
|
|
args,
|
|
|
|
output_path.as_posix(),
|
|
|
|
input_names=["input_ids", "attention_mask"],
|
2024-07-31 06:34:45 +02:00
|
|
|
output_names=["embedding"],
|
2023-10-31 11:02:04 +01:00
|
|
|
opset_version=17,
|
2024-07-31 06:34:45 +02:00
|
|
|
# dynamic_axes={
|
|
|
|
# "input_ids": {0: "batch_size", 1: "sequence_length"},
|
|
|
|
# "attention_mask": {0: "batch_size", 1: "sequence_length"},
|
|
|
|
# },
|
2023-10-31 11:02:04 +01:00
|
|
|
)
|