feat(ml): composable ml (#9973)
* modularize model classes * various fixes * expose port * change response * round coordinates * simplify preload * update server * simplify interface simplify * update tests * composable endpoint * cleanup fixes remove unnecessary interface support text input, cleanup * ew camelcase * update server server fixes fix typing * ml fixes update locustfile fixes * cleaner response * better repo response * update tests formatting and typing rename * undo compose change * linting fix type actually fix typing * stricter typing fix detection-only response no need for defaultdict * update spec file update api linting * update e2e * unnecessary dimension * remove commented code * remove duplicate code * remove unused imports * add batch dim
This commit is contained in:
@@ -7,7 +7,7 @@ import { IAssetRepository, WithoutProperty } from 'src/interfaces/asset.interfac
|
||||
import { ICryptoRepository } from 'src/interfaces/crypto.interface';
|
||||
import { IJobRepository, JobName, JobStatus } from 'src/interfaces/job.interface';
|
||||
import { ILoggerRepository } from 'src/interfaces/logger.interface';
|
||||
import { IMachineLearningRepository } from 'src/interfaces/machine-learning.interface';
|
||||
import { DetectedFaces, IMachineLearningRepository } from 'src/interfaces/machine-learning.interface';
|
||||
import { IMediaRepository } from 'src/interfaces/media.interface';
|
||||
import { IMoveRepository } from 'src/interfaces/move.interface';
|
||||
import { IPersonRepository } from 'src/interfaces/person.interface';
|
||||
@@ -46,19 +46,21 @@ const responseDto: PersonResponseDto = {
|
||||
|
||||
const statistics = { assets: 3 };
|
||||
|
||||
const detectFaceMock = {
|
||||
assetId: 'asset-1',
|
||||
personId: 'person-1',
|
||||
boundingBox: {
|
||||
x1: 100,
|
||||
y1: 100,
|
||||
x2: 200,
|
||||
y2: 200,
|
||||
},
|
||||
const detectFaceMock: DetectedFaces = {
|
||||
faces: [
|
||||
{
|
||||
boundingBox: {
|
||||
x1: 100,
|
||||
y1: 100,
|
||||
x2: 200,
|
||||
y2: 200,
|
||||
},
|
||||
embedding: [1, 2, 3, 4],
|
||||
score: 0.2,
|
||||
},
|
||||
],
|
||||
imageHeight: 500,
|
||||
imageWidth: 400,
|
||||
embedding: [1, 2, 3, 4],
|
||||
score: 0.2,
|
||||
};
|
||||
|
||||
describe(PersonService.name, () => {
|
||||
@@ -642,21 +644,13 @@ describe(PersonService.name, () => {
|
||||
it('should handle no results', async () => {
|
||||
const start = Date.now();
|
||||
|
||||
machineLearningMock.detectFaces.mockResolvedValue([]);
|
||||
machineLearningMock.detectFaces.mockResolvedValue({ imageHeight: 500, imageWidth: 400, faces: [] });
|
||||
assetMock.getByIds.mockResolvedValue([assetStub.image]);
|
||||
await sut.handleDetectFaces({ id: assetStub.image.id });
|
||||
expect(machineLearningMock.detectFaces).toHaveBeenCalledWith(
|
||||
'http://immich-machine-learning:3003',
|
||||
{
|
||||
imagePath: assetStub.image.previewPath,
|
||||
},
|
||||
{
|
||||
enabled: true,
|
||||
maxDistance: 0.5,
|
||||
minScore: 0.7,
|
||||
minFaces: 3,
|
||||
modelName: 'buffalo_l',
|
||||
},
|
||||
assetStub.image.previewPath,
|
||||
expect.objectContaining({ minScore: 0.7, modelName: 'buffalo_l' }),
|
||||
);
|
||||
expect(personMock.createFaces).not.toHaveBeenCalled();
|
||||
expect(jobMock.queue).not.toHaveBeenCalled();
|
||||
@@ -671,7 +665,7 @@ describe(PersonService.name, () => {
|
||||
|
||||
it('should create a face with no person and queue recognition job', async () => {
|
||||
personMock.createFaces.mockResolvedValue([faceStub.face1.id]);
|
||||
machineLearningMock.detectFaces.mockResolvedValue([detectFaceMock]);
|
||||
machineLearningMock.detectFaces.mockResolvedValue(detectFaceMock);
|
||||
searchMock.searchFaces.mockResolvedValue([{ face: faceStub.face1, distance: 0.7 }]);
|
||||
assetMock.getByIds.mockResolvedValue([assetStub.image]);
|
||||
const face = {
|
||||
|
||||
@@ -333,26 +333,28 @@ export class PersonService {
|
||||
return JobStatus.SKIPPED;
|
||||
}
|
||||
|
||||
const faces = await this.machineLearningRepository.detectFaces(
|
||||
if (!asset.isVisible) {
|
||||
return JobStatus.SKIPPED;
|
||||
}
|
||||
|
||||
const { imageHeight, imageWidth, faces } = await this.machineLearningRepository.detectFaces(
|
||||
machineLearning.url,
|
||||
{ imagePath: asset.previewPath },
|
||||
asset.previewPath,
|
||||
machineLearning.facialRecognition,
|
||||
);
|
||||
|
||||
this.logger.debug(`${faces.length} faces detected in ${asset.previewPath}`);
|
||||
this.logger.verbose(faces.map((face) => ({ ...face, embedding: `vector(${face.embedding.length})` })));
|
||||
|
||||
if (faces.length > 0) {
|
||||
await this.jobRepository.queue({ name: JobName.QUEUE_FACIAL_RECOGNITION, data: { force: false } });
|
||||
|
||||
const mappedFaces = faces.map((face) => ({
|
||||
assetId: asset.id,
|
||||
embedding: face.embedding,
|
||||
imageHeight: face.imageHeight,
|
||||
imageWidth: face.imageWidth,
|
||||
imageHeight,
|
||||
imageWidth,
|
||||
boundingBoxX1: face.boundingBox.x1,
|
||||
boundingBoxX2: face.boundingBox.x2,
|
||||
boundingBoxY1: face.boundingBox.y1,
|
||||
boundingBoxX2: face.boundingBox.x2,
|
||||
boundingBoxY2: face.boundingBox.y2,
|
||||
}));
|
||||
|
||||
|
||||
@@ -102,12 +102,7 @@ export class SearchService {
|
||||
|
||||
const userIds = await this.getUserIdsToSearch(auth);
|
||||
|
||||
const embedding = await this.machineLearning.encodeText(
|
||||
machineLearning.url,
|
||||
{ text: dto.query },
|
||||
machineLearning.clip,
|
||||
);
|
||||
|
||||
const embedding = await this.machineLearning.encodeText(machineLearning.url, dto.query, machineLearning.clip);
|
||||
const page = dto.page ?? 1;
|
||||
const size = dto.size || 100;
|
||||
const { hasNextPage, items } = await this.searchRepository.searchSmart(
|
||||
|
||||
@@ -108,8 +108,8 @@ describe(SmartInfoService.name, () => {
|
||||
|
||||
expect(machineMock.encodeImage).toHaveBeenCalledWith(
|
||||
'http://immich-machine-learning:3003',
|
||||
{ imagePath: assetStub.image.previewPath },
|
||||
{ enabled: true, modelName: 'ViT-B-32__openai' },
|
||||
assetStub.image.previewPath,
|
||||
expect.objectContaining({ modelName: 'ViT-B-32__openai' }),
|
||||
);
|
||||
expect(searchMock.upsert).toHaveBeenCalledWith(assetStub.image.id, [0.01, 0.02, 0.03]);
|
||||
});
|
||||
|
||||
@@ -93,9 +93,9 @@ export class SmartInfoService {
|
||||
return JobStatus.FAILED;
|
||||
}
|
||||
|
||||
const clipEmbedding = await this.machineLearning.encodeImage(
|
||||
const embedding = await this.machineLearning.encodeImage(
|
||||
machineLearning.url,
|
||||
{ imagePath: asset.previewPath },
|
||||
asset.previewPath,
|
||||
machineLearning.clip,
|
||||
);
|
||||
|
||||
@@ -104,7 +104,7 @@ export class SmartInfoService {
|
||||
await this.databaseRepository.wait(DatabaseLock.CLIPDimSize);
|
||||
}
|
||||
|
||||
await this.repository.upsert(asset.id, clipEmbedding);
|
||||
await this.repository.upsert(asset.id, embedding);
|
||||
|
||||
return JobStatus.SUCCESS;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user