1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2024-12-29 15:11:58 +00:00

fix(ml): tokenization for webli models (#11881)

This commit is contained in:
Mert 2024-08-18 11:05:10 -04:00 committed by GitHub
parent 5ab92f346a
commit 036676d501
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 48 additions and 3 deletions

View file

@ -10,6 +10,7 @@ from tokenizers import Encoding, Tokenizer
from app.config import log from app.config import log
from app.models.base import InferenceModel from app.models.base import InferenceModel
from app.models.transforms import clean_text
from app.schemas import ModelSession, ModelTask, ModelType from app.schemas import ModelSession, ModelTask, ModelType
@ -25,6 +26,8 @@ class BaseCLIPTextualEncoder(InferenceModel):
session = super()._load() session = super()._load()
log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'") log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'")
self.tokenizer = self._load_tokenizer() self.tokenizer = self._load_tokenizer()
tokenizer_kwargs: dict[str, Any] | None = self.text_cfg.get("tokenizer_kwargs")
self.canonicalize = tokenizer_kwargs is not None and tokenizer_kwargs.get("clean") == "canonicalize"
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'") log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
return session return session
@ -56,6 +59,11 @@ class BaseCLIPTextualEncoder(InferenceModel):
log.debug(f"Loaded model config for CLIP model '{self.model_name}'") log.debug(f"Loaded model config for CLIP model '{self.model_name}'")
return model_cfg return model_cfg
@property
def text_cfg(self) -> dict[str, Any]:
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
return text_cfg
@cached_property @cached_property
def tokenizer_file(self) -> dict[str, Any]: def tokenizer_file(self) -> dict[str, Any]:
log.debug(f"Loading tokenizer file for CLIP model '{self.model_name}'") log.debug(f"Loading tokenizer file for CLIP model '{self.model_name}'")
@ -73,8 +81,7 @@ class BaseCLIPTextualEncoder(InferenceModel):
class OpenClipTextualEncoder(BaseCLIPTextualEncoder): class OpenClipTextualEncoder(BaseCLIPTextualEncoder):
def _load_tokenizer(self) -> Tokenizer: def _load_tokenizer(self) -> Tokenizer:
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"] context_length: int = self.text_cfg.get("context_length", 77)
context_length: int = text_cfg.get("context_length", 77)
pad_token: str = self.tokenizer_cfg["pad_token"] pad_token: str = self.tokenizer_cfg["pad_token"]
tokenizer: Tokenizer = Tokenizer.from_file(self.tokenizer_file_path.as_posix()) tokenizer: Tokenizer = Tokenizer.from_file(self.tokenizer_file_path.as_posix())
@ -86,12 +93,14 @@ class OpenClipTextualEncoder(BaseCLIPTextualEncoder):
return tokenizer return tokenizer
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]: def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
text = clean_text(text, canonicalize=self.canonicalize)
tokens: Encoding = self.tokenizer.encode(text) tokens: Encoding = self.tokenizer.encode(text)
return {"text": np.array([tokens.ids], dtype=np.int32)} return {"text": np.array([tokens.ids], dtype=np.int32)}
class MClipTextualEncoder(OpenClipTextualEncoder): class MClipTextualEncoder(OpenClipTextualEncoder):
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]: def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
text = clean_text(text, canonicalize=self.canonicalize)
tokens: Encoding = self.tokenizer.encode(text) tokens: Encoding = self.tokenizer.encode(text)
return { return {
"input_ids": np.array([tokens.ids], dtype=np.int32), "input_ids": np.array([tokens.ids], dtype=np.int32),

View file

@ -1,3 +1,4 @@
import string
from io import BytesIO from io import BytesIO
from typing import IO from typing import IO
@ -7,6 +8,7 @@ from numpy.typing import NDArray
from PIL import Image from PIL import Image
_PIL_RESAMPLING_METHODS = {resampling.name.lower(): resampling for resampling in Image.Resampling} _PIL_RESAMPLING_METHODS = {resampling.name.lower(): resampling for resampling in Image.Resampling}
_PUNCTUATION_TRANS = str.maketrans("", "", string.punctuation)
def resize_pil(img: Image.Image, size: int) -> Image.Image: def resize_pil(img: Image.Image, size: int) -> Image.Image:
@ -60,3 +62,10 @@ def decode_cv2(image_bytes: NDArray[np.uint8] | bytes | Image.Image) -> NDArray[
if isinstance(image_bytes, Image.Image): if isinstance(image_bytes, Image.Image):
return pil_to_cv2(image_bytes) return pil_to_cv2(image_bytes)
return image_bytes return image_bytes
def clean_text(text: str, canonicalize: bool = False) -> str:
text = " ".join(text.split())
if canonicalize:
text = text.translate(_PUNCTUATION_TRANS).lower()
return text

View file

@ -379,13 +379,40 @@ class TestCLIP:
clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache") clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache")
clip_encoder._load() clip_encoder._load()
tokens = clip_encoder.tokenize("test search query") tokens = clip_encoder.tokenize("test search query")
assert "text" in tokens assert "text" in tokens
assert isinstance(tokens["text"], np.ndarray) assert isinstance(tokens["text"], np.ndarray)
assert tokens["text"].shape == (1, 77) assert tokens["text"].shape == (1, 77)
assert tokens["text"].dtype == np.int32 assert tokens["text"].dtype == np.int32
assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0) assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0)
mock_tokenizer.encode.assert_called_once_with("test search query")
def test_openclip_tokenizer_canonicalizes_text(
self,
mocker: MockerFixture,
clip_model_cfg: dict[str, Any],
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
) -> None:
clip_model_cfg["text_cfg"]["tokenizer_kwargs"] = {"clean": "canonicalize"}
mocker.patch.object(OpenClipTextualEncoder, "download")
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
mock_tokenizer = mocker.patch("app.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
mock_ids = [randint(0, 50000) for _ in range(77)]
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache")
clip_encoder._load()
tokens = clip_encoder.tokenize("Test Search Query!")
assert "text" in tokens
assert isinstance(tokens["text"], np.ndarray)
assert tokens["text"].shape == (1, 77)
assert tokens["text"].dtype == np.int32
assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0)
mock_tokenizer.encode.assert_called_once_with("test search query")
def test_mclip_tokenizer( def test_mclip_tokenizer(
self, self,