set probes

This commit is contained in:
mertalev
2025-05-09 15:01:11 -04:00
parent c80b16d24e
commit b750440f90
8 changed files with 180 additions and 169 deletions
+53 -11
View File
@@ -25,7 +25,7 @@ import { vectorIndexQuery } from 'src/utils/database';
import { isValidInteger } from 'src/validation';
import { DataSource, QueryRunner } from 'typeorm';
let cachedVectorExtension: VectorExtension | undefined;
export let cachedVectorExtension: VectorExtension | undefined;
export async function getVectorExtension(runner: Kysely<DB> | QueryRunner): Promise<VectorExtension> {
if (cachedVectorExtension) {
return cachedVectorExtension;
@@ -50,6 +50,11 @@ export async function getVectorExtension(runner: Kysely<DB> | QueryRunner): Prom
return cachedVectorExtension;
}
export const probes: Record<VectorIndex, number> = {
[VectorIndex.CLIP]: 1,
[VectorIndex.FACE]: 1,
};
@Injectable()
export class DatabaseRepository {
private readonly asyncLock = new AsyncLock();
@@ -183,21 +188,17 @@ export class DatabaseRepository {
for (const indexName of names) {
const row = rows.find((index) => index.indexname === indexName);
const table = VECTOR_INDEX_TABLES[indexName];
if (!row) {
promises.push(this.reindexVectors(indexName));
continue;
}
switch (vectorExtension) {
case DatabaseExtension.VECTOR:
case DatabaseExtension.VECTORS: {
if (!row.indexdef.toLowerCase().includes(keyword)) {
if (!row?.indexdef.toLowerCase().includes(keyword)) {
promises.push(this.reindexVectors(indexName));
}
break;
}
case DatabaseExtension.VECTORCHORD: {
const matches = row.indexdef.match(/(?<=lists = \[)\d+/g);
const matches = row?.indexdef.match(/(?<=lists = \[)\d+/g);
const lists = matches && matches.length > 0 ? Number(matches[0]) : 1;
promises.push(
this.db
@@ -208,11 +209,14 @@ export class DatabaseRepository {
const targetLists = this.targetListCount(count);
this.logger.log(`targetLists=${targetLists}, current=${lists} for ${indexName} of ${count} rows`);
if (
!row.indexdef.toLowerCase().includes(keyword) ||
!row?.indexdef.toLowerCase().includes(keyword) ||
// slack factor is to avoid frequent reindexing if the count is borderline
(lists !== targetLists && lists !== this.targetListCount(count * VECTORCHORD_LIST_SLACK_FACTOR))
) {
probes[indexName] = this.targetProbeCount(targetLists);
return this.reindexVectors(indexName, { lists: targetLists });
} else {
probes[indexName] = this.targetProbeCount(lists);
}
}),
);
@@ -239,7 +243,7 @@ export class DatabaseRepository {
);
return;
}
const dimSize = await this.getDimSize(table);
const dimSize = await this.getDimensionSize(table);
await this.db.transaction().execute(async (tx) => {
await sql`DROP INDEX IF EXISTS ${sql.raw(indexName)}`.execute(tx);
if (!rows.some((row) => row.columnName === 'embedding')) {
@@ -261,7 +265,7 @@ export class DatabaseRepository {
await sql`SET search_path TO "$user", public, vectors`.execute(tx);
}
private async getDimSize(table: string, column = 'embedding'): Promise<number> {
async getDimensionSize(table: string, column = 'embedding'): Promise<number> {
const { rows } = await sql<{ dimsize: number }>`
SELECT atttypmod as dimsize
FROM pg_attribute f
@@ -280,7 +284,41 @@ export class DatabaseRepository {
return dimSize;
}
// TODO: set probes in queries
async setDimensionSize(dimSize: number): Promise<void> {
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
}
// this is done in two transactions to handle concurrent writes
await this.db.transaction().execute(async (trx) => {
await sql`delete from ${sql.table('smart_search')}`.execute(trx);
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
await sql`alter table ${sql.table('smart_search')} add constraint dim_size_constraint check (array_length(embedding::real[], 1) = ${sql.lit(dimSize)})`.execute(
trx,
);
});
const vectorExtension = await this.getVectorExtension();
await this.db.transaction().execute(async (trx) => {
await sql`drop index if exists clip_index`.execute(trx);
await trx.schema
.alterTable('smart_search')
.alterColumn('embedding', (col) => col.setDataType(sql.raw(`vector(${dimSize})`)))
.execute();
await sql
.raw(vectorIndexQuery({ vectorExtension, table: 'smart_search', indexName: VectorIndex.CLIP }))
.execute(trx);
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
});
probes[VectorIndex.CLIP] = 1;
await sql`vacuum analyze ${sql.table('smart_search')}`.execute(this.db);
}
async deleteAllSearchEmbeddings(): Promise<void> {
await sql`truncate ${sql.table('smart_search')}`.execute(this.db);
}
private targetListCount(count: number) {
if (count < 128_000) {
return 1;
@@ -291,6 +329,10 @@ export class DatabaseRepository {
}
}
private targetProbeCount(lists: number) {
return Math.ceil(lists / 8);
}
async runMigrations(options?: { transaction?: 'all' | 'none' | 'each' }): Promise<void> {
const { database } = this.configRepository.getEnv();