feat: preload textual model
This commit is contained in:
@@ -122,7 +122,6 @@ export interface SystemConfig {
|
||||
modelName: string;
|
||||
loadTextualModelOnConnection: {
|
||||
enabled: boolean;
|
||||
ttl: number;
|
||||
};
|
||||
};
|
||||
duplicateDetection: {
|
||||
@@ -276,7 +275,6 @@ export const defaults = Object.freeze<SystemConfig>({
|
||||
modelName: 'ViT-B-32__openai',
|
||||
loadTextualModelOnConnection: {
|
||||
enabled: false,
|
||||
ttl: 300,
|
||||
},
|
||||
},
|
||||
duplicateDetection: {
|
||||
|
||||
@@ -14,12 +14,9 @@ export class ModelConfig extends TaskConfig {
|
||||
modelName!: string;
|
||||
}
|
||||
|
||||
export class LoadTextualModelOnConnection extends TaskConfig {
|
||||
@IsNumber()
|
||||
@Min(0)
|
||||
@Type(() => Number)
|
||||
@ApiProperty({ type: 'number', format: 'int64' })
|
||||
ttl!: number;
|
||||
export class LoadTextualModelOnConnection {
|
||||
@ValidateBoolean()
|
||||
enabled!: boolean;
|
||||
}
|
||||
|
||||
export class CLIPConfig extends ModelConfig {
|
||||
|
||||
@@ -24,17 +24,13 @@ export type ModelPayload = { imagePath: string } | { text: string };
|
||||
|
||||
type ModelOptions = { modelName: string };
|
||||
|
||||
export interface LoadModelOptions extends ModelOptions {
|
||||
ttl: number;
|
||||
}
|
||||
|
||||
export type FaceDetectionOptions = ModelOptions & { minScore: number };
|
||||
|
||||
type VisualResponse = { imageHeight: number; imageWidth: number };
|
||||
export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } };
|
||||
export type ClipVisualResponse = { [ModelTask.SEARCH]: number[] } & VisualResponse;
|
||||
|
||||
export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions | LoadModelOptions } };
|
||||
export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions } };
|
||||
export type ClipTextualResponse = { [ModelTask.SEARCH]: number[] };
|
||||
|
||||
export type FacialRecognitionRequest = {
|
||||
@@ -50,6 +46,11 @@ export interface Face {
|
||||
score: number;
|
||||
}
|
||||
|
||||
export enum LoadTextModelActions {
|
||||
LOAD,
|
||||
UNLOAD,
|
||||
}
|
||||
|
||||
export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse;
|
||||
export type DetectedFaces = { faces: Face[] } & VisualResponse;
|
||||
export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest;
|
||||
@@ -58,5 +59,5 @@ export interface IMachineLearningRepository {
|
||||
encodeImage(url: string, imagePath: string, config: ModelOptions): Promise<number[]>;
|
||||
encodeText(url: string, text: string, config: ModelOptions): Promise<number[]>;
|
||||
detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>;
|
||||
loadTextModel(url: string, config: ModelOptions): Promise<void>;
|
||||
prepareTextModel(url: string, config: ModelOptions, action: LoadTextModelActions): Promise<void>;
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ import {
|
||||
ServerEventMap,
|
||||
} from 'src/interfaces/event.interface';
|
||||
import { ILoggerRepository } from 'src/interfaces/logger.interface';
|
||||
import { IMachineLearningRepository } from 'src/interfaces/machine-learning.interface';
|
||||
import { IMachineLearningRepository, LoadTextModelActions } from 'src/interfaces/machine-learning.interface';
|
||||
import { ISystemMetadataRepository } from 'src/interfaces/system-metadata.interface';
|
||||
import { AuthService } from 'src/services/auth.service';
|
||||
import { Instrumentation } from 'src/utils/instrumentation';
|
||||
@@ -79,7 +79,12 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
|
||||
const { machineLearning } = await this.configCore.getConfig({ withCache: true });
|
||||
if (machineLearning.clip.loadTextualModelOnConnection.enabled) {
|
||||
try {
|
||||
this.machineLearningRepository.loadTextModel(machineLearning.url, machineLearning.clip);
|
||||
console.log(this.server);
|
||||
this.machineLearningRepository.prepareTextModel(
|
||||
machineLearning.url,
|
||||
machineLearning.clip,
|
||||
LoadTextModelActions.LOAD,
|
||||
);
|
||||
} catch (error) {
|
||||
this.logger.warn(error);
|
||||
}
|
||||
@@ -100,6 +105,21 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
|
||||
async handleDisconnect(client: Socket) {
|
||||
this.logger.log(`Websocket Disconnect: ${client.id}`);
|
||||
await client.leave(client.nsp.name);
|
||||
if ('background' in client.handshake.query && client.handshake.query.background === 'false') {
|
||||
const { machineLearning } = await this.configCore.getConfig({ withCache: true });
|
||||
if (machineLearning.clip.loadTextualModelOnConnection.enabled && this.server?.engine.clientsCount == 0) {
|
||||
try {
|
||||
this.machineLearningRepository.prepareTextModel(
|
||||
machineLearning.url,
|
||||
machineLearning.clip,
|
||||
LoadTextModelActions.UNLOAD,
|
||||
);
|
||||
this.logger.debug('sent request to unload text model');
|
||||
} catch (error) {
|
||||
this.logger.warn(error);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
on<T extends EmitEvent>(event: T, handler: EmitHandler<T>): void {
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
FaceDetectionOptions,
|
||||
FacialRecognitionResponse,
|
||||
IMachineLearningRepository,
|
||||
LoadTextModelActions,
|
||||
MachineLearningRequest,
|
||||
ModelPayload,
|
||||
ModelTask,
|
||||
@@ -38,11 +39,16 @@ export class MachineLearningRepository implements IMachineLearningRepository {
|
||||
return res;
|
||||
}
|
||||
|
||||
async loadTextModel(url: string, { modelName, loadTextualModelOnConnection: { ttl } }: CLIPConfig) {
|
||||
private prepareTextModelUrl: Record<LoadTextModelActions, string> = {
|
||||
[LoadTextModelActions.LOAD]: '/load',
|
||||
[LoadTextModelActions.UNLOAD]: '/unload',
|
||||
};
|
||||
|
||||
async prepareTextModel(url: string, { modelName }: CLIPConfig, actions: LoadTextModelActions) {
|
||||
try {
|
||||
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName, ttl } } };
|
||||
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } };
|
||||
const formData = await this.getFormData(request);
|
||||
const res = await this.fetchData(url, '/load', formData);
|
||||
const res = await this.fetchData(url, this.prepareTextModelUrl[actions], formData);
|
||||
if (res.status >= 400) {
|
||||
throw new Error(`${errorPrefix} Loadings textual model failed with status ${res.status}: ${res.statusText}`);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user