feat(server): search across own+partner assets (#5966)
* feat(server): search across own+partner assets * generate sql * fix sql parameter
This commit is contained in:
committed by
GitHub
parent
2688e05033
commit
cc7ba3c21a
@@ -334,7 +334,7 @@ export class PersonService {
|
||||
|
||||
for (const { embedding, ...rest } of faces) {
|
||||
const matches = await this.smartInfoRepository.searchFaces({
|
||||
ownerId: asset.ownerId,
|
||||
userIds: [asset.ownerId],
|
||||
embedding,
|
||||
numResults: 1,
|
||||
maxDistance: machineLearning.facialRecognition.maxDistance,
|
||||
|
||||
@@ -199,5 +199,5 @@ export interface IAssetRepository {
|
||||
search(options: AssetSearchOptions): Promise<AssetEntity[]>;
|
||||
getAssetIdByCity(userId: string, options: AssetExploreFieldOptions): Promise<SearchExploreItem<string>>;
|
||||
getAssetIdByTag(userId: string, options: AssetExploreFieldOptions): Promise<SearchExploreItem<string>>;
|
||||
searchMetadata(query: string, userId: string, options: MetadataSearchOptions): Promise<AssetEntity[]>;
|
||||
searchMetadata(query: string, userIds: string[], options: MetadataSearchOptions): Promise<AssetEntity[]>;
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ export const ISmartInfoRepository = 'ISmartInfoRepository';
|
||||
export type Embedding = number[];
|
||||
|
||||
export interface EmbeddingSearch {
|
||||
ownerId: string;
|
||||
userIds: string[];
|
||||
embedding: Embedding;
|
||||
numResults: number;
|
||||
maxDistance?: number;
|
||||
|
||||
@@ -4,6 +4,7 @@ import {
|
||||
authStub,
|
||||
newAssetRepositoryMock,
|
||||
newMachineLearningRepositoryMock,
|
||||
newPartnerRepositoryMock,
|
||||
newPersonRepositoryMock,
|
||||
newSmartInfoRepositoryMock,
|
||||
newSystemConfigRepositoryMock,
|
||||
@@ -13,6 +14,7 @@ import { mapAsset } from '../asset';
|
||||
import {
|
||||
IAssetRepository,
|
||||
IMachineLearningRepository,
|
||||
IPartnerRepository,
|
||||
IPersonRepository,
|
||||
ISmartInfoRepository,
|
||||
ISystemConfigRepository,
|
||||
@@ -29,6 +31,7 @@ describe(SearchService.name, () => {
|
||||
let machineMock: jest.Mocked<IMachineLearningRepository>;
|
||||
let personMock: jest.Mocked<IPersonRepository>;
|
||||
let smartInfoMock: jest.Mocked<ISmartInfoRepository>;
|
||||
let partnerMock: jest.Mocked<IPartnerRepository>;
|
||||
|
||||
beforeEach(() => {
|
||||
assetMock = newAssetRepositoryMock();
|
||||
@@ -36,7 +39,8 @@ describe(SearchService.name, () => {
|
||||
machineMock = newMachineLearningRepositoryMock();
|
||||
personMock = newPersonRepositoryMock();
|
||||
smartInfoMock = newSmartInfoRepositoryMock();
|
||||
sut = new SearchService(configMock, machineMock, personMock, smartInfoMock, assetMock);
|
||||
partnerMock = newPartnerRepositoryMock();
|
||||
sut = new SearchService(configMock, machineMock, personMock, smartInfoMock, assetMock, partnerMock);
|
||||
});
|
||||
|
||||
it('should work', () => {
|
||||
@@ -87,6 +91,7 @@ describe(SearchService.name, () => {
|
||||
it('should search by metadata if `clip` option is false', async () => {
|
||||
const dto: SearchDto = { q: 'test query', clip: false };
|
||||
assetMock.searchMetadata.mockResolvedValueOnce([assetStub.image]);
|
||||
partnerMock.getAll.mockResolvedValueOnce([]);
|
||||
const expectedResponse = {
|
||||
albums: {
|
||||
total: 0,
|
||||
@@ -105,7 +110,7 @@ describe(SearchService.name, () => {
|
||||
const result = await sut.search(authStub.user1, dto);
|
||||
|
||||
expect(result).toEqual(expectedResponse);
|
||||
expect(assetMock.searchMetadata).toHaveBeenCalledWith(dto.q, authStub.user1.user.id, { numResults: 250 });
|
||||
expect(assetMock.searchMetadata).toHaveBeenCalledWith(dto.q, [authStub.user1.user.id], { numResults: 250 });
|
||||
expect(smartInfoMock.searchCLIP).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
@@ -114,6 +119,7 @@ describe(SearchService.name, () => {
|
||||
const embedding = [1, 2, 3];
|
||||
smartInfoMock.searchCLIP.mockResolvedValueOnce([assetStub.image]);
|
||||
machineMock.encodeText.mockResolvedValueOnce(embedding);
|
||||
partnerMock.getAll.mockResolvedValueOnce([]);
|
||||
const expectedResponse = {
|
||||
albums: {
|
||||
total: 0,
|
||||
@@ -133,7 +139,7 @@ describe(SearchService.name, () => {
|
||||
|
||||
expect(result).toEqual(expectedResponse);
|
||||
expect(smartInfoMock.searchCLIP).toHaveBeenCalledWith({
|
||||
ownerId: authStub.user1.user.id,
|
||||
userIds: [authStub.user1.user.id],
|
||||
embedding,
|
||||
numResults: 100,
|
||||
});
|
||||
|
||||
@@ -7,6 +7,7 @@ import { PersonResponseDto } from '../person';
|
||||
import {
|
||||
IAssetRepository,
|
||||
IMachineLearningRepository,
|
||||
IPartnerRepository,
|
||||
IPersonRepository,
|
||||
ISmartInfoRepository,
|
||||
ISystemConfigRepository,
|
||||
@@ -28,6 +29,7 @@ export class SearchService {
|
||||
@Inject(IPersonRepository) private personRepository: IPersonRepository,
|
||||
@Inject(ISmartInfoRepository) private smartInfoRepository: ISmartInfoRepository,
|
||||
@Inject(IAssetRepository) private assetRepository: IAssetRepository,
|
||||
@Inject(IPartnerRepository) private partnerRepository: IPartnerRepository,
|
||||
) {
|
||||
this.configCore = SystemConfigCore.create(configRepository);
|
||||
}
|
||||
@@ -64,6 +66,7 @@ export class SearchService {
|
||||
throw new Error('CLIP is not enabled');
|
||||
}
|
||||
const strategy = dto.clip ? SearchStrategy.CLIP : SearchStrategy.TEXT;
|
||||
const userIds = await this.getUserIdsToSearch(auth);
|
||||
|
||||
let assets: AssetEntity[] = [];
|
||||
|
||||
@@ -74,10 +77,10 @@ export class SearchService {
|
||||
{ text: query },
|
||||
machineLearning.clip,
|
||||
);
|
||||
assets = await this.smartInfoRepository.searchCLIP({ ownerId: auth.user.id, embedding, numResults: 100 });
|
||||
assets = await this.smartInfoRepository.searchCLIP({ userIds: userIds, embedding, numResults: 100 });
|
||||
break;
|
||||
case SearchStrategy.TEXT:
|
||||
assets = await this.assetRepository.searchMetadata(query, auth.user.id, { numResults: 250 });
|
||||
assets = await this.assetRepository.searchMetadata(query, userIds, { numResults: 250 });
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -97,4 +100,14 @@ export class SearchService {
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
private async getUserIdsToSearch(auth: AuthDto): Promise<string[]> {
|
||||
const userIds: string[] = [auth.user.id];
|
||||
const partners = await this.partnerRepository.getAll(auth.user.id);
|
||||
const partnersIds = partners
|
||||
.filter((partner) => partner.sharedBy && partner.inTimeline)
|
||||
.map((partner) => partner.sharedById);
|
||||
userIds.push(...partnersIds);
|
||||
return userIds;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user