2023-08-29 15:58:00 +02:00
|
|
|
import {
|
|
|
|
CLIPConfig,
|
|
|
|
CLIPMode,
|
|
|
|
DetectFaceResult,
|
|
|
|
IMachineLearningRepository,
|
|
|
|
ModelConfig,
|
|
|
|
ModelType,
|
|
|
|
RecognitionConfig,
|
|
|
|
TextModelInput,
|
|
|
|
VisionModelInput,
|
|
|
|
} from '@app/domain';
|
2023-02-25 15:12:03 +01:00
|
|
|
import { Injectable } from '@nestjs/common';
|
2024-02-02 04:18:00 +01:00
|
|
|
import { readFile } from 'node:fs/promises';
|
2023-02-25 15:12:03 +01:00
|
|
|
|
2024-01-04 21:34:50 +01:00
|
|
|
const errorPrefix = 'Machine learning request';
|
|
|
|
|
2023-02-25 15:12:03 +01:00
|
|
|
@Injectable()
|
|
|
|
export class MachineLearningRepository implements IMachineLearningRepository {
|
2024-01-18 06:08:48 +01:00
|
|
|
private async predict<T>(url: string, input: TextModelInput | VisionModelInput, config: ModelConfig): Promise<T> {
|
2023-08-29 15:58:00 +02:00
|
|
|
const formData = await this.getFormData(input, config);
|
2024-01-04 21:34:50 +01:00
|
|
|
|
|
|
|
const res = await fetch(`${url}/predict`, { method: 'POST', body: formData }).catch((error: Error | any) => {
|
|
|
|
throw new Error(`${errorPrefix} to "${url}" failed with ${error?.cause || error}`);
|
|
|
|
});
|
|
|
|
|
2023-09-09 11:03:59 +02:00
|
|
|
if (res.status >= 400) {
|
2024-01-04 21:34:50 +01:00
|
|
|
const modelType = config.modelType ? ` for ${config.modelType.replace('-', ' ')}` : '';
|
|
|
|
throw new Error(`${errorPrefix}${modelType} failed with status ${res.status}: ${res.statusText}`);
|
2023-09-09 11:03:59 +02:00
|
|
|
}
|
2023-08-29 15:58:00 +02:00
|
|
|
return res.json();
|
|
|
|
}
|
|
|
|
|
|
|
|
detectFaces(url: string, input: VisionModelInput, config: RecognitionConfig): Promise<DetectFaceResult[]> {
|
2024-01-18 06:08:48 +01:00
|
|
|
return this.predict<DetectFaceResult[]>(url, input, { ...config, modelType: ModelType.FACIAL_RECOGNITION });
|
2023-02-25 15:12:03 +01:00
|
|
|
}
|
|
|
|
|
2023-08-29 15:58:00 +02:00
|
|
|
encodeImage(url: string, input: VisionModelInput, config: CLIPConfig): Promise<number[]> {
|
2024-01-18 06:08:48 +01:00
|
|
|
return this.predict<number[]>(url, input, {
|
2023-08-29 15:58:00 +02:00
|
|
|
...config,
|
|
|
|
modelType: ModelType.CLIP,
|
|
|
|
mode: CLIPMode.VISION,
|
|
|
|
} as CLIPConfig);
|
2023-05-17 19:07:17 +02:00
|
|
|
}
|
|
|
|
|
2023-08-29 15:58:00 +02:00
|
|
|
encodeText(url: string, input: TextModelInput, config: CLIPConfig): Promise<number[]> {
|
2024-01-18 06:08:48 +01:00
|
|
|
return this.predict<number[]>(url, input, {
|
|
|
|
...config,
|
|
|
|
modelType: ModelType.CLIP,
|
|
|
|
mode: CLIPMode.TEXT,
|
|
|
|
} as CLIPConfig);
|
2023-03-18 14:44:42 +01:00
|
|
|
}
|
|
|
|
|
2023-08-29 15:58:00 +02:00
|
|
|
async getFormData(input: TextModelInput | VisionModelInput, config: ModelConfig): Promise<FormData> {
|
|
|
|
const formData = new FormData();
|
2023-08-30 10:16:00 +02:00
|
|
|
const { enabled, modelName, modelType, ...options } = config;
|
|
|
|
if (!enabled) {
|
|
|
|
throw new Error(`${modelType} is not enabled`);
|
|
|
|
}
|
2023-08-29 15:58:00 +02:00
|
|
|
|
|
|
|
formData.append('modelName', modelName);
|
|
|
|
if (modelType) {
|
|
|
|
formData.append('modelType', modelType);
|
|
|
|
}
|
|
|
|
if (options) {
|
|
|
|
formData.append('options', JSON.stringify(options));
|
|
|
|
}
|
|
|
|
if ('imagePath' in input) {
|
|
|
|
formData.append('image', new Blob([await readFile(input.imagePath)]));
|
|
|
|
} else if ('text' in input) {
|
|
|
|
formData.append('text', input.text);
|
|
|
|
} else {
|
|
|
|
throw new Error('Invalid input');
|
|
|
|
}
|
|
|
|
|
|
|
|
return formData;
|
2023-03-18 14:44:42 +01:00
|
|
|
}
|
2023-02-25 15:12:03 +01:00
|
|
|
}
|