1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-01 08:31:59 +00:00

chore(server,ml): remove object detection job and endpoint (#2627)

* removed object detection job

* removed object detection endpoint
This commit is contained in:
Mert 2023-05-31 21:49:51 -04:00 committed by GitHub
parent 9730bf0acc
commit 631f13cf2f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 0 additions and 75 deletions

View file

@ -22,7 +22,6 @@ class ClipRequestBody(BaseModel):
classification_model = os.getenv( classification_model = os.getenv(
"MACHINE_LEARNING_CLASSIFICATION_MODEL", "microsoft/resnet-50" "MACHINE_LEARNING_CLASSIFICATION_MODEL", "microsoft/resnet-50"
) )
object_model = os.getenv("MACHINE_LEARNING_OBJECT_MODEL", "hustvl/yolos-tiny")
clip_image_model = os.getenv("MACHINE_LEARNING_CLIP_IMAGE_MODEL", "clip-ViT-B-32") clip_image_model = os.getenv("MACHINE_LEARNING_CLIP_IMAGE_MODEL", "clip-ViT-B-32")
clip_text_model = os.getenv("MACHINE_LEARNING_CLIP_TEXT_MODEL", "clip-ViT-B-32") clip_text_model = os.getenv("MACHINE_LEARNING_CLIP_TEXT_MODEL", "clip-ViT-B-32")
facial_recognition_model = os.getenv( facial_recognition_model = os.getenv(
@ -39,7 +38,6 @@ app = FastAPI()
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
# Get all models # Get all models
_get_model(object_model, "object-detection")
_get_model(classification_model, "image-classification") _get_model(classification_model, "image-classification")
_get_model(clip_image_model) _get_model(clip_image_model)
_get_model(clip_text_model) _get_model(clip_text_model)
@ -55,14 +53,6 @@ async def root():
def ping(): def ping():
return "pong" return "pong"
@app.post("/object-detection/detect-object", status_code=200)
def object_detection(payload: MlRequestBody):
model = _get_model(object_model, "object-detection")
assetPath = payload.thumbnailPath
return run_engine(model, assetPath)
@app.post("/image-classifier/tag-image", status_code=200) @app.post("/image-classifier/tag-image", status_code=200)
def image_classification(payload: MlRequestBody): def image_classification(payload: MlRequestBody):
model = _get_model(classification_model, "image-classification") model = _get_model(classification_model, "image-classification")

View file

@ -52,7 +52,6 @@ export class ProcessorService {
[JobName.USER_DELETE_CHECK]: () => this.userService.handleUserDeleteCheck(), [JobName.USER_DELETE_CHECK]: () => this.userService.handleUserDeleteCheck(),
[JobName.USER_DELETION]: (data) => this.userService.handleUserDelete(data), [JobName.USER_DELETION]: (data) => this.userService.handleUserDelete(data),
[JobName.QUEUE_OBJECT_TAGGING]: (data) => this.smartInfoService.handleQueueObjectTagging(data), [JobName.QUEUE_OBJECT_TAGGING]: (data) => this.smartInfoService.handleQueueObjectTagging(data),
[JobName.DETECT_OBJECTS]: (data) => this.smartInfoService.handleDetectObjects(data),
[JobName.CLASSIFY_IMAGE]: (data) => this.smartInfoService.handleClassifyImage(data), [JobName.CLASSIFY_IMAGE]: (data) => this.smartInfoService.handleClassifyImage(data),
[JobName.QUEUE_ENCODE_CLIP]: (data) => this.smartInfoService.handleQueueEncodeClip(data), [JobName.QUEUE_ENCODE_CLIP]: (data) => this.smartInfoService.handleQueueEncodeClip(data),
[JobName.ENCODE_CLIP]: (data) => this.smartInfoService.handleEncodeClip(data), [JobName.ENCODE_CLIP]: (data) => this.smartInfoService.handleEncodeClip(data),

View file

@ -43,7 +43,6 @@ export enum JobName {
// object tagging // object tagging
QUEUE_OBJECT_TAGGING = 'queue-object-tagging', QUEUE_OBJECT_TAGGING = 'queue-object-tagging',
DETECT_OBJECTS = 'detect-objects',
CLASSIFY_IMAGE = 'classify-image', CLASSIFY_IMAGE = 'classify-image',
// facial recognition // facial recognition
@ -105,7 +104,6 @@ export const JOBS_TO_QUEUE: Record<JobName, QueueName> = {
// object tagging // object tagging
[JobName.QUEUE_OBJECT_TAGGING]: QueueName.OBJECT_TAGGING, [JobName.QUEUE_OBJECT_TAGGING]: QueueName.OBJECT_TAGGING,
[JobName.DETECT_OBJECTS]: QueueName.OBJECT_TAGGING,
[JobName.CLASSIFY_IMAGE]: QueueName.OBJECT_TAGGING, [JobName.CLASSIFY_IMAGE]: QueueName.OBJECT_TAGGING,
// facial recognition // facial recognition

View file

@ -52,7 +52,6 @@ export type JobItem =
// Object Tagging // Object Tagging
| { name: JobName.QUEUE_OBJECT_TAGGING; data: IBaseJob } | { name: JobName.QUEUE_OBJECT_TAGGING; data: IBaseJob }
| { name: JobName.DETECT_OBJECTS; data: IEntityJob }
| { name: JobName.CLASSIFY_IMAGE; data: IEntityJob } | { name: JobName.CLASSIFY_IMAGE; data: IEntityJob }
// Recognize Faces // Recognize Faces

View file

@ -119,7 +119,6 @@ export class JobService {
case JobName.GENERATE_JPEG_THUMBNAIL: { case JobName.GENERATE_JPEG_THUMBNAIL: {
await this.jobRepository.queue({ name: JobName.GENERATE_WEBP_THUMBNAIL, data: item.data }); await this.jobRepository.queue({ name: JobName.GENERATE_WEBP_THUMBNAIL, data: item.data });
await this.jobRepository.queue({ name: JobName.CLASSIFY_IMAGE, data: item.data }); await this.jobRepository.queue({ name: JobName.CLASSIFY_IMAGE, data: item.data });
await this.jobRepository.queue({ name: JobName.DETECT_OBJECTS, data: item.data });
await this.jobRepository.queue({ name: JobName.ENCODE_CLIP, data: item.data }); await this.jobRepository.queue({ name: JobName.ENCODE_CLIP, data: item.data });
await this.jobRepository.queue({ name: JobName.RECOGNIZE_FACES, data: item.data }); await this.jobRepository.queue({ name: JobName.RECOGNIZE_FACES, data: item.data });
@ -134,7 +133,6 @@ export class JobService {
// In addition to the above jobs, all of these should queue `SEARCH_INDEX_ASSET` // In addition to the above jobs, all of these should queue `SEARCH_INDEX_ASSET`
switch (item.name) { switch (item.name) {
case JobName.CLASSIFY_IMAGE: case JobName.CLASSIFY_IMAGE:
case JobName.DETECT_OBJECTS:
case JobName.ENCODE_CLIP: case JobName.ENCODE_CLIP:
case JobName.RECOGNIZE_FACES: case JobName.RECOGNIZE_FACES:
case JobName.METADATA_EXTRACTION: case JobName.METADATA_EXTRACTION:

View file

@ -21,7 +21,6 @@ export interface DetectFaceResult {
export interface IMachineLearningRepository { export interface IMachineLearningRepository {
classifyImage(input: MachineLearningInput): Promise<string[]>; classifyImage(input: MachineLearningInput): Promise<string[]>;
detectObjects(input: MachineLearningInput): Promise<string[]>;
encodeImage(input: MachineLearningInput): Promise<number[]>; encodeImage(input: MachineLearningInput): Promise<number[]>;
encodeText(input: string): Promise<number[]>; encodeText(input: string): Promise<number[]>;
detectFaces(input: MachineLearningInput): Promise<DetectFaceResult[]>; detectFaces(input: MachineLearningInput): Promise<DetectFaceResult[]>;

View file

@ -49,7 +49,6 @@ describe(SmartInfoService.name, () => {
expect(jobMock.queue.mock.calls).toEqual([ expect(jobMock.queue.mock.calls).toEqual([
[{ name: JobName.CLASSIFY_IMAGE, data: { id: assetEntityStub.image.id } }], [{ name: JobName.CLASSIFY_IMAGE, data: { id: assetEntityStub.image.id } }],
[{ name: JobName.DETECT_OBJECTS, data: { id: assetEntityStub.image.id } }],
]); ]);
expect(assetMock.getWithout).toHaveBeenCalledWith({ skip: 0, take: 1000 }, WithoutProperty.OBJECT_TAGS); expect(assetMock.getWithout).toHaveBeenCalledWith({ skip: 0, take: 1000 }, WithoutProperty.OBJECT_TAGS);
}); });
@ -64,7 +63,6 @@ describe(SmartInfoService.name, () => {
expect(jobMock.queue.mock.calls).toEqual([ expect(jobMock.queue.mock.calls).toEqual([
[{ name: JobName.CLASSIFY_IMAGE, data: { id: assetEntityStub.image.id } }], [{ name: JobName.CLASSIFY_IMAGE, data: { id: assetEntityStub.image.id } }],
[{ name: JobName.DETECT_OBJECTS, data: { id: assetEntityStub.image.id } }],
]); ]);
expect(assetMock.getAll).toHaveBeenCalled(); expect(assetMock.getAll).toHaveBeenCalled();
}); });
@ -103,39 +101,6 @@ describe(SmartInfoService.name, () => {
}); });
}); });
describe('handleDetectObjects', () => {
it('should skip assets without a resize path', async () => {
const asset = { resizePath: '' } as AssetEntity;
assetMock.getByIds.mockResolvedValue([asset]);
await sut.handleDetectObjects({ id: asset.id });
expect(smartMock.upsert).not.toHaveBeenCalled();
expect(machineMock.detectObjects).not.toHaveBeenCalled();
});
it('should save the returned objects', async () => {
machineMock.detectObjects.mockResolvedValue(['obj1', 'obj2', 'obj3']);
await sut.handleDetectObjects({ id: asset.id });
expect(machineMock.detectObjects).toHaveBeenCalledWith({ thumbnailPath: 'path/to/resize.ext' });
expect(smartMock.upsert).toHaveBeenCalledWith({
assetId: 'asset-1',
objects: ['obj1', 'obj2', 'obj3'],
});
});
it('should no update the smart info if no objects were returned', async () => {
machineMock.detectObjects.mockResolvedValue([]);
await sut.handleDetectObjects({ id: asset.id });
expect(machineMock.detectObjects).toHaveBeenCalled();
expect(smartMock.upsert).not.toHaveBeenCalled();
});
});
describe('handleQueueEncodeClip', () => { describe('handleQueueEncodeClip', () => {
it('should queue the assets without clip embeddings', async () => { it('should queue the assets without clip embeddings', async () => {
assetMock.getWithout.mockResolvedValue({ assetMock.getWithout.mockResolvedValue({

View file

@ -27,30 +27,12 @@ export class SmartInfoService {
for await (const assets of assetPagination) { for await (const assets of assetPagination) {
for (const asset of assets) { for (const asset of assets) {
await this.jobRepository.queue({ name: JobName.CLASSIFY_IMAGE, data: { id: asset.id } }); await this.jobRepository.queue({ name: JobName.CLASSIFY_IMAGE, data: { id: asset.id } });
await this.jobRepository.queue({ name: JobName.DETECT_OBJECTS, data: { id: asset.id } });
} }
} }
return true; return true;
} }
async handleDetectObjects({ id }: IEntityJob) {
const [asset] = await this.assetRepository.getByIds([id]);
if (!MACHINE_LEARNING_ENABLED || !asset.resizePath) {
return false;
}
const objects = await this.machineLearning.detectObjects({ thumbnailPath: asset.resizePath });
if (objects.length === 0) {
return false;
}
await this.repository.upsert({ assetId: asset.id, objects });
return true;
}
async handleClassifyImage({ id }: IEntityJob) { async handleClassifyImage({ id }: IEntityJob) {
const [asset] = await this.assetRepository.getByIds([id]); const [asset] = await this.assetRepository.getByIds([id]);

View file

@ -3,7 +3,6 @@ import { IMachineLearningRepository } from '../src';
export const newMachineLearningRepositoryMock = (): jest.Mocked<IMachineLearningRepository> => { export const newMachineLearningRepositoryMock = (): jest.Mocked<IMachineLearningRepository> => {
return { return {
classifyImage: jest.fn(), classifyImage: jest.fn(),
detectObjects: jest.fn(),
encodeImage: jest.fn(), encodeImage: jest.fn(),
encodeText: jest.fn(), encodeText: jest.fn(),
detectFaces: jest.fn(), detectFaces: jest.fn(),

View file

@ -14,10 +14,6 @@ export class MachineLearningRepository implements IMachineLearningRepository {
return client.post<DetectFaceResult[]>('/facial-recognition/detect-faces', input).then((res) => res.data); return client.post<DetectFaceResult[]>('/facial-recognition/detect-faces', input).then((res) => res.data);
} }
detectObjects(input: MachineLearningInput): Promise<string[]> {
return client.post<string[]>('/object-detection/detect-object', input).then((res) => res.data);
}
encodeImage(input: MachineLearningInput): Promise<number[]> { encodeImage(input: MachineLearningInput): Promise<number[]> {
return client.post<number[]>('/sentence-transformer/encode-image', input).then((res) => res.data); return client.post<number[]>('/sentence-transformer/encode-image', input).then((res) => res.data);
} }