mirror of
https://github.com/immich-app/immich.git
synced 2025-01-07 20:36:48 +01:00
52 lines
1.3 KiB
Python
52 lines
1.3 KiB
Python
|
from typing import Optional
|
||
|
from pydantic import BaseModel
|
||
|
import numpy as np
|
||
|
from fastapi import FastAPI
|
||
|
import tensorflow as tf
|
||
|
from tensorflow.keras.applications import InceptionV3
|
||
|
from tensorflow.keras.applications.inception_v3 import preprocess_input, decode_predictions
|
||
|
from tensorflow.keras.preprocessing import image
|
||
|
|
||
|
IMG_SIZE = 299
|
||
|
PREDICTION_MODEL = InceptionV3(weights='imagenet')
|
||
|
|
||
|
|
||
|
def warm_up():
|
||
|
img_path = f'./app/test.png'
|
||
|
img = image.load_img(img_path, target_size=(IMG_SIZE, IMG_SIZE))
|
||
|
x = image.img_to_array(img)
|
||
|
x = np.expand_dims(x, axis=0)
|
||
|
x = preprocess_input(x)
|
||
|
PREDICTION_MODEL.predict(x)
|
||
|
|
||
|
|
||
|
# Warm up model
|
||
|
warm_up()
|
||
|
app = FastAPI()
|
||
|
|
||
|
|
||
|
class TagImagePayload(BaseModel):
|
||
|
thumbnail_path: str
|
||
|
|
||
|
|
||
|
@app.post("/tagImage")
|
||
|
async def post_root(payload: TagImagePayload):
|
||
|
imagePath = payload.thumbnail_path
|
||
|
|
||
|
if imagePath[0] == '.':
|
||
|
imagePath = imagePath[2:]
|
||
|
|
||
|
img_path = f'./app/{imagePath}'
|
||
|
img = image.load_img(img_path, target_size=(IMG_SIZE, IMG_SIZE))
|
||
|
x = image.img_to_array(img)
|
||
|
x = np.expand_dims(x, axis=0)
|
||
|
x = preprocess_input(x)
|
||
|
|
||
|
preds = PREDICTION_MODEL.predict(x)
|
||
|
result = decode_predictions(preds, top=3)[0]
|
||
|
payload = []
|
||
|
for _, value, _ in result:
|
||
|
payload.append(value)
|
||
|
|
||
|
return payload
|