feat(ml): composable ml (#9973)

* modularize model classes

* various fixes

* expose port

* change response

* round coordinates

* simplify preload

* update server

* simplify interface

simplify

* update tests

* composable endpoint

* cleanup

fixes

remove unnecessary interface

support text input, cleanup

* ew camelcase

* update server

server fixes

fix typing

* ml fixes

update locustfile

fixes

* cleaner response

* better repo response

* update tests

formatting and typing

rename

* undo compose change

* linting

fix type

actually fix typing

* stricter typing

fix detection-only response

no need for defaultdict

* update spec file

update api

linting

* update e2e

* unnecessary dimension

* remove commented code

* remove duplicate code

* remove unused imports

* add batch dim
This commit is contained in:
Mert
2024-06-06 23:09:47 -04:00
committed by GitHub
parent 7a46f80ddc
commit 2b1b43a7e4
39 changed files with 982 additions and 999 deletions

View File

@@ -7,7 +7,7 @@ import { IAssetRepository, WithoutProperty } from 'src/interfaces/asset.interfac
import { ICryptoRepository } from 'src/interfaces/crypto.interface';
import { IJobRepository, JobName, JobStatus } from 'src/interfaces/job.interface';
import { ILoggerRepository } from 'src/interfaces/logger.interface';
import { IMachineLearningRepository } from 'src/interfaces/machine-learning.interface';
import { DetectedFaces, IMachineLearningRepository } from 'src/interfaces/machine-learning.interface';
import { IMediaRepository } from 'src/interfaces/media.interface';
import { IMoveRepository } from 'src/interfaces/move.interface';
import { IPersonRepository } from 'src/interfaces/person.interface';
@@ -46,19 +46,21 @@ const responseDto: PersonResponseDto = {
const statistics = { assets: 3 };
const detectFaceMock = {
assetId: 'asset-1',
personId: 'person-1',
boundingBox: {
x1: 100,
y1: 100,
x2: 200,
y2: 200,
},
const detectFaceMock: DetectedFaces = {
faces: [
{
boundingBox: {
x1: 100,
y1: 100,
x2: 200,
y2: 200,
},
embedding: [1, 2, 3, 4],
score: 0.2,
},
],
imageHeight: 500,
imageWidth: 400,
embedding: [1, 2, 3, 4],
score: 0.2,
};
describe(PersonService.name, () => {
@@ -642,21 +644,13 @@ describe(PersonService.name, () => {
it('should handle no results', async () => {
const start = Date.now();
machineLearningMock.detectFaces.mockResolvedValue([]);
machineLearningMock.detectFaces.mockResolvedValue({ imageHeight: 500, imageWidth: 400, faces: [] });
assetMock.getByIds.mockResolvedValue([assetStub.image]);
await sut.handleDetectFaces({ id: assetStub.image.id });
expect(machineLearningMock.detectFaces).toHaveBeenCalledWith(
'http://immich-machine-learning:3003',
{
imagePath: assetStub.image.previewPath,
},
{
enabled: true,
maxDistance: 0.5,
minScore: 0.7,
minFaces: 3,
modelName: 'buffalo_l',
},
assetStub.image.previewPath,
expect.objectContaining({ minScore: 0.7, modelName: 'buffalo_l' }),
);
expect(personMock.createFaces).not.toHaveBeenCalled();
expect(jobMock.queue).not.toHaveBeenCalled();
@@ -671,7 +665,7 @@ describe(PersonService.name, () => {
it('should create a face with no person and queue recognition job', async () => {
personMock.createFaces.mockResolvedValue([faceStub.face1.id]);
machineLearningMock.detectFaces.mockResolvedValue([detectFaceMock]);
machineLearningMock.detectFaces.mockResolvedValue(detectFaceMock);
searchMock.searchFaces.mockResolvedValue([{ face: faceStub.face1, distance: 0.7 }]);
assetMock.getByIds.mockResolvedValue([assetStub.image]);
const face = {

View File

@@ -333,26 +333,28 @@ export class PersonService {
return JobStatus.SKIPPED;
}
const faces = await this.machineLearningRepository.detectFaces(
if (!asset.isVisible) {
return JobStatus.SKIPPED;
}
const { imageHeight, imageWidth, faces } = await this.machineLearningRepository.detectFaces(
machineLearning.url,
{ imagePath: asset.previewPath },
asset.previewPath,
machineLearning.facialRecognition,
);
this.logger.debug(`${faces.length} faces detected in ${asset.previewPath}`);
this.logger.verbose(faces.map((face) => ({ ...face, embedding: `vector(${face.embedding.length})` })));
if (faces.length > 0) {
await this.jobRepository.queue({ name: JobName.QUEUE_FACIAL_RECOGNITION, data: { force: false } });
const mappedFaces = faces.map((face) => ({
assetId: asset.id,
embedding: face.embedding,
imageHeight: face.imageHeight,
imageWidth: face.imageWidth,
imageHeight,
imageWidth,
boundingBoxX1: face.boundingBox.x1,
boundingBoxX2: face.boundingBox.x2,
boundingBoxY1: face.boundingBox.y1,
boundingBoxX2: face.boundingBox.x2,
boundingBoxY2: face.boundingBox.y2,
}));

View File

@@ -102,12 +102,7 @@ export class SearchService {
const userIds = await this.getUserIdsToSearch(auth);
const embedding = await this.machineLearning.encodeText(
machineLearning.url,
{ text: dto.query },
machineLearning.clip,
);
const embedding = await this.machineLearning.encodeText(machineLearning.url, dto.query, machineLearning.clip);
const page = dto.page ?? 1;
const size = dto.size || 100;
const { hasNextPage, items } = await this.searchRepository.searchSmart(

View File

@@ -108,8 +108,8 @@ describe(SmartInfoService.name, () => {
expect(machineMock.encodeImage).toHaveBeenCalledWith(
'http://immich-machine-learning:3003',
{ imagePath: assetStub.image.previewPath },
{ enabled: true, modelName: 'ViT-B-32__openai' },
assetStub.image.previewPath,
expect.objectContaining({ modelName: 'ViT-B-32__openai' }),
);
expect(searchMock.upsert).toHaveBeenCalledWith(assetStub.image.id, [0.01, 0.02, 0.03]);
});

View File

@@ -93,9 +93,9 @@ export class SmartInfoService {
return JobStatus.FAILED;
}
const clipEmbedding = await this.machineLearning.encodeImage(
const embedding = await this.machineLearning.encodeImage(
machineLearning.url,
{ imagePath: asset.previewPath },
asset.previewPath,
machineLearning.clip,
);
@@ -104,7 +104,7 @@ export class SmartInfoService {
await this.databaseRepository.wait(DatabaseLock.CLIPDimSize);
}
await this.repository.upsert(asset.id, clipEmbedding);
await this.repository.upsert(asset.id, embedding);
return JobStatus.SUCCESS;
}