Files
immich/server/src/repositories/search.repository.ts
T
Thomas 3ce353393a chore(server): don't insert embeddings if the model has changed (#17885)
* chore(server): don't insert embeddings if the model has changed

We're moving away from the heuristic of waiting for queues to complete. The job
which inserts embeddings can simply check if the model has changed before
inserting, rather than attempting to lock the queue.

* more robust dim size update

* use check constraint

* index command cleanup

* add create statement

* update medium test, create appropriate extension

* new line

* set dimension size when running on all assets

* why does it want braces smh

* take 2

---------

Co-authored-by: mertalev <101130780+mertalev@users.noreply.github.com>
2025-04-29 14:23:01 -04:00

562 lines
17 KiB
TypeScript

import { Injectable } from '@nestjs/common';
import { Kysely, OrderByDirection, Selectable, sql } from 'kysely';
import { InjectKysely } from 'nestjs-kysely';
import { randomUUID } from 'node:crypto';
import { DB, Exif } from 'src/db';
import { DummyValue, GenerateSql } from 'src/decorators';
import { MapAsset } from 'src/dtos/asset-response.dto';
import { AssetStatus, AssetType } from 'src/enum';
import { ConfigRepository } from 'src/repositories/config.repository';
import { anyUuid, asUuid, searchAssetBuilder, vectorIndexQuery } from 'src/utils/database';
import { isValidInteger } from 'src/validation';
export interface SearchResult<T> {
/** total matches */
total: number;
/** collection size */
count: number;
/** current page */
page: number;
/** items for page */
items: T[];
/** score */
distances: number[];
facets: SearchFacet[];
}
export interface SearchFacet {
fieldName: string;
counts: Array<{
count: number;
value: string;
}>;
}
export type SearchExploreItemSet<T> = Array<{
value: string;
data: T;
}>;
export interface SearchExploreItem<T> {
fieldName: string;
items: SearchExploreItemSet<T>;
}
export interface SearchAssetIDOptions {
checksum?: Buffer;
deviceAssetId?: string;
id?: string;
}
export interface SearchUserIdOptions {
deviceId?: string;
libraryId?: string | null;
userIds?: string[];
}
export type SearchIdOptions = SearchAssetIDOptions & SearchUserIdOptions;
export interface SearchStatusOptions {
isArchived?: boolean;
isEncoded?: boolean;
isFavorite?: boolean;
isMotion?: boolean;
isOffline?: boolean;
isVisible?: boolean;
isNotInAlbum?: boolean;
type?: AssetType;
status?: AssetStatus;
withArchived?: boolean;
withDeleted?: boolean;
}
export interface SearchOneToOneRelationOptions {
withExif?: boolean;
withStacked?: boolean;
}
export interface SearchRelationOptions extends SearchOneToOneRelationOptions {
withFaces?: boolean;
withPeople?: boolean;
}
export interface SearchDateOptions {
createdBefore?: Date;
createdAfter?: Date;
takenBefore?: Date;
takenAfter?: Date;
trashedBefore?: Date;
trashedAfter?: Date;
updatedBefore?: Date;
updatedAfter?: Date;
}
export interface SearchPathOptions {
encodedVideoPath?: string;
originalFileName?: string;
originalPath?: string;
previewPath?: string;
thumbnailPath?: string;
}
export interface SearchExifOptions {
city?: string | null;
country?: string | null;
lensModel?: string | null;
make?: string | null;
model?: string | null;
state?: string | null;
description?: string | null;
rating?: number | null;
}
export interface SearchEmbeddingOptions {
embedding: string;
userIds: string[];
}
export interface SearchPeopleOptions {
personIds?: string[];
}
export interface SearchTagOptions {
tagIds?: string[];
}
export interface SearchOrderOptions {
orderDirection?: 'asc' | 'desc';
}
export interface SearchPaginationOptions {
page: number;
size: number;
}
type BaseAssetSearchOptions = SearchDateOptions &
SearchIdOptions &
SearchExifOptions &
SearchOrderOptions &
SearchPathOptions &
SearchStatusOptions &
SearchUserIdOptions &
SearchPeopleOptions &
SearchTagOptions;
export type AssetSearchOptions = BaseAssetSearchOptions & SearchRelationOptions;
export type AssetSearchOneToOneRelationOptions = BaseAssetSearchOptions & SearchOneToOneRelationOptions;
export type AssetSearchBuilderOptions = Omit<AssetSearchOptions, 'orderDirection'>;
export type SmartSearchOptions = SearchDateOptions &
SearchEmbeddingOptions &
SearchExifOptions &
SearchOneToOneRelationOptions &
SearchStatusOptions &
SearchUserIdOptions &
SearchPeopleOptions &
SearchTagOptions;
export interface FaceEmbeddingSearch extends SearchEmbeddingOptions {
hasPerson?: boolean;
numResults: number;
maxDistance: number;
minBirthDate?: Date | null;
}
export interface AssetDuplicateSearch {
assetId: string;
embedding: string;
maxDistance: number;
type: AssetType;
userIds: string[];
}
export interface FaceSearchResult {
distance: number;
id: string;
personId: string | null;
}
export interface AssetDuplicateResult {
assetId: string;
duplicateId: string | null;
distance: number;
}
export interface GetStatesOptions {
country?: string;
}
export interface GetCitiesOptions extends GetStatesOptions {
state?: string;
}
export interface GetCameraModelsOptions {
make?: string;
}
export interface GetCameraMakesOptions {
model?: string;
}
@Injectable()
export class SearchRepository {
constructor(
@InjectKysely() private db: Kysely<DB>,
private configRepository: ConfigRepository,
) {}
@GenerateSql({
params: [
{ page: 1, size: 100 },
{
takenAfter: DummyValue.DATE,
lensModel: DummyValue.STRING,
withStacked: true,
isFavorite: true,
userIds: [DummyValue.UUID],
},
],
})
async searchMetadata(pagination: SearchPaginationOptions, options: AssetSearchOptions) {
const orderDirection = (options.orderDirection?.toLowerCase() || 'desc') as OrderByDirection;
const items = await searchAssetBuilder(this.db, options)
.orderBy('assets.fileCreatedAt', orderDirection)
.limit(pagination.size + 1)
.offset((pagination.page - 1) * pagination.size)
.execute();
const hasNextPage = items.length > pagination.size;
items.splice(pagination.size);
return { items, hasNextPage };
}
@GenerateSql({
params: [
100,
{
takenAfter: DummyValue.DATE,
lensModel: DummyValue.STRING,
withStacked: true,
isFavorite: true,
userIds: [DummyValue.UUID],
},
],
})
async searchRandom(size: number, options: AssetSearchOptions) {
const uuid = randomUUID();
const builder = searchAssetBuilder(this.db, options);
const lessThan = builder
.where('assets.id', '<', uuid)
.orderBy(sql`random()`)
.limit(size);
const greaterThan = builder
.where('assets.id', '>', uuid)
.orderBy(sql`random()`)
.limit(size);
const { rows } = await sql<MapAsset>`${lessThan} union all ${greaterThan} limit ${size}`.execute(this.db);
return rows;
}
@GenerateSql({
params: [
{ page: 1, size: 200 },
{
takenAfter: DummyValue.DATE,
embedding: DummyValue.VECTOR,
lensModel: DummyValue.STRING,
withStacked: true,
isFavorite: true,
userIds: [DummyValue.UUID],
},
],
})
async searchSmart(pagination: SearchPaginationOptions, options: SmartSearchOptions) {
if (!isValidInteger(pagination.size, { min: 1, max: 1000 })) {
throw new Error(`Invalid value for 'size': ${pagination.size}`);
}
const items = await searchAssetBuilder(this.db, options)
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
.orderBy(sql`smart_search.embedding <=> ${options.embedding}`)
.limit(pagination.size + 1)
.offset((pagination.page - 1) * pagination.size)
.execute();
const hasNextPage = items.length > pagination.size;
items.splice(pagination.size);
return { items, hasNextPage };
}
@GenerateSql({
params: [
{
assetId: DummyValue.UUID,
embedding: DummyValue.VECTOR,
maxDistance: 0.6,
type: AssetType.IMAGE,
userIds: [DummyValue.UUID],
},
],
})
searchDuplicates({ assetId, embedding, maxDistance, type, userIds }: AssetDuplicateSearch) {
return this.db
.with('cte', (qb) =>
qb
.selectFrom('assets')
.select([
'assets.id as assetId',
'assets.duplicateId',
sql<number>`smart_search.embedding <=> ${embedding}`.as('distance'),
])
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
.where('assets.ownerId', '=', anyUuid(userIds))
.where('assets.deletedAt', 'is', null)
.where('assets.isVisible', '=', true)
.where('assets.type', '=', type)
.where('assets.id', '!=', asUuid(assetId))
.where('assets.stackId', 'is', null)
.orderBy(sql`smart_search.embedding <=> ${embedding}`)
.limit(64),
)
.selectFrom('cte')
.selectAll()
.where('cte.distance', '<=', maxDistance as number)
.execute();
}
@GenerateSql({
params: [
{
userIds: [DummyValue.UUID],
embedding: DummyValue.VECTOR,
numResults: 10,
maxDistance: 0.6,
},
],
})
searchFaces({ userIds, embedding, numResults, maxDistance, hasPerson, minBirthDate }: FaceEmbeddingSearch) {
if (!isValidInteger(numResults, { min: 1, max: 1000 })) {
throw new Error(`Invalid value for 'numResults': ${numResults}`);
}
return this.db
.with('cte', (qb) =>
qb
.selectFrom('asset_faces')
.select([
'asset_faces.id',
'asset_faces.personId',
sql<number>`face_search.embedding <=> ${embedding}`.as('distance'),
])
.innerJoin('assets', 'assets.id', 'asset_faces.assetId')
.innerJoin('face_search', 'face_search.faceId', 'asset_faces.id')
.leftJoin('person', 'person.id', 'asset_faces.personId')
.where('assets.ownerId', '=', anyUuid(userIds))
.where('assets.deletedAt', 'is', null)
.$if(!!hasPerson, (qb) => qb.where('asset_faces.personId', 'is not', null))
.$if(!!minBirthDate, (qb) =>
qb.where((eb) => eb.or([eb('person.birthDate', 'is', null), eb('person.birthDate', '<=', minBirthDate!)])),
)
.orderBy(sql`face_search.embedding <=> ${embedding}`)
.limit(numResults),
)
.selectFrom('cte')
.selectAll()
.where('cte.distance', '<=', maxDistance)
.execute();
}
@GenerateSql({ params: [DummyValue.STRING] })
searchPlaces(placeName: string) {
return this.db
.selectFrom('geodata_places')
.selectAll()
.where(
() =>
// kysely doesn't support trigram %>> or <->>> operators
sql`
f_unaccent(name) %>> f_unaccent(${placeName}) or
f_unaccent("admin2Name") %>> f_unaccent(${placeName}) or
f_unaccent("admin1Name") %>> f_unaccent(${placeName}) or
f_unaccent("alternateNames") %>> f_unaccent(${placeName})
`,
)
.orderBy(
sql`
coalesce(f_unaccent(name) <->>> f_unaccent(${placeName}), 0.1) +
coalesce(f_unaccent("admin2Name") <->>> f_unaccent(${placeName}), 0.1) +
coalesce(f_unaccent("admin1Name") <->>> f_unaccent(${placeName}), 0.1) +
coalesce(f_unaccent("alternateNames") <->>> f_unaccent(${placeName}), 0.1)
`,
)
.limit(20)
.execute();
}
@GenerateSql({ params: [[DummyValue.UUID]] })
getAssetsByCity(userIds: string[]) {
return this.db
.withRecursive('cte', (qb) => {
const base = qb
.selectFrom('exif')
.select(['city', 'assetId'])
.innerJoin('assets', 'assets.id', 'exif.assetId')
.where('assets.ownerId', '=', anyUuid(userIds))
.where('assets.isVisible', '=', true)
.where('assets.isArchived', '=', false)
.where('assets.type', '=', AssetType.IMAGE)
.where('assets.deletedAt', 'is', null)
.orderBy('city')
.limit(1);
const recursive = qb
.selectFrom('cte')
.select(['l.city', 'l.assetId'])
.innerJoinLateral(
(qb) =>
qb
.selectFrom('exif')
.select(['city', 'assetId'])
.innerJoin('assets', 'assets.id', 'exif.assetId')
.where('assets.ownerId', '=', anyUuid(userIds))
.where('assets.isVisible', '=', true)
.where('assets.isArchived', '=', false)
.where('assets.type', '=', AssetType.IMAGE)
.where('assets.deletedAt', 'is', null)
.whereRef('exif.city', '>', 'cte.city')
.orderBy('city')
.limit(1)
.as('l'),
(join) => join.onTrue(),
);
return sql<{ city: string; assetId: string }>`(${base} union all ${recursive})`;
})
.selectFrom('assets')
.innerJoin('exif', 'assets.id', 'exif.assetId')
.innerJoin('cte', 'assets.id', 'cte.assetId')
.selectAll('assets')
.select((eb) =>
eb
.fn('to_jsonb', [eb.table('exif')])
.$castTo<Selectable<Exif>>()
.as('exifInfo'),
)
.orderBy('exif.city')
.execute();
}
async upsert(assetId: string, embedding: string): Promise<void> {
await this.db
.insertInto('smart_search')
.values({ assetId, embedding })
.onConflict((oc) => oc.column('assetId').doUpdateSet((eb) => ({ embedding: eb.ref('excluded.embedding') })))
.execute();
}
async getDimensionSize(): Promise<number> {
const { rows } = await sql<{ dimsize: number }>`
select atttypmod as dimsize
from pg_attribute f
join pg_class c ON c.oid = f.attrelid
where c.relkind = 'r'::char
and f.attnum > 0
and c.relname = 'smart_search'
and f.attname = 'embedding'
`.execute(this.db);
const dimSize = rows[0]['dimsize'];
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
throw new Error(`Could not retrieve CLIP dimension size`);
}
return dimSize;
}
async setDimensionSize(dimSize: number): Promise<void> {
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
}
// this is done in two transactions to handle concurrent writes
await this.db.transaction().execute(async (trx) => {
await sql`delete from ${sql.table('smart_search')}`.execute(trx);
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
await sql`alter table ${sql.table('smart_search')} add constraint dim_size_constraint check (array_length(embedding::real[], 1) = ${sql.lit(dimSize)})`.execute(
trx,
);
});
const vectorExtension = this.configRepository.getEnv().database.vectorExtension;
await this.db.transaction().execute(async (trx) => {
await sql`drop index if exists clip_index`.execute(trx);
await trx.schema
.alterTable('smart_search')
.alterColumn('embedding', (col) => col.setDataType(sql.raw(`vector(${dimSize})`)))
.execute();
await sql.raw(vectorIndexQuery({ vectorExtension, table: 'smart_search', indexName: 'clip_index' })).execute(trx);
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
});
await sql`vacuum analyze ${sql.table('smart_search')}`.execute(this.db);
}
async deleteAllSearchEmbeddings(): Promise<void> {
await sql`truncate ${sql.table('smart_search')}`.execute(this.db);
}
async getCountries(userIds: string[]): Promise<string[]> {
const res = await this.getExifField('country', userIds).execute();
return res.map((row) => row.country!);
}
@GenerateSql({ params: [[DummyValue.UUID], DummyValue.STRING] })
async getStates(userIds: string[], { country }: GetStatesOptions): Promise<string[]> {
const res = await this.getExifField('state', userIds)
.$if(!!country, (qb) => qb.where('country', '=', country!))
.execute();
return res.map((row) => row.state!);
}
@GenerateSql({ params: [[DummyValue.UUID], DummyValue.STRING, DummyValue.STRING] })
async getCities(userIds: string[], { country, state }: GetCitiesOptions): Promise<string[]> {
const res = await this.getExifField('city', userIds)
.$if(!!country, (qb) => qb.where('country', '=', country!))
.$if(!!state, (qb) => qb.where('state', '=', state!))
.execute();
return res.map((row) => row.city!);
}
@GenerateSql({ params: [[DummyValue.UUID], DummyValue.STRING] })
async getCameraMakes(userIds: string[], { model }: GetCameraMakesOptions): Promise<string[]> {
const res = await this.getExifField('make', userIds)
.$if(!!model, (qb) => qb.where('model', '=', model!))
.execute();
return res.map((row) => row.make!);
}
@GenerateSql({ params: [[DummyValue.UUID], DummyValue.STRING] })
async getCameraModels(userIds: string[], { make }: GetCameraModelsOptions): Promise<string[]> {
const res = await this.getExifField('model', userIds)
.$if(!!make, (qb) => qb.where('make', '=', make!))
.execute();
return res.map((row) => row.model!);
}
private getExifField<K extends 'city' | 'state' | 'country' | 'make' | 'model'>(field: K, userIds: string[]) {
return this.db
.selectFrom('exif')
.select(field)
.distinctOn(field)
.innerJoin('assets', 'assets.id', 'exif.assetId')
.where('ownerId', '=', anyUuid(userIds))
.where('isVisible', '=', true)
.where('deletedAt', 'is', null)
.where(field, 'is not', null);
}
}