set probes
This commit is contained in:
@@ -25,7 +25,7 @@ import { vectorIndexQuery } from 'src/utils/database';
|
||||
import { isValidInteger } from 'src/validation';
|
||||
import { DataSource, QueryRunner } from 'typeorm';
|
||||
|
||||
let cachedVectorExtension: VectorExtension | undefined;
|
||||
export let cachedVectorExtension: VectorExtension | undefined;
|
||||
export async function getVectorExtension(runner: Kysely<DB> | QueryRunner): Promise<VectorExtension> {
|
||||
if (cachedVectorExtension) {
|
||||
return cachedVectorExtension;
|
||||
@@ -50,6 +50,11 @@ export async function getVectorExtension(runner: Kysely<DB> | QueryRunner): Prom
|
||||
return cachedVectorExtension;
|
||||
}
|
||||
|
||||
export const probes: Record<VectorIndex, number> = {
|
||||
[VectorIndex.CLIP]: 1,
|
||||
[VectorIndex.FACE]: 1,
|
||||
};
|
||||
|
||||
@Injectable()
|
||||
export class DatabaseRepository {
|
||||
private readonly asyncLock = new AsyncLock();
|
||||
@@ -183,21 +188,17 @@ export class DatabaseRepository {
|
||||
for (const indexName of names) {
|
||||
const row = rows.find((index) => index.indexname === indexName);
|
||||
const table = VECTOR_INDEX_TABLES[indexName];
|
||||
if (!row) {
|
||||
promises.push(this.reindexVectors(indexName));
|
||||
continue;
|
||||
}
|
||||
|
||||
switch (vectorExtension) {
|
||||
case DatabaseExtension.VECTOR:
|
||||
case DatabaseExtension.VECTORS: {
|
||||
if (!row.indexdef.toLowerCase().includes(keyword)) {
|
||||
if (!row?.indexdef.toLowerCase().includes(keyword)) {
|
||||
promises.push(this.reindexVectors(indexName));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DatabaseExtension.VECTORCHORD: {
|
||||
const matches = row.indexdef.match(/(?<=lists = \[)\d+/g);
|
||||
const matches = row?.indexdef.match(/(?<=lists = \[)\d+/g);
|
||||
const lists = matches && matches.length > 0 ? Number(matches[0]) : 1;
|
||||
promises.push(
|
||||
this.db
|
||||
@@ -208,11 +209,14 @@ export class DatabaseRepository {
|
||||
const targetLists = this.targetListCount(count);
|
||||
this.logger.log(`targetLists=${targetLists}, current=${lists} for ${indexName} of ${count} rows`);
|
||||
if (
|
||||
!row.indexdef.toLowerCase().includes(keyword) ||
|
||||
!row?.indexdef.toLowerCase().includes(keyword) ||
|
||||
// slack factor is to avoid frequent reindexing if the count is borderline
|
||||
(lists !== targetLists && lists !== this.targetListCount(count * VECTORCHORD_LIST_SLACK_FACTOR))
|
||||
) {
|
||||
probes[indexName] = this.targetProbeCount(targetLists);
|
||||
return this.reindexVectors(indexName, { lists: targetLists });
|
||||
} else {
|
||||
probes[indexName] = this.targetProbeCount(lists);
|
||||
}
|
||||
}),
|
||||
);
|
||||
@@ -239,7 +243,7 @@ export class DatabaseRepository {
|
||||
);
|
||||
return;
|
||||
}
|
||||
const dimSize = await this.getDimSize(table);
|
||||
const dimSize = await this.getDimensionSize(table);
|
||||
await this.db.transaction().execute(async (tx) => {
|
||||
await sql`DROP INDEX IF EXISTS ${sql.raw(indexName)}`.execute(tx);
|
||||
if (!rows.some((row) => row.columnName === 'embedding')) {
|
||||
@@ -261,7 +265,7 @@ export class DatabaseRepository {
|
||||
await sql`SET search_path TO "$user", public, vectors`.execute(tx);
|
||||
}
|
||||
|
||||
private async getDimSize(table: string, column = 'embedding'): Promise<number> {
|
||||
async getDimensionSize(table: string, column = 'embedding'): Promise<number> {
|
||||
const { rows } = await sql<{ dimsize: number }>`
|
||||
SELECT atttypmod as dimsize
|
||||
FROM pg_attribute f
|
||||
@@ -280,7 +284,41 @@ export class DatabaseRepository {
|
||||
return dimSize;
|
||||
}
|
||||
|
||||
// TODO: set probes in queries
|
||||
async setDimensionSize(dimSize: number): Promise<void> {
|
||||
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
|
||||
throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
|
||||
}
|
||||
|
||||
// this is done in two transactions to handle concurrent writes
|
||||
await this.db.transaction().execute(async (trx) => {
|
||||
await sql`delete from ${sql.table('smart_search')}`.execute(trx);
|
||||
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
|
||||
await sql`alter table ${sql.table('smart_search')} add constraint dim_size_constraint check (array_length(embedding::real[], 1) = ${sql.lit(dimSize)})`.execute(
|
||||
trx,
|
||||
);
|
||||
});
|
||||
|
||||
const vectorExtension = await this.getVectorExtension();
|
||||
await this.db.transaction().execute(async (trx) => {
|
||||
await sql`drop index if exists clip_index`.execute(trx);
|
||||
await trx.schema
|
||||
.alterTable('smart_search')
|
||||
.alterColumn('embedding', (col) => col.setDataType(sql.raw(`vector(${dimSize})`)))
|
||||
.execute();
|
||||
await sql
|
||||
.raw(vectorIndexQuery({ vectorExtension, table: 'smart_search', indexName: VectorIndex.CLIP }))
|
||||
.execute(trx);
|
||||
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
|
||||
});
|
||||
probes[VectorIndex.CLIP] = 1;
|
||||
|
||||
await sql`vacuum analyze ${sql.table('smart_search')}`.execute(this.db);
|
||||
}
|
||||
|
||||
async deleteAllSearchEmbeddings(): Promise<void> {
|
||||
await sql`truncate ${sql.table('smart_search')}`.execute(this.db);
|
||||
}
|
||||
|
||||
private targetListCount(count: number) {
|
||||
if (count < 128_000) {
|
||||
return 1;
|
||||
@@ -291,6 +329,10 @@ export class DatabaseRepository {
|
||||
}
|
||||
}
|
||||
|
||||
private targetProbeCount(lists: number) {
|
||||
return Math.ceil(lists / 8);
|
||||
}
|
||||
|
||||
async runMigrations(options?: { transaction?: 'all' | 'none' | 'each' }): Promise<void> {
|
||||
const { database } = this.configRepository.getEnv();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user