feat: preload textual model

This commit is contained in:
martabal
2024-09-25 18:22:54 +02:00
parent d34d631dd4
commit 59300d2097
10 changed files with 59 additions and 59 deletions
+9 -4
View File
@@ -28,7 +28,6 @@ from .schemas import (
InferenceEntries,
InferenceEntry,
InferenceResponse,
LoadModelEntry,
MessageResponse,
ModelFormat,
ModelIdentity,
@@ -125,17 +124,16 @@ def get_entries(entries: str = Form()) -> InferenceEntries:
raise HTTPException(422, "Invalid request format.")
def get_entry(entries: str = Form()) -> LoadModelEntry:
def get_entry(entries: str = Form()) -> InferenceEntry:
try:
request: PipelineRequest = orjson.loads(entries)
for task, types in request.items():
for type, entry in types.items():
parsed: LoadModelEntry = {
parsed: InferenceEntry = {
"name": entry["modelName"],
"task": task,
"type": type,
"options": entry.get("options", {}),
"ttl": entry["ttl"] if "ttl" in entry else settings.ttl,
}
return parsed
except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e:
@@ -163,6 +161,13 @@ async def load_model(entry: InferenceEntry = Depends(get_entry)) -> None:
return Response(status_code=200)
@app.post("/unload", response_model=TextResponse)
async def unload_model(entry: InferenceEntry = Depends(get_entry)) -> None:
await model_cache.unload(entry["name"], entry["type"], entry["task"])
print("unload")
return Response(status_code=200)
@app.post("/predict", dependencies=[Depends(update_state)])
async def predict(
entries: InferenceEntries = Depends(get_entries),
+7
View File
@@ -58,3 +58,10 @@ class ModelCache:
async def revalidate(self, key: str, ttl: int | None) -> None:
if ttl is not None and key in self.cache._handlers:
await self.cache.expire(key, ttl)
async def unload(self, model_name: str, model_type: ModelType, model_task: ModelTask) -> None:
key = f"{model_name}{model_type}{model_task}"
async with OptimisticLock(self.cache, key):
value = await self.cache.get(key)
if value is not None:
await self.cache.delete(key)
-11
View File
@@ -109,17 +109,6 @@ class InferenceEntry(TypedDict):
options: dict[str, Any]
class LoadModelEntry(InferenceEntry):
ttl: int
def __init__(self, name: str, task: ModelTask, type: ModelType, options: dict[str, Any], ttl: int):
super().__init__(name=name, task=task, type=type, options=options)
if ttl <= 0:
raise ValueError("ttl must be a positive integer")
self.ttl = ttl
InferenceEntries = tuple[list[InferenceEntry], list[InferenceEntry]]