1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-19 18:26:46 +01:00
immich/machine-learning/app/models/transforms.py

72 lines
2.3 KiB
Python
Raw Normal View History

import string
from io import BytesIO
from typing import IO
import cv2
import numpy as np
from numpy.typing import NDArray
from PIL import Image
_PIL_RESAMPLING_METHODS = {resampling.name.lower(): resampling for resampling in Image.Resampling}
_PUNCTUATION_TRANS = str.maketrans("", "", string.punctuation)
def resize_pil(img: Image.Image, size: int) -> Image.Image:
if img.width < img.height:
return img.resize((size, int((img.height / img.width) * size)), resample=Image.Resampling.BICUBIC)
else:
return img.resize((int((img.width / img.height) * size), size), resample=Image.Resampling.BICUBIC)
# https://stackoverflow.com/a/60883103
def crop_pil(img: Image.Image, size: int) -> Image.Image:
left = int((img.size[0] / 2) - (size / 2))
upper = int((img.size[1] / 2) - (size / 2))
right = left + size
lower = upper + size
return img.crop((left, upper, right, lower))
def to_numpy(img: Image.Image) -> NDArray[np.float32]:
return np.asarray(img if img.mode == "RGB" else img.convert("RGB"), dtype=np.float32) / 255.0
def normalize(
img: NDArray[np.float32], mean: float | NDArray[np.float32], std: float | NDArray[np.float32]
) -> NDArray[np.float32]:
return np.divide(img - mean, std, dtype=np.float32)
def get_pil_resampling(resample: str) -> Image.Resampling:
return _PIL_RESAMPLING_METHODS[resample.lower()]
def pil_to_cv2(image: Image.Image) -> NDArray[np.uint8]:
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) # type: ignore
def decode_pil(image_bytes: bytes | IO[bytes] | Image.Image) -> Image.Image:
if isinstance(image_bytes, Image.Image):
return image_bytes
image: Image.Image = Image.open(BytesIO(image_bytes) if isinstance(image_bytes, bytes) else image_bytes)
image.load()
if not image.mode == "RGB":
image = image.convert("RGB")
return image
def decode_cv2(image_bytes: NDArray[np.uint8] | bytes | Image.Image) -> NDArray[np.uint8]:
if isinstance(image_bytes, bytes):
image_bytes = decode_pil(image_bytes) # pillow is much faster than cv2
if isinstance(image_bytes, Image.Image):
return pil_to_cv2(image_bytes)
return image_bytes
def clean_text(text: str, canonicalize: bool = False) -> str:
text = " ".join(text.split())
if canonicalize:
text = text.translate(_PUNCTUATION_TRANS).lower()
return text