import string from io import BytesIO from typing import IO import cv2 import numpy as np from numpy.typing import NDArray from PIL import Image _PIL_RESAMPLING_METHODS = {resampling.name.lower(): resampling for resampling in Image.Resampling} _PUNCTUATION_TRANS = str.maketrans("", "", string.punctuation) def resize_pil(img: Image.Image, size: int) -> Image.Image: if img.width < img.height: return img.resize((size, int((img.height / img.width) * size)), resample=Image.Resampling.BICUBIC) else: return img.resize((int((img.width / img.height) * size), size), resample=Image.Resampling.BICUBIC) # https://stackoverflow.com/a/60883103 def crop_pil(img: Image.Image, size: int) -> Image.Image: left = int((img.size[0] / 2) - (size / 2)) upper = int((img.size[1] / 2) - (size / 2)) right = left + size lower = upper + size return img.crop((left, upper, right, lower)) def to_numpy(img: Image.Image) -> NDArray[np.float32]: return np.asarray(img if img.mode == "RGB" else img.convert("RGB"), dtype=np.float32) / 255.0 def normalize( img: NDArray[np.float32], mean: float | NDArray[np.float32], std: float | NDArray[np.float32] ) -> NDArray[np.float32]: return np.divide(img - mean, std, dtype=np.float32) def get_pil_resampling(resample: str) -> Image.Resampling: return _PIL_RESAMPLING_METHODS[resample.lower()] def pil_to_cv2(image: Image.Image) -> NDArray[np.uint8]: return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) # type: ignore def decode_pil(image_bytes: bytes | IO[bytes] | Image.Image) -> Image.Image: if isinstance(image_bytes, Image.Image): return image_bytes image: Image.Image = Image.open(BytesIO(image_bytes) if isinstance(image_bytes, bytes) else image_bytes) image.load() if not image.mode == "RGB": image = image.convert("RGB") return image def decode_cv2(image_bytes: NDArray[np.uint8] | bytes | Image.Image) -> NDArray[np.uint8]: if isinstance(image_bytes, bytes): image_bytes = decode_pil(image_bytes) # pillow is much faster than cv2 if isinstance(image_bytes, Image.Image): return pil_to_cv2(image_bytes) return image_bytes def clean_text(text: str, canonicalize: bool = False) -> str: text = " ".join(text.split()) if canonicalize: text = text.translate(_PUNCTUATION_TRANS).lower() return text