mirror of
https://github.com/immich-app/immich.git
synced 2024-12-29 15:11:58 +00:00
fix(ml): tokenization for webli models (#11881)
This commit is contained in:
parent
5ab92f346a
commit
036676d501
3 changed files with 48 additions and 3 deletions
|
@ -10,6 +10,7 @@ from tokenizers import Encoding, Tokenizer
|
||||||
|
|
||||||
from app.config import log
|
from app.config import log
|
||||||
from app.models.base import InferenceModel
|
from app.models.base import InferenceModel
|
||||||
|
from app.models.transforms import clean_text
|
||||||
from app.schemas import ModelSession, ModelTask, ModelType
|
from app.schemas import ModelSession, ModelTask, ModelType
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,6 +26,8 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
||||||
session = super()._load()
|
session = super()._load()
|
||||||
log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'")
|
log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'")
|
||||||
self.tokenizer = self._load_tokenizer()
|
self.tokenizer = self._load_tokenizer()
|
||||||
|
tokenizer_kwargs: dict[str, Any] | None = self.text_cfg.get("tokenizer_kwargs")
|
||||||
|
self.canonicalize = tokenizer_kwargs is not None and tokenizer_kwargs.get("clean") == "canonicalize"
|
||||||
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
|
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
|
||||||
|
|
||||||
return session
|
return session
|
||||||
|
@ -56,6 +59,11 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
||||||
log.debug(f"Loaded model config for CLIP model '{self.model_name}'")
|
log.debug(f"Loaded model config for CLIP model '{self.model_name}'")
|
||||||
return model_cfg
|
return model_cfg
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text_cfg(self) -> dict[str, Any]:
|
||||||
|
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
|
||||||
|
return text_cfg
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def tokenizer_file(self) -> dict[str, Any]:
|
def tokenizer_file(self) -> dict[str, Any]:
|
||||||
log.debug(f"Loading tokenizer file for CLIP model '{self.model_name}'")
|
log.debug(f"Loading tokenizer file for CLIP model '{self.model_name}'")
|
||||||
|
@ -73,8 +81,7 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
||||||
|
|
||||||
class OpenClipTextualEncoder(BaseCLIPTextualEncoder):
|
class OpenClipTextualEncoder(BaseCLIPTextualEncoder):
|
||||||
def _load_tokenizer(self) -> Tokenizer:
|
def _load_tokenizer(self) -> Tokenizer:
|
||||||
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
|
context_length: int = self.text_cfg.get("context_length", 77)
|
||||||
context_length: int = text_cfg.get("context_length", 77)
|
|
||||||
pad_token: str = self.tokenizer_cfg["pad_token"]
|
pad_token: str = self.tokenizer_cfg["pad_token"]
|
||||||
|
|
||||||
tokenizer: Tokenizer = Tokenizer.from_file(self.tokenizer_file_path.as_posix())
|
tokenizer: Tokenizer = Tokenizer.from_file(self.tokenizer_file_path.as_posix())
|
||||||
|
@ -86,12 +93,14 @@ class OpenClipTextualEncoder(BaseCLIPTextualEncoder):
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||||
|
text = clean_text(text, canonicalize=self.canonicalize)
|
||||||
tokens: Encoding = self.tokenizer.encode(text)
|
tokens: Encoding = self.tokenizer.encode(text)
|
||||||
return {"text": np.array([tokens.ids], dtype=np.int32)}
|
return {"text": np.array([tokens.ids], dtype=np.int32)}
|
||||||
|
|
||||||
|
|
||||||
class MClipTextualEncoder(OpenClipTextualEncoder):
|
class MClipTextualEncoder(OpenClipTextualEncoder):
|
||||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||||
|
text = clean_text(text, canonicalize=self.canonicalize)
|
||||||
tokens: Encoding = self.tokenizer.encode(text)
|
tokens: Encoding = self.tokenizer.encode(text)
|
||||||
return {
|
return {
|
||||||
"input_ids": np.array([tokens.ids], dtype=np.int32),
|
"input_ids": np.array([tokens.ids], dtype=np.int32),
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import string
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import IO
|
from typing import IO
|
||||||
|
|
||||||
|
@ -7,6 +8,7 @@ from numpy.typing import NDArray
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
_PIL_RESAMPLING_METHODS = {resampling.name.lower(): resampling for resampling in Image.Resampling}
|
_PIL_RESAMPLING_METHODS = {resampling.name.lower(): resampling for resampling in Image.Resampling}
|
||||||
|
_PUNCTUATION_TRANS = str.maketrans("", "", string.punctuation)
|
||||||
|
|
||||||
|
|
||||||
def resize_pil(img: Image.Image, size: int) -> Image.Image:
|
def resize_pil(img: Image.Image, size: int) -> Image.Image:
|
||||||
|
@ -60,3 +62,10 @@ def decode_cv2(image_bytes: NDArray[np.uint8] | bytes | Image.Image) -> NDArray[
|
||||||
if isinstance(image_bytes, Image.Image):
|
if isinstance(image_bytes, Image.Image):
|
||||||
return pil_to_cv2(image_bytes)
|
return pil_to_cv2(image_bytes)
|
||||||
return image_bytes
|
return image_bytes
|
||||||
|
|
||||||
|
|
||||||
|
def clean_text(text: str, canonicalize: bool = False) -> str:
|
||||||
|
text = " ".join(text.split())
|
||||||
|
if canonicalize:
|
||||||
|
text = text.translate(_PUNCTUATION_TRANS).lower()
|
||||||
|
return text
|
||||||
|
|
|
@ -379,13 +379,40 @@ class TestCLIP:
|
||||||
|
|
||||||
clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache")
|
clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache")
|
||||||
clip_encoder._load()
|
clip_encoder._load()
|
||||||
tokens = clip_encoder.tokenize("test search query")
|
tokens = clip_encoder.tokenize("test search query")
|
||||||
|
|
||||||
assert "text" in tokens
|
assert "text" in tokens
|
||||||
assert isinstance(tokens["text"], np.ndarray)
|
assert isinstance(tokens["text"], np.ndarray)
|
||||||
assert tokens["text"].shape == (1, 77)
|
assert tokens["text"].shape == (1, 77)
|
||||||
assert tokens["text"].dtype == np.int32
|
assert tokens["text"].dtype == np.int32
|
||||||
assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0)
|
assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0)
|
||||||
|
mock_tokenizer.encode.assert_called_once_with("test search query")
|
||||||
|
|
||||||
|
def test_openclip_tokenizer_canonicalizes_text(
|
||||||
|
self,
|
||||||
|
mocker: MockerFixture,
|
||||||
|
clip_model_cfg: dict[str, Any],
|
||||||
|
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
||||||
|
) -> None:
|
||||||
|
clip_model_cfg["text_cfg"]["tokenizer_kwargs"] = {"clean": "canonicalize"}
|
||||||
|
mocker.patch.object(OpenClipTextualEncoder, "download")
|
||||||
|
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
|
||||||
|
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
||||||
|
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
|
||||||
|
mock_tokenizer = mocker.patch("app.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
|
||||||
|
mock_ids = [randint(0, 50000) for _ in range(77)]
|
||||||
|
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
|
||||||
|
|
||||||
|
clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache")
|
||||||
|
clip_encoder._load()
|
||||||
|
tokens = clip_encoder.tokenize("Test Search Query!")
|
||||||
|
|
||||||
|
assert "text" in tokens
|
||||||
|
assert isinstance(tokens["text"], np.ndarray)
|
||||||
|
assert tokens["text"].shape == (1, 77)
|
||||||
|
assert tokens["text"].dtype == np.int32
|
||||||
|
assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0)
|
||||||
|
mock_tokenizer.encode.assert_called_once_with("test search query")
|
||||||
|
|
||||||
def test_mclip_tokenizer(
|
def test_mclip_tokenizer(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in a new issue