Files
immich/server/src/repositories/machine-learning.repository.ts
Mert 4bf1b84cc2 feat(ml): support multiple urls (#14347)
* support multiple url

* update api

* styling

unnecessary `?.`

* update docs, make new url field go first

add load balancing section

* update tests

doc formatting

wording

wording

linting

* small styling

* `url` -> `urls`

* fix tests

* update docs

* make docusaurus happy

---------

Co-authored-by: Alex <alex.tran1502@gmail.com>
2024-12-04 20:17:47 +00:00

87 lines
3.1 KiB
TypeScript

import { Inject, Injectable } from '@nestjs/common';
import { readFile } from 'node:fs/promises';
import { CLIPConfig } from 'src/dtos/model-config.dto';
import { ILoggerRepository } from 'src/interfaces/logger.interface';
import {
ClipTextualResponse,
ClipVisualResponse,
FaceDetectionOptions,
FacialRecognitionResponse,
IMachineLearningRepository,
MachineLearningRequest,
ModelPayload,
ModelTask,
ModelType,
} from 'src/interfaces/machine-learning.interface';
@Injectable()
export class MachineLearningRepository implements IMachineLearningRepository {
constructor(@Inject(ILoggerRepository) private logger: ILoggerRepository) {
this.logger.setContext(MachineLearningRepository.name);
}
private async predict<T>(urls: string[], payload: ModelPayload, config: MachineLearningRequest): Promise<T> {
const formData = await this.getFormData(payload, config);
for (const url of urls) {
try {
const response = await fetch(new URL('/predict', url), { method: 'POST', body: formData });
if (response.ok) {
return response.json();
}
this.logger.warn(
`Machine learning request to "${url}" failed with status ${response.status}: ${response.statusText}`,
);
} catch (error: Error | unknown) {
this.logger.warn(
`Machine learning request to "${url}" failed: ${error instanceof Error ? error.message : error}`,
);
}
}
throw new Error(`Machine learning request '${JSON.stringify(config)}' failed for all URLs`);
}
async detectFaces(urls: string[], imagePath: string, { modelName, minScore }: FaceDetectionOptions) {
const request = {
[ModelTask.FACIAL_RECOGNITION]: {
[ModelType.DETECTION]: { modelName, options: { minScore } },
[ModelType.RECOGNITION]: { modelName },
},
};
const response = await this.predict<FacialRecognitionResponse>(urls, { imagePath }, request);
return {
imageHeight: response.imageHeight,
imageWidth: response.imageWidth,
faces: response[ModelTask.FACIAL_RECOGNITION],
};
}
async encodeImage(urls: string[], imagePath: string, { modelName }: CLIPConfig) {
const request = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: { modelName } } };
const response = await this.predict<ClipVisualResponse>(urls, { imagePath }, request);
return response[ModelTask.SEARCH];
}
async encodeText(urls: string[], text: string, { modelName }: CLIPConfig) {
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } };
const response = await this.predict<ClipTextualResponse>(urls, { text }, request);
return response[ModelTask.SEARCH];
}
private async getFormData(payload: ModelPayload, config: MachineLearningRequest): Promise<FormData> {
const formData = new FormData();
formData.append('entries', JSON.stringify(config));
if ('imagePath' in payload) {
formData.append('image', new Blob([await readFile(payload.imagePath)]));
} else if ('text' in payload) {
formData.append('text', payload.text);
} else {
throw new Error('Invalid input');
}
return formData;
}
}