mirror of
https://github.com/immich-app/immich.git
synced 2025-01-16 16:56:46 +01:00
fix(ml): tokenization for webli models (#11881)
This commit is contained in:
parent
5ab92f346a
commit
036676d501
3 changed files with 48 additions and 3 deletions
|
@ -10,6 +10,7 @@ from tokenizers import Encoding, Tokenizer
|
|||
|
||||
from app.config import log
|
||||
from app.models.base import InferenceModel
|
||||
from app.models.transforms import clean_text
|
||||
from app.schemas import ModelSession, ModelTask, ModelType
|
||||
|
||||
|
||||
|
@ -25,6 +26,8 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
|||
session = super()._load()
|
||||
log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'")
|
||||
self.tokenizer = self._load_tokenizer()
|
||||
tokenizer_kwargs: dict[str, Any] | None = self.text_cfg.get("tokenizer_kwargs")
|
||||
self.canonicalize = tokenizer_kwargs is not None and tokenizer_kwargs.get("clean") == "canonicalize"
|
||||
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
|
||||
|
||||
return session
|
||||
|
@ -56,6 +59,11 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
|||
log.debug(f"Loaded model config for CLIP model '{self.model_name}'")
|
||||
return model_cfg
|
||||
|
||||
@property
|
||||
def text_cfg(self) -> dict[str, Any]:
|
||||
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
|
||||
return text_cfg
|
||||
|
||||
@cached_property
|
||||
def tokenizer_file(self) -> dict[str, Any]:
|
||||
log.debug(f"Loading tokenizer file for CLIP model '{self.model_name}'")
|
||||
|
@ -73,8 +81,7 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
|||
|
||||
class OpenClipTextualEncoder(BaseCLIPTextualEncoder):
|
||||
def _load_tokenizer(self) -> Tokenizer:
|
||||
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
|
||||
context_length: int = text_cfg.get("context_length", 77)
|
||||
context_length: int = self.text_cfg.get("context_length", 77)
|
||||
pad_token: str = self.tokenizer_cfg["pad_token"]
|
||||
|
||||
tokenizer: Tokenizer = Tokenizer.from_file(self.tokenizer_file_path.as_posix())
|
||||
|
@ -86,12 +93,14 @@ class OpenClipTextualEncoder(BaseCLIPTextualEncoder):
|
|||
return tokenizer
|
||||
|
||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||
text = clean_text(text, canonicalize=self.canonicalize)
|
||||
tokens: Encoding = self.tokenizer.encode(text)
|
||||
return {"text": np.array([tokens.ids], dtype=np.int32)}
|
||||
|
||||
|
||||
class MClipTextualEncoder(OpenClipTextualEncoder):
|
||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||
text = clean_text(text, canonicalize=self.canonicalize)
|
||||
tokens: Encoding = self.tokenizer.encode(text)
|
||||
return {
|
||||
"input_ids": np.array([tokens.ids], dtype=np.int32),
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import string
|
||||
from io import BytesIO
|
||||
from typing import IO
|
||||
|
||||
|
@ -7,6 +8,7 @@ from numpy.typing import NDArray
|
|||
from PIL import Image
|
||||
|
||||
_PIL_RESAMPLING_METHODS = {resampling.name.lower(): resampling for resampling in Image.Resampling}
|
||||
_PUNCTUATION_TRANS = str.maketrans("", "", string.punctuation)
|
||||
|
||||
|
||||
def resize_pil(img: Image.Image, size: int) -> Image.Image:
|
||||
|
@ -60,3 +62,10 @@ def decode_cv2(image_bytes: NDArray[np.uint8] | bytes | Image.Image) -> NDArray[
|
|||
if isinstance(image_bytes, Image.Image):
|
||||
return pil_to_cv2(image_bytes)
|
||||
return image_bytes
|
||||
|
||||
|
||||
def clean_text(text: str, canonicalize: bool = False) -> str:
|
||||
text = " ".join(text.split())
|
||||
if canonicalize:
|
||||
text = text.translate(_PUNCTUATION_TRANS).lower()
|
||||
return text
|
||||
|
|
|
@ -379,13 +379,40 @@ class TestCLIP:
|
|||
|
||||
clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache")
|
||||
clip_encoder._load()
|
||||
tokens = clip_encoder.tokenize("test search query")
|
||||
tokens = clip_encoder.tokenize("test search query")
|
||||
|
||||
assert "text" in tokens
|
||||
assert isinstance(tokens["text"], np.ndarray)
|
||||
assert tokens["text"].shape == (1, 77)
|
||||
assert tokens["text"].dtype == np.int32
|
||||
assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0)
|
||||
mock_tokenizer.encode.assert_called_once_with("test search query")
|
||||
|
||||
def test_openclip_tokenizer_canonicalizes_text(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
clip_model_cfg: dict[str, Any],
|
||||
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
||||
) -> None:
|
||||
clip_model_cfg["text_cfg"]["tokenizer_kwargs"] = {"clean": "canonicalize"}
|
||||
mocker.patch.object(OpenClipTextualEncoder, "download")
|
||||
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
|
||||
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
||||
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
|
||||
mock_tokenizer = mocker.patch("app.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
|
||||
mock_ids = [randint(0, 50000) for _ in range(77)]
|
||||
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
|
||||
|
||||
clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache")
|
||||
clip_encoder._load()
|
||||
tokens = clip_encoder.tokenize("Test Search Query!")
|
||||
|
||||
assert "text" in tokens
|
||||
assert isinstance(tokens["text"], np.ndarray)
|
||||
assert tokens["text"].shape == (1, 77)
|
||||
assert tokens["text"].dtype == np.int32
|
||||
assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0)
|
||||
mock_tokenizer.encode.assert_called_once_with("test search query")
|
||||
|
||||
def test_mclip_tokenizer(
|
||||
self,
|
||||
|
|
Loading…
Reference in a new issue